diff --git a/pkg/fileservice/bytes.go b/pkg/fileservice/bytes.go index 57a3bbcd957b9..871fb23550d54 100644 --- a/pkg/fileservice/bytes.go +++ b/pkg/fileservice/bytes.go @@ -25,9 +25,7 @@ import ( type Bytes struct { bytes []byte deallocator malloc.Deallocator - deallocated uint32 - _refs atomic.Int32 - refs *atomic.Int32 + refs atomic.Int32 } func (b *Bytes) Size() int64 { @@ -35,6 +33,9 @@ func (b *Bytes) Size() int64 { } func (b *Bytes) Bytes() []byte { + if b.refs.Load() <= 0 { + panic("Bytes.Bytes: memory was already deallocated.") + } return b.bytes } @@ -44,25 +45,29 @@ func (b *Bytes) Slice(length int) fscache.Data { } func (b *Bytes) Retain() { - if b.refs != nil { - b.refs.Add(1) - } + b.refs.Add(1) } func (b *Bytes) Release() { - if b.refs != nil { - if n := b.refs.Add(-1); n == 0 { - if b.deallocator != nil && - atomic.CompareAndSwapUint32(&b.deallocated, 0, 1) { - b.deallocator.Deallocate(malloc.NoHints) - } - } - } else { - if b.deallocator != nil && - atomic.CompareAndSwapUint32(&b.deallocated, 0, 1) { + n := b.refs.Add(-1) + if n == 0 { + // set bytes to nil + b.bytes = nil + if b.deallocator != nil { b.deallocator.Deallocate(malloc.NoHints) + b.deallocator = nil } + } else if n < 0 { + panic("Bytes.Release: double free") + } +} + +func NewBytes(data []byte) *Bytes { + bytes := &Bytes{ + bytes: data, } + bytes.refs.Store(1) + return bytes } type bytesAllocator struct { @@ -80,8 +85,7 @@ func (b *bytesAllocator) allocateCacheData(size int, hints malloc.Hints) fscache bytes: slice, deallocator: dec, } - bytes._refs.Store(1) - bytes.refs = &bytes._refs + bytes.refs.Store(1) return bytes } diff --git a/pkg/fileservice/bytes_test.go b/pkg/fileservice/bytes_test.go index 0859bdbb0445b..0040003636a61 100644 --- a/pkg/fileservice/bytes_test.go +++ b/pkg/fileservice/bytes_test.go @@ -15,7 +15,9 @@ package fileservice import ( + "sync" "testing" + "time" "github.com/matrixorigin/matrixone/pkg/common/malloc" "github.com/stretchr/testify/assert" @@ -29,6 +31,77 @@ func TestBytes(t *testing.T) { bytes: bytes, deallocator: deallocator, } + bs.refs.Store(1) bs.Release() }) } + +func TestBytesError(t *testing.T) { + t.Run("Bytes get invalid memory", func(t *testing.T) { + bytes, deallocator, err := ioAllocator().Allocate(42, malloc.NoHints) + assert.Nil(t, err) + bs := &Bytes{ + bytes: bytes, + deallocator: deallocator, + } + bs.refs.Store(1) + + // deallocate memory + bs.Release() + + // nil pointer + assert.Panics(t, func() { bs.Bytes() }, "get invalid memory") + }) + + t.Run("Bytes double free", func(t *testing.T) { + bytes, deallocator, err := ioAllocator().Allocate(42, malloc.NoHints) + assert.Nil(t, err) + bs := &Bytes{ + bytes: bytes, + deallocator: deallocator, + } + bs.refs.Store(1) + + // deallocate memory + bs.Release() + + // double free + assert.Panics(t, func() { bs.Release() }, "double free") + }) + + t.Run("Bytes nil deallocator", func(t *testing.T) { + data := []byte("123") + bs := NewBytes(data) + + // deallocate memory + bs.Release() + + assert.Panics(t, func() { bs.Release() }, "double free") + }) +} + +func TestBytesConcurrent(t *testing.T) { + data := []byte("123") + bs := NewBytes(data) + nthread := 5 + var wg sync.WaitGroup + for i := 0; i < nthread; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + bs.Retain() + + time.Sleep(1 * time.Millisecond) + + bs.Release() + }(i) + } + + wg.Wait() + + bs.Release() + + // double free + assert.Panics(t, func() { bs.Release() }, "double free") +} diff --git a/pkg/fileservice/remote_cache.go b/pkg/fileservice/remote_cache.go index 2d32f45baba79..5be6c45f79fec 100644 --- a/pkg/fileservice/remote_cache.go +++ b/pkg/fileservice/remote_cache.go @@ -130,7 +130,7 @@ func (r *RemoteCache) Read(ctx context.Context, vector *IOVector) error { idx := int(cacheData.Index) if cacheData.Hit { vector.Entries[idx].done = true - vector.Entries[idx].CachedData = &Bytes{bytes: cacheData.Data} + vector.Entries[idx].CachedData = NewBytes(cacheData.Data) vector.Entries[idx].fromCache = r numHit++ } diff --git a/pkg/fileservice/remote_cache_test.go b/pkg/fileservice/remote_cache_test.go index e9819c9a77ec3..e1850b30e90db 100644 --- a/pkg/fileservice/remote_cache_test.go +++ b/pkg/fileservice/remote_cache_test.go @@ -89,7 +89,7 @@ func TestRemoteCache(t *testing.T) { err = sf2.rc.Read(ctx, ioVec2) assert.NoError(t, err) assert.Equal(t, 1, len(ioVec2.Entries)) - assert.Equal(t, &Bytes{bytes: []byte{1, 2}}, ioVec2.Entries[0].CachedData) + assert.Equal(t, NewBytes([]byte{1, 2}), ioVec2.Entries[0].CachedData) assert.Equal(t, true, ioVec2.Entries[0].done) assert.NotNil(t, ioVec2.Entries[0].fromCache)