diff --git a/backends/advanced/docker-compose.yml b/backends/advanced/docker-compose.yml index 95cc4cab..3eb7e108 100644 --- a/backends/advanced/docker-compose.yml +++ b/backends/advanced/docker-compose.yml @@ -267,9 +267,6 @@ services: timeout: 10s retries: 5 start_period: 30s - profiles: - - obsidian - - knowledge-graph # ollama: # image: ollama/ollama:latest diff --git a/backends/advanced/init.py b/backends/advanced/init.py index ff57242d..39561c71 100644 --- a/backends/advanced/init.py +++ b/backends/advanced/init.py @@ -444,41 +444,44 @@ def setup_optional_services(self): self.config["TS_AUTHKEY"] = self.args.ts_authkey self.console.print(f"[green][SUCCESS][/green] Tailscale auth key configured (Docker integration enabled)") + def setup_neo4j(self): + """Configure Neo4j credentials (always required - used by Knowledge Graph)""" + neo4j_password = getattr(self.args, 'neo4j_password', None) + + if neo4j_password: + self.console.print(f"[green]✅[/green] Neo4j: password configured via wizard") + else: + # Interactive prompt (standalone init.py run) + self.console.print() + self.console.print("[bold cyan]Neo4j Configuration[/bold cyan]") + self.console.print("Neo4j is used for Knowledge Graph (entity/relationship extraction)") + self.console.print() + neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") + + self.config["NEO4J_HOST"] = "neo4j" + self.config["NEO4J_USER"] = "neo4j" + self.config["NEO4J_PASSWORD"] = neo4j_password + self.console.print("[green][SUCCESS][/green] Neo4j credentials configured") + def setup_obsidian(self): - """Configure Obsidian/Neo4j integration""" - # Check if enabled via command line + """Configure Obsidian integration (optional feature flag only - Neo4j credentials handled by setup_neo4j)""" if hasattr(self.args, 'enable_obsidian') and self.args.enable_obsidian: enable_obsidian = True - neo4j_password = getattr(self.args, 'neo4j_password', None) - - if not neo4j_password: - self.console.print("[yellow][WARNING][/yellow] --enable-obsidian provided but no password") - neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") - - self.console.print(f"[green]✅[/green] Obsidian/Neo4j: enabled (configured via wizard)") + self.console.print(f"[green]✅[/green] Obsidian: enabled (configured via wizard)") else: # Interactive prompt (fallback) self.console.print() - self.console.print("[bold cyan]Obsidian/Neo4j Integration[/bold cyan]") + self.console.print("[bold cyan]Obsidian Integration (Optional)[/bold cyan]") self.console.print("Enable graph-based knowledge management for Obsidian vault notes") self.console.print() try: - enable_obsidian = Confirm.ask("Enable Obsidian/Neo4j integration?", default=False) + enable_obsidian = Confirm.ask("Enable Obsidian integration?", default=False) except EOFError: self.console.print("Using default: No") enable_obsidian = False - if enable_obsidian: - neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") - if enable_obsidian: - # Update .env with credentials only (secrets, not feature flags) - self.config["NEO4J_HOST"] = "neo4j" - self.config["NEO4J_USER"] = "neo4j" - self.config["NEO4J_PASSWORD"] = neo4j_password - - # Update config.yml with feature flag (source of truth) - auto-saves via ConfigManager self.config_manager.update_memory_config({ "obsidian": { "enabled": True, @@ -486,11 +489,8 @@ def setup_obsidian(self): "timeout": 30 } }) - - self.console.print("[green][SUCCESS][/green] Obsidian/Neo4j configured") - self.console.print("[blue][INFO][/blue] Neo4j will start automatically with --profile obsidian") + self.console.print("[green][SUCCESS][/green] Obsidian integration enabled") else: - # Explicitly disable Obsidian in config.yml when not enabled self.config_manager.update_memory_config({ "obsidian": { "enabled": False, @@ -498,52 +498,25 @@ def setup_obsidian(self): "timeout": 30 } }) - self.console.print("[blue][INFO][/blue] Obsidian/Neo4j integration disabled") + self.console.print("[blue][INFO][/blue] Obsidian integration disabled") def setup_knowledge_graph(self): - """Configure Knowledge Graph (Neo4j-based entity/relationship extraction)""" - # Check if enabled via command line + """Configure Knowledge Graph (Neo4j-based entity/relationship extraction - enabled by default)""" if hasattr(self.args, 'enable_knowledge_graph') and self.args.enable_knowledge_graph: enable_kg = True - neo4j_password = getattr(self.args, 'neo4j_password', None) - - if not neo4j_password: - # Check if already set from obsidian setup - neo4j_password = self.config.get("NEO4J_PASSWORD") - if not neo4j_password: - self.console.print("[yellow][WARNING][/yellow] --enable-knowledge-graph provided but no password") - neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") else: - # Interactive prompt (fallback) self.console.print() self.console.print("[bold cyan]Knowledge Graph (Entity Extraction)[/bold cyan]") - self.console.print("Enable graph-based entity and relationship extraction from conversations") - self.console.print("Extracts: People, Places, Organizations, Events, Promises/Tasks") + self.console.print("Extract people, places, organizations, events, and tasks from conversations") self.console.print() try: - enable_kg = Confirm.ask("Enable Knowledge Graph?", default=False) + enable_kg = Confirm.ask("Enable Knowledge Graph?", default=True) except EOFError: - self.console.print("Using default: No") - enable_kg = False - - if enable_kg: - # Check if Neo4j password already set from obsidian setup - existing_password = self.config.get("NEO4J_PASSWORD") - if existing_password: - self.console.print("[blue][INFO][/blue] Using Neo4j password from Obsidian configuration") - neo4j_password = existing_password - else: - neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") + self.console.print("Using default: Yes") + enable_kg = True if enable_kg: - # Update .env with credentials only (secrets, not feature flags) - self.config["NEO4J_HOST"] = "neo4j" - self.config["NEO4J_USER"] = "neo4j" - if neo4j_password: - self.config["NEO4J_PASSWORD"] = neo4j_password - - # Update config.yml with feature flag (source of truth) - auto-saves via ConfigManager self.config_manager.update_memory_config({ "knowledge_graph": { "enabled": True, @@ -551,12 +524,9 @@ def setup_knowledge_graph(self): "timeout": 30 } }) - - self.console.print("[green][SUCCESS][/green] Knowledge Graph configured") - self.console.print("[blue][INFO][/blue] Neo4j will start automatically with --profile knowledge-graph") + self.console.print("[green][SUCCESS][/green] Knowledge Graph enabled") self.console.print("[blue][INFO][/blue] Entities and relationships will be extracted from conversations") else: - # Explicitly disable Knowledge Graph in config.yml when not enabled self.config_manager.update_memory_config({ "knowledge_graph": { "enabled": False, @@ -577,12 +547,14 @@ def setup_langfuse(self): if langfuse_pub and langfuse_sec: # Auto-configure from wizard — no prompts needed - self.config["LANGFUSE_HOST"] = "http://langfuse-web:3000" + langfuse_host = getattr(self.args, 'langfuse_host', None) or "http://langfuse-web:3000" + self.config["LANGFUSE_HOST"] = langfuse_host self.config["LANGFUSE_PUBLIC_KEY"] = langfuse_pub self.config["LANGFUSE_SECRET_KEY"] = langfuse_sec - self.config["LANGFUSE_BASE_URL"] = "http://langfuse-web:3000" - self.console.print("[green][SUCCESS][/green] LangFuse auto-configured from wizard") - self.console.print(f"[blue][INFO][/blue] Host: http://langfuse-web:3000") + self.config["LANGFUSE_BASE_URL"] = langfuse_host + source = "external" if "langfuse-web" not in langfuse_host else "local" + self.console.print(f"[green][SUCCESS][/green] LangFuse auto-configured ({source})") + self.console.print(f"[blue][INFO][/blue] Host: {langfuse_host}") self.console.print(f"[blue][INFO][/blue] Public key: {self.mask_api_key(langfuse_pub)}") return @@ -842,25 +814,7 @@ def show_next_steps(self): config_yml = self.config_manager.get_full_config() self.console.print("1. Start the main services:") - # Include --profile obsidian/knowledge-graph if enabled (read from config.yml) - obsidian_enabled = config_yml.get("memory", {}).get("obsidian", {}).get("enabled", False) - kg_enabled = config_yml.get("memory", {}).get("knowledge_graph", {}).get("enabled", False) - - profiles = [] - profile_notes = [] - if obsidian_enabled: - profiles.append("obsidian") - profile_notes.append("Obsidian integration") - if kg_enabled: - profiles.append("knowledge-graph") - profile_notes.append("Knowledge Graph") - - if profiles: - profile_args = " ".join([f"--profile {p}" for p in profiles]) - self.console.print(f" [cyan]docker compose {profile_args} up --build -d[/cyan]") - self.console.print(f" [dim](Includes Neo4j for: {', '.join(profile_notes)})[/dim]") - else: - self.console.print(" [cyan]docker compose up --build -d[/cyan]") + self.console.print(" [cyan]docker compose up --build -d[/cyan]") self.console.print() # Auto-determine URLs for next steps @@ -908,6 +862,7 @@ def run(self): self.setup_llm() self.setup_memory() self.setup_optional_services() + self.setup_neo4j() self.setup_obsidian() self.setup_knowledge_graph() self.setup_langfuse() @@ -967,9 +922,11 @@ def main(): parser.add_argument("--ts-authkey", help="Tailscale auth key for Docker integration (default: prompt user)") parser.add_argument("--langfuse-public-key", - help="LangFuse project public key (from langfuse init)") + help="LangFuse project public key (from langfuse init or external)") parser.add_argument("--langfuse-secret-key", - help="LangFuse project secret key (from langfuse init)") + help="LangFuse project secret key (from langfuse init or external)") + parser.add_argument("--langfuse-host", + help="LangFuse host URL (default: http://langfuse-web:3000 for local)") args = parser.parse_args() diff --git a/backends/advanced/pyproject.toml b/backends/advanced/pyproject.toml index 23c736d7..e0e964c0 100644 --- a/backends/advanced/pyproject.toml +++ b/backends/advanced/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "google-auth-oauthlib>=1.0.0", "google-auth-httplib2>=0.2.0", "websockets>=12.0", + "croniter>=1.3.0", ] [project.optional-dependencies] diff --git a/backends/advanced/src/advanced_omi_backend/app_factory.py b/backends/advanced/src/advanced_omi_backend/app_factory.py index 3c0417eb..6a99c841 100644 --- a/backends/advanced/src/advanced_omi_backend/app_factory.py +++ b/backends/advanced/src/advanced_omi_backend/app_factory.py @@ -221,6 +221,23 @@ async def lifespan(app: FastAPI): # Register OpenMemory user if using openmemory_mcp provider await initialize_openmemory_user() + # Start cron scheduler (requires Redis to be available) + try: + from advanced_omi_backend.cron_scheduler import get_scheduler, register_cron_job + from advanced_omi_backend.workers.finetuning_jobs import ( + run_asr_jargon_extraction_job, + run_speaker_finetuning_job, + ) + + register_cron_job("speaker_finetuning", run_speaker_finetuning_job) + register_cron_job("asr_jargon_extraction", run_asr_jargon_extraction_job) + + scheduler = get_scheduler() + await scheduler.start() + application_logger.info("Cron scheduler started") + except Exception as e: + application_logger.warning(f"Cron scheduler failed to start: {e}") + # SystemTracker is used for monitoring and debugging application_logger.info("Using SystemTracker for monitoring and debugging") @@ -319,6 +336,16 @@ async def lifespan(app: FastAPI): except Exception as e: application_logger.error(f"Error shutting down plugins: {e}") + # Shutdown cron scheduler + try: + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + await scheduler.stop() + application_logger.info("Cron scheduler stopped") + except Exception as e: + application_logger.error(f"Error stopping cron scheduler: {e}") + # Shutdown memory service and speaker service shutdown_memory_service() application_logger.info("Memory and speaker services shut down.") diff --git a/backends/advanced/src/advanced_omi_backend/config.py b/backends/advanced/src/advanced_omi_backend/config.py index 77a842ce..63b1dcd7 100644 --- a/backends/advanced/src/advanced_omi_backend/config.py +++ b/backends/advanced/src/advanced_omi_backend/config.py @@ -198,9 +198,14 @@ def get_misc_settings() -> dict: transcription_cfg = get_backend_config('transcription') transcription_settings = OmegaConf.to_container(transcription_cfg, resolve=True) if transcription_cfg else {} + # Get speaker recognition settings for per_segment_speaker_id + speaker_cfg = get_backend_config('speaker_recognition') + speaker_settings = OmegaConf.to_container(speaker_cfg, resolve=True) if speaker_cfg else {} + return { 'always_persist_enabled': audio_settings.get('always_persist_enabled', False), - 'use_provider_segments': transcription_settings.get('use_provider_segments', False) + 'use_provider_segments': transcription_settings.get('use_provider_segments', False), + 'per_segment_speaker_id': speaker_settings.get('per_segment_speaker_id', False), } @@ -228,4 +233,10 @@ def save_misc_settings(settings: dict) -> bool: if not save_config_section('backend.transcription', transcription_settings): success = False + # Save speaker recognition settings if per_segment_speaker_id is provided + if 'per_segment_speaker_id' in settings: + speaker_settings = {'per_segment_speaker_id': settings['per_segment_speaker_id']} + if not save_config_section('backend.speaker_recognition', speaker_settings): + success = False + return success \ No newline at end of file diff --git a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py index 734df6ed..c9046484 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py @@ -187,7 +187,7 @@ async def upload_and_process_audio_files( conversation_id, version_id, "batch", # trigger - job_timeout=1800, # 30 minutes + job_timeout=900, # 15 minutes result_ttl=JOB_RESULT_TTL, job_id=transcribe_job_id, description=f"Transcribe uploaded file {conversation_id[:8]}", diff --git a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py index f327a545..f4ffe096 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py @@ -488,7 +488,7 @@ async def reprocess_transcript(conversation_id: str, user: User): conversation_id, version_id, "reprocess", - job_timeout=600, + job_timeout=900, # 15 minutes result_ttl=JOB_RESULT_TTL, job_id=f"reprocess_{conversation_id[:8]}", description=f"Transcribe audio for {conversation_id[:8]}", @@ -722,14 +722,24 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u provider_capabilities.get("diarization", False) or source_version.diarization_source == "provider" ) + has_words = bool(source_version.words) + has_segments = bool(source_version.segments) - if not source_version.words and not (provider_has_diarization and source_version.segments): + if not has_words and not has_segments: return JSONResponse( status_code=400, content={ - "error": "Cannot re-diarize transcript without word timings. Words are required for diarization." + "error": ( + "Cannot re-diarize transcript without word timings or segments. " + "Word timestamps or provider segments are required." + ) }, ) + if not has_words and has_segments and not provider_has_diarization: + logger.warning( + "Reprocessing speakers without word timings; " + "falling back to segment-based identification only." + ) # 5. Check if speaker recognition is enabled speaker_config = get_service_config("speaker_recognition") @@ -752,10 +762,13 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u "reprocessing_type": "speaker_diarization", "source_version_id": source_version_id, "trigger": "manual_reprocess", + "provider_capabilities": provider_capabilities, } - if provider_has_diarization: + use_segments = provider_has_diarization or not has_words + if use_segments: new_segments = source_version.segments # COPY provider segments - new_metadata["provider_capabilities"] = provider_capabilities + if not has_words and not provider_has_diarization: + new_metadata["segments_only"] = True else: new_segments = [] # Empty - will be populated by speaker job @@ -772,7 +785,7 @@ async def reprocess_speakers(conversation_id: str, transcript_version_id: str, u ) # Carry over diarization_source so speaker job knows to use segment identification - if provider_has_diarization: + if provider_has_diarization or (not has_words and has_segments): new_version.diarization_source = "provider" # Save conversation with new version diff --git a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py index fcf80de4..98e8f81b 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py @@ -1021,7 +1021,7 @@ async def _process_rolling_batch( conversation_id, version_id, f"rolling_batch_{batch_number}", # trigger - job_timeout=1800, # 30 minutes + job_timeout=900, # 15 minutes result_ttl=JOB_RESULT_TTL, job_id=transcribe_job_id, description=f"Transcribe rolling batch #{batch_number} {conversation_id[:8]}", @@ -1137,7 +1137,7 @@ async def _process_batch_audio_complete( conversation_id, version_id, "batch", # trigger - job_timeout=1800, # 30 minutes + job_timeout=900, # 15 minutes result_ttl=JOB_RESULT_TTL, job_id=transcribe_job_id, description=f"Transcribe batch audio {conversation_id[:8]}", diff --git a/backends/advanced/src/advanced_omi_backend/cron_scheduler.py b/backends/advanced/src/advanced_omi_backend/cron_scheduler.py new file mode 100644 index 00000000..b79effbf --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/cron_scheduler.py @@ -0,0 +1,277 @@ +""" +Config-driven asyncio cron scheduler for Chronicle. + +Reads job definitions from config.yml ``cron_jobs`` section, uses ``croniter`` +to compute next-run times, and dispatches registered job functions. State +(last_run / next_run) is persisted in Redis so it survives restarts. + +Usage: + scheduler = get_scheduler() + await scheduler.start() # call during FastAPI lifespan startup + await scheduler.stop() # call during shutdown +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Callable, Coroutine, Dict, List, Optional + +import redis.asyncio as aioredis +from croniter import croniter + +from advanced_omi_backend.config_loader import load_config, save_config_section + +logger = logging.getLogger(__name__) + +# Redis key prefixes +_LAST_RUN_KEY = "cron:last_run:{job_id}" +_NEXT_RUN_KEY = "cron:next_run:{job_id}" + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + +@dataclass +class CronJobConfig: + job_id: str + enabled: bool + schedule: str + description: str + next_run: Optional[datetime] = None + last_run: Optional[datetime] = None + running: bool = False + last_error: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Job registry – maps job_id → async callable +# --------------------------------------------------------------------------- + +JobFunc = Callable[[], Coroutine[Any, Any, dict]] + +_JOB_REGISTRY: Dict[str, JobFunc] = {} + + +def register_cron_job(job_id: str, func: JobFunc) -> None: + """Register a job function so the scheduler can dispatch it.""" + _JOB_REGISTRY[job_id] = func + + +def _get_job_func(job_id: str) -> Optional[JobFunc]: + return _JOB_REGISTRY.get(job_id) + + +# --------------------------------------------------------------------------- +# Scheduler +# --------------------------------------------------------------------------- + +class CronScheduler: + def __init__(self) -> None: + self.jobs: Dict[str, CronJobConfig] = {} + self._running = False + self._task: Optional[asyncio.Task] = None + self._redis: Optional[aioredis.Redis] = None + + # -- lifecycle ----------------------------------------------------------- + + async def start(self) -> None: + """Load config, restore state from Redis, and start the scheduler loop.""" + import os + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + self._redis = aioredis.from_url(redis_url, decode_responses=True) + + self._load_jobs_from_config() + await self._restore_state() + + self._running = True + self._task = asyncio.create_task(self._loop()) + logger.info(f"Cron scheduler started with {len(self.jobs)} jobs") + + async def stop(self) -> None: + """Cancel the scheduler loop and close Redis.""" + self._running = False + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + if self._redis: + await self._redis.close() + logger.info("Cron scheduler stopped") + + # -- public API ---------------------------------------------------------- + + async def run_job_now(self, job_id: str) -> dict: + """Manually trigger a job regardless of schedule.""" + if job_id not in self.jobs: + raise ValueError(f"Unknown cron job: {job_id}") + return await self._execute_job(job_id) + + async def update_job( + self, + job_id: str, + enabled: Optional[bool] = None, + schedule: Optional[str] = None, + ) -> None: + """Update a job's config and persist to config.yml.""" + if job_id not in self.jobs: + raise ValueError(f"Unknown cron job: {job_id}") + + cfg = self.jobs[job_id] + + if schedule is not None: + # Validate cron expression + if not croniter.is_valid(schedule): + raise ValueError(f"Invalid cron expression: {schedule}") + cfg.schedule = schedule + cfg.next_run = croniter(schedule, datetime.now(timezone.utc)).get_next(datetime) + + if enabled is not None: + cfg.enabled = enabled + + # Persist changes to config.yml + save_config_section( + f"cron_jobs.{job_id}", + {"enabled": cfg.enabled, "schedule": cfg.schedule, "description": cfg.description}, + ) + + # Update next_run in Redis + if self._redis and cfg.next_run: + await self._redis.set( + _NEXT_RUN_KEY.format(job_id=job_id), + cfg.next_run.isoformat(), + ) + + logger.info(f"Updated cron job '{job_id}': enabled={cfg.enabled}, schedule={cfg.schedule}") + + async def get_all_jobs_status(self) -> List[dict]: + """Return status of all registered cron jobs.""" + result = [] + for job_id, cfg in self.jobs.items(): + result.append({ + "job_id": job_id, + "enabled": cfg.enabled, + "schedule": cfg.schedule, + "description": cfg.description, + "last_run": cfg.last_run.isoformat() if cfg.last_run else None, + "next_run": cfg.next_run.isoformat() if cfg.next_run else None, + "running": cfg.running, + "last_error": cfg.last_error, + }) + return result + + # -- internals ----------------------------------------------------------- + + def _load_jobs_from_config(self) -> None: + """Read cron_jobs section from config.yml.""" + cfg = load_config() + cron_section = cfg.get("cron_jobs", {}) + + for job_id, job_cfg in cron_section.items(): + schedule = str(job_cfg.get("schedule", "0 * * * *")) + now = datetime.now(timezone.utc) + self.jobs[job_id] = CronJobConfig( + job_id=job_id, + enabled=bool(job_cfg.get("enabled", False)), + schedule=schedule, + description=str(job_cfg.get("description", "")), + next_run=croniter(schedule, now).get_next(datetime), + ) + + async def _restore_state(self) -> None: + """Restore last_run / next_run from Redis.""" + if not self._redis: + return + for job_id, cfg in self.jobs.items(): + try: + lr = await self._redis.get(_LAST_RUN_KEY.format(job_id=job_id)) + if lr: + cfg.last_run = datetime.fromisoformat(lr) + nr = await self._redis.get(_NEXT_RUN_KEY.format(job_id=job_id)) + if nr: + cfg.next_run = datetime.fromisoformat(nr) + except Exception as e: + logger.warning(f"Failed to restore state for job '{job_id}': {e}") + + async def _persist_state(self, job_id: str) -> None: + """Write last_run / next_run to Redis.""" + if not self._redis: + return + cfg = self.jobs[job_id] + try: + if cfg.last_run: + await self._redis.set( + _LAST_RUN_KEY.format(job_id=job_id), + cfg.last_run.isoformat(), + ) + if cfg.next_run: + await self._redis.set( + _NEXT_RUN_KEY.format(job_id=job_id), + cfg.next_run.isoformat(), + ) + except Exception as e: + logger.warning(f"Failed to persist state for job '{job_id}': {e}") + + async def _execute_job(self, job_id: str) -> dict: + """Run the job function and update state.""" + cfg = self.jobs[job_id] + func = _get_job_func(job_id) + if func is None: + msg = f"No function registered for cron job '{job_id}'" + logger.error(msg) + cfg.last_error = msg + return {"error": msg} + + cfg.running = True + cfg.last_error = None + now = datetime.now(timezone.utc) + logger.info(f"Executing cron job '{job_id}'") + + try: + result = await func() + cfg.last_run = now + cfg.next_run = croniter(cfg.schedule, now).get_next(datetime) + await self._persist_state(job_id) + logger.info(f"Cron job '{job_id}' completed: {result}") + return result or {} + except Exception as e: + cfg.last_error = str(e) + logger.error(f"Cron job '{job_id}' failed: {e}", exc_info=True) + # Still advance next_run so we don't spin on failures + cfg.last_run = now + cfg.next_run = croniter(cfg.schedule, now).get_next(datetime) + await self._persist_state(job_id) + return {"error": str(e)} + finally: + cfg.running = False + + async def _loop(self) -> None: + """Main scheduler loop – checks every 30s for due jobs.""" + while self._running: + try: + now = datetime.now(timezone.utc) + for job_id, cfg in self.jobs.items(): + if not cfg.enabled or cfg.running: + continue + if cfg.next_run and now >= cfg.next_run: + asyncio.create_task(self._execute_job(job_id)) + except Exception as e: + logger.error(f"Error in cron scheduler loop: {e}", exc_info=True) + await asyncio.sleep(30) + + +# --------------------------------------------------------------------------- +# Singleton +# --------------------------------------------------------------------------- + +_scheduler: Optional[CronScheduler] = None + + +def get_scheduler() -> CronScheduler: + """Get (or create) the global CronScheduler singleton.""" + global _scheduler + if _scheduler is None: + _scheduler = CronScheduler() + return _scheduler diff --git a/backends/advanced/src/advanced_omi_backend/models/annotation.py b/backends/advanced/src/advanced_omi_backend/models/annotation.py index ac8ceefe..8eecf81a 100644 --- a/backends/advanced/src/advanced_omi_backend/models/annotation.py +++ b/backends/advanced/src/advanced_omi_backend/models/annotation.py @@ -19,6 +19,7 @@ class AnnotationType(str, Enum): MEMORY = "memory" TRANSCRIPT = "transcript" DIARIZATION = "diarization" # Speaker identification corrections + ENTITY = "entity" # Knowledge graph entity corrections (name/details edits) class AnnotationSource(str, Enum): @@ -70,6 +71,12 @@ class Annotation(Document): corrected_speaker: Optional[str] = None # Speaker label after correction segment_start_time: Optional[float] = None # Time offset for reference + # For ENTITY annotations: + # Dual purpose: feeds both the jargon pipeline (entity name corrections = domain vocabulary + # the ASR should know) and the entity extraction pipeline (corrections improve future accuracy). + entity_id: Optional[str] = None # Neo4j entity ID + entity_field: Optional[str] = None # Which field was changed ("name" or "details") + # Processed tracking (applies to ALL annotation types) processed: bool = Field(default=False) # Whether annotation has been applied/sent to training processed_at: Optional[datetime] = None # When annotation was processed @@ -88,11 +95,12 @@ class Settings: # Create indexes on commonly queried fields # Note: Enum fields and Optional fields don't use Indexed() wrapper indexes = [ - "annotation_type", # Query by type (memory vs transcript vs diarization) + "annotation_type", # Query by type (memory vs transcript vs diarization vs entity) "user_id", # User-scoped queries "status", # Filter by status (pending/accepted/rejected) "memory_id", # Lookup annotations for specific memory "conversation_id", # Lookup annotations for specific conversation + "entity_id", # Lookup annotations for specific entity "processed", # Query unprocessed annotations ] @@ -108,6 +116,10 @@ def is_diarization_annotation(self) -> bool: """Check if this is a diarization annotation.""" return self.annotation_type == AnnotationType.DIARIZATION + def is_entity_annotation(self) -> bool: + """Check if this is an entity annotation.""" + return self.annotation_type == AnnotationType.ENTITY + def is_pending_suggestion(self) -> bool: """Check if this is a pending AI suggestion.""" return ( @@ -151,6 +163,18 @@ class DiarizationAnnotationCreate(BaseModel): status: AnnotationStatus = AnnotationStatus.ACCEPTED +class EntityAnnotationCreate(BaseModel): + """Create entity annotation request. + + Dual purpose: feeds both the jargon pipeline (entity name corrections = domain vocabulary + the ASR should know) and the entity extraction pipeline (corrections improve future accuracy). + """ + entity_id: str + entity_field: str # "name" or "details" + original_text: str + corrected_text: str + + class AnnotationResponse(BaseModel): """Annotation response for API.""" id: str @@ -164,6 +188,8 @@ class AnnotationResponse(BaseModel): original_speaker: Optional[str] = None corrected_speaker: Optional[str] = None segment_start_time: Optional[float] = None + entity_id: Optional[str] = None + entity_field: Optional[str] = None processed: bool = False processed_at: Optional[datetime] = None processed_by: Optional[str] = None diff --git a/backends/advanced/src/advanced_omi_backend/prompt_defaults.py b/backends/advanced/src/advanced_omi_backend/prompt_defaults.py index eca71cfc..58a1fdd0 100644 --- a/backends/advanced/src/advanced_omi_backend/prompt_defaults.py +++ b/backends/advanced/src/advanced_omi_backend/prompt_defaults.py @@ -260,6 +260,60 @@ def register_all_defaults(registry: PromptRegistry) -> None: category="memory", ) + # ------------------------------------------------------------------ + # memory.reprocess_speaker_update + # ------------------------------------------------------------------ + registry.register_default( + "memory.reprocess_speaker_update", + template="""\ +You are a memory correction system. A conversation's transcript has been reprocessed with \ +updated speaker identification. The words spoken are the same, but speakers have been \ +re-identified more accurately. Your job is to update the existing memories so they \ +correctly attribute information to the right people. + +## Rules + +1. **UPDATE** — If a memory attributes information to a speaker whose label changed, \ +rewrite it with the correct speaker name. Keep the same `id`. +2. **NONE** — If the memory is unaffected by the speaker changes, leave it unchanged. +3. **DELETE** — If a memory is now nonsensical or completely wrong because the speaker \ +was misidentified (e.g., personal traits wrongly attributed), remove it. +4. **ADD** — If the corrected transcript reveals important new facts that become clear \ +only with the correct speaker attribution, add them. + +## Important guidelines + +- Focus on **speaker attribution corrections**. This is the primary reason for reprocessing. +- A change from "Speaker 0" to "John" means memories referencing "Speaker 0" must now \ +reference "John". +- A change from "Alice" to "Bob" means facts previously attributed to "Alice" must be \ +attributed to "Bob" instead — this is critical because it changes *who* said or did something. +- Preserve the factual content when only the speaker name changes. +- Do NOT add memories that duplicate existing ones. +- When you UPDATE, always include `old_memory` with the previous text. + +## Output format (strict JSON only) + +Return ONLY a valid JSON object with this structure: + +{ + "memory": [ + { + "id": "", + "event": "UPDATE|NONE|DELETE|ADD", + "text": "", + "old_memory": "" + } + ] +} + +Do not output any text outside the JSON object. +""", + name="Reprocess Speaker Update", + description="Updates existing memories after speaker re-identification to correct speaker attribution.", + category="memory", + ) + # ------------------------------------------------------------------ # memory.temporal_extraction # ------------------------------------------------------------------ @@ -343,44 +397,23 @@ def register_all_defaults(registry: PromptRegistry) -> None: ) # ------------------------------------------------------------------ - # conversation.title + # conversation.title_summary # ------------------------------------------------------------------ registry.register_default( - "conversation.title", + "conversation.title_summary", template="""\ -Generate a concise, descriptive title (3-6 words) for this conversation transcript. - -Rules: -- Maximum 6 words -- Capture the main topic or theme -- Do NOT include speaker names or participants -- No quotes or special characters -- Examples: "Planning Weekend Trip", "Work Project Discussion", "Medical Appointment" - -Title:""", - name="Conversation Title", - description="Generates a short title for a conversation from its transcript.", - category="conversation", - ) +Based on the full conversation transcript below, generate a concise title and a brief summary. - # ------------------------------------------------------------------ - # conversation.short_summary - # ------------------------------------------------------------------ - registry.register_default( - "conversation.short_summary", - template="""\ -Generate a brief, informative summary (1-2 sentences, max 120 characters) for this conversation. +Respond in this exact format: +Title: +Summary: Rules: -- Maximum 120 characters -- 1-2 complete sentences -{{speaker_instruction}}- Capture key topics and outcomes -- Use present tense -- Be specific and informative - -Summary:""", - name="Conversation Short Summary", - description="Generates a brief 1-2 sentence summary of a conversation.", +- Title: Maximum 6 words, capture the main topic/theme, no quotes or special characters +- Summary: Maximum 120 characters, capture key topics and outcomes, use present tense +{{speaker_instruction}}""", + name="Conversation Title & Summary", + description="Generates both title and short summary from full conversation context in one LLM call.", category="conversation", variables=["speaker_instruction"], is_dynamic=True, @@ -486,6 +519,42 @@ def register_all_defaults(registry: PromptRegistry) -> None: category="knowledge_graph", ) + # ------------------------------------------------------------------ + # asr.hot_words + # ------------------------------------------------------------------ + registry.register_default( + "asr.hot_words", + template="hey vivi, chronicle, omi", + name="ASR Hot Words", + description="Comma-separated hot words for speech recognition. " + "For Deepgram: boosts keyword recognition via keyterm. " + "For VibeVoice: passed as context_info to guide the LLM backbone. " + "Supports names, technical terms, and domain-specific vocabulary.", + category="asr", + ) + + # ------------------------------------------------------------------ + # asr.jargon_extraction + # ------------------------------------------------------------------ + registry.register_default( + "asr.jargon_extraction", + template="""\ +Extract up to 20 key jargon terms, names, and technical vocabulary from these memory facts. +Return ONLY a comma-separated list of words or short phrases (1-3 words each). +Focus on: proper nouns, technical terms, domain-specific vocabulary, names of people/places/products. +Skip generic everyday words. + +Memory facts: +{{memories}} + +Jargon:""", + name="ASR Jargon Extraction", + description="Extracts key jargon terms from user memories for ASR context boosting.", + category="asr", + variables=["memories"], + is_dynamic=True, + ) + # ------------------------------------------------------------------ # transcription.title_summary # ------------------------------------------------------------------ diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py index f85a99ed..c4e49ce1 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/annotation_routes.py @@ -19,10 +19,12 @@ AnnotationStatus, AnnotationType, DiarizationAnnotationCreate, + EntityAnnotationCreate, MemoryAnnotationCreate, TranscriptAnnotationCreate, ) from advanced_omi_backend.models.conversation import Conversation +from advanced_omi_backend.services.knowledge_graph import get_knowledge_graph_service from advanced_omi_backend.services.memory import get_memory_service from advanced_omi_backend.users import User @@ -266,6 +268,25 @@ async def update_annotation_status( except Exception as e: logger.error(f"Error applying transcript suggestion: {e}") # Don't fail the status update if segment update fails + elif annotation.is_entity_annotation(): + # Update entity in Neo4j + try: + kg_service = get_knowledge_graph_service() + update_kwargs = {} + if annotation.entity_field == "name": + update_kwargs["name"] = annotation.corrected_text + elif annotation.entity_field == "details": + update_kwargs["details"] = annotation.corrected_text + if update_kwargs: + await kg_service.update_entity( + entity_id=annotation.entity_id, + user_id=annotation.user_id, + **update_kwargs, + ) + logger.info(f"Applied entity suggestion to entity {annotation.entity_id}") + except Exception as e: + logger.error(f"Error applying entity suggestion: {e}") + # Don't fail the status update if entity update fails await annotation.save() logger.info(f"Updated annotation {annotation_id} status to {status}") @@ -282,6 +303,113 @@ async def update_annotation_status( ) +# === Entity Annotation Routes === + + +@router.post("/entity", response_model=AnnotationResponse) +async def create_entity_annotation( + annotation_data: EntityAnnotationCreate, + current_user: User = Depends(current_active_user), +): + """ + Create annotation for entity edit (name or details correction). + + - Validates user owns the entity + - Creates annotation record for jargon/finetuning pipeline + - Applies correction to Neo4j immediately + - Marked as processed=False for downstream cron consumption + + Dual purpose: entity name corrections feed both the jargon pipeline + (domain vocabulary for ASR) and the entity extraction pipeline + (improving future extraction accuracy). + """ + try: + # Validate entity_field + if annotation_data.entity_field not in ("name", "details"): + raise HTTPException( + status_code=400, + detail="entity_field must be 'name' or 'details'", + ) + + # Verify entity exists and belongs to user + kg_service = get_knowledge_graph_service() + entity = await kg_service.get_entity( + entity_id=annotation_data.entity_id, + user_id=current_user.user_id, + ) + if not entity: + raise HTTPException(status_code=404, detail="Entity not found") + + # Create annotation + annotation = Annotation( + annotation_type=AnnotationType.ENTITY, + user_id=current_user.user_id, + entity_id=annotation_data.entity_id, + entity_field=annotation_data.entity_field, + original_text=annotation_data.original_text, + corrected_text=annotation_data.corrected_text, + status=AnnotationStatus.ACCEPTED, + processed=False, # Unprocessed — jargon/finetuning cron will consume later + ) + await annotation.save() + logger.info( + f"Created entity annotation {annotation.id} for entity {annotation_data.entity_id} " + f"field={annotation_data.entity_field}" + ) + + # Apply correction to Neo4j immediately + try: + update_kwargs = {} + if annotation_data.entity_field == "name": + update_kwargs["name"] = annotation_data.corrected_text + elif annotation_data.entity_field == "details": + update_kwargs["details"] = annotation_data.corrected_text + + await kg_service.update_entity( + entity_id=annotation_data.entity_id, + user_id=current_user.user_id, + **update_kwargs, + ) + logger.info(f"Applied entity correction to Neo4j for entity {annotation_data.entity_id}") + except Exception as e: + logger.error(f"Error applying entity correction to Neo4j: {e}") + # Annotation is saved but Neo4j update failed — log but don't fail the request + + return AnnotationResponse.model_validate(annotation) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error creating entity annotation: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Failed to create entity annotation: {str(e)}", + ) + + +@router.get("/entity/{entity_id}", response_model=List[AnnotationResponse]) +async def get_entity_annotations( + entity_id: str, + current_user: User = Depends(current_active_user), +): + """Get all annotations for an entity.""" + try: + annotations = await Annotation.find( + Annotation.annotation_type == AnnotationType.ENTITY, + Annotation.entity_id == entity_id, + Annotation.user_id == current_user.user_id, + ).to_list() + + return [AnnotationResponse.model_validate(a) for a in annotations] + + except Exception as e: + logger.error(f"Error fetching entity annotations: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Failed to fetch entity annotations: {str(e)}", + ) + + # === Diarization Annotation Routes === diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py index f3792e0b..72c271b6 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/finetuning_routes.py @@ -1,7 +1,8 @@ """ Fine-tuning routes for Chronicle API. -Handles sending annotation corrections to speaker recognition service for training. +Handles sending annotation corrections to speaker recognition service for training +and cron job management for automated tasks. """ import logging @@ -10,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import JSONResponse +from pydantic import BaseModel from advanced_omi_backend.auth import current_active_user from advanced_omi_backend.models.annotation import Annotation, AnnotationType @@ -56,7 +58,7 @@ async def process_annotations_for_training( # Filter out already trained annotations (processed_by contains "training") ready_for_training = [ a for a in annotations - if a.processed_by and "training" not in a.processed_by + if not a.processed_by or "training" not in a.processed_by ] if not ready_for_training: @@ -96,16 +98,14 @@ async def process_annotations_for_training( conversation = await Conversation.find_one( Conversation.conversation_id == annotation.conversation_id ) - + if not conversation or not conversation.active_transcript: - logger.warning(f"Conversation {annotation.conversation_id} not found or has no transcript") failed_count += 1 errors.append(f"Conversation {annotation.conversation_id[:8]} not found") continue # Validate segment index if annotation.segment_index >= len(conversation.active_transcript.segments): - logger.warning(f"Invalid segment index {annotation.segment_index} for conversation {annotation.conversation_id}") failed_count += 1 errors.append(f"Invalid segment index {annotation.segment_index}") continue @@ -198,7 +198,7 @@ async def process_annotations_for_training( "appended_to_existing": appended_count, "total_processed": total_processed, "failed_count": failed_count, - "errors": errors[:10] if errors else [], # Limit error list + "errors": errors[:10] if errors else [], "status": "success" if total_processed > 0 else "partial_failure" }) @@ -227,48 +227,111 @@ async def get_finetuning_status( - cron_status: Cron job schedule and last run info """ try: - # Count annotations by status - pending_count = await Annotation.find( - Annotation.annotation_type == AnnotationType.DIARIZATION, - Annotation.processed == False, - ).count() - - # Get all processed annotations - all_processed = await Annotation.find( - Annotation.annotation_type == AnnotationType.DIARIZATION, - Annotation.processed == True, - ).to_list() - - # Split into trained vs not-yet-trained - trained_annotations = [ - a for a in all_processed - if a.processed_by and "training" in a.processed_by - ] - applied_not_trained = [ - a for a in all_processed - if not a.processed_by or "training" not in a.processed_by - ] - - applied_count = len(applied_not_trained) - trained_count = len(trained_annotations) + # ------------------------------------------------------------------ + # Per-type annotation counts (with orphan detection) + # ------------------------------------------------------------------ + from advanced_omi_backend.models.conversation import Conversation - # Get last training run timestamp + annotation_counts: dict[str, dict] = {} + trained_diarization_list: list = [] + + # Collect all annotations to batch-check for orphans + all_annotations_by_type: dict[AnnotationType, list] = {} + for ann_type in AnnotationType: + all_annotations_by_type[ann_type] = await Annotation.find( + Annotation.annotation_type == ann_type, + ).to_list() + + # Batch-check which conversation_ids still exist + conv_annotation_types = {AnnotationType.DIARIZATION, AnnotationType.TRANSCRIPT} + all_conv_ids: set[str] = set() + for ann_type in conv_annotation_types: + for a in all_annotations_by_type.get(ann_type, []): + if a.conversation_id: + all_conv_ids.add(a.conversation_id) + + existing_conv_ids: set[str] = set() + if all_conv_ids: + existing_convs = await Conversation.find( + {"conversation_id": {"$in": list(all_conv_ids)}}, + ).to_list() + existing_conv_ids = {c.conversation_id for c in existing_convs} + + orphaned_conv_ids = all_conv_ids - existing_conv_ids + + total_orphaned = 0 + for ann_type in AnnotationType: + annotations = all_annotations_by_type[ann_type] + + # Identify orphaned annotations for conversation-based types + if ann_type in conv_annotation_types: + orphaned = [a for a in annotations if a.conversation_id in orphaned_conv_ids] + non_orphaned = [a for a in annotations if a.conversation_id not in orphaned_conv_ids] + else: + # Memory/entity orphan detection is placeholder for now + orphaned = [] + non_orphaned = annotations + + pending = [a for a in non_orphaned if not a.processed] + processed = [a for a in non_orphaned if a.processed] + trained = [a for a in processed if a.processed_by and "training" in a.processed_by] + applied_not_trained = [ + a for a in processed + if not a.processed_by or "training" not in a.processed_by + ] + + orphan_count = len(orphaned) + total_orphaned += orphan_count + + annotation_counts[ann_type.value] = { + "total": len(non_orphaned), + "pending": len(pending), + "applied": len(applied_not_trained), + "trained": len(trained), + "orphaned": orphan_count, + } + + if ann_type == AnnotationType.DIARIZATION: + trained_diarization_list = trained + + # ------------------------------------------------------------------ + # Diarization-specific fields (backward compat) + # ------------------------------------------------------------------ + diarization = annotation_counts.get("diarization", {}) + pending_count = diarization.get("pending", 0) + applied_count = diarization.get("applied", 0) + trained_count = diarization.get("trained", 0) + + # Get last training run timestamp from diarization annotations last_training_run = None - if trained_annotations: - # Find most recent trained annotation + if trained_diarization_list: latest_trained = max( - trained_annotations, + trained_diarization_list, key=lambda a: a.updated_at if a.updated_at else datetime.min.replace(tzinfo=timezone.utc) ) last_training_run = latest_trained.updated_at.isoformat() if latest_trained.updated_at else None - # TODO: Get cron job status from scheduler - cron_status = { - "enabled": False, # Placeholder - "schedule": "0 2 * * *", # Example: daily at 2 AM - "last_run": None, - "next_run": None, - } + # Get cron job status from scheduler + try: + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + all_jobs = await scheduler.get_all_jobs_status() + # Find speaker finetuning job for backward compat + speaker_job = next((j for j in all_jobs if j["job_id"] == "speaker_finetuning"), None) + cron_status = { + "enabled": speaker_job["enabled"] if speaker_job else False, + "schedule": speaker_job["schedule"] if speaker_job else "0 2 * * *", + "last_run": speaker_job["last_run"] if speaker_job else None, + "next_run": speaker_job["next_run"] if speaker_job else None, + } + except Exception: + cron_status = { + "enabled": False, + "schedule": "0 2 * * *", + "last_run": None, + "next_run": None, + } return JSONResponse(content={ "pending_annotation_count": pending_count, @@ -276,6 +339,8 @@ async def get_finetuning_status( "trained_annotation_count": trained_count, "last_training_run": last_training_run, "cron_status": cron_status, + "annotation_counts": annotation_counts, + "orphaned_annotation_count": total_orphaned, }) except Exception as e: @@ -284,3 +349,154 @@ async def get_finetuning_status( status_code=500, detail=f"Failed to fetch fine-tuning status: {str(e)}", ) + + +# --------------------------------------------------------------------------- +# Orphaned Annotation Management Endpoints +# --------------------------------------------------------------------------- + + +@router.delete("/orphaned-annotations") +async def delete_orphaned_annotations( + current_user: User = Depends(current_active_user), + annotation_type: Optional[str] = Query(None, description="Filter by annotation type (e.g. 'diarization')"), +): + """ + Find and delete orphaned annotations whose referenced conversation no longer exists. + + Only handles conversation-based annotation types (diarization, transcript). + """ + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + from advanced_omi_backend.models.conversation import Conversation + + conv_annotation_types = {AnnotationType.DIARIZATION, AnnotationType.TRANSCRIPT} + + # Filter to requested type if specified + if annotation_type: + try: + requested_type = AnnotationType(annotation_type) + except ValueError: + raise HTTPException(status_code=400, detail=f"Unknown annotation type: {annotation_type}") + if requested_type not in conv_annotation_types: + return JSONResponse(content={"deleted_count": 0, "by_type": {}, "message": "Orphan detection not supported for this type"}) + types_to_check = {requested_type} + else: + types_to_check = conv_annotation_types + + # Collect all conversation_ids referenced by these annotation types + all_conv_ids: set[str] = set() + annotations_by_type: dict[AnnotationType, list] = {} + for ann_type in types_to_check: + annotations = await Annotation.find( + Annotation.annotation_type == ann_type, + ).to_list() + annotations_by_type[ann_type] = annotations + for a in annotations: + if a.conversation_id: + all_conv_ids.add(a.conversation_id) + + if not all_conv_ids: + return JSONResponse(content={"deleted_count": 0, "by_type": {}}) + + # Batch-check which conversations still exist + existing_convs = await Conversation.find( + {"conversation_id": {"$in": list(all_conv_ids)}}, + ).to_list() + existing_conv_ids = {c.conversation_id for c in existing_convs} + orphaned_conv_ids = all_conv_ids - existing_conv_ids + + if not orphaned_conv_ids: + return JSONResponse(content={"deleted_count": 0, "by_type": {}}) + + # Delete orphaned annotations + deleted_by_type: dict[str, int] = {} + total_deleted = 0 + for ann_type, annotations in annotations_by_type.items(): + orphaned = [a for a in annotations if a.conversation_id in orphaned_conv_ids] + for a in orphaned: + await a.delete() + if orphaned: + deleted_by_type[ann_type.value] = len(orphaned) + total_deleted += len(orphaned) + + logger.info(f"Deleted {total_deleted} orphaned annotations: {deleted_by_type}") + return JSONResponse(content={ + "deleted_count": total_deleted, + "by_type": deleted_by_type, + }) + + +@router.post("/orphaned-annotations/reattach") +async def reattach_orphaned_annotations( + current_user: User = Depends(current_active_user), +): + """Placeholder for reattaching orphaned annotations to a different conversation.""" + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + raise HTTPException(status_code=501, detail="Reattach functionality coming soon") + + +# --------------------------------------------------------------------------- +# Cron Job Management Endpoints +# --------------------------------------------------------------------------- + + +class CronJobUpdate(BaseModel): + enabled: Optional[bool] = None + schedule: Optional[str] = None + + +@router.get("/cron-jobs") +async def get_cron_jobs(current_user: User = Depends(current_active_user)): + """List all cron jobs with status, schedule, last/next run.""" + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + return await scheduler.get_all_jobs_status() + + +@router.put("/cron-jobs/{job_id}") +async def update_cron_job( + job_id: str, + body: CronJobUpdate, + current_user: User = Depends(current_active_user), +): + """Update a cron job's schedule or enabled state.""" + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + try: + await scheduler.update_job(job_id, enabled=body.enabled, schedule=body.schedule) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + return {"message": f"Job '{job_id}' updated", "job_id": job_id} + + +@router.post("/cron-jobs/{job_id}/run") +async def run_cron_job_now( + job_id: str, + current_user: User = Depends(current_active_user), +): + """Manually trigger a cron job.""" + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Admin access required") + + from advanced_omi_backend.cron_scheduler import get_scheduler + + scheduler = get_scheduler() + try: + result = await scheduler.run_job_now(job_id) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + return result diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py index d4680951..060b6ff7 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py @@ -13,6 +13,11 @@ from pydantic import BaseModel from advanced_omi_backend.auth import current_active_user +from advanced_omi_backend.models.annotation import ( + Annotation, + AnnotationStatus, + AnnotationType, +) from advanced_omi_backend.services.knowledge_graph import ( KnowledgeGraphService, PromiseStatus, @@ -30,6 +35,13 @@ # ============================================================================= +class UpdateEntityRequest(BaseModel): + """Request model for updating entity fields.""" + name: Optional[str] = None + details: Optional[str] = None + icon: Optional[str] = None + + class UpdatePromiseRequest(BaseModel): """Request model for updating promise status.""" status: str # pending, in_progress, completed, cancelled @@ -144,6 +156,78 @@ async def get_entity_relationships( ) +@router.patch("/entities/{entity_id}") +async def update_entity( + entity_id: str, + request: UpdateEntityRequest, + current_user: User = Depends(current_active_user), +): + """Update an entity's name, details, or icon. + + Also creates entity annotations as a side effect for each changed field. + These annotations feed the jargon and entity extraction pipelines. + """ + try: + if not request.name and not request.details and not request.icon: + raise HTTPException( + status_code=400, + detail="At least one field (name, details, icon) must be provided", + ) + + service = get_knowledge_graph_service() + + # Get current entity for annotation original values + existing = await service.get_entity( + entity_id=entity_id, + user_id=str(current_user.id), + ) + if not existing: + raise HTTPException(status_code=404, detail="Entity not found") + + # Apply update to Neo4j + updated = await service.update_entity( + entity_id=entity_id, + user_id=str(current_user.id), + name=request.name, + details=request.details, + icon=request.icon, + ) + if not updated: + raise HTTPException(status_code=404, detail="Entity not found") + + # Create annotations for changed text fields (name, details) + # These feed the jargon pipeline and entity extraction pipeline. + # Icon changes don't create annotations (not text corrections). + for field in ("name", "details"): + new_value = getattr(request, field) + if new_value is not None: + old_value = getattr(existing, field) or "" + annotation = Annotation( + annotation_type=AnnotationType.ENTITY, + user_id=str(current_user.id), + entity_id=entity_id, + entity_field=field, + original_text=old_value, + corrected_text=new_value, + status=AnnotationStatus.ACCEPTED, + processed=False, + ) + await annotation.save() + logger.info( + f"Created entity annotation for {field} change on entity {entity_id}" + ) + + return {"entity": updated.to_dict()} + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating entity {entity_id}: {e}") + return JSONResponse( + status_code=500, + content={"message": f"Error updating entity: {str(e)}"}, + ) + + @router.delete("/entities/{entity_id}") async def delete_entity( entity_id: str, diff --git a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py index 6562dccc..5f13508d 100644 --- a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py +++ b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py @@ -535,6 +535,44 @@ async def search_entities( return entities + async def update_entity( + self, + entity_id: str, + user_id: str, + name: str = None, + details: str = None, + icon: str = None, + ) -> Optional[Entity]: + """Update an entity's fields (partial update via COALESCE). + + Args: + entity_id: Entity UUID + user_id: User ID for permission check + name: New name (None keeps existing) + details: New details (None keeps existing) + icon: New icon (None keeps existing) + + Returns: + Updated Entity object or None if not found + """ + self._ensure_initialized() + + results = self._write.run( + queries.UPDATE_ENTITY, + id=entity_id, + user_id=user_id, + name=name, + details=details, + icon=icon, + metadata=None, + ) + + if not results: + return None + + entity_data = dict(results[0]["e"]) + return self._row_to_entity(entity_data) + async def delete_entity( self, entity_id: str, diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/base.py b/backends/advanced/src/advanced_omi_backend/services/memory/base.py index 7df7748e..277a92f8 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/base.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/base.py @@ -205,6 +205,45 @@ async def update_memory( """ return False + async def reprocess_memory( + self, + transcript: str, + client_id: str, + source_id: str, + user_id: str, + user_email: str, + transcript_diff: Optional[List[Dict[str, Any]]] = None, + previous_transcript: Optional[str] = None, + ) -> Tuple[bool, List[str]]: + """Reprocess memories after transcript or speaker changes. + + This method is called when a conversation's transcript has been + reprocessed (e.g., speaker re-identification) and memories need + to be updated to reflect the changes. + + The default implementation falls back to normal ``add_memory`` + with ``allow_update=True``. Providers that support diff-aware + reprocessing should override this method. + + Args: + transcript: Updated full transcript text (with corrected speakers) + client_id: Client identifier + source_id: Conversation/source identifier + user_id: User identifier + user_email: User email address + transcript_diff: List of dicts describing what changed between + the old and new transcript (speaker changes, text changes). + Each dict has keys like ``type``, ``old_speaker``, + ``new_speaker``, ``text``, ``start``, ``end``. + previous_transcript: The previous transcript text (before changes) + + Returns: + Tuple of (success: bool, affected_memory_ids: List[str]) + """ + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, allow_update=True + ) + @abstractmethod async def delete_memory( self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None @@ -331,6 +370,37 @@ async def propose_memory_actions( """ pass + async def propose_reprocess_actions( + self, + existing_memories: List[Dict[str, str]], + diff_context: str, + new_transcript: str, + custom_prompt: Optional[str] = None, + ) -> Dict[str, Any]: + """Propose memory updates after transcript reprocessing (e.g., speaker changes). + + Uses the LLM to review existing conversation memories in light of + specific transcript changes (speaker re-identification, text corrections) + and propose targeted ADD/UPDATE/DELETE/NONE actions. + + Default implementation raises NotImplementedError. Providers that + support diff-aware reprocessing should override this method. + + Args: + existing_memories: List of existing memories for the conversation + (each dict has ``id`` and ``text`` keys) + diff_context: Formatted string describing what changed in the + transcript (e.g., speaker relabelling details) + new_transcript: The updated full transcript text + custom_prompt: Optional custom system prompt + + Returns: + Dictionary containing proposed actions in ``{"memory": [...]}`` format + """ + raise NotImplementedError( + f"{type(self).__name__} does not support propose_reprocess_actions" + ) + @abstractmethod async def test_connection(self) -> bool: """Test connection to the LLM provider. @@ -415,6 +485,24 @@ async def count_memories(self, user_id: str) -> Optional[int]: """ return None + async def get_memories_by_source( + self, user_id: str, source_id: str, limit: int = 100 + ) -> List["MemoryEntry"]: + """Get all memories for a specific source (conversation) for a user. + + Default implementation returns empty list. Vector stores should + override to filter by metadata.source_id. + + Args: + user_id: User identifier + source_id: Source/conversation identifier + limit: Maximum number of memories to return + + Returns: + List of MemoryEntry objects for the specified source + """ + return [] + @abstractmethod async def update_memory( self, diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py b/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py index 3e4f4535..0e704be3 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py @@ -197,6 +197,80 @@ """ +REPROCESS_SPEAKER_UPDATE_PROMPT = """ +You are a memory correction system. A conversation's transcript has been reprocessed with \ +updated speaker identification. The words spoken are the same, but speakers have been \ +re-identified more accurately. Your job is to update the existing memories so they \ +correctly attribute information to the right people. + +## Rules + +1. **UPDATE** — If a memory attributes information to a speaker whose label changed, \ +rewrite it with the correct speaker name. Keep the same `id`. +2. **NONE** — If the memory is unaffected by the speaker changes, leave it unchanged. +3. **DELETE** — If a memory is now nonsensical or completely wrong because the speaker \ +was misidentified (e.g., personal traits wrongly attributed), remove it. +4. **ADD** — If the corrected transcript reveals important new facts that become clear \ +only with the correct speaker attribution, add them. + +## Important guidelines + +- Focus on **speaker attribution corrections**. This is the primary reason for reprocessing. +- A change from "Speaker 0" to "John" means memories referencing "Speaker 0" must now \ +reference "John". +- A change from "Alice" to "Bob" means facts previously attributed to "Alice" must be \ +attributed to "Bob" instead — this is critical because it changes *who* said or did something. +- Preserve the factual content when only the speaker name changes. +- Do NOT add memories that duplicate existing ones. +- When you UPDATE, always include `old_memory` with the previous text. + +## Output format (strict JSON only) + +Return ONLY a valid JSON object with this structure: + +{ + "memory": [ + { + "id": "", + "event": "UPDATE|NONE|DELETE|ADD", + "text": "", + "old_memory": "" + } + ] +} + +Do not output any text outside the JSON object. +""" + + +def build_reprocess_speaker_messages( + existing_memories: list, + diff_context: str, + new_transcript: str, +) -> str: + """Build the user message for the reprocess-after-speaker-change LLM call. + + Args: + existing_memories: List of dicts with ``id`` and ``text`` keys + diff_context: Formatted string of speaker changes + new_transcript: Full updated transcript with corrected speakers + + Returns: + Formatted user message string + """ + memories_json = json.dumps(existing_memories, ensure_ascii=False) + + return ( + "## Existing Memories for This Conversation\n" + f"{memories_json}\n\n" + "## Speaker Changes in Transcript\n" + f"{diff_context}\n\n" + "## Updated Full Transcript (with corrected speakers)\n" + f"{new_transcript}\n\n" + "Output:" + ) + + PROCEDURAL_MEMORY_SYSTEM_PROMPT = """ You are a memory summarization system that records and preserves the complete interaction history between a human and an AI agent. You are provided with the agent's execution history over the past N steps. Your task is to produce a comprehensive summary of the agent's output history that contains every detail necessary for the agent to continue the task without ambiguity. **Every output produced by the agent must be recorded verbatim as part of the summary.** diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py index 1eddae93..09accf00 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py @@ -476,6 +476,202 @@ def shutdown(self) -> None: self.vector_store = None memory_logger.info("Memory service shut down") + async def reprocess_memory( + self, + transcript: str, + client_id: str, + source_id: str, + user_id: str, + user_email: str, + transcript_diff: Optional[list] = None, + previous_transcript: Optional[str] = None, + ) -> Tuple[bool, List[str]]: + """Reprocess memories after speaker re-identification. + + Instead of extracting fresh facts from scratch, this method: + 1. Fetches existing memories for this specific conversation + 2. Computes what changed (speaker labels) between old and new transcript + 3. Asks the LLM to make targeted updates to the existing memories + + Falls back to normal ``add_memory`` when there are no existing + memories or no meaningful diff. + + Args: + transcript: Updated full transcript (with corrected speakers) + client_id: Client identifier + source_id: Conversation identifier + user_id: User identifier + user_email: User email + transcript_diff: List of dicts describing speaker changes + previous_transcript: Previous transcript text (before changes) + + Returns: + Tuple of (success, affected_memory_ids) + """ + await self._ensure_initialized() + + try: + # 1. Get existing memories for this conversation + existing_memories = await self.vector_store.get_memories_by_source( + user_id, source_id + ) + + # 2. If no existing memories, fall back to normal extraction + if not existing_memories: + memory_logger.info( + f"🔄 Reprocess: no existing memories for {source_id}, " + f"falling back to normal extraction" + ) + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + + # 3. If no diff provided, fall back to normal extraction + if not transcript_diff: + memory_logger.info( + f"🔄 Reprocess: no transcript diff for {source_id}, " + f"falling back to normal extraction" + ) + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + + # 4. Format the diff for the LLM + diff_text = self._format_speaker_diff(transcript_diff) + + memory_logger.info( + f"🔄 Reprocess: {len(existing_memories)} existing memories, " + f"{len(transcript_diff)} speaker changes for {source_id}" + ) + + # 5. Build temp ID mapping (avoid hallucinated UUIDs) + temp_uuid_mapping = {} + existing_memory_dicts = [] + for idx, mem in enumerate(existing_memories): + temp_uuid_mapping[str(idx)] = mem.id + existing_memory_dicts.append({"id": str(idx), "text": mem.content}) + + # 6. Ask LLM for targeted update actions + try: + actions_obj = await self.llm_provider.propose_reprocess_actions( + existing_memories=existing_memory_dicts, + diff_context=diff_text, + new_transcript=transcript, + ) + memory_logger.info( + f"🔄 Reprocess LLM returned actions: {actions_obj}" + ) + except NotImplementedError: + memory_logger.warning( + "LLM provider does not support propose_reprocess_actions, " + "falling back to normal extraction" + ) + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + except Exception as e: + memory_logger.error(f"Reprocess LLM call failed: {e}") + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + + # 7. Normalize and pre-generate embeddings for ADD/UPDATE actions + actions_list = self._normalize_actions(actions_obj) + + texts_needing_embeddings = [ + action.get("text") + for action in actions_list + if action.get("event") in ("ADD", "UPDATE") + and action.get("text") + and isinstance(action.get("text"), str) + ] + + text_to_embedding = {} + if texts_needing_embeddings: + try: + embeddings = await asyncio.wait_for( + self.llm_provider.generate_embeddings(texts_needing_embeddings), + timeout=self.config.timeout_seconds, + ) + text_to_embedding = dict( + zip(texts_needing_embeddings, embeddings) + ) + except Exception as e: + memory_logger.warning( + f"Batch embedding generation failed for reprocess: {e}" + ) + + # 8. Apply the actions (reuses existing infrastructure) + created_ids = await self._apply_memory_actions( + actions_list, + text_to_embedding, + temp_uuid_mapping, + client_id, + source_id, + user_id, + user_email, + ) + + memory_logger.info( + f"✅ Reprocess complete for {source_id}: " + f"{len(created_ids)} memories affected" + ) + return True, created_ids + + except Exception as e: + memory_logger.error( + f"❌ Reprocess memory failed for {source_id}: {e}" + ) + # Fall back to normal extraction on any unexpected error + memory_logger.info( + f"🔄 Falling back to normal extraction after reprocess error" + ) + return await self.add_memory( + transcript, client_id, source_id, user_id, user_email, + allow_update=True, + ) + + @staticmethod + def _format_speaker_diff(transcript_diff: list) -> str: + """Format a transcript diff into a human-readable string for the LLM. + + Args: + transcript_diff: List of change dicts from + ``compute_speaker_diff`` + + Returns: + Formatted multi-line string describing the changes + """ + if not transcript_diff: + return "No changes detected." + + lines = [] + for change in transcript_diff: + change_type = change.get("type", "unknown") + if change_type == "speaker_change": + lines.append( + f"- \"{change.get('text', '')}\" " + f"was spoken by \"{change.get('old_speaker', '?')}\" " + f"but is now identified as \"{change.get('new_speaker', '?')}\"" + ) + elif change_type == "text_change": + lines.append( + f"- Segment by {change.get('speaker', '?')}: " + f"text changed from \"{change.get('old_text', '')}\" " + f"to \"{change.get('new_text', '')}\"" + ) + elif change_type == "new_segment": + lines.append( + f"- New segment: {change.get('speaker', '?')}: " + f"\"{change.get('text', '')}\"" + ) + + return "\n".join(lines) + # Private helper methods def _deduplicate_memories(self, memories_text: List[str]) -> List[str]: diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py index 9b00e8b1..a3f68c5f 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py @@ -20,9 +20,9 @@ from ..base import LLMProviderBase from ..prompts import ( - FACT_RETRIEVAL_PROMPT, + REPROCESS_SPEAKER_UPDATE_PROMPT, + build_reprocess_speaker_messages, build_update_memory_messages, - get_update_memory_messages, ) from ..update_memory_utils import ( extract_assistant_xml_from_openai_response, @@ -357,6 +357,97 @@ async def propose_memory_actions( return {} + async def propose_reprocess_actions( + self, + existing_memories: List[Dict[str, str]], + diff_context: str, + new_transcript: str, + custom_prompt: Optional[str] = None, + ) -> Dict[str, Any]: + """Propose memory updates after speaker re-identification. + + Sends the existing conversation memories, the speaker change diff, + and the corrected transcript to the LLM. Returns JSON with + ADD/UPDATE/DELETE/NONE actions. + + The system prompt is resolved in priority order: + 1. ``custom_prompt`` argument (if provided) + 2. Langfuse override via the prompt registry + (prompt id ``memory.reprocess_speaker_update``) + 3. Registered default from ``prompt_defaults.py`` + + Args: + existing_memories: List of {id, text} dicts for this conversation + diff_context: Formatted string of speaker changes + new_transcript: Full updated transcript with corrected speakers + custom_prompt: Optional custom system prompt + + Returns: + Dictionary with ``memory`` key containing action list + """ + try: + # Resolve prompt: explicit arg → Langfuse/registry → hardcoded fallback + if custom_prompt and custom_prompt.strip(): + system_prompt = custom_prompt + else: + try: + registry = get_prompt_registry() + system_prompt = await registry.get_prompt( + "memory.reprocess_speaker_update" + ) + except Exception as e: + memory_logger.debug( + f"Registry prompt fetch failed for " + f"memory.reprocess_speaker_update: {e}, " + f"using hardcoded fallback" + ) + system_prompt = REPROCESS_SPEAKER_UPDATE_PROMPT + + user_content = build_reprocess_speaker_messages( + existing_memories, diff_context, new_transcript + ) + + messages = [ + {"role": "system", "content": system_prompt.strip()}, + {"role": "user", "content": user_content}, + ] + + memory_logger.info( + f"🔄 Reprocess: asking LLM with {len(existing_memories)} existing memories " + f"and speaker diff" + ) + memory_logger.debug( + f"🔄 Reprocess user content (first 300 chars): {user_content[:300]}..." + ) + + client = _get_openai_client( + api_key=self.api_key, base_url=self.base_url, is_async=True + ) + response = await client.chat.completions.create( + model=self.model, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + response_format={"type": "json_object"}, + ) + content = (response.choices[0].message.content or "").strip() + + if not content: + memory_logger.warning("Reprocess LLM returned empty content") + return {} + + result = json.loads(content) + memory_logger.info(f"🔄 Reprocess LLM returned: {result}") + return result + + except json.JSONDecodeError as e: + memory_logger.error(f"Reprocess LLM returned invalid JSON: {e}") + return {} + except Exception as e: + memory_logger.error(f"propose_reprocess_actions failed: {e}") + return {} + + class OllamaProvider(LLMProviderBase): """Ollama LLM provider implementation. diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py index 9fed0126..50678642 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py @@ -19,6 +19,7 @@ FilterSelector, MatchValue, PointStruct, + Range, VectorParams, ) @@ -445,7 +446,112 @@ async def get_memory(self, memory_id: str, user_id: Optional[str] = None) -> Opt memory_logger.error(f"Qdrant get memory failed for {memory_id}: {e}") return None + async def get_memories_by_source( + self, user_id: str, source_id: str, limit: int = 100 + ) -> List[MemoryEntry]: + """Get all memories for a specific source (conversation) for a user. + Filters by both user_id and source_id in metadata to return only + memories extracted from a particular conversation. + Args: + user_id: User identifier + source_id: Source/conversation identifier + limit: Maximum number of memories to return + + Returns: + List of MemoryEntry objects for the specified source + """ + try: + search_filter = Filter( + must=[ + FieldCondition( + key="metadata.user_id", + match=MatchValue(value=user_id), + ), + FieldCondition( + key="metadata.source_id", + match=MatchValue(value=source_id), + ), + ] + ) + + results = await self.client.scroll( + collection_name=self.collection_name, + scroll_filter=search_filter, + limit=limit, + ) + + memories = [] + for point in results[0]: + memory = MemoryEntry( + id=str(point.id), + content=point.payload.get("content", ""), + metadata=point.payload.get("metadata", {}), + created_at=point.payload.get("created_at"), + updated_at=point.payload.get("updated_at"), + ) + memories.append(memory) + + memory_logger.info( + f"Found {len(memories)} memories for source {source_id} (user {user_id})" + ) + return memories + + except Exception as e: + memory_logger.error(f"Qdrant get memories by source failed: {e}") + return [] + + async def get_recent_memories( + self, user_id: str, since_timestamp: int, limit: int = 100 + ) -> List[MemoryEntry]: + """Get memories created after a given unix timestamp for a user. + Args: + user_id: User identifier + since_timestamp: Unix timestamp; only memories at or after this time are returned + limit: Maximum number of memories to return + + Returns: + List of MemoryEntry objects + """ + try: + search_filter = Filter( + must=[ + FieldCondition( + key="metadata.user_id", + match=MatchValue(value=user_id), + ), + FieldCondition( + key="metadata.timestamp", + range=Range(gte=since_timestamp), + ), + ] + ) + + results = await self.client.scroll( + collection_name=self.collection_name, + scroll_filter=search_filter, + limit=limit, + ) + + memories = [] + for point in results[0]: + memory = MemoryEntry( + id=str(point.id), + content=point.payload.get("content", ""), + metadata=point.payload.get("metadata", {}), + created_at=point.payload.get("created_at"), + updated_at=point.payload.get("updated_at"), + ) + memories.append(memory) + + memory_logger.info( + f"Found {len(memories)} recent memories since {since_timestamp} for user {user_id}" + ) + return memories + + except Exception as e: + memory_logger.error(f"Qdrant get recent memories failed: {e}") + return [] diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py index 71b213b8..9c7f1d21 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py @@ -17,6 +17,7 @@ from advanced_omi_backend.config_loader import get_backend_config from advanced_omi_backend.model_registry import get_models_registry +from advanced_omi_backend.prompt_registry import get_prompt_registry from .base import ( BaseTranscriptionProvider, @@ -27,6 +28,26 @@ logger = logging.getLogger(__name__) +def _parse_hot_words_to_keyterm(hot_words_str: str) -> str: + """Convert comma-separated hot words to Deepgram keyterm format. + + Input: "hey vivi, chronicle, omi" + Output: "hey vivi Hey Vivi chronicle Chronicle omi Omi" + """ + if not hot_words_str or not hot_words_str.strip(): + return "" + terms = [] + for word in hot_words_str.split(","): + word = word.strip() + if not word: + continue + terms.append(word) + capitalized = word.title() + if capitalized != word: + terms.append(capitalized) + return " ".join(terms) + + def _dotted_get(d: dict | list | None, dotted: Optional[str]): """Safely extract a value from nested dict/list using dotted paths. @@ -99,7 +120,7 @@ def get_capabilities_dict(self) -> dict: """ return {cap: True for cap in self._capabilities} - async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False) -> dict: + async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False, context_info: Optional[str] = None, **kwargs) -> dict: # Special handling for mock provider (no HTTP server needed) if self.model.model_provider == "mock": from .mock_provider import MockTranscriptionProvider @@ -148,14 +169,34 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = if "diarize" in query: query["diarize"] = "true" if diarize else "false" + # Use caller-provided context or fall back to LangFuse prompt store + if context_info: + hot_words_str = context_info + else: + hot_words_str = "" + try: + registry = get_prompt_registry() + hot_words_str = await registry.get_prompt("asr.hot_words") + except Exception as e: + logger.debug(f"Failed to fetch asr.hot_words prompt: {e}") + + # For Deepgram: inject as keyterm query param + if self.model.model_provider == "deepgram" and hot_words_str.strip(): + keyterm = _parse_hot_words_to_keyterm(hot_words_str) + if keyterm: + query["keyterm"] = keyterm + timeout = op.get("timeout", 300) try: async with httpx.AsyncClient(timeout=timeout) as client: if method == "POST": if use_multipart: - # Send as multipart file upload (for Parakeet) + # Send as multipart file upload (for Parakeet/VibeVoice) files = {"file": ("audio.wav", audio_data, "audio/wav")} - resp = await client.post(url, headers=headers, params=query, files=files) + data = {} + if hot_words_str and hot_words_str.strip(): + data["context_info"] = hot_words_str.strip() + resp = await client.post(url, headers=headers, params=query, files=files, data=data) else: # Send as raw audio data (for Deepgram) resp = await client.post(url, headers=headers, params=query, content=audio_data) @@ -240,6 +281,18 @@ async def start_stream(self, client_id: str, sample_rate: int = 16000, diarize: if diarize and "diarize" in query_dict: query_dict["diarize"] = "true" + # Inject hot words for streaming (Deepgram only) + if self.model.model_provider == "deepgram": + try: + registry = get_prompt_registry() + hot_words_str = await registry.get_prompt("asr.hot_words") + if hot_words_str and hot_words_str.strip(): + keyterm = _parse_hot_words_to_keyterm(hot_words_str) + if keyterm: + query_dict["keyterm"] = keyterm + except Exception as e: + logger.debug(f"Failed to fetch asr.hot_words for streaming: {e}") + # Normalize boolean values to lowercase strings (Deepgram expects "true"/"false", not "True"/"False") normalized_query = {} for k, v in query_dict.items(): diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/base.py b/backends/advanced/src/advanced_omi_backend/services/transcription/base.py index 7d0f2306..bc5cf6f7 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/base.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/base.py @@ -122,12 +122,14 @@ def mode(self) -> str: return "batch" @abc.abstractmethod - async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False) -> dict: + async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False, context_info: Optional[str] = None, **kwargs) -> dict: """Transcribe audio data. Args: audio_data: Raw audio bytes sample_rate: Audio sample rate diarize: Whether to enable speaker diarization (provider-dependent) + context_info: Optional ASR context (hot words, jargon) to boost recognition + **kwargs: Additional parameters (e.g. Langfuse trace IDs) """ pass diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/context.py b/backends/advanced/src/advanced_omi_backend/services/transcription/context.py new file mode 100644 index 00000000..6b08daa8 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/context.py @@ -0,0 +1,90 @@ +""" +ASR context builder for transcription. + +Combines static hot words from the prompt registry with per-user dynamic +jargon cached in Redis by the ``asr_jargon_extraction`` cron job. +""" + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import redis.asyncio as aioredis + +logger = logging.getLogger(__name__) + +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") + + +@dataclass +class TranscriptionContext: + """Structured context gathered before transcription. + + Holds the individual context components so callers can inspect them + and Langfuse spans can log them as structured metadata. + """ + + hot_words: str = "" + user_jargon: str = "" + user_id: Optional[str] = None + + @property + def combined(self) -> str: + """Comma-separated string suitable for passing to ASR providers.""" + parts = [p.strip() for p in [self.hot_words, self.user_jargon] if p and p.strip()] + return ", ".join(parts) + + def to_metadata(self) -> dict: + """Return a dict suitable for Langfuse span metadata.""" + return { + "hot_words": self.hot_words[:200] if self.hot_words else "", + "user_jargon": self.user_jargon[:200] if self.user_jargon else "", + "user_id": self.user_id, + "combined_length": len(self.combined), + } + + +async def gather_transcription_context(user_id: Optional[str] = None) -> TranscriptionContext: + """Build structured transcription context: static hot words + cached user jargon. + + Args: + user_id: If provided, also look up per-user jargon from Redis. + + Returns: + TranscriptionContext with individual components. + """ + from advanced_omi_backend.prompt_registry import get_prompt_registry + + registry = get_prompt_registry() + hot_words = await registry.get_prompt("asr.hot_words") + + user_jargon = "" + if user_id: + try: + redis_client = aioredis.from_url(REDIS_URL, decode_responses=True) + try: + user_jargon = await redis_client.get(f"asr:jargon:{user_id}") or "" + finally: + await redis_client.close() + except Exception: + pass # Redis unavailable → skip dynamic jargon + + return TranscriptionContext( + hot_words=hot_words or "", + user_jargon=user_jargon, + user_id=user_id, + ) + + +async def get_asr_context(user_id: Optional[str] = None) -> str: + """Build combined ASR context string (backward-compatible alias). + + Args: + user_id: If provided, also look up per-user jargon from Redis. + + Returns: + Comma-separated string of context terms for the ASR provider. + """ + ctx = await gather_transcription_context(user_id) + return ctx.combined diff --git a/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py b/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py index 7c14cccd..c09580af 100644 --- a/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py +++ b/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py @@ -274,8 +274,10 @@ async def identify_segment( form_data.add_field( "file", audio_wav_bytes, filename="segment.wav", content_type="audio/wav" ) + # TODO: Implement proper user mapping between MongoDB ObjectIds and speaker service integer IDs + # Speaker service expects integer user_id, not MongoDB ObjectId strings if user_id is not None: - form_data.add_field("user_id", str(user_id)) + form_data.add_field("user_id", "1") if similarity_threshold is not None: form_data.add_field("similarity_threshold", str(similarity_threshold)) @@ -309,24 +311,32 @@ async def identify_provider_segments( conversation_id: str, segments: List[Dict], user_id: Optional[str] = None, + per_segment: bool = False, + min_segment_duration: float = 1.5, ) -> Dict: """ - Identify speakers in provider-diarized segments using majority-vote per label. + Identify speakers in provider-diarized segments. + + Default mode: majority-vote per label. Picks top 3 longest segments per label, + identifies each, and majority-votes to map labels to names. - For each unique speaker label, picks the top 3 longest segments (min 1.5s), - extracts audio, calls /identify, and majority-votes to map labels to names. + Per-segment mode (per_segment=True): identifies every segment individually. + Used during reprocessing so that fine-tuned embeddings benefit each segment. Args: conversation_id: Conversation ID for audio extraction from MongoDB segments: List of dicts with keys: start, end, text, speaker user_id: Optional user ID for speaker identification + per_segment: If True, identify each segment individually instead of majority-vote + min_segment_duration: Minimum segment duration in seconds for identification Returns: Dict with 'segments' list matching diarize_identify_match() format """ if hasattr(self, "_mock_client"): return await self._mock_client.identify_provider_segments( - conversation_id, segments, user_id + conversation_id, segments, user_id, + per_segment=per_segment, min_segment_duration=min_segment_duration, ) if not self.enabled: @@ -340,7 +350,6 @@ async def identify_provider_segments( config = get_diarization_settings() similarity_threshold = config.get("similarity_threshold", 0.15) - MIN_SEGMENT_DURATION = 1.5 MAX_SAMPLES_PER_LABEL = 3 # Detect non-speech segments (e.g. [Music], [Environmental Sounds], [Human Sounds]) @@ -379,14 +388,26 @@ def _is_non_speech(seg: Dict) -> bool: f"{len(label_groups)} unique labels: {list(label_groups.keys())}" ) - # For each label, pick top N longest segments >= MIN_SEGMENT_DURATION + # Per-segment mode: identify every segment individually (used during reprocess) + if per_segment: + return await self._identify_per_segment( + conversation_id=conversation_id, + segments=segments, + speech_segments=speech_segments, + non_speech_indices=non_speech_indices, + user_id=user_id, + similarity_threshold=similarity_threshold, + min_segment_duration=min_segment_duration, + ) + + # For each label, pick top N longest segments >= min_segment_duration label_samples: Dict[str, List[Dict]] = {} for label, segs in label_groups.items(): - eligible = [s for s in segs if (s["end"] - s["start"]) >= MIN_SEGMENT_DURATION] + eligible = [s for s in segs if (s["end"] - s["start"]) >= min_segment_duration] eligible.sort(key=lambda s: s["end"] - s["start"], reverse=True) label_samples[label] = eligible[:MAX_SAMPLES_PER_LABEL] if not label_samples[label]: - logger.info(f"🎤 Label '{label}': no segments >= {MIN_SEGMENT_DURATION}s, skipping identification") + logger.info(f"🎤 Label '{label}': no segments >= {min_segment_duration}s, skipping identification") # Extract audio and identify concurrently with semaphore semaphore = asyncio.Semaphore(3) @@ -398,7 +419,7 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: conversation_id, seg["start"], seg["end"] ) result = await self.identify_segment( - wav_bytes, user_id="1", similarity_threshold=similarity_threshold + wav_bytes, user_id=user_id, similarity_threshold=similarity_threshold ) return result except Exception as e: @@ -484,6 +505,150 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: return {"segments": result_segments} + async def _identify_per_segment( + self, + conversation_id: str, + segments: List[Dict], + speech_segments: List[Dict], + non_speech_indices: set, + user_id: Optional[str], + similarity_threshold: float, + min_segment_duration: float, + ) -> Dict: + """ + Identify every speech segment individually (no majority vote). + + Used during reprocessing so that fine-tuned speaker embeddings + benefit each segment directly. + + Args: + conversation_id: Conversation ID for audio extraction + segments: All segments (speech + non-speech) in original order + speech_segments: Only the speech segments + non_speech_indices: Indices of non-speech segments in the original list + user_id: User ID for speaker identification + similarity_threshold: Similarity threshold for identification + min_segment_duration: Minimum duration for identification attempt + + Returns: + Dict with 'segments' list matching diarize_identify_match() format + """ + from advanced_omi_backend.utils.audio_chunk_utils import ( + reconstruct_audio_segment, + ) + + logger.info( + f"🎤 Per-segment identification: {len(speech_segments)} speech segments " + f"(min_duration={min_segment_duration}s)" + ) + + semaphore = asyncio.Semaphore(3) + + async def _identify_one(seg: Dict) -> Optional[Dict]: + async with semaphore: + try: + wav_bytes = await reconstruct_audio_segment( + conversation_id, seg["start"], seg["end"] + ) + return await self.identify_segment( + wav_bytes, user_id=user_id, similarity_threshold=similarity_threshold + ) + except Exception as e: + logger.warning( + f"🎤 Failed to identify segment [{seg['start']:.1f}-{seg['end']:.1f}]: {e}" + ) + return None + + # Build tasks for speech segments that meet the duration threshold + seg_tasks: List[tuple] = [] # (original_index, task_or_None) + all_tasks = [] + for i, seg in enumerate(segments): + if i in non_speech_indices: + seg_tasks.append((i, None)) + continue + duration = seg["end"] - seg["start"] + if duration >= min_segment_duration: + task = asyncio.create_task(_identify_one(seg)) + seg_tasks.append((i, task)) + all_tasks.append(task) + else: + seg_tasks.append((i, None)) # too short + + if all_tasks: + await asyncio.gather(*all_tasks, return_exceptions=True) + + # Build result segments + result_segments = [] + identified_count = 0 + for i, seg in enumerate(segments): + label = seg.get("speaker", "Unknown") + + if i in non_speech_indices: + result_segments.append({ + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": label, + "confidence": 0.0, + "status": "non_speech", + }) + continue + + # Find the matching task entry + task_entry = seg_tasks[i] + task = task_entry[1] + + if task is None: + # Too short for identification + result_segments.append({ + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": label, + "confidence": 0.0, + "status": "too_short", + }) + continue + + try: + result = task.result() + except Exception: + result = None + + if result and result.get("found"): + name = result.get("speaker_name", label) + confidence = result.get("confidence", 0.0) + result_segments.append({ + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": name, + "confidence": confidence, + "status": "identified", + }) + identified_count += 1 + else: + result_segments.append({ + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": label, + "confidence": 0.0, + "status": "unknown", + }) + + logger.info( + f"🎤 Per-segment identification complete: " + f"{identified_count}/{len(speech_segments)} segments identified, " + f"{len(result_segments)} total segments" + ) + + return {"segments": result_segments} + async def diarize_and_identify( self, audio_data: bytes, words: None, user_id: Optional[str] = None # NOT IMPLEMENTED ) -> Dict: diff --git a/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py b/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py index 8ba68adf..0e9f4cae 100644 --- a/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py +++ b/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py @@ -181,6 +181,8 @@ async def identify_provider_segments( conversation_id: str, segments: List[Dict], user_id: Optional[str] = None, + per_segment: bool = False, + min_segment_duration: float = 1.5, ) -> Dict: """Mock identify_provider_segments - returns segments with original labels.""" logger.info(f"🎤 Mock identify_provider_segments: {len(segments)} segments") diff --git a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py index 89991327..4ceaee51 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py @@ -159,65 +159,20 @@ def analyze_speech(transcript_data: dict) -> dict: } -async def generate_title(text: str, segments: Optional[list] = None) -> str: +async def generate_title_and_summary( + text: str, segments: Optional[list] = None +) -> tuple[str, str]: """ - Generate an LLM-powered title from conversation text. + Generate title and short summary in a single LLM call using full conversation context. Args: text: Conversation transcript (used if segments not provided) segments: Optional list of speaker segments with structure: [{"speaker": str, "text": str, "start": float, "end": float}, ...] - If provided, uses speaker-aware conversation formatting + If provided, uses speaker-formatted text for richer context Returns: - str: Generated title (3-6 words) or fallback - - Note: - Title intentionally does NOT include speaker names - focuses on topic/theme only. - """ - # Format conversation text from segments if provided - if segments: - conversation_text = "" - for segment in segments[:10]: # Use first 10 segments for title generation - segment_text = segment.text.strip() if segment.text else "" - if segment_text: - conversation_text += f"{segment_text}\n" - text = conversation_text if conversation_text.strip() else text - - if not text or len(text.strip()) < 10: - return "Conversation" - - try: - registry = get_prompt_registry() - prompt_template = await registry.get_prompt("conversation.title") - prompt = f"""{prompt_template} - -"{text[:500]}" -""" - - title = await async_generate(prompt, temperature=0.3) - return title.strip().strip('"').strip("'") or "Conversation" - - except Exception as e: - logger.warning(f"Failed to generate LLM title: {e}") - # Fallback to simple title generation - words = text.split()[:6] - title = " ".join(words) - return title[:40] + "..." if len(title) > 40 else title or "Conversation" - - -async def generate_short_summary(text: str, segments: Optional[list] = None) -> str: - """ - Generate a brief LLM-powered summary from conversation text. - - Args: - text: Conversation transcript (used if segments not provided) - segments: Optional list of speaker segments with structure: - [{"speaker": str, "text": str, "start": float, "end": float}, ...] - If provided, includes speaker context in summary - - Returns: - str: Generated short summary (1-2 sentences, max 120 chars) or fallback + Tuple of (title, short_summary) """ # Format conversation text from segments if provided conversation_text = text @@ -241,37 +196,52 @@ async def generate_short_summary(text: str, segments: Optional[list] = None) -> include_speakers = len(speakers_in_conv) > 0 if not conversation_text or len(conversation_text.strip()) < 10: - return "No content" + return "Conversation", "No content" try: speaker_instruction = ( - '- Include speaker names when relevant (e.g., "John discusses X with Sarah")\n' + '- Include speaker names when relevant in the summary (e.g., "John discusses X with Sarah")\n' if include_speakers else "" ) registry = get_prompt_registry() prompt_text = await registry.get_prompt( - "conversation.short_summary", + "conversation.title_summary", speaker_instruction=speaker_instruction, ) prompt = f"""{prompt_text} -"{conversation_text[:1000]}" +TRANSCRIPT: +"{conversation_text}" """ - summary = await async_generate(prompt, temperature=0.3) - return summary.strip().strip('"').strip("'") or "No content" + response = await async_generate(prompt, temperature=0.3) + + # Parse response for Title: and Summary: lines + title = None + summary = None + for line in response.strip().split("\n"): + line = line.strip() + if line.startswith("Title:"): + title = line.replace("Title:", "").strip().strip('"').strip("'") + elif line.startswith("Summary:"): + summary = line.replace("Summary:", "").strip().strip('"').strip("'") + + title = title or "Conversation" + summary = summary or "No content" + + return title, summary except Exception as e: - logger.warning(f"Failed to generate LLM short summary: {e}") - # Fallback to simple summary generation - return ( - conversation_text[:120] + "..." - if len(conversation_text) > 120 - else conversation_text or "No content" - ) + logger.warning(f"Failed to generate title and summary: {e}") + # Fallback + words = text.split()[:6] + fallback_title = " ".join(words) + fallback_title = fallback_title[:40] + "..." if len(fallback_title) > 40 else fallback_title + fallback_summary = text[:120] + "..." if len(text) > 120 else text + return fallback_title or "Conversation", fallback_summary or "No content" diff --git a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py index 18420ddf..04754431 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py @@ -840,8 +840,7 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) from advanced_omi_backend.models.conversation import Conversation from advanced_omi_backend.utils.conversation_utils import ( generate_detailed_summary, - generate_short_summary, - generate_title, + generate_title_and_summary, ) logger.info(f"📝 Starting title/summary generation for conversation {conversation_id}") @@ -893,12 +892,11 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) except Exception as mem_error: logger.warning(f"⚠️ Could not fetch memory context (continuing without): {mem_error}") - # Generate all three summaries in parallel for efficiency + # Generate title+summary (one call) and detailed summary in parallel import asyncio - title, short_summary, detailed_summary = await asyncio.gather( - generate_title(transcript_text, segments=segments), - generate_short_summary(transcript_text, segments=segments), + (title, short_summary), detailed_summary = await asyncio.gather( + generate_title_and_summary(transcript_text, segments=segments), generate_detailed_summary( transcript_text, segments=segments, memory_context=memory_context ), diff --git a/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py new file mode 100644 index 00000000..f7f77daf --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py @@ -0,0 +1,235 @@ +""" +Cron job implementations for the Chronicle scheduler. + +Jobs: + - speaker_finetuning: sends applied diarization annotations to speaker service + - asr_jargon_extraction: extracts jargon from recent memories, caches in Redis +""" + +import logging +import os +import time +from datetime import datetime, timezone +from typing import Optional + +import redis.asyncio as aioredis + +from advanced_omi_backend.llm_client import async_generate +from advanced_omi_backend.prompt_registry import get_prompt_registry + +logger = logging.getLogger(__name__) + +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") + +# TTL for cached jargon: 2 hours (job runs every 30 min, so always refreshed) +JARGON_CACHE_TTL = 7200 + +# Maximum number of recent memories to pull per user +MAX_RECENT_MEMORIES = 50 + +# How far back to look for memories (24 hours in seconds) +MEMORY_LOOKBACK_SECONDS = 86400 + + +# --------------------------------------------------------------------------- +# Job 1: Speaker Fine-tuning +# --------------------------------------------------------------------------- + +async def run_speaker_finetuning_job() -> dict: + """Process applied diarization annotations and send to speaker recognition service. + + This mirrors the logic in ``finetuning_routes.process_annotations_for_training`` + but is invocable from the cron scheduler without an HTTP request. + """ + from advanced_omi_backend.models.annotation import Annotation, AnnotationType + from advanced_omi_backend.models.conversation import Conversation + from advanced_omi_backend.speaker_recognition_client import SpeakerRecognitionClient + from advanced_omi_backend.utils.audio_chunk_utils import reconstruct_audio_segment + + # Find annotations ready for training + annotations = await Annotation.find( + Annotation.annotation_type == AnnotationType.DIARIZATION, + Annotation.processed == True, + ).to_list() + + ready_for_training = [ + a for a in annotations if not a.processed_by or "training" not in a.processed_by + ] + + if not ready_for_training: + logger.info("Speaker finetuning: no annotations ready for training") + return {"processed": 0, "message": "No annotations ready for training"} + + speaker_client = SpeakerRecognitionClient() + if not speaker_client.enabled: + logger.warning("Speaker finetuning: speaker recognition service is not enabled") + return {"processed": 0, "message": "Speaker recognition service not enabled"} + + enrolled = 0 + appended = 0 + failed = 0 + cleaned = 0 + + for annotation in ready_for_training: + try: + conversation = await Conversation.find_one( + Conversation.conversation_id == annotation.conversation_id + ) + if not conversation or not conversation.active_transcript: + logger.warning( + f"Conversation {annotation.conversation_id} not found — " + f"deleting orphaned annotation {annotation.id}" + ) + await annotation.delete() + cleaned += 1 + continue + + if annotation.segment_index >= len(conversation.active_transcript.segments): + logger.warning( + f"Invalid segment index {annotation.segment_index} for " + f"conversation {annotation.conversation_id} — " + f"deleting orphaned annotation {annotation.id}" + ) + await annotation.delete() + cleaned += 1 + continue + + segment = conversation.active_transcript.segments[annotation.segment_index] + + wav_bytes = await reconstruct_audio_segment( + conversation_id=annotation.conversation_id, + start_time=segment.start, + end_time=segment.end, + ) + if not wav_bytes: + failed += 1 + continue + + existing_speaker = await speaker_client.get_speaker_by_name( + speaker_name=annotation.corrected_speaker, + user_id=1, + ) + + if existing_speaker: + result = await speaker_client.append_to_speaker( + speaker_id=existing_speaker["id"], audio_data=wav_bytes + ) + if "error" in result: + failed += 1 + continue + appended += 1 + else: + result = await speaker_client.enroll_new_speaker( + speaker_name=annotation.corrected_speaker, + audio_data=wav_bytes, + user_id=1, + ) + if "error" in result: + failed += 1 + continue + enrolled += 1 + + # Mark as trained + annotation.processed_by = ( + f"{annotation.processed_by},training" if annotation.processed_by else "training" + ) + annotation.updated_at = datetime.now(timezone.utc) + await annotation.save() + + except Exception as e: + logger.error(f"Speaker finetuning: error processing annotation {annotation.id}: {e}") + failed += 1 + + total = enrolled + appended + logger.info( + f"Speaker finetuning complete: {total} processed " + f"({enrolled} new, {appended} appended, {failed} failed, {cleaned} orphaned cleaned)" + ) + return {"enrolled": enrolled, "appended": appended, "failed": failed, "cleaned": cleaned, "processed": total} + + +# --------------------------------------------------------------------------- +# Job 2: ASR Jargon Extraction +# --------------------------------------------------------------------------- + +async def run_asr_jargon_extraction_job() -> dict: + """Extract jargon from recent memories for all users and cache in Redis.""" + from advanced_omi_backend.models.user import User + + users = await User.find_all().to_list() + processed = 0 + skipped = 0 + errors = 0 + + redis_client = aioredis.from_url(REDIS_URL, decode_responses=True) + try: + for user in users: + user_id = str(user.id) + try: + jargon = await _extract_jargon_for_user(user_id) + if jargon: + await redis_client.set(f"asr:jargon:{user_id}", jargon, ex=JARGON_CACHE_TTL) + processed += 1 + logger.debug(f"Cached jargon for user {user_id}: {jargon[:80]}...") + else: + skipped += 1 + except Exception as e: + logger.error(f"Jargon extraction failed for user {user_id}: {e}") + errors += 1 + finally: + await redis_client.close() + + logger.info( + f"ASR jargon extraction complete: {processed} users processed, " + f"{skipped} skipped, {errors} errors" + ) + return {"users_processed": processed, "skipped": skipped, "errors": errors} + + +async def _extract_jargon_for_user(user_id: str) -> Optional[str]: + """Pull recent memories from Qdrant, call LLM to extract jargon terms. + + Returns a comma-separated string of jargon terms, or None if nothing found. + """ + from advanced_omi_backend.services.memory import get_memory_service + from advanced_omi_backend.services.memory.providers.chronicle import MemoryService + + memory_service = get_memory_service() + + # Only works with Chronicle provider (has Qdrant vector store) + if not isinstance(memory_service, MemoryService): + logger.debug("Jargon extraction requires Chronicle memory provider, skipping") + return None + + if memory_service.vector_store is None: + return None + + since_ts = int(time.time()) - MEMORY_LOOKBACK_SECONDS + + memories = await memory_service.vector_store.get_recent_memories( + user_id=user_id, + since_timestamp=since_ts, + limit=MAX_RECENT_MEMORIES, + ) + + if not memories: + return None + + # Concatenate memory content + memory_text = "\n".join(m.content for m in memories if m.content) + if not memory_text.strip(): + return None + + # Use LLM to extract jargon + registry = get_prompt_registry() + prompt_template = await registry.get_prompt("asr.jargon_extraction", memories=memory_text) + + result = await async_generate(prompt_template) + + # Clean up: strip whitespace, remove empty items + if result: + terms = [t.strip() for t in result.split(",") if t.strip()] + if terms: + return ", ".join(terms) + + return None diff --git a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py index 9c227bd9..f4adb6e3 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py @@ -2,12 +2,20 @@ Memory-related RQ job functions. This module contains jobs related to memory extraction and processing. + +Supports two processing pathways: +1. **Normal extraction**: Extracts fresh facts from transcript, deduplicates + against existing user memories, and proposes ADD/UPDATE/DELETE actions. +2. **Speaker reprocess**: When triggered after speaker re-identification, + computes a diff between old and new speaker labels, fetches existing + conversation-specific memories, and asks the LLM to make targeted + corrections to speaker attribution in those memories. """ import logging import time import uuid -from typing import Any, Dict +from typing import Any, Dict, List, Optional from advanced_omi_backend.controllers.queue_controller import ( JOB_RESULT_TTL, @@ -21,6 +29,85 @@ MIN_CONVERSATION_LENGTH = 10 +def compute_speaker_diff( + old_segments: list, + new_segments: list, +) -> List[Dict[str, Any]]: + """Compare old and new transcript segments to identify speaker changes. + + Matches segments by time overlap and detects where speaker labels differ. + + Args: + old_segments: Segments from the previous transcript version + new_segments: Segments from the new (active) transcript version + + Returns: + List of change dicts, each with keys: + - ``type``: "speaker_change", "text_change", or "new_segment" + - ``text``: The segment text + - ``old_speaker`` / ``new_speaker``: For speaker changes + - ``old_text`` / ``new_text``: For text changes + - ``start`` / ``end``: Time boundaries + """ + changes: List[Dict[str, Any]] = [] + + for new_seg in new_segments: + new_start = new_seg.start + new_end = new_seg.end + + # Find best matching old segment by time overlap + best_match = None + best_overlap = 0.0 + + for old_seg in old_segments: + overlap_start = max(old_seg.start, new_start) + overlap_end = min(old_seg.end, new_end) + overlap = max(0.0, overlap_end - overlap_start) + + if overlap > best_overlap: + best_overlap = overlap + best_match = old_seg + + if best_match: + # Check for speaker change + if best_match.speaker != new_seg.speaker: + changes.append( + { + "type": "speaker_change", + "text": new_seg.text.strip(), + "old_speaker": best_match.speaker, + "new_speaker": new_seg.speaker, + "start": new_start, + "end": new_end, + } + ) + # Check for text change (less common in speaker reprocessing) + if best_match.text.strip() != new_seg.text.strip(): + changes.append( + { + "type": "text_change", + "old_text": best_match.text.strip(), + "new_text": new_seg.text.strip(), + "speaker": new_seg.speaker, + "start": new_start, + "end": new_end, + } + ) + else: + # No matching old segment found + changes.append( + { + "type": "new_segment", + "text": new_seg.text.strip(), + "speaker": new_seg.speaker, + "start": new_start, + "end": new_end, + } + ) + + return changes + + @async_job(redis=True, beanie=True) async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict[str, Any]: """ @@ -113,17 +200,42 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict ) return {"success": True, "skipped": True, "reason": "No primary speakers"} - # Process memory - memory_service = get_memory_service() - memory_result = await memory_service.add_memory( - full_conversation, - client_id, - conversation_id, - user_id, - user_email, - allow_update=True, + # Detect reprocess trigger from RQ job metadata + from rq import get_current_job as _get_current_job + + current_rq_job = _get_current_job() + trigger = ( + current_rq_job.meta.get("trigger") + if current_rq_job and current_rq_job.meta + else None ) + # Process memory — choose pathway based on trigger + memory_service = get_memory_service() + + if trigger == "reprocess_after_speaker": + # === Speaker reprocess pathway === + # Compute diff between old and new transcript versions + memory_result = await _process_speaker_reprocess( + memory_service=memory_service, + conversation_model=conversation_model, + full_conversation=full_conversation, + client_id=client_id, + conversation_id=conversation_id, + user_id=user_id, + user_email=user_email, + ) + else: + # === Normal extraction pathway === + memory_result = await memory_service.add_memory( + full_conversation, + client_id, + conversation_id, + user_id, + user_email, + allow_update=True, + ) + if memory_result: success, created_memory_ids = memory_result @@ -301,6 +413,119 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict return {"success": False, "error": "Memory service returned False"} +async def _process_speaker_reprocess( + memory_service, + conversation_model, + full_conversation: str, + client_id: str, + conversation_id: str, + user_id: str, + user_email: str, +): + """Handle memory reprocessing after speaker re-identification. + + Computes the diff between the previous and current transcript versions + (specifically speaker label changes), then delegates to the memory + service's ``reprocess_memory`` method for targeted updates. + + Falls back to normal ``add_memory`` if diff computation fails or + no meaningful changes are detected. + + Args: + memory_service: Active memory service instance + conversation_model: Conversation Beanie document + full_conversation: New transcript as dialogue lines + client_id: Client identifier + conversation_id: Conversation identifier + user_id: User identifier + user_email: User email + + Returns: + Tuple of (success, memory_ids) matching ``add_memory`` return type + """ + active_version = conversation_model.active_transcript + + if not active_version: + logger.warning( + f"🔄 Reprocess: no active transcript version for {conversation_id}, " + f"falling back to normal extraction" + ) + return await memory_service.add_memory( + full_conversation, client_id, conversation_id, user_id, user_email, + allow_update=True, + ) + + # Find the source (previous) transcript version from metadata + source_version_id = active_version.metadata.get("source_version_id") + + if not source_version_id: + logger.warning( + f"🔄 Reprocess: no source_version_id in active transcript metadata " + f"for {conversation_id}, falling back to normal extraction" + ) + return await memory_service.add_memory( + full_conversation, client_id, conversation_id, user_id, user_email, + allow_update=True, + ) + + # Find the source version's segments + source_version = None + for v in conversation_model.transcript_versions: + if v.version_id == source_version_id: + source_version = v + break + + if not source_version or not source_version.segments: + logger.warning( + f"🔄 Reprocess: source version {source_version_id} not found or has no segments " + f"for {conversation_id}, falling back to normal extraction" + ) + return await memory_service.add_memory( + full_conversation, client_id, conversation_id, user_id, user_email, + allow_update=True, + ) + + # Compute the speaker diff + transcript_diff = compute_speaker_diff( + source_version.segments, + active_version.segments, + ) + + if not transcript_diff: + logger.info( + f"🔄 Reprocess: no speaker changes detected between versions " + f"for {conversation_id}, falling back to normal extraction" + ) + return await memory_service.add_memory( + full_conversation, client_id, conversation_id, user_id, user_email, + allow_update=True, + ) + + # Build the previous transcript for context + previous_lines = [] + for seg in source_version.segments: + text = seg.text.strip() + if text: + previous_lines.append(f"{seg.speaker}: {text}") + previous_transcript = "\n".join(previous_lines) + + logger.info( + f"🔄 Reprocess: detected {len(transcript_diff)} changes " + f"(speakers reprocessed) for {conversation_id}" + ) + + # Use the reprocess pathway + return await memory_service.reprocess_memory( + transcript=full_conversation, + client_id=client_id, + source_id=conversation_id, + user_id=user_id, + user_email=user_email, + transcript_diff=transcript_diff, + previous_transcript=previous_transcript, + ) + + def enqueue_memory_processing( conversation_id: str, priority: JobPriority = JobPriority.NORMAL, diff --git a/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py index 8c90701e..5ab6afa6 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/speaker_jobs.py @@ -279,25 +279,12 @@ async def recognise_speakers_job( can_run_pyannote = bool(actual_words) and not provider_has_diarization if not actual_words and not provider_has_diarization: - # No words AND provider didn't diarize - we have a problem - # This can happen with VibeVoice if it fails to return segments - logger.warning( - f"🎤 No word timestamps available and provider didn't diarize. " - f"Speaker recognition cannot improve segments." - ) - # Keep existing segments and return success (we can't do better) - if transcript_version.segments: - return { - "success": True, - "conversation_id": conversation_id, - "version_id": version_id, - "speaker_recognition_enabled": True, - "identified_speakers": [], - "segment_count": len(transcript_version.segments), - "skip_reason": "No word timestamps available for pyannote, keeping provider segments", - "processing_time_seconds": time.time() - start_time - } - else: + if not transcript_version.segments: + # No words, no provider diarization, no existing segments - nothing we can do + logger.warning( + f"🎤 No word timestamps available, provider didn't diarize, " + f"and no existing segments to identify." + ) return { "success": False, "conversation_id": conversation_id, @@ -305,11 +292,37 @@ async def recognise_speakers_job( "error": "No word timestamps and no segments available", "processing_time_seconds": time.time() - start_time } + # Has existing segments - fall through to run identification on them + logger.info( + f"🎤 No word timestamps for pyannote re-diarization, but " + f"{len(transcript_version.segments)} existing segments found. " + f"Running speaker identification on existing segments." + ) + + # Determine speaker identification mode: + # 1. Config toggle (per_segment_speaker_id) enables per-segment globally + # 2. Manual reprocess trigger also enables per-segment for that run + from advanced_omi_backend.config import get_misc_settings + misc_config = get_misc_settings() + per_segment_config = misc_config.get("per_segment_speaker_id", False) + + trigger = transcript_version.metadata.get("trigger", "") + is_reprocess = trigger == "manual_reprocess" + + use_per_segment = per_segment_config or is_reprocess + if use_per_segment: + reason = [] + if per_segment_config: + reason.append("config toggle enabled") + if is_reprocess: + reason.append("manual reprocess") + logger.info(f"🎤 Per-segment identification mode active ({', '.join(reason)})") try: - if provider_has_diarization and transcript_version.segments: - # Provider already diarized (e.g. VibeVoice) - use segment-level identification - logger.info(f"🎤 Using segment-level speaker identification for provider-diarized segments") + if transcript_version.segments and not can_run_pyannote: + # Have existing segments and can't/shouldn't run pyannote - do identification only + # Covers: provider already diarized, no word timestamps but segments exist, etc. + logger.info(f"🎤 Using segment-level speaker identification on {len(transcript_version.segments)} existing segments") segments_data = [ {"start": s.start, "end": s.end, "text": s.text, "speaker": s.speaker} for s in transcript_version.segments @@ -318,6 +331,8 @@ async def recognise_speakers_job( conversation_id=conversation_id, segments=segments_data, user_id=user_id, + per_segment=use_per_segment, + min_segment_duration=0.5 if use_per_segment else 1.5, ) else: # Standard path: full diarization + identification via speaker service @@ -502,6 +517,7 @@ async def recognise_speakers_job( transcript_version.metadata["speaker_recognition"] = { "enabled": True, + "identification_mode": "per_segment" if use_per_segment else "majority_vote", "identified_speakers": list(identified_speakers), "speaker_count": len(identified_speakers), "total_segments": len(speaker_segments), diff --git a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py index 19483f55..d1cc23e1 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py @@ -215,13 +215,25 @@ async def transcribe_full_audio_job( logger.error(f"Failed to reconstruct audio from MongoDB: {e}", exc_info=True) raise RuntimeError(f"Audio reconstruction failed: {e}") + # Build ASR context (static hot words + per-user cached jargon) + try: + from advanced_omi_backend.services.transcription.context import get_asr_context + + context_info = await get_asr_context(user_id=user_id) + except Exception as e: + logger.warning(f"Failed to build ASR context: {e}") + context_info = None + try: # Transcribe the audio directly from memory (no disk I/O needed) - transcription_result = await provider.transcribe( - audio_data=wav_data, # Pass bytes directly, already in memory - sample_rate=16000, - diarize=True, - ) + transcribe_kwargs: Dict[str, Any] = { + "audio_data": wav_data, + "sample_rate": 16000, + "diarize": True, + } + if context_info: + transcribe_kwargs["context_info"] = context_info + transcription_result = await provider.transcribe(**transcribe_kwargs) except ConnectionError as e: logger.exception(f"Transcription service unreachable for {conversation_id}") raise RuntimeError(str(e)) @@ -653,7 +665,7 @@ async def create_audio_only_conversation( @async_job(redis=True, beanie=True) async def transcription_fallback_check_job( - session_id: str, user_id: str, client_id: str, timeout_seconds: int = 1800, *, redis_client=None + session_id: str, user_id: str, client_id: str, timeout_seconds: int = 900, *, redis_client=None ) -> Dict[str, Any]: """ Check if streaming transcription succeeded, fallback to batch if needed. @@ -666,7 +678,7 @@ async def transcription_fallback_check_job( session_id: Stream session ID user_id: User ID client_id: Client ID - timeout_seconds: Max wait time for batch transcription (default 30 minutes) + timeout_seconds: Max wait time for batch transcription (default 15 minutes) redis_client: Redis client (injected by decorator) Returns: @@ -822,7 +834,7 @@ async def transcription_fallback_check_job( conversation.conversation_id, version_id, "batch_fallback", - job_timeout=1800, + job_timeout=900, # 15 minutes job_id=f"transcribe_{conversation.conversation_id[:12]}", description=f"Batch transcription fallback for {session_id[:8]}", meta={"session_id": session_id, "client_id": client_id}, @@ -1248,8 +1260,8 @@ async def stream_speech_detection_job( session_id, user_id, client_id, - timeout_seconds=1800, # 30 minutes for batch transcription - job_timeout=2400, # 40 minutes job timeout + timeout_seconds=900, # 15 minutes for batch transcription + job_timeout=1200, # 20 minutes job timeout (includes overhead for fallback check) job_id=f"fallback_check_{session_id[:12]}", description=f"Transcription fallback check for {session_id[:8]} (no speech)", meta={"session_id": session_id, "client_id": client_id, "no_speech": True}, diff --git a/backends/advanced/webui/package-lock.json b/backends/advanced/webui/package-lock.json index ead72812..7fe0c6d6 100644 --- a/backends/advanced/webui/package-lock.json +++ b/backends/advanced/webui/package-lock.json @@ -10,6 +10,7 @@ "dependencies": { "axios": "^1.6.2", "clsx": "^2.0.0", + "cronstrue": "^2.50.0", "d3": "^7.8.5", "frappe-gantt": "^1.0.4", "lucide-react": "^0.294.0", @@ -2719,6 +2720,15 @@ "dev": true, "license": "MIT" }, + "node_modules/cronstrue": { + "version": "2.59.0", + "resolved": "https://registry.npmjs.org/cronstrue/-/cronstrue-2.59.0.tgz", + "integrity": "sha512-YKGmAy84hKH+hHIIER07VCAHf9u0Ldelx1uU6EBxsRPDXIA1m5fsKmJfyC3xBhw6cVC/1i83VdbL4PvepTrt8A==", + "license": "MIT", + "bin": { + "cronstrue": "bin/cli.js" + } + }, "node_modules/cross-spawn": { "version": "7.0.6", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", diff --git a/backends/advanced/webui/package.json b/backends/advanced/webui/package.json index b933d8db..c6c1f771 100644 --- a/backends/advanced/webui/package.json +++ b/backends/advanced/webui/package.json @@ -11,6 +11,7 @@ }, "dependencies": { "axios": "^1.6.2", + "cronstrue": "^2.50.0", "clsx": "^2.0.0", "d3": "^7.8.5", "frappe-gantt": "^1.0.4", diff --git a/backends/advanced/webui/src/components/knowledge-graph/EntityCard.tsx b/backends/advanced/webui/src/components/knowledge-graph/EntityCard.tsx index 76d28cf9..bb48fef1 100644 --- a/backends/advanced/webui/src/components/knowledge-graph/EntityCard.tsx +++ b/backends/advanced/webui/src/components/knowledge-graph/EntityCard.tsx @@ -1,4 +1,6 @@ -import { User, MapPin, Building, Calendar, Package, Link2 } from 'lucide-react' +import { useState } from 'react' +import { User, MapPin, Building, Calendar, Package, Link2, Pencil, Check, X } from 'lucide-react' +import { knowledgeGraphApi } from '../../services/api' export interface Entity { id: string @@ -20,6 +22,7 @@ export interface Entity { interface EntityCardProps { entity: Entity onClick?: (entity: Entity) => void + onEntityUpdated?: (entity: Entity) => void compact?: boolean } @@ -39,7 +42,12 @@ const typeColors: Record = { thing: 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-300', } -export default function EntityCard({ entity, onClick, compact = false }: EntityCardProps) { +export default function EntityCard({ entity, onClick, onEntityUpdated, compact = false }: EntityCardProps) { + const [isEditing, setIsEditing] = useState(false) + const [editName, setEditName] = useState(entity.name) + const [editDetails, setEditDetails] = useState(entity.details || '') + const [saving, setSaving] = useState(false) + const icon = typeIcons[entity.type] || const colorClass = typeColors[entity.type] || typeColors.thing @@ -52,6 +60,41 @@ export default function EntityCard({ entity, onClick, compact = false }: EntityC } } + const handleEditClick = (e: React.MouseEvent) => { + e.stopPropagation() + setEditName(entity.name) + setEditDetails(entity.details || '') + setIsEditing(true) + } + + const handleCancel = (e: React.MouseEvent) => { + e.stopPropagation() + setIsEditing(false) + } + + const handleSave = async (e: React.MouseEvent) => { + e.stopPropagation() + const updates: { name?: string; details?: string } = {} + if (editName.trim() !== entity.name) updates.name = editName.trim() + if (editDetails.trim() !== (entity.details || '')) updates.details = editDetails.trim() + + if (Object.keys(updates).length === 0) { + setIsEditing(false) + return + } + + try { + setSaving(true) + const response = await knowledgeGraphApi.updateEntity(entity.id, updates) + setIsEditing(false) + onEntityUpdated?.(response.data.entity) + } catch (err) { + console.error('Failed to update entity:', err) + } finally { + setSaving(false) + } + } + if (compact) { return (
onClick?.(entity)} - className="bg-white dark:bg-gray-800 rounded-lg border border-gray-200 dark:border-gray-700 p-4 hover:border-blue-400 dark:hover:border-blue-500 transition-colors cursor-pointer group" + onClick={() => !isEditing && onClick?.(entity)} + className={`bg-white dark:bg-gray-800 rounded-lg border border-gray-200 dark:border-gray-700 p-4 hover:border-blue-400 dark:hover:border-blue-500 transition-colors ${isEditing ? '' : 'cursor-pointer'} group`} >
-
-
+
+
{entity.icon ? ( {entity.icon} ) : ( icon )}
-
-

- {entity.name} -

+
+ {isEditing ? ( + setEditName(e.target.value)} + onClick={(e) => e.stopPropagation()} + className="w-full px-2 py-1 text-sm font-semibold border border-gray-300 dark:border-gray-600 rounded bg-white dark:bg-gray-700 text-gray-900 dark:text-gray-100 focus:outline-none focus:ring-2 focus:ring-blue-500" + autoFocus + /> + ) : ( +

+ {entity.name} +

+ )} {entity.type}
- {entity.relationship_count != null && entity.relationship_count > 0 && ( -
- - {entity.relationship_count} -
- )} +
+ {isEditing ? ( + <> + + + + ) : ( + <> + + {entity.relationship_count != null && entity.relationship_count > 0 && ( +
+ + {entity.relationship_count} +
+ )} + + )} +
- {entity.details && ( -

- {entity.details} -

+ {isEditing ? ( +