Skip to content

Commit

Permalink
Improve RAG mode
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Sep 27, 2024
1 parent 0c0149b commit 43710f7
Show file tree
Hide file tree
Showing 13 changed files with 308 additions and 61 deletions.
30 changes: 30 additions & 0 deletions src/demo/others/spinner_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

import pause
from askai.core.support.spinner import Spinner
from hspylib.core.tools.commons import sysout
from hspylib.modules.cli.vt100.vt_color import VtColor


def echo(message: str, prefix: str | None = None, end=os.linesep) -> None:
"""Prints a message with a prefix followed by the specified end character.
:param message: The message to print.
:param prefix: Optional prefix to prepend to the message.
:param end: The string appended after the message (default is a newline character).
"""
sysout(f"%CYAN%[{prefix[:22] if prefix else '':>23}]%NC% {message:<80}" if not len(end) else message, end=end)


if __name__ == "__main__":
# Example usage of the humanfriendly Spinner
with Spinner.DEFAULT.run(suffix="Wait", color=VtColor.CYAN) as spinner:
echo("Preparing the docs", "TASK", end="")
spinner.start()
pause.seconds(5)
spinner.stop()
echo("%GREEN%OK%NC%")
echo("Running the jobs", "TASK", end="")
spinner.start()
pause.seconds(5)
spinner.stop()
echo("%GREEN%OK%NC%")
10 changes: 7 additions & 3 deletions src/main/askai/core/askai.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def __init__(
configs.model = model_name

self._session_id = now("%Y%m%d")[:8]
self._engine: AIEngine = shared.create_engine(engine_name, model_name)
self._engine: AIEngine = shared.create_engine(engine_name, model_name, RouterMode.RAG)
self._context: ChatContext = shared.create_context(self._engine.ai_token_limit())
self._mode: RouterMode = shared.mode
self._console_path = Path(f"{CACHE_DIR}/askai-{self.session_id}.md")
self._query_prompt: str | None = None
self._mode: RouterMode = RouterMode.default()

if not self._console_path.exists():
self._console_path.touch()
Expand All @@ -115,7 +115,11 @@ def context(self) -> ChatContext:

@property
def mode(self) -> RouterMode:
return self._mode
return shared.mode

@mode.setter
def mode(self, value: RouterMode):
shared.mode = value

@property
def query_prompt(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/main/askai/core/askai_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def _cb_mode_changed_event(self, ev: Event) -> None:
"""Callback to handle mode change events.
:param ev: The event object representing the mode change.
"""
self._mode: RouterMode = RouterMode.of_name(ev.args.mode)
if not self._mode.is_default:
self.mode: RouterMode = RouterMode.of_name(ev.args.mode)
if self.mode == RouterMode.QNA:
sum_msg: str = (
f"{msg.enter_qna()} \n"
f"```\nContext:  {ev.args.sum_path},  {ev.args.glob} \n```\n"
Expand Down
3 changes: 3 additions & 0 deletions src/main/askai/core/askai_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def welcome(self, username: str) -> str:
def wait(self) -> str:
return "I'm thinking…"

def loading(self, what: str) -> str:
return f"Loading {what}…"

def welcome_back(self) -> str:
return "How may I further assist you ?"

Expand Down
2 changes: 1 addition & 1 deletion src/main/askai/core/component/internet_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class InternetService(metaclass=Singleton):
# fmt: off
CATEGORY_ICONS = {
"Weather": "",
"Sports": "",
"Sports": "",
"News": "",
"Celebrities": "",
"People": "",
Expand Down
7 changes: 6 additions & 1 deletion src/main/askai/core/component/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from askai.core.model.ai_reply import AIReply
from askai.core.model.summary_result import SummaryResult
from askai.core.support.langchain_support import lc_llm
from askai.core.support.spinner import Spinner
from askai.exception.exceptions import DocumentsNotFound
from functools import lru_cache
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.classpath import AnyPath
from hspylib.core.metaclass.singleton import Singleton
from hspylib.core.tools.text_tools import ensure_endswith, hash_text
from hspylib.modules.cli.vt100.vt_color import VtColor
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.vectorstores.chroma import Chroma
Expand Down Expand Up @@ -116,7 +118,10 @@ def generate(self, folder: AnyPath, glob: str) -> bool:
v_store = Chroma(persist_directory=str(self.persist_dir), embedding_function=embeddings)
else:
log.info("Summarizing documents from '%s'", self.sum_path)
documents: list[Document] = DirectoryLoader(self.folder, glob=self.glob).load()
with Spinner.COLIMA.run(suffix=msg.loading("documents"), color=VtColor.GREEN) as spinner:
spinner.start()
documents: list[Document] = DirectoryLoader(self.folder, glob=self.glob).load()
spinner.stop()
if len(documents) <= 0:
raise DocumentsNotFound(f"Unable to find any document to summarize at: '{self.sum_path}'")
texts: list[Document] = self._text_splitter.split_documents(documents)
Expand Down
19 changes: 13 additions & 6 deletions src/main/askai/core/enums/router_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ class RouterMode(Enumeration):

# fmt: on

TASK_SPLIT = "Task Splitter", splitter
TASK_SPLIT = "Task Splitter", "", splitter

QNA = "Questions and Answers", qna
QNA = "Questions and Answers", "", qna

QSTRING = "Non-Interactive", qstring
QSTRING = "Non-Interactive", "", qstring

RAG = "Retrieval-Augmented Generation", rag
RAG = "Retrieval-Augmented Generation", "ﮐ", rag

# fmt: off

Expand Down Expand Up @@ -63,16 +63,23 @@ def of_name(cls, name: str) -> 'RouterMode':
return cls[name] if name.casefold() != 'default' else cls.default()

def __str__(self):
return self.value[0]
return f"{self.icon} {self.name}"

def __eq__(self, other: 'RouterMode') -> bool:
return self.name == other.name

@property
def name(self) -> str:
return self.value[0]

@property
def processor(self) -> AIProcessor:
def icon(self) -> str:
return self.value[1]

@property
def processor(self) -> AIProcessor:
return self.value[2]

@property
def is_default(self) -> bool:
return self == RouterMode.default()
Expand Down
102 changes: 71 additions & 31 deletions src/main/askai/core/features/processors/rag.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import logging as log
import os
from pathlib import Path

from askai.core.askai_configs import configs
from askai.core.askai_events import events
from askai.core.askai_messages import msg
from askai.core.component.cache_service import RAG_DIR
from askai.core.askai_prompt import prompt
from askai.core.component.cache_service import RAG_DIR, PERSIST_DIR
from askai.core.engine.openai.temperature import Temperature
from askai.core.support.langchain_support import lc_llm
from functools import lru_cache
from askai.core.support.spinner import Spinner
from askai.exception.exceptions import DocumentsNotFound
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.classpath import AnyPath
from hspylib.core.metaclass.singleton import Singleton
from langchain import hub
from hspylib.core.tools.text_tools import hash_text
from hspylib.modules.cli.vt100.vt_color import VtColor
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pathlib import Path


class Rag(metaclass=Singleton):
Expand All @@ -22,49 +31,80 @@ class Rag(metaclass=Singleton):
INSTANCE: "Rag"

def __init__(self):
self._rag_chain = None
self._vectorstore = None
self._rag_chain: Runnable | None = None
self._vectorstore: Chroma | None = None
self._text_splitter = RecursiveCharacterTextSplitter(
chunk_size=configs.chunk_size, chunk_overlap=configs.chunk_overlap
)

@property
def rag_template(self) -> BasePromptTemplate:
prompt_file: PathObject = PathObject.of(prompt.append_path(f"langchain/rag-prompt"))
final_prompt: str = prompt.read_prompt(prompt_file.filename, prompt_file.abs_dir)
# fmt: off
return ChatPromptTemplate.from_messages([
("system", final_prompt),
("user", "{question}"),
("user", "{context}")
])
# fmt: on

def persist_dir(self, file_glob: AnyPath) -> Path:
summary_hash = hash_text(file_glob)
return Path(f"{PERSIST_DIR}/{summary_hash}")

def process(self, question: str, **_) -> str:
"""Process the user question to retrieve the final response.
:param question: The user question to process.
:return: The final response after processing the question.
"""
self._generate()
self.generate()

if question.casefold() == "exit":
events.mode_changed.emit(mode="DEFAULT")
output = msg.leave_rag()
elif not (output := self._rag_chain.invoke(question)):
output = msg.invalid_response(output)

self._vectorstore.delete_collection()

return output

@lru_cache(maxsize=8)
def _generate(self, rag_dir: str | Path = RAG_DIR) -> None:
loader: DirectoryLoader = DirectoryLoader(str(rag_dir))
rag_docs: list[Document] = loader.load()
llm = lc_llm.create_model(temperature=Temperature.DATA_ANALYSIS.temp)
embeddings = lc_llm.create_embeddings()
splits = self._text_splitter.split_documents(rag_docs)

self._vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
retriever = self._vectorstore.as_retriever()
rag_prompt = hub.pull("rlm/rag-prompt")

def _format_docs_(docs):
return "\n\n".join(doc.page_content for doc in docs)

self._rag_chain = (
{"context": retriever | _format_docs_, "question": RunnablePassthrough()}
| rag_prompt
| llm
| StrOutputParser()
)
def generate(self, file_glob: str = "**/*.md") -> None:
"""Generates RAG data from the specified directory.
:param file_glob: The files from which to generate the RAG database.
"""
if not self._rag_chain:
embeddings = lc_llm.create_embeddings()
llm = lc_llm.create_chat_model(temperature=Temperature.DATA_ANALYSIS.temp)
persist_dir: str = str(self.persist_dir(file_glob))
if os.path.exists(persist_dir):
log.info("Recovering vector store from: '%s'", persist_dir)
self._vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embeddings)
else:
with Spinner.COLIMA.run(suffix=msg.loading("documents"), color=VtColor.GREEN) as spinner:
spinner.start()
rag_docs: list[Document] = DirectoryLoader(str(RAG_DIR), glob=file_glob, recursive=True).load()
spinner.stop()
if len(rag_docs) <= 0:
raise DocumentsNotFound(f"Unable to find any document to at: '{persist_dir}'")
self._vectorstore = Chroma.from_documents(
persist_directory=persist_dir,
documents=self._text_splitter.split_documents(rag_docs),
embedding=embeddings,
)

retriever = self._vectorstore.as_retriever()
rag_prompt = self.rag_template

def _format_docs_(docs):
return "\n\n".join(doc.page_content for doc in docs)

self._rag_chain = (
{"context": retriever | _format_docs_, "question": RunnablePassthrough()}
| rag_prompt
| llm
| StrOutputParser()
)
return self._rag_chain


assert (rag := Rag().INSTANCE) is not None
Loading

0 comments on commit 43710f7

Please sign in to comment.