From 43710f7818b2f7a30a58a5b623ee567c9f087122 Mon Sep 17 00:00:00 2001 From: Hugo Saporetti Junior Date: Fri, 27 Sep 2024 15:03:59 -0300 Subject: [PATCH] Improve RAG mode --- src/demo/others/spinner_demo.py | 30 +++++ src/main/askai/core/askai.py | 10 +- src/main/askai/core/askai_cli.py | 4 +- src/main/askai/core/askai_messages.py | 3 + .../askai/core/component/internet_service.py | 2 +- src/main/askai/core/component/summarizer.py | 7 +- src/main/askai/core/enums/router_mode.py | 19 ++- .../askai/core/features/processors/rag.py | 102 +++++++++----- .../askai/core/support/shared_instances.py | 43 ++++-- src/main/askai/core/support/spinner.py | 124 ++++++++++++++++++ src/main/askai/core/support/text_formatter.py | 2 +- .../prompts/langchain/rag-prompt.txt | 19 +++ src/main/askai/tui/app_icons.py | 4 +- 13 files changed, 308 insertions(+), 61 deletions(-) create mode 100644 src/demo/others/spinner_demo.py create mode 100644 src/main/askai/core/support/spinner.py create mode 100644 src/main/askai/resources/prompts/langchain/rag-prompt.txt diff --git a/src/demo/others/spinner_demo.py b/src/demo/others/spinner_demo.py new file mode 100644 index 00000000..1241ad37 --- /dev/null +++ b/src/demo/others/spinner_demo.py @@ -0,0 +1,30 @@ +import os + +import pause +from askai.core.support.spinner import Spinner +from hspylib.core.tools.commons import sysout +from hspylib.modules.cli.vt100.vt_color import VtColor + + +def echo(message: str, prefix: str | None = None, end=os.linesep) -> None: + """Prints a message with a prefix followed by the specified end character. + :param message: The message to print. + :param prefix: Optional prefix to prepend to the message. + :param end: The string appended after the message (default is a newline character). + """ + sysout(f"%CYAN%[{prefix[:22] if prefix else '':>23}]%NC% {message:<80}" if not len(end) else message, end=end) + + +if __name__ == "__main__": + # Example usage of the humanfriendly Spinner + with Spinner.DEFAULT.run(suffix="Wait", color=VtColor.CYAN) as spinner: + echo("Preparing the docs", "TASK", end="") + spinner.start() + pause.seconds(5) + spinner.stop() + echo("%GREEN%OK%NC%") + echo("Running the jobs", "TASK", end="") + spinner.start() + pause.seconds(5) + spinner.stop() + echo("%GREEN%OK%NC%") diff --git a/src/main/askai/core/askai.py b/src/main/askai/core/askai.py index c7c4c552..af2d2596 100644 --- a/src/main/askai/core/askai.py +++ b/src/main/askai/core/askai.py @@ -93,11 +93,11 @@ def __init__( configs.model = model_name self._session_id = now("%Y%m%d")[:8] - self._engine: AIEngine = shared.create_engine(engine_name, model_name) + self._engine: AIEngine = shared.create_engine(engine_name, model_name, RouterMode.RAG) self._context: ChatContext = shared.create_context(self._engine.ai_token_limit()) + self._mode: RouterMode = shared.mode self._console_path = Path(f"{CACHE_DIR}/askai-{self.session_id}.md") self._query_prompt: str | None = None - self._mode: RouterMode = RouterMode.default() if not self._console_path.exists(): self._console_path.touch() @@ -115,7 +115,11 @@ def context(self) -> ChatContext: @property def mode(self) -> RouterMode: - return self._mode + return shared.mode + + @mode.setter + def mode(self, value: RouterMode): + shared.mode = value @property def query_prompt(self) -> str: diff --git a/src/main/askai/core/askai_cli.py b/src/main/askai/core/askai_cli.py index 4e282f3f..8dc80406 100644 --- a/src/main/askai/core/askai_cli.py +++ b/src/main/askai/core/askai_cli.py @@ -136,8 +136,8 @@ def _cb_mode_changed_event(self, ev: Event) -> None: """Callback to handle mode change events. :param ev: The event object representing the mode change. """ - self._mode: RouterMode = RouterMode.of_name(ev.args.mode) - if not self._mode.is_default: + self.mode: RouterMode = RouterMode.of_name(ev.args.mode) + if self.mode == RouterMode.QNA: sum_msg: str = ( f"{msg.enter_qna()} \n" f"```\nContext:  {ev.args.sum_path},  {ev.args.glob} \n```\n" diff --git a/src/main/askai/core/askai_messages.py b/src/main/askai/core/askai_messages.py index b344f0bf..95cd9c7a 100644 --- a/src/main/askai/core/askai_messages.py +++ b/src/main/askai/core/askai_messages.py @@ -76,6 +76,9 @@ def welcome(self, username: str) -> str: def wait(self) -> str: return "I'm thinking…" + def loading(self, what: str) -> str: + return f"Loading {what}…" + def welcome_back(self) -> str: return "How may I further assist you ?" diff --git a/src/main/askai/core/component/internet_service.py b/src/main/askai/core/component/internet_service.py index 5223926f..1c866683 100644 --- a/src/main/askai/core/component/internet_service.py +++ b/src/main/askai/core/component/internet_service.py @@ -58,7 +58,7 @@ class InternetService(metaclass=Singleton): # fmt: off CATEGORY_ICONS = { "Weather": "", - "Sports": "", + "Sports": "醴", "News": "", "Celebrities": "", "People": "", diff --git a/src/main/askai/core/component/summarizer.py b/src/main/askai/core/component/summarizer.py index 833b4a23..d9bf6ccc 100644 --- a/src/main/askai/core/component/summarizer.py +++ b/src/main/askai/core/component/summarizer.py @@ -19,12 +19,14 @@ from askai.core.model.ai_reply import AIReply from askai.core.model.summary_result import SummaryResult from askai.core.support.langchain_support import lc_llm +from askai.core.support.spinner import Spinner from askai.exception.exceptions import DocumentsNotFound from functools import lru_cache from hspylib.core.config.path_object import PathObject from hspylib.core.metaclass.classpath import AnyPath from hspylib.core.metaclass.singleton import Singleton from hspylib.core.tools.text_tools import ensure_endswith, hash_text +from hspylib.modules.cli.vt100.vt_color import VtColor from langchain.chains import RetrievalQA from langchain_community.document_loaders import DirectoryLoader from langchain_community.vectorstores.chroma import Chroma @@ -116,7 +118,10 @@ def generate(self, folder: AnyPath, glob: str) -> bool: v_store = Chroma(persist_directory=str(self.persist_dir), embedding_function=embeddings) else: log.info("Summarizing documents from '%s'", self.sum_path) - documents: list[Document] = DirectoryLoader(self.folder, glob=self.glob).load() + with Spinner.COLIMA.run(suffix=msg.loading("documents"), color=VtColor.GREEN) as spinner: + spinner.start() + documents: list[Document] = DirectoryLoader(self.folder, glob=self.glob).load() + spinner.stop() if len(documents) <= 0: raise DocumentsNotFound(f"Unable to find any document to summarize at: '{self.sum_path}'") texts: list[Document] = self._text_splitter.split_documents(documents) diff --git a/src/main/askai/core/enums/router_mode.py b/src/main/askai/core/enums/router_mode.py index ebd0698f..b01c9995 100644 --- a/src/main/askai/core/enums/router_mode.py +++ b/src/main/askai/core/enums/router_mode.py @@ -29,13 +29,13 @@ class RouterMode(Enumeration): # fmt: on - TASK_SPLIT = "Task Splitter", splitter + TASK_SPLIT = "Task Splitter", "", splitter - QNA = "Questions and Answers", qna + QNA = "Questions and Answers", "", qna - QSTRING = "Non-Interactive", qstring + QSTRING = "Non-Interactive", "", qstring - RAG = "Retrieval-Augmented Generation", rag + RAG = "Retrieval-Augmented Generation", "ﮐ", rag # fmt: off @@ -63,16 +63,23 @@ def of_name(cls, name: str) -> 'RouterMode': return cls[name] if name.casefold() != 'default' else cls.default() def __str__(self): - return self.value[0] + return f"{self.icon} {self.name}" + + def __eq__(self, other: 'RouterMode') -> bool: + return self.name == other.name @property def name(self) -> str: return self.value[0] @property - def processor(self) -> AIProcessor: + def icon(self) -> str: return self.value[1] + @property + def processor(self) -> AIProcessor: + return self.value[2] + @property def is_default(self) -> bool: return self == RouterMode.default() diff --git a/src/main/askai/core/features/processors/rag.py b/src/main/askai/core/features/processors/rag.py index 1fcfcd71..aa4fd292 100644 --- a/src/main/askai/core/features/processors/rag.py +++ b/src/main/askai/core/features/processors/rag.py @@ -1,19 +1,28 @@ +import logging as log +import os +from pathlib import Path + from askai.core.askai_configs import configs from askai.core.askai_events import events from askai.core.askai_messages import msg -from askai.core.component.cache_service import RAG_DIR +from askai.core.askai_prompt import prompt +from askai.core.component.cache_service import RAG_DIR, PERSIST_DIR from askai.core.engine.openai.temperature import Temperature from askai.core.support.langchain_support import lc_llm -from functools import lru_cache +from askai.core.support.spinner import Spinner +from askai.exception.exceptions import DocumentsNotFound +from hspylib.core.config.path_object import PathObject +from hspylib.core.metaclass.classpath import AnyPath from hspylib.core.metaclass.singleton import Singleton -from langchain import hub +from hspylib.core.tools.text_tools import hash_text +from hspylib.modules.cli.vt100.vt_color import VtColor from langchain_community.document_loaders import DirectoryLoader from langchain_community.vectorstores import Chroma from langchain_core.documents import Document from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnablePassthrough +from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate +from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_text_splitters import RecursiveCharacterTextSplitter -from pathlib import Path class Rag(metaclass=Singleton): @@ -22,17 +31,34 @@ class Rag(metaclass=Singleton): INSTANCE: "Rag" def __init__(self): - self._rag_chain = None - self._vectorstore = None + self._rag_chain: Runnable | None = None + self._vectorstore: Chroma | None = None self._text_splitter = RecursiveCharacterTextSplitter( chunk_size=configs.chunk_size, chunk_overlap=configs.chunk_overlap ) + @property + def rag_template(self) -> BasePromptTemplate: + prompt_file: PathObject = PathObject.of(prompt.append_path(f"langchain/rag-prompt")) + final_prompt: str = prompt.read_prompt(prompt_file.filename, prompt_file.abs_dir) + # fmt: off + return ChatPromptTemplate.from_messages([ + ("system", final_prompt), + ("user", "{question}"), + ("user", "{context}") + ]) + # fmt: on + + def persist_dir(self, file_glob: AnyPath) -> Path: + summary_hash = hash_text(file_glob) + return Path(f"{PERSIST_DIR}/{summary_hash}") + def process(self, question: str, **_) -> str: """Process the user question to retrieve the final response. :param question: The user question to process. + :return: The final response after processing the question. """ - self._generate() + self.generate() if question.casefold() == "exit": events.mode_changed.emit(mode="DEFAULT") @@ -40,31 +66,45 @@ def process(self, question: str, **_) -> str: elif not (output := self._rag_chain.invoke(question)): output = msg.invalid_response(output) - self._vectorstore.delete_collection() - return output - @lru_cache(maxsize=8) - def _generate(self, rag_dir: str | Path = RAG_DIR) -> None: - loader: DirectoryLoader = DirectoryLoader(str(rag_dir)) - rag_docs: list[Document] = loader.load() - llm = lc_llm.create_model(temperature=Temperature.DATA_ANALYSIS.temp) - embeddings = lc_llm.create_embeddings() - splits = self._text_splitter.split_documents(rag_docs) - - self._vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) - retriever = self._vectorstore.as_retriever() - rag_prompt = hub.pull("rlm/rag-prompt") - - def _format_docs_(docs): - return "\n\n".join(doc.page_content for doc in docs) - - self._rag_chain = ( - {"context": retriever | _format_docs_, "question": RunnablePassthrough()} - | rag_prompt - | llm - | StrOutputParser() - ) + def generate(self, file_glob: str = "**/*.md") -> None: + """Generates RAG data from the specified directory. + :param file_glob: The files from which to generate the RAG database. + """ + if not self._rag_chain: + embeddings = lc_llm.create_embeddings() + llm = lc_llm.create_chat_model(temperature=Temperature.DATA_ANALYSIS.temp) + persist_dir: str = str(self.persist_dir(file_glob)) + if os.path.exists(persist_dir): + log.info("Recovering vector store from: '%s'", persist_dir) + self._vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embeddings) + else: + with Spinner.COLIMA.run(suffix=msg.loading("documents"), color=VtColor.GREEN) as spinner: + spinner.start() + rag_docs: list[Document] = DirectoryLoader(str(RAG_DIR), glob=file_glob, recursive=True).load() + spinner.stop() + if len(rag_docs) <= 0: + raise DocumentsNotFound(f"Unable to find any document to at: '{persist_dir}'") + self._vectorstore = Chroma.from_documents( + persist_directory=persist_dir, + documents=self._text_splitter.split_documents(rag_docs), + embedding=embeddings, + ) + + retriever = self._vectorstore.as_retriever() + rag_prompt = self.rag_template + + def _format_docs_(docs): + return "\n\n".join(doc.page_content for doc in docs) + + self._rag_chain = ( + {"context": retriever | _format_docs_, "question": RunnablePassthrough()} + | rag_prompt + | llm + | StrOutputParser() + ) + return self._rag_chain assert (rag := Rag().INSTANCE) is not None diff --git a/src/main/askai/core/support/shared_instances.py b/src/main/askai/core/support/shared_instances.py index 8018873a..642f5437 100644 --- a/src/main/askai/core/support/shared_instances.py +++ b/src/main/askai/core/support/shared_instances.py @@ -13,6 +13,10 @@ Copyright (c) 2024, HomeSetup """ +import os +from pathlib import Path +from typing import Optional, Any + from askai.__classpath__ import classpath from askai.core.askai_configs import configs from askai.core.askai_messages import msg @@ -33,10 +37,6 @@ from hspylib.modules.cli.keyboard import Keyboard from langchain.memory import ConversationBufferWindowMemory from langchain.memory.chat_memory import BaseChatMemory -from pathlib import Path -from typing import Optional - -import os class SharedInstances(metaclass=Singleton): @@ -48,12 +48,22 @@ class SharedInstances(metaclass=Singleton): UNCERTAIN_ID: str = "bde6f44d-c1a0-4b0c-bd74-8278e468e50c" def __init__(self) -> None: - self._engine: AIEngine | None = None self._context: ChatContext | None = None + self._engine: AIEngine | None = None + self._mode: Any | None = None self._memory: ConversationBufferWindowMemory | None = None self._idiom: str = configs.language.idiom self._max_iteractions: int = configs.max_iteractions + @property + def context(self) -> Optional[ChatContext]: + return self._context + + @context.setter + def context(self, value: ChatContext) -> None: + check_state(self._context is None, "Once set, this instance is immutable.") + self._context = value + @property def engine(self) -> Optional[AIEngine]: return self._engine @@ -64,17 +74,19 @@ def engine(self, value: AIEngine) -> None: self._engine = value @property - def context(self) -> Optional[ChatContext]: - return self._context + def mode(self) -> Any: + return self._mode - @context.setter - def context(self, value: ChatContext) -> None: - check_state(self._context is None, "Once set, this instance is immutable.") - self._context = value + @mode.setter + def mode(self, value: Any) -> None: + self._mode = value + + def mode_icon(self) -> str: + return self._mode.icon @property def nickname(self) -> str: - return f"%GREEN% Taius:%NC% " + return f"%GREEN%{self.mode.icon} Taius:%NC% " @property def username(self) -> str: @@ -82,7 +94,7 @@ def username(self) -> str: @property def nickname_md(self) -> str: - return f"* Taius:* " + return f"*{self.mode.icon} Taius:* " @property def username_md(self) -> str: @@ -114,6 +126,7 @@ def app_info(self) -> str: f"{dtm.center(80, '=')} %EOL%" f" Language: {configs.language} {translator} %EOL%" f" Engine: {shared.engine} %EOL%" + f" Mode: {self.mode} %EOL%" f" Dir: {cur_dir} %EOL%" f" OS: {prompt.os_type}/{prompt.shell} %EOL%" f"{'-' * 80} %EOL%" @@ -124,14 +137,16 @@ def app_info(self) -> str: f"{'=' * 80}%EOL%%NC%" ) - def create_engine(self, engine_name: str, model_name: str) -> AIEngine: + def create_engine(self, engine_name: str, model_name: str, mode: Any) -> AIEngine: """Create or retrieve an AI engine instance based on the specified engine and model names. :param engine_name: The name of the AI engine to create or retrieve. :param model_name: The name of the model to use with the AI engine. + :param mode: The engine routing mode. :return: An instance of the AIEngine configured with the specified engine and model. """ if self._engine is None: self._engine = EngineFactory.create_engine(engine_name, model_name) + self._mode = mode return self._engine def create_context( diff --git a/src/main/askai/core/support/spinner.py b/src/main/askai/core/support/spinner.py new file mode 100644 index 00000000..bc6f3eef --- /dev/null +++ b/src/main/askai/core/support/spinner.py @@ -0,0 +1,124 @@ +import contextlib +import itertools +import threading +import time +from threading import Thread + +import pause +from clitt.core.term.cursor import cursor +from clitt.core.term.terminal import Terminal +from hspylib.core.enums.enumeration import Enumeration +from hspylib.core.tools.commons import to_bool, sysout +from hspylib.modules.cli.vt100.vt_color import VtColor +from rich.console import Console + + +class Spinner(Enumeration): + """TODO""" + + # fmt: off + + DEFAULT = 200, ["●○○", "●●○", "●●●", "○●●", "○○●"] + + ARROW_BAR = 220, ["▹▹▹", "▸▹▹", "▸▸▹", "▸▸▸"] + + BAR = 50, ["▁", "▃", "▄", "▅", "▆", "▇", "█", "▇", "▆", "▅", "▄", "▃"] + + BOUNCE = 220, ['⠁', '⠂', '⠄', '⠂'] + + BOX = 220, ["▖", "▘", "▝", "▗"] + + BULLET = 250, ["○", "●"] + + CIRCLE = 220, ["◐", "◓", "◑", "◒"] + + COLIMA = 70, ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + + HEXAGON = 250, ["⬢", "⬡"] + + LINE_GROW = 60, [" ", "= ", "== ", "=== ", " ===", " ==", " ="] + + LINE_BOUNCE = 60, [" ", "= ", "== ", "=== ", "====", "=== ", "== ", "= "] + + STAR = 180, ["✶", "✸", "✹", "✺", "✹", "✷"] + + # fmt: on + + def __init__(self, interval: int, symbols: list[str]): + self._started: bool = False + self._worker: Thread | None = None + + @property + def started(self) -> bool: + return self._started + + @started.setter + def started(self, value: bool): + self._started = to_bool(value) + + @contextlib.contextmanager + def run( + self, + interval: int = None, + prefix: str = None, + suffix: str = None, + color: VtColor = VtColor.WHITE + ) -> None: + """TODO""" + spinner = itertools.cycle(self.symbols) + + def _work_(): + try: + Terminal.set_show_cursor(False) + while threading.main_thread().is_alive(): + pause.milliseconds(interval if interval else self.interval) + while not self.started and threading.main_thread().is_alive(): + pause.milliseconds(interval if interval else self.interval) + cursor.write(color.placeholder) + smb: str = next(spinner) + cursor.write(f"{prefix + ' 'if prefix else ''}{smb}{' ' + suffix if suffix else ''}%NC%%ED0%") + cursor.restore() + except InterruptedError: + pass + Terminal.set_show_cursor() + + self._worker = Thread(daemon=True, target=_work_) + self._worker.start() + yield self + + def start(self) -> None: + """TODO""" + self.started = True + cursor.save() + + def stop(self) -> None: + """TODO""" + self.started = False + sysout(f"%CUB({len(self.symbols[0])})%%EL0%", end="") + + def wait(self, timeout: int = None) -> None: + """TODO""" + self._worker.join(timeout) + + @property + def interval(self) -> int: + return self.value[0] + + @property + def symbols(self) -> list[str]: + return self.value[1] + + +if __name__ == '__main__': + # Initialize the console + console = Console() + + # Example usage of a spinner + with console.status("[bold green]Working on tasks...", spinner="bouncingBar") as status: + for i in range(3): + time.sleep(2) # Simulate work + # Update the status to control spinner message + status.update(f"[bold green]Task {i + 1}/3 in progress...") + # Print task progress on a new line + console.print(f"[bold green]√ Task {i + 1}/3 is done") + console.print("[bold green] All tasks completed.") diff --git a/src/main/askai/core/support/text_formatter.py b/src/main/askai/core/support/text_formatter.py index 062d759d..e2ad66bb 100644 --- a/src/main/askai/core/support/text_formatter.py +++ b/src/main/askai/core/support/text_formatter.py @@ -40,7 +40,7 @@ class TextFormatter(metaclass=Singleton): RE_MD_CODE_BLOCK = r"(```.+```)" CHAT_ICONS = { - "": ">  Oops: ", + "": " Oops! %NC%", "": "\n>  *Tip:* ", "": "\n>  *Analysis:* ", "": "\n>  *Summary:* ", diff --git a/src/main/askai/resources/prompts/langchain/rag-prompt.txt b/src/main/askai/resources/prompts/langchain/rag-prompt.txt new file mode 100644 index 00000000..6e5ccbf8 --- /dev/null +++ b/src/main/askai/resources/prompts/langchain/rag-prompt.txt @@ -0,0 +1,19 @@ +You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. + +Keep the answer concise. + +Give examples. + +Don't express your opinion. + +Be helpful. + +Mention where in the docs to find out more information. + +If you don't know the answer, just say that you don't know. + +Question: {question} + +Context: {context} + +Answer: diff --git a/src/main/askai/tui/app_icons.py b/src/main/askai/tui/app_icons.py index ec695736..1093b25d 100644 --- a/src/main/askai/tui/app_icons.py +++ b/src/main/askai/tui/app_icons.py @@ -19,7 +19,7 @@ class AppIcons(Enumeration): """Enumerated icons of the new AskAI UI application.""" - # icons:                      鬒  穀           + # icons:                      鬒  穀          DEFAULT = "" STARTED = "" @@ -29,7 +29,7 @@ class AppIcons(Enumeration): HELP = "" SETTINGS = "" INFO = "" - CONSOLE = "" + CONSOLE = "" DEBUG_ON = "" DEBUG_OFF = "" SPEAKING_ON = "墳"