diff --git a/itn/chinese/inverse_normalizer.py b/itn/chinese/inverse_normalizer.py index 3861f4a..41bd4a2 100644 --- a/itn/chinese/inverse_normalizer.py +++ b/itn/chinese/inverse_normalizer.py @@ -46,39 +46,37 @@ def __init__(self, self.build_fst('zh_itn', cache_dir, overwrite_cache) def build_tagger(self): - tagger = ( - add_weight(Date().tagger, 1.02) - | add_weight(Whitelist().tagger, 1.01) - | add_weight(Fraction().tagger, 1.05) - | add_weight( - Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05) # noqa - | add_weight(Money(enable_0_to_9=self.enable_0_to_9).tagger, - 1.04) # noqa - | add_weight(Time().tagger, 1.05) - | add_weight( - Cardinal(self.convert_number, self.enable_0_to_9, - self.enable_million).tagger, 1.06) # noqa - | add_weight(Math().tagger, 1.10) - | add_weight(LicensePlate().tagger, 1.0) - | add_weight(Char().tagger, 100)).optimize() + tagger = (add_weight(Date().tagger, 1.02) + | add_weight(Whitelist().tagger, 1.01) + | add_weight(Fraction().tagger, 1.05) + | add_weight( + Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05) + | add_weight( + Money(enable_0_to_9=self.enable_0_to_9).tagger, 1.04) + | add_weight(Time().tagger, 1.05) + | add_weight( + Cardinal(self.convert_number, self.enable_0_to_9, + self.enable_million).tagger, 1.06) + | add_weight(Math().tagger, 1.10) + | add_weight(LicensePlate().tagger, 1.0) + | add_weight(Char().tagger, 100)).optimize() tagger = tagger.star # remove the last space self.tagger = tagger @ self.build_rule(delete(' '), '', '[EOS]') def build_verbalizer(self): - verbalizer = ( - Cardinal(self.convert_number, self.enable_0_to_9, - self.enable_million).verbalizer # noqa - | Char().verbalizer - | Date().verbalizer - | Fraction().verbalizer - | Math().verbalizer - | Measure(enable_0_to_9=self.enable_0_to_9).verbalizer - | Money(enable_0_to_9=self.enable_0_to_9).verbalizer - | Time().verbalizer - | LicensePlate().verbalizer - | Whitelist().verbalizer).optimize() + verbalizer = (Cardinal(self.convert_number, self.enable_0_to_9, + self.enable_million).verbalizer + | Char().verbalizer + | Date().verbalizer + | Fraction().verbalizer + | Math().verbalizer + | Measure(enable_0_to_9=self.enable_0_to_9).verbalizer + | Money(enable_0_to_9=self.enable_0_to_9).verbalizer + | Time().verbalizer + | LicensePlate().verbalizer + | Whitelist().verbalizer).optimize() postprocessor = PostProcessor(remove_interjections=True).processor self.verbalizer = (verbalizer @ postprocessor).star diff --git a/itn/chinese/rules/cardinal.py b/itn/chinese/rules/cardinal.py index 85cdab9..dbcd491 100644 --- a/itn/chinese/rules/cardinal.py +++ b/itn/chinese/rules/cardinal.py @@ -39,9 +39,9 @@ def build_tagger(self): sign = string_file('itn/chinese/data/number/sign.tsv') # + - dot = string_file('itn/chinese/data/number/dot.tsv') # . + # 0. 基础数字 addzero = insert('0') digits = zero | digit # 0 ~ 9 - # 十一 => 11, 十二 => 12 teen = cross('十', '1') + (digit | add_weight(addzero, 0.1)) # 一十一 => 11, 二十一 => 21, 三十 => 30 @@ -81,6 +81,8 @@ def build_tagger(self): | add_weight(addzero**4, 1.0))) ten_thousand |= (thousand | hundred) + accep("万") + delete( "零").ques + (thousand | hundred | tens | teen | digits).ques + + # 1. 利用基础数字所构建的包含0~9的完整数字 # 个/十/百/千/万 number = digits | teen | tens | hundred | thousand | ten_thousand # 兆/亿 @@ -106,6 +108,7 @@ def build_tagger(self): self.special_2number = special_2number.optimize() self.special_3number = special_3number.optimize() + # 2. 利用基础数字所构建的不包含0~9的完整数字 # 十/百/千/万 number_exclude_0_to_9 = teen | tens | hundred | thousand | ten_thousand # 兆/亿 @@ -124,8 +127,9 @@ def build_tagger(self): number_exclude_0_to_9 |= add_weight(special_3number, -100.0) self.number_exclude_0_to_9 = (sign.ques + - number_exclude_0_to_9).optimize() # noqa + number_exclude_0_to_9).optimize() + # 3. 特殊格式的数字 # cardinal string like 127.0.0.1, used in ID, IP, etc. cardinal = digits.plus + (dot + digits.plus).plus # float number like 1.11 @@ -134,6 +138,8 @@ def build_tagger(self): # 340621199806051223, used in ID card cardinal |= (digits**3 | digits**4 | digits**5 | digits**11 | digits**18) + + # 4. 特殊格式的数字 + 包含或不包含0~9的完整数字 # cardinal string like 23 if self.enable_standalone_number: if self.enable_0_to_9: diff --git a/itn/main.py b/itn/main.py index a188938..d14ee8f 100644 --- a/itn/main.py +++ b/itn/main.py @@ -42,11 +42,11 @@ def main(): parser.add_argument('--enable_standalone_number', type=str, default='True', - help='enable standalone number') + help='一百 = 100 if True else 一百') parser.add_argument('--enable_0_to_9', type=str, default='False', - help='enable convert number 0 to 9') + help='零和九 = 0和9 if True else 零和九') parser.add_argument('--enable_million', type=str, default='False',