diff --git a/src/knowledge_base/config.py b/src/knowledge_base/config.py index 499fc00..e939e38 100644 --- a/src/knowledge_base/config.py +++ b/src/knowledge_base/config.py @@ -98,6 +98,8 @@ class Settings(BaseSettings): NEO4J_LIVENESS_CHECK_TIMEOUT: int = 30 # Check connection health before use NEO4J_MAX_CONNECTION_LIFETIME: int = 1800 # Recycle connections every 30 min NEO4J_CONNECTION_ACQUISITION_TIMEOUT: int = 60 # Wait up to 60s for a connection + # Neo4j search retry on connection error (RuntimeError: TCPTransport closed) + NEO4J_SEARCH_MAX_RETRIES: int = 1 # Retries on stale connection in search path # Feature flags for Graphiti-only architecture GRAPH_ENABLE_GRAPHITI: bool = True # Master switch for Graphiti (now required) GRAPH_EXPANSION_ENABLED: bool = True # Always enabled with Graphiti-only diff --git a/src/knowledge_base/graph/graphiti_client.py b/src/knowledge_base/graph/graphiti_client.py index 88d667f..6058af3 100644 --- a/src/knowledge_base/graph/graphiti_client.py +++ b/src/knowledge_base/graph/graphiti_client.py @@ -399,6 +399,21 @@ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]] return SimpleEmbeddingCrossEncoder() + async def reset_and_reconnect(self) -> None: + """Reset the singleton and close the stale connection. + + Call this when a connection error is detected to force + a fresh connection on the next get_client() call. + """ + logger.warning("Resetting Graphiti client due to connection error") + try: + await self.close() + except Exception as e: + logger.warning(f"Error during client close on reset: {e}") + # Force reset even if close fails + GraphitiClient._instance = None + GraphitiClient._initialized = False + async def close(self) -> None: """Close the Graphiti client connection.""" if GraphitiClient._instance is not None: diff --git a/src/knowledge_base/graph/graphiti_retriever.py b/src/knowledge_base/graph/graphiti_retriever.py index 4f0848d..70d8b20 100644 --- a/src/knowledge_base/graph/graphiti_retriever.py +++ b/src/knowledge_base/graph/graphiti_retriever.py @@ -27,6 +27,30 @@ logger = logging.getLogger(__name__) +def _is_connection_error(exc: Exception) -> bool: + """Check if an exception indicates a Neo4j connection failure. + + The Neo4j driver's liveness check catches OSError, ServiceUnavailable, + and SessionExpired — but RuntimeError from asyncio's transport layer + escapes these catch clauses. This helper detects all connection-related + errors so the search path can retry with a fresh connection. + """ + # asyncio transport closed (the specific error seen in production) + if isinstance(exc, RuntimeError) and "TCPTransport" in str(exc): + return True + # Standard Neo4j connection errors + try: + from neo4j.exceptions import ServiceUnavailable, SessionExpired + if isinstance(exc, (ServiceUnavailable, SessionExpired)): + return True + except ImportError: + pass + # OS-level connection errors (broken pipe, connection reset) + if isinstance(exc, OSError): + return True + return False + + @dataclass class SearchResult: """A single search result (compatible with VectorRetriever.SearchResult).""" @@ -181,6 +205,8 @@ def _to_search_result(self, graphiti_result: Any, episode_data: dict | None = No async def _lookup_episodes(self, episode_uuids: list[str]) -> dict[str, dict]: """Look up episode content and metadata from Neo4j by UUID. + Includes retry-with-reset on Neo4j connection errors. + Args: episode_uuids: List of episode UUIDs to look up @@ -190,37 +216,48 @@ async def _lookup_episodes(self, episode_uuids: list[str]) -> dict[str, dict]: if not episode_uuids: return {} - try: - graphiti = await self._get_graphiti() - driver = graphiti.driver - - records, _, _ = await driver.execute_query( - """ - MATCH (ep:Episodic) - WHERE ep.uuid IN $uuids - RETURN ep.uuid as uuid, ep.name as name, - ep.content as content, - ep.source_description as source_desc - """, - uuids=episode_uuids, - ) + max_retries = settings.NEO4J_SEARCH_MAX_RETRIES + for attempt in range(1 + max_retries): + try: + graphiti = await self._get_graphiti() + driver = graphiti.driver + + records, _, _ = await driver.execute_query( + """ + MATCH (ep:Episodic) + WHERE ep.uuid IN $uuids + RETURN ep.uuid as uuid, ep.name as name, + ep.content as content, + ep.source_description as source_desc + """, + uuids=episode_uuids, + ) - result = {} - for record in records: - uuid = record["uuid"] - source_desc = record.get("source_desc") - metadata = self._parse_metadata(source_desc) - result[uuid] = { - "name": record.get("name", ""), - "content": record.get("content", "") or "", - "metadata": metadata, - } + result = {} + for record in records: + uuid = record["uuid"] + source_desc = record.get("source_desc") + metadata = self._parse_metadata(source_desc) + result[uuid] = { + "name": record.get("name", ""), + "content": record.get("content", "") or "", + "metadata": metadata, + } - return result + return result - except Exception as e: - logger.warning(f"Failed to look up episodes: {e}") - return {} + except Exception as e: + if attempt < max_retries and _is_connection_error(e): + logger.warning( + f"Neo4j connection error in episode lookup (attempt {attempt + 1}), " + f"resetting client and retrying: {e}" + ) + await self.client.reset_and_reconnect() + self._graphiti = None + continue + logger.warning(f"Failed to look up episodes: {e}") + return {} + return {} async def search_chunks( self, @@ -248,82 +285,93 @@ async def search_chunks( logger.warning("Graphiti retrieval DISABLED — returning empty results") return [] - try: - graphiti = await self._get_graphiti() - - # Over-fetch to account for filtering - fetch_count = num_results * 3 if (space_key or doc_type or min_quality_score) else num_results - - results = await graphiti.search( - query=query, - num_results=fetch_count, - group_ids=[self.group_id], - ) + max_retries = settings.NEO4J_SEARCH_MAX_RETRIES + for attempt in range(1 + max_retries): + try: + graphiti = await self._get_graphiti() - logger.info(f"Graphiti raw search returned {len(results)} results for: {query[:50]}...") + # Over-fetch to account for filtering + fetch_count = num_results * 3 if (space_key or doc_type or min_quality_score) else num_results - # Collect all episode UUIDs from edge results for batch lookup - all_episode_uuids = [] - for result in results: - episodes = getattr(result, 'episodes', None) or [] - all_episode_uuids.extend(episodes) - - # Batch lookup episode content and metadata - episode_data = {} - if all_episode_uuids: - unique_uuids = list(set(all_episode_uuids)) - episode_data = await self._lookup_episodes(unique_uuids) - logger.info( - f"Looked up {len(episode_data)}/{len(unique_uuids)} episodes " - f"for {len(results)} search results" + results = await graphiti.search( + query=query, + num_results=fetch_count, + group_ids=[self.group_id], ) - # Convert and filter results - search_results = [] - seen_episodes = set() # Deduplicate by episode - for result in results: - # Get episode data for this result (use first episode) - episodes = getattr(result, 'episodes', None) or [] - ep_data = None - for ep_uuid in episodes: - if ep_uuid in episode_data and ep_uuid not in seen_episodes: - ep_data = episode_data[ep_uuid] - seen_episodes.add(ep_uuid) - break - - sr = self._to_search_result(result, episode_data=ep_data) + logger.info(f"Graphiti raw search returned {len(results)} results for: {query[:50]}...") + + # Collect all episode UUIDs from edge results for batch lookup + all_episode_uuids = [] + for result in results: + episodes = getattr(result, 'episodes', None) or [] + all_episode_uuids.extend(episodes) + + # Batch lookup episode content and metadata + episode_data = {} + if all_episode_uuids: + unique_uuids = list(set(all_episode_uuids)) + episode_data = await self._lookup_episodes(unique_uuids) + logger.info( + f"Looked up {len(episode_data)}/{len(unique_uuids)} episodes " + f"for {len(results)} search results" + ) + + # Convert and filter results + search_results = [] + seen_episodes = set() # Deduplicate by episode + for result in results: + # Get episode data for this result (use first episode) + episodes = getattr(result, 'episodes', None) or [] + ep_data = None + for ep_uuid in episodes: + if ep_uuid in episode_data and ep_uuid not in seen_episodes: + ep_data = episode_data[ep_uuid] + seen_episodes.add(ep_uuid) + break - # Skip deleted chunks - if sr.metadata.get('deleted'): - continue + sr = self._to_search_result(result, episode_data=ep_data) - # Apply filters - if space_key and sr.metadata.get('space_key') != space_key: - continue - if doc_type and sr.metadata.get('doc_type') != doc_type: - continue - if min_quality_score and sr.quality_score < min_quality_score: - continue + # Skip deleted chunks + if sr.metadata.get('deleted'): + continue - search_results.append(sr) + # Apply filters + if space_key and sr.metadata.get('space_key') != space_key: + continue + if doc_type and sr.metadata.get('doc_type') != doc_type: + continue + if min_quality_score and sr.quality_score < min_quality_score: + continue - if len(search_results) >= num_results: - break + search_results.append(sr) - if len(results) > 0 and len(search_results) == 0: - logger.warning( - f"Graphiti returned {len(results)} raw results but ALL were filtered out " - f"(space_key={space_key}, doc_type={doc_type}, min_quality={min_quality_score})" - ) - logger.info(f"Graphiti search_chunks returning {len(search_results)}/{len(results)} results for: {query[:50]}...") - return search_results + if len(search_results) >= num_results: + break - except GraphitiClientError as e: - logger.error(f"Graphiti search FAILED: {e}", exc_info=True) - return [] - except Exception as e: - logger.error(f"Unexpected error in Graphiti search: {e}", exc_info=True) - return [] + if len(results) > 0 and len(search_results) == 0: + logger.warning( + f"Graphiti returned {len(results)} raw results but ALL were filtered out " + f"(space_key={space_key}, doc_type={doc_type}, min_quality={min_quality_score})" + ) + logger.info(f"Graphiti search_chunks returning {len(search_results)}/{len(results)} results for: {query[:50]}...") + return search_results + + except Exception as e: + if attempt < max_retries and _is_connection_error(e): + logger.warning( + f"Neo4j connection error on search attempt {attempt + 1}, " + f"resetting client and retrying: {e}" + ) + await self.client.reset_and_reconnect() + self._graphiti = None + continue + if isinstance(e, GraphitiClientError): + logger.error(f"Graphiti search FAILED: {e}", exc_info=True) + else: + logger.error(f"Unexpected error in Graphiti search: {e}", exc_info=True) + return [] + return [] async def search_with_quality_boost( self, diff --git a/tests/test_connection_retry.py b/tests/test_connection_retry.py new file mode 100644 index 0000000..dc29a92 --- /dev/null +++ b/tests/test_connection_retry.py @@ -0,0 +1,228 @@ +"""Tests for Neo4j stale connection retry logic. + +Tests the _is_connection_error() helper and retry-with-reset behavior +in GraphitiRetriever.search_chunks() and _lookup_episodes(). +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from knowledge_base.graph.graphiti_retriever import ( + _is_connection_error, + GraphitiRetriever, + SearchResult, +) + + +class TestIsConnectionError: + """Tests for the _is_connection_error() helper.""" + + def test_runtime_error_tcp_transport(self): + """RuntimeError with TCPTransport is a connection error.""" + exc = RuntimeError("unable to perform operation on ") + assert _is_connection_error(exc) is True + + def test_runtime_error_other(self): + """RuntimeError without TCPTransport is NOT a connection error.""" + exc = RuntimeError("some other runtime error") + assert _is_connection_error(exc) is False + + def test_os_error(self): + """OSError (broken pipe, connection reset) is a connection error.""" + exc = OSError("Connection reset by peer") + assert _is_connection_error(exc) is True + + def test_connection_refused(self): + """ConnectionRefusedError (subclass of OSError) is a connection error.""" + exc = ConnectionRefusedError("Connection refused") + assert _is_connection_error(exc) is True + + def test_service_unavailable(self): + """neo4j.exceptions.ServiceUnavailable is a connection error.""" + try: + from neo4j.exceptions import ServiceUnavailable + exc = ServiceUnavailable("Server unavailable") + assert _is_connection_error(exc) is True + except ImportError: + pytest.skip("neo4j package not installed") + + def test_session_expired(self): + """neo4j.exceptions.SessionExpired is a connection error.""" + try: + from neo4j.exceptions import SessionExpired + exc = SessionExpired("Session expired") + assert _is_connection_error(exc) is True + except ImportError: + pytest.skip("neo4j package not installed") + + def test_value_error_not_connection(self): + """ValueError is NOT a connection error.""" + exc = ValueError("invalid value") + assert _is_connection_error(exc) is False + + def test_key_error_not_connection(self): + """KeyError is NOT a connection error.""" + exc = KeyError("missing key") + assert _is_connection_error(exc) is False + + def test_generic_exception_not_connection(self): + """Generic Exception is NOT a connection error.""" + exc = Exception("something went wrong") + assert _is_connection_error(exc) is False + + +class TestSearchChunksRetry: + """Tests for retry behavior in GraphitiRetriever.search_chunks().""" + + @pytest.fixture + def retriever(self): + """Create a GraphitiRetriever with mocked client.""" + with patch("knowledge_base.graph.graphiti_retriever.get_graphiti_client") as mock_get: + mock_client = MagicMock() + mock_client.get_client = AsyncMock() + mock_client.reset_and_reconnect = AsyncMock() + mock_get.return_value = mock_client + retriever = GraphitiRetriever() + yield retriever + + @pytest.mark.asyncio + async def test_retries_on_tcp_transport_error(self, retriever): + """search_chunks retries once on TCPTransport RuntimeError.""" + mock_graphiti = AsyncMock() + # First call raises connection error, second succeeds + mock_result = MagicMock() + mock_result.episodes = [] + mock_result.score = 0.9 + mock_result.content = "test content" + mock_result.name = "test" + mock_result.source_description = None + mock_result.fact = None + + mock_graphiti.search = AsyncMock( + side_effect=[ + RuntimeError("unable to perform operation on "), + [mock_result], + ] + ) + retriever._graphiti = mock_graphiti + + with patch.object(retriever, "_get_graphiti", new_callable=AsyncMock, return_value=mock_graphiti): + # Make _get_graphiti return the mock on retry too + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + results = await retriever.search_chunks("test query") + + assert len(results) == 1 + assert retriever.client.reset_and_reconnect.call_count == 1 + assert mock_graphiti.search.call_count == 2 + + @pytest.mark.asyncio + async def test_no_retry_on_regular_error(self, retriever): + """search_chunks does NOT retry on non-connection errors.""" + mock_graphiti = AsyncMock() + mock_graphiti.search = AsyncMock(side_effect=ValueError("bad query")) + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + results = await retriever.search_chunks("test query") + + assert results == [] + assert retriever.client.reset_and_reconnect.call_count == 0 + assert mock_graphiti.search.call_count == 1 + + @pytest.mark.asyncio + async def test_returns_empty_after_exhausted_retries(self, retriever): + """search_chunks returns empty list when all retries fail.""" + mock_graphiti = AsyncMock() + mock_graphiti.search = AsyncMock( + side_effect=RuntimeError("unable to perform operation on ") + ) + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + results = await retriever.search_chunks("test query") + + assert results == [] + assert retriever.client.reset_and_reconnect.call_count == 1 + # 2 attempts: initial + 1 retry + assert mock_graphiti.search.call_count == 2 + + @pytest.mark.asyncio + async def test_succeeds_without_retry(self, retriever): + """search_chunks succeeds on first attempt without retry.""" + mock_graphiti = AsyncMock() + mock_result = MagicMock() + mock_result.episodes = [] + mock_result.score = 0.9 + mock_result.content = "test content" + mock_result.name = "test" + mock_result.source_description = None + mock_result.fact = None + + mock_graphiti.search = AsyncMock(return_value=[mock_result]) + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + results = await retriever.search_chunks("test query") + + assert len(results) == 1 + assert retriever.client.reset_and_reconnect.call_count == 0 + assert mock_graphiti.search.call_count == 1 + + +class TestLookupEpisodesRetry: + """Tests for retry behavior in GraphitiRetriever._lookup_episodes().""" + + @pytest.fixture + def retriever(self): + """Create a GraphitiRetriever with mocked client.""" + with patch("knowledge_base.graph.graphiti_retriever.get_graphiti_client") as mock_get: + mock_client = MagicMock() + mock_client.get_client = AsyncMock() + mock_client.reset_and_reconnect = AsyncMock() + mock_get.return_value = mock_client + retriever = GraphitiRetriever() + yield retriever + + @pytest.mark.asyncio + async def test_retries_on_connection_error(self, retriever): + """_lookup_episodes retries on connection error.""" + mock_driver = MagicMock() + mock_record = {"uuid": "abc-123", "name": "test", "content": "hello", "source_desc": None} + mock_driver.execute_query = AsyncMock( + side_effect=[ + RuntimeError("unable to perform operation on "), + ([mock_record], None, None), + ] + ) + + mock_graphiti = MagicMock() + mock_graphiti.driver = mock_driver + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + result = await retriever._lookup_episodes(["abc-123"]) + + assert "abc-123" in result + assert retriever.client.reset_and_reconnect.call_count == 1 + + @pytest.mark.asyncio + async def test_no_retry_on_regular_error(self, retriever): + """_lookup_episodes does NOT retry on non-connection errors.""" + mock_driver = MagicMock() + mock_driver.execute_query = AsyncMock(side_effect=ValueError("bad query")) + + mock_graphiti = MagicMock() + mock_graphiti.driver = mock_driver + retriever._graphiti = mock_graphiti + retriever._get_graphiti = AsyncMock(return_value=mock_graphiti) + + result = await retriever._lookup_episodes(["abc-123"]) + + assert result == {} + assert retriever.client.reset_and_reconnect.call_count == 0 + + @pytest.mark.asyncio + async def test_empty_uuids_returns_empty(self, retriever): + """_lookup_episodes returns empty for empty UUID list.""" + result = await retriever._lookup_episodes([]) + assert result == {}