Skip to content

Commit

Permalink
perf: support for custom action in config parser (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark committed Sep 16, 2024
1 parent d1dea3f commit 1e850a6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 29 deletions.
31 changes: 3 additions & 28 deletions src/halmos/calldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,10 @@ class FunctionInfo:


class Calldata:
args: HalmosConfig

# `arrlen` holds the parsed value of --array-lengths, specifying the sizes for certain dynamic parameters.
# `default_bytes_lengths` holds the parsed value of --default-bytes-lengths.
#
# For dynamic parameters not explicitly listed in --array-lengths, default sizes are used:
# - For dynamic arrays: the size ranges from 0 to the value of --loop (inclusive).
# - For bytes or strings: the size candidates are given by --default-bytes-lengths.
#
# TODO: Extend `args` to include `arrlen` and `default_bytes_lengths`
# to prevent re-parsing --array-lengths and --default-bytes-lengths multiple times.
arrlen: dict[str, list[int]]
default_bytes_lengths: list[int]
args: HalmosConfig

# `dyn_params` will be updated to include the fully resolved size information for all dynamic parameters.
dyn_params: list[DynamicParam]
Expand All @@ -130,10 +121,6 @@ def __init__(
new_symbol_id: Callable | None,
) -> None:
self.args = args
self.arrlen = mk_arrlen(args)
self.default_bytes_lengths = [
int(x.strip()) for x in self.args.default_bytes_lengths.split(",")
]
self.dyn_params = []
self.new_symbol_id = new_symbol_id if new_symbol_id else lambda: ""

Expand All @@ -144,13 +131,13 @@ def get_dyn_sizes(self, name: str, typ: Type) -> tuple[list[int], BitVecRef]:
The candidates are derived from --array_lengths if provided; otherwise, default values are used.
"""

sizes = self.arrlen.get(name)
sizes = self.args.array_lengths.get(name)

if sizes is None:
sizes = (
list(range(self.args.loop + 1))
if isinstance(typ, DynamicArrayType)
else self.default_bytes_lengths # bytes or string
else self.args.default_bytes_lengths # bytes or string
)
if self.args.debug:
print(
Expand Down Expand Up @@ -345,15 +332,3 @@ def mk_calldata(
new_symbol_id: Callable = None,
) -> tuple[ByteVec, list[DynamicParam]]:
return Calldata(args, new_symbol_id).create(abi, fun_info)


def mk_arrlen(args: HalmosConfig) -> dict[str, list[int]]:
if not args.array_lengths:
return {}

# TODO: update syntax: name1=size1,size2; name2=size3,...; ...
name_sizes_pairs = args.array_lengths.split(",")
return {
name.strip(): [int(x.strip()) for x in sizes.split(";")]
for name, sizes in [x.split("=") for x in name_sizes_pairs]
}
42 changes: 41 additions & 1 deletion src/halmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import MISSING, dataclass, fields
from dataclasses import field as dataclass_field
from typing import Any
Expand Down Expand Up @@ -33,6 +34,7 @@ def arg(
short: str | None = None,
countable: bool = False,
global_default_str: str | None = None,
action: Callable = None,
):
return dataclass_field(
default=None,
Expand All @@ -45,10 +47,39 @@ def arg(
"short": short,
"countable": countable,
"global_default_str": global_default_str,
"action": action,
},
)


class ParseCSV(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
values = ParseCSV.parse(values)
setattr(namespace, self.dest, values)

@staticmethod
def parse(values: str) -> list[int]:
return [int(x.strip()) for x in values.split(",")]


class ParseArrayLengths(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
values = ParseArrayLengths.parse(values)
setattr(namespace, self.dest, values)

@staticmethod
def parse(values: str | None) -> dict[str, list[int]]:
if not values:
return {}

# TODO: update syntax: name1=size1,size2; name2=size3,...; ...
name_sizes_pairs = values.split(",")
return {
name.strip(): [int(x.strip()) for x in sizes.split(";")]
for name, sizes in [x.split("=") for x in name_sizes_pairs]
}


# TODO: add kw_only=True when we support Python>=3.10
@dataclass(frozen=True)
class Config:
Expand Down Expand Up @@ -152,12 +183,14 @@ class Config:
help="set the length of dynamic-sized arrays including bytes and string (default: loop unrolling bound)",
global_default=None,
metavar="NAME1=LENGTH1,NAME2=LENGTH2,...",
action=ParseArrayLengths,
)

default_bytes_lengths: str = arg(
help="set the default length candidates for bytes and string not specified in --array-lengths",
global_default="0,32,1024,65", # 65 is ECDSA signature size
metavar="LENGTH1,LENGTH2,...",
action=ParseCSV,
)

storage_layout: str = arg(
Expand Down Expand Up @@ -528,7 +561,12 @@ def _create_default_config() -> "Config":
if default == MISSING:
continue

values[field.name] = default() if callable(default) else default
# retrieve the default value
raw_value = default() if callable(default) else default

# parse the default value, if a custom parser is provided
action = field.metadata.get("action", None)
values[field.name] = action.parse(raw_value) if action else raw_value

return Config(_parent=None, _source="default", **values)

Expand Down Expand Up @@ -583,6 +621,8 @@ def _create_arg_parser() -> argparse.ArgumentParser:
}
if choices := field_info.metadata.get("choices", None):
kwargs["choices"] = choices
if action := field_info.metadata.get("action", None):
kwargs["action"] = action
group.add_argument(*names, **kwargs)

return parser
Expand Down

0 comments on commit 1e850a6

Please sign in to comment.