From b8fc804ab66fd98205481fc6a17a42efb9da3a3f Mon Sep 17 00:00:00 2001
From: Ankush Malaker <43288948+AnkushMalaker@users.noreply.github.com>
Date: Sun, 22 Feb 2026 06:23:48 +0000
Subject: [PATCH 1/5] Refactor pre-commit configuration and enhance development
scripts
- Removed local hooks for Robot Framework tests and cleanup from `.pre-commit-config.yaml`, streamlining the pre-commit setup.
- Updated `Makefile` to install pre-commit using the `uv` tool, improving dependency management.
- Enhanced `restart.sh`, `start.sh`, `status.sh`, and `stop.sh` scripts to source a new `check_uv.sh` script for better environment validation.
- Added new environment variables for Galileo observability in `.env.template`, improving observability setup.
- Introduced OpenTelemetry initialization in `app_factory.py` for enhanced observability during application runtime.
---
.pre-commit-config.yaml | 29 +-
Makefile | 3 +-
backends/advanced/.env.template | 6 +
.../src/advanced_omi_backend/app_factory.py | 128 ++-
.../clients/audio_stream_client.py | 66 +-
.../src/advanced_omi_backend/config_loader.py | 40 +-
.../controllers/websocket_controller.py | 628 ++++++++------
.../observability/__init__.py | 0
.../observability/otel_setup.py | 105 +++
.../advanced_omi_backend/plugins/router.py | 203 +++--
.../advanced_omi_backend/plugins/services.py | 43 +-
.../advanced_omi_backend/prompt_defaults.py | 2 +-
.../advanced_omi_backend/prompt_registry.py | 25 +-
.../routers/modules/queue_routes.py | 643 +++++++++-----
.../services/audio_stream/producer.py | 140 ++--
.../services/transcription/__init__.py | 236 +++++-
.../workers/conversation_jobs.py | 190 +++--
.../workers/memory_jobs.py | 84 +-
.../workers/rq_worker_entry.py | 24 +-
.../workers/transcription_jobs.py | 189 ++++-
config/config.yml.template | 8 +
config/defaults.yml | 4 +
config/plugins.yml.template | 7 +-
restart.sh | 1 +
start.sh | 2 +
status.sh | 1 +
stop.sh | 2 +
tests/config/plugins.test.yml | 11 +
.../websocket_streaming_tests.robot | 41 +-
.../websocket_transcription_e2e_test.robot | 23 +-
tests/libs/audio_stream_library.py | 18 +-
tests/resources/websocket_keywords.robot | 6 +
wizard.py | 782 ++++++++++++------
wizard.sh | 2 +
34 files changed, 2570 insertions(+), 1122 deletions(-)
create mode 100644 backends/advanced/src/advanced_omi_backend/observability/__init__.py
create mode 100644 backends/advanced/src/advanced_omi_backend/observability/otel_setup.py
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 6ebb6573..adf40dcb 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,42 +1,21 @@
repos:
- # Local hooks (project-specific checks)
- - repo: local
- hooks:
- # Run Robot Framework endpoint tests before push
- - id: robot-framework-tests
- name: Robot Framework Tests (Endpoints)
- entry: bash -c 'cd tests && make endpoints OUTPUTDIR=.pre-commit-results'
- language: system
- pass_filenames: false
- stages: [push]
- verbose: true
-
- # Clean up test results after hook runs
- - id: cleanup-test-results
- name: Cleanup Test Results
- entry: bash -c 'cd tests && rm -rf .pre-commit-results'
- language: system
- pass_filenames: false
- stages: [push]
- always_run: true
-
# Code formatting
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
- files: ^backends/advanced-backend/src/.*\.py$
+ exclude: \.venv/
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
- files: ^backends/advanced-backend/src/.*\.py$
+ exclude: \.venv/
# File hygiene
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- files: ^backends/advanced-backend/src/.*
+ exclude: \.venv/
- id: end-of-file-fixer
- files: ^backends/advanced-backend/src/.*
\ No newline at end of file
+ exclude: \.venv/
diff --git a/Makefile b/Makefile
index d821819e..ae9e784d 100644
--- a/Makefile
+++ b/Makefile
@@ -122,8 +122,9 @@ help: ## Show detailed help for all targets
setup-dev: ## Setup development environment (git hooks, pre-commit)
@echo "π οΈ Setting up development environment..."
@echo ""
+ @bash scripts/check_uv.sh
@echo "π¦ Installing pre-commit..."
- @pip install pre-commit 2>/dev/null || pip3 install pre-commit
+ @uv tool install pre-commit
@echo ""
@echo "π§ Installing git hooks..."
@pre-commit install --hook-type pre-push
diff --git a/backends/advanced/.env.template b/backends/advanced/.env.template
index 6de583fd..293b299f 100644
--- a/backends/advanced/.env.template
+++ b/backends/advanced/.env.template
@@ -62,6 +62,12 @@ LANGFUSE_PUBLIC_KEY=
LANGFUSE_SECRET_KEY=
LANGFUSE_BASE_URL=http://langfuse-web:3000
+# Galileo (OTEL-based LLM observability)
+GALILEO_API_KEY=
+GALILEO_PROJECT=chronicle
+GALILEO_LOG_STREAM=default
+# GALILEO_CONSOLE_URL=https://app.galileo.ai # Default; override for self-hosted
+
# Qwen3-ASR (offline ASR via vLLM)
# QWEN3_ASR_URL=host.docker.internal:8767
# QWEN3_ASR_STREAM_URL=host.docker.internal:8769
diff --git a/backends/advanced/src/advanced_omi_backend/app_factory.py b/backends/advanced/src/advanced_omi_backend/app_factory.py
index 74cddd49..6083de97 100644
--- a/backends/advanced/src/advanced_omi_backend/app_factory.py
+++ b/backends/advanced/src/advanced_omi_backend/app_factory.py
@@ -71,26 +71,48 @@ async def initialize_openmemory_user() -> None:
# Get configured user_id and server_url
openmemory_config = memory_provider_config.openmemory_config
- user_id = openmemory_config.get("user_id", "openmemory") if openmemory_config else "openmemory"
- server_url = openmemory_config.get("server_url", "http://host.docker.internal:8765") if openmemory_config else "http://host.docker.internal:8765"
- client_name = openmemory_config.get("client_name", "chronicle") if openmemory_config else "chronicle"
+ user_id = (
+ openmemory_config.get("user_id", "openmemory")
+ if openmemory_config
+ else "openmemory"
+ )
+ server_url = (
+ openmemory_config.get("server_url", "http://host.docker.internal:8765")
+ if openmemory_config
+ else "http://host.docker.internal:8765"
+ )
+ client_name = (
+ openmemory_config.get("client_name", "chronicle")
+ if openmemory_config
+ else "chronicle"
+ )
- application_logger.info(f"Registering OpenMemory user: {user_id} at {server_url}")
+ application_logger.info(
+ f"Registering OpenMemory user: {user_id} at {server_url}"
+ )
# Make a lightweight registration call (create and delete dummy memory)
- async with MCPClient(server_url=server_url, client_name=client_name, user_id=user_id) as client:
+ async with MCPClient(
+ server_url=server_url, client_name=client_name, user_id=user_id
+ ) as client:
# Test connection first
is_connected = await client.test_connection()
if is_connected:
# Create and immediately delete a dummy memory to trigger user creation
- memory_ids = await client.add_memories("Chronicle initialization - user registration test")
+ memory_ids = await client.add_memories(
+ "Chronicle initialization - user registration test"
+ )
if memory_ids:
# Delete the test memory
await client.delete_memory(memory_ids[0])
application_logger.info(f"β
Registered OpenMemory user: {user_id}")
else:
- application_logger.warning(f"β οΈ OpenMemory MCP not reachable at {server_url}")
- application_logger.info("User will be auto-created on first memory operation")
+ application_logger.warning(
+ f"β οΈ OpenMemory MCP not reachable at {server_url}"
+ )
+ application_logger.info(
+ "User will be auto-created on first memory operation"
+ )
except Exception as e:
application_logger.warning(f"β οΈ Could not register OpenMemory user: {e}")
application_logger.info("User will be auto-created on first memory operation")
@@ -116,7 +138,13 @@ async def lifespan(app: FastAPI):
await init_beanie(
database=config.db,
- document_models=[User, Conversation, AudioChunkDocument, WaveformData, Annotation],
+ document_models=[
+ User,
+ Conversation,
+ AudioChunkDocument,
+ WaveformData,
+ Annotation,
+ ],
)
application_logger.info("Beanie initialized for all document models")
except Exception as e:
@@ -133,12 +161,17 @@ async def lifespan(app: FastAPI):
# Initialize Redis connection for RQ
try:
from advanced_omi_backend.controllers.queue_controller import redis_conn
+
redis_conn.ping()
application_logger.info("Redis connection established for RQ")
- application_logger.info("RQ workers can be started with: rq worker transcription memory default")
+ application_logger.info(
+ "RQ workers can be started with: rq worker transcription memory default"
+ )
except Exception as e:
application_logger.error(f"Failed to connect to Redis for RQ: {e}")
- application_logger.warning("RQ queue system will not be available - check Redis connection")
+ application_logger.warning(
+ "RQ queue system will not be available - check Redis connection"
+ )
# Initialize BackgroundTaskManager (must happen before any code path uses it)
try:
@@ -153,6 +186,14 @@ async def lifespan(app: FastAPI):
get_client_manager()
application_logger.info("ClientManager initialized")
+ # Initialize OTEL/Galileo if configured (before LLM client so instrumentor patches OpenAI first)
+ try:
+ from advanced_omi_backend.observability.otel_setup import init_otel
+
+ init_otel()
+ except Exception as e:
+ application_logger.warning(f"OTEL initialization skipped: {e}")
+
# Initialize prompt registry with defaults; seed into LangFuse in background
try:
from advanced_omi_backend.prompt_defaults import register_all_defaults
@@ -176,6 +217,7 @@ async def _deferred_seed():
# Initialize LLM client eagerly (catch config errors at startup, not on first request)
try:
from advanced_omi_backend.llm_client import get_llm_client
+
get_llm_client()
application_logger.info("LLM client initialized from config.yml")
except Exception as e:
@@ -186,35 +228,47 @@ async def _deferred_seed():
audio_service = get_audio_stream_service()
await audio_service.connect()
application_logger.info("Audio stream service connected to Redis Streams")
- application_logger.info("Audio stream workers can be started with: python -m advanced_omi_backend.workers.audio_stream_worker")
+ application_logger.info(
+ "Audio stream workers can be started with: python -m advanced_omi_backend.workers.audio_stream_worker"
+ )
except Exception as e:
application_logger.error(f"Failed to connect audio stream service: {e}")
- application_logger.warning("Redis Streams audio processing will not be available")
+ application_logger.warning(
+ "Redis Streams audio processing will not be available"
+ )
# Initialize Redis client for audio streaming producer (used by WebSocket handlers)
try:
app.state.redis_audio_stream = await redis.from_url(
- config.redis_url,
- encoding="utf-8",
- decode_responses=False
+ config.redis_url, encoding="utf-8", decode_responses=False
)
from advanced_omi_backend.services.audio_stream import AudioStreamProducer
- app.state.audio_stream_producer = AudioStreamProducer(app.state.redis_audio_stream)
- application_logger.info("β
Redis client for audio streaming producer initialized")
+
+ app.state.audio_stream_producer = AudioStreamProducer(
+ app.state.redis_audio_stream
+ )
+ application_logger.info(
+ "β
Redis client for audio streaming producer initialized"
+ )
# Initialize ClientManager Redis for cross-container clientβuser mapping
from advanced_omi_backend.client_manager import (
initialize_redis_for_client_manager,
)
+
initialize_redis_for_client_manager(config.redis_url)
except Exception as e:
- application_logger.error(f"Failed to initialize Redis client for audio streaming: {e}", exc_info=True)
+ application_logger.error(
+ f"Failed to initialize Redis client for audio streaming: {e}", exc_info=True
+ )
application_logger.warning("Audio streaming producer will not be available")
# Skip memory service pre-initialization to avoid blocking FastAPI startup
# Memory service will be lazily initialized when first used
- application_logger.info("Memory service will be initialized on first use (lazy loading)")
+ application_logger.info(
+ "Memory service will be initialized on first use (lazy loading)"
+ )
# Register OpenMemory user if using openmemory_mcp provider
await initialize_openmemory_user()
@@ -264,12 +318,15 @@ async def _deferred_seed():
application_logger.info(f"β
Plugin '{plugin_id}' initialized")
except Exception as e:
plugin_router.mark_plugin_failed(plugin_id, str(e))
- application_logger.error(f"Failed to initialize plugin '{plugin_id}': {e}", exc_info=True)
+ application_logger.error(
+ f"Failed to initialize plugin '{plugin_id}': {e}",
+ exc_info=True,
+ )
health = plugin_router.get_health_summary()
application_logger.info(
f"Plugins initialized: {health['initialized']}/{health['total']} active"
- + (f", {health['failed']} failed" if health['failed'] else "")
+ + (f", {health['failed']} failed" if health["failed"] else "")
)
# Store in app state for API access
@@ -281,10 +338,14 @@ async def _deferred_seed():
app.state.plugin_router = None
except Exception as e:
- application_logger.error(f"Failed to initialize plugin system: {e}", exc_info=True)
+ application_logger.error(
+ f"Failed to initialize plugin system: {e}", exc_info=True
+ )
app.state.plugin_router = None
- application_logger.info("Application ready - using application-level processing architecture.")
+ application_logger.info(
+ "Application ready - using application-level processing architecture."
+ )
logger.info("App ready")
try:
@@ -300,6 +361,7 @@ async def _deferred_seed():
from advanced_omi_backend.controllers.websocket_controller import (
cleanup_client_state,
)
+
await cleanup_client_state(client_id)
except Exception as e:
application_logger.error(f"Error cleaning up client {client_id}: {e}")
@@ -327,9 +389,14 @@ async def _deferred_seed():
# Close Redis client for audio streaming producer
try:
- if hasattr(app.state, 'redis_audio_stream') and app.state.redis_audio_stream:
+ if (
+ hasattr(app.state, "redis_audio_stream")
+ and app.state.redis_audio_stream
+ ):
await app.state.redis_audio_stream.close()
- application_logger.info("Redis client for audio streaming producer closed")
+ application_logger.info(
+ "Redis client for audio streaming producer closed"
+ )
except Exception as e:
application_logger.error(f"Error closing Redis audio streaming client: {e}")
@@ -341,6 +408,7 @@ async def _deferred_seed():
from advanced_omi_backend.services.plugin_service import (
cleanup_plugin_router,
)
+
await cleanup_plugin_router()
application_logger.info("Plugins shut down")
except Exception as e:
@@ -380,7 +448,7 @@ def create_app() -> FastAPI:
# Add WebSocket router at root level (not under /api prefix)
app.include_router(websocket_router)
- # Add authentication routers
+ # Add authentication routers
app.include_router(
fastapi_users.get_auth_router(cookie_backend),
prefix="/auth/cookie",
@@ -403,6 +471,8 @@ def create_app() -> FastAPI:
CHUNK_DIR = Path("/app/audio_chunks")
app.mount("/audio", StaticFiles(directory=CHUNK_DIR), name="audio")
- logger.info("FastAPI application created with all routers and middleware configured")
+ logger.info(
+ "FastAPI application created with all routers and middleware configured"
+ )
- return app
\ No newline at end of file
+ return app
diff --git a/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py b/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py
index 0595f3a4..38fc2333 100644
--- a/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py
+++ b/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py
@@ -114,7 +114,11 @@ async def connect(self, wait_for_ready: bool = True) -> WebSocketClientProtocol:
if wait_for_ready and "codec=pcm" in self.endpoint:
# PCM codec sends "ready" message after auth (line 261-268 in websocket_controller.py)
ready_msg = await self.ws.recv()
- ready = json.loads(ready_msg.strip() if isinstance(ready_msg, str) else ready_msg.decode().strip())
+ ready = json.loads(
+ ready_msg.strip()
+ if isinstance(ready_msg, str)
+ else ready_msg.decode().strip()
+ )
if ready.get("type") != "ready":
raise RuntimeError(f"Expected 'ready' message, got: {ready}")
logger.info("Received ready message from server")
@@ -154,10 +158,8 @@ async def send_audio_start(
},
"payload_length": None,
}
- print(f"π΅ CLIENT: Sending audio-start message: {header}")
logger.info(f"π΅ CLIENT: Sending audio-start message: {header}")
await self.ws.send(json.dumps(header) + "\n")
- print(f"β
CLIENT: Sent audio-start with mode={recording_mode}")
logger.info(f"β
CLIENT: Sent audio-start with mode={recording_mode}")
async def send_audio_chunk_wyoming(
@@ -197,7 +199,9 @@ async def send_audio_chunk_wyoming(
self.total_bytes += len(audio_data)
if self.chunk_count <= 3 or self.chunk_count % 100 == 0:
- logger.debug(f"Sent audio chunk #{self.chunk_count}: {len(audio_data)} bytes")
+ logger.debug(
+ f"Sent audio chunk #{self.chunk_count}: {len(audio_data)} bytes"
+ )
async def send_audio_chunk_raw(self, audio_data: bytes) -> None:
"""Send raw binary audio without Wyoming header (legacy mode).
@@ -222,7 +226,22 @@ async def send_audio_stop(self) -> None:
header = {"type": "audio-stop"}
await self.ws.send(json.dumps(header) + "\n")
- logger.info(f"Sent audio-stop (total: {self.chunk_count} chunks, {self.total_bytes} bytes)")
+ logger.info(
+ f"Sent audio-stop (total: {self.chunk_count} chunks, {self.total_bytes} bytes)"
+ )
+
+ async def send_button_event(self, button_state: str = "SINGLE_PRESS") -> None:
+ """Send button event via Wyoming protocol.
+
+ Args:
+ button_state: Button state ("SINGLE_PRESS" or "DOUBLE_PRESS")
+ """
+ if not self.ws:
+ raise RuntimeError("Not connected. Call connect() first.")
+
+ header = {"type": "button-event", "data": {"state": button_state}}
+ await self.ws.send(json.dumps(header) + "\n")
+ logger.info(f"Sent button event: {button_state}")
async def send_ping(self) -> None:
"""Send keepalive ping."""
@@ -305,7 +324,9 @@ async def stream_wav_file(
# Send audio-stop
await self.send_audio_stop()
- logger.info(f"Finished streaming: {self.chunk_count} chunks, {self.total_bytes} bytes")
+ logger.info(
+ f"Finished streaming: {self.chunk_count} chunks, {self.total_bytes} bytes"
+ )
return self.chunk_count
async def close(self) -> None:
@@ -318,7 +339,7 @@ async def close(self) -> None:
except asyncio.TimeoutError:
logger.warning("WebSocket close timed out after 2s, forcing close")
# Force close without waiting for handshake
- if hasattr(self.ws, 'transport') and self.ws.transport:
+ if hasattr(self.ws, "transport") and self.ws.transport:
self.ws.transport.close()
except Exception as e:
logger.error(f"Error during WebSocket close: {e}")
@@ -448,10 +469,14 @@ def run_loop():
# Connect and send audio-start
async def _connect_and_start():
try:
- logger.info(f"π΅ CLIENT: Stream {stream_id} connecting for {device_name}...")
+ logger.info(
+ f"π΅ CLIENT: Stream {stream_id} connecting for {device_name}..."
+ )
await client.connect()
session.connected = True
- logger.info(f"β
CLIENT: Stream {stream_id} connected, sending audio-start...")
+ logger.info(
+ f"β
CLIENT: Stream {stream_id} connected, sending audio-start..."
+ )
await client.send_audio_start(recording_mode=recording_mode)
session.audio_started = True
logger.info(f"β
CLIENT: Stream {stream_id} started for {device_name}")
@@ -595,9 +620,30 @@ async def _close_abruptly():
total_chunks = session.chunk_count
del self._sessions[stream_id]
- logger.info(f"Stream {stream_id} closed abruptly (no audio-stop), sent {total_chunks} chunks")
+ logger.info(
+ f"Stream {stream_id} closed abruptly (no audio-stop), sent {total_chunks} chunks"
+ )
return total_chunks
+ def send_button_event(
+ self, stream_id: str, button_state: str = "SINGLE_TAP"
+ ) -> None:
+ """Send a button event to an open stream.
+
+ Args:
+ stream_id: Stream session ID
+ button_state: Button state ("SINGLE_TAP" or "DOUBLE_TAP")
+ """
+ session = self._sessions.get(stream_id)
+ if not session:
+ raise ValueError(f"Unknown stream_id: {stream_id}")
+
+ async def _send():
+ await session.client.send_button_event(button_state)
+
+ future = asyncio.run_coroutine_threadsafe(_send(), session.loop)
+ future.result(timeout=5)
+
def get_session(self, stream_id: str) -> Optional[StreamSession]:
"""Get session info for a stream."""
return self._sessions.get(stream_id)
diff --git a/backends/advanced/src/advanced_omi_backend/config_loader.py b/backends/advanced/src/advanced_omi_backend/config_loader.py
index 1b8be9ee..bd1f9de4 100644
--- a/backends/advanced/src/advanced_omi_backend/config_loader.py
+++ b/backends/advanced/src/advanced_omi_backend/config_loader.py
@@ -26,7 +26,7 @@ def get_config_dir() -> Path:
def get_plugins_yml_path() -> Path:
"""
Get path to plugins.yml file (single source of truth).
-
+
Returns:
Path to plugins.yml
"""
@@ -86,20 +86,34 @@ def load_config(force_reload: bool = False) -> DictConfig:
# OmegaConf.merge replaces lists entirely, so we need custom merge
# for the 'models' list: merge by name so defaults models that aren't
# in user config are still available.
- default_models = OmegaConf.to_container(defaults.get("models", []) or [], resolve=False) if defaults else []
- user_models = OmegaConf.to_container(user_config.get("models", []) or [], resolve=False) if user_config else []
+ default_models = (
+ OmegaConf.to_container(defaults.get("models", []) or [], resolve=False)
+ if defaults
+ else []
+ )
+ user_models = (
+ OmegaConf.to_container(user_config.get("models", []) or [], resolve=False)
+ if user_config
+ else []
+ )
merged = OmegaConf.merge(defaults, user_config)
# Name-based merge: user models override defaults, but default-only models are kept
if default_models and user_models:
user_model_names = {m.get("name") for m in user_models if isinstance(m, dict)}
- extra_defaults = [m for m in default_models if isinstance(m, dict) and m.get("name") not in user_model_names]
+ extra_defaults = [
+ m
+ for m in default_models
+ if isinstance(m, dict) and m.get("name") not in user_model_names
+ ]
if extra_defaults:
all_models = user_models + extra_defaults
merged["models"] = OmegaConf.create(all_models)
- logger.info(f"Merged {len(extra_defaults)} default-only models into config: "
- f"{[m.get('name') for m in extra_defaults]}")
+ logger.info(
+ f"Merged {len(extra_defaults)} default-only models into config: "
+ f"{[m.get('name') for m in extra_defaults]}"
+ )
# Cache result
_config_cache = merged
@@ -126,7 +140,7 @@ def get_backend_config(section: Optional[str] = None) -> DictConfig:
DictConfig for backend section or subsection
"""
cfg = load_config()
- if 'backend' not in cfg:
+ if "backend" not in cfg:
return OmegaConf.create({})
backend_cfg = cfg.backend
@@ -153,6 +167,9 @@ def save_config_section(section_path: str, values: dict) -> bool:
"""
Update a config section and save to config.yml.
+ Also updates the in-memory config cache so changes take effect immediately,
+ even when CONFIG_FILE points to a different file (e.g., in test environments).
+
Args:
section_path: Dot-separated path (e.g., 'backend.diarization')
values: Dict with new values
@@ -174,8 +191,13 @@ def save_config_section(section_path: str, values: dict) -> bool:
# Save back to file
OmegaConf.save(existing_config, config_path)
- # Invalidate cache
- reload_config()
+ # Reload config from the primary config file (CONFIG_FILE env var)
+ merged = reload_config()
+
+ # Also apply the values to the in-memory cache directly.
+ # This is needed when CONFIG_FILE points to a different file than config.yml
+ # (e.g., test configs), so the saved values still take effect at runtime.
+ OmegaConf.update(merged, section_path, values, merge=True)
logger.info(f"Saved config section '{section_path}' to {config_path}")
return True
diff --git a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py
index 4bb1088a..d7fc12e9 100644
--- a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py
+++ b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py
@@ -1,4 +1,3 @@
-
"""
WebSocket controller for Chronicle backend.
@@ -32,7 +31,6 @@
from advanced_omi_backend.services.audio_stream.producer import (
get_audio_stream_producer,
)
-from advanced_omi_backend.utils.audio_utils import process_audio_chunk
# Thread pool executors for audio decoding
_DEC_IO_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
@@ -82,29 +80,34 @@ async def subscribe_to_interim_results(websocket: WebSocket, session_id: str) ->
# Listen for messages
while True:
try:
- message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
+ message = await pubsub.get_message(
+ ignore_subscribe_messages=True, timeout=1.0
+ )
- if message and message['type'] == 'message':
+ if message and message["type"] == "message":
# Parse result data
try:
- result_data = json.loads(message['data'])
+ result_data = json.loads(message["data"])
# Forward to client WebSocket
- await websocket.send_json({
- "type": "interim_transcript",
- "data": result_data
- })
+ await websocket.send_json(
+ {"type": "interim_transcript", "data": result_data}
+ )
# Log for debugging
is_final = result_data.get("is_final", False)
text_preview = result_data.get("text", "")[:50]
result_type = "FINAL" if is_final else "interim"
- logger.debug(f"βοΈ Forwarded {result_type} result to client {session_id}: {text_preview}...")
+ logger.debug(
+ f"βοΈ Forwarded {result_type} result to client {session_id}: {text_preview}..."
+ )
except json.JSONDecodeError as e:
logger.error(f"Failed to parse interim result JSON: {e}")
except Exception as send_error:
- logger.error(f"Failed to send interim result to client {session_id}: {send_error}")
+ logger.error(
+ f"Failed to send interim result to client {session_id}: {send_error}"
+ )
# WebSocket might be closed, exit loop
break
@@ -112,14 +115,22 @@ async def subscribe_to_interim_results(websocket: WebSocket, session_id: str) ->
# No message received, continue waiting
continue
except asyncio.CancelledError:
- logger.info(f"Interim results subscriber cancelled for session {session_id}")
+ logger.info(
+ f"Interim results subscriber cancelled for session {session_id}"
+ )
break
except Exception as e:
- logger.error(f"Error in interim results subscriber for {session_id}: {e}", exc_info=True)
+ logger.error(
+ f"Error in interim results subscriber for {session_id}: {e}",
+ exc_info=True,
+ )
break
except Exception as e:
- logger.error(f"Failed to initialize interim results subscriber for {session_id}: {e}", exc_info=True)
+ logger.error(
+ f"Failed to initialize interim results subscriber for {session_id}: {e}",
+ exc_info=True,
+ )
finally:
try:
# Unsubscribe and close connections
@@ -128,7 +139,9 @@ async def subscribe_to_interim_results(websocket: WebSocket, session_id: str) ->
await redis_client.aclose()
logger.info(f"π Unsubscribed from interim results channel: {channel}")
except Exception as cleanup_error:
- logger.error(f"Error cleaning up interim results subscriber: {cleanup_error}")
+ logger.error(
+ f"Error cleaning up interim results subscriber: {cleanup_error}"
+ )
async def parse_wyoming_protocol(ws: WebSocket) -> tuple[dict, Optional[bytes]]:
@@ -140,14 +153,18 @@ async def parse_wyoming_protocol(ws: WebSocket) -> tuple[dict, Optional[bytes]]:
# Read data from WebSocket
logger.debug(f"parse_wyoming_protocol: About to call ws.receive()")
message = await ws.receive()
- logger.debug(f"parse_wyoming_protocol: Received message with keys: {message.keys() if message else 'None'}")
+ logger.debug(
+ f"parse_wyoming_protocol: Received message with keys: {message.keys() if message else 'None'}"
+ )
# Handle WebSocket close frame
if "type" in message and message["type"] == "websocket.disconnect":
# This is a normal WebSocket close event
code = message.get("code", 1000)
reason = message.get("reason", "")
- logger.info(f"π΄ WebSocket disconnect received in parse_wyoming_protocol. Code: {code}, Reason: {reason}")
+ logger.info(
+ f"π΄ WebSocket disconnect received in parse_wyoming_protocol. Code: {code}, Reason: {reason}"
+ )
raise WebSocketDisconnect(code=code, reason=reason)
# Handle text message (JSON header)
@@ -190,7 +207,10 @@ async def create_client_state(client_id: str, user, device_name: Optional[str] =
# Directory where WAV chunks are written
from pathlib import Path
- CHUNK_DIR = Path("./audio_chunks") # This will be mounted to ./data/audio_chunks by Docker
+
+ CHUNK_DIR = Path(
+ "./audio_chunks"
+ ) # This will be mounted to ./data/audio_chunks by Docker
# Use ClientManager for atomic client creation and registration
client_state = client_manager.create_client(
@@ -199,10 +219,12 @@ async def create_client_state(client_id: str, user, device_name: Optional[str] =
# Also track in persistent mapping (for database queries + cross-container Redis)
from advanced_omi_backend.client_manager import track_client_user_relationship_async
+
await track_client_user_relationship_async(client_id, user.user_id)
# Register client in user model (persistent)
from advanced_omi_backend.users import register_client_to_user
+
await register_client_to_user(user, client_id, device_name)
return client_state
@@ -224,7 +246,9 @@ async def cleanup_client_state(client_id: str):
# The speech detection job now monitors session status and completes naturally.
import redis.asyncio as redis
- logger.info(f"π Letting speech detection job complete naturally for client {client_id} (if running)")
+ logger.info(
+ f"π Letting speech detection job complete naturally for client {client_id} (if running)"
+ )
# Mark all active sessions for this client as complete AND delete Redis streams
try:
@@ -236,6 +260,7 @@ async def cleanup_client_state(client_id: str):
from advanced_omi_backend.services.audio_stream.producer import (
get_audio_stream_producer,
)
+
audio_stream_producer = get_audio_stream_producer()
# Find all session keys for this client and mark them complete
@@ -258,18 +283,26 @@ async def cleanup_client_state(client_id: str):
# If session is still active, finalize it first (sets status + completion_reason atomically)
if status in ["active", None]:
- logger.info(f"π Finalizing active session {session_id[:12]} due to WebSocket disconnect")
- await audio_stream_producer.finalize_session(session_id, completion_reason="websocket_disconnect")
+ logger.info(
+ f"π Finalizing active session {session_id[:12]} due to WebSocket disconnect"
+ )
+ await audio_stream_producer.finalize_session(
+ session_id, completion_reason="websocket_disconnect"
+ )
# Mark session as complete (WebSocket disconnected)
- await mark_session_complete(async_redis, session_id, "websocket_disconnect")
+ await mark_session_complete(
+ async_redis, session_id, "websocket_disconnect"
+ )
sessions_closed += 1
if cursor == 0:
break
if sessions_closed > 0:
- logger.info(f"β
Closed {sessions_closed} active session(s) for client {client_id}")
+ logger.info(
+ f"β
Closed {sessions_closed} active session(s) for client {client_id}"
+ )
# Set TTL on Redis Streams for this client (allows consumer groups to finish processing)
stream_pattern = f"audio:stream:{client_id}"
@@ -282,9 +315,11 @@ async def cleanup_client_state(client_id: str):
pending_count = 0
try:
# Check streaming-transcription consumer group for pending messages
- pending_info = await async_redis.xpending(stream_pattern, "streaming-transcription")
+ pending_info = await async_redis.xpending(
+ stream_pattern, "streaming-transcription"
+ )
if pending_info:
- pending_count = pending_info.get('pending', 0)
+ pending_count = pending_info.get("pending", 0)
except Exception as e:
# Consumer group might not exist yet - that's ok
logger.debug(f"No consumer group for {stream_pattern}: {e}")
@@ -295,7 +330,9 @@ async def cleanup_client_state(client_id: str):
f"{stream_length} messages in stream, {pending_count} pending in consumer group"
)
- await async_redis.expire(stream_pattern, 60) # 60 second TTL for consumer group fan-out
+ await async_redis.expire(
+ stream_pattern, 60
+ ) # 60 second TTL for consumer group fan-out
logger.info(f"β° Set 60s TTL on Redis stream: {stream_pattern}")
else:
logger.debug(f"No Redis stream found for client {client_id}")
@@ -303,7 +340,9 @@ async def cleanup_client_state(client_id: str):
await async_redis.close()
except Exception as session_error:
- logger.warning(f"β οΈ Error marking sessions complete for client {client_id}: {session_error}")
+ logger.warning(
+ f"β οΈ Error marking sessions complete for client {client_id}: {session_error}"
+ )
# Use ClientManager for atomic client removal with cleanup
client_manager = get_client_manager()
@@ -321,7 +360,7 @@ async def _setup_websocket_connection(
token: Optional[str],
device_name: Optional[str],
pending_client_id: str,
- connection_type: str
+ connection_type: str,
) -> tuple[Optional[str], Optional[object], Optional[object]]:
"""
Setup WebSocket connection: accept, authenticate, create client state.
@@ -344,12 +383,17 @@ async def _setup_websocket_connection(
if not user:
# Send error message to client before closing
try:
- error_msg = json.dumps({
- "type": "error",
- "error": "authentication_failed",
- "message": "Authentication failed. Please log in again and ensure your token is valid.",
- "code": 1008
- }) + "\n"
+ error_msg = (
+ json.dumps(
+ {
+ "type": "error",
+ "error": "authentication_failed",
+ "message": "Authentication failed. Please log in again and ensure your token is valid.",
+ "code": 1008,
+ }
+ )
+ + "\n"
+ )
await ws.send_text(error_msg)
application_logger.info("Sent authentication error message to client")
except Exception as send_error:
@@ -370,7 +414,10 @@ async def _setup_websocket_connection(
# Send ready message to confirm connection is established
try:
- ready_msg = json.dumps({"type": "ready", "message": "WebSocket connection established"}) + "\n"
+ ready_msg = (
+ json.dumps({"type": "ready", "message": "WebSocket connection established"})
+ + "\n"
+ )
await ws.send_text(ready_msg)
application_logger.debug(f"β
Sent ready message to {client_id}")
except Exception as e:
@@ -389,7 +436,7 @@ async def _initialize_streaming_session(
user_email: str,
client_id: str,
audio_format: dict,
- websocket: Optional[WebSocket] = None
+ websocket: Optional[WebSocket] = None,
) -> Optional[asyncio.Task]:
"""
Initialize streaming session with Redis and enqueue processing jobs.
@@ -410,30 +457,38 @@ async def _initialize_streaming_session(
f"π΄ BACKEND: _initialize_streaming_session called for {client_id}"
)
- if hasattr(client_state, 'stream_session_id'):
+ if hasattr(client_state, "stream_session_id"):
application_logger.debug(f"Session already initialized for {client_id}")
return None
# Initialize stream session - use client_id as session_id for predictable lookup
# All other session metadata goes to Redis (single source of truth)
client_state.stream_session_id = client_state.client_id
- application_logger.info(f"π Created stream session: {client_state.stream_session_id}")
+ application_logger.info(
+ f"π Created stream session: {client_state.stream_session_id}"
+ )
# Determine transcription provider from config.yml
from advanced_omi_backend.model_registry import get_models_registry
registry = get_models_registry()
if not registry:
- raise ValueError("config.yml not found - cannot determine transcription provider")
+ raise ValueError(
+ "config.yml not found - cannot determine transcription provider"
+ )
stt_model = registry.get_default("stt")
if not stt_model:
raise ValueError("No default STT model configured in config.yml (defaults.stt)")
# Use model_provider for session tracking (generic, not validated against hardcoded list)
- provider = stt_model.model_provider.lower() if stt_model.model_provider else stt_model.name
+ provider = (
+ stt_model.model_provider.lower() if stt_model.model_provider else stt_model.name
+ )
- application_logger.info(f"π Using STT provider: {provider} (model: {stt_model.name})")
+ application_logger.info(
+ f"π Using STT provider: {provider} (model: {stt_model.name})"
+ )
# Initialize session tracking in Redis (SINGLE SOURCE OF TRUTH for session metadata)
# This includes user_email, connection info, audio format, chunk counters, job IDs, etc.
@@ -445,7 +500,7 @@ async def _initialize_streaming_session(
user_email=user_email,
connection_id=connection_id,
mode="streaming",
- provider=provider
+ provider=provider,
)
# Store audio format in Redis session (not in ClientState)
@@ -454,6 +509,7 @@ async def _initialize_streaming_session(
from advanced_omi_backend.services.audio_stream.producer import (
get_audio_stream_producer,
)
+
session_key = f"audio:session:{client_state.stream_session_id}"
redis_client = audio_stream_producer.redis_client
await redis_client.hset(session_key, "audio_format", json.dumps(audio_format))
@@ -462,16 +518,14 @@ async def _initialize_streaming_session(
from advanced_omi_backend.controllers.queue_controller import start_streaming_jobs
job_ids = start_streaming_jobs(
- session_id=client_state.stream_session_id,
- user_id=user_id,
- client_id=client_id
+ session_id=client_state.stream_session_id, user_id=user_id, client_id=client_id
)
# Store job IDs in Redis session (not in ClientState)
await audio_stream_producer.update_session_job_ids(
session_id=client_state.stream_session_id,
- speech_detection_job_id=job_ids['speech_detection'],
- audio_persistence_job_id=job_ids['audio_persistence']
+ speech_detection_job_id=job_ids["speech_detection"],
+ audio_persistence_job_id=job_ids["audio_persistence"],
)
# Note: Placeholder conversation creation is handled by the audio persistence job,
@@ -483,17 +537,15 @@ async def _initialize_streaming_session(
subscriber_task = asyncio.create_task(
subscribe_to_interim_results(websocket, client_state.stream_session_id)
)
- application_logger.info(f"π‘ Launched interim results subscriber for session {client_state.stream_session_id}")
+ application_logger.info(
+ f"π‘ Launched interim results subscriber for session {client_state.stream_session_id}"
+ )
return subscriber_task
async def _finalize_streaming_session(
- client_state,
- audio_stream_producer,
- user_id: str,
- user_email: str,
- client_id: str
+ client_state, audio_stream_producer, user_id: str, user_email: str, client_id: str
) -> None:
"""
Finalize streaming session: flush buffer, signal workers, enqueue finalize job, cleanup.
@@ -505,7 +557,7 @@ async def _finalize_streaming_session(
user_email: User email
client_id: Client ID
"""
- if not hasattr(client_state, 'stream_session_id'):
+ if not hasattr(client_state, "stream_session_id"):
application_logger.debug(f"No active session to finalize for {client_id}")
return
@@ -513,19 +565,21 @@ async def _finalize_streaming_session(
try:
# Flush any remaining buffered audio
- audio_format = getattr(client_state, 'stream_audio_format', {})
+ audio_format = getattr(client_state, "stream_audio_format", {})
await audio_stream_producer.flush_session_buffer(
session_id=session_id,
sample_rate=audio_format.get("rate", 16000),
channels=audio_format.get("channels", 1),
- sample_width=audio_format.get("width", 2)
+ sample_width=audio_format.get("width", 2),
)
# Send end-of-session signal to workers
await audio_stream_producer.send_session_end_signal(session_id)
# Mark session as finalizing with user_stopped reason (audio-stop event)
- await audio_stream_producer.finalize_session(session_id, completion_reason="user_stopped")
+ await audio_stream_producer.finalize_session(
+ session_id, completion_reason="user_stopped"
+ )
# Store markers in Redis so open_conversation_job can persist them
if client_state.markers:
@@ -556,13 +610,12 @@ async def _finalize_streaming_session(
# Clear session state from ClientState (only stream_session_id is stored there now)
# All other session metadata lives in Redis (single source of truth)
- if hasattr(client_state, 'stream_session_id'):
- delattr(client_state, 'stream_session_id')
+ if hasattr(client_state, "stream_session_id"):
+ delattr(client_state, "stream_session_id")
except Exception as finalize_error:
application_logger.error(
- f"β Failed to finalize streaming session: {finalize_error}",
- exc_info=True
+ f"β Failed to finalize streaming session: {finalize_error}", exc_info=True
)
@@ -574,7 +627,7 @@ async def _publish_audio_to_stream(
client_id: str,
sample_rate: int,
channels: int,
- sample_width: int
+ sample_width: int,
) -> None:
"""
Publish audio chunk to Redis Stream with chunk tracking.
@@ -589,28 +642,23 @@ async def _publish_audio_to_stream(
channels: Number of channels
sample_width: Bytes per sample
"""
- if not hasattr(client_state, 'stream_session_id'):
- application_logger.warning(f"β οΈ Received audio chunk before session initialized for {client_id}")
+ if not hasattr(client_state, "stream_session_id"):
+ application_logger.warning(
+ f"β οΈ Received audio chunk before session initialized for {client_id}"
+ )
return
session_id = client_state.stream_session_id
- # Increment chunk count in Redis (single source of truth) and format chunk ID
- session_key = f"audio:session:{session_id}"
- redis_client = audio_stream_producer.redis_client
- chunk_count = await redis_client.hincrby(session_key, "chunks_published", 1)
- chunk_id = f"{chunk_count:05d}"
-
- # Publish to Redis Stream using producer
+ # Publish to Redis Stream using producer (producer owns chunk counting)
await audio_stream_producer.add_audio_chunk(
audio_data=audio_data,
session_id=session_id,
- chunk_id=chunk_id,
user_id=user_id,
client_id=client_id,
sample_rate=sample_rate,
channels=channels,
- sample_width=sample_width
+ sample_width=sample_width,
)
@@ -621,7 +669,7 @@ async def _handle_omi_audio_chunk(
decode_packet_fn,
user_id: str,
client_id: str,
- packet_count: int
+ packet_count: int,
) -> None:
"""
Handle OMI audio chunk: decode Opus to PCM, then publish to stream.
@@ -638,7 +686,9 @@ async def _handle_omi_audio_chunk(
# Decode Opus to PCM
start_time = time.time()
loop = asyncio.get_running_loop()
- pcm_data = await loop.run_in_executor(_DEC_IO_EXECUTOR, decode_packet_fn, opus_payload)
+ pcm_data = await loop.run_in_executor(
+ _DEC_IO_EXECUTOR, decode_packet_fn, opus_payload
+ )
decode_time = time.time() - start_time
if pcm_data:
@@ -657,7 +707,7 @@ async def _handle_omi_audio_chunk(
client_id,
OMI_SAMPLE_RATE,
OMI_CHANNELS,
- OMI_SAMPLE_WIDTH
+ OMI_SAMPLE_WIDTH,
)
else:
# Log decode failures for first 5 packets
@@ -675,7 +725,7 @@ async def _handle_streaming_mode_audio(
user_id: str,
user_email: str,
client_id: str,
- websocket: Optional[WebSocket] = None
+ websocket: Optional[WebSocket] = None,
) -> Optional[asyncio.Task]:
"""
Handle audio chunk in streaming mode.
@@ -695,7 +745,7 @@ async def _handle_streaming_mode_audio(
"""
# Initialize session if needed
subscriber_task = None
- if not hasattr(client_state, 'stream_session_id'):
+ if not hasattr(client_state, "stream_session_id"):
subscriber_task = await _initialize_streaming_session(
client_state,
audio_stream_producer,
@@ -703,7 +753,7 @@ async def _handle_streaming_mode_audio(
user_email,
client_id,
audio_format,
- websocket=websocket # Pass WebSocket to launch interim results subscriber
+ websocket=websocket, # Pass WebSocket to launch interim results subscriber
)
# Publish to Redis Stream
@@ -715,17 +765,14 @@ async def _handle_streaming_mode_audio(
client_id,
audio_format.get("rate", 16000),
audio_format.get("channels", 1),
- audio_format.get("width", 2)
+ audio_format.get("width", 2),
)
return subscriber_task
async def _handle_batch_mode_audio(
- client_state,
- audio_data: bytes,
- audio_format: dict,
- client_id: str
+ client_state, audio_data: bytes, audio_format: dict, client_id: str
) -> None:
"""
Handle audio chunk in batch mode with rolling 30-minute limit.
@@ -737,7 +784,7 @@ async def _handle_batch_mode_audio(
client_id: Client ID
"""
# Initialize batch accumulator if needed
- if not hasattr(client_state, 'batch_audio_chunks'):
+ if not hasattr(client_state, "batch_audio_chunks"):
client_state.batch_audio_chunks = []
client_state.batch_audio_format = audio_format
client_state.batch_audio_bytes = 0 # Track total bytes
@@ -774,7 +821,7 @@ async def _handle_batch_mode_audio(
user_id=client_state.user_id, # Need to store these on session start
user_email=client_state.user_email,
client_id=client_state.client_id,
- batch_number=client_state.batch_chunks_processed + 1
+ batch_number=client_state.batch_chunks_processed + 1,
)
# Clear buffer for next batch
@@ -796,7 +843,7 @@ async def _handle_audio_chunk(
user_id: str,
user_email: str,
client_id: str,
- websocket: Optional[WebSocket] = None
+ websocket: Optional[WebSocket] = None,
) -> Optional[asyncio.Task]:
"""
Route audio chunk to appropriate mode handler (streaming or batch).
@@ -814,13 +861,18 @@ async def _handle_audio_chunk(
Returns:
Interim results subscriber task if websocket provided and streaming mode, None otherwise
"""
- recording_mode = getattr(client_state, 'recording_mode', 'batch')
+ recording_mode = getattr(client_state, "recording_mode", "batch")
if recording_mode == "streaming":
return await _handle_streaming_mode_audio(
- client_state, audio_stream_producer, audio_data,
- audio_format, user_id, user_email, client_id,
- websocket=websocket
+ client_state,
+ audio_stream_producer,
+ audio_data,
+ audio_format,
+ user_id,
+ user_email,
+ client_id,
+ websocket=websocket,
)
else:
await _handle_batch_mode_audio(
@@ -833,7 +885,7 @@ async def _handle_audio_session_start(
client_state,
audio_format: dict,
client_id: str,
- websocket: Optional[WebSocket] = None
+ websocket: Optional[WebSocket] = None,
) -> tuple[bool, str]:
"""
Handle audio-start event - validate mode and set recording mode.
@@ -878,14 +930,20 @@ async def _handle_audio_session_start(
"type": "error",
"error": "streaming_not_configured",
"message": error_msg,
- "code": 400
+ "code": 400,
}
await websocket.send_json(error_response)
- application_logger.info(f"π€ Sent streaming error to WebUI client {client_id}")
+ application_logger.info(
+ f"π€ Sent streaming error to WebUI client {client_id}"
+ )
# Close the websocket connection after sending error
- await websocket.close(code=1008, reason="Streaming transcription not configured")
- application_logger.info(f"π Closed WebSocket connection for {client_id} due to streaming config error")
+ await websocket.close(
+ code=1008, reason="Streaming transcription not configured"
+ )
+ application_logger.info(
+ f"π Closed WebSocket connection for {client_id} due to streaming config error"
+ )
# Raise ValueError to exit the handler completely
raise ValueError(error_msg)
@@ -917,11 +975,7 @@ async def _handle_audio_session_start(
async def _handle_audio_session_stop(
- client_state,
- audio_stream_producer,
- user_id: str,
- user_email: str,
- client_id: str
+ client_state, audio_stream_producer, user_id: str, user_email: str, client_id: str
) -> bool:
"""
Handle audio-stop event - finalize session based on mode.
@@ -936,13 +990,14 @@ async def _handle_audio_session_stop(
Returns:
False to switch back to control mode
"""
- recording_mode = getattr(client_state, 'recording_mode', 'batch')
- application_logger.info(f"π Audio session stopped for {client_id} (mode: {recording_mode})")
+ recording_mode = getattr(client_state, "recording_mode", "batch")
+ application_logger.info(
+ f"π Audio session stopped for {client_id} (mode: {recording_mode})"
+ )
if recording_mode == "streaming":
await _finalize_streaming_session(
- client_state, audio_stream_producer,
- user_id, user_email, client_id
+ client_state, audio_stream_producer, user_id, user_email, client_id
)
else:
await _process_batch_audio_complete(
@@ -969,10 +1024,7 @@ async def _handle_button_event(
user_id: User ID
client_id: Client ID
"""
- from advanced_omi_backend.plugins.events import (
- BUTTON_STATE_TO_EVENT,
- ButtonState,
- )
+ from advanced_omi_backend.plugins.events import BUTTON_STATE_TO_EVENT, ButtonState
from advanced_omi_backend.services.plugin_service import get_plugin_router
timestamp = time.time()
@@ -993,7 +1045,6 @@ async def _handle_button_event(
}
client_state.add_marker(marker)
-
# Map device button state to typed plugin event
try:
button_state_enum = ButtonState(button_state)
@@ -1016,18 +1067,14 @@ async def _handle_button_event(
"state": button_state_enum.value,
"timestamp": timestamp,
"audio_uuid": audio_uuid,
- "session_id": getattr(client_state, 'stream_session_id', None),
+ "session_id": getattr(client_state, "stream_session_id", None),
"client_id": client_id,
},
)
async def _process_rolling_batch(
- client_state,
- user_id: str,
- user_email: str,
- client_id: str,
- batch_number: int
+ client_state, user_id: str, user_email: str, client_id: str, batch_number: int
) -> None:
"""
Process accumulated batch audio as a rolling segment.
@@ -1041,7 +1088,10 @@ async def _process_rolling_batch(
client_id: Client ID
batch_number: Sequential batch number (1, 2, 3...)
"""
- if not hasattr(client_state, 'batch_audio_chunks') or not client_state.batch_audio_chunks:
+ if (
+ not hasattr(client_state, "batch_audio_chunks")
+ or not client_state.batch_audio_chunks
+ ):
application_logger.warning(f"β οΈ No audio chunks to process for rolling batch")
return
@@ -1050,14 +1100,14 @@ async def _process_rolling_batch(
from advanced_omi_backend.utils.audio_chunk_utils import convert_audio_to_chunks
# Combine chunks
- complete_audio = b''.join(client_state.batch_audio_chunks)
+ complete_audio = b"".join(client_state.batch_audio_chunks)
application_logger.info(
f"π¦ Rolling batch #{batch_number}: Combined {len(client_state.batch_audio_chunks)} chunks "
f"into {len(complete_audio)} bytes"
)
# Get audio format
- audio_format = getattr(client_state, 'batch_audio_format', {})
+ audio_format = getattr(client_state, "batch_audio_format", {})
sample_rate = audio_format.get("rate", 16000)
width = audio_format.get("width", 2)
channels = audio_format.get("channels", 1)
@@ -1067,7 +1117,7 @@ async def _process_rolling_batch(
user_id=user_id,
client_id=client_id,
title=f"Recording Part {batch_number}",
- summary="Rolling batch processing..."
+ summary="Rolling batch processing...",
)
await conversation.insert()
conversation_id = conversation.conversation_id # Get the auto-generated ID
@@ -1078,7 +1128,7 @@ async def _process_rolling_batch(
audio_data=complete_audio,
sample_rate=sample_rate,
channels=channels,
- sample_width=width
+ sample_width=width,
)
# Enqueue transcription job
@@ -1104,7 +1154,11 @@ async def _process_rolling_batch(
result_ttl=JOB_RESULT_TTL,
job_id=transcribe_job_id,
description=f"Transcribe rolling batch #{batch_number} {conversation_id[:8]}",
- meta={'conversation_id': conversation_id, 'client_id': client_id, 'batch_number': batch_number}
+ meta={
+ "conversation_id": conversation_id,
+ "client_id": client_id,
+ "batch_number": batch_number,
+ },
)
application_logger.info(
@@ -1114,16 +1168,12 @@ async def _process_rolling_batch(
except Exception as e:
application_logger.error(
- f"β Failed to process rolling batch #{batch_number}: {e}",
- exc_info=True
+ f"β Failed to process rolling batch #{batch_number}: {e}", exc_info=True
)
async def _process_batch_audio_complete(
- client_state,
- user_id: str,
- user_email: str,
- client_id: str
+ client_state, user_id: str, user_email: str, client_id: str
) -> None:
"""
Process completed batch audio: write file, create conversation, enqueue jobs.
@@ -1134,8 +1184,13 @@ async def _process_batch_audio_complete(
user_email: User email
client_id: Client ID
"""
- if not hasattr(client_state, 'batch_audio_chunks') or not client_state.batch_audio_chunks:
- application_logger.warning(f"β οΈ Batch mode: No audio chunks accumulated for {client_id}")
+ if (
+ not hasattr(client_state, "batch_audio_chunks")
+ or not client_state.batch_audio_chunks
+ ):
+ application_logger.warning(
+ f"β οΈ Batch mode: No audio chunks accumulated for {client_id}"
+ )
return
try:
@@ -1143,7 +1198,7 @@ async def _process_batch_audio_complete(
from advanced_omi_backend.utils.audio_chunk_utils import convert_audio_to_chunks
# Combine all chunks
- complete_audio = b''.join(client_state.batch_audio_chunks)
+ complete_audio = b"".join(client_state.batch_audio_chunks)
application_logger.info(
f"π¦ Batch mode: Combined {len(client_state.batch_audio_chunks)} chunks into {len(complete_audio)} bytes"
)
@@ -1152,17 +1207,15 @@ async def _process_batch_audio_complete(
timestamp = int(time.time() * 1000)
# Get audio format from batch metadata (set during audio-start)
- audio_format = getattr(client_state, 'batch_audio_format', {})
- sample_rate = audio_format.get('rate', OMI_SAMPLE_RATE)
- sample_width = audio_format.get('width', OMI_SAMPLE_WIDTH)
- channels = audio_format.get('channels', OMI_CHANNELS)
+ audio_format = getattr(client_state, "batch_audio_format", {})
+ sample_rate = audio_format.get("rate", OMI_SAMPLE_RATE)
+ sample_width = audio_format.get("width", OMI_SAMPLE_WIDTH)
+ channels = audio_format.get("channels", OMI_CHANNELS)
# Calculate audio duration
duration = len(complete_audio) / (sample_rate * sample_width * channels)
- application_logger.info(
- f"β
Batch mode: Processing audio ({duration:.1f}s)"
- )
+ application_logger.info(f"β
Batch mode: Processing audio ({duration:.1f}s)")
# Create conversation immediately for batch audio (conversation_id auto-generated)
version_id = str(uuid.uuid4())
@@ -1171,7 +1224,7 @@ async def _process_batch_audio_complete(
user_id=user_id,
client_id=client_id,
title="Batch Recording",
- summary="Processing batch audio..."
+ summary="Processing batch audio...",
)
# Attach any markers (e.g., button events) captured during the session
if client_state.markers:
@@ -1180,7 +1233,9 @@ async def _process_batch_audio_complete(
await conversation.insert()
conversation_id = conversation.conversation_id # Get the auto-generated ID
- application_logger.info(f"π Batch mode: Created conversation {conversation_id}")
+ application_logger.info(
+ f"π Batch mode: Created conversation {conversation_id}"
+ )
# Convert audio directly to MongoDB chunks (no disk intermediary)
try:
@@ -1197,8 +1252,7 @@ async def _process_batch_audio_complete(
)
except Exception as chunk_error:
application_logger.error(
- f"Failed to convert batch audio to chunks: {chunk_error}",
- exc_info=True
+ f"Failed to convert batch audio to chunks: {chunk_error}", exc_info=True
)
# Continue anyway - transcription job will handle it
@@ -1226,17 +1280,19 @@ async def _process_batch_audio_complete(
result_ttl=JOB_RESULT_TTL,
job_id=transcribe_job_id,
description=f"Transcribe batch audio {conversation_id[:8]}",
- meta={'conversation_id': conversation_id, 'client_id': client_id}
+ meta={"conversation_id": conversation_id, "client_id": client_id},
)
- application_logger.info(f"π₯ Batch mode: Enqueued transcription job {transcription_job.id}")
+ application_logger.info(
+ f"π₯ Batch mode: Enqueued transcription job {transcription_job.id}"
+ )
# Enqueue post-conversation processing job chain (depends on transcription)
job_ids = start_post_conversation_jobs(
conversation_id=conversation_id,
user_id=None, # Will be read from conversation in DB by jobs
depends_on_job=transcription_job, # Wait for transcription to complete
- client_id=client_id # Pass client_id for UI tracking
+ client_id=client_id, # Pass client_id for UI tracking
)
application_logger.info(
@@ -1251,11 +1307,54 @@ async def _process_batch_audio_complete(
except Exception as batch_error:
application_logger.error(
- f"β Batch mode processing failed: {batch_error}",
- exc_info=True
+ f"β Batch mode processing failed: {batch_error}", exc_info=True
)
+async def _cleanup_websocket_connection(
+ client_id: Optional[str],
+ pending_client_id: str,
+ interim_subscriber_task: Optional[asyncio.Task],
+) -> None:
+ """
+ Shared cleanup for WebSocket handlers (OMI and PCM).
+
+ Cancels the interim results subscriber, removes the pending connection
+ tracking entry, and tears down client state.
+
+ Args:
+ client_id: Actual client ID (may be None if auth failed)
+ pending_client_id: Temporary tracking ID to discard
+ interim_subscriber_task: Background task forwarding interim transcripts
+ """
+ # Cancel interim results subscriber task if running
+ if interim_subscriber_task and not interim_subscriber_task.done():
+ interim_subscriber_task.cancel()
+ try:
+ await interim_subscriber_task
+ except asyncio.CancelledError:
+ application_logger.info(
+ f"Interim subscriber task cancelled for {client_id}"
+ )
+ except Exception as task_error:
+ application_logger.error(
+ f"Error cancelling interim subscriber task: {task_error}"
+ )
+
+ # Clean up pending connection tracking
+ pending_connections.discard(pending_client_id)
+
+ # Ensure cleanup happens even if client_id is None
+ if client_id:
+ try:
+ await cleanup_client_state(client_id)
+ except Exception as cleanup_error:
+ application_logger.error(
+ f"Error during cleanup for client {client_id}: {cleanup_error}",
+ exc_info=True,
+ )
+
+
async def handle_omi_websocket(
ws: WebSocket,
token: Optional[str] = None,
@@ -1294,7 +1393,9 @@ async def handle_omi_websocket(
if header["type"] == "audio-start":
# Handle audio session start
- application_logger.info(f"π΄ BACKEND: Received audio-start in OMI MODE for {client_id} (header={header})")
+ application_logger.info(
+ f"π΄ BACKEND: Received audio-start in OMI MODE for {client_id} (header={header})"
+ )
application_logger.info(f"ποΈ OMI audio session started for {client_id}")
# Store user context on client state
@@ -1308,8 +1409,15 @@ async def handle_omi_websocket(
user.user_id,
user.email,
client_id,
- header.get("data", {"rate": OMI_SAMPLE_RATE, "width": OMI_SAMPLE_WIDTH, "channels": OMI_CHANNELS}),
- websocket=ws # Pass WebSocket to launch interim results subscriber
+ header.get(
+ "data",
+ {
+ "rate": OMI_SAMPLE_RATE,
+ "width": OMI_SAMPLE_WIDTH,
+ "channels": OMI_CHANNELS,
+ },
+ ),
+ websocket=ws, # Pass WebSocket to launch interim results subscriber
)
elif header["type"] == "audio-chunk" and payload:
@@ -1330,7 +1438,7 @@ async def handle_omi_websocket(
_decode_packet,
user.user_id,
client_id,
- packet_count
+ packet_count,
)
# Log progress every 1000th packet
@@ -1352,7 +1460,7 @@ async def handle_omi_websocket(
audio_stream_producer,
user.user_id,
user.email,
- client_id
+ client_id,
)
# Reset counters for next session
@@ -1377,36 +1485,17 @@ async def handle_omi_websocket(
f"π WebSocket disconnected - Client: {client_id}, Packets: {packet_count}, Total bytes: {total_bytes}"
)
except Exception as e:
- application_logger.error(f"β WebSocket error for client {client_id}: {e}", exc_info=True)
+ application_logger.error(
+ f"β WebSocket error for client {client_id}: {e}", exc_info=True
+ )
finally:
- # Cancel interim results subscriber task if running
- if interim_subscriber_task and not interim_subscriber_task.done():
- interim_subscriber_task.cancel()
- try:
- await interim_subscriber_task
- except asyncio.CancelledError:
- application_logger.info(f"Interim subscriber task cancelled for {client_id}")
- except Exception as task_error:
- application_logger.error(f"Error cancelling interim subscriber task: {task_error}")
-
- # Clean up pending connection tracking
- pending_connections.discard(pending_client_id)
-
- # Ensure cleanup happens even if client_id is None
- if client_id:
- try:
- # Clean up client state
- await cleanup_client_state(client_id)
- except Exception as cleanup_error:
- application_logger.error(
- f"Error during cleanup for client {client_id}: {cleanup_error}", exc_info=True
- )
+ await _cleanup_websocket_connection(
+ client_id, pending_client_id, interim_subscriber_task
+ )
async def handle_pcm_websocket(
- ws: WebSocket,
- token: Optional[str] = None,
- device_name: Optional[str] = None
+ ws: WebSocket, token: Optional[str] = None, device_name: Optional[str] = None
):
"""Handle PCM WebSocket connections with batch and streaming mode support."""
# Generate pending client_id to track connection even if auth fails
@@ -1436,14 +1525,24 @@ async def handle_pcm_websocket(
try:
if not audio_streaming:
# Control message mode - parse Wyoming protocol
- application_logger.debug(f"π Control mode for {client_id}, WebSocket state: {ws.client_state if hasattr(ws, 'client_state') else 'unknown'}")
- application_logger.debug(f"π¨ About to receive control message for {client_id}")
+ application_logger.debug(
+ f"π Control mode for {client_id}, WebSocket state: {ws.client_state if hasattr(ws, 'client_state') else 'unknown'}"
+ )
+ application_logger.debug(
+ f"π¨ About to receive control message for {client_id}"
+ )
header, payload = await parse_wyoming_protocol(ws)
- application_logger.debug(f"β
Received message type: {header.get('type')} for {client_id}")
+ application_logger.debug(
+ f"β
Received message type: {header.get('type')} for {client_id}"
+ )
if header["type"] == "audio-start":
- application_logger.info(f"π΄ BACKEND: Received audio-start in CONTROL MODE for {client_id}")
- application_logger.debug(f"ποΈ Processing audio-start for {client_id}")
+ application_logger.info(
+ f"π΄ BACKEND: Received audio-start in CONTROL MODE for {client_id}"
+ )
+ application_logger.debug(
+ f"ποΈ Processing audio-start for {client_id}"
+ )
# Store user context on client state for rolling batch processing
client_state.user_id = user.user_id
@@ -1451,28 +1550,34 @@ async def handle_pcm_websocket(
client_state.client_id = client_id
# Handle audio session start using helper function (pass websocket for error handling)
- audio_streaming, recording_mode = await _handle_audio_session_start(
- client_state,
- header.get("data", {}),
- client_id,
- websocket=ws # Pass websocket for WebUI error display
+ audio_streaming, recording_mode = (
+ await _handle_audio_session_start(
+ client_state,
+ header.get("data", {}),
+ client_id,
+ websocket=ws, # Pass websocket for WebUI error display
+ )
)
# Initialize streaming session
if recording_mode == "streaming":
- application_logger.info(f"π΄ BACKEND: Initializing streaming session for {client_id}")
- interim_subscriber_task = await _initialize_streaming_session(
- client_state,
- audio_stream_producer,
- user.user_id,
- user.email,
- client_id,
- header.get("data", {}),
- websocket=ws
+ application_logger.info(
+ f"π΄ BACKEND: Initializing streaming session for {client_id}"
+ )
+ interim_subscriber_task = (
+ await _initialize_streaming_session(
+ client_state,
+ audio_stream_producer,
+ user.user_id,
+ user.email,
+ client_id,
+ header.get("data", {}),
+ websocket=ws,
+ )
)
continue # Continue to audio streaming mode
-
+
elif header["type"] == "ping":
# Handle keepalive ping from frontend
application_logger.debug(f"π Received ping from {client_id}")
@@ -1492,23 +1597,29 @@ async def handle_pcm_websocket(
f"Ignoring Wyoming control event type '{header['type']}' for {client_id}"
)
continue
-
+
else:
# Audio streaming mode - receive raw bytes (like speaker recognition)
- application_logger.debug(f"π΅ Audio streaming mode for {client_id} - waiting for audio data")
-
+ application_logger.debug(
+ f"π΅ Audio streaming mode for {client_id} - waiting for audio data"
+ )
+
try:
# Receive raw audio bytes or check for control messages
message = await ws.receive()
-
-
+
# Check if it's a disconnect
- if "type" in message and message["type"] == "websocket.disconnect":
+ if (
+ "type" in message
+ and message["type"] == "websocket.disconnect"
+ ):
code = message.get("code", 1000)
reason = message.get("reason", "")
- application_logger.info(f"π WebSocket disconnect during audio streaming for {client_id}. Code: {code}, Reason: {reason}")
+ application_logger.info(
+ f"π WebSocket disconnect during audio streaming for {client_id}. Code: {code}, Reason: {reason}"
+ )
break
-
+
# Check if it's a text message (control message like audio-stop)
if "text" in message:
try:
@@ -1520,22 +1631,28 @@ async def handle_pcm_websocket(
audio_stream_producer,
user.user_id,
user.email,
- client_id
+ client_id,
)
# Reset counters for next session
packet_count = 0
total_bytes = 0
continue
elif control_header.get("type") == "ping":
- application_logger.debug(f"π Received ping during streaming from {client_id}")
+ application_logger.debug(
+ f"π Received ping during streaming from {client_id}"
+ )
continue
elif control_header.get("type") == "audio-start":
# Handle duplicate audio-start messages gracefully (idempotent behavior)
- application_logger.info(f"π Ignoring duplicate audio-start message during streaming for {client_id}")
+ application_logger.info(
+ f"π Ignoring duplicate audio-start message during streaming for {client_id}"
+ )
continue
elif control_header.get("type") == "audio-chunk":
# Handle Wyoming protocol audio-chunk with binary payload
- payload_length = control_header.get("payload_length")
+ payload_length = control_header.get(
+ "payload_length"
+ )
if payload_length and payload_length > 0:
# Receive the binary audio data
payload_msg = await ws.receive()
@@ -1544,10 +1661,14 @@ async def handle_pcm_websocket(
packet_count += 1
total_bytes += len(audio_data)
- application_logger.debug(f"π΅ Received audio chunk #{packet_count}: {len(audio_data)} bytes")
+ application_logger.debug(
+ f"π΅ Received audio chunk #{packet_count}: {len(audio_data)} bytes"
+ )
# Route to appropriate mode handler
- audio_format = control_header.get("data", {})
+ audio_format = control_header.get(
+ "data", {}
+ )
task = await _handle_audio_chunk(
client_state,
audio_stream_producer,
@@ -1556,31 +1677,42 @@ async def handle_pcm_websocket(
user.user_id,
user.email,
client_id,
- websocket=ws
+ websocket=ws,
)
# Store subscriber task if it was created (first streaming chunk)
if task and not interim_subscriber_task:
interim_subscriber_task = task
else:
- application_logger.warning(f"Expected binary payload for audio-chunk, got: {payload_msg.keys()}")
+ application_logger.warning(
+ f"Expected binary payload for audio-chunk, got: {payload_msg.keys()}"
+ )
else:
- application_logger.warning(f"audio-chunk missing payload_length: {payload_length}")
+ application_logger.warning(
+ f"audio-chunk missing payload_length: {payload_length}"
+ )
continue
elif control_header.get("type") == "button-event":
button_data = control_header.get("data", {})
button_state = button_data.get("state", "unknown")
await _handle_button_event(
- client_state, button_state, user.user_id, client_id
+ client_state,
+ button_state,
+ user.user_id,
+ client_id,
)
continue
else:
- application_logger.warning(f"Unknown control message during streaming: {control_header.get('type')}")
+ application_logger.warning(
+ f"Unknown control message during streaming: {control_header.get('type')}"
+ )
continue
except json.JSONDecodeError:
- application_logger.warning(f"Invalid control message during streaming for {client_id}")
+ application_logger.warning(
+ f"Invalid control message during streaming for {client_id}"
+ )
continue
-
+
# Check if it's binary data (raw audio without Wyoming protocol)
elif "bytes" in message:
# Raw binary audio data (legacy support)
@@ -1588,7 +1720,9 @@ async def handle_pcm_websocket(
packet_count += 1
total_bytes += len(audio_data)
- application_logger.debug(f"π΅ Received raw audio chunk #{packet_count}: {len(audio_data)} bytes")
+ application_logger.debug(
+ f"π΅ Received raw audio chunk #{packet_count}: {len(audio_data)} bytes"
+ )
# Route to appropriate mode handler with default format
default_format = {"rate": 16000, "width": 2, "channels": 1}
@@ -1600,18 +1734,22 @@ async def handle_pcm_websocket(
user.user_id,
user.email,
client_id,
- websocket=ws
+ websocket=ws,
)
# Store subscriber task if it was created (first streaming chunk)
if task and not interim_subscriber_task:
interim_subscriber_task = task
-
+
else:
- application_logger.warning(f"Unexpected message format in streaming mode: {message.keys()}")
+ application_logger.warning(
+ f"Unexpected message format in streaming mode: {message.keys()}"
+ )
continue
-
+
except Exception as streaming_error:
- application_logger.error(f"Error in audio streaming mode: {streaming_error}")
+ application_logger.error(
+ f"Error in audio streaming mode: {streaming_error}"
+ )
if "disconnect" in str(streaming_error).lower():
break
continue
@@ -1628,9 +1766,7 @@ async def handle_pcm_websocket(
)
continue # Skip this message but don't disconnect
except ValueError as e:
- application_logger.error(
- f"β Protocol error for {client_id}: {e}"
- )
+ application_logger.error(f"β Protocol error for {client_id}: {e}")
continue # Skip this message but don't disconnect
except RuntimeError as e:
# Handle "Cannot call receive once a disconnect message has been received"
@@ -1646,18 +1782,23 @@ async def handle_pcm_websocket(
continue
except Exception as e:
application_logger.error(
- f"β Unexpected error processing message for {client_id}: {e}", exc_info=True
+ f"β Unexpected error processing message for {client_id}: {e}",
+ exc_info=True,
)
# Check if it's a connection-related error
error_msg = str(e).lower()
- if "disconnect" in error_msg or "closed" in error_msg or "receive" in error_msg:
+ if (
+ "disconnect" in error_msg
+ or "closed" in error_msg
+ or "receive" in error_msg
+ ):
application_logger.info(
f"π Connection issue detected for {client_id}, exiting loop"
)
break
else:
continue # Skip this message for other errors
-
+
except WebSocketDisconnect:
application_logger.info(
f"π PCM WebSocket disconnected - Client: {client_id}, Packets: {packet_count}, Total bytes: {total_bytes}"
@@ -1667,25 +1808,6 @@ async def handle_pcm_websocket(
f"β PCM WebSocket error for client {client_id}: {e}", exc_info=True
)
finally:
- # Cancel interim results subscriber task if running
- if interim_subscriber_task and not interim_subscriber_task.done():
- interim_subscriber_task.cancel()
- try:
- await interim_subscriber_task
- except asyncio.CancelledError:
- application_logger.info(f"Interim subscriber task cancelled for {client_id}")
- except Exception as task_error:
- application_logger.error(f"Error cancelling interim subscriber task: {task_error}")
-
- # Clean up pending connection tracking
- pending_connections.discard(pending_client_id)
-
- # Ensure cleanup happens even if client_id is None
- if client_id:
- try:
- # Clean up client state
- await cleanup_client_state(client_id)
- except Exception as cleanup_error:
- application_logger.error(
- f"Error during cleanup for client {client_id}: {cleanup_error}", exc_info=True
- )
+ await _cleanup_websocket_connection(
+ client_id, pending_client_id, interim_subscriber_task
+ )
diff --git a/backends/advanced/src/advanced_omi_backend/observability/__init__.py b/backends/advanced/src/advanced_omi_backend/observability/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py b/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py
new file mode 100644
index 00000000..70a647b2
--- /dev/null
+++ b/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py
@@ -0,0 +1,105 @@
+"""OpenTelemetry setup with Galileo span processor."""
+
+import logging
+import os
+from functools import lru_cache
+
+logger = logging.getLogger(__name__)
+
+
+@lru_cache(maxsize=1)
+def is_galileo_enabled() -> bool:
+ """Check if Galileo OTEL is configured."""
+ return bool(os.getenv("GALILEO_API_KEY"))
+
+
+_session_token = None
+
+
+def set_galileo_session(session_id: str) -> None:
+ """Set Galileo session ID so subsequent traces are grouped together."""
+ global _session_token
+ if not is_galileo_enabled():
+ return
+ try:
+ from galileo.otel import _session_id_context
+
+ _session_token = _session_id_context.set(session_id)
+ except ImportError:
+ pass
+
+
+def clear_galileo_session() -> None:
+ """Clear the Galileo session ID."""
+ global _session_token
+ if _session_token is None:
+ return
+ try:
+ from galileo.otel import _session_id_context
+
+ _session_id_context.reset(_session_token)
+ _session_token = None
+ except ImportError:
+ pass
+
+
+def init_otel() -> None:
+ """Initialize OTEL with Galileo exporter and OpenAI instrumentor.
+
+ Call once at app startup. Safe to call if Galileo is not configured (no-op).
+ Filters out embedding spans β only LLM (chat completion) calls are exported.
+ """
+ if not is_galileo_enabled():
+ logger.info("Galileo not configured, skipping OTEL initialization")
+ return
+
+ try:
+ from galileo import otel
+ from openinference.instrumentation.openai import OpenAIInstrumentor
+ from opentelemetry import context
+ from opentelemetry.sdk import trace as trace_sdk
+ from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor
+
+ project = os.getenv("GALILEO_PROJECT", "chronicle")
+ logstream = os.getenv("GALILEO_LOG_STREAM", "default")
+
+ class _LLMOnlyProcessor(SpanProcessor):
+ """Wraps GalileoSpanProcessor, dropping EMBEDDING spans."""
+
+ def __init__(self, inner: SpanProcessor):
+ self._inner = inner
+
+ def on_start(
+ self, span: Span, parent_context: context.Context | None = None
+ ) -> None:
+ self._inner.on_start(span, parent_context)
+
+ def on_end(self, span: ReadableSpan) -> None:
+ kind = span.attributes.get("openinference.span.kind", "")
+ if kind == "EMBEDDING":
+ return # drop
+ self._inner.on_end(span)
+
+ def shutdown(self) -> None:
+ self._inner.shutdown()
+
+ def force_flush(self, timeout_millis: int = 30000) -> bool:
+ return self._inner.force_flush(timeout_millis)
+
+ tracer_provider = trace_sdk.TracerProvider()
+ galileo_processor = otel.GalileoSpanProcessor(
+ project=project, logstream=logstream
+ )
+ tracer_provider.add_span_processor(_LLMOnlyProcessor(galileo_processor))
+
+ # Auto-instrument all OpenAI SDK calls
+ OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)
+
+ logger.info("OTEL initialized with Galileo exporter + OpenAI instrumentor")
+ except ImportError:
+ logger.warning(
+ "Galileo/OTEL packages not installed. "
+ "Install with: uv pip install '.[galileo]'"
+ )
+ except Exception as e:
+ logger.error(f"Failed to initialize OTEL: {e}")
diff --git a/backends/advanced/src/advanced_omi_backend/plugins/router.py b/backends/advanced/src/advanced_omi_backend/plugins/router.py
index c70ad73b..e06d3043 100644
--- a/backends/advanced/src/advanced_omi_backend/plugins/router.py
+++ b/backends/advanced/src/advanced_omi_backend/plugins/router.py
@@ -37,13 +37,59 @@ def normalize_text_for_wake_word(text: str) -> str:
# Lowercase
text = text.lower()
# Replace punctuation with spaces (instead of removing, to preserve word boundaries)
- text = text.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation)))
+ text = text.translate(
+ str.maketrans(string.punctuation, " " * len(string.punctuation))
+ )
# Normalize whitespace (collapse multiple spaces to single space)
- text = re.sub(r'\s+', ' ', text)
+ text = re.sub(r"\s+", " ", text)
# Strip leading/trailing whitespace
return text.strip()
+def extract_command_around_keyword(transcript: str, keyword: str) -> str:
+ """
+ Extract command by removing a keyword from anywhere in the transcript.
+
+ Handles punctuation and spacing around the keyword gracefully.
+
+ Example:
+ transcript: "Turn off the lights, Vivi"
+ keyword: "vivi"
+ -> "Turn off the lights"
+
+ transcript: "Vivi, turn off the lights in the hall"
+ keyword: "vivi"
+ -> "turn off the lights in the hall"
+
+ transcript: "Turn off the hall lights, Vivi, please"
+ keyword: "vivi"
+ -> "Turn off the hall lights, please"
+
+ Args:
+ transcript: Original transcript text
+ keyword: Keyword to remove (will be normalized)
+
+ Returns:
+ Command text with keyword removed
+ """
+ keyword_parts = normalize_text_for_wake_word(keyword).split()
+ if not keyword_parts:
+ return transcript.strip()
+
+ pattern_parts = [re.escape(part) for part in keyword_parts]
+ # Match keyword with optional surrounding punctuation/whitespace
+ kw_pattern = r"[\s,.\-!?;:]*".join(pattern_parts)
+ # Consume adjacent punctuation/whitespace on both sides
+ full_pattern = r"[\s,.\-!?;:]*" + kw_pattern + r"[\s,.\-!?;:]*"
+
+ command = re.sub(
+ full_pattern, " ", transcript, count=1, flags=re.IGNORECASE
+ ).strip()
+ # Collapse any doubled spaces left behind
+ command = re.sub(r"\s{2,}", " ", command)
+ return command
+
+
def extract_command_after_wake_word(transcript: str, wake_word: str) -> str:
"""
Intelligently extract command after wake word in original transcript.
@@ -73,25 +119,28 @@ def extract_command_after_wake_word(transcript: str, wake_word: str) -> str:
# The pattern matches the wake word parts with optional punctuation/whitespace between and after
pattern_parts = [re.escape(part) for part in wake_word_parts]
# Allow optional punctuation/whitespace between parts
- pattern = r'[\s,.\-!?;:]*'.join(pattern_parts)
+ pattern = r"[\s,.\-!?;:]*".join(pattern_parts)
# Add trailing punctuation/whitespace consumption after last wake word part
- pattern = '^' + pattern + r'[\s,.\-!?;:]*'
+ pattern = "^" + pattern + r"[\s,.\-!?;:]*"
# Try to match wake word at start of transcript (case-insensitive)
match = re.match(pattern, transcript, re.IGNORECASE)
if match:
# Extract everything after the matched wake word (including trailing punctuation)
- command = transcript[match.end():].strip()
+ command = transcript[match.end() :].strip()
return command
else:
# Fallback: couldn't find wake word boundary, return full transcript
- logger.warning(f"Could not find wake word boundary for '{wake_word}' in '{transcript}', using full transcript")
+ logger.warning(
+ f"Could not find wake word boundary for '{wake_word}' in '{transcript}', using full transcript"
+ )
return transcript.strip()
class ConditionResult(NamedTuple):
"""Result of a plugin condition check."""
+
execute: bool
extra: Dict[str, Any] = {}
@@ -100,9 +149,9 @@ class PluginHealth:
"""Health status for a single plugin."""
# Possible status values
- REGISTERED = "registered" # Registered but not yet initialized
- INITIALIZED = "initialized" # Successfully initialized
- FAILED = "failed" # initialize() raised an exception
+ REGISTERED = "registered" # Registered but not yet initialized
+ INITIALIZED = "initialized" # Successfully initialized
+ FAILED = "failed" # initialize() raised an exception
def __init__(self, plugin_id: str):
self.plugin_id = plugin_id
@@ -182,11 +231,7 @@ def get_health_summary(self) -> Dict[str, Any]:
}
async def dispatch_event(
- self,
- event: str,
- user_id: str,
- data: Dict,
- metadata: Optional[Dict] = None
+ self, event: str, user_id: str, data: Dict, metadata: Optional[Dict] = None
) -> List[PluginResult]:
"""
Dispatch event to all subscribed plugins.
@@ -212,7 +257,9 @@ async def dispatch_event(
if not plugin_ids:
logger.info(f"π ROUTER: No plugins subscribed to event '{event}'")
else:
- logger.info(f"π ROUTER: Found {len(plugin_ids)} subscribed plugin(s): {plugin_ids}")
+ logger.info(
+ f"π ROUTER: Found {len(plugin_ids)} subscribed plugin(s): {plugin_ids}"
+ )
for plugin_id in plugin_ids:
plugin = self.plugins[plugin_id]
@@ -251,22 +298,34 @@ async def dispatch_event(
f"success={result.success}, message={result.message}"
)
results.append(result)
- executed.append({"plugin_id": plugin_id, "success": result.success, "message": result.message})
+ executed.append(
+ {
+ "plugin_id": plugin_id,
+ "success": result.success,
+ "message": result.message,
+ }
+ )
# If plugin says stop processing, break
if not result.should_continue:
- logger.info(f" β Plugin '{plugin_id}' stopped further processing")
+ logger.info(
+ f" β Plugin '{plugin_id}' stopped further processing"
+ )
break
else:
- logger.info(f" β Plugin '{plugin_id}' returned no result for '{event}'")
+ logger.info(
+ f" β Plugin '{plugin_id}' returned no result for '{event}'"
+ )
except Exception as e:
# CRITICAL: Log exception details
logger.error(
f" β Plugin '{plugin_id}' FAILED with exception: {e}",
- exc_info=True
+ exc_info=True,
+ )
+ executed.append(
+ {"plugin_id": plugin_id, "success": False, "message": str(e)}
)
- executed.append({"plugin_id": plugin_id, "success": False, "message": str(e)})
# Add at end
logger.info(
@@ -287,7 +346,9 @@ async def dispatch_event(
_SKIP = ConditionResult(execute=False)
_PASS = ConditionResult(execute=True)
- async def _should_execute(self, plugin: BasePlugin, data: Dict, event: Optional[str] = None) -> ConditionResult:
+ async def _should_execute(
+ self, plugin: BasePlugin, data: Dict, event: Optional[str] = None
+ ) -> ConditionResult:
"""Check if plugin should be executed based on condition configuration.
Returns a ConditionResult. The ``extra`` dict contains per-plugin data
@@ -297,53 +358,85 @@ async def _should_execute(self, plugin: BasePlugin, data: Dict, event: Optional[
Button events bypass transcript-based conditions (wake_word) since they
have no transcript to match against.
"""
- condition_type = plugin.condition.get('type', 'always')
+ condition_type = plugin.condition.get("type", "always")
- if condition_type == 'always':
+ if condition_type == "always":
return self._PASS
# Button and starred events bypass transcript-based conditions (no transcript to match)
- if event and event in (PluginEvent.BUTTON_SINGLE_PRESS, PluginEvent.BUTTON_DOUBLE_PRESS, PluginEvent.CONVERSATION_STARRED):
+ if event and event in (
+ PluginEvent.BUTTON_SINGLE_PRESS,
+ PluginEvent.BUTTON_DOUBLE_PRESS,
+ PluginEvent.CONVERSATION_STARRED,
+ ):
return self._PASS
- elif condition_type == 'wake_word':
+ elif condition_type == "wake_word":
# Normalize transcript for matching (handles punctuation and spacing)
- transcript = data.get('transcript', '')
+ transcript = data.get("transcript", "")
normalized_transcript = normalize_text_for_wake_word(transcript)
# Support both singular 'wake_word' and plural 'wake_words' (list)
- wake_words = plugin.condition.get('wake_words', [])
+ wake_words = plugin.condition.get("wake_words", [])
if not wake_words:
# Fallback to singular wake_word for backward compatibility
- wake_word = plugin.condition.get('wake_word', '')
+ wake_word = plugin.condition.get("wake_word", "")
if wake_word:
wake_words = [wake_word]
# Check if transcript starts with any wake word (after normalization)
for wake_word in wake_words:
normalized_wake_word = normalize_text_for_wake_word(wake_word)
- if normalized_wake_word and normalized_transcript.startswith(normalized_wake_word):
+ if normalized_wake_word and normalized_transcript.startswith(
+ normalized_wake_word
+ ):
# Smart extraction: find where wake word actually ends in original text
command = extract_command_after_wake_word(transcript, wake_word)
- logger.debug(f"Wake word '{wake_word}' detected. Original: '{transcript}', Command: '{command}'")
+ logger.debug(
+ f"Wake word '{wake_word}' detected. Original: '{transcript}', Command: '{command}'"
+ )
return ConditionResult(
execute=True,
- extra={'command': command, 'original_transcript': transcript},
+ extra={"command": command, "original_transcript": transcript},
)
return self._SKIP
- elif condition_type == 'conditional':
+ elif condition_type == "keyword_anywhere":
+ # Trigger when keyword appears anywhere in the transcript.
+ # Command is the transcript with the keyword removed.
+ transcript = data.get("transcript", "")
+ normalized_transcript = normalize_text_for_wake_word(transcript)
+
+ keywords = plugin.condition.get("keywords", [])
+ if not keywords:
+ keyword = plugin.condition.get("keyword", "")
+ if keyword:
+ keywords = [keyword]
+
+ for keyword in keywords:
+ normalized_keyword = normalize_text_for_wake_word(keyword)
+ if normalized_keyword and normalized_keyword in normalized_transcript:
+ command = extract_command_around_keyword(transcript, keyword)
+ logger.debug(
+ f"Keyword '{keyword}' found in transcript. "
+ f"Original: '{transcript}', Command: '{command}'"
+ )
+ return ConditionResult(
+ execute=True,
+ extra={"command": command, "original_transcript": transcript},
+ )
+
+ return self._SKIP
+
+ elif condition_type == "conditional":
# Future: Custom condition checking
return self._PASS
return self._SKIP
async def _execute_plugin(
- self,
- plugin: BasePlugin,
- event: str,
- context: PluginContext
+ self, plugin: BasePlugin, event: str, context: PluginContext
) -> Optional[PluginResult]:
"""Execute plugin method for specified event"""
# Map events to plugin callback methods using enums
@@ -356,7 +449,10 @@ async def _execute_plugin(
return await plugin.on_memory_processed(context)
elif event == PluginEvent.CONVERSATION_STARRED:
return await plugin.on_conversation_starred(context)
- elif event in (PluginEvent.BUTTON_SINGLE_PRESS, PluginEvent.BUTTON_DOUBLE_PRESS):
+ elif event in (
+ PluginEvent.BUTTON_SINGLE_PRESS,
+ PluginEvent.BUTTON_DOUBLE_PRESS,
+ ):
return await plugin.on_button_event(context)
elif event == PluginEvent.PLUGIN_ACTION:
return await plugin.on_plugin_action(context)
@@ -377,14 +473,16 @@ def _log_event(
if not self._event_redis:
return
try:
- record = json.dumps({
- "timestamp": time.time(),
- "event": event,
- "user_id": user_id,
- "plugins_subscribed": plugins_subscribed,
- "plugins_executed": plugins_executed,
- "metadata": metadata or {},
- })
+ record = json.dumps(
+ {
+ "timestamp": time.time(),
+ "event": event,
+ "user_id": user_id,
+ "plugins_subscribed": plugins_subscribed,
+ "plugins_executed": plugins_executed,
+ "metadata": metadata or {},
+ }
+ )
pipe = self._event_redis.pipeline()
pipe.lpush(self._EVENT_LOG_KEY, record)
pipe.ltrim(self._EVENT_LOG_KEY, 0, self._EVENT_LOG_MAX - 1)
@@ -404,8 +502,9 @@ def clear_events(self) -> int:
logger.debug("Failed to clear events from Redis", exc_info=True)
return 0
-
- def get_recent_events(self, limit: int = 50, event_type: Optional[str] = None) -> List[Dict]:
+ def get_recent_events(
+ self, limit: int = 50, event_type: Optional[str] = None
+ ) -> List[Dict]:
"""Read recent events from the Redis log."""
if not self._event_redis:
return []
@@ -439,9 +538,15 @@ async def check_connectivity(self) -> Dict[str, Dict[str, Any]]:
result = await asyncio.wait_for(plugin.health_check(), timeout=10.0)
results[plugin_id] = result
except asyncio.TimeoutError:
- results[plugin_id] = {"ok": False, "message": "Health check timed out (10s)"}
+ results[plugin_id] = {
+ "ok": False,
+ "message": "Health check timed out (10s)",
+ }
except Exception as e:
- results[plugin_id] = {"ok": False, "message": f"Health check error: {e}"}
+ results[plugin_id] = {
+ "ok": False,
+ "message": f"Health check error: {e}",
+ }
return results
diff --git a/backends/advanced/src/advanced_omi_backend/plugins/services.py b/backends/advanced/src/advanced_omi_backend/plugins/services.py
index dbddfb21..5dc09a98 100644
--- a/backends/advanced/src/advanced_omi_backend/plugins/services.py
+++ b/backends/advanced/src/advanced_omi_backend/plugins/services.py
@@ -10,7 +10,6 @@
import redis.asyncio as aioredis
-
from .base import PluginContext, PluginResult
from .events import ConversationCloseReason, PluginEvent
@@ -44,18 +43,36 @@ async def close_conversation(
Signals the open_conversation_job to close the current conversation
and trigger post-processing. The session stays active for new conversations.
+ Only succeeds when open_conversation_job is actively running and polling
+ (indicated by the conversation:current:{session_id} Redis key). During
+ speech detection phase, no conversation is open β the flag would go unread.
+
Args:
session_id: The streaming session ID (typically same as client_id)
reason: Why the conversation is being closed
Returns:
- True if the close request was set successfully
+ True if the close request was set successfully, False if no
+ conversation is currently open for this session
"""
+ # Gate: only set the flag when open_conversation_job is running and will read it.
+ # The conversation:current key is set right before the polling loop starts.
+ conversation_id = await self._async_redis.get(
+ f"conversation:current:{session_id}"
+ )
+ if not conversation_id:
+ logger.warning(
+ f"No open conversation for session {session_id} β close request ignored"
+ )
+ return False
+
from advanced_omi_backend.controllers.session_controller import (
request_conversation_close,
)
- return await request_conversation_close(self._async_redis, session_id, reason=reason.value)
+ return await request_conversation_close(
+ self._async_redis, session_id, reason=reason.value
+ )
async def star_conversation(self, session_id: str) -> bool:
"""Toggle the star on the current conversation for a session.
@@ -73,13 +90,17 @@ async def star_conversation(self, session_id: str) -> bool:
from advanced_omi_backend.users import User
# Look up current conversation_id from Redis
- conversation_id = await self._async_redis.get(f"conversation:current:{session_id}")
+ conversation_id = await self._async_redis.get(
+ f"conversation:current:{session_id}"
+ )
if not conversation_id:
logger.warning(f"No current conversation for session {session_id}")
return False
# Find conversation to get user_id
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
logger.warning(f"Conversation {conversation_id} not found for starring")
return False
@@ -115,10 +136,14 @@ async def call_plugin(
plugin = self._router.plugins.get(plugin_id)
if not plugin:
logger.warning(f"Plugin '{plugin_id}' not found for cross-plugin call")
- return PluginResult(success=False, message=f"Plugin '{plugin_id}' not found")
+ return PluginResult(
+ success=False, message=f"Plugin '{plugin_id}' not found"
+ )
if not plugin.enabled:
logger.warning(f"Plugin '{plugin_id}' is disabled, cannot call")
- return PluginResult(success=False, message=f"Plugin '{plugin_id}' is disabled")
+ return PluginResult(
+ success=False, message=f"Plugin '{plugin_id}' is disabled"
+ )
context = PluginContext(
user_id=user_id,
@@ -136,5 +161,7 @@ async def call_plugin(
)
return result
except Exception as e:
- logger.error(f"Cross-plugin call to {plugin_id}.{action} failed: {e}", exc_info=True)
+ logger.error(
+ f"Cross-plugin call to {plugin_id}.{action} failed: {e}", exc_info=True
+ )
return PluginResult(success=False, message=f"Plugin action failed: {e}")
diff --git a/backends/advanced/src/advanced_omi_backend/prompt_defaults.py b/backends/advanced/src/advanced_omi_backend/prompt_defaults.py
index 75d89eee..94b03d31 100644
--- a/backends/advanced/src/advanced_omi_backend/prompt_defaults.py
+++ b/backends/advanced/src/advanced_omi_backend/prompt_defaults.py
@@ -524,7 +524,7 @@ def register_all_defaults(registry: PromptRegistry) -> None:
# ------------------------------------------------------------------
registry.register_default(
"asr.hot_words",
- template="hey vivi, chronicle, omi",
+ template="vivi, chronicle, omi",
name="ASR Hot Words",
description="Comma-separated hot words for speech recognition. "
"For Deepgram: boosts keyword recognition via keyterm. "
diff --git a/backends/advanced/src/advanced_omi_backend/prompt_registry.py b/backends/advanced/src/advanced_omi_backend/prompt_registry.py
index eae9c248..5aacb151 100644
--- a/backends/advanced/src/advanced_omi_backend/prompt_registry.py
+++ b/backends/advanced/src/advanced_omi_backend/prompt_registry.py
@@ -39,6 +39,7 @@ def _get_client(self):
if self._langfuse is None:
try:
from langfuse import Langfuse
+
self._langfuse = Langfuse()
except Exception as e:
logger.warning(f"LangFuse client init failed: {e}")
@@ -77,9 +78,10 @@ async def get_prompt(self, prompt_id: str, **variables) -> str:
return template_text
async def seed_prompts(self) -> None:
- """Create prompts in LangFuse if they don't already exist.
+ """Create or update prompts in LangFuse, skipping unchanged ones.
Called once at startup after all defaults have been registered.
+ Only creates a new version when the prompt text has actually changed.
"""
client = self._get_client()
if client is None:
@@ -90,6 +92,19 @@ async def seed_prompts(self) -> None:
skipped = 0
for prompt_id, template_text in self._defaults.items():
try:
+ # Check if the prompt already exists with the same text
+ existing = None
+ try:
+ existing = client.get_prompt(prompt_id)
+ except Exception:
+ pass # Prompt doesn't exist yet
+
+ if existing is not None:
+ existing_text = getattr(existing, "prompt", None)
+ if existing_text == template_text:
+ skipped += 1
+ continue
+
client.create_prompt(
name=prompt_id,
type="text",
@@ -98,13 +113,9 @@ async def seed_prompts(self) -> None:
)
seeded += 1
except Exception as e:
- err_msg = str(e).lower()
- if "already exists" in err_msg or "409" in err_msg:
- skipped += 1
- else:
- logger.warning(f"Failed to seed prompt '{prompt_id}': {e}")
+ logger.warning(f"Failed to seed prompt '{prompt_id}': {e}")
- logger.info(f"Prompt seeding complete: {seeded} created, {skipped} already existed")
+ logger.info(f"Prompt seeding complete: {seeded} created, {skipped} unchanged")
# ---------------------------------------------------------------------------
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py
index 4410665a..d45513fd 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py
@@ -32,11 +32,17 @@ async def list_jobs(
queue_name: str = Query(None, description="Filter by queue name"),
job_type: str = Query(None, description="Filter by job type (matches func_name)"),
client_id: str = Query(None, description="Filter by client_id in meta"),
- current_user: User = Depends(current_active_user)
+ current_user: User = Depends(current_active_user),
):
"""List jobs with pagination and filtering."""
try:
- result = get_jobs(limit=limit, offset=offset, queue_name=queue_name, job_type=job_type, client_id=client_id)
+ result = get_jobs(
+ limit=limit,
+ offset=offset,
+ queue_name=queue_name,
+ job_type=job_type,
+ client_id=client_id,
+ )
# Filter jobs by user if not admin
if not current_user.is_superuser:
@@ -54,13 +60,21 @@ async def list_jobs(
except Exception as e:
logger.error(f"Failed to list jobs: {e}")
- return {"error": "Failed to list jobs", "jobs": [], "pagination": {"total": 0, "limit": limit, "offset": offset, "has_more": False}}
+ return {
+ "error": "Failed to list jobs",
+ "jobs": [],
+ "pagination": {
+ "total": 0,
+ "limit": limit,
+ "offset": offset,
+ "has_more": False,
+ },
+ }
@router.get("/jobs/{job_id}/status")
async def get_job_status(
- job_id: str,
- current_user: User = Depends(current_active_user)
+ job_id: str, current_user: User = Depends(current_active_user)
):
"""Get just the status of a specific job (lightweight endpoint)."""
try:
@@ -79,10 +93,7 @@ async def get_job_status(
logger.error(f"Failed to determine status for job {job_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
- response = {
- "job_id": job.id,
- "status": status
- }
+ response = {"job_id": job.id, "status": status}
# Include error information for failed jobs
if status == "failed" and job.exc_info:
@@ -100,10 +111,7 @@ async def get_job_status(
@router.get("/jobs/{job_id}")
-async def get_job(
- job_id: str,
- current_user: User = Depends(current_active_user)
-):
+async def get_job(job_id: str, current_user: User = Depends(current_active_user)):
"""Get detailed job information including result."""
try:
job = Job.fetch(job_id, connection=redis_conn)
@@ -128,7 +136,7 @@ async def get_job(
"started_at": job.started_at.isoformat() if job.started_at else None,
"ended_at": job.ended_at.isoformat() if job.ended_at else None,
"description": job.description or "",
- "func_name": job.func_name if hasattr(job, 'func_name') else "",
+ "func_name": job.func_name if hasattr(job, "func_name") else "",
"args": job.args,
"kwargs": job.kwargs,
"meta": job.meta if job.meta else {},
@@ -145,10 +153,7 @@ async def get_job(
@router.delete("/jobs/{job_id}")
-async def cancel_job(
- job_id: str,
- current_user: User = Depends(current_active_user)
-):
+async def cancel_job(job_id: str, current_user: User = Depends(current_active_user)):
"""Cancel or delete a job."""
try:
job = Job.fetch(job_id, connection=redis_conn)
@@ -167,7 +172,7 @@ async def cancel_job(
return {
"job_id": job_id,
"action": "canceled",
- "message": f"Job {job_id} has been canceled"
+ "message": f"Job {job_id} has been canceled",
}
else:
# Delete finished/failed jobs
@@ -176,7 +181,7 @@ async def cancel_job(
return {
"job_id": job_id,
"action": "deleted",
- "message": f"Job {job_id} has been deleted"
+ "message": f"Job {job_id} has been deleted",
}
except HTTPException:
@@ -184,13 +189,14 @@ async def cancel_job(
raise
except Exception as e:
logger.error(f"Failed to cancel/delete job {job_id}: {e}")
- raise HTTPException(status_code=404, detail=f"Job not found or could not be canceled: {str(e)}")
+ raise HTTPException(
+ status_code=404, detail=f"Job not found or could not be canceled: {str(e)}"
+ )
@router.get("/jobs/by-client/{client_id}")
async def get_jobs_by_client(
- client_id: str,
- current_user: User = Depends(current_active_user)
+ client_id: str, current_user: User = Depends(current_active_user)
):
"""Get all jobs associated with a specific client device."""
try:
@@ -237,27 +243,37 @@ def process_job_and_dependents(job, queue_name, base_status):
status = get_job_status(job, {})
# Add this job to results
- all_jobs.append({
- "job_id": job.id,
- "job_type": job.func_name.split('.')[-1] if job.func_name else "unknown",
- "queue": queue_name,
- "status": status,
- "created_at": job.created_at.isoformat() if job.created_at else None,
- "started_at": job.started_at.isoformat() if job.started_at else None,
- "ended_at": job.ended_at.isoformat() if job.ended_at else None,
- "description": job.description or "",
- "result": job.result,
- "meta": job.meta if job.meta else {},
- "args": job.args,
- "kwargs": job.kwargs if job.kwargs else {},
- "error_message": str(job.exc_info) if job.exc_info else None,
- })
+ all_jobs.append(
+ {
+ "job_id": job.id,
+ "job_type": (
+ job.func_name.split(".")[-1] if job.func_name else "unknown"
+ ),
+ "queue": queue_name,
+ "status": status,
+ "created_at": (
+ job.created_at.isoformat() if job.created_at else None
+ ),
+ "started_at": (
+ job.started_at.isoformat() if job.started_at else None
+ ),
+ "ended_at": job.ended_at.isoformat() if job.ended_at else None,
+ "description": job.description or "",
+ "result": job.result,
+ "meta": job.meta if job.meta else {},
+ "args": job.args,
+ "kwargs": job.kwargs if job.kwargs else {},
+ "error_message": str(job.exc_info) if job.exc_info else None,
+ }
+ )
# Check for dependent jobs (jobs that depend on this one)
try:
dependent_ids = job.dependent_ids
if dependent_ids:
- logger.debug(f"Job {job.id} has {len(dependent_ids)} dependents: {dependent_ids}")
+ logger.debug(
+ f"Job {job.id} has {len(dependent_ids)} dependents: {dependent_ids}"
+ )
for dep_id in dependent_ids:
try:
@@ -276,12 +292,21 @@ def process_job_and_dependents(job, queue_name, base_status):
# Check all registries (using RQ standard status names)
registries = [
("queued", queue.job_ids),
- ("started", StartedJobRegistry(queue=queue).get_job_ids()), # RQ standard
- ("finished", FinishedJobRegistry(queue=queue).get_job_ids()), # RQ standard
+ (
+ "started",
+ StartedJobRegistry(queue=queue).get_job_ids(),
+ ), # RQ standard
+ (
+ "finished",
+ FinishedJobRegistry(queue=queue).get_job_ids(),
+ ), # RQ standard
("failed", FailedJobRegistry(queue=queue).get_job_ids()),
- ("canceled", CanceledJobRegistry(queue=queue).get_job_ids()), # RQ standard (US spelling)
+ (
+ "canceled",
+ CanceledJobRegistry(queue=queue).get_job_ids(),
+ ), # RQ standard (US spelling)
("deferred", DeferredJobRegistry(queue=queue).get_job_ids()),
- ("scheduled", ScheduledJobRegistry(queue=queue).get_job_ids())
+ ("scheduled", ScheduledJobRegistry(queue=queue).get_job_ids()),
]
for status_name, job_ids in registries:
@@ -293,8 +318,8 @@ def process_job_and_dependents(job, queue_name, base_status):
matches_client = False
# Check job.meta for client_id (current standard)
- if job.meta and 'client_id' in job.meta:
- if job.meta['client_id'] == client_id:
+ if job.meta and "client_id" in job.meta:
+ if job.meta["client_id"] == client_id:
matches_client = True
if matches_client:
@@ -308,17 +333,17 @@ def process_job_and_dependents(job, queue_name, base_status):
# Sort by created_at
all_jobs.sort(key=lambda x: x["created_at"] or "", reverse=False)
- logger.info(f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)")
+ logger.info(
+ f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)"
+ )
- return {
- "client_id": client_id,
- "jobs": all_jobs,
- "total": len(all_jobs)
- }
+ return {"client_id": client_id, "jobs": all_jobs, "total": len(all_jobs)}
except Exception as e:
logger.error(f"Failed to get jobs for client {client_id}: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to get jobs for client: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to get jobs for client: {str(e)}"
+ )
@router.get("/events")
@@ -338,7 +363,9 @@ async def get_events(
if not router_instance:
return {"events": [], "total": 0}
- events = router_instance.get_recent_events(limit=limit, event_type=event_type or None)
+ events = router_instance.get_recent_events(
+ limit=limit, event_type=event_type or None
+ )
return {"events": events, "total": len(events)}
except Exception as e:
logger.error(f"Failed to get events: {e}")
@@ -408,11 +435,8 @@ async def clear_events(
raise HTTPException(status_code=500, detail=f"Failed to clear events: {str(e)}")
-
@router.get("/stats")
-async def get_queue_stats_endpoint(
- current_user: User = Depends(current_active_user)
-):
+async def get_queue_stats_endpoint(current_user: User = Depends(current_active_user)):
"""Get queue statistics."""
try:
stats = get_job_stats()
@@ -420,13 +444,19 @@ async def get_queue_stats_endpoint(
except Exception as e:
logger.error(f"Failed to get queue stats: {e}")
- return {"total_jobs": 0, "queued_jobs": 0, "started_jobs": 0, "finished_jobs": 0, "failed_jobs": 0, "canceled_jobs": 0, "deferred_jobs": 0}
+ return {
+ "total_jobs": 0,
+ "queued_jobs": 0,
+ "started_jobs": 0,
+ "finished_jobs": 0,
+ "failed_jobs": 0,
+ "canceled_jobs": 0,
+ "deferred_jobs": 0,
+ }
@router.get("/worker-details")
-async def get_queue_worker_details(
- current_user: User = Depends(current_active_user)
-):
+async def get_queue_worker_details(current_user: User = Depends(current_active_user)):
"""Get detailed queue and worker status including task manager health."""
try:
import time
@@ -443,34 +473,34 @@ async def get_queue_worker_details(
"total": queue_health.get("total_workers", 0),
"active": queue_health.get("active_workers", 0),
"idle": queue_health.get("idle_workers", 0),
- "details": queue_health.get("workers", [])
+ "details": queue_health.get("workers", []),
},
"queues": queue_health.get("queues", {}),
- "redis_connection": queue_health.get("redis_connection", "unknown")
+ "redis_connection": queue_health.get("redis_connection", "unknown"),
}
return status
except Exception as e:
logger.error(f"Failed to get queue worker details: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to get worker details: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to get worker details: {str(e)}"
+ )
@router.get("/streams")
async def get_stream_stats(
limit: int = Query(default=10, ge=1, le=100), # Max 100 streams to prevent timeouts
- current_user: User = Depends(current_active_user)
+ current_user: User = Depends(current_active_user),
):
"""Get Redis Streams statistics with consumer group information."""
try:
from advanced_omi_backend.services.audio_service import get_audio_stream_service
+
audio_service = get_audio_stream_service()
if not audio_service.redis:
- return {
- "error": "Audio stream service not connected",
- "streams": []
- }
+ return {"error": "Audio stream service not connected", "streams": []}
# Get audio streams with limit
stream_keys = []
@@ -479,14 +509,16 @@ async def get_stream_stats(
cursor, keys = await audio_service.redis.scan(
cursor, match=f"{audio_service.audio_stream_prefix}*", count=limit
)
- stream_keys.extend(keys[:limit - len(stream_keys)])
+ stream_keys.extend(keys[: limit - len(stream_keys)])
# Use asyncio.gather to fetch stream info in parallel
import asyncio
async def get_stream_info(stream_key):
try:
- stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ stream_name = (
+ stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ )
# Get basic stream info
info = await audio_service.redis.xinfo_stream(stream_name)
@@ -499,9 +531,13 @@ async def get_stream_info(stream_key):
group_dict = {}
# Parse group info (alternating key-value pairs)
for i in range(0, len(group), 2):
- if i+1 < len(group):
- key = group[i].decode() if isinstance(group[i], bytes) else str(group[i])
- value = group[i+1]
+ if i + 1 < len(group):
+ key = (
+ group[i].decode()
+ if isinstance(group[i], bytes)
+ else str(group[i])
+ )
+ value = group[i + 1]
if isinstance(value, bytes):
try:
value = value.decode()
@@ -512,13 +548,19 @@ async def get_stream_info(stream_key):
# Get consumers for this group
consumers = []
try:
- consumers_raw = await audio_service.redis.xinfo_consumers(stream_name, group_dict.get('name', ''))
+ consumers_raw = await audio_service.redis.xinfo_consumers(
+ stream_name, group_dict.get("name", "")
+ )
for consumer in consumers_raw:
consumer_dict = {}
for i in range(0, len(consumer), 2):
- if i+1 < len(consumer):
- key = consumer[i].decode() if isinstance(consumer[i], bytes) else str(consumer[i])
- value = consumer[i+1]
+ if i + 1 < len(consumer):
+ key = (
+ consumer[i].decode()
+ if isinstance(consumer[i], bytes)
+ else str(consumer[i])
+ )
+ value = consumer[i + 1]
if isinstance(value, bytes):
try:
value = value.decode()
@@ -527,46 +569,56 @@ async def get_stream_info(stream_key):
consumer_dict[key] = value
consumers.append(consumer_dict)
except Exception as ce:
- logger.debug(f"Could not fetch consumers for group {group_dict.get('name')}: {ce}")
-
- groups_info.append({
- "name": group_dict.get('name', 'unknown'),
- "consumers": group_dict.get('consumers', 0),
- "pending": group_dict.get('pending', 0),
- "last_delivered_id": group_dict.get('last-delivered-id', 'N/A'),
- "consumer_details": consumers
- })
+ logger.debug(
+ f"Could not fetch consumers for group {group_dict.get('name')}: {ce}"
+ )
+
+ groups_info.append(
+ {
+ "name": group_dict.get("name", "unknown"),
+ "consumers": group_dict.get("consumers", 0),
+ "pending": group_dict.get("pending", 0),
+ "last_delivered_id": group_dict.get(
+ "last-delivered-id", "N/A"
+ ),
+ "consumer_details": consumers,
+ }
+ )
except Exception as ge:
logger.debug(f"No consumer groups for stream {stream_name}: {ge}")
return {
"stream_name": stream_name,
"length": info[b"length"],
- "first_entry_id": info[b"first-entry"][0].decode() if info[b"first-entry"] else None,
- "last_entry_id": info[b"last-entry"][0].decode() if info[b"last-entry"] else None,
- "groups": groups_info
+ "first_entry_id": (
+ info[b"first-entry"][0].decode()
+ if info[b"first-entry"]
+ else None
+ ),
+ "last_entry_id": (
+ info[b"last-entry"][0].decode() if info[b"last-entry"] else None
+ ),
+ "groups": groups_info,
}
except Exception as e:
logger.error(f"Error getting info for stream {stream_key}: {e}")
return None
# Fetch all stream info in parallel
- streams_info_results = await asyncio.gather(*[get_stream_info(key) for key in stream_keys])
+ streams_info_results = await asyncio.gather(
+ *[get_stream_info(key) for key in stream_keys]
+ )
streams_info = [info for info in streams_info_results if info is not None]
return {
"total_streams": len(streams_info),
"streams": streams_info,
- "limited": len(stream_keys) >= limit
+ "limited": len(stream_keys) >= limit,
}
except Exception as e:
logger.error(f"Failed to get stream stats: {e}", exc_info=True)
- return {
- "error": str(e),
- "total_streams": 0,
- "streams": []
- }
+ return {"error": str(e), "total_streams": 0, "streams": []}
class FlushJobsRequest(BaseModel):
@@ -582,8 +634,7 @@ class FlushAllJobsRequest(BaseModel):
@router.post("/flush")
async def flush_jobs(
- request: FlushJobsRequest,
- current_user: User = Depends(current_active_user)
+ request: FlushJobsRequest, current_user: User = Depends(current_active_user)
):
"""Flush old inactive jobs based on age and status."""
if not current_user.is_superuser:
@@ -600,7 +651,9 @@ async def flush_jobs(
from advanced_omi_backend.controllers.queue_controller import get_queue
- cutoff_time = datetime.now(timezone.utc) - timedelta(hours=request.older_than_hours)
+ cutoff_time = datetime.now(timezone.utc) - timedelta(
+ hours=request.older_than_hours
+ )
total_removed = 0
# Get all queues
@@ -632,7 +685,9 @@ async def flush_jobs(
except Exception as e:
logger.error(f"Error deleting job {job_id}: {e}")
- if "canceled" in request.statuses: # RQ standard (US spelling), not "cancelled"
+ if (
+ "canceled" in request.statuses
+ ): # RQ standard (US spelling), not "cancelled"
registry = CanceledJobRegistry(queue=queue)
for job_id in registry.get_job_ids():
try:
@@ -646,7 +701,7 @@ async def flush_jobs(
return {
"total_removed": total_removed,
"cutoff_time": cutoff_time.isoformat(),
- "statuses": request.statuses
+ "statuses": request.statuses,
}
except Exception as e:
@@ -656,8 +711,7 @@ async def flush_jobs(
@router.post("/flush-all")
async def flush_all_jobs(
- request: FlushAllJobsRequest,
- current_user: User = Depends(current_active_user)
+ request: FlushAllJobsRequest, current_user: User = Depends(current_active_user)
):
"""
Flush jobs from queues and registries.
@@ -696,10 +750,16 @@ async def flush_all_jobs(
# Build list of registries to flush based on request parameters
registries = [
- ("started", StartedJobRegistry(queue=queue)), # Always flush in-progress
+ (
+ "started",
+ StartedJobRegistry(queue=queue),
+ ), # Always flush in-progress
("deferred", DeferredJobRegistry(queue=queue)), # Always flush deferred
- ("scheduled", ScheduledJobRegistry(queue=queue)), # Always flush scheduled
- ("canceled", CanceledJobRegistry(queue=queue)) # Always flush canceled
+ (
+ "scheduled",
+ ScheduledJobRegistry(queue=queue),
+ ), # Always flush scheduled
+ ("canceled", CanceledJobRegistry(queue=queue)), # Always flush canceled
]
# Conditionally add failed and finished registries
@@ -709,8 +769,12 @@ async def flush_all_jobs(
registries.append(("finished", FinishedJobRegistry(queue=queue)))
for registry_name, registry in registries:
- job_ids = list(registry.get_job_ids()) # Convert to list to avoid iterator issues
- logger.info(f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}")
+ job_ids = list(
+ registry.get_job_ids()
+ ) # Convert to list to avoid iterator issues
+ logger.info(
+ f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}"
+ )
for job_id in job_ids:
try:
@@ -720,7 +784,9 @@ async def flush_all_jobs(
# Skip session-level jobs (e.g., speech_detection, audio_persistence)
# These run for the entire session and should not be killed by test cleanup
if job.meta and job.meta.get("session_level"):
- logger.info(f"Skipping session-level job {job_id} ({job.description})")
+ logger.info(
+ f"Skipping session-level job {job_id} ({job.description})"
+ )
continue
# Handle running jobs differently to avoid worker deadlock
@@ -729,19 +795,28 @@ async def flush_all_jobs(
# This lets the worker clean up gracefully and prevents deadlock
try:
from rq.command import send_stop_job_command
+
send_stop_job_command(redis_conn, job_id)
- logger.info(f"Sent stop command to worker for job {job_id}")
+ logger.info(
+ f"Sent stop command to worker for job {job_id}"
+ )
# Don't delete yet - let worker move it to canceled/failed registry
# It will be cleaned up on next flush or by worker cleanup
continue
except Exception as stop_error:
- logger.warning(f"Could not send stop command to job {job_id}: {stop_error}")
+ logger.warning(
+ f"Could not send stop command to job {job_id}: {stop_error}"
+ )
# If stop fails, try to cancel it (may already be finishing)
try:
job.cancel()
- logger.info(f"Cancelled job {job_id} after stop failed")
+ logger.info(
+ f"Cancelled job {job_id} after stop failed"
+ )
except Exception as cancel_error:
- logger.warning(f"Could not cancel job {job_id}: {cancel_error}")
+ logger.warning(
+ f"Could not cancel job {job_id}: {cancel_error}"
+ )
# For non-running jobs, safe to delete immediately
job.delete()
@@ -752,22 +827,29 @@ async def flush_all_jobs(
logger.warning(f"Error deleting job {job_id}: {e}")
try:
registry.remove(job_id)
- logger.info(f"Removed stale job reference {job_id} from {registry_name} registry")
+ logger.info(
+ f"Removed stale job reference {job_id} from {registry_name} registry"
+ )
except Exception as reg_error:
- logger.error(f"Could not remove {job_id} from registry: {reg_error}")
+ logger.error(
+ f"Could not remove {job_id} from registry: {reg_error}"
+ )
# Also clean up audio streams and consumer locks
deleted_keys = 0
# Get async Redis connection for scanning
from advanced_omi_backend.controllers.queue_controller import REDIS_URL
+
async_redis = await aioredis.from_url(REDIS_URL)
try:
# Delete audio streams
cursor = 0
while True:
- cursor, keys = await async_redis.scan(cursor, match="audio:*", count=1000)
+ cursor, keys = await async_redis.scan(
+ cursor, match="audio:*", count=1000
+ )
if keys:
await async_redis.delete(*keys)
deleted_keys += len(keys)
@@ -777,7 +859,9 @@ async def flush_all_jobs(
# Delete consumer locks
cursor = 0
while True:
- cursor, keys = await async_redis.scan(cursor, match="consumer:*", count=1000)
+ cursor, keys = await async_redis.scan(
+ cursor, match="consumer:*", count=1000
+ )
if keys:
await async_redis.delete(*keys)
deleted_keys += len(keys)
@@ -793,24 +877,28 @@ async def flush_all_jobs(
preserved.append("finished jobs")
preserved_msg = f" (preserved {', '.join(preserved)})" if preserved else ""
- logger.info(f"Flushed {total_removed} jobs and {deleted_keys} Redis keys from all queues{preserved_msg}")
+ logger.info(
+ f"Flushed {total_removed} jobs and {deleted_keys} Redis keys from all queues{preserved_msg}"
+ )
return {
"total_removed": total_removed,
"deleted_keys": deleted_keys,
"preserved": preserved,
- "message": f"Flushed {total_removed} jobs{preserved_msg}"
+ "message": f"Flushed {total_removed} jobs{preserved_msg}",
}
except Exception as e:
logger.error(f"Failed to flush all jobs: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to flush all jobs: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to flush all jobs: {str(e)}"
+ )
@router.get("/sessions")
async def get_redis_sessions(
limit: int = Query(default=20, ge=1, le=100),
- current_user: User = Depends(current_active_user)
+ current_user: User = Depends(current_active_user),
):
"""Get Redis session tracking information."""
try:
@@ -827,7 +915,7 @@ async def get_redis_sessions(
cursor, keys = await redis_client.scan(
cursor, match="audio:session:*", count=limit
)
- session_keys.extend(keys[:limit - len(session_keys)])
+ session_keys.extend(keys[: limit - len(session_keys)])
# Get session info
sessions = []
@@ -839,31 +927,43 @@ async def get_redis_sessions(
# Get conversation count for this session
conversation_count_key = f"session:conversation_count:{session_id}"
- conversation_count_bytes = await redis_client.get(conversation_count_key)
- conversation_count = int(conversation_count_bytes.decode()) if conversation_count_bytes else 0
-
- sessions.append({
- "session_id": session_id,
- "user_id": session_data.get(b"user_id", b"").decode(),
- "client_id": session_data.get(b"client_id", b"").decode(),
- "stream_name": session_data.get(b"stream_name", b"").decode(),
- "provider": session_data.get(b"provider", b"").decode(),
- "mode": session_data.get(b"mode", b"").decode(),
- "status": session_data.get(b"status", b"").decode(),
- "started_at": session_data.get(b"started_at", b"").decode(),
- "chunks_published": int(session_data.get(b"chunks_published", b"0").decode() or 0),
- "last_chunk_at": session_data.get(b"last_chunk_at", b"").decode(),
- "conversation_count": conversation_count
- })
+ conversation_count_bytes = await redis_client.get(
+ conversation_count_key
+ )
+ conversation_count = (
+ int(conversation_count_bytes.decode())
+ if conversation_count_bytes
+ else 0
+ )
+
+ sessions.append(
+ {
+ "session_id": session_id,
+ "user_id": session_data.get(b"user_id", b"").decode(),
+ "client_id": session_data.get(b"client_id", b"").decode(),
+ "stream_name": session_data.get(
+ b"stream_name", b""
+ ).decode(),
+ "provider": session_data.get(b"provider", b"").decode(),
+ "mode": session_data.get(b"mode", b"").decode(),
+ "status": session_data.get(b"status", b"").decode(),
+ "started_at": session_data.get(b"started_at", b"").decode(),
+ "chunks_published": int(
+ session_data.get(b"chunks_published", b"0").decode()
+ or 0
+ ),
+ "last_chunk_at": session_data.get(
+ b"last_chunk_at", b""
+ ).decode(),
+ "conversation_count": conversation_count,
+ }
+ )
except Exception as e:
logger.error(f"Error getting session info for {key}: {e}")
await redis_client.close()
- return {
- "total_sessions": len(sessions),
- "sessions": sessions
- }
+ return {"total_sessions": len(sessions), "sessions": sessions}
except Exception as e:
logger.error(f"Failed to get sessions: {e}", exc_info=True)
@@ -872,8 +972,10 @@ async def get_redis_sessions(
@router.post("/sessions/clear")
async def clear_old_sessions(
- older_than_seconds: int = Query(default=3600, description="Clear sessions older than N seconds"),
- current_user: User = Depends(current_active_user)
+ older_than_seconds: int = Query(
+ default=3600, description="Clear sessions older than N seconds"
+ ),
+ current_user: User = Depends(current_active_user),
):
"""Clear old Redis sessions that are stuck or inactive."""
if not current_user.is_superuser:
@@ -894,7 +996,9 @@ async def clear_old_sessions(
session_keys = []
cursor = b"0"
while cursor:
- cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=100)
+ cursor, keys = await redis_client.scan(
+ cursor, match="audio:session:*", count=100
+ )
session_keys.extend(keys)
# Check each session and delete if old
@@ -915,21 +1019,22 @@ async def clear_old_sessions(
await redis_client.close()
- return {
- "deleted_count": deleted_count,
- "cutoff_seconds": older_than_seconds
- }
+ return {"deleted_count": deleted_count, "cutoff_seconds": older_than_seconds}
except Exception as e:
logger.error(f"Failed to clear sessions: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"Failed to clear sessions: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to clear sessions: {str(e)}"
+ )
@router.get("/dashboard")
async def get_dashboard_data(
request: Request,
- expanded_clients: str = Query(default="", description="Comma-separated list of client IDs to fetch jobs for"),
- current_user: User = Depends(current_active_user)
+ expanded_clients: str = Query(
+ default="", description="Comma-separated list of client IDs to fetch jobs for"
+ ),
+ current_user: User = Depends(current_active_user),
):
"""Get all data needed for the Queue dashboard in a single API call.
@@ -950,7 +1055,11 @@ async def get_dashboard_data(
from advanced_omi_backend.controllers.queue_controller import get_queue
# Parse expanded clients list
- expanded_client_ids = [c.strip() for c in expanded_clients.split(",") if c.strip()] if expanded_clients else []
+ expanded_client_ids = (
+ [c.strip() for c in expanded_clients.split(",") if c.strip()]
+ if expanded_clients
+ else []
+ )
# Fetch all data in parallel
import asyncio
@@ -968,11 +1077,17 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100):
if status_name == "queued":
job_ids = queue.job_ids[:limit]
elif status_name == "started": # RQ standard, not "processing"
- job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[:limit]
+ job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[
+ :limit
+ ]
elif status_name == "finished": # RQ standard, not "completed"
- job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[:limit]
+ job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[
+ :limit
+ ]
elif status_name == "failed":
- job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[:limit]
+ job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[
+ :limit
+ ]
else:
continue
@@ -983,31 +1098,61 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100):
# Check user permission
if not current_user.is_superuser:
- job_user_id = job.kwargs.get("user_id") if job.kwargs else None
+ job_user_id = (
+ job.kwargs.get("user_id") if job.kwargs else None
+ )
if job_user_id != str(current_user.user_id):
continue
# Add job with metadata
- all_jobs.append({
- "job_id": job.id,
- "job_type": job.func_name.split('.')[-1] if job.func_name else "unknown",
- "user_id": job.kwargs.get("user_id") if job.kwargs else None,
- "status": status_name,
- "priority": "normal", # RQ doesn't have priority concept
- "data": {"description": job.description or ""},
- "result": job.result,
- "meta": job.meta if job.meta else {},
- "kwargs": job.kwargs if job.kwargs else {},
- "error_message": str(job.exc_info) if job.exc_info else None,
- "created_at": job.created_at.isoformat() if job.created_at else None,
- "started_at": job.started_at.isoformat() if job.started_at else None,
- "ended_at": job.ended_at.isoformat() if job.ended_at else None,
- "retry_count": 0, # RQ doesn't track this by default
- "max_retries": 0,
- "progress_percent": 0,
- "progress_message": "",
- "queue": queue_name
- })
+ all_jobs.append(
+ {
+ "job_id": job.id,
+ "job_type": (
+ job.func_name.split(".")[-1]
+ if job.func_name
+ else "unknown"
+ ),
+ "user_id": (
+ job.kwargs.get("user_id")
+ if job.kwargs
+ else None
+ ),
+ "status": status_name,
+ "priority": "normal", # RQ doesn't have priority concept
+ "data": {"description": job.description or ""},
+ "result": job.result,
+ "meta": job.meta if job.meta else {},
+ "kwargs": job.kwargs if job.kwargs else {},
+ "error_message": (
+ str(job.exc_info) if job.exc_info else None
+ ),
+ "created_at": (
+ job.created_at.isoformat()
+ if job.created_at
+ else None
+ ),
+ "started_at": (
+ job.started_at.isoformat()
+ if job.started_at
+ else None
+ ),
+ "ended_at": (
+ job.ended_at.isoformat()
+ if job.ended_at
+ else None
+ ),
+ "retry_count": 0, # RQ doesn't track this by default
+ "max_retries": 0,
+ "progress_percent": (job.meta or {})
+ .get("batch_progress", {})
+ .get("percent", 0),
+ "progress_message": (job.meta or {})
+ .get("batch_progress", {})
+ .get("message", ""),
+ "queue": queue_name,
+ }
+ )
except Exception as e:
logger.debug(f"Error fetching job {job_id}: {e}")
continue
@@ -1023,7 +1168,13 @@ async def fetch_stats():
return get_job_stats()
except Exception as e:
logger.error(f"Error fetching stats: {e}")
- return {"total_jobs": 0, "queued_jobs": 0, "started_jobs": 0, "finished_jobs": 0, "failed_jobs": 0}
+ return {
+ "total_jobs": 0,
+ "queued_jobs": 0,
+ "started_jobs": 0,
+ "finished_jobs": 0,
+ "failed_jobs": 0,
+ }
async def fetch_streaming_status():
"""Fetch streaming status."""
@@ -1072,9 +1223,15 @@ def get_job_status(job):
registries = [
("queued", queue.job_ids),
- ("started", StartedJobRegistry(queue=queue).get_job_ids()), # RQ standard
- ("finished", FinishedJobRegistry(queue=queue).get_job_ids()), # RQ standard
- ("failed", FailedJobRegistry(queue=queue).get_job_ids())
+ (
+ "started",
+ StartedJobRegistry(queue=queue).get_job_ids(),
+ ), # RQ standard
+ (
+ "finished",
+ FinishedJobRegistry(queue=queue).get_job_ids(),
+ ), # RQ standard
+ ("failed", FailedJobRegistry(queue=queue).get_job_ids()),
]
for status_name, job_ids in registries:
@@ -1087,7 +1244,11 @@ def get_job_status(job):
# Check if job belongs to this client
matches_client = False
- if job.meta and 'client_id' in job.meta and job.meta['client_id'] == client_id:
+ if (
+ job.meta
+ and "client_id" in job.meta
+ and job.meta["client_id"] == client_id
+ ):
matches_client = True
if not matches_client:
@@ -1095,24 +1256,48 @@ def get_job_status(job):
# Check user permission
if not current_user.is_superuser:
- job_user_id = job.kwargs.get("user_id") if job.kwargs else None
+ job_user_id = (
+ job.kwargs.get("user_id")
+ if job.kwargs
+ else None
+ )
if job_user_id != str(current_user.user_id):
continue
processed_job_ids.add(job_id)
- all_jobs.append({
- "job_id": job.id,
- "job_type": job.func_name.split('.')[-1] if job.func_name else "unknown",
- "queue": queue_name,
- "status": get_job_status(job),
- "created_at": job.created_at.isoformat() if job.created_at else None,
- "started_at": job.started_at.isoformat() if job.started_at else None,
- "ended_at": job.ended_at.isoformat() if job.ended_at else None,
- "description": job.description or "",
- "result": job.result,
- "meta": job.meta if job.meta else {},
- "error_message": str(job.exc_info) if job.exc_info else None
- })
+ all_jobs.append(
+ {
+ "job_id": job.id,
+ "job_type": (
+ job.func_name.split(".")[-1]
+ if job.func_name
+ else "unknown"
+ ),
+ "queue": queue_name,
+ "status": get_job_status(job),
+ "created_at": (
+ job.created_at.isoformat()
+ if job.created_at
+ else None
+ ),
+ "started_at": (
+ job.started_at.isoformat()
+ if job.started_at
+ else None
+ ),
+ "ended_at": (
+ job.ended_at.isoformat()
+ if job.ended_at
+ else None
+ ),
+ "description": job.description or "",
+ "result": job.result,
+ "meta": job.meta if job.meta else {},
+ "error_message": (
+ str(job.exc_info) if job.exc_info else None
+ ),
+ }
+ )
except Exception as e:
logger.debug(f"Error fetching job {job_id}: {e}")
continue
@@ -1127,7 +1312,9 @@ async def fetch_events():
if not current_user.is_superuser:
return []
try:
- from advanced_omi_backend.services.plugin_service import get_plugin_router
+ from advanced_omi_backend.services.plugin_service import (
+ get_plugin_router,
+ )
router_instance = get_plugin_router()
if not router_instance:
@@ -1139,8 +1326,12 @@ async def fetch_events():
# Execute all fetches in parallel (using RQ standard status names)
queued_jobs_task = fetch_jobs_by_status("queued", limit=100)
- started_jobs_task = fetch_jobs_by_status("started", limit=100) # RQ standard, not "processing"
- finished_jobs_task = fetch_jobs_by_status("finished", limit=50) # RQ standard, not "completed"
+ started_jobs_task = fetch_jobs_by_status(
+ "started", limit=100
+ ) # RQ standard, not "processing"
+ finished_jobs_task = fetch_jobs_by_status(
+ "finished", limit=50
+ ) # RQ standard, not "completed"
failed_jobs_task = fetch_jobs_by_status("failed", limit=50)
stats_task = fetch_stats()
streaming_status_task = fetch_streaming_status()
@@ -1156,15 +1347,25 @@ async def fetch_events():
streaming_status_task,
events_task,
*client_jobs_tasks,
- return_exceptions=True
+ return_exceptions=True,
)
queued_jobs = results[0] if not isinstance(results[0], Exception) else []
- started_jobs = results[1] if not isinstance(results[1], Exception) else [] # RQ standard
- finished_jobs = results[2] if not isinstance(results[2], Exception) else [] # RQ standard
+ started_jobs = (
+ results[1] if not isinstance(results[1], Exception) else []
+ ) # RQ standard
+ finished_jobs = (
+ results[2] if not isinstance(results[2], Exception) else []
+ ) # RQ standard
failed_jobs = results[3] if not isinstance(results[3], Exception) else []
- stats = results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0}
- streaming_status = results[5] if not isinstance(results[5], Exception) else {"active_sessions": []}
+ stats = (
+ results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0}
+ )
+ streaming_status = (
+ results[5]
+ if not isinstance(results[5], Exception)
+ else {"active_sessions": []}
+ )
events = results[6] if not isinstance(results[6], Exception) else []
recent_conversations = []
client_jobs_results = results[7:] if len(results) > 7 else []
@@ -1178,30 +1379,40 @@ async def fetch_events():
# Convert conversations to dict format for frontend
conversations_list = []
for conv in recent_conversations:
- conversations_list.append({
- "conversation_id": conv.conversation_id,
- "user_id": str(conv.user_id) if conv.user_id else None,
- "created_at": conv.created_at.isoformat() if conv.created_at else None,
- "title": conv.title,
- "summary": conv.summary,
- "transcript_text": conv.get_active_transcript_text() if hasattr(conv, 'get_active_transcript_text') else None,
- })
+ conversations_list.append(
+ {
+ "conversation_id": conv.conversation_id,
+ "user_id": str(conv.user_id) if conv.user_id else None,
+ "created_at": (
+ conv.created_at.isoformat() if conv.created_at else None
+ ),
+ "title": conv.title,
+ "summary": conv.summary,
+ "transcript_text": (
+ conv.get_active_transcript_text()
+ if hasattr(conv, "get_active_transcript_text")
+ else None
+ ),
+ }
+ )
return {
"jobs": {
"queued": queued_jobs,
"started": started_jobs, # RQ standard status name
"finished": finished_jobs, # RQ standard status name
- "failed": failed_jobs
+ "failed": failed_jobs,
},
"stats": stats,
"streaming_status": streaming_status,
"recent_conversations": conversations_list,
"client_jobs": client_jobs,
"events": events,
- "timestamp": asyncio.get_event_loop().time()
+ "timestamp": asyncio.get_event_loop().time(),
}
except Exception as e:
logger.error(f"Failed to get dashboard data: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"Failed to get dashboard data: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to get dashboard data: {str(e)}"
+ )
diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py
index dc5e9b27..61244f43 100644
--- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py
+++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py
@@ -2,14 +2,12 @@
Audio stream producer - publishes audio chunks to Redis Streams.
"""
+import json
import logging
import time
-import json
import redis.asyncio as redis
-from advanced_omi_backend.services.transcription.base import TranscriptionProvider
-
logger = logging.getLogger(__name__)
@@ -45,7 +43,7 @@ async def init_session(
user_email: str = "",
connection_id: str = "",
mode: str = "streaming",
- provider: str = "deepgram"
+ provider: str = "deepgram",
):
"""
Initialize session tracking metadata in Redis.
@@ -66,35 +64,32 @@ async def init_session(
stream_name = f"audio:stream:{client_id}"
session_key = f"audio:session:{session_id}"
- await self.redis_client.hset(session_key, mapping={
- # User & Client tracking
- "user_id": user_id,
- "user_email": user_email,
- "client_id": client_id,
- "connection_id": connection_id,
-
- # Stream configuration
- "stream_name": stream_name,
- "provider": provider,
- "mode": mode,
-
- # Timestamps
- "started_at": str(time.time()),
- "last_chunk_at": str(time.time()),
-
- # Counters
- "chunks_published": "0",
-
- # Job tracking (populated by queue_controller when jobs start)
- "speech_detection_job_id": "",
- "audio_persistence_job_id": "",
-
- # Connection state
- "websocket_connected": "true",
-
- # Session status
- "status": "active"
- })
+ await self.redis_client.hset(
+ session_key,
+ mapping={
+ # User & Client tracking
+ "user_id": user_id,
+ "user_email": user_email,
+ "client_id": client_id,
+ "connection_id": connection_id,
+ # Stream configuration
+ "stream_name": stream_name,
+ "provider": provider,
+ "mode": mode,
+ # Timestamps
+ "started_at": str(time.time()),
+ "last_chunk_at": str(time.time()),
+ # Counters
+ "chunks_published": "0",
+ # Job tracking (populated by queue_controller when jobs start)
+ "speech_detection_job_id": "",
+ "audio_persistence_job_id": "",
+ # Connection state
+ "websocket_connected": "true",
+ # Session status
+ "status": "active",
+ },
+ )
# Set TTL of 1 hour
await self.redis_client.expire(session_key, 3600)
@@ -106,10 +101,12 @@ async def init_session(
"user_id": user_id,
"client_id": client_id,
"stream_name": stream_name,
- "provider": provider
+ "provider": provider,
}
- logger.info(f"π Initialized session {session_id} β stream {stream_name} (provider: {provider})")
+ logger.info(
+ f"π Initialized session {session_id} β stream {stream_name} (provider: {provider})"
+ )
async def update_session_chunk_count(self, session_id: str):
"""
@@ -166,10 +163,7 @@ async def send_session_end_signal(self, session_id: str):
}
await self.redis_client.xadd(
- stream_name,
- end_signal,
- maxlen=25000,
- approximate=True
+ stream_name, end_signal, maxlen=25000, approximate=True
)
logger.info(f"π‘ Sent end-of-session signal for {session_id} to {stream_name}")
@@ -187,14 +181,22 @@ async def get_session(self, session_id: str) -> dict:
session_data = await self.redis_client.hgetall(session_key)
# Convert bytes to strings for easier handling
- return {k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
- for k, v in session_data.items()} if session_data else {}
+ return (
+ {
+ k.decode() if isinstance(k, bytes) else k: (
+ v.decode() if isinstance(v, bytes) else v
+ )
+ for k, v in session_data.items()
+ }
+ if session_data
+ else {}
+ )
async def update_session_job_ids(
self,
session_id: str,
speech_detection_job_id: str = None,
- audio_persistence_job_id: str = None
+ audio_persistence_job_id: str = None,
):
"""
Update job IDs in session metadata.
@@ -224,10 +226,13 @@ async def mark_websocket_disconnected(self, session_id: str):
session_id: Session identifier
"""
session_key = f"audio:session:{session_id}"
- await self.redis_client.hset(session_key, mapping={
- "websocket_connected": "false",
- "disconnected_at": str(time.time())
- })
+ await self.redis_client.hset(
+ session_key,
+ mapping={
+ "websocket_connected": "false",
+ "disconnected_at": str(time.time()),
+ },
+ )
logger.info(f"π Marked websocket disconnected for session {session_id}")
async def finalize_session(self, session_id: str, completion_reason: str = None):
@@ -242,15 +247,14 @@ async def finalize_session(self, session_id: str, completion_reason: str = None)
session_key = f"audio:session:{session_id}"
# Build mapping with status and optional completion_reason
- mapping = {
- "status": "finalizing",
- "finalized_at": str(time.time())
- }
+ mapping = {"status": "finalizing", "finalized_at": str(time.time())}
# Set completion_reason atomically with status to prevent race conditions
if completion_reason:
mapping["completion_reason"] = completion_reason
- logger.info(f"π Finalizing session {session_id} with reason: {completion_reason}")
+ logger.info(
+ f"π Finalizing session {session_id} with reason: {completion_reason}"
+ )
await self.redis_client.hset(session_key, mapping=mapping)
@@ -269,10 +273,7 @@ async def finalize_session(self, session_id: str, completion_reason: str = None)
}
await self.redis_client.xadd(
- stream_name,
- end_marker_data,
- maxlen=25000,
- approximate=True
+ stream_name, end_marker_data, maxlen=25000, approximate=True
)
logger.info(f"π‘ Sent end_marker to {stream_name} for session {session_id}")
@@ -286,12 +287,11 @@ async def add_audio_chunk(
self,
audio_data: bytes,
session_id: str,
- chunk_id: str,
user_id: str,
client_id: str,
sample_rate: int = 16000,
channels: int = 1,
- sample_width: int = 2
+ sample_width: int = 2,
) -> list[str]:
"""
Add audio data to session buffer and publish fixed-size chunks.
@@ -302,7 +302,6 @@ async def add_audio_chunk(
Args:
audio_data: Raw PCM audio bytes (arbitrary size from WebSocket)
session_id: Session identifier
- chunk_id: Base chunk identifier (will increment for multiple chunks)
user_id: User identifier
client_id: Client identifier (used for stream naming)
sample_rate: Audio sample rate (Hz)
@@ -321,7 +320,7 @@ async def add_audio_chunk(
"user_id": user_id,
"client_id": client_id,
"stream_name": stream_name,
- "provider": "deepgram"
+ "provider": "deepgram",
}
session_buffer = self.session_buffers[session_id]
@@ -366,7 +365,7 @@ async def add_audio_chunk(
stream_name,
chunk_data,
maxlen=25000, # Keep max 25k chunks (~104 minutes at 250ms/chunk)
- approximate=True
+ approximate=True,
)
message_ids.append(message_id.decode())
@@ -374,7 +373,10 @@ async def add_audio_chunk(
await self.update_session_chunk_count(session_id)
# Log every 10th chunk to avoid spam
- if session_buffer["chunk_count"] % 10 == 0 or session_buffer["chunk_count"] <= 5:
+ if (
+ session_buffer["chunk_count"] % 10 == 0
+ or session_buffer["chunk_count"] <= 5
+ ):
logger.debug(
f"π€ Added fixed-size chunk {chunk_id_formatted} to {stream_name} "
f"({len(chunk_audio)} bytes = {len(chunk_audio)/bytes_per_second:.3f}s, "
@@ -396,7 +398,7 @@ async def flush_session_buffer(
session_id: str,
sample_rate: int = 16000,
channels: int = 1,
- sample_width: int = 2
+ sample_width: int = 2,
) -> str | None:
"""
Flush any remaining audio in session buffer.
@@ -443,10 +445,7 @@ async def flush_session_buffer(
# Add to stream with MAXLEN limit
message_id = await self.redis_client.xadd(
- stream_name,
- chunk_data,
- maxlen=25000,
- approximate=True
+ stream_name, chunk_data, maxlen=25000, approximate=True
)
# Update session tracking
@@ -463,7 +462,6 @@ async def flush_session_buffer(
return None
-
# Singleton instance
_producer_instance = None
@@ -486,12 +484,12 @@ def get_audio_stream_producer() -> AudioStreamProducer:
# Create async Redis client (synchronous call, connection happens on first use)
redis_client = redis_async.from_url(
- redis_url,
- encoding="utf-8",
- decode_responses=False
+ redis_url, encoding="utf-8", decode_responses=False
)
_producer_instance = AudioStreamProducer(redis_client)
- logger.info(f"Created AudioStreamProducer singleton with Redis URL: {redis_url}")
+ logger.info(
+ f"Created AudioStreamProducer singleton with Redis URL: {redis_url}"
+ )
return _producer_instance
diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py
index 84e26f11..804077f4 100644
--- a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py
+++ b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py
@@ -42,13 +42,10 @@ def _parse_hot_words_to_keyterm(hot_words_str: str) -> str:
terms = []
for word in re.split(r"[,\n]+", hot_words_str):
- word = word.strip()
+ word = word.strip().lower()
if not word:
continue
terms.append(word)
- capitalized = word.title()
- if capitalized != word:
- terms.append(capitalized)
return " ".join(terms)
@@ -61,11 +58,11 @@ def _dotted_get(d: dict | list | None, dotted: Optional[str]):
if d is None or not dotted:
return None
cur = d
- for part in dotted.split('.'):
+ for part in dotted.split("."):
if not part:
continue
- if '[' in part and part.endswith(']'):
- name, idx_str = part[:-1].split('[', 1)
+ if "[" in part and part.endswith("]"):
+ name, idx_str = part[:-1].split("[", 1)
if name:
cur = cur.get(name, {}) if isinstance(cur, dict) else {}
try:
@@ -83,6 +80,40 @@ def _dotted_get(d: dict | list | None, dotted: Optional[str]):
return cur
+def _normalize_provider_segments(segments: list) -> list:
+ """Normalize provider-specific segment formats to a standard shape.
+
+ Handles Deepgram paragraph format where:
+ - text is nested in ``sentences[].text`` instead of a top-level ``text`` field
+ - ``speaker`` is an integer (0, 1) instead of a string ("Speaker 0")
+
+ After normalization every segment dict will have:
+ - ``text`` (str): combined sentence text
+ - ``speaker`` (str): "Speaker N" label
+ - ``start`` / ``end`` (float): time span (preserved from original)
+ """
+ if not segments:
+ return segments
+
+ for seg in segments:
+ if not isinstance(seg, dict):
+ continue
+
+ # Deepgram paragraphs: text lives inside sentences[], not top-level
+ if "text" not in seg and "sentences" in seg:
+ sentences = seg.get("sentences", [])
+ seg["text"] = " ".join(
+ s.get("text", "") for s in sentences if isinstance(s, dict)
+ )
+
+ # Normalise integer speaker IDs to "Speaker N" strings
+ speaker = seg.get("speaker")
+ if isinstance(speaker, (int, float)):
+ seg["speaker"] = f"Speaker {int(speaker)}"
+
+ return segments
+
+
class RegistryBatchTranscriptionProvider(BatchTranscriptionProvider):
"""Batch transcription provider driven by config.yml."""
@@ -124,24 +155,33 @@ def get_capabilities_dict(self) -> dict:
"""
return {cap: True for cap in self._capabilities}
- async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False, context_info: Optional[str] = None, **kwargs) -> dict:
+ async def transcribe(
+ self,
+ audio_data: bytes,
+ sample_rate: int,
+ diarize: bool = False,
+ context_info: Optional[str] = None,
+ progress_callback=None,
+ **kwargs,
+ ) -> dict:
# Special handling for mock provider (no HTTP server needed)
if self.model.model_provider == "mock":
from .mock_provider import MockTranscriptionProvider
+
mock = MockTranscriptionProvider(fail_mode=False)
return await mock.transcribe(audio_data, sample_rate, diarize)
op = (self.model.operations or {}).get("stt_transcribe") or {}
method = (op.get("method") or "POST").upper()
- path = (op.get("path") or "/listen")
+ path = op.get("path") or "/listen"
# Build URL
base = self.model.model_url.rstrip("/")
url = base + ("/" + path.lstrip("/"))
-
+
# Check if we should use multipart file upload (for Parakeet)
content_type = op.get("content_type", "audio/raw")
use_multipart = content_type == "multipart/form-data"
-
+
# Build headers (skip Content-Type for multipart as httpx will set it)
headers = {}
if not use_multipart:
@@ -152,7 +192,7 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool =
headers["Content-Type"] = "audio/wav"
else:
headers["Content-Type"] = "audio/raw"
-
+
if self.model.api_key:
# Allow templated header, otherwise fallback to Bearer/Token conventions by config
hdrs = op.get("headers") or {}
@@ -167,7 +207,10 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool =
hdrs = op.get("headers") or {}
for k, v in hdrs.items():
# Skip Authorization headers with empty/invalid values
- if k.lower() == "authorization" and (not v or v.strip().lower() in ["token", "token ", "bearer", "bearer "]):
+ if k.lower() == "authorization" and (
+ not v
+ or v.strip().lower() in ["token", "token ", "bearer", "bearer "]
+ ):
continue
headers[k] = v
@@ -196,24 +239,72 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool =
if keyterm:
query["keyterm"] = keyterm
+ # NOTE: PULSE (smallest.ai) does NOT support keywords on WebSocket or
+ # batch HTTP β any `keywords` query param causes 0 responses or HTTP 400.
+ # Hot-word boosting for PULSE is not injected here.
+
timeout = op.get("timeout", 300)
+ # Use a longer read timeout for NDJSON progress responses β each
+ # batch window can take minutes but the service keeps sending
+ # progress lines between windows.
+ read_timeout = op.get("read_timeout", timeout)
try:
- async with httpx.AsyncClient(timeout=timeout) as client:
+ timeouts = httpx.Timeout(timeout, read=read_timeout)
+ async with httpx.AsyncClient(timeout=timeouts) as client:
if method == "POST":
if use_multipart:
# Send as multipart file upload (for Parakeet/VibeVoice)
files = {"file": ("audio.wav", audio_data, "audio/wav")}
- data = {}
+ form_data = {}
if hot_words_str and hot_words_str.strip():
- data["context_info"] = hot_words_str.strip()
- resp = await client.post(url, headers=headers, params=query, files=files, data=data)
+ form_data["context_info"] = hot_words_str.strip()
+
+ # Use streaming to handle NDJSON progress responses
+ async with client.stream(
+ "POST",
+ url,
+ headers=headers,
+ params=query,
+ files=files,
+ data=form_data,
+ ) as resp:
+ resp.raise_for_status()
+ content_type = resp.headers.get("content-type", "")
+
+ if "application/x-ndjson" in content_type:
+ # Batch progress: read events line by line
+ data = None
+ async for line in resp.aiter_lines():
+ line = line.strip()
+ if not line:
+ continue
+ event = json.loads(line)
+ if (
+ event.get("type") == "progress"
+ and progress_callback
+ ):
+ progress_callback(event)
+ elif event.get("type") == "result":
+ data = event
+ if data is None:
+ raise RuntimeError(
+ f"NDJSON stream from '{self._name}' ended without a result event"
+ )
+ else:
+ # Normal JSON response
+ await resp.aread()
+ data = resp.json()
else:
# Send as raw audio data (for Deepgram)
- resp = await client.post(url, headers=headers, params=query, content=audio_data)
+ resp = await client.post(
+ url, headers=headers, params=query, content=audio_data
+ )
+ resp.raise_for_status()
+ data = resp.json()
else:
resp = await client.get(url, headers=headers, params=query)
- resp.raise_for_status()
- data = resp.json()
+ resp.raise_for_status()
+ data = resp.json()
except httpx.ConnectError as e:
raise ConnectionError(
f"Cannot reach transcription service '{self._name}' at {url}. "
@@ -233,7 +324,9 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool =
channels = data["results"]["channels"]
if channels and "alternatives" in channels[0]:
alt = channels[0]["alternatives"][0]
- logger.debug(f"DEBUG Registry: Deepgram alternative keys: {list(alt.keys())}")
+ logger.debug(
+ f"DEBUG Registry: Deepgram alternative keys: {list(alt.keys())}"
+ )
# Extract normalized shape
text, words, segments = "", [], []
@@ -242,26 +335,36 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool =
text = _dotted_get(data, extract.get("text")) or ""
words = _dotted_get(data, extract.get("words")) or []
segments = _dotted_get(data, extract.get("segments")) or []
+ segments = _normalize_provider_segments(segments)
# Check config to decide whether to keep or discard provider segments
transcription_config = get_backend_config("transcription")
- use_provider_segments = transcription_config.get("use_provider_segments", False)
+ use_provider_segments = transcription_config.get(
+ "use_provider_segments", False
+ )
if not use_provider_segments:
segments = []
- logger.debug(f"Transcription: Extracted {len(words)} words, ignoring provider segments (use_provider_segments=false)")
+ logger.debug(
+ f"Transcription: Extracted {len(words)} words, ignoring provider segments (use_provider_segments=false)"
+ )
else:
- logger.debug(f"Transcription: Extracted {len(words)} words, keeping {len(segments)} provider segments (use_provider_segments=true)")
+ logger.debug(
+ f"Transcription: Extracted {len(words)} words, keeping {len(segments)} provider segments (use_provider_segments=true)"
+ )
return {"text": text, "words": words, "segments": segments}
+
class RegistryStreamingTranscriptionProvider(StreamingTranscriptionProvider):
"""Streaming transcription provider using a config-driven WebSocket template."""
def __init__(self):
registry = get_models_registry()
if not registry:
- raise RuntimeError("config.yml not found; cannot configure streaming STT provider")
+ raise RuntimeError(
+ "config.yml not found; cannot configure streaming STT provider"
+ )
model = registry.get_default("stt_stream")
if not model:
raise RuntimeError("No default stt_stream model defined in config.yml")
@@ -281,9 +384,13 @@ def capabilities(self) -> set:
async def transcribe(self, audio_data: bytes, sample_rate: int, **kwargs) -> dict:
"""Not used for streaming providers - use start_stream/process_audio_chunk/end_stream instead."""
- raise NotImplementedError("Streaming providers do not support batch transcription")
+ raise NotImplementedError(
+ "Streaming providers do not support batch transcription"
+ )
- async def start_stream(self, client_id: str, sample_rate: int = 16000, diarize: bool = False):
+ async def start_stream(
+ self, client_id: str, sample_rate: int = 16000, diarize: bool = False
+ ):
base_url = self.model.model_url
ops = self.model.operations or {}
@@ -309,6 +416,9 @@ async def start_stream(self, client_id: str, sample_rate: int = 16000, diarize:
except Exception as e:
logger.debug(f"Failed to fetch asr.hot_words for streaming: {e}")
+ # NOTE: PULSE/wave (smallest.ai) does NOT support keywords on WebSocket β
+ # any `keywords` query param causes 0 responses or HTTP 400.
+
# Normalize boolean values to lowercase strings (Deepgram expects "true"/"false", not "True"/"False")
normalized_query = {}
for k, v in query_dict.items():
@@ -351,9 +461,16 @@ async def start_stream(self, client_id: str, sample_rate: int = 16000, diarize:
except Exception:
pass
- self._streams[client_id] = {"ws": ws, "sample_rate": sample_rate, "final": None, "interim": []}
+ self._streams[client_id] = {
+ "ws": ws,
+ "sample_rate": sample_rate,
+ "final": None,
+ "interim": [],
+ }
- async def process_audio_chunk(self, client_id: str, audio_chunk: bytes) -> dict | None:
+ async def process_audio_chunk(
+ self, client_id: str, audio_chunk: bytes
+ ) -> dict | None:
if client_id not in self._streams:
return None
ws = self._streams[client_id]["ws"]
@@ -372,7 +489,7 @@ async def process_audio_chunk(self, client_id: str, audio_chunk: bytes) -> dict
await ws.send(audio_chunk)
# Non-blocking read for results
- expect = (ops.get("expect", {}) or {})
+ expect = ops.get("expect", {}) or {}
extract = expect.get("extract", {})
interim_type = expect.get("interim_type")
final_type = expect.get("final_type")
@@ -392,16 +509,33 @@ async def process_audio_chunk(self, client_id: str, audio_chunk: bytes) -> dict
# Fallback: check is_final directly (for providers that don't use a type field)
is_final = data.get("is_final", False)
- # Extract result data
- text = _dotted_get(data, extract.get("text")) if extract.get("text") else data.get("text", "")
- words = _dotted_get(data, extract.get("words")) if extract.get("words") else data.get("words", [])
- segments = _dotted_get(data, extract.get("segments")) if extract.get("segments") else data.get("segments", [])
+ # Extract result data (guard against None from _dotted_get)
+ text = (
+ _dotted_get(data, extract.get("text"))
+ if extract.get("text")
+ else data.get("text", "")
+ ) or ""
+ words = (
+ _dotted_get(data, extract.get("words"))
+ if extract.get("words")
+ else data.get("words", [])
+ ) or []
+ segments = (
+ _dotted_get(data, extract.get("segments"))
+ if extract.get("segments")
+ else data.get("segments", [])
+ ) or []
+ segments = _normalize_provider_segments(segments)
# Calculate confidence if available
confidence = data.get("confidence", 0.0)
if not confidence and words and isinstance(words, list):
# Calculate average word confidence
- confidences = [w.get("confidence", 0.0) for w in words if isinstance(w, dict) and "confidence" in w]
+ confidences = [
+ w.get("confidence", 0.0)
+ for w in words
+ if isinstance(w, dict) and "confidence" in w
+ ]
if confidences:
confidence = sum(confidences) / len(confidences)
@@ -412,7 +546,7 @@ async def process_audio_chunk(self, client_id: str, audio_chunk: bytes) -> dict
"words": words,
"segments": segments,
"is_final": is_final,
- "confidence": confidence
+ "confidence": confidence,
}
except asyncio.TimeoutError:
@@ -430,7 +564,7 @@ async def end_stream(self, client_id: str) -> dict:
end_msg = (ops.get("end", {}) or {}).get("message", {"type": "stop"})
await ws.send(json.dumps(end_msg))
- expect = (ops.get("expect", {}) or {})
+ expect = ops.get("expect", {}) or {}
final_type = expect.get("final_type")
extract = expect.get("extract", {})
@@ -457,14 +591,29 @@ async def end_stream(self, client_id: str) -> dict:
if not isinstance(final, dict):
return {"text": "", "words": [], "segments": []}
+ segments = (
+ _dotted_get(final, extract.get("segments"))
+ if extract
+ else final.get("segments", [])
+ ) or []
return {
- "text": _dotted_get(final, extract.get("text")) if extract else final.get("text", ""),
- "words": _dotted_get(final, extract.get("words")) if extract else final.get("words", []),
- "segments": _dotted_get(final, extract.get("segments")) if extract else final.get("segments", []),
+ "text": (
+ _dotted_get(final, extract.get("text"))
+ if extract
+ else final.get("text", "")
+ ),
+ "words": (
+ _dotted_get(final, extract.get("words"))
+ if extract
+ else final.get("words", [])
+ ),
+ "segments": _normalize_provider_segments(segments),
}
-def get_transcription_provider(provider_name: Optional[str] = None, mode: Optional[str] = None) -> Optional[BaseTranscriptionProvider]:
+def get_transcription_provider(
+ provider_name: Optional[str] = None, mode: Optional[str] = None
+) -> Optional[BaseTranscriptionProvider]:
"""Return a registry-driven transcription provider.
- mode="batch": HTTP-based STT (default)
@@ -503,7 +652,9 @@ def is_transcription_available(mode: str = "batch") -> bool:
return provider is not None
-def get_mock_transcription_provider(fail_mode: bool = False) -> BaseTranscriptionProvider:
+def get_mock_transcription_provider(
+ fail_mode: bool = False,
+) -> BaseTranscriptionProvider:
"""Return a mock transcription provider (for testing only).
Args:
@@ -513,6 +664,7 @@ def get_mock_transcription_provider(fail_mode: bool = False) -> BaseTranscriptio
MockTranscriptionProvider instance
"""
from .mock_provider import MockTranscriptionProvider
+
return MockTranscriptionProvider(fail_mode=fail_mode)
diff --git a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py
index b1bbb8cd..7c741662 100644
--- a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py
+++ b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py
@@ -21,6 +21,7 @@
)
from advanced_omi_backend.controllers.session_controller import mark_session_complete
from advanced_omi_backend.models.job import async_job
+from advanced_omi_backend.observability.otel_setup import set_galileo_session
from advanced_omi_backend.plugins.events import PluginEvent
from advanced_omi_backend.services.plugin_service import (
ensure_plugin_router,
@@ -106,7 +107,9 @@ async def handle_end_of_conversation(
from advanced_omi_backend.models.conversation import Conversation
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if conversation:
# Convert string to enum
try:
@@ -121,7 +124,9 @@ async def handle_end_of_conversation(
f"πΎ Saved conversation {conversation_id[:12]} end_reason: {conversation.end_reason}"
)
else:
- logger.warning(f"β οΈ Conversation {conversation_id} not found for end reason tracking")
+ logger.warning(
+ f"β οΈ Conversation {conversation_id} not found for end reason tracking"
+ )
# Increment conversation count for this session
conversation_count_key = f"session:conversation_count:{session_id}"
@@ -134,7 +139,9 @@ async def handle_end_of_conversation(
session_status = await redis_client.hget(session_key, "status")
if session_status:
status_str = (
- session_status.decode() if isinstance(session_status, bytes) else session_status
+ session_status.decode()
+ if isinstance(session_status, bytes)
+ else session_status
)
if status_str == "active":
@@ -257,7 +264,9 @@ async def open_conversation_job(
conversation = None
if existing_conversation_id_bytes:
existing_conversation_id = existing_conversation_id_bytes.decode()
- logger.info(f"π Found Redis key with conversation_id={existing_conversation_id}")
+ logger.info(
+ f"π Found Redis key with conversation_id={existing_conversation_id}"
+ )
# Try to fetch the existing conversation by conversation_id
conversation = await Conversation.find_one(
@@ -272,13 +281,16 @@ async def open_conversation_job(
f"processing_status={processing_status}"
)
else:
- logger.warning(f"β οΈ Conversation {existing_conversation_id} not found in database!")
+ logger.warning(
+ f"β οΈ Conversation {existing_conversation_id} not found in database!"
+ )
# Verify it's a placeholder conversation (always_persist=True, processing_status='pending_transcription')
if (
conversation
and getattr(conversation, "always_persist", False)
- and getattr(conversation, "processing_status", None) == "pending_transcription"
+ and getattr(conversation, "processing_status", None)
+ == "pending_transcription"
):
logger.info(
f"π Reusing placeholder conversation {conversation.conversation_id} for session {session_id}"
@@ -297,7 +309,9 @@ async def open_conversation_job(
)
conversation = None
else:
- logger.info(f"π No Redis key found for {conversation_key}, creating new conversation")
+ logger.info(
+ f"π No Redis key found for {conversation_key}, creating new conversation"
+ )
# If no valid placeholder found, create new conversation
if not conversation:
@@ -309,17 +323,23 @@ async def open_conversation_job(
)
await conversation.insert()
conversation_id = conversation.conversation_id
- logger.info(f"β
Created streaming conversation {conversation_id} for session {session_id}")
+ logger.info(
+ f"β
Created streaming conversation {conversation_id} for session {session_id}"
+ )
# Attach markers from Redis session (e.g., button events captured during streaming)
session_key = f"audio:session:{session_id}"
markers_json = await redis_client.hget(session_key, "markers")
if markers_json:
try:
- markers_data = markers_json if isinstance(markers_json, str) else markers_json.decode()
+ markers_data = (
+ markers_json if isinstance(markers_json, str) else markers_json.decode()
+ )
conversation.markers = json.loads(markers_data)
await conversation.save()
- logger.info(f"π Attached {len(conversation.markers)} markers to conversation {conversation_id}")
+ logger.info(
+ f"π Attached {len(conversation.markers)} markers to conversation {conversation_id}"
+ )
except Exception as marker_err:
logger.warning(f"β οΈ Failed to parse markers from Redis: {marker_err}")
@@ -334,7 +354,9 @@ async def open_conversation_job(
speaker_check_job_id = speech_job.meta.get("speaker_check_job_id")
if speaker_check_job_id:
try:
- speaker_check_job = Job.fetch(speaker_check_job_id, connection=redis_conn)
+ speaker_check_job = Job.fetch(
+ speaker_check_job_id, connection=redis_conn
+ )
speaker_check_job.meta["conversation_id"] = conversation_id
speaker_check_job.save_meta()
except Exception as e:
@@ -358,7 +380,9 @@ async def open_conversation_job(
# Signal audio persistence job to rotate to this conversation's file
rotation_signal_key = f"conversation:current:{session_id}"
- await redis_client.set(rotation_signal_key, conversation_id, ex=86400) # 24 hour TTL
+ await redis_client.set(
+ rotation_signal_key, conversation_id, ex=86400
+ ) # 24 hour TTL
logger.info(
f"π Signaled audio persistence to rotate file for conversation {conversation_id[:12]}"
)
@@ -368,20 +392,26 @@ async def open_conversation_job(
# Job control
session_key = f"audio:session:{session_id}"
- max_runtime = 10740 # 3 hours - 60 seconds (single conversations shouldn't exceed 3 hours)
+ max_runtime = (
+ 10740 # 3 hours - 60 seconds (single conversations shouldn't exceed 3 hours)
+ )
start_time = time.time()
last_result_count = 0
finalize_received = False
# Inactivity timeout configuration
- inactivity_timeout_seconds = float(os.getenv("SPEECH_INACTIVITY_THRESHOLD_SECONDS", "60"))
+ inactivity_timeout_seconds = float(
+ os.getenv("SPEECH_INACTIVITY_THRESHOLD_SECONDS", "60")
+ )
inactivity_timeout_minutes = inactivity_timeout_seconds / 60
last_meaningful_speech_time = (
0.0 # Initialize with audio time 0 (will be updated with first speech)
)
timeout_triggered = False # Track if closure was due to timeout
- close_requested_reason = None # Track if closure was requested via API/plugin/button
+ close_requested_reason = (
+ None # Track if closure was requested via API/plugin/button
+ )
last_inactivity_log_time = (
time.time()
) # Track when we last logged inactivity (wall-clock for logging)
@@ -389,7 +419,9 @@ async def open_conversation_job(
# Test mode: wait for audio queue to drain before timing out
# In real usage, ambient noise keeps connection alive. In tests, chunks arrive in bursts.
- wait_for_queue_drain = os.getenv("WAIT_FOR_AUDIO_QUEUE_DRAIN", "false").lower() == "true"
+ wait_for_queue_drain = (
+ os.getenv("WAIT_FOR_AUDIO_QUEUE_DRAIN", "false").lower() == "true"
+ )
logger.info(
f"π Conversation timeout configured: {inactivity_timeout_minutes} minutes ({inactivity_timeout_seconds}s)"
@@ -413,7 +445,9 @@ async def open_conversation_job(
finalize_received = True
# Get completion reason (guaranteed to exist with unified API)
- completion_reason = await redis_client.hget(session_key, "completion_reason")
+ completion_reason = await redis_client.hget(
+ session_key, "completion_reason"
+ )
completion_reason_str = (
completion_reason.decode() if completion_reason else "unknown"
)
@@ -433,11 +467,19 @@ async def open_conversation_job(
# Check for conversation close request (set by API, plugins, button press)
if not finalize_received:
- close_reason = await redis_client.hget(session_key, "conversation_close_requested")
+ close_reason = await redis_client.hget(
+ session_key, "conversation_close_requested"
+ )
if close_reason:
await redis_client.hdel(session_key, "conversation_close_requested")
- close_requested_reason = close_reason.decode() if isinstance(close_reason, bytes) else close_reason
- logger.info(f"π Conversation close requested: {close_requested_reason}")
+ close_requested_reason = (
+ close_reason.decode()
+ if isinstance(close_reason, bytes)
+ else close_reason
+ )
+ logger.info(
+ f"π Conversation close requested: {close_requested_reason}"
+ )
timeout_triggered = True # Session stays active (same restart behavior as inactivity timeout)
finalize_received = True
break
@@ -484,9 +526,12 @@ async def open_conversation_job(
estimated_duration = len(text.split()) * 0.5 # ~0.5 seconds per word
seg["end"] = start + estimated_duration
- # Ensure speaker field exists
- if "speaker" not in seg or not seg["speaker"]:
+ # Ensure speaker field exists and is a string
+ speaker = seg.get("speaker")
+ if speaker is None or speaker == "":
seg["speaker"] = "SPEAKER_00"
+ elif isinstance(speaker, (int, float)):
+ seg["speaker"] = f"Speaker {int(speaker)}"
validated_segments.append(seg)
@@ -528,7 +573,9 @@ async def open_conversation_job(
# Can't reliably detect inactivity, so skip timeout check this iteration
inactivity_duration = 0
if speech_analysis.get("fallback", False):
- logger.debug("β οΈ Skipping inactivity check (no audio timestamps available)")
+ logger.debug(
+ "β οΈ Skipping inactivity check (no audio timestamps available)"
+ )
current_time = time.time()
@@ -651,7 +698,9 @@ async def open_conversation_job(
else:
end_reason = "user_stopped"
- logger.info(f"π Conversation {conversation_id[:12]} end_reason determined: {end_reason}")
+ logger.info(
+ f"π Conversation {conversation_id[:12]} end_reason determined: {end_reason}"
+ )
# Wrap all post-processing in try/finally to guarantee handle_end_of_conversation()
# is always called, even if an exception occurs during transcript saving, job
@@ -686,7 +735,9 @@ async def open_conversation_job(
f"β οΈ Streaming transcription ended with error for {session_id}, proceeding anyway"
)
else:
- logger.info(f"β
Streaming transcription confirmed complete for {session_id}")
+ logger.info(
+ f"β
Streaming transcription confirmed complete for {session_id}"
+ )
break
await asyncio.sleep(0.5)
waited_streaming += 0.5
@@ -725,14 +776,20 @@ async def open_conversation_job(
end_reason=end_reason,
)
- logger.info(f"π¦ MongoDB audio chunks ready for conversation {conversation_id[:12]}")
+ logger.info(
+ f"π¦ MongoDB audio chunks ready for conversation {conversation_id[:12]}"
+ )
# Get final streaming transcript and save to conversation
- logger.info(f"π Retrieving final streaming transcript for conversation {conversation_id[:12]}")
+ logger.info(
+ f"π Retrieving final streaming transcript for conversation {conversation_id[:12]}"
+ )
final_transcript = await aggregator.get_combined_results(session_id)
# Fetch conversation from database to ensure we have latest state
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
logger.error(f"β Conversation {conversation_id} not found in database")
raise ValueError(f"Conversation {conversation_id} not found")
@@ -810,7 +867,8 @@ async def open_conversation_job(
# Update placeholder conversation if it exists
if (
getattr(conversation, "always_persist", False)
- and getattr(conversation, "processing_status", None) == "pending_transcription"
+ and getattr(conversation, "processing_status", None)
+ == "pending_transcription"
):
# Keep placeholder status - will be updated by title_summary_job
logger.info(
@@ -837,12 +895,13 @@ async def open_conversation_job(
# Check if always_batch_retranscribe is enabled
from advanced_omi_backend.config_loader import get_backend_config
- transcription_cfg = get_backend_config('transcription')
+ transcription_cfg = get_backend_config("transcription")
batch_retranscribe = False
if transcription_cfg:
from omegaconf import OmegaConf
+
cfg_dict = OmegaConf.to_container(transcription_cfg, resolve=True)
- batch_retranscribe = cfg_dict.get('always_batch_retranscribe', False)
+ batch_retranscribe = cfg_dict.get("always_batch_retranscribe", False)
if batch_retranscribe:
# BATCH PATH: Streaming transcript saved as preview β user sees it immediately
@@ -852,7 +911,9 @@ async def open_conversation_job(
JOB_RESULT_TTL,
transcription_queue,
)
- from advanced_omi_backend.workers.transcription_jobs import transcribe_full_audio_job
+ from advanced_omi_backend.workers.transcription_jobs import (
+ transcribe_full_audio_job,
+ )
batch_version_id = f"batch_{conversation_id[:12]}"
batch_job = transcription_queue.enqueue(
@@ -864,7 +925,7 @@ async def open_conversation_job(
result_ttl=JOB_RESULT_TTL,
job_id=f"batch_retranscribe_{conversation_id[:12]}",
description=f"Batch re-transcription for {conversation_id[:8]}",
- meta={'conversation_id': conversation_id, 'client_id': client_id},
+ meta={"conversation_id": conversation_id, "client_id": client_id},
)
logger.info(
@@ -948,7 +1009,9 @@ async def open_conversation_job(
@async_job(redis=True, beanie=True)
-async def generate_title_summary_job(conversation_id: str, *, redis_client=None) -> Dict[str, Any]:
+async def generate_title_summary_job(
+ conversation_id: str, *, redis_client=None
+) -> Dict[str, Any]:
"""
Generate title, short summary, and detailed summary for a conversation using LLM.
@@ -971,12 +1034,17 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None)
generate_title_and_summary,
)
- logger.info(f"π Starting title/summary generation for conversation {conversation_id}")
+ set_galileo_session(conversation_id)
+ logger.info(
+ f"π Starting title/summary generation for conversation {conversation_id}"
+ )
start_time = time.time()
# Get the conversation
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
logger.error(f"Conversation {conversation_id} not found")
return {"success": False, "error": "Conversation not found"}
@@ -986,7 +1054,9 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None)
segments = conversation.segments or []
if not transcript_text and (not segments or len(segments) == 0):
- logger.warning(f"β οΈ No transcript or segments available for conversation {conversation_id}")
+ logger.warning(
+ f"β οΈ No transcript or segments available for conversation {conversation_id}"
+ )
return {
"success": False,
"error": "No transcript or segments available",
@@ -1018,18 +1088,24 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None)
else:
logger.info(f"π No memories found for context enrichment")
except Exception as mem_error:
- logger.warning(f"β οΈ Could not fetch memory context (continuing without): {mem_error}")
+ logger.warning(
+ f"β οΈ Could not fetch memory context (continuing without): {mem_error}"
+ )
# Generate title+summary (one call) and detailed summary in parallel
import asyncio
(title, short_summary), detailed_summary = await asyncio.gather(
generate_title_and_summary(
- transcript_text, segments=segments, user_id=conversation.user_id,
+ transcript_text,
+ segments=segments,
+ user_id=conversation.user_id,
langfuse_session_id=conversation_id,
),
generate_detailed_summary(
- transcript_text, segments=segments, memory_context=memory_context,
+ transcript_text,
+ segments=segments,
+ memory_context=memory_context,
langfuse_session_id=conversation_id,
),
)
@@ -1040,10 +1116,15 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None)
logger.info(f"β
Generated title: '{conversation.title}'")
logger.info(f"β
Generated summary: '{conversation.summary}'")
- logger.info(f"β
Generated detailed summary: {len(conversation.detailed_summary)} chars")
+ logger.info(
+ f"β
Generated detailed summary: {len(conversation.detailed_summary)} chars"
+ )
# Update processing status for placeholder/reprocessing conversations
- if getattr(conversation, "processing_status", None) in ["pending_transcription", "reprocessing"]:
+ if getattr(conversation, "processing_status", None) in [
+ "pending_transcription",
+ "reprocessing",
+ ]:
conversation.processing_status = "completed"
logger.info(
f"β
Updated placeholder conversation {conversation_id} "
@@ -1054,7 +1135,10 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None)
logger.error(f"β Title/summary generation failed: {gen_error}")
# Mark placeholder/reprocessing conversation as failed
- if getattr(conversation, "processing_status", None) in ["pending_transcription", "reprocessing"]:
+ if getattr(conversation, "processing_status", None) in [
+ "pending_transcription",
+ "reprocessing",
+ ]:
conversation.title = "Audio Recording (Transcription Failed)"
conversation.summary = f"Title/summary generation failed: {str(gen_error)}"
conversation.processing_status = "transcription_failed"
@@ -1089,7 +1173,9 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None)
"title": conversation.title,
"summary": conversation.summary,
"detailed_summary_length": (
- len(conversation.detailed_summary) if conversation.detailed_summary else 0
+ len(conversation.detailed_summary)
+ if conversation.detailed_summary
+ else 0
),
"segment_count": len(segments),
"processing_time": processing_time,
@@ -1140,12 +1226,16 @@ async def dispatch_conversation_complete_event_job(
"""
from advanced_omi_backend.models.conversation import Conversation
- logger.info(f"π Dispatching conversation.complete event for conversation {conversation_id}")
+ logger.info(
+ f"π Dispatching conversation.complete event for conversation {conversation_id}"
+ )
start_time = time.time()
# Get the conversation to include in event data
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
logger.error(f"Conversation {conversation_id} not found")
return {"success": False, "error": "Conversation not found"}
@@ -1222,9 +1312,13 @@ async def dispatch_conversation_complete_event_job(
f"π RESULT: conversation.complete dispatched to {len(plugin_results) if plugin_results else 0} plugins"
)
if plugin_results:
- logger.info(f"π Triggered {len(plugin_results)} conversation-level plugins")
+ logger.info(
+ f"π Triggered {len(plugin_results)} conversation-level plugins"
+ )
for result in plugin_results:
- logger.info(f" Plugin result: success={result.success}, message={result.message}")
+ logger.info(
+ f" Plugin result: success={result.success}, message={result.message}"
+ )
if result.message:
logger.info(f" Plugin result: {result.message}")
diff --git a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py
index 53f0a103..3c0cedd8 100644
--- a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py
+++ b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py
@@ -22,6 +22,10 @@
memory_queue,
)
from advanced_omi_backend.models.job import JobPriority, async_job
+from advanced_omi_backend.observability.otel_setup import (
+ clear_galileo_session,
+ set_galileo_session,
+)
from advanced_omi_backend.plugins.events import PluginEvent
from advanced_omi_backend.services.plugin_service import ensure_plugin_router
@@ -110,7 +114,9 @@ def compute_speaker_diff(
@async_job(redis=True, beanie=True)
-async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict[str, Any]:
+async def process_memory_job(
+ conversation_id: str, *, redis_client=None
+) -> Dict[str, Any]:
"""
RQ job function for memory extraction and processing from conversations.
@@ -135,6 +141,7 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict
from advanced_omi_backend.services.memory import get_memory_service
from advanced_omi_backend.users import get_user_by_id
+ set_galileo_session(conversation_id)
start_time = time.time()
logger.info(f"π Starting memory processing for conversation {conversation_id}")
@@ -169,11 +176,13 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict
for segment in segments:
text = segment.text.strip()
speaker = segment.speaker
- seg_type = getattr(segment, 'segment_type', 'speech')
+ seg_type = getattr(segment, "segment_type", "speech")
if text:
if seg_type == "event":
# Non-speech event: include as context marker without speaker prefix
- dialogue_lines.append(f"[{text}]" if not text.startswith("[") else text)
+ dialogue_lines.append(
+ f"[{text}]" if not text.startswith("[") else text
+ )
elif seg_type == "note":
# User-inserted note: include as distinct context
dialogue_lines.append(f"[Note: {text}]")
@@ -197,14 +206,20 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict
full_conversation = conversation_model.transcript
if len(full_conversation) < MIN_CONVERSATION_LENGTH:
- logger.warning(f"Conversation too short for memory processing: {conversation_id}")
+ logger.warning(
+ f"Conversation too short for memory processing: {conversation_id}"
+ )
return {"success": False, "error": "Conversation too short"}
# Check primary speakers filter (reuse `user` from above β no duplicate DB call)
if user and user.primary_speakers:
- primary_speaker_names = {ps["name"].strip().lower() for ps in user.primary_speakers}
+ primary_speaker_names = {
+ ps["name"].strip().lower() for ps in user.primary_speakers
+ }
- if transcript_speakers and not transcript_speakers.intersection(primary_speaker_names):
+ if transcript_speakers and not transcript_speakers.intersection(
+ primary_speaker_names
+ ):
logger.info(
f"Skipping memory - no primary speakers found in conversation {conversation_id}"
)
@@ -301,11 +316,18 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict
# Fetch memory details to display in UI
memory_details = []
try:
- for memory_id in created_memory_ids[:5]: # Limit to first 5 for display
- memory_entry = await memory_service.get_memory(memory_id, user_id)
+ for memory_id in created_memory_ids[
+ :5
+ ]: # Limit to first 5 for display
+ memory_entry = await memory_service.get_memory(
+ memory_id, user_id
+ )
if memory_entry:
memory_details.append(
- {"memory_id": memory_id, "text": memory_entry.content[:200]}
+ {
+ "memory_id": memory_id,
+ "text": memory_entry.content[:200],
+ }
)
except Exception as e:
logger.warning(f"Failed to fetch memory details for UI: {e}")
@@ -335,7 +357,9 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict
config = get_config()
kg_enabled = (
- config.get("memory", {}).get("knowledge_graph", {}).get("enabled", False)
+ config.get("memory", {})
+ .get("knowledge_graph", {})
+ .get("enabled", False)
)
if kg_enabled:
@@ -379,7 +403,9 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict
"user_id": user_id,
"user_email": user_email,
},
- "memory_count": len(created_memory_ids) if created_memory_ids else 0,
+ "memory_count": (
+ len(created_memory_ids) if created_memory_ids else 0
+ ),
"conversation_id": conversation_id,
}
@@ -403,7 +429,9 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict
)
if plugin_results:
- logger.info(f"π Triggered {len(plugin_results)} memory-level plugins")
+ logger.info(
+ f"π Triggered {len(plugin_results)} memory-level plugins"
+ )
for result in plugin_results:
if result.message:
logger.info(f" Plugin result: {result.message}")
@@ -413,7 +441,9 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict
return {
"success": True,
- "memories_created": len(created_memory_ids) if created_memory_ids else 0,
+ "memories_created": (
+ len(created_memory_ids) if created_memory_ids else 0
+ ),
"processing_time": processing_time,
}
else:
@@ -461,7 +491,11 @@ async def _process_speaker_reprocess(
f"falling back to normal extraction"
)
return await memory_service.add_memory(
- full_conversation, client_id, conversation_id, user_id, user_email,
+ full_conversation,
+ client_id,
+ conversation_id,
+ user_id,
+ user_email,
allow_update=True,
)
@@ -474,7 +508,11 @@ async def _process_speaker_reprocess(
f"for {conversation_id}, falling back to normal extraction"
)
return await memory_service.add_memory(
- full_conversation, client_id, conversation_id, user_id, user_email,
+ full_conversation,
+ client_id,
+ conversation_id,
+ user_id,
+ user_email,
allow_update=True,
)
@@ -491,7 +529,11 @@ async def _process_speaker_reprocess(
f"for {conversation_id}, falling back to normal extraction"
)
return await memory_service.add_memory(
- full_conversation, client_id, conversation_id, user_id, user_email,
+ full_conversation,
+ client_id,
+ conversation_id,
+ user_id,
+ user_email,
allow_update=True,
)
@@ -507,7 +549,11 @@ async def _process_speaker_reprocess(
f"for {conversation_id}, falling back to normal extraction"
)
return await memory_service.add_memory(
- full_conversation, client_id, conversation_id, user_id, user_email,
+ full_conversation,
+ client_id,
+ conversation_id,
+ user_id,
+ user_email,
allow_update=True,
)
@@ -564,5 +610,7 @@ def enqueue_memory_processing(
description=f"Process memory for conversation {conversation_id[:8]}",
)
- logger.info(f"π₯ RQ: Enqueued memory job {job.id} for conversation {conversation_id}")
+ logger.info(
+ f"π₯ RQ: Enqueued memory job {job.id} for conversation {conversation_id}"
+ )
return job
diff --git a/backends/advanced/src/advanced_omi_backend/workers/rq_worker_entry.py b/backends/advanced/src/advanced_omi_backend/workers/rq_worker_entry.py
index d9da1c6a..a3c07fbc 100755
--- a/backends/advanced/src/advanced_omi_backend/workers/rq_worker_entry.py
+++ b/backends/advanced/src/advanced_omi_backend/workers/rq_worker_entry.py
@@ -14,7 +14,7 @@
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
- stream=sys.stdout
+ stream=sys.stdout,
)
logger = logging.getLogger(__name__)
@@ -22,14 +22,24 @@
def main():
"""Start RQ worker with proper logging configuration."""
+ # Initialize OTEL/Galileo if configured (patches OpenAI before any job imports)
+ try:
+ from advanced_omi_backend.observability.otel_setup import init_otel
+
+ init_otel()
+ except Exception:
+ pass # Optional β don't block workers
+
from redis import Redis
from rq import Worker
# Get Redis URL from environment
- redis_url = os.getenv('REDIS_URL', 'redis://localhost:6379/0')
+ redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
# Get queue names from command line arguments
- queue_names = sys.argv[1:] if len(sys.argv) > 1 else ['transcription', 'memory', 'default']
+ queue_names = (
+ sys.argv[1:] if len(sys.argv) > 1 else ["transcription", "memory", "default"]
+ )
logger.info(f"π Starting RQ worker for queues: {', '.join(queue_names)}")
logger.info(f"π‘ Redis URL: {redis_url}")
@@ -38,16 +48,12 @@ def main():
redis_conn = Redis.from_url(redis_url)
# Create and start worker
- worker = Worker(
- queue_names,
- connection=redis_conn,
- log_job_description=True
- )
+ worker = Worker(queue_names, connection=redis_conn, log_job_description=True)
logger.info("β
RQ worker ready")
# This blocks until worker is stopped
- worker.work(logging_level='INFO')
+ worker.work(logging_level="INFO")
if __name__ == "__main__":
diff --git a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py
index 5d6e592e..fb644ec1 100644
--- a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py
+++ b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py
@@ -21,7 +21,10 @@
from rq.exceptions import NoSuchJobError
from rq.job import Job
-from advanced_omi_backend.config import get_backend_config, get_transcription_job_timeout
+from advanced_omi_backend.config import (
+ get_backend_config,
+ get_transcription_job_timeout,
+)
from advanced_omi_backend.controllers.queue_controller import (
JOB_RESULT_TTL,
REDIS_URL,
@@ -32,8 +35,8 @@
from advanced_omi_backend.models.audio_chunk import AudioChunkDocument
from advanced_omi_backend.models.conversation import Conversation
from advanced_omi_backend.models.job import BaseRQJob, JobPriority, async_job
-from advanced_omi_backend.services.audio_stream import TranscriptionResultsAggregator
from advanced_omi_backend.plugins.events import PluginEvent
+from advanced_omi_backend.services.audio_stream import TranscriptionResultsAggregator
from advanced_omi_backend.services.plugin_service import ensure_plugin_router
from advanced_omi_backend.services.transcription import (
get_transcription_provider,
@@ -82,7 +85,9 @@ async def apply_speaker_recognition(
speaker_client = SpeakerRecognitionClient()
if not speaker_client.enabled:
- logger.info(f"π€ Speaker recognition disabled, using original speaker labels")
+ logger.info(
+ f"π€ Speaker recognition disabled, using original speaker labels"
+ )
return segments
logger.info(
@@ -123,7 +128,9 @@ def get_speaker_at_time(timestamp: float, speaker_segments: list) -> str:
updated_count = 0
for seg in segments:
seg_mid = (seg.start + seg.end) / 2.0
- identified_speaker = get_speaker_at_time(seg_mid, speaker_identified_segments)
+ identified_speaker = get_speaker_at_time(
+ seg_mid, speaker_identified_segments
+ )
if identified_speaker and identified_speaker != "Unknown":
original_speaker = seg.speaker
@@ -186,7 +193,9 @@ async def transcribe_full_audio_job(
start_time = time.time()
# Get the conversation
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
raise ValueError(f"Conversation {conversation_id} not found")
@@ -203,18 +212,23 @@ async def transcribe_full_audio_job(
logger.info(f"Using transcription provider: {provider_name}")
# Reconstruct audio from MongoDB chunks
- logger.info(f"π¦ Reconstructing audio from MongoDB chunks for conversation {conversation_id}")
+ logger.info(
+ f"π¦ Reconstructing audio from MongoDB chunks for conversation {conversation_id}"
+ )
try:
# Reconstruct WAV from MongoDB chunks (already in memory as bytes)
wav_data = await reconstruct_wav_from_conversation(conversation_id)
logger.info(
- f"π¦ Reconstructed audio from MongoDB chunks: " f"{len(wav_data) / 1024 / 1024:.2f} MB"
+ f"π¦ Reconstructed audio from MongoDB chunks: "
+ f"{len(wav_data) / 1024 / 1024:.2f} MB"
)
except ValueError as e:
# No chunks found for conversation
- raise FileNotFoundError(f"No audio chunks found for conversation {conversation_id}: {e}")
+ raise FileNotFoundError(
+ f"No audio chunks found for conversation {conversation_id}: {e}"
+ )
except Exception as e:
logger.error(f"Failed to reconstruct audio from MongoDB: {e}", exc_info=True)
raise RuntimeError(f"Audio reconstruction failed: {e}")
@@ -236,11 +250,27 @@ async def transcribe_full_audio_job(
actual_sample_rate = 16000
try:
+ # Progress callback: writes batch progress to RQ job.meta so the
+ # queue API and UI can show "Transcribing segment X of Y".
+ def _on_batch_progress(event: dict) -> None:
+ job = get_current_job()
+ if job:
+ current = event.get("current", 0)
+ total = event.get("total", 0)
+ job.meta["batch_progress"] = {
+ "current": current,
+ "total": total,
+ "percent": int(current / total * 100) if total else 0,
+ "message": f"Transcribing segment {current} of {total}",
+ }
+ job.save_meta()
+
# Transcribe the audio directly from memory (no disk I/O needed)
transcribe_kwargs: Dict[str, Any] = {
"audio_data": wav_data,
"sample_rate": actual_sample_rate,
"diarize": True,
+ "progress_callback": _on_batch_progress,
}
if context_info:
transcribe_kwargs["context_info"] = context_info
@@ -308,7 +338,9 @@ async def transcribe_full_audio_job(
if result.message:
logger.info(f" Plugin: {result.message}")
except Exception as e:
- logger.exception(f"β οΈ Error triggering transcript plugins in batch mode: {e}")
+ logger.exception(
+ f"β οΈ Error triggering transcript plugins in batch mode: {e}"
+ )
logger.info(f"π DEBUG: Plugin processing complete, moving to speech validation")
@@ -363,7 +395,9 @@ async def transcribe_full_audio_job(
f"Job {job_id} hash not found (likely already completed or expired)"
)
else:
- logger.debug(f"Job {job_id} not found or already completed: {e}")
+ logger.debug(
+ f"Job {job_id} not found or already completed: {e}"
+ )
if cancelled_jobs:
logger.info(
@@ -586,7 +620,10 @@ async def create_audio_only_conversation(
placeholder_conversation = await Conversation.find_one(
Conversation.client_id == session_id,
Conversation.always_persist == True,
- In(Conversation.processing_status, ["pending_transcription", "transcription_failed"]),
+ In(
+ Conversation.processing_status,
+ ["pending_transcription", "transcription_failed"],
+ ),
)
if placeholder_conversation:
@@ -597,7 +634,9 @@ async def create_audio_only_conversation(
# Update status to show batch transcription is starting
placeholder_conversation.processing_status = "batch_transcription"
placeholder_conversation.title = "Audio Recording (Batch Transcription...)"
- placeholder_conversation.summary = "Processing audio with offline transcription..."
+ placeholder_conversation.summary = (
+ "Processing audio with offline transcription..."
+ )
await placeholder_conversation.save()
# Audio chunks are already linked to this conversation_id
@@ -624,13 +663,20 @@ async def create_audio_only_conversation(
)
await conversation.insert()
- logger.info(f"β
Created batch transcription conversation {session_id[:12]} for fallback")
+ logger.info(
+ f"β
Created batch transcription conversation {session_id[:12]} for fallback"
+ )
return conversation
@async_job(redis=True, beanie=True)
async def transcription_fallback_check_job(
- session_id: str, user_id: str, client_id: str, timeout_seconds: int = None, *, redis_client=None
+ session_id: str,
+ user_id: str,
+ client_id: str,
+ timeout_seconds: int = None,
+ *,
+ redis_client=None,
) -> Dict[str, Any]:
"""
Check if streaming transcription succeeded, fallback to batch if needed.
@@ -765,14 +811,18 @@ async def transcription_fallback_check_job(
sample_rate, channels, sample_width = 16000, 1, 2
session_key = f"audio:session:{session_id}"
try:
- audio_format_raw = await redis_client.hget(session_key, "audio_format")
+ audio_format_raw = await redis_client.hget(
+ session_key, "audio_format"
+ )
if audio_format_raw:
audio_format = json.loads(audio_format_raw)
sample_rate = int(audio_format.get("rate", 16000))
channels = int(audio_format.get("channels", 1))
sample_width = int(audio_format.get("width", 2))
except Exception as e:
- logger.warning(f"Failed to read audio_format from Redis for {session_id}: {e}")
+ logger.warning(
+ f"Failed to read audio_format from Redis for {session_id}: {e}"
+ )
bytes_per_second = sample_rate * channels * sample_width
logger.info(
@@ -781,7 +831,9 @@ async def transcription_fallback_check_job(
)
# Create conversation placeholder
- conversation = await create_audio_only_conversation(session_id, user_id, client_id)
+ conversation = await create_audio_only_conversation(
+ session_id, user_id, client_id
+ )
# Save audio to MongoDB chunks for batch transcription
num_chunks = await convert_audio_to_chunks(
@@ -798,7 +850,9 @@ async def transcription_fallback_check_job(
)
except Exception as e:
- logger.error(f"β Failed to extract audio from Redis stream: {e}", exc_info=True)
+ logger.error(
+ f"β Failed to extract audio from Redis stream: {e}", exc_info=True
+ )
raise
else:
logger.info(
@@ -807,7 +861,9 @@ async def transcription_fallback_check_job(
)
# Create conversation placeholder for batch transcription
- conversation = await create_audio_only_conversation(session_id, user_id, client_id)
+ conversation = await create_audio_only_conversation(
+ session_id, user_id, client_id
+ )
# Enqueue batch transcription job
version_id = f"batch_fallback_{session_id[:12]}"
@@ -903,10 +959,14 @@ async def stream_speech_detection_job(
# Get conversation count
conversation_count_key = f"session:conversation_count:{session_id}"
conversation_count_bytes = await redis_client.get(conversation_count_key)
- conversation_count = int(conversation_count_bytes) if conversation_count_bytes else 0
+ conversation_count = (
+ int(conversation_count_bytes) if conversation_count_bytes else 0
+ )
# Check if speaker filtering is enabled
- speaker_filter_enabled = os.getenv("RECORD_ONLY_ENROLLED_SPEAKERS", "false").lower() == "true"
+ speaker_filter_enabled = (
+ os.getenv("RECORD_ONLY_ENROLLED_SPEAKERS", "false").lower() == "true"
+ )
logger.info(
f"π Conversation #{conversation_count + 1}, Speaker filter: {'enabled' if speaker_filter_enabled else 'disabled'}"
)
@@ -942,7 +1002,10 @@ async def stream_speech_detection_job(
# Check if session has closed
session_status = await redis_client.hget(session_key, "status")
- session_closed = session_status and session_status.decode() in ["finalizing", "finished"]
+ session_closed = session_status and session_status.decode() in [
+ "finalizing",
+ "finished",
+ ]
if session_closed and session_closed_at is None:
# Session just closed - start grace period for final transcription
@@ -952,10 +1015,30 @@ async def stream_speech_detection_job(
)
# Exit if grace period expired without speech
- if session_closed_at and (time.time() - session_closed_at) > final_check_grace_period:
+ if (
+ session_closed_at
+ and (time.time() - session_closed_at) > final_check_grace_period
+ ):
logger.info(f"β
Session ended without speech (grace period expired)")
break
+ # Consume any stale conversation close request (defensive β shouldn't normally
+ # appear since services.py gates on conversation:current, but handles race conditions)
+ close_reason = await redis_client.hget(
+ session_key, "conversation_close_requested"
+ )
+ if close_reason:
+ await redis_client.hdel(session_key, "conversation_close_requested")
+ close_reason_str = (
+ close_reason.decode()
+ if isinstance(close_reason, bytes)
+ else close_reason
+ )
+ logger.info(
+ f"π Conversation close requested ({close_reason_str}) during speech detection β "
+ f"no open conversation to close, flag consumed"
+ )
+
if time.time() - start_time > max_runtime:
logger.warning(f"β±οΈ Max runtime reached, exiting")
break
@@ -966,11 +1049,15 @@ async def stream_speech_detection_job(
# Health check: detect transcription errors early during grace period
if session_closed_at:
# Check for streaming consumer errors in session metadata
- error_status = await redis_client.hget(session_key, "transcription_error")
+ error_status = await redis_client.hget(
+ session_key, "transcription_error"
+ )
if error_status:
error_msg = error_status.decode()
logger.error(f"β Transcription service error: {error_msg}")
- logger.error(f"β Session failed - transcription service unavailable")
+ logger.error(
+ f"β Session failed - transcription service unavailable"
+ )
break
# Check if we've been waiting too long with no results at all
@@ -980,7 +1067,9 @@ async def stream_speech_detection_job(
logger.error(
f"β No transcription activity after {grace_elapsed:.1f}s - possible API key or connectivity issue"
)
- logger.error(f"β Session failed - check transcription service configuration")
+ logger.error(
+ f"β Session failed - check transcription service configuration"
+ )
break
await asyncio.sleep(2)
@@ -1016,9 +1105,13 @@ async def stream_speech_detection_job(
from datetime import datetime
await redis_client.hset(
- session_key, "last_event", f"speech_detected:{datetime.utcnow().isoformat()}"
+ session_key,
+ "last_event",
+ f"speech_detected:{datetime.utcnow().isoformat()}",
+ )
+ await redis_client.hset(
+ session_key, "speech_detected_at", datetime.utcnow().isoformat()
)
- await redis_client.hset(session_key, "speech_detected_at", datetime.utcnow().isoformat())
# Step 2: If speaker filter enabled, check for enrolled speakers
identified_speakers = []
@@ -1028,7 +1121,9 @@ async def stream_speech_detection_job(
# Add session event for speaker check starting
await redis_client.hset(
- session_key, "last_event", f"speaker_check_starting:{datetime.utcnow().isoformat()}"
+ session_key,
+ "last_event",
+ f"speaker_check_starting:{datetime.utcnow().isoformat()}",
)
await redis_client.hset(session_key, "speaker_check_status", "checking")
from .speaker_jobs import check_enrolled_speakers_job
@@ -1068,7 +1163,9 @@ async def stream_speech_detection_job(
result = speaker_check_job.result
enrolled_present = result.get("enrolled_present", False)
identified_speakers = result.get("identified_speakers", [])
- logger.info(f"β
Speaker check completed: enrolled={enrolled_present}")
+ logger.info(
+ f"β
Speaker check completed: enrolled={enrolled_present}"
+ )
# Update session event for speaker check complete
await redis_client.hset(
@@ -1083,7 +1180,9 @@ async def stream_speech_detection_job(
)
if identified_speakers:
await redis_client.hset(
- session_key, "identified_speakers", ",".join(identified_speakers)
+ session_key,
+ "identified_speakers",
+ ",".join(identified_speakers),
)
break
elif speaker_check_job.is_failed:
@@ -1095,7 +1194,9 @@ async def stream_speech_detection_job(
"last_event",
f"speaker_check_failed:{datetime.utcnow().isoformat()}",
)
- await redis_client.hset(session_key, "speaker_check_status", "failed")
+ await redis_client.hset(
+ session_key, "speaker_check_status", "failed"
+ )
break
await asyncio.sleep(poll_interval)
waited += poll_interval
@@ -1148,7 +1249,9 @@ async def stream_speech_detection_job(
)
# Track the job
- await redis_client.set(open_job_key, open_job.id, ex=10800) # 3 hours to match job timeout
+ await redis_client.set(
+ open_job_key, open_job.id, ex=10800
+ ) # 3 hours to match job timeout
# Store metadata in speech detection job
if current_job:
@@ -1161,23 +1264,31 @@ async def stream_speech_detection_job(
current_job.meta.update(
{
"conversation_job_id": open_job.id,
- "speaker_check_job_id": speaker_check_job.id if speaker_check_job else None,
+ "speaker_check_job_id": (
+ speaker_check_job.id if speaker_check_job else None
+ ),
"detected_speakers": identified_speakers,
- "speech_detected_at": datetime.fromtimestamp(speech_detected_at).isoformat(),
+ "speech_detected_at": datetime.fromtimestamp(
+ speech_detected_at
+ ).isoformat(),
"session_id": session_id,
"client_id": client_id, # For job grouping
}
)
current_job.save_meta()
- logger.info(f"β
Started conversation job {open_job.id}, exiting speech detection")
+ logger.info(
+ f"β
Started conversation job {open_job.id}, exiting speech detection"
+ )
return {
"session_id": session_id,
"user_id": user_id,
"client_id": client_id,
"conversation_job_id": open_job.id,
- "speech_detected_at": datetime.fromtimestamp(speech_detected_at).isoformat(),
+ "speech_detected_at": datetime.fromtimestamp(
+ speech_detected_at
+ ).isoformat(),
"runtime_seconds": time.time() - start_time,
}
@@ -1205,7 +1316,9 @@ async def stream_speech_detection_job(
# Check if this is an always_persist conversation that needs to be marked as failed
# NOTE: We check MongoDB directly because the conversation:current Redis key might have been
# deleted by the audio persistence job cleanup (which runs in parallel).
- logger.info(f"π Checking MongoDB for always_persist conversation with client_id: {client_id}")
+ logger.info(
+ f"π Checking MongoDB for always_persist conversation with client_id: {client_id}"
+ )
# Find conversation by client_id that matches this session
# session_id == client_id for streaming sessions (set in _initialize_streaming_session)
diff --git a/config/config.yml.template b/config/config.yml.template
index c4aaa5ae..e292a7f3 100644
--- a/config/config.yml.template
+++ b/config/config.yml.template
@@ -174,6 +174,11 @@ models:
method: POST
path: /transcribe
content_type: multipart/form-data
+ timeout: 300
+ # Per-read timeout for NDJSON batch progress responses.
+ # Each batch window can take minutes; this timeout covers
+ # the gap between successive progress lines.
+ read_timeout: 600
response:
type: json
extract:
@@ -319,12 +324,15 @@ models:
api_family: websocket
model_url: wss://api.deepgram.com/v1/listen
api_key: ${oc.env:DEEPGRAM_API_KEY,''}
+ capabilities:
+ - diarization
operations:
query:
model: nova-3
language: multi
smart_format: 'true'
punctuate: 'true'
+ diarize: 'true'
encoding: linear16
sample_rate: 16000
channels: '1'
diff --git a/config/defaults.yml b/config/defaults.yml
index a631f486..b32ec1c1 100644
--- a/config/defaults.yml
+++ b/config/defaults.yml
@@ -207,12 +207,15 @@ models:
api_family: websocket
model_url: wss://api.deepgram.com/v1/listen
api_key: ${oc.env:DEEPGRAM_API_KEY,''}
+ capabilities:
+ - diarization
operations:
query:
model: nova-3
language: multi
smart_format: 'true'
punctuate: 'true'
+ diarize: 'true'
encoding: linear16
sample_rate: 16000
channels: '1'
@@ -298,6 +301,7 @@ models:
sample_rate: 16000
word_timestamps: 'true'
diarize: 'true'
+ sentence_timestamps: 'true'
end:
message:
type: finalize
diff --git a/config/plugins.yml.template b/config/plugins.yml.template
index 789cd9ed..2eec468d 100644
--- a/config/plugins.yml.template
+++ b/config/plugins.yml.template
@@ -27,10 +27,9 @@ plugins:
# - transcript.batch # Uncomment to also handle batch transcription
# - conversation.complete # Uncomment to handle completed conversations
condition:
- type: wake_word
- wake_words: # Support multiple wake words
- - hey vivi # Example: "hey vivi, turn off the lights"
- - hey jarvis # Example: "hey jarvis, what's the temperature"
+ type: keyword_anywhere # Trigger when keyword appears anywhere in transcript
+ keywords: # Support multiple keywords
+ - vivi # Example: "turn off the lights, vivi"
ha_url: http://host.docker.internal:8123 # Your Home Assistant URL
ha_token: ${HA_TOKEN} # ALWAYS use env var - never paste actual token here!
# To get a long-lived token:
diff --git a/restart.sh b/restart.sh
index 019518c4..9ac7c54d 100755
--- a/restart.sh
+++ b/restart.sh
@@ -1,2 +1,3 @@
#!/bin/bash
+source "$(dirname "$0")/scripts/check_uv.sh"
uv run --with-requirements setup-requirements.txt python services.py restart --all
diff --git a/start.sh b/start.sh
index b01ef87a..723fb20f 100755
--- a/start.sh
+++ b/start.sh
@@ -1 +1,3 @@
+#!/bin/bash
+source "$(dirname "$0")/scripts/check_uv.sh"
uv run --with-requirements setup-requirements.txt python services.py start --all "$@"
diff --git a/status.sh b/status.sh
index a66fe459..6ded4c94 100755
--- a/status.sh
+++ b/status.sh
@@ -1,2 +1,3 @@
#!/bin/bash
+source "$(dirname "$0")/scripts/check_uv.sh"
uv run --with-requirements setup-requirements.txt python status.py "$@"
diff --git a/stop.sh b/stop.sh
index 0f49add7..0fc033d4 100755
--- a/stop.sh
+++ b/stop.sh
@@ -1 +1,3 @@
+#!/bin/bash
+source "$(dirname "$0")/scripts/check_uv.sh"
uv run --with-requirements setup-requirements.txt python services.py stop --all
diff --git a/tests/config/plugins.test.yml b/tests/config/plugins.test.yml
index 89772a56..7ba31157 100644
--- a/tests/config/plugins.test.yml
+++ b/tests/config/plugins.test.yml
@@ -12,3 +12,14 @@ plugins:
condition:
type: always # Capture all events without filtering
db_path: /app/debug/test_plugin_events.db
+
+ test_button_actions:
+ enabled: true
+ events:
+ - button.single_press
+ - button.double_press
+ condition:
+ type: always
+ actions:
+ single_press:
+ type: close_conversation
diff --git a/tests/integration/websocket_streaming_tests.robot b/tests/integration/websocket_streaming_tests.robot
index 63baadf8..8c0ac647 100644
--- a/tests/integration/websocket_streaming_tests.robot
+++ b/tests/integration/websocket_streaming_tests.robot
@@ -34,12 +34,12 @@ Streaming jobs created on stream start
Sleep 2s
# Check speech detection job
${jobs}= Get Jobs By Type speech_detection
- Should Not Be Empty ${jobs}
+ Should Not Be Empty ${jobs}
${speech_job}= Find Job For Client ${jobs} ${device_name}
Should Not Be Equal ${speech_job} ${None} Speech detection job not created
# Check audio persistence job
- ${persist_job}= Find Job For Client ${jobs} ${device_name}
+ ${persist_job}= Find Job For Client ${jobs} ${device_name}
Should Not Be Equal ${persist_job} ${None} Audio persistence job not created
Log Both jobs active during streaming
@@ -102,6 +102,40 @@ Conversation Job Created After Speech Detection
Log Closed stream, sent ${total_chunks} total chunks
+Button Press Should Close Active Conversation
+ [Documentation] Verify that a button single press during an active conversation
+ ... closes it with end_reason=close_requested and triggers post-processing
+ [Tags] audio-streaming conversation
+ [Timeout] 120s
+
+ # Arrange: Open stream and get enough speech to start a conversation
+ ${device_name}= Set Variable ws-button-close
+ ${stream_id}= Open Audio Stream device_name=${device_name}
+ ${client_id}= Get Client ID From Device Name ${device_name}
+
+ # Get baseline conversation jobs
+ ${baseline_jobs}= Get Jobs By Type And Client open_conversation ${client_id}
+ ${baseline_count}= Get Length ${baseline_jobs}
+
+ # Send audio with speech (realtime pacing for Deepgram to finalize segments)
+ Send Audio Chunks To Stream ${stream_id} ${TEST_AUDIO_FILE} num_chunks=200 realtime_pacing=True
+
+ # Wait for open_conversation_job to start (speech detected -> conversation opened)
+ ${jobs}= Wait Until Keyword Succeeds 60s 3s
+ ... Wait For New Job To Appear open_conversation ${client_id} ${baseline_count}
+ ${conversation_id}= Evaluate $jobs[0]['meta'].get('conversation_id', '')
+ Should Not Be Empty ${conversation_id} msg=Conversation ID not found in job meta
+
+ # Act: Send button press to close the conversation
+ Send Button Event To Stream ${stream_id} SINGLE_TAP
+
+ # Assert: Conversation should close with end_reason=close_requested
+ Wait Until Keyword Succeeds 30s 2s
+ ... Conversation Should Have End Reason ${conversation_id} close_requested
+
+ # Cleanup
+ [Teardown] Run Keyword And Ignore Error Close Audio Stream ${stream_id}
+
Conversation Closes On Inactivity Timeout And Restarts Speech Detection
[Documentation] Verify that after SPEECH_INACTIVITY_THRESHOLD_SECONDS of silence (audio time),
... the open_conversation job closes with timeout_triggered=True,
@@ -168,6 +202,3 @@ Conversation Closes On Inactivity Timeout And Restarts Speech Detection
# Memory extraction job should be created
${memory_jobs}= Get Jobs By Type And Conversation process_memory_job ${conversation_id}
Log To Console Memory jobs found: ${memory_jobs.__len__()}
-
-
-
diff --git a/tests/integration/websocket_transcription_e2e_test.robot b/tests/integration/websocket_transcription_e2e_test.robot
index 3711ac54..951f3aa7 100644
--- a/tests/integration/websocket_transcription_e2e_test.robot
+++ b/tests/integration/websocket_transcription_e2e_test.robot
@@ -53,16 +53,13 @@ WebSocket Stream Produces Final Transcripts In Redis
Log Closing stream - should trigger: end_marker β CloseStream β final results
Close Audio Stream ${stream_id}
- # Allow time for streaming consumer to process end_marker and get final results
- Sleep 5s
-
- # Verify Redis stream transcription:results:{client_id} has entries
+ # Wait for streaming consumer to process end_marker and write final results to Redis
+ # Use retry loop instead of fixed sleep - consumer processing time varies
${stream_name}= Set Variable transcription:results:${client_id}
- ${stream_length}= Redis Command XLEN ${stream_name}
-
- Should Be True ${stream_length} > 0
- ... Redis stream ${stream_name} is empty - no final transcripts received! This means end_marker was not sent or CloseStream failed.
+ Wait Until Keyword Succeeds 30s 2s
+ ... Redis Stream Should Not Be Empty ${stream_name}
+ ${stream_length}= Redis Command XLEN ${stream_name}
Log β
Redis stream has ${stream_length} final transcript(s)
@@ -387,3 +384,13 @@ Streaming Completion Signal Is Set Before Transcript Read
Log β
Completion signal ${completion_key} = ${signal_value} (consumer completed before job reads)
+*** Keywords ***
+
+Redis Stream Should Not Be Empty
+ [Documentation] Assert that a Redis stream has at least one entry.
+ ... Used with Wait Until Keyword Succeeds for retry-based checks.
+ [Arguments] ${stream_name}
+
+ ${stream_length}= Redis Command XLEN ${stream_name}
+ Should Be True ${stream_length} > 0
+ ... Redis stream ${stream_name} is empty - no final transcripts received! This means end_marker was not sent or CloseStream failed.
diff --git a/tests/libs/audio_stream_library.py b/tests/libs/audio_stream_library.py
index e14a174e..bf870639 100644
--- a/tests/libs/audio_stream_library.py
+++ b/tests/libs/audio_stream_library.py
@@ -27,7 +27,10 @@
sys.path.insert(0, str(backend_src))
from advanced_omi_backend.clients import AudioStreamClient
-from advanced_omi_backend.clients.audio_stream_client import StreamManager, stream_audio_file as _stream_audio_file
+from advanced_omi_backend.clients.audio_stream_client import StreamManager
+from advanced_omi_backend.clients.audio_stream_client import (
+ stream_audio_file as _stream_audio_file,
+)
# Module-level manager for non-blocking streams
_manager = StreamManager()
@@ -37,6 +40,7 @@
# Blocking Mode (simple, streams entire file)
# =============================================================================
+
def stream_audio_file(
base_url: str,
token: str,
@@ -60,6 +64,7 @@ def stream_audio_file(
# Non-blocking Mode (for testing during stream)
# =============================================================================
+
def start_audio_stream(
base_url: str,
token: str,
@@ -145,6 +150,16 @@ def close_audio_stream_without_stop(stream_id: str) -> int:
return _manager.close_stream_without_stop(stream_id)
+def send_button_event(stream_id: str, button_state: str = "SINGLE_TAP") -> None:
+ """Send a button event to an open stream.
+
+ Args:
+ stream_id: Stream session ID
+ button_state: Button state ("SINGLE_TAP" or "DOUBLE_TAP")
+ """
+ _manager.send_button_event(stream_id, button_state)
+
+
def cleanup_all_streams():
"""Stop all active streams."""
_manager.cleanup_all()
@@ -154,6 +169,7 @@ def cleanup_all_streams():
# Advanced Usage
# =============================================================================
+
def get_audio_stream_client(
base_url: str,
token: str,
diff --git a/tests/resources/websocket_keywords.robot b/tests/resources/websocket_keywords.robot
index 2eb9381b..1ca2b164 100644
--- a/tests/resources/websocket_keywords.robot
+++ b/tests/resources/websocket_keywords.robot
@@ -158,6 +158,12 @@ Close Audio Stream
Log Stopped stream ${stream_id}, total chunks: ${total_chunks}
RETURN ${total_chunks}
+Send Button Event To Stream
+ [Documentation] Send a button event (SINGLE_TAP, DOUBLE_TAP) to an open stream
+ [Arguments] ${stream_id} ${button_state}=SINGLE_TAP
+ Send Button Event ${stream_id} ${button_state}
+ Log Sent button event ${button_state} to stream ${stream_id}
+
Close Audio Stream Without Stop Event
[Documentation] Close WebSocket connection without sending audio-stop event.
... This simulates abrupt disconnection (network failure, client crash)
diff --git a/wizard.py b/wizard.py
index b04f028c..e4d52395 100755
--- a/wizard.py
+++ b/wizard.py
@@ -26,38 +26,67 @@
console = Console()
SERVICES = {
- 'backend': {
- 'advanced': {
- 'path': 'backends/advanced',
- 'cmd': ['uv', 'run', '--with-requirements', '../../setup-requirements.txt', 'python', 'init.py'],
- 'description': 'Advanced AI backend with full feature set',
- 'required': True
+ "backend": {
+ "advanced": {
+ "path": "backends/advanced",
+ "cmd": [
+ "uv",
+ "run",
+ "--with-requirements",
+ "../../setup-requirements.txt",
+ "python",
+ "init.py",
+ ],
+ "description": "Advanced AI backend with full feature set",
+ "required": True,
}
},
- 'extras': {
- 'speaker-recognition': {
- 'path': 'extras/speaker-recognition',
- 'cmd': ['uv', 'run', '--with-requirements', '../../setup-requirements.txt', 'python', 'init.py'],
- 'description': 'Speaker identification and enrollment'
+ "extras": {
+ "speaker-recognition": {
+ "path": "extras/speaker-recognition",
+ "cmd": [
+ "uv",
+ "run",
+ "--with-requirements",
+ "../../setup-requirements.txt",
+ "python",
+ "init.py",
+ ],
+ "description": "Speaker identification and enrollment",
},
- 'asr-services': {
- 'path': 'extras/asr-services',
- 'cmd': ['uv', 'run', '--with-requirements', '../../setup-requirements.txt', 'python', 'init.py'],
- 'description': 'Offline speech-to-text'
+ "asr-services": {
+ "path": "extras/asr-services",
+ "cmd": [
+ "uv",
+ "run",
+ "--with-requirements",
+ "../../setup-requirements.txt",
+ "python",
+ "init.py",
+ ],
+ "description": "Offline speech-to-text",
},
- 'openmemory-mcp': {
- 'path': 'extras/openmemory-mcp',
- 'cmd': ['./setup.sh'],
- 'description': 'OpenMemory MCP server'
+ "openmemory-mcp": {
+ "path": "extras/openmemory-mcp",
+ "cmd": ["./setup.sh"],
+ "description": "OpenMemory MCP server",
},
- 'langfuse': {
- 'path': 'extras/langfuse',
- 'cmd': ['uv', 'run', '--with-requirements', '../../setup-requirements.txt', 'python', 'init.py'],
- 'description': 'LLM observability and prompt management (local)'
- }
- }
+ "langfuse": {
+ "path": "extras/langfuse",
+ "cmd": [
+ "uv",
+ "run",
+ "--with-requirements",
+ "../../setup-requirements.txt",
+ "python",
+ "init.py",
+ ],
+ "description": "LLM observability and prompt management (local)",
+ },
+ },
}
+
def discover_available_plugins():
"""
Discover plugins by scanning plugins directory.
@@ -75,11 +104,13 @@ def discover_available_plugins():
plugins_dir = Path("backends/advanced/src/advanced_omi_backend/plugins")
if not plugins_dir.exists():
- console.print(f"[yellow]Warning: Plugins directory not found: {plugins_dir}[/yellow]")
+ console.print(
+ f"[yellow]Warning: Plugins directory not found: {plugins_dir}[/yellow]"
+ )
return {}
discovered = {}
- skip_dirs = {'__pycache__', '__init__.py', 'base.py', 'router.py'}
+ skip_dirs = {"__pycache__", "__init__.py", "base.py", "router.py"}
for plugin_dir in plugins_dir.iterdir():
if not plugin_dir.is_dir() or plugin_dir.name in skip_dirs:
@@ -89,32 +120,37 @@ def discover_available_plugins():
setup_script = plugin_dir / "setup.py"
discovered[plugin_id] = {
- 'has_setup': setup_script.exists(),
- 'setup_path': setup_script if setup_script.exists() else None,
- 'dir': plugin_dir
+ "has_setup": setup_script.exists(),
+ "setup_path": setup_script if setup_script.exists() else None,
+ "dir": plugin_dir,
}
return discovered
+
def check_service_exists(service_name, service_config):
"""Check if service directory and script exist"""
- service_path = Path(service_config['path'])
+ service_path = Path(service_config["path"])
if not service_path.exists():
return False, f"Directory {service_path} does not exist"
# For services with Python init scripts, check if init.py exists
- if service_name in ['advanced', 'speaker-recognition', 'asr-services', 'langfuse']:
- script_path = service_path / 'init.py'
+ if service_name in ["advanced", "speaker-recognition", "asr-services", "langfuse"]:
+ script_path = service_path / "init.py"
if not script_path.exists():
return False, f"Script {script_path} does not exist"
else:
# For other extras, check if setup.sh exists
- script_path = service_path / 'setup.sh'
+ script_path = service_path / "setup.sh"
if not script_path.exists():
- return False, f"Script {script_path} does not exist (will be created in Phase 2)"
+ return (
+ False,
+ f"Script {script_path} does not exist (will be created in Phase 2)",
+ )
return True, "OK"
+
def select_services(transcription_provider=None):
"""Let user select which services to setup"""
console.print("π [bold cyan]Chronicle Service Setup[/bold cyan]")
@@ -125,24 +161,30 @@ def select_services(transcription_provider=None):
# Backend is required
console.print("π± [bold]Backend (Required):[/bold]")
console.print(" β
Advanced Backend - Full AI features")
- selected.append('advanced')
+ selected.append("advanced")
# Services that will be auto-added based on transcription provider choice
auto_added = set()
if transcription_provider in ("parakeet", "vibevoice", "qwen3-asr"):
- auto_added.add('asr-services')
+ auto_added.add("asr-services")
# Optional extras
console.print("\nπ§ [bold]Optional Services:[/bold]")
- for service_name, service_config in SERVICES['extras'].items():
+ for service_name, service_config in SERVICES["extras"].items():
# Skip services that will be auto-added based on earlier choices
if service_name in auto_added:
- provider_label = {"vibevoice": "VibeVoice", "parakeet": "Parakeet", "qwen3-asr": "Qwen3-ASR"}.get(transcription_provider, transcription_provider)
- console.print(f" β
{service_config['description']} ({provider_label}) [dim](auto-selected)[/dim]")
+ provider_label = {
+ "vibevoice": "VibeVoice",
+ "parakeet": "Parakeet",
+ "qwen3-asr": "Qwen3-ASR",
+ }.get(transcription_provider, transcription_provider)
+ console.print(
+ f" β
{service_config['description']} ({provider_label}) [dim](auto-selected)[/dim]"
+ )
continue
# LangFuse is handled separately via setup_langfuse_choice()
- if service_name == 'langfuse':
+ if service_name == "langfuse":
continue
# Check if service exists
@@ -152,10 +194,12 @@ def select_services(transcription_provider=None):
continue
# Speaker recognition is recommended by default
- default_enable = service_name == 'speaker-recognition'
+ default_enable = service_name == "speaker-recognition"
try:
- enable_service = Confirm.ask(f" Setup {service_config['description']}?", default=default_enable)
+ enable_service = Confirm.ask(
+ f" Setup {service_config['description']}?", default=default_enable
+ )
except EOFError:
console.print(f"Using default: {'Yes' if default_enable else 'No'}")
enable_service = default_enable
@@ -165,213 +209,268 @@ def select_services(transcription_provider=None):
return selected
+
def cleanup_unselected_services(selected_services):
"""Backup and remove .env files from services that weren't selected"""
-
- all_services = list(SERVICES['backend'].keys()) + list(SERVICES['extras'].keys())
-
+
+ all_services = list(SERVICES["backend"].keys()) + list(SERVICES["extras"].keys())
+
for service_name in all_services:
if service_name not in selected_services:
- if service_name == 'advanced':
- service_path = Path(SERVICES['backend'][service_name]['path'])
+ if service_name == "advanced":
+ service_path = Path(SERVICES["backend"][service_name]["path"])
else:
- service_path = Path(SERVICES['extras'][service_name]['path'])
-
- env_file = service_path / '.env'
+ service_path = Path(SERVICES["extras"][service_name]["path"])
+
+ env_file = service_path / ".env"
if env_file.exists():
# Create backup with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- backup_file = service_path / f'.env.backup.{timestamp}.unselected'
+ backup_file = service_path / f".env.backup.{timestamp}.unselected"
env_file.rename(backup_file)
- console.print(f"π§Ή [dim]Backed up {service_name} configuration to {backup_file.name} (service not selected)[/dim]")
-
-def run_service_setup(service_name, selected_services, https_enabled=False, server_ip=None,
- obsidian_enabled=False, neo4j_password=None, hf_token=None,
- transcription_provider='deepgram', admin_email=None, admin_password=None,
- langfuse_public_key=None, langfuse_secret_key=None, langfuse_host=None,
- streaming_provider=None):
+ console.print(
+ f"π§Ή [dim]Backed up {service_name} configuration to {backup_file.name} (service not selected)[/dim]"
+ )
+
+
+def run_service_setup(
+ service_name,
+ selected_services,
+ https_enabled=False,
+ server_ip=None,
+ obsidian_enabled=False,
+ neo4j_password=None,
+ hf_token=None,
+ transcription_provider="deepgram",
+ admin_email=None,
+ admin_password=None,
+ langfuse_public_key=None,
+ langfuse_secret_key=None,
+ langfuse_host=None,
+ streaming_provider=None,
+):
"""Execute individual service setup script"""
- if service_name == 'advanced':
- service = SERVICES['backend'][service_name]
+ if service_name == "advanced":
+ service = SERVICES["backend"][service_name]
# For advanced backend, pass URLs of other selected services and HTTPS config
- cmd = service['cmd'].copy()
- if 'speaker-recognition' in selected_services:
- cmd.extend(['--speaker-service-url', 'http://speaker-service:8085'])
- if 'asr-services' in selected_services:
- cmd.extend(['--parakeet-asr-url', 'http://host.docker.internal:8767'])
+ cmd = service["cmd"].copy()
+ if "speaker-recognition" in selected_services:
+ cmd.extend(["--speaker-service-url", "http://speaker-service:8085"])
+ if "asr-services" in selected_services:
+ cmd.extend(["--parakeet-asr-url", "http://host.docker.internal:8767"])
# Pass transcription provider choice from wizard
if transcription_provider:
- cmd.extend(['--transcription-provider', transcription_provider])
+ cmd.extend(["--transcription-provider", transcription_provider])
# Pass streaming provider (different from batch) for re-transcription setup
if streaming_provider:
- cmd.extend(['--streaming-provider', streaming_provider])
+ cmd.extend(["--streaming-provider", streaming_provider])
# Add HTTPS configuration
if https_enabled and server_ip:
- cmd.extend(['--enable-https', '--server-ip', server_ip])
+ cmd.extend(["--enable-https", "--server-ip", server_ip])
# Always pass Neo4j password (neo4j is a required service)
if neo4j_password:
- cmd.extend(['--neo4j-password', neo4j_password])
+ cmd.extend(["--neo4j-password", neo4j_password])
# Add Obsidian configuration
if obsidian_enabled:
- cmd.extend(['--enable-obsidian'])
+ cmd.extend(["--enable-obsidian"])
# Pass LangFuse keys from langfuse init or external config
if langfuse_public_key and langfuse_secret_key:
- cmd.extend(['--langfuse-public-key', langfuse_public_key])
- cmd.extend(['--langfuse-secret-key', langfuse_secret_key])
+ cmd.extend(["--langfuse-public-key", langfuse_public_key])
+ cmd.extend(["--langfuse-secret-key", langfuse_secret_key])
if langfuse_host:
- cmd.extend(['--langfuse-host', langfuse_host])
+ cmd.extend(["--langfuse-host", langfuse_host])
else:
- service = SERVICES['extras'][service_name]
- cmd = service['cmd'].copy()
-
+ service = SERVICES["extras"][service_name]
+ cmd = service["cmd"].copy()
+
# Add HTTPS configuration for services that support it
- if service_name == 'speaker-recognition' and https_enabled and server_ip:
- cmd.extend(['--enable-https', '--server-ip', server_ip])
+ if service_name == "speaker-recognition" and https_enabled and server_ip:
+ cmd.extend(["--enable-https", "--server-ip", server_ip])
# For speaker-recognition, pass HF_TOKEN from centralized configuration
- if service_name == 'speaker-recognition':
+ if service_name == "speaker-recognition":
# Define the speaker env path
- speaker_env_path = 'extras/speaker-recognition/.env'
+ speaker_env_path = "extras/speaker-recognition/.env"
# HF Token should have been provided via setup_hf_token_if_needed()
if hf_token:
- cmd.extend(['--hf-token', hf_token])
+ cmd.extend(["--hf-token", hf_token])
else:
- console.print("[yellow][WARNING][/yellow] No HF_TOKEN provided - speaker recognition may fail to download models")
+ console.print(
+ "[yellow][WARNING][/yellow] No HF_TOKEN provided - speaker recognition may fail to download models"
+ )
# Pass Deepgram API key from backend if available
- backend_env_path = 'backends/advanced/.env'
- deepgram_key = read_env_value(backend_env_path, 'DEEPGRAM_API_KEY')
- if deepgram_key and not is_placeholder(deepgram_key, 'your_deepgram_api_key_here', 'your-deepgram-api-key-here'):
- cmd.extend(['--deepgram-api-key', deepgram_key])
- console.print("[blue][INFO][/blue] Found existing DEEPGRAM_API_KEY from backend config, reusing")
+ backend_env_path = "backends/advanced/.env"
+ deepgram_key = read_env_value(backend_env_path, "DEEPGRAM_API_KEY")
+ if deepgram_key and not is_placeholder(
+ deepgram_key, "your_deepgram_api_key_here", "your-deepgram-api-key-here"
+ ):
+ cmd.extend(["--deepgram-api-key", deepgram_key])
+ console.print(
+ "[blue][INFO][/blue] Found existing DEEPGRAM_API_KEY from backend config, reusing"
+ )
# Pass compute mode from existing .env if available
- compute_mode = read_env_value(speaker_env_path, 'COMPUTE_MODE')
- if compute_mode in ['cpu', 'gpu']:
- cmd.extend(['--compute-mode', compute_mode])
- console.print(f"[blue][INFO][/blue] Found existing COMPUTE_MODE ({compute_mode}), reusing")
-
+ compute_mode = read_env_value(speaker_env_path, "COMPUTE_MODE")
+ if compute_mode in ["cpu", "gpu"]:
+ cmd.extend(["--compute-mode", compute_mode])
+ console.print(
+ f"[blue][INFO][/blue] Found existing COMPUTE_MODE ({compute_mode}), reusing"
+ )
+
# For asr-services, pass provider from wizard's transcription choice and reuse CUDA version
- if service_name == 'asr-services':
+ if service_name == "asr-services":
# Map wizard transcription provider to asr-services provider name
wizard_to_asr_provider = {
- 'vibevoice': 'vibevoice',
- 'parakeet': 'nemo',
- 'qwen3-asr': 'qwen3-asr',
+ "vibevoice": "vibevoice",
+ "parakeet": "nemo",
+ "qwen3-asr": "qwen3-asr",
}
asr_provider = wizard_to_asr_provider.get(transcription_provider)
if asr_provider:
- cmd.extend(['--provider', asr_provider])
- console.print(f"[blue][INFO][/blue] Pre-selecting ASR provider: {asr_provider} (from wizard choice: {transcription_provider})")
-
- speaker_env_path = 'extras/speaker-recognition/.env'
- cuda_version = read_env_value(speaker_env_path, 'PYTORCH_CUDA_VERSION')
- if cuda_version and cuda_version in ['cu121', 'cu126', 'cu128']:
- cmd.extend(['--pytorch-cuda-version', cuda_version])
- console.print(f"[blue][INFO][/blue] Found existing PYTORCH_CUDA_VERSION ({cuda_version}) from speaker-recognition, reusing")
+ cmd.extend(["--provider", asr_provider])
+ console.print(
+ f"[blue][INFO][/blue] Pre-selecting ASR provider: {asr_provider} (from wizard choice: {transcription_provider})"
+ )
+
+ speaker_env_path = "extras/speaker-recognition/.env"
+ cuda_version = read_env_value(speaker_env_path, "PYTORCH_CUDA_VERSION")
+ if cuda_version and cuda_version in ["cu121", "cu126", "cu128"]:
+ cmd.extend(["--pytorch-cuda-version", cuda_version])
+ console.print(
+ f"[blue][INFO][/blue] Found existing PYTORCH_CUDA_VERSION ({cuda_version}) from speaker-recognition, reusing"
+ )
# For langfuse, pass admin credentials from backend
- if service_name == 'langfuse':
+ if service_name == "langfuse":
if admin_email:
- cmd.extend(['--admin-email', admin_email])
+ cmd.extend(["--admin-email", admin_email])
if admin_password:
- cmd.extend(['--admin-password', admin_password])
+ cmd.extend(["--admin-password", admin_password])
# For openmemory-mcp, try to pass OpenAI API key from backend if available
- if service_name == 'openmemory-mcp':
- backend_env_path = 'backends/advanced/.env'
- openai_key = read_env_value(backend_env_path, 'OPENAI_API_KEY')
- if openai_key and not is_placeholder(openai_key, 'your_openai_api_key_here', 'your-openai-api-key-here', 'your_openai_key_here', 'your-openai-key-here'):
- cmd.extend(['--openai-api-key', openai_key])
- console.print("[blue][INFO][/blue] Found existing OPENAI_API_KEY from backend config, reusing")
-
+ if service_name == "openmemory-mcp":
+ backend_env_path = "backends/advanced/.env"
+ openai_key = read_env_value(backend_env_path, "OPENAI_API_KEY")
+ if openai_key and not is_placeholder(
+ openai_key,
+ "your_openai_api_key_here",
+ "your-openai-api-key-here",
+ "your_openai_key_here",
+ "your-openai-key-here",
+ ):
+ cmd.extend(["--openai-api-key", openai_key])
+ console.print(
+ "[blue][INFO][/blue] Found existing OPENAI_API_KEY from backend config, reusing"
+ )
+
console.print(f"\nπ§ [bold]Setting up {service_name}...[/bold]")
-
+
# Check if service exists before running
exists, msg = check_service_exists(service_name, service)
if not exists:
console.print(f"β {service_name} setup failed: {msg}")
return False
-
+
try:
result = subprocess.run(
- cmd,
- cwd=service['path'],
+ cmd,
+ cwd=service["path"],
check=True,
- timeout=300 # 5 minute timeout for service setup
+ timeout=300, # 5 minute timeout for service setup
)
-
+
console.print(f"β
{service_name} setup completed")
return True
-
+
except FileNotFoundError as e:
console.print(f"β {service_name} setup failed: {e}")
- console.print(f"[yellow] Check that the service directory exists: {service['path']}[/yellow]")
- console.print(f"[yellow] And that 'uv' is installed and on your PATH[/yellow]")
+ console.print(
+ f"[yellow] Check that the service directory exists: {service['path']}[/yellow]"
+ )
+ console.print(
+ f"[yellow] And that 'uv' is installed and on your PATH[/yellow]"
+ )
return False
except subprocess.TimeoutExpired as e:
console.print(f"β {service_name} setup timed out after {e.timeout}s")
console.print(f"[yellow] Configuration may be partially written.[/yellow]")
console.print(f"[yellow] To retry just this service:[/yellow]")
- console.print(f"[yellow] cd {service['path']} && {' '.join(service['cmd'])}[/yellow]")
+ console.print(
+ f"[yellow] cd {service['path']} && {' '.join(service['cmd'])}[/yellow]"
+ )
return False
except subprocess.CalledProcessError as e:
console.print(f"β {service_name} setup failed with exit code {e.returncode}")
console.print(f"[yellow] Check the error output above for details.[/yellow]")
console.print(f"[yellow] To retry just this service:[/yellow]")
- console.print(f"[yellow] cd {service['path']} && {' '.join(service['cmd'])}[/yellow]")
+ console.print(
+ f"[yellow] cd {service['path']} && {' '.join(service['cmd'])}[/yellow]"
+ )
return False
except Exception as e:
console.print(f"β {service_name} setup failed: {e}")
return False
+
def show_service_status():
"""Show which services are available"""
console.print("\nπ [bold]Service Status:[/bold]")
-
+
# Check backend
- exists, msg = check_service_exists('advanced', SERVICES['backend']['advanced'])
+ exists, msg = check_service_exists("advanced", SERVICES["backend"]["advanced"])
status = "β
" if exists else "β"
console.print(f" {status} Advanced Backend - {msg}")
-
+
# Check extras
- for service_name, service_config in SERVICES['extras'].items():
+ for service_name, service_config in SERVICES["extras"].items():
exists, msg = check_service_exists(service_name, service_config)
status = "β
" if exists else "βΈοΈ"
console.print(f" {status} {service_config['description']} - {msg}")
+
def run_plugin_setup(plugin_id, plugin_info):
"""Run a plugin's setup.py script"""
- setup_path = plugin_info['setup_path']
+ setup_path = plugin_info["setup_path"]
try:
# Run plugin setup script interactively (don't capture output)
# This allows the plugin to prompt for user input
result = subprocess.run(
- ['uv', 'run', '--with-requirements', 'setup-requirements.txt', 'python', str(setup_path)],
- cwd=str(Path.cwd())
+ [
+ "uv",
+ "run",
+ "--with-requirements",
+ "setup-requirements.txt",
+ "python",
+ str(setup_path),
+ ],
+ cwd=str(Path.cwd()),
)
if result.returncode == 0:
console.print(f"\n[green]β
{plugin_id} configured successfully[/green]")
return True
else:
- console.print(f"\n[red]β {plugin_id} setup failed with exit code {result.returncode}[/red]")
+ console.print(
+ f"\n[red]β {plugin_id} setup failed with exit code {result.returncode}[/red]"
+ )
return False
except Exception as e:
console.print(f"[red]β Error running {plugin_id} setup: {e}[/red]")
return False
+
def setup_plugins():
"""Discover and setup plugins via delegation"""
console.print("\nπ [bold cyan]Plugin Configuration[/bold cyan]")
@@ -386,10 +485,7 @@ def setup_plugins():
# Ask about enabling community plugins
try:
- enable_plugins = Confirm.ask(
- "Enable community plugins?",
- default=True
- )
+ enable_plugins = Confirm.ask("Enable community plugins?", default=True)
except EOFError:
console.print("Using default: Yes")
enable_plugins = True
@@ -401,16 +497,15 @@ def setup_plugins():
# For each plugin with setup script
configured_count = 0
for plugin_id, plugin_info in available_plugins.items():
- if not plugin_info['has_setup']:
- console.print(f"[dim] {plugin_id}: No setup wizard available (configure manually)[/dim]")
+ if not plugin_info["has_setup"]:
+ console.print(
+ f"[dim] {plugin_id}: No setup wizard available (configure manually)[/dim]"
+ )
continue
# Ask if user wants to configure this plugin
try:
- configure = Confirm.ask(
- f" Configure {plugin_id} plugin?",
- default=False
- )
+ configure = Confirm.ask(f" Configure {plugin_id} plugin?", default=False)
except EOFError:
configure = False
@@ -423,36 +518,53 @@ def setup_plugins():
console.print(f"\n[green]β
Configured {configured_count} plugin(s)[/green]")
+
def setup_git_hooks():
"""Setup pre-commit hooks for development"""
console.print("\nπ§ [bold]Setting up development environment...[/bold]")
+ # Check if git is available
+ if not shutil.which("git"):
+ console.print(
+ "β οΈ [yellow]git not found, skipping git hooks setup (optional)[/yellow]"
+ )
+ return
+
try:
- # Install pre-commit if not already installed
- subprocess.run(['pip', 'install', 'pre-commit'],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- check=False)
+ # Install pre-commit via uv tool (uv is our package manager)
+ subprocess.run(
+ ["uv", "tool", "install", "pre-commit"],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ check=False,
+ )
# Install git hooks
- result = subprocess.run(['pre-commit', 'install', '--hook-type', 'pre-push'],
- capture_output=True,
- text=True)
+ result = subprocess.run(
+ ["pre-commit", "install", "--hook-type", "pre-push"],
+ capture_output=True,
+ text=True,
+ )
if result.returncode == 0:
- console.print("β
[green]Git hooks installed (tests will run before push)[/green]")
+ console.print(
+ "β
[green]Git hooks installed (tests will run before push)[/green]"
+ )
else:
console.print("β οΈ [yellow]Could not install git hooks (optional)[/yellow]")
# Also install pre-commit hook
- subprocess.run(['pre-commit', 'install', '--hook-type', 'pre-commit'],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- check=False)
+ subprocess.run(
+ ["pre-commit", "install", "--hook-type", "pre-commit"],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ check=False,
+ )
except Exception as e:
console.print(f"β οΈ [yellow]Could not setup git hooks: {e} (optional)[/yellow]")
+
def setup_hf_token_if_needed(selected_services):
"""Prompt for Hugging Face token if needed by selected services.
@@ -463,40 +575,54 @@ def setup_hf_token_if_needed(selected_services):
HF_TOKEN string if provided, None otherwise
"""
# Check if any selected services need HF_TOKEN
- needs_hf_token = 'speaker-recognition' in selected_services
+ needs_hf_token = "speaker-recognition" in selected_services
if not needs_hf_token:
return None
console.print("\nπ€ [bold cyan]Hugging Face Token Configuration[/bold cyan]")
console.print("Required for speaker recognition (PyAnnote models)")
- console.print("\n[blue][INFO][/blue] Get your token from: https://huggingface.co/settings/tokens")
+ console.print(
+ "\n[blue][INFO][/blue] Get your token from: https://huggingface.co/settings/tokens"
+ )
console.print()
- console.print("[yellow]β οΈ You must also accept the model agreements for these gated models:[/yellow]")
+ console.print(
+ "[yellow]β οΈ You must also accept the model agreements for these gated models:[/yellow]"
+ )
console.print(" 1. [cyan]Speaker Diarization[/cyan]")
- console.print(" https://huggingface.co/pyannote/speaker-diarization-community-1")
+ console.print(
+ " https://huggingface.co/pyannote/speaker-diarization-community-1"
+ )
console.print(" 2. [cyan]Segmentation Model[/cyan]")
console.print(" https://huggingface.co/pyannote/segmentation-3.0")
console.print(" 3. [cyan]Segmentation Model[/cyan]")
console.print(" https://huggingface.co/pyannote/segmentation-3.1")
console.print(" 4. [cyan]Embedding Model[/cyan]")
- console.print(" https://huggingface.co/pyannote/wespeaker-voxceleb-resnet34-LM")
+ console.print(
+ " https://huggingface.co/pyannote/wespeaker-voxceleb-resnet34-LM"
+ )
console.print()
- console.print("[yellow]β[/yellow] Open each link and click 'Agree and access repository'")
+ console.print(
+ "[yellow]β[/yellow] Open each link and click 'Agree and access repository'"
+ )
console.print("[yellow]β[/yellow] Use the same Hugging Face account as your token")
console.print()
# Check for existing token from speaker-recognition service
- speaker_env_path = 'extras/speaker-recognition/.env'
- existing_token = read_env_value(speaker_env_path, 'HF_TOKEN')
+ speaker_env_path = "extras/speaker-recognition/.env"
+ existing_token = read_env_value(speaker_env_path, "HF_TOKEN")
# Use the masked prompt function
hf_token = prompt_with_existing_masked(
prompt_text="Hugging Face Token",
existing_value=existing_token,
- placeholders=['your_huggingface_token_here', 'your-huggingface-token-here', 'hf_xxxxx'],
+ placeholders=[
+ "your_huggingface_token_here",
+ "your-huggingface-token-here",
+ "hf_xxxxx",
+ ],
is_password=True,
- default=""
+ default="",
)
if hf_token:
@@ -504,9 +630,12 @@ def setup_hf_token_if_needed(selected_services):
console.print(f"[green]β
HF_TOKEN configured: {masked}[/green]\n")
return hf_token
else:
- console.print("[yellow]β οΈ No HF_TOKEN provided - speaker recognition may fail[/yellow]\n")
+ console.print(
+ "[yellow]β οΈ No HF_TOKEN provided - speaker recognition may fail[/yellow]\n"
+ )
return None
+
def setup_config_file():
"""Setup config/config.yml from template if it doesn't exist"""
config_file = Path("config/config.yml")
@@ -519,9 +648,14 @@ def setup_config_file():
shutil.copy(config_template, config_file)
console.print("β
[green]Created config/config.yml from template[/green]")
else:
- console.print("β οΈ [yellow]config/config.yml.template not found, skipping config setup[/yellow]")
+ console.print(
+ "β οΈ [yellow]config/config.yml.template not found, skipping config setup[/yellow]"
+ )
else:
- console.print("βΉοΈ [blue]config/config.yml already exists, keeping existing configuration[/blue]")
+ console.print(
+ "βΉοΈ [blue]config/config.yml already exists, keeping existing configuration[/blue]"
+ )
+
# Providers that support real-time streaming
STREAMING_CAPABLE = {"deepgram", "smallest", "qwen3-asr"}
@@ -530,8 +664,12 @@ def setup_config_file():
def select_transcription_provider():
"""Ask user which transcription provider they want (batch/primary)."""
console.print("\nπ€ [bold cyan]Transcription Provider[/bold cyan]")
- console.print("Choose your speech-to-text provider (used for [bold]batch[/bold]/high-quality transcription):")
- console.print("[dim]If it also supports streaming, it will be used for real-time too by default.[/dim]")
+ console.print(
+ "Choose your speech-to-text provider (used for [bold]batch[/bold]/high-quality transcription):"
+ )
+ console.print(
+ "[dim]If it also supports streaming, it will be used for real-time too by default.[/dim]"
+ )
console.print()
choices = {
@@ -540,7 +678,7 @@ def select_transcription_provider():
"3": "VibeVoice ASR (offline, batch only, built-in diarization, GPU)",
"4": "Qwen3-ASR (offline, streaming + batch, 52 languages, GPU)",
"5": "Smallest.ai Pulse (cloud, streaming + batch)",
- "6": "None (skip transcription setup)"
+ "6": "None (skip transcription setup)",
}
for key, desc in choices.items():
@@ -563,7 +701,9 @@ def select_transcription_provider():
return "smallest"
elif choice == "6":
return "none"
- console.print(f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]")
+ console.print(
+ f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]"
+ )
except EOFError:
console.print("Using default: Deepgram")
return "deepgram"
@@ -586,7 +726,9 @@ def select_streaming_provider(batch_provider):
console.print(f"\nπ [bold cyan]Streaming[/bold cyan]")
console.print(f"{batch_provider} supports both batch and streaming.")
try:
- use_different = Confirm.ask("Use a different provider for real-time streaming?", default=False)
+ use_different = Confirm.ask(
+ "Use a different provider for real-time streaming?", default=False
+ )
except EOFError:
return None
if not use_different:
@@ -594,7 +736,9 @@ def select_streaming_provider(batch_provider):
else:
# Batch-only provider β need to pick a streaming provider
console.print(f"\nπ [bold cyan]Streaming[/bold cyan]")
- console.print(f"{batch_provider} is batch-only. Pick a streaming provider for real-time transcription:")
+ console.print(
+ f"{batch_provider} is batch-only. Pick a streaming provider for real-time transcription:"
+ )
# Show streaming-capable providers (excluding the batch provider)
streaming_choices = {}
@@ -625,9 +769,13 @@ def select_streaming_provider(batch_provider):
if choice in streaming_choices:
result = provider_map[choice]
if result:
- console.print(f"[green]β
[/green] Streaming: {result}, Batch: {batch_provider}")
+ console.print(
+ f"[green]β
[/green] Streaming: {result}, Batch: {batch_provider}"
+ )
return result
- console.print(f"[red]Invalid choice. Please select from {list(streaming_choices.keys())}[/red]")
+ console.print(
+ f"[red]Invalid choice. Please select from {list(streaming_choices.keys())}[/red]"
+ )
except EOFError:
return None
@@ -648,31 +796,36 @@ def setup_langfuse_choice():
console.print()
try:
- has_existing = Confirm.ask("Use an existing external LangFuse instance instead of local?", default=False)
+ has_existing = Confirm.ask(
+ "Use an existing external LangFuse instance instead of local?",
+ default=False,
+ )
except EOFError:
console.print("Using default: No (will set up locally)")
has_existing = False
if not has_existing:
# Check if the local langfuse directory exists
- exists, msg = check_service_exists('langfuse', SERVICES['extras']['langfuse'])
+ exists, msg = check_service_exists("langfuse", SERVICES["extras"]["langfuse"])
if exists:
console.print("[green]β
[/green] Will set up local LangFuse instance")
- return 'local', {}
+ return "local", {}
else:
console.print(f"[yellow]β οΈ Local LangFuse not available: {msg}[/yellow]")
- console.print("[yellow] Will proceed without LangFuse β add it later when available[/yellow]")
- return 'local', {}
+ console.print(
+ "[yellow] Will proceed without LangFuse β add it later when available[/yellow]"
+ )
+ return "local", {}
# External LangFuse β collect connection details
console.print()
console.print("[bold]Enter your external LangFuse connection details:[/bold]")
- backend_env_path = 'backends/advanced/.env'
+ backend_env_path = "backends/advanced/.env"
- existing_host = read_env_value(backend_env_path, 'LANGFUSE_HOST')
+ existing_host = read_env_value(backend_env_path, "LANGFUSE_HOST")
# Don't treat the local docker host as an existing external value
- if existing_host and 'langfuse-web' in existing_host:
+ if existing_host and "langfuse-web" in existing_host:
existing_host = None
host = prompt_with_existing_masked(
@@ -680,36 +833,38 @@ def setup_langfuse_choice():
existing_value=existing_host,
placeholders=[""],
is_password=False,
- default="https://cloud.langfuse.com"
+ default="https://cloud.langfuse.com",
)
- existing_pub = read_env_value(backend_env_path, 'LANGFUSE_PUBLIC_KEY')
+ existing_pub = read_env_value(backend_env_path, "LANGFUSE_PUBLIC_KEY")
public_key = prompt_with_existing_masked(
prompt_text="LangFuse public key",
existing_value=existing_pub,
placeholders=[""],
is_password=False,
- default=""
+ default="",
)
- existing_sec = read_env_value(backend_env_path, 'LANGFUSE_SECRET_KEY')
+ existing_sec = read_env_value(backend_env_path, "LANGFUSE_SECRET_KEY")
secret_key = prompt_with_existing_masked(
prompt_text="LangFuse secret key",
existing_value=existing_sec,
placeholders=[""],
is_password=True,
- default=""
+ default="",
)
if not (host and public_key and secret_key):
- console.print("[yellow]β οΈ Incomplete LangFuse configuration β skipping[/yellow]")
+ console.print(
+ "[yellow]β οΈ Incomplete LangFuse configuration β skipping[/yellow]"
+ )
return None, {}
console.print(f"[green]β
[/green] External LangFuse configured: {host}")
- return 'external', {
- 'host': host,
- 'public_key': public_key,
- 'secret_key': secret_key,
+ return "external", {
+ "host": host,
+ "public_key": public_key,
+ "secret_key": secret_key,
}
@@ -717,8 +872,12 @@ def main():
"""Main orchestration logic"""
console.print("π [bold green]Welcome to Chronicle![/bold green]\n")
console.print("[dim]This wizard is safe to run as many times as you like.[/dim]")
- console.print("[dim]It backs up your existing config and preserves previously entered values.[/dim]")
- console.print("[dim]When unsure, just press Enter β the defaults will work.[/dim]\n")
+ console.print(
+ "[dim]It backs up your existing config and preserves previously entered values.[/dim]"
+ )
+ console.print(
+ "[dim]When unsure, just press Enter β the defaults will work.[/dim]\n"
+ )
# Setup config file from template
setup_config_file()
@@ -740,11 +899,19 @@ def main():
# Auto-add asr-services if any local ASR was chosen (batch or streaming)
local_asr_providers = ("parakeet", "vibevoice", "qwen3-asr")
- needs_asr = transcription_provider in local_asr_providers or (streaming_provider and streaming_provider in local_asr_providers)
- if needs_asr and 'asr-services' not in selected_services:
- reason = transcription_provider if transcription_provider in local_asr_providers else streaming_provider
- console.print(f"[blue][INFO][/blue] Auto-adding ASR services for {reason} transcription")
- selected_services.append('asr-services')
+ needs_asr = transcription_provider in local_asr_providers or (
+ streaming_provider and streaming_provider in local_asr_providers
+ )
+ if needs_asr and "asr-services" not in selected_services:
+ reason = (
+ transcription_provider
+ if transcription_provider in local_asr_providers
+ else streaming_provider
+ )
+ console.print(
+ f"[blue][INFO][/blue] Auto-adding ASR services for {reason} transcription"
+ )
+ selected_services.append("asr-services")
if not selected_services:
console.print("\n[yellow]No services selected. Exiting.[/yellow]")
@@ -752,8 +919,8 @@ def main():
# LangFuse Configuration (before service setup so keys can be passed to backend)
langfuse_mode, langfuse_external = setup_langfuse_choice()
- if langfuse_mode == 'local' and 'langfuse' not in selected_services:
- selected_services.append('langfuse')
+ if langfuse_mode == "local" and "langfuse" not in selected_services:
+ selected_services.append("langfuse")
# HF Token Configuration (if services require it)
hf_token = setup_hf_token_if_needed(selected_services)
@@ -761,17 +928,24 @@ def main():
# HTTPS Configuration (for services that need it)
https_enabled = False
server_ip = None
-
+
# Check if we have services that benefit from HTTPS
- https_services = {'advanced', 'speaker-recognition'} # advanced will always need https then
+ https_services = {
+ "advanced",
+ "speaker-recognition",
+ } # advanced will always need https then
needs_https = bool(https_services.intersection(selected_services))
-
+
if needs_https:
console.print("\nπ [bold cyan]HTTPS Configuration[/bold cyan]")
- console.print("HTTPS enables microphone access in browsers and secure connections")
+ console.print(
+ "HTTPS enables microphone access in browsers and secure connections"
+ )
try:
- https_enabled = Confirm.ask("Enable HTTPS for selected services?", default=False)
+ https_enabled = Confirm.ask(
+ "Enable HTTPS for selected services?", default=False
+ )
except EOFError:
console.print("Using default: No")
https_enabled = False
@@ -781,24 +955,30 @@ def main():
ts_dns, ts_ip = detect_tailscale_info()
if ts_dns:
- console.print(f"\n[green][AUTO-DETECTED][/green] Tailscale DNS: {ts_dns}")
+ console.print(
+ f"\n[green][AUTO-DETECTED][/green] Tailscale DNS: {ts_dns}"
+ )
if ts_ip:
- console.print(f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}")
+ console.print(
+ f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}"
+ )
default_address = ts_dns
elif ts_ip:
console.print(f"\n[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}")
default_address = ts_ip
else:
console.print("\n[blue][INFO][/blue] Tailscale not detected")
- console.print("[blue][INFO][/blue] To find your Tailscale address: tailscale status --json | jq -r '.Self.DNSName'")
+ console.print(
+ "[blue][INFO][/blue] To find your Tailscale address: tailscale status --json | jq -r '.Self.DNSName'"
+ )
default_address = None
console.print("[blue][INFO][/blue] For local-only access, use 'localhost'")
console.print("Examples: localhost, myhost.tail1234.ts.net, 100.64.1.2")
# Check for existing SERVER_IP from backend .env
- backend_env_path = 'backends/advanced/.env'
- existing_ip = read_env_value(backend_env_path, 'SERVER_IP')
+ backend_env_path = "backends/advanced/.env"
+ existing_ip = read_env_value(backend_env_path, "SERVER_IP")
# Use existing value, or auto-detected address, or localhost as default
effective_default = default_address or "localhost"
@@ -806,9 +986,9 @@ def main():
server_ip = prompt_with_existing_masked(
prompt_text="Server IP/Domain for SSL certificates",
existing_value=existing_ip,
- placeholders=['localhost', 'your-server-ip-here'],
+ placeholders=["localhost", "your-server-ip-here"],
is_password=False,
- default=effective_default
+ default=effective_default,
)
console.print(f"[green]β
[/green] HTTPS configured for: {server_ip}")
@@ -817,14 +997,18 @@ def main():
neo4j_password = None
obsidian_enabled = False
- if 'advanced' in selected_services:
+ if "advanced" in selected_services:
console.print("\nποΈ [bold cyan]Neo4j Configuration[/bold cyan]")
- console.print("Neo4j is used for Knowledge Graph (entity/relationship extraction from conversations)")
+ console.print(
+ "Neo4j is used for Knowledge Graph (entity/relationship extraction from conversations)"
+ )
console.print()
# Always prompt for Neo4j password (masked input)
try:
- console.print("Neo4j password (min 8 chars) [leave empty for default: neo4jpassword]")
+ console.print(
+ "Neo4j password (min 8 chars) [leave empty for default: neo4jpassword]"
+ )
neo4j_password = prompt_password("Neo4j password", min_length=8)
except (EOFError, KeyboardInterrupt):
neo4j_password = "neo4jpassword"
@@ -836,11 +1020,15 @@ def main():
# Obsidian is optional (graph-based knowledge management for vault notes)
console.print("\nποΈ [bold cyan]Obsidian Integration (Optional)[/bold cyan]")
- console.print("Enable graph-based knowledge management for Obsidian vault notes")
+ console.print(
+ "Enable graph-based knowledge management for Obsidian vault notes"
+ )
console.print()
try:
- obsidian_enabled = Confirm.ask("Enable Obsidian integration?", default=False)
+ obsidian_enabled = Confirm.ask(
+ "Enable Obsidian integration?", default=False
+ )
except EOFError:
console.print("Using default: No")
obsidian_enabled = False
@@ -858,40 +1046,59 @@ def main():
failed_services = []
# Pre-populate langfuse keys from external config (if user chose external mode)
- langfuse_public_key = langfuse_external.get('public_key')
- langfuse_secret_key = langfuse_external.get('secret_key')
- langfuse_host = langfuse_external.get('host') # None for local (backend defaults to langfuse-web)
+ langfuse_public_key = langfuse_external.get("public_key")
+ langfuse_secret_key = langfuse_external.get("secret_key")
+ langfuse_host = langfuse_external.get(
+ "host"
+ ) # None for local (backend defaults to langfuse-web)
# Determine setup order: langfuse first (to get API keys), then backend (with langfuse keys), then others
setup_order = []
- if 'langfuse' in selected_services:
- setup_order.append('langfuse')
- if 'advanced' in selected_services:
- setup_order.append('advanced')
+ if "langfuse" in selected_services:
+ setup_order.append("langfuse")
+ if "advanced" in selected_services:
+ setup_order.append("advanced")
for service in selected_services:
if service not in setup_order:
setup_order.append(service)
# Read admin credentials from existing backend .env (for langfuse init reuse)
- backend_env_path = 'backends/advanced/.env'
- wizard_admin_email = read_env_value(backend_env_path, 'ADMIN_EMAIL')
- wizard_admin_password = read_env_value(backend_env_path, 'ADMIN_PASSWORD')
+ backend_env_path = "backends/advanced/.env"
+ wizard_admin_email = read_env_value(backend_env_path, "ADMIN_EMAIL")
+ wizard_admin_password = read_env_value(backend_env_path, "ADMIN_PASSWORD")
for service in setup_order:
- if run_service_setup(service, selected_services, https_enabled, server_ip,
- obsidian_enabled, neo4j_password, hf_token, transcription_provider,
- admin_email=wizard_admin_email, admin_password=wizard_admin_password,
- langfuse_public_key=langfuse_public_key, langfuse_secret_key=langfuse_secret_key,
- langfuse_host=langfuse_host, streaming_provider=streaming_provider):
+ if run_service_setup(
+ service,
+ selected_services,
+ https_enabled,
+ server_ip,
+ obsidian_enabled,
+ neo4j_password,
+ hf_token,
+ transcription_provider,
+ admin_email=wizard_admin_email,
+ admin_password=wizard_admin_password,
+ langfuse_public_key=langfuse_public_key,
+ langfuse_secret_key=langfuse_secret_key,
+ langfuse_host=langfuse_host,
+ streaming_provider=streaming_provider,
+ ):
success_count += 1
# After local langfuse setup, read generated API keys for backend
- if service == 'langfuse':
- langfuse_env_path = 'extras/langfuse/.env'
- langfuse_public_key = read_env_value(langfuse_env_path, 'LANGFUSE_INIT_PROJECT_PUBLIC_KEY')
- langfuse_secret_key = read_env_value(langfuse_env_path, 'LANGFUSE_INIT_PROJECT_SECRET_KEY')
+ if service == "langfuse":
+ langfuse_env_path = "extras/langfuse/.env"
+ langfuse_public_key = read_env_value(
+ langfuse_env_path, "LANGFUSE_INIT_PROJECT_PUBLIC_KEY"
+ )
+ langfuse_secret_key = read_env_value(
+ langfuse_env_path, "LANGFUSE_INIT_PROJECT_SECRET_KEY"
+ )
if langfuse_public_key and langfuse_secret_key:
- console.print("[blue][INFO][/blue] LangFuse API keys will be passed to backend configuration")
+ console.print(
+ "[blue][INFO][/blue] LangFuse API keys will be passed to backend configuration"
+ )
else:
failed_services.append(service)
@@ -902,11 +1109,13 @@ def main():
# Final Summary
console.print(f"\nπ [bold green]Setup Complete![/bold green]")
- console.print(f"β
{success_count}/{len(selected_services)} services configured successfully")
+ console.print(
+ f"β
{success_count}/{len(selected_services)} services configured successfully"
+ )
if failed_services:
console.print(f"β Failed services: {', '.join(failed_services)}")
-
+
# Next Steps
console.print("\nπ [bold]Next Steps:[/bold]")
@@ -914,68 +1123,101 @@ def main():
console.print("")
console.print("π [bold cyan]Configuration Files Updated:[/bold cyan]")
console.print(" β’ [green].env files[/green] - API keys and service URLs")
- console.print(" β’ [green]config.yml[/green] - Model definitions and memory provider settings")
+ console.print(
+ " β’ [green]config.yml[/green] - Model definitions and memory provider settings"
+ )
console.print("")
# Development Environment Setup
console.print("1. Setup development environment (git hooks, testing):")
console.print(" [cyan]make setup-dev[/cyan]")
- console.print(" [dim]This installs pre-commit hooks to run tests before pushing[/dim]")
+ console.print(
+ " [dim]This installs pre-commit hooks to run tests before pushing[/dim]"
+ )
console.print("")
# Service Management Commands
console.print("2. Start all configured services:")
console.print(" [cyan]./start.sh[/cyan]")
- console.print(" [dim]Or: uv run --with-requirements setup-requirements.txt python services.py start --all --build[/dim]")
+ console.print(
+ " [dim]Or: uv run --with-requirements setup-requirements.txt python services.py start --all --build[/dim]"
+ )
console.print("")
console.print("3. Or start individual services:")
-
+
configured_services = []
- if 'advanced' in selected_services and 'advanced' not in failed_services:
+ if "advanced" in selected_services and "advanced" not in failed_services:
configured_services.append("backend")
- if 'speaker-recognition' in selected_services and 'speaker-recognition' not in failed_services:
- configured_services.append("speaker-recognition")
- if 'asr-services' in selected_services and 'asr-services' not in failed_services:
+ if (
+ "speaker-recognition" in selected_services
+ and "speaker-recognition" not in failed_services
+ ):
+ configured_services.append("speaker-recognition")
+ if "asr-services" in selected_services and "asr-services" not in failed_services:
configured_services.append("asr-services")
- if 'openmemory-mcp' in selected_services and 'openmemory-mcp' not in failed_services:
+ if (
+ "openmemory-mcp" in selected_services
+ and "openmemory-mcp" not in failed_services
+ ):
configured_services.append("openmemory-mcp")
- if 'langfuse' in selected_services and 'langfuse' not in failed_services:
+ if "langfuse" in selected_services and "langfuse" not in failed_services:
configured_services.append("langfuse")
# LangFuse prompt management info
- if langfuse_mode == 'local' and 'langfuse' not in failed_services:
+ if langfuse_mode == "local" and "langfuse" not in failed_services:
console.print("")
- console.print("[bold cyan]Prompt Management:[/bold cyan] Once services are running, edit AI prompts at:")
+ console.print(
+ "[bold cyan]Prompt Management:[/bold cyan] Once services are running, edit AI prompts at:"
+ )
if https_enabled and server_ip:
- console.print(f" [link=https://{server_ip}:3443/project/chronicle/prompts]https://{server_ip}:3443/project/chronicle/prompts[/link]")
+ console.print(
+ f" [link=https://{server_ip}:3443/project/chronicle/prompts]https://{server_ip}:3443/project/chronicle/prompts[/link]"
+ )
else:
- console.print(" [link=http://localhost:3002/project/chronicle/prompts]http://localhost:3002/project/chronicle/prompts[/link]")
- elif langfuse_mode == 'external' and langfuse_host:
+ console.print(
+ " [link=http://localhost:3002/project/chronicle/prompts]http://localhost:3002/project/chronicle/prompts[/link]"
+ )
+ elif langfuse_mode == "external" and langfuse_host:
console.print("")
- console.print(f"[bold cyan]Prompt Management:[/bold cyan] Edit AI prompts at your LangFuse instance:")
+ console.print(
+ f"[bold cyan]Prompt Management:[/bold cyan] Edit AI prompts at your LangFuse instance:"
+ )
console.print(f" {langfuse_host}")
if configured_services:
service_list = " ".join(configured_services)
- console.print(f" [cyan]uv run --with-requirements setup-requirements.txt python services.py start {service_list}[/cyan]")
-
+ console.print(
+ f" [cyan]uv run --with-requirements setup-requirements.txt python services.py start {service_list}[/cyan]"
+ )
+
console.print("")
console.print("3. Check service status:")
console.print(" [cyan]./status.sh[/cyan]")
- console.print(" [dim]Or: uv run --with-requirements setup-requirements.txt python services.py status[/dim]")
+ console.print(
+ " [dim]Or: uv run --with-requirements setup-requirements.txt python services.py status[/dim]"
+ )
console.print("")
console.print("4. Stop services when done:")
console.print(" [cyan]./stop.sh[/cyan]")
- console.print(" [dim]Or: uv run --with-requirements setup-requirements.txt python services.py stop --all[/dim]")
-
+ console.print(
+ " [dim]Or: uv run --with-requirements setup-requirements.txt python services.py stop --all[/dim]"
+ )
+
console.print(f"\nπ [bold]Enjoy Chronicle![/bold]")
-
+
# Show individual service usage
console.print(f"\nπ‘ [dim]Tip: You can also setup services individually:[/dim]")
- console.print(f"[dim] cd backends/advanced && uv run --with-requirements ../../setup-requirements.txt python init.py[/dim]")
- console.print(f"[dim] cd extras/speaker-recognition && uv run --with-requirements ../../setup-requirements.txt python init.py[/dim]")
- console.print(f"[dim] cd extras/asr-services && uv run --with-requirements ../../setup-requirements.txt python init.py[/dim]")
+ console.print(
+ f"[dim] cd backends/advanced && uv run --with-requirements ../../setup-requirements.txt python init.py[/dim]"
+ )
+ console.print(
+ f"[dim] cd extras/speaker-recognition && uv run --with-requirements ../../setup-requirements.txt python init.py[/dim]"
+ )
+ console.print(
+ f"[dim] cd extras/asr-services && uv run --with-requirements ../../setup-requirements.txt python init.py[/dim]"
+ )
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/wizard.sh b/wizard.sh
index 02942349..fb749554 100755
--- a/wizard.sh
+++ b/wizard.sh
@@ -1 +1,3 @@
+#!/bin/bash
+source "$(dirname "$0")/scripts/check_uv.sh"
uv run --with-requirements setup-requirements.txt wizard.py
From 9fe9794011ac9c46ab9a4ca0430eca3070d7e681 Mon Sep 17 00:00:00 2001
From: Ankush Malaker <43288948+AnkushMalaker@users.noreply.github.com>
Date: Sun, 22 Feb 2026 07:45:11 +0000
Subject: [PATCH 2/5] Update button event handling and plugin architecture
- Changed button state terminology from `SINGLE_TAP` and `DOUBLE_TAP` to `SINGLE_PRESS` and `DOUBLE_PRESS` across various files, including documentation and code implementations.
- Enhanced the `send_button_event` method to reflect the updated button state values, ensuring consistency in event handling.
- Introduced new methods for managing button events in the plugin architecture, improving the overall interaction with device buttons.
- Updated tests to align with the new button state definitions, ensuring robust coverage for the updated functionality.
---
.../advanced/Docs/plugin-development-guide.md | 4 +-
.../clients/audio_stream_client.py | 4 +-
.../controllers/conversation_controller.py | 555 +++++++++++-------
.../controllers/websocket_controller.py | 501 +++++++---------
.../observability/otel_setup.py | 58 +-
.../src/advanced_omi_backend/plugins/base.py | 27 +-
.../advanced_omi_backend/plugins/events.py | 38 +-
.../workers/conversation_jobs.py | 4 +-
.../workers/memory_jobs.py | 6 +-
extras/friend-lite-sdk/friend_lite/button.py | 4 +-
plugins/homeassistant/command_parser.py | 22 +-
plugins/homeassistant/plugin.py | 254 ++++++--
.../websocket_streaming_tests.robot | 2 +-
tests/libs/audio_stream_library.py | 4 +-
tests/resources/websocket_keywords.robot | 4 +-
15 files changed, 870 insertions(+), 617 deletions(-)
diff --git a/backends/advanced/Docs/plugin-development-guide.md b/backends/advanced/Docs/plugin-development-guide.md
index d5ddf3fa..32fa700d 100644
--- a/backends/advanced/Docs/plugin-development-guide.md
+++ b/backends/advanced/Docs/plugin-development-guide.md
@@ -206,7 +206,7 @@ async def on_memory_processed(self, context: PluginContext):
**When**: OMI device button is pressed
**Context Data**:
-- `state` (str): Button state (`SINGLE_TAP`, `DOUBLE_TAP`)
+- `state` (str): Button state (`SINGLE_PRESS`, `DOUBLE_PRESS`)
- `timestamp` (float): Unix timestamp of the event
- `audio_uuid` (str): Current audio session UUID (may be None)
- `session_id` (str): Streaming session ID (for conversation close)
@@ -222,7 +222,7 @@ friend-lite-sdk (extras/friend-lite-sdk/)
β parse_button_event() converts payload β ButtonState IntEnum
β
BLE Client (extras/local-wearable-client/ or mobile app)
- β Formats as Wyoming protocol: {"type": "button-event", "data": {"state": "SINGLE_TAP"}}
+ β Formats as Wyoming protocol: {"type": "button-event", "data": {"state": "SINGLE_PRESS"}}
β Sends over WebSocket
β
Backend (websocket_controller.py)
diff --git a/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py b/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py
index 38fc2333..b374d8fe 100644
--- a/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py
+++ b/backends/advanced/src/advanced_omi_backend/clients/audio_stream_client.py
@@ -626,13 +626,13 @@ async def _close_abruptly():
return total_chunks
def send_button_event(
- self, stream_id: str, button_state: str = "SINGLE_TAP"
+ self, stream_id: str, button_state: str = "SINGLE_PRESS"
) -> None:
"""Send a button event to an open stream.
Args:
stream_id: Stream session ID
- button_state: Button state ("SINGLE_TAP" or "DOUBLE_TAP")
+ button_state: Button state ("SINGLE_PRESS" or "DOUBLE_PRESS")
"""
session = self._sessions.get(stream_id)
if not session:
diff --git a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py
index 0fcb9d9c..1bf41dfc 100644
--- a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py
+++ b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py
@@ -17,6 +17,7 @@
client_belongs_to_user,
get_client_manager,
)
+from advanced_omi_backend.config import get_transcription_job_timeout
from advanced_omi_backend.config_loader import get_service_config
from advanced_omi_backend.controllers.queue_controller import (
JOB_RESULT_TTL,
@@ -32,15 +33,14 @@
from advanced_omi_backend.models.conversation import Conversation
from advanced_omi_backend.models.job import JobPriority
from advanced_omi_backend.plugins.events import ConversationCloseReason, PluginEvent
+from advanced_omi_backend.services.memory import get_memory_service
from advanced_omi_backend.users import User
from advanced_omi_backend.workers.conversation_jobs import generate_title_summary_job
-from advanced_omi_backend.services.memory import get_memory_service
from advanced_omi_backend.workers.memory_jobs import (
enqueue_memory_processing,
process_memory_job,
)
from advanced_omi_backend.workers.speaker_jobs import recognise_speakers_job
-from advanced_omi_backend.config import get_transcription_job_timeout
logger = logging.getLogger(__name__)
audio_logger = logging.getLogger("audio_processing")
@@ -73,7 +73,7 @@ async def close_current_conversation(client_id: str, user: User):
status_code=404,
)
- session_id = getattr(client_state, 'stream_session_id', None)
+ session_id = getattr(client_state, "stream_session_id", None)
if not session_id:
return JSONResponse(
content={"error": "No active session"},
@@ -96,7 +96,9 @@ async def close_current_conversation(client_id: str, user: User):
status_code=404,
)
- logger.info(f"Conversation close requested for client {client_id} by user {user.user_id}")
+ logger.info(
+ f"Conversation close requested for client {client_id} by user {user.user_id}"
+ )
return JSONResponse(
content={
@@ -111,9 +113,13 @@ async def get_conversation(conversation_id: str, user: User):
"""Get a single conversation with full transcript details."""
try:
# Find the conversation using Beanie
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
# Check ownership for non-admin users
if not user.is_superuser and conversation.user_id != str(user.user_id):
@@ -127,15 +133,23 @@ async def get_conversation(conversation_id: str, user: User):
"audio_chunks_count": conversation.audio_chunks_count,
"audio_total_duration": conversation.audio_total_duration,
"audio_compression_ratio": conversation.audio_compression_ratio,
- "created_at": conversation.created_at.isoformat() if conversation.created_at else None,
+ "created_at": (
+ conversation.created_at.isoformat() if conversation.created_at else None
+ ),
"deleted": conversation.deleted,
"deletion_reason": conversation.deletion_reason,
- "deleted_at": conversation.deleted_at.isoformat() if conversation.deleted_at else None,
+ "deleted_at": (
+ conversation.deleted_at.isoformat() if conversation.deleted_at else None
+ ),
"processing_status": conversation.processing_status,
"always_persist": conversation.always_persist,
- "end_reason": conversation.end_reason.value if conversation.end_reason else None,
+ "end_reason": (
+ conversation.end_reason.value if conversation.end_reason else None
+ ),
"completed_at": (
- conversation.completed_at.isoformat() if conversation.completed_at else None
+ conversation.completed_at.isoformat()
+ if conversation.completed_at
+ else None
),
"title": conversation.title,
"summary": conversation.summary,
@@ -153,14 +167,18 @@ async def get_conversation(conversation_id: str, user: User):
"active_transcript_version_number": conversation.active_transcript_version_number,
"active_memory_version_number": conversation.active_memory_version_number,
"starred": conversation.starred,
- "starred_at": conversation.starred_at.isoformat() if conversation.starred_at else None,
+ "starred_at": (
+ conversation.starred_at.isoformat() if conversation.starred_at else None
+ ),
}
return {"conversation": response}
except Exception as e:
logger.error(f"Error fetching conversation {conversation_id}: {e}")
- return JSONResponse(status_code=500, content={"error": "Error fetching conversation"})
+ return JSONResponse(
+ status_code=500, content={"error": "Error fetching conversation"}
+ )
async def get_conversation_memories(conversation_id: str, user: User, limit: int = 100):
@@ -170,7 +188,9 @@ async def get_conversation_memories(conversation_id: str, user: User, limit: int
Conversation.conversation_id == conversation_id
)
if not conversation:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
if not user.is_superuser and conversation.user_id != str(user.user_id):
return JSONResponse(status_code=403, content={"error": "Access forbidden"})
@@ -364,21 +384,29 @@ async def get_conversations(
if include_unprocessed:
# Orphan type 1: always_persist stuck in pending/failed (not deleted)
- conditions.append({
- "always_persist": True,
- "processing_status": {"$in": ["pending_transcription", "transcription_failed"]},
- "deleted": False,
- })
+ conditions.append(
+ {
+ "always_persist": True,
+ "processing_status": {
+ "$in": ["pending_transcription", "transcription_failed"]
+ },
+ "deleted": False,
+ }
+ )
# Orphan type 2: soft-deleted due to no speech but have audio data
- conditions.append({
- "deleted": True,
- "deletion_reason": {"$in": [
- "no_meaningful_speech",
- "audio_file_not_ready",
- "no_meaningful_speech_batch_transcription",
- ]},
- "audio_chunks_count": {"$gt": 0},
- })
+ conditions.append(
+ {
+ "deleted": True,
+ "deletion_reason": {
+ "$in": [
+ "no_meaningful_speech",
+ "audio_file_not_ready",
+ "no_meaningful_speech_batch_transcription",
+ ]
+ },
+ "audio_chunks_count": {"$gt": 0},
+ }
+ )
# Assemble final query
if len(conditions) == 1:
@@ -406,12 +434,14 @@ async def get_conversations(
conv_id = doc.get("conversation_id")
is_orphan_type1 = (
doc.get("always_persist")
- and doc.get("processing_status") in ("pending_transcription", "transcription_failed")
+ and doc.get("processing_status")
+ in ("pending_transcription", "transcription_failed")
and not doc.get("deleted")
)
is_orphan_type2 = (
doc.get("deleted")
- and doc.get("deletion_reason") in (
+ and doc.get("deletion_reason")
+ in (
"no_meaningful_speech",
"audio_file_not_ready",
"no_meaningful_speech_batch_transcription",
@@ -437,7 +467,9 @@ async def get_conversations(
except Exception as e:
logger.exception(f"Error fetching conversations: {e}")
- return JSONResponse(status_code=500, content={"error": "Error fetching conversations"})
+ return JSONResponse(
+ status_code=500, content={"error": "Error fetching conversations"}
+ )
async def search_conversations(
@@ -513,10 +545,14 @@ async def search_conversations(
except Exception as e:
logger.exception(f"Error searching conversations: {e}")
- return JSONResponse(status_code=500, content={"error": "Error searching conversations"})
+ return JSONResponse(
+ status_code=500, content={"error": "Error searching conversations"}
+ )
-async def _soft_delete_conversation(conversation: Conversation, user: User) -> JSONResponse:
+async def _soft_delete_conversation(
+ conversation: Conversation, user: User
+) -> JSONResponse:
"""Mark conversation and chunks as deleted (soft delete).
Chunks are soft-deleted first so that a crash between the two writes
@@ -533,7 +569,9 @@ async def _soft_delete_conversation(conversation: Conversation, user: User) -> J
).update_many({"$set": {"deleted": True, "deleted_at": deleted_at}})
deleted_chunks = result.modified_count
- logger.info(f"Soft deleted {deleted_chunks} audio chunks for conversation {conversation_id}")
+ logger.info(
+ f"Soft deleted {deleted_chunks} audio chunks for conversation {conversation_id}"
+ )
# 2. Mark conversation as deleted
conversation.deleted = True
@@ -561,7 +599,9 @@ async def _soft_delete_conversation(conversation: Conversation, user: User) -> J
"deleted_chunks": deleted_chunks,
"conversation_id": conversation_id,
"client_id": conversation.client_id,
- "deleted_at": conversation.deleted_at.isoformat() if conversation.deleted_at else None,
+ "deleted_at": (
+ conversation.deleted_at.isoformat() if conversation.deleted_at else None
+ ),
},
)
@@ -582,7 +622,9 @@ async def _hard_delete_conversation(conversation: Conversation) -> JSONResponse:
).delete()
deleted_chunks = result.deleted_count
- logger.info(f"Hard deleted {deleted_chunks} audio chunks for conversation {conversation_id}")
+ logger.info(
+ f"Hard deleted {deleted_chunks} audio chunks for conversation {conversation_id}"
+ )
# 2. Delete conversation document
try:
@@ -607,7 +649,9 @@ async def _hard_delete_conversation(conversation: Conversation) -> JSONResponse:
)
-async def delete_conversation(conversation_id: str, user: User, permanent: bool = False):
+async def delete_conversation(
+ conversation_id: str, user: User, permanent: bool = False
+):
"""
Soft delete a conversation (mark as deleted but keep data).
@@ -628,11 +672,14 @@ async def delete_conversation(conversation_id: str, user: User, permanent: bool
)
# Find the conversation using Beanie
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
return JSONResponse(
- status_code=404, content={"error": f"Conversation '{conversation_id}' not found"}
+ status_code=404,
+ content={"error": f"Conversation '{conversation_id}' not found"},
)
# Check ownership for non-admin users
@@ -658,7 +705,8 @@ async def delete_conversation(conversation_id: str, user: User, permanent: bool
except Exception as e:
logger.error(f"Error deleting conversation {conversation_id}: {e}")
return JSONResponse(
- status_code=500, content={"error": f"Failed to delete conversation: {str(e)}"}
+ status_code=500,
+ content={"error": f"Failed to delete conversation: {str(e)}"},
)
@@ -671,17 +719,23 @@ async def restore_conversation(conversation_id: str, user: User) -> JSONResponse
user: Requesting user
"""
try:
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
# Permission check
if not user.is_superuser and conversation.user_id != str(user.user_id):
return JSONResponse(status_code=403, content={"error": "Access denied"})
if not conversation.deleted:
- return JSONResponse(status_code=400, content={"error": "Conversation is not deleted"})
+ return JSONResponse(
+ status_code=400, content={"error": "Conversation is not deleted"}
+ )
# 1. Restore audio chunks FIRST (safe failure mode: restored chunks, conversation still deleted)
original_deleted_at = conversation.deleted_at
@@ -707,7 +761,9 @@ async def restore_conversation(conversation_id: str, user: User) -> JSONResponse
await AudioChunkDocument.find(
AudioChunkDocument.conversation_id == conversation_id,
AudioChunkDocument.deleted == False,
- ).update_many({"$set": {"deleted": True, "deleted_at": original_deleted_at}})
+ ).update_many(
+ {"$set": {"deleted": True, "deleted_at": original_deleted_at}}
+ )
raise
logger.info(
@@ -727,16 +783,163 @@ async def restore_conversation(conversation_id: str, user: User) -> JSONResponse
except Exception as e:
logger.error(f"Error restoring conversation {conversation_id}: {e}")
return JSONResponse(
- status_code=500, content={"error": f"Failed to restore conversation: {str(e)}"}
+ status_code=500,
+ content={"error": f"Failed to restore conversation: {str(e)}"},
+ )
+
+
+def _enqueue_transcript_reprocessing(
+ conversation_id: str,
+ user_id: str,
+ source: str,
+ job_id_prefix: str,
+ end_reason: str,
+) -> tuple:
+ """Enqueue transcribe job + post-conversation chain.
+
+ Returns (version_id, transcript_job, post_jobs dict).
+ """
+ from advanced_omi_backend.workers.transcription_jobs import (
+ transcribe_full_audio_job,
+ )
+
+ version_id = str(uuid.uuid4())
+
+ transcript_job = transcription_queue.enqueue(
+ transcribe_full_audio_job,
+ conversation_id,
+ version_id,
+ source,
+ job_timeout=get_transcription_job_timeout(),
+ result_ttl=JOB_RESULT_TTL,
+ job_id=f"{job_id_prefix}_{conversation_id[:8]}",
+ description=f"Transcribe audio for {conversation_id[:8]}",
+ meta={"conversation_id": conversation_id},
+ )
+
+ post_jobs = start_post_conversation_jobs(
+ conversation_id=conversation_id,
+ user_id=user_id,
+ transcript_version_id=version_id,
+ depends_on_job=transcript_job,
+ end_reason=end_reason,
+ )
+
+ return version_id, transcript_job, post_jobs
+
+
+def _resolve_transcript_version(conversation: Conversation, version_id: str) -> tuple:
+ """Resolve 'active' to real version ID and find the version object.
+
+ Returns (error_response_or_None, resolved_version_id, version_object).
+ If error_response is not None, the caller should return it immediately.
+ """
+ resolved_id = version_id
+ if resolved_id == "active":
+ active_id = conversation.active_transcript_version
+ if not active_id:
+ return (
+ JSONResponse(
+ status_code=404,
+ content={"error": "No active transcript version found"},
+ ),
+ None,
+ None,
+ )
+ resolved_id = active_id
+
+ version_obj = None
+ for v in conversation.transcript_versions:
+ if v.version_id == resolved_id:
+ version_obj = v
+ break
+
+ if not version_obj:
+ return (
+ JSONResponse(
+ status_code=404,
+ content={"error": f"Transcript version '{resolved_id}' not found"},
+ ),
+ None,
+ None,
)
+ return None, resolved_id, version_obj
+
+
+def _enqueue_speaker_reprocessing_chain(
+ conversation_id: str,
+ version_id: str,
+ source_version_id: str,
+) -> dict:
+ """Enqueue speaker -> memory -> title_summary chain.
+
+ Returns dict with keys: speaker, memory, title_summary (job IDs).
+ """
+ speaker_job = transcription_queue.enqueue(
+ recognise_speakers_job,
+ conversation_id,
+ version_id,
+ job_timeout=1200,
+ result_ttl=JOB_RESULT_TTL,
+ job_id=f"reprocess_speaker_{conversation_id[:12]}",
+ description=f"Re-diarize speakers for {conversation_id[:8]}",
+ meta={
+ "conversation_id": conversation_id,
+ "version_id": version_id,
+ "source_version_id": source_version_id,
+ "trigger": "reprocess",
+ },
+ )
+ logger.info(
+ f"Enqueued speaker reprocessing job {speaker_job.id} for version {version_id}"
+ )
+
+ memory_job = memory_queue.enqueue(
+ process_memory_job,
+ conversation_id,
+ depends_on=speaker_job,
+ job_timeout=1800,
+ result_ttl=JOB_RESULT_TTL,
+ job_id=f"memory_{conversation_id[:12]}",
+ description=f"Extract memories for {conversation_id[:8]}",
+ meta={"conversation_id": conversation_id, "trigger": "reprocess_after_speaker"},
+ )
+ logger.info(
+ f"Chained memory job {memory_job.id} after speaker job {speaker_job.id}"
+ )
+
+ title_summary_job = default_queue.enqueue(
+ generate_title_summary_job,
+ conversation_id,
+ job_timeout=300,
+ result_ttl=JOB_RESULT_TTL,
+ depends_on=memory_job,
+ job_id=f"title_summary_{conversation_id[:12]}",
+ description=f"Regenerate title/summary for {conversation_id[:8]}",
+ meta={"conversation_id": conversation_id, "trigger": "reprocess_after_speaker"},
+ )
+ logger.info(
+ f"Chained title/summary job {title_summary_job.id} after memory job {memory_job.id}"
+ )
+
+ return {
+ "speaker": speaker_job.id,
+ "memory": memory_job.id,
+ "title_summary": title_summary_job.id,
+ }
+
async def toggle_star(conversation_id: str, user: User):
"""Toggle the starred/favorite status of a conversation."""
try:
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
if not user.is_superuser and conversation.user_id != str(user.user_id):
return JSONResponse(status_code=403, content={"error": "Access forbidden"})
@@ -763,7 +966,11 @@ async def toggle_star(conversation_id: str, user: User):
data={
"conversation_id": conversation_id,
"starred": conversation.starred,
- "starred_at": conversation.starred_at.isoformat() if conversation.starred_at else None,
+ "starred_at": (
+ conversation.starred_at.isoformat()
+ if conversation.starred_at
+ else None
+ ),
"title": conversation.title,
},
)
@@ -773,7 +980,9 @@ async def toggle_star(conversation_id: str, user: User):
return {
"conversation_id": conversation_id,
"starred": conversation.starred,
- "starred_at": conversation.starred_at.isoformat() if conversation.starred_at else None,
+ "starred_at": (
+ conversation.starred_at.isoformat() if conversation.starred_at else None
+ ),
}
except Exception as e:
@@ -784,9 +993,13 @@ async def toggle_star(conversation_id: str, user: User):
async def reprocess_orphan(conversation_id: str, user: User):
"""Reprocess an orphan audio session - restore if deleted and enqueue full processing chain."""
try:
- conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id)
+ conversation = await Conversation.find_one(
+ Conversation.conversation_id == conversation_id
+ )
if not conversation:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
# Check ownership
if not user.is_superuser and conversation.user_id != str(user.user_id):
@@ -821,33 +1034,12 @@ async def reprocess_orphan(conversation_id: str, user: User):
conversation.detailed_summary = None
await conversation.save()
- # Create new transcript version ID
- version_id = str(uuid.uuid4())
-
- # Enqueue the same 4-job chain as reprocess_transcript
- from advanced_omi_backend.workers.transcription_jobs import (
- transcribe_full_audio_job,
- )
-
- # Job 1: Transcribe audio
- transcript_job = transcription_queue.enqueue(
- transcribe_full_audio_job,
- conversation_id,
- version_id,
- "reprocess_orphan",
- job_timeout=get_transcription_job_timeout(),
- result_ttl=JOB_RESULT_TTL,
- job_id=f"orphan_transcribe_{conversation_id[:8]}",
- description=f"Transcribe orphan audio for {conversation_id[:8]}",
- meta={"conversation_id": conversation_id},
- )
-
- # Chain post-transcription jobs (speaker recognition β memory β title/summary β event dispatch)
- post_jobs = start_post_conversation_jobs(
+ # Enqueue the same job chain as reprocess_transcript
+ version_id, transcript_job, post_jobs = _enqueue_transcript_reprocessing(
conversation_id=conversation_id,
user_id=str(user.user_id),
- transcript_version_id=version_id,
- depends_on_job=transcript_job,
+ source="reprocess_orphan",
+ job_id_prefix="orphan_transcribe",
end_reason="reprocess_orphan",
)
@@ -881,7 +1073,9 @@ async def reprocess_transcript(conversation_id: str, user: User):
Conversation.conversation_id == conversation_id
)
if not conversation_model:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
# Check ownership for non-admin users
if not user.is_superuser and conversation_model.user_id != str(user.user_id):
@@ -907,34 +1101,12 @@ async def reprocess_transcript(conversation_id: str, user: User):
},
)
- # Create new transcript version ID
- version_id = str(uuid.uuid4())
-
- # Enqueue job chain with RQ (transcription -> speaker recognition -> memory)
- from advanced_omi_backend.workers.transcription_jobs import (
- transcribe_full_audio_job,
- )
-
- # Job 1: Transcribe audio to text (reconstructs from MongoDB chunks)
- transcript_job = transcription_queue.enqueue(
- transcribe_full_audio_job,
- conversation_id,
- version_id,
- "reprocess",
- job_timeout=get_transcription_job_timeout(),
- result_ttl=JOB_RESULT_TTL,
- job_id=f"reprocess_{conversation_id[:8]}",
- description=f"Transcribe audio for {conversation_id[:8]}",
- meta={"conversation_id": conversation_id},
- )
- logger.info(f"π₯ RQ: Enqueued transcription job {transcript_job.id}")
-
- # Chain post-transcription jobs (speaker recognition β memory β title/summary β event dispatch)
- post_jobs = start_post_conversation_jobs(
+ # Enqueue transcription + post-conversation job chain
+ version_id, transcript_job, post_jobs = _enqueue_transcript_reprocessing(
conversation_id=conversation_id,
user_id=str(user.user_id),
- transcript_version_id=version_id,
- depends_on_job=transcript_job,
+ source="reprocess",
+ job_id_prefix="reprocess",
end_reason="reprocess_transcript",
)
@@ -960,7 +1132,9 @@ async def reprocess_transcript(conversation_id: str, user: User):
)
-async def reprocess_memory(conversation_id: str, transcript_version_id: str, user: User):
+async def reprocess_memory(
+ conversation_id: str, transcript_version_id: str, user: User
+):
"""Reprocess memory extraction for a specific transcript version. Users can only reprocess their own conversations."""
try:
# Find the conversation using Beanie
@@ -968,7 +1142,9 @@ async def reprocess_memory(conversation_id: str, transcript_version_id: str, use
Conversation.conversation_id == conversation_id
)
if not conversation_model:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
# Check ownership for non-admin users
if not user.is_superuser and conversation_model.user_id != str(user.user_id):
@@ -979,28 +1155,12 @@ async def reprocess_memory(conversation_id: str, transcript_version_id: str, use
},
)
- # Resolve transcript version ID
- # Handle special "active" version ID
- if transcript_version_id == "active":
- active_version_id = conversation_model.active_transcript_version
- if not active_version_id:
- return JSONResponse(
- status_code=404, content={"error": "No active transcript version found"}
- )
- transcript_version_id = active_version_id
-
- # Find the specific transcript version
- transcript_version = None
- for version in conversation_model.transcript_versions:
- if version.version_id == transcript_version_id:
- transcript_version = version
- break
-
- if not transcript_version:
- return JSONResponse(
- status_code=404,
- content={"error": f"Transcript version '{transcript_version_id}' not found"},
- )
+ # Resolve transcript version ID (handle "active" special case)
+ error, transcript_version_id, transcript_version = _resolve_transcript_version(
+ conversation_model, transcript_version_id
+ )
+ if error:
+ return error
# Create new memory version ID
version_id = str(uuid.uuid4())
@@ -1033,7 +1193,9 @@ async def reprocess_memory(conversation_id: str, transcript_version_id: str, use
)
-async def reprocess_speakers(conversation_id: str, transcript_version_id: str, user: User):
+async def reprocess_speakers(
+ conversation_id: str, transcript_version_id: str, user: User
+):
"""
Reprocess speaker identification for a specific transcript version.
Users can only reprocess their own conversations.
@@ -1047,7 +1209,9 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u
Conversation.conversation_id == conversation_id
)
if not conversation_model:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
# Check ownership for non-admin users
if not user.is_superuser and conversation_model.user_id != str(user.user_id):
@@ -1058,28 +1222,12 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u
},
)
- # 2. Resolve source transcript version ID (handle "active" special case)
- source_version_id = transcript_version_id
- if source_version_id == "active":
- active_version_id = conversation_model.active_transcript_version
- if not active_version_id:
- return JSONResponse(
- status_code=404, content={"error": "No active transcript version found"}
- )
- source_version_id = active_version_id
-
- # 3. Find and validate the source transcript version
- source_version = None
- for version in conversation_model.transcript_versions:
- if version.version_id == source_version_id:
- source_version = version
- break
-
- if not source_version:
- return JSONResponse(
- status_code=404,
- content={"error": f"Transcript version '{source_version_id}' not found"},
- )
+ # 2-3. Resolve source transcript version ID and find version object
+ error, source_version_id, source_version = _resolve_transcript_version(
+ conversation_model, transcript_version_id
+ )
+ if error:
+ return error
# 4. Validate transcript has content and words (or provider-diarized segments)
if not source_version.transcript:
@@ -1169,72 +1317,22 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u
f"for conversation {conversation_id}"
)
- # 7. Enqueue speaker recognition job with NEW version_id
- speaker_job = transcription_queue.enqueue(
- recognise_speakers_job,
+ # 7-8. Enqueue speaker β memory β title/summary chain
+ job_ids = _enqueue_speaker_reprocessing_chain(
conversation_id,
- new_version_id, # NEW version (not source)
- job_timeout=1200, # 20 minutes
- result_ttl=JOB_RESULT_TTL,
- job_id=f"reprocess_speaker_{conversation_id[:12]}",
- description=f"Re-diarize speakers for {conversation_id[:8]}",
- meta={
- "conversation_id": conversation_id,
- "version_id": new_version_id,
- "source_version_id": source_version_id,
- "trigger": "reprocess",
- },
- )
-
- logger.info(
- f"Enqueued speaker reprocessing job {speaker_job.id} "
- f"for new version {new_version_id}"
- )
-
- # 8. Chain memory reprocessing (speaker changes affect memory context)
- memory_job = memory_queue.enqueue(
- process_memory_job,
- conversation_id,
- depends_on=speaker_job,
- job_timeout=1800, # 30 minutes
- result_ttl=JOB_RESULT_TTL,
- job_id=f"memory_{conversation_id[:12]}",
- description=f"Extract memories for {conversation_id[:8]}",
- meta={"conversation_id": conversation_id, "trigger": "reprocess_after_speaker"},
- )
-
- logger.info(
- f"Chained memory reprocessing job {memory_job.id} "
- f"after speaker job {speaker_job.id}"
- )
-
- # 8b. Chain title/summary regeneration after memory job
- # Depends on memory_job to avoid race condition (both save conversation document)
- # and to ensure fresh memories are available for context-enriched summaries
- title_summary_job = default_queue.enqueue(
- generate_title_summary_job,
- conversation_id,
- job_timeout=300,
- result_ttl=JOB_RESULT_TTL,
- depends_on=memory_job,
- job_id=f"title_summary_{conversation_id[:12]}",
- description=f"Regenerate title/summary for {conversation_id[:8]}",
- meta={"conversation_id": conversation_id, "trigger": "reprocess_after_speaker"},
- )
-
- logger.info(
- f"Chained title/summary job {title_summary_job.id} " f"after memory job {memory_job.id}"
+ new_version_id,
+ source_version_id,
)
# 9. Return job information
return JSONResponse(
content={
"message": "Speaker reprocessing started",
- "job_id": speaker_job.id,
- "memory_job_id": memory_job.id,
- "title_summary_job_id": title_summary_job.id,
- "version_id": new_version_id, # NEW version ID
- "source_version_id": source_version_id, # Original version used as source
+ "job_id": job_ids["speaker"],
+ "memory_job_id": job_ids["memory"],
+ "title_summary_job_id": job_ids["title_summary"],
+ "version_id": new_version_id,
+ "source_version_id": source_version_id,
"status": "queued",
}
)
@@ -1246,7 +1344,9 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u
)
-async def activate_transcript_version(conversation_id: str, version_id: str, user: User):
+async def activate_transcript_version(
+ conversation_id: str, version_id: str, user: User
+):
"""Activate a specific transcript version. Users can only modify their own conversations."""
try:
# Find the conversation using Beanie
@@ -1254,20 +1354,25 @@ async def activate_transcript_version(conversation_id: str, version_id: str, use
Conversation.conversation_id == conversation_id
)
if not conversation_model:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
# Check ownership for non-admin users
if not user.is_superuser and conversation_model.user_id != str(user.user_id):
return JSONResponse(
status_code=403,
- content={"error": "Access forbidden. You can only modify your own conversations."},
+ content={
+ "error": "Access forbidden. You can only modify your own conversations."
+ },
)
# Activate the transcript version using Beanie model method
success = conversation_model.set_active_transcript_version(version_id)
if not success:
return JSONResponse(
- status_code=400, content={"error": "Failed to activate transcript version"}
+ status_code=400,
+ content={"error": "Failed to activate transcript version"},
)
await conversation_model.save()
@@ -1301,13 +1406,17 @@ async def activate_memory_version(conversation_id: str, version_id: str, user: U
Conversation.conversation_id == conversation_id
)
if not conversation_model:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
# Check ownership for non-admin users
if not user.is_superuser and conversation_model.user_id != str(user.user_id):
return JSONResponse(
status_code=403,
- content={"error": "Access forbidden. You can only modify your own conversations."},
+ content={
+ "error": "Access forbidden. You can only modify your own conversations."
+ },
)
# Activate the memory version using Beanie model method
@@ -1332,7 +1441,9 @@ async def activate_memory_version(conversation_id: str, version_id: str, user: U
except Exception as e:
logger.error(f"Error activating memory version: {e}")
- return JSONResponse(status_code=500, content={"error": "Error activating memory version"})
+ return JSONResponse(
+ status_code=500, content={"error": "Error activating memory version"}
+ )
async def get_conversation_version_history(conversation_id: str, user: User):
@@ -1343,13 +1454,17 @@ async def get_conversation_version_history(conversation_id: str, user: User):
Conversation.conversation_id == conversation_id
)
if not conversation_model:
- return JSONResponse(status_code=404, content={"error": "Conversation not found"})
+ return JSONResponse(
+ status_code=404, content={"error": "Conversation not found"}
+ )
# Check ownership for non-admin users
if not user.is_superuser and conversation_model.user_id != str(user.user_id):
return JSONResponse(
status_code=403,
- content={"error": "Access forbidden. You can only access your own conversations."},
+ content={
+ "error": "Access forbidden. You can only access your own conversations."
+ },
)
# Get version history from model
@@ -1380,4 +1495,6 @@ async def get_conversation_version_history(conversation_id: str, user: User):
except Exception as e:
logger.error(f"Error fetching version history: {e}")
- return JSONResponse(status_code=500, content={"error": "Error fetching version history"})
+ return JSONResponse(
+ status_code=500, content={"error": "Error fetching version history"}
+ )
diff --git a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py
index d7fc12e9..bab956ec 100644
--- a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py
+++ b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py
@@ -1020,7 +1020,7 @@ async def _handle_button_event(
Args:
client_state: Client state object
- button_state: Button state string (e.g., "SINGLE_TAP", "DOUBLE_TAP")
+ button_state: Button state string (e.g., "SINGLE_PRESS", "DOUBLE_PRESS")
user_id: User ID
client_id: Client ID
"""
@@ -1073,99 +1073,143 @@ async def _handle_button_event(
)
-async def _process_rolling_batch(
- client_state, user_id: str, user_email: str, client_id: str, batch_number: int
-) -> None:
- """
- Process accumulated batch audio as a rolling segment.
-
- Creates conversation titled "Recording Part {batch_number}" and enqueues transcription.
+async def _create_batch_conversation_and_enqueue(
+ client_state,
+ user_id: str,
+ client_id: str,
+ title: str,
+ trigger: str,
+ job_id_prefix: str,
+ enqueue_post_jobs: bool = False,
+ attach_markers: bool = False,
+) -> Optional[str]:
+ """Create conversation from batch audio, store chunks, enqueue transcription.
Args:
client_state: Client state with batch_audio_chunks
user_id: User ID
- user_email: User email
client_id: Client ID
- batch_number: Sequential batch number (1, 2, 3...)
- """
- if (
- not hasattr(client_state, "batch_audio_chunks")
- or not client_state.batch_audio_chunks
- ):
- application_logger.warning(f"β οΈ No audio chunks to process for rolling batch")
- return
+ title: Conversation title
+ trigger: Trigger string for transcription job
+ job_id_prefix: Prefix for the transcription job ID
+ enqueue_post_jobs: If True, chain post-conversation jobs after transcription
+ attach_markers: If True, copy client_state.markers to conversation
- try:
- from advanced_omi_backend.models.conversation import create_conversation
- from advanced_omi_backend.utils.audio_chunk_utils import convert_audio_to_chunks
+ Returns:
+ conversation_id on success, None on failure.
+ """
+ from advanced_omi_backend.config import get_transcription_job_timeout
+ from advanced_omi_backend.controllers.queue_controller import (
+ JOB_RESULT_TTL,
+ transcription_queue,
+ )
+ from advanced_omi_backend.models.conversation import create_conversation
+ from advanced_omi_backend.utils.audio_chunk_utils import convert_audio_to_chunks
+ from advanced_omi_backend.workers.transcription_jobs import (
+ transcribe_full_audio_job,
+ )
- # Combine chunks
- complete_audio = b"".join(client_state.batch_audio_chunks)
- application_logger.info(
- f"π¦ Rolling batch #{batch_number}: Combined {len(client_state.batch_audio_chunks)} chunks "
- f"into {len(complete_audio)} bytes"
- )
+ complete_audio = b"".join(client_state.batch_audio_chunks)
+ audio_format = getattr(client_state, "batch_audio_format", {})
+ sample_rate = audio_format.get("rate", 16000)
+ sample_width = audio_format.get("width", 2)
+ channels = audio_format.get("channels", 1)
- # Get audio format
- audio_format = getattr(client_state, "batch_audio_format", {})
- sample_rate = audio_format.get("rate", 16000)
- width = audio_format.get("width", 2)
- channels = audio_format.get("channels", 1)
+ application_logger.info(
+ f"π¦ Batch: Combined {len(client_state.batch_audio_chunks)} chunks "
+ f"into {len(complete_audio)} bytes (title={title})"
+ )
- # Create conversation with batch number in title
- conversation = create_conversation(
- user_id=user_id,
- client_id=client_id,
- title=f"Recording Part {batch_number}",
- summary="Rolling batch processing...",
- )
- await conversation.insert()
- conversation_id = conversation.conversation_id # Get the auto-generated ID
+ # Create conversation
+ conversation = create_conversation(
+ user_id=user_id,
+ client_id=client_id,
+ title=title,
+ summary="Processing batch audio...",
+ )
+ if attach_markers and client_state.markers:
+ conversation.markers = list(client_state.markers)
+ client_state.markers.clear()
+ await conversation.insert()
+ conversation_id = conversation.conversation_id
- # Convert to MongoDB chunks
+ # Convert audio to MongoDB chunks
+ try:
num_chunks = await convert_audio_to_chunks(
conversation_id=conversation_id,
audio_data=complete_audio,
sample_rate=sample_rate,
channels=channels,
- sample_width=width,
+ sample_width=sample_width,
)
-
- # Enqueue transcription job
- from advanced_omi_backend.controllers.queue_controller import (
- JOB_RESULT_TTL,
- transcription_queue,
+ application_logger.info(
+ f"π¦ Batch: Converted to {num_chunks} MongoDB chunks ({conversation_id[:12]})"
)
- from advanced_omi_backend.workers.transcription_jobs import (
- transcribe_full_audio_job,
+ except Exception as chunk_error:
+ application_logger.error(
+ f"Failed to convert batch audio to chunks: {chunk_error}", exc_info=True
)
- version_id = str(uuid.uuid4())
- transcribe_job_id = f"transcribe_rolling_{conversation_id[:12]}_{batch_number}"
-
- from advanced_omi_backend.config import get_transcription_job_timeout
-
- transcription_job = transcription_queue.enqueue(
- transcribe_full_audio_job,
- conversation_id,
- version_id,
- f"rolling_batch_{batch_number}", # trigger
- job_timeout=get_transcription_job_timeout(),
- result_ttl=JOB_RESULT_TTL,
- job_id=transcribe_job_id,
- description=f"Transcribe rolling batch #{batch_number} {conversation_id[:8]}",
- meta={
- "conversation_id": conversation_id,
- "client_id": client_id,
- "batch_number": batch_number,
- },
+ # Enqueue transcription job
+ version_id = str(uuid.uuid4())
+ transcription_job = transcription_queue.enqueue(
+ transcribe_full_audio_job,
+ conversation_id,
+ version_id,
+ trigger,
+ job_timeout=get_transcription_job_timeout(),
+ result_ttl=JOB_RESULT_TTL,
+ job_id=f"{job_id_prefix}_{conversation_id[:12]}",
+ description=f"Transcribe {title.lower()} {conversation_id[:8]}",
+ meta={"conversation_id": conversation_id, "client_id": client_id},
+ )
+
+ application_logger.info(
+ f"π₯ Batch: Enqueued transcription job {transcription_job.id}"
+ )
+
+ # Optionally chain post-conversation jobs
+ if enqueue_post_jobs:
+ from advanced_omi_backend.controllers.queue_controller import (
+ start_post_conversation_jobs,
)
+ job_ids = start_post_conversation_jobs(
+ conversation_id=conversation_id,
+ user_id=None,
+ depends_on_job=transcription_job,
+ client_id=client_id,
+ )
application_logger.info(
- f"β
Rolling batch #{batch_number} created conversation {conversation_id}, "
- f"enqueued transcription job {transcription_job.id}"
+ f"β
Batch: Enqueued job chain for {conversation_id} β "
+ f"transcription ({transcription_job.id}) β "
+ f"speaker ({job_ids['speaker_recognition']}) β "
+ f"memory ({job_ids['memory']})"
)
+ return conversation_id
+
+
+async def _process_rolling_batch(
+ client_state, user_id: str, user_email: str, client_id: str, batch_number: int
+) -> None:
+ """Process accumulated batch audio as a rolling segment."""
+ if (
+ not hasattr(client_state, "batch_audio_chunks")
+ or not client_state.batch_audio_chunks
+ ):
+ application_logger.warning(f"β οΈ No audio chunks to process for rolling batch")
+ return
+
+ try:
+ await _create_batch_conversation_and_enqueue(
+ client_state,
+ user_id=user_id,
+ client_id=client_id,
+ title=f"Recording Part {batch_number}",
+ trigger=f"rolling_batch_{batch_number}",
+ job_id_prefix=f"transcribe_rolling_{batch_number}",
+ )
except Exception as e:
application_logger.error(
f"β Failed to process rolling batch #{batch_number}: {e}", exc_info=True
@@ -1175,15 +1219,7 @@ async def _process_rolling_batch(
async def _process_batch_audio_complete(
client_state, user_id: str, user_email: str, client_id: str
) -> None:
- """
- Process completed batch audio: write file, create conversation, enqueue jobs.
-
- Args:
- client_state: Client state with batch_audio_chunks
- user_id: User ID
- user_email: User email
- client_id: Client ID
- """
+ """Process completed batch audio: create conversation, enqueue full job chain."""
if (
not hasattr(client_state, "batch_audio_chunks")
or not client_state.batch_audio_chunks
@@ -1194,117 +1230,17 @@ async def _process_batch_audio_complete(
return
try:
- from advanced_omi_backend.models.conversation import create_conversation
- from advanced_omi_backend.utils.audio_chunk_utils import convert_audio_to_chunks
-
- # Combine all chunks
- complete_audio = b"".join(client_state.batch_audio_chunks)
- application_logger.info(
- f"π¦ Batch mode: Combined {len(client_state.batch_audio_chunks)} chunks into {len(complete_audio)} bytes"
- )
-
- # Timestamp for logging
- timestamp = int(time.time() * 1000)
-
- # Get audio format from batch metadata (set during audio-start)
- audio_format = getattr(client_state, "batch_audio_format", {})
- sample_rate = audio_format.get("rate", OMI_SAMPLE_RATE)
- sample_width = audio_format.get("width", OMI_SAMPLE_WIDTH)
- channels = audio_format.get("channels", OMI_CHANNELS)
-
- # Calculate audio duration
- duration = len(complete_audio) / (sample_rate * sample_width * channels)
-
- application_logger.info(f"β
Batch mode: Processing audio ({duration:.1f}s)")
-
- # Create conversation immediately for batch audio (conversation_id auto-generated)
- version_id = str(uuid.uuid4())
-
- conversation = create_conversation(
+ await _create_batch_conversation_and_enqueue(
+ client_state,
user_id=user_id,
client_id=client_id,
title="Batch Recording",
- summary="Processing batch audio...",
- )
- # Attach any markers (e.g., button events) captured during the session
- if client_state.markers:
- conversation.markers = list(client_state.markers)
- client_state.markers.clear()
- await conversation.insert()
- conversation_id = conversation.conversation_id # Get the auto-generated ID
-
- application_logger.info(
- f"π Batch mode: Created conversation {conversation_id}"
+ trigger="batch",
+ job_id_prefix="transcribe",
+ enqueue_post_jobs=True,
+ attach_markers=True,
)
-
- # Convert audio directly to MongoDB chunks (no disk intermediary)
- try:
- num_chunks = await convert_audio_to_chunks(
- conversation_id=conversation_id,
- audio_data=complete_audio,
- sample_rate=sample_rate,
- channels=channels,
- sample_width=sample_width,
- )
- application_logger.info(
- f"π¦ Batch mode: Converted to {num_chunks} MongoDB chunks "
- f"(conversation {conversation_id[:12]})"
- )
- except Exception as chunk_error:
- application_logger.error(
- f"Failed to convert batch audio to chunks: {chunk_error}", exc_info=True
- )
- # Continue anyway - transcription job will handle it
-
- # Enqueue batch transcription job first (file uploads need transcription)
- from advanced_omi_backend.controllers.queue_controller import (
- JOB_RESULT_TTL,
- start_post_conversation_jobs,
- transcription_queue,
- )
- from advanced_omi_backend.workers.transcription_jobs import (
- transcribe_full_audio_job,
- )
-
- version_id = str(uuid.uuid4())
- transcribe_job_id = f"transcribe_{conversation_id[:12]}"
-
- from advanced_omi_backend.config import get_transcription_job_timeout
-
- transcription_job = transcription_queue.enqueue(
- transcribe_full_audio_job,
- conversation_id,
- version_id,
- "batch", # trigger
- job_timeout=get_transcription_job_timeout(),
- result_ttl=JOB_RESULT_TTL,
- job_id=transcribe_job_id,
- description=f"Transcribe batch audio {conversation_id[:8]}",
- meta={"conversation_id": conversation_id, "client_id": client_id},
- )
-
- application_logger.info(
- f"π₯ Batch mode: Enqueued transcription job {transcription_job.id}"
- )
-
- # Enqueue post-conversation processing job chain (depends on transcription)
- job_ids = start_post_conversation_jobs(
- conversation_id=conversation_id,
- user_id=None, # Will be read from conversation in DB by jobs
- depends_on_job=transcription_job, # Wait for transcription to complete
- client_id=client_id, # Pass client_id for UI tracking
- )
-
- application_logger.info(
- f"β
Batch mode: Enqueued job chain for {conversation_id} - "
- f"transcription ({transcription_job.id}) β "
- f"speaker ({job_ids['speaker_recognition']}) β "
- f"memory ({job_ids['memory']})"
- )
-
- # Clear accumulated chunks
client_state.batch_audio_chunks = []
-
except Exception as batch_error:
application_logger.error(
f"β Batch mode processing failed: {batch_error}", exc_info=True
@@ -1355,35 +1291,70 @@ async def _cleanup_websocket_connection(
)
-async def handle_omi_websocket(
- ws: WebSocket,
- token: Optional[str] = None,
- device_name: Optional[str] = None,
-):
- """Handle OMI WebSocket connections with Opus decoding."""
- # Generate pending client_id to track connection even if auth fails
+from contextlib import asynccontextmanager
+
+
+@asynccontextmanager
+async def _websocket_session(ws, token, device_name, connection_type):
+ """Lifecycle wrapper: pending tracking, auth, client setup, cleanup.
+
+ Yields (client_id, client_state, user, audio_stream_producer, interim_holder)
+ on success, or None if auth failed.
+ interim_holder is a mutable list β the inner loop sets interim_holder[0] = task.
+ """
pending_client_id = f"pending_{uuid.uuid4()}"
pending_connections.add(pending_client_id)
client_id = None
- client_state = None
- interim_subscriber_task = None
+ interim_holder = [None] # mutable so inner loop can update
try:
- # Setup connection (accept, auth, create client state)
client_id, client_state, user = await _setup_websocket_connection(
- ws, token, device_name, pending_client_id, "OMI"
+ ws, token, device_name, pending_client_id, connection_type
)
if not user:
+ yield None
+ return
+
+ # Store user context on client state up front (shared by all handlers)
+ client_state.user_id = user.user_id
+ client_state.user_email = user.email
+ client_state.client_id = client_id
+
+ audio_stream_producer = get_audio_stream_producer()
+
+ yield (client_id, client_state, user, audio_stream_producer, interim_holder)
+
+ except WebSocketDisconnect:
+ application_logger.info(
+ f"π {connection_type} WebSocket disconnected β Client: {client_id}"
+ )
+ except Exception as e:
+ application_logger.error(
+ f"β {connection_type} WebSocket error for client {client_id}: {e}",
+ exc_info=True,
+ )
+ finally:
+ await _cleanup_websocket_connection(
+ client_id, pending_client_id, interim_holder[0]
+ )
+
+
+async def handle_omi_websocket(
+ ws: WebSocket,
+ token: Optional[str] = None,
+ device_name: Optional[str] = None,
+):
+ """Handle OMI WebSocket connections with Opus decoding."""
+ async with _websocket_session(ws, token, device_name, "OMI") as session:
+ if session is None:
return
+ client_id, client_state, user, audio_stream_producer, interim_holder = session
# OMI-specific: Setup Opus decoder
decoder = OmiOpusDecoder()
_decode_packet = partial(decoder.decode_packet, strip_header=False)
- # Get singleton audio stream producer
- audio_stream_producer = get_audio_stream_producer()
-
packet_count = 0
total_bytes = 0
@@ -1392,18 +1363,12 @@ async def handle_omi_websocket(
header, payload = await parse_wyoming_protocol(ws)
if header["type"] == "audio-start":
- # Handle audio session start
application_logger.info(
f"π΄ BACKEND: Received audio-start in OMI MODE for {client_id} (header={header})"
)
application_logger.info(f"ποΈ OMI audio session started for {client_id}")
- # Store user context on client state
- client_state.user_id = user.user_id
- client_state.user_email = user.email
- client_state.client_id = client_id
-
- interim_subscriber_task = await _initialize_streaming_session(
+ interim_holder[0] = await _initialize_streaming_session(
client_state,
audio_stream_producer,
user.user_id,
@@ -1417,20 +1382,18 @@ async def handle_omi_websocket(
"channels": OMI_CHANNELS,
},
),
- websocket=ws, # Pass WebSocket to launch interim results subscriber
+ websocket=ws,
)
elif header["type"] == "audio-chunk" and payload:
packet_count += 1
total_bytes += len(payload)
- # Log progress
if packet_count <= 5 or packet_count % 1000 == 0:
application_logger.info(
f"π΅ Received OMI audio chunk #{packet_count}: {len(payload)} bytes"
)
- # Handle OMI audio chunk (Opus decode + publish to stream)
await _handle_omi_audio_chunk(
client_state,
audio_stream_producer,
@@ -1441,20 +1404,17 @@ async def handle_omi_websocket(
packet_count,
)
- # Log progress every 1000th packet
if packet_count % 1000 == 0:
application_logger.info(
f"π Processed {packet_count} OMI packets ({total_bytes} bytes total)"
)
elif header["type"] == "audio-stop":
- # Handle audio session stop
application_logger.info(
f"π OMI audio session stopped for {client_id} - "
f"Total chunks: {packet_count}, Total bytes: {total_bytes}"
)
- # Finalize session using helper function
await _finalize_streaming_session(
client_state,
audio_stream_producer,
@@ -1463,7 +1423,6 @@ async def handle_omi_websocket(
client_id,
)
- # Reset counters for next session
packet_count = 0
total_bytes = 0
@@ -1475,51 +1434,23 @@ async def handle_omi_websocket(
)
else:
- # Unknown event type
application_logger.debug(
f"Ignoring Wyoming event type '{header['type']}' for OMI client {client_id}"
)
- except WebSocketDisconnect:
- application_logger.info(
- f"π WebSocket disconnected - Client: {client_id}, Packets: {packet_count}, Total bytes: {total_bytes}"
- )
- except Exception as e:
- application_logger.error(
- f"β WebSocket error for client {client_id}: {e}", exc_info=True
- )
- finally:
- await _cleanup_websocket_connection(
- client_id, pending_client_id, interim_subscriber_task
- )
-
async def handle_pcm_websocket(
ws: WebSocket, token: Optional[str] = None, device_name: Optional[str] = None
):
"""Handle PCM WebSocket connections with batch and streaming mode support."""
- # Generate pending client_id to track connection even if auth fails
- pending_client_id = f"pending_{uuid.uuid4()}"
- pending_connections.add(pending_client_id)
-
- client_id = None
- client_state = None
- interim_subscriber_task = None
-
- try:
- # Setup connection (accept, auth, create client state)
- client_id, client_state, user = await _setup_websocket_connection(
- ws, token, device_name, pending_client_id, "PCM"
- )
- if not user:
+ async with _websocket_session(ws, token, device_name, "PCM") as session:
+ if session is None:
return
-
- # Get singleton audio stream producer
- audio_stream_producer = get_audio_stream_producer()
+ client_id, client_state, user, audio_stream_producer, interim_holder = session
packet_count = 0
total_bytes = 0
- audio_streaming = False # Track if audio session is active
+ audio_streaming = False
while True:
try:
@@ -1544,18 +1475,13 @@ async def handle_pcm_websocket(
f"ποΈ Processing audio-start for {client_id}"
)
- # Store user context on client state for rolling batch processing
- client_state.user_id = user.user_id
- client_state.user_email = user.email
- client_state.client_id = client_id
-
- # Handle audio session start using helper function (pass websocket for error handling)
+ # Handle audio session start (pass websocket for error handling)
audio_streaming, recording_mode = (
await _handle_audio_session_start(
client_state,
header.get("data", {}),
client_id,
- websocket=ws, # Pass websocket for WebUI error display
+ websocket=ws,
)
)
@@ -1564,22 +1490,19 @@ async def handle_pcm_websocket(
application_logger.info(
f"π΄ BACKEND: Initializing streaming session for {client_id}"
)
- interim_subscriber_task = (
- await _initialize_streaming_session(
- client_state,
- audio_stream_producer,
- user.user_id,
- user.email,
- client_id,
- header.get("data", {}),
- websocket=ws,
- )
+ interim_holder[0] = await _initialize_streaming_session(
+ client_state,
+ audio_stream_producer,
+ user.user_id,
+ user.email,
+ client_id,
+ header.get("data", {}),
+ websocket=ws,
)
- continue # Continue to audio streaming mode
+ continue
elif header["type"] == "ping":
- # Handle keepalive ping from frontend
application_logger.debug(f"π Received ping from {client_id}")
continue
@@ -1592,23 +1515,20 @@ async def handle_pcm_websocket(
continue
else:
- # Unknown control message type
application_logger.debug(
f"Ignoring Wyoming control event type '{header['type']}' for {client_id}"
)
continue
else:
- # Audio streaming mode - receive raw bytes (like speaker recognition)
+ # Audio streaming mode
application_logger.debug(
f"π΅ Audio streaming mode for {client_id} - waiting for audio data"
)
try:
- # Receive raw audio bytes or check for control messages
message = await ws.receive()
- # Check if it's a disconnect
if (
"type" in message
and message["type"] == "websocket.disconnect"
@@ -1620,12 +1540,10 @@ async def handle_pcm_websocket(
)
break
- # Check if it's a text message (control message like audio-stop)
if "text" in message:
try:
control_header = json.loads(message["text"].strip())
if control_header.get("type") == "audio-stop":
- # Handle audio session stop using helper function
audio_streaming = await _handle_audio_session_stop(
client_state,
audio_stream_producer,
@@ -1633,7 +1551,6 @@ async def handle_pcm_websocket(
user.email,
client_id,
)
- # Reset counters for next session
packet_count = 0
total_bytes = 0
continue
@@ -1643,18 +1560,15 @@ async def handle_pcm_websocket(
)
continue
elif control_header.get("type") == "audio-start":
- # Handle duplicate audio-start messages gracefully (idempotent behavior)
application_logger.info(
f"π Ignoring duplicate audio-start message during streaming for {client_id}"
)
continue
elif control_header.get("type") == "audio-chunk":
- # Handle Wyoming protocol audio-chunk with binary payload
payload_length = control_header.get(
"payload_length"
)
if payload_length and payload_length > 0:
- # Receive the binary audio data
payload_msg = await ws.receive()
if "bytes" in payload_msg:
audio_data = payload_msg["bytes"]
@@ -1665,7 +1579,6 @@ async def handle_pcm_websocket(
f"π΅ Received audio chunk #{packet_count}: {len(audio_data)} bytes"
)
- # Route to appropriate mode handler
audio_format = control_header.get(
"data", {}
)
@@ -1679,9 +1592,8 @@ async def handle_pcm_websocket(
client_id,
websocket=ws,
)
- # Store subscriber task if it was created (first streaming chunk)
- if task and not interim_subscriber_task:
- interim_subscriber_task = task
+ if task and not interim_holder[0]:
+ interim_holder[0] = task
else:
application_logger.warning(
f"Expected binary payload for audio-chunk, got: {payload_msg.keys()}"
@@ -1713,9 +1625,7 @@ async def handle_pcm_websocket(
)
continue
- # Check if it's binary data (raw audio without Wyoming protocol)
elif "bytes" in message:
- # Raw binary audio data (legacy support)
audio_data = message["bytes"]
packet_count += 1
total_bytes += len(audio_data)
@@ -1724,7 +1634,6 @@ async def handle_pcm_websocket(
f"π΅ Received raw audio chunk #{packet_count}: {len(audio_data)} bytes"
)
- # Route to appropriate mode handler with default format
default_format = {"rate": 16000, "width": 2, "channels": 1}
task = await _handle_audio_chunk(
client_state,
@@ -1736,9 +1645,8 @@ async def handle_pcm_websocket(
client_id,
websocket=ws,
)
- # Store subscriber task if it was created (first streaming chunk)
- if task and not interim_subscriber_task:
- interim_subscriber_task = task
+ if task and not interim_holder[0]:
+ interim_holder[0] = task
else:
application_logger.warning(
@@ -1759,22 +1667,21 @@ async def handle_pcm_websocket(
f"π WebSocket disconnected during message processing for {client_id}. "
f"Code: {e.code}, Reason: {e.reason}"
)
- break # Exit the loop on disconnect
+ break
except json.JSONDecodeError as e:
application_logger.error(
f"β JSON decode error in Wyoming protocol for {client_id}: {e}"
)
- continue # Skip this message but don't disconnect
+ continue
except ValueError as e:
application_logger.error(f"β Protocol error for {client_id}: {e}")
- continue # Skip this message but don't disconnect
+ continue
except RuntimeError as e:
- # Handle "Cannot call receive once a disconnect message has been received"
if "disconnect" in str(e).lower():
application_logger.info(
f"π WebSocket already disconnected for {client_id}: {e}"
)
- break # Exit the loop on disconnect
+ break
else:
application_logger.error(
f"β Runtime error for {client_id}: {e}", exc_info=True
@@ -1785,7 +1692,6 @@ async def handle_pcm_websocket(
f"β Unexpected error processing message for {client_id}: {e}",
exc_info=True,
)
- # Check if it's a connection-related error
error_msg = str(e).lower()
if (
"disconnect" in error_msg
@@ -1797,17 +1703,4 @@ async def handle_pcm_websocket(
)
break
else:
- continue # Skip this message for other errors
-
- except WebSocketDisconnect:
- application_logger.info(
- f"π PCM WebSocket disconnected - Client: {client_id}, Packets: {packet_count}, Total bytes: {total_bytes}"
- )
- except Exception as e:
- application_logger.error(
- f"β PCM WebSocket error for client {client_id}: {e}", exc_info=True
- )
- finally:
- await _cleanup_websocket_connection(
- client_id, pending_client_id, interim_subscriber_task
- )
+ continue
diff --git a/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py b/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py
index 70a647b2..9c891aeb 100644
--- a/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py
+++ b/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py
@@ -1,11 +1,24 @@
-"""OpenTelemetry setup with Galileo span processor."""
+"""OpenTelemetry setup and session management.
+Uses OpenInference semantic conventions (session.id) so that any
+compatible observability backend (Galileo, Arize Phoenix, Langfuse, etc.)
+can group traces by session.
+"""
+
+import contextvars
import logging
import os
from functools import lru_cache
logger = logging.getLogger(__name__)
+_otel_initialised = False
+
+# Per-task/thread token so concurrent conversations don't clobber each other.
+_session_token_var: contextvars.ContextVar[object | None] = contextvars.ContextVar(
+ "_otel_session_token", default=None
+)
+
@lru_cache(maxsize=1)
def is_galileo_enabled() -> bool:
@@ -13,34 +26,43 @@ def is_galileo_enabled() -> bool:
return bool(os.getenv("GALILEO_API_KEY"))
-_session_token = None
+def is_otel_enabled() -> bool:
+ """Check if any OTel exporter has been initialised."""
+ return _otel_initialised
-def set_galileo_session(session_id: str) -> None:
- """Set Galileo session ID so subsequent traces are grouped together."""
- global _session_token
- if not is_galileo_enabled():
+def set_otel_session(session_id: str) -> None:
+ """Attach *session_id* to the OTel context (OpenInference ``session.id``).
+
+ All subsequent spans on this thread/context will carry the session ID,
+ regardless of which observability backend is consuming them.
+ Safe to call concurrently from different asyncio tasks or threads.
+ """
+ if not is_otel_enabled():
return
try:
- from galileo.otel import _session_id_context
+ from openinference.semconv.trace import SpanAttributes
+ from opentelemetry.context import attach, get_current, set_value
- _session_token = _session_id_context.set(session_id)
+ clear_otel_session()
+ ctx = set_value(SpanAttributes.SESSION_ID, session_id, get_current())
+ _session_token_var.set(attach(ctx))
except ImportError:
pass
-def clear_galileo_session() -> None:
- """Clear the Galileo session ID."""
- global _session_token
- if _session_token is None:
+def clear_otel_session() -> None:
+ """Detach the current session from the OTel context."""
+ token = _session_token_var.get()
+ if token is None:
return
try:
- from galileo.otel import _session_id_context
+ from opentelemetry.context import detach
- _session_id_context.reset(_session_token)
- _session_token = None
- except ImportError:
- pass
+ detach(token)
+ _session_token_var.set(None)
+ except Exception:
+ _session_token_var.set(None)
def init_otel() -> None:
@@ -95,6 +117,8 @@ def force_flush(self, timeout_millis: int = 30000) -> bool:
# Auto-instrument all OpenAI SDK calls
OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)
+ global _otel_initialised
+ _otel_initialised = True
logger.info("OTEL initialized with Galileo exporter + OpenAI instrumentor")
except ImportError:
logger.warning(
diff --git a/backends/advanced/src/advanced_omi_backend/plugins/base.py b/backends/advanced/src/advanced_omi_backend/plugins/base.py
index 5c9b668d..92e926f6 100644
--- a/backends/advanced/src/advanced_omi_backend/plugins/base.py
+++ b/backends/advanced/src/advanced_omi_backend/plugins/base.py
@@ -6,6 +6,7 @@
- PluginResult: Result from plugin execution
- BasePlugin: Abstract base class for all plugins
"""
+
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@@ -14,16 +15,20 @@
@dataclass
class PluginContext:
"""Context passed to plugin execution"""
+
user_id: str
event: str # Event name (e.g., "transcript.streaming", "conversation.complete")
data: Dict[str, Any] # Event-specific data
metadata: Dict[str, Any] = field(default_factory=dict)
- services: Optional[Any] = None # PluginServices instance for system/cross-plugin calls
+ services: Optional[Any] = (
+ None # PluginServices instance for system/cross-plugin calls
+ )
@dataclass
class PluginResult:
"""Result from plugin execution"""
+
success: bool
data: Optional[Dict[str, Any]] = None
message: Optional[str] = None
@@ -58,9 +63,9 @@ def __init__(self, config: Dict[str, Any]):
Contains: enabled, events, condition, and plugin-specific config
"""
self.config = config
- self.enabled = config.get('enabled', False)
- self.events = config.get('events', [])
- self.condition = config.get('condition', {'type': 'always'})
+ self.enabled = config.get("enabled", False)
+ self.events = config.get("events", [])
+ self.condition = config.get("condition", {"type": "always"})
def register_prompts(self, registry) -> None:
"""Register plugin prompts with the prompt registry.
@@ -122,7 +127,9 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]:
"""
pass
- async def on_conversation_complete(self, context: PluginContext) -> Optional[PluginResult]:
+ async def on_conversation_complete(
+ self, context: PluginContext
+ ) -> Optional[PluginResult]:
"""
Called when conversation processing completes.
@@ -137,7 +144,9 @@ async def on_conversation_complete(self, context: PluginContext) -> Optional[Plu
"""
pass
- async def on_memory_processed(self, context: PluginContext) -> Optional[PluginResult]:
+ async def on_memory_processed(
+ self, context: PluginContext
+ ) -> Optional[PluginResult]:
"""
Called after memory extraction finishes.
@@ -152,7 +161,9 @@ async def on_memory_processed(self, context: PluginContext) -> Optional[PluginRe
"""
pass
- async def on_conversation_starred(self, context: PluginContext) -> Optional[PluginResult]:
+ async def on_conversation_starred(
+ self, context: PluginContext
+ ) -> Optional[PluginResult]:
"""
Called when a conversation is starred or unstarred.
@@ -172,7 +183,7 @@ async def on_button_event(self, context: PluginContext) -> Optional[PluginResult
Called when a device button event is received.
Context data contains:
- - state: str - Button state (e.g., "SINGLE_TAP", "DOUBLE_TAP", "LONG_PRESS")
+ - state: str - Button state (e.g., "SINGLE_PRESS", "DOUBLE_PRESS", "LONG_PRESS")
- timestamp: float - Unix timestamp of the event
- audio_uuid: str - Current audio session UUID (may be None)
diff --git a/backends/advanced/src/advanced_omi_backend/plugins/events.py b/backends/advanced/src/advanced_omi_backend/plugins/events.py
index 3d7ec284..7a732a04 100644
--- a/backends/advanced/src/advanced_omi_backend/plugins/events.py
+++ b/backends/advanced/src/advanced_omi_backend/plugins/events.py
@@ -24,32 +24,50 @@ def __new__(cls, value: str, description: str = ""):
return obj
# Conversation lifecycle
- CONVERSATION_COMPLETE = ("conversation.complete", "Fires when conversation processing finishes (transcript ready)")
- TRANSCRIPT_STREAMING = ("transcript.streaming", "Real-time transcript segments during a live conversation")
- TRANSCRIPT_BATCH = ("transcript.batch", "Batch transcript from file upload processing")
- MEMORY_PROCESSED = ("memory.processed", "After memories are extracted from a conversation")
- CONVERSATION_STARRED = ("conversation.starred", "Fires when a conversation is starred or unstarred")
+ CONVERSATION_COMPLETE = (
+ "conversation.complete",
+ "Fires when conversation processing finishes (transcript ready)",
+ )
+ TRANSCRIPT_STREAMING = (
+ "transcript.streaming",
+ "Real-time transcript segments during a live conversation",
+ )
+ TRANSCRIPT_BATCH = (
+ "transcript.batch",
+ "Batch transcript from file upload processing",
+ )
+ MEMORY_PROCESSED = (
+ "memory.processed",
+ "After memories are extracted from a conversation",
+ )
+ CONVERSATION_STARRED = (
+ "conversation.starred",
+ "Fires when a conversation is starred or unstarred",
+ )
# Button events (from OMI device)
BUTTON_SINGLE_PRESS = ("button.single_press", "OMI device button single press")
BUTTON_DOUBLE_PRESS = ("button.double_press", "OMI device button double press")
# Cross-plugin communication (dispatched by PluginServices.call_plugin)
- PLUGIN_ACTION = ("plugin_action", "Cross-plugin dispatch via PluginServices.call_plugin()")
+ PLUGIN_ACTION = (
+ "plugin_action",
+ "Cross-plugin dispatch via PluginServices.call_plugin()",
+ )
class ButtonState(str, Enum):
"""Raw button states from OMI device firmware."""
- SINGLE_TAP = "SINGLE_TAP"
- DOUBLE_TAP = "DOUBLE_TAP"
+ SINGLE_PRESS = "SINGLE_PRESS"
+ DOUBLE_PRESS = "DOUBLE_PRESS"
LONG_PRESS = "LONG_PRESS"
# Maps device button states to plugin events
BUTTON_STATE_TO_EVENT: Dict[ButtonState, PluginEvent] = {
- ButtonState.SINGLE_TAP: PluginEvent.BUTTON_SINGLE_PRESS,
- ButtonState.DOUBLE_TAP: PluginEvent.BUTTON_DOUBLE_PRESS,
+ ButtonState.SINGLE_PRESS: PluginEvent.BUTTON_SINGLE_PRESS,
+ ButtonState.DOUBLE_PRESS: PluginEvent.BUTTON_DOUBLE_PRESS,
}
diff --git a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py
index 7c741662..2142ce07 100644
--- a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py
+++ b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py
@@ -21,7 +21,7 @@
)
from advanced_omi_backend.controllers.session_controller import mark_session_complete
from advanced_omi_backend.models.job import async_job
-from advanced_omi_backend.observability.otel_setup import set_galileo_session
+from advanced_omi_backend.observability.otel_setup import set_otel_session
from advanced_omi_backend.plugins.events import PluginEvent
from advanced_omi_backend.services.plugin_service import (
ensure_plugin_router,
@@ -1034,7 +1034,7 @@ async def generate_title_summary_job(
generate_title_and_summary,
)
- set_galileo_session(conversation_id)
+ set_otel_session(conversation_id)
logger.info(
f"π Starting title/summary generation for conversation {conversation_id}"
)
diff --git a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py
index 3c0cedd8..492dc650 100644
--- a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py
+++ b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py
@@ -23,8 +23,8 @@
)
from advanced_omi_backend.models.job import JobPriority, async_job
from advanced_omi_backend.observability.otel_setup import (
- clear_galileo_session,
- set_galileo_session,
+ clear_otel_session,
+ set_otel_session,
)
from advanced_omi_backend.plugins.events import PluginEvent
from advanced_omi_backend.services.plugin_service import ensure_plugin_router
@@ -141,7 +141,7 @@ async def process_memory_job(
from advanced_omi_backend.services.memory import get_memory_service
from advanced_omi_backend.users import get_user_by_id
- set_galileo_session(conversation_id)
+ set_otel_session(conversation_id)
start_time = time.time()
logger.info(f"π Starting memory processing for conversation {conversation_id}")
diff --git a/extras/friend-lite-sdk/friend_lite/button.py b/extras/friend-lite-sdk/friend_lite/button.py
index 421a87b1..11e6ecd6 100644
--- a/extras/friend-lite-sdk/friend_lite/button.py
+++ b/extras/friend-lite-sdk/friend_lite/button.py
@@ -6,8 +6,8 @@
class ButtonState(IntEnum):
IDLE = 0
- SINGLE_TAP = 1
- DOUBLE_TAP = 2
+ SINGLE_PRESS = 1
+ DOUBLE_PRESS = 2
LONG_PRESS = 3
PRESS = 4
RELEASE = 5
diff --git a/plugins/homeassistant/command_parser.py b/plugins/homeassistant/command_parser.py
index cc73626d..de56ce97 100644
--- a/plugins/homeassistant/command_parser.py
+++ b/plugins/homeassistant/command_parser.py
@@ -51,16 +51,16 @@ class ParsedCommand:
- set_color: Set color
TARGET_TYPE (choose one):
-- area: Targeting all entities of a type in an area (e.g., "study lights")
-- all_in_area: Targeting ALL entities in an area (e.g., "everything in study")
-- entity: Targeting a specific entity by name (e.g., "desk lamp")
+- area: Target entities in a specific area OR label. Labels are groups of areas (e.g., "hall" label might cover dining_room + living_room). Both areas and labels are valid targets.
+- all: Target entities across ALL areas (e.g., "all lights", "every light")
+- entity: Target a specific entity by name (e.g., "desk lamp")
ENTITY_TYPE (optional, use null if not specified):
- light: Light entities
- switch: Switch entities
- fan: Fan entities
- cover: Covers/blinds
-- null: All entity types (when target_type is "all_in_area")
+- null: All entity types
PARAMETERS (optional, empty dict if none):
- brightness_pct: Brightness percentage (0-100)
@@ -71,8 +71,8 @@ class ParsedCommand:
Command: "turn off study lights"
Response: {"action": "turn_off", "target_type": "area", "target": "study", "entity_type": "light", "parameters": {}}
-Command: "turn off everything in study"
-Response: {"action": "turn_off", "target_type": "all_in_area", "target": "study", "entity_type": null, "parameters": {}}
+Command: "turn off hall lights"
+Response: {"action": "turn_off", "target_type": "area", "target": "hall", "entity_type": "light", "parameters": {}}
Command: "turn on desk lamp"
Response: {"action": "turn_on", "target_type": "entity", "target": "desk lamp", "entity_type": null, "parameters": {}}
@@ -80,11 +80,11 @@ class ParsedCommand:
Command: "set study lights to 50%"
Response: {"action": "set_brightness", "target_type": "area", "target": "study", "entity_type": "light", "parameters": {"brightness_pct": 50}}
-Command: "turn on living room fan"
-Response: {"action": "turn_on", "target_type": "area", "target": "living room", "entity_type": "fan", "parameters": {}}
-
Command: "turn off all lights"
-Response: {"action": "turn_off", "target_type": "entity", "target": "all", "entity_type": "light", "parameters": {}}
+Response: {"action": "turn_off", "target_type": "all", "target": "all", "entity_type": "light", "parameters": {}}
+
+Command: "turn off everything"
+Response: {"action": "turn_off", "target_type": "all", "target": "all", "entity_type": null, "parameters": {}}
Command: "toggle hallway light"
Response: {"action": "toggle", "target_type": "entity", "target": "hallway light", "entity_type": null, "parameters": {}}
@@ -94,4 +94,6 @@ class ParsedCommand:
2. Use lowercase for action, target_type, target, entity_type
3. Use null (not "null" string) for missing entity_type
4. Always include all 5 fields: action, target_type, target, entity_type, parameters
+5. The "target" for target_type "area" MUST be an area name or label name from the provided context
+6. Use target_type "all" when the user says "all lights", "every light", "all the lights", etc.
"""
diff --git a/plugins/homeassistant/plugin.py b/plugins/homeassistant/plugin.py
index b94f8ae4..f6e2d973 100644
--- a/plugins/homeassistant/plugin.py
+++ b/plugins/homeassistant/plugin.py
@@ -2,7 +2,7 @@
Home Assistant plugin for Chronicle.
Enables control of Home Assistant devices through natural language commands
-triggered by a wake word.
+triggered by a keyword anywhere in the transcript.
"""
import json
@@ -10,6 +10,7 @@
from typing import Any, Dict, List, Optional
from advanced_omi_backend.plugins.base import BasePlugin, PluginContext, PluginResult
+
from .entity_cache import EntityCache
from .mcp_client import HAMCPClient, MCPError
@@ -18,20 +19,20 @@
class HomeAssistantPlugin(BasePlugin):
"""
- Plugin for controlling Home Assistant devices via wake word commands.
+ Plugin for controlling Home Assistant devices via keyword commands.
Example:
- User says: "Vivi, turn off the hall lights"
- -> Wake word "vivi" detected by router
- -> Command "turn off the hall lights" passed to on_transcript()
- -> Plugin parses command and calls HA MCP to execute
+ User says: "Turn off the hall lights, VV"
+ -> Keyword "vv" detected anywhere in transcript by router
+ -> Command "Turn off the hall lights" passed to on_transcript()
+ -> Plugin parses command and calls HA to execute
-> Returns: PluginResult with "I've turned off the hall light"
"""
SUPPORTED_ACCESS_LEVELS: List[str] = ["transcript", "button"]
name = "Home Assistant"
- description = "Wake word device control with Home Assistant integration"
+ description = "Keyword-triggered device control with Home Assistant integration"
def __init__(self, config: Dict[str, Any]):
"""
@@ -155,7 +156,9 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]:
command = context.data.get("command", "")
if not command:
- return PluginResult(success=False, message="No command provided", should_continue=True)
+ return PluginResult(
+ success=False, message="No command provided", should_continue=True
+ )
if not self.mcp_client:
logger.error("MCP client not initialized")
@@ -166,9 +169,21 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]:
)
try:
+ conversation_id = context.data.get("conversation_id")
+
+ # Step 0: Extract just the HA command from mixed transcript
+ extracted = await self._extract_ha_command(
+ command, conversation_id=conversation_id
+ )
+ if extracted:
+ logger.info(f"Extracted HA command: '{extracted}' (from: '{command}')")
+ command = extracted
+
# Step 1: Parse command using hybrid LLM + fallback parsing
logger.info(f"Processing HA command: '{command}'")
- parsed = await self._parse_command_hybrid(command)
+ parsed = await self._parse_command_hybrid(
+ command, conversation_id=conversation_id
+ )
if not parsed:
return PluginResult(
@@ -199,15 +214,25 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]:
service = service_map.get(parsed.action, "turn_on")
# Step 4: Call Home Assistant service
- logger.info(f"Calling {domain}.{service} for {len(entity_ids)} entities: {entity_ids}")
+ logger.info(
+ f"Calling {domain}.{service} for {len(entity_ids)} entities: {entity_ids}"
+ )
result = await self.mcp_client.call_service(
- domain=domain, service=service, entity_ids=entity_ids, **parsed.parameters
+ domain=domain,
+ service=service,
+ entity_ids=entity_ids,
+ **parsed.parameters,
)
# Step 5: Format user-friendly response
entity_type_name = parsed.entity_type or domain
- if parsed.target_type == "area":
+ if parsed.target_type == "all":
+ message = (
+ f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} "
+ f"{entity_type_name}{'s' if len(entity_ids) != 1 else ''} everywhere"
+ )
+ elif parsed.target_type == "area":
message = (
f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} "
f"{entity_type_name}{'s' if len(entity_ids) != 1 else ''} "
@@ -339,9 +364,10 @@ async def on_button_event(self, context: PluginContext) -> Optional[PluginResult
using the button_actions config. Reuses the same entity resolution and
service call logic as on_plugin_action().
"""
- from .command_parser import ParsedCommand
from advanced_omi_backend.plugins.events import PluginEvent
+ from .command_parser import ParsedCommand
+
# Map event to config key
if context.event == PluginEvent.BUTTON_DOUBLE_PRESS:
action_key = "double_press"
@@ -369,7 +395,9 @@ async def on_button_event(self, context: PluginContext) -> Optional[PluginResult
entity_type = action_config.get("entity_type", "light")
if not target:
- return PluginResult(success=False, message="No target in button_actions config")
+ return PluginResult(
+ success=False, message="No target in button_actions config"
+ )
parsed = ParsedCommand(
action=service,
@@ -430,7 +458,11 @@ async def health_check(self) -> dict:
latency_ms = int((time.time() - start) * 1000)
if str(result).strip() == "2":
return {"ok": True, "message": "Connected", "latency_ms": latency_ms}
- return {"ok": False, "message": f"Unexpected result: {result}", "latency_ms": latency_ms}
+ return {
+ "ok": False,
+ "message": f"Unexpected result: {result}",
+ "latency_ms": latency_ms,
+ }
except Exception as e:
return {"ok": False, "message": str(e)}
@@ -503,14 +535,96 @@ async def _refresh_cache(self):
)
logger.info(
- f"Entity cache refreshed: {len(areas)} areas, " f"{len(entity_details)} entities"
+ f"Entity cache refreshed: {len(areas)} areas, "
+ f"{len(entity_details)} entities"
)
except Exception as e:
logger.error(f"Failed to refresh entity cache: {e}", exc_info=True)
raise
- async def _parse_command_with_llm(self, command: str) -> Optional["ParsedCommand"]:
+ async def _extract_ha_command(
+ self, transcript: str, *, conversation_id: Optional[str] = None
+ ) -> Optional[str]:
+ """
+ Use a lightweight LLM call to extract only the Home Assistant command
+ from a transcript that may contain mixed conversation.
+
+ When the keyword (e.g. "vivi") is detected anywhere in a long transcript,
+ the surrounding text often includes unrelated speech. This method asks the
+ LLM to return just the smart-home command portion.
+
+ Args:
+ transcript: Transcript text with keyword already stripped.
+ conversation_id: Optional conversation ID for Langfuse session grouping.
+
+ Returns:
+ Extracted command string, or None to fall back to the raw text.
+ """
+ # Short transcripts are likely already just the command
+ if len(transcript.split()) <= 8:
+ return None
+
+ try:
+ from advanced_omi_backend.llm_client import get_llm_client
+ from advanced_omi_backend.openai_factory import is_langfuse_enabled
+
+ llm_client = get_llm_client()
+
+ system_prompt = (
+ "Extract ONLY the smart home / home assistant command from the "
+ "transcript below. The transcript is from a conversation and may "
+ "contain unrelated speech mixed in. Return ONLY the command text, "
+ "nothing else. If no smart home command is found, return NONE.\n\n"
+ "Examples:\n"
+ 'Input: "so anyway I was saying turn off the hall lights and then we went to dinner"\n'
+ "Output: turn off the hall lights\n\n"
+ 'Input: "turn on bedroom lights"\n'
+ "Output: turn on bedroom lights\n\n"
+ 'Input: "yeah the meeting was great oh and set living room brightness '
+ 'to 50 percent and also the deadline is tomorrow"\n'
+ "Output: set living room brightness to 50 percent\n\n"
+ 'Input: "so I told him about the project and then toggle the kitchen fan '
+ 'and after that we discussed lunch plans"\n'
+ "Output: toggle the kitchen fan"
+ )
+
+ params = {
+ "model": llm_client.model,
+ "messages": [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": transcript},
+ ],
+ "temperature": 0.0,
+ "max_tokens": 100,
+ }
+ if is_langfuse_enabled():
+ params["name"] = "ha-command-extraction"
+ params["metadata"] = {
+ "plugin": "homeassistant",
+ "step": "extract_command",
+ }
+ if conversation_id:
+ params["langfuse_session_id"] = conversation_id
+
+ response = llm_client.client.chat.completions.create(**params)
+
+ result = response.choices[0].message.content.strip()
+
+ if not result or result.upper() == "NONE":
+ logger.info("LLM extraction found no HA command in transcript")
+ return None
+
+ logger.info(f"LLM extracted HA command: '{result}'")
+ return result
+
+ except Exception as e:
+ logger.warning(f"LLM command extraction failed: {e}, using raw text")
+ return None
+
+ async def _parse_command_with_llm(
+ self, command: str, *, conversation_id: Optional[str] = None
+ ) -> Optional["ParsedCommand"]:
"""
Parse command using LLM with structured system prompt.
@@ -532,26 +646,60 @@ async def _parse_command_with_llm(self, command: str) -> Optional["ParsedCommand
"""
try:
from advanced_omi_backend.llm_client import get_llm_client
+ from advanced_omi_backend.openai_factory import is_langfuse_enabled
from advanced_omi_backend.prompt_registry import get_prompt_registry
from .command_parser import ParsedCommand
llm_client = get_llm_client()
registry = get_prompt_registry()
- system_prompt = await registry.get_prompt("plugin.homeassistant.command_parser")
+ system_prompt = await registry.get_prompt(
+ "plugin.homeassistant.command_parser"
+ )
logger.debug(f"Parsing command with LLM: '{command}'")
+ # Build context from entity cache so the LLM knows valid targets
+ entity_context = ""
+ await self._ensure_cache_initialized()
+ if self.entity_cache:
+ areas = (
+ ", ".join(self.entity_cache.areas)
+ if self.entity_cache.areas
+ else "none"
+ )
+ labels = ""
+ if self.entity_cache.label_areas:
+ label_parts = [
+ f"{lbl} (covers: {', '.join(areas_list)})"
+ for lbl, areas_list in self.entity_cache.label_areas.items()
+ ]
+ labels = "\nAvailable labels: " + ", ".join(label_parts)
+ entity_context = f"\n\nAvailable areas: {areas}{labels}\nUse target_type 'area' with an area/label name above, or target_type 'all' for everything."
+
# Use OpenAI chat format with system + user messages
- response = llm_client.client.chat.completions.create(
- model=llm_client.model,
- messages=[
+ params = {
+ "model": llm_client.model,
+ "messages": [
{"role": "system", "content": system_prompt},
- {"role": "user", "content": f'Command: "{command}"\n\nReturn JSON only.'},
+ {
+ "role": "user",
+ "content": f'Command: "{command}"{entity_context}\n\nReturn JSON only.',
+ },
],
- temperature=0.1,
- max_tokens=150,
- )
+ "temperature": 0.1,
+ "max_tokens": 150,
+ }
+ if is_langfuse_enabled():
+ params["name"] = "ha-command-parser"
+ params["metadata"] = {
+ "plugin": "homeassistant",
+ "step": "parse_command",
+ }
+ if conversation_id:
+ params["langfuse_session_id"] = conversation_id
+
+ response = llm_client.client.chat.completions.create(**params)
result_text = response.choices[0].message.content.strip()
logger.debug(f"LLM response: {result_text}")
@@ -588,7 +736,9 @@ async def _parse_command_with_llm(self, command: str) -> Optional["ParsedCommand
return parsed
except json.JSONDecodeError as e:
- logger.error(f"Failed to parse LLM JSON response: {e}\nResponse: {result_text}")
+ logger.error(
+ f"Failed to parse LLM JSON response: {e}\nResponse: {result_text}"
+ )
return None
except Exception as e:
logger.error(f"LLM command parsing failed: {e}", exc_info=True)
@@ -631,8 +781,12 @@ async def _resolve_entities(self, parsed: "ParsedCommand") -> List[str]:
)
if not entities:
- entity_desc = f"{parsed.entity_type}s" if parsed.entity_type else "entities"
- available = list(self.entity_cache.areas) + list(self.entity_cache.label_areas.keys())
+ entity_desc = (
+ f"{parsed.entity_type}s" if parsed.entity_type else "entities"
+ )
+ available = list(self.entity_cache.areas) + list(
+ self.entity_cache.label_areas.keys()
+ )
raise ValueError(
f"No {entity_desc} found in area/label '{parsed.target}'. "
f"Available: {', '.join(available)}"
@@ -644,9 +798,33 @@ async def _resolve_entities(self, parsed: "ParsedCommand") -> List[str]:
)
return entities
+ elif parsed.target_type == "all":
+ # Get entities across ALL areas, optionally filtered by type
+ entities = []
+ for area in self.entity_cache.areas:
+ entities.extend(
+ self.entity_cache.get_entities_in_area(
+ area=area, entity_type=parsed.entity_type
+ )
+ )
+
+ if not entities:
+ entity_desc = (
+ f"{parsed.entity_type}s" if parsed.entity_type else "entities"
+ )
+ raise ValueError(f"No {entity_desc} found in any area")
+
+ logger.info(
+ f"Resolved 'all' to {len(entities)} "
+ f"{parsed.entity_type or 'entity'}(s) across {len(self.entity_cache.areas)} areas"
+ )
+ return entities
+
elif parsed.target_type == "all_in_area":
# Get ALL entities in area (no filter)
- entities = self.entity_cache.get_entities_in_area(area=parsed.target, entity_type=None)
+ entities = self.entity_cache.get_entities_in_area(
+ area=parsed.target, entity_type=None
+ )
if not entities:
raise ValueError(
@@ -654,7 +832,9 @@ async def _resolve_entities(self, parsed: "ParsedCommand") -> List[str]:
f"Available areas: {', '.join(self.entity_cache.areas)}"
)
- logger.info(f"Resolved 'all in {parsed.target}' to {len(entities)} entities")
+ logger.info(
+ f"Resolved 'all in {parsed.target}' to {len(entities)} entities"
+ )
return entities
elif parsed.target_type == "entity":
@@ -726,7 +906,9 @@ async def _parse_command_fallback(self, command: str) -> Optional[Dict[str, Any]
"action_desc": action_desc,
}
- async def _parse_command_hybrid(self, command: str) -> Optional["ParsedCommand"]:
+ async def _parse_command_hybrid(
+ self, command: str, *, conversation_id: Optional[str] = None
+ ) -> Optional["ParsedCommand"]:
"""
Hybrid command parser: Try LLM first, fallback to keywords.
@@ -736,6 +918,7 @@ async def _parse_command_hybrid(self, command: str) -> Optional["ParsedCommand"]
Args:
command: Natural language command
+ conversation_id: Optional conversation ID for Langfuse session grouping.
Returns:
ParsedCommand if successful, None otherwise
@@ -751,7 +934,10 @@ async def _parse_command_hybrid(self, command: str) -> Optional["ParsedCommand"]
# Try LLM parsing with timeout
try:
logger.debug("Attempting LLM-based command parsing...")
- parsed = await asyncio.wait_for(self._parse_command_with_llm(command), timeout=5.0)
+ parsed = await asyncio.wait_for(
+ self._parse_command_with_llm(command, conversation_id=conversation_id),
+ timeout=5.0,
+ )
if parsed:
logger.info("LLM parsing succeeded")
@@ -823,7 +1009,9 @@ async def test_connection(config: Dict[str, Any]) -> Dict[str, Any]:
try:
# Validate required config fields
required_fields = ["ha_url", "ha_token"]
- missing_fields = [field for field in required_fields if not config.get(field)]
+ missing_fields = [
+ field for field in required_fields if not config.get(field)
+ ]
if missing_fields:
return {
diff --git a/tests/integration/websocket_streaming_tests.robot b/tests/integration/websocket_streaming_tests.robot
index 8c0ac647..5161b816 100644
--- a/tests/integration/websocket_streaming_tests.robot
+++ b/tests/integration/websocket_streaming_tests.robot
@@ -127,7 +127,7 @@ Button Press Should Close Active Conversation
Should Not Be Empty ${conversation_id} msg=Conversation ID not found in job meta
# Act: Send button press to close the conversation
- Send Button Event To Stream ${stream_id} SINGLE_TAP
+ Send Button Event To Stream ${stream_id} SINGLE_PRESS
# Assert: Conversation should close with end_reason=close_requested
Wait Until Keyword Succeeds 30s 2s
diff --git a/tests/libs/audio_stream_library.py b/tests/libs/audio_stream_library.py
index bf870639..f1651dc1 100644
--- a/tests/libs/audio_stream_library.py
+++ b/tests/libs/audio_stream_library.py
@@ -150,12 +150,12 @@ def close_audio_stream_without_stop(stream_id: str) -> int:
return _manager.close_stream_without_stop(stream_id)
-def send_button_event(stream_id: str, button_state: str = "SINGLE_TAP") -> None:
+def send_button_event(stream_id: str, button_state: str = "SINGLE_PRESS") -> None:
"""Send a button event to an open stream.
Args:
stream_id: Stream session ID
- button_state: Button state ("SINGLE_TAP" or "DOUBLE_TAP")
+ button_state: Button state ("SINGLE_PRESS" or "DOUBLE_PRESS")
"""
_manager.send_button_event(stream_id, button_state)
diff --git a/tests/resources/websocket_keywords.robot b/tests/resources/websocket_keywords.robot
index 1ca2b164..58014de5 100644
--- a/tests/resources/websocket_keywords.robot
+++ b/tests/resources/websocket_keywords.robot
@@ -159,8 +159,8 @@ Close Audio Stream
RETURN ${total_chunks}
Send Button Event To Stream
- [Documentation] Send a button event (SINGLE_TAP, DOUBLE_TAP) to an open stream
- [Arguments] ${stream_id} ${button_state}=SINGLE_TAP
+ [Documentation] Send a button event (SINGLE_PRESS, DOUBLE_PRESS) to an open stream
+ [Arguments] ${stream_id} ${button_state}=SINGLE_PRESS
Send Button Event ${stream_id} ${button_state}
Log Sent button event ${button_state} to stream ${stream_id}
From 65e20c39d7e14c22f173f1987c08939f56f00c65 Mon Sep 17 00:00:00 2001
From: Ankush Malaker <43288948+AnkushMalaker@users.noreply.github.com>
Date: Sun, 22 Feb 2026 11:23:46 +0000
Subject: [PATCH 3/5] Add unit tests for Qwen3-ASR output parsing and
repetition detection
- Introduced a new test file `test_qwen3_asr_parsing.py` to validate the functionality of the `_parse_qwen3_output` and `detect_and_fix_repetitions` methods.
- Implemented various test cases covering standard and edge cases for ASR output parsing, including language detection, handling of empty inputs, and unexpected text.
- Added tests for repetition detection to ensure proper functionality based on specified thresholds.
- Enhanced the `Makefile` to include a new target for running specific tests by name, tag, or file, improving test execution flexibility.
- Created a shared prerequisite check script `check_uv.sh` to ensure the `uv` package manager is installed before running scripts, enhancing setup reliability.
---
scripts/check_uv.sh | 15 ++++++++++++++
tests/Makefile | 15 +++++++++++++-
.../audio_streaming_integration_tests.robot | 20 +++++++++++--------
3 files changed, 41 insertions(+), 9 deletions(-)
create mode 100755 scripts/check_uv.sh
diff --git a/scripts/check_uv.sh b/scripts/check_uv.sh
new file mode 100755
index 00000000..8325bfa3
--- /dev/null
+++ b/scripts/check_uv.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+# Shared prerequisite check for Chronicle scripts.
+# Source this at the top of any shell script that needs uv.
+
+if ! command -v uv &> /dev/null; then
+ echo "β 'uv' is not installed."
+ echo ""
+ echo "Chronicle requires 'uv' (Python package manager) to run."
+ echo "Install it with:"
+ echo ""
+ echo " curl -LsSf https://astral.sh/uv/install.sh | sh"
+ echo ""
+ echo "Then restart your terminal and try again."
+ exit 1
+fi
diff --git a/tests/Makefile b/tests/Makefile
index 2992c3a1..b4b3e1b2 100644
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -5,7 +5,7 @@
containers-start containers-stop containers-restart containers-rebuild \
containers-start-rebuild containers-clean containers-status containers-logs \
start stop restart rebuild start-rebuild status logs \
- test test-quick test-slow test-sdk test-no-api test-with-api-keys test-all-with-slow-and-sdk clean-all \
+ test test-quick test-custom test-slow test-sdk test-no-api test-with-api-keys test-all-with-slow-and-sdk clean-all \
test-asr test-asr-gpu \
results results-path results-detailed
@@ -56,6 +56,11 @@ help:
@echo " make test-asr - Run ASR protocol tests (no GPU required)"
@echo " make test-asr-gpu - Run ASR GPU tests (requires NVIDIA GPU)"
@echo ""
+ @echo "Running Specific Tests:"
+ @echo " make test-custom T=\"Test Name\" - Run a single test by name"
+ @echo " make test-custom TAG=audio-streaming - Run tests by tag"
+ @echo " make test-custom F=integration/foo.robot - Run a specific file"
+ @echo ""
@echo "Special Test Tags:"
@echo " make test-slow - Run ONLY slow tests (backend restarts)"
@echo " make test-sdk - Run ONLY SDK tests (unreleased)"
@@ -232,6 +237,14 @@ test:
# Quick workflow: run tests on existing containers (ignores CONFIG changes)
test-quick: all
+# Run specific test by name, tag, or file
+# Usage:
+# make test-custom T="Chunk Count Increments In Redis Session"
+# make test-custom TAG=audio-streaming
+# make test-custom F=integration/audio_streaming_integration_tests.robot
+test-custom:
+ @./run-custom.sh $(if $(T),--test "$(T)") $(if $(TAG),--tag $(TAG)) $(if $(F),$(F))
+
# Run ONLY slow tests (backend restarts, long timeouts)
test-slow:
@echo "Running slow tests only..."
diff --git a/tests/integration/audio_streaming_integration_tests.robot b/tests/integration/audio_streaming_integration_tests.robot
index f34b7984..3ae29331 100644
--- a/tests/integration/audio_streaming_integration_tests.robot
+++ b/tests/integration/audio_streaming_integration_tests.robot
@@ -65,29 +65,33 @@ Redis Session Schema Contains All Required Fields
Chunk Count Increments In Redis Session
- [Documentation] Verify chunk count is tracked in Redis (not ClientState)
+ [Documentation] Verify chunk count is tracked in Redis (not ClientState).
+ ... Note: The producer re-chunks client audio into 250ms fixed-size chunks
+ ... (8000 bytes at 16kHz/16-bit/mono). Client sends 100ms chunks (3200 bytes).
+ ... So N client chunks produce floor(N * 3200 / 8000) published chunks.
[Tags] infra audio-streaming
${device_name}= Set Variable chunk-count-test
${stream_id}= Open Audio Stream device_name=${device_name}
${client_id}= Get Client ID From Device Name ${device_name}
- # Send chunks and verify count increases
- Send Audio Chunks To Stream ${stream_id} ${TEST_AUDIO_FILE} num_chunks=3
+ # Send first batch: 10 client chunks (10 * 3200 = 32000 bytes β 4 published chunks)
+ Send Audio Chunks To Stream ${stream_id} ${TEST_AUDIO_FILE} num_chunks=10
Sleep 1s # Allow chunk counter to update
${session1}= Get Redis Session Data ${client_id}
${count1}= Convert To Integer ${session1}[chunks_published]
+ Should Be True ${count1} > 0 First batch should produce at least 1 published chunk
- Send Audio Chunks To Stream ${stream_id} ${TEST_AUDIO_FILE} num_chunks=5
+ # Send second batch: 10 more client chunks
+ Send Audio Chunks To Stream ${stream_id} ${TEST_AUDIO_FILE} num_chunks=10
Sleep 1s # Allow chunk counter to update
${session2}= Get Redis Session Data ${client_id}
${count2}= Convert To Integer ${session2}[chunks_published]
- # Verify count increased (should be at least 8)
- Should Be True ${count2} > ${count1}
- Should Be True ${count2} >= 8
+ # Verify count increased between batches
+ Should Be True ${count2} > ${count1} Chunk count should increase after sending more audio (${count1} β ${count2})
- Log β
Chunk count tracked in Redis: ${count1} β ${count2}
+ Log Chunk count tracked in Redis: ${count1} β ${count2}
# Close stream after test completes
${total_chunks}= Close Audio Stream ${stream_id}
From 034defa0b75ed0a60cb00837c896f98a5b3ee2a1 Mon Sep 17 00:00:00 2001
From: Ankush Malaker <43288948+AnkushMalaker@users.noreply.github.com>
Date: Sun, 22 Feb 2026 11:25:52 +0000
Subject: [PATCH 4/5] Add unit tests for Qwen3-ASR output parsing and
repetition detection
- Introduced a new test file `test_qwen3_asr_parsing.py` to validate the functionality of the `_parse_qwen3_output` and `detect_and_fix_repetitions` methods.
- Implemented various test cases covering standard and edge cases for ASR output parsing, including language detection, handling of empty inputs, and unexpected text.
- Added tests for repetition detection to ensure proper functionality based on specified thresholds.
---
.pre-commit-config.yaml | 1 +
.../tests/test_qwen3_asr_parsing.py | 125 ++++++++++++++++++
2 files changed, 126 insertions(+)
create mode 100644 extras/asr-services/tests/test_qwen3_asr_parsing.py
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index adf40dcb..859eca3c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -9,6 +9,7 @@ repos:
rev: 5.13.2
hooks:
- id: isort
+ args: ["--profile", "black"]
exclude: \.venv/
# File hygiene
diff --git a/extras/asr-services/tests/test_qwen3_asr_parsing.py b/extras/asr-services/tests/test_qwen3_asr_parsing.py
new file mode 100644
index 00000000..1b8f4c2f
--- /dev/null
+++ b/extras/asr-services/tests/test_qwen3_asr_parsing.py
@@ -0,0 +1,125 @@
+"""
+Tests for Qwen3-ASR output parsing and repetition detection.
+
+Pure function tests β no GPU, no vLLM, no network required.
+
+Run:
+ cd extras/asr-services
+ uv run pytest tests/test_qwen3_asr_parsing.py -v
+"""
+
+import sys
+from pathlib import Path
+
+import pytest
+
+sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
+
+from providers.qwen3_asr.transcriber import (
+ _parse_qwen3_output,
+ detect_and_fix_repetitions,
+)
+
+# ---------------------------------------------------------------------------
+# _parse_qwen3_output tests
+# ---------------------------------------------------------------------------
+
+
+class TestParseQwen3Output:
+ """Tests for _parse_qwen3_output(raw) β (language, text)."""
+
+ def test_standard_english(self):
+ lang, text = _parse_qwen3_output(
+ "language Englishhello world"
+ )
+ assert lang == "English"
+ assert text == "hello world"
+
+ def test_standard_chinese(self):
+ lang, text = _parse_qwen3_output(
+ "language Chineseδ½ ε₯½δΈη"
+ )
+ assert lang == "Chinese"
+ assert text == "δ½ ε₯½δΈη"
+
+ def test_silent_audio_language_none(self):
+ lang, text = _parse_qwen3_output("language None")
+ assert lang == ""
+ assert text == ""
+
+ def test_silent_with_unexpected_text(self):
+ lang, text = _parse_qwen3_output("language Nonehmm")
+ assert lang == ""
+ assert text == "hmm"
+
+ def test_plain_text_no_tags(self):
+ lang, text = _parse_qwen3_output("just plain text")
+ assert lang == ""
+ assert text == "just plain text"
+
+ def test_empty_string(self):
+ lang, text = _parse_qwen3_output("")
+ assert lang == ""
+ assert text == ""
+
+ def test_none_input(self):
+ lang, text = _parse_qwen3_output(None)
+ assert lang == ""
+ assert text == ""
+
+ def test_whitespace_only(self):
+ lang, text = _parse_qwen3_output(" ")
+ assert lang == ""
+ assert text == ""
+
+ def test_missing_closing_tag(self):
+ lang, text = _parse_qwen3_output("language Englishhello world")
+ assert lang == "English"
+ assert text == "hello world"
+
+ def test_multiline_metadata(self):
+ raw = "language English\nsome extra\ntext here"
+ lang, text = _parse_qwen3_output(raw)
+ assert lang == "English"
+ assert text == "text here"
+
+ def test_whitespace_around_text(self):
+ lang, text = _parse_qwen3_output(
+ "language English hello "
+ )
+ assert lang == "English"
+ assert text == "hello"
+
+
+# ---------------------------------------------------------------------------
+# detect_and_fix_repetitions tests
+# ---------------------------------------------------------------------------
+
+
+class TestDetectAndFixRepetitions:
+ """Tests for detect_and_fix_repetitions(text, threshold)."""
+
+ def test_normal_text_unchanged(self):
+ text = "Hello, how are you?"
+ assert detect_and_fix_repetitions(text) == text
+
+ def test_single_char_repeated_above_threshold(self):
+ result = detect_and_fix_repetitions("a" * 50)
+ assert result == "a"
+
+ def test_single_char_repeated_below_threshold(self):
+ text = "a" * 10
+ assert detect_and_fix_repetitions(text) == text
+
+ def test_pattern_repeated_above_threshold(self):
+ result = detect_and_fix_repetitions("ha" * 30)
+ assert result == "ha"
+
+ def test_short_text_unchanged(self):
+ assert detect_and_fix_repetitions("hi") == "hi"
+
+ def test_mixed_content_with_repeating_tail(self):
+ result = detect_and_fix_repetitions("Hello " + "x" * 50)
+ assert result.startswith("Hello ")
+ # The long run of x's should be collapsed
+ assert len(result) < len("Hello " + "x" * 50)
From 75cec50e5eb28dd573153d34a0439c13059b177c Mon Sep 17 00:00:00 2001
From: Ankush Malaker <43288948+AnkushMalaker@users.noreply.github.com>
Date: Sun, 22 Feb 2026 11:29:14 +0000
Subject: [PATCH 5/5] Refactor Redis session handling and enhance error
management
- Updated session retrieval logic in `queue_routes.py` to ensure proper closure of Redis connections using `await redis_client.aclose()`, improving resource management.
- Enhanced error handling during session data retrieval, providing clearer logging for issues encountered while fetching session information.
- Streamlined the session key scanning process, maintaining existing functionality while improving code readability and maintainability.
- Added optional parameters to the `transcribe` method in `mock_provider.py` for better flexibility in handling context information and progress callbacks during transcription tasks.
---
.../routers/modules/queue_routes.py | 354 ++++++------------
.../services/transcription/mock_provider.py | 28 +-
2 files changed, 137 insertions(+), 245 deletions(-)
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py
index d45513fd..29719566 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py
@@ -73,9 +73,7 @@ async def list_jobs(
@router.get("/jobs/{job_id}/status")
-async def get_job_status(
- job_id: str, current_user: User = Depends(current_active_user)
-):
+async def get_job_status(job_id: str, current_user: User = Depends(current_active_user)):
"""Get just the status of a specific job (lightweight endpoint)."""
try:
job = Job.fetch(job_id, connection=redis_conn)
@@ -195,9 +193,7 @@ async def cancel_job(job_id: str, current_user: User = Depends(current_active_us
@router.get("/jobs/by-client/{client_id}")
-async def get_jobs_by_client(
- client_id: str, current_user: User = Depends(current_active_user)
-):
+async def get_jobs_by_client(client_id: str, current_user: User = Depends(current_active_user)):
"""Get all jobs associated with a specific client device."""
try:
from rq.registry import (
@@ -246,17 +242,11 @@ def process_job_and_dependents(job, queue_name, base_status):
all_jobs.append(
{
"job_id": job.id,
- "job_type": (
- job.func_name.split(".")[-1] if job.func_name else "unknown"
- ),
+ "job_type": (job.func_name.split(".")[-1] if job.func_name else "unknown"),
"queue": queue_name,
"status": status,
- "created_at": (
- job.created_at.isoformat() if job.created_at else None
- ),
- "started_at": (
- job.started_at.isoformat() if job.started_at else None
- ),
+ "created_at": (job.created_at.isoformat() if job.created_at else None),
+ "started_at": (job.started_at.isoformat() if job.started_at else None),
"ended_at": job.ended_at.isoformat() if job.ended_at else None,
"description": job.description or "",
"result": job.result,
@@ -333,17 +323,13 @@ def process_job_and_dependents(job, queue_name, base_status):
# Sort by created_at
all_jobs.sort(key=lambda x: x["created_at"] or "", reverse=False)
- logger.info(
- f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)"
- )
+ logger.info(f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)")
return {"client_id": client_id, "jobs": all_jobs, "total": len(all_jobs)}
except Exception as e:
logger.error(f"Failed to get jobs for client {client_id}: {e}")
- raise HTTPException(
- status_code=500, detail=f"Failed to get jobs for client: {str(e)}"
- )
+ raise HTTPException(status_code=500, detail=f"Failed to get jobs for client: {str(e)}")
@router.get("/events")
@@ -363,9 +349,7 @@ async def get_events(
if not router_instance:
return {"events": [], "total": 0}
- events = router_instance.get_recent_events(
- limit=limit, event_type=event_type or None
- )
+ events = router_instance.get_recent_events(limit=limit, event_type=event_type or None)
return {"events": events, "total": len(events)}
except Exception as e:
logger.error(f"Failed to get events: {e}")
@@ -483,9 +467,7 @@ async def get_queue_worker_details(current_user: User = Depends(current_active_u
except Exception as e:
logger.error(f"Failed to get queue worker details: {e}")
- raise HTTPException(
- status_code=500, detail=f"Failed to get worker details: {str(e)}"
- )
+ raise HTTPException(status_code=500, detail=f"Failed to get worker details: {str(e)}")
@router.get("/streams")
@@ -516,9 +498,7 @@ async def get_stream_stats(
async def get_stream_info(stream_key):
try:
- stream_name = (
- stream_key.decode() if isinstance(stream_key, bytes) else stream_key
- )
+ stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key
# Get basic stream info
info = await audio_service.redis.xinfo_stream(stream_name)
@@ -578,9 +558,7 @@ async def get_stream_info(stream_key):
"name": group_dict.get("name", "unknown"),
"consumers": group_dict.get("consumers", 0),
"pending": group_dict.get("pending", 0),
- "last_delivered_id": group_dict.get(
- "last-delivered-id", "N/A"
- ),
+ "last_delivered_id": group_dict.get("last-delivered-id", "N/A"),
"consumer_details": consumers,
}
)
@@ -591,9 +569,7 @@ async def get_stream_info(stream_key):
"stream_name": stream_name,
"length": info[b"length"],
"first_entry_id": (
- info[b"first-entry"][0].decode()
- if info[b"first-entry"]
- else None
+ info[b"first-entry"][0].decode() if info[b"first-entry"] else None
),
"last_entry_id": (
info[b"last-entry"][0].decode() if info[b"last-entry"] else None
@@ -605,9 +581,7 @@ async def get_stream_info(stream_key):
return None
# Fetch all stream info in parallel
- streams_info_results = await asyncio.gather(
- *[get_stream_info(key) for key in stream_keys]
- )
+ streams_info_results = await asyncio.gather(*[get_stream_info(key) for key in stream_keys])
streams_info = [info for info in streams_info_results if info is not None]
return {
@@ -633,9 +607,7 @@ class FlushAllJobsRequest(BaseModel):
@router.post("/flush")
-async def flush_jobs(
- request: FlushJobsRequest, current_user: User = Depends(current_active_user)
-):
+async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(current_active_user)):
"""Flush old inactive jobs based on age and status."""
if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="Admin access required")
@@ -651,9 +623,7 @@ async def flush_jobs(
from advanced_omi_backend.controllers.queue_controller import get_queue
- cutoff_time = datetime.now(timezone.utc) - timedelta(
- hours=request.older_than_hours
- )
+ cutoff_time = datetime.now(timezone.utc) - timedelta(hours=request.older_than_hours)
total_removed = 0
# Get all queues
@@ -685,9 +655,7 @@ async def flush_jobs(
except Exception as e:
logger.error(f"Error deleting job {job_id}: {e}")
- if (
- "canceled" in request.statuses
- ): # RQ standard (US spelling), not "cancelled"
+ if "canceled" in request.statuses: # RQ standard (US spelling), not "cancelled"
registry = CanceledJobRegistry(queue=queue)
for job_id in registry.get_job_ids():
try:
@@ -769,12 +737,8 @@ async def flush_all_jobs(
registries.append(("finished", FinishedJobRegistry(queue=queue)))
for registry_name, registry in registries:
- job_ids = list(
- registry.get_job_ids()
- ) # Convert to list to avoid iterator issues
- logger.info(
- f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}"
- )
+ job_ids = list(registry.get_job_ids()) # Convert to list to avoid iterator issues
+ logger.info(f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}")
for job_id in job_ids:
try:
@@ -784,9 +748,7 @@ async def flush_all_jobs(
# Skip session-level jobs (e.g., speech_detection, audio_persistence)
# These run for the entire session and should not be killed by test cleanup
if job.meta and job.meta.get("session_level"):
- logger.info(
- f"Skipping session-level job {job_id} ({job.description})"
- )
+ logger.info(f"Skipping session-level job {job_id} ({job.description})")
continue
# Handle running jobs differently to avoid worker deadlock
@@ -797,9 +759,7 @@ async def flush_all_jobs(
from rq.command import send_stop_job_command
send_stop_job_command(redis_conn, job_id)
- logger.info(
- f"Sent stop command to worker for job {job_id}"
- )
+ logger.info(f"Sent stop command to worker for job {job_id}")
# Don't delete yet - let worker move it to canceled/failed registry
# It will be cleaned up on next flush or by worker cleanup
continue
@@ -810,13 +770,9 @@ async def flush_all_jobs(
# If stop fails, try to cancel it (may already be finishing)
try:
job.cancel()
- logger.info(
- f"Cancelled job {job_id} after stop failed"
- )
+ logger.info(f"Cancelled job {job_id} after stop failed")
except Exception as cancel_error:
- logger.warning(
- f"Could not cancel job {job_id}: {cancel_error}"
- )
+ logger.warning(f"Could not cancel job {job_id}: {cancel_error}")
# For non-running jobs, safe to delete immediately
job.delete()
@@ -831,9 +787,7 @@ async def flush_all_jobs(
f"Removed stale job reference {job_id} from {registry_name} registry"
)
except Exception as reg_error:
- logger.error(
- f"Could not remove {job_id} from registry: {reg_error}"
- )
+ logger.error(f"Could not remove {job_id} from registry: {reg_error}")
# Also clean up audio streams and consumer locks
deleted_keys = 0
@@ -847,9 +801,7 @@ async def flush_all_jobs(
# Delete audio streams
cursor = 0
while True:
- cursor, keys = await async_redis.scan(
- cursor, match="audio:*", count=1000
- )
+ cursor, keys = await async_redis.scan(cursor, match="audio:*", count=1000)
if keys:
await async_redis.delete(*keys)
deleted_keys += len(keys)
@@ -859,9 +811,7 @@ async def flush_all_jobs(
# Delete consumer locks
cursor = 0
while True:
- cursor, keys = await async_redis.scan(
- cursor, match="consumer:*", count=1000
- )
+ cursor, keys = await async_redis.scan(cursor, match="consumer:*", count=1000)
if keys:
await async_redis.delete(*keys)
deleted_keys += len(keys)
@@ -890,9 +840,7 @@ async def flush_all_jobs(
except Exception as e:
logger.error(f"Failed to flush all jobs: {e}")
- raise HTTPException(
- status_code=500, detail=f"Failed to flush all jobs: {str(e)}"
- )
+ raise HTTPException(status_code=500, detail=f"Failed to flush all jobs: {str(e)}")
@router.get("/sessions")
@@ -907,63 +855,54 @@ async def get_redis_sessions(
from advanced_omi_backend.controllers.queue_controller import REDIS_URL
redis_client = aioredis.from_url(REDIS_URL)
+ try:
+ # Get session keys
+ session_keys = []
+ cursor = b"0"
+ while cursor and len(session_keys) < limit:
+ cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=limit)
+ session_keys.extend(keys[: limit - len(session_keys)])
+
+ # Get session info
+ sessions = []
+ for key in session_keys:
+ try:
+ session_data = await redis_client.hgetall(key)
+ if session_data:
+ session_id = key.decode().replace("audio:session:", "")
+
+ # Get conversation count for this session
+ conversation_count_key = f"session:conversation_count:{session_id}"
+ conversation_count_bytes = await redis_client.get(conversation_count_key)
+ conversation_count = (
+ int(conversation_count_bytes.decode())
+ if conversation_count_bytes
+ else 0
+ )
- # Get session keys
- session_keys = []
- cursor = b"0"
- while cursor and len(session_keys) < limit:
- cursor, keys = await redis_client.scan(
- cursor, match="audio:session:*", count=limit
- )
- session_keys.extend(keys[: limit - len(session_keys)])
-
- # Get session info
- sessions = []
- for key in session_keys:
- try:
- session_data = await redis_client.hgetall(key)
- if session_data:
- session_id = key.decode().replace("audio:session:", "")
-
- # Get conversation count for this session
- conversation_count_key = f"session:conversation_count:{session_id}"
- conversation_count_bytes = await redis_client.get(
- conversation_count_key
- )
- conversation_count = (
- int(conversation_count_bytes.decode())
- if conversation_count_bytes
- else 0
- )
-
- sessions.append(
- {
- "session_id": session_id,
- "user_id": session_data.get(b"user_id", b"").decode(),
- "client_id": session_data.get(b"client_id", b"").decode(),
- "stream_name": session_data.get(
- b"stream_name", b""
- ).decode(),
- "provider": session_data.get(b"provider", b"").decode(),
- "mode": session_data.get(b"mode", b"").decode(),
- "status": session_data.get(b"status", b"").decode(),
- "started_at": session_data.get(b"started_at", b"").decode(),
- "chunks_published": int(
- session_data.get(b"chunks_published", b"0").decode()
- or 0
- ),
- "last_chunk_at": session_data.get(
- b"last_chunk_at", b""
- ).decode(),
- "conversation_count": conversation_count,
- }
- )
- except Exception as e:
- logger.error(f"Error getting session info for {key}: {e}")
-
- await redis_client.close()
+ sessions.append(
+ {
+ "session_id": session_id,
+ "user_id": session_data.get(b"user_id", b"").decode(),
+ "client_id": session_data.get(b"client_id", b"").decode(),
+ "stream_name": session_data.get(b"stream_name", b"").decode(),
+ "provider": session_data.get(b"provider", b"").decode(),
+ "mode": session_data.get(b"mode", b"").decode(),
+ "status": session_data.get(b"status", b"").decode(),
+ "started_at": session_data.get(b"started_at", b"").decode(),
+ "chunks_published": int(
+ session_data.get(b"chunks_published", b"0").decode() or 0
+ ),
+ "last_chunk_at": session_data.get(b"last_chunk_at", b"").decode(),
+ "conversation_count": conversation_count,
+ }
+ )
+ except Exception as e:
+ logger.error(f"Error getting session info for {key}: {e}")
- return {"total_sessions": len(sessions), "sessions": sessions}
+ return {"total_sessions": len(sessions), "sessions": sessions}
+ finally:
+ await redis_client.aclose()
except Exception as e:
logger.error(f"Failed to get sessions: {e}", exc_info=True)
@@ -989,43 +928,40 @@ async def clear_old_sessions(
from advanced_omi_backend.controllers.queue_controller import REDIS_URL
redis_client = aioredis.from_url(REDIS_URL)
- current_time = time.time()
- cutoff_time = current_time - older_than_seconds
-
- # Get all session keys
- session_keys = []
- cursor = b"0"
- while cursor:
- cursor, keys = await redis_client.scan(
- cursor, match="audio:session:*", count=100
- )
- session_keys.extend(keys)
-
- # Check each session and delete if old
- deleted_count = 0
- for key in session_keys:
- try:
- session_data = await redis_client.hgetall(key)
- if session_data:
- last_chunk_at = session_data.get(b"last_chunk_at", b"").decode()
- if last_chunk_at:
- last_chunk_time = float(last_chunk_at)
- if last_chunk_time < cutoff_time:
- await redis_client.delete(key)
- deleted_count += 1
- logger.info(f"Deleted old session: {key.decode()}")
- except Exception as e:
- logger.error(f"Error processing session {key}: {e}")
-
- await redis_client.close()
-
- return {"deleted_count": deleted_count, "cutoff_seconds": older_than_seconds}
+ try:
+ current_time = time.time()
+ cutoff_time = current_time - older_than_seconds
+
+ # Get all session keys
+ session_keys = []
+ cursor = b"0"
+ while cursor:
+ cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=100)
+ session_keys.extend(keys)
+
+ # Check each session and delete if old
+ deleted_count = 0
+ for key in session_keys:
+ try:
+ session_data = await redis_client.hgetall(key)
+ if session_data:
+ last_chunk_at = session_data.get(b"last_chunk_at", b"").decode()
+ if last_chunk_at:
+ last_chunk_time = float(last_chunk_at)
+ if last_chunk_time < cutoff_time:
+ await redis_client.delete(key)
+ deleted_count += 1
+ logger.info(f"Deleted old session: {key.decode()}")
+ except Exception as e:
+ logger.error(f"Error processing session {key}: {e}")
+
+ return {"deleted_count": deleted_count, "cutoff_seconds": older_than_seconds}
+ finally:
+ await redis_client.aclose()
except Exception as e:
logger.error(f"Failed to clear sessions: {e}", exc_info=True)
- raise HTTPException(
- status_code=500, detail=f"Failed to clear sessions: {str(e)}"
- )
+ raise HTTPException(status_code=500, detail=f"Failed to clear sessions: {str(e)}")
@router.get("/dashboard")
@@ -1077,17 +1013,11 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100):
if status_name == "queued":
job_ids = queue.job_ids[:limit]
elif status_name == "started": # RQ standard, not "processing"
- job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[
- :limit
- ]
+ job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[:limit]
elif status_name == "finished": # RQ standard, not "completed"
- job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[
- :limit
- ]
+ job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[:limit]
elif status_name == "failed":
- job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[
- :limit
- ]
+ job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[:limit]
else:
continue
@@ -1098,9 +1028,7 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100):
# Check user permission
if not current_user.is_superuser:
- job_user_id = (
- job.kwargs.get("user_id") if job.kwargs else None
- )
+ job_user_id = job.kwargs.get("user_id") if job.kwargs else None
if job_user_id != str(current_user.user_id):
continue
@@ -1109,38 +1037,24 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100):
{
"job_id": job.id,
"job_type": (
- job.func_name.split(".")[-1]
- if job.func_name
- else "unknown"
- ),
- "user_id": (
- job.kwargs.get("user_id")
- if job.kwargs
- else None
+ job.func_name.split(".")[-1] if job.func_name else "unknown"
),
+ "user_id": (job.kwargs.get("user_id") if job.kwargs else None),
"status": status_name,
"priority": "normal", # RQ doesn't have priority concept
"data": {"description": job.description or ""},
"result": job.result,
"meta": job.meta if job.meta else {},
"kwargs": job.kwargs if job.kwargs else {},
- "error_message": (
- str(job.exc_info) if job.exc_info else None
- ),
+ "error_message": (str(job.exc_info) if job.exc_info else None),
"created_at": (
- job.created_at.isoformat()
- if job.created_at
- else None
+ job.created_at.isoformat() if job.created_at else None
),
"started_at": (
- job.started_at.isoformat()
- if job.started_at
- else None
+ job.started_at.isoformat() if job.started_at else None
),
"ended_at": (
- job.ended_at.isoformat()
- if job.ended_at
- else None
+ job.ended_at.isoformat() if job.ended_at else None
),
"retry_count": 0, # RQ doesn't track this by default
"max_retries": 0,
@@ -1256,11 +1170,7 @@ def get_job_status(job):
# Check user permission
if not current_user.is_superuser:
- job_user_id = (
- job.kwargs.get("user_id")
- if job.kwargs
- else None
- )
+ job_user_id = job.kwargs.get("user_id") if job.kwargs else None
if job_user_id != str(current_user.user_id):
continue
@@ -1276,19 +1186,13 @@ def get_job_status(job):
"queue": queue_name,
"status": get_job_status(job),
"created_at": (
- job.created_at.isoformat()
- if job.created_at
- else None
+ job.created_at.isoformat() if job.created_at else None
),
"started_at": (
- job.started_at.isoformat()
- if job.started_at
- else None
+ job.started_at.isoformat() if job.started_at else None
),
"ended_at": (
- job.ended_at.isoformat()
- if job.ended_at
- else None
+ job.ended_at.isoformat() if job.ended_at else None
),
"description": job.description or "",
"result": job.result,
@@ -1351,20 +1255,12 @@ async def fetch_events():
)
queued_jobs = results[0] if not isinstance(results[0], Exception) else []
- started_jobs = (
- results[1] if not isinstance(results[1], Exception) else []
- ) # RQ standard
- finished_jobs = (
- results[2] if not isinstance(results[2], Exception) else []
- ) # RQ standard
+ started_jobs = results[1] if not isinstance(results[1], Exception) else [] # RQ standard
+ finished_jobs = results[2] if not isinstance(results[2], Exception) else [] # RQ standard
failed_jobs = results[3] if not isinstance(results[3], Exception) else []
- stats = (
- results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0}
- )
+ stats = results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0}
streaming_status = (
- results[5]
- if not isinstance(results[5], Exception)
- else {"active_sessions": []}
+ results[5] if not isinstance(results[5], Exception) else {"active_sessions": []}
)
events = results[6] if not isinstance(results[6], Exception) else []
recent_conversations = []
@@ -1383,9 +1279,7 @@ async def fetch_events():
{
"conversation_id": conv.conversation_id,
"user_id": str(conv.user_id) if conv.user_id else None,
- "created_at": (
- conv.created_at.isoformat() if conv.created_at else None
- ),
+ "created_at": (conv.created_at.isoformat() if conv.created_at else None),
"title": conv.title,
"summary": conv.summary,
"transcript_text": (
@@ -1413,6 +1307,4 @@ async def fetch_events():
except Exception as e:
logger.error(f"Failed to get dashboard data: {e}", exc_info=True)
- raise HTTPException(
- status_code=500, detail=f"Failed to get dashboard data: {str(e)}"
- )
+ raise HTTPException(status_code=500, detail=f"Failed to get dashboard data: {str(e)}")
diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py b/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py
index f596214f..02e5b37c 100644
--- a/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py
+++ b/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py
@@ -33,7 +33,15 @@ def name(self) -> str:
"""Return the provider name for logging."""
return "mock"
- async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False) -> dict:
+ async def transcribe(
+ self,
+ audio_data: bytes,
+ sample_rate: int,
+ diarize: bool = False,
+ context_info=None,
+ progress_callback=None,
+ **kwargs,
+ ) -> dict:
"""
Return a predefined mock transcript or raise exception in fail mode.
@@ -41,6 +49,9 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool =
audio_data: Raw audio bytes (ignored in mock)
sample_rate: Audio sample rate (ignored in mock)
diarize: Whether to enable speaker diarization (ignored in mock)
+ context_info: Optional ASR context (ignored in mock)
+ progress_callback: Optional callback for batch progress (ignored in mock)
+ **kwargs: Additional parameters (ignored in mock)
Returns:
Dictionary containing predefined transcript with words and segments
@@ -88,20 +99,9 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool =
]
# Mock segments (single speaker for simplicity)
- segments = [
- {
- "speaker": 0,
- "start": 0.0,
- "end": 6.5,
- "text": mock_transcript
- }
- ]
+ segments = [{"speaker": 0, "start": 0.0, "end": 6.5, "text": mock_transcript}]
- return {
- "text": mock_transcript,
- "words": words,
- "segments": segments if diarize else []
- }
+ return {"text": mock_transcript, "words": words, "segments": segments if diarize else []}
async def connect(self, client_id: Optional[str] = None):
"""Initialize the mock provider (no-op)."""