Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/knowledge_base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class Settings(BaseSettings):
NEO4J_LIVENESS_CHECK_TIMEOUT: int = 30 # Check connection health before use
NEO4J_MAX_CONNECTION_LIFETIME: int = 1800 # Recycle connections every 30 min
NEO4J_CONNECTION_ACQUISITION_TIMEOUT: int = 60 # Wait up to 60s for a connection
# Neo4j search retry on connection error (RuntimeError: TCPTransport closed)
NEO4J_SEARCH_MAX_RETRIES: int = 1 # Retries on stale connection in search path
# Feature flags for Graphiti-only architecture
GRAPH_ENABLE_GRAPHITI: bool = True # Master switch for Graphiti (now required)
GRAPH_EXPANSION_ENABLED: bool = True # Always enabled with Graphiti-only
Expand Down
15 changes: 15 additions & 0 deletions src/knowledge_base/graph/graphiti_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,21 @@ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]

return SimpleEmbeddingCrossEncoder()

async def reset_and_reconnect(self) -> None:
"""Reset the singleton and close the stale connection.

Call this when a connection error is detected to force
a fresh connection on the next get_client() call.
"""
logger.warning("Resetting Graphiti client due to connection error")
try:
await self.close()
except Exception as e:
logger.warning(f"Error during client close on reset: {e}")
# Force reset even if close fails
GraphitiClient._instance = None
GraphitiClient._initialized = False

async def close(self) -> None:
"""Close the Graphiti client connection."""
if GraphitiClient._instance is not None:
Expand Down
236 changes: 142 additions & 94 deletions src/knowledge_base/graph/graphiti_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,30 @@
logger = logging.getLogger(__name__)


def _is_connection_error(exc: Exception) -> bool:
"""Check if an exception indicates a Neo4j connection failure.

The Neo4j driver's liveness check catches OSError, ServiceUnavailable,
and SessionExpired — but RuntimeError from asyncio's transport layer
escapes these catch clauses. This helper detects all connection-related
errors so the search path can retry with a fresh connection.
"""
# asyncio transport closed (the specific error seen in production)
if isinstance(exc, RuntimeError) and "TCPTransport" in str(exc):
return True
# Standard Neo4j connection errors
try:
from neo4j.exceptions import ServiceUnavailable, SessionExpired
if isinstance(exc, (ServiceUnavailable, SessionExpired)):
return True
except ImportError:
pass
# OS-level connection errors (broken pipe, connection reset)
if isinstance(exc, OSError):
return True
return False


