From 0417610d3e97edc1f23ded22361a80bf2e39f967 Mon Sep 17 00:00:00 2001 From: "Duc Nguyen (john)" Date: Sat, 13 Apr 2024 23:13:04 +0700 Subject: [PATCH] Refactor reasoning pipeline (#31) * Move the text rendering out for reusability * Refactor common operations in the reasoning pipeline * Add run method * Provide dedicated method for invoke --- libs/kotaemon/kotaemon/base/component.py | 2 +- libs/kotaemon/kotaemon/llms/chats/openai.py | 3 +- libs/ktem/ktem/llms/manager.py | 16 +- libs/ktem/ktem/reasoning/simple.py | 408 ++++++++------------ libs/ktem/ktem/utils/__init__.py | 0 libs/ktem/ktem/utils/render.py | 21 + 6 files changed, 192 insertions(+), 258 deletions(-) create mode 100644 libs/ktem/ktem/utils/__init__.py create mode 100644 libs/ktem/ktem/utils/render.py diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py index 6936b2a8f..230ce9ddc 100644 --- a/libs/kotaemon/kotaemon/base/component.py +++ b/libs/kotaemon/kotaemon/base/component.py @@ -39,7 +39,7 @@ def set_output_queue(self, queue): if isinstance(node, BaseComponent): node.set_output_queue(queue) - def report_output(self, output: Optional[dict]): + def report_output(self, output: Optional[Document]): if self._queue is not None: self._queue.put_nowait(output) diff --git a/libs/kotaemon/kotaemon/llms/chats/openai.py b/libs/kotaemon/kotaemon/llms/chats/openai.py index 6f492c7ad..1a31e24f6 100644 --- a/libs/kotaemon/kotaemon/llms/chats/openai.py +++ b/libs/kotaemon/kotaemon/llms/chats/openai.py @@ -270,7 +270,7 @@ def prepare_client(self, async_version: bool = False): def openai_response(self, client, **kwargs): """Get the openai response""" - params = { + params_ = { "model": self.model, "temperature": self.temperature, "max_tokens": self.max_tokens, @@ -285,6 +285,7 @@ def openai_response(self, client, **kwargs): "top_logprobs": self.top_logprobs, "top_p": self.top_p, } + params = {k: v for k, v in params_.items() if v is not None} params.update(kwargs) return client.chat.completions.create(**params) diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py index 0ef64e002..71ad42565 100644 --- a/libs/ktem/ktem/llms/manager.py +++ b/libs/ktem/ktem/llms/manager.py @@ -5,7 +5,7 @@ from theflow.settings import settings as flowsettings from theflow.utils.modules import deserialize -from kotaemon.base import BaseComponent +from kotaemon.llms import ChatLLM from .db import LLMTable, engine @@ -14,7 +14,7 @@ class LLMManager: """Represent a pool of models""" def __init__(self): - self._models: dict[str, BaseComponent] = {} + self._models: dict[str, ChatLLM] = {} self._info: dict[str, dict] = {} self._default: str = "" self._vendors: list[Type] = [] @@ -63,7 +63,7 @@ def load_vendors(self): self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM] - def __getitem__(self, key: str) -> BaseComponent: + def __getitem__(self, key: str) -> ChatLLM: """Get model by name""" return self._models[key] @@ -71,9 +71,7 @@ def __contains__(self, key: str) -> bool: """Check if model exists""" return key in self._models - def get( - self, key: str, default: Optional[BaseComponent] = None - ) -> Optional[BaseComponent]: + def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]: """Get model by name with default value""" return self._models.get(key, default) @@ -119,18 +117,18 @@ def get_default_name(self) -> str: return self._default - def get_random(self) -> BaseComponent: + def get_random(self) -> ChatLLM: """Get random model""" return self._models[self.get_random_name()] - def get_default(self) -> BaseComponent: + def get_default(self) -> ChatLLM: """Get default model In case there is no default model, choose random model from pool. In case there are multiple default models, choose random from them. Returns: - BaseComponent: model + ChatLLM: model """ return self._models[self.get_default_name()] diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 3397250de..d4881d8f9 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -8,6 +8,7 @@ import tiktoken from ktem.llms.manager import llms +from ktem.utils.render import Render from kotaemon.base import ( BaseComponent, @@ -20,7 +21,7 @@ from kotaemon.indices.qa.citation import CitationPipeline from kotaemon.indices.splitters import TokenSplitter from kotaemon.llms import ChatLLM, PromptTemplate -from kotaemon.loaders.utils.gpt4v import stream_gpt4v +from kotaemon.loaders.utils.gpt4v import generate_gpt4v, stream_gpt4v from .base import BaseReasoning @@ -205,31 +206,10 @@ class AnswerWithContextPipeline(BaseComponent): system_prompt: str = "" lang: str = "English" # support English and Japanese - async def run( # type: ignore - self, question: str, evidence: str, evidence_mode: int = 0, **kwargs - ) -> Document: - """Answer the question based on the evidence - - In addition to the question and the evidence, this method also take into - account evidence_mode. The evidence_mode tells which kind of evidence is. - The kind of evidence affects: - 1. How the evidence is represented. - 2. The prompt to generate the answer. - - By default, the evidence_mode is 0, which means the evidence is plain text with - no particular semantic representation. The evidence_mode can be: - 1. "table": There will be HTML markup telling that there is a table - within the evidence. - 2. "chatbot": There will be HTML markup telling that there is a chatbot. - This chatbot is a scenario, extracted from an Excel file, where each - row corresponds to an interaction. + def get_prompt(self, question, evidence, evidence_mode: int): + """Prepare the prompt and other information for LLM""" + images = [] - Args: - question: the original question posed by user - evidence: the text that contain relevant information to answer the question - (determined by retrieval pipeline) - evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot - """ if evidence_mode == EVIDENCE_MODE_TEXT: prompt_template = PromptTemplate(self.qa_template) elif evidence_mode == EVIDENCE_MODE_TABLE: @@ -239,7 +219,6 @@ async def run( # type: ignore else: prompt_template = PromptTemplate(self.qa_chatbot_template) - images = [] if evidence_mode == EVIDENCE_MODE_FIGURE: # isolate image from evidence evidence, images = self.extract_evidence_images(evidence) @@ -255,6 +234,66 @@ async def run( # type: ignore lang=self.lang, ) + return prompt, images + + def run( + self, question: str, evidence: str, evidence_mode: int = 0, **kwargs + ) -> Document: + return self.invoke(question, evidence, evidence_mode, **kwargs) + + def invoke( + self, question: str, evidence: str, evidence_mode: int = 0, **kwargs + ) -> Document: + prompt, images = self.get_prompt(question, evidence, evidence_mode) + + output = "" + if evidence_mode == EVIDENCE_MODE_FIGURE: + output = generate_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768) + else: + messages = [] + if self.system_prompt: + messages.append(SystemMessage(content=self.system_prompt)) + messages.append(HumanMessage(content=prompt)) + output = self.llm(messages).text + + # retrieve the citation + citation = None + if evidence and self.enable_citation: + citation = self.citation_pipeline.invoke( + context=evidence, question=question + ) + + answer = Document(text=output, metadata={"citation": citation}) + + return answer + + async def ainvoke( # type: ignore + self, question: str, evidence: str, evidence_mode: int = 0, **kwargs + ) -> Document: + """Answer the question based on the evidence + + In addition to the question and the evidence, this method also take into + account evidence_mode. The evidence_mode tells which kind of evidence is. + The kind of evidence affects: + 1. How the evidence is represented. + 2. The prompt to generate the answer. + + By default, the evidence_mode is 0, which means the evidence is plain text with + no particular semantic representation. The evidence_mode can be: + 1. "table": There will be HTML markup telling that there is a table + within the evidence. + 2. "chatbot": There will be HTML markup telling that there is a chatbot. + This chatbot is a scenario, extracted from an Excel file, where each + row corresponds to an interaction. + + Args: + question: the original question posed by user + evidence: the text that contain relevant information to answer the question + (determined by retrieval pipeline) + evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot + """ + prompt, images = self.get_prompt(question, evidence, evidence_mode) + citation_task = None if evidence and self.enable_citation: citation_task = asyncio.create_task( @@ -266,7 +305,7 @@ async def run( # type: ignore if evidence_mode == EVIDENCE_MODE_FIGURE: for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768): output += text - self.report_output({"output": text}) + self.report_output(Document(channel="chat", content=text)) await asyncio.sleep(0) else: messages = [] @@ -279,12 +318,12 @@ async def run( # type: ignore print("Trying LLM streaming") for text in self.llm.stream(messages): output += text.text - self.report_output({"output": text.text}) + self.report_output(Document(content=text.text, channel="chat")) await asyncio.sleep(0) except NotImplementedError: print("Streaming is not supported, falling back to normal processing") output = self.llm(messages).text - self.report_output({"output": output}) + self.report_output(Document(content=output, channel="chat")) # retrieve the citation print("Waiting for citation task") @@ -300,52 +339,7 @@ async def run( # type: ignore def stream( # type: ignore self, question: str, evidence: str, evidence_mode: int = 0, **kwargs ) -> Generator[Document, None, Document]: - """Answer the question based on the evidence - - In addition to the question and the evidence, this method also take into - account evidence_mode. The evidence_mode tells which kind of evidence is. - The kind of evidence affects: - 1. How the evidence is represented. - 2. The prompt to generate the answer. - - By default, the evidence_mode is 0, which means the evidence is plain text with - no particular semantic representation. The evidence_mode can be: - 1. "table": There will be HTML markup telling that there is a table - within the evidence. - 2. "chatbot": There will be HTML markup telling that there is a chatbot. - This chatbot is a scenario, extracted from an Excel file, where each - row corresponds to an interaction. - - Args: - question: the original question posed by user - evidence: the text that contain relevant information to answer the question - (determined by retrieval pipeline) - evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot - """ - if evidence_mode == EVIDENCE_MODE_TEXT: - prompt_template = PromptTemplate(self.qa_template) - elif evidence_mode == EVIDENCE_MODE_TABLE: - prompt_template = PromptTemplate(self.qa_table_template) - elif evidence_mode == EVIDENCE_MODE_FIGURE: - prompt_template = PromptTemplate(self.qa_figure_template) - else: - prompt_template = PromptTemplate(self.qa_chatbot_template) - - images = [] - if evidence_mode == EVIDENCE_MODE_FIGURE: - # isolate image from evidence - evidence, images = self.extract_evidence_images(evidence) - prompt = prompt_template.populate( - context=evidence, - question=question, - lang=self.lang, - ) - else: - prompt = prompt_template.populate( - context=evidence, - question=question, - lang=self.lang, - ) + prompt, images = self.get_prompt(question, evidence, evidence_mode) output = "" if evidence_mode == EVIDENCE_MODE_FIGURE: @@ -425,51 +419,35 @@ class Config: rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx() use_rewrite: bool = False - async def ainvoke( # type: ignore - self, message: str, conv_id: str, history: list, **kwargs # type: ignore - ) -> Document: # type: ignore - import markdown - - docs = [] - doc_ids = [] - if self.use_rewrite: - rewrite = await self.rewrite_pipeline(question=message) - message = rewrite.text - + def retrieve(self, message: str) -> tuple[list[RetrievedDocument], list[Document]]: + """Retrieve the documents based on the message""" + docs, doc_ids = [], [] for retriever in self.retrievers: for doc in retriever(text=message): if doc.doc_id not in doc_ids: docs.append(doc) doc_ids.append(doc.doc_id) + + info = [] for doc in docs: - # TODO: a better approach to show the information - text = markdown.markdown( - doc.text, extensions=["markdown.extensions.tables"] - ) - self.report_output( - { - "evidence": ( - "
" - f"{doc.metadata['file_name']}" - f"{text}" - "

" - ) - } + info.append( + Document( + channel="info", + content=Render.collapsible( + header=doc.metadata["file_name"], + content=Render.table(doc.text), + open=True, + ), + ) ) - await asyncio.sleep(0.1) - evidence_mode, evidence = self.evidence_pipeline(docs).content - answer = await self.answering_pipeline( - question=message, - history=history, - evidence=evidence, - evidence_mode=evidence_mode, - conv_id=conv_id, - **kwargs, - ) + return docs, info - # prepare citation + def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]: + """Prepare the citations to show on the UI""" + with_citation, without_citation = [], [] spans = defaultdict(list) + if answer.metadata["citation"] is not None: for fact_with_evidence in answer.metadata["citation"].answer: for quote in fact_with_evidence.substring_quote: @@ -500,9 +478,7 @@ async def ainvoke( # type: ignore break id2docs = {doc.doc_id: doc for doc in docs} - lack_evidence = True not_detected = set(id2docs.keys()) - set(spans.keys()) - self.report_output({"evidence": None}) for id, ss in spans.items(): if not ss: not_detected.add(id) @@ -510,48 +486,74 @@ async def ainvoke( # type: ignore ss = sorted(ss, key=lambda x: x["start"]) text = id2docs[id].text[: ss[0]["start"]] for idx, span in enumerate(ss): - text += ( - "" + id2docs[id].text[span["start"] : span["end"]] + "" - ) + text += Render.highlight(id2docs[id].text[span["start"] : span["end"]]) if idx < len(ss) - 1: text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] text += id2docs[id].text[ss[-1]["end"] :] - text_out = markdown.markdown( - text, extensions=["markdown.extensions.tables"] + with_citation.append( + Document( + channel="info", + content=Render.collapsible( + header=id2docs[id].metadata["file_name"], + content=Render.table(text), + open=True, + ), + ) ) - self.report_output( - { - "evidence": ( - "
" - f"{id2docs[id].metadata['file_name']}" - f"{text_out}" - "

" - ) - } + + without_citation = [ + Document( + channel="info", + content=Render.collapsible( + header=id2docs[id].metadata["file_name"], + content=Render.table(id2docs[id].text), + open=False, + ), ) - lack_evidence = False + for id in list(not_detected) + ] - if lack_evidence: - self.report_output({"evidence": "No evidence found.\n"}) + return with_citation, without_citation - if not_detected: - self.report_output( - {"evidence": "Retrieved segments without matching evidence:\n"} - ) - for id in list(not_detected): - text_out = markdown.markdown( - id2docs[id].text, extensions=["markdown.extensions.tables"] - ) + async def ainvoke( # type: ignore + self, message: str, conv_id: str, history: list, **kwargs # type: ignore + ) -> Document: # type: ignore + if self.use_rewrite: + rewrite = await self.rewrite_pipeline(question=message) + message = rewrite.text + + docs, infos = self.retrieve(message) + for _ in infos: + self.report_output(_) + await asyncio.sleep(0.1) + + evidence_mode, evidence = self.evidence_pipeline(docs).content + answer = await self.answering_pipeline( + question=message, + history=history, + evidence=evidence, + evidence_mode=evidence_mode, + conv_id=conv_id, + **kwargs, + ) + + # show the evidence + with_citation, without_citation = self.prepare_citations(answer, docs) + if not with_citation and not without_citation: + self.report_output(Document(channel="info", content="No evidence found.\n")) + else: + self.report_output(Document(channel="info", content=None)) + for _ in with_citation: + self.report_output(_) + if without_citation: self.report_output( - { - "evidence": ( - "
" - f"{id2docs[id].metadata['file_name']}" - f"{text_out}" - "

" - ) - } + Document( + channel="info", + content="Retrieved segments without matching evidence:\n", + ) ) + for _ in without_citation: + self.report_output(_) self.report_output(None) return answer @@ -559,32 +561,12 @@ async def ainvoke( # type: ignore def stream( # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore ) -> Generator[Document, None, Document]: - import markdown - - docs = [] - doc_ids = [] if self.use_rewrite: message = self.rewrite_pipeline(question=message).text - for retriever in self.retrievers: - for doc in retriever(text=message): - if doc.doc_id not in doc_ids: - docs.append(doc) - doc_ids.append(doc.doc_id) - for doc in docs: - # TODO: a better approach to show the information - text = markdown.markdown( - doc.text, extensions=["markdown.extensions.tables"] - ) - yield Document( - content=( - "
" - f"{doc.metadata['file_name']}" - f"{text}" - "

" - ), - channel="info", - ) + docs, infos = self.retrieve(message) + for _ in infos: + yield _ evidence_mode, evidence = self.evidence_pipeline(docs).content answer = yield from self.answering_pipeline.stream( @@ -596,89 +578,21 @@ def stream( # type: ignore **kwargs, ) - # prepare citation - spans = defaultdict(list) - if answer.metadata["citation"] is not None: - for fact_with_evidence in answer.metadata["citation"].answer: - for quote in fact_with_evidence.substring_quote: - for doc in docs: - start_idx = doc.text.find(quote) - if start_idx == -1: - continue - - end_idx = start_idx + len(quote) - - current_idx = start_idx - if "|" not in doc.text[start_idx:end_idx]: - spans[doc.doc_id].append( - {"start": start_idx, "end": end_idx} - ) - else: - while doc.text[current_idx:end_idx].find("|") != -1: - match_idx = doc.text[current_idx:end_idx].find("|") - spans[doc.doc_id].append( - { - "start": current_idx, - "end": current_idx + match_idx, - } - ) - current_idx += match_idx + 2 - if current_idx > end_idx: - break - break - - id2docs = {doc.doc_id: doc for doc in docs} - lack_evidence = True - not_detected = set(id2docs.keys()) - set(spans.keys()) - yield Document(channel="info", content=None) - for id, ss in spans.items(): - if not ss: - not_detected.add(id) - continue - ss = sorted(ss, key=lambda x: x["start"]) - text = id2docs[id].text[: ss[0]["start"]] - for idx, span in enumerate(ss): - text += ( - "" + id2docs[id].text[span["start"] : span["end"]] + "" - ) - if idx < len(ss) - 1: - text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] - text += id2docs[id].text[ss[-1]["end"] :] - text_out = markdown.markdown( - text, extensions=["markdown.extensions.tables"] - ) - yield Document( - content=( - "
" - f"{id2docs[id].metadata['file_name']}" - f"{text_out}" - "

" - ), - channel="info", - ) - lack_evidence = False - - if lack_evidence: + # show the evidence + with_citation, without_citation = self.prepare_citations(answer, docs) + if not with_citation and not without_citation: yield Document(channel="info", content="No evidence found.\n") - - if not_detected: - yield Document( - channel="info", - content="Retrieved segments without matching evidence:\n", - ) - for id in list(not_detected): - text_out = markdown.markdown( - id2docs[id].text, extensions=["markdown.extensions.tables"] - ) + else: + yield Document(channel="info", content=None) + for _ in with_citation: + yield _ + if without_citation: yield Document( - content=( - "
" - f"{id2docs[id].metadata['file_name']}" - f"{text_out}" - "

" - ), channel="info", + content="Retrieved segments without matching evidence:\n", ) + for _ in without_citation: + yield _ return answer diff --git a/libs/ktem/ktem/utils/__init__.py b/libs/ktem/ktem/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/ktem/ktem/utils/render.py b/libs/ktem/ktem/utils/render.py new file mode 100644 index 000000000..5890d3327 --- /dev/null +++ b/libs/ktem/ktem/utils/render.py @@ -0,0 +1,21 @@ +import markdown + + +class Render: + """Default text rendering into HTML for the UI""" + + @staticmethod + def collapsible(header, content, open: bool = False) -> str: + """Render an HTML friendly collapsible section""" + o = " open" if open else "" + return f"{header}{content}
" + + @staticmethod + def table(text: str) -> str: + """Render table from markdown format into HTML""" + return markdown.markdown(text, extensions=["markdown.extensions.tables"]) + + @staticmethod + def highlight(text: str) -> str: + """Highlight text""" + return f"{text}"