diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index 8c86756a..dc12753f 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -24,6 +24,7 @@ BitVec, BitVecRef, BoolVal, + CheckSatResult, Concat, Extract, Function, @@ -995,9 +996,18 @@ def dump(self, print_mem=False) -> str: def advance_pc(self) -> None: self.pc = self.pgm.next_pc(self.pc) - def check(self, cond: Any) -> Any: - cond = simplify(cond) + def quick_custom_check(self, cond: BitVecRef) -> CheckSatResult | None: + """ + Quick custom checker for specific known patterns. + This method checks for certain common conditions that can be evaluated + quickly without invoking the full SMT solver. + + Returns: + sat if the condition is satisfiable + unsat if the condition is unsatisfiable + None if the condition requires full SMT solving + """ if is_true(cond): return sat @@ -1008,6 +1018,13 @@ def check(self, cond: Any) -> Any: if match_dynamic_array_overflow_condition(cond): return unsat + def check(self, cond: Any) -> Any: + cond = simplify(cond) + + # use quick custom checker for common patterns before falling back to SMT solver + if result := self.quick_custom_check(cond): + return result + return self.path.check(cond) def select( diff --git a/src/halmos/utils.py b/src/halmos/utils.py index 43e037c3..54a54dbf 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -381,15 +381,10 @@ def match_dynamic_array_overflow_condition(cond: BitVecRef) -> bool: # Not(ULE(f_sha3_256(slot), offset + base)) if not (left.decl().name() == "f_sha3_256" and is_app_of(right, Z3_OP_BADD)): return False - slot = left.arg(0) offset, base = right.arg(0), right.arg(1) - # Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))) - if not (base.decl().name() == "f_sha3_256" and eq(base.arg(0), slot)): - return False - - # offset < 2**64 - return is_bv_value(offset) and offset.as_long() < 2**64 + # Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))) and offset < 2**64 + return eq(left, base) and is_bv_value(offset) and offset.as_long() < 2**64 def stripped(hexstring: str) -> str: diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..e986eeec --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,56 @@ +from z3 import ( + ULE, + BitVec, + BitVecSort, + BitVecVal, + Function, + Not, + simplify, +) + +from halmos.utils import match_dynamic_array_overflow_condition + + +def test_match_dynamic_array_overflow_condition(): + # Create Z3 objects + f_sha3_256 = Function("f_sha3_256", BitVecSort(256), BitVecSort(256)) + slot = BitVec("slot", 256) + offset = BitVecVal(1000, 256) # Less than 2**64 + + # Test the function + cond = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))) + assert match_dynamic_array_overflow_condition(cond) + + # Test with opposite order of addition + opposite_order_cond = Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset)) + assert not match_dynamic_array_overflow_condition(opposite_order_cond) + + # Test with opposite order after simplification + simplified_opposite_order_cond = simplify( + Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset)) + ) + assert match_dynamic_array_overflow_condition(simplified_opposite_order_cond) + + # Test with offset = 2**64 - 1 (should match) + max_valid_offset = BitVecVal(2**64 - 1, 256) + max_valid_cond = Not(ULE(f_sha3_256(slot), max_valid_offset + f_sha3_256(slot))) + assert match_dynamic_array_overflow_condition(max_valid_cond) + + # Test with offset >= 2**64 + large_offset = BitVecVal(2**64, 256) + large_offset_cond = Not(ULE(f_sha3_256(slot), large_offset + f_sha3_256(slot))) + assert not match_dynamic_array_overflow_condition(large_offset_cond) + + # Test with a different function + different_func = Function("different_func", BitVecSort(256), BitVecSort(256)) + non_matching_cond = Not(ULE(different_func(slot), offset + different_func(slot))) + assert not match_dynamic_array_overflow_condition(non_matching_cond) + + # Test with just ULE, not Not(ULE(...)) + ule_only = ULE(f_sha3_256(slot), offset + f_sha3_256(slot)) + assert not match_dynamic_array_overflow_condition(ule_only) + + # Test with mismatched slots + slot2 = BitVec("slot2", 256) + mismatched_slots = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot2))) + assert not match_dynamic_array_overflow_condition(mismatched_slots)