Skip to content

Commit

Permalink
feat: support CREATE2 (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark authored Aug 8, 2023
1 parent 1bafa77 commit c5a3f0d
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def run(
#
log=deepcopy(setup_ex.log),
cnts=deepcopy(setup_ex.cnts),
sha3s=deepcopy(setup_ex.sha3s),
sha3s=deepcopy(setup_ex.sha3s), # TODO: shallow copy
storages=deepcopy(setup_ex.storages),
balances=deepcopy(setup_ex.balances),
calls=deepcopy(setup_ex.calls),
Expand Down
78 changes: 51 additions & 27 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@

magic_address: int = 0xAAAA0000

create2_magic_address: int = 0xBBBB0000

new_address_offset: int = 1


Expand Down Expand Up @@ -626,7 +628,7 @@ class Exec: # an execution path
# logs
log: List[Tuple[List[Word], Any]] # event logs emitted
cnts: Dict[str, Dict[int, int]] # opcode -> frequency; counters
sha3s: List[Tuple[Word, Word]] # sha3 hashes generated
sha3s: Dict[Word, int] # sha3 hashes generated
storages: Dict[Any, Any] # storage updates
balances: Dict[Any, Any] # balance updates
calls: List[Any] # external calls
Expand Down Expand Up @@ -728,7 +730,7 @@ def __str__(self) -> str:
)
),
f"SHA3 hashes:\n",
"".join(map(lambda x: f"- {x}\n", self.sha3s)),
"".join(map(lambda x: f"- {self.sha3s[x]}: {x}\n", self.sha3s)),
f"External calls:\n",
"".join(map(lambda x: f"- {x}\n", self.calls)),
# f"Calldata: {self.calldata}\n",
Expand Down Expand Up @@ -919,43 +921,50 @@ def normalize(expr: Any) -> Any:
def sha3(self) -> None:
loc: int = self.st.mloc()
size: int = int_of(self.st.pop(), "symbolic SHA3 data size")
self.sha3_data(wload(self.st.memory, loc, size), size)
self.st.push(self.sha3_data(wload(self.st.memory, loc, size), size))

def sha3_data(self, data: Bytes, size: int) -> None:
def sha3_data(self, data: Bytes, size: int) -> Word:
f_sha3 = Function(
"sha3_" + str(size * 8), BitVecSort(size * 8), BitVecSort(256)
)
sha3_expr = f_sha3(data)
sha3_output = BitVec(f"sha3_output_{len(self.sha3s):>02}", 256)
self.solver.add(sha3_output == sha3_expr)

# assume hash values are sufficiently smaller than the uint max
self.solver.add(ULE(sha3_output, con(2**256 - 2**64)))
self.assume_sha3_distinct(sha3_output, sha3_expr)
if size == 64 or size == 32: # for storage hashed location
self.st.push(sha3_expr)
self.solver.add(ULE(sha3_expr, con(2**256 - 2**64)))

# assume no hash collision
self.assume_sha3_distinct(sha3_expr)

# handle create2 hash
if size == 85 and eq(extract_bytes(data, 0, 1), con(0xFF, 8)):
return con(create2_magic_address + self.sha3s[sha3_expr])
else:
self.st.push(sha3_output)
return sha3_expr

def assume_sha3_distinct(self, sha3_expr) -> None:
# skip if already exist
if sha3_expr in self.sha3s:
return

def assume_sha3_distinct(self, sha3_output, sha3_expr) -> None:
# we expect sha3_expr to be `sha3_<input-bitsize>(input_expr)`
sha3_decl_name = sha3_expr.decl().name()

for prev_sha3_output, prev_sha3_expr in self.sha3s:
for prev_sha3_expr in self.sha3s:
if prev_sha3_expr.decl().name() == sha3_decl_name:
# inputs have the same size: assume different inputs
# lead to different outputs
self.solver.add(
Implies(
sha3_expr.arg(0) != prev_sha3_expr.arg(0),
sha3_output != prev_sha3_output,
sha3_expr != prev_sha3_expr,
)
)
else:
# inputs have different sizes: assume the outputs are different
self.solver.add(sha3_output != prev_sha3_output)
self.solver.add(sha3_expr != prev_sha3_expr)

