Skip to content

Commit

Permalink
Add support for the . operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Dec 2, 2024
1 parent cbd47b3 commit d6c3e4c
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 23 deletions.
11 changes: 7 additions & 4 deletions docsite/docs/guides/grammar.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ unless otherwise indicated.
| `{...}`[^1] | 1 | Quotes python operations, as a more convenient way to do Python operations than `I(...)`, e.g. `` {`my|col`**2} `` ||||
| `<function>(...)`[^1] | 1 | Python transform on column, e.g. `my_func(x)` which is equivalent to `{my_func(x)}` |[^2] |||
|-----|
| `(...)` | 1 | Groups operations, overriding normal precedence rules. All operations with the parentheses are performed before the result of these operations is permitted to be operated upon by its peers. || ||
| `(...)` | 1 | Groups operations, overriding normal precedence rules. All operations with the parentheses are performed before the result of these operations is permitted to be operated upon by its peers. || ||
|-----|
| ** | 2 | Includes all n-th order interactions of the terms in the left operand, where n is the (integral) value of the right operand, e.g. `(a+b+c)**2` is equivalent to `a + b + c + a:b + a:c + b:c`. ||||
| ^ | 2 | Alias for `**`. ||[^3] ||
| `.`[^9] | 0 | Stands in as a wild-card for the sum of variables in the data not used on the left-hand side of a formula. ||||
|-----|
| `**` | 2 | Includes all n-th order interactions of the terms in the left operand, where n is the (integral) value of the right operand, e.g. `(a+b+c)**2` is equivalent to `a + b + c + a:b + a:c + b:c`. ||||
| `^` | 2 | Alias for `**`. ||[^3] ||
|-----|
| `:` | 2 | Adds a new term that corresponds to the interaction of its operands (i.e. their elementwise product). |[^4] |||
|-----|
Expand Down Expand Up @@ -123,4 +125,5 @@ and conventions of which you should be aware.
[^5]: This somewhat confusing operator is useful when you want to include hierachical features in your data, and where certain interaction terms do not make sense (particularly in ANOVA contexts). For example, if `a` represents countries, and `b` represents cities, then the full product of terms from `a * b === a + b + a:b` does not make sense, because any value of `b` is guaranteed to coincide with a value in `a`, and does not independently add value. Thus, the operation `a / b === a + a:b` results in more sensible dataset. As a result, the `/` operator is right-distributive, since if `b` and `c` were both nested in `a`, you would want `a/(b+c) === a + a:b + a:c`. Likewise, the operator is not left-distributive, since if `c` is nested under both `a` and `b` separately, then you want `(a + b)/c === a + b + a:b:c`. Lastly, if `c` is nested in `b`, and `b` is nested in `a`, then you would want `a/b/c === a + a:(b/c) === a + a:b + a:b:c`.
[^6]: Implemented by an R package called [Formula](https://cran.r-project.org/web/packages/Formula/index.html) that extends the default formula syntax.
[^7]: Patsy uses the `rescale` keyword rather than `scale`, but provides the same functionality.
[^8]: For increased compatibility with patsy, we use patsy's signature for `standardize`.
[^8]: For increased compatibility with patsy, we use patsy's signature for `standardize`.
[^9]: Requires additional context to be passed in when directly using the `Formula` constructor. e.g. `Formula("y ~ .", context={"__formulaic_variables_available__": ["x", "y", "z"]})`; or you can use `model_matrix`, `ModelSpec.get_model_matrix()`, or `FormulaMaterializer.get_model_matrix()` without further specification.
2 changes: 2 additions & 0 deletions formulaic/parser/algos/sanitize_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def sanitize_tokens(tokens: Iterable[Token]) -> Iterable[Token]:
- possible more in the future
"""
for token in tokens:
if token.token == ".": # noqa: S105
token.kind = Token.Kind.OPERATOR
if token.kind is Token.Kind.PYTHON:
token.token = sanitize_python_code(token.token)
yield token
Expand Down
51 changes: 50 additions & 1 deletion formulaic/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Generator,
Iterable,
List,
Mapping,
MutableMapping,
Set,
Tuple,
Expand All @@ -20,11 +21,13 @@

from typing_extensions import Self

from formulaic.parser.types.ast_node import ASTNode
from formulaic.errors import FormulaParsingError
from formulaic.utils.layered_mapping import LayeredMapping

from .algos.sanitize_tokens import sanitize_tokens
from .algos.tokenize import tokenize
from .types import (
ASTNode,
Factor,
FormulaParser,
Operator,
Expand Down Expand Up @@ -149,6 +152,7 @@ def get_tokens_from_formula(
[token_one],
kind=Token.Kind.OPERATOR,
join_operator="+",
no_join_for_operators={"+", "-"},
)
)

Expand Down Expand Up @@ -191,9 +195,16 @@ def find_rhs_index(tokens: List[Token]) -> int:
[token_one],
kind=Token.Kind.OPERATOR,
join_operator="+",
no_join_for_operators={"+", "-"},
),
]

context["__formulaic_variables_used_lhs__"] = [
variable
for token in tokens[:rhs_index]
for variable in token.required_variables
]

# Collapse inserted "+" and "-" operators to prevent unary issues.
tokens = merge_operator_tokens(tokens, symbols={"+", "-"})

Expand Down Expand Up @@ -356,6 +367,37 @@ def get_terms(terms: OrderedSet[Term]) -> List[Term]:

return Structured(get_terms(lhs), deps=(Structured(lhs=lhs, rhs=rhs),))

def insert_unused_terms(context: Mapping[str, Any]) -> OrderedSet[Term]:
available_variables: OrderedSet[str]
used_variables: Set[str] = set(context["__formulaic_variables_used_lhs__"])

# Populate `available_variables` or raise.
if "__formulaic_variables_available__" in context:
available_variables = OrderedSet(
context["__formulaic_variables_available__"]
)
elif isinstance(context, LayeredMapping) and "data" in context.named_layers:
available_variables = OrderedSet(context.named_layers["data"])
else:
raise FormulaParsingError(
"The `.` operator requires additional context about which "
"variables are available to use. This can be provided by "
"passing in a value for `__formulaic_variables_available__`"
"in the context while parsing the formula; by passing the "
"formula to the materializer's `.get_model_matrix()` method; "
"or by passing a `LayeredMapping` instance as the context "
"with a `data` layer containing the available variables "
"(such as the `.layered_context` from a "
"`FormulaMaterializer` instance)."
)

unused_variables = available_variables - used_variables

return OrderedSet(
Term([Factor(variable, eval_method="lookup")])
for variable in unused_variables
)

return [
Operator(
"~",
Expand Down Expand Up @@ -474,6 +516,13 @@ def get_terms(terms: OrderedSet[Term]) -> List[Term]:
Operator(
"^", arity=2, precedence=500, associativity="right", to_terms=power
),
Operator(
".",
arity=0,
precedence=1000,
fixity="postfix",
to_terms=insert_unused_terms,
),
]

def resolve(
Expand Down
45 changes: 32 additions & 13 deletions formulaic/parser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ def __get_token_for_ast(ast: Union[Token, ASTNode]) -> Token: # pragma: no cove
while isinstance(rhs_token, ASTNode):
rhs_token = rhs_token.args[-1] # type: ignore
return Token(
token=lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1]
if lhs_token.source
else "",
token=(
lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1]
if lhs_token.source
else ""
),
source=lhs_token.source,
source_start=lhs_token.source_start,
source_end=rhs_token.source_end,
Expand All @@ -93,19 +95,29 @@ def __get_tokens_for_gap(
"""
lhs_token = lhs
while isinstance(lhs_token, ASTNode):
lhs_token = lhs_token.args[-1] # type: ignore
lhs_token = (
lhs_token.args[-1] # type: ignore
if lhs_token.args
else Token(lhs_token.operator.symbol)
)
rhs_token = rhs or lhs
while isinstance(rhs_token, ASTNode):
rhs_token = rhs_token.args[0] # type: ignore
rhs_token = (
rhs_token.args[0] # type: ignore
if rhs_token.args
else Token(rhs_token.operator.symbol)
)
return (
lhs_token,
rhs_token,
Token(
lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1]
if lhs_token.source
and lhs_token.source_start is not None
and rhs_token.source_end is not None
else "",
(
lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1]
if lhs_token.source
and lhs_token.source_start is not None
and rhs_token.source_end is not None
else ""
),
source=lhs_token.source,
source_start=lhs_token.source_start,
source_end=rhs_token.source_end,
Expand Down Expand Up @@ -154,6 +166,7 @@ def insert_tokens_after(
*,
kind: Optional[Token.Kind] = None,
join_operator: Optional[str] = None,
no_join_for_operators: Union[bool, Set[str]] = True,
) -> Iterable[Token]:
"""
Insert additional tokens into a sequence of tokens after (within token)
Expand All @@ -177,6 +190,10 @@ def insert_tokens_after(
the added tokens with existing tokens, the value set here will be
used to create a joining operator token. If not provided, not
additional operators are added.
no_join_for_operators: Whether to use the join operator when the next
token is an operator token; or a set of operator symbols for which
to skip adding the join token.
"""
tokens = list(tokens)

Expand Down Expand Up @@ -205,9 +222,11 @@ def insert_tokens_after(
next_token = split_tokens[j + 1]
elif i < len(tokens) - 1:
next_token = tokens[i + 1]
if (
next_token is not None
and next_token.kind is not Token.Kind.OPERATOR
if next_token is not None and (
next_token.kind is not Token.Kind.OPERATOR
or no_join_for_operators is False
or isinstance(no_join_for_operators, set)
and next_token.token not in no_join_for_operators
):
yield Token(join_operator, kind=Token.Kind.OPERATOR)

Expand Down
30 changes: 28 additions & 2 deletions tests/materializers/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@
["Intercept"],
1,
),
".": (
["Intercept", "a", "b", "A[T.b]", "A[T.c]", "B[T.b]", "B[T.c]"],
[
"Intercept",
"a",
"b",
"A[a]",
"A[b]",
"A[c]",
"B[a]",
"B[b]",
"B[c]",
"D[a]",
],
["Intercept", "a", "b"],
1,
),
}


Expand All @@ -86,7 +103,13 @@ def data(self):
@pytest.fixture
def data_with_nulls(self):
return pandas.DataFrame(
{"a": [1, 2, None], "A": ["a", None, "c"], "B": ["a", "b", None]}
{
"a": [1, 2, None],
"b": [1, 2, 3],
"A": ["a", None, "c"],
"B": ["a", "b", None],
"D": ["a", "a", "a"],
}
)

@pytest.fixture
Expand Down Expand Up @@ -182,7 +205,10 @@ def test_na_handling(self, data_with_nulls, formula, tests, output):
formula, na_action="ignore"
)
assert isinstance(mm, pandas.DataFrame)
assert mm.shape == (3, len(tests[0]) + (-1 if "A" in formula else 0))
if formula == ".":
assert mm.shape == (3, 5)
else:
assert mm.shape == (3, len(tests[0]) + (-1 if "A" in formula else 0))

