From c6045bcb9f0e4e1ef2a01c98fcc2c2f2cc990aff Mon Sep 17 00:00:00 2001 From: "Duc Nguyen (john)" Date: Sat, 20 Apr 2024 01:12:23 +0700 Subject: [PATCH] Update the Citation pipeline according to new OpenAI function call interface (#40) --- libs/kotaemon/kotaemon/indices/qa/citation.py | 19 +++---- libs/kotaemon/kotaemon/llms/chats/openai.py | 52 +++++++++---------- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/libs/kotaemon/kotaemon/indices/qa/citation.py b/libs/kotaemon/kotaemon/indices/qa/citation.py index 3192a07fa..bf8def015 100644 --- a/libs/kotaemon/kotaemon/indices/qa/citation.py +++ b/libs/kotaemon/kotaemon/indices/qa/citation.py @@ -75,8 +75,8 @@ def prepare_llm(self, context: str, question: str): "parameters": schema, } llm_kwargs = { - "functions": [function], - "function_call": {"name": function["name"]}, + "tools": [{"type": "function", "function": function}], + "tool_choice": "auto", } messages = [ SystemMessage( @@ -99,14 +99,13 @@ def prepare_llm(self, context: str, question: str): def invoke(self, context: str, question: str): messages, llm_kwargs = self.prepare_llm(context, question) - try: print("CitationPipeline: invoking LLM") llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs) print("CitationPipeline: finish invoking LLM") if not llm_output.messages: return None - function_output = llm_output.messages[0].additional_kwargs["function_call"][ + function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][ "arguments" ] output = QuestionAnswer.parse_raw(function_output) @@ -123,16 +122,12 @@ async def ainvoke(self, context: str, question: str): print("CitationPipeline: async invoking LLM") llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs) print("CitationPipeline: finish async invoking LLM") + function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][ + "arguments" + ] + output = QuestionAnswer.parse_raw(function_output) except Exception as e: print(e) return None - if not llm_output.messages: - return None - - function_output = llm_output.messages[0].additional_kwargs["function_call"][ - "arguments" - ] - output = QuestionAnswer.parse_raw(function_output) - return output diff --git a/libs/kotaemon/kotaemon/llms/chats/openai.py b/libs/kotaemon/kotaemon/llms/chats/openai.py index e0a3a855e..b12567d9c 100644 --- a/libs/kotaemon/kotaemon/llms/chats/openai.py +++ b/libs/kotaemon/kotaemon/llms/chats/openai.py @@ -152,6 +152,28 @@ def prepare_message( return output_ + def prepare_output(self, resp: dict) -> LLMInterface: + """Convert the OpenAI response into LLMInterface""" + additional_kwargs = {} + if "tool_calls" in resp["choices"][0]["message"]: + additional_kwargs["tool_calls"] = resp["choices"][0]["message"][ + "tool_calls" + ] + output = LLMInterface( + candidates=[(_["message"]["content"] or "") for _ in resp["choices"]], + content=resp["choices"][0]["message"]["content"] or "", + total_tokens=resp["usage"]["total_tokens"], + prompt_tokens=resp["usage"]["prompt_tokens"], + completion_tokens=resp["usage"]["completion_tokens"], + additional_kwargs=additional_kwargs, + messages=[ + AIMessage(content=(_["message"]["content"]) or "") + for _ in resp["choices"] + ], + ) + + return output + def prepare_client(self, async_version: bool = False): """Get the OpenAI client @@ -172,19 +194,7 @@ def invoke( resp = self.openai_response( client, messages=input_messages, stream=False, **kwargs ).dict() - - output = LLMInterface( - candidates=[_["message"]["content"] for _ in resp["choices"]], - content=resp["choices"][0]["message"]["content"], - total_tokens=resp["usage"]["total_tokens"], - prompt_tokens=resp["usage"]["prompt_tokens"], - completion_tokens=resp["usage"]["completion_tokens"], - messages=[ - AIMessage(content=_["message"]["content"]) for _ in resp["choices"] - ], - ) - - return output + return self.prepare_output(resp) async def ainvoke( self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs @@ -195,18 +205,7 @@ async def ainvoke( client, messages=input_messages, stream=False, **kwargs ).dict() - output = LLMInterface( - candidates=[_["message"]["content"] for _ in resp["choices"]], - content=resp["choices"][0]["message"]["content"], - total_tokens=resp["usage"]["total_tokens"], - prompt_tokens=resp["usage"]["prompt_tokens"], - completion_tokens=resp["usage"]["completion_tokens"], - messages=[ - AIMessage(content=_["message"]["content"]) for _ in resp["choices"] - ], - ) - - return output + return self.prepare_output(resp) def stream( self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs @@ -338,7 +337,7 @@ def prepare_client(self, async_version: bool = False): def openai_response(self, client, **kwargs): """Get the openai response""" - params = { + params_ = { "model": self.azure_deployment, "temperature": self.temperature, "max_tokens": self.max_tokens, @@ -353,6 +352,7 @@ def openai_response(self, client, **kwargs): "top_logprobs": self.top_logprobs, "top_p": self.top_p, } + params = {k: v for k, v in params_.items() if v is not None} params.update(kwargs) return client.chat.completions.create(**params)