Skip to content

Commit

Permalink
Add docstrings and annotations
Browse files Browse the repository at this point in the history
Refactor user configuration handling to support multiple users and updates from Neon
  • Loading branch information
NeonDaniel committed Apr 24, 2024
1 parent 7847c30 commit a391dc5
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 16 deletions.
7 changes: 6 additions & 1 deletion neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,12 @@ def check_refresh_request(self, access_token: str, refresh_token: str,
new_auth = self._create_tokens(encode_data)
return new_auth

def get_client_id(self, token: str):
def get_client_id(self, token: str) -> str:
"""
Extract the client_id from a JWT token
@param token: JWT token to parse
@return: client_id associated with token
"""
auth = jwt.decode(token, self._access_secret, self._jwt_algo)
return auth['client_id']

Expand Down
8 changes: 7 additions & 1 deletion neon_hana/auth/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@

@dataclass
class ClientPermissions:
"""
Data class representing permissions of a particular client connection.
"""
assist: bool = True
backend: bool = True
node: bool = False

def as_dict(self):
def as_dict(self) -> dict:
"""
Get a dict representation of this instance.
"""
return asdict(self)
96 changes: 82 additions & 14 deletions neon_hana/mq_websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,86 +45,154 @@ def __init__(self, config: dict):
self._client = "neon_node_websocket"

def new_connection(self, ws: WebSocket, session_id: str):
"""
Record a new client connection to associate the WebSocket with the
session_id for response routing.
@param ws: Client WebSocket object
@param session_id: Session ID of the client
"""
self._sessions[session_id] = {"session": {"session_id": session_id},
"socket": ws}
"socket": ws,
"user": self.user_config}

def get_session(self, session_id: str):
def get_session(self, session_id: str) -> dict:
"""
Get the latest session context for the given session_id.
@param session_id: Session ID to get context for
@return: dict context for the given session_id (may be empty)
"""
with self._session_lock:
sess = dict(self._sessions.get(session_id, {}).get("session"))
sess = dict(self._sessions.get(session_id, {}).get("session", {}))
return sess

@property
def user_config(self) -> dict:
# TODO: Handle per-session config
return super().user_config
def get_user_config(self, session_id: str) -> dict:
"""
Get a dict user configuration for the given session_id
@param session_id: Session to get user configuration for
@return: dict user configuration
"""
with self._session_lock:
config = dict(self._sessions.get(session_id, {}).get("user") or
self.user_config)
return config

def _get_message_context(self, message):
def _get_message_context(self, message: Message, session_id: str) -> dict:
"""
Build message context for a Node input message.
@param message: Input message to include context from
@param session_id: Session ID associated with the message
@return: dict context for this input
"""
user_config = self.get_user_config(session_id)
default_context = {"client_name": self.client_name,
"client": self._client,
"ident": str(time()),
"username": self.user_config['user']['username'],
"user_profiles": [self.user_config],
"username": user_config['user']['username'],
"user_profiles": [user_config],
"neon_should_respond": True,
"timing": dict(),
"mq": {"routing_key": self.uid,
"message_id": self.connection.
create_unique_id()}}
return {**message.context, **default_context}

def _update_session_data(self, message):
def _update_session_data(self, message: Message):
"""
Update the local session data from the latest response message's context
Update the local session data and user profile from the latest response
message's context.
@param message: Response message containing updated context
"""
session_data = message.context.get('session')
if session_data:
user_config = message.context.get('user_profiles', [None])[0]
session_id = session_data.get('session_id')
with self._session_lock:
self._sessions[session_id]['session'] = session_data
if user_config:
self._sessions[session_id]['user'] = user_config

def handle_client_input(self, data: dict, session_id: str):
"""
Handle some client input
Handle some client input data.
@param data: Decoded input from client WebSocket
@param session_id: Session ID associated with the client connection
"""
# Handle `Message.serialize` data sent over WS in addition to proper
# dict representations
data['msg_type'] = data.pop("type", data.get("msg_type"))
message = Message(**data)
message.context = self._get_message_context(message)
message.context = self._get_message_context(message, session_id)
message.context["session"] = self.get_session(session_id)
# Send raw message, skipping any validation by iris
self._send_message(message)

def handle_klat_response(self, message: Message):
"""
Handle a Neon text+audio response to a user input.
@param message: `klat.response` message from Neon
"""
self._update_session_data(message)
run(self.send_to_client(message))
LOG.debug(message.context.get("timing"))

def handle_complete_intent_failure(self, message: Message):
"""
Handle a Neon error response to a user input.
@param message: `complete.intent.failure` message from Neon
"""
self._update_session_data(message)
run(self.send_to_client(message))

def handle_api_response(self, message: Message):
"""
Handle a Neon API response to an input.
@param message: `<msg_type>.response` message from Neon
"""
if message.msg_type == "neon.audio_input.response":
LOG.info(message.data.get("transcripts"))
LOG.debug(message.context.get("timing"))

def handle_error_response(self, message: Message):
"""
Handle an MQ error response to a user input.
@param message: `klat.error` response message
"""
run(self.send_to_client(message))

def clear_caches(self, message: Message):
"""
Handle a Neon request to clear cached data.
@param message: `neon.clear_data` message from Neon
"""
run(self.send_to_client(message))

def clear_media(self, message: Message):
"""
Handle a Neon request to clear media data.
@param message: `neon.clear_data` message from Neon
"""
run(self.send_to_client(message))

def handle_alert(self, message: Message):
"""
Handle an expired alert from Neon.
@param message: `neon.alert_expired` message from Neon
"""
run(self.send_to_client(message))

async def send_to_client(self, message: Message):
"""
Asynchronously forward a message from Neon/MQ to a WebSocket client.
@param message: Message to forward to a WebSocket client
"""
# TODO: Drop context?
session_id = message.context["session"]["session_id"]
await self._sessions[session_id]["socket"].send_text(message.serialize())

def shutdown(self, *_, **__):
"""
Shutdown the event loop and prepare this object for destruction.
"""
loop = get_event_loop()
loop.call_soon_threadsafe(loop.stop)
LOG.info("Stopped Event Loop")
Expand Down

0 comments on commit a391dc5

Please sign in to comment.