Skip to content

Commit 49549bf

Browse files
cpcloudkszucs
andauthored
fix(patterns): fix composing AnyOf and AllOf patterns (#7)
* test: add test for and-ing and or-ing allof/anyof * fix(patterns): reference the correct attribute in sugar * chore: split up and/or and add cython casts * test: check for AnyOf and AllOf unnesting Co-authored-by: Krisztián Szűcs <szucs.krisztian@gmail.com>
1 parent 15dba59 commit 49549bf

File tree

2 files changed

+65
-13
lines changed

2 files changed

+65
-13
lines changed

koerce/patterns.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,8 @@ def __or__(self, other: Pattern) -> AnyOf:
279279
-------
280280
New pattern that matches if either of the patterns match.
281281
"""
282-
if isinstance(self, AnyOf) and isinstance(other, AnyOf):
283-
return AnyOf(*self.patterns, *other.patterns)
284-
elif isinstance(self, AnyOf):
285-
return AnyOf(*self.patterns, other)
286-
elif isinstance(other, AnyOf):
287-
return AnyOf(self, *other.patterns)
282+
if isinstance(other, AnyOf):
283+
return AnyOf(self, *cython.cast(AnyOf, other).inners)
288284
else:
289285
return AnyOf(self, other)
290286

@@ -300,12 +296,8 @@ def __and__(self, other: Pattern) -> AllOf:
300296
-------
301297
New pattern that matches if both of the patterns match.
302298
"""
303-
if isinstance(self, AllOf) and isinstance(other, AllOf):
304-
return AllOf(*self.patterns, *other.patterns)
305-
elif isinstance(self, AllOf):
306-
return AllOf(*self.patterns, other)
307-
elif isinstance(other, AllOf):
308-
return AllOf(self, *other.patterns)
299+
if isinstance(other, AllOf):
300+
return AllOf(self, *cython.cast(AllOf, other).inners)
309301
else:
310302
return AllOf(self, other)
311303

@@ -883,6 +875,23 @@ def match(self, value, ctx: Context):
883875
pass
884876
raise NoMatchError()
885877

878+
def __or__(self, other: Pattern) -> AnyOf:
879+
"""Syntax sugar for matching either of the patterns.
880+
881+
Parameters
882+
----------
883+
other
884+
The other pattern to match against.
885+
886+
Returns
887+
-------
888+
New pattern that matches if either of the patterns match.
889+
"""
890+
if isinstance(other, AnyOf):
891+
return AnyOf(*self.inners, *cython.cast(AnyOf, other).inners)
892+
else:
893+
return AnyOf(*self.inners, other)
894+
886895

887896
@cython.final
888897
@cython.cclass
@@ -905,6 +914,23 @@ def match(self, value, ctx: Context):
905914
value = inner.match(value, ctx)
906915
return value
907916

917+
def __and__(self, other: Pattern) -> AllOf:
918+
"""Syntax sugar for matching both of the patterns.
919+
920+
Parameters
921+
----------
922+
other
923+
The other pattern to match against.
924+
925+
Returns
926+
-------
927+
New pattern that matches if both of the patterns match.
928+
"""
929+
if isinstance(other, AllOf):
930+
return AllOf(*self.inners, *cython.cast(AllOf, other).inners)
931+
else:
932+
return AllOf(*self.inners, other)
933+
908934

909935
def NoneOf(*args) -> Pattern:
910936
"""Match none of the passed patterns."""

koerce/tests/test_patterns.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,24 @@ def test_any_of():
440440
# assert p.describe() == "an int, a str or a float"
441441

442442

443+
def test_any_all_of_operator_overloading():
444+
is_int = InstanceOf(int)
445+
is_str = InstanceOf(str)
446+
is_float = InstanceOf(float)
447+
448+
assert (is_int | is_str) == AnyOf(is_int, is_str)
449+
assert (is_int & is_str) == AllOf(is_int, is_str)
450+
assert (is_int & is_str & is_float) == AllOf(is_int, is_str, is_float)
451+
assert (is_int | is_str | is_float) == AnyOf(is_int, is_str, is_float)
452+
assert (is_int | is_str & is_float) == AnyOf(is_int, AllOf(is_str, is_float))
453+
assert ((is_int | is_str) | (is_float | is_int)) == AnyOf(
454+
is_int, is_str, is_float, is_int
455+
)
456+
assert ((is_int & is_str) & (is_float & is_int)) == AllOf(
457+
is_int, is_str, is_float, is_int
458+
)
459+
460+
443461
def test_all_of():
444462
def negative(_):
445463
return _ < 0
@@ -1117,6 +1135,12 @@ def test_pattern_sequence_with_nested_some_of():
11171135
{"a": 1, "b": 2},
11181136
{"a": 1, "b": 2},
11191137
),
1138+
(AnyOf(InstanceOf(str)) | InstanceOf(int), 7, 7),
1139+
(AllOf(InstanceOf(int)) & InstanceOf(int), 7, 7),
1140+
(InstanceOf(int) | AnyOf(InstanceOf(str)), 7, 7),
1141+
(InstanceOf(int) & AllOf(InstanceOf(int)), 7, 7),
1142+
(AnyOf(InstanceOf(str)) | AnyOf(InstanceOf(int)), 7, 7),
1143+
(AllOf(InstanceOf(int)) & AllOf(InstanceOf(int)), 7, 7),
11201144
],
11211145
)
11221146
def test_various_patterns(pattern, value, expected):
@@ -1451,7 +1475,9 @@ class OtherClass(metaclass=OtherMeta): ...
14511475
my_other_instance = OtherClass()
14521476

14531477
assert InstanceOf(Class).apply(my_instance, context={}) == my_instance
1454-
assert InstanceOf(OtherClass).apply(my_other_instance, context={}) == my_other_instance
1478+
assert (
1479+
InstanceOf(OtherClass).apply(my_other_instance, context={}) == my_other_instance
1480+
)
14551481

14561482
assert InstanceOf(Class).apply(my_other_instance, context={}) == NoMatch
14571483
assert InstanceOf(OtherClass).apply(my_instance, context={}) == NoMatch

0 commit comments

Comments
 (0)