Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,10 @@ AUTH_ENABLED=true # Enable user authentication
JWT_SECRET_KEY=change-this-to-a-random-secret-key # Secret key for JWT signing (REQUIRED if auth enabled)
JWT_ALGORITHM=HS256 # JWT signing algorithm
JWT_EXPIRE_HOURS=24 # Token expiration in hours

# OpenTelemetry Observability
OTEL_ENABLED=true # Enable OpenTelemetry tracing
OTEL_EXPORTER=none # Exporter: "otlp", "console", or "none"
OTEL_ENDPOINT= # OTLP collector endpoint (e.g., "http://localhost:4318")
OTEL_SERVICE_NAME=retriever # Service name for traces
OTEL_SAMPLE_RATE=1.0 # Sampling rate (0.0 to 1.0)
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ dependencies = [
"email-validator>=2.3.0",
"markdown>=3.10",
"bleach>=6.3.0",
# Observability (OpenTelemetry)
"opentelemetry-api~=1.29.0",
"opentelemetry-sdk~=1.29.0",
"opentelemetry-exporter-otlp-proto-http~=1.29.0",
"opentelemetry-instrumentation-fastapi~=0.50b0",
"opentelemetry-instrumentation-httpx~=0.50b0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -120,6 +126,8 @@ module = [
"chromadb",
"chromadb.*",
"rank_bm25",
"opentelemetry",
"opentelemetry.*",
]
ignore_missing_imports = true

Expand Down
7 changes: 7 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ class Settings(BaseSettings):
# Conversation History
conversation_max_messages: int = 20 # Max messages to include in context

# OpenTelemetry / Observability
otel_enabled: bool = True # Enable OpenTelemetry tracing
otel_exporter: str = "none" # Exporter: "otlp", "console", or "none"
otel_endpoint: str | None = None # OTLP endpoint (e.g., "http://localhost:4318")
otel_service_name: str = "retriever" # Service name for traces
otel_sample_rate: float = 1.0 # Sampling rate (0.0 to 1.0)


@lru_cache
def get_settings() -> Settings:
Expand Down
240 changes: 131 additions & 109 deletions src/infrastructure/cache/chroma_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

from src.infrastructure.cache.protocol import CacheEntry
from src.infrastructure.embeddings import EmbeddingProvider
from src.infrastructure.observability import get_tracer

logger = structlog.get_logger()
tracer = get_tracer(__name__)


