Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/infrastructure/cache/chroma_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def get(self, question: str) -> CacheEntry | None:
# Check if we got a result
ids_result = results.get("ids")
if not ids_result or not ids_result[0]:
logger.debug("cache_miss_no_results", question_length=len(question))
logger.info("cache_miss_no_results", question_length=len(question))
return None

# Get the distance and convert to similarity
Expand All @@ -112,7 +112,7 @@ async def get(self, question: str) -> CacheEntry | None:

# Check threshold
if similarity < self._similarity_threshold:
logger.debug(
logger.info(
"cache_miss_below_threshold",
similarity=similarity,
threshold=self._similarity_threshold,
Expand All @@ -136,7 +136,7 @@ async def get(self, question: str) -> CacheEntry | None:
created_at = created_at.replace(tzinfo=UTC)
age_hours = (datetime.now(UTC) - created_at).total_seconds() / 3600
if age_hours > self._ttl_hours:
logger.debug(
logger.info(
"cache_miss_expired",
age_hours=age_hours,
ttl_hours=self._ttl_hours,
Expand Down Expand Up @@ -205,7 +205,7 @@ async def set(
],
)

logger.debug(
logger.info(
"cache_set",
question_length=len(question),
answer_length=len(answer),
Expand Down
31 changes: 25 additions & 6 deletions src/modules/rag/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,31 @@ async def ask(
)

# Store in cache (only high/medium confidence answers with context)
if self._cache is not None and chunks_used and not confidence.needs_review:
await self._cache.set(
question=question,
answer=answer,
chunks_json=_serialize_chunks(chunks_used),
)
if self._cache is not None:
if not chunks_used:
logger.info(
"rag_cache_skip_no_chunks",
question_length=len(question),
)
elif confidence.needs_review:
logger.info(
"rag_cache_skip_low_confidence",
question_length=len(question),
confidence_level=confidence.level.value,
confidence_score=confidence.score,
)
else:
await self._cache.set(
question=question,
answer=answer,
chunks_json=_serialize_chunks(chunks_used),
)
logger.info(
"rag_cache_stored",
question_length=len(question),
answer_length=len(answer),
confidence_level=confidence.level.value,
)

return RAGResponse(
answer=answer,
Expand Down
53 changes: 33 additions & 20 deletions src/web/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# Singletons (per process)
_vector_store_cache: dict[str, ChromaVectorStore] = {}
_semantic_cache_instance: dict[str, ChromaSemanticCache] = {}
_semantic_cache_instance: dict[str, ChromaSemanticCache | None] = {}
_hybrid_retriever_instance: dict[str, HybridRetriever] = {}


Expand Down Expand Up @@ -79,33 +79,46 @@ def get_semantic_cache(
"""Get or create the semantic cache singleton.

Returns None if caching is disabled or embeddings aren't configured.
The result (including None) is cached to avoid repeated logging.
"""
persist_path = settings.chroma_persist_path

# Check singleton cache first (includes None entries for disabled cache)
if persist_path in _semantic_cache_instance:
return _semantic_cache_instance[persist_path]

# First time: determine cache state and log once
if not settings.cache_enabled:
logger.info("semantic_cache_disabled", reason="cache_enabled=False")
_semantic_cache_instance[persist_path] = None
return None

if settings.embedding_api_key is None:
return None

persist_path = settings.chroma_persist_path

if persist_path not in _semantic_cache_instance:
embedding_provider = OpenAIEmbeddingProvider(
api_key=settings.embedding_api_key.get_secret_value(),
model=settings.embedding_model,
base_url=settings.embedding_base_url,
timeout_seconds=settings.embedding_timeout_seconds,
circuit_breaker_fail_max=settings.circuit_breaker_fail_max,
circuit_breaker_timeout=settings.circuit_breaker_timeout,
logger.info(
"semantic_cache_disabled",
reason="embedding_api_key not configured",
)
_semantic_cache_instance[persist_path] = None
return None

_semantic_cache_instance[persist_path] = ChromaSemanticCache(
embedding_provider=embedding_provider,
persist_path=Path(persist_path),
similarity_threshold=settings.cache_similarity_threshold,
ttl_hours=settings.cache_ttl_hours,
)
# Create and cache the semantic cache instance
embedding_provider = OpenAIEmbeddingProvider(
api_key=settings.embedding_api_key.get_secret_value(),
model=settings.embedding_model,
base_url=settings.embedding_base_url,
timeout_seconds=settings.embedding_timeout_seconds,
circuit_breaker_fail_max=settings.circuit_breaker_fail_max,
circuit_breaker_timeout=settings.circuit_breaker_timeout,
)

return _semantic_cache_instance[persist_path]
cache = ChromaSemanticCache(
embedding_provider=embedding_provider,
persist_path=Path(persist_path),
similarity_threshold=settings.cache_similarity_threshold,
ttl_hours=settings.cache_ttl_hours,
)
_semantic_cache_instance[persist_path] = cache
return cache


def get_hybrid_retriever(
Expand Down
125 changes: 125 additions & 0 deletions tests/test_semantic_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for semantic cache."""

import shutil
import tempfile
from datetime import UTC, datetime
from pathlib import Path
Expand Down Expand Up @@ -468,3 +469,127 @@ async def test_clear_cache_is_noop_without_cache(
"""Clear cache should be a no-op when cache is None."""
# Should not raise
await rag_service_no_cache.clear_cache()


class TestCacheEndToEndFlow:
"""End-to-end tests verifying the complete cache flow."""

@pytest.fixture
def mock_llm(self) -> MagicMock:
"""Create a mock LLM provider."""
llm = MagicMock()
llm.complete = AsyncMock(return_value="Dogs make great pets for volunteers.")
return llm

@pytest.fixture
def mock_embeddings(self) -> MagicMock:
"""Create a mock embedding provider that returns consistent embeddings."""
embeddings = MagicMock()
# Return same embedding for similar questions to simulate cache hit
embeddings.embed = AsyncMock(return_value=[0.1] * 1536)
embeddings.embed_batch = AsyncMock(return_value=[[0.1] * 1536])
embeddings.dimensions = 1536
return embeddings

@pytest.fixture
def real_cache(self, mock_embeddings: MagicMock) -> ChromaSemanticCache:
"""Create a real semantic cache for end-to-end testing."""
tmpdir = tempfile.mkdtemp()
cache = ChromaSemanticCache(
embedding_provider=mock_embeddings,
persist_path=Path(tmpdir),
similarity_threshold=0.95,
ttl_hours=24,
)
yield cache
# Cleanup after tests complete
shutil.rmtree(tmpdir, ignore_errors=True)

@pytest.fixture
def mock_vector_store(self) -> MagicMock:
"""Create a mock vector store with indexed documents."""
store = MagicMock()
store.count = MagicMock(return_value=5) # Simulate indexed documents
store.query = AsyncMock(
return_value=[
RetrievalResult(
id="chunk1",
content="Dogs are wonderful companions for shelter volunteers.",
metadata={"source": "volunteer-guide.md", "section": "Pets"},
score=0.92, # High enough for medium/high confidence
),
RetrievalResult(
id="chunk2",
content="Volunteers should always be gentle with animals.",
metadata={"source": "volunteer-guide.md", "section": "Guidelines"},
score=0.88,
),
]
)
store.clear = AsyncMock()
return store

async def test_full_cache_flow_miss_then_hit(
self,
mock_llm: MagicMock,
mock_embeddings: MagicMock,
mock_vector_store: MagicMock,
real_cache: ChromaSemanticCache,
) -> None:
"""Test complete cache flow: miss on first query, hit on second."""
# Create RAG service with real cache
rag_service = RAGService(
llm_provider=mock_llm,
embedding_provider=mock_embeddings,
vector_store=mock_vector_store,
semantic_cache=real_cache,
)

# Verify cache starts empty
assert real_cache.count() == 0

# First query - should be a cache miss, LLM called
question = "What pets are good for volunteers?"
response1 = await rag_service.ask(question)

assert response1.answer == "Dogs make great pets for volunteers."
mock_llm.complete.assert_called_once() # LLM was called

# Verify answer was cached
assert real_cache.count() == 1

# Reset LLM mock to verify it's not called on cache hit
mock_llm.complete.reset_mock()

# Second query with same question - should be a cache hit
response2 = await rag_service.ask(question)

assert response2.answer == "Dogs make great pets for volunteers."
mock_llm.complete.assert_not_called() # LLM was NOT called (cache hit)

# Cache count should still be 1 (no duplicate)
assert real_cache.count() == 1

async def test_cache_not_stored_when_no_chunks(
self,
mock_llm: MagicMock,
mock_embeddings: MagicMock,
real_cache: ChromaSemanticCache,
) -> None:
"""Test that cache is not populated when no documents are indexed."""
# Create vector store with no documents
empty_store = MagicMock()
empty_store.count = MagicMock(return_value=0)

rag_service = RAGService(
llm_provider=mock_llm,
embedding_provider=mock_embeddings,
vector_store=empty_store,
semantic_cache=real_cache,
)

# Query should work but not cache (no indexed documents)
await rag_service.ask("What pets are good?")

# Cache should remain empty (fallback responses not cached)
assert real_cache.count() == 0