Skip to content

Commit 0d93416

Browse files
Add support for the . operator.
1 parent 6fdce0c commit 0d93416

File tree

9 files changed

+194
-23
lines changed

9 files changed

+194
-23
lines changed

docsite/docs/guides/grammar.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ unless otherwise indicated.
2626
| `{...}`[^1] | 1 | Quotes python operations, as a more convenient way to do Python operations than `I(...)`, e.g. `` {`my|col`**2} `` ||||
2727
| `<function>(...)`[^1] | 1 | Python transform on column, e.g. `my_func(x)` which is equivalent to `{my_func(x)}` |[^2] |||
2828
|-----|
29-
| `(...)` | 1 | Groups operations, overriding normal precedence rules. All operations with the parentheses are performed before the result of these operations is permitted to be operated upon by its peers. || ||
29+
| `(...)` | 1 | Groups operations, overriding normal precedence rules. All operations with the parentheses are performed before the result of these operations is permitted to be operated upon by its peers. || ||
3030
|-----|
31-
| ** | 2 | Includes all n-th order interactions of the terms in the left operand, where n is the (integral) value of the right operand, e.g. `(a+b+c)**2` is equivalent to `a + b + c + a:b + a:c + b:c`. ||||
32-
| ^ | 2 | Alias for `**`. ||[^3] ||
31+
| `.`[^9] | 0 | Stands in as a wild-card for the sum of variables in the data not used on the left-hand side of a formula. ||||
32+
|-----|
33+
| `**` | 2 | Includes all n-th order interactions of the terms in the left operand, where n is the (integral) value of the right operand, e.g. `(a+b+c)**2` is equivalent to `a + b + c + a:b + a:c + b:c`. ||||
34+
| `^` | 2 | Alias for `**`. ||[^3] ||
3335
|-----|
3436
| `:` | 2 | Adds a new term that corresponds to the interaction of its operands (i.e. their elementwise product). |[^4] |||
3537
|-----|
@@ -123,4 +125,5 @@ and conventions of which you should be aware.
123125
[^5]: This somewhat confusing operator is useful when you want to include hierachical features in your data, and where certain interaction terms do not make sense (particularly in ANOVA contexts). For example, if `a` represents countries, and `b` represents cities, then the full product of terms from `a * b === a + b + a:b` does not make sense, because any value of `b` is guaranteed to coincide with a value in `a`, and does not independently add value. Thus, the operation `a / b === a + a:b` results in more sensible dataset. As a result, the `/` operator is right-distributive, since if `b` and `c` were both nested in `a`, you would want `a/(b+c) === a + a:b + a:c`. Likewise, the operator is not left-distributive, since if `c` is nested under both `a` and `b` separately, then you want `(a + b)/c === a + b + a:b:c`. Lastly, if `c` is nested in `b`, and `b` is nested in `a`, then you would want `a/b/c === a + a:(b/c) === a + a:b + a:b:c`.
124126
[^6]: Implemented by an R package called [Formula](https://cran.r-project.org/web/packages/Formula/index.html) that extends the default formula syntax.
125127
[^7]: Patsy uses the `rescale` keyword rather than `scale`, but provides the same functionality.
126-
[^8]: For increased compatibility with patsy, we use patsy's signature for `standardize`.
128+
[^8]: For increased compatibility with patsy, we use patsy's signature for `standardize`.
129+
[^9]: Requires additional context to be passed in when directly using the `Formula` constructor. e.g. `Formula("y ~ .", context={"__formulaic_variables_available__": ["x", "y", "z"]})`; or you can use `model_matrix`, `ModelSpec.get_model_matrix()`, or `FormulaMaterializer.get_model_matrix()` without further specification.

formulaic/parser/algos/sanitize_tokens.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def sanitize_tokens(tokens: Iterable[Token]) -> Iterable[Token]:
1515
- possible more in the future
1616
"""
1717
for token in tokens:
18+
if token.token == ".": # noqa: S105
19+
token.kind = Token.Kind.OPERATOR
1820
if token.kind is Token.Kind.PYTHON:
1921
token.token = sanitize_python_code(token.token)
2022
yield token

formulaic/parser/parser.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Generator,
1212
Iterable,
1313
List,
14+
Mapping,
1415
MutableMapping,
1516
Set,
1617
Tuple,
@@ -20,11 +21,13 @@
2021

2122
from typing_extensions import Self
2223

23-
from formulaic.parser.types.ast_node import ASTNode
24+
from formulaic.errors import FormulaParsingError
25+
from formulaic.utils.layered_mapping import LayeredMapping
2426

2527
from .algos.sanitize_tokens import sanitize_tokens
2628
from .algos.tokenize import tokenize
2729
from .types import (
30+
ASTNode,
2831
Factor,
2932
FormulaParser,
3033
Operator,
@@ -149,6 +152,7 @@ def get_tokens_from_formula(
149152
[token_one],
150153
kind=Token.Kind.OPERATOR,
151154
join_operator="+",
155+
no_join_for_operators={"+", "-"},
152156
)
153157
)
154158

@@ -191,9 +195,16 @@ def find_rhs_index(tokens: List[Token]) -> int:
191195
[token_one],
192196
kind=Token.Kind.OPERATOR,
193197
join_operator="+",
198+
no_join_for_operators={"+", "-"},
194199
),
195200
]
196201

202+
context["__formulaic_variables_used_lhs__"] = [
203+
variable
204+
for token in tokens[:rhs_index]
205+
for variable in token.required_variables
206+
]
207+
197208
# Collapse inserted "+" and "-" operators to prevent unary issues.
198209
tokens = merge_operator_tokens(tokens, symbols={"+", "-"})
199210

@@ -356,6 +367,37 @@ def get_terms(terms: OrderedSet[Term]) -> List[Term]:
356367

357368
return Structured(get_terms(lhs), deps=(Structured(lhs=lhs, rhs=rhs),))
358369

370+
def insert_unused_terms(context: Mapping[str, Any]) -> OrderedSet[Term]:
371+
available_variables: OrderedSet[str]
372+
used_variables: Set[str] = set(context["__formulaic_variables_used_lhs__"])
373+
374+
# Populate `available_variables` or raise.
375+
if "__formulaic_variables_available__" in context:
376+
available_variables = OrderedSet(
377+
context["__formulaic_variables_available__"]
378+
)
379+
elif isinstance(context, LayeredMapping) and "data" in context.named_layers:
380+
available_variables = OrderedSet(context.named_layers["data"])
381+
else:
382+
raise FormulaParsingError(
383+
"The `.` operator requires additional context about which "
384+
"variables are available to use. This can be provided by "
385+
"passing in a value for `__formulaic_variables_available__`"
386+
"in the context while parsing the formula; by passing the "
387+
"formula to the materializer's `.get_model_matrix()` method; "
388+
"or by passing a `LayeredMapping` instance as the context "
389+
"with a `data` layer containing the available variables "
390+
"(such as the `.layered_context` from a "
391+
"`FormulaMaterializer` instance)."
392+
)
393+
394+
unused_variables = available_variables - used_variables
395+
396+
return OrderedSet(
397+
Term([Factor(variable, eval_method="lookup")])
398+
for variable in unused_variables
399+
)
400+
359401
return [
360402
Operator(
361403
"~",
@@ -474,6 +516,13 @@ def get_terms(terms: OrderedSet[Term]) -> List[Term]:
474516
Operator(
475517
"^", arity=2, precedence=500, associativity="right", to_terms=power
476518
),
519+
Operator(
520+
".",
521+
arity=0,
522+
precedence=1000,
523+
fixity="postfix",
524+
to_terms=insert_unused_terms,
525+
),
477526
]
478527

479528
def resolve(

formulaic/parser/utils.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ def __get_token_for_ast(ast: Union[Token, ASTNode]) -> Token: # pragma: no cove
7171
while isinstance(rhs_token, ASTNode):
7272
rhs_token = rhs_token.args[-1] # type: ignore
7373
return Token(
74-
token=lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1]
75-
if lhs_token.source
76-
else "",
74+
token=(
75+
lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1]
76+
if lhs_token.source
77+
else ""
78+
),
7779
source=lhs_token.source,
7880
source_start=lhs_token.source_start,
7981
source_end=rhs_token.source_end,
@@ -93,19 +95,29 @@ def __get_tokens_for_gap(
9395
"""
9496
lhs_token = lhs
9597
while isinstance(lhs_token, ASTNode):
96-
lhs_token = lhs_token.args[-1] # type: ignore
98+
lhs_token = (
99+
lhs_token.args[-1] # type: ignore
100+
if lhs_token.args
101+
else Token(lhs_token.operator.symbol)
102+
)
97103
rhs_token = rhs or lhs
98104
while isinstance(rhs_token, ASTNode):
99-
rhs_token = rhs_token.args[0] # type: ignore
105+
rhs_token = (
106+
rhs_token.args[0] # type: ignore
107+
if rhs_token.args
108+
else Token(rhs_token.operator.symbol)
109+
)
100110
return (
101111
lhs_token,
102112
rhs_token,
103113
Token(
104-
lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1]
105-
if lhs_token.source
106-
and lhs_token.source_start is not None
107-
and rhs_token.source_end is not None
108-
else "",
114+
(
115+
lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1]
116+
if lhs_token.source
117+
and lhs_token.source_start is not None
118+
and rhs_token.source_end is not None
119+
else ""
120+
),
109121
source=lhs_token.source,
110122
source_start=lhs_token.source_start,
111123
source_end=rhs_token.source_end,
@@ -154,6 +166,7 @@ def insert_tokens_after(
154166
*,
155167
kind: Optional[Token.Kind] = None,
156168
join_operator: Optional[str] = None,
169+
no_join_for_operators: Union[bool, Set[str]] = True,
157170
) -> Iterable[Token]:
158171
"""
159172
Insert additional tokens into a sequence of tokens after (within token)
@@ -177,6 +190,10 @@ def insert_tokens_after(
177190
the added tokens with existing tokens, the value set here will be
178191
used to create a joining operator token. If not provided, not
179192
additional operators are added.
193+
no_join_for_operators: Whether to use the join operator when the next
194+
token is an operator token; or a set of operator symbols for which
195+
to skip adding the join token.
196+
180197
"""
181198
tokens = list(tokens)
182199

@@ -205,9 +222,11 @@ def insert_tokens_after(
205222
next_token = split_tokens[j + 1]
206223
elif i < len(tokens) - 1:
207224
next_token = tokens[i + 1]
208-
if (
209-
next_token is not None
210-
and next_token.kind is not Token.Kind.OPERATOR
225+
if next_token is not None and (
226+
next_token.kind is not Token.Kind.OPERATOR
227+
or no_join_for_operators is False
228+
or isinstance(no_join_for_operators, set)
229+
and next_token.token not in no_join_for_operators
211230
):
212231
yield Token(join_operator, kind=Token.Kind.OPERATOR)
213232

tests/materializers/test_pandas.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,23 @@
6767
["Intercept"],
6868
1,
6969
),
70+
".": (
71+
["Intercept", "a", "b", "A[T.b]", "A[T.c]", "B[T.b]", "B[T.c]"],
72+
[
73+
"Intercept",
74+
"a",
75+
"b",
76+
"A[a]",
77+
"A[b]",
78+
"A[c]",
79+
"B[a]",
80+
"B[b]",
81+
"B[c]",
82+
"D[a]",
83+
],
84+
["Intercept", "a", "b"],
85+
1,
86+
),
7087
}
7188

7289

@@ -86,7 +103,13 @@ def data(self):
86103
@pytest.fixture
87104
def data_with_nulls(self):
88105
return pandas.DataFrame(
89-
{"a": [1, 2, None], "A": ["a", None, "c"], "B": ["a", "b", None]}
106+
{
107+
"a": [1, 2, None],
108+
"b": [1, 2, 3],
109+
"A": ["a", None, "c"],
110+
"B": ["a", "b", None],
111+
"D": ["a", "a", "a"],
112+
}
90113
)
91114

92115
@pytest.fixture
@@ -182,7 +205,10 @@ def test_na_handling(self, data_with_nulls, formula, tests, output):
182205
formula, na_action="ignore"
183206
)
184207
assert isinstance(mm, pandas.DataFrame)
185-
assert mm.shape == (3, len(tests[0]) + (-1 if "A" in formula else 0))
208+
if formula == ".":
209+
assert mm.shape == (3, 5)
210+
else:
211+
assert mm.shape == (3, len(tests[0]) + (-1 if "A" in formula else 0))
186212

