Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ pytest tests/unit/test_storage_upload_audio_chunk_data_protection.py -v
pytest tests/unit/test_people_conversations_500s.py -v
pytest tests/unit/test_firestore_read_ops_cache.py -v
pytest tests/unit/test_ws_auth_handshake.py -v
pytest tests/unit/test_chat_context_truncation.py -v
226 changes: 226 additions & 0 deletions backend/tests/unit/test_chat_context_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""
Tests for chat context truncation in conversation tools.

Verifies that get_conversations_tool and search_conversations_tool
truncate large result strings to prevent context overflow and 504 timeouts.
Issue #4927: Chat freezes with lengthy date ranges.
"""

import sys
import unittest
from datetime import datetime, timezone, timedelta
from unittest.mock import patch, MagicMock

# Stub heavy dependencies before importing conversation_tools
for mod_name in [
'firebase_admin',
'firebase_admin.firestore',
'firebase_admin.auth',
'firebase_admin.credentials',
'firebase_admin.messaging',
'google.cloud.firestore',
'google.cloud.firestore_v1',
'google.cloud.firestore_v1.base_query',
'langchain_core',
'langchain_core.runnables',
'langchain_core.tools',
'database.conversations',
'database.users',
'database.vector_db',
'utils.llm.clients',
]:
if mod_name not in sys.modules:
sys.modules[mod_name] = MagicMock()

# Mock the tool decorator to be a no-op
mock_tool = MagicMock(side_effect=lambda f: f)
sys.modules['langchain_core.tools'].tool = mock_tool

# Mock RunnableConfig
sys.modules['langchain_core.runnables'].RunnableConfig = None

# Now import the module under test
from models.conversation import Conversation
from models.other import Person


def _make_conversation(index: int, overview_size: int = 200) -> dict:
"""Create a fake conversation dict with a specified overview size."""
return {
'id': f'conv-{index}',
'created_at': datetime(2026, 3, 1, tzinfo=timezone.utc) - timedelta(days=index),
'started_at': datetime(2026, 3, 1, 10, 0, tzinfo=timezone.utc) - timedelta(days=index),
'finished_at': datetime(2026, 3, 1, 11, 0, tzinfo=timezone.utc) - timedelta(days=index),
'structured': {
'title': f'Conversation about topic {index}',
'overview': 'X' * overview_size,
'category': 'personal',
'action_items': [],
'events': [],
'emoji': '',
},
'transcript_segments': [],
'plugins_results': [],
'apps_results': [],
'photos': [],
'source': 'friend',
'language': 'en',
'status': 'completed',
}


def _make_conversations(count: int, overview_size: int = 200) -> list:
"""Create a list of fake conversation dicts."""
return [_make_conversation(i, overview_size) for i in range(count)]


class TestConversationContextTruncation(unittest.TestCase):
"""Test that conversations_to_string output is properly bounded."""

def test_small_result_not_truncated(self):
"""10 conversations with small overviews should not be truncated."""
convs = [Conversation(**d) for d in _make_conversations(10, overview_size=100)]
result = Conversation.conversations_to_string(convs)
# Should have all 10 conversations
self.assertEqual(result.count('Conversation #'), 10)
self.assertNotIn('[Note:', result)

def test_conversations_to_string_output_format(self):
"""Verify basic output format of conversations_to_string."""
convs = [Conversation(**d) for d in _make_conversations(3)]
result = Conversation.conversations_to_string(convs)
self.assertIn('Conversation #1', result)
self.assertIn('Conversation #2', result)
self.assertIn('Conversation #3', result)
self.assertIn('---------------------', result)


class TestGetConversationsToolTruncation(unittest.TestCase):
"""Test truncation logic in get_conversations_tool."""

def _call_tool_with_conversations(self, conversations_data, max_result_chars=None):
"""Helper to simulate the truncation logic from get_conversations_tool."""
MAX_RESULT_CHARS = max_result_chars or 1_600_000

conversations = []
for conv_data in conversations_data:
conversations.append(Conversation(**conv_data))

result = Conversation.conversations_to_string(conversations)

if len(result) > MAX_RESULT_CHARS:
truncated_parts = []
total_chars = 0
included_count = 0
separator = "\n\n---------------------\n\n"
for conversation in conversations:
part = Conversation.conversations_to_string([conversation])
if total_chars + len(part) + len(separator) > MAX_RESULT_CHARS and included_count > 0:
break
truncated_parts.append(part)
total_chars += len(part) + len(separator)
included_count += 1

omitted = len(conversations) - included_count
result = separator.join(truncated_parts)
if omitted > 0:
result += f"\n\n[Note: {omitted} older conversations omitted to fit context. Ask about a shorter time period for full details.]"

return result, len(conversations)
Comment on lines +101 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests exercise a re-implementation, not the actual tool

_call_tool_with_conversations is a copy-paste of the truncation algorithm rather than a call to get_conversations_tool. This means the tests can pass even if the logic inside the real tool diverges (e.g., wrong parameters forwarded to conversations_to_string, or the truncation block is accidentally removed). The numbering bug noted above is a direct consequence of this gap — the helper calls conversations_to_string([conversation]) identically to the production code, so the tests pass while every conversation gets labelled "Conversation #1".

Consider patching out the Firestore/config dependencies and exercising get_conversations_tool end-to-end, or at least extracting the truncation logic into a standalone helper that both the tool and the tests can call directly.


def test_large_result_gets_truncated(self):
"""Many conversations with large overviews should be truncated."""
# Each conversation ~5100 chars. 500 convs = ~2.5M chars > 1.6M limit
conversations_data = _make_conversations(500, overview_size=5000)
result, total = self._call_tool_with_conversations(conversations_data)

self.assertEqual(total, 500)
self.assertLessEqual(len(result), 1_700_000) # ~1.6M + truncation note
self.assertIn('[Note:', result)
self.assertIn('older conversations omitted', result)

def test_small_result_passes_through(self):
"""Few conversations should not be truncated."""
conversations_data = _make_conversations(5, overview_size=200)
result, total = self._call_tool_with_conversations(conversations_data)

self.assertEqual(total, 5)
self.assertNotIn('[Note:', result)
self.assertEqual(result.count('Conversation #'), 5)

def test_truncation_preserves_order(self):
"""Truncated result should contain the first (most recent) conversations."""
conversations_data = _make_conversations(500, overview_size=5000)
result, _ = self._call_tool_with_conversations(conversations_data)

# First conversation should always be present
self.assertIn('Conversation #1', result)
# Last conversation should be omitted
self.assertNotIn('Conversation #500', result)

def test_truncation_with_custom_limit(self):
"""Truncation should work with a smaller limit."""
# Use 10K char limit — each conv ~300 chars, should fit ~30
conversations_data = _make_conversations(100, overview_size=200)
result, total = self._call_tool_with_conversations(conversations_data, max_result_chars=10_000)

self.assertEqual(total, 100)
self.assertLessEqual(len(result), 11_000) # 10K + note
self.assertIn('[Note:', result)

def test_single_huge_conversation_included(self):
"""A single conversation larger than the limit should still be included."""
conversations_data = _make_conversations(1, overview_size=2_000_000)
result, total = self._call_tool_with_conversations(conversations_data)

# Even if it exceeds the limit, 1 conversation should always be returned
self.assertEqual(total, 1)
self.assertIn('Conversation #1', result)

def test_truncation_note_includes_count(self):
"""Truncation note should include the number of omitted conversations."""
conversations_data = _make_conversations(500, overview_size=5000)
result, _ = self._call_tool_with_conversations(conversations_data)

self.assertIn('[Note:', result)
# Extract the omitted count from the note
import re

match = re.search(r'\[Note: (\d+) older conversations omitted', result)
self.assertIsNotNone(match)
omitted = int(match.group(1))
self.assertGreater(omitted, 0)
self.assertLess(omitted, 200)


class TestTokenEstimation(unittest.TestCase):
"""Test that context size stays within safety guard limits."""

def test_truncated_result_fits_safety_guard(self):
"""Truncated result should fit within 500K token safety guard."""
conversations_data = _make_conversations(1000, overview_size=5000)
# Simulate truncation
MAX_RESULT_CHARS = 1_600_000
conversations = [Conversation(**d) for d in conversations_data]

truncated_parts = []
total_chars = 0
included_count = 0
separator = "\n\n---------------------\n\n"
for conversation in conversations:
part = Conversation.conversations_to_string([conversation])
if total_chars + len(part) + len(separator) > MAX_RESULT_CHARS and included_count > 0:
break
truncated_parts.append(part)
total_chars += len(part) + len(separator)
included_count += 1

result = separator.join(truncated_parts)

# Estimate tokens (~4 chars per token)
estimated_tokens = len(result) // 4
self.assertLess(estimated_tokens, 500_000, "Truncated result should fit within 500K token safety guard")


if __name__ == '__main__':
unittest.main()
65 changes: 62 additions & 3 deletions backend/utils/retrieval/tools/conversation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,41 @@ def get_conversations_tool(
f"📚 get_conversations_tool - Added {len(conversations)} conversations to collection (total: {len(conversations_collected)})"
)

# Return formatted string
# Return formatted string with context size guard
# Cap output at ~400K tokens (~1.6M chars) to stay well under 500K token safety guard limit.
# This prevents both the "narrow down" error and 504 timeouts on large date ranges.
MAX_RESULT_CHARS = 1_600_000
result = Conversation.conversations_to_string(
conversations, use_transcript=include_transcript, include_timestamps=include_timestamps, people=people
)
Comment on lines 254 to 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double-formatting wastes memory and time on the hot path

The full conversations_to_string(conversations, ...) is called unconditionally before the truncation check, producing a potentially 5–10 MB string that is immediately discarded when the result exceeds MAX_RESULT_CHARS. Then the loop calls conversations_to_string([conv]) one more time per conversation. For the large-range queries this PR is meant to fix (e.g., limit=5000, include_transcript=True), this doubles the serialisation work and peak memory, and could itself approach the 120 s HTTP timeout before truncation even begins.

Consider building the per-conversation parts first (one pass), then joining and truncating in a second step — this avoids ever materialising the oversized string:

separator = "\n\n---------------------\n\n"
parts = [
    Conversation.conversations_to_string(
        [conv],
        use_transcript=include_transcript,
        include_timestamps=include_timestamps,
        people=people,
    )
    for conv in conversations
]

total = 0
kept = []
for part in parts:
    if total + len(part) + len(separator) > MAX_RESULT_CHARS and kept:
        break
    kept.append(part)
    total += len(part) + len(separator)

result = separator.join(kept)


if len(result) > MAX_RESULT_CHARS:
# Rebuild with truncation: include conversations until we hit the limit
truncated_parts = []
total_chars = 0
included_count = 0
separator = "\n\n---------------------\n\n"
for i, conversation in enumerate(conversations):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused loop variable i

i is assigned by enumerate but never referenced inside the loop body.

Suggested change
for i, conversation in enumerate(conversations):
for conversation in conversations:

part = Conversation.conversations_to_string(
[conversation],
use_transcript=include_transcript,
include_timestamps=include_timestamps,
people=people,
)
if total_chars + len(part) + len(separator) > MAX_RESULT_CHARS and included_count > 0:
break
truncated_parts.append(part)
total_chars += len(part) + len(separator)
included_count += 1
Comment on lines +265 to +275
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All truncated conversations labelled "Conversation #1"

conversations_to_string([conversation]) is called with a single-element list, so enumerate always starts at i=0 — meaning every conversation in the truncated output gets the header Conversation #1. When truncation is active, the LLM receives a context where every entry claims to be the first conversation.

The same issue exists in search_conversations_tool at lines 509-519.

A straightforward fix is to avoid re-calling conversations_to_string on individual items and instead carry the per-conversation strings through a single pre-computation pass:

# Pre-compute individual parts once, preserving original indices
all_parts = []
for idx, conversation in enumerate(conversations):
    sep_str = f"Conversation #{idx + 1}\n"
    # ... format using the same logic, or factor out a single-conversation formatter

Alternatively, build the formatted parts first and then join/truncate, rather than formatting the full list first and re-formatting from scratch on overflow.


omitted = len(conversations) - included_count
result = separator.join(truncated_parts)
if omitted > 0:
result += f"\n\n[Note: {omitted} older conversations omitted to fit context. Ask about a shorter time period for full details.]"
logger.info(
f"🔍 get_conversations_tool - Truncated result: included {included_count}/{len(conversations)}, omitted {omitted}"
)

logger.info(f"🔍 get_conversations_tool - Generated result string, length: {len(result)}")
return result

Expand Down Expand Up @@ -462,12 +493,40 @@ def search_conversations_tool(
f"📚 search_conversations_tool - Added {len(conversations)} conversations to collection (total: {len(conversations_collected)})"
)

# Return formatted string
# Return formatted string with context size guard
MAX_RESULT_CHARS = 1_600_000
result = f"Found {len(conversations)} conversations semantically matching '{query}':\n\n"
result += Conversation.conversations_to_string(
formatted = Conversation.conversations_to_string(
conversations, use_transcript=include_transcript, include_timestamps=include_timestamps, people=people
)

if len(formatted) > MAX_RESULT_CHARS:
truncated_parts = []
total_chars = 0
included_count = 0
separator = "\n\n---------------------\n\n"
for conversation in conversations:
part = Conversation.conversations_to_string(
[conversation],
use_transcript=include_transcript,
include_timestamps=include_timestamps,
people=people,
)
if total_chars + len(part) + len(separator) > MAX_RESULT_CHARS and included_count > 0:
break
truncated_parts.append(part)
total_chars += len(part) + len(separator)
included_count += 1

omitted = len(conversations) - included_count
formatted = separator.join(truncated_parts)
if omitted > 0:
formatted += f"\n\n[Note: {omitted} conversations omitted to fit context. Try a more specific query.]"
logger.info(
f"🔍 search_conversations_tool - Truncated: included {included_count}/{len(conversations)}, omitted {omitted}"
)

result += formatted
logger.info(f"🔍 search_conversations_tool - Generated result string, length: {len(result)}")

return result
Expand Down