Skip to content

Commit e0528fc

Browse files
committed
feat: Implement groups into parser, AST
1 parent d585999 commit e0528fc

File tree

6 files changed

+82
-20
lines changed

6 files changed

+82
-20
lines changed

src/regex_automata/automata/nfa.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def is_trivial(self) -> bool:
2020
class Transition:
2121
predicates: tuple[TransitionPredicate, ...]
2222
consume_char: bool = True
23+
begin_group: int | None = None
24+
end_group: int | None = None
2325
label: str = ""
2426

2527
def matches(self, c_previous: int, c_next: int) -> bool:
@@ -33,12 +35,27 @@ def matches(self, c_previous: int, c_next: int) -> bool:
3335

3436
@property
3537
def is_trivial_epsilon(self) -> bool:
36-
return not self.consume_char and all(p.is_trivial for p in self.predicates)
38+
return (
39+
not self.consume_char and
40+
self.begin_group is None and
41+
self.end_group is None and
42+
all(p.is_trivial for p in self.predicates)
43+
)
3744

3845
@classmethod
3946
def make_trivial_epsilon(cls) -> Self:
4047
return cls(predicates=(TransitionPredicate(),), consume_char=False, label="ε")
4148

49+
@classmethod
50+
def make_begin_group(cls, number: int) -> Self:
51+
return cls(predicates=(TransitionPredicate(),), consume_char=False, label=f"⟨begin group {number}⟩",
52+
begin_group=number)
53+
54+
@classmethod
55+
def make_end_group(cls, number: int) -> Self:
56+
return cls(predicates=(TransitionPredicate(),), consume_char=False, label=f"⟨end group {number}⟩",
57+
end_group=number)
58+
4259

