diff --git a/Makefile b/Makefile index 8ff9d81..3c704dd 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,11 @@ # Unit test test: - go test -v -race -timeout=60s ./... + go test -v -race -timeout=60s -count=1 ./... # Linter lint: - golangci-lint --config .golangci.yml run \ No newline at end of file + golangci-lint --config .golangci.yml run + +# Fix linter +fix-lint: + golangci-lint --config .golangci.yml run --fix diff --git a/db/memory/memory_test.go b/db/memory/memory_test.go index 794cefd..3fb0822 100644 --- a/db/memory/memory_test.go +++ b/db/memory/memory_test.go @@ -9,9 +9,7 @@ import ( ) func TestMemoryStorageInterface(t *testing.T) { - var db merkletree.Storage //nolint:gosimple - - db = NewMemoryStorage() + db := NewMemoryStorage() require.NotNil(t, db) } diff --git a/db/test/test.go b/db/test/test.go index 2d961f4..054b386 100644 --- a/db/test/test.go +++ b/db/test/test.go @@ -109,6 +109,30 @@ func TestAll(t *testing.T, sb StorageBuilder) { t.Run("TestTypesMarshalers", func(t *testing.T) { TestTypesMarshalers(t, sb.NewStorage(t)) }) + t.Run("TestRemoveLeafNearMiddleNodeRightFork", func(t *testing.T) { + TestRemoveLeafNearMiddleNodeRightFork(t, sb.NewStorage(t)) + }) + t.Run("TestRemoveLeafNearMiddleNodeRightForkDeep", func(t *testing.T) { + TestRemoveLeafNearMiddleNodeRightForkDeep(t, sb.NewStorage(t)) + }) + t.Run("TestRemoveLeafNearMiddleLeftFork", func(t *testing.T) { + TestRemoveLeafNearMiddleNodeLeftFork(t, sb.NewStorage(t)) + }) + t.Run("TestRemoveLeafNearMiddleLeftForkDeep", func(t *testing.T) { + TestRemoveLeafNearMiddleNodeLeftForkDeep(t, sb.NewStorage(t)) + }) + t.Run("TestUpToRootAfterDeleteRightFork", func(t *testing.T) { + TestUpToRootAfterDeleteRightFork(t, sb.NewStorage(t)) + }) + t.Run("TestUpToRootAfterDeleteLeftFork", func(t *testing.T) { + TestUpToRootAfterDeleteLeftFork(t, sb.NewStorage(t)) + }) + t.Run("TestCalculatingOfNewRootRightFork", func(t *testing.T) { + TestCalculatingOfNewRootRightFork(t, sb.NewStorage(t)) + }) + t.Run("TestCalculatingOfNewRootLeftFork", func(t *testing.T) { + TestCalculatingOfNewRootLeftFork(t, sb.NewStorage(t)) + }) } // TestReturnKnownErrIfNotExists checks that the implementation of the @@ -927,3 +951,232 @@ func TestTypesMarshalers(t *testing.T, sto merkletree.Storage) { assert.Nil(t, err) assert.Equal(t, cpp, cpp2) } + +func TestRemoveLeafNearMiddleNodeRightFork(t *testing.T, sto merkletree.Storage) { + ctx := context.Background() + mt, err := merkletree.NewMerkleTree(ctx, sto, 140) + require.Nil(t, err) + + values := []*big.Int{big.NewInt(7), big.NewInt(1), big.NewInt(5)} + + for _, v := range values { + err = mt.Add(ctx, v, v) + require.NoError(t, err) + } + + for _, v := range values { + err = mt.Delete(ctx, v) + require.NoError(t, err) + existProof, _, err := mt.GenerateProof(ctx, v, mt.Root()) + require.NoError(t, err) + require.False(t, existProof.Existence) + } +} + +func TestRemoveLeafNearMiddleNodeRightForkDeep(t *testing.T, sto merkletree.Storage) { + ctx := context.Background() + mt, err := merkletree.NewMerkleTree(ctx, sto, 140) + require.Nil(t, err) + + values := []*big.Int{big.NewInt(3), big.NewInt(7), big.NewInt(15)} + + for _, v := range values { + err = mt.Add(ctx, v, v) + require.NoError(t, err) + } + + for _, v := range values { + err = mt.Delete(ctx, v) + require.NoError(t, err) + existProof, _, err := mt.GenerateProof(ctx, v, mt.Root()) + require.NoError(t, err) + require.False(t, existProof.Existence) + } +} + +func TestRemoveLeafNearMiddleNodeLeftFork(t *testing.T, sto merkletree.Storage) { + ctx := context.Background() + mt, err := merkletree.NewMerkleTree(ctx, sto, 140) + require.NoError(t, err) + + // 110 / 6 + // 100 / 4 + // 010 / 2 + values := []*big.Int{big.NewInt(6), big.NewInt(4), big.NewInt(2)} + + for _, v := range values { + err = mt.Add(ctx, v, v) + require.NoError(t, err) + } + + for _, v := range values { + err = mt.Delete(ctx, v) + require.NoError(t, err) + existProof, _, err := mt.GenerateProof(ctx, v, mt.Root()) + require.NoError(t, err) + require.False(t, existProof.Existence) + } +} + +func TestRemoveLeafNearMiddleNodeLeftForkDeep(t *testing.T, sto merkletree.Storage) { + ctx := context.Background() + mt, err := merkletree.NewMerkleTree(ctx, sto, 140) + require.Nil(t, err) + + values := []*big.Int{big.NewInt(4), big.NewInt(8), big.NewInt(16)} + + for _, v := range values { + err = mt.Add(ctx, v, v) + require.NoError(t, err) + } + + for _, v := range values { + err = mt.Delete(ctx, v) + require.NoError(t, err) + existProof, _, err := mt.GenerateProof(ctx, v, mt.Root()) + require.NoError(t, err) + require.False(t, existProof.Existence) + } +} + +// Checking whether the last leaf will be moved to the root position +// +// root +// / \ +// 0 MiddleNode +// / \ +// 01 11 +// +// Up to: +// +// root(11) +func TestUpToRootAfterDeleteRightFork(t *testing.T, sto merkletree.Storage) { + ctx := context.Background() + mt, err := merkletree.NewMerkleTree(ctx, sto, 140) + require.NoError(t, err) + + err = mt.Add(ctx, big.NewInt(1), big.NewInt(1)) + require.NoError(t, err) + err = mt.Add(ctx, big.NewInt(3), big.NewInt(3)) + require.NoError(t, err) + + err = mt.Delete(ctx, big.NewInt(1)) + require.NoError(t, err) + + leaf, err := mt.GetNode(ctx, mt.Root()) + require.NoError(t, err) + require.Equal(t, merkletree.NodeTypeLeaf, leaf.Type) + + require.Equal(t, big.NewInt(3), leaf.Entry[0].BigInt()) +} + +// Checking whether the last leaf will be moved to the root position +// +// root +// / \ +// MiddleNode 0 +// / \ +// 100 010 +// +// Up to: +// +// root(100) +func TestUpToRootAfterDeleteLeftFork(t *testing.T, sto merkletree.Storage) { + ctx := context.Background() + mt, err := merkletree.NewMerkleTree(ctx, sto, 140) + require.NoError(t, err) + + err = mt.Add(ctx, big.NewInt(2), big.NewInt(2)) + require.NoError(t, err) + err = mt.Add(ctx, big.NewInt(4), big.NewInt(4)) + require.NoError(t, err) + + err = mt.Delete(ctx, big.NewInt(2)) + require.NoError(t, err) + + leaf, err := mt.GetNode(ctx, mt.Root()) + require.NoError(t, err) + require.Equal(t, merkletree.NodeTypeLeaf, leaf.Type) + + require.Equal(t, big.NewInt(4), leaf.Entry[0].BigInt()) +} + +// Checking whether the new root will be calculated from to leafs +// +// root +// / \ +// 10 MiddleNode +// / \ +// 01 11 +// +// Up to: +// +// root +// / \ +// 10 11 +func TestCalculatingOfNewRootRightFork(t *testing.T, sto merkletree.Storage) { + ctx := context.Background() + mt, err := merkletree.NewMerkleTree(ctx, sto, 140) + require.NoError(t, err) + + err = mt.Add(ctx, big.NewInt(1), big.NewInt(1)) + require.NoError(t, err) + err = mt.Add(ctx, big.NewInt(3), big.NewInt(3)) + require.NoError(t, err) + err = mt.Add(ctx, big.NewInt(2), big.NewInt(2)) + require.NoError(t, err) + + err = mt.Delete(ctx, big.NewInt(1)) + require.NoError(t, err) + + root, err := mt.GetNode(ctx, mt.Root()) + require.NoError(t, err) + + lLeaf, err := mt.GetNode(ctx, root.ChildL) + require.NoError(t, err) + rLeaf, err := mt.GetNode(ctx, root.ChildR) + require.NoError(t, err) + + require.Equal(t, big.NewInt(2), lLeaf.Entry[0].BigInt()) + require.Equal(t, big.NewInt(3), rLeaf.Entry[0].BigInt()) +} + +// Checking whether the new root will be calculated from to leafs +// +// root +// / \ +// MiddleNode 01 +// / \ +// 100 010 +// +// Up to: +// +// root +// / \ +// 100 001 +func TestCalculatingOfNewRootLeftFork(t *testing.T, sto merkletree.Storage) { + ctx := context.Background() + mt, err := merkletree.NewMerkleTree(ctx, sto, 140) + require.NoError(t, err) + + err = mt.Add(ctx, big.NewInt(1), big.NewInt(1)) + require.NoError(t, err) + err = mt.Add(ctx, big.NewInt(2), big.NewInt(2)) + require.NoError(t, err) + err = mt.Add(ctx, big.NewInt(4), big.NewInt(4)) + require.NoError(t, err) + + err = mt.Delete(ctx, big.NewInt(2)) + require.NoError(t, err) + + root, err := mt.GetNode(ctx, mt.Root()) + require.NoError(t, err) + + lLeaf, err := mt.GetNode(ctx, root.ChildL) + require.NoError(t, err) + rLeaf, err := mt.GetNode(ctx, root.ChildR) + require.NoError(t, err) + + require.Equal(t, big.NewInt(4), lLeaf.Entry[0].BigInt()) + require.Equal(t, big.NewInt(1), rLeaf.Entry[0].BigInt()) +} diff --git a/merkletree.go b/merkletree.go index 1478215..e13557e 100644 --- a/merkletree.go +++ b/merkletree.go @@ -558,6 +558,37 @@ func (mt *MerkleTree) rmAndUpload(ctx context.Context, path []bool, kHash *Hash, return err } } + + //When deleting a leaf node that is on the same level as middleNode, + //need to nullify the leaf node instead of removing it from the tree. + nearestSibling, err := mt.db.Get(ctx, toUpload[:]) + if err != nil { + return err + } + if nearestSibling.Type == NodeTypeMiddle { + var newNode *Node + if path[len(siblings)-1] { + newNode = NewNodeMiddle(toUpload, &HashZero) + } else { + newNode = NewNodeMiddle(&HashZero, toUpload) + } + _, err = mt.addNode(ctx, newNode) + if err != nil { + return err + } + newRootKey, err := mt.recalculatePathUntilRoot(path, newNode, + siblings[:len(siblings)-1]) + if err != nil { + return err + } + mt.rootKey = newRootKey + err = mt.db.SetRoot(ctx, mt.rootKey) + if err != nil { + return err + } + return nil + } + for i := len(siblings) - 2; i >= 0; i-- { if !bytes.Equal(siblings[i][:], HashZero[:]) { var newNode *Node