Skip to content

Commit 251010a

Browse files
committed
Implement Redis integration for client-user mapping and enhance wake word processing
- Added asynchronous Redis support in ClientManager for tracking client-user relationships. - Introduced `initialize_redis_for_client_manager` to set up Redis for cross-container mapping. - Updated `create_client_state` to use asynchronous tracking for client-user relationships. - Enhanced wake word processing in PluginRouter with normalization and command extraction. - Refactored DeepgramStreamingConsumer to utilize async Redis lookups for user ID retrieval. - Set TTL on Redis streams during client state cleanup for better resource management.
1 parent 32d541f commit 251010a

File tree

6 files changed

+162
-20
lines changed

6 files changed

+162
-20
lines changed

backends/advanced/src/advanced_omi_backend/app_factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ async def lifespan(app: FastAPI):
111111
from advanced_omi_backend.services.audio_stream import AudioStreamProducer
112112
app.state.audio_stream_producer = AudioStreamProducer(app.state.redis_audio_stream)
113113
application_logger.info("✅ Redis client for audio streaming producer initialized")
114+
115+
# Initialize ClientManager Redis for cross-container client→user mapping
116+
from advanced_omi_backend.client_manager import initialize_redis_for_client_manager
117+
initialize_redis_for_client_manager(config.redis_url)
118+
114119
except Exception as e:
115120
application_logger.error(f"Failed to initialize Redis client for audio streaming: {e}", exc_info=True)
116121
application_logger.warning("Audio streaming producer will not be available")

backends/advanced/src/advanced_omi_backend/client_manager.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import uuid
1111
from typing import TYPE_CHECKING, Dict, Optional
12+
import redis.asyncio as redis
1213

1314
if TYPE_CHECKING:
1415
from advanced_omi_backend.client import ClientState
@@ -21,6 +22,9 @@
2122
_client_to_user_mapping: Dict[str, str] = {} # Active clients only
2223
_all_client_user_mappings: Dict[str, str] = {} # All clients including disconnected
2324

25+
# Redis client for cross-container client→user mapping
26+
_redis_client: Optional[redis.Redis] = None
27+
2428

2529
class ClientManager:
2630
"""
@@ -372,9 +376,33 @@ def unregister_client_user_mapping(client_id: str):
372376
logger.warning(f"⚠️ Attempted to unregister non-existent client {client_id}")
373377

374378

379+
async def track_client_user_relationship_async(client_id: str, user_id: str, ttl: int = 86400):
380+
"""
381+
Track that a client belongs to a user (async, writes to Redis for cross-container support).
382+
383+
Args:
384+
client_id: The client ID
385+
user_id: The user ID that owns this client
386+
ttl: Time-to-live in seconds (default 24 hours)
387+
"""
388+
_all_client_user_mappings[client_id] = user_id # In-memory fallback
389+
390+
if _redis_client:
391+
try:
392+
await _redis_client.setex(f"client:owner:{client_id}", ttl, user_id)
393+
logger.debug(f"✅ Tracked client {client_id} → user {user_id} in Redis (TTL: {ttl}s)")
394+
except Exception as e:
395+
logger.warning(f"Failed to track client in Redis: {e}")
396+
else:
397+
logger.debug(f"Tracked client {client_id} relationship to user {user_id} (in-memory only)")
398+
399+
375400
def track_client_user_relationship(client_id: str, user_id: str):
376401
"""
377-
Track that a client belongs to a user (persists after disconnection for database queries).
402+
Track that a client belongs to a user (sync version for backward compatibility).
403+
404+
WARNING: This is synchronous and cannot use Redis. Use track_client_user_relationship_async()
405+
instead in async contexts for cross-container support.
378406
379407
Args:
380408
client_id: The client ID
@@ -444,9 +472,45 @@ def get_user_clients_active(user_id: str) -> list[str]:
444472
return user_clients
445473

446474

