Skip to content

Commit

Permalink
Improve rag processor and adding rag hash to re-summarize if folder f…
Browse files Browse the repository at this point in the history
…iles changes
  • Loading branch information
yorevs committed Oct 8, 2024
1 parent 0940d72 commit ef79f0c
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 90 deletions.
44 changes: 23 additions & 21 deletions src/main/askai/core/askai_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
Copyright (c) 2024, HomeSetup
"""
from hspylib.core.metaclass.classpath import AnyPath

from askai.core.askai_configs import configs
from askai.language.ai_translator import AITranslator
from askai.language.language import Language
Expand Down Expand Up @@ -70,7 +72,7 @@ def t(self, text: AnyStr) -> str:

# Informational

def welcome(self, username: str) -> str:
def welcome(self, username: AnyStr) -> str:
return f"Welcome back {username}, How can I assist you today ?"

def wait(self) -> str:
Expand All @@ -91,7 +93,7 @@ def goodbye(self) -> str:
def smile(self, countdown: int) -> str:
return f"\nSmile {str(countdown)} "

def cmd_success(self, command_line: str) -> str:
def cmd_success(self, command_line: AnyStr) -> str:
return f"OK, command `{command_line}` succeeded"

def searching(self) -> str:
Expand All @@ -100,10 +102,10 @@ def searching(self) -> str:
def scrapping(self) -> str:
return f"Scrapping web site…"

def summarizing(self, path: str | None) -> str:
return f"Summarizing documents'{'at: ' + path if path else ''}'…"
def summarizing(self, path: AnyPath | None = None) -> str:
return f"Summarizing documents{' at: ' + str(path) if path else ''}…"

def summary_succeeded(self, path: str, glob: str) -> str:
def summary_succeeded(self, path: AnyPath, glob: str) -> str:
return f"Summarization of docs at: **{path}/{glob}** succeeded !"

def enter_qna(self) -> str:
Expand All @@ -130,24 +132,24 @@ def leave_chat(self) -> str:
def press_esc_enter(self) -> str:
return "Type [exit] to exit Q & A mode"

def device_switch(self, device_info: str) -> str:
def device_switch(self, device_info: AnyStr) -> str:
return f"\nSwitching to Audio Input device: `{device_info}`\n"

# Debug messages

def photo_captured(self, photo: str) -> str:
def photo_captured(self, photo: AnyStr) -> str:
return f"~~[DEBUG]~~ WebCam photo captured: `{photo}`"

def screenshot_saved(self, screenshot: str) -> str:
def screenshot_saved(self, screenshot: AnyStr) -> str:
return f"~~[DEBUG]~~ Screenshot saved: `{screenshot}`"

def executing(self, command_line: str) -> str:
def executing(self, command_line: AnyStr) -> str:
return f"~~[DEBUG]~~ Executing: `{command_line}`…"

def analysis(self, result: str) -> str:
def analysis(self, result: AnyStr) -> str:
return f"~~[DEBUG]~~ Analysis result => {result}"

def assert_acc(self, status: str, details: str) -> str:
def assert_acc(self, status: AnyStr, details: AnyStr) -> str:
match status.casefold():
case "red":
cl = "%RED%"
Expand All @@ -161,25 +163,25 @@ def assert_acc(self, status: str, details: str) -> str:
cl = ""
return f"~~[DEBUG]~~ Accuracy result => {cl}{status}:%NC% {details}"

def action_plan(self, plan_text: str) -> str:
def action_plan(self, plan_text: AnyStr) -> str:
return f"~~[DEBUG]~~ Action plan > {plan_text}"

def x_reference(self, pathname: str) -> str:
def x_reference(self, pathname: AnyPath) -> str:
return f"~~[DEBUG]~~ Resolving X-References: `{pathname}`…"

def describe_image(self, image_path: str) -> str:
def describe_image(self, image_path: AnyPath) -> str:
return f"~~[DEBUG]~~ Describing image: `{image_path}`…"

def model_select(self, model: str) -> str:
def model_select(self, model: AnyStr) -> str:
return f"~~[DEBUG]~~ Using routing model: `{model}`"

def task(self, task: str) -> str:
def task(self, task: AnyStr) -> str:
return f"~~[DEBUG]~~ > `Task:` {task}"

def final_query(self, query: str) -> str:
def final_query(self, query: AnyStr) -> str:
return f"~~[DEBUG]~~ > Final query: `{query}`"

def refine_answer(self, answer: str) -> str:
def refine_answer(self, answer: AnyStr) -> str:
return f"~~[DEBUG]~~ > Refining answer: `{answer}`"

def no_caption(self) -> str:
Expand All @@ -190,7 +192,7 @@ def no_good_result(self) -> str:

# Warnings and alerts

def no_output(self, source: str) -> str:
def no_output(self, source: AnyStr) -> str:
return f"The {source} didn't produce an output !"

def access_grant(self) -> str:
Expand All @@ -210,10 +212,10 @@ def invalid_response(self, response_text: AnyStr) -> str:
def invalid_command(self, response_text: AnyStr) -> str:
return f"Invalid **AskAI** command => '{response_text}' !"

def cmd_no_exist(self, command: str) -> str:
def cmd_no_exist(self, command: AnyStr) -> str:
return f"Command: `{command}' does not exist !"

def cmd_failed(self, cmd_line: str, error_msg: str) -> str:
def cmd_failed(self, cmd_line: AnyStr, error_msg: AnyStr) -> str:
return f"Command: `{cmd_line}' failed to execute -> {error_msg}!"

def camera_not_open(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/main/askai/core/askai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
from askai.__classpath__ import classpath
from askai.core.askai_configs import configs
from askai.core.support.platform import get_os, get_shell, get_user, SupportedPlatforms, SupportedShells
from askai.core.support.os_utils import get_os, get_shell, get_user, SupportedPlatforms, SupportedShells
from askai.core.support.utilities import read_resource
from functools import lru_cache
from hspylib.core.metaclass.singleton import Singleton
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.classpath import AnyPath
from hspylib.core.preconditions import check_state
from hspylib.core.tools.commons import file_is_not_empty
from hspylib.core.tools.commons import file_is_not_empty, dirname
from hspylib.core.tools.text_tools import ensure_endswith, hash_text
from langchain_community.document_loaders import CSVLoader
from langchain_community.vectorstores import FAISS
Expand All @@ -43,36 +43,50 @@ class RAGProvider:

RAG_DIR: Path = Path(os.path.join(classpath.resource_path, "rag"))

@staticmethod
def copy_rag(path_name: AnyPath, dest_name: AnyPath | None = None) -> bool:
"""Copy the RAG documents into the AskAI RAG directory.
@classmethod
def copy_rag(cls, path_name: AnyPath, dest_name: AnyPath | None = None, rag_dir: AnyPath = RAG_EXT_DIR) -> bool:
"""Copy the RAG documents into the specified RAG directory.
:param path_name: The path of the RAG documents to copy.
:param dest_name: The destination, within the RAG directory, where the documents will be copied to. If None,
defaults to a hashed directory based on the source path.
:param rag_dir: The directory where RAG documents will be copied.
:return: True if the copy operation was successful, False otherwise.
"""
src_path: PathObject = PathObject.of(path_name)
with open(f"{RAG_EXT_DIR}/rag-documents.txt", "w") as f_docs:
docs: list[str] = list()
if src_path.exists and src_path.is_file:
file: str = f"{RAG_EXT_DIR}/{src_path.filename}"
copyfile(str(src_path), file)
elif src_path.exists and src_path.is_dir:
shutil.copytree(
str(src_path),
str(RAG_EXT_DIR / (dest_name or hash_text(str(src_path))[:8])),
dirs_exist_ok=True,
symlinks=True,
)
else:
return False
files: list[str] = glob.glob(f"{str(RAG_EXT_DIR)}/**/*.*", recursive=True)
list(map(docs.append, files))
f_docs.write("Available documents for RAG:" + os.linesep * 2)
f_docs.writelines(set(ensure_endswith(d, os.linesep) for d in docs))
if src_path.exists and src_path.is_file:
file: str = f"{rag_dir}/{src_path.filename}"
copyfile(str(src_path), file)
elif src_path.exists and src_path.is_dir:
shutil.copytree(
str(src_path),
str(rag_dir / (dest_name or hash_text(str(src_path))[:8])),
dirs_exist_ok=True,
symlinks=True,
)
else:
return False
files: list[str] = sorted(glob.glob(f"{str(rag_dir)}/**/*.*", recursive=True))
rag_files: str = ''.join(list(ensure_endswith(d, os.linesep) for d in files))
rag_docs_file: Path = Path(os.path.join(rag_dir), "rag-documents.txt")
rag_docs_file.write_text(rag_files)