187213
if formula != "C(A)": # C(A) pre-encodes the data, stripping out nulls.
188214
with pytest.raises(ValueError):

tests/parser/test_parser.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from formulaic.parser import DefaultFormulaParser, DefaultOperatorResolver
1111
from formulaic.parser.types import Structured, Token
1212
from formulaic.parser.types.term import Term
13+
from formulaic.utils.layered_mapping import LayeredMapping
1314

1415
FORMULA_TO_TOKENS = {
1516
"": ["1"],
@@ -133,12 +134,21 @@
133134
# Quoting
134135
"`a|b~c*d`": ["1", "a|b~c*d"],
135136
"{a | b | c}": ["1", "a | b | c"],
137+
# Wildcards
138+
".": ["1", "a", "b", "c"],
139+
".^2": ["1", "a", "a:b", "a:c", "b", "b:c", "c"],
140+
".^2 - a:b": ["1", "a", "a:c", "b", "b:c", "c"],
141+
"a ~ .": {
142+
"lhs": ["a"],
143+
"rhs": ["1", "b", "c"],
144+
},
136145
}
137146

138147
PARSER = DefaultFormulaParser(feature_flags={"all"})
139148
PARSER_NO_INTERCEPT = DefaultFormulaParser(
140149
include_intercept=False, feature_flags={"all"}
141150
)
151+
PARSER_CONTEXT = {"__formulaic_variables_available__": ["a", "b", "c"]}
142152

