Skip to content

Commit 310941c

Browse files
further edits to chunking logic
1 parent 4cec811 commit 310941c

File tree

2 files changed

+2382
-1052
lines changed

2 files changed

+2382
-1052
lines changed

parser/chunking.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,22 @@
4646
SCHEME_PREFIXES = ("SCHEME",)
4747
CODE_BLOCK_TAGS = {"SNIPPET", "PROGRAMLISTING", "CODE", "DISPLAY"}
4848
INLINE_CODE_TAGS = {"JAVASCRIPTINLINE"}
49+
SENTENCE_END_CHARS = {
50+
".",
51+
"?",
52+
"!",
53+
"\"",
54+
"'",
55+
"\u00bb",
56+
"\u201d",
57+
"\u2019",
58+
":",
59+
";",
60+
")",
61+
"]",
62+
"}",
63+
"`",
64+
}
4965

5066

5167
def num_tokens(text: str) -> int:
@@ -95,10 +111,23 @@ def normalize_code(text: str) -> str:
95111
return "\n".join(lines)
96112

97113

114+
def _append_tail(parent: ET.Element, siblings: List[ET.Element], idx: int, tail: Optional[str]) -> None:
115+
if not tail:
116+
return
117+
for prev_idx in range(idx - 1, -1, -1):
118+
prev = siblings[prev_idx]
119+
if prev in parent:
120+
prev.tail = (prev.tail or "") + tail
121+
return
122+
parent.text = (parent.text or "") + tail
123+
124+
98125
def prune_tree(node: ET.Element) -> None:
99-
for child in list(node):
126+
children = list(node)
127+
for idx, child in enumerate(children):
100128
tag = child.tag.upper()
101129
if tag in DROP_TAGS or tag.startswith(SCHEME_PREFIXES):
130+
_append_tail(node, children, idx, child.tail)
102131
node.remove(child)
103132
continue
104133
prune_tree(child)
@@ -111,8 +140,6 @@ def visit(el: ET.Element) -> None:
111140
tag = el.tag.upper()
112141
if tag in DROP_TAGS or tag.startswith(SCHEME_PREFIXES):
113142
return
114-
if tag == "JAVASCRIPT_OUTPUT":
115-
return
116143
if el.text:
117144
parts.append(el.text)
118145
for child in el:
@@ -144,8 +171,9 @@ def flush_text() -> None:
144171
segments.append({"type": "text", "content": text})
145172
buffer.clear()
146173

147-
def walk(el: ET.Element) -> None:
174+
def walk(el: ET.Element, parent_tag: Optional[str] = None) -> None:
148175
tag = el.tag.upper()
176+
parent_upper = parent_tag.upper() if parent_tag else None
149177

150178
if tag in DROP_TAGS or tag.startswith(SCHEME_PREFIXES):
151179
return
@@ -173,6 +201,12 @@ def walk(el: ET.Element) -> None:
173201
snippet = gather_code(el)
174202
if not snippet:
175203
return
204+
inline_context = parent_upper not in CODE_BLOCK_TAGS
205+
if inline_context:
206+
inline_snippet = normalize_whitespace(snippet.replace("\n", " "))
207+
if inline_snippet:
208+
append_text(inline_snippet)
209+
return
176210
if "\n" in snippet or re.search(r"[;{}=]", snippet):
177211
flush_text()
178212
segments.append({"type": "code", "content": snippet})
@@ -184,15 +218,15 @@ def walk(el: ET.Element) -> None:
184218
append_text(el.text)
185219

186220
for child in el:
187-
walk(child)
221+
walk(child, tag)
188222
if child.tail:
189223
append_text(child.tail)
190224

191225
if node.text:
192226
append_text(node.text)
193227

194228
for child in node:
195-
walk(child)
229+
walk(child, node.tag)
196230
if child.tail:
197231
append_text(child.tail)
198232

@@ -349,6 +383,13 @@ def join_units(units: List[str], types: List[str]) -> str:
349383
return out.strip()
350384

351385

386+
def has_sentence_ending(text: str) -> bool:
387+
stripped = text.rstrip()
388+
if not stripped:
389+
return True
390+
return stripped[-1] in SENTENCE_END_CHARS
391+
392+
352393
def chunk_units(
353394
units: List[Dict[str, str]],
354395
meta: Dict[str, Optional[str]],
@@ -408,14 +449,23 @@ def flush() -> None:
408449
if carry_pending and not buffer:
409450
seed_carry()
410451

411-
if current_tokens + unit_tokens > MAX_TOKENS and buffer:
452+
overflow = current_tokens + unit_tokens > MAX_TOKENS and buffer
453+
allow_overflow = (
454+
overflow
455+
and ttype == "code"
456+
and types
457+
and types[-1] == "text"
458+
and not has_sentence_ending(buffer[-1])
459+
)
460+
461+
if overflow and not allow_overflow:
412462
flush()
413463
if carry_pending and not buffer:
414464
seed_carry()
415-
416-
buffer.append(text)
417-
types.append(ttype)
418-
current_tokens += unit_tokens
465+
else:
466+
buffer.append(text)
467+
types.append(ttype)
468+
current_tokens = current_tokens + unit_tokens
419469

420470
flush()
421471
return chunks

0 commit comments

Comments
 (0)