Skip to content

Commit e56b635

Browse files
committed
Defines an SIO and MQ configured_personas_changed event that is emitted any time the server makes a change to the database
Leaves existing logic untouched to allow clients to request an update from the server at any time Relates to NeonGeckoCom/neon-llm-core#8
1 parent ab5bd5c commit e56b635

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

chat_server/blueprints/personas.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
2626
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
2727
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
import json
29+
2830
from fastapi import APIRouter
2931
from starlette.responses import JSONResponse
3032

33+
from chat_server.server_utils.api_dependencies import CurrentUserModel
3134
from chat_server.server_utils.enums import RequestModelType, UserRoles
3235
from chat_server.server_utils.http_exceptions import (
3336
ItemNotFoundException,
@@ -47,7 +50,7 @@
4750
PersonaData,
4851
)
4952
from chat_server.server_utils.api_dependencies.validators import permitted_access
50-
53+
from chat_server.sio.server import sio
5154
from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators
5255
from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI
5356

@@ -61,7 +64,7 @@
6164
async def list_personas(
6265
current_user: CurrentUserData,
6366
request_model: ListPersonasQueryModel = permitted_access(ListPersonasQueryModel),
64-
):
67+
) -> JSONResponse:
6568
"""Lists personas matching query params"""
6669
filters = []
6770
if request_model.llms:
@@ -112,6 +115,7 @@ async def add_persona(
112115
if existing_model:
113116
raise DuplicatedItemException
114117
MongoDocumentsAPI.PERSONAS.add_item(data=request_model.model_dump())
118+
await _notify_personas_changed()
115119
return KlatAPIResponse.OK
116120

117121

@@ -131,6 +135,7 @@ async def set_persona(
131135
MongoDocumentsAPI.PERSONAS.update_item(
132136
filters=mongo_filter, data=request_model.model_dump()
133137
)
138+
await _notify_personas_changed()
134139
return KlatAPIResponse.OK
135140

136141

@@ -140,6 +145,7 @@ async def delete_persona(
140145
):
141146
"""Deletes persona"""
142147
MongoDocumentsAPI.PERSONAS.delete_item(item_id=request_model.persona_id)
148+
await _notify_personas_changed()
143149
return KlatAPIResponse.OK
144150

145151

@@ -157,4 +163,13 @@ async def toggle_persona_state(
157163
)
158164
if updated_data.matched_count == 0:
159165
raise ItemNotFoundException
166+
await _notify_personas_changed()
160167
return KlatAPIResponse.OK
168+
169+
170+
async def _notify_personas_changed():
171+
response = await list_personas(CurrentUserModel(_id="", nickname="",
172+
first_name="", last_name=""),
173+
ListPersonasQueryModel(only_enabled=True))
174+
enabled_personas = json.loads(response.body.decode())
175+
sio.emit("configured_personas_changed", enabled_personas)

services/klatchat_observer/controller.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import json
2929
import re
3030
import time
31+
from typing import Optional
32+
3133
import cachetools.func
3234

3335
from threading import Event, Timer
@@ -303,7 +305,7 @@ def get_neon_service(self, wait_timeout: int = 10) -> None:
303305
LOG.info("Joining sync consumer")
304306
sync_consumer.join()
305307
if not self.neon_service_event.is_set():
306-
LOG.warning(f"Failed to get neon_service in {wait_timeout} seconds")
308+
LOG.warning(f"Failed to get neon response in {wait_timeout} seconds")
307309
self.__neon_service_id = ""
308310

309311
def register_sio_handlers(self):
@@ -326,6 +328,8 @@ def register_sio_handlers(self):
326328
handler=self.request_revoke_submind_ban_from_conversation,
327329
)
328330
self._sio.on("auth_expired", handler=self._handle_auth_expired)
331+
self._sio.on("configured_personas_changed",
332+
handler=self._handle_personas_changed)
329333

330334
def connect_sio(self):
331335
"""
@@ -396,6 +400,7 @@ def get_neon_request_structure(msg_data: dict):
396400
if requested_skill == "tts":
397401
utterance = msg_data.pop("utterance", "") or msg_data.pop("text", "")
398402
request_dict = {
403+
"msg_type": "neon.get_tts",
399404
"data": {
400405
"utterance": utterance,
401406
"text": utterance,
@@ -404,12 +409,14 @@ def get_neon_request_structure(msg_data: dict):
404409
}
405410
elif requested_skill == "stt":
406411
request_dict = {
412+
"msg_type": "neon.get_stt",
407413
"data": {
408414
"audio_data": msg_data.pop("audio_data", msg_data["message_body"]),
409415
}
410416
}
411417
else:
412418
request_dict = {
419+
"msg_type": "recognizer_loop:utterance",
413420
"data": {
414421
"utterances": [msg_data["message_body"]],
415422
},
@@ -419,13 +426,17 @@ def get_neon_request_structure(msg_data: dict):
419426
return request_dict
420427

421428
def __handle_neon_recipient(self, recipient_data: dict, msg_data: dict):
429+
"""
430+
Handle a chat message intended for Neon.
431+
"""
422432
msg_data.setdefault("message_body", msg_data.pop("messageText", ""))
423433
msg_data.setdefault("message_id", msg_data.pop("messageID", ""))
424434
recipient_data.setdefault("context", {})
425435
pattern = re.compile("Neon", re.IGNORECASE)
426436
msg_data["message_body"] = (
427-
pattern.sub("", msg_data["message_body"], 1).strip("<>@,.:|- ").capitalize()
437+
pattern.sub("", msg_data["message_body"], 1).strip("<>@,.:|- \n")
428438
)
439+
# This is really referencing an MQ endpoint (i.e. stt, tts), not a skill
429440
msg_data.setdefault(
430441
"requested_skill", recipient_data["context"].pop("service", "recognizer")
431442
)
@@ -774,6 +785,7 @@ def on_subminds_state(self, body: dict):
774785

775786
@create_mq_callback()
776787
def on_get_configured_personas(self, body: dict):
788+
# Handles request to get all defined personas
777789
response_data = self._fetch_persona_api(user_id=body.get("user_id"))
778790
response_data["items"] = [
779791
item
@@ -791,7 +803,7 @@ def on_get_configured_personas(self, body: dict):
791803
)
792804

793805
@cachetools.func.ttl_cache(ttl=15)
794-
def _fetch_persona_api(self, user_id: str) -> dict:
806+
def _fetch_persona_api(self, user_id: Optional[str]) -> dict:
795807
query_string = self._build_persona_api_query(user_id=user_id)
796808
url = f"{self.server_url}/personas/list?{query_string}"
797809
try:
@@ -803,12 +815,25 @@ def _fetch_persona_api(self, user_id: str) -> dict:
803815
data = {"items": []}
804816
return data
805817

818+
def _handle_personas_changed(self, data: dict):
819+
"""
820+
SIO handler called when configured personas are modified. This emits an
821+
MQ message to allow any connected listeners to maintain a set of known
822+
personas.
823+
"""
824+
self.send_message(
825+
request_data=data,
826+
vhost=self.get_vhost("llm"),
827+
queue="configured_personas_changed",
828+
expiration=5000,
829+
)
830+
806831
def _refresh_default_persona_llms(self, data):
807832
for item in data["items"]:
808833
if default_llm := item.get("default_llm"):
809834
self.default_persona_llms[item["id"]] = item["id"] + "_" + default_llm
810835

811-
def _build_persona_api_query(self, user_id: str) -> str:
836+
def _build_persona_api_query(self, user_id: Optional[str]) -> str:
812837
url_query_params = f"only_enabled=true"
813838
if user_id:
814839
url_query_params += f"&user_id={user_id}"

0 commit comments

Comments
 (0)