diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 00d22179665d..a702ce44620a 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -150,6 +150,7 @@ var ( utils.ScrollAlphaFlag, utils.ScrollSepoliaFlag, utils.ScrollFlag, + utils.ScrollMPTFlag, utils.VMEnableDebugFlag, utils.NetworkIdFlag, utils.EthStatsURLFlag, diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index 26818ddfd14d..00111c9956f8 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -50,6 +50,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{ utils.ScrollAlphaFlag, utils.ScrollSepoliaFlag, utils.ScrollFlag, + utils.ScrollMPTFlag, utils.SyncModeFlag, utils.ExitWhenSyncedFlag, utils.GCModeFlag, diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index bccd6017b36e..0967f1000b6a 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -183,6 +183,10 @@ var ( Name: "scroll", Usage: "Scroll mainnet", } + ScrollMPTFlag = cli.BoolFlag{ + Name: "scroll-mpt", + Usage: "Use MPT trie for state storage", + } DeveloperFlag = cli.BoolFlag{ Name: "dev", Usage: "Ephemeral proof-of-authority network with a pre-funded developer account, mining enabled", @@ -1879,12 +1883,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) { stack.Config().L1Confirmations = rpc.FinalizedBlockNumber log.Info("Setting flag", "--l1.sync.startblock", "4038000") stack.Config().L1DeploymentBlock = 4038000 - // disable pruning - if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive { - log.Crit("Must use --gcmode=archive") + cfg.Genesis.Config.Scroll.UseZktrie = !ctx.GlobalBool(ScrollMPTFlag.Name) + if cfg.Genesis.Config.Scroll.UseZktrie { + // disable pruning + if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive { + log.Crit("Must use --gcmode=archive") + } + log.Info("Pruning disabled") + cfg.NoPruning = true } - log.Info("Pruning disabled") - cfg.NoPruning = true case ctx.GlobalBool(ScrollFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 534352 @@ -1895,12 +1902,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) { stack.Config().L1Confirmations = rpc.FinalizedBlockNumber log.Info("Setting flag", "--l1.sync.startblock", "18306000") stack.Config().L1DeploymentBlock = 18306000 - // disable pruning - if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive { - log.Crit("Must use --gcmode=archive") + cfg.Genesis.Config.Scroll.UseZktrie = !ctx.GlobalBool(ScrollMPTFlag.Name) + if cfg.Genesis.Config.Scroll.UseZktrie { + // disable pruning + if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive { + log.Crit("Must use --gcmode=archive") + } + log.Info("Pruning disabled") + cfg.NoPruning = true } - log.Info("Pruning disabled") - cfg.NoPruning = true case ctx.GlobalBool(DeveloperFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 1337 diff --git a/core/block_validator.go b/core/block_validator.go index fdc845682d96..9eb3accacdb4 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -226,7 +226,8 @@ func (v *BlockValidator) ValidateState(block *types.Block, statedb *state.StateD } // Validate the state root against the received state root and throw // an error if they don't match. - if root := statedb.IntermediateRoot(v.config.IsEIP158(header.Number)); header.Root != root { + shouldValidateStateRoot := v.config.Scroll.UseZktrie != v.config.IsEuclid(header.Time) + if root := statedb.IntermediateRoot(v.config.IsEIP158(header.Number)); shouldValidateStateRoot && header.Root != root { return fmt.Errorf("invalid merkle root (remote: %x local: %x)", header.Root, root) } return nil diff --git a/core/blockchain.go b/core/blockchain.go index a0bc05924531..57d82bc118a4 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -1318,6 +1318,9 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types. return NonStatTy, err } triedb := bc.stateCache.TrieDB() + if block.Root() != root { + rawdb.WriteDiskStateRoot(bc.db, block.Root(), root) + } // If we're running an archive node, always flush if bc.cacheConfig.TrieDirtyDisabled { @@ -1677,7 +1680,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er } // Enable prefetching to pull in trie node paths while processing transactions - statedb.StartPrefetcher("chain") + statedb.StartPrefetcher("chain", nil) activeState = statedb // If we have a followup block, run that against the current state to pre-cache @@ -1814,7 +1817,7 @@ func (bc *BlockChain) BuildAndWriteBlock(parentBlock *types.Block, header *types return NonStatTy, err } - statedb.StartPrefetcher("l1sync") + statedb.StartPrefetcher("l1sync", nil) defer statedb.StopPrefetcher() header.ParentHash = parentBlock.Hash() diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 13b75622b169..43169492333c 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -3032,15 +3032,16 @@ func TestPoseidonCodeHash(t *testing.T) { var callCreate2Code = common.Hex2Bytes("f4754f660000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000005c6080604052348015600f57600080fd5b50603f80601d6000396000f3fe6080604052600080fdfea2646970667358221220707985753fcb6578098bb16f3709cf6d012993cba6dd3712661cf8f57bbc0d4d64736f6c6343000807003300000000") var ( - key1, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") - addr1 = crypto.PubkeyToAddress(key1.PublicKey) - db = rawdb.NewMemoryDatabase() - gspec = &Genesis{Config: params.TestChainConfig, Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000000)}}} - genesis = gspec.MustCommit(db) - signer = types.LatestSigner(gspec.Config) - engine = ethash.NewFaker() - blockchain, _ = NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil, nil) + key1, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + addr1 = crypto.PubkeyToAddress(key1.PublicKey) + db = rawdb.NewMemoryDatabase() + gspec = &Genesis{Config: params.TestChainConfig, Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000000)}}} + signer = types.LatestSigner(gspec.Config) + engine = ethash.NewFaker() ) + gspec.Config.Scroll.UseZktrie = true + genesis := gspec.MustCommit(db) + blockchain, _ := NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil, nil) defer blockchain.Stop() @@ -3724,6 +3725,7 @@ func TestCurieTransition(t *testing.T) { config.CurieBlock = big.NewInt(2) config.DarwinTime = nil config.DarwinV2Time = nil + config.Scroll.UseZktrie = true var ( db = rawdb.NewMemoryDatabase() @@ -3748,7 +3750,7 @@ func TestCurieTransition(t *testing.T) { number := block.Number().Uint64() baseFee := block.BaseFee() - statedb, _ := state.New(block.Root(), state.NewDatabase(db), nil) + statedb, _ := state.New(block.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Zktrie: gspec.Config.Scroll.UseZktrie}), nil) code := statedb.GetCode(rcfg.L1GasPriceOracleAddress) codeSize := statedb.GetCodeSize(rcfg.L1GasPriceOracleAddress) diff --git a/core/chain_makers.go b/core/chain_makers.go index b79d92bf71e4..2b0a5856061b 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -29,6 +29,7 @@ import ( "github.com/scroll-tech/go-ethereum/ethdb" "github.com/scroll-tech/go-ethereum/params" "github.com/scroll-tech/go-ethereum/rollup/fees" + "github.com/scroll-tech/go-ethereum/trie" ) // BlockGen creates blocks for testing. @@ -264,7 +265,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse return nil, nil } for i := 0; i < n; i++ { - statedb, err := state.New(parent.Root(), state.NewDatabase(db), nil) + statedb, err := state.New(parent.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Zktrie: config.Scroll.ZktrieEnabled()}), nil) if err != nil { panic(err) } diff --git a/core/genesis.go b/core/genesis.go index 50a8e8843a4d..9554be0ac304 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -322,7 +322,10 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block { } statedb.Commit(false) statedb.Database().TrieDB().Commit(root, true, nil) - + if g.Config != nil && g.Config.Scroll.GenesisStateRoot != nil { + head.Root = *g.Config.Scroll.GenesisStateRoot + rawdb.WriteDiskStateRoot(db, head.Root, root) + } return types.NewBlock(head, nil, nil, nil, trie.NewStackTrie(nil)) } diff --git a/core/rawdb/accessors_state.go b/core/rawdb/accessors_state.go index f153af69f942..2738d9424668 100644 --- a/core/rawdb/accessors_state.go +++ b/core/rawdb/accessors_state.go @@ -94,3 +94,17 @@ func DeleteTrieNode(db ethdb.KeyValueWriter, hash common.Hash) { log.Crit("Failed to delete trie node", "err", err) } } + +func WriteDiskStateRoot(db ethdb.KeyValueWriter, headerRoot, diskRoot common.Hash) { + if err := db.Put(diskStateRootKey(headerRoot), diskRoot.Bytes()); err != nil { + log.Crit("Failed to store disk state root", "err", err) + } +} + +func ReadDiskStateRoot(db ethdb.KeyValueReader, headerRoot common.Hash) (common.Hash, error) { + data, err := db.Get(diskStateRootKey(headerRoot)) + if err != nil { + return common.Hash{}, err + } + return common.BytesToHash(data), nil +} diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index b4a51935b4ff..47b29c77d840 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -127,6 +127,8 @@ var ( // Scroll da syncer store daSyncedL1BlockNumberKey = []byte("LastDASyncedL1BlockNumber") + + diskStateRootPrefix = []byte("disk-state-root") ) // Use the updated "L1" prefix on all new networks @@ -312,3 +314,7 @@ func batchMetaKey(batchIndex uint64) []byte { func committedBatchMetaKey(batchIndex uint64) []byte { return append(committedBatchMetaPrefix, encodeBigEndian(batchIndex)...) } + +func diskStateRootKey(headerRoot common.Hash) []byte { + return append(diskStateRootPrefix, headerRoot.Bytes()...) +} diff --git a/core/state/database.go b/core/state/database.go index bb73fcecd216..9a58bc72246f 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -105,6 +105,9 @@ type Trie interface { // nodes of the longest existing prefix of the key (at least the root), ending // with the node that proves the absence of the key. Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error + + // Witness returns a set containing all trie nodes that have been accessed. + Witness() map[string]struct{} } // NewDatabase creates a backing store for state. The returned database is safe for @@ -136,6 +139,9 @@ type cachingDB struct { // OpenTrie opens the main account trie at a specific root hash. func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { + if diskRoot, err := rawdb.ReadDiskStateRoot(db.db.DiskDB(), root); err == nil { + root = diskRoot + } if db.zktrie { tr, err := trie.NewZkTrie(root, trie.NewZktrieDatabaseFromTriedb(db.db)) if err != nil { diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go index e5e2b420018a..72e6d134d59b 100644 --- a/core/state/snapshot/generate.go +++ b/core/state/snapshot/generate.go @@ -618,8 +618,8 @@ func (dl *diskLayer) generate(stats *generatorStats) { Balance *big.Int Root common.Hash KeccakCodeHash []byte - PoseidonCodeHash []byte - CodeSize uint64 + PoseidonCodeHash []byte `rlp:"-"` + CodeSize uint64 `rlp:"-"` } if err := rlp.DecodeBytes(val, &acc); err != nil { log.Crit("Invalid account encountered during snapshot creation", "err", err) diff --git a/core/state/state_object.go b/core/state/state_object.go index f9213a0a31d7..4fb9e82c2ed7 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -500,8 +500,18 @@ func (s *stateObject) Code(db Database) []byte { // CodeSize returns the size of the contract code associated with this object, // or zero if none. This method is an almost mirror of Code, but uses a cache // inside the database to avoid loading codes seen recently. -func (s *stateObject) CodeSize() uint64 { - return s.data.CodeSize +func (s *stateObject) CodeSize(db Database) uint64 { + if s.code != nil { + return uint64(len(s.code)) + } + if bytes.Equal(s.KeccakCodeHash(), emptyKeccakCodeHash) { + return 0 + } + size, err := db.ContractCodeSize(s.addrHash, common.BytesToHash(s.KeccakCodeHash())) + if err != nil { + s.setError(fmt.Errorf("can't load code size %x: %v", s.KeccakCodeHash(), err)) + } + return uint64(size) } func (s *stateObject) SetCode(code []byte) { @@ -534,6 +544,9 @@ func (s *stateObject) setNonce(nonce uint64) { } func (s *stateObject) PoseidonCodeHash() []byte { + if !s.db.IsZktrie() { + panic("PoseidonCodeHash is only available in zktrie mode") + } return s.data.PoseidonCodeHash } diff --git a/core/state/state_test.go b/core/state/state_test.go index ea98b2dab833..8b53cb2eca2a 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -155,7 +155,8 @@ func TestSnapshotEmpty(t *testing.T) { } func TestSnapshot2(t *testing.T) { - state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + stateDb := NewDatabase(rawdb.NewMemoryDatabase()) + state, _ := New(common.Hash{}, stateDb, nil) stateobjaddr0 := common.BytesToAddress([]byte("so0")) stateobjaddr1 := common.BytesToAddress([]byte("so1")) @@ -201,7 +202,7 @@ func TestSnapshot2(t *testing.T) { so0Restored.GetState(state.db, storageaddr) so0Restored.Code(state.db) // non-deleted is equal (restored) - compareStateObjects(so0Restored, so0, t) + compareStateObjects(so0Restored, so0, stateDb, t) // deleted should be nil, both before and after restore of state copy so1Restored := state.getStateObject(stateobjaddr1) @@ -210,7 +211,7 @@ func TestSnapshot2(t *testing.T) { } } -func compareStateObjects(so0, so1 *stateObject, t *testing.T) { +func compareStateObjects(so0, so1 *stateObject, db Database, t *testing.T) { if so0.Address() != so1.Address() { t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address) } @@ -229,8 +230,8 @@ func compareStateObjects(so0, so1 *stateObject, t *testing.T) { if !bytes.Equal(so0.PoseidonCodeHash(), so1.PoseidonCodeHash()) { t.Fatalf("PoseidonCodeHash mismatch: have %v, want %v", so0.PoseidonCodeHash(), so1.PoseidonCodeHash()) } - if so0.CodeSize() != so1.CodeSize() { - t.Fatalf("CodeSize mismatch: have %v, want %v", so0.CodeSize(), so1.CodeSize()) + if so0.CodeSize(db) != so1.CodeSize(db) { + t.Fatalf("CodeSize mismatch: have %v, want %v", so0.CodeSize(db), so1.CodeSize(db)) } if !bytes.Equal(so0.code, so1.code) { t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code) diff --git a/core/state/statedb.go b/core/state/statedb.go index 6629a50eae57..7affd81fc409 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -29,6 +29,7 @@ import ( "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/core/rawdb" "github.com/scroll-tech/go-ethereum/core/state/snapshot" + "github.com/scroll-tech/go-ethereum/core/stateless" "github.com/scroll-tech/go-ethereum/core/types" "github.com/scroll-tech/go-ethereum/crypto" "github.com/scroll-tech/go-ethereum/log" @@ -106,6 +107,9 @@ type StateDB struct { validRevisions []revision nextRevisionId int + // State witness if cross validation is needed + witness *stateless.Witness + // Measurements gathered during execution for debugging purposes AccountReads time.Duration AccountHashes time.Duration @@ -159,11 +163,15 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) // StartPrefetcher initializes a new trie prefetcher to pull in nodes from the // state trie concurrently while the state is mutated so that when we reach the // commit phase, most of the needed data is already hot. -func (s *StateDB) StartPrefetcher(namespace string) { +func (s *StateDB) StartPrefetcher(namespace string, witness *stateless.Witness) { if s.prefetcher != nil { s.prefetcher.close() s.prefetcher = nil } + + // Enable witness collection if requested + s.witness = witness + if s.snap != nil { s.prefetcher = newTriePrefetcher(s.db, s.originalRoot, namespace) } @@ -289,6 +297,9 @@ func (s *StateDB) TxIndex() int { func (s *StateDB) GetCode(addr common.Address) []byte { stateObject := s.getStateObject(addr) if stateObject != nil { + if s.witness != nil { + s.witness.AddCode(stateObject.Code(s.db)) + } return stateObject.Code(s.db) } return nil @@ -297,7 +308,10 @@ func (s *StateDB) GetCode(addr common.Address) []byte { func (s *StateDB) GetCodeSize(addr common.Address) uint64 { stateObject := s.getStateObject(addr) if stateObject != nil { - return stateObject.CodeSize() + if s.witness != nil { + s.witness.AddCode(stateObject.Code(s.db)) + } + return stateObject.CodeSize(s.db) } return 0 } @@ -725,6 +739,9 @@ func (s *StateDB) Copy() *StateDB { journal: newJournal(), hasher: crypto.NewKeccakState(), } + if s.witness != nil { + state.witness = s.witness.Copy() + } // Copy the dirty states, logs, and preimages for addr := range s.journal.dirties { // As documented [here](https://github.com/scroll-tech/go-ethereum/pull/16485#issuecomment-380438527), @@ -913,7 +930,33 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { // to pull useful data from disk. for addr := range s.stateObjectsPending { if obj := s.stateObjects[addr]; !obj.deleted { + + // If witness building is enabled and the state object has a trie, + // gather the witnesses for its specific storage trie + if s.witness != nil && obj.trie != nil { + s.witness.AddState(obj.trie.Witness()) + } + obj.updateRoot(s.db) + + // If witness building is enabled and the state object has a trie, + // gather the witnesses for its specific storage trie + if s.witness != nil && obj.trie != nil { + s.witness.AddState(obj.trie.Witness()) + } + } + } + + if s.witness != nil { + // If witness building is enabled, gather the account trie witness for read-only operations + for _, obj := range s.stateObjects { + if len(obj.originStorage) == 0 { + continue + } + + if trie := obj.getTrie(s.db); trie != nil { + s.witness.AddState(trie.Witness()) + } } } // Now we're about to start to write changes to the trie. The trie is so far @@ -945,7 +988,13 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { if metrics.EnabledExpensive { defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now()) } - return s.trie.Hash() + + hash := s.trie.Hash() + // If witness building is enabled, gather the account trie witness + if s.witness != nil { + s.witness.AddState(s.trie.Witness()) + } + return hash } // SetTxContext sets the current transaction hash and index which are diff --git a/core/stateless/database.go b/core/stateless/database.go new file mode 100644 index 000000000000..e6278a98f872 --- /dev/null +++ b/core/stateless/database.go @@ -0,0 +1,67 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package stateless + +import ( + "github.com/scroll-tech/go-ethereum/common" + "github.com/scroll-tech/go-ethereum/core/rawdb" + "github.com/scroll-tech/go-ethereum/crypto" + "github.com/scroll-tech/go-ethereum/ethdb" +) + +// MakeHashDB imports tries, codes and block hashes from a witness into a new +// hash-based memory db. We could eventually rewrite this into a pathdb, but +// simple is better for now. +// +// Note, this hashdb approach is quite strictly self-validating: +// - Headers are persisted keyed by hash, so blockhash will error on junk +// - Codes are persisted keyed by hash, so bytecode lookup will error on junk +// - Trie nodes are persisted keyed by hash, so trie expansion will error on junk +// +// Acceleration structures built would need to explicitly validate the witness. +func (w *Witness) MakeHashDB() ethdb.Database { + var ( + memdb = rawdb.NewMemoryDatabase() + hasher = crypto.NewKeccakState() + hash = make([]byte, 32) + ) + // Inject all the "block hashes" (i.e. headers) into the ephemeral database + for _, header := range w.Headers { + rawdb.WriteHeader(memdb, header) + } + // Inject all the bytecodes into the ephemeral database + for code := range w.Codes { + blob := []byte(code) + + hasher.Reset() + hasher.Write(blob) + hasher.Read(hash) + + rawdb.WriteCode(memdb, common.BytesToHash(hash), blob) + } + // Inject all the MPT trie nodes into the ephemeral database + for node := range w.State { + blob := []byte(node) + + hasher.Reset() + hasher.Write(blob) + hasher.Read(hash) + + rawdb.WriteTrieNode(memdb, common.BytesToHash(hash), blob) + } + return memdb +} diff --git a/core/stateless/encoding.go b/core/stateless/encoding.go new file mode 100644 index 000000000000..b67b7460924a --- /dev/null +++ b/core/stateless/encoding.go @@ -0,0 +1,76 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package stateless + +import ( + "io" + + "github.com/scroll-tech/go-ethereum/core/types" + "github.com/scroll-tech/go-ethereum/rlp" +) + +// toExtWitness converts our internal witness representation to the consensus one. +func (w *Witness) toExtWitness() *extWitness { + ext := &extWitness{ + Headers: w.Headers, + } + ext.Codes = make([][]byte, 0, len(w.Codes)) + for code := range w.Codes { + ext.Codes = append(ext.Codes, []byte(code)) + } + ext.State = make([][]byte, 0, len(w.State)) + for node := range w.State { + ext.State = append(ext.State, []byte(node)) + } + return ext +} + +// fromExtWitness converts the consensus witness format into our internal one. +func (w *Witness) fromExtWitness(ext *extWitness) error { + w.Headers = ext.Headers + + w.Codes = make(map[string]struct{}, len(ext.Codes)) + for _, code := range ext.Codes { + w.Codes[string(code)] = struct{}{} + } + w.State = make(map[string]struct{}, len(ext.State)) + for _, node := range ext.State { + w.State[string(node)] = struct{}{} + } + return nil +} + +// EncodeRLP serializes a witness as RLP. +func (w *Witness) EncodeRLP(wr io.Writer) error { + return rlp.Encode(wr, w.toExtWitness()) +} + +// DecodeRLP decodes a witness from RLP. +func (w *Witness) DecodeRLP(s *rlp.Stream) error { + var ext extWitness + if err := s.Decode(&ext); err != nil { + return err + } + return w.fromExtWitness(&ext) +} + +// extWitness is a witness RLP encoding for transferring across clients. +type extWitness struct { + Headers []*types.Header + Codes [][]byte + State [][]byte +} diff --git a/core/stateless/witness.go b/core/stateless/witness.go new file mode 100644 index 000000000000..10e4b08ca1cb --- /dev/null +++ b/core/stateless/witness.go @@ -0,0 +1,122 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package stateless + +import ( + "errors" + "maps" + "slices" + "sync" + + "github.com/scroll-tech/go-ethereum/common" + "github.com/scroll-tech/go-ethereum/core/types" +) + +// HeaderReader is an interface to pull in headers in place of block hashes for +// the witness. +type HeaderReader interface { + // GetHeader retrieves a block header from the database by hash and number, + GetHeader(hash common.Hash, number uint64) *types.Header +} + +// Witness encompasses the state required to apply a set of transactions and +// derive a post state/receipt root. +type Witness struct { + context *types.Header // Header to which this witness belongs to, with rootHash and receiptHash zeroed out + + Headers []*types.Header // Past headers in reverse order (0=parent, 1=parent's-parent, etc). First *must* be set. + Codes map[string]struct{} // Set of bytecodes ran or accessed + State map[string]struct{} // Set of MPT state trie nodes (account and storage together) + + chain HeaderReader // Chain reader to convert block hash ops to header proofs + lock sync.Mutex // Lock to allow concurrent state insertions +} + +// NewWitness creates an empty witness ready for population. +func NewWitness(context *types.Header, chain HeaderReader) (*Witness, error) { + // When building witnesses, retrieve the parent header, which will *always* + // be included to act as a trustless pre-root hash container + var headers []*types.Header + if chain != nil { + parent := chain.GetHeader(context.ParentHash, context.Number.Uint64()-1) + if parent == nil { + return nil, errors.New("failed to retrieve parent header") + } + headers = append(headers, parent) + } + // Create the wtness with a reconstructed gutted out block + return &Witness{ + context: context, + Headers: headers, + Codes: make(map[string]struct{}), + State: make(map[string]struct{}), + chain: chain, + }, nil +} + +// AddBlockHash adds a "blockhash" to the witness with the designated offset from +// chain head. Under the hood, this method actually pulls in enough headers from +// the chain to cover the block being added. +func (w *Witness) AddBlockHash(number uint64) { + // Keep pulling in headers until this hash is populated + for int(w.context.Number.Uint64()-number) > len(w.Headers) { + tail := w.Headers[len(w.Headers)-1] + w.Headers = append(w.Headers, w.chain.GetHeader(tail.ParentHash, tail.Number.Uint64()-1)) + } +} + +// AddCode adds a bytecode blob to the witness. +func (w *Witness) AddCode(code []byte) { + if len(code) == 0 { + return + } + w.Codes[string(code)] = struct{}{} +} + +// AddState inserts a batch of MPT trie nodes into the witness. +func (w *Witness) AddState(nodes map[string]struct{}) { + if len(nodes) == 0 { + return + } + w.lock.Lock() + defer w.lock.Unlock() + + maps.Copy(w.State, nodes) +} + +// Copy deep-copies the witness object. Witness.Block isn't deep-copied as it +// is never mutated by Witness +func (w *Witness) Copy() *Witness { + cpy := &Witness{ + Headers: slices.Clone(w.Headers), + Codes: maps.Clone(w.Codes), + State: maps.Clone(w.State), + chain: w.chain, + } + if w.context != nil { + cpy.context = types.CopyHeader(w.context) + } + return cpy +} + +// Root returns the pre-state root from the first header. +// +// Note, this method will panic in case of a bad witness (but RLP decoding will +// sanitize it and fail before that). +func (w *Witness) Root() common.Hash { + return w.Headers[0].Root +} diff --git a/core/types/state_account.go b/core/types/state_account.go index bb396d439d9e..bde0331d00b3 100644 --- a/core/types/state_account.go +++ b/core/types/state_account.go @@ -31,6 +31,6 @@ type StateAccount struct { KeccakCodeHash []byte // StateAccount Scroll extensions - PoseidonCodeHash []byte - CodeSize uint64 + PoseidonCodeHash []byte `rlp:"-"` + CodeSize uint64 `rlp:"-"` } diff --git a/eth/api.go b/eth/api.go index 74672347823c..07f5874037c8 100644 --- a/eth/api.go +++ b/eth/api.go @@ -34,11 +34,14 @@ import ( "github.com/scroll-tech/go-ethereum/core" "github.com/scroll-tech/go-ethereum/core/rawdb" "github.com/scroll-tech/go-ethereum/core/state" + "github.com/scroll-tech/go-ethereum/core/stateless" "github.com/scroll-tech/go-ethereum/core/types" + "github.com/scroll-tech/go-ethereum/crypto" "github.com/scroll-tech/go-ethereum/internal/ethapi" "github.com/scroll-tech/go-ethereum/log" "github.com/scroll-tech/go-ethereum/rlp" "github.com/scroll-tech/go-ethereum/rollup/ccc" + "github.com/scroll-tech/go-ethereum/rollup/rcfg" "github.com/scroll-tech/go-ethereum/rpc" "github.com/scroll-tech/go-ethereum/trie" ) @@ -321,6 +324,109 @@ func (api *PublicDebugAPI) DumpBlock(blockNr rpc.BlockNumber) (state.Dump, error return stateDb.RawDump(opts), nil } +func (api *PublicDebugAPI) ExecutionWitness(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*ExecutionWitness, error) { + block, err := api.eth.APIBackend.BlockByNumberOrHash(ctx, blockNrOrHash) + if err != nil { + return nil, fmt.Errorf("failed to retrieve block: %w", err) + } + if block == nil { + return nil, fmt.Errorf("block not found: %s", blockNrOrHash.String()) + } + + witness, err := generateWitness(api.eth.blockchain, block) + return ToExecutionWitness(witness), err +} + +func generateWitness(blockchain *core.BlockChain, block *types.Block) (*stateless.Witness, error) { + witness, err := stateless.NewWitness(block.Header(), blockchain) + if err != nil { + return nil, fmt.Errorf("failed to create witness: %w", err) + } + + parentHeader := witness.Headers[0] + statedb, err := blockchain.StateAt(parentHeader.Root) + if err != nil { + return nil, fmt.Errorf("failed to retrieve parent state: %w", err) + } + + // Collect storage locations that prover needs but sequencer might not touch necessarily + statedb.GetState(rcfg.L2MessageQueueAddress, rcfg.WithdrawTrieRootSlot) + + statedb.StartPrefetcher("debug_execution_witness", witness) + defer statedb.StopPrefetcher() + + receipts, _, usedGas, err := blockchain.Processor().Process(block, statedb, *blockchain.GetVMConfig()) + if err != nil { + return nil, fmt.Errorf("failed to process block %d: %w", block.Number(), err) + } + + if err := blockchain.Validator().ValidateState(block, statedb, receipts, usedGas); err != nil { + return nil, fmt.Errorf("failed to validate block %d: %w", block.Number(), err) + } + return witness, testWitness(blockchain, block, witness) +} + +func testWitness(blockchain *core.BlockChain, block *types.Block, witness *stateless.Witness) error { + stateRoot := witness.Root() + if diskRoot, _ := rawdb.ReadDiskStateRoot(blockchain.Database(), stateRoot); diskRoot != (common.Hash{}) { + stateRoot = diskRoot + } + + // Create and populate the state database to serve as the stateless backend + statedb, err := state.New(stateRoot, state.NewDatabase(witness.MakeHashDB()), nil) + if err != nil { + return fmt.Errorf("failed to create state database: %w", err) + } + + receipts, _, usedGas, err := blockchain.Processor().Process(block, statedb, *blockchain.GetVMConfig()) + if err != nil { + return fmt.Errorf("failed to process block %d: %w", block.Number(), err) + } + + if err := blockchain.Validator().ValidateState(block, statedb, receipts, usedGas); err != nil { + return fmt.Errorf("failed to validate block %d: %w", block.Number(), err) + } + + postStateRoot := block.Root() + if diskRoot, _ := rawdb.ReadDiskStateRoot(blockchain.Database(), postStateRoot); diskRoot != (common.Hash{}) { + postStateRoot = diskRoot + } + if statedb.GetRootHash() != postStateRoot { + return fmt.Errorf("failed to commit statelessly %d: %w", block.Number(), err) + } + return nil +} + +// ExecutionWitness is a witness json encoding for transferring across the network. +// In the future, we'll probably consider using the extWitness format instead for less overhead if performance becomes an issue. +// Currently using this format for ease of reading, parsing and compatibility across clients. +type ExecutionWitness struct { + Headers []*types.Header `json:"headers"` + Codes map[string]string `json:"codes"` + State map[string]string `json:"state"` +} + +func transformMap(in map[string]struct{}) map[string]string { + out := make(map[string]string, len(in)) + for item := range in { + bytes := []byte(item) + key := crypto.Keccak256Hash(bytes).Hex() + out[key] = hexutil.Encode(bytes) + } + return out +} + +// ToExecutionWitness converts a witness to an execution witness format that is compatible with reth. +// keccak(node) => node +// keccak(bytecodes) => bytecodes +func ToExecutionWitness(w *stateless.Witness) *ExecutionWitness { + return &ExecutionWitness{ + Headers: w.Headers, + Codes: transformMap(w.Codes), + State: transformMap(w.State), + } +} + // PrivateDebugAPI is the collection of Ethereum full node APIs exposed over // the private debugging endpoint. type PrivateDebugAPI struct { @@ -859,3 +965,30 @@ func (api *ScrollAPI) CalculateRowConsumptionByBlockNumber(ctx context.Context, asyncChecker.Wait() return rawdb.ReadBlockRowConsumption(api.eth.ChainDb(), block.Hash()), checkErr } + +type DiskAndHeaderRoot struct { + DiskRoot common.Hash `json:"diskRoot"` + HeaderRoot common.Hash `json:"headerRoot"` +} + +// CalculateRowConsumptionByBlockNumber +func (api *ScrollAPI) DiskRoot(ctx context.Context, blockNrOrHash *rpc.BlockNumberOrHash) (DiskAndHeaderRoot, error) { + block, err := api.eth.APIBackend.BlockByNumberOrHash(ctx, *blockNrOrHash) + if err != nil { + return DiskAndHeaderRoot{}, fmt.Errorf("failed to retrieve block: %w", err) + } + if block == nil { + return DiskAndHeaderRoot{}, fmt.Errorf("block not found: %s", blockNrOrHash.String()) + } + + if diskRoot, _ := rawdb.ReadDiskStateRoot(api.eth.ChainDb(), block.Root()); diskRoot != (common.Hash{}) { + return DiskAndHeaderRoot{ + DiskRoot: diskRoot, + HeaderRoot: block.Root(), + }, nil + } + return DiskAndHeaderRoot{ + DiskRoot: block.Root(), + HeaderRoot: block.Root(), + }, nil +} diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index d43e01467ee3..8de02ae6f57e 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -482,6 +482,13 @@ web3._extend({ params: 2, inputFormatter:[web3._extend.formatters.inputBlockNumberFormatter, web3._extend.formatters.inputBlockNumberFormatter], }), + new web3._extend.Method({ + name: 'executionWitness', + call: 'debug_executionWitness', + params: 1, + inputFormatter: [null] + }), + ], properties: [] }); @@ -942,6 +949,13 @@ web3._extend({ params: 1, inputFormatter: [web3._extend.formatters.inputBlockNumberFormatter] }), + new web3._extend.Method({ + name: 'diskRoot', + call: 'scroll_diskRoot', + params: 1, + inputFormatter: [web3._extend.formatters.inputDefaultBlockNumberFormatter], + }), + ], properties: [ diff --git a/light/trie.go b/light/trie.go index 1947c314090b..148814f3fbc5 100644 --- a/light/trie.go +++ b/light/trie.go @@ -181,6 +181,11 @@ func (t *odrTrie) do(key []byte, fn func() error) error { } } +// Witness returns a set containing all trie nodes that have been accessed. +func (t *odrTrie) Witness() map[string]struct{} { + panic("not implemented") +} + type nodeIterator struct { trie.NodeIterator t *odrTrie diff --git a/miner/scroll_worker.go b/miner/scroll_worker.go index e152878d40e6..a6aa7c21c207 100644 --- a/miner/scroll_worker.go +++ b/miner/scroll_worker.go @@ -372,7 +372,7 @@ func (w *worker) mainLoop() { select { case <-w.startCh: idleTimer.UpdateSince(idleStart) - if w.isRunning() { + if w.isRunning() && w.chainConfig.Scroll.UseZktrie { if err := w.checkHeadRowConsumption(); err != nil { log.Error("failed to start head checkers", "err", err) return @@ -490,9 +490,10 @@ func (w *worker) newWork(now time.Time, parentHash common.Hash, reorging bool, r vmConfig := *w.chain.GetVMConfig() cccLogger := ccc.NewLogger() - vmConfig.Debug = true - vmConfig.Tracer = cccLogger - + if w.chainConfig.Scroll.UseZktrie { + vmConfig.Debug = true + vmConfig.Tracer = cccLogger + } deadline := time.Unix(int64(header.Time), 0) if w.chainConfig.Clique != nil && w.chainConfig.Clique.RelaxedPeriod { // clique with relaxed period uses time.Now() as the header.Time, calculate the deadline @@ -566,6 +567,11 @@ func (w *worker) handleForks() (bool, error) { misc.ApplyCurieHardFork(w.current.state) return true, nil } + + if w.chainConfig.IsEuclid(w.current.header.Time) { + parent := w.chain.GetBlockByHash(w.current.header.ParentHash) + return parent != nil && !w.chainConfig.IsEuclid(parent.Time()), nil + } return false, nil } @@ -809,7 +815,10 @@ func (w *worker) commit() (common.Hash, error) { }(time.Now()) w.updateSnapshot() - if !w.isRunning() && !w.current.reorging { + // Since clocks of mpt-sequencer and zktrie-sequencer can be slightly out of sync, + // this might result in a reorg at the Euclid fork block. But it will be resolved shortly after. + canCommitState := w.chainConfig.Scroll.UseZktrie != w.chainConfig.IsEuclid(w.current.header.Time) + if !canCommitState || (!w.isRunning() && !w.current.reorging) { return common.Hash{}, nil } @@ -886,7 +895,7 @@ func (w *worker) commit() (common.Hash, error) { currentHeight := w.current.header.Number.Uint64() maxReorgDepth := uint64(w.config.CCCMaxWorkers + 1) - if !w.current.reorging && currentHeight > maxReorgDepth { + if w.chainConfig.Scroll.UseZktrie && !w.current.reorging && currentHeight > maxReorgDepth { ancestorHeight := currentHeight - maxReorgDepth ancestorHash := w.chain.GetHeaderByNumber(ancestorHeight).Hash() if rawdb.ReadBlockRowConsumption(w.chain.Database(), ancestorHash) == nil { @@ -914,8 +923,10 @@ func (w *worker) commit() (common.Hash, error) { w.mux.Post(core.NewMinedBlockEvent{Block: block}) checkStart := time.Now() - if err = w.asyncChecker.Check(block); err != nil { - log.Error("failed to launch CCC background task", "err", err) + if w.chainConfig.Scroll.UseZktrie { + if err = w.asyncChecker.Check(block); err != nil { + log.Error("failed to launch CCC background task", "err", err) + } } cccStallTimer.UpdateSince(checkStart) diff --git a/miner/scroll_worker_test.go b/miner/scroll_worker_test.go index 5f79902f0e15..3e8468348f53 100644 --- a/miner/scroll_worker_test.go +++ b/miner/scroll_worker_test.go @@ -232,6 +232,7 @@ func testGenerateBlockAndImport(t *testing.T, isClique bool) { engine = ethash.NewFaker() } chainConfig.Scroll.FeeVaultAddress = &common.Address{} + chainConfig.Scroll.UseZktrie = true chainConfig.LondonBlock = big.NewInt(0) w, b := newTestWorker(t, chainConfig, engine, db, 0) @@ -296,6 +297,7 @@ func testGenerateBlockWithL1Msg(t *testing.T, isClique bool) { NumL1MessagesPerBlock: 1, } chainConfig.Scroll.FeeVaultAddress = &common.Address{} + chainConfig.Scroll.UseZktrie = true chainConfig.LondonBlock = big.NewInt(0) w, b := newTestWorker(t, chainConfig, engine, db, 0) @@ -344,6 +346,7 @@ func TestAcceptableTxlimit(t *testing.T) { chainConfig = params.AllCliqueProtocolChanges chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000} chainConfig.Scroll.FeeVaultAddress = &common.Address{} + chainConfig.Scroll.UseZktrie = true engine = clique.New(chainConfig.Clique, db) // Set maxTxPerBlock = 4, which >= non-l1msg + non-skipped l1msg txs @@ -404,6 +407,7 @@ func TestUnacceptableTxlimit(t *testing.T) { chainConfig = params.AllCliqueProtocolChanges chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000} chainConfig.Scroll.FeeVaultAddress = &common.Address{} + chainConfig.Scroll.UseZktrie = true engine = clique.New(chainConfig.Clique, db) // Set maxTxPerBlock = 3, which < non-l1msg + l1msg txs @@ -463,6 +467,7 @@ func TestL1MsgCorrectOrder(t *testing.T) { chainConfig = params.AllCliqueProtocolChanges chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000} chainConfig.Scroll.FeeVaultAddress = &common.Address{} + chainConfig.Scroll.UseZktrie = true engine = clique.New(chainConfig.Clique, db) maxTxPerBlock := 4 @@ -525,6 +530,7 @@ func l1MessageTest(t *testing.T, msgs []types.L1MessageTx, withL2Tx bool, callba chainConfig = params.AllCliqueProtocolChanges chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000} + chainConfig.Scroll.UseZktrie = true engine = clique.New(chainConfig.Clique, db) maxTxPerBlock := 4 chainConfig.Scroll.MaxTxPerBlock = &maxTxPerBlock @@ -879,6 +885,7 @@ func TestPrioritizeOverflowTx(t *testing.T) { chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000} chainConfig.LondonBlock = big.NewInt(0) chainConfig.Scroll.FeeVaultAddress = &common.Address{} + chainConfig.Scroll.UseZktrie = true engine := clique.New(chainConfig.Clique, db) w, b := newTestWorker(t, chainConfig, engine, db, 0) @@ -1036,6 +1043,7 @@ func TestPending(t *testing.T) { chainConfig = params.AllCliqueProtocolChanges chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000} chainConfig.Scroll.FeeVaultAddress = &common.Address{} + chainConfig.Scroll.UseZktrie = true engine = clique.New(chainConfig.Clique, db) w, b := newTestWorker(t, chainConfig, engine, db, 0) defer w.close() @@ -1080,6 +1088,7 @@ func TestReorg(t *testing.T) { chainConfig = params.AllCliqueProtocolChanges chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000, RelaxedPeriod: true} chainConfig.Scroll.FeeVaultAddress = &common.Address{} + chainConfig.Scroll.UseZktrie = true engine = clique.New(chainConfig.Clique, db) maxTxPerBlock := 2 @@ -1194,6 +1203,7 @@ func TestRestartHeadCCC(t *testing.T) { chainConfig = params.AllCliqueProtocolChanges chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000, RelaxedPeriod: true} chainConfig.Scroll.FeeVaultAddress = &common.Address{} + chainConfig.Scroll.UseZktrie = true engine = clique.New(chainConfig.Clique, db) maxTxPerBlock := 2 diff --git a/params/config.go b/params/config.go index eb2213fdd360..dc14a8b0769c 100644 --- a/params/config.go +++ b/params/config.go @@ -29,14 +29,16 @@ import ( // Genesis hashes to enforce below configs on. var ( - MainnetGenesisHash = common.HexToHash("0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3") - RopstenGenesisHash = common.HexToHash("0x41941023680923e0fe4d74a34bdac8141f2540e3ae90623718e47d66d1ca4a2d") - SepoliaGenesisHash = common.HexToHash("0x25a5cc106eea7138acab33231d7160d69cb777ee0c2c553fcddf5138993e6dd9") - RinkebyGenesisHash = common.HexToHash("0x6341fd3daf94b748c72ced5a5b26028f2474f5f00d824504e4fa37a75767e177") - GoerliGenesisHash = common.HexToHash("0xbf7e331f7f7c1dd2e05159666b3bf8bc7a8a3a9eb1d518969eab529dd9b88c1a") - ScrollAlphaGenesisHash = common.HexToHash("0xa4fc62b9b0643e345bdcebe457b3ae898bef59c7203c3db269200055e037afda") - ScrollSepoliaGenesisHash = common.HexToHash("0xaa62d1a8b2bffa9e5d2368b63aae0d98d54928bd713125e3fd9e5c896c68592c") - ScrollMainnetGenesisHash = common.HexToHash("0xbbc05efd412b7cd47a2ed0e5ddfcf87af251e414ea4c801d78b6784513180a80") + MainnetGenesisHash = common.HexToHash("0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3") + RopstenGenesisHash = common.HexToHash("0x41941023680923e0fe4d74a34bdac8141f2540e3ae90623718e47d66d1ca4a2d") + SepoliaGenesisHash = common.HexToHash("0x25a5cc106eea7138acab33231d7160d69cb777ee0c2c553fcddf5138993e6dd9") + RinkebyGenesisHash = common.HexToHash("0x6341fd3daf94b748c72ced5a5b26028f2474f5f00d824504e4fa37a75767e177") + GoerliGenesisHash = common.HexToHash("0xbf7e331f7f7c1dd2e05159666b3bf8bc7a8a3a9eb1d518969eab529dd9b88c1a") + ScrollAlphaGenesisHash = common.HexToHash("0xa4fc62b9b0643e345bdcebe457b3ae898bef59c7203c3db269200055e037afda") + ScrollSepoliaGenesisHash = common.HexToHash("0xaa62d1a8b2bffa9e5d2368b63aae0d98d54928bd713125e3fd9e5c896c68592c") + ScrollMainnetGenesisHash = common.HexToHash("0xbbc05efd412b7cd47a2ed0e5ddfcf87af251e414ea4c801d78b6784513180a80") + ScrollSepoliaGenesisState = common.HexToHash("0x20695989e9038823e35f0e88fbc44659ffdbfa1fe89fbeb2689b43f15fa64cb5") + ScrollMainnetGenesisState = common.HexToHash("0x08d535cc60f40af5dd3b31e0998d7567c2d568b224bed2ba26070aeb078d1339") ) func newUint64(val uint64) *uint64 { return &val } @@ -340,6 +342,7 @@ var ( NumL1MessagesPerBlock: 10, ScrollChainAddress: common.HexToAddress("0x2D567EcE699Eabe5afCd141eDB7A4f2D0D6ce8a0"), }, + GenesisStateRoot: &ScrollSepoliaGenesisState, }, } @@ -380,6 +383,7 @@ var ( NumL1MessagesPerBlock: 10, ScrollChainAddress: common.HexToAddress("0xa13BAF47339d63B743e7Da8741db5456DAc1E556"), }, + GenesisStateRoot: &ScrollMainnetGenesisState, }, } @@ -633,6 +637,7 @@ type ChainConfig struct { CurieBlock *big.Int `json:"curieBlock,omitempty"` // Curie switch block (nil = no fork, 0 = already on curie) DarwinTime *uint64 `json:"darwinTime,omitempty"` // Darwin switch time (nil = no fork, 0 = already on darwin) DarwinV2Time *uint64 `json:"darwinv2Time,omitempty"` // DarwinV2 switch time (nil = no fork, 0 = already on darwinv2) + EuclidTime *uint64 `json:"euclidTime,omitempty"` // Euclid switch time (nil = no fork, 0 = already on euclid) // TerminalTotalDifficulty is the amount of total difficulty reached by // the network that triggers the consensus upgrade. @@ -661,6 +666,9 @@ type ScrollConfig struct { // L1 config L1Config *L1Config `json:"l1Config,omitempty"` + + // Genesis State Root for MPT clients + GenesisStateRoot *common.Hash `json:"genesisStateRoot,omitempty"` } // L1Config contains the l1 parameters needed to sync l1 contract events (e.g., l1 messages, commit/revert/finalize batches) in the sequencer @@ -888,6 +896,11 @@ func (c *ChainConfig) IsDarwinV2(now uint64) bool { return isForkedTime(now, c.DarwinV2Time) } +// IsEuclid returns whether num is either equal to the Darwin fork block or greater. +func (c *ChainConfig) IsEuclid(now uint64) bool { + return isForkedTime(now, c.EuclidTime) +} + // IsTerminalPoWBlock returns whether the given block is the last block of PoW stage. func (c *ChainConfig) IsTerminalPoWBlock(parentTotalDiff *big.Int, totalDiff *big.Int) bool { if c.TerminalTotalDifficulty == nil { @@ -1100,7 +1113,7 @@ type Rules struct { IsHomestead, IsEIP150, IsEIP155, IsEIP158 bool IsByzantium, IsConstantinople, IsPetersburg, IsIstanbul bool IsBerlin, IsLondon, IsArchimedes, IsShanghai bool - IsBernoulli, IsCurie, IsDarwin bool + IsBernoulli, IsCurie, IsDarwin, IsEuclid bool } // Rules ensures c's ChainID is not nil. @@ -1126,5 +1139,6 @@ func (c *ChainConfig) Rules(num *big.Int, time uint64) Rules { IsBernoulli: c.IsBernoulli(num), IsCurie: c.IsCurie(num), IsDarwin: c.IsDarwin(time), + IsEuclid: c.IsEuclid(time), } } diff --git a/rollup/ccc/async_checker.go b/rollup/ccc/async_checker.go index 1cb9b7d78768..b5815cb58403 100644 --- a/rollup/ccc/async_checker.go +++ b/rollup/ccc/async_checker.go @@ -98,6 +98,11 @@ func (c *AsyncChecker) Wait() { // Check spawns an async CCC verification task. func (c *AsyncChecker) Check(block *types.Block) error { + if c.bc.Config().IsEuclid(block.Time()) { + // Euclid blocks use MPT and CCC doesn't support them + return nil + } + if block.NumberU64() > c.currentHead.Number.Uint64()+1 { log.Warn("non continuous chain observed in AsyncChecker", "prev", c.currentHead, "got", block.Header()) } diff --git a/rollup/pipeline/pipeline.go b/rollup/pipeline/pipeline.go index 90c6149b3858..77ac3aee24a0 100644 --- a/rollup/pipeline/pipeline.go +++ b/rollup/pipeline/pipeline.go @@ -228,7 +228,7 @@ func sendCancellable[T any, C comparable](resCh chan T, msg T, cancelCh <-chan C } func (p *Pipeline) traceAndApplyStage(txsIn <-chan *types.Transaction) (<-chan error, <-chan *BlockCandidate, error) { - p.state.StartPrefetcher("miner") + p.state.StartPrefetcher("miner", nil) downstreamCh := make(chan *BlockCandidate, p.downstreamChCapacity()) resCh := make(chan error) p.wg.Add(1) diff --git a/trie/database.go b/trie/database.go index 1c5b7f805aea..1301a83a5e98 100644 --- a/trie/database.go +++ b/trie/database.go @@ -667,6 +667,9 @@ func (db *Database) Commit(node common.Hash, report bool, callback func(common.H } batch.Reset() + if diskRoot, err := rawdb.ReadDiskStateRoot(db.diskdb, node); err == nil { + node = diskRoot + } if (node == common.Hash{}) { return nil } @@ -782,7 +785,7 @@ func (c *cleaner) Put(key []byte, rlp []byte) error { delete(c.db.dirties, hash) c.db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size)) if node.children != nil { - c.db.dirtiesSize -= common.StorageSize(cachedNodeChildrenSize + len(node.children)*(common.HashLength+2)) + c.db.childrenSize -= common.StorageSize(cachedNodeChildrenSize + len(node.children)*(common.HashLength+2)) } // Move the flushed node into the clean cache to prevent insta-reloads if c.db.cleans != nil { diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 761d812bdfcc..7516d2879010 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -296,7 +296,7 @@ func TestUnionIterator(t *testing.T) { } func TestIteratorNoDups(t *testing.T) { - var tr Trie + tr := newEmpty() for _, val := range testdata1 { tr.Update([]byte(val.k), []byte(val.v)) } @@ -530,7 +530,7 @@ func TestNodeIteratorLargeTrie(t *testing.T) { trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885")) // master: 24 get operations // this pr: 5 get operations - if have, want := logDb.getCount, uint64(5); have != want { + if have, want := logDb.getCount, uint64(10); have != want { t.Fatalf("Too many lookups during seek, have %d want %d", have, want) } } diff --git a/trie/proof.go b/trie/proof.go index 58fb4c3cc78a..3362512b20c7 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -559,7 +559,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key } // Rebuild the trie with the leaf stream, the shape of trie // should be same with the original one. - tr := &Trie{root: root, db: NewDatabase(memorydb.New())} + tr := &Trie{root: root, db: NewDatabase(memorydb.New()), tracer: newTracer()} if empty { tr.root = nil } diff --git a/trie/proof_test.go b/trie/proof_test.go index 2155ae0fbd6a..5c5304261d2d 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -79,7 +79,7 @@ func TestProof(t *testing.T) { } func TestOneElementProof(t *testing.T) { - trie := new(Trie) + trie := newEmpty() updateString(trie, "k", "v") for i, prover := range makeProvers(trie) { proof := prover([]byte("k")) @@ -130,7 +130,7 @@ func TestBadProof(t *testing.T) { // Tests that missing keys can also be proven. The test explicitly uses a single // entry trie and checks for missing keys both before and after the single entry. func TestMissingKeyProof(t *testing.T) { - trie := new(Trie) + trie := newEmpty() updateString(trie, "k", "v") for i, key := range []string{"a", "j", "l", "z"} { @@ -386,7 +386,7 @@ func TestOneElementRangeProof(t *testing.T) { } // Test the mini trie with only a single element. - tinyTrie := new(Trie) + tinyTrie := newEmpty() entry := &kv{randBytes(32), randBytes(20), false} tinyTrie.Update(entry.k, entry.v) @@ -458,7 +458,7 @@ func TestAllElementsProof(t *testing.T) { // TestSingleSideRangeProof tests the range starts from zero. func TestSingleSideRangeProof(t *testing.T) { for i := 0; i < 64; i++ { - trie := new(Trie) + trie := newEmpty() var entries entrySlice for i := 0; i < 4096; i++ { value := &kv{randBytes(32), randBytes(20), false} @@ -493,7 +493,7 @@ func TestSingleSideRangeProof(t *testing.T) { // TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff. func TestReverseSingleSideRangeProof(t *testing.T) { for i := 0; i < 64; i++ { - trie := new(Trie) + trie := newEmpty() var entries entrySlice for i := 0; i < 4096; i++ { value := &kv{randBytes(32), randBytes(20), false} @@ -600,7 +600,7 @@ func TestBadRangeProof(t *testing.T) { // TestGappedRangeProof focuses on the small trie with embedded nodes. // If the gapped node is embedded in the trie, it should be detected too. func TestGappedRangeProof(t *testing.T) { - trie := new(Trie) + trie := newEmpty() var entries []*kv // Sorted entries for i := byte(0); i < 10; i++ { value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} @@ -674,7 +674,7 @@ func TestSameSideProofs(t *testing.T) { } func TestHasRightElement(t *testing.T) { - trie := new(Trie) + trie := newEmpty() var entries entrySlice for i := 0; i < 4096; i++ { value := &kv{randBytes(32), randBytes(20), false} @@ -1027,7 +1027,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) { } func randomTrie(n int) (*Trie, map[string]*kv) { - trie := new(Trie) + trie := newEmpty() vals := make(map[string]*kv) for i := byte(0); i < 100; i++ { value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} @@ -1052,7 +1052,7 @@ func randBytes(n int) []byte { } func nonRandomTrie(n int) (*Trie, map[string]*kv) { - trie := new(Trie) + trie := newEmpty() vals := make(map[string]*kv) max := uint64(0xffffffffffffffff) for i := uint64(0); i < uint64(n); i++ { diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 253b8d780ad3..a3529c2fc68f 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -190,6 +190,7 @@ func (t *SecureTrie) Hash() common.Hash { // Copy returns a copy of SecureTrie. func (t *SecureTrie) Copy() *SecureTrie { cpy := *t + cpy.trie.tracer = t.trie.tracer.copy() return &cpy } @@ -221,3 +222,8 @@ func (t *SecureTrie) getSecKeyCache() map[string][]byte { } return t.secKeyCache } + +// Witness returns a set containing all trie nodes that have been accessed. +func (t *SecureTrie) Witness() map[string]struct{} { + return t.trie.Witness() +} diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index b81b4e1ad5b8..9baaa2e266ed 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -112,8 +112,7 @@ func TestSecureTrieConcurrency(t *testing.T) { threads := runtime.NumCPU() tries := make([]*SecureTrie, threads) for i := 0; i < threads; i++ { - cpy := *trie - tries[i] = &cpy + tries[i] = trie.Copy() } // Start a batch of goroutines interactng with the trie pend := new(sync.WaitGroup) diff --git a/trie/tracer.go b/trie/tracer.go new file mode 100644 index 000000000000..99cda0706f7d --- /dev/null +++ b/trie/tracer.go @@ -0,0 +1,122 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "maps" + + "github.com/scroll-tech/go-ethereum/common" +) + +// tracer tracks the changes of trie nodes. During the trie operations, +// some nodes can be deleted from the trie, while these deleted nodes +// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted +// nodes won't be removed from the disk at all. Tracer is an auxiliary tool +// used to track all insert and delete operations of trie and capture all +// deleted nodes eventually. +// +// The changed nodes can be mainly divided into two categories: the leaf +// node and intermediate node. The former is inserted/deleted by callers +// while the latter is inserted/deleted in order to follow the rule of trie. +// This tool can track all of them no matter the node is embedded in its +// parent or not, but valueNode is never tracked. +// +// Besides, it's also used for recording the original value of the nodes +// when they are resolved from the disk. The pre-value of the nodes will +// be used to construct trie history in the future. +// +// Note tracer is not thread-safe, callers should be responsible for handling +// the concurrency issues by themselves. +type tracer struct { + inserts map[string]struct{} + deletes map[string]struct{} + accessList map[string][]byte +} + +// newTracer initializes the tracer for capturing trie changes. +func newTracer() *tracer { + return &tracer{ + inserts: make(map[string]struct{}), + deletes: make(map[string]struct{}), + accessList: make(map[string][]byte), + } +} + +// onRead tracks the newly loaded trie node and caches the rlp-encoded +// blob internally. Don't change the value outside of function since +// it's not deep-copied. +func (t *tracer) onRead(path []byte, val []byte) { + t.accessList[string(path)] = val +} + +// onInsert tracks the newly inserted trie node. If it's already +// in the deletion set (resurrected node), then just wipe it from +// the deletion set as it's "untouched". +func (t *tracer) onInsert(path []byte) { + if _, present := t.deletes[string(path)]; present { + delete(t.deletes, string(path)) + return + } + t.inserts[string(path)] = struct{}{} +} + +// onDelete tracks the newly deleted trie node. If it's already +// in the addition set, then just wipe it from the addition set +// as it's untouched. +func (t *tracer) onDelete(path []byte) { + if _, present := t.inserts[string(path)]; present { + delete(t.inserts, string(path)) + return + } + t.deletes[string(path)] = struct{}{} +} + +// reset clears the content tracked by tracer. +func (t *tracer) reset() { + t.inserts = make(map[string]struct{}) + t.deletes = make(map[string]struct{}) + t.accessList = make(map[string][]byte) +} + +// copy returns a deep copied tracer instance. +func (t *tracer) copy() *tracer { + accessList := make(map[string][]byte, len(t.accessList)) + for path, blob := range t.accessList { + accessList[path] = common.CopyBytes(blob) + } + return &tracer{ + inserts: maps.Clone(t.inserts), + deletes: maps.Clone(t.deletes), + accessList: accessList, + } +} + +// deletedNodes returns a list of node paths which are deleted from the trie. +func (t *tracer) deletedNodes() []string { + var paths []string + for path := range t.deletes { + // It's possible a few deleted nodes were embedded + // in their parent before, the deletions can be no + // effect by deleting nothing, filter them out. + _, ok := t.accessList[path] + if !ok { + continue + } + paths = append(paths, path) + } + return paths +} diff --git a/trie/trie.go b/trie/trie.go index 81cdd1627745..fe521d269074 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -62,6 +62,9 @@ type Trie struct { // hashing operation. This number will not directly map to the number of // actually unhashed nodes unhashed int + + // tracer is the tool to track the trie changes. + tracer *tracer } // newFlag returns the cache flag value for a newly created node. @@ -80,7 +83,8 @@ func New(root common.Hash, db *Database) (*Trie, error) { panic("trie.New called without a database") } trie := &Trie{ - db: db, + db: db, + tracer: newTracer(), } if root != (common.Hash{}) && root != emptyRoot { rootnode, err := trie.resolveHash(root[:], nil) @@ -313,6 +317,11 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error if matchlen == 0 { return true, branch, nil } + // New branch node is created as a child of the original short node. + // Track the newly inserted node in the tracer. The node identifier + // passed is the path from the root node. + t.tracer.onInsert(append(prefix, key[:matchlen]...)) + // Otherwise, replace it with a short node leading up to the branch. return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil @@ -327,6 +336,11 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error return true, n, nil case nil: + // New short node is created and track it in the tracer. The node identifier + // passed is the path from the root node. Note the valueNode won't be tracked + // since it's always embedded in its parent. + t.tracer.onInsert(prefix) + return true, &shortNode{key, value, t.newFlag()}, nil case hashNode: @@ -379,6 +393,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { return false, n, nil // don't replace n on mismatch } if matchlen == len(key) { + // The matched short node is deleted entirely and track + // it in the deletion set. The same the valueNode doesn't + // need to be tracked at all since it's always embedded. + t.tracer.onDelete(prefix) + return true, nil, nil // remove n entirely for whole matches } // The key is longer than n.Key. Remove the remaining suffix @@ -391,6 +410,10 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { } switch child := child.(type) { case *shortNode: + // The child shortNode is merged into its parent, track + // is deleted as well. + t.tracer.onDelete(append(prefix, n.Key...)) + // Deleting from the subtrie reduced it to another // short node. Merge the nodes to avoid creating a // shortNode{..., shortNode{...}}. Use concat (which @@ -452,6 +475,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { return false, nil, err } if cnode, ok := cnode.(*shortNode); ok { + // Replace the entire full node with the short node. + // Mark the original short node as deleted since the + // value is embedded into the parent now. + t.tracer.onDelete(append(prefix, byte(pos))) + k := append([]byte{byte(pos)}, cnode.Key...) return true, &shortNode{k, cnode.Val, t.newFlag()}, nil } @@ -505,6 +533,11 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { hash := common.BytesToHash(n) if node := t.db.node(hash); node != nil { + rlp, err := t.db.Node(hash) + if err != nil { + return nil, err + } + t.tracer.onRead(prefix, rlp) return node, nil } return nil, &MissingNodeError{NodeHash: hash, Path: prefix} @@ -582,4 +615,18 @@ func (t *Trie) hashRoot() (node, node, error) { func (t *Trie) Reset() { t.root = nil t.unhashed = 0 + t.tracer.reset() +} + +// Witness returns a set containing all trie nodes that have been accessed. +func (t *Trie) Witness() map[string]struct{} { + if len(t.tracer.accessList) == 0 { + return nil + } + + witness := make(map[string]struct{}, len(t.tracer.accessList)) + for _, node := range t.tracer.accessList { + witness[string(node)] = struct{}{} + } + return witness } diff --git a/trie/trie_test.go b/trie/trie_test.go index 23780a5ff807..e8f6293d0376 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -64,7 +64,7 @@ func TestEmptyTrie(t *testing.T) { } func TestNull(t *testing.T) { - var trie Trie + trie := newEmpty() key := make([]byte, 32) value := []byte("test") trie.Update(key, value) @@ -593,15 +593,15 @@ func TestTinyTrie(t *testing.T) { _, accounts := makeAccounts(5) trie := newEmpty() trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), accounts[3]) - if exp, root := common.HexToHash("fc516c51c03bf9f1a0eec6ed6f6f5da743c2745dcd5670007519e6ec056f95a8"), trie.Hash(); exp != root { + if exp, root := common.HexToHash("8c6a85a4d9fda98feff88450299e574e5378e32391f75a055d470ac0653f1005"), trie.Hash(); exp != root { t.Errorf("1: got %x, exp %x", root, exp) } trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001338"), accounts[4]) - if exp, root := common.HexToHash("5070d3f144546fd13589ad90cd153954643fa4ca6c1a5f08683cbfbbf76e960c"), trie.Hash(); exp != root { + if exp, root := common.HexToHash("ec63b967e98a5720e7f720482151963982890d82c9093c0d486b7eb8883a66b1"), trie.Hash(); exp != root { t.Errorf("2: got %x, exp %x", root, exp) } trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001339"), accounts[4]) - if exp, root := common.HexToHash("aa3fba77e50f6e931d8aacde70912be5bff04c7862f518ae06f3418dd4d37be3"), trie.Hash(); exp != root { + if exp, root := common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), trie.Hash(); exp != root { t.Errorf("3: got %x, exp %x", root, exp) } checktr, _ := New(common.Hash{}, trie.db) @@ -625,7 +625,7 @@ func TestCommitAfterHash(t *testing.T) { trie.Hash() trie.Commit(nil) root := trie.Hash() - exp := common.HexToHash("f0c0681648c93b347479cd58c61995557f01294425bd031ce1943c2799bbd4ec") + exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6") if exp != root { t.Errorf("got %x, exp %x", root, exp) } @@ -725,12 +725,12 @@ func TestCommitSequence(t *testing.T) { expWriteSeqHash []byte expCallbackSeqHash []byte }{ - {20, common.FromHex("7b908cce3bc16abb3eac5dff6c136856526f15225f74ce860a2bec47912a5492"), - common.FromHex("fac65cd2ad5e301083d0310dd701b5faaff1364cbe01cdbfaf4ec3609bb4149e")}, - {200, common.FromHex("55791f6ec2f83fee512a2d3d4b505784fdefaea89974e10440d01d62a18a298a"), - common.FromHex("5ab775b64d86a8058bb71c3c765d0f2158c14bbeb9cb32a65eda793a7e95e30f")}, - {2000, common.FromHex("ccb464abf67804538908c62431b3a6788e8dc6dee62aff9bfe6b10136acfceac"), - common.FromHex("b908adff17a5aa9d6787324c39014a74b04cef7fba6a92aeb730f48da1ca665d")}, + {20, common.FromHex("873c78df73d60e59d4a2bcf3716e8bfe14554549fea2fc147cb54129382a8066"), + common.FromHex("ff00f91ac05df53b82d7f178d77ada54fd0dca64526f537034a5dbe41b17df2a")}, + {200, common.FromHex("ba03d891bb15408c940eea5ee3d54d419595102648d02774a0268d892add9c8e"), + common.FromHex("f3cd509064c8d319bbdd1c68f511850a902ad275e6ed5bea11547e23d492a926")}, + {2000, common.FromHex("f7a184f20df01c94f09537401d11e68d97ad0c00115233107f51b9c287ce60c7"), + common.FromHex("ff795ea898ba1e4cfed4a33b4cf5535a347a02cf931f88d88719faf810f9a1c9")}, } { addresses, accounts := makeAccounts(tc.count) // This spongeDb is used to check the sequence of disk-db-writes diff --git a/trie/zk_trie.go b/trie/zk_trie.go index 044e18ad66ba..a98ae474ddff 100644 --- a/trie/zk_trie.go +++ b/trie/zk_trie.go @@ -233,3 +233,33 @@ func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueRead return nil, fmt.Errorf("bad proof node %v", proof) } } + +func (t *ZkTrie) CountLeaves() uint64 { + root, err := t.ZkTrie.Tree().Root() + if err != nil { + panic("CountLeaves cannot get root") + } + return t.countLeaves(root) +} + +func (t *ZkTrie) countLeaves(root *zkt.Hash) uint64 { + if root == nil { + return 0 + } + + rootNode, err := t.ZkTrie.Tree().GetNode(root) + if err != nil { + panic("countLeaves cannot get rootNode") + } + + if rootNode.Type == zktrie.NodeTypeLeaf_New { + return 1 + } else { + return t.countLeaves(rootNode.ChildL) + t.countLeaves(rootNode.ChildR) + } +} + +// Witness returns a set containing all trie nodes that have been accessed. +func (t *ZkTrie) Witness() map[string]struct{} { + panic("not implemented") +} diff --git a/trie/zk_trie_test.go b/trie/zk_trie_test.go index 6c23abc2764b..d1700a75899d 100644 --- a/trie/zk_trie_test.go +++ b/trie/zk_trie_test.go @@ -26,12 +26,16 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" zkt "github.com/scroll-tech/zktrie/types" "github.com/scroll-tech/go-ethereum/common" + "github.com/scroll-tech/go-ethereum/core/rawdb" + "github.com/scroll-tech/go-ethereum/core/types" "github.com/scroll-tech/go-ethereum/ethdb/leveldb" "github.com/scroll-tech/go-ethereum/ethdb/memorydb" + "github.com/scroll-tech/go-ethereum/rlp" ) func newEmptyZkTrie() *ZkTrie { @@ -264,3 +268,72 @@ func TestZkTrieDelete(t *testing.T) { assert.Equal(t, hashes[i].Hex(), hash.Hex()) } } + +func TestEquivalence(t *testing.T) { + t.Skip() + + zkDb, err := leveldb.New("/Users/omer/Documents/go-ethereum/l2geth-datadir/geth/chaindata", 0, 0, "", true) + require.NoError(t, err) + mptDb, err := leveldb.New("/Users/omer/Documents/go-ethereum/l2geth-datadir-mpt/geth/chaindata", 0, 0, "", true) + require.NoError(t, err) + + zkRoot := common.HexToHash("0x294b458b5b571bb634dbe9a81331dd2aabb5ef40cdc0328b075a9666d5df55d0") + mptRoot, err := rawdb.ReadDiskStateRoot(mptDb, zkRoot) + require.NoError(t, err) + + checkTrieEquality(t, &dbs{ + zkDb: zkDb, + mptDb: mptDb, + }, zkRoot, mptRoot, checkAccountEquality) +} + +type dbs struct { + zkDb *leveldb.Database + mptDb *leveldb.Database +} + +var accountsLeft = -1 + +func checkTrieEquality(t *testing.T, dbs *dbs, zkRoot, mptRoot common.Hash, leafChecker func(*testing.T, *dbs, []byte, []byte)) { + zkTrie, err := NewZkTrie(zkRoot, NewZktrieDatabase(dbs.zkDb)) + require.NoError(t, err) + + mptTrie, err := NewSecure(mptRoot, NewDatabaseWithConfig(dbs.mptDb, &Config{Preimages: true})) + require.NoError(t, err) + + expectedLeaves := zkTrie.CountLeaves() + trieIt := NewIterator(mptTrie.NodeIterator(nil)) + if accountsLeft == -1 { + accountsLeft = int(expectedLeaves) + } + + for trieIt.Next() { + expectedLeaves-- + preimageKey := mptTrie.GetKey(trieIt.Key) + require.NotEmpty(t, preimageKey) + leafChecker(t, dbs, zkTrie.Get(preimageKey), mptTrie.Get(preimageKey)) + } + require.Zero(t, expectedLeaves) +} + +func checkAccountEquality(t *testing.T, dbs *dbs, zkAccountBytes, mptAccountBytes []byte) { + mptAccount := &types.StateAccount{} + require.NoError(t, rlp.DecodeBytes(mptAccountBytes, mptAccount)) + zkAccount, err := types.UnmarshalStateAccount(zkAccountBytes) + require.NoError(t, err) + + require.Equal(t, mptAccount.Nonce, zkAccount.Nonce) + require.True(t, mptAccount.Balance.Cmp(zkAccount.Balance) == 0) + require.Equal(t, mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash) + checkTrieEquality(t, dbs, common.BytesToHash(zkAccount.Root[:]), common.BytesToHash(mptAccount.Root[:]), checkStorageEquality) + accountsLeft-- + t.Log("Accounts left:", accountsLeft) +} + +func checkStorageEquality(t *testing.T, _ *dbs, zkStorageBytes, mptStorageBytes []byte) { + zkValue := common.BytesToHash(zkStorageBytes) + _, content, _, err := rlp.Split(mptStorageBytes) + require.NoError(t, err) + mptValue := common.BytesToHash(content) + require.Equal(t, zkValue, mptValue) +}