From 9da8845e3f157fd5b50111d3daaf2f35a53de3f1 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 05:13:53 +0100 Subject: [PATCH 001/163] Add focus analysis handler for desktop screen_frame messages (#5396) Co-Authored-By: Claude Opus 4.6 --- backend/utils/desktop/__init__.py | 0 backend/utils/desktop/focus.py | 149 ++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 backend/utils/desktop/__init__.py create mode 100644 backend/utils/desktop/focus.py diff --git a/backend/utils/desktop/__init__.py b/backend/utils/desktop/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/utils/desktop/focus.py b/backend/utils/desktop/focus.py new file mode 100644 index 0000000000..6807eeded9 --- /dev/null +++ b/backend/utils/desktop/focus.py @@ -0,0 +1,149 @@ +import logging +from typing import Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.goals import get_user_goals +from database.action_items import get_action_items +from database.memories import get_memories +from utils.llm.clients import llm_gemini_flash + +logger = logging.getLogger(__name__) + +# Match the desktop FocusAssistant's ScreenAnalysis schema +FOCUS_SYSTEM_PROMPT = """You are a focus coach. Analyze the PRIMARY/MAIN window in screenshots to determine \ +if the user is focused or distracted. + +IMPORTANT: Look at the MAIN APPLICATION WINDOW, not log text or terminal output. \ +If you see a code editor with logs that mention "YouTube" - that's just log text, \ +the user is CODING, not on YouTube. Text in logs/terminals mentioning a site does \ +NOT mean the user is on that site. + +CONTEXT-AWARE ANALYSIS: +Each request may include the user's active goals, current tasks, recent memories, \ +and analysis history. Use this context when available, but DO NOT let it prevent you \ +from flagging obvious distractions. + +- GOALS & TASKS: If the user's screen activity clearly relates to their active \ +goals or current tasks, they are FOCUSED. +- HISTORY: Use recent analysis history to notice patterns, acknowledge transitions, \ +and vary your responses. + +Set status to "distracted" if the PRIMARY window is: +- YouTube, Twitch, Netflix, TikTok (actual video site visible, not just text mentioning it) +- Social media feeds: Twitter/X, Instagram, Facebook, Reddit (casual browsing, not researching) +- News sites, entertainment sites, games +- Any content consumption with no clear work purpose + +Set status to "focused" if the PRIMARY window is: +- Code editors, IDEs, terminals, command line +- Documents, spreadsheets, slides, design tools +- Email, work chat (Slack, Teams), research +- Browsing that is clearly work-related (Stack Overflow, docs, PRs, Jira, etc.) + +When in doubt, lean toward "distracted" — it's better to nudge the user once too \ +often than to silently let them drift. + +Always provide a short coaching message (100 characters max for notification banner): +- If distracted: Create a unique nudge to refocus. Vary your approach — be playful, \ +direct, or motivational. +- If focused: Acknowledge their work with variety — don't just say "Nice focus!" \ +every time.""" + + +class FocusResult(BaseModel): + status: str = Field(description='Focus status: "focused" or "distracted"') + app_or_site: str = Field(description="Primary app or site in focus") + description: str = Field(description="Brief description of what the user is doing") + message: Optional[str] = Field(default=None, description="Short coaching message (max 100 chars)") + + +def _build_context(uid: str) -> str: + """Build context from user's goals, tasks, and memories (server-side).""" + parts = [] + + # Goals (up to 10) + try: + goals = get_user_goals(uid, limit=10) + if goals: + goal_lines = [f"- {g.get('title', g.get('description', ''))}" for g in goals] + parts.append("Active Goals:\n" + "\n".join(goal_lines)) + except Exception as e: + logger.warning(f"Failed to fetch goals for context: {e}") + + # Tasks (up to 50, not completed) + try: + tasks = get_action_items(uid, completed=False, limit=50) + if tasks: + task_lines = [f"- {t.get('description', '')}" for t in tasks[:50]] + parts.append("Current Tasks:\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch tasks for context: {e}") + + # Recent memories (up to 20, core category) + try: + memories = get_memories(uid, limit=20, categories=['core']) + if memories: + mem_lines = [f"- {m.get('structured', {}).get('title', m.get('content', ''))}" for m in memories[:20]] + parts.append("Recent Memories:\n" + "\n".join(mem_lines)) + except Exception as e: + logger.warning(f"Failed to fetch memories for context: {e}") + + return "\n\n".join(parts) if parts else "" + + +async def analyze_focus( + uid: str, + image_b64: str, + app_name: str = "", + window_title: str = "", + history: str = "", +) -> dict: + """Analyze a screenshot for focus status using vision LLM. + + Args: + uid: User ID for fetching context + image_b64: Base64-encoded JPEG screenshot + app_name: Name of the foreground app + window_title: Window title + history: Formatted recent analysis history + + Returns: + Dict with type, frame_id, status, app_or_site, description, message + """ + # Build context from user data + context = _build_context(uid) + + # Assemble prompt + prompt_parts = [] + if context: + prompt_parts.append(context) + if history: + prompt_parts.append(f"Recent activity (oldest to newest):\n{history}") + if app_name or window_title: + prompt_parts.append(f"Current app: {app_name}, Window: {window_title}") + prompt_parts.append("Now analyze this screenshot:") + + prompt_text = "\n\n".join(prompt_parts) + + # Call vision LLM with structured output + with_parser = llm_gemini_flash.with_structured_output(FocusResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=FOCUS_SYSTEM_PROMPT), + HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}, + ] + ), + ] + ) + + return { + "status": result.status, + "app_or_site": result.app_or_site, + "description": result.description, + "message": result.message, + } From 4206abe99b4fd8d11ab5c18aca766bc45e7a2115 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 05:13:57 +0100 Subject: [PATCH 002/163] Add FocusResultEvent message type for desktop proactive AI (#5396) Co-Authored-By: Claude Opus 4.6 --- backend/models/message_event.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/backend/models/message_event.py b/backend/models/message_event.py index bddbeb2a27..36556f6af9 100644 --- a/backend/models/message_event.py +++ b/backend/models/message_event.py @@ -181,3 +181,21 @@ def to_json(self): j["type"] = self.event_type del j["event_type"] return j + + +# Desktop proactive AI events (Phase 2 — #5396) + + +class FocusResultEvent(MessageEvent): + event_type: str = "focus_result" + frame_id: str + status: str + app_or_site: str + description: str + message: Optional[str] = None + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j From dc9d76572dfeebc3edb21237c65d48e7019c9d40 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 05:14:01 +0100 Subject: [PATCH 003/163] Add screen_frame dispatcher to /v4/listen for desktop focus analysis (#5396) Co-Authored-By: Claude Opus 4.6 --- backend/routers/transcribe.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 7a8910204f..658f275a7d 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -51,6 +51,7 @@ ) from models.message_event import ( ConversationEvent, + FocusResultEvent, FREEMIUM_ACTION_SETUP_ON_DEVICE_STT, FreemiumThresholdReachedEvent, LastConversationEvent, @@ -100,6 +101,7 @@ SPEAKER_MATCH_THRESHOLD, ) from utils.speaker_sample_migration import maybe_migrate_person_samples +from utils.desktop.focus import analyze_focus from utils.log_sanitizer import sanitize, sanitize_pii logger = logging.getLogger(__name__) @@ -2127,6 +2129,38 @@ async def close_soniox_profile(): logger.info( f"Speaker assignment ignored: missing speaker_id/person_id/person_name. {uid} {session_id}" ) + # Desktop proactive AI — screen_frame analysis (#5396) + elif json_data.get('type') == 'screen_frame': + frame_id = json_data.get('frame_id', '') + image_b64 = json_data.get('image_b64', '') + analyze_types = json_data.get('analyze', []) + if image_b64 and 'focus' in analyze_types: + async def _handle_focus(fid, img, app, wtitle): + try: + result = await analyze_focus( + uid=uid, + image_b64=img, + app_name=app, + window_title=wtitle, + ) + _send_message_event(FocusResultEvent( + frame_id=fid, + status=result['status'], + app_or_site=result['app_or_site'], + description=result['description'], + message=result.get('message'), + )) + except Exception as focus_err: + logger.error(f"Focus analysis failed: {focus_err} {uid} {session_id}") + + spawn(_handle_focus( + frame_id, + image_b64, + json_data.get('app_name', ''), + json_data.get('window_title', ''), + )) + elif not image_b64: + logger.warning(f"screen_frame missing image_b64 {uid} {session_id}") except json.JSONDecodeError: logger.info( f"Received non-json text message: {sanitize(message.get('text'))} {uid} {session_id}" From f102156fa44ee831af2ca2457f8fc3794555450f Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 05:14:04 +0100 Subject: [PATCH 004/163] Add 26 unit tests for desktop focus analysis (#5396) Co-Authored-By: Claude Opus 4.6 --- backend/test.sh | 1 + backend/tests/unit/test_desktop_focus.py | 382 +++++++++++++++++++++++ 2 files changed, 383 insertions(+) create mode 100644 backend/tests/unit/test_desktop_focus.py diff --git a/backend/test.sh b/backend/test.sh index 38f8192263..954cd0dbf2 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -36,3 +36,4 @@ pytest tests/unit/test_pusher_heartbeat.py -v pytest tests/unit/test_desktop_updates.py -v pytest tests/unit/test_translation_optimization.py -v pytest tests/unit/test_conversation_source_unknown.py -v +pytest tests/unit/test_desktop_focus.py -v diff --git a/backend/tests/unit/test_desktop_focus.py b/backend/tests/unit/test_desktop_focus.py new file mode 100644 index 0000000000..60b52c80af --- /dev/null +++ b/backend/tests/unit/test_desktop_focus.py @@ -0,0 +1,382 @@ +"""Tests for desktop focus analysis (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Mock heavy dependencies before any project imports +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +# Now safe to import +from utils.desktop.focus import FocusResult, FOCUS_SYSTEM_PROMPT, _build_context +from models.message_event import FocusResultEvent + +# --- FocusResult model tests --- + + +class TestFocusResultModel: + def test_focus_result_focused(self): + result = FocusResult( + status="focused", + app_or_site="VS Code", + description="Writing Python code", + message="Great focus!", + ) + assert result.status == "focused" + assert result.app_or_site == "VS Code" + assert result.description == "Writing Python code" + assert result.message == "Great focus!" + + def test_focus_result_distracted(self): + result = FocusResult( + status="distracted", + app_or_site="YouTube", + description="Watching videos", + message="Time to refocus!", + ) + assert result.status == "distracted" + assert result.app_or_site == "YouTube" + + def test_focus_result_message_optional(self): + result = FocusResult( + status="focused", + app_or_site="Terminal", + description="Running tests", + ) + assert result.message is None + + def test_focus_result_message_none_explicit(self): + result = FocusResult( + status="focused", + app_or_site="Terminal", + description="Running tests", + message=None, + ) + assert result.message is None + + +# --- FocusResultEvent tests --- + + +class TestFocusResultEvent: + def test_focus_result_event_to_json(self): + event = FocusResultEvent( + frame_id="abc-123", + status="focused", + app_or_site="VS Code", + description="Writing code", + message="Keep it up!", + ) + j = event.to_json() + assert j["type"] == "focus_result" + assert j["frame_id"] == "abc-123" + assert j["status"] == "focused" + assert j["app_or_site"] == "VS Code" + assert j["description"] == "Writing code" + assert j["message"] == "Keep it up!" + assert "event_type" not in j + + def test_focus_result_event_null_message(self): + event = FocusResultEvent( + frame_id="def-456", + status="distracted", + app_or_site="Twitter", + description="Browsing feed", + ) + j = event.to_json() + assert j["type"] == "focus_result" + assert j["message"] is None + + def test_focus_result_event_default_type(self): + event = FocusResultEvent( + frame_id="x", + status="focused", + app_or_site="Code", + description="Working", + ) + assert event.event_type == "focus_result" + + +# --- Context building tests --- + + +class TestBuildContext: + @patch('utils.desktop.focus.get_memories', return_value=[]) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch('utils.desktop.focus.get_user_goals', return_value=[]) + def test_empty_context(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert result == "" + + @patch('utils.desktop.focus.get_memories', return_value=[]) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch( + 'utils.desktop.focus.get_user_goals', + return_value=[ + {"title": "Ship Phase 2"}, + {"title": "Learn Rust"}, + ], + ) + def test_goals_in_context(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Active Goals:" in result + assert "Ship Phase 2" in result + assert "Learn Rust" in result + + @patch('utils.desktop.focus.get_memories', return_value=[]) + @patch( + 'utils.desktop.focus.get_action_items', + return_value=[ + {"description": "Fix login bug"}, + {"description": "Review PR #42"}, + ], + ) + @patch('utils.desktop.focus.get_user_goals', return_value=[]) + def test_tasks_in_context(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Current Tasks:" in result + assert "Fix login bug" in result + assert "Review PR #42" in result + + @patch( + 'utils.desktop.focus.get_memories', + return_value=[ + {"structured": {"title": "Learned about WebSockets"}}, + ], + ) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch('utils.desktop.focus.get_user_goals', return_value=[]) + def test_memories_in_context(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Recent Memories:" in result + assert "Learned about WebSockets" in result + + @patch('utils.desktop.focus.get_memories', side_effect=Exception("DB error")) + @patch('utils.desktop.focus.get_action_items', side_effect=Exception("DB error")) + @patch('utils.desktop.focus.get_user_goals', side_effect=Exception("DB error")) + def test_context_graceful_on_errors(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert result == "" + + @patch('utils.desktop.focus.get_memories', return_value=[]) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch( + 'utils.desktop.focus.get_user_goals', + return_value=[ + {"description": "Goal without title"}, + ], + ) + def test_goals_fallback_to_description(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Goal without title" in result + + @patch( + 'utils.desktop.focus.get_memories', + return_value=[ + {"content": "Memory without structured field"}, + ], + ) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch('utils.desktop.focus.get_user_goals', return_value=[]) + def test_memories_fallback_to_content(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Memory without structured field" in result + + +# --- analyze_focus integration tests --- + + +class TestAnalyzeFocus: + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_returns_result(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="VS Code", + description="Editing Python", + message="Nice work!", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + result = asyncio.get_event_loop().run_until_complete( + analyze_focus(uid="test", image_b64="base64data", app_name="VS Code", window_title="main.py") + ) + + assert result["status"] == "focused" + assert result["app_or_site"] == "VS Code" + assert result["description"] == "Editing Python" + assert result["message"] == "Nice work!" + + @patch('utils.desktop.focus._build_context', return_value="Active Goals:\n- Ship code") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_includes_context_in_prompt(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="distracted", + app_or_site="Twitter", + description="Browsing", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="data")) + + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + prompt_text = human_msg.content[0]["text"] + assert "Active Goals:" in prompt_text + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_includes_history(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="Terminal", + description="Running tests", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete( + analyze_focus( + uid="test", + image_b64="data", + history="1. [focused] VS Code: Writing code", + ) + ) + + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + prompt_text = human_msg.content[0]["text"] + assert "Recent activity" in prompt_text + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_includes_app_and_window(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="Safari", + description="Reading docs", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete( + analyze_focus(uid="test", image_b64="data", app_name="Safari", window_title="MDN Web Docs") + ) + + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + prompt_text = human_msg.content[0]["text"] + assert "Safari" in prompt_text + assert "MDN Web Docs" in prompt_text + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_sends_image_as_base64(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="Code", + description="Coding", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="FAKE_BASE64_IMAGE")) + + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + image_part = human_msg.content[1] + assert image_part["type"] == "image_url" + assert "FAKE_BASE64_IMAGE" in image_part["image_url"]["url"] + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_sends_system_prompt(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="Code", + description="Coding", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="data")) + + call_args = mock_parser.ainvoke.call_args[0][0] + system_msg = call_args[0] + assert FOCUS_SYSTEM_PROMPT in system_msg.content + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_distracted_result(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="distracted", + app_or_site="Reddit", + description="Scrolling r/programming", + message="Back to work!", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + result = asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="data")) + + assert result["status"] == "distracted" + assert result["app_or_site"] == "Reddit" + assert result["message"] == "Back to work!" + + +# --- System prompt content tests --- + + +class TestFocusSystemPrompt: + def test_prompt_includes_focused_criteria(self): + assert "Code editors" in FOCUS_SYSTEM_PROMPT + + def test_prompt_includes_distracted_criteria(self): + assert "YouTube" in FOCUS_SYSTEM_PROMPT + assert "Twitter" in FOCUS_SYSTEM_PROMPT + + def test_prompt_warns_about_log_text(self): + assert "log text" in FOCUS_SYSTEM_PROMPT + + def test_prompt_mentions_context_aware(self): + assert "CONTEXT-AWARE" in FOCUS_SYSTEM_PROMPT + + def test_prompt_coaching_message_guidance(self): + assert "100 characters max" in FOCUS_SYSTEM_PROMPT From 616ca112f58cdcefebe12c01044bf5209c986844 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:37 +0100 Subject: [PATCH 005/163] Add task extraction handler for desktop screen analysis --- backend/utils/desktop/tasks.py | 156 +++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 backend/utils/desktop/tasks.py diff --git a/backend/utils/desktop/tasks.py b/backend/utils/desktop/tasks.py new file mode 100644 index 0000000000..85b297e633 --- /dev/null +++ b/backend/utils/desktop/tasks.py @@ -0,0 +1,156 @@ +import logging +from typing import List, Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.action_items import get_action_items +from utils.llm.clients import llm_gemini_flash + +logger = logging.getLogger(__name__) + +TASK_SYSTEM_PROMPT = """\ +You are a task extraction assistant. Analyze screenshots to identify actionable tasks, \ +requests, or to-dos visible on screen. + +EXTRACTION RULES: +- Only extract tasks that are clearly visible and actionable +- Title must be 6+ words, verb-first, naming a specific person/project/artifact + concrete action +- Skip vague or generic items ("do something", "check this") +- ~90% of screenshots contain NO new task — use no_tasks when nothing actionable is found + +DEDUPLICATION: +- Compare against the user's existing tasks provided in context +- If a task is semantically similar to an existing one (even with different wording), skip it +- "Call John" and "Phone John" are duplicates +- "Finish report by Friday" and "Complete report by end of week" are duplicates +- When in doubt, err on treating as duplicate (DON'T extract) + +PRIORITY GUIDELINES: +- high: urgent deadlines, blocking requests, error fixes +- medium: normal work tasks, follow-ups +- low: nice-to-haves, ideas, non-urgent items + +SOURCE CATEGORIES: +- direct_request: someone asked the user to do something (message, meeting, mention) +- self_generated: user's own idea, reminder, or goal subtask +- calendar_driven: event preparation, recurring task, deadline +- reactive: error response, notification, observation +- external_system: from project tools, alerts, documentation""" + + +class ExtractedTask(BaseModel): + title: str = Field(description="Verb-first title, 6+ words, specific person/project + concrete action") + description: str = Field(default="", description="Additional context if needed") + priority: str = Field(description="high, medium, or low") + tags: List[str] = Field(default_factory=list, description="1-3 relevant tags") + source_app: str = Field(default="", description="App where task was found") + inferred_deadline: Optional[str] = Field(default=None, description="yyyy-MM-dd format or null") + confidence: float = Field(ge=0.0, le=1.0, description="Extraction confidence") + source_category: str = Field( + default="reactive", description="direct_request|self_generated|calendar_driven|reactive|external_system" + ) + + +class TaskExtractionResult(BaseModel): + has_new_tasks: bool = Field(description="Whether any new tasks were found") + tasks: List[ExtractedTask] = Field(default_factory=list, description="Extracted tasks (empty if none)") + context_summary: str = Field(default="", description="Brief summary of what user is viewing") + current_activity: str = Field(default="", description="What user is actively doing") + + +def _build_task_context(uid: str) -> str: + """Build existing tasks context for deduplication.""" + parts = [] + + try: + # Active tasks (not completed) for dedup + active_tasks = get_action_items(uid, completed=False, limit=50) + if active_tasks: + task_lines = [] + for t in active_tasks: + desc = t.get('description', '') + due = t.get('due_at', '') + due_str = f" (Due: {due})" if due else "" + task_lines.append(f"- {desc}{due_str} [Pending]") + parts.append("Existing active tasks (DO NOT extract duplicates):\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch active tasks for dedup: {e}") + + try: + # Recently completed tasks (last 10) for dedup + completed_tasks = get_action_items(uid, completed=True, limit=10) + if completed_tasks: + task_lines = [f"- {t.get('description', '')} [Completed]" for t in completed_tasks[:10]] + parts.append("Recently completed tasks:\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch completed tasks: {e}") + + return "\n\n".join(parts) if parts else "" + + +async def extract_tasks( + uid: str, + image_b64: str, + app_name: str = "", + window_title: str = "", +) -> dict: + """Extract tasks from a screenshot using vision LLM. + + Args: + uid: User ID for fetching existing tasks (dedup) + image_b64: Base64-encoded JPEG screenshot + app_name: Name of the foreground app + window_title: Window title + + Returns: + Dict with has_new_tasks, tasks list, context_summary, current_activity + """ + # Pre-fetch existing tasks for dedup context + task_context = _build_task_context(uid) + + # Assemble prompt + prompt_parts = [] + if task_context: + prompt_parts.append(task_context) + if app_name or window_title: + prompt_parts.append(f"Current app: {app_name}, Window: {window_title}") + prompt_parts.append("Analyze this screenshot for actionable tasks:") + + prompt_text = "\n\n".join(prompt_parts) + + # Call vision LLM with structured output + with_parser = llm_gemini_flash.with_structured_output(TaskExtractionResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=TASK_SYSTEM_PROMPT), + HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}, + ] + ), + ] + ) + + tasks_list = [] + for task in result.tasks: + tasks_list.append( + { + "title": task.title, + "description": task.description, + "priority": task.priority, + "tags": task.tags, + "source_app": task.source_app or app_name, + "inferred_deadline": task.inferred_deadline, + "confidence": task.confidence, + "source_category": task.source_category, + } + ) + + return { + "has_new_tasks": result.has_new_tasks and len(tasks_list) > 0, + "tasks": tasks_list, + "context_summary": result.context_summary, + "current_activity": result.current_activity, + } From a0da068222f27cb76960c3d66a87bc45329048a1 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:40 +0100 Subject: [PATCH 006/163] Add memory extraction handler for desktop screen analysis --- backend/utils/desktop/memories.py | 100 ++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 backend/utils/desktop/memories.py diff --git a/backend/utils/desktop/memories.py b/backend/utils/desktop/memories.py new file mode 100644 index 0000000000..c1bbbcdbda --- /dev/null +++ b/backend/utils/desktop/memories.py @@ -0,0 +1,100 @@ +import logging +from typing import List, Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.memories import get_memories +from utils.llm.clients import llm_gemini_flash + +logger = logging.getLogger(__name__) + +MEMORY_SYSTEM_PROMPT = """\ +You are a memory extraction assistant. Analyze screenshots to identify facts, insights, \ +or noteworthy information worth remembering about the user or their context. + +EXTRACTION RULES: +- Extract facts ABOUT the user: preferences, projects, people they work with, decisions, realizations +- Extract useful external information: advice, tips, insights from what they're reading +- Maximum 3 memories per screenshot +- Each memory should be a concise, standalone fact +- Skip trivial or transient information (UI state, loading screens, timestamps) +- ~80% of screenshots contain NO memorable information — return empty list when nothing stands out + +DEDUPLICATION: +- Compare against existing memories provided in context +- If a fact is already known, skip it +- Only extract genuinely NEW information + +CATEGORIES: +- system: Facts about the user (preferences, opinions, network, projects, habits) +- interesting: External wisdom or advice from others (articles, conversations, tips)""" + + +class ExtractedMemory(BaseModel): + content: str = Field(description="Concise statement of the fact or insight") + category: str = Field(description="system or interesting") + confidence: float = Field(ge=0.0, le=1.0, description="Extraction confidence") + + +class MemoryExtractionResult(BaseModel): + memories: List[ExtractedMemory] = Field(default_factory=list, description="Extracted memories (empty if none)") + + +def _build_memory_context(uid: str) -> str: + """Build existing memories context for deduplication.""" + try: + existing = get_memories(uid, limit=30, categories=['system', 'interesting']) + if existing: + lines = [] + for m in existing: + content = m.get('structured', {}).get('content', m.get('content', '')) + if content: + lines.append(f"- {content}") + if lines: + return "Existing memories (DO NOT extract duplicates):\n" + "\n".join(lines) + except Exception as e: + logger.warning(f"Failed to fetch existing memories: {e}") + return "" + + +async def extract_memories( + uid: str, + image_b64: str, + app_name: str = "", + window_title: str = "", +) -> dict: + """Extract memories from a screenshot using vision LLM. + + Returns: + Dict with memories list (each has content, category, confidence) + """ + memory_context = _build_memory_context(uid) + + prompt_parts = [] + if memory_context: + prompt_parts.append(memory_context) + if app_name or window_title: + prompt_parts.append(f"Current app: {app_name}, Window: {window_title}") + prompt_parts.append("Analyze this screenshot for noteworthy facts or insights:") + + prompt_text = "\n\n".join(prompt_parts) + + with_parser = llm_gemini_flash.with_structured_output(MemoryExtractionResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=MEMORY_SYSTEM_PROMPT), + HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}, + ] + ), + ] + ) + + return { + "memories": [ + {"content": m.content, "category": m.category, "confidence": m.confidence} for m in result.memories + ] + } From 8d2b0f8ca4e1be7a8930ed4bf8534fd44ad5b22b Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:40 +0100 Subject: [PATCH 007/163] Add contextual advice handler for desktop screen analysis --- backend/utils/desktop/advice.py | 115 ++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 backend/utils/desktop/advice.py diff --git a/backend/utils/desktop/advice.py b/backend/utils/desktop/advice.py new file mode 100644 index 0000000000..c73a7ae251 --- /dev/null +++ b/backend/utils/desktop/advice.py @@ -0,0 +1,115 @@ +import logging +from typing import Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.goals import get_user_goals +from database.action_items import get_action_items +from utils.llm.clients import llm_gemini_flash + +logger = logging.getLogger(__name__) + +ADVICE_SYSTEM_PROMPT = """\ +You are a proactive assistant that offers brief, actionable advice based on what the user \ +is currently doing on their screen. Your advice should be contextual and helpful. + +ADVICE RULES: +- Only offer advice when you can provide genuinely useful, specific guidance +- Advice must relate to what's visible on screen +- Keep it short (1-2 sentences max) +- Be actionable — tell the user something they can DO, not just observe +- Consider the user's goals and tasks when forming advice +- ~70% of screenshots need NO advice — return null when nothing useful to say + +TONE: +- Direct and casual, not formal +- Helpful, not preachy +- Specific to what you see, not generic productivity tips + +CATEGORIES: +- productivity: efficiency tips, workflow improvements +- mistake_prevention: catching potential errors or oversights +- learning: suggesting resources or approaches +- health: break reminders, posture, eye strain (only if clearly needed) +- goal_alignment: connecting current activity to stated goals""" + + +class AdviceResult(BaseModel): + has_advice: bool = Field(description="Whether advice is warranted") + content: Optional[str] = Field(default=None, description="The advice (1-2 sentences, null if none)") + category: Optional[str] = Field( + default=None, description="productivity|mistake_prevention|learning|health|goal_alignment" + ) + confidence: float = Field(ge=0.0, le=1.0, description="Confidence this advice is useful") + + +def _build_advice_context(uid: str) -> str: + """Build user context for advice generation.""" + parts = [] + + try: + goals = get_user_goals(uid, limit=5) + if goals: + goal_lines = [f"- {g.get('title', g.get('description', ''))}" for g in goals] + parts.append("User's goals:\n" + "\n".join(goal_lines)) + except Exception as e: + logger.warning(f"Failed to fetch goals for advice: {e}") + + try: + tasks = get_action_items(uid, completed=False, limit=10) + if tasks: + task_lines = [f"- {t.get('description', '')}" for t in tasks[:10]] + parts.append("Current tasks:\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch tasks for advice: {e}") + + return "\n\n".join(parts) if parts else "" + + +async def generate_advice( + uid: str, + image_b64: str, + app_name: str = "", + window_title: str = "", +) -> dict: + """Generate contextual advice from a screenshot using vision LLM. + + Returns: + Dict with has_advice, content, category, confidence (or nulls if no advice) + """ + advice_context = _build_advice_context(uid) + + prompt_parts = [] + if advice_context: + prompt_parts.append(advice_context) + if app_name or window_title: + prompt_parts.append(f"Current app: {app_name}, Window: {window_title}") + prompt_parts.append("Based on this screenshot, do you have any specific, actionable advice?") + + prompt_text = "\n\n".join(prompt_parts) + + with_parser = llm_gemini_flash.with_structured_output(AdviceResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=ADVICE_SYSTEM_PROMPT), + HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}, + ] + ), + ] + ) + + if not result.has_advice: + return {"has_advice": False, "advice": None} + + return { + "has_advice": True, + "advice": { + "content": result.content, + "category": result.category, + "confidence": result.confidence, + }, + } From 51ed561d19954a9e4a2f313e2a213a4d962358be Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:41 +0100 Subject: [PATCH 008/163] Add live notes handler for desktop transcript processing --- backend/utils/desktop/live_notes.py | 55 +++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 backend/utils/desktop/live_notes.py diff --git a/backend/utils/desktop/live_notes.py b/backend/utils/desktop/live_notes.py new file mode 100644 index 0000000000..4c88b878b4 --- /dev/null +++ b/backend/utils/desktop/live_notes.py @@ -0,0 +1,55 @@ +import logging + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from utils.llm.clients import llm_mini + +logger = logging.getLogger(__name__) + +LIVE_NOTES_SYSTEM_PROMPT = """\ +You are a live note-taking assistant. Given a transcript segment, generate a concise, \ +well-structured note that captures the key information. + +RULES: +- Condense transcript into clear, readable notes +- Preserve important details: names, numbers, decisions, action items +- Remove filler words, repetition, and hesitation +- Use bullet points for multiple items +- Keep notes under 200 words +- If the transcript is too short or contains no meaningful content, return empty string""" + + +class LiveNoteResult(BaseModel): + text: str = Field(description="The generated note (empty string if no meaningful content)") + + +async def generate_live_note( + text: str, + session_context: str = "", +) -> dict: + """Generate a live note from transcript text. + + Args: + text: Transcript text to summarize + session_context: Optional session context + + Returns: + Dict with text field (the note) + """ + prompt_parts = [] + if session_context: + prompt_parts.append(f"Session context: {session_context}") + prompt_parts.append(f"Transcript:\n{text}") + + prompt_text = "\n\n".join(prompt_parts) + + with_parser = llm_mini.with_structured_output(LiveNoteResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=LIVE_NOTES_SYSTEM_PROMPT), + HumanMessage(content=prompt_text), + ] + ) + + return {"text": result.text} From 4e3968f0adfe83c6d6d1ebd69ddb1aece044e98f Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:42 +0100 Subject: [PATCH 009/163] Add user profile generation handler for desktop --- backend/utils/desktop/profile.py | 79 ++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 backend/utils/desktop/profile.py diff --git a/backend/utils/desktop/profile.py b/backend/utils/desktop/profile.py new file mode 100644 index 0000000000..84fa697b97 --- /dev/null +++ b/backend/utils/desktop/profile.py @@ -0,0 +1,79 @@ +import logging + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.memories import get_memories +from database.action_items import get_action_items +from database.goals import get_user_goals +from utils.llm.clients import llm_mini + +logger = logging.getLogger(__name__) + +PROFILE_SYSTEM_PROMPT = """\ +You are generating a concise user profile summary based on their data (goals, tasks, memories). \ +This profile helps other AI assistants understand who the user is and what they care about. + +FORMAT: +- Write in third person ("The user...") +- Include: professional focus, key projects, communication style, preferences +- Keep under 300 words +- Be factual — only include what's supported by the data +- If data is sparse, keep the profile short rather than speculating""" + + +class ProfileResult(BaseModel): + profile_text: str = Field(description="The generated user profile summary") + + +async def generate_profile(uid: str) -> dict: + """Generate a user profile from their goals, tasks, and memories. + + Returns: + Dict with profile_text + """ + parts = [] + + try: + goals = get_user_goals(uid, limit=10) + if goals: + goal_lines = [f"- {g.get('title', g.get('description', ''))}" for g in goals] + parts.append("Goals:\n" + "\n".join(goal_lines)) + except Exception as e: + logger.warning(f"Failed to fetch goals for profile: {e}") + + try: + tasks = get_action_items(uid, completed=False, limit=30) + if tasks: + task_lines = [f"- {t.get('description', '')}" for t in tasks[:30]] + parts.append("Active tasks:\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch tasks for profile: {e}") + + try: + memories = get_memories(uid, limit=30, categories=['system']) + if memories: + mem_lines = [] + for m in memories: + content = m.get('structured', {}).get('content', m.get('content', '')) + if content: + mem_lines.append(f"- {content}") + if mem_lines: + parts.append("Known facts:\n" + "\n".join(mem_lines)) + except Exception as e: + logger.warning(f"Failed to fetch memories for profile: {e}") + + if not parts: + return {"profile_text": "No data available to generate profile."} + + data_text = "\n\n".join(parts) + + with_parser = llm_mini.with_structured_output(ProfileResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=PROFILE_SYSTEM_PROMPT), + HumanMessage(content=f"Generate a user profile from this data:\n\n{data_text}"), + ] + ) + + return {"profile_text": result.profile_text} From a2138ca5d81d1e76e5649d51cbed506c44f038ed Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:43 +0100 Subject: [PATCH 010/163] Add task reranking and deduplication handlers for desktop --- backend/utils/desktop/task_ops.py | 141 ++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 backend/utils/desktop/task_ops.py diff --git a/backend/utils/desktop/task_ops.py b/backend/utils/desktop/task_ops.py new file mode 100644 index 0000000000..3e6b7506c0 --- /dev/null +++ b/backend/utils/desktop/task_ops.py @@ -0,0 +1,141 @@ +import logging +from typing import List + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.action_items import get_action_items +from utils.llm.clients import llm_mini + +logger = logging.getLogger(__name__) + +# --- Task Reranking --- + +RERANK_SYSTEM_PROMPT = """\ +You are a task prioritization assistant. Given a list of tasks, rerank them by importance \ +and urgency. Consider deadlines, dependencies, and impact. + +RULES: +- Most important/urgent tasks first +- Tasks with approaching deadlines rank higher +- Blocking tasks rank higher than blocked tasks +- Return the same task IDs in new order""" + + +class RankedTask(BaseModel): + id: str = Field(description="Task ID") + new_position: int = Field(description="New position (1 = most important)") + + +class RerankResult(BaseModel): + updated_tasks: List[RankedTask] = Field(description="Tasks in new priority order") + + +async def rerank_tasks(uid: str) -> dict: + """Rerank user's active tasks by priority. + + Returns: + Dict with updated_tasks list + """ + try: + tasks = get_action_items(uid, completed=False, limit=50) + except Exception as e: + logger.error(f"Failed to fetch tasks for reranking: {e}") + return {"updated_tasks": []} + + if not tasks: + return {"updated_tasks": []} + + task_lines = [] + for t in tasks: + tid = t.get('id', '') + desc = t.get('description', '') + due = t.get('due_at', '') + priority = t.get('priority', 'medium') + due_str = f", Due: {due}" if due else "" + task_lines.append(f"- ID: {tid} | {desc} | Priority: {priority}{due_str}") + + task_text = "\n".join(task_lines) + + with_parser = llm_mini.with_structured_output(RerankResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=RERANK_SYSTEM_PROMPT), + HumanMessage(content=f"Rerank these tasks by importance:\n\n{task_text}"), + ] + ) + + return {"updated_tasks": [{"id": t.id, "new_position": t.new_position} for t in result.updated_tasks]} + + +# --- Task Deduplication --- + +DEDUP_SYSTEM_PROMPT = """\ +You are a task deduplication assistant. Identify semantically duplicate tasks and decide \ +which to keep and which to delete. + +RULES: +- Two tasks are duplicates if they describe the same action, even with different wording +- "Call John" and "Phone John" are duplicates +- "Review PR #42" and "Look at pull request 42" are duplicates +- Keep the more specific/detailed version +- Keep the one with a deadline if only one has one +- Keep the more recently created one if equally specific +- Only flag true duplicates — similar but distinct tasks should both be kept""" + + +class DedupGroup(BaseModel): + keep_id: str = Field(description="ID of the task to keep") + delete_ids: List[str] = Field(description="IDs of duplicate tasks to remove") + reason: str = Field(description="Why these are duplicates") + + +class DedupResult(BaseModel): + groups: List[DedupGroup] = Field(default_factory=list, description="Duplicate groups (empty if no duplicates)") + + +async def dedup_tasks(uid: str) -> dict: + """Find and resolve duplicate tasks. + + Returns: + Dict with deleted_ids and reason + """ + try: + tasks = get_action_items(uid, completed=False, limit=100) + except Exception as e: + logger.error(f"Failed to fetch tasks for dedup: {e}") + return {"deleted_ids": [], "reason": "Failed to fetch tasks"} + + if len(tasks) < 2: + return {"deleted_ids": [], "reason": "Not enough tasks to deduplicate"} + + task_lines = [] + for t in tasks: + tid = t.get('id', '') + desc = t.get('description', '') + due = t.get('due_at', '') + created = t.get('created_at', '') + due_str = f", Due: {due}" if due else "" + created_str = f", Created: {created}" if created else "" + task_lines.append(f"- ID: {tid} | {desc}{due_str}{created_str}") + + task_text = "\n".join(task_lines) + + with_parser = llm_mini.with_structured_output(DedupResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=DEDUP_SYSTEM_PROMPT), + HumanMessage(content=f"Find duplicate tasks:\n\n{task_text}"), + ] + ) + + all_deleted = [] + reasons = [] + for group in result.groups: + all_deleted.extend(group.delete_ids) + reasons.append(group.reason) + + return { + "deleted_ids": all_deleted, + "reason": "; ".join(reasons) if reasons else "No duplicates found", + } From ecf3523122c7550a47ff3453806f399612a4bc9b Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:46 +0100 Subject: [PATCH 011/163] Add message event classes for all desktop handler types --- backend/models/message_event.py | 81 +++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/backend/models/message_event.py b/backend/models/message_event.py index 36556f6af9..b422c767b1 100644 --- a/backend/models/message_event.py +++ b/backend/models/message_event.py @@ -199,3 +199,84 @@ def to_json(self): j["type"] = self.event_type del j["event_type"] return j + + +class TasksExtractedEvent(MessageEvent): + event_type: str = "tasks_extracted" + frame_id: str + tasks: List = [] + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class MemoriesExtractedEvent(MessageEvent): + event_type: str = "memories_extracted" + frame_id: str + memories: List = [] + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class AdviceExtractedEvent(MessageEvent): + event_type: str = "advice_extracted" + frame_id: str + advice: Optional[Any] = None + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class LiveNoteEvent(MessageEvent): + event_type: str = "live_note" + text: str + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class ProfileUpdatedEvent(MessageEvent): + event_type: str = "profile_updated" + profile_text: str + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class RerankCompleteEvent(MessageEvent): + event_type: str = "rerank_complete" + updated_tasks: List = [] + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class DedupCompleteEvent(MessageEvent): + event_type: str = "dedup_complete" + deleted_ids: List = [] + reason: str = "" + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j From 3cac01db92f48a1263a1f7b1c9de8e2d21f40181 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:47 +0100 Subject: [PATCH 012/163] Add full desktop dispatcher for screen_frame and text message types --- backend/routers/transcribe.py | 125 +++++++++++++++++++++++++++------- 1 file changed, 99 insertions(+), 26 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 658f275a7d..bbf4022f0a 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -50,17 +50,24 @@ TranscriptSegment, ) from models.message_event import ( + AdviceExtractedEvent, ConversationEvent, + DedupCompleteEvent, FocusResultEvent, FREEMIUM_ACTION_SETUP_ON_DEVICE_STT, FreemiumThresholdReachedEvent, LastConversationEvent, + LiveNoteEvent, + MemoriesExtractedEvent, MessageEvent, MessageServiceStatusEvent, PhotoDescribedEvent, PhotoProcessingEvent, + ProfileUpdatedEvent, + RerankCompleteEvent, SegmentsDeletedEvent, SpeakerLabelSuggestionEvent, + TasksExtractedEvent, TranslationEvent, ) from models.transcript_segment import Translation @@ -101,7 +108,13 @@ SPEAKER_MATCH_THRESHOLD, ) from utils.speaker_sample_migration import maybe_migrate_person_samples +from utils.desktop.advice import generate_advice from utils.desktop.focus import analyze_focus +from utils.desktop.live_notes import generate_live_note +from utils.desktop.memories import extract_memories +from utils.desktop.profile import generate_profile +from utils.desktop.task_ops import dedup_tasks, rerank_tasks +from utils.desktop.tasks import extract_tasks from utils.log_sanitizer import sanitize, sanitize_pii logger = logging.getLogger(__name__) @@ -2134,33 +2147,93 @@ async def close_soniox_profile(): frame_id = json_data.get('frame_id', '') image_b64 = json_data.get('image_b64', '') analyze_types = json_data.get('analyze', []) - if image_b64 and 'focus' in analyze_types: - async def _handle_focus(fid, img, app, wtitle): - try: - result = await analyze_focus( - uid=uid, - image_b64=img, - app_name=app, - window_title=wtitle, - ) - _send_message_event(FocusResultEvent( - frame_id=fid, - status=result['status'], - app_or_site=result['app_or_site'], - description=result['description'], - message=result.get('message'), - )) - except Exception as focus_err: - logger.error(f"Focus analysis failed: {focus_err} {uid} {session_id}") - - spawn(_handle_focus( - frame_id, - image_b64, - json_data.get('app_name', ''), - json_data.get('window_title', ''), - )) - elif not image_b64: + sf_app = json_data.get('app_name', '') + sf_wtitle = json_data.get('window_title', '') + if not image_b64: logger.warning(f"screen_frame missing image_b64 {uid} {session_id}") + else: + # Fan out to parallel handlers per analyze type + if 'focus' in analyze_types: + async def _handle_focus(fid, img, app, wtitle): + try: + result = await analyze_focus(uid=uid, image_b64=img, app_name=app, window_title=wtitle) + _send_message_event(FocusResultEvent( + frame_id=fid, status=result['status'], app_or_site=result['app_or_site'], + description=result['description'], message=result.get('message'), + )) + except Exception as e: + logger.error(f"Focus analysis failed: {e} {uid} {session_id}") + spawn(_handle_focus(frame_id, image_b64, sf_app, sf_wtitle)) + + if 'tasks' in analyze_types: + async def _handle_tasks(fid, img, app, wtitle): + try: + result = await extract_tasks(uid=uid, image_b64=img, app_name=app, window_title=wtitle) + _send_message_event(TasksExtractedEvent(frame_id=fid, tasks=result.get('tasks', []))) + except Exception as e: + logger.error(f"Task extraction failed: {e} {uid} {session_id}") + spawn(_handle_tasks(frame_id, image_b64, sf_app, sf_wtitle)) + + if 'memories' in analyze_types: + async def _handle_memories(fid, img, app, wtitle): + try: + result = await extract_memories(uid=uid, image_b64=img, app_name=app, window_title=wtitle) + _send_message_event(MemoriesExtractedEvent(frame_id=fid, memories=result.get('memories', []))) + except Exception as e: + logger.error(f"Memory extraction failed: {e} {uid} {session_id}") + spawn(_handle_memories(frame_id, image_b64, sf_app, sf_wtitle)) + + if 'advice' in analyze_types: + async def _handle_advice(fid, img, app, wtitle): + try: + result = await generate_advice(uid=uid, image_b64=img, app_name=app, window_title=wtitle) + _send_message_event(AdviceExtractedEvent( + frame_id=fid, advice=result.get('advice'), + )) + except Exception as e: + logger.error(f"Advice generation failed: {e} {uid} {session_id}") + spawn(_handle_advice(frame_id, image_b64, sf_app, sf_wtitle)) + + # Desktop proactive AI — text-only message types (#5396) + elif json_data.get('type') == 'live_notes_text': + async def _handle_live_notes(text, ctx): + try: + result = await generate_live_note(text=text, session_context=ctx) + if result.get('text'): + _send_message_event(LiveNoteEvent(text=result['text'])) + except Exception as e: + logger.error(f"Live note generation failed: {e} {uid} {session_id}") + spawn(_handle_live_notes(json_data.get('text', ''), json_data.get('session_context', ''))) + + elif json_data.get('type') == 'profile_request': + async def _handle_profile(): + try: + result = await generate_profile(uid=uid) + _send_message_event(ProfileUpdatedEvent(profile_text=result['profile_text'])) + except Exception as e: + logger.error(f"Profile generation failed: {e} {uid} {session_id}") + spawn(_handle_profile()) + + elif json_data.get('type') == 'task_rerank': + async def _handle_rerank(): + try: + result = await rerank_tasks(uid=uid) + _send_message_event(RerankCompleteEvent(updated_tasks=result['updated_tasks'])) + except Exception as e: + logger.error(f"Task reranking failed: {e} {uid} {session_id}") + spawn(_handle_rerank()) + + elif json_data.get('type') == 'task_dedup': + async def _handle_dedup(): + try: + result = await dedup_tasks(uid=uid) + _send_message_event(DedupCompleteEvent( + deleted_ids=result['deleted_ids'], reason=result['reason'], + )) + except Exception as e: + logger.error(f"Task dedup failed: {e} {uid} {session_id}") + spawn(_handle_dedup()) + except json.JSONDecodeError: logger.info( f"Received non-json text message: {sanitize(message.get('text'))} {uid} {session_id}" From 5875ba726370ec9bcce948d294b5e2902c98c7ac Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:47 +0100 Subject: [PATCH 013/163] Add unit tests for task extraction handler (18 tests) --- backend/tests/unit/test_desktop_tasks.py | 238 +++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 backend/tests/unit/test_desktop_tasks.py diff --git a/backend/tests/unit/test_desktop_tasks.py b/backend/tests/unit/test_desktop_tasks.py new file mode 100644 index 0000000000..908b4d9a50 --- /dev/null +++ b/backend/tests/unit/test_desktop_tasks.py @@ -0,0 +1,238 @@ +"""Tests for desktop task extraction handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Mock heavy dependencies before any project imports +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.tasks import ( + ExtractedTask, + TaskExtractionResult, + TASK_SYSTEM_PROMPT, + _build_task_context, + extract_tasks, +) +from models.message_event import TasksExtractedEvent + + +class TestExtractedTaskModel: + def test_task_with_all_fields(self): + task = ExtractedTask( + title="Review pull request 42 for authentication changes", + description="Check auth middleware", + priority="high", + tags=["code-review", "auth"], + source_app="GitHub", + inferred_deadline="2026-03-10", + confidence=0.9, + source_category="direct_request", + ) + assert task.title == "Review pull request 42 for authentication changes" + assert task.priority == "high" + assert task.confidence == 0.9 + + def test_task_defaults(self): + task = ExtractedTask( + title="Update the README with new API docs", + priority="low", + confidence=0.5, + ) + assert task.description == "" + assert task.tags == [] + assert task.source_app == "" + assert task.inferred_deadline is None + assert task.source_category == "reactive" + + def test_task_confidence_bounds(self): + with pytest.raises(Exception): + ExtractedTask(title="Test", priority="high", confidence=1.5) + with pytest.raises(Exception): + ExtractedTask(title="Test", priority="high", confidence=-0.1) + + +class TestTaskExtractionResult: + def test_result_with_tasks(self): + result = TaskExtractionResult( + has_new_tasks=True, + tasks=[ + ExtractedTask(title="Call John about the project deadline", priority="high", confidence=0.8), + ], + context_summary="Slack messages", + current_activity="Reading messages", + ) + assert result.has_new_tasks is True + assert len(result.tasks) == 1 + + def test_result_no_tasks(self): + result = TaskExtractionResult( + has_new_tasks=False, + context_summary="IDE open", + current_activity="Coding", + ) + assert result.has_new_tasks is False + assert result.tasks == [] + + +class TestTasksExtractedEvent: + def test_event_structure(self): + event = TasksExtractedEvent( + frame_id="frame123", + tasks=[{"title": "Test task", "priority": "high"}], + ) + data = event.to_json() + assert data["type"] == "tasks_extracted" + assert data["frame_id"] == "frame123" + assert len(data["tasks"]) == 1 + + +class TestBuildTaskContext: + @patch('utils.desktop.tasks.get_action_items') + def test_active_tasks_in_context(self, mock_get): + mock_get.return_value = [ + {'description': 'Write tests', 'due_at': '2026-03-10'}, + {'description': 'Fix bug'}, + ] + ctx = _build_task_context("uid1") + assert "Write tests" in ctx + assert "Due: 2026-03-10" in ctx + assert "Fix bug" in ctx + assert "Pending" in ctx + + @patch('utils.desktop.tasks.get_action_items') + def test_completed_tasks_in_context(self, mock_get): + mock_get.side_effect = [ + [], # active tasks + [{'description': 'Done task'}], # completed tasks + ] + ctx = _build_task_context("uid1") + assert "Done task" in ctx + assert "Completed" in ctx + + @patch('utils.desktop.tasks.get_action_items') + def test_empty_context(self, mock_get): + mock_get.return_value = [] + ctx = _build_task_context("uid1") + assert ctx == "" + + @patch('utils.desktop.tasks.get_action_items') + def test_graceful_on_errors(self, mock_get): + mock_get.side_effect = Exception("DB error") + ctx = _build_task_context("uid1") + assert ctx == "" + + +class TestExtractTasks: + @patch('utils.desktop.tasks._build_task_context') + @patch('utils.desktop.tasks.llm_gemini_flash') + def test_extract_tasks_returns_result(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=TaskExtractionResult( + has_new_tasks=True, + tasks=[ + ExtractedTask( + title="Review pull request 42 for auth changes", + priority="high", + confidence=0.9, + source_app="GitHub", + ) + ], + context_summary="GitHub PR page", + current_activity="Reviewing code", + ) + ) + result = asyncio.get_event_loop().run_until_complete( + extract_tasks("uid1", "base64img", "Chrome", "GitHub PR") + ) + assert result["has_new_tasks"] is True + assert len(result["tasks"]) == 1 + assert result["tasks"][0]["title"] == "Review pull request 42 for auth changes" + assert result["tasks"][0]["source_app"] == "GitHub" + + @patch('utils.desktop.tasks._build_task_context') + @patch('utils.desktop.tasks.llm_gemini_flash') + def test_extract_tasks_no_tasks(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=TaskExtractionResult( + has_new_tasks=False, + context_summary="Desktop idle", + current_activity="Nothing", + ) + ) + result = asyncio.get_event_loop().run_until_complete( + extract_tasks("uid1", "base64img") + ) + assert result["has_new_tasks"] is False + assert result["tasks"] == [] + + @patch('utils.desktop.tasks._build_task_context') + @patch('utils.desktop.tasks.llm_gemini_flash') + def test_source_app_fallback(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=TaskExtractionResult( + has_new_tasks=True, + tasks=[ + ExtractedTask( + title="Send email to team about deadline update", + priority="medium", + confidence=0.7, + source_app="", # empty + ) + ], + ) + ) + result = asyncio.get_event_loop().run_until_complete( + extract_tasks("uid1", "base64img", "Slack", "General") + ) + # Falls back to app_name when source_app is empty + assert result["tasks"][0]["source_app"] == "Slack" + + @patch('utils.desktop.tasks._build_task_context') + @patch('utils.desktop.tasks.llm_gemini_flash') + def test_includes_context_in_prompt(self, mock_llm, mock_ctx): + mock_ctx.return_value = "Existing active tasks:\n- Write tests [Pending]" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=TaskExtractionResult(has_new_tasks=False) + ) + asyncio.get_event_loop().run_until_complete( + extract_tasks("uid1", "base64img", "VS Code", "main.py") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + text_content = human_msg.content[0]["text"] + assert "Write tests" in text_content + assert "VS Code" in text_content + + +class TestTaskSystemPrompt: + def test_prompt_includes_dedup_rules(self): + assert "DEDUPLICATION" in TASK_SYSTEM_PROMPT + + def test_prompt_includes_priority_guidelines(self): + assert "high" in TASK_SYSTEM_PROMPT + assert "medium" in TASK_SYSTEM_PROMPT + assert "low" in TASK_SYSTEM_PROMPT + + def test_prompt_includes_source_categories(self): + assert "direct_request" in TASK_SYSTEM_PROMPT + assert "self_generated" in TASK_SYSTEM_PROMPT + assert "calendar_driven" in TASK_SYSTEM_PROMPT From aa330becd04c874be636f2b99b8e5565413ca520 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:50 +0100 Subject: [PATCH 014/163] Add unit tests for memory extraction handler (14 tests) --- backend/tests/unit/test_desktop_memories.py | 150 ++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 backend/tests/unit/test_desktop_memories.py diff --git a/backend/tests/unit/test_desktop_memories.py b/backend/tests/unit/test_desktop_memories.py new file mode 100644 index 0000000000..158760e2c4 --- /dev/null +++ b/backend/tests/unit/test_desktop_memories.py @@ -0,0 +1,150 @@ +"""Tests for desktop memory extraction handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.memories import ( + ExtractedMemory, + MemoryExtractionResult, + MEMORY_SYSTEM_PROMPT, + _build_memory_context, + extract_memories, +) +from models.message_event import MemoriesExtractedEvent + + +class TestExtractedMemoryModel: + def test_memory_all_fields(self): + m = ExtractedMemory(content="User prefers dark mode", category="system", confidence=0.95) + assert m.content == "User prefers dark mode" + assert m.category == "system" + assert m.confidence == 0.95 + + def test_memory_interesting_category(self): + m = ExtractedMemory(content="AI tip from article", category="interesting", confidence=0.7) + assert m.category == "interesting" + + def test_confidence_bounds(self): + with pytest.raises(Exception): + ExtractedMemory(content="test", category="system", confidence=1.5) + + +class TestMemoryExtractionResult: + def test_result_with_memories(self): + result = MemoryExtractionResult( + memories=[ExtractedMemory(content="Fact 1", category="system", confidence=0.8)] + ) + assert len(result.memories) == 1 + + def test_result_empty(self): + result = MemoryExtractionResult() + assert result.memories == [] + + +class TestMemoriesExtractedEvent: + def test_event_structure(self): + event = MemoriesExtractedEvent( + frame_id="frame456", + memories=[{"content": "Test fact", "category": "system", "confidence": 0.9}], + ) + data = event.to_json() + assert data["type"] == "memories_extracted" + assert data["frame_id"] == "frame456" + assert len(data["memories"]) == 1 + + +class TestBuildMemoryContext: + @patch('utils.desktop.memories.get_memories') + def test_existing_memories_in_context(self, mock_get): + mock_get.return_value = [ + {'structured': {'content': 'User likes Python'}}, + {'content': 'Fallback content'}, + ] + ctx = _build_memory_context("uid1") + assert "User likes Python" in ctx + assert "Fallback content" in ctx + assert "DO NOT extract duplicates" in ctx + + @patch('utils.desktop.memories.get_memories') + def test_empty_context(self, mock_get): + mock_get.return_value = [] + ctx = _build_memory_context("uid1") + assert ctx == "" + + @patch('utils.desktop.memories.get_memories') + def test_graceful_on_errors(self, mock_get): + mock_get.side_effect = Exception("DB error") + ctx = _build_memory_context("uid1") + assert ctx == "" + + +class TestExtractMemories: + @patch('utils.desktop.memories._build_memory_context') + @patch('utils.desktop.memories.llm_gemini_flash') + def test_extract_returns_memories(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=MemoryExtractionResult( + memories=[ + ExtractedMemory(content="User works on Omi project", category="system", confidence=0.85), + ] + ) + ) + result = asyncio.get_event_loop().run_until_complete( + extract_memories("uid1", "base64img", "VS Code", "omi/main.py") + ) + assert len(result["memories"]) == 1 + assert result["memories"][0]["content"] == "User works on Omi project" + assert result["memories"][0]["category"] == "system" + + @patch('utils.desktop.memories._build_memory_context') + @patch('utils.desktop.memories.llm_gemini_flash') + def test_extract_empty_result(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=MemoryExtractionResult()) + result = asyncio.get_event_loop().run_until_complete( + extract_memories("uid1", "base64img") + ) + assert result["memories"] == [] + + @patch('utils.desktop.memories._build_memory_context') + @patch('utils.desktop.memories.llm_gemini_flash') + def test_sends_image_and_system_prompt(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=MemoryExtractionResult()) + asyncio.get_event_loop().run_until_complete( + extract_memories("uid1", "testimg64") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + sys_msg = call_args[0] + human_msg = call_args[1] + assert MEMORY_SYSTEM_PROMPT in sys_msg.content + assert human_msg.content[1]["image_url"]["url"] == "data:image/jpeg;base64,testimg64" + + +class TestMemorySystemPrompt: + def test_includes_extraction_rules(self): + assert "EXTRACTION RULES" in MEMORY_SYSTEM_PROMPT + + def test_includes_dedup(self): + assert "DEDUPLICATION" in MEMORY_SYSTEM_PROMPT + + def test_includes_categories(self): + assert "system" in MEMORY_SYSTEM_PROMPT + assert "interesting" in MEMORY_SYSTEM_PROMPT From 8c0ebbbbcef80d16174866b596b1cc3e67af2f7f Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:51 +0100 Subject: [PATCH 015/163] Add unit tests for advice handler (14 tests) --- backend/tests/unit/test_desktop_advice.py | 159 ++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 backend/tests/unit/test_desktop_advice.py diff --git a/backend/tests/unit/test_desktop_advice.py b/backend/tests/unit/test_desktop_advice.py new file mode 100644 index 0000000000..e1699f3eb7 --- /dev/null +++ b/backend/tests/unit/test_desktop_advice.py @@ -0,0 +1,159 @@ +"""Tests for desktop advice handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.advice import ( + AdviceResult, + ADVICE_SYSTEM_PROMPT, + _build_advice_context, + generate_advice, +) +from models.message_event import AdviceExtractedEvent + + +class TestAdviceResultModel: + def test_advice_with_content(self): + r = AdviceResult(has_advice=True, content="Take a break", category="health", confidence=0.8) + assert r.has_advice is True + assert r.content == "Take a break" + assert r.category == "health" + + def test_no_advice(self): + r = AdviceResult(has_advice=False, confidence=0.1) + assert r.has_advice is False + assert r.content is None + assert r.category is None + + def test_confidence_bounds(self): + with pytest.raises(Exception): + AdviceResult(has_advice=True, confidence=2.0) + + +class TestAdviceExtractedEvent: + def test_event_with_advice(self): + event = AdviceExtractedEvent( + frame_id="frame789", + advice={"content": "Try dark mode", "category": "productivity", "confidence": 0.7}, + ) + data = event.to_json() + assert data["type"] == "advice_extracted" + assert data["frame_id"] == "frame789" + assert data["advice"]["content"] == "Try dark mode" + + def test_event_no_advice(self): + event = AdviceExtractedEvent(frame_id="frame789", advice=None) + data = event.to_json() + assert data["advice"] is None + + +class TestBuildAdviceContext: + @patch('utils.desktop.advice.get_action_items') + @patch('utils.desktop.advice.get_user_goals') + def test_goals_and_tasks_in_context(self, mock_goals, mock_tasks): + mock_goals.return_value = [{'title': 'Ship v2'}] + mock_tasks.return_value = [{'description': 'Write tests'}] + ctx = _build_advice_context("uid1") + assert "Ship v2" in ctx + assert "Write tests" in ctx + + @patch('utils.desktop.advice.get_action_items') + @patch('utils.desktop.advice.get_user_goals') + def test_empty_context(self, mock_goals, mock_tasks): + mock_goals.return_value = [] + mock_tasks.return_value = [] + ctx = _build_advice_context("uid1") + assert ctx == "" + + @patch('utils.desktop.advice.get_action_items') + @patch('utils.desktop.advice.get_user_goals') + def test_graceful_on_errors(self, mock_goals, mock_tasks): + mock_goals.side_effect = Exception("DB error") + mock_tasks.side_effect = Exception("DB error") + ctx = _build_advice_context("uid1") + assert ctx == "" + + @patch('utils.desktop.advice.get_action_items') + @patch('utils.desktop.advice.get_user_goals') + def test_goals_fallback_to_description(self, mock_goals, mock_tasks): + mock_goals.return_value = [{'description': 'Fallback goal'}] + mock_tasks.return_value = [] + ctx = _build_advice_context("uid1") + assert "Fallback goal" in ctx + + +class TestGenerateAdvice: + @patch('utils.desktop.advice._build_advice_context') + @patch('utils.desktop.advice.llm_gemini_flash') + def test_returns_advice(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=AdviceResult( + has_advice=True, + content="Consider using a linter", + category="productivity", + confidence=0.75, + ) + ) + result = asyncio.get_event_loop().run_until_complete( + generate_advice("uid1", "base64img", "VS Code", "main.py") + ) + assert result["has_advice"] is True + assert result["advice"]["content"] == "Consider using a linter" + assert result["advice"]["category"] == "productivity" + + @patch('utils.desktop.advice._build_advice_context') + @patch('utils.desktop.advice.llm_gemini_flash') + def test_no_advice(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=AdviceResult(has_advice=False, confidence=0.1) + ) + result = asyncio.get_event_loop().run_until_complete( + generate_advice("uid1", "base64img") + ) + assert result["has_advice"] is False + assert result["advice"] is None + + @patch('utils.desktop.advice._build_advice_context') + @patch('utils.desktop.advice.llm_gemini_flash') + def test_includes_app_info(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=AdviceResult(has_advice=False, confidence=0.1) + ) + asyncio.get_event_loop().run_until_complete( + generate_advice("uid1", "base64img", "Chrome", "Stack Overflow") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + text_content = human_msg.content[0]["text"] + assert "Chrome" in text_content + assert "Stack Overflow" in text_content + + +class TestAdviceSystemPrompt: + def test_includes_categories(self): + assert "productivity" in ADVICE_SYSTEM_PROMPT + assert "mistake_prevention" in ADVICE_SYSTEM_PROMPT + assert "health" in ADVICE_SYSTEM_PROMPT + assert "goal_alignment" in ADVICE_SYSTEM_PROMPT + + def test_includes_tone_guidance(self): + assert "TONE" in ADVICE_SYSTEM_PROMPT From 5a32e04f0290c005b1046ea2e143fa20d8b2c1b2 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:51 +0100 Subject: [PATCH 016/163] Add unit tests for live notes handler (10 tests) --- backend/tests/unit/test_desktop_live_notes.py | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 backend/tests/unit/test_desktop_live_notes.py diff --git a/backend/tests/unit/test_desktop_live_notes.py b/backend/tests/unit/test_desktop_live_notes.py new file mode 100644 index 0000000000..7969427ce2 --- /dev/null +++ b/backend/tests/unit/test_desktop_live_notes.py @@ -0,0 +1,99 @@ +"""Tests for desktop live notes handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.live_notes import ( + LiveNoteResult, + LIVE_NOTES_SYSTEM_PROMPT, + generate_live_note, +) +from models.message_event import LiveNoteEvent + + +class TestLiveNoteResultModel: + def test_note_with_text(self): + r = LiveNoteResult(text="Key decision: ship by Friday") + assert r.text == "Key decision: ship by Friday" + + def test_empty_note(self): + r = LiveNoteResult(text="") + assert r.text == "" + + +class TestLiveNoteEvent: + def test_event_structure(self): + event = LiveNoteEvent(text="Meeting note content") + data = event.to_json() + assert data["type"] == "live_note" + assert data["text"] == "Meeting note content" + + +class TestGenerateLiveNote: + @patch('utils.desktop.live_notes.llm_mini') + def test_returns_note(self, mock_llm): + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=LiveNoteResult(text="- Decision: use Redis for caching") + ) + result = asyncio.get_event_loop().run_until_complete( + generate_live_note("We decided to use Redis for caching the API responses") + ) + assert result["text"] == "- Decision: use Redis for caching" + + @patch('utils.desktop.live_notes.llm_mini') + def test_empty_result(self, mock_llm): + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=LiveNoteResult(text="")) + result = asyncio.get_event_loop().run_until_complete( + generate_live_note("um yeah so like um") + ) + assert result["text"] == "" + + @patch('utils.desktop.live_notes.llm_mini') + def test_includes_session_context(self, mock_llm): + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=LiveNoteResult(text="note")) + asyncio.get_event_loop().run_until_complete( + generate_live_note("transcript text", session_context="Sprint planning") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + assert "Sprint planning" in human_msg.content + + @patch('utils.desktop.live_notes.llm_mini') + def test_sends_system_prompt(self, mock_llm): + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=LiveNoteResult(text="")) + asyncio.get_event_loop().run_until_complete( + generate_live_note("test text") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + sys_msg = call_args[0] + assert LIVE_NOTES_SYSTEM_PROMPT in sys_msg.content + + +class TestLiveNotesSystemPrompt: + def test_includes_condensation_rules(self): + assert "Condense" in LIVE_NOTES_SYSTEM_PROMPT + + def test_includes_word_limit(self): + assert "200 words" in LIVE_NOTES_SYSTEM_PROMPT + + def test_includes_preservation_rules(self): + assert "names" in LIVE_NOTES_SYSTEM_PROMPT + assert "decisions" in LIVE_NOTES_SYSTEM_PROMPT From bbfe287bd46e74027cee2f0902ce0d7456eb4cb3 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:52 +0100 Subject: [PATCH 017/163] Add unit tests for profile handler (9 tests) --- backend/tests/unit/test_desktop_profile.py | 103 +++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 backend/tests/unit/test_desktop_profile.py diff --git a/backend/tests/unit/test_desktop_profile.py b/backend/tests/unit/test_desktop_profile.py new file mode 100644 index 0000000000..ea6fea4234 --- /dev/null +++ b/backend/tests/unit/test_desktop_profile.py @@ -0,0 +1,103 @@ +"""Tests for desktop profile generation handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.profile import ( + ProfileResult, + PROFILE_SYSTEM_PROMPT, + generate_profile, +) +from models.message_event import ProfileUpdatedEvent + + +class TestProfileResultModel: + def test_profile_text(self): + r = ProfileResult(profile_text="The user is a backend engineer focused on Python.") + assert "backend engineer" in r.profile_text + + +class TestProfileUpdatedEvent: + def test_event_structure(self): + event = ProfileUpdatedEvent(profile_text="User profile text") + data = event.to_json() + assert data["type"] == "profile_updated" + assert data["profile_text"] == "User profile text" + + +class TestGenerateProfile: + @patch('utils.desktop.profile.get_memories') + @patch('utils.desktop.profile.get_action_items') + @patch('utils.desktop.profile.get_user_goals') + @patch('utils.desktop.profile.llm_mini') + def test_generates_profile(self, mock_llm, mock_goals, mock_tasks, mock_memories): + mock_goals.return_value = [{'title': 'Ship v2'}] + mock_tasks.return_value = [{'description': 'Fix auth bug'}] + mock_memories.return_value = [{'structured': {'content': 'User prefers Python'}}] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=ProfileResult(profile_text="The user is a developer focused on shipping v2.") + ) + result = asyncio.get_event_loop().run_until_complete(generate_profile("uid1")) + assert "developer" in result["profile_text"] + + @patch('utils.desktop.profile.get_memories') + @patch('utils.desktop.profile.get_action_items') + @patch('utils.desktop.profile.get_user_goals') + def test_no_data_returns_default(self, mock_goals, mock_tasks, mock_memories): + mock_goals.return_value = [] + mock_tasks.return_value = [] + mock_memories.return_value = [] + result = asyncio.get_event_loop().run_until_complete(generate_profile("uid1")) + assert "No data available" in result["profile_text"] + + @patch('utils.desktop.profile.get_memories') + @patch('utils.desktop.profile.get_action_items') + @patch('utils.desktop.profile.get_user_goals') + @patch('utils.desktop.profile.llm_mini') + def test_graceful_on_db_errors(self, mock_llm, mock_goals, mock_tasks, mock_memories): + mock_goals.side_effect = Exception("DB error") + mock_tasks.side_effect = Exception("DB error") + mock_memories.side_effect = Exception("DB error") + result = asyncio.get_event_loop().run_until_complete(generate_profile("uid1")) + assert "No data available" in result["profile_text"] + + @patch('utils.desktop.profile.get_memories') + @patch('utils.desktop.profile.get_action_items') + @patch('utils.desktop.profile.get_user_goals') + @patch('utils.desktop.profile.llm_mini') + def test_includes_goals_in_prompt(self, mock_llm, mock_goals, mock_tasks, mock_memories): + mock_goals.return_value = [{'title': 'Learn Rust'}] + mock_tasks.return_value = [] + mock_memories.return_value = [] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=ProfileResult(profile_text="Profile text") + ) + asyncio.get_event_loop().run_until_complete(generate_profile("uid1")) + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + assert "Learn Rust" in human_msg.content + + +class TestProfileSystemPrompt: + def test_third_person_format(self): + assert "third person" in PROFILE_SYSTEM_PROMPT + + def test_word_limit(self): + assert "300 words" in PROFILE_SYSTEM_PROMPT + + def test_factual_requirement(self): + assert "factual" in PROFILE_SYSTEM_PROMPT From 10112df6023a1e05b40c380610144e878af341d8 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:52 +0100 Subject: [PATCH 018/163] Add unit tests for task rerank and dedup handlers (16 tests) --- backend/tests/unit/test_desktop_task_ops.py | 175 ++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 backend/tests/unit/test_desktop_task_ops.py diff --git a/backend/tests/unit/test_desktop_task_ops.py b/backend/tests/unit/test_desktop_task_ops.py new file mode 100644 index 0000000000..6cd9803a5f --- /dev/null +++ b/backend/tests/unit/test_desktop_task_ops.py @@ -0,0 +1,175 @@ +"""Tests for desktop task operations (rerank + dedup) handlers (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.task_ops import ( + RankedTask, + RerankResult, + DedupGroup, + DedupResult, + RERANK_SYSTEM_PROMPT, + DEDUP_SYSTEM_PROMPT, + rerank_tasks, + dedup_tasks, +) +from models.message_event import RerankCompleteEvent, DedupCompleteEvent + + +# --- Rerank tests --- + + +class TestRankedTaskModel: + def test_ranked_task(self): + t = RankedTask(id="task1", new_position=1) + assert t.id == "task1" + assert t.new_position == 1 + + +class TestRerankResult: + def test_rerank_result(self): + r = RerankResult(updated_tasks=[RankedTask(id="t1", new_position=1)]) + assert len(r.updated_tasks) == 1 + + +class TestRerankCompleteEvent: + def test_event_structure(self): + event = RerankCompleteEvent(updated_tasks=[{"id": "t1", "new_position": 1}]) + data = event.to_json() + assert data["type"] == "rerank_complete" + assert len(data["updated_tasks"]) == 1 + + +class TestRerankTasks: + @patch('utils.desktop.task_ops.get_action_items') + @patch('utils.desktop.task_ops.llm_mini') + def test_rerank_returns_order(self, mock_llm, mock_get): + mock_get.return_value = [ + {'id': 't1', 'description': 'Low priority', 'priority': 'low'}, + {'id': 't2', 'description': 'Urgent fix', 'priority': 'high', 'due_at': '2026-03-08'}, + ] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=RerankResult( + updated_tasks=[ + RankedTask(id="t2", new_position=1), + RankedTask(id="t1", new_position=2), + ] + ) + ) + result = asyncio.get_event_loop().run_until_complete(rerank_tasks("uid1")) + assert result["updated_tasks"][0]["id"] == "t2" + assert result["updated_tasks"][0]["new_position"] == 1 + + @patch('utils.desktop.task_ops.get_action_items') + def test_rerank_empty_tasks(self, mock_get): + mock_get.return_value = [] + result = asyncio.get_event_loop().run_until_complete(rerank_tasks("uid1")) + assert result["updated_tasks"] == [] + + @patch('utils.desktop.task_ops.get_action_items') + def test_rerank_db_error(self, mock_get): + mock_get.side_effect = Exception("DB error") + result = asyncio.get_event_loop().run_until_complete(rerank_tasks("uid1")) + assert result["updated_tasks"] == [] + + +# --- Dedup tests --- + + +class TestDedupGroupModel: + def test_dedup_group(self): + g = DedupGroup(keep_id="t1", delete_ids=["t2", "t3"], reason="Same task") + assert g.keep_id == "t1" + assert len(g.delete_ids) == 2 + + +class TestDedupResult: + def test_dedup_with_groups(self): + r = DedupResult(groups=[DedupGroup(keep_id="t1", delete_ids=["t2"], reason="Duplicate")]) + assert len(r.groups) == 1 + + def test_dedup_no_groups(self): + r = DedupResult() + assert r.groups == [] + + +class TestDedupCompleteEvent: + def test_event_structure(self): + event = DedupCompleteEvent(deleted_ids=["t2", "t3"], reason="Duplicate tasks") + data = event.to_json() + assert data["type"] == "dedup_complete" + assert data["deleted_ids"] == ["t2", "t3"] + assert data["reason"] == "Duplicate tasks" + + +class TestDedupTasks: + @patch('utils.desktop.task_ops.get_action_items') + @patch('utils.desktop.task_ops.llm_mini') + def test_dedup_finds_duplicates(self, mock_llm, mock_get): + mock_get.return_value = [ + {'id': 't1', 'description': 'Call John'}, + {'id': 't2', 'description': 'Phone John'}, + {'id': 't3', 'description': 'Write report'}, + ] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=DedupResult( + groups=[DedupGroup(keep_id="t1", delete_ids=["t2"], reason="Same action: contact John")] + ) + ) + result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1")) + assert result["deleted_ids"] == ["t2"] + assert "contact John" in result["reason"] + + @patch('utils.desktop.task_ops.get_action_items') + @patch('utils.desktop.task_ops.llm_mini') + def test_dedup_no_duplicates(self, mock_llm, mock_get): + mock_get.return_value = [ + {'id': 't1', 'description': 'Task A'}, + {'id': 't2', 'description': 'Task B'}, + ] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=DedupResult()) + result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1")) + assert result["deleted_ids"] == [] + assert result["reason"] == "No duplicates found" + + @patch('utils.desktop.task_ops.get_action_items') + def test_dedup_too_few_tasks(self, mock_get): + mock_get.return_value = [{'id': 't1', 'description': 'Only one'}] + result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1")) + assert result["deleted_ids"] == [] + assert "Not enough" in result["reason"] + + @patch('utils.desktop.task_ops.get_action_items') + def test_dedup_db_error(self, mock_get): + mock_get.side_effect = Exception("DB error") + result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1")) + assert result["deleted_ids"] == [] + assert "Failed" in result["reason"] + + +class TestRerankSystemPrompt: + def test_includes_rules(self): + assert "RULES" in RERANK_SYSTEM_PROMPT + assert "deadlines" in RERANK_SYSTEM_PROMPT + + +class TestDedupSystemPrompt: + def test_includes_rules(self): + assert "RULES" in DEDUP_SYSTEM_PROMPT + assert "duplicates" in DEDUP_SYSTEM_PROMPT.lower() From 7dd8fa5a7e252f60e1338abe7079e61afaa1ff80 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:26:02 +0100 Subject: [PATCH 019/163] Add all desktop handler tests to test.sh --- backend/test.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backend/test.sh b/backend/test.sh index 954cd0dbf2..7ebd1feada 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -37,3 +37,9 @@ pytest tests/unit/test_desktop_updates.py -v pytest tests/unit/test_translation_optimization.py -v pytest tests/unit/test_conversation_source_unknown.py -v pytest tests/unit/test_desktop_focus.py -v +pytest tests/unit/test_desktop_tasks.py -v +pytest tests/unit/test_desktop_memories.py -v +pytest tests/unit/test_desktop_advice.py -v +pytest tests/unit/test_desktop_live_notes.py -v +pytest tests/unit/test_desktop_profile.py -v +pytest tests/unit/test_desktop_task_ops.py -v From cc4313875cf7b936eeb159bf7345b0cfe20bcd8c Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 08:57:19 +0100 Subject: [PATCH 020/163] Add BackendProactiveService for server-side proactive AI (#5396) WebSocket client that connects to /v4/listen with Bearer auth and sends screen_frame JSON messages. Routes focus_result responses back to callers via async continuations with frame_id correlation. Co-Authored-By: Claude Opus 4.6 --- .../Core/BackendProactiveService.swift | 358 ++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift new file mode 100644 index 0000000000..e9b69ffe3b --- /dev/null +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift @@ -0,0 +1,358 @@ +import Foundation + +/// WebSocket client for desktop proactive AI via /v4/listen. +/// Sends typed JSON messages (screen_frame, etc.) and routes typed responses +/// (focus_result, etc.) back to callers via async continuations. +/// +/// This is the Phase 2 replacement for direct GeminiClient calls — all LLM +/// processing happens server-side; the client just sends screenshots and +/// receives structured results. +class BackendProactiveService { + + // MARK: - Types + + enum ServiceError: LocalizedError { + case missingAPIURL + case authFailed(String) + case notConnected + case timeout + case serverError(String) + + var errorDescription: String? { + switch self { + case .missingAPIURL: return "OMI_API_URL not set" + case .authFailed(let reason): return "Auth failed: \(reason)" + case .notConnected: return "Backend WebSocket not connected" + case .timeout: return "Request timed out" + case .serverError(let msg): return "Server error: \(msg)" + } + } + } + + // MARK: - Properties + + private var webSocketTask: URLSessionWebSocketTask? + private var urlSession: URLSession? + private(set) var isConnected = false + private var shouldReconnect = false + private var reconnectAttempts = 0 + private let maxReconnectAttempts = 10 + private var reconnectTask: Task? + + // Keepalive + private var keepaliveTask: Task? + private let keepaliveInterval: TimeInterval = 30.0 + + // Pending request continuations keyed by frame_id + private var pendingFocusRequests: [String: CheckedContinuation] = [:] + private let requestLock = NSLock() + private let requestTimeout: TimeInterval = 30.0 + + // MARK: - Connection + + func connect() { + shouldReconnect = true + reconnectAttempts = 0 + startConnect() + } + + func disconnect() { + shouldReconnect = false + reconnectTask?.cancel() + reconnectTask = nil + keepaliveTask?.cancel() + keepaliveTask = nil + + isConnected = false + webSocketTask?.cancel(with: .normalClosure, reason: nil) + webSocketTask = nil + urlSession?.invalidateAndCancel() + urlSession = nil + + // Cancel all pending requests + cancelAllPending(error: ServiceError.notConnected) + + log("BackendProactiveService: Disconnected") + } + + // MARK: - Public API + + /// Send a screen_frame for focus analysis and wait for the focus_result response. + func analyzeFocus( + imageBase64: String, + appName: String, + windowTitle: String + ) async throws -> ScreenAnalysis { + guard isConnected else { + throw ServiceError.notConnected + } + + let frameId = UUID().uuidString + + let message: [String: Any] = [ + "type": "screen_frame", + "frame_id": frameId, + "image_b64": imageBase64, + "app_name": appName, + "window_title": windowTitle, + "analyze": ["focus"], + ] + + let jsonData = try JSONSerialization.data(withJSONObject: message) + guard let jsonString = String(data: jsonData, encoding: .utf8) else { + throw ServiceError.serverError("Failed to encode message") + } + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingFocusRequests[frameId] = continuation + requestLock.unlock() + + webSocketTask?.send(.string(jsonString)) { [weak self] error in + if let error = error { + self?.requestLock.lock() + let cont = self?.pendingFocusRequests.removeValue(forKey: frameId) + self?.requestLock.unlock() + cont?.resume(throwing: error) + } + } + + // Timeout guard + Task { [weak self] in + try? await Task.sleep(nanoseconds: UInt64((self?.requestTimeout ?? 30.0) * 1_000_000_000)) + self?.requestLock.lock() + let cont = self?.pendingFocusRequests.removeValue(forKey: frameId) + self?.requestLock.unlock() + cont?.resume(throwing: ServiceError.timeout) + } + } + } + + // MARK: - Connection Internals + + private func startConnect() { + guard let baseURL = Self.getBaseURL() else { + log("BackendProactiveService: OMI_API_URL not set") + return + } + + Task { + do { + let idToken = try await AuthService.shared.getIdToken() + await connectWithToken(baseURL: baseURL, token: idToken) + } catch { + logError("BackendProactiveService: Failed to get ID token", error: error) + handleDisconnection() + } + } + } + + private func connectWithToken(baseURL: String, token: String) async { + let wsURL = baseURL + .replacingOccurrences(of: "https://", with: "wss://") + .replacingOccurrences(of: "http://", with: "ws://") + let base = wsURL.hasSuffix("/") ? wsURL : wsURL + "/" + + // Connect to /v4/listen with source=desktop — same endpoint as audio, + // but we only send JSON messages (no audio data) + var components = URLComponents(string: "\(base)v4/listen")! + components.queryItems = [ + URLQueryItem(name: "source", value: "desktop"), + URLQueryItem(name: "sample_rate", value: "16000"), + URLQueryItem(name: "codec", value: "pcm16"), + URLQueryItem(name: "channels", value: "1"), + URLQueryItem(name: "language", value: "en"), + ] + + guard let url = components.url else { + log("BackendProactiveService: Invalid URL") + return + } + + log("BackendProactiveService: Connecting to \(url.absoluteString)") + + var request = URLRequest(url: url) + request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization") + request.timeoutInterval = 30 + + let configuration = URLSessionConfiguration.default + configuration.timeoutIntervalForResource = 0 + urlSession = URLSession(configuration: configuration) + webSocketTask = urlSession?.webSocketTask(with: request) + webSocketTask?.resume() + + receiveMessage() + + // Confirm connection after short delay + DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in + guard let self = self, self.webSocketTask?.state == .running else { + self?.handleDisconnection() + return + } + self.isConnected = true + self.reconnectAttempts = 0 + self.startKeepalive() + log("BackendProactiveService: Connected") + } + } + + private func startKeepalive() { + keepaliveTask?.cancel() + keepaliveTask = Task { [weak self] in + while !Task.isCancelled { + try? await Task.sleep(nanoseconds: UInt64((self?.keepaliveInterval ?? 30.0) * 1_000_000_000)) + guard !Task.isCancelled, let self = self, self.isConnected else { break } + self.sendKeepalive() + } + } + } + + private func sendKeepalive() { + guard isConnected, let ws = webSocketTask else { return } + ws.send(.string("{\"type\": \"KeepAlive\"}")) { [weak self] error in + if let error = error { + logError("BackendProactiveService: Keepalive error", error: error) + self?.handleDisconnection() + } + } + } + + private func handleDisconnection() { + guard isConnected || shouldReconnect else { return } + + isConnected = false + keepaliveTask?.cancel() + keepaliveTask = nil + webSocketTask?.cancel(with: .goingAway, reason: nil) + webSocketTask = nil + urlSession?.invalidateAndCancel() + urlSession = nil + + cancelAllPending(error: ServiceError.notConnected) + + if shouldReconnect && reconnectAttempts < maxReconnectAttempts { + reconnectAttempts += 1 + let delay = min(pow(2.0, Double(reconnectAttempts)), 32.0) + log("BackendProactiveService: Reconnecting in \(delay)s (attempt \(reconnectAttempts))") + + reconnectTask = Task { + try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) + guard !Task.isCancelled, self.shouldReconnect else { return } + self.startConnect() + } + } else if reconnectAttempts >= maxReconnectAttempts { + log("BackendProactiveService: Max reconnect attempts reached") + } + } + + // MARK: - Message Handling + + private func receiveMessage() { + webSocketTask?.receive { [weak self] result in + guard let self = self else { return } + + switch result { + case .success(let message): + self.handleMessage(message) + self.receiveMessage() + case .failure(let error): + guard self.isConnected else { return } + logError("BackendProactiveService: Receive error", error: error) + self.handleDisconnection() + } + } + } + + private func handleMessage(_ message: URLSessionWebSocketTask.Message) { + let text: String + switch message { + case .string(let s): + text = s + case .data(let data): + guard let s = String(data: data, encoding: .utf8) else { return } + text = s + @unknown default: + return + } + + // Skip heartbeat + if text == "ping" { return } + + guard let data = text.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + let type = json["type"] as? String else { + return + } + + switch type { + case "focus_result": + handleFocusResult(data) + default: + // Other event types (memory_created, etc.) — ignore for now + break + } + } + + private func handleFocusResult(_ data: Data) { + guard let response = try? JSONDecoder().decode(FocusResultResponse.self, from: data) else { + log("BackendProactiveService: Failed to decode focus_result") + return + } + + let analysis = ScreenAnalysis( + status: FocusStatus(rawValue: response.status) ?? .focused, + appOrSite: response.appOrSite, + description: response.description, + message: response.message + ) + + requestLock.lock() + let continuation = pendingFocusRequests.removeValue(forKey: response.frameId) + requestLock.unlock() + + continuation?.resume(returning: analysis) + } + + // MARK: - Helpers + + private func cancelAllPending(error: Error) { + requestLock.lock() + let pending = pendingFocusRequests + pendingFocusRequests.removeAll() + requestLock.unlock() + + for (_, continuation) in pending { + continuation.resume(throwing: error) + } + } + + private static func getBaseURL() -> String? { + if let cString = getenv("OMI_API_URL"), let url = String(validatingUTF8: cString), !url.isEmpty { + return url + } + if let envURL = ProcessInfo.processInfo.environment["OMI_API_URL"], !envURL.isEmpty { + return envURL + } + return nil + } +} + +// MARK: - Response Models + +private struct FocusResultResponse: Decodable { + let type: String + let frameId: String + let status: String + let appOrSite: String + let description: String + let message: String? + + enum CodingKeys: String, CodingKey { + case type + case frameId = "frame_id" + case status + case appOrSite = "app_or_site" + case description + case message + } +} From 056a352672bc3535b11e5930186fa774bd892b05 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 08:57:25 +0100 Subject: [PATCH 021/163] Wire FocusAssistant to BackendProactiveService instead of GeminiClient (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace direct Gemini API calls with backend WebSocket screen_frame messages. Context building (goals, tasks, memories, AI profile) moves server-side. Client becomes thin: encode JPEG→base64, send screen_frame, receive focus_result. Co-Authored-By: Claude Opus 4.6 --- .../Assistants/Focus/FocusAssistant.swift | 226 ++---------------- 1 file changed, 16 insertions(+), 210 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift index 33a355f56b..88eaddfc7c 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift @@ -17,7 +17,7 @@ actor FocusAssistant: ProactiveAssistant { // MARK: - Properties - private let geminiClient: GeminiClient + private let backendService: BackendProactiveService private let onAlert: (String) -> Void private let onStatusChange: ((FocusStatus) -> Void)? private let onRefocus: (() -> Void)? @@ -35,12 +35,6 @@ actor FocusAssistant: ProactiveAssistant { private let maxPendingTasks = 3 private var currentApp: String? - // MARK: - Context Cache - // Cached context from local DB (goals, tasks, memories) to enrich focus analysis - private var cachedContextString: String? - private var contextCacheTime: Date? - private let contextCacheDuration: TimeInterval = 120 // 2 minutes - // MARK: - Smart Analysis Filtering // Skip analysis when user is focused on the same context (app + window title) // Also skip during cooldown period after distraction (unless context changes) @@ -58,25 +52,16 @@ actor FocusAssistant: ProactiveAssistant { private var consecutiveErrorCount = 0 private var errorBackoffEndTime: Date? - /// Get the current system prompt from settings (accessed on MainActor for thread safety) - private var systemPrompt: String { - get async { - await MainActor.run { - FocusAssistantSettings.shared.analysisPrompt - } - } - } - // MARK: - Initialization init( - apiKey: String? = nil, + backendService: BackendProactiveService, onAlert: @escaping (String) -> Void = { _ in }, onStatusChange: ((FocusStatus) -> Void)? = nil, onRefocus: (() -> Void)? = nil, onDistraction: (() -> Void)? = nil - ) throws { - self.geminiClient = try GeminiClient(apiKey: apiKey) + ) { + self.backendService = backendService self.onAlert = onAlert self.onStatusChange = onStatusChange self.onRefocus = onRefocus @@ -299,8 +284,6 @@ actor FocusAssistant: ProactiveAssistant { analysisCooldownEndTime = nil consecutiveErrorCount = 0 errorBackoffEndTime = nil - cachedContextString = nil - contextCacheTime = nil // Clear cooldown in UI await MainActor.run { @@ -353,99 +336,26 @@ actor FocusAssistant: ProactiveAssistant { /// Run analysis on a screenshot with no side effects (no saving, no state updates, no notifications). /// Used by the test runner GUI and CLI. func testAnalyze(jpegData: Data, appName: String) async throws -> ScreenAnalysis? { - return try await analyzeScreenshot(jpegData: jpegData) + return try await analyzeScreenshot(jpegData: jpegData, appName: appName, windowTitle: nil) } /// Reset test history — call before starting a test run to get a clean slate. func resetTestHistory() { - testAnalysisHistory.removeAll() + // History is now tracked server-side; no-op on client } /// Run analysis with accumulating history across calls (simulates production behavior). - /// Each result is appended to a separate test history buffer so the model sees prior decisions. + /// History is tracked server-side per WebSocket session, so this is equivalent to testAnalyze. func testAnalyzeWithHistory(jpegData: Data, appName: String) async throws -> ScreenAnalysis? { - let result = try await analyzeScreenshotWithHistory(jpegData: jpegData, history: testAnalysisHistory) - if let result = result { - testAnalysisHistory.append(result) - if testAnalysisHistory.count > maxHistorySize { - testAnalysisHistory.removeFirst() - } - } - return result - } - - /// Separate history buffer for test runs (doesn't pollute production history) - private var testAnalysisHistory: [ScreenAnalysis] = [] - - /// Variant of analyzeScreenshot that accepts an explicit history array - private func analyzeScreenshotWithHistory(jpegData: Data, history: [ScreenAnalysis]) async throws -> ScreenAnalysis? { - let context = await refreshContext() - - // Format provided history - var historyText = "" - if !history.isEmpty { - var lines = ["Recent activity (oldest to newest):"] - for (i, past) in history.enumerated() { - lines.append("\(i + 1). [\(past.status.rawValue)] \(past.appOrSite): \(past.description)") - if let message = past.message { - lines.append(" Message: \(message)") - } - } - historyText = lines.joined(separator: "\n") - } - - var promptParts: [String] = [] - if !context.isEmpty { - promptParts.append(context) - } - if !historyText.isEmpty { - promptParts.append(historyText) - } - promptParts.append("Now analyze this new screenshot:") - let prompt = promptParts.joined(separator: "\n\n") - - let currentSystemPrompt = await systemPrompt - - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "status": .init(type: "string", enum: ["focused", "distracted"], description: "Whether the user is focused or distracted"), - "app_or_site": .init(type: "string", enum: nil, description: "The app or website visible"), - "description": .init(type: "string", enum: nil, description: "Brief description of what's on screen"), - "message": .init(type: "string", enum: nil, description: "Coaching message") - ], - required: ["status", "app_or_site", "description"] - ) - - let responseText = try await geminiClient.sendRequest( - prompt: prompt, - imageData: jpegData, - systemPrompt: currentSystemPrompt, - responseSchema: responseSchema - ) - - return try JSONDecoder().decode(ScreenAnalysis.self, from: Data(responseText.utf8)) + return try await analyzeScreenshot(jpegData: jpegData, appName: appName, windowTitle: nil) } // MARK: - Analysis - private func formatHistory() -> String { - guard !analysisHistory.isEmpty else { return "" } - - var lines = ["Recent activity (oldest to newest):"] - for (i, past) in analysisHistory.enumerated() { - lines.append("\(i + 1). [\(past.status.rawValue)] \(past.appOrSite): \(past.description)") - if let message = past.message { - lines.append(" Message: \(message)") - } - } - return lines.joined(separator: "\n") - } - private func processFrame(_ frame: CapturedFrame) async { guard await isEnabled else { return } do { - guard let analysis = try await analyzeScreenshot(jpegData: frame.jpegData) else { + guard let analysis = try await analyzeScreenshot(jpegData: frame.jpegData, appName: frame.appName, windowTitle: frame.windowTitle) else { return } @@ -585,118 +495,14 @@ actor FocusAssistant: ProactiveAssistant { } } - /// Refresh context from local DB (goals, tasks, memories) with caching - private func refreshContext() async -> String { - // Return cached context if fresh - if let cached = cachedContextString, - let cacheTime = contextCacheTime, - Date().timeIntervalSince(cacheTime) < contextCacheDuration { - return cached - } - - var sections: [String] = [] - - // AI User Profile - do { - if let profile = await AIUserProfileService.shared.getLatestProfile() { - sections.append("USER PROFILE (who this user is):\n\(profile.profileText)") - } - } - - // Time context - let formatter = DateFormatter() - formatter.dateFormat = "EEEE, MMMM d, yyyy 'at' h:mm a" - sections.append("TIME CONTEXT:\n\(formatter.string(from: Date()))") - - // Active goals - do { - let goals = try await GoalStorage.shared.getLocalGoals(activeOnly: true) - if !goals.isEmpty { - var lines = ["ACTIVE GOALS:"] - for (i, goal) in goals.prefix(10).enumerated() { - let desc = goal.description.map { " - \($0)" } ?? "" - lines.append("\(i + 1). \(goal.title)\(desc)") - } - sections.append(lines.joined(separator: "\n")) - } - } catch { - logError("Focus: Failed to load goals for context", error: error) - } - - // Top tasks by importance - do { - let tasks = try await ActionItemStorage.shared.getTopRelevanceTasks(limit: 50) - if !tasks.isEmpty { - var lines = ["CURRENT TASKS (by importance):"] - for (i, task) in tasks.enumerated() { - let priority = task.priority ?? "medium" - lines.append("\(i + 1). [\(priority)] \(task.description)") - } - sections.append(lines.joined(separator: "\n")) - } - } catch { - logError("Focus: Failed to load tasks for context", error: error) - } - - // Recent memories - do { - let memories = try await MemoryStorage.shared.getLocalMemories(limit: 50, category: "core") - if !memories.isEmpty { - var lines = ["RECENT MEMORIES:"] - for (i, memory) in memories.enumerated() { - lines.append("\(i + 1). \(memory.content)") - } - sections.append(lines.joined(separator: "\n")) - } - } catch { - logError("Focus: Failed to load memories for context", error: error) - } - - let contextString = sections.joined(separator: "\n\n") - cachedContextString = contextString - contextCacheTime = Date() - return contextString - } - - private func analyzeScreenshot(jpegData: Data) async throws -> ScreenAnalysis? { - // Refresh context from local DB - let context = await refreshContext() - - // Build prompt with context + history - let historyText = formatHistory() - var promptParts: [String] = [] - if !context.isEmpty { - promptParts.append(context) - } - if !historyText.isEmpty { - promptParts.append(historyText) - } - promptParts.append("Now analyze this new screenshot:") - let prompt = promptParts.joined(separator: "\n\n") - - // Get current system prompt from settings - let currentSystemPrompt = await systemPrompt - - // Build response schema - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "status": .init(type: "string", enum: ["focused", "distracted"], description: "Whether the user is focused or distracted"), - "app_or_site": .init(type: "string", enum: nil, description: "The app or website visible"), - "description": .init(type: "string", enum: nil, description: "Brief description of what's on screen"), - "message": .init(type: "string", enum: nil, description: "Coaching message") - ], - required: ["status", "app_or_site", "description"] + private func analyzeScreenshot(jpegData: Data, appName: String, windowTitle: String?) async throws -> ScreenAnalysis? { + let base64 = jpegData.base64EncodedString() + let result = try await backendService.analyzeFocus( + imageBase64: base64, + appName: appName, + windowTitle: windowTitle ?? "" ) - - let responseText = try await geminiClient.sendRequest( - prompt: prompt, - imageData: jpegData, - systemPrompt: currentSystemPrompt, - responseSchema: responseSchema - ) - - return try JSONDecoder().decode(ScreenAnalysis.self, from: Data(responseText.utf8)) + return result } // MARK: - Storage From a3a8dfacb678753e595a19c14a84b128e3edda96 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 08:57:31 +0100 Subject: [PATCH 022/163] Create BackendProactiveService in ProactiveAssistantsPlugin lifecycle (#5396) Start WS connection when monitoring starts, disconnect on stop. Pass service to FocusAssistant (shared for future assistant types). Co-Authored-By: Claude Opus 4.6 --- .../ProactiveAssistantsPlugin.swift | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift index e4d94557d2..4bc4cfe752 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift @@ -14,6 +14,7 @@ public class ProactiveAssistantsPlugin: NSObject { private var screenCaptureService: ScreenCaptureService? private var windowMonitor: WindowMonitor? + private var backendProactiveService: BackendProactiveService? private var focusAssistant: FocusAssistant? /// Public read-only accessor for memory diagnostics @@ -308,8 +309,14 @@ public class ProactiveAssistantsPlugin: NSObject { // Initialize services screenCaptureService = ScreenCaptureService() + // Start backend proactive AI WebSocket (Phase 2 — server-side LLM) + let proactiveService = BackendProactiveService() + proactiveService.connect() + backendProactiveService = proactiveService + do { - focusAssistant = try FocusAssistant( + focusAssistant = FocusAssistant( + backendService: proactiveService, onAlert: { [weak self] message in self?.sendEvent(type: "alert", data: ["message": message]) }, @@ -448,6 +455,8 @@ public class ProactiveAssistantsPlugin: NSObject { } } + backendProactiveService?.disconnect() + backendProactiveService = nil focusAssistant = nil taskAssistant = nil adviceAssistant = nil From 01e323d5f4e492c52b12b41fab2048cf60fd8652 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 08:57:36 +0100 Subject: [PATCH 023/163] Update FocusTestRunnerWindow for new FocusAssistant init signature (#5396) Co-Authored-By: Claude Opus 4.6 --- .../ProactiveAssistants/UI/FocusTestRunnerWindow.swift | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift b/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift index ca5c21034a..0ae19fd68e 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift @@ -637,12 +637,9 @@ enum FocusTestRunner { if let existing = coordAssistant as? FocusAssistant { focusAssistant = existing } else { - do { - focusAssistant = try FocusAssistant() - } catch { - log("FocusTestCLI: ERROR — Failed to create FocusAssistant: \(error)") - return - } + let service = BackendProactiveService() + service.connect() + focusAssistant = FocusAssistant(backendService: service) } // Get excluded apps From bc3abafdb85ac8ef9dcb72c3027d47c28c0a85c9 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:15:12 +0100 Subject: [PATCH 024/163] Add all 8 message types to BackendProactiveService (#5396) Vision handlers: analyzeFocus, extractTasks, extractMemories, generateAdvice (send screen_frame with analyze type, receive typed result via frame_id) Text handlers: generateLiveNote, requestProfile, rerankTasks, deduplicateTasks (send typed JSON message, receive result via single-slot continuation) Co-Authored-By: Claude Opus 4.6 --- .../Core/BackendProactiveService.swift | 376 ++++++++++++++---- 1 file changed, 305 insertions(+), 71 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift index e9b69ffe3b..1d473ba4c8 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift @@ -43,10 +43,21 @@ class BackendProactiveService { private var keepaliveTask: Task? private let keepaliveInterval: TimeInterval = 30.0 - // Pending request continuations keyed by frame_id + // Pending continuations keyed by frame_id (vision handlers) private var pendingFocusRequests: [String: CheckedContinuation] = [:] + private var pendingTasksRequests: [String: CheckedContinuation] = [:] + private var pendingMemoriesRequests: [String: CheckedContinuation] = [:] + private var pendingAdviceRequests: [String: CheckedContinuation] = [:] + + // Pending continuations for text-only handlers (one outstanding per type) + private var pendingLiveNote: CheckedContinuation? + private var pendingProfile: CheckedContinuation? + private var pendingRerank: CheckedContinuation? + private var pendingDedup: CheckedContinuation? + private let requestLock = NSLock() private let requestTimeout: TimeInterval = 30.0 + private let textRequestTimeout: TimeInterval = 60.0 // MARK: - Connection @@ -69,63 +80,191 @@ class BackendProactiveService { urlSession?.invalidateAndCancel() urlSession = nil - // Cancel all pending requests cancelAllPending(error: ServiceError.notConnected) - log("BackendProactiveService: Disconnected") } - // MARK: - Public API + // MARK: - Vision Handlers (screen_frame) /// Send a screen_frame for focus analysis and wait for the focus_result response. - func analyzeFocus( - imageBase64: String, - appName: String, - windowTitle: String - ) async throws -> ScreenAnalysis { - guard isConnected else { - throw ServiceError.notConnected + func analyzeFocus(imageBase64: String, appName: String, windowTitle: String) async throws -> ScreenAnalysis { + guard isConnected else { throw ServiceError.notConnected } + let frameId = UUID().uuidString + let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["focus"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingFocusRequests[frameId] = continuation + requestLock.unlock() + sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout, + remove: { self.pendingFocusRequests.removeValue(forKey: $0) }) + } + } + + /// Send a screen_frame for task extraction. + func extractTasks(imageBase64: String, appName: String, windowTitle: String) async throws -> TasksExtractedResult { + guard isConnected else { throw ServiceError.notConnected } + let frameId = UUID().uuidString + let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["tasks"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingTasksRequests[frameId] = continuation + requestLock.unlock() + sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout, + remove: { self.pendingTasksRequests.removeValue(forKey: $0) }) } + } + /// Send a screen_frame for memory extraction. + func extractMemories(imageBase64: String, appName: String, windowTitle: String) async throws -> MemoriesExtractedResult { + guard isConnected else { throw ServiceError.notConnected } let frameId = UUID().uuidString + let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["memories"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingMemoriesRequests[frameId] = continuation + requestLock.unlock() + sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout, + remove: { self.pendingMemoriesRequests.removeValue(forKey: $0) }) + } + } + + /// Send a screen_frame for advice generation. + func generateAdvice(imageBase64: String, appName: String, windowTitle: String) async throws -> AdviceExtractedResult { + guard isConnected else { throw ServiceError.notConnected } + let frameId = UUID().uuidString + let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["advice"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingAdviceRequests[frameId] = continuation + requestLock.unlock() + sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout, + remove: { self.pendingAdviceRequests.removeValue(forKey: $0) }) + } + } + + // MARK: - Text-Only Handlers + + /// Send transcript text for live note generation. + func generateLiveNote(text: String, sessionContext: String = "") async throws -> String { + guard isConnected else { throw ServiceError.notConnected } + let jsonString = try buildJSON(["type": "live_notes_text", "text": text, "session_context": sessionContext]) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingLiveNote = continuation + requestLock.unlock() + sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout, + remove: { let c = self.pendingLiveNote; self.pendingLiveNote = nil; return c }) + } + } + + /// Request profile generation (server fetches user data from Firestore). + func requestProfile() async throws -> String { + guard isConnected else { throw ServiceError.notConnected } + let jsonString = try buildJSON(["type": "profile_request"]) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingProfile = continuation + requestLock.unlock() + sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout, + remove: { let c = self.pendingProfile; self.pendingProfile = nil; return c }) + } + } + + /// Request task reranking (server fetches tasks from Firestore). + func rerankTasks() async throws -> RerankExtractedResult { + guard isConnected else { throw ServiceError.notConnected } + let jsonString = try buildJSON(["type": "task_rerank"]) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingRerank = continuation + requestLock.unlock() + sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout, + remove: { let c = self.pendingRerank; self.pendingRerank = nil; return c }) + } + } + + /// Request task deduplication (server fetches tasks from Firestore). + func deduplicateTasks() async throws -> DedupExtractedResult { + guard isConnected else { throw ServiceError.notConnected } + let jsonString = try buildJSON(["type": "task_dedup"]) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingDedup = continuation + requestLock.unlock() + sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout, + remove: { let c = self.pendingDedup; self.pendingDedup = nil; return c }) + } + } + + // MARK: - Send Helpers - let message: [String: Any] = [ + private func buildScreenFrameJSON(frameId: String, analyzeTypes: [String], imageBase64: String, appName: String, windowTitle: String) throws -> String { + try buildJSON([ "type": "screen_frame", "frame_id": frameId, "image_b64": imageBase64, "app_name": appName, "window_title": windowTitle, - "analyze": ["focus"], - ] + "analyze": analyzeTypes, + ]) + } - let jsonData = try JSONSerialization.data(withJSONObject: message) - guard let jsonString = String(data: jsonData, encoding: .utf8) else { + private func buildJSON(_ dict: [String: Any]) throws -> String { + let data = try JSONSerialization.data(withJSONObject: dict) + guard let str = String(data: data, encoding: .utf8) else { throw ServiceError.serverError("Failed to encode message") } + return str + } - return try await withCheckedThrowingContinuation { continuation in - requestLock.lock() - pendingFocusRequests[frameId] = continuation - requestLock.unlock() - - webSocketTask?.send(.string(jsonString)) { [weak self] error in - if let error = error { - self?.requestLock.lock() - let cont = self?.pendingFocusRequests.removeValue(forKey: frameId) - self?.requestLock.unlock() - cont?.resume(throwing: error) - } + /// Send JSON and set up timeout for frame_id-keyed continuations. + private func sendAndTimeout(jsonString: String, frameId: String, timeout: TimeInterval, + remove: @escaping (String) -> CheckedContinuation?) { + webSocketTask?.send(.string(jsonString)) { [weak self] error in + if let error = error { + self?.requestLock.lock() + let cont = remove(frameId) + self?.requestLock.unlock() + cont?.resume(throwing: error) } + } + + Task { [weak self] in + try? await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + self?.requestLock.lock() + let cont = remove(frameId) + self?.requestLock.unlock() + cont?.resume(throwing: ServiceError.timeout) + } + } - // Timeout guard - Task { [weak self] in - try? await Task.sleep(nanoseconds: UInt64((self?.requestTimeout ?? 30.0) * 1_000_000_000)) + /// Send JSON and set up timeout for single-slot continuations. + private func sendAndTimeoutSingle(jsonString: String, timeout: TimeInterval, + remove: @escaping () -> CheckedContinuation?) { + webSocketTask?.send(.string(jsonString)) { [weak self] error in + if let error = error { self?.requestLock.lock() - let cont = self?.pendingFocusRequests.removeValue(forKey: frameId) + let cont = remove() self?.requestLock.unlock() - cont?.resume(throwing: ServiceError.timeout) + cont?.resume(throwing: error) } } + + Task { [weak self] in + try? await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + self?.requestLock.lock() + let cont = remove() + self?.requestLock.unlock() + cont?.resume(throwing: ServiceError.timeout) + } } // MARK: - Connection Internals @@ -153,8 +292,6 @@ class BackendProactiveService { .replacingOccurrences(of: "http://", with: "ws://") let base = wsURL.hasSuffix("/") ? wsURL : wsURL + "/" - // Connect to /v4/listen with source=desktop — same endpoint as audio, - // but we only send JSON messages (no audio data) var components = URLComponents(string: "\(base)v4/listen")! components.queryItems = [ URLQueryItem(name: "source", value: "desktop"), @@ -183,7 +320,6 @@ class BackendProactiveService { receiveMessage() - // Confirm connection after short delay DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in guard let self = self, self.webSocketTask?.state == .running else { self?.handleDisconnection() @@ -275,7 +411,6 @@ class BackendProactiveService { return } - // Skip heartbeat if text == "ping" { return } guard let data = text.data(using: .utf8), @@ -286,44 +421,132 @@ class BackendProactiveService { switch type { case "focus_result": - handleFocusResult(data) + handleFocusResult(json) + case "tasks_extracted": + handleTasksExtracted(json) + case "memories_extracted": + handleMemoriesExtracted(json) + case "advice_extracted": + handleAdviceExtracted(json) + case "live_note": + handleLiveNote(json) + case "profile_updated": + handleProfileUpdated(json) + case "rerank_complete": + handleRerankComplete(json) + case "dedup_complete": + handleDedupComplete(json) default: - // Other event types (memory_created, etc.) — ignore for now break } } - private func handleFocusResult(_ data: Data) { - guard let response = try? JSONDecoder().decode(FocusResultResponse.self, from: data) else { - log("BackendProactiveService: Failed to decode focus_result") - return - } + // MARK: - Response Handlers + private func handleFocusResult(_ json: [String: Any]) { + guard let frameId = json["frame_id"] as? String else { return } let analysis = ScreenAnalysis( - status: FocusStatus(rawValue: response.status) ?? .focused, - appOrSite: response.appOrSite, - description: response.description, - message: response.message + status: FocusStatus(rawValue: json["status"] as? String ?? "focused") ?? .focused, + appOrSite: json["app_or_site"] as? String ?? "", + description: json["description"] as? String ?? "", + message: json["message"] as? String ) + requestLock.lock() + let cont = pendingFocusRequests.removeValue(forKey: frameId) + requestLock.unlock() + cont?.resume(returning: analysis) + } + + private func handleTasksExtracted(_ json: [String: Any]) { + guard let frameId = json["frame_id"] as? String else { return } + let tasks = (json["tasks"] as? [[String: Any]]) ?? [] + let result = TasksExtractedResult(frameId: frameId, tasks: tasks) + requestLock.lock() + let cont = pendingTasksRequests.removeValue(forKey: frameId) + requestLock.unlock() + cont?.resume(returning: result) + } + + private func handleMemoriesExtracted(_ json: [String: Any]) { + guard let frameId = json["frame_id"] as? String else { return } + let memories = (json["memories"] as? [[String: Any]]) ?? [] + let result = MemoriesExtractedResult(frameId: frameId, memories: memories) + requestLock.lock() + let cont = pendingMemoriesRequests.removeValue(forKey: frameId) + requestLock.unlock() + cont?.resume(returning: result) + } + + private func handleAdviceExtracted(_ json: [String: Any]) { + guard let frameId = json["frame_id"] as? String else { return } + let result = AdviceExtractedResult(frameId: frameId, advice: json["advice"]) + requestLock.lock() + let cont = pendingAdviceRequests.removeValue(forKey: frameId) + requestLock.unlock() + cont?.resume(returning: result) + } + + private func handleLiveNote(_ json: [String: Any]) { + let text = json["text"] as? String ?? "" + requestLock.lock() + let cont = pendingLiveNote + pendingLiveNote = nil + requestLock.unlock() + cont?.resume(returning: text) + } + + private func handleProfileUpdated(_ json: [String: Any]) { + let profileText = json["profile_text"] as? String ?? "" + requestLock.lock() + let cont = pendingProfile + pendingProfile = nil + requestLock.unlock() + cont?.resume(returning: profileText) + } + private func handleRerankComplete(_ json: [String: Any]) { + let updatedTasks = (json["updated_tasks"] as? [[String: Any]]) ?? [] + let result = RerankExtractedResult(updatedTasks: updatedTasks) requestLock.lock() - let continuation = pendingFocusRequests.removeValue(forKey: response.frameId) + let cont = pendingRerank + pendingRerank = nil requestLock.unlock() + cont?.resume(returning: result) + } - continuation?.resume(returning: analysis) + private func handleDedupComplete(_ json: [String: Any]) { + let deletedIds = (json["deleted_ids"] as? [String]) ?? [] + let reason = json["reason"] as? String ?? "" + let result = DedupExtractedResult(deletedIds: deletedIds, reason: reason) + requestLock.lock() + let cont = pendingDedup + pendingDedup = nil + requestLock.unlock() + cont?.resume(returning: result) } // MARK: - Helpers private func cancelAllPending(error: Error) { requestLock.lock() - let pending = pendingFocusRequests - pendingFocusRequests.removeAll() + let focus = pendingFocusRequests; pendingFocusRequests.removeAll() + let tasks = pendingTasksRequests; pendingTasksRequests.removeAll() + let memories = pendingMemoriesRequests; pendingMemoriesRequests.removeAll() + let advice = pendingAdviceRequests; pendingAdviceRequests.removeAll() + let liveNote = pendingLiveNote; pendingLiveNote = nil + let profile = pendingProfile; pendingProfile = nil + let rerank = pendingRerank; pendingRerank = nil + let dedup = pendingDedup; pendingDedup = nil requestLock.unlock() - for (_, continuation) in pending { - continuation.resume(throwing: error) - } + for (_, c) in focus { c.resume(throwing: error) } + for (_, c) in tasks { c.resume(throwing: error) } + for (_, c) in memories { c.resume(throwing: error) } + for (_, c) in advice { c.resume(throwing: error) } + liveNote?.resume(throwing: error) + profile?.resume(throwing: error) + rerank?.resume(throwing: error) + dedup?.resume(throwing: error) } private static func getBaseURL() -> String? { @@ -337,22 +560,33 @@ class BackendProactiveService { } } -// MARK: - Response Models +// MARK: - Result Types -private struct FocusResultResponse: Decodable { - let type: String +/// Tasks extracted from a screen_frame analysis. +struct TasksExtractedResult { let frameId: String - let status: String - let appOrSite: String - let description: String - let message: String? - - enum CodingKeys: String, CodingKey { - case type - case frameId = "frame_id" - case status - case appOrSite = "app_or_site" - case description - case message - } + let tasks: [[String: Any]] // Raw task dicts from backend +} + +/// Memories extracted from a screen_frame analysis. +struct MemoriesExtractedResult { + let frameId: String + let memories: [[String: Any]] // Raw memory dicts from backend +} + +/// Advice extracted from a screen_frame analysis. +struct AdviceExtractedResult { + let frameId: String + let advice: Any? // Raw advice from backend (dict or null) +} + +/// Task reranking result. +struct RerankExtractedResult { + let updatedTasks: [[String: Any]] // [{id, new_position}, ...] +} + +/// Task deduplication result. +struct DedupExtractedResult { + let deletedIds: [String] + let reason: String } From d1fcb80ed9ff2cc2a2d4b897da30f53c02b6c830 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:01 +0100 Subject: [PATCH 025/163] Wire TaskAssistant thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace GeminiClient tool-calling loop with backendService.extractTasks(). Remove extractTaskSingleStage, refreshContext, vector/keyword search, validateTaskTitle — all LLM logic now server-side. -550 lines. Co-Authored-By: Claude Opus 4.6 --- .../TaskExtraction/TaskAssistant.swift | 698 ++---------------- 1 file changed, 82 insertions(+), 616 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift index 8df5b2fcc7..512c464777 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift @@ -1,7 +1,7 @@ import Foundation -/// Task extraction assistant that identifies tasks and action items from screen content -/// Uses single-stage Gemini tool calling with vector + FTS5 search for deduplication +/// Task extraction assistant that identifies tasks and action items from screen content. +/// Phase 2: sends screenshots to backend via WebSocket, receives structured task results. actor TaskAssistant: ProactiveAssistant { // MARK: - ProactiveAssistant Protocol @@ -18,7 +18,7 @@ actor TaskAssistant: ProactiveAssistant { // MARK: - Properties - private let geminiClient: GeminiClient + private let backendService: BackendProactiveService private var isRunning = false private var previousTasks: [ExtractedTask] = [] // Last 10 extracted tasks for context private let maxPreviousTasks = 10 @@ -41,11 +41,6 @@ actor TaskAssistant: ProactiveAssistant { /// Timestamp of last context switch yield, for throttling rapid switches private var lastContextSwitchYieldTime: Date = .distantPast - // Cached goals (refreshed every 5 minutes) - private var cachedGoals: [Goal] = [] - private var lastGoalsRefresh: Date = .distantPast - private let goalsRefreshInterval: TimeInterval = 300 - // MARK: - Due Date Helpers /// Parse an inferred deadline string into a Date, or default to end of today. @@ -114,15 +109,6 @@ actor TaskAssistant: ProactiveAssistant { return calendar.date(bySettingHour: 23, minute: 59, second: 0, of: startOfDay) ?? startOfDay } - /// Get the current system prompt from settings (accessed on MainActor for thread safety) - private var systemPrompt: String { - get async { - await MainActor.run { - TaskAssistantSettings.shared.analysisPrompt - } - } - } - /// Get the extraction interval from settings private var extractionInterval: TimeInterval { get async { @@ -143,9 +129,8 @@ actor TaskAssistant: ProactiveAssistant { // MARK: - Initialization - init(apiKey: String? = nil) throws { - // Use Gemini 3 Pro for better task extraction quality - self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest") + init(backendService: BackendProactiveService) { + self.backendService = backendService let (stream, continuation) = AsyncStream.makeStream(of: TriggerEvent.self, bufferingPolicy: .bufferingNewest(1)) self.triggerStream = stream @@ -221,11 +206,17 @@ actor TaskAssistant: ProactiveAssistant { // MARK: - Test Analysis (for test runner) - /// Run the extraction pipeline on arbitrary JPEG data without side effects (no saving, no events). - /// Used by the test runner to replay past screenshots. - /// Returns (result, searchCount) where searchCount is the number of search tool calls made. + /// Run extraction via backend for test runner. Returns (result, 0) for compatibility. func testAnalyze(jpegData: Data, appName: String) async throws -> (TaskExtractionResult?, Int) { - return try await extractTaskSingleStage(from: jpegData, appName: appName) + let base64 = autoreleasepool { jpegData.base64EncodedString() } + let backendResult = try await backendService.extractTasks( + imageBase64: base64, appName: appName, windowTitle: "" + ) + if backendResult.tasks.isEmpty { + return (TaskExtractionResult(hasNewTask: false, task: nil, contextSummary: "Analyzed \(appName)", currentActivity: ""), 0) + } + let result = parseBackendTask(backendResult.tasks[0], appName: appName) + return (result, 0) } // MARK: - ProactiveAssistant Protocol Methods @@ -579,7 +570,7 @@ actor TaskAssistant: ProactiveAssistant { latestFrame = nil } - // MARK: - Single-Stage Analysis with Tool Calling + // MARK: - Backend Analysis (Phase 2 thin client) private func processFrame(_ frame: CapturedFrame) async { let enabled = await isEnabled @@ -590,613 +581,88 @@ actor TaskAssistant: ProactiveAssistant { log("Task: Analyzing frame from \(frame.appName)...") do { - let (result, searchCount) = try await extractTaskSingleStage(from: frame.jpegData, appName: frame.appName) - guard let result = result else { - log("Task: Analysis returned no result") - return - } - - log("Task: Analysis complete - hasNewTask: \(result.hasNewTask), context: \(result.contextSummary), searches: \(searchCount)") + let base64 = autoreleasepool { frame.jpegData.base64EncodedString() } + let backendResult = try await backendService.extractTasks( + imageBase64: base64, + appName: frame.appName, + windowTitle: frame.windowTitle ?? "" + ) - await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, appName: frame.appName, windowTitle: frame.windowTitle) { type, data in + let sendEvent: (String, [String: Any]) -> Void = { type, data in Task { @MainActor in AssistantCoordinator.shared.sendEvent(type: type, data: data) } } - } catch { - logError("Task extraction error", error: error) - } - } - - /// Loop-based extraction: image analysis + iterative tool calling for search + terminal tool for decision - /// Returns (result, searchCount) where searchCount is the number of search tool calls made. - private func extractTaskSingleStage(from jpegData: Data, appName: String) async throws -> (TaskExtractionResult?, Int) { - // 1. Gather context - let context = await refreshContext() - - // 2. Build prompt with injected context - let dateFormatter = DateFormatter() - dateFormatter.dateFormat = "yyyy-MM-dd (EEEE)" - let todayStr = dateFormatter.string(from: Date()) - - var prompt = "Screenshot from \(appName). Today is \(todayStr). Analyze this screenshot for any unaddressed request directed at the user.\n\n" - - // For messaging apps, add an extra reminder about conversation analysis - let messagingApps: Set = ["Telegram", "WhatsApp", "\u{200E}WhatsApp", "Messages", "Slack", "Discord"] - if messagingApps.contains(appName) { - prompt += """ - REMINDER — THIS IS A MESSAGING APP: - - If this screenshot shows a chat sidebar/conversation list rather than an open conversation, SKIP entirely. - - If it shows an open conversation, read the FULL conversation flow between the user and the other person. - - LEFT-SIDE messages = from the other person. RIGHT-SIDE/colored = from the user. - - PRIORITY: Look for where the user AGREED or COMMITTED to doing something the other person asked. - Example: Other person says "Can you send me the report?" → User replies "Sure, will do" → Extract task: "Send [person] the report" - - ALSO: Look for incoming requests the user hasn't responded to yet. - - The task title should describe what was asked for, naming the other person in the conversation. - - """ - } - - // Inject AI user profile for context - if let profile = await AIUserProfileService.shared.getLatestProfile() { - prompt += "USER PROFILE (who this user is — use for context, not as a task source):\n" - prompt += profile.profileText + "\n\n" - } - - if !context.activeTasks.isEmpty { - // Get score range for context - let scoreRange = try? await ActionItemStorage.shared.getRelevanceScoreRange() - let rangeStr = scoreRange.map { "Score range: \($0.min)–\($0.max). " } ?? "" - - prompt += "ACTIVE TASKS (user is already tracking these — each has a relevance_score where 1 = most important, higher numbers = less important):\n" - prompt += "\(rangeStr)Use these scores to place any new task appropriately.\n" - for (i, task) in context.activeTasks.enumerated() { - let pri = task.priority.map { " [\($0)]" } ?? "" - let score = task.relevanceScore.map { " [score:\($0)]" } ?? "" - prompt += "\(i + 1).\(score) \(task.description)\(pri)\n" - } - prompt += "\n" - } - - if !context.completedTasks.isEmpty { - prompt += "RECENTLY COMPLETED TASKS (user engaged with these — this is the kind of task the user finds valuable. Extract similar types of tasks, just not exact duplicates of these specific ones):\n" - for (i, task) in context.completedTasks.enumerated() { - prompt += "\(i + 1). \(task.description)\n" - } - prompt += "\n" - } - - if !context.deletedTasks.isEmpty { - prompt += "USER-DELETED TASKS (user explicitly rejected these — do not re-extract similar):\n" - for (i, task) in context.deletedTasks.enumerated() { - prompt += "\(i + 1). \(task.description)\n" - } - prompt += "\n" - } - - if !context.goals.isEmpty { - prompt += "ACTIVE GOALS:\n" - for (i, goal) in context.goals.enumerated() { - prompt += "\(i + 1). \(goal.title)" - if let desc = goal.description { - prompt += " — \(desc)" - } - prompt += "\n" - } - prompt += "\n" - } - prompt += """ - Analyze this screenshot. If you see a potential request, search for duplicates first. - If there is clearly no request on screen (~90% of screenshots), call no_task_found immediately. - """ - - // 3. Define 5 tools - let tools = GeminiTool(functionDeclarations: [ - GeminiTool.FunctionDeclaration( - name: "search_similar", - description: "Search for semantically similar existing tasks using vector similarity. Call this when you see a potential request and want to check for duplicates.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init(type: "string", description: "A concise description of the potential task to search for") - ], - required: ["query"] + if backendResult.tasks.isEmpty { + let result = TaskExtractionResult( + hasNewTask: false, task: nil, + contextSummary: "Analyzed \(frame.appName)", + currentActivity: "" ) - ), - GeminiTool.FunctionDeclaration( - name: "search_keywords", - description: "Search for existing tasks matching specific keywords. Use this for precise keyword-based matching complementing vector search.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init(type: "string", description: "Keywords to search for in existing tasks") - ], - required: ["query"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "no_task_found", - description: "Call this when there is no actionable request on screen. This is the most common outcome (~90% of screenshots). Use for: code editors, terminals, settings, media players, dashboards, or any screen without a direct request from another person or AI.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "context_summary": .init(type: "string", description: "Brief summary of what the user is looking at"), - "current_activity": .init(type: "string", description: "What the user is actively doing") - ], - required: ["context_summary", "current_activity"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "extract_task", - description: "Extract a new task that is not already tracked. Call ONLY after searching for duplicates. All fields are required.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "title": .init(type: "string", description: "Verb-first task title, 6–15 words. MUST name a specific person/project/artifact and a concrete action. If you can't write 6+ specific words, call no_task_found instead."), - "description": .init(type: "string", description: "Additional context about the task. Empty string if none."), - "priority": .init(type: "string", description: "Task priority", enumValues: ["high", "medium", "low"]), - "tags": .init(type: "array", description: "1-3 relevant tags", items: .init(type: "string")), - "source_app": .init(type: "string", description: "App where the task was found"), - "inferred_deadline": .init(type: "string", description: "Deadline in yyyy-MM-dd format (e.g. '2025-10-04'). Resolve relative references like 'Thursday' or 'next week' to an actual date. Empty string if no deadline."), - "confidence": .init(type: "number", description: "Confidence score 0.0-1.0"), - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "What the user is actively doing"), - "source_category": .init(type: "string", description: "Where the task originated", enumValues: ["direct_request", "self_generated", "calendar_driven", "reactive", "external_system", "other"]), - "source_subcategory": .init(type: "string", description: "Specific origin within category", enumValues: ["message", "meeting", "mention", "commitment", "idea", "reminder", "goal_subtask", "event_prep", "recurring", "deadline", "error", "notification", "observation", "project_tool", "alert", "documentation", "other"]), - "relevance_score": .init(type: "integer", description: "Where this task ranks relative to existing tasks. Look at the relevance_score values of existing active tasks and assign a score that places this task appropriately. 1 = most important/urgent, higher numbers = less important. Must be a positive integer.") - ], - required: ["title", "description", "priority", "tags", "source_app", "inferred_deadline", "confidence", "context_summary", "current_activity", "source_category", "source_subcategory", "relevance_score"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "reject_task", - description: "Reject task extraction — the potential task is a duplicate, already completed, or was previously rejected by the user. Call after searching confirms this.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "reason": .init(type: "string", description: "Why this task was rejected (e.g. 'duplicate of existing active task', 'already completed')"), - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "What the user is actively doing") - ], - required: ["reason", "context_summary", "current_activity"] - ) - ) - ]) - - // 4. Get system prompt - let currentSystemPrompt = await systemPrompt - - // 5. Build initial contents - // Wrap base64 encoding in autoreleasepool — Swift concurrency doesn't - // drain autorelease pools, causing bridged NSString objects to accumulate. - var contents: [GeminiImageToolRequest.Content] = autoreleasepool { - let base64Data = jpegData.base64EncodedString() - return [ - GeminiImageToolRequest.Content( - role: "user", - parts: [ - GeminiImageToolRequest.Part(text: prompt), - GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64Data) - ] - ) - ] - } - - // 6. Tool-calling loop (max 5 iterations) - var searchCount = 0 - - for iteration in 0..<5 { - let result = try await geminiClient.sendImageToolLoop( - contents: contents, - systemPrompt: currentSystemPrompt, - tools: [tools], - forceToolCall: iteration == 0 - ) - - guard let toolCall = result.toolCalls.first else { - log("Task: No tool call received on iteration \(iteration), breaking") - break - } - - switch toolCall.name { - case "no_task_found": - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "No task on screen" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "Unknown" - log("Task: no_task_found — \(contextSummary)") - return (TaskExtractionResult( - hasNewTask: false, - task: nil, - contextSummary: contextSummary, - currentActivity: currentActivity - ), searchCount) - - case "extract_task": - let title = toolCall.arguments["title"] as? String ?? "" - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "" - - // --- Hard validation: reject vague titles and ask the model to retry --- - let titleWords = title.split(separator: " ").count - let validationError = Self.validateTaskTitle(title, wordCount: titleWords) - if let error = validationError { - log("Task: Title rejected (\(error)): \"\(title)\"") - - // Feed rejection back into the loop so the model can retry with more specifics - contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: toolCall.arguments as? [String: String] ?? ["title": title]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: """ - REJECTED: \(error). \ - Your title was: "\(title)" (\(titleWords) words). \ - Either rewrite with 6+ words including a specific person/project name and concrete action, \ - or call no_task_found if you cannot be more specific. - """) - ))] - )) - continue - } - - let description = toolCall.arguments["description"] as? String - let priorityStr = toolCall.arguments["priority"] as? String ?? "medium" - let priority = TaskPriority(rawValue: priorityStr) ?? .medium - let tags: [String] - if let tagArray = toolCall.arguments["tags"] as? [Any] { - tags = tagArray.compactMap { $0 as? String } - } else { - tags = [] - } - let sourceApp = toolCall.arguments["source_app"] as? String ?? appName - let inferredDeadline = toolCall.arguments["inferred_deadline"] as? String - let confidence: Double - if let confValue = toolCall.arguments["confidence"] as? Double { - confidence = confValue - } else if let confInt = toolCall.arguments["confidence"] as? Int { - confidence = Double(confInt) - } else { - confidence = 0.5 - } - let sourceCategory = toolCall.arguments["source_category"] as? String ?? "other" - let sourceSubcategory = toolCall.arguments["source_subcategory"] as? String ?? "other" - let relevanceScore: Int? - if let scoreValue = toolCall.arguments["relevance_score"] as? Int { - relevanceScore = scoreValue - } else if let scoreDouble = toolCall.arguments["relevance_score"] as? Double { - relevanceScore = Int(scoreDouble) - } else { - relevanceScore = nil - } - - let task = ExtractedTask( - title: title, - description: description?.isEmpty == true ? nil : description, - priority: priority, - sourceApp: sourceApp, - inferredDeadline: inferredDeadline?.isEmpty == true ? nil : inferredDeadline, - confidence: confidence, - tags: tags, - sourceCategory: sourceCategory, - sourceSubcategory: sourceSubcategory, - relevanceScore: relevanceScore - ) - - log("Task: extract_task — \"\(title)\" (confidence: \(confidence), priority: \(priorityStr), score: \(relevanceScore.map { String($0) } ?? "nil"))") - return (TaskExtractionResult( - hasNewTask: true, - task: task, - contextSummary: contextSummary, - currentActivity: currentActivity - ), searchCount) - - case "reject_task": - let reason = toolCall.arguments["reason"] as? String ?? "Unknown reason" - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "" - log("Task: reject_task — \(reason)") - return (TaskExtractionResult( - hasNewTask: false, - task: nil, - contextSummary: contextSummary, - currentActivity: currentActivity - ), searchCount) - - case "search_similar": - let query = toolCall.arguments["query"] as? String ?? "" - searchCount += 1 - log("Task: search_similar query: \"\(query)\"") - let searchResults = await executeVectorSearch(query: query) - log("Task: Vector search returned \(searchResults.count) results") - - let searchResultsJson: String - if let data = try? JSONEncoder().encode(searchResults), - let json = String(data: data, encoding: .utf8) { - searchResultsJson = json - } else { - searchResultsJson = "[]" - } - - // Append model's tool call + function response to contents - contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: ["query": query]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: searchResultsJson) - ))] - )) - continue - - case "search_keywords": - let query = toolCall.arguments["query"] as? String ?? "" - searchCount += 1 - log("Task: search_keywords query: \"\(query)\"") - let searchResults = await executeKeywordSearch(query: query) - log("Task: Keyword search returned \(searchResults.count) results") - - let searchResultsJson: String - if let data = try? JSONEncoder().encode(searchResults), - let json = String(data: data, encoding: .utf8) { - searchResultsJson = json - } else { - searchResultsJson = "[]" - } - - // Append model's tool call + function response to contents - contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: ["query": query]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: searchResultsJson) - ))] - )) - continue - - default: - log("Task: Unknown tool call: \(toolCall.name), breaking") - break - } - } - - log("Task: Completed in \(searchCount) searches (loop exhausted without terminal tool)") - return (nil, searchCount) - } - - // MARK: - Title Validation - - /// Validates a task title for minimum specificity. Returns an error message if invalid, nil if OK. - private static func validateTaskTitle(_ title: String, wordCount: Int) -> String? { - let trimmed = title.trimmingCharacters(in: .whitespacesAndNewlines) - - // Must not be empty - if trimmed.isEmpty { - return "Title is empty" - } - - // Minimum 6 words - if wordCount < 6 { - return "Title too short (\(wordCount) words, minimum 6)" - } - - // Reject titles that are purely generic verbs with no specifics - let genericPatterns: [String] = [ - "investigate", "check logs", "clean up", "look into", - "look through", "update to", "fix the", "review the", - "check the", "modify the", "track the" - ] - let lowered = trimmed.lowercased() - for pattern in genericPatterns { - // If the entire title is just a generic pattern (possibly with 1-2 filler words), reject - if lowered == pattern || (wordCount <= 4 && lowered.hasPrefix(pattern)) { - return "Title too generic (matches vague pattern '\(pattern)')" + log("Task: Analysis returned no tasks") + await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, appName: frame.appName, windowTitle: frame.windowTitle, sendEvent: sendEvent) + return } - } - // Must contain at least one capitalized proper noun (person, project, app name) - // Heuristic: after the first word (verb), there should be at least one word starting with uppercase - let words = trimmed.split(separator: " ") - let hasProperNoun = words.dropFirst().contains { word in - guard let first = word.first else { return false } - return first.isUppercase - } - if !hasProperNoun { - return "Title lacks a specific name (person, project, or app) — no proper nouns found after the verb" - } - - return nil - } - - // MARK: - Context & Search + log("Task: Analysis complete - \(backendResult.tasks.count) task(s)") - /// Refresh context from local SQLite + cached goals - private func refreshContext() async -> TaskExtractionContext { - var topRelevanceTasks: [(id: Int64, description: String, priority: String?, relevanceScore: Int?)] = [] - var recentTasks: [(id: Int64, description: String, priority: String?, relevanceScore: Int?)] = [] - var completedTasks: [(id: Int64, description: String)] = [] - var deletedTasks: [(id: Int64, description: String)] = [] - - // Query both action_items (promoted + manual) and staged_tasks for full context - do { - topRelevanceTasks = try await ActionItemStorage.shared.getTopRelevanceTasks(limit: 30) - } catch { - logError("Task: Failed to load top relevance tasks", error: error) - } - - do { - recentTasks = try await ActionItemStorage.shared.getRecentActiveTasks(limit: 30) - } catch { - logError("Task: Failed to load recent tasks", error: error) - } - - // Also include staged tasks for dedup context - do { - let stagedTasks = try await StagedTaskStorage.shared.getAllStagedTasks(limit: 30) - let stagedAsTuples = stagedTasks.map { task in - (id: Int64(0), description: task.description, priority: task.priority, relevanceScore: task.relevanceScore) + for taskDict in backendResult.tasks { + let result = parseBackendTask(taskDict, appName: frame.appName) + await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, appName: frame.appName, windowTitle: frame.windowTitle, sendEvent: sendEvent) } - recentTasks.append(contentsOf: stagedAsTuples) - } catch { - logError("Task: Failed to load staged tasks for context", error: error) - } - - // Merge: top relevance tasks first, then recent ones not already included - let topIds = Set(topRelevanceTasks.map { $0.id }) - let activeTasks = topRelevanceTasks + recentTasks.filter { !topIds.contains($0.id) } - - do { - completedTasks = try await ActionItemStorage.shared.getRecentCompletedTasks(limit: 10) - } catch { - logError("Task: Failed to load completed tasks", error: error) - } - - do { - deletedTasks = try await ActionItemStorage.shared.getRecentDeletedTasks(limit: 10, deletedBy: "user") } catch { - logError("Task: Failed to load deleted tasks", error: error) - } - - // Refresh goals if stale - let timeSinceGoals = Date().timeIntervalSince(lastGoalsRefresh) - if timeSinceGoals >= goalsRefreshInterval { - do { - cachedGoals = try await APIClient.shared.getGoals() - lastGoalsRefresh = Date() - log("Task: Refreshed \(cachedGoals.count) goals") - } catch { - logError("Task: Failed to refresh goals", error: error) - } + logError("Task extraction error", error: error) } - - return TaskExtractionContext( - activeTasks: activeTasks, - completedTasks: completedTasks, - deletedTasks: deletedTasks, - goals: cachedGoals - ) } - /// Execute vector similarity search - private func executeVectorSearch(query: String) async -> [TaskSearchResult] { - var results: [TaskSearchResult] = [] - - do { - let queryEmbedding = try await EmbeddingService.shared.embed(text: query) - let vectorResults = await EmbeddingService.shared.searchSimilar(query: queryEmbedding, topK: 10) - - for result in vectorResults where result.similarity > 0.3 { - if let record = try await ActionItemStorage.shared.getActionItem(id: result.id) { - let status: String - if record.deleted { status = "deleted" } - else if record.completed { status = "completed" } - else { status = "active" } - - results.append(TaskSearchResult( - id: result.id, - description: record.description, - status: status, - similarity: Double(result.similarity), - matchType: "vector", - relevanceScore: record.relevanceScore - )) - } else if let staged = try await StagedTaskStorage.shared.getStagedTask(id: result.id) { - // Fallback: ID belongs to a staged task (shared embedding index) - let status: String - if staged.deleted { status = "deleted" } - else if staged.completed { status = "completed" } - else { status = "active" } - - results.append(TaskSearchResult( - id: result.id, - description: staged.description, - status: status, - similarity: Double(result.similarity), - matchType: "vector", - relevanceScore: staged.relevanceScore - )) - } - } - } catch { - logError("Task: Vector search failed", error: error) - } + /// Parse a raw task dict from the backend into a TaskExtractionResult. + private func parseBackendTask(_ dict: [String: Any], appName: String) -> TaskExtractionResult { + let title = dict["title"] as? String ?? "" + let description = dict["description"] as? String + let priorityStr = dict["priority"] as? String ?? "medium" + let priority = TaskPriority(rawValue: priorityStr) ?? .medium + let tags = (dict["tags"] as? [String]) ?? [] + let sourceApp = dict["source_app"] as? String ?? appName + let inferredDeadline = dict["inferred_deadline"] as? String + let confidence: Double + if let confValue = dict["confidence"] as? Double { + confidence = confValue + } else if let confInt = dict["confidence"] as? Int { + confidence = Double(confInt) + } else { + confidence = 0.5 + } + let sourceCategory = dict["source_category"] as? String ?? "other" + let sourceSubcategory = dict["source_subcategory"] as? String ?? "other" + let relevanceScore: Int? + if let scoreValue = dict["relevance_score"] as? Int { + relevanceScore = scoreValue + } else if let scoreDouble = dict["relevance_score"] as? Double { + relevanceScore = Int(scoreDouble) + } else { + relevanceScore = nil + } + + let task = ExtractedTask( + title: title, + description: description?.isEmpty == true ? nil : description, + priority: priority, + sourceApp: sourceApp, + inferredDeadline: inferredDeadline?.isEmpty == true ? nil : inferredDeadline, + confidence: confidence, + tags: tags, + sourceCategory: sourceCategory, + sourceSubcategory: sourceSubcategory, + relevanceScore: relevanceScore + ) - return results.sorted { ($0.similarity ?? 0) > ($1.similarity ?? 0) } + return TaskExtractionResult( + hasNewTask: true, + task: task, + contextSummary: dict["context_summary"] as? String ?? "Analyzed \(appName)", + currentActivity: dict["current_activity"] as? String ?? "" + ) } - /// Execute FTS5 keyword search (searches both action_items and staged_tasks) - private func executeKeywordSearch(query: String) async -> [TaskSearchResult] { - var results: [TaskSearchResult] = [] - - do { - let words = query.components(separatedBy: .whitespaces) - .map { $0.filter { $0.isLetter || $0.isNumber } } // Strip FTS5 special chars (- : * " etc.) - .filter { $0.count >= 3 } - let ftsQuery = words.map { "\($0)*" }.joined(separator: " OR ") - - if !ftsQuery.isEmpty { - // Search action_items (promoted + manual) - let ftsResults = try await ActionItemStorage.shared.searchFTS( - query: ftsQuery, - limit: 10, - includeCompleted: true, - includeDeleted: true - ) - - for result in ftsResults { - let status: String - if result.deleted { status = "deleted" } - else if result.completed { status = "completed" } - else { status = "active" } - - results.append(TaskSearchResult( - id: result.id, - description: result.description, - status: status, - similarity: nil, - matchType: "fts", - relevanceScore: result.relevanceScore - )) - } - - // Also search staged_tasks - let stagedResults = try await StagedTaskStorage.shared.searchFTS( - query: ftsQuery, - limit: 10 - ) - for result in stagedResults { - results.append(TaskSearchResult( - id: result.id, - description: result.description, - status: "active", - similarity: nil, - matchType: "fts", - relevanceScore: result.relevanceScore - )) - } - } - } catch { - logError("Task: FTS search failed", error: error) - } - - return results - } } From cc33cbd5fd4b5e0f59645434ca45368516a5f7ba Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:07 +0100 Subject: [PATCH 026/163] Wire MemoryAssistant thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace GeminiClient.sendRequest with backendService.extractMemories(). Remove prompt/schema building — all LLM logic now server-side. Co-Authored-By: Claude Opus 4.6 --- .../MemoryExtraction/MemoryAssistant.swift | 97 +++++++------------ 1 file changed, 33 insertions(+), 64 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift index bcc6a6f1e6..2e0c671820 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift @@ -17,7 +17,7 @@ actor MemoryAssistant: ProactiveAssistant { // MARK: - Properties - private let geminiClient: GeminiClient + private let backendService: BackendProactiveService private var isRunning = false private var lastAnalysisTime: Date = .distantPast private var previousMemories: [ExtractedMemory] = [] // Last 20 extracted memories for deduplication @@ -28,15 +28,6 @@ actor MemoryAssistant: ProactiveAssistant { private let frameSignal: AsyncStream private let frameSignalContinuation: AsyncStream.Continuation - /// Get the current system prompt from settings (accessed on MainActor for thread safety) - private var systemPrompt: String { - get async { - await MainActor.run { - MemoryAssistantSettings.shared.analysisPrompt - } - } - } - /// Get the extraction interval from settings private var extractionInterval: TimeInterval { get async { @@ -57,9 +48,8 @@ actor MemoryAssistant: ProactiveAssistant { // MARK: - Initialization - init(apiKey: String? = nil) throws { - // Use Gemini 3 Pro for better memory extraction quality - self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest") + init(backendService: BackendProactiveService) { + self.backendService = backendService let (stream, continuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1)) self.frameSignal = stream @@ -340,61 +330,40 @@ actor MemoryAssistant: ProactiveAssistant { } private func extractMemories(from jpegData: Data, appName: String) async throws -> MemoryExtractionResult? { - // Build context with previous memories for deduplication - var prompt = "Analyze this screenshot from \(appName).\n\n" + let base64 = autoreleasepool { jpegData.base64EncodedString() } + let backendResult = try await backendService.extractMemories( + imageBase64: base64, + appName: appName, + windowTitle: "" + ) - if !previousMemories.isEmpty { - prompt += "RECENTLY EXTRACTED MEMORIES (do not re-extract these or semantically similar ones):\n" - for (index, memory) in previousMemories.enumerated() { - prompt += "\(index + 1). [\(memory.category.rawValue)] \(memory.content)\n" + // Parse backend response into MemoryExtractionResult + let memories: [ExtractedMemory] = backendResult.memories.compactMap { dict in + guard let content = dict["content"] as? String, !content.isEmpty else { return nil } + let categoryStr = dict["category"] as? String ?? "system" + let category: ExtractedMemoryCategory = categoryStr == "interesting" ? .interesting : .system + let sourceApp = dict["source_app"] as? String ?? appName + let confidence: Double + if let confValue = dict["confidence"] as? Double { + confidence = confValue + } else if let confInt = dict["confidence"] as? Int { + confidence = Double(confInt) + } else { + confidence = 0.5 } - prompt += "\nLook for NEW memories that are NOT already in the list above." - } else { - prompt += "Look for memories to extract (system facts about the user, or interesting wisdom from others)." + return ExtractedMemory( + content: content, + category: category, + sourceApp: sourceApp, + confidence: confidence + ) } - // Get current system prompt from settings - let currentSystemPrompt = await systemPrompt - - // Build response schema for memory extraction - let memoryProperties: [String: GeminiRequest.GenerationConfig.ResponseSchema.Property] = [ - "content": .init(type: "string", description: "The memory content (max 15 words)"), - "category": .init(type: "string", enum: ["system", "interesting"], description: "Memory category"), - "source_app": .init(type: "string", description: "App where memory was found"), - "confidence": .init(type: "number", description: "Confidence score 0.0-1.0") - ] - - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "has_new_memory": .init(type: "boolean", description: "True if new memories were found"), - "memories": .init( - type: "array", - description: "Array of extracted memories (0-3 max)", - items: .init( - type: "object", - properties: memoryProperties, - required: ["content", "category", "source_app", "confidence"] - ) - ), - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "High-level description of user's activity") - ], - required: ["has_new_memory", "memories", "context_summary", "current_activity"] + return MemoryExtractionResult( + hasNewMemory: !memories.isEmpty, + memories: memories, + contextSummary: "Analyzed \(appName)", + currentActivity: "" ) - - do { - let responseText = try await geminiClient.sendRequest( - prompt: prompt, - imageData: jpegData, - systemPrompt: currentSystemPrompt, - responseSchema: responseSchema - ) - - return try JSONDecoder().decode(MemoryExtractionResult.self, from: Data(responseText.utf8)) - } catch { - logError("Memory analysis error", error: error) - return nil - } } } From 0456a62c12747ad213e84d597d78bcfe60505dcd Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:14 +0100 Subject: [PATCH 027/163] Wire AdviceAssistant thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace 2-phase Gemini tool-calling loop (execute_sql + vision) with backendService.generateAdvice(). Remove compressForGemini, getUserLanguage, buildActivitySummary, buildPhase1/2Tools — all LLM logic server-side. -560 lines. Co-Authored-By: Claude Opus 4.6 --- .../Assistants/Advice/AdviceAssistant.swift | 660 ++---------------- 1 file changed, 60 insertions(+), 600 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift index 87ca8c4b0a..e1c2701817 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift @@ -1,4 +1,3 @@ -import AppKit import Foundation import GRDB @@ -19,7 +18,7 @@ actor AdviceAssistant: ProactiveAssistant { // MARK: - Properties - private let geminiClient: GeminiClient + private let backendService: BackendProactiveService private var isRunning = false private var lastAnalysisTime: Date = .distantPast private var previousAdvice: [ExtractedAdvice] = [] // Dedup window for advice context @@ -33,15 +32,6 @@ actor AdviceAssistant: ProactiveAssistant { private let frameSignal: AsyncStream private let frameSignalContinuation: AsyncStream.Continuation - /// Get the current system prompt from settings (accessed on MainActor for thread safety) - private var systemPrompt: String { - get async { - await MainActor.run { - AdviceAssistantSettings.shared.analysisPrompt - } - } - } - /// Get the extraction interval from settings private var extractionInterval: TimeInterval { get async { @@ -62,9 +52,8 @@ actor AdviceAssistant: ProactiveAssistant { // MARK: - Initialization - init(apiKey: String? = nil) throws { - // Use Gemini 3.1 Pro for better advice quality (3-pro-preview retires March 9, 2026) - self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest") + init(backendService: BackendProactiveService) { + self.backendService = backendService let (stream, continuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1)) self.frameSignal = stream @@ -140,25 +129,6 @@ actor AdviceAssistant: ProactiveAssistant { log("Advice assistant stopped") } - // MARK: - Test Analysis (for test runner) - - /// Run the extraction pipeline on arbitrary JPEG data without side effects (no saving, no events). - /// Used by the test runner to replay past screenshots. - /// `screenshotTime` anchors the activity summary to the screenshot's actual timestamp. - /// Returns (result, sqlQueryCount) where sqlQueryCount is the number of execute_sql tool calls made. - func testAnalyze(jpegData: Data, appName: String, windowTitle: String? = nil, screenshotTime: Date) async throws -> (AdviceExtractionResult?, Int) { - let interval = await extractionInterval - let lookbackStart = screenshotTime.addingTimeInterval(-interval) - return try await runAdviceExtraction( - jpegData: nil, - appName: appName, - windowTitle: windowTitle, - referenceTime: screenshotTime, - lookbackStart: lookbackStart, - trackSqlCount: true - ) - } - // MARK: - ProactiveAssistant Protocol Methods func shouldAnalyze(frameNumber: Int, timeSinceLastAnalysis: TimeInterval) -> Bool { @@ -379,62 +349,34 @@ actor AdviceAssistant: ProactiveAssistant { pendingFrame = nil } - // MARK: - Image Processing - - /// Resize and compress an image for Gemini analysis (max 1280px wide, JPEG quality 0.4) - private static func compressForGemini(_ data: Data) -> Data? { - guard let source = CGImageSourceCreateWithData(data as CFData, nil), - let cgImage = CGImageSourceCreateImageAtIndex(source, 0, nil) else { return nil } - - let maxWidth = 1280 - let width = cgImage.width - let height = cgImage.height - let scale = width > maxWidth ? Double(maxWidth) / Double(width) : 1.0 - let newWidth = Int(Double(width) * scale) - let newHeight = Int(Double(height) * scale) - - guard let context = CGContext( - data: nil, width: newWidth, height: newHeight, - bitsPerComponent: 8, bytesPerRow: 0, - space: CGColorSpaceCreateDeviceRGB(), - bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue - ) else { return nil } - - context.interpolationQuality = .high - context.draw(cgImage, in: CGRect(x: 0, y: 0, width: newWidth, height: newHeight)) - - guard let resized = context.makeImage() else { return nil } - - let mutableData = NSMutableData() - guard let dest = CGImageDestinationCreateWithData(mutableData as CFMutableData, "public.jpeg" as CFString, 1, nil) else { return nil } - CGImageDestinationAddImage(dest, resized, [kCGImageDestinationLossyCompressionQuality: 0.4] as CFDictionary) - guard CGImageDestinationFinalize(dest) else { return nil } - return mutableData as Data - } - - // MARK: - Helpers + // MARK: - Test Analysis (for test runner) - /// Get user's preferred language, cached for 1 hour - private func getUserLanguage() async -> String? { - // Return cached value if fresh (< 1 hour) - if let cached = cachedLanguage, Date().timeIntervalSince(languageFetchedAt) < 3600 { - return cached + /// Run extraction via backend for test runner. Returns (result, 0) for compatibility. + func testAnalyze(jpegData: Data, appName: String, windowTitle: String? = nil, screenshotTime: Date) async throws -> (AdviceExtractionResult?, Int) { + let base64 = autoreleasepool { jpegData.base64EncodedString() } + let backendResult = try await backendService.generateAdvice( + imageBase64: base64, appName: appName, windowTitle: windowTitle ?? "" + ) + guard let adviceDict = backendResult.advice as? [String: Any] else { + return (AdviceExtractionResult(hasAdvice: false, advice: nil, contextSummary: "Analyzed \(appName)", currentActivity: ""), 0) } - - do { - let response = try await APIClient.shared.getUserLanguage() - let lang = response.language - cachedLanguage = lang - languageFetchedAt = Date() - return lang.isEmpty ? nil : lang - } catch { - // Fall back to transcription language setting - let fallback = await MainActor.run { AssistantSettings.shared.transcriptionLanguage } - return fallback.isEmpty || fallback == "en" ? nil : fallback + let hasAdvice = adviceDict["has_advice"] as? Bool ?? !adviceDict.isEmpty + guard hasAdvice, let adviceText = adviceDict["content"] as? String ?? adviceDict["advice"] as? String, !adviceText.isEmpty else { + return (AdviceExtractionResult(hasAdvice: false, advice: nil, contextSummary: "Analyzed \(appName)", currentActivity: ""), 0) } + let categoryStr = adviceDict["category"] as? String ?? "other" + let category = AdviceCategory(rawValue: categoryStr) ?? .other + let confidence = adviceDict["confidence"] as? Double ?? 0.5 + let advice = ExtractedAdvice( + advice: adviceText, headline: adviceDict["headline"] as? String, + reasoning: adviceDict["reasoning"] as? String, category: category, + sourceApp: appName, confidence: confidence + ) + let result = AdviceExtractionResult(hasAdvice: true, advice: advice, contextSummary: "Analyzed \(appName)", currentActivity: "") + return (result, 0) } - // MARK: - Analysis + // MARK: - Backend Analysis (Phase 2 thin client) private func processFrame(_ frame: CapturedFrame) async { guard await isEnabled else { return } @@ -443,7 +385,6 @@ actor AdviceAssistant: ProactiveAssistant { return } - // Handle the result with screenshot ID for SQLite storage await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, windowTitle: frame.windowTitle) { type, data in Task { @MainActor in AssistantCoordinator.shared.sendEvent(type: type, data: data) @@ -455,549 +396,68 @@ actor AdviceAssistant: ProactiveAssistant { } private func extractAdvice(from frame: CapturedFrame) async throws -> AdviceExtractionResult? { - let now = Date() - // Cap lookback: since last analysis or max 1 hour ago - let lookbackStart = max(lastAnalysisTime, now.addingTimeInterval(-3600)) - let (result, _) = try await runAdviceExtraction( - jpegData: nil, + let base64 = autoreleasepool { frame.jpegData.base64EncodedString() } + let backendResult = try await backendService.generateAdvice( + imageBase64: base64, appName: frame.appName, - windowTitle: frame.windowTitle, - referenceTime: now, - lookbackStart: lookbackStart, - trackSqlCount: false + windowTitle: frame.windowTitle ?? "" ) - return result - } - // MARK: - Core Extraction (shared by production + test) - - /// Two-phase advice extraction: - /// Phase 1 (text-only): Activity summary + SQL investigation loop. Model investigates via - /// execute_sql, then calls `request_screenshot` with an ID and its findings so far. - /// Phase 2 (single vision call): Load the chosen screenshot + Phase 1 findings → single - /// Gemini call with image → provide_advice or no_advice. - /// Returns (result, sqlQueryCount). - private func runAdviceExtraction( - jpegData: Data?, - appName: String, - windowTitle: String?, - referenceTime: Date, - lookbackStart: Date, - trackSqlCount: Bool - ) async throws -> (AdviceExtractionResult?, Int) { - var sqlCount = 0 - - // Build prompt with current context - let timeFormatter = DateFormatter() - timeFormatter.dateFormat = "h:mm a, EEEE" - var prompt = "CURRENT APP: \(appName)." - if let windowTitle = windowTitle, !windowTitle.isEmpty { - prompt += " Window: \"\(windowTitle)\"." - } - prompt += " Time: \(timeFormatter.string(from: referenceTime))." - - // Add activity summary from database, anchored to the reference time - let elapsed = referenceTime.timeIntervalSince(lookbackStart) - log("Advice: Activity lookback: \(String(format: "%.0f", elapsed))s (\(lookbackStart) to \(referenceTime))") - let activitySummary = await buildActivitySummary(from: lookbackStart, to: referenceTime) - if !activitySummary.isEmpty { - prompt += "\n\n" + activitySummary - log("Advice: --- ACTIVITY SUMMARY ---\n\(activitySummary)") - } else { - log("Advice: --- ACTIVITY SUMMARY --- (empty, no screenshots in range)") - } - - // Add user profile for context - if let profile = await AIUserProfileService.shared.getLatestProfile() { - prompt += "\n\nUSER PROFILE (who this user is):\n" - prompt += profile.profileText + "\n" - } - - // Add previous advice for dedup - if !previousAdvice.isEmpty { - prompt += "\n\nPREVIOUSLY PROVIDED ADVICE (do not repeat these or semantically similar):\n" - let adviceToInclude = previousAdvice.prefix(maxAdviceInPrompt) - for (index, advice) in adviceToInclude.enumerated() { - prompt += "\(index + 1). \(advice.advice)" - if let reasoning = advice.reasoning { - prompt += " (Reasoning: \(reasoning))" - } - prompt += "\n" - } - prompt += "\nOnly provide advice if there's a genuinely NEW non-obvious insight not covered above." - } else { - prompt += "\n\nOnly provide advice if there's something specific and non-obvious that would help." - } - - prompt += "\n\nInvestigate the activity summary. Scan OCR from the TOP 3-5 apps (not just the dominant one) — the best insights often come from browsers, communication apps, and notes, not just the app with the most screenshots. Skip apps with < 10 screenshots. When you've identified the most interesting screenshot, call request_screenshot with the ID and your findings. Or call no_advice if nothing qualifies." - - log("Advice: --- PROMPT ---\n\(prompt)") - - // Build system prompt - var currentSystemPrompt = await systemPrompt - if let language = await getUserLanguage(), language != "en" { - currentSystemPrompt += "\n\nIMPORTANT: Respond in the user's preferred language: \(language)" - } - currentSystemPrompt += "\n\nDATABASE SCHEMA for execute_sql:\nscreenshots table columns: id INTEGER, timestamp TEXT, appName TEXT, windowTitle TEXT, ocrText TEXT, focusStatus TEXT" - - // ============================================= - // PHASE 1: Text-only investigation loop - // ============================================= - - let phase1Tools = buildPhase1Tools() - var contents: [GeminiImageToolRequest.Content] = [ - GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(text: prompt)] + // Parse backend response into AdviceExtractionResult + guard let adviceDict = backendResult.advice as? [String: Any] else { + return AdviceExtractionResult( + hasAdvice: false, + advice: nil, + contextSummary: "Analyzed \(frame.appName)", + currentActivity: "" ) - ] - - let client = self.geminiClient - var chosenScreenshotId: Int64? - var investigationFindings: String? - - for iteration in 0..<7 { - let iterContents = contents - let iterSystemPrompt = currentSystemPrompt - let iterTools = [phase1Tools] - let iterForce = iteration == 0 - let result: ToolChatResult - do { - result = try await withThrowingTimeout(seconds: 120) { - try await client.sendImageToolLoop( - contents: iterContents, - systemPrompt: iterSystemPrompt, - tools: iterTools, - forceToolCall: iterForce - ) - } - } catch { - log("Advice: Phase 1 failed on iteration \(iteration): \(error.localizedDescription)") - throw error - } - - guard let toolCall = result.toolCalls.first else { - log("Advice: Phase 1 — no tool call on iteration \(iteration), breaking") - break - } - - switch toolCall.name { - case "execute_sql": - let query = toolCall.arguments["query"] as? String ?? "" - sqlCount += 1 - log("Advice: P1 execute_sql iter \(iteration): \(query)") - let sqlToolCall = ToolCall(name: "execute_sql", arguments: ["query": query], thoughtSignature: nil) - let resultStr = await ChatToolExecutor.execute(sqlToolCall) - let truncated = resultStr.count > 2000 ? String(resultStr.prefix(2000)) + "... (truncated)" : resultStr - log("Advice: P1 sql result (\(resultStr.count) chars): \(truncated)") - - contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: ["query": query]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: resultStr) - ))] - )) - continue - - case "request_screenshot": - let findings = toolCall.arguments["findings"] as? String ?? "" - investigationFindings = findings - if let idInt = toolCall.arguments["screenshot_id"] as? Int { - chosenScreenshotId = Int64(idInt) - } else if let idInt64 = toolCall.arguments["screenshot_id"] as? Int64 { - chosenScreenshotId = idInt64 - } else if let idStr = toolCall.arguments["screenshot_id"] as? String, let parsed = Int64(idStr) { - chosenScreenshotId = parsed - } else if let idDouble = toolCall.arguments["screenshot_id"] as? Double { - chosenScreenshotId = Int64(idDouble) - } - log("Advice: P1 request_screenshot iter \(iteration): id=\(chosenScreenshotId ?? 0), findings=\(findings.prefix(200))") - break // Exit phase 1 - - case "no_advice": - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "No context" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "Unknown" - log("Advice: P1 no_advice — \(contextSummary)") - return (AdviceExtractionResult( - hasAdvice: false, - advice: nil, - contextSummary: contextSummary, - currentActivity: currentActivity - ), sqlCount) - - default: - log("Advice: P1 unknown tool: \(toolCall.name), breaking") - break - } - - // Break out of loop if request_screenshot was called - if chosenScreenshotId != nil { break } - } - - // If Phase 1 exhausted without choosing a screenshot, no advice - guard let screenshotId = chosenScreenshotId, let findings = investigationFindings else { - log("Advice: Phase 1 exhausted without request_screenshot") - return (nil, sqlCount) } - // ============================================= - // PHASE 2: Single vision call with chosen screenshot - // ============================================= - - log("Advice: Phase 2 — loading screenshot \(screenshotId)") - - // Load the screenshot image - let imageData: Data - do { - guard let screenshot = try await RewindDatabase.shared.getScreenshot(id: screenshotId) else { - log("Advice: P2 screenshot not in DB: \(screenshotId)") - return (nil, sqlCount) - } - // Check active chunk - if screenshot.usesVideoStorage, let chunk = screenshot.videoChunkPath { - let activeChunk = await VideoChunkEncoder.shared.currentChunkPath - if chunk == activeChunk { - log("Advice: P2 screenshot is in active chunk, skipping") - return (nil, sqlCount) - } - } - let rawData = try await RewindStorage.shared.loadScreenshotData(for: screenshot) - imageData = Self.compressForGemini(rawData) ?? rawData - log("Advice: P2 loaded \(imageData.count) bytes (\(rawData.count) raw) from \(screenshot.appName)") - } catch { - log("Advice: P2 screenshot load failed: \(error.localizedDescription)") - return (nil, sqlCount) - } - - // Build Phase 2 prompt — compact findings + image + cross-reference instruction - let phase2Prompt = """ - INVESTIGATION FINDINGS: - \(findings) - - The screenshot below is from the app/window identified during investigation. - - Before giving advice, CROSS-REFERENCE your findings: - - Use execute_sql to check if this issue was resolved in later screenshots - - Check if the user moved on to something else (the issue may be stale) - - Verify the context is still relevant by looking at nearby timestamps - - Then call provide_advice if the insight is still valid, or no_advice if it was resolved or is no longer relevant. - """ - - let phase2Tools = buildPhase2Tools() - let base64 = imageData.base64EncodedString() - var phase2Contents: [GeminiImageToolRequest.Content] = [ - GeminiImageToolRequest.Content( - role: "user", - parts: [ - GeminiImageToolRequest.Part(text: phase2Prompt), - GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64), - ] + let hasAdvice = adviceDict["has_advice"] as? Bool ?? !adviceDict.isEmpty + guard hasAdvice else { + return AdviceExtractionResult( + hasAdvice: false, + advice: nil, + contextSummary: "Analyzed \(frame.appName)", + currentActivity: "" ) - ] - - // Phase 2 loop — model can cross-reference via SQL before deciding - for p2Iteration in 0..<5 { - let p2Contents = phase2Contents - let p2SystemPrompt = currentSystemPrompt - let p2Tools = [phase2Tools] - let p2Force = p2Iteration == 0 - let phase2Result: ToolChatResult - do { - phase2Result = try await withThrowingTimeout(seconds: 120) { - try await client.sendImageToolLoop( - contents: p2Contents, - systemPrompt: p2SystemPrompt, - tools: p2Tools, - forceToolCall: p2Force - ) - } - } catch { - log("Advice: Phase 2 failed on iteration \(p2Iteration): \(error.localizedDescription)") - throw error - } - - guard let toolCall = phase2Result.toolCalls.first else { - log("Advice: Phase 2 — no tool call on iteration \(p2Iteration), breaking") - break - } - - switch toolCall.name { - case "execute_sql": - let query = toolCall.arguments["query"] as? String ?? "" - sqlCount += 1 - log("Advice: P2 execute_sql iter \(p2Iteration): \(query)") - let sqlToolCall = ToolCall(name: "execute_sql", arguments: ["query": query], thoughtSignature: nil) - let resultStr = await ChatToolExecutor.execute(sqlToolCall) - let truncated = resultStr.count > 2000 ? String(resultStr.prefix(2000)) + "... (truncated)" : resultStr - log("Advice: P2 sql result (\(resultStr.count) chars): \(truncated)") - - phase2Contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: ["query": query]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - phase2Contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: resultStr) - ))] - )) - continue - - case "provide_advice": - log("Advice: P2 provide_advice (after \(p2Iteration) cross-reference iterations)") - return (parseProvideAdvice(toolCall), sqlCount) - - case "no_advice": - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "No context" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "Unknown" - log("Advice: P2 no_advice — \(contextSummary)") - return (AdviceExtractionResult( - hasAdvice: false, - advice: nil, - contextSummary: contextSummary, - currentActivity: currentActivity - ), sqlCount) - - default: - log("Advice: P2 unexpected tool: \(toolCall.name)") - break - } - break // Break on unexpected tool - } - return (nil, sqlCount) - } - - // MARK: - Activity Summary - - /// Query the screenshots table to build a summary of recent activity. - /// - `from`: lower bound (e.g. last analysis time or screenshot.timestamp - interval) - /// - `to`: upper bound (e.g. now or the screenshot's timestamp) - private func buildActivitySummary(from lookbackStart: Date, to referenceTime: Date) async -> String { - guard let dbQueue = await RewindDatabase.shared.getDatabaseQueue() else { - return "" } - do { - return try await dbQueue.read { db in - // Pass Date objects directly — GRDB encodes them as UTC strings - // matching the stored format. Manual DateFormatter uses local timezone - // which causes mismatches. - let rows = try Row.fetchAll(db, sql: """ - SELECT appName, windowTitle, COUNT(*) as count, - MIN(timestamp) as first_seen, MAX(timestamp) as last_seen - FROM screenshots - WHERE timestamp >= ? AND timestamp <= ? - AND appName IS NOT NULL AND appName != '' - GROUP BY appName, windowTitle - ORDER BY count DESC - LIMIT 30 - """, arguments: [lookbackStart, referenceTime]) - - if rows.isEmpty { - return "" - } - - let totalScreenshots = rows.reduce(0) { $0 + (($1["count"] as? Int64).map(Int.init) ?? ($1["count"] as? Int) ?? 0) } - let elapsedMin = referenceTime.timeIntervalSince(lookbackStart) / 60.0 - - let timeOnlyFormatter = DateFormatter() - timeOnlyFormatter.dateFormat = "HH:mm:ss" - - var lines: [String] = [] - lines.append("ACTIVITY SUMMARY (last \(Int(elapsedMin)) min, \(totalScreenshots) screenshots):") - lines.append("Time range: \(timeOnlyFormatter.string(from: lookbackStart)) – \(timeOnlyFormatter.string(from: referenceTime))") - lines.append("") - lines.append("App | Window | Screenshots | Est. Duration") - lines.append(String(repeating: "-", count: 60)) - - for row in rows { - let app = row["appName"] as? String ?? "Unknown" - let window = row["windowTitle"] as? String ?? "" - let count = (row["count"] as? Int64).map(Int.init) ?? (row["count"] as? Int) ?? 0 - let estMinutes = String(format: "%.1f", Double(count) / 60.0) - let windowDisplay = window.isEmpty ? "(no title)" : String(window.prefix(50)) - lines.append("\(app) | \(windowDisplay) | \(count) | \(estMinutes) min") - } - - let summary = lines.joined(separator: "\n") - log("Advice: Activity summary (last \(Int(elapsedMin)) min, \(totalScreenshots) screenshots)") - return summary - } - } catch { - logError("Advice: Failed to build activity summary", error: error) - return "" + let adviceText = adviceDict["content"] as? String ?? adviceDict["advice"] as? String ?? "" + guard !adviceText.isEmpty else { + return AdviceExtractionResult( + hasAdvice: false, + advice: nil, + contextSummary: "Analyzed \(frame.appName)", + currentActivity: "" + ) } - } - // MARK: - Tool Definitions - - /// Phase 1 tools: text-only investigation (execute_sql, request_screenshot, no_advice) - private func buildPhase1Tools() -> GeminiTool { - GeminiTool(functionDeclarations: [ - GeminiTool.FunctionDeclaration( - name: "execute_sql", - description: "Execute a SQL query on the local database to investigate screen activity. The screenshots table has: id INTEGER, timestamp TEXT, appName TEXT, windowTitle TEXT, ocrText TEXT, focusStatus TEXT. Use this to read OCR text from interesting windows, check what the user was doing, etc. SELECT queries only. Auto-limited to 200 rows.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init(type: "string", description: "SQL SELECT query to execute on the screenshots table") - ], - required: ["query"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "request_screenshot", - description: "Request to view a specific screenshot. Call this when you've found something interesting via SQL and want to see the actual screen. Provide the screenshot ID and a summary of your findings so far. The screenshot will be shown to you for final analysis.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "screenshot_id": .init(type: "integer", description: "The screenshot ID from the screenshots table"), - "findings": .init(type: "string", description: "Summary of what you found during investigation — what app, what OCR text caught your attention, and what you suspect might be worth advising about") - ], - required: ["screenshot_id", "findings"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "no_advice", - description: "Call this when there is nothing worth advising about. Nothing qualifies as a specific, non-obvious insight. This ends the analysis.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "High-level description of user's activity") - ], - required: ["context_summary", "current_activity"] - ) - ), - ]) - } - - /// Phase 2 tools: vision call with screenshot + SQL cross-referencing (execute_sql, provide_advice, no_advice) - private func buildPhase2Tools() -> GeminiTool { - GeminiTool(functionDeclarations: [ - GeminiTool.FunctionDeclaration( - name: "execute_sql", - description: "Cross-reference your findings by querying the database. Use this to check if an issue was resolved in later screenshots, verify context across time, or look up related activity. The screenshots table has: id INTEGER, timestamp TEXT, appName TEXT, windowTitle TEXT, ocrText TEXT, focusStatus TEXT. SELECT queries only.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init(type: "string", description: "SQL SELECT query to execute on the screenshots table") - ], - required: ["query"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "provide_advice", - description: "Call this when you have a specific, non-obvious insight for the user based on the screenshot and your investigation findings. You should cross-reference first using execute_sql to verify the issue is still relevant.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "advice": .init(type: "string", description: "The advice text (1-2 sentences, max 100 chars). Start with what you noticed, then why it matters."), - "headline": .init(type: "string", description: "Ultra-short observation (max 5 words) for notification preview. E.g. 'Draft saved in /tmp', 'Credentials visible in terminal'"), - "reasoning": .init(type: "string", description: "Brief explanation of why this advice is relevant"), - "category": .init(type: "string", description: "Category of advice", enumValues: ["productivity", "communication", "learning", "other"]), - "source_app": .init(type: "string", description: "App where context was observed"), - "confidence": .init(type: "number", description: "Confidence score 0.0-1.0. 0.90+: preventing clear mistake. 0.75-0.89: highly relevant non-obvious tip. 0.60-0.74: useful but user might know."), - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "High-level description of user's activity") - ], - required: ["advice", "headline", "category", "source_app", "confidence", "context_summary", "current_activity"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "no_advice", - description: "Call this when the screenshot doesn't reveal anything worth advising about, or when cross-referencing shows the issue was already resolved.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "High-level description of user's activity") - ], - required: ["context_summary", "current_activity"] - ) - ), - ]) - } - - // MARK: - Parse Tool Results - - /// Parse the provide_advice tool call into an AdviceExtractionResult - private func parseProvideAdvice(_ toolCall: ToolCall) -> AdviceExtractionResult { - let adviceText = toolCall.arguments["advice"] as? String ?? "" - let headline = toolCall.arguments["headline"] as? String - let reasoning = toolCall.arguments["reasoning"] as? String - let categoryStr = toolCall.arguments["category"] as? String ?? "other" + let categoryStr = adviceDict["category"] as? String ?? "other" let category = AdviceCategory(rawValue: categoryStr) ?? .other - let sourceApp = toolCall.arguments["source_app"] as? String ?? "" - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "" - let confidence: Double - if let confValue = toolCall.arguments["confidence"] as? Double { + if let confValue = adviceDict["confidence"] as? Double { confidence = confValue - } else if let confInt = toolCall.arguments["confidence"] as? Int { + } else if let confInt = adviceDict["confidence"] as? Int { confidence = Double(confInt) - } else if let confStr = toolCall.arguments["confidence"] as? String, let parsed = Double(confStr) { - confidence = parsed } else { confidence = 0.5 } let advice = ExtractedAdvice( advice: adviceText, - headline: headline, - reasoning: reasoning, + headline: adviceDict["headline"] as? String, + reasoning: adviceDict["reasoning"] as? String, category: category, - sourceApp: sourceApp, + sourceApp: frame.appName, confidence: confidence ) - log("Advice: --- PROVIDE_ADVICE ---") - log("Advice: advice: \(adviceText)") - log("Advice: headline: \(headline ?? "(none)")") - log("Advice: reasoning: \(reasoning ?? "(none)")") - log("Advice: category: \(categoryStr)") - log("Advice: source_app: \(sourceApp)") - log("Advice: confidence: \(confidence)") - log("Advice: context: \(contextSummary)") - log("Advice: activity: \(currentActivity)") return AdviceExtractionResult( hasAdvice: true, advice: advice, - contextSummary: contextSummary, - currentActivity: currentActivity + contextSummary: adviceDict["context_summary"] as? String ?? "Analyzed \(frame.appName)", + currentActivity: adviceDict["current_activity"] as? String ?? "" ) } } - -// MARK: - Timeout Helper - -/// Run an async operation with a timeout. Throws `CancellationError` if the timeout expires. -private func withThrowingTimeout(seconds: Double, operation: @escaping @Sendable () async throws -> T) async throws -> T { - try await withThrowingTaskGroup(of: T.self) { group in - group.addTask { - try await operation() - } - group.addTask { - try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000)) - throw CancellationError() - } - // First task to complete wins; cancel the other - let result = try await group.next()! - group.cancelAll() - return result - } -} From 2bad74609aa5c5aeb4022f07db057b33eb248941 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:19 +0100 Subject: [PATCH 028/163] Wire TaskDeduplicationService thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace GeminiClient with backendService.deduplicateTasks(). Remove prompt/schema building, local dedup logic — server handles everything. Co-Authored-By: Claude Opus 4.6 --- .../TaskDeduplicationService.swift | 214 ++---------------- 1 file changed, 22 insertions(+), 192 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift index 4618f99673..98b38a8761 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift @@ -6,7 +6,7 @@ import Foundation actor TaskDeduplicationService { static let shared = TaskDeduplicationService() - private var geminiClient: GeminiClient? + private var backendService: BackendProactiveService? private var timer: Task? private var isRunning = false private var lastRunTime: Date? @@ -17,13 +17,11 @@ actor TaskDeduplicationService { private let cooldownSeconds: TimeInterval = 1800 // 30-min cooldown private let minimumTaskCount = 3 - private init() { - do { - self.geminiClient = try GeminiClient(model: "gemini-pro-latest") - } catch { - log("TaskDedup: Failed to initialize GeminiClient: \(error)") - self.geminiClient = nil - } + private init() {} + + /// Set the backend service for Phase 2 server-side deduplication. + func configure(backendService: BackendProactiveService) { + self.backendService = backendService } // MARK: - Lifecycle @@ -67,210 +65,42 @@ actor TaskDeduplicationService { // MARK: - Deduplication Logic private func runDeduplication() async { - guard let client = geminiClient else { - log("TaskDedup: Skipping - Gemini client not initialized") + guard let service = backendService else { + log("TaskDedup: Skipping - backend service not configured") return } lastRunTime = Date() - log("TaskDedup: Starting deduplication run on staged tasks") - - // 1. Fetch staged tasks (not yet promoted to action items) - let tasks: [TaskActionItem] - do { - let response = try await APIClient.shared.getStagedTasks(limit: 200) - tasks = response.items - } catch { - log("TaskDedup: Failed to fetch staged tasks: \(error)") - return - } - - guard tasks.count >= minimumTaskCount else { - log("TaskDedup: Only \(tasks.count) staged tasks, skipping (minimum: \(minimumTaskCount))") - return - } - - log("TaskDedup: Analyzing \(tasks.count) staged tasks for duplicates") - - // 2. Send all tasks to Gemini in a single call - let totalDeleted = await analyzeAndDeleteDuplicates(tasks: tasks, client: client) - - log("TaskDedup: Run complete. Hard-deleted \(totalDeleted) duplicate staged tasks.") - } - - private func analyzeAndDeleteDuplicates(tasks: [TaskActionItem], client: GeminiClient) async -> Int { - // Build task list for prompt - let taskDescriptions = tasks.map { task -> String in - var parts = ["ID: \(task.id)", "Description: \(task.description)"] - if let due = task.dueAt { - parts.append("Due: \(ISO8601DateFormatter().string(from: due))") - } - if let priority = task.priority { - parts.append("Priority: \(priority)") - } - if let source = task.source { - parts.append("Source: \(source)") - } - parts.append("Created: \(ISO8601DateFormatter().string(from: task.createdAt))") - return parts.joined(separator: "\n") - }.joined(separator: "\n") - - let prompt = """ - Analyze the following tasks for semantic duplicates. Two tasks are duplicates if they \ - refer to the same action, even if worded differently. - - Tasks: - \(taskDescriptions) - - For each group of duplicates, pick the best task to KEEP based on these criteria (in order): - 1. Most descriptive/specific wording - 2. Has a due date over one that doesn't - 3. Higher priority set (high > medium > low > none) - 4. More reliable source (manual > transcription > screenshot) - 5. Most recently created - - Only flag tasks as duplicates if you are confident they refer to the same action. \ - When in doubt, do NOT flag as duplicates. - """ - - let systemPrompt = """ - You are a task deduplication assistant. You identify semantically duplicate tasks \ - and choose the best one to keep. Be conservative - only flag clear duplicates. \ - Return has_duplicates: false if no duplicates are found. - """ - - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "has_duplicates": .init(type: "boolean", description: "Whether any duplicate groups were found"), - "duplicate_groups": .init( - type: "array", - description: "Groups of duplicate tasks", - items: .init( - type: "object", - properties: [ - "keep_id": .init(type: "string", description: "ID of the task to keep"), - "delete_ids": .init( - type: "array", - description: "IDs of tasks to delete", - items: .init(type: "string", properties: nil, required: nil) - ), - "reason": .init(type: "string", description: "Why these tasks are duplicates and which was kept") - ], - required: ["keep_id", "delete_ids", "reason"] - ) - ) - ], - required: ["has_duplicates", "duplicate_groups"] - ) - - // Call Gemini - let responseText: String - do { - responseText = try await client.sendRequest( - prompt: prompt, - systemPrompt: systemPrompt, - responseSchema: responseSchema - ) - } catch { - log("TaskDedup: Gemini request failed: \(error)") - return 0 - } - - // Parse response - guard let data = responseText.data(using: .utf8) else { - log("TaskDedup: Failed to convert response to data") - return 0 - } + log("TaskDedup: Starting server-side deduplication") - let result: DedupResponse do { - result = try JSONDecoder().decode(DedupResponse.self, from: data) - } catch { - log("TaskDedup: Failed to parse response: \(error)") - return 0 - } - - guard result.hasDuplicates, !result.duplicateGroups.isEmpty else { - log("TaskDedup: No duplicates found in batch of \(tasks.count) staged tasks") - return 0 - } + let result = try await service.deduplicateTasks() - // Validate and delete - let validTaskIDs = Set(tasks.map { $0.id }) - let taskLookup = Dictionary(tasks.map { ($0.id, $0) }, uniquingKeysWith: { _, latest in latest }) - var deletedCount = 0 - - for group in result.duplicateGroups { - // Safety: verify all IDs exist in our input - guard validTaskIDs.contains(group.keepId) else { - log("TaskDedup: Skipping group - keep_id '\(group.keepId)' not in input set") - continue - } - - let validDeleteIds = group.deleteIds.filter { validTaskIDs.contains($0) } - if validDeleteIds.count != group.deleteIds.count { - log("TaskDedup: Some delete IDs not in input set, filtering") + if result.deletedIds.isEmpty { + log("TaskDedup: No duplicates found") + return } - guard !validDeleteIds.isEmpty else { continue } - - let keptTask = taskLookup[group.keepId] + log("TaskDedup: Server deleted \(result.deletedIds.count) duplicates. Reason: \(result.reason)") - for deleteId in validDeleteIds { - let deletedTask = taskLookup[deleteId] - - // Log to SQLite + // Log each deletion locally + for deleteId in result.deletedIds { let logRecord = TaskDedupLogRecord( deletedTaskId: deleteId, - deletedDescription: deletedTask?.description ?? "unknown", - keptTaskId: group.keepId, - keptDescription: keptTask?.description ?? "unknown", - reason: group.reason, + deletedDescription: "server-side dedup", + keptTaskId: "", + keptDescription: "", + reason: result.reason, deletedAt: Date() ) - do { try await ProactiveStorage.shared.insertDedupLogRecord(logRecord) } catch { log("TaskDedup: Failed to log deletion record: \(error)") } - - // Hard-delete staged task from backend - do { - try await APIClient.shared.deleteStagedTask(id: deleteId) - deletedCount += 1 - log("TaskDedup: Hard-deleted staged task '\(deletedTask?.description ?? deleteId)' (kept: '\(keptTask?.description ?? group.keepId)') - \(group.reason)") - } catch { - log("TaskDedup: Failed to delete staged task \(deleteId) on backend: \(error)") - } } - } - - return deletedCount - } -} - -// MARK: - Response Models - -private struct DedupResponse: Codable { - let hasDuplicates: Bool - let duplicateGroups: [DuplicateGroup] - - enum CodingKeys: String, CodingKey { - case hasDuplicates = "has_duplicates" - case duplicateGroups = "duplicate_groups" - } - - struct DuplicateGroup: Codable { - let keepId: String - let deleteIds: [String] - let reason: String - - enum CodingKeys: String, CodingKey { - case keepId = "keep_id" - case deleteIds = "delete_ids" - case reason + } catch { + log("TaskDedup: Server deduplication failed: \(error)") } } } From 0e0492b5caa3583f6bcaf0420c4482cdb3f8083e Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:25 +0100 Subject: [PATCH 029/163] Wire TaskPrioritizationService thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace GeminiClient with backendService.rerankTasks(). Remove prompt/ schema building, context fetching — server handles reranking. Co-Authored-By: Claude Opus 4.6 --- .../TaskPrioritizationService.swift | 263 ++---------------- 1 file changed, 29 insertions(+), 234 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift index a1aeca228f..ac922254ef 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift @@ -7,7 +7,7 @@ import Foundation actor TaskPrioritizationService { static let shared = TaskPrioritizationService() - private var geminiClient: GeminiClient? + private var backendService: BackendProactiveService? private var timer: Task? private var isRunning = false private(set) var isScoringInProgress = false @@ -29,13 +29,6 @@ actor TaskPrioritizationService { // Restore persisted timestamps self.lastFullRunTime = UserDefaults.standard.object(forKey: Self.fullRunKey) as? Date - do { - self.geminiClient = try GeminiClient(model: "gemini-pro-latest") - } catch { - log("TaskPrioritize: Failed to initialize GeminiClient: \(error)") - self.geminiClient = nil - } - if let last = self.lastFullRunTime { let hoursAgo = Int(Date().timeIntervalSince(last) / 3600) log("TaskPrioritize: Last full rescore was \(hoursAgo)h ago") @@ -44,6 +37,11 @@ actor TaskPrioritizationService { } } + /// Set the backend service for Phase 2 server-side reranking. + func configure(backendService: BackendProactiveService) { + self.backendService = backendService + } + // MARK: - Lifecycle func start() { @@ -101,187 +99,48 @@ actor TaskPrioritizationService { // MARK: - Full Rescore (Hourly) - /// Send ALL staged tasks to Gemini, get back only the ones that need re-ranking + /// Request server-side reranking via backend WebSocket. private func runFullRescore() async { guard !isScoringInProgress else { log("TaskPrioritize: [FULL] Skipping — scoring already in progress") return } - guard let client = geminiClient else { - log("TaskPrioritize: Skipping full rescore — Gemini client not initialized") + guard let service = backendService else { + log("TaskPrioritize: Skipping full rescore — backend service not configured") return } isScoringInProgress = true defer { isScoringInProgress = false } - log("TaskPrioritize: [FULL] Starting hourly rescore of staged tasks") + log("TaskPrioritize: [FULL] Starting server-side rescore") - // Get ALL staged tasks (not action_items) - let allTasks: [TaskActionItem] do { - allTasks = try await StagedTaskStorage.shared.getAllStagedTasks(limit: 10000) - } catch { - log("TaskPrioritize: [FULL] Failed to fetch staged tasks: \(error)") - return - } - - log("TaskPrioritize: [FULL] Found \(allTasks.count) staged tasks") + let result = try await service.rerankTasks() - guard allTasks.count >= minimumTaskCount else { - log("TaskPrioritize: [FULL] Only \(allTasks.count) staged tasks, skipping") - lastFullRunTime = Date() - return - } - - // Fetch context - let (referenceContext, profile, goals) = await fetchContext() - - // Build the current ranking: tasks ordered by relevanceScore ASC (1 = top) - let sortedTasks = allTasks.sorted { a, b in - let scoreA = a.relevanceScore ?? Int.max - let scoreB = b.relevanceScore ?? Int.max - return scoreA < scoreB - } - - // Build task list for the prompt with current positions - let taskLines = sortedTasks.enumerated().map { (index, task) -> String in - var parts = ["\(index + 1). [id:\(task.id)] \(task.description)"] - if let priority = task.priority { - parts.append("[\(priority)]") - } - if let due = task.dueAt { - let formatter = ISO8601DateFormatter() - parts.append("[due: \(formatter.string(from: due))]") + if result.updatedTasks.isEmpty { + log("TaskPrioritize: [FULL] No tasks need re-ranking, current order is good") + lastFullRunTime = Date() + return } - return parts.joined(separator: " ") - }.joined(separator: "\n") - - // Build context sections - var contextParts: [String] = [] - if let profile = profile, !profile.isEmpty { - contextParts.append("USER PROFILE:\n\(profile)") - } + // Parse server response into reranking tuples + let reranks: [(backendId: String, newPosition: Int)] = result.updatedTasks.compactMap { dict in + guard let id = dict["id"] as? String, + let newPos = dict["new_position"] as? Int else { return nil } + return (backendId: id, newPosition: newPos) + } - if !goals.isEmpty { - let goalsText = goals.enumerated().map { (i, goal) in - var text = "\(i + 1). \(goal.title)" - if let desc = goal.description { - text += " — \(desc)" + if !reranks.isEmpty { + do { + try await StagedTaskStorage.shared.applySelectiveReranking(reranks) + log("TaskPrioritize: [FULL] Applied server re-ranking for \(reranks.count) staged tasks") + } catch { + log("TaskPrioritize: [FULL] Failed to apply re-ranking: \(error)") } - text += " (\(Int(goal.progress))% complete)" - return text - }.joined(separator: "\n") - contextParts.append("ACTIVE GOALS:\n\(goalsText)") - } - - if !referenceContext.isEmpty { - contextParts.append(referenceContext) - } - - let contextSection = contextParts.isEmpty ? "" : contextParts.joined(separator: "\n\n") + "\n\n" - - let prompt = """ - Review the user's staged task list (ranked 1 = most important, \(sortedTasks.count) = least important). - - Identify tasks that are MISRANKED — tasks whose current position doesn't match their actual importance. - Only return tasks that need to move. Do NOT return tasks that are already well-positioned. - - Consider: - 1. Alignment with the user's goals and current priorities - 2. Time urgency (due date proximity) - 3. Actionability — specific tasks rank higher than vague ones - 4. Real-world importance (financial, health, commitments to others) - 5. Most AI-extracted tasks are noise — push vague/irrelevant tasks down - - \(contextSection)CURRENT TASK RANKING (1 = most important): - \(taskLines) - - Return ONLY the tasks that need re-ranking, with their new position numbers. - New positions should be relative to the current list size (1 to \(sortedTasks.count)). - """ - - let systemPrompt = """ - You are a task prioritization assistant. You review a ranked task list and identify \ - tasks that are misranked. Be selective — only return tasks that genuinely need to move. \ - If the ranking looks reasonable, return an empty list. Be decisive about pushing noise \ - and vague tasks down and promoting urgent, goal-aligned tasks up. - """ - - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "reranked_tasks": .init( - type: "array", - description: "Tasks that need to be moved, with new positions", - items: .init( - type: "object", - properties: [ - "task_id": .init(type: "string", description: "The task ID"), - "new_position": .init(type: "integer", description: "New rank position (1 = most important)") - ], - required: ["task_id", "new_position"] - ) - ), - "reasoning": .init(type: "string", description: "Brief explanation of major ranking changes") - ], - required: ["reranked_tasks", "reasoning"] - ) - - log("TaskPrioritize: [FULL] Sending \(sortedTasks.count) staged tasks to Gemini") - - let responseText: String - do { - responseText = try await client.sendRequest( - prompt: prompt, - systemPrompt: systemPrompt, - responseSchema: responseSchema - ) - } catch { - log("TaskPrioritize: [FULL] Gemini request failed: \(error)") - return - } - - let truncated = responseText.prefix(500) - log("TaskPrioritize: [FULL] Gemini response (\(responseText.count) chars): \(truncated)\(responseText.count > 500 ? "..." : "")") - - guard let data = responseText.data(using: .utf8) else { - log("TaskPrioritize: [FULL] Failed to convert response to data") - return - } - - let result: ReRankingResponse - do { - result = try JSONDecoder().decode(ReRankingResponse.self, from: data) - } catch { - log("TaskPrioritize: [FULL] Failed to parse re-ranking response: \(error)") - return - } - - log("TaskPrioritize: [FULL] Gemini returned \(result.rerankedTasks.count) tasks to re-rank") - if !result.reasoning.isEmpty { - log("TaskPrioritize: [FULL] Reasoning: \(result.reasoning.prefix(300))") - } - - // Validate: only keep task IDs that exist in our list - let validIds = Set(allTasks.map { $0.id }) - let validReranks = result.rerankedTasks.filter { validIds.contains($0.taskId) } - - if validReranks.count != result.rerankedTasks.count { - log("TaskPrioritize: [FULL] Filtered out \(result.rerankedTasks.count - validReranks.count) invalid task IDs") - } - - if !validReranks.isEmpty { - let reranks = validReranks.map { (backendId: $0.taskId, newPosition: $0.newPosition) } - do { - try await StagedTaskStorage.shared.applySelectiveReranking(reranks) - log("TaskPrioritize: [FULL] Applied selective re-ranking for \(validReranks.count) staged tasks") - } catch { - log("TaskPrioritize: [FULL] Failed to apply re-ranking: \(error)") } - } else { - log("TaskPrioritize: [FULL] No tasks need re-ranking, current order is good") + } catch { + log("TaskPrioritize: [FULL] Server reranking failed: \(error)") } lastFullRunTime = Date() @@ -304,68 +163,4 @@ actor TaskPrioritizationService { } } - // MARK: - Shared Context Fetching - - private func fetchContext() async -> (referenceContext: String, profile: String?, goals: [Goal]) { - let userProfile = await AIUserProfileService.shared.getLatestProfile() - - let goals: [Goal] - do { - goals = try await APIClient.shared.getGoals() - } catch { - log("TaskPrioritize: Failed to fetch goals: \(error)") - goals = [] - } - - let referenceTasks: [TaskActionItem] - do { - referenceTasks = try await ActionItemStorage.shared.getLocalActionItems( - limit: 100, - completed: true - ) - } catch { - log("TaskPrioritize: Failed to fetch reference tasks: \(error)") - referenceTasks = [] - } - let referenceContext = buildReferenceContext(referenceTasks) - - return (referenceContext, userProfile?.profileText, goals) - } - - // MARK: - Context Builders - - private func buildReferenceContext(_ tasks: [TaskActionItem]) -> String { - guard !tasks.isEmpty else { return "" } - - let completed = tasks.filter { !($0.description.isEmpty) }.prefix(50) - guard !completed.isEmpty else { return "" } - - let lines = completed.map { task -> String in - "- [completed] \(task.description)" - }.joined(separator: "\n") - - return "TASKS THE USER HAS COMPLETED (for reference — do NOT rank these):\n\(lines)" - } -} - -// MARK: - Response Models - -private struct ReRankingResponse: Codable { - let rerankedTasks: [ReRankedTask] - let reasoning: String - - struct ReRankedTask: Codable { - let taskId: String - let newPosition: Int - - enum CodingKeys: String, CodingKey { - case taskId = "task_id" - case newPosition = "new_position" - } - } - - enum CodingKeys: String, CodingKey { - case rerankedTasks = "reranked_tasks" - case reasoning - } } From 289e41abbd6731cdb875aed0ddf8ef6b41257fc8 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:33 +0100 Subject: [PATCH 030/163] Wire AIUserProfileService thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace 2-stage Gemini profile generation with backendService.requestProfile(). Remove fetchDataSources, buildPrompt, buildConsolidationPrompt — server fetches user data from Firestore and generates profile server-side. Co-Authored-By: Claude Opus 4.6 --- .../Services/AIUserProfileService.swift | 307 ++---------------- 1 file changed, 20 insertions(+), 287 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift index ec36f82ece..111615bf75 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift @@ -34,7 +34,7 @@ extension AIUserProfileRecord: TableDocumented { actor AIUserProfileService { static let shared = AIUserProfileService() - private let model = "gemini-pro-latest" + private var backendService: BackendProactiveService? private let maxProfileLength = 10000 /// Whether profile generation is currently in progress @@ -48,6 +48,11 @@ actor AIUserProfileService { _dbQueue = nil } + /// Set the backend service for Phase 2 server-side profile generation. + func configure(backendService: BackendProactiveService) { + self.backendService = backendService + } + // MARK: - Database Access private func ensureDB() async throws -> DatabasePool { @@ -211,314 +216,42 @@ actor AIUserProfileService { }) ?? [] } - /// Generate a new AI user profile from all available data sources + /// Generate a new AI user profile via backend WebSocket. + /// The backend fetches all user data from Firestore and generates the profile server-side. func generateProfile() async throws -> AIUserProfileRecord { guard !isGenerating else { throw ProfileError.alreadyGenerating } + guard let service = backendService else { + throw ProfileError.databaseNotAvailable + } isGenerating = true defer { isGenerating = false } - log("AIUserProfileService: Starting profile generation") - - // 1. Fetch all data sources in parallel - let (memories, tasks, goals, conversations, messages) = await fetchDataSources() - - // 2. Count total data items - let dataSourcesUsed = memories.count + tasks.count + goals.count + conversations.count + messages.count - log("AIUserProfileService: Fetched \(dataSourcesUsed) data items (memories=\(memories.count), tasks=\(tasks.count), goals=\(goals.count), convos=\(conversations.count), messages=\(messages.count))") - - guard dataSourcesUsed > 0 else { - throw ProfileError.insufficientData - } - - // 3. Build prompt - let prompt = buildPrompt(memories: memories, tasks: tasks, goals: goals, conversations: conversations, messages: messages) - - // 4. Call Gemini - let gemini = try GeminiClient(model: model) - let systemPrompt = """ - You are generating a structured user profile that will be injected as context into AI pipelines \ - (task extraction, goal extraction, memory extraction) that analyze the user's screen and audio activity. - - OUTPUT FORMAT: - - A flat list of factual statements, one per line, prefixed with "- " - - Each statement must be a concrete fact directly supported by the provided data - - No prose, no paragraphs, no headers, no markdown formatting - - No adjectives like "passionate", "dedicated", "impressive" - - Write in third person ("User works at...", not "You work at...") - - WHAT TO INCLUDE (only if clearly supported by the data): - - Full name, role, company, industry - - Current projects and what tools/apps they use for each - - Key people they interact with (names, roles, relationship) - - Active goals and their progress - - Recurring meetings, deadlines, routines - - Communication platforms they use (Slack, email, iMessage, etc.) - - Technical stack, programming languages, frameworks - - Topics they frequently discuss or research - - Pending tasks and commitments to others - - Time zone, work schedule patterns - - CRITICAL RULES: - - ONLY include facts that are directly evidenced in the provided data - - If a category has no supporting data, skip it entirely — do not guess or infer - - Do NOT hallucinate names, roles, companies, or relationships not present in the data - - Do NOT add personality descriptions or subjective assessments - - When uncertain, omit rather than speculate - - NEVER fabricate email addresses, phone numbers, URLs, or contact information - - The provided data contains NO email addresses — do not invent any - - If you cannot find a piece of information verbatim in the data, do not include it - - The output MUST be under 2000 characters total. - """ - - let stageOneText = try await gemini.sendTextRequest(prompt: prompt, systemPrompt: systemPrompt) - log("AIUserProfileService: Stage 1 complete (\(stageOneText.count) chars)") - - // 5. Stage 2 — Consolidate with past profiles for holistic view - let pastProfiles = await getAllProfiles(limit: 5) - let finalText: String - if pastProfiles.isEmpty { - finalText = stageOneText - } else { - let consolidationPrompt = buildConsolidationPrompt( - newProfile: stageOneText, - pastProfiles: pastProfiles - ) - let consolidationSystemPrompt = """ - You are merging a newly generated user profile with historical profiles to create \ - one holistic, up-to-date user profile. This profile is injected as context into AI pipelines \ - (task extraction, goal extraction, memory extraction) that analyze the user's screen and audio activity. - - OUTPUT FORMAT: - - A flat list of factual statements, one per line, prefixed with "- " - - Each statement must be a concrete fact - - No prose, no paragraphs, no headers, no markdown formatting - - No adjectives or subjective assessments - - Write in third person - - MERGE RULES: - - The NEW profile reflects today's data and takes priority for current state - - Past profiles provide historical context — retain facts that are still relevant - - If a fact from the past contradicts the new profile, use the new one - - Remove outdated information (completed tasks, past deadlines, old routines) - - Keep stable facts (name, role, company, key relationships, tech stack) - - Accumulate knowledge: if past profiles mention people, projects, or patterns \ - not in today's data, keep them if they seem ongoing - - Do NOT hallucinate — only include facts present in the provided profiles - - Do NOT add commentary about changes or evolution over time - - The output MUST be under 2000 characters total. - """ - finalText = try await gemini.sendTextRequest( - prompt: consolidationPrompt, - systemPrompt: consolidationSystemPrompt - ) - log("AIUserProfileService: Stage 2 consolidation complete (\(finalText.count) chars)") - } + log("AIUserProfileService: Requesting server-side profile generation") - // 6. Truncate if needed - let truncated = String(finalText.prefix(maxProfileLength)) + let profileText = try await service.requestProfile() + let truncated = String(profileText.prefix(maxProfileLength)) let generatedAt = Date() - // 6. Save to database + log("AIUserProfileService: Received profile from backend (\(truncated.count) chars)") + + // Save to local database let db = try await ensureDB() let record = AIUserProfileRecord( profileText: truncated, - dataSourcesUsed: dataSourcesUsed, - backendSynced: false, + dataSourcesUsed: 0, + backendSynced: true, // Backend already has it generatedAt: generatedAt ) try await db.write { database in try record.insert(database) } - // 7. Sync to backend (fire-and-forget) - let recordId = record.id - Task { - do { - try await APIClient.shared.syncAIUserProfile( - profileText: truncated, - generatedAt: generatedAt, - dataSourcesUsed: dataSourcesUsed - ) - // Mark as synced - if let id = recordId, let db = try? await self.ensureDB() { - _ = try? await db.write { database in - try database.execute( - sql: "UPDATE ai_user_profiles SET backendSynced = 1 WHERE id = ?", - arguments: [id] - ) - } - } - log("AIUserProfileService: Synced profile to backend") - } catch { - log("AIUserProfileService: Failed to sync profile to backend: \(error.localizedDescription)") - } - } - - log("AIUserProfileService: Profile generated successfully (\(truncated.count) chars, \(dataSourcesUsed) data items)") + log("AIUserProfileService: Profile saved to local DB") return record } - // MARK: - Data Fetching - - private func fetchDataSources() async -> ( - memories: [String], - tasks: [String], - goals: [String], - conversations: [String], - messages: [String] - ) { - async let memoriesTask = fetchMemories() - async let tasksTask = fetchTasks() - async let goalsTask = fetchGoals() - async let conversationsTask = fetchConversations() - async let messagesTask = fetchMessages() - - let memories = await memoriesTask - let tasks = await tasksTask - let goals = await goalsTask - let conversations = await conversationsTask - let messages = await messagesTask - - return (memories, tasks, goals, conversations, messages) - } - - private func fetchMemories() async -> [String] { - do { - let memories = try await APIClient.shared.getMemories(limit: 100) - return memories.map { "[\($0.category.rawValue)] \($0.content)" } - } catch { - log("AIUserProfileService: Failed to fetch memories: \(error.localizedDescription)") - return [] - } - } - - private func fetchTasks() async -> [String] { - do { - let response = try await APIClient.shared.getActionItems(limit: 50) - return response.items.map { item in - let status = item.completed ? "done" : "todo" - let priority = item.priority ?? "medium" - return "[\(status)/\(priority)] \(item.description)" - } - } catch { - log("AIUserProfileService: Failed to fetch tasks: \(error.localizedDescription)") - return [] - } - } - - private func fetchGoals() async -> [String] { - do { - let goals = try await APIClient.shared.getGoals() - return goals.filter { $0.isActive }.map { goal in - let progress = goal.targetValue > 0 ? Int((goal.currentValue / goal.targetValue) * 100) : 0 - return "\(goal.title) (\(progress)% complete)" - } - } catch { - log("AIUserProfileService: Failed to fetch goals: \(error.localizedDescription)") - return [] - } - } - - private func fetchConversations() async -> [String] { - do { - let sevenDaysAgo = Calendar.current.date(byAdding: .day, value: -7, to: Date()) - let conversations = try await APIClient.shared.getConversations( - limit: 20, - startDate: sevenDaysAgo - ) - return conversations.compactMap { convo in - let title = convo.structured.title - let summary = convo.structured.overview - guard !title.isEmpty else { return nil } - return "\(title): \(summary)" - } - } catch { - log("AIUserProfileService: Failed to fetch conversations: \(error.localizedDescription)") - return [] - } - } - - private func fetchMessages() async -> [String] { - do { - let messages = try await APIClient.shared.getMessages(limit: 30) - return messages.map { "[\($0.sender)] \($0.text)" } - } catch { - log("AIUserProfileService: Failed to fetch messages: \(error.localizedDescription)") - return [] - } - } - - // MARK: - Prompt Building - - private func buildPrompt( - memories: [String], - tasks: [String], - goals: [String], - conversations: [String], - messages: [String] - ) -> String { - var sections: [String] = [] - - if !memories.isEmpty { - sections.append("## Memories about the user\n\(memories.joined(separator: "\n"))") - } - - if !tasks.isEmpty { - sections.append("## Recent tasks\n\(tasks.joined(separator: "\n"))") - } - - if !goals.isEmpty { - sections.append("## Active goals\n\(goals.joined(separator: "\n"))") - } - - if !conversations.isEmpty { - sections.append("## Recent conversations (past 7 days)\n\(conversations.joined(separator: "\n"))") - } - - if !messages.isEmpty { - sections.append("## Recent AI chat messages\n\(messages.joined(separator: "\n"))") - } - - return """ - Generate a factual user profile from the following data. \ - Output a flat list of concrete facts (one per line, prefixed with "- "). \ - This profile will be used as context for AI pipelines that analyze the user's screen and audio activity \ - to extract tasks, goals, and memories. Focus on facts that help identify who is who, what projects are active, \ - and what the user's current priorities are. Under 2000 characters. - - \(sections.joined(separator: "\n\n")) - """ - } - - private func buildConsolidationPrompt( - newProfile: String, - pastProfiles: [AIUserProfileRecord] - ) -> String { - let dateFormatter = DateFormatter() - dateFormatter.dateStyle = .medium - dateFormatter.timeStyle = .none - - var pastSection = "" - for profile in pastProfiles { - let dateStr = dateFormatter.string(from: profile.generatedAt) - pastSection += "--- Profile from \(dateStr) ---\n\(profile.profileText)\n\n" - } - - return """ - Merge the following into one holistic user profile. Under 2000 characters. - - === NEW PROFILE (generated today from latest data) === - \(newProfile) - - === PAST PROFILES (oldest to newest, up to 5) === - \(pastSection) - """ - } - // MARK: - Errors enum ProfileError: LocalizedError { From 4c92d5ba10454de1f3c9fd18bbe51db94289c9e4 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:40 +0100 Subject: [PATCH 031/163] Wire ProactiveAssistantsPlugin to pass backendService to all assistants (#5396) Pass shared BackendProactiveService to all 4 assistants and 3 text-only services. Remove do/catch since inits no longer throw. Update AdviceTestRunnerWindow fallback creation. Co-Authored-By: Claude Opus 4.6 --- .../ProactiveAssistantsPlugin.swift | 88 +++++++++---------- .../UI/AdviceTestRunnerWindow.swift | 22 ++--- 2 files changed, 47 insertions(+), 63 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift index 4bc4cfe752..268d260f62 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift @@ -314,62 +314,58 @@ public class ProactiveAssistantsPlugin: NSObject { proactiveService.connect() backendProactiveService = proactiveService - do { - focusAssistant = FocusAssistant( - backendService: proactiveService, - onAlert: { [weak self] message in - self?.sendEvent(type: "alert", data: ["message": message]) - }, - onStatusChange: { [weak self] status in - Task { @MainActor in - self?.lastStatus = status - self?.sendEvent(type: "statusChange", data: ["status": status.rawValue]) - } - }, - onRefocus: { - Task { @MainActor in - OverlayService.shared.showGlowAroundActiveWindow(colorMode: .focused) - } - }, - onDistraction: { - Task { @MainActor in - OverlayService.shared.showGlowAroundActiveWindow(colorMode: .distracted) - } + focusAssistant = FocusAssistant( + backendService: proactiveService, + onAlert: { [weak self] message in + self?.sendEvent(type: "alert", data: ["message": message]) + }, + onStatusChange: { [weak self] status in + Task { @MainActor in + self?.lastStatus = status + self?.sendEvent(type: "statusChange", data: ["status": status.rawValue]) + } + }, + onRefocus: { + Task { @MainActor in + OverlayService.shared.showGlowAroundActiveWindow(colorMode: .focused) + } + }, + onDistraction: { + Task { @MainActor in + OverlayService.shared.showGlowAroundActiveWindow(colorMode: .distracted) } - ) - - if let focus = focusAssistant { - AssistantCoordinator.shared.register(focus) } + ) - taskAssistant = try TaskAssistant() + if let focus = focusAssistant { + AssistantCoordinator.shared.register(focus) + } - if let task = taskAssistant { - AssistantCoordinator.shared.register(task) - } + taskAssistant = TaskAssistant(backendService: proactiveService) - Task { await TaskDeduplicationService.shared.start() } - Task { await TaskPrioritizationService.shared.start() } - Task { await TaskPromotionService.shared.start() } + if let task = taskAssistant { + AssistantCoordinator.shared.register(task) + } - adviceAssistant = try AdviceAssistant() + // Configure text-only services with backend service + Task { await TaskDeduplicationService.shared.configure(backendService: proactiveService) } + Task { await TaskPrioritizationService.shared.configure(backendService: proactiveService) } + Task { await AIUserProfileService.shared.configure(backendService: proactiveService) } - if let advice = adviceAssistant { - AssistantCoordinator.shared.register(advice) - } + Task { await TaskDeduplicationService.shared.start() } + Task { await TaskPrioritizationService.shared.start() } + Task { await TaskPromotionService.shared.start() } - memoryAssistant = try MemoryAssistant() + adviceAssistant = AdviceAssistant(backendService: proactiveService) - if let memory = memoryAssistant { - AssistantCoordinator.shared.register(memory) - } + if let advice = adviceAssistant { + AssistantCoordinator.shared.register(advice) + } - } catch { - log("ProactiveAssistantsPlugin: Failed to initialize assistants: \(error.localizedDescription)") - logError("ProactiveAssistantsPlugin: Assistant initialization failed", error: error) - isStartingMonitoring = false - completion(false, error.localizedDescription) - return + memoryAssistant = MemoryAssistant(backendService: proactiveService) + + if let memory = memoryAssistant { + AssistantCoordinator.shared.register(memory) } // Get initial app state diff --git a/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift b/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift index 925c73f312..a87b4582b5 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift @@ -441,17 +441,9 @@ struct AdviceTestRunnerView: View { adviceAssistant = existing log("AdviceTestRunner: Using existing AdviceAssistant from coordinator") } else { - do { - adviceAssistant = try AdviceAssistant() - log("AdviceTestRunner: Created fresh AdviceAssistant instance") - } catch { - log("AdviceTestRunner: ERROR - Failed to create AdviceAssistant: \(error)") - await MainActor.run { - statusMessage = "Failed to create Advice Assistant: \(error.localizedDescription)" - isRunning = false - } - return - } + let service = BackendProactiveService(); service.connect() + adviceAssistant = AdviceAssistant(backendService: service) + log("AdviceTestRunner: Created fresh AdviceAssistant instance") } // Get excluded apps @@ -647,12 +639,8 @@ enum AdviceTestRunner { if let existing = coordAssistant as? AdviceAssistant { adviceAssistant = existing } else { - do { - adviceAssistant = try AdviceAssistant() - } catch { - log("AdviceTestCLI: ERROR — Failed to create AdviceAssistant: \(error)") - return - } + let service = BackendProactiveService(); service.connect() + adviceAssistant = AdviceAssistant(backendService: service) } // Get excluded apps From aba6be44ba07c87369728085f1fc5fdb69b87372 Mon Sep 17 00:00:00 2001 From: beastoin Date: Mon, 9 Mar 2026 05:20:34 +0100 Subject: [PATCH 032/163] Wire LiveNotesMonitor thin client for Phase 2 (#5396) Replace direct GeminiClient usage with BackendProactiveService. Uses configure(backendService:) singleton pattern matching other text-based services. Prompt logic moves server-side. Co-Authored-By: Claude Opus 4.6 --- .../Sources/LiveNotes/LiveNotesMonitor.swift | 67 ++++++------------- 1 file changed, 20 insertions(+), 47 deletions(-) diff --git a/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift b/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift index f859f973a4..fdcc7944a9 100644 --- a/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift +++ b/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift @@ -45,23 +45,12 @@ class LiveNotesMonitor: ObservableObject { /// Existing notes for context (to avoid repetition) private var existingNotesContext: [String] = [] - /// GeminiClient for AI generation (lazily initialized) - private var geminiClient: GeminiClient? + /// Backend service for AI generation (injected via configure()) + private var backendService: BackendProactiveService? /// Cancellables for subscriptions private var cancellables = Set() - /// AI prompt for note generation (from m13v/meeting) - private let noteGenerationPrompt = """ - generate a single, concise note about what happened in this segment. - be factual and specific. - focus on the key point or action item. - keep it a few word sentence. - do not use quotes. - do not use wrapping words like "discussion on", jump straight into note. - avoid repeating information from existing notes. - """ - private init() { // Subscribe to transcript changes LiveTranscriptMonitor.shared.$segments @@ -72,6 +61,12 @@ class LiveNotesMonitor: ObservableObject { .store(in: &cancellables) } + /// Configure with backend service (call before startSession) + func configure(backendService: BackendProactiveService) { + self.backendService = backendService + log("LiveNotesMonitor: Configured with BackendProactiveService") + } + // MARK: - Session Lifecycle /// Start a new notes session @@ -85,15 +80,8 @@ class LiveNotesMonitor: ObservableObject { lastProcessedSegmentEnd = nil existingNotesContext = [] - // Initialize Gemini client if not already done - if geminiClient == nil { - do { - // Use Gemini 3 Pro for better note generation quality - geminiClient = try GeminiClient(model: "gemini-pro-latest") - log("LiveNotesMonitor: GeminiClient initialized with gemini-pro-latest") - } catch { - logError("LiveNotesMonitor: Failed to initialize GeminiClient", error: error) - } + if backendService == nil { + log("LiveNotesMonitor: WARNING — backendService not configured, AI notes disabled") } // Load any existing notes from DB (for crash recovery) @@ -252,10 +240,10 @@ class LiveNotesMonitor: ObservableObject { } } - /// Generate an AI note from recent transcript + /// Generate an AI note from recent transcript via backend private func generateNote(from segments: [SpeakerSegment]) { guard let sessionId = currentSessionId, - let client = geminiClient, + let service = backendService, !isGenerating else { return } isGenerating = true @@ -265,34 +253,21 @@ class LiveNotesMonitor: ObservableObject { let segmentStartOrder = max(0, currentSegmentOrder - 3) let segmentEndOrder = currentSegmentOrder - // Build context from existing notes - let existingNotesText = existingNotesContext.isEmpty - ? "No existing notes yet." + // Build session context from existing notes + let sessionContext = existingNotesContext.isEmpty + ? "" : "Existing notes:\n" + existingNotesContext.map { "- \($0)" }.joined(separator: "\n") - let prompt = """ - Transcript segment: - \(recentText) - - \(existingNotesText) - - \(noteGenerationPrompt) - """ - Task { do { - let response = try await client.sendTextRequest( - prompt: prompt, - systemPrompt: "You are a concise note-taker. Generate a single short note (3-10 words) about the key point in the transcript. Do not use quotes. Be direct and specific." - ) + let noteText = try await service.generateLiveNote(text: recentText, sessionContext: sessionContext) - // Clean up the response - let noteText = response + let cleaned = noteText .trimmingCharacters(in: .whitespacesAndNewlines) .replacingOccurrences(of: "\"", with: "") .replacingOccurrences(of: "'", with: "") - guard !noteText.isEmpty else { + guard !cleaned.isEmpty else { await MainActor.run { self.isGenerating = false } return } @@ -300,7 +275,7 @@ class LiveNotesMonitor: ObservableObject { // Save to DB let record = try await NoteStorage.shared.createNote( sessionId: sessionId, - text: noteText, + text: cleaned, isAiGenerated: true, segmentStartOrder: segmentStartOrder, segmentEndOrder: segmentEndOrder @@ -309,8 +284,7 @@ class LiveNotesMonitor: ObservableObject { if let note = record.toLiveNote() { await MainActor.run { self.notes.append(note) - self.existingNotesContext.append(noteText) - // Trim context to prevent unbounded growth (keep most recent notes) + self.existingNotesContext.append(cleaned) if self.existingNotesContext.count > self.maxExistingNotesContext { self.existingNotesContext.removeFirst(self.existingNotesContext.count - self.maxExistingNotesContext) } @@ -321,7 +295,6 @@ class LiveNotesMonitor: ObservableObject { await MainActor.run { self.isGenerating = false } } } catch let dbError as DatabaseError where dbError.resultCode == .SQLITE_CONSTRAINT { - // Session was deleted during async AI generation — not an error log("LiveNotesMonitor: Session \(sessionId) deleted during note generation, skipping") await MainActor.run { self.isGenerating = false } } catch { From b985003d2be66cac3f9e90e5fdb5499003b936c3 Mon Sep 17 00:00:00 2001 From: beastoin Date: Mon, 9 Mar 2026 05:20:35 +0100 Subject: [PATCH 033/163] Wire LiveNotesMonitor in ProactiveAssistantsPlugin (#5396) Add configure(backendService:) call for LiveNotesMonitor alongside other singleton text-based services. Co-Authored-By: Claude Opus 4.6 --- .../Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift index 268d260f62..5c62372560 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift @@ -351,6 +351,7 @@ public class ProactiveAssistantsPlugin: NSObject { Task { await TaskDeduplicationService.shared.configure(backendService: proactiveService) } Task { await TaskPrioritizationService.shared.configure(backendService: proactiveService) } Task { await AIUserProfileService.shared.configure(backendService: proactiveService) } + Task { await LiveNotesMonitor.shared.configure(backendService: proactiveService) } Task { await TaskDeduplicationService.shared.start() } Task { await TaskPrioritizationService.shared.start() } From 0db4d3d334014c04266dc164c77c9d8fe0789a09 Mon Sep 17 00:00:00 2001 From: beastoin Date: Wed, 4 Mar 2026 14:37:42 +0100 Subject: [PATCH 034/163] Use OMI_API_URL for auth instead of dedicated Cloud Run service Replaces the hardcoded omi-desktop-auth Cloud Run URL with the OMI_API_URL environment variable, matching APIClient.baseURL resolution. Python backend already has identical /v1/auth/* endpoints. Closes #5359 --- desktop/Desktop/Sources/AuthService.swift | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/desktop/Desktop/Sources/AuthService.swift b/desktop/Desktop/Sources/AuthService.swift index 84e94b605d..c44eee8da9 100644 --- a/desktop/Desktop/Sources/AuthService.swift +++ b/desktop/Desktop/Sources/AuthService.swift @@ -43,8 +43,17 @@ class AuthService { private var appleSignInDelegate: AppleSignInDelegate? // API Configuration - // Production: Cloud Run backend - private let apiBaseURL: String = "https://omi-desktop-auth-208440318997.us-central1.run.app/" + // Auth uses the same backend as the rest of the app (OMI_API_URL) + private var apiBaseURL: String { + // Match APIClient.baseURL resolution: getenv() first, then ProcessInfo fallback + if let cString = getenv("OMI_API_URL"), let url = String(validatingUTF8: cString), !url.isEmpty { + return url.hasSuffix("/") ? url : url + "/" + } + if let envURL = ProcessInfo.processInfo.environment["OMI_API_URL"], !envURL.isEmpty { + return envURL.hasSuffix("/") ? envURL : envURL + "/" + } + fatalError("OMI_API_URL not set. Ensure .env file is present in app bundle.") + } private var redirectURI: String { return "\(urlScheme)://auth/callback" } @@ -350,7 +359,8 @@ class AuthService { return } - NSLog("OMI AUTH: Starting Sign in with %@ (Web OAuth)", provider) + let authHost = URL(string: apiBaseURL)?.host ?? "unknown" + NSLog("OMI AUTH: Starting Sign in with %@ (Web OAuth) via %@", provider, authHost) isLoading = true error = nil From fe6d87aff3c690d84beaafc3f62dc831c4e54562 Mon Sep 17 00:00:00 2001 From: beastoin Date: Wed, 4 Mar 2026 14:37:46 +0100 Subject: [PATCH 035/163] Use dynamic redirect_uri in auth callback template Pass redirect_uri from the auth session to the callback HTML template instead of hardcoding omi://auth/callback. This enables desktop apps (which use omi-computer://auth/callback) to receive OAuth callbacks correctly when authenticating through the Python backend. --- backend/templates/auth_callback.html | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/templates/auth_callback.html b/backend/templates/auth_callback.html index 7f5a820e6b..de22660bc5 100644 --- a/backend/templates/auth_callback.html +++ b/backend/templates/auth_callback.html @@ -108,8 +108,9 @@

Authentication Successful

spinnerElement.style.display = 'none'; messageElement.textContent = 'Please close this window and try again.'; } else if (code) { - // Build the custom scheme redirect URL - let redirectUrl = 'omi://auth/callback?code=' + encodeURIComponent(code); + // Build the custom scheme redirect URL using the redirect_uri from the auth session + const redirectUri = "{{ redirect_uri }}"; + let redirectUrl = redirectUri + '?code=' + encodeURIComponent(code); if (state) { redirectUrl += '&state=' + encodeURIComponent(state); } From 05f43bf6c80e8e1d213e65f67e1fab5bc2271448 Mon Sep 17 00:00:00 2001 From: beastoin Date: Wed, 4 Mar 2026 14:37:51 +0100 Subject: [PATCH 036/163] Pass redirect_uri from session to auth callback template Both Google and Apple callback endpoints now pass the session's redirect_uri to the auth_callback.html template, enabling dynamic custom URL scheme redirects per client (mobile vs desktop). --- backend/routers/auth.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/routers/auth.py b/backend/routers/auth.py index 3e81cbd93b..083910b066 100644 --- a/backend/routers/auth.py +++ b/backend/routers/auth.py @@ -96,6 +96,7 @@ async def auth_callback_google( "request": request, "code": auth_code, "state": session_data['state'] or '', + "redirect_uri": session_data.get('redirect_uri', 'omi://auth/callback'), }, ) @@ -134,6 +135,7 @@ async def auth_callback_apple_post( "request": request, "code": auth_code, "state": session_data['state'] or '', + "redirect_uri": session_data.get('redirect_uri', 'omi://auth/callback'), }, ) From 103ef6f3ed24d114348ec5659266e62bf2244dbe Mon Sep 17 00:00:00 2001 From: beastoin Date: Wed, 4 Mar 2026 14:42:02 +0100 Subject: [PATCH 037/163] Validate redirect_uri against allowed app URL schemes Add server-side validation at /v1/auth/authorize to reject redirect_uri values that don't match allowed app schemes (omi://, omi-computer://, omi-computer-dev://). Also fix empty string fallback with 'or' operator. --- backend/routers/auth.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/backend/routers/auth.py b/backend/routers/auth.py index 083910b066..fb05b38c96 100644 --- a/backend/routers/auth.py +++ b/backend/routers/auth.py @@ -44,6 +44,11 @@ async def auth_authorize( if provider not in ['google', 'apple']: raise HTTPException(status_code=400, detail="Unsupported provider") + # Validate redirect_uri against allowed app URL schemes + ALLOWED_REDIRECT_SCHEMES = ('omi://', 'omi-computer://', 'omi-computer-dev://') + if not redirect_uri or not any(redirect_uri.startswith(s) for s in ALLOWED_REDIRECT_SCHEMES): + raise HTTPException(status_code=400, detail="Invalid redirect_uri: must use an allowed app URL scheme") + # Store session for auth flow session_id = str(uuid.uuid4()) session_data = { @@ -96,7 +101,7 @@ async def auth_callback_google( "request": request, "code": auth_code, "state": session_data['state'] or '', - "redirect_uri": session_data.get('redirect_uri', 'omi://auth/callback'), + "redirect_uri": session_data.get('redirect_uri') or 'omi://auth/callback', }, ) @@ -135,7 +140,7 @@ async def auth_callback_apple_post( "request": request, "code": auth_code, "state": session_data['state'] or '', - "redirect_uri": session_data.get('redirect_uri', 'omi://auth/callback'), + "redirect_uri": session_data.get('redirect_uri') or 'omi://auth/callback', }, ) From b2faf43a5b3545d73fb177a239fb92f5eb5e8beb Mon Sep 17 00:00:00 2001 From: beastoin Date: Wed, 4 Mar 2026 14:42:07 +0100 Subject: [PATCH 038/163] Add client-side redirect scheme validation and safe serialization Use |tojson filter for safe template variable serialization. Add defense-in-depth scheme validation in JavaScript before redirect. Block redirect and manual link for disallowed schemes. --- backend/templates/auth_callback.html | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/backend/templates/auth_callback.html b/backend/templates/auth_callback.html index de22660bc5..e4a591b0f7 100644 --- a/backend/templates/auth_callback.html +++ b/backend/templates/auth_callback.html @@ -109,17 +109,28 @@

Authentication Successful

messageElement.textContent = 'Please close this window and try again.'; } else if (code) { // Build the custom scheme redirect URL using the redirect_uri from the auth session - const redirectUri = "{{ redirect_uri }}"; + const redirectUri = {{ redirect_uri|tojson }}; + + // Validate redirect scheme before use (defense-in-depth; server also validates) + const ALLOWED_SCHEMES = ['omi://', 'omi-computer://', 'omi-computer-dev://']; + const isAllowedScheme = ALLOWED_SCHEMES.some(s => redirectUri.startsWith(s)); + if (!isAllowedScheme) { + errorElement.textContent = 'Invalid redirect scheme.'; + spinnerElement.style.display = 'none'; + messageElement.textContent = 'Please close this window and try again.'; + } + let redirectUrl = redirectUri + '?code=' + encodeURIComponent(code); if (state) { redirectUrl += '&state=' + encodeURIComponent(state); } // Set up manual link - manualLinkElement.href = redirectUrl; + manualLinkElement.href = isAllowedScheme ? redirectUrl : '#'; // Attempt automatic redirect try { + if (!isAllowedScheme) throw new Error('Blocked redirect to disallowed scheme'); console.log('Redirecting to:', redirectUrl); window.location.href = redirectUrl; From 53318e0fe443c992392a95f51d9a11b4347e25f5 Mon Sep 17 00:00:00 2001 From: beastoin Date: Wed, 4 Mar 2026 14:48:04 +0100 Subject: [PATCH 039/163] Add auth endpoint tests for redirect_uri validation and template rendering 15 tests covering: - Redirect_uri allowlist validation (rejects https, javascript, data, ftp, empty) - Allowed schemes pass (omi://, omi-computer://, omi-computer-dev://) - Google/Apple callback uses session redirect_uri in template - Fallback to default omi://auth/callback when missing - XSS safety: JSON-escaped redirect_uri prevents script injection --- backend/tests/unit/test_auth_routes.py | 242 +++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 backend/tests/unit/test_auth_routes.py diff --git a/backend/tests/unit/test_auth_routes.py b/backend/tests/unit/test_auth_routes.py new file mode 100644 index 0000000000..68d4cc4c88 --- /dev/null +++ b/backend/tests/unit/test_auth_routes.py @@ -0,0 +1,242 @@ +"""Tests for auth endpoint redirect_uri validation and callback template rendering.""" +import sys +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +os.environ.setdefault( + "ENCRYPTION_SECRET", + "omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv", +) + +# Stub heavy dependencies before importing the module under test +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('firebase_admin.messaging', MagicMock()) +sys.modules.setdefault('google.cloud', MagicMock()) +sys.modules.setdefault('google.cloud.firestore', MagicMock()) +sys.modules.setdefault('google.cloud.firestore_v1', MagicMock()) +sys.modules.setdefault('google.auth', MagicMock()) +sys.modules.setdefault('google.auth.transport.requests', MagicMock()) + +from fastapi import FastAPI + +from routers.auth import router as auth_router + +# Minimal test app mounting only the auth router +_test_app = FastAPI() +_test_app.include_router(auth_router) + + +# --- /v1/auth/authorize redirect_uri validation --- + +class TestAuthorizeRedirectUriValidation: + """Tests for redirect_uri allowlist at /v1/auth/authorize.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize("bad_uri", [ + "https://evil.com/steal", + "javascript:alert(1)", + "data:text/html,", + "ftp://example.com", + "", + ]) + async def test_rejects_disallowed_redirect_uri(self, bad_uri): + async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client: + resp = await client.get( + "/v1/auth/authorize", + params={"provider": "google", "redirect_uri": bad_uri, "state": "test"}, + ) + assert resp.status_code == 400 + assert "allowed app URL scheme" in resp.json()["detail"] + + @pytest.mark.asyncio + @pytest.mark.parametrize("good_uri", [ + "omi://auth/callback", + "omi-computer://auth/callback", + "omi-computer-dev://auth/callback", + ]) + @patch("routers.auth.set_auth_session") + async def test_accepts_allowed_redirect_schemes(self, mock_set_session, good_uri): + with patch("routers.auth.os.getenv") as mock_getenv: + mock_getenv.side_effect = lambda key, *args: { + "GOOGLE_CLIENT_ID": "test-client-id", + "GOOGLE_CLIENT_SECRET": "test-secret", + "BASE_API_URL": "https://api.omi.me", + "APPLE_CLIENT_ID": "me.omi.web", + "APPLE_TEAM_ID": "TEST", + "APPLE_KEY_ID": "TEST", + "APPLE_PRIVATE_KEY": "TEST", + }.get(key, args[0] if args else None) + + async with AsyncClient( + transport=ASGITransport(app=_test_app), + base_url="http://test", + follow_redirects=False, + ) as client: + resp = await client.get( + "/v1/auth/authorize", + params={"provider": "google", "redirect_uri": good_uri, "state": "test123"}, + ) + # Should redirect to Google OAuth (307) or return 200, not 400 + assert resp.status_code != 400 + # Verify session was stored with the redirect_uri + mock_set_session.assert_called_once() + session_data = mock_set_session.call_args[0][1] + assert session_data["redirect_uri"] == good_uri + + @pytest.mark.asyncio + async def test_rejects_missing_redirect_uri(self): + async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client: + resp = await client.get( + "/v1/auth/authorize", + params={"provider": "google", "state": "test"}, + ) + # FastAPI returns 422 for missing required query param + assert resp.status_code == 422 + + @pytest.mark.asyncio + async def test_rejects_invalid_provider(self): + async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client: + resp = await client.get( + "/v1/auth/authorize", + params={"provider": "github", "redirect_uri": "omi://auth/callback"}, + ) + assert resp.status_code == 400 + assert "Unsupported provider" in resp.json()["detail"] + + +# --- Google callback template rendering --- + +class TestGoogleCallbackRedirectUri: + """Tests for redirect_uri in Google OAuth callback template.""" + + @pytest.mark.asyncio + @patch("routers.auth.get_auth_session") + @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock) + @patch("routers.auth.set_auth_code") + async def test_uses_session_redirect_uri(self, mock_set_code, mock_exchange, mock_get_session): + mock_get_session.return_value = { + "provider": "google", + "redirect_uri": "omi-computer://auth/callback", + "state": "test_state", + "flow_type": "user_auth", + } + mock_exchange.return_value = '{"id_token": "test"}' + + async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client: + resp = await client.get( + "/v1/auth/callback/google", + params={"code": "test_code", "state": "test_state"}, + ) + assert resp.status_code == 200 + body = resp.text + # Template should contain the desktop redirect scheme + assert "omi-computer://auth/callback" in body + + @pytest.mark.asyncio + @patch("routers.auth.get_auth_session") + @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock) + @patch("routers.auth.set_auth_code") + async def test_falls_back_to_default_redirect_uri(self, mock_set_code, mock_exchange, mock_get_session): + mock_get_session.return_value = { + "provider": "google", + "state": "test_state", + "flow_type": "user_auth", + # No redirect_uri in session + } + mock_exchange.return_value = '{"id_token": "test"}' + + async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client: + resp = await client.get( + "/v1/auth/callback/google", + params={"code": "test_code", "state": "test_state"}, + ) + assert resp.status_code == 200 + body = resp.text + # Should fall back to omi:// scheme + assert "omi://auth/callback" in body + + +# --- Apple callback template rendering --- + +class TestAppleCallbackRedirectUri: + """Tests for redirect_uri in Apple OAuth callback template.""" + + @pytest.mark.asyncio + @patch("routers.auth.get_auth_session") + @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock) + @patch("routers.auth.set_auth_code") + async def test_uses_session_redirect_uri(self, mock_set_code, mock_exchange, mock_get_session): + mock_get_session.return_value = { + "provider": "apple", + "redirect_uri": "omi-computer://auth/callback", + "state": "test_state", + "flow_type": "user_auth", + } + mock_exchange.return_value = '{"id_token": "test"}' + + async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client: + resp = await client.post( + "/v1/auth/callback/apple", + data={"code": "test_code", "state": "test_state"}, + ) + assert resp.status_code == 200 + body = resp.text + assert "omi-computer://auth/callback" in body + + @pytest.mark.asyncio + @patch("routers.auth.get_auth_session") + @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock) + @patch("routers.auth.set_auth_code") + async def test_falls_back_to_default_redirect_uri(self, mock_set_code, mock_exchange, mock_get_session): + mock_get_session.return_value = { + "provider": "apple", + "state": "test_state", + "flow_type": "user_auth", + } + mock_exchange.return_value = '{"id_token": "test"}' + + async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client: + resp = await client.post( + "/v1/auth/callback/apple", + data={"code": "test_code", "state": "test_state"}, + ) + assert resp.status_code == 200 + body = resp.text + assert "omi://auth/callback" in body + + +# --- Template XSS safety --- + +class TestCallbackTemplateXssSafety: + """Verify that redirect_uri is safely serialized in the callback template.""" + + @pytest.mark.asyncio + @patch("routers.auth.get_auth_session") + @patch("routers.auth._exchange_provider_code_for_oauth_credentials", new_callable=AsyncMock) + @patch("routers.auth.set_auth_code") + async def test_redirect_uri_json_escaped(self, mock_set_code, mock_exchange, mock_get_session): + # Use a redirect_uri with quotes to test JSON escaping + mock_get_session.return_value = { + "provider": "google", + "redirect_uri": 'omi://auth/callback"test', + "state": "test_state", + "flow_type": "user_auth", + } + mock_exchange.return_value = '{"id_token": "test"}' + + async with AsyncClient(transport=ASGITransport(app=_test_app), base_url="http://test") as client: + resp = await client.get( + "/v1/auth/callback/google", + params={"code": "test_code", "state": "test_state"}, + ) + assert resp.status_code == 200 + body = resp.text + # The quote should be JSON-escaped, not raw + assert r'omi://auth/callback\"test' in body or r'omi:\/\/auth\/callback\"test' in body + # Should NOT contain unescaped quote that breaks out of the JS string + assert 'const redirectUri = "omi://auth/callback"test"' not in body From 1e2ca51e8d078388c69dcc64aecff917a58ac3e8 Mon Sep 17 00:00:00 2001 From: beastoin Date: Wed, 4 Mar 2026 14:48:07 +0100 Subject: [PATCH 040/163] Add auth route tests to test.sh --- backend/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/test.sh b/backend/test.sh index e431460b1e..f5ffe1a434 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -43,3 +43,4 @@ pytest tests/unit/test_storage_upload_audio_chunk_data_protection.py -v pytest tests/unit/test_people_conversations_500s.py -v pytest tests/unit/test_firestore_read_ops_cache.py -v pytest tests/unit/test_ws_auth_handshake.py -v +pytest tests/unit/test_auth_routes.py -v From f8ad41685f1bd1d06cc5d8c88daac1002d76b03b Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 05:15:31 +0100 Subject: [PATCH 041/163] Add Admin SDK fallback for custom token generation When FIREBASE_API_KEY has app restrictions (e.g. Android-only), the signInWithIdp REST API returns 403. Fall back to decoding the Google id_token JWT, looking up the user via Admin SDK (get_user_by_email), and creating a custom token directly. This makes auth work regardless of API key restrictions. --- backend/routers/auth.py | 97 +++++++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 38 deletions(-) diff --git a/backend/routers/auth.py b/backend/routers/auth.py index fb05b38c96..ea54dc6658 100644 --- a/backend/routers/auth.py +++ b/backend/routers/auth.py @@ -3,6 +3,7 @@ import json import hashlib import time +import base64 import requests import jwt from typing import Optional @@ -386,47 +387,67 @@ async def _generate_custom_token(provider: str, id_token: str, access_token: str Works with any bundle ID - perfect for multiple developers """ try: - # Get Firebase API Key from environment - firebase_api_key = os.getenv('FIREBASE_API_KEY') - if not firebase_api_key: - raise Exception("FIREBASE_API_KEY not configured") - - # Sign in with OAuth credential using Firebase Auth REST API - sign_in_url = f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp?key={firebase_api_key}" - - # Prepare the postBody based on provider - if provider == 'google': - post_body = f'id_token={id_token}&providerId=google.com' - if access_token: - post_body += f'&access_token={access_token}' - elif provider == 'apple': - post_body = f'id_token={id_token}&providerId=apple.com' - if access_token: - post_body += f'&access_token={access_token}' - else: - raise Exception(f"Unsupported provider: {provider}") - - payload = { - 'postBody': post_body, - 'requestUri': 'http://localhost', - 'returnIdpCredential': True, - 'returnSecureToken': True, - } - - # Call Firebase Auth REST API to sign in - response = requests.post(sign_in_url, json=payload) - - if response.status_code != 200: - logger.error(f"Firebase sign-in failed: {sanitize(response.text)}") - raise Exception(f"Firebase sign-in failed: status={response.status_code}") - - result = response.json() - firebase_uid = result.get('localId') + firebase_uid = None + # Try REST API first (works when FIREBASE_API_KEY has no app restrictions) + firebase_api_key = os.getenv('FIREBASE_API_KEY') + if firebase_api_key: + sign_in_url = f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp?key={firebase_api_key}" + + if provider == 'google': + post_body = f'id_token={id_token}&providerId=google.com' + if access_token: + post_body += f'&access_token={access_token}' + elif provider == 'apple': + post_body = f'id_token={id_token}&providerId=apple.com' + if access_token: + post_body += f'&access_token={access_token}' + else: + raise Exception(f"Unsupported provider: {provider}") + + payload = { + 'postBody': post_body, + 'requestUri': 'http://localhost', + 'returnIdpCredential': True, + 'returnSecureToken': True, + } + + response = requests.post(sign_in_url, json=payload) + + if response.status_code == 200: + result = response.json() + firebase_uid = result.get('localId') + if firebase_uid: + logger.info(f"Firebase sign-in successful for {provider}, UID: {firebase_uid}") + else: + logger.warning( + f"Firebase REST API sign-in failed (status={response.status_code}), falling back to Admin SDK" + ) + + # Fallback: decode id_token JWT and look up/create user via Admin SDK if not firebase_uid: - raise Exception("No Firebase UID returned from sign-in") + parts = id_token.split('.') + if len(parts) < 2: + raise Exception("Invalid id_token format") + payload_b64 = parts[1] + '=' * (4 - len(parts[1]) % 4) + token_payload = json.loads(base64.urlsafe_b64decode(payload_b64)) + email = token_payload.get('email') + if not email: + raise Exception("No email in id_token") + + # Look up existing Firebase user by email + try: + user = firebase_admin.auth.get_user_by_email(email) + firebase_uid = user.uid + logger.info(f"Found existing Firebase user for {email}, UID: {firebase_uid}") + except firebase_admin.auth.UserNotFoundError: + # Create new Firebase user + user = firebase_admin.auth.create_user(email=email, email_verified=True) + firebase_uid = user.uid + logger.info(f"Created new Firebase user for {email}, UID: {firebase_uid}") - logger.info(f"Firebase sign-in successful for {provider}, UID: {firebase_uid}") + if not firebase_uid: + raise Exception("No Firebase UID obtained") # Create custom token for this UID custom_token = firebase_admin.auth.create_custom_token(firebase_uid) From a2dafb0cb67496cf51580ce66fa005a3b77210f5 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:02:35 +0100 Subject: [PATCH 042/163] Add POST /v1/conversations/from-segments with user auth Desktop app uploads transcriptions via this endpoint but Python backend only had it in the developer API (API key auth). This adds a user-auth version to conversations router, reusing the same process_conversation pipeline. Defaults source to 'desktop', accepts timezone and input_device_name fields sent by Swift client. --- backend/routers/conversations.py | 98 +++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/backend/routers/conversations.py b/backend/routers/conversations.py index 73394980f3..3845e44851 100644 --- a/backend/routers/conversations.py +++ b/backend/routers/conversations.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks from typing import Optional, List -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta import database.conversations as conversations_db import database.action_items as action_items_db @@ -12,8 +12,10 @@ CalendarMeetingContext, Conversation, ConversationPhoto, + ConversationSource, ConversationStatus, ConversationVisibility, + CreateConversation, CreateConversationResponse, Geolocation, MergeConversationsRequest, @@ -90,6 +92,100 @@ def process_in_progress_conversation( return CreateConversationResponse(conversation=conversation, messages=messages) +class FromSegmentsTranscriptSegment(BaseModel): + text: str + speaker: Optional[str] = 'SPEAKER_00' + speaker_id: Optional[int] = None + is_user: bool = False + person_id: Optional[str] = None + start: float + end: float + + +class CreateConversationFromSegmentsRequest(BaseModel): + transcript_segments: List[FromSegmentsTranscriptSegment] + source: Optional[ConversationSource] = ConversationSource.desktop + started_at: Optional[datetime] = None + finished_at: Optional[datetime] = None + language: Optional[str] = 'en' + timezone: Optional[str] = None + input_device_name: Optional[str] = None + geolocation: Optional[Geolocation] = None + + +class FromSegmentsResponse(BaseModel): + id: str + status: str + discarded: bool + + +@router.post("/v1/conversations/from-segments", response_model=FromSegmentsResponse, tags=['conversations']) +def create_conversation_from_segments( + request: CreateConversationFromSegmentsRequest, + uid: str = Depends(auth.get_current_user_uid), +): + if not request.transcript_segments: + raise HTTPException(status_code=422, detail="transcript_segments cannot be empty") + + if len(request.transcript_segments) > 500: + raise HTTPException(status_code=422, detail="Maximum 500 transcript segments allowed") + + for idx, segment in enumerate(request.transcript_segments): + if segment.end <= segment.start: + raise HTTPException(status_code=422, detail=f"Segment {idx}: end time must be after start time") + if segment.start < 0: + raise HTTPException(status_code=422, detail=f"Segment {idx}: start time cannot be negative") + if not segment.text or len(segment.text.strip()) == 0: + raise HTTPException(status_code=422, detail=f"Segment {idx}: text cannot be empty") + + transcript_segments = [ + TranscriptSegment( + text=seg.text.strip(), + speaker=seg.speaker or 'SPEAKER_00', + speaker_id=seg.speaker_id, + is_user=seg.is_user, + person_id=seg.person_id, + start=seg.start, + end=seg.end, + ) + for seg in request.transcript_segments + ] + + started_at = request.started_at or datetime.now(timezone.utc) + if request.finished_at is not None: + finished_at = request.finished_at + else: + last_segment = request.transcript_segments[-1] + finished_at = started_at + timedelta(seconds=last_segment.end) + + if finished_at <= started_at: + raise HTTPException(status_code=422, detail="finished_at must be after started_at") + + geolocation = request.geolocation + if geolocation and not geolocation.google_place_id: + try: + geolocation = get_google_maps_location(geolocation.latitude, geolocation.longitude) + except Exception as e: + logger.error(f"Error enriching geolocation: {e}") + + create_conversation_obj = CreateConversation( + transcript_segments=transcript_segments, + started_at=started_at, + finished_at=finished_at, + language=request.language or 'en', + geolocation=geolocation, + source=request.source or ConversationSource.desktop, + ) + + conversation = process_conversation(uid, request.language or 'en', create_conversation_obj) + + return FromSegmentsResponse( + id=conversation.id, + status=conversation.status.value if conversation.status else 'completed', + discarded=conversation.discarded, + ) + + @router.post('/v1/conversations/{conversation_id}/reprocess', response_model=Conversation, tags=['conversations']) def reprocess_conversation( conversation_id: str, From 38599d1cba6d6a813e915f1c9c45d49428eee354 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:02:40 +0100 Subject: [PATCH 043/163] Add tests for from-segments endpoint (16 tests) --- backend/test.sh | 9 + backend/tests/unit/test_from_segments.py | 233 +++++++++++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 backend/tests/unit/test_from_segments.py diff --git a/backend/test.sh b/backend/test.sh index f5ffe1a434..ae4a3dbc41 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -44,3 +44,12 @@ pytest tests/unit/test_people_conversations_500s.py -v pytest tests/unit/test_firestore_read_ops_cache.py -v pytest tests/unit/test_ws_auth_handshake.py -v pytest tests/unit/test_auth_routes.py -v +pytest tests/unit/test_from_segments.py -v +pytest tests/unit/test_desktop_chat.py -v +pytest tests/unit/test_screen_activity_sync.py -v +pytest tests/unit/test_assistant_settings_ai_profile.py -v +pytest tests/unit/test_focus_sessions.py -v +pytest tests/unit/test_advice.py -v +pytest tests/unit/test_staged_tasks.py -v +pytest tests/unit/test_chat_generate_title.py -v +pytest tests/unit/test_conversations_count.py -v diff --git a/backend/tests/unit/test_from_segments.py b/backend/tests/unit/test_from_segments.py new file mode 100644 index 0000000000..e068165921 --- /dev/null +++ b/backend/tests/unit/test_from_segments.py @@ -0,0 +1,233 @@ +"""Tests for POST /v1/conversations/from-segments endpoint models and validation.""" +import sys +from unittest.mock import MagicMock +from datetime import datetime, timezone, timedelta + +import pytest + +# Stub ALL heavy dependencies before any import that could transitively pull them in. +# Order matters: stub parent packages before child packages. +for mod_name in [ + 'firebase_admin', 'firebase_admin.auth', 'firebase_admin.firestore', 'firebase_admin.messaging', + 'google.cloud', 'google.cloud.exceptions', 'google.cloud.firestore', 'google.cloud.firestore_v1', + 'google.cloud.firestore_v1.base_query', 'google.cloud.firestore_v1.query', + 'google.cloud.storage', 'google.cloud.storage.blob', 'google.cloud.storage.bucket', + 'google.auth', 'google.auth.transport', 'google.auth.transport.requests', + 'google.oauth2', 'google.oauth2.service_account', + 'pinecone', + 'typesense', +]: + sys.modules.setdefault(mod_name, MagicMock()) + +from routers.conversations import ( + FromSegmentsTranscriptSegment, + CreateConversationFromSegmentsRequest, + FromSegmentsResponse, +) + + +@pytest.fixture +def valid_segments(): + return [ + FromSegmentsTranscriptSegment(text="Hello there", speaker="SPEAKER_00", is_user=True, start=0.0, end=2.5), + FromSegmentsTranscriptSegment(text="Hi, how are you?", speaker="SPEAKER_01", is_user=False, start=2.8, end=5.2), + ] + + +class TestFromSegmentsModels: + def test_segment_defaults(self): + seg = FromSegmentsTranscriptSegment(text="Hello", start=0.0, end=1.0) + assert seg.speaker == "SPEAKER_00" + assert seg.is_user is False + assert seg.person_id is None + assert seg.speaker_id is None + + def test_request_defaults(self, valid_segments): + req = CreateConversationFromSegmentsRequest(transcript_segments=valid_segments) + assert req.source == "desktop" + assert req.language == "en" + assert req.started_at is None + assert req.finished_at is None + assert req.timezone is None + assert req.input_device_name is None + + def test_response_model(self): + resp = FromSegmentsResponse(id="conv123", status="completed", discarded=False) + assert resp.id == "conv123" + assert resp.status == "completed" + assert resp.discarded is False + + +class TestFromSegmentsValidation: + def test_segment_with_all_fields(self): + seg = FromSegmentsTranscriptSegment( + text="Hello", + speaker="SPEAKER_01", + speaker_id=1, + is_user=True, + person_id="person123", + start=10.5, + end=15.3, + ) + assert seg.speaker_id == 1 + assert seg.person_id == "person123" + + def test_desktop_source_default(self, valid_segments): + req = CreateConversationFromSegmentsRequest(transcript_segments=valid_segments) + assert req.source == "desktop" + + def test_custom_source(self, valid_segments): + req = CreateConversationFromSegmentsRequest(transcript_segments=valid_segments, source="phone") + assert req.source == "phone" + + def test_timezone_and_input_device_accepted(self, valid_segments): + req = CreateConversationFromSegmentsRequest( + transcript_segments=valid_segments, + timezone="America/New_York", + input_device_name="MacBook Pro Microphone", + ) + assert req.timezone == "America/New_York" + assert req.input_device_name == "MacBook Pro Microphone" + + def test_started_finished_at(self, valid_segments): + now = datetime.now(timezone.utc) + later = now + timedelta(minutes=5) + req = CreateConversationFromSegmentsRequest( + transcript_segments=valid_segments, + started_at=now, + finished_at=later, + ) + assert req.started_at == now + assert req.finished_at == later + + def test_500_segments_accepted(self): + segs = [FromSegmentsTranscriptSegment(text=f"seg {i}", start=float(i), end=float(i + 1)) for i in range(500)] + req = CreateConversationFromSegmentsRequest(transcript_segments=segs) + assert len(req.transcript_segments) == 500 + + def test_geolocation_accepted(self, valid_segments): + req = CreateConversationFromSegmentsRequest( + transcript_segments=valid_segments, + geolocation={'latitude': 37.7749, 'longitude': -122.4194}, + ) + assert req.geolocation is not None + + +class TestFromSegmentsEndpoint: + """Endpoint-level tests using FastAPI TestClient with mocked auth and processing.""" + + def _make_app(self): + from fastapi import FastAPI + from routers.conversations import router + app = FastAPI() + app.include_router(router) + return app + + @pytest.fixture + def client(self): + from fastapi.testclient import TestClient + return TestClient(self._make_app()) + + def test_successful_creation(self, client): + with ( + patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'), + patch('routers.conversations.process_conversation') as mock_process, + patch('routers.conversations.get_google_maps_location'), + ): + mock_conv = MagicMock() + mock_conv.id = 'conv-abc' + mock_conv.status.value = 'completed' + mock_conv.discarded = False + mock_process.return_value = mock_conv + + response = client.post( + '/v1/conversations/from-segments', + json={ + 'transcript_segments': [ + {'text': 'Hello there', 'speaker': 'SPEAKER_00', 'is_user': True, 'start': 0.0, 'end': 2.5}, + {'text': 'Hi!', 'speaker': 'SPEAKER_01', 'is_user': False, 'start': 2.8, 'end': 5.2}, + ], + 'source': 'desktop', + 'language': 'en', + }, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 200 + data = response.json() + assert data['id'] == 'conv-abc' + assert data['status'] == 'completed' + assert data['discarded'] is False + mock_process.assert_called_once() + + def test_invalid_segment_times_returns_422(self, client): + with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'): + response = client.post( + '/v1/conversations/from-segments', + json={'transcript_segments': [{'text': 'Hello', 'start': 5.0, 'end': 3.0}]}, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 422 + + def test_empty_text_returns_422(self, client): + with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'): + response = client.post( + '/v1/conversations/from-segments', + json={'transcript_segments': [{'text': ' ', 'start': 0.0, 'end': 1.0}]}, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 422 + + def test_negative_start_returns_422(self, client): + with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'): + response = client.post( + '/v1/conversations/from-segments', + json={'transcript_segments': [{'text': 'Hello', 'start': -1.0, 'end': 1.0}]}, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 422 + + def test_finished_at_auto_calculated(self, client): + with ( + patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'), + patch('routers.conversations.process_conversation') as mock_process, + patch('routers.conversations.get_google_maps_location'), + ): + mock_conv = MagicMock() + mock_conv.id = 'conv-calc' + mock_conv.status.value = 'completed' + mock_conv.discarded = False + mock_process.return_value = mock_conv + + response = client.post( + '/v1/conversations/from-segments', + json={'transcript_segments': [{'text': 'Hello', 'start': 0.0, 'end': 30.0}], 'source': 'desktop'}, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 200 + create_obj = mock_process.call_args[0][2] + assert create_obj.finished_at > create_obj.started_at + + def test_source_defaults_to_desktop(self, client): + with ( + patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'), + patch('routers.conversations.process_conversation') as mock_process, + patch('routers.conversations.get_google_maps_location'), + ): + mock_conv = MagicMock() + mock_conv.id = 'conv-def' + mock_conv.status.value = 'completed' + mock_conv.discarded = False + mock_process.return_value = mock_conv + + response = client.post( + '/v1/conversations/from-segments', + json={'transcript_segments': [{'text': 'Hello', 'start': 0.0, 'end': 1.0}]}, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 200 + create_obj = mock_process.call_args[0][2] + assert create_obj.source.value == 'desktop' + + +# Keep patch import at module scope for the with-statement usage +from unittest.mock import patch From 29cdee6a81e5bf0883b6ee161b7802f686877a05 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:06:53 +0100 Subject: [PATCH 044/163] Add chat session list/update and message save/rating DB functions --- backend/database/chat.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/backend/database/chat.py b/backend/database/chat.py index 68ebcf51fc..03739a6d4c 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -468,6 +468,39 @@ def delete_chat_session(uid, chat_session_id): session_ref.delete() +def get_chat_sessions( + uid: str, app_id: Optional[str] = None, limit: int = 50, offset: int = 0, starred: Optional[bool] = None +): + """List chat sessions with optional filters.""" + sessions_ref = db.collection('users').document(uid).collection('chat_sessions') + sessions_ref = sessions_ref.where(filter=FieldFilter('plugin_id', '==', app_id)) + if starred is not None: + sessions_ref = sessions_ref.where(filter=FieldFilter('starred', '==', starred)) + sessions_ref = sessions_ref.order_by('updated_at', direction=firestore.Query.DESCENDING).limit(limit).offset(offset) + return [doc.to_dict() for doc in sessions_ref.stream()] + + +def update_chat_session(uid: str, chat_session_id: str, update_data: dict): + """Partial update of a chat session.""" + user_ref = db.collection('users').document(uid) + session_ref = user_ref.collection('chat_sessions').document(chat_session_id) + session_ref.update(update_data) + + +def save_message(uid: str, message_data: dict): + """Save a message directly by document ID (for desktop CRUD).""" + user_ref = db.collection('users').document(uid) + user_ref.collection('messages').document(message_data['id']).set(message_data) + return message_data + + +def update_message_rating(uid: str, message_id: str, rating: Optional[int]): + """Update the rating on a message.""" + user_ref = db.collection('users').document(uid) + message_ref = user_ref.collection('messages').document(message_id) + message_ref.update({'rating': rating}) + + def add_message_to_chat_session(uid: str, chat_session_id: str, message_id: str): user_ref = db.collection('users').document(uid) session_ref = user_ref.collection('chat_sessions').document(chat_session_id) From 0acc57d8b8f53e9849283f03fcfd7fbc50453dec Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:06:59 +0100 Subject: [PATCH 045/163] Add desktop chat sessions CRUD + message rating endpoints New router for desktop app's session-based chat: - GET/POST/GET/:id/PATCH/:id/DELETE /v2/chat-sessions - POST /v2/desktop/messages (simple save, not streaming) - PATCH /v2/messages/:id/rating --- backend/routers/desktop_chat.py | 211 ++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 backend/routers/desktop_chat.py diff --git a/backend/routers/desktop_chat.py b/backend/routers/desktop_chat.py new file mode 100644 index 0000000000..06f4b636fa --- /dev/null +++ b/backend/routers/desktop_chat.py @@ -0,0 +1,211 @@ +"""Desktop chat sessions CRUD + message operations. + +These endpoints support the desktop app's session-based chat model where +messages are organized into named sessions. The Python backend's existing +streaming chat (routers/chat.py) is session-aware internally, but the +desktop Swift client expects explicit CRUD for sessions and simple +message save/rating. +""" + +import uuid +from datetime import datetime, timezone +from typing import Optional, List + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel + +import database.chat as chat_db +from utils.other import endpoints as auth +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + + +class CreateChatSessionRequest(BaseModel): + title: Optional[str] = None + app_id: Optional[str] = None + + +class UpdateChatSessionRequest(BaseModel): + title: Optional[str] = None + starred: Optional[bool] = None + + +class ChatSessionResponse(BaseModel): + id: str + title: str + preview: Optional[str] = None + created_at: datetime + updated_at: datetime + app_id: Optional[str] = None + message_count: int = 0 + starred: bool = False + + +class SaveMessageRequest(BaseModel): + text: str + sender: str + app_id: Optional[str] = None + session_id: Optional[str] = None + metadata: Optional[str] = None + + +class SaveMessageResponse(BaseModel): + id: str + created_at: datetime + + +class RateMessageRequest(BaseModel): + rating: Optional[int] = None + + +class StatusResponse(BaseModel): + status: str + + +# --------------------------------------------------------------------------- +# Chat Sessions CRUD +# --------------------------------------------------------------------------- + + +@router.get('/v2/chat-sessions', response_model=List[ChatSessionResponse], tags=['desktop-chat']) +def list_chat_sessions( + app_id: Optional[str] = Query(None), + limit: int = Query(50, ge=1, le=200), + offset: int = Query(0, ge=0), + starred: Optional[bool] = Query(None), + uid: str = Depends(auth.get_current_user_uid), +): + sessions = chat_db.get_chat_sessions(uid, app_id=app_id, limit=limit, offset=offset, starred=starred) + return sessions + + +@router.post('/v2/chat-sessions', response_model=ChatSessionResponse, tags=['desktop-chat']) +def create_chat_session( + request: CreateChatSessionRequest, + uid: str = Depends(auth.get_current_user_uid), +): + now = datetime.now(timezone.utc) + session_data = { + 'id': str(uuid.uuid4()), + 'title': request.title or 'New Chat', + 'preview': None, + 'created_at': now, + 'updated_at': now, + 'app_id': request.app_id, + 'plugin_id': request.app_id, # Python backend uses plugin_id for filtering + 'message_count': 0, + 'starred': False, + } + chat_db.add_chat_session(uid, session_data) + return session_data + + +@router.get('/v2/chat-sessions/{session_id}', response_model=ChatSessionResponse, tags=['desktop-chat']) +def get_chat_session( + session_id: str, + uid: str = Depends(auth.get_current_user_uid), +): + session = chat_db.get_chat_session_by_id(uid, session_id) + if not session: + raise HTTPException(status_code=404, detail="Chat session not found") + return session + + +@router.patch('/v2/chat-sessions/{session_id}', response_model=StatusResponse, tags=['desktop-chat']) +def update_chat_session( + session_id: str, + request: UpdateChatSessionRequest, + uid: str = Depends(auth.get_current_user_uid), +): + session = chat_db.get_chat_session_by_id(uid, session_id) + if not session: + raise HTTPException(status_code=404, detail="Chat session not found") + + update_data = {} + if request.title is not None: + update_data['title'] = request.title + if request.starred is not None: + update_data['starred'] = request.starred + if update_data: + update_data['updated_at'] = datetime.now(timezone.utc) + chat_db.update_chat_session(uid, session_id, update_data) + + return StatusResponse(status='ok') + + +@router.delete('/v2/chat-sessions/{session_id}', response_model=StatusResponse, tags=['desktop-chat']) +def delete_chat_session( + session_id: str, + uid: str = Depends(auth.get_current_user_uid), +): + session = chat_db.get_chat_session_by_id(uid, session_id) + if not session: + raise HTTPException(status_code=404, detail="Chat session not found") + + chat_db.delete_chat_session(uid, session_id) + return StatusResponse(status='ok') + + +# --------------------------------------------------------------------------- +# Desktop message CRUD (simple save, not streaming) +# --------------------------------------------------------------------------- + + +@router.post('/v2/desktop/messages', response_model=SaveMessageResponse, tags=['desktop-chat']) +def save_message( + request: SaveMessageRequest, + uid: str = Depends(auth.get_current_user_uid), +): + if not request.text or not request.text.strip(): + raise HTTPException(status_code=422, detail="Message text cannot be empty") + if request.sender not in ('human', 'ai'): + raise HTTPException(status_code=422, detail="sender must be 'human' or 'ai'") + + now = datetime.now(timezone.utc) + message_id = str(uuid.uuid4()) + message_data = { + 'id': message_id, + 'text': request.text, + 'created_at': now, + 'sender': request.sender, + 'app_id': request.app_id, + 'plugin_id': request.app_id, + 'session_id': request.session_id, + 'chat_session_id': request.session_id, + 'rating': None, + 'reported': False, + 'type': 'text', + 'memories_id': [], + 'from_external_integration': False, + 'metadata': request.metadata, + } + chat_db.save_message(uid, message_data) + + if request.session_id: + try: + chat_db.add_message_to_chat_session(uid, request.session_id, message_id) + except Exception as e: + logger.warning(f"Failed to link message to session {request.session_id}: {e}") + + return SaveMessageResponse(id=message_id, created_at=now) + + +@router.patch('/v2/messages/{message_id}/rating', response_model=StatusResponse, tags=['desktop-chat']) +def rate_message( + message_id: str, + request: RateMessageRequest, + uid: str = Depends(auth.get_current_user_uid), +): + if request.rating is not None and request.rating not in (1, -1): + raise HTTPException(status_code=422, detail="rating must be 1, -1, or null") + + chat_db.update_message_rating(uid, message_id, request.rating) + return StatusResponse(status='ok') From e485cf4532f3bdac814b6a36091d20d97f143119 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:07:04 +0100 Subject: [PATCH 046/163] Register desktop_chat router in main.py --- backend/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/main.py b/backend/main.py index 827cd4d34b..394316e6cb 100644 --- a/backend/main.py +++ b/backend/main.py @@ -45,6 +45,7 @@ announcements, phone_calls, agent_tools, + desktop_chat, ) from utils.other.timeout import TimeoutMiddleware @@ -104,6 +105,7 @@ app.include_router(announcements.router) app.include_router(phone_calls.router) app.include_router(agent_tools.router) +app.include_router(desktop_chat.router) methods_timeout = { From 112cfab64cbb43f5e721906b52d400e6f97c7886 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:07:14 +0100 Subject: [PATCH 047/163] Add tests for desktop chat endpoints (18 tests) --- backend/tests/unit/test_desktop_chat.py | 255 ++++++++++++++++++++++++ 1 file changed, 255 insertions(+) create mode 100644 backend/tests/unit/test_desktop_chat.py diff --git a/backend/tests/unit/test_desktop_chat.py b/backend/tests/unit/test_desktop_chat.py new file mode 100644 index 0000000000..031b2d22fb --- /dev/null +++ b/backend/tests/unit/test_desktop_chat.py @@ -0,0 +1,255 @@ +"""Tests for desktop chat sessions CRUD + message rating endpoints.""" +import sys +from unittest.mock import patch, MagicMock +from datetime import datetime, timezone + +import pytest + +for mod_name in [ + 'firebase_admin', 'firebase_admin.auth', 'firebase_admin.firestore', 'firebase_admin.messaging', + 'google.cloud', 'google.cloud.exceptions', 'google.cloud.firestore', 'google.cloud.firestore_v1', + 'google.cloud.firestore_v1.base_query', 'google.cloud.firestore_v1.query', + 'google.cloud.storage', 'google.cloud.storage.blob', 'google.cloud.storage.bucket', + 'google.auth', 'google.auth.transport', 'google.auth.transport.requests', + 'google.oauth2', 'google.oauth2.service_account', + 'pinecone', 'typesense', +]: + sys.modules.setdefault(mod_name, MagicMock()) + +from routers.desktop_chat import ( + CreateChatSessionRequest, + UpdateChatSessionRequest, + ChatSessionResponse, + SaveMessageRequest, + SaveMessageResponse, + RateMessageRequest, + StatusResponse, + router, +) + + +class TestChatSessionModels: + def test_create_request_defaults(self): + req = CreateChatSessionRequest() + assert req.title is None + assert req.app_id is None + + def test_update_request_partial(self): + req = UpdateChatSessionRequest(title="New Title") + assert req.title == "New Title" + assert req.starred is None + + def test_session_response(self): + now = datetime.now(timezone.utc) + resp = ChatSessionResponse(id="s1", title="Test", created_at=now, updated_at=now) + assert resp.message_count == 0 + assert resp.starred is False + + def test_save_message_request(self): + req = SaveMessageRequest(text="Hello", sender="human") + assert req.app_id is None + assert req.session_id is None + + def test_rate_request(self): + req = RateMessageRequest(rating=1) + assert req.rating == 1 + req2 = RateMessageRequest() + assert req2.rating is None + + +class TestChatSessionEndpoints: + def _make_app(self): + from fastapi import FastAPI + app = FastAPI() + app.include_router(router) + return app + + @pytest.fixture + def client(self): + from fastapi.testclient import TestClient + return TestClient(self._make_app()) + + def test_create_session(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.add_chat_session') as mock_add, + ): + mock_add.side_effect = lambda uid, data: data + response = client.post( + '/v2/chat-sessions', + json={'title': 'My Chat', 'app_id': None}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + data = response.json() + assert data['title'] == 'My Chat' + assert data['message_count'] == 0 + assert data['starred'] is False + assert 'id' in data + + def test_create_session_default_title(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.add_chat_session') as mock_add, + ): + mock_add.side_effect = lambda uid, data: data + response = client.post( + '/v2/chat-sessions', + json={}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert response.json()['title'] == 'New Chat' + + def test_list_sessions(self, client): + now = datetime.now(timezone.utc) + mock_sessions = [ + {'id': 's1', 'title': 'Chat 1', 'created_at': now, 'updated_at': now, 'message_count': 5, 'starred': False}, + {'id': 's2', 'title': 'Chat 2', 'created_at': now, 'updated_at': now, 'message_count': 3, 'starred': True}, + ] + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.get_chat_sessions', return_value=mock_sessions), + ): + response = client.get('/v2/chat-sessions', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]['title'] == 'Chat 1' + + def test_get_session(self, client): + now = datetime.now(timezone.utc) + mock_session = {'id': 's1', 'title': 'Chat', 'created_at': now, 'updated_at': now, 'message_count': 0, 'starred': False} + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=mock_session), + ): + response = client.get('/v2/chat-sessions/s1', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert response.json()['id'] == 's1' + + def test_get_session_not_found(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=None), + ): + response = client.get('/v2/chat-sessions/missing', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 404 + + def test_update_session(self, client): + now = datetime.now(timezone.utc) + mock_session = {'id': 's1', 'title': 'Old', 'created_at': now, 'updated_at': now} + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=mock_session), + patch('routers.desktop_chat.chat_db.update_chat_session') as mock_update, + ): + response = client.patch( + '/v2/chat-sessions/s1', + json={'title': 'Renamed', 'starred': True}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + call_data = mock_update.call_args[0][2] + assert call_data['title'] == 'Renamed' + assert call_data['starred'] is True + + def test_delete_session(self, client): + now = datetime.now(timezone.utc) + mock_session = {'id': 's1', 'title': 'Del', 'created_at': now, 'updated_at': now} + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=mock_session), + patch('routers.desktop_chat.chat_db.delete_chat_session') as mock_del, + ): + response = client.delete('/v2/chat-sessions/s1', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert mock_del.called + assert mock_del.call_args[0][1] == 's1' + + +class TestDesktopMessageEndpoints: + def _make_app(self): + from fastapi import FastAPI + app = FastAPI() + app.include_router(router) + return app + + @pytest.fixture + def client(self): + from fastapi.testclient import TestClient + return TestClient(self._make_app()) + + def test_save_message(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.save_message') as mock_save, + patch('routers.desktop_chat.chat_db.add_message_to_chat_session'), + ): + mock_save.side_effect = lambda uid, data: data + response = client.post( + '/v2/desktop/messages', + json={'text': 'Hello', 'sender': 'human', 'session_id': 's1'}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + data = response.json() + assert 'id' in data + assert 'created_at' in data + + def test_save_message_empty_text_422(self, client): + with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + response = client.post( + '/v2/desktop/messages', + json={'text': ' ', 'sender': 'human'}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 + + def test_save_message_invalid_sender_422(self, client): + with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + response = client.post( + '/v2/desktop/messages', + json={'text': 'Hello', 'sender': 'bot'}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 + + def test_rate_message_thumbs_up(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.update_message_rating') as mock_rate, + ): + response = client.patch( + '/v2/messages/msg-1/rating', + json={'rating': 1}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert mock_rate.called + assert mock_rate.call_args[0][1] == 'msg-1' + assert mock_rate.call_args[0][2] == 1 + + def test_rate_message_clear(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.update_message_rating') as mock_rate, + ): + response = client.patch( + '/v2/messages/msg-1/rating', + json={'rating': None}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert mock_rate.called + assert mock_rate.call_args[0][1] == 'msg-1' + assert mock_rate.call_args[0][2] is None + + def test_rate_message_invalid_value_422(self, client): + with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + response = client.patch( + '/v2/messages/msg-1/rating', + json={'rating': 5}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 From e5eae14d3e4cb1c9cc4e8b1efd96185fd679ec0e Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:09:26 +0100 Subject: [PATCH 048/163] =?UTF-8?q?Fix=20chat=20sessions=20list=20?= =?UTF-8?q?=E2=80=94=20client-side=20sort=20to=20avoid=20composite=20index?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/database/chat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/database/chat.py b/backend/database/chat.py index 03739a6d4c..1a82c95c56 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -476,8 +476,10 @@ def get_chat_sessions( sessions_ref = sessions_ref.where(filter=FieldFilter('plugin_id', '==', app_id)) if starred is not None: sessions_ref = sessions_ref.where(filter=FieldFilter('starred', '==', starred)) - sessions_ref = sessions_ref.order_by('updated_at', direction=firestore.Query.DESCENDING).limit(limit).offset(offset) - return [doc.to_dict() for doc in sessions_ref.stream()] + sessions_ref = sessions_ref.limit(limit).offset(offset) + sessions = [doc.to_dict() for doc in sessions_ref.stream()] + sessions.sort(key=lambda s: s.get('updated_at', s.get('created_at', datetime.min)), reverse=True) + return sessions def update_chat_session(uid: str, chat_session_id: str, update_data: dict): From af49fd9822a6915636d33a4e91be5c8cf21a32da Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:21:07 +0100 Subject: [PATCH 049/163] =?UTF-8?q?Fix=20reviewer=20issues=20in=20database?= =?UTF-8?q?/chat.py=20=E2=80=94=20remove=20duplicate=20update=5Fmessage=5F?= =?UTF-8?q?rating,=20add=20data=20protection=20decorators=20to=20save=5Fme?= =?UTF-8?q?ssage,=20fix=20get=5Fchat=5Fsessions=20filtering=20and=20pagina?= =?UTF-8?q?tion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/database/chat.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/backend/database/chat.py b/backend/database/chat.py index 1a82c95c56..e57c7f9b17 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -473,13 +473,13 @@ def get_chat_sessions( ): """List chat sessions with optional filters.""" sessions_ref = db.collection('users').document(uid).collection('chat_sessions') - sessions_ref = sessions_ref.where(filter=FieldFilter('plugin_id', '==', app_id)) + if app_id is not None: + sessions_ref = sessions_ref.where(filter=FieldFilter('plugin_id', '==', app_id)) if starred is not None: sessions_ref = sessions_ref.where(filter=FieldFilter('starred', '==', starred)) - sessions_ref = sessions_ref.limit(limit).offset(offset) sessions = [doc.to_dict() for doc in sessions_ref.stream()] sessions.sort(key=lambda s: s.get('updated_at', s.get('created_at', datetime.min)), reverse=True) - return sessions + return sessions[offset : offset + limit] def update_chat_session(uid: str, chat_session_id: str, update_data: dict): @@ -489,6 +489,8 @@ def update_chat_session(uid: str, chat_session_id: str, update_data: dict): session_ref.update(update_data) +@set_data_protection_level(data_arg_name='message_data') +@prepare_for_write(data_arg_name='message_data', prepare_func=_prepare_data_for_write) def save_message(uid: str, message_data: dict): """Save a message directly by document ID (for desktop CRUD).""" user_ref = db.collection('users').document(uid) @@ -496,11 +498,21 @@ def save_message(uid: str, message_data: dict): return message_data -def update_message_rating(uid: str, message_id: str, rating: Optional[int]): - """Update the rating on a message.""" +def delete_chat_session_messages(uid: str, chat_session_id: str): + """Delete all messages belonging to a chat session.""" user_ref = db.collection('users').document(uid) - message_ref = user_ref.collection('messages').document(message_id) - message_ref.update({'rating': rating}) + messages_ref = user_ref.collection('messages').where(filter=FieldFilter('chat_session_id', '==', chat_session_id)) + batch = db.batch() + count = 0 + for doc in messages_ref.stream(): + batch.delete(doc.reference) + count += 1 + if count % 400 == 0: + batch.commit() + batch = db.batch() + if count % 400 != 0: + batch.commit() + logger.info(f"Deleted {count} messages for session {chat_session_id}") def add_message_to_chat_session(uid: str, chat_session_id: str, message_id: str): From 842626a7dba4ed0ea8d3c19a37353aaf9442a41d Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:21:11 +0100 Subject: [PATCH 050/163] Fix PATCH response to return full ChatSessionResponse, add cascade delete for session messages --- backend/routers/desktop_chat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/routers/desktop_chat.py b/backend/routers/desktop_chat.py index 06f4b636fa..355b47b936 100644 --- a/backend/routers/desktop_chat.py +++ b/backend/routers/desktop_chat.py @@ -119,7 +119,7 @@ def get_chat_session( return session -@router.patch('/v2/chat-sessions/{session_id}', response_model=StatusResponse, tags=['desktop-chat']) +@router.patch('/v2/chat-sessions/{session_id}', response_model=ChatSessionResponse, tags=['desktop-chat']) def update_chat_session( session_id: str, request: UpdateChatSessionRequest, @@ -137,8 +137,9 @@ def update_chat_session( if update_data: update_data['updated_at'] = datetime.now(timezone.utc) chat_db.update_chat_session(uid, session_id, update_data) + session.update(update_data) - return StatusResponse(status='ok') + return session @router.delete('/v2/chat-sessions/{session_id}', response_model=StatusResponse, tags=['desktop-chat']) @@ -150,6 +151,7 @@ def delete_chat_session( if not session: raise HTTPException(status_code=404, detail="Chat session not found") + chat_db.delete_chat_session_messages(uid, session_id) chat_db.delete_chat_session(uid, session_id) return StatusResponse(status='ok') From a8248b4b403da018556895040cf93898cb01e2e3 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:21:14 +0100 Subject: [PATCH 051/163] =?UTF-8?q?Update=20tests=20for=20reviewer=20fixes?= =?UTF-8?q?=20=E2=80=94=20verify=20cascade=20delete=20and=20full=20session?= =?UTF-8?q?=20response?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/tests/unit/test_desktop_chat.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/backend/tests/unit/test_desktop_chat.py b/backend/tests/unit/test_desktop_chat.py index 031b2d22fb..5d017687dc 100644 --- a/backend/tests/unit/test_desktop_chat.py +++ b/backend/tests/unit/test_desktop_chat.py @@ -136,9 +136,9 @@ def test_get_session_not_found(self, client): response = client.get('/v2/chat-sessions/missing', headers={'Authorization': 'Bearer test'}) assert response.status_code == 404 - def test_update_session(self, client): + def test_update_session_returns_full_session(self, client): now = datetime.now(timezone.utc) - mock_session = {'id': 's1', 'title': 'Old', 'created_at': now, 'updated_at': now} + mock_session = {'id': 's1', 'title': 'Old', 'created_at': now, 'updated_at': now, 'message_count': 0, 'starred': False} with ( patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=mock_session), @@ -150,20 +150,24 @@ def test_update_session(self, client): headers={'Authorization': 'Bearer test'}, ) assert response.status_code == 200 - call_data = mock_update.call_args[0][2] - assert call_data['title'] == 'Renamed' - assert call_data['starred'] is True + data = response.json() + assert data['title'] == 'Renamed' + assert data['starred'] is True + assert data['id'] == 's1' - def test_delete_session(self, client): + def test_delete_session_cascades_messages(self, client): now = datetime.now(timezone.utc) mock_session = {'id': 's1', 'title': 'Del', 'created_at': now, 'updated_at': now} with ( patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=mock_session), + patch('routers.desktop_chat.chat_db.delete_chat_session_messages') as mock_del_msgs, patch('routers.desktop_chat.chat_db.delete_chat_session') as mock_del, ): response = client.delete('/v2/chat-sessions/s1', headers={'Authorization': 'Bearer test'}) assert response.status_code == 200 + assert mock_del_msgs.called + assert mock_del_msgs.call_args[0][1] == 's1' assert mock_del.called assert mock_del.call_args[0][1] == 's1' From 28c1cda50d1c88f274d27f1699e462dc0bd6b9b9 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:26:44 +0100 Subject: [PATCH 052/163] Add endpoint-level tests for from-segments boundary cases and error paths --- backend/tests/unit/test_from_segments.py | 58 ++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/backend/tests/unit/test_from_segments.py b/backend/tests/unit/test_from_segments.py index e068165921..7199bcb80f 100644 --- a/backend/tests/unit/test_from_segments.py +++ b/backend/tests/unit/test_from_segments.py @@ -228,6 +228,64 @@ def test_source_defaults_to_desktop(self, client): create_obj = mock_process.call_args[0][2] assert create_obj.source.value == 'desktop' + def test_empty_segments_list_returns_422(self, client): + with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'): + response = client.post( + '/v1/conversations/from-segments', + json={'transcript_segments': []}, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 422 + + def test_over_500_segments_returns_422(self, client): + with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'): + segments = [{'text': f'seg {i}', 'start': float(i), 'end': float(i + 1)} for i in range(501)] + response = client.post( + '/v1/conversations/from-segments', + json={'transcript_segments': segments}, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 422 + assert '500' in response.json()['detail'] + + def test_finished_at_before_started_at_returns_422(self, client): + with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'): + now = datetime.now(timezone.utc) + earlier = now - timedelta(hours=1) + response = client.post( + '/v1/conversations/from-segments', + json={ + 'transcript_segments': [{'text': 'Hello', 'start': 0.0, 'end': 1.0}], + 'started_at': now.isoformat(), + 'finished_at': earlier.isoformat(), + }, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 422 + assert 'finished_at' in response.json()['detail'] + + def test_geolocation_enrichment_failure_continues(self, client): + with ( + patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'), + patch('routers.conversations.process_conversation') as mock_process, + patch('routers.conversations.get_google_maps_location', side_effect=Exception('API error')), + ): + mock_conv = MagicMock() + mock_conv.id = 'conv-geo' + mock_conv.status.value = 'completed' + mock_conv.discarded = False + mock_process.return_value = mock_conv + + response = client.post( + '/v1/conversations/from-segments', + json={ + 'transcript_segments': [{'text': 'Hello', 'start': 0.0, 'end': 1.0}], + 'geolocation': {'latitude': 37.7749, 'longitude': -122.4194}, + }, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 200 + # Keep patch import at module scope for the with-statement usage from unittest.mock import patch From 08f93282ca12218aea068924f49e28150132eb96 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:26:48 +0100 Subject: [PATCH 053/163] Add desktop chat tests for not-found, session link failure, query bounds, thumbs-down rating --- backend/tests/unit/test_desktop_chat.py | 67 +++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/backend/tests/unit/test_desktop_chat.py b/backend/tests/unit/test_desktop_chat.py index 5d017687dc..a409c3aea5 100644 --- a/backend/tests/unit/test_desktop_chat.py +++ b/backend/tests/unit/test_desktop_chat.py @@ -249,6 +249,19 @@ def test_rate_message_clear(self, client): assert mock_rate.call_args[0][1] == 'msg-1' assert mock_rate.call_args[0][2] is None + def test_rate_message_thumbs_down(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.update_message_rating') as mock_rate, + ): + response = client.patch( + '/v2/messages/msg-1/rating', + json={'rating': -1}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert mock_rate.call_args[0][2] == -1 + def test_rate_message_invalid_value_422(self, client): with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): response = client.patch( @@ -257,3 +270,57 @@ def test_rate_message_invalid_value_422(self, client): headers={'Authorization': 'Bearer test'}, ) assert response.status_code == 422 + + def test_update_session_not_found(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=None), + ): + response = client.patch( + '/v2/chat-sessions/missing', + json={'title': 'Renamed'}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 404 + + def test_delete_session_not_found(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=None), + ): + response = client.delete( + '/v2/chat-sessions/missing', + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 404 + + def test_save_message_session_link_failure_still_succeeds(self, client): + with ( + patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.desktop_chat.chat_db.save_message') as mock_save, + patch('routers.desktop_chat.chat_db.add_message_to_chat_session', side_effect=Exception('Firestore error')), + ): + mock_save.side_effect = lambda uid, data: data + response = client.post( + '/v2/desktop/messages', + json={'text': 'Hello', 'sender': 'human', 'session_id': 's1'}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert 'id' in response.json() + + def test_list_sessions_limit_validation(self, client): + with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + response = client.get( + '/v2/chat-sessions?limit=0', + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 + + def test_list_sessions_offset_negative_validation(self, client): + with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + response = client.get( + '/v2/chat-sessions?offset=-1', + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 From 00e91d287ef20b8842dbf1ab97965311e62b1a90 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:51:01 +0100 Subject: [PATCH 054/163] Add POST /v1/screen-activity/sync endpoint New router for desktop screen activity sync. Accepts up to 100 screenshot rows per batch, writes to Firestore via existing database/screen_activity.py, and upserts Pinecone ns3 vectors in a background thread. Matches Rust backend contract. Closes part of #5302 Co-Authored-By: Claude Opus 4.6 --- backend/routers/screen_activity.py | 70 ++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 backend/routers/screen_activity.py diff --git a/backend/routers/screen_activity.py b/backend/routers/screen_activity.py new file mode 100644 index 0000000000..a0eeb4d851 --- /dev/null +++ b/backend/routers/screen_activity.py @@ -0,0 +1,70 @@ +import logging +import threading +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +import database.screen_activity as screen_activity_db +import database.vector_db as vector_db +from utils.other import endpoints as auth + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +class ScreenActivityRow(BaseModel): + id: int = Field(description="Screenshot ID (used as Firestore document ID)") + timestamp: str = Field(description="Timestamp in RFC3339 or 'YYYY-MM-DD HH:MM:SS' format") + appName: str = Field(default='', description="Application name") + windowTitle: str = Field(default='', description="Window title") + ocrText: str = Field(default='', description="OCR text from screenshot (truncated to 1000 chars)") + embedding: Optional[List[float]] = Field(default=None, description="Optional vector embedding (3072-dim Gemini)") + + +class ScreenActivitySyncRequest(BaseModel): + rows: List[ScreenActivityRow] + + +class ScreenActivitySyncResponse(BaseModel): + synced: int = Field(description="Number of rows written to Firestore") + last_id: int = Field(description="Maximum row ID from the batch") + + +@router.post('/v1/screen-activity/sync', response_model=ScreenActivitySyncResponse, tags=['screen-activity']) +def sync_screen_activity( + request: ScreenActivitySyncRequest, + uid: str = Depends(auth.get_current_user_uid), +): + if len(request.rows) > 100: + raise HTTPException(status_code=400, detail="Maximum 100 rows per batch") + + if not request.rows: + return ScreenActivitySyncResponse(synced=0, last_id=0) + + # Convert Pydantic models to dicts for database layer + rows_data = [row.model_dump() for row in request.rows] + + # Firestore upsert (synchronous — blocks response until written) + synced = screen_activity_db.upsert_screen_activity(uid, rows_data) + + # Pinecone vector upsert (fire-and-forget background thread) + rows_with_embeddings = [r for r in rows_data if r.get('embedding')] + if rows_with_embeddings: + thread = threading.Thread( + target=_upsert_vectors_background, + args=(uid, rows_with_embeddings), + daemon=True, + ) + thread.start() + + last_id = max(row.id for row in request.rows) + return ScreenActivitySyncResponse(synced=synced, last_id=last_id) + + +def _upsert_vectors_background(uid: str, rows: list): + try: + vector_db.upsert_screen_activity_vectors(uid, rows) + except Exception: + logger.exception('Failed to upsert screen activity vectors for uid=%s', uid) From bdc4d2c2c2be2a584a34ea666948b88dc61b0fef Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:51:05 +0100 Subject: [PATCH 055/163] Register screen_activity router in main.py Wire up the new screen_activity router for desktop migration. Part of #5302 Co-Authored-By: Claude Opus 4.6 --- backend/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/main.py b/backend/main.py index 394316e6cb..cd12ae57fe 100644 --- a/backend/main.py +++ b/backend/main.py @@ -46,6 +46,7 @@ phone_calls, agent_tools, desktop_chat, + screen_activity, ) from utils.other.timeout import TimeoutMiddleware @@ -106,6 +107,7 @@ app.include_router(phone_calls.router) app.include_router(agent_tools.router) app.include_router(desktop_chat.router) +app.include_router(screen_activity.router) methods_timeout = { From c0b418524ebd73e6c3476a4416345d59cb8559e1 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:51:11 +0100 Subject: [PATCH 056/163] Add assistant_settings and ai_user_profile database functions Per-section merge for assistant_settings preserves sibling sections. AI profile uses full-replace semantics. Both use Firestore set(merge=True) for document-creation safety. Part of #5302 Co-Authored-By: Claude Opus 4.6 --- backend/database/users.py | 65 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/backend/database/users.py b/backend/database/users.py index 0944b9eb0a..0276ae6c5d 100644 --- a/backend/database/users.py +++ b/backend/database/users.py @@ -1050,3 +1050,68 @@ def set_user_transcription_preferences(uid: str, single_language_mode: bool = No if update_data: user_ref.update(update_data) + + +# ************************************** +# ****** Assistant Settings ************ +# ************************************** + + +def get_assistant_settings(uid: str) -> dict: + """Get the user's assistant_settings map from their user document.""" + user_ref = db.collection('users').document(uid) + user_doc = user_ref.get() + if user_doc.exists: + user_data = user_doc.to_dict() + settings = user_data.get('assistant_settings', {}) + # update_channel is a top-level field, not inside assistant_settings + update_channel = user_data.get('update_channel') + if update_channel is not None: + settings['update_channel'] = update_channel + return settings + return {} + + +def update_assistant_settings(uid: str, data: dict) -> dict: + """Merge-update the user's assistant_settings map. Returns merged state. + + Uses per-section set(merge=True) to avoid overwriting sibling sections. + Each non-empty section is written individually so that e.g. patching + only 'focus' does not wipe 'shared' or 'task'. + """ + user_ref = db.collection('users').document(uid) + + # Separate update_channel (top-level) from assistant_settings sub-map + update_channel = data.pop('update_channel', None) + + # Write each section individually with merge to preserve siblings + for section_key, section_val in data.items(): + if isinstance(section_val, dict) and section_val: + user_ref.set({'assistant_settings': {section_key: section_val}}, merge=True) + + if update_channel is not None: + user_ref.set({'update_channel': update_channel}, merge=True) + + return get_assistant_settings(uid) + + +# ************************************** +# ******** AI User Profile ************* +# ************************************** + + +def get_ai_user_profile(uid: str) -> Optional[dict]: + """Get the user's ai_user_profile map from their user document.""" + user_ref = db.collection('users').document(uid) + user_doc = user_ref.get() + if user_doc.exists: + user_data = user_doc.to_dict() + return user_data.get('ai_user_profile') + return None + + +def update_ai_user_profile(uid: str, data: dict) -> dict: + """Replace the user's ai_user_profile map. Returns new state.""" + user_ref = db.collection('users').document(uid) + user_ref.set({'ai_user_profile': data}, merge=True) + return get_ai_user_profile(uid) From 9c34e9e84f64a0e807397a84dd1bb206f77a53c7 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:51:16 +0100 Subject: [PATCH 057/163] Add assistant-settings and ai-profile endpoints to users router GET/PATCH /v1/users/assistant-settings with per-section merge, validation (prompt length, list caps, confidence range). GET/PATCH /v1/users/ai-profile with RFC3339 timestamp validation and 10KB profile_text truncation. Matches Rust backend contracts. Part of #5302 Co-Authored-By: Claude Opus 4.6 --- backend/routers/users.py | 135 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/backend/routers/users.py b/backend/routers/users.py index 5958f7261d..4294ac8825 100644 --- a/backend/routers/users.py +++ b/backend/routers/users.py @@ -1193,3 +1193,138 @@ def generate(): media_type='application/json', headers={'Content-Disposition': 'attachment; filename="omi-export.json"'}, ) + + +# ************************************** +# ****** Assistant Settings ************ +# ************************************** + + +class SharedAssistantSettings(BaseModel): + cooldown_interval: Optional[int] = None + glow_overlay_enabled: Optional[bool] = None + analysis_delay: Optional[int] = None + screen_analysis_enabled: Optional[bool] = None + + +class FocusSettings(BaseModel): + enabled: Optional[bool] = None + analysis_prompt: Optional[str] = None + cooldown_interval: Optional[int] = None + notifications_enabled: Optional[bool] = None + excluded_apps: Optional[List[str]] = None + + +class TaskSettings(BaseModel): + enabled: Optional[bool] = None + analysis_prompt: Optional[str] = None + extraction_interval: Optional[float] = None + min_confidence: Optional[float] = None + notifications_enabled: Optional[bool] = None + allowed_apps: Optional[List[str]] = None + browser_keywords: Optional[List[str]] = None + + +class AdviceSettings(BaseModel): + enabled: Optional[bool] = None + analysis_prompt: Optional[str] = None + extraction_interval: Optional[float] = None + min_confidence: Optional[float] = None + notifications_enabled: Optional[bool] = None + excluded_apps: Optional[List[str]] = None + + +class MemorySettings(BaseModel): + enabled: Optional[bool] = None + analysis_prompt: Optional[str] = None + extraction_interval: Optional[float] = None + min_confidence: Optional[float] = None + notifications_enabled: Optional[bool] = None + excluded_apps: Optional[List[str]] = None + + +class AssistantSettingsData(BaseModel): + shared: Optional[SharedAssistantSettings] = None + focus: Optional[FocusSettings] = None + task: Optional[TaskSettings] = None + advice: Optional[AdviceSettings] = None + memory: Optional[MemorySettings] = None + update_channel: Optional[str] = None + + +def _validate_assistant_settings(data: AssistantSettingsData): + """Validate prompt lengths and list sizes matching Rust backend limits.""" + for section_name in ('focus', 'task', 'advice', 'memory'): + section = getattr(data, section_name, None) + if section and section.analysis_prompt and len(section.analysis_prompt) > 10000: + raise HTTPException(status_code=400, detail=f'{section_name}.analysis_prompt exceeds 10000 chars') + + if data.task: + if data.task.allowed_apps and len(data.task.allowed_apps) > 500: + raise HTTPException(status_code=400, detail='task.allowed_apps exceeds 500 items') + if data.task.browser_keywords and len(data.task.browser_keywords) > 500: + raise HTTPException(status_code=400, detail='task.browser_keywords exceeds 500 items') + if data.task.min_confidence is not None and not (0.0 <= data.task.min_confidence <= 1.0): + raise HTTPException(status_code=400, detail='task.min_confidence must be between 0.0 and 1.0') + + if data.advice and data.advice.min_confidence is not None and not (0.0 <= data.advice.min_confidence <= 1.0): + raise HTTPException(status_code=400, detail='advice.min_confidence must be between 0.0 and 1.0') + + if data.memory and data.memory.min_confidence is not None and not (0.0 <= data.memory.min_confidence <= 1.0): + raise HTTPException(status_code=400, detail='memory.min_confidence must be between 0.0 and 1.0') + + +@router.get('/v1/users/assistant-settings', tags=['users']) +def get_assistant_settings_endpoint(uid: str = Depends(auth.get_current_user_uid)): + return get_assistant_settings(uid) + + +@router.patch('/v1/users/assistant-settings', tags=['users']) +def update_assistant_settings_endpoint( + data: AssistantSettingsData, + uid: str = Depends(auth.get_current_user_uid), +): + _validate_assistant_settings(data) + update_data = data.model_dump(exclude_none=True) + if not update_data: + return get_assistant_settings(uid) + return update_assistant_settings(uid, update_data) + + +# ************************************** +# ******** AI User Profile ************* +# ************************************** + + +class UpdateAIProfileRequest(BaseModel): + profile_text: str + generated_at: str + data_sources_used: int + + +@router.get('/v1/users/ai-profile', tags=['users']) +def get_ai_profile_endpoint(uid: str = Depends(auth.get_current_user_uid)): + return get_ai_user_profile(uid) + + +@router.patch('/v1/users/ai-profile', tags=['users']) +def update_ai_profile_endpoint( + data: UpdateAIProfileRequest, + uid: str = Depends(auth.get_current_user_uid), +): + # Validate generated_at is RFC3339 (matching Rust behavior) + try: + datetime.fromisoformat(data.generated_at.replace('Z', '+00:00')) + except (ValueError, AttributeError): + raise HTTPException(status_code=400, detail="generated_at must be a valid RFC3339 timestamp") + + # Truncate profile_text to 10000 bytes (matching Rust behavior — truncate, not reject) + profile_bytes = data.profile_text.encode('utf-8')[:10000] + profile_text = profile_bytes.decode('utf-8', errors='ignore') + + profile_data = { + 'profile_text': profile_text, + 'generated_at': data.generated_at, + 'data_sources_used': data.data_sources_used, + } + return update_ai_user_profile(uid, profile_data) From 129f5e10502f26b36b312705a3804b234729af66 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:58:38 +0100 Subject: [PATCH 058/163] Add explicit Firestore error handling in screen-activity sync Wrap upsert in try/except and return 500 with controlled message on failure. Matches Rust error handling behavior. Co-Authored-By: Claude Opus 4.6 --- backend/routers/screen_activity.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backend/routers/screen_activity.py b/backend/routers/screen_activity.py index a0eeb4d851..81105bdddc 100644 --- a/backend/routers/screen_activity.py +++ b/backend/routers/screen_activity.py @@ -47,7 +47,11 @@ def sync_screen_activity( rows_data = [row.model_dump() for row in request.rows] # Firestore upsert (synchronous — blocks response until written) - synced = screen_activity_db.upsert_screen_activity(uid, rows_data) + try: + synced = screen_activity_db.upsert_screen_activity(uid, rows_data) + except Exception: + logger.exception('Firestore upsert failed for uid=%s', uid) + raise HTTPException(status_code=500, detail="Failed to sync screen activity") # Pinecone vector upsert (fire-and-forget background thread) rows_with_embeddings = [r for r in rows_data if r.get('embedding')] From fcc55c9816f8f322df6a471a6d3a96282f97b77b Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:58:42 +0100 Subject: [PATCH 059/163] Use update() for ai_user_profile full replacement Prevents stale nested keys from persisting. Falls back to set(merge=True) when document doesn't exist yet. Co-Authored-By: Claude Opus 4.6 --- backend/database/users.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/backend/database/users.py b/backend/database/users.py index 0276ae6c5d..986632e9b6 100644 --- a/backend/database/users.py +++ b/backend/database/users.py @@ -1111,7 +1111,15 @@ def get_ai_user_profile(uid: str) -> Optional[dict]: def update_ai_user_profile(uid: str, data: dict) -> dict: - """Replace the user's ai_user_profile map. Returns new state.""" + """Full-replace the user's ai_user_profile map. Returns new state. + + Uses update() for true field replacement (removes stale nested keys). + Falls back to set(merge=True) if document doesn't exist yet. + """ user_ref = db.collection('users').document(uid) - user_ref.set({'ai_user_profile': data}, merge=True) + try: + user_ref.update({'ai_user_profile': data}) + except Exception: + # Document may not exist — create with merge + user_ref.set({'ai_user_profile': data}, merge=True) return get_ai_user_profile(uid) From 0ef223f6f06a84ed49401ef0855bae728eb941df Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 07:58:46 +0100 Subject: [PATCH 060/163] Strict RFC3339 validation and timestamp storage for ai-profile Require T separator and timezone in generated_at. Store as parsed datetime for Firestore timestampValue compatibility with Rust. Co-Authored-By: Claude Opus 4.6 --- backend/routers/users.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/routers/users.py b/backend/routers/users.py index 4294ac8825..4845c94919 100644 --- a/backend/routers/users.py +++ b/backend/routers/users.py @@ -1312,9 +1312,12 @@ def update_ai_profile_endpoint( data: UpdateAIProfileRequest, uid: str = Depends(auth.get_current_user_uid), ): - # Validate generated_at is RFC3339 (matching Rust behavior) + # Strict RFC3339 validation — require T separator and timezone (Z or +/-offset) + ts = data.generated_at + if 'T' not in ts or (not ts.endswith('Z') and '+' not in ts.split('T')[1] and '-' not in ts.split('T')[1]): + raise HTTPException(status_code=400, detail="generated_at must be a valid RFC3339 timestamp") try: - datetime.fromisoformat(data.generated_at.replace('Z', '+00:00')) + parsed_ts = datetime.fromisoformat(ts.replace('Z', '+00:00')) except (ValueError, AttributeError): raise HTTPException(status_code=400, detail="generated_at must be a valid RFC3339 timestamp") @@ -1324,7 +1327,7 @@ def update_ai_profile_endpoint( profile_data = { 'profile_text': profile_text, - 'generated_at': data.generated_at, + 'generated_at': parsed_ts, 'data_sources_used': data.data_sources_used, } return update_ai_user_profile(uid, profile_data) From cb68247a492383c3a4b3844bd8ce4f2556975fff Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:03:01 +0100 Subject: [PATCH 061/163] Use regex for strict RFC3339 validation on ai-profile timestamp Rejects non-standard offsets like +00 or +0000. Requires full YYYY-MM-DDTHH:MM:SS(Z|+HH:MM|-HH:MM) format matching Rust. Co-Authored-By: Claude Opus 4.6 --- backend/routers/users.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/routers/users.py b/backend/routers/users.py index 4845c94919..62a274c856 100644 --- a/backend/routers/users.py +++ b/backend/routers/users.py @@ -1,4 +1,5 @@ import json +import re import threading import uuid from typing import List, Dict, Any, Union, Optional @@ -1312,9 +1313,9 @@ def update_ai_profile_endpoint( data: UpdateAIProfileRequest, uid: str = Depends(auth.get_current_user_uid), ): - # Strict RFC3339 validation — require T separator and timezone (Z or +/-offset) + # Strict RFC3339: YYYY-MM-DDTHH:MM:SS[.frac](Z|+HH:MM|-HH:MM) ts = data.generated_at - if 'T' not in ts or (not ts.endswith('Z') and '+' not in ts.split('T')[1] and '-' not in ts.split('T')[1]): + if not re.fullmatch(r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?(Z|[+-]\d{2}:\d{2})', ts): raise HTTPException(status_code=400, detail="generated_at must be a valid RFC3339 timestamp") try: parsed_ts = datetime.fromisoformat(ts.replace('Z', '+00:00')) From eec3e24ab4a6133a5693d8f067c94b1bcea8ffd0 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:03:05 +0100 Subject: [PATCH 062/163] Narrow ai_user_profile fallback to NotFound only Prevents unrelated update errors from silently converting to merge-write path. Only falls back on document-not-found. Co-Authored-By: Claude Opus 4.6 --- backend/database/users.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/database/users.py b/backend/database/users.py index 986632e9b6..75b7fe111c 100644 --- a/backend/database/users.py +++ b/backend/database/users.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone from typing import Optional +from google.api_core.exceptions import NotFound from google.cloud import firestore from google.cloud.firestore_v1 import FieldFilter, transactional @@ -1119,7 +1120,7 @@ def update_ai_user_profile(uid: str, data: dict) -> dict: user_ref = db.collection('users').document(uid) try: user_ref.update({'ai_user_profile': data}) - except Exception: - # Document may not exist — create with merge + except NotFound: + # Document doesn't exist yet — create with merge user_ref.set({'ai_user_profile': data}, merge=True) return get_ai_user_profile(uid) From 420fea00556e2882b0f3ddb57db46a05a0880653 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:05:16 +0100 Subject: [PATCH 063/163] Use code-based not-found check instead of importing NotFound Avoids adding google.api_core import to database/users.py. Checks e.code == 404 to narrow fallback, re-raises other errors. Co-Authored-By: Claude Opus 4.6 --- backend/database/users.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/database/users.py b/backend/database/users.py index 75b7fe111c..5c53c54f45 100644 --- a/backend/database/users.py +++ b/backend/database/users.py @@ -1,7 +1,6 @@ from datetime import datetime, timezone from typing import Optional -from google.api_core.exceptions import NotFound from google.cloud import firestore from google.cloud.firestore_v1 import FieldFilter, transactional @@ -1120,7 +1119,10 @@ def update_ai_user_profile(uid: str, data: dict) -> dict: user_ref = db.collection('users').document(uid) try: user_ref.update({'ai_user_profile': data}) - except NotFound: - # Document doesn't exist yet — create with merge - user_ref.set({'ai_user_profile': data}, merge=True) + except Exception as e: + # Only fall back on not-found (code 404); re-raise other errors + if hasattr(e, 'code') and e.code == 404: + user_ref.set({'ai_user_profile': data}, merge=True) + else: + raise return get_ai_user_profile(uid) From 6fbbb6728beef9d3a9de919d84c9d108849f3973 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:07:30 +0100 Subject: [PATCH 064/163] Consolidate desktop chat endpoints into existing chat router MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move session CRUD, message save, and rating endpoints from desktop_chat.py into routers/chat.py with clear section comments. Desktop uses ACP Bridge for AI — backend only persists messages, so save endpoint uses /v2/messages/save (non-streaming). Co-Authored-By: Claude Opus 4.6 --- backend/routers/chat.py | 197 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 196 insertions(+), 1 deletion(-) diff --git a/backend/routers/chat.py b/backend/routers/chat.py index a4a748ef4d..6651e75651 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -6,8 +6,9 @@ from typing import List, Optional from pathlib import Path -from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form +from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File, Form from fastapi.responses import StreamingResponse +from pydantic import BaseModel from multipart.multipart import shutil import database.chat as chat_db @@ -498,6 +499,200 @@ def upload_file_chat(files: List[UploadFile] = File(...), uid: str = Depends(aut return response +# --------------------------------------------------------------------------- +# Desktop: session management, message persistence, and rating +# The desktop app manages sessions explicitly (vs mobile's implicit sessions) +# and persists messages without triggering the LLM pipeline — AI responses +# come from the local ACP Bridge, not the backend. +# --------------------------------------------------------------------------- + + +class CreateChatSessionRequest(BaseModel): + title: Optional[str] = None + app_id: Optional[str] = None + + +class UpdateChatSessionRequest(BaseModel): + title: Optional[str] = None + starred: Optional[bool] = None + + +class ChatSessionResponse(BaseModel): + id: str + title: str + preview: Optional[str] = None + created_at: datetime + updated_at: datetime + app_id: Optional[str] = None + message_count: int = 0 + starred: bool = False + + +class SaveMessageRequest(BaseModel): + text: str + sender: str + app_id: Optional[str] = None + session_id: Optional[str] = None + metadata: Optional[str] = None + + +class SaveMessageResponse(BaseModel): + id: str + created_at: datetime + + +class RateMessageRequest(BaseModel): + rating: Optional[int] = None + + +class StatusResponse(BaseModel): + status: str + + +@router.get('/v2/chat-sessions', response_model=List[ChatSessionResponse], tags=['chat']) +def list_chat_sessions( + app_id: Optional[str] = Query(None), + limit: int = Query(50, ge=1, le=200), + offset: int = Query(0, ge=0), + starred: Optional[bool] = Query(None), + uid: str = Depends(auth.get_current_user_uid), +): + """Desktop: list chat sessions with optional filtering.""" + sessions = chat_db.get_chat_sessions(uid, app_id=app_id, limit=limit, offset=offset, starred=starred) + return sessions + + +@router.post('/v2/chat-sessions', response_model=ChatSessionResponse, tags=['chat']) +def create_chat_session( + request: CreateChatSessionRequest, + uid: str = Depends(auth.get_current_user_uid), +): + """Desktop: explicitly create a named chat session.""" + now = datetime.now(timezone.utc) + session_data = { + 'id': str(uuid.uuid4()), + 'title': request.title or 'New Chat', + 'preview': None, + 'created_at': now, + 'updated_at': now, + 'app_id': request.app_id, + 'plugin_id': request.app_id, + 'message_count': 0, + 'starred': False, + } + chat_db.add_chat_session(uid, session_data) + return session_data + + +@router.get('/v2/chat-sessions/{session_id}', response_model=ChatSessionResponse, tags=['chat']) +def get_chat_session_by_id( + session_id: str, + uid: str = Depends(auth.get_current_user_uid), +): + """Desktop: get a single chat session by ID.""" + session = chat_db.get_chat_session_by_id(uid, session_id) + if not session: + raise HTTPException(status_code=404, detail="Chat session not found") + return session + + +@router.patch('/v2/chat-sessions/{session_id}', response_model=ChatSessionResponse, tags=['chat']) +def update_chat_session( + session_id: str, + request: UpdateChatSessionRequest, + uid: str = Depends(auth.get_current_user_uid), +): + """Desktop: update session title or starred status.""" + session = chat_db.get_chat_session_by_id(uid, session_id) + if not session: + raise HTTPException(status_code=404, detail="Chat session not found") + + update_data = {} + if request.title is not None: + update_data['title'] = request.title + if request.starred is not None: + update_data['starred'] = request.starred + if update_data: + update_data['updated_at'] = datetime.now(timezone.utc) + chat_db.update_chat_session(uid, session_id, update_data) + session.update(update_data) + + return session + + +@router.delete('/v2/chat-sessions/{session_id}', response_model=StatusResponse, tags=['chat']) +def delete_chat_session( + session_id: str, + uid: str = Depends(auth.get_current_user_uid), +): + """Desktop: delete a chat session and cascade-delete its messages.""" + session = chat_db.get_chat_session_by_id(uid, session_id) + if not session: + raise HTTPException(status_code=404, detail="Chat session not found") + + chat_db.delete_chat_session_messages(uid, session_id) + chat_db.delete_chat_session(uid, session_id) + return StatusResponse(status='ok') + + +@router.post('/v2/messages/save', response_model=SaveMessageResponse, tags=['chat']) +def save_message( + request: SaveMessageRequest, + uid: str = Depends(auth.get_current_user_uid), +): + """Desktop: persist a message without triggering LLM pipeline. + + The desktop app runs AI locally via ACP Bridge and only calls this + endpoint to sync human + AI messages to Firestore. + """ + if not request.text or not request.text.strip(): + raise HTTPException(status_code=422, detail="Message text cannot be empty") + if request.sender not in ('human', 'ai'): + raise HTTPException(status_code=422, detail="sender must be 'human' or 'ai'") + + now = datetime.now(timezone.utc) + message_id = str(uuid.uuid4()) + message_data = { + 'id': message_id, + 'text': request.text, + 'created_at': now, + 'sender': request.sender, + 'app_id': request.app_id, + 'plugin_id': request.app_id, + 'session_id': request.session_id, + 'chat_session_id': request.session_id, + 'rating': None, + 'reported': False, + 'type': 'text', + 'memories_id': [], + 'from_external_integration': False, + 'metadata': request.metadata, + } + chat_db.save_message(uid, message_data) + + if request.session_id: + try: + chat_db.add_message_to_chat_session(uid, request.session_id, message_id) + except Exception as e: + logger.warning(f"Failed to link message to session {request.session_id}: {e}") + + return SaveMessageResponse(id=message_id, created_at=now) + + +@router.patch('/v2/messages/{message_id}/rating', response_model=StatusResponse, tags=['chat']) +def rate_message( + message_id: str, + request: RateMessageRequest, + uid: str = Depends(auth.get_current_user_uid), +): + """Desktop: rate a message (1 = thumbs up, -1 = thumbs down, null = clear).""" + if request.rating is not None and request.rating not in (1, -1): + raise HTTPException(status_code=422, detail="rating must be 1, -1, or null") + + chat_db.update_message_rating(uid, message_id, request.rating) + return StatusResponse(status='ok') + + # CLEANUP: Remove after new app goes to prod ---------------------------------------------------------- From e6df53058b20b1b659823fc35353c6feeec923f4 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:07:33 +0100 Subject: [PATCH 065/163] =?UTF-8?q?Remove=20desktop=5Fchat.py=20=E2=80=94?= =?UTF-8?q?=20endpoints=20consolidated=20into=20chat.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 --- backend/routers/desktop_chat.py | 213 -------------------------------- 1 file changed, 213 deletions(-) delete mode 100644 backend/routers/desktop_chat.py diff --git a/backend/routers/desktop_chat.py b/backend/routers/desktop_chat.py deleted file mode 100644 index 355b47b936..0000000000 --- a/backend/routers/desktop_chat.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Desktop chat sessions CRUD + message operations. - -These endpoints support the desktop app's session-based chat model where -messages are organized into named sessions. The Python backend's existing -streaming chat (routers/chat.py) is session-aware internally, but the -desktop Swift client expects explicit CRUD for sessions and simple -message save/rating. -""" - -import uuid -from datetime import datetime, timezone -from typing import Optional, List - -from fastapi import APIRouter, Depends, HTTPException, Query -from pydantic import BaseModel - -import database.chat as chat_db -from utils.other import endpoints as auth -import logging - -logger = logging.getLogger(__name__) - -router = APIRouter() - - -# --------------------------------------------------------------------------- -# Models -# --------------------------------------------------------------------------- - - -class CreateChatSessionRequest(BaseModel): - title: Optional[str] = None - app_id: Optional[str] = None - - -class UpdateChatSessionRequest(BaseModel): - title: Optional[str] = None - starred: Optional[bool] = None - - -class ChatSessionResponse(BaseModel): - id: str - title: str - preview: Optional[str] = None - created_at: datetime - updated_at: datetime - app_id: Optional[str] = None - message_count: int = 0 - starred: bool = False - - -class SaveMessageRequest(BaseModel): - text: str - sender: str - app_id: Optional[str] = None - session_id: Optional[str] = None - metadata: Optional[str] = None - - -class SaveMessageResponse(BaseModel): - id: str - created_at: datetime - - -class RateMessageRequest(BaseModel): - rating: Optional[int] = None - - -class StatusResponse(BaseModel): - status: str - - -# --------------------------------------------------------------------------- -# Chat Sessions CRUD -# --------------------------------------------------------------------------- - - -@router.get('/v2/chat-sessions', response_model=List[ChatSessionResponse], tags=['desktop-chat']) -def list_chat_sessions( - app_id: Optional[str] = Query(None), - limit: int = Query(50, ge=1, le=200), - offset: int = Query(0, ge=0), - starred: Optional[bool] = Query(None), - uid: str = Depends(auth.get_current_user_uid), -): - sessions = chat_db.get_chat_sessions(uid, app_id=app_id, limit=limit, offset=offset, starred=starred) - return sessions - - -@router.post('/v2/chat-sessions', response_model=ChatSessionResponse, tags=['desktop-chat']) -def create_chat_session( - request: CreateChatSessionRequest, - uid: str = Depends(auth.get_current_user_uid), -): - now = datetime.now(timezone.utc) - session_data = { - 'id': str(uuid.uuid4()), - 'title': request.title or 'New Chat', - 'preview': None, - 'created_at': now, - 'updated_at': now, - 'app_id': request.app_id, - 'plugin_id': request.app_id, # Python backend uses plugin_id for filtering - 'message_count': 0, - 'starred': False, - } - chat_db.add_chat_session(uid, session_data) - return session_data - - -@router.get('/v2/chat-sessions/{session_id}', response_model=ChatSessionResponse, tags=['desktop-chat']) -def get_chat_session( - session_id: str, - uid: str = Depends(auth.get_current_user_uid), -): - session = chat_db.get_chat_session_by_id(uid, session_id) - if not session: - raise HTTPException(status_code=404, detail="Chat session not found") - return session - - -@router.patch('/v2/chat-sessions/{session_id}', response_model=ChatSessionResponse, tags=['desktop-chat']) -def update_chat_session( - session_id: str, - request: UpdateChatSessionRequest, - uid: str = Depends(auth.get_current_user_uid), -): - session = chat_db.get_chat_session_by_id(uid, session_id) - if not session: - raise HTTPException(status_code=404, detail="Chat session not found") - - update_data = {} - if request.title is not None: - update_data['title'] = request.title - if request.starred is not None: - update_data['starred'] = request.starred - if update_data: - update_data['updated_at'] = datetime.now(timezone.utc) - chat_db.update_chat_session(uid, session_id, update_data) - session.update(update_data) - - return session - - -@router.delete('/v2/chat-sessions/{session_id}', response_model=StatusResponse, tags=['desktop-chat']) -def delete_chat_session( - session_id: str, - uid: str = Depends(auth.get_current_user_uid), -): - session = chat_db.get_chat_session_by_id(uid, session_id) - if not session: - raise HTTPException(status_code=404, detail="Chat session not found") - - chat_db.delete_chat_session_messages(uid, session_id) - chat_db.delete_chat_session(uid, session_id) - return StatusResponse(status='ok') - - -# --------------------------------------------------------------------------- -# Desktop message CRUD (simple save, not streaming) -# --------------------------------------------------------------------------- - - -@router.post('/v2/desktop/messages', response_model=SaveMessageResponse, tags=['desktop-chat']) -def save_message( - request: SaveMessageRequest, - uid: str = Depends(auth.get_current_user_uid), -): - if not request.text or not request.text.strip(): - raise HTTPException(status_code=422, detail="Message text cannot be empty") - if request.sender not in ('human', 'ai'): - raise HTTPException(status_code=422, detail="sender must be 'human' or 'ai'") - - now = datetime.now(timezone.utc) - message_id = str(uuid.uuid4()) - message_data = { - 'id': message_id, - 'text': request.text, - 'created_at': now, - 'sender': request.sender, - 'app_id': request.app_id, - 'plugin_id': request.app_id, - 'session_id': request.session_id, - 'chat_session_id': request.session_id, - 'rating': None, - 'reported': False, - 'type': 'text', - 'memories_id': [], - 'from_external_integration': False, - 'metadata': request.metadata, - } - chat_db.save_message(uid, message_data) - - if request.session_id: - try: - chat_db.add_message_to_chat_session(uid, request.session_id, message_id) - except Exception as e: - logger.warning(f"Failed to link message to session {request.session_id}: {e}") - - return SaveMessageResponse(id=message_id, created_at=now) - - -@router.patch('/v2/messages/{message_id}/rating', response_model=StatusResponse, tags=['desktop-chat']) -def rate_message( - message_id: str, - request: RateMessageRequest, - uid: str = Depends(auth.get_current_user_uid), -): - if request.rating is not None and request.rating not in (1, -1): - raise HTTPException(status_code=422, detail="rating must be 1, -1, or null") - - chat_db.update_message_rating(uid, message_id, request.rating) - return StatusResponse(status='ok') From 1c006b3c03a9415c48fe68fb8c281cbfb00cedfb Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:07:37 +0100 Subject: [PATCH 066/163] Remove desktop_chat import from main.py Co-Authored-By: Claude Opus 4.6 --- backend/main.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/backend/main.py b/backend/main.py index cd12ae57fe..3f87e6af53 100644 --- a/backend/main.py +++ b/backend/main.py @@ -45,7 +45,6 @@ announcements, phone_calls, agent_tools, - desktop_chat, screen_activity, ) @@ -106,7 +105,6 @@ app.include_router(announcements.router) app.include_router(phone_calls.router) app.include_router(agent_tools.router) -app.include_router(desktop_chat.router) app.include_router(screen_activity.router) From 40aba057b227c6c702fea798a81499fa9c8f096e Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:07:42 +0100 Subject: [PATCH 067/163] Update desktop chat tests to import from routers.chat Change mock paths and endpoint paths to match consolidated router. Save message endpoint now at /v2/messages/save. Co-Authored-By: Claude Opus 4.6 --- backend/tests/unit/test_desktop_chat.py | 86 ++++++++++++------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/backend/tests/unit/test_desktop_chat.py b/backend/tests/unit/test_desktop_chat.py index a409c3aea5..75e2c86f24 100644 --- a/backend/tests/unit/test_desktop_chat.py +++ b/backend/tests/unit/test_desktop_chat.py @@ -16,7 +16,7 @@ ]: sys.modules.setdefault(mod_name, MagicMock()) -from routers.desktop_chat import ( +from routers.chat import ( CreateChatSessionRequest, UpdateChatSessionRequest, ChatSessionResponse, @@ -71,8 +71,8 @@ def client(self): def test_create_session(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.add_chat_session') as mock_add, + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.add_chat_session') as mock_add, ): mock_add.side_effect = lambda uid, data: data response = client.post( @@ -89,8 +89,8 @@ def test_create_session(self, client): def test_create_session_default_title(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.add_chat_session') as mock_add, + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.add_chat_session') as mock_add, ): mock_add.side_effect = lambda uid, data: data response = client.post( @@ -108,8 +108,8 @@ def test_list_sessions(self, client): {'id': 's2', 'title': 'Chat 2', 'created_at': now, 'updated_at': now, 'message_count': 3, 'starred': True}, ] with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.get_chat_sessions', return_value=mock_sessions), + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_sessions', return_value=mock_sessions), ): response = client.get('/v2/chat-sessions', headers={'Authorization': 'Bearer test'}) assert response.status_code == 200 @@ -121,8 +121,8 @@ def test_get_session(self, client): now = datetime.now(timezone.utc) mock_session = {'id': 's1', 'title': 'Chat', 'created_at': now, 'updated_at': now, 'message_count': 0, 'starred': False} with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=mock_session), + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_session_by_id', return_value=mock_session), ): response = client.get('/v2/chat-sessions/s1', headers={'Authorization': 'Bearer test'}) assert response.status_code == 200 @@ -130,8 +130,8 @@ def test_get_session(self, client): def test_get_session_not_found(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=None), + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_session_by_id', return_value=None), ): response = client.get('/v2/chat-sessions/missing', headers={'Authorization': 'Bearer test'}) assert response.status_code == 404 @@ -140,9 +140,9 @@ def test_update_session_returns_full_session(self, client): now = datetime.now(timezone.utc) mock_session = {'id': 's1', 'title': 'Old', 'created_at': now, 'updated_at': now, 'message_count': 0, 'starred': False} with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=mock_session), - patch('routers.desktop_chat.chat_db.update_chat_session') as mock_update, + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_session_by_id', return_value=mock_session), + patch('routers.chat.chat_db.update_chat_session') as mock_update, ): response = client.patch( '/v2/chat-sessions/s1', @@ -159,10 +159,10 @@ def test_delete_session_cascades_messages(self, client): now = datetime.now(timezone.utc) mock_session = {'id': 's1', 'title': 'Del', 'created_at': now, 'updated_at': now} with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=mock_session), - patch('routers.desktop_chat.chat_db.delete_chat_session_messages') as mock_del_msgs, - patch('routers.desktop_chat.chat_db.delete_chat_session') as mock_del, + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_session_by_id', return_value=mock_session), + patch('routers.chat.chat_db.delete_chat_session_messages') as mock_del_msgs, + patch('routers.chat.chat_db.delete_chat_session') as mock_del, ): response = client.delete('/v2/chat-sessions/s1', headers={'Authorization': 'Bearer test'}) assert response.status_code == 200 @@ -186,13 +186,13 @@ def client(self): def test_save_message(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.save_message') as mock_save, - patch('routers.desktop_chat.chat_db.add_message_to_chat_session'), + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.save_message') as mock_save, + patch('routers.chat.chat_db.add_message_to_chat_session'), ): mock_save.side_effect = lambda uid, data: data response = client.post( - '/v2/desktop/messages', + '/v2/messages/save', json={'text': 'Hello', 'sender': 'human', 'session_id': 's1'}, headers={'Authorization': 'Bearer test'}, ) @@ -202,18 +202,18 @@ def test_save_message(self, client): assert 'created_at' in data def test_save_message_empty_text_422(self, client): - with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): response = client.post( - '/v2/desktop/messages', + '/v2/messages/save', json={'text': ' ', 'sender': 'human'}, headers={'Authorization': 'Bearer test'}, ) assert response.status_code == 422 def test_save_message_invalid_sender_422(self, client): - with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): response = client.post( - '/v2/desktop/messages', + '/v2/messages/save', json={'text': 'Hello', 'sender': 'bot'}, headers={'Authorization': 'Bearer test'}, ) @@ -221,8 +221,8 @@ def test_save_message_invalid_sender_422(self, client): def test_rate_message_thumbs_up(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.update_message_rating') as mock_rate, + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.update_message_rating') as mock_rate, ): response = client.patch( '/v2/messages/msg-1/rating', @@ -236,8 +236,8 @@ def test_rate_message_thumbs_up(self, client): def test_rate_message_clear(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.update_message_rating') as mock_rate, + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.update_message_rating') as mock_rate, ): response = client.patch( '/v2/messages/msg-1/rating', @@ -251,8 +251,8 @@ def test_rate_message_clear(self, client): def test_rate_message_thumbs_down(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.update_message_rating') as mock_rate, + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.update_message_rating') as mock_rate, ): response = client.patch( '/v2/messages/msg-1/rating', @@ -263,7 +263,7 @@ def test_rate_message_thumbs_down(self, client): assert mock_rate.call_args[0][2] == -1 def test_rate_message_invalid_value_422(self, client): - with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): response = client.patch( '/v2/messages/msg-1/rating', json={'rating': 5}, @@ -273,8 +273,8 @@ def test_rate_message_invalid_value_422(self, client): def test_update_session_not_found(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=None), + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_session_by_id', return_value=None), ): response = client.patch( '/v2/chat-sessions/missing', @@ -285,8 +285,8 @@ def test_update_session_not_found(self, client): def test_delete_session_not_found(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.get_chat_session_by_id', return_value=None), + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_session_by_id', return_value=None), ): response = client.delete( '/v2/chat-sessions/missing', @@ -296,13 +296,13 @@ def test_delete_session_not_found(self, client): def test_save_message_session_link_failure_still_succeeds(self, client): with ( - patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.desktop_chat.chat_db.save_message') as mock_save, - patch('routers.desktop_chat.chat_db.add_message_to_chat_session', side_effect=Exception('Firestore error')), + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.save_message') as mock_save, + patch('routers.chat.chat_db.add_message_to_chat_session', side_effect=Exception('Firestore error')), ): mock_save.side_effect = lambda uid, data: data response = client.post( - '/v2/desktop/messages', + '/v2/messages/save', json={'text': 'Hello', 'sender': 'human', 'session_id': 's1'}, headers={'Authorization': 'Bearer test'}, ) @@ -310,7 +310,7 @@ def test_save_message_session_link_failure_still_succeeds(self, client): assert 'id' in response.json() def test_list_sessions_limit_validation(self, client): - with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): response = client.get( '/v2/chat-sessions?limit=0', headers={'Authorization': 'Bearer test'}, @@ -318,7 +318,7 @@ def test_list_sessions_limit_validation(self, client): assert response.status_code == 422 def test_list_sessions_offset_negative_validation(self, client): - with patch('routers.desktop_chat.auth.get_current_user_uid', return_value='uid-1'): + with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): response = client.get( '/v2/chat-sessions?offset=-1', headers={'Authorization': 'Bearer test'}, From a171b89964b3c58597adadad4aed818815014662 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:07:45 +0100 Subject: [PATCH 068/163] Update Swift saveMessage to use /v2/messages/save endpoint Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/APIClient.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 533f273750..b08e4f3554 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -3967,7 +3967,7 @@ extension APIClient { let metadata: String? } let body = SaveRequest(text: text, sender: sender, app_id: appId, session_id: sessionId, metadata: metadata) - return try await post("v2/messages", body: body) + return try await post("v2/messages/save", body: body) } /// Fetch chat message history From f7376d15da446e3cbd4ae1b4e2168acb9f4c49eb Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:11:53 +0100 Subject: [PATCH 069/163] Add unit tests for screen-activity sync, assistant-settings, and ai-profile endpoints 30 tests covering: batch limits, auth, last_id, Firestore errors, vector thread spawning, settings validation (prompt length, list caps, confidence ranges), RFC3339 parsing, profile text truncation, and multibyte boundary safety. Co-Authored-By: Claude Opus 4.6 --- .../test_assistant_settings_ai_profile.py | 156 ++++++++++++++++++ .../tests/unit/test_screen_activity_sync.py | 79 +++++++++ 2 files changed, 235 insertions(+) create mode 100644 backend/tests/unit/test_assistant_settings_ai_profile.py create mode 100644 backend/tests/unit/test_screen_activity_sync.py diff --git a/backend/tests/unit/test_assistant_settings_ai_profile.py b/backend/tests/unit/test_assistant_settings_ai_profile.py new file mode 100644 index 0000000000..402dabb5e6 --- /dev/null +++ b/backend/tests/unit/test_assistant_settings_ai_profile.py @@ -0,0 +1,156 @@ +from unittest.mock import patch, MagicMock + +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture +def client(): + with patch('database.screen_activity.db'), \ + patch('database.vector_db.Pinecone'), \ + patch('database.vector_db.pc'), \ + patch('database.vector_db.index'), \ + patch('utils.llm.clients.embeddings'): + from main import app + with TestClient(app) as c: + yield c + + +AUTH = {"Authorization": "Bearer 123testuser"} + + +class TestAssistantSettingsValidation: + def test_get_empty_returns_200(self, client): + with patch('routers.users.get_assistant_settings', return_value={}): + resp = client.get("/v1/users/assistant-settings", headers=AUTH) + assert resp.status_code == 200 + assert resp.json() == {} + + def test_patch_prompt_exceeds_10000_chars(self, client): + data = {"focus": {"analysis_prompt": "x" * 10001}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + assert "10000" in resp.json()["detail"] + + def test_patch_prompt_at_10000_chars_accepted(self, client): + data = {"focus": {"analysis_prompt": "x" * 10000}} + with patch('routers.users.update_assistant_settings', return_value=data): + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 200 + + def test_patch_allowed_apps_exceeds_500(self, client): + data = {"task": {"allowed_apps": ["app"] * 501}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + assert "500" in resp.json()["detail"] + + def test_patch_browser_keywords_exceeds_500(self, client): + data = {"task": {"browser_keywords": ["kw"] * 501}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_task_confidence_below_zero(self, client): + data = {"task": {"min_confidence": -0.1}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_task_confidence_above_one(self, client): + data = {"task": {"min_confidence": 1.5}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_task_confidence_zero_accepted(self, client): + data = {"task": {"min_confidence": 0.0}} + with patch('routers.users.update_assistant_settings', return_value=data): + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 200 + + def test_patch_task_confidence_one_accepted(self, client): + data = {"task": {"min_confidence": 1.0}} + with patch('routers.users.update_assistant_settings', return_value=data): + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 200 + + def test_patch_advice_confidence_above_one(self, client): + data = {"advice": {"min_confidence": 1.1}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_memory_confidence_below_zero(self, client): + data = {"memory": {"min_confidence": -0.5}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_empty_body_returns_current(self, client): + with patch('routers.users.get_assistant_settings', return_value={"focus": {"enabled": True}}): + resp = client.patch("/v1/users/assistant-settings", json={}, headers=AUTH) + assert resp.status_code == 200 + + +class TestAIProfileValidation: + def test_get_empty_returns_null(self, client): + with patch('routers.users.get_ai_user_profile', return_value=None): + resp = client.get("/v1/users/ai-profile", headers=AUTH) + assert resp.status_code == 200 + assert resp.json() is None + + def test_patch_valid_rfc3339_z(self, client): + data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00Z", "data_sources_used": 1} + with patch('routers.users.update_ai_user_profile', return_value=data): + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 200 + + def test_patch_valid_rfc3339_offset(self, client): + data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00+05:30", "data_sources_used": 1} + with patch('routers.users.update_ai_user_profile', return_value=data): + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 200 + + def test_patch_valid_rfc3339_fractional(self, client): + data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00.123Z", "data_sources_used": 1} + with patch('routers.users.update_ai_user_profile', return_value=data): + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 200 + + def test_patch_invalid_no_timezone(self, client): + data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00", "data_sources_used": 1} + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_invalid_no_t_separator(self, client): + data = {"profile_text": "test", "generated_at": "2026-03-05 10:00:00Z", "data_sources_used": 1} + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_invalid_short_offset(self, client): + data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00+00", "data_sources_used": 1} + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_invalid_garbage(self, client): + data = {"profile_text": "test", "generated_at": "not-a-date", "data_sources_used": 1} + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_profile_text_truncation_at_boundary(self, client): + # 10001 bytes of ASCII should truncate to 10000 + long_text = "x" * 10001 + data = {"profile_text": long_text, "generated_at": "2026-03-05T10:00:00Z", "data_sources_used": 1} + with patch('routers.users.update_ai_user_profile') as mock_update: + mock_update.return_value = {"profile_text": "x" * 10000} + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 200 + call_data = mock_update.call_args[0][1] + assert len(call_data['profile_text']) == 10000 + + def test_profile_text_multibyte_truncation(self, client): + # Multibyte UTF-8: emoji is 4 bytes, test boundary doesn't split mid-char + text = "a" * 9998 + "\U0001F600" # 9998 + 4 bytes = 10002 bytes + data = {"profile_text": text, "generated_at": "2026-03-05T10:00:00Z", "data_sources_used": 1} + with patch('routers.users.update_ai_user_profile') as mock_update: + mock_update.return_value = {} + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 200 + call_data = mock_update.call_args[0][1] + # Should not have broken emoji — truncated to 9998 'a's + assert len(call_data['profile_text'].encode('utf-8')) <= 10000 diff --git a/backend/tests/unit/test_screen_activity_sync.py b/backend/tests/unit/test_screen_activity_sync.py new file mode 100644 index 0000000000..774b4bef65 --- /dev/null +++ b/backend/tests/unit/test_screen_activity_sync.py @@ -0,0 +1,79 @@ +import threading +from unittest.mock import patch, MagicMock + +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture +def client(): + with patch('database.screen_activity.db'), \ + patch('database.vector_db.Pinecone'), \ + patch('database.vector_db.pc'), \ + patch('database.vector_db.index'), \ + patch('utils.llm.clients.embeddings'): + from main import app + with TestClient(app) as c: + yield c + + +AUTH = {"Authorization": "Bearer 123testuser"} + + +class TestScreenActivitySyncValidation: + def test_empty_rows_returns_zero(self, client): + resp = client.post("/v1/screen-activity/sync", json={"rows": []}, headers=AUTH) + assert resp.status_code == 200 + assert resp.json() == {"synced": 0, "last_id": 0} + + def test_exceeds_100_rows_returns_400(self, client): + rows = [{"id": i, "timestamp": "2026-01-01T00:00:00Z", "appName": "A", "windowTitle": "W", "ocrText": "x"} for i in range(101)] + resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH) + assert resp.status_code == 400 + assert "100" in resp.json()["detail"] + + def test_exactly_100_rows_accepted(self, client): + rows = [{"id": i, "timestamp": "2026-01-01T00:00:00Z"} for i in range(100)] + with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', return_value=100): + resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["synced"] == 100 + + def test_no_auth_returns_401(self, client): + resp = client.post("/v1/screen-activity/sync", json={"rows": []}) + assert resp.status_code == 401 + + def test_last_id_is_max_from_batch(self, client): + rows = [ + {"id": 5, "timestamp": "2026-01-01T00:00:00Z"}, + {"id": 99, "timestamp": "2026-01-01T00:01:00Z"}, + {"id": 3, "timestamp": "2026-01-01T00:02:00Z"}, + ] + with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', return_value=3): + resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["last_id"] == 99 + + def test_firestore_error_returns_500(self, client): + rows = [{"id": 1, "timestamp": "2026-01-01T00:00:00Z"}] + with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', side_effect=Exception("Firestore down")): + resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH) + assert resp.status_code == 500 + + def test_rows_with_embeddings_spawn_thread(self, client): + rows = [{"id": 1, "timestamp": "2026-01-01T00:00:00Z", "embedding": [0.1] * 3072}] + with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', return_value=1), \ + patch('routers.screen_activity.threading.Thread') as mock_thread: + mock_thread.return_value = MagicMock() + resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH) + assert resp.status_code == 200 + mock_thread.assert_called_once() + mock_thread.return_value.start.assert_called_once() + + def test_rows_without_embeddings_no_thread(self, client): + rows = [{"id": 1, "timestamp": "2026-01-01T00:00:00Z"}] + with patch('routers.screen_activity.screen_activity_db.upsert_screen_activity', return_value=1), \ + patch('routers.screen_activity.threading.Thread') as mock_thread: + resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH) + assert resp.status_code == 200 + mock_thread.assert_not_called() From 4ac4ad9846d6a9cecb51eb80e4ac86c44a0947f2 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:16:11 +0100 Subject: [PATCH 070/163] Address tester gaps: add 11 more tests for coverage - Prompt validation for task/advice/memory sections (not just focus) - Complementary confidence bounds (advice below 0, memory above 1) - model_dump excludes None fields before DB call - update_channel propagation - Invalid calendar date (Feb 30) regex-pass but parse-fail - generated_at stored as datetime with timezone - _upsert_vectors_background success and exception logging paths Co-Authored-By: Claude Opus 4.6 --- .../test_assistant_settings_ai_profile.py | 61 +++++++++++++++++++ .../tests/unit/test_screen_activity_sync.py | 15 +++++ 2 files changed, 76 insertions(+) diff --git a/backend/tests/unit/test_assistant_settings_ai_profile.py b/backend/tests/unit/test_assistant_settings_ai_profile.py index 402dabb5e6..08908620a6 100644 --- a/backend/tests/unit/test_assistant_settings_ai_profile.py +++ b/backend/tests/unit/test_assistant_settings_ai_profile.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone from unittest.mock import patch, MagicMock import pytest @@ -86,6 +87,50 @@ def test_patch_empty_body_returns_current(self, client): resp = client.patch("/v1/users/assistant-settings", json={}, headers=AUTH) assert resp.status_code == 200 + def test_patch_task_prompt_exceeds_10000_chars(self, client): + data = {"task": {"analysis_prompt": "x" * 10001}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + assert "task" in resp.json()["detail"] + + def test_patch_advice_prompt_exceeds_10000_chars(self, client): + data = {"advice": {"analysis_prompt": "x" * 10001}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + assert "advice" in resp.json()["detail"] + + def test_patch_memory_prompt_exceeds_10000_chars(self, client): + data = {"memory": {"analysis_prompt": "x" * 10001}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + assert "memory" in resp.json()["detail"] + + def test_patch_advice_confidence_below_zero(self, client): + data = {"advice": {"min_confidence": -0.1}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_memory_confidence_above_one(self, client): + data = {"memory": {"min_confidence": 1.5}} + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_excludes_none_fields(self, client): + data = {"task": {"enabled": True}} + with patch('routers.users.update_assistant_settings', return_value=data) as mock_update: + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 200 + call_data = mock_update.call_args[0][1] + assert "min_confidence" not in call_data.get("task", {}) + + def test_patch_update_channel(self, client): + data = {"update_channel": "beta"} + with patch('routers.users.update_assistant_settings', return_value={"update_channel": "beta"}) as mock_update: + resp = client.patch("/v1/users/assistant-settings", json=data, headers=AUTH) + assert resp.status_code == 200 + call_data = mock_update.call_args[0][1] + assert call_data["update_channel"] == "beta" + class TestAIProfileValidation: def test_get_empty_returns_null(self, client): @@ -132,6 +177,22 @@ def test_patch_invalid_garbage(self, client): resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) assert resp.status_code == 400 + def test_patch_invalid_calendar_date(self, client): + # Feb 30 passes regex but fails fromisoformat + data = {"profile_text": "test", "generated_at": "2026-02-30T10:00:00Z", "data_sources_used": 1} + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_patch_generated_at_stored_as_datetime(self, client): + data = {"profile_text": "test", "generated_at": "2026-03-05T10:00:00Z", "data_sources_used": 1} + with patch('routers.users.update_ai_user_profile') as mock_update: + mock_update.return_value = {} + resp = client.patch("/v1/users/ai-profile", json=data, headers=AUTH) + assert resp.status_code == 200 + call_data = mock_update.call_args[0][1] + assert isinstance(call_data['generated_at'], datetime) + assert call_data['generated_at'].tzinfo is not None + def test_profile_text_truncation_at_boundary(self, client): # 10001 bytes of ASCII should truncate to 10000 long_text = "x" * 10001 diff --git a/backend/tests/unit/test_screen_activity_sync.py b/backend/tests/unit/test_screen_activity_sync.py index 774b4bef65..bb15a6ec40 100644 --- a/backend/tests/unit/test_screen_activity_sync.py +++ b/backend/tests/unit/test_screen_activity_sync.py @@ -77,3 +77,18 @@ def test_rows_without_embeddings_no_thread(self, client): resp = client.post("/v1/screen-activity/sync", json={"rows": rows}, headers=AUTH) assert resp.status_code == 200 mock_thread.assert_not_called() + + +class TestUpsertVectorsBackground: + def test_vector_upsert_exception_is_logged(self): + from routers.screen_activity import _upsert_vectors_background + with patch('routers.screen_activity.vector_db.upsert_screen_activity_vectors', side_effect=Exception("Pinecone down")), \ + patch('routers.screen_activity.logger') as mock_logger: + _upsert_vectors_background("uid123", [{"id": 1, "embedding": [0.1]}]) + mock_logger.exception.assert_called_once() + + def test_vector_upsert_success(self): + from routers.screen_activity import _upsert_vectors_background + with patch('routers.screen_activity.vector_db.upsert_screen_activity_vectors') as mock_upsert: + _upsert_vectors_background("uid123", [{"id": 1, "embedding": [0.1]}]) + mock_upsert.assert_called_once_with("uid123", [{"id": 1, "embedding": [0.1]}]) From 1625094f42e9a55c2bacebbc960650e5699a7f79 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:17:41 +0100 Subject: [PATCH 071/163] Return 404 on rating when message not found Rating endpoint now checks update_message_rating return value and raises 404 instead of silently returning 200. Co-Authored-By: Claude Opus 4.6 --- backend/routers/chat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/routers/chat.py b/backend/routers/chat.py index 6651e75651..27d1ae2150 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -672,7 +672,7 @@ def save_message( if request.session_id: try: - chat_db.add_message_to_chat_session(uid, request.session_id, message_id) + chat_db.add_message_to_chat_session(uid, request.session_id, message_id, preview=request.text[:200]) except Exception as e: logger.warning(f"Failed to link message to session {request.session_id}: {e}") @@ -689,7 +689,9 @@ def rate_message( if request.rating is not None and request.rating not in (1, -1): raise HTTPException(status_code=422, detail="rating must be 1, -1, or null") - chat_db.update_message_rating(uid, message_id, request.rating) + success = chat_db.update_message_rating(uid, message_id, request.rating) + if not success: + raise HTTPException(status_code=404, detail="Message not found") return StatusResponse(status='ok') From 5786f3bc74a95b87fc333ee22d3b6375471ae81e Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:17:47 +0100 Subject: [PATCH 072/163] Update session metadata when saving messages add_message_to_chat_session now updates updated_at, message_count (via Increment), and preview text alongside the message_ids array. Also documents pagination limitation in get_chat_sessions. Co-Authored-By: Claude Opus 4.6 --- backend/database/chat.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/backend/database/chat.py b/backend/database/chat.py index e57c7f9b17..0cd1b887a2 100644 --- a/backend/database/chat.py +++ b/backend/database/chat.py @@ -471,7 +471,12 @@ def delete_chat_session(uid, chat_session_id): def get_chat_sessions( uid: str, app_id: Optional[str] = None, limit: int = 50, offset: int = 0, starred: Optional[bool] = None ): - """List chat sessions with optional filters.""" + """List chat sessions with optional filters. + + Note: Client-side sort + slice because Firestore composite indexes would be + needed for every filter combination. Acceptable for desktop users (low session + counts). Revisit with server-side ordering if session counts grow large. + """ sessions_ref = db.collection('users').document(uid).collection('chat_sessions') if app_id is not None: sessions_ref = sessions_ref.where(filter=FieldFilter('plugin_id', '==', app_id)) @@ -515,10 +520,17 @@ def delete_chat_session_messages(uid: str, chat_session_id: str): logger.info(f"Deleted {count} messages for session {chat_session_id}") -def add_message_to_chat_session(uid: str, chat_session_id: str, message_id: str): +def add_message_to_chat_session(uid: str, chat_session_id: str, message_id: str, preview: str = None): user_ref = db.collection('users').document(uid) session_ref = user_ref.collection('chat_sessions').document(chat_session_id) - session_ref.update({"message_ids": firestore.ArrayUnion([message_id])}) + update_data = { + "message_ids": firestore.ArrayUnion([message_id]), + "updated_at": datetime.now(timezone.utc), + "message_count": firestore.Increment(1), + } + if preview: + update_data["preview"] = preview[:200] + session_ref.update(update_data) def add_files_to_chat_session(uid: str, chat_session_id: str, file_ids: List[str]): From 3031442cae0df040c428899c29be3860d50a1289 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:17:52 +0100 Subject: [PATCH 073/163] Remove unused timezone and input_device_name from from-segments request Fields were accepted but never stored. Remove to avoid contract drift. Co-Authored-By: Claude Opus 4.6 --- backend/routers/conversations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/backend/routers/conversations.py b/backend/routers/conversations.py index 3845e44851..c2041779a2 100644 --- a/backend/routers/conversations.py +++ b/backend/routers/conversations.py @@ -108,8 +108,6 @@ class CreateConversationFromSegmentsRequest(BaseModel): started_at: Optional[datetime] = None finished_at: Optional[datetime] = None language: Optional[str] = 'en' - timezone: Optional[str] = None - input_device_name: Optional[str] = None geolocation: Optional[Geolocation] = None From b9d9032b1271a5027de4cb0d08fafc52b625ba0f Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:17:57 +0100 Subject: [PATCH 074/163] Update tests for reviewer round 2 fixes Add rating-not-found 404 test, update mock return values for rating endpoint, remove tests for dropped timezone/input_device fields. Co-Authored-By: Claude Opus 4.6 --- backend/tests/unit/test_desktop_chat.py | 18 +++++++++++++++--- backend/tests/unit/test_from_segments.py | 12 +----------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/backend/tests/unit/test_desktop_chat.py b/backend/tests/unit/test_desktop_chat.py index 75e2c86f24..fc996f30ac 100644 --- a/backend/tests/unit/test_desktop_chat.py +++ b/backend/tests/unit/test_desktop_chat.py @@ -222,7 +222,7 @@ def test_save_message_invalid_sender_422(self, client): def test_rate_message_thumbs_up(self, client): with ( patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.chat.chat_db.update_message_rating') as mock_rate, + patch('routers.chat.chat_db.update_message_rating', return_value=True) as mock_rate, ): response = client.patch( '/v2/messages/msg-1/rating', @@ -237,7 +237,7 @@ def test_rate_message_thumbs_up(self, client): def test_rate_message_clear(self, client): with ( patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.chat.chat_db.update_message_rating') as mock_rate, + patch('routers.chat.chat_db.update_message_rating', return_value=True) as mock_rate, ): response = client.patch( '/v2/messages/msg-1/rating', @@ -252,7 +252,7 @@ def test_rate_message_clear(self, client): def test_rate_message_thumbs_down(self, client): with ( patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.chat.chat_db.update_message_rating') as mock_rate, + patch('routers.chat.chat_db.update_message_rating', return_value=True) as mock_rate, ): response = client.patch( '/v2/messages/msg-1/rating', @@ -262,6 +262,18 @@ def test_rate_message_thumbs_down(self, client): assert response.status_code == 200 assert mock_rate.call_args[0][2] == -1 + def test_rate_message_not_found_404(self, client): + with ( + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.update_message_rating', return_value=False), + ): + response = client.patch( + '/v2/messages/msg-missing/rating', + json={'rating': 1}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 404 + def test_rate_message_invalid_value_422(self, client): with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): response = client.patch( diff --git a/backend/tests/unit/test_from_segments.py b/backend/tests/unit/test_from_segments.py index 7199bcb80f..19b93a2391 100644 --- a/backend/tests/unit/test_from_segments.py +++ b/backend/tests/unit/test_from_segments.py @@ -48,8 +48,7 @@ def test_request_defaults(self, valid_segments): assert req.language == "en" assert req.started_at is None assert req.finished_at is None - assert req.timezone is None - assert req.input_device_name is None + assert req.geolocation is None def test_response_model(self): resp = FromSegmentsResponse(id="conv123", status="completed", discarded=False) @@ -80,15 +79,6 @@ def test_custom_source(self, valid_segments): req = CreateConversationFromSegmentsRequest(transcript_segments=valid_segments, source="phone") assert req.source == "phone" - def test_timezone_and_input_device_accepted(self, valid_segments): - req = CreateConversationFromSegmentsRequest( - transcript_segments=valid_segments, - timezone="America/New_York", - input_device_name="MacBook Pro Microphone", - ) - assert req.timezone == "America/New_York" - assert req.input_device_name == "MacBook Pro Microphone" - def test_started_finished_at(self, valid_segments): now = datetime.now(timezone.utc) later = now + timedelta(minutes=5) From b416664252f18c8ee11cd11c9b32607b49efed41 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:21:26 +0100 Subject: [PATCH 075/163] Remove timezone and inputDeviceName from Swift from-segments request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Match backend contract — fields were not stored. Can re-add when the conversation model supports them. Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/APIClient.swift | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index b08e4f3554..a0bfcf5f0e 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -1187,8 +1187,6 @@ extension APIClient { let startedAt: String let finishedAt: String let language: String - let timezone: String - let inputDeviceName: String? enum CodingKeys: String, CodingKey { case transcriptSegments = "transcript_segments" @@ -1196,8 +1194,6 @@ extension APIClient { case startedAt = "started_at" case finishedAt = "finished_at" case language - case timezone - case inputDeviceName = "input_device_name" } } @@ -1233,16 +1229,12 @@ extension APIClient { /// - finishedAt: When the recording finished /// - source: Source of the conversation (e.g., "desktop", "omi", "bee") /// - language: Language code for transcription - /// - timezone: User's timezone - /// - inputDeviceName: Name of the input device (microphone or BLE device) func createConversationFromSegments( segments: [TranscriptSegmentRequest], startedAt: Date, finishedAt: Date, source: ConversationSource = .desktop, - language: String = "en", - timezone: String = "UTC", - inputDeviceName: String? = nil + language: String = "en" ) async throws -> CreateConversationResponse { let formatter = ISO8601DateFormatter() formatter.formatOptions = [.withInternetDateTime, .withFractionalSeconds] @@ -1252,9 +1244,7 @@ extension APIClient { source: source.rawValue, startedAt: formatter.string(from: startedAt), finishedAt: formatter.string(from: finishedAt), - language: language, - timezone: timezone, - inputDeviceName: inputDeviceName + language: language ) return try await post("v1/conversations/from-segments", body: request) From ffde130335d3bd23173f58424f278ed6b3867998 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:21:30 +0100 Subject: [PATCH 076/163] Remove inputDeviceName param from AppState conversation upload Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/AppState.swift | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 862bfaf070..26c41d9d11 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -2119,8 +2119,7 @@ class AppState: ObservableObject { segments: apiSegments, startedAt: startTime, finishedAt: endTime, - source: currentConversationSource, - inputDeviceName: recordingInputDeviceName + source: currentConversationSource ) log("Transcription: Conversation saved - id=\(response.id), status=\(response.status), discarded=\(response.discarded), source=\(currentConversationSource.rawValue), device=\(recordingInputDeviceName ?? "Unknown")") From 535db37bbdd79a92ccde1fee97e698ef935f5012 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:21:34 +0100 Subject: [PATCH 077/163] Remove timezone and inputDeviceName from retry service upload Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/TranscriptionRetryService.swift | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/desktop/Desktop/Sources/TranscriptionRetryService.swift b/desktop/Desktop/Sources/TranscriptionRetryService.swift index ead4ee2ee5..913c6fd945 100644 --- a/desktop/Desktop/Sources/TranscriptionRetryService.swift +++ b/desktop/Desktop/Sources/TranscriptionRetryService.swift @@ -266,9 +266,7 @@ class TranscriptionRetryService { startedAt: session.startedAt, finishedAt: session.finishedAt ?? Date(), source: source, - language: session.language, - timezone: session.timezone, - inputDeviceName: session.inputDeviceName + language: session.language ) log("TranscriptionRetryService: Session \(sessionId) uploaded successfully (backendId: \(response.id))") From fee7fcb8f0f9b2c9ad4d6d8990a6cd50d45c5b78 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 08:26:23 +0100 Subject: [PATCH 078/163] Add tester-requested boundary and filter tests Chat sessions: malformed body 422, limit max=200 valid, limit=201 422, app_id filter, starred filter. From-segments: 500 boundary success test. Co-Authored-By: Claude Opus 4.6 --- backend/tests/unit/test_desktop_chat.py | 52 ++++++++++++++++++++++++ backend/tests/unit/test_from_segments.py | 20 +++++++++ 2 files changed, 72 insertions(+) diff --git a/backend/tests/unit/test_desktop_chat.py b/backend/tests/unit/test_desktop_chat.py index fc996f30ac..00fe444042 100644 --- a/backend/tests/unit/test_desktop_chat.py +++ b/backend/tests/unit/test_desktop_chat.py @@ -321,6 +321,58 @@ def test_save_message_session_link_failure_still_succeeds(self, client): assert response.status_code == 200 assert 'id' in response.json() + def test_create_session_malformed_body_422(self, client): + with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): + response = client.post( + '/v2/chat-sessions', + content=b'not json', + headers={'Authorization': 'Bearer test', 'Content-Type': 'application/json'}, + ) + assert response.status_code == 422 + + def test_list_sessions_limit_max_valid(self, client): + with ( + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_sessions', return_value=[]), + ): + response = client.get( + '/v2/chat-sessions?limit=200', + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + + def test_list_sessions_limit_over_max_422(self, client): + with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): + response = client.get( + '/v2/chat-sessions?limit=201', + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 + + def test_list_sessions_app_id_filter(self, client): + with ( + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_sessions', return_value=[]) as mock_get, + ): + response = client.get( + '/v2/chat-sessions?app_id=my-app', + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert mock_get.call_args[1]['app_id'] == 'my-app' + + def test_list_sessions_starred_filter(self, client): + with ( + patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.chat.chat_db.get_chat_sessions', return_value=[]) as mock_get, + ): + response = client.get( + '/v2/chat-sessions?starred=true', + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert mock_get.call_args[1]['starred'] is True + def test_list_sessions_limit_validation(self, client): with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): response = client.get( diff --git a/backend/tests/unit/test_from_segments.py b/backend/tests/unit/test_from_segments.py index 19b93a2391..068e3134f9 100644 --- a/backend/tests/unit/test_from_segments.py +++ b/backend/tests/unit/test_from_segments.py @@ -227,6 +227,26 @@ def test_empty_segments_list_returns_422(self, client): ) assert response.status_code == 422 + def test_exactly_500_segments_succeeds(self, client): + with ( + patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'), + patch('routers.conversations.process_conversation') as mock_process, + patch('routers.conversations.get_google_maps_location'), + ): + mock_conv = MagicMock() + mock_conv.id = 'conv-500' + mock_conv.status.value = 'completed' + mock_conv.discarded = False + mock_process.return_value = mock_conv + + segments = [{'text': f'seg {i}', 'start': float(i), 'end': float(i + 1)} for i in range(500)] + response = client.post( + '/v1/conversations/from-segments', + json={'transcript_segments': segments}, + headers={'Authorization': 'Bearer test-token'}, + ) + assert response.status_code == 200 + def test_over_500_segments_returns_422(self, client): with patch('routers.conversations.auth.get_current_user_uid', return_value='test-uid-123'): segments = [{'text': f'seg {i}', 'start': float(i), 'end': float(i + 1)} for i in range(501)] From fd847a612abad7e41126ad108df020c7c1b98111 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:18:15 +0100 Subject: [PATCH 079/163] Add Firestore CRUD for focus sessions Collection: users/{uid}/focus_sessions/{session_id} Functions: create, get (with date filter), delete, get_for_stats Co-Authored-By: Claude Opus 4.6 --- backend/database/focus_sessions.py | 75 ++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 backend/database/focus_sessions.py diff --git a/backend/database/focus_sessions.py b/backend/database/focus_sessions.py new file mode 100644 index 0000000000..b4a1e0f54e --- /dev/null +++ b/backend/database/focus_sessions.py @@ -0,0 +1,75 @@ +import logging +import uuid +from datetime import datetime, timezone +from typing import List, Dict, Any, Optional + +from google.cloud import firestore + +from ._client import db + +logger = logging.getLogger(__name__) + +USERS_COLLECTION = 'users' +FOCUS_SESSIONS_SUBCOLLECTION = 'focus_sessions' + + +def _collection_ref(uid: str): + return db.collection(USERS_COLLECTION).document(uid).collection(FOCUS_SESSIONS_SUBCOLLECTION) + + +def create_focus_session(uid: str, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new focus session document. Returns the created document with id.""" + session_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + doc_data = { + 'status': data['status'], + 'app_or_site': data['app_or_site'], + 'description': data['description'], + 'created_at': now, + } + if data.get('message') is not None: + doc_data['message'] = data['message'] + if data.get('duration_seconds') is not None: + doc_data['duration_seconds'] = data['duration_seconds'] + + _collection_ref(uid).document(session_id).set(doc_data) + + doc_data['id'] = session_id + return doc_data + + +def get_focus_sessions( + uid: str, + limit: int = 100, + offset: int = 0, + date: Optional[str] = None, +) -> List[Dict[str, Any]]: + """Query focus sessions, ordered by created_at DESC. Optional date filter (YYYY-MM-DD).""" + query = _collection_ref(uid).order_by('created_at', direction=firestore.Query.DESCENDING) + + if date: + day_start = datetime.strptime(date, '%Y-%m-%d').replace(tzinfo=timezone.utc) + day_end = day_start.replace(hour=23, minute=59, second=59) + query = query.where(filter=firestore.FieldFilter('created_at', '>=', day_start)) + query = query.where(filter=firestore.FieldFilter('created_at', '<=', day_end)) + + query = query.offset(offset).limit(limit) + + results = [] + for doc in query.stream(): + data = doc.to_dict() + data['id'] = doc.id + results.append(data) + return results + + +def delete_focus_session(uid: str, session_id: str) -> bool: + """Delete a focus session document. Returns True on success.""" + _collection_ref(uid).document(session_id).delete() + return True + + +def get_focus_sessions_for_stats(uid: str, date: str) -> List[Dict[str, Any]]: + """Get up to 1000 sessions for a date, for stats computation.""" + return get_focus_sessions(uid, limit=1000, offset=0, date=date) From 278f97035eed60995c34f6c2d98a27a9ef5372d7 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:18:20 +0100 Subject: [PATCH 080/163] Add Firestore CRUD for advice Collection: users/{uid}/advice/{advice_id} Functions: create, get (with category/dismissed filters), update, delete, mark_all_read Co-Authored-By: Claude Opus 4.6 --- backend/database/advice.py | 107 +++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 backend/database/advice.py diff --git a/backend/database/advice.py b/backend/database/advice.py new file mode 100644 index 0000000000..e03c697181 --- /dev/null +++ b/backend/database/advice.py @@ -0,0 +1,107 @@ +import logging +import uuid +from datetime import datetime, timezone +from typing import List, Dict, Any, Optional + +from google.cloud import firestore + +from ._client import db + +logger = logging.getLogger(__name__) + +USERS_COLLECTION = 'users' +ADVICE_SUBCOLLECTION = 'advice' + + +def _collection_ref(uid: str): + return db.collection(USERS_COLLECTION).document(uid).collection(ADVICE_SUBCOLLECTION) + + +def create_advice(uid: str, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new advice document. Returns the created document with id.""" + advice_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + doc_data = { + 'content': data['content'], + 'category': data.get('category', 'other'), + 'confidence': data.get('confidence', 0.5), + 'is_read': False, + 'is_dismissed': False, + 'created_at': now, + } + for optional_field in ('reasoning', 'source_app', 'context_summary', 'current_activity'): + if data.get(optional_field) is not None: + doc_data[optional_field] = data[optional_field] + + _collection_ref(uid).document(advice_id).set(doc_data) + + doc_data['id'] = advice_id + return doc_data + + +def get_advice( + uid: str, + limit: int = 100, + offset: int = 0, + category: Optional[str] = None, + include_dismissed: bool = False, +) -> List[Dict[str, Any]]: + """Query advice, ordered by created_at DESC.""" + query = _collection_ref(uid).order_by('created_at', direction=firestore.Query.DESCENDING) + + if not include_dismissed: + query = query.where(filter=firestore.FieldFilter('is_dismissed', '==', False)) + if category: + query = query.where(filter=firestore.FieldFilter('category', '==', category)) + + query = query.offset(offset).limit(limit) + + results = [] + for doc in query.stream(): + data = doc.to_dict() + data['id'] = doc.id + results.append(data) + return results + + +def update_advice(uid: str, advice_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update an advice document (is_read, is_dismissed). Returns updated doc.""" + doc_ref = _collection_ref(uid).document(advice_id) + + update_data = {'updated_at': datetime.now(timezone.utc)} + if 'is_read' in data: + update_data['is_read'] = data['is_read'] + if 'is_dismissed' in data: + update_data['is_dismissed'] = data['is_dismissed'] + + doc_ref.update(update_data) + + doc = doc_ref.get() + if doc.exists: + result = doc.to_dict() + result['id'] = doc.id + return result + return None + + +def delete_advice(uid: str, advice_id: str) -> bool: + """Delete an advice document. Returns True on success.""" + _collection_ref(uid).document(advice_id).delete() + return True + + +def mark_all_advice_read(uid: str) -> int: + """Mark all unread, non-dismissed advice as read. Returns count of marked items.""" + query = _collection_ref(uid).where( + filter=firestore.FieldFilter('is_dismissed', '==', False) + ).where( + filter=firestore.FieldFilter('is_read', '==', False) + ).limit(1000) + + count = 0 + now = datetime.now(timezone.utc) + for doc in query.stream(): + doc.reference.update({'is_read': True, 'updated_at': now}) + count += 1 + return count From 338a6f476d702d0a39c8998ef064186fb6e3e3bf Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:18:25 +0100 Subject: [PATCH 081/163] Add focus sessions router with 4 endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit POST /v1/focus-sessions — create session (focused/distracted) GET /v1/focus-sessions — list with date filter, pagination DELETE /v1/focus-sessions/{id} — delete session GET /v1/focus-stats — daily stats with top 5 distractions Co-Authored-By: Claude Opus 4.6 --- backend/routers/focus_sessions.py | 154 ++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 backend/routers/focus_sessions.py diff --git a/backend/routers/focus_sessions.py b/backend/routers/focus_sessions.py new file mode 100644 index 0000000000..75c6702311 --- /dev/null +++ b/backend/routers/focus_sessions.py @@ -0,0 +1,154 @@ +import logging +from collections import defaultdict +from datetime import datetime, timezone +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field + +import database.focus_sessions as focus_sessions_db +from utils.other import endpoints as auth + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +class CreateFocusSessionRequest(BaseModel): + status: str = Field(description="'focused' or 'distracted'") + app_or_site: str = Field(description="App or website name") + description: str = Field(description="Brief description of the session") + message: Optional[str] = Field(default=None, description="Optional coaching message") + duration_seconds: Optional[int] = Field(default=None, description="Optional session duration in seconds") + + +class FocusSessionResponse(BaseModel): + id: str + status: str + app_or_site: str + description: str + message: Optional[str] = None + created_at: datetime + duration_seconds: Optional[int] = None + + +class FocusSessionStatusResponse(BaseModel): + status: str + + +class DistractionEntry(BaseModel): + app_or_site: str + total_seconds: int + count: int + + +class FocusStatsResponse(BaseModel): + date: str + focused_minutes: int + distracted_minutes: int + session_count: int + focused_count: int + distracted_count: int + top_distractions: List[DistractionEntry] + + +def _validate_focus_status(status: str): + if status not in ('focused', 'distracted'): + raise HTTPException(status_code=400, detail="status must be 'focused' or 'distracted'") + + +@router.post('/v1/focus-sessions', response_model=FocusSessionResponse, status_code=201, tags=['focus-sessions']) +def create_focus_session( + request: CreateFocusSessionRequest, + uid: str = Depends(auth.get_current_user_uid), +): + _validate_focus_status(request.status) + try: + session = focus_sessions_db.create_focus_session(uid, request.model_dump()) + return session + except Exception: + logger.exception('Failed to create focus session for uid=%s', uid) + raise HTTPException(status_code=500, detail="Failed to create focus session") + + +@router.get('/v1/focus-sessions', response_model=List[FocusSessionResponse], tags=['focus-sessions']) +def get_focus_sessions( + limit: int = Query(default=100, ge=1, le=1000), + offset: int = Query(default=0, ge=0), + date: Optional[str] = Query(default=None, description="Filter by date (YYYY-MM-DD)"), + uid: str = Depends(auth.get_current_user_uid), +): + if date: + try: + datetime.strptime(date, '%Y-%m-%d') + except ValueError: + raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD format") + try: + return focus_sessions_db.get_focus_sessions(uid, limit=limit, offset=offset, date=date) + except Exception: + logger.exception('Failed to get focus sessions for uid=%s', uid) + return [] + + +@router.delete('/v1/focus-sessions/{session_id}', response_model=FocusSessionStatusResponse, tags=['focus-sessions']) +def delete_focus_session( + session_id: str, + uid: str = Depends(auth.get_current_user_uid), +): + try: + focus_sessions_db.delete_focus_session(uid, session_id) + return FocusSessionStatusResponse(status="ok") + except Exception: + logger.exception('Failed to delete focus session %s for uid=%s', session_id, uid) + raise HTTPException(status_code=500, detail="Failed to delete focus session") + + +@router.get('/v1/focus-stats', response_model=FocusStatsResponse, tags=['focus-sessions']) +def get_focus_stats( + date: Optional[str] = Query(default=None, description="Date for stats (YYYY-MM-DD), defaults to today"), + uid: str = Depends(auth.get_current_user_uid), +): + if date: + try: + datetime.strptime(date, '%Y-%m-%d') + except ValueError: + raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD format") + else: + date = datetime.now(timezone.utc).strftime('%Y-%m-%d') + + try: + sessions = focus_sessions_db.get_focus_sessions_for_stats(uid, date) + except Exception: + logger.exception('Failed to get focus stats for uid=%s', uid) + raise HTTPException(status_code=500, detail="Failed to get focus stats") + + focused_count = 0 + distracted_count = 0 + distraction_map = defaultdict(lambda: {'total_seconds': 0, 'count': 0}) + + for s in sessions: + status = s.get('status', '') + if status == 'focused': + focused_count += 1 + elif status == 'distracted': + distracted_count += 1 + app = s.get('app_or_site', 'Unknown') + duration = s.get('duration_seconds') or 60 + distraction_map[app]['total_seconds'] += duration + distraction_map[app]['count'] += 1 + + top_distractions = sorted( + [DistractionEntry(app_or_site=app, **vals) for app, vals in distraction_map.items()], + key=lambda d: d.total_seconds, + reverse=True, + )[:5] + + return FocusStatsResponse( + date=date, + focused_minutes=focused_count, + distracted_minutes=distracted_count, + session_count=focused_count + distracted_count, + focused_count=focused_count, + distracted_count=distracted_count, + top_distractions=top_distractions, + ) From 53a352a646e8d480b9259b19cb9fde4e4aa78bf6 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:18:30 +0100 Subject: [PATCH 082/163] Add advice router with 5 endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit POST /v1/advice — create with category/confidence validation GET /v1/advice — list with category filter, dismissed toggle, pagination PATCH /v1/advice/{id} — update is_read/is_dismissed DELETE /v1/advice/{id} — delete advice POST /v1/advice/mark-all-read — batch mark unread as read Co-Authored-By: Claude Opus 4.6 --- backend/routers/advice.py | 139 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 backend/routers/advice.py diff --git a/backend/routers/advice.py b/backend/routers/advice.py new file mode 100644 index 0000000000..d1c36c5e12 --- /dev/null +++ b/backend/routers/advice.py @@ -0,0 +1,139 @@ +import logging +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field + +import database.advice as advice_db +from utils.other import endpoints as auth + +logger = logging.getLogger(__name__) + +router = APIRouter() + +VALID_CATEGORIES = ('productivity', 'health', 'communication', 'learning', 'other') + + +class CreateAdviceRequest(BaseModel): + content: str = Field(description="Advice content text") + category: Optional[str] = Field(default=None, description="Category: productivity, health, communication, learning, other") + reasoning: Optional[str] = Field(default=None, description="Reasoning behind the advice") + source_app: Optional[str] = Field(default=None, description="App where context was observed") + confidence: Optional[float] = Field(default=None, description="Confidence score 0.0-1.0") + context_summary: Optional[str] = Field(default=None, description="Context summary") + current_activity: Optional[str] = Field(default=None, description="User's current activity") + + +class UpdateAdviceRequest(BaseModel): + is_read: Optional[bool] = None + is_dismissed: Optional[bool] = None + + +class AdviceResponse(BaseModel): + id: str + content: str + category: str = 'other' + reasoning: Optional[str] = None + source_app: Optional[str] = None + confidence: float = 0.5 + context_summary: Optional[str] = None + current_activity: Optional[str] = None + created_at: object = None + updated_at: object = None + is_read: bool = False + is_dismissed: bool = False + + +class AdviceStatusResponse(BaseModel): + status: str + + +def _validate_category(category: Optional[str]): + if category and category not in VALID_CATEGORIES: + raise HTTPException( + status_code=400, + detail=f"category must be one of: {', '.join(VALID_CATEGORIES)}" + ) + + +def _validate_confidence(confidence: Optional[float]): + if confidence is not None and not (0.0 <= confidence <= 1.0): + raise HTTPException(status_code=400, detail="confidence must be between 0.0 and 1.0") + + +@router.post('/v1/advice', status_code=201, tags=['advice']) +def create_advice( + request: CreateAdviceRequest, + uid: str = Depends(auth.get_current_user_uid), +): + _validate_category(request.category) + _validate_confidence(request.confidence) + try: + return advice_db.create_advice(uid, request.model_dump(exclude_none=True)) + except Exception: + logger.exception('Failed to create advice for uid=%s', uid) + raise HTTPException(status_code=500, detail="Failed to create advice") + + +@router.get('/v1/advice', tags=['advice']) +def get_advice( + limit: int = Query(default=100, ge=1, le=1000), + offset: int = Query(default=0, ge=0), + category: Optional[str] = Query(default=None), + include_dismissed: bool = Query(default=False), + uid: str = Depends(auth.get_current_user_uid), +): + _validate_category(category) + try: + return advice_db.get_advice( + uid, limit=limit, offset=offset, category=category, include_dismissed=include_dismissed, + ) + except Exception: + logger.exception('Failed to get advice for uid=%s', uid) + return [] + + +@router.patch('/v1/advice/{advice_id}', tags=['advice']) +def update_advice( + advice_id: str, + request: UpdateAdviceRequest, + uid: str = Depends(auth.get_current_user_uid), +): + update_data = request.model_dump(exclude_none=True) + if not update_data: + raise HTTPException(status_code=400, detail="No fields to update") + try: + result = advice_db.update_advice(uid, advice_id, update_data) + if result is None: + raise HTTPException(status_code=404, detail="Advice not found") + return result + except HTTPException: + raise + except Exception: + logger.exception('Failed to update advice %s for uid=%s', advice_id, uid) + raise HTTPException(status_code=500, detail="Failed to update advice") + + +@router.delete('/v1/advice/{advice_id}', response_model=AdviceStatusResponse, tags=['advice']) +def delete_advice( + advice_id: str, + uid: str = Depends(auth.get_current_user_uid), +): + try: + advice_db.delete_advice(uid, advice_id) + return AdviceStatusResponse(status="ok") + except Exception: + logger.exception('Failed to delete advice %s for uid=%s', advice_id, uid) + raise HTTPException(status_code=500, detail="Failed to delete advice") + + +@router.post('/v1/advice/mark-all-read', response_model=AdviceStatusResponse, tags=['advice']) +def mark_all_advice_read( + uid: str = Depends(auth.get_current_user_uid), +): + try: + count = advice_db.mark_all_advice_read(uid) + return AdviceStatusResponse(status=f"marked {count} as read") + except Exception: + logger.exception('Failed to mark all advice read for uid=%s', uid) + raise HTTPException(status_code=500, detail="Failed to mark advice as read") From f66136b85a91fa961eae2b64050d3390f7e836d9 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:18:34 +0100 Subject: [PATCH 083/163] Register focus_sessions and advice routers in main.py Co-Authored-By: Claude Opus 4.6 --- backend/main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backend/main.py b/backend/main.py index 3f87e6af53..32e6790280 100644 --- a/backend/main.py +++ b/backend/main.py @@ -46,6 +46,8 @@ phone_calls, agent_tools, screen_activity, + focus_sessions, + advice, ) from utils.other.timeout import TimeoutMiddleware @@ -106,6 +108,8 @@ app.include_router(phone_calls.router) app.include_router(agent_tools.router) app.include_router(screen_activity.router) +app.include_router(focus_sessions.router) +app.include_router(advice.router) methods_timeout = { From fbb42490a30af960c6283cbd527a3e0f54749f4c Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:18:39 +0100 Subject: [PATCH 084/163] Add 45 unit tests for focus sessions and advice endpoints Focus sessions (21): create, invalid status, optional fields, auth, errors, date filter, pagination, delete, stats (empty/mixed/defaults/top5/duration fallback) Advice (24): create with validation, categories, confidence bounds, auth, errors, get with filters/pagination, update read/dismissed, delete, mark-all-read Co-Authored-By: Claude Opus 4.6 --- backend/tests/unit/test_advice.py | 196 ++++++++++++++++++++++ backend/tests/unit/test_focus_sessions.py | 186 ++++++++++++++++++++ 2 files changed, 382 insertions(+) create mode 100644 backend/tests/unit/test_advice.py create mode 100644 backend/tests/unit/test_focus_sessions.py diff --git a/backend/tests/unit/test_advice.py b/backend/tests/unit/test_advice.py new file mode 100644 index 0000000000..9f16d4adab --- /dev/null +++ b/backend/tests/unit/test_advice.py @@ -0,0 +1,196 @@ +from datetime import datetime, timezone +from unittest.mock import patch, MagicMock + +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture +def client(): + with patch('database.screen_activity.db'), \ + patch('database.focus_sessions.db'), \ + patch('database.advice.db'), \ + patch('database.vector_db.Pinecone'), \ + patch('database.vector_db.pc'), \ + patch('database.vector_db.index'), \ + patch('utils.llm.clients.embeddings'): + from main import app + with TestClient(app) as c: + yield c + + +AUTH = {"Authorization": "Bearer 123testuser"} + + +class TestCreateAdvice: + def test_create_minimal(self, client): + data = {"content": "Take a break"} + with patch('routers.advice.advice_db.create_advice') as mock_create: + mock_create.return_value = { + "id": "adv-1", "content": "Take a break", "category": "other", + "confidence": 0.5, "is_read": False, "is_dismissed": False, + "created_at": datetime.now(timezone.utc), + } + resp = client.post("/v1/advice", json=data, headers=AUTH) + assert resp.status_code == 201 + assert resp.json()["content"] == "Take a break" + assert resp.json()["category"] == "other" + + def test_create_with_all_fields(self, client): + data = { + "content": "Drink water", "category": "health", "reasoning": "Dehydrated", + "source_app": "Chrome", "confidence": 0.9, "context_summary": "Long session", + "current_activity": "Browsing", + } + with patch('routers.advice.advice_db.create_advice') as mock_create: + mock_create.return_value = {"id": "adv-2", **data, "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)} + resp = client.post("/v1/advice", json=data, headers=AUTH) + assert resp.status_code == 201 + assert resp.json()["category"] == "health" + assert resp.json()["confidence"] == 0.9 + + def test_create_invalid_category_returns_400(self, client): + data = {"content": "Test", "category": "invalid_cat"} + resp = client.post("/v1/advice", json=data, headers=AUTH) + assert resp.status_code == 400 + assert "category" in resp.json()["detail"] + + def test_create_confidence_below_zero_returns_400(self, client): + data = {"content": "Test", "confidence": -0.1} + resp = client.post("/v1/advice", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_create_confidence_above_one_returns_400(self, client): + data = {"content": "Test", "confidence": 1.1} + resp = client.post("/v1/advice", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_create_confidence_boundary_zero(self, client): + data = {"content": "Test", "confidence": 0.0} + with patch('routers.advice.advice_db.create_advice') as mock_create: + mock_create.return_value = {"id": "adv-3", "content": "Test", "confidence": 0.0, "category": "other", "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)} + resp = client.post("/v1/advice", json=data, headers=AUTH) + assert resp.status_code == 201 + + def test_create_confidence_boundary_one(self, client): + data = {"content": "Test", "confidence": 1.0} + with patch('routers.advice.advice_db.create_advice') as mock_create: + mock_create.return_value = {"id": "adv-4", "content": "Test", "confidence": 1.0, "category": "other", "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)} + resp = client.post("/v1/advice", json=data, headers=AUTH) + assert resp.status_code == 201 + + def test_create_no_auth_returns_401(self, client): + resp = client.post("/v1/advice", json={"content": "Test"}) + assert resp.status_code == 401 + + def test_create_firestore_error_returns_500(self, client): + with patch('routers.advice.advice_db.create_advice', side_effect=Exception("DB down")): + resp = client.post("/v1/advice", json={"content": "Test"}, headers=AUTH) + assert resp.status_code == 500 + + def test_create_each_valid_category(self, client): + for cat in ('productivity', 'health', 'communication', 'learning', 'other'): + data = {"content": "Test", "category": cat} + with patch('routers.advice.advice_db.create_advice') as mock_create: + mock_create.return_value = {"id": "x", "content": "Test", "category": cat, "confidence": 0.5, "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)} + resp = client.post("/v1/advice", json=data, headers=AUTH) + assert resp.status_code == 201, f"Failed for category {cat}" + + +class TestGetAdvice: + def test_get_empty(self, client): + with patch('routers.advice.advice_db.get_advice', return_value=[]): + resp = client.get("/v1/advice", headers=AUTH) + assert resp.status_code == 200 + assert resp.json() == [] + + def test_get_with_category_filter(self, client): + with patch('routers.advice.advice_db.get_advice', return_value=[]) as mock_get: + resp = client.get("/v1/advice?category=health", headers=AUTH) + assert resp.status_code == 200 + assert mock_get.call_args[1]['category'] == 'health' + + def test_get_invalid_category_filter_returns_400(self, client): + resp = client.get("/v1/advice?category=bad_cat", headers=AUTH) + assert resp.status_code == 400 + + def test_get_include_dismissed(self, client): + with patch('routers.advice.advice_db.get_advice', return_value=[]) as mock_get: + resp = client.get("/v1/advice?include_dismissed=true", headers=AUTH) + assert resp.status_code == 200 + assert mock_get.call_args[1]['include_dismissed'] is True + + def test_get_with_pagination(self, client): + with patch('routers.advice.advice_db.get_advice', return_value=[]) as mock_get: + resp = client.get("/v1/advice?limit=50&offset=20", headers=AUTH) + assert resp.status_code == 200 + assert mock_get.call_args[1]['limit'] == 50 + assert mock_get.call_args[1]['offset'] == 20 + + def test_get_firestore_error_returns_empty(self, client): + with patch('routers.advice.advice_db.get_advice', side_effect=Exception("err")): + resp = client.get("/v1/advice", headers=AUTH) + assert resp.status_code == 200 + assert resp.json() == [] + + +class TestUpdateAdvice: + def test_mark_as_read(self, client): + with patch('routers.advice.advice_db.update_advice') as mock_update: + mock_update.return_value = {"id": "adv-1", "is_read": True, "is_dismissed": False, "content": "x", "category": "other", "confidence": 0.5, "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc)} + resp = client.patch("/v1/advice/adv-1", json={"is_read": True}, headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["is_read"] is True + + def test_mark_as_dismissed(self, client): + with patch('routers.advice.advice_db.update_advice') as mock_update: + mock_update.return_value = {"id": "adv-1", "is_read": False, "is_dismissed": True, "content": "x", "category": "other", "confidence": 0.5, "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc)} + resp = client.patch("/v1/advice/adv-1", json={"is_dismissed": True}, headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["is_dismissed"] is True + + def test_empty_update_returns_400(self, client): + resp = client.patch("/v1/advice/adv-1", json={}, headers=AUTH) + assert resp.status_code == 400 + + def test_update_not_found_returns_404(self, client): + with patch('routers.advice.advice_db.update_advice', return_value=None): + resp = client.patch("/v1/advice/adv-1", json={"is_read": True}, headers=AUTH) + assert resp.status_code == 404 + + def test_update_firestore_error_returns_500(self, client): + with patch('routers.advice.advice_db.update_advice', side_effect=Exception("err")): + resp = client.patch("/v1/advice/adv-1", json={"is_read": True}, headers=AUTH) + assert resp.status_code == 500 + + +class TestDeleteAdvice: + def test_delete_returns_ok(self, client): + with patch('routers.advice.advice_db.delete_advice', return_value=True): + resp = client.delete("/v1/advice/adv-1", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + def test_delete_firestore_error_returns_500(self, client): + with patch('routers.advice.advice_db.delete_advice', side_effect=Exception("err")): + resp = client.delete("/v1/advice/adv-1", headers=AUTH) + assert resp.status_code == 500 + + +class TestMarkAllRead: + def test_mark_all_read_returns_count(self, client): + with patch('routers.advice.advice_db.mark_all_advice_read', return_value=5): + resp = client.post("/v1/advice/mark-all-read", headers=AUTH) + assert resp.status_code == 200 + assert "5" in resp.json()["status"] + + def test_mark_all_read_zero(self, client): + with patch('routers.advice.advice_db.mark_all_advice_read', return_value=0): + resp = client.post("/v1/advice/mark-all-read", headers=AUTH) + assert resp.status_code == 200 + assert "0" in resp.json()["status"] + + def test_mark_all_read_firestore_error(self, client): + with patch('routers.advice.advice_db.mark_all_advice_read', side_effect=Exception("err")): + resp = client.post("/v1/advice/mark-all-read", headers=AUTH) + assert resp.status_code == 500 diff --git a/backend/tests/unit/test_focus_sessions.py b/backend/tests/unit/test_focus_sessions.py new file mode 100644 index 0000000000..10985336b3 --- /dev/null +++ b/backend/tests/unit/test_focus_sessions.py @@ -0,0 +1,186 @@ +from datetime import datetime, timezone +from unittest.mock import patch, MagicMock + +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture +def client(): + with patch('database.screen_activity.db'), \ + patch('database.focus_sessions.db'), \ + patch('database.advice.db'), \ + patch('database.vector_db.Pinecone'), \ + patch('database.vector_db.pc'), \ + patch('database.vector_db.index'), \ + patch('utils.llm.clients.embeddings'): + from main import app + with TestClient(app) as c: + yield c + + +AUTH = {"Authorization": "Bearer 123testuser"} + + +class TestCreateFocusSession: + def test_create_focused_session(self, client): + data = {"status": "focused", "app_or_site": "VSCode", "description": "Coding"} + with patch('routers.focus_sessions.focus_sessions_db.create_focus_session') as mock_create: + mock_create.return_value = { + "id": "abc-123", "status": "focused", "app_or_site": "VSCode", + "description": "Coding", "created_at": datetime.now(timezone.utc), + } + resp = client.post("/v1/focus-sessions", json=data, headers=AUTH) + assert resp.status_code == 201 + assert resp.json()["status"] == "focused" + + def test_create_distracted_session(self, client): + data = {"status": "distracted", "app_or_site": "Twitter", "description": "Scrolling"} + with patch('routers.focus_sessions.focus_sessions_db.create_focus_session') as mock_create: + mock_create.return_value = { + "id": "abc-456", "status": "distracted", "app_or_site": "Twitter", + "description": "Scrolling", "created_at": datetime.now(timezone.utc), + } + resp = client.post("/v1/focus-sessions", json=data, headers=AUTH) + assert resp.status_code == 201 + assert resp.json()["status"] == "distracted" + + def test_create_invalid_status_returns_400(self, client): + data = {"status": "invalid", "app_or_site": "X", "description": "Y"} + resp = client.post("/v1/focus-sessions", json=data, headers=AUTH) + assert resp.status_code == 400 + assert "focused" in resp.json()["detail"] + + def test_create_with_optional_fields(self, client): + data = { + "status": "focused", "app_or_site": "VSCode", "description": "Coding", + "message": "Keep going!", "duration_seconds": 300, + } + with patch('routers.focus_sessions.focus_sessions_db.create_focus_session') as mock_create: + mock_create.return_value = { + "id": "abc-789", "message": "Keep going!", "duration_seconds": 300, + **{k: v for k, v in data.items()}, "created_at": datetime.now(timezone.utc), + } + resp = client.post("/v1/focus-sessions", json=data, headers=AUTH) + assert resp.status_code == 201 + assert resp.json()["message"] == "Keep going!" + assert resp.json()["duration_seconds"] == 300 + + def test_create_no_auth_returns_401(self, client): + data = {"status": "focused", "app_or_site": "X", "description": "Y"} + resp = client.post("/v1/focus-sessions", json=data) + assert resp.status_code == 401 + + def test_create_firestore_error_returns_500(self, client): + data = {"status": "focused", "app_or_site": "X", "description": "Y"} + with patch('routers.focus_sessions.focus_sessions_db.create_focus_session', side_effect=Exception("DB down")): + resp = client.post("/v1/focus-sessions", json=data, headers=AUTH) + assert resp.status_code == 500 + + +class TestGetFocusSessions: + def test_get_empty_returns_list(self, client): + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', return_value=[]): + resp = client.get("/v1/focus-sessions", headers=AUTH) + assert resp.status_code == 200 + assert resp.json() == [] + + def test_get_with_date_filter(self, client): + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', return_value=[]) as mock_get: + resp = client.get("/v1/focus-sessions?date=2026-03-05", headers=AUTH) + assert resp.status_code == 200 + mock_get.assert_called_once() + assert mock_get.call_args[1]['date'] == '2026-03-05' + + def test_get_invalid_date_returns_400(self, client): + resp = client.get("/v1/focus-sessions?date=not-a-date", headers=AUTH) + assert resp.status_code == 400 + + def test_get_with_limit_and_offset(self, client): + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', return_value=[]) as mock_get: + resp = client.get("/v1/focus-sessions?limit=50&offset=10", headers=AUTH) + assert resp.status_code == 200 + mock_get.assert_called_once() + assert mock_get.call_args[1]['limit'] == 50 + assert mock_get.call_args[1]['offset'] == 10 + + def test_get_firestore_error_returns_empty(self, client): + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', side_effect=Exception("err")): + resp = client.get("/v1/focus-sessions", headers=AUTH) + assert resp.status_code == 200 + assert resp.json() == [] + + +class TestDeleteFocusSession: + def test_delete_returns_ok(self, client): + with patch('routers.focus_sessions.focus_sessions_db.delete_focus_session', return_value=True): + resp = client.delete("/v1/focus-sessions/abc-123", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + def test_delete_firestore_error_returns_500(self, client): + with patch('routers.focus_sessions.focus_sessions_db.delete_focus_session', side_effect=Exception("err")): + resp = client.delete("/v1/focus-sessions/abc-123", headers=AUTH) + assert resp.status_code == 500 + + +class TestFocusStats: + def test_stats_empty_sessions(self, client): + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=[]): + resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH) + assert resp.status_code == 200 + data = resp.json() + assert data["date"] == "2026-03-05" + assert data["session_count"] == 0 + assert data["focused_count"] == 0 + assert data["distracted_count"] == 0 + assert data["top_distractions"] == [] + + def test_stats_with_sessions(self, client): + sessions = [ + {"status": "focused", "app_or_site": "VSCode", "duration_seconds": 120}, + {"status": "distracted", "app_or_site": "Twitter", "duration_seconds": 60}, + {"status": "distracted", "app_or_site": "Twitter", "duration_seconds": 90}, + {"status": "distracted", "app_or_site": "Reddit", "duration_seconds": 30}, + ] + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=sessions): + resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH) + assert resp.status_code == 200 + data = resp.json() + assert data["focused_count"] == 1 + assert data["distracted_count"] == 3 + assert data["session_count"] == 4 + assert len(data["top_distractions"]) == 2 + assert data["top_distractions"][0]["app_or_site"] == "Twitter" + assert data["top_distractions"][0]["total_seconds"] == 150 + + def test_stats_defaults_to_today(self, client): + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=[]) as mock_get: + resp = client.get("/v1/focus-stats", headers=AUTH) + assert resp.status_code == 200 + called_date = mock_get.call_args[0][1] + today = datetime.now(timezone.utc).strftime('%Y-%m-%d') + assert called_date == today + + def test_stats_invalid_date_returns_400(self, client): + resp = client.get("/v1/focus-stats?date=bad", headers=AUTH) + assert resp.status_code == 400 + + def test_stats_distraction_without_duration_defaults_60(self, client): + sessions = [ + {"status": "distracted", "app_or_site": "YouTube"}, + ] + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=sessions): + resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["top_distractions"][0]["total_seconds"] == 60 + + def test_stats_top5_limit(self, client): + sessions = [ + {"status": "distracted", "app_or_site": f"App{i}", "duration_seconds": i * 10} + for i in range(8) + ] + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=sessions): + resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH) + assert resp.status_code == 200 + assert len(resp.json()["top_distractions"]) == 5 From b20aa9a953db9983dccb0a9102861e64591ecf9a Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:22:48 +0100 Subject: [PATCH 085/163] Fix advice update not-found and focus date filter cutoff - advice update: catch 404 from Firestore update() on missing doc, return None instead of letting exception bubble to 500 - focus date filter: use < next_day_start instead of <= 23:59:59 to avoid excluding docs at 23:59:59.xxx Co-Authored-By: Claude Opus 4.6 --- backend/database/advice.py | 7 ++++++- backend/database/focus_sessions.py | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/backend/database/advice.py b/backend/database/advice.py index e03c697181..5c28e8fef2 100644 --- a/backend/database/advice.py +++ b/backend/database/advice.py @@ -75,7 +75,12 @@ def update_advice(uid: str, advice_id: str, data: Dict[str, Any]) -> Optional[Di if 'is_dismissed' in data: update_data['is_dismissed'] = data['is_dismissed'] - doc_ref.update(update_data) + try: + doc_ref.update(update_data) + except Exception as e: + if hasattr(e, 'code') and e.code == 404: + return None + raise doc = doc_ref.get() if doc.exists: diff --git a/backend/database/focus_sessions.py b/backend/database/focus_sessions.py index b4a1e0f54e..c21af103ab 100644 --- a/backend/database/focus_sessions.py +++ b/backend/database/focus_sessions.py @@ -1,6 +1,6 @@ import logging import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import List, Dict, Any, Optional from google.cloud import firestore @@ -50,9 +50,9 @@ def get_focus_sessions( if date: day_start = datetime.strptime(date, '%Y-%m-%d').replace(tzinfo=timezone.utc) - day_end = day_start.replace(hour=23, minute=59, second=59) + next_day_start = day_start + timedelta(days=1) query = query.where(filter=firestore.FieldFilter('created_at', '>=', day_start)) - query = query.where(filter=firestore.FieldFilter('created_at', '<=', day_end)) + query = query.where(filter=firestore.FieldFilter('created_at', '<', next_day_start)) query = query.offset(offset).limit(limit) From da4ed820ac181d2a156cfec6d5f186df2b0da78f Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:28:12 +0100 Subject: [PATCH 086/163] Fix Rust parity issues from CP7 review - Create endpoints return 200 (not 201) matching Rust default - Invalid date on GET skips filter instead of 400 (Rust ignores bad dates) - Invalid category on GET skips filter instead of 400 (Rust accepts any) - PATCH empty body updates updated_at only (Rust behavior), not 400 - Missing advice on PATCH returns 500 (Rust), not 404 - duration_seconds=0 preserved (not replaced with 60 default) - mark_all_read ignores per-item failures (match Rust let _ = pattern) Co-Authored-By: Claude Opus 4.6 --- backend/database/advice.py | 7 +++-- backend/routers/advice.py | 9 +++---- backend/routers/focus_sessions.py | 11 ++++---- backend/tests/unit/test_advice.py | 30 ++++++++++++--------- backend/tests/unit/test_focus_sessions.py | 32 ++++++++++++++++------- 5 files changed, 55 insertions(+), 34 deletions(-) diff --git a/backend/database/advice.py b/backend/database/advice.py index 5c28e8fef2..5043c9cb72 100644 --- a/backend/database/advice.py +++ b/backend/database/advice.py @@ -107,6 +107,9 @@ def mark_all_advice_read(uid: str) -> int: count = 0 now = datetime.now(timezone.utc) for doc in query.stream(): - doc.reference.update({'is_read': True, 'updated_at': now}) - count += 1 + try: + doc.reference.update({'is_read': True, 'updated_at': now}) + count += 1 + except Exception: + logger.warning('Failed to mark advice %s as read for uid=%s', doc.id, uid) return count diff --git a/backend/routers/advice.py b/backend/routers/advice.py index d1c36c5e12..72d671c84d 100644 --- a/backend/routers/advice.py +++ b/backend/routers/advice.py @@ -61,7 +61,7 @@ def _validate_confidence(confidence: Optional[float]): raise HTTPException(status_code=400, detail="confidence must be between 0.0 and 1.0") -@router.post('/v1/advice', status_code=201, tags=['advice']) +@router.post('/v1/advice', tags=['advice']) def create_advice( request: CreateAdviceRequest, uid: str = Depends(auth.get_current_user_uid), @@ -83,7 +83,8 @@ def get_advice( include_dismissed: bool = Query(default=False), uid: str = Depends(auth.get_current_user_uid), ): - _validate_category(category) + if category and category not in VALID_CATEGORIES: + category = None # Skip unknown category filter (match Rust behavior) try: return advice_db.get_advice( uid, limit=limit, offset=offset, category=category, include_dismissed=include_dismissed, @@ -100,12 +101,10 @@ def update_advice( uid: str = Depends(auth.get_current_user_uid), ): update_data = request.model_dump(exclude_none=True) - if not update_data: - raise HTTPException(status_code=400, detail="No fields to update") try: result = advice_db.update_advice(uid, advice_id, update_data) if result is None: - raise HTTPException(status_code=404, detail="Advice not found") + raise HTTPException(status_code=500, detail="Failed to update advice") return result except HTTPException: raise diff --git a/backend/routers/focus_sessions.py b/backend/routers/focus_sessions.py index 75c6702311..db19622e86 100644 --- a/backend/routers/focus_sessions.py +++ b/backend/routers/focus_sessions.py @@ -57,7 +57,7 @@ def _validate_focus_status(status: str): raise HTTPException(status_code=400, detail="status must be 'focused' or 'distracted'") -@router.post('/v1/focus-sessions', response_model=FocusSessionResponse, status_code=201, tags=['focus-sessions']) +@router.post('/v1/focus-sessions', response_model=FocusSessionResponse, tags=['focus-sessions']) def create_focus_session( request: CreateFocusSessionRequest, uid: str = Depends(auth.get_current_user_uid), @@ -82,7 +82,7 @@ def get_focus_sessions( try: datetime.strptime(date, '%Y-%m-%d') except ValueError: - raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD format") + date = None # Skip invalid date filter (match Rust behavior) try: return focus_sessions_db.get_focus_sessions(uid, limit=limit, offset=offset, date=date) except Exception: @@ -112,8 +112,8 @@ def get_focus_stats( try: datetime.strptime(date, '%Y-%m-%d') except ValueError: - raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD format") - else: + date = None # Skip invalid date filter (match Rust behavior) + if not date: date = datetime.now(timezone.utc).strftime('%Y-%m-%d') try: @@ -133,7 +133,8 @@ def get_focus_stats( elif status == 'distracted': distracted_count += 1 app = s.get('app_or_site', 'Unknown') - duration = s.get('duration_seconds') or 60 + raw_duration = s.get('duration_seconds') + duration = raw_duration if raw_duration is not None else 60 distraction_map[app]['total_seconds'] += duration distraction_map[app]['count'] += 1 diff --git a/backend/tests/unit/test_advice.py b/backend/tests/unit/test_advice.py index 9f16d4adab..6b16f69933 100644 --- a/backend/tests/unit/test_advice.py +++ b/backend/tests/unit/test_advice.py @@ -32,7 +32,7 @@ def test_create_minimal(self, client): "created_at": datetime.now(timezone.utc), } resp = client.post("/v1/advice", json=data, headers=AUTH) - assert resp.status_code == 201 + assert resp.status_code == 200 assert resp.json()["content"] == "Take a break" assert resp.json()["category"] == "other" @@ -45,7 +45,7 @@ def test_create_with_all_fields(self, client): with patch('routers.advice.advice_db.create_advice') as mock_create: mock_create.return_value = {"id": "adv-2", **data, "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)} resp = client.post("/v1/advice", json=data, headers=AUTH) - assert resp.status_code == 201 + assert resp.status_code == 200 assert resp.json()["category"] == "health" assert resp.json()["confidence"] == 0.9 @@ -70,14 +70,14 @@ def test_create_confidence_boundary_zero(self, client): with patch('routers.advice.advice_db.create_advice') as mock_create: mock_create.return_value = {"id": "adv-3", "content": "Test", "confidence": 0.0, "category": "other", "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)} resp = client.post("/v1/advice", json=data, headers=AUTH) - assert resp.status_code == 201 + assert resp.status_code == 200 def test_create_confidence_boundary_one(self, client): data = {"content": "Test", "confidence": 1.0} with patch('routers.advice.advice_db.create_advice') as mock_create: mock_create.return_value = {"id": "adv-4", "content": "Test", "confidence": 1.0, "category": "other", "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)} resp = client.post("/v1/advice", json=data, headers=AUTH) - assert resp.status_code == 201 + assert resp.status_code == 200 def test_create_no_auth_returns_401(self, client): resp = client.post("/v1/advice", json={"content": "Test"}) @@ -94,7 +94,7 @@ def test_create_each_valid_category(self, client): with patch('routers.advice.advice_db.create_advice') as mock_create: mock_create.return_value = {"id": "x", "content": "Test", "category": cat, "confidence": 0.5, "is_read": False, "is_dismissed": False, "created_at": datetime.now(timezone.utc)} resp = client.post("/v1/advice", json=data, headers=AUTH) - assert resp.status_code == 201, f"Failed for category {cat}" + assert resp.status_code == 200, f"Failed for category {cat}" class TestGetAdvice: @@ -110,9 +110,11 @@ def test_get_with_category_filter(self, client): assert resp.status_code == 200 assert mock_get.call_args[1]['category'] == 'health' - def test_get_invalid_category_filter_returns_400(self, client): - resp = client.get("/v1/advice?category=bad_cat", headers=AUTH) - assert resp.status_code == 400 + def test_get_invalid_category_skips_filter(self, client): + with patch('routers.advice.advice_db.get_advice', return_value=[]) as mock_get: + resp = client.get("/v1/advice?category=bad_cat", headers=AUTH) + assert resp.status_code == 200 + assert mock_get.call_args[1]['category'] is None def test_get_include_dismissed(self, client): with patch('routers.advice.advice_db.get_advice', return_value=[]) as mock_get: @@ -149,14 +151,16 @@ def test_mark_as_dismissed(self, client): assert resp.status_code == 200 assert resp.json()["is_dismissed"] is True - def test_empty_update_returns_400(self, client): - resp = client.patch("/v1/advice/adv-1", json={}, headers=AUTH) - assert resp.status_code == 400 + def test_empty_update_still_updates_timestamp(self, client): + with patch('routers.advice.advice_db.update_advice') as mock_update: + mock_update.return_value = {"id": "adv-1", "is_read": False, "is_dismissed": False, "content": "x", "category": "other", "confidence": 0.5, "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc)} + resp = client.patch("/v1/advice/adv-1", json={}, headers=AUTH) + assert resp.status_code == 200 - def test_update_not_found_returns_404(self, client): + def test_update_not_found_returns_500(self, client): with patch('routers.advice.advice_db.update_advice', return_value=None): resp = client.patch("/v1/advice/adv-1", json={"is_read": True}, headers=AUTH) - assert resp.status_code == 404 + assert resp.status_code == 500 def test_update_firestore_error_returns_500(self, client): with patch('routers.advice.advice_db.update_advice', side_effect=Exception("err")): diff --git a/backend/tests/unit/test_focus_sessions.py b/backend/tests/unit/test_focus_sessions.py index 10985336b3..fc898d4a39 100644 --- a/backend/tests/unit/test_focus_sessions.py +++ b/backend/tests/unit/test_focus_sessions.py @@ -31,7 +31,7 @@ def test_create_focused_session(self, client): "description": "Coding", "created_at": datetime.now(timezone.utc), } resp = client.post("/v1/focus-sessions", json=data, headers=AUTH) - assert resp.status_code == 201 + assert resp.status_code == 200 assert resp.json()["status"] == "focused" def test_create_distracted_session(self, client): @@ -42,7 +42,7 @@ def test_create_distracted_session(self, client): "description": "Scrolling", "created_at": datetime.now(timezone.utc), } resp = client.post("/v1/focus-sessions", json=data, headers=AUTH) - assert resp.status_code == 201 + assert resp.status_code == 200 assert resp.json()["status"] == "distracted" def test_create_invalid_status_returns_400(self, client): @@ -62,7 +62,7 @@ def test_create_with_optional_fields(self, client): **{k: v for k, v in data.items()}, "created_at": datetime.now(timezone.utc), } resp = client.post("/v1/focus-sessions", json=data, headers=AUTH) - assert resp.status_code == 201 + assert resp.status_code == 200 assert resp.json()["message"] == "Keep going!" assert resp.json()["duration_seconds"] == 300 @@ -92,9 +92,11 @@ def test_get_with_date_filter(self, client): mock_get.assert_called_once() assert mock_get.call_args[1]['date'] == '2026-03-05' - def test_get_invalid_date_returns_400(self, client): - resp = client.get("/v1/focus-sessions?date=not-a-date", headers=AUTH) - assert resp.status_code == 400 + def test_get_invalid_date_skips_filter(self, client): + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', return_value=[]) as mock_get: + resp = client.get("/v1/focus-sessions?date=not-a-date", headers=AUTH) + assert resp.status_code == 200 + assert mock_get.call_args[1]['date'] is None def test_get_with_limit_and_offset(self, client): with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions', return_value=[]) as mock_get: @@ -162,9 +164,12 @@ def test_stats_defaults_to_today(self, client): today = datetime.now(timezone.utc).strftime('%Y-%m-%d') assert called_date == today - def test_stats_invalid_date_returns_400(self, client): - resp = client.get("/v1/focus-stats?date=bad", headers=AUTH) - assert resp.status_code == 400 + def test_stats_invalid_date_defaults_to_today(self, client): + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=[]) as mock_get: + resp = client.get("/v1/focus-stats?date=bad", headers=AUTH) + assert resp.status_code == 200 + today = datetime.now(timezone.utc).strftime('%Y-%m-%d') + assert mock_get.call_args[0][1] == today def test_stats_distraction_without_duration_defaults_60(self, client): sessions = [ @@ -175,6 +180,15 @@ def test_stats_distraction_without_duration_defaults_60(self, client): assert resp.status_code == 200 assert resp.json()["top_distractions"][0]["total_seconds"] == 60 + def test_stats_distraction_with_zero_duration_keeps_zero(self, client): + sessions = [ + {"status": "distracted", "app_or_site": "Slack", "duration_seconds": 0}, + ] + with patch('routers.focus_sessions.focus_sessions_db.get_focus_sessions_for_stats', return_value=sessions): + resp = client.get("/v1/focus-stats?date=2026-03-05", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["top_distractions"][0]["total_seconds"] == 0 + def test_stats_top5_limit(self, client): sessions = [ {"status": "distracted", "app_or_site": f"App{i}", "duration_seconds": i * 10} From 3d9168503e0f931484c7a067532432864a76c4e6 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:19:24 +0100 Subject: [PATCH 087/163] Add staged tasks database layer for desktop migration --- backend/database/staged_tasks.py | 220 +++++++++++++++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 backend/database/staged_tasks.py diff --git a/backend/database/staged_tasks.py b/backend/database/staged_tasks.py new file mode 100644 index 0000000000..3dfb3a66be --- /dev/null +++ b/backend/database/staged_tasks.py @@ -0,0 +1,220 @@ +"""Database operations for desktop staged tasks (users/{uid}/staged_tasks).""" + +from datetime import datetime, timezone +from typing import Optional, List, Tuple + +from google.cloud import firestore + +from ._client import db +import logging + +logger = logging.getLogger(__name__) + +COLLECTION = 'staged_tasks' + + +def _prepare_for_read(data: dict) -> dict: + """Convert Firestore timestamps to Python datetimes.""" + for field in ['created_at', 'updated_at', 'due_at', 'completed_at', 'deleted_at']: + if field in data and data[field] and hasattr(data[field], 'timestamp'): + data[field] = datetime.fromtimestamp(data[field].timestamp(), tz=timezone.utc) + return data + + +# --- CREATE --- + + +def create_staged_task(uid: str, data: dict) -> dict: + """Create a staged task. Returns the created document with id.""" + now = datetime.now(timezone.utc) + data.setdefault('created_at', now) + data.setdefault('updated_at', now) + data.setdefault('completed', False) + + ref = db.collection('users').document(uid).collection(COLLECTION) + _, doc_ref = ref.add(data) + result = data.copy() + result['id'] = doc_ref.id + return result + + +# --- READ --- + + +def get_staged_tasks(uid: str, limit: int = 100, offset: int = 0) -> Tuple[List[dict], bool]: + """List staged tasks ordered by relevance_score ASC. Returns (items, has_more).""" + ref = db.collection('users').document(uid).collection(COLLECTION) + query = ref.order_by('relevance_score', direction=firestore.Query.ASCENDING) + + # Fetch limit+1 to detect has_more + fetch_limit = limit + 1 + if offset > 0: + query = query.offset(offset) + query = query.limit(fetch_limit) + + docs = list(query.stream()) + items = [] + for doc in docs: + data = doc.to_dict() + data['id'] = doc.id + items.append(_prepare_for_read(data)) + + has_more = len(items) > limit + if has_more: + items = items[:limit] + return items, has_more + + +def get_staged_task(uid: str, task_id: str) -> Optional[dict]: + """Get a single staged task by ID.""" + doc = db.collection('users').document(uid).collection(COLLECTION).document(task_id).get() + if not doc.exists: + return None + data = doc.to_dict() + data['id'] = doc.id + return _prepare_for_read(data) + + +# --- UPDATE --- + + +def batch_update_scores(uid: str, scores: List[dict]) -> None: + """Batch update relevance_score for multiple staged tasks. + + Args: + scores: List of {"id": str, "relevance_score": int} + """ + if not scores: + return + batch = db.batch() + ref = db.collection('users').document(uid).collection(COLLECTION) + now = datetime.now(timezone.utc) + for item in scores: + doc_ref = ref.document(item['id']) + batch.update(doc_ref, {'relevance_score': item['relevance_score'], 'updated_at': now}) + batch.commit() + + +# --- DELETE --- + + +def delete_staged_task(uid: str, task_id: str) -> bool: + """Hard-delete a staged task. Returns True if deleted.""" + doc_ref = db.collection('users').document(uid).collection(COLLECTION).document(task_id) + doc = doc_ref.get() + if not doc.exists: + return False + doc_ref.delete() + return True + + +def delete_staged_tasks_batch(uid: str, task_ids: List[str]) -> int: + """Hard-delete multiple staged tasks. Returns count deleted.""" + if not task_ids: + return 0 + batch = db.batch() + ref = db.collection('users').document(uid).collection(COLLECTION) + for task_id in task_ids: + batch.delete(ref.document(task_id)) + batch.commit() + return len(task_ids) + + +# --- PROMOTE --- + + +def get_active_ai_action_items(uid: str) -> List[dict]: + """Get active action items that were promoted from staged (from_staged=true, not completed, not deleted).""" + ref = db.collection('users').document(uid).collection('action_items') + query = ref.where(filter=firestore.FieldFilter('from_staged', '==', True)).where( + filter=firestore.FieldFilter('completed', '==', False) + ) + items = [] + for doc in query.stream(): + data = doc.to_dict() + # Skip soft-deleted + if data.get('deleted'): + continue + data['id'] = doc.id + items.append(_prepare_for_read(data)) + return items + + +def promote_staged_task(uid: str, staged_task: dict) -> dict: + """Create an action item from a staged task (from_staged=true). Returns created action item.""" + now = datetime.now(timezone.utc) + action_item_data = { + 'description': staged_task['description'], + 'completed': False, + 'created_at': now, + 'updated_at': now, + 'from_staged': True, + 'source': staged_task.get('source'), + 'priority': staged_task.get('priority'), + 'metadata': staged_task.get('metadata'), + 'category': staged_task.get('category'), + 'relevance_score': staged_task.get('relevance_score'), + } + if staged_task.get('due_at'): + action_item_data['due_at'] = staged_task['due_at'] + + ref = db.collection('users').document(uid).collection('action_items') + _, doc_ref = ref.add(action_item_data) + action_item_data['id'] = doc_ref.id + return action_item_data + + +# --- SCORES (daily/weekly/overall) --- + + +def get_action_items_for_daily_score(uid: str, due_start: str, due_end: str) -> Tuple[int, int]: + """Count completed vs total action items due on a specific day. + + Returns (completed_count, total_count). + """ + ref = db.collection('users').document(uid).collection('action_items') + start_dt = datetime.fromisoformat(due_start.replace('Z', '+00:00')) + end_dt = datetime.fromisoformat(due_end.replace('Z', '+00:00')) + + query = ref.where(filter=firestore.FieldFilter('due_at', '>=', start_dt)).where( + filter=firestore.FieldFilter('due_at', '<=', end_dt) + ) + + completed = 0 + total = 0 + for doc in query.stream(): + data = doc.to_dict() + if data.get('deleted'): + continue + total += 1 + if data.get('completed'): + completed += 1 + return completed, total + + +def get_action_items_for_weekly_score(uid: str, week_start: str, week_end: str) -> Tuple[int, int]: + """Count completed vs total action items in a 7-day window. + + Returns (completed_count, total_count). + """ + # Same logic as daily but wider date range + return get_action_items_for_daily_score(uid, week_start, week_end) + + +def get_action_items_for_overall_score(uid: str) -> Tuple[int, int]: + """Count completed vs total action items (all time, not deleted). + + Returns (completed_count, total_count). + """ + ref = db.collection('users').document(uid).collection('action_items') + + completed = 0 + total = 0 + for doc in ref.stream(): + data = doc.to_dict() + if data.get('deleted'): + continue + total += 1 + if data.get('completed'): + completed += 1 + return completed, total From 761b5d49330c649bef214fc2a05e9525549b6083 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:19:24 +0100 Subject: [PATCH 088/163] Add staged tasks + daily scores endpoints for desktop --- backend/routers/staged_tasks.py | 301 ++++++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 backend/routers/staged_tasks.py diff --git a/backend/routers/staged_tasks.py b/backend/routers/staged_tasks.py new file mode 100644 index 0000000000..e6b495c134 --- /dev/null +++ b/backend/routers/staged_tasks.py @@ -0,0 +1,301 @@ +"""Desktop staged tasks endpoints. + +Staged tasks are AI-extracted action items ranked by relevance_score. +The top-ranked task can be promoted to action_items (max 5 active AI tasks). +Deduplication prevents promoting tasks whose description already exists in active action_items. +""" + +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field, field_validator +from typing import Optional, List +from datetime import datetime + +import database.staged_tasks as staged_tasks_db +from utils.other import endpoints as auth + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# --- Models --- + + +class CreateStagedTaskRequest(BaseModel): + description: str = Field(..., min_length=1, max_length=2000) + due_at: Optional[datetime] = None + source: Optional[str] = None + priority: Optional[str] = None + metadata: Optional[str] = None + category: Optional[str] = None + relevance_score: Optional[int] = None + + @field_validator('description') + @classmethod + def description_not_blank(cls, v): + if not v.strip(): + raise ValueError('description must not be blank') + return v + + +class StagedTaskResponse(BaseModel): + id: str + description: str + completed: bool = False + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + due_at: Optional[datetime] = None + source: Optional[str] = None + priority: Optional[str] = None + metadata: Optional[str] = None + category: Optional[str] = None + relevance_score: Optional[int] = None + + +class StagedTasksListResponse(BaseModel): + items: List[StagedTaskResponse] + has_more: bool + + +class StatusResponse(BaseModel): + status: str + + +class ScoreUpdate(BaseModel): + id: str + relevance_score: int + + +class BatchUpdateScoresRequest(BaseModel): + scores: List[ScoreUpdate] = Field(..., min_length=1, max_length=500) + + +class PromoteResponse(BaseModel): + promoted: bool + reason: Optional[str] = None + promoted_task: Optional[StagedTaskResponse] = None + + +# --- Endpoints --- + + +# --- Desktop staged tasks --- + + +@router.post('/v1/staged-tasks', response_model=StagedTaskResponse, tags=['staged-tasks']) +def create_staged_task(request: CreateStagedTaskRequest, uid: str = Depends(auth.get_current_user_uid)): + """Create a new staged task.""" + data = { + 'description': request.description.strip(), + 'source': request.source, + 'priority': request.priority, + 'metadata': request.metadata, + 'category': request.category, + 'relevance_score': request.relevance_score, + } + if request.due_at: + data['due_at'] = request.due_at + + result = staged_tasks_db.create_staged_task(uid, data) + return StagedTaskResponse(**result) + + +@router.get('/v1/staged-tasks', response_model=StagedTasksListResponse, tags=['staged-tasks']) +def get_staged_tasks( + limit: int = Query(default=100, ge=1, le=500), + offset: int = Query(default=0, ge=0), + uid: str = Depends(auth.get_current_user_uid), +): + """List staged tasks ordered by relevance_score ASC (best ranked first).""" + items, has_more = staged_tasks_db.get_staged_tasks(uid, limit=limit, offset=offset) + return StagedTasksListResponse( + items=[StagedTaskResponse(**item) for item in items], + has_more=has_more, + ) + + +@router.delete('/v1/staged-tasks/{task_id}', response_model=StatusResponse, tags=['staged-tasks']) +def delete_staged_task(task_id: str, uid: str = Depends(auth.get_current_user_uid)): + """Hard-delete a staged task.""" + deleted = staged_tasks_db.delete_staged_task(uid, task_id) + if not deleted: + raise HTTPException(status_code=404, detail='Staged task not found') + return StatusResponse(status='ok') + + +@router.patch('/v1/staged-tasks/batch-scores', response_model=StatusResponse, tags=['staged-tasks']) +def batch_update_scores(request: BatchUpdateScoresRequest, uid: str = Depends(auth.get_current_user_uid)): + """Batch update relevance scores for staged tasks.""" + scores = [{'id': s.id, 'relevance_score': s.relevance_score} for s in request.scores] + staged_tasks_db.batch_update_scores(uid, scores) + return StatusResponse(status='ok') + + +@router.post('/v1/staged-tasks/promote', response_model=PromoteResponse, tags=['staged-tasks']) +def promote_staged_task(uid: str = Depends(auth.get_current_user_uid)): + """Promote the top-ranked staged task to action_items. + + Rules: + - Max 5 active AI tasks (from_staged=true, not completed, not deleted) + - Skips duplicates (case-insensitive description match, strips [screen] prefix/suffix) + - Deletes duplicate staged tasks found during scan + - Hard-deletes the promoted task from staged_tasks + """ + # Step 1: Check active AI task count + active_items = staged_tasks_db.get_active_ai_action_items(uid) + if len(active_items) >= 5: + return PromoteResponse( + promoted=False, + reason=f'Already have {len(active_items)} active AI tasks (max 5)', + ) + + # Build dedup set from existing descriptions + existing_descriptions = set() + for item in active_items: + desc = item.get('description', '') + normalized = desc.strip().removeprefix('[screen] ').removesuffix(' [screen]').lower() + existing_descriptions.add(normalized) + + # Step 2: Get top-ranked staged tasks (batch of 20 for dedup scanning) + staged_items, _ = staged_tasks_db.get_staged_tasks(uid, limit=20, offset=0) + if not staged_items: + return PromoteResponse(promoted=False, reason='No staged tasks available') + + # Step 3: Find first non-duplicate, collecting duplicates to delete + selected_task = None + seen_descriptions = set() + duplicate_ids = [] + + for task in staged_items: + normalized = task.get('description', '').strip().removeprefix('[screen] ').removesuffix(' [screen]').lower() + if normalized in existing_descriptions or normalized in seen_descriptions: + duplicate_ids.append(task['id']) + continue + seen_descriptions.add(normalized) + if selected_task is None: + selected_task = task + + # Clean up duplicates + if duplicate_ids: + staged_tasks_db.delete_staged_tasks_batch(uid, duplicate_ids) + logger.info(f'Cleaned up {len(duplicate_ids)} duplicate staged tasks for user {uid}') + + if selected_task is None: + return PromoteResponse(promoted=False, reason='All candidate staged tasks are duplicates') + + # Step 4: Promote to action_items + promoted_item = staged_tasks_db.promote_staged_task(uid, selected_task) + + # Step 5: Hard-delete from staged_tasks + staged_tasks_db.delete_staged_task(uid, selected_task['id']) + + logger.info(f'Promoted staged task {selected_task["id"]} -> action item {promoted_item["id"]} for user {uid}') + + return PromoteResponse(promoted=True, promoted_task=StagedTaskResponse(**promoted_item)) + + +# --- Desktop daily scores --- + + +class DailyScoreResponse(BaseModel): + score: float + completed_tasks: int + total_tasks: int + date: str + + +class ScoreData(BaseModel): + score: float + completed_tasks: int + total_tasks: int + + +class ScoresResponse(BaseModel): + daily: ScoreData + weekly: ScoreData + overall: ScoreData + default_tab: str + date: str + + +@router.get('/v1/daily-score', response_model=DailyScoreResponse, tags=['scores']) +def get_daily_score( + date: Optional[str] = Query(default=None, description='Date in YYYY-MM-DD format'), + uid: str = Depends(auth.get_current_user_uid), +): + """Calculate daily score from action items due today (legacy endpoint).""" + from datetime import date as date_type + + if date: + try: + parsed = datetime.strptime(date, '%Y-%m-%d').date() + except ValueError: + raise HTTPException(status_code=400, detail='Invalid date format, use YYYY-MM-DD') + else: + parsed = datetime.now().date() + + date_str = parsed.strftime('%Y-%m-%d') + due_start = f'{date_str}T00:00:00Z' + due_end = f'{date_str}T23:59:59.999Z' + + completed, total = staged_tasks_db.get_action_items_for_daily_score(uid, due_start, due_end) + score = (completed / total * 100.0) if total > 0 else 0.0 + + return DailyScoreResponse(score=score, completed_tasks=completed, total_tasks=total, date=date_str) + + +@router.get('/v1/scores', response_model=ScoresResponse, tags=['scores']) +def get_scores( + date: Optional[str] = Query(default=None, description='Date in YYYY-MM-DD format'), + uid: str = Depends(auth.get_current_user_uid), +): + """Get daily, weekly, and overall scores with default tab selection.""" + from datetime import timedelta + + if date: + try: + parsed = datetime.strptime(date, '%Y-%m-%d').date() + except ValueError: + raise HTTPException(status_code=400, detail='Invalid date format, use YYYY-MM-DD') + else: + parsed = datetime.now().date() + + date_str = parsed.strftime('%Y-%m-%d') + + # Daily: tasks due today + today_start = f'{date_str}T00:00:00Z' + today_end = f'{date_str}T23:59:59.999Z' + daily_completed, daily_total = staged_tasks_db.get_action_items_for_daily_score(uid, today_start, today_end) + + # Weekly: last 7 days + week_ago = parsed - timedelta(days=7) + week_start = f'{week_ago.strftime("%Y-%m-%d")}T00:00:00Z' + weekly_completed, weekly_total = staged_tasks_db.get_action_items_for_weekly_score(uid, week_start, today_end) + + # Overall: all time + overall_completed, overall_total = staged_tasks_db.get_action_items_for_overall_score(uid) + + def calc_score(completed, total): + return (completed / total * 100.0) if total > 0 else 0.0 + + daily = ScoreData( + score=calc_score(daily_completed, daily_total), completed_tasks=daily_completed, total_tasks=daily_total + ) + weekly = ScoreData( + score=calc_score(weekly_completed, weekly_total), completed_tasks=weekly_completed, total_tasks=weekly_total + ) + overall = ScoreData( + score=calc_score(overall_completed, overall_total), completed_tasks=overall_completed, total_tasks=overall_total + ) + + # Default tab: highest score, prefer daily if tied + if daily.total_tasks > 0 and daily.score >= weekly.score and daily.score >= overall.score: + default_tab = 'daily' + elif weekly.score >= overall.score: + default_tab = 'weekly' + else: + default_tab = 'overall' + + return ScoresResponse(daily=daily, weekly=weekly, overall=overall, default_tab=default_tab, date=date_str) From c694212aff4092b53ef01a7449710cd635a21492 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:19:28 +0100 Subject: [PATCH 089/163] Wire staged_tasks router into main.py --- backend/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/main.py b/backend/main.py index 32e6790280..04be0b585b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -48,6 +48,7 @@ screen_activity, focus_sessions, advice, + staged_tasks, ) from utils.other.timeout import TimeoutMiddleware @@ -110,6 +111,7 @@ app.include_router(screen_activity.router) app.include_router(focus_sessions.router) app.include_router(advice.router) +app.include_router(staged_tasks.router) methods_timeout = { From 08c44d5f96b5db9617454f3ca686f3ca18ee2848 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:19:29 +0100 Subject: [PATCH 090/163] Add 30 tests for staged tasks and daily scores endpoints --- backend/tests/unit/test_staged_tasks.py | 420 ++++++++++++++++++++++++ 1 file changed, 420 insertions(+) create mode 100644 backend/tests/unit/test_staged_tasks.py diff --git a/backend/tests/unit/test_staged_tasks.py b/backend/tests/unit/test_staged_tasks.py new file mode 100644 index 0000000000..b86a0254cf --- /dev/null +++ b/backend/tests/unit/test_staged_tasks.py @@ -0,0 +1,420 @@ +"""Tests for desktop staged tasks + daily scores endpoints.""" + +import sys +from unittest.mock import patch, MagicMock +from datetime import datetime, timezone + +import pytest + +for mod_name in [ + 'firebase_admin', + 'firebase_admin.auth', + 'firebase_admin.firestore', + 'firebase_admin.messaging', + 'google.cloud', + 'google.cloud.exceptions', + 'google.cloud.firestore', + 'google.cloud.firestore_v1', + 'google.cloud.firestore_v1.base_query', + 'google.cloud.firestore_v1.query', + 'google.cloud.storage', + 'google.cloud.storage.blob', + 'google.cloud.storage.bucket', + 'google.auth', + 'google.auth.transport', + 'google.auth.transport.requests', + 'google.oauth2', + 'google.oauth2.service_account', + 'pinecone', + 'typesense', +]: + sys.modules.setdefault(mod_name, MagicMock()) + +from routers.staged_tasks import ( + CreateStagedTaskRequest, + StagedTaskResponse, + StagedTasksListResponse, + BatchUpdateScoresRequest, + ScoreUpdate, + PromoteResponse, + DailyScoreResponse, + ScoresResponse, + ScoreData, + StatusResponse, + router, +) + +# --- Model Tests --- + + +class TestStagedTaskModels: + def test_create_request_required_fields(self): + req = CreateStagedTaskRequest(description='Buy groceries') + assert req.description == 'Buy groceries' + assert req.source is None + assert req.relevance_score is None + + def test_create_request_all_fields(self): + req = CreateStagedTaskRequest( + description='Ship feature', + source='screenshot', + priority='high', + metadata='{"app": "Safari"}', + category='work', + relevance_score=3, + ) + assert req.priority == 'high' + assert req.relevance_score == 3 + + def test_create_request_blank_description_rejected(self): + with pytest.raises(Exception): + CreateStagedTaskRequest(description=' ') + + def test_batch_scores_request(self): + req = BatchUpdateScoresRequest(scores=[ScoreUpdate(id='t1', relevance_score=5)]) + assert len(req.scores) == 1 + + def test_batch_scores_empty_rejected(self): + with pytest.raises(Exception): + BatchUpdateScoresRequest(scores=[]) + + def test_promote_response(self): + resp = PromoteResponse(promoted=True, promoted_task=StagedTaskResponse(id='t1', description='Task')) + assert resp.promoted is True + assert resp.promoted_task.id == 't1' + + def test_daily_score_response(self): + resp = DailyScoreResponse(score=75.0, completed_tasks=3, total_tasks=4, date='2026-03-05') + assert resp.score == 75.0 + + def test_scores_response(self): + data = ScoreData(score=50.0, completed_tasks=1, total_tasks=2) + resp = ScoresResponse(daily=data, weekly=data, overall=data, default_tab='daily', date='2026-03-05') + assert resp.default_tab == 'daily' + + +# --- Endpoint Tests --- + + +class TestStagedTaskEndpoints: + def _make_app(self): + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + return app + + @pytest.fixture + def client(self): + from fastapi.testclient import TestClient + + return TestClient(self._make_app()) + + def test_create_staged_task(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.create_staged_task') as mock_create, + ): + mock_create.return_value = { + 'id': 'st-1', + 'description': 'Buy milk', + 'completed': False, + 'created_at': datetime.now(timezone.utc), + 'updated_at': datetime.now(timezone.utc), + } + response = client.post( + '/v1/staged-tasks', + json={'description': 'Buy milk', 'source': 'screenshot', 'relevance_score': 5}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert response.json()['id'] == 'st-1' + assert response.json()['description'] == 'Buy milk' + + def test_create_staged_task_blank_desc_422(self, client): + with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'): + response = client.post( + '/v1/staged-tasks', + json={'description': ' '}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 + + def test_list_staged_tasks(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_get, + ): + mock_get.return_value = ( + [ + {'id': 'st-1', 'description': 'Task 1', 'completed': False, 'relevance_score': 1}, + {'id': 'st-2', 'description': 'Task 2', 'completed': False, 'relevance_score': 3}, + ], + False, + ) + response = client.get('/v1/staged-tasks', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + data = response.json() + assert len(data['items']) == 2 + assert data['has_more'] is False + + def test_list_staged_tasks_with_pagination(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_get, + ): + mock_get.return_value = ([], True) + response = client.get( + '/v1/staged-tasks?limit=10&offset=20', + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert mock_get.called + assert mock_get.call_args[1] == {'limit': 10, 'offset': 20} + + def test_list_staged_tasks_limit_over_max_422(self, client): + with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'): + response = client.get( + '/v1/staged-tasks?limit=501', + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 + + def test_delete_staged_task(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.delete_staged_task', return_value=True), + ): + response = client.delete('/v1/staged-tasks/st-1', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert response.json()['status'] == 'ok' + + def test_delete_staged_task_not_found_404(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.delete_staged_task', return_value=False), + ): + response = client.delete('/v1/staged-tasks/missing', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 404 + + def test_batch_update_scores(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.batch_update_scores') as mock_batch, + ): + response = client.patch( + '/v1/staged-tasks/batch-scores', + json={'scores': [{'id': 'st-1', 'relevance_score': 10}, {'id': 'st-2', 'relevance_score': 3}]}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert mock_batch.called + assert len(mock_batch.call_args[0][1]) == 2 + + def test_batch_update_scores_empty_422(self, client): + with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'): + response = client.patch( + '/v1/staged-tasks/batch-scores', + json={'scores': []}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 + + def test_promote_success(self, client): + now = datetime.now(timezone.utc) + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items', return_value=[]), + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged, + patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote, + patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'), + ): + mock_staged.return_value = ( + [ + {'id': 'st-1', 'description': 'Top task', 'completed': False, 'relevance_score': 1}, + ], + False, + ) + mock_promote.return_value = { + 'id': 'ai-1', + 'description': 'Top task', + 'completed': False, + 'created_at': now, + 'updated_at': now, + 'from_staged': True, + } + response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + data = response.json() + assert data['promoted'] is True + assert data['promoted_task']['id'] == 'ai-1' + + def test_promote_max_active_returns_false(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active, + ): + mock_active.return_value = [{'id': f'ai-{i}', 'description': f'Task {i}'} for i in range(5)] + response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + data = response.json() + assert data['promoted'] is False + assert 'max 5' in data['reason'] + + def test_promote_no_staged_tasks(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items', return_value=[]), + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks', return_value=([], False)), + ): + response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert response.json()['promoted'] is False + assert 'No staged tasks' in response.json()['reason'] + + def test_promote_skips_duplicates(self, client): + now = datetime.now(timezone.utc) + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active, + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged, + patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote, + patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'), + patch('routers.staged_tasks.staged_tasks_db.delete_staged_tasks_batch') as mock_batch_del, + ): + mock_active.return_value = [{'id': 'ai-1', 'description': 'Buy groceries'}] + mock_staged.return_value = ( + [ + {'id': 'st-1', 'description': 'buy groceries', 'completed': False, 'relevance_score': 1}, + {'id': 'st-2', 'description': 'Ship feature', 'completed': False, 'relevance_score': 2}, + ], + False, + ) + mock_promote.return_value = { + 'id': 'ai-2', + 'description': 'Ship feature', + 'completed': False, + 'created_at': now, + 'updated_at': now, + } + response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert response.json()['promoted'] is True + assert response.json()['promoted_task']['description'] == 'Ship feature' + # st-1 should be batch-deleted as duplicate + assert mock_batch_del.called + assert mock_batch_del.call_args[0][1] == ['st-1'] + + def test_promote_all_duplicates(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active, + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged, + patch('routers.staged_tasks.staged_tasks_db.delete_staged_tasks_batch'), + ): + mock_active.return_value = [{'id': 'ai-1', 'description': 'Task A'}] + mock_staged.return_value = ( + [ + {'id': 'st-1', 'description': 'task a', 'completed': False, 'relevance_score': 1}, + ], + False, + ) + response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert response.json()['promoted'] is False + assert 'duplicates' in response.json()['reason'] + + +class TestDailyScoreEndpoints: + def _make_app(self): + from fastapi import FastAPI + + app = FastAPI() + app.include_router(router) + return app + + @pytest.fixture + def client(self): + from fastapi.testclient import TestClient + + return TestClient(self._make_app()) + + def test_daily_score_today(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(3, 4)), + ): + response = client.get('/v1/daily-score', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + data = response.json() + assert data['score'] == 75.0 + assert data['completed_tasks'] == 3 + assert data['total_tasks'] == 4 + + def test_daily_score_specific_date(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(0, 0)), + ): + response = client.get('/v1/daily-score?date=2026-01-15', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert response.json()['date'] == '2026-01-15' + assert response.json()['score'] == 0.0 + + def test_daily_score_invalid_date_400(self, client): + with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'): + response = client.get('/v1/daily-score?date=not-a-date', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 400 + + def test_scores_all_three(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(2, 4)), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(10, 20)), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(50, 100)), + ): + response = client.get('/v1/scores', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + data = response.json() + assert data['daily']['score'] == 50.0 + assert data['weekly']['score'] == 50.0 + assert data['overall']['score'] == 50.0 + + def test_scores_default_tab_daily_when_highest(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(4, 4)), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(5, 10)), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(10, 30)), + ): + response = client.get('/v1/scores', headers={'Authorization': 'Bearer test'}) + assert response.json()['default_tab'] == 'daily' + + def test_scores_default_tab_weekly_when_no_daily(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(0, 0)), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(5, 10)), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(10, 30)), + ): + response = client.get('/v1/scores', headers={'Authorization': 'Bearer test'}) + assert response.json()['default_tab'] == 'weekly' + + def test_scores_invalid_date_400(self, client): + with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'): + response = client.get('/v1/scores?date=bad', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 400 + + def test_scores_no_tasks_zero(self, client): + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(0, 0)), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(0, 0)), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(0, 0)), + ): + response = client.get('/v1/scores', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + data = response.json() + assert data['daily']['score'] == 0.0 + assert data['weekly']['score'] == 0.0 + assert data['overall']['score'] == 0.0 From 837dfd798baa9a803a44cb91e553fda2e847c516 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:26:29 +0100 Subject: [PATCH 091/163] Fix reviewer issues: dedup on create, filter completed/deleted, weekly uses created_at, idempotent delete --- backend/database/staged_tasks.py | 68 +++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/backend/database/staged_tasks.py b/backend/database/staged_tasks.py index 3dfb3a66be..728ea0f278 100644 --- a/backend/database/staged_tasks.py +++ b/backend/database/staged_tasks.py @@ -25,13 +25,29 @@ def _prepare_for_read(data: dict) -> dict: def create_staged_task(uid: str, data: dict) -> dict: - """Create a staged task. Returns the created document with id.""" + """Create a staged task with dedup. Returns existing item if description matches (case-insensitive).""" + description = data.get('description', '').strip() + if not description: + raise ValueError('description must not be empty') + + ref = db.collection('users').document(uid).collection(COLLECTION) + + # Dedup: check for existing task with same description (case-insensitive) + normalized = description.lower() + for doc in ref.stream(): + existing = doc.to_dict() + if existing.get('deleted'): + continue + if existing.get('description', '').strip().lower() == normalized: + existing['id'] = doc.id + return _prepare_for_read(existing) + now = datetime.now(timezone.utc) + data['description'] = description data.setdefault('created_at', now) data.setdefault('updated_at', now) data.setdefault('completed', False) - ref = db.collection('users').document(uid).collection(COLLECTION) _, doc_ref = ref.add(data) result = data.copy() result['id'] = doc_ref.id @@ -42,12 +58,18 @@ def create_staged_task(uid: str, data: dict) -> dict: def get_staged_tasks(uid: str, limit: int = 100, offset: int = 0) -> Tuple[List[dict], bool]: - """List staged tasks ordered by relevance_score ASC. Returns (items, has_more).""" + """List staged tasks ordered by relevance_score ASC, filtering out completed/deleted. + + Matches Rust behavior: completed=false filter, skip deleted, tie-break by created_at DESC. + Returns (items, has_more). + """ ref = db.collection('users').document(uid).collection(COLLECTION) - query = ref.order_by('relevance_score', direction=firestore.Query.ASCENDING) + query = ref.where(filter=firestore.FieldFilter('completed', '==', False)).order_by( + 'relevance_score', direction=firestore.Query.ASCENDING + ) - # Fetch limit+1 to detect has_more - fetch_limit = limit + 1 + # Fetch more than needed to account for deleted items being filtered client-side + fetch_limit = (limit + 1) * 2 if offset > 0: query = query.offset(offset) query = query.limit(fetch_limit) @@ -56,6 +78,9 @@ def get_staged_tasks(uid: str, limit: int = 100, offset: int = 0) -> Tuple[List[ items = [] for doc in docs: data = doc.to_dict() + # Skip soft-deleted + if data.get('deleted'): + continue data['id'] = doc.id items.append(_prepare_for_read(data)) @@ -98,14 +123,10 @@ def batch_update_scores(uid: str, scores: List[dict]) -> None: # --- DELETE --- -def delete_staged_task(uid: str, task_id: str) -> bool: - """Hard-delete a staged task. Returns True if deleted.""" +def delete_staged_task(uid: str, task_id: str) -> None: + """Hard-delete a staged task. Idempotent — no error if not found (matches Rust behavior).""" doc_ref = db.collection('users').document(uid).collection(COLLECTION).document(task_id) - doc = doc_ref.get() - if not doc.exists: - return False doc_ref.delete() - return True def delete_staged_tasks_batch(uid: str, task_ids: List[str]) -> int: @@ -193,12 +214,29 @@ def get_action_items_for_daily_score(uid: str, due_start: str, due_end: str) -> def get_action_items_for_weekly_score(uid: str, week_start: str, week_end: str) -> Tuple[int, int]: - """Count completed vs total action items in a 7-day window. + """Count completed vs total action items created in a 7-day window. + Uses created_at range (not due_at) to match Rust weekly score behavior. Returns (completed_count, total_count). """ - # Same logic as daily but wider date range - return get_action_items_for_daily_score(uid, week_start, week_end) + ref = db.collection('users').document(uid).collection('action_items') + start_dt = datetime.fromisoformat(week_start.replace('Z', '+00:00')) + end_dt = datetime.fromisoformat(week_end.replace('Z', '+00:00')) + + query = ref.where(filter=firestore.FieldFilter('created_at', '>=', start_dt)).where( + filter=firestore.FieldFilter('created_at', '<=', end_dt) + ) + + completed = 0 + total = 0 + for doc in query.stream(): + data = doc.to_dict() + if data.get('deleted'): + continue + total += 1 + if data.get('completed'): + completed += 1 + return completed, total def get_action_items_for_overall_score(uid: str) -> Tuple[int, int]: From 2c1bedd122aade6dd05fa2f5a65ceda30c79347c Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:26:30 +0100 Subject: [PATCH 092/163] Make delete endpoint idempotent to match Rust behavior --- backend/routers/staged_tasks.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/backend/routers/staged_tasks.py b/backend/routers/staged_tasks.py index e6b495c134..2d60cadc6b 100644 --- a/backend/routers/staged_tasks.py +++ b/backend/routers/staged_tasks.py @@ -118,10 +118,8 @@ def get_staged_tasks( @router.delete('/v1/staged-tasks/{task_id}', response_model=StatusResponse, tags=['staged-tasks']) def delete_staged_task(task_id: str, uid: str = Depends(auth.get_current_user_uid)): - """Hard-delete a staged task.""" - deleted = staged_tasks_db.delete_staged_task(uid, task_id) - if not deleted: - raise HTTPException(status_code=404, detail='Staged task not found') + """Hard-delete a staged task. Idempotent — returns ok even if not found (matches Rust).""" + staged_tasks_db.delete_staged_task(uid, task_id) return StatusResponse(status='ok') From 62daa9966905674a7856b52d29db17f90a006277 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:26:31 +0100 Subject: [PATCH 093/163] Add tests for dedup create, idempotent delete, weekly created_at semantics --- backend/tests/unit/test_staged_tasks.py | 51 +++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/backend/tests/unit/test_staged_tasks.py b/backend/tests/unit/test_staged_tasks.py index b86a0254cf..14d0f5431c 100644 --- a/backend/tests/unit/test_staged_tasks.py +++ b/backend/tests/unit/test_staged_tasks.py @@ -183,19 +183,22 @@ def test_list_staged_tasks_limit_over_max_422(self, client): def test_delete_staged_task(self, client): with ( patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.staged_tasks.staged_tasks_db.delete_staged_task', return_value=True), + patch('routers.staged_tasks.staged_tasks_db.delete_staged_task') as mock_del, ): response = client.delete('/v1/staged-tasks/st-1', headers={'Authorization': 'Bearer test'}) assert response.status_code == 200 assert response.json()['status'] == 'ok' + assert mock_del.called - def test_delete_staged_task_not_found_404(self, client): + def test_delete_staged_task_idempotent(self, client): + """Delete returns 200 even for non-existent task (matches Rust behavior).""" with ( patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), - patch('routers.staged_tasks.staged_tasks_db.delete_staged_task', return_value=False), + patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'), ): response = client.delete('/v1/staged-tasks/missing', headers={'Authorization': 'Bearer test'}) - assert response.status_code == 404 + assert response.status_code == 200 + assert response.json()['status'] == 'ok' def test_batch_update_scores(self, client): with ( @@ -418,3 +421,43 @@ def test_scores_no_tasks_zero(self, client): assert data['daily']['score'] == 0.0 assert data['weekly']['score'] == 0.0 assert data['overall']['score'] == 0.0 + + def test_create_dedup_returns_existing(self, client): + """Create returns existing task if description matches (case-insensitive).""" + now = datetime.now(timezone.utc) + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.create_staged_task') as mock_create, + ): + # Simulate dedup returning existing task + mock_create.return_value = { + 'id': 'existing-1', + 'description': 'Buy milk', + 'completed': False, + 'created_at': now, + 'updated_at': now, + } + response = client.post( + '/v1/staged-tasks', + json={'description': 'buy milk'}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + assert response.json()['id'] == 'existing-1' + + def test_weekly_score_uses_created_at(self, client): + """Weekly score filters by created_at range, not due_at.""" + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_daily_score', return_value=(1, 2)), + patch( + 'routers.staged_tasks.staged_tasks_db.get_action_items_for_weekly_score', return_value=(7, 14) + ) as mock_weekly, + patch('routers.staged_tasks.staged_tasks_db.get_action_items_for_overall_score', return_value=(20, 40)), + ): + response = client.get('/v1/scores?date=2026-03-05', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert mock_weekly.called + # Weekly should use a 7-day window ending today + week_start_arg = mock_weekly.call_args[0][1] + assert '2026-02-26' in week_start_arg From 899acc5d2dc5987f7f1df18451ba25fed43fe592 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:30:19 +0100 Subject: [PATCH 094/163] Add created_at DESC tie-break ordering for staged tasks query --- backend/database/staged_tasks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/database/staged_tasks.py b/backend/database/staged_tasks.py index 728ea0f278..28f8c35c18 100644 --- a/backend/database/staged_tasks.py +++ b/backend/database/staged_tasks.py @@ -64,8 +64,10 @@ def get_staged_tasks(uid: str, limit: int = 100, offset: int = 0) -> Tuple[List[ Returns (items, has_more). """ ref = db.collection('users').document(uid).collection(COLLECTION) - query = ref.where(filter=firestore.FieldFilter('completed', '==', False)).order_by( - 'relevance_score', direction=firestore.Query.ASCENDING + query = ( + ref.where(filter=firestore.FieldFilter('completed', '==', False)) + .order_by('relevance_score', direction=firestore.Query.ASCENDING) + .order_by('created_at', direction=firestore.Query.DESCENDING) ) # Fetch more than needed to account for deleted items being filtered client-side From aede9243f65c7d21a72961b9a298e75054a5df92 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:30:20 +0100 Subject: [PATCH 095/163] Move in-function imports to module top level per CLAUDE.md --- backend/routers/staged_tasks.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/backend/routers/staged_tasks.py b/backend/routers/staged_tasks.py index 2d60cadc6b..9078187bc7 100644 --- a/backend/routers/staged_tasks.py +++ b/backend/routers/staged_tasks.py @@ -10,7 +10,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field, field_validator from typing import Optional, List -from datetime import datetime +from datetime import datetime, timedelta import database.staged_tasks as staged_tasks_db from utils.other import endpoints as auth @@ -224,8 +224,6 @@ def get_daily_score( uid: str = Depends(auth.get_current_user_uid), ): """Calculate daily score from action items due today (legacy endpoint).""" - from datetime import date as date_type - if date: try: parsed = datetime.strptime(date, '%Y-%m-%d').date() @@ -250,8 +248,6 @@ def get_scores( uid: str = Depends(auth.get_current_user_uid), ): """Get daily, weekly, and overall scores with default tab selection.""" - from datetime import timedelta - if date: try: parsed = datetime.strptime(date, '%Y-%m-%d').date() From b67873c468419cd29d523238c2535569f979b511 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:35:15 +0100 Subject: [PATCH 096/163] Add 18 tests: DB-layer dedup/filter/scoring, [screen] normalization, boundary caps (50 total) --- backend/tests/unit/test_staged_tasks.py | 354 ++++++++++++++++++++++++ 1 file changed, 354 insertions(+) diff --git a/backend/tests/unit/test_staged_tasks.py b/backend/tests/unit/test_staged_tasks.py index 14d0f5431c..60c414d8f8 100644 --- a/backend/tests/unit/test_staged_tasks.py +++ b/backend/tests/unit/test_staged_tasks.py @@ -461,3 +461,357 @@ def test_weekly_score_uses_created_at(self, client): # Weekly should use a 7-day window ending today week_start_arg = mock_weekly.call_args[0][1] assert '2026-02-26' in week_start_arg + + # --- Promote with [screen] prefix/suffix normalization --- + + def test_promote_skips_screen_prefix_duplicate(self, client): + """Promote dedup strips [screen] prefix when comparing descriptions.""" + now = datetime.now(timezone.utc) + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active, + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged, + patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote, + patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'), + patch('routers.staged_tasks.staged_tasks_db.delete_staged_tasks_batch') as mock_batch_del, + ): + # Active item without [screen] prefix + mock_active.return_value = [{'id': 'ai-1', 'description': 'Buy milk'}] + # Staged item with [screen] prefix — should be detected as duplicate + mock_staged.return_value = ( + [ + {'id': 'st-1', 'description': '[screen] Buy milk', 'completed': False, 'relevance_score': 1}, + {'id': 'st-2', 'description': 'New unique task', 'completed': False, 'relevance_score': 2}, + ], + False, + ) + mock_promote.return_value = { + 'id': 'ai-2', + 'description': 'New unique task', + 'completed': False, + 'created_at': now, + 'updated_at': now, + } + response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert response.json()['promoted'] is True + assert response.json()['promoted_task']['description'] == 'New unique task' + # st-1 with [screen] prefix should be deleted as duplicate + assert mock_batch_del.called + assert 'st-1' in mock_batch_del.call_args[0][1] + + def test_promote_skips_screen_suffix_duplicate(self, client): + """Promote dedup strips [screen] suffix when comparing descriptions.""" + now = datetime.now(timezone.utc) + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active, + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged, + patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote, + patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'), + patch('routers.staged_tasks.staged_tasks_db.delete_staged_tasks_batch') as mock_batch_del, + ): + # Active item with [screen] suffix + mock_active.return_value = [{'id': 'ai-1', 'description': 'Buy milk [screen]'}] + # Staged item without [screen] — should be detected as duplicate + mock_staged.return_value = ( + [ + {'id': 'st-1', 'description': 'buy milk', 'completed': False, 'relevance_score': 1}, + {'id': 'st-2', 'description': 'Different task', 'completed': False, 'relevance_score': 2}, + ], + False, + ) + mock_promote.return_value = { + 'id': 'ai-2', + 'description': 'Different task', + 'completed': False, + 'created_at': now, + 'updated_at': now, + } + response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert response.json()['promoted'] is True + # st-1 should be deleted as duplicate + assert mock_batch_del.called + assert 'st-1' in mock_batch_del.call_args[0][1] + + # --- Promote boundary: 4 active should still promote --- + + def test_promote_with_4_active_succeeds(self, client): + """Promote succeeds when exactly 4 active AI tasks (under max 5).""" + now = datetime.now(timezone.utc) + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_active_ai_action_items') as mock_active, + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks') as mock_staged, + patch('routers.staged_tasks.staged_tasks_db.promote_staged_task') as mock_promote, + patch('routers.staged_tasks.staged_tasks_db.delete_staged_task'), + ): + mock_active.return_value = [{'id': f'ai-{i}', 'description': f'Task {i}'} for i in range(4)] + mock_staged.return_value = ( + [{'id': 'st-1', 'description': 'New task', 'completed': False, 'relevance_score': 1}], + False, + ) + mock_promote.return_value = { + 'id': 'ai-5', + 'description': 'New task', + 'completed': False, + 'created_at': now, + 'updated_at': now, + } + response = client.post('/v1/staged-tasks/promote', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + assert response.json()['promoted'] is True + + # --- Cap boundary tests --- + + def test_create_description_max_length_accepted(self, client): + """Description at exactly 2000 chars is accepted.""" + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.create_staged_task') as mock_create, + ): + desc = 'A' * 2000 + mock_create.return_value = { + 'id': 'st-1', + 'description': desc, + 'completed': False, + } + response = client.post( + '/v1/staged-tasks', + json={'description': desc}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 200 + + def test_create_description_over_max_rejected(self, client): + """Description at 2001 chars is rejected.""" + with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'): + response = client.post( + '/v1/staged-tasks', + json={'description': 'A' * 2001}, + headers={'Authorization': 'Bearer test'}, + ) + assert response.status_code == 422 + + def test_list_limit_1_accepted(self, client): + """List with limit=1 is accepted.""" + with ( + patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'), + patch('routers.staged_tasks.staged_tasks_db.get_staged_tasks', return_value=([], False)), + ): + response = client.get('/v1/staged-tasks?limit=1', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 200 + + def test_list_limit_0_rejected(self, client): + """List with limit=0 is rejected (min 1).""" + with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'): + response = client.get('/v1/staged-tasks?limit=0', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 422 + + def test_list_offset_negative_rejected(self, client): + """List with offset=-1 is rejected (min 0).""" + with patch('routers.staged_tasks.auth.get_current_user_uid', return_value='uid-1'): + response = client.get('/v1/staged-tasks?offset=-1', headers={'Authorization': 'Bearer test'}) + assert response.status_code == 422 + + +# --- DB Unit Tests --- + + +class _MockDoc: + """Mock Firestore document snapshot.""" + + def __init__(self, doc_id, data, exists=True): + self.id = doc_id + self._data = data + self.exists = exists + + def to_dict(self): + return self._data.copy() + + +class TestStagedTasksDB: + """Unit tests for database/staged_tasks.py functions with mocked Firestore.""" + + def test_create_dedup_case_insensitive(self): + """create_staged_task returns existing task if description matches case-insensitively.""" + import database.staged_tasks as db_mod + + existing_doc = _MockDoc('existing-1', {'description': 'Buy Milk', 'completed': False}) + mock_ref = MagicMock() + mock_ref.stream.return_value = [existing_doc] + + with patch.object(db_mod, 'db') as mock_db: + mock_db.collection.return_value.document.return_value.collection.return_value = mock_ref + result = db_mod.create_staged_task('uid-1', {'description': 'buy milk'}) + assert result['id'] == 'existing-1' + assert result['description'] == 'Buy Milk' + # Should NOT have called add (dedup returned existing) + mock_ref.add.assert_not_called() + + def test_create_dedup_whitespace_trim(self): + """create_staged_task trims whitespace before dedup comparison.""" + import database.staged_tasks as db_mod + + existing_doc = _MockDoc('existing-1', {'description': 'Buy Milk', 'completed': False}) + mock_ref = MagicMock() + mock_ref.stream.return_value = [existing_doc] + + with patch.object(db_mod, 'db') as mock_db: + mock_db.collection.return_value.document.return_value.collection.return_value = mock_ref + result = db_mod.create_staged_task('uid-1', {'description': ' buy milk '}) + assert result['id'] == 'existing-1' + mock_ref.add.assert_not_called() + + def test_create_dedup_skips_deleted(self): + """create_staged_task ignores soft-deleted tasks during dedup scan.""" + import database.staged_tasks as db_mod + + deleted_doc = _MockDoc('del-1', {'description': 'Buy Milk', 'completed': False, 'deleted': True}) + mock_ref = MagicMock() + mock_ref.stream.return_value = [deleted_doc] + mock_ref.add.return_value = (None, MagicMock(id='new-1')) + + with patch.object(db_mod, 'db') as mock_db: + mock_db.collection.return_value.document.return_value.collection.return_value = mock_ref + result = db_mod.create_staged_task('uid-1', {'description': 'Buy Milk'}) + # Should create new since deleted match doesn't count + assert result['id'] == 'new-1' + mock_ref.add.assert_called_once() + + def test_create_empty_description_raises(self): + """create_staged_task raises ValueError for empty/whitespace description.""" + import database.staged_tasks as db_mod + + with pytest.raises(ValueError, match='description must not be empty'): + db_mod.create_staged_task('uid-1', {'description': ' '}) + + def test_get_staged_tasks_filters_completed_and_deleted(self): + """get_staged_tasks uses completed=false filter and skips deleted client-side.""" + import database.staged_tasks as db_mod + + docs = [ + _MockDoc('t-1', {'description': 'Active', 'completed': False, 'relevance_score': 1}), + _MockDoc('t-2', {'description': 'Deleted', 'completed': False, 'deleted': True, 'relevance_score': 2}), + _MockDoc('t-3', {'description': 'Also active', 'completed': False, 'relevance_score': 3}), + ] + + mock_query = MagicMock() + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.limit.return_value = mock_query + mock_query.stream.return_value = docs + + with patch.object(db_mod, 'db') as mock_db: + mock_db.collection.return_value.document.return_value.collection.return_value = mock_query + items, has_more = db_mod.get_staged_tasks('uid-1', limit=10) + # Should have 2 items (t-2 is deleted, filtered out) + assert len(items) == 2 + assert items[0]['id'] == 't-1' + assert items[1]['id'] == 't-3' + assert has_more is False + + def test_get_staged_tasks_queries_completed_false(self): + """get_staged_tasks passes completed=false FieldFilter to Firestore.""" + import database.staged_tasks as db_mod + + mock_query = MagicMock() + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.limit.return_value = mock_query + mock_query.stream.return_value = [] + + with ( + patch.object(db_mod, 'db') as mock_db, + patch.object(db_mod, 'firestore') as mock_fs, + ): + mock_db.collection.return_value.document.return_value.collection.return_value = mock_query + mock_fs.FieldFilter.return_value = 'completed_filter' + mock_fs.Query.ASCENDING = 'ASC' + mock_fs.Query.DESCENDING = 'DESC' + + db_mod.get_staged_tasks('uid-1') + + # Verify FieldFilter was called with completed=false + mock_fs.FieldFilter.assert_called_once_with('completed', '==', False) + mock_query.where.assert_called_once_with(filter='completed_filter') + + def test_daily_score_uses_due_at(self): + """get_action_items_for_daily_score filters by due_at range.""" + import database.staged_tasks as db_mod + + mock_query = MagicMock() + mock_query.where.return_value = mock_query + mock_query.stream.return_value = [] + + with ( + patch.object(db_mod, 'db') as mock_db, + patch.object(db_mod, 'firestore') as mock_fs, + ): + mock_db.collection.return_value.document.return_value.collection.return_value = mock_query + mock_fs.FieldFilter.side_effect = lambda field, op, val: f'{field}_{op}_{val}' + + db_mod.get_action_items_for_daily_score('uid-1', '2026-03-05T00:00:00Z', '2026-03-05T23:59:59.999Z') + + # Should have called FieldFilter with 'due_at' (not 'created_at') + calls = mock_fs.FieldFilter.call_args_list + fields_used = [c[0][0] for c in calls] + assert 'due_at' in fields_used + assert 'created_at' not in fields_used + + def test_weekly_score_uses_created_at(self): + """get_action_items_for_weekly_score filters by created_at range (not due_at).""" + import database.staged_tasks as db_mod + + mock_query = MagicMock() + mock_query.where.return_value = mock_query + mock_query.stream.return_value = [] + + with ( + patch.object(db_mod, 'db') as mock_db, + patch.object(db_mod, 'firestore') as mock_fs, + ): + mock_db.collection.return_value.document.return_value.collection.return_value = mock_query + mock_fs.FieldFilter.side_effect = lambda field, op, val: f'{field}_{op}_{val}' + + db_mod.get_action_items_for_weekly_score('uid-1', '2026-02-26T00:00:00Z', '2026-03-05T23:59:59.999Z') + + # Should have called FieldFilter with 'created_at' (not 'due_at') + calls = mock_fs.FieldFilter.call_args_list + fields_used = [c[0][0] for c in calls] + assert 'created_at' in fields_used + assert 'due_at' not in fields_used + + def test_overall_score_counts_all_non_deleted(self): + """get_action_items_for_overall_score scans all docs, skips deleted.""" + import database.staged_tasks as db_mod + + docs = [ + _MockDoc('a-1', {'completed': True}), + _MockDoc('a-2', {'completed': False}), + _MockDoc('a-3', {'completed': True, 'deleted': True}), # Should be skipped + _MockDoc('a-4', {'completed': False}), + ] + + mock_ref = MagicMock() + mock_ref.stream.return_value = docs + + with patch.object(db_mod, 'db') as mock_db: + mock_db.collection.return_value.document.return_value.collection.return_value = mock_ref + completed, total = db_mod.get_action_items_for_overall_score('uid-1') + assert completed == 1 # Only a-1 (a-3 is deleted) + assert total == 3 # a-1, a-2, a-4 (a-3 is deleted) + + def test_delete_is_idempotent(self): + """delete_staged_task calls Firestore delete without checking existence.""" + import database.staged_tasks as db_mod + + mock_doc_ref = MagicMock() + with patch.object(db_mod, 'db') as mock_db: + mock_db.collection.return_value.document.return_value.collection.return_value.document.return_value = ( + mock_doc_ref + ) + # Should not raise even if doc doesn't exist + db_mod.delete_staged_task('uid-1', 'nonexistent-id') + mock_doc_ref.delete.assert_called_once() From c84e4a28f49520bdef77b01c94b1cd0fa55bbe8f Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:41:32 +0100 Subject: [PATCH 097/163] Fix focus_sessions test fixture: use sys.modules mock + isolated router import --- backend/tests/unit/test_focus_sessions.py | 57 +++++++++++++++++++---- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/backend/tests/unit/test_focus_sessions.py b/backend/tests/unit/test_focus_sessions.py index fc898d4a39..bc78913fe3 100644 --- a/backend/tests/unit/test_focus_sessions.py +++ b/backend/tests/unit/test_focus_sessions.py @@ -1,24 +1,57 @@ +import sys from datetime import datetime, timezone from unittest.mock import patch, MagicMock import pytest + +for mod_name in [ + 'firebase_admin', + 'firebase_admin.auth', + 'firebase_admin.firestore', + 'firebase_admin.messaging', + 'google.cloud', + 'google.cloud.exceptions', + 'google.cloud.firestore', + 'google.cloud.firestore_v1', + 'google.cloud.firestore_v1.base_query', + 'google.cloud.firestore_v1.query', + 'google.cloud.storage', + 'google.cloud.storage.blob', + 'google.cloud.storage.bucket', + 'google.auth', + 'google.auth.transport', + 'google.auth.transport.requests', + 'google.oauth2', + 'google.oauth2.service_account', + 'pinecone', + 'typesense', +]: + sys.modules.setdefault(mod_name, MagicMock()) + +from routers.focus_sessions import router + +from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient @pytest.fixture def client(): - with patch('database.screen_activity.db'), \ - patch('database.focus_sessions.db'), \ - patch('database.advice.db'), \ - patch('database.vector_db.Pinecone'), \ - patch('database.vector_db.pc'), \ - patch('database.vector_db.index'), \ - patch('utils.llm.clients.embeddings'): - from main import app + with patch('routers.focus_sessions.auth.get_current_user_uid', return_value='uid-1'): + app = FastAPI() + app.include_router(router) with TestClient(app) as c: yield c +@pytest.fixture +def client_no_auth(): + """Client without auth mock — for testing 401 responses.""" + app = FastAPI() + app.include_router(router) + with TestClient(app) as c: + yield c + + AUTH = {"Authorization": "Bearer 123testuser"} @@ -66,9 +99,13 @@ def test_create_with_optional_fields(self, client): assert resp.json()["message"] == "Keep going!" assert resp.json()["duration_seconds"] == 300 - def test_create_no_auth_returns_401(self, client): + def test_create_no_auth_returns_401(self, client_no_auth): data = {"status": "focused", "app_or_site": "X", "description": "Y"} - resp = client.post("/v1/focus-sessions", json=data) + with patch( + 'routers.focus_sessions.auth.get_current_user_uid', + side_effect=HTTPException(status_code=401, detail='Not authenticated'), + ): + resp = client_no_auth.post("/v1/focus-sessions", json=data) assert resp.status_code == 401 def test_create_firestore_error_returns_500(self, client): From c2e0bdf69cd7100a599d1de0bb45dddb3eb2188f Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 10:41:33 +0100 Subject: [PATCH 098/163] Fix advice test fixture: use sys.modules mock + isolated router import --- backend/tests/unit/test_advice.py | 57 +++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/backend/tests/unit/test_advice.py b/backend/tests/unit/test_advice.py index 6b16f69933..657d909a90 100644 --- a/backend/tests/unit/test_advice.py +++ b/backend/tests/unit/test_advice.py @@ -1,24 +1,57 @@ +import sys from datetime import datetime, timezone from unittest.mock import patch, MagicMock import pytest + +for mod_name in [ + 'firebase_admin', + 'firebase_admin.auth', + 'firebase_admin.firestore', + 'firebase_admin.messaging', + 'google.cloud', + 'google.cloud.exceptions', + 'google.cloud.firestore', + 'google.cloud.firestore_v1', + 'google.cloud.firestore_v1.base_query', + 'google.cloud.firestore_v1.query', + 'google.cloud.storage', + 'google.cloud.storage.blob', + 'google.cloud.storage.bucket', + 'google.auth', + 'google.auth.transport', + 'google.auth.transport.requests', + 'google.oauth2', + 'google.oauth2.service_account', + 'pinecone', + 'typesense', +]: + sys.modules.setdefault(mod_name, MagicMock()) + +from routers.advice import router + +from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient @pytest.fixture def client(): - with patch('database.screen_activity.db'), \ - patch('database.focus_sessions.db'), \ - patch('database.advice.db'), \ - patch('database.vector_db.Pinecone'), \ - patch('database.vector_db.pc'), \ - patch('database.vector_db.index'), \ - patch('utils.llm.clients.embeddings'): - from main import app + with patch('routers.advice.auth.get_current_user_uid', return_value='uid-1'): + app = FastAPI() + app.include_router(router) with TestClient(app) as c: yield c +@pytest.fixture +def client_no_auth(): + """Client without auth mock — for testing 401 responses.""" + app = FastAPI() + app.include_router(router) + with TestClient(app) as c: + yield c + + AUTH = {"Authorization": "Bearer 123testuser"} @@ -79,8 +112,12 @@ def test_create_confidence_boundary_one(self, client): resp = client.post("/v1/advice", json=data, headers=AUTH) assert resp.status_code == 200 - def test_create_no_auth_returns_401(self, client): - resp = client.post("/v1/advice", json={"content": "Test"}) + def test_create_no_auth_returns_401(self, client_no_auth): + with patch( + 'routers.advice.auth.get_current_user_uid', + side_effect=HTTPException(status_code=401, detail='Not authenticated'), + ): + resp = client_no_auth.post("/v1/advice", json={"content": "Test"}) assert resp.status_code == 401 def test_create_firestore_error_returns_500(self, client): From 9414efbbe0c12a19e77e272344d5edf7b9e1605e Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 12:13:36 +0100 Subject: [PATCH 099/163] Add POST /v2/chat/generate-title endpoint for desktop session naming --- backend/routers/chat.py | 45 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/backend/routers/chat.py b/backend/routers/chat.py index 27d1ae2150..15d34e920c 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -33,6 +33,7 @@ resolve_voice_message_language, transcribe_voice_message_segment, ) +from utils.llm.clients import llm_mini from utils.llm.persona import initial_persona_chat_message from utils.llm.chat import initial_chat_message from utils.llm.goals import extract_and_update_goal_progress @@ -695,6 +696,50 @@ def rate_message( return StatusResponse(status='ok') +class TitleMessageInput(BaseModel): + text: str + sender: str + + +class GenerateTitleRequest(BaseModel): + session_id: str + messages: List[TitleMessageInput] + + +class GenerateTitleResponse(BaseModel): + title: str + + +@router.post('/v2/chat/generate-title', response_model=GenerateTitleResponse, tags=['chat']) +def generate_chat_title( + request: GenerateTitleRequest, + uid: str = Depends(auth.get_current_user_uid), +): + """Desktop: generate a short title for a chat session from its messages.""" + if not request.messages: + raise HTTPException(status_code=400, detail="messages list cannot be empty") + + transcript = '\n'.join(f'{m.sender}: {m.text[:500]}' for m in request.messages[:10]) + prompt = ( + 'Generate a short chat session title (max 6 words) summarising this conversation. ' + 'Return ONLY the title text, no quotes or punctuation.\n\n' + transcript + ) + try: + result = llm_mini.invoke(prompt) + title = result.content.strip().strip('"\'')[:100] + except Exception as e: + logger.warning(f'generate_chat_title LLM failed: {e}') + title = request.messages[0].text[:50] + + # Update session title if session exists + try: + chat_db.update_chat_session(uid, request.session_id, {'title': title, 'updated_at': datetime.now(timezone.utc)}) + except Exception as e: + logger.warning(f'generate_chat_title update session failed: {e}') + + return GenerateTitleResponse(title=title) + + # CLEANUP: Remove after new app goes to prod ---------------------------------------------------------- From 21204a5f62d630a3c48636b5e2702a46255218b9 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 12:13:37 +0100 Subject: [PATCH 100/163] Add GET /v1/conversations/count endpoint with Firestore aggregation --- backend/database/conversations.py | 11 +++++++++++ backend/routers/conversations.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/backend/database/conversations.py b/backend/database/conversations.py index 1d9896ff44..7a408ccf41 100644 --- a/backend/database/conversations.py +++ b/backend/database/conversations.py @@ -220,6 +220,17 @@ def get_conversations( return conversations +def count_conversations(uid: str, statuses: List[str] = []) -> int: + """Count conversations matching status filters without fetching full documents.""" + conversations_ref = db.collection('users').document(uid).collection(conversations_collection) + conversations_ref = conversations_ref.where(filter=FieldFilter('discarded', '==', False)) + if statuses: + conversations_ref = conversations_ref.where(filter=FieldFilter('status', 'in', statuses)) + count_query = conversations_ref.count() + results = count_query.get() + return results[0][0].value + + @prepare_for_read(decrypt_func=_prepare_conversation_for_read) def get_conversations_without_photos( uid: str, diff --git a/backend/routers/conversations.py b/backend/routers/conversations.py index c2041779a2..db71b1e3c8 100644 --- a/backend/routers/conversations.py +++ b/backend/routers/conversations.py @@ -249,6 +249,22 @@ def get_conversations( return conversations +@router.get('/v1/conversations/count', tags=['conversations']) +def get_conversations_count( + statuses: Optional[str] = Query("processing,completed"), + uid: str = Depends(auth.get_current_user_uid), +): + """Count conversations matching optional status filters.""" + status_list = [s.strip() for s in statuses.split(',') if s.strip()] if statuses else [] + try: + count = conversations_db.count_conversations(uid, statuses=status_list) + except Exception as e: + logger.warning(f'count_conversations fallback: {e}') + conversations = conversations_db.get_conversations(uid, limit=10000, statuses=status_list) + count = len(conversations) + return {'count': count} + + @router.get("/v1/conversations/{conversation_id}", response_model=Conversation, tags=['conversations']) def get_conversation_by_id(conversation_id: str, uid: str = Depends(auth.get_current_user_uid)): logger.info(f'get_conversation_by_id {uid} {conversation_id}') From 967f47bc6408d1a92aafcd4a37e424966df8fa0b Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 12:13:41 +0100 Subject: [PATCH 101/163] Add unit tests for generate-title and conversations count endpoints --- .../tests/unit/test_chat_generate_title.py | 178 ++++++++++++++++++ .../tests/unit/test_conversations_count.py | 110 +++++++++++ 2 files changed, 288 insertions(+) create mode 100644 backend/tests/unit/test_chat_generate_title.py create mode 100644 backend/tests/unit/test_conversations_count.py diff --git a/backend/tests/unit/test_chat_generate_title.py b/backend/tests/unit/test_chat_generate_title.py new file mode 100644 index 0000000000..fbedfa18c1 --- /dev/null +++ b/backend/tests/unit/test_chat_generate_title.py @@ -0,0 +1,178 @@ +import sys +from datetime import datetime, timezone +from unittest.mock import patch, MagicMock + +import pytest + +for mod_name in [ + 'firebase_admin', + 'firebase_admin.auth', + 'firebase_admin.firestore', + 'firebase_admin.messaging', + 'google.cloud', + 'google.cloud.exceptions', + 'google.cloud.firestore', + 'google.cloud.firestore_v1', + 'google.cloud.firestore_v1.base_query', + 'google.cloud.firestore_v1.query', + 'google.cloud.storage', + 'google.cloud.storage.blob', + 'google.cloud.storage.bucket', + 'google.auth', + 'google.auth.transport', + 'google.auth.transport.requests', + 'google.oauth2', + 'google.oauth2.service_account', + 'pinecone', + 'typesense', + 'openai', + 'langchain_openai', +]: + sys.modules.setdefault(mod_name, MagicMock()) + +# Mock llm_mini before importing the router +mock_llm = MagicMock() +mock_llm.invoke.return_value = MagicMock(content='Project Discussion') +sys.modules.setdefault('utils.llm.clients', MagicMock(llm_mini=mock_llm)) + +from routers.chat import router + +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient + + +@pytest.fixture +def client(): + with patch('routers.chat.auth.get_current_user_uid', return_value='uid-1'): + app = FastAPI() + app.include_router(router) + with TestClient(app) as c: + yield c + + +@pytest.fixture +def client_no_auth(): + app = FastAPI() + app.include_router(router) + with TestClient(app) as c: + yield c + + +AUTH = {"Authorization": "Bearer 123testuser"} + + +class TestGenerateChatTitle: + def test_generate_title_success(self, client): + data = { + "session_id": "sess-1", + "messages": [ + {"text": "How do I deploy to production?", "sender": "human"}, + {"text": "You can use the CI/CD pipeline.", "sender": "ai"}, + ], + } + with patch('routers.chat.llm_mini') as mock_llm: + mock_llm.invoke.return_value = MagicMock(content='Production Deployment') + with patch('routers.chat.chat_db.update_chat_session'): + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["title"] == "Production Deployment" + + def test_generate_title_strips_quotes(self, client): + data = { + "session_id": "sess-1", + "messages": [{"text": "Hello", "sender": "human"}], + } + with patch('routers.chat.llm_mini') as mock_llm: + mock_llm.invoke.return_value = MagicMock(content='"Greeting Chat"') + with patch('routers.chat.chat_db.update_chat_session'): + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["title"] == "Greeting Chat" + + def test_generate_title_empty_messages_returns_400(self, client): + data = {"session_id": "sess-1", "messages": []} + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 400 + + def test_generate_title_no_messages_field_returns_422(self, client): + data = {"session_id": "sess-1"} + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 422 + + def test_generate_title_llm_fallback(self, client): + data = { + "session_id": "sess-1", + "messages": [{"text": "What about the budget proposal?", "sender": "human"}], + } + with patch('routers.chat.llm_mini') as mock_llm: + mock_llm.invoke.side_effect = Exception("LLM down") + with patch('routers.chat.chat_db.update_chat_session'): + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["title"] == "What about the budget proposal?" + + def test_generate_title_updates_session(self, client): + data = { + "session_id": "sess-1", + "messages": [{"text": "Hello", "sender": "human"}], + } + with patch('routers.chat.llm_mini') as mock_llm: + mock_llm.invoke.return_value = MagicMock(content='Greeting') + with patch('routers.chat.chat_db.update_chat_session') as mock_update: + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 200 + mock_update.assert_called_once() + call_args = mock_update.call_args[0] + assert call_args[1] == 'sess-1' + assert call_args[2]['title'] == 'Greeting' + + def test_generate_title_session_update_failure_still_returns(self, client): + data = { + "session_id": "sess-1", + "messages": [{"text": "Hello", "sender": "human"}], + } + with patch('routers.chat.llm_mini') as mock_llm: + mock_llm.invoke.return_value = MagicMock(content='Greeting') + with patch('routers.chat.chat_db.update_chat_session', side_effect=Exception("DB err")): + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["title"] == "Greeting" + + def test_generate_title_truncates_long_title(self, client): + data = { + "session_id": "sess-1", + "messages": [{"text": "Hello", "sender": "human"}], + } + with patch('routers.chat.llm_mini') as mock_llm: + mock_llm.invoke.return_value = MagicMock(content='A' * 200) + with patch('routers.chat.chat_db.update_chat_session'): + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 200 + assert len(resp.json()["title"]) <= 100 + + def test_generate_title_no_auth_returns_401(self, client_no_auth): + data = { + "session_id": "sess-1", + "messages": [{"text": "Hello", "sender": "human"}], + } + with patch( + 'routers.chat.auth.get_current_user_uid', + side_effect=HTTPException(status_code=401, detail='Not authenticated'), + ): + resp = client_no_auth.post("/v2/chat/generate-title", json=data) + assert resp.status_code == 401 + + def test_generate_title_limits_messages(self, client): + """Only first 10 messages should be sent to LLM.""" + data = { + "session_id": "sess-1", + "messages": [{"text": f"Message {i}", "sender": "human"} for i in range(20)], + } + with patch('routers.chat.llm_mini') as mock_llm: + mock_llm.invoke.return_value = MagicMock(content='Long Chat') + with patch('routers.chat.chat_db.update_chat_session'): + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 200 + prompt = mock_llm.invoke.call_args[0][0] + assert 'Message 9' in prompt + assert 'Message 10' not in prompt diff --git a/backend/tests/unit/test_conversations_count.py b/backend/tests/unit/test_conversations_count.py new file mode 100644 index 0000000000..d1104cfae2 --- /dev/null +++ b/backend/tests/unit/test_conversations_count.py @@ -0,0 +1,110 @@ +import sys +from unittest.mock import patch, MagicMock + +import pytest + +for mod_name in [ + 'firebase_admin', + 'firebase_admin.auth', + 'firebase_admin.firestore', + 'firebase_admin.messaging', + 'google.cloud', + 'google.cloud.exceptions', + 'google.cloud.firestore', + 'google.cloud.firestore_v1', + 'google.cloud.firestore_v1.base_query', + 'google.cloud.firestore_v1.query', + 'google.cloud.storage', + 'google.cloud.storage.blob', + 'google.cloud.storage.bucket', + 'google.auth', + 'google.auth.transport', + 'google.auth.transport.requests', + 'google.oauth2', + 'google.oauth2.service_account', + 'pinecone', + 'typesense', + 'openai', + 'langchain_openai', +]: + sys.modules.setdefault(mod_name, MagicMock()) + +from routers.conversations import router + +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient + + +@pytest.fixture +def client(): + with patch('routers.conversations.auth.get_current_user_uid', return_value='uid-1'): + app = FastAPI() + app.include_router(router) + with TestClient(app) as c: + yield c + + +@pytest.fixture +def client_no_auth(): + app = FastAPI() + app.include_router(router) + with TestClient(app) as c: + yield c + + +AUTH = {"Authorization": "Bearer 123testuser"} + + +class TestConversationsCount: + def test_count_default_statuses(self, client): + with patch('routers.conversations.conversations_db.count_conversations', return_value=42) as mock_count: + resp = client.get("/v1/conversations/count", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["count"] == 42 + args = mock_count.call_args + assert args[1]['statuses'] == ['processing', 'completed'] + + def test_count_custom_statuses(self, client): + with patch('routers.conversations.conversations_db.count_conversations', return_value=10) as mock_count: + resp = client.get("/v1/conversations/count?statuses=completed", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["count"] == 10 + assert mock_count.call_args[1]['statuses'] == ['completed'] + + def test_count_empty_statuses(self, client): + with patch('routers.conversations.conversations_db.count_conversations', return_value=0) as mock_count: + resp = client.get("/v1/conversations/count?statuses=", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["count"] == 0 + assert mock_count.call_args[1]['statuses'] == [] + + def test_count_zero(self, client): + with patch('routers.conversations.conversations_db.count_conversations', return_value=0): + resp = client.get("/v1/conversations/count", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["count"] == 0 + + def test_count_fallback_on_aggregation_error(self, client): + """If Firestore count() aggregation fails, falls back to len(get_conversations).""" + with patch( + 'routers.conversations.conversations_db.count_conversations', side_effect=Exception("aggregation err") + ): + with patch('routers.conversations.conversations_db.get_conversations', return_value=[{}, {}, {}]): + resp = client.get("/v1/conversations/count", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["count"] == 3 + + def test_count_no_auth_returns_401(self, client_no_auth): + with patch( + 'routers.conversations.auth.get_current_user_uid', + side_effect=HTTPException(status_code=401, detail='Not authenticated'), + ): + resp = client_no_auth.get("/v1/conversations/count") + assert resp.status_code == 401 + + def test_count_multiple_statuses(self, client): + with patch('routers.conversations.conversations_db.count_conversations', return_value=25) as mock_count: + resp = client.get("/v1/conversations/count?statuses=processing,completed,in_progress", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["count"] == 25 + assert mock_count.call_args[1]['statuses'] == ['processing', 'completed', 'in_progress'] From e1e2c1050b62b5ab2b6496ed50ca83a8d7960f54 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 12:17:07 +0100 Subject: [PATCH 102/163] Fix count endpoint: validate statuses limit, use stream fallback --- backend/database/conversations.py | 9 +++++++++ backend/routers/conversations.py | 7 ++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/backend/database/conversations.py b/backend/database/conversations.py index 7a408ccf41..3b6a45f5d9 100644 --- a/backend/database/conversations.py +++ b/backend/database/conversations.py @@ -231,6 +231,15 @@ def count_conversations(uid: str, statuses: List[str] = []) -> int: return results[0][0].value +def stream_conversations(uid: str, statuses: List[str] = []): + """Yield conversation docs as a stream for counting without loading all into memory.""" + conversations_ref = db.collection('users').document(uid).collection(conversations_collection) + conversations_ref = conversations_ref.where(filter=FieldFilter('discarded', '==', False)) + if statuses: + conversations_ref = conversations_ref.where(filter=FieldFilter('status', 'in', statuses)) + yield from conversations_ref.stream() + + @prepare_for_read(decrypt_func=_prepare_conversation_for_read) def get_conversations_without_photos( uid: str, diff --git a/backend/routers/conversations.py b/backend/routers/conversations.py index db71b1e3c8..f58fe6e386 100644 --- a/backend/routers/conversations.py +++ b/backend/routers/conversations.py @@ -256,12 +256,13 @@ def get_conversations_count( ): """Count conversations matching optional status filters.""" status_list = [s.strip() for s in statuses.split(',') if s.strip()] if statuses else [] + if len(status_list) > 10: + raise HTTPException(status_code=400, detail="Too many status values (max 10)") try: count = conversations_db.count_conversations(uid, statuses=status_list) except Exception as e: - logger.warning(f'count_conversations fallback: {e}') - conversations = conversations_db.get_conversations(uid, limit=10000, statuses=status_list) - count = len(conversations) + logger.warning(f'count_conversations aggregation fallback: {e}') + count = sum(1 for _ in conversations_db.stream_conversations(uid, statuses=status_list)) return {'count': count} From a4f9a7274eb1c1a3030f24be8e41024ac324ca3f Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 12:17:08 +0100 Subject: [PATCH 103/163] Add tests for statuses validation and stream fallback --- .../tests/unit/test_conversations_count.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/backend/tests/unit/test_conversations_count.py b/backend/tests/unit/test_conversations_count.py index d1104cfae2..52a02f4e9f 100644 --- a/backend/tests/unit/test_conversations_count.py +++ b/backend/tests/unit/test_conversations_count.py @@ -85,11 +85,11 @@ def test_count_zero(self, client): assert resp.json()["count"] == 0 def test_count_fallback_on_aggregation_error(self, client): - """If Firestore count() aggregation fails, falls back to len(get_conversations).""" + """If Firestore count() aggregation fails, falls back to stream_conversations.""" with patch( 'routers.conversations.conversations_db.count_conversations', side_effect=Exception("aggregation err") ): - with patch('routers.conversations.conversations_db.get_conversations', return_value=[{}, {}, {}]): + with patch('routers.conversations.conversations_db.stream_conversations', return_value=iter([1, 2, 3])): resp = client.get("/v1/conversations/count", headers=AUTH) assert resp.status_code == 200 assert resp.json()["count"] == 3 @@ -108,3 +108,21 @@ def test_count_multiple_statuses(self, client): assert resp.status_code == 200 assert resp.json()["count"] == 25 assert mock_count.call_args[1]['statuses'] == ['processing', 'completed', 'in_progress'] + + def test_count_too_many_statuses_returns_400(self, client): + statuses = ','.join(f'status{i}' for i in range(11)) + resp = client.get(f"/v1/conversations/count?statuses={statuses}", headers=AUTH) + assert resp.status_code == 400 + assert 'max 10' in resp.json()['detail'] + + def test_count_stream_fallback(self, client): + """Fallback uses stream_conversations for unbounded counting.""" + with patch( + 'routers.conversations.conversations_db.count_conversations', side_effect=Exception("no aggregation") + ): + with patch( + 'routers.conversations.conversations_db.stream_conversations', return_value=iter([1, 2, 3, 4, 5]) + ): + resp = client.get("/v1/conversations/count", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["count"] == 5 From dbb1dbb08af39cae773ffd22dbace6ed4290f358 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 12:21:34 +0100 Subject: [PATCH 104/163] Fix mutable default argument in count/stream_conversations --- backend/database/conversations.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/backend/database/conversations.py b/backend/database/conversations.py index 3b6a45f5d9..8d4cce69b9 100644 --- a/backend/database/conversations.py +++ b/backend/database/conversations.py @@ -220,8 +220,10 @@ def get_conversations( return conversations -def count_conversations(uid: str, statuses: List[str] = []) -> int: +def count_conversations(uid: str, statuses: Optional[List[str]] = None) -> int: """Count conversations matching status filters without fetching full documents.""" + if statuses is None: + statuses = [] conversations_ref = db.collection('users').document(uid).collection(conversations_collection) conversations_ref = conversations_ref.where(filter=FieldFilter('discarded', '==', False)) if statuses: @@ -231,8 +233,10 @@ def count_conversations(uid: str, statuses: List[str] = []) -> int: return results[0][0].value -def stream_conversations(uid: str, statuses: List[str] = []): +def stream_conversations(uid: str, statuses: Optional[List[str]] = None): """Yield conversation docs as a stream for counting without loading all into memory.""" + if statuses is None: + statuses = [] conversations_ref = db.collection('users').document(uid).collection(conversations_collection) conversations_ref = conversations_ref.where(filter=FieldFilter('discarded', '==', False)) if statuses: From 2922c4d0ce483a3eaa85e8662a3e0604482f10d4 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 12:26:48 +0100 Subject: [PATCH 105/163] Add boundary tests for fallback truncation and message text limit --- .../tests/unit/test_chat_generate_title.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/backend/tests/unit/test_chat_generate_title.py b/backend/tests/unit/test_chat_generate_title.py index fbedfa18c1..77232e83a1 100644 --- a/backend/tests/unit/test_chat_generate_title.py +++ b/backend/tests/unit/test_chat_generate_title.py @@ -176,3 +176,34 @@ def test_generate_title_limits_messages(self, client): prompt = mock_llm.invoke.call_args[0][0] assert 'Message 9' in prompt assert 'Message 10' not in prompt + + def test_generate_title_fallback_truncates_to_50_chars(self, client): + """When LLM fails, fallback title is truncated to 50 chars.""" + long_text = 'A' * 100 + data = { + "session_id": "sess-1", + "messages": [{"text": long_text, "sender": "human"}], + } + with patch('routers.chat.llm_mini') as mock_llm: + mock_llm.invoke.side_effect = Exception("LLM down") + with patch('routers.chat.chat_db.update_chat_session'): + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 200 + assert len(resp.json()["title"]) == 50 + + def test_generate_title_truncates_message_text_to_500_chars(self, client): + """Each message text is truncated to 500 chars in the transcript sent to LLM.""" + long_text = 'B' * 1000 + data = { + "session_id": "sess-1", + "messages": [{"text": long_text, "sender": "human"}], + } + with patch('routers.chat.llm_mini') as mock_llm: + mock_llm.invoke.return_value = MagicMock(content='Title') + with patch('routers.chat.chat_db.update_chat_session'): + resp = client.post("/v2/chat/generate-title", json=data, headers=AUTH) + assert resp.status_code == 200 + prompt = mock_llm.invoke.call_args[0][0] + # The transcript line should contain exactly 500 B's, not 1000 + assert 'B' * 500 in prompt + assert 'B' * 501 not in prompt From 1e1a42c5129e8ed54599a7bda3c4eb1ad9e3d073 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 12:26:49 +0100 Subject: [PATCH 106/163] Add tests for status whitespace normalization and fallback parity --- .../tests/unit/test_conversations_count.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/backend/tests/unit/test_conversations_count.py b/backend/tests/unit/test_conversations_count.py index 52a02f4e9f..5ac7767ace 100644 --- a/backend/tests/unit/test_conversations_count.py +++ b/backend/tests/unit/test_conversations_count.py @@ -126,3 +126,24 @@ def test_count_stream_fallback(self, client): resp = client.get("/v1/conversations/count", headers=AUTH) assert resp.status_code == 200 assert resp.json()["count"] == 5 + + def test_count_statuses_whitespace_normalization(self, client): + """Parser strips whitespace and drops empty segments from statuses.""" + with patch('routers.conversations.conversations_db.count_conversations', return_value=7) as mock_count: + resp = client.get("/v1/conversations/count?statuses= processing , , completed ", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["count"] == 7 + assert mock_count.call_args[1]['statuses'] == ['processing', 'completed'] + + def test_count_fallback_receives_parsed_statuses(self, client): + """Fallback stream_conversations receives the same parsed status list.""" + with patch( + 'routers.conversations.conversations_db.count_conversations', side_effect=Exception("err") + ): + with patch( + 'routers.conversations.conversations_db.stream_conversations', return_value=iter([1, 2]) + ) as mock_stream: + resp = client.get("/v1/conversations/count?statuses=completed,processing", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["count"] == 2 + assert mock_stream.call_args[1]['statuses'] == ['completed', 'processing'] From c74570a3c08cb324071f51f9552c31a21c567b77 Mon Sep 17 00:00:00 2001 From: beastoin Date: Thu, 5 Mar 2026 12:06:31 +0100 Subject: [PATCH 107/163] Swift desktop: path updates, decoder hardening, no-ops, remove migrations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Path updates (5 endpoints): - v2/chat/initial-message → v2/initial-message - v2/agent/provision → v1/agent/vm-ensure - v2/agent/status → v1/agent/vm-status - v1/personas/check-username → v1/apps/check-username - v1/personas/generate-prompt → v1/app/generate-prompts (POST→GET) Decoder hardening: - ServerConversation.createdAt: use decodeIfPresent with Date() fallback - ActionItemsListResponse: try "action_items" then "items" key (Python vs staged-tasks) - AgentProvisionResponse/AgentStatusResponse: make fields optional, add hasVm - UsernameAvailableResponse: support both is_taken (Python) and available (Rust) Graceful no-ops: - recordLlmUsage(): no-op with log (endpoint removed) - fetchTotalOmiAICost(): return nil immediately (endpoint removed) - getChatMessageCount(): return 0 immediately (endpoint removed) Remove staged-tasks migration: - Remove migrateStagedTasks() and migrateConversationItemsToStaged() from APIClient - Remove migration callers and functions from TasksStore Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/APIClient.swift | 146 +++++++++--------- desktop/Desktop/Sources/AgentVMService.swift | 28 ++-- .../Desktop/Sources/Stores/TasksStore.swift | 65 -------- 3 files changed, 91 insertions(+), 148 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index a0bfcf5f0e..741996ffcc 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -405,14 +405,10 @@ extension APIClient { return response.count } - /// Gets the count of AI chat messages from PostHog + /// Gets the count of AI chat messages func getChatMessageCount() async throws -> Int { - struct CountResponse: Decodable { - let count: Int - } - - let response: CountResponse = try await get("v1/users/stats/chat-messages") - return response.count + // No-op: chat-messages stats endpoint not available in Python backend + return 0 } /// Merges multiple conversations into a new conversation @@ -581,7 +577,7 @@ struct ServerConversation: Codable, Identifiable, Equatable { let container = try decoder.container(keyedBy: CodingKeys.self) id = try container.decode(String.self, forKey: .id) - createdAt = try container.decode(Date.self, forKey: .createdAt) + createdAt = try container.decodeIfPresent(Date.self, forKey: .createdAt) ?? Date() startedAt = try container.decodeIfPresent(Date.self, forKey: .startedAt) finishedAt = try container.decodeIfPresent(Date.self, forKey: .finishedAt) structured = try container.decode(Structured.self, forKey: .structured) @@ -1454,14 +1450,26 @@ struct UserProfile: Codable { // MARK: - Action Items API /// Response wrapper for paginated action items list -struct ActionItemsListResponse: Codable { +struct ActionItemsListResponse: Decodable { let items: [TaskActionItem] let hasMore: Bool enum CodingKeys: String, CodingKey { + case actionItems = "action_items" case items case hasMore = "has_more" } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + hasMore = try container.decodeIfPresent(Bool.self, forKey: .hasMore) ?? false + // Python action_items endpoint returns "action_items"; staged-tasks returns "items" + if let actionItems = try container.decodeIfPresent([TaskActionItem].self, forKey: .actionItems) { + items = actionItems + } else { + items = try container.decodeIfPresent([TaskActionItem].self, forKey: .items) ?? [] + } + } } extension APIClient { @@ -1917,17 +1925,6 @@ extension APIClient { return try await post("v1/staged-tasks/promote") } - /// One-time migration of existing AI tasks to staged_tasks - func migrateStagedTasks() async throws { - struct StatusResponse: Decodable { let status: String } - let _: StatusResponse = try await post("v1/staged-tasks/migrate") - } - - /// Migrate conversation-extracted action items (no source field) to staged_tasks - func migrateConversationItemsToStaged() async throws { - struct MigrateResponse: Decodable { let status: String; let migrated: Int; let deleted: Int } - let _: MigrateResponse = try await post("v1/staged-tasks/migrate-conversation-items") - } } /// Response for staged task promotion @@ -3175,13 +3172,12 @@ extension APIClient { /// Regenerates persona prompt from current public memories func regeneratePersonaPrompt() async throws -> GeneratePromptResponse { - struct EmptyRequest: Encodable {} - return try await post("v1/personas/generate-prompt", body: EmptyRequest()) + return try await get("v1/app/generate-prompts") } /// Checks if a username is available func checkPersonaUsername(_ username: String) async throws -> UsernameAvailableResponse { - return try await get("v1/personas/check-username?username=\(username)") + return try await get("v1/apps/check-username?username=\(username)") } } @@ -3279,9 +3275,27 @@ struct GeneratePromptResponse: Codable { } /// Response for username availability check -struct UsernameAvailableResponse: Codable { +struct UsernameAvailableResponse: Decodable { let available: Bool - let username: String + let username: String? + let isTaken: Bool? + + enum CodingKeys: String, CodingKey { + case available, username + case isTaken = "is_taken" + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + username = try container.decodeIfPresent(String.self, forKey: .username) + isTaken = try container.decodeIfPresent(Bool.self, forKey: .isTaken) + // Python returns is_taken; Rust returned available. Support both. + if let isTaken = isTaken { + available = !isTaken + } else { + available = try container.decodeIfPresent(Bool.self, forKey: .available) ?? true + } + } } // MARK: - User Settings API @@ -4119,7 +4133,7 @@ extension APIClient { } let body = InitialMessageRequest(sessionId: sessionId, appId: appId) - return try await post("v2/chat/initial-message", body: body) + return try await post("v2/initial-message", body: body) } /// Generate a title for a chat session based on its messages @@ -4268,31 +4282,51 @@ extension APIClient { // MARK: - Agent VM struct AgentProvisionResponse: Decodable { - let status: String - let vmName: String + let hasVm: Bool + let status: String? + let vmName: String? let ip: String? - let authToken: String - let agentStatus: String + let authToken: String? + let agentStatus: String? + + enum CodingKeys: String, CodingKey { + case hasVm = "has_vm" + case status + case vmName = "vm_name" + case ip + case authToken = "auth_token" + case agentStatus = "agent_status" + } } /// Provision a cloud agent VM for the current user (fire-and-forget) func provisionAgentVM() async throws -> AgentProvisionResponse { - return try await post("v2/agent/provision") + return try await post("v1/agent/vm-ensure") } struct AgentStatusResponse: Decodable { - let vmName: String - let zone: String + let hasVm: Bool + let vmName: String? + let zone: String? let ip: String? - let status: String - let authToken: String - let createdAt: String + let status: String? + let authToken: String? + let createdAt: String? let lastQueryAt: String? + + enum CodingKeys: String, CodingKey { + case hasVm = "has_vm" + case vmName = "vm_name" + case zone, ip, status + case authToken = "auth_token" + case createdAt = "created_at" + case lastQueryAt = "last_query_at" + } } /// Get current agent VM status func getAgentStatus() async throws -> AgentStatusResponse? { - return try await get("v2/agent/status") + return try await get("v1/agent/vm-status") } } @@ -4397,41 +4431,13 @@ extension APIClient { costUsd: Double, account: String = "omi" ) async { - struct Req: Encodable { - let input_tokens: Int - let output_tokens: Int - let cache_read_tokens: Int - let cache_write_tokens: Int - let total_tokens: Int - let cost_usd: Double - let account: String - } - struct Res: Decodable { let status: String } - do { - let _: Res = try await post("v1/users/me/llm-usage", body: Req( - input_tokens: inputTokens, - output_tokens: outputTokens, - cache_read_tokens: cacheReadTokens, - cache_write_tokens: cacheWriteTokens, - total_tokens: totalTokens, - cost_usd: costUsd, - account: account - )) - } catch { - log("APIClient: LLM usage record failed: \(error.localizedDescription)") - } + // No-op: LLM usage tracking endpoint not available in Python backend + log("APIClient: recordLlmUsage no-op (endpoint removed)") } func fetchTotalOmiAICost() async -> Double? { - struct Res: Decodable { let total_cost_usd: Double } - do { - log("APIClient: Fetching total Omi AI cost from backend") - let res: Res = try await get("v1/users/me/llm-usage/total") - log("APIClient: Total Omi AI cost from backend: $\(String(format: "%.4f", res.total_cost_usd))") - return res.total_cost_usd - } catch { - log("APIClient: LLM total cost fetch failed: \(error.localizedDescription)") - return nil - } + // No-op: LLM usage total endpoint not available in Python backend + log("APIClient: fetchTotalOmiAICost no-op (endpoint removed)") + return nil } } diff --git a/desktop/Desktop/Sources/AgentVMService.swift b/desktop/Desktop/Sources/AgentVMService.swift index ceec6ca19f..985be922cb 100644 --- a/desktop/Desktop/Sources/AgentVMService.swift +++ b/desktop/Desktop/Sources/AgentVMService.swift @@ -23,26 +23,28 @@ actor AgentVMService { do { let status = try await APIClient.shared.getAgentStatus() if let status = status, status.status == "ready", let ip = status.ip { - log("AgentVMService: VM already ready — vmName=\(status.vmName) ip=\(ip)") + let token = status.authToken ?? "" + log("AgentVMService: VM already ready — vmName=\(status.vmName ?? "unknown") ip=\(ip)") // Only upload if the VM doesn't have a database yet - if await checkVMNeedsDatabase(vmIP: ip, authToken: status.authToken) { - await uploadDatabase(vmIP: ip, authToken: status.authToken) + if await checkVMNeedsDatabase(vmIP: ip, authToken: token) { + await uploadDatabase(vmIP: ip, authToken: token) } else { log("AgentVMService: VM already has database, skipping upload") } - await startIncrementalSync(vmIP: ip, authToken: status.authToken) + await startIncrementalSync(vmIP: ip, authToken: token) return } if let status = status, status.status == "provisioning" || status.status == "stopped" { - log("AgentVMService: VM is \(status.status), polling until ready...") + log("AgentVMService: VM is \(status.status ?? "unknown"), polling until ready...") if let result = await pollUntilReady(maxAttempts: 30, intervalSeconds: 5), let ip = result.ip { + let token = result.authToken ?? "" log("AgentVMService: VM became ready — ip=\(ip)") - if await checkVMNeedsDatabase(vmIP: ip, authToken: result.authToken) { - await uploadDatabase(vmIP: ip, authToken: result.authToken) + if await checkVMNeedsDatabase(vmIP: ip, authToken: token) { + await uploadDatabase(vmIP: ip, authToken: token) } - await startIncrementalSync(vmIP: ip, authToken: result.authToken) + await startIncrementalSync(vmIP: ip, authToken: token) } return } @@ -76,7 +78,7 @@ actor AgentVMService { let provisionResult: APIClient.AgentProvisionResponse do { provisionResult = try await APIClient.shared.provisionAgentVM() - log("AgentVMService: Provision response — vmName=\(provisionResult.vmName) status=\(provisionResult.status) ip=\(provisionResult.ip ?? "none")") + log("AgentVMService: Provision response — vmName=\(provisionResult.vmName ?? "unknown") status=\(provisionResult.status ?? "unknown") ip=\(provisionResult.ip ?? "none")") } catch { log("AgentVMService: Provision failed — \(error.localizedDescription)") return @@ -84,14 +86,14 @@ actor AgentVMService { // Step 2: Poll until VM is ready with an IP var vmIP = provisionResult.ip - var authToken = provisionResult.authToken + var authToken = provisionResult.authToken ?? "" - if vmIP == nil || provisionResult.agentStatus == "provisioning" { + if vmIP == nil || provisionResult.status == "provisioning" { log("AgentVMService: Waiting for VM to be ready...") let pollResult = await pollUntilReady(maxAttempts: 30, intervalSeconds: 5) if let result = pollResult { vmIP = result.ip - authToken = result.authToken + authToken = result.authToken ?? "" log("AgentVMService: VM ready — ip=\(vmIP ?? "none")") } else { log("AgentVMService: VM did not become ready in time") @@ -111,7 +113,7 @@ actor AgentVMService { await startIncrementalSync(vmIP: ip, authToken: authToken) } - /// Poll GET /v2/agent/status until status is "ready" and IP is available. + /// Poll GET /v1/agent/vm-status until status is "ready" and IP is available. private func pollUntilReady(maxAttempts: Int, intervalSeconds: UInt64) async -> APIClient.AgentStatusResponse? { for attempt in 1...maxAttempts { do { diff --git a/desktop/Desktop/Sources/Stores/TasksStore.swift b/desktop/Desktop/Sources/Stores/TasksStore.swift index d9fba16493..b5252c600f 100644 --- a/desktop/Desktop/Sources/Stores/TasksStore.swift +++ b/desktop/Desktop/Sources/Stores/TasksStore.swift @@ -448,8 +448,6 @@ class TasksStore: ObservableObject { // Then retry pushing any locally-created tasks that failed to sync Task { await performFullSyncIfNeeded() - await migrateAITasksToStagedIfNeeded() - await migrateConversationItemsToStagedIfNeeded() await retryUnsyncedItems() } // Backfill relevance scores for unscored tasks (independent of full sync) @@ -808,69 +806,6 @@ class TasksStore: ObservableObject { } } - /// In-memory guard to prevent duplicate migration calls within the same app session - private static var isMigrating = false - - /// One-time migration: tell backend to move excess AI tasks to staged_tasks subcollection. - /// The SQLite migration handles local data; this handles Firestore. - /// Sets the flag optimistically before the request to avoid retry loops on timeout. - private func migrateAITasksToStagedIfNeeded() async { - let userId = UserDefaults.standard.string(forKey: "auth_userId") ?? "unknown" - let migrationKey = "stagedTasksMigrationCompleted_v4_\(userId)" - - guard !UserDefaults.standard.bool(forKey: migrationKey) else { - log("TasksStore: Staged tasks migration already completed for user \(userId)") - return - } - - // In-memory guard: loadTasks() can be called from multiple pages - guard !Self.isMigrating else { - log("TasksStore: Staged tasks migration already in progress, skipping") - return - } - Self.isMigrating = true - - // Set flag optimistically — the migration is idempotent and safe to skip on re-run. - // This prevents infinite retry loops when the backend succeeds but the client times out. - UserDefaults.standard.set(true, forKey: migrationKey) - - log("TasksStore: Starting staged tasks backend migration for user \(userId)") - - do { - try await APIClient.shared.migrateStagedTasks() - log("TasksStore: Staged tasks backend migration completed") - } catch { - log("TasksStore: Staged tasks backend migration fired (may complete in background): \(error.localizedDescription)") - } - Self.isMigrating = false - } - - /// One-time migration of conversation-extracted action items (no source field) to staged_tasks. - /// These were created by the old save_action_items path that bypassed the staging pipeline. - private func migrateConversationItemsToStagedIfNeeded() async { - let userId = UserDefaults.standard.string(forKey: "auth_userId") ?? "unknown" - let migrationKey = "conversationItemsMigrationCompleted_v4_\(userId)" - - guard !UserDefaults.standard.bool(forKey: migrationKey) else { return } - - UserDefaults.standard.set(true, forKey: migrationKey) - log("TasksStore: Starting conversation items migration for user \(userId)") - - do { - try await APIClient.shared.migrateConversationItemsToStaged() - log("TasksStore: Conversation items migration completed, resetting full sync to clean up local SQLite") - - // Reset full sync flag so it re-runs and marks migrated items as staged locally - let syncKey = "tasksFullSyncCompleted_v9_\(userId)" - UserDefaults.standard.set(false, forKey: syncKey) - - // Run full sync now to clean up local SQLite - await performFullSyncIfNeeded() - } catch { - log("TasksStore: Conversation items migration fired (may complete in background): \(error.localizedDescription)") - } - } - /// Retry syncing locally-created tasks that failed to push to the backend. /// These are records with backendSynced=false and no backendId — the API call /// failed during extraction and there was no retry mechanism. From 78d15d278543576af6549428008bf491f7a6a7ff Mon Sep 17 00:00:00 2001 From: beastoin Date: Mon, 9 Mar 2026 05:57:25 +0100 Subject: [PATCH 108/163] Fix auth fallback: verify JWT signature, sanitize email PII C1: Replace unsafe base64 JWT decode with firebase_admin.auth.verify_id_token() which verifies signature against Google public keys before trusting claims. C2: Wrap email in sanitize_pii() per CLAUDE.md logging rules. Co-Authored-By: Claude Opus 4.6 --- backend/routers/auth.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/backend/routers/auth.py b/backend/routers/auth.py index ea54dc6658..aa626aa145 100644 --- a/backend/routers/auth.py +++ b/backend/routers/auth.py @@ -16,7 +16,7 @@ import pathlib import firebase_admin.auth from database.redis_db import set_auth_session, get_auth_session, set_auth_code, get_auth_code, delete_auth_code -from utils.log_sanitizer import sanitize +from utils.log_sanitizer import sanitize, sanitize_pii import logging logger = logging.getLogger(__name__) @@ -424,27 +424,23 @@ async def _generate_custom_token(provider: str, id_token: str, access_token: str f"Firebase REST API sign-in failed (status={response.status_code}), falling back to Admin SDK" ) - # Fallback: decode id_token JWT and look up/create user via Admin SDK + # Fallback: verify id_token via Admin SDK and look up/create user if not firebase_uid: - parts = id_token.split('.') - if len(parts) < 2: - raise Exception("Invalid id_token format") - payload_b64 = parts[1] + '=' * (4 - len(parts[1]) % 4) - token_payload = json.loads(base64.urlsafe_b64decode(payload_b64)) - email = token_payload.get('email') + verified_token = firebase_admin.auth.verify_id_token(id_token) + email = verified_token.get('email') if not email: - raise Exception("No email in id_token") + raise Exception("No email in verified id_token") # Look up existing Firebase user by email try: user = firebase_admin.auth.get_user_by_email(email) firebase_uid = user.uid - logger.info(f"Found existing Firebase user for {email}, UID: {firebase_uid}") + logger.info(f"Found existing Firebase user for {sanitize_pii(email)}, UID: {firebase_uid}") except firebase_admin.auth.UserNotFoundError: # Create new Firebase user user = firebase_admin.auth.create_user(email=email, email_verified=True) firebase_uid = user.uid - logger.info(f"Created new Firebase user for {email}, UID: {firebase_uid}") + logger.info(f"Created new Firebase user for {sanitize_pii(email)}, UID: {firebase_uid}") if not firebase_uid: raise Exception("No Firebase UID obtained") From ea7021c80e1f24e3b46f72f52dc93bf2a48079fb Mon Sep 17 00:00:00 2001 From: beastoin Date: Fri, 6 Mar 2026 08:35:28 +0100 Subject: [PATCH 109/163] Add BackendTranscriptionService for /v4/listen WebSocket New service replacing direct Deepgram connection. Connects to backend /v4/listen with Bearer auth header, streams mono PCM16 audio at 16kHz, parses backend response format (segment arrays, ping heartbeats, events). Configurable source parameter for BLE device type propagation. Co-Authored-By: Claude Opus 4.6 --- .../Sources/BackendTranscriptionService.swift | 492 ++++++++++++++++++ 1 file changed, 492 insertions(+) create mode 100644 desktop/Desktop/Sources/BackendTranscriptionService.swift diff --git a/desktop/Desktop/Sources/BackendTranscriptionService.swift b/desktop/Desktop/Sources/BackendTranscriptionService.swift new file mode 100644 index 0000000000..12768fd583 --- /dev/null +++ b/desktop/Desktop/Sources/BackendTranscriptionService.swift @@ -0,0 +1,492 @@ +import Foundation + +/// Service for real-time speech-to-text transcription via the OMI backend. +/// Streams mono audio over WebSocket to /v4/listen and receives transcript segments. +/// This replaces direct Deepgram connections — the backend handles STT server-side. +class BackendTranscriptionService { + + // MARK: - Types + + /// Reuse the same TranscriptSegment type for compatibility with existing handlers + typealias TranscriptSegment = TranscriptionService.TranscriptSegment + typealias TranscriptHandler = (TranscriptSegment) -> Void + typealias ErrorHandler = (Error) -> Void + typealias ConnectionHandler = () -> Void + + enum BackendTranscriptionError: LocalizedError { + case notSignedIn + case connectionFailed(Error) + case invalidResponse + case webSocketError(String) + + var errorDescription: String? { + switch self { + case .notSignedIn: + return "Not signed in — cannot connect to backend" + case .connectionFailed(let error): + return "Connection failed: \(error.localizedDescription)" + case .invalidResponse: + return "Invalid response from backend" + case .webSocketError(let message): + return "WebSocket error: \(message)" + } + } + } + + // MARK: - Properties + + private var webSocketTask: URLSessionWebSocketTask? + private var urlSession: URLSession? + private var isConnected = false + private var shouldReconnect = false + + // Callbacks + private var onTranscript: TranscriptHandler? + private var onError: ErrorHandler? + private var onConnected: ConnectionHandler? + private var onDisconnected: ConnectionHandler? + + // Configuration + private let language: String + private let sampleRate = 16000 + private let codec = "pcm16" + private let channels = 1 // Always mono — backend handles diarization + private let source: String + private let conversationTimeout: Int + + // Reconnection + private var reconnectAttempts = 0 + private let maxReconnectAttempts = 10 + private var reconnectTask: Task? + + // Keepalive — send empty data periodically to prevent timeout + private var keepaliveTask: Task? + private let keepaliveInterval: TimeInterval = 8.0 + + // Watchdog: detect stale connections where WebSocket dies silently + private var watchdogTask: Task? + private var lastDataReceivedAt: Date? + private var lastKeepaliveSuccessAt: Date? + private let watchdogInterval: TimeInterval = 30.0 + private let staleThreshold: TimeInterval = 60.0 + + // Audio buffering + private var audioBuffer = Data() + private let audioBufferSize = 3200 // ~100ms of 16kHz 16-bit mono (16000 * 2 * 0.1) + private let audioBufferLock = NSLock() + + // MARK: - Initialization + + /// Initialize the backend transcription service + /// - Parameters: + /// - language: Language code for transcription (e.g., "en", "multi") + /// - source: Audio source identifier for backend analytics (e.g., "desktop", "omi", "bee") + /// - conversationTimeout: Seconds of silence before the backend creates a memory + init(language: String = "en", source: String = "desktop", conversationTimeout: Int = 120) { + self.language = language + self.source = source + self.conversationTimeout = conversationTimeout + log("BackendTranscriptionService: Initialized with language=\(language), source=\(source)") + } + + // MARK: - Public Methods + + /// Start the transcription service + func start( + onTranscript: @escaping TranscriptHandler, + onError: ErrorHandler? = nil, + onConnected: ConnectionHandler? = nil, + onDisconnected: ConnectionHandler? = nil + ) { + self.onTranscript = onTranscript + self.onError = onError + self.onConnected = onConnected + self.onDisconnected = onDisconnected + self.shouldReconnect = true + self.reconnectAttempts = 0 + + connect() + } + + /// Stop the transcription service + func stop() { + shouldReconnect = false + reconnectTask?.cancel() + reconnectTask = nil + keepaliveTask?.cancel() + keepaliveTask = nil + watchdogTask?.cancel() + watchdogTask = nil + + flushAudioBuffer() + disconnect() + } + + /// Signal the backend that no more audio will be sent, but keep connection open + /// to receive final transcription results. Call stop() later to fully disconnect. + func finishStream() { + shouldReconnect = false + reconnectTask?.cancel() + reconnectTask = nil + keepaliveTask?.cancel() + keepaliveTask = nil + watchdogTask?.cancel() + watchdogTask = nil + + flushAudioBuffer() + + // Backend doesn't have a CloseStream message like Deepgram. + // The connection will be closed when stop() is called. + log("BackendTranscriptionService: finishStream called, waiting for final results") + } + + /// Send audio data to the backend (buffered for efficiency) + func sendAudio(_ data: Data) { + guard isConnected else { return } + + audioBufferLock.lock() + audioBuffer.append(data) + + if audioBuffer.count >= audioBufferSize { + let chunk = audioBuffer + audioBuffer = Data() + audioBufferLock.unlock() + sendAudioChunk(chunk) + } else { + audioBufferLock.unlock() + } + } + + /// Flush any remaining audio in the buffer + private func flushAudioBuffer() { + audioBufferLock.lock() + let chunk = audioBuffer + audioBuffer = Data() + audioBufferLock.unlock() + + if !chunk.isEmpty { + sendAudioChunk(chunk) + } + } + + /// Actually send an audio chunk over the WebSocket + private func sendAudioChunk(_ data: Data) { + guard isConnected, let webSocketTask = webSocketTask else { return } + + let message = URLSessionWebSocketTask.Message.data(data) + webSocketTask.send(message) { [weak self] error in + if let error = error { + logError("BackendTranscriptionService: Send error", error: error) + self?.handleDisconnection() + } + } + } + + /// No-op for backend (Deepgram-specific Finalize message not needed) + func sendFinalize() { + // Backend handles segmentation server-side + } + + /// Public keepalive for VAD gate to call during extended silence + func sendKeepalivePublic() { + sendKeepalive() + } + + /// Check if connected + var connected: Bool { + return isConnected + } + + // MARK: - Connection + + private func connect() { + Task { + do { + let token = try await AuthService.shared.getIdToken() + let baseURL = await APIClient.shared.baseURL + self.connectWithToken(token, baseURL: baseURL) + } catch { + logError("BackendTranscriptionService: Failed to get auth token", error: error) + self.onError?(BackendTranscriptionError.notSignedIn) + } + } + } + + private func connectWithToken(_ token: String, baseURL: String) { + + // Convert http(s) to ws(s) + let wsBaseURL: String + if baseURL.hasPrefix("https://") { + wsBaseURL = "wss://" + baseURL.dropFirst("https://".count) + } else if baseURL.hasPrefix("http://") { + wsBaseURL = "ws://" + baseURL.dropFirst("http://".count) + } else { + wsBaseURL = "wss://" + baseURL + } + + // Strip trailing slash before appending path + let cleanBase = wsBaseURL.hasSuffix("/") ? String(wsBaseURL.dropLast()) : wsBaseURL + + var components = URLComponents(string: cleanBase + "/v4/listen")! + components.queryItems = [ + URLQueryItem(name: "language", value: language), + URLQueryItem(name: "sample_rate", value: String(sampleRate)), + URLQueryItem(name: "codec", value: codec), + URLQueryItem(name: "channels", value: String(channels)), + URLQueryItem(name: "source", value: source), + URLQueryItem(name: "include_speech_profile", value: "true"), + URLQueryItem(name: "speaker_auto_assign", value: "enabled"), + URLQueryItem(name: "conversation_timeout", value: String(conversationTimeout)), + ] + + guard let url = components.url else { + onError?(BackendTranscriptionError.connectionFailed(NSError(domain: "Invalid URL", code: -1))) + return + } + + log("BackendTranscriptionService: Connecting to \(url.absoluteString)") + + // Create URL request with Bearer auth header (same as mobile app) + var request = URLRequest(url: url) + request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization") + + // Create URLSession and WebSocket task + let configuration = URLSessionConfiguration.default + configuration.timeoutIntervalForRequest = 30 + configuration.timeoutIntervalForResource = 0 // No resource timeout for long-lived WebSocket + urlSession = URLSession(configuration: configuration) + webSocketTask = urlSession?.webSocketTask(with: request) + + // Start the connection + webSocketTask?.resume() + + // Start receiving messages + receiveMessage() + + // Mark as connected after a short delay (backend doesn't send a connect confirmation) + DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in + guard let self = self, self.webSocketTask?.state == .running else { return } + self.isConnected = true + self.reconnectAttempts = 0 + self.lastDataReceivedAt = Date() + self.lastKeepaliveSuccessAt = Date() + log("BackendTranscriptionService: Connected") + self.startKeepalive() + self.startWatchdog() + self.onConnected?() + } + } + + // MARK: - Keepalive + + private func startKeepalive() { + keepaliveTask?.cancel() + keepaliveTask = Task { [weak self] in + while !Task.isCancelled { + try? await Task.sleep(nanoseconds: UInt64(self?.keepaliveInterval ?? 8.0) * 1_000_000_000) + guard !Task.isCancelled, let self = self, self.isConnected else { break } + self.sendKeepalive() + } + } + } + + private func sendKeepalive() { + guard isConnected, let webSocketTask = webSocketTask else { return } + + // Send a small chunk of silence as keepalive (2 bytes of zero = 1 silent sample) + let silence = Data(repeating: 0, count: 2) + let message = URLSessionWebSocketTask.Message.data(silence) + webSocketTask.send(message) { [weak self] error in + if let error = error { + logError("BackendTranscriptionService: Keepalive error", error: error) + self?.handleDisconnection() + } else { + self?.lastKeepaliveSuccessAt = Date() + } + } + } + + // MARK: - Watchdog + + private func startWatchdog() { + watchdogTask?.cancel() + watchdogTask = Task { [weak self] in + while !Task.isCancelled { + try? await Task.sleep(nanoseconds: UInt64(self?.watchdogInterval ?? 30.0) * 1_000_000_000) + guard !Task.isCancelled, let self = self, self.isConnected else { break } + + if let lastData = self.lastDataReceivedAt, + Date().timeIntervalSince(lastData) > self.staleThreshold { + if let lastKeepalive = self.lastKeepaliveSuccessAt, + Date().timeIntervalSince(lastKeepalive) < self.staleThreshold { + continue + } + log("BackendTranscriptionService: Watchdog detected stale connection — forcing reconnect") + self.handleDisconnection() + } + } + } + } + + // MARK: - Disconnect / Reconnect + + private func disconnect() { + isConnected = false + keepaliveTask?.cancel() + keepaliveTask = nil + watchdogTask?.cancel() + watchdogTask = nil + webSocketTask?.cancel(with: .normalClosure, reason: nil) + webSocketTask = nil + urlSession?.invalidateAndCancel() + urlSession = nil + log("BackendTranscriptionService: Disconnected") + onDisconnected?() + } + + private func handleDisconnection() { + guard isConnected else { return } + + isConnected = false + keepaliveTask?.cancel() + keepaliveTask = nil + watchdogTask?.cancel() + watchdogTask = nil + webSocketTask = nil + urlSession?.invalidateAndCancel() + urlSession = nil + onDisconnected?() + + if shouldReconnect && reconnectAttempts < maxReconnectAttempts { + reconnectAttempts += 1 + let delay = min(pow(2.0, Double(reconnectAttempts)), 32.0) + log("BackendTranscriptionService: Reconnecting in \(delay)s (attempt \(reconnectAttempts))") + + reconnectTask = Task { + try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) + guard !Task.isCancelled, self.shouldReconnect else { return } + self.connect() + } + } else if reconnectAttempts >= maxReconnectAttempts { + log("BackendTranscriptionService: Max reconnect attempts reached") + onError?(BackendTranscriptionError.webSocketError("Max reconnect attempts reached")) + } + } + + // MARK: - Message Handling + + private func receiveMessage() { + webSocketTask?.receive { [weak self] result in + guard let self = self else { return } + + switch result { + case .success(let message): + self.handleMessage(message) + self.receiveMessage() + + case .failure(let error): + guard self.isConnected else { return } + logError("BackendTranscriptionService: Receive error", error: error) + self.handleDisconnection() + } + } + } + + private func handleMessage(_ message: URLSessionWebSocketTask.Message) { + lastDataReceivedAt = Date() + + switch message { + case .string(let text): + parseResponse(text) + case .data(let data): + if let text = String(data: data, encoding: .utf8) { + parseResponse(text) + } + @unknown default: + break + } + } + + private func parseResponse(_ text: String) { + // Handle heartbeat ping from backend + if text == "ping" { + return + } + + guard let data = text.data(using: .utf8) else { return } + + // Try parsing as array of transcript segments (main response format) + if let segments = try? JSONDecoder().decode([BackendSegment].self, from: data) { + for segment in segments { + // Map backend is_user to channel index: + // is_user=true → channelIndex=0 (mic/user) + // is_user=false → channelIndex=1 (system/others) + let channelIndex = segment.is_user ? 0 : 1 + + let transcriptSegment = TranscriptSegment( + text: segment.text, + isFinal: true, + speechFinal: true, + confidence: 1.0, + words: [TranscriptSegment.Word( + word: segment.text, + start: segment.start, + end: segment.end, + confidence: 1.0, + speaker: segment.speaker_id, + punctuatedWord: segment.text + )], + channelIndex: channelIndex + ) + onTranscript?(transcriptSegment) + } + return + } + + // Try parsing as event object (memory_created, service_status, etc.) + if let event = try? JSONDecoder().decode(BackendEvent.self, from: data) { + switch event.type { + case "memory_created": + log("BackendTranscriptionService: Memory created") + case "service_status": + log("BackendTranscriptionService: Service status: \(event.status ?? "unknown")") + default: + log("BackendTranscriptionService: Event: \(event.type)") + } + return + } + + // Unknown message — log for debugging + log("BackendTranscriptionService: Unknown message: \(text.prefix(200))") + } +} + +// MARK: - Backend Response Models + +/// Transcript segment from the OMI backend +private struct BackendSegment: Decodable { + let text: String + let speaker: String? + let speaker_id: Int? + let is_user: Bool + let start: Double + let end: Double + let person_id: String? +} + +/// Event message from the OMI backend +private struct BackendEvent: Decodable { + let type: String + let status: String? + + enum CodingKeys: String, CodingKey { + case type + case status + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + type = try container.decode(String.self, forKey: .type) + status = try container.decodeIfPresent(String.self, forKey: .status) + } +} From 8c28e580346db2773f99fea4aaab449d267a9297 Mon Sep 17 00:00:00 2001 From: beastoin Date: Fri, 6 Mar 2026 08:35:34 +0100 Subject: [PATCH 110/163] Add mono output mode and fix single-source operation in AudioMixer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add OutputMode enum (.stereo/.mono) with mono averaging both channels. Fix processBuffers() to work when only one source has data (e.g. system audio disabled by default) — previously min(mic, 0) = 0 blocked all output. Existing silence-padding handles the gap. Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/AudioMixer.swift | 75 ++++++++++++++++++++---- 1 file changed, 62 insertions(+), 13 deletions(-) diff --git a/desktop/Desktop/Sources/AudioMixer.swift b/desktop/Desktop/Sources/AudioMixer.swift index b35d9a58b7..88ac96178b 100644 --- a/desktop/Desktop/Sources/AudioMixer.swift +++ b/desktop/Desktop/Sources/AudioMixer.swift @@ -1,19 +1,26 @@ import Foundation -/// Mixes microphone and system audio into a stereo stream for multichannel transcription +/// Mixes microphone and system audio into a combined stream for transcription. +/// Supports stereo (interleaved mic+system) or mono (averaged) output. /// Channel 0 (left) = Microphone (user) /// Channel 1 (right) = System audio (others) class AudioMixer { // MARK: - Types - /// Callback for receiving stereo audio chunks + enum OutputMode { + case stereo // Interleaved [mic0, sys0, mic1, sys1, ...] — for Deepgram multichannel + case mono // Averaged (mic + system) / 2 — for backend /v4/listen + } + + /// Callback for receiving mixed audio chunks typealias StereoAudioHandler = (Data) -> Void // MARK: - Properties private var onStereoChunk: StereoAudioHandler? private var isRunning = false + private(set) var outputMode: OutputMode = .stereo // Audio buffers (16kHz mono Int16 PCM) private var micBuffer = Data() @@ -29,15 +36,18 @@ class AudioMixer { // MARK: - Public Methods /// Start the mixer - /// - Parameter onStereoChunk: Callback receiving interleaved stereo 16-bit PCM at 16kHz - func start(onStereoChunk: @escaping StereoAudioHandler) { + /// - Parameters: + /// - outputMode: `.stereo` for interleaved multichannel, `.mono` for averaged single-channel + /// - onStereoChunk: Callback receiving mixed 16-bit PCM at 16kHz + func start(outputMode: OutputMode = .stereo, onStereoChunk: @escaping StereoAudioHandler) { bufferLock.lock() + self.outputMode = outputMode self.onStereoChunk = onStereoChunk self.isRunning = true micBuffer = Data() systemBuffer = Data() bufferLock.unlock() - log("AudioMixer: Started") + log("AudioMixer: Started (output=\(outputMode))") } /// Stop the mixer and flush remaining audio @@ -105,12 +115,17 @@ class AudioMixer { if flush { // When flushing, process whatever is available bytesToProcess = max(micBuffer.count, systemBuffer.count) + } else if micBuffer.count >= minBufferBytes && systemBuffer.count >= minBufferBytes { + // Both buffers have data — use shorter to stay in sync + bytesToProcess = (min(micBuffer.count, systemBuffer.count) / 2) * 2 + } else if micBuffer.count >= minBufferBytes { + // Only mic has data (system audio disabled/unavailable) — pad system with silence + bytesToProcess = (micBuffer.count / 2) * 2 + } else if systemBuffer.count >= minBufferBytes { + // Only system has data — pad mic with silence + bytesToProcess = (systemBuffer.count / 2) * 2 } else { - // Normal operation: process when both have data - let minAvailable = min(micBuffer.count, systemBuffer.count) - guard minAvailable >= minBufferBytes else { return } - // Align to sample boundary (2 bytes per Int16 sample) - bytesToProcess = (minAvailable / 2) * 2 + return } guard bytesToProcess >= 2 else { return } @@ -137,11 +152,17 @@ class AudioMixer { systemBuffer = Data() } - // Interleave into stereo - let stereoData = interleave(mic: micData, system: sysData) + // Mix according to output mode + let mixedData: Data + switch outputMode { + case .stereo: + mixedData = interleave(mic: micData, system: sysData) + case .mono: + mixedData = mixToMono(mic: micData, system: sysData) + } // Send to callback - onStereoChunk?(stereoData) + onStereoChunk?(mixedData) } /// Interleave two mono Int16 streams into stereo @@ -174,4 +195,32 @@ class AudioMixer { Data(buffer: buffer) } } + + /// Average two mono Int16 streams into a single mono stream + /// Output format: [(mic0+sys0)/2, (mic1+sys1)/2, ...] + private func mixToMono(mic: Data, system: Data) -> Data { + let sampleCount = mic.count / 2 + + var monoSamples = [Int16]() + monoSamples.reserveCapacity(sampleCount) + + mic.withUnsafeBytes { micPtr in + system.withUnsafeBytes { sysPtr in + let micSamples = micPtr.bindMemory(to: Int16.self) + let sysSamples = sysPtr.bindMemory(to: Int16.self) + + for i in 0.. Date: Fri, 6 Mar 2026 08:35:38 +0100 Subject: [PATCH 111/163] Switch BleAudioService to closure-based audio sink Replace concrete TranscriptionService parameter with audioSink closure for decoupled audio routing. Callers provide destination closure instead of coupling to a specific transcription type. Co-Authored-By: Claude Opus 4.6 --- .../Sources/Audio/BleAudioService.swift | 38 ++++--------------- 1 file changed, 7 insertions(+), 31 deletions(-) diff --git a/desktop/Desktop/Sources/Audio/BleAudioService.swift b/desktop/Desktop/Sources/Audio/BleAudioService.swift index 0cb9bb527f..9078653668 100644 --- a/desktop/Desktop/Sources/Audio/BleAudioService.swift +++ b/desktop/Desktop/Sources/Audio/BleAudioService.swift @@ -27,7 +27,7 @@ final class BleAudioService: ObservableObject { private var cancellables = Set() // Audio delivery - private var transcriptionService: TranscriptionService? + private var audioSink: ((Data) -> Void)? private var audioDataHandler: ((Data) -> Void)? private var rawFrameHandler: ((Data) -> Void)? @@ -44,12 +44,12 @@ final class BleAudioService: ObservableObject { /// Start processing audio from a device connection /// - Parameters: /// - connection: The device connection to get audio from - /// - transcriptionService: Optional transcription service to send audio to + /// - audioSink: Optional closure to receive decoded mono PCM audio (e.g., send to transcription service) /// - audioDataHandler: Optional handler for decoded PCM data (alternative to transcription) /// - rawFrameHandler: Optional handler for raw encoded frames (for WAL recording) func startProcessing( from connection: DeviceConnection, - transcriptionService: TranscriptionService? = nil, + audioSink: ((Data) -> Void)? = nil, audioDataHandler: ((Data) -> Void)? = nil, rawFrameHandler: ((Data) -> Void)? = nil ) async { @@ -58,7 +58,7 @@ final class BleAudioService: ObservableObject { return } - self.transcriptionService = transcriptionService + self.audioSink = audioSink self.audioDataHandler = audioDataHandler self.rawFrameHandler = rawFrameHandler @@ -126,7 +126,7 @@ final class BleAudioService: ObservableObject { cancellables.removeAll() isProcessing = false - transcriptionService = nil + audioSink = nil audioDataHandler = nil rawFrameHandler = nil @@ -194,37 +194,13 @@ final class BleAudioService: ObservableObject { // Calculate audio level updateAudioLevel(from: pcmData) - // Send to transcription service (mono channel) - if let transcription = transcriptionService { - // TranscriptionService expects stereo (2 channels) for multichannel transcription - // For BLE device audio, we duplicate to both channels (device is the "user") - let stereoData = convertToStereo(pcmData) - transcription.sendAudio(stereoData) - } + // Send decoded mono PCM to audio sink (e.g., transcription service) + audioSink?(pcmData) // Send to custom handler audioDataHandler?(pcmData) } - /// Convert mono PCM to stereo (duplicate to both channels) - private func convertToStereo(_ monoData: Data) -> Data { - // Mono: [S0, S1, S2, ...] - // Stereo: [S0, S0, S1, S1, S2, S2, ...] (interleaved) - var stereoData = Data(capacity: monoData.count * 2) - - monoData.withUnsafeBytes { bytes in - let samples = bytes.bindMemory(to: Int16.self) - for i in 0.. Date: Fri, 6 Mar 2026 08:35:43 +0100 Subject: [PATCH 112/163] Route desktop STT through backend /v4/listen in AppState Replace direct Deepgram with BackendTranscriptionService. Force streaming mode, set AudioMixer to mono. Add backendOwnsConversation flag to skip createConversationFromSegments() (backend creates conversations via lifecycle manager). Pass correct source for BLE devices. Remove DEEPGRAM_API_KEY check. Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/AppState.swift | 316 ++++++++++++------------- 1 file changed, 148 insertions(+), 168 deletions(-) diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 862bfaf070..474fec6ceb 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -135,13 +135,16 @@ class AppState: ObservableObject { // Transcription services private var audioCaptureService: AudioCaptureService? - private var transcriptionService: TranscriptionService? + private var transcriptionService: BackendTranscriptionService? private var systemAudioCaptureService: Any? // SystemAudioCaptureService (macOS 14.4+) private var audioMixer: AudioMixer? private var vadGateService: VADGateService? - // Batch transcription mode + // Batch transcription mode (disabled — backend handles everything via /v4/listen) private var useBatchTranscription: Bool = false + // When true, backend owns conversation creation via /v4/listen lifecycle manager. + // Desktop skips createConversationFromSegments() to avoid duplicates. + private var backendOwnsConversation: Bool = false private var recordingStartCATime: Double = 0 // CACurrentMediaTime at recording start // Speaker segments for diarized transcription (sliding window — older segments are in SQLite) @@ -430,12 +433,7 @@ class AppState: ObservableObject { } } - // Log final state of important keys - if getenv("DEEPGRAM_API_KEY") != nil { - log("DEEPGRAM_API_KEY is set") - } else { - log("WARNING: DEEPGRAM_API_KEY is NOT set") - } + // DEEPGRAM_API_KEY no longer needed — STT routed through backend /v4/listen } private func shouldSkipBundledAnthropicKey(key: String, sourcePath: String, bundledEnvPath: String?) -> Bool { @@ -1152,165 +1150,143 @@ class AppState: ObservableObject { } } - do { - // Get effective language from settings (handles auto-detect vs single language) - let effectiveLanguage = AssistantSettings.shared.effectiveTranscriptionLanguage - let vocabulary = AssistantSettings.shared.effectiveVocabulary - log("Transcription: Using language=\(effectiveLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))") - log("Transcription: Custom vocabulary: \(vocabulary.joined(separator: ", "))") - - // Determine transcription mode - useBatchTranscription = AssistantSettings.shared.batchTranscriptionEnabled && effectiveSource == .microphone - - if !useBatchTranscription { - // Streaming mode: initialize WebSocket transcription service - transcriptionService = try TranscriptionService(language: effectiveLanguage, vocabulary: vocabulary) + // Get effective language from settings (handles auto-detect vs single language) + let effectiveLanguage = AssistantSettings.shared.effectiveTranscriptionLanguage + let vocabulary = AssistantSettings.shared.effectiveVocabulary + log("Transcription: Using language=\(effectiveLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))") + log("Transcription: Custom vocabulary: \(vocabulary.joined(separator: ", "))") + + // Always use streaming mode through the backend — batch mode not needed + // (backend handles STT, diarization, and memory creation server-side) + useBatchTranscription = false + // Backend owns conversation creation via /v4/listen lifecycle manager + backendOwnsConversation = true + + // Set conversation source based on audio source + let sourceValue: String + if effectiveSource == .bleDevice, let device = DeviceProvider.shared.connectedDevice { + currentConversationSource = ConversationSource.from(deviceType: device.type) + recordingInputDeviceName = device.displayName + sourceValue = currentConversationSource.rawValue + } else { + currentConversationSource = .desktop + recordingInputDeviceName = AudioCaptureService.getCurrentMicrophoneName() + sourceValue = "desktop" + } + + transcriptionService = BackendTranscriptionService(language: effectiveLanguage, source: sourceValue) + + // Initialize audio services based on source + if effectiveSource == .microphone { + // Initialize audio capture service + audioCaptureService = AudioCaptureService() + + // Initialize audio mixer for combining mic and system audio + audioMixer = AudioMixer() + + // VAD gate is optional for streaming mode (silence gating) + if AssistantSettings.shared.vadGateEnabled { + let gate = VADGateService() + vadGateService = gate + log("Transcription: VAD gate enabled") } else { - log("Transcription: Batch mode enabled — skipping WebSocket") + vadGateService = nil + } + + // Initialize system audio capture if supported (macOS 14.4+) + // Can be disabled via: defaults write com.omi.desktop-dev disableSystemAudioCapture -bool true + // or: defaults write com.omi.computer-macos disableSystemAudioCapture -bool true + let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture") + if systemAudioDisabled { + log("Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)") + } else if #available(macOS 14.4, *) { + systemAudioCaptureService = SystemAudioCaptureService() + log("Transcription: System audio capture initialized (macOS 14.4+)") + } else { + log("Transcription: System audio capture not available (requires macOS 14.4+)") } + } + // For BLE device, BleAudioService will be used in startAudioCapture - // Set conversation source based on audio source - if effectiveSource == .bleDevice, let device = DeviceProvider.shared.connectedDevice { - currentConversationSource = ConversationSource.from(deviceType: device.type) - recordingInputDeviceName = device.displayName - } else { - currentConversationSource = .desktop - recordingInputDeviceName = AudioCaptureService.getCurrentMicrophoneName() - } - - // Initialize audio services based on source - if effectiveSource == .microphone { - // Initialize audio capture service - audioCaptureService = AudioCaptureService() - - // Initialize audio mixer for combining mic and system audio - audioMixer = AudioMixer() - - // VAD gate is always needed for batch mode (chunk boundaries), - // and optional for streaming mode (silence gating) - if useBatchTranscription || AssistantSettings.shared.vadGateEnabled { - let gate = VADGateService() - if useBatchTranscription && !gate.modelAvailable { - // Batch mode requires working VAD — fall back to streaming - log("Transcription: VAD models unavailable, falling back from batch to streaming mode") - useBatchTranscription = false - vadGateService = nil - transcriptionService = try TranscriptionService(language: effectiveLanguage, vocabulary: vocabulary) - } else { - vadGateService = gate - log("Transcription: VAD gate enabled\(useBatchTranscription ? " (batch mode)" : "")") - } - } else { - vadGateService = nil + // Start backend transcription service, then audio on connect + transcriptionService?.start( + onTranscript: { [weak self] segment in + Task { @MainActor in + self?.handleTranscriptSegment(segment) } - - // Initialize system audio capture if supported (macOS 14.4+) - // Can be disabled via: defaults write com.omi.desktop-dev disableSystemAudioCapture -bool true - // or: defaults write com.omi.computer-macos disableSystemAudioCapture -bool true - let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture") - if systemAudioDisabled { - log("Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)") - } else if #available(macOS 14.4, *) { - systemAudioCaptureService = SystemAudioCaptureService() - log("Transcription: System audio capture initialized (macOS 14.4+)") - } else { - log("Transcription: System audio capture not available (requires macOS 14.4+)") + }, + onError: { [weak self] error in + Task { @MainActor in + logError("Transcription error", error: error) + AnalyticsManager.shared.recordingError(error: error.localizedDescription) + self?.stopTranscription() } - } - // For BLE device, BleAudioService will be used in startAudioCapture - - if useBatchTranscription { - // Batch mode: start audio capture directly (no WebSocket to wait for) - recordingStartCATime = CACurrentMediaTime() - Task { @MainActor [weak self] in + }, + onConnected: { [weak self] in + Task { @MainActor in + log("Transcription: Connected to backend") + // Start audio capture once connected await self?.startAudioCapture(source: effectiveSource) } - } else { - // Streaming mode: start transcription service first, then audio on connect - transcriptionService?.start( - onTranscript: { [weak self] segment in - Task { @MainActor in - self?.handleTranscriptSegment(segment) - } - }, - onError: { [weak self] error in - Task { @MainActor in - logError("Transcription error", error: error) - AnalyticsManager.shared.recordingError(error: error.localizedDescription) - self?.stopTranscription() - } - }, - onConnected: { [weak self] in - Task { @MainActor in - log("Transcription: Connected to DeepGram") - // Start audio capture once connected - await self?.startAudioCapture(source: effectiveSource) - } - }, - onDisconnected: { - log("Transcription: Disconnected from DeepGram") - } - ) + }, + onDisconnected: { + log("Transcription: Disconnected from backend") } + ) - isTranscribing = true - AssistantSettings.shared.transcriptionEnabled = true - audioSource = effectiveSource - currentTranscript = "" - speakerSegments = [] - totalSegmentCount = 0 - totalWordCount = 0 - liveSpeakerPersonMap = [:] - LiveTranscriptMonitor.shared.clear() - recordingStartTime = Date() - AudioLevelMonitor.shared.reset() - RecordingTimer.shared.start() + isTranscribing = true + AssistantSettings.shared.transcriptionEnabled = true + audioSource = effectiveSource + currentTranscript = "" + speakerSegments = [] + totalSegmentCount = 0 + totalWordCount = 0 + liveSpeakerPersonMap = [:] + LiveTranscriptMonitor.shared.clear() + recordingStartTime = Date() + AudioLevelMonitor.shared.reset() + RecordingTimer.shared.start() - log("Transcription: Using source: \(effectiveSource.rawValue), device: \(recordingInputDeviceName ?? "Unknown")") + log("Transcription: Using source: \(effectiveSource.rawValue), device: \(recordingInputDeviceName ?? "Unknown")") - // Create crash-safe DB session for persistence - Task { - do { - let sessionId = try await TranscriptionStorage.shared.startSession( - source: currentConversationSource.rawValue, - language: effectiveLanguage, - timezone: TimeZone.current.identifier, - inputDeviceName: recordingInputDeviceName - ) - await MainActor.run { - self.currentSessionId = sessionId - // Start live notes session - LiveNotesMonitor.shared.startSession(sessionId: sessionId) - } - log("Transcription: Created DB session \(sessionId)") - } catch { - logError("Transcription: Failed to create DB session", error: error) - // Non-fatal - continue recording even if DB fails + // Create crash-safe DB session for persistence + Task { + do { + let sessionId = try await TranscriptionStorage.shared.startSession( + source: currentConversationSource.rawValue, + language: effectiveLanguage, + timezone: TimeZone.current.identifier, + inputDeviceName: recordingInputDeviceName + ) + await MainActor.run { + self.currentSessionId = sessionId + // Start live notes session + LiveNotesMonitor.shared.startSession(sessionId: sessionId) } + log("Transcription: Created DB session \(sessionId)") + } catch { + logError("Transcription: Failed to create DB session", error: error) + // Non-fatal - continue recording even if DB fails } + } - // Start 4-hour max recording timer - maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) { [weak self] _ in - Task { @MainActor in - guard let self = self, self.isTranscribing else { return } - log("Transcription: 4-hour limit reached - finalizing conversation") - _ = await self.finalizeConversation() - // Start a new recording session automatically - self.stopAudioCapture() - self.clearTranscriptionState() - self.startTranscription() - } + // Start 4-hour max recording timer + maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) { [weak self] _ in + Task { @MainActor in + guard let self = self, self.isTranscribing else { return } + log("Transcription: 4-hour limit reached - finalizing conversation") + _ = await self.finalizeConversation() + // Start a new recording session automatically + self.stopAudioCapture() + self.clearTranscriptionState() + self.startTranscription() } + } - // Track transcription started - AnalyticsManager.shared.transcriptionStarted() - - log("Transcription: Starting...") + // Track transcription started + AnalyticsManager.shared.transcriptionStarted() - } catch { - AnalyticsManager.shared.recordingError(error: error.localizedDescription) - showAlert(title: "Transcription Error", message: error.localizedDescription) - } + log("Transcription: Starting...") } /// Start audio capture and pipe to transcription service @@ -1330,23 +1306,12 @@ class AppState: ObservableObject { guard let audioCaptureService = audioCaptureService, let audioMixer = audioMixer else { return } - // Start the audio mixer - it will send stereo audio to transcription service - // Branch on batch vs streaming mode - audioMixer.start { [weak self] stereoData in + // Start the audio mixer in mono mode — backend handles diarization server-side + audioMixer.start(outputMode: .mono) { [weak self] monoData in guard let self = self else { return } - if self.useBatchTranscription { - // Batch mode: accumulate audio in VAD gate, transcribe on silence - guard let gate = self.vadGateService else { return } - let output = gate.processAudioBatch(stereoData) - if output.isComplete, let audioBuffer = output.audioBuffer { - let wallStartTime = output.speechStartWallTime - Task { @MainActor [weak self] in - await self?.batchTranscribeChunk(audioBuffer: audioBuffer, wallStartTime: wallStartTime) - } - } - } else if let gate = self.vadGateService { + if let gate = self.vadGateService { // Streaming mode with VAD gate - let output = gate.processAudio(stereoData) + let output = gate.processAudio(monoData) if !output.audioToSend.isEmpty { self.transcriptionService?.sendAudio(output.audioToSend) } else if gate.needsKeepalive() { @@ -1357,7 +1322,7 @@ class AppState: ObservableObject { } } else { // Streaming mode without VAD gate - self.transcriptionService?.sendAudio(stereoData) + self.transcriptionService?.sendAudio(monoData) } } @@ -1411,10 +1376,12 @@ class AppState: ObservableObject { return } - // Start BLE audio processing and pipe directly to transcription + // Start BLE audio processing and pipe mono PCM directly to backend transcription await BleAudioService.shared.startProcessing( from: connection, - transcriptionService: transcriptionService, + audioSink: { [weak transcriptionService] pcmData in + transcriptionService?.sendAudio(pcmData) + }, audioDataHandler: { _ in // Audio level is updated by BleAudioService Task { @MainActor in @@ -2095,6 +2062,19 @@ class AppState: ObservableObject { log("Transcription: Finalizing conversation with \(segmentsToUpload.count) segments") + // When backend owns conversation creation (via /v4/listen lifecycle manager), + // skip client-side createConversationFromSegments() to avoid duplicates. + // The backend already has all segments from the live stream and will process + // the conversation on timeout or next connection. + if backendOwnsConversation { + log("Transcription: Backend owns conversation — skipping client-side upload (\(segmentsToUpload.count) segments streamed)") + if let sessionId = sessionId { + // Mark session as completed — no retry needed since backend has the data + try? await TranscriptionStorage.shared.markSessionCompleted(id: sessionId, backendId: "backend-owned") + } + return .saved + } + // Convert SpeakerSegment to API request format (include person_id from live naming) let speakerPersonMap = liveSpeakerPersonMap let apiSegments = segmentsToUpload.map { segment in From e80ec706e62f2d04883e7b7ca5a1bdbea078da2f Mon Sep 17 00:00:00 2001 From: beastoin Date: Fri, 6 Mar 2026 08:35:56 +0100 Subject: [PATCH 113/163] Use BackendTranscriptionService for push-to-talk MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace direct Deepgram with backend service for live PTT. Remove batch transcription path entirely — backend handles STT server-side. Co-Authored-By: Claude Opus 4.6 --- .../PushToTalkManager.swift | 136 +++++------------- 1 file changed, 37 insertions(+), 99 deletions(-) diff --git a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift index aee2f956d7..4156578928 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift @@ -34,7 +34,7 @@ class PushToTalkManager: ObservableObject { private let doubleTapThreshold: TimeInterval = 0.4 // Transcription - private var transcriptionService: TranscriptionService? + private var transcriptionService: BackendTranscriptionService? private var audioCaptureService: AudioCaptureService? private var transcriptSegments: [String] = [] private var lastInterimText: String = "" @@ -302,58 +302,20 @@ class PushToTalkManager: ObservableObject { sound?.play() } - let isBatchMode = ShortcutSettings.shared.pttTranscriptionMode == .batch + // Flush remaining audio and wait for final transcript from backend + transcriptionService?.finishStream() + log("PushToTalkManager: finalizing — mic stopped, waiting for final transcript") - if isBatchMode { - // Batch mode: send accumulated audio to pre-recorded API - log("PushToTalkManager: finalizing (batch) — mic stopped, transcribing recorded audio") - batchAudioLock.lock() - let audioData = batchAudioBuffer - batchAudioBuffer = Data() - batchAudioLock.unlock() - - // Stop streaming service (was not used in batch mode, but clean up) - stopAudioTranscription() - - guard !audioData.isEmpty else { - log("PushToTalkManager: batch mode — no audio recorded") - sendTranscript() - return - } - - barState?.voiceTranscript = "Transcribing..." - - Task { - do { - let language = AssistantSettings.shared.effectiveTranscriptionLanguage - let transcript = try await TranscriptionService.batchTranscribe( - audioData: audioData, - language: language - ) - if let transcript, !transcript.isEmpty { - self.transcriptSegments = [transcript] - } - } catch { - logError("PushToTalkManager: batch transcription failed", error: error) - } + // Safety timeout: if backend doesn't send a final segment within 3s, send what we have + let timeout = DispatchWorkItem { [weak self] in + Task { @MainActor in + guard let self, self.state == .finalizing else { return } + log("PushToTalkManager: finalization timeout — sending transcript") self.sendTranscript() } - } else { - // Live mode: flush remaining audio and wait for final transcript from Deepgram - transcriptionService?.finishStream() - log("PushToTalkManager: finalizing (live) — mic stopped, waiting for final transcript") - - // Safety timeout: if Deepgram doesn't send a final segment within 3s, send what we have - let timeout = DispatchWorkItem { [weak self] in - Task { @MainActor in - guard let self, self.state == .finalizing else { return } - log("PushToTalkManager: live finalization timeout — sending transcript") - self.sendTranscript() - } - } - liveFinalizationTimeout = timeout - DispatchQueue.main.asyncAfter(deadline: .now() + 3.0, execute: timeout) } + liveFinalizationTimeout = timeout + DispatchQueue.main.asyncAfter(deadline: .now() + 3.0, execute: timeout) } private func sendTranscript() { @@ -421,50 +383,34 @@ class PushToTalkManager: ObservableObject { return } - let isBatchMode = ShortcutSettings.shared.pttTranscriptionMode == .batch + // Always use live streaming through the backend (no client-side batch mode) + startMicCapture() - if isBatchMode { - // Batch mode: just capture audio into buffer, no streaming connection - batchAudioLock.lock() - batchAudioBuffer = Data() - batchAudioLock.unlock() - startMicCapture(batchMode: true) - log("PushToTalkManager: started audio capture (batch mode)") - } else { - // Live mode: start mic capture and stream to Deepgram - startMicCapture() + let language = AssistantSettings.shared.effectiveTranscriptionLanguage + let service = BackendTranscriptionService(language: language) + transcriptionService = service - do { - let language = AssistantSettings.shared.effectiveTranscriptionLanguage - let service = try TranscriptionService(language: language, channels: 1) - transcriptionService = service - - service.start( - onTranscript: { [weak self] segment in - Task { @MainActor in - self?.handleTranscript(segment) - } - }, - onError: { [weak self] error in - Task { @MainActor in - logError("PushToTalkManager: transcription error", error: error) - self?.stopListening() - } - }, - onConnected: { - Task { @MainActor in - log("PushToTalkManager: DeepGram connected") - } - } - ) - } catch { - logError("PushToTalkManager: failed to create TranscriptionService", error: error) - stopListening() + service.start( + onTranscript: { [weak self] segment in + Task { @MainActor in + self?.handleTranscript(segment) + } + }, + onError: { [weak self] error in + Task { @MainActor in + logError("PushToTalkManager: transcription error", error: error) + self?.stopListening() + } + }, + onConnected: { + Task { @MainActor in + log("PushToTalkManager: backend connected") + } } - } + ) } - private func startMicCapture(batchMode: Bool = false) { + private func startMicCapture() { if audioCaptureService == nil { audioCaptureService = AudioCaptureService() } @@ -475,20 +421,12 @@ class PushToTalkManager: ObservableObject { do { try await capture.startCapture( onAudioChunk: { [weak self] audioData in - guard let self else { return } - if batchMode { - // Batch mode: accumulate audio in buffer - self.batchAudioLock.lock() - self.batchAudioBuffer.append(audioData) - self.batchAudioLock.unlock() - } else { - // Live mode: stream to Deepgram - self.transcriptionService?.sendAudio(audioData) - } + // Stream mono audio to backend + self?.transcriptionService?.sendAudio(audioData) }, onAudioLevel: { _ in } ) - log("PushToTalkManager: mic capture started (batch=\(batchMode))") + log("PushToTalkManager: mic capture started") } catch { logError("PushToTalkManager: mic capture failed", error: error) self.stopListening() From 3824c4a2bc7cee958dfb918659ed78d906bbd9be Mon Sep 17 00:00:00 2001 From: beastoin Date: Fri, 6 Mar 2026 08:36:00 +0100 Subject: [PATCH 114/163] Remove old transcriptionService parameter from AudioSourceManager Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/Audio/AudioSourceManager.swift | 1 - 1 file changed, 1 deletion(-) diff --git a/desktop/Desktop/Sources/Audio/AudioSourceManager.swift b/desktop/Desktop/Sources/Audio/AudioSourceManager.swift index 9444be7d99..3247adc62c 100644 --- a/desktop/Desktop/Sources/Audio/AudioSourceManager.swift +++ b/desktop/Desktop/Sources/Audio/AudioSourceManager.swift @@ -301,7 +301,6 @@ final class AudioSourceManager: ObservableObject { // Start BLE audio processing with direct audio callback and WAL recording await bleAudioService.startProcessing( from: connection, - transcriptionService: nil, // We'll handle routing ourselves audioDataHandler: { [weak self] pcmData in // Convert decoded PCM mono to stereo and forward self?.handleBleAudio(pcmData) From 2152e699876b3281769236eb50e68d906b6c13ce Mon Sep 17 00:00:00 2001 From: beastoin Date: Fri, 6 Mar 2026 08:36:03 +0100 Subject: [PATCH 115/163] Remove DEEPGRAM_API_KEY from desktop .env.example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit No longer needed — STT now routes through backend /v4/listen. Co-Authored-By: Claude Opus 4.6 --- desktop/.env.example | 3 --- 1 file changed, 3 deletions(-) diff --git a/desktop/.env.example b/desktop/.env.example index 87d0a94fb7..6c25ff1689 100644 --- a/desktop/.env.example +++ b/desktop/.env.example @@ -17,9 +17,6 @@ # Production: https://api.omi.me OMI_API_URL=http://localhost:8080 -# DeepGram API key — required for real-time transcription -DEEPGRAM_API_KEY= - # ─── AI (optional) ────────────────────────────────────────────────── # Gemini API key for proactive assistants and embeddings # Falls back to backend-side processing if not set From e2a88573499ae071b37ed3a095875de6c11fd68c Mon Sep 17 00:00:00 2001 From: beastoin Date: Fri, 6 Mar 2026 08:36:06 +0100 Subject: [PATCH 116/163] Add changelog entry for backend STT migration Co-Authored-By: Claude Opus 4.6 --- desktop/CHANGELOG.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/desktop/CHANGELOG.json b/desktop/CHANGELOG.json index 614cda711d..e9ad191f41 100644 --- a/desktop/CHANGELOG.json +++ b/desktop/CHANGELOG.json @@ -1,5 +1,7 @@ { - "unreleased": [], + "unreleased": [ + "Removed client-side Deepgram API key — transcription now routes securely through the Omi backend" + ], "releases": [ { "version": "0.11.90", From 2e76c8ec865242db95bcfbf4569d6c6dae748a22 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 05:13:53 +0100 Subject: [PATCH 117/163] Add focus analysis handler for desktop screen_frame messages (#5396) Co-Authored-By: Claude Opus 4.6 --- backend/utils/desktop/__init__.py | 0 backend/utils/desktop/focus.py | 149 ++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 backend/utils/desktop/__init__.py create mode 100644 backend/utils/desktop/focus.py diff --git a/backend/utils/desktop/__init__.py b/backend/utils/desktop/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/utils/desktop/focus.py b/backend/utils/desktop/focus.py new file mode 100644 index 0000000000..6807eeded9 --- /dev/null +++ b/backend/utils/desktop/focus.py @@ -0,0 +1,149 @@ +import logging +from typing import Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.goals import get_user_goals +from database.action_items import get_action_items +from database.memories import get_memories +from utils.llm.clients import llm_gemini_flash + +logger = logging.getLogger(__name__) + +# Match the desktop FocusAssistant's ScreenAnalysis schema +FOCUS_SYSTEM_PROMPT = """You are a focus coach. Analyze the PRIMARY/MAIN window in screenshots to determine \ +if the user is focused or distracted. + +IMPORTANT: Look at the MAIN APPLICATION WINDOW, not log text or terminal output. \ +If you see a code editor with logs that mention "YouTube" - that's just log text, \ +the user is CODING, not on YouTube. Text in logs/terminals mentioning a site does \ +NOT mean the user is on that site. + +CONTEXT-AWARE ANALYSIS: +Each request may include the user's active goals, current tasks, recent memories, \ +and analysis history. Use this context when available, but DO NOT let it prevent you \ +from flagging obvious distractions. + +- GOALS & TASKS: If the user's screen activity clearly relates to their active \ +goals or current tasks, they are FOCUSED. +- HISTORY: Use recent analysis history to notice patterns, acknowledge transitions, \ +and vary your responses. + +Set status to "distracted" if the PRIMARY window is: +- YouTube, Twitch, Netflix, TikTok (actual video site visible, not just text mentioning it) +- Social media feeds: Twitter/X, Instagram, Facebook, Reddit (casual browsing, not researching) +- News sites, entertainment sites, games +- Any content consumption with no clear work purpose + +Set status to "focused" if the PRIMARY window is: +- Code editors, IDEs, terminals, command line +- Documents, spreadsheets, slides, design tools +- Email, work chat (Slack, Teams), research +- Browsing that is clearly work-related (Stack Overflow, docs, PRs, Jira, etc.) + +When in doubt, lean toward "distracted" — it's better to nudge the user once too \ +often than to silently let them drift. + +Always provide a short coaching message (100 characters max for notification banner): +- If distracted: Create a unique nudge to refocus. Vary your approach — be playful, \ +direct, or motivational. +- If focused: Acknowledge their work with variety — don't just say "Nice focus!" \ +every time.""" + + +class FocusResult(BaseModel): + status: str = Field(description='Focus status: "focused" or "distracted"') + app_or_site: str = Field(description="Primary app or site in focus") + description: str = Field(description="Brief description of what the user is doing") + message: Optional[str] = Field(default=None, description="Short coaching message (max 100 chars)") + + +def _build_context(uid: str) -> str: + """Build context from user's goals, tasks, and memories (server-side).""" + parts = [] + + # Goals (up to 10) + try: + goals = get_user_goals(uid, limit=10) + if goals: + goal_lines = [f"- {g.get('title', g.get('description', ''))}" for g in goals] + parts.append("Active Goals:\n" + "\n".join(goal_lines)) + except Exception as e: + logger.warning(f"Failed to fetch goals for context: {e}") + + # Tasks (up to 50, not completed) + try: + tasks = get_action_items(uid, completed=False, limit=50) + if tasks: + task_lines = [f"- {t.get('description', '')}" for t in tasks[:50]] + parts.append("Current Tasks:\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch tasks for context: {e}") + + # Recent memories (up to 20, core category) + try: + memories = get_memories(uid, limit=20, categories=['core']) + if memories: + mem_lines = [f"- {m.get('structured', {}).get('title', m.get('content', ''))}" for m in memories[:20]] + parts.append("Recent Memories:\n" + "\n".join(mem_lines)) + except Exception as e: + logger.warning(f"Failed to fetch memories for context: {e}") + + return "\n\n".join(parts) if parts else "" + + +async def analyze_focus( + uid: str, + image_b64: str, + app_name: str = "", + window_title: str = "", + history: str = "", +) -> dict: + """Analyze a screenshot for focus status using vision LLM. + + Args: + uid: User ID for fetching context + image_b64: Base64-encoded JPEG screenshot + app_name: Name of the foreground app + window_title: Window title + history: Formatted recent analysis history + + Returns: + Dict with type, frame_id, status, app_or_site, description, message + """ + # Build context from user data + context = _build_context(uid) + + # Assemble prompt + prompt_parts = [] + if context: + prompt_parts.append(context) + if history: + prompt_parts.append(f"Recent activity (oldest to newest):\n{history}") + if app_name or window_title: + prompt_parts.append(f"Current app: {app_name}, Window: {window_title}") + prompt_parts.append("Now analyze this screenshot:") + + prompt_text = "\n\n".join(prompt_parts) + + # Call vision LLM with structured output + with_parser = llm_gemini_flash.with_structured_output(FocusResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=FOCUS_SYSTEM_PROMPT), + HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}, + ] + ), + ] + ) + + return { + "status": result.status, + "app_or_site": result.app_or_site, + "description": result.description, + "message": result.message, + } From f636720a38867dbfc7fe706e1d756012db564458 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 05:13:57 +0100 Subject: [PATCH 118/163] Add FocusResultEvent message type for desktop proactive AI (#5396) Co-Authored-By: Claude Opus 4.6 --- backend/models/message_event.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/backend/models/message_event.py b/backend/models/message_event.py index bddbeb2a27..36556f6af9 100644 --- a/backend/models/message_event.py +++ b/backend/models/message_event.py @@ -181,3 +181,21 @@ def to_json(self): j["type"] = self.event_type del j["event_type"] return j + + +# Desktop proactive AI events (Phase 2 — #5396) + + +class FocusResultEvent(MessageEvent): + event_type: str = "focus_result" + frame_id: str + status: str + app_or_site: str + description: str + message: Optional[str] = None + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j From beb5f6e94d41905ce2f8129b87b8ec1f0441aa5b Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 05:14:01 +0100 Subject: [PATCH 119/163] Add screen_frame dispatcher to /v4/listen for desktop focus analysis (#5396) Co-Authored-By: Claude Opus 4.6 --- backend/routers/transcribe.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 24914a7702..696a34718f 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -52,6 +52,7 @@ ) from models.message_event import ( ConversationEvent, + FocusResultEvent, FREEMIUM_ACTION_SETUP_ON_DEVICE_STT, FreemiumThresholdReachedEvent, LastConversationEvent, @@ -101,6 +102,7 @@ SPEAKER_MATCH_THRESHOLD, ) from utils.speaker_sample_migration import maybe_migrate_person_samples +from utils.desktop.focus import analyze_focus from utils.log_sanitizer import sanitize, sanitize_pii logger = logging.getLogger(__name__) @@ -2431,6 +2433,38 @@ async def close_soniox_profile(): logger.info( f"Speaker assignment ignored: missing speaker_id/person_id/person_name. {uid} {session_id}" ) + # Desktop proactive AI — screen_frame analysis (#5396) + elif json_data.get('type') == 'screen_frame': + frame_id = json_data.get('frame_id', '') + image_b64 = json_data.get('image_b64', '') + analyze_types = json_data.get('analyze', []) + if image_b64 and 'focus' in analyze_types: + async def _handle_focus(fid, img, app, wtitle): + try: + result = await analyze_focus( + uid=uid, + image_b64=img, + app_name=app, + window_title=wtitle, + ) + _send_message_event(FocusResultEvent( + frame_id=fid, + status=result['status'], + app_or_site=result['app_or_site'], + description=result['description'], + message=result.get('message'), + )) + except Exception as focus_err: + logger.error(f"Focus analysis failed: {focus_err} {uid} {session_id}") + + spawn(_handle_focus( + frame_id, + image_b64, + json_data.get('app_name', ''), + json_data.get('window_title', ''), + )) + elif not image_b64: + logger.warning(f"screen_frame missing image_b64 {uid} {session_id}") except json.JSONDecodeError: logger.info( f"Received non-json text message: {sanitize(message.get('text'))} {uid} {session_id}" From e3c970df46eff6d1c403e67e52ceabd3e3f9d8ba Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 05:14:04 +0100 Subject: [PATCH 120/163] Add 26 unit tests for desktop focus analysis (#5396) Co-Authored-By: Claude Opus 4.6 --- backend/test.sh | 1 + backend/tests/unit/test_desktop_focus.py | 382 +++++++++++++++++++++++ 2 files changed, 383 insertions(+) create mode 100644 backend/tests/unit/test_desktop_focus.py diff --git a/backend/test.sh b/backend/test.sh index e431460b1e..d3c5640275 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -43,3 +43,4 @@ pytest tests/unit/test_storage_upload_audio_chunk_data_protection.py -v pytest tests/unit/test_people_conversations_500s.py -v pytest tests/unit/test_firestore_read_ops_cache.py -v pytest tests/unit/test_ws_auth_handshake.py -v +pytest tests/unit/test_desktop_focus.py -v diff --git a/backend/tests/unit/test_desktop_focus.py b/backend/tests/unit/test_desktop_focus.py new file mode 100644 index 0000000000..60b52c80af --- /dev/null +++ b/backend/tests/unit/test_desktop_focus.py @@ -0,0 +1,382 @@ +"""Tests for desktop focus analysis (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Mock heavy dependencies before any project imports +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +# Now safe to import +from utils.desktop.focus import FocusResult, FOCUS_SYSTEM_PROMPT, _build_context +from models.message_event import FocusResultEvent + +# --- FocusResult model tests --- + + +class TestFocusResultModel: + def test_focus_result_focused(self): + result = FocusResult( + status="focused", + app_or_site="VS Code", + description="Writing Python code", + message="Great focus!", + ) + assert result.status == "focused" + assert result.app_or_site == "VS Code" + assert result.description == "Writing Python code" + assert result.message == "Great focus!" + + def test_focus_result_distracted(self): + result = FocusResult( + status="distracted", + app_or_site="YouTube", + description="Watching videos", + message="Time to refocus!", + ) + assert result.status == "distracted" + assert result.app_or_site == "YouTube" + + def test_focus_result_message_optional(self): + result = FocusResult( + status="focused", + app_or_site="Terminal", + description="Running tests", + ) + assert result.message is None + + def test_focus_result_message_none_explicit(self): + result = FocusResult( + status="focused", + app_or_site="Terminal", + description="Running tests", + message=None, + ) + assert result.message is None + + +# --- FocusResultEvent tests --- + + +class TestFocusResultEvent: + def test_focus_result_event_to_json(self): + event = FocusResultEvent( + frame_id="abc-123", + status="focused", + app_or_site="VS Code", + description="Writing code", + message="Keep it up!", + ) + j = event.to_json() + assert j["type"] == "focus_result" + assert j["frame_id"] == "abc-123" + assert j["status"] == "focused" + assert j["app_or_site"] == "VS Code" + assert j["description"] == "Writing code" + assert j["message"] == "Keep it up!" + assert "event_type" not in j + + def test_focus_result_event_null_message(self): + event = FocusResultEvent( + frame_id="def-456", + status="distracted", + app_or_site="Twitter", + description="Browsing feed", + ) + j = event.to_json() + assert j["type"] == "focus_result" + assert j["message"] is None + + def test_focus_result_event_default_type(self): + event = FocusResultEvent( + frame_id="x", + status="focused", + app_or_site="Code", + description="Working", + ) + assert event.event_type == "focus_result" + + +# --- Context building tests --- + + +class TestBuildContext: + @patch('utils.desktop.focus.get_memories', return_value=[]) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch('utils.desktop.focus.get_user_goals', return_value=[]) + def test_empty_context(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert result == "" + + @patch('utils.desktop.focus.get_memories', return_value=[]) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch( + 'utils.desktop.focus.get_user_goals', + return_value=[ + {"title": "Ship Phase 2"}, + {"title": "Learn Rust"}, + ], + ) + def test_goals_in_context(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Active Goals:" in result + assert "Ship Phase 2" in result + assert "Learn Rust" in result + + @patch('utils.desktop.focus.get_memories', return_value=[]) + @patch( + 'utils.desktop.focus.get_action_items', + return_value=[ + {"description": "Fix login bug"}, + {"description": "Review PR #42"}, + ], + ) + @patch('utils.desktop.focus.get_user_goals', return_value=[]) + def test_tasks_in_context(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Current Tasks:" in result + assert "Fix login bug" in result + assert "Review PR #42" in result + + @patch( + 'utils.desktop.focus.get_memories', + return_value=[ + {"structured": {"title": "Learned about WebSockets"}}, + ], + ) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch('utils.desktop.focus.get_user_goals', return_value=[]) + def test_memories_in_context(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Recent Memories:" in result + assert "Learned about WebSockets" in result + + @patch('utils.desktop.focus.get_memories', side_effect=Exception("DB error")) + @patch('utils.desktop.focus.get_action_items', side_effect=Exception("DB error")) + @patch('utils.desktop.focus.get_user_goals', side_effect=Exception("DB error")) + def test_context_graceful_on_errors(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert result == "" + + @patch('utils.desktop.focus.get_memories', return_value=[]) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch( + 'utils.desktop.focus.get_user_goals', + return_value=[ + {"description": "Goal without title"}, + ], + ) + def test_goals_fallback_to_description(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Goal without title" in result + + @patch( + 'utils.desktop.focus.get_memories', + return_value=[ + {"content": "Memory without structured field"}, + ], + ) + @patch('utils.desktop.focus.get_action_items', return_value=[]) + @patch('utils.desktop.focus.get_user_goals', return_value=[]) + def test_memories_fallback_to_content(self, mock_goals, mock_tasks, mock_memories): + result = _build_context("test-uid") + assert "Memory without structured field" in result + + +# --- analyze_focus integration tests --- + + +class TestAnalyzeFocus: + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_returns_result(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="VS Code", + description="Editing Python", + message="Nice work!", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + result = asyncio.get_event_loop().run_until_complete( + analyze_focus(uid="test", image_b64="base64data", app_name="VS Code", window_title="main.py") + ) + + assert result["status"] == "focused" + assert result["app_or_site"] == "VS Code" + assert result["description"] == "Editing Python" + assert result["message"] == "Nice work!" + + @patch('utils.desktop.focus._build_context', return_value="Active Goals:\n- Ship code") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_includes_context_in_prompt(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="distracted", + app_or_site="Twitter", + description="Browsing", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="data")) + + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + prompt_text = human_msg.content[0]["text"] + assert "Active Goals:" in prompt_text + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_includes_history(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="Terminal", + description="Running tests", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete( + analyze_focus( + uid="test", + image_b64="data", + history="1. [focused] VS Code: Writing code", + ) + ) + + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + prompt_text = human_msg.content[0]["text"] + assert "Recent activity" in prompt_text + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_includes_app_and_window(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="Safari", + description="Reading docs", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete( + analyze_focus(uid="test", image_b64="data", app_name="Safari", window_title="MDN Web Docs") + ) + + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + prompt_text = human_msg.content[0]["text"] + assert "Safari" in prompt_text + assert "MDN Web Docs" in prompt_text + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_sends_image_as_base64(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="Code", + description="Coding", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="FAKE_BASE64_IMAGE")) + + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + image_part = human_msg.content[1] + assert image_part["type"] == "image_url" + assert "FAKE_BASE64_IMAGE" in image_part["image_url"]["url"] + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_sends_system_prompt(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="focused", + app_or_site="Code", + description="Coding", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="data")) + + call_args = mock_parser.ainvoke.call_args[0][0] + system_msg = call_args[0] + assert FOCUS_SYSTEM_PROMPT in system_msg.content + + @patch('utils.desktop.focus._build_context', return_value="") + @patch('utils.desktop.focus.llm_gemini_flash') + def test_analyze_focus_distracted_result(self, mock_llm, mock_ctx): + from utils.desktop.focus import analyze_focus + + mock_parser = MagicMock() + mock_parser.ainvoke = AsyncMock( + return_value=FocusResult( + status="distracted", + app_or_site="Reddit", + description="Scrolling r/programming", + message="Back to work!", + ) + ) + mock_llm.with_structured_output.return_value = mock_parser + + result = asyncio.get_event_loop().run_until_complete(analyze_focus(uid="test", image_b64="data")) + + assert result["status"] == "distracted" + assert result["app_or_site"] == "Reddit" + assert result["message"] == "Back to work!" + + +# --- System prompt content tests --- + + +class TestFocusSystemPrompt: + def test_prompt_includes_focused_criteria(self): + assert "Code editors" in FOCUS_SYSTEM_PROMPT + + def test_prompt_includes_distracted_criteria(self): + assert "YouTube" in FOCUS_SYSTEM_PROMPT + assert "Twitter" in FOCUS_SYSTEM_PROMPT + + def test_prompt_warns_about_log_text(self): + assert "log text" in FOCUS_SYSTEM_PROMPT + + def test_prompt_mentions_context_aware(self): + assert "CONTEXT-AWARE" in FOCUS_SYSTEM_PROMPT + + def test_prompt_coaching_message_guidance(self): + assert "100 characters max" in FOCUS_SYSTEM_PROMPT From 44248b0350b1c4a6f51b9fc7dac9c371c71f11cc Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:37 +0100 Subject: [PATCH 121/163] Add task extraction handler for desktop screen analysis --- backend/utils/desktop/tasks.py | 156 +++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 backend/utils/desktop/tasks.py diff --git a/backend/utils/desktop/tasks.py b/backend/utils/desktop/tasks.py new file mode 100644 index 0000000000..85b297e633 --- /dev/null +++ b/backend/utils/desktop/tasks.py @@ -0,0 +1,156 @@ +import logging +from typing import List, Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.action_items import get_action_items +from utils.llm.clients import llm_gemini_flash + +logger = logging.getLogger(__name__) + +TASK_SYSTEM_PROMPT = """\ +You are a task extraction assistant. Analyze screenshots to identify actionable tasks, \ +requests, or to-dos visible on screen. + +EXTRACTION RULES: +- Only extract tasks that are clearly visible and actionable +- Title must be 6+ words, verb-first, naming a specific person/project/artifact + concrete action +- Skip vague or generic items ("do something", "check this") +- ~90% of screenshots contain NO new task — use no_tasks when nothing actionable is found + +DEDUPLICATION: +- Compare against the user's existing tasks provided in context +- If a task is semantically similar to an existing one (even with different wording), skip it +- "Call John" and "Phone John" are duplicates +- "Finish report by Friday" and "Complete report by end of week" are duplicates +- When in doubt, err on treating as duplicate (DON'T extract) + +PRIORITY GUIDELINES: +- high: urgent deadlines, blocking requests, error fixes +- medium: normal work tasks, follow-ups +- low: nice-to-haves, ideas, non-urgent items + +SOURCE CATEGORIES: +- direct_request: someone asked the user to do something (message, meeting, mention) +- self_generated: user's own idea, reminder, or goal subtask +- calendar_driven: event preparation, recurring task, deadline +- reactive: error response, notification, observation +- external_system: from project tools, alerts, documentation""" + + +class ExtractedTask(BaseModel): + title: str = Field(description="Verb-first title, 6+ words, specific person/project + concrete action") + description: str = Field(default="", description="Additional context if needed") + priority: str = Field(description="high, medium, or low") + tags: List[str] = Field(default_factory=list, description="1-3 relevant tags") + source_app: str = Field(default="", description="App where task was found") + inferred_deadline: Optional[str] = Field(default=None, description="yyyy-MM-dd format or null") + confidence: float = Field(ge=0.0, le=1.0, description="Extraction confidence") + source_category: str = Field( + default="reactive", description="direct_request|self_generated|calendar_driven|reactive|external_system" + ) + + +class TaskExtractionResult(BaseModel): + has_new_tasks: bool = Field(description="Whether any new tasks were found") + tasks: List[ExtractedTask] = Field(default_factory=list, description="Extracted tasks (empty if none)") + context_summary: str = Field(default="", description="Brief summary of what user is viewing") + current_activity: str = Field(default="", description="What user is actively doing") + + +def _build_task_context(uid: str) -> str: + """Build existing tasks context for deduplication.""" + parts = [] + + try: + # Active tasks (not completed) for dedup + active_tasks = get_action_items(uid, completed=False, limit=50) + if active_tasks: + task_lines = [] + for t in active_tasks: + desc = t.get('description', '') + due = t.get('due_at', '') + due_str = f" (Due: {due})" if due else "" + task_lines.append(f"- {desc}{due_str} [Pending]") + parts.append("Existing active tasks (DO NOT extract duplicates):\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch active tasks for dedup: {e}") + + try: + # Recently completed tasks (last 10) for dedup + completed_tasks = get_action_items(uid, completed=True, limit=10) + if completed_tasks: + task_lines = [f"- {t.get('description', '')} [Completed]" for t in completed_tasks[:10]] + parts.append("Recently completed tasks:\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch completed tasks: {e}") + + return "\n\n".join(parts) if parts else "" + + +async def extract_tasks( + uid: str, + image_b64: str, + app_name: str = "", + window_title: str = "", +) -> dict: + """Extract tasks from a screenshot using vision LLM. + + Args: + uid: User ID for fetching existing tasks (dedup) + image_b64: Base64-encoded JPEG screenshot + app_name: Name of the foreground app + window_title: Window title + + Returns: + Dict with has_new_tasks, tasks list, context_summary, current_activity + """ + # Pre-fetch existing tasks for dedup context + task_context = _build_task_context(uid) + + # Assemble prompt + prompt_parts = [] + if task_context: + prompt_parts.append(task_context) + if app_name or window_title: + prompt_parts.append(f"Current app: {app_name}, Window: {window_title}") + prompt_parts.append("Analyze this screenshot for actionable tasks:") + + prompt_text = "\n\n".join(prompt_parts) + + # Call vision LLM with structured output + with_parser = llm_gemini_flash.with_structured_output(TaskExtractionResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=TASK_SYSTEM_PROMPT), + HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}, + ] + ), + ] + ) + + tasks_list = [] + for task in result.tasks: + tasks_list.append( + { + "title": task.title, + "description": task.description, + "priority": task.priority, + "tags": task.tags, + "source_app": task.source_app or app_name, + "inferred_deadline": task.inferred_deadline, + "confidence": task.confidence, + "source_category": task.source_category, + } + ) + + return { + "has_new_tasks": result.has_new_tasks and len(tasks_list) > 0, + "tasks": tasks_list, + "context_summary": result.context_summary, + "current_activity": result.current_activity, + } From 2aefe84bba48efcb66efcf15c4ce818eb894ddaf Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:40 +0100 Subject: [PATCH 122/163] Add memory extraction handler for desktop screen analysis --- backend/utils/desktop/memories.py | 100 ++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 backend/utils/desktop/memories.py diff --git a/backend/utils/desktop/memories.py b/backend/utils/desktop/memories.py new file mode 100644 index 0000000000..c1bbbcdbda --- /dev/null +++ b/backend/utils/desktop/memories.py @@ -0,0 +1,100 @@ +import logging +from typing import List, Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.memories import get_memories +from utils.llm.clients import llm_gemini_flash + +logger = logging.getLogger(__name__) + +MEMORY_SYSTEM_PROMPT = """\ +You are a memory extraction assistant. Analyze screenshots to identify facts, insights, \ +or noteworthy information worth remembering about the user or their context. + +EXTRACTION RULES: +- Extract facts ABOUT the user: preferences, projects, people they work with, decisions, realizations +- Extract useful external information: advice, tips, insights from what they're reading +- Maximum 3 memories per screenshot +- Each memory should be a concise, standalone fact +- Skip trivial or transient information (UI state, loading screens, timestamps) +- ~80% of screenshots contain NO memorable information — return empty list when nothing stands out + +DEDUPLICATION: +- Compare against existing memories provided in context +- If a fact is already known, skip it +- Only extract genuinely NEW information + +CATEGORIES: +- system: Facts about the user (preferences, opinions, network, projects, habits) +- interesting: External wisdom or advice from others (articles, conversations, tips)""" + + +class ExtractedMemory(BaseModel): + content: str = Field(description="Concise statement of the fact or insight") + category: str = Field(description="system or interesting") + confidence: float = Field(ge=0.0, le=1.0, description="Extraction confidence") + + +class MemoryExtractionResult(BaseModel): + memories: List[ExtractedMemory] = Field(default_factory=list, description="Extracted memories (empty if none)") + + +def _build_memory_context(uid: str) -> str: + """Build existing memories context for deduplication.""" + try: + existing = get_memories(uid, limit=30, categories=['system', 'interesting']) + if existing: + lines = [] + for m in existing: + content = m.get('structured', {}).get('content', m.get('content', '')) + if content: + lines.append(f"- {content}") + if lines: + return "Existing memories (DO NOT extract duplicates):\n" + "\n".join(lines) + except Exception as e: + logger.warning(f"Failed to fetch existing memories: {e}") + return "" + + +async def extract_memories( + uid: str, + image_b64: str, + app_name: str = "", + window_title: str = "", +) -> dict: + """Extract memories from a screenshot using vision LLM. + + Returns: + Dict with memories list (each has content, category, confidence) + """ + memory_context = _build_memory_context(uid) + + prompt_parts = [] + if memory_context: + prompt_parts.append(memory_context) + if app_name or window_title: + prompt_parts.append(f"Current app: {app_name}, Window: {window_title}") + prompt_parts.append("Analyze this screenshot for noteworthy facts or insights:") + + prompt_text = "\n\n".join(prompt_parts) + + with_parser = llm_gemini_flash.with_structured_output(MemoryExtractionResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=MEMORY_SYSTEM_PROMPT), + HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}, + ] + ), + ] + ) + + return { + "memories": [ + {"content": m.content, "category": m.category, "confidence": m.confidence} for m in result.memories + ] + } From 0da775a1e42657c8f7624fedcda4d287dba56fc3 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:40 +0100 Subject: [PATCH 123/163] Add contextual advice handler for desktop screen analysis --- backend/utils/desktop/advice.py | 115 ++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 backend/utils/desktop/advice.py diff --git a/backend/utils/desktop/advice.py b/backend/utils/desktop/advice.py new file mode 100644 index 0000000000..c73a7ae251 --- /dev/null +++ b/backend/utils/desktop/advice.py @@ -0,0 +1,115 @@ +import logging +from typing import Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.goals import get_user_goals +from database.action_items import get_action_items +from utils.llm.clients import llm_gemini_flash + +logger = logging.getLogger(__name__) + +ADVICE_SYSTEM_PROMPT = """\ +You are a proactive assistant that offers brief, actionable advice based on what the user \ +is currently doing on their screen. Your advice should be contextual and helpful. + +ADVICE RULES: +- Only offer advice when you can provide genuinely useful, specific guidance +- Advice must relate to what's visible on screen +- Keep it short (1-2 sentences max) +- Be actionable — tell the user something they can DO, not just observe +- Consider the user's goals and tasks when forming advice +- ~70% of screenshots need NO advice — return null when nothing useful to say + +TONE: +- Direct and casual, not formal +- Helpful, not preachy +- Specific to what you see, not generic productivity tips + +CATEGORIES: +- productivity: efficiency tips, workflow improvements +- mistake_prevention: catching potential errors or oversights +- learning: suggesting resources or approaches +- health: break reminders, posture, eye strain (only if clearly needed) +- goal_alignment: connecting current activity to stated goals""" + + +class AdviceResult(BaseModel): + has_advice: bool = Field(description="Whether advice is warranted") + content: Optional[str] = Field(default=None, description="The advice (1-2 sentences, null if none)") + category: Optional[str] = Field( + default=None, description="productivity|mistake_prevention|learning|health|goal_alignment" + ) + confidence: float = Field(ge=0.0, le=1.0, description="Confidence this advice is useful") + + +def _build_advice_context(uid: str) -> str: + """Build user context for advice generation.""" + parts = [] + + try: + goals = get_user_goals(uid, limit=5) + if goals: + goal_lines = [f"- {g.get('title', g.get('description', ''))}" for g in goals] + parts.append("User's goals:\n" + "\n".join(goal_lines)) + except Exception as e: + logger.warning(f"Failed to fetch goals for advice: {e}") + + try: + tasks = get_action_items(uid, completed=False, limit=10) + if tasks: + task_lines = [f"- {t.get('description', '')}" for t in tasks[:10]] + parts.append("Current tasks:\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch tasks for advice: {e}") + + return "\n\n".join(parts) if parts else "" + + +async def generate_advice( + uid: str, + image_b64: str, + app_name: str = "", + window_title: str = "", +) -> dict: + """Generate contextual advice from a screenshot using vision LLM. + + Returns: + Dict with has_advice, content, category, confidence (or nulls if no advice) + """ + advice_context = _build_advice_context(uid) + + prompt_parts = [] + if advice_context: + prompt_parts.append(advice_context) + if app_name or window_title: + prompt_parts.append(f"Current app: {app_name}, Window: {window_title}") + prompt_parts.append("Based on this screenshot, do you have any specific, actionable advice?") + + prompt_text = "\n\n".join(prompt_parts) + + with_parser = llm_gemini_flash.with_structured_output(AdviceResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=ADVICE_SYSTEM_PROMPT), + HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}, + ] + ), + ] + ) + + if not result.has_advice: + return {"has_advice": False, "advice": None} + + return { + "has_advice": True, + "advice": { + "content": result.content, + "category": result.category, + "confidence": result.confidence, + }, + } From 4dde2a51956d0b75ceb3a469c641020f99b00c44 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:41 +0100 Subject: [PATCH 124/163] Add live notes handler for desktop transcript processing --- backend/utils/desktop/live_notes.py | 55 +++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 backend/utils/desktop/live_notes.py diff --git a/backend/utils/desktop/live_notes.py b/backend/utils/desktop/live_notes.py new file mode 100644 index 0000000000..4c88b878b4 --- /dev/null +++ b/backend/utils/desktop/live_notes.py @@ -0,0 +1,55 @@ +import logging + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from utils.llm.clients import llm_mini + +logger = logging.getLogger(__name__) + +LIVE_NOTES_SYSTEM_PROMPT = """\ +You are a live note-taking assistant. Given a transcript segment, generate a concise, \ +well-structured note that captures the key information. + +RULES: +- Condense transcript into clear, readable notes +- Preserve important details: names, numbers, decisions, action items +- Remove filler words, repetition, and hesitation +- Use bullet points for multiple items +- Keep notes under 200 words +- If the transcript is too short or contains no meaningful content, return empty string""" + + +class LiveNoteResult(BaseModel): + text: str = Field(description="The generated note (empty string if no meaningful content)") + + +async def generate_live_note( + text: str, + session_context: str = "", +) -> dict: + """Generate a live note from transcript text. + + Args: + text: Transcript text to summarize + session_context: Optional session context + + Returns: + Dict with text field (the note) + """ + prompt_parts = [] + if session_context: + prompt_parts.append(f"Session context: {session_context}") + prompt_parts.append(f"Transcript:\n{text}") + + prompt_text = "\n\n".join(prompt_parts) + + with_parser = llm_mini.with_structured_output(LiveNoteResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=LIVE_NOTES_SYSTEM_PROMPT), + HumanMessage(content=prompt_text), + ] + ) + + return {"text": result.text} From 36a4a8280d25ca22a0231d1e205619922e811537 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:42 +0100 Subject: [PATCH 125/163] Add user profile generation handler for desktop --- backend/utils/desktop/profile.py | 79 ++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 backend/utils/desktop/profile.py diff --git a/backend/utils/desktop/profile.py b/backend/utils/desktop/profile.py new file mode 100644 index 0000000000..84fa697b97 --- /dev/null +++ b/backend/utils/desktop/profile.py @@ -0,0 +1,79 @@ +import logging + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.memories import get_memories +from database.action_items import get_action_items +from database.goals import get_user_goals +from utils.llm.clients import llm_mini + +logger = logging.getLogger(__name__) + +PROFILE_SYSTEM_PROMPT = """\ +You are generating a concise user profile summary based on their data (goals, tasks, memories). \ +This profile helps other AI assistants understand who the user is and what they care about. + +FORMAT: +- Write in third person ("The user...") +- Include: professional focus, key projects, communication style, preferences +- Keep under 300 words +- Be factual — only include what's supported by the data +- If data is sparse, keep the profile short rather than speculating""" + + +class ProfileResult(BaseModel): + profile_text: str = Field(description="The generated user profile summary") + + +async def generate_profile(uid: str) -> dict: + """Generate a user profile from their goals, tasks, and memories. + + Returns: + Dict with profile_text + """ + parts = [] + + try: + goals = get_user_goals(uid, limit=10) + if goals: + goal_lines = [f"- {g.get('title', g.get('description', ''))}" for g in goals] + parts.append("Goals:\n" + "\n".join(goal_lines)) + except Exception as e: + logger.warning(f"Failed to fetch goals for profile: {e}") + + try: + tasks = get_action_items(uid, completed=False, limit=30) + if tasks: + task_lines = [f"- {t.get('description', '')}" for t in tasks[:30]] + parts.append("Active tasks:\n" + "\n".join(task_lines)) + except Exception as e: + logger.warning(f"Failed to fetch tasks for profile: {e}") + + try: + memories = get_memories(uid, limit=30, categories=['system']) + if memories: + mem_lines = [] + for m in memories: + content = m.get('structured', {}).get('content', m.get('content', '')) + if content: + mem_lines.append(f"- {content}") + if mem_lines: + parts.append("Known facts:\n" + "\n".join(mem_lines)) + except Exception as e: + logger.warning(f"Failed to fetch memories for profile: {e}") + + if not parts: + return {"profile_text": "No data available to generate profile."} + + data_text = "\n\n".join(parts) + + with_parser = llm_mini.with_structured_output(ProfileResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=PROFILE_SYSTEM_PROMPT), + HumanMessage(content=f"Generate a user profile from this data:\n\n{data_text}"), + ] + ) + + return {"profile_text": result.profile_text} From ef7154d72f24968bd7e34b9897f585e2f8627a0f Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:43 +0100 Subject: [PATCH 126/163] Add task reranking and deduplication handlers for desktop --- backend/utils/desktop/task_ops.py | 141 ++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 backend/utils/desktop/task_ops.py diff --git a/backend/utils/desktop/task_ops.py b/backend/utils/desktop/task_ops.py new file mode 100644 index 0000000000..3e6b7506c0 --- /dev/null +++ b/backend/utils/desktop/task_ops.py @@ -0,0 +1,141 @@ +import logging +from typing import List + +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from database.action_items import get_action_items +from utils.llm.clients import llm_mini + +logger = logging.getLogger(__name__) + +# --- Task Reranking --- + +RERANK_SYSTEM_PROMPT = """\ +You are a task prioritization assistant. Given a list of tasks, rerank them by importance \ +and urgency. Consider deadlines, dependencies, and impact. + +RULES: +- Most important/urgent tasks first +- Tasks with approaching deadlines rank higher +- Blocking tasks rank higher than blocked tasks +- Return the same task IDs in new order""" + + +class RankedTask(BaseModel): + id: str = Field(description="Task ID") + new_position: int = Field(description="New position (1 = most important)") + + +class RerankResult(BaseModel): + updated_tasks: List[RankedTask] = Field(description="Tasks in new priority order") + + +async def rerank_tasks(uid: str) -> dict: + """Rerank user's active tasks by priority. + + Returns: + Dict with updated_tasks list + """ + try: + tasks = get_action_items(uid, completed=False, limit=50) + except Exception as e: + logger.error(f"Failed to fetch tasks for reranking: {e}") + return {"updated_tasks": []} + + if not tasks: + return {"updated_tasks": []} + + task_lines = [] + for t in tasks: + tid = t.get('id', '') + desc = t.get('description', '') + due = t.get('due_at', '') + priority = t.get('priority', 'medium') + due_str = f", Due: {due}" if due else "" + task_lines.append(f"- ID: {tid} | {desc} | Priority: {priority}{due_str}") + + task_text = "\n".join(task_lines) + + with_parser = llm_mini.with_structured_output(RerankResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=RERANK_SYSTEM_PROMPT), + HumanMessage(content=f"Rerank these tasks by importance:\n\n{task_text}"), + ] + ) + + return {"updated_tasks": [{"id": t.id, "new_position": t.new_position} for t in result.updated_tasks]} + + +# --- Task Deduplication --- + +DEDUP_SYSTEM_PROMPT = """\ +You are a task deduplication assistant. Identify semantically duplicate tasks and decide \ +which to keep and which to delete. + +RULES: +- Two tasks are duplicates if they describe the same action, even with different wording +- "Call John" and "Phone John" are duplicates +- "Review PR #42" and "Look at pull request 42" are duplicates +- Keep the more specific/detailed version +- Keep the one with a deadline if only one has one +- Keep the more recently created one if equally specific +- Only flag true duplicates — similar but distinct tasks should both be kept""" + + +class DedupGroup(BaseModel): + keep_id: str = Field(description="ID of the task to keep") + delete_ids: List[str] = Field(description="IDs of duplicate tasks to remove") + reason: str = Field(description="Why these are duplicates") + + +class DedupResult(BaseModel): + groups: List[DedupGroup] = Field(default_factory=list, description="Duplicate groups (empty if no duplicates)") + + +async def dedup_tasks(uid: str) -> dict: + """Find and resolve duplicate tasks. + + Returns: + Dict with deleted_ids and reason + """ + try: + tasks = get_action_items(uid, completed=False, limit=100) + except Exception as e: + logger.error(f"Failed to fetch tasks for dedup: {e}") + return {"deleted_ids": [], "reason": "Failed to fetch tasks"} + + if len(tasks) < 2: + return {"deleted_ids": [], "reason": "Not enough tasks to deduplicate"} + + task_lines = [] + for t in tasks: + tid = t.get('id', '') + desc = t.get('description', '') + due = t.get('due_at', '') + created = t.get('created_at', '') + due_str = f", Due: {due}" if due else "" + created_str = f", Created: {created}" if created else "" + task_lines.append(f"- ID: {tid} | {desc}{due_str}{created_str}") + + task_text = "\n".join(task_lines) + + with_parser = llm_mini.with_structured_output(DedupResult) + result = await with_parser.ainvoke( + [ + SystemMessage(content=DEDUP_SYSTEM_PROMPT), + HumanMessage(content=f"Find duplicate tasks:\n\n{task_text}"), + ] + ) + + all_deleted = [] + reasons = [] + for group in result.groups: + all_deleted.extend(group.delete_ids) + reasons.append(group.reason) + + return { + "deleted_ids": all_deleted, + "reason": "; ".join(reasons) if reasons else "No duplicates found", + } From 24f9e9b98f515b15cd7fb92ca5aed14e33034fb0 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:46 +0100 Subject: [PATCH 127/163] Add message event classes for all desktop handler types --- backend/models/message_event.py | 81 +++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/backend/models/message_event.py b/backend/models/message_event.py index 36556f6af9..b422c767b1 100644 --- a/backend/models/message_event.py +++ b/backend/models/message_event.py @@ -199,3 +199,84 @@ def to_json(self): j["type"] = self.event_type del j["event_type"] return j + + +class TasksExtractedEvent(MessageEvent): + event_type: str = "tasks_extracted" + frame_id: str + tasks: List = [] + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class MemoriesExtractedEvent(MessageEvent): + event_type: str = "memories_extracted" + frame_id: str + memories: List = [] + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class AdviceExtractedEvent(MessageEvent): + event_type: str = "advice_extracted" + frame_id: str + advice: Optional[Any] = None + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class LiveNoteEvent(MessageEvent): + event_type: str = "live_note" + text: str + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class ProfileUpdatedEvent(MessageEvent): + event_type: str = "profile_updated" + profile_text: str + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class RerankCompleteEvent(MessageEvent): + event_type: str = "rerank_complete" + updated_tasks: List = [] + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j + + +class DedupCompleteEvent(MessageEvent): + event_type: str = "dedup_complete" + deleted_ids: List = [] + reason: str = "" + + def to_json(self): + j = self.model_dump(mode="json") + j["type"] = self.event_type + del j["event_type"] + return j From 2794289178bf2c3fca4e38bb868d627fae1bce38 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:22:47 +0100 Subject: [PATCH 128/163] Add full desktop dispatcher for screen_frame and text message types --- backend/routers/transcribe.py | 125 +++++++++++++++++++++++++++------- 1 file changed, 99 insertions(+), 26 deletions(-) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 696a34718f..fbd85413a6 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -51,17 +51,24 @@ TranscriptSegment, ) from models.message_event import ( + AdviceExtractedEvent, ConversationEvent, + DedupCompleteEvent, FocusResultEvent, FREEMIUM_ACTION_SETUP_ON_DEVICE_STT, FreemiumThresholdReachedEvent, LastConversationEvent, + LiveNoteEvent, + MemoriesExtractedEvent, MessageEvent, MessageServiceStatusEvent, PhotoDescribedEvent, PhotoProcessingEvent, + ProfileUpdatedEvent, + RerankCompleteEvent, SegmentsDeletedEvent, SpeakerLabelSuggestionEvent, + TasksExtractedEvent, TranslationEvent, ) from models.transcript_segment import Translation @@ -102,7 +109,13 @@ SPEAKER_MATCH_THRESHOLD, ) from utils.speaker_sample_migration import maybe_migrate_person_samples +from utils.desktop.advice import generate_advice from utils.desktop.focus import analyze_focus +from utils.desktop.live_notes import generate_live_note +from utils.desktop.memories import extract_memories +from utils.desktop.profile import generate_profile +from utils.desktop.task_ops import dedup_tasks, rerank_tasks +from utils.desktop.tasks import extract_tasks from utils.log_sanitizer import sanitize, sanitize_pii logger = logging.getLogger(__name__) @@ -2438,33 +2451,93 @@ async def close_soniox_profile(): frame_id = json_data.get('frame_id', '') image_b64 = json_data.get('image_b64', '') analyze_types = json_data.get('analyze', []) - if image_b64 and 'focus' in analyze_types: - async def _handle_focus(fid, img, app, wtitle): - try: - result = await analyze_focus( - uid=uid, - image_b64=img, - app_name=app, - window_title=wtitle, - ) - _send_message_event(FocusResultEvent( - frame_id=fid, - status=result['status'], - app_or_site=result['app_or_site'], - description=result['description'], - message=result.get('message'), - )) - except Exception as focus_err: - logger.error(f"Focus analysis failed: {focus_err} {uid} {session_id}") - - spawn(_handle_focus( - frame_id, - image_b64, - json_data.get('app_name', ''), - json_data.get('window_title', ''), - )) - elif not image_b64: + sf_app = json_data.get('app_name', '') + sf_wtitle = json_data.get('window_title', '') + if not image_b64: logger.warning(f"screen_frame missing image_b64 {uid} {session_id}") + else: + # Fan out to parallel handlers per analyze type + if 'focus' in analyze_types: + async def _handle_focus(fid, img, app, wtitle): + try: + result = await analyze_focus(uid=uid, image_b64=img, app_name=app, window_title=wtitle) + _send_message_event(FocusResultEvent( + frame_id=fid, status=result['status'], app_or_site=result['app_or_site'], + description=result['description'], message=result.get('message'), + )) + except Exception as e: + logger.error(f"Focus analysis failed: {e} {uid} {session_id}") + spawn(_handle_focus(frame_id, image_b64, sf_app, sf_wtitle)) + + if 'tasks' in analyze_types: + async def _handle_tasks(fid, img, app, wtitle): + try: + result = await extract_tasks(uid=uid, image_b64=img, app_name=app, window_title=wtitle) + _send_message_event(TasksExtractedEvent(frame_id=fid, tasks=result.get('tasks', []))) + except Exception as e: + logger.error(f"Task extraction failed: {e} {uid} {session_id}") + spawn(_handle_tasks(frame_id, image_b64, sf_app, sf_wtitle)) + + if 'memories' in analyze_types: + async def _handle_memories(fid, img, app, wtitle): + try: + result = await extract_memories(uid=uid, image_b64=img, app_name=app, window_title=wtitle) + _send_message_event(MemoriesExtractedEvent(frame_id=fid, memories=result.get('memories', []))) + except Exception as e: + logger.error(f"Memory extraction failed: {e} {uid} {session_id}") + spawn(_handle_memories(frame_id, image_b64, sf_app, sf_wtitle)) + + if 'advice' in analyze_types: + async def _handle_advice(fid, img, app, wtitle): + try: + result = await generate_advice(uid=uid, image_b64=img, app_name=app, window_title=wtitle) + _send_message_event(AdviceExtractedEvent( + frame_id=fid, advice=result.get('advice'), + )) + except Exception as e: + logger.error(f"Advice generation failed: {e} {uid} {session_id}") + spawn(_handle_advice(frame_id, image_b64, sf_app, sf_wtitle)) + + # Desktop proactive AI — text-only message types (#5396) + elif json_data.get('type') == 'live_notes_text': + async def _handle_live_notes(text, ctx): + try: + result = await generate_live_note(text=text, session_context=ctx) + if result.get('text'): + _send_message_event(LiveNoteEvent(text=result['text'])) + except Exception as e: + logger.error(f"Live note generation failed: {e} {uid} {session_id}") + spawn(_handle_live_notes(json_data.get('text', ''), json_data.get('session_context', ''))) + + elif json_data.get('type') == 'profile_request': + async def _handle_profile(): + try: + result = await generate_profile(uid=uid) + _send_message_event(ProfileUpdatedEvent(profile_text=result['profile_text'])) + except Exception as e: + logger.error(f"Profile generation failed: {e} {uid} {session_id}") + spawn(_handle_profile()) + + elif json_data.get('type') == 'task_rerank': + async def _handle_rerank(): + try: + result = await rerank_tasks(uid=uid) + _send_message_event(RerankCompleteEvent(updated_tasks=result['updated_tasks'])) + except Exception as e: + logger.error(f"Task reranking failed: {e} {uid} {session_id}") + spawn(_handle_rerank()) + + elif json_data.get('type') == 'task_dedup': + async def _handle_dedup(): + try: + result = await dedup_tasks(uid=uid) + _send_message_event(DedupCompleteEvent( + deleted_ids=result['deleted_ids'], reason=result['reason'], + )) + except Exception as e: + logger.error(f"Task dedup failed: {e} {uid} {session_id}") + spawn(_handle_dedup()) + except json.JSONDecodeError: logger.info( f"Received non-json text message: {sanitize(message.get('text'))} {uid} {session_id}" From 4c5abcdc34b9ee180b653b6572b424988ddd6eaf Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:47 +0100 Subject: [PATCH 129/163] Add unit tests for task extraction handler (18 tests) --- backend/tests/unit/test_desktop_tasks.py | 238 +++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 backend/tests/unit/test_desktop_tasks.py diff --git a/backend/tests/unit/test_desktop_tasks.py b/backend/tests/unit/test_desktop_tasks.py new file mode 100644 index 0000000000..908b4d9a50 --- /dev/null +++ b/backend/tests/unit/test_desktop_tasks.py @@ -0,0 +1,238 @@ +"""Tests for desktop task extraction handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Mock heavy dependencies before any project imports +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.tasks import ( + ExtractedTask, + TaskExtractionResult, + TASK_SYSTEM_PROMPT, + _build_task_context, + extract_tasks, +) +from models.message_event import TasksExtractedEvent + + +class TestExtractedTaskModel: + def test_task_with_all_fields(self): + task = ExtractedTask( + title="Review pull request 42 for authentication changes", + description="Check auth middleware", + priority="high", + tags=["code-review", "auth"], + source_app="GitHub", + inferred_deadline="2026-03-10", + confidence=0.9, + source_category="direct_request", + ) + assert task.title == "Review pull request 42 for authentication changes" + assert task.priority == "high" + assert task.confidence == 0.9 + + def test_task_defaults(self): + task = ExtractedTask( + title="Update the README with new API docs", + priority="low", + confidence=0.5, + ) + assert task.description == "" + assert task.tags == [] + assert task.source_app == "" + assert task.inferred_deadline is None + assert task.source_category == "reactive" + + def test_task_confidence_bounds(self): + with pytest.raises(Exception): + ExtractedTask(title="Test", priority="high", confidence=1.5) + with pytest.raises(Exception): + ExtractedTask(title="Test", priority="high", confidence=-0.1) + + +class TestTaskExtractionResult: + def test_result_with_tasks(self): + result = TaskExtractionResult( + has_new_tasks=True, + tasks=[ + ExtractedTask(title="Call John about the project deadline", priority="high", confidence=0.8), + ], + context_summary="Slack messages", + current_activity="Reading messages", + ) + assert result.has_new_tasks is True + assert len(result.tasks) == 1 + + def test_result_no_tasks(self): + result = TaskExtractionResult( + has_new_tasks=False, + context_summary="IDE open", + current_activity="Coding", + ) + assert result.has_new_tasks is False + assert result.tasks == [] + + +class TestTasksExtractedEvent: + def test_event_structure(self): + event = TasksExtractedEvent( + frame_id="frame123", + tasks=[{"title": "Test task", "priority": "high"}], + ) + data = event.to_json() + assert data["type"] == "tasks_extracted" + assert data["frame_id"] == "frame123" + assert len(data["tasks"]) == 1 + + +class TestBuildTaskContext: + @patch('utils.desktop.tasks.get_action_items') + def test_active_tasks_in_context(self, mock_get): + mock_get.return_value = [ + {'description': 'Write tests', 'due_at': '2026-03-10'}, + {'description': 'Fix bug'}, + ] + ctx = _build_task_context("uid1") + assert "Write tests" in ctx + assert "Due: 2026-03-10" in ctx + assert "Fix bug" in ctx + assert "Pending" in ctx + + @patch('utils.desktop.tasks.get_action_items') + def test_completed_tasks_in_context(self, mock_get): + mock_get.side_effect = [ + [], # active tasks + [{'description': 'Done task'}], # completed tasks + ] + ctx = _build_task_context("uid1") + assert "Done task" in ctx + assert "Completed" in ctx + + @patch('utils.desktop.tasks.get_action_items') + def test_empty_context(self, mock_get): + mock_get.return_value = [] + ctx = _build_task_context("uid1") + assert ctx == "" + + @patch('utils.desktop.tasks.get_action_items') + def test_graceful_on_errors(self, mock_get): + mock_get.side_effect = Exception("DB error") + ctx = _build_task_context("uid1") + assert ctx == "" + + +class TestExtractTasks: + @patch('utils.desktop.tasks._build_task_context') + @patch('utils.desktop.tasks.llm_gemini_flash') + def test_extract_tasks_returns_result(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=TaskExtractionResult( + has_new_tasks=True, + tasks=[ + ExtractedTask( + title="Review pull request 42 for auth changes", + priority="high", + confidence=0.9, + source_app="GitHub", + ) + ], + context_summary="GitHub PR page", + current_activity="Reviewing code", + ) + ) + result = asyncio.get_event_loop().run_until_complete( + extract_tasks("uid1", "base64img", "Chrome", "GitHub PR") + ) + assert result["has_new_tasks"] is True + assert len(result["tasks"]) == 1 + assert result["tasks"][0]["title"] == "Review pull request 42 for auth changes" + assert result["tasks"][0]["source_app"] == "GitHub" + + @patch('utils.desktop.tasks._build_task_context') + @patch('utils.desktop.tasks.llm_gemini_flash') + def test_extract_tasks_no_tasks(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=TaskExtractionResult( + has_new_tasks=False, + context_summary="Desktop idle", + current_activity="Nothing", + ) + ) + result = asyncio.get_event_loop().run_until_complete( + extract_tasks("uid1", "base64img") + ) + assert result["has_new_tasks"] is False + assert result["tasks"] == [] + + @patch('utils.desktop.tasks._build_task_context') + @patch('utils.desktop.tasks.llm_gemini_flash') + def test_source_app_fallback(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=TaskExtractionResult( + has_new_tasks=True, + tasks=[ + ExtractedTask( + title="Send email to team about deadline update", + priority="medium", + confidence=0.7, + source_app="", # empty + ) + ], + ) + ) + result = asyncio.get_event_loop().run_until_complete( + extract_tasks("uid1", "base64img", "Slack", "General") + ) + # Falls back to app_name when source_app is empty + assert result["tasks"][0]["source_app"] == "Slack" + + @patch('utils.desktop.tasks._build_task_context') + @patch('utils.desktop.tasks.llm_gemini_flash') + def test_includes_context_in_prompt(self, mock_llm, mock_ctx): + mock_ctx.return_value = "Existing active tasks:\n- Write tests [Pending]" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=TaskExtractionResult(has_new_tasks=False) + ) + asyncio.get_event_loop().run_until_complete( + extract_tasks("uid1", "base64img", "VS Code", "main.py") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + text_content = human_msg.content[0]["text"] + assert "Write tests" in text_content + assert "VS Code" in text_content + + +class TestTaskSystemPrompt: + def test_prompt_includes_dedup_rules(self): + assert "DEDUPLICATION" in TASK_SYSTEM_PROMPT + + def test_prompt_includes_priority_guidelines(self): + assert "high" in TASK_SYSTEM_PROMPT + assert "medium" in TASK_SYSTEM_PROMPT + assert "low" in TASK_SYSTEM_PROMPT + + def test_prompt_includes_source_categories(self): + assert "direct_request" in TASK_SYSTEM_PROMPT + assert "self_generated" in TASK_SYSTEM_PROMPT + assert "calendar_driven" in TASK_SYSTEM_PROMPT From 2d1d32ae131ca270369c4671889e476e38cb63c1 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:50 +0100 Subject: [PATCH 130/163] Add unit tests for memory extraction handler (14 tests) --- backend/tests/unit/test_desktop_memories.py | 150 ++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 backend/tests/unit/test_desktop_memories.py diff --git a/backend/tests/unit/test_desktop_memories.py b/backend/tests/unit/test_desktop_memories.py new file mode 100644 index 0000000000..158760e2c4 --- /dev/null +++ b/backend/tests/unit/test_desktop_memories.py @@ -0,0 +1,150 @@ +"""Tests for desktop memory extraction handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.memories import ( + ExtractedMemory, + MemoryExtractionResult, + MEMORY_SYSTEM_PROMPT, + _build_memory_context, + extract_memories, +) +from models.message_event import MemoriesExtractedEvent + + +class TestExtractedMemoryModel: + def test_memory_all_fields(self): + m = ExtractedMemory(content="User prefers dark mode", category="system", confidence=0.95) + assert m.content == "User prefers dark mode" + assert m.category == "system" + assert m.confidence == 0.95 + + def test_memory_interesting_category(self): + m = ExtractedMemory(content="AI tip from article", category="interesting", confidence=0.7) + assert m.category == "interesting" + + def test_confidence_bounds(self): + with pytest.raises(Exception): + ExtractedMemory(content="test", category="system", confidence=1.5) + + +class TestMemoryExtractionResult: + def test_result_with_memories(self): + result = MemoryExtractionResult( + memories=[ExtractedMemory(content="Fact 1", category="system", confidence=0.8)] + ) + assert len(result.memories) == 1 + + def test_result_empty(self): + result = MemoryExtractionResult() + assert result.memories == [] + + +class TestMemoriesExtractedEvent: + def test_event_structure(self): + event = MemoriesExtractedEvent( + frame_id="frame456", + memories=[{"content": "Test fact", "category": "system", "confidence": 0.9}], + ) + data = event.to_json() + assert data["type"] == "memories_extracted" + assert data["frame_id"] == "frame456" + assert len(data["memories"]) == 1 + + +class TestBuildMemoryContext: + @patch('utils.desktop.memories.get_memories') + def test_existing_memories_in_context(self, mock_get): + mock_get.return_value = [ + {'structured': {'content': 'User likes Python'}}, + {'content': 'Fallback content'}, + ] + ctx = _build_memory_context("uid1") + assert "User likes Python" in ctx + assert "Fallback content" in ctx + assert "DO NOT extract duplicates" in ctx + + @patch('utils.desktop.memories.get_memories') + def test_empty_context(self, mock_get): + mock_get.return_value = [] + ctx = _build_memory_context("uid1") + assert ctx == "" + + @patch('utils.desktop.memories.get_memories') + def test_graceful_on_errors(self, mock_get): + mock_get.side_effect = Exception("DB error") + ctx = _build_memory_context("uid1") + assert ctx == "" + + +class TestExtractMemories: + @patch('utils.desktop.memories._build_memory_context') + @patch('utils.desktop.memories.llm_gemini_flash') + def test_extract_returns_memories(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=MemoryExtractionResult( + memories=[ + ExtractedMemory(content="User works on Omi project", category="system", confidence=0.85), + ] + ) + ) + result = asyncio.get_event_loop().run_until_complete( + extract_memories("uid1", "base64img", "VS Code", "omi/main.py") + ) + assert len(result["memories"]) == 1 + assert result["memories"][0]["content"] == "User works on Omi project" + assert result["memories"][0]["category"] == "system" + + @patch('utils.desktop.memories._build_memory_context') + @patch('utils.desktop.memories.llm_gemini_flash') + def test_extract_empty_result(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=MemoryExtractionResult()) + result = asyncio.get_event_loop().run_until_complete( + extract_memories("uid1", "base64img") + ) + assert result["memories"] == [] + + @patch('utils.desktop.memories._build_memory_context') + @patch('utils.desktop.memories.llm_gemini_flash') + def test_sends_image_and_system_prompt(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=MemoryExtractionResult()) + asyncio.get_event_loop().run_until_complete( + extract_memories("uid1", "testimg64") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + sys_msg = call_args[0] + human_msg = call_args[1] + assert MEMORY_SYSTEM_PROMPT in sys_msg.content + assert human_msg.content[1]["image_url"]["url"] == "data:image/jpeg;base64,testimg64" + + +class TestMemorySystemPrompt: + def test_includes_extraction_rules(self): + assert "EXTRACTION RULES" in MEMORY_SYSTEM_PROMPT + + def test_includes_dedup(self): + assert "DEDUPLICATION" in MEMORY_SYSTEM_PROMPT + + def test_includes_categories(self): + assert "system" in MEMORY_SYSTEM_PROMPT + assert "interesting" in MEMORY_SYSTEM_PROMPT From f3b20e38c62438c8c2661090451f43103b59ff23 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:51 +0100 Subject: [PATCH 131/163] Add unit tests for advice handler (14 tests) --- backend/tests/unit/test_desktop_advice.py | 159 ++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 backend/tests/unit/test_desktop_advice.py diff --git a/backend/tests/unit/test_desktop_advice.py b/backend/tests/unit/test_desktop_advice.py new file mode 100644 index 0000000000..e1699f3eb7 --- /dev/null +++ b/backend/tests/unit/test_desktop_advice.py @@ -0,0 +1,159 @@ +"""Tests for desktop advice handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.advice import ( + AdviceResult, + ADVICE_SYSTEM_PROMPT, + _build_advice_context, + generate_advice, +) +from models.message_event import AdviceExtractedEvent + + +class TestAdviceResultModel: + def test_advice_with_content(self): + r = AdviceResult(has_advice=True, content="Take a break", category="health", confidence=0.8) + assert r.has_advice is True + assert r.content == "Take a break" + assert r.category == "health" + + def test_no_advice(self): + r = AdviceResult(has_advice=False, confidence=0.1) + assert r.has_advice is False + assert r.content is None + assert r.category is None + + def test_confidence_bounds(self): + with pytest.raises(Exception): + AdviceResult(has_advice=True, confidence=2.0) + + +class TestAdviceExtractedEvent: + def test_event_with_advice(self): + event = AdviceExtractedEvent( + frame_id="frame789", + advice={"content": "Try dark mode", "category": "productivity", "confidence": 0.7}, + ) + data = event.to_json() + assert data["type"] == "advice_extracted" + assert data["frame_id"] == "frame789" + assert data["advice"]["content"] == "Try dark mode" + + def test_event_no_advice(self): + event = AdviceExtractedEvent(frame_id="frame789", advice=None) + data = event.to_json() + assert data["advice"] is None + + +class TestBuildAdviceContext: + @patch('utils.desktop.advice.get_action_items') + @patch('utils.desktop.advice.get_user_goals') + def test_goals_and_tasks_in_context(self, mock_goals, mock_tasks): + mock_goals.return_value = [{'title': 'Ship v2'}] + mock_tasks.return_value = [{'description': 'Write tests'}] + ctx = _build_advice_context("uid1") + assert "Ship v2" in ctx + assert "Write tests" in ctx + + @patch('utils.desktop.advice.get_action_items') + @patch('utils.desktop.advice.get_user_goals') + def test_empty_context(self, mock_goals, mock_tasks): + mock_goals.return_value = [] + mock_tasks.return_value = [] + ctx = _build_advice_context("uid1") + assert ctx == "" + + @patch('utils.desktop.advice.get_action_items') + @patch('utils.desktop.advice.get_user_goals') + def test_graceful_on_errors(self, mock_goals, mock_tasks): + mock_goals.side_effect = Exception("DB error") + mock_tasks.side_effect = Exception("DB error") + ctx = _build_advice_context("uid1") + assert ctx == "" + + @patch('utils.desktop.advice.get_action_items') + @patch('utils.desktop.advice.get_user_goals') + def test_goals_fallback_to_description(self, mock_goals, mock_tasks): + mock_goals.return_value = [{'description': 'Fallback goal'}] + mock_tasks.return_value = [] + ctx = _build_advice_context("uid1") + assert "Fallback goal" in ctx + + +class TestGenerateAdvice: + @patch('utils.desktop.advice._build_advice_context') + @patch('utils.desktop.advice.llm_gemini_flash') + def test_returns_advice(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=AdviceResult( + has_advice=True, + content="Consider using a linter", + category="productivity", + confidence=0.75, + ) + ) + result = asyncio.get_event_loop().run_until_complete( + generate_advice("uid1", "base64img", "VS Code", "main.py") + ) + assert result["has_advice"] is True + assert result["advice"]["content"] == "Consider using a linter" + assert result["advice"]["category"] == "productivity" + + @patch('utils.desktop.advice._build_advice_context') + @patch('utils.desktop.advice.llm_gemini_flash') + def test_no_advice(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=AdviceResult(has_advice=False, confidence=0.1) + ) + result = asyncio.get_event_loop().run_until_complete( + generate_advice("uid1", "base64img") + ) + assert result["has_advice"] is False + assert result["advice"] is None + + @patch('utils.desktop.advice._build_advice_context') + @patch('utils.desktop.advice.llm_gemini_flash') + def test_includes_app_info(self, mock_llm, mock_ctx): + mock_ctx.return_value = "" + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=AdviceResult(has_advice=False, confidence=0.1) + ) + asyncio.get_event_loop().run_until_complete( + generate_advice("uid1", "base64img", "Chrome", "Stack Overflow") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + text_content = human_msg.content[0]["text"] + assert "Chrome" in text_content + assert "Stack Overflow" in text_content + + +class TestAdviceSystemPrompt: + def test_includes_categories(self): + assert "productivity" in ADVICE_SYSTEM_PROMPT + assert "mistake_prevention" in ADVICE_SYSTEM_PROMPT + assert "health" in ADVICE_SYSTEM_PROMPT + assert "goal_alignment" in ADVICE_SYSTEM_PROMPT + + def test_includes_tone_guidance(self): + assert "TONE" in ADVICE_SYSTEM_PROMPT From be0a3b229949acc034f788313fb95de6717ef720 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:51 +0100 Subject: [PATCH 132/163] Add unit tests for live notes handler (10 tests) --- backend/tests/unit/test_desktop_live_notes.py | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 backend/tests/unit/test_desktop_live_notes.py diff --git a/backend/tests/unit/test_desktop_live_notes.py b/backend/tests/unit/test_desktop_live_notes.py new file mode 100644 index 0000000000..7969427ce2 --- /dev/null +++ b/backend/tests/unit/test_desktop_live_notes.py @@ -0,0 +1,99 @@ +"""Tests for desktop live notes handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.live_notes import ( + LiveNoteResult, + LIVE_NOTES_SYSTEM_PROMPT, + generate_live_note, +) +from models.message_event import LiveNoteEvent + + +class TestLiveNoteResultModel: + def test_note_with_text(self): + r = LiveNoteResult(text="Key decision: ship by Friday") + assert r.text == "Key decision: ship by Friday" + + def test_empty_note(self): + r = LiveNoteResult(text="") + assert r.text == "" + + +class TestLiveNoteEvent: + def test_event_structure(self): + event = LiveNoteEvent(text="Meeting note content") + data = event.to_json() + assert data["type"] == "live_note" + assert data["text"] == "Meeting note content" + + +class TestGenerateLiveNote: + @patch('utils.desktop.live_notes.llm_mini') + def test_returns_note(self, mock_llm): + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=LiveNoteResult(text="- Decision: use Redis for caching") + ) + result = asyncio.get_event_loop().run_until_complete( + generate_live_note("We decided to use Redis for caching the API responses") + ) + assert result["text"] == "- Decision: use Redis for caching" + + @patch('utils.desktop.live_notes.llm_mini') + def test_empty_result(self, mock_llm): + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=LiveNoteResult(text="")) + result = asyncio.get_event_loop().run_until_complete( + generate_live_note("um yeah so like um") + ) + assert result["text"] == "" + + @patch('utils.desktop.live_notes.llm_mini') + def test_includes_session_context(self, mock_llm): + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=LiveNoteResult(text="note")) + asyncio.get_event_loop().run_until_complete( + generate_live_note("transcript text", session_context="Sprint planning") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + assert "Sprint planning" in human_msg.content + + @patch('utils.desktop.live_notes.llm_mini') + def test_sends_system_prompt(self, mock_llm): + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=LiveNoteResult(text="")) + asyncio.get_event_loop().run_until_complete( + generate_live_note("test text") + ) + call_args = mock_parser.ainvoke.call_args[0][0] + sys_msg = call_args[0] + assert LIVE_NOTES_SYSTEM_PROMPT in sys_msg.content + + +class TestLiveNotesSystemPrompt: + def test_includes_condensation_rules(self): + assert "Condense" in LIVE_NOTES_SYSTEM_PROMPT + + def test_includes_word_limit(self): + assert "200 words" in LIVE_NOTES_SYSTEM_PROMPT + + def test_includes_preservation_rules(self): + assert "names" in LIVE_NOTES_SYSTEM_PROMPT + assert "decisions" in LIVE_NOTES_SYSTEM_PROMPT From 4197646cc84a77475f7c5415dd401fd03ebdcfc7 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:52 +0100 Subject: [PATCH 133/163] Add unit tests for profile handler (9 tests) --- backend/tests/unit/test_desktop_profile.py | 103 +++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 backend/tests/unit/test_desktop_profile.py diff --git a/backend/tests/unit/test_desktop_profile.py b/backend/tests/unit/test_desktop_profile.py new file mode 100644 index 0000000000..ea6fea4234 --- /dev/null +++ b/backend/tests/unit/test_desktop_profile.py @@ -0,0 +1,103 @@ +"""Tests for desktop profile generation handler (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.profile import ( + ProfileResult, + PROFILE_SYSTEM_PROMPT, + generate_profile, +) +from models.message_event import ProfileUpdatedEvent + + +class TestProfileResultModel: + def test_profile_text(self): + r = ProfileResult(profile_text="The user is a backend engineer focused on Python.") + assert "backend engineer" in r.profile_text + + +class TestProfileUpdatedEvent: + def test_event_structure(self): + event = ProfileUpdatedEvent(profile_text="User profile text") + data = event.to_json() + assert data["type"] == "profile_updated" + assert data["profile_text"] == "User profile text" + + +class TestGenerateProfile: + @patch('utils.desktop.profile.get_memories') + @patch('utils.desktop.profile.get_action_items') + @patch('utils.desktop.profile.get_user_goals') + @patch('utils.desktop.profile.llm_mini') + def test_generates_profile(self, mock_llm, mock_goals, mock_tasks, mock_memories): + mock_goals.return_value = [{'title': 'Ship v2'}] + mock_tasks.return_value = [{'description': 'Fix auth bug'}] + mock_memories.return_value = [{'structured': {'content': 'User prefers Python'}}] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=ProfileResult(profile_text="The user is a developer focused on shipping v2.") + ) + result = asyncio.get_event_loop().run_until_complete(generate_profile("uid1")) + assert "developer" in result["profile_text"] + + @patch('utils.desktop.profile.get_memories') + @patch('utils.desktop.profile.get_action_items') + @patch('utils.desktop.profile.get_user_goals') + def test_no_data_returns_default(self, mock_goals, mock_tasks, mock_memories): + mock_goals.return_value = [] + mock_tasks.return_value = [] + mock_memories.return_value = [] + result = asyncio.get_event_loop().run_until_complete(generate_profile("uid1")) + assert "No data available" in result["profile_text"] + + @patch('utils.desktop.profile.get_memories') + @patch('utils.desktop.profile.get_action_items') + @patch('utils.desktop.profile.get_user_goals') + @patch('utils.desktop.profile.llm_mini') + def test_graceful_on_db_errors(self, mock_llm, mock_goals, mock_tasks, mock_memories): + mock_goals.side_effect = Exception("DB error") + mock_tasks.side_effect = Exception("DB error") + mock_memories.side_effect = Exception("DB error") + result = asyncio.get_event_loop().run_until_complete(generate_profile("uid1")) + assert "No data available" in result["profile_text"] + + @patch('utils.desktop.profile.get_memories') + @patch('utils.desktop.profile.get_action_items') + @patch('utils.desktop.profile.get_user_goals') + @patch('utils.desktop.profile.llm_mini') + def test_includes_goals_in_prompt(self, mock_llm, mock_goals, mock_tasks, mock_memories): + mock_goals.return_value = [{'title': 'Learn Rust'}] + mock_tasks.return_value = [] + mock_memories.return_value = [] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=ProfileResult(profile_text="Profile text") + ) + asyncio.get_event_loop().run_until_complete(generate_profile("uid1")) + call_args = mock_parser.ainvoke.call_args[0][0] + human_msg = call_args[1] + assert "Learn Rust" in human_msg.content + + +class TestProfileSystemPrompt: + def test_third_person_format(self): + assert "third person" in PROFILE_SYSTEM_PROMPT + + def test_word_limit(self): + assert "300 words" in PROFILE_SYSTEM_PROMPT + + def test_factual_requirement(self): + assert "factual" in PROFILE_SYSTEM_PROMPT From daf72d03afedeb3e80217430a2bbb7dce626205b Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:25:52 +0100 Subject: [PATCH 134/163] Add unit tests for task rerank and dedup handlers (16 tests) --- backend/tests/unit/test_desktop_task_ops.py | 175 ++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 backend/tests/unit/test_desktop_task_ops.py diff --git a/backend/tests/unit/test_desktop_task_ops.py b/backend/tests/unit/test_desktop_task_ops.py new file mode 100644 index 0000000000..6cd9803a5f --- /dev/null +++ b/backend/tests/unit/test_desktop_task_ops.py @@ -0,0 +1,175 @@ +"""Tests for desktop task operations (rerank + dedup) handlers (Phase 2 — #5396).""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.modules.setdefault('firebase_admin', MagicMock()) +sys.modules.setdefault('firebase_admin.auth', MagicMock()) +sys.modules.setdefault('firebase_admin.firestore', MagicMock()) +sys.modules.setdefault('database._client', MagicMock()) +_mock_clients = MagicMock() +sys.modules.setdefault('utils.llm.clients', _mock_clients) + +from utils.desktop.task_ops import ( + RankedTask, + RerankResult, + DedupGroup, + DedupResult, + RERANK_SYSTEM_PROMPT, + DEDUP_SYSTEM_PROMPT, + rerank_tasks, + dedup_tasks, +) +from models.message_event import RerankCompleteEvent, DedupCompleteEvent + + +# --- Rerank tests --- + + +class TestRankedTaskModel: + def test_ranked_task(self): + t = RankedTask(id="task1", new_position=1) + assert t.id == "task1" + assert t.new_position == 1 + + +class TestRerankResult: + def test_rerank_result(self): + r = RerankResult(updated_tasks=[RankedTask(id="t1", new_position=1)]) + assert len(r.updated_tasks) == 1 + + +class TestRerankCompleteEvent: + def test_event_structure(self): + event = RerankCompleteEvent(updated_tasks=[{"id": "t1", "new_position": 1}]) + data = event.to_json() + assert data["type"] == "rerank_complete" + assert len(data["updated_tasks"]) == 1 + + +class TestRerankTasks: + @patch('utils.desktop.task_ops.get_action_items') + @patch('utils.desktop.task_ops.llm_mini') + def test_rerank_returns_order(self, mock_llm, mock_get): + mock_get.return_value = [ + {'id': 't1', 'description': 'Low priority', 'priority': 'low'}, + {'id': 't2', 'description': 'Urgent fix', 'priority': 'high', 'due_at': '2026-03-08'}, + ] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=RerankResult( + updated_tasks=[ + RankedTask(id="t2", new_position=1), + RankedTask(id="t1", new_position=2), + ] + ) + ) + result = asyncio.get_event_loop().run_until_complete(rerank_tasks("uid1")) + assert result["updated_tasks"][0]["id"] == "t2" + assert result["updated_tasks"][0]["new_position"] == 1 + + @patch('utils.desktop.task_ops.get_action_items') + def test_rerank_empty_tasks(self, mock_get): + mock_get.return_value = [] + result = asyncio.get_event_loop().run_until_complete(rerank_tasks("uid1")) + assert result["updated_tasks"] == [] + + @patch('utils.desktop.task_ops.get_action_items') + def test_rerank_db_error(self, mock_get): + mock_get.side_effect = Exception("DB error") + result = asyncio.get_event_loop().run_until_complete(rerank_tasks("uid1")) + assert result["updated_tasks"] == [] + + +# --- Dedup tests --- + + +class TestDedupGroupModel: + def test_dedup_group(self): + g = DedupGroup(keep_id="t1", delete_ids=["t2", "t3"], reason="Same task") + assert g.keep_id == "t1" + assert len(g.delete_ids) == 2 + + +class TestDedupResult: + def test_dedup_with_groups(self): + r = DedupResult(groups=[DedupGroup(keep_id="t1", delete_ids=["t2"], reason="Duplicate")]) + assert len(r.groups) == 1 + + def test_dedup_no_groups(self): + r = DedupResult() + assert r.groups == [] + + +class TestDedupCompleteEvent: + def test_event_structure(self): + event = DedupCompleteEvent(deleted_ids=["t2", "t3"], reason="Duplicate tasks") + data = event.to_json() + assert data["type"] == "dedup_complete" + assert data["deleted_ids"] == ["t2", "t3"] + assert data["reason"] == "Duplicate tasks" + + +class TestDedupTasks: + @patch('utils.desktop.task_ops.get_action_items') + @patch('utils.desktop.task_ops.llm_mini') + def test_dedup_finds_duplicates(self, mock_llm, mock_get): + mock_get.return_value = [ + {'id': 't1', 'description': 'Call John'}, + {'id': 't2', 'description': 'Phone John'}, + {'id': 't3', 'description': 'Write report'}, + ] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock( + return_value=DedupResult( + groups=[DedupGroup(keep_id="t1", delete_ids=["t2"], reason="Same action: contact John")] + ) + ) + result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1")) + assert result["deleted_ids"] == ["t2"] + assert "contact John" in result["reason"] + + @patch('utils.desktop.task_ops.get_action_items') + @patch('utils.desktop.task_ops.llm_mini') + def test_dedup_no_duplicates(self, mock_llm, mock_get): + mock_get.return_value = [ + {'id': 't1', 'description': 'Task A'}, + {'id': 't2', 'description': 'Task B'}, + ] + mock_parser = MagicMock() + mock_llm.with_structured_output.return_value = mock_parser + mock_parser.ainvoke = AsyncMock(return_value=DedupResult()) + result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1")) + assert result["deleted_ids"] == [] + assert result["reason"] == "No duplicates found" + + @patch('utils.desktop.task_ops.get_action_items') + def test_dedup_too_few_tasks(self, mock_get): + mock_get.return_value = [{'id': 't1', 'description': 'Only one'}] + result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1")) + assert result["deleted_ids"] == [] + assert "Not enough" in result["reason"] + + @patch('utils.desktop.task_ops.get_action_items') + def test_dedup_db_error(self, mock_get): + mock_get.side_effect = Exception("DB error") + result = asyncio.get_event_loop().run_until_complete(dedup_tasks("uid1")) + assert result["deleted_ids"] == [] + assert "Failed" in result["reason"] + + +class TestRerankSystemPrompt: + def test_includes_rules(self): + assert "RULES" in RERANK_SYSTEM_PROMPT + assert "deadlines" in RERANK_SYSTEM_PROMPT + + +class TestDedupSystemPrompt: + def test_includes_rules(self): + assert "RULES" in DEDUP_SYSTEM_PROMPT + assert "duplicates" in DEDUP_SYSTEM_PROMPT.lower() From 77da192b18f17a19be2922ae6fafbd082b91212b Mon Sep 17 00:00:00 2001 From: beastoin Date: Sat, 7 Mar 2026 06:26:02 +0100 Subject: [PATCH 135/163] Add all desktop handler tests to test.sh --- backend/test.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backend/test.sh b/backend/test.sh index d3c5640275..af5b8e9e7a 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -44,3 +44,9 @@ pytest tests/unit/test_people_conversations_500s.py -v pytest tests/unit/test_firestore_read_ops_cache.py -v pytest tests/unit/test_ws_auth_handshake.py -v pytest tests/unit/test_desktop_focus.py -v +pytest tests/unit/test_desktop_tasks.py -v +pytest tests/unit/test_desktop_memories.py -v +pytest tests/unit/test_desktop_advice.py -v +pytest tests/unit/test_desktop_live_notes.py -v +pytest tests/unit/test_desktop_profile.py -v +pytest tests/unit/test_desktop_task_ops.py -v From 31b8100810b4719539fef6fdab69cd17c0eca7c5 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 08:57:19 +0100 Subject: [PATCH 136/163] Add BackendProactiveService for server-side proactive AI (#5396) WebSocket client that connects to /v4/listen with Bearer auth and sends screen_frame JSON messages. Routes focus_result responses back to callers via async continuations with frame_id correlation. Co-Authored-By: Claude Opus 4.6 --- .../Core/BackendProactiveService.swift | 358 ++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift new file mode 100644 index 0000000000..e9b69ffe3b --- /dev/null +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift @@ -0,0 +1,358 @@ +import Foundation + +/// WebSocket client for desktop proactive AI via /v4/listen. +/// Sends typed JSON messages (screen_frame, etc.) and routes typed responses +/// (focus_result, etc.) back to callers via async continuations. +/// +/// This is the Phase 2 replacement for direct GeminiClient calls — all LLM +/// processing happens server-side; the client just sends screenshots and +/// receives structured results. +class BackendProactiveService { + + // MARK: - Types + + enum ServiceError: LocalizedError { + case missingAPIURL + case authFailed(String) + case notConnected + case timeout + case serverError(String) + + var errorDescription: String? { + switch self { + case .missingAPIURL: return "OMI_API_URL not set" + case .authFailed(let reason): return "Auth failed: \(reason)" + case .notConnected: return "Backend WebSocket not connected" + case .timeout: return "Request timed out" + case .serverError(let msg): return "Server error: \(msg)" + } + } + } + + // MARK: - Properties + + private var webSocketTask: URLSessionWebSocketTask? + private var urlSession: URLSession? + private(set) var isConnected = false + private var shouldReconnect = false + private var reconnectAttempts = 0 + private let maxReconnectAttempts = 10 + private var reconnectTask: Task? + + // Keepalive + private var keepaliveTask: Task? + private let keepaliveInterval: TimeInterval = 30.0 + + // Pending request continuations keyed by frame_id + private var pendingFocusRequests: [String: CheckedContinuation] = [:] + private let requestLock = NSLock() + private let requestTimeout: TimeInterval = 30.0 + + // MARK: - Connection + + func connect() { + shouldReconnect = true + reconnectAttempts = 0 + startConnect() + } + + func disconnect() { + shouldReconnect = false + reconnectTask?.cancel() + reconnectTask = nil + keepaliveTask?.cancel() + keepaliveTask = nil + + isConnected = false + webSocketTask?.cancel(with: .normalClosure, reason: nil) + webSocketTask = nil + urlSession?.invalidateAndCancel() + urlSession = nil + + // Cancel all pending requests + cancelAllPending(error: ServiceError.notConnected) + + log("BackendProactiveService: Disconnected") + } + + // MARK: - Public API + + /// Send a screen_frame for focus analysis and wait for the focus_result response. + func analyzeFocus( + imageBase64: String, + appName: String, + windowTitle: String + ) async throws -> ScreenAnalysis { + guard isConnected else { + throw ServiceError.notConnected + } + + let frameId = UUID().uuidString + + let message: [String: Any] = [ + "type": "screen_frame", + "frame_id": frameId, + "image_b64": imageBase64, + "app_name": appName, + "window_title": windowTitle, + "analyze": ["focus"], + ] + + let jsonData = try JSONSerialization.data(withJSONObject: message) + guard let jsonString = String(data: jsonData, encoding: .utf8) else { + throw ServiceError.serverError("Failed to encode message") + } + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingFocusRequests[frameId] = continuation + requestLock.unlock() + + webSocketTask?.send(.string(jsonString)) { [weak self] error in + if let error = error { + self?.requestLock.lock() + let cont = self?.pendingFocusRequests.removeValue(forKey: frameId) + self?.requestLock.unlock() + cont?.resume(throwing: error) + } + } + + // Timeout guard + Task { [weak self] in + try? await Task.sleep(nanoseconds: UInt64((self?.requestTimeout ?? 30.0) * 1_000_000_000)) + self?.requestLock.lock() + let cont = self?.pendingFocusRequests.removeValue(forKey: frameId) + self?.requestLock.unlock() + cont?.resume(throwing: ServiceError.timeout) + } + } + } + + // MARK: - Connection Internals + + private func startConnect() { + guard let baseURL = Self.getBaseURL() else { + log("BackendProactiveService: OMI_API_URL not set") + return + } + + Task { + do { + let idToken = try await AuthService.shared.getIdToken() + await connectWithToken(baseURL: baseURL, token: idToken) + } catch { + logError("BackendProactiveService: Failed to get ID token", error: error) + handleDisconnection() + } + } + } + + private func connectWithToken(baseURL: String, token: String) async { + let wsURL = baseURL + .replacingOccurrences(of: "https://", with: "wss://") + .replacingOccurrences(of: "http://", with: "ws://") + let base = wsURL.hasSuffix("/") ? wsURL : wsURL + "/" + + // Connect to /v4/listen with source=desktop — same endpoint as audio, + // but we only send JSON messages (no audio data) + var components = URLComponents(string: "\(base)v4/listen")! + components.queryItems = [ + URLQueryItem(name: "source", value: "desktop"), + URLQueryItem(name: "sample_rate", value: "16000"), + URLQueryItem(name: "codec", value: "pcm16"), + URLQueryItem(name: "channels", value: "1"), + URLQueryItem(name: "language", value: "en"), + ] + + guard let url = components.url else { + log("BackendProactiveService: Invalid URL") + return + } + + log("BackendProactiveService: Connecting to \(url.absoluteString)") + + var request = URLRequest(url: url) + request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization") + request.timeoutInterval = 30 + + let configuration = URLSessionConfiguration.default + configuration.timeoutIntervalForResource = 0 + urlSession = URLSession(configuration: configuration) + webSocketTask = urlSession?.webSocketTask(with: request) + webSocketTask?.resume() + + receiveMessage() + + // Confirm connection after short delay + DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in + guard let self = self, self.webSocketTask?.state == .running else { + self?.handleDisconnection() + return + } + self.isConnected = true + self.reconnectAttempts = 0 + self.startKeepalive() + log("BackendProactiveService: Connected") + } + } + + private func startKeepalive() { + keepaliveTask?.cancel() + keepaliveTask = Task { [weak self] in + while !Task.isCancelled { + try? await Task.sleep(nanoseconds: UInt64((self?.keepaliveInterval ?? 30.0) * 1_000_000_000)) + guard !Task.isCancelled, let self = self, self.isConnected else { break } + self.sendKeepalive() + } + } + } + + private func sendKeepalive() { + guard isConnected, let ws = webSocketTask else { return } + ws.send(.string("{\"type\": \"KeepAlive\"}")) { [weak self] error in + if let error = error { + logError("BackendProactiveService: Keepalive error", error: error) + self?.handleDisconnection() + } + } + } + + private func handleDisconnection() { + guard isConnected || shouldReconnect else { return } + + isConnected = false + keepaliveTask?.cancel() + keepaliveTask = nil + webSocketTask?.cancel(with: .goingAway, reason: nil) + webSocketTask = nil + urlSession?.invalidateAndCancel() + urlSession = nil + + cancelAllPending(error: ServiceError.notConnected) + + if shouldReconnect && reconnectAttempts < maxReconnectAttempts { + reconnectAttempts += 1 + let delay = min(pow(2.0, Double(reconnectAttempts)), 32.0) + log("BackendProactiveService: Reconnecting in \(delay)s (attempt \(reconnectAttempts))") + + reconnectTask = Task { + try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) + guard !Task.isCancelled, self.shouldReconnect else { return } + self.startConnect() + } + } else if reconnectAttempts >= maxReconnectAttempts { + log("BackendProactiveService: Max reconnect attempts reached") + } + } + + // MARK: - Message Handling + + private func receiveMessage() { + webSocketTask?.receive { [weak self] result in + guard let self = self else { return } + + switch result { + case .success(let message): + self.handleMessage(message) + self.receiveMessage() + case .failure(let error): + guard self.isConnected else { return } + logError("BackendProactiveService: Receive error", error: error) + self.handleDisconnection() + } + } + } + + private func handleMessage(_ message: URLSessionWebSocketTask.Message) { + let text: String + switch message { + case .string(let s): + text = s + case .data(let data): + guard let s = String(data: data, encoding: .utf8) else { return } + text = s + @unknown default: + return + } + + // Skip heartbeat + if text == "ping" { return } + + guard let data = text.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + let type = json["type"] as? String else { + return + } + + switch type { + case "focus_result": + handleFocusResult(data) + default: + // Other event types (memory_created, etc.) — ignore for now + break + } + } + + private func handleFocusResult(_ data: Data) { + guard let response = try? JSONDecoder().decode(FocusResultResponse.self, from: data) else { + log("BackendProactiveService: Failed to decode focus_result") + return + } + + let analysis = ScreenAnalysis( + status: FocusStatus(rawValue: response.status) ?? .focused, + appOrSite: response.appOrSite, + description: response.description, + message: response.message + ) + + requestLock.lock() + let continuation = pendingFocusRequests.removeValue(forKey: response.frameId) + requestLock.unlock() + + continuation?.resume(returning: analysis) + } + + // MARK: - Helpers + + private func cancelAllPending(error: Error) { + requestLock.lock() + let pending = pendingFocusRequests + pendingFocusRequests.removeAll() + requestLock.unlock() + + for (_, continuation) in pending { + continuation.resume(throwing: error) + } + } + + private static func getBaseURL() -> String? { + if let cString = getenv("OMI_API_URL"), let url = String(validatingUTF8: cString), !url.isEmpty { + return url + } + if let envURL = ProcessInfo.processInfo.environment["OMI_API_URL"], !envURL.isEmpty { + return envURL + } + return nil + } +} + +// MARK: - Response Models + +private struct FocusResultResponse: Decodable { + let type: String + let frameId: String + let status: String + let appOrSite: String + let description: String + let message: String? + + enum CodingKeys: String, CodingKey { + case type + case frameId = "frame_id" + case status + case appOrSite = "app_or_site" + case description + case message + } +} From 344a553f3c00fa8b181bdb11b51d9e2398b7ce08 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 08:57:25 +0100 Subject: [PATCH 137/163] Wire FocusAssistant to BackendProactiveService instead of GeminiClient (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace direct Gemini API calls with backend WebSocket screen_frame messages. Context building (goals, tasks, memories, AI profile) moves server-side. Client becomes thin: encode JPEG→base64, send screen_frame, receive focus_result. Co-Authored-By: Claude Opus 4.6 --- .../Assistants/Focus/FocusAssistant.swift | 226 ++---------------- 1 file changed, 16 insertions(+), 210 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift index 33a355f56b..88eaddfc7c 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Focus/FocusAssistant.swift @@ -17,7 +17,7 @@ actor FocusAssistant: ProactiveAssistant { // MARK: - Properties - private let geminiClient: GeminiClient + private let backendService: BackendProactiveService private let onAlert: (String) -> Void private let onStatusChange: ((FocusStatus) -> Void)? private let onRefocus: (() -> Void)? @@ -35,12 +35,6 @@ actor FocusAssistant: ProactiveAssistant { private let maxPendingTasks = 3 private var currentApp: String? - // MARK: - Context Cache - // Cached context from local DB (goals, tasks, memories) to enrich focus analysis - private var cachedContextString: String? - private var contextCacheTime: Date? - private let contextCacheDuration: TimeInterval = 120 // 2 minutes - // MARK: - Smart Analysis Filtering // Skip analysis when user is focused on the same context (app + window title) // Also skip during cooldown period after distraction (unless context changes) @@ -58,25 +52,16 @@ actor FocusAssistant: ProactiveAssistant { private var consecutiveErrorCount = 0 private var errorBackoffEndTime: Date? - /// Get the current system prompt from settings (accessed on MainActor for thread safety) - private var systemPrompt: String { - get async { - await MainActor.run { - FocusAssistantSettings.shared.analysisPrompt - } - } - } - // MARK: - Initialization init( - apiKey: String? = nil, + backendService: BackendProactiveService, onAlert: @escaping (String) -> Void = { _ in }, onStatusChange: ((FocusStatus) -> Void)? = nil, onRefocus: (() -> Void)? = nil, onDistraction: (() -> Void)? = nil - ) throws { - self.geminiClient = try GeminiClient(apiKey: apiKey) + ) { + self.backendService = backendService self.onAlert = onAlert self.onStatusChange = onStatusChange self.onRefocus = onRefocus @@ -299,8 +284,6 @@ actor FocusAssistant: ProactiveAssistant { analysisCooldownEndTime = nil consecutiveErrorCount = 0 errorBackoffEndTime = nil - cachedContextString = nil - contextCacheTime = nil // Clear cooldown in UI await MainActor.run { @@ -353,99 +336,26 @@ actor FocusAssistant: ProactiveAssistant { /// Run analysis on a screenshot with no side effects (no saving, no state updates, no notifications). /// Used by the test runner GUI and CLI. func testAnalyze(jpegData: Data, appName: String) async throws -> ScreenAnalysis? { - return try await analyzeScreenshot(jpegData: jpegData) + return try await analyzeScreenshot(jpegData: jpegData, appName: appName, windowTitle: nil) } /// Reset test history — call before starting a test run to get a clean slate. func resetTestHistory() { - testAnalysisHistory.removeAll() + // History is now tracked server-side; no-op on client } /// Run analysis with accumulating history across calls (simulates production behavior). - /// Each result is appended to a separate test history buffer so the model sees prior decisions. + /// History is tracked server-side per WebSocket session, so this is equivalent to testAnalyze. func testAnalyzeWithHistory(jpegData: Data, appName: String) async throws -> ScreenAnalysis? { - let result = try await analyzeScreenshotWithHistory(jpegData: jpegData, history: testAnalysisHistory) - if let result = result { - testAnalysisHistory.append(result) - if testAnalysisHistory.count > maxHistorySize { - testAnalysisHistory.removeFirst() - } - } - return result - } - - /// Separate history buffer for test runs (doesn't pollute production history) - private var testAnalysisHistory: [ScreenAnalysis] = [] - - /// Variant of analyzeScreenshot that accepts an explicit history array - private func analyzeScreenshotWithHistory(jpegData: Data, history: [ScreenAnalysis]) async throws -> ScreenAnalysis? { - let context = await refreshContext() - - // Format provided history - var historyText = "" - if !history.isEmpty { - var lines = ["Recent activity (oldest to newest):"] - for (i, past) in history.enumerated() { - lines.append("\(i + 1). [\(past.status.rawValue)] \(past.appOrSite): \(past.description)") - if let message = past.message { - lines.append(" Message: \(message)") - } - } - historyText = lines.joined(separator: "\n") - } - - var promptParts: [String] = [] - if !context.isEmpty { - promptParts.append(context) - } - if !historyText.isEmpty { - promptParts.append(historyText) - } - promptParts.append("Now analyze this new screenshot:") - let prompt = promptParts.joined(separator: "\n\n") - - let currentSystemPrompt = await systemPrompt - - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "status": .init(type: "string", enum: ["focused", "distracted"], description: "Whether the user is focused or distracted"), - "app_or_site": .init(type: "string", enum: nil, description: "The app or website visible"), - "description": .init(type: "string", enum: nil, description: "Brief description of what's on screen"), - "message": .init(type: "string", enum: nil, description: "Coaching message") - ], - required: ["status", "app_or_site", "description"] - ) - - let responseText = try await geminiClient.sendRequest( - prompt: prompt, - imageData: jpegData, - systemPrompt: currentSystemPrompt, - responseSchema: responseSchema - ) - - return try JSONDecoder().decode(ScreenAnalysis.self, from: Data(responseText.utf8)) + return try await analyzeScreenshot(jpegData: jpegData, appName: appName, windowTitle: nil) } // MARK: - Analysis - private func formatHistory() -> String { - guard !analysisHistory.isEmpty else { return "" } - - var lines = ["Recent activity (oldest to newest):"] - for (i, past) in analysisHistory.enumerated() { - lines.append("\(i + 1). [\(past.status.rawValue)] \(past.appOrSite): \(past.description)") - if let message = past.message { - lines.append(" Message: \(message)") - } - } - return lines.joined(separator: "\n") - } - private func processFrame(_ frame: CapturedFrame) async { guard await isEnabled else { return } do { - guard let analysis = try await analyzeScreenshot(jpegData: frame.jpegData) else { + guard let analysis = try await analyzeScreenshot(jpegData: frame.jpegData, appName: frame.appName, windowTitle: frame.windowTitle) else { return } @@ -585,118 +495,14 @@ actor FocusAssistant: ProactiveAssistant { } } - /// Refresh context from local DB (goals, tasks, memories) with caching - private func refreshContext() async -> String { - // Return cached context if fresh - if let cached = cachedContextString, - let cacheTime = contextCacheTime, - Date().timeIntervalSince(cacheTime) < contextCacheDuration { - return cached - } - - var sections: [String] = [] - - // AI User Profile - do { - if let profile = await AIUserProfileService.shared.getLatestProfile() { - sections.append("USER PROFILE (who this user is):\n\(profile.profileText)") - } - } - - // Time context - let formatter = DateFormatter() - formatter.dateFormat = "EEEE, MMMM d, yyyy 'at' h:mm a" - sections.append("TIME CONTEXT:\n\(formatter.string(from: Date()))") - - // Active goals - do { - let goals = try await GoalStorage.shared.getLocalGoals(activeOnly: true) - if !goals.isEmpty { - var lines = ["ACTIVE GOALS:"] - for (i, goal) in goals.prefix(10).enumerated() { - let desc = goal.description.map { " - \($0)" } ?? "" - lines.append("\(i + 1). \(goal.title)\(desc)") - } - sections.append(lines.joined(separator: "\n")) - } - } catch { - logError("Focus: Failed to load goals for context", error: error) - } - - // Top tasks by importance - do { - let tasks = try await ActionItemStorage.shared.getTopRelevanceTasks(limit: 50) - if !tasks.isEmpty { - var lines = ["CURRENT TASKS (by importance):"] - for (i, task) in tasks.enumerated() { - let priority = task.priority ?? "medium" - lines.append("\(i + 1). [\(priority)] \(task.description)") - } - sections.append(lines.joined(separator: "\n")) - } - } catch { - logError("Focus: Failed to load tasks for context", error: error) - } - - // Recent memories - do { - let memories = try await MemoryStorage.shared.getLocalMemories(limit: 50, category: "core") - if !memories.isEmpty { - var lines = ["RECENT MEMORIES:"] - for (i, memory) in memories.enumerated() { - lines.append("\(i + 1). \(memory.content)") - } - sections.append(lines.joined(separator: "\n")) - } - } catch { - logError("Focus: Failed to load memories for context", error: error) - } - - let contextString = sections.joined(separator: "\n\n") - cachedContextString = contextString - contextCacheTime = Date() - return contextString - } - - private func analyzeScreenshot(jpegData: Data) async throws -> ScreenAnalysis? { - // Refresh context from local DB - let context = await refreshContext() - - // Build prompt with context + history - let historyText = formatHistory() - var promptParts: [String] = [] - if !context.isEmpty { - promptParts.append(context) - } - if !historyText.isEmpty { - promptParts.append(historyText) - } - promptParts.append("Now analyze this new screenshot:") - let prompt = promptParts.joined(separator: "\n\n") - - // Get current system prompt from settings - let currentSystemPrompt = await systemPrompt - - // Build response schema - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "status": .init(type: "string", enum: ["focused", "distracted"], description: "Whether the user is focused or distracted"), - "app_or_site": .init(type: "string", enum: nil, description: "The app or website visible"), - "description": .init(type: "string", enum: nil, description: "Brief description of what's on screen"), - "message": .init(type: "string", enum: nil, description: "Coaching message") - ], - required: ["status", "app_or_site", "description"] + private func analyzeScreenshot(jpegData: Data, appName: String, windowTitle: String?) async throws -> ScreenAnalysis? { + let base64 = jpegData.base64EncodedString() + let result = try await backendService.analyzeFocus( + imageBase64: base64, + appName: appName, + windowTitle: windowTitle ?? "" ) - - let responseText = try await geminiClient.sendRequest( - prompt: prompt, - imageData: jpegData, - systemPrompt: currentSystemPrompt, - responseSchema: responseSchema - ) - - return try JSONDecoder().decode(ScreenAnalysis.self, from: Data(responseText.utf8)) + return result } // MARK: - Storage From b29b882a583fa6764987dd7f8a77e89601aa5bf7 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 08:57:31 +0100 Subject: [PATCH 138/163] Create BackendProactiveService in ProactiveAssistantsPlugin lifecycle (#5396) Start WS connection when monitoring starts, disconnect on stop. Pass service to FocusAssistant (shared for future assistant types). Co-Authored-By: Claude Opus 4.6 --- .../ProactiveAssistantsPlugin.swift | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift index bd00065057..d6d3409675 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift @@ -14,6 +14,7 @@ public class ProactiveAssistantsPlugin: NSObject { private var screenCaptureService: ScreenCaptureService? private var windowMonitor: WindowMonitor? + private var backendProactiveService: BackendProactiveService? private var focusAssistant: FocusAssistant? /// Public read-only accessor for memory diagnostics @@ -319,8 +320,14 @@ public class ProactiveAssistantsPlugin: NSObject { // Initialize services screenCaptureService = ScreenCaptureService() + // Start backend proactive AI WebSocket (Phase 2 — server-side LLM) + let proactiveService = BackendProactiveService() + proactiveService.connect() + backendProactiveService = proactiveService + do { - focusAssistant = try FocusAssistant( + focusAssistant = FocusAssistant( + backendService: proactiveService, onAlert: { [weak self] message in self?.sendEvent(type: "alert", data: ["message": message]) }, @@ -459,6 +466,8 @@ public class ProactiveAssistantsPlugin: NSObject { } } + backendProactiveService?.disconnect() + backendProactiveService = nil focusAssistant = nil taskAssistant = nil adviceAssistant = nil From 1e876f16a5fab632c799407360a6100b6da672d0 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 08:57:36 +0100 Subject: [PATCH 139/163] Update FocusTestRunnerWindow for new FocusAssistant init signature (#5396) Co-Authored-By: Claude Opus 4.6 --- .../ProactiveAssistants/UI/FocusTestRunnerWindow.swift | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift b/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift index ca5c21034a..0ae19fd68e 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/UI/FocusTestRunnerWindow.swift @@ -637,12 +637,9 @@ enum FocusTestRunner { if let existing = coordAssistant as? FocusAssistant { focusAssistant = existing } else { - do { - focusAssistant = try FocusAssistant() - } catch { - log("FocusTestCLI: ERROR — Failed to create FocusAssistant: \(error)") - return - } + let service = BackendProactiveService() + service.connect() + focusAssistant = FocusAssistant(backendService: service) } // Get excluded apps From c4b9f3e278cf90d9e51899084751c4deb56885ee Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:15:12 +0100 Subject: [PATCH 140/163] Add all 8 message types to BackendProactiveService (#5396) Vision handlers: analyzeFocus, extractTasks, extractMemories, generateAdvice (send screen_frame with analyze type, receive typed result via frame_id) Text handlers: generateLiveNote, requestProfile, rerankTasks, deduplicateTasks (send typed JSON message, receive result via single-slot continuation) Co-Authored-By: Claude Opus 4.6 --- .../Core/BackendProactiveService.swift | 376 ++++++++++++++---- 1 file changed, 305 insertions(+), 71 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift index e9b69ffe3b..1d473ba4c8 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/BackendProactiveService.swift @@ -43,10 +43,21 @@ class BackendProactiveService { private var keepaliveTask: Task? private let keepaliveInterval: TimeInterval = 30.0 - // Pending request continuations keyed by frame_id + // Pending continuations keyed by frame_id (vision handlers) private var pendingFocusRequests: [String: CheckedContinuation] = [:] + private var pendingTasksRequests: [String: CheckedContinuation] = [:] + private var pendingMemoriesRequests: [String: CheckedContinuation] = [:] + private var pendingAdviceRequests: [String: CheckedContinuation] = [:] + + // Pending continuations for text-only handlers (one outstanding per type) + private var pendingLiveNote: CheckedContinuation? + private var pendingProfile: CheckedContinuation? + private var pendingRerank: CheckedContinuation? + private var pendingDedup: CheckedContinuation? + private let requestLock = NSLock() private let requestTimeout: TimeInterval = 30.0 + private let textRequestTimeout: TimeInterval = 60.0 // MARK: - Connection @@ -69,63 +80,191 @@ class BackendProactiveService { urlSession?.invalidateAndCancel() urlSession = nil - // Cancel all pending requests cancelAllPending(error: ServiceError.notConnected) - log("BackendProactiveService: Disconnected") } - // MARK: - Public API + // MARK: - Vision Handlers (screen_frame) /// Send a screen_frame for focus analysis and wait for the focus_result response. - func analyzeFocus( - imageBase64: String, - appName: String, - windowTitle: String - ) async throws -> ScreenAnalysis { - guard isConnected else { - throw ServiceError.notConnected + func analyzeFocus(imageBase64: String, appName: String, windowTitle: String) async throws -> ScreenAnalysis { + guard isConnected else { throw ServiceError.notConnected } + let frameId = UUID().uuidString + let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["focus"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingFocusRequests[frameId] = continuation + requestLock.unlock() + sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout, + remove: { self.pendingFocusRequests.removeValue(forKey: $0) }) + } + } + + /// Send a screen_frame for task extraction. + func extractTasks(imageBase64: String, appName: String, windowTitle: String) async throws -> TasksExtractedResult { + guard isConnected else { throw ServiceError.notConnected } + let frameId = UUID().uuidString + let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["tasks"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingTasksRequests[frameId] = continuation + requestLock.unlock() + sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout, + remove: { self.pendingTasksRequests.removeValue(forKey: $0) }) } + } + /// Send a screen_frame for memory extraction. + func extractMemories(imageBase64: String, appName: String, windowTitle: String) async throws -> MemoriesExtractedResult { + guard isConnected else { throw ServiceError.notConnected } let frameId = UUID().uuidString + let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["memories"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingMemoriesRequests[frameId] = continuation + requestLock.unlock() + sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout, + remove: { self.pendingMemoriesRequests.removeValue(forKey: $0) }) + } + } + + /// Send a screen_frame for advice generation. + func generateAdvice(imageBase64: String, appName: String, windowTitle: String) async throws -> AdviceExtractedResult { + guard isConnected else { throw ServiceError.notConnected } + let frameId = UUID().uuidString + let jsonString = try buildScreenFrameJSON(frameId: frameId, analyzeTypes: ["advice"], imageBase64: imageBase64, appName: appName, windowTitle: windowTitle) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingAdviceRequests[frameId] = continuation + requestLock.unlock() + sendAndTimeout(jsonString: jsonString, frameId: frameId, timeout: requestTimeout, + remove: { self.pendingAdviceRequests.removeValue(forKey: $0) }) + } + } + + // MARK: - Text-Only Handlers + + /// Send transcript text for live note generation. + func generateLiveNote(text: String, sessionContext: String = "") async throws -> String { + guard isConnected else { throw ServiceError.notConnected } + let jsonString = try buildJSON(["type": "live_notes_text", "text": text, "session_context": sessionContext]) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingLiveNote = continuation + requestLock.unlock() + sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout, + remove: { let c = self.pendingLiveNote; self.pendingLiveNote = nil; return c }) + } + } + + /// Request profile generation (server fetches user data from Firestore). + func requestProfile() async throws -> String { + guard isConnected else { throw ServiceError.notConnected } + let jsonString = try buildJSON(["type": "profile_request"]) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingProfile = continuation + requestLock.unlock() + sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout, + remove: { let c = self.pendingProfile; self.pendingProfile = nil; return c }) + } + } + + /// Request task reranking (server fetches tasks from Firestore). + func rerankTasks() async throws -> RerankExtractedResult { + guard isConnected else { throw ServiceError.notConnected } + let jsonString = try buildJSON(["type": "task_rerank"]) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingRerank = continuation + requestLock.unlock() + sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout, + remove: { let c = self.pendingRerank; self.pendingRerank = nil; return c }) + } + } + + /// Request task deduplication (server fetches tasks from Firestore). + func deduplicateTasks() async throws -> DedupExtractedResult { + guard isConnected else { throw ServiceError.notConnected } + let jsonString = try buildJSON(["type": "task_dedup"]) + + return try await withCheckedThrowingContinuation { continuation in + requestLock.lock() + pendingDedup = continuation + requestLock.unlock() + sendAndTimeoutSingle(jsonString: jsonString, timeout: textRequestTimeout, + remove: { let c = self.pendingDedup; self.pendingDedup = nil; return c }) + } + } + + // MARK: - Send Helpers - let message: [String: Any] = [ + private func buildScreenFrameJSON(frameId: String, analyzeTypes: [String], imageBase64: String, appName: String, windowTitle: String) throws -> String { + try buildJSON([ "type": "screen_frame", "frame_id": frameId, "image_b64": imageBase64, "app_name": appName, "window_title": windowTitle, - "analyze": ["focus"], - ] + "analyze": analyzeTypes, + ]) + } - let jsonData = try JSONSerialization.data(withJSONObject: message) - guard let jsonString = String(data: jsonData, encoding: .utf8) else { + private func buildJSON(_ dict: [String: Any]) throws -> String { + let data = try JSONSerialization.data(withJSONObject: dict) + guard let str = String(data: data, encoding: .utf8) else { throw ServiceError.serverError("Failed to encode message") } + return str + } - return try await withCheckedThrowingContinuation { continuation in - requestLock.lock() - pendingFocusRequests[frameId] = continuation - requestLock.unlock() - - webSocketTask?.send(.string(jsonString)) { [weak self] error in - if let error = error { - self?.requestLock.lock() - let cont = self?.pendingFocusRequests.removeValue(forKey: frameId) - self?.requestLock.unlock() - cont?.resume(throwing: error) - } + /// Send JSON and set up timeout for frame_id-keyed continuations. + private func sendAndTimeout(jsonString: String, frameId: String, timeout: TimeInterval, + remove: @escaping (String) -> CheckedContinuation?) { + webSocketTask?.send(.string(jsonString)) { [weak self] error in + if let error = error { + self?.requestLock.lock() + let cont = remove(frameId) + self?.requestLock.unlock() + cont?.resume(throwing: error) } + } + + Task { [weak self] in + try? await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + self?.requestLock.lock() + let cont = remove(frameId) + self?.requestLock.unlock() + cont?.resume(throwing: ServiceError.timeout) + } + } - // Timeout guard - Task { [weak self] in - try? await Task.sleep(nanoseconds: UInt64((self?.requestTimeout ?? 30.0) * 1_000_000_000)) + /// Send JSON and set up timeout for single-slot continuations. + private func sendAndTimeoutSingle(jsonString: String, timeout: TimeInterval, + remove: @escaping () -> CheckedContinuation?) { + webSocketTask?.send(.string(jsonString)) { [weak self] error in + if let error = error { self?.requestLock.lock() - let cont = self?.pendingFocusRequests.removeValue(forKey: frameId) + let cont = remove() self?.requestLock.unlock() - cont?.resume(throwing: ServiceError.timeout) + cont?.resume(throwing: error) } } + + Task { [weak self] in + try? await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + self?.requestLock.lock() + let cont = remove() + self?.requestLock.unlock() + cont?.resume(throwing: ServiceError.timeout) + } } // MARK: - Connection Internals @@ -153,8 +292,6 @@ class BackendProactiveService { .replacingOccurrences(of: "http://", with: "ws://") let base = wsURL.hasSuffix("/") ? wsURL : wsURL + "/" - // Connect to /v4/listen with source=desktop — same endpoint as audio, - // but we only send JSON messages (no audio data) var components = URLComponents(string: "\(base)v4/listen")! components.queryItems = [ URLQueryItem(name: "source", value: "desktop"), @@ -183,7 +320,6 @@ class BackendProactiveService { receiveMessage() - // Confirm connection after short delay DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in guard let self = self, self.webSocketTask?.state == .running else { self?.handleDisconnection() @@ -275,7 +411,6 @@ class BackendProactiveService { return } - // Skip heartbeat if text == "ping" { return } guard let data = text.data(using: .utf8), @@ -286,44 +421,132 @@ class BackendProactiveService { switch type { case "focus_result": - handleFocusResult(data) + handleFocusResult(json) + case "tasks_extracted": + handleTasksExtracted(json) + case "memories_extracted": + handleMemoriesExtracted(json) + case "advice_extracted": + handleAdviceExtracted(json) + case "live_note": + handleLiveNote(json) + case "profile_updated": + handleProfileUpdated(json) + case "rerank_complete": + handleRerankComplete(json) + case "dedup_complete": + handleDedupComplete(json) default: - // Other event types (memory_created, etc.) — ignore for now break } } - private func handleFocusResult(_ data: Data) { - guard let response = try? JSONDecoder().decode(FocusResultResponse.self, from: data) else { - log("BackendProactiveService: Failed to decode focus_result") - return - } + // MARK: - Response Handlers + private func handleFocusResult(_ json: [String: Any]) { + guard let frameId = json["frame_id"] as? String else { return } let analysis = ScreenAnalysis( - status: FocusStatus(rawValue: response.status) ?? .focused, - appOrSite: response.appOrSite, - description: response.description, - message: response.message + status: FocusStatus(rawValue: json["status"] as? String ?? "focused") ?? .focused, + appOrSite: json["app_or_site"] as? String ?? "", + description: json["description"] as? String ?? "", + message: json["message"] as? String ) + requestLock.lock() + let cont = pendingFocusRequests.removeValue(forKey: frameId) + requestLock.unlock() + cont?.resume(returning: analysis) + } + + private func handleTasksExtracted(_ json: [String: Any]) { + guard let frameId = json["frame_id"] as? String else { return } + let tasks = (json["tasks"] as? [[String: Any]]) ?? [] + let result = TasksExtractedResult(frameId: frameId, tasks: tasks) + requestLock.lock() + let cont = pendingTasksRequests.removeValue(forKey: frameId) + requestLock.unlock() + cont?.resume(returning: result) + } + + private func handleMemoriesExtracted(_ json: [String: Any]) { + guard let frameId = json["frame_id"] as? String else { return } + let memories = (json["memories"] as? [[String: Any]]) ?? [] + let result = MemoriesExtractedResult(frameId: frameId, memories: memories) + requestLock.lock() + let cont = pendingMemoriesRequests.removeValue(forKey: frameId) + requestLock.unlock() + cont?.resume(returning: result) + } + + private func handleAdviceExtracted(_ json: [String: Any]) { + guard let frameId = json["frame_id"] as? String else { return } + let result = AdviceExtractedResult(frameId: frameId, advice: json["advice"]) + requestLock.lock() + let cont = pendingAdviceRequests.removeValue(forKey: frameId) + requestLock.unlock() + cont?.resume(returning: result) + } + + private func handleLiveNote(_ json: [String: Any]) { + let text = json["text"] as? String ?? "" + requestLock.lock() + let cont = pendingLiveNote + pendingLiveNote = nil + requestLock.unlock() + cont?.resume(returning: text) + } + + private func handleProfileUpdated(_ json: [String: Any]) { + let profileText = json["profile_text"] as? String ?? "" + requestLock.lock() + let cont = pendingProfile + pendingProfile = nil + requestLock.unlock() + cont?.resume(returning: profileText) + } + private func handleRerankComplete(_ json: [String: Any]) { + let updatedTasks = (json["updated_tasks"] as? [[String: Any]]) ?? [] + let result = RerankExtractedResult(updatedTasks: updatedTasks) requestLock.lock() - let continuation = pendingFocusRequests.removeValue(forKey: response.frameId) + let cont = pendingRerank + pendingRerank = nil requestLock.unlock() + cont?.resume(returning: result) + } - continuation?.resume(returning: analysis) + private func handleDedupComplete(_ json: [String: Any]) { + let deletedIds = (json["deleted_ids"] as? [String]) ?? [] + let reason = json["reason"] as? String ?? "" + let result = DedupExtractedResult(deletedIds: deletedIds, reason: reason) + requestLock.lock() + let cont = pendingDedup + pendingDedup = nil + requestLock.unlock() + cont?.resume(returning: result) } // MARK: - Helpers private func cancelAllPending(error: Error) { requestLock.lock() - let pending = pendingFocusRequests - pendingFocusRequests.removeAll() + let focus = pendingFocusRequests; pendingFocusRequests.removeAll() + let tasks = pendingTasksRequests; pendingTasksRequests.removeAll() + let memories = pendingMemoriesRequests; pendingMemoriesRequests.removeAll() + let advice = pendingAdviceRequests; pendingAdviceRequests.removeAll() + let liveNote = pendingLiveNote; pendingLiveNote = nil + let profile = pendingProfile; pendingProfile = nil + let rerank = pendingRerank; pendingRerank = nil + let dedup = pendingDedup; pendingDedup = nil requestLock.unlock() - for (_, continuation) in pending { - continuation.resume(throwing: error) - } + for (_, c) in focus { c.resume(throwing: error) } + for (_, c) in tasks { c.resume(throwing: error) } + for (_, c) in memories { c.resume(throwing: error) } + for (_, c) in advice { c.resume(throwing: error) } + liveNote?.resume(throwing: error) + profile?.resume(throwing: error) + rerank?.resume(throwing: error) + dedup?.resume(throwing: error) } private static func getBaseURL() -> String? { @@ -337,22 +560,33 @@ class BackendProactiveService { } } -// MARK: - Response Models +// MARK: - Result Types -private struct FocusResultResponse: Decodable { - let type: String +/// Tasks extracted from a screen_frame analysis. +struct TasksExtractedResult { let frameId: String - let status: String - let appOrSite: String - let description: String - let message: String? - - enum CodingKeys: String, CodingKey { - case type - case frameId = "frame_id" - case status - case appOrSite = "app_or_site" - case description - case message - } + let tasks: [[String: Any]] // Raw task dicts from backend +} + +/// Memories extracted from a screen_frame analysis. +struct MemoriesExtractedResult { + let frameId: String + let memories: [[String: Any]] // Raw memory dicts from backend +} + +/// Advice extracted from a screen_frame analysis. +struct AdviceExtractedResult { + let frameId: String + let advice: Any? // Raw advice from backend (dict or null) +} + +/// Task reranking result. +struct RerankExtractedResult { + let updatedTasks: [[String: Any]] // [{id, new_position}, ...] +} + +/// Task deduplication result. +struct DedupExtractedResult { + let deletedIds: [String] + let reason: String } From 3010fe28bc2dc3aa8938dede100d3f8a5427c308 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:01 +0100 Subject: [PATCH 141/163] Wire TaskAssistant thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace GeminiClient tool-calling loop with backendService.extractTasks(). Remove extractTaskSingleStage, refreshContext, vector/keyword search, validateTaskTitle — all LLM logic now server-side. -550 lines. Co-Authored-By: Claude Opus 4.6 --- .../TaskExtraction/TaskAssistant.swift | 698 ++---------------- 1 file changed, 82 insertions(+), 616 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift index 8df5b2fcc7..512c464777 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift @@ -1,7 +1,7 @@ import Foundation -/// Task extraction assistant that identifies tasks and action items from screen content -/// Uses single-stage Gemini tool calling with vector + FTS5 search for deduplication +/// Task extraction assistant that identifies tasks and action items from screen content. +/// Phase 2: sends screenshots to backend via WebSocket, receives structured task results. actor TaskAssistant: ProactiveAssistant { // MARK: - ProactiveAssistant Protocol @@ -18,7 +18,7 @@ actor TaskAssistant: ProactiveAssistant { // MARK: - Properties - private let geminiClient: GeminiClient + private let backendService: BackendProactiveService private var isRunning = false private var previousTasks: [ExtractedTask] = [] // Last 10 extracted tasks for context private let maxPreviousTasks = 10 @@ -41,11 +41,6 @@ actor TaskAssistant: ProactiveAssistant { /// Timestamp of last context switch yield, for throttling rapid switches private var lastContextSwitchYieldTime: Date = .distantPast - // Cached goals (refreshed every 5 minutes) - private var cachedGoals: [Goal] = [] - private var lastGoalsRefresh: Date = .distantPast - private let goalsRefreshInterval: TimeInterval = 300 - // MARK: - Due Date Helpers /// Parse an inferred deadline string into a Date, or default to end of today. @@ -114,15 +109,6 @@ actor TaskAssistant: ProactiveAssistant { return calendar.date(bySettingHour: 23, minute: 59, second: 0, of: startOfDay) ?? startOfDay } - /// Get the current system prompt from settings (accessed on MainActor for thread safety) - private var systemPrompt: String { - get async { - await MainActor.run { - TaskAssistantSettings.shared.analysisPrompt - } - } - } - /// Get the extraction interval from settings private var extractionInterval: TimeInterval { get async { @@ -143,9 +129,8 @@ actor TaskAssistant: ProactiveAssistant { // MARK: - Initialization - init(apiKey: String? = nil) throws { - // Use Gemini 3 Pro for better task extraction quality - self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest") + init(backendService: BackendProactiveService) { + self.backendService = backendService let (stream, continuation) = AsyncStream.makeStream(of: TriggerEvent.self, bufferingPolicy: .bufferingNewest(1)) self.triggerStream = stream @@ -221,11 +206,17 @@ actor TaskAssistant: ProactiveAssistant { // MARK: - Test Analysis (for test runner) - /// Run the extraction pipeline on arbitrary JPEG data without side effects (no saving, no events). - /// Used by the test runner to replay past screenshots. - /// Returns (result, searchCount) where searchCount is the number of search tool calls made. + /// Run extraction via backend for test runner. Returns (result, 0) for compatibility. func testAnalyze(jpegData: Data, appName: String) async throws -> (TaskExtractionResult?, Int) { - return try await extractTaskSingleStage(from: jpegData, appName: appName) + let base64 = autoreleasepool { jpegData.base64EncodedString() } + let backendResult = try await backendService.extractTasks( + imageBase64: base64, appName: appName, windowTitle: "" + ) + if backendResult.tasks.isEmpty { + return (TaskExtractionResult(hasNewTask: false, task: nil, contextSummary: "Analyzed \(appName)", currentActivity: ""), 0) + } + let result = parseBackendTask(backendResult.tasks[0], appName: appName) + return (result, 0) } // MARK: - ProactiveAssistant Protocol Methods @@ -579,7 +570,7 @@ actor TaskAssistant: ProactiveAssistant { latestFrame = nil } - // MARK: - Single-Stage Analysis with Tool Calling + // MARK: - Backend Analysis (Phase 2 thin client) private func processFrame(_ frame: CapturedFrame) async { let enabled = await isEnabled @@ -590,613 +581,88 @@ actor TaskAssistant: ProactiveAssistant { log("Task: Analyzing frame from \(frame.appName)...") do { - let (result, searchCount) = try await extractTaskSingleStage(from: frame.jpegData, appName: frame.appName) - guard let result = result else { - log("Task: Analysis returned no result") - return - } - - log("Task: Analysis complete - hasNewTask: \(result.hasNewTask), context: \(result.contextSummary), searches: \(searchCount)") + let base64 = autoreleasepool { frame.jpegData.base64EncodedString() } + let backendResult = try await backendService.extractTasks( + imageBase64: base64, + appName: frame.appName, + windowTitle: frame.windowTitle ?? "" + ) - await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, appName: frame.appName, windowTitle: frame.windowTitle) { type, data in + let sendEvent: (String, [String: Any]) -> Void = { type, data in Task { @MainActor in AssistantCoordinator.shared.sendEvent(type: type, data: data) } } - } catch { - logError("Task extraction error", error: error) - } - } - - /// Loop-based extraction: image analysis + iterative tool calling for search + terminal tool for decision - /// Returns (result, searchCount) where searchCount is the number of search tool calls made. - private func extractTaskSingleStage(from jpegData: Data, appName: String) async throws -> (TaskExtractionResult?, Int) { - // 1. Gather context - let context = await refreshContext() - - // 2. Build prompt with injected context - let dateFormatter = DateFormatter() - dateFormatter.dateFormat = "yyyy-MM-dd (EEEE)" - let todayStr = dateFormatter.string(from: Date()) - - var prompt = "Screenshot from \(appName). Today is \(todayStr). Analyze this screenshot for any unaddressed request directed at the user.\n\n" - - // For messaging apps, add an extra reminder about conversation analysis - let messagingApps: Set = ["Telegram", "WhatsApp", "\u{200E}WhatsApp", "Messages", "Slack", "Discord"] - if messagingApps.contains(appName) { - prompt += """ - REMINDER — THIS IS A MESSAGING APP: - - If this screenshot shows a chat sidebar/conversation list rather than an open conversation, SKIP entirely. - - If it shows an open conversation, read the FULL conversation flow between the user and the other person. - - LEFT-SIDE messages = from the other person. RIGHT-SIDE/colored = from the user. - - PRIORITY: Look for where the user AGREED or COMMITTED to doing something the other person asked. - Example: Other person says "Can you send me the report?" → User replies "Sure, will do" → Extract task: "Send [person] the report" - - ALSO: Look for incoming requests the user hasn't responded to yet. - - The task title should describe what was asked for, naming the other person in the conversation. - - """ - } - - // Inject AI user profile for context - if let profile = await AIUserProfileService.shared.getLatestProfile() { - prompt += "USER PROFILE (who this user is — use for context, not as a task source):\n" - prompt += profile.profileText + "\n\n" - } - - if !context.activeTasks.isEmpty { - // Get score range for context - let scoreRange = try? await ActionItemStorage.shared.getRelevanceScoreRange() - let rangeStr = scoreRange.map { "Score range: \($0.min)–\($0.max). " } ?? "" - - prompt += "ACTIVE TASKS (user is already tracking these — each has a relevance_score where 1 = most important, higher numbers = less important):\n" - prompt += "\(rangeStr)Use these scores to place any new task appropriately.\n" - for (i, task) in context.activeTasks.enumerated() { - let pri = task.priority.map { " [\($0)]" } ?? "" - let score = task.relevanceScore.map { " [score:\($0)]" } ?? "" - prompt += "\(i + 1).\(score) \(task.description)\(pri)\n" - } - prompt += "\n" - } - - if !context.completedTasks.isEmpty { - prompt += "RECENTLY COMPLETED TASKS (user engaged with these — this is the kind of task the user finds valuable. Extract similar types of tasks, just not exact duplicates of these specific ones):\n" - for (i, task) in context.completedTasks.enumerated() { - prompt += "\(i + 1). \(task.description)\n" - } - prompt += "\n" - } - - if !context.deletedTasks.isEmpty { - prompt += "USER-DELETED TASKS (user explicitly rejected these — do not re-extract similar):\n" - for (i, task) in context.deletedTasks.enumerated() { - prompt += "\(i + 1). \(task.description)\n" - } - prompt += "\n" - } - - if !context.goals.isEmpty { - prompt += "ACTIVE GOALS:\n" - for (i, goal) in context.goals.enumerated() { - prompt += "\(i + 1). \(goal.title)" - if let desc = goal.description { - prompt += " — \(desc)" - } - prompt += "\n" - } - prompt += "\n" - } - prompt += """ - Analyze this screenshot. If you see a potential request, search for duplicates first. - If there is clearly no request on screen (~90% of screenshots), call no_task_found immediately. - """ - - // 3. Define 5 tools - let tools = GeminiTool(functionDeclarations: [ - GeminiTool.FunctionDeclaration( - name: "search_similar", - description: "Search for semantically similar existing tasks using vector similarity. Call this when you see a potential request and want to check for duplicates.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init(type: "string", description: "A concise description of the potential task to search for") - ], - required: ["query"] + if backendResult.tasks.isEmpty { + let result = TaskExtractionResult( + hasNewTask: false, task: nil, + contextSummary: "Analyzed \(frame.appName)", + currentActivity: "" ) - ), - GeminiTool.FunctionDeclaration( - name: "search_keywords", - description: "Search for existing tasks matching specific keywords. Use this for precise keyword-based matching complementing vector search.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init(type: "string", description: "Keywords to search for in existing tasks") - ], - required: ["query"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "no_task_found", - description: "Call this when there is no actionable request on screen. This is the most common outcome (~90% of screenshots). Use for: code editors, terminals, settings, media players, dashboards, or any screen without a direct request from another person or AI.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "context_summary": .init(type: "string", description: "Brief summary of what the user is looking at"), - "current_activity": .init(type: "string", description: "What the user is actively doing") - ], - required: ["context_summary", "current_activity"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "extract_task", - description: "Extract a new task that is not already tracked. Call ONLY after searching for duplicates. All fields are required.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "title": .init(type: "string", description: "Verb-first task title, 6–15 words. MUST name a specific person/project/artifact and a concrete action. If you can't write 6+ specific words, call no_task_found instead."), - "description": .init(type: "string", description: "Additional context about the task. Empty string if none."), - "priority": .init(type: "string", description: "Task priority", enumValues: ["high", "medium", "low"]), - "tags": .init(type: "array", description: "1-3 relevant tags", items: .init(type: "string")), - "source_app": .init(type: "string", description: "App where the task was found"), - "inferred_deadline": .init(type: "string", description: "Deadline in yyyy-MM-dd format (e.g. '2025-10-04'). Resolve relative references like 'Thursday' or 'next week' to an actual date. Empty string if no deadline."), - "confidence": .init(type: "number", description: "Confidence score 0.0-1.0"), - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "What the user is actively doing"), - "source_category": .init(type: "string", description: "Where the task originated", enumValues: ["direct_request", "self_generated", "calendar_driven", "reactive", "external_system", "other"]), - "source_subcategory": .init(type: "string", description: "Specific origin within category", enumValues: ["message", "meeting", "mention", "commitment", "idea", "reminder", "goal_subtask", "event_prep", "recurring", "deadline", "error", "notification", "observation", "project_tool", "alert", "documentation", "other"]), - "relevance_score": .init(type: "integer", description: "Where this task ranks relative to existing tasks. Look at the relevance_score values of existing active tasks and assign a score that places this task appropriately. 1 = most important/urgent, higher numbers = less important. Must be a positive integer.") - ], - required: ["title", "description", "priority", "tags", "source_app", "inferred_deadline", "confidence", "context_summary", "current_activity", "source_category", "source_subcategory", "relevance_score"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "reject_task", - description: "Reject task extraction — the potential task is a duplicate, already completed, or was previously rejected by the user. Call after searching confirms this.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "reason": .init(type: "string", description: "Why this task was rejected (e.g. 'duplicate of existing active task', 'already completed')"), - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "What the user is actively doing") - ], - required: ["reason", "context_summary", "current_activity"] - ) - ) - ]) - - // 4. Get system prompt - let currentSystemPrompt = await systemPrompt - - // 5. Build initial contents - // Wrap base64 encoding in autoreleasepool — Swift concurrency doesn't - // drain autorelease pools, causing bridged NSString objects to accumulate. - var contents: [GeminiImageToolRequest.Content] = autoreleasepool { - let base64Data = jpegData.base64EncodedString() - return [ - GeminiImageToolRequest.Content( - role: "user", - parts: [ - GeminiImageToolRequest.Part(text: prompt), - GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64Data) - ] - ) - ] - } - - // 6. Tool-calling loop (max 5 iterations) - var searchCount = 0 - - for iteration in 0..<5 { - let result = try await geminiClient.sendImageToolLoop( - contents: contents, - systemPrompt: currentSystemPrompt, - tools: [tools], - forceToolCall: iteration == 0 - ) - - guard let toolCall = result.toolCalls.first else { - log("Task: No tool call received on iteration \(iteration), breaking") - break - } - - switch toolCall.name { - case "no_task_found": - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "No task on screen" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "Unknown" - log("Task: no_task_found — \(contextSummary)") - return (TaskExtractionResult( - hasNewTask: false, - task: nil, - contextSummary: contextSummary, - currentActivity: currentActivity - ), searchCount) - - case "extract_task": - let title = toolCall.arguments["title"] as? String ?? "" - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "" - - // --- Hard validation: reject vague titles and ask the model to retry --- - let titleWords = title.split(separator: " ").count - let validationError = Self.validateTaskTitle(title, wordCount: titleWords) - if let error = validationError { - log("Task: Title rejected (\(error)): \"\(title)\"") - - // Feed rejection back into the loop so the model can retry with more specifics - contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: toolCall.arguments as? [String: String] ?? ["title": title]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: """ - REJECTED: \(error). \ - Your title was: "\(title)" (\(titleWords) words). \ - Either rewrite with 6+ words including a specific person/project name and concrete action, \ - or call no_task_found if you cannot be more specific. - """) - ))] - )) - continue - } - - let description = toolCall.arguments["description"] as? String - let priorityStr = toolCall.arguments["priority"] as? String ?? "medium" - let priority = TaskPriority(rawValue: priorityStr) ?? .medium - let tags: [String] - if let tagArray = toolCall.arguments["tags"] as? [Any] { - tags = tagArray.compactMap { $0 as? String } - } else { - tags = [] - } - let sourceApp = toolCall.arguments["source_app"] as? String ?? appName - let inferredDeadline = toolCall.arguments["inferred_deadline"] as? String - let confidence: Double - if let confValue = toolCall.arguments["confidence"] as? Double { - confidence = confValue - } else if let confInt = toolCall.arguments["confidence"] as? Int { - confidence = Double(confInt) - } else { - confidence = 0.5 - } - let sourceCategory = toolCall.arguments["source_category"] as? String ?? "other" - let sourceSubcategory = toolCall.arguments["source_subcategory"] as? String ?? "other" - let relevanceScore: Int? - if let scoreValue = toolCall.arguments["relevance_score"] as? Int { - relevanceScore = scoreValue - } else if let scoreDouble = toolCall.arguments["relevance_score"] as? Double { - relevanceScore = Int(scoreDouble) - } else { - relevanceScore = nil - } - - let task = ExtractedTask( - title: title, - description: description?.isEmpty == true ? nil : description, - priority: priority, - sourceApp: sourceApp, - inferredDeadline: inferredDeadline?.isEmpty == true ? nil : inferredDeadline, - confidence: confidence, - tags: tags, - sourceCategory: sourceCategory, - sourceSubcategory: sourceSubcategory, - relevanceScore: relevanceScore - ) - - log("Task: extract_task — \"\(title)\" (confidence: \(confidence), priority: \(priorityStr), score: \(relevanceScore.map { String($0) } ?? "nil"))") - return (TaskExtractionResult( - hasNewTask: true, - task: task, - contextSummary: contextSummary, - currentActivity: currentActivity - ), searchCount) - - case "reject_task": - let reason = toolCall.arguments["reason"] as? String ?? "Unknown reason" - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "" - log("Task: reject_task — \(reason)") - return (TaskExtractionResult( - hasNewTask: false, - task: nil, - contextSummary: contextSummary, - currentActivity: currentActivity - ), searchCount) - - case "search_similar": - let query = toolCall.arguments["query"] as? String ?? "" - searchCount += 1 - log("Task: search_similar query: \"\(query)\"") - let searchResults = await executeVectorSearch(query: query) - log("Task: Vector search returned \(searchResults.count) results") - - let searchResultsJson: String - if let data = try? JSONEncoder().encode(searchResults), - let json = String(data: data, encoding: .utf8) { - searchResultsJson = json - } else { - searchResultsJson = "[]" - } - - // Append model's tool call + function response to contents - contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: ["query": query]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: searchResultsJson) - ))] - )) - continue - - case "search_keywords": - let query = toolCall.arguments["query"] as? String ?? "" - searchCount += 1 - log("Task: search_keywords query: \"\(query)\"") - let searchResults = await executeKeywordSearch(query: query) - log("Task: Keyword search returned \(searchResults.count) results") - - let searchResultsJson: String - if let data = try? JSONEncoder().encode(searchResults), - let json = String(data: data, encoding: .utf8) { - searchResultsJson = json - } else { - searchResultsJson = "[]" - } - - // Append model's tool call + function response to contents - contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: ["query": query]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: searchResultsJson) - ))] - )) - continue - - default: - log("Task: Unknown tool call: \(toolCall.name), breaking") - break - } - } - - log("Task: Completed in \(searchCount) searches (loop exhausted without terminal tool)") - return (nil, searchCount) - } - - // MARK: - Title Validation - - /// Validates a task title for minimum specificity. Returns an error message if invalid, nil if OK. - private static func validateTaskTitle(_ title: String, wordCount: Int) -> String? { - let trimmed = title.trimmingCharacters(in: .whitespacesAndNewlines) - - // Must not be empty - if trimmed.isEmpty { - return "Title is empty" - } - - // Minimum 6 words - if wordCount < 6 { - return "Title too short (\(wordCount) words, minimum 6)" - } - - // Reject titles that are purely generic verbs with no specifics - let genericPatterns: [String] = [ - "investigate", "check logs", "clean up", "look into", - "look through", "update to", "fix the", "review the", - "check the", "modify the", "track the" - ] - let lowered = trimmed.lowercased() - for pattern in genericPatterns { - // If the entire title is just a generic pattern (possibly with 1-2 filler words), reject - if lowered == pattern || (wordCount <= 4 && lowered.hasPrefix(pattern)) { - return "Title too generic (matches vague pattern '\(pattern)')" + log("Task: Analysis returned no tasks") + await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, appName: frame.appName, windowTitle: frame.windowTitle, sendEvent: sendEvent) + return } - } - // Must contain at least one capitalized proper noun (person, project, app name) - // Heuristic: after the first word (verb), there should be at least one word starting with uppercase - let words = trimmed.split(separator: " ") - let hasProperNoun = words.dropFirst().contains { word in - guard let first = word.first else { return false } - return first.isUppercase - } - if !hasProperNoun { - return "Title lacks a specific name (person, project, or app) — no proper nouns found after the verb" - } - - return nil - } - - // MARK: - Context & Search + log("Task: Analysis complete - \(backendResult.tasks.count) task(s)") - /// Refresh context from local SQLite + cached goals - private func refreshContext() async -> TaskExtractionContext { - var topRelevanceTasks: [(id: Int64, description: String, priority: String?, relevanceScore: Int?)] = [] - var recentTasks: [(id: Int64, description: String, priority: String?, relevanceScore: Int?)] = [] - var completedTasks: [(id: Int64, description: String)] = [] - var deletedTasks: [(id: Int64, description: String)] = [] - - // Query both action_items (promoted + manual) and staged_tasks for full context - do { - topRelevanceTasks = try await ActionItemStorage.shared.getTopRelevanceTasks(limit: 30) - } catch { - logError("Task: Failed to load top relevance tasks", error: error) - } - - do { - recentTasks = try await ActionItemStorage.shared.getRecentActiveTasks(limit: 30) - } catch { - logError("Task: Failed to load recent tasks", error: error) - } - - // Also include staged tasks for dedup context - do { - let stagedTasks = try await StagedTaskStorage.shared.getAllStagedTasks(limit: 30) - let stagedAsTuples = stagedTasks.map { task in - (id: Int64(0), description: task.description, priority: task.priority, relevanceScore: task.relevanceScore) + for taskDict in backendResult.tasks { + let result = parseBackendTask(taskDict, appName: frame.appName) + await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, appName: frame.appName, windowTitle: frame.windowTitle, sendEvent: sendEvent) } - recentTasks.append(contentsOf: stagedAsTuples) - } catch { - logError("Task: Failed to load staged tasks for context", error: error) - } - - // Merge: top relevance tasks first, then recent ones not already included - let topIds = Set(topRelevanceTasks.map { $0.id }) - let activeTasks = topRelevanceTasks + recentTasks.filter { !topIds.contains($0.id) } - - do { - completedTasks = try await ActionItemStorage.shared.getRecentCompletedTasks(limit: 10) - } catch { - logError("Task: Failed to load completed tasks", error: error) - } - - do { - deletedTasks = try await ActionItemStorage.shared.getRecentDeletedTasks(limit: 10, deletedBy: "user") } catch { - logError("Task: Failed to load deleted tasks", error: error) - } - - // Refresh goals if stale - let timeSinceGoals = Date().timeIntervalSince(lastGoalsRefresh) - if timeSinceGoals >= goalsRefreshInterval { - do { - cachedGoals = try await APIClient.shared.getGoals() - lastGoalsRefresh = Date() - log("Task: Refreshed \(cachedGoals.count) goals") - } catch { - logError("Task: Failed to refresh goals", error: error) - } + logError("Task extraction error", error: error) } - - return TaskExtractionContext( - activeTasks: activeTasks, - completedTasks: completedTasks, - deletedTasks: deletedTasks, - goals: cachedGoals - ) } - /// Execute vector similarity search - private func executeVectorSearch(query: String) async -> [TaskSearchResult] { - var results: [TaskSearchResult] = [] - - do { - let queryEmbedding = try await EmbeddingService.shared.embed(text: query) - let vectorResults = await EmbeddingService.shared.searchSimilar(query: queryEmbedding, topK: 10) - - for result in vectorResults where result.similarity > 0.3 { - if let record = try await ActionItemStorage.shared.getActionItem(id: result.id) { - let status: String - if record.deleted { status = "deleted" } - else if record.completed { status = "completed" } - else { status = "active" } - - results.append(TaskSearchResult( - id: result.id, - description: record.description, - status: status, - similarity: Double(result.similarity), - matchType: "vector", - relevanceScore: record.relevanceScore - )) - } else if let staged = try await StagedTaskStorage.shared.getStagedTask(id: result.id) { - // Fallback: ID belongs to a staged task (shared embedding index) - let status: String - if staged.deleted { status = "deleted" } - else if staged.completed { status = "completed" } - else { status = "active" } - - results.append(TaskSearchResult( - id: result.id, - description: staged.description, - status: status, - similarity: Double(result.similarity), - matchType: "vector", - relevanceScore: staged.relevanceScore - )) - } - } - } catch { - logError("Task: Vector search failed", error: error) - } + /// Parse a raw task dict from the backend into a TaskExtractionResult. + private func parseBackendTask(_ dict: [String: Any], appName: String) -> TaskExtractionResult { + let title = dict["title"] as? String ?? "" + let description = dict["description"] as? String + let priorityStr = dict["priority"] as? String ?? "medium" + let priority = TaskPriority(rawValue: priorityStr) ?? .medium + let tags = (dict["tags"] as? [String]) ?? [] + let sourceApp = dict["source_app"] as? String ?? appName + let inferredDeadline = dict["inferred_deadline"] as? String + let confidence: Double + if let confValue = dict["confidence"] as? Double { + confidence = confValue + } else if let confInt = dict["confidence"] as? Int { + confidence = Double(confInt) + } else { + confidence = 0.5 + } + let sourceCategory = dict["source_category"] as? String ?? "other" + let sourceSubcategory = dict["source_subcategory"] as? String ?? "other" + let relevanceScore: Int? + if let scoreValue = dict["relevance_score"] as? Int { + relevanceScore = scoreValue + } else if let scoreDouble = dict["relevance_score"] as? Double { + relevanceScore = Int(scoreDouble) + } else { + relevanceScore = nil + } + + let task = ExtractedTask( + title: title, + description: description?.isEmpty == true ? nil : description, + priority: priority, + sourceApp: sourceApp, + inferredDeadline: inferredDeadline?.isEmpty == true ? nil : inferredDeadline, + confidence: confidence, + tags: tags, + sourceCategory: sourceCategory, + sourceSubcategory: sourceSubcategory, + relevanceScore: relevanceScore + ) - return results.sorted { ($0.similarity ?? 0) > ($1.similarity ?? 0) } + return TaskExtractionResult( + hasNewTask: true, + task: task, + contextSummary: dict["context_summary"] as? String ?? "Analyzed \(appName)", + currentActivity: dict["current_activity"] as? String ?? "" + ) } - /// Execute FTS5 keyword search (searches both action_items and staged_tasks) - private func executeKeywordSearch(query: String) async -> [TaskSearchResult] { - var results: [TaskSearchResult] = [] - - do { - let words = query.components(separatedBy: .whitespaces) - .map { $0.filter { $0.isLetter || $0.isNumber } } // Strip FTS5 special chars (- : * " etc.) - .filter { $0.count >= 3 } - let ftsQuery = words.map { "\($0)*" }.joined(separator: " OR ") - - if !ftsQuery.isEmpty { - // Search action_items (promoted + manual) - let ftsResults = try await ActionItemStorage.shared.searchFTS( - query: ftsQuery, - limit: 10, - includeCompleted: true, - includeDeleted: true - ) - - for result in ftsResults { - let status: String - if result.deleted { status = "deleted" } - else if result.completed { status = "completed" } - else { status = "active" } - - results.append(TaskSearchResult( - id: result.id, - description: result.description, - status: status, - similarity: nil, - matchType: "fts", - relevanceScore: result.relevanceScore - )) - } - - // Also search staged_tasks - let stagedResults = try await StagedTaskStorage.shared.searchFTS( - query: ftsQuery, - limit: 10 - ) - for result in stagedResults { - results.append(TaskSearchResult( - id: result.id, - description: result.description, - status: "active", - similarity: nil, - matchType: "fts", - relevanceScore: result.relevanceScore - )) - } - } - } catch { - logError("Task: FTS search failed", error: error) - } - - return results - } } From e6155f3b4b5a8500ceb753d706fd3a37d4ffe05e Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:07 +0100 Subject: [PATCH 142/163] Wire MemoryAssistant thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace GeminiClient.sendRequest with backendService.extractMemories(). Remove prompt/schema building — all LLM logic now server-side. Co-Authored-By: Claude Opus 4.6 --- .../MemoryExtraction/MemoryAssistant.swift | 97 +++++++------------ 1 file changed, 33 insertions(+), 64 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift index bcc6a6f1e6..2e0c671820 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift @@ -17,7 +17,7 @@ actor MemoryAssistant: ProactiveAssistant { // MARK: - Properties - private let geminiClient: GeminiClient + private let backendService: BackendProactiveService private var isRunning = false private var lastAnalysisTime: Date = .distantPast private var previousMemories: [ExtractedMemory] = [] // Last 20 extracted memories for deduplication @@ -28,15 +28,6 @@ actor MemoryAssistant: ProactiveAssistant { private let frameSignal: AsyncStream private let frameSignalContinuation: AsyncStream.Continuation - /// Get the current system prompt from settings (accessed on MainActor for thread safety) - private var systemPrompt: String { - get async { - await MainActor.run { - MemoryAssistantSettings.shared.analysisPrompt - } - } - } - /// Get the extraction interval from settings private var extractionInterval: TimeInterval { get async { @@ -57,9 +48,8 @@ actor MemoryAssistant: ProactiveAssistant { // MARK: - Initialization - init(apiKey: String? = nil) throws { - // Use Gemini 3 Pro for better memory extraction quality - self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest") + init(backendService: BackendProactiveService) { + self.backendService = backendService let (stream, continuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1)) self.frameSignal = stream @@ -340,61 +330,40 @@ actor MemoryAssistant: ProactiveAssistant { } private func extractMemories(from jpegData: Data, appName: String) async throws -> MemoryExtractionResult? { - // Build context with previous memories for deduplication - var prompt = "Analyze this screenshot from \(appName).\n\n" + let base64 = autoreleasepool { jpegData.base64EncodedString() } + let backendResult = try await backendService.extractMemories( + imageBase64: base64, + appName: appName, + windowTitle: "" + ) - if !previousMemories.isEmpty { - prompt += "RECENTLY EXTRACTED MEMORIES (do not re-extract these or semantically similar ones):\n" - for (index, memory) in previousMemories.enumerated() { - prompt += "\(index + 1). [\(memory.category.rawValue)] \(memory.content)\n" + // Parse backend response into MemoryExtractionResult + let memories: [ExtractedMemory] = backendResult.memories.compactMap { dict in + guard let content = dict["content"] as? String, !content.isEmpty else { return nil } + let categoryStr = dict["category"] as? String ?? "system" + let category: ExtractedMemoryCategory = categoryStr == "interesting" ? .interesting : .system + let sourceApp = dict["source_app"] as? String ?? appName + let confidence: Double + if let confValue = dict["confidence"] as? Double { + confidence = confValue + } else if let confInt = dict["confidence"] as? Int { + confidence = Double(confInt) + } else { + confidence = 0.5 } - prompt += "\nLook for NEW memories that are NOT already in the list above." - } else { - prompt += "Look for memories to extract (system facts about the user, or interesting wisdom from others)." + return ExtractedMemory( + content: content, + category: category, + sourceApp: sourceApp, + confidence: confidence + ) } - // Get current system prompt from settings - let currentSystemPrompt = await systemPrompt - - // Build response schema for memory extraction - let memoryProperties: [String: GeminiRequest.GenerationConfig.ResponseSchema.Property] = [ - "content": .init(type: "string", description: "The memory content (max 15 words)"), - "category": .init(type: "string", enum: ["system", "interesting"], description: "Memory category"), - "source_app": .init(type: "string", description: "App where memory was found"), - "confidence": .init(type: "number", description: "Confidence score 0.0-1.0") - ] - - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "has_new_memory": .init(type: "boolean", description: "True if new memories were found"), - "memories": .init( - type: "array", - description: "Array of extracted memories (0-3 max)", - items: .init( - type: "object", - properties: memoryProperties, - required: ["content", "category", "source_app", "confidence"] - ) - ), - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "High-level description of user's activity") - ], - required: ["has_new_memory", "memories", "context_summary", "current_activity"] + return MemoryExtractionResult( + hasNewMemory: !memories.isEmpty, + memories: memories, + contextSummary: "Analyzed \(appName)", + currentActivity: "" ) - - do { - let responseText = try await geminiClient.sendRequest( - prompt: prompt, - imageData: jpegData, - systemPrompt: currentSystemPrompt, - responseSchema: responseSchema - ) - - return try JSONDecoder().decode(MemoryExtractionResult.self, from: Data(responseText.utf8)) - } catch { - logError("Memory analysis error", error: error) - return nil - } } } From f1b47a5990ba37d5a737b97a5d30e43630a7db6f Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:14 +0100 Subject: [PATCH 143/163] Wire AdviceAssistant thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace 2-phase Gemini tool-calling loop (execute_sql + vision) with backendService.generateAdvice(). Remove compressForGemini, getUserLanguage, buildActivitySummary, buildPhase1/2Tools — all LLM logic server-side. -560 lines. Co-Authored-By: Claude Opus 4.6 --- .../Assistants/Advice/AdviceAssistant.swift | 660 ++---------------- 1 file changed, 60 insertions(+), 600 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift index 87ca8c4b0a..e1c2701817 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Advice/AdviceAssistant.swift @@ -1,4 +1,3 @@ -import AppKit import Foundation import GRDB @@ -19,7 +18,7 @@ actor AdviceAssistant: ProactiveAssistant { // MARK: - Properties - private let geminiClient: GeminiClient + private let backendService: BackendProactiveService private var isRunning = false private var lastAnalysisTime: Date = .distantPast private var previousAdvice: [ExtractedAdvice] = [] // Dedup window for advice context @@ -33,15 +32,6 @@ actor AdviceAssistant: ProactiveAssistant { private let frameSignal: AsyncStream private let frameSignalContinuation: AsyncStream.Continuation - /// Get the current system prompt from settings (accessed on MainActor for thread safety) - private var systemPrompt: String { - get async { - await MainActor.run { - AdviceAssistantSettings.shared.analysisPrompt - } - } - } - /// Get the extraction interval from settings private var extractionInterval: TimeInterval { get async { @@ -62,9 +52,8 @@ actor AdviceAssistant: ProactiveAssistant { // MARK: - Initialization - init(apiKey: String? = nil) throws { - // Use Gemini 3.1 Pro for better advice quality (3-pro-preview retires March 9, 2026) - self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest") + init(backendService: BackendProactiveService) { + self.backendService = backendService let (stream, continuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1)) self.frameSignal = stream @@ -140,25 +129,6 @@ actor AdviceAssistant: ProactiveAssistant { log("Advice assistant stopped") } - // MARK: - Test Analysis (for test runner) - - /// Run the extraction pipeline on arbitrary JPEG data without side effects (no saving, no events). - /// Used by the test runner to replay past screenshots. - /// `screenshotTime` anchors the activity summary to the screenshot's actual timestamp. - /// Returns (result, sqlQueryCount) where sqlQueryCount is the number of execute_sql tool calls made. - func testAnalyze(jpegData: Data, appName: String, windowTitle: String? = nil, screenshotTime: Date) async throws -> (AdviceExtractionResult?, Int) { - let interval = await extractionInterval - let lookbackStart = screenshotTime.addingTimeInterval(-interval) - return try await runAdviceExtraction( - jpegData: nil, - appName: appName, - windowTitle: windowTitle, - referenceTime: screenshotTime, - lookbackStart: lookbackStart, - trackSqlCount: true - ) - } - // MARK: - ProactiveAssistant Protocol Methods func shouldAnalyze(frameNumber: Int, timeSinceLastAnalysis: TimeInterval) -> Bool { @@ -379,62 +349,34 @@ actor AdviceAssistant: ProactiveAssistant { pendingFrame = nil } - // MARK: - Image Processing - - /// Resize and compress an image for Gemini analysis (max 1280px wide, JPEG quality 0.4) - private static func compressForGemini(_ data: Data) -> Data? { - guard let source = CGImageSourceCreateWithData(data as CFData, nil), - let cgImage = CGImageSourceCreateImageAtIndex(source, 0, nil) else { return nil } - - let maxWidth = 1280 - let width = cgImage.width - let height = cgImage.height - let scale = width > maxWidth ? Double(maxWidth) / Double(width) : 1.0 - let newWidth = Int(Double(width) * scale) - let newHeight = Int(Double(height) * scale) - - guard let context = CGContext( - data: nil, width: newWidth, height: newHeight, - bitsPerComponent: 8, bytesPerRow: 0, - space: CGColorSpaceCreateDeviceRGB(), - bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue - ) else { return nil } - - context.interpolationQuality = .high - context.draw(cgImage, in: CGRect(x: 0, y: 0, width: newWidth, height: newHeight)) - - guard let resized = context.makeImage() else { return nil } - - let mutableData = NSMutableData() - guard let dest = CGImageDestinationCreateWithData(mutableData as CFMutableData, "public.jpeg" as CFString, 1, nil) else { return nil } - CGImageDestinationAddImage(dest, resized, [kCGImageDestinationLossyCompressionQuality: 0.4] as CFDictionary) - guard CGImageDestinationFinalize(dest) else { return nil } - return mutableData as Data - } - - // MARK: - Helpers + // MARK: - Test Analysis (for test runner) - /// Get user's preferred language, cached for 1 hour - private func getUserLanguage() async -> String? { - // Return cached value if fresh (< 1 hour) - if let cached = cachedLanguage, Date().timeIntervalSince(languageFetchedAt) < 3600 { - return cached + /// Run extraction via backend for test runner. Returns (result, 0) for compatibility. + func testAnalyze(jpegData: Data, appName: String, windowTitle: String? = nil, screenshotTime: Date) async throws -> (AdviceExtractionResult?, Int) { + let base64 = autoreleasepool { jpegData.base64EncodedString() } + let backendResult = try await backendService.generateAdvice( + imageBase64: base64, appName: appName, windowTitle: windowTitle ?? "" + ) + guard let adviceDict = backendResult.advice as? [String: Any] else { + return (AdviceExtractionResult(hasAdvice: false, advice: nil, contextSummary: "Analyzed \(appName)", currentActivity: ""), 0) } - - do { - let response = try await APIClient.shared.getUserLanguage() - let lang = response.language - cachedLanguage = lang - languageFetchedAt = Date() - return lang.isEmpty ? nil : lang - } catch { - // Fall back to transcription language setting - let fallback = await MainActor.run { AssistantSettings.shared.transcriptionLanguage } - return fallback.isEmpty || fallback == "en" ? nil : fallback + let hasAdvice = adviceDict["has_advice"] as? Bool ?? !adviceDict.isEmpty + guard hasAdvice, let adviceText = adviceDict["content"] as? String ?? adviceDict["advice"] as? String, !adviceText.isEmpty else { + return (AdviceExtractionResult(hasAdvice: false, advice: nil, contextSummary: "Analyzed \(appName)", currentActivity: ""), 0) } + let categoryStr = adviceDict["category"] as? String ?? "other" + let category = AdviceCategory(rawValue: categoryStr) ?? .other + let confidence = adviceDict["confidence"] as? Double ?? 0.5 + let advice = ExtractedAdvice( + advice: adviceText, headline: adviceDict["headline"] as? String, + reasoning: adviceDict["reasoning"] as? String, category: category, + sourceApp: appName, confidence: confidence + ) + let result = AdviceExtractionResult(hasAdvice: true, advice: advice, contextSummary: "Analyzed \(appName)", currentActivity: "") + return (result, 0) } - // MARK: - Analysis + // MARK: - Backend Analysis (Phase 2 thin client) private func processFrame(_ frame: CapturedFrame) async { guard await isEnabled else { return } @@ -443,7 +385,6 @@ actor AdviceAssistant: ProactiveAssistant { return } - // Handle the result with screenshot ID for SQLite storage await handleResultWithScreenshot(result, screenshotId: frame.screenshotId, windowTitle: frame.windowTitle) { type, data in Task { @MainActor in AssistantCoordinator.shared.sendEvent(type: type, data: data) @@ -455,549 +396,68 @@ actor AdviceAssistant: ProactiveAssistant { } private func extractAdvice(from frame: CapturedFrame) async throws -> AdviceExtractionResult? { - let now = Date() - // Cap lookback: since last analysis or max 1 hour ago - let lookbackStart = max(lastAnalysisTime, now.addingTimeInterval(-3600)) - let (result, _) = try await runAdviceExtraction( - jpegData: nil, + let base64 = autoreleasepool { frame.jpegData.base64EncodedString() } + let backendResult = try await backendService.generateAdvice( + imageBase64: base64, appName: frame.appName, - windowTitle: frame.windowTitle, - referenceTime: now, - lookbackStart: lookbackStart, - trackSqlCount: false + windowTitle: frame.windowTitle ?? "" ) - return result - } - // MARK: - Core Extraction (shared by production + test) - - /// Two-phase advice extraction: - /// Phase 1 (text-only): Activity summary + SQL investigation loop. Model investigates via - /// execute_sql, then calls `request_screenshot` with an ID and its findings so far. - /// Phase 2 (single vision call): Load the chosen screenshot + Phase 1 findings → single - /// Gemini call with image → provide_advice or no_advice. - /// Returns (result, sqlQueryCount). - private func runAdviceExtraction( - jpegData: Data?, - appName: String, - windowTitle: String?, - referenceTime: Date, - lookbackStart: Date, - trackSqlCount: Bool - ) async throws -> (AdviceExtractionResult?, Int) { - var sqlCount = 0 - - // Build prompt with current context - let timeFormatter = DateFormatter() - timeFormatter.dateFormat = "h:mm a, EEEE" - var prompt = "CURRENT APP: \(appName)." - if let windowTitle = windowTitle, !windowTitle.isEmpty { - prompt += " Window: \"\(windowTitle)\"." - } - prompt += " Time: \(timeFormatter.string(from: referenceTime))." - - // Add activity summary from database, anchored to the reference time - let elapsed = referenceTime.timeIntervalSince(lookbackStart) - log("Advice: Activity lookback: \(String(format: "%.0f", elapsed))s (\(lookbackStart) to \(referenceTime))") - let activitySummary = await buildActivitySummary(from: lookbackStart, to: referenceTime) - if !activitySummary.isEmpty { - prompt += "\n\n" + activitySummary - log("Advice: --- ACTIVITY SUMMARY ---\n\(activitySummary)") - } else { - log("Advice: --- ACTIVITY SUMMARY --- (empty, no screenshots in range)") - } - - // Add user profile for context - if let profile = await AIUserProfileService.shared.getLatestProfile() { - prompt += "\n\nUSER PROFILE (who this user is):\n" - prompt += profile.profileText + "\n" - } - - // Add previous advice for dedup - if !previousAdvice.isEmpty { - prompt += "\n\nPREVIOUSLY PROVIDED ADVICE (do not repeat these or semantically similar):\n" - let adviceToInclude = previousAdvice.prefix(maxAdviceInPrompt) - for (index, advice) in adviceToInclude.enumerated() { - prompt += "\(index + 1). \(advice.advice)" - if let reasoning = advice.reasoning { - prompt += " (Reasoning: \(reasoning))" - } - prompt += "\n" - } - prompt += "\nOnly provide advice if there's a genuinely NEW non-obvious insight not covered above." - } else { - prompt += "\n\nOnly provide advice if there's something specific and non-obvious that would help." - } - - prompt += "\n\nInvestigate the activity summary. Scan OCR from the TOP 3-5 apps (not just the dominant one) — the best insights often come from browsers, communication apps, and notes, not just the app with the most screenshots. Skip apps with < 10 screenshots. When you've identified the most interesting screenshot, call request_screenshot with the ID and your findings. Or call no_advice if nothing qualifies." - - log("Advice: --- PROMPT ---\n\(prompt)") - - // Build system prompt - var currentSystemPrompt = await systemPrompt - if let language = await getUserLanguage(), language != "en" { - currentSystemPrompt += "\n\nIMPORTANT: Respond in the user's preferred language: \(language)" - } - currentSystemPrompt += "\n\nDATABASE SCHEMA for execute_sql:\nscreenshots table columns: id INTEGER, timestamp TEXT, appName TEXT, windowTitle TEXT, ocrText TEXT, focusStatus TEXT" - - // ============================================= - // PHASE 1: Text-only investigation loop - // ============================================= - - let phase1Tools = buildPhase1Tools() - var contents: [GeminiImageToolRequest.Content] = [ - GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(text: prompt)] + // Parse backend response into AdviceExtractionResult + guard let adviceDict = backendResult.advice as? [String: Any] else { + return AdviceExtractionResult( + hasAdvice: false, + advice: nil, + contextSummary: "Analyzed \(frame.appName)", + currentActivity: "" ) - ] - - let client = self.geminiClient - var chosenScreenshotId: Int64? - var investigationFindings: String? - - for iteration in 0..<7 { - let iterContents = contents - let iterSystemPrompt = currentSystemPrompt - let iterTools = [phase1Tools] - let iterForce = iteration == 0 - let result: ToolChatResult - do { - result = try await withThrowingTimeout(seconds: 120) { - try await client.sendImageToolLoop( - contents: iterContents, - systemPrompt: iterSystemPrompt, - tools: iterTools, - forceToolCall: iterForce - ) - } - } catch { - log("Advice: Phase 1 failed on iteration \(iteration): \(error.localizedDescription)") - throw error - } - - guard let toolCall = result.toolCalls.first else { - log("Advice: Phase 1 — no tool call on iteration \(iteration), breaking") - break - } - - switch toolCall.name { - case "execute_sql": - let query = toolCall.arguments["query"] as? String ?? "" - sqlCount += 1 - log("Advice: P1 execute_sql iter \(iteration): \(query)") - let sqlToolCall = ToolCall(name: "execute_sql", arguments: ["query": query], thoughtSignature: nil) - let resultStr = await ChatToolExecutor.execute(sqlToolCall) - let truncated = resultStr.count > 2000 ? String(resultStr.prefix(2000)) + "... (truncated)" : resultStr - log("Advice: P1 sql result (\(resultStr.count) chars): \(truncated)") - - contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: ["query": query]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: resultStr) - ))] - )) - continue - - case "request_screenshot": - let findings = toolCall.arguments["findings"] as? String ?? "" - investigationFindings = findings - if let idInt = toolCall.arguments["screenshot_id"] as? Int { - chosenScreenshotId = Int64(idInt) - } else if let idInt64 = toolCall.arguments["screenshot_id"] as? Int64 { - chosenScreenshotId = idInt64 - } else if let idStr = toolCall.arguments["screenshot_id"] as? String, let parsed = Int64(idStr) { - chosenScreenshotId = parsed - } else if let idDouble = toolCall.arguments["screenshot_id"] as? Double { - chosenScreenshotId = Int64(idDouble) - } - log("Advice: P1 request_screenshot iter \(iteration): id=\(chosenScreenshotId ?? 0), findings=\(findings.prefix(200))") - break // Exit phase 1 - - case "no_advice": - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "No context" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "Unknown" - log("Advice: P1 no_advice — \(contextSummary)") - return (AdviceExtractionResult( - hasAdvice: false, - advice: nil, - contextSummary: contextSummary, - currentActivity: currentActivity - ), sqlCount) - - default: - log("Advice: P1 unknown tool: \(toolCall.name), breaking") - break - } - - // Break out of loop if request_screenshot was called - if chosenScreenshotId != nil { break } - } - - // If Phase 1 exhausted without choosing a screenshot, no advice - guard let screenshotId = chosenScreenshotId, let findings = investigationFindings else { - log("Advice: Phase 1 exhausted without request_screenshot") - return (nil, sqlCount) } - // ============================================= - // PHASE 2: Single vision call with chosen screenshot - // ============================================= - - log("Advice: Phase 2 — loading screenshot \(screenshotId)") - - // Load the screenshot image - let imageData: Data - do { - guard let screenshot = try await RewindDatabase.shared.getScreenshot(id: screenshotId) else { - log("Advice: P2 screenshot not in DB: \(screenshotId)") - return (nil, sqlCount) - } - // Check active chunk - if screenshot.usesVideoStorage, let chunk = screenshot.videoChunkPath { - let activeChunk = await VideoChunkEncoder.shared.currentChunkPath - if chunk == activeChunk { - log("Advice: P2 screenshot is in active chunk, skipping") - return (nil, sqlCount) - } - } - let rawData = try await RewindStorage.shared.loadScreenshotData(for: screenshot) - imageData = Self.compressForGemini(rawData) ?? rawData - log("Advice: P2 loaded \(imageData.count) bytes (\(rawData.count) raw) from \(screenshot.appName)") - } catch { - log("Advice: P2 screenshot load failed: \(error.localizedDescription)") - return (nil, sqlCount) - } - - // Build Phase 2 prompt — compact findings + image + cross-reference instruction - let phase2Prompt = """ - INVESTIGATION FINDINGS: - \(findings) - - The screenshot below is from the app/window identified during investigation. - - Before giving advice, CROSS-REFERENCE your findings: - - Use execute_sql to check if this issue was resolved in later screenshots - - Check if the user moved on to something else (the issue may be stale) - - Verify the context is still relevant by looking at nearby timestamps - - Then call provide_advice if the insight is still valid, or no_advice if it was resolved or is no longer relevant. - """ - - let phase2Tools = buildPhase2Tools() - let base64 = imageData.base64EncodedString() - var phase2Contents: [GeminiImageToolRequest.Content] = [ - GeminiImageToolRequest.Content( - role: "user", - parts: [ - GeminiImageToolRequest.Part(text: phase2Prompt), - GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64), - ] + let hasAdvice = adviceDict["has_advice"] as? Bool ?? !adviceDict.isEmpty + guard hasAdvice else { + return AdviceExtractionResult( + hasAdvice: false, + advice: nil, + contextSummary: "Analyzed \(frame.appName)", + currentActivity: "" ) - ] - - // Phase 2 loop — model can cross-reference via SQL before deciding - for p2Iteration in 0..<5 { - let p2Contents = phase2Contents - let p2SystemPrompt = currentSystemPrompt - let p2Tools = [phase2Tools] - let p2Force = p2Iteration == 0 - let phase2Result: ToolChatResult - do { - phase2Result = try await withThrowingTimeout(seconds: 120) { - try await client.sendImageToolLoop( - contents: p2Contents, - systemPrompt: p2SystemPrompt, - tools: p2Tools, - forceToolCall: p2Force - ) - } - } catch { - log("Advice: Phase 2 failed on iteration \(p2Iteration): \(error.localizedDescription)") - throw error - } - - guard let toolCall = phase2Result.toolCalls.first else { - log("Advice: Phase 2 — no tool call on iteration \(p2Iteration), breaking") - break - } - - switch toolCall.name { - case "execute_sql": - let query = toolCall.arguments["query"] as? String ?? "" - sqlCount += 1 - log("Advice: P2 execute_sql iter \(p2Iteration): \(query)") - let sqlToolCall = ToolCall(name: "execute_sql", arguments: ["query": query], thoughtSignature: nil) - let resultStr = await ChatToolExecutor.execute(sqlToolCall) - let truncated = resultStr.count > 2000 ? String(resultStr.prefix(2000)) + "... (truncated)" : resultStr - log("Advice: P2 sql result (\(resultStr.count) chars): \(truncated)") - - phase2Contents.append(GeminiImageToolRequest.Content( - role: "model", - parts: [GeminiImageToolRequest.Part( - functionCall: .init(name: toolCall.name, args: ["query": query]), - thoughtSignature: toolCall.thoughtSignature - )] - )) - phase2Contents.append(GeminiImageToolRequest.Content( - role: "user", - parts: [GeminiImageToolRequest.Part(functionResponse: .init( - name: toolCall.name, - response: .init(result: resultStr) - ))] - )) - continue - - case "provide_advice": - log("Advice: P2 provide_advice (after \(p2Iteration) cross-reference iterations)") - return (parseProvideAdvice(toolCall), sqlCount) - - case "no_advice": - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "No context" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "Unknown" - log("Advice: P2 no_advice — \(contextSummary)") - return (AdviceExtractionResult( - hasAdvice: false, - advice: nil, - contextSummary: contextSummary, - currentActivity: currentActivity - ), sqlCount) - - default: - log("Advice: P2 unexpected tool: \(toolCall.name)") - break - } - break // Break on unexpected tool - } - return (nil, sqlCount) - } - - // MARK: - Activity Summary - - /// Query the screenshots table to build a summary of recent activity. - /// - `from`: lower bound (e.g. last analysis time or screenshot.timestamp - interval) - /// - `to`: upper bound (e.g. now or the screenshot's timestamp) - private func buildActivitySummary(from lookbackStart: Date, to referenceTime: Date) async -> String { - guard let dbQueue = await RewindDatabase.shared.getDatabaseQueue() else { - return "" } - do { - return try await dbQueue.read { db in - // Pass Date objects directly — GRDB encodes them as UTC strings - // matching the stored format. Manual DateFormatter uses local timezone - // which causes mismatches. - let rows = try Row.fetchAll(db, sql: """ - SELECT appName, windowTitle, COUNT(*) as count, - MIN(timestamp) as first_seen, MAX(timestamp) as last_seen - FROM screenshots - WHERE timestamp >= ? AND timestamp <= ? - AND appName IS NOT NULL AND appName != '' - GROUP BY appName, windowTitle - ORDER BY count DESC - LIMIT 30 - """, arguments: [lookbackStart, referenceTime]) - - if rows.isEmpty { - return "" - } - - let totalScreenshots = rows.reduce(0) { $0 + (($1["count"] as? Int64).map(Int.init) ?? ($1["count"] as? Int) ?? 0) } - let elapsedMin = referenceTime.timeIntervalSince(lookbackStart) / 60.0 - - let timeOnlyFormatter = DateFormatter() - timeOnlyFormatter.dateFormat = "HH:mm:ss" - - var lines: [String] = [] - lines.append("ACTIVITY SUMMARY (last \(Int(elapsedMin)) min, \(totalScreenshots) screenshots):") - lines.append("Time range: \(timeOnlyFormatter.string(from: lookbackStart)) – \(timeOnlyFormatter.string(from: referenceTime))") - lines.append("") - lines.append("App | Window | Screenshots | Est. Duration") - lines.append(String(repeating: "-", count: 60)) - - for row in rows { - let app = row["appName"] as? String ?? "Unknown" - let window = row["windowTitle"] as? String ?? "" - let count = (row["count"] as? Int64).map(Int.init) ?? (row["count"] as? Int) ?? 0 - let estMinutes = String(format: "%.1f", Double(count) / 60.0) - let windowDisplay = window.isEmpty ? "(no title)" : String(window.prefix(50)) - lines.append("\(app) | \(windowDisplay) | \(count) | \(estMinutes) min") - } - - let summary = lines.joined(separator: "\n") - log("Advice: Activity summary (last \(Int(elapsedMin)) min, \(totalScreenshots) screenshots)") - return summary - } - } catch { - logError("Advice: Failed to build activity summary", error: error) - return "" + let adviceText = adviceDict["content"] as? String ?? adviceDict["advice"] as? String ?? "" + guard !adviceText.isEmpty else { + return AdviceExtractionResult( + hasAdvice: false, + advice: nil, + contextSummary: "Analyzed \(frame.appName)", + currentActivity: "" + ) } - } - // MARK: - Tool Definitions - - /// Phase 1 tools: text-only investigation (execute_sql, request_screenshot, no_advice) - private func buildPhase1Tools() -> GeminiTool { - GeminiTool(functionDeclarations: [ - GeminiTool.FunctionDeclaration( - name: "execute_sql", - description: "Execute a SQL query on the local database to investigate screen activity. The screenshots table has: id INTEGER, timestamp TEXT, appName TEXT, windowTitle TEXT, ocrText TEXT, focusStatus TEXT. Use this to read OCR text from interesting windows, check what the user was doing, etc. SELECT queries only. Auto-limited to 200 rows.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init(type: "string", description: "SQL SELECT query to execute on the screenshots table") - ], - required: ["query"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "request_screenshot", - description: "Request to view a specific screenshot. Call this when you've found something interesting via SQL and want to see the actual screen. Provide the screenshot ID and a summary of your findings so far. The screenshot will be shown to you for final analysis.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "screenshot_id": .init(type: "integer", description: "The screenshot ID from the screenshots table"), - "findings": .init(type: "string", description: "Summary of what you found during investigation — what app, what OCR text caught your attention, and what you suspect might be worth advising about") - ], - required: ["screenshot_id", "findings"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "no_advice", - description: "Call this when there is nothing worth advising about. Nothing qualifies as a specific, non-obvious insight. This ends the analysis.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "High-level description of user's activity") - ], - required: ["context_summary", "current_activity"] - ) - ), - ]) - } - - /// Phase 2 tools: vision call with screenshot + SQL cross-referencing (execute_sql, provide_advice, no_advice) - private func buildPhase2Tools() -> GeminiTool { - GeminiTool(functionDeclarations: [ - GeminiTool.FunctionDeclaration( - name: "execute_sql", - description: "Cross-reference your findings by querying the database. Use this to check if an issue was resolved in later screenshots, verify context across time, or look up related activity. The screenshots table has: id INTEGER, timestamp TEXT, appName TEXT, windowTitle TEXT, ocrText TEXT, focusStatus TEXT. SELECT queries only.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init(type: "string", description: "SQL SELECT query to execute on the screenshots table") - ], - required: ["query"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "provide_advice", - description: "Call this when you have a specific, non-obvious insight for the user based on the screenshot and your investigation findings. You should cross-reference first using execute_sql to verify the issue is still relevant.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "advice": .init(type: "string", description: "The advice text (1-2 sentences, max 100 chars). Start with what you noticed, then why it matters."), - "headline": .init(type: "string", description: "Ultra-short observation (max 5 words) for notification preview. E.g. 'Draft saved in /tmp', 'Credentials visible in terminal'"), - "reasoning": .init(type: "string", description: "Brief explanation of why this advice is relevant"), - "category": .init(type: "string", description: "Category of advice", enumValues: ["productivity", "communication", "learning", "other"]), - "source_app": .init(type: "string", description: "App where context was observed"), - "confidence": .init(type: "number", description: "Confidence score 0.0-1.0. 0.90+: preventing clear mistake. 0.75-0.89: highly relevant non-obvious tip. 0.60-0.74: useful but user might know."), - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "High-level description of user's activity") - ], - required: ["advice", "headline", "category", "source_app", "confidence", "context_summary", "current_activity"] - ) - ), - GeminiTool.FunctionDeclaration( - name: "no_advice", - description: "Call this when the screenshot doesn't reveal anything worth advising about, or when cross-referencing shows the issue was already resolved.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "context_summary": .init(type: "string", description: "Brief summary of what user is looking at"), - "current_activity": .init(type: "string", description: "High-level description of user's activity") - ], - required: ["context_summary", "current_activity"] - ) - ), - ]) - } - - // MARK: - Parse Tool Results - - /// Parse the provide_advice tool call into an AdviceExtractionResult - private func parseProvideAdvice(_ toolCall: ToolCall) -> AdviceExtractionResult { - let adviceText = toolCall.arguments["advice"] as? String ?? "" - let headline = toolCall.arguments["headline"] as? String - let reasoning = toolCall.arguments["reasoning"] as? String - let categoryStr = toolCall.arguments["category"] as? String ?? "other" + let categoryStr = adviceDict["category"] as? String ?? "other" let category = AdviceCategory(rawValue: categoryStr) ?? .other - let sourceApp = toolCall.arguments["source_app"] as? String ?? "" - let contextSummary = toolCall.arguments["context_summary"] as? String ?? "" - let currentActivity = toolCall.arguments["current_activity"] as? String ?? "" - let confidence: Double - if let confValue = toolCall.arguments["confidence"] as? Double { + if let confValue = adviceDict["confidence"] as? Double { confidence = confValue - } else if let confInt = toolCall.arguments["confidence"] as? Int { + } else if let confInt = adviceDict["confidence"] as? Int { confidence = Double(confInt) - } else if let confStr = toolCall.arguments["confidence"] as? String, let parsed = Double(confStr) { - confidence = parsed } else { confidence = 0.5 } let advice = ExtractedAdvice( advice: adviceText, - headline: headline, - reasoning: reasoning, + headline: adviceDict["headline"] as? String, + reasoning: adviceDict["reasoning"] as? String, category: category, - sourceApp: sourceApp, + sourceApp: frame.appName, confidence: confidence ) - log("Advice: --- PROVIDE_ADVICE ---") - log("Advice: advice: \(adviceText)") - log("Advice: headline: \(headline ?? "(none)")") - log("Advice: reasoning: \(reasoning ?? "(none)")") - log("Advice: category: \(categoryStr)") - log("Advice: source_app: \(sourceApp)") - log("Advice: confidence: \(confidence)") - log("Advice: context: \(contextSummary)") - log("Advice: activity: \(currentActivity)") return AdviceExtractionResult( hasAdvice: true, advice: advice, - contextSummary: contextSummary, - currentActivity: currentActivity + contextSummary: adviceDict["context_summary"] as? String ?? "Analyzed \(frame.appName)", + currentActivity: adviceDict["current_activity"] as? String ?? "" ) } } - -// MARK: - Timeout Helper - -/// Run an async operation with a timeout. Throws `CancellationError` if the timeout expires. -private func withThrowingTimeout(seconds: Double, operation: @escaping @Sendable () async throws -> T) async throws -> T { - try await withThrowingTaskGroup(of: T.self) { group in - group.addTask { - try await operation() - } - group.addTask { - try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000)) - throw CancellationError() - } - // First task to complete wins; cancel the other - let result = try await group.next()! - group.cancelAll() - return result - } -} From daefcaf775b202f49cb6a8d9535d3682db5e06f5 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:19 +0100 Subject: [PATCH 144/163] Wire TaskDeduplicationService thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace GeminiClient with backendService.deduplicateTasks(). Remove prompt/schema building, local dedup logic — server handles everything. Co-Authored-By: Claude Opus 4.6 --- .../TaskDeduplicationService.swift | 214 ++---------------- 1 file changed, 22 insertions(+), 192 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift index 4618f99673..98b38a8761 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskDeduplicationService.swift @@ -6,7 +6,7 @@ import Foundation actor TaskDeduplicationService { static let shared = TaskDeduplicationService() - private var geminiClient: GeminiClient? + private var backendService: BackendProactiveService? private var timer: Task? private var isRunning = false private var lastRunTime: Date? @@ -17,13 +17,11 @@ actor TaskDeduplicationService { private let cooldownSeconds: TimeInterval = 1800 // 30-min cooldown private let minimumTaskCount = 3 - private init() { - do { - self.geminiClient = try GeminiClient(model: "gemini-pro-latest") - } catch { - log("TaskDedup: Failed to initialize GeminiClient: \(error)") - self.geminiClient = nil - } + private init() {} + + /// Set the backend service for Phase 2 server-side deduplication. + func configure(backendService: BackendProactiveService) { + self.backendService = backendService } // MARK: - Lifecycle @@ -67,210 +65,42 @@ actor TaskDeduplicationService { // MARK: - Deduplication Logic private func runDeduplication() async { - guard let client = geminiClient else { - log("TaskDedup: Skipping - Gemini client not initialized") + guard let service = backendService else { + log("TaskDedup: Skipping - backend service not configured") return } lastRunTime = Date() - log("TaskDedup: Starting deduplication run on staged tasks") - - // 1. Fetch staged tasks (not yet promoted to action items) - let tasks: [TaskActionItem] - do { - let response = try await APIClient.shared.getStagedTasks(limit: 200) - tasks = response.items - } catch { - log("TaskDedup: Failed to fetch staged tasks: \(error)") - return - } - - guard tasks.count >= minimumTaskCount else { - log("TaskDedup: Only \(tasks.count) staged tasks, skipping (minimum: \(minimumTaskCount))") - return - } - - log("TaskDedup: Analyzing \(tasks.count) staged tasks for duplicates") - - // 2. Send all tasks to Gemini in a single call - let totalDeleted = await analyzeAndDeleteDuplicates(tasks: tasks, client: client) - - log("TaskDedup: Run complete. Hard-deleted \(totalDeleted) duplicate staged tasks.") - } - - private func analyzeAndDeleteDuplicates(tasks: [TaskActionItem], client: GeminiClient) async -> Int { - // Build task list for prompt - let taskDescriptions = tasks.map { task -> String in - var parts = ["ID: \(task.id)", "Description: \(task.description)"] - if let due = task.dueAt { - parts.append("Due: \(ISO8601DateFormatter().string(from: due))") - } - if let priority = task.priority { - parts.append("Priority: \(priority)") - } - if let source = task.source { - parts.append("Source: \(source)") - } - parts.append("Created: \(ISO8601DateFormatter().string(from: task.createdAt))") - return parts.joined(separator: "\n") - }.joined(separator: "\n") - - let prompt = """ - Analyze the following tasks for semantic duplicates. Two tasks are duplicates if they \ - refer to the same action, even if worded differently. - - Tasks: - \(taskDescriptions) - - For each group of duplicates, pick the best task to KEEP based on these criteria (in order): - 1. Most descriptive/specific wording - 2. Has a due date over one that doesn't - 3. Higher priority set (high > medium > low > none) - 4. More reliable source (manual > transcription > screenshot) - 5. Most recently created - - Only flag tasks as duplicates if you are confident they refer to the same action. \ - When in doubt, do NOT flag as duplicates. - """ - - let systemPrompt = """ - You are a task deduplication assistant. You identify semantically duplicate tasks \ - and choose the best one to keep. Be conservative - only flag clear duplicates. \ - Return has_duplicates: false if no duplicates are found. - """ - - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "has_duplicates": .init(type: "boolean", description: "Whether any duplicate groups were found"), - "duplicate_groups": .init( - type: "array", - description: "Groups of duplicate tasks", - items: .init( - type: "object", - properties: [ - "keep_id": .init(type: "string", description: "ID of the task to keep"), - "delete_ids": .init( - type: "array", - description: "IDs of tasks to delete", - items: .init(type: "string", properties: nil, required: nil) - ), - "reason": .init(type: "string", description: "Why these tasks are duplicates and which was kept") - ], - required: ["keep_id", "delete_ids", "reason"] - ) - ) - ], - required: ["has_duplicates", "duplicate_groups"] - ) - - // Call Gemini - let responseText: String - do { - responseText = try await client.sendRequest( - prompt: prompt, - systemPrompt: systemPrompt, - responseSchema: responseSchema - ) - } catch { - log("TaskDedup: Gemini request failed: \(error)") - return 0 - } - - // Parse response - guard let data = responseText.data(using: .utf8) else { - log("TaskDedup: Failed to convert response to data") - return 0 - } + log("TaskDedup: Starting server-side deduplication") - let result: DedupResponse do { - result = try JSONDecoder().decode(DedupResponse.self, from: data) - } catch { - log("TaskDedup: Failed to parse response: \(error)") - return 0 - } - - guard result.hasDuplicates, !result.duplicateGroups.isEmpty else { - log("TaskDedup: No duplicates found in batch of \(tasks.count) staged tasks") - return 0 - } + let result = try await service.deduplicateTasks() - // Validate and delete - let validTaskIDs = Set(tasks.map { $0.id }) - let taskLookup = Dictionary(tasks.map { ($0.id, $0) }, uniquingKeysWith: { _, latest in latest }) - var deletedCount = 0 - - for group in result.duplicateGroups { - // Safety: verify all IDs exist in our input - guard validTaskIDs.contains(group.keepId) else { - log("TaskDedup: Skipping group - keep_id '\(group.keepId)' not in input set") - continue - } - - let validDeleteIds = group.deleteIds.filter { validTaskIDs.contains($0) } - if validDeleteIds.count != group.deleteIds.count { - log("TaskDedup: Some delete IDs not in input set, filtering") + if result.deletedIds.isEmpty { + log("TaskDedup: No duplicates found") + return } - guard !validDeleteIds.isEmpty else { continue } - - let keptTask = taskLookup[group.keepId] + log("TaskDedup: Server deleted \(result.deletedIds.count) duplicates. Reason: \(result.reason)") - for deleteId in validDeleteIds { - let deletedTask = taskLookup[deleteId] - - // Log to SQLite + // Log each deletion locally + for deleteId in result.deletedIds { let logRecord = TaskDedupLogRecord( deletedTaskId: deleteId, - deletedDescription: deletedTask?.description ?? "unknown", - keptTaskId: group.keepId, - keptDescription: keptTask?.description ?? "unknown", - reason: group.reason, + deletedDescription: "server-side dedup", + keptTaskId: "", + keptDescription: "", + reason: result.reason, deletedAt: Date() ) - do { try await ProactiveStorage.shared.insertDedupLogRecord(logRecord) } catch { log("TaskDedup: Failed to log deletion record: \(error)") } - - // Hard-delete staged task from backend - do { - try await APIClient.shared.deleteStagedTask(id: deleteId) - deletedCount += 1 - log("TaskDedup: Hard-deleted staged task '\(deletedTask?.description ?? deleteId)' (kept: '\(keptTask?.description ?? group.keepId)') - \(group.reason)") - } catch { - log("TaskDedup: Failed to delete staged task \(deleteId) on backend: \(error)") - } } - } - - return deletedCount - } -} - -// MARK: - Response Models - -private struct DedupResponse: Codable { - let hasDuplicates: Bool - let duplicateGroups: [DuplicateGroup] - - enum CodingKeys: String, CodingKey { - case hasDuplicates = "has_duplicates" - case duplicateGroups = "duplicate_groups" - } - - struct DuplicateGroup: Codable { - let keepId: String - let deleteIds: [String] - let reason: String - - enum CodingKeys: String, CodingKey { - case keepId = "keep_id" - case deleteIds = "delete_ids" - case reason + } catch { + log("TaskDedup: Server deduplication failed: \(error)") } } } From 822c3c0f46e807dd5549cd2cd110430f38ecb052 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:25 +0100 Subject: [PATCH 145/163] Wire TaskPrioritizationService thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace GeminiClient with backendService.rerankTasks(). Remove prompt/ schema building, context fetching — server handles reranking. Co-Authored-By: Claude Opus 4.6 --- .../TaskPrioritizationService.swift | 263 ++---------------- 1 file changed, 29 insertions(+), 234 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift index a1aeca228f..ac922254ef 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskPrioritizationService.swift @@ -7,7 +7,7 @@ import Foundation actor TaskPrioritizationService { static let shared = TaskPrioritizationService() - private var geminiClient: GeminiClient? + private var backendService: BackendProactiveService? private var timer: Task? private var isRunning = false private(set) var isScoringInProgress = false @@ -29,13 +29,6 @@ actor TaskPrioritizationService { // Restore persisted timestamps self.lastFullRunTime = UserDefaults.standard.object(forKey: Self.fullRunKey) as? Date - do { - self.geminiClient = try GeminiClient(model: "gemini-pro-latest") - } catch { - log("TaskPrioritize: Failed to initialize GeminiClient: \(error)") - self.geminiClient = nil - } - if let last = self.lastFullRunTime { let hoursAgo = Int(Date().timeIntervalSince(last) / 3600) log("TaskPrioritize: Last full rescore was \(hoursAgo)h ago") @@ -44,6 +37,11 @@ actor TaskPrioritizationService { } } + /// Set the backend service for Phase 2 server-side reranking. + func configure(backendService: BackendProactiveService) { + self.backendService = backendService + } + // MARK: - Lifecycle func start() { @@ -101,187 +99,48 @@ actor TaskPrioritizationService { // MARK: - Full Rescore (Hourly) - /// Send ALL staged tasks to Gemini, get back only the ones that need re-ranking + /// Request server-side reranking via backend WebSocket. private func runFullRescore() async { guard !isScoringInProgress else { log("TaskPrioritize: [FULL] Skipping — scoring already in progress") return } - guard let client = geminiClient else { - log("TaskPrioritize: Skipping full rescore — Gemini client not initialized") + guard let service = backendService else { + log("TaskPrioritize: Skipping full rescore — backend service not configured") return } isScoringInProgress = true defer { isScoringInProgress = false } - log("TaskPrioritize: [FULL] Starting hourly rescore of staged tasks") + log("TaskPrioritize: [FULL] Starting server-side rescore") - // Get ALL staged tasks (not action_items) - let allTasks: [TaskActionItem] do { - allTasks = try await StagedTaskStorage.shared.getAllStagedTasks(limit: 10000) - } catch { - log("TaskPrioritize: [FULL] Failed to fetch staged tasks: \(error)") - return - } - - log("TaskPrioritize: [FULL] Found \(allTasks.count) staged tasks") + let result = try await service.rerankTasks() - guard allTasks.count >= minimumTaskCount else { - log("TaskPrioritize: [FULL] Only \(allTasks.count) staged tasks, skipping") - lastFullRunTime = Date() - return - } - - // Fetch context - let (referenceContext, profile, goals) = await fetchContext() - - // Build the current ranking: tasks ordered by relevanceScore ASC (1 = top) - let sortedTasks = allTasks.sorted { a, b in - let scoreA = a.relevanceScore ?? Int.max - let scoreB = b.relevanceScore ?? Int.max - return scoreA < scoreB - } - - // Build task list for the prompt with current positions - let taskLines = sortedTasks.enumerated().map { (index, task) -> String in - var parts = ["\(index + 1). [id:\(task.id)] \(task.description)"] - if let priority = task.priority { - parts.append("[\(priority)]") - } - if let due = task.dueAt { - let formatter = ISO8601DateFormatter() - parts.append("[due: \(formatter.string(from: due))]") + if result.updatedTasks.isEmpty { + log("TaskPrioritize: [FULL] No tasks need re-ranking, current order is good") + lastFullRunTime = Date() + return } - return parts.joined(separator: " ") - }.joined(separator: "\n") - - // Build context sections - var contextParts: [String] = [] - if let profile = profile, !profile.isEmpty { - contextParts.append("USER PROFILE:\n\(profile)") - } + // Parse server response into reranking tuples + let reranks: [(backendId: String, newPosition: Int)] = result.updatedTasks.compactMap { dict in + guard let id = dict["id"] as? String, + let newPos = dict["new_position"] as? Int else { return nil } + return (backendId: id, newPosition: newPos) + } - if !goals.isEmpty { - let goalsText = goals.enumerated().map { (i, goal) in - var text = "\(i + 1). \(goal.title)" - if let desc = goal.description { - text += " — \(desc)" + if !reranks.isEmpty { + do { + try await StagedTaskStorage.shared.applySelectiveReranking(reranks) + log("TaskPrioritize: [FULL] Applied server re-ranking for \(reranks.count) staged tasks") + } catch { + log("TaskPrioritize: [FULL] Failed to apply re-ranking: \(error)") } - text += " (\(Int(goal.progress))% complete)" - return text - }.joined(separator: "\n") - contextParts.append("ACTIVE GOALS:\n\(goalsText)") - } - - if !referenceContext.isEmpty { - contextParts.append(referenceContext) - } - - let contextSection = contextParts.isEmpty ? "" : contextParts.joined(separator: "\n\n") + "\n\n" - - let prompt = """ - Review the user's staged task list (ranked 1 = most important, \(sortedTasks.count) = least important). - - Identify tasks that are MISRANKED — tasks whose current position doesn't match their actual importance. - Only return tasks that need to move. Do NOT return tasks that are already well-positioned. - - Consider: - 1. Alignment with the user's goals and current priorities - 2. Time urgency (due date proximity) - 3. Actionability — specific tasks rank higher than vague ones - 4. Real-world importance (financial, health, commitments to others) - 5. Most AI-extracted tasks are noise — push vague/irrelevant tasks down - - \(contextSection)CURRENT TASK RANKING (1 = most important): - \(taskLines) - - Return ONLY the tasks that need re-ranking, with their new position numbers. - New positions should be relative to the current list size (1 to \(sortedTasks.count)). - """ - - let systemPrompt = """ - You are a task prioritization assistant. You review a ranked task list and identify \ - tasks that are misranked. Be selective — only return tasks that genuinely need to move. \ - If the ranking looks reasonable, return an empty list. Be decisive about pushing noise \ - and vague tasks down and promoting urgent, goal-aligned tasks up. - """ - - let responseSchema = GeminiRequest.GenerationConfig.ResponseSchema( - type: "object", - properties: [ - "reranked_tasks": .init( - type: "array", - description: "Tasks that need to be moved, with new positions", - items: .init( - type: "object", - properties: [ - "task_id": .init(type: "string", description: "The task ID"), - "new_position": .init(type: "integer", description: "New rank position (1 = most important)") - ], - required: ["task_id", "new_position"] - ) - ), - "reasoning": .init(type: "string", description: "Brief explanation of major ranking changes") - ], - required: ["reranked_tasks", "reasoning"] - ) - - log("TaskPrioritize: [FULL] Sending \(sortedTasks.count) staged tasks to Gemini") - - let responseText: String - do { - responseText = try await client.sendRequest( - prompt: prompt, - systemPrompt: systemPrompt, - responseSchema: responseSchema - ) - } catch { - log("TaskPrioritize: [FULL] Gemini request failed: \(error)") - return - } - - let truncated = responseText.prefix(500) - log("TaskPrioritize: [FULL] Gemini response (\(responseText.count) chars): \(truncated)\(responseText.count > 500 ? "..." : "")") - - guard let data = responseText.data(using: .utf8) else { - log("TaskPrioritize: [FULL] Failed to convert response to data") - return - } - - let result: ReRankingResponse - do { - result = try JSONDecoder().decode(ReRankingResponse.self, from: data) - } catch { - log("TaskPrioritize: [FULL] Failed to parse re-ranking response: \(error)") - return - } - - log("TaskPrioritize: [FULL] Gemini returned \(result.rerankedTasks.count) tasks to re-rank") - if !result.reasoning.isEmpty { - log("TaskPrioritize: [FULL] Reasoning: \(result.reasoning.prefix(300))") - } - - // Validate: only keep task IDs that exist in our list - let validIds = Set(allTasks.map { $0.id }) - let validReranks = result.rerankedTasks.filter { validIds.contains($0.taskId) } - - if validReranks.count != result.rerankedTasks.count { - log("TaskPrioritize: [FULL] Filtered out \(result.rerankedTasks.count - validReranks.count) invalid task IDs") - } - - if !validReranks.isEmpty { - let reranks = validReranks.map { (backendId: $0.taskId, newPosition: $0.newPosition) } - do { - try await StagedTaskStorage.shared.applySelectiveReranking(reranks) - log("TaskPrioritize: [FULL] Applied selective re-ranking for \(validReranks.count) staged tasks") - } catch { - log("TaskPrioritize: [FULL] Failed to apply re-ranking: \(error)") } - } else { - log("TaskPrioritize: [FULL] No tasks need re-ranking, current order is good") + } catch { + log("TaskPrioritize: [FULL] Server reranking failed: \(error)") } lastFullRunTime = Date() @@ -304,68 +163,4 @@ actor TaskPrioritizationService { } } - // MARK: - Shared Context Fetching - - private func fetchContext() async -> (referenceContext: String, profile: String?, goals: [Goal]) { - let userProfile = await AIUserProfileService.shared.getLatestProfile() - - let goals: [Goal] - do { - goals = try await APIClient.shared.getGoals() - } catch { - log("TaskPrioritize: Failed to fetch goals: \(error)") - goals = [] - } - - let referenceTasks: [TaskActionItem] - do { - referenceTasks = try await ActionItemStorage.shared.getLocalActionItems( - limit: 100, - completed: true - ) - } catch { - log("TaskPrioritize: Failed to fetch reference tasks: \(error)") - referenceTasks = [] - } - let referenceContext = buildReferenceContext(referenceTasks) - - return (referenceContext, userProfile?.profileText, goals) - } - - // MARK: - Context Builders - - private func buildReferenceContext(_ tasks: [TaskActionItem]) -> String { - guard !tasks.isEmpty else { return "" } - - let completed = tasks.filter { !($0.description.isEmpty) }.prefix(50) - guard !completed.isEmpty else { return "" } - - let lines = completed.map { task -> String in - "- [completed] \(task.description)" - }.joined(separator: "\n") - - return "TASKS THE USER HAS COMPLETED (for reference — do NOT rank these):\n\(lines)" - } -} - -// MARK: - Response Models - -private struct ReRankingResponse: Codable { - let rerankedTasks: [ReRankedTask] - let reasoning: String - - struct ReRankedTask: Codable { - let taskId: String - let newPosition: Int - - enum CodingKeys: String, CodingKey { - case taskId = "task_id" - case newPosition = "new_position" - } - } - - enum CodingKeys: String, CodingKey { - case rerankedTasks = "reranked_tasks" - case reasoning - } } From d7c6cf49a4ea8b3911578b1e574cd9164dcfcb13 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:33 +0100 Subject: [PATCH 146/163] Wire AIUserProfileService thin client for Phase 2 (#5396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace 2-stage Gemini profile generation with backendService.requestProfile(). Remove fetchDataSources, buildPrompt, buildConsolidationPrompt — server fetches user data from Firestore and generates profile server-side. Co-Authored-By: Claude Opus 4.6 --- .../Services/AIUserProfileService.swift | 307 ++---------------- 1 file changed, 20 insertions(+), 287 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift index ec36f82ece..111615bf75 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift @@ -34,7 +34,7 @@ extension AIUserProfileRecord: TableDocumented { actor AIUserProfileService { static let shared = AIUserProfileService() - private let model = "gemini-pro-latest" + private var backendService: BackendProactiveService? private let maxProfileLength = 10000 /// Whether profile generation is currently in progress @@ -48,6 +48,11 @@ actor AIUserProfileService { _dbQueue = nil } + /// Set the backend service for Phase 2 server-side profile generation. + func configure(backendService: BackendProactiveService) { + self.backendService = backendService + } + // MARK: - Database Access private func ensureDB() async throws -> DatabasePool { @@ -211,314 +216,42 @@ actor AIUserProfileService { }) ?? [] } - /// Generate a new AI user profile from all available data sources + /// Generate a new AI user profile via backend WebSocket. + /// The backend fetches all user data from Firestore and generates the profile server-side. func generateProfile() async throws -> AIUserProfileRecord { guard !isGenerating else { throw ProfileError.alreadyGenerating } + guard let service = backendService else { + throw ProfileError.databaseNotAvailable + } isGenerating = true defer { isGenerating = false } - log("AIUserProfileService: Starting profile generation") - - // 1. Fetch all data sources in parallel - let (memories, tasks, goals, conversations, messages) = await fetchDataSources() - - // 2. Count total data items - let dataSourcesUsed = memories.count + tasks.count + goals.count + conversations.count + messages.count - log("AIUserProfileService: Fetched \(dataSourcesUsed) data items (memories=\(memories.count), tasks=\(tasks.count), goals=\(goals.count), convos=\(conversations.count), messages=\(messages.count))") - - guard dataSourcesUsed > 0 else { - throw ProfileError.insufficientData - } - - // 3. Build prompt - let prompt = buildPrompt(memories: memories, tasks: tasks, goals: goals, conversations: conversations, messages: messages) - - // 4. Call Gemini - let gemini = try GeminiClient(model: model) - let systemPrompt = """ - You are generating a structured user profile that will be injected as context into AI pipelines \ - (task extraction, goal extraction, memory extraction) that analyze the user's screen and audio activity. - - OUTPUT FORMAT: - - A flat list of factual statements, one per line, prefixed with "- " - - Each statement must be a concrete fact directly supported by the provided data - - No prose, no paragraphs, no headers, no markdown formatting - - No adjectives like "passionate", "dedicated", "impressive" - - Write in third person ("User works at...", not "You work at...") - - WHAT TO INCLUDE (only if clearly supported by the data): - - Full name, role, company, industry - - Current projects and what tools/apps they use for each - - Key people they interact with (names, roles, relationship) - - Active goals and their progress - - Recurring meetings, deadlines, routines - - Communication platforms they use (Slack, email, iMessage, etc.) - - Technical stack, programming languages, frameworks - - Topics they frequently discuss or research - - Pending tasks and commitments to others - - Time zone, work schedule patterns - - CRITICAL RULES: - - ONLY include facts that are directly evidenced in the provided data - - If a category has no supporting data, skip it entirely — do not guess or infer - - Do NOT hallucinate names, roles, companies, or relationships not present in the data - - Do NOT add personality descriptions or subjective assessments - - When uncertain, omit rather than speculate - - NEVER fabricate email addresses, phone numbers, URLs, or contact information - - The provided data contains NO email addresses — do not invent any - - If you cannot find a piece of information verbatim in the data, do not include it - - The output MUST be under 2000 characters total. - """ - - let stageOneText = try await gemini.sendTextRequest(prompt: prompt, systemPrompt: systemPrompt) - log("AIUserProfileService: Stage 1 complete (\(stageOneText.count) chars)") - - // 5. Stage 2 — Consolidate with past profiles for holistic view - let pastProfiles = await getAllProfiles(limit: 5) - let finalText: String - if pastProfiles.isEmpty { - finalText = stageOneText - } else { - let consolidationPrompt = buildConsolidationPrompt( - newProfile: stageOneText, - pastProfiles: pastProfiles - ) - let consolidationSystemPrompt = """ - You are merging a newly generated user profile with historical profiles to create \ - one holistic, up-to-date user profile. This profile is injected as context into AI pipelines \ - (task extraction, goal extraction, memory extraction) that analyze the user's screen and audio activity. - - OUTPUT FORMAT: - - A flat list of factual statements, one per line, prefixed with "- " - - Each statement must be a concrete fact - - No prose, no paragraphs, no headers, no markdown formatting - - No adjectives or subjective assessments - - Write in third person - - MERGE RULES: - - The NEW profile reflects today's data and takes priority for current state - - Past profiles provide historical context — retain facts that are still relevant - - If a fact from the past contradicts the new profile, use the new one - - Remove outdated information (completed tasks, past deadlines, old routines) - - Keep stable facts (name, role, company, key relationships, tech stack) - - Accumulate knowledge: if past profiles mention people, projects, or patterns \ - not in today's data, keep them if they seem ongoing - - Do NOT hallucinate — only include facts present in the provided profiles - - Do NOT add commentary about changes or evolution over time - - The output MUST be under 2000 characters total. - """ - finalText = try await gemini.sendTextRequest( - prompt: consolidationPrompt, - systemPrompt: consolidationSystemPrompt - ) - log("AIUserProfileService: Stage 2 consolidation complete (\(finalText.count) chars)") - } + log("AIUserProfileService: Requesting server-side profile generation") - // 6. Truncate if needed - let truncated = String(finalText.prefix(maxProfileLength)) + let profileText = try await service.requestProfile() + let truncated = String(profileText.prefix(maxProfileLength)) let generatedAt = Date() - // 6. Save to database + log("AIUserProfileService: Received profile from backend (\(truncated.count) chars)") + + // Save to local database let db = try await ensureDB() let record = AIUserProfileRecord( profileText: truncated, - dataSourcesUsed: dataSourcesUsed, - backendSynced: false, + dataSourcesUsed: 0, + backendSynced: true, // Backend already has it generatedAt: generatedAt ) try await db.write { database in try record.insert(database) } - // 7. Sync to backend (fire-and-forget) - let recordId = record.id - Task { - do { - try await APIClient.shared.syncAIUserProfile( - profileText: truncated, - generatedAt: generatedAt, - dataSourcesUsed: dataSourcesUsed - ) - // Mark as synced - if let id = recordId, let db = try? await self.ensureDB() { - _ = try? await db.write { database in - try database.execute( - sql: "UPDATE ai_user_profiles SET backendSynced = 1 WHERE id = ?", - arguments: [id] - ) - } - } - log("AIUserProfileService: Synced profile to backend") - } catch { - log("AIUserProfileService: Failed to sync profile to backend: \(error.localizedDescription)") - } - } - - log("AIUserProfileService: Profile generated successfully (\(truncated.count) chars, \(dataSourcesUsed) data items)") + log("AIUserProfileService: Profile saved to local DB") return record } - // MARK: - Data Fetching - - private func fetchDataSources() async -> ( - memories: [String], - tasks: [String], - goals: [String], - conversations: [String], - messages: [String] - ) { - async let memoriesTask = fetchMemories() - async let tasksTask = fetchTasks() - async let goalsTask = fetchGoals() - async let conversationsTask = fetchConversations() - async let messagesTask = fetchMessages() - - let memories = await memoriesTask - let tasks = await tasksTask - let goals = await goalsTask - let conversations = await conversationsTask - let messages = await messagesTask - - return (memories, tasks, goals, conversations, messages) - } - - private func fetchMemories() async -> [String] { - do { - let memories = try await APIClient.shared.getMemories(limit: 100) - return memories.map { "[\($0.category.rawValue)] \($0.content)" } - } catch { - log("AIUserProfileService: Failed to fetch memories: \(error.localizedDescription)") - return [] - } - } - - private func fetchTasks() async -> [String] { - do { - let response = try await APIClient.shared.getActionItems(limit: 50) - return response.items.map { item in - let status = item.completed ? "done" : "todo" - let priority = item.priority ?? "medium" - return "[\(status)/\(priority)] \(item.description)" - } - } catch { - log("AIUserProfileService: Failed to fetch tasks: \(error.localizedDescription)") - return [] - } - } - - private func fetchGoals() async -> [String] { - do { - let goals = try await APIClient.shared.getGoals() - return goals.filter { $0.isActive }.map { goal in - let progress = goal.targetValue > 0 ? Int((goal.currentValue / goal.targetValue) * 100) : 0 - return "\(goal.title) (\(progress)% complete)" - } - } catch { - log("AIUserProfileService: Failed to fetch goals: \(error.localizedDescription)") - return [] - } - } - - private func fetchConversations() async -> [String] { - do { - let sevenDaysAgo = Calendar.current.date(byAdding: .day, value: -7, to: Date()) - let conversations = try await APIClient.shared.getConversations( - limit: 20, - startDate: sevenDaysAgo - ) - return conversations.compactMap { convo in - let title = convo.structured.title - let summary = convo.structured.overview - guard !title.isEmpty else { return nil } - return "\(title): \(summary)" - } - } catch { - log("AIUserProfileService: Failed to fetch conversations: \(error.localizedDescription)") - return [] - } - } - - private func fetchMessages() async -> [String] { - do { - let messages = try await APIClient.shared.getMessages(limit: 30) - return messages.map { "[\($0.sender)] \($0.text)" } - } catch { - log("AIUserProfileService: Failed to fetch messages: \(error.localizedDescription)") - return [] - } - } - - // MARK: - Prompt Building - - private func buildPrompt( - memories: [String], - tasks: [String], - goals: [String], - conversations: [String], - messages: [String] - ) -> String { - var sections: [String] = [] - - if !memories.isEmpty { - sections.append("## Memories about the user\n\(memories.joined(separator: "\n"))") - } - - if !tasks.isEmpty { - sections.append("## Recent tasks\n\(tasks.joined(separator: "\n"))") - } - - if !goals.isEmpty { - sections.append("## Active goals\n\(goals.joined(separator: "\n"))") - } - - if !conversations.isEmpty { - sections.append("## Recent conversations (past 7 days)\n\(conversations.joined(separator: "\n"))") - } - - if !messages.isEmpty { - sections.append("## Recent AI chat messages\n\(messages.joined(separator: "\n"))") - } - - return """ - Generate a factual user profile from the following data. \ - Output a flat list of concrete facts (one per line, prefixed with "- "). \ - This profile will be used as context for AI pipelines that analyze the user's screen and audio activity \ - to extract tasks, goals, and memories. Focus on facts that help identify who is who, what projects are active, \ - and what the user's current priorities are. Under 2000 characters. - - \(sections.joined(separator: "\n\n")) - """ - } - - private func buildConsolidationPrompt( - newProfile: String, - pastProfiles: [AIUserProfileRecord] - ) -> String { - let dateFormatter = DateFormatter() - dateFormatter.dateStyle = .medium - dateFormatter.timeStyle = .none - - var pastSection = "" - for profile in pastProfiles { - let dateStr = dateFormatter.string(from: profile.generatedAt) - pastSection += "--- Profile from \(dateStr) ---\n\(profile.profileText)\n\n" - } - - return """ - Merge the following into one holistic user profile. Under 2000 characters. - - === NEW PROFILE (generated today from latest data) === - \(newProfile) - - === PAST PROFILES (oldest to newest, up to 5) === - \(pastSection) - """ - } - // MARK: - Errors enum ProfileError: LocalizedError { From 336122919d23fc319f12f5c569ba289f764bc367 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 8 Mar 2026 10:36:40 +0100 Subject: [PATCH 147/163] Wire ProactiveAssistantsPlugin to pass backendService to all assistants (#5396) Pass shared BackendProactiveService to all 4 assistants and 3 text-only services. Remove do/catch since inits no longer throw. Update AdviceTestRunnerWindow fallback creation. Co-Authored-By: Claude Opus 4.6 --- .../ProactiveAssistantsPlugin.swift | 88 +++++++++---------- .../UI/AdviceTestRunnerWindow.swift | 22 ++--- 2 files changed, 47 insertions(+), 63 deletions(-) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift index d6d3409675..c721c1bd5a 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift @@ -325,62 +325,58 @@ public class ProactiveAssistantsPlugin: NSObject { proactiveService.connect() backendProactiveService = proactiveService - do { - focusAssistant = FocusAssistant( - backendService: proactiveService, - onAlert: { [weak self] message in - self?.sendEvent(type: "alert", data: ["message": message]) - }, - onStatusChange: { [weak self] status in - Task { @MainActor in - self?.lastStatus = status - self?.sendEvent(type: "statusChange", data: ["status": status.rawValue]) - } - }, - onRefocus: { - Task { @MainActor in - OverlayService.shared.showGlowAroundActiveWindow(colorMode: .focused) - } - }, - onDistraction: { - Task { @MainActor in - OverlayService.shared.showGlowAroundActiveWindow(colorMode: .distracted) - } + focusAssistant = FocusAssistant( + backendService: proactiveService, + onAlert: { [weak self] message in + self?.sendEvent(type: "alert", data: ["message": message]) + }, + onStatusChange: { [weak self] status in + Task { @MainActor in + self?.lastStatus = status + self?.sendEvent(type: "statusChange", data: ["status": status.rawValue]) + } + }, + onRefocus: { + Task { @MainActor in + OverlayService.shared.showGlowAroundActiveWindow(colorMode: .focused) + } + }, + onDistraction: { + Task { @MainActor in + OverlayService.shared.showGlowAroundActiveWindow(colorMode: .distracted) } - ) - - if let focus = focusAssistant { - AssistantCoordinator.shared.register(focus) } + ) - taskAssistant = try TaskAssistant() + if let focus = focusAssistant { + AssistantCoordinator.shared.register(focus) + } - if let task = taskAssistant { - AssistantCoordinator.shared.register(task) - } + taskAssistant = TaskAssistant(backendService: proactiveService) - Task { await TaskDeduplicationService.shared.start() } - Task { await TaskPrioritizationService.shared.start() } - Task { await TaskPromotionService.shared.start() } + if let task = taskAssistant { + AssistantCoordinator.shared.register(task) + } - adviceAssistant = try AdviceAssistant() + // Configure text-only services with backend service + Task { await TaskDeduplicationService.shared.configure(backendService: proactiveService) } + Task { await TaskPrioritizationService.shared.configure(backendService: proactiveService) } + Task { await AIUserProfileService.shared.configure(backendService: proactiveService) } - if let advice = adviceAssistant { - AssistantCoordinator.shared.register(advice) - } + Task { await TaskDeduplicationService.shared.start() } + Task { await TaskPrioritizationService.shared.start() } + Task { await TaskPromotionService.shared.start() } - memoryAssistant = try MemoryAssistant() + adviceAssistant = AdviceAssistant(backendService: proactiveService) - if let memory = memoryAssistant { - AssistantCoordinator.shared.register(memory) - } + if let advice = adviceAssistant { + AssistantCoordinator.shared.register(advice) + } - } catch { - log("ProactiveAssistantsPlugin: Failed to initialize assistants: \(error.localizedDescription)") - logError("ProactiveAssistantsPlugin: Assistant initialization failed", error: error) - isStartingMonitoring = false - completion(false, error.localizedDescription) - return + memoryAssistant = MemoryAssistant(backendService: proactiveService) + + if let memory = memoryAssistant { + AssistantCoordinator.shared.register(memory) } // Get initial app state diff --git a/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift b/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift index 925c73f312..a87b4582b5 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/UI/AdviceTestRunnerWindow.swift @@ -441,17 +441,9 @@ struct AdviceTestRunnerView: View { adviceAssistant = existing log("AdviceTestRunner: Using existing AdviceAssistant from coordinator") } else { - do { - adviceAssistant = try AdviceAssistant() - log("AdviceTestRunner: Created fresh AdviceAssistant instance") - } catch { - log("AdviceTestRunner: ERROR - Failed to create AdviceAssistant: \(error)") - await MainActor.run { - statusMessage = "Failed to create Advice Assistant: \(error.localizedDescription)" - isRunning = false - } - return - } + let service = BackendProactiveService(); service.connect() + adviceAssistant = AdviceAssistant(backendService: service) + log("AdviceTestRunner: Created fresh AdviceAssistant instance") } // Get excluded apps @@ -647,12 +639,8 @@ enum AdviceTestRunner { if let existing = coordAssistant as? AdviceAssistant { adviceAssistant = existing } else { - do { - adviceAssistant = try AdviceAssistant() - } catch { - log("AdviceTestCLI: ERROR — Failed to create AdviceAssistant: \(error)") - return - } + let service = BackendProactiveService(); service.connect() + adviceAssistant = AdviceAssistant(backendService: service) } // Get excluded apps From e8e5820e0162ebf8a1d56fd3d92ed4a76306abfe Mon Sep 17 00:00:00 2001 From: beastoin Date: Mon, 9 Mar 2026 05:20:34 +0100 Subject: [PATCH 148/163] Wire LiveNotesMonitor thin client for Phase 2 (#5396) Replace direct GeminiClient usage with BackendProactiveService. Uses configure(backendService:) singleton pattern matching other text-based services. Prompt logic moves server-side. Co-Authored-By: Claude Opus 4.6 --- .../Sources/LiveNotes/LiveNotesMonitor.swift | 67 ++++++------------- 1 file changed, 20 insertions(+), 47 deletions(-) diff --git a/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift b/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift index f859f973a4..fdcc7944a9 100644 --- a/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift +++ b/desktop/Desktop/Sources/LiveNotes/LiveNotesMonitor.swift @@ -45,23 +45,12 @@ class LiveNotesMonitor: ObservableObject { /// Existing notes for context (to avoid repetition) private var existingNotesContext: [String] = [] - /// GeminiClient for AI generation (lazily initialized) - private var geminiClient: GeminiClient? + /// Backend service for AI generation (injected via configure()) + private var backendService: BackendProactiveService? /// Cancellables for subscriptions private var cancellables = Set() - /// AI prompt for note generation (from m13v/meeting) - private let noteGenerationPrompt = """ - generate a single, concise note about what happened in this segment. - be factual and specific. - focus on the key point or action item. - keep it a few word sentence. - do not use quotes. - do not use wrapping words like "discussion on", jump straight into note. - avoid repeating information from existing notes. - """ - private init() { // Subscribe to transcript changes LiveTranscriptMonitor.shared.$segments @@ -72,6 +61,12 @@ class LiveNotesMonitor: ObservableObject { .store(in: &cancellables) } + /// Configure with backend service (call before startSession) + func configure(backendService: BackendProactiveService) { + self.backendService = backendService + log("LiveNotesMonitor: Configured with BackendProactiveService") + } + // MARK: - Session Lifecycle /// Start a new notes session @@ -85,15 +80,8 @@ class LiveNotesMonitor: ObservableObject { lastProcessedSegmentEnd = nil existingNotesContext = [] - // Initialize Gemini client if not already done - if geminiClient == nil { - do { - // Use Gemini 3 Pro for better note generation quality - geminiClient = try GeminiClient(model: "gemini-pro-latest") - log("LiveNotesMonitor: GeminiClient initialized with gemini-pro-latest") - } catch { - logError("LiveNotesMonitor: Failed to initialize GeminiClient", error: error) - } + if backendService == nil { + log("LiveNotesMonitor: WARNING — backendService not configured, AI notes disabled") } // Load any existing notes from DB (for crash recovery) @@ -252,10 +240,10 @@ class LiveNotesMonitor: ObservableObject { } } - /// Generate an AI note from recent transcript + /// Generate an AI note from recent transcript via backend private func generateNote(from segments: [SpeakerSegment]) { guard let sessionId = currentSessionId, - let client = geminiClient, + let service = backendService, !isGenerating else { return } isGenerating = true @@ -265,34 +253,21 @@ class LiveNotesMonitor: ObservableObject { let segmentStartOrder = max(0, currentSegmentOrder - 3) let segmentEndOrder = currentSegmentOrder - // Build context from existing notes - let existingNotesText = existingNotesContext.isEmpty - ? "No existing notes yet." + // Build session context from existing notes + let sessionContext = existingNotesContext.isEmpty + ? "" : "Existing notes:\n" + existingNotesContext.map { "- \($0)" }.joined(separator: "\n") - let prompt = """ - Transcript segment: - \(recentText) - - \(existingNotesText) - - \(noteGenerationPrompt) - """ - Task { do { - let response = try await client.sendTextRequest( - prompt: prompt, - systemPrompt: "You are a concise note-taker. Generate a single short note (3-10 words) about the key point in the transcript. Do not use quotes. Be direct and specific." - ) + let noteText = try await service.generateLiveNote(text: recentText, sessionContext: sessionContext) - // Clean up the response - let noteText = response + let cleaned = noteText .trimmingCharacters(in: .whitespacesAndNewlines) .replacingOccurrences(of: "\"", with: "") .replacingOccurrences(of: "'", with: "") - guard !noteText.isEmpty else { + guard !cleaned.isEmpty else { await MainActor.run { self.isGenerating = false } return } @@ -300,7 +275,7 @@ class LiveNotesMonitor: ObservableObject { // Save to DB let record = try await NoteStorage.shared.createNote( sessionId: sessionId, - text: noteText, + text: cleaned, isAiGenerated: true, segmentStartOrder: segmentStartOrder, segmentEndOrder: segmentEndOrder @@ -309,8 +284,7 @@ class LiveNotesMonitor: ObservableObject { if let note = record.toLiveNote() { await MainActor.run { self.notes.append(note) - self.existingNotesContext.append(noteText) - // Trim context to prevent unbounded growth (keep most recent notes) + self.existingNotesContext.append(cleaned) if self.existingNotesContext.count > self.maxExistingNotesContext { self.existingNotesContext.removeFirst(self.existingNotesContext.count - self.maxExistingNotesContext) } @@ -321,7 +295,6 @@ class LiveNotesMonitor: ObservableObject { await MainActor.run { self.isGenerating = false } } } catch let dbError as DatabaseError where dbError.resultCode == .SQLITE_CONSTRAINT { - // Session was deleted during async AI generation — not an error log("LiveNotesMonitor: Session \(sessionId) deleted during note generation, skipping") await MainActor.run { self.isGenerating = false } } catch { From 15bf1ec6a192df96f3a5dced873559363dfeea11 Mon Sep 17 00:00:00 2001 From: beastoin Date: Mon, 9 Mar 2026 05:20:35 +0100 Subject: [PATCH 149/163] Wire LiveNotesMonitor in ProactiveAssistantsPlugin (#5396) Add configure(backendService:) call for LiveNotesMonitor alongside other singleton text-based services. Co-Authored-By: Claude Opus 4.6 --- .../Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift index c721c1bd5a..0a5c31a0b4 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/ProactiveAssistantsPlugin.swift @@ -362,6 +362,7 @@ public class ProactiveAssistantsPlugin: NSObject { Task { await TaskDeduplicationService.shared.configure(backendService: proactiveService) } Task { await TaskPrioritizationService.shared.configure(backendService: proactiveService) } Task { await AIUserProfileService.shared.configure(backendService: proactiveService) } + Task { await LiveNotesMonitor.shared.configure(backendService: proactiveService) } Task { await TaskDeduplicationService.shared.start() } Task { await TaskPrioritizationService.shared.start() } From 61ee9c3f90ccc6cab8b5e42f9610d6d029948737 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 07:12:48 +0100 Subject: [PATCH 150/163] Swap dev plist to based-hardware-dev Firebase project Update GoogleService-Info-Dev.plist with dev Firebase values: API_KEY, PROJECT_ID, STORAGE_BUCKET, GCM_SENDER_ID, GOOGLE_APP_ID. Fixes #5536 Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/GoogleService-Info-Dev.plist | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/desktop/Desktop/Sources/GoogleService-Info-Dev.plist b/desktop/Desktop/Sources/GoogleService-Info-Dev.plist index 9602a49423..c761a117c6 100644 --- a/desktop/Desktop/Sources/GoogleService-Info-Dev.plist +++ b/desktop/Desktop/Sources/GoogleService-Info-Dev.plist @@ -9,17 +9,17 @@ ANDROID_CLIENT_ID 208440318997-1ek8tj5oa9ljmnh8tgehk27nqpivivbf.apps.googleusercontent.com API_KEY - AIzaSyD9dzBdglc7IO9pPDIOvqnCoTis_xKkkC8 + AIzaSyBK-G7KmEoC72mR10gmQyb2NFBbZyDvcqM GCM_SENDER_ID - 208440318997 + 1031333818730 PLIST_VERSION 1 BUNDLE_ID com.omi.desktop-dev PROJECT_ID - based-hardware + based-hardware-dev STORAGE_BUCKET - based-hardware.firebasestorage.app + based-hardware-dev.firebasestorage.app IS_ADS_ENABLED IS_ANALYTICS_ENABLED @@ -31,6 +31,6 @@ IS_SIGNIN_ENABLED GOOGLE_APP_ID - 1:208440318997:ios:a1906bb92fe244810e421c + 1:1031333818730:ios:3bea63d8e4f41dbfafb513 \ No newline at end of file From 7796471de60ee39c4ca0349901640016227909e3 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 07:12:54 +0100 Subject: [PATCH 151/163] Read Firebase API key from plist at runtime instead of hardcoding Dev builds load GoogleService-Info-Dev.plist (via run.sh), prod builds load GoogleService-Info.plist. AuthService now reads API_KEY from whichever plist is in the bundle, with prod key as fallback. Fixes #5536 Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/AuthService.swift | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/desktop/Desktop/Sources/AuthService.swift b/desktop/Desktop/Sources/AuthService.swift index d42e27c1fe..cd1d4eeecc 100644 --- a/desktop/Desktop/Sources/AuthService.swift +++ b/desktop/Desktop/Sources/AuthService.swift @@ -70,8 +70,15 @@ class AuthService { private let kAuthTokenExpiry = "auth_tokenExpiry" private let kAuthTokenUserId = "auth_tokenUserId" // User ID that owns the stored token - // Firebase Web API key (from GoogleService-Info.plist) - private let firebaseApiKey = "AIzaSyD9dzBdglc7IO9pPDIOvqnCoTis_xKkkC8" + // Firebase Web API key (read from active GoogleService-Info.plist at runtime) + private let firebaseApiKey: String = { + if let path = Bundle.main.path(forResource: "GoogleService-Info", ofType: "plist"), + let dict = NSDictionary(contentsOfFile: path), + let key = dict["API_KEY"] as? String { + return key + } + return "AIzaSyD9dzBdglc7IO9pPDIOvqnCoTis_xKkkC8" // fallback to prod + }() // MARK: - User Name Properties From 9e8c3a0b4c5ff61dffa84ae4472da9707b7991aa Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 07:21:32 +0100 Subject: [PATCH 152/163] Fix dev.sh to copy dev Firebase plist instead of prod dev.sh builds Omi Dev (com.omi.desktop-dev) but was copying the prod GoogleService-Info.plist. Now uses the same dev plist logic as run.sh. Fixes #5536 Co-Authored-By: Claude Opus 4.6 --- desktop/dev.sh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/desktop/dev.sh b/desktop/dev.sh index 3b3ec2692d..26f0a3b660 100755 --- a/desktop/dev.sh +++ b/desktop/dev.sh @@ -85,8 +85,12 @@ cp Desktop/Info.plist "$APP_BUNDLE/Contents/Info.plist" /usr/libexec/PlistBuddy -c "Set :CFBundleDisplayName $APP_NAME" "$APP_BUNDLE/Contents/Info.plist" /usr/libexec/PlistBuddy -c "Set :CFBundleURLTypes:0:CFBundleURLSchemes:0 omi-computer-dev" "$APP_BUNDLE/Contents/Info.plist" -# Copy GoogleService-Info.plist for Firebase -cp Desktop/Sources/GoogleService-Info.plist "$APP_BUNDLE/Contents/Resources/" +# Copy GoogleService-Info.plist for Firebase (dev version for com.omi.desktop-dev) +if [ -f "Desktop/Sources/GoogleService-Info-Dev.plist" ]; then + cp -f Desktop/Sources/GoogleService-Info-Dev.plist "$APP_BUNDLE/Contents/Resources/GoogleService-Info.plist" +else + cp -f Desktop/Sources/GoogleService-Info.plist "$APP_BUNDLE/Contents/Resources/" +fi # Copy resource bundle (contains app assets like herologo.png, omi-with-rope-no-padding.webp, etc.) SWIFT_BUILD_DIR="Desktop/.build/debug" From e3cfbf1608752de8460daa0bd5c0fc8c173030fc Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 07:21:38 +0100 Subject: [PATCH 153/163] Fix reset-and-run.sh to copy dev Firebase plist instead of prod reset-and-run.sh builds Omi Dev (com.omi.desktop-dev) but was copying the prod GoogleService-Info.plist. Now uses the same dev plist logic as run.sh. Fixes #5536 Co-Authored-By: Claude Opus 4.6 --- desktop/reset-and-run.sh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/desktop/reset-and-run.sh b/desktop/reset-and-run.sh index 1590b862fd..c52c28a10d 100755 --- a/desktop/reset-and-run.sh +++ b/desktop/reset-and-run.sh @@ -381,8 +381,12 @@ cp Desktop/Info.plist "$APP_BUNDLE/Contents/Info.plist" /usr/libexec/PlistBuddy -c "Set :CFBundleDisplayName $APP_NAME" "$APP_BUNDLE/Contents/Info.plist" /usr/libexec/PlistBuddy -c "Set :CFBundleURLTypes:0:CFBundleURLSchemes:0 omi-computer-dev" "$APP_BUNDLE/Contents/Info.plist" -# Copy GoogleService-Info.plist for Firebase -cp Desktop/Sources/GoogleService-Info.plist "$APP_BUNDLE/Contents/Resources/" +# Copy GoogleService-Info.plist for Firebase (dev version for com.omi.desktop-dev) +if [ -f "Desktop/Sources/GoogleService-Info-Dev.plist" ]; then + cp -f Desktop/Sources/GoogleService-Info-Dev.plist "$APP_BUNDLE/Contents/Resources/GoogleService-Info.plist" +else + cp -f Desktop/Sources/GoogleService-Info.plist "$APP_BUNDLE/Contents/Resources/" +fi # Copy .env.app (app runtime secrets only) and add API URL if [ -f ".env.app" ]; then From 00ad7c8c87533e079e76b3e0910592c6f0e62cb7 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 07:28:23 +0100 Subject: [PATCH 154/163] Log fatal warning when dev build falls back to prod Firebase key CODEx review: dev builds should not silently use prod credentials. Now logs a FATAL warning if GoogleService-Info.plist is missing or has no API_KEY in a dev build (bundle ID ending in -dev). Fixes #5536 Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/AuthService.swift | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/desktop/Desktop/Sources/AuthService.swift b/desktop/Desktop/Sources/AuthService.swift index cd1d4eeecc..bab0433afc 100644 --- a/desktop/Desktop/Sources/AuthService.swift +++ b/desktop/Desktop/Sources/AuthService.swift @@ -77,7 +77,12 @@ class AuthService { let key = dict["API_KEY"] as? String { return key } - return "AIzaSyD9dzBdglc7IO9pPDIOvqnCoTis_xKkkC8" // fallback to prod + // Dev builds must not silently fall back to prod credentials + let isDev = Bundle.main.bundleIdentifier?.hasSuffix("-dev") == true + if isDev { + log("AuthService: FATAL — GoogleService-Info.plist missing or has no API_KEY in dev build") + } + return "AIzaSyD9dzBdglc7IO9pPDIOvqnCoTis_xKkkC8" // fallback to prod (prod builds only) }() // MARK: - User Name Properties From 6d8b57e8ed753abebd6a350b4a01eaa9f9f65dd6 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 07:33:20 +0100 Subject: [PATCH 155/163] Crash dev builds when Firebase plist is missing instead of falling back to prod CODEx review round 2: logging is not fail-fast. Dev builds now crash with fatalError if GoogleService-Info.plist has no API_KEY, preventing silent use of prod credentials. Prod builds still fall back safely. Fixes #5536 Co-Authored-By: Claude Opus 4.6 --- desktop/Desktop/Sources/AuthService.swift | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/desktop/Desktop/Sources/AuthService.swift b/desktop/Desktop/Sources/AuthService.swift index bab0433afc..9838fa057e 100644 --- a/desktop/Desktop/Sources/AuthService.swift +++ b/desktop/Desktop/Sources/AuthService.swift @@ -78,9 +78,8 @@ class AuthService { return key } // Dev builds must not silently fall back to prod credentials - let isDev = Bundle.main.bundleIdentifier?.hasSuffix("-dev") == true - if isDev { - log("AuthService: FATAL — GoogleService-Info.plist missing or has no API_KEY in dev build") + if Bundle.main.bundleIdentifier?.hasSuffix("-dev") == true { + fatalError("AuthService: GoogleService-Info.plist missing or has no API_KEY in dev build — check that run.sh copied GoogleService-Info-Dev.plist") } return "AIzaSyD9dzBdglc7IO9pPDIOvqnCoTis_xKkkC8" // fallback to prod (prod builds only) }() From 4ec7217c759889c8351e12d2257bc58b8caf0850 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 08:25:55 +0100 Subject: [PATCH 156/163] Add GCE VM creation logic for new users (port from Rust agent.rs) --- backend/routers/agent_tools.py | 210 +++++++++++++++++++++++++++++---- 1 file changed, 190 insertions(+), 20 deletions(-) diff --git a/backend/routers/agent_tools.py b/backend/routers/agent_tools.py index 0fcc3b7f55..31b7a5a945 100644 --- a/backend/routers/agent_tools.py +++ b/backend/routers/agent_tools.py @@ -4,13 +4,15 @@ Endpoints: - GET /v1/agent/tools — returns tool definitions (name, description, parameters) - POST /v1/agent/execute-tool — executes a named tool and returns the result -- GET /v1/agent/vm-status — returns basic VM status from Firestore -- POST /v1/agent/vm-ensure — checks VM status, restarts if stopped, returns current state +- GET /v1/agent/vm-status — returns VM status from Firestore (with restart if stopped) +- POST /v1/agent/vm-ensure — ensures user has a VM: creates if missing, restarts if stopped - POST /v1/agent/keepalive — pings the VM to reset its idle auto-stop timer """ import asyncio import logging +import os +import uuid from datetime import datetime, timezone import google.auth @@ -29,7 +31,10 @@ router = APIRouter() -GCE_PROJECT = "based-hardware" +GCE_PROJECT = os.environ.get("GCE_PROJECT_ID", os.environ.get("GOOGLE_CLOUD_PROJECT", "based-hardware")) +GCE_ZONE = "us-central1-a" +GCE_SOURCE_IMAGE = os.environ.get("GCE_SOURCE_IMAGE", f"projects/{GCE_PROJECT}/global/images/family/omi-agent") +AGENT_GCS_BUCKET = os.environ.get("AGENT_GCS_BUCKET", "based-hardware-agent") # --------------- GCE helpers --------------- @@ -128,6 +133,118 @@ def _update_firestore_vm(uid: str, ip: str | None, status: str): firestore_db.collection('users').document(uid).update(update) +def _set_firestore_vm(uid: str, vm_name: str, zone: str, ip: str | None, status: str, auth_token: str): + """Write the full agentVm document to Firestore (for initial provisioning).""" + from database.users import db as firestore_db + + now = datetime.now(timezone.utc).isoformat() + vm_data = { + "vmName": vm_name, + "zone": zone, + "status": status, + "authToken": auth_token, + "createdAt": now, + } + if ip: + vm_data["ip"] = ip + firestore_db.collection('users').document(uid).set({"agentVm": vm_data}, merge=True) + + +async def _create_gce_vm(vm_name: str, auth_token: str) -> str: + """Create a GCE VM from the omi-agent image family. Returns the external IP.""" + zone = GCE_ZONE + startup_script = ( + f"#!/bin/bash\ncurl -sf https://storage.googleapis.com/{AGENT_GCS_BUCKET}/startup.sh" + f" -o /tmp/omi-startup.sh && bash /tmp/omi-startup.sh\n" + ) + + url = f"https://compute.googleapis.com/compute/v1/projects/{GCE_PROJECT}/zones/{zone}/instances" + body = { + "name": vm_name, + "machineType": f"zones/{zone}/machineTypes/e2-small", + "disks": [ + { + "boot": True, + "autoDelete": True, + "initializeParams": { + "sourceImage": GCE_SOURCE_IMAGE, + "diskSizeGb": "50", + "diskType": f"zones/{zone}/diskTypes/pd-ssd", + }, + } + ], + "networkInterfaces": [ + { + "network": "global/networks/default", + "accessConfigs": [{"type": "ONE_TO_ONE_NAT", "name": "External NAT"}], + } + ], + "tags": {"items": ["omi-agent-vm"]}, + "metadata": { + "items": [ + {"key": "startup-script", "value": startup_script}, + {"key": "auth-token", "value": auth_token}, + ] + }, + } + + token = _get_gce_access_token() + async with httpx.AsyncClient(timeout=180) as client: + resp = await client.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) + if resp.status_code not in (200, 204): + raise Exception(f"GCE insert failed: {resp.status_code} {sanitize(resp.text)}") + + op_name = resp.json().get("name") + if not op_name: + raise Exception("Missing operation name in GCE insert response") + + # Poll operation until done (max ~2 minutes) + op_url = f"https://compute.googleapis.com/compute/v1/projects/{GCE_PROJECT}/zones/{zone}/operations/{op_name}" + for i in range(24): + await asyncio.sleep(5) + token = _get_gce_access_token() + status_resp = await client.get(op_url, headers={"Authorization": f"Bearer {token}"}) + op_status = status_resp.json() + if op_status.get("status") == "DONE": + if "error" in op_status: + raise Exception(f"GCE insert operation failed: {op_status['error']}") + logger.info(f"[vm-create] {vm_name} operation done after {i + 1} polls") + break + + # Get external IP + instance_url = ( + f"https://compute.googleapis.com/compute/v1/projects/{GCE_PROJECT}/zones/{zone}/instances/{vm_name}" + ) + ip = None + for attempt in range(6): + token = _get_gce_access_token() + inst_resp = await client.get(instance_url, headers={"Authorization": f"Bearer {token}"}) + instance = inst_resp.json() + try: + candidate = instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"] + if candidate and candidate != "unknown": + ip = candidate + logger.info(f"[vm-create] {vm_name} got IP {ip} on attempt {attempt + 1}") + break + except (KeyError, IndexError): + pass + if attempt < 5: + await asyncio.sleep(3) + + return ip or "unknown" + + +async def _provision_vm_background(uid: str, vm_name: str, auth_token: str): + """Background task: create a new GCE VM, update Firestore when ready.""" + try: + ip = await _create_gce_vm(vm_name, auth_token) + _set_firestore_vm(uid, vm_name, GCE_ZONE, ip, "ready", auth_token) + logger.info(f"[vm-ensure] VM {vm_name} created, ip={ip}") + except Exception as e: + logger.error(f"[vm-ensure] Failed to create VM {vm_name}: {e}") + _update_firestore_vm(uid, None, "error") + + async def _restart_vm_background(uid: str, vm_name: str, zone: str): """Background task: start stopped VM, update Firestore with new IP when ready.""" try: @@ -142,53 +259,106 @@ async def _restart_vm_background(uid: str, vm_name: str, zone: str): # --------------- endpoints --------------- -@router.get("/v1/agent/vm-status") -def get_vm_status(uid: str = Depends(get_current_user_uid)): - """Return the user's agent VM info from Firestore.""" - vm = get_agent_vm(uid) - logger.info(f"[vm-status] uid={uid} vm={sanitize(vm)}") - if not vm or vm.get("status") != "ready": - return {"has_vm": False} +def _vm_response(vm: dict, status_override: str | None = None) -> dict: + """Build a standard VM response dict with all fields the desktop expects.""" return { "has_vm": True, - "status": vm.get("status"), + "status": status_override or vm.get("status"), + "vm_name": vm.get("vmName"), + "ip": vm.get("ip"), + "auth_token": vm.get("authToken"), + "zone": vm.get("zone", GCE_ZONE), + "created_at": vm.get("createdAt"), + "last_query_at": vm.get("lastQueryAt"), } +@router.get("/v1/agent/vm-status") +async def get_vm_status(background_tasks: BackgroundTasks, uid: str = Depends(get_current_user_uid)): + """Return the user's agent VM info from Firestore. Restarts stopped VMs.""" + vm = get_agent_vm(uid) + if not vm: + return {"has_vm": False} + + fs_status = vm.get("status", "") + vm_name = vm.get("vmName") + zone = vm.get("zone", GCE_ZONE) + + # For ready/error/stopped VMs, verify actual GCE status and restart if needed + if fs_status in ("ready", "error", "stopped") and vm_name: + try: + gce_status = await _check_gce_status(vm_name, zone) + except Exception as e: + logger.warning(f"[vm-status] GCE status check failed for {vm_name}: {e}") + return _vm_response(vm) + + if gce_status in ("TERMINATED", "STOPPED"): + logger.info(f"[vm-status] VM {vm_name} is {gce_status}, restarting...") + _update_firestore_vm(uid, None, "provisioning") + background_tasks.add_task(_restart_vm_background, uid, vm_name, zone) + return _vm_response(vm, status_override="provisioning") + + if gce_status == "RUNNING" and fs_status != "ready": + _update_firestore_vm(uid, vm.get("ip"), "ready") + return _vm_response(vm, status_override="ready") + + return _vm_response(vm) + + @router.post("/v1/agent/vm-ensure") async def ensure_vm(background_tasks: BackgroundTasks, uid: str = Depends(get_current_user_uid)): - """Check VM status; if stopped/terminated, restart it in the background.""" + """Ensure user has a VM: create if missing, restart if stopped.""" vm = get_agent_vm(uid) + + # No VM exists — provision a new one if not vm: - return {"has_vm": False} + uid_prefix = uid[:12].lower() if len(uid) > 12 else uid.lower() + vm_name = f"omi-agent-{uid_prefix}" + auth_token = f"omi-{uuid.uuid4()}" + + # Claim the slot in Firestore before spawning background creation + _set_firestore_vm(uid, vm_name, GCE_ZONE, None, "provisioning", auth_token) + background_tasks.add_task(_provision_vm_background, uid, vm_name, auth_token) + logger.info(f"[vm-ensure] Provisioning new VM {vm_name} for uid={uid[:8]}...") + + return { + "has_vm": True, + "status": "provisioning", + "vm_name": vm_name, + "ip": None, + "auth_token": auth_token, + "zone": GCE_ZONE, + "created_at": datetime.now(timezone.utc).isoformat(), + "last_query_at": None, + } vm_name = vm.get("vmName") - zone = vm.get("zone", "us-central1-a") + zone = vm.get("zone", GCE_ZONE) fs_status = vm.get("status", "") # If Firestore already says provisioning, don't double-start if fs_status == "provisioning": - return {"has_vm": True, "status": "provisioning"} + return _vm_response(vm) # Check actual GCE status for ready/error/stopped VMs - if fs_status in ("ready", "error", "stopped"): + if fs_status in ("ready", "error", "stopped") and vm_name: try: gce_status = await _check_gce_status(vm_name, zone) except Exception as e: logger.error(f"[vm-ensure] GCE status check failed: {e}") - return {"has_vm": True, "status": fs_status} + return _vm_response(vm) if gce_status in ("TERMINATED", "STOPPED"): logger.info(f"[vm-ensure] VM {vm_name} is {gce_status}, restarting...") _update_firestore_vm(uid, None, "provisioning") background_tasks.add_task(_restart_vm_background, uid, vm_name, zone) - return {"has_vm": True, "status": "provisioning"} + return _vm_response(vm, status_override="provisioning") if gce_status == "RUNNING" and fs_status != "ready": _update_firestore_vm(uid, vm.get("ip"), "ready") - return {"has_vm": True, "status": "ready"} + return _vm_response(vm, status_override="ready") - return {"has_vm": True, "status": fs_status} + return _vm_response(vm) @router.post("/v1/agent/keepalive") From 4587490b80e9d7bbb4ba03e08b2b4efad5cfd1cb Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 08:25:56 +0100 Subject: [PATCH 157/163] Add unit tests for /v1/agent/vm-ensure and /v1/agent/vm-status (13 tests) --- backend/tests/unit/test_agent_vm.py | 244 ++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 backend/tests/unit/test_agent_vm.py diff --git a/backend/tests/unit/test_agent_vm.py b/backend/tests/unit/test_agent_vm.py new file mode 100644 index 0000000000..110de2b786 --- /dev/null +++ b/backend/tests/unit/test_agent_vm.py @@ -0,0 +1,244 @@ +"""Tests for agent VM endpoints — vm-ensure and vm-status. + +Verifies: +- vm-ensure creates new VMs for users with no existing VM +- vm-ensure restarts stopped/terminated VMs +- vm-ensure is idempotent (doesn't double-provision) +- vm-status returns full VM fields (vm_name, ip, auth_token, zone, created_at) +- vm-status triggers restart for stopped VMs (Rust parity) +- Response JSON matches desktop Swift AgentProvisionResponse/AgentStatusResponse +""" + +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + + +# Stub heavy imports before importing the router +sys.modules.setdefault('database._client', MagicMock()) +sys.modules.setdefault('database.users', MagicMock()) +sys.modules.setdefault('utils.retrieval.agentic', MagicMock(agent_config_context=MagicMock(), CORE_TOOLS=[])) +sys.modules.setdefault('utils.retrieval.tools.app_tools', MagicMock()) + +from routers.agent_tools import router, _vm_response, GCE_ZONE + +app = FastAPI() +app.include_router(router) + +TEST_UID = "testuser1234abcd" + + +def _mock_auth(): + """Override auth dependency to return a test UID.""" + from utils.other.endpoints import get_current_user_uid + + app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID + + +_mock_auth() +client = TestClient(app) + + +SAMPLE_VM = { + "vmName": "omi-agent-testuser1234", + "zone": "us-central1-a", + "ip": "35.192.1.1", + "status": "ready", + "authToken": "omi-abc123", + "createdAt": "2026-03-10T00:00:00+00:00", + "lastQueryAt": "2026-03-10T01:00:00+00:00", +} + + +# --------------- vm-status tests --------------- + + +@patch("routers.agent_tools.get_agent_vm", return_value=None) +def test_vm_status_no_vm(mock_get): + """vm-status returns has_vm=False when user has no VM.""" + resp = client.get("/v1/agent/vm-status") + assert resp.status_code == 200 + data = resp.json() + assert data["has_vm"] is False + + +@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="RUNNING") +@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM) +def test_vm_status_returns_full_fields(mock_get, mock_gce): + """vm-status returns all fields desktop needs: vm_name, ip, auth_token, zone, created_at.""" + resp = client.get("/v1/agent/vm-status") + assert resp.status_code == 200 + data = resp.json() + assert data["has_vm"] is True + assert data["vm_name"] == "omi-agent-testuser1234" + assert data["ip"] == "35.192.1.1" + assert data["auth_token"] == "omi-abc123" + assert data["zone"] == "us-central1-a" + assert data["created_at"] == "2026-03-10T00:00:00+00:00" + assert data["last_query_at"] == "2026-03-10T01:00:00+00:00" + assert data["status"] == "ready" + + +@patch("routers.agent_tools._restart_vm_background") +@patch("routers.agent_tools._update_firestore_vm") +@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="TERMINATED") +@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM) +def test_vm_status_restarts_stopped_vm(mock_get, mock_gce, mock_update, mock_restart): + """vm-status triggers restart when GCE status is TERMINATED (Rust parity).""" + resp = client.get("/v1/agent/vm-status") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "provisioning" + mock_update.assert_called_once_with(TEST_UID, None, "provisioning") + + +@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, side_effect=Exception("GCE unreachable")) +@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM) +def test_vm_status_gce_failure_returns_firestore_data(mock_get, mock_gce): + """vm-status returns Firestore data when GCE check fails.""" + resp = client.get("/v1/agent/vm-status") + assert resp.status_code == 200 + data = resp.json() + assert data["has_vm"] is True + assert data["status"] == "ready" + assert data["vm_name"] == "omi-agent-testuser1234" + + +# --------------- vm-ensure tests --------------- + + +@patch("routers.agent_tools._provision_vm_background") +@patch("routers.agent_tools._set_firestore_vm") +@patch("routers.agent_tools.get_agent_vm", return_value=None) +def test_vm_ensure_creates_new_vm(mock_get, mock_set_fs, mock_provision): + """vm-ensure creates a new VM when no Firestore record exists.""" + resp = client.post("/v1/agent/vm-ensure") + assert resp.status_code == 200 + data = resp.json() + assert data["has_vm"] is True + assert data["status"] == "provisioning" + assert data["vm_name"] == "omi-agent-testuser1234" + assert data["auth_token"].startswith("omi-") + assert data["zone"] == "us-central1-a" + assert data["ip"] is None + + # Verify Firestore was written + mock_set_fs.assert_called_once() + call_args = mock_set_fs.call_args + assert call_args[0][0] == TEST_UID + assert call_args[0][1] == "omi-agent-testuser1234" + assert call_args[0][4] == "provisioning" + + +@patch( + "routers.agent_tools.get_agent_vm", + return_value={"vmName": "omi-agent-testuser1234", "status": "provisioning", "authToken": "omi-xyz"}, +) +def test_vm_ensure_idempotent_provisioning(mock_get): + """vm-ensure doesn't double-provision when already provisioning.""" + resp = client.post("/v1/agent/vm-ensure") + assert resp.status_code == 200 + data = resp.json() + assert data["has_vm"] is True + assert data["status"] == "provisioning" + + +@patch("routers.agent_tools._restart_vm_background") +@patch("routers.agent_tools._update_firestore_vm") +@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="STOPPED") +@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM) +def test_vm_ensure_restarts_stopped_vm(mock_get, mock_gce, mock_update, mock_restart): + """vm-ensure restarts a stopped VM.""" + resp = client.post("/v1/agent/vm-ensure") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "provisioning" + mock_update.assert_called_once_with(TEST_UID, None, "provisioning") + + +@patch("routers.agent_tools._update_firestore_vm") +@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="RUNNING") +@patch( + "routers.agent_tools.get_agent_vm", + return_value={**SAMPLE_VM, "status": "error"}, +) +def test_vm_ensure_recovers_running_but_error_status(mock_get, mock_gce, mock_update): + """vm-ensure recovers when GCE is RUNNING but Firestore says error.""" + resp = client.post("/v1/agent/vm-ensure") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ready" + mock_update.assert_called_once_with(TEST_UID, "35.192.1.1", "ready") + + +@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="RUNNING") +@patch("routers.agent_tools.get_agent_vm", return_value=SAMPLE_VM) +def test_vm_ensure_returns_full_fields_for_ready_vm(mock_get, mock_gce): + """vm-ensure returns full VM fields when VM is ready.""" + resp = client.post("/v1/agent/vm-ensure") + assert resp.status_code == 200 + data = resp.json() + assert data["has_vm"] is True + assert data["vm_name"] == "omi-agent-testuser1234" + assert data["ip"] == "35.192.1.1" + assert data["auth_token"] == "omi-abc123" + + +# --------------- _vm_response tests --------------- + + +def test_vm_response_maps_firestore_fields(): + """_vm_response correctly maps Firestore camelCase to snake_case.""" + result = _vm_response(SAMPLE_VM) + assert result["vm_name"] == "omi-agent-testuser1234" + assert result["auth_token"] == "omi-abc123" + assert result["created_at"] == "2026-03-10T00:00:00+00:00" + assert result["last_query_at"] == "2026-03-10T01:00:00+00:00" + + +def test_vm_response_status_override(): + """_vm_response applies status_override correctly.""" + result = _vm_response(SAMPLE_VM, status_override="provisioning") + assert result["status"] == "provisioning" + assert result["vm_name"] == "omi-agent-testuser1234" + + +# --------------- vm_name generation tests --------------- + + +@patch("routers.agent_tools._provision_vm_background") +@patch("routers.agent_tools._set_firestore_vm") +@patch("routers.agent_tools.get_agent_vm", return_value=None) +def test_vm_name_truncates_long_uid(mock_get, mock_set_fs, mock_provision): + """VM name uses first 12 chars of UID, lowercased.""" + from utils.other.endpoints import get_current_user_uid + + app.dependency_overrides[get_current_user_uid] = lambda: "ABCDEFghijklmnopqrstuvwxyz" + try: + resp = client.post("/v1/agent/vm-ensure") + data = resp.json() + assert data["vm_name"] == "omi-agent-abcdefghijkl" + finally: + app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID + + +@patch("routers.agent_tools._provision_vm_background") +@patch("routers.agent_tools._set_firestore_vm") +@patch("routers.agent_tools.get_agent_vm", return_value=None) +def test_vm_name_short_uid(mock_get, mock_set_fs, mock_provision): + """Short UIDs use the full UID in VM name.""" + from utils.other.endpoints import get_current_user_uid + + app.dependency_overrides[get_current_user_uid] = lambda: "ShortUid" + try: + resp = client.post("/v1/agent/vm-ensure") + data = resp.json() + assert data["vm_name"] == "omi-agent-shortuid" + finally: + app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID From 23d88ee7c1021f31e1529d8909316cf104933588 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 08:25:57 +0100 Subject: [PATCH 158/163] Add test_agent_vm.py to test.sh --- backend/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/test.sh b/backend/test.sh index ae4a3dbc41..6aad400d60 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -53,3 +53,4 @@ pytest tests/unit/test_advice.py -v pytest tests/unit/test_staged_tasks.py -v pytest tests/unit/test_chat_generate_title.py -v pytest tests/unit/test_conversations_count.py -v +pytest tests/unit/test_agent_vm.py -v From 4b8704e95cb62f91c5980961589d7bd37f8136a4 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 08:30:56 +0100 Subject: [PATCH 159/163] Fix reviewer issues: move imports to top-level, fail-fast on GCE timeout/IP --- backend/routers/agent_tools.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/backend/routers/agent_tools.py b/backend/routers/agent_tools.py index 31b7a5a945..9b1b6c2ddd 100644 --- a/backend/routers/agent_tools.py +++ b/backend/routers/agent_tools.py @@ -21,6 +21,7 @@ from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from pydantic import BaseModel +from database._client import db as firestore_db from database.users import get_agent_vm from utils.other.endpoints import get_current_user_uid from utils.retrieval.agentic import agent_config_context, CORE_TOOLS @@ -125,8 +126,6 @@ async def _start_vm_and_wait(vm_name: str, zone: str) -> str: def _update_firestore_vm(uid: str, ip: str | None, status: str): """Update the user's agentVm fields in Firestore.""" - from database.users import db as firestore_db - update = {"agentVm.status": status} if ip: update["agentVm.ip"] = ip @@ -135,8 +134,6 @@ def _update_firestore_vm(uid: str, ip: str | None, status: str): def _set_firestore_vm(uid: str, vm_name: str, zone: str, ip: str | None, status: str, auth_token: str): """Write the full agentVm document to Firestore (for initial provisioning).""" - from database.users import db as firestore_db - now = datetime.now(timezone.utc).isoformat() vm_data = { "vmName": vm_name, @@ -200,6 +197,7 @@ async def _create_gce_vm(vm_name: str, auth_token: str) -> str: # Poll operation until done (max ~2 minutes) op_url = f"https://compute.googleapis.com/compute/v1/projects/{GCE_PROJECT}/zones/{zone}/operations/{op_name}" + op_done = False for i in range(24): await asyncio.sleep(5) token = _get_gce_access_token() @@ -209,29 +207,30 @@ async def _create_gce_vm(vm_name: str, auth_token: str) -> str: if "error" in op_status: raise Exception(f"GCE insert operation failed: {op_status['error']}") logger.info(f"[vm-create] {vm_name} operation done after {i + 1} polls") + op_done = True break + if not op_done: + raise Exception(f"GCE insert timed out after 120s for {vm_name}") # Get external IP instance_url = ( f"https://compute.googleapis.com/compute/v1/projects/{GCE_PROJECT}/zones/{zone}/instances/{vm_name}" ) - ip = None for attempt in range(6): token = _get_gce_access_token() inst_resp = await client.get(instance_url, headers={"Authorization": f"Bearer {token}"}) instance = inst_resp.json() try: candidate = instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"] - if candidate and candidate != "unknown": - ip = candidate - logger.info(f"[vm-create] {vm_name} got IP {ip} on attempt {attempt + 1}") - break + if candidate: + logger.info(f"[vm-create] {vm_name} got IP {candidate} on attempt {attempt + 1}") + return candidate except (KeyError, IndexError): pass if attempt < 5: await asyncio.sleep(3) - return ip or "unknown" + raise Exception(f"Failed to get external IP for {vm_name} after 6 attempts") async def _provision_vm_background(uid: str, vm_name: str, auth_token: str): From f1a912daf911dbc5db015abeefbf29703303e363 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 08:30:57 +0100 Subject: [PATCH 160/163] Fix test in-function imports per CLAUDE.md style rule --- backend/tests/unit/test_agent_vm.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/backend/tests/unit/test_agent_vm.py b/backend/tests/unit/test_agent_vm.py index 110de2b786..0743638f00 100644 --- a/backend/tests/unit/test_agent_vm.py +++ b/backend/tests/unit/test_agent_vm.py @@ -27,21 +27,14 @@ sys.modules.setdefault('utils.retrieval.tools.app_tools', MagicMock()) from routers.agent_tools import router, _vm_response, GCE_ZONE +from utils.other.endpoints import get_current_user_uid app = FastAPI() app.include_router(router) TEST_UID = "testuser1234abcd" - -def _mock_auth(): - """Override auth dependency to return a test UID.""" - from utils.other.endpoints import get_current_user_uid - - app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID - - -_mock_auth() +app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID client = TestClient(app) @@ -217,8 +210,6 @@ def test_vm_response_status_override(): @patch("routers.agent_tools.get_agent_vm", return_value=None) def test_vm_name_truncates_long_uid(mock_get, mock_set_fs, mock_provision): """VM name uses first 12 chars of UID, lowercased.""" - from utils.other.endpoints import get_current_user_uid - app.dependency_overrides[get_current_user_uid] = lambda: "ABCDEFghijklmnopqrstuvwxyz" try: resp = client.post("/v1/agent/vm-ensure") @@ -233,8 +224,6 @@ def test_vm_name_truncates_long_uid(mock_get, mock_set_fs, mock_provision): @patch("routers.agent_tools.get_agent_vm", return_value=None) def test_vm_name_short_uid(mock_get, mock_set_fs, mock_provision): """Short UIDs use the full UID in VM name.""" - from utils.other.endpoints import get_current_user_uid - app.dependency_overrides[get_current_user_uid] = lambda: "ShortUid" try: resp = client.post("/v1/agent/vm-ensure") From c9640aa6f9295ae3eb2e98d8202403bf5bd9d30c Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 08:33:02 +0100 Subject: [PATCH 161/163] Move time import to module level per CLAUDE.md style rule --- backend/routers/agent_tools.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/routers/agent_tools.py b/backend/routers/agent_tools.py index 9b1b6c2ddd..805e98c4f2 100644 --- a/backend/routers/agent_tools.py +++ b/backend/routers/agent_tools.py @@ -12,6 +12,7 @@ import asyncio import logging import os +import time import uuid from datetime import datetime, timezone @@ -62,8 +63,6 @@ async def _check_gce_status(vm_name: str, zone: str) -> str: async def _start_vm_and_wait(vm_name: str, zone: str) -> str: """Start a stopped/terminated GCE VM and wait for it to get an IP. Returns the new IP.""" - import time - t0 = time.monotonic() token = _get_gce_access_token() start_url = ( From 03ddcd8cd503b7e3055dcc9708d193740b36cc31 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 08:37:23 +0100 Subject: [PATCH 162/163] Add boundary, background-error, and incomplete-payload tests (8 new, 21 total) --- backend/tests/unit/test_agent_vm.py | 94 ++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) diff --git a/backend/tests/unit/test_agent_vm.py b/backend/tests/unit/test_agent_vm.py index 0743638f00..e71d4420a7 100644 --- a/backend/tests/unit/test_agent_vm.py +++ b/backend/tests/unit/test_agent_vm.py @@ -26,7 +26,7 @@ sys.modules.setdefault('utils.retrieval.agentic', MagicMock(agent_config_context=MagicMock(), CORE_TOOLS=[])) sys.modules.setdefault('utils.retrieval.tools.app_tools', MagicMock()) -from routers.agent_tools import router, _vm_response, GCE_ZONE +from routers.agent_tools import router, _vm_response, _provision_vm_background, _restart_vm_background, GCE_ZONE from utils.other.endpoints import get_current_user_uid app = FastAPI() @@ -231,3 +231,95 @@ def test_vm_name_short_uid(mock_get, mock_set_fs, mock_provision): assert data["vm_name"] == "omi-agent-shortuid" finally: app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID + + +# --------------- UID boundary tests --------------- + + +@patch("routers.agent_tools._provision_vm_background") +@patch("routers.agent_tools._set_firestore_vm") +@patch("routers.agent_tools.get_agent_vm", return_value=None) +def test_vm_name_whitespace_uid(mock_get, mock_set_fs, mock_provision): + """UID with whitespace is lowercased and truncated normally.""" + app.dependency_overrides[get_current_user_uid] = lambda: "User With Spaces" + try: + resp = client.post("/v1/agent/vm-ensure") + data = resp.json() + assert data["vm_name"] == "omi-agent-user with sp" + finally: + app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID + + +@patch("routers.agent_tools._provision_vm_background") +@patch("routers.agent_tools._set_firestore_vm") +@patch("routers.agent_tools.get_agent_vm", return_value=None) +def test_vm_name_empty_string_uid(mock_get, mock_set_fs, mock_provision): + """Empty-string UID produces omi-agent- prefix with empty suffix.""" + app.dependency_overrides[get_current_user_uid] = lambda: "" + try: + resp = client.post("/v1/agent/vm-ensure") + data = resp.json() + assert data["vm_name"] == "omi-agent-" + finally: + app.dependency_overrides[get_current_user_uid] = lambda: TEST_UID + + +# --------------- background task error handling tests --------------- + + +@pytest.mark.asyncio +@patch("routers.agent_tools._update_firestore_vm") +@patch("routers.agent_tools._create_gce_vm", new_callable=AsyncMock, side_effect=Exception("GCE insert timed out")) +async def test_provision_vm_background_sets_error_on_failure(mock_create, mock_update): + """_provision_vm_background sets Firestore status to 'error' when GCE creation fails.""" + await _provision_vm_background("uid123", "omi-agent-uid123", "omi-token") + mock_update.assert_called_once_with("uid123", None, "error") + + +@pytest.mark.asyncio +@patch("routers.agent_tools._update_firestore_vm") +@patch("routers.agent_tools._start_vm_and_wait", new_callable=AsyncMock, side_effect=Exception("GCE start timed out")) +async def test_restart_vm_background_sets_error_on_failure(mock_start, mock_update): + """_restart_vm_background sets Firestore status to 'error' when restart fails.""" + await _restart_vm_background("uid123", "omi-agent-uid123", "us-central1-a") + mock_update.assert_called_once_with("uid123", None, "error") + + +@pytest.mark.asyncio +@patch("routers.agent_tools._set_firestore_vm") +@patch("routers.agent_tools._create_gce_vm", new_callable=AsyncMock, return_value="10.0.0.1") +async def test_provision_vm_background_sets_ready_on_success(mock_create, mock_set_fs): + """_provision_vm_background writes 'ready' status with IP on success.""" + await _provision_vm_background("uid123", "omi-agent-uid123", "omi-token") + mock_set_fs.assert_called_once_with("uid123", "omi-agent-uid123", GCE_ZONE, "10.0.0.1", "ready", "omi-token") + + +# --------------- incomplete Firestore payload tests --------------- + + +@patch("routers.agent_tools.get_agent_vm", return_value={"status": "ready"}) +def test_vm_status_handles_missing_vm_name(mock_get): + """vm-status does not crash when vmName is missing from Firestore.""" + resp = client.get("/v1/agent/vm-status") + assert resp.status_code == 200 + data = resp.json() + assert data["has_vm"] is True + assert data["vm_name"] is None + + +@patch("routers.agent_tools.get_agent_vm", return_value={"vmName": "omi-agent-x", "status": "ready"}) +def test_vm_status_handles_missing_ip_and_auth(mock_get): + """vm-status returns None for ip and auth_token when missing from Firestore.""" + resp = client.get("/v1/agent/vm-status") + assert resp.status_code == 200 + data = resp.json() + assert data["ip"] is None + assert data["auth_token"] is None + assert data["vm_name"] == "omi-agent-x" + + +@patch("routers.agent_tools.get_agent_vm", return_value={}) +def test_vm_ensure_handles_empty_firestore_vm(mock_get): + """vm-ensure with empty Firestore dict (no status field) doesn't crash.""" + resp = client.post("/v1/agent/vm-ensure") + assert resp.status_code == 200 From 40ae983afc484888eb16db299a0ead5b20d32d57 Mon Sep 17 00:00:00 2001 From: beastoin Date: Tue, 10 Mar 2026 08:40:52 +0100 Subject: [PATCH 163/163] Fix test isolation: mock GCE status in incomplete-payload tests --- backend/tests/unit/test_agent_vm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/tests/unit/test_agent_vm.py b/backend/tests/unit/test_agent_vm.py index e71d4420a7..7f5bcd4303 100644 --- a/backend/tests/unit/test_agent_vm.py +++ b/backend/tests/unit/test_agent_vm.py @@ -299,7 +299,7 @@ async def test_provision_vm_background_sets_ready_on_success(mock_create, mock_s @patch("routers.agent_tools.get_agent_vm", return_value={"status": "ready"}) def test_vm_status_handles_missing_vm_name(mock_get): - """vm-status does not crash when vmName is missing from Firestore.""" + """vm-status does not crash when vmName is missing from Firestore (skips GCE check).""" resp = client.get("/v1/agent/vm-status") assert resp.status_code == 200 data = resp.json() @@ -307,8 +307,9 @@ def test_vm_status_handles_missing_vm_name(mock_get): assert data["vm_name"] is None +@patch("routers.agent_tools._check_gce_status", new_callable=AsyncMock, return_value="RUNNING") @patch("routers.agent_tools.get_agent_vm", return_value={"vmName": "omi-agent-x", "status": "ready"}) -def test_vm_status_handles_missing_ip_and_auth(mock_get): +def test_vm_status_handles_missing_ip_and_auth(mock_get, mock_gce): """vm-status returns None for ip and auth_token when missing from Firestore.""" resp = client.get("/v1/agent/vm-status") assert resp.status_code == 200 @@ -320,6 +321,8 @@ def test_vm_status_handles_missing_ip_and_auth(mock_get): @patch("routers.agent_tools.get_agent_vm", return_value={}) def test_vm_ensure_handles_empty_firestore_vm(mock_get): - """vm-ensure with empty Firestore dict (no status field) doesn't crash.""" + """vm-ensure with empty Firestore dict (no status, falls through to _vm_response).""" resp = client.post("/v1/agent/vm-ensure") assert resp.status_code == 200 + data = resp.json() + assert data["has_vm"] is True