diff --git a/.env.example b/.env.example index 6570deb..363dcf4 100644 --- a/.env.example +++ b/.env.example @@ -12,6 +12,7 @@ OPENROUTER_API_KEY=your-api-key-here # Get from https://openrouter.ai/keys LLM_MODEL=anthropic/claude-sonnet-4 # Primary model LLM_FALLBACK_MODEL=anthropic/claude-haiku # Fallback for resilience LLM_TIMEOUT_SECONDS=30.0 # Request timeout +LLM_ENABLE_PROMPT_CACHING=true # Enable Anthropic prompt caching for cost reduction # Rate limiting RATE_LIMIT_REQUESTS=10 # Requests allowed per window diff --git a/src/config.py b/src/config.py index b047175..c7025ab 100644 --- a/src/config.py +++ b/src/config.py @@ -38,6 +38,9 @@ class Settings(BaseSettings): circuit_breaker_fail_max: int = 5 circuit_breaker_timeout: float = 60.0 + # Prompt caching (Anthropic models via OpenRouter) + llm_enable_prompt_caching: bool = True + # Embeddings (via OpenRouter) embedding_api_key: SecretStr | None = None embedding_base_url: str = "https://openrouter.ai/api/v1" diff --git a/src/infrastructure/llm/openrouter.py b/src/infrastructure/llm/openrouter.py index c0e9d2a..6700215 100644 --- a/src/infrastructure/llm/openrouter.py +++ b/src/infrastructure/llm/openrouter.py @@ -47,6 +47,7 @@ def __init__( timeout_seconds: float = 30.0, circuit_breaker_fail_max: int = 5, circuit_breaker_timeout: float = 60.0, + enable_prompt_caching: bool = True, ) -> None: """Initialize the OpenRouter provider. @@ -56,6 +57,7 @@ def __init__( timeout_seconds: Request timeout in seconds. circuit_breaker_fail_max: Open circuit after this many failures. circuit_breaker_timeout: Time in seconds before attempting recovery. + enable_prompt_caching: Enable Anthropic prompt caching via OpenRouter. Raises: LLMConfigurationError: If API key is missing. @@ -72,6 +74,7 @@ def __init__( ) self._default_model = default_model self._timeout = timeout_seconds + self._enable_prompt_caching = enable_prompt_caching # Circuit breaker: fail fast after repeated failures self._breaker = CircuitBreaker( @@ -165,6 +168,60 @@ async def _complete_with_resilience( self._do_complete, system_prompt, user_message, model ) + def _build_system_message(self, system_prompt: str) -> dict[str, object]: + """Build system message with optional cache control. + + The cache_control format follows Anthropic's prompt caching API, + which OpenRouter passes through to Anthropic models. For non-Anthropic + models, the cache_control field is typically ignored by the provider. + + Args: + system_prompt: The system prompt text. + + Returns: + Message dict with cache_control if caching is enabled. + """ + if self._enable_prompt_caching: + return { + "role": "system", + "content": [ + { + "type": "text", + "text": system_prompt, + "cache_control": {"type": "ephemeral"}, + } + ], + } + return {"role": "system", "content": system_prompt} + + def _log_cache_metrics(self, response: object, model: str) -> None: + """Log cache performance metrics from response. + + Args: + response: The API response object. + model: The model used for the request. + """ + usage = getattr(response, "usage", None) + if usage is None: + return + + # Safely extract cache metrics with fallback to 0 + try: + cache_read = int(getattr(usage, "cache_read_input_tokens", 0) or 0) + cache_creation = int(getattr(usage, "cache_creation_input_tokens", 0) or 0) + except (TypeError, ValueError): + # Handle cases where attributes aren't numeric + return + + if cache_read > 0 or cache_creation > 0: + logger.info( + "llm_cache_metrics", + provider=self.PROVIDER_NAME, + model=model, + cache_read_tokens=cache_read, + cache_creation_tokens=cache_creation, + ) + async def _do_complete( self, system_prompt: str, @@ -184,16 +241,20 @@ async def _do_complete( ) try: + messages: list[dict[str, object]] = [ + self._build_system_message(system_prompt), + {"role": "user", "content": user_message}, + ] + response = await self._client.chat.completions.create( model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_message}, - ], + messages=messages, # type: ignore[arg-type] ) content = response.choices[0].message.content or "" + self._log_cache_metrics(response, model) + logger.debug( "llm_request_success", provider=self.PROVIDER_NAME, @@ -333,11 +394,11 @@ async def _do_complete_with_history( ) try: - # Build messages array with system prompt first - all_messages: list[dict[str, str]] = [ - {"role": "system", "content": system_prompt} + # Build messages array with system prompt first (with cache control) + all_messages: list[dict[str, object]] = [ + self._build_system_message(system_prompt) ] - all_messages.extend(messages) + all_messages.extend(messages) # type: ignore[arg-type] response = await self._client.chat.completions.create( model=model, @@ -346,6 +407,8 @@ async def _do_complete_with_history( content = response.choices[0].message.content or "" + self._log_cache_metrics(response, model) + logger.debug( "llm_history_request_success", provider=self.PROVIDER_NAME, diff --git a/src/web/routes.py b/src/web/routes.py index 5d8d34d..d878937 100644 --- a/src/web/routes.py +++ b/src/web/routes.py @@ -51,6 +51,7 @@ def get_llm_provider( timeout_seconds=settings.llm_timeout_seconds, circuit_breaker_fail_max=settings.circuit_breaker_fail_max, circuit_breaker_timeout=settings.circuit_breaker_timeout, + enable_prompt_caching=settings.llm_enable_prompt_caching, ) diff --git a/tests/test_llm_provider.py b/tests/test_llm_provider.py index bbc05e4..1092ea3 100644 --- a/tests/test_llm_provider.py +++ b/tests/test_llm_provider.py @@ -51,6 +51,21 @@ def test_init_with_custom_circuit_breaker_settings(self): assert provider._breaker._fail_max == 3 + def test_init_with_prompt_caching_enabled_by_default(self): + """Provider should enable prompt caching by default.""" + provider = OpenRouterProvider(api_key="test-key") + + assert provider._enable_prompt_caching is True + + def test_init_with_prompt_caching_disabled(self): + """Provider should accept disabled prompt caching.""" + provider = OpenRouterProvider( + api_key="test-key", + enable_prompt_caching=False, + ) + + assert provider._enable_prompt_caching is False + class TestOpenRouterProviderComplete: """Tests for OpenRouterProvider.complete() method.""" @@ -67,6 +82,11 @@ def mock_response(self): response = MagicMock() response.choices = [MagicMock()] response.choices[0].message.content = "Test response" + # Mock usage with no cache activity + usage = MagicMock() + type(usage).cache_read_input_tokens = 0 + type(usage).cache_creation_input_tokens = 0 + response.usage = usage return response async def test_complete_returns_content(self, provider, mock_response): @@ -105,6 +125,70 @@ async def test_complete_with_custom_model(self, provider, mock_response): call_kwargs = provider._client.chat.completions.create.call_args.kwargs assert call_kwargs["model"] == "anthropic/claude-haiku" + async def test_complete_includes_cache_control_when_enabled( + self, provider, mock_response + ): + """Complete should include cache_control in system message when enabled.""" + provider._client.chat.completions.create = AsyncMock(return_value=mock_response) + + await provider.complete( + system_prompt="You are helpful.", + user_message="Hello", + ) + + call_kwargs = provider._client.chat.completions.create.call_args.kwargs + messages = call_kwargs["messages"] + system_msg = messages[0] + + # Verify cache control structure + assert system_msg["role"] == "system" + assert isinstance(system_msg["content"], list) + assert len(system_msg["content"]) == 1 + assert system_msg["content"][0]["type"] == "text" + assert system_msg["content"][0]["text"] == "You are helpful." + assert system_msg["content"][0]["cache_control"] == {"type": "ephemeral"} + + async def test_complete_excludes_cache_control_when_disabled(self, mock_response): + """Complete should use plain string content when caching disabled.""" + provider = OpenRouterProvider(api_key="test-key", enable_prompt_caching=False) + provider._client.chat.completions.create = AsyncMock(return_value=mock_response) + + await provider.complete( + system_prompt="You are helpful.", + user_message="Hello", + ) + + call_kwargs = provider._client.chat.completions.create.call_args.kwargs + messages = call_kwargs["messages"] + system_msg = messages[0] + + # Verify plain string format + assert system_msg["role"] == "system" + assert system_msg["content"] == "You are helpful." + + async def test_complete_with_history_includes_cache_control( + self, provider, mock_response + ): + """Complete with history should include cache_control in system message.""" + provider._client.chat.completions.create = AsyncMock(return_value=mock_response) + + await provider.complete_with_history( + system_prompt="You are helpful.", + messages=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + ) + + call_kwargs = provider._client.chat.completions.create.call_args.kwargs + messages = call_kwargs["messages"] + system_msg = messages[0] + + # Verify cache control structure + assert system_msg["role"] == "system" + assert isinstance(system_msg["content"], list) + assert system_msg["content"][0]["cache_control"] == {"type": "ephemeral"} + async def test_complete_timeout_raises_llm_timeout_error(self, provider): """Complete should raise LLMTimeoutError on timeout.""" provider._client.chat.completions.create = AsyncMock( @@ -157,6 +241,11 @@ async def test_complete_handles_empty_response_content(self, provider): response = MagicMock() response.choices = [MagicMock()] response.choices[0].message.content = None + # Mock usage with no cache activity + usage = MagicMock() + type(usage).cache_read_input_tokens = 0 + type(usage).cache_creation_input_tokens = 0 + response.usage = usage provider._client.chat.completions.create = AsyncMock(return_value=response) @@ -178,6 +267,11 @@ async def test_retries_on_connection_error(self): mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "Success after retry" + # Mock usage with no cache activity + usage = MagicMock() + type(usage).cache_read_input_tokens = 0 + type(usage).cache_creation_input_tokens = 0 + mock_response.usage = usage # Fail once, then succeed provider._client.chat.completions.create = AsyncMock( @@ -238,3 +332,138 @@ async def test_circuit_breaker_opens_after_failures(self): assert "temporarily unavailable" in str(exc_info.value) # Also verify it's not LLMRateLimitError (more specific) assert not isinstance(exc_info.value, LLMRateLimitError) + + +class TestOpenRouterProviderCacheMetrics: + """Tests for cache metrics logging.""" + + @pytest.fixture + def provider(self): + """Create a provider with caching enabled.""" + return OpenRouterProvider(api_key="test-key", enable_prompt_caching=True) + + @pytest.fixture + def mock_response_with_cache_hit(self): + """Create a mock response with cache hit metrics.""" + response = MagicMock() + response.choices = [MagicMock()] + response.choices[0].message.content = "Cached response" + # Use PropertyMock to return proper integers + usage = MagicMock() + type(usage).cache_read_input_tokens = 500 + type(usage).cache_creation_input_tokens = 0 + response.usage = usage + return response + + @pytest.fixture + def mock_response_with_cache_write(self): + """Create a mock response with cache write metrics.""" + response = MagicMock() + response.choices = [MagicMock()] + response.choices[0].message.content = "New response" + usage = MagicMock() + type(usage).cache_read_input_tokens = 0 + type(usage).cache_creation_input_tokens = 1200 + response.usage = usage + return response + + @pytest.fixture + def mock_response_without_cache(self): + """Create a mock response without cache metrics.""" + response = MagicMock() + response.choices = [MagicMock()] + response.choices[0].message.content = "Regular response" + usage = MagicMock() + type(usage).cache_read_input_tokens = 0 + type(usage).cache_creation_input_tokens = 0 + response.usage = usage + return response + + async def test_logs_cache_hit_metrics( + self, provider, mock_response_with_cache_hit, capsys + ): + """Provider should log cache hit metrics.""" + provider._client.chat.completions.create = AsyncMock( + return_value=mock_response_with_cache_hit + ) + + await provider.complete( + system_prompt="You are helpful.", + user_message="Hello", + ) + + captured = capsys.readouterr() + assert "llm_cache_metrics" in captured.out + assert "cache_read_tokens" in captured.out + + async def test_logs_cache_write_metrics( + self, provider, mock_response_with_cache_write, capsys + ): + """Provider should log cache write metrics.""" + provider._client.chat.completions.create = AsyncMock( + return_value=mock_response_with_cache_write + ) + + await provider.complete( + system_prompt="You are helpful.", + user_message="Hello", + ) + + captured = capsys.readouterr() + assert "llm_cache_metrics" in captured.out + assert "cache_creation_tokens" in captured.out + + async def test_no_log_when_no_cache_activity( + self, provider, mock_response_without_cache, capsys + ): + """Provider should not log when no cache activity.""" + provider._client.chat.completions.create = AsyncMock( + return_value=mock_response_without_cache + ) + + await provider.complete( + system_prompt="You are helpful.", + user_message="Hello", + ) + + captured = capsys.readouterr() + assert "llm_cache_metrics" not in captured.out + + async def test_handles_missing_usage_gracefully(self, provider): + """Provider should handle responses without usage object.""" + response = MagicMock() + response.choices = [MagicMock()] + response.choices[0].message.content = "Response without usage" + response.usage = None # No usage data + + provider._client.chat.completions.create = AsyncMock(return_value=response) + + result = await provider.complete( + system_prompt="You are helpful.", + user_message="Hello", + ) + + # Should complete without error + assert result == "Response without usage" + + async def test_handles_non_numeric_cache_tokens(self, provider, capsys): + """Provider should handle non-numeric cache token values gracefully.""" + response = MagicMock() + response.choices = [MagicMock()] + response.choices[0].message.content = "Response with invalid usage" + # Create usage with non-convertible values (TypeError on int()) + response.usage = MagicMock() + response.usage.cache_read_input_tokens = "not_a_number" + response.usage.cache_creation_input_tokens = None + + provider._client.chat.completions.create = AsyncMock(return_value=response) + + result = await provider.complete( + system_prompt="You are helpful.", + user_message="Hello", + ) + + # Should complete without error and not log cache metrics + assert result == "Response with invalid usage" + captured = capsys.readouterr() + assert "llm_cache_metrics" not in captured.out