mirror of
https://github.com/FiloSottile/age.git
synced 2026-03-11 08:55:41 +00:00
internal/stream: fix DecryptReaderAt concurrency
This commit is contained in:
parent
da2191789a
commit
420273952a
2 changed files with 168 additions and 1 deletions
|
|
@ -398,6 +398,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||||
if len(p) == 0 {
|
if len(p) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
var cacheUpdate *cachedChunk
|
||||||
chunk := make([]byte, encChunkSize)
|
chunk := make([]byte, encChunkSize)
|
||||||
for len(p) > 0 && off < r.size {
|
for len(p) > 0 && off < r.size {
|
||||||
chunkIndex := off / ChunkSize
|
chunkIndex := off / ChunkSize
|
||||||
|
|
@ -409,6 +410,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||||
var plaintext []byte
|
var plaintext []byte
|
||||||
if cached != nil && cached.off == chunkOff {
|
if cached != nil && cached.off == chunkOff {
|
||||||
plaintext = cached.data
|
plaintext = cached.data
|
||||||
|
cacheUpdate = nil
|
||||||
} else {
|
} else {
|
||||||
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
|
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
|
|
@ -429,7 +431,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
|
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
|
||||||
}
|
}
|
||||||
r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext})
|
cacheUpdate = &cachedChunk{off: chunkOff, data: plaintext}
|
||||||
}
|
}
|
||||||
|
|
||||||
plainChunkOff := int(off - chunkIndex*ChunkSize)
|
plainChunkOff := int(off - chunkIndex*ChunkSize)
|
||||||
|
|
@ -439,6 +441,9 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||||
off += int64(copySize)
|
off += int64(copySize)
|
||||||
n += copySize
|
n += copySize
|
||||||
}
|
}
|
||||||
|
if cacheUpdate != nil {
|
||||||
|
r.cache.Store(cacheUpdate)
|
||||||
|
}
|
||||||
if off == r.size {
|
if off == r.size {
|
||||||
return n, io.EOF
|
return n, io.EOF
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -743,6 +743,168 @@ func TestDecryptReaderAtTruncatedChunk(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecryptReaderAtConcurrent(t *testing.T) {
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create plaintext spanning 3 chunks: 2 full + partial
|
||||||
|
plaintextSize := 2*cs + 500
|
||||||
|
plaintext := make([]byte, plaintextSize)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, err := stream.NewEncryptWriter(key, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ciphertext := buf.Bytes()
|
||||||
|
|
||||||
|
ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("same chunk", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
const goroutines = 10
|
||||||
|
const iterations = 100
|
||||||
|
errc := make(chan error, goroutines)
|
||||||
|
|
||||||
|
for g := range goroutines {
|
||||||
|
go func(id int) {
|
||||||
|
for i := range iterations {
|
||||||
|
off := int64((id*iterations + i) % 500)
|
||||||
|
p := make([]byte, 100)
|
||||||
|
n, err := ra.ReadAt(p, off)
|
||||||
|
if err != nil {
|
||||||
|
errc <- fmt.Errorf("goroutine %d iter %d: %v", id, i, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n != 100 {
|
||||||
|
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want 100", id, i, n)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p, plaintext[off:off+100]) {
|
||||||
|
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errc <- nil
|
||||||
|
}(g)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range goroutines {
|
||||||
|
if err := <-errc; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different chunks", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
const goroutines = 10
|
||||||
|
const iterations = 100
|
||||||
|
errc := make(chan error, goroutines)
|
||||||
|
|
||||||
|
for g := range goroutines {
|
||||||
|
go func(id int) {
|
||||||
|
for i := range iterations {
|
||||||
|
// Each goroutine reads from a different chunk based on id
|
||||||
|
chunkIdx := id % 3
|
||||||
|
off := int64(chunkIdx*cs + (i % 400))
|
||||||
|
size := 100
|
||||||
|
if off+int64(size) > int64(plaintextSize) {
|
||||||
|
size = plaintextSize - int(off)
|
||||||
|
}
|
||||||
|
p := make([]byte, size)
|
||||||
|
n, err := ra.ReadAt(p, off)
|
||||||
|
if n == size && err == io.EOF {
|
||||||
|
err = nil // EOF at end is acceptable
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
errc <- fmt.Errorf("goroutine %d iter %d: off=%d: %v", id, i, off, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n != size {
|
||||||
|
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
|
||||||
|
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errc <- nil
|
||||||
|
}(g)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range goroutines {
|
||||||
|
if err := <-errc; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("across chunks", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
const goroutines = 10
|
||||||
|
const iterations = 100
|
||||||
|
errc := make(chan error, goroutines)
|
||||||
|
|
||||||
|
for g := range goroutines {
|
||||||
|
go func(id int) {
|
||||||
|
for i := range iterations {
|
||||||
|
// Read across chunk boundaries
|
||||||
|
boundary := (id%2 + 1) * cs // either cs or 2*cs
|
||||||
|
off := int64(boundary - 50 + (i % 30))
|
||||||
|
size := 100
|
||||||
|
if off+int64(size) > int64(plaintextSize) {
|
||||||
|
size = plaintextSize - int(off)
|
||||||
|
}
|
||||||
|
if size <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p := make([]byte, size)
|
||||||
|
n, err := ra.ReadAt(p, off)
|
||||||
|
if n == size && err == io.EOF {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
errc <- fmt.Errorf("goroutine %d iter %d: off=%d size=%d: %v", id, i, off, size, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n != size {
|
||||||
|
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
|
||||||
|
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errc <- nil
|
||||||
|
}(g)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range goroutines {
|
||||||
|
if err := <-errc; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestDecryptReaderAtCorrupted(t *testing.T) {
|
func TestDecryptReaderAtCorrupted(t *testing.T) {
|
||||||
key := make([]byte, chacha20poly1305.KeySize)
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
if _, err := rand.Read(key); err != nil {
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue