diff --git a/w3types/state.go b/w3types/state.go index fcc1be73..ae8747e9 100644 --- a/w3types/state.go +++ b/w3types/state.go @@ -1,7 +1,9 @@ package w3types import ( + "bytes" "encoding/json" + "maps" "math/big" "sync/atomic" @@ -28,6 +30,27 @@ func (s State) SetGenesisAlloc(alloc types.GenesisAlloc) State { return s } +// Merge returns a new state that is the result of merging the called state with the given state. +// All state in other state will overwrite the state in the called state. +func (s State) Merge(other State) (merged State) { + merged = make(State, len(s)) + + // copy all accounts from s + for addr, acc := range s { + merged[addr] = acc.deepCopy() + } + + // merge all accounts from other + for addr, acc := range other { + if mergedAcc, ok := merged[addr]; ok { + mergedAcc.merge(acc) + } else { + merged[addr] = acc.deepCopy() + } + } + return merged +} + type Account struct { Nonce uint64 Balance *big.Int @@ -37,6 +60,47 @@ type Account struct { codeHash atomic.Pointer[common.Hash] // caches the code hash } +// deepCopy returns a deep copy of the account. +func (acc *Account) deepCopy() *Account { + newAcc := &Account{Nonce: acc.Nonce} + if acc.Balance != nil { + newAcc.Balance = new(big.Int).Set(acc.Balance) + } + if len(acc.Code) > 0 { + newAcc.Code = bytes.Clone(acc.Code) + } + if len(acc.Storage) > 0 { + newAcc.Storage = maps.Clone(acc.Storage) + } + return newAcc +} + +// merge merges the given account into the called account. +func (dst *Account) merge(src *Account) { + // merge account fields + srcIsZero := src.Nonce == 0 && src.Balance == nil && len(src.Code) == 0 + if !srcIsZero { + dst.Nonce = src.Nonce + if src.Balance != nil { + dst.Balance = new(big.Int).Set(src.Balance) + } else { + dst.Balance = nil + } + if len(src.Code) > 0 { + dst.Code = bytes.Clone(src.Code) + } else { + dst.Code = nil + } + } + + // merge storage + if dst.Storage == nil && len(src.Storage) > 0 { + dst.Storage = maps.Clone(src.Storage) + } else if len(src.Storage) > 0 { + maps.Copy(dst.Storage, src.Storage) + } +} + // CodeHash returns the hash of the account's code. func (acc *Account) CodeHash() common.Hash { if codeHash := acc.codeHash.Load(); codeHash != nil { diff --git a/w3types/state_test.go b/w3types/state_test.go new file mode 100644 index 00000000..6123edde --- /dev/null +++ b/w3types/state_test.go @@ -0,0 +1,79 @@ +package w3types_test + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/lmittmann/w3/w3types" +) + +func TestStateMerge(t *testing.T) { + tests := []struct { + Name string + StateDst w3types.State + StateSrc w3types.State + Want w3types.State + }{ + { + Name: "empty", + StateDst: w3types.State{}, + StateSrc: w3types.State{}, + Want: w3types.State{}, + }, + { + Name: "empty-dst", + StateDst: w3types.State{}, + StateSrc: w3types.State{common.Address{}: {}}, + Want: w3types.State{common.Address{}: {}}, + }, + { + Name: "empty-src", + StateDst: w3types.State{common.Address{}: {}}, + StateSrc: w3types.State{}, + Want: w3types.State{common.Address{}: {}}, + }, + { + Name: "simple", + StateDst: w3types.State{common.Address{0x01}: {}}, + StateSrc: w3types.State{common.Address{0x02}: {}}, + Want: w3types.State{ + common.Address{0x01}: {}, + common.Address{0x02}: {}, + }, + }, + { + Name: "simple-conflict", + StateDst: w3types.State{common.Address{}: {Nonce: 1}}, + StateSrc: w3types.State{common.Address{}: {Nonce: 2}}, + Want: w3types.State{common.Address{}: {Nonce: 2}}, + }, + { + Name: "storage-simple", + StateDst: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{0x01}: common.Hash{0x01}}}}, + StateSrc: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{0x02}: common.Hash{0x02}}}}, + Want: w3types.State{common.Address{}: {Storage: w3types.Storage{ + common.Hash{0x01}: common.Hash{0x01}, + common.Hash{0x02}: common.Hash{0x02}, + }}}, + }, + { + Name: "storage-conflict", + StateDst: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{}: common.Hash{0x01}}}}, + StateSrc: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{}: common.Hash{0x02}}}}, + Want: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{}: common.Hash{0x02}}}}, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + got := test.StateDst.Merge(test.StateSrc) + if diff := cmp.Diff(test.Want, got, + cmpopts.IgnoreUnexported(w3types.Account{}), + ); diff != "" { + t.Fatalf("(-want, +got)\n%s", diff) + } + }) + } +}