mirror of
https://github.com/FiloSottile/age.git
synced 2026-03-11 08:55:41 +00:00
age: add DecryptReaderAt
This commit is contained in:
parent
abe371e157
commit
2ff5d341f6
6 changed files with 1110 additions and 128 deletions
61
age.go
61
age.go
|
|
@ -252,13 +252,12 @@ func (e *NoIdentityMatchError) Unwrap() []error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt decrypts a file encrypted to one or more identities.
|
// Decrypt decrypts a file encrypted to one or more identities.
|
||||||
//
|
// All identities will be tried until one successfully decrypts the file.
|
||||||
// It returns a Reader reading the decrypted plaintext of the age file read
|
|
||||||
// from src. All identities will be tried until one successfully decrypts the file.
|
|
||||||
// Native, non-interactive identities are tried before any other identities.
|
// Native, non-interactive identities are tried before any other identities.
|
||||||
//
|
//
|
||||||
// If no identity matches the encrypted file, the returned error will be of type
|
// Decrypt returns a Reader reading the decrypted plaintext of the age file read
|
||||||
// [NoIdentityMatchError].
|
// from src. If no identity matches the encrypted file, the returned error will
|
||||||
|
// be of type [NoIdentityMatchError].
|
||||||
func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
|
func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
|
||||||
hdr, payload, err := format.Parse(src)
|
hdr, payload, err := format.Parse(src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -278,6 +277,58 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
|
||||||
return stream.NewDecryptReader(streamKey(fileKey, nonce), payload)
|
return stream.NewDecryptReader(streamKey(fileKey, nonce), payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DecryptReaderAt decrypts a file encrypted to one or more identities.
|
||||||
|
// All identities will be tried until one successfully decrypts the file.
|
||||||
|
// Native, non-interactive identities are tried before any other identities.
|
||||||
|
//
|
||||||
|
// DecryptReaderAt takes an underlying [io.ReaderAt] and its total encrypted
|
||||||
|
// size, and returns a ReaderAt of the decrypted plaintext and the plaintext
|
||||||
|
// size. These can be used for example to instantiate an [io.SectionReader],
|
||||||
|
// which implements [io.Reader] and [io.Seeker]. Note that ReaderAt by
|
||||||
|
// definition disregards the seek position of src.
|
||||||
|
//
|
||||||
|
// The ReadAt method of the returned ReaderAt can be called concurrently.
|
||||||
|
// The ReaderAt will internally cache the most recently decrypted chunk.
|
||||||
|
// DecryptReaderAt reads and decrypts the final chunk before returning,
|
||||||
|
// to authenticate the plaintext size.
|
||||||
|
//
|
||||||
|
// If no identity matches the encrypted file, the returned error will be of
|
||||||
|
// type [NoIdentityMatchError].
|
||||||
|
func DecryptReaderAt(src io.ReaderAt, encryptedSize int64, identities ...Identity) (io.ReaderAt, int64, error) {
|
||||||
|
srcReader := io.NewSectionReader(src, 0, encryptedSize)
|
||||||
|
hdr, payload, err := format.Parse(srcReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to read header: %w", err)
|
||||||
|
}
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
if err := hdr.Marshal(buf); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to serialize header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileKey, err := decryptHdr(hdr, identities...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := make([]byte, streamNonceSize)
|
||||||
|
if _, err := io.ReadFull(payload, nonce); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to read nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadOffset := int64(buf.Len()) + int64(len(nonce))
|
||||||
|
payloadSize := encryptedSize - payloadOffset
|
||||||
|
plaintextSize, err := stream.PlaintextSize(payloadSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
payloadReaderAt := io.NewSectionReader(src, payloadOffset, payloadSize)
|
||||||
|
r, err := stream.NewDecryptReaderAt(streamKey(fileKey, nonce), payloadReaderAt, payloadSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return r, plaintextSize, nil
|
||||||
|
}
|
||||||
|
|
||||||
func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {
|
func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {
|
||||||
if len(identities) == 0 {
|
if len(identities) == 0 {
|
||||||
return nil, errors.New("no identities specified")
|
return nil, errors.New("no identities specified")
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"filippo.io/age/armor"
|
"filippo.io/age/armor"
|
||||||
"filippo.io/age/internal/format"
|
"filippo.io/age/internal/format"
|
||||||
"filippo.io/age/internal/stream"
|
"filippo.io/age/internal/stream"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Metadata struct {
|
type Metadata struct {
|
||||||
|
|
@ -88,9 +87,9 @@ func Inspect(r io.Reader, fileSize int64) (*Metadata, error) {
|
||||||
}
|
}
|
||||||
data.Sizes.Armor = tr.count - fileSize
|
data.Sizes.Armor = tr.count - fileSize
|
||||||
}
|
}
|
||||||
data.Sizes.Overhead = streamOverhead(fileSize - data.Sizes.Header)
|
data.Sizes.Overhead, err = streamOverhead(fileSize - data.Sizes.Header)
|
||||||
if data.Sizes.Overhead > fileSize-data.Sizes.Header {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("payload too small to be a valid age file")
|
return nil, fmt.Errorf("failed to compute stream overhead: %w", err)
|
||||||
}
|
}
|
||||||
data.Sizes.MinPayload = fileSize - data.Sizes.Header - data.Sizes.Overhead
|
data.Sizes.MinPayload = fileSize - data.Sizes.Header - data.Sizes.Overhead
|
||||||
data.Sizes.MaxPayload = data.Sizes.MinPayload
|
data.Sizes.MaxPayload = data.Sizes.MinPayload
|
||||||
|
|
@ -114,13 +113,15 @@ func (tr *trackReader) Read(p []byte) (int, error) {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamOverhead(payloadSize int64) int64 {
|
func streamOverhead(payloadSize int64) (int64, error) {
|
||||||
const streamNonceSize = 16
|
const streamNonceSize = 16
|
||||||
const encChunkSize = stream.ChunkSize + chacha20poly1305.Overhead
|
if payloadSize < streamNonceSize {
|
||||||
payloadSize -= streamNonceSize
|
return 0, fmt.Errorf("encrypted size too small: %d", payloadSize)
|
||||||
if payloadSize <= 0 {
|
|
||||||
return streamNonceSize
|
|
||||||
}
|
}
|
||||||
chunks := (payloadSize + encChunkSize - 1) / encChunkSize
|
encryptedSize := payloadSize - streamNonceSize
|
||||||
return streamNonceSize + chunks*chacha20poly1305.Overhead
|
plaintextSize, err := stream.PlaintextSize(encryptedSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return payloadSize - plaintextSize, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
46
internal/inspect/inspect_test.go
Normal file
46
internal/inspect/inspect_test.go
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
package inspect
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"filippo.io/age/internal/stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStreamOverhead(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
payloadSize int64
|
||||||
|
want int64
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{payloadSize: 0, wantErr: true},
|
||||||
|
{payloadSize: 15, wantErr: true},
|
||||||
|
{payloadSize: 16, wantErr: true},
|
||||||
|
{payloadSize: 16 + 15, wantErr: true},
|
||||||
|
{payloadSize: 16 + 16, want: 16 + 16}, // empty plaintext
|
||||||
|
{payloadSize: 16 + 1 + 16, want: 16 + 16},
|
||||||
|
{payloadSize: 16 + stream.ChunkSize + 16, want: 16 + 16},
|
||||||
|
{payloadSize: 16 + stream.ChunkSize + 16 + 1, wantErr: true},
|
||||||
|
{payloadSize: 16 + stream.ChunkSize + 16 + 15, wantErr: true},
|
||||||
|
{payloadSize: 16 + stream.ChunkSize + 16 + 16, wantErr: true}, // empty final chunk
|
||||||
|
{payloadSize: 16 + stream.ChunkSize + 16 + 1 + 16, want: 16 + 16 + 16},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
name := "payloadSize=" + fmt.Sprint(tt.payloadSize)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
got, gotErr := streamOverhead(tt.payloadSize)
|
||||||
|
if gotErr != nil {
|
||||||
|
if !tt.wantErr {
|
||||||
|
t.Errorf("streamOverhead() failed: %v", gotErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr {
|
||||||
|
t.Fatal("streamOverhead() succeeded unexpectedly")
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("streamOverhead() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -8,15 +8,42 @@ package stream
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ChunkSize = 64 * 1024
|
const ChunkSize = 64 * 1024
|
||||||
|
|
||||||
|
func EncryptedChunkCount(encryptedSize int64) (int64, error) {
|
||||||
|
chunks := (encryptedSize + encChunkSize - 1) / encChunkSize
|
||||||
|
|
||||||
|
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
|
||||||
|
expChunks := (plaintextSize + ChunkSize - 1) / ChunkSize
|
||||||
|
// Empty plaintext, the only case that allows (and requires) an empty chunk.
|
||||||
|
if plaintextSize == 0 {
|
||||||
|
expChunks = 1
|
||||||
|
}
|
||||||
|
if expChunks != chunks {
|
||||||
|
return 0, fmt.Errorf("invalid encrypted payload size: %d", encryptedSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PlaintextSize(encryptedSize int64) (int64, error) {
|
||||||
|
chunks, err := EncryptedChunkCount(encryptedSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
|
||||||
|
return plaintextSize, nil
|
||||||
|
}
|
||||||
|
|
||||||
type DecryptReader struct {
|
type DecryptReader struct {
|
||||||
a cipher.AEAD
|
a cipher.AEAD
|
||||||
src io.Reader
|
src io.Reader
|
||||||
|
|
@ -135,6 +162,12 @@ func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
|
||||||
panic("stream: chunk counter wrapped around")
|
panic("stream: chunk counter wrapped around")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func nonceForChunk(chunkIndex int64) *[chacha20poly1305.NonceSize]byte {
|
||||||
|
var nonce [chacha20poly1305.NonceSize]byte
|
||||||
|
binary.BigEndian.PutUint64(nonce[3:11], uint64(chunkIndex))
|
||||||
|
return &nonce
|
||||||
|
}
|
||||||
|
|
||||||
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
|
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
|
||||||
nonce[len(nonce)-1] = lastChunkFlag
|
nonce[len(nonce)-1] = lastChunkFlag
|
||||||
}
|
}
|
||||||
|
|
@ -312,3 +345,102 @@ func (r *EncryptReader) feedBuffer() error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DecryptReaderAt struct {
|
||||||
|
a cipher.AEAD
|
||||||
|
src io.ReaderAt
|
||||||
|
size int64
|
||||||
|
chunks int64
|
||||||
|
cache atomic.Pointer[cachedChunk]
|
||||||
|
}
|
||||||
|
|
||||||
|
type cachedChunk struct {
|
||||||
|
off int64
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDecryptReaderAt(key []byte, src io.ReaderAt, size int64) (*DecryptReaderAt, error) {
|
||||||
|
aead, err := chacha20poly1305.New(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that size is valid by decrypting the final chunk.
|
||||||
|
chunks, err := EncryptedChunkCount(size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
finalChunkIndex := chunks - 1
|
||||||
|
finalChunkOff := finalChunkIndex * encChunkSize
|
||||||
|
finalChunkSize := size - finalChunkOff
|
||||||
|
finalChunk := make([]byte, finalChunkSize)
|
||||||
|
if _, err := src.ReadAt(finalChunk, finalChunkOff); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read final chunk: %w", err)
|
||||||
|
}
|
||||||
|
nonce := nonceForChunk(finalChunkIndex)
|
||||||
|
setLastChunkFlag(nonce)
|
||||||
|
plaintext, err := aead.Open(finalChunk[:0], nonce[:], finalChunk, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decrypt and authenticate final chunk: %w", err)
|
||||||
|
}
|
||||||
|
cache := &cachedChunk{off: finalChunkOff, data: plaintext}
|
||||||
|
|
||||||
|
plaintextSize := size - chunks*chacha20poly1305.Overhead
|
||||||
|
r := &DecryptReaderAt{a: aead, src: src, size: plaintextSize, chunks: chunks}
|
||||||
|
r.cache.Store(cache)
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||||
|
if off < 0 || off > r.size {
|
||||||
|
return 0, fmt.Errorf("offset out of range [0:%d]: %d", r.size, off)
|
||||||
|
}
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
chunk := make([]byte, encChunkSize)
|
||||||
|
for len(p) > 0 && off < r.size {
|
||||||
|
chunkIndex := off / ChunkSize
|
||||||
|
chunkOff := chunkIndex * encChunkSize
|
||||||
|
encSize := r.size + r.chunks*chacha20poly1305.Overhead
|
||||||
|
chunkSize := min(encSize-chunkOff, encChunkSize)
|
||||||
|
|
||||||
|
cached := r.cache.Load()
|
||||||
|
var plaintext []byte
|
||||||
|
if cached != nil && cached.off == chunkOff {
|
||||||
|
plaintext = cached.data
|
||||||
|
} else {
|
||||||
|
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
|
||||||
|
if err == io.EOF {
|
||||||
|
if int64(nn) != chunkSize {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
} else {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return n, fmt.Errorf("failed to read chunk at offset %d: %w", chunkOff, err)
|
||||||
|
}
|
||||||
|
nonce := nonceForChunk(chunkIndex)
|
||||||
|
if chunkIndex == r.chunks-1 {
|
||||||
|
setLastChunkFlag(nonce)
|
||||||
|
}
|
||||||
|
plaintext, err = r.a.Open(chunk[:0], nonce[:], chunk[:chunkSize], nil)
|
||||||
|
if err != nil {
|
||||||
|
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
|
||||||
|
}
|
||||||
|
r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext})
|
||||||
|
}
|
||||||
|
|
||||||
|
plainChunkOff := int(off - chunkIndex*ChunkSize)
|
||||||
|
copySize := min(len(plaintext)-plainChunkOff, len(p))
|
||||||
|
copy(p, plaintext[plainChunkOff:plainChunkOff+copySize])
|
||||||
|
p = p[copySize:]
|
||||||
|
off += int64(copySize)
|
||||||
|
n += copySize
|
||||||
|
}
|
||||||
|
if off == r.size {
|
||||||
|
return n, io.EOF
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
"testing/iotest"
|
||||||
|
|
||||||
"filippo.io/age/internal/stream"
|
"filippo.io/age/internal/stream"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
|
@ -20,13 +21,16 @@ const cs = stream.ChunkSize
|
||||||
func TestRoundTrip(t *testing.T) {
|
func TestRoundTrip(t *testing.T) {
|
||||||
for _, length := range []int{0, 1000, cs - 1, cs, cs + 1, cs + 100, 2 * cs, 2*cs + 500} {
|
for _, length := range []int{0, 1000, cs - 1, cs, cs + 1, cs + 100, 2 * cs, 2*cs + 500} {
|
||||||
for _, stepSize := range []int{512, 600, 1000, cs - 1, cs, cs + 1} {
|
for _, stepSize := range []int{512, 600, 1000, cs - 1, cs, cs + 1} {
|
||||||
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
|
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) {
|
||||||
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
|
testRoundTrip(t, stepSize, length)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
length, stepSize := 2*cs+500, 1
|
length, stepSize := 2*cs+500, 1
|
||||||
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
|
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) {
|
||||||
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
|
testRoundTrip(t, stepSize, length)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testRoundTrip(t *testing.T, stepSize, length int) {
|
func testRoundTrip(t *testing.T, stepSize, length int) {
|
||||||
|
|
@ -34,85 +38,753 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
|
||||||
if _, err := rand.Read(src); err != nil {
|
if _, err := rand.Read(src); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
buf := &bytes.Buffer{}
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var ciphertext []byte
|
||||||
|
|
||||||
|
t.Run("EncryptWriter", func(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int
|
||||||
|
for n < length {
|
||||||
|
b := min(length-n, stepSize)
|
||||||
|
nn, err := w.Write(src[n : n+b])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if nn != b {
|
||||||
|
t.Errorf("Write returned %d, expected %d", nn, b)
|
||||||
|
}
|
||||||
|
n += nn
|
||||||
|
|
||||||
|
nn, err = w.Write(src[n:n])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if nn != 0 {
|
||||||
|
t.Errorf("Write returned %d, expected 0", nn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Error("Close returned an error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext = buf.Bytes()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DecryptReader", func(t *testing.T) {
|
||||||
|
r, err := stream.NewDecryptReader(key, bytes.NewReader(ciphertext))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int
|
||||||
|
readBuf := make([]byte, stepSize)
|
||||||
|
for n < length {
|
||||||
|
nn, err := r.Read(readBuf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Read error at index %d: %v", n, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
|
||||||
|
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
|
||||||
|
}
|
||||||
|
|
||||||
|
n += nn
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("TestReader", func(t *testing.T) {
|
||||||
|
if length > 1000 && testing.Short() {
|
||||||
|
t.Skip("skipping slow iotest.TestReader on long input")
|
||||||
|
}
|
||||||
|
r, _ := stream.NewDecryptReader(key, bytes.NewReader(ciphertext))
|
||||||
|
if err := iotest.TestReader(r, src); err != nil {
|
||||||
|
t.Error("iotest.TestReader error on DecryptReader:", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DecryptReaderAt", func(t *testing.T) {
|
||||||
|
rAt, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
rr := io.NewSectionReader(rAt, 0, int64(len(ciphertext)))
|
||||||
|
|
||||||
|
var n int
|
||||||
|
readBuf := make([]byte, stepSize)
|
||||||
|
for n < length {
|
||||||
|
nn, err := rr.Read(readBuf)
|
||||||
|
if n+nn == length && err == io.EOF {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAt error at index %d: %v", n, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
|
||||||
|
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
|
||||||
|
}
|
||||||
|
|
||||||
|
n += nn
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("TestReader", func(t *testing.T) {
|
||||||
|
if length > 1000 && testing.Short() {
|
||||||
|
t.Skip("skipping slow iotest.TestReader on long input")
|
||||||
|
}
|
||||||
|
rr := io.NewSectionReader(rAt, 0, int64(len(src)))
|
||||||
|
if err := iotest.TestReader(rr, src); err != nil {
|
||||||
|
t.Error("iotest.TestReader error on DecryptReaderAt:", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EncryptReader", func(t *testing.T) {
|
||||||
|
er, err := stream.NewEncryptReader(key, bytes.NewReader(src))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int
|
||||||
|
readBuf := make([]byte, stepSize)
|
||||||
|
for {
|
||||||
|
nn, err := er.Read(readBuf)
|
||||||
|
if nn == 0 && err == io.EOF {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
t.Fatalf("EncryptReader Read error at index %d: %v", n, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) {
|
||||||
|
t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn)
|
||||||
|
}
|
||||||
|
|
||||||
|
n += nn
|
||||||
|
}
|
||||||
|
if n != len(ciphertext) {
|
||||||
|
t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext))
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("TestReader", func(t *testing.T) {
|
||||||
|
if length > 1000 && testing.Short() {
|
||||||
|
t.Skip("skipping slow iotest.TestReader on long input")
|
||||||
|
}
|
||||||
|
er, _ := stream.NewEncryptReader(key, bytes.NewReader(src))
|
||||||
|
if err := iotest.TestReader(er, ciphertext); err != nil {
|
||||||
|
t.Error("iotest.TestReader error on EncryptReader:", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// trackingReaderAt wraps an io.ReaderAt and tracks whether ReadAt was called.
|
||||||
|
type trackingReaderAt struct {
|
||||||
|
r io.ReaderAt
|
||||||
|
called bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *trackingReaderAt) ReadAt(p []byte, off int64) (int, error) {
|
||||||
|
t.called = true
|
||||||
|
return t.r.ReadAt(p, off)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *trackingReaderAt) reset() {
|
||||||
|
t.called = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAt(t *testing.T) {
|
||||||
key := make([]byte, chacha20poly1305.KeySize)
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
if _, err := rand.Read(key); err != nil {
|
if _, err := rand.Read(key); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create plaintext spanning exactly 3 chunks: 2 full chunks + partial third
|
||||||
|
// Chunk 0: [0, cs)
|
||||||
|
// Chunk 1: [cs, 2*cs)
|
||||||
|
// Chunk 2: [2*cs, 2*cs+500)
|
||||||
|
plaintextSize := 2*cs + 500
|
||||||
|
plaintext := make([]byte, plaintextSize)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
w, err := stream.NewEncryptWriter(key, buf)
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
var n int
|
t.Fatal(err)
|
||||||
for n < length {
|
|
||||||
b := min(length-n, stepSize)
|
|
||||||
nn, err := w.Write(src[n : n+b])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if nn != b {
|
|
||||||
t.Errorf("Write returned %d, expected %d", nn, b)
|
|
||||||
}
|
|
||||||
n += nn
|
|
||||||
|
|
||||||
nn, err = w.Write(src[n:n])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if nn != 0 {
|
|
||||||
t.Errorf("Write returned %d, expected 0", nn)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.Close(); err != nil {
|
if err := w.Close(); err != nil {
|
||||||
t.Error("Close returned an error:", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
ciphertext := buf.Bytes()
|
||||||
|
|
||||||
t.Logf("buffer size: %d", buf.Len())
|
// Create tracking ReaderAt
|
||||||
ciphertext := bytes.Clone(buf.Bytes())
|
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
|
||||||
|
|
||||||
r, err := stream.NewDecryptReader(key, buf)
|
// Create DecryptReaderAt (this reads and caches the final chunk)
|
||||||
|
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
tracker.reset()
|
||||||
|
|
||||||
n = 0
|
// Helper to check reads
|
||||||
readBuf := make([]byte, stepSize)
|
checkRead := func(name string, off int64, size int, wantN int, wantEOF bool, wantSrcRead bool) {
|
||||||
for n < length {
|
t.Helper()
|
||||||
nn, err := r.Read(readBuf)
|
tracker.reset()
|
||||||
if err != nil {
|
p := make([]byte, size)
|
||||||
t.Fatalf("Read error at index %d: %v", n, err)
|
n, err := ra.ReadAt(p, off)
|
||||||
|
|
||||||
|
if wantEOF {
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Errorf("%s: got err=%v, want EOF", name, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%s: got err=%v, want nil", name, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
|
if n != wantN {
|
||||||
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
|
t.Errorf("%s: got n=%d, want %d", name, n, wantN)
|
||||||
}
|
}
|
||||||
|
|
||||||
n += nn
|
if tracker.called != wantSrcRead {
|
||||||
|
t.Errorf("%s: src.ReadAt called=%v, want %v", name, tracker.called, wantSrcRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify data correctness
|
||||||
|
if n > 0 && off >= 0 && off < int64(plaintextSize) {
|
||||||
|
end := int(off) + n
|
||||||
|
if end > plaintextSize {
|
||||||
|
end = plaintextSize
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p[:n], plaintext[off:end]) {
|
||||||
|
t.Errorf("%s: data mismatch", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
er, err := stream.NewEncryptReader(key, bytes.NewReader(src))
|
// Test 1: Read from final chunk (cached by constructor)
|
||||||
|
checkRead("final chunk (cached)", int64(2*cs+100), 100, 100, false, false)
|
||||||
|
|
||||||
|
// Test 2: Read spanning second and third chunk
|
||||||
|
checkRead("span chunks 1-2", int64(cs+cs-50), 100, 100, false, true)
|
||||||
|
|
||||||
|
// Test 3: Read from final chunk again (cached from test 2)
|
||||||
|
// When reading across chunks 1-2 in test 2, the loop processes chunk 1 then chunk 2,
|
||||||
|
// so chunk 2 ends up in the cache.
|
||||||
|
checkRead("final chunk after span", int64(2*cs+200), 100, 100, false, false)
|
||||||
|
|
||||||
|
// Test 4: Read from final chunk again (now cached)
|
||||||
|
checkRead("final chunk (cached again)", int64(2*cs+50), 50, 50, false, false)
|
||||||
|
|
||||||
|
// Test 5: Read from first chunk (not cached)
|
||||||
|
checkRead("first chunk", 0, 100, 100, false, true)
|
||||||
|
|
||||||
|
// Test 6: Read from first chunk again (now cached)
|
||||||
|
checkRead("first chunk (cached)", 50, 100, 100, false, false)
|
||||||
|
|
||||||
|
// Test 7: Read spanning all chunks
|
||||||
|
tracker.reset()
|
||||||
|
p := make([]byte, plaintextSize)
|
||||||
|
n, err := ra.ReadAt(p, 0)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Errorf("span all: got err=%v, want EOF", err)
|
||||||
|
}
|
||||||
|
if n != plaintextSize {
|
||||||
|
t.Errorf("span all: got n=%d, want %d", n, plaintextSize)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p, plaintext) {
|
||||||
|
t.Errorf("span all: data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 8: Read beyond the end (offset > size)
|
||||||
|
tracker.reset()
|
||||||
|
p = make([]byte, 100)
|
||||||
|
n, err = ra.ReadAt(p, int64(plaintextSize+100))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("beyond end: expected error, got nil")
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("beyond end: got n=%d, want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 9: Read with off = size (should return 0, EOF)
|
||||||
|
tracker.reset()
|
||||||
|
p = make([]byte, 100)
|
||||||
|
n, err = ra.ReadAt(p, int64(plaintextSize))
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Errorf("off=size: got err=%v, want EOF", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("off=size: got n=%d, want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 10: Read spanning last chunk and beyond
|
||||||
|
tracker.reset()
|
||||||
|
p = make([]byte, 1000) // request more than available
|
||||||
|
n, err = ra.ReadAt(p, int64(2*cs+400))
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Errorf("span last+beyond: got err=%v, want EOF", err)
|
||||||
|
}
|
||||||
|
wantN := 500 - 400 // only 100 bytes available from offset 2*cs+400
|
||||||
|
if n != wantN {
|
||||||
|
t.Errorf("span last+beyond: got n=%d, want %d", n, wantN)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p[:n], plaintext[2*cs+400:]) {
|
||||||
|
t.Error("span last+beyond: data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 11: Read spanning second+last chunk and beyond
|
||||||
|
tracker.reset()
|
||||||
|
p = make([]byte, cs+1000) // request more than available
|
||||||
|
n, err = ra.ReadAt(p, int64(cs+100))
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Errorf("span 1-2+beyond: got err=%v, want EOF", err)
|
||||||
|
}
|
||||||
|
wantN = plaintextSize - (cs + 100)
|
||||||
|
if n != wantN {
|
||||||
|
t.Errorf("span 1-2+beyond: got n=%d, want %d", n, wantN)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p[:n], plaintext[cs+100:]) {
|
||||||
|
t.Error("span 1-2+beyond: data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 12: Negative offset
|
||||||
|
tracker.reset()
|
||||||
|
p = make([]byte, 100)
|
||||||
|
n, err = ra.ReadAt(p, -1)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("negative offset: expected error, got nil")
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("negative offset: got n=%d, want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 13: Zero-length read in the middle
|
||||||
|
tracker.reset()
|
||||||
|
p = make([]byte, 0)
|
||||||
|
n, err = ra.ReadAt(p, 100)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("zero-length middle: got err=%v, want nil", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("zero-length middle: got n=%d, want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 14: Zero-length read at end
|
||||||
|
tracker.reset()
|
||||||
|
p = make([]byte, 0)
|
||||||
|
n, err = ra.ReadAt(p, int64(plaintextSize))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("zero-length end: got err=%v, want nil", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("zero-length end: got n=%d, want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 15: Read exactly one chunk at chunk boundary
|
||||||
|
checkRead("exact chunk at boundary", int64(cs), cs, cs, false, true)
|
||||||
|
|
||||||
|
// Test 16: Read one byte at each chunk boundary
|
||||||
|
checkRead("one byte at start", 0, 1, 1, false, true)
|
||||||
|
checkRead("one byte at cs-1", int64(cs-1), 1, 1, false, false) // cached from test 15
|
||||||
|
checkRead("one byte at cs", int64(cs), 1, 1, false, true)
|
||||||
|
checkRead("one byte at 2*cs-1", int64(2*cs-1), 1, 1, false, false) // same chunk
|
||||||
|
checkRead("one byte at 2*cs", int64(2*cs), 1, 1, false, true)
|
||||||
|
checkRead("last byte", int64(plaintextSize-1), 1, 1, true, false) // same chunk, EOF because we reach end
|
||||||
|
|
||||||
|
// Test 17: Read crossing exactly one chunk boundary
|
||||||
|
checkRead("cross boundary 0-1", int64(cs-50), 100, 100, false, true)
|
||||||
|
checkRead("cross boundary 1-2", int64(2*cs-50), 100, 100, false, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAtEmpty(t *testing.T) {
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create empty encrypted file
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
n = 0
|
if err := w.Close(); err != nil {
|
||||||
for {
|
t.Fatal(err)
|
||||||
nn, err := er.Read(readBuf)
|
|
||||||
if nn == 0 && err == io.EOF {
|
|
||||||
break
|
|
||||||
} else if err != nil {
|
|
||||||
t.Fatalf("EncryptReader Read error at index %d: %v", n, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) {
|
|
||||||
t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn)
|
|
||||||
}
|
|
||||||
|
|
||||||
n += nn
|
|
||||||
}
|
}
|
||||||
if n != len(ciphertext) {
|
ciphertext := buf.Bytes()
|
||||||
t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext))
|
|
||||||
|
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
|
||||||
|
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tracker.reset()
|
||||||
|
|
||||||
|
// Test 1: Read from empty file at offset 0
|
||||||
|
p := make([]byte, 100)
|
||||||
|
n, err := ra.ReadAt(p, 0)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Errorf("empty read: got err=%v, want EOF", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("empty read: got n=%d, want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Zero-length read from empty file
|
||||||
|
p = make([]byte, 0)
|
||||||
|
n, err = ra.ReadAt(p, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("empty zero-length: got err=%v, want nil", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("empty zero-length: got n=%d, want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3: Read beyond empty file
|
||||||
|
p = make([]byte, 100)
|
||||||
|
n, err = ra.ReadAt(p, 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("empty beyond: expected error, got nil")
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("empty beyond: got n=%d, want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAtSingleChunk(t *testing.T) {
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single chunk, not full
|
||||||
|
plaintext := make([]byte, 1000)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ciphertext := buf.Bytes()
|
||||||
|
|
||||||
|
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
|
||||||
|
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tracker.reset()
|
||||||
|
|
||||||
|
// All reads should use cache (final chunk = only chunk)
|
||||||
|
p := make([]byte, 100)
|
||||||
|
n, err := ra.ReadAt(p, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("single chunk start: got err=%v, want nil", err)
|
||||||
|
}
|
||||||
|
if n != 100 {
|
||||||
|
t.Errorf("single chunk start: got n=%d, want 100", n)
|
||||||
|
}
|
||||||
|
if tracker.called {
|
||||||
|
t.Error("single chunk start: unexpected src.ReadAt call")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p[:n], plaintext[:100]) {
|
||||||
|
t.Error("single chunk start: data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read at end
|
||||||
|
n, err = ra.ReadAt(p, 900)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Errorf("single chunk end: got err=%v, want EOF", err)
|
||||||
|
}
|
||||||
|
if n != 100 {
|
||||||
|
t.Errorf("single chunk end: got n=%d, want 100", n)
|
||||||
|
}
|
||||||
|
if tracker.called {
|
||||||
|
t.Error("single chunk end: unexpected src.ReadAt call")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p[:n], plaintext[900:]) {
|
||||||
|
t.Error("single chunk end: data mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAtFullChunks(t *testing.T) {
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exactly 2 full chunks
|
||||||
|
plaintext := make([]byte, 2*cs)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ciphertext := buf.Bytes()
|
||||||
|
|
||||||
|
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
|
||||||
|
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tracker.reset()
|
||||||
|
|
||||||
|
// Read last byte of second chunk (cached)
|
||||||
|
p := make([]byte, 1)
|
||||||
|
n, err := ra.ReadAt(p, int64(2*cs-1))
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Errorf("last byte: got err=%v, want EOF", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Errorf("last byte: got n=%d, want 1", n)
|
||||||
|
}
|
||||||
|
if tracker.called {
|
||||||
|
t.Error("last byte: unexpected src.ReadAt call (should be cached)")
|
||||||
|
}
|
||||||
|
if p[0] != plaintext[2*cs-1] {
|
||||||
|
t.Error("last byte: data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read at exactly the boundary between chunks
|
||||||
|
p = make([]byte, 100)
|
||||||
|
n, err = ra.ReadAt(p, int64(cs-50))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("boundary: got err=%v, want nil", err)
|
||||||
|
}
|
||||||
|
if n != 100 {
|
||||||
|
t.Errorf("boundary: got n=%d, want 100", n)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p, plaintext[cs-50:cs+50]) {
|
||||||
|
t.Error("boundary: data mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAtWrongKey(t *testing.T) {
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, 1000)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ciphertext := buf.Bytes()
|
||||||
|
|
||||||
|
// Try to decrypt with wrong key
|
||||||
|
wrongKey := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(wrongKey); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = stream.NewDecryptReaderAt(wrongKey, bytes.NewReader(ciphertext), int64(len(ciphertext)))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("wrong key: expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAtInvalidSize(t *testing.T) {
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, 1000)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ciphertext := buf.Bytes()
|
||||||
|
|
||||||
|
// Wrong size (too small)
|
||||||
|
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)-1))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("wrong size (small): expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrong size (too large)
|
||||||
|
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)+1))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("wrong size (large): expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size that would imply empty final chunk (invalid)
|
||||||
|
// This would be: one full encrypted chunk + just overhead
|
||||||
|
invalidSize := int64(cs + chacha20poly1305.Overhead + chacha20poly1305.Overhead)
|
||||||
|
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(make([]byte, invalidSize)), invalidSize)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("invalid size (empty final chunk): expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAtTruncated(t *testing.T) {
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, 2*cs+500)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ciphertext := buf.Bytes()
|
||||||
|
|
||||||
|
// Truncate ciphertext but lie about size
|
||||||
|
truncated := ciphertext[:len(ciphertext)-100]
|
||||||
|
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), int64(len(ciphertext)))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("truncated: expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAtTruncatedChunk(t *testing.T) {
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 4 chunks: 3 full + 1 partial
|
||||||
|
plaintext := make([]byte, 3*cs+500)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ciphertext := buf.Bytes()
|
||||||
|
|
||||||
|
// Truncate to 3 chunks (remove the actual final chunk)
|
||||||
|
// The third chunk was NOT encrypted with the last chunk flag,
|
||||||
|
// so decryption should fail when we try to use it as the final chunk.
|
||||||
|
encChunkSize := cs + 16 // ChunkSize + Overhead
|
||||||
|
truncatedSize := int64(3 * encChunkSize)
|
||||||
|
truncated := ciphertext[:truncatedSize]
|
||||||
|
|
||||||
|
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), truncatedSize)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("truncated at chunk boundary: expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAtCorrupted(t *testing.T) {
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, 2*cs+500)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ciphertext := bytes.Clone(buf.Bytes())
|
||||||
|
|
||||||
|
// Corrupt final chunk - should fail in constructor
|
||||||
|
corruptedFinal := bytes.Clone(ciphertext)
|
||||||
|
corruptedFinal[len(corruptedFinal)-10] ^= 0xFF
|
||||||
|
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFinal), int64(len(corruptedFinal)))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("corrupted final: expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Corrupt first chunk - should fail on read
|
||||||
|
corruptedFirst := bytes.Clone(ciphertext)
|
||||||
|
corruptedFirst[10] ^= 0xFF
|
||||||
|
ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFirst), int64(len(corruptedFirst)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("corrupted first constructor: unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
p := make([]byte, 100)
|
||||||
|
_, err = ra.ReadAt(p, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("corrupted first read: expected error, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
192
testkit_test.go
192
testkit_test.go
|
|
@ -140,10 +140,16 @@ func parseVector(t *testing.T, test []byte) *vector {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVectors(t *testing.T) {
|
func TestVectors(t *testing.T) {
|
||||||
forEachVector(t, testVector)
|
forEachVector(t, func(t *testing.T, v *vector) {
|
||||||
|
var plaintext []byte
|
||||||
|
t.Run("Decrypt", func(t *testing.T) { plaintext = testDecrypt(t, v) })
|
||||||
|
t.Run("DecryptReaderAt", func(t *testing.T) { testDecryptReaderAt(t, v, plaintext) })
|
||||||
|
t.Run("Inspect", func(t *testing.T) { testInspect(t, v, plaintext) })
|
||||||
|
t.Run("RoundTrip", func(t *testing.T) { testVectorRoundTrip(t, v) })
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testVector(t *testing.T, v *vector) {
|
func testDecrypt(t *testing.T, v *vector) []byte {
|
||||||
var in io.Reader = bytes.NewReader(v.file)
|
var in io.Reader = bytes.NewReader(v.file)
|
||||||
if v.armored {
|
if v.armored {
|
||||||
in = armor.NewReader(in)
|
in = armor.NewReader(in)
|
||||||
|
|
@ -152,25 +158,25 @@ func testVector(t *testing.T, v *vector) {
|
||||||
if err != nil && strings.HasSuffix(err.Error(), "bad header MAC") {
|
if err != nil && strings.HasSuffix(err.Error(), "bad header MAC") {
|
||||||
if v.expect == "HMAC failure" {
|
if v.expect == "HMAC failure" {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
t.Fatalf("expected %s, got HMAC error", v.expect)
|
t.Fatalf("expected %s, got HMAC error", v.expect)
|
||||||
} else if e := new(armor.Error); errors.As(err, &e) {
|
} else if e := new(armor.Error); errors.As(err, &e) {
|
||||||
if v.expect == "armor failure" {
|
if v.expect == "armor failure" {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
t.Fatalf("expected %s, got: %v", v.expect, err)
|
t.Fatalf("expected %s, got: %v", v.expect, err)
|
||||||
} else if _, ok := err.(*age.NoIdentityMatchError); ok {
|
} else if _, ok := err.(*age.NoIdentityMatchError); ok {
|
||||||
if v.expect == "no match" {
|
if v.expect == "no match" {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
t.Fatalf("expected %s, got: %v", v.expect, err)
|
t.Fatalf("expected %s, got: %v", v.expect, err)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
if v.expect == "header failure" {
|
if v.expect == "header failure" {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
t.Fatalf("expected %s, got: %v", v.expect, err)
|
t.Fatalf("expected %s, got: %v", v.expect, err)
|
||||||
} else if v.expect != "success" && v.expect != "payload failure" &&
|
} else if v.expect != "success" && v.expect != "payload failure" &&
|
||||||
|
|
@ -188,15 +194,77 @@ func testVector(t *testing.T, v *vector) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v.payloadHash != nil && sha256.Sum256(out) != *v.payloadHash {
|
if v.payloadHash != nil && sha256.Sum256(out) != *v.payloadHash {
|
||||||
t.Error("partial payload hash mismatch")
|
t.Errorf("partial payload hash mismatch, read %d bytes", len(out))
|
||||||
}
|
}
|
||||||
return
|
return out
|
||||||
} else if v.expect != "success" {
|
} else if v.expect != "success" {
|
||||||
t.Fatalf("expected %s, got success", v.expect)
|
t.Fatalf("expected %s, got success", v.expect)
|
||||||
}
|
}
|
||||||
if sha256.Sum256(out) != *v.payloadHash {
|
if sha256.Sum256(out) != *v.payloadHash {
|
||||||
t.Error("payload hash mismatch")
|
t.Error("payload hash mismatch")
|
||||||
}
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDecryptReaderAt(t *testing.T, v *vector, plaintext []byte) {
|
||||||
|
if v.armored {
|
||||||
|
t.Skip("armor.NewReader does not implement ReaderAt")
|
||||||
|
}
|
||||||
|
rAt, s, err := age.DecryptReaderAt(bytes.NewReader(v.file), int64(len(v.file)), v.identities...)
|
||||||
|
switch v.expect {
|
||||||
|
case "success":
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected success, got: %v", err)
|
||||||
|
}
|
||||||
|
if int64(len(plaintext)) != s {
|
||||||
|
t.Errorf("unexpected size: got %d, want %d", s, len(plaintext))
|
||||||
|
}
|
||||||
|
case "payload failure":
|
||||||
|
// DecryptReaderAt detects some (but not all) payload failures upfront,
|
||||||
|
// either from the size of the payload, or by decrypting the last chunk
|
||||||
|
// to authenticate its size.
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("expected %s, got success", v.expect)
|
||||||
|
}
|
||||||
|
out, err := io.ReadAll(io.NewSectionReader(rAt, 0, s))
|
||||||
|
if v.expect == "success" {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected success, got: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected %s, got success", v.expect)
|
||||||
|
}
|
||||||
|
t.Log(err)
|
||||||
|
// We can't check the partial payload hash, because the ReaderAt will
|
||||||
|
// notice errors that a linearly scanning Reader could not. For example,
|
||||||
|
// if there are two final chunks, the linear Reader will decrypt the
|
||||||
|
// first one and then error out on the second, while the ReaderAt will
|
||||||
|
// decrypt the second one to check the size, and then know that the
|
||||||
|
// first chunk could not be the last one. Instead, check that the
|
||||||
|
// prefix, if any, matches.
|
||||||
|
if !bytes.HasPrefix(plaintext, out) {
|
||||||
|
t.Errorf("partial payload prefix mismatch, read %d bytes", len(out))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if sha256.Sum256(out) != *v.payloadHash {
|
||||||
|
t.Error("payload hash mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInspect(t *testing.T, v *vector, plaintext []byte) {
|
||||||
|
if v.expect != "success" {
|
||||||
|
t.Skip("invalid file, can't inspect")
|
||||||
|
}
|
||||||
for _, fileSize := range []int64{int64(len(v.file)), -1} {
|
for _, fileSize := range []int64{int64(len(v.file)), -1} {
|
||||||
metadata, err := inspect.Inspect(bytes.NewReader(v.file), fileSize)
|
metadata, err := inspect.Inspect(bytes.NewReader(v.file), fileSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -211,8 +279,8 @@ func testVector(t *testing.T, v *vector) {
|
||||||
if metadata.Sizes.Armor+metadata.Sizes.Header+metadata.Sizes.Overhead+metadata.Sizes.MinPayload != int64(len(v.file)) {
|
if metadata.Sizes.Armor+metadata.Sizes.Header+metadata.Sizes.Overhead+metadata.Sizes.MinPayload != int64(len(v.file)) {
|
||||||
t.Errorf("size breakdown does not add up to file size")
|
t.Errorf("size breakdown does not add up to file size")
|
||||||
}
|
}
|
||||||
if metadata.Sizes.MinPayload != int64(len(out)) {
|
if metadata.Sizes.MinPayload != int64(len(plaintext)) {
|
||||||
t.Errorf("unexpected payload size: got %d, want %d", metadata.Sizes.MinPayload, len(out))
|
t.Errorf("unexpected payload size: got %d, want %d", metadata.Sizes.MinPayload, len(plaintext))
|
||||||
}
|
}
|
||||||
if metadata.Sizes.MaxPayload != metadata.Sizes.MinPayload {
|
if metadata.Sizes.MaxPayload != metadata.Sizes.MinPayload {
|
||||||
t.Errorf("unexpected max payload size: got %d, want %d", metadata.Sizes.MaxPayload, metadata.Sizes.MinPayload)
|
t.Errorf("unexpected max payload size: got %d, want %d", metadata.Sizes.MaxPayload, metadata.Sizes.MinPayload)
|
||||||
|
|
@ -223,16 +291,12 @@ func testVector(t *testing.T, v *vector) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestVectorsRoundTrip checks that any (valid) armor, header, and/or STREAM
|
// testVectorsRoundTrip checks that any (valid) armor, header, and/or STREAM
|
||||||
// payload in the test vectors re-encodes identically.
|
// payload in the test vectors re-encodes identically.
|
||||||
func TestVectorsRoundTrip(t *testing.T) {
|
|
||||||
forEachVector(t, testVectorRoundTrip)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testVectorRoundTrip(t *testing.T, v *vector) {
|
func testVectorRoundTrip(t *testing.T, v *vector) {
|
||||||
if v.armored {
|
if v.armored {
|
||||||
if v.expect == "armor failure" {
|
if v.expect == "armor failure" {
|
||||||
t.SkipNow()
|
t.Skip("invalid armor, nothing to round-trip")
|
||||||
}
|
}
|
||||||
t.Run("armor", func(t *testing.T) {
|
t.Run("armor", func(t *testing.T) {
|
||||||
payload, err := io.ReadAll(armor.NewReader(bytes.NewReader(v.file)))
|
payload, err := io.ReadAll(armor.NewReader(bytes.NewReader(v.file)))
|
||||||
|
|
@ -261,7 +325,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.expect == "header failure" {
|
if v.expect == "header failure" {
|
||||||
t.SkipNow()
|
t.Skip("invalid header, nothing to round-trip")
|
||||||
}
|
}
|
||||||
hdr, p, err := format.Parse(bytes.NewReader(v.file))
|
hdr, p, err := format.Parse(bytes.NewReader(v.file))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -283,46 +347,62 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
if v.expect == "success" {
|
if v.expect != "success" {
|
||||||
t.Run("STREAM", func(t *testing.T) {
|
return
|
||||||
nonce, payload := payload[:16], payload[16:]
|
|
||||||
key := streamKey(v.fileKey[:], nonce)
|
|
||||||
r, err := stream.NewDecryptReader(key, bytes.NewReader(payload))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
plaintext, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
w, err := stream.NewEncryptWriter(key, buf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if _, err := w.Write(plaintext); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := w.Close(); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(buf.Bytes(), payload) {
|
|
||||||
t.Error("got a different STREAM ciphertext")
|
|
||||||
}
|
|
||||||
buf.Reset()
|
|
||||||
er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
ciphertext, err := io.ReadAll(er)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(ciphertext, payload) {
|
|
||||||
t.Error("got a different STREAM ciphertext from EncryptReader")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("STREAM", func(t *testing.T) {
|
||||||
|
nonce, payload := payload[:16], payload[16:]
|
||||||
|
key := streamKey(v.fileKey[:], nonce)
|
||||||
|
|
||||||
|
r, err := stream.NewDecryptReader(key, bytes.NewReader(payload))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
plaintext, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rAt, err := stream.NewDecryptReaderAt(key, bytes.NewReader(payload), int64(len(payload)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
plaintextAt, err := io.ReadAll(io.NewSectionReader(rAt, 0, int64(len(plaintext))))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(plaintextAt, plaintext) {
|
||||||
|
t.Errorf("got a different plaintext from DecryptReaderAt")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(buf.Bytes(), payload) {
|
||||||
|
t.Error("got a different STREAM ciphertext")
|
||||||
|
}
|
||||||
|
|
||||||
|
er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ciphertext, err := io.ReadAll(er)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(ciphertext, payload) {
|
||||||
|
t.Error("got a different STREAM ciphertext from EncryptReader")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamKey(fileKey, nonce []byte) []byte {
|
func streamKey(fileKey, nonce []byte) []byte {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue