diff --git a/docs/images/chat-tab-demo.png b/docs/images/chat-tab-demo.png index 9730b366c..19bac86aa 100644 Binary files a/docs/images/chat-tab-demo.png and b/docs/images/chat-tab-demo.png differ diff --git a/docs/images/chat-tab.png b/docs/images/chat-tab.png index 233908439..6f21f7dd6 100644 Binary files a/docs/images/chat-tab.png and b/docs/images/chat-tab.png differ diff --git a/docs/usage.md b/docs/usage.md index 0e7d8e6d0..e976f8c32 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -128,9 +128,12 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions: 1. Conversation Settings Panel - Here you can select, create, rename, and delete conversations. - By default, a new conversation is created automatically if no conversation is selected. - - Below that you have the file index, where you can select which files to retrieve references from. - - These are the files you have uploaded to the application from the `File Index` tab. - - If no file is selected, all files will be used. + - Below that you have the file index, where you can choose whether to disable, select all files, or select which files to retrieve references from. + - If you choose "Disabled", no files will be considered as context during chat. + - If you choose "Search All", all files will be considered during chat. + - If you choose "Select", a dropdown will appear for you to select the + files to be considered during chat. If no files are selected, then no + files will be considered during chat. 2. Chat Panel - This is where you can chat with the chatbot. 3. Information panel diff --git a/libs/ktem/ktem/assets/md/usage.md b/libs/ktem/ktem/assets/md/usage.md index 0e7d8e6d0..e976f8c32 100644 --- a/libs/ktem/ktem/assets/md/usage.md +++ b/libs/ktem/ktem/assets/md/usage.md @@ -128,9 +128,12 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions: 1. Conversation Settings Panel - Here you can select, create, rename, and delete conversations. - By default, a new conversation is created automatically if no conversation is selected. - - Below that you have the file index, where you can select which files to retrieve references from. - - These are the files you have uploaded to the application from the `File Index` tab. - - If no file is selected, all files will be used. + - Below that you have the file index, where you can choose whether to disable, select all files, or select which files to retrieve references from. + - If you choose "Disabled", no files will be considered as context during chat. + - If you choose "Search All", all files will be considered during chat. + - If you choose "Select", a dropdown will appear for you to select the + files to be considered during chat. If no files are selected, then no + files will be considered during chat. 2. Chat Panel - This is where you can chat with the chatbot. 3. Information panel diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index 78e66a1f0..375f3ddb5 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -67,58 +67,63 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): documents get_extra_table: if True, for each retrieved document, the pipeline will look for surrounding tables (e.g. within the page) + top_k: number of documents to retrieve + mmr: whether to use mmr to re-rank the documents """ vector_retrieval: VectorRetrieval = VectorRetrieval.withx() reranker: BaseReranking get_extra_table: bool = False + mmr: bool = False + top_k: int = 5 def run( self, text: str, - top_k: int = 5, - mmr: bool = False, doc_ids: Optional[list[str]] = None, + *args, + **kwargs, ) -> list[RetrievedDocument]: """Retrieve document excerpts similar to the text Args: text: the text to retrieve similar documents - top_k: number of documents to retrieve - mmr: whether to use mmr to re-rank the documents doc_ids: list of document ids to constraint the retrieval """ + if not doc_ids: + logger.info(f"Skip retrieval because of no selected files: {self}") + return [] + Index = self._Index - kwargs = {} - if doc_ids: - with Session(engine) as session: - stmt = select(Index).where( - Index.relation_type == "vector", - Index.source_id.in_(doc_ids), # type: ignore - ) - results = session.execute(stmt) - vs_ids = [r[0].target_id for r in results.all()] - - kwargs["filters"] = MetadataFilters( - filters=[ - MetadataFilter( - key="doc_id", - value=vs_id, - operator=FilterOperator.EQ, - ) - for vs_id in vs_ids - ], - condition=FilterCondition.OR, + retrieval_kwargs = {} + with Session(engine) as session: + stmt = select(Index).where( + Index.relation_type == "vector", + Index.source_id.in_(doc_ids), # type: ignore ) + results = session.execute(stmt) + vs_ids = [r[0].target_id for r in results.all()] + + retrieval_kwargs["filters"] = MetadataFilters( + filters=[ + MetadataFilter( + key="doc_id", + value=vs_id, + operator=FilterOperator.EQ, + ) + for vs_id in vs_ids + ], + condition=FilterCondition.OR, + ) - if mmr: + if self.mmr: # TODO: double check that llama-index MMR works correctly - kwargs["mode"] = VectorStoreQueryMode.MMR - kwargs["mmr_threshold"] = 0.5 + retrieval_kwargs["mode"] = VectorStoreQueryMode.MMR + retrieval_kwargs["mmr_threshold"] = 0.5 # rerank - docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs) + docs = self.vector_retrieval(text=text, top_k=self.top_k, **retrieval_kwargs) if docs and self.get_from_path("reranker"): docs = self.reranker(docs, query=text) @@ -221,6 +226,8 @@ def get_pipeline(cls, user_settings, index_settings, selected): retriever = cls( get_extra_table=user_settings["prioritize_table"], reranker=user_settings["reranking_llm"], + top_k=user_settings["num_retrieval"], + mmr=user_settings["mmr"], ) if not user_settings["use_reranking"]: retriever.reranker = None # type: ignore @@ -228,11 +235,7 @@ def get_pipeline(cls, user_settings, index_settings, selected): retriever.vector_retrieval.embedding = embedding_models_manager[ index_settings.get("embedding", embedding_models_manager.get_default_name()) ] - kwargs = { - ".top_k": int(user_settings["num_retrieval"]), - ".mmr": user_settings["mmr"], - ".doc_ids": selected, - } + kwargs = {".doc_ids": selected} retriever.set_run(kwargs, temp=True) return retriever diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index c166123de..5ca461681 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -512,20 +512,55 @@ def __init__(self, app, index): self._index = index self.on_building_ui() + def default(self): + return "disabled", [] + def on_building_ui(self): + default_mode, default_selector = self.default() + + self.mode = gr.Radio( + value=default_mode, + choices=[ + ("Disabled", "disabled"), + ("Search All", "all"), + ("Select", "select"), + ], + container=False, + ) self.selector = gr.Dropdown( label="Files", - choices=[], + choices=default_selector, multiselect=True, container=False, interactive=True, + visible=False, + ) + + def on_register_events(self): + self.mode.change( + fn=lambda mode: gr.update(visible=mode == "select"), + inputs=[self.mode], + outputs=[self.selector], ) def as_gradio_component(self): - return self.selector + return [self.mode, self.selector] + + def get_selected_ids(self, components): + mode, selected = components[0], components[1] + if mode == "disabled": + return [] + elif mode == "select": + return selected + + file_ids = [] + with Session(engine) as session: + statement = select(self._index._resources["Source"].id) + results = session.execute(statement).all() + for (id,) in results: + file_ids.append(id) - def get_selected_ids(self, selected): - return selected + return file_ids def load_files(self, selected_files): options = [] diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index f95faf686..b8be90a8c 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -52,9 +52,11 @@ def on_building_ui(self): len(self._indices_input) + len(gr_index), ) ) + index.default_selector = index_ui.default() self._indices_input.extend(gr_index) else: index.selector = len(self._indices_input) + index.default_selector = index_ui.default() self._indices_input.append(gr_index) setattr(self, f"_index_{index.id}", index_ui) diff --git a/libs/ktem/ktem/pages/chat/control.py b/libs/ktem/ktem/pages/chat/control.py index 818420fbf..5e369b53f 100644 --- a/libs/ktem/ktem/pages/chat/control.py +++ b/libs/ktem/ktem/pages/chat/control.py @@ -156,9 +156,9 @@ def select_conv(self, conversation_id): if index.selector is None: continue if isinstance(index.selector, int): - indices.append(selected.get(str(index.id), [])) + indices.append(selected.get(str(index.id), index.default_selector)) if isinstance(index.selector, tuple): - indices.extend(selected.get(str(index.id), [[]] * len(index.selector))) + indices.extend(selected.get(str(index.id), index.default_selector)) return id_, id_, name, chats, info_panel, state, *indices