@dataclass
class SearchResult:
"""A single search result (compatible with VectorRetriever.SearchResult)."""
Expand Down Expand Up @@ -181,6 +205,8 @@ def _to_search_result(self, graphiti_result: Any, episode_data: dict | None = No
async def _lookup_episodes(self, episode_uuids: list[str]) -> dict[str, dict]:
"""Look up episode content and metadata from Neo4j by UUID.

Includes retry-with-reset on Neo4j connection errors.

Args:
episode_uuids: List of episode UUIDs to look up

Expand All @@ -190,37 +216,48 @@ async def _lookup_episodes(self, episode_uuids: list[str]) -> dict[str, dict]:
if not episode_uuids:
return {}

try:
graphiti = await self._get_graphiti()
driver = graphiti.driver

records, _, _ = await driver.execute_query(
"""
MATCH (ep:Episodic)
WHERE ep.uuid IN $uuids
RETURN ep.uuid as uuid, ep.name as name,
ep.content as content,
ep.source_description as source_desc
""",
uuids=episode_uuids,
)
max_retries = settings.NEO4J_SEARCH_MAX_RETRIES
for attempt in range(1 + max_retries):
try:
graphiti = await self._get_graphiti()
driver = graphiti.driver

records, _, _ = await driver.execute_query(
"""
MATCH (ep:Episodic)
WHERE ep.uuid IN $uuids
RETURN ep.uuid as uuid, ep.name as name,
ep.content as content,
ep.source_description as source_desc
""",
uuids=episode_uuids,
)

result = {}
for record in records:
uuid = record["uuid"]
source_desc = record.get("source_desc")
metadata = self._parse_metadata(source_desc)
result[uuid] = {
"name": record.get("name", ""),
"content": record.get("content", "") or "",
"metadata": metadata,
}
result = {}
for record in records:
uuid = record["uuid"]
source_desc = record.get("source_desc")
metadata = self._parse_metadata(source_desc)
result[uuid] = {
"name": record.get("name", ""),
"content": record.get("content", "") or "",
"metadata": metadata,
}

return result
return result

except Exception as e:
logger.warning(f"Failed to look up episodes: {e}")
return {}
except Exception as e:
if attempt < max_retries and _is_connection_error(e):
logger.warning(
f"Neo4j connection error in episode lookup (attempt {attempt + 1}), "
f"resetting client and retrying: {e}"
)
await self.client.reset_and_reconnect()
self._graphiti = None
continue
logger.warning(f"Failed to look up episodes: {e}")
return {}
return {}

async def search_chunks(
self,
Expand Down Expand Up @@ -248,82 +285,93 @@ async def search_chunks(
logger.warning("Graphiti retrieval DISABLED — returning empty results")
return []

try:
graphiti = await self._get_graphiti()

# Over-fetch to account for filtering
fetch_count = num_results * 3 if (space_key or doc_type or min_quality_score) else num_results

results = await graphiti.search(
query=query,
num_results=fetch_count,
group_ids=[self.group_id],
)
max_retries = settings.NEO4J_SEARCH_MAX_RETRIES
for attempt in range(1 + max_retries):
try:
graphiti = await self._get_graphiti()

logger.info(f"Graphiti raw search returned {len(results)} results for: {query[:50]}...")
# Over-fetch to account for filtering
fetch_count = num_results * 3 if (space_key or doc_type or min_quality_score) else num_results

# Collect all episode UUIDs from edge results for batch lookup
all_episode_uuids = []
for result in results:
episodes = getattr(result, 'episodes', None) or []
all_episode_uuids.extend(episodes)

# Batch lookup episode content and metadata
episode_data = {}
if all_episode_uuids:
unique_uuids = list(set(all_episode_uuids))
episode_data = await self._lookup_episodes(unique_uuids)
logger.info(
f"Looked up {len(episode_data)}/{len(unique_uuids)} episodes "
f"for {len(results)} search results"
results = await graphiti.search(
query=query,
num_results=fetch_count,
group_ids=[self.group_id],
)

# Convert and filter results
search_results = []
seen_episodes = set() # Deduplicate by episode
for result in results:
# Get episode data for this result (use first episode)
episodes = getattr(result, 'episodes', None) or []
ep_data = None
for ep_uuid in episodes:
if ep_uuid in episode_data and ep_uuid not in seen_episodes:
ep_data = episode_data[ep_uuid]
seen_episodes.add(ep_uuid)
break

sr = self._to_search_result(result, episode_data=ep_data)
logger.info(f"Graphiti raw search returned {len(results)} results for: {query[:50]}...")

# Collect all episode UUIDs from edge results for batch lookup
all_episode_uuids = []
for result in results:
episodes = getattr(result, 'episodes', None) or []
all_episode_uuids.extend(episodes)

# Batch lookup episode content and metadata
episode_data = {}
if all_episode_uuids:
unique_uuids = list(set(all_episode_uuids))
episode_data = await self._lookup_episodes(unique_uuids)
logger.info(
f"Looked up {len(episode_data)}/{len(unique_uuids)} episodes "
f"for {len(results)} search results"
)

# Convert and filter results
search_results = []
seen_episodes = set() # Deduplicate by episode
for result in results:
# Get episode data for this result (use first episode)
episodes = getattr(result, 'episodes', None) or []
ep_data = None
for ep_uuid in episodes:
if ep_uuid in episode_data and ep_uuid not in seen_episodes:
ep_data = episode_data[ep_uuid]
seen_episodes.add(ep_uuid)
break

# Skip deleted chunks
if sr.metadata.get('deleted'):
continue
sr = self._to_search_result(result, episode_data=ep_data)

# Apply filters
if space_key and sr.metadata.get('space_key') != space_key:
continue
if doc_type and sr.metadata.get('doc_type') != doc_type:
continue
if min_quality_score and sr.quality_score < min_quality_score:
continue
# Skip deleted chunks
if sr.metadata.get('deleted'):
continue

search_results.append(sr)
# Apply filters
if space_key and sr.metadata.get('space_key') != space_key:
continue
if doc_type and sr.metadata.get('doc_type') != doc_type:
continue
if min_quality_score and sr.quality_score < min_quality_score:
continue

if len(search_results) >= num_results:
break
search_results.append(sr)

if len(results) > 0 and len(search_results) == 0:
logger.warning(
f"Graphiti returned {len(results)} raw results but ALL were filtered out "
f"(space_key={space_key}, doc_type={doc_type}, min_quality={min_quality_score})"
)
logger.info(f"Graphiti search_chunks returning {len(search_results)}/{len(results)} results for: {query[:50]}...")
return search_results
if len(search_results) >= num_results:
break

except GraphitiClientError as e:
logger.error(f"Graphiti search FAILED: {e}", exc_info=True)
return []
except Exception as e:
logger.error(f"Unexpected error in Graphiti search: {e}", exc_info=True)
return []
if len(results) > 0 and len(search_results) == 0:
logger.warning(
f"Graphiti returned {len(results)} raw results but ALL were filtered out "
f"(space_key={space_key}, doc_type={doc_type}, min_quality={min_quality_score})"
)
logger.info(f"Graphiti search_chunks returning {len(search_results)}/{len(results)} results for: {query[:50]}...")
return search_results

except Exception as e:
if attempt < max_retries and _is_connection_error(e):
logger.warning(
f"Neo4j connection error on search attempt {attempt + 1}, "
f"resetting client and retrying: {e}"
)
await self.client.reset_and_reconnect()
self._graphiti = None
continue
if isinstance(e, GraphitiClientError):
logger.error(f"Graphiti search FAILED: {e}", exc_info=True)
else:
logger.error(f"Unexpected error in Graphiti search: {e}", exc_info=True)
return []
return []

async def search_with_quality_boost(
self,
Expand Down
Loading