diff --git a/.github/workflows/manaul.yml b/.github/workflows/manaul.yml index 9ade94a..394b4e3 100644 --- a/.github/workflows/manaul.yml +++ b/.github/workflows/manaul.yml @@ -40,7 +40,7 @@ jobs: - name: Install Model shell: cmd run: | - .\resource\aria2c.exe https://github.com/neavo/KeywordGachaModel/releases/download/kg_ner_20240819/kg_ner_cpu.zip -o kg_ner_cpu.zip + .\resource\aria2c.exe https://github.com/neavo/KeywordGachaModel/releases/download/kg_ner_20240826/kg_ner_cpu.zip -o kg_ner_cpu.zip powershell -Command "Expand-Archive -Path 'kg_ner_cpu.zip' -DestinationPath 'dist\KeywordGacha\resource\kg_ner_cpu'" powershell -Command "Remove-Item -Path 'kg_ner_cpu.zip' -Recurse -Force -ErrorAction SilentlyContinue" @@ -71,7 +71,7 @@ jobs: .\dist\KeywordGacha\env\python.exe -m pip install torch --index-url https://download.pytorch.org/whl/cu121 .\dist\KeywordGacha\env\python.exe -m pip cache purge - .\resource\aria2c.exe https://github.com/neavo/KeywordGachaModel/releases/download/kg_ner_20240819/kg_ner_gpu.zip -o kg_ner_gpu.zip + .\resource\aria2c.exe https://github.com/neavo/KeywordGachaModel/releases/download/kg_ner_20240826/kg_ner_gpu.zip -o kg_ner_gpu.zip powershell -Command "Expand-Archive -Path 'kg_ner_gpu.zip' -DestinationPath 'dist\KeywordGacha\resource\kg_ner_gpu'" powershell -Command "Remove-Item -Path 'kg_ner_gpu.zip' -Recurse -Force -ErrorAction SilentlyContinue" echo > .\dist\KeywordGacha\gpuboost.txt diff --git a/.gitignore b/.gitignore index 056952d..30c5d8a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,14 @@ +__pycache__ + /env* /dist* /build* /input* /output* -__pycache__ /*log* /debug* /gpuboost* -/words_dict* +/words_all* /config_dev* -/resource/kg_ner_* \ No newline at end of file +/resource/*ner_* \ No newline at end of file diff --git a/README.md b/README.md index db3026b..994f0b2 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,13 @@ - 具体可见 [Wiki - 支持的文件格式](https://github.com/neavo/KeywordGacha/wiki/%E6%94%AF%E6%8C%81%E7%9A%84%E6%96%87%E4%BB%B6%E6%A0%BC%E5%BC%8F) ## 近期更新 📅 +- 20240826 v0.4.0 + - 新增 - 初步完成对 `韩文` 的支持 + - 完全不懂 `韩文`,所以无法评估表现水平 + - 寻求懂 `韩文` 的用户协助测试 + - 调整 - 优化了 NER 实体识别步骤的执行速度 + - `CPU` 和 `GPU` 版本都提速了一倍左右 + - 20240820 v0.3.0 - 调整 - NER 模型更新至 20240819 - 调整 - 移除了一些不再需要的步骤以节约处理时间 @@ -126,7 +133,7 @@ - [X] 添加 对 组织、道具、地域 等其他名词类型的支持 - [X] 添加 对 `英文内容` 的支持 - [X] 添加 对 `中文内容` 的支持 -- [ ] 添加 对 `韩文内容` 的支持 +- [X] 添加 对 `韩文内容` 的支持 - [ ] 添加 对 `俄文内容` 的支持 - [X] 添加 对 GPU 加速的支持 - [X] 添加 全自动生成模式 diff --git a/helper/TestHelper.py b/helper/TestHelper.py index f09cd2a..1ceeffd 100644 --- a/helper/TestHelper.py +++ b/helper/TestHelper.py @@ -1,17 +1,17 @@ from helper.LogHelper import LogHelper from helper.TextHelper import TextHelper + class TestHelper: - @staticmethod - def check_duplicates(*args): - a = { - "ブルースライム": "蓝史莱姆", - "ワイバーン": "双足飞龙", - "ミスリル": "秘银", - "ポーション": "药水", - "ハイポーション": "高级药水", - "魔付き": "魔附者", + x = set( + { + # "ブルースライム": "蓝史莱姆", + # "ワイバーン": "双足飞龙", + # "ミスリル": "秘银", + # "ポーション": "药水", + # "ハイポーション": "高级药水", + # "魔付き": "魔附者", "オルディネ": "奥迪涅", "イシュラナ": "伊修拉纳", "エリルキア": "艾利尔齐亚", @@ -155,23 +155,51 @@ def check_duplicates(*args): "ノワルスール": "诺瓦尔苏尔", "ヴェントルジェント": "文特尔金特", } + ) - b = {} - - if len(a) == 0 or len(b) == 0: - return + @staticmethod + def check_score_threshold(words, path): + thresholds = [ + 0.50, + 0.55, + 0.60, + 0.65, + 0.70, + 0.75, + 0.80, + 0.85, + 0.90, + 0.95, + ] - keys_a = set(a.keys()) - keys_b = set(b.keys()) + with open(path, "w", encoding="utf-8") as writer: + for threshold in thresholds: + y = { + word.surface + for word in words + if word.ner_type == "PER" and word.score > threshold + } - LogHelper.print(f"第一个词典独有的键 - {len(keys_a - keys_b)}") - LogHelper.print(f"{keys_a - keys_b}") - LogHelper.print(f"") + writer.write(f"当置信度阈值设置为 {threshold:.4f} 时:\n") + writer.write(f"第一个词典独有的键 - {len(TestHelper.x - y)}\n") + writer.write(f"{TestHelper.x - y}\n") + writer.write(f"第二个词典独有的键 - {len(y - TestHelper.x)}\n") + writer.write(f"{y - TestHelper.x}\n") + writer.write(f"两个字典共有的键 - {len(TestHelper.x & y)}\n") + writer.write(f"{TestHelper.x & y}\n") + writer.write(f"\n") + writer.write(f"\n") - LogHelper.print(f"第二个词典独有的键 - {len(keys_b - keys_a)}") - LogHelper.print(f"{keys_b - keys_a}") - LogHelper.print(f"") + @staticmethod + def check_result_duplication(words, path): + with open(path, "w", encoding="utf-8") as writer: + y = {word.surface for word in words if word.ner_type == "PER"} - LogHelper.print(f"两个字典共有的键 - {len(keys_a & keys_b)}") - LogHelper.print(f"{keys_a & keys_b}") - LogHelper.print(f"") \ No newline at end of file + writer.write(f"第一个词典独有的键 - {len(TestHelper.x - y)}\n") + writer.write(f"{TestHelper.x - y}\n") + writer.write(f"第二个词典独有的键 - {len(y - TestHelper.x)}\n") + writer.write(f"{y - TestHelper.x}\n") + writer.write(f"两个字典共有的键 - {len(TestHelper.x & y)}\n") + writer.write(f"{TestHelper.x & y}\n") + writer.write(f"\n") + writer.write(f"\n") \ No newline at end of file diff --git a/helper/TextHelper.py b/helper/TextHelper.py index 4e960ce..f819e11 100644 --- a/helper/TextHelper.py +++ b/helper/TextHelper.py @@ -14,6 +14,21 @@ class TextHelper: # 濁音和半浊音符号 VOICED_SOUND_MARKS = ("\u309B", "\u309C") + # 韩文字母 (Hangul Jamo) + HANGUL_JAMO = ("\u1100", "\u11FF") + + # 韩文字母扩展-A (Hangul Jamo Extended-A) + HANGUL_JAMO_EXTENDED_A = ("\uA960", "\uA97F") + + # 韩文字母扩展-B (Hangul Jamo Extended-B) + HANGUL_JAMO_EXTENDED_B = ("\uD7B0", "\uD7FF") + + # 韩文音节块 (Hangul Syllables) + HANGUL_SYLLABLES = ("\uAC00", "\uD7AF") + + # 韩文兼容字母 (Hangul Compatibility Jamo) + HANGUL_COMPATIBILITY_JAMO = ("\u3130", "\u318F") + # 中日韩统一表意文字 CJK = ("\u4E00", "\u9FFF") @@ -247,4 +262,39 @@ def strip_not_latin(text): while text and not TextHelper.is_latin(text[-1]): text = text[:-1] + return text.strip() + + # 判断字符是否为韩文字符 + @staticmethod + def is_korean(ch): + return ( + TextHelper.CJK[0] <= ch <= TextHelper.CJK[1] + or TextHelper.HANGUL_JAMO[0] <= ch <= TextHelper.HANGUL_JAMO[1] + or TextHelper.HANGUL_JAMO_EXTENDED_A[0] <= ch <= TextHelper.HANGUL_JAMO_EXTENDED_A[1] + or TextHelper.HANGUL_JAMO_EXTENDED_B[0] <= ch <= TextHelper.HANGUL_JAMO_EXTENDED_B[1] + or TextHelper.HANGUL_SYLLABLES[0] <= ch <= TextHelper.HANGUL_SYLLABLES[1] + or TextHelper.HANGUL_COMPATIBILITY_JAMO[0] <= ch <= TextHelper.HANGUL_COMPATIBILITY_JAMO[1] + ) + + # 判断输入的字符串是否全部由韩文字符组成 + @staticmethod + def is_all_korean(text): + return all(TextHelper.is_korean(ch) for ch in text) + + # 检查字符串是否包含至少一个韩文字符组成 + @staticmethod + def has_any_korean(text): + return any(TextHelper.is_korean(ch) for ch in text) + + # 移除开头结尾的非韩文字符 + @staticmethod + def strip_not_korean(text): + text = text.strip() + + while text and not TextHelper.is_korean(text[0]): + text = text[1:] + + while text and not TextHelper.is_korean(text[-1]): + text = text[:-1] + return text.strip() \ No newline at end of file diff --git a/image/01.jpg b/image/01.jpg index 8ece131..b29fa46 100644 Binary files a/image/01.jpg and b/image/01.jpg differ diff --git a/main.py b/main.py index 09dcabf..51cab85 100644 --- a/main.py +++ b/main.py @@ -188,13 +188,16 @@ def read_input_file(language): if language == NER.LANGUAGE.JP and not TextHelper.has_any_japanese(line): continue + if language == NER.LANGUAGE.KR and not TextHelper.has_any_korean(line): + continue + input_lines_filtered.append(line.strip()) LogHelper.info(f"已读取到文本 {len(input_lines)} 行,其中有效文本 {len(input_lines_filtered)} 行, 角色名 {len(input_names)} 个...") return input_lines_filtered, input_names # 合并、计数并按置信度过滤 -def merge_and_count(words, full_lines, language): +def merge_and_count(words, full_lines, language, x = None, y = None): words_unique = {} for v in words: if (v.surface, v.ner_type) not in words_unique: @@ -202,19 +205,22 @@ def merge_and_count(words, full_lines, language): words_unique[(v.surface, v.ner_type)].append(v) threshold = { - NER.LANGUAGE.JP : (0.75, 0.80), - NER.LANGUAGE.ZH : (0.75, 0.80), - NER.LANGUAGE.EN : (0.75, 0.80), + NER.LANGUAGE.ZH : (0.80, 0.80), + NER.LANGUAGE.EN : (0.80, 0.80), + NER.LANGUAGE.JP : (0.80, 0.80), + NER.LANGUAGE.KR : (0.80, 0.80), } + threshold_x = x if x != None else threshold[language][0] + threshold_y = y if y != None else threshold[language][1] words_merged = [] for k, v in words_unique.items(): word = v[0] - word.score = min(1.00, sum(w.score for w in v) / len(v)) # 求平均分 + word.score = min(0.9999, sum(w.score for w in v) / len(v)) # 求平均分 if ( - word.ner_type == "PER" and word.score > threshold[language][0] or - word.ner_type != "PER" and word.score > threshold[language][1] + (word.ner_type == "PER" and word.score > threshold_x) + or (word.ner_type != "PER" and word.score > threshold_y) ): words_merged.append(word) @@ -256,25 +262,8 @@ def replace_words_by_ner_type(words, in_words, ner_type): words.extend(in_words) return words -# 将 词语字典 写入文件 -def write_words_dict_to_file(words, path): - words_dict = {} - for k, word in enumerate(words): - if word.ner_type not in words_dict: - words_dict[word.ner_type] = [] - - t = {} - t["score"] = float(word.score) - t["count"] = word.count - t["surface"] = word.surface - t["ner_type"] = word.ner_type - words_dict[word.ner_type].append(t) - - with open(path, "w", encoding = "utf-8") as file: - file.write(json.dumps(words_dict, indent = 4, ensure_ascii = False)) - # 将 词语日志 写入文件 -def write_words_log_to_file(words, path): +def write_words_log_to_file(words, path, language): with open(path, "w", encoding = "utf-8") as file: for k, word in enumerate(words): if getattr(word, "surface", "") != "": @@ -284,7 +273,8 @@ def write_words_log_to_file(words, path): file.write(f"置信度 : {word.score:.4f}\n") if getattr(word, "surface_romaji", "") != "": - file.write(f"罗马音 : {word.surface_romaji}\n") + if language == NER.LANGUAGE.JP: + file.write(f"罗马音 : {word.surface_romaji}\n") if getattr(word, "count", int(-1)) >= 0: file.write(f"出现次数 : {word.count}\n") @@ -320,7 +310,7 @@ def write_words_log_to_file(words, path): LogHelper.info(f"结果已写入 - [green]{path}[/]") # 将 词语列表 写入文件 -def write_words_list_to_file(words, path): +def write_words_list_to_file(words, path, language): with open(path, "w", encoding = "utf-8") as file: data = {} for k, word in enumerate(words): @@ -333,7 +323,7 @@ def write_words_list_to_file(words, path): LogHelper.info(f"结果已写入 - [green]{path}[/]") # 将 AiNiee 词典写入文件 -def write_ainiee_dict_to_file(words, path): +def write_ainiee_dict_to_file(words, path, language): type_map = { "PER": "角色", # 表示人名,如"张三"、"约翰·多伊"等。 "ORG": "组织", # 表示组织,如"联合国"、"苹果公司"等。 @@ -367,7 +357,7 @@ def write_ainiee_dict_to_file(words, path): LogHelper.info(f"结果已写入 - [green]{path}[/]") # 将 GalTransl 词典写入文件 -def write_galtransl_dict_to_file(words, path): +def write_galtransl_dict_to_file(words, path, language): type_map = { "PER": "角色", # 表示人名,如"张三"、"约翰·多伊"等。 "ORG": "组织", # 表示组织,如"联合国"、"苹果公司"等。 @@ -405,10 +395,11 @@ async def process_text(language): words = [] words = G.ner.search_for_entity(input_lines, input_names, language) + # 调试模式时,检查置信度阈值 if LogHelper.is_debug(): - with LogHelper.status(f"正在将实体字典写入文件 ..."): - words = merge_and_count(words, input_lines, language) - write_words_dict_to_file(words, "words_dict.json") + with LogHelper.status(f"正在检查置信度阈值 ..."): + words = merge_and_count(words, input_lines, language, 0.00, 0.00) + TestHelper.check_score_threshold(words, "check_score_threshold.log") # 查找上下文 LogHelper.info("即将开始执行 [查找上下文] ...") @@ -448,13 +439,21 @@ async def process_text(language): words_person = remove_words_by_ner_type(words_person, "") words = replace_words_by_ner_type(words, words_person, "PER") + # 调试模式时,检查结果重复度 + if LogHelper.is_debug(): + with LogHelper.status(f"正在检查结果重复度..."): + TestHelper.check_result_duplication(words, "check_result_duplication.log") + # 等待 重复性校验任务 结果 LogHelper.info("即将开始执行 [重复性校验] ...") words = G.ner.validate_words_by_duplication(words) words = remove_words_by_ner_type(words, "") # 等待翻译词语任务结果 - if language != NER.LANGUAGE.ZH and G.config.translate_surface == 1: + if ( + G.config.translate_surface == 1 + and (language == NER.LANGUAGE.EN or language == NER.LANGUAGE.JP or language == NER.LANGUAGE.KR) + ): LogHelper.info("即将开始执行 [词语翻译] ...") words = await G.llm.translate_surface_batch(words) @@ -467,17 +466,18 @@ async def process_text(language): } # 等待 上下文翻译 任务结果 - for k, v in ner_type.items(): - if ( - (language != NER.LANGUAGE.ZH and k == "PER" and G.config.translate_context_per == 1) - or - (language != NER.LANGUAGE.ZH and k != "PER" and G.config.translate_context_other == 1) - ): - LogHelper.info(f"即将开始执行 [上下文翻译 - {v}] ...") - word_type = get_words_by_ner_type(words, k) - word_type = await G.llm.translate_context_batch(word_type) - words = replace_words_by_ner_type(words, word_type, k) - + if language == NER.LANGUAGE.EN or language == NER.LANGUAGE.JP or language == NER.LANGUAGE.KR: + for k, v in ner_type.items(): + if ( + (k == "PER" and G.config.translate_context_per == 1) + or (k != "PER" and G.config.translate_context_other == 1) + ): + LogHelper.info(f"即将开始执行 [上下文翻译 - {v}] ...") + word_type = get_words_by_ner_type(words, k) + word_type = await G.llm.translate_context_batch(word_type) + words = replace_words_by_ner_type(words, word_type, k) + + # 将结果写入文件 dir_name, file_name_with_extension = os.path.split(G.config.input_file_path) file_name, extension = os.path.splitext(file_name_with_extension) @@ -491,10 +491,10 @@ async def process_text(language): os.remove(f"output/{file_name}_{v}_galtransl.txt") if os.path.exists(f"output/{file_name}_{v}_galtransl.txt") else None if len(words_ner_type) > 0: - write_words_log_to_file(words_ner_type, f"output/{file_name}_{v}_日志.txt") - write_words_list_to_file(words_ner_type, f"output/{file_name}_{v}_列表.json") - write_ainiee_dict_to_file(words_ner_type, f"output/{file_name}_{v}_ainiee.json") - write_galtransl_dict_to_file(words_ner_type, f"output/{file_name}_{v}_galtransl.txt") + write_words_log_to_file(words_ner_type, f"output/{file_name}_{v}_日志.txt", language) + write_words_list_to_file(words_ner_type, f"output/{file_name}_{v}_列表.json", language) + write_ainiee_dict_to_file(words_ner_type, f"output/{file_name}_{v}_ainiee.json", language) + write_galtransl_dict_to_file(words_ner_type, f"output/{file_name}_{v}_galtransl.txt", language) # 等待用户退出 LogHelper.info("") @@ -503,6 +503,7 @@ async def process_text(language): LogHelper.info("") os.system("pause") +# 接口测试 async def test_api(): if await G.llm.api_test(): LogHelper.print("") @@ -556,14 +557,15 @@ def print_app_info(): def print_menu_main(): LogHelper.print(f"请选择:") LogHelper.print(f"") - LogHelper.print(f"\t--> 1. 开始处理 [green]日文文本[/]") - LogHelper.print(f"\t--> 2. 开始处理 [green]中文文本(测试版)[/]") - LogHelper.print(f"\t--> 3. 开始处理 [green]英文文本(测试版)[/]") - LogHelper.print(f"\t--> 4. 开始执行 [green]接口测试[/]") + LogHelper.print(f"\t--> 1. 开始处理 [green]中文文本[/]") + LogHelper.print(f"\t--> 2. 开始处理 [green]英文文本[/]") + LogHelper.print(f"\t--> 3. 开始处理 [green]日文文本[/]") + LogHelper.print(f"\t--> 4. 开始处理 [green]韩文文本(初步支持)[/]") + LogHelper.print(f"\t--> 5. 开始执行 [green]接口测试[/]") LogHelper.print(f"") - choice = int(Prompt.ask("请输入选项前的 [green]数字序号[/] 来使用对应的功能,默认为 [green][1][/] ", - choices = ["1", "2", "3", "4"], - default = "1", + choice = int(Prompt.ask("请输入选项前的 [green]数字序号[/] 来使用对应的功能,默认为 [green][3][/] ", + choices = ["1", "2", "3", "4", "5"], + default = "3", show_choices = False, show_default = False )) @@ -574,25 +576,24 @@ def print_menu_main(): # 主函数 async def begin(): choice = -1 - while choice not in [1, 2, 3]: + while choice not in [1, 2, 3, 4]: print_app_info() choice = print_menu_main() if choice == 1: - await process_text(NER.LANGUAGE.JP) - elif choice == 2: await process_text(NER.LANGUAGE.ZH) - elif choice == 3: + elif choice == 2: await process_text(NER.LANGUAGE.EN) + elif choice == 3: + await process_text(NER.LANGUAGE.JP) elif choice == 4: + await process_text(NER.LANGUAGE.KR) + elif choice == 5: await test_api() # 一些初始化步骤 def init(): with LogHelper.status(f"正在初始化 [green]KG[/] 引擎 ..."): - if LogHelper.is_debug(): - TestHelper.check_duplicates() - # 注册全局异常追踪器 rich.traceback.install() diff --git a/model/LLM.py b/model/LLM.py index 0261f48..f8d6547 100644 --- a/model/LLM.py +++ b/model/LLM.py @@ -26,27 +26,27 @@ class LLM: # 请求参数配置 - 接口测试 LLMCONFIG[TASK_TYPE_API_TEST] = type("GClass", (), {})() LLMCONFIG[TASK_TYPE_API_TEST].TEMPERATURE = 0.05 - LLMCONFIG[TASK_TYPE_API_TEST].TOP_P = 0.95 + LLMCONFIG[TASK_TYPE_API_TEST].TOP_P = 0.85 LLMCONFIG[TASK_TYPE_API_TEST].MAX_TOKENS = 768 LLMCONFIG[TASK_TYPE_API_TEST].FREQUENCY_PENALTY = 0 # 请求参数配置 - 语义分析 LLMCONFIG[TASK_TYPE_SUMMAIRZE_CONTEXT] = type("GClass", (), {})() LLMCONFIG[TASK_TYPE_SUMMAIRZE_CONTEXT].TEMPERATURE = 0.05 - LLMCONFIG[TASK_TYPE_SUMMAIRZE_CONTEXT].TOP_P = 0.95 + LLMCONFIG[TASK_TYPE_SUMMAIRZE_CONTEXT].TOP_P = 0.85 LLMCONFIG[TASK_TYPE_SUMMAIRZE_CONTEXT].MAX_TOKENS = 768 LLMCONFIG[TASK_TYPE_SUMMAIRZE_CONTEXT].FREQUENCY_PENALTY = 0 # 请求参数配置 - 翻译词语 LLMCONFIG[TASK_TYPE_TRANSLATE_SURFACE] = type("GClass", (), {})() LLMCONFIG[TASK_TYPE_TRANSLATE_SURFACE].TEMPERATURE = 0.05 - LLMCONFIG[TASK_TYPE_TRANSLATE_SURFACE].TOP_P = 0.95 + LLMCONFIG[TASK_TYPE_TRANSLATE_SURFACE].TOP_P = 0.85 LLMCONFIG[TASK_TYPE_TRANSLATE_SURFACE].MAX_TOKENS = 768 LLMCONFIG[TASK_TYPE_TRANSLATE_SURFACE].FREQUENCY_PENALTY = 0 # 请求参数配置 - 翻译上下文 LLMCONFIG[TASK_TYPE_TRANSLATE_CONTEXT] = type("GClass", (), {})() - LLMCONFIG[TASK_TYPE_TRANSLATE_CONTEXT].TEMPERATURE = 0.25 + LLMCONFIG[TASK_TYPE_TRANSLATE_CONTEXT].TEMPERATURE = 0.75 LLMCONFIG[TASK_TYPE_TRANSLATE_CONTEXT].TOP_P = 0.95 LLMCONFIG[TASK_TYPE_TRANSLATE_CONTEXT].MAX_TOKENS = 1024 LLMCONFIG[TASK_TYPE_TRANSLATE_CONTEXT].FREQUENCY_PENALTY = 0 diff --git a/model/NER.py b/model/NER.py index 123c5f5..370bb57 100644 --- a/model/NER.py +++ b/model/NER.py @@ -23,8 +23,8 @@ class NER: TASK_MODES.ACCURACY = 20 GPU_BOOST = torch.cuda.is_available() and LogHelper.is_gpu_boost() + BATCH_SIZE = 32 if GPU_BOOST else 1 MODEL_PATH = "resource/kg_ner_gpu" if GPU_BOOST else "resource/kg_ner_cpu" - LINE_SIZE_PER_GROUP = 128 RE_SPLIT_BY_PUNCTUATION = re.compile( rf"[" + @@ -52,6 +52,7 @@ class NER: LANGUAGE.ZH = "ZH" LANGUAGE.EN = "EN" LANGUAGE.JP = "JP" + LANGUAGE.KR = "KR" def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained( @@ -85,7 +86,6 @@ def __init__(self): model = self.model, device = "cuda" if self.GPU_BOOST else "cpu", tokenizer = self.tokenizer, - batch_size = 64 if self.GPU_BOOST else min(10, os.cpu_count()), aggregation_strategy = "simple", ) @@ -106,6 +106,11 @@ def release(self): LogHelper.debug(f"显存保留量 - {torch.cuda.memory_reserved()/1024/1024:>8.2f} MB") LogHelper.debug(f"显存分配量 - {torch.cuda.memory_allocated()/1024/1024:>8.2f} MB") + # 生成器 + def generator(self, data): + for v in data: + yield v + # 从指定路径加载黑名单文件内容 def load_blacklist(self, filepath): try: @@ -173,6 +178,50 @@ def is_valid_japanese_word(self, surface, blacklist): return flag + # 判断是否是有意义的韩文词语 + def is_valid_korean_word(self, surface, blacklist): + flag = True + + if len(surface) <= 1: + return False + + if surface in blacklist: + return False + + if not TextHelper.has_any_korean(surface): + return False + + return flag + + # 生成片段 + def generate_chunks(self, input_lines, chunk_size): + chunks = [] + + chunk = "" + chunk_length = 0 + for line in input_lines: + encoding = self.tokenizer( + line, + padding = False, + truncation = True, + max_length = chunk_size - 3, + ) + length = len(encoding.input_ids) + + if chunk_length + length > chunk_size - 3: + chunks.append(chunk) + chunk = "" + chunk_length = 0 + + chunk = chunk + "\n" + line + chunk_length = chunk_length + length + 1 + + # 循环结束后添加最后一段 + if len(chunk) > 0: + chunks.append(chunk) + + return chunks + # 生成词语 def generate_words(self, text, score, ner_type, language, unique_words): words = [] @@ -206,6 +255,12 @@ def generate_words(self, text, score, ner_type, language, unique_words): if not self.is_valid_japanese_word(surface, self.blacklist): continue + # 韩文词语判断 + if language == NER.LANGUAGE.KR: + surface = TextHelper.strip_not_korean(surface) + if not self.is_valid_korean_word(surface, self.blacklist): + continue + word = Word() word.count = 1 word.score = score @@ -231,43 +286,42 @@ def get_english_lemma(self, surface): def search_for_entity(self, input_lines, input_names, language): words = [] - input_lines_chunked = [ - input_lines[i : i + self.LINE_SIZE_PER_GROUP] - for i in range(0, len(input_lines), self.LINE_SIZE_PER_GROUP) - ] - if LogHelper.is_gpu_boost() and torch.cuda.is_available(): LogHelper.info("启用 [green]GPU[/] 加速成功 ...") if LogHelper.is_gpu_boost() and not torch.cuda.is_available(): LogHelper.warning("启用 [green]GPU[/] 加速失败 ...") LogHelper.print(f"") + with LogHelper.status("正在对文本进行预处理 ..."): + chunks = self.generate_chunks(input_lines, 512) + with ProgressHelper.get_progress() as progress: pid = progress.add_task("查找 NER 实体", total = None) + i = 0 seen = set() unique_words = None - for k, lines in enumerate(input_lines_chunked): - self.classifier.call_count = 0 # 防止出现应使用 dateset 的提示 - - # 拼接文本 - line_joined = "\n".join(lines) - - # 如果是英文,则抓取去重词表,再计算并添加所有词根到词表 + for result in self.classifier( + self.generator(chunks), + batch_size = self.BATCH_SIZE, + ): + # 获取当前文本 + line = chunks[i] + + # 如果是英文,则抓取去重词表,再计算并添加所有词根到词表,以供后续筛选词语 if language == NER.LANGUAGE.EN: - unique_words = set(re.findall(r"\b\w+\b", line_joined)) + unique_words = set(re.findall(r"\b\w+\b", line)) unique_words.update(set(self.get_english_lemma(v) for v in unique_words)) - # 使用 NER 模型抓取实体 - for i, doc in enumerate(self.classifier(lines)): - for token in doc: - text = token.get("word") - score = token.get("score") - entity_group = token.get("entity_group") - words.extend(self.generate_words(text, score, entity_group, language, unique_words)) - + # 处理 NER模型 识别结果 + for token in result: + text = token.get("word") + score = token.get("score") + entity_group = token.get("entity_group") + words.extend(self.generate_words(text, score, entity_group, language, unique_words)) + # 匹配【】中的字符串 - for name in re.findall(r"【(.*?)】", line_joined): + for name in re.findall(r"【(.*?)】", line): if len(name) <= 12: text = name score = 65535 @@ -278,14 +332,13 @@ def search_for_entity(self, input_lines, input_names, language): continue else: words.append(word) - seen.add(word.surface) + seen.add(word.surface) - # 进行英文和中文任务时显存保留量会暴涨,原因不明,暂时按照超过 2G 时手动释放显存进行处理 - torch.cuda.empty_cache() if torch.cuda.memory_reserved() > 2 * 1024 * 1024 * 1024 else None - progress.update(pid, advance = 1, total = len(input_lines_chunked)) + i = i + 1 + progress.update(pid, advance = 1, total = len(chunks)) # 打印通过模式匹配抓取的角色实体 - LogHelper.print() + LogHelper.print(f"") LogHelper.info(f"[查找 NER 实体] 已完成 ...") if len(seen) > 0: LogHelper.info(f"[查找 NER 实体] 通过模式 [green]【(.*?)】[/] 抓取到角色实体 - {", ".join(seen)}") diff --git a/prompt/prompt_translate_surface_common.txt b/prompt/prompt_translate_surface_common.txt index b3e9b22..48d26f5 100644 --- a/prompt/prompt_translate_surface_common.txt +++ b/prompt/prompt_translate_surface_common.txt @@ -1,4 +1,4 @@ -将接收到的专有名词翻译成两种不同的中文译文,并标注原词语的罗马音。 +将接收到的专有名词翻译成两种不同的中文译文,并使用罗马音标注词语的读音。 专有名词: {surface} diff --git a/prompt/prompt_translate_surface_person.txt b/prompt/prompt_translate_surface_person.txt index 29b290f..b9de25a 100644 --- a/prompt/prompt_translate_surface_person.txt +++ b/prompt/prompt_translate_surface_person.txt @@ -1,4 +1,4 @@ -将接收到的性别为{attribute}的角色的名字翻译成两种不同的中文译文,并标注原词语的罗马音。 +将接收到的性别为{attribute}的角色的名字翻译成两种不同的中文译文,并使用罗马音标注词语的读音。 角色的名字: {surface} diff --git a/version.txt b/version.txt index d4dfa56..01e994d 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -v0.3.0 \ No newline at end of file +v0.4.0 \ No newline at end of file