diff --git a/src/halmos/calldata.py b/src/halmos/calldata.py index 5dd28dd0..dcfdc361 100644 --- a/src/halmos/calldata.py +++ b/src/halmos/calldata.py @@ -103,19 +103,10 @@ class FunctionInfo: class Calldata: - args: HalmosConfig - - # `arrlen` holds the parsed value of --array-lengths, specifying the sizes for certain dynamic parameters. - # `default_bytes_lengths` holds the parsed value of --default-bytes-lengths. - # # For dynamic parameters not explicitly listed in --array-lengths, default sizes are used: # - For dynamic arrays: the size ranges from 0 to the value of --loop (inclusive). # - For bytes or strings: the size candidates are given by --default-bytes-lengths. - # - # TODO: Extend `args` to include `arrlen` and `default_bytes_lengths` - # to prevent re-parsing --array-lengths and --default-bytes-lengths multiple times. - arrlen: dict[str, list[int]] - default_bytes_lengths: list[int] + args: HalmosConfig # `dyn_params` will be updated to include the fully resolved size information for all dynamic parameters. dyn_params: list[DynamicParam] @@ -130,10 +121,6 @@ def __init__( new_symbol_id: Callable | None, ) -> None: self.args = args - self.arrlen = mk_arrlen(args) - self.default_bytes_lengths = [ - int(x.strip()) for x in self.args.default_bytes_lengths.split(",") - ] self.dyn_params = [] self.new_symbol_id = new_symbol_id if new_symbol_id else lambda: "" @@ -144,13 +131,13 @@ def get_dyn_sizes(self, name: str, typ: Type) -> tuple[list[int], BitVecRef]: The candidates are derived from --array_lengths if provided; otherwise, default values are used. """ - sizes = self.arrlen.get(name) + sizes = self.args.array_lengths.get(name) if sizes is None: sizes = ( list(range(self.args.loop + 1)) if isinstance(typ, DynamicArrayType) - else self.default_bytes_lengths # bytes or string + else self.args.default_bytes_lengths # bytes or string ) if self.args.debug: print( @@ -345,15 +332,3 @@ def mk_calldata( new_symbol_id: Callable = None, ) -> tuple[ByteVec, list[DynamicParam]]: return Calldata(args, new_symbol_id).create(abi, fun_info) - - -def mk_arrlen(args: HalmosConfig) -> dict[str, list[int]]: - if not args.array_lengths: - return {} - - # TODO: update syntax: name1=size1,size2; name2=size3,...; ... - name_sizes_pairs = args.array_lengths.split(",") - return { - name.strip(): [int(x.strip()) for x in sizes.split(";")] - for name, sizes in [x.split("=") for x in name_sizes_pairs] - } diff --git a/src/halmos/config.py b/src/halmos/config.py index 738fdba6..fad3b0d3 100644 --- a/src/halmos/config.py +++ b/src/halmos/config.py @@ -2,6 +2,7 @@ import os import sys from collections import OrderedDict +from collections.abc import Callable from dataclasses import MISSING, dataclass, fields from dataclasses import field as dataclass_field from typing import Any @@ -33,6 +34,7 @@ def arg( short: str | None = None, countable: bool = False, global_default_str: str | None = None, + action: Callable = None, ): return dataclass_field( default=None, @@ -45,10 +47,39 @@ def arg( "short": short, "countable": countable, "global_default_str": global_default_str, + "action": action, }, ) +class ParseCSV(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + values = ParseCSV.parse(values) + setattr(namespace, self.dest, values) + + @staticmethod + def parse(values: str) -> list[int]: + return [int(x.strip()) for x in values.split(",")] + + +class ParseArrayLengths(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + values = ParseArrayLengths.parse(values) + setattr(namespace, self.dest, values) + + @staticmethod + def parse(values: str | None) -> dict[str, list[int]]: + if not values: + return {} + + # TODO: update syntax: name1=size1,size2; name2=size3,...; ... + name_sizes_pairs = values.split(",") + return { + name.strip(): [int(x.strip()) for x in sizes.split(";")] + for name, sizes in [x.split("=") for x in name_sizes_pairs] + } + + # TODO: add kw_only=True when we support Python>=3.10 @dataclass(frozen=True) class Config: @@ -152,12 +183,14 @@ class Config: help="set the length of dynamic-sized arrays including bytes and string (default: loop unrolling bound)", global_default=None, metavar="NAME1=LENGTH1,NAME2=LENGTH2,...", + action=ParseArrayLengths, ) default_bytes_lengths: str = arg( help="set the default length candidates for bytes and string not specified in --array-lengths", global_default="0,32,1024,65", # 65 is ECDSA signature size metavar="LENGTH1,LENGTH2,...", + action=ParseCSV, ) storage_layout: str = arg( @@ -528,7 +561,12 @@ def _create_default_config() -> "Config": if default == MISSING: continue - values[field.name] = default() if callable(default) else default + # retrieve the default value + raw_value = default() if callable(default) else default + + # parse the default value, if a custom parser is provided + action = field.metadata.get("action", None) + values[field.name] = action.parse(raw_value) if action else raw_value return Config(_parent=None, _source="default", **values) @@ -583,6 +621,8 @@ def _create_arg_parser() -> argparse.ArgumentParser: } if choices := field_info.metadata.get("choices", None): kwargs["choices"] = choices + if action := field_info.metadata.get("action", None): + kwargs["action"] = action group.add_argument(*names, **kwargs) return parser