diff --git a/koerce/patterns.py b/koerce/patterns.py index e23c9ea..f52efb9 100644 --- a/koerce/patterns.py +++ b/koerce/patterns.py @@ -279,12 +279,8 @@ def __or__(self, other: Pattern) -> AnyOf: ------- New pattern that matches if either of the patterns match. """ - if isinstance(self, AnyOf) and isinstance(other, AnyOf): - return AnyOf(*self.inners, *other.inners) - elif isinstance(self, AnyOf): - return AnyOf(*self.inners, other) - elif isinstance(other, AnyOf): - return AnyOf(self, *other.inners) + if isinstance(other, AnyOf): + return AnyOf(self, *cython.cast(AnyOf, other).inners) else: return AnyOf(self, other) @@ -300,12 +296,8 @@ def __and__(self, other: Pattern) -> AllOf: ------- New pattern that matches if both of the patterns match. """ - if isinstance(self, AllOf) and isinstance(other, AllOf): - return AllOf(*self.inners, *other.inners) - elif isinstance(self, AllOf): - return AllOf(*self.inners, other) - elif isinstance(other, AllOf): - return AllOf(self, *other.inners) + if isinstance(other, AllOf): + return AllOf(self, *cython.cast(AllOf, other).inners) else: return AllOf(self, other) @@ -881,6 +873,23 @@ def match(self, value, ctx: Context): pass raise NoMatchError() + def __or__(self, other: Pattern) -> AnyOf: + """Syntax sugar for matching either of the patterns. + + Parameters + ---------- + other + The other pattern to match against. + + Returns + ------- + New pattern that matches if either of the patterns match. + """ + if isinstance(other, AnyOf): + return AnyOf(*self.inners, *cython.cast(AnyOf, other).inners) + else: + return AnyOf(*self.inners, other) + @cython.final @cython.cclass @@ -903,6 +912,23 @@ def match(self, value, ctx: Context): value = inner.match(value, ctx) return value + def __and__(self, other: Pattern) -> AllOf: + """Syntax sugar for matching both of the patterns. + + Parameters + ---------- + other + The other pattern to match against. + + Returns + ------- + New pattern that matches if both of the patterns match. + """ + if isinstance(other, AllOf): + return AllOf(*self.inners, *cython.cast(AllOf, other).inners) + else: + return AllOf(*self.inners, other) + def NoneOf(*args) -> Pattern: """Match none of the passed patterns."""