4360
@dataclass
4461
class NFA:
@@ -48,18 +65,18 @@ class NFA:
4865
"""
4966
states: list[int]
5067
initial_state: int
51-
final_states: list[int]
68+
final_states: set[int]
5269
transitions: dict[int, dict[Transition, set[int]]]
5370

5471
def copy(self) -> "NFA":
5572
return deepcopy(self)
5673

57-
def renumber_states(self, x0: int) -> "NFA":
74+
def renumber_states(self, x0: int = 0) -> "NFA":
5875
f = dict(zip(self.states, count(x0)))
5976
return NFA(
6077
states=[f[x] for x in self.states],
6178
initial_state=f[self.initial_state],
62-
final_states=[f[x] for x in self.final_states],
79+
final_states={f[x] for x in self.final_states},
6380
transitions={
6481
f[x]: {p: {f[y] for y in ys} for p, ys in d.items()}
6582
for x, d in self.transitions.items()
@@ -133,6 +150,6 @@ def get_trivial_epsilon_free_nfa(self) -> "NFA":
133150
return NFA(
134151
states=list(sorted(reachable_states)),
135152
initial_state=initial_state,
136-
final_states=list(sorted(final_states)),
153+
final_states=final_states,
137154
transitions=transitions,
138-
)
155+
).renumber_states()

src/regex_automata/parser/ast.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,15 @@ def iter_children(self) -> Iterator["AstNode"]:
105105
class AstEmpty(AstNode):
106106
def get_label(self) -> str:
107107
return "ε"
108+
109+
110+
@dataclass
111+
class AstGroup(AstNode):
112+
number: int
113+
u: AstNode
114+
115+
def get_label(self) -> str:
116+
return f"group {self.number}"
117+
118+
def iter_children(self) -> Iterator["AstNode"]:
119+
yield self.u

src/regex_automata/parser/ast_processor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11

22
from .ast import AstNode, AstEmpty, AstConcatenation, AstUnion, AstRepetition, AstCharacterSet, AstIteration, \
3-
AstBoundaryAssertion
3+
AstBoundaryAssertion, AstGroup
44

55

66
class ASTProcessor:
77
def __init__(self, raw_ast: AstNode) -> None:
88
self.raw_ast = raw_ast
99

1010
def get_processed_ast(self) -> AstNode:
11-
return self.convert(self.raw_ast)
11+
ast = self.convert(self.raw_ast)
12+
return AstGroup(0, ast)
1213

1314
def convert(self, node: AstNode) -> AstNode:
1415
match node:
@@ -26,6 +27,8 @@ def convert(self, node: AstNode) -> AstNode:
2627
return self.convert_AstConcatenation(node)
2728
case AstBoundaryAssertion():
2829
return self.convert_AstBoundaryAssertion(node)
30+
case AstGroup():
31+
return self.convert_AstGroup(node)
2932
case _:
3033
return node
3134

@@ -88,6 +91,9 @@ def convert_AstConcatenation(self, node: AstConcatenation) -> AstNode:
8891
def convert_AstBoundaryAssertion(self, node: AstBoundaryAssertion) -> AstNode:
8992
return node
9093

94+
def convert_AstGroup(self, node: AstGroup) -> AstNode:
95+
return AstGroup(node.number, self.convert(node.u))
96+
9197
@staticmethod
9298
def iterated_concatenation(node: AstNode, n: int) -> AstNode:
9399
# 0 -> AstEmpty == ""

src/regex_automata/parser/parser.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from typing import Type, TypeVar, NoReturn, ParamSpec, Callable
33

44
from .tokens import Token, LPar, RPar, Repetition, Pipe, CharacterSet, BoundaryAssertion
5-
from .ast import AstNode, AstUnion, AstRepetition, AstCharacterSet, AstConcatenation, AstEmpty, AstBoundaryAssertion
5+
from .ast import AstNode, AstUnion, AstRepetition, AstCharacterSet, AstConcatenation, AstEmpty, AstBoundaryAssertion, \
6+
AstGroup
67
from ..errors import ParserError
78

89

@@ -24,6 +25,7 @@ def __init__(self, tokens: list[Token]):
2425
self.tokens = list(tokens)
2526
self.pos = -1
2627
self.string_pos = -1
28+
self.group_number = 1
2729

2830
def read(self, cls: Type[TToken]) -> TToken:
2931
self.pos += 1
@@ -54,6 +56,11 @@ def parse(self) -> AstNode:
5456
self.error("unread input remaining (expected end of input)")
5557
return root
5658

59+
def make_group(self, u: AstNode) -> AstGroup:
60+
i = self.group_number
61+
self.group_number += 1
62+
return AstGroup(i, u)
63+
5764
@rule
5865
def p1(self) -> AstNode:
5966
"""
@@ -232,7 +239,7 @@ def p10(self) -> AstNode:
232239
# rpar
233240
_ = self.read(RPar)
234241

235-
return E
242+
return self.make_group(E)
236243

237244
@rule
238245
def p11(self) -> AstNode:

src/regex_automata/regex/nfa_builder.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ..automata.rangeset import RangeSet, WORD_RANGESET, NONWORD_RANGESET
22
from ..parser.ast import AstNode, AstCharacterSet, AstConcatenation, AstUnion, AstEmpty, AstIteration, \
3-
AstBoundaryAssertion
3+
AstBoundaryAssertion, AstGroup
44
from ..automata.nfa import NFA, Transition, TransitionPredicate
55
from ..parser.tokens import BoundaryAssertionSemantic
66

@@ -30,22 +30,24 @@ def convert(self, node: AstNode) -> NFA:
3030
return self.convert_AstConcatenation(node)
3131
case AstBoundaryAssertion():
3232
return self.covert_AstBoundaryAssertion(node)
33+
case AstGroup():
34+
return self.convert_AstGroup(node)
3335
case _:
3436
raise NotImplementedError(f"Cannot convert node {node!r}")
3537

3638
def convert_AstEmpty(self, _: AstEmpty) -> NFA:
3739
return NFA(
3840
states=[0],
3941
initial_state=0,
40-
final_states=[0],
42+
final_states={0},
4143
transitions={}
4244
)
4345

