diff --git a/koerce/patterns.py b/koerce/patterns.py index 97c1abc..39ad093 100644 --- a/koerce/patterns.py +++ b/koerce/patterns.py @@ -3,6 +3,7 @@ import importlib from collections.abc import Callable, Mapping, Sequence from enum import Enum +from inspect import Parameter, Signature from types import UnionType from typing import ( Annotated, @@ -121,21 +122,21 @@ def from_typehint(annot: Any, allow_coercion: bool = True) -> Pattern: # is used for isinstance checks, the rest are applied in conjunction annot, *extras = args return AllOf(Pattern.from_typehint(annot), *extras) - # elif origin is Callable: - # # the Callable typehint is used to annotate functions, e.g. the - # # following typehint annotates a function that takes two integers - # # and returns a string: Callable[[int, int], str] - # if args: - # # callable with args and return typehints construct a special - # # CallableWith validator - # arg_hints, return_hint = args - # arg_patterns = tuple(map(cls.from_typehint, arg_hints)) - # return_pattern = cls.from_typehint(return_hint) - # return CallableWith(arg_patterns, return_pattern) - # else: - # # in case of Callable without args we check for the Callable - # # protocol only - # return InstanceOf(Callable) + elif origin is Callable: + # the Callable typehint is used to annotate functions, e.g. the + # following typehint annotates a function that takes two integers + # and returns a string: Callable[[int, int], str] + if args: + # callable with args and return typehints construct a special + # CallableWith validator + arg_hints, return_hint = args + arg_patterns = list(map(Pattern.from_typehint, arg_hints)) + return_pattern = Pattern.from_typehint(return_hint) + return CallableWith(arg_patterns, return_pattern) + else: + # in case of Callable without args we check for the Callable + # protocol only + return InstanceOf(Callable) elif issubclass(origin, tuple): # construct validators for the tuple elements, but need to treat # variadic tuples differently, e.g. tuple[int, ...] is a variadic @@ -1492,6 +1493,46 @@ def match(self, value, context): else: return value +@cython.final +@cython.cclass +class CallableWith(Pattern): + args: list[Pattern] + return_: Pattern + + def __init__(self, args, return_=_any): + self.args = [pattern(arg) for arg in args] + self.return_ = pattern(return_) + + @cython.cfunc + def match(self, value, ctx): + if not callable(value): + return NoMatch + + sig = Signature.from_callable(value) + + has_varargs: bool = False + positional: list = [] + required_positional: list = [] + for p in sig.parameters.values(): + if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): + positional.append(p) + if p.default is Parameter.empty: + required_positional.append(p) + elif p.kind is Parameter.KEYWORD_ONLY and p.default is Parameter.empty: + raise TypeError( + "Callable has mandatory keyword-only arguments which cannot be specified" + ) + elif p.kind is Parameter.VAR_POSITIONAL: + has_varargs = True + + if len(required_positional) > len(self.args): + # Callable has more positional arguments than expected") + return NoMatch + elif len(positional) < len(self.args) and not has_varargs: + # Callable has less positional arguments than expected") + return NoMatch + else: + return value @cython.final @cython.cclass diff --git a/koerce/tests/test_patterns.py b/koerce/tests/test_patterns.py index dfd930b..5aa5f11 100644 --- a/koerce/tests/test_patterns.py +++ b/koerce/tests/test_patterns.py @@ -1,10 +1,12 @@ from __future__ import annotations +import functools import sys from dataclasses import dataclass from typing import ( Annotated, Any, + Callable, Generic, List, Literal, @@ -22,6 +24,7 @@ AnyOf, Anything, AsType, + CallableWith, Capture, CoercedTo, CoercionError, @@ -1217,3 +1220,63 @@ def f(x): # matching mapping values assert pattern({"a": 1, "b": 2}) == EqValue({"a": 1, "b": 2}) + + + +def test_callable_with(): + def func(a, b): + return str(a) + b + + def func_with_args(a, b, *args): + return sum((a, b) + args) + + def func_with_kwargs(a, b, c=1, **kwargs): + return str(a) + b + str(c) + + def func_with_optional_keyword_only_kwargs(a, *, c=1): + return a + c + + def func_with_required_keyword_only_kwargs(*, c): + return c + + p = CallableWith([InstanceOf(int), InstanceOf(str)]) + assert p.apply(10, context={}) is NoMatch + + msg = "Callable has mandatory keyword-only arguments which cannot be specified" + with pytest.raises(TypeError, match=msg): + p.apply(func_with_required_keyword_only_kwargs, context={}) + + # Callable has more positional arguments than expected + p = CallableWith([InstanceOf(int)] * 2) + assert p.apply(func_with_kwargs, context={}) is func_with_kwargs + + # Callable has less positional arguments than expected + p = CallableWith([InstanceOf(int)] * 4) + assert p.apply(func_with_kwargs, context={}) is NoMatch + + p = CallableWith([InstanceOf(int)] * 4, InstanceOf(int)) + wrapped = p.apply(func_with_args, context={}) + assert wrapped(1, 2, 3, 4) == 10 + + p = CallableWith([InstanceOf(int), InstanceOf(str)], InstanceOf(str)) + wrapped = p.apply(func, context={}) + assert wrapped(1, "st") == "1st" + + p = CallableWith([InstanceOf(int)]) + wrapped = p.apply(func_with_optional_keyword_only_kwargs, context={}) + assert wrapped(1) == 2 + + +def test_callable_with_default_arguments(): + def f(a: int, b: str, c: str): + return a + int(b) + int(c) + + def g(a: int, b: str, c: str = "0"): + return a + int(b) + int(c) + + h = functools.partial(f, c="0") + + p = Pattern.from_typehint(Callable[[int, str], int]) + assert p.apply(f) is NoMatch + assert p.apply(g) == g + assert p.apply(h) == h