|
1 | 1 | package contract_test
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "errors" |
4 | 5 | "testing"
|
5 | 6 |
|
6 | 7 | "github.com/ethereum/go-ethereum/common"
|
7 | 8 | "github.com/ethereum/go-ethereum/core/state"
|
8 | 9 | "github.com/ethereum/go-ethereum/precompile/contract"
|
| 10 | + "github.com/stretchr/testify/assert" |
9 | 11 | "github.com/stretchr/testify/require"
|
10 | 12 | )
|
11 | 13 |
|
@@ -144,3 +146,68 @@ func TestPrecompileInvalidCalls(t *testing.T) {
|
144 | 146 | })
|
145 | 147 | }
|
146 | 148 | }
|
| 149 | + |
| 150 | +type stateDBWithExt struct { |
| 151 | + contract.StateDB |
| 152 | + |
| 153 | + customVal []byte |
| 154 | +} |
| 155 | + |
| 156 | +func newStateDBWithExt(t *testing.T, customVal []byte) *stateDBWithExt { |
| 157 | + return &stateDBWithExt{ |
| 158 | + StateDB: state.NewTestStateDB(t), |
| 159 | + customVal: customVal, |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +func (s *stateDBWithExt) CustomValueExt() []byte { |
| 164 | + return s.customVal |
| 165 | +} |
| 166 | + |
| 167 | +func TestPrecompileCustomStateDBExtension(t *testing.T) { |
| 168 | + unsupportedErr := errors.New("unsupported statedb") |
| 169 | + |
| 170 | + // test function that asserts a concrete type on the stateDB, |
| 171 | + // and returns the return value of the extended type |
| 172 | + testFunc := func( |
| 173 | + accessibleState contract.AccessibleState, |
| 174 | + caller common.Address, |
| 175 | + addr common.Address, |
| 176 | + input []byte, |
| 177 | + suppliedGas uint64, |
| 178 | + readOnly bool, |
| 179 | + ) ([]byte, uint64, error) { |
| 180 | + customStateDB, ok := accessibleState.GetStateDB().(*stateDBWithExt) |
| 181 | + if !ok { |
| 182 | + return nil, suppliedGas, unsupportedErr |
| 183 | + } |
| 184 | + |
| 185 | + return customStateDB.CustomValueExt(), 0, nil |
| 186 | + } |
| 187 | + |
| 188 | + // create the precompiled contract with the testFunc |
| 189 | + funcSelector := contract.MustCalculateFunctionSelector("getStateDBCustomValue()") |
| 190 | + functions := []*contract.StatefulPrecompileFunction{ |
| 191 | + contract.NewStatefulPrecompileFunction(funcSelector, testFunc), |
| 192 | + } |
| 193 | + precompiledContract, err := contract.NewStatefulPrecompileContract(functions) |
| 194 | + require.NoError(t, err) |
| 195 | + |
| 196 | + // build the accessible state with the extended statedb |
| 197 | + customVal := []byte("value from statedb extension") |
| 198 | + accessibleState := newAccessibleState(newStateDBWithExt(t, customVal)) |
| 199 | + |
| 200 | + // run and see the returned value from calling CustomValueExt() on the injected statedb |
| 201 | + retVal, remainingGas, err := precompiledContract.Run(accessibleState, callerAddr, contractAddr, funcSelector, 1, true) |
| 202 | + require.NoError(t, err) |
| 203 | + assert.Equal(t, uint64(0), remainingGas) |
| 204 | + assert.Equal(t, customVal, retVal, "expected contract to access and return customVal from extended statedb") |
| 205 | + |
| 206 | + // test contract with unsupported statedb |
| 207 | + accessibleState = newAccessibleState(state.NewTestStateDB(t)) |
| 208 | + retVal, remainingGas, err = precompiledContract.Run(accessibleState, callerAddr, contractAddr, funcSelector, 1, true) |
| 209 | + require.Error(t, err) |
| 210 | + assert.Equal(t, unsupportedErr, err) |
| 211 | + assert.Equal(t, uint64(1), remainingGas) |
| 212 | + assert.Nil(t, retVal) |
| 213 | +} |
0 commit comments