Skip to content

Commit

Permalink
w3types: added State.Merge (#151)
Browse files Browse the repository at this point in the history
Co-authored-by: lmittmann <lmittmann@users.noreply.github.com>
  • Loading branch information
lmittmann and lmittmann authored Jun 10, 2024
1 parent 706fe33 commit 40f3516
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
64 changes: 64 additions & 0 deletions w3types/state.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package w3types

import (
"bytes"
"encoding/json"
"maps"
"math/big"
"sync/atomic"

Expand All @@ -28,6 +30,27 @@ func (s State) SetGenesisAlloc(alloc types.GenesisAlloc) State {
return s
}

// Merge returns a new state that is the result of merging the called state with the given state.
// All state in other state will overwrite the state in the called state.
func (s State) Merge(other State) (merged State) {
merged = make(State, len(s))

// copy all accounts from s
for addr, acc := range s {
merged[addr] = acc.deepCopy()
}

// merge all accounts from other
for addr, acc := range other {
if mergedAcc, ok := merged[addr]; ok {
mergedAcc.merge(acc)
} else {
merged[addr] = acc.deepCopy()
}
}
return merged
}

type Account struct {
Nonce uint64
Balance *big.Int
Expand All @@ -37,6 +60,47 @@ type Account struct {
codeHash atomic.Pointer[common.Hash] // caches the code hash
}

// deepCopy returns a deep copy of the account.
func (acc *Account) deepCopy() *Account {
newAcc := &Account{Nonce: acc.Nonce}
if acc.Balance != nil {
newAcc.Balance = new(big.Int).Set(acc.Balance)
}
if len(acc.Code) > 0 {
newAcc.Code = bytes.Clone(acc.Code)
}
if len(acc.Storage) > 0 {
newAcc.Storage = maps.Clone(acc.Storage)
}
return newAcc
}

// merge merges the given account into the called account.
func (dst *Account) merge(src *Account) {
// merge account fields
srcIsZero := src.Nonce == 0 && src.Balance == nil && len(src.Code) == 0
if !srcIsZero {
dst.Nonce = src.Nonce
if src.Balance != nil {
dst.Balance = new(big.Int).Set(src.Balance)
} else {
dst.Balance = nil
}
if len(src.Code) > 0 {
dst.Code = bytes.Clone(src.Code)
} else {
dst.Code = nil
}
}

// merge storage
if dst.Storage == nil && len(src.Storage) > 0 {
dst.Storage = maps.Clone(src.Storage)
} else if len(src.Storage) > 0 {
maps.Copy(dst.Storage, src.Storage)
}
}

// CodeHash returns the hash of the account's code.
func (acc *Account) CodeHash() common.Hash {
if codeHash := acc.codeHash.Load(); codeHash != nil {
Expand Down
79 changes: 79 additions & 0 deletions w3types/state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package w3types_test

import (
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/lmittmann/w3/w3types"
)

func TestStateMerge(t *testing.T) {
tests := []struct {
Name string
StateDst w3types.State
StateSrc w3types.State
Want w3types.State
}{
{
Name: "empty",
StateDst: w3types.State{},
StateSrc: w3types.State{},
Want: w3types.State{},
},
{
Name: "empty-dst",
StateDst: w3types.State{},
StateSrc: w3types.State{common.Address{}: {}},
Want: w3types.State{common.Address{}: {}},
},
{
Name: "empty-src",
StateDst: w3types.State{common.Address{}: {}},
StateSrc: w3types.State{},
Want: w3types.State{common.Address{}: {}},
},
{
Name: "simple",
StateDst: w3types.State{common.Address{0x01}: {}},
StateSrc: w3types.State{common.Address{0x02}: {}},
Want: w3types.State{
common.Address{0x01}: {},
common.Address{0x02}: {},
},
},
{
Name: "simple-conflict",
StateDst: w3types.State{common.Address{}: {Nonce: 1}},
StateSrc: w3types.State{common.Address{}: {Nonce: 2}},
Want: w3types.State{common.Address{}: {Nonce: 2}},
},
{
Name: "storage-simple",
StateDst: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{0x01}: common.Hash{0x01}}}},
StateSrc: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{0x02}: common.Hash{0x02}}}},
Want: w3types.State{common.Address{}: {Storage: w3types.Storage{
common.Hash{0x01}: common.Hash{0x01},
common.Hash{0x02}: common.Hash{0x02},
}}},
},
{
Name: "storage-conflict",
StateDst: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{}: common.Hash{0x01}}}},
StateSrc: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{}: common.Hash{0x02}}}},
Want: w3types.State{common.Address{}: {Storage: w3types.Storage{common.Hash{}: common.Hash{0x02}}}},
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
got := test.StateDst.Merge(test.StateSrc)
if diff := cmp.Diff(test.Want, got,
cmpopts.IgnoreUnexported(w3types.Account{}),
); diff != "" {
t.Fatalf("(-want, +got)\n%s", diff)
}
})
}
}

0 comments on commit 40f3516

Please sign in to comment.