Skip to content

Commit cbd47b3

Browse files
Add .required_variables property to ModelSpec[s] and Token instances.
1 parent a7ee19f commit cbd47b3

File tree

6 files changed

+83
-2
lines changed

6 files changed

+83
-2
lines changed

formulaic/model_spec.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,21 @@ def variables_by_source(self) -> Dict[Optional[str], Set[Variable]]:
424424
variables_by_source[variable.source].add(variable)
425425
return dict(variables_by_source)
426426

427+
@property
428+
def required_variables(self) -> Set[Variable]:
429+
"""
430+
The set of variables required to be in the data to materialize this
431+
model specification.
432+
433+
If `.structure` has not been populated (which contains metadata about
434+
which columns where ultimate drawn from the data during
435+
materialization), then this will fallback to the variables inferred to
436+
be required by `.formula`.
437+
"""
438+
if self.structure is None:
439+
return self.formula.required_variables
440+
return self.variables_by_source.get("data", set())
441+
427442
def get_slice(self, columns_identifier: Union[int, str, Term, slice]) -> slice:
428443
"""
429444
Generate a `slice` instance corresponding to the columns associated with
@@ -656,6 +671,16 @@ def _prepare_item(self, key: str, item: Any) -> Any:
656671
)
657672
return item
658673

674+
@property
675+
def required_variables(self) -> Set[Variable]:
676+
"""
677+
The set of variables required to be in the data to materialize all of
678+
the model specifications in this `ModelSpecs` instance.
679+
"""
680+
variables: Set[Variable] = set()
681+
self._map(lambda ms: variables.update(ms.required_variables))
682+
return variables
683+
659684
def get_model_matrix(
660685
self,
661686
data: Any,

formulaic/parser/types/token.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import copy
44
import re
55
from enum import Enum
6-
from typing import Any, Iterable, Mapping, Optional, Tuple, Union
6+
from typing import Any, Iterable, Mapping, Optional, Set, Tuple, Union
7+
8+
from formulaic.utils.variables import Variable, get_expression_variables
79

810
from .factor import Factor
911
from .ordered_set import OrderedSet
@@ -179,6 +181,40 @@ def get_source_context(self, colorize: bool = False) -> Optional[str]:
179181
return f"{self.source[:self.source_start]}{RED_BOLD}{self.source[self.source_start:self.source_end+1]}{RESET}{self.source[self.source_end+1:]}"
180182
return f"{self.source[:self.source_start]}{self.source[self.source_start:self.source_end+1]}{self.source[self.source_end+1:]}"
181183

184+
@property
185+
def required_variables(self) -> Set[Variable]:
186+
"""
187+
The set of variables required to evaluate this token.
188+
189+
If this is a Python token, and the code is malformed and unable to be
190+
parsed, an empty set is returned. The code will fail more gracefully
191+
later on.
192+
193+
Attempts are made to restrict these variables only to those expected in
194+
the data, and not, for example, those associated with transforms and/or
195+
values present in the evaluation namespace by default (e.g. `y ~ C(x)`
196+
would include only `y` and `x`). This may not always be possible for
197+
more advanced formulae that insert constants into the formula via the
198+
evaluation context rather than the data context.
199+
"""
200+
if self.kind is Token.Kind.NAME:
201+
return {Variable(self.token)}
202+
if self.kind is Token.Kind.PYTHON:
203+
try:
204+
# Filter out constants like `contr` that are already present in the
205+
# TRANSFORMS namespace.
206+
from formulaic.transforms import TRANSFORMS
207+
208+
return set(
209+
filter(
210+
lambda variable: variable.split(".", 1)[0] not in TRANSFORMS,
211+
get_expression_variables(self.token),
212+
)
213+
)
214+
except Exception: # noqa: S110
215+
pass
216+
return set()
217+
182218
def __repr__(self) -> str:
183219
return self.token
184220

formulaic/utils/variables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def union(cls, *variable_sets: Iterable[Variable]) -> Set[Variable]:
4646

4747

4848
def get_expression_variables(
49-
expr: Union[str, ast.AST], context: Mapping, aliases: Optional[Mapping] = None
49+
expr: Union[str, ast.AST],
50+
context: Optional[Mapping] = None,
51+
aliases: Optional[Mapping] = None,
5052
) -> Set[Variable]:
5153
"""
5254
Extract the variables that are used in the nominated Python expression.

tests/parser/types/test_term.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
from formulaic.parser.types import Factor, Term
4+
from formulaic.parser.types.ordered_set import OrderedSet
45

56

67
class TestTerm:
@@ -48,3 +49,6 @@ def test_degree(self, term1, term3):
4849
assert term3.degree == 3
4950
assert Term([Factor("1", eval_method="literal")]).degree == 0
5051
assert Term([Factor("1", eval_method="literal"), Factor("x")]).degree == 1
52+
53+
def test_to_terms(self, term1):
54+
assert term1.to_terms() == OrderedSet((term1,))

tests/parser/types/test_token.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,9 @@ def test_split(self, token_a):
9595
Token("b"),
9696
Token("c"),
9797
]
98+
99+
def test_required_variables(self, token_a, token_b):
100+
assert token_a.required_variables == {"a"}
101+
assert token_b.required_variables == {"x"}
102+
assert Token("malformed((python", kind="python").required_variables == set()
103+
assert Token("xyz", kind="value").required_variables == set()

tests/test_model_spec.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ def test_get_variable_indices(self, model_spec):
123123
assert model_spec.get_variable_indices("a") == [1, 4, 5]
124124
assert model_spec.get_variable_indices("A") == [2, 3, 4, 5]
125125

126+
def test_required_variables(self, model_spec):
127+
assert model_spec.structure
128+
assert model_spec.required_variables == {"a", "A"}
129+
130+
# Derived using formula instead of structure
131+
assert model_spec.update(structure=None).required_variables == {"a", "A"}
132+
126133
def test_get_slice(self, model_spec):
127134
s = slice(0, 1)
128135
assert model_spec.get_slice(s) is s
@@ -274,6 +281,7 @@ def test_model_specs(self, model_spec, data2):
274281
assert numpy.all(
275282
model_specs.get_model_matrix(data2).a == model_spec.get_model_matrix(data2)
276283
)
284+
assert model_specs.required_variables == {"a", "A"}
277285
sparse_matrices = model_specs.get_model_matrix(data2, output="sparse")
278286
assert isinstance(sparse_matrices, ModelMatrices)
279287
assert isinstance(sparse_matrices.a, scipy.sparse.spmatrix)

0 commit comments

Comments
 (0)