143153

144154
class TestFormulaParser:
@@ -148,7 +158,9 @@ class TestFormulaParser:
148158

149159
@pytest.mark.parametrize("formula,terms", FORMULA_TO_TERMS.items())
150160
def test_to_terms(self, formula, terms):
151-
generated_terms: Structured[List[Term]] = PARSER.get_terms(formula)
161+
generated_terms: Structured[List[Term]] = PARSER.get_terms(
162+
formula, context=PARSER_CONTEXT
163+
)
152164
if generated_terms._has_keys:
153165
comp = generated_terms._map(list)._to_dict()
154166
elif generated_terms._has_root and isinstance(generated_terms.root, tuple):
@@ -280,6 +292,19 @@ def test_invalid_multistage_formula(self):
280292
):
281293
DefaultFormulaParser(feature_flags={"all"}).get_terms("[[a ~ b] ~ c]")
282294

295+
def test_alternative_wildcard_usage(self):
296+
PARSER.get_terms(
297+
".", context=LayeredMapping({"a": 1, "b": 2}, name="data")
298+
) == ["1", "a", "b"]
299+
300+
with pytest.raises(
301+
FormulaParsingError,
302+
match=re.escape(
303+
"The `.` operator requires additional context about which "
304+
),
305+
):
306+
PARSER.get_terms(".")
307+
283308