return True

@staticmethod
def requires_update(rag_dir: AnyPath = RAG_EXT_DIR) -> bool:
"""Check whether the RAG directory has changed and therefore, requires an update from the Chroma DB.
:return: True if an update is required, False otherwise
"""
rag_docs_file: Path = Path(os.path.join(rag_dir), "rag-documents.txt")
rag_hash_file: Path = Path(os.path.join(dirname(str(rag_docs_file)), ".rag-hash"))
files_hash: str = hash_text(Path(rag_docs_file).read_text())
if not os.path.exists(str(rag_docs_file)) or not os.path.exists(str(rag_hash_file)):
rag_hash_file.write_text(files_hash)
return True
rag_hash: str = rag_hash_file.read_text()
rag_hash_file.write_text(files_hash)
return files_hash != rag_hash

def __init__(self, rag_filepath: str):
self._rag_db = None
self._rag_path: str = os.path.join(str(self.RAG_DIR), rag_filepath)
Expand Down
11 changes: 6 additions & 5 deletions src/main/askai/core/component/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,20 @@ def generate(self, folder: AnyPath, glob: str) -> bool:
self._glob: str = glob.strip()
events.reply.emit(reply=AIReply.info(msg.summarizing(self.sum_path)))
embeddings: Embeddings = lc_llm.create_embeddings()
v_store: Chroma | None = None

