@@ -74,6 +74,7 @@ def __call__(
74
74
_ordering : Union [OrderingMethod , str ] = OrderingMethod .DEGREE ,
75
75
_parser : Optional [FormulaParser ] = None ,
76
76
_nested_parser : Optional [FormulaParser ] = None ,
77
+ _context : Optional [Mapping [str , Any ]] = None ,
77
78
** structure : FormulaSpec ,
78
79
) -> Formula :
79
80
"""
@@ -82,7 +83,7 @@ def __call__(
82
83
`SimpleFormula` instance will be returned; otherwise, a
83
84
`StructuredFormula`.
84
85
85
- Some arguments a prefixed with underscores to prevent collision with
86
+ Some arguments are prefixed with underscores to prevent collision with
86
87
formula structure.
87
88
88
89
Args:
@@ -108,6 +109,7 @@ def __call__(
108
109
_ordering = _ordering ,
109
110
_parser = _parser ,
110
111
_nested_parser = _nested_parser ,
112
+ _context = _context ,
111
113
** structure ,
112
114
)
113
115
return self
@@ -120,13 +122,15 @@ def __call__(
120
122
_parser = _parser ,
121
123
_nested_parser = _nested_parser ,
122
124
_ordering = _ordering ,
123
- ** structure ,
124
- )
125
+ _context = _context ,
126
+ ** structure , # type: ignore[arg-type]
127
+ )._simplify ()
125
128
return cls .from_spec (
126
129
cast (FormulaSpec , root ),
127
130
ordering = _ordering ,
128
131
parser = _parser ,
129
132
nested_parser = _nested_parser ,
133
+ context = _context ,
130
134
)
131
135
132
136
def from_spec (
@@ -136,6 +140,7 @@ def from_spec(
136
140
ordering : Union [OrderingMethod , str ] = OrderingMethod .DEGREE ,
137
141
parser : Optional [FormulaParser ] = None ,
138
142
nested_parser : Optional [FormulaParser ] = None ,
143
+ context : Optional [Mapping [str , Any ]] = None ,
139
144
) -> Union [SimpleFormula , StructuredFormula ]:
140
145
"""
141
146
Construct a `SimpleFormula` or `StructuredFormula` instance from a
@@ -164,18 +169,25 @@ def from_spec(
164
169
if isinstance (spec , str ):
165
170
spec = cast (
166
171
FormulaSpec ,
167
- (parser or DefaultFormulaParser ()).get_terms (spec )._simplify (),
172
+ (parser or DefaultFormulaParser ())
173
+ .get_terms (spec , context = context )
174
+ ._simplify (),
168
175
)
169
176
170
177
if isinstance (spec , dict ):
171
178
return StructuredFormula (
172
- _parser = parser , _nested_parser = nested_parser , _ordering = ordering , ** spec
179
+ _parser = parser ,
180
+ _nested_parser = nested_parser ,
181
+ _ordering = ordering ,
182
+ _context = context ,
183
+ ** spec , # type: ignore[arg-type]
173
184
)
174
185
if isinstance (spec , Structured ):
175
186
return StructuredFormula (
176
187
_ordering = ordering ,
177
188
_parser = nested_parser ,
178
189
_nested_parser = nested_parser ,
190
+ _context = context ,
179
191
** spec ._structure ,
180
192
)._simplify ()
181
193
if isinstance (spec , tuple ):
@@ -184,13 +196,14 @@ def from_spec(
184
196
_ordering = ordering ,
185
197
_parser = parser ,
186
198
_nested_parser = nested_parser ,
199
+ _context = context ,
187
200
)._simplify ()
188
201
if isinstance (spec , (list , set , OrderedSet )):
189
202
terms = [
190
203
term
191
204
for value in spec
192
205
for term in (
193
- nested_parser .get_terms (value ) # type: ignore[attr-defined]
206
+ nested_parser .get_terms (value , context = context ) # type: ignore[attr-defined]
194
207
if isinstance (value , str )
195
208
else [value ]
196
209
)
@@ -248,9 +261,11 @@ class Formula(metaclass=_FormulaMeta):
248
261
def __init__ (
249
262
self ,
250
263
root : Union [FormulaSpec , _MissingType ] = MISSING ,
264
+ * ,
251
265
_parser : Optional [FormulaParser ] = None ,
252
266
_nested_parser : Optional [FormulaParser ] = None ,
253
267
_ordering : Union [OrderingMethod , str ] = OrderingMethod .DEGREE ,
268
+ _context : Optional [Mapping [str , Any ]] = None ,
254
269
** structure : FormulaSpec ,
255
270
):
256
271
"""
@@ -288,7 +303,7 @@ def get_model_matrix(
288
303
@abstractmethod
289
304
def required_variables (self ) -> Set [Variable ]:
290
305
"""
291
- The set of variables required in the data order to materialize this
306
+ The set of variables required to be in the data to materialize this
292
307
formula.
293
308
294
309
Attempts are made to restrict these variables only to those expected in
@@ -354,6 +369,7 @@ def __init__(
354
369
_ordering : Union [OrderingMethod , str ] = OrderingMethod .DEGREE ,
355
370
_parser : Optional [FormulaParser ] = None ,
356
371
_nested_parser : Optional [FormulaParser ] = None ,
372
+ _context : Optional [Mapping [str , Any ]] = None ,
357
373
** structure : FormulaSpec ,
358
374
):
359
375
if root is MISSING :
@@ -667,19 +683,22 @@ class StructuredFormula(Structured[SimpleFormula], Formula):
667
683
formula specifications. Can be: "none", "degree" (default), or "sort".
668
684
"""
669
685
670
- __slots__ = ("_parser" , "_nested_parser" , "_ordering" )
686
+ __slots__ = ("_parser" , "_nested_parser" , "_ordering" , "_context" )
671
687
672
688
def __init__ (
673
689
self ,
674
690
root : Union [FormulaSpec , _MissingType ] = MISSING ,
691
+ * ,
692
+ _ordering : Union [OrderingMethod , str ] = OrderingMethod .DEGREE ,
675
693
_parser : Optional [FormulaParser ] = None ,
676
694
_nested_parser : Optional [FormulaParser ] = None ,
677
- _ordering : Union [ OrderingMethod , str ] = OrderingMethod . DEGREE ,
695
+ _context : Optional [ Mapping [ str , Any ]] = None ,
678
696
** structure : FormulaSpec ,
679
697
):
698
+ self ._ordering = OrderingMethod (_ordering )
680
699
self ._parser = _parser or DEFAULT_PARSER
681
700
self ._nested_parser = _nested_parser or _parser or DEFAULT_NESTED_PARSER
682
- self ._ordering = OrderingMethod ( _ordering )
701
+ self ._context = _context
683
702
super ().__init__ (root , ** structure ) # type: ignore
684
703
self ._simplify (unwrap = False , inplace = True )
685
704
@@ -704,6 +723,7 @@ def _prepare_item( # type: ignore[override]
704
723
ordering = self ._ordering ,
705
724
parser = (self ._parser if key == "root" else self ._nested_parser ),
706
725
nested_parser = self ._nested_parser ,
726
+ context = self ._context ,
707
727
)
708
728
709
729
def get_model_matrix (
@@ -782,3 +802,9 @@ def differentiate( # pylint: disable=redefined-builtin
782
802
SimpleFormula ,
783
803
self ._map (lambda formula : formula .differentiate (* wrt , use_sympy = use_sympy )),
784
804
)
805
+
806
+ # Ensure pickling never includes context
807
+ def __getstate__ (self ) -> Tuple [None , Dict [str , Any ]]:
808
+ state = cast (Tuple [None , Dict [str , Any ]], super ().__getstate__ ())
809
+ state [1 ]["_context" ] = None
810
+ return state
0 commit comments