Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement event-driven persona updates #107

Open
wants to merge 11 commits into
base: alpha
Choose a base branch
from
Open
19 changes: 17 additions & 2 deletions chat_server/blueprints/personas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json

from fastapi import APIRouter
from starlette.responses import JSONResponse

from chat_server.server_utils.api_dependencies import CurrentUserModel
from chat_server.server_utils.enums import RequestModelType, UserRoles
from chat_server.server_utils.http_exceptions import (
ItemNotFoundException,
Expand All @@ -47,7 +50,7 @@
PersonaData,
)
from chat_server.server_utils.api_dependencies.validators import permitted_access

from chat_server.sio.server import sio
from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators
from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI

Expand All @@ -61,7 +64,7 @@
async def list_personas(
current_user: CurrentUserData,
request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel),
):
) -> JSONResponse:
"""Lists personas matching query params"""
filters = []
if request_model.llms:
Expand Down Expand Up @@ -112,6 +115,7 @@ async def add_persona(
if existing_model:
raise DuplicatedItemException
MongoDocumentsAPI.PERSONAS.add_item(data=request_model.model_dump())
await _notify_personas_changed()
return KlatAPIResponse.OK


Expand All @@ -131,6 +135,7 @@ async def set_persona(
MongoDocumentsAPI.PERSONAS.update_item(
filters=mongo_filter, data=request_model.model_dump()
)
await _notify_personas_changed()
return KlatAPIResponse.OK


Expand All @@ -140,6 +145,7 @@ async def delete_persona(
):
"""Deletes persona"""
MongoDocumentsAPI.PERSONAS.delete_item(item_id=request_model.persona_id)
await _notify_personas_changed()
return KlatAPIResponse.OK


Expand All @@ -157,4 +163,13 @@ async def toggle_persona_state(
)
if updated_data.matched_count == 0:
raise ItemNotFoundException
await _notify_personas_changed()
return KlatAPIResponse.OK


async def _notify_personas_changed():
NeonDaniel marked this conversation as resolved.
Show resolved Hide resolved
response = await list_personas(CurrentUserModel(_id="", nickname="",
first_name="", last_name=""),
ListPersonasQueryModel(only_enabled=True))
enabled_personas = json.loads(response.body.decode())
sio.emit("configured_personas_changed", enabled_personas)
33 changes: 29 additions & 4 deletions services/klatchat_observer/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import json
import re
import time
from typing import Optional

import cachetools.func

from threading import Event, Timer
Expand Down Expand Up @@ -304,7 +306,7 @@ def get_neon_service(self, wait_timeout: int = 10) -> None:
LOG.info("Joining sync consumer")
sync_consumer.join()
if not self.neon_service_event.is_set():
LOG.warning(f"Failed to get neon_service in {wait_timeout} seconds")
LOG.warning(f"Failed to get neon response in {wait_timeout} seconds")
self.__neon_service_id = ""

def register_sio_handlers(self):
Expand All @@ -327,6 +329,8 @@ def register_sio_handlers(self):
handler=self.request_revoke_submind_ban_from_conversation,
)
self._sio.on("auth_expired", handler=self._handle_auth_expired)
self._sio.on("configured_personas_changed",
handler=self._handle_personas_changed)

def connect_sio(self):
"""
Expand Down Expand Up @@ -396,6 +400,7 @@ def get_neon_request_structure(msg_data: dict):
if requested_skill == "tts":
utterance = msg_data.pop("utterance", "") or msg_data.pop("text", "")
request_dict = {
"msg_type": "neon.get_tts",
"data": {
"utterance": utterance,
"text": utterance,
Expand All @@ -404,12 +409,14 @@ def get_neon_request_structure(msg_data: dict):
}
elif requested_skill == "stt":
request_dict = {
"msg_type": "neon.get_stt",
"data": {
"audio_data": msg_data.pop("audio_data", msg_data["message_body"]),
}
}
else:
request_dict = {
"msg_type": "recognizer_loop:utterance",
"data": {
"utterances": [msg_data["message_body"]],
},
Expand All @@ -419,13 +426,17 @@ def get_neon_request_structure(msg_data: dict):
return request_dict

def __handle_neon_recipient(self, recipient_data: dict, msg_data: dict):
"""
Handle a chat message intended for Neon.
"""
msg_data.setdefault("message_body", msg_data.pop("messageText", ""))
msg_data.setdefault("message_id", msg_data.pop("messageID", ""))
recipient_data.setdefault("context", {})
pattern = re.compile("Neon", re.IGNORECASE)
msg_data["message_body"] = (
pattern.sub("", msg_data["message_body"], 1).strip("<>@,.:|- ").capitalize()
pattern.sub("", msg_data["message_body"], 1).strip("<>@,.:|- \n")
)
# This is really referencing an MQ endpoint (i.e. stt, tts), not a skill
NeonDaniel marked this conversation as resolved.
Show resolved Hide resolved
msg_data.setdefault(
"requested_skill", recipient_data["context"].pop("service", "recognizer")
)
Expand Down Expand Up @@ -774,6 +785,7 @@ def on_subminds_state(self, body: dict):

@create_mq_callback()
def on_get_configured_personas(self, body: dict):
# Handles request to get all defined personas
response_data = self._fetch_persona_api(user_id=body.get("user_id"))
response_data["items"] = [
item
Expand All @@ -791,7 +803,7 @@ def on_get_configured_personas(self, body: dict):
)

@cachetools.func.ttl_cache(ttl=15)
def _fetch_persona_api(self, user_id: str) -> dict:
def _fetch_persona_api(self, user_id: Optional[str]) -> dict:
query_string = self._build_persona_api_query(user_id=user_id)
url = f"{self.server_url}/personas/list?{query_string}"
try:
Expand All @@ -803,12 +815,25 @@ def _fetch_persona_api(self, user_id: str) -> dict:
data = {"items": []}
return data

def _handle_personas_changed(self, data: dict):
"""
SIO handler called when configured personas are modified. This emits an
MQ message to allow any connected listeners to maintain a set of known
personas.
"""
self.send_message(
request_data=data,
vhost=self.get_vhost("llm"),
queue="configured_personas_changed",
expiration=5000,
)

def _refresh_default_persona_llms(self, data):
for item in data["items"]:
if default_llm := item.get("default_llm"):
self.default_persona_llms[item["id"]] = item["id"] + "_" + default_llm

def _build_persona_api_query(self, user_id: str) -> str:
def _build_persona_api_query(self, user_id: Optional[str]) -> str:
url_query_params = f"only_enabled=true"
if user_id:
url_query_params += f"&user_id={user_id}"
Expand Down