diff --git a/.gitignore b/.gitignore index 043f0b1..bf48101 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +_internal.pyx diff --git a/build.py b/build.py index fcd2b5b..b2c35bb 100644 --- a/build.py +++ b/build.py @@ -3,6 +3,7 @@ import os import shutil import sys +import ast from pathlib import Path # setuptools *must* come before Cython, otherwise Cython's distutils hacking @@ -16,17 +17,69 @@ BUILD_DIR = Path("cython_build") -extensions = [ - Extension("koerce.annots", ["koerce/annots.py"]), - Extension("koerce.builders", ["koerce/builders.py"]), - Extension("koerce.patterns", ["koerce/patterns.py"]), - # Extension("koerce.utils", ["koerce/utils.py"]), -] +def extract_imports_and_code(path): + """Extracts the import statements and other code from python source.""" + with path.open("r") as file: + tree = ast.parse(file.read(), filename=path.name) + + code = [] + imports = [] + for node in tree.body: + if isinstance(node, (ast.Import, ast.ImportFrom)): + imports.append(node) + else: + code.append(node) + + return imports, code + + +def ignore_import(imp, modules): + absolute_names = ["koerce.{name}" for name in modules] + if isinstance(imp, ast.ImportFrom): + return imp.module in modules + elif isinstance(imp, ast.Import): + return imp.names[0].name in absolute_names + else: + raise TypeError(imp) + + +def concatenate_files(file_paths, output_file): + all_imports = [] + all_code = [] + modules = [] + + for file_path in file_paths: + path = Path(SOURCE_DIR / file_path) + imports, code = extract_imports_and_code(path) + all_imports.extend(imports) + all_code.extend(code) + modules.append(path.stem) + + # Deduplicate imports by their unparsed code + unique_imports = {ast.unparse(stmt): stmt for stmt in all_imports} + + # Write to the output file + with (SOURCE_DIR / output_file).open("w") as out: + # Write unique imports + for code, stmt in unique_imports.items(): + if not ignore_import(stmt, modules): + out.write(code) + out.write("\n") + + # Write the rest of the code + for stmt in all_code: + out.write(ast.unparse(stmt)) + out.write("\n\n\n") + + +concatenate_files(["builders.py", "patterns.py", "annots.py"], "_internal.pyx") +extension = Extension("koerce._internal", ["koerce/_internal.pyx"]) cythonized_modules = cythonize( - extensions, + [extension], build_dir=BUILD_DIR, - # generate anannotated .html output files. + cache=True, + show_all_warnings=False, annotate=True, compiler_directives={ "language_level": "3", diff --git a/koerce/__init__.py b/koerce/__init__.py index 7d7e347..6a93b40 100644 --- a/koerce/__init__.py +++ b/koerce/__init__.py @@ -1,6 +1,87 @@ from __future__ import annotations -from .patterns import NoMatch, Pattern -from .sugar import match, var +import sys -__all__ = ["NoMatch", "Pattern", "match", "var"] +from ._internal import * + + +class _Variable(Deferred): + def __init__(self, name: str): + builder = Var(name) + super().__init__(builder) + + def __invert__(self): + return Capture(self) + + +class _Namespace: + """Convenience class for creating patterns for various types from a module. + + Useful to reduce boilerplate when creating patterns for various types from + a module. + + Parameters + ---------- + factory + The pattern to construct with the looked up types. + module + The module object or name to look up the types. + + """ + + __slots__ = ("_factory", "_module") + + def __init__(self, factory, module): + if isinstance(module, str): + module = sys.modules[module] + self._module = module + self._factory = factory + + def __getattr__(self, name: str): + obj = getattr(self._module, name) + return self._factory(obj) + + +def var(name): + return _Variable(name) + + +def match(pat: Pattern, value: Any, context: Context = None) -> Any: + """Match a value against a pattern. + + Parameters + ---------- + pat + The pattern to match against. + value + The value to match. + context + Arbitrary mapping of values to be used while matching. + + Returns + ------- + The matched value if the pattern matches, otherwise :obj:`NoMatch`. + + Examples + -------- + >>> assert match(Any(), 1) == 1 + >>> assert match(1, 1) == 1 + >>> assert match(1, 2) is NoMatch + >>> assert match(1, 1, context={"x": 1}) == 1 + >>> assert match(1, 2, context={"x": 1}) is NoMatch + >>> assert match([1, int], [1, 2]) == [1, 2] + >>> assert match([1, int, "a" @ InstanceOf(str)], [1, 2, "three"]) == [ + ... 1, + ... 2, + ... "three", + ... ] + + """ + pat = pattern(pat) + return pat.apply(value, context) + + +_ = var("_") + + +# define __all__ diff --git a/koerce/_internal.py b/koerce/_internal.py new file mode 100644 index 0000000..65a4d01 --- /dev/null +++ b/koerce/_internal.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from .annots import * +from .builders import * +from .patterns import * + +compiled = False diff --git a/koerce/annots.py b/koerce/annots.py index 8807857..f2c4497 100644 --- a/koerce/annots.py +++ b/koerce/annots.py @@ -3,12 +3,17 @@ import functools import inspect import typing -from typing import Any +from abc import ABC +from collections.abc import Callable, Mapping, Sequence +from copy import copy +from types import FunctionType, MethodType +from typing import Any, ClassVar, Optional import cython from .patterns import ( DictOf, + FrozenDictOf, NoMatch, Option, Pattern, @@ -17,9 +22,36 @@ _any, pattern, ) -from .utils import get_type_hints +from .utils import get_type_hints, get_type_origin EMPTY = inspect.Parameter.empty +_ensure_pattern = pattern + + +@cython.final +@cython.cclass +class Attribute: + pattern = cython.declare(Pattern, visibility="readonly") + default_ = cython.declare(object, visibility="readonly") + + def __init__(self, pattern: Any = _any, default: Any = EMPTY): + self.pattern = _ensure_pattern(pattern) + self.default_ = default + + def __repr__(self): + return f"<{self.__class__.__name__} pattern={self.pattern!r} default={self.default_!r}>" + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Attribute): + return NotImplemented + right: Attribute = cython.cast(Attribute, other) + return self.pattern == right.pattern and self.default_ == right.default_ + + def __call__(self, default): + """Needed to support the decorator syntax.""" + return self.__class__(self.pattern, default) + + _POSITIONAL_ONLY = cython.declare(cython.int, int(inspect.Parameter.POSITIONAL_ONLY)) _POSITIONAL_OR_KEYWORD = cython.declare( cython.int, int(inspect.Parameter.POSITIONAL_OR_KEYWORD) @@ -38,22 +70,35 @@ class Parameter: KEYWORD_ONLY: typing.ClassVar[int] = _KEYWORD_ONLY VAR_KEYWORD: typing.ClassVar[int] = _VAR_KEYWORD - name = cython.declare(str, visibility="readonly") kind = cython.declare(cython.int, visibility="readonly") - # Cannot use C reserved keyword 'default' here + pattern = cython.declare(Pattern, visibility="readonly") default_ = cython.declare(object, visibility="readonly") typehint = cython.declare(object, visibility="readonly") def __init__( - self, name: str, kind: int, default: Any = EMPTY, typehint: Any = EMPTY + self, + kind: int, + pattern: Any = _any, + default: Any = EMPTY, + typehint: Any = EMPTY, ): - self.name = name self.kind = kind - self.default_ = default self.typehint = typehint + if kind is _VAR_POSITIONAL: + self.pattern = TupleOf(pattern) + elif kind is _VAR_KEYWORD: + self.pattern = FrozenDictOf(_any, pattern) + else: + self.pattern = _ensure_pattern(pattern) + + # validate that the default value matches the pattern + if default is not EMPTY: + self.default_ = self.pattern.match(default, {}) + else: + self.default_ = default - def __str__(self) -> str: - result: str = self.name + def format(self, name) -> str: + result: str = name if self.typehint is not EMPTY: if hasattr(self.typehint, "__qualname__"): result += f": {self.typehint.__qualname__}" @@ -70,16 +115,12 @@ def __str__(self) -> str: result = f"**{result}" return result - def __repr__(self): - return f'<{self.__class__.__name__} "{self}">' - def __eq__(self, other: Any) -> bool: if not isinstance(other, Parameter): return NotImplemented right: Parameter = cython.cast(Parameter, other) return ( - self.name == right.name - and self.kind == right.kind + self.kind == right.kind and self.default_ == right.default_ and self.typehint == right.typehint ) @@ -88,25 +129,71 @@ def __eq__(self, other: Any) -> bool: @cython.final @cython.cclass class Signature: - parameters = cython.declare(list[Parameter], visibility="readonly") + length = cython.declare(cython.int, visibility="readonly") + parameters = cython.declare(dict[str, Parameter], visibility="readonly") + return_pattern = cython.declare(Pattern, visibility="readonly") return_typehint = cython.declare(object, visibility="readonly") - def __init__(self, parameters: list[Parameter], return_typehint: Any = EMPTY): + def __init__( + self, + parameters: dict[str, Parameter], + return_pattern: Pattern = _any, + return_typehint: Any = EMPTY, + ): + self.length = len(parameters) self.parameters = parameters + self.return_pattern = return_pattern self.return_typehint = return_typehint @staticmethod - def from_callable(func: Any) -> Signature: + def from_callable( + func: Any, + arg_patterns: Sequence[Any] | Mapping[str, Any] = None, + return_pattern: Any = None, + ) -> Signature: sig = inspect.signature(func) hints = get_type_hints(func) - params: list[Parameter] = [ - Parameter(p.name, int(p.kind), p.default, hints.get(p.name, EMPTY)) - for p in sig.parameters.values() - ] - return Signature(params, return_typehint=hints.get("return", EMPTY)) + params: dict[str, Parameter] = {} + + if arg_patterns is None: + arg_patterns = {} + elif isinstance(arg_patterns, Sequence): + # create a mapping of parameter name to pattern + arg_patterns = dict(zip(sig.parameters.keys(), arg_patterns)) + elif not isinstance(arg_patterns, Mapping): + raise TypeError("arg_patterns must be a sequence or a mapping") + + for name, param in sig.parameters.items(): + typehint = hints.get(name, EMPTY) + if name in arg_patterns: + argpat = pattern(arg_patterns[name]) + elif typehint is not EMPTY: + argpat = Pattern.from_typehint(typehint) + else: + argpat = _any + + params[name] = Parameter( + kind=int(param.kind), + default=param.default, + pattern=argpat, + typehint=typehint, + ) + + return_typehint = hints.get("return", EMPTY) + if return_pattern is not None: + retpat = pattern(return_pattern) + elif return_typehint is not EMPTY: + retpat = Pattern.from_typehint(return_typehint) + else: + retpat = _any + + return Signature(params, return_typehint=return_typehint, return_pattern=retpat) @staticmethod - def merge(*signatures: Signature, **annotations): + def merge( + signatures: Sequence[Signature], + parameters: Optional[dict[str, Parameter]] = None, + ): """Merge multiple signatures. In addition to concatenating the parameters, it also reorders the @@ -114,10 +201,10 @@ def merge(*signatures: Signature, **annotations): Parameters ---------- - *signatures : Signature + signatures : Signature instances to merge. - **annotations : dict - Annotations to add to the merged signature. + parameters : + Parameters to add to the merged signature. Returns ------- @@ -128,12 +215,12 @@ def merge(*signatures: Signature, **annotations): param: Parameter params: dict[str, Parameter] = {} for sig in signatures: - for param in sig.parameters: - params[param.name] = param + params.update(sig.parameters) inherited: set[str] = set(params.keys()) - for name, annot in annotations.items(): - params[name] = Parameter(name, annotation=annot) + if parameters: + for name, param in parameters.items(): + params[name] = param # mandatory fields without default values must precede the optional # ones in the function signature, the partial ordering will be kept @@ -147,37 +234,52 @@ def merge(*signatures: Signature, **annotations): if param.kind == _VAR_POSITIONAL: if var_args: raise TypeError("only one variadic *args parameter is allowed") - var_args.append(param) + var_args.append(name) elif param.kind == _VAR_KEYWORD: if var_kwargs: raise TypeError("only one variadic **kwargs parameter is allowed") - var_kwargs.append(param) + var_kwargs.append(name) elif name in inherited: if param.default_ is EMPTY: - old_args.append(param) + old_args.append(name) else: - old_kwargs.append(param) + old_kwargs.append(name) elif param.default_ is EMPTY: - new_args.append(param) + new_args.append(name) else: - new_kwargs.append(param) + new_kwargs.append(name) - return Signature( + order: list[str] = ( old_args + new_args + var_args + new_kwargs + old_kwargs + var_kwargs ) + return Signature({name: params[name] for name in order}) def __eq__(self, other: Any) -> bool: if not isinstance(other, Signature): return NotImplemented right: Signature = cython.cast(Signature, other) return ( - self.parameters == right.parameters - and self.return_annotation == right.return_annotation + tuple(self.parameters.items()) == tuple(right.parameters.items()) + and self.return_pattern == right.return_pattern + and self.return_typehint == right.return_typehint ) def __call__(self, /, *args, **kwargs): return self.bind(args, kwargs) + def __len__(self) -> int: + return self.length + + def __str__(self): + params_str = ", ".join( + param.format(name) for name, param in self.parameters.items() + ) + if self.return_typehint is not EMPTY: + return_str = f" -> {self.return_typehint}" + else: + return_str = "" + return f"({params_str}){return_str}" + @cython.ccall def bind(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]: """Bind the arguments to the signature. @@ -199,65 +301,66 @@ def bind(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]: kind: cython.int param: Parameter bound: dict[str, Any] = {} + params = iter(self.parameters.items()) # 1. HANDLE ARGS for i in range(len(args)): - if i >= len(self.parameters): + try: + name, param = next(params) + except StopIteration: raise TypeError("too many positional arguments") - param = self.parameters[i] kind = param.kind if kind is _POSITIONAL_OR_KEYWORD: - if param.name in kwargs: - raise TypeError(f"multiple values for argument '{param.name}'") - bound[param.name] = args[i] + if name in kwargs: + raise TypeError(f"multiple values for argument '{name}'") + bound[name] = args[i] elif kind is _VAR_KEYWORD or kind is _KEYWORD_ONLY: raise TypeError("too many positional arguments") elif kind is _VAR_POSITIONAL: - bound[param.name] = args[i:] + bound[name] = args[i:] break elif kind is _POSITIONAL_ONLY: - bound[param.name] = args[i] + bound[name] = args[i] else: raise TypeError("unreachable code") - # 2. INCREMENT PARAMETER INDEX - if args: - i += 1 + # 2. HANDLE KWARGS + while True: + try: + name, param = next(params) + except StopIteration: + if kwargs: + raise TypeError( + f"got an unexpected keyword argument '{next(iter(kwargs))}'" + ) + break - # 3. HANDLE KWARGS - for param in self.parameters[i:]: - if param.kind is _POSITIONAL_OR_KEYWORD or param.kind is _KEYWORD_ONLY: - if param.name in kwargs: - bound[param.name] = kwargs.pop(param.name) + kind = param.kind + if kind is _POSITIONAL_OR_KEYWORD or kind is _KEYWORD_ONLY: + if name in kwargs: + bound[name] = kwargs.pop(name) elif param.default_ is EMPTY: - raise TypeError(f"missing a required argument: '{param.name}'") + raise TypeError(f"missing a required argument: '{name}'") else: - bound[param.name] = param.default_ - elif param.kind is _VAR_POSITIONAL: - bound[param.name] = () - elif param.kind is _VAR_KEYWORD: - bound[param.name] = kwargs + bound[name] = param.default_ + elif kind is _VAR_POSITIONAL: + bound[name] = () + elif kind is _VAR_KEYWORD: + bound[name] = kwargs break - elif param.kind is _POSITIONAL_ONLY: + elif kind is _POSITIONAL_ONLY: if param.default_ is EMPTY: - if param.name in kwargs: + if name in kwargs: raise TypeError( - f"positional only argument '{param.name}' passed as keyword argument" + f"positional only argument '{name}' passed as keyword argument" ) else: - raise TypeError( - f"missing required positional argument {param.name}" - ) + raise TypeError(f"missing required positional argument {name}") else: - bound[param.name] = param.default_ + bound[name] = param.default_ else: raise TypeError("unreachable code") - else: - if kwargs: - raise TypeError( - f"got an unexpected keyword argument '{next(iter(kwargs))}'" - ) return bound @@ -280,94 +383,27 @@ def unbind(self, bound: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any] """ # does the reverse of bind, but doesn't apply defaults args: list = [] + kind: cython.int kwargs: dict = {} param: Parameter - for param in self.parameters: - value = bound[param.name] - if param.kind is _POSITIONAL_OR_KEYWORD: + for name, param in self.parameters.items(): + value = bound[name] + kind = param.kind + if kind is _POSITIONAL_OR_KEYWORD: args.append(value) - elif param.kind is _VAR_POSITIONAL: + elif kind is _VAR_POSITIONAL: args.extend(value) - elif param.kind is _VAR_KEYWORD: + elif kind is _VAR_KEYWORD: kwargs.update(value) - elif param.kind is _KEYWORD_ONLY: - kwargs[param.name] = value - elif param.kind is _POSITIONAL_ONLY: + elif kind is _KEYWORD_ONLY: + kwargs[name] = value + elif kind is _POSITIONAL_ONLY: args.append(value) else: - raise TypeError(f"unsupported parameter kind {param.kind}") + raise TypeError(f"unsupported parameter kind {kind}") return tuple(args), kwargs - def to_pattern( - self, - overrides: dict[str, Any] | list[Any] | None = None, - return_override: Any = None, - ) -> Pattern: - """Create patterns from a Signature. - - Two patterns are created, one for the arguments and one for the return value. - - Parameters - ---------- - overrides : dict, default None - Pass patterns to add missing or override existing argument type - annotations. - return_override : Pattern, default None - Pattern for the return value of the callable. - - Returns - ------- - Tuple of patterns for the arguments and the return value. - """ - arg_overrides: dict[str, Any] - if overrides is None: - arg_overrides = {} - elif isinstance(overrides, (list, tuple)): - # create a mapping of parameter name to pattern - arg_overrides = { - param.name: arg for param, arg in zip(self.parameters, overrides) - } - elif isinstance(overrides, dict): - arg_overrides = overrides - else: - raise TypeError(f"patterns must be a list or dict, got {type(overrides)}") - - retpat: Pattern - argpat: Pattern - argpats: dict[str, Pattern] = {} - for param in self.parameters: - name: str = param.name - - if name in arg_overrides: - argpat = pattern(arg_overrides[name]) - elif param.typehint is not EMPTY: - argpat = Pattern.from_typehint(param.typehint) - else: - argpat = _any - - if param.kind is _VAR_POSITIONAL: - argpat = TupleOf(argpat) - elif param.kind is _VAR_KEYWORD: - argpat = DictOf(_any, argpat) - elif param.default_ is not EMPTY: - argpat = Option(argpat, default=param.default_) - - argpats[name] = argpat - - if return_override is not None: - retpat = pattern(return_override) - elif self.return_typehint is not EMPTY: - retpat = Pattern.from_typehint(self.return_typehint) - else: - retpat = _any - - return (PatternMap(argpats), retpat) - - -class ValidationError(Exception): - pass - def annotated(_1=None, _2=None, _3=None, **kwargs): """Create functions with arguments validated at runtime. @@ -436,9 +472,11 @@ def annotated(_1=None, _2=None, _3=None, **kwargs): else: func, patterns, return_pattern = _3, _1, _2 - sig: Signature = Signature.from_callable(func) - argpats, retpat = sig.to_pattern( - overrides=patterns or kwargs, return_override=return_pattern + sig: Signature = Signature.from_callable( + func, arg_patterns=patterns or kwargs, return_pattern=return_pattern + ) + pat: Pattern = PatternMap( + {name: param.pattern for name, param in sig.parameters.items()} ) @functools.wraps(func) @@ -447,9 +485,7 @@ def wrapped(*args, **kwargs): bound: dict[str, Any] = sig.bind(args, kwargs) # 1. Validate the passed arguments - values: Any = argpats.apply(bound) - if values is NoMatch: - raise ValidationError() + values: Any = pat.match(bound, {}) # 2. Reconstruction of the original arguments args, kwargs = sig.unbind(values) @@ -458,12 +494,334 @@ def wrapped(*args, **kwargs): result = func(*args, **kwargs) # 4. Validate the return value - result = retpat.apply(result) - if result is NoMatch: - raise ValidationError() + result = sig.return_pattern.match(result, {}) return result wrapped.__signature__ = sig return wrapped + + +def attribute(pattern=_any, default=EMPTY): + """Annotation to mark a field in a class.""" + if default is EMPTY and isinstance(pattern, (FunctionType, MethodType)): + return Attribute(pattern=_any, default=pattern) + else: + return Attribute(pattern=pattern, default=default) + + +def argument(pattern=_any, default=EMPTY, typehint=EMPTY): + """Annotation type for all fields which should be passed as arguments.""" + return Parameter( + kind=_POSITIONAL_OR_KEYWORD, default=default, pattern=pattern, typehint=typehint + ) + + +def optional(pattern=_any, default=None, typehint=EMPTY): + """Annotation to allow and treat `None` values as missing arguments.""" + if default is None: + pattern = Option(pattern) + return Parameter( + kind=_POSITIONAL_OR_KEYWORD, default=default, pattern=pattern, typehint=typehint + ) + + +def varargs(pattern=_any, typehint=EMPTY): + """Annotation to mark a variable length positional arguments.""" + return Parameter(kind=_VAR_POSITIONAL, pattern=pattern, typehint=typehint) + + +def varkwargs(pattern=_any, typehint=EMPTY): + """Annotation to mark a variable length keyword arguments.""" + return Parameter(kind=_VAR_KEYWORD, pattern=pattern, typehint=typehint) + + +__create__ = cython.declare(object, type.__call__) +if cython.compiled: + from cython.cimports.cpython.object import PyObject_GenericSetAttr as __setattr__ +else: + __setattr__ = object.__setattr__ + + +@cython.final +@cython.cclass +class AnnotableSpec: + # make them readonly + initable: cython.bint + hashable: cython.bint + immutable: cython.bint + signature: Signature + attributes: dict[str, Attribute] + + def __init__( + self, + initable: bool, + hashable: bool, + immutable: bool, + signature: Signature, + attributes: dict[str, Attribute], + ): + self.initable = initable + self.hashable = hashable + self.immutable = immutable + self.signature = signature + self.attributes = attributes + + @cython.cfunc + @cython.inline + def new(self, cls: type, args: tuple[Any, ...], kwargs: dict[str, Any]): + ctx: dict[str, Any] = {} + bound: dict[str, Any] + param: Parameter + + if not args and len(kwargs) == self.signature.length: + bound = kwargs + else: + bound = self.signature.bind(args, kwargs) + + if self.initable: + # slow initialization calling __init__ + for name, param in self.signature.parameters.items(): + bound[name] = param.pattern.match(bound[name], ctx) + return __create__(cls, **bound) + else: + # fast initialization directly setting the arguments + this = cls.__new__(cls) + for name, param in self.signature.parameters.items(): + __setattr__(this, name, param.pattern.match(bound[name], ctx)) + if self.attributes: + self.init_attributes(this) + if self.hashable: + self.init_precomputes(this) + return this + + @cython.cfunc + @cython.inline + def init_attributes(self, this) -> cython.void: + attr: Attribute + for name, attr in self.attributes.items(): + if attr.default_ is not EMPTY: + if callable(attr.default_): + value = attr.default_(this) + else: + value = attr.default_ + __setattr__(this, name, value) + + @cython.cfunc + @cython.inline + def init_precomputes(self, this) -> cython.void: + arguments = tuple(getattr(this, name) for name in self.signature.parameters) + hashvalue = hash((this.__class__, arguments)) + __setattr__(this, "__args__", arguments) + __setattr__(this, "__precomputed_hash__", hashvalue) + + +class AnnotableMeta(type): + def __new__( + metacls, + clsname, + bases, + dct, + initable=None, + hashable=False, + immutable=False, + **kwargs, + ): + traits = [] + if initable is None: + # this flag is handled in AnnotableSpec + initable = "__init__" in dct or "__new__" in dct + if hashable: + if not immutable: + raise ValueError("Only immutable classes can be hashable") + traits.append(Hashable) + if immutable: + traits.append(Immutable) + + # inherit signature from parent classes + abstracts: set = set() + signatures: list = [] + attributes: dict[str, Attribute] = {} + for parent in bases: + try: # noqa: SIM105 + signatures.append(parent.__signature__) + except AttributeError: + pass + try: # noqa: SIM105 + attributes.update(parent.__attributes__) + except AttributeError: + pass + try: # noqa: SIM105 + abstracts.update(parent.__abstractmethods__) + except AttributeError: + pass + + # collection type annotations and convert them to patterns + slots: list[str] = list(dct.pop("__slots__", [])) + module: str | None = dct.pop("__module__", None) + qualname: str = dct.pop("__qualname__", clsname) + annotations: dict[str, Any] = dct.pop("__annotations__", {}) + if module is None: + self_qualname = None + else: + self_qualname = f"{module}.{qualname}" + + # TODO(kszucs): pass dct as localns to evaluate_annotations + typehints = get_type_hints(annotations, module=module) + for name, typehint in typehints.items(): + if get_type_origin(typehint) is ClassVar: + continue + dct[name] = Parameter( + kind=_POSITIONAL_OR_KEYWORD, + pattern=Pattern.from_typehint(typehint, self_qualname=self_qualname), + default=dct.get(name, EMPTY), + typehint=typehint, + ) + + namespace: dict[str, Any] = {} + parameters: dict[str, Parameter] = {} + for name, value in dct.items(): + if isinstance(value, Parameter): + parameters[name] = value + slots.append(name) + elif isinstance(value, Attribute): + attributes[name] = value + slots.append(name) + else: + if getattr(value, "__isabstractmethod__", False): + abstracts.add(name) + else: + abstracts.discard(name) + namespace[name] = value + + # merge the annotations with the parent annotations + signature = Signature.merge(signatures, parameters) + argnames = tuple(signature.parameters.keys()) + bases = tuple(traits) + bases + spec = AnnotableSpec( + initable=initable, + hashable=hashable, + immutable=immutable, + signature=signature, + attributes=attributes, + ) + + namespace.update( + __argnames__=argnames, + __match_args__=argnames, + __module__=module, + __qualname__=qualname, + __signature__=signature, + __slots__=tuple(slots), + __spec__=spec, + ) + klass = super().__new__(metacls, clsname, bases, namespace, **kwargs) + klass.__abstractmethods__ = frozenset(abstracts) + return klass + + def __call__(cls, *args, **kwargs): + spec: AnnotableSpec = cython.cast(AnnotableSpec, cls.__spec__) + return spec.new(cython.cast(type, cls), args, kwargs) + + +class Immutable: + __slots__ = () + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + def __setattr__(self, name: str, _: Any) -> None: + raise AttributeError( + f"Attribute {name!r} cannot be assigned to immutable instance of " + f"type {type(self)}" + ) + + +class Hashable: + __slots__ = ("__args__", "__precomputed_hash__") + + def __hash__(self) -> int: + return self.__precomputed_hash__ + + def __eq__(self, other) -> bool: + if type(self) is not type(other): + return NotImplemented + return ( + self.__precomputed_hash__ == other.__precomputed_hash__ + and self.__args__ == other.__args__ + ) + + +class Annotable(metaclass=AnnotableMeta, initable=False): + __argnames__: ClassVar[tuple[str, ...]] + __match_args__: ClassVar[tuple[str, ...]] + __signature__: ClassVar[Signature] + + __slots__ = ("__weakref__",) + + def __init__(self, **kwargs): + spec: AnnotableSpec = self.__spec__ + for name, value in kwargs.items(): + __setattr__(self, name, value) + if spec.attributes: + spec.init_attributes(self) + if spec.hashable: + spec.init_precomputes(self) + + def __setattr__(self, name, value) -> None: + spec: AnnotableSpec = self.__spec__ + attr: Attribute + param: Parameter + if param := spec.signature.parameters.get(name): + # try to look up the parameter + value = param.pattern.match(value, {}) + elif attr := spec.attributes.get(name): + # try to look up the attribute + value = attr.pattern.match(value, {}) + __setattr__(self, name, value) + + def __eq__(self, other): + spec: AnnotableSpec = self.__spec__ + # compare types + if type(self) is not type(other): + return NotImplemented + # compare parameters + for name in spec.signature.parameters: + if getattr(self, name) != getattr(other, name): + return False + # compare attributes + for name in spec.attributes: + if getattr(self, name, EMPTY) != getattr(other, name, EMPTY): + return False + return True + + def __getstate__(self): + spec: AnnotableSpec = self.__spec__ + state: dict[str, Any] = {} + for name in spec.signature.parameters: + state[name] = getattr(self, name) + for name in spec.attributes: + value = getattr(self, name, EMPTY) + if value is not EMPTY: + state[name] = value + return state + + def __setstate__(self, state): + spec: AnnotableSpec = self.__spec__ + for name, value in state.items(): + __setattr__(self, name, value) + if spec.hashable: + spec.init_precomputes(self) + + def __repr__(self) -> str: + args = (f"{n}={getattr(self, n)!r}" for n in self.__argnames__) + argstring = ", ".join(args) + return f"{self.__class__.__name__}({argstring})" + + @property + def __args__(self) -> tuple[Any, ...]: + return tuple(getattr(self, name) for name in self.__argnames__) diff --git a/koerce/builders.py b/koerce/builders.py index 564fd95..42adc7a 100644 --- a/koerce/builders.py +++ b/koerce/builders.py @@ -2,7 +2,6 @@ import collections.abc import operator -from collections.abc import Callable from typing import Any import cython @@ -162,7 +161,7 @@ def __eq__(self, other: Any) -> bool: @cython.final @cython.cclass -class Custom(Builder): +class Func(Builder): """Construct a value by calling a function. The function is called with two positional arguments: @@ -177,20 +176,20 @@ class Custom(Builder): The function to apply. """ - func: Callable + func: Any - def __init__(self, func): + def __init__(self, func: Any): self.func = func def __repr__(self): return f"{self.func.__name__}(...)" - def equals(self, other: Custom) -> bool: + def equals(self, other: Func) -> bool: return self.func == other.func @cython.cfunc - def build(self, context): - return self.func(**context) + def build(self, ctx: Context): + return self.func(**ctx) @cython.final @@ -228,7 +227,7 @@ def build(self, ctx: Context): @cython.final @cython.cclass -class Variable(Builder): +class Var(Builder): """Retrieve a value from the context. Parameters @@ -245,7 +244,7 @@ def __init__(self, name: str): def __repr__(self): return f"${self.name}" - def equals(self, other: Variable) -> bool: + def equals(self, other: Var) -> bool: return self.name == other.name @cython.cfunc @@ -528,10 +527,10 @@ class Unop(Builder): The argument to apply the operator to. """ - op: Callable + op: Any arg: Builder - def __init__(self, op: Callable, arg): + def __init__(self, op: Any, arg: Any): self.op = op self.arg = builder(arg) @@ -563,11 +562,11 @@ class Binop(Builder): The right-hand side argument. """ - op: Callable + op: Any arg1: Builder arg2: Builder - def __init__(self, op: Callable, arg1, arg2): + def __init__(self, op: Any, arg1: Any, arg2: Any): self.op = op self.arg1 = builder(arg1) self.arg2 = builder(arg2) @@ -655,7 +654,7 @@ def build(self, ctx: Context): @cython.final @cython.cclass -class Sequence(Builder): +class Seq(Builder): """Pattern that constructs a sequence from the given items. Parameters @@ -680,7 +679,7 @@ def __repr__(self): else: return f"{self.type_.__name__}({elems})" - def equals(self, other: Sequence) -> bool: + def equals(self, other: Seq) -> bool: return self.type_ == other.type_ and self.items == other.items @cython.cfunc @@ -694,7 +693,7 @@ def build(self, ctx: Context): @cython.final @cython.cclass -class Mapping(Builder): +class Map(Builder): """Pattern that constructs a mapping from the given items. Parameters @@ -717,7 +716,7 @@ def __repr__(self): else: return f"{self.type_.__name__}({{{items}}})" - def equals(self, other: Mapping) -> bool: + def equals(self, other: Map) -> bool: return self.type_ == other.type_ and self.items == other.items @cython.cfunc @@ -738,16 +737,16 @@ def builder(obj, allow_custom=False) -> Builder: return obj elif isinstance(obj, collections.abc.Mapping): # allow nesting deferred patterns in dicts - return Mapping(obj) + return Map(obj) elif isinstance(obj, collections.abc.Sequence): # allow nesting deferred patterns in tuples/lists if isinstance(obj, (str, bytes)): return Just(obj) else: - return Sequence(obj) + return Seq(obj) elif callable(obj) and allow_custom: # the object is used as a custom builder function - return Custom(obj) + return Func(obj) else: # the object is used as a constant value return Just(obj) diff --git a/koerce/patterns.py b/koerce/patterns.py index fe4ce58..2f88a8a 100644 --- a/koerce/patterns.py +++ b/koerce/patterns.py @@ -1,9 +1,9 @@ from __future__ import annotations import importlib +import inspect from collections.abc import Callable, Mapping, Sequence from enum import Enum -from inspect import Parameter, Signature from types import UnionType from typing import ( Annotated, @@ -19,9 +19,10 @@ from typing_extensions import GenericMeta, Self, get_original_bases # TODO(kszucs): would be nice to cimport Signature and Builder -from .builders import Builder, Deferred, Variable, builder +from .builders import Builder, Deferred, Var, builder from .utils import ( RewindableIterator, + frozendict, get_type_args, get_type_boundvars, get_type_origin, @@ -36,9 +37,14 @@ class CoercionError(Exception): Context = dict[str, Any] +@cython.cclass +class MatchError(Exception): + pass + + @cython.final @cython.cclass -class NoMatchError(Exception): +class NoMatchError(MatchError): pass @@ -265,7 +271,7 @@ def __rmatmul__(self, name) -> Capture: """ return Capture(name, self) - def __rshift__(self, other: Deferred) -> Replace: + def __rshift__(self, other) -> Replace: """Syntax sugar for replacing a value. Parameters @@ -700,6 +706,8 @@ def equals(self, other: AsType) -> bool: @cython.cfunc def match(self, value, ctx: Context): + if isinstance(value, self.type_): + return value try: return self.type_(value) except ValueError: @@ -904,8 +912,10 @@ class Option(Pattern): default: Any def __init__(self, pat, default=None): - self.pattern = pattern(pat) self.default = default + self.pattern = pattern(pat) + if isinstance(self.pattern, Option): + self.pattern = cython.cast(Option, self.pattern).pattern def __repr__(self) -> str: return f"Option({self.pattern!r}, default={self.default!r})" @@ -916,10 +926,7 @@ def equals(self, other: Option) -> bool: @cython.cfunc def match(self, value, ctx: Context): if value is None: - if self.default is None: - return None - else: - return self.default + return self.default else: return self.pattern.match(value, ctx) @@ -1074,19 +1081,14 @@ def match(self, values, ctx: Context): if isinstance(values, (str, bytes)): raise NoMatchError() - # optimization to avoid unnecessary iteration - if isinstance(self.item, Anything): - return values - + # could initialize the result list with the length of values + result: list = [] try: - it = iter(values) + for item in values: + result.append(self.item.match(item, ctx)) except TypeError: raise NoMatchError() - result: list = [] - for item in it: - result.append(self.item.match(item, ctx)) - return self.type_.match(result, ctx) @@ -1122,12 +1124,15 @@ class MappingOf(Pattern): def __init__(self, key: Pattern, value: Pattern, type_=dict): self.key = pattern(key) self.value = pattern(value) - if isinstance(type_, type): - self.type_ = AsType(type_) - elif hasattr(type_, "__coerce__"): + if hasattr(type_, "__coerce__"): self.type_ = CoercedTo(type_) else: - raise TypeError(f"Cannot coerce to container type {type_}") + try: + type_({}) + except TypeError: + self.type_ = AsType(dict) + else: + self.type_ = AsType(type_) def __repr__(self) -> str: return f"MappingOf({self.key!r}, {self.value!r}, {self.type_!r})" @@ -1157,6 +1162,10 @@ def DictOf(key, value) -> Pattern: return MappingOf(key, value, dict) +def FrozenDictOf(key, value) -> Pattern: + return MappingOf(key, value, frozendict) + + @cython.final @cython.cclass class Custom(Pattern): @@ -1202,10 +1211,10 @@ class Capture(Pattern): key: str what: Pattern - def __init__(self, key: str | Deferred | Builder, what=_any): + def __init__(self, key: Any, what=_any): if isinstance(key, (Deferred, Builder)): key = builder(key) - if isinstance(key, Variable): + if isinstance(key, Var): key = key.name else: raise TypeError("Only variables can be used as capture keys") @@ -1336,8 +1345,8 @@ def match(self, value, ctx: Context): class ObjectOf2(Pattern): type_: Any field1: str - pattern1: Pattern field2: str + pattern1: Pattern pattern2: Pattern def __init__(self, type_: Any, **kwargs): @@ -1354,8 +1363,8 @@ def equals(self, other: ObjectOf2) -> bool: return ( self.type_ == other.type_ and self.field1 == other.field1 - and self.pattern1 == other.pattern1 and self.field2 == other.field2 + and self.pattern1 == other.pattern1 and self.pattern2 == other.pattern2 ) @@ -1405,10 +1414,10 @@ def equals(self, other: ObjectOf3) -> bool: return ( self.type_ == other.type_ and self.field1 == other.field1 - and self.pattern1 == other.pattern1 and self.field2 == other.field2 - and self.pattern2 == other.pattern2 and self.field3 == other.field3 + and self.pattern1 == other.pattern1 + and self.pattern2 == other.pattern2 and self.pattern3 == other.pattern3 ) @@ -1561,21 +1570,27 @@ def match(self, value, ctx: Context): if not callable(value): raise NoMatchError() - sig = Signature.from_callable(value) + sig = inspect.signature(value) has_varargs: bool = False positional: list = [] required_positional: list = [] for p in sig.parameters.values(): - if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): + if p.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): positional.append(p) - if p.default is Parameter.empty: + if p.default is inspect.Parameter.empty: required_positional.append(p) - elif p.kind is Parameter.KEYWORD_ONLY and p.default is Parameter.empty: + elif ( + p.kind is inspect.Parameter.KEYWORD_ONLY + and p.default is inspect.Parameter.empty + ): raise TypeError( "Callable has mandatory keyword-only arguments which cannot be specified" ) - elif p.kind is Parameter.VAR_POSITIONAL: + elif p.kind is inspect.Parameter.VAR_POSITIONAL: has_varargs = True if len(required_positional) > len(self.args): @@ -1868,6 +1883,8 @@ def PatternMap(fields) -> Pattern: return PatternMap1(fields) elif len(fields) == 2: return PatternMap2(fields) + elif len(fields) == 3: + return PatternMap3(fields) else: return PatternMapN(fields) diff --git a/koerce/sugar.py b/koerce/sugar.py deleted file mode 100644 index 7292297..0000000 --- a/koerce/sugar.py +++ /dev/null @@ -1,97 +0,0 @@ -from __future__ import annotations - -import sys -from typing import Any - -from .builders import Deferred, Variable -from .patterns import ( - Capture, - Context, - Eq, - If, - NoMatch, # noqa: F401 - Pattern, - pattern, -) - - -class Namespace: - """Convenience class for creating patterns for various types from a module. - - Useful to reduce boilerplate when creating patterns for various types from - a module. - - Parameters - ---------- - factory - The pattern to construct with the looked up types. - module - The module object or name to look up the types. - - """ - - __slots__ = ("_factory", "_module") - # _factory: Callable - # _module: ModuleType - - def __init__(self, factory, module): - if isinstance(module, str): - module = sys.modules[module] - self._module = module - self._factory = factory - - def __getattr__(self, name: str): - obj = getattr(self._module, name) - return self._factory(obj) - - -class Var(Deferred): - def __init__(self, name: str): - builder = Variable(name) - super().__init__(builder) - - def __invert__(self): - return Capture(self) - - -var = Var - - -def match(pat: Pattern, value: Any, context: Context = None) -> Any: - """Match a value against a pattern. - - Parameters - ---------- - pat - The pattern to match against. - value - The value to match. - context - Arbitrary mapping of values to be used while matching. - - Returns - ------- - The matched value if the pattern matches, otherwise :obj:`NoMatch`. - - Examples - -------- - >>> assert match(Any(), 1) == 1 - >>> assert match(1, 1) == 1 - >>> assert match(1, 2) is NoMatch - >>> assert match(1, 1, context={"x": 1}) == 1 - >>> assert match(1, 2, context={"x": 1}) is NoMatch - >>> assert match([1, int], [1, 2]) == [1, 2] - >>> assert match([1, int, "a" @ InstanceOf(str)], [1, 2, "three"]) == [ - ... 1, - ... 2, - ... "three", - ... ] - - """ - pat = pattern(pat) - return pat.apply(value, context) - - -if_ = If -eq = Eq -_ = var("_") diff --git a/koerce/tests/test_annots.py b/koerce/tests/test_annots.py index ab3e34f..0f020fb 100644 --- a/koerce/tests/test_annots.py +++ b/koerce/tests/test_annots.py @@ -1,87 +1,122 @@ from __future__ import annotations +import copy +import pickle +import weakref +from abc import ABCMeta, abstractmethod from dataclasses import dataclass from typing import ( Annotated, + Callable, + Generic, + Mapping, Optional, + Sequence, + TypeVar, Union, ) import pytest +from typing_extensions import Self -from koerce.annots import ( +from koerce._internal import ( EMPTY, - Parameter, - Signature, - ValidationError, - annotated, -) -from koerce.patterns import ( + Annotable, + AnnotableMeta, Anything, + AsType, + FrozenDictOf, + Hashable, + Immutable, InstanceOf, MappingOf, + MatchError, NoMatchError, Option, - PatternMap, + Parameter, + Pattern, + Signature, TupleOf, + annotated, + argument, + attribute, + optional, pattern, + varargs, + varkwargs, ) def test_parameter(): - p = Parameter("x", Parameter.POSITIONAL_OR_KEYWORD, typehint=int) - assert p.name == "x" + p = Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=int) assert p.kind is Parameter.POSITIONAL_OR_KEYWORD - assert str(p) == "x: int" + assert p.format("x") == "x: int" - p = Parameter("x", Parameter.POSITIONAL_OR_KEYWORD, default=1) - assert p.name == "x" + p = Parameter(Parameter.POSITIONAL_OR_KEYWORD, default=1) assert p.kind is Parameter.POSITIONAL_OR_KEYWORD assert p.default_ == 1 - assert str(p) == "x=1" + assert p.format("x") == "x=1" + assert p.pattern == Anything() - p = Parameter("x", Parameter.POSITIONAL_OR_KEYWORD, typehint=int, default=1) - assert p.name == "x" + p = Parameter( + Parameter.POSITIONAL_OR_KEYWORD, typehint=int, default=1, pattern=is_int + ) assert p.kind is Parameter.POSITIONAL_OR_KEYWORD assert p.default_ == 1 assert p.typehint is int - assert str(p) == "x: int = 1" + assert p.format("x") == "x: int = 1" + assert p.pattern == is_int - p = Parameter("y", Parameter.VAR_POSITIONAL, typehint=int) - assert p.name == "y" + p = Parameter(Parameter.VAR_POSITIONAL, typehint=int, pattern=is_int) assert p.kind is Parameter.VAR_POSITIONAL assert p.typehint is int - assert str(p) == "*y: int" + assert p.format("y") == "*y: int" + assert p.pattern == TupleOf(is_int) - p = Parameter("z", Parameter.VAR_KEYWORD, typehint=int) - assert p.name == "z" + p = Parameter(Parameter.VAR_KEYWORD, typehint=int, pattern=is_int) assert p.kind is Parameter.VAR_KEYWORD assert p.typehint is int - assert str(p) == "**z: int" + assert p.format("z") == "**z: int" + assert p.pattern == FrozenDictOf(Anything(), is_int) def test_signature_contruction(): - a = Parameter("a", Parameter.POSITIONAL_OR_KEYWORD, typehint=int) - b = Parameter("b", Parameter.POSITIONAL_OR_KEYWORD, typehint=str) - c = Parameter("c", Parameter.POSITIONAL_OR_KEYWORD, typehint=int, default=1) - d = Parameter("d", Parameter.VAR_POSITIONAL, typehint=int) + a = Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=int) + b = Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=str) + c = Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=int, default=1) + d = Parameter(Parameter.VAR_POSITIONAL, typehint=int, pattern=is_int) - sig = Signature([a, b, c, d]) - assert sig.parameters == [a, b, c, d] + sig = Signature({"a": a, "b": b, "c": c, "d": d}) + assert sig.parameters == {"a": a, "b": b, "c": c, "d": d} assert sig.return_typehint is EMPTY + assert sig.return_pattern == Anything() + + +def test_signature_equality_comparison(): + # order of parameters matters + a = Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=int) + b = Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=str) + c = Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=int, default=1) + + sig1 = Signature({"a": a, "b": b, "c": c}) + sig2 = Signature({"a": a, "b": b, "c": c}) + assert sig1 == sig2 + + sig3 = Signature({"a": a, "c": c, "b": b}) + assert sig1 != sig3 def test_signature_from_callable(): def func(a: int, b: str, *args, c=1, **kwargs) -> float: ... sig = Signature.from_callable(func) - assert sig.parameters == [ - Parameter("a", Parameter.POSITIONAL_OR_KEYWORD, typehint=int), - Parameter("b", Parameter.POSITIONAL_OR_KEYWORD, typehint=str), - Parameter("args", Parameter.VAR_POSITIONAL), - Parameter("c", Parameter.KEYWORD_ONLY, default=1), - Parameter("kwargs", Parameter.VAR_KEYWORD), - ] + assert sig.parameters == { + "a": Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=int), + "b": Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=str), + "args": Parameter(Parameter.VAR_POSITIONAL), + "c": Parameter(Parameter.KEYWORD_ONLY, default=1), + "kwargs": Parameter(Parameter.VAR_KEYWORD), + } assert sig.return_typehint is float @@ -537,24 +572,26 @@ def f3(d, a=1, **kwargs): ... sig2 = Signature.from_callable(f2) sig3 = Signature.from_callable(f3) - sig12 = Signature.merge(sig1, sig2) - assert sig12.parameters == [ - Parameter("a", Parameter.POSITIONAL_OR_KEYWORD), - Parameter("b", Parameter.POSITIONAL_OR_KEYWORD), - Parameter("d", Parameter.POSITIONAL_OR_KEYWORD), - Parameter("e", Parameter.POSITIONAL_OR_KEYWORD), - Parameter("c", Parameter.POSITIONAL_OR_KEYWORD, default=1), - Parameter("f", Parameter.POSITIONAL_OR_KEYWORD, default=2), - ] - - sig13 = Signature.merge(sig1, sig3) - assert sig13.parameters == [ - Parameter("b", Parameter.POSITIONAL_OR_KEYWORD), - Parameter("d", Parameter.POSITIONAL_OR_KEYWORD), - Parameter("a", Parameter.POSITIONAL_OR_KEYWORD, default=1), - Parameter("c", Parameter.POSITIONAL_OR_KEYWORD, default=1), - Parameter("kwargs", Parameter.VAR_KEYWORD), - ] + sig12 = Signature.merge([sig1, sig2]) + assert sig12.parameters == { + "a": Parameter(Parameter.POSITIONAL_OR_KEYWORD), + "b": Parameter(Parameter.POSITIONAL_OR_KEYWORD), + "d": Parameter(Parameter.POSITIONAL_OR_KEYWORD), + "e": Parameter(Parameter.POSITIONAL_OR_KEYWORD), + "c": Parameter(Parameter.POSITIONAL_OR_KEYWORD, default=1), + "f": Parameter(Parameter.POSITIONAL_OR_KEYWORD, default=2), + } + assert tuple(sig12.parameters.keys()) == ("a", "b", "d", "e", "c", "f") + + sig13 = Signature.merge([sig1, sig3]) + assert sig13.parameters == { + "b": Parameter(Parameter.POSITIONAL_OR_KEYWORD), + "d": Parameter(Parameter.POSITIONAL_OR_KEYWORD), + "a": Parameter(Parameter.POSITIONAL_OR_KEYWORD, default=1), + "c": Parameter(Parameter.POSITIONAL_OR_KEYWORD, default=1), + "kwargs": Parameter(Parameter.VAR_KEYWORD), + } + assert tuple(sig13.parameters.keys()) == ("b", "d", "a", "c", "kwargs") def test_annotated_function(): @@ -567,7 +604,7 @@ def test(a, b, c=1): assert test(2, 3, c=4) == 9 assert test(a=2, b=3, c=4) == 9 - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): test(2, 3, c="4") @annotated(a=InstanceOf(int)) @@ -607,7 +644,7 @@ def test_wrong(a: int, b: int, c: int = 1) -> int: return "invalid result" assert test_ok(2, 3) == 6 - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): test_wrong(2, 3) @@ -616,27 +653,44 @@ def test_annotated_function_with_keyword_overrides(): def test(a: int, b: int, c: int = 1): return a + b + c - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): test(2, 3) assert test(2, 3.0) == 6.0 def test_annotated_function_with_list_overrides(): - @annotated([InstanceOf(int), InstanceOf(int), InstanceOf(str)]) - def test(a: int, b: int, c: int = 1): + with pytest.raises(NoMatchError): + + @annotated([InstanceOf(int), InstanceOf(int), InstanceOf(str)]) + def test(a: int, b: int, c: int = 1): + return a + b + c + + @annotated([InstanceOf(int), InstanceOf(int), InstanceOf(float)]) + def test(a: int, b: int, c: int = 1.0): return a + b + c - with pytest.raises(ValidationError): + assert test(2, 3) == 6.0 + assert isinstance(test(2, 3), float) + with pytest.raises(NoMatchError): test(2, 3, 4) def test_annotated_function_with_list_overrides_and_return_override(): + with pytest.raises(NoMatchError): + + @annotated( + [InstanceOf(int), InstanceOf(int), InstanceOf(float)], InstanceOf(float) + ) + def test(a: int, b: int, c: int = 1): + return a + b + c + @annotated([InstanceOf(int), InstanceOf(int), InstanceOf(float)], InstanceOf(float)) - def test(a: int, b: int, c: int = 1): + def test(a: int, b: int, c: int = 1.1): return a + b + c - with pytest.raises(ValidationError): + assert test(2, 3) == 6.1 + with pytest.raises(NoMatchError): test(2, 3, 4) assert test(2, 3, 4.0) == 9.0 @@ -666,11 +720,11 @@ def test(a: Annotated[str, short_str, endswith_d], b: Union[int, float]): assert test("abcd", 1) == ("abcd", 1) assert test("---d", 1.0) == ("---d", 1.0) - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): test("---c", 1) - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): test("123", 1) - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): test("abcd", "qweqwe") @@ -680,7 +734,7 @@ def test(a, b, c): return a, b, c assert test(1, 2, 3) == (1, 2, 3) - assert [p.name for p in test.__signature__.parameters] == ["a", "b", "c"] + assert list(test.__signature__.parameters.keys()) == ["a", "b", "c"] def test_annotated_function_without_decoration(): @@ -702,7 +756,7 @@ def test(a: float, b: float, *args: int): assert test(1.0, 2.0, 3, 4) == 10.0 assert test(1.0, 2.0, 3, 4, 5) == 15.0 - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): test(1.0, 2.0, 3, 4, 5, 6.0) @@ -714,7 +768,7 @@ def test(a: float, b: float, **kwargs: int): assert test(1.0, 2.0, c=3, d=4) == 10.0 assert test(1.0, 2.0, c=3, d=4, e=5) == 15.0 - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): test(1.0, 2.0, c=3, d=4, e=5, f=6.0) @@ -728,42 +782,39 @@ def test(a: float, b: float, **kwargs: int): # assert len(excinfo.value.errors) == 2 -def test_signature_to_pattern(): +def test_signature_patterns(): def func(a: int, b: str) -> str: ... - args, ret = Signature.from_callable(func).to_pattern() - assert args == PatternMap({"a": InstanceOf(int), "b": InstanceOf(str)}) - assert ret == InstanceOf(str) + sig = Signature.from_callable(func) + assert sig.parameters["a"].pattern == InstanceOf(int) + assert sig.parameters["b"].pattern == InstanceOf(str) + assert sig.return_pattern == InstanceOf(str) def func(a: int, b: str, c: str = "0") -> str: ... - args, ret = Signature.from_callable(func).to_pattern() - assert args == PatternMap( - {"a": InstanceOf(int), "b": InstanceOf(str), "c": Option(InstanceOf(str), "0")} - ) - assert ret == InstanceOf(str) + sig = Signature.from_callable(func) + assert sig.parameters["a"].pattern == InstanceOf(int) + assert sig.parameters["b"].pattern == InstanceOf(str) + assert sig.parameters["c"].pattern == InstanceOf(str) + assert sig.return_pattern == InstanceOf(str) def func(a: int, b: str, *args): ... - args, ret = Signature.from_callable(func).to_pattern() - assert args == PatternMap( - {"a": InstanceOf(int), "b": InstanceOf(str), "args": TupleOf(Anything())} - ) - assert ret == Anything() + sig = Signature.from_callable(func) + assert sig.parameters["a"].pattern == InstanceOf(int) + assert sig.parameters["b"].pattern == InstanceOf(str) + assert sig.parameters["args"].pattern == TupleOf(Anything()) + assert sig.return_pattern == Anything() def func(a: int, b: str, c: str = "0", *args, **kwargs: int) -> float: ... - args, ret = Signature.from_callable(func).to_pattern() - assert args == PatternMap( - { - "a": InstanceOf(int), - "b": InstanceOf(str), - "c": Option(InstanceOf(str), "0"), - "args": TupleOf(Anything()), - "kwargs": MappingOf(Anything(), InstanceOf(int)), - } - ) - assert ret == InstanceOf(float) + sig = Signature.from_callable(func) + assert sig.parameters["a"].pattern == InstanceOf(int) + assert sig.parameters["b"].pattern == InstanceOf(str) + assert sig.parameters["c"].pattern == InstanceOf(str) + assert sig.parameters["args"].pattern == TupleOf(Anything()) + assert sig.parameters["kwargs"].pattern == FrozenDictOf(Anything(), InstanceOf(int)) + assert sig.return_pattern == InstanceOf(float) def test_annotated_with_class(): @@ -780,7 +831,7 @@ def __init__(self, a, b, c, d=1): self.c = c self.d = d - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): A(1, "2", "d") @@ -802,8 +853,1176 @@ class InventoryItem: assert item.unit_price == 3.0 assert item.quantity_on_hand == 0 - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): InventoryItem("widget", "3.0", 10) - with pytest.raises(ValidationError): + with pytest.raises(NoMatchError): InventoryItem("widget", 3.0, "10") + + +################################################## + + +is_any = InstanceOf(object) +is_bool = InstanceOf(bool) +is_float = InstanceOf(float) +is_int = InstanceOf(int) +is_str = InstanceOf(str) +is_list = InstanceOf(list) + + +class Op(Annotable): + pass + + +class Value(Op): + arg = argument(InstanceOf(object)) + + +class StringOp(Value): + arg = argument(InstanceOf(str)) + + +class BetweenSimple(Annotable): + value = argument(is_int) + lower = optional(is_int, default=0) + upper = optional(is_int, default=None) + + +class BetweenWithExtra(Annotable): + extra = attribute(is_int) + value = argument(is_int) + lower = optional(is_int, default=0) + upper = optional(is_int, default=None) + + +class BetweenWithCalculated(Annotable, immutable=True, hashable=True): + value = argument(is_int) + lower = optional(is_int, default=0) + upper = optional(is_int, default=None) + + @attribute + def calculated(self): + return self.value + self.lower + + +class VariadicArgs(Annotable, immutable=True, hashable=True): + args = varargs(is_int) + + +class VariadicKeywords(Annotable, immutable=True, hashable=True): + kwargs = varkwargs(is_int) + + +class VariadicArgsAndKeywords(Annotable, immutable=True, hashable=True): + args = varargs(is_int) + kwargs = varkwargs(is_int) + + +T = TypeVar("T", covariant=True) +K = TypeVar("K", covariant=True) +V = TypeVar("V", covariant=True) + + +class List(Annotable, Generic[T]): + @classmethod + def __coerce__(self, values, T=None): + values = tuple(values) + if values: + head, *rest = values + return ConsList(head, rest) + else: + return EmptyList() + + def __eq__(self, other) -> bool: + if not isinstance(other, List): + return NotImplemented + if len(self) != len(other): + return False + for a, b in zip(self, other): + if a != b: + return False + return True + + +# AnnotableMeta doesn't extend ABCMeta, so we need to register the class +# this is due to performance reasons since ABCMeta overrides +# __instancecheck__ and __subclasscheck__ which makes +# issubclass and isinstance way slower +Sequence.register(List) + + +class EmptyList(List[T]): + def __getitem__(self, key): + raise IndexError(key) + + def __len__(self): + return 0 + + +class ConsList(List[T]): + head: T + rest: List[T] + + def __getitem__(self, key): + if key == 0: + return self.head + else: + return self.rest[key - 1] + + def __len__(self): + return len(self.rest) + 1 + + +class Map(Annotable, Generic[K, V]): + @classmethod + def __coerce__(self, pairs, K=None, V=None): + pairs = dict(pairs) + if pairs: + head_key = next(iter(pairs)) + head_value = pairs.pop(head_key) + rest = pairs + return ConsMap((head_key, head_value), rest) + else: + return EmptyMap() + + def __eq__(self, other) -> bool: + if not isinstance(other, Map): + return NotImplemented + if len(self) != len(other): + return False + for key in self: + if self[key] != other[key]: + return False + return True + + def items(self): + for key in self: + yield key, self[key] + + +# AnnotableMeta doesn't extend ABCMeta, so we need to register the class +# this is due to performance reasons since ABCMeta overrides +# __instancecheck__ and __subclasscheck__ which makes +# issubclass and isinstance way slower +Mapping.register(Map) + + +class EmptyMap(Map[K, V]): + def __getitem__(self, key): + raise KeyError(key) + + def __iter__(self): + return iter(()) + + def __len__(self): + return 0 + + +class ConsMap(Map[K, V]): + head: tuple[K, V] + rest: Map[K, V] + + def __getitem__(self, key): + if key == self.head[0]: + return self.head[1] + else: + return self.rest[key] + + def __iter__(self): + yield self.head[0] + yield from self.rest + + def __len__(self): + return len(self.rest) + 1 + + +class Integer(int): + @classmethod + def __coerce__(cls, value): + return Integer(value) + + +class Float(float): + @classmethod + def __coerce__(cls, value): + return Float(value) + + +class MyExpr(Annotable): + a: Integer + b: List[Float] + c: Map[str, Integer] + + +class MyInt(int): + @classmethod + def __coerce__(cls, value): + return cls(value) + + +class MyFloat(float): + @classmethod + def __coerce__(cls, value): + return cls(value) + + +J = TypeVar("J", bound=MyInt, covariant=True) +F = TypeVar("F", bound=MyFloat, covariant=True) +N = TypeVar("N", bound=Union[MyInt, MyFloat], covariant=True) + + +class MyValue(Annotable, Generic[J, F]): + integer: J + floating: F + numeric: N + + +def test_annotable(): + class Between(BetweenSimple): + pass + + assert not issubclass(type(Between), ABCMeta) + assert type(Between) is AnnotableMeta + + argnames = ("value", "lower", "upper") + signature = BetweenSimple.__signature__ + assert isinstance(signature, Signature) + paramnames = tuple(signature.parameters.keys()) + + assert BetweenSimple.__slots__ == argnames + + obj = BetweenSimple(10, lower=2) + assert obj.value == 10 + assert obj.lower == 2 + assert obj.upper is None + assert obj.__argnames__ == argnames + assert obj.__slots__ == ("value", "lower", "upper") + assert not hasattr(obj, "__dict__") + assert obj.__module__ == __name__ + assert type(obj).__qualname__ == "BetweenSimple" + + # test that a child without additional arguments doesn't have __dict__ + obj = Between(10, lower=2) + assert obj.__slots__ == tuple() + assert not hasattr(obj, "__dict__") + assert obj.__module__ == __name__ + assert type(obj).__qualname__ == "test_annotable..Between" + + copied = copy.copy(obj) + assert obj == copied + assert obj is not copied + + # copied = obj.copy() + # assert obj == copied + # assert obj is not copied + + # obj2 = Between(10, lower=8) + # assert obj.copy(lower=8) == obj2 + + +def test_annotable_with_bound_typevars_properly_coerce_values(): + v = MyValue(1.1, 2.2, 3.3) + assert isinstance(v.integer, MyInt) + assert v.integer == 1 + assert isinstance(v.floating, MyFloat) + assert v.floating == 2.2 + assert isinstance(v.numeric, MyInt) + assert v.numeric == 3 + + +def test_annotable_picklable_with_additional_attributes(): + a = BetweenWithExtra(10, lower=2) + b = BetweenWithExtra(10, lower=2) + assert a == b + assert a is not b + + a.extra = 1 + assert a.extra == 1 + assert a != b + + assert a == pickle.loads(pickle.dumps(a)) + + +def test_annotable_is_mutable_by_default(): + # TODO(kszucs): more exhaustive testing of mutability, e.g. setting + # optional value to None doesn't set to the default value + class Op(Annotable): + __slots__ = ("custom",) + + a = argument(is_int) + b = argument(is_int) + + p = Op(1, 2) + assert p.a == 1 + p.a = 3 + assert p.a == 3 + assert p == Op(a=3, b=2) + + # test that non-annotable attributes can be set as well + p.custom = 1 + assert p.custom == 1 + + +def test_annotable_with_type_annotations() -> None: + class Op1(Annotable): + foo: int + bar: str = "" + + p = Op1(1) + assert p.foo == 1 + assert p.bar == "" + + with pytest.raises(MatchError): + + class Op2(Annotable): + bar: str = None + + class Op2(Annotable): + bar: str | None = None + + op = Op2() + assert op.bar is None + + +class RecursiveNode(Annotable): + child: Optional[Self] = None + + +def test_annotable_with_self_typehint(): + node = RecursiveNode(RecursiveNode(RecursiveNode(None))) + assert isinstance(node, RecursiveNode) + assert isinstance(node.child, RecursiveNode) + assert isinstance(node.child.child, RecursiveNode) + assert node.child.child.child is None + + with pytest.raises(NoMatchError): + RecursiveNode(1) + + +def test_annotable_with_recursive_generic_type_annotations(): + # testing cons list + pattern = Pattern.from_typehint(List[Integer]) + values = ["1", 2.0, 3] + result = pattern.apply(values, {}) + expected = ConsList(1, ConsList(2, ConsList(3, EmptyList()))) + assert result == expected + assert result[0] == 1 + assert result[1] == 2 + assert result[2] == 3 + assert len(result) == 3 + with pytest.raises(IndexError): + result[3] + + # testing cons map + pattern = Pattern.from_typehint(Map[Integer, Float]) + values = {"1": 2, 3: "4.0", 5: 6.0} + result = pattern.apply(values, {}) + expected = ConsMap((1, 2.0), ConsMap((3, 4.0), ConsMap((5, 6.0), EmptyMap()))) + assert result == expected + assert result[1] == 2.0 + assert result[3] == 4.0 + assert result[5] == 6.0 + assert len(result) == 3 + with pytest.raises(KeyError): + result[7] + + # testing both encapsulated in a class + expr = MyExpr(a="1", b=["2.0", 3, True], c={"a": "1", "b": 2, "c": 3.0}) + assert expr.a == 1 + assert expr.b == ConsList(2.0, ConsList(3.0, ConsList(1.0, EmptyList()))) + assert expr.c == ConsMap(("a", 1), ConsMap(("b", 2), ConsMap(("c", 3), EmptyMap()))) + + +def test_annotable_as_immutable(): + class AnnImm(Annotable, immutable=True): + value = argument(is_int) + lower = optional(is_int, default=0) + upper = optional(is_int, default=None) + + assert AnnImm.__mro__ == (AnnImm, Immutable, Annotable, object) + + obj = AnnImm(3, lower=0, upper=4) + with pytest.raises(AttributeError): + obj.value = 1 + + +def test_annotable_equality_checks(): + class Between(Annotable): + value = argument(is_int) + lower = optional(is_int, default=0) + upper = optional(is_int, default=None) + + a = Between(3, lower=0, upper=4) + b = Between(3, lower=0, upper=4) + c = Between(2, lower=0, upper=4) + + assert a == b + assert b == a + assert a != c + assert c != a + assert a.__eq__(b) + assert not a.__eq__(c) + + +def test_maintain_definition_order(): + class Between(Annotable): + value = argument(is_int) + lower = optional(is_int, default=0) + upper = optional(is_int, default=None) + + assert Between.__argnames__ == ("value", "lower", "upper") + + +def test_signature_inheritance(): + class IntBinop(Annotable): + left = argument(is_int) + right = argument(is_int) + + class FloatAddRhs(IntBinop): + right = argument(is_float) + + class FloatAddClip(FloatAddRhs): + left = argument(is_float) + clip_lower = optional(is_int, default=0) + clip_upper = optional(is_int, default=10) + + class IntAddClip(FloatAddClip, IntBinop): + pass + + assert IntBinop.__signature__ == Signature( + { + "left": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_int), + "right": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_int), + } + ) + + assert FloatAddRhs.__signature__ == Signature( + { + "left": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_int), + "right": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_float), + } + ) + + assert FloatAddClip.__signature__ == Signature( + { + "left": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_float), + "right": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_float), + "clip_lower": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_int, default=0), + "clip_upper": Parameter( + Parameter.POSITIONAL_OR_KEYWORD, is_int, default=10 + ), + } + ) + + assert IntAddClip.__signature__ == Signature( + { + "left": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_float), + "right": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_float), + "clip_lower": Parameter(Parameter.POSITIONAL_OR_KEYWORD, is_int, default=0), + "clip_upper": Parameter( + Parameter.POSITIONAL_OR_KEYWORD, is_int, default=10 + ), + } + ) + + +def test_positional_argument_reordering(): + class Farm(Annotable): + ducks = argument(is_int) + donkeys = argument(is_int) + horses = argument(is_int) + goats = argument(is_int) + chickens = argument(is_int) + + class NoHooves(Farm): + horses = optional(is_int, default=0) + goats = optional(is_int, default=0) + donkeys = optional(is_int, default=0) + + f1 = Farm(1, 2, 3, 4, 5) + f2 = Farm(1, 2, goats=4, chickens=5, horses=3) + f3 = Farm(1, 0, 0, 0, 100) + assert f1 == f2 + assert f1 != f3 + + g1 = NoHooves(1, 2, donkeys=-1) + assert g1.ducks == 1 + assert g1.chickens == 2 + assert g1.donkeys == -1 + assert g1.horses == 0 + assert g1.goats == 0 + + +def test_keyword_argument_reordering(): + class Alpha(Annotable): + a = argument(is_int) + b = argument(is_int) + + class Beta(Alpha): + c = argument(is_int) + d = optional(is_int, default=0) + e = argument(is_int) + + obj = Beta(1, 2, 3, 4) + assert obj.a == 1 + assert obj.b == 2 + assert obj.c == 3 + assert obj.e == 4 + assert obj.d == 0 + + obj = Beta(1, 2, 3, 4, 5) + assert obj.d == 5 + assert obj.e == 4 + + +def test_variadic_argument_reordering(): + class Test(Annotable): + a = argument(is_int) + b = argument(is_int) + args = varargs(is_int) + + class Test2(Test): + c = argument(is_int) + args = varargs(is_int) + + with pytest.raises(TypeError, match="missing a required argument: 'c'"): + Test2(1, 2) + + a = Test2(1, 2, 3) + assert a.a == 1 + assert a.b == 2 + assert a.c == 3 + assert a.args == () + + b = Test2(*range(5)) + assert b.a == 0 + assert b.b == 1 + assert b.c == 2 + assert b.args == (3, 4) + + msg = "only one variadic \\*args parameter is allowed" + with pytest.raises(TypeError, match=msg): + + class Test3(Test): + another_args = varargs(is_int) + + +def test_variadic_keyword_argument_reordering(): + class Test(Annotable): + a = argument(is_int) + b = argument(is_int) + options = varkwargs(is_int) + + class Test2(Test): + c = argument(is_int) + options = varkwargs(is_int) + + with pytest.raises(TypeError, match="missing a required argument: 'c'"): + Test2(1, 2) + + a = Test2(1, 2, c=3) + assert a.a == 1 + assert a.b == 2 + assert a.c == 3 + assert a.options == {} + + b = Test2(1, 2, c=3, d=4, e=5) + assert b.a == 1 + assert b.b == 2 + assert b.c == 3 + assert b.options == {"d": 4, "e": 5} + + msg = "only one variadic \\*\\*kwargs parameter is allowed" + with pytest.raises(TypeError, match=msg): + + class Test3(Test): + another_options = varkwargs(is_int) + + +def test_variadic_argument(): + class Test(Annotable): + a = argument(is_int) + b = argument(is_int) + args = varargs(is_int) + + assert Test(1, 2).args == () + assert Test(1, 2, 3).args == (3,) + assert Test(1, 2, 3, 4, 5).args == (3, 4, 5) + + +def test_variadic_keyword_argument(): + class Test(Annotable): + first = argument(is_int) + second = argument(is_int) + options = varkwargs(is_int) + + assert Test(1, 2).options == {} + assert Test(1, 2, a=3).options == {"a": 3} + assert Test(1, 2, a=3, b=4, c=5).options == {"a": 3, "b": 4, "c": 5} + + +# def test_copy_with_variadic_argument(): +# class Foo(Annotable): +# a = is_int +# b = is_int +# args = varargs(is_int) + +# class Bar(Concrete): +# a = is_int +# b = is_int +# args = varargs(is_int) + +# for t in [Foo(1, 2, 3, 4, 5), Bar(1, 2, 3, 4, 5)]: +# assert t.a == 1 +# assert t.b == 2 +# assert t.args == (3, 4, 5) + +# u = t.copy(a=6, args=(8, 9, 10)) +# assert u.a == 6 +# assert u.b == 2 +# assert u.args == (8, 9, 10) + + +# def test_concrete_copy_with_unknown_argument_raise(): +# class Bar(Concrete): +# a = is_int +# b = is_int + +# t = Bar(1, 2) +# assert t.a == 1 +# assert t.b == 2 + +# with pytest.raises(AttributeError, match="Unexpected arguments"): +# t.copy(c=3, d=4) + + +def test_immutable_pickling_variadic_arguments(): + v = VariadicArgs(1, 2, 3, 4, 5) + assert v.args == (1, 2, 3, 4, 5) + assert v == pickle.loads(pickle.dumps(v)) + + v = VariadicKeywords(a=3, b=4, c=5) + assert v.kwargs == {"a": 3, "b": 4, "c": 5} + assert v == pickle.loads(pickle.dumps(v)) + + v = VariadicArgsAndKeywords(1, 2, 3, 4, 5, a=3, b=4, c=5) + assert v.args == (1, 2, 3, 4, 5) + assert v.kwargs == {"a": 3, "b": 4, "c": 5} + assert v == pickle.loads(pickle.dumps(v)) + + +def test_dont_copy_default_argument(): + default = tuple() + + class Op(Annotable): + arg = optional(InstanceOf(tuple), default=default) + + op = Op() + assert op.arg is default + + +# def test_copy_mutable_with_default_attribute(): +# class Test(Annotable): +# a = attribute(InstanceOf(dict), default={}) +# b = argument(InstanceOf(str)) # required argument + +# @attribute +# def c(self): +# return self.b.upper() + +# t = Test("t") +# assert t.a == {} +# assert t.b == "t" +# assert t.c == "T" + +# with pytest.raises(ValidationError): +# t.a = 1 +# t.a = {"map": "ping"} +# assert t.a == {"map": "ping"} + +# assert t.copy() == t + +# u = t.copy(b="u") +# assert u.b == "u" +# assert u.c == "T" +# assert u.a == {"map": "ping"} + +# x = t.copy(a={"emp": "ty"}) +# assert x.a == {"emp": "ty"} +# assert x.b == "t" + + +def test_slots_are_inherited_and_overridable(): + class Op(Annotable): + __slots__ = ("_cache",) # first definition + arg = argument(Anything()) + + class StringOp(Op): + arg = argument(AsType(str)) # new overridden slot + + class StringSplit(StringOp): + sep = argument(AsType(str)) # new slot + + class StringJoin(StringOp): + __slots__ = ("_memoize",) # new slot + sep = argument(AsType(str)) # new overridden slot + + assert Op.__slots__ == ("_cache", "arg") + assert StringOp.__slots__ == ("arg",) + assert StringSplit.__slots__ == ("sep",) + assert StringJoin.__slots__ == ("_memoize", "sep") + + +def test_multiple_inheritance(): + # multiple inheritance is allowed only if one of the parents has non-empty + # __slots__ definition, otherwise python will raise lay-out conflict + + class Op(Annotable): + __slots__ = ("_hash",) + + class Value(Annotable): + arg = argument(InstanceOf(object)) + + class Reduction(Value): + pass + + class UDF(Value): + func = argument(InstanceOf(Callable)) + + class UDAF(UDF, Reduction): + arity = argument(is_int) + + class A(Annotable): + a = argument(is_int) + + class B(Annotable): + b = argument(is_int) + + msg = "multiple bases have instance lay-out conflict" + with pytest.raises(TypeError, match=msg): + + class AB(A, B): + ab = argument(is_int) + + assert UDAF.__slots__ == ("arity",) + strlen = UDAF(arg=2, func=lambda value: len(str(value)), arity=1) + assert strlen.arg == 2 + assert strlen.arity == 1 + + +@pytest.mark.parametrize( + "obj", + [ + StringOp("something"), + StringOp(arg="something"), + ], +) +def test_pickling_support(obj): + assert obj == pickle.loads(pickle.dumps(obj)) + + +def test_multiple_inheritance_argument_order(): + class Value(Annotable): + arg = argument(is_any) + + class VersionedOp(Value): + version = argument(is_int) + + class Reduction(Annotable): + pass + + class Sum(VersionedOp, Reduction): + where = optional(is_bool, default=False) + + assert tuple(Sum.__signature__.parameters.keys()) == ("arg", "version", "where") + + +def test_multiple_inheritance_optional_argument_order(): + class Value(Annotable): + pass + + class ConditionalOp(Annotable): + where = optional(is_bool, default=False) + + class Between(Value, ConditionalOp): + min = argument(is_int) + max = argument(is_int) + how = optional(is_str, default="strict") + + assert tuple(Between.__signature__.parameters.keys()) == ( + "min", + "max", + "how", + "where", + ) + + +def test_immutability(): + class Value(Annotable, immutable=True): + a = argument(is_int) + + op = Value(1) + with pytest.raises(AttributeError): + op.a = 3 + + +class BaseValue(Annotable): + i = argument(is_int) + j = attribute(is_int) + + +class Value2(BaseValue): + @attribute + def k(self): + return 3 + + +class Value3(BaseValue): + k = attribute(is_int, default=3) + + +class Value4(BaseValue): + k = attribute(Option(is_int), default=None) + + +def test_annotable_with_dict_slot(): + class Flexible(Annotable): + __slots__ = ("__dict__",) + + v = Flexible() + v.a = 1 + v.b = 2 + assert v.a == 1 + assert v.b == 2 + + +def test_annotable_attribute(): + with pytest.raises(TypeError, match="too many positional arguments"): + BaseValue(1, 2) + + v = BaseValue(1) + assert v.__slots__ == ("i", "j") + assert v.i == 1 + assert not hasattr(v, "j") + v.j = 2 + assert v.j == 2 + + # TODO(kszucs) + # with pytest.raises(TypeError): + # v.j = "foo" + + +def test_annotable_attribute_init(): + assert Value2.__slots__ == ("k",) + v = Value2(1) + + assert v.i == 1 + assert not hasattr(v, "j") + v.j = 2 + assert v.j == 2 + assert v.k == 3 + + v = Value3(1) + assert v.k == 3 + + v = Value4(1) + assert v.k is None + + +def test_annotable_mutability_and_serialization(): + v_ = BaseValue(1) + v_.j = 2 + v = BaseValue(1) + v.j = 2 + assert v_ == v + assert v_.j == v.j == 2 + + assert repr(v) == "BaseValue(i=1)" + w = pickle.loads(pickle.dumps(v)) + assert w.i == 1 + assert w.j == 2 + assert v == w + + v.j = 4 + assert v_ != v + w = pickle.loads(pickle.dumps(v)) + assert w == v + assert repr(w) == "BaseValue(i=1)" + + +def test_initialized_attribute_basics(): + class Value(Annotable): + a = argument(is_int) + + @attribute + def double_a(self): + return 2 * self.a + + op = Value(1) + assert op.a == 1 + assert op.double_a == 2 + assert "double_a" in Value.__slots__ + + +def test_initialized_attribute_with_validation(): + class Value(Annotable): + a = argument(is_int) + + @attribute(int) + def double_a(self): + return 2 * self.a + + op = Value(1) + assert op.a == 1 + assert op.double_a == 2 + assert "double_a" in Value.__slots__ + + op.double_a = 3 + assert op.double_a == 3 + + with pytest.raises(NoMatchError): + op.double_a = "foo" + + +def test_initialized_attribute_mixed_with_classvar(): + class Value(Annotable): + arg = argument(is_int) + + shape = "like-arg" + dtype = "like-arg" + + class Reduction(Value): + shape = "scalar" + + class Variadic(Value): + @attribute + def shape(self): + if self.arg > 10: + return "columnar" + else: + return "scalar" + + r = Reduction(1) + assert r.shape == "scalar" + assert "shape" not in r.__slots__ + + v = Variadic(1) + assert v.shape == "scalar" + assert "shape" in v.__slots__ + + v = Variadic(100) + assert v.shape == "columnar" + assert "shape" in v.__slots__ + + +# def test_composition_of_annotable_and_singleton() -> None: +# class AnnSing(Annotable, Singleton): +# value = CoercedTo(int) + +# class SingAnn(Singleton, Annotable): +# # this is the preferable method resolution order +# value = CoercedTo(int) + +# # arguments looked up after validation +# obj1 = AnnSing("3") +# assert AnnSing("3") is obj1 +# assert AnnSing(3) is obj1 +# assert AnnSing(3.0) is obj1 + +# # arguments looked up before validation +# obj2 = SingAnn("3") +# assert SingAnn("3") is obj2 +# obj3 = SingAnn(3) +# assert obj3 is not obj2 +# assert SingAnn(3) is obj3 + + +def test_hashable(): + assert BetweenWithCalculated.__mro__ == ( + BetweenWithCalculated, + Hashable, + Immutable, + Annotable, + object, + ) + + assert BetweenWithCalculated.__eq__ is Hashable.__eq__ + assert BetweenWithCalculated.__argnames__ == ("value", "lower", "upper") + + # annotable + obj = BetweenWithCalculated(10, lower=5, upper=15) + obj2 = BetweenWithCalculated(10, lower=5, upper=15) + assert obj.value == 10 + assert obj.lower == 5 + assert obj.upper == 15 + assert obj.calculated == 15 + assert obj == obj2 + assert obj is not obj2 + assert obj != (10, 5, 15) + assert obj.__args__ == (10, 5, 15) + # assert obj.args == (10, 5, 15) + # assert obj.argnames == ("value", "lower", "upper") + + # immutable + with pytest.raises(AttributeError): + obj.value = 11 + + # hashable + assert {obj: 1}.get(obj) == 1 + + # weakrefable + ref = weakref.ref(obj) + assert ref() == obj + + # serializable + assert pickle.loads(pickle.dumps(obj)) == obj + + +# def test_composition_of_concrete_and_singleton(): +# class ConcSing(Concrete, Singleton): +# value = CoercedTo(int) + +# class SingConc(Singleton, Concrete): +# value = CoercedTo(int) + +# # arguments looked up after validation +# obj = ConcSing("3") +# assert ConcSing("3") is obj +# assert ConcSing(3) is obj +# assert ConcSing(3.0) is obj + +# # arguments looked up before validation +# obj = SingConc("3") +# assert SingConc("3") is obj +# obj2 = SingConc(3) +# assert obj2 is not obj +# assert SingConc(3) is obj2 + + +def test_init_subclass_keyword_arguments(): + class Test(Annotable): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__() + cls.kwargs = kwargs + + class Test2(Test, something="value", value="something"): + pass + + assert Test2.kwargs == {"something": "value", "value": "something"} + + +def test_argument_order_using_optional_annotations(): + class Case1(Annotable): + results: Optional[tuple[int, ...]] = () + default: Optional[int] = None + + class SimpleCase1(Case1): + base: int + cases: Optional[tuple[int, ...]] = () + + class Case2(Annotable): + results = optional(TupleOf(is_int), default=()) + default = optional(is_int) + + class SimpleCase2(Case2): + base = argument(is_int) + cases = optional(TupleOf(is_int), default=()) + + assert ( + SimpleCase1.__argnames__ + == SimpleCase2.__argnames__ + == ("base", "cases", "results", "default") + ) + + +def test_annotable_with_optional_coercible_typehint(): + class Example(Annotable): + value: Optional[MyInt] = None + + assert Example().value is None + assert Example(None).value is None + assert Example(1).value == 1 + assert isinstance(Example(1).value, MyInt) + + +# def test_error_message(snapshot): +# class Example(Annotable): +# a: int +# b: int = 0 +# c: str = "foo" +# d: Optional[float] = None +# e: tuple[int, ...] = (1, 2, 3) +# f: As[int] = 1 + +# with pytest.raises(ValidationError) as exc_info: +# Example("1", "2", "3", "4", "5", []) + +# # assert "Failed" in str(exc_info.value) + +# if sys.version_info >= (3, 11): +# target = "error_message_py311.txt" +# else: +# target = "error_message.txt" +# snapshot.assert_match(str(exc_info.value), target) + + +def test_annotable_supports_abstractmethods(): + class Foo(Annotable): + @abstractmethod + def foo(self): ... + + @property + @abstractmethod + def bar(self): ... + + assert not issubclass(type(Foo), ABCMeta) + assert issubclass(type(Foo), AnnotableMeta) + assert Foo.__abstractmethods__ == frozenset({"foo", "bar"}) + + with pytest.raises(TypeError, match="Can't instantiate abstract class .*Foo.*"): + Foo() + + class Bar(Foo): + def foo(self): + return 1 + + @property + def bar(self): + return 2 + + bar = Bar() + assert bar.foo() == 1 + assert bar.bar == 2 + assert isinstance(bar, Foo) + assert isinstance(bar, Annotable) + assert Bar.__abstractmethods__ == frozenset() + + +def test_annotable_with_custom_init(): + called_with = None + + class MyInit(Annotable): + a = argument(int) + b = argument(AsType(str)) + c = optional(float, default=0.0) + + def __init__(self, a, b, c): + nonlocal called_with + called_with = (a, b, c) + super().__init__(a=a, b=b, c=c) + + @attribute + def called_with(self): + return (self.a, self.b, self.c) + + with pytest.raises(MatchError): + MyInit(1, 2, 3) + + mi = MyInit(1, 2, 3.3) + assert called_with == (1, "2", 3.3) + assert isinstance(mi, MyInit) + assert mi.a == 1 + assert mi.b == "2" + assert mi.c == 3.3 + assert mi.called_with == called_with diff --git a/koerce/tests/test_builders.py b/koerce/tests/test_builders.py index f4df9b6..33ea2b8 100644 --- a/koerce/tests/test_builders.py +++ b/koerce/tests/test_builders.py @@ -4,7 +4,7 @@ import pytest -from koerce.builders import ( +from koerce._internal import ( Attr, Binop, Call, @@ -13,18 +13,18 @@ Call2, Call3, CallN, - Custom, Deferred, + Func, Item, Just, - Mapping, - Sequence, + Map, + Seq, Unop, - Variable, + Var, builder, ) -_ = Deferred(Variable("_")) +_ = Deferred(Var("_")) def test_builder(): @@ -39,15 +39,12 @@ def fn(x): assert builder(Just(Just(1))) == Just(1) assert builder(MyClass) == Just(MyClass) assert builder(fn) == Just(fn) - assert builder(()) == Sequence(()) - assert builder((1, 2, _)) == Sequence((Just(1), Just(2), _)) - assert builder({}) == Mapping({}) - assert builder({"a": 1, "b": _}) == Mapping({"a": Just(1), "b": _}) + assert builder(()) == Seq(()) + assert builder((1, 2, _)) == Seq((Just(1), Just(2), _)) + assert builder({}) == Map({}) + assert builder({"a": 1, "b": _}) == Map({"a": Just(1), "b": _}) assert builder("string") == Just("string") - # assert builder(var("x")) == Variable("x") - # assert builder(Variable("x")) == Variable("x") - def test_builder_just(): p = Just(1) @@ -64,14 +61,14 @@ def test_builder_just(): # Just(Factory(lambda _: _)) -def test_builder_variable(): - p = Variable("other") +def test_builder_Var(): + p = Var("other") context = {"other": 10} assert p.apply(context) == 10 -def test_builder_custom(): - f = Custom(lambda _: _ + 1) +def test_builder_func(): + f = Func(lambda _: _ + 1) assert f.apply({"_": 1}) == 2 assert f.apply({"_": 2}) == 3 @@ -79,7 +76,7 @@ def fn(**kwargs): assert kwargs == {"_": 10, "a": 5} return -1 - f = Custom(fn) + f = Func(fn) assert f.apply({"_": 10, "a": 5}) == -1 @@ -148,7 +145,7 @@ def __init__(self, a, b): def __hash__(self): return hash((type(self), self.a, self.b)) - v = Variable("v") + v = Var("v") b = Attr(v, "b") assert b.apply({"v": MyType(1, 2)}) == 2 @@ -161,32 +158,32 @@ def __hash__(self): def test_builder_item(): - v = Variable("v") + v = Var("v") b = Item(v, Just(1)) assert b.apply({"v": [1, 2, 3]}) == 2 b = Item(Just(dict(a=1, b=2)), Just("a")) assert b.apply({}) == 1 - name = Variable("name") + name = Var("name") # test that name can be a deferred as well b = Item(v, name) assert b.apply({"v": {"a": 1, "b": 2}, "name": "b"}) == 2 -def test_builder_sequence(): - b = Sequence([Just(1), Just(2), Just(3)]) +def test_builder_Seq(): + b = Seq([Just(1), Just(2), Just(3)]) assert b.apply({}) == [1, 2, 3] - b = Sequence((Just(1), Just(2), Just(3))) + b = Seq((Just(1), Just(2), Just(3))) assert b.apply({}) == (1, 2, 3) -def test_builder_mapping(): - b = Mapping({"a": Just(1), "b": Just(2)}) +def test_builder_Map(): + b = Map({"a": Just(1), "b": Just(2)}) assert b.apply({}) == {"a": 1, "b": 2} - b = Mapping({"a": Just(1), "b": Just(2)}) + b = Map({"a": Just(1), "b": Just(2)}) assert b.apply({}) == {"a": 1, "b": 2} @@ -220,14 +217,14 @@ def test_deferred_builds(value, expected): def test_deferred_supports_string_arguments(): # deferred() is applied on all arguments of Call() except the first one and - # sequences are transparently handled, the check far sequences was incorrect + # Seqs are transparently handled, the check far Seqs was incorrect # for strings causing infinite recursion b = builder("3.14") assert b.apply({}) == "3.14" -def test_deferred_variable_getattr(): - v = Deferred(Variable("v")) +def test_deferred_Var_getattr(): + v = Deferred(Var("v")) p = v.copy assert builder(p) == Attr(v, "copy") assert builder(p).apply({"v": [1, 2, 3]})() == [1, 2, 3] diff --git a/koerce/tests/test_patterns.py b/koerce/tests/test_patterns.py index a7c63fb..bb53d4d 100644 --- a/koerce/tests/test_patterns.py +++ b/koerce/tests/test_patterns.py @@ -19,16 +19,18 @@ import pytest from typing_extensions import Self -from koerce.builders import Call, Deferred, Variable -from koerce.patterns import ( +from koerce import match +from koerce._internal import ( AllOf, AnyOf, Anything, AsType, + Call, CallableWith, Capture, CoercedTo, CoercionError, + Deferred, DictOf, EqValue, GenericCoercedTo, @@ -63,9 +65,9 @@ SomeOf, TupleOf, TypeOf, + Var, pattern, ) -from koerce.sugar import match class Min(Pattern): @@ -540,7 +542,7 @@ def test_capture(): @pytest.mark.parametrize( - "x", [Deferred(Variable("x")), Variable("x")], ids=["deferred", "builder"] + "x", [Deferred(Var("x")), Var("x")], ids=["deferred", "builder"] ) def test_capture_with_deferred_and_builder(x): ctx = {} @@ -762,7 +764,7 @@ def __eq__(self, other): and self.c == other.c ) - a = Variable("a") + a = Var("a") p = ObjectOf(Foo, Capture(a, InstanceOf(int)), c=a) assert p.apply(Foo(1, 2, 3)) is NoMatch @@ -898,14 +900,14 @@ def test_matching(): def test_replace_in_nested_object_pattern(): # simple example using reference to replace a value - b = Variable("b") + b = Var("b") p = ObjectOf(Foo, 1, b=Replace(Anything(), b)) f = p.apply(Foo(1, 2), {"b": 3}) assert f.a == 1 assert f.b == 3 # nested example using reference to replace a value - d = Variable("d") + d = Var("d") p = ObjectOf(Foo, 1, b=ObjectOf(Bar, 2, d=Replace(Anything(), d))) g = p.apply(Foo(1, Bar(2, 3)), {"d": 4}) assert g.b.c == 2 @@ -923,7 +925,7 @@ def test_replace_in_nested_object_pattern(): assert isinstance(h.b, Foo) assert h.b.b == 3 - d = Variable("d") + d = Var("d") p = ObjectOf(Foo, 1, b=ObjectOf(Bar, 2, d=d @ Anything()) >> Call(Foo, -1, b=d)) h1 = p.apply(Foo(1, Bar(2, 3)), {}) assert isinstance(h1, Foo) @@ -942,8 +944,8 @@ def test_replace_in_nested_object_pattern(): def test_replace_using_deferred(): - x = Deferred(Variable("x")) - y = Deferred(Variable("y")) + x = Deferred(Var("x")) + y = Deferred(Var("y")) pat = ObjectOf(Foo, Capture(x), b=Capture(y)) >> Call(Foo, x, b=y) assert pat.apply(Foo(1, 2)) == Foo(1, 2) @@ -982,7 +984,7 @@ def test_matching_sequence_pattern_keeps_original_type(): def test_matching_sequence_with_captures(): - x = Deferred(Variable("x")) + x = Deferred(Var("x")) v = list(range(1, 9)) assert match([1, 2, 3, 4, SomeOf(...)], v) == v diff --git a/koerce/tests/test_sugar.py b/koerce/tests/test_sugar.py index 0369866..6c4fc70 100644 --- a/koerce/tests/test_sugar.py +++ b/koerce/tests/test_sugar.py @@ -1,6 +1,6 @@ from __future__ import annotations -from koerce.sugar import NoMatch, match, var +from koerce import NoMatch, match, var def test_capture_shorthand(): diff --git a/koerce/tests/test_utils.py b/koerce/tests/test_utils.py index a09635d..a2499e6 100644 --- a/koerce/tests/test_utils.py +++ b/koerce/tests/test_utils.py @@ -1,12 +1,15 @@ from __future__ import annotations import inspect -from typing import Dict, Generic, List, Optional, TypeVar, Union +import pickle +from typing import Dict, Generic, List, Mapping, Optional, TypeVar, Union import pytest from typing_extensions import Self from koerce.utils import ( + FrozenDict, + RewindableIterator, get_type_boundvars, get_type_hints, get_type_params, @@ -128,3 +131,59 @@ def test_get_type_boundvars_unable_to_deduce() -> None: msg = "Unable to deduce corresponding type attributes..." with pytest.raises(ValueError, match=msg): get_type_boundvars(MyDict[int, str]) + + +def test_rewindable_iterator(): + it = RewindableIterator(range(10)) + assert next(it) == 0 + assert next(it) == 1 + with pytest.raises(ValueError, match="No checkpoint to rewind to"): + it.rewind() + + it.checkpoint() + assert next(it) == 2 + assert next(it) == 3 + it.rewind() + assert next(it) == 2 + assert next(it) == 3 + assert next(it) == 4 + it.checkpoint() + assert next(it) == 5 + assert next(it) == 6 + it.rewind() + assert next(it) == 5 + assert next(it) == 6 + assert next(it) == 7 + it.rewind() + assert next(it) == 5 + assert next(it) == 6 + assert next(it) == 7 + assert next(it) == 8 + assert next(it) == 9 + with pytest.raises(StopIteration): + next(it) + + +def test_frozendict(): + d = FrozenDict({"a": 1, "b": 2, "c": 3}) + e = FrozenDict(a=1, b=2, c=3) + f = FrozenDict(a=1, b=2, c=3, d=4) + + assert isinstance(d, Mapping) + + assert d == e + assert d != f + + assert d["a"] == 1 + assert d["b"] == 2 + + msg = "'FrozenDict' object does not support item assignment" + with pytest.raises(TypeError, match=msg): + d["a"] = 2 + with pytest.raises(TypeError, match=msg): + d["d"] = 4 + + assert hash(FrozenDict(a=1, b=2)) == hash(FrozenDict(b=2, a=1)) + assert hash(FrozenDict(a=1, b=2)) != hash(d) + + assert d == pickle.loads(pickle.dumps(d)) diff --git a/koerce/tests/test_y.py b/koerce/tests/test_y.py new file mode 100644 index 0000000..2f46c13 --- /dev/null +++ b/koerce/tests/test_y.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +from dataclasses import dataclass +from inspect import Signature as InspectSignature +from typing import Generic + +import pytest + +pydantic = pytest.importorskip("pydantic") + +from ibis.common.grounds import Annotable as IAnnotable +from pydantic import BaseModel, validate_call +from pydantic_core import SchemaValidator +from typing_extensions import TypeVar + +from koerce import ( + Annotable, + PatternMap, + Signature, + annotated, +) + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") + + +class A(Generic[T, S, U]): + a: int + b: str + + t: T + s: S + + @property + def u(self) -> U: # type: ignore + ... + + +@dataclass +class Person: + name: str + age: int + is_developer: bool = True + has_children: bool = False + + +v = SchemaValidator( + { + "type": "typed-dict", + "fields": { + "name": { + "type": "typed-dict-field", + "schema": { + "type": "str", + }, + }, + "age": { + "type": "typed-dict-field", + "schema": { + "type": "int", + }, + }, + "is_developer": { + "type": "typed-dict-field", + "schema": { + "type": "bool", + }, + }, + "has_children": { + "type": "typed-dict-field", + "schema": { + "type": "bool", + }, + }, + }, + } +) + +p = Person(name="Samuel", age=35, is_developer=True, has_children=False) + +data = {"name": "Samuel", "age": 35, "is_developer": True, "has_children": False} + + +ITS = 50 + + +def test_patternmap_pydantic(benchmark): + r1 = benchmark.pedantic( + v.validate_python, args=(data,), iterations=ITS, rounds=20000 + ) + assert r1 == data + + +def test_patternmap_koerce(benchmark): + pat = PatternMap( + {"name": str, "age": int, "is_developer": bool, "has_children": bool} + ) + r2 = benchmark.pedantic(pat.apply, args=(data, {}), iterations=ITS, rounds=20000) + assert r2 == data + + +def func(x: int, y: str, *args: int, z: float = 3.14, **kwargs) -> float: ... + + +args = (1, "a", 2, 3, 4) +kwargs = dict(z=3.14, w=5, q=6) +expected = {"x": 1, "y": "a", "args": (2, 3, 4), "z": 3.14, "kwargs": {"w": 5, "q": 6}} + + +def test_signature_stdlib(benchmark): + sig = InspectSignature.from_callable(func) + r = benchmark.pedantic( + sig.bind, args=args, kwargs=kwargs, iterations=ITS, rounds=20000 + ) + assert r.arguments == expected + + +def test_signature_koerce(benchmark): + sig = Signature.from_callable(func) + r = benchmark.pedantic(sig.bind, args=(args, kwargs), iterations=ITS, rounds=20000) + assert r == expected + + +@validate_call +def prepeat(s: str, count: int, *, separator: bytes = b"") -> bytes: + return b"" + + +@annotated +def krepeat(s: str, count: int, *, separator: bytes = b"") -> bytes: + return b"" + + +def test_validated_call_pydantic(benchmark): + r1 = benchmark.pedantic( + prepeat, + args=("hello", 3), + kwargs={"separator": b" "}, + iterations=ITS, + rounds=20000, + ) + assert r1 == b"" + + +def test_validated_call_annotated(benchmark): + r2 = benchmark.pedantic( + krepeat, + args=("hello", 3), + kwargs={"separator": b" "}, + iterations=ITS, + rounds=20000, + ) + assert r2 == b"" + + +class PUser(BaseModel): + id: int + name: str = "Jane Doe" + age: int | None = None + children: list[str] = [] + + +class KUser(Annotable): + id: int + name: str = "Jane Doe" + age: int | None = None + children: list[str] = [] + + +class IUser(IAnnotable): + id: int + name: str = "Jane Doe" + age: int | None = None + children: list[str] = [] + + +ch = ["Alice", "Bob", "Charlie"] +ch = [] + + +def test_pydantic(benchmark): + r1 = benchmark.pedantic( + PUser, + args=(), + kwargs={"id": 1, "name": "Jane Doe", "age": None, "children": []}, + iterations=ITS, + rounds=20000, + ) + assert r1 == PUser(id=1, name="Jane Doe", age=None, children=[]) + + +def test_annotated(benchmark): + r2 = benchmark.pedantic( + KUser, + args=(), + kwargs={"id": 1, "name": "Jane Doe", "age": None, "children": ()}, + iterations=ITS, + rounds=20000, + ) + assert r2 == KUser(id=1, name="Jane Doe", age=None, children=[]) + + +# def test_ibis(benchmark): +# r2 = benchmark.pedantic( +# IUser, +# args=(), +# kwargs={"id": 1, "name": "Jane Doe", "age": None, "children": ()}, +# iterations=ITS, +# rounds=20000, +# ) +# assert r2 == IUser(id=1, name="Jane Doe", age=None, children=[]) diff --git a/koerce/utils.py b/koerce/utils.py index 362e260..19fb051 100644 --- a/koerce/utils.py +++ b/koerce/utils.py @@ -3,8 +3,12 @@ import itertools import sys import typing +from collections.abc import Hashable from typing import Any, Optional, TypeVar +K = TypeVar("K") +V = TypeVar("V") + get_type_args = typing.get_args get_type_origin = typing.get_origin @@ -163,6 +167,33 @@ def get_type_boundvars(typ: Any) -> dict[TypeVar, tuple[str, type]]: return result +class FrozenDict(dict[K, V], Hashable): + __slots__ = ("__precomputed_hash__",) + __precomputed_hash__: int + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + hashable = frozenset(self.items()) + object.__setattr__(self, "__precomputed_hash__", hash(hashable)) + + def __hash__(self) -> int: + return self.__precomputed_hash__ + + def __setitem__(self, key: K, value: V) -> None: + raise TypeError( + f"'{self.__class__.__name__}' object does not support item assignment" + ) + + def __setattr__(self, name: str, _: Any) -> None: + raise TypeError(f"Attribute {name!r} cannot be assigned to frozendict") + + def __reduce__(self) -> tuple: + return (self.__class__, (dict(self),)) + + +frozendict = FrozenDict + + class RewindableIterator: """Iterator that can be rewound to a checkpoint.