Skip to content

Commit

Permalink
feat(builders): support deferrable functions and custom deferred repr
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Sep 11, 2024
1 parent 644eb4c commit 4385bda
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 9 deletions.
90 changes: 81 additions & 9 deletions koerce/builders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import collections.abc
import functools
import inspect
import operator
from typing import Any

Expand All @@ -20,15 +22,30 @@ class Deferred:
Its sole purpose is to provide a nicer syntax for constructing deferred
expressions, thus it gets unwrapped to the underlying deferred expression
when used by the rest of the library.
Parameters
----------
deferred
The deferred object to provide syntax sugar for.
repr
An optional fixed string to use when repr-ing the deferred expression,
instead of the default. This is useful for complex deferred expressions
where the arguments don't necessarily make sense to be user facing in
the repr.
"""

_repr: str
_builder: Builder

def __init__(self, builder: Builder):
def __init__(self, builder: Builder, repr: Optional[str] = None):
self._repr = repr
self._builder = builder

def __repr__(self):
return repr(self._builder)
if self._repr is None:
return repr(self._builder)
else:
return self._repr

def __getattr__(self, name):
return Deferred(Attr(self, name))
Expand Down Expand Up @@ -171,6 +188,16 @@ def __eq__(self, other: Any) -> bool:
return type(self) is type(other) and self.equals(other)


def _deferred_repr(obj):
try:
return obj.__deferred_repr__()
except (AttributeError, TypeError):
if callable(obj) and hasattr(obj, "__name__"):
return obj.__name__
else:
return repr(obj)


@cython.final
@cython.cclass
class Func(Builder):
Expand All @@ -194,7 +221,7 @@ def __init__(self, func: Any):
self.func = func

def __repr__(self):
return f"{self.func.__name__}(...)"
return _deferred_repr(self.func)

def equals(self, other: Func) -> bool:
return self.func == other.func
Expand Down Expand Up @@ -224,12 +251,7 @@ def __init__(self, value: Any):
self.value = value

def __repr__(self):
if hasattr(self.value, "__deferred_repr__"):
return self.value.__deferred_repr__()
elif callable(self.value):
return getattr(self.value, "__name__", repr(self.value))
else:
return repr(self.value)
return _deferred_repr(self.value)

def equals(self, other: Just) -> bool:
return self.value == other.value
Expand Down Expand Up @@ -772,3 +794,53 @@ def deferred(obj, allow_custom=False) -> Deferred:
def resolve(obj, **context):
bldr: Builder = builder(obj)
return bldr.build(context)


def _contains_deferred(obj: Any) -> bool:
if isinstance(obj, (Builder, Deferred)):
return True
elif (typ := type(obj)) in (tuple, list, set):
return any(_contains_deferred(o) for o in obj)
elif typ is dict:
return any(_contains_deferred(o) for o in obj.values())
return False


def deferrable(func=None, *, repr=None):
"""Wrap a top-level expr function to support deferred arguments.
When a deferrable function is called, the args & kwargs are traversed to
look for `Deferred` values (through builtin collections like
`list`/`tuple`/`set`/`dict`). If any `Deferred` arguments are found, then
the result is also `Deferred`. Otherwise the function is called directly.
Parameters
----------
func
A callable to make deferrable
repr
An optional fixed string to use when repr-ing the deferred expression,
instead of the usual. This is useful for complex deferred expressions
where the arguments don't necessarily make sense to be user facing
in the repr.
"""

def wrapper(func):
# Parse the signature of func so we can validate deferred calls eagerly,
# erroring for invalid/missing arguments at call time not resolve time.
sig = inspect.signature(func)

@functools.wraps(func)
def inner(*args, **kwargs):
if _contains_deferred((args, kwargs)):
# Try to bind the arguments now, raising a nice error
# immediately if the function was called incorrectly
sig.bind(*args, **kwargs)
builder = Call(func, *args, **kwargs)
return Deferred(builder, repr=repr)
return func(*args, **kwargs)

return inner # type: ignore

return wrapper if func is None else wrapper(func)
87 changes: 87 additions & 0 deletions koerce/tests/test_builders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import operator
import pickle

import pytest

Expand All @@ -23,6 +24,7 @@
Unop,
Var,
builder,
deferrable,
resolve,
)

Expand Down Expand Up @@ -202,6 +204,11 @@ def test_deferred_object_are_not_hashable():
hash(_.a)


def test_deferred_set_raises():
with pytest.raises(TypeError, match="unhashable type"):
{_.a, _.b} # noqa: B018


@pytest.mark.parametrize(
("value", "expected"),
[
Expand Down Expand Up @@ -529,3 +536,83 @@ def test_builder_coercion():
def test_resolve():
deferred = _["a"] + 1
assert resolve(deferred, _={"a": 1}) == 2


def test_deferrable(table):
@deferrable
def f(a, b, c=3):
return a + b + c

assert f(table.a, table.b) == table.a + table.b + 3
assert f(table.a, table.b, c=4) == table.a + table.b + 4

expr = f(_.a, _.b)
sol = table.a + table.b + 3
res = resolve(expr, _=table)
assert res == sol
assert repr(expr) == "f($_.a, $_.b)"

expr = f(1, 2, c=_.a)
sol = 3 + table.a
res = resolve(expr, _=table)
assert res == sol
assert repr(expr) == "f(1, 2, c=$_.a)"

with pytest.raises(TypeError, match="unknown"):
f(_.a, _.b, unknown=3) # invalid calls caught at call time


def test_deferrable_repr():
@deferrable(repr="<test>")
def myfunc(x):
return x + 1

assert repr(myfunc(_.a)) == "<test>"


@pytest.mark.parametrize(
"case",
[
pytest.param(lambda: ([1, _], [1, 2]), id="list"),
pytest.param(lambda: ((1, _), (1, 2)), id="tuple"),
pytest.param(lambda: ({"x": 1, "y": _}, {"x": 1, "y": 2}), id="dict"),
pytest.param(
lambda: ({"x": 1, "y": [_, 3]}, {"x": 1, "y": [2, 3]}), id="nested"
),
],
)
def test_deferrable_nested_args(case):
arg, sol = case()

@deferrable
def identity(x):
return x

expr = identity(arg)
assert resolve(expr, _=2) == sol
assert identity(sol) is sol
assert repr(expr) == f"identity({arg!r})"


@pytest.mark.parametrize(
"func",
[
pytest.param(lambda _: _, id="root"),
pytest.param(lambda _: _.a, id="getattr"),
pytest.param(lambda _: _["a"], id="getitem"),
pytest.param(lambda _: _.a.log(), id="method"),
pytest.param(lambda _: _.a.log(_.b), id="method-with-args"),
pytest.param(lambda _: _.a.log(base=_.b), id="method-with-kwargs"),
pytest.param(lambda _: _.a + _.b, id="binary-op"),
pytest.param(lambda _: ~_.a, id="unary-op"),
],
)
def test_deferred_is_pickleable(func, table):
expr1 = func(_)
builder1 = builder(expr1)
builder2 = pickle.loads(pickle.dumps(builder1))

r1 = resolve(builder1, _=table)
r2 = resolve(builder2, _=table)

assert r1 == r2

0 comments on commit 4385bda

Please sign in to comment.