self.solver.add(sha3_output != con(0))
self.sha3s.append((sha3_output, sha3_expr))
self.solver.add(sha3_expr != con(0))
self.sha3s[sha3_expr] = len(self.sha3s)

def new_gas_id(self) -> int:
self.cnts["fresh"]["gas"] += 1
Expand Down Expand Up @@ -1738,6 +1747,7 @@ def call_unknown() -> None:
def create(
self,
ex: Exec,
op: int,
stack: List[Tuple[Exec, int]],
step_id: int,
out: List[Exec],
Expand All @@ -1747,15 +1757,32 @@ def create(
loc: int = int_of(ex.st.pop(), "symbolic CREATE offset")
size: int = int_of(ex.st.pop(), "symbolic CREATE size")

if op == EVM.CREATE2:
salt = ex.st.pop()

# lookup prank
caller = ex.prank.lookup(ex.this, con_addr(0))

# contract creation code
create_hexcode = wload(ex.st.memory, loc, size, prefer_concrete=True)
create_code = Contract(create_hexcode)

# new account address
new_addr = ex.new_address()
if op == EVM.CREATE:
new_addr = ex.new_address()
else: # EVM.CREATE2
if isinstance(create_hexcode, bytes):
create_hexcode = con(
int.from_bytes(create_hexcode, "big"), len(create_hexcode) * 8
)
code_hash = ex.sha3_data(create_hexcode, create_hexcode.size() // 8)
hash_data = simplify(Concat(con(0xFF, 8), caller, salt, code_hash))
new_addr = uint160(ex.sha3_data(hash_data, 85))

if new_addr in ex.code:
raise ValueError(f"existing address: {new_addr}")
ex.error = f"existing address: {hexify(new_addr)}"
out.append(ex)
return

for addr in ex.code:
ex.solver.add(new_addr != addr) # ensure new address is fresh
Expand All @@ -1764,9 +1791,6 @@ def create(
ex.code[new_addr] = Contract(b"") # existing code must be empty
ex.storage[new_addr] = {} # existing storage may not be empty and reset here

# lookup prank
caller = ex.prank.lookup(ex.this, new_addr)

# transfer value
# assume balance is enough; otherwise ignore this path
ex.solver.add(UGE(ex.balance_of(caller), value))
Expand Down Expand Up @@ -1975,7 +1999,7 @@ def create_branch(self, ex: Exec, cond: BitVecRef, target: int) -> Exec:
#
log=deepcopy(ex.log),
cnts=deepcopy(ex.cnts),
sha3s=deepcopy(ex.sha3s),
sha3s=deepcopy(ex.sha3s), # TODO: shallow copy
storages=deepcopy(ex.storages),
balances=deepcopy(ex.balances),
calls=deepcopy(ex.calls),
Expand Down Expand Up @@ -2229,8 +2253,8 @@ def run(self, ex0: Exec) -> Tuple[List[Exec], Steps]:
elif opcode == EVM.SHA3:
ex.sha3()

elif opcode == EVM.CREATE:
self.create(ex, stack, step_id, out, bounded_loops)
elif opcode in [EVM.CREATE, EVM.CREATE2]:
self.create(ex, opcode, stack, step_id, out, bounded_loops)
continue

elif opcode == EVM.POP:
Expand Down Expand Up @@ -2349,7 +2373,7 @@ def run(self, ex0: Exec) -> Tuple[List[Exec], Steps]:
val = int_of(insn.operand)
if opcode == EVM.PUSH32 and val in sha3_inv:
# restore precomputed hashes
ex.sha3_data(con(sha3_inv[val]), 32)
ex.st.push(ex.sha3_data(con(sha3_inv[val]), 32))
else:
ex.st.push(con(val))
else:
Expand Down Expand Up @@ -2444,7 +2468,7 @@ def mk_exec(
#
log=[],
cnts=defaultdict(lambda: defaultdict(int)),
sha3s=[],
sha3s={},
storages={},
balances={},
calls=[],
Expand Down
58 changes: 58 additions & 0 deletions tests/expected/all.json
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,64 @@
"num_bounded_loops": null
}
],
"test/Create2.t.sol:Create2Test": [
{
"name": "check_create2(uint256,uint256,bytes32)",
"exitcode": 0,
"num_models": 0,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_create2_caller(address,uint256,uint256,bytes32)",
"exitcode": 0,
"num_models": 0,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_create2_collision(uint256,uint256,bytes32)",
"exitcode": 1,
"num_models": 0,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_create2_collision_alias(uint256,uint256,bytes32)",
"exitcode": 1,
"num_models": 1,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_create2_concrete()",
"exitcode": 0,
"num_models": 0,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_create2_no_collision_1(uint256,uint256,bytes32,bytes32)",
"exitcode": 0,
"num_models": 0,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_create2_no_collision_2(uint256,uint256,bytes32)",
"exitcode": 0,
"num_models": 0,
"num_paths": null,
"time": null,
"num_bounded_loops": null
}
],
"test/Deal.t.sol:DealTest": [
{
"name": "check_deal_1(address,uint256)",
Expand Down
107 changes: 107 additions & 0 deletions tests/test/Create2.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// SPDX-License-Identifier: AGPL-3.0
pragma solidity >=0.8.0 <0.9.0;

import "forge-std/Test.sol";

contract C {
uint public num1;
uint public num2;

constructor(uint x, uint y) {
set(x, y);
}

function set(uint x, uint y) public {
num1 = x;
num2 = y;
}
}

contract Create2Test is Test {
function check_create2(uint x, uint y, bytes32 salt) public {
C c1 = new C{salt: salt}(x, y);

bytes32 codeHash = keccak256(abi.encodePacked(type(C).creationCode, abi.encode(x, y)));
bytes32 hash = keccak256(abi.encodePacked(bytes1(0xff), address(this), salt, codeHash));
address c2 = address(uint160(uint(hash)));

assert(address(c1) == c2);

assert(C(c2).num1() == x);
assert(C(c2).num2() == y);

c1.set(y, x);

assert(C(c2).num1() == y);
assert(C(c2).num2() == x);
}

function check_create2_caller(address caller, uint x, uint y, bytes32 salt) public {
vm.prank(caller);
C c1 = new C{salt: salt}(x, y);

bytes32 codeHash = keccak256(abi.encodePacked(type(C).creationCode, abi.encode(x, y)));
bytes32 hash = keccak256(abi.encodePacked(bytes1(0xff), caller, salt, codeHash));
address c2 = address(uint160(uint(hash)));

assert(address(c1) == c2);

assert(C(c2).num1() == x);
assert(C(c2).num2() == y);

c1.set(y, x);

assert(C(c2).num1() == y);
assert(C(c2).num2() == x);
}

function check_create2_concrete() public {
uint x = 1;
uint y = 2;
bytes32 salt = bytes32(uint(3));

C c1 = new C{salt: salt}(x, y);

bytes32 codeHash = keccak256(abi.encodePacked(type(C).creationCode, abi.encode(x, y)));
bytes32 hash = keccak256(abi.encodePacked(bytes1(0xff), address(this), salt, codeHash));
address c2 = address(uint160(uint(hash)));

assert(address(c1) == c2);

assert(C(c2).num1() == x);
assert(C(c2).num2() == y);

c1.set(y, x);

assert(C(c2).num1() == y);
assert(C(c2).num2() == x);
}

function check_create2_collision(uint x, uint y, bytes32 salt) public {
C c1 = new C{salt: salt}(x, y);
C c2 = new C{salt: salt}(x, y); // expected to fail
assert(c1 == c2); // deadcode
}

function check_create2_no_collision_1(uint x, uint y, bytes32 salt1, bytes32 salt2) public {
C c1 = new C{salt: salt1}(x, y);
C c2 = new C{salt: salt2}(x, y);
assert(c1 != c2);
}

function check_create2_no_collision_2(uint x, uint y, bytes32 salt) public {
vm.assume(x != y);

C c1 = new C{salt: salt}(x, y);
C c2 = new C{salt: salt}(y, x);
assert(c1 != c2);
}

function check_create2_collision_alias(uint x, uint y, bytes32 salt) public {
vm.assume(x == y);

C c1 = new C{salt: salt}(x, y);
C c2 = new C{salt: salt}(y, x);
assert(c1 == c2); // currently fail // TODO: support symbolic alias for hash
}
}

0 comments on commit c5a3f0d

Please sign in to comment.