Skip to content

Commit 0946e7b

Browse files
committed
feat: Implement groups into NFAEvaluator (overlapping matches)
1 parent e0528fc commit 0946e7b

File tree

1 file changed

+117
-66
lines changed

1 file changed

+117
-66
lines changed
Lines changed: 117 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,145 @@
1-
from typing import Iterator
1+
from dataclasses import dataclass
2+
from typing import Iterator, Self, Iterable, Set
23

3-
from regex_automata.automata.nfa import NFA
4+
from regex_automata.automata.nfa import NFA, Transition
45
from regex_automata.regex.flags import PatternFlag
56
from regex_automata.regex.match import Match
67

78

8-
class NFAEvaluator:
9-
class Head:
10-
def __init__(self, evaluator: "NFAEvaluator", start: int) -> None:
11-
self.start = start
12-
self.evaluator = evaluator
13-
self.states: set[int] = set(self.evaluator.initial_states)
14-
self.entered_final = bool(self.states & self.evaluator.final_states)
15-
self.left_final = False
16-
17-
def step_epsilon(self, c_previous: int, c_next: int) -> None:
18-
self.states = new_states = self._step_epsilon(c_previous, c_next, self.states)
19-
new_in_final = bool(new_states & self.evaluator.final_states)
20-
self.entered_final = self.entered_final or new_in_final
21-
self.left_final = self.entered_final and not new_in_final
22-
23-
def _step_epsilon(self, c_previous: int, c_next: int, states: set[int]) -> set[int]:
24-
return self.evaluator.nfa.epsilon_closure(states, c_previous, c_next)
25-
26-
def step_read(self, c_previous: int, c_next: int) -> None:
27-
self.states = new_states = self._step_read(c_previous, c_next, self.states)
28-
new_in_final = bool(new_states & self.evaluator.final_states)
29-
self.entered_final = self.entered_final or new_in_final
30-
self.left_final = self.entered_final and not new_in_final
31-
32-
def _step_read(self, c_previous: int, c_next: int, states: set[int]) -> set[int]:
33-
assert c_next != -1
34-
new_states = set()
35-
for u in states:
36-
u_transitions = self.evaluator.nfa.transitions.get(u, {})
37-
for p, vs in u_transitions.items():
38-
if p.consume_char and p.matches(c_previous, c_next):
39-
new_states.update(vs)
40-
41-
return new_states
42-
43-
def __repr__(self) -> str:
44-
return f"<Head {self.start=} {self.states=} {self.entered_final=} {self.left_final=}>"
9+
@dataclass(frozen=True)
10+
class GroupMatch:
11+
start: int
12+
end: int
13+
14+
15+
@dataclass(frozen=True)
16+
class Head:
17+
state: int
18+
start: int
19+
position: int
20+
groups: tuple[GroupMatch | None, ...] = ()
21+
22+
def apply_transition(self, transition: Transition, next_state: int) -> Self:
23+
head = self
24+
if transition.begin_group is not None:
25+
head = head._begin_group(transition.begin_group)
26+
if transition.end_group is not None:
27+
head = head._end_group(transition.end_group)
28+
return Head(
29+
next_state,
30+
head.start,
31+
head.position + (1 if transition.consume_char else 0),
32+
head.groups
33+
)
34+
35+
def _begin_group(self, number: int) -> Self:
36+
groups = list(self.groups)
37+
while len(groups) <= number:
38+
groups.append(None)
39+
m = groups[number]
40+
if m is not None and m.end == -1:
41+
raise ValueError(f"Group {number} has not been closed yet, cannot begin it again")
42+
groups[number] = GroupMatch(self.position, -1)
43+
return Head(
44+
self.state,
45+
self.start,
46+
self.position,
47+
tuple(groups)
48+
)
49+
50+
def _end_group(self, number: int) -> Self:
51+
groups = list(self.groups)
52+
if len(groups) < number:
53+
raise ValueError(f"Group {number} has never been opened, cannot close it")
54+
m = groups[number]
55+
if m is None:
56+
raise ValueError(f"Group {number} has never been opened, cannot close it")
57+
elif m.end != -1:
58+
raise ValueError(f"Group {number} has already been closed, cannot close it again")
59+
groups[number] = GroupMatch(m.start, self.position)
60+
return Head(
61+
self.state,
62+
self.start,
63+
self.position,
64+
tuple(groups)
65+
)
66+
4567

