Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(patterns): support As/Is annotations like Is[int] and As[int] #28

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,19 +383,30 @@ assert match(As(MyNumber[float]), 8).value == 8.0
first argument to a pattern using the `koerce.pattern()` function:

```py
from koerce import pattern
from koerce import pattern, As, Is

assert pattern(int, allow_coercion=False) == Is(int)
assert pattern(int, allow_coercion=True) == As(int)

assert match(int, 1, allow_coercion=False) == 1
assert match(int, 1.1, allow_coercion=False) is NoMatch
assert match(int, 1.1, allow_coercion=True) == 1
# lossy coercion is not allowed
assert match(int, 1.1, allow_coercion=True) is NoMatch

# default is allow_coercion=False
assert match(int, 1.1) is NoMatch
```

`As[typehint]` and `Is[typehint]` can be used to create patterns:

```py
from koerce import Pattern, As, Is

assert match(As[int], '1') == 1
assert match(Is[int], 1) == 1
assert match(Is[int], '1') is NoMatch
```

### `If` patterns for conditionals

Allows conditional matching based on the value of the object,
Expand Down
9 changes: 9 additions & 0 deletions koerce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def namespace(module):
return p, d


def replace(matcher):
"""More convenient syntax for replacing a value with the output of a function."""

def decorator(replacer):
return Replace(matcher, replacer)

return decorator


class NoMatch:
__slots__ = ()

