From df3db2d9658304fb7dd71b709235c213af3d6d57 Mon Sep 17 00:00:00 2001 From: "Matt, Park" <45252226+mattverse@users.noreply.github.com> Date: Mon, 4 Sep 2023 18:20:46 +0900 Subject: [PATCH] chore: Backport IAVL Concurrency fix for v0.20 (#828) --- iterator_test.go | 21 +++++++---- mutable_tree.go | 78 ++++++++++++++++++++++++---------------- unsaved_fast_iterator.go | 74 +++++++++++++++++++++----------------- 3 files changed, 103 insertions(+), 70 deletions(-) diff --git a/iterator_test.go b/iterator_test.go index dff2a05a9..403e0b74c 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -3,10 +3,10 @@ package iavl import ( "math/rand" "sort" + "sync" "testing" dbm "github.com/cometbft/cometbft-db" - "github.com/cosmos/iavl/fastnode" "github.com/stretchr/testify/require" ) @@ -36,7 +36,7 @@ func TestIterator_NewIterator_NilTree_Failure(t *testing.T) { }) t.Run("Unsaved Fast Iterator", func(t *testing.T) { - itr := NewUnsavedFastIterator(start, end, ascending, nil, map[string]*fastnode.Node{}, map[string]interface{}{}) + itr := NewUnsavedFastIterator(start, end, ascending, nil, &sync.Map{}, &sync.Map{}) performTest(t, itr) require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error()) }) @@ -297,14 +297,14 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite require.NoError(t, err) // No unsaved additions or removals should be present after saving - require.Equal(t, 0, len(tree.unsavedFastNodeAdditions)) - require.Equal(t, 0, len(tree.unsavedFastNodeRemovals)) + require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeAdditions)) + require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals)) // Ensure that there are unsaved additions and removals present secondHalfMirror := setupMirrorForIterator(t, &secondHalfConfig, tree) - require.True(t, len(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror)) - require.Equal(t, 0, len(tree.unsavedFastNodeRemovals)) + require.True(t, syncMapCount(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror)) + require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals)) // Merge the two halves if config.ascending { @@ -331,6 +331,15 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite return itr, mirror } +func syncMapCount(m *sync.Map) int { + count := 0 + m.Range(func(_, _ interface{}) bool { + count++ + return true + }) + return count +} + func TestNodeIterator_WithEmptyRoot(t *testing.T) { itr, err := NewNodeIterator(nil, newNodeDB(dbm.NewMemDB(), 0, nil)) require.NoError(t, err) diff --git a/mutable_tree.go b/mutable_tree.go index 940f3dc89..4f953b9ea 100644 --- a/mutable_tree.go +++ b/mutable_tree.go @@ -30,13 +30,13 @@ var ErrVersionDoesNotExist = errors.New("version does not exist") // // The inner ImmutableTree should not be used directly by callers. type MutableTree struct { - *ImmutableTree // The current, working tree. - lastSaved *ImmutableTree // The most recently saved tree. - orphans map[string]int64 // Nodes removed by changes to working tree. - versions map[int64]bool // The previous, saved versions of the tree. - allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion) - unsavedFastNodeAdditions map[string]*fastnode.Node // FastNodes that have not yet been saved to disk - unsavedFastNodeRemovals map[string]interface{} // FastNodes that have not yet been removed from disk + *ImmutableTree // The current, working tree. + lastSaved *ImmutableTree // The most recently saved tree. + orphans map[string]int64 // Nodes removed by changes to working tree. + versions map[int64]bool // The previous, saved versions of the tree. + allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion) + unsavedFastNodeAdditions *sync.Map // map[string]*FastNode FastNodes that have not yet been saved to disk + unsavedFastNodeRemovals *sync.Map // map[string]interface{} FastNodes that have not yet been removed from disk ndb *nodeDB skipFastStorageUpgrade bool // If true, the tree will work like no fast storage and always not upgrade fast storage @@ -59,8 +59,8 @@ func NewMutableTreeWithOpts(db dbm.DB, cacheSize int, opts *Options, skipFastSto orphans: map[string]int64{}, versions: map[int64]bool{}, allRootLoaded: false, - unsavedFastNodeAdditions: make(map[string]*fastnode.Node), - unsavedFastNodeRemovals: make(map[string]interface{}), + unsavedFastNodeAdditions: &sync.Map{}, + unsavedFastNodeRemovals: &sync.Map{}, ndb: ndb, skipFastStorageUpgrade: skipFastStorageUpgrade, }, nil @@ -152,11 +152,11 @@ func (tree *MutableTree) Get(key []byte) ([]byte, error) { } if !tree.skipFastStorageUpgrade { - if fastNode, ok := tree.unsavedFastNodeAdditions[ibytes.UnsafeBytesToStr(key)]; ok { - return fastNode.GetValue(), nil + if fastNode, ok := tree.unsavedFastNodeAdditions.Load(ibytes.UnsafeBytesToStr(key)); ok { + return fastNode.(*fastnode.Node).GetValue(), nil } // check if node was deleted - if _, ok := tree.unsavedFastNodeRemovals[string(key)]; ok { + if _, ok := tree.unsavedFastNodeRemovals.Load(string(key)); ok { return nil, nil } } @@ -816,8 +816,8 @@ func (tree *MutableTree) Rollback() { } tree.orphans = map[string]int64{} if !tree.skipFastStorageUpgrade { - tree.unsavedFastNodeAdditions = map[string]*fastnode.Node{} - tree.unsavedFastNodeRemovals = map[string]interface{}{} + tree.unsavedFastNodeAdditions = &sync.Map{} + tree.unsavedFastNodeRemovals = &sync.Map{} } } @@ -936,8 +936,8 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) { tree.lastSaved = tree.ImmutableTree.clone() tree.orphans = map[string]int64{} if !tree.skipFastStorageUpgrade { - tree.unsavedFastNodeAdditions = make(map[string]*fastnode.Node) - tree.unsavedFastNodeRemovals = make(map[string]interface{}) + tree.unsavedFastNodeAdditions = &sync.Map{} + tree.unsavedFastNodeRemovals = &sync.Map{} } hash, err := tree.Hash() @@ -958,48 +958,64 @@ func (tree *MutableTree) saveFastNodeVersion() error { return tree.ndb.setFastStorageVersionToBatch() } +// nolint: unused func (tree *MutableTree) getUnsavedFastNodeAdditions() map[string]*fastnode.Node { - return tree.unsavedFastNodeAdditions + additions := make(map[string]*fastnode.Node) + tree.unsavedFastNodeAdditions.Range(func(key, value interface{}) bool { + additions[key.(string)] = value.(*fastnode.Node) + return true + }) + return additions } // getUnsavedFastNodeRemovals returns unsaved FastNodes to remove func (tree *MutableTree) getUnsavedFastNodeRemovals() map[string]interface{} { - return tree.unsavedFastNodeRemovals + removals := make(map[string]interface{}) + tree.unsavedFastNodeRemovals.Range(func(key, value interface{}) bool { + removals[key.(string)] = value + return true + }) + return removals } +// addUnsavedAddition stores an addition into the unsaved additions map func (tree *MutableTree) addUnsavedAddition(key []byte, node *fastnode.Node) { skey := ibytes.UnsafeBytesToStr(key) - delete(tree.unsavedFastNodeRemovals, skey) - tree.unsavedFastNodeAdditions[skey] = node + tree.unsavedFastNodeRemovals.Delete(skey) + tree.unsavedFastNodeAdditions.Store(skey, node) } func (tree *MutableTree) saveFastNodeAdditions() error { - keysToSort := make([]string, 0, len(tree.unsavedFastNodeAdditions)) - for key := range tree.unsavedFastNodeAdditions { - keysToSort = append(keysToSort, key) - } + keysToSort := make([]string, 0) + tree.unsavedFastNodeAdditions.Range(func(k, v interface{}) bool { + keysToSort = append(keysToSort, k.(string)) + return true + }) sort.Strings(keysToSort) for _, key := range keysToSort { - if err := tree.ndb.SaveFastNode(tree.unsavedFastNodeAdditions[key]); err != nil { + val, _ := tree.unsavedFastNodeAdditions.Load(key) + if err := tree.ndb.SaveFastNode(val.(*fastnode.Node)); err != nil { return err } } return nil } +// addUnsavedRemoval adds a removal to the unsaved removals map func (tree *MutableTree) addUnsavedRemoval(key []byte) { skey := ibytes.UnsafeBytesToStr(key) - delete(tree.unsavedFastNodeAdditions, skey) - tree.unsavedFastNodeRemovals[skey] = true + tree.unsavedFastNodeAdditions.Delete(skey) + tree.unsavedFastNodeRemovals.Store(skey, true) } func (tree *MutableTree) saveFastNodeRemovals() error { - keysToSort := make([]string, 0, len(tree.unsavedFastNodeRemovals)) - for key := range tree.unsavedFastNodeRemovals { - keysToSort = append(keysToSort, key) - } + keysToSort := make([]string, 0) + tree.unsavedFastNodeRemovals.Range(func(k, v interface{}) bool { + keysToSort = append(keysToSort, k.(string)) + return true + }) sort.Strings(keysToSort) for _, key := range keysToSort { diff --git a/unsaved_fast_iterator.go b/unsaved_fast_iterator.go index 9ed5881dd..47c62b2d4 100644 --- a/unsaved_fast_iterator.go +++ b/unsaved_fast_iterator.go @@ -4,10 +4,12 @@ import ( "bytes" "errors" "sort" + "sync" dbm "github.com/cometbft/cometbft-db" - "github.com/cosmos/iavl/fastnode" ibytes "github.com/cosmos/iavl/internal/bytes" + + "github.com/cosmos/iavl/fastnode" ) var ( @@ -30,14 +32,14 @@ type UnsavedFastIterator struct { fastIterator dbm.Iterator nextUnsavedNodeIdx int - unsavedFastNodeAdditions map[string]*fastnode.Node - unsavedFastNodeRemovals map[string]interface{} + unsavedFastNodeAdditions *sync.Map // map[string]*FastNode + unsavedFastNodeRemovals *sync.Map // map[string]interface{} unsavedFastNodesToSort []string } var _ dbm.Iterator = (*UnsavedFastIterator)(nil) -func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions map[string]*fastnode.Node, unsavedFastNodeRemovals map[string]interface{}) *UnsavedFastIterator { +func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions, unsavedFastNodeRemovals *sync.Map) *UnsavedFastIterator { iter := &UnsavedFastIterator{ start: start, end: end, @@ -51,28 +53,6 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa fastIterator: NewFastIterator(start, end, ascending, ndb), } - // We need to ensure that we iterate over saved and unsaved state in order. - // The strategy is to sort unsaved nodes, the fast node on disk are already sorted. - // Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently. - for _, fastNode := range unsavedFastNodeAdditions { - if start != nil && bytes.Compare(fastNode.GetKey(), start) < 0 { - continue - } - - if end != nil && bytes.Compare(fastNode.GetKey(), end) >= 0 { - continue - } - - iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, ibytes.UnsafeBytesToStr(fastNode.GetKey())) - } - - sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool { - if ascending { - return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j] - } - return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j] - }) - if iter.ndb == nil { iter.err = errFastIteratorNilNdbGiven iter.valid = false @@ -91,7 +71,33 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa return iter } - // Move to the first elemenet + // We need to ensure that we iterate over saved and unsaved state in order. + // The strategy is to sort unsaved nodes, the fast node on disk are already sorted. + // Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently. + unsavedFastNodeAdditions.Range(func(k, v interface{}) bool { + fastNode := v.(*fastnode.Node) + + if start != nil && bytes.Compare(fastNode.GetKey(), start) < 0 { + return true + } + + if end != nil && bytes.Compare(fastNode.GetKey(), end) >= 0 { + return true + } + + iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, k.(string)) + + return true + }) + + sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool { + if ascending { + return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j] + } + return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j] + }) + + // Move to the first element iter.Next() return iter @@ -136,8 +142,8 @@ func (iter *UnsavedFastIterator) Next() { diskKeyStr := ibytes.UnsafeBytesToStr(iter.fastIterator.Key()) if iter.fastIterator.Valid() && iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) { - - if iter.unsavedFastNodeRemovals[diskKeyStr] != nil { + value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr) + if ok && value != nil { // If next fast node from disk is to be removed, skip it. iter.fastIterator.Next() iter.Next() @@ -145,7 +151,8 @@ func (iter *UnsavedFastIterator) Next() { } nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx] - nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey] + nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey) + nextUnsavedNode := nextUnsavedNodeVal.(*fastnode.Node) var isUnsavedNext bool if iter.ascending { @@ -156,7 +163,6 @@ func (iter *UnsavedFastIterator) Next() { if isUnsavedNext { // Unsaved node is next - if diskKeyStr == nextUnsavedKey { // Unsaved update prevails over saved copy so we skip the copy from disk iter.fastIterator.Next() @@ -178,7 +184,8 @@ func (iter *UnsavedFastIterator) Next() { // if only nodes on disk are left, we return them if iter.fastIterator.Valid() { - if iter.unsavedFastNodeRemovals[diskKeyStr] != nil { + value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr) + if ok && value != nil { // If next fast node from disk is to be removed, skip it. iter.fastIterator.Next() iter.Next() @@ -195,7 +202,8 @@ func (iter *UnsavedFastIterator) Next() { // if only unsaved nodes are left, we can just iterate if iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) { nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx] - nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey] + nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey) + nextUnsavedNode := nextUnsavedNodeVal.(*fastnode.Node) iter.nextKey = nextUnsavedNode.GetKey() iter.nextVal = nextUnsavedNode.GetValue()