class ChromaSemanticCache:
Expand Down Expand Up @@ -83,94 +85,107 @@ async def get(self, question: str) -> CacheEntry | None:
CacheEntry if a similar question was found above the
similarity threshold, None otherwise.
"""
if self._collection.count() == 0:
return None

# Embed the question
query_embedding = await self._embeddings.embed(question)

# Query for similar cached questions
results = self._collection.query(
query_embeddings=[query_embedding], # type: ignore[arg-type]
n_results=1,
include=["documents", "metadatas", "distances"], # type: ignore[list-item]
)
with tracer.start_as_current_span("cache.get") as span:
span.set_attribute("cache.question_length", len(question))

if self._collection.count() == 0:
span.set_attribute("cache.hit", False)
span.set_attribute("cache.miss_reason", "empty_cache")
return None

# Embed the question
query_embedding = await self._embeddings.embed(question)

# Query for similar cached questions
results = self._collection.query(
query_embeddings=[query_embedding],
n_results=1,
include=["documents", "metadatas", "distances"],
)

# Check if we got a result
ids_result = results.get("ids")
if not ids_result or not ids_result[0]:
logger.debug("cache_miss_no_results", question_length=len(question))
return None

# Get the distance and convert to similarity
distances_result = results.get("distances")
if distances_result and distances_result[0]:
distance = float(distances_result[0][0])
else:
distance = 2.0 # Max cosine distance
similarity = 1.0 - distance

# Check threshold
if similarity < self._similarity_threshold:
logger.debug(
"cache_miss_below_threshold",
# Check if we got a result
ids_result = results.get("ids")
if not ids_result or not ids_result[0]:
span.set_attribute("cache.hit", False)
span.set_attribute("cache.miss_reason", "no_results")
logger.debug("cache_miss_no_results", question_length=len(question))
return None

# Get the distance and convert to similarity
distances_result = results.get("distances")
if distances_result and distances_result[0]:
distance = float(distances_result[0][0])
else:
distance = 2.0 # Max cosine distance
similarity = 1.0 - distance
span.set_attribute("cache.similarity", similarity)

# Check threshold
if similarity < self._similarity_threshold:
span.set_attribute("cache.hit", False)
span.set_attribute("cache.miss_reason", "below_threshold")
logger.debug(
"cache_miss_below_threshold",
similarity=similarity,
threshold=self._similarity_threshold,
question_length=len(question),
)
return None

# Check TTL
metadatas_result = results.get("metadatas")
if metadatas_result and metadatas_result[0]:
metadata = metadatas_result[0][0]
else:
metadata = {}
created_at_str = str(metadata.get("created_at", ""))

if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
# Handle timezone-naive datetime from storage
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=UTC)
age_hours = (datetime.now(UTC) - created_at).total_seconds() / 3600
if age_hours > self._ttl_hours:
span.set_attribute("cache.hit", False)
span.set_attribute("cache.miss_reason", "expired")
logger.debug(
"cache_miss_expired",
age_hours=age_hours,
ttl_hours=self._ttl_hours,
)
return None
except ValueError:
pass # Invalid date, treat as valid

# Cache hit!
span.set_attribute("cache.hit", True)
documents_result = results.get("documents")
if documents_result and documents_result[0]:
answer = str(documents_result[0][0])
else:
answer = ""

cached_question = str(metadata.get("question", ""))
chunks_json = str(metadata.get("chunks_json", "[]"))

logger.info(
"cache_hit",
similarity=similarity,
threshold=self._similarity_threshold,
question_length=len(question),
cached_question_length=len(cached_question),
query_question_length=len(question),
)
return None

# Check TTL
metadatas_result = results.get("metadatas")
if metadatas_result and metadatas_result[0]:
metadata = metadatas_result[0][0]
else:
metadata = {}
created_at_str = str(metadata.get("created_at", ""))

if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
# Handle timezone-naive datetime from storage
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=UTC)
age_hours = (datetime.now(UTC) - created_at).total_seconds() / 3600
if age_hours > self._ttl_hours:
logger.debug(
"cache_miss_expired",
age_hours=age_hours,
ttl_hours=self._ttl_hours,
)
return None
except ValueError:
pass # Invalid date, treat as valid

# Cache hit!
documents_result = results.get("documents")
if documents_result and documents_result[0]:
answer = str(documents_result[0][0])
else:
answer = ""

cached_question = str(metadata.get("question", ""))
chunks_json = str(metadata.get("chunks_json", "[]"))

logger.info(
"cache_hit",
similarity=similarity,
cached_question_length=len(cached_question),
query_question_length=len(question),
)

return CacheEntry(
question=cached_question,
answer=answer,
chunks_json=chunks_json,
created_at=datetime.fromisoformat(created_at_str)
if created_at_str
else datetime.now(UTC),
similarity_score=similarity,
)
return CacheEntry(
question=cached_question,
answer=answer,
chunks_json=chunks_json,
created_at=datetime.fromisoformat(created_at_str)
if created_at_str
else datetime.now(UTC),
similarity_score=similarity,
)

async def set(
self,
Expand All @@ -185,32 +200,39 @@ async def set(
answer: The generated answer.
chunks_json: JSON-serialized list of ChunkWithScore used.
"""
# Generate embedding for the question
question_embedding = await self._embeddings.embed(question)

# Generate unique ID
entry_id = f"cache:{uuid.uuid4().hex}"

# Store with metadata
self._collection.add(
ids=[entry_id],
documents=[answer],
embeddings=[question_embedding], # type: ignore[arg-type]
metadatas=[
{
"question": question,
"chunks_json": chunks_json,
"created_at": datetime.now(UTC).isoformat(),
}
],
)
with tracer.start_as_current_span("cache.set") as span:
span.set_attribute("cache.question_length", len(question))
span.set_attribute("cache.answer_length", len(answer))

# Generate embedding for the question
question_embedding = await self._embeddings.embed(question)

# Generate unique ID
entry_id = f"cache:{uuid.uuid4().hex}"

# Store with metadata
self._collection.add(
ids=[entry_id],
documents=[answer],
embeddings=[question_embedding],
metadatas=[
{
"question": question,
"chunks_json": chunks_json,
"created_at": datetime.now(UTC).isoformat(),
}
],
)

logger.debug(
"cache_set",
question_length=len(question),
answer_length=len(answer),
total_cached=self._collection.count(),
)
total_cached = self._collection.count()
span.set_attribute("cache.total_entries", total_cached)

logger.debug(
"cache_set",
question_length=len(question),
answer_length=len(answer),
total_cached=total_cached,
)

async def clear(self) -> None:
"""Clear all cached entries.
Expand Down
Loading
Loading