Skip to content

Commit

Permalink
Include Spinners for all processors but QNA
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Nov 14, 2024
1 parent d93e984 commit 1f0e790
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 43 deletions.
63 changes: 38 additions & 25 deletions src/main/askai/core/processors/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,26 @@
Copyright 2024, HSPyLib team
"""
from typing import Any, Optional

from clitt.core.term.cursor import cursor
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.singleton import Singleton
from hspylib.core.tools.dict_tools import get_or_default_by_key
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain_core.runnables import RunnableWithMessageHistory
from rich.live import Live
from rich.spinner import Spinner

from askai.core.askai_events import events
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.engine.openai.temperature import Temperature
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
from askai.core.support.shared_instances import shared
from askai.core.support.text_formatter import text_formatter as tf
from askai.exception.exceptions import TerminatingQuery
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.singleton import Singleton
from hspylib.core.tools.dict_tools import get_or_default_by_key
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain_core.runnables import RunnableWithMessageHistory
from typing import Any, Optional


class ChatProcessor(metaclass=Singleton):
Expand Down Expand Up @@ -59,25 +65,32 @@ def process(self, question: str, **kwargs) -> Optional[str]:
events.mode_changed.emit(mode="DEFAULT")
return None

response = None
prompt_file: str = get_or_default_by_key(kwargs, "prompt_file", None)
history_ctx: Any | None = get_or_default_by_key(kwargs, "history_ctx", "HISTORY")
ctx: str = get_or_default_by_key(kwargs, "context", "")
inputs: list[str] = get_or_default_by_key(kwargs, "inputs", [])
args: dict[str, Any] = get_or_default_by_key(kwargs, "args", {})
inputs = inputs or ["user", "idiom", "context", "question"]
args = args or {"user": prompt.user.title(), "idiom": shared.idiom, "context": ctx, "question": question}
prompt_file: PathObject = PathObject.of(prompt_file or prompt.append_path(f"taius/taius-jarvis"))
prompt_str: str = prompt.read_prompt(prompt_file.filename, prompt_file.abs_dir)

template = self.template(prompt_str, *inputs, **args)
runnable = template | lc_llm.create_chat_model(Temperature.COLDEST.temp)
runnable = RunnableWithMessageHistory(
runnable, shared.context.flat, input_messages_key="input", history_messages_key="chat_history"
)

if output := runnable.invoke({"input": question}, config={"configurable": {"session_id": history_ctx or ""}}):
response = output.content
with Live(
Spinner("dots", f"[green]{msg.wait()}[/green]", style="green"), console=tf.console
):
response = None
prompt_file: str = get_or_default_by_key(kwargs, "prompt_file", None)
history_ctx: Any | None = get_or_default_by_key(kwargs, "history_ctx", "HISTORY")
ctx: str = get_or_default_by_key(kwargs, "context", "")
inputs: list[str] = get_or_default_by_key(kwargs, "inputs", [])
args: dict[str, Any] = get_or_default_by_key(kwargs, "args", {})
inputs = inputs or ["user", "idiom", "context", "question"]
args = args or {"user": prompt.user.title(), "idiom": shared.idiom, "context": ctx, "question": question}
prompt_file: PathObject = PathObject.of(prompt_file or prompt.append_path(f"taius/taius-jarvis"))
prompt_str: str = prompt.read_prompt(prompt_file.filename, prompt_file.abs_dir)

template = self.template(prompt_str, *inputs, **args)
runnable = template | lc_llm.create_chat_model(Temperature.COLDEST.temp)
runnable = RunnableWithMessageHistory(
runnable, shared.context.flat, input_messages_key="input", history_messages_key="chat_history"
)

if output := runnable.invoke(
input={"input": question},
config={"configurable": {"session_id": history_ctx or ""}}):
response = output.content

cursor.erase_line()

return response

Expand Down
39 changes: 25 additions & 14 deletions src/main/askai/core/processors/qstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@
Copyright (c) 2024, HomeSetup
"""
from clitt.core.term.cursor import cursor
from rich.live import Live
from rich.spinner import Spinner