try:
if self.persist_dir.exists():
log.info("Recovering vector store from: '%s'", self.persist_dir)
v_store = Chroma(persist_directory=str(self.persist_dir), embedding_function=embeddings)
else:
log.info("Summarizing documents from '%s'", self.sum_path)
with Status(f'[green]{msg.summarizing()}[/green]'):
with Status(f'[green]{msg.summarizing(self.folder)}[/green]'):
documents: list[Document] = DirectoryLoader(self.folder, glob=self.glob).load()
if len(documents) <= 0:
raise DocumentsNotFound(f"Unable to find any document to summarize at: '{self.sum_path}'")
texts: list[Document] = self._text_splitter.split_documents(documents)
v_store = Chroma.from_documents(texts, embeddings, persist_directory=str(self.persist_dir))
if len(documents) <= 0:
raise DocumentsNotFound(f"Unable to find any document to summarize at: '{self.sum_path}'")
texts: list[Document] = self._text_splitter.split_documents(documents)
v_store = Chroma.from_documents(texts, embeddings, persist_directory=str(self.persist_dir))

self._retriever = RetrievalQA.from_chain_type(
llm=lc_llm.create_chat_model(), chain_type="stuff", retriever=v_store.as_retriever()
Expand Down
70 changes: 41 additions & 29 deletions src/main/askai/core/features/processors/rag.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from askai.core.askai_configs import configs
from askai.core.askai_events import events
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.cache_service import PERSIST_DIR
from askai.core.engine.openai.temperature import Temperature
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
from askai.core.support.rag_provider import RAG_EXT_DIR
from askai.exception.exceptions import DocumentsNotFound, TerminatingQuery
import logging as log
import os
import shutil
from functools import lru_cache
from pathlib import Path
from typing import Optional

from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.classpath import AnyPath
from hspylib.core.metaclass.singleton import Singleton
Expand All @@ -19,12 +16,18 @@
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pathlib import Path
from rich.status import Status
from typing import Optional

import logging as log
import os
from askai.core.askai_configs import configs
from askai.core.askai_events import events
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.cache_service import PERSIST_DIR
from askai.core.component.rag_provider import RAGProvider, RAG_EXT_DIR
from askai.core.engine.openai.temperature import Temperature
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
from askai.exception.exceptions import DocumentsNotFound, TerminatingQuery


class Rag(metaclass=Singleton):
Expand All @@ -51,10 +54,11 @@ def rag_template(self) -> BasePromptTemplate:
])
# fmt: on

