diff --git a/src/main/askai/core/processors/chat.py b/src/main/askai/core/processors/chat.py index f97a612a..be06c8e1 100644 --- a/src/main/askai/core/processors/chat.py +++ b/src/main/askai/core/processors/chat.py @@ -12,6 +12,17 @@ Copyright 2024, HSPyLib team """ +from typing import Any, Optional + +from clitt.core.term.cursor import cursor +from hspylib.core.config.path_object import PathObject +from hspylib.core.metaclass.singleton import Singleton +from hspylib.core.tools.dict_tools import get_or_default_by_key +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate +from langchain_core.runnables import RunnableWithMessageHistory +from rich.live import Live +from rich.spinner import Spinner + from askai.core.askai_events import events from askai.core.askai_messages import msg from askai.core.askai_prompt import prompt @@ -19,13 +30,8 @@ from askai.core.model.ai_reply import AIReply from askai.core.support.langchain_support import lc_llm from askai.core.support.shared_instances import shared +from askai.core.support.text_formatter import text_formatter as tf from askai.exception.exceptions import TerminatingQuery -from hspylib.core.config.path_object import PathObject -from hspylib.core.metaclass.singleton import Singleton -from hspylib.core.tools.dict_tools import get_or_default_by_key -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate -from langchain_core.runnables import RunnableWithMessageHistory -from typing import Any, Optional class ChatProcessor(metaclass=Singleton): @@ -59,25 +65,32 @@ def process(self, question: str, **kwargs) -> Optional[str]: events.mode_changed.emit(mode="DEFAULT") return None - response = None - prompt_file: str = get_or_default_by_key(kwargs, "prompt_file", None) - history_ctx: Any | None = get_or_default_by_key(kwargs, "history_ctx", "HISTORY") - ctx: str = get_or_default_by_key(kwargs, "context", "") - inputs: list[str] = get_or_default_by_key(kwargs, "inputs", []) - args: dict[str, Any] = get_or_default_by_key(kwargs, "args", {}) - inputs = inputs or ["user", "idiom", "context", "question"] - args = args or {"user": prompt.user.title(), "idiom": shared.idiom, "context": ctx, "question": question} - prompt_file: PathObject = PathObject.of(prompt_file or prompt.append_path(f"taius/taius-jarvis")) - prompt_str: str = prompt.read_prompt(prompt_file.filename, prompt_file.abs_dir) - - template = self.template(prompt_str, *inputs, **args) - runnable = template | lc_llm.create_chat_model(Temperature.COLDEST.temp) - runnable = RunnableWithMessageHistory( - runnable, shared.context.flat, input_messages_key="input", history_messages_key="chat_history" - ) - - if output := runnable.invoke({"input": question}, config={"configurable": {"session_id": history_ctx or ""}}): - response = output.content + with Live( + Spinner("dots", f"[green]{msg.wait()}[/green]", style="green"), console=tf.console + ): + response = None + prompt_file: str = get_or_default_by_key(kwargs, "prompt_file", None) + history_ctx: Any | None = get_or_default_by_key(kwargs, "history_ctx", "HISTORY") + ctx: str = get_or_default_by_key(kwargs, "context", "") + inputs: list[str] = get_or_default_by_key(kwargs, "inputs", []) + args: dict[str, Any] = get_or_default_by_key(kwargs, "args", {}) + inputs = inputs or ["user", "idiom", "context", "question"] + args = args or {"user": prompt.user.title(), "idiom": shared.idiom, "context": ctx, "question": question} + prompt_file: PathObject = PathObject.of(prompt_file or prompt.append_path(f"taius/taius-jarvis")) + prompt_str: str = prompt.read_prompt(prompt_file.filename, prompt_file.abs_dir) + + template = self.template(prompt_str, *inputs, **args) + runnable = template | lc_llm.create_chat_model(Temperature.COLDEST.temp) + runnable = RunnableWithMessageHistory( + runnable, shared.context.flat, input_messages_key="input", history_messages_key="chat_history" + ) + + if output := runnable.invoke( + input={"input": question}, + config={"configurable": {"session_id": history_ctx or ""}}): + response = output.content + + cursor.erase_line() return response diff --git a/src/main/askai/core/processors/qstring.py b/src/main/askai/core/processors/qstring.py index 8c48addf..a838c2ea 100644 --- a/src/main/askai/core/processors/qstring.py +++ b/src/main/askai/core/processors/qstring.py @@ -12,11 +12,17 @@ Copyright (c) 2024, HomeSetup """ +from clitt.core.term.cursor import cursor +from rich.live import Live +from rich.spinner import Spinner + +from askai.core.askai_messages import msg from askai.core.askai_prompt import prompt from askai.core.component.cache_service import cache from askai.core.engine.openai.temperature import Temperature from askai.core.support.langchain_support import lc_llm from askai.core.support.utilities import find_file +from askai.core.support.text_formatter import text_formatter as tf from askai.exception.exceptions import TerminatingQuery from hspylib.core.config.path_object import PathObject from hspylib.core.metaclass.singleton import Singleton @@ -45,20 +51,25 @@ def process(self, question: str, **kwargs) -> Optional[str]: if question.casefold() in ["exit", "leave", "quit", "q"]: return None - output = None - query_prompt: str | None = find_file(kwargs["query_prompt"]) if "query_prompt" in kwargs else None - context: str | None = kwargs["context"] if "context" in kwargs else None - temperature: int = kwargs["temperature"] if "temperature" in kwargs else None - - dir_name, file_name = PathObject.split(query_prompt or self.DEFAULT_PROMPT) - template = PromptTemplate( - input_variables=["context", "question"], template=prompt.read_prompt(file_name, dir_name) - ) - final_prompt: str = template.format(context=context or self.DEFAULT_CONTEXT, question=question) - llm = lc_llm.create_chat_model(temperature or self.DEFAULT_TEMPERATURE) - - if (response := llm.invoke(final_prompt)) and (output := response.content): - cache.save_input_history() + with Live( + Spinner("dots", f"[green]{msg.wait()}[/green]", style="green"), console=tf.console + ): + output = None + query_prompt: str | None = find_file(kwargs["query_prompt"]) if "query_prompt" in kwargs else None + context: str | None = kwargs["context"] if "context" in kwargs else None + temperature: int = kwargs["temperature"] if "temperature" in kwargs else None + + dir_name, file_name = PathObject.split(query_prompt or self.DEFAULT_PROMPT) + template = PromptTemplate( + input_variables=["context", "question"], template=prompt.read_prompt(file_name, dir_name) + ) + final_prompt: str = template.format(context=context or self.DEFAULT_CONTEXT, question=question) + llm = lc_llm.create_chat_model(temperature or self.DEFAULT_TEMPERATURE) + + if (response := llm.invoke(final_prompt)) and (output := response.content): + cache.save_input_history() + + cursor.erase_line() return output diff --git a/src/main/askai/core/processors/rag.py b/src/main/askai/core/processors/rag.py index d18305fd..a7f9bad9 100644 --- a/src/main/askai/core/processors/rag.py +++ b/src/main/askai/core/processors/rag.py @@ -12,6 +12,10 @@ Copyright (c) 2024, HomeSetup """ +from clitt.core.term.cursor import cursor +from rich.live import Live +from rich.spinner import Spinner + from askai.core.askai_configs import configs from askai.core.askai_events import events from askai.core.askai_messages import msg @@ -22,6 +26,7 @@ from askai.core.model.ai_reply import AIReply from askai.core.support.langchain_support import lc_llm from askai.exception.exceptions import DocumentsNotFound, TerminatingQuery +from askai.core.support.text_formatter import text_formatter as tf from functools import lru_cache from hspylib.core.config.path_object import PathObject from hspylib.core.metaclass.classpath import AnyPath @@ -88,8 +93,13 @@ def process(self, question: str, **_) -> Optional[str]: # FIXME Include kwargs to specify rag dir and glob self.generate() - if not (output := self._rag_chain.invoke(question)): - output = msg.invalid_response(output) + with Live( + Spinner("dots", f"[green]{msg.wait()}[/green]", style="green"), console=tf.console + ): + if not (output := self._rag_chain.invoke(question)): + output = msg.invalid_response(output) + + cursor.erase_line() return output diff --git a/src/main/askai/core/processors/splitter/splitter_executor.py b/src/main/askai/core/processors/splitter/splitter_executor.py index 2f34d3e7..4b6151f8 100644 --- a/src/main/askai/core/processors/splitter/splitter_executor.py +++ b/src/main/askai/core/processors/splitter/splitter_executor.py @@ -18,7 +18,6 @@ from askai.core.enums.acc_color import AccColor from askai.core.processors.splitter.splitter_pipeline import SplitterPipeline from askai.core.processors.splitter.splitter_states import States -from askai.core.support.text_formatter import text_formatter from askai.core.support.text_formatter import text_formatter as tf from askai.exception.exceptions import InaccurateResponse from clitt.core.term.cursor import cursor @@ -52,7 +51,7 @@ def display(self, text: str, force: bool = False) -> None: :param force: Force displaying the message regardless of the debug flag. """ if force or is_debugging(): - text_formatter.console.print(Text.from_markup(text)) + tf.console.print(Text.from_markup(text)) def interrupt(self, ev: Event) -> None: """Interrupt the active execution pipeline.