Skip to content

Commit

Permalink
Feat/regenerate answer (Cinnamon#7)
Browse files Browse the repository at this point in the history
* Add regen button and repharasing question on regen

* Stop appending regen messages to history, allow only one

* Add dynamic conversation state

* Allow reasoning pipeline to manipulate state

---------

Co-authored-by: albert <albert@cinnamon.is>
Co-authored-by: Duc Nguyen (john) <trungduc1992@gmail.com>
  • Loading branch information
3 people authored and phv2312 committed Sep 11, 2024
1 parent 6e3be69 commit 28a02f4
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 24 deletions.
2 changes: 0 additions & 2 deletions libs/kotaemon/kotaemon/loaders/utils/adobe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

from kotaemon.loaders.utils.gpt4v import generate_gpt4v

logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))


def request_adobe_service(file_path: str, output_path: str = "") -> str:
"""Main function to call the adobe service, and unzip the results.
Expand Down
1 change: 1 addition & 0 deletions libs/ktem/ktem/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class BaseApp:
The main application contains app-level information:
- setting state
- dynamic conversation state
- user id
Also contains registering methods for:
Expand Down
89 changes: 77 additions & 12 deletions libs/ktem/ktem/pages/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlmodel import Session, select

from .chat_panel import ChatPanel
from .common import STATE
from .control import ConversationControl
from .report import ReportIssue

Expand All @@ -21,6 +22,7 @@ def __init__(self, app):

def on_building_ui(self):
with gr.Row():
self.chat_state = gr.State(STATE)
with gr.Column(scale=1):
self.chat_control = ConversationControl(self._app)

Expand Down Expand Up @@ -62,19 +64,47 @@ def on_register_events(self):
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.chat_state,
]
+ self._indices_input,
outputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
],
show_progress="minimal",
).then(
fn=self.update_data_source,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self.chat_state,
]
+ self._indices_input,
outputs=None,
)

self.chat_panel.regen_btn.click(
fn=self.regen_fn,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.chat_state,
]
+ self._indices_input,
outputs=[
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
],
show_progress="minimal",
).then(
fn=self.update_data_source,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self.chat_state,
]
+ self._indices_input,
outputs=None,
Expand All @@ -94,6 +124,7 @@ def on_register_events(self):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
Expand All @@ -109,12 +140,13 @@ def on_register_events(self):
self.chat_panel.chatbot,
self._app.settings_state,
self._app.user_id,
self.chat_state,
]
+ self._indices_input,
outputs=None,
)

def update_data_source(self, convo_id, messages, *selecteds):
def update_data_source(self, convo_id, messages, state, *selecteds):
"""Update the data source"""
if not convo_id:
gr.Warning("No conversation selected")
Expand All @@ -133,6 +165,7 @@ def update_data_source(self, convo_id, messages, *selecteds):
result.data_source = {
"selected": selecteds_,
"messages": messages,
"state": state,
"likes": deepcopy(data_source.get("likes", [])),
}
session.add(result)
Expand All @@ -152,17 +185,22 @@ def is_liked(self, convo_id, liked: gr.LikeData):
session.add(result)
session.commit()

def create_pipeline(self, settings: dict, *selecteds):
def create_pipeline(self, settings: dict, state: dict, *selecteds):
"""Create the pipeline from settings
Args:
settings: the settings of the app
is_regen: whether the regen button is clicked
selected: the list of file ids that will be served as context. If None, then
consider using all files
Returns:
the pipeline objects
- the pipeline objects
"""
reasoning_mode = settings["reasoning.use"]
reasoning_cls = reasonings[reasoning_mode]
reasoning_id = reasoning_cls.get_info()["id"]

# get retrievers
retrievers = []
for index in self._app.index_manager.indices:
Expand All @@ -172,21 +210,25 @@ def create_pipeline(self, settings: dict, *selecteds):
iretrievers = index.get_retriever_pipelines(settings, index_selected)
retrievers += iretrievers

reasoning_mode = settings["reasoning.use"]
reasoning_cls = reasonings[reasoning_mode]
pipeline = reasoning_cls.get_pipeline(settings, retrievers)
# prepare states
reasoning_state = {
"app": deepcopy(state["app"]),
"pipeline": deepcopy(state.get(reasoning_id, {})),
}

return pipeline
pipeline = reasoning_cls.get_pipeline(settings, reasoning_state, retrievers)

async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
return pipeline, reasoning_state

async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
"""Chat function"""
chat_input = chat_history[-1][0]
chat_history = chat_history[:-1]

queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()

# construct the pipeline
pipeline = self.create_pipeline(settings, *selecteds)
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
pipeline.set_output_queue(queue)

asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
Expand All @@ -198,7 +240,8 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
try:
response = queue.get_nowait()
except Exception:
yield "", chat_history + [(chat_input, text or "Thinking ...")], refs
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [(chat_input, text or "Thinking ...")], refs, state
continue

if response is None:
Expand All @@ -208,6 +251,7 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):

if "output" in response:
text += response["output"]

if "evidence" in response:
if response["evidence"] is None:
refs = ""
Expand All @@ -218,4 +262,25 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
print(f"Len refs: {len(refs)}")
len_ref = len(refs)

yield "", chat_history + [(chat_input, text)], refs
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [(chat_input, text)], refs, state

async def regen_fn(
self, conversation_id, chat_history, settings, state, *selecteds
):
"""Regen function"""
if not chat_history:
gr.Warning("Empty chat")
yield chat_history, "", state
return

state["app"]["regen"] = True
async for chat, refs, state in self.chat_fn(
conversation_id, chat_history, settings, state, *selecteds
):
new_state = deepcopy(state)
new_state["app"]["regen"] = False
yield chat, refs, new_state
else:
state["app"]["regen"] = False
yield chat_history, "", state
1 change: 1 addition & 0 deletions libs/ktem/ktem/pages/chat/chat_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def on_building_ui(self):
placeholder="Chat input", scale=15, container=False
)
self.submit_btn = gr.Button(value="Send", scale=1, min_width=10)
self.regen_btn = gr.Button(value="Regen", scale=1, min_width=10)

def submit_msg(self, chat_input, chat_history):
"""Submit a message to the chatbot"""
Expand Down
4 changes: 4 additions & 0 deletions libs/ktem/ktem/pages/chat/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
DEFAULT_APPLICATION_STATE = {"regen": False}
STATE = {
"app": DEFAULT_APPLICATION_STATE,
}
6 changes: 5 additions & 1 deletion libs/ktem/ktem/pages/chat/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ktem.db.models import Conversation, engine
from sqlmodel import Session, select

from .common import STATE

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -159,12 +161,14 @@ def select_conv(self, conversation_id):
name = result.name
selected = result.data_source.get("selected", {})
chats = result.data_source.get("messages", [])
state = result.data_source.get("state", STATE)
except Exception as e:
logger.warning(e)
id_ = ""
name = ""
selected = {}
chats = []
state = STATE

indices = []
for index in self._app.index_manager.indices:
Expand All @@ -173,7 +177,7 @@ def select_conv(self, conversation_id):
continue
indices.append(selected.get(str(index.id), []))

return id_, id_, name, chats, *indices
return id_, id_, name, chats, state, *indices

def rename_conv(self, conversation_id, new_name, user_id):
"""Rename the conversation"""
Expand Down
2 changes: 2 additions & 0 deletions libs/ktem/ktem/pages/chat/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def report(
chat_history: list,
settings: dict,
user_id: Optional[int],
chat_state: dict,
*selecteds
):
selecteds_ = {}
Expand All @@ -65,6 +66,7 @@ def report(
chat={
"conv_id": conv_id,
"chat_history": chat_history,
"chat_state": chat_state,
"selecteds": selecteds_,
},
settings=settings,
Expand Down
70 changes: 61 additions & 9 deletions libs/ktem/ktem/reasoning/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import tiktoken
from ktem.components import llms
from ktem.reasoning.base import BaseReasoning
from theflow.settings import settings as flowsettings

from kotaemon.base import (
Expand Down Expand Up @@ -164,6 +163,15 @@ def run(self, docs: list[RetrievedDocument]) -> Document:
"Answer: "
)

DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
)


class AnswerWithContextPipeline(BaseComponent):
"""Answer the question based on the evidence
Expand Down Expand Up @@ -287,15 +295,48 @@ async def run( # type: ignore

return answer

def extract_evidence_images(self, evidence: str):
"""Util function to extract and isolate images from context/evidence"""
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
matches = re.findall(image_pattern, evidence)
context = re.sub(image_pattern, "", evidence)
return context, matches

def extract_evidence_images(self, evidence: str):
"""Util function to extract and isolate images from context/evidence"""
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
matches = re.findall(image_pattern, evidence)
context = re.sub(image_pattern, "", evidence)
return context, matches


class RewriteQuestionPipeline(BaseComponent):
"""Rewrite user question
Args:
llm: the language model to rewrite question
rewrite_template: the prompt template for llm to paraphrase a text input
lang: the language of the answer. Currently support English and Japanese
"""

llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost())
rewrite_template: str = DEFAULT_REWRITE_PROMPT

