From bc97c96279f28756f0fe214dd1794240637ceed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 7 Aug 2024 01:00:03 +0200 Subject: [PATCH] feat: cythonize inspect signature --- build.py | 6 +- koerce/patterns.py | 20 +- koerce/sugar.py | 3 +- koerce/tests/test_sugar.py | 42 ++-- koerce/tests/test_utils.py | 418 ++++++++++++++++++++++++++++++++++++- koerce/utils.py | 203 +++++++++++++++++- 6 files changed, 643 insertions(+), 49 deletions(-) diff --git a/build.py b/build.py index b6e5835..b25514a 100644 --- a/build.py +++ b/build.py @@ -25,6 +25,10 @@ "koerce.patterns", ["koerce/patterns.py"], ), + Extension( + "koerce.utils", + ["koerce/utils.py"], + ), ], build_dir=BUILD_DIR, # generate anannotated .html output files. @@ -41,7 +45,7 @@ # "annotation_typing": False }, # always rebuild, even if files untouched - force=True, + force=False, # emit_linenums=True ) diff --git a/koerce/patterns.py b/koerce/patterns.py index d24df22..77e8b05 100644 --- a/koerce/patterns.py +++ b/koerce/patterns.py @@ -3,7 +3,6 @@ import importlib from collections.abc import Callable, Mapping, Sequence from enum import Enum -from inspect import Parameter from types import UnionType from typing import ( Annotated, @@ -18,8 +17,11 @@ import cython from typing_extensions import GenericMeta, get_original_bases +# TODO(kszucs): would be nice to cimport Signature and Builder from .builders import Builder, Deferred, Variable, builder from .utils import ( + EMPTY, + Parameter, RewindableIterator, Signature, get_type_args, @@ -209,17 +211,17 @@ def from_callable( args = {} elif isinstance(args, (list, tuple)): # create a mapping of parameter name to pattern - args = dict(zip(sig.parameters.keys(), args)) + args = {param.name: arg for param, arg in zip(sig.parameters, args)} elif not isinstance(args, dict): raise TypeError(f"patterns must be a list or dict, got {type(args)}") retpat: Pattern argpat: Pattern argpats: dict[str, Pattern] = {} - for param in sig.parameters.values(): + for param in sig.parameters: name: str = param.name - kind = param.kind - default = param.default + kind: int = param.kind + default = param.default_ typehint = typehints.get(name) if name in args: @@ -233,7 +235,7 @@ def from_callable( argpat = TupleOf(argpat) elif kind is Parameter.VAR_KEYWORD: argpat = DictOf(_any, argpat) - elif default is not Parameter.empty: + elif default is not EMPTY: argpat = Option(argpat, default=default) argpats[name] = argpat @@ -1615,12 +1617,12 @@ def match(self, value, ctx: Context): has_varargs: bool = False positional: list = [] required_positional: list = [] - for p in sig.parameters.values(): + for p in sig.parameters: if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): positional.append(p) - if p.default is Parameter.empty: + if p.default_ is EMPTY: required_positional.append(p) - elif p.kind is Parameter.KEYWORD_ONLY and p.default is Parameter.empty: + elif p.kind is Parameter.KEYWORD_ONLY and p.default_ is EMPTY: raise TypeError( "Callable has mandatory keyword-only arguments which cannot be specified" ) diff --git a/koerce/sugar.py b/koerce/sugar.py index 4b48a11..09d41b9 100644 --- a/koerce/sugar.py +++ b/koerce/sugar.py @@ -174,10 +174,9 @@ def annotated(_1=None, _2=None, _3=None, **kwargs): def wrapped(*args, **kwargs): # 0. Bind the arguments to the signature bound = sig.bind(*args, **kwargs) - bound.apply_defaults() # 1. Validate the passed arguments - values = argpats.apply(bound.arguments) + values = argpats.apply(bound) if values is NoMatch: raise ValidationError() diff --git a/koerce/tests/test_sugar.py b/koerce/tests/test_sugar.py index 5aa9010..e9dea66 100644 --- a/koerce/tests/test_sugar.py +++ b/koerce/tests/test_sugar.py @@ -39,11 +39,10 @@ def test(a: int, b: int, c: int = 1): ... sig = Signature.from_callable(test) bound = sig.bind(2, 3) - bound.apply_defaults() - assert bound.arguments == {"a": 2, "b": 3, "c": 1} + assert bound == {"a": 2, "b": 3, "c": 1} - args, kwargs = sig.unbind(bound.arguments) + args, kwargs = sig.unbind(bound) assert args == (2, 3, 1) assert kwargs == {} @@ -53,17 +52,15 @@ def test(a: int, b: int, *args: int): ... sig = Signature.from_callable(test) bound = sig.bind(2, 3) - bound.apply_defaults() - assert bound.arguments == {"a": 2, "b": 3, "args": ()} - args, kwargs = sig.unbind(bound.arguments) + assert bound == {"a": 2, "b": 3, "args": ()} + args, kwargs = sig.unbind(bound) assert args == (2, 3) assert kwargs == {} bound = sig.bind(2, 3, 4, 5) - bound.apply_defaults() - assert bound.arguments == {"a": 2, "b": 3, "args": (4, 5)} - args, kwargs = sig.unbind(bound.arguments) + assert bound == {"a": 2, "b": 3, "args": (4, 5)} + args, kwargs = sig.unbind(bound) assert args == (2, 3, 4, 5) assert kwargs == {} @@ -73,18 +70,16 @@ def test(a: int, b: int, /, c: int = 1): ... sig = Signature.from_callable(test) bound = sig.bind(2, 3) - bound.apply_defaults() - assert bound.arguments == {"a": 2, "b": 3, "c": 1} + assert bound == {"a": 2, "b": 3, "c": 1} - args, kwargs = sig.unbind(bound.arguments) + args, kwargs = sig.unbind(bound) assert args == (2, 3, 1) assert kwargs == {} bound = sig.bind(2, 3, 4) - bound.apply_defaults() - assert bound.arguments == {"a": 2, "b": 3, "c": 4} + assert bound == {"a": 2, "b": 3, "c": 4} - args, kwargs = sig.unbind(bound.arguments) + args, kwargs = sig.unbind(bound) assert args == (2, 3, 4) assert kwargs == {} @@ -94,10 +89,9 @@ def test(a: int, b: int, *, c: float, d: float = 0.0): ... sig = Signature.from_callable(test) bound = sig.bind(2, 3, c=4.0) - bound.apply_defaults() - assert bound.arguments == {"a": 2, "b": 3, "c": 4.0, "d": 0.0} + assert bound == {"a": 2, "b": 3, "c": 4.0, "d": 0.0} - args, kwargs = sig.unbind(bound.arguments) + args, kwargs = sig.unbind(bound) assert args == (2, 3) assert kwargs == {"c": 4.0, "d": 0.0} @@ -107,11 +101,10 @@ def func(a, b, c=1): ... sig = Signature.from_callable(func) bound = sig.bind(1, 2) - bound.apply_defaults() - assert bound.arguments == {"a": 1, "b": 2, "c": 1} + assert bound == {"a": 1, "b": 2, "c": 1} - args, kwargs = sig.unbind(bound.arguments) + args, kwargs = sig.unbind(bound) assert args == (1, 2, 1) assert kwargs == {} @@ -123,10 +116,9 @@ def func(a, b, c, *args, e=None): sig = Signature.from_callable(func) bound = sig.bind(1, 2, 3, *d, e=4) - bound.apply_defaults() - assert bound.arguments == {"a": 1, "b": 2, "c": 3, "args": d, "e": 4} + assert bound == {"a": 1, "b": 2, "c": 3, "args": d, "e": 4} - args, kwargs = sig.unbind(bound.arguments) + args, kwargs = sig.unbind(bound) assert args == (1, 2, 3, *d) assert kwargs == {"e": 4} @@ -254,7 +246,7 @@ def test(a, b, c): return a, b, c assert test(1, 2, 3) == (1, 2, 3) - assert test.__signature__.parameters.keys() == {"a", "b", "c"} + assert [p.name for p in test.__signature__.parameters] == ["a", "b", "c"] # def test_annotated_function_without_decoration(snapshot): diff --git a/koerce/tests/test_utils.py b/koerce/tests/test_utils.py index 07180de..ff20e46 100644 --- a/koerce/tests/test_utils.py +++ b/koerce/tests/test_utils.py @@ -4,7 +4,14 @@ import pytest -from koerce.utils import get_type_boundvars, get_type_hints, get_type_params +from koerce.utils import ( + EMPTY, + Parameter, + Signature, + get_type_boundvars, + get_type_hints, + get_type_params, +) T = TypeVar("T", covariant=True) S = TypeVar("S", covariant=True) @@ -108,3 +115,412 @@ def test_get_type_boundvars_unable_to_deduce() -> None: msg = "Unable to deduce corresponding type attributes..." with pytest.raises(ValueError, match=msg): get_type_boundvars(MyDict[int, str]) + + +def test_parameter(): + p = Parameter("x", Parameter.POSITIONAL_OR_KEYWORD, annotation=int) + assert p.name == "x" + assert p.kind is Parameter.POSITIONAL_OR_KEYWORD + assert str(p) == "x: int" + + p = Parameter("x", Parameter.POSITIONAL_OR_KEYWORD, default=1) + assert p.name == "x" + assert p.kind is Parameter.POSITIONAL_OR_KEYWORD + assert p.default_ == 1 + assert str(p) == "x=1" + + p = Parameter("x", Parameter.POSITIONAL_OR_KEYWORD, annotation=int, default=1) + assert p.name == "x" + assert p.kind is Parameter.POSITIONAL_OR_KEYWORD + assert p.default_ == 1 + assert p.annotation is int + assert str(p) == "x: int = 1" + + p = Parameter("y", Parameter.VAR_POSITIONAL, annotation=int) + assert p.name == "y" + assert p.kind is Parameter.VAR_POSITIONAL + assert p.annotation is int + assert str(p) == "*y: int" + + p = Parameter("z", Parameter.VAR_KEYWORD, annotation=int) + assert p.name == "z" + assert p.kind is Parameter.VAR_KEYWORD + assert p.annotation is int + assert str(p) == "**z: int" + + +def test_signature_contruction(): + a = Parameter("a", Parameter.POSITIONAL_OR_KEYWORD, annotation=int) + b = Parameter("b", Parameter.POSITIONAL_OR_KEYWORD, annotation=str) + c = Parameter("c", Parameter.POSITIONAL_OR_KEYWORD, annotation=int, default=1) + d = Parameter("d", Parameter.VAR_POSITIONAL, annotation=int) + + sig = Signature([a, b, c, d]) + assert sig.parameters == [a, b, c, d] + assert sig.return_annotation is EMPTY + + +def test_signature_from_callable(): + def func(a: int, b: str, *args, c=1, **kwargs) -> float: ... + + sig = Signature.from_callable(func) + assert sig.parameters == [ + Parameter("a", Parameter.POSITIONAL_OR_KEYWORD, annotation="int"), + Parameter("b", Parameter.POSITIONAL_OR_KEYWORD, annotation="str"), + Parameter("args", Parameter.VAR_POSITIONAL), + Parameter("c", Parameter.KEYWORD_ONLY, default=1), + Parameter("kwargs", Parameter.VAR_KEYWORD), + ] + assert sig.return_annotation == "float" + + +def test_signature_bind_various(): + # with positional or keyword default + def func(a: int, b: str, c=1) -> float: ... + + sig = Signature.from_callable(func) + bound = sig.bind(1, "2") + assert bound == {"a": 1, "b": "2", "c": 1} + + # with variable positional arguments + def func(a: int, b: str, *args: int, c=1) -> float: ... + + sig = Signature.from_callable(func) + bound = sig.bind(1, "2", 3, 4) + assert bound == {"a": 1, "b": "2", "args": (3, 4), "c": 1} + + # with both variadic positional and variadic keyword arguments + def func(a: int, b: str, *args: int, c=1, **kwargs: int) -> float: ... + + sig = Signature.from_callable(func) + bound = sig.bind(1, "2", 3, 4, x=5, y=6) + assert bound == { + "a": 1, + "b": "2", + "args": (3, 4), + "c": 1, + "kwargs": {"x": 5, "y": 6}, + } + + # with positional only arguments + def func(a: int, b: str, /, c=1) -> float: ... + + sig = Signature.from_callable(func) + bound = sig.bind(1, "2") + assert bound == {"a": 1, "b": "2", "c": 1} + + with pytest.raises(TypeError, match="passed as keyword argument"): + sig.bind(a=1, b="2", c=3) + + # with keyword only arguments + def func(a: int, b: str, *, c=1) -> float: ... + + sig = Signature.from_callable(func) + bound = sig.bind(1, "2", c=3) + assert bound == {"a": 1, "b": "2", "c": 3} + + with pytest.raises(TypeError, match="too many positional arguments"): + sig.bind(1, "2", 3) + + def func(a, *args, b, z=100, **kwargs): ... + + sig = Signature.from_callable(func) + bound = sig.bind(10, 20, b=30, c=40, args=50, kwargs=60) + assert bound == { + "a": 10, + "args": (20,), + "b": 30, + "z": 100, + "kwargs": {"c": 40, "args": 50, "kwargs": 60}, + } + + +def call(func, *args, **kwargs): + # it also tests the unbind method + sig = Signature.from_callable(func) + bound = sig.bind(*args, **kwargs) + ubargs, ubkwargs = sig.unbind(bound) + return func(*ubargs, **ubkwargs) + + +def test_signature_bind_no_arguments(): + def func(): ... + + sig = Signature.from_callable(func) + assert sig.bind() == {} + + with pytest.raises(TypeError, match="too many positional arguments"): + sig.bind(1) + with pytest.raises(TypeError, match="too many positional arguments"): + sig.bind(1, keyword=2) + with pytest.raises(TypeError, match="got an unexpected keyword argument 'keyword'"): + sig.bind(keyword=1) + + +def test_signature_bind_positional_or_keyword_arguments(): + def func(a, b, c): + return a, b, c + + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + call(func) + with pytest.raises(TypeError, match="missing a required argument: 'b'"): + call(func, 1) + with pytest.raises(TypeError, match="missing a required argument: 'c'"): + call(func, 1, 2) + assert call(func, 1, 2, 3) == (1, 2, 3) + + # one optional argument + def func(a, b, c=0): + return a, b, c + + assert call(func, 1, 2, 3) == (1, 2, 3) + assert call(func, 1, 2) == (1, 2, 0) + + # two optional arguments + def func(a, b=0, c=0): + return a, b, c + + assert call(func, 1, 2, 3) == (1, 2, 3) + assert call(func, 1, 2) == (1, 2, 0) + assert call(func, 1) == (1, 0, 0) + + # three optional arguments + def func(a=0, b=0, c=0): + return a, b, c + + assert call(func, 1, 2, 3) == (1, 2, 3) + assert call(func, 1, 2) == (1, 2, 0) + assert call(func, 1) == (1, 0, 0) + assert call(func) == (0, 0, 0) + + +def test_signature_bind_varargs(): + def func(*args): + return args + + assert call(func) == () + assert call(func, 1) == (1,) + assert call(func, 1, 2) == (1, 2) + assert call(func, 1, 2, 3) == (1, 2, 3) + + def func(a, b, c=3, *args): + return a, b, c, args + + assert call(func, 1, 2) == (1, 2, 3, ()) + assert call(func, 1, 2, 3) == (1, 2, 3, ()) + assert call(func, 1, 2, 3, 4) == (1, 2, 3, (4,)) + assert call(func, 1, 2, 3, 4, 5) == (1, 2, 3, (4, 5)) + assert call(func, 1, 2, 4) == (1, 2, 4, ()) + assert call(func, a=1, b=2, c=3) == (1, 2, 3, ()) + assert call(func, c=3, a=1, b=2) == (1, 2, 3, ()) + + with pytest.raises(TypeError, match="multiple values for argument 'c'"): + call(func, 1, 2, 3, c=4) + + def func(a, *args): + return a, args + + with pytest.raises(TypeError, match="got an unexpected keyword argument 'args'"): + call(func, a=0, args=1) + + def func(*args, **kwargs): + return args, kwargs + + assert call(func, args=1) == ((), {"args": 1}) + + sig = Signature.from_callable(func) + ba = sig.bind(args=1) + assert ba == {"args": (), "kwargs": {"args": 1}} + + +def test_signature_bind_varkwargs(): + def func(**kwargs): + return kwargs + + assert call(func) == {} + assert call(func, foo="bar") == {"foo": "bar"} + assert call(func, foo="bar", spam="ham") == {"foo": "bar", "spam": "ham"} + + def func(a, b, c=3, **kwargs): + return a, b, c, kwargs + + assert call(func, 1, 2) == (1, 2, 3, {}) + assert call(func, 1, 2, foo="bar") == (1, 2, 3, {"foo": "bar"}) + assert call(func, 1, 2, foo="bar", spam="ham") == ( + 1, + 2, + 3, + {"foo": "bar", "spam": "ham"}, + ) + assert call(func, 1, 2, foo="bar", spam="ham", c=4) == ( + 1, + 2, + 4, + {"foo": "bar", "spam": "ham"}, + ) + assert call(func, 1, 2, c=4, foo="bar", spam="ham") == ( + 1, + 2, + 4, + {"foo": "bar", "spam": "ham"}, + ) + assert call(func, 1, 2, c=4, foo="bar", spam="ham", args=10) == ( + 1, + 2, + 4, + {"foo": "bar", "spam": "ham", "args": 10}, + ) + assert call(func, b=2, a=1, c=4, foo="bar", spam="ham") == ( + 1, + 2, + 4, + {"foo": "bar", "spam": "ham"}, + ) + + +def test_signature_bind_varargs_and_varkwargs(): + def func(*args, **kwargs): + return args, kwargs + + assert call(func) == ((), {}) + assert call(func, 1) == ((1,), {}) + assert call(func, 1, 2) == ((1, 2), {}) + assert call(func, foo="bar") == ((), {"foo": "bar"}) + assert call(func, 1, foo="bar") == ((1,), {"foo": "bar"}) + assert call(func, args=10), () == {"args": 10} + assert call(func, 1, 2, foo="bar") == ((1, 2), {"foo": "bar"}) + assert call(func, 1, 2, foo="bar", spam="ham") == ( + (1, 2), + {"foo": "bar", "spam": "ham"}, + ) + assert call(func, foo="bar", spam="ham", args=10) == ( + (), + {"foo": "bar", "spam": "ham", "args": 10}, + ) + + +def test_signature_bind_positional_only_arguments(): + def func(a, b, /, c=3): + return a, b, c + + assert call(func, 1, 2) == (1, 2, 3) + assert call(func, 1, 2, 4) == (1, 2, 4) + assert call(func, 1, 2, c=4) == (1, 2, 4) + with pytest.raises(TypeError, match="multiple values for argument 'c'"): + call(func, 1, 2, 3, c=4) + + def func(a, b=2, /, c=3, *args): + return a, b, c, args + + assert call(func, 1, 2) == (1, 2, 3, ()) + assert call(func, 1, 2, 4) == (1, 2, 4, ()) + assert call(func, 1, c=3) == (1, 2, 3, ()) + + def func(a, b, c=3, /, foo=42, *, bar=50, **kwargs): + return a, b, c, foo, bar, kwargs + + assert call(func, 1, 2, 4, 5, bar=6) == (1, 2, 4, 5, 6, {}) + assert call(func, 1, 2) == (1, 2, 3, 42, 50, {}) + assert call(func, 1, 2, foo=4, bar=5) == (1, 2, 3, 4, 5, {}) + assert call(func, 1, 2, foo=4, bar=5, c=10) == (1, 2, 3, 4, 5, {"c": 10}) + assert call(func, 1, 2, 30, c=31, foo=4, bar=5) == (1, 2, 30, 4, 5, {"c": 31}) + assert call(func, 1, 2, 30, foo=4, bar=5, c=31) == (1, 2, 30, 4, 5, {"c": 31}) + assert call(func, 1, 2, c=4) == (1, 2, 3, 42, 50, {"c": 4}) + assert call(func, 1, 2, c=4, foo=5) == (1, 2, 3, 5, 50, {"c": 4}) + + with pytest.raises( + TypeError, match="positional only argument 'a' passed as keyword argument" + ): + call(func, a=1, b=2) + + def func(a=1, b=2, /): + return a, b + + with pytest.raises(TypeError, match="got an unexpected keyword argument 'a'"): + call(func, a=3, b=4) + + def func(a, /, **kwargs): + return a, kwargs + + assert call(func, "pos-only", bar="keyword") == ("pos-only", {"bar": "keyword"}) + + +def test_signature_bind_keyword_only_arguments(): + def func(*, a, b, c=3): + return a, b, c + + with pytest.raises(TypeError, match="too many positional arguments"): + call(func, 1) + + assert call(func, a=1, b=2) == (1, 2, 3) + assert call(func, a=1, b=2, c=4) == (1, 2, 4) + + def func(a, *, b, c=3, **kwargs): + return a, b, c, kwargs + + with pytest.raises(TypeError, match="missing a required argument: 'b'"): + call(func, 1) + + assert call(func, 1, b=2) == (1, 2, 3, {}) + assert call(func, 1, b=2, c=4) == (1, 2, 4, {}) + + def func(*, a, b, c=3, foo=42, **kwargs): + return a, b, c, foo, kwargs + + assert call(func, a=1, b=2) == (1, 2, 3, 42, {}) + assert call(func, a=1, b=2, foo=4) == (1, 2, 3, 4, {}) + assert call(func, a=1, b=2, foo=4, bar=5) == (1, 2, 3, 4, {"bar": 5}) + assert call(func, a=1, b=2, foo=4, bar=5, c=10) == (1, 2, 10, 4, {"bar": 5}) + assert call(func, a=1, b=2, foo=4, bar=5, c=10, spam=6) == ( + 1, + 2, + 10, + 4, + {"bar": 5, "spam": 6}, + ) + + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + call(func, b=2) + with pytest.raises(TypeError, match="missing a required argument: 'b'"): + call(func, a=1) + + def func(a, *, b): + return a, b + + assert call(func, 1, b=2) == (1, 2) + with pytest.raises(TypeError, match="missing a required argument: 'b'"): + call(func, 1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + call(func, b=2) + with pytest.raises(TypeError, match="got an unexpected keyword argument 'c'"): + call(func, a=1, b=2, c=3) + with pytest.raises(TypeError, match="too many positional arguments"): + call(func, 1, 2) + with pytest.raises(TypeError, match="too many positional arguments"): + call(func, 1, 2, c=3) + + def func(a, *, b, **kwargs): + return a, b, kwargs + + assert call(func, 1, b=2) == (1, 2, {}) + assert call(func, 1, b=2, c=3) == (1, 2, {"c": 3}) + assert call(func, 1, b=2, c=3, d=4) == (1, 2, {"c": 3, "d": 4}) + assert call(func, a=1, b=2) == (1, 2, {}) + assert call(func, c=3, a=1, b=2) == (1, 2, {"c": 3}) + with pytest.raises(TypeError, match="missing a required argument: 'b'"): + call(func, a=1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + call(func, c=3, b=2) + + +def test_signature_bind_with_arg_named_self(): + def test(a, self, b): + pass + + sig = Signature.from_callable(test) + ba = sig.bind(1, 2, 3) + args, _ = sig.unbind(ba) + assert args == (1, 2, 3) + ba = sig.bind(1, self=2, b=3) + args, _ = sig.unbind(ba) + assert args == (1, 2, 3) diff --git a/koerce/utils.py b/koerce/utils.py index a2e77d5..fd26177 100644 --- a/koerce/utils.py +++ b/koerce/utils.py @@ -5,6 +5,8 @@ import typing from typing import Any, TypeVar +import cython + get_type_args = typing.get_args get_type_origin = typing.get_origin @@ -180,15 +182,193 @@ def checkpoint(self): self._iterator, self._checkpoint = itertools.tee(self._iterator) -class Signature(inspect.Signature): - def unbind(self, this: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]]: +EMPTY = inspect.Parameter.empty + +POSITIONAL_ONLY = cython.declare(cython.int, int(inspect.Parameter.POSITIONAL_ONLY)) +POSITIONAL_OR_KEYWORD = cython.declare( + cython.int, int(inspect.Parameter.POSITIONAL_OR_KEYWORD) +) +VAR_POSITIONAL = cython.declare(cython.int, int(inspect.Parameter.VAR_POSITIONAL)) +KEYWORD_ONLY = cython.declare(cython.int, int(inspect.Parameter.KEYWORD_ONLY)) +VAR_KEYWORD = cython.declare(cython.int, int(inspect.Parameter.VAR_KEYWORD)) + + +@cython.final +@cython.cclass +class Parameter: + POSITIONAL_ONLY: typing.ClassVar[int] = int(inspect.Parameter.POSITIONAL_ONLY) + POSITIONAL_OR_KEYWORD: typing.ClassVar[int] = int( + inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + VAR_POSITIONAL: typing.ClassVar[int] = int(inspect.Parameter.VAR_POSITIONAL) + KEYWORD_ONLY: typing.ClassVar[int] = int(inspect.Parameter.KEYWORD_ONLY) + VAR_KEYWORD: typing.ClassVar[int] = int(inspect.Parameter.VAR_KEYWORD) + + name = cython.declare(str, visibility="readonly") + kind = cython.declare(cython.int, visibility="readonly") + # Cannot use C reserved keyword 'default' here + default_ = cython.declare(object, visibility="readonly") + annotation = cython.declare(object, visibility="readonly") + + def __init__( + self, name: str, kind: int, default: Any = EMPTY, annotation: Any = EMPTY + ): + self.name = name + self.kind = kind + self.default_ = default + self.annotation = annotation + + def __str__(self) -> str: + result: str = self.name + if self.annotation is not EMPTY: + if hasattr(self.annotation, "__qualname__"): + result += f": {self.annotation.__qualname__}" + else: + result += f": {self.annotation}" + if self.default_ is not EMPTY: + if self.annotation is EMPTY: + result = f"{result}={self.default_}" + else: + result = f"{result} = {self.default_!r}" + if self.kind == VAR_POSITIONAL: + result = f"*{result}" + elif self.kind == VAR_KEYWORD: + result = f"**{result}" + return result + + def __repr__(self): + return f'<{self.__class__.__name__} "{self}">' + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Parameter): + return NotImplemented + right: Parameter = cython.cast(Parameter, other) + return ( + self.name == right.name + and self.kind == right.kind + and self.default_ == right.default_ + and self.annotation == right.annotation + ) + + +@cython.final +@cython.cclass +class Signature: + parameters = cython.declare(list[Parameter], visibility="readonly") + return_annotation = cython.declare(object, visibility="readonly") + + def __init__(self, parameters: list[Parameter], return_annotation: Any = EMPTY): + self.parameters = parameters + self.return_annotation = return_annotation + + @staticmethod + def from_callable(func: Any) -> Signature: + sig = inspect.signature(func) + params: list[Parameter] = [ + Parameter(p.name, int(p.kind), p.default, p.annotation) + for p in sig.parameters.values() + ] + return Signature(params, sig.return_annotation) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Signature): + return NotImplemented + right: Signature = cython.cast(Signature, other) + return ( + self.parameters == right.parameters + and self.return_annotation == right.return_annotation + ) + + def bind(self, /, *args: Any, **kwargs: Any) -> dict[str, Any]: + """Bind the arguments to the signature. + + Parameters + ---------- + args : Any + Positional arguments. + kwargs : Any + Keyword arguments. + + Returns + ------- + dict + Mapping of parameter names to argument values. + + """ + i: cython.int = 0 + kind: cython.int + param: Parameter + bound: dict[str, Any] = {} + + # 1. HANDLE ARGS + for i in range(len(args)): + if i >= len(self.parameters): + raise TypeError("too many positional arguments") + + param = self.parameters[i] + kind = param.kind + if kind is POSITIONAL_OR_KEYWORD: + if param.name in kwargs: + raise TypeError(f"multiple values for argument '{param.name}'") + bound[param.name] = args[i] + elif kind is VAR_KEYWORD or kind is KEYWORD_ONLY: + raise TypeError("too many positional arguments") + elif kind is VAR_POSITIONAL: + bound[param.name] = args[i:] + break + elif kind is POSITIONAL_ONLY: + bound[param.name] = args[i] + else: + raise TypeError("unreachable code") + + # 2. INCREMENT PARAMETER INDEX + if args: + i += 1 + + # 3. HANDLE KWARGS + for param in self.parameters[i:]: + if param.kind is POSITIONAL_OR_KEYWORD or param.kind is KEYWORD_ONLY: + if param.name in kwargs: + bound[param.name] = kwargs.pop(param.name) + elif param.default_ is EMPTY: + raise TypeError(f"missing a required argument: '{param.name}'") + else: + bound[param.name] = param.default_ + elif param.kind is VAR_POSITIONAL: + bound[param.name] = () + elif param.kind is VAR_KEYWORD: + bound[param.name] = kwargs + break + elif param.kind is POSITIONAL_ONLY: + if param.default_ is EMPTY: + if param.name in kwargs: + raise TypeError( + f"positional only argument '{param.name}' passed as keyword argument" + ) + else: + raise TypeError( + f"missing required positional argument {param.name}" + ) + else: + bound[param.name] = param.default_ + else: + raise TypeError("unreachable code") + else: + if kwargs: + raise TypeError( + f"got an unexpected keyword argument '{next(iter(kwargs))}'" + ) + + return bound + + def unbind(self, bound: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]]: """Reverse bind of the parameters. Attempts to reconstructs the original arguments as keyword only arguments. Parameters ---------- - this : Any + bound Object with attributes matching the signature parameters. Returns @@ -200,17 +380,18 @@ def unbind(self, this: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]] # does the reverse of bind, but doesn't apply defaults args: list = [] kwargs: dict = {} - for name, param in self.parameters.items(): - value = this[name] - if param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: + param: Parameter + for param in self.parameters: + value = bound[param.name] + if param.kind is POSITIONAL_OR_KEYWORD: args.append(value) - elif param.kind is inspect.Parameter.VAR_POSITIONAL: + elif param.kind is VAR_POSITIONAL: args.extend(value) - elif param.kind is inspect.Parameter.VAR_KEYWORD: + elif param.kind is VAR_KEYWORD: kwargs.update(value) - elif param.kind is inspect.Parameter.KEYWORD_ONLY: - kwargs[name] = value - elif param.kind is inspect.Parameter.POSITIONAL_ONLY: + elif param.kind is KEYWORD_ONLY: + kwargs[param.name] = value + elif param.kind is POSITIONAL_ONLY: args.append(value) else: raise TypeError(f"unsupported parameter kind {param.kind}")