From 9fc7aeeb7a9cd25d5dd7539ee73158d7a5402190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 11 Sep 2024 11:25:24 +0200 Subject: [PATCH] feat(builders): support deferrable functions and custom deferred repr --- koerce/builders.py | 90 +++++++++++++++++++++++++++++++---- koerce/tests/test_builders.py | 87 +++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 9 deletions(-) diff --git a/koerce/builders.py b/koerce/builders.py index 2968eb8..35b2b83 100644 --- a/koerce/builders.py +++ b/koerce/builders.py @@ -1,6 +1,8 @@ from __future__ import annotations import collections.abc +import functools +import inspect import operator from typing import Any @@ -20,15 +22,30 @@ class Deferred: Its sole purpose is to provide a nicer syntax for constructing deferred expressions, thus it gets unwrapped to the underlying deferred expression when used by the rest of the library. + + Parameters + ---------- + deferred + The deferred object to provide syntax sugar for. + repr + An optional fixed string to use when repr-ing the deferred expression, + instead of the default. This is useful for complex deferred expressions + where the arguments don't necessarily make sense to be user facing in + the repr. """ + _repr: str _builder: Builder - def __init__(self, builder: Builder): + def __init__(self, builder: Builder, repr: Optional[str] = None): + self._repr = repr self._builder = builder def __repr__(self): - return repr(self._builder) + if self._repr is None: + return repr(self._builder) + else: + return self._repr def __getattr__(self, name): return Deferred(Attr(self, name)) @@ -171,6 +188,16 @@ def __eq__(self, other: Any) -> bool: return type(self) is type(other) and self.equals(other) +def _deferred_repr(obj): + try: + return obj.__deferred_repr__() + except (AttributeError, TypeError): + if callable(obj) and hasattr(obj, "__name__"): + return obj.__name__ + else: + return repr(obj) + + @cython.final @cython.cclass class Func(Builder): @@ -194,7 +221,7 @@ def __init__(self, func: Any): self.func = func def __repr__(self): - return f"{self.func.__name__}(...)" + return _deferred_repr(self.func) def equals(self, other: Func) -> bool: return self.func == other.func @@ -224,12 +251,7 @@ def __init__(self, value: Any): self.value = value def __repr__(self): - if hasattr(self.value, "__deferred_repr__"): - return self.value.__deferred_repr__() - elif callable(self.value): - return getattr(self.value, "__name__", repr(self.value)) - else: - return repr(self.value) + return _deferred_repr(self.value) def equals(self, other: Just) -> bool: return self.value == other.value @@ -772,3 +794,53 @@ def deferred(obj, allow_custom=False) -> Deferred: def resolve(obj, **context): bldr: Builder = builder(obj) return bldr.build(context) + + +def _contains_deferred(obj: Any) -> bool: + if isinstance(obj, (Builder, Deferred)): + return True + elif (typ := type(obj)) in (tuple, list, set): + return any(_contains_deferred(o) for o in obj) + elif typ is dict: + return any(_contains_deferred(o) for o in obj.values()) + return False + + +def deferrable(func=None, *, repr=None): + """Wrap a top-level expr function to support deferred arguments. + + When a deferrable function is called, the args & kwargs are traversed to + look for `Deferred` values (through builtin collections like + `list`/`tuple`/`set`/`dict`). If any `Deferred` arguments are found, then + the result is also `Deferred`. Otherwise the function is called directly. + + Parameters + ---------- + func + A callable to make deferrable + repr + An optional fixed string to use when repr-ing the deferred expression, + instead of the usual. This is useful for complex deferred expressions + where the arguments don't necessarily make sense to be user facing + in the repr. + + """ + + def wrapper(func): + # Parse the signature of func so we can validate deferred calls eagerly, + # erroring for invalid/missing arguments at call time not resolve time. + sig = inspect.signature(func) + + @functools.wraps(func) + def inner(*args, **kwargs): + if _contains_deferred((args, kwargs)): + # Try to bind the arguments now, raising a nice error + # immediately if the function was called incorrectly + sig.bind(*args, **kwargs) + builder = Call(func, *args, **kwargs) + return Deferred(builder, repr=repr) + return func(*args, **kwargs) + + return inner # type: ignore + + return wrapper if func is None else wrapper(func) diff --git a/koerce/tests/test_builders.py b/koerce/tests/test_builders.py index 048ff3e..552deb4 100644 --- a/koerce/tests/test_builders.py +++ b/koerce/tests/test_builders.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +import pickle import pytest @@ -23,6 +24,7 @@ Unop, Var, builder, + deferrable, resolve, ) @@ -202,6 +204,11 @@ def test_deferred_object_are_not_hashable(): hash(_.a) +def test_deferred_set_raises(): + with pytest.raises(TypeError, match="unhashable type"): + {_.a, _.b} # noqa: B018 + + @pytest.mark.parametrize( ("value", "expected"), [ @@ -529,3 +536,83 @@ def test_builder_coercion(): def test_resolve(): deferred = _["a"] + 1 assert resolve(deferred, _={"a": 1}) == 2 + + +def test_deferrable(table): + @deferrable + def f(a, b, c=3): + return a + b + c + + assert f(table.a, table.b) == table.a + table.b + 3 + assert f(table.a, table.b, c=4) == table.a + table.b + 4 + + expr = f(_.a, _.b) + sol = table.a + table.b + 3 + res = resolve(expr, _=table) + assert res == sol + assert repr(expr) == "f($_.a, $_.b)" + + expr = f(1, 2, c=_.a) + sol = 3 + table.a + res = resolve(expr, _=table) + assert res == sol + assert repr(expr) == "f(1, 2, c=$_.a)" + + with pytest.raises(TypeError, match="unknown"): + f(_.a, _.b, unknown=3) # invalid calls caught at call time + + +def test_deferrable_repr(): + @deferrable(repr="") + def myfunc(x): + return x + 1 + + assert repr(myfunc(_.a)) == "" + + +@pytest.mark.parametrize( + "case", + [ + pytest.param(lambda: ([1, _], [1, 2]), id="list"), + pytest.param(lambda: ((1, _), (1, 2)), id="tuple"), + pytest.param(lambda: ({"x": 1, "y": _}, {"x": 1, "y": 2}), id="dict"), + pytest.param( + lambda: ({"x": 1, "y": [_, 3]}, {"x": 1, "y": [2, 3]}), id="nested" + ), + ], +) +def test_deferrable_nested_args(case): + arg, sol = case() + + @deferrable + def identity(x): + return x + + expr = identity(arg) + assert resolve(expr, _=2) == sol + assert identity(sol) is sol + assert repr(expr) == f"identity({arg!r})" + + +@pytest.mark.parametrize( + "func", + [ + pytest.param(lambda _: _, id="root"), + pytest.param(lambda _: _.a, id="getattr"), + pytest.param(lambda _: _["a"], id="getitem"), + pytest.param(lambda _: _.a.log(), id="method"), + pytest.param(lambda _: _.a.log(_.b), id="method-with-args"), + pytest.param(lambda _: _.a.log(base=_.b), id="method-with-kwargs"), + pytest.param(lambda _: _.a + _.b, id="binary-op"), + pytest.param(lambda _: ~_.a, id="unary-op"), + ], +) +def test_deferred_is_pickleable(func, table): + expr1 = func(_) + builder1 = builder(expr1) + builder2 = pickle.loads(pickle.dumps(builder1)) + + r1 = resolve(builder1, _=table) + r2 = resolve(builder2, _=table) + + assert r1 == r2