diff --git a/libs/kotaemon/kotaemon/llms/chats/openai.py b/libs/kotaemon/kotaemon/llms/chats/openai.py index b1ae8726..df46da24 100644 --- a/libs/kotaemon/kotaemon/llms/chats/openai.py +++ b/libs/kotaemon/kotaemon/llms/chats/openai.py @@ -196,6 +196,10 @@ def openai_response(self, client, **kwargs): """Get the openai response""" raise NotImplementedError + async def aopenai_response(self, client, **kwargs): + """Get the openai response""" + raise NotImplementedError + def invoke( self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs ) -> LLMInterface: @@ -211,8 +215,10 @@ async def ainvoke( ) -> LLMInterface: client = self.prepare_client(async_version=True) input_messages = self.prepare_message(messages) - resp = await self.openai_response( - client, messages=input_messages, stream=False, **kwargs + resp = ( + await self.aopenai_response( + client, messages=input_messages, stream=False, **kwargs + ) ).dict() return self.prepare_output(resp) @@ -290,8 +296,7 @@ def prepare_client(self, async_version: bool = False): return OpenAI(**params) - def openai_response(self, client, **kwargs): - """Get the openai response""" + def prepare_params(self, **kwargs): if "tools_pydantic" in kwargs: kwargs.pop("tools_pydantic") @@ -313,8 +318,17 @@ def openai_response(self, client, **kwargs): params = {k: v for k, v in params_.items() if v is not None} params.update(kwargs) + return params + + def openai_response(self, client, **kwargs): + """Get the openai response""" + params = self.prepare_params(**kwargs) return client.chat.completions.create(**params) + async def aopenai_response(self, client, **kwargs): + params = self.prepare_params(**kwargs) + return await client.chat.completions.create(**params) + class AzureChatOpenAI(BaseChatOpenAI): """OpenAI chat model provided by Microsoft Azure""" @@ -361,8 +375,7 @@ def prepare_client(self, async_version: bool = False): return AzureOpenAI(**params) - def openai_response(self, client, **kwargs): - """Get the openai response""" + def prepare_params(self, **kwargs): if "tools_pydantic" in kwargs: kwargs.pop("tools_pydantic") @@ -384,4 +397,13 @@ def openai_response(self, client, **kwargs): params = {k: v for k, v in params_.items() if v is not None} params.update(kwargs) + return params + + def openai_response(self, client, **kwargs): + """Get the openai response""" + params = self.prepare_params(**kwargs) return client.chat.completions.create(**params) + + async def aopenai_response(self, client, **kwargs): + params = self.prepare_params(**kwargs) + return await client.chat.completions.create(**params) diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py index 3db13ed8..4b546486 100644 --- a/libs/ktem/ktem/pages/chat/chat_panel.py +++ b/libs/ktem/ktem/pages/chat/chat_panel.py @@ -26,7 +26,7 @@ def on_building_ui(self): scale=20, file_count="multiple", placeholder=( - "Type a message, or search the @web, " "tag a file with @filename" + "Type a message, search the @web, or tag a file with @filename" ), container=False, show_label=False,