lang: str = "English"

async def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
output = ""
for text in self.llm(messages):
if "content" in text:
output += text[1]
self.report_output({"chat_input": text[1]})
break
await asyncio.sleep(0)

return Document(text=output)

class FullQAPipeline(BaseReasoning):

class FullQAPipeline(BaseComponent):
"""Question answering pipeline. Handle from question to answer"""

class Config:
Expand All @@ -305,12 +346,18 @@ class Config:

evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False

async def run( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
docs = []
doc_ids = []
if self.use_rewrite:
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text

for retriever in self.retrievers:
for doc in retriever(text=message):
if doc.doc_id not in doc_ids:
Expand Down Expand Up @@ -402,7 +449,7 @@ async def run( # type: ignore
return answer

@classmethod
def get_pipeline(cls, settings, retrievers):
def get_pipeline(cls, settings, states, retrievers):
"""Get the reasoning pipeline
Args:
Expand Down Expand Up @@ -430,6 +477,11 @@ def get_pipeline(cls, settings, retrievers):
pipeline.answering_pipeline.qa_template = settings[
f"reasoning.options.{_id}.qa_prompt"
]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.rewrite_pipeline.llm = llms.get_lowest_cost()
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
return pipeline

@classmethod
Expand Down

0 comments on commit 28a02f4

Please sign in to comment.