diff --git a/docsite/docs/guides/grammar.md b/docsite/docs/guides/grammar.md index fc227457..22c8dd20 100644 --- a/docsite/docs/guides/grammar.md +++ b/docsite/docs/guides/grammar.md @@ -26,10 +26,12 @@ unless otherwise indicated. | `{...}`[^1] | 1 | Quotes python operations, as a more convenient way to do Python operations than `I(...)`, e.g. `` {`my|col`**2} `` | ✓ | ✗ | ✗ | | `(...)`[^1] | 1 | Python transform on column, e.g. `my_func(x)` which is equivalent to `{my_func(x)}` | ✓[^2] | ✓ | ✗ | |-----| -| `(...)` | 1 | Groups operations, overriding normal precedence rules. All operations with the parentheses are performed before the result of these operations is permitted to be operated upon by its peers. | ✓ | ✓ | ✓ | +| `(...)` | 1 | Groups operations, overriding normal precedence rules. All operations with the parentheses are performed before the result of these operations is permitted to be operated upon by its peers. | ✓ | ✗ | ✓ | |-----| -| ** | 2 | Includes all n-th order interactions of the terms in the left operand, where n is the (integral) value of the right operand, e.g. `(a+b+c)**2` is equivalent to `a + b + c + a:b + a:c + b:c`. | ✓ | ✓ | ✓ | -| ^ | 2 | Alias for `**`. | ✓ | ✗[^3] | ✓ | +| `.`[^9] | 0 | Stands in as a wild-card for the sum of variables in the data not used on the left-hand side of a formula. | ✓ | ✗ | ✓ | +|-----| +| `**` | 2 | Includes all n-th order interactions of the terms in the left operand, where n is the (integral) value of the right operand, e.g. `(a+b+c)**2` is equivalent to `a + b + c + a:b + a:c + b:c`. | ✓ | ✓ | ✓ | +| `^` | 2 | Alias for `**`. | ✓ | ✗[^3] | ✓ | |-----| | `:` | 2 | Adds a new term that corresponds to the interaction of its operands (i.e. their elementwise product). | ✓[^4] | ✓ | ✓ | |-----| @@ -123,4 +125,5 @@ and conventions of which you should be aware. [^5]: This somewhat confusing operator is useful when you want to include hierachical features in your data, and where certain interaction terms do not make sense (particularly in ANOVA contexts). For example, if `a` represents countries, and `b` represents cities, then the full product of terms from `a * b === a + b + a:b` does not make sense, because any value of `b` is guaranteed to coincide with a value in `a`, and does not independently add value. Thus, the operation `a / b === a + a:b` results in more sensible dataset. As a result, the `/` operator is right-distributive, since if `b` and `c` were both nested in `a`, you would want `a/(b+c) === a + a:b + a:c`. Likewise, the operator is not left-distributive, since if `c` is nested under both `a` and `b` separately, then you want `(a + b)/c === a + b + a:b:c`. Lastly, if `c` is nested in `b`, and `b` is nested in `a`, then you would want `a/b/c === a + a:(b/c) === a + a:b + a:b:c`. [^6]: Implemented by an R package called [Formula](https://cran.r-project.org/web/packages/Formula/index.html) that extends the default formula syntax. [^7]: Patsy uses the `rescale` keyword rather than `scale`, but provides the same functionality. -[^8]: For increased compatibility with patsy, we use patsy's signature for `standardize`. \ No newline at end of file +[^8]: For increased compatibility with patsy, we use patsy's signature for `standardize`. +[^9]: Requires additional context to be passed in when directly using the `Formula` constructor. e.g. `Formula("y ~ .", context={"__formulaic_variables_available__": ["x", "y", "z"]})`; or you can use `model_matrix`, `ModelSpec.get_model_matrix()`, or `FormulaMaterializer.get_model_matrix()` without further specification. diff --git a/formulaic/formula.py b/formulaic/formula.py index a06611dd..723c4b85 100644 --- a/formulaic/formula.py +++ b/formulaic/formula.py @@ -74,6 +74,7 @@ def __call__( _ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE, _parser: Optional[FormulaParser] = None, _nested_parser: Optional[FormulaParser] = None, + _context: Optional[Mapping[str, Any]] = None, **structure: FormulaSpec, ) -> Formula: """ @@ -82,7 +83,7 @@ def __call__( `SimpleFormula` instance will be returned; otherwise, a `StructuredFormula`. - Some arguments a prefixed with underscores to prevent collision with + Some arguments are prefixed with underscores to prevent collision with formula structure. Args: @@ -108,6 +109,7 @@ def __call__( _ordering=_ordering, _parser=_parser, _nested_parser=_nested_parser, + _context=_context, **structure, ) return self @@ -120,13 +122,15 @@ def __call__( _parser=_parser, _nested_parser=_nested_parser, _ordering=_ordering, - **structure, - ) + _context=_context, + **structure, # type: ignore[arg-type] + )._simplify() return cls.from_spec( cast(FormulaSpec, root), ordering=_ordering, parser=_parser, nested_parser=_nested_parser, + context=_context, ) def from_spec( @@ -136,6 +140,7 @@ def from_spec( ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE, parser: Optional[FormulaParser] = None, nested_parser: Optional[FormulaParser] = None, + context: Optional[Mapping[str, Any]] = None, ) -> Union[SimpleFormula, StructuredFormula]: """ Construct a `SimpleFormula` or `StructuredFormula` instance from a @@ -164,18 +169,25 @@ def from_spec( if isinstance(spec, str): spec = cast( FormulaSpec, - (parser or DefaultFormulaParser()).get_terms(spec)._simplify(), + (parser or DefaultFormulaParser()) + .get_terms(spec, context=context) + ._simplify(), ) if isinstance(spec, dict): return StructuredFormula( - _parser=parser, _nested_parser=nested_parser, _ordering=ordering, **spec + _parser=parser, + _nested_parser=nested_parser, + _ordering=ordering, + _context=context, + **spec, # type: ignore[arg-type] ) if isinstance(spec, Structured): return StructuredFormula( _ordering=ordering, _parser=nested_parser, _nested_parser=nested_parser, + _context=context, **spec._structure, )._simplify() if isinstance(spec, tuple): @@ -184,13 +196,14 @@ def from_spec( _ordering=ordering, _parser=parser, _nested_parser=nested_parser, + _context=context, )._simplify() if isinstance(spec, (list, set, OrderedSet)): terms = [ term for value in spec for term in ( - nested_parser.get_terms(value) # type: ignore[attr-defined] + nested_parser.get_terms(value, context=context) # type: ignore[attr-defined] if isinstance(value, str) else [value] ) @@ -248,9 +261,11 @@ class Formula(metaclass=_FormulaMeta): def __init__( self, root: Union[FormulaSpec, _MissingType] = MISSING, + *, _parser: Optional[FormulaParser] = None, _nested_parser: Optional[FormulaParser] = None, _ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE, + _context: Optional[Mapping[str, Any]] = None, **structure: FormulaSpec, ): """ @@ -288,7 +303,7 @@ def get_model_matrix( @abstractmethod def required_variables(self) -> Set[Variable]: """ - The set of variables required in the data order to materialize this + The set of variables required to be in the data to materialize this formula. Attempts are made to restrict these variables only to those expected in @@ -354,6 +369,7 @@ def __init__( _ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE, _parser: Optional[FormulaParser] = None, _nested_parser: Optional[FormulaParser] = None, + _context: Optional[Mapping[str, Any]] = None, **structure: FormulaSpec, ): if root is MISSING: @@ -667,19 +683,22 @@ class StructuredFormula(Structured[SimpleFormula], Formula): formula specifications. Can be: "none", "degree" (default), or "sort". """ - __slots__ = ("_parser", "_nested_parser", "_ordering") + __slots__ = ("_parser", "_nested_parser", "_ordering", "_context") def __init__( self, root: Union[FormulaSpec, _MissingType] = MISSING, + *, + _ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE, _parser: Optional[FormulaParser] = None, _nested_parser: Optional[FormulaParser] = None, - _ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE, + _context: Optional[Mapping[str, Any]] = None, **structure: FormulaSpec, ): + self._ordering = OrderingMethod(_ordering) self._parser = _parser or DEFAULT_PARSER self._nested_parser = _nested_parser or _parser or DEFAULT_NESTED_PARSER - self._ordering = OrderingMethod(_ordering) + self._context = _context super().__init__(root, **structure) # type: ignore self._simplify(unwrap=False, inplace=True) @@ -704,6 +723,7 @@ def _prepare_item( # type: ignore[override] ordering=self._ordering, parser=(self._parser if key == "root" else self._nested_parser), nested_parser=self._nested_parser, + context=self._context, ) def get_model_matrix( @@ -782,3 +802,14 @@ def differentiate( # pylint: disable=redefined-builtin SimpleFormula, self._map(lambda formula: formula.differentiate(*wrt, use_sympy=use_sympy)), ) + + # Ensure pickling never includes context + def __getstate__(self) -> Tuple[None, Dict[str, Any]]: + slots = self.__slots__ + Structured.__slots__ + return ( + None, + { + slot: getattr(self, slot) if slot != "_context" else None + for slot in slots + }, + ) diff --git a/formulaic/materializers/base.py b/formulaic/materializers/base.py index 0eb2cf4d..b6a45aa9 100644 --- a/formulaic/materializers/base.py +++ b/formulaic/materializers/base.py @@ -163,7 +163,9 @@ def get_model_matrix( from formulaic import ModelSpec # Prepare ModelSpec(s) - spec: Union[ModelSpec, ModelSpecs] = ModelSpec.from_spec(spec, **spec_overrides) + spec: Union[ModelSpec, ModelSpecs] = ModelSpec.from_spec( + spec, context=self.layered_context, **spec_overrides + ) should_simplify = isinstance(spec, ModelSpec) model_specs: ModelSpecs = self._prepare_model_specs(spec) diff --git a/formulaic/model_spec.py b/formulaic/model_spec.py index f60dfc67..824e35a5 100644 --- a/formulaic/model_spec.py +++ b/formulaic/model_spec.py @@ -78,6 +78,8 @@ class ModelSpec: def from_spec( cls, spec: Union[FormulaSpec, ModelMatrix, ModelMatrices, ModelSpec, ModelSpecs], + *, + context: Optional[Mapping[str, Any]] = None, **attrs: Any, ) -> Union[ModelSpec, ModelSpecs]: """ @@ -90,6 +92,11 @@ def from_spec( instance or structured set of `ModelSpec` instances. attrs: Any `ModelSpec` attributes to set and/or override on all generated `ModelSpec` instances. + context: Optional additional context to pass through to the formula + parsing algorithms. This is not normally required, and if + involved operators place additional constraints on the type + and/or structure of this context, they will raise exceptions + when they are not satisfied with instructions for how to fix it. """ from .model_matrix import ModelMatrix @@ -98,7 +105,7 @@ def prepare_model_spec(obj: Any) -> Union[ModelSpec, ModelSpecs]: obj = obj.model_spec if isinstance(obj, ModelSpec): return obj.update(**attrs) - formula = Formula.from_spec(obj) + formula = Formula.from_spec(obj, context=context) if isinstance(formula, StructuredFormula): return cast( ModelSpecs, formula._map(prepare_model_spec, as_type=ModelSpecs) @@ -417,6 +424,21 @@ def variables_by_source(self) -> Dict[Optional[str], Set[Variable]]: variables_by_source[variable.source].add(variable) return dict(variables_by_source) + @property + def required_variables(self) -> Set[Variable]: + """ + The set of variables required to be in the data to materialize this + model specification. + + If `.structure` has not been populated (which contains metadata about + which columns where ultimate drawn from the data during + materialization), then this will fallback to the variables inferred to + be required by `.formula`. + """ + if self.structure is None: + return self.formula.required_variables + return self.variables_by_source.get("data", set()) + def get_slice(self, columns_identifier: Union[int, str, Term, slice]) -> slice: """ Generate a `slice` instance corresponding to the columns associated with @@ -459,6 +481,24 @@ def get_slice(self, columns_identifier: Union[int, str, Term, slice]) -> slice: # Utility methods + def get_materializer( + self, data: Any, context: Optional[Mapping[str, Any]] = None + ) -> FormulaMaterializer: + """ + Construct a `FormulaMaterializer` instance for `data` that can be used + to generate model matrices consistent with this model specification. + + Args: + data: The data for which to build the materializer. + context: An additional mapping object of names to make available in + when evaluating formula term factors. + """ + if self.materializer is None: + materializer = FormulaMaterializer.for_data(data) + else: + materializer = FormulaMaterializer.for_materializer(self.materializer) + return materializer(data, context=context, **(self.materializer_params or {})) + def get_model_matrix( self, data: Any, @@ -484,13 +524,12 @@ def get_model_matrix( """ if attr_overrides: return self.update(**attr_overrides).get_model_matrix(data, context=context) - if self.materializer is None: - materializer = FormulaMaterializer.for_data(data) - else: - materializer = FormulaMaterializer.for_materializer(self.materializer) - return materializer( - data, context=context, **(self.materializer_params or {}) - ).get_model_matrix(self, drop_rows=drop_rows) + return cast( + "ModelMatrix", + self.get_materializer(data, context=context).get_model_matrix( + self, drop_rows=drop_rows + ), + ) def get_linear_constraints(self, spec: LinearConstraintSpec) -> LinearConstraints: """ @@ -632,6 +671,16 @@ def _prepare_item(self, key: str, item: Any) -> Any: ) return item + @property + def required_variables(self) -> Set[Variable]: + """ + The set of variables required to be in the data to materialize all of + the model specifications in this `ModelSpecs` instance. + """ + variables: Set[Variable] = set() + self._map(lambda ms: variables.update(ms.required_variables)) + return variables + def get_model_matrix( self, data: Any, diff --git a/formulaic/parser/algos/sanitize_tokens.py b/formulaic/parser/algos/sanitize_tokens.py index 1e16d863..13336150 100644 --- a/formulaic/parser/algos/sanitize_tokens.py +++ b/formulaic/parser/algos/sanitize_tokens.py @@ -15,6 +15,8 @@ def sanitize_tokens(tokens: Iterable[Token]) -> Iterable[Token]: - possible more in the future """ for token in tokens: + if token.token == ".": # noqa: S105 + token.kind = Token.Kind.OPERATOR if token.kind is Token.Kind.PYTHON: token.token = sanitize_python_code(token.token) yield token diff --git a/formulaic/parser/algos/tokens_to_ast.py b/formulaic/parser/algos/tokens_to_ast.py index ed6c7f99..6631fb00 100644 --- a/formulaic/parser/algos/tokens_to_ast.py +++ b/formulaic/parser/algos/tokens_to_ast.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Iterable, List, Union +from typing import Iterable, List, Set, Union from ..types import ASTNode, Operator, OperatorResolver, Token from ..utils import exc_for_missing_operator, exc_for_token @@ -42,6 +42,7 @@ def tokens_to_ast( """ output_queue: List[Union[Token, ASTNode]] = [] operator_stack: List[OrderedOperator] = [] + disabled_operators: Set[Token] = set() def stack_operator(operator: Union[Token, Operator], token: Token) -> None: operator_stack.append(OrderedOperator(operator, token, len(output_queue))) @@ -98,30 +99,55 @@ def operate( f"Context token `{token.token}` is unrecognized.", ) elif token.kind is token.Kind.OPERATOR: - max_prefix_arity = ( - len(output_queue) - operator_stack[-1].index - if operator_stack - else len(output_queue) - ) - operators = operator_resolver.resolve( - token, - max_prefix_arity=max_prefix_arity, - context=[s.operator for s in operator_stack], - ) - - for operator in operators: - while ( - operator_stack - and operator_stack[-1].token.kind is not Token.Kind.CONTEXT - and ( - operator_stack[-1].operator.precedence > operator.precedence - or operator_stack[-1].operator.precedence == operator.precedence - and operator.associativity is Operator.Associativity.LEFT + for operator_token, operators in operator_resolver.resolve(token): + for operator in operators: + if not operator.accepts_context( + [s.operator for s in operator_stack] + ): + continue + if operator.disabled: + disabled_operators.add(operator_token) + continue + # Apply all operators with precedence greater than the current operator + while ( + operator_stack + and operator_stack[-1].token.kind is not Token.Kind.CONTEXT + and ( + operator_stack[-1].operator.precedence > operator.precedence + or operator_stack[-1].operator.precedence + == operator.precedence + and operator.associativity is Operator.Associativity.LEFT + ) + ): + output_queue = operate(operator_stack.pop(), output_queue) + + # Determine maximum number of postfix arguments + max_postfix_arity = ( + len(output_queue) - operator_stack[-1].index + if operator_stack + else len(output_queue) ) - ): - output_queue = operate(operator_stack.pop(), output_queue) - stack_operator(operator, token) + # Check if operator is valid in current context + if ( + operator.arity == 0 + or operator.fixity is Operator.Fixity.PREFIX + or max_postfix_arity == 1 + and operator.fixity is Operator.Fixity.INFIX + or max_postfix_arity >= operator.arity + and operator.fixity is Operator.Fixity.POSTFIX + ): + stack_operator(operator, token) + break + else: + if operator_token in disabled_operators: + raise exc_for_token( + token, + f"Operator `{operator_token}` is at least partially disabled by parser configuration, and/or is incorrectly used.", + ) + raise exc_for_token( + token, f"Operator `{operator_token}` is incorrectly used." + ) else: output_queue.append(token) @@ -134,7 +160,16 @@ def operate( if output_queue: if len(output_queue) > 1: - raise exc_for_missing_operator(output_queue[0], output_queue[1]) + raise exc_for_missing_operator( + output_queue[0], + output_queue[1], + extra=( + "This may be due to the following operators being at least " + f"partially disabled by parser configuration: {disabled_operators}." + if disabled_operators + else None + ), + ) return output_queue[0] return None diff --git a/formulaic/parser/parser.py b/formulaic/parser/parser.py index e85a7f1f..0c8dc895 100644 --- a/formulaic/parser/parser.py +++ b/formulaic/parser/parser.py @@ -6,13 +6,28 @@ import re from dataclasses import dataclass, field from enum import Flag, auto -from typing import Iterable, List, Sequence, Set, Tuple, Union, cast +from typing import ( + Any, + Generator, + Iterable, + List, + Mapping, + MutableMapping, + Set, + Tuple, + Union, + cast, +) from typing_extensions import Self +from formulaic.errors import FormulaParsingError +from formulaic.utils.layered_mapping import LayeredMapping + from .algos.sanitize_tokens import sanitize_tokens from .algos.tokenize import tokenize from .types import ( + ASTNode, Factor, FormulaParser, Operator, @@ -97,13 +112,17 @@ def set_feature_flags(self, flags: DefaultParserFeatureFlag | Set[str]) -> Self: self.__post_init__() return self - def get_tokens(self, formula: str) -> Iterable[Token]: + def get_tokens_from_formula( + self, formula: str, *, context: MutableMapping[str, Any] + ) -> Iterable[Token]: """ Return an iterable of `Token` instances for the nominated `formula` string. Args: formula: The formula string to be tokenized. + context: An optional context which may be used during the evaluation + of operators. """ # Transform formula to add intercepts and replace 0 with -1. We do this @@ -133,6 +152,7 @@ def get_tokens(self, formula: str) -> Iterable[Token]: [token_one], kind=Token.Kind.OPERATOR, join_operator="+", + no_join_for_operators={"+", "-"}, ) ) @@ -175,15 +195,24 @@ def find_rhs_index(tokens: List[Token]) -> int: [token_one], kind=Token.Kind.OPERATOR, join_operator="+", + no_join_for_operators={"+", "-"}, ), ] + context["__formulaic_variables_used_lhs__"] = [ + variable + for token in tokens[:rhs_index] + for variable in token.required_variables + ] + # Collapse inserted "+" and "-" operators to prevent unary issues. tokens = merge_operator_tokens(tokens, symbols={"+", "-"}) return tokens - def get_terms(self, formula: str) -> Structured[List[Term]]: + def get_terms_from_ast( + self, ast: Union[None, Token, ASTNode], *, context: MutableMapping[str, Any] + ) -> Structured[OrderedSet[Term]]: """ Assemble the `Term` instances for a formula string. Depending on the operators involved, this may be an iterable of `Term` instances, or @@ -195,8 +224,11 @@ def get_terms(self, formula: str) -> Structured[List[Term]]: Args: formula: The formula for which an AST should be generated. + context: An optional context which may be used during the evaluation + of operators. """ - terms = super().get_terms(formula) + + terms = super().get_terms_from_ast(ast, context=context) def check_terms(terms: Iterable[Term]) -> None: seen_terms = set() @@ -209,11 +241,13 @@ def check_terms(terms: Iterable[Term]) -> None: ): raise exc_for_token( factor.token or Token(), - "Numeric literals other than `1` can only be used " - "to scale other terms. (tip: Use `:` rather than " - "`*` when scaling terms)" - if factor.expr.replace(".", "", 1).isnumeric() - else "String literals are not valid in formulae.", + ( + "Numeric literals other than `1` can only be used " + "to scale other terms. (tip: Use `:` rather than " + "`*` when scaling terms)" + if factor.expr.replace(".", "", 1).isnumeric() + else "String literals are not valid in formulae." + ), ) else: for factor in term.factors: @@ -333,6 +367,37 @@ def get_terms(terms: OrderedSet[Term]) -> List[Term]: return Structured(get_terms(lhs), deps=(Structured(lhs=lhs, rhs=rhs),)) + def insert_unused_terms(context: Mapping[str, Any]) -> OrderedSet[Term]: + available_variables: OrderedSet[str] + used_variables: Set[str] = set(context["__formulaic_variables_used_lhs__"]) + + # Populate `available_variables` or raise. + if "__formulaic_variables_available__" in context: + available_variables = OrderedSet( + context["__formulaic_variables_available__"] + ) + elif isinstance(context, LayeredMapping) and "data" in context.named_layers: + available_variables = OrderedSet(context.named_layers["data"]) + else: + raise FormulaParsingError( + "The `.` operator requires additional context about which " + "variables are available to use. This can be provided by " + "passing in a value for `__formulaic_variables_available__`" + "in the context while parsing the formula; by passing the " + "formula to the materializer's `.get_model_matrix()` method; " + "or by passing a `LayeredMapping` instance as the context " + "with a `data` layer containing the available variables " + "(such as the `.layered_context` from a " + "`FormulaMaterializer` instance)." + ) + + unused_variables = available_variables - used_variables + + return OrderedSet( + Term([Factor(variable, eval_method="lookup")]) + for variable in unused_variables + ) + return [ Operator( "~", @@ -451,13 +516,22 @@ def get_terms(terms: OrderedSet[Term]) -> List[Term]: Operator( "^", arity=2, precedence=500, associativity="right", to_terms=power ), + Operator( + ".", + arity=0, + precedence=1000, + fixity="postfix", + to_terms=insert_unused_terms, + ), ] def resolve( - self, token: Token, max_prefix_arity: int, context: List[Union[Token, Operator]] - ) -> Sequence[Operator]: + self, + token: Token, + ) -> Generator[Tuple[Token, Iterable[Operator]], None, None]: if token.token in self.operator_table: - return super().resolve(token, max_prefix_arity, context) + yield from super().resolve(token) + return symbol = token.token @@ -474,9 +548,8 @@ def resolve( ) if symbol in self.operator_table: - return [self._resolve(token, symbol, max_prefix_arity, context)] + yield self._resolve(token, symbol) + return - return [ - self._resolve(token, sym, max_prefix_arity if i == 0 else 0, context) - for i, sym in enumerate(symbol) - ] + for sym in symbol: + yield self._resolve(token, sym) diff --git a/formulaic/parser/types/ast_node.py b/formulaic/parser/types/ast_node.py index cf93f222..256fb75b 100644 --- a/formulaic/parser/types/ast_node.py +++ b/formulaic/parser/types/ast_node.py @@ -1,9 +1,22 @@ from __future__ import annotations +import functools import graphlib -from typing import Any, Dict, Generic, Iterable, List, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, +) from .operator import Operator +from .ordered_set import OrderedSet from .structured import Structured from .term import Term @@ -28,13 +41,19 @@ def __init__(self, operator: Operator, args: Iterable[Any]): self.operator = operator self.args = args - def to_terms(self) -> Union[List[Term], Structured[List[Term]], Tuple]: + def to_terms( + self, *, context: Optional[Mapping[str, Any]] = None + ) -> Union[OrderedSet[Term], Structured[OrderedSet[Term]], Tuple]: """ Evaluate this AST node and return the resulting set of `Term` instances. Note: We use topological evaluation here to avoid recursion issues for long formula (exceeding ~700 terms, though this depends on the recursion limit set in the interpreter). + + Args: + context: An optional context mapping that can be used by operators + to modify their behaviour (e.g. the `.` operator). """ g = graphlib.TopologicalSorter(self.__generate_evaluation_graph()) g.prepare() @@ -43,16 +62,18 @@ def to_terms(self) -> Union[List[Term], Structured[List[Term]], Tuple]: while g.is_active(): for node in g.get_ready(): - node_args = ( + node_args = tuple( (results[arg] if isinstance(arg, ASTNode) else arg.to_terms()) for arg in node.args ) - if node.operator.structural: - results[node] = node.operator.to_terms(*node_args) + if node.operator.structural or not node_args: + results[node] = node.operator.to_terms(*node_args, context=context) else: results[node] = Structured._merge( *node_args, - merger=node.operator.to_terms, + merger=functools.partial( + node.operator.to_terms, context=context + ), ) g.done(node) @@ -78,9 +99,11 @@ def flatten(self, str_args: bool = False) -> List[Any]: return [ str(self.operator) if str_args else self.operator, *[ - arg.flatten(str_args=str_args) - if isinstance(arg, ASTNode) - else (str(arg) if str_args else arg) + ( + arg.flatten(str_args=str_args) + if isinstance(arg, ASTNode) + else (str(arg) if str_args else arg) + ) for arg in self.args ], ] diff --git a/formulaic/parser/types/factor.py b/formulaic/parser/types/factor.py index d6cad25c..dc4ed4a4 100644 --- a/formulaic/parser/types/factor.py +++ b/formulaic/parser/types/factor.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union from .ordered_set import OrderedSet from .term import Term @@ -91,7 +91,9 @@ def __lt__(self, other: Any) -> bool: return self.expr < other.expr return NotImplemented - def to_terms(self) -> OrderedSet[Term]: + def to_terms( + self, *, context: Optional[Mapping[str, Any]] = None + ) -> OrderedSet[Term]: """ Convert this `Factor` instance into a `Term` instance, and expose it as a single-element ordered set. diff --git a/formulaic/parser/types/formula_parser.py b/formulaic/parser/types/formula_parser.py index 7cedc883..4da93290 100644 --- a/formulaic/parser/types/formula_parser.py +++ b/formulaic/parser/types/formula_parser.py @@ -1,5 +1,22 @@ +from __future__ import annotations + from dataclasses import dataclass -from typing import Iterable, List, Union +from enum import IntEnum +from typing import ( + Any, + Iterable, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, + overload, +) + +from typing_extensions import Literal + +from formulaic.parser.types.ordered_set import OrderedSet +from formulaic.utils.layered_mapping import LayeredMapping from .ast_node import ASTNode from .operator_resolver import OperatorResolver @@ -31,9 +48,91 @@ class FormulaParser: Only the `get_terms()` method is essential from an API perspective. """ + class Target(IntEnum): + FORMULA = 0 + TOKENS = 1 + AST = 2 + TERMS = 3 + operator_resolver: OperatorResolver + context: Optional[Mapping[str, Any]] = None + + @overload + def parse( + self, + formula: str, + *, + target: Literal[FormulaParser.Target.FORMULA, "formula", 0], + context: Optional[Mapping[str, Any]] = None, + ) -> str: ... + + @overload + def parse( + self, + formula: str, + *, + target: Literal[FormulaParser.Target.TOKENS, "tokens", 1], + context: Optional[Mapping[str, Any]] = None, + ) -> Iterable[Token]: ... + + @overload + def parse( + self, + formula: str, + *, + target: Literal[FormulaParser.Target.AST, "ast", 2], + context: Optional[Mapping[str, Any]] = None, + ) -> Union[None, Token, ASTNode]: ... + + @overload + def parse( + self, + formula: str, + *, + target: Literal[FormulaParser.Target.TERMS, "terms", 3], + context: Optional[Mapping[str, Any]] = None, + ) -> Structured[OrderedSet[Term]]: ... + + def parse( + self, + formula: str, + *, + target: Union[Target, str, int] = Target.TERMS, + context: Optional[Mapping[str, Any]] = None, + ) -> Union[ + str, Iterable[Token], Union[None, Token, ASTNode], Structured[OrderedSet[Term]] + ]: + """ + Parse the nominated `formula` string to the nominated `target`. - def get_tokens(self, formula: str) -> Iterable[Token]: + Args: + formula: The formula string to be parsed. + context: An optional context which may be used during the evaluation + of operators. + """ + if isinstance(target, int): + target = self.Target(target) + elif isinstance(target, str): + target = self.Target[target.upper()] + + out: Union[ + str, + Iterable[Token], + Union[None, Token, ASTNode], + Structured[OrderedSet[Term]], + ] = formula + context = LayeredMapping(context or {}, self.context) + if target >= self.Target.TOKENS: + out = tokens = self.get_tokens_from_formula(formula, context=context) + if target >= self.Target.AST: + out = ast = self.get_ast_from_tokens(tokens, context=context) + if target >= self.Target.TERMS: + out = self.get_terms_from_ast(ast, context=context) + return out + + def get_tokens_from_formula( + self, formula: str, *, context: MutableMapping[str, Any] + ) -> Iterable[Token]: """ Return an iterable of `Token` instances for the nominated `formula` string. @@ -46,9 +145,11 @@ def get_tokens(self, formula: str) -> Iterable[Token]: return sanitize_tokens(tokenize(formula)) - def get_ast(self, formula: str) -> Union[None, Token, ASTNode]: + def get_ast_from_tokens( + self, tokens: Iterable[Token], *, context: MutableMapping[str, Any] + ) -> Union[None, Token, ASTNode]: """ - Assemble an abstract syntax tree for the nominated `formula` string. + Assemble an abstract syntax tree for the nominated `tokens`. Args: formula: The formula for which an AST should be generated. @@ -56,25 +157,74 @@ def get_ast(self, formula: str) -> Union[None, Token, ASTNode]: from ..algos.tokens_to_ast import tokens_to_ast return tokens_to_ast( - self.get_tokens(formula), + tokens, operator_resolver=self.operator_resolver, ) - def get_terms(self, formula: str) -> Structured[List[Term]]: + def get_terms_from_ast( + self, + ast: Union[None, Token, ASTNode], + *, + context: MutableMapping[str, Any], + ) -> Structured[OrderedSet[Term]]: """ - Assemble the `Term` instances for a formula string. Depending on the - operators involved, this may be an iterable of `Term` instances, or - an iterable of iterables of `Term`s, etc. + Assemble the structured `Term` instances for the nominated AST. A + `Structured` instance will always be returned even if the structure is + trivial. Args: formula: The formula for which an AST should be generated. + context: An optional context which may be used during the evaluation + of operators. """ - ast = self.get_ast(formula) if ast is None: return Structured([]) - terms = ast.to_terms() + terms: Union[ + OrderedSet[Term], Tuple[OrderedSet[Term]], Structured[OrderedSet[Term]] + ] = ast.to_terms(context=context) if not isinstance(terms, Structured): - terms = Structured[List[Term]](terms) + terms = Structured[OrderedSet[Term]](terms) return terms + + # Convenience methods for common use-cases. + + def get_tokens( + self, formula: str, *, context: Optional[Mapping[str, Any]] = None + ) -> Iterable[Token]: + """ + Parse the nominated `formula` string and return the resulting tokens. + + Args: + formula: The formula string to be parsed. + context: An optional context which may be used during the evaluation + of operators. + """ + return self.parse(formula, target=self.Target.TOKENS, context=context) + + def get_ast( + self, formula: str, *, context: Optional[Mapping[str, Any]] = None + ) -> Union[None, Token, ASTNode]: + """ + Assemble an abstract syntax tree for the nominated `formula` string. + + Args: + formula: The formula for which an AST should be generated. + context: An optional context which may be used during the evaluation + of operators. + """ + return self.parse(formula, target=self.Target.AST, context=context) + + def get_terms( + self, formula: str, *, context: Optional[Mapping[str, Any]] = None + ) -> Structured[OrderedSet[Term]]: + """ + Parse the nominated `formula` string and return the resulting terms. + + Args: + formula: The formula string to be parsed. + context: An optional context which may be used during the evaluation + of operators. + """ + return self.parse(formula, target=self.Target.TERMS, context=context) diff --git a/formulaic/parser/types/operator.py b/formulaic/parser/types/operator.py index a0f7f804..844bc4c0 100644 --- a/formulaic/parser/types/operator.py +++ b/formulaic/parser/types/operator.py @@ -1,7 +1,8 @@ from __future__ import annotations +import inspect from enum import Enum -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Mapping, Optional, Union from .token import Token @@ -91,9 +92,11 @@ def fixity(self) -> Operator.Fixity: def fixity(self, fixity: Union[str, Operator.Fixity]) -> None: self._fixity = Operator.Fixity(fixity) - def to_terms(self, *args: Any) -> Any: + def to_terms(self, *args: Any, context: Optional[Mapping[str, Any]] = None) -> Any: if self._to_terms is None: raise RuntimeError(f"`to_terms` is not implemented for '{self.symbol}'.") + if inspect.signature(self._to_terms).parameters.get("context"): + return self._to_terms(*args, context=context or {}) return self._to_terms(*args) def accepts_context(self, context: List[Union[Token, Operator]]) -> bool: diff --git a/formulaic/parser/types/operator_resolver.py b/formulaic/parser/types/operator_resolver.py index 1c712803..8d4f8d14 100644 --- a/formulaic/parser/types/operator_resolver.py +++ b/formulaic/parser/types/operator_resolver.py @@ -1,6 +1,6 @@ import abc from collections import defaultdict -from typing import Dict, List, Sequence, Union +from typing import Dict, Generator, Iterable, List, Tuple from ..utils import exc_for_token from .operator import Operator @@ -53,11 +53,15 @@ def operator_table(self) -> Dict[str, List[Operator]]: return operator_table def resolve( - self, token: Token, max_prefix_arity: int, context: List[Union[Token, Operator]] - ) -> Sequence[Operator]: + self, token: Token + ) -> Generator[Tuple[Token, Iterable[Operator]], None, None]: """ - Return a list of operators to apply for a given token in the AST - generation. + Generate the sets of operator candidates that may be viable for the + given token (which may include multiple adjacent operators concatenated + together). Each item generated must be a tuple for the token associated + with the operator, and an iterable of `Operator` instances which should + be considered by the AST generator. These `Operator` instances *MUST* be + sorted in descending order of precendence and arity. Args: token: The operator `Token` instance for which `Operator`(s) should @@ -68,45 +72,19 @@ def resolve( resolved will be placed. This will be a list of `Operator` instances or tokens (tokens are return for grouping operators). """ - return [self._resolve(token, token.token, max_prefix_arity, context)] + yield self._resolve(token, token.token) def _resolve( self, token: Token, symbol: str, - max_prefix_arity: int, - context: List[Union[Token, Operator]], - ) -> Operator: + ) -> Tuple[Token, Iterable[Operator]]: """ The default operator resolving logic. """ if symbol not in self.operator_table: raise exc_for_token(token, f"Unknown operator '{symbol}'.") - candidates = [ - candidate - for candidate in self.operator_table[symbol] - if ( - max_prefix_arity == 0 - and candidate.fixity is Operator.Fixity.PREFIX - or max_prefix_arity > 0 - and candidate.fixity is not Operator.Fixity.PREFIX - ) - and candidate.accepts_context(context) - ] - if not candidates: - raise exc_for_token(token, f"Operator `{symbol}` is incorrectly used.") - candidates = [candidate for candidate in candidates if not candidate.disabled] - if not candidates: - raise exc_for_token( - token, - f"Operator `{symbol}` has been disabled in this context via parser configuration.", - ) - if len(candidates) > 1: - raise exc_for_token( - token, - f"Ambiguous operator `{symbol}`. This is not usually a user error. Please report this!", - ) - return candidates[0] + return token, self.operator_table[symbol] # The operator table cache may not be pickleable, so let's drop it. def __getstate__(self) -> Dict: diff --git a/formulaic/parser/types/term.py b/formulaic/parser/types/term.py index 7157689d..a409c699 100644 --- a/formulaic/parser/types/term.py +++ b/formulaic/parser/types/term.py @@ -1,7 +1,9 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional + +from .ordered_set import OrderedSet if TYPE_CHECKING: from .factor import Factor # pragma: no cover @@ -65,5 +67,13 @@ def __lt__(self, other: Any) -> bool: return False return NotImplemented + def to_terms( + self, *, context: Optional[Mapping[str, Any]] = None + ) -> OrderedSet[Term]: + """ + Convert this `Term` instance into set of `Term`s. + """ + return OrderedSet((self,)) + def __repr__(self) -> str: return ":".join(repr(factor) for factor in self.factors) diff --git a/formulaic/parser/types/token.py b/formulaic/parser/types/token.py index 36793966..331d0e5a 100644 --- a/formulaic/parser/types/token.py +++ b/formulaic/parser/types/token.py @@ -3,7 +3,9 @@ import copy import re from enum import Enum -from typing import Any, Iterable, Optional, Tuple, Union +from typing import Any, Iterable, Mapping, Optional, Set, Tuple, Union + +from formulaic.utils.variables import Variable, get_expression_variables from .factor import Factor from .ordered_set import OrderedSet @@ -140,7 +142,9 @@ def to_factor(self) -> Factor: token=self, ) - def to_terms(self) -> OrderedSet[Term]: + def to_terms( + self, *, context: Optional[Mapping[str, Any]] = None + ) -> OrderedSet[Term]: """ An order set of `Term` instances for this token. This will just be an iterable with one `Term` having one `Factor` (that generated by @@ -177,6 +181,40 @@ def get_source_context(self, colorize: bool = False) -> Optional[str]: return f"{self.source[:self.source_start]}⧛{RED_BOLD}{self.source[self.source_start:self.source_end+1]}{RESET}⧚{self.source[self.source_end+1:]}" return f"{self.source[:self.source_start]}⧛{self.source[self.source_start:self.source_end+1]}⧚{self.source[self.source_end+1:]}" + @property + def required_variables(self) -> Set[Variable]: + """ + The set of variables required to evaluate this token. + + If this is a Python token, and the code is malformed and unable to be + parsed, an empty set is returned. The code will fail more gracefully + later on. + + Attempts are made to restrict these variables only to those expected in + the data, and not, for example, those associated with transforms and/or + values present in the evaluation namespace by default (e.g. `y ~ C(x)` + would include only `y` and `x`). This may not always be possible for + more advanced formulae that insert constants into the formula via the + evaluation context rather than the data context. + """ + if self.kind is Token.Kind.NAME: + return {Variable(self.token)} + if self.kind is Token.Kind.PYTHON: + try: + # Filter out constants like `contr` that are already present in the + # TRANSFORMS namespace. + from formulaic.transforms import TRANSFORMS + + return set( + filter( + lambda variable: variable.split(".", 1)[0] not in TRANSFORMS, + get_expression_variables(self.token), + ) + ) + except Exception: # noqa: S110 + pass + return set() + def __repr__(self) -> str: return self.token diff --git a/formulaic/parser/utils.py b/formulaic/parser/utils.py index b26dec3f..e000454a 100644 --- a/formulaic/parser/utils.py +++ b/formulaic/parser/utils.py @@ -34,6 +34,7 @@ def exc_for_missing_operator( lhs: Union[Token, ASTNode], rhs: Union[Token, ASTNode], errcls: Type[Exception] = FormulaSyntaxError, + extra: Optional[str] = None, ) -> Exception: """ Return an exception ready to be raised about a missing operator token @@ -45,11 +46,12 @@ def exc_for_missing_operator( rhs: The `Token` or `ASTNode` instance to the right of where an operator should be placed. errcls: The type of the exception to be returned. + extra: Any additional information to be included in the exception message. """ lhs_token, rhs_token, error_token = __get_tokens_for_gap(lhs, rhs) return exc_for_token( error_token, - f"Missing operator between `{lhs_token.token}` and `{rhs_token.token}`.", + f"Missing operator between `{lhs_token.token}` and `{rhs_token.token}`.{f' {extra}' if extra else ''}", errcls=errcls, ) @@ -69,9 +71,11 @@ def __get_token_for_ast(ast: Union[Token, ASTNode]) -> Token: # pragma: no cove while isinstance(rhs_token, ASTNode): rhs_token = rhs_token.args[-1] # type: ignore return Token( - token=lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1] - if lhs_token.source - else "", + token=( + lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1] + if lhs_token.source + else "" + ), source=lhs_token.source, source_start=lhs_token.source_start, source_end=rhs_token.source_end, @@ -91,19 +95,29 @@ def __get_tokens_for_gap( """ lhs_token = lhs while isinstance(lhs_token, ASTNode): - lhs_token = lhs_token.args[-1] # type: ignore + lhs_token = ( + lhs_token.args[-1] # type: ignore + if lhs_token.args + else Token(lhs_token.operator.symbol) + ) rhs_token = rhs or lhs while isinstance(rhs_token, ASTNode): - rhs_token = rhs_token.args[0] # type: ignore + rhs_token = ( + rhs_token.args[0] # type: ignore + if rhs_token.args + else Token(rhs_token.operator.symbol) + ) return ( lhs_token, rhs_token, Token( - lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1] - if lhs_token.source - and lhs_token.source_start is not None - and rhs_token.source_end is not None - else "", + ( + lhs_token.source[lhs_token.source_start : rhs_token.source_end + 1] + if lhs_token.source + and lhs_token.source_start is not None + and rhs_token.source_end is not None + else "" + ), source=lhs_token.source, source_start=lhs_token.source_start, source_end=rhs_token.source_end, @@ -152,6 +166,7 @@ def insert_tokens_after( *, kind: Optional[Token.Kind] = None, join_operator: Optional[str] = None, + no_join_for_operators: Union[bool, Set[str]] = True, ) -> Iterable[Token]: """ Insert additional tokens into a sequence of tokens after (within token) @@ -175,6 +190,10 @@ def insert_tokens_after( the added tokens with existing tokens, the value set here will be used to create a joining operator token. If not provided, not additional operators are added. + no_join_for_operators: Whether to use the join operator when the next + token is an operator token; or a set of operator symbols for which + to skip adding the join token. + """ tokens = list(tokens) @@ -203,9 +222,11 @@ def insert_tokens_after( next_token = split_tokens[j + 1] elif i < len(tokens) - 1: next_token = tokens[i + 1] - if ( - next_token is not None - and next_token.kind is not Token.Kind.OPERATOR + if next_token is not None and ( + next_token.kind is not Token.Kind.OPERATOR + or no_join_for_operators is False + or isinstance(no_join_for_operators, set) + and next_token.token not in no_join_for_operators ): yield Token(join_operator, kind=Token.Kind.OPERATOR) diff --git a/formulaic/sugar.py b/formulaic/sugar.py index 844f060e..09201434 100644 --- a/formulaic/sugar.py +++ b/formulaic/sugar.py @@ -19,7 +19,10 @@ def model_matrix( This method is syntactic sugar for: ``` - Formula(spec).get_model_matrix(data, context=LayeredMapping(locals(), globals()), **kwargs) + Formula( + spec, + context={"__formulaic_variables_available__": ...}, # used for the `.` operator + ).get_model_matrix(data, context=LayeredMapping(locals(), globals()), **kwargs) ``` or ``` @@ -52,6 +55,12 @@ def model_matrix( nominated structure. """ _context = capture_context(context + 1) if isinstance(context, int) else context - return ModelSpec.from_spec(spec, **spec_overrides).get_model_matrix( - data, context=_context, drop_rows=drop_rows + _spec_context = ( # use materializer context for parser context + ModelSpec.from_spec([], **spec_overrides) + .get_materializer(data, context=_context) + .layered_context ) + + return ModelSpec.from_spec( + spec, context=_spec_context, **spec_overrides + ).get_model_matrix(data, context=_context, drop_rows=drop_rows) diff --git a/formulaic/utils/constraints.py b/formulaic/utils/constraints.py index cf95bc9f..8dbf00af 100644 --- a/formulaic/utils/constraints.py +++ b/formulaic/utils/constraints.py @@ -9,6 +9,7 @@ Dict, Iterable, List, + Mapping, Optional, Sequence, Set, @@ -320,7 +321,9 @@ def for_token(cls, token: Token) -> ConstraintToken: } ) - def to_terms(self) -> Set[ScaledFactor]: # type: ignore[override] + def to_terms( # type: ignore[override] + self, *, context: Optional[Mapping[str, Any]] = None + ) -> Set[ScaledFactor]: if self.kind is Token.Kind.VALUE: factor = ast.literal_eval(self.token) if isinstance(factor, (int, float)): diff --git a/formulaic/utils/variables.py b/formulaic/utils/variables.py index e075e400..748b8bee 100644 --- a/formulaic/utils/variables.py +++ b/formulaic/utils/variables.py @@ -46,7 +46,9 @@ def union(cls, *variable_sets: Iterable[Variable]) -> Set[Variable]: def get_expression_variables( - expr: Union[str, ast.AST], context: Mapping, aliases: Optional[Mapping] = None + expr: Union[str, ast.AST], + context: Optional[Mapping] = None, + aliases: Optional[Mapping] = None, ) -> Set[Variable]: """ Extract the variables that are used in the nominated Python expression. diff --git a/tests/materializers/test_pandas.py b/tests/materializers/test_pandas.py index f2e1abcc..9c60b972 100644 --- a/tests/materializers/test_pandas.py +++ b/tests/materializers/test_pandas.py @@ -67,6 +67,23 @@ ["Intercept"], 1, ), + ".": ( + ["Intercept", "a", "b", "A[T.b]", "A[T.c]", "B[T.b]", "B[T.c]"], + [ + "Intercept", + "a", + "b", + "A[a]", + "A[b]", + "A[c]", + "B[a]", + "B[b]", + "B[c]", + "D[a]", + ], + ["Intercept", "a", "b"], + 1, + ), } @@ -86,7 +103,13 @@ def data(self): @pytest.fixture def data_with_nulls(self): return pandas.DataFrame( - {"a": [1, 2, None], "A": ["a", None, "c"], "B": ["a", "b", None]} + { + "a": [1, 2, None], + "b": [1, 2, 3], + "A": ["a", None, "c"], + "B": ["a", "b", None], + "D": ["a", "a", "a"], + } ) @pytest.fixture @@ -182,7 +205,10 @@ def test_na_handling(self, data_with_nulls, formula, tests, output): formula, na_action="ignore" ) assert isinstance(mm, pandas.DataFrame) - assert mm.shape == (3, len(tests[0]) + (-1 if "A" in formula else 0)) + if formula == ".": + assert mm.shape == (3, 5) + else: + assert mm.shape == (3, len(tests[0]) + (-1 if "A" in formula else 0)) if formula != "C(A)": # C(A) pre-encodes the data, stripping out nulls. with pytest.raises(ValueError): diff --git a/tests/parser/test_parser.py b/tests/parser/test_parser.py index d574d0d2..6e6d299a 100644 --- a/tests/parser/test_parser.py +++ b/tests/parser/test_parser.py @@ -10,6 +10,7 @@ from formulaic.parser import DefaultFormulaParser, DefaultOperatorResolver from formulaic.parser.types import Structured, Token from formulaic.parser.types.term import Term +from formulaic.utils.layered_mapping import LayeredMapping FORMULA_TO_TOKENS = { "": ["1"], @@ -133,12 +134,21 @@ # Quoting "`a|b~c*d`": ["1", "a|b~c*d"], "{a | b | c}": ["1", "a | b | c"], + # Wildcards + ".": ["1", "a", "b", "c"], + ".^2": ["1", "a", "a:b", "a:c", "b", "b:c", "c"], + ".^2 - a:b": ["1", "a", "a:c", "b", "b:c", "c"], + "a ~ .": { + "lhs": ["a"], + "rhs": ["1", "b", "c"], + }, } PARSER = DefaultFormulaParser(feature_flags={"all"}) PARSER_NO_INTERCEPT = DefaultFormulaParser( include_intercept=False, feature_flags={"all"} ) +PARSER_CONTEXT = {"__formulaic_variables_available__": ["a", "b", "c"]} class TestFormulaParser: @@ -148,7 +158,9 @@ class TestFormulaParser: @pytest.mark.parametrize("formula,terms", FORMULA_TO_TERMS.items()) def test_to_terms(self, formula, terms): - generated_terms: Structured[List[Term]] = PARSER.get_terms(formula) + generated_terms: Structured[List[Term]] = PARSER.get_terms( + formula, context=PARSER_CONTEXT + ) if generated_terms._has_keys: comp = generated_terms._map(list)._to_dict() elif generated_terms._has_root and isinstance(generated_terms.root, tuple): @@ -250,7 +262,7 @@ def test_feature_flags(self): with pytest.raises( FormulaSyntaxError, match=re.escape( - "Operator `~` has been disabled in this context via parser configuration." + "Missing operator between `y` and `1`. This may be due to the following operators being at least partially disabled by parser configuration: {~}." ), ): DefaultFormulaParser(feature_flags={}).get_terms("y ~ x") @@ -266,7 +278,7 @@ def test_feature_flags(self): with pytest.raises( FormulaSyntaxError, match=re.escape( - "Operator `|` has been disabled in this context via parser configuration." + "Operator `|` is at least partially disabled by parser configuration, and/or is incorrectly used." ), ): DefaultFormulaParser().set_feature_flags({}).get_terms("x | y") @@ -280,6 +292,19 @@ def test_invalid_multistage_formula(self): ): DefaultFormulaParser(feature_flags={"all"}).get_terms("[[a ~ b] ~ c]") + def test_alternative_wildcard_usage(self): + PARSER.get_terms( + ".", context=LayeredMapping({"a": 1, "b": 2}, name="data") + ) == ["1", "a", "b"] + + with pytest.raises( + FormulaParsingError, + match=re.escape( + "The `.` operator requires additional context about which " + ), + ): + PARSER.get_terms(".") + class TestDefaultOperatorResolver: @pytest.fixture @@ -287,32 +312,24 @@ def resolver(self): return DefaultOperatorResolver() def test_resolve(self, resolver): - assert len(resolver.resolve(Token("+++++"), 1, [])) == 1 - assert resolver.resolve(Token("+++++"), 1, [])[0].symbol == "+" - assert resolver.resolve(Token("+++++"), 1, [])[0].arity == 2 - - assert len(resolver.resolve(Token("+++-+"), 1, [])) == 1 - assert resolver.resolve(Token("+++-+"), 1, [])[0].symbol == "-" - assert resolver.resolve(Token("+++-+"), 1, [])[0].arity == 2 - - assert len(resolver.resolve(Token("*+++-+"), 1, [])) == 2 - assert resolver.resolve(Token("*+++-+"), 1, [])[0].symbol == "*" - assert resolver.resolve(Token("*+++-+"), 1, [])[0].arity == 2 - assert resolver.resolve(Token("*+++-+"), 1, [])[1].symbol == "-" - assert resolver.resolve(Token("*+++-+"), 1, [])[1].arity == 1 - - with pytest.raises( - FormulaSyntaxError, match="Operator `/` is incorrectly used." - ): - resolver.resolve(Token("*/"), 2, []) - - def test_accepts_context(self, resolver): - tilde_operator = resolver.resolve(Token("~"), 1, [])[0] - - with pytest.raises( - FormulaSyntaxError, match=re.escape("Operator `~` is incorrectly used.") - ): - resolver.resolve(Token("~"), 1, [tilde_operator]) + resolved = list(resolver.resolve(Token("+++++"))) + assert len(resolved) == 1 + assert resolved[0][1][0].symbol == "+" + assert resolved[0][1][0].arity == 2 + + resolved = list(resolver.resolve(Token("+++-+"))) + assert len(resolved) == 1 + assert resolved[0][1][0].symbol == "-" + assert resolved[0][1][0].arity == 2 + + resolved = list(resolver.resolve(Token("*+++-+"))) + assert len(resolved) == 2 + assert resolved[0][1][0].symbol == "*" + assert resolved[0][1][0].arity == 2 + assert resolved[1][1][0].symbol == "-" + assert resolved[1][1][0].arity == 2 + assert resolved[1][1][1].symbol == "-" + assert resolved[1][1][1].arity == 1 def test_pickleable(self, resolver): o = BytesIO() diff --git a/tests/parser/test_utils.py b/tests/parser/test_utils.py index 62c28d54..38459302 100644 --- a/tests/parser/test_utils.py +++ b/tests/parser/test_utils.py @@ -1,9 +1,9 @@ -from ntpath import join - import pytest +from formulaic.errors import FormulaSyntaxError from formulaic.parser.types import Token from formulaic.parser.utils import ( + exc_for_token, insert_tokens_after, merge_operator_tokens, replace_tokens, @@ -29,6 +29,15 @@ def test_replace_tokens(tokens): ] +def test_exc_for_token(tokens): + with pytest.raises(FormulaSyntaxError, match="Hello World"): + raise exc_for_token(tokens[0], "Hello World") + with pytest.raises(FormulaSyntaxError, match="Hello World"): + raise exc_for_token( + Token("h", source="hi", source_start=0, source_end=1), "Hello World" + ) + + def test_insert_tokens_after(tokens): assert list( insert_tokens_after( @@ -50,6 +59,32 @@ def test_insert_tokens_after(tokens): join_operator="+", ) ) == ["1", "+|", "hi", "-", "field"] + assert list( + insert_tokens_after( + [ + Token("1", kind=Token.Kind.VALUE), + Token("+|-", kind=Token.Kind.OPERATOR), + Token("field", kind=Token.Kind.NAME), + ], + r"\|", + [Token("hi", kind=Token.Kind.NAME)], + join_operator="+", + no_join_for_operators=False, + ) + ) == ["1", "+|", "hi", "+", "-", "field"] + assert list( + insert_tokens_after( + [ + Token("1", kind=Token.Kind.VALUE), + Token("+|-", kind=Token.Kind.OPERATOR), + Token("field", kind=Token.Kind.NAME), + ], + r"\|", + [Token("hi", kind=Token.Kind.NAME)], + join_operator="+", + no_join_for_operators={"+", "-"}, + ) + ) == ["1", "+|", "hi", "-", "field"] assert list( insert_tokens_after( [ diff --git a/tests/parser/types/test_formula_parser.py b/tests/parser/types/test_formula_parser.py index 19bf6e9d..34448742 100644 --- a/tests/parser/types/test_formula_parser.py +++ b/tests/parser/types/test_formula_parser.py @@ -28,3 +28,12 @@ class TestFormulaParser: @pytest.mark.parametrize("formula,tokens", FORMULA_TO_TOKENS.items()) def test_get_tokens(self, formula, tokens): assert list(PARSER.get_tokens(formula)) == tokens + + def test_parse_target_equivalence(self): + target = FormulaParser.Target.TOKENS + assert ( + list(PARSER.parse("a ~ b", target=target)) + == list(PARSER.parse("a ~ b", target=target.value)) + == list(PARSER.parse("a ~ b", target=target.name)) + == list(PARSER.parse("a ~ b", target=target.name.lower())) + ) diff --git a/tests/parser/types/test_operator_resolver.py b/tests/parser/types/test_operator_resolver.py index 0335f67c..aca53366 100644 --- a/tests/parser/types/test_operator_resolver.py +++ b/tests/parser/types/test_operator_resolver.py @@ -26,17 +26,5 @@ def resolver(self): return DummyOperatorResolver() def test_resolve(self, resolver): - assert resolver.resolve(Token("+"), 1, [])[0] is OPERATOR_PLUS - assert resolver.resolve(Token("-"), 0, [])[0] is OPERATOR_UNARY_MINUS - - with pytest.raises(FormulaSyntaxError): - resolver.resolve(Token("@"), 0, []) - - with pytest.raises(FormulaSyntaxError): - resolver.resolve(Token("+"), 0, []) - - with pytest.raises(FormulaSyntaxError): - resolver.resolve(Token("-"), 1, []) - - with pytest.raises(FormulaParsingError, match="Ambiguous operator `:`"): - resolver.resolve(Token(":"), 1, []) + assert list(resolver.resolve(Token("+")))[0][1][0] is OPERATOR_PLUS + assert list(resolver.resolve(Token("-")))[0][1][0] is OPERATOR_UNARY_MINUS diff --git a/tests/parser/types/test_term.py b/tests/parser/types/test_term.py index 0ae26f9c..3631be9d 100644 --- a/tests/parser/types/test_term.py +++ b/tests/parser/types/test_term.py @@ -1,6 +1,7 @@ import pytest from formulaic.parser.types import Factor, Term +from formulaic.parser.types.ordered_set import OrderedSet class TestTerm: @@ -48,3 +49,6 @@ def test_degree(self, term1, term3): assert term3.degree == 3 assert Term([Factor("1", eval_method="literal")]).degree == 0 assert Term([Factor("1", eval_method="literal"), Factor("x")]).degree == 1 + + def test_to_terms(self, term1): + assert term1.to_terms() == OrderedSet((term1,)) diff --git a/tests/parser/types/test_token.py b/tests/parser/types/test_token.py index e929b77b..244d386b 100644 --- a/tests/parser/types/test_token.py +++ b/tests/parser/types/test_token.py @@ -95,3 +95,9 @@ def test_split(self, token_a): Token("b"), Token("c"), ] + + def test_required_variables(self, token_a, token_b): + assert token_a.required_variables == {"a"} + assert token_b.required_variables == {"x"} + assert Token("malformed((python", kind="python").required_variables == set() + assert Token("xyz", kind="value").required_variables == set() diff --git a/tests/test_formula.py b/tests/test_formula.py index 139fa0fc..c8abb860 100644 --- a/tests/test_formula.py +++ b/tests/test_formula.py @@ -82,6 +82,14 @@ def test_constructor(self): assert Formula.from_spec(f) is f assert Formula.from_spec(["a"]) == f + # Test wildcards + assert Formula( + ".", _context={"__formulaic_variables_available__": ["a", "b"]} + ) == ["1", "a", "b"] + assert Formula( + "a ~ .", _context={"__formulaic_variables_available__": ["a", "b"]} + )._to_dict() == {"lhs": ["a"], "rhs": ["1", "b"]} + def test_terms(self, formula_expr): assert [str(t) for t in formula_expr] == [ "1", @@ -379,3 +387,11 @@ def test_deprecated_methods(self): with pytest.warns(DeprecationWarning): assert f._update(nested="a") == StructuredFormula(f, nested="a") + + +class TestStructuredFormula: + def test_pickling(self): + s = StructuredFormula("a + b", _context={}) + s2 = pickle.loads(pickle.dumps(s)) + assert s == s2 + assert s2._context is None diff --git a/tests/test_model_spec.py b/tests/test_model_spec.py index 97dbae88..083b1395 100644 --- a/tests/test_model_spec.py +++ b/tests/test_model_spec.py @@ -123,6 +123,13 @@ def test_get_variable_indices(self, model_spec): assert model_spec.get_variable_indices("a") == [1, 4, 5] assert model_spec.get_variable_indices("A") == [2, 3, 4, 5] + def test_required_variables(self, model_spec): + assert model_spec.structure + assert model_spec.required_variables == {"a", "A"} + + # Derived using formula instead of structure + assert model_spec.update(structure=None).required_variables == {"a", "A"} + def test_get_slice(self, model_spec): s = slice(0, 1) assert model_spec.get_slice(s) is s @@ -274,6 +281,7 @@ def test_model_specs(self, model_spec, data2): assert numpy.all( model_specs.get_model_matrix(data2).a == model_spec.get_model_matrix(data2) ) + assert model_specs.required_variables == {"a", "A"} sparse_matrices = model_specs.get_model_matrix(data2, output="sparse") assert isinstance(sparse_matrices, ModelMatrices) assert isinstance(sparse_matrices.a, scipy.sparse.spmatrix) diff --git a/tests/test_sugar.py b/tests/test_sugar.py index cb1b1552..3186ffe0 100644 --- a/tests/test_sugar.py +++ b/tests/test_sugar.py @@ -29,3 +29,7 @@ def local_test(x): with pytest.raises(FactorEvaluationError): model_matrix("0 + global_test(a) + local_test(b)", data, context=None) + + # test wild-cards + r3 = model_matrix("a ~ .", data) + assert r3.rhs.model_spec.column_names == ("Intercept", "b", "c")