@@ -25,10 +25,9 @@ def __init__(
25
25
default_lang : str = DEFAULT_LANG ,
26
26
punctuation : str = PUNCTUATION ,
27
27
not_merge_punctuation : str = "" ,
28
- special_merge_for_zh_ja : bool = True ,
29
28
merge_across_punctuation : bool = True ,
30
- merge_across_newline : bool = True ,
31
29
merge_across_digit : bool = True ,
30
+ merge_across_newline : bool = True ,
32
31
debug : bool = True ,
33
32
log_level : int = logging .INFO ,
34
33
) -> None :
@@ -51,10 +50,9 @@ def __init__(
51
50
self .debug = debug
52
51
self .punctuation = punctuation
53
52
self .not_merge_punctuation = not_merge_punctuation
54
- self .merge_across_newline = merge_across_newline
55
- self .special_merge_for_zh_ja = special_merge_for_zh_ja
56
53
self .merge_across_punctuation = merge_across_punctuation
57
54
self .merge_across_digit = merge_across_digit
55
+ self .merge_across_newline = merge_across_newline
58
56
self .log_level = log_level
59
57
logging .basicConfig (
60
58
level = self .log_level ,
@@ -79,17 +77,18 @@ def split_by_lang(
79
77
sections = self ._split (pre_split_section = pre_split_section )
80
78
81
79
if self .merge_across_punctuation : # 合并跨标点符号的 SubString
82
- after_merge_punctuation_sections = (
83
- self ._merge_substrings_across_punctuation_based_on_sections (
84
- sections = sections
85
- )
80
+ sections = self ._merge_substrings_across_punctuation_based_on_sections (
81
+ sections = sections
86
82
)
87
83
88
84
if self .merge_across_digit : # 合并跨数字的 SubString
89
- after_merge_digit_sections = (
90
- self ._merge_substrings_across_digit_based_on_sections (
91
- sections = after_merge_punctuation_sections
92
- )
85
+ sections = self ._merge_substrings_across_digit_based_on_sections (
86
+ sections = sections
87
+ )
88
+
89
+ if self .merge_across_newline :
90
+ sections = self ._merge_substrings_across_newline_based_on_sections (
91
+ sections = sections
93
92
)
94
93
95
94
substrings : List [SubString ] = []
@@ -148,8 +147,12 @@ def add_substring(lang_section_type: LangSectionType):
148
147
current_lang = LangSectionType .PUNCTUATION
149
148
elif char .isspace ():
150
149
# concat space to current text
151
- add_substring (current_lang )
152
- current_lang = LangSectionType .PUNCTUATION
150
+ if char == "\n " :
151
+ add_substring (current_lang )
152
+ current_lang = LangSectionType .NEWLINE
153
+ else :
154
+ add_substring (current_lang )
155
+ current_lang = LangSectionType .PUNCTUATION
153
156
else :
154
157
if current_lang != LangSectionType .OTHERS :
155
158
add_substring (current_lang )
@@ -193,6 +196,16 @@ def _split(
193
196
length = section_len ,
194
197
)
195
198
)
199
+ elif section .lang_section_type is LangSectionType .NEWLINE :
200
+ # NOTE: 换行作为单独的 SubString
201
+ section .substrings .append (
202
+ SubString (
203
+ text = section .text ,
204
+ lang = "newline" ,
205
+ index = section_index ,
206
+ length = section_len ,
207
+ )
208
+ )
196
209
else :
197
210
substrings_with_lang : List [SubString ] = []
198
211
if section .lang_section_type is LangSectionType .ZH_JA :
@@ -212,6 +225,7 @@ def _split(
212
225
length = section_len ,
213
226
)
214
227
]
228
+
215
229
else :
216
230
temp_substrings = self ._parse_without_zh_ja (section .text )
217
231
substrings_with_lang = self ._init_substr_lang (
@@ -232,7 +246,10 @@ def _split(
232
246
233
247
# MARK: smart merge substring together
234
248
for section in pre_split_section :
235
- if section .lang_section_type is LangSectionType .PUNCTUATION :
249
+ if (
250
+ section .lang_section_type is LangSectionType .PUNCTUATION
251
+ or section .lang_section_type is LangSectionType .NEWLINE
252
+ ):
236
253
# print(section.text)
237
254
continue
238
255
smart_concat_result = self ._smart_merge (
@@ -660,6 +677,98 @@ def _special_merge_for_zh_ja(
660
677
new_substrings = self ._merge_substrings (substrings = new_substrings )
661
678
return new_substrings
662
679
680
+ def _merge_substrings_across_newline_based_on_sections (
681
+ self ,
682
+ sections : List [SubStringSection ],
683
+ ) -> List [SubStringSection ]:
684
+ new_sections : List [SubStringSection ] = [sections [0 ]]
685
+ # NOTE: 将 sections 中的 newline 合并到临近的非 punctuation 的 section
686
+ for index , _ in enumerate (sections ):
687
+ if index == 0 :
688
+ continue
689
+ if index >= len (sections ):
690
+ break
691
+
692
+ prev_section = new_sections [- 1 ]
693
+ current_section = sections [index ]
694
+ if (
695
+ current_section .lang_section_type != LangSectionType .PUNCTUATION
696
+ and prev_section .lang_section_type == LangSectionType .NEWLINE
697
+ ):
698
+ # NOTE: 如果前一个 section 是 newline,则合并
699
+ prev_section .lang_section_type = current_section .lang_section_type
700
+ prev_section .text += current_section .text
701
+ prev_section .substrings .extend (current_section .substrings )
702
+ for index , substr in enumerate (prev_section .substrings ):
703
+ if index == 0 :
704
+ continue
705
+ else :
706
+ substr .index = (
707
+ new_sections [- 1 ].substrings [index - 1 ].index
708
+ + new_sections [- 1 ].substrings [index - 1 ].length
709
+ )
710
+
711
+ elif (
712
+ current_section .lang_section_type == LangSectionType .NEWLINE
713
+ and prev_section .lang_section_type != LangSectionType .PUNCTUATION
714
+ ):
715
+ # NOTE: 如果前一个 section 不是 punctuation,则合并
716
+ prev_section .text += current_section .text
717
+ prev_section .substrings .extend (current_section .substrings )
718
+ prev_section .substrings [- 1 ].index = (
719
+ prev_section .substrings [- 2 ].index
720
+ + prev_section .substrings [- 2 ].length
721
+ )
722
+ else :
723
+ new_sections .append (current_section )
724
+ # NOTE: 将相同类型的 section 合并
725
+ new_sections_merged : List [SubStringSection ] = [new_sections [0 ]]
726
+ for index , _ in enumerate (new_sections ):
727
+ if index == 0 :
728
+ continue
729
+ if (
730
+ new_sections_merged [- 1 ].lang_section_type
731
+ == new_sections [index ].lang_section_type
732
+ ):
733
+ new_sections_merged [- 1 ].text += new_sections [index ].text
734
+ new_sections_merged [- 1 ].substrings .extend (
735
+ new_sections [index ].substrings
736
+ )
737
+ else :
738
+ new_sections_merged .append (new_sections [index ])
739
+ # NOTE: 重新计算 index
740
+ for section_index , section in enumerate (new_sections_merged ):
741
+ if section_index == 0 :
742
+ for substr_index , substr in enumerate (section .substrings ):
743
+ if substr_index == 0 :
744
+ continue
745
+ else :
746
+ substr .index = (
747
+ section .substrings [substr_index - 1 ].index
748
+ + section .substrings [substr_index - 1 ].length
749
+ )
750
+ else :
751
+ for substr_index , substr in enumerate (section .substrings ):
752
+ if substr_index == 0 :
753
+ substr .index = (
754
+ new_sections_merged [section_index - 1 ].substrings [- 1 ].index
755
+ + new_sections_merged [section_index - 1 ]
756
+ .substrings [- 1 ]
757
+ .length
758
+ )
759
+ else :
760
+ substr .index = (
761
+ section .substrings [substr_index - 1 ].index
762
+ + section .substrings [substr_index - 1 ].length
763
+ )
764
+ if self .debug :
765
+ logger .debug (
766
+ "---------------------------------after_merge_newline_sections:"
767
+ )
768
+ for section in new_sections_merged :
769
+ logger .debug (section )
770
+ return new_sections_merged
771
+
663
772
# MARK: _merge_substrings_across_digit_based_on_sections
664
773
def _merge_substrings_across_digit_based_on_sections (
665
774
self ,
@@ -799,7 +908,11 @@ def _merge_substrings_across_punctuation_based_on_sections(
799
908
# 如果前一个 section 和当前的 section 类型不同,且其中一个是 punctuation,则合并
800
909
if current_section .lang_section_type != prev_section .lang_section_type :
801
910
# 如果前一个 section 是 punctuation,且第一个元素不是 not_merge_punctuation,则合并
802
- if prev_section .lang_section_type == LangSectionType .PUNCTUATION :
911
+ if (
912
+ prev_section .lang_section_type == LangSectionType .PUNCTUATION
913
+ and prev_section .substrings [0 ].text
914
+ not in self .not_merge_punctuation
915
+ ):
803
916
# 将前一个 punctuation section 和当前的 section 合并
804
917
prev_section .text += current_section .text
805
918
prev_section .lang_section_type = current_section .lang_section_type
0 commit comments