Skip to content

Commit 2a26ffa

Browse files
authored
fix(litellm): map LiteLLM context-window errors to ContextWindowOverflowException (#994)
1 parent 776fd93 commit 2a26ffa

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

src/strands/models/litellm.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast
99

1010
import litellm
11+
from litellm.exceptions import ContextWindowExceededError
1112
from litellm.utils import supports_response_schema
1213
from pydantic import BaseModel
1314
from typing_extensions import Unpack, override
1415

1516
from ..types.content import ContentBlock, Messages
17+
from ..types.exceptions import ContextWindowOverflowException
1618
from ..types.streaming import StreamEvent
1719
from ..types.tools import ToolChoice, ToolSpec
1820
from ._validation import validate_config_keys
@@ -135,7 +137,11 @@ async def stream(
135137
logger.debug("request=<%s>", request)
136138

137139
logger.debug("invoking model")
138-
response = await litellm.acompletion(**self.client_args, **request)
140+
try:
141+
response = await litellm.acompletion(**self.client_args, **request)
142+
except ContextWindowExceededError as e:
143+
logger.warning("litellm client raised context window overflow")
144+
raise ContextWindowOverflowException(e) from e
139145

140146
logger.debug("got response from model")
141147
yield self.format_chunk({"chunk_type": "message_start"})
@@ -205,15 +211,24 @@ async def structured_output(
205211
Yields:
206212
Model events with the last being the structured output.
207213
"""
208-
if not supports_response_schema(self.get_config()["model_id"]):
214+
supports_schema = supports_response_schema(self.get_config()["model_id"])
215+
216+
# If the provider does not support response schemas, we cannot reliably parse structured output.
217+
# In that case we must not call the provider and must raise the documented ValueError.
218+
if not supports_schema:
209219
raise ValueError("Model does not support response_format")
210220

211-
response = await litellm.acompletion(
212-
**self.client_args,
213-
model=self.get_config()["model_id"],
214-
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
215-
response_format=output_model,
216-
)
221+
# For providers that DO support response schemas, call litellm and map context-window errors.
222+
try:
223+
response = await litellm.acompletion(
224+
**self.client_args,
225+
model=self.get_config()["model_id"],
226+
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
227+
response_format=output_model,
228+
)
229+
except ContextWindowExceededError as e:
230+
logger.warning("litellm client raised context window overflow in structured_output")
231+
raise ContextWindowOverflowException(e) from e
217232

218233
if len(response.choices) > 1:
219234
raise ValueError("Multiple choices found in the response.")

tests/strands/models/test_litellm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
import pydantic
55
import pytest
6+
from litellm.exceptions import ContextWindowExceededError
67

78
import strands
89
from strands.models.litellm import LiteLLMModel
10+
from strands.types.exceptions import ContextWindowOverflowException
911

1012

1113
@pytest.fixture
@@ -332,3 +334,13 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings):
332334
model.format_request(messages, tool_choice=None)
333335

334336
assert len(captured_warnings) == 0
337+
338+
339+
@pytest.mark.asyncio
340+
async def test_context_window_maps_to_typed_exception(litellm_acompletion, model):
341+
"""Test that a typed ContextWindowExceededError is mapped correctly."""
342+
litellm_acompletion.side_effect = ContextWindowExceededError(message="test error", model="x", llm_provider="y")
343+
344+
with pytest.raises(ContextWindowOverflowException):
345+
async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]):
346+
pass

0 commit comments

Comments
 (0)