Skip to content

Commit

Permalink
Defines an SIO and MQ configured_personas_changed event that is emi…
Browse files Browse the repository at this point in the history
…tted any time the server makes a change to the database

Leaves existing logic untouched to allow clients to request an update from the server at any time
Relates to NeonGeckoCom/neon-llm-core#8
  • Loading branch information
NeonDaniel committed Nov 26, 2024
1 parent be65e64 commit e03408f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
19 changes: 17 additions & 2 deletions chat_server/blueprints/personas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json

from fastapi import APIRouter
from starlette.responses import JSONResponse

from chat_server.server_utils.api_dependencies import CurrentUserModel
from chat_server.server_utils.enums import RequestModelType, UserRoles
from chat_server.server_utils.http_exceptions import (
ItemNotFoundException,
Expand All @@ -47,7 +50,7 @@
PersonaData,
)
from chat_server.server_utils.api_dependencies.validators import permitted_access

from chat_server.sio.server import sio
from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators
from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI

Expand All @@ -61,7 +64,7 @@
async def list_personas(
current_user: CurrentUserData,
request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel),
):
) -> JSONResponse:
"""Lists personas matching query params"""
filters = []
if request_model.llms:
Expand Down Expand Up @@ -112,6 +115,7 @@ async def add_persona(
if existing_model:
raise DuplicatedItemException
MongoDocumentsAPI.PERSONAS.add_item(data=request_model.model_dump())
await _notify_personas_changed()
return KlatAPIResponse.OK


Expand All @@ -131,6 +135,7 @@ async def set_persona(
MongoDocumentsAPI.PERSONAS.update_item(
filters=mongo_filter, data=request_model.model_dump()
)
await _notify_personas_changed()
return KlatAPIResponse.OK


Expand All @@ -140,6 +145,7 @@ async def delete_persona(
):
"""Deletes persona"""
MongoDocumentsAPI.PERSONAS.delete_item(item_id=request_model.persona_id)
await _notify_personas_changed()
return KlatAPIResponse.OK


Expand All @@ -157,4 +163,13 @@ async def toggle_persona_state(
)
if updated_data.matched_count == 0:
raise ItemNotFoundException
await _notify_personas_changed()
return KlatAPIResponse.OK


async def _notify_personas_changed():
response = await list_personas(CurrentUserModel(_id="", nickname="",
first_name="", last_name=""),
ListPersonasQueryModel(only_enabled=True))
enabled_personas = json.loads(response.body.decode())
sio.emit("configured_personas_changed", enabled_personas)
33 changes: 29 additions & 4 deletions services/klatchat_observer/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import json
import re
import time
from typing import Optional

import cachetools.func

from threading import Event, Timer
Expand Down Expand Up @@ -304,7 +306,7 @@ def get_neon_service(self, wait_timeout: int = 10) -> None:
LOG.info("Joining sync consumer")
sync_consumer.join()
if not self.neon_service_event.is_set():
LOG.warning(f"Failed to get neon_service in {wait_timeout} seconds")
LOG.warning(f"Failed to get neon response in {wait_timeout} seconds")
self.__neon_service_id = ""

def register_sio_handlers(self):
Expand All @@ -327,6 +329,8 @@ def register_sio_handlers(self):
handler=self.request_revoke_submind_ban_from_conversation,
)
self._sio.on("auth_expired", handler=self._handle_auth_expired)
self._sio.on("configured_personas_changed",
handler=self._handle_personas_changed)

def connect_sio(self):
"""
Expand Down Expand Up @@ -396,6 +400,7 @@ def get_neon_request_structure(msg_data: dict):
if requested_skill == "tts":
utterance = msg_data.pop("utterance", "") or msg_data.pop("text", "")
request_dict = {
"msg_type": "neon.get_tts",
"data": {
"utterance": utterance,
"text": utterance,
Expand All @@ -404,12 +409,14 @@ def get_neon_request_structure(msg_data: dict):
}
elif requested_skill == "stt":
request_dict = {
"msg_type": "neon.get_stt",
"data": {
"audio_data": msg_data.pop("audio_data", msg_data["message_body"]),
}
}
else:
request_dict = {
"msg_type": "recognizer_loop:utterance",
"data": {
"utterances": [msg_data["message_body"]],
},
Expand All @@ -419,13 +426,17 @@ def get_neon_request_structure(msg_data: dict):
return request_dict

def __handle_neon_recipient(self, recipient_data: dict, msg_data: dict):
"""
Handle a chat message intended for Neon.
"""
msg_data.setdefault("message_body", msg_data.pop("messageText", ""))
msg_data.setdefault("message_id", msg_data.pop("messageID", ""))
recipient_data.setdefault("context", {})
pattern = re.compile("Neon", re.IGNORECASE)
msg_data["message_body"] = (
pattern.sub("", msg_data["message_body"], 1).strip("<>@,.:|- ").capitalize()
pattern.sub("", msg_data["message_body"], 1).strip("<>@,.:|- \n")
)
# This is really referencing an MQ endpoint (i.e. stt, tts), not a skill
msg_data.setdefault(
"requested_skill", recipient_data["context"].pop("service", "recognizer")
)
Expand Down Expand Up @@ -774,6 +785,7 @@ def on_subminds_state(self, body: dict):

@create_mq_callback()
def on_get_configured_personas(self, body: dict):
# Handles request to get all defined personas
response_data = self._fetch_persona_api(user_id=body.get("user_id"))
response_data["items"] = [
item
Expand All @@ -791,7 +803,7 @@ def on_get_configured_personas(self, body: dict):
)

@cachetools.func.ttl_cache(ttl=15)
def _fetch_persona_api(self, user_id: str) -> dict:
def _fetch_persona_api(self, user_id: Optional[str]) -> dict:
query_string = self._build_persona_api_query(user_id=user_id)
url = f"{self.server_url}/personas/list?{query_string}"
try:
Expand All @@ -803,12 +815,25 @@ def _fetch_persona_api(self, user_id: str) -> dict:
data = {"items": []}
return data

def _handle_personas_changed(self, data: dict):
"""
SIO handler called when configured personas are modified. This emits an
MQ message to allow any connected listeners to maintain a set of known
personas.
"""
self.send_message(
request_data=data,
vhost=self.get_vhost("llm"),
queue="configured_personas_changed",
expiration=5000,
)

def _refresh_default_persona_llms(self, data):
for item in data["items"]:
if default_llm := item.get("default_llm"):
self.default_persona_llms[item["id"]] = item["id"] + "_" + default_llm

def _build_persona_api_query(self, user_id: str) -> str:
def _build_persona_api_query(self, user_id: Optional[str]) -> str:
url_query_params = f"only_enabled=true"
if user_id:
url_query_params += f"&user_id={user_id}"
Expand Down

0 comments on commit e03408f

Please sign in to comment.