def persist_dir(self, file_glob: AnyPath) -> Path:
@lru_cache
def persist_dir(self, rag_dir: AnyPath, file_glob: AnyPath) -> Path:
"""TODO"""
summary_hash = hash_text(file_glob)
return Path(f"{PERSIST_DIR}/{summary_hash}")
summary_hash = hash_text(os.path.join(rag_dir, file_glob))
return Path(os.path.join(str(PERSIST_DIR), summary_hash))

def process(self, question: str, **_) -> Optional[str]:
"""Process the user question to retrieve the final response.
Expand All @@ -68,34 +72,42 @@ def process(self, question: str, **_) -> Optional[str]:
events.mode_changed.emit(mode="DEFAULT")
return None

# FIXME Include kwargs to specify rag dir and glob
self.generate()

if not (output := self._rag_chain.invoke(question)):
output = msg.invalid_response(output)

return output

def generate(self, file_glob: str = "**/*.md") -> None:
def generate(self, rag_dir: AnyPath = RAG_EXT_DIR, file_glob: AnyPath = "**/*.md") -> None:
"""Generates RAG data from the specified directory.
:param rag_dir: The directory containing the files for RAG data generation
:param file_glob: The files from which to generate the RAG database.
:return: None
"""
if not self._rag_chain:
# Check weather the rag directory requires update.
if RAGProvider.requires_update(RAG_EXT_DIR):
rag_db_dir: Path = self.persist_dir(rag_dir, file_glob)
shutil.rmtree(str(rag_db_dir), ignore_errors=True)

embeddings = lc_llm.create_embeddings()
llm = lc_llm.create_chat_model(temperature=Temperature.DATA_ANALYSIS.temp)
persist_dir: str = str(self.persist_dir(file_glob))
if os.path.exists(persist_dir):
persist_dir: Path = self.persist_dir(rag_dir, file_glob)
if persist_dir.exists() and persist_dir.is_dir():
log.info("Recovering vector store from: '%s'", persist_dir)
self._vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embeddings)
self._vectorstore = Chroma(persist_directory=str(persist_dir), embedding_function=embeddings)
else:
with Status(f'[green]{msg.summarizing()}[/green]'):
rag_docs: list[Document] = DirectoryLoader(str(RAG_EXT_DIR), glob=file_glob, recursive=True).load()
if len(rag_docs) <= 0:
raise DocumentsNotFound(f"Unable to find any document to at: '{persist_dir}'")
self._vectorstore = Chroma.from_documents(
persist_directory=persist_dir,
documents=self._text_splitter.split_documents(rag_docs),
embedding=embeddings,
)
rag_docs: list[Document] = DirectoryLoader(str(rag_dir), glob=file_glob, recursive=True).load()
if len(rag_docs) <= 0:
raise DocumentsNotFound(f"Unable to find any document to at: '{persist_dir}'")
self._vectorstore = Chroma.from_documents(
persist_directory=str(persist_dir),
documents=self._text_splitter.split_documents(rag_docs),
embedding=embeddings,
)

retriever = self._vectorstore.as_retriever()
rag_prompt = self.rag_template
Expand Down
2 changes: 1 addition & 1 deletion src/main/askai/core/features/processors/task_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.geo_location import geo_location
from askai.core.component.rag_provider import RAGProvider
from askai.core.engine.openai.temperature import Temperature
from askai.core.enums.acc_color import AccColor
from askai.core.enums.response_model import ResponseModel
Expand All @@ -29,7 +30,6 @@
from askai.core.model.ai_reply import AIReply
from askai.core.model.model_result import ModelResult
from askai.core.support.langchain_support import lc_llm
from askai.core.support.rag_provider import RAGProvider
from askai.core.support.shared_instances import shared
from askai.exception.exceptions import InaccurateResponse, InterruptionRequest, TerminatingQuery
from hspylib.core.exception.exceptions import InvalidArgumentError
Expand Down
2 changes: 1 addition & 1 deletion src/main/askai/core/features/router/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from askai.core.askai_events import events
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.rag_provider import RAGProvider
from askai.core.engine.openai.temperature import Temperature
from askai.core.enums.acc_color import AccColor
from askai.core.model.acc_response import AccResponse
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
from askai.core.support.rag_provider import RAGProvider
from askai.core.support.shared_instances import shared
from askai.exception.exceptions import InaccurateResponse, InterruptionRequest, TerminatingQuery
from langchain_core.messages import AIMessage
Expand Down
Loading

0 comments on commit ef79f0c

Please sign in to comment.