475+
def initialize_redis_for_client_manager(redis_url: str):
476+
"""
477+
Initialize Redis client for cross-container client→user mapping.
478+
479+
Args:
480+
redis_url: Redis connection URL
481+
"""
482+
global _redis_client
483+
_redis_client = redis.from_url(redis_url, decode_responses=True)
484+
logger.info(f"✅ ClientManager Redis initialized: {redis_url}")
485+
486+
487+
async def get_client_owner_async(client_id: str) -> Optional[str]:
488+
"""
489+
Get the user ID that owns a specific client (async Redis lookup).
490+
491+
Args:
492+
client_id: The client ID to look up
493+
494+
Returns:
495+
User ID if found, None otherwise
496+
"""
497+
if _redis_client:
498+
try:
499+
user_id = await _redis_client.get(f"client:owner:{client_id}")
500+
return user_id
501+
except Exception as e:
502+
logger.warning(f"Redis lookup failed for client {client_id}: {e}")
503+
504+
# Fallback to in-memory mapping
505+
return _all_client_user_mappings.get(client_id)
506+
507+
447508
def get_client_owner(client_id: str) -> Optional[str]:
448509
"""
449-
Get the user ID that owns a specific client.
510+
Get the user ID that owns a specific client (sync version for backward compatibility).
511+
512+
WARNING: This is synchronous and cannot use Redis. Use get_client_owner_async() instead
513+
in async contexts for cross-container support.
450514
451515
Args:
452516
client_id: The client ID to look up

backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ async def create_client_state(client_id: str, user, device_name: Optional[str] =
189189
client_id, CHUNK_DIR, user.user_id, user.email
190190
)
191191

192-
# Also track in persistent mapping (for database queries)
193-
from advanced_omi_backend.client_manager import track_client_user_relationship
194-
track_client_user_relationship(client_id, user.user_id)
192+
# Also track in persistent mapping (for database queries + cross-container Redis)
193+
from advanced_omi_backend.client_manager import track_client_user_relationship_async
194+
await track_client_user_relationship_async(client_id, user.user_id)
195195

196196
# Register client in user model (persistent)
197197
from advanced_omi_backend.users import register_client_to_user
@@ -265,12 +265,12 @@ async def cleanup_client_state(client_id: str):
265265
if sessions_closed > 0:
266266
logger.info(f"✅ Closed {sessions_closed} active session(s) for client {client_id}")
267267

268-
# Delete Redis Streams for this client
268+
# Set TTL on Redis Streams for this client (allows consumer groups to finish processing)
269269
stream_pattern = f"audio:stream:{client_id}"
270270
stream_key = await async_redis.exists(stream_pattern)
271271
if stream_key:
272-
await async_redis.delete(stream_pattern)
273-
logger.info(f"🧹 Deleted Redis stream: {stream_pattern}")
272+
await async_redis.expire(stream_pattern, 60) # 60 second TTL for consumer group fan-out
273+
logger.info(f"⏰ Set 60s TTL on Redis stream: {stream_pattern}")
274274
else:
275275
logger.debug(f"No Redis stream found for client {client_id}")
276276

backends/advanced/src/advanced_omi_backend/plugins/router.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,80 @@
55
"""
66

77
import logging
8+
import re
9+
import string
810
from typing import Dict, List, Optional
911

1012
from .base import BasePlugin, PluginContext, PluginResult
1113

1214
logger = logging.getLogger(__name__)
1315

1416

17+
def normalize_text_for_wake_word(text: str) -> str:
18+
"""
19+
Normalize text for wake word matching.
20+
- Lowercase
21+
- Remove punctuation
22+
- Collapse multiple spaces to single space
23+
- Strip leading/trailing whitespace
24+
25+
Example:
26+
"Hey, Vivi!" -> "hey vivi"
27+
"HEY VIVI" -> "hey vivi"
28+
"""
29+
# Lowercase
30+
text = text.lower()
31+
# Remove punctuation
32+
text = text.translate(str.maketrans('', '', string.punctuation))
33+
# Normalize whitespace (collapse multiple spaces to single space)
34+
text = re.sub(r'\s+', ' ', text)
35+
# Strip leading/trailing whitespace
36+
return text.strip()
37+
38+
39+
def extract_command_after_wake_word(transcript: str, wake_word: str) -> str:
40+
"""
41+
Intelligently extract command after wake word in original transcript.
42+
43+
Handles punctuation and spacing variations by creating a flexible regex pattern.
44+
45+
Example:
46+
transcript: "Hey, Vivi, turn off lights"
47+
wake_word: "hey vivi"
48+
-> extracts: "turn off lights"
49+
50+
Args:
51+
transcript: Original transcript text with punctuation
52+
wake_word: Configured wake word (will be normalized)
53+
54+
Returns:
55+
Command text after wake word, or full transcript if wake word boundary not found
56+
"""
57+
# Split wake word into parts (normalized)
58+
wake_word_parts = normalize_text_for_wake_word(wake_word).split()
59+
60+
if not wake_word_parts:
61+
return transcript.strip()
62+
63+
# Create regex pattern that allows punctuation/whitespace between parts
64+
# Example: "hey" + "vivi" -> r"hey[\s,.\-!?]*vivi"
65+
pattern_parts = [re.escape(part) for part in wake_word_parts]
66+
pattern = r'\s*[\W_]*\s*'.join(pattern_parts)
67+
pattern = '^' + pattern # Must be at start of transcript
68+
69+
# Try to match wake word at start of transcript (case-insensitive)
70+
match = re.match(pattern, transcript, re.IGNORECASE)
71+
72+
if match:
73+
# Extract everything after the matched wake word
74+
command = transcript[match.end():].strip()
75+
return command
76+
else:
77+
# Fallback: couldn't find wake word boundary, return full transcript
78+
logger.warning(f"Could not find wake word boundary for '{wake_word}' in '{transcript}', using full transcript")
79+
return transcript.strip()
80+
81+
1582
class PluginRouter:
1683
"""Routes pipeline events to appropriate plugins based on access level and triggers"""
1784

