Skip to content

Commit

Permalink
Allow passing context to Operator.t_terms (and also for other obj…
Browse files Browse the repository at this point in the history
…ects that may or may not appear in the AST tree).
  • Loading branch information
matthewwardrop committed Dec 1, 2024
1 parent 8787e2f commit ee1c3a2
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 14 deletions.
32 changes: 26 additions & 6 deletions formulaic/parser/types/ast_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
from __future__ import annotations

import functools
import graphlib
from typing import Any, Dict, Generic, Iterable, List, Tuple, TypeVar, Union
from typing import (
Any,
Dict,
Generic,
Iterable,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

from .operator import Operator
from .structured import Structured
Expand All @@ -28,13 +40,19 @@ def __init__(self, operator: Operator, args: Iterable[Any]):
self.operator = operator
self.args = args

def to_terms(self) -> Union[List[Term], Structured[List[Term]], Tuple]:
def to_terms(
self, *, context: Optional[Mapping[str, Any]] = None
) -> Union[List[Term], Structured[List[Term]], Tuple]:
"""
Evaluate this AST node and return the resulting set of `Term` instances.
Note: We use topological evaluation here to avoid recursion issues for
long formula (exceeding ~700 terms, though this depends on the recursion
limit set in the interpreter).
Args:
context: An optional context mapping that can be used by operators
to modify their behaviour (e.g. the `.` operator).
"""
g = graphlib.TopologicalSorter(self.__generate_evaluation_graph())
g.prepare()
Expand All @@ -43,16 +61,18 @@ def to_terms(self) -> Union[List[Term], Structured[List[Term]], Tuple]:

while g.is_active():
for node in g.get_ready():
node_args = (
node_args = tuple(
(results[arg] if isinstance(arg, ASTNode) else arg.to_terms())
for arg in node.args
)
if node.operator.structural:
results[node] = node.operator.to_terms(*node_args)
if node.operator.structural or not node_args:
results[node] = node.operator.to_terms(*node_args, context=context)
else:
results[node] = Structured._merge(
*node_args,
merger=node.operator.to_terms,
merger=functools.partial(
node.operator.to_terms, context=context
),
)
g.done(node)

Expand Down
6 changes: 4 additions & 2 deletions formulaic/parser/types/factor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union

from .ordered_set import OrderedSet
from .term import Term
Expand Down Expand Up @@ -91,7 +91,9 @@ def __lt__(self, other: Any) -> bool:
return self.expr < other.expr
return NotImplemented

def to_terms(self) -> OrderedSet[Term]:
def to_terms(
self, *, context: Optional[Mapping[str, Any]] = None
) -> OrderedSet[Term]:
"""
Convert this `Factor` instance into a `Term` instance, and expose it as
a single-element ordered set.
Expand Down
7 changes: 5 additions & 2 deletions formulaic/parser/types/operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import inspect
from enum import Enum
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Mapping, Optional, Union

from .token import Token

Expand Down Expand Up @@ -91,9 +92,11 @@ def fixity(self) -> Operator.Fixity:
def fixity(self, fixity: Union[str, Operator.Fixity]) -> None:
self._fixity = Operator.Fixity(fixity)

def to_terms(self, *args: Any) -> Any:
def to_terms(self, *args: Any, context: Optional[Mapping[str, Any]] = None) -> Any:
if self._to_terms is None:
raise RuntimeError(f"`to_terms` is not implemented for '{self.symbol}'.")
if inspect.signature(self._to_terms).parameters.get("context"):
return self._to_terms(*args, context=context or {})
return self._to_terms(*args)

def accepts_context(self, context: List[Union[Token, Operator]]) -> bool:
Expand Down
12 changes: 11 additions & 1 deletion formulaic/parser/types/term.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, Iterable, Optional
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional

from .ordered_set import OrderedSet

if TYPE_CHECKING:
from .factor import Factor # pragma: no cover
Expand Down Expand Up @@ -65,5 +67,13 @@ def __lt__(self, other: Any) -> bool:
return False
return NotImplemented

def to_terms(
self, *, context: Optional[Mapping[str, Any]] = None
) -> OrderedSet[Term]:
"""
Convert this `Term` instance into set of `Term`s.
"""
return OrderedSet((self,))

def __repr__(self) -> str:
return ":".join(repr(factor) for factor in self.factors)
6 changes: 4 additions & 2 deletions formulaic/parser/types/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import re
from enum import Enum
from typing import Any, Iterable, Optional, Tuple, Union
from typing import Any, Iterable, Mapping, Optional, Tuple, Union

from .factor import Factor
from .ordered_set import OrderedSet
Expand Down Expand Up @@ -140,7 +140,9 @@ def to_factor(self) -> Factor:
token=self,
)

def to_terms(self) -> OrderedSet[Term]:
def to_terms(
self, *, context: Optional[Mapping[str, Any]] = None
) -> OrderedSet[Term]:
"""
An order set of `Term` instances for this token. This will just be
an iterable with one `Term` having one `Factor` (that generated by
Expand Down
5 changes: 4 additions & 1 deletion formulaic/utils/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -320,7 +321,9 @@ def for_token(cls, token: Token) -> ConstraintToken:
}
)

def to_terms(self) -> Set[ScaledFactor]: # type: ignore[override]
def to_terms( # type: ignore[override]
self, *, context: Optional[Mapping[str, Any]] = None
) -> Set[ScaledFactor]:
if self.kind is Token.Kind.VALUE:
factor = ast.literal_eval(self.token)
if isinstance(factor, (int, float)):
Expand Down

0 comments on commit ee1c3a2

Please sign in to comment.