4446
def convert_AstCharacter(self, node: AstCharacterSet) -> NFA:
4547
return NFA(
4648
states=[0, 1],
4749
initial_state=0,
48-
final_states=[1],
50+
final_states={1},
4951
transitions={0: {Transition(predicates=(TransitionPredicate(next=node.rs),), label=node.label): {1}}}
5052
)
5153

@@ -55,7 +57,7 @@ def convert_AstIteration(self, node: AstIteration) -> NFA:
5557

5658
for x in nfa.final_states:
5759
nfa.transitions.setdefault(x, {}).setdefault(Transition.make_trivial_epsilon(), set()).add(nfa.initial_state)
58-
nfa.final_states = list(sorted(nfa.trivial_epsilon_closure(set(nfa.final_states))))
60+
nfa.final_states = nfa.trivial_epsilon_closure(set(nfa.final_states))
5961
return nfa
6062

6163
def convert_AstUnion(self, node: AstUnion) -> NFA:
@@ -66,7 +68,7 @@ def convert_AstUnion(self, node: AstUnion) -> NFA:
6668

6769
nfa = nfa_u.copy()
6870
nfa.states += nfa_v.states
69-
nfa.final_states += nfa_v.final_states
71+
nfa.final_states |= nfa_v.final_states
7072
nfa.transitions.update(nfa_v.transitions)
7173
new_initial_state = max(nfa.states) + 1
7274
nfa.states.append(new_initial_state)
@@ -144,6 +146,21 @@ def covert_AstBoundaryAssertion(self, node: AstBoundaryAssertion) -> NFA:
144146
return NFA(
145147
states=[0, 1],
146148
initial_state=0,
147-
final_states=[1],
149+
final_states={1},
148150
transitions={0: {transition: {1}}}
149151
)
152+
153+
def convert_AstGroup(self, node: AstGroup) -> NFA:
154+
nfa_u = self.convert(node.u).renumber_states(1)
155+
156+
start_state = 0
157+
final_state = max(nfa_u.states) + 1
158+
159+
nfa = nfa_u.copy()
160+
nfa.initial_state = start_state
161+
nfa.states += [start_state, final_state]
162+
nfa.final_states = {final_state}
163+
nfa.transitions[start_state] = {Transition.make_begin_group(node.number): {nfa_u.initial_state}}
164+
for s in nfa_u.final_states:
165+
nfa.transitions.setdefault(s, {}).setdefault(Transition.make_end_group(node.number), set()).add(final_state)
166+
return nfa

tests/test_parser.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from regex_automata.automata.rangeset import RangeSet
44
from regex_automata.errors import TokenizerError
5-
from regex_automata.parser.ast import AstUnion, AstCharacterSet, AstConcatenation
5+
from regex_automata.parser.ast import AstUnion, AstCharacterSet, AstConcatenation, AstGroup
66
from regex_automata.parser.parser import Parser
77
from regex_automata.parser.tokenizer import Tokenizer
88

@@ -31,9 +31,12 @@ def test_parse_tree_union_parens():
3131
assert Parser(tokens).parse() == AstUnion(
3232
AstConcatenation(_ast_character_set("a"), _ast_character_set("b")),
3333
AstUnion(
34-
AstUnion(
35-
AstConcatenation(_ast_character_set("c"), _ast_character_set("d")),
36-
AstConcatenation(_ast_character_set("e"), _ast_character_set("f")),
34+
AstGroup(
35+
1,
36+
AstUnion(
37+
AstConcatenation(_ast_character_set("c"), _ast_character_set("d")),
38+
AstConcatenation(_ast_character_set("e"), _ast_character_set("f")),
39+
)
3740
),
3841
AstConcatenation(_ast_character_set("g"), _ast_character_set("h")),
3942
)

0 commit comments

Comments
 (0)