From 990fd5883d1e3c5bf177c37a906821f587b8b41f Mon Sep 17 00:00:00 2001 From: Gemini Agent Date: Sun, 22 Feb 2026 10:43:43 +0100 Subject: [PATCH] Fix Neo4j stale connection recovery in search path The Slack bot returns "I couldn't find relevant information" when the Neo4j TCP connection goes stale after idle periods. The Neo4j driver's liveness check catches OSError/ServiceUnavailable/SessionExpired but RuntimeError from asyncio's transport layer escapes, silently returning zero search results. Add retry-with-reset at the GraphitiRetriever layer: - _is_connection_error() detects RuntimeError(TCPTransport), OSError, ServiceUnavailable, and SessionExpired - search_chunks() and _lookup_episodes() retry once on connection error after resetting the GraphitiClient singleton (fresh Neo4j connection) - NEO4J_SEARCH_MAX_RETRIES config setting (default: 1) - GraphitiClient.reset_and_reconnect() closes stale connection and clears singleton state --- src/knowledge_base/config.py | 2 + src/knowledge_base/graph/graphiti_client.py | 15 ++ .../graph/graphiti_retriever.py | 236 +++++++++++------- tests/test_connection_retry.py | 228 +++++++++++++++++ 4 files changed, 387 insertions(+), 94 deletions(-) create mode 100644 tests/test_connection_retry.py diff --git a/src/knowledge_base/config.py b/src/knowledge_base/config.py index 499fc00..e939e38 100644 --- a/src/knowledge_base/config.py +++ b/src/knowledge_base/config.py @@ -98,6 +98,8 @@ class Settings(BaseSettings): NEO4J_LIVENESS_CHECK_TIMEOUT: int = 30 # Check connection health before use NEO4J_MAX_CONNECTION_LIFETIME: int = 1800 # Recycle connections every 30 min NEO4J_CONNECTION_ACQUISITION_TIMEOUT: int = 60 # Wait up to 60s for a connection + # Neo4j search retry on connection error (RuntimeError: TCPTransport closed) + NEO4J_SEARCH_MAX_RETRIES: int = 1 # Retries on stale connection in search path # Feature flags for Graphiti-only architecture GRAPH_ENABLE_GRAPHITI: bool = True # Master switch for Graphiti (now required) GRAPH_EXPANSION_ENABLED: bool = True # Always enabled with Graphiti-only diff --git a/src/knowledge_base/graph/graphiti_client.py b/src/knowledge_base/graph/graphiti_client.py index 88d667f..6058af3 100644 --- a/src/knowledge_base/graph/graphiti_client.py +++ b/src/knowledge_base/graph/graphiti_client.py @@ -399,6 +399,21 @@ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]] return SimpleEmbeddingCrossEncoder() + async def reset_and_reconnect(self) -> None: + """Reset the singleton and close the stale connection. + + Call this when a connection error is detected to force + a fresh connection on the next get_client() call. + """ + logger.warning("Resetting Graphiti client due to connection error") + try: + await self.close() + except Exception as e: + logger.warning(f"Error during client close on reset: {e}") + # Force reset even if close fails + GraphitiClient._instance = None + GraphitiClient._initialized = False + async def close(self) -> None: """Close the Graphiti client connection.""" if GraphitiClient._instance is not None: diff --git a/src/knowledge_base/graph/graphiti_retriever.py b/src/knowledge_base/graph/graphiti_retriever.py index 4f0848d..70d8b20 100644 --- a/src/knowledge_base/graph/graphiti_retriever.py +++ b/src/knowledge_base/graph/graphiti_retriever.py @@ -27,6 +27,30 @@ logger = logging.getLogger(__name__) +def _is_connection_error(exc: Exception) -> bool: + """Check if an exception indicates a Neo4j connection failure. + + The Neo4j driver's liveness check catches OSError, ServiceUnavailable, + and SessionExpired — but RuntimeError from asyncio's transport layer + escapes these catch clauses. This helper detects all connection-related + errors so the search path can retry with a fresh connection. + """ + # asyncio transport closed (the specific error seen in production) + if isinstance(exc, RuntimeError) and "TCPTransport" in str(exc): + return True + # Standard Neo4j connection errors + try: + from neo4j.exceptions import ServiceUnavailable, SessionExpired + if isinstance(exc, (ServiceUnavailable, SessionExpired)): + return True + except ImportError: + pass + # OS-level connection errors (broken pipe, connection reset) + if isinstance(exc, OSError): + return True + return False + + @dataclass class SearchResult: """A single search result (compatible with VectorRetriever.SearchResult).""" @@ -181,6 +205,8 @@ def _to_search_result(self, graphiti_result: Any, episode_data: dict | None = No async def _lookup_episodes(self, episode_uuids: list[str]) -> dict[str, dict]: """Look up episode content and metadata from Neo4j by UUID. + Includes retry-with-reset on Neo4j connection errors. + Args: episode_uuids: List of episode UUIDs to look up @@ -190,37 +216,48 @@ async def _lookup_episodes(self, episode_uuids: list[str]) -> dict[str, dict]: if not episode_uuids: return {} - try: - graphiti = await self._get_graphiti() - driver = graphiti.driver - - records, _, _ = await driver.execute_query( - """ - MATCH (ep:Episodic) - WHERE ep.uuid IN $uuids - RETURN ep.uuid as uuid, ep.name as name, - ep.content as content, - ep.source_description as source_desc - """, - uuids=episode_uuids, - ) + max_retries = settings.NEO4J_SEARCH_MAX_RETRIES + for attempt in range(1 + max_retries): + try: + graphiti = await self._get_graphiti() + driver = graphiti.driver + + records, _, _ = await driver.execute_query( + """ + MATCH (ep:Episodic) + WHERE ep.uuid IN $uuids + RETURN ep.uuid as uuid, ep.name as name, + ep.content as content, + ep.source_description as source_desc + """, + uuids=episode_uuids, + ) - result = {} - for record in records: - uuid = record["uuid"] - source_desc = record.get("source_desc") - metadata = self._parse_metadata(source_desc) - result[uuid] = { - "name": record.get("name", ""), - "content": record.get("content", "") or "", - "metadata": metadata, - } + result = {} + for record in records: + uuid = record["uuid"] + source_desc = record.get("source_desc") + metadata = self._parse_metadata(source_desc) + result[uuid] = { + "name": record.get("name", ""), + "content": record.get("content", "") or "", + "metadata": metadata, + } - return result + return result - except Exception as e: - logger.warning(f"Failed to look up episodes: {e}") - return {} + except Exception as e: + if attempt < max_retries and _is_connection_error(e): + logger.warning( + f"Neo4j connection error in episode lookup (attempt {attempt + 1}), " + f"resetting client and retrying: {e}" + ) + await self.client.reset_and_reconnect() + self._graphiti = None + continue + logger.warning(f"Failed to look up episodes: {e}") + return {} + return {} async def search_chunks( self, @@ -248,82 +285,93 @@ async def search_chunks( logger.warning("Graphiti retrieval DISABLED — returning empty results") return [] - try: - graphiti = await self._get_graphiti() - - # Over-fetch to account for filtering - fetch_count = num_results * 3 if (space_key or doc_type or min_quality_score) else num_results - - results = await graphiti.search( - query=query, - num_results=fetch_count, - group_ids=[self.group_id], - ) + max_retries = settings.NEO4J_SEARCH_MAX_RETRIES + for attempt in range(1 + max_retries): + try: + graphiti = await self._get_graphiti() - logger.info(f"Graphiti raw search returned {len(results)} results for: {query[:50]}...") + # Over-fetch to account for filtering + fetch_count = num_results * 3 if (space_key or doc_type or min_quality_score) else num_results - # Collect all episode UUIDs from edge results for batch lookup - all_episode_uuids = [] - for result in results: - episodes = getattr(result, 'episodes', None) or [] - all_episode_uuids.extend(episodes) - - # Batch lookup episode content and metadata - episode_data = {} - if all_episode_uuids: - unique_uuids = list(set(all_episode_uuids)) - episode_data = await self._lookup_episodes(unique_uuids) - logger.info( - f"Looked up {len(episode_data)}/{len(unique_uuids)} episodes " - f"for {len(results)} search results" + results = await graphiti.search( + query=query, + num_results=fetch_count, + group_ids=[self.group_id], ) - # Convert and filter results - search_results = [] - seen_episodes = set() # Deduplicate by episode - for result in results: - # Get episode data for this result (use first episode) - episodes = getattr(result, 'episodes', None) or [] - ep_data = None - for ep_uuid in episodes: - if ep_uuid in episode_data and ep_uuid not in seen_episodes: - ep_data = episode_data[ep_uuid] - seen_episodes.add(ep_uuid) - break - - sr = self._to_search_result(result, episode_data=ep_data) + logger.info(f"Graphiti raw search returned {len(results)} results for: {query[:50]}...") + + # Collect all episode UUIDs from edge results for batch lookup + all_episode_uuids = [] + for result in results: + episodes = getattr(result, 'episodes', None) or [] + all_episode_uuids.extend(episodes) + + # Batch lookup episode content and metadata + episode_data = {} + if all_episode_uuids: + unique_uuids = list(set(all_episode_uuids)) + episode_data = await self._lookup_episodes(unique_uuids) + logger.info( + f"Looked up {len(episode_data)}/{len(unique_uuids)} episodes " + f"for {len(results)} search results" + ) + + # Convert and filter results + search_results = [] + seen_episodes = set() # Deduplicate by episode + for result in results: + # Get episode data for this result (use first episode) + episodes = getattr(result, 'episodes', None) or [] + ep_data = None + for ep_uuid in episodes: + if ep_uuid in episode_data and ep_uuid not in seen_episodes: + ep_data = episode_data[ep_uuid] + seen_episodes.add(ep_uuid) + break - # Skip deleted chunks - if sr.metadata.get('deleted'): - continue + sr = self._to_search_result(result, episode_data=ep_data) - # Apply filters - if space_key and sr.metadata.get('space_key') != space_key: - continue - if doc_type and sr.metadata.get('doc_type') != doc_type: - continue - if min_quality_score and sr.quality_score < min_quality_score: - continue + # Skip deleted chunks + if sr.metadata.get('deleted'): + continue - search_results.append(sr) + # Apply filters + if space_key and sr.metadata.get('space_key') != space_key: + continue + if doc_type and sr.metadata.get('doc_type') != doc_type: + continue + if min_quality_score and sr.quality_score < min_quality_score: + continue - if len(search_results) >= num_results: - break + search_results.append(sr) - if len(results) > 0 and len(search_results) == 0: - logger.warning( - f"Graphiti returned {len(results)} raw results but ALL were filtered out " - f"(space_key={space_key}, doc_type={doc_type}, min_quality={min_quality_score})" - ) - logger.info(f"Graphiti search_chunks returning {len(search_results)}/{len(results)} results for: {query[:50]}...") - return search_results + if len(search_results) >= num_results: + break - except GraphitiClientError as e: - logger.error(f"Graphiti search FAILED: {e}", exc_info=True) - return [] - except Exception as e: - logger.error(f"Unexpected error in Graphiti search: {e}", exc_info=True) - return [] + if len(results) > 0 and len(search_results) == 0: + logger.warning( + f"Graphiti returned {len(results)} raw results but ALL were filtered out " + f"(space_key={space_key}, doc_type={doc_type}, min_quality={min_quality_score})" + ) + logger.info(f"Graphiti search_chunks returning {len(search_results)}/{len(results)} results for: {query[:50]}...") + return search_results + + except Exception as e: + if attempt < max_retries and _is_connection_error(e): + logger.warning( + f"Neo4j connection error on search attempt {attempt + 1}, " + f"resetting client and retrying: {e}" + ) + await self.client.reset_and_reconnect() + self._graphiti = None + continue + if isinstance(e, GraphitiClientError): + logger.error(f"Graphiti search FAILED: {e}", exc_info=True) + else: + logger.error(f"Unexpected error in Graphiti search: {e}", exc_info=True) + return [] + return [] async def search_with_quality_boost( self, diff --git a/tests/test_connection_retry.py b/tests/test_connection_retry.py new file mode 100644 index 0000000..dc29a92 --- /dev/null +++ b/tests/test_connection_retry.py @@ -0,0 +1,228 @@ +"""Tests for Neo4j stale connection retry logic. + +Tests the _is_connection_error() helper and retry-with-reset behavior +in GraphitiRetriever.search_chunks() and _lookup_episodes(). +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from knowledge_base.graph.graphiti_retriever import ( + _is_connection_error, + GraphitiRetriever, + SearchResult, +) + + +class TestIsConnectionError: + """Tests for the _is_connection_error() helper.""" + + def test_runtime_error_tcp_transport(self): + """RuntimeError with TCPTransport is a connection error.""" + exc = RuntimeError("unable to perform operation on ") + assert _is_connection_error(exc) is True + + def test_runtime_error_other(self): + """RuntimeError without TCPTransport is NOT a connection error.""" + exc = RuntimeError("some other runtime error") + assert _is_connection_error(exc) is False + + def test_os_error(self): + """OSError (broken pipe, connection reset) is a connection error.""" + exc = OSError("Connection reset by peer") + assert _is_connection_error(exc) is True + + def test_connection_refused(self): + """ConnectionRefusedError (subclass of OSError) is a connection error.""" + exc = ConnectionRefusedError("Connection refused") + assert _is_connection_error(exc) is True + + def test_service_unavailable(self): + """neo4j.exceptions.ServiceUnavailable is a connection error.""" + try: + from neo4j.exceptions import ServiceUnavailable + exc = ServiceUnavailable("Server unavailable") + assert _is_connection_error(exc) is True + except ImportError: + pytest.skip("neo4j package not installed") + + def test_session_expired(self): + """neo4j.exceptions.SessionExpired is a connection error.""" + try: + from neo4j.exceptions import SessionExpired + exc = SessionExpired("Session expired") + assert _is_connection_error(exc) is True + except ImportError: + pytest.skip("neo4j package not installed") + + def test_value_error_not_connection(self): + """ValueError is NOT a connection error.""" + exc = ValueError("invalid value") + assert _is_connection_error(exc) is False + + def test_key_error_not_connection(self): + """KeyError is NOT a connection error.""" + exc = KeyError("missing key") + assert _is_connection_error(exc) is False + + def test_generic_exception_not_connection(self): + """Generic Exception is NOT a connection error.""" + exc = Exception("something went wrong") + assert _is_connection_error(exc) is False + + +class TestSearchChunksRetry: + """Tests for retry behavior in GraphitiRetriever.search_chunks().""" + + @pytest.fixture + def retriever(self): + """Create a GraphitiRetriever with mocked client.""" + with patch("knowledge_base.graph.graphiti_retriever.get_graphiti_client") as mock_get: + mock_client = MagicMock() + mock_client.get_client = AsyncMock() + mock_client.reset_and_reconnect = AsyncMock() + mock_get.return_value = mock_client + retriever = GraphitiRetriever() + yield retriever + + @pytest.mark.asyncio + async def test_retries_on_tcp_transport_error(self, retriever): + """search_chunks retries once on TCPTransport RuntimeError.""" + mock_graphiti = AsyncMock() + # First call raises connection error, second succeeds + mock_result = MagicMock() + mock_result.episodes = [] + mock_result.score = 0.9 + mock_result.content = "test content" + mock_result.name = "test" + mock_result.source_description = None + mock_result.fact = None + + mock_graphiti.search = AsyncMock( + side_effect=[ + RuntimeError("unable to perform operation on "), + [mock_result], + ] + ) + retriever._graphiti = mock_graphiti + + with patch.object(retriever, "_get_graphiti", new_callable=AsyncMock, return_value=mock_graphiti): + # Make _get_graphiti return the mock on retry too + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + results = await retriever.search_chunks("test query") + + assert len(results) == 1 + assert retriever.client.reset_and_reconnect.call_count == 1 + assert mock_graphiti.search.call_count == 2 + + @pytest.mark.asyncio + async def test_no_retry_on_regular_error(self, retriever): + """search_chunks does NOT retry on non-connection errors.""" + mock_graphiti = AsyncMock() + mock_graphiti.search = AsyncMock(side_effect=ValueError("bad query")) + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + results = await retriever.search_chunks("test query") + + assert results == [] + assert retriever.client.reset_and_reconnect.call_count == 0 + assert mock_graphiti.search.call_count == 1 + + @pytest.mark.asyncio + async def test_returns_empty_after_exhausted_retries(self, retriever): + """search_chunks returns empty list when all retries fail.""" + mock_graphiti = AsyncMock() + mock_graphiti.search = AsyncMock( + side_effect=RuntimeError("unable to perform operation on ") + ) + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + results = await retriever.search_chunks("test query") + + assert results == [] + assert retriever.client.reset_and_reconnect.call_count == 1 + # 2 attempts: initial + 1 retry + assert mock_graphiti.search.call_count == 2 + + @pytest.mark.asyncio + async def test_succeeds_without_retry(self, retriever): + """search_chunks succeeds on first attempt without retry.""" + mock_graphiti = AsyncMock() + mock_result = MagicMock() + mock_result.episodes = [] + mock_result.score = 0.9 + mock_result.content = "test content" + mock_result.name = "test" + mock_result.source_description = None + mock_result.fact = None + + mock_graphiti.search = AsyncMock(return_value=[mock_result]) + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + results = await retriever.search_chunks("test query") + + assert len(results) == 1 + assert retriever.client.reset_and_reconnect.call_count == 0 + assert mock_graphiti.search.call_count == 1 + + +class TestLookupEpisodesRetry: + """Tests for retry behavior in GraphitiRetriever._lookup_episodes().""" + + @pytest.fixture + def retriever(self): + """Create a GraphitiRetriever with mocked client.""" + with patch("knowledge_base.graph.graphiti_retriever.get_graphiti_client") as mock_get: + mock_client = MagicMock() + mock_client.get_client = AsyncMock() + mock_client.reset_and_reconnect = AsyncMock() + mock_get.return_value = mock_client + retriever = GraphitiRetriever() + yield retriever + + @pytest.mark.asyncio + async def test_retries_on_connection_error(self, retriever): + """_lookup_episodes retries on connection error.""" + mock_driver = MagicMock() + mock_record = {"uuid": "abc-123", "name": "test", "content": "hello", "source_desc": None} + mock_driver.execute_query = AsyncMock( + side_effect=[ + RuntimeError("unable to perform operation on "), + ([mock_record], None, None), + ] + ) + + mock_graphiti = MagicMock() + mock_graphiti.driver = mock_driver + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + result = await retriever._lookup_episodes(["abc-123"]) + + assert "abc-123" in result + assert retriever.client.reset_and_reconnect.call_count == 1 + + @pytest.mark.asyncio + async def test_no_retry_on_regular_error(self, retriever): + """_lookup_episodes does NOT retry on non-connection errors.""" + mock_driver = MagicMock() + mock_driver.execute_query = AsyncMock(side_effect=ValueError("bad query")) + + mock_graphiti = MagicMock() + mock_graphiti.driver = mock_driver + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + result = await retriever._lookup_episodes(["abc-123"]) + + assert result == {} + assert retriever.client.reset_and_reconnect.call_count == 0 + + @pytest.mark.asyncio + async def test_empty_uuids_returns_empty(self, retriever): + """_lookup_episodes returns empty for empty UUID list.""" + result = await retriever._lookup_episodes([]) + assert result == {}