From 9aebb2b3c166e0f9029c73fab5b5060f06eb4851 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 5 Aug 2024 00:41:40 +0200 Subject: [PATCH] feat: support validating callable objects with @annotated decorator --- build.py | 5 +- koerce/patterns.py | 435 +++++++++++++++++++++------------- koerce/sugar.py | 110 ++++++++- koerce/tests/test_patterns.py | 90 ++++++- koerce/tests/test_sugar.py | 277 +++++++++++++++++++++- koerce/utils.py | 39 +++ 6 files changed, 778 insertions(+), 178 deletions(-) diff --git a/build.py b/build.py index b062950..dfb6687 100644 --- a/build.py +++ b/build.py @@ -14,18 +14,15 @@ BUILD_DIR = Path("cython_build") - cythonized_modules = cythonize( [ Extension( "koerce.builders", ["koerce/builders.py"], - # extra_compile_args=["-O3"] ), Extension( "koerce.patterns", ["koerce/patterns.py"], - # extra_compile_args=["-O3"] ), ], build_dir=BUILD_DIR, @@ -39,9 +36,11 @@ # "wraparound": False, # "nonecheck": False, # "profile": True, + # "annotation_typing": False }, # always rebuild, even if files untouched force=False, + # emit_linenums=True ) dist = Distribution({"ext_modules": cythonized_modules}) diff --git a/koerce/patterns.py b/koerce/patterns.py index 29b80a1..f35096f 100644 --- a/koerce/patterns.py +++ b/koerce/patterns.py @@ -3,7 +3,7 @@ import importlib from collections.abc import Callable, Mapping, Sequence from enum import Enum -from inspect import Parameter, Signature +from inspect import Parameter from types import UnionType from typing import ( Annotated, @@ -21,8 +21,10 @@ from .builders import Builder, Deferred, Variable, builder from .utils import ( RewindableIterator, + Signature, get_type_args, get_type_boundvars, + get_type_hints, get_type_origin, get_type_params, ) @@ -35,6 +37,11 @@ class CoercionError(Exception): Context = dict[str, Any] +@cython.final +@cython.cclass +class NoMatchError(Exception): + pass + @cython.cclass class NoMatch: @@ -170,19 +177,89 @@ def from_typehint(annot: Any, allow_coercion: bool = True) -> Pattern: f"Cannot create validator from annotation {annot!r} {origin!r}" ) + @staticmethod + def from_callable( + fn: Callable, sig=None, args=None, return_=None + ) -> tuple[Pattern, Pattern]: + """Create patterns from a callable object. + + Two patterns are created, one for the arguments and one for the return value. + + Parameters + ---------- + fn : Callable + Callable to create a signature from. + sig : Signature, default None + Signature to use for the callable. If None, a signature is created + from the callable. + args : list or dict, default None + Pass patterns to add missing or override existing argument type + annotations. + return_ : Pattern, default None + Pattern for the return value of the callable. + + Returns + ------- + Tuple of patterns for the arguments and the return value. + """ + sig: Signature = sig or Signature.from_callable(fn) + typehints: dict[str, Any] = get_type_hints(fn) + + if args is None: + args = {} + elif isinstance(args, (list, tuple)): + # create a mapping of parameter name to pattern + args = dict(zip(sig.parameters.keys(), args)) + elif not isinstance(args, dict): + raise TypeError(f"patterns must be a list or dict, got {type(args)}") + + retpat: Pattern + argpat: Pattern + argpats: dict[str, Pattern] = {} + for param in sig.parameters.values(): + name: str = param.name + kind = param.kind + default = param.default + typehint = typehints.get(name) + + if name in args: + argpat = pattern(args[name]) + elif typehint is not None: + argpat = Pattern.from_typehint(typehint) + else: + argpat = _any + + if kind is Parameter.VAR_POSITIONAL: + argpat = TupleOf(argpat) + elif kind is Parameter.VAR_KEYWORD: + argpat = DictOf(_any, argpat) + elif default is not Parameter.empty: + argpat = Option(argpat, default=default) + + argpats[name] = argpat + + if return_ is not None: + retpat = pattern(return_) + elif (typehint := typehints.get("return")) is not None: + retpat = Pattern.from_typehint(typehint) + else: + retpat = _any + + return (PatternMap(argpats), retpat) + def apply(self, value, context=None): if context is None: context = {} - return self.match(value, context) + try: + return self.match(value, context) + except NoMatchError: + return NoMatch @cython.cfunc def match(self, value, ctx: Context): ... def __repr__(self) -> str: ... - # def __str__(self) -> str: - # return repr(self) - def __eq__(self, other) -> bool: return type(self) is type(other) and self.equals(other) @@ -296,7 +373,7 @@ def equals(self, other: Nothing) -> bool: @cython.cfunc @cython.inline def match(self, value, ctx: Context): - return NoMatch + raise NoMatchError() @cython.final @@ -319,7 +396,8 @@ def match(self, value, ctx: Context): if value is self.value: return value else: - return NoMatch + raise NoMatchError() + @cython.ccall def Eq(value) -> Pattern: @@ -349,7 +427,7 @@ def match(self, value, ctx: Context): if value == self.value: return value else: - return NoMatch + raise NoMatchError() @cython.final @@ -380,7 +458,7 @@ def match(self, value, ctx: Context): if value == self.value.apply(ctx): return value else: - return NoMatch + raise NoMatchError() @cython.final @@ -403,7 +481,7 @@ def match(self, value, ctx: Context): if type(value) is self.type_: return value else: - return NoMatch + raise NoMatchError() @cython.final @@ -429,7 +507,7 @@ def match(self, value, ctx: Context): if isinstance(value, self.type_): return value else: - return NoMatch + raise NoMatchError() @cython.final @@ -472,7 +550,10 @@ def _import_type(self): @cython.cfunc def match(self, value, ctx: Context): if self.type_ is not None: - return value if isinstance(value, self.type_) else NoMatch + if isinstance(value, self.type_): + return value + else: + raise NoMatchError() klass: type package: str @@ -480,9 +561,12 @@ def match(self, value, ctx: Context): package = klass.__module__.split(".", 1)[0] if package == self.package: self._import_type() - return value if isinstance(value, self.type_) else NoMatch + if isinstance(value, self.type_): + return value + else: + raise NoMatchError() - return NoMatch + raise NoMatchError() @cython.ccall @@ -526,11 +610,10 @@ def equals(self, other: GenericInstanceOf1) -> bool: @cython.cfunc def match(self, value, ctx: Context): if not isinstance(value, self.origin): - return NoMatch + raise NoMatchError() attr1 = getattr(value, self.name1) - if self.pattern1.match(attr1, ctx) is NoMatch: - return NoMatch + self.pattern1.match(attr1, ctx) return value @@ -569,15 +652,13 @@ def equals(self, other: GenericInstanceOf2) -> bool: @cython.cfunc def match(self, value, ctx: Context): if not isinstance(value, self.origin): - return NoMatch + raise NoMatchError() attr1 = getattr(value, self.name1) - if self.pattern1.match(attr1, ctx) is NoMatch: - return NoMatch + self.pattern1.match(attr1, ctx) attr2 = getattr(value, self.name2) - if self.pattern2.match(attr2, ctx) is NoMatch: - return NoMatch + self.pattern2.match(attr2, ctx) return value @@ -608,14 +689,13 @@ def equals(self, other: GenericInstanceOfN) -> bool: @cython.cfunc def match(self, value, ctx: Context): if not isinstance(value, self.origin): - return NoMatch + raise NoMatchError() name: str pattern: Pattern for name, pattern in self.fields.items(): attr = getattr(value, name) - if pattern.match(attr, ctx) is NoMatch: - return NoMatch + pattern.match(attr, ctx) return value @@ -640,7 +720,7 @@ def match(self, value, ctx: Context): if issubclass(value, self.type_): return value else: - return NoMatch + raise NoMatchError() # @cython.ccall @@ -677,7 +757,7 @@ def match(self, value, ctx: Context): try: return self.type_(value) except ValueError: - return NoMatch + raise NoMatchError() @cython.final @@ -704,12 +784,12 @@ def match(self, value, ctx: Context): try: value = self.type_.__coerce__(value) except CoercionError: - return NoMatch + raise NoMatchError() if isinstance(value, self.type_): return value else: - return NoMatch + raise NoMatchError() @cython.final @@ -745,10 +825,9 @@ def match(self, value, ctx: Context): try: value = self.origin.__coerce__(value, **self.params) except CoercionError: - return NoMatch + raise NoMatchError() - if self.checker.match(value, ctx) is NoMatch: - return NoMatch + self.checker.match(value, ctx) return value @@ -769,10 +848,12 @@ def equals(self, other: Not) -> bool: @cython.cfunc def match(self, value, ctx: Context): - if self.inner.match(value, ctx) is NoMatch: + try: + self.inner.match(value, ctx) + except NoMatchError: return value else: - return NoMatch + raise NoMatchError() @cython.final @@ -793,9 +874,11 @@ def equals(self, other: AnyOf) -> bool: def match(self, value, ctx: Context): inner: Pattern for inner in self.inners: - if inner.match(value, ctx) is not NoMatch: - return value - return NoMatch + try: + return inner.match(value, ctx) + except NoMatchError: + pass + raise NoMatchError() @cython.final @@ -817,8 +900,6 @@ def match(self, value, ctx: Context): inner: Pattern for inner in self.inners: value = inner.match(value, ctx) - if value is NoMatch: - return NoMatch return value @@ -901,7 +982,7 @@ def match(self, value, ctx: Context): if self.predicate(value): return value else: - return NoMatch + raise NoMatchError() @cython.final @@ -935,7 +1016,7 @@ def match(self, value, ctx: Context): if self.builder.apply(ctx): return value else: - return NoMatch + raise NoMatchError() @cython.final @@ -966,7 +1047,7 @@ def match(self, value, ctx: Context): if value in self.haystack: return value else: - return NoMatch + raise NoMatchError() @cython.final @@ -1011,7 +1092,7 @@ def equals(self, other: SequenceOf) -> bool: @cython.cfunc def match(self, values, ctx: Context): if isinstance(values, (str, bytes)): - return NoMatch + raise NoMatchError() # optimization to avoid unnecessary iteration if isinstance(self.item, Anything): @@ -1020,14 +1101,11 @@ def match(self, values, ctx: Context): try: it = iter(values) except TypeError: - return NoMatch + raise NoMatchError() result: list = [] for item in it: - res = self.item.match(item, ctx) - if res is NoMatch: - return NoMatch - result.append(res) + result.append(self.item.match(item, ctx)) return self.type_.match(result, ctx) @@ -1084,14 +1162,12 @@ def equals(self, other: MappingOf) -> bool: @cython.cfunc def match(self, value, ctx: Context): if not isinstance(value, Mapping): - return NoMatch + raise NoMatchError() result = {} for k, v in value.items(): - if (k := self.key.match(k, ctx)) is NoMatch: - return NoMatch - if (v := self.value.match(v, ctx)) is NoMatch: - return NoMatch + k = self.key.match(k, ctx) + v = self.value.match(v, ctx) result[k] = v return self.type_.match(result, ctx) @@ -1126,11 +1202,7 @@ def equals(self, other: Custom) -> bool: @cython.cfunc def match(self, value, ctx: Context): - result = self.func(value, **ctx) - if result is NoMatch: - return NoMatch - else: - return result + return self.func(value, **ctx) @cython.final @@ -1169,8 +1241,6 @@ def equals(self, other: Capture) -> bool: @cython.cfunc def match(self, value, ctx: Context): value = self.what.match(value, ctx) - if value is NoMatch: - return NoMatch ctx[self.key] = value return value @@ -1202,8 +1272,6 @@ def __repr__(self) -> str: @cython.cfunc def match(self, value, ctx: Context): value = self.searcher.match(value, ctx) - if value is NoMatch: - return NoMatch # use the `_` reserved variable to record the value being replaced # in the context, so that it can be used in the replacer pattern ctx["_"] = value @@ -1233,7 +1301,6 @@ def ObjectOf(type_, *args, **kwargs) -> Pattern: return ObjectOfX(type_, *args, **kwargs) - @cython.cfunc @cython.inline def _reconstruct(value: Any, changed: dict[str, Any]): @@ -1256,25 +1323,28 @@ class ObjectOf1(Pattern): def __init__(self, type_: type, **kwargs): assert len(kwargs) == 1 self.type_ = type_ - (self.field1, pattern1), = kwargs.items() + ((self.field1, pattern1),) = kwargs.items() self.pattern1 = pattern(pattern1) def __repr__(self) -> str: return f"ObjectOf1({self.type_!r}, {self.field1!r}={self.pattern1!r})" def equals(self, other: ObjectOf1) -> bool: - return self.type_ == other.type_ and self.field1 == other.field1 and self.pattern1 == other.pattern1 + return ( + self.type_ == other.type_ + and self.field1 == other.field1 + and self.pattern1 == other.pattern1 + ) @cython.cfunc def match(self, value, ctx: Context): if not isinstance(value, self.type_): - return NoMatch + raise NoMatchError() attr1 = getattr(value, self.field1) result1 = self.pattern1.match(attr1, ctx) - if result1 is NoMatch: - return NoMatch - elif result1 is not attr1: + + if result1 is not attr1: changed: dict = {self.field1: result1} return _reconstruct(value, changed) else: @@ -1312,17 +1382,13 @@ def equals(self, other: ObjectOf2) -> bool: @cython.cfunc def match(self, value, ctx: Context): if not isinstance(value, self.type_): - return NoMatch + raise NoMatchError() attr1 = getattr(value, self.field1) result1 = self.pattern1.match(attr1, ctx) - if result1 is NoMatch: - return NoMatch attr2 = getattr(value, self.field2) result2 = self.pattern2.match(attr2, ctx) - if result2 is NoMatch: - return NoMatch if result1 is not attr1 or result2 is not attr2: changed: dict = {self.field1: result1, self.field2: result2} @@ -1330,6 +1396,7 @@ def match(self, value, ctx: Context): else: return value + @cython.final @cython.cclass class ObjectOf3(Pattern): @@ -1344,7 +1411,9 @@ class ObjectOf3(Pattern): def __init__(self, type_: type, **kwargs): assert len(kwargs) == 3 self.type_ = type_ - (self.field1, pattern1), (self.field2, pattern2), (self.field3, pattern3) = kwargs.items() + (self.field1, pattern1), (self.field2, pattern2), (self.field3, pattern3) = ( + kwargs.items() + ) self.pattern1 = pattern(pattern1) self.pattern2 = pattern(pattern2) self.pattern3 = pattern(pattern3) @@ -1366,31 +1435,28 @@ def equals(self, other: ObjectOf3) -> bool: @cython.cfunc def match(self, value, ctx: Context): if not isinstance(value, self.type_): - return NoMatch + raise NoMatchError() attr1 = getattr(value, self.field1) result1 = self.pattern1.match(attr1, ctx) - if result1 is NoMatch: - return NoMatch attr2 = getattr(value, self.field2) result2 = self.pattern2.match(attr2, ctx) - if result2 is NoMatch: - return NoMatch attr3 = getattr(value, self.field3) result3 = self.pattern3.match(attr3, ctx) - if result3 is NoMatch: - return NoMatch if result1 is not attr1 or result2 is not attr2 or result3 is not attr3: - changed: dict = {self.field1: result1, self.field2: result2, self.field3: result3} + changed: dict = { + self.field1: result1, + self.field2: result2, + self.field3: result3, + } return _reconstruct(value, changed) else: return value - @cython.final @cython.cclass class ObjectOfN(Pattern): @@ -1426,17 +1492,14 @@ def equals(self, other: ObjectOfN) -> bool: @cython.cfunc def match(self, value, ctx: Context): if not isinstance(value, self.type_): - return NoMatch + raise NoMatchError() - name: str pattern: Pattern changed: dict[str, Any] = {} for name, pattern in self.fields.items(): attr = getattr(value, name) result = pattern.match(attr, ctx) - if result is NoMatch: - return NoMatch - elif result is not attr: + if result is not attr: changed[name] = result if changed: @@ -1469,12 +1532,11 @@ def equals(self, other: ObjectOfX) -> bool: @cython.cfunc def match(self, value, ctx: Context): - if self.type_.match(value, ctx) is NoMatch: - return NoMatch + self.type_.match(value, ctx) # the pattern requirest more positional arguments than the object has if len(value.__match_args__) < len(self.args): - return NoMatch + raise NoMatchError() patterns: dict[str, Pattern] = dict(zip(value.__match_args__, self.args)) patterns.update(self.kwargs) @@ -1486,12 +1548,10 @@ def match(self, value, ctx: Context): try: attr = getattr(value, name) except AttributeError: - return NoMatch + raise NoMatchError() result = pattern.match(attr, ctx) - if result is NoMatch: - return NoMatch - elif result is not attr: + if result is not attr: changed[name] = result if changed: @@ -1510,10 +1570,16 @@ def __init__(self, args, return_=_any): self.args = [pattern(arg) for arg in args] self.return_ = pattern(return_) + def __repr__(self) -> str: + return f"CallableWith({self.args!r}, return_={self.return_!r})" + + def equals(self, other: CallableWith) -> bool: + return self.args == other.args and self.return_ == other.return_ + @cython.cfunc def match(self, value, ctx: Context): if not callable(value): - return NoMatch + raise NoMatchError() sig = Signature.from_callable(value) @@ -1534,13 +1600,14 @@ def match(self, value, ctx: Context): if len(required_positional) > len(self.args): # Callable has more positional arguments than expected") - return NoMatch + raise NoMatchError() elif len(positional) < len(self.args) and not has_varargs: # Callable has less positional arguments than expected") - return NoMatch + raise NoMatchError() else: return value + @cython.final @cython.cclass class WithLength(Pattern): @@ -1585,9 +1652,9 @@ def equals(self, other: WithLength) -> bool: def match(self, value, ctx: Context): length = len(value) if self.at_least is not None and length < self.at_least: - return NoMatch + raise NoMatchError() if self.at_most is not None and length > self.at_most: - return NoMatch + raise NoMatchError() return value @@ -1619,9 +1686,6 @@ def equals(self, other: SomeItemsOf) -> bool: @cython.cfunc def match(self, values, ctx: Context): result = self.pattern.match(values, ctx) - if result is NoMatch: - return NoMatch - return self.length.match(result, ctx) @@ -1652,7 +1716,9 @@ def equals(self, other: SomeChunksOf) -> bool: def chunk(self, values, context): chunk: list = [] for item in values: - if self.delimiter.match(item, context) is NoMatch: + try: + self.delimiter.match(item, context) + except NoMatchError: chunk.append(item) else: if chunk: # only yield if there are items in the chunk @@ -1665,13 +1731,7 @@ def chunk(self, values, context): def match(self, values, ctx: Context): chunks = self.chunk(values, ctx) result = self.pattern.match(chunks, ctx) - if result is NoMatch: - return NoMatch - result = self.length.match(result, ctx) - if result is NoMatch: - return NoMatch - return [el for lst in result for el in lst] @@ -1727,22 +1787,20 @@ def delimiter(self) -> Pattern: @cython.cfunc def match(self, values, ctx: Context): if isinstance(values, (str, bytes)): - return NoMatch + raise NoMatchError() try: values = list(values) except TypeError: - return NoMatch + raise NoMatchError() if len(values) != len(self.patterns): - return NoMatch + raise NoMatchError() result = [] pattern: Pattern for pattern, value in zip(self.patterns, values): value = pattern.match(value, ctx) - if value is NoMatch: - return NoMatch result.append(value) return self.type_(result) @@ -1771,7 +1829,10 @@ def delimiter(self) -> Pattern: @cython.cfunc def match(self, value, ctx: Context): if not self.patterns: - return NoMatch if value else [] + if value: + raise NoMatchError() + else: + return self.type_(value) it = RewindableIterator(value) @@ -1799,32 +1860,28 @@ def match(self, value, ctx: Context): except StopIteration: break - res = following.match(item, ctx) - if res is NoMatch: + try: + res = following.match(item, ctx) + except NoMatchError: matches.append(item) else: it.rewind() break res = original.match(matches, ctx) - if res is NoMatch: - return NoMatch - else: - result.extend(res) + result.extend(res) else: try: item = next(it) except StopIteration: - return NoMatch + raise NoMatchError() res = original.match(item, ctx) - if res is NoMatch: - return NoMatch - else: - result.append(res) + result.append(res) return self.type_(result) + @cython.ccall def PatternMap(fields) -> Pattern: if len(fields) == 1: @@ -1842,7 +1899,7 @@ class PatternMap1(Pattern): pattern1: Pattern def __init__(self, fields): - (self.field1, pattern1), = fields.items() + ((self.field1, pattern1),) = fields.items() self.pattern1 = pattern(pattern1) def __repr__(self) -> str: @@ -1853,21 +1910,25 @@ def equals(self, other: PatternMap1) -> bool: @cython.cfunc def match(self, value, ctx: Context): - if not isinstance(value, Mapping): - return NoMatch + if not isinstance(value, dict): + raise NoMatchError() + + if len(value) != 1: + raise NoMatchError() try: item1 = value[self.field1] except KeyError: - return NoMatch + raise NoMatchError() + result1 = self.pattern1.match(item1, ctx) - if result1 is NoMatch: - return NoMatch - elif result1 is not item1: + + if result1 is not item1: return type(value)({**value, self.field1: result1}) else: return value + @cython.final @cython.cclass class PatternMap2(Pattern): @@ -1887,31 +1948,27 @@ def __repr__(self) -> str: def equals(self, other: PatternMap2) -> bool: return ( self.field1 == other.field1 - and self.pattern1 == other.pattern1 and self.field2 == other.field2 + and self.pattern1 == other.pattern1 and self.pattern2 == other.pattern2 ) @cython.cfunc def match(self, value, ctx: Context): - if not isinstance(value, Mapping): - return NoMatch + if not isinstance(value, dict): + raise NoMatchError() - try: - item1 = value[self.field1] - except KeyError: - return NoMatch - result1 = self.pattern1.match(item1, ctx) - if result1 is NoMatch: - return NoMatch + if len(value) != 2: + raise NoMatchError() try: + item1 = value[self.field1] item2 = value[self.field2] except KeyError: - return NoMatch + raise NoMatchError() + + result1 = self.pattern1.match(item1, ctx) result2 = self.pattern2.match(item2, ctx) - if result2 is NoMatch: - return NoMatch if result1 is not item1 or result2 is not item2: return type(value)({**value, self.field1: result1, self.field2: result2}) @@ -1919,6 +1976,68 @@ def match(self, value, ctx: Context): return value +@cython.final +@cython.cclass +class PatternMap3(Pattern): + field1: str + field2: str + field3: str + pattern1: Pattern + pattern2: Pattern + pattern3: Pattern + + def __init__(self, fields): + (self.field1, pattern1), (self.field2, pattern2), (self.field3, pattern3) = ( + fields.items() + ) + self.pattern1 = pattern(pattern1) + self.pattern2 = pattern(pattern2) + self.pattern3 = pattern(pattern3) + + def __repr__(self) -> str: + return f"PatternMap3({self.field1!r}={self.pattern1!r}, {self.field2!r}={self.pattern2!r}, {self.field3!r}={self.pattern3!r})" + + def equals(self, other: PatternMap3) -> bool: + return ( + self.field1 == other.field1 + and self.field2 == other.field2 + and self.field3 == other.field3 + and self.pattern1 == other.pattern1 + and self.pattern2 == other.pattern2 + and self.pattern3 == other.pattern3 + ) + + @cython.cfunc + def match(self, value, ctx: Context): + if not isinstance(value, dict): + raise NoMatchError() + + if len(value) != 3: + raise NoMatchError() + + try: + item1 = value[self.field1] + item2 = value[self.field2] + item3 = value[self.field3] + except KeyError: + raise NoMatchError() + + result1 = self.pattern1.match(item1, ctx) + result2 = self.pattern2.match(item2, ctx) + result3 = self.pattern3.match(item3, ctx) + + if result1 is not item1 or result2 is not item2 or result3 is not item3: + return type(value)( + { + **value, + self.field1: result1, + self.field2: result2, + self.field3: result3, + } + ) + else: + return value + @cython.final @cython.cclass @@ -1936,8 +2055,11 @@ def equals(self, other: PatternMapN) -> bool: @cython.cfunc def match(self, value, ctx: Context): - if not isinstance(value, Mapping): - return NoMatch + if not isinstance(value, dict): # check for __getitem__ + raise NoMatchError() + + if len(value) != len(self.fields): + raise NoMatchError() name: str pattern: Pattern @@ -1946,11 +2068,9 @@ def match(self, value, ctx: Context): try: item = value[name] except KeyError: - return NoMatch + raise NoMatchError() result = pattern.match(item, ctx) - if result is NoMatch: - return NoMatch - elif result is not item: + if result is not item: changed[name] = result if changed: @@ -1959,8 +2079,6 @@ def match(self, value, ctx: Context): return value - - @cython.ccall def pattern(obj: Any, allow_custom: bool = True) -> Pattern: """Create a pattern from various types. @@ -1998,10 +2116,9 @@ def pattern(obj: Any, allow_custom: bool = True) -> Pattern: elif isinstance(obj, Pattern): return obj elif isinstance(obj, (Deferred, Builder)): - # return Capture(obj) return EqDeferred(obj) elif isinstance(obj, Mapping): - return EqValue(obj) + return PatternMap(obj) elif isinstance(obj, Sequence): if isinstance(obj, (str, bytes)): return EqValue(obj) @@ -2015,5 +2132,3 @@ def pattern(obj: Any, allow_custom: bool = True) -> Pattern: return Custom(obj) else: return EqValue(obj) - - diff --git a/koerce/sugar.py b/koerce/sugar.py index 3cc7d7b..4b48a11 100644 --- a/koerce/sugar.py +++ b/koerce/sugar.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import sys from typing import Any @@ -9,10 +10,15 @@ Context, Eq, If, - NoMatch, # noqa: F401 + NoMatch, Pattern, pattern, ) +from .utils import Signature + + +class ValidationError(Exception): + pass class Namespace: @@ -92,7 +98,107 @@ def match(pat: Pattern, value: Any, context: Context = None) -> Any: return pat.apply(value, context) +def annotated(_1=None, _2=None, _3=None, **kwargs): + """Create functions with arguments validated at runtime. + + There are various ways to apply this decorator: + + 1. With type annotations + + >>> @annotated + ... def foo(x: int, y: str) -> float: + ... return float(x) + float(y) + + 2. With argument patterns passed as keyword arguments + + >>> from ibis.common.patterns import InstanceOf as instance_of + >>> @annotated(x=instance_of(int), y=instance_of(str)) + ... def foo(x, y): + ... return float(x) + float(y) + + 3. With mixing type annotations and patterns where the latter takes precedence + + >>> @annotated(x=instance_of(float)) + ... def foo(x: int, y: str) -> float: + ... return float(x) + float(y) + + 4. With argument patterns passed as a list and/or an optional return pattern + + >>> @annotated([instance_of(int), instance_of(str)], instance_of(float)) + ... def foo(x, y): + ... return float(x) + float(y) + + Parameters + ---------- + *args : Union[ + tuple[Callable], + tuple[list[Pattern], Callable], + tuple[list[Pattern], Pattern, Callable] + ] + Positional arguments. + - If a single callable is passed, it's wrapped with the signature + - If two arguments are passed, the first one is a list of patterns for the + arguments and the second one is the callable to wrap + - If three arguments are passed, the first one is a list of patterns for the + arguments, the second one is a pattern for the return value and the third + one is the callable to wrap + **kwargs : dict[str, Pattern] + Patterns for the arguments. + + Returns + ------- + Callable + + """ + if _1 is None: + return functools.partial(annotated, **kwargs) + elif _2 is None: + if callable(_1): + func, patterns, return_pattern = _1, None, None + else: + return functools.partial(annotated, _1, **kwargs) + elif _3 is None: + if not isinstance(_2, Pattern): + func, patterns, return_pattern = _2, _1, None + else: + return functools.partial(annotated, _1, _2, **kwargs) + else: + func, patterns, return_pattern = _3, _1, _2 + + sig = Signature.from_callable(func) + argpats, retpat = Pattern.from_callable( + func, sig=sig, args=patterns or kwargs, return_=return_pattern + ) + + @functools.wraps(func) + def wrapped(*args, **kwargs): + # 0. Bind the arguments to the signature + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + # 1. Validate the passed arguments + values = argpats.apply(bound.arguments) + if values is NoMatch: + raise ValidationError() + + # 2. Reconstruction of the original arguments + args, kwargs = sig.unbind(values) + + # 3. Call the function with the validated arguments + result = func(*args, **kwargs) + + # 4. Validate the return value + result = retpat.apply(result) + if result is NoMatch: + raise ValidationError() + + return result + + wrapped.__signature__ = sig + + return wrapped + + if_ = If eq = Eq _ = var("_") - diff --git a/koerce/tests/test_patterns.py b/koerce/tests/test_patterns.py index 2add403..8acc7b5 100644 --- a/koerce/tests/test_patterns.py +++ b/koerce/tests/test_patterns.py @@ -55,6 +55,7 @@ Option, Pattern, PatternList, + PatternMap, Replace, SequenceOf, SomeChunksOf, @@ -85,6 +86,11 @@ def equals(self, other): return self.min == other.min +class FrozenDict(dict): + def __setitem__(self, key: Any, value: Any) -> None: + raise TypeError("Cannot modify a frozen dict") + + various_values = [1, "1", 1.0, object, False, None] @@ -98,6 +104,23 @@ def test_nothing(value): assert Nothing().apply(value) is NoMatch +@pytest.mark.parametrize( + ("inner", "default", "value", "expected"), + [ + (Anything(), None, None, None), + (Anything(), None, "three", "three"), + (Anything(), 1, None, 1), + (AsType(int), 11, None, 11), + (AsType(int), None, None, None), + (AsType(int), None, 18, 18), + (AsType(str), None, "caracal", "caracal"), + ], +) +def test_option(inner, default, value, expected): + p = Option(inner, default=default) + assert p.apply(value) == expected + + @pytest.mark.parametrize("value", various_values) def test_identical_to(value): assert IdenticalTo(value).apply(value) == value @@ -489,7 +512,7 @@ def negative(x): assert p.apply(1.0, context={}) == 1.0 assert p.apply(-1.0, context={}) is NoMatch assert p.apply(1, context={}) is NoMatch - #assert p.describe() == "anything except an int or a value that satisfies negative()" + # assert p.describe() == "anything except an int or a value that satisfies negative()" def test_generic_sequence_of(): @@ -564,7 +587,14 @@ def __init__(self, e, f, g, h, i): self.i = i def __eq__(self, other): - return type(self) is type(other) and self.e == other.e and self.f == other.f and self.g == other.g and self.h == other.h and self.i == other.i + return ( + type(self) is type(other) + and self.e == other.e + and self.f == other.f + and self.g == other.g + and self.h == other.h + and self.i == other.i + ) def test_object_pattern(): @@ -740,8 +770,8 @@ def test_pattern_list(): assert p.apply([1, 2, 3.0, 4.0, 5.0], context={}) == [1, 2, 3, 4.0, 5.0] # subpattern is a sequence - # p = PatternList([1, 2, 3, SomeOf(AsType(int), at_least=1)]) - # assert p.apply([1, 2, 3, 4.0, 5.0], context={}) == [1, 2, 3, 4, 5] + p = PatternList([1, 2, 3, SomeOf(AsType(int), at_least=1)]) + assert p.apply([1, 2, 3, 4.0, 5.0], context={}) == [1, 2, 3, 4, 5] def test_pattern_list_from_tuple_typehint(): @@ -968,7 +998,6 @@ def test_matching_sequence_complicated(): assert match([0, SomeOf([1, 2, str]), 3], v) == v - def test_pattern_sequence_with_nested_some_of(): assert SomeChunksOf(1, 2) == SomeOf(1, 2) @@ -1121,12 +1150,12 @@ def test_pattern_decorator(): dict[str, float], DictOf(InstanceOf(str), InstanceOf(float)), ), - # (FrozenDict[str, int], FrozenDictOf(InstanceOf(str), InstanceOf(int))), + (FrozenDict[str, int], MappingOf(InstanceOf(str), InstanceOf(int), FrozenDict)), (Literal["alpha", "beta", "gamma"], IsIn(("alpha", "beta", "gamma"))), - # ( - # Callable[[str, int], str], - # CallableWith((InstanceOf(str), InstanceOf(int)), InstanceOf(str)), - # ), + ( + Callable[[str, int], str], + CallableWith((InstanceOf(str), InstanceOf(int)), InstanceOf(str)), + ), # (Callable, InstanceOf(CallableABC)), ], ) @@ -1278,8 +1307,7 @@ def f(x): # assert pattern(f) == Custom(f) # matching mapping values - assert pattern({"a": 1, "b": 2}) == EqValue({"a": 1, "b": 2}) - + assert pattern({"a": 1, "b": 2}) == PatternMap({"a": EqValue(1), "b": EqValue(2)}) def test_callable_with(): @@ -1339,3 +1367,41 @@ def g(a: int, b: str, c: str = "0"): assert p.apply(f) is NoMatch assert p.apply(g) == g assert p.apply(h) == h + + +def test_pattern_from_callable(): + def func(a: int, b: str) -> str: ... + + args, ret = Pattern.from_callable(func) + assert args == PatternMap({"a": InstanceOf(int), "b": InstanceOf(str)}) + assert ret == InstanceOf(str) + + def func(a: int, b: str, c: str = "0") -> str: ... + + args, ret = Pattern.from_callable(func) + assert args == PatternMap( + {"a": InstanceOf(int), "b": InstanceOf(str), "c": Option(InstanceOf(str), "0")} + ) + assert ret == InstanceOf(str) + + def func(a: int, b: str, *args): ... + + args, ret = Pattern.from_callable(func) + assert args == PatternMap( + {"a": InstanceOf(int), "b": InstanceOf(str), "args": TupleOf(Anything())} + ) + assert ret == Anything() + + def func(a: int, b: str, c: str = "0", *args, **kwargs: int) -> float: ... + + args, ret = Pattern.from_callable(func) + assert args == PatternMap( + { + "a": InstanceOf(int), + "b": InstanceOf(str), + "c": Option(InstanceOf(str), "0"), + "args": TupleOf(Anything()), + "kwargs": MappingOf(Anything(), InstanceOf(int)), + } + ) + assert ret == InstanceOf(float) diff --git a/koerce/tests/test_sugar.py b/koerce/tests/test_sugar.py index 0369866..5aa9010 100644 --- a/koerce/tests/test_sugar.py +++ b/koerce/tests/test_sugar.py @@ -1,6 +1,12 @@ from __future__ import annotations -from koerce.sugar import NoMatch, match, var +from typing import Annotated, Union + +import pytest + +from koerce.patterns import InstanceOf, NoMatchError, pattern +from koerce.sugar import NoMatch, ValidationError, annotated, match, var +from koerce.utils import Signature def test_capture_shorthand(): @@ -26,3 +32,272 @@ def test_capture_shorthand(): def test_namespace(): pass + + +def test_signature_unbind_from_callable(): + def test(a: int, b: int, c: int = 1): ... + + sig = Signature.from_callable(test) + bound = sig.bind(2, 3) + bound.apply_defaults() + + assert bound.arguments == {"a": 2, "b": 3, "c": 1} + + args, kwargs = sig.unbind(bound.arguments) + assert args == (2, 3, 1) + assert kwargs == {} + + +def test_signature_unbind_from_callable_with_varargs(): + def test(a: int, b: int, *args: int): ... + + sig = Signature.from_callable(test) + bound = sig.bind(2, 3) + bound.apply_defaults() + + assert bound.arguments == {"a": 2, "b": 3, "args": ()} + args, kwargs = sig.unbind(bound.arguments) + assert args == (2, 3) + assert kwargs == {} + + bound = sig.bind(2, 3, 4, 5) + bound.apply_defaults() + assert bound.arguments == {"a": 2, "b": 3, "args": (4, 5)} + args, kwargs = sig.unbind(bound.arguments) + assert args == (2, 3, 4, 5) + assert kwargs == {} + + +def test_signature_unbind_from_callable_with_positional_only_arguments(): + def test(a: int, b: int, /, c: int = 1): ... + + sig = Signature.from_callable(test) + bound = sig.bind(2, 3) + bound.apply_defaults() + assert bound.arguments == {"a": 2, "b": 3, "c": 1} + + args, kwargs = sig.unbind(bound.arguments) + assert args == (2, 3, 1) + assert kwargs == {} + + bound = sig.bind(2, 3, 4) + bound.apply_defaults() + assert bound.arguments == {"a": 2, "b": 3, "c": 4} + + args, kwargs = sig.unbind(bound.arguments) + assert args == (2, 3, 4) + assert kwargs == {} + + +def test_signature_unbind_from_callable_with_keyword_only_arguments(): + def test(a: int, b: int, *, c: float, d: float = 0.0): ... + + sig = Signature.from_callable(test) + bound = sig.bind(2, 3, c=4.0) + bound.apply_defaults() + assert bound.arguments == {"a": 2, "b": 3, "c": 4.0, "d": 0.0} + + args, kwargs = sig.unbind(bound.arguments) + assert args == (2, 3) + assert kwargs == {"c": 4.0, "d": 0.0} + + +def test_signature_unbind(): + def func(a, b, c=1): ... + + sig = Signature.from_callable(func) + bound = sig.bind(1, 2) + bound.apply_defaults() + + assert bound.arguments == {"a": 1, "b": 2, "c": 1} + + args, kwargs = sig.unbind(bound.arguments) + assert args == (1, 2, 1) + assert kwargs == {} + + +@pytest.mark.parametrize("d", [(), (5, 6, 7)]) +def test_signature_unbind_with_empty_variadic(d): + def func(a, b, c, *args, e=None): + return a, b, c, args, e + + sig = Signature.from_callable(func) + bound = sig.bind(1, 2, 3, *d, e=4) + bound.apply_defaults() + assert bound.arguments == {"a": 1, "b": 2, "c": 3, "args": d, "e": 4} + + args, kwargs = sig.unbind(bound.arguments) + assert args == (1, 2, 3, *d) + assert kwargs == {"e": 4} + + +def test_annotated_function(): + @annotated(a=InstanceOf(int), b=InstanceOf(int), c=InstanceOf(int)) + def test(a, b, c=1): + return a + b + c + + assert test(2, 3) == 6 + assert test(2, 3, 4) == 9 + assert test(2, 3, c=4) == 9 + assert test(a=2, b=3, c=4) == 9 + + with pytest.raises(ValidationError): + test(2, 3, c="4") + + @annotated(a=InstanceOf(int)) + def test(a, b, c=1): + return (a, b, c) + + assert test(2, "3") == (2, "3", 1) + + +def test_annotated_function_with_type_annotations(): + @annotated() + def test(a: int, b: int, c: int = 1): + return a + b + c + + assert test(2, 3) == 6 + + @annotated + def test(a: int, b: int, c: int = 1): + return a + b + c + + assert test(2, 3) == 6 + + @annotated + def test(a: int, b, c=1): + return (a, b, c) + + assert test(2, 3, "4") == (2, 3, "4") + + +def test_annotated_function_with_return_type_annotation(): + @annotated + def test_ok(a: int, b: int, c: int = 1) -> int: + return a + b + c + + @annotated + def test_wrong(a: int, b: int, c: int = 1) -> int: + return "invalid result" + + assert test_ok(2, 3) == 6 + with pytest.raises(ValidationError): + test_wrong(2, 3) + + +def test_annotated_function_with_keyword_overrides(): + @annotated(b=InstanceOf(float)) + def test(a: int, b: int, c: int = 1): + return a + b + c + + with pytest.raises(ValidationError): + test(2, 3) + + assert test(2, 3.0) == 6.0 + + +def test_annotated_function_with_list_overrides(): + @annotated([InstanceOf(int), InstanceOf(int), InstanceOf(float)]) + def test(a: int, b: int, c: int = 1): + return a + b + c + + with pytest.raises(ValidationError): + test(2, 3, 4) + + +def test_annotated_function_with_list_overrides_and_return_override(): + @annotated([InstanceOf(int), InstanceOf(int), InstanceOf(float)], InstanceOf(float)) + def test(a: int, b: int, c: int = 1): + return a + b + c + + with pytest.raises(ValidationError): + test(2, 3, 4) + + assert test(2, 3, 4.0) == 9.0 + + +@pattern +def short_str(x, **context): + if len(x) > 3: + return x + else: + raise NoMatchError() + + +@pattern +def endswith_d(x, **context): + if x.endswith("d"): + return x + else: + raise NoMatchError() + + +def test_annotated_function_with_complex_type_annotations(): + @annotated + def test(a: Annotated[str, short_str, endswith_d], b: Union[int, float]): + return a, b + + assert test("abcd", 1) == ("abcd", 1) + assert test("---d", 1.0) == ("---d", 1.0) + + with pytest.raises(ValidationError): + test("---c", 1) + with pytest.raises(ValidationError): + test("123", 1) + with pytest.raises(ValidationError): + test("abcd", "qweqwe") + + +def test_annotated_function_without_annotations(): + @annotated + def test(a, b, c): + return a, b, c + + assert test(1, 2, 3) == (1, 2, 3) + assert test.__signature__.parameters.keys() == {"a", "b", "c"} + + +# def test_annotated_function_without_decoration(snapshot): +# def test(a, b, c): +# return a + b + c + +# func = annotated(test) +# with pytest.raises(ValidationError) as excinfo: +# func(1, 2) +# snapshot.assert_match(str(excinfo.value), "error.txt") + +# assert func(1, 2, c=3) == 6 + + +def test_annotated_function_with_varargs(): + @annotated + def test(a: float, b: float, *args: int): + return sum((a, b) + args) + + assert test(1.0, 2.0, 3, 4) == 10.0 + assert test(1.0, 2.0, 3, 4, 5) == 15.0 + + with pytest.raises(ValidationError): + test(1.0, 2.0, 3, 4, 5, 6.0) + + +def test_annotated_function_with_varkwargs(): + @annotated + def test(a: float, b: float, **kwargs: int): + return sum((a, b) + tuple(kwargs.values())) + + assert test(1.0, 2.0, c=3, d=4) == 10.0 + assert test(1.0, 2.0, c=3, d=4, e=5) == 15.0 + + with pytest.raises(ValidationError): + test(1.0, 2.0, c=3, d=4, e=5, f=6.0) + + +# def test_multiple_validation_failures(): +# @annotated +# def test(a: float, b: float, *args: int, **kwargs: int): ... + +# with pytest.raises(ValidationError) as excinfo: +# test(1.0, 2.0, 3.0, 4, c=5.0, d=6) + +# assert len(excinfo.value.errors) == 2 diff --git a/koerce/utils.py b/koerce/utils.py index fbf60b0..cf19e8a 100644 --- a/koerce/utils.py +++ b/koerce/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import itertools import typing from typing import Any, TypeVar @@ -177,3 +178,41 @@ def rewind(self): def checkpoint(self): """Create a checkpoint of the current iterator state.""" self._iterator, self._checkpoint = itertools.tee(self._iterator) + + +class Signature(inspect.Signature): + def unbind(self, this: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]]: + """Reverse bind of the parameters. + + Attempts to reconstructs the original arguments as keyword only arguments. + + Parameters + ---------- + this : Any + Object with attributes matching the signature parameters. + + Returns + ------- + args : (args, kwargs) + Tuple of positional and keyword arguments. + + """ + # does the reverse of bind, but doesn't apply defaults + args: list = [] + kwargs: dict = {} + for name, param in self.parameters.items(): + value = this[name] + if param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: + args.append(value) + elif param.kind is inspect.Parameter.VAR_POSITIONAL: + args.extend(value) + elif param.kind is inspect.Parameter.VAR_KEYWORD: + kwargs.update(value) + elif param.kind is inspect.Parameter.KEYWORD_ONLY: + kwargs[name] = value + elif param.kind is inspect.Parameter.POSITIONAL_ONLY: + args.append(value) + else: + raise TypeError(f"unsupported parameter kind {param.kind}") + + return tuple(args), kwargs