from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.cache_service import cache
from askai.core.engine.openai.temperature import Temperature
from askai.core.support.langchain_support import lc_llm
from askai.core.support.utilities import find_file
from askai.core.support.text_formatter import text_formatter as tf
from askai.exception.exceptions import TerminatingQuery
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.singleton import Singleton
Expand Down Expand Up @@ -45,20 +51,25 @@ def process(self, question: str, **kwargs) -> Optional[str]:
if question.casefold() in ["exit", "leave", "quit", "q"]:
return None

output = None
query_prompt: str | None = find_file(kwargs["query_prompt"]) if "query_prompt" in kwargs else None
context: str | None = kwargs["context"] if "context" in kwargs else None
temperature: int = kwargs["temperature"] if "temperature" in kwargs else None

dir_name, file_name = PathObject.split(query_prompt or self.DEFAULT_PROMPT)
template = PromptTemplate(
input_variables=["context", "question"], template=prompt.read_prompt(file_name, dir_name)
)
final_prompt: str = template.format(context=context or self.DEFAULT_CONTEXT, question=question)
llm = lc_llm.create_chat_model(temperature or self.DEFAULT_TEMPERATURE)

if (response := llm.invoke(final_prompt)) and (output := response.content):
cache.save_input_history()
with Live(
Spinner("dots", f"[green]{msg.wait()}[/green]", style="green"), console=tf.console
):
output = None
query_prompt: str | None = find_file(kwargs["query_prompt"]) if "query_prompt" in kwargs else None
context: str | None = kwargs["context"] if "context" in kwargs else None
temperature: int = kwargs["temperature"] if "temperature" in kwargs else None

dir_name, file_name = PathObject.split(query_prompt or self.DEFAULT_PROMPT)
template = PromptTemplate(
input_variables=["context", "question"], template=prompt.read_prompt(file_name, dir_name)
)
final_prompt: str = template.format(context=context or self.DEFAULT_CONTEXT, question=question)
llm = lc_llm.create_chat_model(temperature or self.DEFAULT_TEMPERATURE)

if (response := llm.invoke(final_prompt)) and (output := response.content):
cache.save_input_history()

cursor.erase_line()

return output

Expand Down
14 changes: 12 additions & 2 deletions src/main/askai/core/processors/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
Copyright (c) 2024, HomeSetup
"""
from clitt.core.term.cursor import cursor
from rich.live import Live
from rich.spinner import Spinner

from askai.core.askai_configs import configs
from askai.core.askai_events import events
from askai.core.askai_messages import msg
Expand All @@ -22,6 +26,7 @@
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
from askai.exception.exceptions import DocumentsNotFound, TerminatingQuery
from askai.core.support.text_formatter import text_formatter as tf
from functools import lru_cache
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.classpath import AnyPath
Expand Down Expand Up @@ -88,8 +93,13 @@ def process(self, question: str, **_) -> Optional[str]:
# FIXME Include kwargs to specify rag dir and glob
self.generate()

if not (output := self._rag_chain.invoke(question)):
output = msg.invalid_response(output)
with Live(
Spinner("dots", f"[green]{msg.wait()}[/green]", style="green"), console=tf.console
):
if not (output := self._rag_chain.invoke(question)):
output = msg.invalid_response(output)

cursor.erase_line()

return output

Expand Down
3 changes: 1 addition & 2 deletions src/main/askai/core/processors/splitter/splitter_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from askai.core.enums.acc_color import AccColor
from askai.core.processors.splitter.splitter_pipeline import SplitterPipeline
from askai.core.processors.splitter.splitter_states import States
from askai.core.support.text_formatter import text_formatter
from askai.core.support.text_formatter import text_formatter as tf
from askai.exception.exceptions import InaccurateResponse
from clitt.core.term.cursor import cursor
Expand Down Expand Up @@ -52,7 +51,7 @@ def display(self, text: str, force: bool = False) -> None:
:param force: Force displaying the message regardless of the debug flag.
"""
if force or is_debugging():
text_formatter.console.print(Text.from_markup(text))
tf.console.print(Text.from_markup(text))

def interrupt(self, ev: Event) -> None:
"""Interrupt the active execution pipeline.
Expand Down

0 comments on commit 1f0e790

Please sign in to comment.