Skip to content

Commit

Permalink
tokens_to_ast: Improve generality of algorithm, and simplify contract…
Browse files Browse the repository at this point in the history
… with `OperatorResolver` instances.
  • Loading branch information
matthewwardrop committed Dec 1, 2024
1 parent ee1c3a2 commit 89a4164
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 110 deletions.
83 changes: 59 additions & 24 deletions formulaic/parser/algos/tokens_to_ast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import namedtuple
from typing import Iterable, List, Union
from typing import Iterable, List, Set, Union

from ..types import ASTNode, Operator, OperatorResolver, Token
from ..utils import exc_for_missing_operator, exc_for_token
Expand Down Expand Up @@ -42,6 +42,7 @@ def tokens_to_ast(
"""
output_queue: List[Union[Token, ASTNode]] = []
operator_stack: List[OrderedOperator] = []
disabled_operators: Set[Token] = set()

def stack_operator(operator: Union[Token, Operator], token: Token) -> None:
operator_stack.append(OrderedOperator(operator, token, len(output_queue)))
Expand Down Expand Up @@ -98,30 +99,55 @@ def operate(
f"Context token `{token.token}` is unrecognized.",
)
elif token.kind is token.Kind.OPERATOR:
max_prefix_arity = (
len(output_queue) - operator_stack[-1].index
if operator_stack
else len(output_queue)
)
operators = operator_resolver.resolve(
token,
max_prefix_arity=max_prefix_arity,
context=[s.operator for s in operator_stack],
)

for operator in operators:
while (
operator_stack
and operator_stack[-1].token.kind is not Token.Kind.CONTEXT
and (
operator_stack[-1].operator.precedence > operator.precedence
or operator_stack[-1].operator.precedence == operator.precedence
and operator.associativity is Operator.Associativity.LEFT
for operator_token, operators in operator_resolver.resolve(token):
for operator in operators:
if not operator.accepts_context(
[s.operator for s in operator_stack]
):
continue
if operator.disabled:
disabled_operators.add(operator_token)
continue
# Apply all operators with precedence greater than the current operator
while (
operator_stack
and operator_stack[-1].token.kind is not Token.Kind.CONTEXT
and (
operator_stack[-1].operator.precedence > operator.precedence
or operator_stack[-1].operator.precedence
== operator.precedence
and operator.associativity is Operator.Associativity.LEFT
)
):
output_queue = operate(operator_stack.pop(), output_queue)

# Determine maximum number of postfix arguments
max_postfix_arity = (
len(output_queue) - operator_stack[-1].index
if operator_stack
else len(output_queue)
)
):
output_queue = operate(operator_stack.pop(), output_queue)

stack_operator(operator, token)
# Check if operator is valid in current context
if (
operator.arity == 0
or operator.fixity is Operator.Fixity.PREFIX
or max_postfix_arity == 1
and operator.fixity is Operator.Fixity.INFIX
or max_postfix_arity >= operator.arity
and operator.fixity is Operator.Fixity.POSTFIX
):
stack_operator(operator, token)
break
else:
if operator_token in disabled_operators:
raise exc_for_token(
token,
f"Operator `{operator_token}` is at least partially disabled by parser configuration, and/or is incorrectly used.",
)
raise exc_for_token(
token, f"Operator `{operator_token}` is incorrectly used."
)
else:
output_queue.append(token)

Expand All @@ -134,7 +160,16 @@ def operate(

if output_queue:
if len(output_queue) > 1:
raise exc_for_missing_operator(output_queue[0], output_queue[1])
raise exc_for_missing_operator(
output_queue[0],
output_queue[1],
extra=(
"This may be due to the following operators being at least "
f"partially disabled by parser configuration: {disabled_operators}."
if disabled_operators
else None
),
)
return output_queue[0]

return None
19 changes: 10 additions & 9 deletions formulaic/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from dataclasses import dataclass, field
from enum import Flag, auto
from typing import Iterable, List, Sequence, Set, Tuple, Union, cast
from typing import Generator, Iterable, List, Set, Tuple, Union, cast

from typing_extensions import Self

Expand Down Expand Up @@ -454,10 +454,12 @@ def get_terms(terms: OrderedSet[Term]) -> List[Term]:
]

def resolve(
self, token: Token, max_prefix_arity: int, context: List[Union[Token, Operator]]
) -> Sequence[Operator]:
self,
token: Token,
) -> Generator[Tuple[Token, Iterable[Operator]], None, None]:
if token.token in self.operator_table:
return super().resolve(token, max_prefix_arity, context)
yield from super().resolve(token)
return

symbol = token.token

Expand All @@ -474,9 +476,8 @@ def resolve(
)

if symbol in self.operator_table:
return [self._resolve(token, symbol, max_prefix_arity, context)]
yield self._resolve(token, symbol)
return

return [
self._resolve(token, sym, max_prefix_arity if i == 0 else 0, context)
for i, sym in enumerate(symbol)
]
for sym in symbol:
yield self._resolve(token, sym)
46 changes: 12 additions & 34 deletions formulaic/parser/types/operator_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from collections import defaultdict
from typing import Dict, List, Sequence, Union
from typing import Dict, Generator, Iterable, List, Tuple

from ..utils import exc_for_token
from .operator import Operator
Expand Down Expand Up @@ -53,11 +53,15 @@ def operator_table(self) -> Dict[str, List[Operator]]:
return operator_table

def resolve(
self, token: Token, max_prefix_arity: int, context: List[Union[Token, Operator]]
) -> Sequence[Operator]:
self, token: Token
) -> Generator[Tuple[Token, Iterable[Operator]], None, None]:
"""
Return a list of operators to apply for a given token in the AST
generation.
Generate the sets of operator candidates that may be viable for the
given token (which may include multiple adjacent operators concatenated
together). Each item generated must be a tuple for the token associated
with the operator, and an iterable of `Operator` instances which should
be considered by the AST generator. These `Operator` instances *MUST* be
sorted in descending order of precendence and arity.
Args:
token: The operator `Token` instance for which `Operator`(s) should
Expand All @@ -68,45 +72,19 @@ def resolve(
resolved will be placed. This will be a list of `Operator`
instances or tokens (tokens are return for grouping operators).
"""
return [self._resolve(token, token.token, max_prefix_arity, context)]
yield self._resolve(token, token.token)

def _resolve(
self,
token: Token,
symbol: str,
max_prefix_arity: int,
context: List[Union[Token, Operator]],
) -> Operator:
) -> Tuple[Token, Iterable[Operator]]:
"""
The default operator resolving logic.
"""
if symbol not in self.operator_table:
raise exc_for_token(token, f"Unknown operator '{symbol}'.")
candidates = [
candidate
for candidate in self.operator_table[symbol]
if (
max_prefix_arity == 0
and candidate.fixity is Operator.Fixity.PREFIX
or max_prefix_arity > 0
and candidate.fixity is not Operator.Fixity.PREFIX
)
and candidate.accepts_context(context)
]
if not candidates:
raise exc_for_token(token, f"Operator `{symbol}` is incorrectly used.")
candidates = [candidate for candidate in candidates if not candidate.disabled]
if not candidates:
raise exc_for_token(
token,
f"Operator `{symbol}` has been disabled in this context via parser configuration.",
)
if len(candidates) > 1:
raise exc_for_token(
token,
f"Ambiguous operator `{symbol}`. This is not usually a user error. Please report this!",
)
return candidates[0]
return token, self.operator_table[symbol]

# The operator table cache may not be pickleable, so let's drop it.
def __getstate__(self) -> Dict:
Expand Down
4 changes: 3 additions & 1 deletion formulaic/parser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def exc_for_missing_operator(
lhs: Union[Token, ASTNode],
rhs: Union[Token, ASTNode],
errcls: Type[Exception] = FormulaSyntaxError,
extra: Optional[str] = None,
) -> Exception:
"""
Return an exception ready to be raised about a missing operator token
Expand All @@ -45,11 +46,12 @@ def exc_for_missing_operator(
rhs: The `Token` or `ASTNode` instance to the right of where an operator
should be placed.
errcls: The type of the exception to be returned.
extra: Any additional information to be included in the exception message.
"""
lhs_token, rhs_token, error_token = __get_tokens_for_gap(lhs, rhs)
return exc_for_token(
error_token,
f"Missing operator between `{lhs_token.token}` and `{rhs_token.token}`.",
f"Missing operator between `{lhs_token.token}` and `{rhs_token.token}`.{f' {extra}' if extra else ''}",
errcls=errcls,
)

Expand Down
48 changes: 20 additions & 28 deletions tests/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_feature_flags(self):
with pytest.raises(
FormulaSyntaxError,
match=re.escape(
"Operator `~` has been disabled in this context via parser configuration."
"Missing operator between `y` and `1`. This may be due to the following operators being at least partially disabled by parser configuration: {~}."
),
):
DefaultFormulaParser(feature_flags={}).get_terms("y ~ x")
Expand All @@ -266,7 +266,7 @@ def test_feature_flags(self):
with pytest.raises(
FormulaSyntaxError,
match=re.escape(
"Operator `|` has been disabled in this context via parser configuration."
"Operator `|` is at least partially disabled by parser configuration, and/or is incorrectly used."
),
):
DefaultFormulaParser().set_feature_flags({}).get_terms("x | y")
Expand All @@ -287,32 +287,24 @@ def resolver(self):
return DefaultOperatorResolver()

def test_resolve(self, resolver):
assert len(resolver.resolve(Token("+++++"), 1, [])) == 1
assert resolver.resolve(Token("+++++"), 1, [])[0].symbol == "+"
assert resolver.resolve(Token("+++++"), 1, [])[0].arity == 2

assert len(resolver.resolve(Token("+++-+"), 1, [])) == 1
assert resolver.resolve(Token("+++-+"), 1, [])[0].symbol == "-"
assert resolver.resolve(Token("+++-+"), 1, [])[0].arity == 2

assert len(resolver.resolve(Token("*+++-+"), 1, [])) == 2
assert resolver.resolve(Token("*+++-+"), 1, [])[0].symbol == "*"
assert resolver.resolve(Token("*+++-+"), 1, [])[0].arity == 2
assert resolver.resolve(Token("*+++-+"), 1, [])[1].symbol == "-"
assert resolver.resolve(Token("*+++-+"), 1, [])[1].arity == 1

with pytest.raises(
FormulaSyntaxError, match="Operator `/` is incorrectly used."
):
resolver.resolve(Token("*/"), 2, [])

def test_accepts_context(self, resolver):
tilde_operator = resolver.resolve(Token("~"), 1, [])[0]

with pytest.raises(
FormulaSyntaxError, match=re.escape("Operator `~` is incorrectly used.")
):
resolver.resolve(Token("~"), 1, [tilde_operator])
resolved = list(resolver.resolve(Token("+++++")))
assert len(resolved) == 1
assert resolved[0][1][0].symbol == "+"
assert resolved[0][1][0].arity == 2

resolved = list(resolver.resolve(Token("+++-+")))
assert len(resolved) == 1
assert resolved[0][1][0].symbol == "-"
assert resolved[0][1][0].arity == 2

resolved = list(resolver.resolve(Token("*+++-+")))
assert len(resolved) == 2
assert resolved[0][1][0].symbol == "*"
assert resolved[0][1][0].arity == 2
assert resolved[1][1][0].symbol == "-"
assert resolved[1][1][0].arity == 2
assert resolved[1][1][1].symbol == "-"
assert resolved[1][1][1].arity == 1

def test_pickleable(self, resolver):
o = BytesIO()
Expand Down
16 changes: 2 additions & 14 deletions tests/parser/types/test_operator_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,5 @@ def resolver(self):
return DummyOperatorResolver()

def test_resolve(self, resolver):
assert resolver.resolve(Token("+"), 1, [])[0] is OPERATOR_PLUS
assert resolver.resolve(Token("-"), 0, [])[0] is OPERATOR_UNARY_MINUS

with pytest.raises(FormulaSyntaxError):
resolver.resolve(Token("@"), 0, [])

with pytest.raises(FormulaSyntaxError):
resolver.resolve(Token("+"), 0, [])

with pytest.raises(FormulaSyntaxError):
resolver.resolve(Token("-"), 1, [])

with pytest.raises(FormulaParsingError, match="Ambiguous operator `:`"):
resolver.resolve(Token(":"), 1, [])
assert list(resolver.resolve(Token("+")))[0][1][0] is OPERATOR_PLUS
assert list(resolver.resolve(Token("-")))[0][1][0] is OPERATOR_UNARY_MINUS

0 comments on commit 89a4164

Please sign in to comment.