Skip to content

Commit 02d8349

Browse files
Pass parsing context through Formula, ModelSpec and FormulaMaterializer objects; as well as model_matrix.
1 parent c31ada0 commit 02d8349

File tree

4 files changed

+83
-22
lines changed

4 files changed

+83
-22
lines changed

formulaic/formula.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __call__(
7474
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
7575
_parser: Optional[FormulaParser] = None,
7676
_nested_parser: Optional[FormulaParser] = None,
77+
_context: Optional[Mapping[str, Any]] = None,
7778
**structure: FormulaSpec,
7879
) -> Formula:
7980
"""
@@ -82,7 +83,7 @@ def __call__(
8283
`SimpleFormula` instance will be returned; otherwise, a
8384
`StructuredFormula`.
8485
85-
Some arguments a prefixed with underscores to prevent collision with
86+
Some arguments are prefixed with underscores to prevent collision with
8687
formula structure.
8788
8889
Args:
@@ -108,6 +109,7 @@ def __call__(
108109
_ordering=_ordering,
109110
_parser=_parser,
110111
_nested_parser=_nested_parser,
112+
_context=_context,
111113
**structure,
112114
)
113115
return self
@@ -120,13 +122,15 @@ def __call__(
120122
_parser=_parser,
121123
_nested_parser=_nested_parser,
122124
_ordering=_ordering,
123-
**structure,
124-
)
125+
_context=_context,
126+
**structure, # type: ignore[arg-type]
127+
)._simplify()
125128
return cls.from_spec(
126129
cast(FormulaSpec, root),
127130
ordering=_ordering,
128131
parser=_parser,
129132
nested_parser=_nested_parser,
133+
context=_context,
130134
)
131135

132136
def from_spec(
@@ -136,6 +140,7 @@ def from_spec(
136140
ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
137141
parser: Optional[FormulaParser] = None,
138142
nested_parser: Optional[FormulaParser] = None,
143+
context: Optional[Mapping[str, Any]] = None,
139144
) -> Union[SimpleFormula, StructuredFormula]:
140145
"""
141146
Construct a `SimpleFormula` or `StructuredFormula` instance from a
@@ -164,18 +169,25 @@ def from_spec(
164169
if isinstance(spec, str):
165170
spec = cast(
166171
FormulaSpec,
167-
(parser or DefaultFormulaParser()).get_terms(spec)._simplify(),
172+
(parser or DefaultFormulaParser())
173+
.get_terms(spec, context=context)
174+
._simplify(),
168175
)
169176

170177
if isinstance(spec, dict):
171178
return StructuredFormula(
172-
_parser=parser, _nested_parser=nested_parser, _ordering=ordering, **spec
179+
_parser=parser,
180+
_nested_parser=nested_parser,
181+
_ordering=ordering,
182+
_context=context,
183+
**spec, # type: ignore[arg-type]
173184
)
174185
if isinstance(spec, Structured):
175186
return StructuredFormula(
176187
_ordering=ordering,
177188
_parser=nested_parser,
178189
_nested_parser=nested_parser,
190+
_context=context,
179191
**spec._structure,
180192
)._simplify()
181193
if isinstance(spec, tuple):
@@ -184,13 +196,14 @@ def from_spec(
184196
_ordering=ordering,
185197
_parser=parser,
186198
_nested_parser=nested_parser,
199+
_context=context,
187200
)._simplify()
188201
if isinstance(spec, (list, set, OrderedSet)):
189202
terms = [
190203
term
191204
for value in spec
192205
for term in (
193-
nested_parser.get_terms(value) # type: ignore[attr-defined]
206+
nested_parser.get_terms(value, context=context) # type: ignore[attr-defined]
194207
if isinstance(value, str)
195208
else [value]
196209
)
@@ -248,9 +261,11 @@ class Formula(metaclass=_FormulaMeta):
248261
def __init__(
249262
self,
250263
root: Union[FormulaSpec, _MissingType] = MISSING,
264+
*,
251265
_parser: Optional[FormulaParser] = None,
252266
_nested_parser: Optional[FormulaParser] = None,
253267
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
268+
_context: Optional[Mapping[str, Any]] = None,
254269
**structure: FormulaSpec,
255270
):
256271
"""
@@ -288,7 +303,7 @@ def get_model_matrix(
288303
@abstractmethod
289304
def required_variables(self) -> Set[Variable]:
290305
"""
291-
The set of variables required in the data order to materialize this
306+
The set of variables required to be in the data to materialize this
292307
formula.
293308
294309
Attempts are made to restrict these variables only to those expected in
@@ -354,6 +369,7 @@ def __init__(
354369
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
355370
_parser: Optional[FormulaParser] = None,
356371
_nested_parser: Optional[FormulaParser] = None,
372+
_context: Optional[Mapping[str, Any]] = None,
357373
**structure: FormulaSpec,
358374
):
359375
if root is MISSING:
@@ -667,19 +683,22 @@ class StructuredFormula(Structured[SimpleFormula], Formula):
667683
formula specifications. Can be: "none", "degree" (default), or "sort".
668684
"""
669685

670-
__slots__ = ("_parser", "_nested_parser", "_ordering")
686+
__slots__ = ("_parser", "_nested_parser", "_ordering", "_context")
671687

672688
def __init__(
673689
self,
674690
root: Union[FormulaSpec, _MissingType] = MISSING,
691+
*,
692+
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
675693
_parser: Optional[FormulaParser] = None,
676694
_nested_parser: Optional[FormulaParser] = None,
677-
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
695+
_context: Optional[Mapping[str, Any]] = None,
678696
**structure: FormulaSpec,
679697
):
698+
self._ordering = OrderingMethod(_ordering)
680699
self._parser = _parser or DEFAULT_PARSER
681700
self._nested_parser = _nested_parser or _parser or DEFAULT_NESTED_PARSER
682-
self._ordering = OrderingMethod(_ordering)
701+
self._context = _context
683702
super().__init__(root, **structure) # type: ignore
684703
self._simplify(unwrap=False, inplace=True)
685704

@@ -704,6 +723,7 @@ def _prepare_item( # type: ignore[override]
704723
ordering=self._ordering,
705724
parser=(self._parser if key == "root" else self._nested_parser),
706725
nested_parser=self._nested_parser,
726+
context=self._context,
707727
)
708728

709729
def get_model_matrix(
@@ -782,3 +802,9 @@ def differentiate( # pylint: disable=redefined-builtin
782802
SimpleFormula,
783803
self._map(lambda formula: formula.differentiate(*wrt, use_sympy=use_sympy)),
784804
)
805+
806+
# Ensure pickling never includes context
807+
def __getstate__(self) -> Tuple[None, Dict[str, Any]]:
808+
state = cast(Tuple[None, Dict[str, Any]], super().__getstate__())
809+
state[1]["_context"] = None
810+
return state

formulaic/materializers/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def get_model_matrix(
163163
from formulaic import ModelSpec
164164

165165
# Prepare ModelSpec(s)
166-
spec: Union[ModelSpec, ModelSpecs] = ModelSpec.from_spec(spec, **spec_overrides)
166+
spec: Union[ModelSpec, ModelSpecs] = ModelSpec.from_spec(
167+
spec, context=self.layered_context, **spec_overrides
168+
)
167169
should_simplify = isinstance(spec, ModelSpec)
168170
model_specs: ModelSpecs = self._prepare_model_specs(spec)
169171

formulaic/model_spec.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ class ModelSpec:
7878
def from_spec(
7979
cls,
8080
spec: Union[FormulaSpec, ModelMatrix, ModelMatrices, ModelSpec, ModelSpecs],
81+
*,
82+
context: Optional[Mapping[str, Any]] = None,
8183
**attrs: Any,
8284
) -> Union[ModelSpec, ModelSpecs]:
8385
"""
@@ -90,6 +92,11 @@ def from_spec(
9092
instance or structured set of `ModelSpec` instances.
9193
attrs: Any `ModelSpec` attributes to set and/or override on all
9294
generated `ModelSpec` instances.
95+
context: Optional additional context to pass through to the formula
96+
parsing algorithms. This is not normally required, and if
97+
involved operators place additional constraints on the type
98+
and/or structure of this context, they will raise exceptions
99+
when they are not satisfied with instructions for how to fix it.
93100
"""
94101
from .model_matrix import ModelMatrix
95102

@@ -98,7 +105,7 @@ def prepare_model_spec(obj: Any) -> Union[ModelSpec, ModelSpecs]:
98105
obj = obj.model_spec
99106
if isinstance(obj, ModelSpec):
100107
return obj.update(**attrs)
101-
formula = Formula.from_spec(obj)
108+
formula = Formula.from_spec(obj, context=context)
102109
if isinstance(formula, StructuredFormula):
103110
return cast(
104111
ModelSpecs, formula._map(prepare_model_spec, as_type=ModelSpecs)
@@ -459,6 +466,24 @@ def get_slice(self, columns_identifier: Union[int, str, Term, slice]) -> slice:
459466

460467
# Utility methods
461468

469+
def get_materializer(
470+
self, data: Any, context: Optional[Mapping[str, Any]] = None
471+
) -> FormulaMaterializer:
472+
"""
473+
Construct a `FormulaMaterializer` instance for `data` that can be used
474+
to generate model matrices consistent with this model specification.
475+
476+
Args:
477+
data: The data for which to build the materializer.
478+
context: An additional mapping object of names to make available in
479+
when evaluating formula term factors.
480+
"""
481+
if self.materializer is None:
482+
materializer = FormulaMaterializer.for_data(data)
483+
else:
484+
materializer = FormulaMaterializer.for_materializer(self.materializer)
485+
return materializer(data, context=context, **(self.materializer_params or {}))
486+
462487
def get_model_matrix(
463488
self,
464489
data: Any,
@@ -484,13 +509,12 @@ def get_model_matrix(
484509
"""
485510
if attr_overrides:
486511
return self.update(**attr_overrides).get_model_matrix(data, context=context)
487-
if self.materializer is None:
488-
materializer = FormulaMaterializer.for_data(data)
489-
else:
490-
materializer = FormulaMaterializer.for_materializer(self.materializer)
491-
return materializer(
492-
data, context=context, **(self.materializer_params or {})
493-
).get_model_matrix(self, drop_rows=drop_rows)
512+
return cast(
513+
"ModelMatrix",
514+
self.get_materializer(data, context=context).get_model_matrix(
515+
self, drop_rows=drop_rows
516+
),
517+
)
494518

495519
def get_linear_constraints(self, spec: LinearConstraintSpec) -> LinearConstraints:
496520
"""

formulaic/sugar.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ def model_matrix(
1919
2020
This method is syntactic sugar for:
2121
```
22-
Formula(spec).get_model_matrix(data, context=LayeredMapping(locals(), globals()), **kwargs)
22+
Formula(
23+
spec,
24+
context={"__formulaic_variables_available__": ...}, # used for the `.` operator
25+
).get_model_matrix(data, context=LayeredMapping(locals(), globals()), **kwargs)
2326
```
2427
or
2528
```
@@ -52,6 +55,12 @@ def model_matrix(
5255
nominated structure.
5356
"""
5457
_context = capture_context(context + 1) if isinstance(context, int) else context
55-
return ModelSpec.from_spec(spec, **spec_overrides).get_model_matrix(
56-
data, context=_context, drop_rows=drop_rows
58+
_spec_context = ( # use materializer context for parser context
59+
ModelSpec.from_spec([], **spec_overrides)
60+
.get_materializer(data, context=_context)
61+
.layered_context
5762
)
63+
64+
return ModelSpec.from_spec(
65+
spec, context=_spec_context, **spec_overrides
66+
).get_model_matrix(data, context=_context, drop_rows=drop_rows)

0 commit comments

Comments
 (0)