Skip to content

Commit

Permalink
refactor and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark committed Sep 23, 2024
1 parent 0cc631a commit 9397fd8
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 9 deletions.
21 changes: 19 additions & 2 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BitVec,
BitVecRef,
BoolVal,
CheckSatResult,
Concat,
Extract,
Function,
Expand Down Expand Up @@ -995,9 +996,18 @@ def dump(self, print_mem=False) -> str:
def advance_pc(self) -> None:
self.pc = self.pgm.next_pc(self.pc)

def check(self, cond: Any) -> Any:
cond = simplify(cond)
def quick_custom_check(self, cond: BitVecRef) -> CheckSatResult | None:
"""
Quick custom checker for specific known patterns.
This method checks for certain common conditions that can be evaluated
quickly without invoking the full SMT solver.
Returns:
sat if the condition is satisfiable
unsat if the condition is unsatisfiable
None if the condition requires full SMT solving
"""
if is_true(cond):
return sat

Expand All @@ -1008,6 +1018,13 @@ def check(self, cond: Any) -> Any:
if match_dynamic_array_overflow_condition(cond):
return unsat

def check(self, cond: Any) -> Any:
cond = simplify(cond)

# use quick custom checker for common patterns before falling back to SMT solver
if result := self.quick_custom_check(cond):
return result

return self.path.check(cond)

def select(
Expand Down
9 changes: 2 additions & 7 deletions src/halmos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,10 @@ def match_dynamic_array_overflow_condition(cond: BitVecRef) -> bool:
# Not(ULE(f_sha3_256(slot), offset + base))
if not (left.decl().name() == "f_sha3_256" and is_app_of(right, Z3_OP_BADD)):
return False
slot = left.arg(0)
offset, base = right.arg(0), right.arg(1)

# Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot)))
if not (base.decl().name() == "f_sha3_256" and eq(base.arg(0), slot)):
return False

# offset < 2**64
return is_bv_value(offset) and offset.as_long() < 2**64
# Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))) and offset < 2**64
return eq(left, base) and is_bv_value(offset) and offset.as_long() < 2**64


def stripped(hexstring: str) -> str:
Expand Down
56 changes: 56 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from z3 import (
ULE,
BitVec,
BitVecSort,
BitVecVal,
Function,
Not,
simplify,
)

from halmos.utils import match_dynamic_array_overflow_condition


def test_match_dynamic_array_overflow_condition():
# Create Z3 objects
f_sha3_256 = Function("f_sha3_256", BitVecSort(256), BitVecSort(256))
slot = BitVec("slot", 256)
offset = BitVecVal(1000, 256) # Less than 2**64

# Test the function
cond = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot)))
assert match_dynamic_array_overflow_condition(cond)

# Test with opposite order of addition
opposite_order_cond = Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset))
assert not match_dynamic_array_overflow_condition(opposite_order_cond)

# Test with opposite order after simplification
simplified_opposite_order_cond = simplify(
Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset))
)
assert match_dynamic_array_overflow_condition(simplified_opposite_order_cond)

# Test with offset = 2**64 - 1 (should match)
max_valid_offset = BitVecVal(2**64 - 1, 256)
max_valid_cond = Not(ULE(f_sha3_256(slot), max_valid_offset + f_sha3_256(slot)))
assert match_dynamic_array_overflow_condition(max_valid_cond)

# Test with offset >= 2**64
large_offset = BitVecVal(2**64, 256)
large_offset_cond = Not(ULE(f_sha3_256(slot), large_offset + f_sha3_256(slot)))
assert not match_dynamic_array_overflow_condition(large_offset_cond)

# Test with a different function
different_func = Function("different_func", BitVecSort(256), BitVecSort(256))
non_matching_cond = Not(ULE(different_func(slot), offset + different_func(slot)))
assert not match_dynamic_array_overflow_condition(non_matching_cond)

# Test with just ULE, not Not(ULE(...))
ule_only = ULE(f_sha3_256(slot), offset + f_sha3_256(slot))
assert not match_dynamic_array_overflow_condition(ule_only)

# Test with mismatched slots
slot2 = BitVec("slot2", 256)
mismatched_slots = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot2)))
assert not match_dynamic_array_overflow_condition(mismatched_slots)

0 comments on commit 9397fd8

Please sign in to comment.