Skip to content

Commit

Permalink
feat: cythonize inspect signature
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Aug 7, 2024
1 parent 49549bf commit 58e63e3
Show file tree
Hide file tree
Showing 6 changed files with 633 additions and 53 deletions.
6 changes: 5 additions & 1 deletion build.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
"koerce.patterns",
["koerce/patterns.py"],
),
Extension(
"koerce.utils",
["koerce/utils.py"],
),
],
build_dir=BUILD_DIR,
# generate anannotated .html output files.
Expand All @@ -41,7 +45,7 @@
# "annotation_typing": False
},
# always rebuild, even if files untouched
force=True,
force=False,
# emit_linenums=True
)

Expand Down
32 changes: 19 additions & 13 deletions koerce/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import importlib
from collections.abc import Callable, Mapping, Sequence
from enum import Enum
from inspect import Parameter
from types import UnionType
from typing import (
Annotated,
Expand All @@ -18,6 +17,7 @@
import cython
from typing_extensions import GenericMeta, get_original_bases

# TODO(kszucs): would be nice to cimport Signature and Builder
from .builders import Builder, Deferred, Variable, builder
from .utils import (
RewindableIterator,
Expand All @@ -27,6 +27,12 @@
get_type_hints,
get_type_origin,
get_type_params,
EMPTY,
POSITIONAL_ONLY,
POSITIONAL_OR_KEYWORD,
VAR_KEYWORD,
VAR_POSITIONAL,
KEYWORD_ONLY,
)


Expand Down Expand Up @@ -209,17 +215,17 @@ def from_callable(
args = {}
elif isinstance(args, (list, tuple)):
# create a mapping of parameter name to pattern
args = dict(zip(sig.parameters.keys(), args))
args = {param.name: arg for param, arg in zip(sig.parameters, args)}
elif not isinstance(args, dict):
raise TypeError(f"patterns must be a list or dict, got {type(args)}")

retpat: Pattern
argpat: Pattern
argpats: dict[str, Pattern] = {}
for param in sig.parameters.values():
for param in sig.parameters:
name: str = param.name
kind = param.kind
default = param.default
kind: int = param.kind
default = param.default_
typehint = typehints.get(name)

if name in args:
Expand All @@ -229,11 +235,11 @@ def from_callable(
else:
argpat = _any

if kind is Parameter.VAR_POSITIONAL:
if kind is VAR_POSITIONAL:
argpat = TupleOf(argpat)
elif kind is Parameter.VAR_KEYWORD:
elif kind is VAR_KEYWORD:
argpat = DictOf(_any, argpat)
elif default is not Parameter.empty:
elif default is not EMPTY:
argpat = Option(argpat, default=default)

argpats[name] = argpat
Expand Down Expand Up @@ -1615,16 +1621,16 @@ def match(self, value, ctx: Context):
has_varargs: bool = False
positional: list = []
required_positional: list = []
for p in sig.parameters.values():
if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
for p in sig.parameters:
if p.kind in (POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD):
positional.append(p)
if p.default is Parameter.empty:
if p.default_ is EMPTY:
required_positional.append(p)
elif p.kind is Parameter.KEYWORD_ONLY and p.default is Parameter.empty:
elif p.kind is KEYWORD_ONLY and p.default_ is EMPTY:
raise TypeError(
"Callable has mandatory keyword-only arguments which cannot be specified"
)
elif p.kind is Parameter.VAR_POSITIONAL:
elif p.kind is VAR_POSITIONAL:
has_varargs = True

if len(required_positional) > len(self.args):
Expand Down
3 changes: 1 addition & 2 deletions koerce/sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,9 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):
def wrapped(*args, **kwargs):
# 0. Bind the arguments to the signature
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

# 1. Validate the passed arguments
values = argpats.apply(bound.arguments)
values = argpats.apply(bound)
if values is NoMatch:
raise ValidationError()

Expand Down
42 changes: 17 additions & 25 deletions koerce/tests/test_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ def test(a: int, b: int, c: int = 1): ...

sig = Signature.from_callable(test)
bound = sig.bind(2, 3)
bound.apply_defaults()

assert bound.arguments == {"a": 2, "b": 3, "c": 1}
assert bound == {"a": 2, "b": 3, "c": 1}

args, kwargs = sig.unbind(bound.arguments)
args, kwargs = sig.unbind(bound)
assert args == (2, 3, 1)
assert kwargs == {}

Expand All @@ -53,17 +52,15 @@ def test(a: int, b: int, *args: int): ...

sig = Signature.from_callable(test)
bound = sig.bind(2, 3)
bound.apply_defaults()

assert bound.arguments == {"a": 2, "b": 3, "args": ()}
args, kwargs = sig.unbind(bound.arguments)
assert bound == {"a": 2, "b": 3, "args": ()}
args, kwargs = sig.unbind(bound)
assert args == (2, 3)
assert kwargs == {}

bound = sig.bind(2, 3, 4, 5)
bound.apply_defaults()
assert bound.arguments == {"a": 2, "b": 3, "args": (4, 5)}
args, kwargs = sig.unbind(bound.arguments)
assert bound == {"a": 2, "b": 3, "args": (4, 5)}
args, kwargs = sig.unbind(bound)
assert args == (2, 3, 4, 5)
assert kwargs == {}

Expand All @@ -73,18 +70,16 @@ def test(a: int, b: int, /, c: int = 1): ...

sig = Signature.from_callable(test)
bound = sig.bind(2, 3)
bound.apply_defaults()
assert bound.arguments == {"a": 2, "b": 3, "c": 1}
assert bound == {"a": 2, "b": 3, "c": 1}

args, kwargs = sig.unbind(bound.arguments)
args, kwargs = sig.unbind(bound)
assert args == (2, 3, 1)
assert kwargs == {}

bound = sig.bind(2, 3, 4)
bound.apply_defaults()
assert bound.arguments == {"a": 2, "b": 3, "c": 4}
assert bound == {"a": 2, "b": 3, "c": 4}

args, kwargs = sig.unbind(bound.arguments)
args, kwargs = sig.unbind(bound)
assert args == (2, 3, 4)
assert kwargs == {}

Expand All @@ -94,10 +89,9 @@ def test(a: int, b: int, *, c: float, d: float = 0.0): ...

sig = Signature.from_callable(test)
bound = sig.bind(2, 3, c=4.0)
bound.apply_defaults()
assert bound.arguments == {"a": 2, "b": 3, "c": 4.0, "d": 0.0}
assert bound == {"a": 2, "b": 3, "c": 4.0, "d": 0.0}

args, kwargs = sig.unbind(bound.arguments)
args, kwargs = sig.unbind(bound)
assert args == (2, 3)
assert kwargs == {"c": 4.0, "d": 0.0}

Expand All @@ -107,11 +101,10 @@ def func(a, b, c=1): ...

sig = Signature.from_callable(func)
bound = sig.bind(1, 2)
bound.apply_defaults()

assert bound.arguments == {"a": 1, "b": 2, "c": 1}
assert bound == {"a": 1, "b": 2, "c": 1}

args, kwargs = sig.unbind(bound.arguments)
args, kwargs = sig.unbind(bound)
assert args == (1, 2, 1)
assert kwargs == {}

Expand All @@ -123,10 +116,9 @@ def func(a, b, c, *args, e=None):

sig = Signature.from_callable(func)
bound = sig.bind(1, 2, 3, *d, e=4)
bound.apply_defaults()
assert bound.arguments == {"a": 1, "b": 2, "c": 3, "args": d, "e": 4}
assert bound == {"a": 1, "b": 2, "c": 3, "args": d, "e": 4}

args, kwargs = sig.unbind(bound.arguments)
args, kwargs = sig.unbind(bound)
assert args == (1, 2, 3, *d)
assert kwargs == {"e": 4}

Expand Down Expand Up @@ -254,7 +246,7 @@ def test(a, b, c):
return a, b, c

assert test(1, 2, 3) == (1, 2, 3)
assert test.__signature__.parameters.keys() == {"a", "b", "c"}
assert [p.name for p in test.__signature__.parameters] == ["a", "b", "c"]


# def test_annotated_function_without_decoration(snapshot):
Expand Down
Loading

0 comments on commit 58e63e3

Please sign in to comment.