Skip to content

Commit

Permalink
Add .required_variables property to ModelSpec[s] and Token inst…
Browse files Browse the repository at this point in the history
…ances.
  • Loading branch information
matthewwardrop committed Dec 2, 2024
1 parent 02d8349 commit 52ea23f
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 2 deletions.
25 changes: 25 additions & 0 deletions formulaic/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,21 @@ def variables_by_source(self) -> Dict[Optional[str], Set[Variable]]:
variables_by_source[variable.source].add(variable)
return dict(variables_by_source)

@property
def required_variables(self) -> Set[Variable]:
"""
The set of variables required to be in the data to materialize this
model specification.
If `.structure` has not been populated (which contains metadata about
which columns where ultimate drawn from the data during
materialization), then this will fallback to the variables inferred to
be required by `.formula`.
"""
if self.structure is None:
return self.formula.required_variables
return self.variables_by_source.get("data", set())

def get_slice(self, columns_identifier: Union[int, str, Term, slice]) -> slice:
"""
Generate a `slice` instance corresponding to the columns associated with
Expand Down Expand Up @@ -656,6 +671,16 @@ def _prepare_item(self, key: str, item: Any) -> Any:
)
return item

@property
def required_variables(self) -> Set[Variable]:
"""
The set of variables required to be in the data to materialize all of
the model specifications in this `ModelSpecs` instance.
"""
variables: Set[Variable] = set()
self._map(lambda ms: variables.update(ms.required_variables))
return variables

def get_model_matrix(
self,
data: Any,
Expand Down
38 changes: 37 additions & 1 deletion formulaic/parser/types/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import copy
import re
from enum import Enum
from typing import Any, Iterable, Mapping, Optional, Tuple, Union
from typing import Any, Iterable, Mapping, Optional, Set, Tuple, Union

from formulaic.utils.variables import Variable, get_expression_variables

from .factor import Factor
from .ordered_set import OrderedSet
Expand Down Expand Up @@ -179,6 +181,40 @@ def get_source_context(self, colorize: bool = False) -> Optional[str]:
return f"{self.source[:self.source_start]}{RED_BOLD}{self.source[self.source_start:self.source_end+1]}{RESET}{self.source[self.source_end+1:]}"
return f"{self.source[:self.source_start]}{self.source[self.source_start:self.source_end+1]}{self.source[self.source_end+1:]}"

@property
def required_variables(self) -> Set[Variable]:
"""
The set of variables required to evaluate this token.
If this is a Python token, and the code is malformed and unable to be
parsed, an empty set is returned. The code will fail more gracefully
later on.
Attempts are made to restrict these variables only to those expected in
the data, and not, for example, those associated with transforms and/or
values present in the evaluation namespace by default (e.g. `y ~ C(x)`
would include only `y` and `x`). This may not always be possible for
more advanced formulae that insert constants into the formula via the
evaluation context rather than the data context.
"""
if self.kind is Token.Kind.NAME:
return {Variable(self.token)}
if self.kind is Token.Kind.PYTHON:
try:
# Filter out constants like `contr` that are already present in the
# TRANSFORMS namespace.
from formulaic.transforms import TRANSFORMS

return set(
filter(
lambda variable: variable.split(".", 1)[0] not in TRANSFORMS,
get_expression_variables(self.token),
)
)
except Exception: # noqa: S110
pass
return set()

def __repr__(self) -> str:
return self.token

Expand Down
4 changes: 3 additions & 1 deletion formulaic/utils/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def union(cls, *variable_sets: Iterable[Variable]) -> Set[Variable]:


def get_expression_variables(
expr: Union[str, ast.AST], context: Mapping, aliases: Optional[Mapping] = None
expr: Union[str, ast.AST],
context: Optional[Mapping] = None,
aliases: Optional[Mapping] = None,
) -> Set[Variable]:
"""
Extract the variables that are used in the nominated Python expression.
Expand Down
4 changes: 4 additions & 0 deletions tests/parser/types/test_term.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from formulaic.parser.types import Factor, Term
from formulaic.parser.types.ordered_set import OrderedSet


class TestTerm:
Expand Down Expand Up @@ -48,3 +49,6 @@ def test_degree(self, term1, term3):
assert term3.degree == 3
assert Term([Factor("1", eval_method="literal")]).degree == 0
assert Term([Factor("1", eval_method="literal"), Factor("x")]).degree == 1

def test_to_terms(self, term1):
assert term1.to_terms() == OrderedSet((term1,))
6 changes: 6 additions & 0 deletions tests/parser/types/test_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,9 @@ def test_split(self, token_a):
Token("b"),
Token("c"),
]

def test_required_variables(self, token_a, token_b):
assert token_a.required_variables == {"a"}
assert token_b.required_variables == {"x"}
assert Token("malformed((python", kind="python").required_variables == set()
assert Token("xyz", kind="value").required_variables == set()
8 changes: 8 additions & 0 deletions tests/test_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ def test_get_variable_indices(self, model_spec):
assert model_spec.get_variable_indices("a") == [1, 4, 5]
assert model_spec.get_variable_indices("A") == [2, 3, 4, 5]

def test_required_variables(self, model_spec):
assert model_spec.structure
assert model_spec.required_variables == {"a", "A"}

# Derived using formula instead of structure
assert model_spec.update(structure=None).required_variables == {"a", "A"}

def test_get_slice(self, model_spec):
s = slice(0, 1)
assert model_spec.get_slice(s) is s
Expand Down Expand Up @@ -274,6 +281,7 @@ def test_model_specs(self, model_spec, data2):
assert numpy.all(
model_specs.get_model_matrix(data2).a == model_spec.get_model_matrix(data2)
)
assert model_specs.required_variables == {"a", "A"}
sparse_matrices = model_specs.get_model_matrix(data2, output="sparse")
assert isinstance(sparse_matrices, ModelMatrices)
assert isinstance(sparse_matrices.a, scipy.sparse.spmatrix)
Expand Down

0 comments on commit 52ea23f

Please sign in to comment.