Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: manually solve dynamic array overflow conditions #366

Merged
merged 7 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BitVec,
BitVecRef,
BoolVal,
CheckSatResult,
Concat,
Extract,
Function,
Expand Down Expand Up @@ -94,14 +95,19 @@
debug,
extract_bytes,
f_ecrecover,
f_sha3_256_name,
f_sha3_512_name,
f_sha3_name,
hexify,
int_of,
is_bool,
is_bv,
is_bv_value,
is_concrete,
is_f_sha3_name,
is_non_zero,
is_zero,
match_dynamic_array_overflow_condition,
restore_precomputed_hashes,
sha3_inv,
str_opcode,
Expand Down Expand Up @@ -994,15 +1000,35 @@ def dump(self, print_mem=False) -> str:
def advance_pc(self) -> None:
self.pc = self.pgm.next_pc(self.pc)

def check(self, cond: Any) -> Any:
cond = simplify(cond)
def quick_custom_check(self, cond: BitVecRef) -> CheckSatResult | None:
daejunpark marked this conversation as resolved.
Show resolved Hide resolved
"""
Quick custom checker for specific known patterns.

This method checks for certain common conditions that can be evaluated
quickly without invoking the full SMT solver.

Returns:
sat if the condition is satisfiable
unsat if the condition is unsatisfiable
None if the condition requires full SMT solving
"""
if is_true(cond):
return sat

if is_false(cond):
return unsat

# Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64
if match_dynamic_array_overflow_condition(cond):
return unsat

def check(self, cond: Any) -> Any:
cond = simplify(cond)

# use quick custom checker for common patterns before falling back to SMT solver
if result := self.quick_custom_check(cond):
return result

return self.path.check(cond)

