diff --git a/koerce/annots.py b/koerce/annots.py index a9a99f1..c7f907b 100644 --- a/koerce/annots.py +++ b/koerce/annots.py @@ -19,7 +19,7 @@ _any, pattern, ) -from .utils import get_type_hints, get_type_origin +from .utils import PseudoHashable, get_type_hints, get_type_origin EMPTY = inspect.Parameter.empty _ensure_pattern = pattern @@ -38,6 +38,9 @@ def __init__(self, pattern: Any = _any, default: Any = EMPTY): def __repr__(self): return f"<{self.__class__.__name__} pattern={self.pattern!r} default={self.default_!r}>" + def __hash__(self) -> int: + return hash((self.__class__, self.pattern, PseudoHashable(self.default_))) + def __eq__(self, other: Any) -> bool: if not isinstance(other, Attribute): return NotImplemented @@ -124,6 +127,17 @@ def __eq__(self, other: Any) -> bool: and self.typehint == right.typehint ) + def __hash__(self) -> int: + return hash( + ( + self.__class__, + self.kind, + self.pattern, + PseudoHashable(self.default_), + self.typehint, + ) + ) + @cython.final @cython.cclass @@ -266,6 +280,16 @@ def __eq__(self, other: Any) -> bool: and self.return_typehint == right.return_typehint ) + def __hash__(self) -> int: + return hash( + ( + self.__class__, + PseudoHashable(self.parameters), + self.return_pattern, + self.return_typehint, + ) + ) + def __call__(self, /, *args, **kwargs): return self.bind(args, kwargs) diff --git a/koerce/builders.py b/koerce/builders.py index 35b2b83..af57a27 100644 --- a/koerce/builders.py +++ b/koerce/builders.py @@ -4,10 +4,12 @@ import functools import inspect import operator -from typing import Any +from typing import Any, Optional import cython +from .utils import PseudoHashable + Context = dict[str, Any] @@ -64,9 +66,6 @@ def __getitem__(self, name): def __call__(self, *args, **kwargs): return Deferred(Call(self, *args, **kwargs)) - # def __contains__(self, item): - # return Deferred(Binop(operator.contains, self, item)) - def __invert__(self) -> Deferred: return Deferred(Unop(operator.invert, self)) @@ -187,6 +186,12 @@ def build(self, ctx: Context): ... def __eq__(self, other: Any) -> bool: return type(self) is type(other) and self.equals(other) + def __hash__(self): + return self._hash() + + def __repr__(self): + raise NotImplementedError(f"{self.__class__.__name__} is not reprable") + def _deferred_repr(obj): try: @@ -223,6 +228,9 @@ def __init__(self, func: Any): def __repr__(self): return _deferred_repr(self.func) + def _hash(self): + return hash((self.__class__, self.func)) + def equals(self, other: Func) -> bool: return self.func == other.func @@ -247,12 +255,17 @@ class Just(Builder): def __init__(self, value: Any): if isinstance(value, Just): self.value = cython.cast(Just, value).value + elif isinstance(value, (Builder, Deferred)): + raise TypeError(f"`{value}` cannot be used as a Just value") else: self.value = value def __repr__(self): return _deferred_repr(self.value) + def _hash(self): + return hash((self.__class__, PseudoHashable(self.value))) + def equals(self, other: Just) -> bool: return self.value == other.value @@ -279,6 +292,9 @@ def __init__(self, name: str): def __repr__(self): return f"${self.name}" + def _hash(self): + return hash((self.__class__, self.name)) + def equals(self, other: Var) -> bool: return self.name == other.name @@ -332,6 +348,9 @@ def __init__(self, func): def __repr__(self): return f"{self.func!r}()" + def _hash(self): + return hash((self.__class__, self.func)) + def equals(self, other: Call0) -> bool: return self.func == other.func @@ -364,6 +383,9 @@ def __init__(self, func, arg): def __repr__(self): return f"{self.func!r}({self.arg!r})" + def _hash(self): + return hash((self.__class__, self.func, self.arg)) + def equals(self, other: Call1) -> bool: return self.func == other.func and self.arg == other.arg @@ -401,6 +423,9 @@ def __init__(self, func, arg1, arg2): def __repr__(self): return f"{self.func!r}({self.arg1!r}, {self.arg2!r})" + def _hash(self): + return hash((self.__class__, self.func, self.arg1, self.arg2)) + def equals(self, other: Call2) -> bool: return ( self.func == other.func @@ -447,6 +472,9 @@ def __init__(self, func, arg1, arg2, arg3): def __repr__(self): return f"{self.func!r}({self.arg1!r}, {self.arg2!r}, {self.arg3!r})" + def _hash(self): + return hash((self.__class__, self.func, self.arg1, self.arg2, self.arg3)) + def equals(self, other: Call3) -> bool: return ( self.func == other.func @@ -482,12 +510,12 @@ class CallN(Builder): """ func: Builder - args: list[Builder] + args: tuple[Builder, ...] kwargs: dict[str, Builder] def __init__(self, func, *args, **kwargs): self.func = builder(func) - self.args = [builder(arg) for arg in args] + self.args = tuple(builder(arg) for arg in args) self.kwargs = {k: builder(v) for k, v in kwargs.items()} def __repr__(self): @@ -502,6 +530,9 @@ def __repr__(self): else: return f"{self.func!r}()" + def _hash(self): + return hash((self.__class__, self.func, self.args, PseudoHashable(self.kwargs))) + def equals(self, other: CallN) -> bool: return ( self.func == other.func @@ -573,6 +604,9 @@ def __repr__(self): symbol = _operator_symbols[self.op] return f"{symbol}{self.arg!r}" + def _hash(self): + return hash((self.__class__, self.op, self.arg)) + def equals(self, other: Unop) -> bool: return self.op == other.op and self.arg == other.arg @@ -610,6 +644,9 @@ def __repr__(self): symbol = _operator_symbols[self.op] return f"({self.arg1!r} {symbol} {self.arg2!r})" + def _hash(self): + return hash((self.__class__, self.op, self.arg1, self.arg2)) + def equals(self, other: Binop) -> bool: return ( self.op == other.op and self.arg1 == other.arg1 and self.arg2 == other.arg2 @@ -645,6 +682,9 @@ def __init__(self, obj, key): def __repr__(self): return f"{self.obj!r}[{self.key!r}]" + def _hash(self): + return hash((self.__class__, self.obj, self.key)) + def equals(self, other: Item) -> bool: return self.obj == other.obj and self.key == other.key @@ -678,6 +718,9 @@ def __init__(self, obj: Any, attr: str): def __repr__(self): return f"{self.obj!r}.{self.attr}" + def _hash(self): + return hash((self.__class__, self.obj, self.attr)) + def equals(self, other: Attr) -> bool: return self.obj == other.obj and self.attr == other.attr @@ -699,11 +742,11 @@ class Seq(Builder): """ type_: Any - items: list[Builder] + items: tuple[Builder, ...] def __init__(self, items): self.type_ = type(items) - self.items = [builder(item) for item in items] + self.items = tuple(builder(item) for item in items) def __repr__(self): elems = ", ".join(map(repr, self.items)) @@ -714,6 +757,9 @@ def __repr__(self): else: return f"{self.type_.__name__}({elems})" + def _hash(self): + return hash((self.__class__, self.type_, self.items)) + def equals(self, other: Seq) -> bool: return self.type_ == other.type_ and self.items == other.items @@ -751,6 +797,9 @@ def __repr__(self): else: return f"{self.type_.__name__}({{{items}}})" + def _hash(self): + return hash((self.__class__, self.type_, PseudoHashable(self.items))) + def equals(self, other: Map) -> bool: return self.type_ == other.type_ and self.items == other.items diff --git a/koerce/patterns.py b/koerce/patterns.py index fd199e1..0ae8958 100644 --- a/koerce/patterns.py +++ b/koerce/patterns.py @@ -24,6 +24,7 @@ # TODO(kszucs): would be nice to cimport Signature and Builder from .builders import Builder, Deferred, Var, builder from .utils import ( + PseudoHashable, RewindableIterator, frozendict, get_type_args, @@ -206,7 +207,10 @@ def match(self, value, ctx: Context): ... def describe(self, value, reason) -> str: ... def __repr__(self) -> str: - return f"{self.__class__.__name__}()" + raise NotImplementedError(f"{self.__class__.__name__} is not reprable") + + def __hash__(self) -> int: + return self._hash() def __eq__(self, other) -> bool: return type(self) is type(other) and self.equals(other) @@ -286,6 +290,12 @@ def __iter__(self) -> SomeOf: @cython.final @cython.cclass class Anything(Pattern): + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def _hash(self) -> int: + return hash(self.__class__) + def equals(self, other: Anything) -> bool: return True @@ -304,6 +314,12 @@ def match(self, value, ctx: Context): @cython.final @cython.cclass class Nothing(Pattern): + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def _hash(self) -> int: + return hash(self.__class__) + def equals(self, other: Nothing) -> bool: return True @@ -327,6 +343,9 @@ def __init__(self, value): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value!r})" + def _hash(self) -> int: + return hash((self.__class__, self.value)) + def equals(self, other: IdenticalTo) -> bool: return self.value == other.value @@ -361,6 +380,9 @@ def __init__(self, value: Any): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value!r})" + def _hash(self) -> int: + return hash((self.__class__, PseudoHashable(self.value))) + def equals(self, other: EqValue) -> bool: return self.value == other.value @@ -396,6 +418,9 @@ def __init__(self, value): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value!r})" + def _hash(self) -> int: + return hash((self.__class__, self.value)) + def equals(self, other: EqDeferred) -> bool: return self.value == other.value @@ -425,6 +450,9 @@ def __init__(self, type_): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.type_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.type_)) + def equals(self, other: TypeOf) -> bool: return self.type_ == other.type_ @@ -461,6 +489,9 @@ def __init__(self, type_: Any): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.type_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.type_)) + def __call__(self, *args, **kwargs): return ObjectOf(self.type_, args, kwargs) @@ -503,6 +534,9 @@ def __init__(self, qualname: str): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.qualname!r})" + def _hash(self) -> int: + return hash((self.__class__, self.qualname)) + def __call__(self, *args, **kwargs): return ObjectOf(self, args, kwargs) @@ -577,6 +611,9 @@ def __repr__(self) -> str: def __call__(self, *args, **kwargs): return ObjectOf(self, args, kwargs) + def _hash(self) -> int: + return hash((self.__class__, self.origin, self.name1, self.pattern1)) + def equals(self, other: IsGeneric1) -> bool: return ( self.origin == other.origin @@ -624,6 +661,18 @@ def __repr__(self) -> str: def __call__(self, *args, **kwargs): return ObjectOf(self, *args, **kwargs) + def _hash(self) -> int: + return hash( + ( + self.__class__, + self.origin, + self.name1, + self.name2, + self.pattern1, + self.pattern2, + ) + ) + def describe(self, value, reason) -> str: return f"{value!r} is not an instance of {self.origin!r}" @@ -670,6 +719,9 @@ def __repr__(self) -> str: def __call__(self, *args, **kwargs): return ObjectOf(self, args, kwargs) + def _hash(self) -> int: + return hash((self.__class__, self.origin, PseudoHashable(self.fields))) + def equals(self, other: IsGenericN) -> bool: return self.origin == other.origin and self.fields == other.fields @@ -698,6 +750,9 @@ def __init__(self, type_: Any): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.type_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.type_)) + def equals(self, other: SubclassOf) -> bool: return self.type_ == other.type_ @@ -721,6 +776,12 @@ def __new__(cls, type_) -> Self: @cython.final @cython.cclass class AsBool(Pattern): + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def _hash(self) -> int: + return hash(self.__class__) + def equals(self, other: AsBool) -> bool: return True @@ -752,6 +813,12 @@ def match(self, value, ctx: Context): @cython.final @cython.cclass class AsInt(Pattern): + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def _hash(self) -> int: + return hash(self.__class__) + def equals(self, other: AsInt) -> bool: return True @@ -819,6 +886,9 @@ def lookup(cls, type_: Any): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.type_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.type_)) + def equals(self, other: AsType) -> bool: return self.type_ == other.type_ and self.func == other.func @@ -846,6 +916,9 @@ def __init__(self, type_: Any): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.type_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.type_)) + def equals(self, other: AsBuiltin) -> bool: return self.type_ == other.type_ @@ -877,6 +950,9 @@ def __init__(self, type_: Any): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.type_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.type_)) + def __call__(self, *args, **kwargs): return ObjectOf(self, args, kwargs) @@ -929,6 +1005,9 @@ def __init__(self, typ): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.origin!r}, params={self.params!r})" + def _hash(self) -> int: + return hash((self.__class__, self.origin, PseudoHashable(self.params))) + def __call__(self, *args: Any, **kwds: Any) -> Any: return ObjectOf(self, args, kwds) @@ -964,6 +1043,9 @@ def __init__(self, inner, **options): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.inner!r})" + def _hash(self) -> int: + return hash((self.__class__, self.inner)) + def equals(self, other: Not) -> bool: return self.inner == other.inner @@ -983,14 +1065,17 @@ def match(self, value, ctx: Context): @cython.final @cython.cclass class AnyOf(Pattern): - inners: list[Pattern] + inners: tuple[Pattern] def __init__(self, *inners: Pattern, **options): - self.inners = [pattern(inner, **options) for inner in inners] + self.inners = tuple(pattern(inner, **options) for inner in inners) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.inners!r})" + def _hash(self) -> int: + return hash((self.__class__, self.inners)) + def equals(self, other: AnyOf) -> bool: return self.inners == other.inners @@ -1028,14 +1113,17 @@ def __or__(self, other: Pattern) -> AnyOf: @cython.final @cython.cclass class AllOf(Pattern): - inners: list[Pattern] + inners: tuple[Pattern] def __init__(self, *inners: Pattern, **options): - self.inners = [pattern(inner, **options) for inner in inners] + self.inners = tuple(pattern(inner, **options) for inner in inners) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.inners!r})" + def _hash(self) -> int: + return hash((self.__class__, self.inners)) + def equals(self, other: AllOf) -> bool: return self.inners == other.inners @@ -1093,6 +1181,9 @@ def __init__(self, pat, default=None, **options): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.pattern!r}, default={self.default!r})" + def _hash(self) -> int: + return hash((self.__class__, self.pattern, PseudoHashable(self.default))) + def equals(self, other: Option) -> bool: return self.pattern == other.pattern and self.default == other.default @@ -1134,6 +1225,9 @@ def __init__(self, predicate): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.predicate!r})" + def _hash(self) -> int: + return hash((self.__class__, self.predicate)) + def equals(self, other: IfFunction) -> bool: return self.predicate == other.predicate @@ -1168,6 +1262,9 @@ def __init__(self, obj): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.builder!r})" + def _hash(self) -> int: + return hash((self.__class__, self.builder)) + def equals(self, other: IfDeferred) -> bool: return self.builder == other.builder @@ -1205,6 +1302,9 @@ def __init__(self, haystack): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.haystack})" + def _hash(self) -> int: + return hash((self.__class__, self.haystack)) + def equals(self, other: IsIn) -> bool: return self.haystack == other.haystack @@ -1247,6 +1347,9 @@ def __init__(self, item: Any, type_: Any = list, **options): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.item!r}, type_={self.type_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.item, self.type_)) + def equals(self, other: SequenceOf) -> bool: return self.item == other.item and self.type_ == other.type_ @@ -1313,6 +1416,9 @@ def __repr__(self) -> str: f"{self.__class__.__name__}({self.key!r}, {self.value!r}, {self.type_!r})" ) + def _hash(self) -> int: + return hash((self.__class__, self.key, self.value, self.type_)) + def equals(self, other: MappingOf) -> bool: return ( self.key == other.key @@ -1365,6 +1471,9 @@ def __init__(self, func): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.func!r})" + def _hash(self) -> int: + return hash((self.__class__, self.func)) + def equals(self, other: Custom) -> bool: return self.func == other.func @@ -1406,6 +1515,9 @@ def __init__(self, key: Any, what=_any, **options): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.key!r}, {self.what!r})" + def _hash(self) -> int: + return hash((self.__class__, self.key, self.what)) + def equals(self, other: Capture) -> bool: return self.key == other.key and self.what == other.what @@ -1440,6 +1552,9 @@ def __init__(self, searcher, replacer, **options): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.searcher!r}, {self.replacer!r})" + def _hash(self) -> int: + return hash((self.__class__, self.searcher, self.replacer)) + @cython.cfunc def match(self, value, ctx: Context): value = self.searcher.match(value, ctx) @@ -1503,7 +1618,13 @@ def __init__(self, type_: Any, fields, **options): self.pattern1 = pattern(pattern1, **options) def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.type_!r}, {self.field1!r}={self.pattern1!r})" + return ( + f"{self.__class__.__name__}({self.type_!r}, " + f"{self.field1!r}={self.pattern1!r})" + ) + + def _hash(self) -> int: + return hash((self.__class__, self.type_, self.field1, self.pattern1)) def equals(self, other: ObjectOf1) -> bool: return ( @@ -1547,7 +1668,23 @@ def __init__(self, type_: Any, fields, **options): self.pattern2 = pattern(pattern2, **options) def __repr__(self) -> str: - return f"ObjectOf2({self.type_!r}, {self.field1!r}={self.pattern1!r}, {self.field2!r}={self.pattern2!r})" + return ( + f"ObjectOf2({self.type_!r}, " + f"{self.field1!r}={self.pattern1!r}, " + f"{self.field2!r}={self.pattern2!r})" + ) + + def _hash(self) -> int: + return hash( + ( + self.__class__, + self.type_, + self.field1, + self.field2, + self.pattern1, + self.pattern2, + ) + ) def equals(self, other: ObjectOf2) -> bool: return ( @@ -1608,6 +1745,20 @@ def __repr__(self) -> str: f"{self.field3!r}={self.pattern3!r})" ) + def _hash(self) -> int: + return hash( + ( + self.__class__, + self.type_, + self.field1, + self.field2, + self.field3, + self.pattern1, + self.pattern2, + self.pattern3, + ) + ) + def equals(self, other: ObjectOf3) -> bool: return ( self.type_ == other.type_ @@ -1676,6 +1827,9 @@ def __init__(self, type_: Any, fields, **options): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.type_!r}, {self.fields!r})" + def _hash(self) -> int: + return hash((self.__class__, self.type_, PseudoHashable(self.fields))) + def equals(self, other: ObjectOfN) -> bool: return self.type_ == other.type_ and self.fields == other.fields @@ -1705,12 +1859,12 @@ def match(self, value, ctx: Context): @cython.cclass class ObjectOfX(Pattern): type_: Pattern - args: list[Pattern] + args: tuple[Pattern, ...] kwargs: dict[str, Pattern] def __init__(self, type_, args, kwargs, **options): self.type_ = pattern(type_, **options) - self.args = [pattern(arg, **options) for arg in args] + self.args = tuple(pattern(arg, **options) for arg in args) self.kwargs = {k: pattern(v, **options) for k, v in kwargs.items()} def __repr__(self) -> str: @@ -1718,6 +1872,11 @@ def __repr__(self) -> str: f"{self.__class__.__name__}({self.type_!r}, {self.args!r}, {self.kwargs!r})" ) + def _hash(self) -> int: + return hash( + (self.__class__, self.type_, self.args, PseudoHashable(self.kwargs)) + ) + def equals(self, other: ObjectOfX) -> bool: return ( self.type_ == self.type_ @@ -1767,16 +1926,19 @@ def match(self, value, ctx: Context): @cython.final @cython.cclass class CallableWith(Pattern): - args: list[Pattern] + args: tuple[Pattern, ...] return_: Pattern def __init__(self, args, return_=_any, **options): - self.args = [pattern(arg, **options) for arg in args] + self.args = tuple(pattern(arg, **options) for arg in args) self.return_ = pattern(return_, **options) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.args!r}, return_={self.return_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.args, self.return_)) + def equals(self, other: CallableWith) -> bool: return self.args == other.args and self.return_ == other.return_ @@ -1871,6 +2033,9 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}(at_least={self.at_least}, at_most={self.at_most})" + def _hash(self) -> int: + return hash((self.__class__, self.at_least, self.at_most)) + def equals(self, other: Length) -> bool: return self.at_least == other.at_least and self.at_most == other.at_most @@ -1916,6 +2081,9 @@ def __init__(self, item, type_=list, **kwargs): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.pattern!r})" + def _hash(self) -> int: + return hash((self.__class__, self.pattern, self.delimiter, self.length)) + def equals(self, other: SomeItemsOf) -> bool: return self.pattern == other.pattern @@ -1946,6 +2114,9 @@ def __init__(self, *args, type_=list, **kwargs): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.pattern!r}, {self.delimiter!r})" + def _hash(self) -> int: + return hash((self.__class__, self.pattern, self.delimiter, self.length)) + def equals(self, other: SomeChunksOf) -> bool: return self.pattern == other.pattern and self.delimiter == other.delimiter @@ -2004,15 +2175,18 @@ class FixedPatternList(Pattern): """ type_: type - patterns: list[Pattern] + patterns: tuple[Pattern, ...] def __init__(self, patterns, type_=list, **options): self.type_ = type_ - self.patterns = [pattern(p, **options) for p in patterns] + self.patterns = tuple(pattern(p, **options) for p in patterns) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.patterns!r}, type_={self.type_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.type_, self.patterns)) + def equals(self, other: FixedPatternList) -> bool: return self.patterns == other.patterns and self.type_ == other.type_ @@ -2056,15 +2230,18 @@ def match(self, values, ctx: Context): @cython.cclass class VariadicPatternList(Pattern): type_: type - patterns: list[Pattern] + patterns: tuple[Pattern, ...] def __init__(self, patterns, type_=list, **options): self.type_ = type_ - self.patterns = [pattern(p, **options) for p in patterns] + self.patterns = tuple(pattern(p, **options) for p in patterns) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.patterns!r}, {self.type_!r})" + def _hash(self) -> int: + return hash((self.__class__, self.type_, self.pattern)) + def equals(self, other: VariadicPatternList) -> bool: return self.patterns == other.patterns and self.type_ == other.type_ @@ -2089,7 +2266,7 @@ def match(self, value, ctx: Context): current: Pattern original: Pattern following: Pattern - following_patterns = self.patterns[1:] + [Nothing()] + following_patterns = self.patterns[1:] + (Nothing(),) for current, following in zip(self.patterns, following_patterns): original = current current = _maybe_unwrap_capture(current) @@ -2155,6 +2332,9 @@ def __init__(self, fields, **options): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.field1!r}={self.pattern1!r})" + def _hash(self) -> int: + return hash((self.__class__, self.field1, self.pattern1)) + def equals(self, other: PatternMap1) -> bool: return self.field1 == other.field1 and self.pattern1 == other.pattern1 @@ -2205,6 +2385,11 @@ def __repr__(self) -> str: f"{self.field2!r}={self.pattern2!r})" ) + def _hash(self) -> int: + return hash( + (self.__class__, self.field1, self.field2, self.pattern1, self.pattern2) + ) + def equals(self, other: PatternMap2) -> bool: return ( self.field1 == other.field1 @@ -2265,6 +2450,19 @@ def __repr__(self) -> str: f"{self.field3!r}={self.pattern3!r})" ) + def _hash(self) -> int: + return hash( + ( + self.__class__, + self.field1, + self.field2, + self.field3, + self.pattern1, + self.pattern2, + self.pattern3, + ) + ) + def equals(self, other: PatternMap3) -> bool: return ( self.field1 == other.field1 @@ -2321,6 +2519,9 @@ def __init__(self, fields, **options): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.fields!r})" + def _hash(self) -> int: + return hash((self.__class__, PseudoHashable(self.fields))) + def equals(self, other: PatternMapN) -> bool: return self.fields == other.fields diff --git a/koerce/tests/test_annots.py b/koerce/tests/test_annots.py index 49323c1..6d4d8ec 100644 --- a/koerce/tests/test_annots.py +++ b/koerce/tests/test_annots.py @@ -28,6 +28,7 @@ AnnotableMeta, Anything, As, + Attribute, FrozenDictOf, Hashable, Immutable, @@ -53,12 +54,14 @@ def test_parameter(): p = Parameter(Parameter.POSITIONAL_OR_KEYWORD, typehint=int) assert p.kind is Parameter.POSITIONAL_OR_KEYWORD assert p.format("x") == "x: int" + assert hash(p) == hash(p) p = Parameter(Parameter.POSITIONAL_OR_KEYWORD, default=1) assert p.kind is Parameter.POSITIONAL_OR_KEYWORD assert p.default_ == 1 assert p.format("x") == "x=1" assert p.pattern == Anything() + assert hash(p) == hash(p) p = Parameter( Parameter.POSITIONAL_OR_KEYWORD, typehint=int, default=1, pattern=is_int @@ -68,18 +71,45 @@ def test_parameter(): assert p.typehint is int assert p.format("x") == "x: int = 1" assert p.pattern == is_int + assert hash(p) == hash(p) p = Parameter(Parameter.VAR_POSITIONAL, typehint=int, pattern=is_int) assert p.kind is Parameter.VAR_POSITIONAL assert p.typehint is int assert p.format("y") == "*y: int" assert p.pattern == TupleOf(is_int) + assert hash(p) == hash(p) p = Parameter(Parameter.VAR_KEYWORD, typehint=int, pattern=is_int) assert p.kind is Parameter.VAR_KEYWORD assert p.typehint is int assert p.format("z") == "**z: int" assert p.pattern == FrozenDictOf(Anything(), is_int) + assert hash(p) == hash(p) + + p = Parameter(Parameter.VAR_KEYWORD, typehint=int, pattern=is_int, default={}) + assert p.kind is Parameter.VAR_KEYWORD + assert p.typehint is int + assert p.format("z") == "**z: int = {}" + assert p.pattern == FrozenDictOf(Anything(), is_int) + assert hash(p) == hash(p) + + +def test_attribute(): + a = Attribute(int, default=1) + assert a.pattern == Is(int) + assert a.default_ == 1 + assert hash(a) == hash(a) + + a = Attribute(dict, default={}) + assert a.pattern == Is(dict) + assert a.default_ == {} + assert hash(a) == hash(a) + + a = Attribute(dict[str, float], default={"a": 1.0}) + assert a.pattern == Is(dict[str, float]) + assert a.default_ == {"a": 1.0} + assert hash(a) == hash(a) def test_signature_contruction(): @@ -92,6 +122,7 @@ def test_signature_contruction(): assert sig.parameters == {"a": a, "b": b, "c": c, "d": d} assert sig.return_typehint is EMPTY assert sig.return_pattern == Anything() + assert hash(sig) == hash(sig) def test_signature_equality_comparison(): @@ -103,9 +134,11 @@ def test_signature_equality_comparison(): sig1 = Signature({"a": a, "b": b, "c": c}) sig2 = Signature({"a": a, "b": b, "c": c}) assert sig1 == sig2 + assert hash(sig1) == hash(sig2) sig3 = Signature({"a": a, "c": c, "b": b}) assert sig1 != sig3 + assert hash(sig1) != hash(sig3) def test_signature_from_callable(): diff --git a/koerce/tests/test_builders.py b/koerce/tests/test_builders.py index 552deb4..107bb39 100644 --- a/koerce/tests/test_builders.py +++ b/koerce/tests/test_builders.py @@ -57,18 +57,18 @@ def test_builder_just(): # unwrap subsequently nested Just instances assert Just(p) == p + assert hash(p) == hash(p) # disallow creating a Just builder from other builders or deferreds - # with pytest.raises(TypeError, match="cannot be used as a Just value"): - # Just(_) - # with pytest.raises(TypeError, match="cannot be used as a Just value"): - # Just(Factory(lambda _: _)) + with pytest.raises(TypeError, match="cannot be used as a Just value"): + Just(_) def test_builder_Var(): p = Var("other") context = {"other": 10} assert p.apply(context) == 10 + assert hash(p) == hash(p) def test_builder_func(): @@ -82,6 +82,7 @@ def fn(**kwargs): f = Func(fn) assert f.apply({"_": 10, "a": 5}) == -1 + assert hash(f) == hash(f) def test_builder_call(): @@ -93,19 +94,24 @@ def func(a, b, c=1): c = Call0(Just(fn)) assert c.apply({}) == () + assert hash(c) == hash(c) c = Call1(Just(fn), Just(1)) assert c.apply({}) == (1,) + assert hash(c) == hash(c) c = Call2(Just(fn), Just(1), Just(2)) assert c.apply({}) == (1, 2) + assert hash(c) == hash(c) c = Call3(Just(fn), Just(1), Just(2), Just(3)) assert c.apply({}) == (1, 2, 3) + assert hash(c) == hash(c) c = Call(Just(func), Just(1), Just(2), c=Just(3)) assert isinstance(c, CallN) assert c.apply({}) == 6 + assert hash(c) == hash(c) c = Call(Just(func), Just(-1), Just(-2)) assert isinstance(c, Call2) @@ -127,6 +133,7 @@ def func(a, b, c=1): def test_builder_unop(): b = Unop(operator.neg, Just(1)) assert b.apply({}) == -1 + assert hash(b) == hash(b) b = Unop(operator.abs, Just(-1)) assert b.apply({}) == 1 @@ -135,6 +142,7 @@ def test_builder_unop(): def test_builder_binop(): b = Binop(operator.add, Just(1), Just(2)) assert b.apply({}) == 3 + assert hash(b) == hash(b) b = Binop(operator.mul, Just(2), Just(3)) assert b.apply({}) == 6 @@ -152,6 +160,7 @@ def __hash__(self): v = Var("v") b = Attr(v, "b") assert b.apply({"v": MyType(1, 2)}) == 2 + assert hash(b) == hash(b) b = Attr(Just(MyType(1, 2)), "a") assert b.apply({}) == 1 @@ -165,6 +174,7 @@ def test_builder_item(): v = Var("v") b = Item(v, Just(1)) assert b.apply({"v": [1, 2, 3]}) == 2 + assert hash(b) == hash(b) b = Item(Just(dict(a=1, b=2)), Just("a")) assert b.apply({}) == 1 @@ -178,6 +188,7 @@ def test_builder_item(): def test_builder_seq(): b = Seq([Just(1), Just(2), Just(3)]) assert b.apply({}) == [1, 2, 3] + assert hash(b) == hash(b) b = Seq((Just(1), Just(2), Just(3))) assert b.apply({}) == (1, 2, 3) @@ -186,6 +197,7 @@ def test_builder_seq(): def test_builder_map(): b = Map({"a": Just(1), "b": Just(2)}) assert b.apply({}) == {"a": 1, "b": 2} + assert hash(b) == hash(b) b = Map({"a": Just(1), "b": Just(2)}) assert b.apply({}) == {"a": 1, "b": 2} diff --git a/koerce/tests/test_patterns.py b/koerce/tests/test_patterns.py index e4f021f..ac5ae67 100644 --- a/koerce/tests/test_patterns.py +++ b/koerce/tests/test_patterns.py @@ -566,9 +566,8 @@ def test_any_of(): assert p == p1 assert p.apply(1) == 1 assert p.apply("foo") == "foo" - msg = re.escape( - "`1.0` does not match any of [IsType(), IsType()]" - ) + msg = "`1.0` does not match any of [IsType(), IsType()]" + with pytest.raises(MatchError, match=msg): p.apply(1.0) @@ -1117,15 +1116,6 @@ def test_replace_in_nested_object_pattern(): assert h1.b.b == 3 -# def test_replace_decorator(): -# @replace(int) -# def sub(_): -# return _ - 1 - -# assert match(sub, 1) == 0 -# assert match(sub, 2) == 1 - - def test_replace_using_deferred(): x = Deferred(Var("x")) y = Deferred(Var("y")) @@ -1328,7 +1318,10 @@ def test_pattern_sequence_with_nested_some_of(): ], ) def test_various_patterns(pattern, value, expected): - assert pattern.apply(value, context={}) == expected + assert pattern.apply(value) == expected + assert hash(pattern) == hash(pattern) + assert repr(pattern) == repr(pattern) + assert pattern == pattern @pytest.mark.parametrize( @@ -1432,6 +1425,9 @@ def test_pattern_from_typehint_no_coercion(annot, expected): ) def test_pattern_from_typehint_with_coercion(annot, expected): assert Pattern.from_typehint(annot, allow_coercion=True) == expected + assert hash(pattern) == hash(pattern) + assert repr(pattern) == repr(pattern) + assert pattern == pattern def test_pattern_from_annotated(): diff --git a/koerce/tests/test_utils.py b/koerce/tests/test_utils.py index 66612ef..bd00784 100644 --- a/koerce/tests/test_utils.py +++ b/koerce/tests/test_utils.py @@ -9,6 +9,7 @@ from koerce.utils import ( FrozenDict, + PseudoHashable, RewindableIterator, get_type_boundvars, get_type_hints, @@ -187,3 +188,88 @@ def test_frozendict(): assert hash(FrozenDict(a=1, b=2)) != hash(d) assert d == pickle.loads(pickle.dumps(d)) + + +def test_pseudo_hashable(): + class Unhashable: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return isinstance(other, Unhashable) and self.value == other.value + + class MyList(list): + pass + + class MyMap(dict): + pass + + for obj in [1, "a", b"a", 2.0, object(), (), frozenset()]: + assert PseudoHashable(obj) is obj + + # test unhashable sequences + lst1 = [1, 2, 3] + lst2 = [1, 2, 4] + lst3 = MyList([1, 2, 3]) + ph1 = PseudoHashable(lst1) + ph2 = PseudoHashable(lst2) + ph3 = PseudoHashable(lst3) + ph4 = PseudoHashable(lst1.copy()) + + assert hash(ph1) == hash(ph1) + assert hash(ph1) != hash(ph2) + assert hash(ph1) != hash(ph3) + assert hash(ph2) != hash(ph3) + assert hash(ph3) == hash(ph3) + assert hash(ph1) == hash(ph4) + assert ph1 == ph1 + assert ph1 != ph2 + assert ph1 == ph3 + assert ph2 != ph3 + assert ph3 == ph3 + assert ph1 == ph4 + + # test unhashable mappings + dct1 = {"a": 1, "b": 2} + dct2 = {"a": 1, "b": 3} + dct3 = MyMap({"a": 1, "b": 2}) + ph1 = PseudoHashable(dct1) + ph2 = PseudoHashable(dct2) + ph3 = PseudoHashable(dct3) + ph4 = PseudoHashable(dct1.copy()) + + assert hash(ph1) == hash(ph1) + assert hash(ph1) != hash(ph2) + assert hash(ph1) != hash(ph3) + assert hash(ph2) != hash(ph3) + assert hash(ph3) == hash(ph3) + assert hash(ph1) == hash(ph4) + assert ph1 == ph1 + assert ph1 != ph2 + assert ph1 == ph3 + assert ph2 != ph3 + assert ph3 == ph3 + assert ph1 == ph4 + + # test unhashable objects + obj1 = Unhashable(1) + obj2 = Unhashable(1) + obj3 = Unhashable(2) + obj4 = Unhashable(1) + ph1 = PseudoHashable(obj1) + ph2 = PseudoHashable(obj2) + ph3 = PseudoHashable(obj3) + ph4 = PseudoHashable(obj4) + + assert hash(ph1) == hash(ph1) + assert hash(ph1) != hash(ph2) + assert hash(ph1) != hash(ph3) + assert hash(ph2) != hash(ph3) + assert hash(ph3) == hash(ph3) + assert hash(ph1) != hash(ph4) + assert ph1 == ph1 + assert ph1 == ph2 + assert ph1 != ph3 + assert ph2 != ph3 + assert ph3 == ph3 + assert ph1 == ph4 diff --git a/koerce/utils.py b/koerce/utils.py index ebc80a2..f93f756 100644 --- a/koerce/utils.py +++ b/koerce/utils.py @@ -3,7 +3,7 @@ import itertools import sys import typing -from collections.abc import Hashable +from collections.abc import Hashable, Mapping, Sequence, Set from typing import Any, ClassVar, ForwardRef, Optional, TypeVar from typing_extensions import Self @@ -257,6 +257,40 @@ def checkpoint(self): self._iterator, self._checkpoint = itertools.tee(self._iterator) +class PseudoHashable: + """A wrapper that provides a best effort precomputed hash.""" + + __slots__ = ("obj", "hash") + + def __new__(cls, obj): + if isinstance(obj, Hashable): + return obj + else: + return super().__new__(cls) + + def __init__(self, obj): + if isinstance(obj, Sequence): + hashable_obj = tuple(obj) + elif isinstance(obj, Mapping): + hashable_obj = tuple(obj.items()) + elif isinstance(obj, Set): + hashable_obj = frozenset(obj) + else: + hashable_obj = id(obj) + + self.obj = obj + self.hash = hash((type(obj), hashable_obj)) + + def __hash__(self): + return self.hash + + def __eq__(self, other): + if isinstance(other, PseudoHashable): + return self.obj == other.obj + else: + return NotImplemented + + # def format_typehint(typ: Any) -> str: # if isinstance(typ, type): # return typ.__name__