Skip to content

Commit

Permalink
feat: add some pattern matching sugar
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Jul 30, 2024
1 parent 13bb3de commit c7f1153
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 30 deletions.
2 changes: 1 addition & 1 deletion build.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil
from pathlib import Path

from Cython.Build import cythonize, build_ext
from Cython.Build import build_ext, cythonize
from setuptools import Distribution

# import Cython.Compiler.Options
Expand Down
9 changes: 3 additions & 6 deletions koerce/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Context = dict[str, Any]


@cython.final
@cython.cclass
class Deferred:
"""The user facing wrapper object providing syntactic sugar for deferreds.
Expand Down Expand Up @@ -55,6 +54,9 @@ def __invert__(self) -> Deferred:
def __neg__(self) -> Deferred:
return Deferred(Unop(operator.neg, self))

def __pos__(self) -> Deferred:
return Deferred(Unop(operator.pos, self))

def __add__(self, other: Any) -> Deferred:
return Deferred(Binop(operator.add, self, other))

Expand Down Expand Up @@ -749,8 +751,3 @@ def builder(obj, allow_custom=False) -> Builder:
else:
# the object is used as a constant value
return Just(obj)


@cython.ccall
def variable(name):
return Deferred(Variable(name))
31 changes: 19 additions & 12 deletions koerce/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def from_typehint(annot: Any, allow_coercion: bool = True) -> Pattern:
return _any
elif isinstance(annot, Enum):
# for enums we check the value against the enum values
return EqualTo(annot)
return EqValue(annot)
elif isinstance(annot, str):
# for strings and forward references we check in a lazy way
return LazyInstanceOf(annot)
Expand Down Expand Up @@ -319,19 +319,26 @@ def match(self, value, ctx: Context):
else:
return NoMatch

@cython.ccall
def Eq(value) -> Pattern:
if isinstance(value, (Deferred, Builder)):
return EqDeferred(value)
else:
return EqValue(value)


@cython.final
@cython.cclass
class EqualTo(Pattern):
class EqValue(Pattern):
value: Any

def __init__(self, value: Any):
self.value = value

def __repr__(self) -> str:
return f"EqualTo({self.value!r})"
return f"EqValue({self.value!r})"

def equals(self, other: EqualTo) -> bool:
def equals(self, other: EqValue) -> bool:
return self.value == other.value

@cython.cfunc
Expand All @@ -345,7 +352,7 @@ def match(self, value, ctx: Context):

@cython.final
@cython.cclass
class DeferredEqualTo(Pattern):
class EqDeferred(Pattern):
"""Pattern that checks a value equals to the given value.
Parameters
Expand All @@ -361,11 +368,11 @@ def __init__(self, value):
self.value = builder(value)

def __repr__(self) -> str:
return f"DeferredEqualTo({self.value!r})"
return f"EqDeferred({self.value!r})"

@cython.cfunc
def match(self, value, ctx):
ctx["_"] = value
# ctx["_"] = value
# TODO(kszucs): Builder is not cimported so self.value.build() cannot be
# used, hence using .apply() instead
if value == self.value.apply(ctx):
Expand Down Expand Up @@ -1487,7 +1494,7 @@ def _maybe_unwrap_capture(obj):

def PatternList(patterns, type=list):
if patterns == ():
return EqualTo(patterns)
return EqValue(patterns)

patterns = tuple(map(pattern, patterns))
for pat in patterns:
Expand Down Expand Up @@ -1669,12 +1676,12 @@ def pattern(obj: Any, allow_custom: bool = True) -> Pattern:
return obj
elif isinstance(obj, (Deferred, Builder)):
# return Capture(obj)
return DeferredEqualTo(obj)
return EqDeferred(obj)
elif isinstance(obj, Mapping):
return EqualTo(obj)
return EqValue(obj)
elif isinstance(obj, Sequence):
if isinstance(obj, (str, bytes)):
return EqualTo(obj)
return EqValue(obj)
else:
return PatternList(obj, type(obj))
elif isinstance(obj, type):
Expand All @@ -1684,7 +1691,7 @@ def pattern(obj: Any, allow_custom: bool = True) -> Pattern:
elif callable(obj) and allow_custom:
return Custom(obj)
else:
return EqualTo(obj)
return EqValue(obj)


# barhol ahol callback-et lehet hasznalni oda kell egy deferred verzio is
34 changes: 28 additions & 6 deletions koerce/sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import sys
from typing import Any

from .patterns import Context, Pattern, pattern

# if_
# isa
# eq
# as_
from .builders import Deferred, Variable
from .patterns import (
Capture,
Context,
Eq,
If,
NoMatch, # noqa: F401
Pattern,
pattern,
)


class Namespace:
Expand Down Expand Up @@ -41,6 +45,18 @@ def __getattr__(self, name: str):
return self._factory(obj)


class Var(Deferred):
def __init__(self, name: str):
builder = Variable(name)
super().__init__(builder)

def __invert__(self):
return Capture(self)


var = Var


def match(pat: Pattern, value: Any, context: Context = None) -> Any:
"""Match a value against a pattern.
Expand Down Expand Up @@ -74,3 +90,9 @@ def match(pat: Pattern, value: Any, context: Context = None) -> Any:
"""
pat = pattern(pat)
return pat.apply(value, context)


if_ = If
eq = Eq
_ = var("_")

10 changes: 5 additions & 5 deletions koerce/tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CoercedTo,
CoercionError,
DictOf,
EqualTo,
EqValue,
GenericCoercedTo,
GenericInstanceOf,
GenericInstanceOf1,
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_identical_to(value):
],
)
def test_equal_to(a, b, expected):
pattern = EqualTo(a)
pattern = EqValue(a)
if expected:
assert pattern.apply(b) is b
else:
Expand Down Expand Up @@ -1192,7 +1192,7 @@ def f(x):
# ... is treated the same as Any()
assert pattern(...) == Anything()
assert pattern(Anything()) == Anything()
assert pattern(True) == EqualTo(True)
assert pattern(True) == EqValue(True)

# plain types are converted to InstanceOf patterns
assert pattern(int) == InstanceOf(int)
Expand All @@ -1209,11 +1209,11 @@ def f(x):

# spelled out sequences construct a more advanced pattern sequence
assert pattern([int, str, 1]) == PatternList(
[InstanceOf(int), InstanceOf(str), EqualTo(1)]
[InstanceOf(int), InstanceOf(str), EqValue(1)]
)

# matching deferred to user defined functions
# assert pattern(f) == Custom(f)

# matching mapping values
assert pattern({"a": 1, "b": 2}) == EqualTo({"a": 1, "b": 2})
assert pattern({"a": 1, "b": 2}) == EqValue({"a": 1, "b": 2})
28 changes: 28 additions & 0 deletions koerce/tests/test_sugar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from koerce.sugar import NoMatch, match, var


def test_capture_shorthand():
a = var("a")
b = var("b")

ctx = {}
assert match((~a, ~b), (1, 2), ctx) == (1, 2)
assert ctx == {"a": 1, "b": 2}

ctx = {}
assert match((~a, a, a), (1, 2, 3), ctx) is NoMatch
assert ctx == {"a": 1}

ctx = {}
assert match((~a, a, a), (1, 1, 1), ctx) == (1, 1, 1)
assert ctx == {"a": 1}

ctx = {}
assert match((~a, a, a), (1, 1, 2), ctx) is NoMatch
assert ctx == {"a": 1}


def test_namespace():
pass

0 comments on commit c7f1153

Please sign in to comment.