diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py index 230ce9ddc..5a38747f1 100644 --- a/libs/kotaemon/kotaemon/base/component.py +++ b/libs/kotaemon/kotaemon/base/component.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import AsyncGenerator, Iterator, Optional +from typing import Any, AsyncGenerator, Iterator, Optional from theflow import Function, Node, Param, lazy @@ -58,7 +58,7 @@ def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None: @abstractmethod def run( self, *args, **kwargs - ) -> Document | list[Document] | Iterator[Document] | None: + ) -> Document | list[Document] | Iterator[Document] | None | Any: """Run the component.""" ... diff --git a/libs/kotaemon/kotaemon/base/schema.py b/libs/kotaemon/kotaemon/base/schema.py index 0769b6c89..a153ed30f 100644 --- a/libs/kotaemon/kotaemon/base/schema.py +++ b/libs/kotaemon/kotaemon/base/schema.py @@ -32,12 +32,13 @@ class Document(BaseDocument): channel: the channel to show the document. Optional.: - chat: show in chat message - info: show in information panel + - index: show in index panel - debug: show in debug panel """ content: Any = None source: Optional[str] = None - channel: Optional[Literal["chat", "info", "debug"]] = None + channel: Optional[Literal["chat", "info", "index", "debug"]] = None def __init__(self, content: Optional[Any] = None, *args, **kwargs): if content is None: diff --git a/libs/kotaemon/kotaemon/indices/ingests/files.py b/libs/kotaemon/kotaemon/indices/ingests/files.py index e5e325e02..3eb536110 100644 --- a/libs/kotaemon/kotaemon/indices/ingests/files.py +++ b/libs/kotaemon/kotaemon/indices/ingests/files.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Type +from llama_index.readers import PDFReader from llama_index.readers.base import BaseReader from kotaemon.base import BaseComponent, Document, Param @@ -17,18 +18,20 @@ UnstructuredReader, ) -KH_DEFAULT_FILE_EXTRACTORS: dict[str, Type[BaseReader]] = { - ".xlsx": PandasExcelReader, - ".docx": UnstructuredReader, - ".xls": UnstructuredReader, - ".doc": UnstructuredReader, - ".html": HtmlReader, - ".mhtml": MhtmlReader, - ".png": UnstructuredReader, - ".jpeg": UnstructuredReader, - ".jpg": UnstructuredReader, - ".tiff": UnstructuredReader, - ".tif": UnstructuredReader, +unstructured = UnstructuredReader() +KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = { + ".xlsx": PandasExcelReader(), + ".docx": unstructured, + ".xls": unstructured, + ".doc": unstructured, + ".html": HtmlReader(), + ".mhtml": MhtmlReader(), + ".png": unstructured, + ".jpeg": unstructured, + ".jpg": unstructured, + ".tiff": unstructured, + ".tif": unstructured, + ".pdf": PDFReader(), } @@ -64,7 +67,7 @@ class DocumentIngestor(BaseComponent): def _get_reader(self, input_files: list[str | Path]): """Get appropriate readers for the input files based on file extension""" file_extractors: dict[str, BaseReader] = { - ext: cls() for ext, cls in KH_DEFAULT_FILE_EXTRACTORS.items() + ext: reader for ext, reader in KH_DEFAULT_FILE_EXTRACTORS.items() } for ext, cls in self.override_file_extractors.items(): file_extractors[ext] = cls() diff --git a/libs/kotaemon/kotaemon/loaders/base.py b/libs/kotaemon/kotaemon/loaders/base.py index 2e52f7292..52bef490f 100644 --- a/libs/kotaemon/kotaemon/loaders/base.py +++ b/libs/kotaemon/kotaemon/loaders/base.py @@ -8,6 +8,8 @@ class BaseReader(BaseComponent): + """The base class for all readers""" + ... diff --git a/libs/ktem/ktem/index/base.py b/libs/ktem/ktem/index/base.py index 002b765c4..c3e666d7b 100644 --- a/libs/ktem/ktem/index/base.py +++ b/libs/ktem/ktem/index/base.py @@ -126,7 +126,7 @@ def get_indexing_pipeline( ... def get_retriever_pipelines( - self, settings: dict, selected: Any = None + self, settings: dict, user_id: int, selected: Any = None ) -> list["BaseComponent"]: """Return the retriever pipelines to retrieve the entity from the index""" return [] diff --git a/libs/ktem/ktem/index/file/base.py b/libs/ktem/ktem/index/file/base.py index 4f28f51ac..a489a8e51 100644 --- a/libs/ktem/ktem/index/file/base.py +++ b/libs/ktem/ktem/index/file/base.py @@ -1,10 +1,18 @@ from pathlib import Path -from typing import Optional +from typing import Generator, Optional -from kotaemon.base import BaseComponent +from kotaemon.base import BaseComponent, Document, Param class BaseFileIndexRetriever(BaseComponent): + + Source = Param(help="The SQLAlchemy Source table") + Index = Param(help="The SQLAlchemy Index table") + VS = Param(help="The VectorStore") + DS = Param(help="The DocStore") + FSPath = Param(help="The file storage path") + user_id = Param(help="The user id") + @classmethod def get_user_settings(cls) -> dict: """Get the user settings for indexing @@ -24,20 +32,6 @@ def get_pipeline( ) -> "BaseFileIndexRetriever": raise NotImplementedError - def set_resources(self, resources: dict): - """Set the resources for the indexing pipeline - - This will setup the tables, the vector store and docstore. - - Args: - resources (dict): the resources for the indexing pipeline - """ - self._Source = resources["Source"] - self._Index = resources["Index"] - self._VS = resources["VectorStore"] - self._DS = resources["DocStore"] - self._fs_path = resources["FileStoragePath"] - class BaseFileIndexIndexing(BaseComponent): """The pipeline to index information into the data store @@ -54,11 +48,45 @@ class BaseFileIndexIndexing(BaseComponent): - self._DS: the docstore """ - def run(self, file_paths: str | Path | list[str | Path], *args, **kwargs): + Source = Param(help="The SQLAlchemy Source table") + Index = Param(help="The SQLAlchemy Index table") + VS = Param(help="The VectorStore") + DS = Param(help="The DocStore") + FSPath = Param(help="The file storage path") + user_id = Param(help="The user id") + + def run( + self, file_paths: str | Path | list[str | Path], *args, **kwargs + ) -> tuple[list[str | None], list[str | None]]: """Run the indexing pipeline Args: file_paths (str | Path | list[str | Path]): the file paths to index + + Returns: + - the indexed file ids (each file id corresponds to an input file path, or + None if the indexing failed for that file path) + - the error messages (each error message corresponds to an input file path, + or None if the indexing was successful for that file path) + """ + raise NotImplementedError + + def stream( + self, file_paths: str | Path | list[str | Path], *args, **kwargs + ) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]: + """Stream the indexing pipeline + + Args: + file_paths (str | Path | list[str | Path]): the file paths to index + + Yields: + Document: the output message to the UI, must have channel == index or debug + + Returns: + - the indexed file ids (each file id corresponds to an input file path, or + None if the indexing failed for that file path) + - the error messages (each error message corresponds to an input file path, + or None if the indexing was successful for that file path) """ raise NotImplementedError @@ -78,20 +106,6 @@ def get_user_settings(cls) -> dict: """ return {} - def set_resources(self, resources: dict): - """Set the resources for the indexing pipeline - - This will setup the tables, the vector store and docstore. - - Args: - resources (dict): the resources for the indexing pipeline - """ - self._Source = resources["Source"] - self._Index = resources["Index"] - self._VS = resources["VectorStore"] - self._DS = resources["DocStore"] - self._fs_path = resources["FileStoragePath"] - def copy_to_filestorage( self, file_paths: str | Path | list[str | Path] ) -> list[str]: @@ -113,7 +127,7 @@ def copy_to_filestorage( for file_path in file_paths: with open(file_path, "rb") as f: paths.append(sha256(f.read()).hexdigest()) - shutil.copy(file_path, self._fs_path / paths[-1]) + shutil.copy(file_path, self.FSPath / paths[-1]) return paths diff --git a/libs/ktem/ktem/index/file/index.py b/libs/ktem/ktem/index/file/index.py index 0d6838d96..e3d4405d4 100644 --- a/libs/ktem/ktem/index/file/index.py +++ b/libs/ktem/ktem/index/file/index.py @@ -362,13 +362,17 @@ def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing: stripped_settings[key] = value obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config) - obj.set_resources(resources=self._resources) - obj._user_id = user_id + obj.Source = self._resources["Source"] + obj.Index = self._resources["Index"] + obj.VS = self._vs + obj.DS = self._docstore + obj.FSPath = self._fs_path + obj.user_id = user_id return obj def get_retriever_pipelines( - self, settings: dict, selected: Any = None + self, settings: dict, user_id: int, selected: Any = None ) -> list["BaseFileIndexRetriever"]: # retrieval settings prefix = f"index.options.{self.id}." @@ -387,7 +391,12 @@ def get_retriever_pipelines( obj = cls.get_pipeline(stripped_settings, self.config, selected_ids) if obj is None: continue - obj.set_resources(self._resources) + obj.Source = self._resources["Source"] + obj.Index = self._resources["Index"] + obj.VS = self._vs + obj.DS = self._docstore + obj.FSPath = self._fs_path + obj.user_id = user_id retrievers.append(obj) return retrievers diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index 558d2a900..450bb3a1e 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -7,13 +7,13 @@ from functools import lru_cache from hashlib import sha256 from pathlib import Path -from typing import Optional +from typing import Generator, Optional -import gradio as gr -from ktem.components import filestorage_path from ktem.db.models import engine from ktem.embeddings.manager import embedding_models_manager from ktem.llms.manager import llms +from llama_index.readers.base import BaseReader +from llama_index.readers.file.base import default_file_metadata_func from llama_index.vector_stores import ( FilterCondition, FilterOperator, @@ -26,10 +26,12 @@ from theflow.settings import settings from theflow.utils.modules import import_dotted_string -from kotaemon.base import RetrievedDocument +from kotaemon.base import BaseComponent, Document, Node, Param, RetrievedDocument +from kotaemon.embeddings import BaseEmbeddings from kotaemon.indices import VectorIndexing, VectorRetrieval -from kotaemon.indices.ingests import DocumentIngestor +from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS from kotaemon.indices.rankings import BaseReranking, LLMReranking +from kotaemon.indices.splitters import BaseSplitter, TokenSplitter from .base import BaseFileIndexIndexing, BaseFileIndexRetriever @@ -43,7 +45,7 @@ def dev_settings(): if hasattr(settings, "FILE_INDEX_PIPELINE_FILE_EXTRACTORS"): file_extractors = { - key: import_dotted_string(value, safe=False) + key: import_dotted_string(value, safe=False)() for key, value in settings.FILE_INDEX_PIPELINE_FILE_EXTRACTORS.items() } @@ -72,12 +74,20 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): mmr: whether to use mmr to re-rank the documents """ - vector_retrieval: VectorRetrieval = VectorRetrieval.withx() + embedding: BaseEmbeddings reranker: BaseReranking = LLMReranking.withx() get_extra_table: bool = False mmr: bool = False top_k: int = 5 + @Node.auto(depends_on=["embedding", "VS", "DS"]) + def vector_retrieval(self) -> VectorRetrieval: + return VectorRetrieval( + embedding=self.embedding, + vector_store=self.VS, + doc_store=self.DS, + ) + def run( self, text: str, @@ -95,13 +105,11 @@ def run( logger.info(f"Skip retrieval because of no selected files: {self}") return [] - Index = self._Index - retrieval_kwargs = {} with Session(engine) as session: - stmt = select(Index).where( - Index.relation_type == "vector", - Index.source_id.in_(doc_ids), # type: ignore + stmt = select(self.Index).where( + self.Index.relation_type == "vector", + self.Index.source_id.in_(doc_ids), ) results = session.execute(stmt) vs_ids = [r[0].target_id for r in results.all()] @@ -186,7 +194,7 @@ def get_user_settings(cls) -> dict: "component": "dropdown", }, "num_retrieval": { - "name": "Number of documents to retrieve", + "name": "Number of document chunks to retrieve", "value": 3, "component": "number", }, @@ -228,6 +236,11 @@ def get_pipeline(cls, user_settings, index_settings, selected): get_extra_table=user_settings["prioritize_table"], top_k=user_settings["num_retrieval"], mmr=user_settings["mmr"], + embedding=embedding_models_manager[ + index_settings.get( + "embedding", embedding_models_manager.get_default_name() + ) + ], ) if not user_settings["use_reranking"]: retriever.reranker = None # type: ignore @@ -236,226 +249,346 @@ def get_pipeline(cls, user_settings, index_settings, selected): user_settings["reranking_llm"], llms.get_default() ) - retriever.vector_retrieval.embedding = embedding_models_manager[ - index_settings.get("embedding", embedding_models_manager.get_default_name()) - ] kwargs = {".doc_ids": selected} retriever.set_run(kwargs, temp=True) return retriever - def set_resources(self, resources: dict): - super().set_resources(resources) - self.vector_retrieval.vector_store = self._VS - self.vector_retrieval.doc_store = self._DS +class IndexPipeline(BaseComponent): + """Index a single file""" -class IndexDocumentPipeline(BaseFileIndexIndexing): - """Store the documents and index the content into vector store and doc store + loader: BaseReader + splitter: BaseSplitter + chunk_batch_size: int = 50 - Args: - indexing_vector_pipeline: pipeline to index the documents - file_ingestor: ingestor to ingest the documents - """ + Source = Param(help="The SQLAlchemy Source table") + Index = Param(help="The SQLAlchemy Index table") + VS = Param(help="The VectorStore") + DS = Param(help="The DocStore") + FSPath = Param(help="The file storage path") + user_id = Param(help="The user id") + embedding: BaseEmbeddings - indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx() - file_ingestor: DocumentIngestor = DocumentIngestor.withx() + @Node.auto(depends_on=["Source", "Index", "embedding"]) + def vector_indexing(self) -> VectorIndexing: + return VectorIndexing( + vector_store=self.VS, doc_store=self.DS, embedding=self.embedding + ) - def run( - self, - file_paths: str | Path | list[str | Path], - reindex: bool = False, - **kwargs, # type: ignore - ): - """Index the list of documents + def handle_docs(self, docs, file_id, file_name) -> Generator[Document, None, int]: + chunks = [] + n_chunks = 0 + for cidx, chunk in enumerate(self.splitter(docs)): + chunks.append(chunk) + if cidx % self.chunk_batch_size == 0: + self.handle_chunks(chunks, file_id) + n_chunks += len(chunks) + chunks = [] + yield Document( + f" => [{file_name}] Processed {n_chunks} chunks", channel="debug" + ) - This function will extract the files, persist the files to storage, - index the files. + if chunks: + self.handle_chunks(chunks, file_id) + n_chunks += len(chunks) + yield Document( + f" => [{file_name}] Processed {n_chunks} chunks", channel="debug" + ) - Args: - file_paths: list of file paths to index - reindex: whether to force reindexing the files if they exist + return n_chunks - Returns: - list of split nodes - """ - Source = self._Source - Index = self._Index + def handle_chunks(self, chunks, file_id): + """Run chunks""" + # run embedding, add to both vector store and doc store + self.vector_indexing(chunks) - if not isinstance(file_paths, list): - file_paths = [file_paths] + # record in the index + with Session(engine) as session: + nodes = [] + for chunk in chunks: + nodes.append( + self.Index( + source_id=file_id, + target_id=chunk.doc_id, + relation_type="document", + ) + ) + nodes.append( + self.Index( + source_id=file_id, + target_id=chunk.doc_id, + relation_type="vector", + ) + ) + session.add_all(nodes) + session.commit() - to_index: list[str] = [] - file_to_hash: dict[str, str] = {} - errors = [] - to_update = [] + def get_id_if_exists(self, file_path: Path) -> Optional[str]: + """Check if the file is already indexed - for file_path in file_paths: - abs_path = str(Path(file_path).resolve()) - with open(abs_path, "rb") as fi: - file_hash = sha256(fi.read()).hexdigest() - - file_to_hash[abs_path] = file_hash - - with Session(engine) as session: - statement = select(Source).where(Source.name == Path(abs_path).name) - item = session.execute(statement).first() - - if item: - if not reindex: - errors.append(Path(abs_path).name) - continue - else: - to_update.append(Path(abs_path).name) - - to_index.append(abs_path) - - if errors: - error_files = ", ".join(errors) - if len(error_files) > 100: - error_files = error_files[:80] + "..." - print( - "Skip these files already exist. Please rename/remove them or " - f"enable reindex:\n{errors}" - ) - self.warning( - "Skip these files already exist. Please rename/remove them or " - f"enable reindex:\n{error_files}" - ) + Args: + file_path: the path to the file - if not to_index: - return [], [] + Returns: + the file id if the file is indexed, otherwise None + """ + with Session(engine) as session: + stmt = select(self.Source).where(self.Source.name == file_path.name) + item = session.execute(stmt).first() + if item: + return item[0].id - # persist the files to storage - for path in to_index: - shutil.copy(path, filestorage_path / file_to_hash[path]) + return None - # extract the file & prepare record info - file_to_source: dict = {} - extraction_errors = [] - nodes = [] - for file_path, file_hash in file_to_hash.items(): - if str(Path(file_path).resolve()) not in to_index: - continue + def store_file(self, file_path: Path) -> str: + """Store file into the database and storage, return the file id - extraction_result = self.file_ingestor(file_path) - if not extraction_result: - extraction_errors.append(Path(file_path).name) - continue - nodes.extend(extraction_result) - source = Source( - name=Path(file_path).name, - path=file_hash, - size=Path(file_path).stat().st_size, - user=self._user_id, # type: ignore - ) - file_to_source[file_path] = source + Args: + file_path: the path to the file - if extraction_errors: - msg = "Failed to extract these files: {}".format( - ", ".join(extraction_errors) - ) - print(msg) - self.warning(msg) - - if not nodes: - return [], [] - - print( - "Extracted", - len(to_index) - len(extraction_errors), - "files into", - len(nodes), - "nodes", + Returns: + the file id + """ + with file_path.open("rb") as fi: + file_hash = sha256(fi.read()).hexdigest() + + shutil.copy(file_path, self.FSPath / file_hash) + source = self.Source( + name=file_path.name, + path=file_hash, + size=file_path.stat().st_size, + user=self.user_id, # type: ignore ) - - # index the files - print("Indexing the files into vector store") - self.indexing_vector_pipeline(nodes) - print("Finishing indexing the files into vector store") - - # persist to the index - print("Persisting the vector and the document into index") - file_ids = [] - to_update = list(set(to_update)) with Session(engine) as session: - if to_update: - session.execute(delete(Source).where(Source.name.in_(to_update))) - - for source in file_to_source.values(): - session.add(source) + session.add(source) session.commit() - for source in file_to_source.values(): - file_ids.append(source.id) + file_id = source.id - for node in nodes: - file_path = str(node.metadata["file_path"]) - node.source = str(file_to_source[file_path].id) - file_to_source[file_path].text_length += len(node.text) + return file_id - session.flush() - session.commit() + def finish(self, file_id: str, file_path: Path) -> str: + """Finish the indexing""" + with Session(engine) as session: + stmt = select(self.Index.target_id).where(self.Index.source_id == file_id) + doc_ids = [_[0] for _ in session.execute(stmt)] + if doc_ids: + docs = self.DS.get(doc_ids) + stmt = select(self.Source).where(self.Source.id == file_id) + result = session.execute(stmt).first() + if result: + item = result[0] + item.text_length = sum([len(doc.text) for doc in docs]) + session.add(item) + session.commit() + + return file_id + + def delete_file(self, file_id: str): + """Delete a file from the db, including its chunks in docstore and vectorstore + Args: + file_id: the file id + """ with Session(engine) as session: - for node in nodes: - index = Index( - source_id=node.source, - target_id=node.doc_id, - relation_type="document", + session.execute(delete(self.Source).where(self.Source.id == file_id)) + vs_ids, ds_ids = [], [] + index = session.execute( + select(self.Index).where(self.Index.source_id == file_id) + ).all() + for each in index: + if each[0].relation_type == "vector": + vs_ids.append(each[0].target_id) + else: + ds_ids.append(each[0].target_id) + session.delete(each[0]) + session.commit() + self.VS.delete(vs_ids) + self.DS.delete(ds_ids) + + def run(self, file_path: str | Path, reindex: bool, **kwargs) -> str: + """Index the file and return the file id""" + # check for duplication + file_path = Path(file_path).resolve() + file_id = self.get_id_if_exists(file_path) + if file_id is not None: + if not reindex: + raise ValueError( + f"File {file_path.name} already indexed. Please rerun with " + "reindex=True to force reindexing." ) - session.add(index) - for node in nodes: - index = Index( - source_id=node.source, - target_id=node.doc_id, - relation_type="vector", + else: + # remove the existing records + self.delete_file(file_id) + file_id = self.store_file(file_path) + else: + # add record to db + file_id = self.store_file(file_path) + + # extract the file + extra_info = default_file_metadata_func(str(file_path)) + docs = self.loader.load_data(file_path, extra_info=extra_info) + for _ in self.handle_docs(docs, file_id, file_path.name): + continue + self.finish(file_id, file_path) + + return file_id + + def stream( + self, file_path: str | Path, reindex: bool, **kwargs + ) -> Generator[Document, None, str]: + # check for duplication + file_path = Path(file_path).resolve() + file_id = self.get_id_if_exists(file_path) + if file_id is not None: + if not reindex: + raise ValueError( + f"File {file_path.name} already indexed. Please rerun with " + "reindex=True to force reindexing." ) - session.add(index) - session.commit() + else: + # remove the existing records + yield Document(f" => Removing old {file_path.name}", channel="debug") + self.delete_file(file_id) + file_id = self.store_file(file_path) + else: + # add record to db + file_id = self.store_file(file_path) - print("Finishing persisting the vector and the document into index") - print(f"{len(nodes)} nodes are indexed") - return nodes, file_ids + # extract the file + extra_info = default_file_metadata_func(str(file_path)) + yield Document(f" => Converting {file_path.name} to text", channel="debug") + docs = self.loader.load_data(file_path, extra_info=extra_info) + yield Document(f" => Converted {file_path.name} to text", channel="debug") + yield from self.handle_docs(docs, file_id, file_path.name) - @classmethod - def get_user_settings(cls) -> dict: - return { - "index_parser": { - "name": "Index parser", - "value": "normal", - "choices": [ - ("PDF text parser", "normal"), - ("Mathpix", "mathpix"), - ("Advanced ocr", "ocr"), - ("Multimodal parser", "multimodal"), - ], - "component": "dropdown", - }, - } + self.finish(file_id, file_path) - @classmethod - def get_pipeline(cls, user_settings, index_settings) -> "IndexDocumentPipeline": - """Get the pipeline based on the setting""" - obj = cls() - obj.file_ingestor.pdf_mode = user_settings["index_parser"] - - file_extractors, chunk_size, chunk_overlap = dev_settings() - if file_extractors: - obj.file_ingestor.override_file_extractors = file_extractors - if chunk_size: - obj.file_ingestor.text_splitter.chunk_size = chunk_size - if chunk_overlap: - obj.file_ingestor.text_splitter.chunk_overlap = chunk_overlap - - obj.indexing_vector_pipeline.embedding = embedding_models_manager[ - index_settings.get("embedding", embedding_models_manager.get_default_name()) - ] + yield Document(f" => Finished indexing {file_path.name}", channel="debug") + return file_id + + +class IndexDocumentPipeline(BaseFileIndexIndexing): + """Index the file. Decide which pipeline based on the file type. + + This method is essentially a factory to decide which indexing pipeline to use. + We can decide the pipeline programmatically, and/or automatically based on an LLM. + If we based on the LLM, essentially we will log the LLM thought process in a file, + and then during the indexing, we will read that file to decide which pipeline + to use, and then log the operation in that file. Overtime, the LLM can learn to + decide which pipeline should be used. + """ + + embedding: BaseEmbeddings + + @classmethod + def get_pipeline(cls, user_settings, index_settings) -> BaseFileIndexIndexing: + obj = cls( + embedding=embedding_models_manager[ + index_settings.get( + "embedding", embedding_models_manager.get_default_name() + ) + ] + ) return obj - def set_resources(self, resources: dict): - super().set_resources(resources) - self.indexing_vector_pipeline.vector_store = self._VS - self.indexing_vector_pipeline.doc_store = self._DS + def route(self, file_path: Path) -> IndexPipeline: + """Decide the pipeline based on the file type + + Can subclass this method for a more elaborate pipeline routing strategy. + """ + readers, chunk_size, chunk_overlap = dev_settings() + + ext = file_path.suffix + reader = readers.get(ext, KH_DEFAULT_FILE_EXTRACTORS.get(ext, None)) + if reader is None: + raise NotImplementedError( + f"No supported pipeline to index {file_path.name}. Please specify " + "the suitable pipeline for this file type in the settings." + ) + + pipeline: IndexPipeline = IndexPipeline( + loader=reader, + splitter=TokenSplitter( + chunk_size=chunk_size or 1024, + chunk_overlap=chunk_overlap or 256, + separator="\n\n", + backup_separators=["\n", ".", "\u200B"], + ), + Source=self.Source, + Index=self.Index, + VS=self.VS, + DS=self.DS, + FSPath=self.FSPath, + user_id=self.user_id, + embedding=self.embedding, + ) + + return pipeline + + def run( + self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs + ) -> tuple[list[str | None], list[str | None]]: + """Return a list of indexed file ids, and a list of errors""" + if not isinstance(file_paths, list): + file_paths = [file_paths] + + file_ids: list[str | None] = [] + errors: list[str | None] = [] + for file_path in file_paths: + file_path = Path(file_path) + + try: + pipeline = self.route(file_path) + file_id = pipeline.run(file_path, reindex=reindex, **kwargs) + file_ids.append(file_id) + errors.append(None) + except Exception as e: + logger.error(e) + file_ids.append(None) + errors.append(str(e)) + + return file_ids, errors + + def stream( + self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs + ) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]: + """Return a list of indexed file ids, and a list of errors""" + if not isinstance(file_paths, list): + file_paths = [file_paths] + + file_ids: list[str | None] = [] + errors: list[str | None] = [] + n_files = len(file_paths) + for idx, file_path in enumerate(file_paths): + file_path = Path(file_path) + yield Document( + content=f"Indexing [{idx+1}/{n_files}]: {file_path.name}", + channel="debug", + ) + + try: + pipeline = self.route(file_path) + file_id = yield from pipeline.stream( + file_path, reindex=reindex, **kwargs + ) + file_ids.append(file_id) + errors.append(None) + yield Document( + content={"file_path": file_path, "status": "success"}, + channel="index", + ) + except Exception as e: + logger.error(e) + file_ids.append(None) + errors.append(str(e)) + yield Document( + content={ + "file_path": file_path, + "status": "failed", + "message": str(e), + }, + channel="index", + ) - def warning(self, msg): - gr.Warning(msg) + return file_ids, errors diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index 43fe4c511..d46e0725d 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -1,6 +1,7 @@ import os import tempfile from pathlib import Path +from typing import Generator import gradio as gr import pandas as pd @@ -63,9 +64,6 @@ def on_building_ui(self): ) self.upload_button = gr.Button("Upload and Index") - self.file_output = gr.File( - visible=False, label="Output files (debug purpose)" - ) class FileIndexPage(BasePage): @@ -127,11 +125,23 @@ def on_building_ui(self): self.upload_button = gr.Button( "Upload and Index", variant="primary" ) - self.file_output = gr.File( - visible=False, label="Output files (debug purpose)" - ) with gr.Column(scale=4): + with gr.Column(visible=False) as self.upload_progress_panel: + gr.Markdown("## Upload Progress") + with gr.Row(): + self.upload_result = gr.Textbox( + lines=1, max_lines=20, label="Upload result" + ) + self.upload_info = gr.Textbox( + lines=1, max_lines=20, label="Upload info" + ) + self.btn_close_upload_progress_panel = gr.Button( + "Clear Upload Info and Close", + variant="secondary", + elem_classes=["right-button"], + ) + gr.Markdown("## File List") self.file_list_state = gr.State(value=None) self.file_list = gr.DataFrame( @@ -261,6 +271,9 @@ def on_register_events(self): ) onUploaded = self.upload_button.click( + fn=lambda: gr.update(visible=True), + outputs=[self.upload_progress_panel], + ).then( fn=self.index_fn, inputs=[ self.files, @@ -268,16 +281,28 @@ def on_register_events(self): self._app.settings_state, self._app.user_id, ], - outputs=[self.file_output], + outputs=[self.upload_result, self.upload_info], concurrency_limit=20, - ).then( + ) + + uploadedEvent = onUploaded.then( fn=self.list_file, inputs=[self._app.user_id], outputs=[self.file_list_state, self.file_list], concurrency_limit=20, ) for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): - onUploaded = onUploaded.then(**event) + uploadedEvent = uploadedEvent.then(**event) + + _ = onUploaded.success( + fn=lambda: None, + outputs=[self.files], + ) + + self.btn_close_upload_progress_panel.click( + fn=lambda: (gr.update(visible=False), "", ""), + outputs=[self.upload_progress_panel, self.upload_result, self.upload_info], + ) self.file_list.select( fn=self.interact_file_list, @@ -294,7 +319,9 @@ def _on_app_created(self): outputs=[self.file_list_state, self.file_list], ) - def index_fn(self, files, reindex: bool, settings, user_id): + def index_fn( + self, files, reindex: bool, settings, user_id + ) -> Generator[tuple[str, str], None, None]: """Upload and index the files Args: @@ -305,35 +332,56 @@ def index_fn(self, files, reindex: bool, settings, user_id): """ if not files: gr.Info("No uploaded file") - return gr.update() + yield "", "" + return errors = self.validate(files) if errors: gr.Warning(", ".join(errors)) - return gr.update() + yield "", "" + return gr.Info(f"Start indexing {len(files)} files...") # get the pipeline indexing_pipeline = self._index.get_indexing_pipeline(settings, user_id) - result = indexing_pipeline(files, reindex=reindex) - if result is None: - gr.Info("Finish indexing") + outputs, debugs = [], [] + # stream the output + output_stream = indexing_pipeline.stream(files, reindex=reindex) + try: + while True: + response = next(output_stream) + if response is None: + continue + if response.channel == "index": + if response.content["status"] == "success": + outputs.append(f"\u2705 | {response.content['file_path'].name}") + elif response.content["status"] == "failed": + outputs.append( + f"\u274c | {response.content['file_path'].name}: " + f"{response.content['message']}" + ) + elif response.channel == "debug": + debugs.append(response.text) + yield "\n".join(outputs), "\n".join(debugs) + except StopIteration as e: + result, errors = e.value + except Exception as e: + debugs.append(f"Error: {e}") + yield "\n".join(outputs), "\n".join(debugs) return - output_nodes, _ = result - gr.Info(f"Finish indexing into {len(output_nodes)} chunks") - - # download the file - text = "\n\n".join([each.text for each in output_nodes]) - handler, file_path = tempfile.mkstemp(suffix=".txt") - with open(file_path, "w", encoding="utf-8") as f: - f.write(text) - os.close(handler) - return gr.update(value=file_path, visible=True) + n_successes = len([_ for _ in result if _]) + if n_successes: + gr.Info(f"Successfully index {n_successes} files") + n_errors = len([_ for _ in errors if _]) + if n_errors: + gr.Warning(f"Have errors for {n_errors} files") - def index_files_from_dir(self, folder_path, reindex, settings, user_id): + def index_files_from_dir( + self, folder_path, reindex, settings, user_id + ) -> Generator[tuple[str, str], None, None]: """This should be constructable by users It means that the users can build their own index. @@ -363,6 +411,7 @@ def index_files_from_dir(self, folder_path, reindex, settings, user_id): 2. Implement the transformation from artifacts to UI """ if not folder_path: + yield "", "" return import fnmatch @@ -401,7 +450,7 @@ def index_files_from_dir(self, folder_path, reindex, settings, user_id): for p in exclude_patterns: files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)] - return self.index_fn(files, reindex, settings, user_id) + yield from self.index_fn(files, reindex, settings, user_id) def list_file(self, user_id): if user_id is None: diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index b21e85575..d9826e0db 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -99,6 +99,7 @@ def on_register_events(self): self.chat_panel.chatbot, self._app.settings_state, self.chat_state, + self._app.user_id, ] + self._indices_input, outputs=[ @@ -127,6 +128,7 @@ def on_register_events(self): self.chat_panel.chatbot, self._app.settings_state, self.chat_state, + self._app.user_id, ] + self._indices_input, outputs=[ @@ -360,7 +362,7 @@ def is_liked(self, convo_id, liked: gr.LikeData): session.add(result) session.commit() - def create_pipeline(self, settings: dict, state: dict, *selecteds): + def create_pipeline(self, settings: dict, state: dict, user_id: int, *selecteds): """Create the pipeline from settings Args: @@ -385,7 +387,9 @@ def create_pipeline(self, settings: dict, state: dict, *selecteds): if isinstance(index.selector, tuple): for i in index.selector: index_selected.append(selecteds[i]) - iretrievers = index.get_retriever_pipelines(settings, index_selected) + iretrievers = index.get_retriever_pipelines( + settings, user_id, index_selected + ) retrievers += iretrievers # prepare states @@ -398,7 +402,9 @@ def create_pipeline(self, settings: dict, state: dict, *selecteds): return pipeline, reasoning_state - def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds): + def chat_fn( + self, conversation_id, chat_history, settings, state, user_id, *selecteds + ): """Chat function""" chat_input = chat_history[-1][0] chat_history = chat_history[:-1] @@ -406,7 +412,9 @@ def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds): queue: asyncio.Queue[Optional[dict]] = asyncio.Queue() # construct the pipeline - pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds) + pipeline, reasoning_state = self.create_pipeline( + settings, state, user_id, *selecteds + ) pipeline.set_output_queue(queue) text, refs = "", "" @@ -452,7 +460,9 @@ def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds): print(f"Generate nothing: {empty_msg}") yield chat_history + [(chat_input, text or empty_msg)], refs, state - def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds): + def regen_fn( + self, conversation_id, chat_history, settings, state, user_id, *selecteds + ): """Regen function""" if not chat_history: gr.Warning("Empty chat") @@ -461,7 +471,7 @@ def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds): state["app"]["regen"] = True for chat, refs, state in self.chat_fn( - conversation_id, chat_history, settings, state, *selecteds + conversation_id, chat_history, settings, state, user_id, *selecteds ): new_state = deepcopy(state) new_state["app"]["regen"] = False diff --git a/libs/ktem/ktem/reasoning/react.py b/libs/ktem/ktem/reasoning/react.py index afafcef55..9f9202332 100644 --- a/libs/ktem/ktem/reasoning/react.py +++ b/libs/ktem/ktem/reasoning/react.py @@ -107,8 +107,8 @@ def prepare_evidence(self, docs, trim_len: int = 4000): separator=" ", model_name="gpt-3.5-turbo", ) - texts = text_splitter.split_text(evidence) - evidence = texts[0] + texts = text_splitter.split_text(evidence) + evidence = texts[0] return Document(content=evidence)