diff --git a/src/lgdo/types/table.py b/src/lgdo/types/table.py index 34cdac5..bf59153 100644 --- a/src/lgdo/types/table.py +++ b/src/lgdo/types/table.py @@ -7,6 +7,7 @@ import logging from collections.abc import Mapping +from types import ModuleType from typing import Any from warnings import warn @@ -266,6 +267,7 @@ def eval( self, expr: str, parameters: Mapping[str, str] | None = None, + modules: Mapping[str, ModuleType] | None = None, ) -> LGDO: """Apply column operations to the table and return a new LGDO. @@ -299,6 +301,10 @@ def eval( a dictionary of function parameters. Passed to :func:`numexpr.evaluate`` as `local_dict` argument or to :func:`eval` as `locals` argument. + modules + a dictionary of additional modules used by the expression. If this is not `None` + then :func:`eval`is used and the expression can depend on any modules from this dictionary in + addition to awkward and numpy. These are passed to :func:`eval` as `globals` argument. Examples -------- @@ -339,8 +345,8 @@ def eval( msg = f"evaluating {expr!r} with locals={(self_unwrap | parameters)} and {has_ak=}" log.debug(msg) - # use numexpr if we are only dealing with numpy data types - if not has_ak: + # use numexpr if we are only dealing with numpy data types (and no global dictionary) + if not has_ak and modules is None: out_data = ne.evaluate( expr, local_dict=(self_unwrap | parameters), @@ -366,6 +372,9 @@ def eval( # resort to good ol' eval() globs = {"ak": ak, "np": np} + if modules is not None: + globs = globs | modules + out_data = eval(expr, globs, (self_unwrap | parameters)) msg = f"...the result is {out_data!r}" @@ -380,6 +389,10 @@ def eval( if np.isscalar(out_data): return Scalar(out_data) + # if out_data is already an LGDO just return it + if isinstance(out_data, LGDO): + return out_data + msg = ( f"evaluation resulted in a {type(out_data)} object, " "I don't know which LGDO this corresponds to" diff --git a/tests/types/test_table_eval.py b/tests/types/test_table_eval.py index 1fb837e..1c4a180 100644 --- a/tests/types/test_table_eval.py +++ b/tests/types/test_table_eval.py @@ -1,6 +1,8 @@ from __future__ import annotations +import hist import numpy as np +import pytest import lgdo @@ -31,6 +33,7 @@ def test_eval_dependency(): ), } ) + r = obj.eval("sum(a)") assert isinstance(r, lgdo.Scalar) @@ -77,3 +80,11 @@ def test_eval_dependency(): assert isinstance(r, lgdo.Scalar) assert obj.eval("np.sum(a) + ak.sum(e)") + + # test with modules argument, the simplest is using directly lgdo + res = obj.eval("lgdo.Array([1,2,3])", {}, modules={"lgdo": lgdo}) + assert res == lgdo.Array([1, 2, 3]) + + # check bad type + with pytest.raises(RuntimeError): + obj.eval("hist.Hist()", modules={"hist": hist})