diff --git a/trie/ctrie/ctrie.go b/trie/ctrie/ctrie.go index 0f21ce2..f146b36 100644 --- a/trie/ctrie/ctrie.go +++ b/trie/ctrie/ctrie.go @@ -306,28 +306,35 @@ func (c *Ctrie) Remove(key []byte) (interface{}, bool) { return c.remove(&Entry{Key: key, hash: c.hash(key)}) } -// Snapshot returns a stable, point-in-time snapshot of the Ctrie. +// Snapshot returns a stable, point-in-time snapshot of the Ctrie. If the Ctrie +// is read-only, the returned Ctrie will also be read-only. func (c *Ctrie) Snapshot() *Ctrie { - for { - root := c.readRoot() - main := gcasRead(root, c) - if c.rdcssRoot(root, main, root.copyToGen(&generation{}, c)) { - return newCtrie(c.readRoot().copyToGen(&generation{}, c), c.hashFactory, c.readOnly) - } - } + return c.snapshot(c.readOnly) } // ReadOnlySnapshot returns a stable, point-in-time snapshot of the Ctrie which // is read-only. Write operations on a read-only snapshot will panic. func (c *Ctrie) ReadOnlySnapshot() *Ctrie { - if c.readOnly { + return c.snapshot(true) +} + +// snapshot wraps up the CAS logic to make a snapshot or a read-only snapshot. +func (c *Ctrie) snapshot(readOnly bool) *Ctrie { + if readOnly && c.readOnly { return c } for { root := c.readRoot() main := gcasRead(root, c) if c.rdcssRoot(root, main, root.copyToGen(&generation{}, c)) { - return newCtrie(c.readRoot(), c.hashFactory, true) + if readOnly { + // For a read-only snapshot, we can share the old generation + // root. + return newCtrie(root, c.hashFactory, readOnly) + } + // For a read-write snapshot, we need to take a copy of the root + // in the new generation. + return newCtrie(c.readRoot().copyToGen(&generation{}, c), c.hashFactory, readOnly) } } } diff --git a/trie/ctrie/ctrie_test.go b/trie/ctrie/ctrie_test.go index 0163984..e62701e 100644 --- a/trie/ctrie/ctrie_test.go +++ b/trie/ctrie/ctrie_test.go @@ -226,6 +226,7 @@ func TestSnapshot(t *testing.T) { assert.Equal(i, val) } + // Now remove the values from the original. for i := 0; i < 100; i++ { ctrie.Remove([]byte(strconv.Itoa(i))) } @@ -237,6 +238,7 @@ func TestSnapshot(t *testing.T) { assert.Equal(i, val) } + // New Ctrie and snapshot. ctrie = New(nil) for i := 0; i < 100; i++ { ctrie.Insert([]byte(strconv.Itoa(i)), i) @@ -266,7 +268,45 @@ func TestSnapshot(t *testing.T) { _, ok = ctrie.Lookup([]byte("bat")) assert.False(ok) - snapshot = ctrie.ReadOnlySnapshot() + // Ensure snapshots-of-snapshots work as expected. + snapshot2 := snapshot.Snapshot() + for i := 0; i < 100; i++ { + _, ok := snapshot2.Lookup([]byte(strconv.Itoa(i))) + assert.False(ok) + } + val, ok = snapshot2.Lookup([]byte("bat")) + assert.True(ok) + assert.Equal("man", val) + + snapshot2.Remove([]byte("bat")) + _, ok = snapshot2.Lookup([]byte("bat")) + assert.False(ok) + val, ok = snapshot.Lookup([]byte("bat")) + assert.True(ok) + assert.Equal("man", val) +} + +func TestReadOnlySnapshot(t *testing.T) { + assert := assert.New(t) + ctrie := New(nil) + for i := 0; i < 100; i++ { + ctrie.Insert([]byte(strconv.Itoa(i)), i) + } + + snapshot := ctrie.ReadOnlySnapshot() + + // Ensure snapshot contains expected keys. + for i := 0; i < 100; i++ { + val, ok := snapshot.Lookup([]byte(strconv.Itoa(i))) + assert.True(ok) + assert.Equal(i, val) + } + + for i := 0; i < 50; i++ { + ctrie.Remove([]byte(strconv.Itoa(i))) + } + + // Ensure snapshot was unaffected by removals. for i := 0; i < 100; i++ { val, ok := snapshot.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) @@ -274,24 +314,31 @@ func TestSnapshot(t *testing.T) { } // Ensure read-only snapshots panic on writes. - defer func() { - assert.NotNil(recover()) + func() { + defer func() { + assert.NotNil(recover()) + }() + snapshot.Remove([]byte("blah")) }() - snapshot.Remove([]byte("blah")) // Ensure snapshots-of-snapshots work as expected. snapshot2 := snapshot.Snapshot() + for i := 50; i < 100; i++ { + ctrie.Remove([]byte(strconv.Itoa(i))) + } for i := 0; i < 100; i++ { val, ok := snapshot2.Lookup([]byte(strconv.Itoa(i))) assert.True(ok) assert.Equal(i, val) } - snapshot2.Remove([]byte("0")) - _, ok = snapshot2.Lookup([]byte("0")) - assert.False(ok) - val, ok = snapshot.Lookup([]byte("0")) - assert.True(ok) - assert.Equal(0, val) + + // Ensure snapshots of read-only snapshots panic on writes. + func() { + defer func() { + assert.NotNil(recover()) + }() + snapshot2.Remove([]byte("blah")) + }() } func TestIterator(t *testing.T) {