Skip to content

Commit

Permalink
Update the Citation pipeline according to new OpenAI function call in…
Browse files Browse the repository at this point in the history
…terface (#40)
  • Loading branch information
trducng authored Apr 19, 2024
1 parent 1b2082a commit c6045bc
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 38 deletions.
19 changes: 7 additions & 12 deletions libs/kotaemon/kotaemon/indices/qa/citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def prepare_llm(self, context: str, question: str):
"parameters": schema,
}
llm_kwargs = {
"functions": [function],
"function_call": {"name": function["name"]},
"tools": [{"type": "function", "function": function}],
"tool_choice": "auto",
}
messages = [
SystemMessage(
Expand All @@ -99,14 +99,13 @@ def prepare_llm(self, context: str, question: str):

def invoke(self, context: str, question: str):
messages, llm_kwargs = self.prepare_llm(context, question)

try:
print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM")
if not llm_output.messages:
return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
Expand All @@ -123,16 +122,12 @@ async def ainvoke(self, context: str, question: str):
print("CitationPipeline: async invoking LLM")
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
print("CitationPipeline: finish async invoking LLM")
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)
except Exception as e:
print(e)
return None

if not llm_output.messages:
return None

function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments"
]
output = QuestionAnswer.parse_raw(function_output)

return output
52 changes: 26 additions & 26 deletions libs/kotaemon/kotaemon/llms/chats/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,28 @@ def prepare_message(

return output_

def prepare_output(self, resp: dict) -> LLMInterface:
"""Convert the OpenAI response into LLMInterface"""
additional_kwargs = {}
if "tool_calls" in resp["choices"][0]["message"]:
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
"tool_calls"
]
output = LLMInterface(
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"] or "",
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
additional_kwargs=additional_kwargs,
messages=[
AIMessage(content=(_["message"]["content"]) or "")
for _ in resp["choices"]
],
)

return output

def prepare_client(self, async_version: bool = False):
"""Get the OpenAI client
Expand All @@ -172,19 +194,7 @@ def invoke(
resp = self.openai_response(
client, messages=input_messages, stream=False, **kwargs
).dict()

output = LLMInterface(
candidates=[_["message"]["content"] for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"],
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
messages=[
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
],
)

return output
return self.prepare_output(resp)

async def ainvoke(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
Expand All @@ -195,18 +205,7 @@ async def ainvoke(
client, messages=input_messages, stream=False, **kwargs
).dict()

output = LLMInterface(
candidates=[_["message"]["content"] for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"],
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
messages=[
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
],
)

return output
return self.prepare_output(resp)

def stream(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
Expand Down Expand Up @@ -338,7 +337,7 @@ def prepare_client(self, async_version: bool = False):

def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = {
params_ = {
"model": self.azure_deployment,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
Expand All @@ -353,6 +352,7 @@ def openai_response(self, client, **kwargs):
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
}
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs)

return client.chat.completions.create(**params)

0 comments on commit c6045bc

Please sign in to comment.