Skip to content

Commit 4f7718c

Browse files
committed
fix(splitter): merge across newline (\n) on section stage
1 parent b94cf69 commit 4f7718c

File tree

3 files changed

+155
-19
lines changed

3 files changed

+155
-19
lines changed

split_lang/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class LangSectionType(Enum):
88
ZH_JA = "zh_ja"
99
KO = "ko"
1010
PUNCTUATION = "punctuation"
11+
NEWLINE = "newline"
1112
DIGIT = "digit"
1213
OTHERS = "others"
1314
ALL = "all"

split_lang/split/splitter.py

Lines changed: 129 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ def __init__(
2525
default_lang: str = DEFAULT_LANG,
2626
punctuation: str = PUNCTUATION,
2727
not_merge_punctuation: str = "",
28-
special_merge_for_zh_ja: bool = True,
2928
merge_across_punctuation: bool = True,
30-
merge_across_newline: bool = True,
3129
merge_across_digit: bool = True,
30+
merge_across_newline: bool = True,
3231
debug: bool = True,
3332
log_level: int = logging.INFO,
3433
) -> None:
@@ -51,10 +50,9 @@ def __init__(
5150
self.debug = debug
5251
self.punctuation = punctuation
5352
self.not_merge_punctuation = not_merge_punctuation
54-
self.merge_across_newline = merge_across_newline
55-
self.special_merge_for_zh_ja = special_merge_for_zh_ja
5653
self.merge_across_punctuation = merge_across_punctuation
5754
self.merge_across_digit = merge_across_digit
55+
self.merge_across_newline = merge_across_newline
5856
self.log_level = log_level
5957
logging.basicConfig(
6058
level=self.log_level,
@@ -79,17 +77,18 @@ def split_by_lang(
7977
sections = self._split(pre_split_section=pre_split_section)
8078

8179
if self.merge_across_punctuation: # 合并跨标点符号的 SubString
82-
after_merge_punctuation_sections = (
83-
self._merge_substrings_across_punctuation_based_on_sections(
84-
sections=sections
85-
)
80+
sections = self._merge_substrings_across_punctuation_based_on_sections(
81+
sections=sections
8682
)
8783

8884
if self.merge_across_digit: # 合并跨数字的 SubString
89-
after_merge_digit_sections = (
90-
self._merge_substrings_across_digit_based_on_sections(
91-
sections=after_merge_punctuation_sections
92-
)
85+
sections = self._merge_substrings_across_digit_based_on_sections(
86+
sections=sections
87+
)
88+
89+
if self.merge_across_newline:
90+
sections = self._merge_substrings_across_newline_based_on_sections(
91+
sections=sections
9392
)
9493

9594
substrings: List[SubString] = []
@@ -148,8 +147,12 @@ def add_substring(lang_section_type: LangSectionType):
148147
current_lang = LangSectionType.PUNCTUATION
149148
elif char.isspace():
150149
# concat space to current text
151-
add_substring(current_lang)
152-
current_lang = LangSectionType.PUNCTUATION
150+
if char == "\n":
151+
add_substring(current_lang)
152+
current_lang = LangSectionType.NEWLINE
153+
else:
154+
add_substring(current_lang)
155+
current_lang = LangSectionType.PUNCTUATION
153156
else:
154157
if current_lang != LangSectionType.OTHERS:
155158
add_substring(current_lang)
@@ -193,6 +196,16 @@ def _split(
193196
length=section_len,
194197
)
195198
)
199+
elif section.lang_section_type is LangSectionType.NEWLINE:
200+
# NOTE: 换行作为单独的 SubString
201+
section.substrings.append(
202+
SubString(
203+
text=section.text,
204+
lang="newline",
205+
index=section_index,
206+
length=section_len,
207+
)
208+
)
196209
else:
197210
substrings_with_lang: List[SubString] = []
198211
if section.lang_section_type is LangSectionType.ZH_JA:
@@ -212,6 +225,7 @@ def _split(
212225
length=section_len,
213226
)
214227
]
228+
215229
else:
216230
temp_substrings = self._parse_without_zh_ja(section.text)
217231
substrings_with_lang = self._init_substr_lang(
@@ -232,7 +246,10 @@ def _split(
232246

233247
# MARK: smart merge substring together
234248
for section in pre_split_section:
235-
if section.lang_section_type is LangSectionType.PUNCTUATION:
249+
if (
250+
section.lang_section_type is LangSectionType.PUNCTUATION
251+
or section.lang_section_type is LangSectionType.NEWLINE
252+
):
236253
# print(section.text)
237254
continue
238255
smart_concat_result = self._smart_merge(
@@ -660,6 +677,98 @@ def _special_merge_for_zh_ja(
660677
new_substrings = self._merge_substrings(substrings=new_substrings)
661678
return new_substrings
662679

680+
def _merge_substrings_across_newline_based_on_sections(
681+
self,
682+
sections: List[SubStringSection],
683+
) -> List[SubStringSection]:
684+
new_sections: List[SubStringSection] = [sections[0]]
685+
# NOTE: 将 sections 中的 newline 合并到临近的非 punctuation 的 section
686+
for index, _ in enumerate(sections):
687+
if index == 0:
688+
continue
689+
if index >= len(sections):
690+
break
691+
692+
prev_section = new_sections[-1]
693+
current_section = sections[index]
694+
if (
695+
current_section.lang_section_type != LangSectionType.PUNCTUATION
696+
and prev_section.lang_section_type == LangSectionType.NEWLINE
697+
):
698+
# NOTE: 如果前一个 section 是 newline,则合并
699+
prev_section.lang_section_type = current_section.lang_section_type
700+
prev_section.text += current_section.text
701+
prev_section.substrings.extend(current_section.substrings)
702+
for index, substr in enumerate(prev_section.substrings):
703+
if index == 0:
704+
continue
705+
else:
706+
substr.index = (
707+
new_sections[-1].substrings[index - 1].index
708+
+ new_sections[-1].substrings[index - 1].length
709+
)
710+
711+
elif (
712+
current_section.lang_section_type == LangSectionType.NEWLINE
713+
and prev_section.lang_section_type != LangSectionType.PUNCTUATION
714+
):
715+
# NOTE: 如果前一个 section 不是 punctuation,则合并
716+
prev_section.text += current_section.text
717+
prev_section.substrings.extend(current_section.substrings)
718+
prev_section.substrings[-1].index = (
719+
prev_section.substrings[-2].index
720+
+ prev_section.substrings[-2].length
721+
)
722+
else:
723+
new_sections.append(current_section)
724+
# NOTE: 将相同类型的 section 合并
725+
new_sections_merged: List[SubStringSection] = [new_sections[0]]
726+
for index, _ in enumerate(new_sections):
727+
if index == 0:
728+
continue
729+
if (
730+
new_sections_merged[-1].lang_section_type
731+
== new_sections[index].lang_section_type
732+
):
733+
new_sections_merged[-1].text += new_sections[index].text
734+
new_sections_merged[-1].substrings.extend(
735+
new_sections[index].substrings
736+
)
737+
else:
738+
new_sections_merged.append(new_sections[index])
739+
# NOTE: 重新计算 index
740+
for section_index, section in enumerate(new_sections_merged):
741+
if section_index == 0:
742+
for substr_index, substr in enumerate(section.substrings):
743+
if substr_index == 0:
744+
continue
745+
else:
746+
substr.index = (
747+
section.substrings[substr_index - 1].index
748+
+ section.substrings[substr_index - 1].length
749+
)
750+
else:
751+
for substr_index, substr in enumerate(section.substrings):
752+
if substr_index == 0:
753+
substr.index = (
754+
new_sections_merged[section_index - 1].substrings[-1].index
755+
+ new_sections_merged[section_index - 1]
756+
.substrings[-1]
757+
.length
758+
)
759+
else:
760+
substr.index = (
761+
section.substrings[substr_index - 1].index
762+
+ section.substrings[substr_index - 1].length
763+
)
764+
if self.debug:
765+
logger.debug(
766+
"---------------------------------after_merge_newline_sections:"
767+
)
768+
for section in new_sections_merged:
769+
logger.debug(section)
770+
return new_sections_merged
771+
663772
# MARK: _merge_substrings_across_digit_based_on_sections
664773
def _merge_substrings_across_digit_based_on_sections(
665774
self,
@@ -799,7 +908,11 @@ def _merge_substrings_across_punctuation_based_on_sections(
799908
# 如果前一个 section 和当前的 section 类型不同,且其中一个是 punctuation,则合并
800909
if current_section.lang_section_type != prev_section.lang_section_type:
801910
# 如果前一个 section 是 punctuation,且第一个元素不是 not_merge_punctuation,则合并
802-
if prev_section.lang_section_type == LangSectionType.PUNCTUATION:
911+
if (
912+
prev_section.lang_section_type == LangSectionType.PUNCTUATION
913+
and prev_section.substrings[0].text
914+
not in self.not_merge_punctuation
915+
):
803916
# 将前一个 punctuation section 和当前的 section 合并
804917
prev_section.text += current_section.text
805918
prev_section.lang_section_type = current_section.lang_section_type

tests/test_split.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
from split_lang import LangSplitter
24

35
texts = [
@@ -113,10 +115,10 @@
113115
そして、この先、私達3人の関係は壊れていくことにこの時は気づかなかった…""",
114116
]
115117

116-
lang_splitter = LangSplitter()
118+
lang_splitter = LangSplitter(log_level=logging.DEBUG)
117119

118120

119-
def test_split():
121+
def test_split_step_by_step():
120122
for text in texts:
121123
pre_split_sections = lang_splitter.pre_split(
122124
text=text,
@@ -143,12 +145,32 @@ def test_split():
143145
sections=after_merge_punctuation_sections,
144146
)
145147
)
148+
146149
# for section in after_merge_digit_sections:
147150
# print(section)
148151

152+
after_merge_newline_sections = (
153+
lang_splitter._merge_substrings_across_newline_based_on_sections(
154+
sections=after_merge_digit_sections,
155+
)
156+
)
157+
# for section in after_merge_newline_sections:
158+
# print(section)
159+
160+
161+
def test_split():
162+
print("===========test_split===========")
163+
lang_splitter.merge_across_punctuation = True
164+
lang_splitter.not_merge_punctuation = ["。"]
165+
for text in texts:
166+
substrings = lang_splitter.split_by_lang(text=text)
167+
for substr in substrings:
168+
print(substr)
169+
149170

150171
def main():
151-
test_split()
172+
test_split_step_by_step()
173+
# test_split()
152174
pass
153175

154176

0 commit comments

Comments
 (0)