Expand Down
149 changes: 94 additions & 55 deletions koerce/annots.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ def __init__(
if kind is _VAR_POSITIONAL:
self.pattern = TupleOf(pattern)
elif kind is _VAR_KEYWORD:
# TODO(kszucs): remove FrozenDict?
self.pattern = FrozenDictOf(_any, pattern)
else:
self.pattern = _ensure_pattern(pattern)

# validate that the default value matches the pattern
if default is not EMPTY:
# TODO(kszucs): try/except MatchError raise an error indicating that the default value doesn't match the pattern
self.default_ = self.pattern.match(default, {})
else:
self.default_ = default
Expand Down Expand Up @@ -475,7 +477,7 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):
func,
arg_patterns=patterns or kwargs,
return_pattern=return_pattern,
allow_coercion=False,
allow_coercion=True,
)
pat: Pattern = PatternMap(
{name: param.pattern for name, param in sig.parameters.items()}
Expand Down Expand Up @@ -550,12 +552,12 @@ def varkwargs(pattern=_any, typehint=EMPTY):
@cython.cclass
class AnnotableSpec:
# make them readonly
initable: cython.bint
hashable: cython.bint
immutable: cython.bint
signature: Signature
attributes: dict[str, Attribute]
hasattribs: cython.bint
initable = cython.declare(cython.bint, visibility="readonly")
hashable = cython.declare(cython.bint, visibility="readonly")
immutable = cython.declare(cython.bint, visibility="readonly")
signature = cython.declare(Signature, visibility="readonly")
attributes = cython.declare(dict[str, Attribute], visibility="readonly")
hasattribs = cython.declare(cython.bint, visibility="readonly")

def __init__(
self,
Expand Down Expand Up @@ -594,10 +596,11 @@ def new(self, cls: type, args: tuple[Any, ...], kwargs: dict[str, Any]):
this = cls.__new__(cls)
for name, param in self.signature.parameters.items():
__setattr__(this, name, param.pattern.match(bound[name], ctx))
if self.hasattribs:
self.init_attributes(this)
# TODO(kszucs): test order ot precomputes and attributes calculations
if self.hashable:
self.init_precomputes(this)
if self.hasattribs:
self.init_attributes(this)
return this

@cython.cfunc
Expand All @@ -621,48 +624,97 @@ def init_precomputes(self, this) -> cython.void:
__setattr__(this, "__precomputed_hash__", hashvalue)


class AnnotableMeta(type):
class AbstractMeta(type):
"""Base metaclass for many of the ibis core classes.

Enforce the subclasses to define a `__slots__` attribute and provide a
`__create__` classmethod to change the instantiation behavior of the class.

Support abstract methods without extending `abc.ABCMeta`. While it provides
a reduced feature set compared to `abc.ABCMeta` (no way to register virtual
subclasses) but avoids expensive instance checks by enforcing explicit
subclassing.
"""

__slots__ = ()

def __new__(metacls, clsname, bases, dct, **kwargs):
# # enforce slot definitions
# dct.setdefault("__slots__", ())

# construct the class object
cls = super().__new__(metacls, clsname, bases, dct, **kwargs)

# calculate abstract methods existing in the class
abstracts = {
name
for name, value in dct.items()
if getattr(value, "__isabstractmethod__", False)
}
for parent in bases:
for name in getattr(parent, "__abstractmethods__", set()):
value = getattr(cls, name, None)
if getattr(value, "__isabstractmethod__", False):
abstracts.add(name)

# set the abstract methods for the class
cls.__abstractmethods__ = frozenset(abstracts)

return cls


class AnnotableMeta(AbstractMeta):
def __new__(
metacls,
clsname,
bases,
dct,
initable=None,
hashable=False,
immutable=False,
hashable=None,
immutable=None,
allow_coercion=True,
**kwargs,
):
traits = []
if initable is None:
# this flag is handled in AnnotableSpec
initable = "__init__" in dct or "__new__" in dct
if hashable:
if not immutable:
raise ValueError("Only immutable classes can be hashable")
traits.append(Hashable)
if immutable:
traits.append(Immutable)

# inherit signature from parent classes
abstracts: set = set()
# inherit annotable specifications from parent classes
spec: AnnotableSpec
signatures: list = []
attributes: dict[str, Attribute] = {}
is_initable: cython.bint
is_hashable: cython.bint = hashable is True
is_immutable: cython.bint = immutable is True
if initable is None:
is_initable = "__init__" in dct or "__new__" in dct
else:
is_initable = initable
for parent in bases:
try: # noqa: SIM105
signatures.append(parent.__signature__)
except AttributeError:
pass
try: # noqa: SIM105
attributes.update(parent.__attributes__)
spec = parent.__spec__
except AttributeError:
pass
try: # noqa: SIM105
abstracts.update(parent.__abstractmethods__)
except AttributeError:
pass
continue
is_initable |= spec.initable
is_hashable |= spec.hashable
is_immutable |= spec.immutable
signatures.append(spec.signature)
attributes.update(spec.attributes)

# create the base classes for the new class
traits: list[type] = []
if is_immutable and immutable is False:
raise TypeError(
"One of the base classes is immutable so the child class cannot be mutable"
)
if is_hashable and hashable is False:
raise TypeError(
"One of the base classes is hashable so this child class must be hashable"
)
if is_hashable and not is_immutable:
raise TypeError("Only immutable classes can be hashable")
if hashable:
traits.append(Hashable)
if immutable:
traits.append(Immutable)

# collection type annotations and convert them to patterns
# collect type annotations and convert them to patterns
slots: list[str] = list(dct.pop("__slots__", []))
module: str | None = dct.pop("__module__", None)
qualname: str = dct.pop("__qualname__", clsname)
Expand All @@ -688,7 +740,6 @@ def __new__(

namespace: dict[str, Any] = {}
parameters: dict[str, Parameter] = {}
abstractmethods: set = set()
for name, value in dct.items():
if isinstance(value, Parameter):
parameters[name] = value
Expand All @@ -697,18 +748,16 @@ def __new__(
attributes[name] = value
slots.append(name)
else:
if getattr(value, "__isabstractmethod__", False):
abstractmethods.add(name)
namespace[name] = value

# merge the annotations with the parent annotations
signature = Signature.merge(signatures, parameters)
argnames = tuple(signature.parameters.keys())
bases = tuple(traits) + bases
spec = AnnotableSpec(
initable=initable,
hashable=hashable,
immutable=immutable,
initable=is_initable,
hashable=is_hashable,
immutable=is_immutable,
signature=signature,
attributes=attributes,
)
Expand All @@ -722,17 +771,7 @@ def __new__(
__slots__=tuple(slots),
__spec__=spec,
)
klass = super().__new__(metacls, clsname, bases, namespace, **kwargs)

# check whether the inherited abstract methods are implemented by
# any of the parent classes, basically recalculating the abstractmethods
for name in abstracts:
value = getattr(klass, name, None)
if getattr(value, "__isabstractmethod__", False):
abstractmethods.add(name)
klass.__abstractmethods__ = frozenset(abstractmethods)

return klass
return super().__new__(metacls, clsname, bases, namespace, **kwargs)

def __call__(cls, *args, **kwargs):
spec: AnnotableSpec = cython.cast(AnnotableSpec, cls.__spec__)
Expand Down Expand Up @@ -781,10 +820,10 @@ def __init__(self, **kwargs):
spec: AnnotableSpec = self.__spec__
for name, value in kwargs.items():
__setattr__(self, name, value)
if spec.hasattribs:
spec.init_attributes(self)
if spec.hashable:
spec.init_precomputes(self)
if spec.hasattribs:
spec.init_attributes(self)

def __setattr__(self, name, value) -> None:
spec: AnnotableSpec = self.__spec__
Expand Down
13 changes: 12 additions & 1 deletion koerce/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def __getitem__(self, name):
def __call__(self, *args, **kwargs):
return Deferred(Call(self, *args, **kwargs))

# def __contains__(self, item):
# return Deferred(Binop(operator.contains, self, item))

def __invert__(self) -> Deferred:
return Deferred(Unop(operator.invert, self))

Expand Down Expand Up @@ -149,6 +152,15 @@ def __rxor__(self, other: Any) -> Deferred:

@cython.cclass
class Builder:
@staticmethod
def __coerce__(value) -> Builder:
if isinstance(value, Builder):
return value
elif isinstance(value, Deferred):
return cython.cast(Deferred, value)._builder
else:
raise ValueError(f"Cannot coerce {type(value).__name__!r} to Builder")

def apply(self, ctx: Context):
return self.build(ctx)

Expand Down Expand Up @@ -225,7 +237,6 @@ def build(self, ctx: Context):
return self.value


@cython.final
@cython.cclass
class Var(Builder):
"""Retrieve a value from the context.
Expand Down
Loading
Loading