Skip to content

Commit

Permalink
add support for prank(sender, origin) and startPrank(sender, origin) …
Browse files Browse the repository at this point in the history
…cheatcodes (#336)
  • Loading branch information
karmacoma-eth committed Aug 13, 2024
1 parent 5292bd2 commit f029418
Show file tree
Hide file tree
Showing 14 changed files with 722 additions and 143 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ jobs:
run: python -m pip install -e .

- name: Run pytest
run: pytest -n 4 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="-st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0 ${{ inputs.halmos-options }}" ${{ inputs.pytest-options }}
run: pytest -n 1 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="--debug -st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0 ${{ inputs.halmos-options }}" ${{ inputs.pytest-options }}
9 changes: 0 additions & 9 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
[submodule "tests/lib/forge-std"]
path = tests/lib/forge-std
url = https://github.com/foundry-rs/forge-std
shallow = true
[submodule "tests/lib/halmos-cheatcodes"]
path = tests/lib/halmos-cheatcodes
url = https://github.com/a16z/halmos-cheatcodes
shallow = true
[submodule "tests/lib/openzeppelin-contracts"]
path = tests/lib/openzeppelin-contracts
url = https://github.com/OpenZeppelin/openzeppelin-contracts
shallow = true
[submodule "tests/lib/solmate"]
path = tests/lib/solmate
url = https://github.com/transmissions11/solmate
shallow = true
[submodule "tests/lib/solady"]
path = tests/lib/solady
url = https://github.com/Vectorized/solady
shallow = true
[submodule "tests/lib/multicaller"]
path = tests/lib/multicaller
url = https://github.com/Vectorized/multicaller
shallow = true
2 changes: 1 addition & 1 deletion examples/simple/remappings.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
multicaller/=../../tests/lib/multicaller/src/
multicaller/=src/multicaller/
151 changes: 151 additions & 0 deletions examples/simple/src/multicaller/MulticallerWithSender.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;

/// from Vectorized/multicaller@v1.3.2

/**
* @title MulticallerWithSender
* @author vectorized.eth
* @notice Contract that allows for efficient aggregation of multiple calls
* in a single transaction, while "forwarding" the `msg.sender`.
*/
contract MulticallerWithSender {
// =============================================================
// ERRORS
// =============================================================

/**
* @dev The lengths of the input arrays are not the same.
*/
error ArrayLengthsMismatch();

/**
* @dev This function does not support reentrancy.
*/
error Reentrancy();

// =============================================================
// CONSTRUCTOR
// =============================================================

constructor() payable {
assembly {
// Throughout this code, we will abuse returndatasize
// in place of zero anywhere before a call to save a bit of gas.
// We will use storage slot zero to store the caller at
// bits [0..159] and reentrancy guard flag at bit 160.
sstore(returndatasize(), shl(160, 1))
}
}

// =============================================================
// AGGREGATION OPERATIONS
// =============================================================

/**
* @dev Returns the address that called `aggregateWithSender` on this contract.
* The value is always the zero address outside a transaction.
*/
receive() external payable {
assembly {
mstore(returndatasize(), and(sub(shl(160, 1), 1), sload(returndatasize())))
return(returndatasize(), 0x20)
}
}

/**
* @dev Aggregates multiple calls in a single transaction.
* This method will set `sender` to the `msg.sender` temporarily
* for the span of its execution.
* This method does not support reentrancy.
* @param targets An array of addresses to call.
* @param data An array of calldata to forward to the targets.
* @param values How much ETH to forward to each target.
* @return An array of the returndata from each call.
*/
function aggregateWithSender(
address[] calldata targets,
bytes[] calldata data,
uint256[] calldata values
) external payable returns (bytes[] memory) {
assembly {
if iszero(and(eq(targets.length, data.length), eq(data.length, values.length))) {
// Store the function selector of `ArrayLengthsMismatch()`.
mstore(returndatasize(), 0x3b800a46)
// Revert with (offset, size).
revert(0x1c, 0x04)
}

if iszero(and(sload(returndatasize()), shl(160, 1))) {
// Store the function selector of `Reentrancy()`.
mstore(returndatasize(), 0xab143c06)
// Revert with (offset, size).
revert(0x1c, 0x04)
}

mstore(returndatasize(), 0x20) // Store the memory offset of the `results`.
mstore(0x20, data.length) // Store `data.length` into `results`.
// Early return if no data.
if iszero(data.length) { return(returndatasize(), 0x40) }

// Set the sender slot temporarily for the span of this transaction.
sstore(returndatasize(), caller())

let results := 0x40
// Left shift by 5 is equivalent to multiplying by 0x20.
data.length := shl(5, data.length)
// Copy the offsets from calldata into memory.
calldatacopy(results, data.offset, data.length)
// Offset into `results`.
let resultsOffset := data.length
// Pointer to the end of `results`.
// Recycle `data.length` to avoid stack too deep.
data.length := add(results, data.length)

for {} 1 {} {
// The offset of the current bytes in the calldata.
let o := add(data.offset, mload(results))
let memPtr := add(resultsOffset, 0x40)
// Copy the current bytes from calldata to the memory.
calldatacopy(
memPtr,
add(o, 0x20), // The offset of the current bytes' bytes.
calldataload(o) // The length of the current bytes.
)
if iszero(
call(
gas(), // Remaining gas.
calldataload(targets.offset), // Address to call.
calldataload(values.offset), // ETH to send.
memPtr, // Start of input calldata in memory.
calldataload(o), // Size of input calldata.
0x00, // We will use returndatacopy instead.
0x00 // We will use returndatacopy instead.
)
) {
// Bubble up the revert if the call reverts.
returndatacopy(0x00, 0x00, returndatasize())
revert(0x00, returndatasize())
}
// Advance the `targets.offset`.
targets.offset := add(targets.offset, 0x20)
// Advance the `values.offset`.
values.offset := add(values.offset, 0x20)
// Append the current `resultsOffset` into `results`.
mstore(results, resultsOffset)
results := add(results, 0x20)
// Append the returndatasize, and the returndata.
mstore(memPtr, returndatasize())
returndatacopy(add(memPtr, 0x20), 0x00, returndatasize())
// Advance the `resultsOffset` by `returndatasize() + 0x20`,
// rounded up to the next multiple of 0x20.
resultsOffset := and(add(add(resultsOffset, returndatasize()), 0x3f), not(0x1f))
if iszero(lt(results, data.length)) { break }
}
// Restore the `sender` slot.
sstore(0, shl(160, 1))
// Direct return.
return(0x00, add(resultsOffset, 0x40))
}
}
}
34 changes: 13 additions & 21 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,7 @@ def mk_addr(name: str) -> Address:


def mk_caller(args: HalmosConfig) -> Address:
if args.symbolic_msg_sender:
return mk_addr("msg_sender")
else:
return magic_address
return mk_addr("msg_sender") if args.symbolic_msg_sender else magic_address


def mk_this() -> Address:
Expand Down Expand Up @@ -347,27 +344,24 @@ def render_trace(context: CallContext, file=sys.stdout) -> None:

def run_bytecode(hexcode: str, args: HalmosConfig) -> List[Exec]:
solver = mk_solver(args)
contract = Contract.from_hexcode(hexcode)
balance = mk_balance()
block = mk_block()
this = mk_this()

message = Message(
target=this,
caller=mk_caller(args),
origin=mk_addr("tx_origin"),
value=mk_callvalue(),
data=ByteVec(),
call_scheme=EVM.CALL,
)

contract = Contract.from_hexcode(hexcode)
sevm = SEVM(args)
ex = sevm.mk_exec(
code={this: contract},
storage={this: {}},
balance=balance,
block=block,
balance=mk_balance(),
block=mk_block(),
context=CallContext(message=message),
this=this,
pgm=contract,
symbolic=args.symbolic_storage,
path=Path(solver),
Expand All @@ -377,7 +371,6 @@ def run_bytecode(hexcode: str, args: HalmosConfig) -> List[Exec]:

for idx, ex in enumerate(exs):
result_exs.append(ex)

opcode = ex.current_opcode()
error = ex.context.output.error
returndata = ex.context.output.data
Expand Down Expand Up @@ -414,6 +407,7 @@ def deploy_test(
message = Message(
target=this,
caller=mk_caller(args),
origin=mk_addr("tx_origin"),
value=0,
data=ByteVec(),
call_scheme=EVM.CREATE,
Expand All @@ -425,7 +419,6 @@ def deploy_test(
balance=mk_balance(),
block=mk_block(),
context=CallContext(message=message),
this=this,
pgm=None, # to be added
symbolic=False,
path=Path(mk_solver(args)),
Expand Down Expand Up @@ -474,7 +467,6 @@ def deploy_test(
ex.st = State()
ex.context.output = CallOutput()
ex.jumpis = {}
ex.prank = Prank()

return ex

Expand Down Expand Up @@ -503,18 +495,19 @@ def setup(
dyn_param_size = [] # TODO: propagate to run
mk_calldata(abi, setup_info, calldata, dyn_param_size, args)

parent_message = setup_ex.message()
setup_ex.context = CallContext(
message=Message(
target=setup_ex.message().target,
caller=setup_ex.message().caller,
target=parent_message.target,
caller=parent_message.caller,
origin=parent_message.origin,
value=0,
data=calldata,
call_scheme=EVM.CALL,
),
)

setup_exs_all = sevm.run(setup_ex)

setup_exs_no_error = []

for idx, setup_ex in enumerate(setup_exs_all):
Expand Down Expand Up @@ -651,8 +644,9 @@ def run(
mk_calldata(abi, fun_info, cd, dyn_param_size, args)

message = Message(
target=setup_ex.this,
target=setup_ex.this(),
caller=setup_ex.caller(),
origin=setup_ex.origin(),
value=0,
data=cd,
call_scheme=EVM.CALL,
Expand Down Expand Up @@ -680,14 +674,12 @@ def run(
#
context=CallContext(message=message),
callback=None,
this=setup_ex.this,
#
pgm=setup_ex.code[setup_ex.this],
pgm=setup_ex.code[setup_ex.this()],
pc=0,
st=State(),
jumpis={},
symbolic=args.symbolic_storage,
prank=Prank(), # prank is reset after setUp()
#
path=path,
alias=setup_ex.alias.copy(),
Expand Down
Loading

0 comments on commit f029418

Please sign in to comment.