diff --git a/backends/advanced/init.py b/backends/advanced/init.py index 4ea037b2..f26958eb 100644 --- a/backends/advanced/init.py +++ b/backends/advanced/init.py @@ -34,21 +34,31 @@ def __init__(self, args=None): self.console = Console() self.config: Dict[str, Any] = {} self.args = args or argparse.Namespace() - self.config_yml_path = Path("../../config/config.yml") # Main config at config/config.yml + self.config_yml_path = Path( + "../../config/config.yml" + ) # Main config at config/config.yml # Check if we're in the right directory if not Path("pyproject.toml").exists() or not Path("src").exists(): - self.console.print("[red][ERROR][/red] Please run this script from the backends/advanced directory") + self.console.print( + "[red][ERROR][/red] Please run this script from the backends/advanced directory" + ) sys.exit(1) # Initialize ConfigManager (single source of truth for config.yml) self.config_manager = ConfigManager(service_path="backends/advanced") - self.console.print(f"[blue][INFO][/blue] Using config.yml at: {self.config_manager.config_yml_path}") + self.console.print( + f"[blue][INFO][/blue] Using config.yml at: {self.config_manager.config_yml_path}" + ) # Verify config.yml exists - fail fast if missing if not self.config_manager.config_yml_path.exists(): - self.console.print(f"[red][ERROR][/red] config.yml not found at {self.config_manager.config_yml_path}") - self.console.print("[red][ERROR][/red] Run wizard.py from project root to create config.yml") + self.console.print( + f"[red][ERROR][/red] config.yml not found at {self.config_manager.config_yml_path}" + ) + self.console.print( + "[red][ERROR][/red] Run wizard.py from project root to create config.yml" + ) sys.exit(1) # Ensure plugins.yml exists (copy from template if missing) @@ -57,11 +67,7 @@ def __init__(self, args=None): def print_header(self, title: str): """Print a colorful header""" self.console.print() - panel = Panel( - Text(title, style="cyan bold"), - style="cyan", - expand=False - ) + panel = Panel(Text(title, style="cyan bold"), style="cyan", expand=False) self.console.print(panel) self.console.print() @@ -84,19 +90,23 @@ def prompt_password(self, prompt: str) -> str: """Prompt for password (delegates to shared utility)""" return util_prompt_password(prompt, min_length=8, allow_generated=True) - def prompt_choice(self, prompt: str, choices: Dict[str, str], default: str = "1") -> str: + def prompt_choice( + self, prompt: str, choices: Dict[str, str], default: str = "1" + ) -> str: """Prompt for a choice from options""" self.console.print(prompt) for key, desc in choices.items(): self.console.print(f" {key}) {desc}") self.console.print() - + while True: try: choice = Prompt.ask("Enter choice", default=default) if choice in choices: return choice - self.console.print(f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]") + self.console.print( + f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]" + ) except EOFError: self.console.print(f"Using default choice: {default}") return default @@ -108,11 +118,19 @@ def _ensure_plugins_yml_exists(self): if not plugins_yml.exists(): if plugins_template.exists(): - self.console.print("[blue][INFO][/blue] plugins.yml not found, creating from template...") + self.console.print( + "[blue][INFO][/blue] plugins.yml not found, creating from template..." + ) shutil.copy2(plugins_template, plugins_yml) - self.console.print(f"[green]✅[/green] Created {plugins_yml} from template") - self.console.print("[yellow][NOTE][/yellow] Edit config/plugins.yml to configure plugins") - self.console.print("[yellow][NOTE][/yellow] Set HA_TOKEN in .env for Home Assistant integration") + self.console.print( + f"[green]✅[/green] Created {plugins_yml} from template" + ) + self.console.print( + "[yellow][NOTE][/yellow] Edit config/plugins.yml to configure plugins" + ) + self.console.print( + "[yellow][NOTE][/yellow] Set HA_TOKEN in .env for Home Assistant integration" + ) else: raise RuntimeError( f"Template file not found: {plugins_template}\n" @@ -128,7 +146,9 @@ def backup_existing_env(self): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = f".env.backup.{timestamp}" shutil.copy2(env_path, backup_path) - self.console.print(f"[blue][INFO][/blue] Backed up existing .env file to {backup_path}") + self.console.print( + f"[blue][INFO][/blue] Backed up existing .env file to {backup_path}" + ) def read_existing_env_value(self, key: str) -> str: """Read a value from existing .env file (delegates to shared utility)""" @@ -138,8 +158,14 @@ def mask_api_key(self, key: str, show_chars: int = 5) -> str: """Mask API key (delegates to shared utility)""" return mask_value(key, show_chars) - def prompt_with_existing_masked(self, prompt_text: str, env_key: str, placeholders: list, - is_password: bool = False, default: str = "") -> str: + def prompt_with_existing_masked( + self, + prompt_text: str, + env_key: str, + placeholders: list, + is_password: bool = False, + default: str = "", + ) -> str: """ Prompt for a value, showing masked existing value from .env if present. Delegates to shared utility from setup_utils. @@ -161,10 +187,9 @@ def prompt_with_existing_masked(self, prompt_text: str, env_key: str, placeholde env_key=env_key, placeholders=placeholders, is_password=is_password, - default=default + default=default, ) - def setup_authentication(self): """Configure authentication settings""" self.print_section("Authentication Setup") @@ -186,13 +211,17 @@ def setup_authentication(self): ) self.config["ADMIN_PASSWORD"] = password else: - self.config["ADMIN_PASSWORD"] = self.prompt_password("Admin password (min 8 chars)") + self.config["ADMIN_PASSWORD"] = self.prompt_password( + "Admin password (min 8 chars)" + ) # Preserve existing AUTH_SECRET_KEY to avoid invalidating JWTs existing_secret = self.read_existing_env_value("AUTH_SECRET_KEY") if existing_secret: self.config["AUTH_SECRET_KEY"] = existing_secret - self.console.print("[blue][INFO][/blue] Reusing existing AUTH_SECRET_KEY (existing JWT tokens remain valid)") + self.console.print( + "[blue][INFO][/blue] Reusing existing AUTH_SECRET_KEY (existing JWT tokens remain valid)" + ) else: self.config["AUTH_SECRET_KEY"] = secrets.token_hex(32) @@ -201,9 +230,14 @@ def setup_authentication(self): def setup_transcription(self): """Configure transcription provider - updates config.yml and .env""" # Check if transcription provider was provided via command line - if hasattr(self.args, 'transcription_provider') and self.args.transcription_provider: + if ( + hasattr(self.args, "transcription_provider") + and self.args.transcription_provider + ): provider = self.args.transcription_provider - self.console.print(f"[green]✅[/green] Transcription: {provider} (configured via wizard)") + self.console.print( + f"[green]✅[/green] Transcription: {provider} (configured via wizard)" + ) # Map provider to choice if provider == "deepgram": @@ -223,21 +257,27 @@ def setup_transcription(self): else: self.print_section("Speech-to-Text Configuration") - self.console.print("[blue][INFO][/blue] Provider selection is configured in config.yml (defaults.stt)") + self.console.print( + "[blue][INFO][/blue] Provider selection is configured in config.yml (defaults.stt)" + ) self.console.print("[blue][INFO][/blue] API keys are stored in .env") self.console.print() # Interactive prompt - is_macos = platform.system() == 'Darwin' + is_macos = platform.system() == "Darwin" if is_macos: parakeet_desc = "Offline (Parakeet ASR - CPU-based, runs locally)" vibevoice_desc = "Offline (VibeVoice - CPU-based, built-in diarization)" else: parakeet_desc = "Offline (Parakeet ASR - GPU recommended, runs locally)" - vibevoice_desc = "Offline (VibeVoice - GPU recommended, built-in diarization)" + vibevoice_desc = ( + "Offline (VibeVoice - GPU recommended, built-in diarization)" + ) - qwen3_desc = "Offline (Qwen3-ASR - GPU required, 52 languages, streaming + batch)" + qwen3_desc = ( + "Offline (Qwen3-ASR - GPU required, 52 languages, streaming + batch)" + ) smallest_desc = "Smallest.ai Pulse (cloud-based, fast, requires API key)" @@ -247,10 +287,12 @@ def setup_transcription(self): "3": vibevoice_desc, "4": qwen3_desc, "5": smallest_desc, - "6": "None (skip transcription setup)" + "6": "None (skip transcription setup)", } - choice = self.prompt_choice("Choose your transcription provider:", choices, "1") + choice = self.prompt_choice( + "Choose your transcription provider:", choices, "1" + ) if choice == "1": self.console.print("[blue][INFO][/blue] Deepgram selected") @@ -260,9 +302,9 @@ def setup_transcription(self): api_key = self.prompt_with_existing_masked( prompt_text="Deepgram API key (leave empty to skip)", env_key="DEEPGRAM_API_KEY", - placeholders=['your_deepgram_api_key_here', 'your-deepgram-key-here'], + placeholders=["your_deepgram_api_key_here", "your-deepgram-key-here"], is_password=True, - default="" + default="", ) if api_key: @@ -272,14 +314,20 @@ def setup_transcription(self): # Update config.yml to use Deepgram self.config_manager.update_config_defaults({"stt": "stt-deepgram"}) - self.console.print("[green][SUCCESS][/green] Deepgram configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] Deepgram configured in config.yml and .env" + ) self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-deepgram") else: - self.console.print("[yellow][WARNING][/yellow] No API key provided - transcription will not work") + self.console.print( + "[yellow][WARNING][/yellow] No API key provided - transcription will not work" + ) elif choice == "2": self.console.print("[blue][INFO][/blue] Offline Parakeet ASR selected") - parakeet_url = self.prompt_value("Parakeet ASR URL", "http://host.docker.internal:8767") + parakeet_url = self.prompt_value( + "Parakeet ASR URL", "http://host.docker.internal:8767" + ) # Write URL to .env for ${PARAKEET_ASR_URL} placeholder in config.yml self.config["PARAKEET_ASR_URL"] = parakeet_url @@ -287,13 +335,23 @@ def setup_transcription(self): # Update config.yml to use Parakeet self.config_manager.update_config_defaults({"stt": "stt-parakeet-batch"}) - self.console.print("[green][SUCCESS][/green] Parakeet configured in config.yml and .env") - self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-parakeet-batch") - self.console.print("[yellow][WARNING][/yellow] Remember to start Parakeet service: cd ../../extras/asr-services && docker compose up nemo-asr") + self.console.print( + "[green][SUCCESS][/green] Parakeet configured in config.yml and .env" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.stt: stt-parakeet-batch" + ) + self.console.print( + "[yellow][WARNING][/yellow] Remember to start Parakeet service: cd ../../extras/asr-services && docker compose up nemo-asr" + ) elif choice == "3": - self.console.print("[blue][INFO][/blue] Offline VibeVoice ASR selected (built-in speaker diarization)") - vibevoice_url = self.prompt_value("VibeVoice ASR URL", "http://host.docker.internal:8767") + self.console.print( + "[blue][INFO][/blue] Offline VibeVoice ASR selected (built-in speaker diarization)" + ) + vibevoice_url = self.prompt_value( + "VibeVoice ASR URL", "http://host.docker.internal:8767" + ) # Write URL to .env for ${VIBEVOICE_ASR_URL} placeholder in config.yml self.config["VIBEVOICE_ASR_URL"] = vibevoice_url @@ -301,14 +359,24 @@ def setup_transcription(self): # Update config.yml to use VibeVoice self.config_manager.update_config_defaults({"stt": "stt-vibevoice"}) - self.console.print("[green][SUCCESS][/green] VibeVoice configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] VibeVoice configured in config.yml and .env" + ) self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-vibevoice") - self.console.print("[blue][INFO][/blue] VibeVoice provides built-in speaker diarization - pyannote will be skipped") - self.console.print("[yellow][WARNING][/yellow] Remember to start VibeVoice service: cd ../../extras/asr-services && docker compose up vibevoice-asr") + self.console.print( + "[blue][INFO][/blue] VibeVoice provides built-in speaker diarization - pyannote will be skipped" + ) + self.console.print( + "[yellow][WARNING][/yellow] Remember to start VibeVoice service: cd ../../extras/asr-services && docker compose up vibevoice-asr" + ) elif choice == "4": - self.console.print("[blue][INFO][/blue] Qwen3-ASR selected (52 languages, streaming + batch via vLLM)") - qwen3_url = self.prompt_value("Qwen3-ASR URL", "http://host.docker.internal:8767") + self.console.print( + "[blue][INFO][/blue] Qwen3-ASR selected (52 languages, streaming + batch via vLLM)" + ) + qwen3_url = self.prompt_value( + "Qwen3-ASR URL", "http://host.docker.internal:8767" + ) # Write URL to .env for ${QWEN3_ASR_URL} placeholder in config.yml self.config["QWEN3_ASR_URL"] = qwen3_url.replace("http://", "").rstrip("/") @@ -320,9 +388,13 @@ def setup_transcription(self): # Update config.yml to use Qwen3-ASR self.config_manager.update_config_defaults({"stt": "stt-qwen3-asr"}) - self.console.print("[green][SUCCESS][/green] Qwen3-ASR configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] Qwen3-ASR configured in config.yml and .env" + ) self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-qwen3-asr") - self.console.print("[yellow][WARNING][/yellow] Remember to start Qwen3-ASR: cd ../../extras/asr-services && docker compose up qwen3-asr-wrapper qwen3-asr-bridge -d") + self.console.print( + "[yellow][WARNING][/yellow] Remember to start Qwen3-ASR: cd ../../extras/asr-services && docker compose up qwen3-asr-wrapper qwen3-asr-bridge -d" + ) elif choice == "5": self.console.print("[blue][INFO][/blue] Smallest.ai Pulse selected") @@ -332,9 +404,9 @@ def setup_transcription(self): api_key = self.prompt_with_existing_masked( prompt_text="Smallest.ai API key (leave empty to skip)", env_key="SMALLEST_API_KEY", - placeholders=['your_smallest_api_key_here', 'your-smallest-key-here'], + placeholders=["your_smallest_api_key_here", "your-smallest-key-here"], is_password=True, - default="" + default="", ) if api_key: @@ -342,16 +414,21 @@ def setup_transcription(self): self.config["SMALLEST_API_KEY"] = api_key # Update config.yml to use Smallest.ai (batch + streaming) - self.config_manager.update_config_defaults({ - "stt": "stt-smallest", - "stt_stream": "stt-smallest-stream" - }) + self.config_manager.update_config_defaults( + {"stt": "stt-smallest", "stt_stream": "stt-smallest-stream"} + ) - self.console.print("[green][SUCCESS][/green] Smallest.ai configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] Smallest.ai configured in config.yml and .env" + ) self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-smallest") - self.console.print("[blue][INFO][/blue] Set defaults.stt_stream: stt-smallest-stream") + self.console.print( + "[blue][INFO][/blue] Set defaults.stt_stream: stt-smallest-stream" + ) else: - self.console.print("[yellow][WARNING][/yellow] No API key provided - transcription will not work") + self.console.print( + "[yellow][WARNING][/yellow] No API key provided - transcription will not work" + ) elif choice == "6": self.console.print("[blue][INFO][/blue] Skipping transcription setup") @@ -362,11 +439,16 @@ def setup_streaming_provider(self): When a different streaming provider is specified, sets defaults.stt_stream and enables always_batch_retranscribe (batch provider was set by setup_transcription). """ - if not hasattr(self.args, 'streaming_provider') or not self.args.streaming_provider: + if ( + not hasattr(self.args, "streaming_provider") + or not self.args.streaming_provider + ): return streaming_provider = self.args.streaming_provider - self.console.print(f"\n[green]✅[/green] Streaming provider: {streaming_provider} (configured via wizard)") + self.console.print( + f"\n[green]✅[/green] Streaming provider: {streaming_provider} (configured via wizard)" + ) # Map streaming provider to stt_stream config value provider_to_stt_stream = { @@ -377,7 +459,9 @@ def setup_streaming_provider(self): stream_stt = provider_to_stt_stream.get(streaming_provider) if not stream_stt: - self.console.print(f"[yellow][WARNING][/yellow] Unknown streaming provider: {streaming_provider}") + self.console.print( + f"[yellow][WARNING][/yellow] Unknown streaming provider: {streaming_provider}" + ) return # Set stt_stream (batch stt was already set by setup_transcription) @@ -385,11 +469,11 @@ def setup_streaming_provider(self): # Enable always_batch_retranscribe full_config = self.config_manager.get_full_config() - if 'backend' not in full_config: - full_config['backend'] = {} - if 'transcription' not in full_config['backend']: - full_config['backend']['transcription'] = {} - full_config['backend']['transcription']['always_batch_retranscribe'] = True + if "backend" not in full_config: + full_config["backend"] = {} + if "transcription" not in full_config["backend"]: + full_config["backend"]["transcription"] = {} + full_config["backend"]["transcription"]["always_batch_retranscribe"] = True self.config_manager.save_full_config(full_config) self.console.print(f"[blue][INFO][/blue] Set defaults.stt_stream: {stream_stt}") @@ -397,33 +481,47 @@ def setup_streaming_provider(self): # Prompt for streaming provider env vars if not already set if streaming_provider == "deepgram": - existing_key = read_env_value('.env', 'DEEPGRAM_API_KEY') - if not existing_key or existing_key in ('your_deepgram_api_key_here', 'your-deepgram-key-here'): + existing_key = read_env_value(".env", "DEEPGRAM_API_KEY") + if not existing_key or existing_key in ( + "your_deepgram_api_key_here", + "your-deepgram-key-here", + ): api_key = self.prompt_with_existing_masked( prompt_text="Deepgram API key for streaming", env_key="DEEPGRAM_API_KEY", - placeholders=['your_deepgram_api_key_here', 'your-deepgram-key-here'], + placeholders=[ + "your_deepgram_api_key_here", + "your-deepgram-key-here", + ], is_password=True, - default="" + default="", ) if api_key: self.config["DEEPGRAM_API_KEY"] = api_key elif streaming_provider == "smallest": - existing_key = read_env_value('.env', 'SMALLEST_API_KEY') - if not existing_key or existing_key in ('your_smallest_api_key_here', 'your-smallest-key-here'): + existing_key = read_env_value(".env", "SMALLEST_API_KEY") + if not existing_key or existing_key in ( + "your_smallest_api_key_here", + "your-smallest-key-here", + ): api_key = self.prompt_with_existing_masked( prompt_text="Smallest.ai API key for streaming", env_key="SMALLEST_API_KEY", - placeholders=['your_smallest_api_key_here', 'your-smallest-key-here'], + placeholders=[ + "your_smallest_api_key_here", + "your-smallest-key-here", + ], is_password=True, - default="" + default="", ) if api_key: self.config["SMALLEST_API_KEY"] = api_key elif streaming_provider == "qwen3-asr": - existing_url = read_env_value('.env', 'QWEN3_ASR_STREAM_URL') + existing_url = read_env_value(".env", "QWEN3_ASR_STREAM_URL") if not existing_url: - qwen3_url = self.prompt_value("Qwen3-ASR streaming URL", "http://host.docker.internal:8769") + qwen3_url = self.prompt_value( + "Qwen3-ASR streaming URL", "http://host.docker.internal:8769" + ) stream_host = qwen3_url.replace("http://", "").rstrip("/") self.config["QWEN3_ASR_STREAM_URL"] = stream_host @@ -431,51 +529,177 @@ def setup_llm(self): """Configure LLM provider - updates config.yml and .env""" self.print_section("LLM Provider Configuration") - self.console.print("[blue][INFO][/blue] LLM configuration will be saved to config.yml") + self.console.print( + "[blue][INFO][/blue] LLM configuration will be saved to config.yml" + ) self.console.print() choices = { "1": "OpenAI (GPT-4, GPT-3.5 - requires API key)", "2": "Ollama (local models - runs locally)", - "3": "Skip (no memory extraction)" + "3": "OpenAI-Compatible (custom endpoint - Groq, Together AI, LM Studio, etc.)", + "4": "Skip (no memory extraction)", } choice = self.prompt_choice("Which LLM provider will you use?", choices, "1") if choice == "1": self.console.print("[blue][INFO][/blue] OpenAI selected") - self.console.print("Get your API key from: https://platform.openai.com/api-keys") + self.console.print( + "Get your API key from: https://platform.openai.com/api-keys" + ) # Use the new masked prompt function api_key = self.prompt_with_existing_masked( prompt_text="OpenAI API key (leave empty to skip)", env_key="OPENAI_API_KEY", - placeholders=['your_openai_api_key_here', 'your-openai-key-here'], + placeholders=["your_openai_api_key_here", "your-openai-key-here"], is_password=True, - default="" + default="", ) if api_key: self.config["OPENAI_API_KEY"] = api_key # Update config.yml to use OpenAI models - self.config_manager.update_config_defaults({"llm": "openai-llm", "embedding": "openai-embed"}) - self.console.print("[green][SUCCESS][/green] OpenAI configured in config.yml") + self.config_manager.update_config_defaults( + {"llm": "openai-llm", "embedding": "openai-embed"} + ) + self.console.print( + "[green][SUCCESS][/green] OpenAI configured in config.yml" + ) self.console.print("[blue][INFO][/blue] Set defaults.llm: openai-llm") - self.console.print("[blue][INFO][/blue] Set defaults.embedding: openai-embed") + self.console.print( + "[blue][INFO][/blue] Set defaults.embedding: openai-embed" + ) else: - self.console.print("[yellow][WARNING][/yellow] No API key provided - memory extraction will not work") + self.console.print( + "[yellow][WARNING][/yellow] No API key provided - memory extraction will not work" + ) elif choice == "2": self.console.print("[blue][INFO][/blue] Ollama selected") # Update config.yml to use Ollama models - self.config_manager.update_config_defaults({"llm": "local-llm", "embedding": "local-embed"}) - self.console.print("[green][SUCCESS][/green] Ollama configured in config.yml") + self.config_manager.update_config_defaults( + {"llm": "local-llm", "embedding": "local-embed"} + ) + self.console.print( + "[green][SUCCESS][/green] Ollama configured in config.yml" + ) self.console.print("[blue][INFO][/blue] Set defaults.llm: local-llm") - self.console.print("[blue][INFO][/blue] Set defaults.embedding: local-embed") - self.console.print("[yellow][WARNING][/yellow] Make sure Ollama is running and models are pulled") + self.console.print( + "[blue][INFO][/blue] Set defaults.embedding: local-embed" + ) + self.console.print( + "[yellow][WARNING][/yellow] Make sure Ollama is running and models are pulled" + ) elif choice == "3": - self.console.print("[blue][INFO][/blue] Skipping LLM setup - memory extraction disabled") + self.console.print( + "[blue][INFO][/blue] OpenAI-Compatible custom endpoint selected" + ) + self.console.print( + "This works with any provider that exposes an OpenAI-compatible API" + ) + self.console.print("(e.g., Groq, Together AI, LM Studio, vLLM, etc.)") + self.console.print() + + # Prompt for base URL (required) + base_url = self.prompt_value( + "API Base URL (e.g., https://api.groq.com/openai/v1)", "" + ) + if not base_url: + self.console.print( + "[yellow][WARNING][/yellow] No base URL provided - skipping custom LLM setup" + ) + else: + # Prompt for API key + api_key = self.prompt_with_existing_masked( + prompt_text="API Key (leave empty if not required)", + env_key="CUSTOM_LLM_API_KEY", + placeholders=["your_custom_llm_api_key_here"], + is_password=True, + default="", + ) + if api_key: + self.config["CUSTOM_LLM_API_KEY"] = api_key + + # Prompt for model name (required) + model_name = self.prompt_value( + "LLM Model name (e.g., llama-3.1-70b-versatile)", "" + ) + if not model_name: + self.console.print( + "[yellow][WARNING][/yellow] No model name provided - skipping custom LLM setup" + ) + else: + # Create LLM model entry + llm_model = { + "name": "custom-llm", + "description": "Custom OpenAI-compatible LLM", + "model_type": "llm", + "model_provider": "openai", + "api_family": "openai", + "model_name": model_name, + "model_url": base_url, + "api_key": "${oc.env:CUSTOM_LLM_API_KEY,''}", + "model_params": {"temperature": 0.2, "max_tokens": 2000}, + "model_output": "json", + } + self.config_manager.add_or_update_model(llm_model) + + # Prompt for optional embedding model + embedding_model_name = self.prompt_value( + "Embedding model name (leave empty to use Ollama local-embed)", + "", + ) + + if embedding_model_name: + embed_model = { + "name": "custom-embed", + "description": "Custom OpenAI-compatible embeddings", + "model_type": "embedding", + "model_provider": "openai", + "api_family": "openai", + "model_name": embedding_model_name, + "model_url": base_url, + "api_key": "${oc.env:CUSTOM_LLM_API_KEY,''}", + "embedding_dimensions": 1536, + "model_output": "vector", + } + self.config_manager.add_or_update_model(embed_model) + self.config_manager.update_config_defaults( + {"llm": "custom-llm", "embedding": "custom-embed"} + ) + self.console.print( + "[green][SUCCESS][/green] Custom LLM and embedding configured in config.yml" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.llm: custom-llm" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.embedding: custom-embed" + ) + else: + self.config_manager.update_config_defaults( + {"llm": "custom-llm", "embedding": "local-embed"} + ) + self.console.print( + "[green][SUCCESS][/green] Custom LLM configured in config.yml" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.llm: custom-llm" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.embedding: local-embed (Ollama)" + ) + self.console.print( + "[yellow][WARNING][/yellow] Make sure Ollama is running for embeddings" + ) + + elif choice == "4": + self.console.print( + "[blue][INFO][/blue] Skipping LLM setup - memory extraction disabled" + ) # Disable memory extraction in config.yml self.config_manager.update_memory_config({"extraction": {"enabled": False}}) @@ -491,80 +715,115 @@ def setup_memory(self): choice = self.prompt_choice("Choose your memory storage backend:", choices, "1") if choice == "1": - self.console.print("[blue][INFO][/blue] Chronicle Native memory provider selected") + self.console.print( + "[blue][INFO][/blue] Chronicle Native memory provider selected" + ) qdrant_url = self.prompt_value("Qdrant URL", "qdrant") self.config["QDRANT_BASE_URL"] = qdrant_url # Update config.yml (also updates .env automatically) self.config_manager.update_memory_config({"provider": "chronicle"}) - self.console.print("[green][SUCCESS][/green] Chronicle memory provider configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] Chronicle memory provider configured in config.yml and .env" + ) elif choice == "2": self.console.print("[blue][INFO][/blue] OpenMemory MCP selected") - mcp_url = self.prompt_value("OpenMemory MCP server URL", "http://host.docker.internal:8765") + mcp_url = self.prompt_value( + "OpenMemory MCP server URL", "http://host.docker.internal:8765" + ) client_name = self.prompt_value("OpenMemory client name", "chronicle") user_id = self.prompt_value("OpenMemory user ID", "openmemory") timeout = self.prompt_value("OpenMemory timeout (seconds)", "30") # Update config.yml with OpenMemory MCP settings (also updates .env automatically) - self.config_manager.update_memory_config({ - "provider": "openmemory_mcp", - "openmemory_mcp": { - "server_url": mcp_url, - "client_name": client_name, - "user_id": user_id, - "timeout": int(timeout) + self.config_manager.update_memory_config( + { + "provider": "openmemory_mcp", + "openmemory_mcp": { + "server_url": mcp_url, + "client_name": client_name, + "user_id": user_id, + "timeout": int(timeout), + }, } - }) - self.console.print("[green][SUCCESS][/green] OpenMemory MCP configured in config.yml and .env") - self.console.print("[yellow][WARNING][/yellow] Remember to start OpenMemory: cd ../../extras/openmemory-mcp && docker compose up -d") + ) + self.console.print( + "[green][SUCCESS][/green] OpenMemory MCP configured in config.yml and .env" + ) + self.console.print( + "[yellow][WARNING][/yellow] Remember to start OpenMemory: cd ../../extras/openmemory-mcp && docker compose up -d" + ) def setup_optional_services(self): """Configure optional services""" # Check if speaker service URL provided via args - has_speaker_arg = hasattr(self.args, 'speaker_service_url') and self.args.speaker_service_url - has_asr_arg = hasattr(self.args, 'parakeet_asr_url') and self.args.parakeet_asr_url + has_speaker_arg = ( + hasattr(self.args, "speaker_service_url") and self.args.speaker_service_url + ) + has_asr_arg = ( + hasattr(self.args, "parakeet_asr_url") and self.args.parakeet_asr_url + ) if has_speaker_arg: self.config["SPEAKER_SERVICE_URL"] = self.args.speaker_service_url - self.console.print(f"[green]✅[/green] Speaker Recognition: {self.args.speaker_service_url} (configured via wizard)") + self.console.print( + f"[green]✅[/green] Speaker Recognition: {self.args.speaker_service_url} (configured via wizard)" + ) if has_asr_arg: self.config["PARAKEET_ASR_URL"] = self.args.parakeet_asr_url - self.console.print(f"[green]✅[/green] Parakeet ASR: {self.args.parakeet_asr_url} (configured via wizard)") + self.console.print( + f"[green]✅[/green] Parakeet ASR: {self.args.parakeet_asr_url} (configured via wizard)" + ) # Only show interactive section if not all configured via args if not has_speaker_arg: try: - enable_speaker = Confirm.ask("Enable Speaker Recognition?", default=False) + enable_speaker = Confirm.ask( + "Enable Speaker Recognition?", default=False + ) except EOFError: self.console.print("Using default: No") enable_speaker = False - + if enable_speaker: - speaker_url = self.prompt_value("Speaker Recognition service URL", "http://host.docker.internal:8001") + speaker_url = self.prompt_value( + "Speaker Recognition service URL", + "http://host.docker.internal:8001", + ) self.config["SPEAKER_SERVICE_URL"] = speaker_url - self.console.print("[green][SUCCESS][/green] Speaker Recognition configured") - self.console.print("[blue][INFO][/blue] Start with: cd ../../extras/speaker-recognition && docker compose up -d") - + self.console.print( + "[green][SUCCESS][/green] Speaker Recognition configured" + ) + self.console.print( + "[blue][INFO][/blue] Start with: cd ../../extras/speaker-recognition && docker compose up -d" + ) + # Check if Tailscale auth key provided via args - if hasattr(self.args, 'ts_authkey') and self.args.ts_authkey: + if hasattr(self.args, "ts_authkey") and self.args.ts_authkey: self.config["TS_AUTHKEY"] = self.args.ts_authkey - self.console.print(f"[green][SUCCESS][/green] Tailscale auth key configured (Docker integration enabled)") + self.console.print( + f"[green][SUCCESS][/green] Tailscale auth key configured (Docker integration enabled)" + ) def setup_neo4j(self): """Configure Neo4j credentials (always required - used by Knowledge Graph)""" - neo4j_password = getattr(self.args, 'neo4j_password', None) + neo4j_password = getattr(self.args, "neo4j_password", None) if neo4j_password: - self.console.print(f"[green]✅[/green] Neo4j: password configured via wizard") + self.console.print( + f"[green]✅[/green] Neo4j: password configured via wizard" + ) else: # Interactive prompt (standalone init.py run) self.console.print() self.console.print("[bold cyan]Neo4j Configuration[/bold cyan]") - self.console.print("Neo4j is used for Knowledge Graph (entity/relationship extraction)") + self.console.print( + "Neo4j is used for Knowledge Graph (entity/relationship extraction)" + ) self.console.print() neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") @@ -575,49 +834,54 @@ def setup_neo4j(self): def setup_obsidian(self): """Configure Obsidian integration (optional feature flag only - Neo4j credentials handled by setup_neo4j)""" - if hasattr(self.args, 'enable_obsidian') and self.args.enable_obsidian: + if hasattr(self.args, "enable_obsidian") and self.args.enable_obsidian: enable_obsidian = True - self.console.print(f"[green]✅[/green] Obsidian: enabled (configured via wizard)") + self.console.print( + f"[green]✅[/green] Obsidian: enabled (configured via wizard)" + ) else: # Interactive prompt (fallback) self.console.print() self.console.print("[bold cyan]Obsidian Integration (Optional)[/bold cyan]") - self.console.print("Enable graph-based knowledge management for Obsidian vault notes") + self.console.print( + "Enable graph-based knowledge management for Obsidian vault notes" + ) self.console.print() try: - enable_obsidian = Confirm.ask("Enable Obsidian integration?", default=False) + enable_obsidian = Confirm.ask( + "Enable Obsidian integration?", default=False + ) except EOFError: self.console.print("Using default: No") enable_obsidian = False if enable_obsidian: - self.config_manager.update_memory_config({ - "obsidian": { - "enabled": True, - "neo4j_host": "neo4j", - "timeout": 30 - } - }) + self.config_manager.update_memory_config( + {"obsidian": {"enabled": True, "neo4j_host": "neo4j", "timeout": 30}} + ) self.console.print("[green][SUCCESS][/green] Obsidian integration enabled") else: - self.config_manager.update_memory_config({ - "obsidian": { - "enabled": False, - "neo4j_host": "neo4j", - "timeout": 30 - } - }) + self.config_manager.update_memory_config( + {"obsidian": {"enabled": False, "neo4j_host": "neo4j", "timeout": 30}} + ) self.console.print("[blue][INFO][/blue] Obsidian integration disabled") def setup_knowledge_graph(self): """Configure Knowledge Graph (Neo4j-based entity/relationship extraction - enabled by default)""" - if hasattr(self.args, 'enable_knowledge_graph') and self.args.enable_knowledge_graph: + if ( + hasattr(self.args, "enable_knowledge_graph") + and self.args.enable_knowledge_graph + ): enable_kg = True else: self.console.print() - self.console.print("[bold cyan]Knowledge Graph (Entity Extraction)[/bold cyan]") - self.console.print("Extract people, places, organizations, events, and tasks from conversations") + self.console.print( + "[bold cyan]Knowledge Graph (Entity Extraction)[/bold cyan]" + ) + self.console.print( + "Extract people, places, organizations, events, and tasks from conversations" + ) self.console.print() try: @@ -627,56 +891,77 @@ def setup_knowledge_graph(self): enable_kg = True if enable_kg: - self.config_manager.update_memory_config({ - "knowledge_graph": { - "enabled": True, - "neo4j_host": "neo4j", - "timeout": 30 + self.config_manager.update_memory_config( + { + "knowledge_graph": { + "enabled": True, + "neo4j_host": "neo4j", + "timeout": 30, + } } - }) + ) self.console.print("[green][SUCCESS][/green] Knowledge Graph enabled") - self.console.print("[blue][INFO][/blue] Entities and relationships will be extracted from conversations") + self.console.print( + "[blue][INFO][/blue] Entities and relationships will be extracted from conversations" + ) else: - self.config_manager.update_memory_config({ - "knowledge_graph": { - "enabled": False, - "neo4j_host": "neo4j", - "timeout": 30 + self.config_manager.update_memory_config( + { + "knowledge_graph": { + "enabled": False, + "neo4j_host": "neo4j", + "timeout": 30, + } } - }) + ) self.console.print("[blue][INFO][/blue] Knowledge Graph disabled") def setup_langfuse(self): """Configure LangFuse observability and prompt management""" self.console.print() - self.console.print("[bold cyan]LangFuse Observability & Prompt Management[/bold cyan]") + self.console.print( + "[bold cyan]LangFuse Observability & Prompt Management[/bold cyan]" + ) # Check if keys were passed from wizard (langfuse init already ran) - langfuse_pub = getattr(self.args, 'langfuse_public_key', None) - langfuse_sec = getattr(self.args, 'langfuse_secret_key', None) + langfuse_pub = getattr(self.args, "langfuse_public_key", None) + langfuse_sec = getattr(self.args, "langfuse_secret_key", None) if langfuse_pub and langfuse_sec: # Auto-configure from wizard — no prompts needed - langfuse_host = getattr(self.args, 'langfuse_host', None) or "http://langfuse-web:3000" + langfuse_host = ( + getattr(self.args, "langfuse_host", None) or "http://langfuse-web:3000" + ) self.config["LANGFUSE_HOST"] = langfuse_host self.config["LANGFUSE_PUBLIC_KEY"] = langfuse_pub self.config["LANGFUSE_SECRET_KEY"] = langfuse_sec self.config["LANGFUSE_BASE_URL"] = langfuse_host # Derive browser-accessible URL for deep-links - public_url = getattr(self.args, 'langfuse_public_url', None) or "http://localhost:3002" + public_url = ( + getattr(self.args, "langfuse_public_url", None) + or "http://localhost:3002" + ) self._save_langfuse_public_url(public_url) source = "external" if "langfuse-web" not in langfuse_host else "local" - self.console.print(f"[green][SUCCESS][/green] LangFuse auto-configured ({source})") + self.console.print( + f"[green][SUCCESS][/green] LangFuse auto-configured ({source})" + ) self.console.print(f"[blue][INFO][/blue] Host: {langfuse_host}") self.console.print(f"[blue][INFO][/blue] Public URL: {public_url}") - self.console.print(f"[blue][INFO][/blue] Public key: {self.mask_api_key(langfuse_pub)}") + self.console.print( + f"[blue][INFO][/blue] Public key: {self.mask_api_key(langfuse_pub)}" + ) return # Manual configuration (standalone init.py run) - self.console.print("Enable LLM tracing, observability, and prompt management with LangFuse") - self.console.print("Self-host: cd ../../extras/langfuse && docker compose up -d") + self.console.print( + "Enable LLM tracing, observability, and prompt management with LangFuse" + ) + self.console.print( + "Self-host: cd ../../extras/langfuse && docker compose up -d" + ) self.console.print() try: @@ -748,52 +1033,68 @@ def setup_network(self): def setup_https(self): """Configure HTTPS settings for microphone access""" # Check if HTTPS configuration provided via command line - if hasattr(self.args, 'enable_https') and self.args.enable_https: + if hasattr(self.args, "enable_https") and self.args.enable_https: enable_https = True - server_ip = getattr(self.args, 'server_ip', 'localhost') - self.console.print(f"[green]✅[/green] HTTPS: {server_ip} (configured via wizard)") + server_ip = getattr(self.args, "server_ip", "localhost") + self.console.print( + f"[green]✅[/green] HTTPS: {server_ip} (configured via wizard)" + ) else: # Interactive configuration self.print_section("HTTPS Configuration (Optional)") try: - enable_https = Confirm.ask("Enable HTTPS for microphone access?", default=False) + enable_https = Confirm.ask( + "Enable HTTPS for microphone access?", default=False + ) except EOFError: self.console.print("Using default: No") enable_https = False if enable_https: - self.console.print("[blue][INFO][/blue] HTTPS enables microphone access in browsers") + self.console.print( + "[blue][INFO][/blue] HTTPS enables microphone access in browsers" + ) # Try to auto-detect Tailscale address ts_dns, ts_ip = detect_tailscale_info() if ts_dns: - self.console.print(f"[green][AUTO-DETECTED][/green] Tailscale DNS: {ts_dns}") + self.console.print( + f"[green][AUTO-DETECTED][/green] Tailscale DNS: {ts_dns}" + ) if ts_ip: - self.console.print(f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}") + self.console.print( + f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}" + ) default_address = ts_dns elif ts_ip: - self.console.print(f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}") + self.console.print( + f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}" + ) default_address = ts_ip else: self.console.print("[blue][INFO][/blue] Tailscale not detected") - self.console.print("[blue][INFO][/blue] To find your Tailscale address: tailscale status --json | jq -r '.Self.DNSName'") + self.console.print( + "[blue][INFO][/blue] To find your Tailscale address: tailscale status --json | jq -r '.Self.DNSName'" + ) default_address = "localhost" - self.console.print("[blue][INFO][/blue] For local-only access, use 'localhost'") + self.console.print( + "[blue][INFO][/blue] For local-only access, use 'localhost'" + ) # Use the new masked prompt function (not masked for IP, but shows existing) server_ip = self.prompt_with_existing_masked( prompt_text="Server IP/Domain for SSL certificate", env_key="SERVER_IP", - placeholders=['localhost', 'your-server-ip-here'], + placeholders=["localhost", "your-server-ip-here"], is_password=False, - default=default_address + default=default_address, ) - + if enable_https: - + # Generate SSL certificates self.console.print("[blue][INFO][/blue] Generating SSL certificates...") # Use path relative to this script's directory @@ -802,17 +1103,32 @@ def setup_https(self): if ssl_script.exists(): try: # Run from the backend directory so paths work correctly - subprocess.run([str(ssl_script), server_ip], check=True, cwd=str(script_dir), timeout=180) - self.console.print("[green][SUCCESS][/green] SSL certificates generated") + subprocess.run( + [str(ssl_script), server_ip], + check=True, + cwd=str(script_dir), + timeout=180, + ) + self.console.print( + "[green][SUCCESS][/green] SSL certificates generated" + ) except subprocess.TimeoutExpired: - self.console.print("[yellow][WARNING][/yellow] SSL certificate generation timed out after 3 minutes") + self.console.print( + "[yellow][WARNING][/yellow] SSL certificate generation timed out after 3 minutes" + ) except subprocess.CalledProcessError: - self.console.print("[yellow][WARNING][/yellow] SSL certificate generation failed") + self.console.print( + "[yellow][WARNING][/yellow] SSL certificate generation failed" + ) else: - self.console.print(f"[yellow][WARNING][/warning] SSL script not found at {ssl_script}") + self.console.print( + f"[yellow][WARNING][/warning] SSL script not found at {ssl_script}" + ) # Generate Caddyfile from template - self.console.print("[blue][INFO][/blue] Creating Caddyfile configuration...") + self.console.print( + "[blue][INFO][/blue] Creating Caddyfile configuration..." + ) caddyfile_template = script_dir / "Caddyfile.template" caddyfile_path = script_dir / "Caddyfile" @@ -820,32 +1136,50 @@ def setup_https(self): try: # Check if Caddyfile exists as a directory (common issue) if caddyfile_path.exists() and caddyfile_path.is_dir(): - self.console.print("[red]❌ ERROR: 'Caddyfile' exists as a directory![/red]") - self.console.print("[yellow] Please remove it manually:[/yellow]") - self.console.print(f"[yellow] rm -rf {caddyfile_path}[/yellow]") - self.console.print("[red] HTTPS will NOT work without a proper Caddyfile![/red]") + self.console.print( + "[red]❌ ERROR: 'Caddyfile' exists as a directory![/red]" + ) + self.console.print( + "[yellow] Please remove it manually:[/yellow]" + ) + self.console.print( + f"[yellow] rm -rf {caddyfile_path}[/yellow]" + ) + self.console.print( + "[red] HTTPS will NOT work without a proper Caddyfile![/red]" + ) self.config["HTTPS_ENABLED"] = "false" else: - with open(caddyfile_template, 'r') as f: + with open(caddyfile_template, "r") as f: caddyfile_content = f.read() # Replace TAILSCALE_IP with server_ip - caddyfile_content = caddyfile_content.replace('TAILSCALE_IP', server_ip) + caddyfile_content = caddyfile_content.replace( + "TAILSCALE_IP", server_ip + ) - with open(caddyfile_path, 'w') as f: + with open(caddyfile_path, "w") as f: f.write(caddyfile_content) - self.console.print(f"[green][SUCCESS][/green] Caddyfile created for: {server_ip}") + self.console.print( + f"[green][SUCCESS][/green] Caddyfile created for: {server_ip}" + ) self.config["HTTPS_ENABLED"] = "true" self.config["SERVER_IP"] = server_ip except Exception as e: - self.console.print(f"[red]❌ ERROR: Caddyfile generation failed: {e}[/red]") - self.console.print("[red] HTTPS will NOT work without a proper Caddyfile![/red]") + self.console.print( + f"[red]❌ ERROR: Caddyfile generation failed: {e}[/red]" + ) + self.console.print( + "[red] HTTPS will NOT work without a proper Caddyfile![/red]" + ) self.config["HTTPS_ENABLED"] = "false" else: self.console.print("[red]❌ ERROR: Caddyfile.template not found[/red]") - self.console.print("[red] HTTPS will NOT work without a proper Caddyfile![/red]") + self.console.print( + "[red] HTTPS will NOT work without a proper Caddyfile![/red]" + ) self.config["HTTPS_ENABLED"] = "false" else: self.config["HTTPS_ENABLED"] = "false" @@ -863,7 +1197,9 @@ def generate_env_file(self): shutil.copy2(env_template, env_path) self.console.print("[blue][INFO][/blue] Copied .env.template to .env") else: - self.console.print("[yellow][WARNING][/yellow] .env.template not found, creating new .env") + self.console.print( + "[yellow][WARNING][/yellow] .env.template not found, creating new .env" + ) env_path.touch(mode=0o600) # Update configured values using set_key @@ -875,24 +1211,35 @@ def generate_env_file(self): # Ensure secure permissions os.chmod(env_path, 0o600) - self.console.print("[green][SUCCESS][/green] .env file configured successfully with secure permissions") + self.console.print( + "[green][SUCCESS][/green] .env file configured successfully with secure permissions" + ) # Note: config.yml is automatically saved by ConfigManager when updates are made - self.console.print("[blue][INFO][/blue] Configuration saved to config.yml and .env (via ConfigManager)") + self.console.print( + "[blue][INFO][/blue] Configuration saved to config.yml and .env (via ConfigManager)" + ) def copy_config_templates(self): """Copy other configuration files""" - if not Path("diarization_config.json").exists() and Path("diarization_config.json.template").exists(): + if ( + not Path("diarization_config.json").exists() + and Path("diarization_config.json.template").exists() + ): shutil.copy2("diarization_config.json.template", "diarization_config.json") - self.console.print("[green][SUCCESS][/green] diarization_config.json created") + self.console.print( + "[green][SUCCESS][/green] diarization_config.json created" + ) def show_summary(self): """Show configuration summary""" self.print_section("Configuration Summary") self.console.print() - self.console.print(f"✅ Admin Account: {self.config.get('ADMIN_EMAIL', 'Not configured')}") + self.console.print( + f"✅ Admin Account: {self.config.get('ADMIN_EMAIL', 'Not configured')}" + ) # Get current config from ConfigManager (single source of truth) config_yml = self.config_manager.get_full_config() @@ -901,10 +1248,16 @@ def show_summary(self): stt_default = config_yml.get("defaults", {}).get("stt", "not set") stt_model = next( (m for m in config_yml.get("models", []) if m.get("name") == stt_default), - None + None, + ) + stt_provider = ( + stt_model.get("model_provider", "unknown") + if stt_model + else "not configured" + ) + self.console.print( + f"✅ Transcription: {stt_provider} ({stt_default}) - config.yml" ) - stt_provider = stt_model.get("model_provider", "unknown") if stt_model else "not configured" - self.console.print(f"✅ Transcription: {stt_provider} ({stt_default}) - config.yml") # Show LLM config from config.yml llm_default = config_yml.get("defaults", {}).get("llm", "not set") @@ -929,13 +1282,13 @@ def show_summary(self): self.console.print(f"✅ Knowledge Graph: Enabled ({neo4j_host})") # Auto-determine URLs based on HTTPS configuration - if self.config.get('HTTPS_ENABLED') == 'true': - server_ip = self.config.get('SERVER_IP', 'localhost') + if self.config.get("HTTPS_ENABLED") == "true": + server_ip = self.config.get("SERVER_IP", "localhost") self.console.print(f"✅ Backend URL: https://{server_ip}/") self.console.print(f"✅ Dashboard URL: https://{server_ip}/") else: - backend_port = self.config.get('BACKEND_PUBLIC_PORT', '8000') - webui_port = self.config.get('WEBUI_PORT', '5173') + backend_port = self.config.get("BACKEND_PUBLIC_PORT", "8000") + webui_port = self.config.get("WEBUI_PORT", "5173") self.console.print(f"✅ Backend URL: http://localhost:{backend_port}") self.console.print(f"✅ Dashboard URL: http://localhost:{webui_port}") @@ -950,40 +1303,52 @@ def show_next_steps(self): self.console.print("1. Start the main services:") self.console.print(" [cyan]docker compose up --build -d[/cyan]") self.console.print() - + # Auto-determine URLs for next steps - if self.config.get('HTTPS_ENABLED') == 'true': - server_ip = self.config.get('SERVER_IP', 'localhost') + if self.config.get("HTTPS_ENABLED") == "true": + server_ip = self.config.get("SERVER_IP", "localhost") self.console.print("2. Access the dashboard:") self.console.print(f" [cyan]https://{server_ip}/[/cyan]") self.console.print() self.console.print("3. Check service health:") self.console.print(f" [cyan]curl -k https://{server_ip}/health[/cyan]") else: - webui_port = self.config.get('WEBUI_PORT', '5173') - backend_port = self.config.get('BACKEND_PUBLIC_PORT', '8000') + webui_port = self.config.get("WEBUI_PORT", "5173") + backend_port = self.config.get("BACKEND_PUBLIC_PORT", "8000") self.console.print("2. Access the dashboard:") self.console.print(f" [cyan]http://localhost:{webui_port}[/cyan]") self.console.print() self.console.print("3. Check service health:") - self.console.print(f" [cyan]curl http://localhost:{backend_port}/health[/cyan]") + self.console.print( + f" [cyan]curl http://localhost:{backend_port}/health[/cyan]" + ) if self.config.get("MEMORY_PROVIDER") == "openmemory_mcp": self.console.print() self.console.print("4. Start OpenMemory MCP:") - self.console.print(" [cyan]cd ../../extras/openmemory-mcp && docker compose up -d[/cyan]") + self.console.print( + " [cyan]cd ../../extras/openmemory-mcp && docker compose up -d[/cyan]" + ) if self.config.get("TRANSCRIPTION_PROVIDER") == "offline": self.console.print() self.console.print("5. Start Parakeet ASR:") - self.console.print(" [cyan]cd ../../extras/asr-services && docker compose up parakeet -d[/cyan]") + self.console.print( + " [cyan]cd ../../extras/asr-services && docker compose up parakeet -d[/cyan]" + ) def run(self): """Run the complete setup process""" self.print_header("🚀 Chronicle Interactive Setup") - self.console.print("This wizard will help you configure Chronicle with all necessary services.") - self.console.print("[dim]Safe to run again — it backs up your config and preserves previous values.[/dim]") - self.console.print("[dim]When unsure, just press Enter — the defaults will work.[/dim]") + self.console.print( + "This wizard will help you configure Chronicle with all necessary services." + ) + self.console.print( + "[dim]Safe to run again — it backs up your config and preserves previous values.[/dim]" + ) + self.console.print( + "[dim]When unsure, just press Enter — the defaults will work.[/dim]" + ) self.console.print() try: @@ -1018,7 +1383,9 @@ def run(self): self.console.print() self.console.print("📝 [bold]Configuration files updated:[/bold]") self.console.print(f" • .env - API keys and environment variables") - self.console.print(f" • ../../config/config.yml - Model and memory provider configuration") + self.console.print( + f" • ../../config/config.yml - Model and memory provider configuration" + ) self.console.print() self.console.print("For detailed documentation, see:") self.console.print(" • Docs/quickstart.md") @@ -1037,39 +1404,68 @@ def run(self): def main(): """Main entry point""" parser = argparse.ArgumentParser(description="Chronicle Advanced Backend Setup") - parser.add_argument("--speaker-service-url", - help="Speaker Recognition service URL (default: prompt user)") - parser.add_argument("--parakeet-asr-url", - help="Parakeet ASR service URL (default: prompt user)") - parser.add_argument("--transcription-provider", - choices=["deepgram", "parakeet", "vibevoice", "qwen3-asr", "smallest", "none"], - help="Transcription provider (default: prompt user)") - parser.add_argument("--enable-https", action="store_true", - help="Enable HTTPS configuration (default: prompt user)") - parser.add_argument("--server-ip", - help="Server IP/domain for SSL certificate (default: prompt user)") - parser.add_argument("--enable-obsidian", action="store_true", - help="Enable Obsidian/Neo4j integration (default: prompt user)") - parser.add_argument("--enable-knowledge-graph", action="store_true", - help="Enable Knowledge Graph entity extraction (default: prompt user)") - parser.add_argument("--neo4j-password", - help="Neo4j password (default: prompt user)") - parser.add_argument("--ts-authkey", - help="Tailscale auth key for Docker integration (default: prompt user)") - parser.add_argument("--langfuse-public-key", - help="LangFuse project public key (from langfuse init or external)") - parser.add_argument("--langfuse-secret-key", - help="LangFuse project secret key (from langfuse init or external)") - parser.add_argument("--langfuse-host", - help="LangFuse host URL (default: http://langfuse-web:3000 for local)") - parser.add_argument("--langfuse-public-url", - help="LangFuse browser-accessible URL for deep-links (default: http://localhost:3002)") - parser.add_argument("--streaming-provider", - choices=["deepgram", "smallest", "qwen3-asr"], - help="Streaming provider when different from batch (enables batch re-transcription)") + parser.add_argument( + "--speaker-service-url", + help="Speaker Recognition service URL (default: prompt user)", + ) + parser.add_argument( + "--parakeet-asr-url", help="Parakeet ASR service URL (default: prompt user)" + ) + parser.add_argument( + "--transcription-provider", + choices=["deepgram", "parakeet", "vibevoice", "qwen3-asr", "smallest", "none"], + help="Transcription provider (default: prompt user)", + ) + parser.add_argument( + "--enable-https", + action="store_true", + help="Enable HTTPS configuration (default: prompt user)", + ) + parser.add_argument( + "--server-ip", + help="Server IP/domain for SSL certificate (default: prompt user)", + ) + parser.add_argument( + "--enable-obsidian", + action="store_true", + help="Enable Obsidian/Neo4j integration (default: prompt user)", + ) + parser.add_argument( + "--enable-knowledge-graph", + action="store_true", + help="Enable Knowledge Graph entity extraction (default: prompt user)", + ) + parser.add_argument( + "--neo4j-password", help="Neo4j password (default: prompt user)" + ) + parser.add_argument( + "--ts-authkey", + help="Tailscale auth key for Docker integration (default: prompt user)", + ) + parser.add_argument( + "--langfuse-public-key", + help="LangFuse project public key (from langfuse init or external)", + ) + parser.add_argument( + "--langfuse-secret-key", + help="LangFuse project secret key (from langfuse init or external)", + ) + parser.add_argument( + "--langfuse-host", + help="LangFuse host URL (default: http://langfuse-web:3000 for local)", + ) + parser.add_argument( + "--langfuse-public-url", + help="LangFuse browser-accessible URL for deep-links (default: http://localhost:3002)", + ) + parser.add_argument( + "--streaming-provider", + choices=["deepgram", "smallest", "qwen3-asr"], + help="Streaming provider when different from batch (enables batch re-transcription)", + ) args = parser.parse_args() - + setup = ChronicleSetup(args) setup.run() diff --git a/backends/advanced/src/advanced_omi_backend/app_factory.py b/backends/advanced/src/advanced_omi_backend/app_factory.py index 6083de97..c1d56fed 100644 --- a/backends/advanced/src/advanced_omi_backend/app_factory.py +++ b/backends/advanced/src/advanced_omi_backend/app_factory.py @@ -7,6 +7,7 @@ import asyncio import logging +import time from contextlib import asynccontextmanager from pathlib import Path @@ -122,10 +123,14 @@ async def initialize_openmemory_user() -> None: async def lifespan(app: FastAPI): """Manage application lifespan events.""" config = get_app_config() + startup_start = time.monotonic() # Startup application_logger.info("Starting application...") + # ── Phase 1 (sequential — dependencies) ────────────────────────── + phase_start = time.monotonic() + # Initialize Beanie for all document models try: from beanie import init_beanie @@ -151,200 +156,258 @@ async def lifespan(app: FastAPI): application_logger.error(f"Failed to initialize Beanie: {e}") raise - # Create admin user if needed + # Create admin user if needed (requires Beanie) try: await create_admin_user_if_needed() except Exception as e: application_logger.error(f"Failed to create admin user: {e}") - # Don't raise here as this is not critical for startup - # Initialize Redis connection for RQ - try: - from advanced_omi_backend.controllers.queue_controller import redis_conn + application_logger.info( + f"Phase 1 (Beanie + admin) completed in {time.monotonic() - phase_start:.2f}s" + ) - redis_conn.ping() - application_logger.info("Redis connection established for RQ") - application_logger.info( - "RQ workers can be started with: rq worker transcription memory default" - ) - except Exception as e: - application_logger.error(f"Failed to connect to Redis for RQ: {e}") - application_logger.warning( - "RQ queue system will not be available - check Redis connection" - ) + # ── Phase 2 (parallel — all independent) ───────────────────────── + phase_start = time.monotonic() - # Initialize BackgroundTaskManager (must happen before any code path uses it) - try: - task_manager = init_task_manager() - await task_manager.start() - application_logger.info("BackgroundTaskManager initialized and started") - except Exception as e: - application_logger.error(f"Failed to initialize task manager: {e}") - raise # Task manager is essential + async def _init_redis_rq(): + try: + from advanced_omi_backend.controllers.queue_controller import redis_conn - # Initialize ClientManager eagerly (prevents lazy race on first WebSocket connect) - get_client_manager() - application_logger.info("ClientManager initialized") + redis_conn.ping() + application_logger.info("Redis connection established for RQ") + except Exception as e: + application_logger.error(f"Failed to connect to Redis for RQ: {e}") + application_logger.warning( + "RQ queue system will not be available - check Redis connection" + ) - # Initialize OTEL/Galileo if configured (before LLM client so instrumentor patches OpenAI first) - try: - from advanced_omi_backend.observability.otel_setup import init_otel + async def _init_task_manager(): + try: + tm = init_task_manager() + await tm.start() + application_logger.info("BackgroundTaskManager initialized and started") + except Exception as e: + application_logger.error(f"Failed to initialize task manager: {e}") + raise # Task manager is essential - init_otel() - except Exception as e: - application_logger.warning(f"OTEL initialization skipped: {e}") + async def _init_client_manager(): + get_client_manager() + application_logger.info("ClientManager initialized") - # Initialize prompt registry with defaults; seed into LangFuse in background - try: - from advanced_omi_backend.prompt_defaults import register_all_defaults - from advanced_omi_backend.prompt_registry import get_prompt_registry + async def _init_otel(): + try: + from advanced_omi_backend.observability.otel_setup import init_otel - prompt_registry = get_prompt_registry() - register_all_defaults(prompt_registry) - application_logger.info( - f"Prompt registry initialized with {len(prompt_registry._defaults)} defaults" - ) + init_otel() + except Exception as e: + application_logger.warning(f"OTEL initialization skipped: {e}") - # Seed prompts in background — Langfuse may not be ready at startup - async def _deferred_seed(): - await asyncio.sleep(10) - await prompt_registry.seed_prompts() + async def _init_prompt_registry(): + try: + from advanced_omi_backend.prompt_defaults import register_all_defaults + from advanced_omi_backend.prompt_registry import get_prompt_registry - asyncio.create_task(_deferred_seed()) - except Exception as e: - application_logger.warning(f"Prompt registry initialization failed: {e}") + registry = get_prompt_registry() + register_all_defaults(registry) + application_logger.info( + f"Prompt registry initialized with {len(registry._defaults)} defaults" + ) + except Exception as e: + application_logger.warning(f"Prompt registry initialization failed: {e}") + + await asyncio.gather( + _init_redis_rq(), + _init_task_manager(), + _init_client_manager(), + _init_otel(), + _init_prompt_registry(), + ) - # Initialize LLM client eagerly (catch config errors at startup, not on first request) - try: - from advanced_omi_backend.llm_client import get_llm_client + application_logger.info( + f"Phase 2 (Redis/TaskMgr/ClientMgr/OTEL/Prompts) completed in {time.monotonic() - phase_start:.2f}s" + ) - get_llm_client() - application_logger.info("LLM client initialized from config.yml") - except Exception as e: - application_logger.warning(f"LLM client initialization deferred: {e}") + # ── Phase 3 (parallel — OTEL done, safe for LLM patching) ──────── + phase_start = time.monotonic() - # Initialize audio stream service for Redis Streams - try: - audio_service = get_audio_stream_service() - await audio_service.connect() - application_logger.info("Audio stream service connected to Redis Streams") - application_logger.info( - "Audio stream workers can be started with: python -m advanced_omi_backend.workers.audio_stream_worker" - ) - except Exception as e: - application_logger.error(f"Failed to connect audio stream service: {e}") - application_logger.warning( - "Redis Streams audio processing will not be available" - ) + async def _init_llm_client(): + try: + from advanced_omi_backend.llm_client import get_llm_client - # Initialize Redis client for audio streaming producer (used by WebSocket handlers) - try: - app.state.redis_audio_stream = await redis.from_url( - config.redis_url, encoding="utf-8", decode_responses=False - ) - from advanced_omi_backend.services.audio_stream import AudioStreamProducer + get_llm_client() + application_logger.info("LLM client initialized from config.yml") + except Exception as e: + application_logger.warning(f"LLM client initialization deferred: {e}") - app.state.audio_stream_producer = AudioStreamProducer( - app.state.redis_audio_stream - ) - application_logger.info( - "✅ Redis client for audio streaming producer initialized" - ) + async def _init_audio_stream_service(): + try: + audio_service = get_audio_stream_service() + await audio_service.connect() + application_logger.info("Audio stream service connected to Redis Streams") + except Exception as e: + application_logger.error(f"Failed to connect audio stream service: {e}") + application_logger.warning( + "Redis Streams audio processing will not be available" + ) - # Initialize ClientManager Redis for cross-container client→user mapping - from advanced_omi_backend.client_manager import ( - initialize_redis_for_client_manager, - ) + async def _init_redis_audio_producer(): + try: + app.state.redis_audio_stream = await redis.from_url( + config.redis_url, encoding="utf-8", decode_responses=False + ) + from advanced_omi_backend.services.audio_stream import AudioStreamProducer - initialize_redis_for_client_manager(config.redis_url) + app.state.audio_stream_producer = AudioStreamProducer( + app.state.redis_audio_stream + ) + application_logger.info( + "Redis client for audio streaming producer initialized" + ) - except Exception as e: - application_logger.error( - f"Failed to initialize Redis client for audio streaming: {e}", exc_info=True + from advanced_omi_backend.client_manager import ( + initialize_redis_for_client_manager, + ) + + initialize_redis_for_client_manager(config.redis_url) + except Exception as e: + application_logger.error( + f"Failed to initialize Redis client for audio streaming: {e}", + exc_info=True, + ) + application_logger.warning("Audio streaming producer will not be available") + + async def _deferred_prompt_seed(): + """Seed prompts into Langfuse with retry backoff.""" + try: + from advanced_omi_backend.prompt_registry import get_prompt_registry + + registry = get_prompt_registry() + except Exception: + return + + backoff_delays = [0, 2, 4, 8, 16, 32] + for delay in backoff_delays: + if delay: + await asyncio.sleep(delay) + try: + await registry.seed_prompts() + application_logger.info("Prompt seeding to Langfuse completed") + return + except Exception as e: + application_logger.debug( + f"Prompt seeding attempt failed (next retry in {delay}s): {e}" + ) + application_logger.warning( + "Prompt seeding to Langfuse failed after all retries" ) - application_logger.warning("Audio streaming producer will not be available") - # Skip memory service pre-initialization to avoid blocking FastAPI startup - # Memory service will be lazily initialized when first used + await asyncio.gather( + _init_llm_client(), + _init_audio_stream_service(), + _init_redis_audio_producer(), + ) + + # Launch deferred prompt seeding as a fire-and-forget background task + asyncio.create_task(_deferred_prompt_seed()) + application_logger.info( - "Memory service will be initialized on first use (lazy loading)" + f"Phase 3 (LLM/AudioStream/RedisProducer) completed in {time.monotonic() - phase_start:.2f}s" ) - # Register OpenMemory user if using openmemory_mcp provider - await initialize_openmemory_user() + # ── Phase 4 (parallel — all independent) ───────────────────────── + phase_start = time.monotonic() - # Start cron scheduler (requires Redis to be available) - try: - from advanced_omi_backend.cron_scheduler import get_scheduler, register_cron_job - from advanced_omi_backend.workers.finetuning_jobs import ( - run_asr_finetuning_job, - run_asr_jargon_extraction_job, - run_speaker_finetuning_job, - ) - from advanced_omi_backend.workers.prompt_optimization_jobs import ( - run_prompt_optimization_job, - ) + application_logger.info( + "Memory service will be initialized on first use (lazy loading)" + ) - register_cron_job("speaker_finetuning", run_speaker_finetuning_job) - register_cron_job("asr_finetuning", run_asr_finetuning_job) - register_cron_job("asr_jargon_extraction", run_asr_jargon_extraction_job) - register_cron_job("prompt_optimization", run_prompt_optimization_job) + async def _init_openmemory(): + await initialize_openmemory_user() - scheduler = get_scheduler() - await scheduler.start() - application_logger.info("Cron scheduler started") - except Exception as e: - application_logger.warning(f"Cron scheduler failed to start: {e}") + async def _init_cron_scheduler(): + try: + from advanced_omi_backend.cron_scheduler import ( + get_scheduler, + register_cron_job, + ) + from advanced_omi_backend.workers.annotation_jobs import ( + surface_error_suggestions, + ) + from advanced_omi_backend.workers.finetuning_jobs import ( + run_asr_finetuning_job, + run_asr_jargon_extraction_job, + run_speaker_finetuning_job, + ) + from advanced_omi_backend.workers.prompt_optimization_jobs import ( + run_prompt_optimization_job, + ) - # SystemTracker is used for monitoring and debugging - application_logger.info("Using SystemTracker for monitoring and debugging") + register_cron_job("speaker_finetuning", run_speaker_finetuning_job) + register_cron_job("asr_finetuning", run_asr_finetuning_job) + register_cron_job("asr_jargon_extraction", run_asr_jargon_extraction_job) + register_cron_job("prompt_optimization", run_prompt_optimization_job) + register_cron_job("annotation_suggestions", surface_error_suggestions) - # Initialize plugins using plugin service - try: - from advanced_omi_backend.services.plugin_service import ( - init_plugin_router, - set_plugin_router, - ) + scheduler = get_scheduler() + await scheduler.start() + application_logger.info("Cron scheduler started") + except Exception as e: + application_logger.warning(f"Cron scheduler failed to start: {e}") - plugin_router = init_plugin_router() - - if plugin_router: - # Initialize async resources for each enabled plugin - for plugin_id, plugin in plugin_router.plugins.items(): - if plugin.enabled: - try: - await plugin.initialize() - plugin_router.mark_plugin_initialized(plugin_id) - application_logger.info(f"✅ Plugin '{plugin_id}' initialized") - except Exception as e: - plugin_router.mark_plugin_failed(plugin_id, str(e)) - application_logger.error( - f"Failed to initialize plugin '{plugin_id}': {e}", - exc_info=True, - ) - - health = plugin_router.get_health_summary() - application_logger.info( - f"Plugins initialized: {health['initialized']}/{health['total']} active" - + (f", {health['failed']} failed" if health["failed"] else "") + async def _init_plugins(): + try: + from advanced_omi_backend.services.plugin_service import ( + init_plugin_router, + set_plugin_router, ) - # Store in app state for API access - app.state.plugin_router = plugin_router - # Register with plugin service for worker access - set_plugin_router(plugin_router) - else: - application_logger.info("No plugins configured") + plugin_router = init_plugin_router() + + if plugin_router: + for plugin_id, plugin in plugin_router.plugins.items(): + if plugin.enabled: + try: + await plugin.initialize() + plugin_router.mark_plugin_initialized(plugin_id) + application_logger.info(f"Plugin '{plugin_id}' initialized") + except Exception as e: + plugin_router.mark_plugin_failed(plugin_id, str(e)) + application_logger.error( + f"Failed to initialize plugin '{plugin_id}': {e}", + exc_info=True, + ) + + health = plugin_router.get_health_summary() + application_logger.info( + f"Plugins initialized: {health['initialized']}/{health['total']} active" + + (f", {health['failed']} failed" if health["failed"] else "") + ) + + app.state.plugin_router = plugin_router + set_plugin_router(plugin_router) + else: + application_logger.info("No plugins configured") + app.state.plugin_router = None + + except Exception as e: + application_logger.error( + f"Failed to initialize plugin system: {e}", exc_info=True + ) app.state.plugin_router = None - except Exception as e: - application_logger.error( - f"Failed to initialize plugin system: {e}", exc_info=True - ) - app.state.plugin_router = None + await asyncio.gather( + _init_openmemory(), + _init_cron_scheduler(), + _init_plugins(), + ) + + application_logger.info( + f"Phase 4 (OpenMemory/Cron/Plugins) completed in {time.monotonic() - phase_start:.2f}s" + ) + total_startup = time.monotonic() - startup_start application_logger.info( - "Application ready - using application-level processing architecture." + f"Application ready in {total_startup:.2f}s - using application-level processing architecture." ) logger.info("App ready") diff --git a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py index 1bf41dfc..c2f1ad5d 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py @@ -46,6 +46,25 @@ audio_logger = logging.getLogger("audio_processing") +async def _get_conversation_or_error(conversation_id: str, user: User): + """Fetch a conversation and validate user access. + + Returns (conversation, None) on success, or (None, error_response) on failure. + """ + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) + if not conversation: + return None, JSONResponse( + status_code=404, content={"error": "Conversation not found"} + ) + if not user.is_superuser and conversation.user_id != str(user.user_id): + return None, JSONResponse( + status_code=403, content={"error": "Access forbidden"} + ) + return conversation, None + + async def close_current_conversation(client_id: str, user: User): """Close the current conversation for a specific client. @@ -112,18 +131,9 @@ async def close_current_conversation(client_id: str, user: User): async def get_conversation(conversation_id: str, user: User): """Get a single conversation with full transcript details.""" try: - # Find the conversation using Beanie - conversation = await Conversation.find_one( - Conversation.conversation_id == conversation_id - ) - if not conversation: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - # Check ownership for non-admin users - if not user.is_superuser and conversation.user_id != str(user.user_id): - return JSONResponse(status_code=403, content={"error": "Access forbidden"}) + conversation, error = await _get_conversation_or_error(conversation_id, user) + if error: + return error # Build response with explicit curated fields response = { @@ -184,16 +194,9 @@ async def get_conversation(conversation_id: str, user: User): async def get_conversation_memories(conversation_id: str, user: User, limit: int = 100): """Get memories extracted from a specific conversation.""" try: - conversation = await Conversation.find_one( - Conversation.conversation_id == conversation_id - ) - if not conversation: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - if not user.is_superuser and conversation.user_id != str(user.user_id): - return JSONResponse(status_code=403, content={"error": "Access forbidden"}) + conversation, error = await _get_conversation_or_error(conversation_id, user) + if error: + return error memory_service = get_memory_service() memories = await memory_service.get_memories_by_source( @@ -671,29 +674,13 @@ async def delete_conversation( f"Attempting to {'permanently ' if permanent else ''}delete conversation: {masked_id}" ) - # Find the conversation using Beanie - conversation = await Conversation.find_one( - Conversation.conversation_id == conversation_id - ) - - if not conversation: - return JSONResponse( - status_code=404, - content={"error": f"Conversation '{conversation_id}' not found"}, - ) - - # Check ownership for non-admin users - if not user.is_superuser and conversation.user_id != str(user.user_id): - logger.warning( - f"User {user.user_id} attempted to delete conversation {conversation_id} without permission" - ) - return JSONResponse( - status_code=403, - content={ - "error": "Access forbidden. You can only delete your own conversations.", - "details": f"Conversation '{conversation_id}' does not belong to your account.", - }, - ) + conversation, error = await _get_conversation_or_error(conversation_id, user) + if error: + if error.status_code == 403: + logger.warning( + f"User {user.user_id} attempted to delete conversation {conversation_id} without permission" + ) + return error # Hard delete (admin only, permanent flag) if permanent and user.is_superuser: @@ -719,18 +706,9 @@ async def restore_conversation(conversation_id: str, user: User) -> JSONResponse user: Requesting user """ try: - conversation = await Conversation.find_one( - Conversation.conversation_id == conversation_id - ) - - if not conversation: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - # Permission check - if not user.is_superuser and conversation.user_id != str(user.user_id): - return JSONResponse(status_code=403, content={"error": "Access denied"}) + conversation, error = await _get_conversation_or_error(conversation_id, user) + if error: + return error if not conversation.deleted: return JSONResponse( @@ -933,16 +911,9 @@ def _enqueue_speaker_reprocessing_chain( async def toggle_star(conversation_id: str, user: User): """Toggle the starred/favorite status of a conversation.""" try: - conversation = await Conversation.find_one( - Conversation.conversation_id == conversation_id - ) - if not conversation: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - if not user.is_superuser and conversation.user_id != str(user.user_id): - return JSONResponse(status_code=403, content={"error": "Access forbidden"}) + conversation, error = await _get_conversation_or_error(conversation_id, user) + if error: + return error # Toggle conversation.starred = not conversation.starred @@ -993,17 +964,9 @@ async def toggle_star(conversation_id: str, user: User): async def reprocess_orphan(conversation_id: str, user: User): """Reprocess an orphan audio session - restore if deleted and enqueue full processing chain.""" try: - conversation = await Conversation.find_one( - Conversation.conversation_id == conversation_id - ) - if not conversation: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - # Check ownership - if not user.is_superuser and conversation.user_id != str(user.user_id): - return JSONResponse(status_code=403, content={"error": "Access forbidden"}) + conversation, error = await _get_conversation_or_error(conversation_id, user) + if error: + return error # Verify audio chunks exist (check both deleted and non-deleted) total_chunks = await AudioChunkDocument.find( @@ -1068,23 +1031,11 @@ async def reprocess_orphan(conversation_id: str, user: User): async def reprocess_transcript(conversation_id: str, user: User): """Reprocess transcript for a conversation. Users can only reprocess their own conversations.""" try: - # Find the conversation using Beanie - conversation_model = await Conversation.find_one( - Conversation.conversation_id == conversation_id + conversation_model, error = await _get_conversation_or_error( + conversation_id, user ) - if not conversation_model: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - # Check ownership for non-admin users - if not user.is_superuser and conversation_model.user_id != str(user.user_id): - return JSONResponse( - status_code=403, - content={ - "error": "Access forbidden. You can only reprocess your own conversations." - }, - ) + if error: + return error # Get audio_uuid from conversation # Validate audio chunks exist in MongoDB @@ -1137,24 +1088,11 @@ async def reprocess_memory( ): """Reprocess memory extraction for a specific transcript version. Users can only reprocess their own conversations.""" try: - # Find the conversation using Beanie - conversation_model = await Conversation.find_one( - Conversation.conversation_id == conversation_id + conversation_model, error = await _get_conversation_or_error( + conversation_id, user ) - if not conversation_model: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - # Check ownership for non-admin users - if not user.is_superuser and conversation_model.user_id != str(user.user_id): - return JSONResponse( - status_code=403, - content={ - "error": "Access forbidden. You can only reprocess your own conversations." - }, - ) - + if error: + return error # Resolve transcript version ID (handle "active" special case) error, transcript_version_id, transcript_version = _resolve_transcript_version( conversation_model, transcript_version_id @@ -1205,23 +1143,11 @@ async def reprocess_speakers( """ try: # 1. Find conversation and validate ownership - conversation_model = await Conversation.find_one( - Conversation.conversation_id == conversation_id + conversation_model, error = await _get_conversation_or_error( + conversation_id, user ) - if not conversation_model: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - # Check ownership for non-admin users - if not user.is_superuser and conversation_model.user_id != str(user.user_id): - return JSONResponse( - status_code=403, - content={ - "error": "Access forbidden. You can only reprocess your own conversations." - }, - ) - + if error: + return error # 2-3. Resolve source transcript version ID and find version object error, source_version_id, source_version = _resolve_transcript_version( conversation_model, transcript_version_id @@ -1349,23 +1275,11 @@ async def activate_transcript_version( ): """Activate a specific transcript version. Users can only modify their own conversations.""" try: - # Find the conversation using Beanie - conversation_model = await Conversation.find_one( - Conversation.conversation_id == conversation_id + conversation_model, error = await _get_conversation_or_error( + conversation_id, user ) - if not conversation_model: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - # Check ownership for non-admin users - if not user.is_superuser and conversation_model.user_id != str(user.user_id): - return JSONResponse( - status_code=403, - content={ - "error": "Access forbidden. You can only modify your own conversations." - }, - ) + if error: + return error # Activate the transcript version using Beanie model method success = conversation_model.set_active_transcript_version(version_id) @@ -1401,23 +1315,11 @@ async def activate_transcript_version( async def activate_memory_version(conversation_id: str, version_id: str, user: User): """Activate a specific memory version. Users can only modify their own conversations.""" try: - # Find the conversation using Beanie - conversation_model = await Conversation.find_one( - Conversation.conversation_id == conversation_id + conversation_model, error = await _get_conversation_or_error( + conversation_id, user ) - if not conversation_model: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - # Check ownership for non-admin users - if not user.is_superuser and conversation_model.user_id != str(user.user_id): - return JSONResponse( - status_code=403, - content={ - "error": "Access forbidden. You can only modify your own conversations." - }, - ) + if error: + return error # Activate the memory version using Beanie model method success = conversation_model.set_active_memory_version(version_id) @@ -1449,23 +1351,11 @@ async def activate_memory_version(conversation_id: str, version_id: str, user: U async def get_conversation_version_history(conversation_id: str, user: User): """Get version history for a conversation. Users can only access their own conversations.""" try: - # Find the conversation using Beanie to check ownership - conversation_model = await Conversation.find_one( - Conversation.conversation_id == conversation_id + conversation_model, error = await _get_conversation_or_error( + conversation_id, user ) - if not conversation_model: - return JSONResponse( - status_code=404, content={"error": "Conversation not found"} - ) - - # Check ownership for non-admin users - if not user.is_superuser and conversation_model.user_id != str(user.user_id): - return JSONResponse( - status_code=403, - content={ - "error": "Access forbidden. You can only access your own conversations." - }, - ) + if error: + return error # Get version history from model # Convert datetime objects to ISO strings for JSON serialization diff --git a/backends/advanced/src/advanced_omi_backend/controllers/memory_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/memory_controller.py index fe4fca88..40c1ac51 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/memory_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/memory_controller.py @@ -17,15 +17,19 @@ audio_logger = logging.getLogger("audio_processing") +def _resolve_target_user(user: User, user_id: Optional[str] = None) -> str: + """Return the effective user ID: admins may override with user_id param.""" + if user.is_superuser and user_id: + return user_id + return user.user_id + + async def get_memories(user: User, limit: int, user_id: Optional[str] = None): """Get memories. Users see only their own memories, admins can see all or filter by user.""" try: memory_service = get_memory_service() - # Determine which user's memories to fetch - target_user_id = user.user_id - if user.is_superuser and user_id: - target_user_id = user_id + target_user_id = _resolve_target_user(user, user_id) # Execute memory retrieval directly (now async) memories = await memory_service.get_all_memories(target_user_id, limit) @@ -40,7 +44,7 @@ async def get_memories(user: User, limit: int, user_id: Optional[str] = None): "memories": memories_dicts, "count": len(memories), "total_count": total_count, - "user_id": target_user_id + "user_id": target_user_id, } except Exception as e: @@ -50,15 +54,14 @@ async def get_memories(user: User, limit: int, user_id: Optional[str] = None): ) -async def get_memories_with_transcripts(user: User, limit: int, user_id: Optional[str] = None): +async def get_memories_with_transcripts( + user: User, limit: int, user_id: Optional[str] = None +): """Get memories with their source transcripts. Users see only their own memories, admins can see all or filter by user.""" try: memory_service = get_memory_service() - # Determine which user's memories to fetch - target_user_id = user.user_id - if user.is_superuser and user_id: - target_user_id = user_id + target_user_id = _resolve_target_user(user, user_id) # Execute memory retrieval directly (now async) memories_with_transcripts = await memory_service.get_memories_with_transcripts( @@ -72,25 +75,32 @@ async def get_memories_with_transcripts(user: User, limit: int, user_id: Optiona } except Exception as e: - audio_logger.error(f"Error fetching memories with transcripts: {e}", exc_info=True) + audio_logger.error( + f"Error fetching memories with transcripts: {e}", exc_info=True + ) return JSONResponse( status_code=500, content={"message": f"Error fetching memories with transcripts: {str(e)}"}, ) -async def search_memories(query: str, user: User, limit: int, score_threshold: float = 0.0, user_id: Optional[str] = None): +async def search_memories( + query: str, + user: User, + limit: int, + score_threshold: float = 0.0, + user_id: Optional[str] = None, +): """Search memories by text query. Users can only search their own memories, admins can search all or filter by user.""" try: memory_service = get_memory_service() - # Determine which user's memories to search - target_user_id = user.user_id - if user.is_superuser and user_id: - target_user_id = user_id + target_user_id = _resolve_target_user(user, user_id) # Execute search directly (now async) - search_results = await memory_service.search_memories(query, target_user_id, limit, score_threshold) + search_results = await memory_service.search_memories( + query, target_user_id, limit, score_threshold + ) # Convert MemoryEntry objects to dicts for JSON serialization results_dicts = [result.to_dict() for result in search_results] @@ -122,16 +132,26 @@ async def delete_memory(memory_id: str, user: User): # MemoryEntry is a dataclass, access id attribute directly memory_ids = [str(mem.id) for mem in user_memories] if memory_id not in memory_ids: - return JSONResponse(status_code=404, content={"message": "Memory not found"}) + return JSONResponse( + status_code=404, content={"message": "Memory not found"} + ) # Delete the memory - audio_logger.info(f"Deleting memory {memory_id} for user_id={user.user_id}, email={user.email}") - success = await memory_service.delete_memory(memory_id, user_id=user.user_id, user_email=user.email) + audio_logger.info( + f"Deleting memory {memory_id} for user_id={user.user_id}, email={user.email}" + ) + success = await memory_service.delete_memory( + memory_id, user_id=user.user_id, user_email=user.email + ) if success: - return JSONResponse(content={"message": f"Memory {memory_id} deleted successfully"}) + return JSONResponse( + content={"message": f"Memory {memory_id} deleted successfully"} + ) else: - return JSONResponse(status_code=404, content={"message": "Memory not found"}) + return JSONResponse( + status_code=404, content={"message": "Memory not found"} + ) except Exception as e: audio_logger.error(f"Error deleting memory: {e}", exc_info=True) @@ -146,7 +166,9 @@ async def add_memory(content: str, user: User, source_id: Optional[str] = None): memory_service = get_memory_service() # Use source_id or generate a unique one - memory_source_id = source_id or f"manual_{user.user_id}_{int(asyncio.get_event_loop().time())}" + memory_source_id = ( + source_id or f"manual_{user.user_id}_{int(asyncio.get_event_loop().time())}" + ) # Extract memories from content success, memory_ids = await memory_service.add_memory( @@ -156,7 +178,7 @@ async def add_memory(content: str, user: User, source_id: Optional[str] = None): user_id=user.user_id, user_email=user.email, allow_update=False, - db_helper=None + db_helper=None, ) if success: @@ -165,18 +187,19 @@ async def add_memory(content: str, user: User, source_id: Optional[str] = None): "memory_ids": memory_ids, "count": len(memory_ids), "source_id": memory_source_id, - "message": f"Successfully created {len(memory_ids)} memory/memories" + "message": f"Successfully created {len(memory_ids)} memory/memories", } else: return JSONResponse( status_code=500, - content={"success": False, "message": "Failed to create memories"} + content={"success": False, "message": "Failed to create memories"}, ) except Exception as e: audio_logger.error(f"Error adding memory: {e}", exc_info=True) return JSONResponse( - status_code=500, content={"success": False, "message": f"Error adding memory: {str(e)}"} + status_code=500, + content={"success": False, "message": f"Error adding memory: {str(e)}"}, ) @@ -225,7 +248,8 @@ async def get_all_memories_admin(user: User, limit: int): except Exception as e: audio_logger.error(f"Error fetching admin memories: {e}", exc_info=True) return JSONResponse( - status_code=500, content={"message": f"Error fetching admin memories: {str(e)}"} + status_code=500, + content={"message": f"Error fetching admin memories: {str(e)}"}, ) @@ -234,10 +258,7 @@ async def get_memory_by_id(memory_id: str, user: User, user_id: Optional[str] = try: memory_service = get_memory_service() - # Determine which user's memory to fetch - target_user_id = user.user_id - if user.is_superuser and user_id: - target_user_id = user_id + target_user_id = _resolve_target_user(user, user_id) # Get the specific memory memory = await memory_service.get_memory(memory_id, target_user_id) @@ -265,11 +286,15 @@ async def get_memory_by_id(memory_id: str, user: User, user_id: Optional[str] = ), } except Exception as e: - logger.warning(f"Failed to fetch source conversation {source_id}: {e}") + logger.warning( + f"Failed to fetch source conversation {source_id}: {e}" + ) return {"memory": memory_dict} else: - return JSONResponse(status_code=404, content={"message": "Memory not found"}) + return JSONResponse( + status_code=404, content={"message": "Memory not found"} + ) except Exception as e: audio_logger.error(f"Error fetching memory {memory_id}: {e}", exc_info=True) diff --git a/backends/advanced/src/advanced_omi_backend/models/annotation.py b/backends/advanced/src/advanced_omi_backend/models/annotation.py index 451d84d1..99974532 100644 --- a/backends/advanced/src/advanced_omi_backend/models/annotation.py +++ b/backends/advanced/src/advanced_omi_backend/models/annotation.py @@ -16,22 +16,26 @@ class AnnotationType(str, Enum): """Type of content being annotated.""" + MEMORY = "memory" TRANSCRIPT = "transcript" DIARIZATION = "diarization" # Speaker identification corrections ENTITY = "entity" # Knowledge graph entity corrections (name/details edits) TITLE = "title" # Conversation title corrections INSERT = "insert" # Insert new segment between existing segments + SPEECH_SUGGESTION_CORRECTION = "speech_suggestion_correction" # User-refined model suggestion (training signal triple) class AnnotationSource(str, Enum): """Origin of the annotation.""" + USER = "user" # User-created edit MODEL_SUGGESTION = "model_suggestion" # AI-generated suggestion class AnnotationStatus(str, Enum): """Lifecycle status of annotation.""" + PENDING = "pending" # Waiting for user review (suggestions) ACCEPTED = "accepted" # Applied to content REJECTED = "rejected" # User dismissed suggestion @@ -79,6 +83,11 @@ class Annotation(Document): entity_id: Optional[str] = None # Neo4j entity ID entity_field: Optional[str] = None # Which field was changed ("name" or "details") + # For SPEECH_SUGGESTION_CORRECTION annotations: + model_suggested_text: Optional[str] = ( + None # What AI originally suggested before user edited + ) + # For INSERT annotations: insert_after_index: Optional[int] = None # -1 = before first segment insert_text: Optional[str] = None # e.g., "[laughter]" or "wife laughed" @@ -86,17 +95,17 @@ class Annotation(Document): insert_speaker: Optional[str] = None # Speaker label for "speech" type inserts # Processed tracking (applies to ALL annotation types) - processed: bool = Field(default=False) # Whether annotation has been applied/sent to training + processed: bool = Field( + default=False + ) # Whether annotation has been applied/sent to training processed_at: Optional[datetime] = None # When annotation was processed - processed_by: Optional[str] = None # What processed it (manual, cron, apply, training, etc.) + processed_by: Optional[str] = ( + None # What processed it (manual, cron, apply, training, etc.) + ) # Timestamps (Python 3.12+ compatible) - created_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc) - ) - updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc) - ) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) class Settings: name = "annotations" @@ -132,6 +141,10 @@ def is_title_annotation(self) -> bool: """Check if this is a title annotation.""" return self.annotation_type == AnnotationType.TITLE + def is_speech_suggestion_correction(self) -> bool: + """Check if this is a user-refined model suggestion.""" + return self.annotation_type == AnnotationType.SPEECH_SUGGESTION_CORRECTION + def is_pending_suggestion(self) -> bool: """Check if this is a pending AI suggestion.""" return ( @@ -145,6 +158,7 @@ def is_pending_suggestion(self) -> bool: class AnnotationCreateBase(BaseModel): """Base model for annotation creation.""" + original_text: str = "" # Optional for diarization corrected_text: str = "" # Optional for diarization status: AnnotationStatus = AnnotationStatus.ACCEPTED @@ -152,6 +166,7 @@ class AnnotationCreateBase(BaseModel): class MemoryAnnotationCreate(AnnotationCreateBase): """Create memory annotation request.""" + memory_id: str original_text: str # Required for memory annotations corrected_text: str # Required for memory annotations @@ -159,6 +174,7 @@ class MemoryAnnotationCreate(AnnotationCreateBase): class TranscriptAnnotationCreate(AnnotationCreateBase): """Create transcript annotation request.""" + conversation_id: str segment_index: int original_text: str # Required for transcript annotations @@ -167,6 +183,7 @@ class TranscriptAnnotationCreate(AnnotationCreateBase): class DiarizationAnnotationCreate(BaseModel): """Create diarization annotation request.""" + conversation_id: str segment_index: int original_speaker: str @@ -181,6 +198,7 @@ class EntityAnnotationCreate(BaseModel): Dual purpose: feeds both the jargon pipeline (entity name corrections = domain vocabulary the ASR should know) and the entity extraction pipeline (corrections improve future accuracy). """ + entity_id: str entity_field: str # "name" or "details" original_text: str @@ -189,6 +207,7 @@ class EntityAnnotationCreate(BaseModel): class TitleAnnotationCreate(AnnotationCreateBase): """Create title annotation request.""" + conversation_id: str original_text: str corrected_text: str @@ -196,6 +215,7 @@ class TitleAnnotationCreate(AnnotationCreateBase): class InsertAnnotationCreate(BaseModel): """Create insert annotation request (new segment between existing segments).""" + conversation_id: str insert_after_index: int # -1 = before first segment insert_text: str @@ -205,8 +225,10 @@ class InsertAnnotationCreate(BaseModel): class AnnotationUpdate(BaseModel): """Update an existing unprocessed annotation.""" + corrected_text: Optional[str] = None corrected_speaker: Optional[str] = None + model_suggested_text: Optional[str] = None insert_text: Optional[str] = None insert_segment_type: Optional[str] = None insert_speaker: Optional[str] = None @@ -214,6 +236,7 @@ class AnnotationUpdate(BaseModel): class AnnotationResponse(BaseModel): """Annotation response for API.""" + id: str annotation_type: AnnotationType user_id: str @@ -227,6 +250,7 @@ class AnnotationResponse(BaseModel): segment_start_time: Optional[float] = None entity_id: Optional[str] = None entity_field: Optional[str] = None + model_suggested_text: Optional[str] = None insert_after_index: Optional[int] = None insert_text: Optional[str] = None insert_segment_type: Optional[str] = None diff --git a/backends/advanced/src/advanced_omi_backend/prompt_defaults.py b/backends/advanced/src/advanced_omi_backend/prompt_defaults.py index 94b03d31..7e847b41 100644 --- a/backends/advanced/src/advanced_omi_backend/prompt_defaults.py +++ b/backends/advanced/src/advanced_omi_backend/prompt_defaults.py @@ -607,6 +607,42 @@ def register_all_defaults(registry: PromptRegistry) -> None: is_dynamic=True, ) + # ------------------------------------------------------------------ + # annotation.transcript_error_detection + # ------------------------------------------------------------------ + registry.register_default( + "annotation.transcript_error_detection", + template="""\ +You are a transcript quality reviewer. Analyze the following transcript segments \ +from a conversation and identify potential transcription errors. + +Look for: +- Misheard words (homophones, phonetically similar substitutions) +- Nonsensical phrases that are likely ASR mistakes +- Obvious hallucinations or repeated/garbled text +- Missing or extra words that break sentence meaning + +Conversation title: {{title}} + +Segments (index: speaker - text): +{{segments_text}} + +Return a JSON array of issues found. Each issue should have: +- "segment_index": the index number of the problematic segment +- "original_text": the exact text from that segment +- "corrected_text": your suggested correction +- "reason": brief explanation (e.g. "misheard word", "garbled text", "hallucination") + +If no issues are found, return an empty array: [] + +Return ONLY the JSON array, no other text.""", + name="Transcript Error Detection", + description="Analyzes transcript segments for ASR errors, hallucinations, and misheard words.", + category="annotation", + variables=["title", "segments_text"], + is_dynamic=True, + ) + # ------------------------------------------------------------------ # prompt_optimization.title_optimizer # ------------------------------------------------------------------ diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py index e04e6c76..43ffa212 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py @@ -9,13 +9,14 @@ from datetime import datetime, timezone from typing import List -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import JSONResponse from advanced_omi_backend.auth import current_active_user from advanced_omi_backend.models.annotation import ( Annotation, AnnotationResponse, + AnnotationSource, AnnotationStatus, AnnotationType, AnnotationUpdate, @@ -36,6 +37,103 @@ router = APIRouter(prefix="/annotations", tags=["annotations"]) +@router.get("/suggestions") +async def get_pending_suggestions( + current_user: User = Depends(current_active_user), + limit: int = Query(20, ge=1, le=100), +): + """ + Get pending AI-generated suggestions for the current user. + + Returns MODEL_SUGGESTION annotations with PENDING status, + enriched with conversation context (title, transcript snippet, + audio path) for the swipe review UI. + """ + try: + annotations = ( + await Annotation.find( + Annotation.user_id == current_user.user_id, + Annotation.source == AnnotationSource.MODEL_SUGGESTION, + Annotation.status == AnnotationStatus.PENDING, + ) + .sort("-created_at") + .limit(limit) + .to_list() + ) + + if not annotations: + return [] + + # Batch-fetch conversations for context + conversation_ids = list( + {a.conversation_id for a in annotations if a.conversation_id} + ) + conversations = await Conversation.find( + {"conversation_id": {"$in": conversation_ids}}, + ).to_list() + conv_map = {c.conversation_id: c for c in conversations} + + results = [] + for a in annotations: + conv = conv_map.get(a.conversation_id) + + segment_start = None + segment_end = None + if conv and a.segment_index is not None: + transcript = conv.active_transcript + if ( + transcript + and transcript.segments + and a.segment_index < len(transcript.segments) + ): + seg = transcript.segments[a.segment_index] + segment_start = seg.start + segment_end = seg.end + + results.append( + { + "id": a.id, + "annotation_type": a.annotation_type, + "conversation_id": a.conversation_id, + "segment_index": a.segment_index, + "original_text": a.original_text, + "corrected_text": a.corrected_text, + "created_at": a.created_at.isoformat(), + "conversation_title": conv.title if conv else None, + "transcript_snippet": _get_segment_context(conv, a.segment_index), + "segment_start": segment_start, + "segment_end": segment_end, + } + ) + + return results + + except Exception as e: + logger.error(f"Error fetching suggestions: {e}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to fetch suggestions: {str(e)}" + ) + + +def _get_segment_context( + conversation, segment_index: int | None, context_size: int = 1 +) -> str | None: + """Get a snippet of transcript around the flagged segment for context.""" + if not conversation or segment_index is None: + return None + transcript = conversation.active_transcript + if not transcript or not transcript.segments: + return None + start = max(0, segment_index - context_size) + end = min(len(transcript.segments), segment_index + context_size + 1) + lines = [] + for i in range(start, end): + seg = transcript.segments[i] + prefix = ">>> " if i == segment_index else " " + lines.append(f"{prefix}{seg.speaker}: {seg.text}") + return "\n".join(lines) + + @router.post("/memory", response_model=AnnotationResponse) async def create_memory_annotation( annotation_data: MemoryAnnotationCreate, @@ -85,11 +183,15 @@ async def create_memory_annotation( content=annotation_data.corrected_text, user_id=current_user.user_id, ) - logger.info(f"Updated memory {annotation_data.memory_id} with corrected text") + logger.info( + f"Updated memory {annotation_data.memory_id} with corrected text" + ) except Exception as e: logger.error(f"Error updating memory: {e}") # Annotation is saved, but memory update failed - log but don't fail the request - logger.warning(f"Memory annotation {annotation.id} saved but memory update failed") + logger.warning( + f"Memory annotation {annotation.id} saved but memory update failed" + ) return AnnotationResponse.model_validate(annotation) @@ -237,7 +339,22 @@ async def update_annotation_status( annotation.updated_at = datetime.now(timezone.utc) # If accepting a pending suggestion, apply the correction - if status == AnnotationStatus.ACCEPTED and old_status == AnnotationStatus.PENDING: + if ( + status == AnnotationStatus.ACCEPTED + and old_status == AnnotationStatus.PENDING + ): + # Promote to SPEECH_SUGGESTION_CORRECTION if user edited the AI suggestion + if ( + annotation.source == AnnotationSource.MODEL_SUGGESTION + and annotation.model_suggested_text is not None + and annotation.is_transcript_annotation() + ): + annotation.annotation_type = AnnotationType.SPEECH_SUGGESTION_CORRECTION + logger.info( + f"Promoted annotation {annotation_id} to SPEECH_SUGGESTION_CORRECTION " + f"(AI suggested: {annotation.model_suggested_text!r}, user decided: {annotation.corrected_text!r})" + ) + if annotation.is_memory_annotation(): # Update memory try: @@ -251,8 +368,11 @@ async def update_annotation_status( except Exception as e: logger.error(f"Error applying memory suggestion: {e}") # Don't fail the status update if memory update fails - elif annotation.is_transcript_annotation(): - # Update transcript segment + elif ( + annotation.is_transcript_annotation() + or annotation.is_speech_suggestion_correction() + ): + # Update transcript segment (same logic for both TRANSCRIPT and SPEECH_SUGGESTION_CORRECTION) try: conversation = await Conversation.find_one( Conversation.conversation_id == annotation.conversation_id, @@ -260,7 +380,9 @@ async def update_annotation_status( ) if conversation: transcript = conversation.active_transcript - if transcript and annotation.segment_index < len(transcript.segments): + if transcript and annotation.segment_index < len( + transcript.segments + ): transcript.segments[annotation.segment_index].text = ( annotation.corrected_text ) @@ -286,7 +408,9 @@ async def update_annotation_status( user_id=annotation.user_id, **update_kwargs, ) - logger.info(f"Applied entity suggestion to entity {annotation.entity_id}") + logger.info( + f"Applied entity suggestion to entity {annotation.entity_id}" + ) except Exception as e: logger.error(f"Error applying entity suggestion: {e}") # Don't fail the status update if entity update fails @@ -310,7 +434,11 @@ async def update_annotation_status( await annotation.save() logger.info(f"Updated annotation {annotation_id} status to {status}") - return {"status": "updated", "annotation_id": annotation_id, "new_status": status} + return { + "status": "updated", + "annotation_id": annotation_id, + "new_status": status, + } except HTTPException: raise @@ -345,7 +473,9 @@ async def delete_annotation( raise HTTPException(status_code=404, detail="Annotation not found") if annotation.processed: - raise HTTPException(status_code=400, detail="Cannot delete a processed annotation") + raise HTTPException( + status_code=400, detail="Cannot delete a processed annotation" + ) await annotation.delete() logger.info(f"Deleted annotation {annotation_id}") @@ -384,10 +514,22 @@ async def update_annotation( raise HTTPException(status_code=404, detail="Annotation not found") if annotation.processed: - raise HTTPException(status_code=400, detail="Cannot update a processed annotation") + raise HTTPException( + status_code=400, detail="Cannot update a processed annotation" + ) if update_data.corrected_text is not None: + # Auto-capture AI's original suggestion before user overwrites it + if ( + annotation.source == AnnotationSource.MODEL_SUGGESTION + and annotation.model_suggested_text is None + and annotation.corrected_text + and update_data.corrected_text != annotation.corrected_text + ): + annotation.model_suggested_text = annotation.corrected_text annotation.corrected_text = update_data.corrected_text + if update_data.model_suggested_text is not None: + annotation.model_suggested_text = update_data.model_suggested_text if update_data.corrected_speaker is not None: annotation.corrected_speaker = update_data.corrected_speaker if update_data.insert_text is not None: @@ -441,7 +583,10 @@ async def create_insert_annotation( raise HTTPException(status_code=400, detail="No active transcript found") segment_count = len(active_transcript.segments) - if annotation_data.insert_after_index < -1 or annotation_data.insert_after_index >= segment_count: + if ( + annotation_data.insert_after_index < -1 + or annotation_data.insert_after_index >= segment_count + ): raise HTTPException( status_code=400, detail=f"insert_after_index must be between -1 and {segment_count - 1}", @@ -572,7 +717,9 @@ async def create_entity_annotation( user_id=current_user.user_id, **update_kwargs, ) - logger.info(f"Applied entity correction to Neo4j for entity {annotation_data.entity_id}") + logger.info( + f"Applied entity correction to Neo4j for entity {annotation_data.entity_id}" + ) except Exception as e: logger.error(f"Error applying entity correction to Neo4j: {e}") # Annotation is saved but Neo4j update failed — log but don't fail the request @@ -657,7 +804,9 @@ async def create_title_annotation( try: conversation.title = annotation_data.corrected_text await conversation.save() - logger.info(f"Updated title for conversation {annotation_data.conversation_id}") + logger.info( + f"Updated title for conversation {annotation_data.conversation_id}" + ) except Exception as e: logger.error(f"Error updating conversation title: {e}") # Annotation is saved but title update failed — log but don't fail the request @@ -697,7 +846,6 @@ async def get_title_annotations( ) - # === Diarization Annotation Routes === @@ -817,7 +965,10 @@ async def apply_diarization_annotations( if not annotations: return JSONResponse( - content={"message": "No pending annotations to apply", "applied_count": 0} + content={ + "message": "No pending annotations to apply", + "applied_count": 0, + } ) # Get active transcript version @@ -839,7 +990,9 @@ async def apply_diarization_annotations( key=lambda a: a.updated_at, reverse=True, ) - annotation_for_segment = annotations_for_segment[0] if annotations_for_segment else None + annotation_for_segment = ( + annotations_for_segment[0] if annotations_for_segment else None + ) if annotation_for_segment: # Apply correction @@ -951,7 +1104,10 @@ async def apply_all_annotations( a for a in annotations if a.annotation_type == AnnotationType.DIARIZATION ] transcript_annotations = [ - a for a in annotations if a.annotation_type == AnnotationType.TRANSCRIPT + a + for a in annotations + if a.annotation_type + in (AnnotationType.TRANSCRIPT, AnnotationType.SPEECH_SUGGESTION_CORRECTION) ] insert_annotations = [ a for a in annotations if a.annotation_type == AnnotationType.INSERT diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py index 7abb8bbd..2e338232 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py @@ -25,7 +25,9 @@ @router.post("/process-annotations") async def process_annotations_for_training( current_user: User = Depends(current_active_user), - annotation_type: Optional[str] = Query("diarization", description="Type of annotations to process"), + annotation_type: Optional[str] = Query( + "diarization", description="Type of annotations to process" + ), ): """ Send processed annotations to speaker recognition service for training. @@ -44,8 +46,7 @@ async def process_annotations_for_training( # Only admins can trigger training for now (can expand to per-user later) if not current_user.is_superuser: raise HTTPException( - status_code=403, - detail="Only administrators can trigger model training" + status_code=403, detail="Only administrators can trigger model training" ) # Find annotations ready for training @@ -57,15 +58,18 @@ async def process_annotations_for_training( # Filter out already trained annotations (processed_by contains "training") ready_for_training = [ - a for a in annotations + a + for a in annotations if not a.processed_by or "training" not in a.processed_by ] if not ready_for_training: - return JSONResponse(content={ - "message": "No annotations ready for training", - "processed_count": 0 - }) + return JSONResponse( + content={ + "message": "No annotations ready for training", + "processed_count": 0, + } + ) # Import required modules from advanced_omi_backend.models.conversation import Conversation @@ -78,13 +82,16 @@ async def process_annotations_for_training( # Initialize speaker client speaker_client = SpeakerRecognitionClient() - + if not speaker_client.enabled: - return JSONResponse(content={ - "message": "Speaker recognition service is not enabled", - "processed_count": 0, - "status": "error" - }, status_code=503) + return JSONResponse( + content={ + "message": "Speaker recognition service is not enabled", + "processed_count": 0, + "status": "error", + }, + status_code=503, + ) # Track processing statistics enrolled_count = 0 @@ -101,27 +108,33 @@ async def process_annotations_for_training( if not conversation or not conversation.active_transcript: failed_count += 1 - errors.append(f"Conversation {annotation.conversation_id[:8]} not found") + errors.append( + f"Conversation {annotation.conversation_id[:8]} not found" + ) continue # Validate segment index - if annotation.segment_index >= len(conversation.active_transcript.segments): + if annotation.segment_index >= len( + conversation.active_transcript.segments + ): failed_count += 1 errors.append(f"Invalid segment index {annotation.segment_index}") continue - segment = conversation.active_transcript.segments[annotation.segment_index] + segment = conversation.active_transcript.segments[ + annotation.segment_index + ] # 2. Extract audio segment from MongoDB logger.info( f"Extracting audio for conversation {annotation.conversation_id[:8]}... " f"segment {annotation.segment_index} ({segment.start:.2f}s - {segment.end:.2f}s)" ) - + wav_bytes = await reconstruct_audio_segment( conversation_id=annotation.conversation_id, start_time=segment.start, - end_time=segment.end + end_time=segment.end, ) if not wav_bytes: @@ -135,42 +148,49 @@ async def process_annotations_for_training( # 3. Check if speaker exists existing_speaker = await speaker_client.get_speaker_by_name( speaker_name=annotation.corrected_speaker, - user_id=1 # TODO: Map Chronicle user_id to speaker service user_id + user_id=1, # TODO: Map Chronicle user_id to speaker service user_id ) if existing_speaker: # APPEND to existing speaker - logger.info(f"Appending to existing speaker: {annotation.corrected_speaker}") + logger.info( + f"Appending to existing speaker: {annotation.corrected_speaker}" + ) result = await speaker_client.append_to_speaker( - speaker_id=existing_speaker["id"], - audio_data=wav_bytes + speaker_id=existing_speaker["id"], audio_data=wav_bytes ) - + if "error" in result: logger.error(f"Failed to append to speaker: {result}") failed_count += 1 errors.append(f"Append failed: {result.get('error')}") continue - + appended_count += 1 - logger.info(f"✅ Successfully appended to speaker '{annotation.corrected_speaker}'") + logger.info( + f"✅ Successfully appended to speaker '{annotation.corrected_speaker}'" + ) else: # ENROLL new speaker - logger.info(f"Enrolling new speaker: {annotation.corrected_speaker}") + logger.info( + f"Enrolling new speaker: {annotation.corrected_speaker}" + ) result = await speaker_client.enroll_new_speaker( speaker_name=annotation.corrected_speaker, audio_data=wav_bytes, - user_id=1 # TODO: Map Chronicle user_id to speaker service user_id + user_id=1, # TODO: Map Chronicle user_id to speaker service user_id ) - + if "error" in result: logger.error(f"Failed to enroll speaker: {result}") failed_count += 1 errors.append(f"Enroll failed: {result.get('error')}") continue - + enrolled_count += 1 - logger.info(f"✅ Successfully enrolled new speaker '{annotation.corrected_speaker}'") + logger.info( + f"✅ Successfully enrolled new speaker '{annotation.corrected_speaker}'" + ) # 4. Mark annotation as trained if annotation.processed_by: @@ -181,7 +201,9 @@ async def process_annotations_for_training( await annotation.save() except Exception as e: - logger.error(f"Error processing annotation {annotation.id}: {e}", exc_info=True) + logger.error( + f"Error processing annotation {annotation.id}: {e}", exc_info=True + ) failed_count += 1 errors.append(f"Exception: {str(e)[:50]}") continue @@ -192,15 +214,17 @@ async def process_annotations_for_training( f"({enrolled_count} new, {appended_count} appended, {failed_count} failed)" ) - return JSONResponse(content={ - "message": "Training complete", - "enrolled_new_speakers": enrolled_count, - "appended_to_existing": appended_count, - "total_processed": total_processed, - "failed_count": failed_count, - "errors": errors[:10] if errors else [], - "status": "success" if total_processed > 0 else "partial_failure" - }) + return JSONResponse( + content={ + "message": "Training complete", + "enrolled_new_speakers": enrolled_count, + "appended_to_existing": appended_count, + "total_processed": total_processed, + "failed_count": failed_count, + "errors": errors[:10] if errors else [], + "status": "success" if total_processed > 0 else "partial_failure", + } + ) except HTTPException: raise @@ -226,7 +250,9 @@ async def export_asr_dataset( Export job results with counts of conversations exported and annotations consumed. """ if not current_user.is_superuser: - raise HTTPException(status_code=403, detail="Only administrators can trigger ASR dataset export") + raise HTTPException( + status_code=403, detail="Only administrators can trigger ASR dataset export" + ) try: from advanced_omi_backend.workers.finetuning_jobs import run_asr_finetuning_job @@ -235,7 +261,9 @@ async def export_asr_dataset( return JSONResponse(content=result) except Exception as e: logger.error(f"ASR dataset export failed: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"ASR dataset export failed: {str(e)}") + raise HTTPException( + status_code=500, detail=f"ASR dataset export failed: {str(e)}" + ) @router.get("/status") @@ -269,7 +297,11 @@ async def get_finetuning_status( ).to_list() # Batch-check which conversation_ids still exist - conv_annotation_types = {AnnotationType.DIARIZATION, AnnotationType.TRANSCRIPT} + conv_annotation_types = { + AnnotationType.DIARIZATION, + AnnotationType.TRANSCRIPT, + AnnotationType.SPEECH_SUGGESTION_CORRECTION, + } all_conv_ids: set[str] = set() for ann_type in conv_annotation_types: for a in all_annotations_by_type.get(ann_type, []): @@ -291,8 +323,12 @@ async def get_finetuning_status( # Identify orphaned annotations for conversation-based types if ann_type in conv_annotation_types: - orphaned = [a for a in annotations if a.conversation_id in orphaned_conv_ids] - non_orphaned = [a for a in annotations if a.conversation_id not in orphaned_conv_ids] + orphaned = [ + a for a in annotations if a.conversation_id in orphaned_conv_ids + ] + non_orphaned = [ + a for a in annotations if a.conversation_id not in orphaned_conv_ids + ] else: # Memory/entity orphan detection is placeholder for now orphaned = [] @@ -300,9 +336,12 @@ async def get_finetuning_status( pending = [a for a in non_orphaned if not a.processed] processed = [a for a in non_orphaned if a.processed] - trained = [a for a in processed if a.processed_by and "training" in a.processed_by] + trained = [ + a for a in processed if a.processed_by and "training" in a.processed_by + ] applied_not_trained = [ - a for a in processed + a + for a in processed if not a.processed_by or "training" not in a.processed_by ] @@ -333,9 +372,17 @@ async def get_finetuning_status( if trained_diarization_list: latest_trained = max( trained_diarization_list, - key=lambda a: a.updated_at if a.updated_at else datetime.min.replace(tzinfo=timezone.utc) + key=lambda a: ( + a.updated_at + if a.updated_at + else datetime.min.replace(tzinfo=timezone.utc) + ), + ) + last_training_run = ( + latest_trained.updated_at.isoformat() + if latest_trained.updated_at + else None ) - last_training_run = latest_trained.updated_at.isoformat() if latest_trained.updated_at else None # Get cron job status from scheduler try: @@ -344,7 +391,9 @@ async def get_finetuning_status( scheduler = get_scheduler() all_jobs = await scheduler.get_all_jobs_status() # Find speaker finetuning job for backward compat - speaker_job = next((j for j in all_jobs if j["job_id"] == "speaker_finetuning"), None) + speaker_job = next( + (j for j in all_jobs if j["job_id"] == "speaker_finetuning"), None + ) cron_status = { "enabled": speaker_job["enabled"] if speaker_job else False, "schedule": speaker_job["schedule"] if speaker_job else "0 2 * * *", @@ -359,15 +408,17 @@ async def get_finetuning_status( "next_run": None, } - return JSONResponse(content={ - "pending_annotation_count": pending_count, - "applied_annotation_count": applied_count, - "trained_annotation_count": trained_count, - "last_training_run": last_training_run, - "cron_status": cron_status, - "annotation_counts": annotation_counts, - "orphaned_annotation_count": total_orphaned, - }) + return JSONResponse( + content={ + "pending_annotation_count": pending_count, + "applied_annotation_count": applied_count, + "trained_annotation_count": trained_count, + "last_training_run": last_training_run, + "cron_status": cron_status, + "annotation_counts": annotation_counts, + "orphaned_annotation_count": total_orphaned, + } + ) except Exception as e: logger.error(f"Error fetching fine-tuning status: {e}", exc_info=True) @@ -385,7 +436,9 @@ async def get_finetuning_status( @router.delete("/orphaned-annotations") async def delete_orphaned_annotations( current_user: User = Depends(current_active_user), - annotation_type: Optional[str] = Query(None, description="Filter by annotation type (e.g. 'diarization')"), + annotation_type: Optional[str] = Query( + None, description="Filter by annotation type (e.g. 'diarization')" + ), ): """ Find and delete orphaned annotations whose referenced conversation no longer exists. @@ -404,9 +457,17 @@ async def delete_orphaned_annotations( try: requested_type = AnnotationType(annotation_type) except ValueError: - raise HTTPException(status_code=400, detail=f"Unknown annotation type: {annotation_type}") + raise HTTPException( + status_code=400, detail=f"Unknown annotation type: {annotation_type}" + ) if requested_type not in conv_annotation_types: - return JSONResponse(content={"deleted_count": 0, "by_type": {}, "message": "Orphan detection not supported for this type"}) + return JSONResponse( + content={ + "deleted_count": 0, + "by_type": {}, + "message": "Orphan detection not supported for this type", + } + ) types_to_check = {requested_type} else: types_to_check = conv_annotation_types @@ -448,10 +509,12 @@ async def delete_orphaned_annotations( total_deleted += len(orphaned) logger.info(f"Deleted {total_deleted} orphaned annotations: {deleted_by_type}") - return JSONResponse(content={ - "deleted_count": total_deleted, - "by_type": deleted_by_type, - }) + return JSONResponse( + content={ + "deleted_count": total_deleted, + "by_type": deleted_by_type, + } + ) @router.post("/orphaned-annotations/reattach") diff --git a/backends/advanced/src/advanced_omi_backend/services/plugin_service.py b/backends/advanced/src/advanced_omi_backend/services/plugin_service.py index 9a8cc205..a2061c5d 100644 --- a/backends/advanced/src/advanced_omi_backend/services/plugin_service.py +++ b/backends/advanced/src/advanced_omi_backend/services/plugin_service.py @@ -18,6 +18,7 @@ from advanced_omi_backend.config_loader import get_plugins_yml_path from advanced_omi_backend.plugins import BasePlugin, PluginRouter +from advanced_omi_backend.plugins.events import PluginEvent from advanced_omi_backend.plugins.services import PluginServices logger = logging.getLogger(__name__) @@ -176,7 +177,9 @@ def replacer(match): return value -def load_plugin_config(plugin_id: str, orchestration_config: Dict[str, Any]) -> Dict[str, Any]: +def load_plugin_config( + plugin_id: str, orchestration_config: Dict[str, Any] +) -> Dict[str, Any]: """ Load complete plugin configuration from multiple sources. @@ -215,9 +218,13 @@ def load_plugin_config(plugin_id: str, orchestration_config: Dict[str, Any]) -> with open(plugin_config_path, "r") as f: plugin_config = yaml.safe_load(f) or {} config.update(plugin_config) - logger.debug(f"Loaded {len(plugin_config)} config keys for '{plugin_id}'") + logger.debug( + f"Loaded {len(plugin_config)} config keys for '{plugin_id}'" + ) else: - logger.debug(f"No config.yml found for plugin '{plugin_id}' at {plugin_config_path}") + logger.debug( + f"No config.yml found for plugin '{plugin_id}' at {plugin_config_path}" + ) except Exception as e: logger.warning(f"Failed to load config.yml for plugin '{plugin_id}': {e}") @@ -398,7 +405,9 @@ def load_schema_yml(plugin_id: str) -> Optional[Dict[str, Any]]: return None -def infer_schema_from_config(plugin_id: str, config_dict: Dict[str, Any]) -> Dict[str, Any]: +def infer_schema_from_config( + plugin_id: str, config_dict: Dict[str, Any] +) -> Dict[str, Any]: """Infer configuration schema from plugin config.yml. This function analyzes the config.yml file to generate a JSON schema @@ -480,8 +489,7 @@ def mask_secrets_in_config( if env_var and env_var in secret_env_vars: # Check if env var is set in per-plugin .env or os.environ is_set = bool( - (plugin_env and plugin_env.get(env_var)) - or os.environ.get(env_var) + (plugin_env and plugin_env.get(env_var)) or os.environ.get(env_var) ) masked_config[key] = "••••••••••••" if is_set else "" @@ -527,7 +535,9 @@ def get_plugin_metadata( # Mask secrets in current config current_config = load_plugin_config(plugin_id, orchestration_config) - masked_config = mask_secrets_in_config(current_config, config_schema, plugin_env=plugin_env) + masked_config = mask_secrets_in_config( + current_config, config_schema, plugin_env=plugin_env + ) # Mark which env vars are set (check per-plugin .env first, then os.environ) for env_var_name, env_var_schema in config_schema.get("env_vars", {}).items(): @@ -727,17 +737,24 @@ def _build_plugin_router() -> Optional[PluginRouter]: # Let plugin register its prompts with the prompt registry try: - from advanced_omi_backend.prompt_registry import get_prompt_registry + from advanced_omi_backend.prompt_registry import ( + get_prompt_registry, + ) + plugin.register_prompts(get_prompt_registry()) except Exception as e: - logger.debug(f"Plugin '{plugin_id}' prompt registration skipped: {e}") + logger.debug( + f"Plugin '{plugin_id}' prompt registration skipped: {e}" + ) # Note: async initialization happens in app_factory lifespan or reload_plugins router.register_plugin(plugin_id, plugin) logger.info(f"Plugin '{plugin_id}' registered successfully") except Exception as e: - logger.error(f"Failed to register plugin '{plugin_id}': {e}", exc_info=True) + logger.error( + f"Failed to register plugin '{plugin_id}': {e}", exc_info=True + ) logger.info( f"Plugin registration complete: {len(router.plugins)} plugin(s) registered" @@ -806,6 +823,62 @@ async def ensure_plugin_router() -> Optional[PluginRouter]: return plugin_router +async def dispatch_plugin_event( + event: PluginEvent, + user_id: str, + data: dict, + metadata: dict = None, + description: str = "", + require_router: bool = False, +) -> Optional[list]: + """Dispatch an event to the plugin system with standard logging. + + Handles the common pattern of: ensure router -> dispatch event -> log results. + + Args: + event: Plugin event to dispatch + user_id: User ID for the event + data: Event-specific data dict + metadata: Optional metadata dict + description: Log context (e.g., "conversation=abc123, memories=5") + require_router: If True and no router, raise RuntimeError instead of returning None + + Returns: + List of plugin results, or None if no router available + + Raises: + RuntimeError: If require_router=True and no plugin router is available + """ + plugin_router = await ensure_plugin_router() + + if not plugin_router: + if require_router: + raise RuntimeError( + f"Plugin router could not be initialized in worker process. " + f"{event.value} event will NOT be dispatched!" + ) + return None + + logger.info(f"🔌 DISPATCH: {event.value} event ({description})") + + plugin_results = await plugin_router.dispatch_event( + event=event, + user_id=user_id, + data=data, + metadata=metadata or {}, + ) + + result_count = len(plugin_results) if plugin_results else 0 + logger.info(f"🔌 RESULT: {event.value} dispatched to {result_count} plugins") + + if plugin_results: + for result in plugin_results: + if result.message: + logger.info(f" Plugin result: {result.message}") + + return plugin_results + + async def cleanup_plugin_router() -> None: """Clean up the plugin router and all registered plugins.""" global _plugin_router @@ -934,7 +1007,9 @@ def signal_worker_restart() -> None: try: timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") client.set(WORKER_RESTART_KEY, timestamp) - logger.info(f"Worker restart signal sent via Redis key '{WORKER_RESTART_KEY}'") + logger.info( + f"Worker restart signal sent via Redis key '{WORKER_RESTART_KEY}'" + ) finally: client.close() except Exception as e: diff --git a/backends/advanced/src/advanced_omi_backend/utils/job_utils.py b/backends/advanced/src/advanced_omi_backend/utils/job_utils.py index c9028909..695906c4 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/job_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/job_utils.py @@ -10,7 +10,31 @@ logger = logging.getLogger(__name__) -async def check_job_alive(redis_client, current_job, session_id: Optional[str] = None) -> bool: +def update_job_meta(**kwargs) -> None: + """Update the current RQ job's metadata with the given key-value pairs. + + Handles the common boilerplate of: get_current_job() -> null check -> + meta init -> update -> save_meta. + + Args: + **kwargs: Key-value pairs to merge into job.meta + + Example: + update_job_meta(conversation_id="abc", processing_time=1.5) + """ + from rq import get_current_job + + current_job = get_current_job() + if current_job: + if not current_job.meta: + current_job.meta = {} + current_job.meta.update(kwargs) + current_job.save_meta() + + +async def check_job_alive( + redis_client, current_job, session_id: Optional[str] = None +) -> bool: """ Check if current RQ job still exists in Redis. @@ -44,12 +68,19 @@ async def check_job_alive(redis_client, current_job, session_id: Optional[str] = if session_id: session_key = f"audio:session:{session_id}" session_status = await redis_client.hget(session_key, "status") - if session_status and session_status.decode() in ["finalizing", "finished"]: + if session_status and session_status.decode() in [ + "finalizing", + "finished", + ]: # Session ended naturally - not a zombie, just natural cleanup - logger.debug(f"📋 Job {current_job.id} ending naturally (session closed)") + logger.debug( + f"📋 Job {current_job.id} ending naturally (session closed)" + ) return False # True zombie - job deleted while session still active - logger.error(f"🧟 Zombie job detected - job {current_job.id} deleted from Redis while session still active, exiting") + logger.error( + f"🧟 Zombie job detected - job {current_job.id} deleted from Redis while session still active, exiting" + ) return False return True diff --git a/backends/advanced/src/advanced_omi_backend/workers/annotation_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/annotation_jobs.py index 3681ab5f..10d8f65e 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/annotation_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/annotation_jobs.py @@ -4,14 +4,14 @@ These jobs run periodically via the cron scheduler to: 1. Surface potential errors in transcripts and memories for user review 2. Fine-tune error detection models using accepted/rejected annotations - -TODO: Implement actual LLM-based error detection and model training logic. """ +import json import logging from datetime import datetime, timedelta, timezone from typing import List +from advanced_omi_backend.llm_client import async_generate from advanced_omi_backend.models.annotation import ( Annotation, AnnotationSource, @@ -20,100 +20,182 @@ ) from advanced_omi_backend.models.conversation import Conversation from advanced_omi_backend.models.user import User +from advanced_omi_backend.prompt_registry import get_prompt_registry logger = logging.getLogger(__name__) +LOOKBACK_DAYS = 7 +MAX_SEGMENTS_PER_PROMPT = 30 +MAX_SUGGESTIONS_PER_RUN = 50 + +PROMPT_ID = "annotation.transcript_error_detection" + async def surface_error_suggestions(): """ - Generate AI suggestions for potential transcript/memory errors. - Runs daily, creates PENDING annotations for user review. + Generate AI suggestions for potential transcript errors. - This is a PLACEHOLDER implementation. To fully implement: - 1. Query recent transcripts and memories (last N days) - 2. Use LLM to analyze content for potential errors: - - Hallucinations (made-up facts) - - Misheard words (audio transcription errors) - - Grammar/spelling issues - - Inconsistencies with other memories - 3. For each potential error: - - Create PENDING annotation with MODEL_SUGGESTION source - - Store original_text and suggested corrected_text - 4. Users can review suggestions in UI (accept/reject) - 5. Accepted suggestions improve future model accuracy - - TODO: Implement LLM-based error detection logic. + Runs daily via cron. For each user, queries recent conversations + and uses the LLM to identify potential transcription errors. + Creates PENDING annotations with MODEL_SUGGESTION source for + user review in the swipe UI. """ - logger.info("📝 Checking for annotation suggestions (placeholder)...") + logger.info("Checking for annotation suggestions...") + total_created = 0 try: - # Get all users users = await User.find_all().to_list() - logger.info(f" Found {len(users)} users to analyze") + logger.info(f"Found {len(users)} users to analyze") for user in users: - # TODO: Query recent conversations for this user (last 7 days) - # recent_conversations = await Conversation.find( - # Conversation.user_id == str(user.id), - # Conversation.created_at >= datetime.now(timezone.utc) - timedelta(days=7) - # ).to_list() - - # TODO: For each conversation, analyze transcripts - # for conversation in recent_conversations: - # active_transcript = conversation.get_active_transcript() - # if not active_transcript: - # continue - # - # # TODO: Use LLM to identify potential errors - # # suggestions = await llm_provider.analyze_transcript_for_errors( - # # segments=active_transcript.segments, - # # context=conversation.summary - # # ) - # - # # TODO: Create PENDING annotations for each suggestion - # # for suggestion in suggestions: - # # annotation = Annotation( - # # annotation_type=AnnotationType.TRANSCRIPT, - # # user_id=str(user.id), - # # conversation_id=conversation.conversation_id, - # # segment_index=suggestion.segment_index, - # # original_text=suggestion.original_text, - # # corrected_text=suggestion.suggested_text, - # # source=AnnotationSource.MODEL_SUGGESTION, - # # status=AnnotationStatus.PENDING - # # ) - # # await annotation.save() - - # TODO: Query recent memories for this user - # recent_memories = await memory_service.get_recent_memories( - # user_id=str(user.id), - # days=7 - # ) - - # TODO: Use LLM to identify potential errors in memories - # for memory in recent_memories: - # # TODO: Analyze memory content for hallucinations/errors - # # suggestions = await llm_provider.analyze_memory_for_errors( - # # content=memory.content, - # # metadata=memory.metadata - # # ) - # - # # TODO: Create PENDING annotations - # # ... - - # Placeholder logging - logger.debug(f" Analyzed user {user.id} (placeholder)") - - logger.info("✅ Suggestion check complete (placeholder implementation)") - logger.info( - " ℹ️ TODO: Implement LLM-based error detection to create actual suggestions" - ) + user_id = str(user.id) + cutoff = datetime.now(timezone.utc) - timedelta(days=LOOKBACK_DAYS) + + recent_conversations = await Conversation.find( + Conversation.user_id == user_id, + Conversation.created_at >= cutoff, + Conversation.deleted != True, + ).to_list() + + if not recent_conversations: + logger.info( + f"User {user.email or user_id}: no recent conversations, skipping" + ) + continue + + logger.info( + f"User {user.email or user_id}: {len(recent_conversations)} conversations in last {LOOKBACK_DAYS} days" + ) + + # Get conversation IDs that already have pending model suggestions + existing = await Annotation.find( + Annotation.user_id == user_id, + Annotation.source == AnnotationSource.MODEL_SUGGESTION, + Annotation.status == AnnotationStatus.PENDING, + ).to_list() + skip_conversation_ids = { + a.conversation_id for a in existing if a.conversation_id + } + if skip_conversation_ids: + logger.info( + f" Skipping {len(skip_conversation_ids)} conversations with existing pending suggestions" + ) + + created_for_user = 0 + for conversation in recent_conversations: + if total_created >= MAX_SUGGESTIONS_PER_RUN: + logger.info( + f" Reached max suggestions per run ({MAX_SUGGESTIONS_PER_RUN}), stopping" + ) + break + if conversation.conversation_id in skip_conversation_ids: + continue + + active_transcript = conversation.active_transcript + if not active_transcript or not active_transcript.segments: + logger.debug( + f" Conversation '{conversation.title or conversation.conversation_id}': no transcript/segments, skipping" + ) + continue + + seg_count = len(active_transcript.segments) + logger.info( + f" Analyzing '{conversation.title or 'Untitled'}' " + f"({seg_count} segments, id={conversation.conversation_id[:8]}...)" + ) + + suggestions = await _analyze_transcript(conversation, active_transcript) + + if not suggestions: + logger.info(f" No issues found") + else: + logger.info(f" LLM found {len(suggestions)} potential issues") + + for suggestion in suggestions: + if total_created >= MAX_SUGGESTIONS_PER_RUN: + break + + seg_idx = suggestion.get("segment_index") + if seg_idx is None or seg_idx >= len(active_transcript.segments): + logger.debug(f" Skipping invalid segment_index={seg_idx}") + continue + + annotation = Annotation( + annotation_type=AnnotationType.TRANSCRIPT, + user_id=user_id, + conversation_id=conversation.conversation_id, + segment_index=seg_idx, + original_text=suggestion.get("original_text", ""), + corrected_text=suggestion.get("corrected_text", ""), + source=AnnotationSource.MODEL_SUGGESTION, + status=AnnotationStatus.PENDING, + ) + await annotation.save() + total_created += 1 + created_for_user += 1 + logger.info( + f" Created suggestion: segment {seg_idx} - " + f"'{suggestion.get('reason', 'unknown')}'" + ) + + logger.info( + f"User {user.email or user_id}: {created_for_user} suggestions created" + ) + + logger.info(f"Suggestion check complete: {total_created} annotations created") except Exception as e: - logger.error(f"❌ Error in surface_error_suggestions: {e}", exc_info=True) + logger.error(f"Error in surface_error_suggestions: {e}", exc_info=True) raise +async def _analyze_transcript(conversation, transcript) -> list[dict]: + """Use LLM to analyze a transcript for potential errors.""" + segments = transcript.segments[:MAX_SEGMENTS_PER_PROMPT] + segments_text = "\n".join( + f"{i}: {seg.speaker} - {seg.text}" + for i, seg in enumerate(segments) + if seg.text.strip() + ) + + if not segments_text: + logger.debug(f" No non-empty segments to analyze") + return [] + + registry = get_prompt_registry() + prompt = await registry.get_prompt( + PROMPT_ID, + title=conversation.title or "Untitled", + segments_text=segments_text, + ) + + try: + logger.debug(f" Sending {len(segments)} segments to LLM for analysis...") + response = await async_generate(prompt) + logger.debug(f" LLM response length: {len(response)} chars") + # Parse JSON from response, handling markdown code blocks + text = response.strip() + if text.startswith("```"): + text = text.split("\n", 1)[1] if "\n" in text else text[3:] + text = text.rsplit("```", 1)[0] + suggestions = json.loads(text) + if not isinstance(suggestions, list): + logger.warning(f" LLM returned non-list response, ignoring") + return [] + return suggestions + except json.JSONDecodeError as e: + logger.warning( + f" Failed to parse LLM JSON for '{conversation.title or conversation.conversation_id}': {e}" + ) + logger.debug(f" Raw LLM response: {response[:500]}") + return [] + except Exception as e: + logger.warning( + f" LLM call failed for '{conversation.title or conversation.conversation_id}': {e}" + ) + return [] + + async def finetune_hallucination_model(): """ Fine-tune error detection model using accepted/rejected annotations. @@ -199,15 +281,11 @@ async def finetune_hallucination_model(): # Calculate acceptance rate if accepted_count + rejected_count > 0: - acceptance_rate = ( - accepted_count / (accepted_count + rejected_count) - ) * 100 + acceptance_rate = (accepted_count / (accepted_count + rejected_count)) * 100 logger.info(f" Suggestion acceptance rate: {acceptance_rate:.1f}%") logger.info("✅ Training check complete (placeholder implementation)") - logger.info( - " ℹ️ TODO: Implement model fine-tuning using user feedback data" - ) + logger.info(" ℹ️ TODO: Implement model fine-tuning using user feedback data") except Exception as e: logger.error(f"❌ Error in finetune_hallucination_model: {e}", exc_info=True) @@ -216,6 +294,7 @@ async def finetune_hallucination_model(): # Additional helper functions for future implementation + async def analyze_common_error_patterns() -> List[dict]: """ Analyze accepted annotations to identify common error patterns. diff --git a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py index 2142ce07..5f7487e5 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py @@ -24,7 +24,7 @@ from advanced_omi_backend.observability.otel_setup import set_otel_session from advanced_omi_backend.plugins.events import PluginEvent from advanced_omi_backend.services.plugin_service import ( - ensure_plugin_router, + dispatch_plugin_event, get_plugin_router, ) from advanced_omi_backend.utils.conversation_utils import ( @@ -35,6 +35,7 @@ track_speech_activity, update_job_progress_metadata, ) +from advanced_omi_backend.utils.job_utils import update_job_meta logger = logging.getLogger(__name__) @@ -1161,27 +1162,16 @@ async def generate_title_summary_job( processing_time = time.time() - start_time # Update job metadata - from rq import get_current_job - - current_job = get_current_job() - if current_job: - if not current_job.meta: - current_job.meta = {} - current_job.meta.update( - { - "conversation_id": conversation_id, - "title": conversation.title, - "summary": conversation.summary, - "detailed_summary_length": ( - len(conversation.detailed_summary) - if conversation.detailed_summary - else 0 - ), - "segment_count": len(segments), - "processing_time": processing_time, - } - ) - current_job.save_meta() + update_job_meta( + conversation_id=conversation_id, + title=conversation.title, + summary=conversation.summary, + detailed_summary_length=( + len(conversation.detailed_summary) if conversation.detailed_summary else 0 + ), + segment_count=len(segments), + processing_time=processing_time, + ) logger.info( f"✅ Title/summary generation completed for {conversation_id} in {processing_time:.2f}s" @@ -1264,64 +1254,25 @@ async def dispatch_conversation_complete_event_job( user_email = user.email if user else "" # Prepare plugin event data (same format as open_conversation_job) + actual_end_reason = end_reason or "file_upload" try: - plugin_router = await ensure_plugin_router() - - # CRITICAL CHECK: Fail loudly if no router - if not plugin_router: - error_msg = ( - f"❌ Plugin router could not be initialized in worker process. " - f"conversation.complete event for {conversation_id[:12]} will NOT be dispatched!" - ) - logger.error(error_msg) - - return { - "success": False, - "skipped": True, - "reason": "No plugin router", - "conversation_id": conversation_id, - "error": error_msg, - } - - plugin_data = { - "conversation": { - "client_id": client_id, - "user_id": user_id, - }, - "transcript": conversation.transcript if conversation else "", - "duration": 0, # Duration not tracked for file uploads - "conversation_id": conversation_id, - } - - # Use provided end_reason or default to 'file_upload' for backward compatibility - actual_end_reason = end_reason or "file_upload" - - logger.info( - f"🔌 DISPATCH: conversation.complete event for {conversation_id[:12]} " - f"(end_reason={actual_end_reason}, user={user_id}, client={client_id})" - ) - - plugin_results = await plugin_router.dispatch_event( + plugin_results = await dispatch_plugin_event( event=PluginEvent.CONVERSATION_COMPLETE, user_id=user_id, - data=plugin_data, + data={ + "conversation": { + "client_id": client_id, + "user_id": user_id, + }, + "transcript": conversation.transcript if conversation else "", + "duration": 0, # Duration not tracked for file uploads + "conversation_id": conversation_id, + }, metadata={"end_reason": actual_end_reason}, + description=f"conversation={conversation_id[:12]}, end_reason={actual_end_reason}", + require_router=True, ) - logger.info( - f"🔌 RESULT: conversation.complete dispatched to {len(plugin_results) if plugin_results else 0} plugins" - ) - if plugin_results: - logger.info( - f"📌 Triggered {len(plugin_results)} conversation-level plugins" - ) - for result in plugin_results: - logger.info( - f" Plugin result: success={result.success}, message={result.message}" - ) - if result.message: - logger.info(f" Plugin result: {result.message}") - processing_time = time.time() - start_time logger.info( f"✅ Conversation complete event dispatched for {conversation_id} in {processing_time:.2f}s" @@ -1334,6 +1285,15 @@ async def dispatch_conversation_complete_event_job( "processing_time_seconds": processing_time, } + except RuntimeError as e: + logger.error(f"❌ {e}") + return { + "success": False, + "skipped": True, + "reason": "No plugin router", + "conversation_id": conversation_id, + "error": str(e), + } except Exception as e: logger.warning(f"⚠️ Error dispatching conversation complete event: {e}") return { diff --git a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py index 492dc650..1fcfe510 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py @@ -27,7 +27,7 @@ set_otel_session, ) from advanced_omi_backend.plugins.events import PluginEvent -from advanced_omi_backend.services.plugin_service import ensure_plugin_router +from advanced_omi_backend.services.plugin_service import dispatch_plugin_event logger = logging.getLogger(__name__) @@ -391,11 +391,12 @@ async def process_memory_job( logger.warning(f"⚠️ Knowledge graph extraction failed (non-fatal): {e}") # Trigger memory-level plugins (ALWAYS dispatch when success, even with 0 new memories) + memory_count = len(created_memory_ids) if created_memory_ids else 0 try: - plugin_router = await ensure_plugin_router() - - if plugin_router: - plugin_data = { + await dispatch_plugin_event( + event=PluginEvent.MEMORY_PROCESSED, + user_id=user_id, + data={ "memories": created_memory_ids or [], "conversation": { "conversation_id": conversation_id, @@ -403,39 +404,15 @@ async def process_memory_job( "user_id": user_id, "user_email": user_email, }, - "memory_count": ( - len(created_memory_ids) if created_memory_ids else 0 - ), + "memory_count": memory_count, "conversation_id": conversation_id, - } - - logger.info( - f"🔌 DISPATCH: memory.processed event " - f"(conversation={conversation_id[:12]}, memories={len(created_memory_ids) if created_memory_ids else 0})" - ) - - plugin_results = await plugin_router.dispatch_event( - event=PluginEvent.MEMORY_PROCESSED, - user_id=user_id, - data=plugin_data, - metadata={ - "processing_time": processing_time, - "memory_provider": memory_provider, - }, - ) - - logger.info( - f"🔌 RESULT: memory.processed dispatched to {len(plugin_results) if plugin_results else 0} plugins" - ) - - if plugin_results: - logger.info( - f"📌 Triggered {len(plugin_results)} memory-level plugins" - ) - for result in plugin_results: - if result.message: - logger.info(f" Plugin result: {result.message}") - + }, + metadata={ + "processing_time": processing_time, + "memory_provider": memory_provider, + }, + description=f"conversation={conversation_id[:12]}, memories={memory_count}", + ) except Exception as e: logger.warning(f"⚠️ Error triggering memory-level plugins: {e}") diff --git a/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py index 0a4192fb..1576d9eb 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py @@ -12,22 +12,17 @@ from advanced_omi_backend.auth import generate_jwt_for_user from advanced_omi_backend.models.conversation import Conversation from advanced_omi_backend.models.job import async_job -from advanced_omi_backend.services.audio_stream import ( - TranscriptionResultsAggregator, -) +from advanced_omi_backend.services.audio_stream import TranscriptionResultsAggregator from advanced_omi_backend.speaker_recognition_client import SpeakerRecognitionClient from advanced_omi_backend.users import get_user_by_id +from advanced_omi_backend.utils.job_utils import update_job_meta logger = logging.getLogger(__name__) @async_job(redis=True, beanie=True) async def check_enrolled_speakers_job( - session_id: str, - user_id: str, - client_id: str, - *, - redis_client=None + session_id: str, user_id: str, client_id: str, *, redis_client=None ) -> Dict[str, Any]: """ Check if any enrolled speakers are present in the current audio stream. @@ -54,19 +49,23 @@ async def check_enrolled_speakers_job( # Check for enrolled speakers speaker_client = SpeakerRecognitionClient() - enrolled_present, speaker_result = await speaker_client.check_if_enrolled_speaker_present( - redis_client=redis_client, - client_id=client_id, - session_id=session_id, - user_id=user_id, - transcription_results=raw_results + enrolled_present, speaker_result = ( + await speaker_client.check_if_enrolled_speaker_present( + redis_client=redis_client, + client_id=client_id, + session_id=session_id, + user_id=user_id, + transcription_results=raw_results, + ) ) # Check for errors from speaker service if speaker_result and speaker_result.get("error"): error_type = speaker_result.get("error") error_message = speaker_result.get("message", "Unknown error") - logger.error(f"🎤 [SPEAKER CHECK] Speaker service error: {error_type} - {error_message}") + logger.error( + f"🎤 [SPEAKER CHECK] Speaker service error: {error_type} - {error_message}" + ) # For connection failures, assume no enrolled speakers but allow conversation to proceed # Speaker filtering is optional - if service is down, conversation should still be created @@ -82,7 +81,7 @@ async def check_enrolled_speakers_job( "enrolled_present": False, "identified_speakers": [], "skip_reason": f"Speaker service unavailable: {error_type}", - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } # For other processing errors, also assume no enrolled speakers @@ -93,7 +92,7 @@ async def check_enrolled_speakers_job( "error_details": error_message, "enrolled_present": False, "identified_speakers": [], - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } # Extract identified speakers @@ -101,31 +100,31 @@ async def check_enrolled_speakers_job( if speaker_result and "segments" in speaker_result: for seg in speaker_result["segments"]: identified_as = seg.get("identified_as") - if identified_as and identified_as != "Unknown" and identified_as not in identified_speakers: + if ( + identified_as + and identified_as != "Unknown" + and identified_as not in identified_speakers + ): identified_speakers.append(identified_as) processing_time = time.time() - start_time if enrolled_present: - logger.info(f"✅ Enrolled speaker(s) found: {', '.join(identified_speakers)} ({processing_time:.2f}s)") + logger.info( + f"✅ Enrolled speaker(s) found: {', '.join(identified_speakers)} ({processing_time:.2f}s)" + ) else: logger.info(f"⏭️ No enrolled speakers found ({processing_time:.2f}s)") # Update job metadata for timeline tracking - from rq import get_current_job - current_job = get_current_job() - if current_job: - if not current_job.meta: - current_job.meta = {} - current_job.meta.update({ - "session_id": session_id, - "client_id": client_id, - "enrolled_present": enrolled_present, - "identified_speakers": identified_speakers, - "speaker_count": len(identified_speakers), - "processing_time": processing_time - }) - current_job.save_meta() + update_job_meta( + session_id=session_id, + client_id=client_id, + enrolled_present=enrolled_present, + identified_speakers=identified_speakers, + speaker_count=len(identified_speakers), + processing_time=processing_time, + ) return { "success": True, @@ -133,7 +132,7 @@ async def check_enrolled_speakers_job( "enrolled_present": enrolled_present, "identified_speakers": identified_speakers, "speaker_result": speaker_result, - "processing_time_seconds": processing_time + "processing_time_seconds": processing_time, } @@ -144,7 +143,7 @@ async def recognise_speakers_job( transcript_text: str = "", words: list = None, *, - redis_client=None + redis_client=None, ) -> Dict[str, Any]: """ RQ job function for identifying speakers in a transcribed conversation. @@ -168,12 +167,16 @@ async def recognise_speakers_job( Dict with processing results """ - logger.info(f"🎤 RQ: Starting speaker recognition for conversation {conversation_id}") + logger.info( + f"🎤 RQ: Starting speaker recognition for conversation {conversation_id}" + ) start_time = time.time() # Get the conversation - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) if not conversation: logger.error(f"Conversation {conversation_id} not found") return {"success": False, "error": "Conversation not found"} @@ -201,7 +204,7 @@ async def recognise_speakers_job( "conversation_id": conversation_id, "version_id": version_id, "speaker_recognition_enabled": False, - "processing_time_seconds": 0 + "processing_time_seconds": 0, } # Get provider capabilities from metadata @@ -222,7 +225,9 @@ async def recognise_speakers_job( # If we have existing segments from provider, proceed to identification if transcript_version.segments: - logger.info(f"🎤 Using {len(transcript_version.segments)} segments from provider") + logger.info( + f"🎤 Using {len(transcript_version.segments)} segments from provider" + ) # Continue to speaker identification below (after this block) else: logger.warning(f"🎤 Provider claimed diarization but no segments found") @@ -237,32 +242,35 @@ async def recognise_speakers_job( if not actual_words and transcript_version.words: # Convert Word objects to dicts for speaker service API actual_words = [ - { - "word": w.word, - "start": w.start, - "end": w.end, - "confidence": w.confidence - } + {"word": w.word, "start": w.start, "end": w.end, "confidence": w.confidence} for w in transcript_version.words ] - logger.info(f"🔤 Loaded {len(actual_words)} words from transcript version.words field") + logger.info( + f"🔤 Loaded {len(actual_words)} words from transcript version.words field" + ) # Backward compatibility: Fall back to metadata if words field is empty (old data) elif not actual_words and transcript_version.metadata.get("words"): actual_words = transcript_version.metadata.get("words", []) - logger.info(f"🔤 Loaded {len(actual_words)} words from transcript version metadata (legacy)") + logger.info( + f"🔤 Loaded {len(actual_words)} words from transcript version metadata (legacy)" + ) # Backward compatibility: Extract from segments if that's all we have (old streaming data) elif not actual_words and transcript_version.segments: for segment in transcript_version.segments: if segment.words: for w in segment.words: - actual_words.append({ - "word": w.word, - "start": w.start, - "end": w.end, - "confidence": w.confidence - }) + actual_words.append( + { + "word": w.word, + "start": w.start, + "end": w.end, + "confidence": w.confidence, + } + ) if actual_words: - logger.info(f"🔤 Extracted {len(actual_words)} words from segments (legacy)") + logger.info( + f"🔤 Extracted {len(actual_words)} words from segments (legacy)" + ) if not actual_transcript_text: logger.warning(f"🎤 No transcript text found in version {version_id}") @@ -271,7 +279,7 @@ async def recognise_speakers_job( "conversation_id": conversation_id, "version_id": version_id, "error": "No transcript text available", - "processing_time_seconds": 0 + "processing_time_seconds": 0, } # Check if we can run pyannote diarization @@ -290,7 +298,7 @@ async def recognise_speakers_job( "conversation_id": conversation_id, "version_id": version_id, "error": "No word timestamps and no segments available", - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } # Has existing segments - fall through to run identification on them logger.info( @@ -303,6 +311,7 @@ async def recognise_speakers_job( # 1. Config toggle (per_segment_speaker_id) enables per-segment globally # 2. Manual reprocess trigger also enables per-segment for that run from advanced_omi_backend.config import get_misc_settings + misc_config = get_misc_settings() per_segment_config = misc_config.get("per_segment_speaker_id", False) @@ -323,7 +332,11 @@ async def recognise_speakers_job( # Have existing segments and can't/shouldn't run pyannote - do identification only # Covers: provider already diarized, no word timestamps but segments exist, etc. # Only send speech segments for identification; skip event/note segments - speech_segments = [s for s in transcript_version.segments if getattr(s, 'segment_type', 'speech') == 'speech'] + speech_segments = [ + s + for s in transcript_version.segments + if getattr(s, "segment_type", "speech") == "speech" + ] logger.info( f"🎤 Using segment-level speaker identification on {len(speech_segments)} speech segments " f"(skipped {len(transcript_version.segments) - len(speech_segments)} non-speech)" @@ -341,10 +354,7 @@ async def recognise_speakers_job( ) else: # Standard path: full diarization + identification via speaker service - transcript_data = { - "text": actual_transcript_text, - "words": actual_words - } + transcript_data = {"text": actual_transcript_text, "words": actual_words} # Generate backend token for speaker service to fetch audio try: @@ -356,35 +366,41 @@ async def recognise_speakers_job( "conversation_id": conversation_id, "version_id": version_id, "error": "User not found", - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } backend_token = generate_jwt_for_user(user_id, user.email) logger.info(f"🔐 Generated backend token for speaker service") except Exception as token_error: - logger.error(f"Failed to generate backend token: {token_error}", exc_info=True) + logger.error( + f"Failed to generate backend token: {token_error}", exc_info=True + ) return { "success": False, "conversation_id": conversation_id, "version_id": version_id, "error": f"Token generation failed: {token_error}", - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } - logger.info(f"🎤 Calling speaker recognition service with conversation_id...") + logger.info( + f"🎤 Calling speaker recognition service with conversation_id..." + ) speaker_result = await speaker_client.diarize_identify_match( conversation_id=conversation_id, backend_token=backend_token, transcript_data=transcript_data, - user_id=user_id + user_id=user_id, ) # Check for errors from speaker service if speaker_result.get("error"): error_type = speaker_result.get("error") error_message = speaker_result.get("message", "Unknown error") - logger.error(f"🎤 Speaker recognition service error: {error_type} - {error_message}") + logger.error( + f"🎤 Speaker recognition service error: {error_type} - {error_message}" + ) # Connection/timeout errors → skip gracefully (existing behavior) if error_type in ("connection_failed", "timeout", "client_error"): @@ -401,7 +417,7 @@ async def recognise_speakers_job( "identified_speakers": [], "skip_reason": f"Speaker service unavailable: {error_type}", "error_type": error_type, - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } # Validation errors → fail job, don't retry @@ -414,7 +430,7 @@ async def recognise_speakers_job( "error": f"Validation error: {error_message}", "error_type": error_type, "retryable": False, # Don't retry validation errors - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } # Resource errors → fail job, can retry later @@ -427,7 +443,7 @@ async def recognise_speakers_job( "error": f"Resource error: {error_message}", "error_type": error_type, "retryable": True, # Can retry later when resources available - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } # Unknown errors → fail job @@ -439,11 +455,15 @@ async def recognise_speakers_job( "error": f"Speaker recognition failed: {error_type}", "error_details": error_message, "error_type": error_type, - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } # Service worked but found no segments (legitimate empty result) - if not speaker_result or "segments" not in speaker_result or not speaker_result["segments"]: + if ( + not speaker_result + or "segments" not in speaker_result + or not speaker_result["segments"] + ): logger.warning(f"🎤 Speaker recognition returned no segments") return { "success": True, @@ -451,7 +471,7 @@ async def recognise_speakers_job( "version_id": version_id, "speaker_recognition_enabled": True, "identified_speakers": [], - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } speaker_segments = speaker_result["segments"] @@ -486,12 +506,16 @@ async def recognise_speakers_job( continue # Skip segments with invalid structure - if not isinstance(seg.get("start"), (int, float)) or not isinstance(seg.get("end"), (int, float)): + if not isinstance(seg.get("start"), (int, float)) or not isinstance( + seg.get("end"), (int, float) + ): empty_segment_count += 1 logger.debug(f"Filtered segment with invalid timing: {seg}") continue - speaker_name = seg.get("identified_as") or unknown_label_map.get(seg.get("speaker", "Unknown"), "Unknown Speaker") + speaker_name = seg.get("identified_as") or unknown_label_map.get( + seg.get("speaker", "Unknown"), "Unknown Speaker" + ) # Extract words from speaker service response (already matched to this segment) words_data = seg.get("words", []) @@ -500,13 +524,14 @@ async def recognise_speakers_job( word=w.get("word", ""), start=w.get("start", 0.0), end=w.get("end", 0.0), - confidence=w.get("confidence") + confidence=w.get("confidence"), ) for w in words_data ] # Classify segment type from content from advanced_omi_backend.utils.segment_utils import classify_segment_text + seg_classification = classify_segment_text(text) seg_type = "event" if seg_classification == "event" else "speech" @@ -519,18 +544,21 @@ async def recognise_speakers_job( segment_type=seg_type, identified_as=seg.get("identified_as"), confidence=seg.get("confidence"), - words=segment_words # Use words from speaker service + words=segment_words, # Use words from speaker service ) ) if empty_segment_count > 0: - logger.info(f"🔇 Filtered out {empty_segment_count} empty segments from speaker recognition") + logger.info( + f"🔇 Filtered out {empty_segment_count} empty segments from speaker recognition" + ) # Re-insert non-speech segments (event/note) that were skipped during identification # They need to be merged back into position based on timestamps non_speech_segments = [ - s for s in transcript_version.segments - if getattr(s, 'segment_type', 'speech') != 'speech' + s + for s in transcript_version.segments + if getattr(s, "segment_type", "speech") != "speech" ] if non_speech_segments: for ns_seg in non_speech_segments: @@ -541,7 +569,9 @@ async def recognise_speakers_job( insert_pos = i break updated_segments.insert(insert_pos, ns_seg) - logger.info(f"🎤 Re-inserted {len(non_speech_segments)} non-speech segments") + logger.info( + f"🎤 Re-inserted {len(non_speech_segments)} non-speech segments" + ) # Update the transcript version transcript_version.segments = updated_segments @@ -559,24 +589,31 @@ async def recognise_speakers_job( sr_metadata = { "enabled": True, - "identification_mode": "per_segment" if use_per_segment else "majority_vote", + "identification_mode": ( + "per_segment" if use_per_segment else "majority_vote" + ), "identified_speakers": list(identified_speakers), "speaker_count": len(identified_speakers), "total_segments": len(speaker_segments), - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } if speaker_result.get("partial_errors"): sr_metadata["partial_errors"] = speaker_result["partial_errors"] transcript_version.metadata["speaker_recognition"] = sr_metadata # Set diarization source if pyannote ran (provider didn't do diarization) - if not provider_has_diarization and transcript_version.diarization_source != "provider": + if ( + not provider_has_diarization + and transcript_version.diarization_source != "provider" + ): transcript_version.diarization_source = "pyannote" await conversation.save() processing_time = time.time() - start_time - logger.info(f"✅ Speaker recognition completed for {conversation_id} in {processing_time:.2f}s") + logger.info( + f"✅ Speaker recognition completed for {conversation_id} in {processing_time:.2f}s" + ) return { "success": True, @@ -585,22 +622,18 @@ async def recognise_speakers_job( "speaker_recognition_enabled": True, "identified_speakers": list(identified_speakers), "segment_count": len(updated_segments), - "processing_time_seconds": processing_time + "processing_time_seconds": processing_time, } except asyncio.TimeoutError as e: logger.error(f"❌ Speaker recognition timeout: {e}") # Add timeout metadata to job - from rq import get_current_job - current_job = get_current_job() - if current_job: - current_job.meta.update({ - "error_type": "timeout", - "audio_duration": conversation.audio_total_duration if conversation else None, - "timeout_occurred_at": time.time() - }) - current_job.save_meta() + update_job_meta( + error_type="timeout", + audio_duration=conversation.audio_total_duration if conversation else None, + timeout_occurred_at=time.time(), + ) return { "success": False, @@ -608,13 +641,16 @@ async def recognise_speakers_job( "version_id": version_id, "error": "Speaker recognition timeout", "error_type": "timeout", - "audio_duration": conversation.audio_total_duration if conversation else None, - "processing_time_seconds": time.time() - start_time + "audio_duration": ( + conversation.audio_total_duration if conversation else None + ), + "processing_time_seconds": time.time() - start_time, } except Exception as speaker_error: logger.error(f"❌ Speaker recognition failed: {speaker_error}") import traceback + logger.debug(traceback.format_exc()) return { @@ -622,5 +658,5 @@ async def recognise_speakers_job( "conversation_id": conversation_id, "version_id": version_id, "error": str(speaker_error), - "processing_time_seconds": time.time() - start_time + "processing_time_seconds": time.time() - start_time, } diff --git a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py index fb644ec1..a9e98c5f 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py @@ -10,10 +10,8 @@ import logging import os import time -import uuid import wave from datetime import datetime -from pathlib import Path from typing import Any, Dict from beanie.operators import In @@ -21,23 +19,18 @@ from rq.exceptions import NoSuchJobError from rq.job import Job -from advanced_omi_backend.config import ( - get_backend_config, - get_transcription_job_timeout, -) +from advanced_omi_backend.config import get_transcription_job_timeout from advanced_omi_backend.controllers.queue_controller import ( JOB_RESULT_TTL, - REDIS_URL, - redis_conn, start_post_conversation_jobs, transcription_queue, ) from advanced_omi_backend.models.audio_chunk import AudioChunkDocument from advanced_omi_backend.models.conversation import Conversation -from advanced_omi_backend.models.job import BaseRQJob, JobPriority, async_job +from advanced_omi_backend.models.job import async_job from advanced_omi_backend.plugins.events import PluginEvent from advanced_omi_backend.services.audio_stream import TranscriptionResultsAggregator -from advanced_omi_backend.services.plugin_service import ensure_plugin_router +from advanced_omi_backend.services.plugin_service import dispatch_plugin_event from advanced_omi_backend.services.transcription import ( get_transcription_provider, is_transcription_available, @@ -50,6 +43,7 @@ analyze_speech, mark_conversation_deleted, ) +from advanced_omi_backend.utils.job_utils import update_job_meta logger = logging.getLogger(__name__) @@ -295,48 +289,21 @@ def _on_batch_progress(event: dict) -> None: # Trigger transcript-level plugins BEFORE speech validation # This ensures wake-word commands execute even if conversation gets deleted - logger.info( - f"🔍 DEBUG: About to trigger plugins - transcript_text exists: {bool(transcript_text)}" - ) if transcript_text: try: - plugin_router = await ensure_plugin_router() - - if plugin_router: - logger.info( - f"🔍 DEBUG: Preparing to trigger transcript plugins for conversation {conversation_id}" - ) - plugin_data = { + await dispatch_plugin_event( + event=PluginEvent.TRANSCRIPT_BATCH, + user_id=user_id, + data={ "transcript": transcript_text, "segment_id": f"{conversation_id}_batch", "conversation_id": conversation_id, "segments": segments, "word_count": len(words), - } - - logger.info( - f"🔌 DISPATCH: transcript.batch event " - f"(conversation={conversation_id[:12]}, words={len(words)})" - ) - - plugin_results = await plugin_router.dispatch_event( - event=PluginEvent.TRANSCRIPT_BATCH, - user_id=user_id, - data=plugin_data, - metadata={"client_id": client_id}, - ) - - logger.info( - f"🔌 RESULT: transcript.batch dispatched to {len(plugin_results) if plugin_results else 0} plugins" - ) - - if plugin_results: - logger.info( - f"✅ Triggered {len(plugin_results)} transcript plugins in batch mode" - ) - for result in plugin_results: - if result.message: - logger.info(f" Plugin: {result.message}") + }, + metadata={"client_id": client_id}, + description=f"conversation={conversation_id[:12]}, words={len(words)}", + ) except Exception as e: logger.exception( f"⚠️ Error triggering transcript plugins in batch mode: {e}" @@ -573,21 +540,14 @@ def _on_batch_progress(event: dict) -> None: ) # Update job metadata with title and summary for UI display - current_job = get_current_job() - if current_job: - if not current_job.meta: - current_job.meta = {} - current_job.meta.update( - { - "conversation_id": conversation_id, - "title": conversation.title, - "summary": conversation.summary, - "transcript_length": len(transcript_text), - "word_count": len(words), - "processing_time": processing_time, - } - ) - current_job.save_meta() + update_job_meta( + conversation_id=conversation_id, + title=conversation.title, + summary=conversation.summary, + transcript_length=len(transcript_text), + word_count=len(words), + processing_time=processing_time, + ) return { "success": True, @@ -972,18 +932,12 @@ async def stream_speech_detection_job( ) # Update job metadata to show status - if current_job: - if not current_job.meta: - current_job.meta = {} - current_job.meta.update( - { - "status": "listening_for_speech", - "session_id": session_id, - "client_id": client_id, - "session_level": True, # Mark as session-level job - } - ) - current_job.save_meta() + update_job_meta( + status="listening_for_speech", + session_id=session_id, + client_id=client_id, + session_level=True, # Mark as session-level job + ) # Track when session closes for graceful shutdown session_closed_at = None diff --git a/backends/advanced/webui/package-lock.json b/backends/advanced/webui/package-lock.json index 54ca06ae..c3bd503e 100644 --- a/backends/advanced/webui/package-lock.json +++ b/backends/advanced/webui/package-lock.json @@ -18,6 +18,7 @@ "d3-selection": "^3.0.0", "d3-time-format": "^4.1.0", "d3-zoom": "^3.0.0", + "framer-motion": "^11.0.0", "lucide-react": "^0.294.0", "react": "^18.2.0", "react-dom": "^18.2.0", @@ -3355,6 +3356,33 @@ "url": "https://github.com/sponsors/rawify" } }, + "node_modules/framer-motion": { + "version": "11.18.2", + "resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-11.18.2.tgz", + "integrity": "sha512-5F5Och7wrvtLVElIpclDT0CBzMVg3dL22B64aZwHtsIY8RB4mXICLrkajK4G9R+ieSAGcgrLeae2SeUTg2pr6w==", + "license": "MIT", + "dependencies": { + "motion-dom": "^11.18.1", + "motion-utils": "^11.18.1", + "tslib": "^2.4.0" + }, + "peerDependencies": { + "@emotion/is-prop-valid": "*", + "react": "^18.0.0 || ^19.0.0", + "react-dom": "^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@emotion/is-prop-valid": { + "optional": true + }, + "react": { + "optional": true + }, + "react-dom": { + "optional": true + } + } + }, "node_modules/fs.realpath": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", @@ -4028,6 +4056,21 @@ "node": ">=16 || 14 >=14.17" } }, + "node_modules/motion-dom": { + "version": "11.18.1", + "resolved": "https://registry.npmjs.org/motion-dom/-/motion-dom-11.18.1.tgz", + "integrity": "sha512-g76KvA001z+atjfxczdRtw/RXOM3OMSdd1f4DL77qCTF/+avrRJiawSG4yDibEQ215sr9kpinSlX2pCTJ9zbhw==", + "license": "MIT", + "dependencies": { + "motion-utils": "^11.18.1" + } + }, + "node_modules/motion-utils": { + "version": "11.18.1", + "resolved": "https://registry.npmjs.org/motion-utils/-/motion-utils-11.18.1.tgz", + "integrity": "sha512-49Kt+HKjtbJKLtgO/LKj9Ld+6vw9BjH5d9sc40R/kVyH8GLAXgT42M2NnuPcJNuA3s9ZfZBUcwIgpmZWGEE+hA==", + "license": "MIT" + }, "node_modules/ms": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", @@ -5560,7 +5603,6 @@ "version": "2.8.1", "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", - "dev": true, "license": "0BSD" }, "node_modules/type-check": { diff --git a/backends/advanced/webui/package.json b/backends/advanced/webui/package.json index ca2c77a5..7c497790 100644 --- a/backends/advanced/webui/package.json +++ b/backends/advanced/webui/package.json @@ -13,6 +13,7 @@ "@tanstack/react-query": "^5.90.20", "axios": "^1.6.2", "clsx": "^2.0.0", + "framer-motion": "^11.0.0", "cronstrue": "^2.50.0", "d3-array": "^3.2.4", "d3-axis": "^3.0.0", diff --git a/backends/advanced/webui/src/components/UserLoopModal.tsx b/backends/advanced/webui/src/components/UserLoopModal.tsx new file mode 100644 index 00000000..6e671432 --- /dev/null +++ b/backends/advanced/webui/src/components/UserLoopModal.tsx @@ -0,0 +1,533 @@ +import { useState, useEffect, useCallback, useMemo, useRef } from 'react' +import { motion, AnimatePresence, PanInfo } from 'framer-motion' +import { X, Check, Heart, HeartCrack, Pencil, Play, Pause } from 'lucide-react' +import { api, BACKEND_URL } from '../services/api' +import { getStorageKey } from '../utils/storage' + +type DiffToken = { text: string; type: 'equal' | 'added' | 'removed' } + +/** Simple word-level diff using LCS to highlight changes. */ +function computeWordDiff(original: string, corrected: string): { originalTokens: DiffToken[]; correctedTokens: DiffToken[] } { + const a = original.split(/(\s+)/) + const b = corrected.split(/(\s+)/) + + // Build LCS table + const m = a.length, n = b.length + const dp: number[][] = Array.from({ length: m + 1 }, () => Array(n + 1).fill(0)) + for (let i = 1; i <= m; i++) { + for (let j = 1; j <= n; j++) { + dp[i][j] = a[i - 1] === b[j - 1] ? dp[i - 1][j - 1] + 1 : Math.max(dp[i - 1][j], dp[i][j - 1]) + } + } + + // Backtrack to get diff + const originalTokens: DiffToken[] = [] + const correctedTokens: DiffToken[] = [] + let i = m, j = n + const origReverse: DiffToken[] = [] + const corrReverse: DiffToken[] = [] + + while (i > 0 || j > 0) { + if (i > 0 && j > 0 && a[i - 1] === b[j - 1]) { + origReverse.push({ text: a[i - 1], type: 'equal' }) + corrReverse.push({ text: b[j - 1], type: 'equal' }) + i--; j-- + } else if (j > 0 && (i === 0 || dp[i][j - 1] >= dp[i - 1][j])) { + corrReverse.push({ text: b[j - 1], type: 'added' }) + j-- + } else { + origReverse.push({ text: a[i - 1], type: 'removed' }) + i-- + } + } + + originalTokens.push(...origReverse.reverse()) + correctedTokens.push(...corrReverse.reverse()) + return { originalTokens, correctedTokens } +} + +const AUTO_SHOW_KEY = 'userloop-auto-show' + +interface Suggestion { + id: string + annotation_type: string + conversation_id: string + segment_index: number | null + original_text: string + corrected_text: string + created_at: string + conversation_title: string | null + transcript_snippet: string | null + segment_start: number | null + segment_end: number | null +} + +/** Read auto-show preference from localStorage (default: false). */ +function getAutoShow(): boolean { + try { + return localStorage.getItem(AUTO_SHOW_KEY) === 'true' + } catch { + return false + } +} + +export default function UserLoopModal() { + const [isOpen, setIsOpen] = useState(false) + const [suggestions, setSuggestions] = useState([]) + const [currentIndex, setCurrentIndex] = useState(0) + const [direction, setDirection] = useState(0) + const [isAnimating, setIsAnimating] = useState(false) + const [particles, setParticles] = useState<{ id: number; x: number; y: number; type: 'heart' | 'heart-break' }[]>([]) + const [isEditing, setIsEditing] = useState(false) + const [editText, setEditText] = useState('') + const [isPlaying, setIsPlaying] = useState(false) + const audioRef = useRef(null) + + const stopAudio = useCallback(() => { + if (audioRef.current) { + audioRef.current.pause() + audioRef.current = null + } + setIsPlaying(false) + }, []) + + const fetchSuggestions = useCallback(async (): Promise => { + try { + const response = await api.get('/api/annotations/suggestions', { params: { limit: 20 } }) + const data = response.data + if (Array.isArray(data) && data.length > 0) { + setSuggestions(data) + setCurrentIndex(0) + return data + } + return [] + } catch { + return [] + } + }, []) + + // Auto-show: only poll & auto-open when the user has opted in via localStorage + useEffect(() => { + if (!getAutoShow()) return + + const check = async () => { + const data = await fetchSuggestions() + if (data.length > 0) setIsOpen(true) + } + check() + const interval = setInterval(check, 60000) + return () => clearInterval(interval) + }, [fetchSuggestions]) + + // Explicit trigger from Fine-tuning page (always works regardless of auto-show) + useEffect(() => { + const handler = () => { + fetchSuggestions().then(data => { + if (data.length > 0) setIsOpen(true) + }) + } + window.addEventListener('open-swipe-ui', handler) + return () => window.removeEventListener('open-swipe-ui', handler) + }, [fetchSuggestions]) + + // Stop audio on unmount + useEffect(() => { + return () => { stopAudio() } + }, [stopAudio]) + + // Stop audio when card changes + useEffect(() => { + stopAudio() + }, [currentIndex, stopAudio]) + + // Clean up particles + useEffect(() => { + const timer = setTimeout(() => setParticles([]), 1000) + return () => clearTimeout(timer) + }, [particles]) + + // Close modal when no suggestions left + useEffect(() => { + if (suggestions.length === 0 && isOpen) { + stopAudio() + setIsOpen(false) + } + }, [suggestions.length, isOpen, stopAudio]) + + // Keyboard shortcuts + useEffect(() => { + if (!isOpen || suggestions.length === 0) return + + const handleKeyDown = (e: KeyboardEvent) => { + // Don't capture keys when editing (textarea handles its own keys) + if (isEditing) return + + switch (e.key) { + case 'ArrowDown': + e.preventDefault() + handleSkip() + break + case 'ArrowUp': + e.preventDefault() + setEditText(suggestions[currentIndex]?.corrected_text || '') + setIsEditing(true) + break + case 'ArrowLeft': + e.preventDefault() + handleAction('reject', -1) + break + case 'ArrowRight': + e.preventDefault() + handleAction('accept', 1) + break + } + } + + window.addEventListener('keydown', handleKeyDown) + return () => window.removeEventListener('keydown', handleKeyDown) + }, [isOpen, suggestions, currentIndex, isEditing, isAnimating]) + + const createParticles = (type: 'heart' | 'heart-break') => { + setParticles( + Array.from({ length: 8 }, (_, i) => ({ + id: Date.now() + i, + x: Math.random() * 400 - 200, + y: Math.random() * 200 - 100, + type, + })) + ) + } + + const handleSkip = () => { + if (isAnimating) return + setIsEditing(false) + stopAudio() + if (currentIndex < suggestions.length - 1) { + setCurrentIndex(prev => prev + 1) + } else { + setIsOpen(false) + setSuggestions([]) + } + } + + const handleEditSave = async () => { + const suggestion = suggestions[currentIndex] + if (!suggestion) return + try { + await api.patch(`/api/annotations/${suggestion.id}`, { corrected_text: editText }) + // Update local state so diff re-renders with new text + setSuggestions(prev => prev.map((s, i) => i === currentIndex ? { ...s, corrected_text: editText } : s)) + } catch (error) { + console.error('Failed to save edit:', error) + } + setIsEditing(false) + } + + const togglePlay = () => { + const s = suggestions[currentIndex] + if (!s || s.segment_start == null || s.segment_end == null) return + + if (isPlaying && audioRef.current) { + stopAudio() + return + } + + const token = localStorage.getItem(getStorageKey('token')) || '' + const url = `${BACKEND_URL}/api/audio/chunks/${s.conversation_id}?start_time=${s.segment_start}&end_time=${s.segment_end}&token=${token}` + const audio = new Audio(url) + audioRef.current = audio + audio.addEventListener('ended', () => setIsPlaying(false)) + audio.play().then(() => setIsPlaying(true)).catch(() => setIsPlaying(false)) + } + + const handleAction = async (action: 'accept' | 'reject', swipeDirection: number) => { + const suggestion = suggestions[currentIndex] + if (!suggestion || isAnimating) return + + setIsAnimating(true) + setDirection(swipeDirection) + createParticles(action === 'accept' ? 'heart' : 'heart-break') + + try { + const status = action === 'accept' ? 'accepted' : 'rejected' + await api.patch(`/api/annotations/${suggestion.id}/status`, null, { + params: { status }, + }) + } catch (error) { + console.error(`Failed to ${action} suggestion:`, error) + } + + setTimeout(() => { + if (currentIndex < suggestions.length - 1) { + setCurrentIndex(prev => prev + 1) + } else { + setIsOpen(false) + setSuggestions([]) + } + setIsAnimating(false) + setDirection(0) + }, 400) + } + + const onPanEnd = (_event: MouseEvent | TouchEvent | PointerEvent, info: PanInfo) => { + if (isAnimating) return + const threshold = 100 + if (info.offset.x > threshold) { + handleAction('accept', 1) + } else if (info.offset.x < -threshold) { + handleAction('reject', -1) + } + } + + const diff = useMemo(() => { + if (!isOpen || suggestions.length === 0) return null + const current = suggestions[currentIndex] + return computeWordDiff(current.original_text, current.corrected_text) + }, [isOpen, suggestions, currentIndex]) + + if (!isOpen || suggestions.length === 0) return null + + const current = suggestions[currentIndex] + + const cardVariants = { + enter: (dir: number) => ({ x: dir > 0 ? 1000 : -1000, opacity: 0, scale: 0.8 }), + center: { zIndex: 1, x: 0, opacity: 1, scale: 1 }, + exit: (dir: number) => ({ zIndex: 0, x: dir > 0 ? 1000 : -1000, opacity: 0, scale: 0.8 }), + } + + return ( + + {isOpen && ( + +
+ {/* Particles */} + + {particles.map(p => ( + + {p.type === 'heart' ? ( + + ) : ( + + )} + + ))} + + + {/* Card */} + + {/* Status Overlays */} + + {direction > 0 && ( + + GOOD + + )} + {direction < 0 && ( + + NOPE + + )} + + + {/* Content */} + +
+ Review Suggestion +
+ + {current.conversation_title && ( +
+ {current.conversation_title} +
+ )} + + {/* Original vs corrected with diff highlighting */} +
+
+
Original
+
+ {diff?.originalTokens.map((t, i) => + t.type === 'removed' ? ( + {t.text} + ) : ( + {t.text} + ) + )} +
+
+
+
+ {isEditing ? ( + <> + + Editing... + + ) : ( + 'Suggested' + )} +
+ {isEditing ? ( +
+