From 89b1cfb5d061fbaed4dbc155196c082dbfd73523 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 13 Nov 2023 15:46:45 -0500 Subject: [PATCH] more --- clvm/CLVMObject.py | 4 ++- clvm/SExp.py | 14 ++++---- clvm/as_python.py | 22 +++++++++---- clvm/core_ops.py | 6 ++-- clvm/more_ops.py | 18 +++++++---- clvm/operators.py | 65 +++++++++++++++++++++++++------------- clvm/run_program.py | 25 ++++++++------- clvm/serialize.py | 32 +++++++++++++++---- setup.py | 2 +- tests/operatordict_test.py | 16 +++++----- 10 files changed, 130 insertions(+), 74 deletions(-) diff --git a/clvm/CLVMObject.py b/clvm/CLVMObject.py index d4dfcc04..e9ea3d2e 100644 --- a/clvm/CLVMObject.py +++ b/clvm/CLVMObject.py @@ -44,7 +44,9 @@ def __new__( raise ValueError("tuples must be of size 2, cannot create CLVMObject from: %s" % str(v)) self.pair = v self.atom = None - else: + elif isinstance(v, bytes): self.atom = v self.pair = None + else: + raise ValueError(f"cannot create CLVMObject from: {v!r}") return self diff --git a/clvm/SExp.py b/clvm/SExp.py index 7875b46e..c955b6f1 100644 --- a/clvm/SExp.py +++ b/clvm/SExp.py @@ -178,18 +178,18 @@ def as_bin(self) -> bytes: return f.getvalue() @classmethod - def to(class_, v: CastableType) -> "SExp": - if isinstance(v, class_): + def to(cls: typing.Type[_T_SExp], v: CastableType) -> _T_SExp: + if isinstance(v, cls): return v if looks_like_clvm_object(v): # TODO: maybe this can be done more cleanly - return class_(typing.cast(CLVMObjectLike, v)) + return cls(typing.cast(CLVMObjectLike, v)) # this will lazily convert elements - return class_(to_sexp_type(v)) + return cls(to_sexp_type(v)) - def cons(self: _T_SExp, right) -> _T_SExp: + def cons(self: _T_SExp, right: _T_SExp) -> _T_SExp: return self.to((self, right)) def first(self: _T_SExp) -> _T_SExp: @@ -214,9 +214,9 @@ def as_iter(self: _T_SExp) -> typing.Iterable[_T_SExp]: yield v.first() v = v.rest() - def __eq__(self, other: CastableType) -> bool: + def __eq__(self, other: object) -> bool: try: - other = self.to(other) + other = self.to(typing.cast(CastableType, other)) to_compare_stack = [(self, other)] while to_compare_stack: s1, s2 = to_compare_stack.pop() diff --git a/clvm/as_python.py b/clvm/as_python.py index 9370d21d..a72c9274 100644 --- a/clvm/as_python.py +++ b/clvm/as_python.py @@ -1,19 +1,27 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Callable, List, Tuple, TYPE_CHECKING, Union if TYPE_CHECKING: from clvm.SExp import SExp +OpCallable = Callable[["OpStackType", "ValStackType"], None] + +ValStackType = List[SExp] +OpStackType = List[OpCallable] + +# TODO: hum... +PythonType = Union[int, bytes, str, List["PythonType"], Tuple["PythonType", "PythonType"]] + def as_python(sexp: SExp): - def _roll(op_stack, val_stack): + def _roll(op_stack: OpStackType, value_stack: ValStackType) -> None: v1 = val_stack.pop() v2 = val_stack.pop() val_stack.append(v1) val_stack.append(v2) - def _make_tuple(op_stack, val_stack): + def _make_tuple(op_stack: OpStackType, value_stack: ValStackType) -> None: left = val_stack.pop() right = val_stack.pop() if right == b"": @@ -24,7 +32,7 @@ def _make_tuple(op_stack, val_stack): else: val_stack.append((left, right)) - def _as_python(op_stack, val_stack): + def _as_python(op_stack: OpStackType, value_stack: ValStackType) -> None: t = val_stack.pop() pair = t.as_pair() if pair: @@ -36,10 +44,10 @@ def _as_python(op_stack, val_stack): val_stack.append(left) val_stack.append(right) else: - val_stack.append(t.as_atom()) + val_stack.append(t.atom) - op_stack = [_as_python] - val_stack = [sexp] + op_stack: OpStackType = [_as_python] + val_stack: ValStackType = [sexp] while op_stack: op_f = op_stack.pop() op_f(op_stack, val_stack) diff --git a/clvm/core_ops.py b/clvm/core_ops.py index 3a114125..d8a69c09 100644 --- a/clvm/core_ops.py +++ b/clvm/core_ops.py @@ -46,7 +46,7 @@ def op_rest(args: _T_SExp) -> Tuple[int, _T_SExp]: return REST_COST, args.first().rest() -def op_listp(args: _T_SExp) -> Tuple[int, _T_SExp]: +def op_listp(args: _T_SExp) -> Tuple[int, SExp]: if args.list_len() != 1: raise EvalError("l takes exactly 1 argument", args) return LISTP_COST, args.true if args.first().listp() else args.false @@ -59,7 +59,7 @@ def op_raise(args: _T_SExp) -> Tuple[int, _T_SExp]: raise EvalError("clvm raise", args) -def op_eq(args: _T_SExp) -> Tuple[int, _T_SExp]: +def op_eq(args: _T_SExp) -> Tuple[int, SExp]: if args.list_len() != 2: raise EvalError("= takes exactly 2 arguments", args) a0 = args.first() @@ -67,7 +67,9 @@ def op_eq(args: _T_SExp) -> Tuple[int, _T_SExp]: if a0.pair or a1.pair: raise EvalError("= on list", a0 if a0.pair else a1) b0 = a0.as_atom() + assert b0 is not None b1 = a1.as_atom() + assert b1 is not None cost = EQ_BASE_COST cost += (len(b0) + len(b1)) * EQ_COST_PER_BYTE return cost, (args.true if b0 == b1 else args.false) diff --git a/clvm/more_ops.py b/clvm/more_ops.py index 55690ba3..e6b48b7e 100644 --- a/clvm/more_ops.py +++ b/clvm/more_ops.py @@ -224,7 +224,7 @@ def op_gr_bytes(args: SExp) -> typing.Tuple[int, SExp]: return cost, args.true if b0 > b1 else args.false -def op_pubkey_for_exp(args: _T_SExp) -> typing.Tuple[_T_SExp, _T_SExp]: +def op_pubkey_for_exp(args: _T_SExp) -> typing.Tuple[int, _T_SExp]: ((i0, l0),) = args_as_int_list("pubkey_for_exp", args, 1) i0 %= 0x73EDA753299D7D483339D80809A1D80553BDA402FFFE5BFEFFFFFFFF00000001 exponent = PrivateKey.from_bytes(i0.to_bytes(32, "big")) @@ -258,7 +258,8 @@ def op_strlen(args: _T_SExp) -> typing.Tuple[int, _T_SExp]: a0 = args.first() if a0.pair: raise EvalError("strlen on list", a0) - size = len(a0.as_atom()) + assert a0.atom is not None + size = len(a0.atom) cost = STRLEN_BASE_COST + size * STRLEN_COST_PER_BYTE return malloc_cost(cost, args.to(size)) @@ -272,6 +273,7 @@ def op_substr(args: _T_SExp) -> typing.Tuple[int, _T_SExp]: raise EvalError("substr on list", a0) s0 = a0.as_atom() + assert s0 is not None if arg_count == 2: i1, = list(args_as_int32("substr", args.rest())) @@ -292,7 +294,8 @@ def op_concat(args: _T_SExp) -> typing.Tuple[int, _T_SExp]: for arg in args.as_iter(): if arg.pair: raise EvalError("concat on list", arg) - s.write(arg.as_atom()) + assert arg.atom is not None + s.write(arg.atom) cost += CONCAT_COST_PER_ARG r = s.getvalue() cost += len(r) * CONCAT_COST_PER_BYTE @@ -322,6 +325,7 @@ def op_lsh(args: _T_SExp) -> typing.Tuple[int, _T_SExp]: raise EvalError("shift too large", args.to(i1)) # we actually want i0 to be an *unsigned* int a0 = args.first().as_atom() + assert a0 is not None i0 = int.from_bytes(a0, "big", signed=False) if i1 >= 0: r = i0 << i1 @@ -350,7 +354,7 @@ def binop_reduction( def op_logand(args: _T_SExp) -> typing.Tuple[int, _T_SExp]: - def binop(a, b): + def binop(a: int, b: int) -> int: a &= b return a @@ -358,7 +362,7 @@ def binop(a, b): def op_logior(args: _T_SExp) -> typing.Tuple[int, _T_SExp]: - def binop(a, b): + def binop(a: int, b: int) -> int: a |= b return a @@ -366,7 +370,7 @@ def binop(a, b): def op_logxor(args: _T_SExp) -> typing.Tuple[int, _T_SExp]: - def binop(a, b): + def binop(a: int, b: int) -> int: a ^= b return a @@ -411,7 +415,7 @@ def op_all(args: _T_SExp) -> typing.Tuple[int, _T_SExp]: return cost, args.to(r) -def op_softfork(args: SExp) -> typing.Tuple[int, bool]: +def op_softfork(args: SExp) -> typing.Tuple[int, SExp]: if args.list_len() < 1: raise EvalError("softfork takes at least 1 argument", args) a = args.first() diff --git a/clvm/operators.py b/clvm/operators.py index 4bdef192..c1ce92d9 100644 --- a/clvm/operators.py +++ b/clvm/operators.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Tuple, Type, TypeVar +from typing import Dict, Iterator, Optional, Tuple, Type, TypeVar from typing_extensions import Protocol @@ -28,22 +28,16 @@ KEYWORDS = ( # core opcodes 0x01-x08 ". q a i c f r l x " - # opcodes on atoms as strings 0x09-0x0f "= >s sha256 substr strlen concat . " - # opcodes on atoms as ints 0x10-0x17 "+ - * / divmod > ash lsh " - # opcodes on atoms as vectors of bools 0x18-0x1c "logand logior logxor lognot . " - # opcodes for bls 1381 0x1d-0x1f "point_add pubkey_for_exp . " - # bool opcodes 0x20-0x23 "not any all . " - # misc 0x24 "softfork " ).split() @@ -68,11 +62,12 @@ } -def args_len(op_name, args: SExp): +def args_len(op_name: str, args: SExp) -> Iterator[int]: for arg in args.as_iter(): if arg.pair: raise EvalError("%s requires int args" % op_name, arg) - yield len(arg.as_atom()) + assert arg.atom is not None + yield len(arg.atom) # unknown ops are reserved if they start with 0xffff @@ -102,6 +97,7 @@ def args_len(op_name, args: SExp): # this means that unknown ops where cost_function is 1, 2, or 3, may still be # fatal errors if the arguments passed are not atoms. + def default_unknown_op(op: bytes, args: SExp) -> Tuple[int, SExp]: # any opcode starting with ffff is reserved (i.e. fatal error) # opcodes are not allowed to be empty @@ -158,6 +154,7 @@ def default_unknown_op(op: bytes, args: SExp) -> Tuple[int, SExp]: if arg.pair: raise EvalError("unknown op on list", arg) cost += CONCAT_COST_PER_ARG + assert arg.atom is not None length += len(arg.atom) cost += length * CONCAT_COST_PER_BYTE @@ -169,7 +166,13 @@ def default_unknown_op(op: bytes, args: SExp) -> Tuple[int, SExp]: class OperatorProtocol(Protocol): - def __call__(self, op: bytes, args: SExp) -> Tuple[int, SExp]: ... + def __call__(self, args: SExp) -> Tuple[int, SExp]: + ... + + +class UnknownOperatorProtocol(Protocol): + def __call__(self, op: bytes, args: SExp) -> Tuple[int, SExp]: + ... _T_OperatorDict = TypeVar("_T_OperatorDict", bound="OperatorDict") @@ -181,12 +184,21 @@ class OperatorDict(Dict[bytes, OperatorProtocol]): operators can be added dynamically. """ - unknown_op_handler: OperatorProtocol - quote_atom: int - apply_atom: int - - # TODO: how do you create an instance if that requires passing in an instance? - def __new__(cls: Type[_T_OperatorDict], d: Dict[bytes, OperatorProtocol], *args: object, **kwargs) -> _T_OperatorDict: + unknown_op_handler: UnknownOperatorProtocol + quote_atom: bytes + apply_atom: bytes + + # TODO: can we remove the args and kwargs? + # TODO: hint the overloads + def __new__( + cls: Type[_T_OperatorDict], + d: Dict[bytes, OperatorProtocol], + *args: object, + quote: Optional[bytes] = None, + apply: Optional[bytes] = None, + unknown_op_handler: UnknownOperatorProtocol = default_unknown_op, + **kwargs: object, + ) -> _T_OperatorDict: """ `quote_atom` and `apply_atom` must be set `unknown_op_handler` has a default implementation @@ -194,12 +206,19 @@ def __new__(cls: Type[_T_OperatorDict], d: Dict[bytes, OperatorProtocol], *args: We do not check if the opcode values for quote and apply exist in the passed-in dict """ self = super().__new__(cls, d) - self.quote_atom = kwargs["quote"] if "quote" in kwargs else d.quote_atom - self.apply_atom = kwargs["apply"] if "apply" in kwargs else d.apply_atom - if "unknown_op_handler" in kwargs: - self.unknown_op_handler = kwargs["unknown_op_handler"] + + if quote is None: + assert isinstance(d, OperatorDict) + self.quote_atom = d.quote_atom else: - self.unknown_op_handler = default_unknown_op + self.quote_atom = quote + + if apply is None: + assert isinstance(d, OperatorDict) + self.apply_atom = d.apply_atom + else: + self.apply_atom = apply + return self def __call__(self, op: bytes, arguments: SExp) -> Tuple[int, SExp]: @@ -214,6 +233,8 @@ def __call__(self, op: bytes, arguments: SExp) -> Tuple[int, SExp]: APPLY_ATOM = KEYWORD_TO_ATOM["a"] OPERATOR_LOOKUP = OperatorDict( - operators_for_module(KEYWORD_TO_ATOM, core_ops, OP_REWRITE), quote=QUOTE_ATOM, apply=APPLY_ATOM + operators_for_module(KEYWORD_TO_ATOM, core_ops, OP_REWRITE), + quote=QUOTE_ATOM, + apply=APPLY_ATOM, ) OPERATOR_LOOKUP.update(operators_for_module(KEYWORD_TO_ATOM, more_ops, OP_REWRITE)) diff --git a/clvm/run_program.py b/clvm/run_program.py index a8b5fac0..5db7a00b 100644 --- a/clvm/run_program.py +++ b/clvm/run_program.py @@ -1,8 +1,8 @@ -from typing import Any, Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple from .CLVMObject import CLVMObject from .EvalError import EvalError -from .SExp import SExp +from .SExp import CastableType, SExp from .operators import OperatorDict from .costs import ( @@ -13,16 +13,14 @@ PATH_LOOKUP_COST_PER_ZERO_BYTE ) -# the "Any" below should really be "OpStackType" but -# recursive types aren't supported by mypy - -OpCallable = Callable[[Any, "ValStackType"], int] +OpCallable = Callable[["OpStackType", "ValStackType"], int] +PreOpCallable = Callable[["OpStackType", "ValStackType"], None] ValStackType = List[SExp] OpStackType = List[OpCallable] -def to_pre_eval_op(pre_eval_f, to_sexp_f) -> Callable[[OpStackType, ValStackType], None]: +def to_pre_eval_op(pre_eval_f: Callable[[SExp, SExp], Optional[Callable[[SExp], object]]], to_sexp_f: Callable[[CastableType], SExp]) -> PreOpCallable: def my_pre_eval_op(op_stack: OpStackType, value_stack: ValStackType) -> None: v = to_sexp_f(value_stack[-1]) context = pre_eval_f(v.first(), v.rest()) @@ -39,7 +37,7 @@ def invoke_context_op( return my_pre_eval_op -def msb_mask(byte): +def msb_mask(byte: int) -> int: byte |= byte >> 1 byte |= byte >> 2 byte |= byte >> 4 @@ -48,14 +46,14 @@ def msb_mask(byte): def run_program( program: CLVMObject, - args: CLVMObject, + args: SExp, operator_lookup: OperatorDict, max_cost: Optional[int] = None, - pre_eval_f=None, + pre_eval_f: Optional[PreOpCallable] = None, ) -> Tuple[int, SExp]: _program = SExp.to(program) - if pre_eval_f: + if pre_eval_f is not None: pre_eval_op = to_pre_eval_op(pre_eval_f, _program.to) else: pre_eval_op = None @@ -129,7 +127,9 @@ def eval_op(op_stack: OpStackType, value_stack: ValStackType) -> int: operator = sexp.first() if operator.pair: - new_operator, must_be_nil = operator.as_pair() + from_as_pair = operator.as_pair() + assert from_as_pair is not None + new_operator, must_be_nil = from_as_pair if new_operator.pair or must_be_nil.atom != b"": raise EvalError("in ((X)...) syntax X must be lone atom", sexp) new_operand_list = sexp.rest() @@ -163,6 +163,7 @@ def apply_op(op_stack: OpStackType, value_stack: ValStackType) -> int: raise EvalError("internal error", operator) op = operator.as_atom() + assert op is not None if op == operator_lookup.apply_atom: if operand_list.list_len() != 2: raise EvalError("apply requires exactly 2 parameters", operand_list) diff --git a/clvm/serialize.py b/clvm/serialize.py index aec94141..04fb4a6d 100644 --- a/clvm/serialize.py +++ b/clvm/serialize.py @@ -31,6 +31,14 @@ T = typing.TypeVar("T") +OpCallable = typing.Callable[ + ["OpStackType", "ValStackType", typing.BinaryIO, typing.Type], None +] + +ValStackType = typing.List[SExp] +OpStackType = typing.List[OpCallable] + + def sexp_to_byte_iterator(sexp: SExp) -> typing.Iterator[bytes]: todo_stack = [sexp] while todo_stack: @@ -41,7 +49,8 @@ def sexp_to_byte_iterator(sexp: SExp) -> typing.Iterator[bytes]: todo_stack.append(pair[1]) todo_stack.append(pair[0]) else: - yield from atom_to_byte_iterator(sexp.as_atom()) + assert sexp.atom is not None + yield from atom_to_byte_iterator(sexp.atom) def atom_to_byte_iterator(as_atom: bytes) -> typing.Iterator[bytes]: @@ -90,7 +99,9 @@ def sexp_to_stream(sexp: SExp, f: typing.BinaryIO) -> None: f.write(b) -def _op_read_sexp(op_stack, val_stack, f: typing.BinaryIO, to_sexp) -> None: +def _op_read_sexp( + op_stack: OpStackType, val_stack: ValStackType, f: typing.BinaryIO, to_sexp: typing.Callable[[bytes], SExp], +) -> None: blob = f.read(1) if len(blob) == 0: raise ValueError("bad encoding") @@ -103,15 +114,20 @@ def _op_read_sexp(op_stack, val_stack, f: typing.BinaryIO, to_sexp) -> None: val_stack.append(_atom_from_stream(f, b, to_sexp)) -def _op_cons(op_stack, val_stack, f: typing.BinaryIO, to_sexp) -> None: +def _op_cons( + op_stack: OpStackType, + val_stack: ValStackType, + f: typing.BinaryIO, + to_sexp: typing.Callable[[typing.Tuple[SExp, SExp]], SExp], +) -> None: right = val_stack.pop() left = val_stack.pop() val_stack.append(to_sexp((left, right))) -def sexp_from_stream(f: typing.BinaryIO, to_sexp: typing.Callable[..., T]) -> T: - op_stack = [_op_read_sexp] - val_stack = [] +def sexp_from_stream(f: typing.BinaryIO, to_sexp: typing.Callable[[SExp], T]) -> T: + op_stack: OpStackType = [_op_read_sexp] + val_stack: ValStackType = [] while op_stack: func = op_stack.pop() @@ -171,7 +187,9 @@ def sexp_buffer_from_stream(f: typing.BinaryIO) -> bytes: return ret.getvalue() -def _atom_from_stream(f: typing.BinaryIO, b: int, to_sexp: typing.Callable[..., T]) -> T: +def _atom_from_stream( + f: typing.BinaryIO, b: int, to_sexp: typing.Callable[[bytes], T] +) -> T: if b == 0x80: return to_sexp(b"") if b <= MAX_SINGLE_BYTE: diff --git a/setup.py b/setup.py index 56118da1..3b5e0240 100755 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ dependencies = [ "blspy>=0.9", - "typing-extensions~=4.0.0", # Backports of new typing module features + "typing-extensions~=4.0", # Backports of new typing module features ] dev_dependencies = [ diff --git a/tests/operatordict_test.py b/tests/operatordict_test.py index cee1bfe6..89a60950 100644 --- a/tests/operatordict_test.py +++ b/tests/operatordict_test.py @@ -12,21 +12,21 @@ def test_operatordict_constructor(self) -> None: Note that they cannot be specified in the operator dictionary itself. """ # ignoring because apparently it doesn't matter for this test that the types are all wrong - d: Dict[bytes, OperatorProtocol] = {1: "hello", 2: "goodbye"} # type: ignore [dict-item] + d: Dict[bytes, OperatorProtocol] = {b"\01": "hello", b"\02": "goodbye"} # type: ignore [dict-item] with self.assertRaises(AttributeError): o = OperatorDict(d) with self.assertRaises(AttributeError): - o = OperatorDict(d, apply=1) + o = OperatorDict(d, apply=b"\01") with self.assertRaises(AttributeError): - o = OperatorDict(d, quote=1) - o = OperatorDict(d, apply=1, quote=2) + o = OperatorDict(d, quote=b"\01") + o = OperatorDict(d, apply=b"\01", quote=b"\02") print(o) # Why does the constructed Operator dict contain entries for "apply":1 and "quote":2 ? # assert d == o - self.assertEqual(o.apply_atom, 1) - self.assertEqual(o.quote_atom, 2) + self.assertEqual(o.apply_atom, b"\01") + self.assertEqual(o.quote_atom, b"\02") # Test construction from an already existing OperatorDict o2 = OperatorDict(o) - self.assertEqual(o2.apply_atom, 1) - self.assertEqual(o2.quote_atom, 2) + self.assertEqual(o2.apply_atom, b"\01") + self.assertEqual(o2.quote_atom, b"\02")