if formula != "C(A)": # C(A) pre-encodes the data, stripping out nulls.
with pytest.raises(ValueError):
Expand Down
27 changes: 26 additions & 1 deletion tests/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from formulaic.parser import DefaultFormulaParser, DefaultOperatorResolver
from formulaic.parser.types import Structured, Token
from formulaic.parser.types.term import Term
from formulaic.utils.layered_mapping import LayeredMapping

FORMULA_TO_TOKENS = {
"": ["1"],
Expand Down Expand Up @@ -133,12 +134,21 @@
# Quoting
"`a|b~c*d`": ["1", "a|b~c*d"],
"{a | b | c}": ["1", "a | b | c"],
# Wildcards
".": ["1", "a", "b", "c"],
".^2": ["1", "a", "a:b", "a:c", "b", "b:c", "c"],
".^2 - a:b": ["1", "a", "a:c", "b", "b:c", "c"],
"a ~ .": {
"lhs": ["a"],
"rhs": ["1", "b", "c"],
},
}

PARSER = DefaultFormulaParser(feature_flags={"all"})
PARSER_NO_INTERCEPT = DefaultFormulaParser(
include_intercept=False, feature_flags={"all"}
)
PARSER_CONTEXT = {"__formulaic_variables_available__": ["a", "b", "c"]}


