-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #40 from zh-plus/dev
Add context reviewer agent into translation workflow.
- Loading branch information
Showing
11 changed files
with
524 additions
and
507 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Copyright (C) 2024. Hao Zheng | ||
# All rights reserved. | ||
import abc | ||
import re | ||
from typing import Optional, Tuple, List | ||
|
||
from openlrc.chatbot import route_chatbot | ||
from openlrc.context import TranslationContext, TranslateInfo | ||
from openlrc.logger import logger | ||
from openlrc.prompter import BaseTranslatePrompter, ContextReviewPrompter, potential_prefix_combo, \ | ||
ProofreaderPrompter, proofread_prefix | ||
|
||
|
||
class Agent(abc.ABC): | ||
TEMPERATURE = 0.5 | ||
""" | ||
Base class for all agents. | ||
""" | ||
|
||
def _initialize_chatbot(self, chatbot_model: str, fee_limit: float, proxy: str, base_url_config: Optional[dict]): | ||
chatbot_cls, model_name = route_chatbot(chatbot_model) | ||
return chatbot_cls(model=model_name, fee_limit=fee_limit, proxy=proxy, retry=3, | ||
temperature=self.TEMPERATURE, base_url_config=base_url_config) | ||
|
||
|
||
class ChunkedTranslatorAgent(Agent): | ||
""" | ||
Translate the well-defined chunked text to the target language and send it to the chatbot for further processing. | ||
""" | ||
|
||
TEMPERATURE = 0.9 | ||
|
||
def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(), | ||
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, proxy: str = None, | ||
base_url_config: Optional[dict] = None): | ||
super().__init__() | ||
self.chatbot_model = chatbot_model | ||
self.info = info | ||
self.chatbot = self._initialize_chatbot(chatbot_model, fee_limit, proxy, base_url_config) | ||
self.prompter = BaseTranslatePrompter(src_lang, target_lang, info) | ||
self.cost = 0 | ||
|
||
def __str__(self): | ||
return f'Translator Agent ({self.chatbot_model})' | ||
|
||
def _parse_responses(self, resp) -> Tuple[List[str], str, str]: | ||
""" | ||
Parse the response from the chatbot API. | ||
Args: | ||
resp: The response from the chatbot API. | ||
Returns: | ||
Tuple[List[str], str, str]: Parsed translations, summary, and scene from the response. | ||
""" | ||
content = self.chatbot.get_content(resp) | ||
|
||
try: | ||
summary = self._extract_tag_content(content, 'summary') | ||
scene = self._extract_tag_content(content, 'scene') | ||
translations = self._extract_translations(content) | ||
|
||
return [t.strip() for t in translations], summary.strip(), scene.strip() | ||
except Exception as e: | ||
logger.error(f'Failed to extract contents from response: {content}') | ||
raise e | ||
|
||
def _extract_tag_content(self, content: str, tag: str) -> str: | ||
match = re.search(rf'<{tag}>(.*?)</{tag}>', content) | ||
return match.group(1) if match else '' | ||
|
||
def _extract_translations(self, content: str) -> List[str]: | ||
for _, trans_prefix in potential_prefix_combo: | ||
translations = re.findall(f'{trans_prefix}\n*(.*?)(?:#\d+|<summary>|\n*$)', content, re.DOTALL) | ||
if translations: | ||
return self._clean_translations(translations, content) | ||
return [] | ||
|
||
def _clean_translations(self, translations: List[str], content: str) -> List[str]: | ||
if any(re.search(r'(<.*?>|</.*?>)', t) for t in translations): | ||
logger.warning(f'The extracted translation from response contains tags: {content}, tags removed') | ||
return [re.sub(r'(<.*?>|</.*?>).*', '', t, flags=re.DOTALL) for t in translations] | ||
return translations | ||
|
||
def translate_chunk(self, chunk_id: int, chunk: List[Tuple[int, str]], | ||
context: TranslationContext = TranslationContext(), | ||
use_glossary: bool = True) -> Tuple[List[str], TranslationContext]: | ||
user_input = self.prompter.format_texts(chunk) | ||
guideline = context.guideline if use_glossary else context.non_glossary_guideline | ||
messages_list = [ | ||
{'role': 'system', 'content': self.prompter.system()}, | ||
{'role': 'user', 'content': self.prompter.user(chunk_id, user_input, context.summary, guideline)}, | ||
] | ||
resp = self.chatbot.message(messages_list, output_checker=self.prompter.check_format)[0] | ||
translations, summary, scene = self._parse_responses(resp) | ||
self.cost += self.chatbot.api_fees[-1] | ||
context.update(summary=summary, scene=scene, model=self.chatbot_model) | ||
|
||
return translations, context | ||
|
||
|
||
class ContextReviewerAgent(Agent): | ||
""" | ||
Review the context of the subtitles to ensure accuracy and completeness. | ||
""" | ||
|
||
TEMPERATURE = 0.8 | ||
|
||
def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(), | ||
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, proxy: str = None, | ||
base_url_config: Optional[dict] = None): | ||
super().__init__() | ||
self.src_lang = src_lang | ||
self.target_lang = target_lang | ||
self.info = info | ||
self.chatbot_model = chatbot_model | ||
self.prompter = ContextReviewPrompter(src_lang, target_lang) | ||
self.chatbot = self._initialize_chatbot(chatbot_model, fee_limit, proxy, base_url_config) | ||
|
||
def __str__(self): | ||
return f'Context Reviewer Agent ({self.chatbot_model})' | ||
|
||
def build_context(self, texts, title='', glossary: Optional[dict] = None) -> str: | ||
text_content = '\n'.join(texts) | ||
messages_list = [ | ||
{'role': 'system', 'content': self.prompter.system()}, | ||
{'role': 'user', 'content': self.prompter.user(text_content, title=title, given_glossary=glossary)}, | ||
] | ||
resp = self.chatbot.message(messages_list, output_checker=self.prompter.check_format)[0] | ||
context = self.chatbot.get_content(resp) | ||
return context | ||
|
||
|
||
class ProofreaderAgent(Agent): | ||
""" | ||
Adapt subtitles to ensure cultural relevance and appropriateness. | ||
""" | ||
TEMPERATURE = 0.8 | ||
|
||
def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(), | ||
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, proxy: str = None, | ||
base_url_config: Optional[dict] = None): | ||
super().__init__() | ||
self.src_lang = src_lang | ||
self.target_lang = target_lang | ||
self.info = info | ||
self.prompter = ProofreaderPrompter(src_lang, target_lang) | ||
self.chatbot = self._initialize_chatbot(chatbot_model, fee_limit, proxy, base_url_config) | ||
|
||
def _parse_responses(self, resp) -> List[str]: | ||
content = self.chatbot.get_content(resp) | ||
revised = re.findall(proofread_prefix + r'\s*(.*)', content, re.MULTILINE) | ||
|
||
return revised | ||
|
||
def proofread(self, texts: List[str], translations, context: TranslationContext) -> List[str]: | ||
messages_list = [ | ||
{'role': 'system', 'content': self.prompter.system()}, | ||
{'role': 'user', 'content': self.prompter.user(texts, translations, context.guideline)}, | ||
] | ||
resp = self.chatbot.message(messages_list, output_checker=self.prompter.check_format)[0] | ||
revised = self._parse_responses(resp) | ||
return revised |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,29 @@ | ||
# Copyright (C) 2024. Hao Zheng | ||
# All rights reserved. | ||
from difflib import get_close_matches | ||
from pathlib import Path | ||
from typing import Union | ||
import re | ||
from typing import Optional | ||
|
||
import yaml | ||
from pydantic import BaseModel | ||
|
||
from openlrc.logger import logger | ||
|
||
class TranslationContext(BaseModel): | ||
summary: Optional[str] = '' | ||
scene: Optional[str] = '' | ||
model: Optional[str] = None | ||
guideline: Optional[str] = None | ||
|
||
class Context: | ||
def __init__(self, background='', description_map=None, audio_type='Anime', config_path=None): | ||
""" | ||
Context(optional) for translation. | ||
def update(self, **args): | ||
for key, value in args.items(): | ||
if hasattr(self, key): | ||
setattr(self, key, value) | ||
|
||
Args: | ||
background (str): Providing background information for establishing context for the translation. | ||
description_map (dict, optional): {"name(without extension)": "description", ...} | ||
audio_type (str, optional): Audio type, default to Anime. | ||
config_path (str, optional): Path to config file. | ||
@property | ||
def non_glossary_guideline(self) -> str: | ||
cleaned_text = re.sub(r'Glossary:\n(.*?\n)*?\nCharacters:', 'Characters:', self.guideline, flags=re.DOTALL) | ||
return cleaned_text | ||
|
||
Raises: | ||
FileNotFoundError: If the config file specified by config_path does not exist. | ||
|
||
""" | ||
self.config_path = None | ||
self.background = background | ||
self.audio_type = audio_type | ||
self.description_map = description_map if description_map else dict() | ||
|
||
# if config_path exist, load yaml file | ||
if config_path: | ||
config_path = Path(config_path) | ||
if config_path.exists(): | ||
self.load_config(config_path) | ||
else: | ||
raise FileNotFoundError(f'Config file {config_path} not found.') | ||
|
||
def load_config(self, config_path: Union[str, Path]): | ||
config_path = Path(config_path) | ||
if not config_path.exists(): | ||
raise FileNotFoundError(f'Config file {config_path} not found.') | ||
|
||
with open(config_path, 'r', encoding='utf-8') as f: | ||
config: dict = yaml.safe_load(f) | ||
|
||
if config.get('background'): | ||
self.background = config['background'] | ||
|
||
if config.get('audio_type'): | ||
self.audio_type = config['audio_type'] | ||
|
||
if config.get('description_map'): | ||
self.description_map = config['description_map'] | ||
|
||
self.config_path = config_path | ||
|
||
def save_config(self): | ||
with open(self.config_path, 'w') as f: | ||
yaml.dump({ | ||
'background': self.background, | ||
'audio_type': self.audio_type, | ||
'description_map': self.description_map, | ||
}, f) | ||
|
||
def get_description(self, audio_name): | ||
value = '' | ||
if self.description_map: | ||
matches = get_close_matches(audio_name, self.description_map.keys()) | ||
if matches: | ||
key = matches[0] | ||
value = self.description_map.get(key) | ||
logger.info(f'Found description map: {key} -> {value}') | ||
else: | ||
logger.info(f'No description map for {audio_name} found.') | ||
|
||
return value | ||
|
||
def __str__(self): | ||
return f'Context(background={self.background}, audio_type={self.audio_type}, description_map={self.description_map})' | ||
class TranslateInfo(BaseModel): | ||
title: Optional[str] = '' | ||
audio_type: str = 'Movie' | ||
glossary: Optional[dict] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.