From c7f1153fcfe43511ecf23b5ec1212712903dc273 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 30 Jul 2024 20:09:39 +0200 Subject: [PATCH] feat: add some pattern matching sugar --- build.py | 2 +- koerce/builders.py | 9 +++------ koerce/patterns.py | 31 +++++++++++++++++++------------ koerce/sugar.py | 34 ++++++++++++++++++++++++++++------ koerce/tests/test_patterns.py | 10 +++++----- koerce/tests/test_sugar.py | 28 ++++++++++++++++++++++++++++ 6 files changed, 84 insertions(+), 30 deletions(-) create mode 100644 koerce/tests/test_sugar.py diff --git a/build.py b/build.py index a773570..6541920 100644 --- a/build.py +++ b/build.py @@ -4,7 +4,7 @@ import shutil from pathlib import Path -from Cython.Build import cythonize, build_ext +from Cython.Build import build_ext, cythonize from setuptools import Distribution # import Cython.Compiler.Options diff --git a/koerce/builders.py b/koerce/builders.py index 1fc502f..7de35e1 100644 --- a/koerce/builders.py +++ b/koerce/builders.py @@ -10,7 +10,6 @@ Context = dict[str, Any] -@cython.final @cython.cclass class Deferred: """The user facing wrapper object providing syntactic sugar for deferreds. @@ -55,6 +54,9 @@ def __invert__(self) -> Deferred: def __neg__(self) -> Deferred: return Deferred(Unop(operator.neg, self)) + def __pos__(self) -> Deferred: + return Deferred(Unop(operator.pos, self)) + def __add__(self, other: Any) -> Deferred: return Deferred(Binop(operator.add, self, other)) @@ -749,8 +751,3 @@ def builder(obj, allow_custom=False) -> Builder: else: # the object is used as a constant value return Just(obj) - - -@cython.ccall -def variable(name): - return Deferred(Variable(name)) diff --git a/koerce/patterns.py b/koerce/patterns.py index b9084d9..8e3ff35 100644 --- a/koerce/patterns.py +++ b/koerce/patterns.py @@ -87,7 +87,7 @@ def from_typehint(annot: Any, allow_coercion: bool = True) -> Pattern: return _any elif isinstance(annot, Enum): # for enums we check the value against the enum values - return EqualTo(annot) + return EqValue(annot) elif isinstance(annot, str): # for strings and forward references we check in a lazy way return LazyInstanceOf(annot) @@ -319,19 +319,26 @@ def match(self, value, ctx: Context): else: return NoMatch +@cython.ccall +def Eq(value) -> Pattern: + if isinstance(value, (Deferred, Builder)): + return EqDeferred(value) + else: + return EqValue(value) + @cython.final @cython.cclass -class EqualTo(Pattern): +class EqValue(Pattern): value: Any def __init__(self, value: Any): self.value = value def __repr__(self) -> str: - return f"EqualTo({self.value!r})" + return f"EqValue({self.value!r})" - def equals(self, other: EqualTo) -> bool: + def equals(self, other: EqValue) -> bool: return self.value == other.value @cython.cfunc @@ -345,7 +352,7 @@ def match(self, value, ctx: Context): @cython.final @cython.cclass -class DeferredEqualTo(Pattern): +class EqDeferred(Pattern): """Pattern that checks a value equals to the given value. Parameters @@ -361,11 +368,11 @@ def __init__(self, value): self.value = builder(value) def __repr__(self) -> str: - return f"DeferredEqualTo({self.value!r})" + return f"EqDeferred({self.value!r})" @cython.cfunc def match(self, value, ctx): - ctx["_"] = value + # ctx["_"] = value # TODO(kszucs): Builder is not cimported so self.value.build() cannot be # used, hence using .apply() instead if value == self.value.apply(ctx): @@ -1487,7 +1494,7 @@ def _maybe_unwrap_capture(obj): def PatternList(patterns, type=list): if patterns == (): - return EqualTo(patterns) + return EqValue(patterns) patterns = tuple(map(pattern, patterns)) for pat in patterns: @@ -1669,12 +1676,12 @@ def pattern(obj: Any, allow_custom: bool = True) -> Pattern: return obj elif isinstance(obj, (Deferred, Builder)): # return Capture(obj) - return DeferredEqualTo(obj) + return EqDeferred(obj) elif isinstance(obj, Mapping): - return EqualTo(obj) + return EqValue(obj) elif isinstance(obj, Sequence): if isinstance(obj, (str, bytes)): - return EqualTo(obj) + return EqValue(obj) else: return PatternList(obj, type(obj)) elif isinstance(obj, type): @@ -1684,7 +1691,7 @@ def pattern(obj: Any, allow_custom: bool = True) -> Pattern: elif callable(obj) and allow_custom: return Custom(obj) else: - return EqualTo(obj) + return EqValue(obj) # barhol ahol callback-et lehet hasznalni oda kell egy deferred verzio is diff --git a/koerce/sugar.py b/koerce/sugar.py index d171ee8..3cc7d7b 100644 --- a/koerce/sugar.py +++ b/koerce/sugar.py @@ -3,12 +3,16 @@ import sys from typing import Any -from .patterns import Context, Pattern, pattern - -# if_ -# isa -# eq -# as_ +from .builders import Deferred, Variable +from .patterns import ( + Capture, + Context, + Eq, + If, + NoMatch, # noqa: F401 + Pattern, + pattern, +) class Namespace: @@ -41,6 +45,18 @@ def __getattr__(self, name: str): return self._factory(obj) +class Var(Deferred): + def __init__(self, name: str): + builder = Variable(name) + super().__init__(builder) + + def __invert__(self): + return Capture(self) + + +var = Var + + def match(pat: Pattern, value: Any, context: Context = None) -> Any: """Match a value against a pattern. @@ -74,3 +90,9 @@ def match(pat: Pattern, value: Any, context: Context = None) -> Any: """ pat = pattern(pat) return pat.apply(value, context) + + +if_ = If +eq = Eq +_ = var("_") + diff --git a/koerce/tests/test_patterns.py b/koerce/tests/test_patterns.py index bf49781..dfd930b 100644 --- a/koerce/tests/test_patterns.py +++ b/koerce/tests/test_patterns.py @@ -26,7 +26,7 @@ CoercedTo, CoercionError, DictOf, - EqualTo, + EqValue, GenericCoercedTo, GenericInstanceOf, GenericInstanceOf1, @@ -119,7 +119,7 @@ def test_identical_to(value): ], ) def test_equal_to(a, b, expected): - pattern = EqualTo(a) + pattern = EqValue(a) if expected: assert pattern.apply(b) is b else: @@ -1192,7 +1192,7 @@ def f(x): # ... is treated the same as Any() assert pattern(...) == Anything() assert pattern(Anything()) == Anything() - assert pattern(True) == EqualTo(True) + assert pattern(True) == EqValue(True) # plain types are converted to InstanceOf patterns assert pattern(int) == InstanceOf(int) @@ -1209,11 +1209,11 @@ def f(x): # spelled out sequences construct a more advanced pattern sequence assert pattern([int, str, 1]) == PatternList( - [InstanceOf(int), InstanceOf(str), EqualTo(1)] + [InstanceOf(int), InstanceOf(str), EqValue(1)] ) # matching deferred to user defined functions # assert pattern(f) == Custom(f) # matching mapping values - assert pattern({"a": 1, "b": 2}) == EqualTo({"a": 1, "b": 2}) + assert pattern({"a": 1, "b": 2}) == EqValue({"a": 1, "b": 2}) diff --git a/koerce/tests/test_sugar.py b/koerce/tests/test_sugar.py new file mode 100644 index 0000000..0369866 --- /dev/null +++ b/koerce/tests/test_sugar.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from koerce.sugar import NoMatch, match, var + + +def test_capture_shorthand(): + a = var("a") + b = var("b") + + ctx = {} + assert match((~a, ~b), (1, 2), ctx) == (1, 2) + assert ctx == {"a": 1, "b": 2} + + ctx = {} + assert match((~a, a, a), (1, 2, 3), ctx) is NoMatch + assert ctx == {"a": 1} + + ctx = {} + assert match((~a, a, a), (1, 1, 1), ctx) == (1, 1, 1) + assert ctx == {"a": 1} + + ctx = {} + assert match((~a, a, a), (1, 1, 2), ctx) is NoMatch + assert ctx == {"a": 1} + + +def test_namespace(): + pass