|
8 | 8 | from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast
|
9 | 9 |
|
10 | 10 | import litellm
|
| 11 | +from litellm.exceptions import ContextWindowExceededError |
11 | 12 | from litellm.utils import supports_response_schema
|
12 | 13 | from pydantic import BaseModel
|
13 | 14 | from typing_extensions import Unpack, override
|
14 | 15 |
|
15 | 16 | from ..types.content import ContentBlock, Messages
|
| 17 | +from ..types.exceptions import ContextWindowOverflowException |
16 | 18 | from ..types.streaming import StreamEvent
|
17 | 19 | from ..types.tools import ToolChoice, ToolSpec
|
18 | 20 | from ._validation import validate_config_keys
|
@@ -135,7 +137,11 @@ async def stream(
|
135 | 137 | logger.debug("request=<%s>", request)
|
136 | 138 |
|
137 | 139 | logger.debug("invoking model")
|
138 |
| - response = await litellm.acompletion(**self.client_args, **request) |
| 140 | + try: |
| 141 | + response = await litellm.acompletion(**self.client_args, **request) |
| 142 | + except ContextWindowExceededError as e: |
| 143 | + logger.warning("litellm client raised context window overflow") |
| 144 | + raise ContextWindowOverflowException(e) from e |
139 | 145 |
|
140 | 146 | logger.debug("got response from model")
|
141 | 147 | yield self.format_chunk({"chunk_type": "message_start"})
|
@@ -205,15 +211,24 @@ async def structured_output(
|
205 | 211 | Yields:
|
206 | 212 | Model events with the last being the structured output.
|
207 | 213 | """
|
208 |
| - if not supports_response_schema(self.get_config()["model_id"]): |
| 214 | + supports_schema = supports_response_schema(self.get_config()["model_id"]) |
| 215 | + |
| 216 | + # If the provider does not support response schemas, we cannot reliably parse structured output. |
| 217 | + # In that case we must not call the provider and must raise the documented ValueError. |
| 218 | + if not supports_schema: |
209 | 219 | raise ValueError("Model does not support response_format")
|
210 | 220 |
|
211 |
| - response = await litellm.acompletion( |
212 |
| - **self.client_args, |
213 |
| - model=self.get_config()["model_id"], |
214 |
| - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], |
215 |
| - response_format=output_model, |
216 |
| - ) |
| 221 | + # For providers that DO support response schemas, call litellm and map context-window errors. |
| 222 | + try: |
| 223 | + response = await litellm.acompletion( |
| 224 | + **self.client_args, |
| 225 | + model=self.get_config()["model_id"], |
| 226 | + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], |
| 227 | + response_format=output_model, |
| 228 | + ) |
| 229 | + except ContextWindowExceededError as e: |
| 230 | + logger.warning("litellm client raised context window overflow in structured_output") |
| 231 | + raise ContextWindowOverflowException(e) from e |
217 | 232 |
|
218 | 233 | if len(response.choices) > 1:
|
219 | 234 | raise ValueError("Multiple choices found in the response.")
|
|
0 commit comments