diff --git a/src/knowledge_base/config.py b/src/knowledge_base/config.py index de6616c..499fc00 100644 --- a/src/knowledge_base/config.py +++ b/src/knowledge_base/config.py @@ -94,6 +94,10 @@ class Settings(BaseSettings): NEO4J_URI: str = "bolt://localhost:7687" NEO4J_USER: str = "neo4j" NEO4J_PASSWORD: str = "" + # Neo4j connection pool resilience (seconds) + NEO4J_LIVENESS_CHECK_TIMEOUT: int = 30 # Check connection health before use + NEO4J_MAX_CONNECTION_LIFETIME: int = 1800 # Recycle connections every 30 min + NEO4J_CONNECTION_ACQUISITION_TIMEOUT: int = 60 # Wait up to 60s for a connection # Feature flags for Graphiti-only architecture GRAPH_ENABLE_GRAPHITI: bool = True # Master switch for Graphiti (now required) GRAPH_EXPANSION_ENABLED: bool = True # Always enabled with Graphiti-only diff --git a/src/knowledge_base/graph/graphiti_client.py b/src/knowledge_base/graph/graphiti_client.py index 63eed71..88d667f 100644 --- a/src/knowledge_base/graph/graphiti_client.py +++ b/src/knowledge_base/graph/graphiti_client.py @@ -7,6 +7,7 @@ NetworkX-based graph during the gradual rollout phase. """ +import asyncio import logging import os from typing import TYPE_CHECKING @@ -38,6 +39,7 @@ class GraphitiClient: _instance: "Graphiti | None" = None _initialized: bool = False + _init_lock: asyncio.Lock | None = None def __init__( self, @@ -68,6 +70,9 @@ def __init__( async def get_client(self) -> "Graphiti": """Get or create the Graphiti client instance. + Uses an asyncio.Lock to prevent concurrent initialization from + multiple async tasks creating duplicate Graphiti instances. + Returns: Configured Graphiti instance @@ -78,22 +83,31 @@ async def get_client(self) -> "Graphiti": if GraphitiClient._instance is not None and GraphitiClient._initialized: return GraphitiClient._instance - try: - if self.backend == "kuzu": - client = await self._create_kuzu_client() - elif self.backend == "neo4j": - client = await self._create_neo4j_client() - else: - raise GraphitiClientError(f"Unsupported graph backend: {self.backend}") + # Lazy-init the lock (safe: first call always happens in a single task) + if GraphitiClient._init_lock is None: + GraphitiClient._init_lock = asyncio.Lock() - GraphitiClient._instance = client - GraphitiClient._initialized = True - logger.info(f"Graphiti client initialized with {self.backend} backend") - return client + async with GraphitiClient._init_lock: + # Double-check after acquiring lock + if GraphitiClient._instance is not None and GraphitiClient._initialized: + return GraphitiClient._instance - except Exception as e: - logger.error(f"Failed to initialize Graphiti client: {e}") - raise GraphitiConnectionError(f"Could not connect to {self.backend}: {e}") from e + try: + if self.backend == "kuzu": + client = await self._create_kuzu_client() + elif self.backend == "neo4j": + client = await self._create_neo4j_client() + else: + raise GraphitiClientError(f"Unsupported graph backend: {self.backend}") + + GraphitiClient._instance = client + GraphitiClient._initialized = True + logger.info(f"Graphiti client initialized with {self.backend} backend") + return client + + except Exception as e: + logger.error(f"Failed to initialize Graphiti client: {e}") + raise GraphitiConnectionError(f"Could not connect to {self.backend}: {e}") from e async def _create_kuzu_client(self) -> "Graphiti": """Create Graphiti client with Kuzu embedded backend.""" @@ -136,10 +150,34 @@ async def _create_kuzu_client(self) -> "Graphiti": async def _create_neo4j_client(self) -> "Graphiti": """Create Graphiti client with Neo4j backend.""" from graphiti_core import Graphiti + from graphiti_core.driver.neo4j_driver import Neo4jDriver + from neo4j import AsyncGraphDatabase if not self.neo4j_password: raise GraphitiClientError("NEO4J_PASSWORD is required for Neo4j backend") + # Create Neo4j driver with connection pool resilience settings + # Default driver has liveness_check_timeout=None (disabled), which causes + # stale connections to fail with "TCPTransport closed" after Neo4j restarts + # or long idle periods (e.g. during multi-hour pipeline intake runs) + neo4j_async_driver = AsyncGraphDatabase.driver( + self.neo4j_uri, + auth=(self.neo4j_user, self.neo4j_password), + liveness_check_timeout=settings.NEO4J_LIVENESS_CHECK_TIMEOUT, + max_connection_lifetime=settings.NEO4J_MAX_CONNECTION_LIFETIME, + connection_acquisition_timeout=settings.NEO4J_CONNECTION_ACQUISITION_TIMEOUT, + ) + + # Create Neo4jDriver wrapper and inject our configured async driver + neo4j_driver = Neo4jDriver( + uri=self.neo4j_uri, + user=self.neo4j_user, + password=self.neo4j_password, + ) + # Replace the default driver with our configured one + await neo4j_driver.client.close() + neo4j_driver.client = neo4j_async_driver + # Create LLM client for entity extraction llm_client = self._get_llm_client() @@ -149,11 +187,9 @@ async def _create_neo4j_client(self) -> "Graphiti": # Create cross encoder for reranking cross_encoder = self._get_cross_encoder() - # Create Graphiti instance with Neo4j (uses default Neo4j driver) + # Create Graphiti instance with our pre-configured Neo4j driver graphiti = Graphiti( - uri=self.neo4j_uri, - user=self.neo4j_user, - password=self.neo4j_password, + graph_driver=neo4j_driver, llm_client=llm_client, embedder=embedder, cross_encoder=cross_encoder,