diff --git a/backends/advanced/Docs/plugin-development-guide.md b/backends/advanced/Docs/plugin-development-guide.md index 17c53b4a..a7361469 100644 --- a/backends/advanced/Docs/plugin-development-guide.md +++ b/backends/advanced/Docs/plugin-development-guide.md @@ -24,11 +24,6 @@ Chronicle's plugin system allows you to extend functionality by subscribing to e - **Configurable**: YAML-based configuration with environment variable support - **Isolated**: Each plugin runs independently with proper error handling -### Plugin Types - -- **Core Plugins**: Built-in plugins (`homeassistant`, `test_event`) -- **Community Plugins**: Auto-discovered plugins in `plugins/` directory - ## Quick Start ### 1. Generate Plugin Boilerplate @@ -207,6 +202,84 @@ async def on_memory_processed(self, context: PluginContext): await self.index_memory(memory) ``` +### 4. Button Events (`button.single_press`, `button.double_press`) + +**When**: OMI device button is pressed +**Context Data**: +- `state` (str): Button state (`SINGLE_TAP`, `DOUBLE_TAP`) +- `timestamp` (float): Unix timestamp of the event +- `audio_uuid` (str): Current audio session UUID (may be None) +- `session_id` (str): Streaming session ID (for conversation close) +- `client_id` (str): Client device identifier + +**Data Flow**: +``` +OMI Device (BLE) + → Button press on physical device + → BLE characteristic notifies with 8-byte payload + ↓ +friend-lite-sdk (extras/friend-lite-sdk/) + → parse_button_event() converts payload → ButtonState IntEnum + ↓ +BLE Client (extras/local-omi-bt/ or mobile app) + → Formats as Wyoming protocol: {"type": "button-event", "data": {"state": "SINGLE_TAP"}} + → Sends over WebSocket + ↓ +Backend (websocket_controller.py) + → _handle_button_event() stores marker on client_state + → Maps ButtonState → PluginEvent using enums (plugins/events.py) + → Dispatches granular event to plugin system + ↓ +Plugin System + → Routed to subscribed plugins (e.g., test_button_actions) + → Plugins use PluginServices for system actions and cross-plugin calls +``` + +**Use Cases**: +- Close current conversation (single press) +- Toggle smart home devices (double press) +- Custom actions via cross-plugin communication + +**Example**: +```python +async def on_button_event(self, context: PluginContext): + if context.event == PluginEvent.BUTTON_SINGLE_PRESS: + session_id = context.data.get('session_id') + await context.services.close_conversation(session_id) +``` + +### 5. Plugin Action Events (`plugin_action`) + +**When**: Another plugin calls `context.services.call_plugin()` +**Context Data**: +- `action` (str): Action name (e.g., `toggle_lights`) +- Plus any additional data from the calling plugin + +**Use Cases**: +- Cross-plugin communication (button press → toggle lights) +- Service orchestration between plugins + +**Example**: +```python +async def on_plugin_action(self, context: PluginContext): + action = context.data.get('action') + if action == 'toggle_lights': + # Handle the action + ... +``` + +### PluginServices + +Plugins receive a `services` object on the context for system and cross-plugin interaction: + +```python +# Close the current conversation (triggers post-processing) +await context.services.close_conversation(session_id, reason) + +# Call another plugin's on_plugin_action() handler +result = await context.services.call_plugin("homeassistant", "toggle_lights", data) +``` + ## Creating Your First Plugin ### Step 1: Generate Boilerplate @@ -225,7 +298,7 @@ import logging import re from typing import Any, Dict, List, Optional -from ..base import BasePlugin, PluginContext, PluginResult +from advanced_omi_backend.plugins.base import BasePlugin, PluginContext, PluginResult logger = logging.getLogger(__name__) @@ -671,7 +744,7 @@ async def on_conversation_complete(self, context): **Solution**: - Restart backend after adding dependencies - Verify imports are from correct modules -- Check relative imports use `..base` for base classes +- Use absolute imports for framework classes: `from advanced_omi_backend.plugins.base import BasePlugin` ### Database Connection Issues @@ -749,12 +822,13 @@ class ExternalServicePlugin(BasePlugin): ## Resources -- **Base Plugin Class**: `backends/advanced/src/advanced_omi_backend/plugins/base.py` -- **Example Plugins**: +- **Plugin Framework**: `backends/advanced/src/advanced_omi_backend/plugins/` (base.py, router.py, events.py, services.py) +- **Plugin Implementations**: `plugins/` at repo root - Email Summarizer: `plugins/email_summarizer/` - Home Assistant: `plugins/homeassistant/` - Test Event: `plugins/test_event/` -- **Plugin Generator**: `scripts/create_plugin.py` + - Test Button Actions: `plugins/test_button_actions/` +- **Plugin Generator**: `backends/advanced/scripts/create_plugin.py` - **Configuration**: `config/plugins.yml.template` ## Contributing Plugins diff --git a/backends/advanced/docker-compose-test.yml b/backends/advanced/docker-compose-test.yml index a18b0493..86bb8325 100644 --- a/backends/advanced/docker-compose-test.yml +++ b/backends/advanced/docker-compose-test.yml @@ -21,6 +21,7 @@ services: - ../../config:/app/config # Mount config directory with defaults.yml - ../../tests/configs:/app/test-configs:ro # Mount test-specific configs - ${PLUGINS_CONFIG:-../../tests/config/plugins.test.yml}:/app/config/plugins.yml # Mount test plugins config to correct location + - ../../plugins:/app/plugins # External plugins directory environment: # Override with test-specific settings - MONGODB_URI=mongodb://mongo-test:27017/test_db @@ -223,6 +224,7 @@ services: - ../../config:/app/config # Mount config directory with defaults.yml - ../../tests/configs:/app/test-configs:ro # Mount test-specific configs - ${PLUGINS_CONFIG:-../../tests/config/plugins.test.yml}:/app/config/plugins.yml # Mount test plugins config to correct location + - ../../plugins:/app/plugins # External plugins directory environment: # Same environment as backend - MONGODB_URI=mongodb://mongo-test:27017/test_db diff --git a/backends/advanced/docker-compose.yml b/backends/advanced/docker-compose.yml index 95cc4cab..d8e33c7e 100644 --- a/backends/advanced/docker-compose.yml +++ b/backends/advanced/docker-compose.yml @@ -40,6 +40,7 @@ services: - ./data/debug_dir:/app/debug_dir - ./data:/app/data - ../../config:/app/config # Mount entire config directory (includes config.yml, defaults.yml, plugins.yml) + - ../../plugins:/app/plugins # External plugins directory environment: - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} - PARAKEET_ASR_URL=${PARAKEET_ASR_URL} @@ -95,6 +96,7 @@ services: - ./data/audio_chunks:/app/audio_chunks - ./data:/app/data - ../../config:/app/config # Mount entire config directory (includes config.yml, defaults.yml, plugins.yml) + - ../../plugins:/app/plugins # External plugins directory environment: - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} - PARAKEET_ASR_URL=${PARAKEET_ASR_URL} @@ -212,8 +214,8 @@ services: - "6033:6033" # gRPC - "6034:6034" # HTTP volumes: - - ./data/qdrant_data:/qdrant/storage - + - ./data/qdrant_data:/qdrant/storage + restart: unless-stopped mongo: image: mongo:8.0.14 @@ -227,6 +229,7 @@ services: timeout: 5s retries: 5 start_period: 10s + restart: unless-stopped redis: image: redis:7-alpine @@ -235,6 +238,7 @@ services: volumes: - ./data/redis_data:/data command: redis-server --appendonly yes + restart: unless-stopped healthcheck: test: ["CMD", "redis-cli", "ping"] interval: 5s @@ -267,9 +271,6 @@ services: timeout: 10s retries: 5 start_period: 30s - profiles: - - obsidian - - knowledge-graph # ollama: # image: ollama/ollama:latest diff --git a/backends/advanced/init.py b/backends/advanced/init.py index ff57242d..39561c71 100644 --- a/backends/advanced/init.py +++ b/backends/advanced/init.py @@ -444,41 +444,44 @@ def setup_optional_services(self): self.config["TS_AUTHKEY"] = self.args.ts_authkey self.console.print(f"[green][SUCCESS][/green] Tailscale auth key configured (Docker integration enabled)") + def setup_neo4j(self): + """Configure Neo4j credentials (always required - used by Knowledge Graph)""" + neo4j_password = getattr(self.args, 'neo4j_password', None) + + if neo4j_password: + self.console.print(f"[green]✅[/green] Neo4j: password configured via wizard") + else: + # Interactive prompt (standalone init.py run) + self.console.print() + self.console.print("[bold cyan]Neo4j Configuration[/bold cyan]") + self.console.print("Neo4j is used for Knowledge Graph (entity/relationship extraction)") + self.console.print() + neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") + + self.config["NEO4J_HOST"] = "neo4j" + self.config["NEO4J_USER"] = "neo4j" + self.config["NEO4J_PASSWORD"] = neo4j_password + self.console.print("[green][SUCCESS][/green] Neo4j credentials configured") + def setup_obsidian(self): - """Configure Obsidian/Neo4j integration""" - # Check if enabled via command line + """Configure Obsidian integration (optional feature flag only - Neo4j credentials handled by setup_neo4j)""" if hasattr(self.args, 'enable_obsidian') and self.args.enable_obsidian: enable_obsidian = True - neo4j_password = getattr(self.args, 'neo4j_password', None) - - if not neo4j_password: - self.console.print("[yellow][WARNING][/yellow] --enable-obsidian provided but no password") - neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") - - self.console.print(f"[green]✅[/green] Obsidian/Neo4j: enabled (configured via wizard)") + self.console.print(f"[green]✅[/green] Obsidian: enabled (configured via wizard)") else: # Interactive prompt (fallback) self.console.print() - self.console.print("[bold cyan]Obsidian/Neo4j Integration[/bold cyan]") + self.console.print("[bold cyan]Obsidian Integration (Optional)[/bold cyan]") self.console.print("Enable graph-based knowledge management for Obsidian vault notes") self.console.print() try: - enable_obsidian = Confirm.ask("Enable Obsidian/Neo4j integration?", default=False) + enable_obsidian = Confirm.ask("Enable Obsidian integration?", default=False) except EOFError: self.console.print("Using default: No") enable_obsidian = False - if enable_obsidian: - neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") - if enable_obsidian: - # Update .env with credentials only (secrets, not feature flags) - self.config["NEO4J_HOST"] = "neo4j" - self.config["NEO4J_USER"] = "neo4j" - self.config["NEO4J_PASSWORD"] = neo4j_password - - # Update config.yml with feature flag (source of truth) - auto-saves via ConfigManager self.config_manager.update_memory_config({ "obsidian": { "enabled": True, @@ -486,11 +489,8 @@ def setup_obsidian(self): "timeout": 30 } }) - - self.console.print("[green][SUCCESS][/green] Obsidian/Neo4j configured") - self.console.print("[blue][INFO][/blue] Neo4j will start automatically with --profile obsidian") + self.console.print("[green][SUCCESS][/green] Obsidian integration enabled") else: - # Explicitly disable Obsidian in config.yml when not enabled self.config_manager.update_memory_config({ "obsidian": { "enabled": False, @@ -498,52 +498,25 @@ def setup_obsidian(self): "timeout": 30 } }) - self.console.print("[blue][INFO][/blue] Obsidian/Neo4j integration disabled") + self.console.print("[blue][INFO][/blue] Obsidian integration disabled") def setup_knowledge_graph(self): - """Configure Knowledge Graph (Neo4j-based entity/relationship extraction)""" - # Check if enabled via command line + """Configure Knowledge Graph (Neo4j-based entity/relationship extraction - enabled by default)""" if hasattr(self.args, 'enable_knowledge_graph') and self.args.enable_knowledge_graph: enable_kg = True - neo4j_password = getattr(self.args, 'neo4j_password', None) - - if not neo4j_password: - # Check if already set from obsidian setup - neo4j_password = self.config.get("NEO4J_PASSWORD") - if not neo4j_password: - self.console.print("[yellow][WARNING][/yellow] --enable-knowledge-graph provided but no password") - neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") else: - # Interactive prompt (fallback) self.console.print() self.console.print("[bold cyan]Knowledge Graph (Entity Extraction)[/bold cyan]") - self.console.print("Enable graph-based entity and relationship extraction from conversations") - self.console.print("Extracts: People, Places, Organizations, Events, Promises/Tasks") + self.console.print("Extract people, places, organizations, events, and tasks from conversations") self.console.print() try: - enable_kg = Confirm.ask("Enable Knowledge Graph?", default=False) + enable_kg = Confirm.ask("Enable Knowledge Graph?", default=True) except EOFError: - self.console.print("Using default: No") - enable_kg = False - - if enable_kg: - # Check if Neo4j password already set from obsidian setup - existing_password = self.config.get("NEO4J_PASSWORD") - if existing_password: - self.console.print("[blue][INFO][/blue] Using Neo4j password from Obsidian configuration") - neo4j_password = existing_password - else: - neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") + self.console.print("Using default: Yes") + enable_kg = True if enable_kg: - # Update .env with credentials only (secrets, not feature flags) - self.config["NEO4J_HOST"] = "neo4j" - self.config["NEO4J_USER"] = "neo4j" - if neo4j_password: - self.config["NEO4J_PASSWORD"] = neo4j_password - - # Update config.yml with feature flag (source of truth) - auto-saves via ConfigManager self.config_manager.update_memory_config({ "knowledge_graph": { "enabled": True, @@ -551,12 +524,9 @@ def setup_knowledge_graph(self): "timeout": 30 } }) - - self.console.print("[green][SUCCESS][/green] Knowledge Graph configured") - self.console.print("[blue][INFO][/blue] Neo4j will start automatically with --profile knowledge-graph") + self.console.print("[green][SUCCESS][/green] Knowledge Graph enabled") self.console.print("[blue][INFO][/blue] Entities and relationships will be extracted from conversations") else: - # Explicitly disable Knowledge Graph in config.yml when not enabled self.config_manager.update_memory_config({ "knowledge_graph": { "enabled": False, @@ -577,12 +547,14 @@ def setup_langfuse(self): if langfuse_pub and langfuse_sec: # Auto-configure from wizard — no prompts needed - self.config["LANGFUSE_HOST"] = "http://langfuse-web:3000" + langfuse_host = getattr(self.args, 'langfuse_host', None) or "http://langfuse-web:3000" + self.config["LANGFUSE_HOST"] = langfuse_host self.config["LANGFUSE_PUBLIC_KEY"] = langfuse_pub self.config["LANGFUSE_SECRET_KEY"] = langfuse_sec - self.config["LANGFUSE_BASE_URL"] = "http://langfuse-web:3000" - self.console.print("[green][SUCCESS][/green] LangFuse auto-configured from wizard") - self.console.print(f"[blue][INFO][/blue] Host: http://langfuse-web:3000") + self.config["LANGFUSE_BASE_URL"] = langfuse_host + source = "external" if "langfuse-web" not in langfuse_host else "local" + self.console.print(f"[green][SUCCESS][/green] LangFuse auto-configured ({source})") + self.console.print(f"[blue][INFO][/blue] Host: {langfuse_host}") self.console.print(f"[blue][INFO][/blue] Public key: {self.mask_api_key(langfuse_pub)}") return @@ -842,25 +814,7 @@ def show_next_steps(self): config_yml = self.config_manager.get_full_config() self.console.print("1. Start the main services:") - # Include --profile obsidian/knowledge-graph if enabled (read from config.yml) - obsidian_enabled = config_yml.get("memory", {}).get("obsidian", {}).get("enabled", False) - kg_enabled = config_yml.get("memory", {}).get("knowledge_graph", {}).get("enabled", False) - - profiles = [] - profile_notes = [] - if obsidian_enabled: - profiles.append("obsidian") - profile_notes.append("Obsidian integration") - if kg_enabled: - profiles.append("knowledge-graph") - profile_notes.append("Knowledge Graph") - - if profiles: - profile_args = " ".join([f"--profile {p}" for p in profiles]) - self.console.print(f" [cyan]docker compose {profile_args} up --build -d[/cyan]") - self.console.print(f" [dim](Includes Neo4j for: {', '.join(profile_notes)})[/dim]") - else: - self.console.print(" [cyan]docker compose up --build -d[/cyan]") + self.console.print(" [cyan]docker compose up --build -d[/cyan]") self.console.print() # Auto-determine URLs for next steps @@ -908,6 +862,7 @@ def run(self): self.setup_llm() self.setup_memory() self.setup_optional_services() + self.setup_neo4j() self.setup_obsidian() self.setup_knowledge_graph() self.setup_langfuse() @@ -967,9 +922,11 @@ def main(): parser.add_argument("--ts-authkey", help="Tailscale auth key for Docker integration (default: prompt user)") parser.add_argument("--langfuse-public-key", - help="LangFuse project public key (from langfuse init)") + help="LangFuse project public key (from langfuse init or external)") parser.add_argument("--langfuse-secret-key", - help="LangFuse project secret key (from langfuse init)") + help="LangFuse project secret key (from langfuse init or external)") + parser.add_argument("--langfuse-host", + help="LangFuse host URL (default: http://langfuse-web:3000 for local)") args = parser.parse_args() diff --git a/backends/advanced/pyproject.toml b/backends/advanced/pyproject.toml index 23c736d7..e0e964c0 100644 --- a/backends/advanced/pyproject.toml +++ b/backends/advanced/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "google-auth-oauthlib>=1.0.0", "google-auth-httplib2>=0.2.0", "websockets>=12.0", + "croniter>=1.3.0", ] [project.optional-dependencies] diff --git a/backends/advanced/scripts/create_plugin.py b/backends/advanced/scripts/create_plugin.py index 41b93c83..f24427ad 100755 --- a/backends/advanced/scripts/create_plugin.py +++ b/backends/advanced/scripts/create_plugin.py @@ -37,10 +37,10 @@ def create_plugin(plugin_name: str, force: bool = False): # Convert to class name class_name = snake_to_pascal(plugin_name) + 'Plugin' - # Get plugins directory + # Get plugins directory (repo root plugins/) script_dir = Path(__file__).parent backend_dir = script_dir.parent - plugins_dir = backend_dir / 'src' / 'advanced_omi_backend' / 'plugins' + plugins_dir = backend_dir.parent.parent / 'plugins' plugin_dir = plugins_dir / plugin_name # Check if plugin already exists @@ -83,7 +83,7 @@ def create_plugin(plugin_name: str, force: bool = False): import logging from typing import Any, Dict, List, Optional -from ..base import BasePlugin, PluginContext, PluginResult +from advanced_omi_backend.plugins.base import BasePlugin, PluginContext, PluginResult logger = logging.getLogger(__name__) diff --git a/backends/advanced/src/advanced_omi_backend/app_factory.py b/backends/advanced/src/advanced_omi_backend/app_factory.py index 3c0417eb..6a99c841 100644 --- a/backends/advanced/src/advanced_omi_backend/app_factory.py +++ b/backends/advanced/src/advanced_omi_backend/app_factory.py @@ -221,6 +221,23 @@ async def lifespan(app: FastAPI): # Register OpenMemory user if using openmemory_mcp provider await initialize_openmemory_user() + # Start cron scheduler (requires Redis to be available) + try: + from advanced_omi_backend.cron_scheduler import get_scheduler, register_cron_job + from advanced_omi_backend.workers.finetuning_jobs import ( + run_asr_jargon_extraction_job, + run_speaker_finetuning_job, + ) + + register_cron_job("speaker_finetuning", run_speaker_finetuning_job) + register_cron_job("asr_jargon_extraction", run_asr_jargon_extraction_job) + + scheduler = get_scheduler() + await scheduler.start() + application_logger.info("Cron scheduler started") + except Exception as e: + application_logger.warning(f"Cron scheduler failed to start: {e}") + # SystemTracker is used for monitoring and debugging application_logger.info("Using SystemTracker for monitoring and debugging") @@ -319,6 +336,16 @@ async def lifespan(app: FastAPI): except Exception as e: application_logger.error(f"Error shutting down plugins: {e}") + # Shutdown cron scheduler + try: + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + await scheduler.stop() + application_logger.info("Cron scheduler stopped") + except Exception as e: + application_logger.error(f"Error stopping cron scheduler: {e}") + # Shutdown memory service and speaker service shutdown_memory_service() application_logger.info("Memory and speaker services shut down.") diff --git a/backends/advanced/src/advanced_omi_backend/client.py b/backends/advanced/src/advanced_omi_backend/client.py index a92fbc10..79ee2957 100644 --- a/backends/advanced/src/advanced_omi_backend/client.py +++ b/backends/advanced/src/advanced_omi_backend/client.py @@ -51,6 +51,9 @@ def __init__( # NOTE: Removed in-memory transcript storage for single source of truth # Transcripts are stored only in MongoDB via TranscriptionManager + # Markers (e.g., button events) collected during the session + self.markers: List[dict] = [] + # Track if conversation has been closed self.conversation_closed: bool = False @@ -102,6 +105,10 @@ def update_transcript_received(self): """Update timestamp when transcript is received (for timeout detection).""" self.last_transcript_time = time.time() + def add_marker(self, marker: dict) -> None: + """Add a marker (e.g., button event) to the current session.""" + self.markers.append(marker) + def should_start_new_conversation(self) -> bool: """Check if we should start a new conversation based on timeout.""" if self.last_transcript_time is None: @@ -114,8 +121,7 @@ def should_start_new_conversation(self) -> bool: return time_since_last_transcript > timeout_seconds async def close_current_conversation(self): - """Close the current conversation and queue necessary processing.""" - # Prevent double closure + """Clean up in-memory speech segments for the current conversation.""" if self.conversation_closed: audio_logger.debug( f"🔒 Conversation already closed for client {self.client_id}, skipping" @@ -125,23 +131,15 @@ async def close_current_conversation(self): self.conversation_closed = True if not self.current_audio_uuid: - audio_logger.info(f"🔒 No active conversation to close for client {self.client_id}") return - # NOTE: ClientState is legacy V1 code. In V2 architecture, conversation closure - # is handled by the websocket controllers using RQ jobs directly. - # This method is kept minimal for backward compatibility. + audio_logger.info(f"🔒 Closing conversation state for client {self.client_id}") - audio_logger.info(f"🔒 Closing conversation for client {self.client_id}, audio_uuid: {self.current_audio_uuid}") - - # Clean up speech segments for this conversation if self.current_audio_uuid in self.speech_segments: del self.speech_segments[self.current_audio_uuid] if self.current_audio_uuid in self.current_speech_start: del self.current_speech_start[self.current_audio_uuid] - audio_logger.info(f"✅ Cleaned up state for {self.current_audio_uuid}") - async def start_new_conversation(self): """Start a new conversation by closing current and resetting state.""" await self.close_current_conversation() @@ -151,11 +149,9 @@ async def start_new_conversation(self): self.conversation_start_time = time.time() self.last_transcript_time = None self.conversation_closed = False + self.markers = [] - audio_logger.info( - f"Client {self.client_id}: Started new conversation due to " - f"{NEW_CONVERSATION_TIMEOUT_MINUTES}min timeout" - ) + audio_logger.info(f"Client {self.client_id}: Started new conversation") async def disconnect(self): """Clean disconnect of client state.""" diff --git a/backends/advanced/src/advanced_omi_backend/config.py b/backends/advanced/src/advanced_omi_backend/config.py index 77a842ce..63b1dcd7 100644 --- a/backends/advanced/src/advanced_omi_backend/config.py +++ b/backends/advanced/src/advanced_omi_backend/config.py @@ -198,9 +198,14 @@ def get_misc_settings() -> dict: transcription_cfg = get_backend_config('transcription') transcription_settings = OmegaConf.to_container(transcription_cfg, resolve=True) if transcription_cfg else {} + # Get speaker recognition settings for per_segment_speaker_id + speaker_cfg = get_backend_config('speaker_recognition') + speaker_settings = OmegaConf.to_container(speaker_cfg, resolve=True) if speaker_cfg else {} + return { 'always_persist_enabled': audio_settings.get('always_persist_enabled', False), - 'use_provider_segments': transcription_settings.get('use_provider_segments', False) + 'use_provider_segments': transcription_settings.get('use_provider_segments', False), + 'per_segment_speaker_id': speaker_settings.get('per_segment_speaker_id', False), } @@ -228,4 +233,10 @@ def save_misc_settings(settings: dict) -> bool: if not save_config_section('backend.transcription', transcription_settings): success = False + # Save speaker recognition settings if per_segment_speaker_id is provided + if 'per_segment_speaker_id' in settings: + speaker_settings = {'per_segment_speaker_id': settings['per_segment_speaker_id']} + if not save_config_section('backend.speaker_recognition', speaker_settings): + success = False + return success \ No newline at end of file diff --git a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py index 734df6ed..b1646a8e 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py @@ -8,6 +8,7 @@ """ import logging +import os import time import uuid @@ -24,7 +25,10 @@ from advanced_omi_backend.services.transcription import is_transcription_available from advanced_omi_backend.utils.audio_chunk_utils import convert_audio_to_chunks from advanced_omi_backend.utils.audio_utils import ( + SUPPORTED_AUDIO_EXTENSIONS, + VIDEO_EXTENSIONS, AudioValidationError, + convert_any_to_wav, validate_and_prepare_audio, ) from advanced_omi_backend.workers.transcription_jobs import ( @@ -71,22 +75,38 @@ async def upload_and_process_audio_files( for file_index, file in enumerate(files): try: - # Validate file type (only WAV for now) - if not file.filename or not file.filename.lower().endswith(".wav"): + # Validate file type + filename = file.filename or "unknown" + _, ext = os.path.splitext(filename.lower()) + if not ext or ext not in SUPPORTED_AUDIO_EXTENSIONS: + supported = ", ".join(sorted(SUPPORTED_AUDIO_EXTENSIONS)) processed_files.append({ - "filename": file.filename or "unknown", + "filename": filename, "status": "error", - "error": "Only WAV files are currently supported", + "error": f"Unsupported format '{ext}'. Supported: {supported}", }) continue + is_video_source = ext in VIDEO_EXTENSIONS + audio_logger.info( - f"📁 Uploading file {file_index + 1}/{len(files)}: {file.filename}" + f"📁 Uploading file {file_index + 1}/{len(files)}: {filename}" ) # Read file content content = await file.read() + # Convert non-WAV files to WAV via FFmpeg + if ext != ".wav": + try: + content = await convert_any_to_wav(content, ext) + except AudioValidationError as e: + processed_files.append({ + "filename": filename, + "status": "error", + "error": str(e), + }) + continue # Track external source for deduplication (Google Drive, etc.) external_source_id = None @@ -95,7 +115,7 @@ async def upload_and_process_audio_files( external_source_id = getattr(file, "file_id", None) or getattr(file, "audio_uuid", None) external_source_type = "gdrive" if not external_source_id: - audio_logger.warning(f"Missing file_id for gdrive file: {file.filename}") + audio_logger.warning(f"Missing file_id for gdrive file: {filename}") timestamp = int(time.time() * 1000) # Validate and prepare audio (read format from WAV file) @@ -108,21 +128,18 @@ async def upload_and_process_audio_files( ) except AudioValidationError as e: processed_files.append({ - "filename": file.filename, + "filename": filename, "status": "error", "error": str(e), }) continue audio_logger.info( - f"📊 {file.filename}: {duration:.1f}s ({sample_rate}Hz, {channels}ch, {sample_width} bytes/sample)" + f"📊 {filename}: {duration:.1f}s ({sample_rate}Hz, {channels}ch, {sample_width} bytes/sample)" ) - # Create conversation immediately for uploaded files (conversation_id auto-generated) - version_id = str(uuid.uuid4()) - # Generate title from filename - title = file.filename.rsplit('.', 1)[0][:50] if file.filename else "Uploaded Audio" + title = filename.rsplit('.', 1)[0][:50] if filename != "unknown" else "Uploaded Audio" conversation = create_conversation( user_id=user.user_id, @@ -154,7 +171,7 @@ async def upload_and_process_audio_files( # Handle validation errors (e.g., file too long) audio_logger.error(f"Audio validation failed: {val_error}") processed_files.append({ - "filename": file.filename, + "filename": filename, "status": "error", "error": str(val_error), }) @@ -167,7 +184,7 @@ async def upload_and_process_audio_files( exc_info=True ) processed_files.append({ - "filename": file.filename, + "filename": filename, "status": "error", "error": f"Audio conversion failed: {str(chunk_error)}", }) @@ -187,7 +204,7 @@ async def upload_and_process_audio_files( conversation_id, version_id, "batch", # trigger - job_timeout=1800, # 30 minutes + job_timeout=900, # 15 minutes result_ttl=JOB_RESULT_TTL, job_id=transcribe_job_id, description=f"Transcribe uploaded file {conversation_id[:8]}", @@ -209,15 +226,18 @@ async def upload_and_process_audio_files( client_id=client_id # Pass client_id for UI tracking ) - processed_files.append({ - "filename": file.filename, + file_result = { + "filename": filename, "status": "started", # RQ standard: job has been enqueued "conversation_id": conversation_id, "transcript_job_id": transcription_job.id if transcription_job else None, "speaker_job_id": job_ids['speaker_recognition'], "memory_job_id": job_ids['memory'], "duration_seconds": round(duration, 2), - }) + } + if is_video_source: + file_result["note"] = "Audio extracted from video file" + processed_files.append(file_result) # Build job chain description job_chain = [] @@ -229,23 +249,23 @@ async def upload_and_process_audio_files( job_chain.append(job_ids['memory']) audio_logger.info( - f"✅ Processed {file.filename} → conversation {conversation_id}, " + f"✅ Processed {filename} → conversation {conversation_id}, " f"jobs: {' → '.join(job_chain) if job_chain else 'none'}" ) except (OSError, IOError) as e: # File I/O errors during audio processing - audio_logger.exception(f"File I/O error processing {file.filename}") + audio_logger.exception(f"File I/O error processing {filename}") processed_files.append({ - "filename": file.filename or "unknown", + "filename": filename, "status": "error", "error": str(e), }) except Exception as e: # Unexpected errors during file processing - audio_logger.exception(f"Unexpected error processing file {file.filename}") + audio_logger.exception(f"Unexpected error processing file {filename}") processed_files.append({ - "filename": file.filename or "unknown", + "filename": filename, "status": "error", "error": str(e), }) diff --git a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py index f327a545..13f2620d 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py @@ -3,16 +3,18 @@ """ import logging +import os import time import uuid from datetime import datetime from pathlib import Path +import redis.asyncio as aioredis from fastapi.responses import JSONResponse from advanced_omi_backend.client_manager import ( - ClientManager, client_belongs_to_user, + get_client_manager, ) from advanced_omi_backend.config_loader import get_service_config from advanced_omi_backend.controllers.queue_controller import ( @@ -21,9 +23,13 @@ memory_queue, transcription_queue, ) +from advanced_omi_backend.controllers.session_controller import ( + request_conversation_close, +) from advanced_omi_backend.models.audio_chunk import AudioChunkDocument from advanced_omi_backend.models.conversation import Conversation from advanced_omi_backend.models.job import JobPriority +from advanced_omi_backend.plugins.events import ConversationCloseReason from advanced_omi_backend.users import User from advanced_omi_backend.workers.conversation_jobs import generate_title_summary_job from advanced_omi_backend.workers.memory_jobs import ( @@ -36,8 +42,12 @@ audio_logger = logging.getLogger("audio_processing") -async def close_current_conversation(client_id: str, user: User, client_manager: ClientManager): - """Close the current conversation for a specific client. Users can only close their own conversations.""" +async def close_current_conversation(client_id: str, user: User): + """Close the current conversation for a specific client. + + Signals the open_conversation_job to close the current conversation + and trigger post-processing. The session stays active for new conversations. + """ # Validate client ownership if not user.is_superuser and not client_belongs_to_user(client_id, user.user_id): logger.warning( @@ -51,50 +61,47 @@ async def close_current_conversation(client_id: str, user: User, client_manager: status_code=403, ) - if not client_manager.has_client(client_id): - return JSONResponse( - content={"error": f"Client '{client_id}' not found or not connected"}, - status_code=404, - ) - + client_manager = get_client_manager() client_state = client_manager.get_client(client_id) - if client_state is None: + if client_state is None or not client_state.connected: return JSONResponse( content={"error": f"Client '{client_id}' not found or not connected"}, status_code=404, ) - if not client_state.connected: + session_id = getattr(client_state, 'stream_session_id', None) + if not session_id: return JSONResponse( - content={"error": f"Client '{client_id}' is not connected"}, status_code=400 + content={"error": "No active session"}, + status_code=400, ) + # Signal the conversation job to close and trigger post-processing + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + r = aioredis.from_url(redis_url) try: - # Close the current conversation - await client_state.close_current_conversation() - - # Reset conversation state but keep client connected - client_state.current_audio_uuid = None - client_state.conversation_start_time = time.time() - client_state.last_transcript_time = None - - logger.info(f"Manually closed conversation for client {client_id} by user {user.id}") - - return JSONResponse( - content={ - "message": f"Successfully closed current conversation for client '{client_id}'", - "client_id": client_id, - "timestamp": int(time.time()), - } + success = await request_conversation_close( + r, session_id, reason=ConversationCloseReason.USER_REQUESTED.value ) + finally: + await r.aclose() - except Exception as e: - logger.error(f"Error closing conversation for client {client_id}: {e}") + if not success: return JSONResponse( - content={"error": f"Failed to close conversation: {str(e)}"}, - status_code=500, + content={"error": "Session not found in Redis"}, + status_code=404, ) + logger.info(f"Conversation close requested for client {client_id} by user {user.user_id}") + + return JSONResponse( + content={ + "message": f"Conversation close requested for client '{client_id}'", + "client_id": client_id, + "timestamp": int(time.time()), + } + ) + async def get_conversation(conversation_id: str, user: User): """Get a single conversation with full transcript details.""" @@ -150,40 +157,85 @@ async def get_conversation(conversation_id: str, user: User): return JSONResponse(status_code=500, content={"error": "Error fetching conversation"}) -async def get_conversations(user: User, include_deleted: bool = False): - """Get conversations with speech only (speech-driven architecture).""" +async def get_conversations( + user: User, + include_deleted: bool = False, + include_unprocessed: bool = False, + limit: int = 200, + offset: int = 0, +): + """Get conversations with speech only (speech-driven architecture). + + Uses a single consolidated query with ``$or`` when ``include_unprocessed`` + is True, eliminating multiple round-trips and Python-side merge/sort. + Results are paginated with ``limit``/``offset``. + """ try: - # Build query based on user permissions using Beanie - if not user.is_superuser: - # Regular users can only see their own conversations - # Filter by deleted status - if not include_deleted: - user_conversations = ( - await Conversation.find( - Conversation.user_id == str(user.user_id), Conversation.deleted == False - ) - .sort(-Conversation.created_at) - .to_list() - ) - else: - user_conversations = ( - await Conversation.find(Conversation.user_id == str(user.user_id)) - .sort(-Conversation.created_at) - .to_list() - ) + user_filter = {} if user.is_superuser else {"user_id": str(user.user_id)} + + # Build query conditions — single $or when orphans are requested + conditions = [] + + # Condition 1: normal (non-deleted or all) conversations + if include_deleted: + conditions.append({}) # no filter on deleted else: - # Admins see all conversations - # Filter by deleted status - if not include_deleted: - user_conversations = ( - await Conversation.find(Conversation.deleted == False) - .sort(-Conversation.created_at) - .to_list() + conditions.append({"deleted": False}) + + if include_unprocessed: + # Orphan type 1: always_persist stuck in pending/failed (not deleted) + conditions.append({ + "always_persist": True, + "processing_status": {"$in": ["pending_transcription", "transcription_failed"]}, + "deleted": False, + }) + # Orphan type 2: soft-deleted due to no speech but have audio data + conditions.append({ + "deleted": True, + "deletion_reason": {"$in": [ + "no_meaningful_speech", + "audio_file_not_ready", + "no_meaningful_speech_batch_transcription", + ]}, + "audio_chunks_count": {"$gt": 0}, + }) + + # Assemble final query + if len(conditions) == 1: + query = {**user_filter, **conditions[0]} + else: + query = {**user_filter, "$or": conditions} + + total = await Conversation.find(query).count() + + user_conversations = ( + await Conversation.find(query) + .sort(-Conversation.created_at) + .skip(offset) + .limit(limit) + .to_list() + ) + + # Mark orphans in results (lightweight in-memory check on the page) + orphan_ids: set = set() + if include_unprocessed: + for conv in user_conversations: + is_orphan_type1 = ( + conv.always_persist + and conv.processing_status in ("pending_transcription", "transcription_failed") + and not conv.deleted ) - else: - user_conversations = ( - await Conversation.find_all().sort(-Conversation.created_at).to_list() + is_orphan_type2 = ( + conv.deleted + and conv.deletion_reason in ( + "no_meaningful_speech", + "audio_file_not_ready", + "no_meaningful_speech_batch_transcription", + ) + and (conv.audio_chunks_count or 0) > 0 ) + if is_orphan_type1 or is_orphan_type2: + orphan_ids.add(conv.conversation_id) # Build response with explicit curated fields - minimal for list view conversations = [] @@ -215,10 +267,16 @@ async def get_conversations(user: User, include_deleted: bool = False): "memory_version_count": conv.memory_version_count, "active_transcript_version_number": conv.active_transcript_version_number, "active_memory_version_number": conv.active_memory_version_number, + "is_orphan": conv.conversation_id in orphan_ids, } ) - return {"conversations": conversations} + return { + "conversations": conversations, + "total": total, + "limit": limit, + "offset": offset, + } except Exception as e: logger.exception(f"Error fetching conversations: {e}") @@ -440,6 +498,134 @@ async def restore_conversation(conversation_id: str, user: User) -> JSONResponse ) +async def reprocess_orphan(conversation_id: str, user: User): + """Reprocess an orphan audio session - restore if deleted and enqueue full processing chain.""" + try: + conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + if not conversation: + return JSONResponse(status_code=404, content={"error": "Conversation not found"}) + + # Check ownership + if not user.is_superuser and conversation.user_id != str(user.user_id): + return JSONResponse(status_code=403, content={"error": "Access forbidden"}) + + # Verify audio chunks exist (check both deleted and non-deleted) + total_chunks = await AudioChunkDocument.find( + AudioChunkDocument.conversation_id == conversation_id + ).count() + + if total_chunks == 0: + return JSONResponse( + status_code=400, + content={"error": "No audio data found for this conversation"}, + ) + + # If conversation is soft-deleted, restore it and its chunks + if conversation.deleted: + await AudioChunkDocument.find( + AudioChunkDocument.conversation_id == conversation_id, + AudioChunkDocument.deleted == True, + ).update_many({"$set": {"deleted": False, "deleted_at": None}}) + + conversation.deleted = False + conversation.deletion_reason = None + conversation.deleted_at = None + + # Set processing status and update title + conversation.processing_status = "reprocessing" + conversation.title = "Reprocessing..." + conversation.summary = None + conversation.detailed_summary = None + await conversation.save() + + # Create new transcript version ID + version_id = str(uuid.uuid4()) + + # Enqueue the same 4-job chain as reprocess_transcript + from advanced_omi_backend.workers.transcription_jobs import ( + transcribe_full_audio_job, + ) + + # Job 1: Transcribe audio + transcript_job = transcription_queue.enqueue( + transcribe_full_audio_job, + conversation_id, + version_id, + "reprocess_orphan", + job_timeout=900, + result_ttl=JOB_RESULT_TTL, + job_id=f"orphan_transcribe_{conversation_id[:8]}", + description=f"Transcribe orphan audio for {conversation_id[:8]}", + meta={"conversation_id": conversation_id}, + ) + + # Job 2: Speaker recognition (conditional) + speaker_config = get_service_config("speaker_recognition") + speaker_enabled = speaker_config.get("enabled", True) + speaker_dependency = transcript_job + speaker_job = None + + if speaker_enabled: + speaker_job = transcription_queue.enqueue( + recognise_speakers_job, + conversation_id, + version_id, + depends_on=transcript_job, + job_timeout=600, + result_ttl=JOB_RESULT_TTL, + job_id=f"orphan_speaker_{conversation_id[:8]}", + description=f"Recognize speakers for orphan {conversation_id[:8]}", + meta={"conversation_id": conversation_id}, + ) + speaker_dependency = speaker_job + + # Job 3: Extract memories + memory_job = memory_queue.enqueue( + process_memory_job, + conversation_id, + depends_on=speaker_dependency, + job_timeout=1800, + result_ttl=JOB_RESULT_TTL, + job_id=f"orphan_memory_{conversation_id[:8]}", + description=f"Extract memories for orphan {conversation_id[:8]}", + meta={"conversation_id": conversation_id}, + ) + + # Job 4: Generate title/summary + title_summary_job = default_queue.enqueue( + generate_title_summary_job, + conversation_id, + job_timeout=300, + result_ttl=JOB_RESULT_TTL, + depends_on=memory_job, + job_id=f"orphan_title_{conversation_id[:8]}", + description=f"Generate title/summary for orphan {conversation_id[:8]}", + meta={"conversation_id": conversation_id, "trigger": "reprocess_orphan"}, + ) + + logger.info( + f"Enqueued orphan reprocessing chain for {conversation_id}: " + f"transcribe={transcript_job.id} → speaker={'skipped' if not speaker_job else speaker_job.id} " + f"→ memory={memory_job.id} → title={title_summary_job.id}" + ) + + return JSONResponse( + content={ + "message": f"Orphan reprocessing started for conversation {conversation_id}", + "job_id": transcript_job.id, + "title_summary_job_id": title_summary_job.id, + "version_id": version_id, + "status": "queued", + } + ) + + except Exception as e: + logger.error(f"Error starting orphan reprocessing for {conversation_id}: {e}") + return JSONResponse( + status_code=500, content={"error": "Error starting orphan reprocessing"} + ) + + async def reprocess_transcript(conversation_id: str, user: User): """Reprocess transcript for a conversation. Users can only reprocess their own conversations.""" try: @@ -488,7 +674,7 @@ async def reprocess_transcript(conversation_id: str, user: User): conversation_id, version_id, "reprocess", - job_timeout=600, + job_timeout=900, # 15 minutes result_ttl=JOB_RESULT_TTL, job_id=f"reprocess_{conversation_id[:8]}", description=f"Transcribe audio for {conversation_id[:8]}", @@ -722,14 +908,24 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u provider_capabilities.get("diarization", False) or source_version.diarization_source == "provider" ) + has_words = bool(source_version.words) + has_segments = bool(source_version.segments) - if not source_version.words and not (provider_has_diarization and source_version.segments): + if not has_words and not has_segments: return JSONResponse( status_code=400, content={ - "error": "Cannot re-diarize transcript without word timings. Words are required for diarization." + "error": ( + "Cannot re-diarize transcript without word timings or segments. " + "Word timestamps or provider segments are required." + ) }, ) + if not has_words and has_segments and not provider_has_diarization: + logger.warning( + "Reprocessing speakers without word timings; " + "falling back to segment-based identification only." + ) # 5. Check if speaker recognition is enabled speaker_config = get_service_config("speaker_recognition") @@ -752,10 +948,13 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u "reprocessing_type": "speaker_diarization", "source_version_id": source_version_id, "trigger": "manual_reprocess", + "provider_capabilities": provider_capabilities, } - if provider_has_diarization: + use_segments = provider_has_diarization or not has_words + if use_segments: new_segments = source_version.segments # COPY provider segments - new_metadata["provider_capabilities"] = provider_capabilities + if not has_words and not provider_has_diarization: + new_metadata["segments_only"] = True else: new_segments = [] # Empty - will be populated by speaker job @@ -772,7 +971,7 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u ) # Carry over diarization_source so speaker job knows to use segment identification - if provider_has_diarization: + if provider_has_diarization or (not has_words and has_segments): new_version.diarization_source = "provider" # Save conversation with new version diff --git a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py index 7d7d5f2e..9b3a2de9 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py @@ -65,6 +65,37 @@ async def mark_session_complete( logger.info(f"✅ Session {session_id[:12]} marked finished: {reason} [TIME: {mark_time:.3f}]") +async def request_conversation_close( + redis_client, + session_id: str, + reason: str = "user_requested", +) -> bool: + """ + Request closing the current conversation without killing the session. + + Unlike mark_session_complete() which finalizes the entire session, + this signals open_conversation_job to close just the current conversation + and trigger post-processing. The session stays active for new conversations. + + Sets 'conversation_close_requested' field on the session hash. + The open_conversation_job checks this field every poll iteration. + + Args: + redis_client: Redis async client + session_id: Session UUID + reason: Why the conversation is being closed + + Returns: + True if the close request was set, False if session not found + """ + session_key = f"audio:session:{session_id}" + if not await redis_client.exists(session_key): + return False + await redis_client.hset(session_key, "conversation_close_requested", reason) + logger.info(f"🔒 Conversation close requested for session {session_id[:12]}: {reason}") + return True + + async def get_session_info(redis_client, session_id: str) -> Optional[Dict]: """ Get detailed information about a specific session. diff --git a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py index 53e8ff95..c4794a40 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py @@ -353,7 +353,7 @@ async def save_misc_settings_controller(settings: dict): """Save miscellaneous settings.""" try: # Validate settings - valid_keys = {"always_persist_enabled", "use_provider_segments"} + valid_keys = {"always_persist_enabled", "use_provider_segments", "per_segment_speaker_id"} # Filter to only valid keys filtered_settings = {} @@ -1102,8 +1102,7 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict: Success message with list of updated files """ try: - import advanced_omi_backend.plugins - from advanced_omi_backend.services.plugin_service import discover_plugins + from advanced_omi_backend.services.plugin_service import _get_plugins_dir, discover_plugins # Validate plugin exists discovered_plugins = discover_plugins() @@ -1151,7 +1150,7 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict: # 2. Update plugins/{plugin_id}/config.yml (settings with env var references) if 'settings' in config: - plugins_dir = Path(advanced_omi_backend.plugins.__file__).parent + plugins_dir = _get_plugins_dir() plugin_config_path = plugins_dir / plugin_id / "config.yml" # Load current config.yml diff --git a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py index fcf80de4..0dc09396 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py @@ -528,6 +528,14 @@ async def _finalize_streaming_session( # Mark session as finalizing with user_stopped reason (audio-stop event) await audio_stream_producer.finalize_session(session_id, completion_reason="user_stopped") + # Store markers in Redis so open_conversation_job can persist them + if client_state.markers: + session_key = f"audio:session:{session_id}" + await audio_stream_producer.redis_client.hset( + session_key, "markers", json.dumps(client_state.markers) + ) + client_state.markers.clear() + # NOTE: Finalize job disabled - open_conversation_job now handles everything # The open_conversation_job will: # 1. Detect the "finalizing" status @@ -945,6 +953,75 @@ async def _handle_audio_session_stop( return False # Switch back to control mode +async def _handle_button_event( + client_state, + button_state: str, + user_id: str, + client_id: str, +) -> None: + """Handle a button event from the device. + + Stores a marker on the client state and dispatches granular events + to the plugin system using typed enums. + + Args: + client_state: Client state object + button_state: Button state string (e.g., "SINGLE_TAP", "DOUBLE_TAP") + user_id: User ID + client_id: Client ID + """ + from advanced_omi_backend.plugins.events import ( + BUTTON_STATE_TO_EVENT, + ButtonState, + ) + from advanced_omi_backend.services.plugin_service import get_plugin_router + + timestamp = time.time() + audio_uuid = client_state.current_audio_uuid + + application_logger.info( + f"🔘 Button event from {client_id}: {button_state} " + f"(audio_uuid={audio_uuid})" + ) + + # Store marker on client state for later persistence to conversation + marker = { + "type": "button_event", + "state": button_state, + "timestamp": timestamp, + "audio_uuid": audio_uuid, + "client_id": client_id, + } + client_state.add_marker(marker) + + # Map device button state to typed plugin event + try: + button_state_enum = ButtonState(button_state) + except ValueError: + application_logger.warning(f"Unknown button state: {button_state}") + return + + event = BUTTON_STATE_TO_EVENT.get(button_state_enum) + if not event: + application_logger.debug(f"No plugin event mapped for {button_state_enum}") + return + + # Dispatch granular event to plugin system + router = get_plugin_router() + if router: + await router.dispatch_event( + event=event.value, + user_id=user_id, + data={ + "state": button_state_enum.value, + "timestamp": timestamp, + "audio_uuid": audio_uuid, + "session_id": getattr(client_state, 'stream_session_id', None), + "client_id": client_id, + }, + ) + + async def _process_rolling_batch( client_state, user_id: str, @@ -1021,7 +1098,7 @@ async def _process_rolling_batch( conversation_id, version_id, f"rolling_batch_{batch_number}", # trigger - job_timeout=1800, # 30 minutes + job_timeout=900, # 15 minutes result_ttl=JOB_RESULT_TTL, job_id=transcribe_job_id, description=f"Transcribe rolling batch #{batch_number} {conversation_id[:8]}", @@ -1094,6 +1171,10 @@ async def _process_batch_audio_complete( title="Batch Recording", summary="Processing batch audio..." ) + # Attach any markers (e.g., button events) captured during the session + if client_state.markers: + conversation.markers = list(client_state.markers) + client_state.markers.clear() await conversation.insert() conversation_id = conversation.conversation_id # Get the auto-generated ID @@ -1137,7 +1218,7 @@ async def _process_batch_audio_complete( conversation_id, version_id, "batch", # trigger - job_timeout=1800, # 30 minutes + job_timeout=900, # 15 minutes result_ttl=JOB_RESULT_TTL, job_id=transcribe_job_id, description=f"Transcribe batch audio {conversation_id[:8]}", @@ -1385,7 +1466,15 @@ async def handle_pcm_websocket( # Handle keepalive ping from frontend application_logger.debug(f"🏓 Received ping from {client_id}") continue - + + elif header["type"] == "button-event": + button_data = header.get("data", {}) + button_state = button_data.get("state", "unknown") + await _handle_button_event( + client_state, button_state, user.user_id, client_id + ) + continue + else: # Unknown control message type application_logger.debug( @@ -1466,10 +1555,17 @@ async def handle_pcm_websocket( else: application_logger.warning(f"audio-chunk missing payload_length: {payload_length}") continue + elif control_header.get("type") == "button-event": + button_data = control_header.get("data", {}) + button_state = button_data.get("state", "unknown") + await _handle_button_event( + client_state, button_state, user.user_id, client_id + ) + continue else: application_logger.warning(f"Unknown control message during streaming: {control_header.get('type')}") continue - + except json.JSONDecodeError: application_logger.warning(f"Invalid control message during streaming for {client_id}") continue diff --git a/backends/advanced/src/advanced_omi_backend/cron_scheduler.py b/backends/advanced/src/advanced_omi_backend/cron_scheduler.py new file mode 100644 index 00000000..a496516f --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/cron_scheduler.py @@ -0,0 +1,285 @@ +""" +Config-driven asyncio cron scheduler for Chronicle. + +Reads job definitions from config.yml ``cron_jobs`` section, uses ``croniter`` +to compute next-run times, and dispatches registered job functions. State +(last_run / next_run) is persisted in Redis so it survives restarts. + +Usage: + scheduler = get_scheduler() + await scheduler.start() # call during FastAPI lifespan startup + await scheduler.stop() # call during shutdown +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Callable, Coroutine, Dict, List, Optional + +import redis.asyncio as aioredis +from croniter import croniter + +from advanced_omi_backend.config_loader import load_config, save_config_section + +logger = logging.getLogger(__name__) + +# Redis key prefixes +_LAST_RUN_KEY = "cron:last_run:{job_id}" +_NEXT_RUN_KEY = "cron:next_run:{job_id}" + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + +@dataclass +class CronJobConfig: + job_id: str + enabled: bool + schedule: str + description: str + next_run: Optional[datetime] = None + last_run: Optional[datetime] = None + running: bool = False + last_error: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Job registry – maps job_id → async callable +# --------------------------------------------------------------------------- + +JobFunc = Callable[[], Coroutine[Any, Any, dict]] + +_JOB_REGISTRY: Dict[str, JobFunc] = {} + + +def register_cron_job(job_id: str, func: JobFunc) -> None: + """Register a job function so the scheduler can dispatch it.""" + _JOB_REGISTRY[job_id] = func + + +def _get_job_func(job_id: str) -> Optional[JobFunc]: + return _JOB_REGISTRY.get(job_id) + + +# --------------------------------------------------------------------------- +# Scheduler +# --------------------------------------------------------------------------- + +class CronScheduler: + def __init__(self) -> None: + self.jobs: Dict[str, CronJobConfig] = {} + self._running = False + self._task: Optional[asyncio.Task] = None + self._redis: Optional[aioredis.Redis] = None + self._active_tasks: set[asyncio.Task] = set() + + # -- lifecycle ----------------------------------------------------------- + + async def start(self) -> None: + """Load config, restore state from Redis, and start the scheduler loop.""" + import os + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + self._redis = aioredis.from_url(redis_url, decode_responses=True) + + self._load_jobs_from_config() + await self._restore_state() + + self._running = True + self._task = asyncio.create_task(self._loop()) + logger.info(f"Cron scheduler started with {len(self.jobs)} jobs") + + async def stop(self) -> None: + """Cancel the scheduler loop and close Redis.""" + self._running = False + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + if self._redis: + await self._redis.close() + logger.info("Cron scheduler stopped") + + # -- public API ---------------------------------------------------------- + + async def run_job_now(self, job_id: str) -> dict: + """Manually trigger a job regardless of schedule.""" + if job_id not in self.jobs: + raise ValueError(f"Unknown cron job: {job_id}") + if self.jobs[job_id].running: + return {"error": f"Job '{job_id}' is already running"} + return await self._execute_job(job_id) + + async def update_job( + self, + job_id: str, + enabled: Optional[bool] = None, + schedule: Optional[str] = None, + ) -> None: + """Update a job's config and persist to config.yml.""" + if job_id not in self.jobs: + raise ValueError(f"Unknown cron job: {job_id}") + + cfg = self.jobs[job_id] + + if schedule is not None: + # Validate cron expression + if not croniter.is_valid(schedule): + raise ValueError(f"Invalid cron expression: {schedule}") + cfg.schedule = schedule + cfg.next_run = croniter(schedule, datetime.now(timezone.utc)).get_next(datetime) + + if enabled is not None: + cfg.enabled = enabled + + # Persist changes to config.yml + save_config_section( + f"cron_jobs.{job_id}", + {"enabled": cfg.enabled, "schedule": cfg.schedule, "description": cfg.description}, + ) + + # Update next_run in Redis + if self._redis and cfg.next_run: + await self._redis.set( + _NEXT_RUN_KEY.format(job_id=job_id), + cfg.next_run.isoformat(), + ) + + logger.info(f"Updated cron job '{job_id}': enabled={cfg.enabled}, schedule={cfg.schedule}") + + async def get_all_jobs_status(self) -> List[dict]: + """Return status of all registered cron jobs.""" + result = [] + for job_id, cfg in self.jobs.items(): + result.append({ + "job_id": job_id, + "enabled": cfg.enabled, + "schedule": cfg.schedule, + "description": cfg.description, + "last_run": cfg.last_run.isoformat() if cfg.last_run else None, + "next_run": cfg.next_run.isoformat() if cfg.next_run else None, + "running": cfg.running, + "last_error": cfg.last_error, + }) + return result + + # -- internals ----------------------------------------------------------- + + def _load_jobs_from_config(self) -> None: + """Read cron_jobs section from config.yml.""" + cfg = load_config() + cron_section = cfg.get("cron_jobs", {}) + + for job_id, job_cfg in cron_section.items(): + schedule = str(job_cfg.get("schedule", "0 * * * *")) + if not croniter.is_valid(schedule): + logger.warning(f"Invalid cron expression for job '{job_id}': {schedule} — skipping") + continue + now = datetime.now(timezone.utc) + self.jobs[job_id] = CronJobConfig( + job_id=job_id, + enabled=bool(job_cfg.get("enabled", False)), + schedule=schedule, + description=str(job_cfg.get("description", "")), + next_run=croniter(schedule, now).get_next(datetime), + ) + + async def _restore_state(self) -> None: + """Restore last_run / next_run from Redis.""" + if not self._redis: + return + for job_id, cfg in self.jobs.items(): + try: + lr = await self._redis.get(_LAST_RUN_KEY.format(job_id=job_id)) + if lr: + cfg.last_run = datetime.fromisoformat(lr) + nr = await self._redis.get(_NEXT_RUN_KEY.format(job_id=job_id)) + if nr: + cfg.next_run = datetime.fromisoformat(nr) + except Exception as e: + logger.warning(f"Failed to restore state for job '{job_id}': {e}") + + async def _persist_state(self, job_id: str) -> None: + """Write last_run / next_run to Redis.""" + if not self._redis: + return + cfg = self.jobs[job_id] + try: + if cfg.last_run: + await self._redis.set( + _LAST_RUN_KEY.format(job_id=job_id), + cfg.last_run.isoformat(), + ) + if cfg.next_run: + await self._redis.set( + _NEXT_RUN_KEY.format(job_id=job_id), + cfg.next_run.isoformat(), + ) + except Exception as e: + logger.warning(f"Failed to persist state for job '{job_id}': {e}") + + async def _execute_job(self, job_id: str) -> dict: + """Run the job function and update state.""" + cfg = self.jobs[job_id] + func = _get_job_func(job_id) + if func is None: + msg = f"No function registered for cron job '{job_id}'" + logger.error(msg) + cfg.last_error = msg + return {"error": msg} + + cfg.running = True + cfg.last_error = None + now = datetime.now(timezone.utc) + logger.info(f"Executing cron job '{job_id}'") + + try: + result = await func() + cfg.last_run = now + cfg.next_run = croniter(cfg.schedule, now).get_next(datetime) + await self._persist_state(job_id) + logger.info(f"Cron job '{job_id}' completed: {result}") + return result or {} + except Exception as e: + cfg.last_error = str(e) + logger.error(f"Cron job '{job_id}' failed: {e}", exc_info=True) + # Still advance next_run so we don't spin on failures + cfg.last_run = now + cfg.next_run = croniter(cfg.schedule, now).get_next(datetime) + await self._persist_state(job_id) + return {"error": str(e)} + finally: + cfg.running = False + + async def _loop(self) -> None: + """Main scheduler loop – checks every 30s for due jobs.""" + while self._running: + try: + now = datetime.now(timezone.utc) + for job_id, cfg in self.jobs.items(): + if not cfg.enabled or cfg.running: + continue + if cfg.next_run and now >= cfg.next_run: + task = asyncio.create_task(self._execute_job(job_id)) + self._active_tasks.add(task) + task.add_done_callback(self._active_tasks.discard) + except Exception as e: + logger.error(f"Error in cron scheduler loop: {e}", exc_info=True) + await asyncio.sleep(30) + + +# --------------------------------------------------------------------------- +# Singleton +# --------------------------------------------------------------------------- + +_scheduler: Optional[CronScheduler] = None + + +def get_scheduler() -> CronScheduler: + """Get (or create) the global CronScheduler singleton.""" + global _scheduler + if _scheduler is None: + _scheduler = CronScheduler() + return _scheduler diff --git a/backends/advanced/src/advanced_omi_backend/models/annotation.py b/backends/advanced/src/advanced_omi_backend/models/annotation.py index ac8ceefe..8eecf81a 100644 --- a/backends/advanced/src/advanced_omi_backend/models/annotation.py +++ b/backends/advanced/src/advanced_omi_backend/models/annotation.py @@ -19,6 +19,7 @@ class AnnotationType(str, Enum): MEMORY = "memory" TRANSCRIPT = "transcript" DIARIZATION = "diarization" # Speaker identification corrections + ENTITY = "entity" # Knowledge graph entity corrections (name/details edits) class AnnotationSource(str, Enum): @@ -70,6 +71,12 @@ class Annotation(Document): corrected_speaker: Optional[str] = None # Speaker label after correction segment_start_time: Optional[float] = None # Time offset for reference + # For ENTITY annotations: + # Dual purpose: feeds both the jargon pipeline (entity name corrections = domain vocabulary + # the ASR should know) and the entity extraction pipeline (corrections improve future accuracy). + entity_id: Optional[str] = None # Neo4j entity ID + entity_field: Optional[str] = None # Which field was changed ("name" or "details") + # Processed tracking (applies to ALL annotation types) processed: bool = Field(default=False) # Whether annotation has been applied/sent to training processed_at: Optional[datetime] = None # When annotation was processed @@ -88,11 +95,12 @@ class Settings: # Create indexes on commonly queried fields # Note: Enum fields and Optional fields don't use Indexed() wrapper indexes = [ - "annotation_type", # Query by type (memory vs transcript vs diarization) + "annotation_type", # Query by type (memory vs transcript vs diarization vs entity) "user_id", # User-scoped queries "status", # Filter by status (pending/accepted/rejected) "memory_id", # Lookup annotations for specific memory "conversation_id", # Lookup annotations for specific conversation + "entity_id", # Lookup annotations for specific entity "processed", # Query unprocessed annotations ] @@ -108,6 +116,10 @@ def is_diarization_annotation(self) -> bool: """Check if this is a diarization annotation.""" return self.annotation_type == AnnotationType.DIARIZATION + def is_entity_annotation(self) -> bool: + """Check if this is an entity annotation.""" + return self.annotation_type == AnnotationType.ENTITY + def is_pending_suggestion(self) -> bool: """Check if this is a pending AI suggestion.""" return ( @@ -151,6 +163,18 @@ class DiarizationAnnotationCreate(BaseModel): status: AnnotationStatus = AnnotationStatus.ACCEPTED +class EntityAnnotationCreate(BaseModel): + """Create entity annotation request. + + Dual purpose: feeds both the jargon pipeline (entity name corrections = domain vocabulary + the ASR should know) and the entity extraction pipeline (corrections improve future accuracy). + """ + entity_id: str + entity_field: str # "name" or "details" + original_text: str + corrected_text: str + + class AnnotationResponse(BaseModel): """Annotation response for API.""" id: str @@ -164,6 +188,8 @@ class AnnotationResponse(BaseModel): original_speaker: Optional[str] = None corrected_speaker: Optional[str] = None segment_start_time: Optional[float] = None + entity_id: Optional[str] = None + entity_field: Optional[str] = None processed: bool = False processed_at: Optional[datetime] = None processed_by: Optional[str] = None diff --git a/backends/advanced/src/advanced_omi_backend/models/conversation.py b/backends/advanced/src/advanced_omi_backend/models/conversation.py index 2ec45f33..23fae946 100644 --- a/backends/advanced/src/advanced_omi_backend/models/conversation.py +++ b/backends/advanced/src/advanced_omi_backend/models/conversation.py @@ -39,6 +39,7 @@ class EndReason(str, Enum): INACTIVITY_TIMEOUT = "inactivity_timeout" # No speech detected for threshold period WEBSOCKET_DISCONNECT = "websocket_disconnect" # Connection lost (Bluetooth, network, etc.) MAX_DURATION = "max_duration" # Hit maximum conversation duration + CLOSE_REQUESTED = "close_requested" # External close signal (API, plugin, button) ERROR = "error" # Processing error forced conversation end UNKNOWN = "unknown" # Unknown or legacy reason @@ -122,6 +123,12 @@ class MemoryVersion(BaseModel): description="Compression ratio (compressed_size / original_size), typically ~0.047 for Opus" ) + # Markers (e.g., button events) captured during the session + markers: List[Dict[str, Any]] = Field( + default_factory=list, + description="Markers captured during audio session (button events, bookmarks, etc.)" + ) + # Creation metadata created_at: Indexed(datetime) = Field(default_factory=datetime.utcnow, description="When the conversation was created") @@ -377,7 +384,7 @@ class Settings: "conversation_id", "user_id", "created_at", - [("user_id", 1), ("created_at", -1)], # Compound index for user queries + [("user_id", 1), ("deleted", 1), ("created_at", -1)], # Compound index for paginated list queries IndexModel([("external_source_id", 1)], sparse=True) # Sparse index for deduplication ] diff --git a/backends/advanced/src/advanced_omi_backend/plugins/__init__.py b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py index 3ccea7dc..90c47460 100644 --- a/backends/advanced/src/advanced_omi_backend/plugins/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py @@ -5,6 +5,8 @@ - transcript: When new transcript segment arrives - conversation: When conversation processing completes - memory: After memory extraction finishes +- button: When device button events are received +- plugin_action: Cross-plugin communication Trigger types control when plugins execute: - wake_word: Only when transcript starts with specified wake word @@ -13,6 +15,18 @@ """ from .base import BasePlugin, PluginContext, PluginResult +from .events import ButtonActionType, ButtonState, ConversationCloseReason, PluginEvent from .router import PluginRouter +from .services import PluginServices -__all__ = ['BasePlugin', 'PluginContext', 'PluginResult', 'PluginRouter'] +__all__ = [ + 'BasePlugin', + 'ButtonActionType', + 'ButtonState', + 'ConversationCloseReason', + 'PluginContext', + 'PluginEvent', + 'PluginResult', + 'PluginRouter', + 'PluginServices', +] diff --git a/backends/advanced/src/advanced_omi_backend/plugins/base.py b/backends/advanced/src/advanced_omi_backend/plugins/base.py index bb55128a..2bfe3609 100644 --- a/backends/advanced/src/advanced_omi_backend/plugins/base.py +++ b/backends/advanced/src/advanced_omi_backend/plugins/base.py @@ -18,6 +18,7 @@ class PluginContext: event: str # Event name (e.g., "transcript.streaming", "conversation.complete") data: Dict[str, Any] # Event-specific data metadata: Dict[str, Any] = field(default_factory=dict) + services: Optional[Any] = None # PluginServices instance for system/cross-plugin calls @dataclass @@ -56,24 +57,10 @@ def __init__(self, config: Dict[str, Any]): config: Plugin configuration from config/plugins.yml Contains: enabled, events, condition, and plugin-specific config """ - import logging - logger = logging.getLogger(__name__) - self.config = config self.enabled = config.get('enabled', False) - - # NEW terminology with backward compatibility - self.events = config.get('events') or config.get('subscriptions', []) - self.condition = config.get('condition') or config.get('trigger', {'type': 'always'}) - - # Deprecation warnings - plugin_name = config.get('name', 'unknown') - if 'subscriptions' in config: - logger.warning(f"Plugin '{plugin_name}': 'subscriptions' is deprecated, use 'events' instead") - if 'trigger' in config: - logger.warning(f"Plugin '{plugin_name}': 'condition' is deprecated, use 'condition' instead") - if 'access_level' in config: - logger.warning(f"Plugin '{plugin_name}': 'access_level' is deprecated and ignored") + self.events = config.get('events', []) + self.condition = config.get('condition', {'type': 'always'}) def register_prompts(self, registry) -> None: """Register plugin prompts with the prompt registry. @@ -154,3 +141,30 @@ async def on_memory_processed(self, context: PluginContext) -> Optional[PluginRe PluginResult with success status, optional message, and should_continue flag """ pass + + async def on_button_event(self, context: PluginContext) -> Optional[PluginResult]: + """ + Called when a device button event is received. + + Context data contains: + - state: str - Button state (e.g., "SINGLE_TAP", "DOUBLE_TAP", "LONG_PRESS") + - timestamp: float - Unix timestamp of the event + - audio_uuid: str - Current audio session UUID (may be None) + + Returns: + PluginResult with success status, optional message, and should_continue flag + """ + pass + + async def on_plugin_action(self, context: PluginContext) -> Optional[PluginResult]: + """ + Called when another plugin dispatches an action to this plugin via PluginServices.call_plugin(). + + Context data contains: + - action: str - Action name (e.g., "toggle_lights", "call_service") + - Plus any additional data from the calling plugin + + Returns: + PluginResult with success status, optional message, and should_continue flag + """ + pass diff --git a/backends/advanced/src/advanced_omi_backend/plugins/events.py b/backends/advanced/src/advanced_omi_backend/plugins/events.py new file mode 100644 index 00000000..210c8fd6 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/events.py @@ -0,0 +1,57 @@ +""" +Single source of truth for all plugin event types, button states, and action types. + +All event names, button states, and action types live here. No raw strings anywhere else. +Using str, Enum so values work directly as strings in Redis, YAML, JSON — but code +always references the enum member, never a raw string. +""" + +from enum import Enum +from typing import Dict + + +class PluginEvent(str, Enum): + """All events that can trigger plugins.""" + + # Conversation lifecycle + CONVERSATION_COMPLETE = "conversation.complete" + TRANSCRIPT_STREAMING = "transcript.streaming" + TRANSCRIPT_BATCH = "transcript.batch" + MEMORY_PROCESSED = "memory.processed" + + # Button events (from OMI device) + BUTTON_SINGLE_PRESS = "button.single_press" + BUTTON_DOUBLE_PRESS = "button.double_press" + + # Cross-plugin communication (dispatched by PluginServices.call_plugin) + PLUGIN_ACTION = "plugin_action" + + +class ButtonState(str, Enum): + """Raw button states from OMI device firmware.""" + + SINGLE_TAP = "SINGLE_TAP" + DOUBLE_TAP = "DOUBLE_TAP" + LONG_PRESS = "LONG_PRESS" + + +# Maps device button states to plugin events +BUTTON_STATE_TO_EVENT: Dict[ButtonState, PluginEvent] = { + ButtonState.SINGLE_TAP: PluginEvent.BUTTON_SINGLE_PRESS, + ButtonState.DOUBLE_TAP: PluginEvent.BUTTON_DOUBLE_PRESS, +} + + +class ButtonActionType(str, Enum): + """Types of actions a button press can trigger (from test_button_actions plugin config).""" + + CLOSE_CONVERSATION = "close_conversation" + CALL_PLUGIN = "call_plugin" + + +class ConversationCloseReason(str, Enum): + """Reasons for requesting a conversation close.""" + + USER_REQUESTED = "user_requested" + PLUGIN_REQUESTED = "plugin_requested" + BUTTON_CLOSE = "button_close" diff --git a/backends/advanced/src/advanced_omi_backend/plugins/router.py b/backends/advanced/src/advanced_omi_backend/plugins/router.py index 523fe3ed..422a97da 100644 --- a/backends/advanced/src/advanced_omi_backend/plugins/router.py +++ b/backends/advanced/src/advanced_omi_backend/plugins/router.py @@ -4,12 +4,18 @@ Routes pipeline events to appropriate plugins based on access level and triggers. """ +import json import logging +import os import re import string +import time from typing import Dict, List, Optional +import redis + from .base import BasePlugin, PluginContext, PluginResult +from .events import PluginEvent logger = logging.getLogger(__name__) @@ -86,10 +92,26 @@ def extract_command_after_wake_word(transcript: str, wake_word: str) -> str: class PluginRouter: """Routes pipeline events to appropriate plugins based on event subscriptions""" + _EVENT_LOG_KEY = "system:event_log" + _EVENT_LOG_MAX = 1000 + def __init__(self): self.plugins: Dict[str, BasePlugin] = {} # Index plugins by event for fast lookup self._plugins_by_event: Dict[str, List[str]] = {} + self._services = None + + # Sync Redis for event logging (works from both FastAPI and RQ workers) + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + try: + self._event_redis = redis.from_url(redis_url, decode_responses=True) + except Exception: + logger.warning("Could not connect to Redis for event logging") + self._event_redis = None + + def set_services(self, services) -> None: + """Attach PluginServices instance for injection into plugin contexts.""" + self._services = services def register_plugin(self, plugin_id: str, plugin: BasePlugin): """Register a plugin with the router""" @@ -126,16 +148,15 @@ async def dispatch_event( logger.info(f"🔌 ROUTER: Dispatching '{event}' event (user={user_id})") results = [] + executed = [] # Track per-plugin outcomes for event log # Get plugins subscribed to this event plugin_ids = self._plugins_by_event.get(event, []) - # Add subscription check if not plugin_ids: - logger.warning(f"🔌 ROUTER: No plugins subscribed to event '{event}'") - return results - - logger.info(f"🔌 ROUTER: Found {len(plugin_ids)} subscribed plugin(s): {plugin_ids}") + logger.info(f"🔌 ROUTER: No plugins subscribed to event '{event}'") + else: + logger.info(f"🔌 ROUTER: Found {len(plugin_ids)} subscribed plugin(s): {plugin_ids}") for plugin_id in plugin_ids: plugin = self.plugins[plugin_id] @@ -157,7 +178,8 @@ async def dispatch_event( user_id=user_id, event=event, data=data, - metadata=metadata or {} + metadata=metadata or {}, + services=self._services, ) result = await self._execute_plugin(plugin, event, context) @@ -169,6 +191,7 @@ async def dispatch_event( f"success={result.success}, message={result.message}" ) results.append(result) + executed.append({"plugin_id": plugin_id, "success": result.success, "message": result.message}) # If plugin says stop processing, break if not result.should_continue: @@ -181,6 +204,7 @@ async def dispatch_event( f" ✗ Plugin '{plugin_id}' FAILED with exception: {e}", exc_info=True ) + executed.append({"plugin_id": plugin_id, "success": False, "message": str(e)}) # Add at end logger.info( @@ -188,6 +212,14 @@ async def dispatch_event( f"{len(results)} plugin(s) executed successfully" ) + self._log_event( + event=event, + user_id=user_id, + plugins_subscribed=plugin_ids, + plugins_executed=executed, + metadata=metadata, + ) + return results async def _should_execute(self, plugin: BasePlugin, data: Dict) -> bool: @@ -236,16 +268,66 @@ async def _execute_plugin( context: PluginContext ) -> Optional[PluginResult]: """Execute plugin method for specified event""" - # Map events to plugin callback methods - if event.startswith('transcript.'): + # Map events to plugin callback methods using enums + # str(Enum) comparisons work because PluginEvent inherits from str + if event in (PluginEvent.TRANSCRIPT_STREAMING, PluginEvent.TRANSCRIPT_BATCH): return await plugin.on_transcript(context) - elif event.startswith('conversation.'): + elif event in (PluginEvent.CONVERSATION_COMPLETE,): return await plugin.on_conversation_complete(context) - elif event.startswith('memory.'): + elif event in (PluginEvent.MEMORY_PROCESSED,): return await plugin.on_memory_processed(context) + elif event in (PluginEvent.BUTTON_SINGLE_PRESS, PluginEvent.BUTTON_DOUBLE_PRESS): + return await plugin.on_button_event(context) + elif event == PluginEvent.PLUGIN_ACTION: + return await plugin.on_plugin_action(context) + # Fallback for any unrecognized events (forward compatibility) + logger.warning(f"No handler mapping for event '{event}'") return None + def _log_event( + self, + event: str, + user_id: str, + plugins_subscribed: List[str], + plugins_executed: List[Dict], + metadata: Optional[Dict] = None, + ) -> None: + """Append an event record to the Redis event log (capped list).""" + if not self._event_redis: + return + try: + record = json.dumps({ + "timestamp": time.time(), + "event": event, + "user_id": user_id, + "plugins_subscribed": plugins_subscribed, + "plugins_executed": plugins_executed, + "metadata": metadata or {}, + }) + pipe = self._event_redis.pipeline() + pipe.lpush(self._EVENT_LOG_KEY, record) + pipe.ltrim(self._EVENT_LOG_KEY, 0, self._EVENT_LOG_MAX - 1) + pipe.execute() + except Exception: + logger.debug("Failed to log event to Redis", exc_info=True) + + def get_recent_events(self, limit: int = 50, event_type: Optional[str] = None) -> List[Dict]: + """Read recent events from the Redis log.""" + if not self._event_redis: + return [] + try: + # Fetch more than needed when filtering by type + fetch_count = self._EVENT_LOG_MAX if event_type else limit + raw = self._event_redis.lrange(self._EVENT_LOG_KEY, 0, fetch_count - 1) + events = [json.loads(r) for r in raw] + if event_type: + events = [e for e in events if e.get("event") == event_type][:limit] + return events + except Exception: + logger.debug("Failed to read events from Redis", exc_info=True) + return [] + async def cleanup_all(self): """Clean up all registered plugins""" for plugin_id, plugin in self.plugins.items(): diff --git a/backends/advanced/src/advanced_omi_backend/plugins/services.py b/backends/advanced/src/advanced_omi_backend/plugins/services.py new file mode 100644 index 00000000..0322265b --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/services.py @@ -0,0 +1,99 @@ +""" +PluginServices — typed interface for plugin-to-system and plugin-to-plugin communication. + +Plugins use this interface (via context.services) to interact with the core system +(e.g., close a conversation) or with other plugins (e.g., call Home Assistant to toggle lights). +""" + +import logging +from typing import TYPE_CHECKING, Optional + +from .base import PluginContext, PluginResult +from .events import ConversationCloseReason, PluginEvent + +if TYPE_CHECKING: + from .router import PluginRouter + +logger = logging.getLogger(__name__) + + +class PluginServices: + """Typed interface for plugin-to-system and plugin-to-plugin communication.""" + + def __init__(self, router: "PluginRouter", redis_url: str): + self._router = router + self._redis_url = redis_url + + async def close_conversation( + self, + session_id: str, + reason: ConversationCloseReason = ConversationCloseReason.PLUGIN_REQUESTED, + ) -> bool: + """Request closing the current conversation for a session. + + Signals the open_conversation_job to close the current conversation + and trigger post-processing. The session stays active for new conversations. + + Args: + session_id: The streaming session ID (typically same as client_id) + reason: Why the conversation is being closed + + Returns: + True if the close request was set successfully + """ + import redis.asyncio as aioredis + + from advanced_omi_backend.controllers.session_controller import ( + request_conversation_close, + ) + + r = aioredis.from_url(self._redis_url) + try: + return await request_conversation_close(r, session_id, reason=reason.value) + finally: + await r.aclose() + + async def call_plugin( + self, + plugin_id: str, + action: str, + data: dict, + user_id: str = "system", + ) -> Optional[PluginResult]: + """Dispatch an action to another plugin's on_plugin_action() handler. + + Args: + plugin_id: Target plugin identifier (e.g., "homeassistant") + action: Action name (e.g., "toggle_lights") + data: Action-specific data + user_id: User context for the action + + Returns: + PluginResult from the target plugin, or error result if plugin not found + """ + plugin = self._router.plugins.get(plugin_id) + if not plugin: + logger.warning(f"Plugin '{plugin_id}' not found for cross-plugin call") + return PluginResult(success=False, message=f"Plugin '{plugin_id}' not found") + if not plugin.enabled: + logger.warning(f"Plugin '{plugin_id}' is disabled, cannot call") + return PluginResult(success=False, message=f"Plugin '{plugin_id}' is disabled") + + context = PluginContext( + user_id=user_id, + event=PluginEvent.PLUGIN_ACTION, + data={**data, "action": action}, + services=self, + ) + + try: + result = await plugin.on_plugin_action(context) + if result: + logger.info( + f"Cross-plugin call {plugin_id}.{action}: " + f"success={result.success}, message={result.message}" + ) + return result + except Exception as e: + logger.error(f"Cross-plugin call to {plugin_id}.{action} failed: {e}", exc_info=True) + return PluginResult(success=False, message=f"Plugin action failed: {e}") diff --git a/backends/advanced/src/advanced_omi_backend/prompt_defaults.py b/backends/advanced/src/advanced_omi_backend/prompt_defaults.py index eca71cfc..58a1fdd0 100644 --- a/backends/advanced/src/advanced_omi_backend/prompt_defaults.py +++ b/backends/advanced/src/advanced_omi_backend/prompt_defaults.py @@ -260,6 +260,60 @@ def register_all_defaults(registry: PromptRegistry) -> None: category="memory", ) + # ------------------------------------------------------------------ + # memory.reprocess_speaker_update + # ------------------------------------------------------------------ + registry.register_default( + "memory.reprocess_speaker_update", + template="""\ +You are a memory correction system. A conversation's transcript has been reprocessed with \ +updated speaker identification. The words spoken are the same, but speakers have been \ +re-identified more accurately. Your job is to update the existing memories so they \ +correctly attribute information to the right people. + +## Rules + +1. **UPDATE** — If a memory attributes information to a speaker whose label changed, \ +rewrite it with the correct speaker name. Keep the same `id`. +2. **NONE** — If the memory is unaffected by the speaker changes, leave it unchanged. +3. **DELETE** — If a memory is now nonsensical or completely wrong because the speaker \ +was misidentified (e.g., personal traits wrongly attributed), remove it. +4. **ADD** — If the corrected transcript reveals important new facts that become clear \ +only with the correct speaker attribution, add them. + +## Important guidelines + +- Focus on **speaker attribution corrections**. This is the primary reason for reprocessing. +- A change from "Speaker 0" to "John" means memories referencing "Speaker 0" must now \ +reference "John". +- A change from "Alice" to "Bob" means facts previously attributed to "Alice" must be \ +attributed to "Bob" instead — this is critical because it changes *who* said or did something. +- Preserve the factual content when only the speaker name changes. +- Do NOT add memories that duplicate existing ones. +- When you UPDATE, always include `old_memory` with the previous text. + +## Output format (strict JSON only) + +Return ONLY a valid JSON object with this structure: + +{ + "memory": [ + { + "id": "", + "event": "UPDATE|NONE|DELETE|ADD", + "text": "", + "old_memory": "" + } + ] +} + +Do not output any text outside the JSON object. +""", + name="Reprocess Speaker Update", + description="Updates existing memories after speaker re-identification to correct speaker attribution.", + category="memory", + ) + # ------------------------------------------------------------------ # memory.temporal_extraction # ------------------------------------------------------------------ @@ -343,44 +397,23 @@ def register_all_defaults(registry: PromptRegistry) -> None: ) # ------------------------------------------------------------------ - # conversation.title + # conversation.title_summary # ------------------------------------------------------------------ registry.register_default( - "conversation.title", + "conversation.title_summary", template="""\ -Generate a concise, descriptive title (3-6 words) for this conversation transcript. - -Rules: -- Maximum 6 words -- Capture the main topic or theme -- Do NOT include speaker names or participants -- No quotes or special characters -- Examples: "Planning Weekend Trip", "Work Project Discussion", "Medical Appointment" - -Title:""", - name="Conversation Title", - description="Generates a short title for a conversation from its transcript.", - category="conversation", - ) +Based on the full conversation transcript below, generate a concise title and a brief summary. - # ------------------------------------------------------------------ - # conversation.short_summary - # ------------------------------------------------------------------ - registry.register_default( - "conversation.short_summary", - template="""\ -Generate a brief, informative summary (1-2 sentences, max 120 characters) for this conversation. +Respond in this exact format: +Title: +Summary: Rules: -- Maximum 120 characters -- 1-2 complete sentences -{{speaker_instruction}}- Capture key topics and outcomes -- Use present tense -- Be specific and informative - -Summary:""", - name="Conversation Short Summary", - description="Generates a brief 1-2 sentence summary of a conversation.", +- Title: Maximum 6 words, capture the main topic/theme, no quotes or special characters +- Summary: Maximum 120 characters, capture key topics and outcomes, use present tense +{{speaker_instruction}}""", + name="Conversation Title & Summary", + description="Generates both title and short summary from full conversation context in one LLM call.", category="conversation", variables=["speaker_instruction"], is_dynamic=True, @@ -486,6 +519,42 @@ def register_all_defaults(registry: PromptRegistry) -> None: category="knowledge_graph", ) + # ------------------------------------------------------------------ + # asr.hot_words + # ------------------------------------------------------------------ + registry.register_default( + "asr.hot_words", + template="hey vivi, chronicle, omi", + name="ASR Hot Words", + description="Comma-separated hot words for speech recognition. " + "For Deepgram: boosts keyword recognition via keyterm. " + "For VibeVoice: passed as context_info to guide the LLM backbone. " + "Supports names, technical terms, and domain-specific vocabulary.", + category="asr", + ) + + # ------------------------------------------------------------------ + # asr.jargon_extraction + # ------------------------------------------------------------------ + registry.register_default( + "asr.jargon_extraction", + template="""\ +Extract up to 20 key jargon terms, names, and technical vocabulary from these memory facts. +Return ONLY a comma-separated list of words or short phrases (1-3 words each). +Focus on: proper nouns, technical terms, domain-specific vocabulary, names of people/places/products. +Skip generic everyday words. + +Memory facts: +{{memories}} + +Jargon:""", + name="ASR Jargon Extraction", + description="Extracts key jargon terms from user memories for ASR context boosting.", + category="asr", + variables=["memories"], + is_dynamic=True, + ) + # ------------------------------------------------------------------ # transcription.title_summary # ------------------------------------------------------------------ diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py index f85a99ed..c4e49ce1 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py @@ -19,10 +19,12 @@ AnnotationStatus, AnnotationType, DiarizationAnnotationCreate, + EntityAnnotationCreate, MemoryAnnotationCreate, TranscriptAnnotationCreate, ) from advanced_omi_backend.models.conversation import Conversation +from advanced_omi_backend.services.knowledge_graph import get_knowledge_graph_service from advanced_omi_backend.services.memory import get_memory_service from advanced_omi_backend.users import User @@ -266,6 +268,25 @@ async def update_annotation_status( except Exception as e: logger.error(f"Error applying transcript suggestion: {e}") # Don't fail the status update if segment update fails + elif annotation.is_entity_annotation(): + # Update entity in Neo4j + try: + kg_service = get_knowledge_graph_service() + update_kwargs = {} + if annotation.entity_field == "name": + update_kwargs["name"] = annotation.corrected_text + elif annotation.entity_field == "details": + update_kwargs["details"] = annotation.corrected_text + if update_kwargs: + await kg_service.update_entity( + entity_id=annotation.entity_id, + user_id=annotation.user_id, + **update_kwargs, + ) + logger.info(f"Applied entity suggestion to entity {annotation.entity_id}") + except Exception as e: + logger.error(f"Error applying entity suggestion: {e}") + # Don't fail the status update if entity update fails await annotation.save() logger.info(f"Updated annotation {annotation_id} status to {status}") @@ -282,6 +303,113 @@ async def update_annotation_status( ) +# === Entity Annotation Routes === + + +@router.post("/entity", response_model=AnnotationResponse) +async def create_entity_annotation( + annotation_data: EntityAnnotationCreate, + current_user: User = Depends(current_active_user), +): + """ + Create annotation for entity edit (name or details correction). + + - Validates user owns the entity + - Creates annotation record for jargon/finetuning pipeline + - Applies correction to Neo4j immediately + - Marked as processed=False for downstream cron consumption + + Dual purpose: entity name corrections feed both the jargon pipeline + (domain vocabulary for ASR) and the entity extraction pipeline + (improving future extraction accuracy). + """ + try: + # Validate entity_field + if annotation_data.entity_field not in ("name", "details"): + raise HTTPException( + status_code=400, + detail="entity_field must be 'name' or 'details'", + ) + + # Verify entity exists and belongs to user + kg_service = get_knowledge_graph_service() + entity = await kg_service.get_entity( + entity_id=annotation_data.entity_id, + user_id=current_user.user_id, + ) + if not entity: + raise HTTPException(status_code=404, detail="Entity not found") + + # Create annotation + annotation = Annotation( + annotation_type=AnnotationType.ENTITY, + user_id=current_user.user_id, + entity_id=annotation_data.entity_id, + entity_field=annotation_data.entity_field, + original_text=annotation_data.original_text, + corrected_text=annotation_data.corrected_text, + status=AnnotationStatus.ACCEPTED, + processed=False, # Unprocessed — jargon/finetuning cron will consume later + ) + await annotation.save() + logger.info( + f"Created entity annotation {annotation.id} for entity {annotation_data.entity_id} " + f"field={annotation_data.entity_field}" + ) + + # Apply correction to Neo4j immediately + try: + update_kwargs = {} + if annotation_data.entity_field == "name": + update_kwargs["name"] = annotation_data.corrected_text + elif annotation_data.entity_field == "details": + update_kwargs["details"] = annotation_data.corrected_text + + await kg_service.update_entity( + entity_id=annotation_data.entity_id, + user_id=current_user.user_id, + **update_kwargs, + ) + logger.info(f"Applied entity correction to Neo4j for entity {annotation_data.entity_id}") + except Exception as e: + logger.error(f"Error applying entity correction to Neo4j: {e}") + # Annotation is saved but Neo4j update failed — log but don't fail the request + + return AnnotationResponse.model_validate(annotation) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error creating entity annotation: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Failed to create entity annotation: {str(e)}", + ) + + +@router.get("/entity/{entity_id}", response_model=List[AnnotationResponse]) +async def get_entity_annotations( + entity_id: str, + current_user: User = Depends(current_active_user), +): + """Get all annotations for an entity.""" + try: + annotations = await Annotation.find( + Annotation.annotation_type == AnnotationType.ENTITY, + Annotation.entity_id == entity_id, + Annotation.user_id == current_user.user_id, + ).to_list() + + return [AnnotationResponse.model_validate(a) for a in annotations] + + except Exception as e: + logger.error(f"Error fetching entity annotations: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Failed to fetch entity annotations: {str(e)}", + ) + + # === Diarization Annotation Routes === diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py index 2de13ae7..c4424904 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py @@ -32,10 +32,13 @@ async def close_current_conversation( @router.get("") async def get_conversations( include_deleted: bool = Query(False, description="Include soft-deleted conversations"), + include_unprocessed: bool = Query(False, description="Include orphan audio sessions (always_persist with failed/pending transcription)"), + limit: int = Query(200, ge=1, le=500, description="Max conversations to return"), + offset: int = Query(0, ge=0, description="Number of conversations to skip"), current_user: User = Depends(current_active_user) ): """Get conversations. Admins see all conversations, users see only their own.""" - return await conversation_controller.get_conversations(current_user, include_deleted) + return await conversation_controller.get_conversations(current_user, include_deleted, include_unprocessed, limit, offset) @router.get("/{conversation_id}") @@ -48,6 +51,14 @@ async def get_conversation_detail( # New reprocessing endpoints +@router.post("/{conversation_id}/reprocess-orphan") +async def reprocess_orphan( + conversation_id: str, current_user: User = Depends(current_active_user) +): + """Reprocess an orphan audio session (always_persist conversation with failed/pending transcription).""" + return await conversation_controller.reprocess_orphan(conversation_id, current_user) + + @router.post("/{conversation_id}/reprocess-transcript") async def reprocess_transcript( conversation_id: str, current_user: User = Depends(current_active_user) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py index f3792e0b..72c271b6 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py @@ -1,7 +1,8 @@ """ Fine-tuning routes for Chronicle API. -Handles sending annotation corrections to speaker recognition service for training. +Handles sending annotation corrections to speaker recognition service for training +and cron job management for automated tasks. """ import logging @@ -10,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import JSONResponse +from pydantic import BaseModel from advanced_omi_backend.auth import current_active_user from advanced_omi_backend.models.annotation import Annotation, AnnotationType @@ -56,7 +58,7 @@ async def process_annotations_for_training( # Filter out already trained annotations (processed_by contains "training") ready_for_training = [ a for a in annotations - if a.processed_by and "training" not in a.processed_by + if not a.processed_by or "training" not in a.processed_by ] if not ready_for_training: @@ -96,16 +98,14 @@ async def process_annotations_for_training( conversation = await Conversation.find_one( Conversation.conversation_id == annotation.conversation_id ) - + if not conversation or not conversation.active_transcript: - logger.warning(f"Conversation {annotation.conversation_id} not found or has no transcript") failed_count += 1 errors.append(f"Conversation {annotation.conversation_id[:8]} not found") continue # Validate segment index if annotation.segment_index >= len(conversation.active_transcript.segments): - logger.warning(f"Invalid segment index {annotation.segment_index} for conversation {annotation.conversation_id}") failed_count += 1 errors.append(f"Invalid segment index {annotation.segment_index}") continue @@ -198,7 +198,7 @@ async def process_annotations_for_training( "appended_to_existing": appended_count, "total_processed": total_processed, "failed_count": failed_count, - "errors": errors[:10] if errors else [], # Limit error list + "errors": errors[:10] if errors else [], "status": "success" if total_processed > 0 else "partial_failure" }) @@ -227,48 +227,111 @@ async def get_finetuning_status( - cron_status: Cron job schedule and last run info """ try: - # Count annotations by status - pending_count = await Annotation.find( - Annotation.annotation_type == AnnotationType.DIARIZATION, - Annotation.processed == False, - ).count() - - # Get all processed annotations - all_processed = await Annotation.find( - Annotation.annotation_type == AnnotationType.DIARIZATION, - Annotation.processed == True, - ).to_list() - - # Split into trained vs not-yet-trained - trained_annotations = [ - a for a in all_processed - if a.processed_by and "training" in a.processed_by - ] - applied_not_trained = [ - a for a in all_processed - if not a.processed_by or "training" not in a.processed_by - ] - - applied_count = len(applied_not_trained) - trained_count = len(trained_annotations) + # ------------------------------------------------------------------ + # Per-type annotation counts (with orphan detection) + # ------------------------------------------------------------------ + from advanced_omi_backend.models.conversation import Conversation - # Get last training run timestamp + annotation_counts: dict[str, dict] = {} + trained_diarization_list: list = [] + + # Collect all annotations to batch-check for orphans + all_annotations_by_type: dict[AnnotationType, list] = {} + for ann_type in AnnotationType: + all_annotations_by_type[ann_type] = await Annotation.find( + Annotation.annotation_type == ann_type, + ).to_list() + + # Batch-check which conversation_ids still exist + conv_annotation_types = {AnnotationType.DIARIZATION, AnnotationType.TRANSCRIPT} + all_conv_ids: set[str] = set() + for ann_type in conv_annotation_types: + for a in all_annotations_by_type.get(ann_type, []): + if a.conversation_id: + all_conv_ids.add(a.conversation_id) + + existing_conv_ids: set[str] = set() + if all_conv_ids: + existing_convs = await Conversation.find( + {"conversation_id": {"$in": list(all_conv_ids)}}, + ).to_list() + existing_conv_ids = {c.conversation_id for c in existing_convs} + + orphaned_conv_ids = all_conv_ids - existing_conv_ids + + total_orphaned = 0 + for ann_type in AnnotationType: + annotations = all_annotations_by_type[ann_type] + + # Identify orphaned annotations for conversation-based types + if ann_type in conv_annotation_types: + orphaned = [a for a in annotations if a.conversation_id in orphaned_conv_ids] + non_orphaned = [a for a in annotations if a.conversation_id not in orphaned_conv_ids] + else: + # Memory/entity orphan detection is placeholder for now + orphaned = [] + non_orphaned = annotations + + pending = [a for a in non_orphaned if not a.processed] + processed = [a for a in non_orphaned if a.processed] + trained = [a for a in processed if a.processed_by and "training" in a.processed_by] + applied_not_trained = [ + a for a in processed + if not a.processed_by or "training" not in a.processed_by + ] + + orphan_count = len(orphaned) + total_orphaned += orphan_count + + annotation_counts[ann_type.value] = { + "total": len(non_orphaned), + "pending": len(pending), + "applied": len(applied_not_trained), + "trained": len(trained), + "orphaned": orphan_count, + } + + if ann_type == AnnotationType.DIARIZATION: + trained_diarization_list = trained + + # ------------------------------------------------------------------ + # Diarization-specific fields (backward compat) + # ------------------------------------------------------------------ + diarization = annotation_counts.get("diarization", {}) + pending_count = diarization.get("pending", 0) + applied_count = diarization.get("applied", 0) + trained_count = diarization.get("trained", 0) + + # Get last training run timestamp from diarization annotations last_training_run = None - if trained_annotations: - # Find most recent trained annotation + if trained_diarization_list: latest_trained = max( - trained_annotations, + trained_diarization_list, key=lambda a: a.updated_at if a.updated_at else datetime.min.replace(tzinfo=timezone.utc) ) last_training_run = latest_trained.updated_at.isoformat() if latest_trained.updated_at else None - # TODO: Get cron job status from scheduler - cron_status = { - "enabled": False, # Placeholder - "schedule": "0 2 * * *", # Example: daily at 2 AM - "last_run": None, - "next_run": None, - } + # Get cron job status from scheduler + try: + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + all_jobs = await scheduler.get_all_jobs_status() + # Find speaker finetuning job for backward compat + speaker_job = next((j for j in all_jobs if j["job_id"] == "speaker_finetuning"), None) + cron_status = { + "enabled": speaker_job["enabled"] if speaker_job else False, + "schedule": speaker_job["schedule"] if speaker_job else "0 2 * * *", + "last_run": speaker_job["last_run"] if speaker_job else None, + "next_run": speaker_job["next_run"] if speaker_job else None, + } + except Exception: + cron_status = { + "enabled": False, + "schedule": "0 2 * * *", + "last_run": None, + "next_run": None, + } return JSONResponse(content={ "pending_annotation_count": pending_count, @@ -276,6 +339,8 @@ async def get_finetuning_status( "trained_annotation_count": trained_count, "last_training_run": last_training_run, "cron_status": cron_status, + "annotation_counts": annotation_counts, + "orphaned_annotation_count": total_orphaned, }) except Exception as e: @@ -284,3 +349,154 @@ async def get_finetuning_status( status_code=500, detail=f"Failed to fetch fine-tuning status: {str(e)}", ) + + +# --------------------------------------------------------------------------- +# Orphaned Annotation Management Endpoints +# --------------------------------------------------------------------------- + + +@router.delete("/orphaned-annotations") +async def delete_orphaned_annotations( + current_user: User = Depends(current_active_user), + annotation_type: Optional[str] = Query(None, description="Filter by annotation type (e.g. 'diarization')"), +): + """ + Find and delete orphaned annotations whose referenced conversation no longer exists. + + Only handles conversation-based annotation types (diarization, transcript). + """ + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + from advanced_omi_backend.models.conversation import Conversation + + conv_annotation_types = {AnnotationType.DIARIZATION, AnnotationType.TRANSCRIPT} + + # Filter to requested type if specified + if annotation_type: + try: + requested_type = AnnotationType(annotation_type) + except ValueError: + raise HTTPException(status_code=400, detail=f"Unknown annotation type: {annotation_type}") + if requested_type not in conv_annotation_types: + return JSONResponse(content={"deleted_count": 0, "by_type": {}, "message": "Orphan detection not supported for this type"}) + types_to_check = {requested_type} + else: + types_to_check = conv_annotation_types + + # Collect all conversation_ids referenced by these annotation types + all_conv_ids: set[str] = set() + annotations_by_type: dict[AnnotationType, list] = {} + for ann_type in types_to_check: + annotations = await Annotation.find( + Annotation.annotation_type == ann_type, + ).to_list() + annotations_by_type[ann_type] = annotations + for a in annotations: + if a.conversation_id: + all_conv_ids.add(a.conversation_id) + + if not all_conv_ids: + return JSONResponse(content={"deleted_count": 0, "by_type": {}}) + + # Batch-check which conversations still exist + existing_convs = await Conversation.find( + {"conversation_id": {"$in": list(all_conv_ids)}}, + ).to_list() + existing_conv_ids = {c.conversation_id for c in existing_convs} + orphaned_conv_ids = all_conv_ids - existing_conv_ids + + if not orphaned_conv_ids: + return JSONResponse(content={"deleted_count": 0, "by_type": {}}) + + # Delete orphaned annotations + deleted_by_type: dict[str, int] = {} + total_deleted = 0 + for ann_type, annotations in annotations_by_type.items(): + orphaned = [a for a in annotations if a.conversation_id in orphaned_conv_ids] + for a in orphaned: + await a.delete() + if orphaned: + deleted_by_type[ann_type.value] = len(orphaned) + total_deleted += len(orphaned) + + logger.info(f"Deleted {total_deleted} orphaned annotations: {deleted_by_type}") + return JSONResponse(content={ + "deleted_count": total_deleted, + "by_type": deleted_by_type, + }) + + +@router.post("/orphaned-annotations/reattach") +async def reattach_orphaned_annotations( + current_user: User = Depends(current_active_user), +): + """Placeholder for reattaching orphaned annotations to a different conversation.""" + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + raise HTTPException(status_code=501, detail="Reattach functionality coming soon") + + +# --------------------------------------------------------------------------- +# Cron Job Management Endpoints +# --------------------------------------------------------------------------- + + +class CronJobUpdate(BaseModel): + enabled: Optional[bool] = None + schedule: Optional[str] = None + + +@router.get("/cron-jobs") +async def get_cron_jobs(current_user: User = Depends(current_active_user)): + """List all cron jobs with status, schedule, last/next run.""" + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + return await scheduler.get_all_jobs_status() + + +@router.put("/cron-jobs/{job_id}") +async def update_cron_job( + job_id: str, + body: CronJobUpdate, + current_user: User = Depends(current_active_user), +): + """Update a cron job's schedule or enabled state.""" + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + try: + await scheduler.update_job(job_id, enabled=body.enabled, schedule=body.schedule) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + return {"message": f"Job '{job_id}' updated", "job_id": job_id} + + +@router.post("/cron-jobs/{job_id}/run") +async def run_cron_job_now( + job_id: str, + current_user: User = Depends(current_active_user), +): + """Manually trigger a cron job.""" + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + try: + result = await scheduler.run_job_now(job_id) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + return result diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py index d4680951..1b8ae1cf 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py @@ -13,6 +13,11 @@ from pydantic import BaseModel from advanced_omi_backend.auth import current_active_user +from advanced_omi_backend.models.annotation import ( + Annotation, + AnnotationStatus, + AnnotationType, +) from advanced_omi_backend.services.knowledge_graph import ( KnowledgeGraphService, PromiseStatus, @@ -30,6 +35,13 @@ # ============================================================================= +class UpdateEntityRequest(BaseModel): + """Request model for updating entity fields.""" + name: Optional[str] = None + details: Optional[str] = None + icon: Optional[str] = None + + class UpdatePromiseRequest(BaseModel): """Request model for updating promise status.""" status: str # pending, in_progress, completed, cancelled @@ -144,6 +156,78 @@ async def get_entity_relationships( ) +@router.patch("/entities/{entity_id}") +async def update_entity( + entity_id: str, + request: UpdateEntityRequest, + current_user: User = Depends(current_active_user), +): + """Update an entity's name, details, or icon. + + Also creates entity annotations as a side effect for each changed field. + These annotations feed the jargon and entity extraction pipelines. + """ + try: + if request.name is None and request.details is None and request.icon is None: + raise HTTPException( + status_code=400, + detail="At least one field (name, details, icon) must be provided", + ) + + service = get_knowledge_graph_service() + + # Get current entity for annotation original values + existing = await service.get_entity( + entity_id=entity_id, + user_id=str(current_user.id), + ) + if not existing: + raise HTTPException(status_code=404, detail="Entity not found") + + # Apply update to Neo4j + updated = await service.update_entity( + entity_id=entity_id, + user_id=str(current_user.id), + name=request.name, + details=request.details, + icon=request.icon, + ) + if not updated: + raise HTTPException(status_code=404, detail="Entity not found") + + # Create annotations for changed text fields (name, details) + # These feed the jargon pipeline and entity extraction pipeline. + # Icon changes don't create annotations (not text corrections). + for field in ("name", "details"): + new_value = getattr(request, field) + if new_value is not None: + old_value = getattr(existing, field) or "" + annotation = Annotation( + annotation_type=AnnotationType.ENTITY, + user_id=str(current_user.id), + entity_id=entity_id, + entity_field=field, + original_text=old_value, + corrected_text=new_value, + status=AnnotationStatus.ACCEPTED, + processed=False, + ) + await annotation.save() + logger.info( + f"Created entity annotation for {field} change on entity {entity_id}" + ) + + return {"entity": updated.to_dict()} + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating entity {entity_id}: {e}") + return JSONResponse( + status_code=500, + content={"message": f"Error updating entity: {str(e)}"}, + ) + + @router.delete("/entities/{entity_id}") async def delete_entity( entity_id: str, diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py index 934cf0b1..ff3d460a 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py @@ -321,6 +321,30 @@ def process_job_and_dependents(job, queue_name, base_status): raise HTTPException(status_code=500, detail=f"Failed to get jobs for client: {str(e)}") +@router.get("/events") +async def get_events( + limit: int = Query(50, ge=1, le=200, description="Number of recent events"), + event_type: str = Query(None, description="Filter by event type"), + current_user: User = Depends(current_active_user), +): + """Get recent system events from the event log (admin only).""" + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + try: + from advanced_omi_backend.services.plugin_service import get_plugin_router + + router_instance = get_plugin_router() + if not router_instance: + return {"events": [], "total": 0} + + events = router_instance.get_recent_events(limit=limit, event_type=event_type or None) + return {"events": events, "total": len(events)} + except Exception as e: + logger.error(f"Failed to get events: {e}") + return {"events": [], "total": 0} + + @router.get("/stats") async def get_queue_stats_endpoint( current_user: User = Depends(current_active_user) @@ -1034,6 +1058,21 @@ def get_job_status(job): logger.error(f"Error fetching jobs for client {client_id}: {e}") return {"client_id": client_id, "jobs": []} + async def fetch_events(): + """Fetch recent system events from the event log (admin only).""" + if not current_user.is_superuser: + return [] + try: + from advanced_omi_backend.services.plugin_service import get_plugin_router + + router_instance = get_plugin_router() + if not router_instance: + return [] + return router_instance.get_recent_events(limit=50) + except Exception as e: + logger.error(f"Error fetching events: {e}") + return [] + # Execute all fetches in parallel (using RQ standard status names) queued_jobs_task = fetch_jobs_by_status("queued", limit=100) started_jobs_task = fetch_jobs_by_status("started", limit=100) # RQ standard, not "processing" @@ -1041,6 +1080,7 @@ def get_job_status(job): failed_jobs_task = fetch_jobs_by_status("failed", limit=50) stats_task = fetch_stats() streaming_status_task = fetch_streaming_status() + events_task = fetch_events() client_jobs_tasks = [fetch_client_jobs(cid) for cid in expanded_client_ids] results = await asyncio.gather( @@ -1050,6 +1090,7 @@ def get_job_status(job): failed_jobs_task, stats_task, streaming_status_task, + events_task, *client_jobs_tasks, return_exceptions=True ) @@ -1060,8 +1101,9 @@ def get_job_status(job): failed_jobs = results[3] if not isinstance(results[3], Exception) else [] stats = results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0} streaming_status = results[5] if not isinstance(results[5], Exception) else {"active_sessions": []} + events = results[6] if not isinstance(results[6], Exception) else [] recent_conversations = [] - client_jobs_results = results[6:] if len(results) > 6 else [] + client_jobs_results = results[7:] if len(results) > 7 else [] # Convert client jobs list to dict client_jobs = {} @@ -1092,6 +1134,7 @@ def get_job_status(job): "streaming_status": streaming_status, "recent_conversations": conversations_list, "client_jobs": client_jobs, + "events": events, "timestamp": asyncio.get_event_loop().time() } diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py index e7fae522..dc5e9b27 100644 --- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py +++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/producer.py @@ -4,6 +4,7 @@ import logging import time +import json import redis.asyncio as redis @@ -139,6 +140,19 @@ async def send_session_end_signal(self, session_id: str): stream_name = buffer["stream_name"] # Send special "end" message to signal workers to flush + # Read audio format from Redis session metadata (stored at audio-start time) + sample_rate, channels, sample_width = 16000, 1, 2 + try: + session_key = f"audio:session:{session_id}" + audio_format_raw = await self.redis_client.hget(session_key, "audio_format") + if audio_format_raw: + audio_format = json.loads(audio_format_raw) + sample_rate = int(audio_format.get("rate", 16000)) + channels = int(audio_format.get("channels", 1)) + sample_width = int(audio_format.get("width", 2)) + except Exception: + pass # Fall back to defaults + end_signal = { b"audio_data": b"", # Empty audio data b"session_id": session_id.encode(), @@ -146,9 +160,9 @@ async def send_session_end_signal(self, session_id: str): b"user_id": buffer["user_id"].encode(), b"client_id": buffer["client_id"].encode(), b"timestamp": str(time.time()).encode(), - b"sample_rate": b"16000", - b"channels": b"1", - b"sample_width": b"2", + b"sample_rate": str(sample_rate).encode(), + b"channels": str(channels).encode(), + b"sample_width": str(sample_width).encode(), } await self.redis_client.xadd( diff --git a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py index 6562dccc..5f13508d 100644 --- a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py +++ b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py @@ -535,6 +535,44 @@ async def search_entities( return entities + async def update_entity( + self, + entity_id: str, + user_id: str, + name: str = None, + details: str = None, + icon: str = None, + ) -> Optional[Entity]: + """Update an entity's fields (partial update via COALESCE). + + Args: + entity_id: Entity UUID + user_id: User ID for permission check + name: New name (None keeps existing) + details: New details (None keeps existing) + icon: New icon (None keeps existing) + + Returns: + Updated Entity object or None if not found + """ + self._ensure_initialized() + + results = self._write.run( + queries.UPDATE_ENTITY, + id=entity_id, + user_id=user_id, + name=name, + details=details, + icon=icon, + metadata=None, + ) + + if not results: + return None + + entity_data = dict(results[0]["e"]) + return self._row_to_entity(entity_data) + async def delete_entity( self, entity_id: str, diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/base.py b/backends/advanced/src/advanced_omi_backend/services/memory/base.py index 7df7748e..277a92f8 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/base.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/base.py @@ -205,6 +205,45 @@ async def update_memory( """ return False + async def reprocess_memory( + self, + transcript: str, + client_id: str, + source_id: str, + user_id: str, + user_email: str, + transcript_diff: Optional[List[Dict[str, Any]]] = None, + previous_transcript: Optional[str] = None, + ) -> Tuple[bool, List[str]]: + """Reprocess memories after transcript or speaker changes. + + This method is called when a conversation's transcript has been + reprocessed (e.g., speaker re-identification) and memories need + to be updated to reflect the changes. + + The default implementation falls back to normal ``add_memory`` + with ``allow_update=True``. Providers that support diff-aware + reprocessing should override this method. + + Args: + transcript: Updated full transcript text (with corrected speakers) + client_id: Client identifier + source_id: Conversation/source identifier + user_id: User identifier + user_email: User email address + transcript_diff: List of dicts describing what changed between + the old and new transcript (speaker changes, text changes). + Each dict has keys like ``type``, ``old_speaker``, + ``new_speaker``, ``text``, ``start``, ``end``. + previous_transcript: The previous transcript text (before changes) + + Returns: + Tuple of (success: bool, affected_memory_ids: List[str]) + """ + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, allow_update=True + ) + @abstractmethod async def delete_memory( self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None @@ -331,6 +370,37 @@ async def propose_memory_actions( """ pass + async def propose_reprocess_actions( + self, + existing_memories: List[Dict[str, str]], + diff_context: str, + new_transcript: str, + custom_prompt: Optional[str] = None, + ) -> Dict[str, Any]: + """Propose memory updates after transcript reprocessing (e.g., speaker changes). + + Uses the LLM to review existing conversation memories in light of + specific transcript changes (speaker re-identification, text corrections) + and propose targeted ADD/UPDATE/DELETE/NONE actions. + + Default implementation raises NotImplementedError. Providers that + support diff-aware reprocessing should override this method. + + Args: + existing_memories: List of existing memories for the conversation + (each dict has ``id`` and ``text`` keys) + diff_context: Formatted string describing what changed in the + transcript (e.g., speaker relabelling details) + new_transcript: The updated full transcript text + custom_prompt: Optional custom system prompt + + Returns: + Dictionary containing proposed actions in ``{"memory": [...]}`` format + """ + raise NotImplementedError( + f"{type(self).__name__} does not support propose_reprocess_actions" + ) + @abstractmethod async def test_connection(self) -> bool: """Test connection to the LLM provider. @@ -415,6 +485,24 @@ async def count_memories(self, user_id: str) -> Optional[int]: """ return None + async def get_memories_by_source( + self, user_id: str, source_id: str, limit: int = 100 + ) -> List["MemoryEntry"]: + """Get all memories for a specific source (conversation) for a user. + + Default implementation returns empty list. Vector stores should + override to filter by metadata.source_id. + + Args: + user_id: User identifier + source_id: Source/conversation identifier + limit: Maximum number of memories to return + + Returns: + List of MemoryEntry objects for the specified source + """ + return [] + @abstractmethod async def update_memory( self, diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py b/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py index 3e4f4535..0e704be3 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py @@ -197,6 +197,80 @@ """ +REPROCESS_SPEAKER_UPDATE_PROMPT = """ +You are a memory correction system. A conversation's transcript has been reprocessed with \ +updated speaker identification. The words spoken are the same, but speakers have been \ +re-identified more accurately. Your job is to update the existing memories so they \ +correctly attribute information to the right people. + +## Rules + +1. **UPDATE** — If a memory attributes information to a speaker whose label changed, \ +rewrite it with the correct speaker name. Keep the same `id`. +2. **NONE** — If the memory is unaffected by the speaker changes, leave it unchanged. +3. **DELETE** — If a memory is now nonsensical or completely wrong because the speaker \ +was misidentified (e.g., personal traits wrongly attributed), remove it. +4. **ADD** — If the corrected transcript reveals important new facts that become clear \ +only with the correct speaker attribution, add them. + +## Important guidelines + +- Focus on **speaker attribution corrections**. This is the primary reason for reprocessing. +- A change from "Speaker 0" to "John" means memories referencing "Speaker 0" must now \ +reference "John". +- A change from "Alice" to "Bob" means facts previously attributed to "Alice" must be \ +attributed to "Bob" instead — this is critical because it changes *who* said or did something. +- Preserve the factual content when only the speaker name changes. +- Do NOT add memories that duplicate existing ones. +- When you UPDATE, always include `old_memory` with the previous text. + +## Output format (strict JSON only) + +Return ONLY a valid JSON object with this structure: + +{ + "memory": [ + { + "id": "", + "event": "UPDATE|NONE|DELETE|ADD", + "text": "", + "old_memory": "" + } + ] +} + +Do not output any text outside the JSON object. +""" + + +def build_reprocess_speaker_messages( + existing_memories: list, + diff_context: str, + new_transcript: str, +) -> str: + """Build the user message for the reprocess-after-speaker-change LLM call. + + Args: + existing_memories: List of dicts with ``id`` and ``text`` keys + diff_context: Formatted string of speaker changes + new_transcript: Full updated transcript with corrected speakers + + Returns: + Formatted user message string + """ + memories_json = json.dumps(existing_memories, ensure_ascii=False) + + return ( + "## Existing Memories for This Conversation\n" + f"{memories_json}\n\n" + "## Speaker Changes in Transcript\n" + f"{diff_context}\n\n" + "## Updated Full Transcript (with corrected speakers)\n" + f"{new_transcript}\n\n" + "Output:" + ) + + PROCEDURAL_MEMORY_SYSTEM_PROMPT = """ You are a memory summarization system that records and preserves the complete interaction history between a human and an AI agent. You are provided with the agent's execution history over the past N steps. Your task is to produce a comprehensive summary of the agent's output history that contains every detail necessary for the agent to continue the task without ambiguity. **Every output produced by the agent must be recorded verbatim as part of the summary.** diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py index 1eddae93..7d9eecc4 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py @@ -476,6 +476,202 @@ def shutdown(self) -> None: self.vector_store = None memory_logger.info("Memory service shut down") + async def reprocess_memory( + self, + transcript: str, + client_id: str, + source_id: str, + user_id: str, + user_email: str, + transcript_diff: Optional[list] = None, + previous_transcript: Optional[str] = None, + ) -> Tuple[bool, List[str]]: + """Reprocess memories after speaker re-identification. + + Instead of extracting fresh facts from scratch, this method: + 1. Fetches existing memories for this specific conversation + 2. Computes what changed (speaker labels) between old and new transcript + 3. Asks the LLM to make targeted updates to the existing memories + + Falls back to normal ``add_memory`` when there are no existing + memories or no meaningful diff. + + Args: + transcript: Updated full transcript (with corrected speakers) + client_id: Client identifier + source_id: Conversation identifier + user_id: User identifier + user_email: User email + transcript_diff: List of dicts describing speaker changes + previous_transcript: Previous transcript text (before changes) + + Returns: + Tuple of (success, affected_memory_ids) + """ + await self._ensure_initialized() + + try: + # 1. Get existing memories for this conversation + existing_memories = await self.vector_store.get_memories_by_source( + user_id, source_id + ) + + # 2. If no existing memories, fall back to normal extraction + if not existing_memories: + memory_logger.info( + f"🔄 Reprocess: no existing memories for {source_id}, " + f"falling back to normal extraction" + ) + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + + # 3. If no diff provided, fall back to normal extraction + if not transcript_diff: + memory_logger.info( + f"🔄 Reprocess: no transcript diff for {source_id}, " + f"falling back to normal extraction" + ) + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + + # 4. Format the diff for the LLM + diff_text = self._format_speaker_diff(transcript_diff) + + memory_logger.info( + f"🔄 Reprocess: {len(existing_memories)} existing memories, " + f"{len(transcript_diff)} speaker changes for {source_id}" + ) + + # 5. Build temp ID mapping (avoid hallucinated UUIDs) + temp_uuid_mapping = {} + existing_memory_dicts = [] + for idx, mem in enumerate(existing_memories): + temp_uuid_mapping[str(idx)] = mem.id + existing_memory_dicts.append({"id": str(idx), "text": mem.content}) + + # 6. Ask LLM for targeted update actions + try: + actions_obj = await self.llm_provider.propose_reprocess_actions( + existing_memories=existing_memory_dicts, + diff_context=diff_text, + new_transcript=transcript, + ) + memory_logger.info( + f"🔄 Reprocess LLM returned actions: {actions_obj}" + ) + except NotImplementedError: + memory_logger.warning( + "LLM provider does not support propose_reprocess_actions, " + "falling back to normal extraction" + ) + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + except Exception as e: + memory_logger.error(f"Reprocess LLM call failed: {e}") + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + + # 7. Normalize and pre-generate embeddings for ADD/UPDATE actions + actions_list = self._normalize_actions(actions_obj) + + texts_needing_embeddings = [ + action.get("text") + for action in actions_list + if action.get("event") in ("ADD", "UPDATE") + and action.get("text") + and isinstance(action.get("text"), str) + ] + + text_to_embedding = {} + if texts_needing_embeddings: + try: + embeddings = await asyncio.wait_for( + self.llm_provider.generate_embeddings(texts_needing_embeddings), + timeout=self.config.timeout_seconds, + ) + text_to_embedding = dict( + zip(texts_needing_embeddings, embeddings, strict=True) + ) + except Exception as e: + memory_logger.warning( + f"Batch embedding generation failed for reprocess: {e}" + ) + + # 8. Apply the actions (reuses existing infrastructure) + created_ids = await self._apply_memory_actions( + actions_list, + text_to_embedding, + temp_uuid_mapping, + client_id, + source_id, + user_id, + user_email, + ) + + memory_logger.info( + f"✅ Reprocess complete for {source_id}: " + f"{len(created_ids)} memories affected" + ) + return True, created_ids + + except Exception as e: + memory_logger.error( + f"❌ Reprocess memory failed for {source_id}: {e}" + ) + # Fall back to normal extraction on any unexpected error + memory_logger.info( + f"🔄 Falling back to normal extraction after reprocess error" + ) + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + + @staticmethod + def _format_speaker_diff(transcript_diff: list) -> str: + """Format a transcript diff into a human-readable string for the LLM. + + Args: + transcript_diff: List of change dicts from + ``compute_speaker_diff`` + + Returns: + Formatted multi-line string describing the changes + """ + if not transcript_diff: + return "No changes detected." + + lines = [] + for change in transcript_diff: + change_type = change.get("type", "unknown") + if change_type == "speaker_change": + lines.append( + f"- \"{change.get('text', '')}\" " + f"was spoken by \"{change.get('old_speaker', '?')}\" " + f"but is now identified as \"{change.get('new_speaker', '?')}\"" + ) + elif change_type == "text_change": + lines.append( + f"- Segment by {change.get('speaker', '?')}: " + f"text changed from \"{change.get('old_text', '')}\" " + f"to \"{change.get('new_text', '')}\"" + ) + elif change_type == "new_segment": + lines.append( + f"- New segment: {change.get('speaker', '?')}: " + f"\"{change.get('text', '')}\"" + ) + + return "\n".join(lines) + # Private helper methods def _deduplicate_memories(self, memories_text: List[str]) -> List[str]: diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py index 9b00e8b1..a3f68c5f 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py @@ -20,9 +20,9 @@ from ..base import LLMProviderBase from ..prompts import ( - FACT_RETRIEVAL_PROMPT, + REPROCESS_SPEAKER_UPDATE_PROMPT, + build_reprocess_speaker_messages, build_update_memory_messages, - get_update_memory_messages, ) from ..update_memory_utils import ( extract_assistant_xml_from_openai_response, @@ -357,6 +357,97 @@ async def propose_memory_actions( return {} + async def propose_reprocess_actions( + self, + existing_memories: List[Dict[str, str]], + diff_context: str, + new_transcript: str, + custom_prompt: Optional[str] = None, + ) -> Dict[str, Any]: + """Propose memory updates after speaker re-identification. + + Sends the existing conversation memories, the speaker change diff, + and the corrected transcript to the LLM. Returns JSON with + ADD/UPDATE/DELETE/NONE actions. + + The system prompt is resolved in priority order: + 1. ``custom_prompt`` argument (if provided) + 2. Langfuse override via the prompt registry + (prompt id ``memory.reprocess_speaker_update``) + 3. Registered default from ``prompt_defaults.py`` + + Args: + existing_memories: List of {id, text} dicts for this conversation + diff_context: Formatted string of speaker changes + new_transcript: Full updated transcript with corrected speakers + custom_prompt: Optional custom system prompt + + Returns: + Dictionary with ``memory`` key containing action list + """ + try: + # Resolve prompt: explicit arg → Langfuse/registry → hardcoded fallback + if custom_prompt and custom_prompt.strip(): + system_prompt = custom_prompt + else: + try: + registry = get_prompt_registry() + system_prompt = await registry.get_prompt( + "memory.reprocess_speaker_update" + ) + except Exception as e: + memory_logger.debug( + f"Registry prompt fetch failed for " + f"memory.reprocess_speaker_update: {e}, " + f"using hardcoded fallback" + ) + system_prompt = REPROCESS_SPEAKER_UPDATE_PROMPT + + user_content = build_reprocess_speaker_messages( + existing_memories, diff_context, new_transcript + ) + + messages = [ + {"role": "system", "content": system_prompt.strip()}, + {"role": "user", "content": user_content}, + ] + + memory_logger.info( + f"🔄 Reprocess: asking LLM with {len(existing_memories)} existing memories " + f"and speaker diff" + ) + memory_logger.debug( + f"🔄 Reprocess user content (first 300 chars): {user_content[:300]}..." + ) + + client = _get_openai_client( + api_key=self.api_key, base_url=self.base_url, is_async=True + ) + response = await client.chat.completions.create( + model=self.model, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + response_format={"type": "json_object"}, + ) + content = (response.choices[0].message.content or "").strip() + + if not content: + memory_logger.warning("Reprocess LLM returned empty content") + return {} + + result = json.loads(content) + memory_logger.info(f"🔄 Reprocess LLM returned: {result}") + return result + + except json.JSONDecodeError as e: + memory_logger.error(f"Reprocess LLM returned invalid JSON: {e}") + return {} + except Exception as e: + memory_logger.error(f"propose_reprocess_actions failed: {e}") + return {} + + class OllamaProvider(LLMProviderBase): """Ollama LLM provider implementation. diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py index 9fed0126..50678642 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py @@ -19,6 +19,7 @@ FilterSelector, MatchValue, PointStruct, + Range, VectorParams, ) @@ -445,7 +446,112 @@ async def get_memory(self, memory_id: str, user_id: Optional[str] = None) -> Opt memory_logger.error(f"Qdrant get memory failed for {memory_id}: {e}") return None + async def get_memories_by_source( + self, user_id: str, source_id: str, limit: int = 100 + ) -> List[MemoryEntry]: + """Get all memories for a specific source (conversation) for a user. + Filters by both user_id and source_id in metadata to return only + memories extracted from a particular conversation. + Args: + user_id: User identifier + source_id: Source/conversation identifier + limit: Maximum number of memories to return + + Returns: + List of MemoryEntry objects for the specified source + """ + try: + search_filter = Filter( + must=[ + FieldCondition( + key="metadata.user_id", + match=MatchValue(value=user_id), + ), + FieldCondition( + key="metadata.source_id", + match=MatchValue(value=source_id), + ), + ] + ) + + results = await self.client.scroll( + collection_name=self.collection_name, + scroll_filter=search_filter, + limit=limit, + ) + + memories = [] + for point in results[0]: + memory = MemoryEntry( + id=str(point.id), + content=point.payload.get("content", ""), + metadata=point.payload.get("metadata", {}), + created_at=point.payload.get("created_at"), + updated_at=point.payload.get("updated_at"), + ) + memories.append(memory) + + memory_logger.info( + f"Found {len(memories)} memories for source {source_id} (user {user_id})" + ) + return memories + + except Exception as e: + memory_logger.error(f"Qdrant get memories by source failed: {e}") + return [] + + async def get_recent_memories( + self, user_id: str, since_timestamp: int, limit: int = 100 + ) -> List[MemoryEntry]: + """Get memories created after a given unix timestamp for a user. + Args: + user_id: User identifier + since_timestamp: Unix timestamp; only memories at or after this time are returned + limit: Maximum number of memories to return + + Returns: + List of MemoryEntry objects + """ + try: + search_filter = Filter( + must=[ + FieldCondition( + key="metadata.user_id", + match=MatchValue(value=user_id), + ), + FieldCondition( + key="metadata.timestamp", + range=Range(gte=since_timestamp), + ), + ] + ) + + results = await self.client.scroll( + collection_name=self.collection_name, + scroll_filter=search_filter, + limit=limit, + ) + + memories = [] + for point in results[0]: + memory = MemoryEntry( + id=str(point.id), + content=point.payload.get("content", ""), + metadata=point.payload.get("metadata", {}), + created_at=point.payload.get("created_at"), + updated_at=point.payload.get("updated_at"), + ) + memories.append(memory) + + memory_logger.info( + f"Found {len(memories)} recent memories since {since_timestamp} for user {user_id}" + ) + return memories + + except Exception as e: + memory_logger.error(f"Qdrant get recent memories failed: {e}") + return [] diff --git a/backends/advanced/src/advanced_omi_backend/services/plugin_service.py b/backends/advanced/src/advanced_omi_backend/services/plugin_service.py index 2a69e860..e71422f8 100644 --- a/backends/advanced/src/advanced_omi_backend/services/plugin_service.py +++ b/backends/advanced/src/advanced_omi_backend/services/plugin_service.py @@ -9,6 +9,7 @@ import logging import os import re +import sys from pathlib import Path from typing import Any, Dict, List, Optional, Type @@ -16,6 +17,7 @@ from advanced_omi_backend.config_loader import get_plugins_yml_path from advanced_omi_backend.plugins import BasePlugin, PluginRouter +from advanced_omi_backend.plugins.services import PluginServices logger = logging.getLogger(__name__) @@ -23,6 +25,22 @@ _plugin_router: Optional[PluginRouter] = None +def _get_plugins_dir() -> Path: + """Get external plugins directory. + + Priority: PLUGINS_DIR env var > Docker path > local dev path. + """ + env_dir = os.getenv("PLUGINS_DIR") + if env_dir: + return Path(env_dir) + docker_path = Path("/app/plugins") + if docker_path.is_dir(): + return docker_path + # Local dev: plugin_service.py is at /backends/advanced/src/advanced_omi_backend/services/ + repo_root = Path(__file__).resolve().parents[5] + return repo_root / "plugins" + + def expand_env_vars(value: Any) -> Any: """ Recursively expand environment variables in configuration values. @@ -105,9 +123,7 @@ def load_plugin_config(plugin_id: str, orchestration_config: Dict[str, Any]) -> # 1. Load plugin-specific config.yml if it exists try: - import advanced_omi_backend.plugins - - plugins_dir = Path(advanced_omi_backend.plugins.__file__).parent + plugins_dir = _get_plugins_dir() plugin_config_path = plugins_dir / plugin_id / "config.yml" if plugin_config_path.exists(): @@ -284,9 +300,7 @@ def load_schema_yml(plugin_id: str) -> Optional[Dict[str, Any]]: Schema dictionary if schema.yml exists, None otherwise """ try: - import advanced_omi_backend.plugins - - plugins_dir = Path(advanced_omi_backend.plugins.__file__).parent + plugins_dir = _get_plugins_dir() schema_path = plugins_dir / plugin_id / "schema.yml" if schema_path.exists(): @@ -396,9 +410,7 @@ def get_plugin_metadata( """ # Load plugin config.yml try: - import advanced_omi_backend.plugins - - plugins_dir = Path(advanced_omi_backend.plugins.__file__).parent + plugins_dir = _get_plugins_dir() plugin_config_path = plugins_dir / plugin_id / "config.yml" config_dict = {} @@ -469,23 +481,21 @@ def discover_plugins() -> Dict[str, Type[BasePlugin]]: """ discovered_plugins = {} - # Get the plugins directory path - try: - import advanced_omi_backend.plugins - - plugins_dir = Path(advanced_omi_backend.plugins.__file__).parent - except Exception as e: - logger.error(f"Failed to locate plugins directory: {e}") + plugins_dir = _get_plugins_dir() + if not plugins_dir.is_dir(): + logger.warning(f"Plugins directory not found: {plugins_dir}") return discovered_plugins - logger.info(f"🔍 Scanning for plugins in: {plugins_dir}") + # Add plugins dir to sys.path so plugin packages can be imported directly + plugins_dir_str = str(plugins_dir) + if plugins_dir_str not in sys.path: + sys.path.insert(0, plugins_dir_str) - # Skip these known system directories/files - skip_items = {"__pycache__", "__init__.py", "base.py", "router.py"} + logger.info(f"Scanning for plugins in: {plugins_dir}") - # Scan for plugin directories + # Scan for plugin directories (skip hidden/underscore dirs) for item in plugins_dir.iterdir(): - if not item.is_dir() or item.name in skip_items: + if not item.is_dir() or item.name.startswith("_"): continue plugin_id = item.name @@ -500,12 +510,9 @@ def discover_plugins() -> Dict[str, Type[BasePlugin]]: # e.g., email_summarizer -> EmailSummarizerPlugin class_name = "".join(word.capitalize() for word in plugin_id.split("_")) + "Plugin" - # Import the plugin module - module_path = f"advanced_omi_backend.plugins.{plugin_id}" - logger.debug(f"Attempting to import plugin from: {module_path}") - - # Import the plugin package (which should export the class in __init__.py) - plugin_module = importlib.import_module(module_path) + # Import the plugin package directly (it's on sys.path now) + logger.debug(f"Attempting to import plugin: {plugin_id}") + plugin_module = importlib.import_module(plugin_id) # Try to get the plugin class if not hasattr(plugin_module, class_name): @@ -530,14 +537,14 @@ def discover_plugins() -> Dict[str, Type[BasePlugin]]: # Successfully discovered plugin discovered_plugins[plugin_id] = plugin_class - logger.info(f"✅ Discovered plugin: '{plugin_id}' ({class_name})") + logger.info(f"Discovered plugin: '{plugin_id}' ({class_name})") except ImportError as e: logger.warning(f"Failed to import plugin '{plugin_id}': {e}") except Exception as e: logger.error(f"Error discovering plugin '{plugin_id}': {e}", exc_info=True) - logger.info(f"🎉 Plugin discovery complete: {len(discovered_plugins)} plugin(s) found") + logger.info(f"Plugin discovery complete: {len(discovered_plugins)} plugin(s) found") return discovered_plugins @@ -577,9 +584,6 @@ def init_plugin_router() -> Optional[PluginRouter]: # Discover all plugins via auto-discovery discovered_plugins = discover_plugins() - # Core plugin names (for informational logging only) - CORE_PLUGIN_NAMES = {"homeassistant", "test_event"} - # Initialize each plugin listed in config/plugins.yml for plugin_id, orchestration_config in plugins_data.items(): logger.info( @@ -602,7 +606,6 @@ def init_plugin_router() -> Optional[PluginRouter]: # Get plugin class from discovered plugins plugin_class = discovered_plugins[plugin_id] - plugin_type = "core" if plugin_id in CORE_PLUGIN_NAMES else "community" # Instantiate and register the plugin plugin = plugin_class(plugin_config) @@ -616,7 +619,7 @@ def init_plugin_router() -> Optional[PluginRouter]: # Note: async initialization happens in app_factory lifespan _plugin_router.register_plugin(plugin_id, plugin) - logger.info(f"✅ Plugin '{plugin_id}' registered successfully ({plugin_type})") + logger.info(f"Plugin '{plugin_id}' registered successfully") except Exception as e: logger.error(f"Failed to register plugin '{plugin_id}': {e}", exc_info=True) @@ -627,6 +630,11 @@ def init_plugin_router() -> Optional[PluginRouter]: else: logger.info("No plugins.yml found, plugins disabled") + # Attach PluginServices for cross-plugin and system interaction + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + services = PluginServices(router=_plugin_router, redis_url=redis_url) + _plugin_router.set_services(services) + return _plugin_router except Exception as e: diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py index 71b213b8..48637b13 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py @@ -17,6 +17,7 @@ from advanced_omi_backend.config_loader import get_backend_config from advanced_omi_backend.model_registry import get_models_registry +from advanced_omi_backend.prompt_registry import get_prompt_registry from .base import ( BaseTranscriptionProvider, @@ -27,6 +28,26 @@ logger = logging.getLogger(__name__) +def _parse_hot_words_to_keyterm(hot_words_str: str) -> str: + """Convert comma-separated hot words to Deepgram keyterm format. + + Input: "hey vivi, chronicle, omi" + Output: "hey vivi Hey Vivi chronicle Chronicle omi Omi" + """ + if not hot_words_str or not hot_words_str.strip(): + return "" + terms = [] + for word in hot_words_str.split(","): + word = word.strip() + if not word: + continue + terms.append(word) + capitalized = word.title() + if capitalized != word: + terms.append(capitalized) + return " ".join(terms) + + def _dotted_get(d: dict | list | None, dotted: Optional[str]): """Safely extract a value from nested dict/list using dotted paths. @@ -99,7 +120,7 @@ def get_capabilities_dict(self) -> dict: """ return {cap: True for cap in self._capabilities} - async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False) -> dict: + async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False, context_info: Optional[str] = None, **kwargs) -> dict: # Special handling for mock provider (no HTTP server needed) if self.model.model_provider == "mock": from .mock_provider import MockTranscriptionProvider @@ -120,7 +141,13 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = # Build headers (skip Content-Type for multipart as httpx will set it) headers = {} if not use_multipart: - headers["Content-Type"] = "audio/raw" + # Auto-detect WAV format from RIFF header and use correct Content-Type. + # Sending WAV data as audio/raw can cause Deepgram to silently return + # empty transcripts because it tries to decode the WAV header as raw PCM. + if audio_data[:4] == b"RIFF": + headers["Content-Type"] = "audio/wav" + else: + headers["Content-Type"] = "audio/raw" if self.model.api_key: # Allow templated header, otherwise fallback to Bearer/Token conventions by config @@ -148,14 +175,34 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = if "diarize" in query: query["diarize"] = "true" if diarize else "false" + # Use caller-provided context or fall back to LangFuse prompt store + if context_info: + hot_words_str = context_info + else: + hot_words_str = "" + try: + registry = get_prompt_registry() + hot_words_str = await registry.get_prompt("asr.hot_words") + except Exception as e: + logger.debug(f"Failed to fetch asr.hot_words prompt: {e}") + + # For Deepgram: inject as keyterm query param + if self.model.model_provider == "deepgram" and hot_words_str.strip(): + keyterm = _parse_hot_words_to_keyterm(hot_words_str) + if keyterm: + query["keyterm"] = keyterm + timeout = op.get("timeout", 300) try: async with httpx.AsyncClient(timeout=timeout) as client: if method == "POST": if use_multipart: - # Send as multipart file upload (for Parakeet) + # Send as multipart file upload (for Parakeet/VibeVoice) files = {"file": ("audio.wav", audio_data, "audio/wav")} - resp = await client.post(url, headers=headers, params=query, files=files) + data = {} + if hot_words_str and hot_words_str.strip(): + data["context_info"] = hot_words_str.strip() + resp = await client.post(url, headers=headers, params=query, files=files, data=data) else: # Send as raw audio data (for Deepgram) resp = await client.post(url, headers=headers, params=query, content=audio_data) @@ -240,6 +287,18 @@ async def start_stream(self, client_id: str, sample_rate: int = 16000, diarize: if diarize and "diarize" in query_dict: query_dict["diarize"] = "true" + # Inject hot words for streaming (Deepgram only) + if self.model.model_provider == "deepgram": + try: + registry = get_prompt_registry() + hot_words_str = await registry.get_prompt("asr.hot_words") + if hot_words_str and hot_words_str.strip(): + keyterm = _parse_hot_words_to_keyterm(hot_words_str) + if keyterm: + query_dict["keyterm"] = keyterm + except Exception as e: + logger.debug(f"Failed to fetch asr.hot_words for streaming: {e}") + # Normalize boolean values to lowercase strings (Deepgram expects "true"/"false", not "True"/"False") normalized_query = {} for k, v in query_dict.items(): diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/base.py b/backends/advanced/src/advanced_omi_backend/services/transcription/base.py index 7d0f2306..bc5cf6f7 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/base.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/base.py @@ -122,12 +122,14 @@ def mode(self) -> str: return "batch" @abc.abstractmethod - async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False) -> dict: + async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False, context_info: Optional[str] = None, **kwargs) -> dict: """Transcribe audio data. Args: audio_data: Raw audio bytes sample_rate: Audio sample rate diarize: Whether to enable speaker diarization (provider-dependent) + context_info: Optional ASR context (hot words, jargon) to boost recognition + **kwargs: Additional parameters (e.g. Langfuse trace IDs) """ pass diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/context.py b/backends/advanced/src/advanced_omi_backend/services/transcription/context.py new file mode 100644 index 00000000..5eda20fd --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/context.py @@ -0,0 +1,94 @@ +""" +ASR context builder for transcription. + +Combines static hot words from the prompt registry with per-user dynamic +jargon cached in Redis by the ``asr_jargon_extraction`` cron job. +""" + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import redis.asyncio as aioredis + +logger = logging.getLogger(__name__) + +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") + + +@dataclass +class TranscriptionContext: + """Structured context gathered before transcription. + + Holds the individual context components so callers can inspect them + and Langfuse spans can log them as structured metadata. + """ + + hot_words: str = "" + user_jargon: str = "" + user_id: Optional[str] = None + + @property + def combined(self) -> str: + """Comma-separated string suitable for passing to ASR providers.""" + parts = [p.strip() for p in [self.hot_words, self.user_jargon] if p and p.strip()] + return ", ".join(parts) + + def to_metadata(self) -> dict: + """Return a dict suitable for Langfuse span metadata.""" + return { + "hot_words": self.hot_words[:200] if self.hot_words else "", + "user_jargon": self.user_jargon[:200] if self.user_jargon else "", + "user_id": self.user_id, + "combined_length": len(self.combined), + } + + +async def gather_transcription_context(user_id: Optional[str] = None) -> TranscriptionContext: + """Build structured transcription context: static hot words + cached user jargon. + + Args: + user_id: If provided, also look up per-user jargon from Redis. + + Returns: + TranscriptionContext with individual components. + """ + from advanced_omi_backend.prompt_registry import get_prompt_registry + + registry = get_prompt_registry() + try: + hot_words = await registry.get_prompt("asr.hot_words") + except Exception: + logger.debug("Failed to fetch asr.hot_words prompt, using empty default") + hot_words = "" + + user_jargon = "" + if user_id: + try: + redis_client = aioredis.from_url(REDIS_URL, decode_responses=True) + try: + user_jargon = await redis_client.get(f"asr:jargon:{user_id}") or "" + finally: + await redis_client.close() + except Exception: + pass # Redis unavailable → skip dynamic jargon + + return TranscriptionContext( + hot_words=hot_words or "", + user_jargon=user_jargon, + user_id=user_id, + ) + + +async def get_asr_context(user_id: Optional[str] = None) -> str: + """Build combined ASR context string (backward-compatible alias). + + Args: + user_id: If provided, also look up per-user jargon from Redis. + + Returns: + Comma-separated string of context terms for the ASR provider. + """ + ctx = await gather_transcription_context(user_id) + return ctx.combined diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py b/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py index 052680d2..749558ce 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py @@ -19,6 +19,8 @@ import redis.asyncio as redis from redis import exceptions as redis_exceptions +from advanced_omi_backend.plugins.events import PluginEvent + from advanced_omi_backend.client_manager import get_client_owner_async from advanced_omi_backend.plugins.router import PluginRouter from advanced_omi_backend.services.transcription import get_transcription_provider @@ -357,7 +359,7 @@ async def trigger_plugins(self, session_id: str, result: Dict): logger.info(f"🎯 Dispatching transcript.streaming event for user {user_id}, transcript: {plugin_data['transcript'][:50]}...") plugin_results = await self.plugin_router.dispatch_event( - event='transcript.streaming', + event=PluginEvent.TRANSCRIPT_STREAMING, user_id=user_id, data=plugin_data, metadata={'client_id': session_id} @@ -387,8 +389,20 @@ async def process_stream(self, stream_name: str): "started_at": time.time() } + # Read actual sample rate from the session's audio_format stored in Redis + sample_rate = 16000 + session_key = f"audio:session:{session_id}" + try: + audio_format_raw = await self.redis_client.hget(session_key, "audio_format") + if audio_format_raw: + audio_format = json.loads(audio_format_raw) + sample_rate = int(audio_format.get("rate", 16000)) + logger.info(f"📊 Read sample rate {sample_rate}Hz from session {session_id}") + except Exception as e: + logger.warning(f"Failed to read audio_format from Redis for {session_id}: {e}") + # Start WebSocket connection to Deepgram - await self.start_session_stream(session_id) + await self.start_session_stream(session_id, sample_rate=sample_rate) last_id = "0" # Start from beginning stream_ended = False diff --git a/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py b/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py index 7c14cccd..ea8510fe 100644 --- a/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py +++ b/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py @@ -5,6 +5,10 @@ to enhance transcripts with actual speaker names instead of generic labels. Configuration is managed via config.yml (speaker_recognition section). + +NOTE: user_id is currently hardcoded to "1" throughout this client because only +a single admin user is supported at this time. Update when multi-user support +is implemented. """ import asyncio @@ -177,7 +181,7 @@ async def diarize_identify_match( form_data.add_field("transcript_data", json.dumps(transcript_data)) form_data.add_field("user_id", "1") # TODO: Implement proper user mapping - form_data.add_field("similarity_threshold", str(config.get("similarity_threshold", 0.15))) + form_data.add_field("similarity_threshold", str(config.get("similarity_threshold", 0.45))) form_data.add_field("min_duration", str(config.get("min_duration", 0.5))) # Use /v1/diarize-identify-match endpoint as fallback @@ -190,7 +194,7 @@ async def diarize_identify_match( # Send existing transcript for diarization and speaker matching form_data.add_field("transcript_data", json.dumps(transcript_data)) form_data.add_field("user_id", "1") # TODO: Implement proper user mapping - form_data.add_field("similarity_threshold", str(config.get("similarity_threshold", 0.15))) + form_data.add_field("similarity_threshold", str(config.get("similarity_threshold", 0.45))) # Add pyannote diarization parameters form_data.add_field("min_duration", str(config.get("min_duration", 0.5))) @@ -274,8 +278,10 @@ async def identify_segment( form_data.add_field( "file", audio_wav_bytes, filename="segment.wav", content_type="audio/wav" ) + # TODO: Implement proper user mapping between MongoDB ObjectIds and speaker service integer IDs + # Speaker service expects integer user_id, not MongoDB ObjectId strings if user_id is not None: - form_data.add_field("user_id", str(user_id)) + form_data.add_field("user_id", "1") if similarity_threshold is not None: form_data.add_field("similarity_threshold", str(similarity_threshold)) @@ -309,24 +315,32 @@ async def identify_provider_segments( conversation_id: str, segments: List[Dict], user_id: Optional[str] = None, + per_segment: bool = False, + min_segment_duration: float = 1.5, ) -> Dict: """ - Identify speakers in provider-diarized segments using majority-vote per label. + Identify speakers in provider-diarized segments. + + Default mode: majority-vote per label. Picks top 3 longest segments per label, + identifies each, and majority-votes to map labels to names. - For each unique speaker label, picks the top 3 longest segments (min 1.5s), - extracts audio, calls /identify, and majority-votes to map labels to names. + Per-segment mode (per_segment=True): identifies every segment individually. + Used during reprocessing so that fine-tuned embeddings benefit each segment. Args: conversation_id: Conversation ID for audio extraction from MongoDB segments: List of dicts with keys: start, end, text, speaker user_id: Optional user ID for speaker identification + per_segment: If True, identify each segment individually instead of majority-vote + min_segment_duration: Minimum segment duration in seconds for identification Returns: Dict with 'segments' list matching diarize_identify_match() format """ if hasattr(self, "_mock_client"): return await self._mock_client.identify_provider_segments( - conversation_id, segments, user_id + conversation_id, segments, user_id, + per_segment=per_segment, min_segment_duration=min_segment_duration, ) if not self.enabled: @@ -338,9 +352,8 @@ async def identify_provider_segments( ) config = get_diarization_settings() - similarity_threshold = config.get("similarity_threshold", 0.15) + similarity_threshold = config.get("similarity_threshold", 0.45) - MIN_SEGMENT_DURATION = 1.5 MAX_SAMPLES_PER_LABEL = 3 # Detect non-speech segments (e.g. [Music], [Environmental Sounds], [Human Sounds]) @@ -379,14 +392,26 @@ def _is_non_speech(seg: Dict) -> bool: f"{len(label_groups)} unique labels: {list(label_groups.keys())}" ) - # For each label, pick top N longest segments >= MIN_SEGMENT_DURATION + # Per-segment mode: identify every segment individually (used during reprocess) + if per_segment: + return await self._identify_per_segment( + conversation_id=conversation_id, + segments=segments, + speech_segments=speech_segments, + non_speech_indices=non_speech_indices, + user_id=user_id, + similarity_threshold=similarity_threshold, + min_segment_duration=min_segment_duration, + ) + + # For each label, pick top N longest segments >= min_segment_duration label_samples: Dict[str, List[Dict]] = {} for label, segs in label_groups.items(): - eligible = [s for s in segs if (s["end"] - s["start"]) >= MIN_SEGMENT_DURATION] + eligible = [s for s in segs if (s["end"] - s["start"]) >= min_segment_duration] eligible.sort(key=lambda s: s["end"] - s["start"], reverse=True) label_samples[label] = eligible[:MAX_SAMPLES_PER_LABEL] if not label_samples[label]: - logger.info(f"🎤 Label '{label}': no segments >= {MIN_SEGMENT_DURATION}s, skipping identification") + logger.info(f"🎤 Label '{label}': no segments >= {min_segment_duration}s, skipping identification") # Extract audio and identify concurrently with semaphore semaphore = asyncio.Semaphore(3) @@ -398,7 +423,7 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: conversation_id, seg["start"], seg["end"] ) result = await self.identify_segment( - wav_bytes, user_id="1", similarity_threshold=similarity_threshold + wav_bytes, user_id=user_id, similarity_threshold=similarity_threshold ) return result except Exception as e: @@ -471,7 +496,7 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: "end": seg["end"], "text": seg.get("text", ""), "speaker": label, - "identified_as": mapped[0] if mapped else label, + "identified_as": mapped[0] if mapped else None, "confidence": mapped[1] if mapped else 0.0, "status": "identified" if mapped else "unknown", }) @@ -484,6 +509,150 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: return {"segments": result_segments} + async def _identify_per_segment( + self, + conversation_id: str, + segments: List[Dict], + speech_segments: List[Dict], + non_speech_indices: set, + user_id: Optional[str], + similarity_threshold: float, + min_segment_duration: float, + ) -> Dict: + """ + Identify every speech segment individually (no majority vote). + + Used during reprocessing so that fine-tuned speaker embeddings + benefit each segment directly. + + Args: + conversation_id: Conversation ID for audio extraction + segments: All segments (speech + non-speech) in original order + speech_segments: Only the speech segments + non_speech_indices: Indices of non-speech segments in the original list + user_id: User ID for speaker identification + similarity_threshold: Similarity threshold for identification + min_segment_duration: Minimum duration for identification attempt + + Returns: + Dict with 'segments' list matching diarize_identify_match() format + """ + from advanced_omi_backend.utils.audio_chunk_utils import ( + reconstruct_audio_segment, + ) + + logger.info( + f"🎤 Per-segment identification: {len(speech_segments)} speech segments " + f"(min_duration={min_segment_duration}s)" + ) + + semaphore = asyncio.Semaphore(3) + + async def _identify_one(seg: Dict) -> Optional[Dict]: + async with semaphore: + try: + wav_bytes = await reconstruct_audio_segment( + conversation_id, seg["start"], seg["end"] + ) + return await self.identify_segment( + wav_bytes, user_id=user_id, similarity_threshold=similarity_threshold + ) + except Exception as e: + logger.warning( + f"🎤 Failed to identify segment [{seg['start']:.1f}-{seg['end']:.1f}]: {e}" + ) + return None + + # Build tasks for speech segments that meet the duration threshold + seg_tasks: List[tuple] = [] # (original_index, task_or_None) + all_tasks = [] + for i, seg in enumerate(segments): + if i in non_speech_indices: + seg_tasks.append((i, None)) + continue + duration = seg["end"] - seg["start"] + if duration >= min_segment_duration: + task = asyncio.create_task(_identify_one(seg)) + seg_tasks.append((i, task)) + all_tasks.append(task) + else: + seg_tasks.append((i, None)) # too short + + if all_tasks: + await asyncio.gather(*all_tasks, return_exceptions=True) + + # Build result segments + result_segments = [] + identified_count = 0 + for i, seg in enumerate(segments): + label = seg.get("speaker", "Unknown") + + if i in non_speech_indices: + result_segments.append({ + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": label, + "confidence": 0.0, + "status": "non_speech", + }) + continue + + # Find the matching task entry + task_entry = seg_tasks[i] + task = task_entry[1] + + if task is None: + # Too short for identification + result_segments.append({ + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": None, + "confidence": 0.0, + "status": "too_short", + }) + continue + + try: + result = task.result() + except Exception: + result = None + + if result and result.get("found"): + name = result.get("speaker_name", label) + confidence = result.get("confidence", 0.0) + result_segments.append({ + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": name, + "confidence": confidence, + "status": "identified", + }) + identified_count += 1 + else: + result_segments.append({ + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": None, + "confidence": 0.0, + "status": "unknown", + }) + + logger.info( + f"🎤 Per-segment identification complete: " + f"{identified_count}/{len(speech_segments)} segments identified, " + f"{len(result_segments)} total segments" + ) + + return {"segments": result_segments} + async def diarize_and_identify( self, audio_data: bytes, words: None, user_id: Optional[str] = None # NOT IMPLEMENTED ) -> Dict: @@ -531,7 +700,7 @@ async def diarize_and_identify( # Add all diarization parameters for the diarize-and-identify endpoint min_duration = diarization_settings.get("min_duration", 0.5) - similarity_threshold = diarization_settings.get("similarity_threshold", 0.15) + similarity_threshold = diarization_settings.get("similarity_threshold", 0.45) collar = diarization_settings.get("collar", 2.0) min_duration_off = diarization_settings.get("min_duration_off", 1.5) @@ -660,7 +829,7 @@ async def identify_speakers(self, audio_path: str, segments: List[Dict]) -> Dict # Add all diarization parameters for the diarize-and-identify endpoint form_data.add_field("min_duration", str(_diarization_settings.get("min_duration", 0.5))) - form_data.add_field("similarity_threshold", str(_diarization_settings.get("similarity_threshold", 0.15))) + form_data.add_field("similarity_threshold", str(_diarization_settings.get("similarity_threshold", 0.45))) form_data.add_field("collar", str(_diarization_settings.get("collar", 2.0))) form_data.add_field("min_duration_off", str(_diarization_settings.get("min_duration_off", 1.5))) if _diarization_settings.get("min_speakers"): diff --git a/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py b/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py index 8ba68adf..0e9f4cae 100644 --- a/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py +++ b/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py @@ -181,6 +181,8 @@ async def identify_provider_segments( conversation_id: str, segments: List[Dict], user_id: Optional[str] = None, + per_segment: bool = False, + min_segment_duration: float = 1.5, ) -> Dict: """Mock identify_provider_segments - returns segments with original labels.""" logger.info(f"🎤 Mock identify_provider_segments: {len(segments)} segments") diff --git a/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py b/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py index 5b5fa992..4abb1d5d 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py @@ -27,6 +27,9 @@ MIN_SPEECH_SEGMENT_DURATION = float(os.getenv("MIN_SPEECH_SEGMENT_DURATION", "1.0")) # seconds CROPPING_CONTEXT_PADDING = float(os.getenv("CROPPING_CONTEXT_PADDING", "0.1")) # seconds +SUPPORTED_AUDIO_EXTENSIONS = {".wav", ".mp3", ".mp4", ".m4a", ".flac", ".ogg", ".webm"} +VIDEO_EXTENSIONS = {".mp4", ".webm"} + class AudioValidationError(Exception): """Exception raised when audio validation fails.""" @@ -107,6 +110,59 @@ async def resample_audio_with_ffmpeg( return stdout +async def convert_any_to_wav(file_data: bytes, file_extension: str) -> bytes: + """ + Convert any supported audio/video file to 16kHz mono WAV using FFmpeg. + + For .wav input, returns the data as-is. + For everything else, runs FFmpeg to extract audio and convert to WAV. + + Args: + file_data: Raw file bytes + file_extension: File extension including dot (e.g. ".mp3", ".mp4") + + Returns: + WAV file bytes (16kHz, mono, 16-bit PCM) + + Raises: + AudioValidationError: If FFmpeg conversion fails + """ + ext = file_extension.lower() + if ext == ".wav": + return file_data + + cmd = [ + "ffmpeg", + "-i", "pipe:0", + "-vn", # Strip video track (no-op for audio-only files) + "-acodec", "pcm_s16le", + "-ar", "16000", + "-ac", "1", + "-f", "wav", + "pipe:1", + ] + + process = await asyncio.create_subprocess_exec( + *cmd, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdout, stderr = await process.communicate(input=file_data) + + if process.returncode != 0: + error_msg = stderr.decode() if stderr else "Unknown error" + audio_logger.error(f"FFmpeg conversion failed for {ext}: {error_msg}") + raise AudioValidationError(f"Failed to convert {ext} file to WAV: {error_msg}") + + audio_logger.info( + f"Converted {ext} to WAV: {len(file_data)} → {len(stdout)} bytes" + ) + + return stdout + + async def validate_and_prepare_audio( audio_data: bytes, expected_sample_rate: int = 16000, @@ -207,6 +263,9 @@ async def write_audio_file( timestamp: int, chunk_dir: Optional[Path] = None, validate: bool = True, + pcm_sample_rate: int = 16000, + pcm_channels: int = 1, + pcm_sample_width: int = 2, ) -> tuple[str, str, float]: """ Validate, write audio data to WAV file, and create AudioSession database entry. @@ -223,6 +282,9 @@ async def write_audio_file( timestamp: Timestamp in milliseconds chunk_dir: Optional directory path (defaults to CHUNK_DIR from config) validate: Whether to validate and prepare audio (default: True for uploads, False for WebSocket) + pcm_sample_rate: Sample rate for raw PCM data when validate=False (default: 16000) + pcm_channels: Channel count for raw PCM data when validate=False (default: 1) + pcm_sample_width: Sample width in bytes for raw PCM data when validate=False (default: 2) Returns: Tuple of (relative_audio_path, absolute_file_path, duration) @@ -242,11 +304,11 @@ async def write_audio_file( audio_data, sample_rate, sample_width, channels, duration = \ await validate_and_prepare_audio(raw_audio_data) else: - # For WebSocket path - audio is already processed PCM + # For WebSocket/streaming path - audio is already processed PCM audio_data = raw_audio_data - sample_rate = 16000 # WebSocket always uses 16kHz - sample_width = 2 - channels = 1 + sample_rate = pcm_sample_rate + sample_width = pcm_sample_width + channels = pcm_channels duration = len(audio_data) / (sample_rate * sample_width * channels) # Use provided chunk_dir or default from config diff --git a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py index 89991327..4ceaee51 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py @@ -159,65 +159,20 @@ def analyze_speech(transcript_data: dict) -> dict: } -async def generate_title(text: str, segments: Optional[list] = None) -> str: +async def generate_title_and_summary( + text: str, segments: Optional[list] = None +) -> tuple[str, str]: """ - Generate an LLM-powered title from conversation text. + Generate title and short summary in a single LLM call using full conversation context. Args: text: Conversation transcript (used if segments not provided) segments: Optional list of speaker segments with structure: [{"speaker": str, "text": str, "start": float, "end": float}, ...] - If provided, uses speaker-aware conversation formatting + If provided, uses speaker-formatted text for richer context Returns: - str: Generated title (3-6 words) or fallback - - Note: - Title intentionally does NOT include speaker names - focuses on topic/theme only. - """ - # Format conversation text from segments if provided - if segments: - conversation_text = "" - for segment in segments[:10]: # Use first 10 segments for title generation - segment_text = segment.text.strip() if segment.text else "" - if segment_text: - conversation_text += f"{segment_text}\n" - text = conversation_text if conversation_text.strip() else text - - if not text or len(text.strip()) < 10: - return "Conversation" - - try: - registry = get_prompt_registry() - prompt_template = await registry.get_prompt("conversation.title") - prompt = f"""{prompt_template} - -"{text[:500]}" -""" - - title = await async_generate(prompt, temperature=0.3) - return title.strip().strip('"').strip("'") or "Conversation" - - except Exception as e: - logger.warning(f"Failed to generate LLM title: {e}") - # Fallback to simple title generation - words = text.split()[:6] - title = " ".join(words) - return title[:40] + "..." if len(title) > 40 else title or "Conversation" - - -async def generate_short_summary(text: str, segments: Optional[list] = None) -> str: - """ - Generate a brief LLM-powered summary from conversation text. - - Args: - text: Conversation transcript (used if segments not provided) - segments: Optional list of speaker segments with structure: - [{"speaker": str, "text": str, "start": float, "end": float}, ...] - If provided, includes speaker context in summary - - Returns: - str: Generated short summary (1-2 sentences, max 120 chars) or fallback + Tuple of (title, short_summary) """ # Format conversation text from segments if provided conversation_text = text @@ -241,37 +196,52 @@ async def generate_short_summary(text: str, segments: Optional[list] = None) -> include_speakers = len(speakers_in_conv) > 0 if not conversation_text or len(conversation_text.strip()) < 10: - return "No content" + return "Conversation", "No content" try: speaker_instruction = ( - '- Include speaker names when relevant (e.g., "John discusses X with Sarah")\n' + '- Include speaker names when relevant in the summary (e.g., "John discusses X with Sarah")\n' if include_speakers else "" ) registry = get_prompt_registry() prompt_text = await registry.get_prompt( - "conversation.short_summary", + "conversation.title_summary", speaker_instruction=speaker_instruction, ) prompt = f"""{prompt_text} -"{conversation_text[:1000]}" +TRANSCRIPT: +"{conversation_text}" """ - summary = await async_generate(prompt, temperature=0.3) - return summary.strip().strip('"').strip("'") or "No content" + response = await async_generate(prompt, temperature=0.3) + + # Parse response for Title: and Summary: lines + title = None + summary = None + for line in response.strip().split("\n"): + line = line.strip() + if line.startswith("Title:"): + title = line.replace("Title:", "").strip().strip('"').strip("'") + elif line.startswith("Summary:"): + summary = line.replace("Summary:", "").strip().strip('"').strip("'") + + title = title or "Conversation" + summary = summary or "No content" + + return title, summary except Exception as e: - logger.warning(f"Failed to generate LLM short summary: {e}") - # Fallback to simple summary generation - return ( - conversation_text[:120] + "..." - if len(conversation_text) > 120 - else conversation_text or "No content" - ) + logger.warning(f"Failed to generate title and summary: {e}") + # Fallback + words = text.split()[:6] + fallback_title = " ".join(words) + fallback_title = fallback_title[:40] + "..." if len(fallback_title) > 40 else fallback_title + fallback_summary = text[:120] + "..." if len(text) > 120 else text + return fallback_title or "Conversation", fallback_summary or "No content" diff --git a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py index 18420ddf..9b5077e6 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py @@ -5,6 +5,7 @@ """ import asyncio +import json import logging import os import time @@ -20,6 +21,7 @@ ) from advanced_omi_backend.controllers.session_controller import mark_session_complete from advanced_omi_backend.models.job import async_job +from advanced_omi_backend.plugins.events import PluginEvent from advanced_omi_backend.services.plugin_service import ( ensure_plugin_router, get_plugin_router, @@ -303,6 +305,18 @@ async def open_conversation_job( conversation_id = conversation.conversation_id logger.info(f"✅ Created streaming conversation {conversation_id} for session {session_id}") + # Attach markers from Redis session (e.g., button events captured during streaming) + session_key = f"audio:session:{session_id}" + markers_json = await redis_client.hget(session_key, "markers") + if markers_json: + try: + markers_data = markers_json if isinstance(markers_json, str) else markers_json.decode() + conversation.markers = json.loads(markers_data) + await conversation.save() + logger.info(f"📌 Attached {len(conversation.markers)} markers to conversation {conversation_id}") + except Exception as marker_err: + logger.warning(f"⚠️ Failed to parse markers from Redis: {marker_err}") + # Link job metadata to conversation (cascading updates) current_job.meta["conversation_id"] = conversation_id current_job.save_meta() @@ -361,6 +375,7 @@ async def open_conversation_job( 0.0 # Initialize with audio time 0 (will be updated with first speech) ) timeout_triggered = False # Track if closure was due to timeout + close_requested_reason = None # Track if closure was requested via API/plugin/button last_inactivity_log_time = ( time.time() ) # Track when we last logged inactivity (wall-clock for logging) @@ -410,6 +425,17 @@ async def open_conversation_job( ) break # Exit immediately when finalize signal received + # Check for conversation close request (set by API, plugins, button press) + if not finalize_received: + close_reason = await redis_client.hget(session_key, "conversation_close_requested") + if close_reason: + await redis_client.hdel(session_key, "conversation_close_requested") + close_requested_reason = close_reason.decode() if isinstance(close_reason, bytes) else close_reason + logger.info(f"🔒 Conversation close requested: {close_requested_reason}") + timeout_triggered = True # Session stays active (same restart behavior as inactivity timeout) + finalize_received = True + break + # Check max runtime timeout if time.time() - start_time > max_runtime: logger.warning(f"⏱️ Max runtime reached for {conversation_id}") @@ -564,7 +590,7 @@ async def open_conversation_job( ) plugin_results = await plugin_router.dispatch_event( - event="transcript.streaming", + event=PluginEvent.TRANSCRIPT_STREAMING, user_id=user_id, data=plugin_data, metadata={"client_id": client_id}, @@ -602,12 +628,16 @@ async def open_conversation_job( # Determine end_reason with proper precedence: # 1. completion_reason from Redis (set by WebSocket controller: websocket_disconnect, user_stopped) - # 2. inactivity_timeout (no speech for SPEECH_INACTIVITY_THRESHOLD_SECONDS) - # 3. max_duration (conversation exceeded max runtime) - # 4. user_stopped (fallback for any other exit condition) + # 2. close_requested (via API, plugin, or button press) + # 3. inactivity_timeout (no speech for SPEECH_INACTIVITY_THRESHOLD_SECONDS) + # 4. max_duration (conversation exceeded max runtime) + # 5. user_stopped (fallback for any other exit condition) if completion_reason_str: end_reason = completion_reason_str logger.info(f"📊 Using completion_reason from session: {end_reason}") + elif close_requested_reason: + end_reason = "close_requested" + logger.info(f"📊 Conversation closed by request: {close_requested_reason}") elif timeout_triggered: end_reason = "inactivity_timeout" elif time.time() - start_time > max_runtime: @@ -655,30 +685,6 @@ async def open_conversation_job( f"(waited {max_wait_streaming}s), proceeding with available transcript" ) - # Wait for streaming transcription consumer to complete before reading transcript - # This fixes the race condition where conversation job reads transcript before - # streaming consumer stores all final results (seen as 24+ second delay in logs) - completion_key = f"transcription:complete:{session_id}" - max_wait_streaming = 30 # seconds - waited_streaming = 0.0 - while waited_streaming < max_wait_streaming: - completion_status = await redis_client.get(completion_key) - if completion_status: - status_str = completion_status.decode() if isinstance(completion_status, bytes) else completion_status - if status_str == "error": - logger.warning(f"⚠️ Streaming transcription ended with error for {session_id}, proceeding anyway") - else: - logger.info(f"✅ Streaming transcription confirmed complete for {session_id}") - break - await asyncio.sleep(0.5) - waited_streaming += 0.5 - - if waited_streaming >= max_wait_streaming: - logger.warning( - f"⚠️ Timed out waiting for streaming completion signal for {session_id} " - f"(waited {max_wait_streaming}s), proceeding with available transcript" - ) - # Wait for audio_streaming_persistence_job to complete and write MongoDB chunks from advanced_omi_backend.utils.audio_chunk_utils import wait_for_audio_chunks @@ -840,8 +846,7 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) from advanced_omi_backend.models.conversation import Conversation from advanced_omi_backend.utils.conversation_utils import ( generate_detailed_summary, - generate_short_summary, - generate_title, + generate_title_and_summary, ) logger.info(f"📝 Starting title/summary generation for conversation {conversation_id}") @@ -893,12 +898,11 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) except Exception as mem_error: logger.warning(f"⚠️ Could not fetch memory context (continuing without): {mem_error}") - # Generate all three summaries in parallel for efficiency + # Generate title+summary (one call) and detailed summary in parallel import asyncio - title, short_summary, detailed_summary = await asyncio.gather( - generate_title(transcript_text, segments=segments), - generate_short_summary(transcript_text, segments=segments), + (title, short_summary), detailed_summary = await asyncio.gather( + generate_title_and_summary(transcript_text, segments=segments), generate_detailed_summary( transcript_text, segments=segments, memory_context=memory_context ), @@ -912,8 +916,8 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) logger.info(f"✅ Generated summary: '{conversation.summary}'") logger.info(f"✅ Generated detailed summary: {len(conversation.detailed_summary)} chars") - # Update processing status for placeholder conversations - if getattr(conversation, "processing_status", None) == "pending_transcription": + # Update processing status for placeholder/reprocessing conversations + if getattr(conversation, "processing_status", None) in ["pending_transcription", "reprocessing"]: conversation.processing_status = "completed" logger.info( f"✅ Updated placeholder conversation {conversation_id} " @@ -923,8 +927,8 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) except Exception as gen_error: logger.error(f"❌ Title/summary generation failed: {gen_error}") - # Mark placeholder conversation as failed - if getattr(conversation, "processing_status", None) == "pending_transcription": + # Mark placeholder/reprocessing conversation as failed + if getattr(conversation, "processing_status", None) in ["pending_transcription", "reprocessing"]: conversation.title = "Audio Recording (Transcription Failed)" conversation.summary = f"Title/summary generation failed: {str(gen_error)}" conversation.processing_status = "transcription_failed" @@ -1082,7 +1086,7 @@ async def dispatch_conversation_complete_event_job( ) plugin_results = await plugin_router.dispatch_event( - event="conversation.complete", + event=PluginEvent.CONVERSATION_COMPLETE, user_id=user_id, data=plugin_data, metadata={"end_reason": actual_end_reason}, diff --git a/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py new file mode 100644 index 00000000..83c82bfa --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py @@ -0,0 +1,236 @@ +""" +Cron job implementations for the Chronicle scheduler. + +Jobs: + - speaker_finetuning: sends applied diarization annotations to speaker service + - asr_jargon_extraction: extracts jargon from recent memories, caches in Redis +""" + +import logging +import os +import time +from datetime import datetime, timezone +from typing import Optional + +import redis.asyncio as aioredis + +from advanced_omi_backend.llm_client import async_generate +from advanced_omi_backend.prompt_registry import get_prompt_registry + +logger = logging.getLogger(__name__) + +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") + +# TTL for cached jargon: 2 hours (job runs every 30 min, so always refreshed) +JARGON_CACHE_TTL = 7200 + +# Maximum number of recent memories to pull per user +MAX_RECENT_MEMORIES = 50 + +# How far back to look for memories (24 hours in seconds) +MEMORY_LOOKBACK_SECONDS = 86400 + + +# --------------------------------------------------------------------------- +# Job 1: Speaker Fine-tuning +# --------------------------------------------------------------------------- + +async def run_speaker_finetuning_job() -> dict: + """Process applied diarization annotations and send to speaker recognition service. + + This mirrors the logic in ``finetuning_routes.process_annotations_for_training`` + but is invocable from the cron scheduler without an HTTP request. + """ + from advanced_omi_backend.models.annotation import Annotation, AnnotationType + from advanced_omi_backend.models.conversation import Conversation + from advanced_omi_backend.speaker_recognition_client import SpeakerRecognitionClient + from advanced_omi_backend.utils.audio_chunk_utils import reconstruct_audio_segment + + # Find annotations ready for training + annotations = await Annotation.find( + Annotation.annotation_type == AnnotationType.DIARIZATION, + Annotation.processed == True, + ).to_list() + + ready_for_training = [ + a for a in annotations if not a.processed_by or "training" not in a.processed_by + ] + + if not ready_for_training: + logger.info("Speaker finetuning: no annotations ready for training") + return {"processed": 0, "message": "No annotations ready for training"} + + speaker_client = SpeakerRecognitionClient() + if not speaker_client.enabled: + logger.warning("Speaker finetuning: speaker recognition service is not enabled") + return {"processed": 0, "message": "Speaker recognition service not enabled"} + + enrolled = 0 + appended = 0 + failed = 0 + cleaned = 0 + + for annotation in ready_for_training: + try: + conversation = await Conversation.find_one( + Conversation.conversation_id == annotation.conversation_id + ) + if not conversation or not conversation.active_transcript: + logger.warning( + f"Conversation {annotation.conversation_id} not found — " + f"deleting orphaned annotation {annotation.id}" + ) + await annotation.delete() + cleaned += 1 + continue + + if annotation.segment_index >= len(conversation.active_transcript.segments): + logger.warning( + f"Invalid segment index {annotation.segment_index} for " + f"conversation {annotation.conversation_id} — " + f"deleting orphaned annotation {annotation.id}" + ) + await annotation.delete() + cleaned += 1 + continue + + segment = conversation.active_transcript.segments[annotation.segment_index] + + wav_bytes = await reconstruct_audio_segment( + conversation_id=annotation.conversation_id, + start_time=segment.start, + end_time=segment.end, + ) + if not wav_bytes: + failed += 1 + continue + + # Intentional: only single admin user (user_id=1) is supported currently + existing_speaker = await speaker_client.get_speaker_by_name( + speaker_name=annotation.corrected_speaker, + user_id=1, + ) + + if existing_speaker: + result = await speaker_client.append_to_speaker( + speaker_id=existing_speaker["id"], audio_data=wav_bytes + ) + if "error" in result: + failed += 1 + continue + appended += 1 + else: + result = await speaker_client.enroll_new_speaker( + speaker_name=annotation.corrected_speaker, + audio_data=wav_bytes, + user_id=1, + ) + if "error" in result: + failed += 1 + continue + enrolled += 1 + + # Mark as trained + annotation.processed_by = ( + f"{annotation.processed_by},training" if annotation.processed_by else "training" + ) + annotation.updated_at = datetime.now(timezone.utc) + await annotation.save() + + except Exception as e: + logger.error(f"Speaker finetuning: error processing annotation {annotation.id}: {e}") + failed += 1 + + total = enrolled + appended + logger.info( + f"Speaker finetuning complete: {total} processed " + f"({enrolled} new, {appended} appended, {failed} failed, {cleaned} orphaned cleaned)" + ) + return {"enrolled": enrolled, "appended": appended, "failed": failed, "cleaned": cleaned, "processed": total} + + +# --------------------------------------------------------------------------- +# Job 2: ASR Jargon Extraction +# --------------------------------------------------------------------------- + +async def run_asr_jargon_extraction_job() -> dict: + """Extract jargon from recent memories for all users and cache in Redis.""" + from advanced_omi_backend.models.user import User + + users = await User.find_all().to_list() + processed = 0 + skipped = 0 + errors = 0 + + redis_client = aioredis.from_url(REDIS_URL, decode_responses=True) + try: + for user in users: + user_id = str(user.id) + try: + jargon = await _extract_jargon_for_user(user_id) + if jargon: + await redis_client.set(f"asr:jargon:{user_id}", jargon, ex=JARGON_CACHE_TTL) + processed += 1 + logger.debug(f"Cached jargon for user {user_id}: {jargon[:80]}...") + else: + skipped += 1 + except Exception as e: + logger.error(f"Jargon extraction failed for user {user_id}: {e}") + errors += 1 + finally: + await redis_client.close() + + logger.info( + f"ASR jargon extraction complete: {processed} users processed, " + f"{skipped} skipped, {errors} errors" + ) + return {"users_processed": processed, "skipped": skipped, "errors": errors} + + +async def _extract_jargon_for_user(user_id: str) -> Optional[str]: + """Pull recent memories from Qdrant, call LLM to extract jargon terms. + + Returns a comma-separated string of jargon terms, or None if nothing found. + """ + from advanced_omi_backend.services.memory import get_memory_service + from advanced_omi_backend.services.memory.providers.chronicle import MemoryService + + memory_service = get_memory_service() + + # Only works with Chronicle provider (has Qdrant vector store) + if not isinstance(memory_service, MemoryService): + logger.debug("Jargon extraction requires Chronicle memory provider, skipping") + return None + + if memory_service.vector_store is None: + return None + + since_ts = int(time.time()) - MEMORY_LOOKBACK_SECONDS + + memories = await memory_service.vector_store.get_recent_memories( + user_id=user_id, + since_timestamp=since_ts, + limit=MAX_RECENT_MEMORIES, + ) + + if not memories: + return None + + # Concatenate memory content + memory_text = "\n".join(m.content for m in memories if m.content) + if not memory_text.strip(): + return None + + # Use LLM to extract jargon + registry = get_prompt_registry() + prompt_template = await registry.get_prompt("asr.jargon_extraction", memories=memory_text) + + result = await async_generate(prompt_template) + + # Clean up: strip whitespace, remove empty items + if result: + terms = [t.strip() for t in result.split(",") if t.strip()] + if terms: + return ", ".join(terms) + + return None diff --git a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py index 9c227bd9..8bf6d27b 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py @@ -2,18 +2,27 @@ Memory-related RQ job functions. This module contains jobs related to memory extraction and processing. + +Supports two processing pathways: +1. **Normal extraction**: Extracts fresh facts from transcript, deduplicates + against existing user memories, and proposes ADD/UPDATE/DELETE actions. +2. **Speaker reprocess**: When triggered after speaker re-identification, + computes a diff between old and new speaker labels, fetches existing + conversation-specific memories, and asks the LLM to make targeted + corrections to speaker attribution in those memories. """ import logging import time import uuid -from typing import Any, Dict +from typing import Any, Dict, List, Optional from advanced_omi_backend.controllers.queue_controller import ( JOB_RESULT_TTL, memory_queue, ) from advanced_omi_backend.models.job import JobPriority, async_job +from advanced_omi_backend.plugins.events import PluginEvent from advanced_omi_backend.services.plugin_service import ensure_plugin_router logger = logging.getLogger(__name__) @@ -21,6 +30,85 @@ MIN_CONVERSATION_LENGTH = 10 +def compute_speaker_diff( + old_segments: list, + new_segments: list, +) -> List[Dict[str, Any]]: + """Compare old and new transcript segments to identify speaker changes. + + Matches segments by time overlap and detects where speaker labels differ. + + Args: + old_segments: Segments from the previous transcript version + new_segments: Segments from the new (active) transcript version + + Returns: + List of change dicts, each with keys: + - ``type``: "speaker_change", "text_change", or "new_segment" + - ``text``: The segment text + - ``old_speaker`` / ``new_speaker``: For speaker changes + - ``old_text`` / ``new_text``: For text changes + - ``start`` / ``end``: Time boundaries + """ + changes: List[Dict[str, Any]] = [] + + for new_seg in new_segments: + new_start = new_seg.start + new_end = new_seg.end + + # Find best matching old segment by time overlap + best_match = None + best_overlap = 0.0 + + for old_seg in old_segments: + overlap_start = max(old_seg.start, new_start) + overlap_end = min(old_seg.end, new_end) + overlap = max(0.0, overlap_end - overlap_start) + + if overlap > best_overlap: + best_overlap = overlap + best_match = old_seg + + if best_match: + # Check for speaker change + if best_match.speaker != new_seg.speaker: + changes.append( + { + "type": "speaker_change", + "text": new_seg.text.strip(), + "old_speaker": best_match.speaker, + "new_speaker": new_seg.speaker, + "start": new_start, + "end": new_end, + } + ) + # Check for text change (less common in speaker reprocessing) + if best_match.text.strip() != new_seg.text.strip(): + changes.append( + { + "type": "text_change", + "old_text": best_match.text.strip(), + "new_text": new_seg.text.strip(), + "speaker": new_seg.speaker, + "start": new_start, + "end": new_end, + } + ) + else: + # No matching old segment found + changes.append( + { + "type": "new_segment", + "text": new_seg.text.strip(), + "speaker": new_seg.speaker, + "start": new_start, + "end": new_end, + } + ) + + return changes + + @async_job(redis=True, beanie=True) async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict[str, Any]: """ @@ -113,17 +201,42 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict ) return {"success": True, "skipped": True, "reason": "No primary speakers"} - # Process memory - memory_service = get_memory_service() - memory_result = await memory_service.add_memory( - full_conversation, - client_id, - conversation_id, - user_id, - user_email, - allow_update=True, + # Detect reprocess trigger from RQ job metadata + from rq import get_current_job as _get_current_job + + current_rq_job = _get_current_job() + trigger = ( + current_rq_job.meta.get("trigger") + if current_rq_job and current_rq_job.meta + else None ) + # Process memory — choose pathway based on trigger + memory_service = get_memory_service() + + if trigger == "reprocess_after_speaker": + # === Speaker reprocess pathway === + # Compute diff between old and new transcript versions + memory_result = await _process_speaker_reprocess( + memory_service=memory_service, + conversation_model=conversation_model, + full_conversation=full_conversation, + client_id=client_id, + conversation_id=conversation_id, + user_id=user_id, + user_email=user_email, + ) + else: + # === Normal extraction pathway === + memory_result = await memory_service.add_memory( + full_conversation, + client_id, + conversation_id, + user_id, + user_email, + allow_update=True, + ) + if memory_result: success, created_memory_ids = memory_result @@ -267,7 +380,7 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict ) plugin_results = await plugin_router.dispatch_event( - event="memory.processed", + event=PluginEvent.MEMORY_PROCESSED, user_id=user_id, data=plugin_data, metadata={ @@ -301,6 +414,119 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict return {"success": False, "error": "Memory service returned False"} +async def _process_speaker_reprocess( + memory_service, + conversation_model, + full_conversation: str, + client_id: str, + conversation_id: str, + user_id: str, + user_email: str, +): + """Handle memory reprocessing after speaker re-identification. + + Computes the diff between the previous and current transcript versions + (specifically speaker label changes), then delegates to the memory + service's ``reprocess_memory`` method for targeted updates. + + Falls back to normal ``add_memory`` if diff computation fails or + no meaningful changes are detected. + + Args: + memory_service: Active memory service instance + conversation_model: Conversation Beanie document + full_conversation: New transcript as dialogue lines + client_id: Client identifier + conversation_id: Conversation identifier + user_id: User identifier + user_email: User email + + Returns: + Tuple of (success, memory_ids) matching ``add_memory`` return type + """ + active_version = conversation_model.active_transcript + + if not active_version: + logger.warning( + f"🔄 Reprocess: no active transcript version for {conversation_id}, " + f"falling back to normal extraction" + ) + return await memory_service.add_memory( + full_conversation, client_id, conversation_id, user_id, user_email, + allow_update=True, + ) + + # Find the source (previous) transcript version from metadata + source_version_id = active_version.metadata.get("source_version_id") + + if not source_version_id: + logger.warning( + f"🔄 Reprocess: no source_version_id in active transcript metadata " + f"for {conversation_id}, falling back to normal extraction" + ) + return await memory_service.add_memory( + full_conversation, client_id, conversation_id, user_id, user_email, + allow_update=True, + ) + + # Find the source version's segments + source_version = None + for v in conversation_model.transcript_versions: + if v.version_id == source_version_id: + source_version = v + break + + if not source_version or not source_version.segments: + logger.warning( + f"🔄 Reprocess: source version {source_version_id} not found or has no segments " + f"for {conversation_id}, falling back to normal extraction" + ) + return await memory_service.add_memory( + full_conversation, client_id, conversation_id, user_id, user_email, + allow_update=True, + ) + + # Compute the speaker diff + transcript_diff = compute_speaker_diff( + source_version.segments, + active_version.segments, + ) + + if not transcript_diff: + logger.info( + f"🔄 Reprocess: no speaker changes detected between versions " + f"for {conversation_id}, falling back to normal extraction" + ) + return await memory_service.add_memory( + full_conversation, client_id, conversation_id, user_id, user_email, + allow_update=True, + ) + + # Build the previous transcript for context + previous_lines = [] + for seg in source_version.segments: + text = seg.text.strip() + if text: + previous_lines.append(f"{seg.speaker}: {text}") + previous_transcript = "\n".join(previous_lines) + + logger.info( + f"🔄 Reprocess: detected {len(transcript_diff)} changes " + f"(speakers reprocessed) for {conversation_id}" + ) + + # Use the reprocess pathway + return await memory_service.reprocess_memory( + transcript=full_conversation, + client_id=client_id, + source_id=conversation_id, + user_id=user_id, + user_email=user_email, + transcript_diff=transcript_diff, + previous_transcript=previous_transcript, + ) + + def enqueue_memory_processing( conversation_id: str, priority: JobPriority = JobPriority.NORMAL, diff --git a/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py index 8c90701e..bfe38c62 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py @@ -279,25 +279,12 @@ async def recognise_speakers_job( can_run_pyannote = bool(actual_words) and not provider_has_diarization if not actual_words and not provider_has_diarization: - # No words AND provider didn't diarize - we have a problem - # This can happen with VibeVoice if it fails to return segments - logger.warning( - f"🎤 No word timestamps available and provider didn't diarize. " - f"Speaker recognition cannot improve segments." - ) - # Keep existing segments and return success (we can't do better) - if transcript_version.segments: - return { - "success": True, - "conversation_id": conversation_id, - "version_id": version_id, - "speaker_recognition_enabled": True, - "identified_speakers": [], - "segment_count": len(transcript_version.segments), - "skip_reason": "No word timestamps available for pyannote, keeping provider segments", - "processing_time_seconds": time.time() - start_time - } - else: + if not transcript_version.segments: + # No words, no provider diarization, no existing segments - nothing we can do + logger.warning( + f"🎤 No word timestamps available, provider didn't diarize, " + f"and no existing segments to identify." + ) return { "success": False, "conversation_id": conversation_id, @@ -305,11 +292,37 @@ async def recognise_speakers_job( "error": "No word timestamps and no segments available", "processing_time_seconds": time.time() - start_time } + # Has existing segments - fall through to run identification on them + logger.info( + f"🎤 No word timestamps for pyannote re-diarization, but " + f"{len(transcript_version.segments)} existing segments found. " + f"Running speaker identification on existing segments." + ) + + # Determine speaker identification mode: + # 1. Config toggle (per_segment_speaker_id) enables per-segment globally + # 2. Manual reprocess trigger also enables per-segment for that run + from advanced_omi_backend.config import get_misc_settings + misc_config = get_misc_settings() + per_segment_config = misc_config.get("per_segment_speaker_id", False) + + trigger = transcript_version.metadata.get("trigger", "") + is_reprocess = trigger == "manual_reprocess" + + use_per_segment = per_segment_config or is_reprocess + if use_per_segment: + reason = [] + if per_segment_config: + reason.append("config toggle enabled") + if is_reprocess: + reason.append("manual reprocess") + logger.info(f"🎤 Per-segment identification mode active ({', '.join(reason)})") try: - if provider_has_diarization and transcript_version.segments: - # Provider already diarized (e.g. VibeVoice) - use segment-level identification - logger.info(f"🎤 Using segment-level speaker identification for provider-diarized segments") + if transcript_version.segments and not can_run_pyannote: + # Have existing segments and can't/shouldn't run pyannote - do identification only + # Covers: provider already diarized, no word timestamps but segments exist, etc. + logger.info(f"🎤 Using segment-level speaker identification on {len(transcript_version.segments)} existing segments") segments_data = [ {"start": s.start, "end": s.end, "text": s.text, "speaker": s.speaker} for s in transcript_version.segments @@ -318,6 +331,8 @@ async def recognise_speakers_job( conversation_id=conversation_id, segments=segments_data, user_id=user_id, + per_segment=use_per_segment, + min_segment_duration=0.5 if use_per_segment else 1.5, ) else: # Standard path: full diarization + identification via speaker service @@ -437,6 +452,20 @@ async def recognise_speakers_job( speaker_segments = speaker_result["segments"] logger.info(f"🎤 Speaker recognition returned {len(speaker_segments)} segments") + # Build mapping for unknown speakers: diarization_label -> "Unknown Speaker N" + unknown_label_map = {} + unknown_counter = 1 + for seg in speaker_segments: + identified_as = seg.get("identified_as") + if not identified_as: + label = seg.get("speaker", "Unknown") + if label not in unknown_label_map: + unknown_label_map[label] = f"Unknown Speaker {unknown_counter}" + unknown_counter += 1 + + if unknown_label_map: + logger.info(f"🎤 Unknown speaker mapping: {unknown_label_map}") + # Update the transcript version segments with identified speakers # Filter out empty segments (diarization sometimes creates segments with no text) updated_segments = [] @@ -457,7 +486,7 @@ async def recognise_speakers_job( logger.debug(f"Filtered segment with invalid timing: {seg}") continue - speaker_name = seg.get("identified_as") or seg.get("speaker", "Unknown") + speaker_name = seg.get("identified_as") or unknown_label_map.get(seg.get("speaker", "Unknown"), "Unknown Speaker") # Extract words from speaker service response (already matched to this segment) words_data = seg.get("words", []) @@ -502,6 +531,7 @@ async def recognise_speakers_job( transcript_version.metadata["speaker_recognition"] = { "enabled": True, + "identification_mode": "per_segment" if use_per_segment else "majority_vote", "identified_speakers": list(identified_speakers), "speaker_count": len(identified_speakers), "total_segments": len(speaker_segments), diff --git a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py index 19483f55..f2dda207 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py @@ -5,10 +5,13 @@ """ import asyncio +import io +import json import logging import os import time import uuid +import wave from datetime import datetime from pathlib import Path from typing import Any, Dict @@ -30,6 +33,7 @@ from advanced_omi_backend.models.conversation import Conversation from advanced_omi_backend.models.job import BaseRQJob, JobPriority, async_job from advanced_omi_backend.services.audio_stream import TranscriptionResultsAggregator +from advanced_omi_backend.plugins.events import PluginEvent from advanced_omi_backend.services.plugin_service import ensure_plugin_router from advanced_omi_backend.services.transcription import ( get_transcription_provider, @@ -215,13 +219,32 @@ async def transcribe_full_audio_job( logger.error(f"Failed to reconstruct audio from MongoDB: {e}", exc_info=True) raise RuntimeError(f"Audio reconstruction failed: {e}") + # Build ASR context (static hot words + per-user cached jargon) + try: + from advanced_omi_backend.services.transcription.context import get_asr_context + + context_info = await get_asr_context(user_id=user_id) + except Exception as e: + logger.warning(f"Failed to build ASR context: {e}") + context_info = None + + # Read actual sample rate from WAV header + try: + with wave.open(io.BytesIO(wav_data), "rb") as wf: + actual_sample_rate = wf.getframerate() + except Exception: + actual_sample_rate = 16000 + try: # Transcribe the audio directly from memory (no disk I/O needed) - transcription_result = await provider.transcribe( - audio_data=wav_data, # Pass bytes directly, already in memory - sample_rate=16000, - diarize=True, - ) + transcribe_kwargs: Dict[str, Any] = { + "audio_data": wav_data, + "sample_rate": actual_sample_rate, + "diarize": True, + } + if context_info: + transcribe_kwargs["context_info"] = context_info + transcription_result = await provider.transcribe(**transcribe_kwargs) except ConnectionError as e: logger.exception(f"Transcription service unreachable for {conversation_id}") raise RuntimeError(str(e)) @@ -267,7 +290,7 @@ async def transcribe_full_audio_job( ) plugin_results = await plugin_router.dispatch_event( - event="transcript.batch", + event=PluginEvent.TRANSCRIPT_BATCH, user_id=user_id, data=plugin_data, metadata={"client_id": client_id}, @@ -653,7 +676,7 @@ async def create_audio_only_conversation( @async_job(redis=True, beanie=True) async def transcription_fallback_check_job( - session_id: str, user_id: str, client_id: str, timeout_seconds: int = 1800, *, redis_client=None + session_id: str, user_id: str, client_id: str, timeout_seconds: int = 900, *, redis_client=None ) -> Dict[str, Any]: """ Check if streaming transcription succeeded, fallback to batch if needed. @@ -666,7 +689,7 @@ async def transcription_fallback_check_job( session_id: Stream session ID user_id: User ID client_id: Client ID - timeout_seconds: Max wait time for batch transcription (default 30 minutes) + timeout_seconds: Max wait time for batch transcription (default 15 minutes) redis_client: Redis client (injected by decorator) Returns: @@ -781,9 +804,23 @@ async def transcription_fallback_check_job( sorted_chunks = sorted(audio_chunks.items()) combined_audio = b"".join(data for _, data in sorted_chunks) + # Read audio format from Redis session metadata + sample_rate, channels, sample_width = 16000, 1, 2 + session_key = f"audio:session:{session_id}" + try: + audio_format_raw = await redis_client.hget(session_key, "audio_format") + if audio_format_raw: + audio_format = json.loads(audio_format_raw) + sample_rate = int(audio_format.get("rate", 16000)) + channels = int(audio_format.get("channels", 1)) + sample_width = int(audio_format.get("width", 2)) + except Exception as e: + logger.warning(f"Failed to read audio_format from Redis for {session_id}: {e}") + + bytes_per_second = sample_rate * channels * sample_width logger.info( f"✅ Extracted {len(sorted_chunks)} audio chunks from Redis stream " - f"({len(combined_audio)} bytes, ~{len(combined_audio)/32000:.1f}s)" + f"({len(combined_audio)} bytes, ~{len(combined_audio)/bytes_per_second:.1f}s)" ) # Create conversation placeholder @@ -793,9 +830,9 @@ async def transcription_fallback_check_job( num_chunks = await convert_audio_to_chunks( conversation_id=conversation.conversation_id, audio_data=combined_audio, - sample_rate=16000, - channels=1, - sample_width=2, + sample_rate=sample_rate, + channels=channels, + sample_width=sample_width, ) logger.info( @@ -822,7 +859,7 @@ async def transcription_fallback_check_job( conversation.conversation_id, version_id, "batch_fallback", - job_timeout=1800, + job_timeout=900, # 15 minutes job_id=f"transcribe_{conversation.conversation_id[:12]}", description=f"Batch transcription fallback for {session_id[:8]}", meta={"session_id": session_id, "client_id": client_id}, @@ -1248,8 +1285,8 @@ async def stream_speech_detection_job( session_id, user_id, client_id, - timeout_seconds=1800, # 30 minutes for batch transcription - job_timeout=2400, # 40 minutes job timeout + timeout_seconds=900, # 15 minutes for batch transcription + job_timeout=1200, # 20 minutes job timeout (includes overhead for fallback check) job_id=f"fallback_check_{session_id[:12]}", description=f"Transcription fallback check for {session_id[:8]} (no speech)", meta={"session_id": session_id, "client_id": client_id, "no_speech": True}, diff --git a/backends/advanced/uv.lock b/backends/advanced/uv.lock index 98dc39d1..9c9a1b1c 100644 --- a/backends/advanced/uv.lock +++ b/backends/advanced/uv.lock @@ -13,6 +13,7 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "aiohttp" }, + { name = "croniter" }, { name = "easy-audio-interfaces" }, { name = "en-core-web-sm" }, { name = "fastapi" }, @@ -72,6 +73,7 @@ test = [ [package.metadata] requires-dist = [ { name = "aiohttp", specifier = ">=3.8.0" }, + { name = "croniter", specifier = ">=1.3.0" }, { name = "deepgram-sdk", marker = "extra == 'deepgram'", specifier = ">=4.0.0" }, { name = "easy-audio-interfaces", specifier = ">=0.7.1" }, { name = "easy-audio-interfaces", extras = ["local-audio"], marker = "extra == 'local-audio'", specifier = ">=0.7.1" }, diff --git a/backends/advanced/webui/package-lock.json b/backends/advanced/webui/package-lock.json index ead72812..7fe0c6d6 100644 --- a/backends/advanced/webui/package-lock.json +++ b/backends/advanced/webui/package-lock.json @@ -10,6 +10,7 @@ "dependencies": { "axios": "^1.6.2", "clsx": "^2.0.0", + "cronstrue": "^2.50.0", "d3": "^7.8.5", "frappe-gantt": "^1.0.4", "lucide-react": "^0.294.0", @@ -2719,6 +2720,15 @@ "dev": true, "license": "MIT" }, + "node_modules/cronstrue": { + "version": "2.59.0", + "resolved": "https://registry.npmjs.org/cronstrue/-/cronstrue-2.59.0.tgz", + "integrity": "sha512-YKGmAy84hKH+hHIIER07VCAHf9u0Ldelx1uU6EBxsRPDXIA1m5fsKmJfyC3xBhw6cVC/1i83VdbL4PvepTrt8A==", + "license": "MIT", + "bin": { + "cronstrue": "bin/cli.js" + } + }, "node_modules/cross-spawn": { "version": "7.0.6", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", diff --git a/backends/advanced/webui/package.json b/backends/advanced/webui/package.json index b933d8db..c6c1f771 100644 --- a/backends/advanced/webui/package.json +++ b/backends/advanced/webui/package.json @@ -11,6 +11,7 @@ }, "dependencies": { "axios": "^1.6.2", + "cronstrue": "^2.50.0", "clsx": "^2.0.0", "d3": "^7.8.5", "frappe-gantt": "^1.0.4", diff --git a/backends/advanced/webui/src/components/knowledge-graph/EntityCard.tsx b/backends/advanced/webui/src/components/knowledge-graph/EntityCard.tsx index 76d28cf9..bb48fef1 100644 --- a/backends/advanced/webui/src/components/knowledge-graph/EntityCard.tsx +++ b/backends/advanced/webui/src/components/knowledge-graph/EntityCard.tsx @@ -1,4 +1,6 @@ -import { User, MapPin, Building, Calendar, Package, Link2 } from 'lucide-react' +import { useState } from 'react' +import { User, MapPin, Building, Calendar, Package, Link2, Pencil, Check, X } from 'lucide-react' +import { knowledgeGraphApi } from '../../services/api' export interface Entity { id: string @@ -20,6 +22,7 @@ export interface Entity { interface EntityCardProps { entity: Entity onClick?: (entity: Entity) => void + onEntityUpdated?: (entity: Entity) => void compact?: boolean } @@ -39,7 +42,12 @@ const typeColors: Record = { thing: 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-300', } -export default function EntityCard({ entity, onClick, compact = false }: EntityCardProps) { +export default function EntityCard({ entity, onClick, onEntityUpdated, compact = false }: EntityCardProps) { + const [isEditing, setIsEditing] = useState(false) + const [editName, setEditName] = useState(entity.name) + const [editDetails, setEditDetails] = useState(entity.details || '') + const [saving, setSaving] = useState(false) + const icon = typeIcons[entity.type] || const colorClass = typeColors[entity.type] || typeColors.thing @@ -52,6 +60,41 @@ export default function EntityCard({ entity, onClick, compact = false }: EntityC } } + const handleEditClick = (e: React.MouseEvent) => { + e.stopPropagation() + setEditName(entity.name) + setEditDetails(entity.details || '') + setIsEditing(true) + } + + const handleCancel = (e: React.MouseEvent) => { + e.stopPropagation() + setIsEditing(false) + } + + const handleSave = async (e: React.MouseEvent) => { + e.stopPropagation() + const updates: { name?: string; details?: string } = {} + if (editName.trim() !== entity.name) updates.name = editName.trim() + if (editDetails.trim() !== (entity.details || '')) updates.details = editDetails.trim() + + if (Object.keys(updates).length === 0) { + setIsEditing(false) + return + } + + try { + setSaving(true) + const response = await knowledgeGraphApi.updateEntity(entity.id, updates) + setIsEditing(false) + onEntityUpdated?.(response.data.entity) + } catch (err) { + console.error('Failed to update entity:', err) + } finally { + setSaving(false) + } + } + if (compact) { return (
onClick?.(entity)} - className="bg-white dark:bg-gray-800 rounded-lg border border-gray-200 dark:border-gray-700 p-4 hover:border-blue-400 dark:hover:border-blue-500 transition-colors cursor-pointer group" + onClick={() => !isEditing && onClick?.(entity)} + className={`bg-white dark:bg-gray-800 rounded-lg border border-gray-200 dark:border-gray-700 p-4 hover:border-blue-400 dark:hover:border-blue-500 transition-colors ${isEditing ? '' : 'cursor-pointer'} group`} >
-
-
+
+
{entity.icon ? ( {entity.icon} ) : ( icon )}
-
-

- {entity.name} -

+
+ {isEditing ? ( + setEditName(e.target.value)} + onClick={(e) => e.stopPropagation()} + className="w-full px-2 py-1 text-sm font-semibold border border-gray-300 dark:border-gray-600 rounded bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500" + autoFocus + /> + ) : ( +

+ {entity.name} +

+ )} {entity.type}
- {entity.relationship_count != null && entity.relationship_count > 0 && ( -
- - {entity.relationship_count} -
- )} +
+ {isEditing ? ( + <> + + + + ) : ( + <> + + {entity.relationship_count != null && entity.relationship_count > 0 && ( +
+ + {entity.relationship_count} +
+ )} + + )} +
- {entity.details && ( -

- {entity.details} -

+ {isEditing ? ( +