68+
class NFAEvaluator:
4669
def __init__(self, nfa: NFA, flags: PatternFlag = PatternFlag.NOFLAG) -> None:
4770
self.nfa = nfa
48-
self.initial_states = self.nfa.trivial_epsilon_closure({self.nfa.initial_state})
49-
self.heads: list["NFAEvaluator.Head"] = []
5071
self.flags = flags
51-
self.final_states = set(self.nfa.final_states)
72+
if len(nfa.final_states) != 1:
73+
raise ValueError("Expected NFA with exactly one final state (end of group 0)")
74+
self.final_state = next(iter(nfa.final_states))
5275

5376
def finditer(self, text: str, start: int = 0, end: int | None = None, search: bool = True) -> Iterator[Match]:
77+
original_text = text
5478
if self.flags & PatternFlag.IGNORECASE:
5579
text = text.lower()
5680

5781
end_ = end if end is not None else len(text)
5882

59-
self.heads.append(self.Head(self, min(len(text), start)))
83+
heads = {self.init_head(min(len(text), start))}
6084

6185
c_previous = -1
6286
for char_no, i in enumerate(range(min(len(text), start), min(len(text), end_))):
63-
match_at_position = False
87+
print("finditer", char_no, i)
88+
print(heads)
6489
if search and char_no > 0:
65-
self.heads.append(self.Head(self, i))
90+
heads.add(self.init_head(i))
6691

6792
c_next = ord(text[i])
6893

69-
for head in self.heads:
70-
head.step_epsilon(c_previous, c_next)
71-
if not match_at_position and head.left_final:
72-
self.purge_heads(i-1)
73-
yield Match.from_span_and_text(head.start, i-1, text)
74-
match_at_position = True # avoid returning multiple matches
94+
# do epsilon transitions
95+
print("do epsilon transitions")
96+
heads = self.apply_epsilon_transitions(heads, c_previous, c_next)
97+
print(heads)
98+
yield from self.iter_matches_from_heads(heads, original_text)
7599

76-
for head in self.heads:
77-
head.step_read(c_previous, c_next)
78-
if not match_at_position and head.left_final:
79-
self.purge_heads(i)
80-
yield Match.from_span_and_text(head.start, i, text)
81-
match_at_position = True
100+
# do character transitions
101+
print("do character transitions")
102+
heads = self.apply_character_transitions(heads, c_previous, c_next)
103+
print(heads)
104+
yield from self.iter_matches_from_heads(heads, original_text)
82105

83106
c_previous = c_next
84107

108+
print("finished reading input, doing final epsilon transitions")
85109
c_next = -1
86-
for head in self.heads:
87-
head.step_epsilon(c_previous, c_next)
88-
89-
if head.entered_final and not head.left_final:
90-
yield Match.from_span_and_text(head.start, end_, text)
91-
return
92-
93-
def purge_heads(self, start_min: int) -> None:
94-
self.heads = [h for h in self.heads if h.start >= start_min]
110+
# do epsilon transitions
111+
heads = self.apply_epsilon_transitions(heads, c_previous, c_next)
112+
print(heads)
113+
yield from self.iter_matches_from_heads(heads, original_text)
114+
115+
def init_head(self, position: int) -> Head:
116+
return Head(self.nfa.initial_state, position, position)
117+
118+
def apply_epsilon_transitions(self, heads: Iterable[Head], c_previous: int, c_next: int) -> Set[Head]:
119+
closure = set(heads)
120+
while True:
121+
new_closure = closure
122+
for head in closure:
123+
for transition, next_states in self.nfa.transitions.get(head.state, {}).items():
124+
if not transition.consume_char and transition.matches(c_previous, c_next):
125+
new_closure = new_closure | {head.apply_transition(transition, next_state) for next_state in next_states}
126+
if len(closure) == len(new_closure):
127+
break
128+
closure = new_closure
129+
return closure
130+
131+
def apply_character_transitions(self, heads: Iterable[Head], c_previous: int, c_next: int) -> Set[Head]:
132+
new_heads = set()
133+
for head in heads:
134+
for transition, next_states in self.nfa.transitions.get(head.state, {}).items():
135+
if transition.consume_char and transition.matches(c_previous, c_next):
136+
for next_state in next_states:
137+
new_head = head.apply_transition(transition, next_state)
138+
new_heads.add(new_head)
139+
140+
return new_heads
141+
142+
def iter_matches_from_heads(self, heads: Iterable[Head], text: str) -> Iterator[Match]:
143+
for head in sorted(heads, key=lambda h: (h.start, -h.position)):
144+
if head.state == self.final_state:
145+
yield Match.from_span_and_text(head.start, head.position, text)

0 commit comments

Comments
 (0)