diff --git a/cmd/geth/main.go b/cmd/geth/main.go
index 00d22179665d..a702ce44620a 100644
--- a/cmd/geth/main.go
+++ b/cmd/geth/main.go
@@ -150,6 +150,7 @@ var (
utils.ScrollAlphaFlag,
utils.ScrollSepoliaFlag,
utils.ScrollFlag,
+ utils.ScrollMPTFlag,
utils.VMEnableDebugFlag,
utils.NetworkIdFlag,
utils.EthStatsURLFlag,
diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go
index 26818ddfd14d..00111c9956f8 100644
--- a/cmd/geth/usage.go
+++ b/cmd/geth/usage.go
@@ -50,6 +50,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{
utils.ScrollAlphaFlag,
utils.ScrollSepoliaFlag,
utils.ScrollFlag,
+ utils.ScrollMPTFlag,
utils.SyncModeFlag,
utils.ExitWhenSyncedFlag,
utils.GCModeFlag,
diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index bccd6017b36e..0967f1000b6a 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -183,6 +183,10 @@ var (
Name: "scroll",
Usage: "Scroll mainnet",
}
+ ScrollMPTFlag = cli.BoolFlag{
+ Name: "scroll-mpt",
+ Usage: "Use MPT trie for state storage",
+ }
DeveloperFlag = cli.BoolFlag{
Name: "dev",
Usage: "Ephemeral proof-of-authority network with a pre-funded developer account, mining enabled",
@@ -1879,12 +1883,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) {
stack.Config().L1Confirmations = rpc.FinalizedBlockNumber
log.Info("Setting flag", "--l1.sync.startblock", "4038000")
stack.Config().L1DeploymentBlock = 4038000
- // disable pruning
- if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
- log.Crit("Must use --gcmode=archive")
+ cfg.Genesis.Config.Scroll.UseZktrie = !ctx.GlobalBool(ScrollMPTFlag.Name)
+ if cfg.Genesis.Config.Scroll.UseZktrie {
+ // disable pruning
+ if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
+ log.Crit("Must use --gcmode=archive")
+ }
+ log.Info("Pruning disabled")
+ cfg.NoPruning = true
}
- log.Info("Pruning disabled")
- cfg.NoPruning = true
case ctx.GlobalBool(ScrollFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 534352
@@ -1895,12 +1902,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) {
stack.Config().L1Confirmations = rpc.FinalizedBlockNumber
log.Info("Setting flag", "--l1.sync.startblock", "18306000")
stack.Config().L1DeploymentBlock = 18306000
- // disable pruning
- if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
- log.Crit("Must use --gcmode=archive")
+ cfg.Genesis.Config.Scroll.UseZktrie = !ctx.GlobalBool(ScrollMPTFlag.Name)
+ if cfg.Genesis.Config.Scroll.UseZktrie {
+ // disable pruning
+ if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
+ log.Crit("Must use --gcmode=archive")
+ }
+ log.Info("Pruning disabled")
+ cfg.NoPruning = true
}
- log.Info("Pruning disabled")
- cfg.NoPruning = true
case ctx.GlobalBool(DeveloperFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 1337
diff --git a/core/block_validator.go b/core/block_validator.go
index fdc845682d96..9eb3accacdb4 100644
--- a/core/block_validator.go
+++ b/core/block_validator.go
@@ -226,7 +226,8 @@ func (v *BlockValidator) ValidateState(block *types.Block, statedb *state.StateD
}
// Validate the state root against the received state root and throw
// an error if they don't match.
- if root := statedb.IntermediateRoot(v.config.IsEIP158(header.Number)); header.Root != root {
+ shouldValidateStateRoot := v.config.Scroll.UseZktrie != v.config.IsEuclid(header.Time)
+ if root := statedb.IntermediateRoot(v.config.IsEIP158(header.Number)); shouldValidateStateRoot && header.Root != root {
return fmt.Errorf("invalid merkle root (remote: %x local: %x)", header.Root, root)
}
return nil
diff --git a/core/blockchain.go b/core/blockchain.go
index a0bc05924531..57d82bc118a4 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -1318,6 +1318,9 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
return NonStatTy, err
}
triedb := bc.stateCache.TrieDB()
+ if block.Root() != root {
+ rawdb.WriteDiskStateRoot(bc.db, block.Root(), root)
+ }
// If we're running an archive node, always flush
if bc.cacheConfig.TrieDirtyDisabled {
@@ -1677,7 +1680,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
}
// Enable prefetching to pull in trie node paths while processing transactions
- statedb.StartPrefetcher("chain")
+ statedb.StartPrefetcher("chain", nil)
activeState = statedb
// If we have a followup block, run that against the current state to pre-cache
@@ -1814,7 +1817,7 @@ func (bc *BlockChain) BuildAndWriteBlock(parentBlock *types.Block, header *types
return NonStatTy, err
}
- statedb.StartPrefetcher("l1sync")
+ statedb.StartPrefetcher("l1sync", nil)
defer statedb.StopPrefetcher()
header.ParentHash = parentBlock.Hash()
diff --git a/core/blockchain_test.go b/core/blockchain_test.go
index 13b75622b169..43169492333c 100644
--- a/core/blockchain_test.go
+++ b/core/blockchain_test.go
@@ -3032,15 +3032,16 @@ func TestPoseidonCodeHash(t *testing.T) {
var callCreate2Code = common.Hex2Bytes("f4754f660000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000005c6080604052348015600f57600080fd5b50603f80601d6000396000f3fe6080604052600080fdfea2646970667358221220707985753fcb6578098bb16f3709cf6d012993cba6dd3712661cf8f57bbc0d4d64736f6c6343000807003300000000")
var (
- key1, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
- addr1 = crypto.PubkeyToAddress(key1.PublicKey)
- db = rawdb.NewMemoryDatabase()
- gspec = &Genesis{Config: params.TestChainConfig, Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000000)}}}
- genesis = gspec.MustCommit(db)
- signer = types.LatestSigner(gspec.Config)
- engine = ethash.NewFaker()
- blockchain, _ = NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil, nil)
+ key1, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+ addr1 = crypto.PubkeyToAddress(key1.PublicKey)
+ db = rawdb.NewMemoryDatabase()
+ gspec = &Genesis{Config: params.TestChainConfig, Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000000)}}}
+ signer = types.LatestSigner(gspec.Config)
+ engine = ethash.NewFaker()
)
+ gspec.Config.Scroll.UseZktrie = true
+ genesis := gspec.MustCommit(db)
+ blockchain, _ := NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil, nil)
defer blockchain.Stop()
@@ -3724,6 +3725,7 @@ func TestCurieTransition(t *testing.T) {
config.CurieBlock = big.NewInt(2)
config.DarwinTime = nil
config.DarwinV2Time = nil
+ config.Scroll.UseZktrie = true
var (
db = rawdb.NewMemoryDatabase()
@@ -3748,7 +3750,7 @@ func TestCurieTransition(t *testing.T) {
number := block.Number().Uint64()
baseFee := block.BaseFee()
- statedb, _ := state.New(block.Root(), state.NewDatabase(db), nil)
+ statedb, _ := state.New(block.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Zktrie: gspec.Config.Scroll.UseZktrie}), nil)
code := statedb.GetCode(rcfg.L1GasPriceOracleAddress)
codeSize := statedb.GetCodeSize(rcfg.L1GasPriceOracleAddress)
diff --git a/core/chain_makers.go b/core/chain_makers.go
index b79d92bf71e4..2b0a5856061b 100644
--- a/core/chain_makers.go
+++ b/core/chain_makers.go
@@ -29,6 +29,7 @@ import (
"github.com/scroll-tech/go-ethereum/ethdb"
"github.com/scroll-tech/go-ethereum/params"
"github.com/scroll-tech/go-ethereum/rollup/fees"
+ "github.com/scroll-tech/go-ethereum/trie"
)
// BlockGen creates blocks for testing.
@@ -264,7 +265,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
return nil, nil
}
for i := 0; i < n; i++ {
- statedb, err := state.New(parent.Root(), state.NewDatabase(db), nil)
+ statedb, err := state.New(parent.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Zktrie: config.Scroll.ZktrieEnabled()}), nil)
if err != nil {
panic(err)
}
diff --git a/core/genesis.go b/core/genesis.go
index 50a8e8843a4d..9554be0ac304 100644
--- a/core/genesis.go
+++ b/core/genesis.go
@@ -322,7 +322,10 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block {
}
statedb.Commit(false)
statedb.Database().TrieDB().Commit(root, true, nil)
-
+ if g.Config != nil && g.Config.Scroll.GenesisStateRoot != nil {
+ head.Root = *g.Config.Scroll.GenesisStateRoot
+ rawdb.WriteDiskStateRoot(db, head.Root, root)
+ }
return types.NewBlock(head, nil, nil, nil, trie.NewStackTrie(nil))
}
diff --git a/core/rawdb/accessors_state.go b/core/rawdb/accessors_state.go
index f153af69f942..2738d9424668 100644
--- a/core/rawdb/accessors_state.go
+++ b/core/rawdb/accessors_state.go
@@ -94,3 +94,17 @@ func DeleteTrieNode(db ethdb.KeyValueWriter, hash common.Hash) {
log.Crit("Failed to delete trie node", "err", err)
}
}
+
+func WriteDiskStateRoot(db ethdb.KeyValueWriter, headerRoot, diskRoot common.Hash) {
+ if err := db.Put(diskStateRootKey(headerRoot), diskRoot.Bytes()); err != nil {
+ log.Crit("Failed to store disk state root", "err", err)
+ }
+}
+
+func ReadDiskStateRoot(db ethdb.KeyValueReader, headerRoot common.Hash) (common.Hash, error) {
+ data, err := db.Get(diskStateRootKey(headerRoot))
+ if err != nil {
+ return common.Hash{}, err
+ }
+ return common.BytesToHash(data), nil
+}
diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go
index b4a51935b4ff..47b29c77d840 100644
--- a/core/rawdb/schema.go
+++ b/core/rawdb/schema.go
@@ -127,6 +127,8 @@ var (
// Scroll da syncer store
daSyncedL1BlockNumberKey = []byte("LastDASyncedL1BlockNumber")
+
+ diskStateRootPrefix = []byte("disk-state-root")
)
// Use the updated "L1" prefix on all new networks
@@ -312,3 +314,7 @@ func batchMetaKey(batchIndex uint64) []byte {
func committedBatchMetaKey(batchIndex uint64) []byte {
return append(committedBatchMetaPrefix, encodeBigEndian(batchIndex)...)
}
+
+func diskStateRootKey(headerRoot common.Hash) []byte {
+ return append(diskStateRootPrefix, headerRoot.Bytes()...)
+}
diff --git a/core/state/database.go b/core/state/database.go
index bb73fcecd216..9a58bc72246f 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -105,6 +105,9 @@ type Trie interface {
// nodes of the longest existing prefix of the key (at least the root), ending
// with the node that proves the absence of the key.
Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error
+
+ // Witness returns a set containing all trie nodes that have been accessed.
+ Witness() map[string]struct{}
}
// NewDatabase creates a backing store for state. The returned database is safe for
@@ -136,6 +139,9 @@ type cachingDB struct {
// OpenTrie opens the main account trie at a specific root hash.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
+ if diskRoot, err := rawdb.ReadDiskStateRoot(db.db.DiskDB(), root); err == nil {
+ root = diskRoot
+ }
if db.zktrie {
tr, err := trie.NewZkTrie(root, trie.NewZktrieDatabaseFromTriedb(db.db))
if err != nil {
diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go
index e5e2b420018a..72e6d134d59b 100644
--- a/core/state/snapshot/generate.go
+++ b/core/state/snapshot/generate.go
@@ -618,8 +618,8 @@ func (dl *diskLayer) generate(stats *generatorStats) {
Balance *big.Int
Root common.Hash
KeccakCodeHash []byte
- PoseidonCodeHash []byte
- CodeSize uint64
+ PoseidonCodeHash []byte `rlp:"-"`
+ CodeSize uint64 `rlp:"-"`
}
if err := rlp.DecodeBytes(val, &acc); err != nil {
log.Crit("Invalid account encountered during snapshot creation", "err", err)
diff --git a/core/state/state_object.go b/core/state/state_object.go
index f9213a0a31d7..4fb9e82c2ed7 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -500,8 +500,18 @@ func (s *stateObject) Code(db Database) []byte {
// CodeSize returns the size of the contract code associated with this object,
// or zero if none. This method is an almost mirror of Code, but uses a cache
// inside the database to avoid loading codes seen recently.
-func (s *stateObject) CodeSize() uint64 {
- return s.data.CodeSize
+func (s *stateObject) CodeSize(db Database) uint64 {
+ if s.code != nil {
+ return uint64(len(s.code))
+ }
+ if bytes.Equal(s.KeccakCodeHash(), emptyKeccakCodeHash) {
+ return 0
+ }
+ size, err := db.ContractCodeSize(s.addrHash, common.BytesToHash(s.KeccakCodeHash()))
+ if err != nil {
+ s.setError(fmt.Errorf("can't load code size %x: %v", s.KeccakCodeHash(), err))
+ }
+ return uint64(size)
}
func (s *stateObject) SetCode(code []byte) {
@@ -534,6 +544,9 @@ func (s *stateObject) setNonce(nonce uint64) {
}
func (s *stateObject) PoseidonCodeHash() []byte {
+ if !s.db.IsZktrie() {
+ panic("PoseidonCodeHash is only available in zktrie mode")
+ }
return s.data.PoseidonCodeHash
}
diff --git a/core/state/state_test.go b/core/state/state_test.go
index ea98b2dab833..8b53cb2eca2a 100644
--- a/core/state/state_test.go
+++ b/core/state/state_test.go
@@ -155,7 +155,8 @@ func TestSnapshotEmpty(t *testing.T) {
}
func TestSnapshot2(t *testing.T) {
- state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+ stateDb := NewDatabase(rawdb.NewMemoryDatabase())
+ state, _ := New(common.Hash{}, stateDb, nil)
stateobjaddr0 := common.BytesToAddress([]byte("so0"))
stateobjaddr1 := common.BytesToAddress([]byte("so1"))
@@ -201,7 +202,7 @@ func TestSnapshot2(t *testing.T) {
so0Restored.GetState(state.db, storageaddr)
so0Restored.Code(state.db)
// non-deleted is equal (restored)
- compareStateObjects(so0Restored, so0, t)
+ compareStateObjects(so0Restored, so0, stateDb, t)
// deleted should be nil, both before and after restore of state copy
so1Restored := state.getStateObject(stateobjaddr1)
@@ -210,7 +211,7 @@ func TestSnapshot2(t *testing.T) {
}
}
-func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
+func compareStateObjects(so0, so1 *stateObject, db Database, t *testing.T) {
if so0.Address() != so1.Address() {
t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address)
}
@@ -229,8 +230,8 @@ func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
if !bytes.Equal(so0.PoseidonCodeHash(), so1.PoseidonCodeHash()) {
t.Fatalf("PoseidonCodeHash mismatch: have %v, want %v", so0.PoseidonCodeHash(), so1.PoseidonCodeHash())
}
- if so0.CodeSize() != so1.CodeSize() {
- t.Fatalf("CodeSize mismatch: have %v, want %v", so0.CodeSize(), so1.CodeSize())
+ if so0.CodeSize(db) != so1.CodeSize(db) {
+ t.Fatalf("CodeSize mismatch: have %v, want %v", so0.CodeSize(db), so1.CodeSize(db))
}
if !bytes.Equal(so0.code, so1.code) {
t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code)
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 6629a50eae57..7affd81fc409 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -29,6 +29,7 @@ import (
"github.com/scroll-tech/go-ethereum/common"
"github.com/scroll-tech/go-ethereum/core/rawdb"
"github.com/scroll-tech/go-ethereum/core/state/snapshot"
+ "github.com/scroll-tech/go-ethereum/core/stateless"
"github.com/scroll-tech/go-ethereum/core/types"
"github.com/scroll-tech/go-ethereum/crypto"
"github.com/scroll-tech/go-ethereum/log"
@@ -106,6 +107,9 @@ type StateDB struct {
validRevisions []revision
nextRevisionId int
+ // State witness if cross validation is needed
+ witness *stateless.Witness
+
// Measurements gathered during execution for debugging purposes
AccountReads time.Duration
AccountHashes time.Duration
@@ -159,11 +163,15 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error)
// StartPrefetcher initializes a new trie prefetcher to pull in nodes from the
// state trie concurrently while the state is mutated so that when we reach the
// commit phase, most of the needed data is already hot.
-func (s *StateDB) StartPrefetcher(namespace string) {
+func (s *StateDB) StartPrefetcher(namespace string, witness *stateless.Witness) {
if s.prefetcher != nil {
s.prefetcher.close()
s.prefetcher = nil
}
+
+ // Enable witness collection if requested
+ s.witness = witness
+
if s.snap != nil {
s.prefetcher = newTriePrefetcher(s.db, s.originalRoot, namespace)
}
@@ -289,6 +297,9 @@ func (s *StateDB) TxIndex() int {
func (s *StateDB) GetCode(addr common.Address) []byte {
stateObject := s.getStateObject(addr)
if stateObject != nil {
+ if s.witness != nil {
+ s.witness.AddCode(stateObject.Code(s.db))
+ }
return stateObject.Code(s.db)
}
return nil
@@ -297,7 +308,10 @@ func (s *StateDB) GetCode(addr common.Address) []byte {
func (s *StateDB) GetCodeSize(addr common.Address) uint64 {
stateObject := s.getStateObject(addr)
if stateObject != nil {
- return stateObject.CodeSize()
+ if s.witness != nil {
+ s.witness.AddCode(stateObject.Code(s.db))
+ }
+ return stateObject.CodeSize(s.db)
}
return 0
}
@@ -725,6 +739,9 @@ func (s *StateDB) Copy() *StateDB {
journal: newJournal(),
hasher: crypto.NewKeccakState(),
}
+ if s.witness != nil {
+ state.witness = s.witness.Copy()
+ }
// Copy the dirty states, logs, and preimages
for addr := range s.journal.dirties {
// As documented [here](https://github.com/scroll-tech/go-ethereum/pull/16485#issuecomment-380438527),
@@ -913,7 +930,33 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
// to pull useful data from disk.
for addr := range s.stateObjectsPending {
if obj := s.stateObjects[addr]; !obj.deleted {
+
+ // If witness building is enabled and the state object has a trie,
+ // gather the witnesses for its specific storage trie
+ if s.witness != nil && obj.trie != nil {
+ s.witness.AddState(obj.trie.Witness())
+ }
+
obj.updateRoot(s.db)
+
+ // If witness building is enabled and the state object has a trie,
+ // gather the witnesses for its specific storage trie
+ if s.witness != nil && obj.trie != nil {
+ s.witness.AddState(obj.trie.Witness())
+ }
+ }
+ }
+
+ if s.witness != nil {
+ // If witness building is enabled, gather the account trie witness for read-only operations
+ for _, obj := range s.stateObjects {
+ if len(obj.originStorage) == 0 {
+ continue
+ }
+
+ if trie := obj.getTrie(s.db); trie != nil {
+ s.witness.AddState(trie.Witness())
+ }
}
}
// Now we're about to start to write changes to the trie. The trie is so far
@@ -945,7 +988,13 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
if metrics.EnabledExpensive {
defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now())
}
- return s.trie.Hash()
+
+ hash := s.trie.Hash()
+ // If witness building is enabled, gather the account trie witness
+ if s.witness != nil {
+ s.witness.AddState(s.trie.Witness())
+ }
+ return hash
}
// SetTxContext sets the current transaction hash and index which are
diff --git a/core/stateless/database.go b/core/stateless/database.go
new file mode 100644
index 000000000000..e6278a98f872
--- /dev/null
+++ b/core/stateless/database.go
@@ -0,0 +1,67 @@
+// Copyright 2024 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package stateless
+
+import (
+ "github.com/scroll-tech/go-ethereum/common"
+ "github.com/scroll-tech/go-ethereum/core/rawdb"
+ "github.com/scroll-tech/go-ethereum/crypto"
+ "github.com/scroll-tech/go-ethereum/ethdb"
+)
+
+// MakeHashDB imports tries, codes and block hashes from a witness into a new
+// hash-based memory db. We could eventually rewrite this into a pathdb, but
+// simple is better for now.
+//
+// Note, this hashdb approach is quite strictly self-validating:
+// - Headers are persisted keyed by hash, so blockhash will error on junk
+// - Codes are persisted keyed by hash, so bytecode lookup will error on junk
+// - Trie nodes are persisted keyed by hash, so trie expansion will error on junk
+//
+// Acceleration structures built would need to explicitly validate the witness.
+func (w *Witness) MakeHashDB() ethdb.Database {
+ var (
+ memdb = rawdb.NewMemoryDatabase()
+ hasher = crypto.NewKeccakState()
+ hash = make([]byte, 32)
+ )
+ // Inject all the "block hashes" (i.e. headers) into the ephemeral database
+ for _, header := range w.Headers {
+ rawdb.WriteHeader(memdb, header)
+ }
+ // Inject all the bytecodes into the ephemeral database
+ for code := range w.Codes {
+ blob := []byte(code)
+
+ hasher.Reset()
+ hasher.Write(blob)
+ hasher.Read(hash)
+
+ rawdb.WriteCode(memdb, common.BytesToHash(hash), blob)
+ }
+ // Inject all the MPT trie nodes into the ephemeral database
+ for node := range w.State {
+ blob := []byte(node)
+
+ hasher.Reset()
+ hasher.Write(blob)
+ hasher.Read(hash)
+
+ rawdb.WriteTrieNode(memdb, common.BytesToHash(hash), blob)
+ }
+ return memdb
+}
diff --git a/core/stateless/encoding.go b/core/stateless/encoding.go
new file mode 100644
index 000000000000..b67b7460924a
--- /dev/null
+++ b/core/stateless/encoding.go
@@ -0,0 +1,76 @@
+// Copyright 2024 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package stateless
+
+import (
+ "io"
+
+ "github.com/scroll-tech/go-ethereum/core/types"
+ "github.com/scroll-tech/go-ethereum/rlp"
+)
+
+// toExtWitness converts our internal witness representation to the consensus one.
+func (w *Witness) toExtWitness() *extWitness {
+ ext := &extWitness{
+ Headers: w.Headers,
+ }
+ ext.Codes = make([][]byte, 0, len(w.Codes))
+ for code := range w.Codes {
+ ext.Codes = append(ext.Codes, []byte(code))
+ }
+ ext.State = make([][]byte, 0, len(w.State))
+ for node := range w.State {
+ ext.State = append(ext.State, []byte(node))
+ }
+ return ext
+}
+
+// fromExtWitness converts the consensus witness format into our internal one.
+func (w *Witness) fromExtWitness(ext *extWitness) error {
+ w.Headers = ext.Headers
+
+ w.Codes = make(map[string]struct{}, len(ext.Codes))
+ for _, code := range ext.Codes {
+ w.Codes[string(code)] = struct{}{}
+ }
+ w.State = make(map[string]struct{}, len(ext.State))
+ for _, node := range ext.State {
+ w.State[string(node)] = struct{}{}
+ }
+ return nil
+}
+
+// EncodeRLP serializes a witness as RLP.
+func (w *Witness) EncodeRLP(wr io.Writer) error {
+ return rlp.Encode(wr, w.toExtWitness())
+}
+
+// DecodeRLP decodes a witness from RLP.
+func (w *Witness) DecodeRLP(s *rlp.Stream) error {
+ var ext extWitness
+ if err := s.Decode(&ext); err != nil {
+ return err
+ }
+ return w.fromExtWitness(&ext)
+}
+
+// extWitness is a witness RLP encoding for transferring across clients.
+type extWitness struct {
+ Headers []*types.Header
+ Codes [][]byte
+ State [][]byte
+}
diff --git a/core/stateless/witness.go b/core/stateless/witness.go
new file mode 100644
index 000000000000..10e4b08ca1cb
--- /dev/null
+++ b/core/stateless/witness.go
@@ -0,0 +1,122 @@
+// Copyright 2024 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package stateless
+
+import (
+ "errors"
+ "maps"
+ "slices"
+ "sync"
+
+ "github.com/scroll-tech/go-ethereum/common"
+ "github.com/scroll-tech/go-ethereum/core/types"
+)
+
+// HeaderReader is an interface to pull in headers in place of block hashes for
+// the witness.
+type HeaderReader interface {
+ // GetHeader retrieves a block header from the database by hash and number,
+ GetHeader(hash common.Hash, number uint64) *types.Header
+}
+
+// Witness encompasses the state required to apply a set of transactions and
+// derive a post state/receipt root.
+type Witness struct {
+ context *types.Header // Header to which this witness belongs to, with rootHash and receiptHash zeroed out
+
+ Headers []*types.Header // Past headers in reverse order (0=parent, 1=parent's-parent, etc). First *must* be set.
+ Codes map[string]struct{} // Set of bytecodes ran or accessed
+ State map[string]struct{} // Set of MPT state trie nodes (account and storage together)
+
+ chain HeaderReader // Chain reader to convert block hash ops to header proofs
+ lock sync.Mutex // Lock to allow concurrent state insertions
+}
+
+// NewWitness creates an empty witness ready for population.
+func NewWitness(context *types.Header, chain HeaderReader) (*Witness, error) {
+ // When building witnesses, retrieve the parent header, which will *always*
+ // be included to act as a trustless pre-root hash container
+ var headers []*types.Header
+ if chain != nil {
+ parent := chain.GetHeader(context.ParentHash, context.Number.Uint64()-1)
+ if parent == nil {
+ return nil, errors.New("failed to retrieve parent header")
+ }
+ headers = append(headers, parent)
+ }
+ // Create the wtness with a reconstructed gutted out block
+ return &Witness{
+ context: context,
+ Headers: headers,
+ Codes: make(map[string]struct{}),
+ State: make(map[string]struct{}),
+ chain: chain,
+ }, nil
+}
+
+// AddBlockHash adds a "blockhash" to the witness with the designated offset from
+// chain head. Under the hood, this method actually pulls in enough headers from
+// the chain to cover the block being added.
+func (w *Witness) AddBlockHash(number uint64) {
+ // Keep pulling in headers until this hash is populated
+ for int(w.context.Number.Uint64()-number) > len(w.Headers) {
+ tail := w.Headers[len(w.Headers)-1]
+ w.Headers = append(w.Headers, w.chain.GetHeader(tail.ParentHash, tail.Number.Uint64()-1))
+ }
+}
+
+// AddCode adds a bytecode blob to the witness.
+func (w *Witness) AddCode(code []byte) {
+ if len(code) == 0 {
+ return
+ }
+ w.Codes[string(code)] = struct{}{}
+}
+
+// AddState inserts a batch of MPT trie nodes into the witness.
+func (w *Witness) AddState(nodes map[string]struct{}) {
+ if len(nodes) == 0 {
+ return
+ }
+ w.lock.Lock()
+ defer w.lock.Unlock()
+
+ maps.Copy(w.State, nodes)
+}
+
+// Copy deep-copies the witness object. Witness.Block isn't deep-copied as it
+// is never mutated by Witness
+func (w *Witness) Copy() *Witness {
+ cpy := &Witness{
+ Headers: slices.Clone(w.Headers),
+ Codes: maps.Clone(w.Codes),
+ State: maps.Clone(w.State),
+ chain: w.chain,
+ }
+ if w.context != nil {
+ cpy.context = types.CopyHeader(w.context)
+ }
+ return cpy
+}
+
+// Root returns the pre-state root from the first header.
+//
+// Note, this method will panic in case of a bad witness (but RLP decoding will
+// sanitize it and fail before that).
+func (w *Witness) Root() common.Hash {
+ return w.Headers[0].Root
+}
diff --git a/core/types/state_account.go b/core/types/state_account.go
index bb396d439d9e..bde0331d00b3 100644
--- a/core/types/state_account.go
+++ b/core/types/state_account.go
@@ -31,6 +31,6 @@ type StateAccount struct {
KeccakCodeHash []byte
// StateAccount Scroll extensions
- PoseidonCodeHash []byte
- CodeSize uint64
+ PoseidonCodeHash []byte `rlp:"-"`
+ CodeSize uint64 `rlp:"-"`
}
diff --git a/eth/api.go b/eth/api.go
index 74672347823c..07f5874037c8 100644
--- a/eth/api.go
+++ b/eth/api.go
@@ -34,11 +34,14 @@ import (
"github.com/scroll-tech/go-ethereum/core"
"github.com/scroll-tech/go-ethereum/core/rawdb"
"github.com/scroll-tech/go-ethereum/core/state"
+ "github.com/scroll-tech/go-ethereum/core/stateless"
"github.com/scroll-tech/go-ethereum/core/types"
+ "github.com/scroll-tech/go-ethereum/crypto"
"github.com/scroll-tech/go-ethereum/internal/ethapi"
"github.com/scroll-tech/go-ethereum/log"
"github.com/scroll-tech/go-ethereum/rlp"
"github.com/scroll-tech/go-ethereum/rollup/ccc"
+ "github.com/scroll-tech/go-ethereum/rollup/rcfg"
"github.com/scroll-tech/go-ethereum/rpc"
"github.com/scroll-tech/go-ethereum/trie"
)
@@ -321,6 +324,109 @@ func (api *PublicDebugAPI) DumpBlock(blockNr rpc.BlockNumber) (state.Dump, error
return stateDb.RawDump(opts), nil
}
+func (api *PublicDebugAPI) ExecutionWitness(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*ExecutionWitness, error) {
+ block, err := api.eth.APIBackend.BlockByNumberOrHash(ctx, blockNrOrHash)
+ if err != nil {
+ return nil, fmt.Errorf("failed to retrieve block: %w", err)
+ }
+ if block == nil {
+ return nil, fmt.Errorf("block not found: %s", blockNrOrHash.String())
+ }
+
+ witness, err := generateWitness(api.eth.blockchain, block)
+ return ToExecutionWitness(witness), err
+}
+
+func generateWitness(blockchain *core.BlockChain, block *types.Block) (*stateless.Witness, error) {
+ witness, err := stateless.NewWitness(block.Header(), blockchain)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create witness: %w", err)
+ }
+
+ parentHeader := witness.Headers[0]
+ statedb, err := blockchain.StateAt(parentHeader.Root)
+ if err != nil {
+ return nil, fmt.Errorf("failed to retrieve parent state: %w", err)
+ }
+
+ // Collect storage locations that prover needs but sequencer might not touch necessarily
+ statedb.GetState(rcfg.L2MessageQueueAddress, rcfg.WithdrawTrieRootSlot)
+
+ statedb.StartPrefetcher("debug_execution_witness", witness)
+ defer statedb.StopPrefetcher()
+
+ receipts, _, usedGas, err := blockchain.Processor().Process(block, statedb, *blockchain.GetVMConfig())
+ if err != nil {
+ return nil, fmt.Errorf("failed to process block %d: %w", block.Number(), err)
+ }
+
+ if err := blockchain.Validator().ValidateState(block, statedb, receipts, usedGas); err != nil {
+ return nil, fmt.Errorf("failed to validate block %d: %w", block.Number(), err)
+ }
+ return witness, testWitness(blockchain, block, witness)
+}
+
+func testWitness(blockchain *core.BlockChain, block *types.Block, witness *stateless.Witness) error {
+ stateRoot := witness.Root()
+ if diskRoot, _ := rawdb.ReadDiskStateRoot(blockchain.Database(), stateRoot); diskRoot != (common.Hash{}) {
+ stateRoot = diskRoot
+ }
+
+ // Create and populate the state database to serve as the stateless backend
+ statedb, err := state.New(stateRoot, state.NewDatabase(witness.MakeHashDB()), nil)
+ if err != nil {
+ return fmt.Errorf("failed to create state database: %w", err)
+ }
+
+ receipts, _, usedGas, err := blockchain.Processor().Process(block, statedb, *blockchain.GetVMConfig())
+ if err != nil {
+ return fmt.Errorf("failed to process block %d: %w", block.Number(), err)
+ }
+
+ if err := blockchain.Validator().ValidateState(block, statedb, receipts, usedGas); err != nil {
+ return fmt.Errorf("failed to validate block %d: %w", block.Number(), err)
+ }
+
+ postStateRoot := block.Root()
+ if diskRoot, _ := rawdb.ReadDiskStateRoot(blockchain.Database(), postStateRoot); diskRoot != (common.Hash{}) {
+ postStateRoot = diskRoot
+ }
+ if statedb.GetRootHash() != postStateRoot {
+ return fmt.Errorf("failed to commit statelessly %d: %w", block.Number(), err)
+ }
+ return nil
+}
+
+// ExecutionWitness is a witness json encoding for transferring across the network.
+// In the future, we'll probably consider using the extWitness format instead for less overhead if performance becomes an issue.
+// Currently using this format for ease of reading, parsing and compatibility across clients.
+type ExecutionWitness struct {
+ Headers []*types.Header `json:"headers"`
+ Codes map[string]string `json:"codes"`
+ State map[string]string `json:"state"`
+}
+
+func transformMap(in map[string]struct{}) map[string]string {
+ out := make(map[string]string, len(in))
+ for item := range in {
+ bytes := []byte(item)
+ key := crypto.Keccak256Hash(bytes).Hex()
+ out[key] = hexutil.Encode(bytes)
+ }
+ return out
+}
+
+// ToExecutionWitness converts a witness to an execution witness format that is compatible with reth.
+// keccak(node) => node
+// keccak(bytecodes) => bytecodes
+func ToExecutionWitness(w *stateless.Witness) *ExecutionWitness {
+ return &ExecutionWitness{
+ Headers: w.Headers,
+ Codes: transformMap(w.Codes),
+ State: transformMap(w.State),
+ }
+}
+
// PrivateDebugAPI is the collection of Ethereum full node APIs exposed over
// the private debugging endpoint.
type PrivateDebugAPI struct {
@@ -859,3 +965,30 @@ func (api *ScrollAPI) CalculateRowConsumptionByBlockNumber(ctx context.Context,
asyncChecker.Wait()
return rawdb.ReadBlockRowConsumption(api.eth.ChainDb(), block.Hash()), checkErr
}
+
+type DiskAndHeaderRoot struct {
+ DiskRoot common.Hash `json:"diskRoot"`
+ HeaderRoot common.Hash `json:"headerRoot"`
+}
+
+// CalculateRowConsumptionByBlockNumber
+func (api *ScrollAPI) DiskRoot(ctx context.Context, blockNrOrHash *rpc.BlockNumberOrHash) (DiskAndHeaderRoot, error) {
+ block, err := api.eth.APIBackend.BlockByNumberOrHash(ctx, *blockNrOrHash)
+ if err != nil {
+ return DiskAndHeaderRoot{}, fmt.Errorf("failed to retrieve block: %w", err)
+ }
+ if block == nil {
+ return DiskAndHeaderRoot{}, fmt.Errorf("block not found: %s", blockNrOrHash.String())
+ }
+
+ if diskRoot, _ := rawdb.ReadDiskStateRoot(api.eth.ChainDb(), block.Root()); diskRoot != (common.Hash{}) {
+ return DiskAndHeaderRoot{
+ DiskRoot: diskRoot,
+ HeaderRoot: block.Root(),
+ }, nil
+ }
+ return DiskAndHeaderRoot{
+ DiskRoot: block.Root(),
+ HeaderRoot: block.Root(),
+ }, nil
+}
diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go
index d43e01467ee3..8de02ae6f57e 100644
--- a/internal/web3ext/web3ext.go
+++ b/internal/web3ext/web3ext.go
@@ -482,6 +482,13 @@ web3._extend({
params: 2,
inputFormatter:[web3._extend.formatters.inputBlockNumberFormatter, web3._extend.formatters.inputBlockNumberFormatter],
}),
+ new web3._extend.Method({
+ name: 'executionWitness',
+ call: 'debug_executionWitness',
+ params: 1,
+ inputFormatter: [null]
+ }),
+
],
properties: []
});
@@ -942,6 +949,13 @@ web3._extend({
params: 1,
inputFormatter: [web3._extend.formatters.inputBlockNumberFormatter]
}),
+ new web3._extend.Method({
+ name: 'diskRoot',
+ call: 'scroll_diskRoot',
+ params: 1,
+ inputFormatter: [web3._extend.formatters.inputDefaultBlockNumberFormatter],
+ }),
+
],
properties:
[
diff --git a/light/trie.go b/light/trie.go
index 1947c314090b..148814f3fbc5 100644
--- a/light/trie.go
+++ b/light/trie.go
@@ -181,6 +181,11 @@ func (t *odrTrie) do(key []byte, fn func() error) error {
}
}
+// Witness returns a set containing all trie nodes that have been accessed.
+func (t *odrTrie) Witness() map[string]struct{} {
+ panic("not implemented")
+}
+
type nodeIterator struct {
trie.NodeIterator
t *odrTrie
diff --git a/miner/scroll_worker.go b/miner/scroll_worker.go
index e152878d40e6..a6aa7c21c207 100644
--- a/miner/scroll_worker.go
+++ b/miner/scroll_worker.go
@@ -372,7 +372,7 @@ func (w *worker) mainLoop() {
select {
case <-w.startCh:
idleTimer.UpdateSince(idleStart)
- if w.isRunning() {
+ if w.isRunning() && w.chainConfig.Scroll.UseZktrie {
if err := w.checkHeadRowConsumption(); err != nil {
log.Error("failed to start head checkers", "err", err)
return
@@ -490,9 +490,10 @@ func (w *worker) newWork(now time.Time, parentHash common.Hash, reorging bool, r
vmConfig := *w.chain.GetVMConfig()
cccLogger := ccc.NewLogger()
- vmConfig.Debug = true
- vmConfig.Tracer = cccLogger
-
+ if w.chainConfig.Scroll.UseZktrie {
+ vmConfig.Debug = true
+ vmConfig.Tracer = cccLogger
+ }
deadline := time.Unix(int64(header.Time), 0)
if w.chainConfig.Clique != nil && w.chainConfig.Clique.RelaxedPeriod {
// clique with relaxed period uses time.Now() as the header.Time, calculate the deadline
@@ -566,6 +567,11 @@ func (w *worker) handleForks() (bool, error) {
misc.ApplyCurieHardFork(w.current.state)
return true, nil
}
+
+ if w.chainConfig.IsEuclid(w.current.header.Time) {
+ parent := w.chain.GetBlockByHash(w.current.header.ParentHash)
+ return parent != nil && !w.chainConfig.IsEuclid(parent.Time()), nil
+ }
return false, nil
}
@@ -809,7 +815,10 @@ func (w *worker) commit() (common.Hash, error) {
}(time.Now())
w.updateSnapshot()
- if !w.isRunning() && !w.current.reorging {
+ // Since clocks of mpt-sequencer and zktrie-sequencer can be slightly out of sync,
+ // this might result in a reorg at the Euclid fork block. But it will be resolved shortly after.
+ canCommitState := w.chainConfig.Scroll.UseZktrie != w.chainConfig.IsEuclid(w.current.header.Time)
+ if !canCommitState || (!w.isRunning() && !w.current.reorging) {
return common.Hash{}, nil
}
@@ -886,7 +895,7 @@ func (w *worker) commit() (common.Hash, error) {
currentHeight := w.current.header.Number.Uint64()
maxReorgDepth := uint64(w.config.CCCMaxWorkers + 1)
- if !w.current.reorging && currentHeight > maxReorgDepth {
+ if w.chainConfig.Scroll.UseZktrie && !w.current.reorging && currentHeight > maxReorgDepth {
ancestorHeight := currentHeight - maxReorgDepth
ancestorHash := w.chain.GetHeaderByNumber(ancestorHeight).Hash()
if rawdb.ReadBlockRowConsumption(w.chain.Database(), ancestorHash) == nil {
@@ -914,8 +923,10 @@ func (w *worker) commit() (common.Hash, error) {
w.mux.Post(core.NewMinedBlockEvent{Block: block})
checkStart := time.Now()
- if err = w.asyncChecker.Check(block); err != nil {
- log.Error("failed to launch CCC background task", "err", err)
+ if w.chainConfig.Scroll.UseZktrie {
+ if err = w.asyncChecker.Check(block); err != nil {
+ log.Error("failed to launch CCC background task", "err", err)
+ }
}
cccStallTimer.UpdateSince(checkStart)
diff --git a/miner/scroll_worker_test.go b/miner/scroll_worker_test.go
index 5f79902f0e15..3e8468348f53 100644
--- a/miner/scroll_worker_test.go
+++ b/miner/scroll_worker_test.go
@@ -232,6 +232,7 @@ func testGenerateBlockAndImport(t *testing.T, isClique bool) {
engine = ethash.NewFaker()
}
chainConfig.Scroll.FeeVaultAddress = &common.Address{}
+ chainConfig.Scroll.UseZktrie = true
chainConfig.LondonBlock = big.NewInt(0)
w, b := newTestWorker(t, chainConfig, engine, db, 0)
@@ -296,6 +297,7 @@ func testGenerateBlockWithL1Msg(t *testing.T, isClique bool) {
NumL1MessagesPerBlock: 1,
}
chainConfig.Scroll.FeeVaultAddress = &common.Address{}
+ chainConfig.Scroll.UseZktrie = true
chainConfig.LondonBlock = big.NewInt(0)
w, b := newTestWorker(t, chainConfig, engine, db, 0)
@@ -344,6 +346,7 @@ func TestAcceptableTxlimit(t *testing.T) {
chainConfig = params.AllCliqueProtocolChanges
chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000}
chainConfig.Scroll.FeeVaultAddress = &common.Address{}
+ chainConfig.Scroll.UseZktrie = true
engine = clique.New(chainConfig.Clique, db)
// Set maxTxPerBlock = 4, which >= non-l1msg + non-skipped l1msg txs
@@ -404,6 +407,7 @@ func TestUnacceptableTxlimit(t *testing.T) {
chainConfig = params.AllCliqueProtocolChanges
chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000}
chainConfig.Scroll.FeeVaultAddress = &common.Address{}
+ chainConfig.Scroll.UseZktrie = true
engine = clique.New(chainConfig.Clique, db)
// Set maxTxPerBlock = 3, which < non-l1msg + l1msg txs
@@ -463,6 +467,7 @@ func TestL1MsgCorrectOrder(t *testing.T) {
chainConfig = params.AllCliqueProtocolChanges
chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000}
chainConfig.Scroll.FeeVaultAddress = &common.Address{}
+ chainConfig.Scroll.UseZktrie = true
engine = clique.New(chainConfig.Clique, db)
maxTxPerBlock := 4
@@ -525,6 +530,7 @@ func l1MessageTest(t *testing.T, msgs []types.L1MessageTx, withL2Tx bool, callba
chainConfig = params.AllCliqueProtocolChanges
chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000}
+ chainConfig.Scroll.UseZktrie = true
engine = clique.New(chainConfig.Clique, db)
maxTxPerBlock := 4
chainConfig.Scroll.MaxTxPerBlock = &maxTxPerBlock
@@ -879,6 +885,7 @@ func TestPrioritizeOverflowTx(t *testing.T) {
chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000}
chainConfig.LondonBlock = big.NewInt(0)
chainConfig.Scroll.FeeVaultAddress = &common.Address{}
+ chainConfig.Scroll.UseZktrie = true
engine := clique.New(chainConfig.Clique, db)
w, b := newTestWorker(t, chainConfig, engine, db, 0)
@@ -1036,6 +1043,7 @@ func TestPending(t *testing.T) {
chainConfig = params.AllCliqueProtocolChanges
chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000}
chainConfig.Scroll.FeeVaultAddress = &common.Address{}
+ chainConfig.Scroll.UseZktrie = true
engine = clique.New(chainConfig.Clique, db)
w, b := newTestWorker(t, chainConfig, engine, db, 0)
defer w.close()
@@ -1080,6 +1088,7 @@ func TestReorg(t *testing.T) {
chainConfig = params.AllCliqueProtocolChanges
chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000, RelaxedPeriod: true}
chainConfig.Scroll.FeeVaultAddress = &common.Address{}
+ chainConfig.Scroll.UseZktrie = true
engine = clique.New(chainConfig.Clique, db)
maxTxPerBlock := 2
@@ -1194,6 +1203,7 @@ func TestRestartHeadCCC(t *testing.T) {
chainConfig = params.AllCliqueProtocolChanges
chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000, RelaxedPeriod: true}
chainConfig.Scroll.FeeVaultAddress = &common.Address{}
+ chainConfig.Scroll.UseZktrie = true
engine = clique.New(chainConfig.Clique, db)
maxTxPerBlock := 2
diff --git a/params/config.go b/params/config.go
index eb2213fdd360..dc14a8b0769c 100644
--- a/params/config.go
+++ b/params/config.go
@@ -29,14 +29,16 @@ import (
// Genesis hashes to enforce below configs on.
var (
- MainnetGenesisHash = common.HexToHash("0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3")
- RopstenGenesisHash = common.HexToHash("0x41941023680923e0fe4d74a34bdac8141f2540e3ae90623718e47d66d1ca4a2d")
- SepoliaGenesisHash = common.HexToHash("0x25a5cc106eea7138acab33231d7160d69cb777ee0c2c553fcddf5138993e6dd9")
- RinkebyGenesisHash = common.HexToHash("0x6341fd3daf94b748c72ced5a5b26028f2474f5f00d824504e4fa37a75767e177")
- GoerliGenesisHash = common.HexToHash("0xbf7e331f7f7c1dd2e05159666b3bf8bc7a8a3a9eb1d518969eab529dd9b88c1a")
- ScrollAlphaGenesisHash = common.HexToHash("0xa4fc62b9b0643e345bdcebe457b3ae898bef59c7203c3db269200055e037afda")
- ScrollSepoliaGenesisHash = common.HexToHash("0xaa62d1a8b2bffa9e5d2368b63aae0d98d54928bd713125e3fd9e5c896c68592c")
- ScrollMainnetGenesisHash = common.HexToHash("0xbbc05efd412b7cd47a2ed0e5ddfcf87af251e414ea4c801d78b6784513180a80")
+ MainnetGenesisHash = common.HexToHash("0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3")
+ RopstenGenesisHash = common.HexToHash("0x41941023680923e0fe4d74a34bdac8141f2540e3ae90623718e47d66d1ca4a2d")
+ SepoliaGenesisHash = common.HexToHash("0x25a5cc106eea7138acab33231d7160d69cb777ee0c2c553fcddf5138993e6dd9")
+ RinkebyGenesisHash = common.HexToHash("0x6341fd3daf94b748c72ced5a5b26028f2474f5f00d824504e4fa37a75767e177")
+ GoerliGenesisHash = common.HexToHash("0xbf7e331f7f7c1dd2e05159666b3bf8bc7a8a3a9eb1d518969eab529dd9b88c1a")
+ ScrollAlphaGenesisHash = common.HexToHash("0xa4fc62b9b0643e345bdcebe457b3ae898bef59c7203c3db269200055e037afda")
+ ScrollSepoliaGenesisHash = common.HexToHash("0xaa62d1a8b2bffa9e5d2368b63aae0d98d54928bd713125e3fd9e5c896c68592c")
+ ScrollMainnetGenesisHash = common.HexToHash("0xbbc05efd412b7cd47a2ed0e5ddfcf87af251e414ea4c801d78b6784513180a80")
+ ScrollSepoliaGenesisState = common.HexToHash("0x20695989e9038823e35f0e88fbc44659ffdbfa1fe89fbeb2689b43f15fa64cb5")
+ ScrollMainnetGenesisState = common.HexToHash("0x08d535cc60f40af5dd3b31e0998d7567c2d568b224bed2ba26070aeb078d1339")
)
func newUint64(val uint64) *uint64 { return &val }
@@ -340,6 +342,7 @@ var (
NumL1MessagesPerBlock: 10,
ScrollChainAddress: common.HexToAddress("0x2D567EcE699Eabe5afCd141eDB7A4f2D0D6ce8a0"),
},
+ GenesisStateRoot: &ScrollSepoliaGenesisState,
},
}
@@ -380,6 +383,7 @@ var (
NumL1MessagesPerBlock: 10,
ScrollChainAddress: common.HexToAddress("0xa13BAF47339d63B743e7Da8741db5456DAc1E556"),
},
+ GenesisStateRoot: &ScrollMainnetGenesisState,
},
}
@@ -633,6 +637,7 @@ type ChainConfig struct {
CurieBlock *big.Int `json:"curieBlock,omitempty"` // Curie switch block (nil = no fork, 0 = already on curie)
DarwinTime *uint64 `json:"darwinTime,omitempty"` // Darwin switch time (nil = no fork, 0 = already on darwin)
DarwinV2Time *uint64 `json:"darwinv2Time,omitempty"` // DarwinV2 switch time (nil = no fork, 0 = already on darwinv2)
+ EuclidTime *uint64 `json:"euclidTime,omitempty"` // Euclid switch time (nil = no fork, 0 = already on euclid)
// TerminalTotalDifficulty is the amount of total difficulty reached by
// the network that triggers the consensus upgrade.
@@ -661,6 +666,9 @@ type ScrollConfig struct {
// L1 config
L1Config *L1Config `json:"l1Config,omitempty"`
+
+ // Genesis State Root for MPT clients
+ GenesisStateRoot *common.Hash `json:"genesisStateRoot,omitempty"`
}
// L1Config contains the l1 parameters needed to sync l1 contract events (e.g., l1 messages, commit/revert/finalize batches) in the sequencer
@@ -888,6 +896,11 @@ func (c *ChainConfig) IsDarwinV2(now uint64) bool {
return isForkedTime(now, c.DarwinV2Time)
}
+// IsEuclid returns whether num is either equal to the Darwin fork block or greater.
+func (c *ChainConfig) IsEuclid(now uint64) bool {
+ return isForkedTime(now, c.EuclidTime)
+}
+
// IsTerminalPoWBlock returns whether the given block is the last block of PoW stage.
func (c *ChainConfig) IsTerminalPoWBlock(parentTotalDiff *big.Int, totalDiff *big.Int) bool {
if c.TerminalTotalDifficulty == nil {
@@ -1100,7 +1113,7 @@ type Rules struct {
IsHomestead, IsEIP150, IsEIP155, IsEIP158 bool
IsByzantium, IsConstantinople, IsPetersburg, IsIstanbul bool
IsBerlin, IsLondon, IsArchimedes, IsShanghai bool
- IsBernoulli, IsCurie, IsDarwin bool
+ IsBernoulli, IsCurie, IsDarwin, IsEuclid bool
}
// Rules ensures c's ChainID is not nil.
@@ -1126,5 +1139,6 @@ func (c *ChainConfig) Rules(num *big.Int, time uint64) Rules {
IsBernoulli: c.IsBernoulli(num),
IsCurie: c.IsCurie(num),
IsDarwin: c.IsDarwin(time),
+ IsEuclid: c.IsEuclid(time),
}
}
diff --git a/rollup/ccc/async_checker.go b/rollup/ccc/async_checker.go
index 1cb9b7d78768..b5815cb58403 100644
--- a/rollup/ccc/async_checker.go
+++ b/rollup/ccc/async_checker.go
@@ -98,6 +98,11 @@ func (c *AsyncChecker) Wait() {
// Check spawns an async CCC verification task.
func (c *AsyncChecker) Check(block *types.Block) error {
+ if c.bc.Config().IsEuclid(block.Time()) {
+ // Euclid blocks use MPT and CCC doesn't support them
+ return nil
+ }
+
if block.NumberU64() > c.currentHead.Number.Uint64()+1 {
log.Warn("non continuous chain observed in AsyncChecker", "prev", c.currentHead, "got", block.Header())
}
diff --git a/rollup/pipeline/pipeline.go b/rollup/pipeline/pipeline.go
index 90c6149b3858..77ac3aee24a0 100644
--- a/rollup/pipeline/pipeline.go
+++ b/rollup/pipeline/pipeline.go
@@ -228,7 +228,7 @@ func sendCancellable[T any, C comparable](resCh chan T, msg T, cancelCh <-chan C
}
func (p *Pipeline) traceAndApplyStage(txsIn <-chan *types.Transaction) (<-chan error, <-chan *BlockCandidate, error) {
- p.state.StartPrefetcher("miner")
+ p.state.StartPrefetcher("miner", nil)
downstreamCh := make(chan *BlockCandidate, p.downstreamChCapacity())
resCh := make(chan error)
p.wg.Add(1)
diff --git a/trie/database.go b/trie/database.go
index 1c5b7f805aea..1301a83a5e98 100644
--- a/trie/database.go
+++ b/trie/database.go
@@ -667,6 +667,9 @@ func (db *Database) Commit(node common.Hash, report bool, callback func(common.H
}
batch.Reset()
+ if diskRoot, err := rawdb.ReadDiskStateRoot(db.diskdb, node); err == nil {
+ node = diskRoot
+ }
if (node == common.Hash{}) {
return nil
}
@@ -782,7 +785,7 @@ func (c *cleaner) Put(key []byte, rlp []byte) error {
delete(c.db.dirties, hash)
c.db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size))
if node.children != nil {
- c.db.dirtiesSize -= common.StorageSize(cachedNodeChildrenSize + len(node.children)*(common.HashLength+2))
+ c.db.childrenSize -= common.StorageSize(cachedNodeChildrenSize + len(node.children)*(common.HashLength+2))
}
// Move the flushed node into the clean cache to prevent insta-reloads
if c.db.cleans != nil {
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 761d812bdfcc..7516d2879010 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -296,7 +296,7 @@ func TestUnionIterator(t *testing.T) {
}
func TestIteratorNoDups(t *testing.T) {
- var tr Trie
+ tr := newEmpty()
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
@@ -530,7 +530,7 @@ func TestNodeIteratorLargeTrie(t *testing.T) {
trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
// master: 24 get operations
// this pr: 5 get operations
- if have, want := logDb.getCount, uint64(5); have != want {
+ if have, want := logDb.getCount, uint64(10); have != want {
t.Fatalf("Too many lookups during seek, have %d want %d", have, want)
}
}
diff --git a/trie/proof.go b/trie/proof.go
index 58fb4c3cc78a..3362512b20c7 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -559,7 +559,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
}
// Rebuild the trie with the leaf stream, the shape of trie
// should be same with the original one.
- tr := &Trie{root: root, db: NewDatabase(memorydb.New())}
+ tr := &Trie{root: root, db: NewDatabase(memorydb.New()), tracer: newTracer()}
if empty {
tr.root = nil
}
diff --git a/trie/proof_test.go b/trie/proof_test.go
index 2155ae0fbd6a..5c5304261d2d 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -79,7 +79,7 @@ func TestProof(t *testing.T) {
}
func TestOneElementProof(t *testing.T) {
- trie := new(Trie)
+ trie := newEmpty()
updateString(trie, "k", "v")
for i, prover := range makeProvers(trie) {
proof := prover([]byte("k"))
@@ -130,7 +130,7 @@ func TestBadProof(t *testing.T) {
// Tests that missing keys can also be proven. The test explicitly uses a single
// entry trie and checks for missing keys both before and after the single entry.
func TestMissingKeyProof(t *testing.T) {
- trie := new(Trie)
+ trie := newEmpty()
updateString(trie, "k", "v")
for i, key := range []string{"a", "j", "l", "z"} {
@@ -386,7 +386,7 @@ func TestOneElementRangeProof(t *testing.T) {
}
// Test the mini trie with only a single element.
- tinyTrie := new(Trie)
+ tinyTrie := newEmpty()
entry := &kv{randBytes(32), randBytes(20), false}
tinyTrie.Update(entry.k, entry.v)
@@ -458,7 +458,7 @@ func TestAllElementsProof(t *testing.T) {
// TestSingleSideRangeProof tests the range starts from zero.
func TestSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
- trie := new(Trie)
+ trie := newEmpty()
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
@@ -493,7 +493,7 @@ func TestSingleSideRangeProof(t *testing.T) {
// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff.
func TestReverseSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
- trie := new(Trie)
+ trie := newEmpty()
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
@@ -600,7 +600,7 @@ func TestBadRangeProof(t *testing.T) {
// TestGappedRangeProof focuses on the small trie with embedded nodes.
// If the gapped node is embedded in the trie, it should be detected too.
func TestGappedRangeProof(t *testing.T) {
- trie := new(Trie)
+ trie := newEmpty()
var entries []*kv // Sorted entries
for i := byte(0); i < 10; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
@@ -674,7 +674,7 @@ func TestSameSideProofs(t *testing.T) {
}
func TestHasRightElement(t *testing.T) {
- trie := new(Trie)
+ trie := newEmpty()
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
@@ -1027,7 +1027,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
}
func randomTrie(n int) (*Trie, map[string]*kv) {
- trie := new(Trie)
+ trie := newEmpty()
vals := make(map[string]*kv)
for i := byte(0); i < 100; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
@@ -1052,7 +1052,7 @@ func randBytes(n int) []byte {
}
func nonRandomTrie(n int) (*Trie, map[string]*kv) {
- trie := new(Trie)
+ trie := newEmpty()
vals := make(map[string]*kv)
max := uint64(0xffffffffffffffff)
for i := uint64(0); i < uint64(n); i++ {
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 253b8d780ad3..a3529c2fc68f 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -190,6 +190,7 @@ func (t *SecureTrie) Hash() common.Hash {
// Copy returns a copy of SecureTrie.
func (t *SecureTrie) Copy() *SecureTrie {
cpy := *t
+ cpy.trie.tracer = t.trie.tracer.copy()
return &cpy
}
@@ -221,3 +222,8 @@ func (t *SecureTrie) getSecKeyCache() map[string][]byte {
}
return t.secKeyCache
}
+
+// Witness returns a set containing all trie nodes that have been accessed.
+func (t *SecureTrie) Witness() map[string]struct{} {
+ return t.trie.Witness()
+}
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
index b81b4e1ad5b8..9baaa2e266ed 100644
--- a/trie/secure_trie_test.go
+++ b/trie/secure_trie_test.go
@@ -112,8 +112,7 @@ func TestSecureTrieConcurrency(t *testing.T) {
threads := runtime.NumCPU()
tries := make([]*SecureTrie, threads)
for i := 0; i < threads; i++ {
- cpy := *trie
- tries[i] = &cpy
+ tries[i] = trie.Copy()
}
// Start a batch of goroutines interactng with the trie
pend := new(sync.WaitGroup)
diff --git a/trie/tracer.go b/trie/tracer.go
new file mode 100644
index 000000000000..99cda0706f7d
--- /dev/null
+++ b/trie/tracer.go
@@ -0,0 +1,122 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "maps"
+
+ "github.com/scroll-tech/go-ethereum/common"
+)
+
+// tracer tracks the changes of trie nodes. During the trie operations,
+// some nodes can be deleted from the trie, while these deleted nodes
+// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted
+// nodes won't be removed from the disk at all. Tracer is an auxiliary tool
+// used to track all insert and delete operations of trie and capture all
+// deleted nodes eventually.
+//
+// The changed nodes can be mainly divided into two categories: the leaf
+// node and intermediate node. The former is inserted/deleted by callers
+// while the latter is inserted/deleted in order to follow the rule of trie.
+// This tool can track all of them no matter the node is embedded in its
+// parent or not, but valueNode is never tracked.
+//
+// Besides, it's also used for recording the original value of the nodes
+// when they are resolved from the disk. The pre-value of the nodes will
+// be used to construct trie history in the future.
+//
+// Note tracer is not thread-safe, callers should be responsible for handling
+// the concurrency issues by themselves.
+type tracer struct {
+ inserts map[string]struct{}
+ deletes map[string]struct{}
+ accessList map[string][]byte
+}
+
+// newTracer initializes the tracer for capturing trie changes.
+func newTracer() *tracer {
+ return &tracer{
+ inserts: make(map[string]struct{}),
+ deletes: make(map[string]struct{}),
+ accessList: make(map[string][]byte),
+ }
+}
+
+// onRead tracks the newly loaded trie node and caches the rlp-encoded
+// blob internally. Don't change the value outside of function since
+// it's not deep-copied.
+func (t *tracer) onRead(path []byte, val []byte) {
+ t.accessList[string(path)] = val
+}
+
+// onInsert tracks the newly inserted trie node. If it's already
+// in the deletion set (resurrected node), then just wipe it from
+// the deletion set as it's "untouched".
+func (t *tracer) onInsert(path []byte) {
+ if _, present := t.deletes[string(path)]; present {
+ delete(t.deletes, string(path))
+ return
+ }
+ t.inserts[string(path)] = struct{}{}
+}
+
+// onDelete tracks the newly deleted trie node. If it's already
+// in the addition set, then just wipe it from the addition set
+// as it's untouched.
+func (t *tracer) onDelete(path []byte) {
+ if _, present := t.inserts[string(path)]; present {
+ delete(t.inserts, string(path))
+ return
+ }
+ t.deletes[string(path)] = struct{}{}
+}
+
+// reset clears the content tracked by tracer.
+func (t *tracer) reset() {
+ t.inserts = make(map[string]struct{})
+ t.deletes = make(map[string]struct{})
+ t.accessList = make(map[string][]byte)
+}
+
+// copy returns a deep copied tracer instance.
+func (t *tracer) copy() *tracer {
+ accessList := make(map[string][]byte, len(t.accessList))
+ for path, blob := range t.accessList {
+ accessList[path] = common.CopyBytes(blob)
+ }
+ return &tracer{
+ inserts: maps.Clone(t.inserts),
+ deletes: maps.Clone(t.deletes),
+ accessList: accessList,
+ }
+}
+
+// deletedNodes returns a list of node paths which are deleted from the trie.
+func (t *tracer) deletedNodes() []string {
+ var paths []string
+ for path := range t.deletes {
+ // It's possible a few deleted nodes were embedded
+ // in their parent before, the deletions can be no
+ // effect by deleting nothing, filter them out.
+ _, ok := t.accessList[path]
+ if !ok {
+ continue
+ }
+ paths = append(paths, path)
+ }
+ return paths
+}
diff --git a/trie/trie.go b/trie/trie.go
index 81cdd1627745..fe521d269074 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -62,6 +62,9 @@ type Trie struct {
// hashing operation. This number will not directly map to the number of
// actually unhashed nodes
unhashed int
+
+ // tracer is the tool to track the trie changes.
+ tracer *tracer
}
// newFlag returns the cache flag value for a newly created node.
@@ -80,7 +83,8 @@ func New(root common.Hash, db *Database) (*Trie, error) {
panic("trie.New called without a database")
}
trie := &Trie{
- db: db,
+ db: db,
+ tracer: newTracer(),
}
if root != (common.Hash{}) && root != emptyRoot {
rootnode, err := trie.resolveHash(root[:], nil)
@@ -313,6 +317,11 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
if matchlen == 0 {
return true, branch, nil
}
+ // New branch node is created as a child of the original short node.
+ // Track the newly inserted node in the tracer. The node identifier
+ // passed is the path from the root node.
+ t.tracer.onInsert(append(prefix, key[:matchlen]...))
+
// Otherwise, replace it with a short node leading up to the branch.
return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
@@ -327,6 +336,11 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return true, n, nil
case nil:
+ // New short node is created and track it in the tracer. The node identifier
+ // passed is the path from the root node. Note the valueNode won't be tracked
+ // since it's always embedded in its parent.
+ t.tracer.onInsert(prefix)
+
return true, &shortNode{key, value, t.newFlag()}, nil
case hashNode:
@@ -379,6 +393,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, n, nil // don't replace n on mismatch
}
if matchlen == len(key) {
+ // The matched short node is deleted entirely and track
+ // it in the deletion set. The same the valueNode doesn't
+ // need to be tracked at all since it's always embedded.
+ t.tracer.onDelete(prefix)
+
return true, nil, nil // remove n entirely for whole matches
}
// The key is longer than n.Key. Remove the remaining suffix
@@ -391,6 +410,10 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
}
switch child := child.(type) {
case *shortNode:
+ // The child shortNode is merged into its parent, track
+ // is deleted as well.
+ t.tracer.onDelete(append(prefix, n.Key...))
+
// Deleting from the subtrie reduced it to another
// short node. Merge the nodes to avoid creating a
// shortNode{..., shortNode{...}}. Use concat (which
@@ -452,6 +475,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, nil, err
}
if cnode, ok := cnode.(*shortNode); ok {
+ // Replace the entire full node with the short node.
+ // Mark the original short node as deleted since the
+ // value is embedded into the parent now.
+ t.tracer.onDelete(append(prefix, byte(pos)))
+
k := append([]byte{byte(pos)}, cnode.Key...)
return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
}
@@ -505,6 +533,11 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) {
func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
hash := common.BytesToHash(n)
if node := t.db.node(hash); node != nil {
+ rlp, err := t.db.Node(hash)
+ if err != nil {
+ return nil, err
+ }
+ t.tracer.onRead(prefix, rlp)
return node, nil
}
return nil, &MissingNodeError{NodeHash: hash, Path: prefix}
@@ -582,4 +615,18 @@ func (t *Trie) hashRoot() (node, node, error) {
func (t *Trie) Reset() {
t.root = nil
t.unhashed = 0
+ t.tracer.reset()
+}
+
+// Witness returns a set containing all trie nodes that have been accessed.
+func (t *Trie) Witness() map[string]struct{} {
+ if len(t.tracer.accessList) == 0 {
+ return nil
+ }
+
+ witness := make(map[string]struct{}, len(t.tracer.accessList))
+ for _, node := range t.tracer.accessList {
+ witness[string(node)] = struct{}{}
+ }
+ return witness
}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 23780a5ff807..e8f6293d0376 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -64,7 +64,7 @@ func TestEmptyTrie(t *testing.T) {
}
func TestNull(t *testing.T) {
- var trie Trie
+ trie := newEmpty()
key := make([]byte, 32)
value := []byte("test")
trie.Update(key, value)
@@ -593,15 +593,15 @@ func TestTinyTrie(t *testing.T) {
_, accounts := makeAccounts(5)
trie := newEmpty()
trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), accounts[3])
- if exp, root := common.HexToHash("fc516c51c03bf9f1a0eec6ed6f6f5da743c2745dcd5670007519e6ec056f95a8"), trie.Hash(); exp != root {
+ if exp, root := common.HexToHash("8c6a85a4d9fda98feff88450299e574e5378e32391f75a055d470ac0653f1005"), trie.Hash(); exp != root {
t.Errorf("1: got %x, exp %x", root, exp)
}
trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001338"), accounts[4])
- if exp, root := common.HexToHash("5070d3f144546fd13589ad90cd153954643fa4ca6c1a5f08683cbfbbf76e960c"), trie.Hash(); exp != root {
+ if exp, root := common.HexToHash("ec63b967e98a5720e7f720482151963982890d82c9093c0d486b7eb8883a66b1"), trie.Hash(); exp != root {
t.Errorf("2: got %x, exp %x", root, exp)
}
trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001339"), accounts[4])
- if exp, root := common.HexToHash("aa3fba77e50f6e931d8aacde70912be5bff04c7862f518ae06f3418dd4d37be3"), trie.Hash(); exp != root {
+ if exp, root := common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), trie.Hash(); exp != root {
t.Errorf("3: got %x, exp %x", root, exp)
}
checktr, _ := New(common.Hash{}, trie.db)
@@ -625,7 +625,7 @@ func TestCommitAfterHash(t *testing.T) {
trie.Hash()
trie.Commit(nil)
root := trie.Hash()
- exp := common.HexToHash("f0c0681648c93b347479cd58c61995557f01294425bd031ce1943c2799bbd4ec")
+ exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6")
if exp != root {
t.Errorf("got %x, exp %x", root, exp)
}
@@ -725,12 +725,12 @@ func TestCommitSequence(t *testing.T) {
expWriteSeqHash []byte
expCallbackSeqHash []byte
}{
- {20, common.FromHex("7b908cce3bc16abb3eac5dff6c136856526f15225f74ce860a2bec47912a5492"),
- common.FromHex("fac65cd2ad5e301083d0310dd701b5faaff1364cbe01cdbfaf4ec3609bb4149e")},
- {200, common.FromHex("55791f6ec2f83fee512a2d3d4b505784fdefaea89974e10440d01d62a18a298a"),
- common.FromHex("5ab775b64d86a8058bb71c3c765d0f2158c14bbeb9cb32a65eda793a7e95e30f")},
- {2000, common.FromHex("ccb464abf67804538908c62431b3a6788e8dc6dee62aff9bfe6b10136acfceac"),
- common.FromHex("b908adff17a5aa9d6787324c39014a74b04cef7fba6a92aeb730f48da1ca665d")},
+ {20, common.FromHex("873c78df73d60e59d4a2bcf3716e8bfe14554549fea2fc147cb54129382a8066"),
+ common.FromHex("ff00f91ac05df53b82d7f178d77ada54fd0dca64526f537034a5dbe41b17df2a")},
+ {200, common.FromHex("ba03d891bb15408c940eea5ee3d54d419595102648d02774a0268d892add9c8e"),
+ common.FromHex("f3cd509064c8d319bbdd1c68f511850a902ad275e6ed5bea11547e23d492a926")},
+ {2000, common.FromHex("f7a184f20df01c94f09537401d11e68d97ad0c00115233107f51b9c287ce60c7"),
+ common.FromHex("ff795ea898ba1e4cfed4a33b4cf5535a347a02cf931f88d88719faf810f9a1c9")},
} {
addresses, accounts := makeAccounts(tc.count)
// This spongeDb is used to check the sequence of disk-db-writes
diff --git a/trie/zk_trie.go b/trie/zk_trie.go
index 044e18ad66ba..a98ae474ddff 100644
--- a/trie/zk_trie.go
+++ b/trie/zk_trie.go
@@ -233,3 +233,33 @@ func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueRead
return nil, fmt.Errorf("bad proof node %v", proof)
}
}
+
+func (t *ZkTrie) CountLeaves() uint64 {
+ root, err := t.ZkTrie.Tree().Root()
+ if err != nil {
+ panic("CountLeaves cannot get root")
+ }
+ return t.countLeaves(root)
+}
+
+func (t *ZkTrie) countLeaves(root *zkt.Hash) uint64 {
+ if root == nil {
+ return 0
+ }
+
+ rootNode, err := t.ZkTrie.Tree().GetNode(root)
+ if err != nil {
+ panic("countLeaves cannot get rootNode")
+ }
+
+ if rootNode.Type == zktrie.NodeTypeLeaf_New {
+ return 1
+ } else {
+ return t.countLeaves(rootNode.ChildL) + t.countLeaves(rootNode.ChildR)
+ }
+}
+
+// Witness returns a set containing all trie nodes that have been accessed.
+func (t *ZkTrie) Witness() map[string]struct{} {
+ panic("not implemented")
+}
diff --git a/trie/zk_trie_test.go b/trie/zk_trie_test.go
index 6c23abc2764b..d1700a75899d 100644
--- a/trie/zk_trie_test.go
+++ b/trie/zk_trie_test.go
@@ -26,12 +26,16 @@ import (
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
zkt "github.com/scroll-tech/zktrie/types"
"github.com/scroll-tech/go-ethereum/common"
+ "github.com/scroll-tech/go-ethereum/core/rawdb"
+ "github.com/scroll-tech/go-ethereum/core/types"
"github.com/scroll-tech/go-ethereum/ethdb/leveldb"
"github.com/scroll-tech/go-ethereum/ethdb/memorydb"
+ "github.com/scroll-tech/go-ethereum/rlp"
)
func newEmptyZkTrie() *ZkTrie {
@@ -264,3 +268,72 @@ func TestZkTrieDelete(t *testing.T) {
assert.Equal(t, hashes[i].Hex(), hash.Hex())
}
}
+
+func TestEquivalence(t *testing.T) {
+ t.Skip()
+
+ zkDb, err := leveldb.New("/Users/omer/Documents/go-ethereum/l2geth-datadir/geth/chaindata", 0, 0, "", true)
+ require.NoError(t, err)
+ mptDb, err := leveldb.New("/Users/omer/Documents/go-ethereum/l2geth-datadir-mpt/geth/chaindata", 0, 0, "", true)
+ require.NoError(t, err)
+
+ zkRoot := common.HexToHash("0x294b458b5b571bb634dbe9a81331dd2aabb5ef40cdc0328b075a9666d5df55d0")
+ mptRoot, err := rawdb.ReadDiskStateRoot(mptDb, zkRoot)
+ require.NoError(t, err)
+
+ checkTrieEquality(t, &dbs{
+ zkDb: zkDb,
+ mptDb: mptDb,
+ }, zkRoot, mptRoot, checkAccountEquality)
+}
+
+type dbs struct {
+ zkDb *leveldb.Database
+ mptDb *leveldb.Database
+}
+
+var accountsLeft = -1
+
+func checkTrieEquality(t *testing.T, dbs *dbs, zkRoot, mptRoot common.Hash, leafChecker func(*testing.T, *dbs, []byte, []byte)) {
+ zkTrie, err := NewZkTrie(zkRoot, NewZktrieDatabase(dbs.zkDb))
+ require.NoError(t, err)
+
+ mptTrie, err := NewSecure(mptRoot, NewDatabaseWithConfig(dbs.mptDb, &Config{Preimages: true}))
+ require.NoError(t, err)
+
+ expectedLeaves := zkTrie.CountLeaves()
+ trieIt := NewIterator(mptTrie.NodeIterator(nil))
+ if accountsLeft == -1 {
+ accountsLeft = int(expectedLeaves)
+ }
+
+ for trieIt.Next() {
+ expectedLeaves--
+ preimageKey := mptTrie.GetKey(trieIt.Key)
+ require.NotEmpty(t, preimageKey)
+ leafChecker(t, dbs, zkTrie.Get(preimageKey), mptTrie.Get(preimageKey))
+ }
+ require.Zero(t, expectedLeaves)
+}
+
+func checkAccountEquality(t *testing.T, dbs *dbs, zkAccountBytes, mptAccountBytes []byte) {
+ mptAccount := &types.StateAccount{}
+ require.NoError(t, rlp.DecodeBytes(mptAccountBytes, mptAccount))
+ zkAccount, err := types.UnmarshalStateAccount(zkAccountBytes)
+ require.NoError(t, err)
+
+ require.Equal(t, mptAccount.Nonce, zkAccount.Nonce)
+ require.True(t, mptAccount.Balance.Cmp(zkAccount.Balance) == 0)
+ require.Equal(t, mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash)
+ checkTrieEquality(t, dbs, common.BytesToHash(zkAccount.Root[:]), common.BytesToHash(mptAccount.Root[:]), checkStorageEquality)
+ accountsLeft--
+ t.Log("Accounts left:", accountsLeft)
+}
+
+func checkStorageEquality(t *testing.T, _ *dbs, zkStorageBytes, mptStorageBytes []byte) {
+ zkValue := common.BytesToHash(zkStorageBytes)
+ _, content, _, err := rlp.Split(mptStorageBytes)
+ require.NoError(t, err)
+ mptValue := common.BytesToHash(content)
+ require.Equal(t, zkValue, mptValue)
+}