From 033febe8e9bfe0b248959118156e7b564d81d8bd Mon Sep 17 00:00:00 2001 From: "luca.gobbi" Date: Thu, 3 Oct 2024 08:24:09 +0200 Subject: [PATCH 1/2] run stray in threadpool to allow tool execution on http message endpoint --- core/cat/routes/base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/cat/routes/base.py b/core/cat/routes/base.py index d9d021e5..fae243c6 100644 --- a/core/cat/routes/base.py +++ b/core/cat/routes/base.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, Depends, Body +from fastapi.concurrency import run_in_threadpool from typing import Dict import tomli from cat.auth.permissions import AuthPermission, AuthResource @@ -27,5 +28,11 @@ async def message_with_cat( stray=Depends(HTTPAuth(AuthResource.CONVERSATION, AuthPermission.WRITE)), ) -> Dict: """Get a response from the Cat""" - answer = await stray({"user_id": stray.user_id, **payload}) + answer = await run_in_threadpool(stray_run, stray, {"user_id": stray.user_id, **payload}) return answer + + +def stray_run(stray, user_message_json): + cat_message = stray.loop.run_until_complete(stray(user_message_json)) + return cat_message + From a494d205d7f9c85b3c20bd0dbf0d083af76634c3 Mon Sep 17 00:00:00 2001 From: "luca.gobbi" Date: Sat, 5 Oct 2024 08:58:48 +0200 Subject: [PATCH 2/2] hide complexity in stray.run --- core/cat/looking_glass/stray_cat.py | 17 +++++++++++------ core/cat/routes/base.py | 9 ++------- core/cat/routes/websocket.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index 817cf669..f18bf220 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -450,17 +450,22 @@ async def __call__(self, message_dict): return final_output - def run(self, user_message_json): + def run(self, user_message_json, return_message=False): try: cat_message = self.loop.run_until_complete(self.__call__(user_message_json)) - # send message back to client - self.send_chat_message(cat_message) + if return_message: + # return the message for HTTP usage + return cat_message + else: + # send message back to client via WS + self.send_chat_message(cat_message) except Exception as e: - # Log any unexpected errors log.error(e) traceback.print_exc() - # Send error as websocket message - self.send_error(e) + if return_message: + return {"error": str(e)} + else: + self.send_error(e) def classify( self, sentence: str, labels: List[str] | Dict[str, List[str]] diff --git a/core/cat/routes/base.py b/core/cat/routes/base.py index fae243c6..e24fafbf 100644 --- a/core/cat/routes/base.py +++ b/core/cat/routes/base.py @@ -28,11 +28,6 @@ async def message_with_cat( stray=Depends(HTTPAuth(AuthResource.CONVERSATION, AuthPermission.WRITE)), ) -> Dict: """Get a response from the Cat""" - answer = await run_in_threadpool(stray_run, stray, {"user_id": stray.user_id, **payload}) + user_message_json = {"user_id": stray.user_id, **payload} + answer = await run_in_threadpool(stray.run, user_message_json, True) return answer - - -def stray_run(stray, user_message_json): - cat_message = stray.loop.run_until_complete(stray(user_message_json)) - return cat_message - diff --git a/core/cat/routes/websocket.py b/core/cat/routes/websocket.py index e415e259..8031bf54 100644 --- a/core/cat/routes/websocket.py +++ b/core/cat/routes/websocket.py @@ -22,7 +22,7 @@ async def receive_message(websocket: WebSocket, stray: StrayCat): user_message["user_id"] = stray.user_id # Run the `stray` object's method in a threadpool since it might be a CPU-bound operation. - await run_in_threadpool(stray.run, user_message) + await run_in_threadpool(stray.run, user_message, return_message=False) @router.websocket("/ws")