def select(
Expand Down Expand Up @@ -1063,7 +1089,7 @@ def sha3_data(self, data: Bytes) -> Word:
data = bytes_to_bv_value(data)

f_sha3 = Function(
f"f_sha3_{size * 8}", BitVecSorts[size * 8], BitVecSort256
f_sha3_name(size * 8), BitVecSorts[size * 8], BitVecSort256
)
sha3_expr = f_sha3(data)
else:
Expand Down Expand Up @@ -1288,17 +1314,17 @@ def get_key_structure(cls, loc) -> tuple:
def decode(cls, loc: Any) -> Any:
loc = normalize(loc)
# m[k] : hash(k.m)
if loc.decl().name() == "f_sha3_512":
if loc.decl().name() == f_sha3_512_name:
args = loc.arg(0)
offset = simplify(Extract(511, 256, args))
base = simplify(Extract(255, 0, args))
return cls.decode(base) + (offset, ZERO)
# a[i] : hash(a) + i
elif loc.decl().name() == "f_sha3_256":
elif loc.decl().name() == f_sha3_256_name:
base = loc.arg(0)
return cls.decode(base) + (ZERO,)
# m[k] : hash(k.m) where |k| != 256-bit
elif loc.decl().name().startswith("f_sha3_"):
elif is_f_sha3_name(loc.decl().name()):
sha3_input = normalize(loc.arg(0))
if sha3_input.decl().name() == "concat" and sha3_input.num_args() == 2:
offset = simplify(sha3_input.arg(0))
Expand Down Expand Up @@ -1417,12 +1443,12 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None:
@classmethod
def decode(cls, loc: Any) -> Any:
loc = normalize(loc)
if loc.decl().name() == "f_sha3_512": # hash(hi,lo), recursively
if loc.decl().name() == f_sha3_512_name: # hash(hi,lo), recursively
args = loc.arg(0)
hi = cls.decode(simplify(Extract(511, 256, args)))
lo = cls.decode(simplify(Extract(255, 0, args)))
return cls.simple_hash(Concat(hi, lo))
elif loc.decl().name().startswith("f_sha3_"):
elif is_f_sha3_name(loc.decl().name()):
sha3_input = normalize(loc.arg(0))
if sha3_input.decl().name() == "concat":
decoded_sha3_input_args = [
Expand Down Expand Up @@ -2359,6 +2385,12 @@ def jumpi(
follow_false = visited[False] < self.options.loop
if not (follow_true and follow_false):
self.logs.bounded_loops.append(jid)
if self.options.debug:
debug(f"\nloop id: {jid}")
debug(f"loop condition: {cond}")
debug(f"calldata: {ex.calldata()}")
debug("path condition:")
debug(ex.path)
else:
# for constant-bounded loops
follow_true = potential_true
Expand Down
50 changes: 50 additions & 0 deletions src/halmos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from typing import Any

from z3 import (
Z3_OP_BADD,
Z3_OP_CONCAT,
Z3_OP_ULEQ,
BitVecNumRef,
BitVecRef,
BitVecSort,
Expand All @@ -21,11 +23,13 @@
SignExt,
SolverFor,
ZeroExt,
eq,
is_app,
is_app_of,
is_bool,
is_bv,
is_bv_value,
is_not,
simplify,
)

Expand Down Expand Up @@ -94,6 +98,18 @@ def __getitem__(self, size: int) -> BitVecSort:
)


def is_f_sha3_name(name: str) -> bool:
return name.startswith("f_sha3_")


def f_sha3_name(bitsize: int) -> str:
return f"f_sha3_{bitsize}"


f_sha3_256_name = f_sha3_name(256)
f_sha3_512_name = f_sha3_name(512)


def wrap(x: Any) -> Word:
if is_bv(x):
return x
Expand Down Expand Up @@ -349,6 +365,40 @@ def byte_length(x: Any, strict=True) -> int:
raise TypeError(f"byte_length({x}) of type {type(x)}")


def match_dynamic_array_overflow_condition(cond: BitVecRef) -> bool:
"""
Check if `cond` matches the following pattern:
Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64

This condition is satisfied when a dynamic array at `slot` exceeds the storage limit.
Since such an overflow is highly unlikely in practice, we assume that this condition is unsat.

Note: we already assume that any sha3 hash output is smaller than 2**256 - 2**64 (see SEVM.sha3_data()).
However, the smt solver may not be able to solve this condition within the branching timeout.
In such cases, this explicit pattern serves as a fallback to avoid exploring practically infeasible paths.

We don't need to handle the negation of this condition, because unknown conditions are conservatively assumed to be sat.
"""

# Not(ule)
if not is_not(cond):
return False
ule = cond.arg(0)

# Not(ULE(left, right)
if not is_app_of(ule, Z3_OP_ULEQ):
return False
left, right = ule.arg(0), ule.arg(1)

# Not(ULE(f_sha3_N(slot), offset + base))
if not (is_f_sha3_name(left.decl().name()) and is_app_of(right, Z3_OP_BADD)):
return False
offset, base = right.arg(0), right.arg(1)

# Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))) and offset < 2**64
return eq(left, base) and is_bv_value(offset) and offset.as_long() < 2**64


def stripped(hexstring: str) -> str:
"""Remove 0x prefix from hexstring"""
return hexstring[2:] if hexstring.startswith("0x") else hexstring
Expand Down
11 changes: 11 additions & 0 deletions tests/expected/all.json
Original file line number Diff line number Diff line change
Expand Up @@ -2365,6 +2365,17 @@
"num_bounded_loops": null
}
],
"test/Solver.t.sol:SolverTest": [
{
"name": "check_dynamic_array_overflow()",
"exitcode": 0,
"num_models": 0,
"models": null,
"num_paths": null,
"time": null,
"num_bounded_loops": null
}
],
"test/StaticContexts.t.sol:StaticContextsTest": [
{
"name": "check_create2_fails()",
Expand Down
13 changes: 13 additions & 0 deletions tests/regression/test/Solver.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-License-Identifier: AGPL-3.0
pragma solidity >=0.8.0 <0.9.0;

import "forge-std/Test.sol";
import {SymTest} from "halmos-cheatcodes/SymTest.sol";

contract SolverTest is SymTest, Test {
uint[] numbers;

function check_dynamic_array_overflow() public {
numbers = new uint[](5); // shouldn't generate loop bounds warning
}
}
56 changes: 56 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from z3 import (
ULE,
BitVec,
BitVecSort,
BitVecVal,
Function,
Not,
simplify,
)

from halmos.utils import f_sha3_256_name, match_dynamic_array_overflow_condition


def test_match_dynamic_array_overflow_condition():
# Create Z3 objects
f_sha3_256 = Function(f_sha3_256_name, BitVecSort(256), BitVecSort(256))
slot = BitVec("slot", 256)
offset = BitVecVal(1000, 256) # Less than 2**64

# Test the function
cond = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot)))
assert match_dynamic_array_overflow_condition(cond)

# Test with opposite order of addition
opposite_order_cond = Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset))
assert not match_dynamic_array_overflow_condition(opposite_order_cond)

# Test with opposite order after simplification
simplified_opposite_order_cond = simplify(
Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset))
)
assert match_dynamic_array_overflow_condition(simplified_opposite_order_cond)

# Test with offset = 2**64 - 1 (should match)
max_valid_offset = BitVecVal(2**64 - 1, 256)
max_valid_cond = Not(ULE(f_sha3_256(slot), max_valid_offset + f_sha3_256(slot)))
assert match_dynamic_array_overflow_condition(max_valid_cond)

# Test with offset >= 2**64
large_offset = BitVecVal(2**64, 256)
large_offset_cond = Not(ULE(f_sha3_256(slot), large_offset + f_sha3_256(slot)))
assert not match_dynamic_array_overflow_condition(large_offset_cond)

# Test with a different function
different_func = Function("different_func", BitVecSort(256), BitVecSort(256))
non_matching_cond = Not(ULE(different_func(slot), offset + different_func(slot)))
assert not match_dynamic_array_overflow_condition(non_matching_cond)

# Test with just ULE, not Not(ULE(...))
ule_only = ULE(f_sha3_256(slot), offset + f_sha3_256(slot))
assert not match_dynamic_array_overflow_condition(ule_only)

# Test with mismatched slots
slot2 = BitVec("slot2", 256)
mismatched_slots = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot2)))
assert not match_dynamic_array_overflow_condition(mismatched_slots)
Loading