diff --git a/core/state/statedb.go b/core/state/statedb.go index 4bedebe8b3c3..7348fbc79e43 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -649,6 +649,10 @@ func (s *StateDB) clearJournalAndRefund() { // Commit writes the state to the underlying in-memory trie database. func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) { + if s.dbErr != nil { + return common.Hash{}, fmt.Errorf("commit aborted due to earlier error: %v", s.dbErr) + } + defer s.clearJournalAndRefund() // Commit objects to the trie. diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 7e7f5343c34a..ed364042eed5 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -20,7 +20,6 @@ import ( "bytes" "encoding/binary" "fmt" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math" "math/big" "math/rand" @@ -29,6 +28,8 @@ import ( "testing" "testing/quick" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" + check "gopkg.in/check.v1" "github.com/XinFinOrg/XDPoSChain/common" @@ -427,3 +428,51 @@ func (s *StateSuite) TestTouchDelete(c *check.C) { c.Fatal("expected no dirty state object") } } + +// TestMissingTrieNodes tests that if the statedb fails to load parts of the trie, +// the Commit operation fails with an error +// If we are missing trie nodes, we should not continue writing to the trie +func TestMissingTrieNodes(t *testing.T) { + + // Create an initial state with a few accounts + memDb := rawdb.NewMemoryDatabase() + db := NewDatabase(memDb) + var root common.Hash + state, _ := New(common.Hash{}, db) + addr := toAddr([]byte("so")) + { + state.SetBalance(addr, big.NewInt(1)) + state.SetCode(addr, []byte{1, 2, 3}) + a2 := toAddr([]byte("another")) + state.SetBalance(a2, big.NewInt(100)) + state.SetCode(a2, []byte{1, 2, 4}) + root, _ = state.Commit(false) + t.Logf("root: %x", root) + // force-flush + state.Database().TrieDB().Cap(0) + } + // Create a new state on the old root + state, _ = New(root, db) + //state, _ = New(root, db, nil) + // Now we clear out the memdb + it := memDb.NewIterator(nil, nil) + for it.Next() { + k := it.Key() + // Leave the root intact + if !bytes.Equal(k, root[:]) { + t.Logf("key: %x", k) + memDb.Delete(k) + } + } + balance := state.GetBalance(addr) + // The removed elem should lead to it returning zero balance + if exp, got := uint64(0), balance.Uint64(); got != exp { + t.Errorf("expected %d, got %d", exp, got) + } + // Modify the state + state.SetBalance(addr, big.NewInt(2)) + root, err := state.Commit(false) + if err == nil { + t.Fatalf("expected error, got root :%x", root) + } +}