From ab6b3fc5294d675bcc571bd2616d8326413acfcd Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Thu, 28 Nov 2024 21:12:56 +0700 Subject: [PATCH] feat: add quick file selection upon tagging on Chat input (#533) bump:patch * fix: improve inline citation logics without rag * fix: improve explanation for citation options * feat: add quick file selection on Chat input --- .../kotaemon/indices/qa/citation_qa_inline.py | 59 ++++++++++------ libs/kotaemon/kotaemon/indices/qa/utils.py | 3 + libs/ktem/ktem/app.py | 4 ++ libs/ktem/ktem/assets/css/main.css | 17 +++++ libs/ktem/ktem/index/file/ui.py | 67 +++++++++++++------ libs/ktem/ktem/pages/chat/__init__.py | 63 +++++++++++++++-- libs/ktem/ktem/pages/chat/chat_panel.py | 2 +- libs/ktem/ktem/reasoning/simple.py | 6 +- libs/ktem/ktem/utils/__init__.py | 3 +- libs/ktem/ktem/utils/conversation.py | 13 ++++ 10 files changed, 186 insertions(+), 51 deletions(-) diff --git a/libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py b/libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py index 9770b90fb..17e94e0e3 100644 --- a/libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py +++ b/libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py @@ -152,6 +152,20 @@ def answer_to_citations(self, answer) -> list[InlineEvidence]: def replace_citation_with_link(self, answer: str): # Define the regex pattern to match 【number】 pattern = r"【\d+】" + + # Regular expression to match merged citations + multi_pattern = r"【([\d,\s]+)】" + + # Function to replace merged citations with independent ones + def split_citations(match): + # Extract the numbers, split by comma, and create individual citations + numbers = match.group(1).split(",") + return "".join(f"【{num.strip()}】" for num in numbers) + + # Replace merged citations in the text + answer = re.sub(multi_pattern, split_citations, answer) + + # Find all citations in the answer matches = re.finditer(pattern, answer) matched_citations = set() @@ -240,25 +254,30 @@ def mindmap_call(): # try streaming first print("Trying LLM streaming") for out_msg in self.llm.stream(messages): - if START_ANSWER in output: - if not final_answer: - try: - left_over_answer = output.split(START_ANSWER)[1].lstrip() - except IndexError: - left_over_answer = "" - if left_over_answer: - out_msg.text = left_over_answer + out_msg.text - - final_answer += ( - out_msg.text.lstrip() if not final_answer else out_msg.text - ) + if evidence: + if START_ANSWER in output: + if not final_answer: + try: + left_over_answer = output.split(START_ANSWER)[ + 1 + ].lstrip() + except IndexError: + left_over_answer = "" + if left_over_answer: + out_msg.text = left_over_answer + out_msg.text + + final_answer += ( + out_msg.text.lstrip() if not final_answer else out_msg.text + ) + yield Document(channel="chat", content=out_msg.text) + + # check for the edge case of citation list is repeated + # with smaller LLMs + if START_CITATION in out_msg.text: + break + else: yield Document(channel="chat", content=out_msg.text) - # check for the edge case of citation list is repeated - # with smaller LLMs - if START_CITATION in out_msg.text: - break - output += out_msg.text logprobs += out_msg.logprobs except NotImplementedError: @@ -289,8 +308,10 @@ def mindmap_call(): # yield the final answer final_answer = self.replace_citation_with_link(final_answer) - yield Document(channel="chat", content=None) - yield Document(channel="chat", content=final_answer) + + if final_answer: + yield Document(channel="chat", content=None) + yield Document(channel="chat", content=final_answer) return answer diff --git a/libs/kotaemon/kotaemon/indices/qa/utils.py b/libs/kotaemon/kotaemon/indices/qa/utils.py index 4b6495a3d..51602b805 100644 --- a/libs/kotaemon/kotaemon/indices/qa/utils.py +++ b/libs/kotaemon/kotaemon/indices/qa/utils.py @@ -26,6 +26,9 @@ def find_start_end_phrase( matches = [] matched_length = 0 for sentence in [start_phrase, end_phrase]: + if sentence is None: + continue + match = SequenceMatcher( None, sentence, context, autojunk=False ).find_longest_match() diff --git a/libs/ktem/ktem/app.py b/libs/ktem/ktem/app.py index b56612b60..7142377e1 100644 --- a/libs/ktem/ktem/app.py +++ b/libs/ktem/ktem/app.py @@ -177,6 +177,10 @@ def make(self): "" + "" + "" # noqa ) with gr.Blocks( diff --git a/libs/ktem/ktem/assets/css/main.css b/libs/ktem/ktem/assets/css/main.css index ae6fe7daf..dba11efe9 100644 --- a/libs/ktem/ktem/assets/css/main.css +++ b/libs/ktem/ktem/assets/css/main.css @@ -365,3 +365,20 @@ details.evidence { color: #10b981; text-decoration: none; } + +/* pop-up for file tag in chat input*/ +.tribute-container ul { + background-color: var(--background-fill-primary) !important; + color: var(--body-text-color) !important; + font-family: var(--font); + font-size: var(--text-md); +} + +.tribute-container li.highlight { + background-color: var(--border-color-primary) !important; +} + +/* a fix for flickering background in Gradio DataFrame */ +tbody:not(.row_odd) { + background: var(--table-even-background-fill); +} diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index 7a97aaeb0..6b1fdf473 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -29,6 +29,25 @@ } """ +update_file_list_js = """ +function(file_list) { + var values = []; + for (var i = 0; i < file_list.length; i++) { + values.push({ + key: file_list[i][0], + value: '"' + file_list[i][0] + '"', + }); + } + var tribute = new Tribute({ + values: values, + noMatchTemplate: "", + allowSpaces: true, + }) + input_box = document.querySelector('#chat-input textarea'); + tribute.attach(input_box); +} +""" + class File(gr.File): """Subclass from gr.File to maintain the original filename @@ -1429,6 +1448,10 @@ def on_building_ui(self): visible=False, ) self.selector_user_id = gr.State(value=user_id) + self.selector_choices = gr.JSON( + value=[], + visible=False, + ) def on_register_events(self): self.mode.change( @@ -1436,6 +1459,14 @@ def on_register_events(self): inputs=[self.mode, self._app.user_id], outputs=[self.selector, self.selector_user_id], ) + # attach special event for the first index + if self._index.id == 1: + self.selector_choices.change( + fn=None, + inputs=[self.selector_choices], + js=update_file_list_js, + show_progress="hidden", + ) def as_gradio_component(self): return [self.mode, self.selector, self.selector_user_id] @@ -1468,7 +1499,7 @@ def load_files(self, selected_files, user_id): available_ids = [] if user_id is None: # not signed in - return gr.update(value=selected_files, choices=options) + return gr.update(value=selected_files, choices=options), options with Session(engine) as session: # get file list from Source table @@ -1501,13 +1532,13 @@ def load_files(self, selected_files, user_id): each for each in selected_files if each in available_ids_set ] - return gr.update(value=selected_files, choices=options) + return gr.update(value=selected_files, choices=options), options def _on_app_created(self): self._app.app.load( self.load_files, inputs=[self.selector, self._app.user_id], - outputs=[self.selector], + outputs=[self.selector, self.selector_choices], ) def on_subscribe_public_events(self): @@ -1516,26 +1547,18 @@ def on_subscribe_public_events(self): definition={ "fn": self.load_files, "inputs": [self.selector, self._app.user_id], - "outputs": [self.selector], + "outputs": [self.selector, self.selector_choices], "show_progress": "hidden", }, ) if self._app.f_user_management: - self._app.subscribe_event( - name="onSignIn", - definition={ - "fn": self.load_files, - "inputs": [self.selector, self._app.user_id], - "outputs": [self.selector], - "show_progress": "hidden", - }, - ) - self._app.subscribe_event( - name="onSignOut", - definition={ - "fn": self.load_files, - "inputs": [self.selector, self._app.user_id], - "outputs": [self.selector], - "show_progress": "hidden", - }, - ) + for event_name in ["onSignIn", "onSignOut"]: + self._app.subscribe_event( + name=event_name, + definition={ + "fn": self.load_files, + "inputs": [self.selector, self._app.user_id], + "outputs": [self.selector, self.selector_choices], + "show_progress": "hidden", + }, + ) diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index 55f04b3b9..045358735 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -8,7 +8,7 @@ from ktem.app import BasePage from ktem.components import reasonings from ktem.db.models import Conversation, engine -from ktem.index.file.ui import File +from ktem.index.file.ui import File, chat_input_focus_js from ktem.reasoning.prompt_optimization.suggest_conversation_name import ( SuggestConvNamePipeline, ) @@ -22,7 +22,7 @@ from kotaemon.base import Document from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS -from ...utils import SUPPORTED_LANGUAGE_MAP +from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex from .chat_panel import ChatPanel from .common import STATE from .control import ConversationControl @@ -113,6 +113,7 @@ def on_building_ui(self): self.state_plot_history = gr.State([]) self.state_plot_panel = gr.State(None) self.state_follow_up = gr.State(None) + self.first_selector_choices = gr.State(None) with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column: self.chat_control = ConversationControl(self._app) @@ -130,6 +131,11 @@ def on_building_ui(self): ): index_ui.render() gr_index = index_ui.as_gradio_component() + + # get the file selector choices for the first index + if index_id == 0: + self.first_selector_choices = index_ui.selector_choices + if gr_index: if isinstance(gr_index, list): index.selector = tuple( @@ -272,6 +278,7 @@ def on_register_events(self): self.chat_control.conversation_id, self.chat_control.conversation_rn, self.state_follow_up, + self.first_selector_choices, ], outputs=[ self.chat_panel.text_input, @@ -280,6 +287,9 @@ def on_register_events(self): self.chat_control.conversation, self.chat_control.conversation_rn, self.state_follow_up, + # file selector from the first index + self._indices_input[0], + self._indices_input[1], ], concurrency_limit=20, show_progress="hidden", @@ -426,6 +436,10 @@ def on_register_events(self): fn=self._json_to_plot, inputs=self.state_plot_panel, outputs=self.plot_panel, + ).then( + fn=None, + inputs=None, + js=chat_input_focus_js, ) self.chat_control.btn_del.click( @@ -516,7 +530,12 @@ def on_register_events(self): lambda: self.toggle_delete(""), outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], ).then( - fn=None, inputs=None, outputs=None, js=pdfview_js + fn=lambda: True, + inputs=None, + outputs=[self._preview_links], + js=pdfview_js, + ).then( + fn=None, inputs=None, outputs=None, js=chat_input_focus_js ) # evidence display on message selection @@ -535,7 +554,12 @@ def on_register_events(self): inputs=self.state_plot_panel, outputs=self.plot_panel, ).then( - fn=None, inputs=None, outputs=None, js=pdfview_js + fn=lambda: True, + inputs=None, + outputs=[self._preview_links], + js=pdfview_js, + ).then( + fn=None, inputs=None, outputs=None, js=chat_input_focus_js ) self.chat_control.cb_is_public.change( @@ -585,7 +609,14 @@ def on_register_events(self): ) def submit_msg( - self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest + self, + chat_input, + chat_history, + user_id, + conv_id, + conv_name, + chat_suggest, + first_selector_choices, ): """Submit a message to the chatbot""" if not chat_input: @@ -593,6 +624,24 @@ def submit_msg( chat_input_text = chat_input.get("text", "") + # get all file names with pattern @"filename" in input_str + file_names, chat_input_text = get_file_names_regex(chat_input_text) + first_selector_choices_map = { + item[0]: item[1] for item in first_selector_choices + } + file_ids = [] + + if file_names: + for file_name in file_names: + file_id = first_selector_choices_map.get(file_name) + if file_id: + file_ids.append(file_id) + + if file_ids: + selector_output = ["select", file_ids] + else: + selector_output = [gr.update(), gr.update()] + # check if regen mode is active if chat_input_text: chat_history = chat_history + [(chat_input_text, None)] @@ -620,14 +669,14 @@ def submit_msg( new_conv_name = conv_name new_chat_suggestion = chat_suggest - return ( + return [ {}, chat_history, new_conv_id, conv_update, new_conv_name, new_chat_suggestion, - ) + ] + selector_output def toggle_delete(self, conv_id): if conv_id: diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py index b83c5d154..2adc52f01 100644 --- a/libs/ktem/ktem/pages/chat/chat_panel.py +++ b/libs/ktem/ktem/pages/chat/chat_panel.py @@ -25,7 +25,7 @@ def on_building_ui(self): interactive=True, scale=20, file_count="multiple", - placeholder="Chat input", + placeholder="Type a message (or tag a file with @filename)", container=False, show_label=False, elem_id="chat-input", diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index fbd861f29..ab7ebf204 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -410,7 +410,11 @@ def get_user_settings(cls) -> dict: "name": "Citation style", "value": "highlight", "component": "radio", - "choices": ["highlight", "inline", "off"], + "choices": [ + ("highlight (long answer)", "highlight"), + ("inline (precise answer)", "inline"), + ("off", "off"), + ], }, "create_mindmap": { "name": "Create Mindmap", diff --git a/libs/ktem/ktem/utils/__init__.py b/libs/ktem/ktem/utils/__init__.py index 009c60a62..8865bd328 100644 --- a/libs/ktem/ktem/utils/__init__.py +++ b/libs/ktem/ktem/utils/__init__.py @@ -1,3 +1,4 @@ +from .conversation import get_file_names_regex from .lang import SUPPORTED_LANGUAGE_MAP -__all__ = ["SUPPORTED_LANGUAGE_MAP"] +__all__ = ["SUPPORTED_LANGUAGE_MAP", "get_file_names_regex"] diff --git a/libs/ktem/ktem/utils/conversation.py b/libs/ktem/ktem/utils/conversation.py index 2550aa7f3..2dc95b13a 100644 --- a/libs/ktem/ktem/utils/conversation.py +++ b/libs/ktem/ktem/utils/conversation.py @@ -1,3 +1,6 @@ +import re + + def sync_retrieval_n_message( messages: list[list[str]], retrievals: list[str], @@ -16,5 +19,15 @@ def sync_retrieval_n_message( return retrievals +def get_file_names_regex(input_str: str) -> tuple[list[str], str]: + # get all file names with pattern @"filename" in input_str + # also remove these file names from input_str + pattern = r'@"([^"]*)"' + matches = re.findall(pattern, input_str) + input_str = re.sub(pattern, "", input_str).strip() + + return matches, input_str + + if __name__ == "__main__": print(sync_retrieval_n_message([[""], [""], [""]], []))