Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Nov 27, 2024
1 parent e8baef2 commit 29aaeae
Show file tree
Hide file tree
Showing 13 changed files with 290 additions and 45 deletions.
30 changes: 26 additions & 4 deletions formulaic/formula.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Union, cast

Expand Down Expand Up @@ -84,7 +85,7 @@ class Formula(Structured[List[Term]]):
DEFAULT_PARSER = DefaultFormulaParser()
DEFAULT_NESTED_PARSER = DefaultFormulaParser(include_intercept=False)

__slots__ = ("_parser", "_nested_parser", "_ordering")
__slots__ = ("_parser", "_nested_parser", "_ordering", "_context")

@classmethod
def from_spec(
Expand All @@ -94,6 +95,7 @@ def from_spec(
parser: Optional[FormulaParser] = None,
nested_parser: Optional[FormulaParser] = None,
ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
context: Optional[Mapping[str, Any]] = None,
) -> Formula:
"""
Construct a `Formula` instance from a formula specification.
Expand All @@ -115,7 +117,11 @@ def from_spec(
if isinstance(spec, Formula):
return spec
return Formula(
spec, _parser=parser, _nested_parser=nested_parser, _ordering=ordering
spec,
_parser=parser,
_nested_parser=nested_parser,
_ordering=ordering,
_context=context,
)

def __init__(
Expand All @@ -124,11 +130,13 @@ def __init__(
_parser: Optional[FormulaParser] = None,
_nested_parser: Optional[FormulaParser] = None,
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
_context: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
):
self._parser = _parser or self.DEFAULT_PARSER
self._nested_parser = _nested_parser or _parser or self.DEFAULT_NESTED_PARSER
self._ordering = OrderingMethod(_ordering)
self._context = _context
super().__init__(*args, **kwargs)
self._simplify(unwrap=False, inplace=True)

Expand All @@ -151,7 +159,7 @@ def _prepare_item(self, key: str, item: FormulaSpec) -> Union[List[Term], Formul
item = cast(
FormulaSpec,
(self._parser if key == "root" else self._nested_parser)
.get_terms(item)
.get_terms(item, context=self._context)
._simplify(),
)

Expand All @@ -169,7 +177,7 @@ def _prepare_item(self, key: str, item: FormulaSpec) -> Union[List[Term], Formul
term
for value in item
for term in (
self._nested_parser.get_terms(value) # type: ignore[attr-defined]
self._nested_parser.get_terms(value, context=self._context) # type: ignore[attr-defined]
if isinstance(value, str)
else [value]
)
Expand Down Expand Up @@ -267,6 +275,7 @@ def required_variables(self) -> Set[Variable]:
variable
for term in terms
for factor in term.factors
if factor != "."
for variable in get_expression_variables(factor.expr, {})
if "value" in variable.roles
)
Expand Down Expand Up @@ -331,3 +340,16 @@ def __repr__(self, to_str: Callable[..., str] = repr) -> str:
if not self._has_structure and self._has_root:
return " + ".join([str(t) for t in self])
return str(self._map(lambda terms: " + ".join([str(t) for t in terms])))

# Ensure pickling never includes context
def __getstate__(self):
if self._context is not None:
warnings.warn(
"Dropping context from Formula instance during pickling.",
RuntimeWarning,
stacklevel=2,
)

state = super().__getstate__()
state[1]["_context"] = None
return state
33 changes: 25 additions & 8 deletions formulaic/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class ModelSpec:
def from_spec(
cls,
spec: Union[FormulaSpec, ModelMatrix, ModelMatrices, ModelSpec, ModelSpecs],
*,
context: Optional[Mapping[str, Any]] = None,
**attrs: Any,
) -> Union[ModelSpec, ModelSpecs]:
"""
Expand All @@ -97,7 +99,7 @@ def prepare_model_spec(obj: Any) -> ModelSpec:
obj = obj.model_spec
if isinstance(obj, ModelSpec):
return obj.update(**attrs)
formula = Formula.from_spec(obj)
formula = Formula.from_spec(obj, context=context)
if not formula._has_root or formula._has_structure:
return cast(
ModelSpec, formula._map(prepare_model_spec, as_type=ModelSpecs)
Expand Down Expand Up @@ -366,6 +368,12 @@ def variables_by_source(self) -> Dict[Optional[str], Set[Variable]]:
variables_by_source[variable.source].add(variable)
return dict(variables_by_source)

@property
def required_variables(self):
if self.structure is None:
return self.formula.required_variables
return self.variables_by_source.get("data", set())

# Transforms

def update(self, **kwargs: Any) -> ModelSpec:
Expand Down Expand Up @@ -397,6 +405,13 @@ def differentiate(self, *wrt: str, use_sympy: bool = False) -> ModelSpec:

# Utility methods

def get_materializer(self, data: Any, context: Optional[Mapping[str, Any]] = None):
if self.materializer is None:
materializer = FormulaMaterializer.for_data(data)
else:
materializer = FormulaMaterializer.for_materializer(self.materializer)
return materializer(data, context=context, **(self.materializer_params or {}))

def get_model_matrix(
self,
data: Any,
Expand All @@ -422,13 +437,9 @@ def get_model_matrix(
"""
if attr_overrides:
return self.update(**attr_overrides).get_model_matrix(data, context=context)
if self.materializer is None:
materializer = FormulaMaterializer.for_data(data)
else:
materializer = FormulaMaterializer.for_materializer(self.materializer)
return materializer(
data, context=context, **(self.materializer_params or {})
).get_model_matrix(self, drop_rows=drop_rows)
return self.get_materializer(data, context=context).get_model_matrix(
self, drop_rows=drop_rows
)

def get_linear_constraints(self, spec: LinearConstraintSpec) -> LinearConstraints:
"""
Expand Down Expand Up @@ -501,6 +512,12 @@ def _prepare_item(self, key: str, item: Any) -> Any:
)
return item

@property
def required_variables(self) -> Set[Variable]:
variables: Set[Variable] = set()
self._map(lambda ms: variables.update(ms.required_variables))
return variables

def get_model_matrix(
self,
data: Any,
Expand Down
45 changes: 42 additions & 3 deletions formulaic/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from typing_extensions import Self

from formulaic.parser.types.ast_node import ASTNode
from formulaic.utils.layered_mapping import LayeredMapping

from .algos.sanitize_tokens import sanitize_tokens
from .algos.tokenize import tokenize
from .types import (
Expand Down Expand Up @@ -97,7 +100,9 @@ def set_feature_flags(self, flags: DefaultParserFeatureFlag | Set[str]) -> Self:
self.__post_init__()
return self

def get_tokens(self, formula: str) -> Iterable[Token]:
def get_tokens_from_formula(
self, formula: str, *, context: MutableMapping[str, Any]
) -> Iterable[Token]:
"""
Return an iterable of `Token` instances for the nominated `formula`
string.
Expand Down Expand Up @@ -178,12 +183,20 @@ def find_rhs_index(tokens: List[Token]) -> int:
),
]

context["__formulaic_used_variables__"] = [
variable
for token in tokens[:rhs_index]
for variable in token.required_variables
]

# Collapse inserted "+" and "-" operators to prevent unary issues.
tokens = merge_operator_tokens(tokens, symbols={"+", "-"})

return tokens

def get_terms(self, formula: str) -> Structured[List[Term]]:
def get_terms_from_ast(
self, ast: ASTNode, *, context: Optional[Mapping[str, Any]] = None
) -> Structured[OrderedSet[Term]]:
"""
Assemble the `Term` instances for a formula string. Depending on the
operators involved, this may be an iterable of `Term` instances, or
Expand All @@ -195,8 +208,11 @@ def get_terms(self, formula: str) -> Structured[List[Term]]:
Args:
formula: The formula for which an AST should be generated.
context: An optional context which may be used during the evaluation
of operators.
"""
terms = super().get_terms(formula)

terms = super().get_terms_from_ast(ast, context=context)

def check_terms(terms: Iterable[Term]) -> None:
seen_terms = set()
Expand Down Expand Up @@ -333,6 +349,22 @@ def get_terms(terms: OrderedSet[Term]) -> List[Term]:

return Structured(get_terms(lhs), deps=(Structured(lhs=lhs, rhs=rhs),))

def insert_unused_terms(context: Mapping[str, Any]) -> OrderedSet[Term]:
if (
not isinstance(context, LayeredMapping)
or "data" not in context.named_layers
):
raise ValueError(
"Context must be a layered mapping with a named 'data' layer."
)
return OrderedSet(
Term([Factor(variable, eval_method="lookup")])
for variable in (
set(context.named_layers["data"])
- set(context["__formulaic_used_variables__"])
)
)

return [
Operator(
"~",
Expand Down Expand Up @@ -451,6 +483,13 @@ def get_terms(terms: OrderedSet[Term]) -> List[Term]:
Operator(
"^", arity=2, precedence=500, associativity="right", to_terms=power
),
Operator(
".",
arity=0,
precedence=1000,
fixity="postfix",
to_terms=insert_unused_terms,
),
]

def resolve(
Expand Down
5 changes: 5 additions & 0 deletions formulaic/parser/types/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def kind(self) -> Kind:
def kind(self, kind: Union[str, Factor.Kind]) -> None:
self._kind = Factor.Kind(kind or "unknown")

def __mul__(self, other: Any) -> Term:
if isinstance(other, Factor):
return Term([self, other])
return NotImplemented

def __eq__(self, other: Any) -> bool:
if isinstance(other, str):
return self.expr == other
Expand Down
Loading

0 comments on commit 29aaeae

Please sign in to comment.