@@ -113,9 +180,9 @@ async def _should_trigger(self, plugin: BasePlugin, data: Dict) -> bool:
113180
return True
114181

115182
elif trigger_type == 'wake_word':
116-
# Check if transcript starts with wake word(s)
183+
# Normalize transcript for matching (handles punctuation and spacing)
117184
transcript = data.get('transcript', '')
118-
transcript_lower = transcript.lower().strip()
185+
normalized_transcript = normalize_text_for_wake_word(transcript)
119186

120187
# Support both singular 'wake_word' and plural 'wake_words' (list)
121188
wake_words = plugin.trigger.get('wake_words', [])
@@ -125,14 +192,15 @@ async def _should_trigger(self, plugin: BasePlugin, data: Dict) -> bool:
125192
if wake_word:
126193
wake_words = [wake_word]
127194

128-
# Check if transcript starts with any wake word
195+
# Check if transcript starts with any wake word (after normalization)
129196
for wake_word in wake_words:
130-
wake_word_lower = wake_word.lower()
131-
if wake_word_lower and transcript_lower.startswith(wake_word_lower):
132-
# Extract command (remove wake word)
133-
command = transcript[len(wake_word):].strip()
197+
normalized_wake_word = normalize_text_for_wake_word(wake_word)
198+
if normalized_wake_word and normalized_transcript.startswith(normalized_wake_word):
199+
# Smart extraction: find where wake word actually ends in original text
200+
command = extract_command_after_wake_word(transcript, wake_word)
134201
data['command'] = command
135202
data['original_transcript'] = transcript
203+
logger.debug(f"Wake word '{wake_word}' detected. Original: '{transcript}', Command: '{command}'")
136204
return True
137205

138206
return False

backends/advanced/src/advanced_omi_backend/services/transcription/deepgram_stream_consumer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from advanced_omi_backend.plugins.router import PluginRouter
2121
from advanced_omi_backend.services.transcription import get_transcription_provider
22-
from advanced_omi_backend.client_manager import get_client_owner
22+
from advanced_omi_backend.client_manager import get_client_owner_async
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -257,20 +257,20 @@ async def store_final_result(self, session_id: str, result: Dict, chunk_id: str
257257

258258
async def _get_user_id_from_client_id(self, client_id: str) -> Optional[str]:
259259
"""
260-
Look up user_id from client_id using ClientManager.
260+
Look up user_id from client_id using ClientManager (async Redis lookup).
261261
262262
Args:
263263
client_id: Client ID to search for
264264
265265
Returns:
266266
user_id if found, None otherwise
267267
"""
268-
user_id = get_client_owner(client_id)
268+
user_id = await get_client_owner_async(client_id)
269269

270270
if user_id:
271-
logger.debug(f"Found user_id {user_id} for client_id {client_id}")
271+
logger.debug(f"Found user_id {user_id} for client_id {client_id} via Redis")
272272
else:
273-
logger.warning(f"No user_id found for client_id {client_id}")
273+
logger.warning(f"No user_id found for client_id {client_id} in Redis")
274274

275275
return user_id
276276

backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_streaming_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from advanced_omi_backend.services.plugin_service import init_plugin_router
2020
from advanced_omi_backend.services.transcription.deepgram_stream_consumer import DeepgramStreamingConsumer
21+
from advanced_omi_backend.client_manager import initialize_redis_for_client_manager
2122

2223
logging.basicConfig(
2324
level=logging.INFO,
@@ -48,6 +49,10 @@ async def main():
4849
decode_responses=False
4950
)
5051
logger.info(f"✅ Connected to Redis: {redis_url}")
52+
53+
# Initialize ClientManager Redis for cross-container client→user mapping
54+
initialize_redis_for_client_manager(redis_url)
55+
5156
except Exception as e:
5257
logger.error(f"Failed to connect to Redis: {e}", exc_info=True)
5358
sys.exit(1)

0 commit comments

Comments
 (0)