284309
class TestDefaultOperatorResolver:
285310
@pytest.fixture

tests/parser/test_utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from ntpath import join
2-
31
import pytest
42

3+
from formulaic.errors import FormulaSyntaxError
54
from formulaic.parser.types import Token
65
from formulaic.parser.utils import (
6+
exc_for_token,
77
insert_tokens_after,
88
merge_operator_tokens,
99
replace_tokens,
@@ -29,6 +29,15 @@ def test_replace_tokens(tokens):
2929
]
3030

3131

32+
def test_exc_for_token(tokens):
33+
with pytest.raises(FormulaSyntaxError, match="Hello World"):
34+
raise exc_for_token(tokens[0], "Hello World")
35+
with pytest.raises(FormulaSyntaxError, match="Hello World"):
36+
raise exc_for_token(
37+
Token("h", source="hi", source_start=0, source_end=1), "Hello World"
38+
)
39+
40+
3241
def test_insert_tokens_after(tokens):
3342
assert list(
3443
insert_tokens_after(
@@ -50,6 +59,32 @@ def test_insert_tokens_after(tokens):
5059
join_operator="+",
5160
)
5261
) == ["1", "+|", "hi", "-", "field"]
62+
assert list(
63+
insert_tokens_after(
64+
[
65+
Token("1", kind=Token.Kind.VALUE),
66+
Token("+|-", kind=Token.Kind.OPERATOR),
67+
Token("field", kind=Token.Kind.NAME),
68+
],
69+
r"\|",
70+
[Token("hi", kind=Token.Kind.NAME)],
71+
join_operator="+",
72+
no_join_for_operators=False,
73+
)
74+
) == ["1", "+|", "hi", "+", "-", "field"]
75+
assert list(
76+
insert_tokens_after(
77+
[
78+
Token("1", kind=Token.Kind.VALUE),
79+
Token("+|-", kind=Token.Kind.OPERATOR),
80+
Token("field", kind=Token.Kind.NAME),
81+
],
82+
r"\|",
83+
[Token("hi", kind=Token.Kind.NAME)],
84+
join_operator="+",
85+
no_join_for_operators={"+", "-"},
86+
)
87+
) == ["1", "+|", "hi", "-", "field"]
5388
assert list(
5489
insert_tokens_after(
5590
[

0 commit comments

Comments
 (0)