Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/knowledge_base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ class Settings(BaseSettings):
NEO4J_URI: str = "bolt://localhost:7687"
NEO4J_USER: str = "neo4j"
NEO4J_PASSWORD: str = ""
# Neo4j connection pool resilience (seconds)
NEO4J_LIVENESS_CHECK_TIMEOUT: int = 30 # Check connection health before use
NEO4J_MAX_CONNECTION_LIFETIME: int = 1800 # Recycle connections every 30 min
NEO4J_CONNECTION_ACQUISITION_TIMEOUT: int = 60 # Wait up to 60s for a connection
# Feature flags for Graphiti-only architecture
GRAPH_ENABLE_GRAPHITI: bool = True # Master switch for Graphiti (now required)
GRAPH_EXPANSION_ENABLED: bool = True # Always enabled with Graphiti-only
Expand Down
72 changes: 54 additions & 18 deletions src/knowledge_base/graph/graphiti_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
NetworkX-based graph during the gradual rollout phase.
"""

import asyncio
import logging
import os
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -38,6 +39,7 @@ class GraphitiClient:

_instance: "Graphiti | None" = None
_initialized: bool = False
_init_lock: asyncio.Lock | None = None

def __init__(
self,
Expand Down Expand Up @@ -68,6 +70,9 @@ def __init__(
async def get_client(self) -> "Graphiti":
"""Get or create the Graphiti client instance.

Uses an asyncio.Lock to prevent concurrent initialization from
multiple async tasks creating duplicate Graphiti instances.

Returns:
Configured Graphiti instance

Expand All @@ -78,22 +83,31 @@ async def get_client(self) -> "Graphiti":
if GraphitiClient._instance is not None and GraphitiClient._initialized:
return GraphitiClient._instance

try:
if self.backend == "kuzu":
client = await self._create_kuzu_client()
elif self.backend == "neo4j":
client = await self._create_neo4j_client()
else:
raise GraphitiClientError(f"Unsupported graph backend: {self.backend}")
# Lazy-init the lock (safe: first call always happens in a single task)
if GraphitiClient._init_lock is None:
GraphitiClient._init_lock = asyncio.Lock()

GraphitiClient._instance = client
GraphitiClient._initialized = True
logger.info(f"Graphiti client initialized with {self.backend} backend")
return client
async with GraphitiClient._init_lock:
# Double-check after acquiring lock
if GraphitiClient._instance is not None and GraphitiClient._initialized:
return GraphitiClient._instance

except Exception as e:
logger.error(f"Failed to initialize Graphiti client: {e}")
raise GraphitiConnectionError(f"Could not connect to {self.backend}: {e}") from e
try:
if self.backend == "kuzu":
client = await self._create_kuzu_client()
elif self.backend == "neo4j":
client = await self._create_neo4j_client()
else:
raise GraphitiClientError(f"Unsupported graph backend: {self.backend}")

GraphitiClient._instance = client
GraphitiClient._initialized = True
logger.info(f"Graphiti client initialized with {self.backend} backend")
return client

except Exception as e:
logger.error(f"Failed to initialize Graphiti client: {e}")
raise GraphitiConnectionError(f"Could not connect to {self.backend}: {e}") from e

async def _create_kuzu_client(self) -> "Graphiti":
"""Create Graphiti client with Kuzu embedded backend."""
Expand Down Expand Up @@ -136,10 +150,34 @@ async def _create_kuzu_client(self) -> "Graphiti":
async def _create_neo4j_client(self) -> "Graphiti":
"""Create Graphiti client with Neo4j backend."""
from graphiti_core import Graphiti
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from neo4j import AsyncGraphDatabase

if not self.neo4j_password:
raise GraphitiClientError("NEO4J_PASSWORD is required for Neo4j backend")

# Create Neo4j driver with connection pool resilience settings
# Default driver has liveness_check_timeout=None (disabled), which causes
# stale connections to fail with "TCPTransport closed" after Neo4j restarts
# or long idle periods (e.g. during multi-hour pipeline intake runs)
neo4j_async_driver = AsyncGraphDatabase.driver(
self.neo4j_uri,
auth=(self.neo4j_user, self.neo4j_password),
liveness_check_timeout=settings.NEO4J_LIVENESS_CHECK_TIMEOUT,
max_connection_lifetime=settings.NEO4J_MAX_CONNECTION_LIFETIME,
connection_acquisition_timeout=settings.NEO4J_CONNECTION_ACQUISITION_TIMEOUT,
)

# Create Neo4jDriver wrapper and inject our configured async driver
neo4j_driver = Neo4jDriver(
uri=self.neo4j_uri,
user=self.neo4j_user,
password=self.neo4j_password,
)
# Replace the default driver with our configured one
await neo4j_driver.client.close()
neo4j_driver.client = neo4j_async_driver

# Create LLM client for entity extraction
llm_client = self._get_llm_client()

Expand All @@ -149,11 +187,9 @@ async def _create_neo4j_client(self) -> "Graphiti":
# Create cross encoder for reranking
cross_encoder = self._get_cross_encoder()

# Create Graphiti instance with Neo4j (uses default Neo4j driver)
# Create Graphiti instance with our pre-configured Neo4j driver
graphiti = Graphiti(
uri=self.neo4j_uri,
user=self.neo4j_user,
password=self.neo4j_password,
graph_driver=neo4j_driver,
llm_client=llm_client,
embedder=embedder,
cross_encoder=cross_encoder,
Expand Down