Skip to content

Commit

Permalink
tests: fix goroutine leak related to state snapshot generation (#28974)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Felix Lange <fjl@twurst.com>
  • Loading branch information
holiman and fjl authored Feb 14, 2024
1 parent 55a46c3 commit 8321fe2
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 124 deletions.
18 changes: 9 additions & 9 deletions cmd/evm/staterunner.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/tracers/logger"
"github.com/ethereum/go-ethereum/tests"
Expand Down Expand Up @@ -90,26 +89,27 @@ func runStateTest(fname string, cfg vm.Config, jsonOut, dump bool) error {
if err != nil {
return err
}
var tests map[string]tests.StateTest
if err := json.Unmarshal(src, &tests); err != nil {
var testsByName map[string]tests.StateTest
if err := json.Unmarshal(src, &testsByName); err != nil {
return err
}

// Iterate over all the tests, run them and aggregate the results
results := make([]StatetestResult, 0, len(tests))
for key, test := range tests {
results := make([]StatetestResult, 0, len(testsByName))
for key, test := range testsByName {
for _, st := range test.Subtests() {
// Run the test and aggregate the result
result := &StatetestResult{Name: key, Fork: st.Fork, Pass: true}
test.Run(st, cfg, false, rawdb.HashScheme, func(err error, snaps *snapshot.Tree, statedb *state.StateDB) {
test.Run(st, cfg, false, rawdb.HashScheme, func(err error, tstate *tests.StateTestState) {
var root common.Hash
if statedb != nil {
root = statedb.IntermediateRoot(false)
if tstate.StateDB != nil {
root = tstate.StateDB.IntermediateRoot(false)
result.Root = &root
if jsonOut {
fmt.Fprintf(os.Stderr, "{\"stateRoot\": \"%#x\"}\n", root)
}
if dump { // Dump any state to aid debugging
cpy, _ := state.New(root, statedb.Database(), nil)
cpy, _ := state.New(root, tstate.StateDB.Database(), nil)
dump := cpy.RawDump(nil)
result.State = &dump
}
Expand Down
8 changes: 8 additions & 0 deletions core/state/snapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,14 @@ func (t *Tree) Disable() {
for _, layer := range t.layers {
switch layer := layer.(type) {
case *diskLayer:

layer.lock.RLock()
generating := layer.genMarker != nil
layer.lock.RUnlock()
if !generating {
// Generator is already aborted or finished
break
}
// If the base layer is generating, abort it
if layer.genAbort != nil {
abort := make(chan *generatorStats)
Expand Down
22 changes: 11 additions & 11 deletions eth/tracers/internal/tracetest/calltrace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ func testCallTracer(tracerName string, dirPath string, t *testing.T) {
GasLimit: uint64(test.Context.GasLimit),
BaseFee: test.Genesis.BaseFee,
}
triedb, _, statedb = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
state = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
)
triedb.Close()
state.Close()

tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig)
if err != nil {
Expand All @@ -145,7 +145,7 @@ func testCallTracer(tracerName string, dirPath string, t *testing.T) {
if err != nil {
t.Fatalf("failed to prepare transaction for tracing: %v", err)
}
evm := vm.NewEVM(context, core.NewEVMTxContext(msg), statedb, test.Genesis.Config, vm.Config{Tracer: tracer})
evm := vm.NewEVM(context, core.NewEVMTxContext(msg), state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer})
vmRet, err := core.ApplyMessage(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
if err != nil {
t.Fatalf("failed to execute transaction: %v", err)
Expand Down Expand Up @@ -235,8 +235,8 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) {
if err != nil {
b.Fatalf("failed to prepare transaction for tracing: %v", err)
}
triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
defer triedb.Close()
state := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
defer state.Close()

b.ReportAllocs()
b.ResetTimer()
Expand All @@ -245,16 +245,16 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) {
if err != nil {
b.Fatalf("failed to create call tracer: %v", err)
}
evm := vm.NewEVM(context, txContext, statedb, test.Genesis.Config, vm.Config{Tracer: tracer})
snap := statedb.Snapshot()
evm := vm.NewEVM(context, txContext, state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer})
snap := state.StateDB.Snapshot()
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
if _, err = st.TransitionDb(); err != nil {
b.Fatalf("failed to execute transaction: %v", err)
}
if _, err = tracer.GetResult(); err != nil {
b.Fatal(err)
}
statedb.RevertToSnapshot(snap)
state.StateDB.RevertToSnapshot(snap)
}
}

Expand Down Expand Up @@ -362,7 +362,7 @@ func TestInternals(t *testing.T) {
},
} {
t.Run(tc.name, func(t *testing.T) {
triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(),
state := tests.MakePreState(rawdb.NewMemoryDatabase(),
core.GenesisAlloc{
to: core.GenesisAccount{
Code: tc.code,
Expand All @@ -371,9 +371,9 @@ func TestInternals(t *testing.T) {
Balance: big.NewInt(500000000000000),
},
}, false, rawdb.HashScheme)
defer triedb.Close()
defer state.Close()

evm := vm.NewEVM(context, txContext, statedb, params.MainnetChainConfig, vm.Config{Tracer: tc.tracer})
evm := vm.NewEVM(context, txContext, state.StateDB, params.MainnetChainConfig, vm.Config{Tracer: tc.tracer})
msg := &core.Message{
To: &to,
From: origin,
Expand Down
6 changes: 3 additions & 3 deletions eth/tracers/internal/tracetest/flat_calltrace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ func flatCallTracerTestRunner(tracerName string, filename string, dirPath string
Difficulty: (*big.Int)(test.Context.Difficulty),
GasLimit: uint64(test.Context.GasLimit),
}
triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
defer triedb.Close()
state := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
defer state.Close()

// Create the tracer, the EVM environment and run it
tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig)
Expand All @@ -107,7 +107,7 @@ func flatCallTracerTestRunner(tracerName string, filename string, dirPath string
if err != nil {
return fmt.Errorf("failed to prepare transaction for tracing: %v", err)
}
evm := vm.NewEVM(context, core.NewEVMTxContext(msg), statedb, test.Genesis.Config, vm.Config{Tracer: tracer})
evm := vm.NewEVM(context, core.NewEVMTxContext(msg), state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer})
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))

if _, err = st.TransitionDb(); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions eth/tracers/internal/tracetest/prestate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ func testPrestateDiffTracer(tracerName string, dirPath string, t *testing.T) {
GasLimit: uint64(test.Context.GasLimit),
BaseFee: test.Genesis.BaseFee,
}
triedb, _, statedb = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
state = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
)
defer triedb.Close()
defer state.Close()

tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig)
if err != nil {
Expand All @@ -115,7 +115,7 @@ func testPrestateDiffTracer(tracerName string, dirPath string, t *testing.T) {
if err != nil {
t.Fatalf("failed to prepare transaction for tracing: %v", err)
}
evm := vm.NewEVM(context, core.NewEVMTxContext(msg), statedb, test.Genesis.Config, vm.Config{Tracer: tracer})
evm := vm.NewEVM(context, core.NewEVMTxContext(msg), state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer})
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
if _, err = st.TransitionDb(); err != nil {
t.Fatalf("failed to execute transaction: %v", err)
Expand Down
10 changes: 5 additions & 5 deletions eth/tracers/tracers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ func BenchmarkTransactionTrace(b *testing.B) {
Code: []byte{},
Balance: big.NewInt(500000000000000),
}
triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), alloc, false, rawdb.HashScheme)
defer triedb.Close()
state := tests.MakePreState(rawdb.NewMemoryDatabase(), alloc, false, rawdb.HashScheme)
defer state.Close()

// Create the tracer, the EVM environment and run it
tracer := logger.NewStructLogger(&logger.Config{
Expand All @@ -89,7 +89,7 @@ func BenchmarkTransactionTrace(b *testing.B) {
//EnableMemory: false,
//EnableReturnData: false,
})
evm := vm.NewEVM(context, txContext, statedb, params.AllEthashProtocolChanges, vm.Config{Tracer: tracer})
evm := vm.NewEVM(context, txContext, state.StateDB, params.AllEthashProtocolChanges, vm.Config{Tracer: tracer})
msg, err := core.TransactionToMessage(tx, signer, context.BaseFee)
if err != nil {
b.Fatalf("failed to prepare transaction for tracing: %v", err)
Expand All @@ -98,13 +98,13 @@ func BenchmarkTransactionTrace(b *testing.B) {
b.ReportAllocs()

for i := 0; i < b.N; i++ {
snap := statedb.Snapshot()
snap := state.StateDB.Snapshot()
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
_, err = st.TransitionDb()
if err != nil {
b.Fatal(err)
}
statedb.RevertToSnapshot(snap)
state.StateDB.RevertToSnapshot(snap)
if have, want := len(tracer.StructLogs()), 244752; have != want {
b.Fatalf("trace wrong, want %d steps, have %d", want, have)
}
Expand Down
32 changes: 15 additions & 17 deletions tests/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ import (

"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/tracers/logger"
Expand Down Expand Up @@ -82,7 +80,7 @@ func TestState(t *testing.T) {
t.Run(key+"/hash/trie", func(t *testing.T) {
withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
var result error
test.Run(subtest, vmconfig, false, rawdb.HashScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) {
test.Run(subtest, vmconfig, false, rawdb.HashScheme, func(err error, state *StateTestState) {
result = st.checkFailure(t, err)
})
return result
Expand All @@ -91,9 +89,9 @@ func TestState(t *testing.T) {
t.Run(key+"/hash/snap", func(t *testing.T) {
withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
var result error
test.Run(subtest, vmconfig, true, rawdb.HashScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) {
if snaps != nil && state != nil {
if _, err := snaps.Journal(state.IntermediateRoot(false)); err != nil {
test.Run(subtest, vmconfig, true, rawdb.HashScheme, func(err error, state *StateTestState) {
if state.Snapshots != nil && state.StateDB != nil {
if _, err := state.Snapshots.Journal(state.StateDB.IntermediateRoot(false)); err != nil {
result = err
return
}
Expand All @@ -106,7 +104,7 @@ func TestState(t *testing.T) {
t.Run(key+"/path/trie", func(t *testing.T) {
withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
var result error
test.Run(subtest, vmconfig, false, rawdb.PathScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) {
test.Run(subtest, vmconfig, false, rawdb.PathScheme, func(err error, state *StateTestState) {
result = st.checkFailure(t, err)
})
return result
Expand All @@ -115,9 +113,9 @@ func TestState(t *testing.T) {
t.Run(key+"/path/snap", func(t *testing.T) {
withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
var result error
test.Run(subtest, vmconfig, true, rawdb.PathScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) {
if snaps != nil && state != nil {
if _, err := snaps.Journal(state.IntermediateRoot(false)); err != nil {
test.Run(subtest, vmconfig, true, rawdb.PathScheme, func(err error, state *StateTestState) {
if state.Snapshots != nil && state.StateDB != nil {
if _, err := state.Snapshots.Journal(state.StateDB.IntermediateRoot(false)); err != nil {
result = err
return
}
Expand Down Expand Up @@ -222,8 +220,8 @@ func runBenchmark(b *testing.B, t *StateTest) {

vmconfig.ExtraEips = eips
block := t.genesis(config).ToBlock()
triedb, _, statedb := MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, false, rawdb.HashScheme)
defer triedb.Close()
state := MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, false, rawdb.HashScheme)
defer state.Close()

var baseFee *big.Int
if rules.IsLondon {
Expand Down Expand Up @@ -261,7 +259,7 @@ func runBenchmark(b *testing.B, t *StateTest) {
context := core.NewEVMBlockContext(block.Header(), nil, &t.json.Env.Coinbase)
context.GetHash = vmTestBlockHash
context.BaseFee = baseFee
evm := vm.NewEVM(context, txContext, statedb, config, vmconfig)
evm := vm.NewEVM(context, txContext, state.StateDB, config, vmconfig)

// Create "contract" for sender to cache code analysis.
sender := vm.NewContract(vm.AccountRef(msg.From), vm.AccountRef(msg.From),
Expand All @@ -274,8 +272,8 @@ func runBenchmark(b *testing.B, t *StateTest) {
)
b.ResetTimer()
for n := 0; n < b.N; n++ {
snapshot := statedb.Snapshot()
statedb.Prepare(rules, msg.From, context.Coinbase, msg.To, vm.ActivePrecompiles(rules), msg.AccessList)
snapshot := state.StateDB.Snapshot()
state.StateDB.Prepare(rules, msg.From, context.Coinbase, msg.To, vm.ActivePrecompiles(rules), msg.AccessList)
b.StartTimer()
start := time.Now()

Expand All @@ -288,10 +286,10 @@ func runBenchmark(b *testing.B, t *StateTest) {

b.StopTimer()
elapsed += uint64(time.Since(start))
refund += statedb.GetRefund()
refund += state.StateDB.GetRefund()
gasUsed += msg.GasLimit - leftOverGas

statedb.RevertToSnapshot(snapshot)
state.StateDB.RevertToSnapshot(snapshot)
}
if elapsed < 1 {
elapsed = 1
Expand Down
Loading

0 comments on commit 8321fe2

Please sign in to comment.