From 0b6506b51a86d8ce94c1fd6408d6b1c50ba22052 Mon Sep 17 00:00:00 2001 From: Dmitry Petukhov Date: Wed, 25 Oct 2023 01:31:12 +0500 Subject: [PATCH] Add ability to set assertions and assumptions Values on top of the stack can be constrained with assertions Data placeholders can be constrained with assumptions --- README.md | 112 +++ bsst/__init__.py | 986 +++++++++++++++++------ release-notes.md | 7 +- tests/runtests.sh | 1 + tests/test_assertions_and_assumptions.py | 535 ++++++++++++ tests/test_data_placeholders.py | 25 +- tests/test_elements_script_tests.py | 37 +- tests/test_scripts.py | 22 +- tests/test_util/__init__.py | 14 + tests/test_varnames.py | 26 +- 10 files changed, 1450 insertions(+), 315 deletions(-) create mode 100755 tests/test_assertions_and_assumptions.py create mode 100644 tests/test_util/__init__.py diff --git a/README.md b/README.md index f5e6c0d..cf21e15 100644 --- a/README.md +++ b/README.md @@ -76,10 +76,14 @@ Syntax parser is rather basic: * ScriptNum values are represented with normal base10 integers. * Data (but not opcodes) can be enclosed in angle brackets (like this: `<0x1234>`), and these angle brackets will be ignored (for compatibilty with ScriptWiz IDE syntax) +### Data placeholders + Identifiers starting with `$` are recognized as data placeholders: `$some_var` `//` marks the start of the comment, that spans to end of line. +### Data references + A special format of comment is recoginzed: `OP_ADD // =>add_result` will mark the value on the stack after `OP_ADD` with @@ -91,6 +95,114 @@ an apostrophe <<'>> will be appended to the identifier with different value. The data reference identifiers will be prepended with `&` in the report +### Assertions + +Specially-formatted comments can be used to put constraints on value on top of the stack: `// bsst-assert:` and `// bsst-assert-size:`. The difference is that the former puts constraints on the value itself, while the latter constraints the +data size instead of value. + +The expressibility of these assertions are limited, as their primary purpose is to help the solver by reducing the range of values to be considered. + +For the value on top of the stack constrained via assertion, B'SST will check +if the value can happen to be outside the range defined by the assertion +expression. If it can, the currently analyzed execution path will be deemed +as failed, and in the report the failure will be shown as +`assertion failed at line ` or `check_assertion_at_line_` where `` +would be the line at which the failing `// bsst-assert:` comment is at. + +After the assertion check is successfully passed, the value will be assumed to be +constrained by the assertion expression. + +The difference between `assertion failed at line ` and `check_assertion_at_line_` is that the former is detected at the time the assertion is applied at the +position in the script it resides on, while the later is detected afterwards, when other constraints are imposed on values, and that may cause the assertion constraints to be violated. + +A data reference can be supplied as argument, like `// bsst-assert(&ref):` or `// bsst-assert-size(&ref):`, and then the target of the assertion will be this data reference instead of the top of the stack. The assertion will be checked at the place where the assetion itself is declared, not at the place where the data reference is declared. + +A witness name in the form of `wit` where `` is a number, is also accepted as assertion argument. The witness must be referenced by the script at the time when assertion is checked, otherwise the assertion will be ignored (with a warning) + +#### Assertion expression syntax + +After `:`, a whitespace-separated list of expressions is expected, +finished at end of line. The following is recognized in expressions: + +- decimal number: scriptnum equal to the number, for example `1`, `-33` +- le64(): LE64 value equal to the number, for example `le64(0)`, `le64(125)` +- bytes in hex (either as 0x1234 or x('1234'): bytes equal to the hex-encoded +- string in single quotes: bytes equal to utf-8 encoding of the string + +Before these, `!=` can be placed to express non-equality. `=` can also be placed +before these, for readability: `=42` is the same as `42`. `!=le64(0)` means +"not equal to 64-byte zero", `!='abc'` means "not equal to the string 'abc'" + +Before decimal number or le64 number, `>`, `<`, `>=`, `<=` can be placed +to express "greater than", "less than", "greater or equal", "less or equal", +for example `>0`, `<=-44`, `>=le64(999)` + +For decimal or le64 numbers, a range expression is recognized: `1..456` means +from 1 to 456. Likewise, `le64(1)..le64(456)` + +Within one `// bsst-assert:` or `// bsst-assert-size:`, space-separated +expressions are combined with `OR` logic. +For example, `OP_ADD // bsst-assert: >1 !=8 <=-3` would express +that "result of `OP_ADD` must be above 1 or not equal to 8 or less than or equal +to -3". Note that !=8 here is meaningless, because (`>1` OR `!=8`) is the same +as `>1`. So this expression constraints the value to "any representable +scriptnum, except -2, -1, 0, 1" + +If more than one `// bsst-assert:` or `bsst-assert-size:` is placed without any +script opcode or data between them, the expressions of the asserts since +last opcode or data will be combined with `AND` logic. For example, + +``` +OP_ADD // bsst-assert: >1 <=-3 -44..55 + // bsst-assert: 'a' + // bsst-assert-size: 1 +``` + +Will express + +``` +((value above 1) OR (value below -2) OR (value between -44 and 55 inclusive)) +AND (value equal to 'a') +AND (data size equal 1) +``` + +Note that expressions other than "value equal to 'a'" here are meaningless, but included for illustration purposes. + +Integers in asserts also impose scriptnum-encoding constraints on their targets. +That is, `// bsst-assert: 3 0x00` is the same as `// bsst-assert: 3`, unless +`---minimaldata-flag=false`, because `0x00` is not a minimal-encoded scriptnum, +and, given that values are combined with `OR` logic, it will just be ignored. + +Combining `3` and `0x00` with two separate asserts on the same target value with +the same minimaldata flag setting will result in assertion to always be triggered, +because then these two will be combined with `AND` logic, and the result will be +an empty set + +Asserts with LE64 integers also impose a constraint of 'size is exactly 8 bytes' +on their targets. + +Mixing scriptnum and LE64 values in assert on the same target value is not +allowed, although mixing scriptnums with arbitrary byte expressions is allowed. + +### Assumptions + +Specially-formatted comments can be used to put unconditional constraints +on data placeholders: `// bsst-assume($name):` and `// bsst-assume-size($name):` +to apply assumption to the data placeholder `$name`. + +Assumptions differ from assertions in the following: + +- Only work with data placeholders +- Applied to corresponding data placeholder regardless of where the assumption or the data placeholder reside in the source file +- No check is performed to determine if the value can be outside of the range defined by the expression. The constraints defined by the expression are simply assumed to apply to the corresponding data placeholder + +In other aspects, assumptions work similar to assertions. The syntax for expressions is the same, different assumptions with the same data placeholder are combined with `AND` logic, `// bsst-assume-size($name):` works with data size instead of value, etc. + +Note that if conflicting assumptions are placed on a data placeholder, +or an assumed constraint on data placeholder might possibly relate to a script +failure, you can still see error code `check_assumption_at_line_` where `` +points to the line with an assumption + ## Reports The reports show: diff --git a/bsst/__init__.py b/bsst/__init__.py index 3fdc6bf..fbcbd71 100755 --- a/bsst/__init__.py +++ b/bsst/__init__.py @@ -111,7 +111,7 @@ import importlib.util import multiprocessing -from typing import TextIO +from typing import TextIO, Mapping from multiprocessing.pool import AsyncResult from copy import deepcopy from dataclasses import dataclass @@ -138,10 +138,34 @@ PLUGIN_NAME_PREFIX = 'op_plugin' +class BSSTError(Exception): + ... + + +class BSSTPluginLoadError(BSSTError): + ... + + +class BSSTParsingError(BSSTError): + ... + + +class BSSTSolvingError(BSSTError): + ... + + +class BSSTInitializationError(BSSTError): + ... + + class SymEnvironment: _nulldummy_flag: Optional[bool] + post_finalize_hook: Optional[ + Callable[['ExecContext', 'SymEnvironment'], None] + ] = None + @property def input_file(self) -> str: """The file of the script to analyze. The dash "-" means STDIN @@ -562,8 +586,8 @@ def skip_immediately_failed_branches_on( assert isinstance(value, str) with CurrentEnvironment(self.__class__()): - script_body, _, _ = get_opcodes([value]) - self._skip_immediately_failed_branches_on = script_body + si = get_opcodes([value]) + self._skip_immediately_failed_branches_on = si.body @property def is_miner(self) -> bool: @@ -998,6 +1022,12 @@ def __init__(self, *, is_for_usage_message: bool = False) -> None: self.z3_current_constraints_frame: list[ tuple['z3.BoolRef', str, Optional[tuple['SymData', int]]]] = [] + self.script_info = ScriptInfo() + + self.dummyexpr_counter = 0 + self.stack_symdata_index: int | None = None + self.data_placeholders: dict[str, 'SymData'] = {} + self._root_branch: Optional['Branchpoint'] = None self._enabled_opcodes: list['OpCode'] = [] self._solver: Optional['z3.Solver'] = None @@ -1140,17 +1170,15 @@ def load_plugin_modules(self) -> None: if not g_optional_modules_register.get(name): spec = importlib.util.spec_from_file_location(module_name, ppath) if spec is None: - sys.stderr.write(f'cannot load plugin \'{ppath}\': spec_from_file_location failed') - sys.exit(-1) + raise BSSTPluginLoadError( + f'cannot load plugin \'{ppath}\': spec_from_file_location failed') if spec.loader is None: - sys.stderr.write(f'cannot load plugin \'{ppath}\': spec.loader is None') - sys.exit(-1) + raise BSSTPluginLoadError(f'cannot load plugin \'{ppath}\': spec.loader is None') plugin_module = importlib.util.module_from_spec(spec) if plugin_module is None: - sys.stderr.write(f'cannot load plugin \'{ppath}\': module_from_spec failed') - sys.exit(-1) + raise BSSTPluginLoadError(f'cannot load plugin \'{ppath}\': module_from_spec failed') sys.modules[module_name] = plugin_module spec.loader.exec_module(plugin_module) @@ -1312,10 +1340,10 @@ def Int(v: str) -> 'z3.ArithRef': def FreshInt(prefix: str) -> 'z3.ArithRef': - if not cur_env().z3_enabled: - global g_dummyexpr_counter - g_dummyexpr_counter += 1 - return DummyExpr('INT', '!{prefix}_{g_dummyexpr_counter}') + env = cur_env() + if not env.z3_enabled: + env.dummyexpr_counter += 1 + return DummyExpr('INT', '!{prefix}_{env.dummyexpr_counter}') return z3.FreshInt(prefix) @@ -1328,10 +1356,10 @@ def Const(v: str, sort: Any) -> 'z3.ExprRef': def FreshConst(sort: Any, prefix: str) -> 'z3.ExprRef': - if not cur_env().z3_enabled: - global g_dummyexpr_counter - g_dummyexpr_counter += 1 - return DummyExpr('CONST', '!{prefix}_{g_dummyexpr_counter}', sort) + env = cur_env() + if not env.z3_enabled: + env.dummyexpr_counter += 1 + return DummyExpr('CONST', '!{prefix}_{env.dummyexpr_counter}', sort) return z3.FreshConst(sort, prefix) @@ -1355,16 +1383,10 @@ def get_name_suffix(self) -> str: g_is_in_processing = False g_do_process_data_reference_names = False - -g_dummyexpr_counter = 0 -g_stack_symdata_index: int | None = None g_current_exec_context: Optional['ExecContext'] = None g_current_op: Optional['OpCode'] = None g_skip_assertion_for_enforcement_condition: Optional[tuple['SymData', int]] = None -g_data_placeholders: dict[str, 'SymData'] = {} - g_check_op_start_time = 0.0 - g_mode_name_for_opcodes = '' g_opcodes_for_mode: dict[str, list['OpCode']] = {} @@ -1558,6 +1580,7 @@ def __call__(self) -> 'FailureCodeDispatcher': def parse_failcodes(errstr: str) -> list[tuple[str, int]]: + assert errstr.startswith(SCRIPT_FAILURE_PREFIX_SOLVER) info_set: set[tuple[str, int]] = set() plen = len(SCRIPT_FAILURE_PREFIX_SOLVER) for code in errstr[plen:].split(','): @@ -1578,7 +1601,7 @@ def parse_failcodes(errstr: str) -> list[tuple[str, int]]: lpos = code[atpos+1:].find('L') if lpos < 0: assert code[atpos+1:atpos+5] in ('END', 'END~') - pc = len(g_script_body) + pc = len(cur_env().script_info.body) else: pc = int(code[atpos+1:atpos+1+lpos]) @@ -2184,7 +2207,7 @@ def z3check( # noqa if env.log_progress: env.write(' ') if env.exit_on_solver_result_unknown: - sys.exit(-1) + raise BSSTSolvingError() assert isinstance(model_values_or_fail_reason, str) cur_context().add_warning( @@ -2379,8 +2402,35 @@ def getval(v: SymData) -> int | bytes: a.set_static(getval(b)) +def collect_model_values( + values: Iterable['SymData'], + cb: Callable[[Optional[dict[str, 'ConstrainedValue']]], + Optional['z3.BoolRef']], + *, preferred_rtype: Optional[SymDataRType] = None +) -> None: + z3_push_context() + + mvdict_req: dict[str, tuple[str, SymDataRType]] = {} + mvnamemap: dict[str, 'SymData'] = {} + + for v in values: + v.update_model_values_request_dict(mvdict_req, mvnamemap, + preferred_rtype=preferred_rtype) + + mvdict: Optional[dict[str, 'ConstrainedValue']] = None + + while cb(mvdict): + try: + mvdict = z3check(force_check=True, + model_values_to_retrieve=mvdict_req) + except ScriptFailure: + break + + z3_pop_context() + + def is_cond_possible( # noqa - cond: Union[bool, 'z3.BoolRef'], sd: Optional['SymData'] = None, + cond: Union[bool, 'z3.BoolRef'], sd: 'SymData', *, name: str = '', fail_msg: str = '', ) -> bool: @@ -2388,21 +2438,17 @@ def is_cond_possible( # noqa z3_push_context() - if (name or sd) and env.log_progress: - env.write(f'check {name or sd} ') + if env.log_progress: + env.write(f'checking {name or sd} ') if env.log_solving_attempts_to_stderr: - env.solving_log(f' check {name or sd} ') + env.solving_log(f' checking {name or sd} ') - if sd: - sd_failcode = sd.get_failcode_dispatcher('possible') + sd_failcode = sd.get_failcode_dispatcher('possible') try: - if sd: - Check(cond, sd_failcode()) - else: - Check(cond) + Check(cond, sd_failcode()) except ScriptFailure: - if sd and env.log_progress and fail_msg: + if env.log_progress and fail_msg: env.ensure_newline() env.write_line(f'{fail_msg}, because condition is static') @@ -2411,11 +2457,7 @@ def is_cond_possible( # noqa failcodes: list[tuple[int, str]] = [] try: - if sd: - Check(cond, sd_failcode()) - else: - Check(cond) - + Check(cond, sd_failcode()) z3check(force_check=True) check_ok = True except ScriptFailure as sf: @@ -2427,14 +2469,14 @@ def is_cond_possible( # noqa else: failcodes.append((pc, code)) - if sd and ignored_code: + if ignored_code: assert ignored_code == f'check_{sd_failcode.name_prefix}' check_ok = False z3_pop_context() - if sd and env.log_progress: + if env.log_progress: maybe_report_elapsed_time() env.ensure_newline() if not check_ok and fail_msg: @@ -2983,7 +3025,7 @@ def CurrentOp(op_or_sd: Optional[Union['OpCode', 'ScriptData']] if env.do_progressive_z3_checks and \ (op_or_sd is None or isinstance(op_or_sd, OpCode)): ctx = cur_context() - assert (op is None) == (ctx.pc == len(g_script_body)) + assert (op is None) == (ctx.pc == len(env.script_info.body)) if op is not None: env.solving_log(f'{op} @ {op_pos_info(ctx.pc)} ') @@ -3231,9 +3273,54 @@ def __repr__(self) -> str: return f'{clsname}(name={repr(self.name)}, value={value_common_repr(self.value)})' -g_script_body: tuple[Union[OpCode, 'ScriptData'], ...] = () -g_line_no_table: list[int] = [] -g_var_save_positions: dict[int, str] = {} +@dataclass +class BsstAssertion: + fun: Callable[['SymData'], 'z3.BoolRef'] + is_for_size: bool + line_no: int + text: str + dref_name: str + + +@dataclass +class BsstAssumption: + fun: Callable[['SymData'], 'z3.BoolRef'] + is_for_size: bool + line_no: int + text: str + + +class ScriptInfo: + body: tuple[OpCode | ScriptData, ...] + line_no_table: tuple[int, ...] + _data_reference_positions: dict[int, str] + _assertion_positions: dict[int, tuple[BsstAssertion, ...]] + _assumption_table: dict[str, tuple[BsstAssumption, ...]] + + def __init__(self, *, body: Iterable[OpCode | ScriptData] = (), + line_no_table: Iterable[int] = (), + data_reference_positions: Mapping[int, str] = {}, + assertion_positions: Mapping[int, Iterable[BsstAssertion]] = {}, + assumption_table: Mapping[str, Iterable[BsstAssumption]] = {} + ): + self.body = tuple(body) + self.line_no_table = tuple(line_no_table) + self._data_reference_positions = {k: v for k, v in data_reference_positions.items()} + self._assertion_positions = {k: tuple(v) for k, v in assertion_positions.items()} + self._assumption_table = {k: tuple(v) for k, v in assumption_table.items()} + + def data_reference_at(self, line_no: int) -> str | None: + return self._data_reference_positions.get(line_no) + + def bsst_assertions_at(self, line_no: int + ) -> tuple[BsstAssertion, ...] | None: + return self._assertion_positions.get(line_no) + + def bsst_assumptions_for(self, dph_name: str + ) -> tuple[BsstAssumption, ...] | None: + return self._assumption_table.get(dph_name) + + g_data_reference_names_table: dict[str, dict[str, tuple['SymData', 'ExecContext']]] = {} g_seen_named_values: set[str] = set() @@ -3243,15 +3330,20 @@ class ScriptFailure(Exception): def op_pos_repr(pc: int) -> str: - return str(g_script_body[pc]) if pc < len(g_script_body) else 'FINAL_CHECKS' + env = cur_env() + if pc < len(env.script_info.body): + return str(env.script_info.body[pc]) + + return 'FINAL_CHECKS' def op_pos_info(pc: int, separator: str = ':') -> str: - if pc >= len(g_script_body): - assert pc == len(g_script_body) + env = cur_env() + if pc >= len(env.script_info.body): + assert pc == len(env.script_info.body) return 'END' - return f'{pc}{separator}L{g_line_no_table[pc]}' + return f'{pc}{separator}L{env.script_info.line_no_table[pc]}' def non_static_value_error(msg: str) -> NoReturn: @@ -3334,10 +3426,12 @@ def get_valid_branches(self) -> tuple['Branchpoint', ...]: @property def is_if_branch(self) -> bool: - if not isinstance(g_script_body[self.pc], OpCode): + env = cur_env() + + if not isinstance(env.script_info.body[self.pc], OpCode): return False - return g_script_body[self.pc] in (OP_IF, OP_NOTIF) + return env.script_info.body[self.pc] in (OP_IF, OP_NOTIF) def get_path(self, *, skip_failed_branches: bool = True ) -> tuple['Branchpoint', ...]: @@ -3359,7 +3453,8 @@ def get_path(self, *, skip_failed_branches: bool = True def repr_for_path(self) -> str: with CurrentExecContext(self.cond_context): cond = f' {self.cond}' if self.cond else '' - return (f'{g_script_body[self.pc]}{cond} @ {op_pos_info(self.pc)} : ' + + return (f'{cur_env().script_info.body[self.pc]}{cond} @ {op_pos_info(self.pc)} : ' f'{self.designation}') def get_timeline_strings(self, *, skip_failed_branches: bool = True @@ -4149,6 +4244,7 @@ class ExecContext(SupportsFailureCodeCallbacks): _data_refcount_neighbors: dict[str, set['SymData']] unused_values: set['SymData'] skip_enforcement_in_region: tuple[int, int] | None = None + data_placeholders_with_assumptions_applied: set[str] def __init__( self, *, @@ -4185,6 +4281,8 @@ def __init__( self._data_refcounts = {} self._data_refcount_neighbors = {} self.unused_values = set() + self.data_placeholders_with_assumptions_applied = set() + self.data_references: dict[str, 'SymData'] = {} @property def stack(self) -> list['SymData']: @@ -4233,6 +4331,8 @@ def clone(self: T_ExecContext) -> T_ExecContext: inst._data_refcounts = self._data_refcounts.copy() inst._data_refcount_neighbors = deepcopy(self._data_refcount_neighbors) inst.unused_values = self.unused_values.copy() + inst.data_placeholders_with_assumptions_applied = \ + self.data_placeholders_with_assumptions_applied.copy() for e in self.enforcements: inst.enforcements.append(e.clone(context=inst)) @@ -4272,12 +4372,14 @@ def on_start(self) -> None: "on-start routines are possible only for branches" return + env = cur_env() + # run on-start routines as if we're still at the opcode # that created the branch self.pc -= 1 try: - with CurrentOp(g_script_body[self.pc]): + with CurrentOp(env.script_info.body[self.pc]): for fun in self._run_on_start: fun() for c, c_name in self._z3_on_start: @@ -4452,6 +4554,7 @@ def __init__(self, *, name: str | None = None, self._wit_no = witness_number self._data_reference_aliases = set() + env = cur_env() ctx = cur_context() pc = ctx.pc @@ -4461,16 +4564,14 @@ def __init__(self, *, name: str | None = None, bpc = ctx.branchpoint.pc branch_index = ctx.branchpoint.branch_index - global g_stack_symdata_index - assert g_stack_symdata_index is not None - - sd_idx = g_stack_symdata_index - line_no_str = f'L{g_line_no_table[pc]}' - branch_line_no_str = f'L{g_line_no_table[bpc]}' + env = cur_env() + line_no_str = f'L{env.script_info.line_no_table[pc]}' + branch_line_no_str = f'L{env.script_info.line_no_table[bpc]}' self._unique_name = \ (f'{self._name or "_"}_{pc}{line_no_str}_{bpc}' - f'{branch_line_no_str}_{branch_index}_{sd_idx}') - g_stack_symdata_index += 1 + f'{branch_line_no_str}_{branch_index}_{env.stack_symdata_index}') + assert env.stack_symdata_index is not None + env.stack_symdata_index += 1 else: self._unique_name = unique_name @@ -4607,9 +4708,9 @@ def set_possible_sizes(self, *_sizes: int, value_name: str = '', self.update_solver_for_constrained_value(cv) def set_data_reference(self, data_reference: str) -> None: + ctx = cur_context() if self._data_reference is not None or self._data_reference_was_reset: if not self._data_reference_was_reset: - ctx = cur_context() ctx.warnings.append( (ctx.pc, f'Tried to replace data_reference {self._data_reference} with data_reference ' @@ -4620,6 +4721,10 @@ def set_data_reference(self, data_reference: str) -> None: self._data_reference = None else: self._data_reference = data_reference + assert data_reference not in ctx.data_references, \ + ("duplicate data reference names are not allowed, so within " + "a single context, data references must be unique") + ctx.data_references[data_reference] = self @property def is_static(self) -> bool: @@ -5095,44 +5200,39 @@ def update_solver_for_constrained_value(self, cv: ConstrainedValue) -> None: def update_model_values_request_dict( self, mvdict_req: dict[str, tuple[str, SymDataRType]], - namemap: dict[str, 'SymData'] + namemap: dict[str, 'SymData'], + *, preferred_rtype: Optional[SymDataRType] = None, ) -> None: - if self.was_used_as_Int: - name = self._name_Int - elif self.was_used_as_Int64: - name = self._name_Int64 - elif self.was_used_as_ByteSeq: - name = self._name_ByteSeq - elif self.was_used_as_Length: - name = self._name_Length - else: - return + rtype_table = ( + (SymDataRType.INT, self.was_used_as_Int, self._name_Int, self._Int), + (SymDataRType.INT64, self.was_used_as_Int64, self._name_Int64, self._Int64), + (SymDataRType.BYTESEQ, self.was_used_as_ByteSeq, self._name_ByteSeq, self._ByteSeq), + (SymDataRType.LENGTH, self.was_used_as_Length, self._name_Length, self._Length), + ) - if not self.is_static: - if self.was_used_as_Int: - assert self._Int is not None - dname = self._Int.decl().name() - rtype = SymDataRType.INT - elif self.was_used_as_Int64: - assert self._Int64 is not None - dname = self._Int64.decl().name() - rtype = SymDataRType.INT64 - elif self.was_used_as_ByteSeq: - assert self._ByteSeq is not None - dname = self._ByteSeq.decl().name() - rtype = SymDataRType.BYTESEQ - elif self.was_used_as_Length: - assert self._Length is not None - dname = self._Length.decl().name() - rtype = SymDataRType.LENGTH - else: - raise AssertionError("unreachable") + name = '' + dname = '' + rtype: SymDataRType | None = None + for varrtype, wasused, varname, z3var in rtype_table: + assert varname != '' + if wasused and (not name or preferred_rtype == varrtype): + name = varname + if not self.is_static: + rtype = varrtype + assert z3var is not None + dname = z3var.decl().name() + + if preferred_rtype == rtype: + break - assert name not in mvdict_req - mvdict_req[name] = (dname, rtype) + if name: + if not self.is_static: + assert rtype is not None + assert name not in mvdict_req + mvdict_req[name] = (dname, rtype) - assert name not in namemap - namemap[name] = self + assert name not in namemap + namemap[name] = self def check_only_one_value_possible(self, *, name: str = '') -> None: if cv := cur_context().model_values.get(self._unique_name): @@ -5163,6 +5263,76 @@ def set_known_bool(self, value: bool, set_size: bool = False) -> None: cur_context().known_bool_values[self._unique_name] = value + def collect_integer_model_values(self, max_count: int) -> list[int]: + + if not self.was_used_as_Int: + raise ValueError(f'{self} was not used as scriptnum yet') + + result: list[int] = [] + + def collect(mvdict: Optional[dict[str, 'ConstrainedValue']]) -> bool: + + if mvdict is None: # init call + # Add dummy check to make sure the solver knows about our value + # Note that the check must not be reduced to True by simplifying + # For that, we introduce a dummy unconstrained variable + dummy_value = SymData(unique_name=f'_dummy_{self._unique_name}') + Check(self.as_Int() == dummy_value.as_Int()) + return True + + if self._name_Int not in mvdict: + return False + + v = mvdict[self._name_Int].as_scriptnum_int() + + result.append(v) + + if len(result) == max_count: + return False + + Check(self.as_Int() != v) + + return True + + collect_model_values([self], collect, preferred_rtype=SymDataRType.INT) + + return result + + def collect_byte_model_values(self, max_count: int) -> list[bytes]: + + if not self.was_used_as_ByteSeq: + raise ValueError(f'{self} was not used as ByteSeq yet') + + result: list[bytes] = [] + + def collect(mvdict: Optional[dict[str, 'ConstrainedValue']]) -> bool: + + if mvdict is None: # init call + # Add dummy check to make sure the solver knows about our value + # Note that the check must not be reduced to True by simplifying + # For that, we introduce a dummy unconstrained variable + dummy_value = SymData(unique_name=f'_dummy_{self._unique_name}') + Check(self.as_ByteSeq() == dummy_value.as_ByteSeq()) + return True + + if self._name_ByteSeq not in mvdict: + return False + + v = mvdict[self._name_ByteSeq].as_bytes() + + result.append(v) + + if len(result) == max_count: + return False + + Check(self.as_ByteSeq() != IntSeqVal(v)) + + return True + + collect_model_values([self], collect, preferred_rtype=SymDataRType.BYTESEQ) + + return result + class SymDepth(SymData): @@ -5234,18 +5404,113 @@ def should_skip_immediately_failed_branch() -> bool: ctx = cur_context() start = ctx.pc + 1 end = start + len(env.skip_immediately_failed_branches_on) - if end <= len(g_script_body) and \ - g_script_body[start:end] == env.skip_immediately_failed_branches_on: + if end <= len(env.script_info.body) and \ + env.script_info.body[start:end] == env.skip_immediately_failed_branches_on: ctx.skip_enforcement_in_region = (start, end) return True return False +def apply_bsst_assn(ctx: ExecContext, assn: BsstAssertion | BsstAssumption, + top: SymData | None) -> None: + + env = cur_env() + + if isinstance(assn, BsstAssumption): + is_assumption = True + dref_name = '' + else: + assert isinstance(assn, BsstAssertion) + is_assumption = False + dref_name = assn.dref_name + + def ign_assertion_warning(cause: str) -> None: + env.solving_log_ensure_newline() + env.solving_log(f'WARNING: ignored assertion @L{assn.line_no}: "{assn.text}" ' + f'because {cause}') + ctx.add_warning(f"Assertion at line {assn.line_no} ignored because {cause}") + + if not top and not dref_name: + ign_assertion_warning('stack was empty') + return + + if dref_name: + if dref_name.startswith('&'): + if dref_name[1:] not in ctx.data_references: + ign_assertion_warning('data reference was not found') + return + + target = ctx.data_references[dref_name[1:]] + target_txt = f'{dref_name} = {target}' + else: + m = re.match('wit(\\d+)$', dref_name) + assert m, 'only witnesses must be without "&" prefix' + wit_no = int(m.group(1)) + if wit_no >= len(ctx.used_witnesses): + ign_assertion_warning(f'witness {dref_name} was not used at this point') + return + + target = ctx.used_witnesses[wit_no] + assert target.name == dref_name + target_txt = target.name + + else: + assert top is not None + target = top + target_txt = f'{target}' + + cond = assn.fun(target) + + amsg = f'@L{assn.line_no} for {target_txt}: "{assn.text}"' + + env.solving_log_ensure_newline() + + atype = 'size' if assn.is_for_size else 'value' + + if is_assumption: + env.solving_log(f'applying {atype} assumption {amsg}\n') + fc = failcode(f'assumption_at_line_{assn.line_no}') + else: + if is_cond_possible(Not(cond), target, + name=f'{atype} assertion {amsg}'): + raise ScriptFailure( + f'assertion failed at line {assn.line_no}') + + fc = failcode(f'assertion_at_line_{assn.line_no}') + + Check(cond, fc()) + z3check() + + env.solving_log_ensure_newline() + + +def check_bsst_assertions_and_assumptions(ctx: ExecContext) -> None: # noqa + env = cur_env() + + if len(ctx.stack): + top = ctx.stack[-1] + else: + top = None + + if top and top.name and top.name.startswith('$'): + if top.name not in ctx.data_placeholders_with_assumptions_applied: + if assumptions := env.script_info.bsst_assumptions_for(top.name): + for assumption in assumptions: + apply_bsst_assn(ctx, assumption, top) + + ctx.data_placeholders_with_assumptions_applied.add(top.name) + + for assertion in env.script_info.bsst_assertions_at(ctx.pc) or (): + apply_bsst_assn(ctx, assertion, top) + + def symex_op(ctx: ExecContext, op_or_sd: OpCode | ScriptData) -> bool: try: with CurrentOp(op_or_sd): was_executed = _symex_op(ctx, op_or_sd) + if was_executed and all(ctx.vfExec): + check_bsst_assertions_and_assumptions(ctx) except ScriptFailure as sf: ctx.register_failure(ctx.pc, str(sf)) was_executed = False @@ -5374,11 +5639,11 @@ def scope() -> None: 'minimaldata flag handling is strict') if sd.name and sd.name.startswith('$'): - if sd.name not in g_data_placeholders: - g_data_placeholders[sd.name] = SymData( + if sd.name not in env.data_placeholders: + env.data_placeholders[sd.name] = SymData( name=sd.name, unique_name=f'_dph_{sd.name}') - data = g_data_placeholders[sd.name] + data = env.data_placeholders[sd.name] else: data = SymData(name=sd.name, static_value=sd.value) @@ -7602,15 +7867,22 @@ def scope() -> None: return True -def symex_script() -> None: # noqa +def symex_script() -> None: + global g_is_in_processing + + g_is_in_processing = True + try: + _symex_script() + finally: + g_is_in_processing = False + + +def _symex_script() -> None: # noqa env = cur_env() def symex_context(ctx: ExecContext) -> None: - global g_stack_symdata_index - global g_var_save_positions - if ctx.is_finalized: return @@ -7622,14 +7894,14 @@ def symex_context(ctx: ExecContext) -> None: z3_push_context() - g_stack_symdata_index = 0 + env.stack_symdata_index = 0 ctx.on_start() - while ctx.pc < len(g_script_body) and not ctx.failure: + while ctx.pc < len(env.script_info.body) and not ctx.failure: pre_op_state = ctx.exec_state.clone() - op_or_sd = g_script_body[ctx.pc] + op_or_sd = env.script_info.body[ctx.pc] if env.sigversion in (SigVersion.BASE, SigVersion.WITNESS_V0) and \ isinstance(op_or_sd, OpCode) and \ @@ -7644,7 +7916,7 @@ def symex_context(ctx: ExecContext) -> None: num_pre_op_used_witnesses = len(ctx.used_witnesses) if symex_op(ctx, op_or_sd): - if data_reference := g_var_save_positions.get(ctx.pc): + if data_reference := env.script_info.data_reference_at(ctx.pc): if len(ctx.stack) > 0: ctx.stack[-1].set_data_reference(data_reference) @@ -7659,44 +7931,266 @@ def symex_context(ctx: ExecContext) -> None: if not ctx.failure: ctx.pc += 1 - g_stack_symdata_index = 0 + env.stack_symdata_index = 0 if not ctx.failure: ctx.exec_state_log[ctx.pc] = ctx.exec_state.clone() with CurrentOp(None): finalize(ctx) - g_stack_symdata_index = None + env.stack_symdata_index = None z3_pop_context() env.get_root_branch().walk_contexts(symex_context, is_executing=True) +def parse_bsst_assn( # noqa + text: str, *, die: Callable[[str], NoReturn], + env: SymEnvironment, is_for_size: bool, + types_used: set[type] +) -> Callable[['SymData'], 'z3.BoolRef']: + + cond_list: list[tuple[str, + int | bytes | IntLE64, + dict[str, int | bytes | IntLE64]]] = [] + + def append_sd(cmp_op: str, sd: ScriptData) -> None: + if isinstance(sd.value, int): + cond_list.append((cmp_op, sd.value, {'v': sd.value})) + elif isinstance(sd.value, IntLE64): + cond_list.append((cmp_op, sd.value, {'v': sd.value.as_int()})) + elif isinstance(sd.value, str): + v = sd.value.encode('utf-8') + cond_list.append((cmp_op, v, {'v': IntSeqVal(v)})) + else: + assert isinstance(sd.value, bytes) + cond_list.append((cmp_op, sd.value, {'v': IntSeqVal(sd.value)})) + + def parse_data(val_str: str, *, allow_bytes: bool) -> ScriptData: + sd = parse_script_data(val_str, die=die, env=env, + allow_nonstandard_size_scriptnums=True) + if sd is None: + die('unrecoginzed data format in assertion/assumption') + + if (isinstance(sd.value, int) and IntLE64 in types_used) \ + or (isinstance(sd.value, IntLE64) and int in types_used): + die('mixed ScriptNum and LE64 types in assertion/assumption') + + types_used.add(type(sd.value)) + + # check for str last, so that it would not be converted to numeric + if isinstance(sd.value, str): + sd = ScriptData(value=sd.value.encode('utf-8')) + + if is_for_size and not isinstance(sd.value, int): + die('only simple integers allowed for size assertions/assumptions') + + if not allow_bytes and not isinstance(sd.value, (int, IntLE64)): + die('raw data can only be compared compared for equality') + + return sd + + for valcmp_str in text.split(): + val_str = valcmp_str + if valcmp_str[0] in ('<', '>', '!', '='): + if len(valcmp_str) < 2: + die('empty value in assertion/assumption') + + if valcmp_str[0] != '=' and valcmp_str[1] == '=': + cmp_op = valcmp_str[:2] + val_str = valcmp_str[2:] + elif valcmp_str[0] == '!': + die('the "!" by itself has no meaning in assertion/assumption') + else: + cmp_op = valcmp_str[:1] + val_str = valcmp_str[1:] + + sd = parse_data(val_str, + allow_bytes=(valcmp_str[0] in ('=', '!'))) + + append_sd(cmp_op, sd) + + elif m := re.match("^([^\\'\\.]+)\\.\\.([^\\'\\.]+)$", valcmp_str): + lower = parse_data(m.group(1), allow_bytes=False) + if lower is None: + die('unrecoginzed data format for lower bound in range') + + upper = parse_data(m.group(2), allow_bytes=False) + if upper is None: + die('unrecoginzed data format for upper bound in range') + + if not isinstance(lower.value, (int, IntLE64)) or \ + not isinstance(upper.value, (int, IntLE64)): + die('non-numeric data can only be compared for equality') + + if isinstance(lower.value, int) != isinstance(upper.value, int): + die('types for lower and upper range values do not match') + + if isinstance(lower.value, int): + assert isinstance(upper.value, int) + lv = lower.value + uv = upper.value + else: + assert isinstance(lower.value, IntLE64) + assert isinstance(upper.value, IntLE64) + lv = lower.value.as_int() + uv = upper.value.as_int() + + if lv >= uv: + die('lower value must be >= upper value for a range') + + assert isinstance(lower.value, int | IntLE64) + cond_list.append(('..', lower.value, {'lower': lv, 'upper': uv})) + else: + sd = parse_data(valcmp_str, allow_bytes=True) + if sd is None: + die('unrecoginzed data format in value assertion/assumption') + + append_sd('=', sd) + + cmp_op_table: dict[ + str, Callable[[Union['z3.ArithRef', 'z3.SeqSortRef'], + dict[str, int | bytes | IntLE64]], + 'z3.BoolRef'] + ] = { + '=': lambda v, d: v == d['v'], + '<': lambda v, d: v < d['v'], + '>': lambda v, d: v > d['v'], + '<=': lambda v, d: v <= d['v'], + '>=': lambda v, d: v >= d['v'], + '!=': lambda v, d: v != d['v'], + '..': lambda v, d: And(d['lower'] <= v, v <= d['upper']), + } + + def get_value_for_type(v: SymData, ref_v: int | IntLE64 | bytes + ) -> Union['z3.ArithRef', 'z3.SeqSortRef']: + + if is_for_size: + return v.use_as_Length() + + if isinstance(ref_v, int): + return v.use_as_Int(max_size=5) + + if isinstance(ref_v, IntLE64): + return v.use_as_Int64() + + assert isinstance(ref_v, bytes) + return v.use_as_ByteSeq() + + def apply_conds(v: SymData) -> 'z3.BoolRef': + return Or(*[cmp_op_table[cmp_op](get_value_for_type(v, rv), d) + for cmp_op, rv, d in cond_list]) + + return apply_conds + + +def parse_script_data(text: str, *, die: Callable[[str], NoReturn], # noqa + env: SymEnvironment, + allow_nonstandard_size_scriptnums: bool + ) -> Optional[ScriptData]: + + if len(text) >= 2 and text[0] == "'" and text[-1] == "'": + if "'" in text[1:-1]: + die('ambiguous quotes. you have to use hex encoding ' + 'if you want to include single quote (0x27) in data') + + return ScriptData(name=None, value=text[1:-1], + do_check_non_minimal=env.minimaldata_flag_strict) + + if env.is_elements and \ + text.lower().startswith('le64(') and text.endswith(')'): + + text = text[5:-1] + + sign = 1 + if text.startswith('-'): + sign = -1 + text = text[1:] + + if not text.isdigit(): + die('incorrect argument to le64()') + + if text.startswith('0') and len(text) > 1: + die('no leading zeroes allowed') + + v = int(text)*sign + + return ScriptData(name=None, value=IntLE64.from_int(v)) + + if text.isdigit() or (text.startswith('-') and text[1:].isdigit()): + sign = 1 + if text.startswith('-'): + sign = -1 + text = text[1:] + if text.startswith('0') and len(text) > 1: + die('no leading zeroes allowed') + + v = int(text)*sign + + if not allow_nonstandard_size_scriptnums: + vch = integer_to_scriptnum(v) + if len(vch) > SCRIPTNUM_DEFAULT_SIZE: + die(f'the number {v}, when converted to ' + f'CScriptNum will be {len(vch)} bytes in length, ' + f'which is above the limit of ' + f'{SCRIPTNUM_DEFAULT_SIZE} bytes') + + return ScriptData(name=None, value=v) + + if text.lower().startswith("x('") and text.endswith("')"): + data_str = text[3:-2] + try: + return ScriptData( + name=None, value=bytes.fromhex(data_str), + do_check_non_minimal=env.minimaldata_flag_strict) + except ValueError: + die(f'cannot decode data: {data_str}') + + if text.lower().startswith("0x"): + data_str = text[2:] + try: + return ScriptData( + name=None, value=bytes.fromhex(data_str), + do_check_non_minimal=env.minimaldata_flag_strict) + except ValueError: + die(f'cannot decode data: {data_str}') + + return None + def get_opcodes(script_lines: Iterable[str], # noqa allow_nonstandard_size_scriptnums: bool = False - ) -> tuple[tuple[OpCode | ScriptData, ...], - list[int], dict[int, str]]: + ) -> ScriptInfo: env = cur_env() - opcodes: list[OpCode | ScriptData] = [] + body: list[OpCode | ScriptData] = [] line_no_table: list[int] = [] - var_save_positions: dict[int, str] = {} + data_reference_positions: dict[int, str] = {} + assertion_positions: dict[int, list[BsstAssertion]] = {} + assumption_table: dict[str, list[BsstAssumption]] = {} seen_data_reference_names: dict[str, int] = {} line_no = -1 + types_used_in_assertions: tuple[set[type], set[type]] = (set(), set()) + types_used_in_assumptions: tuple[dict[str, set[type]], + dict[str, set[type]]] = ({}, {}) + for l_idx, line in enumerate(script_lines): line_no = l_idx + 1 def die(msg: str) -> NoReturn: msg = re.sub('[\\x00-\x1F]', '?', msg) - sys.stderr.write(f'ERROR at line {line_no}: {msg}\n') - sys.exit(-1) + raise BSSTParsingError(f'ERROR at line {line_no}: {msg}') + assn_check_fun: Optional[Callable[['SymData'], 'z3.BoolRef']] = None + is_for_size = False + assn_text = '' + assn_dph_name = '' + assn_dref_name = '' data_reference = '' # remove '//' comments if m := re.search('//', line): @@ -7705,78 +8199,67 @@ def die(msg: str) -> NoReturn: if m := re.match('\\s*=>(\\S+)', comment): data_reference = m.group(1) + if m := re.match('\\s*bsst-(assert|assume)(-size)?([^:]*):(.*)', comment): + is_for_size = bool(m.group(2)) + assn_arg = m.group(3) + assn_text = m.group(4).strip() + if m.group(1) == 'assume': + if not (assn_arg.startswith('(') and assn_arg.endswith(')')): + die('unexpected format for bsst-assume') + + assn_dph_name = assn_arg[1:-1] + + if not assn_dph_name.startswith('$') or \ + not assn_dph_name[1:].isidentifier(): + die('bsst-assume argument must be a valid data placeholder') + + typedict = types_used_in_assumptions[int(is_for_size)] + if assn_dph_name not in typedict: + typedict[assn_dph_name] = set() + + types_used = typedict[assn_dph_name] + else: + if assn_arg != '': + if not (assn_arg.startswith('(') and assn_arg.endswith(')')): + die('unexpected format for bsst-assert') + assn_dref_name = assn_arg[1:-1] + if not assn_dref_name.startswith('&'): + if not re.match('wit(\\d+)$', assn_dref_name): + die('only data references and witnesses are recognized ' + 'as arguments to bsst-assert, data reference names ' + 'must be prefixed with "&", and witness names must have ' + 'format "wit" where N is a number') + else: + assn_dref_name = '' + + types_used = types_used_in_assertions[int(is_for_size)] + + if len(body) and body[-1].name and \ + body[-1].name.startswith('$'): + types_used.update( + types_used_in_assumptions[int(is_for_size)].get( + body[-1].name) or set()) + + assn_check_fun = parse_bsst_assn( + assn_text, die=die, env=env, + types_used=types_used, is_for_size=is_for_size) + for op_str in line.split(): got_angle_brackets = False if op_str.startswith('<') and op_str.endswith('>'): op_str = op_str[1:-1] got_angle_brackets = True - op: OpCode | ScriptData - if op_str.startswith('$') and op_str[1:].isidentifier(): - op = ScriptData(name=op_str, value=None) - elif len(op_str) >= 2 and op_str[0] == "'" and op_str[-1] == "'": - if "'" in op_str[1:-1]: - die('ambiguous quotes. you have to use hex encoding ' - 'if you want to include single quote (0x27) in data') - - op = ScriptData(name=None, value=op_str[1:-1], - do_check_non_minimal=env.minimaldata_flag_strict) - elif (env.is_elements and - op_str.lower().startswith('le64(') and op_str.endswith(')')): - - op_str = op_str[5:-1] - - sign = 1 - if op_str.startswith('-'): - sign = -1 - op_str = op_str[1:] - - if not op_str.isdigit(): - die('incorrect argument to le64()') - - if op_str.startswith('0') and len(op_str) > 1: - die('no leading zeroes allowed') - - v = int(op_str)*sign - op = ScriptData(name=None, value=IntLE64.from_int(v)) - - elif (op_str.isdigit() or (op_str.startswith('-') - and op_str[1:].isdigit())): - sign = 1 - if op_str.startswith('-'): - sign = -1 - op_str = op_str[1:] - if op_str.startswith('0') and len(op_str) > 1: - die('no leading zeroes allowed') - - v = int(op_str)*sign - - if not allow_nonstandard_size_scriptnums: - vch = integer_to_scriptnum(v) - if len(vch) > SCRIPTNUM_DEFAULT_SIZE: - die(f'the number {v}, when converted to ' - f'CScriptNum will be {len(vch)} bytes in length, ' - f'which is above the limit of ' - f'{SCRIPTNUM_DEFAULT_SIZE} bytes') - - op = ScriptData(name=None, value=v) - - elif op_str.lower().startswith("x('") and op_str.endswith("')"): - data_str = op_str[3:-2] - try: - op = ScriptData( - name=None, value=bytes.fromhex(data_str), - do_check_non_minimal=env.minimaldata_flag_strict) - except ValueError: - die(f'cannot decode data: {data_str}') - elif op_str.lower().startswith("0x"): - data_str = op_str[2:] - try: - op = ScriptData( - name=None, value=bytes.fromhex(data_str), - do_check_non_minimal=env.minimaldata_flag_strict) - except ValueError: - die(f'cannot decode data: {data_str}') + op_or_sd: OpCode | ScriptData + if op_str.startswith('$'): + if not op_str[1:].isidentifier(): + die('data placeholder name must be an identifier') + op_or_sd = ScriptData(name=op_str, value=None) + elif maybe_sd := parse_script_data( + op_str, die=die, env=env, + allow_nonstandard_size_scriptnums=allow_nonstandard_size_scriptnums + ): + op_or_sd = maybe_sd elif got_angle_brackets: die(f'unexpected value in angle brackets: {op_str}') else: @@ -7789,7 +8272,7 @@ def die(msg: str) -> NoReturn: if maybe_op is None: die(f'unknown opcode {op_str}') - op = maybe_op + op_or_sd = maybe_op mode = 'elements' if env.is_elements else 'bitcoin' if env.sigversion == SigVersion.TAPSCRIPT: @@ -7797,11 +8280,11 @@ def die(msg: str) -> NoReturn: else: mode = f'{mode} (non-tapscript)' - if op not in env.get_enabled_opcodes(): + if op_or_sd not in env.get_enabled_opcodes(): die(f'opcode {op_str} is not valid for {mode}') line_no_table.append(line_no) - opcodes.append(op) + body.append(op_or_sd) if data_reference and env.restrict_data_reference_names: if data_reference and not data_reference.isidentifier(): @@ -7810,7 +8293,12 @@ def die(msg: str) -> NoReturn: f"{line_no}\n") data_reference = '' + op_pos = len(body)-1 + if data_reference: + if op_pos < 0: + die('data reference before any opcode or value in the script') + if data_reference in seen_data_reference_names: die(f'data_reference at line {line_no} was already used at line ' f'{seen_data_reference_names[data_reference]}') @@ -7818,30 +8306,61 @@ def die(msg: str) -> NoReturn: if "'" in data_reference: die("apostrophe <<'>> is not allowed in data reference names") - seen_data_reference_names[data_reference] = line_no + if re.match('wit(\\d+)$', data_reference): + die('cannot use the name "wit" (where is a number) as ' + 'data reference, because this name is reserved for witnesses') - var_save_positions[len(opcodes)-1] = data_reference + seen_data_reference_names[data_reference] = line_no + data_reference_positions[op_pos] = data_reference + + if assn_check_fun: + if assn_dph_name: + assn_funcs = assumption_table.get(assn_dph_name, []) + assn_funcs.append(BsstAssumption( + fun=assn_check_fun, is_for_size=is_for_size, + line_no=line_no, text=assn_text)) + assumption_table[assn_dph_name] = assn_funcs + else: + if op_pos < 0: + die('assertion before any opcode or value in the script') + + vac_funcs = assertion_positions.get(op_pos, []) + vac_funcs.append(BsstAssertion( + fun=assn_check_fun, is_for_size=is_for_size, + line_no=line_no, text=assn_text, + dref_name=assn_dref_name)) + assertion_positions[op_pos] = vac_funcs + else: + types_used_in_assertions[0].clear() + types_used_in_assertions[1].clear() line_no_table.append(line_no+1) - return tuple(opcodes), line_no_table, var_save_positions + return ScriptInfo(body=body, line_no_table=line_no_table, + data_reference_positions=data_reference_positions, + assertion_positions=assertion_positions, + assumption_table=assumption_table) def finalize(ctx: ExecContext) -> None: # noqa + env = cur_env() + assert not ctx.failure - assert ctx.pc == len(g_script_body) + assert ctx.pc == len(env.script_info.body) try: - _finalize(ctx) + _finalize(ctx, env) except ScriptFailure as sf: ctx.register_failure(ctx.pc, str(sf)) ctx.is_finalized = True + if hook := env.post_finalize_hook: + hook(ctx, env) -def _finalize(ctx: ExecContext) -> None: # noqa - env = cur_env() + +def _finalize(ctx: ExecContext, env: SymEnvironment) -> None: # noqa assert not ctx.failure - assert ctx.pc == len(g_script_body) + assert ctx.pc == len(env.script_info.body) env.solving_log_ensure_empty_line() @@ -7915,7 +8434,7 @@ def _finalize(ctx: ExecContext) -> None: # noqa txval.update_model_values_request_dict(mvdict_req, mvnamemap) processed.append(txval) - for val in g_data_placeholders.values(): + for val in env.data_placeholders.values(): if val not in processed: val.update_model_values_request_dict(mvdict_req, mvnamemap) processed.append(val) @@ -7977,10 +8496,10 @@ def _finalize(ctx: ExecContext) -> None: # noqa verify_targets: list[Enforcement] = [] if not env.use_z3_incremental_mode: for e in ctx.enforcements: - if e.pc >= len(g_script_body): + if e.pc >= len(env.script_info.body): op = None else: - op = g_script_body[e.pc] + op = env.script_info.body[e.pc] is_verify_target = (op is None or op in (OP_VERIFY, OP_EQUALVERIFY, @@ -7994,7 +8513,7 @@ def _finalize(ctx: ExecContext) -> None: # noqa txvalues = ctx.tx.values() got_model_values = (env.produce_model_values and (ctx.used_witnesses or txvalues or ctx.stack - or g_data_placeholders)) + or env.data_placeholders)) if env.produce_model_values: @@ -8016,9 +8535,10 @@ def _finalize(ctx: ExecContext) -> None: # noqa txval.check_only_one_value_possible() processed.append(txval) - for val in g_data_placeholders.values(): + for val in env.data_placeholders.values(): if val in processed: - env.write_line(f'skip checking {val}: already checked') + if env.log_progress: + env.write_line(f'skip checking {val}: already checked') else: val.check_only_one_value_possible() processed.append(val) @@ -8031,7 +8551,8 @@ def _finalize(ctx: ExecContext) -> None: # noqa valname = f'stack[{pos}]' if val in processed: - env.write_line(f'skip checking {valname}: already checked') + if env.log_progress: + env.write_line(f'skip checking {valname}: already checked') else: val.check_only_one_value_possible(name=valname) processed.append(val) @@ -8141,8 +8662,8 @@ def report() -> None: # noqa enforcements_by_path: dict[tuple['Branchpoint', ...], set['Enforcement']] = {} - model_values_map: dict[tuple[str, ...], - tuple[int, list['ExecContext']]] = {} + model_values_map: dict[tuple[int, tuple[str, ...]], + list['ExecContext']] = {} nonmodel_stack: list[SymData] @@ -8194,7 +8715,7 @@ def get_val_str(v: SymData) -> str: for txval in txvalues: mvals_list.append(f'{txval} {get_val_str(txval)}') - for vname, val in g_data_placeholders.items(): + for vname, val in env.data_placeholders.items(): mvals_list.append(f'{vname} {get_val_str(val)}') for w in bp.context.used_witnesses: @@ -8224,14 +8745,12 @@ def get_val_str(v: SymData) -> str: assert env.is_incomplete_script, \ "context should have failure set otherwise" - mvals = tuple(mvals_list) - - num_witnesses = len(bp.context.used_witnesses) + mvmap_key = (len(bp.context.used_witnesses), tuple(mvals_list)) - if mvals not in model_values_map: - model_values_map[mvals] = (num_witnesses, [bp.context]) + if mvmap_key not in model_values_map: + model_values_map[mvmap_key] = [bp.context] else: - model_values_map[mvals][1].append(bp.context) + model_values_map[mvmap_key].append(bp.context) for e in bp.unique_enforcements or (): path = bp.get_enforcement_path(e) @@ -8308,14 +8827,15 @@ def collect_valid_paths(bp: Branchpoint, level: int) -> None: else: print_as_header(f'Witness usage {path_msg}:') - for mvals, (num_witnesses, ctx_list) in model_values_map.items(): - assert len(ctx_list) > 0 + for mvmap_key, contexts in model_values_map.items(): + assert len(contexts) > 0 if len(model_values_map) > 1: with VarnamesDisplay(): - for ctx in ctx_list: + for ctx in contexts: print_as_header('\n'.join(ctx.get_timeline_strings()), level=1) + num_witnesses, mvals = mvmap_key env.write_line(f"Witnesses used: {num_witnesses}") env.ensure_empty_line() @@ -8401,7 +8921,7 @@ def report_failures(ctx: ExecContext) -> None: else: assert poi.startswith('L') line_no = int(poi[1:]) - for pc, lno in enumerate(g_line_no_table): + for pc, lno in enumerate(env.script_info.line_no_table): if line_no == lno: pc_list.append(pc) break @@ -8411,16 +8931,14 @@ def report_failures(ctx: ExecContext) -> None: level=1) def report_poi(ctx: ExecContext) -> None: - global g_script_body - print_as_header( (ctx.get_timeline_strings(skip_failed_branches=False) or "All paths"), level=1) for pc in sorted(pc_list): if pc in ctx.exec_state_log: - if pc < len(g_script_body): - op_str = f' ({g_script_body[pc]})' + if pc < len(env.script_info.body): + op_str = f' ({env.script_info.body[pc]})' else: op_str = '' @@ -8908,10 +9426,6 @@ def compress(h0: int, h1: int, h2: int, h3: int, h4: int, block: bytes def main() -> None: # noqa - global g_script_body - global g_line_no_table - global g_var_save_positions - env = cur_env() if env.z3_enabled: maybe_randomize_z3_seeds() @@ -8929,15 +9443,10 @@ def main() -> None: # noqa with open(env.input_file) as f: lines = f.readlines() - g_script_body, g_line_no_table, g_var_save_positions = \ - get_opcodes(lines) + env.script_info = get_opcodes(lines) - if g_script_body: - global g_is_in_processing - - g_is_in_processing = True + if env.script_info.body: symex_script() - g_is_in_processing = False report() @@ -8961,8 +9470,7 @@ def try_import_optional_modules() -> None: global z3 import z3 except ImportError as e: - sys.stderr.write(f'ERROR: Failed to import z3: {e}\n') - sys.exit(-1) + raise BSSTInitializationError(f'ERROR: Failed to import z3: {e}\n') g_optional_modules_register['z3'] = True @@ -9179,10 +9687,14 @@ def main_cli() -> None: if pid == 0: signal.signal(signal.SIGINT, sigint_handler) - with CurrentEnvironment(SymEnvironment()): - parse_cmdline_args() - try_import_optional_modules() - main() + try: + with CurrentEnvironment(SymEnvironment()): + parse_cmdline_args() + try_import_optional_modules() + main() + except BSSTError as e: + print(e) + sys.exit(-1) else: signal.signal(signal.SIGCHLD, sigchld_handler) signal.signal(signal.SIGINT, signal.SIG_IGN) diff --git a/release-notes.md b/release-notes.md index cdd70b0..31b0bd2 100644 --- a/release-notes.md +++ b/release-notes.md @@ -2,8 +2,11 @@ Version 0.1.2.dev0: -* Fix: scriptnum decoding was not imposing "0 >= x => 255" bound on the byte sequence if its size was 1. - This was causing problems with `bsst-assume` tests, but likely that this could have caused problems elsewhere, too +* Add ability to set assertions on stack values and witnesses, and assumptions for data placeholders. Please see newly added "Assertions" and "Assumptions" sections in README. You might also look at `tests/test_assertions_and_assumptions.py` for examples of usage + +* Fix: scriptnum decoding was not imposing "0 >= x => 255" bound on the byte sequence if its size was 1. This was causing problems with `bsst-assume` tests, but likely that this could have caused problems elsewhere, too + +* To avoid confusion, data reference names cannot be "wit" (where is a number), because such names are reserved for witnesses * Fixes in parser: quotes within quotes were allowed, but should not; angle brackets were sometimes not ignored diff --git a/tests/runtests.sh b/tests/runtests.sh index 5f418df..323efc5 100755 --- a/tests/runtests.sh +++ b/tests/runtests.sh @@ -5,6 +5,7 @@ set -ex ./test_integer_conversion.py ./test_data_placeholders.py ./test_varnames.py +./test_assertions_and_assumptions.py ./test_elements_script_tests.py ./script_tests_tapscript_opcodes.json tapscript if [ ! -e script_tests.json ]; then wget https://raw.githubusercontent.com/ElementsProject/elements/master/src/test/data/script_tests.json diff --git a/tests/test_assertions_and_assumptions.py b/tests/test_assertions_and_assumptions.py new file mode 100755 index 0000000..2ed66c1 --- /dev/null +++ b/tests/test_assertions_and_assumptions.py @@ -0,0 +1,535 @@ +#!/usr/bin/env python3 + +import re +from contextlib import contextmanager +from typing import Generator, Sequence + +import bsst + + +@contextmanager +def FreshEnv(*, z3_enabled: bool + ) -> Generator[bsst.SymEnvironment, None, None]: + env = bsst.SymEnvironment() + env.z3_enabled = z3_enabled + with bsst.CurrentEnvironment(env): + bsst.try_import_optional_modules() + bp = bsst.Branchpoint(pc=0, branch_index=0) + with bsst.CurrentExecContext(bp.context): + yield env + + +testcases_normal: list[tuple[str, set[int | bytes]]] = [ + ( + """ + // bsst-assume($a): 1 2 3 + $a + """, + {1, 2, 3} + ), + ( + # NOTE: -4839433545 is beyond 4-byte scriptnum range, but will be + # included if no arithmetic is performed + """ + // bsst-assume($a): 100 1000 43345 -245 -3344 -48394 -4839433545 + $a + """, + {b'\x64', b'\xe8\x03', b'\x51\xA9\x00', + b'\xf5\x80', b'\x10\x8d', b'\x0a\xbd\x80', b'\x49\xe5\x73\x20\x81'} + ), + ( + # NOTE: -4839433545 is beyond 4-byte scriptnum range, will not + # be included after arithmetic is done + """ + // bsst-assume($a): 100 -4839433545 + $a 0 ADD + """, + {100} + ), + ( + # NOTE: 0x00 must be ignored because it is not valid scriptnum, + # and we have scriptnums for assume($a), and minimaldat_flag is False + """ + // bsst-assume($a): >-3 0x00 + // bsst-assume($a): <2 + // bsst-assume($a): !='' + $a + """, + {b'\x81', b'\x82', b'\x01'} + ), + ( + """ + // bsst-assume($a): >=-2 + // bsst-assume($a): <=1 + // bsst-assume($a): !='' + $a + """, + {b'\x81', b'\x82', b'\x01'} + ), + ( + # combined via AND, and thus '' must be the only possible value + """ + // bsst-assume($a): >=-2 + // bsst-assume($a): <=1 + // bsst-assume($a): ='' + $a + // bsst-assert: 0 + """, + {b''} + ), + ( + """ + // bsst-assume($a): -1..2 + $a + """, + {-1, 0, 1, 2} + ), + ( + """ + // bsst-assume($a): le64(-1)..le64(2) + $a + """, + set(bsst.IntLE64.from_int(v) for v in (-1, 0, 1, 2)) + ), + ( + """ + // bsst-assume($a): x('efcdab99') + $a 0x78563412 CAT + // bsst-assert: le64(1311768467445894639) + """, + {bsst.IntLE64.from_int(0x1234567899ABCDEF)} + ), + ( + """ + // bsst-assume($a): 'abc' + $a 'def' CAT + // bsst-assert: 'abcdef' + """, + {b'abcdef'} + ), + ( + """ + // bsst-assume($a): 19 + // bsst-assume($b): 0 1 2 3 + $a $b ADD DUP 22 NUMEQUAL NOT VERIFY DUP 19 NUMEQUAL NOT VERIFY + // bsst-assert: 20 21 + // bsst-assert: <22 + // bsst-assert: !=19 + """, + {20, 21} + ), + ( + """ + // bsst-assume($a): le64(19) le64(21) + $a le64(1) ADD64 VERIFY le64(2) DIV64 VERIFY + // bsst-assert: le64(10) le64(11) !=le64(12) + """, + {bsst.IntLE64.from_int(10), bsst.IntLE64.from_int(11)} + ), + ( + """ + // bsst-assume($a): le64(20) le64(21) + $a DUP + le64(1) ADD64 VERIFY le64(3) DIV64 VERIFY + SWAP // =>remainder + le64(0) EQUAL VERIFY + // bsst-assert(&remainder): le64(0) + // bsst-assert-size(&remainder): 8 + // bsst-assert: le64(7) + // bsst-assert-size: 8 + DROP + """, + {bsst.IntLE64.from_int(20)} + ), + ( + """ + 10 5 DUP ADD SUB + // bsst-assert-size: 0 + 1 ADD + // bsst-assert: 1 + """, + {1} + ), + ( + """ + $a // =>a + // bsst-assume-size($a): 1 + DUP x('01') CAT + // bsst-assert-size: 2 + 1 ADD 258 NUMEQUALVERIFY + // bsst-assert(&a): 1 + """, + {1} + ), + ( + """ + DUP TOALTSTACK + CHECKSIGVERIFY + // bsst-assert-size(wit0): 32 + // bsst-assert-size(wit1): 64 65 + FROMALTSTACK SIZE + """, + {32} + ), +] + +testcases_assnfail: list[tuple[str, set[int], bool]] = [ + ( + # conflicting assumption constraints + """ + // bsst-assume($a): 100 1000 43345 -245 -3344 -48394 -4839433545 + // bsst-assume($a): 3 + $a + """, + {2, 3}, False + ), + ( + """ + // bsst-assume($a): 100 -4839433545 + $a 0 ADD + // bsst-assert: 1 + """, + {4}, False + ), + ( + """ + // bsst-assume($a): >-3 0x00 + // bsst-assume($a): <2 + // bsst-assume($a): !='' + $a + // bsst-assert-size: 0 + """, + {6}, False + ), + ( + """ + // bsst-assume($a): >=-2 + // bsst-assume($a): <=1 + // bsst-assume($a): ='' + $a + // bsst-assert: !=0 + """, + {6}, False + ), + ( + """ + // bsst-assume($a): -1..2 + $a + // bsst-assert: -3..-2 + """, + {4}, False + ), + ( + """ + // bsst-assume($a): le64(-1)..le64(2) + $a + // bsst-assert: le64(-4)..le64(-2) + """, + {4}, False + ), + ( + """ + // bsst-assume($a): x('efcdab99') + $a 0x78563412 CAT + // bsst-assert: le64(123) + """, + {4}, False + ), + ( + """ + // bsst-assume($a): 'abc' + $a 'def' CAT + // bsst-assert: 'ABCDEF' + """, + {4}, False + ), + ( + """ + // bsst-assume($a): 19 + // bsst-assume($b): 0 1 2 3 + $a $b ADD DUP 22 NUMEQUAL NOT VERIFY DUP 19 NUMEQUAL NOT VERIFY + // bsst-assert: 20 21 + // bsst-assert: <22 + // bsst-assert: =19 + """, + {7}, False + ), + ( + """ + // bsst-assume($a): le64(19) le64(21) + $a le64(1) ADD64 VERIFY le64(2) DIV64 VERIFY + // bsst-assert: le64(10) le64(11) !=le64(12) + // bsst-assert: =le64(12) + """, + {5}, False + ), + ( + """ + // bsst-assume($a): le64(20) le64(21) + $a DUP + le64(1) ADD64 VERIFY le64(3) DIV64 VERIFY + SWAP // =>remainder + le64(0) EQUAL VERIFY + // bsst-assert(&remainder): le64(0) + // bsst-assert-size(&remainder): 8 + // bsst-assert: le64(4) + // bsst-assert-size: 8 + DROP + """, + {9}, False + ), + ( + """ + 10 5 DUP ADD SUB + // bsst-assert-size: 1 + 1 ADD + // bsst-assert: 1 + """, + {3}, False + ), + ( + """ + $a // =>a + // bsst-assume-size($a): 1 + DUP x('01') CAT + // bsst-assert-size: 2 + 1 ADD 258 NUMEQUALVERIFY + // bsst-assert(&a): 0x00 + """, + {7}, False + ), + ( + """ + DUP + IF + // bsst-assert-size: >0 + DUP 1 GREATERTHANOREQUAL VERIFY + ELSE + // bsst-assert-size: 1 + SIZE 10 NUMEQUALVERIFY + ENDIF + // bsst-assert(wit0): 1 + // bsst-assert: >10 + """, + {7, 11}, True + ), + ( + """ + DUP + IF + // bsst-assert-size: >0 + DUP 1 GREATERTHANOREQUAL VERIFY + ELSE + // bsst-assert-size: 1 + SIZE 10 NUMEQUALVERIFY + ENDIF + // bsst-assert: >0 + // bsst-assert(wit0): 0 + """, + {7, 11}, True + ), + ( + # tapscript signature is 32 bytes, so first assertion musr fail + """ + CHECKSIGVERIFY + // bsst-assert-size(wit0): 33 + // bsst-assert-size(wit1): 64 + """, + {3}, False + ), + ( + # signature can be 64 or 65 bytes, so second assertion musr fail + """ + CHECKSIGVERIFY + // bsst-assert-size(wit0): 32 + // bsst-assert-size(wit1): 64 + """, + {4}, False + ), +] + +testcases_otherfail: list[str] = [ + # no mixing of le64 and scriptnums + """ + // bsst-assume($a): 100 101 + // bsst-assume($a): le64(100) + $a + """, + """ + // bsst-assume($a): le64(100) + $a + // bsst-assert: 100 + """, + # no asserts against empty stack + """ + // bsst-assert: 0 + 1 + """, + # no spaces allowed inside expression + """ + 1 + // bsst-assert: >= 0 + """, + """ + le64(1) + // bsst-assert: le64( 1 ) + """, + # range end must be > start + """ + 1 + // bsst-assert: 1..1 + """, + """ + 1 + // bsst-assert: 2..1 + """, + # non-witness name for bsst-assert name without & + """ + 1 // =>wit + // bsst-assert(wit): 1 + """ +] + + +def test_assn_normal( + tc_no: int, tc_text: str, tc_expected_values: set[int | bytes] +) -> None: + + if tc_expected_values: + assert all(isinstance(v, type(tuple(tc_expected_values)[0])) + for v in tc_expected_values), "mixed types are not allowed" + + is_ok = False + + with FreshEnv(z3_enabled=True) as env: + env.use_parallel_solving = False + env.log_progress = False + env.is_elements = True + env.is_incomplete_script = True + env.solver_timeout_seconds = 0 + + def post_finalize_hook(ctx: bsst.ExecContext, + env: bsst.SymEnvironment) -> None: + nonlocal is_ok + + mvals: Sequence[int | bytes] + + top = ctx.stack[-1] + + if tc_expected_values and \ + isinstance(tuple(tc_expected_values)[0], int): + if top.is_static: + mvals = [top.as_scriptnum_int()] + else: + top.use_as_Int(max_size=5) + mvals = ctx.stack[-1].collect_integer_model_values( + max_count=len(tc_expected_values)+1) + else: + if top.is_static: + mvals = [top.as_bytes()] + else: + ctx.stack[-1].use_as_ByteSeq() + mvals = ctx.stack[-1].collect_byte_model_values( + max_count=len(tc_expected_values)+1) + + assert len(mvals) == len(set(mvals)), \ + ("no duplicates expected", mvals) + assert set(mvals) == tc_expected_values, \ + ("model values must match expected values", mvals) + + print("OK") + + is_ok = True + + env.script_info = bsst.get_opcodes(tc_text.split('\n')) + env.post_finalize_hook = post_finalize_hook + + bsst.symex_script() + + assert is_ok, "post_finalize_hook must run and successfully return" + + +def test_assn_failing( + tc_no: int, tc_text: str, failure_lines: set[int], is_exact_match: bool +) -> None: + + with FreshEnv(z3_enabled=True) as env: + env.use_parallel_solving = False + env.log_progress = False + env.is_elements = True + env.is_incomplete_script = True + env.solver_timeout_seconds = 0 + + env.script_info = bsst.get_opcodes(tc_text.split('\n')) + + bsst.symex_script() + + seen_failure_lines: set[int] = set() + + def search_failures(ctx: bsst.ExecContext) -> None: + if not ctx.failure: + return + + flines: set[int] = set() + pc, errstr = ctx.failure + if errstr.startswith(bsst.SCRIPT_FAILURE_PREFIX_SOLVER): + for code, pc in bsst.parse_failcodes(errstr): + m = re.match('check_(assumption|assertion)_at_line_(\\d+)', + code) + assert m, (f'assertion or assumption failure expected, ' + f'but got "{code}"') + fl = int(m.group(2)) + flines.add(fl) + else: + m = re.match('assertion failed at line (\\d+)', errstr) + assert m, (f'assertion or assumption failure expected, ' + f'but got "{errstr}"') + fl = int(m.group(1)) + flines.add(fl) + + assert flines.issubset(failure_lines), (flines, failure_lines) + seen_failure_lines.update(flines) + + env.get_root_branch().walk_contexts(search_failures, + include_failed=True) + + assert seen_failure_lines, "at least one failure must be detected" + assert seen_failure_lines.issubset(failure_lines) + if is_exact_match: + assert seen_failure_lines == failure_lines, \ + ("exact match expected", seen_failure_lines, failure_lines) + + print("OK") + + +def test_other_failing(tc_no: int, tc_text: str) -> None: + + with FreshEnv(z3_enabled=True) as env: + env.use_parallel_solving = False + env.log_progress = False + env.is_elements = True + env.is_incomplete_script = True + env.solver_timeout_seconds = 0 + + try: + env.script_info = bsst.get_opcodes(tc_text.split('\n')) + except bsst.BSSTParsingError: + print("OK") + return + + assert False, "BSSTParsingError expected" + + +if __name__ == '__main__': + for tc_no, (tc_text, tc_expected_values) in enumerate(testcases_normal): + print("TESTCASE ASSN NORMAL", tc_no+1, end=' ') + test_assn_normal(tc_no, tc_text, tc_expected_values) + + for tc_no, (tc_text, failure_lines, is_exact_match) in \ + enumerate(testcases_assnfail): + print("TESTCASE ASSN FAILING", tc_no+1, end=' ') + test_assn_failing(tc_no, tc_text, failure_lines, is_exact_match) + + for tc_no, tc_text in enumerate(testcases_otherfail): + print("TESTCASE OTHER FAILING", tc_no+1, end=' ') + test_other_failing(tc_no, tc_text) diff --git a/tests/test_data_placeholders.py b/tests/test_data_placeholders.py index 731b8d0..0064688 100755 --- a/tests/test_data_placeholders.py +++ b/tests/test_data_placeholders.py @@ -2,13 +2,13 @@ import sys -from io import StringIO - from contextlib import contextmanager from typing import Generator import bsst +from test_util import CaptureStdout + testcase = """ $a 1 add $a 2 add 1 sub @@ -42,16 +42,7 @@ @contextmanager -def CaptureStdout() -> Generator[StringIO, None, None]: - save_stdout = sys.stdout - out = StringIO() - sys.stdout = out - yield out - sys.stdout = save_stdout - - -@contextmanager -def FreshEnv() -> Generator[None, None, None]: +def FreshEnv() -> Generator[bsst.SymEnvironment, None, None]: env = bsst.SymEnvironment() env.use_parallel_solving = False env.log_progress = False @@ -63,20 +54,16 @@ def FreshEnv() -> Generator[None, None, None]: bsst.try_import_optional_modules() bp = bsst.Branchpoint(pc=0, branch_index=0) with bsst.CurrentExecContext(bp.context): - yield + yield env def test() -> None: - with FreshEnv(): - (bsst.g_script_body, - bsst.g_line_no_table, - bsst.g_var_save_positions) = bsst.get_opcodes(testcase.split('\n')) + with FreshEnv() as env: + env.script_info = bsst.get_opcodes(testcase.split('\n')) out: str = '' with CaptureStdout() as output: - bsst.g_is_in_processing = True bsst.symex_script() - bsst.g_is_in_processing = False bsst.report() out = output.getvalue() diff --git a/tests/test_elements_script_tests.py b/tests/test_elements_script_tests.py index f3fa575..2a331e0 100755 --- a/tests/test_elements_script_tests.py +++ b/tests/test_elements_script_tests.py @@ -97,7 +97,7 @@ def convert_script(line: str, flags: list[str], elif len(op_str) >= 2 and op_str[0] == "'" and op_str[-1] == "'": script_bytes.append(CScript([op_str[1:-1].encode('utf-8')])) else: - ops, _, _ = bsst.get_opcodes([maybe_subst_with_nop(op_str, flags)]) + ops = bsst.get_opcodes([maybe_subst_with_nop(op_str, flags)]).body assert len(ops) == 1 assert isinstance(ops[0], bsst.OpCode) script_bytes.append(CScript(bytes([ops[0].code]))) @@ -128,7 +128,7 @@ def convert_script(line: str, flags: list[str], script_lines.append(maybe_subst_with_nop(op_str, flags)) - return bsst.get_opcodes(script_lines)[0] + return bsst.get_opcodes(script_lines).body supported_flags = {'DISCOURAGE_UPGRADEABLE_PUBKEY_TYPE', 'STRICTENC', @@ -182,13 +182,15 @@ def do_processing(bp: 'bsst.Branchpoint', level: int) -> None: def set_script_body(script_body: tuple[bsst.OpCode | bsst.ScriptData, ...] ) -> None: - bsst.g_script_body = script_body - bsst.g_var_save_positions = {} - bsst.g_line_no_table = [] - for pc in range(len(bsst.g_script_body)): - bsst.g_line_no_table.append(pc) + line_no_table = [] + for pc in range(len(script_body)): + line_no_table.append(pc) - bsst.g_line_no_table.append(len(bsst.g_script_body)) + line_no_table.append(len(script_body)) + + env = bsst.cur_env() + env.script_info = bsst.ScriptInfo(body=script_body, + line_no_table=line_no_table) def process_testcase_single( @@ -202,8 +204,6 @@ def process_testcase_single( use_nonstatic_witnesses: bool = False, flags_were_altered: bool = False ) -> None: - set_script_body(scriptPubKey) - if use_nonstatic_witnesses: assert z3_enabled @@ -216,6 +216,7 @@ def check_p2sh() -> None: assert scriptPubKey[2] == bsst.OP_EQUAL with FreshEnv(z3_enabled=z3_enabled) as env: + set_script_body(scriptPubKey) common_env_settings(env, flags) env.z3_enabled = z3_enabled env.do_progressive_z3_checks = True @@ -246,9 +247,7 @@ def check_p2sh() -> None: data.increase_refcount() try: - bsst.g_is_in_processing = True bsst.symex_script() - bsst.g_is_in_processing = False except ValueError as e: if str(e).startswith('non-static value:'): if expected_result == 'INVALID_STACK_OPERATION': @@ -345,7 +344,7 @@ def check_p2sh() -> None: pass elif expected_result == 'UNBALANCED_CONDITIONAL': assert ( - bsst.g_script_body[ctx.used_witnesses[0].src_pc] + env.script_info.body[ctx.used_witnesses[0].src_pc] in (bsst.OP_IF, bsst.OP_NOTIF) ) elif (expected_result == 'UNSATISFIED_LOCKTIME' and @@ -402,16 +401,16 @@ def check_p2sh() -> None: (len(invalid_contexts) == 1 and len(failures) == 1 and ((failures[0] == 'check_length_mismatch' and - bsst.g_script_body[invalid_contexts[0].pc] in (bsst.OP_AND, - bsst.OP_OR)) + env.script_info.body[invalid_contexts[0].pc] in (bsst.OP_AND, + bsst.OP_OR)) or (failures[0] in ('check_negative_argument', 'check_argument_above_bounds') and - bsst.g_script_body[invalid_contexts[0].pc] in (bsst.OP_SUBSTR,)) + env.script_info.body[invalid_contexts[0].pc] in (bsst.OP_SUBSTR,)) or (failures[0] == 'check_data_too_long' and - bsst.g_script_body[invalid_contexts[0].pc] in (bsst.OP_CAT,))))) + env.script_info.body[invalid_contexts[0].pc] in (bsst.OP_CAT,))))) elif expected_result == 'PUSH_SIZE': assert any(f == 'check_data_too_long' for f in failures) elif expected_result == 'OP_COUNT': @@ -496,16 +495,14 @@ def process_testcase( clean_contexts() if scriptSig: - set_script_body(scriptSig) with FreshEnv() as env: + set_script_body(scriptSig) common_env_settings(env, flags) env.is_incomplete_script = True print("Sym-exec SSig") - bsst.g_is_in_processing = True bsst.symex_script() - bsst.g_is_in_processing = False bsst.report() sys.stdout.flush() diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 383f590..205d541 100755 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -1,26 +1,17 @@ #!/usr/bin/env python3 import os -import sys import struct from contextlib import contextmanager from typing import Generator, Iterable -from io import StringIO from bitcointx.core.key import CKey, XOnlyPubKey import bsst -# pylama:ignore=E501 - +from test_util import CaptureStdout -@contextmanager -def CaptureStdout() -> Generator[StringIO, None, None]: - save_stdout = sys.stdout - out = StringIO() - sys.stdout = out - yield out - sys.stdout = save_stdout +# pylama:ignore=E501 @contextmanager @@ -91,13 +82,10 @@ def do_test_single(script: str, *, assume_no_160bit_hash_collisions=assume_no_160bit_hash_collisions, nullfail_flag=nullfail_flag ) as env: - (bsst.g_script_body, - bsst.g_line_no_table, - bsst.g_var_save_positions) = bsst.get_opcodes(script.split('\n')) - bsst.g_is_in_processing = True + env.script_info = bsst.get_opcodes(script.split('\n')) + bsst.symex_script() - bsst.g_is_in_processing = False bsst.report() process_contexts(env) @@ -107,7 +95,7 @@ def do_test_single(script: str, *, assert len(invalid_contexts) == 0 else: assert len(invalid_contexts) > 0 - assert not (set(failures) - set(expect_failures)) + assert not (set(failures) - set(expect_failures)), failures return failures diff --git a/tests/test_util/__init__.py b/tests/test_util/__init__.py new file mode 100644 index 0000000..5a1373a --- /dev/null +++ b/tests/test_util/__init__.py @@ -0,0 +1,14 @@ +import sys + +from io import StringIO +from typing import Generator +from contextlib import contextmanager + + +@contextmanager +def CaptureStdout() -> Generator[StringIO, None, None]: + save_stdout = sys.stdout + out = StringIO() + sys.stdout = out + yield out + sys.stdout = save_stdout diff --git a/tests/test_varnames.py b/tests/test_varnames.py index a485085..dd74347 100755 --- a/tests/test_varnames.py +++ b/tests/test_varnames.py @@ -2,13 +2,13 @@ import sys -from io import StringIO - from contextlib import contextmanager from typing import Generator import bsst +from test_util import CaptureStdout + testcase: list[str] = [] expected_result: list[str] = [] expected_result_z3: list[str] = [] @@ -135,16 +135,7 @@ @contextmanager -def CaptureStdout() -> Generator[StringIO, None, None]: - save_stdout = sys.stdout - out = StringIO() - sys.stdout = out - yield out - sys.stdout = save_stdout - - -@contextmanager -def FreshEnv() -> Generator[None, None, None]: +def FreshEnv() -> Generator[bsst.SymEnvironment, None, None]: env = bsst.SymEnvironment() env.use_parallel_solving = False env.log_progress = False @@ -154,21 +145,16 @@ def FreshEnv() -> Generator[None, None, None]: bsst.try_import_optional_modules() bp = bsst.Branchpoint(pc=0, branch_index=0) with bsst.CurrentExecContext(bp.context): - yield + yield env def test(testno: int, expres: list[str]) -> None: - with FreshEnv(): - (bsst.g_script_body, - bsst.g_line_no_table, - bsst.g_var_save_positions) = bsst.get_opcodes( - testcase[testno].split('\n')) + with FreshEnv() as env: + env.script_info = bsst.get_opcodes(testcase[testno].split('\n')) out: str = '' with CaptureStdout() as output: - bsst.g_is_in_processing = True bsst.symex_script() - bsst.g_is_in_processing = False bsst.report() out = output.getvalue()