Skip to content

Commit

Permalink
feat: add CallableWith pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Jul 31, 2024
1 parent 9a55735 commit 857df1f
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 15 deletions.
71 changes: 56 additions & 15 deletions koerce/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib
from collections.abc import Callable, Mapping, Sequence
from enum import Enum
from inspect import Parameter, Signature
from types import UnionType
from typing import (
Annotated,
Expand Down Expand Up @@ -121,21 +122,21 @@ def from_typehint(annot: Any, allow_coercion: bool = True) -> Pattern:
# is used for isinstance checks, the rest are applied in conjunction
annot, *extras = args
return AllOf(Pattern.from_typehint(annot), *extras)
# elif origin is Callable:
# # the Callable typehint is used to annotate functions, e.g. the
# # following typehint annotates a function that takes two integers
# # and returns a string: Callable[[int, int], str]
# if args:
# # callable with args and return typehints construct a special
# # CallableWith validator
# arg_hints, return_hint = args
# arg_patterns = tuple(map(cls.from_typehint, arg_hints))
# return_pattern = cls.from_typehint(return_hint)
# return CallableWith(arg_patterns, return_pattern)
# else:
# # in case of Callable without args we check for the Callable
# # protocol only
# return InstanceOf(Callable)
elif origin is Callable:
# the Callable typehint is used to annotate functions, e.g. the
# following typehint annotates a function that takes two integers
# and returns a string: Callable[[int, int], str]
if args:
# callable with args and return typehints construct a special
# CallableWith validator
arg_hints, return_hint = args
arg_patterns = list(map(Pattern.from_typehint, arg_hints))
return_pattern = Pattern.from_typehint(return_hint)
return CallableWith(arg_patterns, return_pattern)
else:
# in case of Callable without args we check for the Callable
# protocol only
return InstanceOf(Callable)
elif issubclass(origin, tuple):
# construct validators for the tuple elements, but need to treat
# variadic tuples differently, e.g. tuple[int, ...] is a variadic
Expand Down Expand Up @@ -1492,6 +1493,46 @@ def match(self, value, context):
else:
return value

@cython.final
@cython.cclass
class CallableWith(Pattern):
args: list[Pattern]
return_: Pattern

def __init__(self, args, return_=_any):
self.args = [pattern(arg) for arg in args]
self.return_ = pattern(return_)

@cython.cfunc
def match(self, value, ctx):
if not callable(value):
return NoMatch

sig = Signature.from_callable(value)

has_varargs: bool = False
positional: list = []
required_positional: list = []
for p in sig.parameters.values():
if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
positional.append(p)
if p.default is Parameter.empty:
required_positional.append(p)
elif p.kind is Parameter.KEYWORD_ONLY and p.default is Parameter.empty:
raise TypeError(
"Callable has mandatory keyword-only arguments which cannot be specified"
)
elif p.kind is Parameter.VAR_POSITIONAL:
has_varargs = True

if len(required_positional) > len(self.args):
# Callable has more positional arguments than expected")
return NoMatch
elif len(positional) < len(self.args) and not has_varargs:
# Callable has less positional arguments than expected")
return NoMatch
else:
return value

@cython.final
@cython.cclass
Expand Down
63 changes: 63 additions & 0 deletions koerce/tests/test_patterns.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import functools
import sys
from dataclasses import dataclass
from typing import (
Annotated,
Any,
Callable,
Generic,
List,
Literal,
Expand All @@ -22,6 +24,7 @@
AnyOf,
Anything,
AsType,
CallableWith,
Capture,
CoercedTo,
CoercionError,
Expand Down Expand Up @@ -1217,3 +1220,63 @@ def f(x):

# matching mapping values
assert pattern({"a": 1, "b": 2}) == EqValue({"a": 1, "b": 2})



def test_callable_with():
def func(a, b):
return str(a) + b

def func_with_args(a, b, *args):
return sum((a, b) + args)

def func_with_kwargs(a, b, c=1, **kwargs):
return str(a) + b + str(c)

def func_with_optional_keyword_only_kwargs(a, *, c=1):
return a + c

def func_with_required_keyword_only_kwargs(*, c):
return c

p = CallableWith([InstanceOf(int), InstanceOf(str)])
assert p.apply(10, context={}) is NoMatch

msg = "Callable has mandatory keyword-only arguments which cannot be specified"
with pytest.raises(TypeError, match=msg):
p.apply(func_with_required_keyword_only_kwargs, context={})

# Callable has more positional arguments than expected
p = CallableWith([InstanceOf(int)] * 2)
assert p.apply(func_with_kwargs, context={}) is func_with_kwargs

# Callable has less positional arguments than expected
p = CallableWith([InstanceOf(int)] * 4)
assert p.apply(func_with_kwargs, context={}) is NoMatch

p = CallableWith([InstanceOf(int)] * 4, InstanceOf(int))
wrapped = p.apply(func_with_args, context={})
assert wrapped(1, 2, 3, 4) == 10

p = CallableWith([InstanceOf(int), InstanceOf(str)], InstanceOf(str))
wrapped = p.apply(func, context={})
assert wrapped(1, "st") == "1st"

p = CallableWith([InstanceOf(int)])
wrapped = p.apply(func_with_optional_keyword_only_kwargs, context={})
assert wrapped(1) == 2


def test_callable_with_default_arguments():
def f(a: int, b: str, c: str):
return a + int(b) + int(c)

def g(a: int, b: str, c: str = "0"):
return a + int(b) + int(c)

h = functools.partial(f, c="0")

p = Pattern.from_typehint(Callable[[int, str], int])
assert p.apply(f) is NoMatch
assert p.apply(g) == g
assert p.apply(h) == h

0 comments on commit 857df1f

Please sign in to comment.