Skip to content

Commit

Permalink
Fixed the emit with spinner. Improved translations
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Oct 30, 2024
1 parent 94c26e2 commit 6ab84cd
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 63 deletions.
14 changes: 9 additions & 5 deletions src/main/askai/core/component/internet_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.tools import Tool
from langchain_google_community import GoogleSearchAPIWrapper
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from openai import APIError
from typing import List

Expand Down Expand Up @@ -152,10 +152,9 @@ def _build_google_query(search: SearchResult) -> str:
return f"{google_query} {sites}"

def __init__(self):
API_KEYS.ensure("GOOGLE_API_KEY", "google_search")
self._google = GoogleSearchAPIWrapper(k=10, google_api_key=API_KEYS.GOOGLE_API_KEY)
self._tool = Tool(name="google_search", description="Search Google for recent results.", func=self._google.run)
self._text_splitter = RecursiveCharacterTextSplitter(
self._google: GoogleSearchAPIWrapper | None = None
self._tool: Tool | None = None
self._text_splitter: TextSplitter = RecursiveCharacterTextSplitter(
chunk_size=configs.chunk_size, chunk_overlap=configs.chunk_overlap
)

Expand All @@ -170,6 +169,11 @@ def google_search(self, search: SearchResult) -> str:
:param search: The AI search parameters encapsulated in a SearchResult object.
:return: A refined string containing the search results.
"""
# Lazy initialization to allow GOOGLE_API_KEY be optional.
if not self._google:
API_KEYS.ensure("GOOGLE_API_KEY", "google_search")
self._google = GoogleSearchAPIWrapper(k=10, google_api_key=API_KEYS.GOOGLE_API_KEY)
self._tool = Tool(name="google_search", description="Search Google for recent results.", func=self._google.run)
events.reply.emit(reply=AIReply.info(msg.searching()))
terms: str = self._build_google_query(search).strip()
question: str = re.sub(r"(\w+:)*|((\w+\.\w+)*)", "", terms, flags=re.DOTALL | re.MULTILINE)
Expand Down
32 changes: 12 additions & 20 deletions src/main/askai/core/model/api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
Copyright (c) 2024, HomeSetup
"""
from askai.exception.exceptions import MissingApiKeyError
import os
from pathlib import Path
from typing import Optional

import dotenv
from clitt.core.tui.minput.input_validator import InputValidator
from clitt.core.tui.minput.menu_input import MenuInput
from clitt.core.tui.minput.minput import minput
from hspylib.core.enums.charset import Charset
from pathlib import Path
from pydantic.v1 import BaseSettings, Field, validator
from typing import AnyStr
from pydantic.v1 import BaseSettings, Field

import dotenv
import os
from askai.exception.exceptions import MissingApiKeyError

API_KEY_FILE: str = os.environ.get("HHS_ENV_FILE", str(os.path.join(Path.home(), ".env")))

Expand All @@ -37,21 +38,10 @@ class ApiKeys(BaseSettings):

# fmt: off
OPENAI_API_KEY: str = Field(..., description="Open AI Api Key")
GOOGLE_API_KEY: str = Field(..., description="Google Api Key")
DEEPL_API_KEY: str = Field(..., description="DeepL Api Key")
GOOGLE_API_KEY: Optional[str] = Field(None, description="Google Api Key")
DEEPL_API_KEY: Optional[str] = Field(None, description="DeepL Api Key")
# fmt: on

@validator("OPENAI_API_KEY", "GOOGLE_API_KEY", "DEEPL_API_KEY")
def not_empty(cls, value: AnyStr) -> AnyStr:
"""Pydantic validator to ensure that API key fields are not empty.
:param value: The value of the API key being validated.
:return: The value if it is not empty.
:raises ValueError: If the value is empty or None.
"""
if not value or not value.strip():
raise ValueError("must not be empty or blank")
return value

def has_key(self, key_name: str) -> bool:
"""Check if the specified API key exists and is not empty.
:param key_name: The name of the API key to check.
Expand Down Expand Up @@ -84,16 +74,18 @@ def prompt() -> bool:
.label('GOOGLE_API_KEY') \
.value(os.environ.get("GOOGLE_API_KEY")) \
.min_max_length(39, 39) \
.validator(InputValidator.anything()) \
.build() \
.field() \
.label('DEEPL_API_KEY') \
.value(os.environ.get("DEEPL_API_KEY")) \
.min_max_length(39, 39) \
.validator(InputValidator.anything()) \
.build() \
.build()
# fmt: on

if result := minput(form_fields, "Please fill all required Api Keys"):
if result := minput(form_fields, "Please fill the required ApiKeys"):
with open(API_KEY_FILE, "r+", encoding=Charset.UTF_8.val) as f_envs:
envs = f_envs.readlines()
with open(API_KEY_FILE, "w", encoding=Charset.UTF_8.val) as f_envs:
Expand Down
6 changes: 4 additions & 2 deletions src/main/askai/core/processors/splitter/splitter_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from textwrap import indent
from threading import Thread

from clitt.core.term.cursor import cursor
from hspylib.core.tools.commons import is_debugging
from rich.live import Live
from rich.spinner import Spinner
Expand All @@ -39,10 +40,10 @@ def __init__(self, query: str):
def pipeline(self) -> SplitterPipeline:
return self._pipeline

def display(self, reply: str) -> None:
def display(self, text: str) -> None:
"""TODO"""
if is_debugging():
text_formatter.console.print(Text.from_markup(reply))
text_formatter.console.print(Text.from_markup(text))

def run(self) -> None:
with Live(Spinner("dots", f"[green]{self.pipeline.state}…[/green]", style="green"), console=tf.console) as live:
Expand Down Expand Up @@ -105,6 +106,7 @@ def run(self) -> None:
self.display(f"[green]{execution_status_str}[/green]")
live.update(Spinner("dots", f"[green]{self.pipeline.state}…[/green]", style="green"))
self.pipeline.iteractions += 1
cursor.erase_line()

if configs.is_debug:
final_state: States = self.pipeline.state
Expand Down
5 changes: 5 additions & 0 deletions src/main/askai/core/processors/splitter/splitter_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""
from hspylib.core.enums.enumeration import Enumeration

from askai.core.askai_messages import msg


class States(Enumeration):
"""Enumeration of possible task splitter states."""
Expand All @@ -36,3 +38,6 @@ class States(Enumeration):

COMPLETE = 'Completed'
# fmt: on

def __str__(self) -> str:
return msg.t(self.value)
48 changes: 43 additions & 5 deletions src/main/askai/language/ai_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,58 @@
Copyright (c) 2024, HomeSetup
"""
import re
from abc import ABC
from typing import AnyStr

from askai.core.askai_events import events
from askai.core.model.ai_reply import AIReply
from askai.language.language import Language
from functools import lru_cache
from typing import Protocol


class AITranslator(Protocol):
class AITranslator(ABC):
"""Provides a base class for multilingual offline translation engines. Various implementations can be used."""

def __init__(self, source_lang: Language, target_lang: Language):
self._source_lang: Language = source_lang
self._target_lang: Language = target_lang

@lru_cache
def translate(self, text: str, **kwargs) -> str:
def translate(self, text: AnyStr) -> str:
"""Translates text excluding the parts enclosed in [TAG]...[/TAG] and %TAG%...%TAG% formatting.
:param text: The input text with [TAG]...[/TAG] and %TAG%...%TAG% formatting.
:returns str: The translated text with original tags preserved.
"""
if self._source_lang == self._target_lang:
return text

# Regex to match [TAG]..[/TAG], %TAG%..%/TAG%, "..", and '..'
tag_pattern = re.compile(r'(\[/?\w+]|%/?\w+%|["\']\w+["\'])', re.IGNORECASE)
parts = tag_pattern.split(text) # Split the text into parts: tags and non-tags
texts_to_translate = []
indices = []
# Collect texts to translate
for i, part in enumerate(parts):
if not tag_pattern.fullmatch(part) and part.strip() != '':
texts_to_translate.append(part)
indices.append(i)

try:
# Perform batch translation
translated_texts: list[str] = list(map(self.translate_text, texts_to_translate))
except Exception as err:
events.reply.emit(reply=AIReply.debug(f"Error during batch translation: {err}"))
return text

# Replace translated texts in the parts list
for idx, translated in zip(indices, translated_texts):
parts[idx] = translated

# Reassemble the translated text with tags
translated_text = ''.join(parts)

return translated_text

def translate_text(self, text: AnyStr, **kwargs) -> str:
"""Translate text from the source language to the target language.
:param text: Text to translate.
:return: The translated text.
Expand Down
17 changes: 9 additions & 8 deletions src/main/askai/language/translators/argos_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
Copyright (c) 2024, HomeSetup
"""

import logging as log
import os
import sys
from functools import lru_cache
from typing import Optional, AnyStr

from argostranslate import package, translate
from argostranslate.translate import ITranslation

from askai.exception.exceptions import TranslationPackageError
from askai.language.ai_translator import AITranslator
from askai.language.language import Language
from functools import lru_cache
from typing import Optional

import logging as log
import os
import sys


class ArgosTranslator(AITranslator):
Expand Down Expand Up @@ -65,12 +66,12 @@ def __init__(self, from_idiom: Language, to_idiom: Language):
self._argos_model = argos_model

@lru_cache
def translate(self, text: str, **kwargs) -> str:
def translate_text(self, text: AnyStr, **kwargs) -> str:
"""Translate text from the source language to the target language.
:param text: Text to translate.
:return: The translated text.
"""
return text if self._source_lang == self._target_lang else self._argos_model.translate(text)
return self._argos_model.translate(text)

def name(self) -> str:
return "Argos"
Expand Down
32 changes: 18 additions & 14 deletions src/main/askai/language/translators/deepl_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
Copyright (c) 2024, HomeSetup
"""
from askai.__classpath__ import API_KEYS
from askai.language.ai_translator import AITranslator
from askai.language.language import Language
from functools import lru_cache
from typing import AnyStr

import deepl
from deepl import Translator

from askai.__classpath__ import API_KEYS
from askai.language.ai_translator import AITranslator
from askai.language.language import Language


class DeepLTranslator(AITranslator):
Expand All @@ -27,23 +30,24 @@ class DeepLTranslator(AITranslator):

def __init__(self, source_lang: Language, target_lang: Language):
super().__init__(source_lang, target_lang)
API_KEYS.ensure("DEEPL_API_KEY", "DeepLTranslator")
self._translator = deepl.Translator(API_KEYS.DEEPL_API_KEY)
self._translator: Translator | None = None

@lru_cache
def translate(self, text: str, **kwargs) -> str:
def translate_text(self, text: AnyStr, **kwargs) -> str:
"""Translate text from the source language to the target language.
:param text: Text to translate.
:return: The translated text.
"""
if self._source_lang != self._target_lang:
kwargs["preserve_formatting"] = True
lang = self._from_locale()
result: deepl.TextResult = self._translator.translate_text(
text, source_lang=lang[0], target_lang=lang[1], **kwargs
)
return str(result)
return text
# Lazy initialization to allow DEEPL_API_KEY be optional.
if not self._translator:
API_KEYS.ensure("DEEPL_API_KEY", "DeepLTranslator")
self._translator = deepl.Translator(API_KEYS.DEEPL_API_KEY)
kwargs["preserve_formatting"] = True
lang = self._from_locale()
result: deepl.TextResult = self._translator.translate_text(
text, source_lang=lang[0], target_lang=lang[1], **kwargs
)
return str(result)

def name(self) -> str:
return "DeepL"
Expand Down
17 changes: 8 additions & 9 deletions src/main/askai/language/translators/marian_translator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from askai.language.ai_translator import AITranslator
from askai.language.language import Language
import re
from functools import lru_cache
from typing import AnyStr

from transformers import MarianMTModel, MarianTokenizer

import re
from askai.language.ai_translator import AITranslator
from askai.language.language import Language


class MarianTranslator(AITranslator):
Expand All @@ -21,20 +23,17 @@ def __init__(self, from_idiom: Language, to_idiom: Language):
self._tokenizer = MarianTokenizer.from_pretrained(self.MODEL_NAME)

@lru_cache
def translate(self, text: str, **kwargs) -> str:
def translate_text(self, text: AnyStr, **kwargs) -> str:
"""Translate text from the source language to the target language.
:param text: Text to translate.
:return: The translated text.
"""
if self._source_lang == self._target_lang:
return text

kwargs["return_tensors"] = "pt"
kwargs["padding"] = True

return self._translate(f">>{self._target_lang.idiom}<<{text}", **kwargs)
return self._decode(f">>{self._target_lang.idiom}<<{text}", **kwargs)

def _translate(self, text, **kwargs) -> str:
def _decode(self, text, **kwargs) -> str:
"""Wrapper function that is going to provide the translation of the text.
:param text: The text to be translated.
:return: The translated text.
Expand Down

0 comments on commit 6ab84cd

Please sign in to comment.