diff --git a/chat_server/blueprints/personas.py b/chat_server/blueprints/personas.py index 6c0c8abd..8d11f210 100644 --- a/chat_server/blueprints/personas.py +++ b/chat_server/blueprints/personas.py @@ -25,9 +25,12 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json + from fastapi import APIRouter from starlette.responses import JSONResponse +from chat_server.server_utils.api_dependencies import CurrentUserModel from chat_server.server_utils.enums import RequestModelType, UserRoles from chat_server.server_utils.http_exceptions import ( ItemNotFoundException, @@ -47,7 +50,7 @@ PersonaData, ) from chat_server.server_utils.api_dependencies.validators import permitted_access - +from chat_server.sio.server import sio from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI @@ -61,7 +64,7 @@ async def list_personas( current_user: CurrentUserData, request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel), -): +) -> JSONResponse: """Lists personas matching query params""" filters = [] if request_model.llms: @@ -112,6 +115,7 @@ async def add_persona( if existing_model: raise DuplicatedItemException MongoDocumentsAPI.PERSONAS.add_item(data=request_model.model_dump()) + await _notify_personas_changed() return KlatAPIResponse.OK @@ -131,6 +135,7 @@ async def set_persona( MongoDocumentsAPI.PERSONAS.update_item( filters=mongo_filter, data=request_model.model_dump() ) + await _notify_personas_changed() return KlatAPIResponse.OK @@ -140,6 +145,7 @@ async def delete_persona( ): """Deletes persona""" MongoDocumentsAPI.PERSONAS.delete_item(item_id=request_model.persona_id) + await _notify_personas_changed() return KlatAPIResponse.OK @@ -157,4 +163,13 @@ async def toggle_persona_state( ) if updated_data.matched_count == 0: raise ItemNotFoundException + await _notify_personas_changed() return KlatAPIResponse.OK + + +async def _notify_personas_changed(): + response = await list_personas(CurrentUserModel(_id="", nickname="", + first_name="", last_name=""), + ListPersonasQueryModel(only_enabled=True)) + enabled_personas = json.loads(response.body.decode()) + sio.emit("configured_personas_changed", enabled_personas) diff --git a/services/klatchat_observer/controller.py b/services/klatchat_observer/controller.py index ef959c1c..873eabcb 100644 --- a/services/klatchat_observer/controller.py +++ b/services/klatchat_observer/controller.py @@ -28,6 +28,8 @@ import json import re import time +from typing import Optional + import cachetools.func from threading import Event, Timer @@ -304,7 +306,7 @@ def get_neon_service(self, wait_timeout: int = 10) -> None: LOG.info("Joining sync consumer") sync_consumer.join() if not self.neon_service_event.is_set(): - LOG.warning(f"Failed to get neon_service in {wait_timeout} seconds") + LOG.warning(f"Failed to get neon response in {wait_timeout} seconds") self.__neon_service_id = "" def register_sio_handlers(self): @@ -327,6 +329,8 @@ def register_sio_handlers(self): handler=self.request_revoke_submind_ban_from_conversation, ) self._sio.on("auth_expired", handler=self._handle_auth_expired) + self._sio.on("configured_personas_changed", + handler=self._handle_personas_changed) def connect_sio(self): """ @@ -396,6 +400,7 @@ def get_neon_request_structure(msg_data: dict): if requested_skill == "tts": utterance = msg_data.pop("utterance", "") or msg_data.pop("text", "") request_dict = { + "msg_type": "neon.get_tts", "data": { "utterance": utterance, "text": utterance, @@ -404,12 +409,14 @@ def get_neon_request_structure(msg_data: dict): } elif requested_skill == "stt": request_dict = { + "msg_type": "neon.get_stt", "data": { "audio_data": msg_data.pop("audio_data", msg_data["message_body"]), } } else: request_dict = { + "msg_type": "recognizer_loop:utterance", "data": { "utterances": [msg_data["message_body"]], }, @@ -419,13 +426,17 @@ def get_neon_request_structure(msg_data: dict): return request_dict def __handle_neon_recipient(self, recipient_data: dict, msg_data: dict): + """ + Handle a chat message intended for Neon. + """ msg_data.setdefault("message_body", msg_data.pop("messageText", "")) msg_data.setdefault("message_id", msg_data.pop("messageID", "")) recipient_data.setdefault("context", {}) pattern = re.compile("Neon", re.IGNORECASE) msg_data["message_body"] = ( - pattern.sub("", msg_data["message_body"], 1).strip("<>@,.:|- ").capitalize() + pattern.sub("", msg_data["message_body"], 1).strip("<>@,.:|- \n") ) + # This is really referencing an MQ endpoint (i.e. stt, tts), not a skill msg_data.setdefault( "requested_skill", recipient_data["context"].pop("service", "recognizer") ) @@ -774,6 +785,7 @@ def on_subminds_state(self, body: dict): @create_mq_callback() def on_get_configured_personas(self, body: dict): + # Handles request to get all defined personas response_data = self._fetch_persona_api(user_id=body.get("user_id")) response_data["items"] = [ item @@ -791,7 +803,7 @@ def on_get_configured_personas(self, body: dict): ) @cachetools.func.ttl_cache(ttl=15) - def _fetch_persona_api(self, user_id: str) -> dict: + def _fetch_persona_api(self, user_id: Optional[str]) -> dict: query_string = self._build_persona_api_query(user_id=user_id) url = f"{self.server_url}/personas/list?{query_string}" try: @@ -803,12 +815,25 @@ def _fetch_persona_api(self, user_id: str) -> dict: data = {"items": []} return data + def _handle_personas_changed(self, data: dict): + """ + SIO handler called when configured personas are modified. This emits an + MQ message to allow any connected listeners to maintain a set of known + personas. + """ + self.send_message( + request_data=data, + vhost=self.get_vhost("llm"), + queue="configured_personas_changed", + expiration=5000, + ) + def _refresh_default_persona_llms(self, data): for item in data["items"]: if default_llm := item.get("default_llm"): self.default_persona_llms[item["id"]] = item["id"] + "_" + default_llm - def _build_persona_api_query(self, user_id: str) -> str: + def _build_persona_api_query(self, user_id: Optional[str]) -> str: url_query_params = f"only_enabled=true" if user_id: url_query_params += f"&user_id={user_id}"