class TestFormulaParser:
Expand All @@ -148,7 +158,9 @@ class TestFormulaParser:

@pytest.mark.parametrize("formula,terms", FORMULA_TO_TERMS.items())
def test_to_terms(self, formula, terms):
generated_terms: Structured[List[Term]] = PARSER.get_terms(formula)
generated_terms: Structured[List[Term]] = PARSER.get_terms(
formula, context=PARSER_CONTEXT
)
if generated_terms._has_keys:
comp = generated_terms._map(list)._to_dict()
elif generated_terms._has_root and isinstance(generated_terms.root, tuple):
Expand Down Expand Up @@ -280,6 +292,19 @@ def test_invalid_multistage_formula(self):
):
DefaultFormulaParser(feature_flags={"all"}).get_terms("[[a ~ b] ~ c]")

def test_alternative_wildcard_usage(self):
PARSER.get_terms(
".", context=LayeredMapping({"a": 1, "b": 2}, name="data")
) == ["1", "a", "b"]

with pytest.raises(
FormulaParsingError,
match=re.escape(
"The `.` operator requires additional context about which "
),
):
PARSER.get_terms(".")


class TestDefaultOperatorResolver:
@pytest.fixture
Expand Down
39 changes: 37 additions & 2 deletions tests/parser/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from ntpath import join

import pytest

from formulaic.errors import FormulaSyntaxError
from formulaic.parser.types import Token
from formulaic.parser.utils import (
exc_for_token,
insert_tokens_after,
merge_operator_tokens,
replace_tokens,
Expand All @@ -29,6 +29,15 @@ def test_replace_tokens(tokens):
]


def test_exc_for_token(tokens):
with pytest.raises(FormulaSyntaxError, match="Hello World"):
raise exc_for_token(tokens[0], "Hello World")
with pytest.raises(FormulaSyntaxError, match="Hello World"):
raise exc_for_token(
Token("h", source="hi", source_start=0, source_end=1), "Hello World"
)


def test_insert_tokens_after(tokens):
assert list(
insert_tokens_after(
Expand All @@ -50,6 +59,32 @@ def test_insert_tokens_after(tokens):
join_operator="+",
)
) == ["1", "+|", "hi", "-", "field"]
assert list(
insert_tokens_after(
[
Token("1", kind=Token.Kind.VALUE),
Token("+|-", kind=Token.Kind.OPERATOR),
Token("field", kind=Token.Kind.NAME),
],
r"\|",
[Token("hi", kind=Token.Kind.NAME)],
join_operator="+",
no_join_for_operators=False,
)
) == ["1", "+|", "hi", "+", "-", "field"]
assert list(
insert_tokens_after(
[
Token("1", kind=Token.Kind.VALUE),
Token("+|-", kind=Token.Kind.OPERATOR),
Token("field", kind=Token.Kind.NAME),
],
r"\|",
[Token("hi", kind=Token.Kind.NAME)],
join_operator="+",
no_join_for_operators={"+", "-"},
)
) == ["1", "+|", "hi", "-", "field"]
assert list(
insert_tokens_after(
[
Expand Down
Loading

0 comments on commit d6c3e4c

Please sign in to comment.