|
4 | 4 | # This source code is licensed under the terms described in the LICENSE file in |
5 | 5 | # the root directory of this source tree. |
6 | 6 |
|
| 7 | +from collections.abc import AsyncIterator |
7 | 8 | from typing import Any |
8 | 9 |
|
| 10 | +import litellm |
9 | 11 | import requests |
10 | 12 |
|
11 | | -from llama_stack.apis.inference import ChatCompletionRequest |
| 13 | +from llama_stack.apis.inference.inference import ( |
| 14 | + OpenAIChatCompletion, |
| 15 | + OpenAIChatCompletionChunk, |
| 16 | + OpenAIChatCompletionRequestWithExtraBody, |
| 17 | + OpenAIChatCompletionUsage, |
| 18 | + OpenAICompletion, |
| 19 | + OpenAICompletionRequestWithExtraBody, |
| 20 | + OpenAIEmbeddingsRequestWithExtraBody, |
| 21 | + OpenAIEmbeddingsResponse, |
| 22 | +) |
12 | 23 | from llama_stack.apis.models import Model |
13 | 24 | from llama_stack.apis.models.models import ModelType |
| 25 | +from llama_stack.log import get_logger |
14 | 26 | from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig |
15 | 27 | from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin |
| 28 | +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params |
| 29 | +from llama_stack.providers.utils.telemetry.tracing import get_current_span |
| 30 | + |
| 31 | +logger = get_logger(name=__name__, category="providers::remote::watsonx") |
16 | 32 |
|
17 | 33 |
|
18 | 34 | class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): |
19 | 35 | _model_cache: dict[str, Model] = {} |
20 | 36 |
|
| 37 | + provider_data_api_key_field: str = "watsonx_api_key" |
| 38 | + |
21 | 39 | def __init__(self, config: WatsonXConfig): |
| 40 | + self.available_models = None |
| 41 | + self.config = config |
| 42 | + api_key = config.auth_credential.get_secret_value() if config.auth_credential else None |
22 | 43 | LiteLLMOpenAIMixin.__init__( |
23 | 44 | self, |
24 | 45 | litellm_provider_name="watsonx", |
25 | | - api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None, |
| 46 | + api_key_from_config=api_key, |
26 | 47 | provider_data_api_key_field="watsonx_api_key", |
| 48 | + openai_compat_api_base=self.get_base_url(), |
27 | 49 | ) |
28 | | - self.available_models = None |
29 | | - self.config = config |
30 | 50 |
|
31 | | - def get_base_url(self) -> str: |
32 | | - return self.config.url |
| 51 | + async def openai_chat_completion( |
| 52 | + self, |
| 53 | + params: OpenAIChatCompletionRequestWithExtraBody, |
| 54 | + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: |
| 55 | + """ |
| 56 | + Override parent method to add timeout and inject usage object when missing. |
| 57 | + This works around a LiteLLM defect where usage block is sometimes dropped. |
| 58 | + """ |
| 59 | + |
| 60 | + # Add usage tracking for streaming when telemetry is active |
| 61 | + stream_options = params.stream_options |
| 62 | + if params.stream and get_current_span() is not None: |
| 63 | + if stream_options is None: |
| 64 | + stream_options = {"include_usage": True} |
| 65 | + elif "include_usage" not in stream_options: |
| 66 | + stream_options = {**stream_options, "include_usage": True} |
| 67 | + |
| 68 | + model_obj = await self.model_store.get_model(params.model) |
| 69 | + |
| 70 | + request_params = await prepare_openai_completion_params( |
| 71 | + model=self.get_litellm_model_name(model_obj.provider_resource_id), |
| 72 | + messages=params.messages, |
| 73 | + frequency_penalty=params.frequency_penalty, |
| 74 | + function_call=params.function_call, |
| 75 | + functions=params.functions, |
| 76 | + logit_bias=params.logit_bias, |
| 77 | + logprobs=params.logprobs, |
| 78 | + max_completion_tokens=params.max_completion_tokens, |
| 79 | + max_tokens=params.max_tokens, |
| 80 | + n=params.n, |
| 81 | + parallel_tool_calls=params.parallel_tool_calls, |
| 82 | + presence_penalty=params.presence_penalty, |
| 83 | + response_format=params.response_format, |
| 84 | + seed=params.seed, |
| 85 | + stop=params.stop, |
| 86 | + stream=params.stream, |
| 87 | + stream_options=stream_options, |
| 88 | + temperature=params.temperature, |
| 89 | + tool_choice=params.tool_choice, |
| 90 | + tools=params.tools, |
| 91 | + top_logprobs=params.top_logprobs, |
| 92 | + top_p=params.top_p, |
| 93 | + user=params.user, |
| 94 | + api_key=self.get_api_key(), |
| 95 | + api_base=self.api_base, |
| 96 | + # These are watsonx-specific parameters |
| 97 | + timeout=self.config.timeout, |
| 98 | + project_id=self.config.project_id, |
| 99 | + ) |
| 100 | + |
| 101 | + result = await litellm.acompletion(**request_params) |
| 102 | + |
| 103 | + # If not streaming, check and inject usage if missing |
| 104 | + if not params.stream: |
| 105 | + # Use getattr to safely handle cases where usage attribute might not exist |
| 106 | + if getattr(result, "usage", None) is None: |
| 107 | + # Create usage object with zeros |
| 108 | + usage_obj = OpenAIChatCompletionUsage( |
| 109 | + prompt_tokens=0, |
| 110 | + completion_tokens=0, |
| 111 | + total_tokens=0, |
| 112 | + ) |
| 113 | + # Use model_copy to create a new response with the usage injected |
| 114 | + result = result.model_copy(update={"usage": usage_obj}) |
| 115 | + return result |
| 116 | + |
| 117 | + # For streaming, wrap the iterator to normalize chunks |
| 118 | + return self._normalize_stream(result) |
| 119 | + |
| 120 | + def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk: |
| 121 | + """ |
| 122 | + Normalize a chunk to ensure it has all expected attributes. |
| 123 | + This works around LiteLLM not always including all expected attributes. |
| 124 | + """ |
| 125 | + # Ensure chunk has usage attribute with zeros if missing |
| 126 | + if not hasattr(chunk, "usage") or chunk.usage is None: |
| 127 | + usage_obj = OpenAIChatCompletionUsage( |
| 128 | + prompt_tokens=0, |
| 129 | + completion_tokens=0, |
| 130 | + total_tokens=0, |
| 131 | + ) |
| 132 | + chunk = chunk.model_copy(update={"usage": usage_obj}) |
| 133 | + |
| 134 | + # Ensure all delta objects in choices have expected attributes |
| 135 | + if hasattr(chunk, "choices") and chunk.choices: |
| 136 | + normalized_choices = [] |
| 137 | + for choice in chunk.choices: |
| 138 | + if hasattr(choice, "delta") and choice.delta: |
| 139 | + delta = choice.delta |
| 140 | + # Build update dict for missing attributes |
| 141 | + delta_updates = {} |
| 142 | + if not hasattr(delta, "refusal"): |
| 143 | + delta_updates["refusal"] = None |
| 144 | + if not hasattr(delta, "reasoning_content"): |
| 145 | + delta_updates["reasoning_content"] = None |
| 146 | + |
| 147 | + # If we need to update delta, create a new choice with updated delta |
| 148 | + if delta_updates: |
| 149 | + new_delta = delta.model_copy(update=delta_updates) |
| 150 | + new_choice = choice.model_copy(update={"delta": new_delta}) |
| 151 | + normalized_choices.append(new_choice) |
| 152 | + else: |
| 153 | + normalized_choices.append(choice) |
| 154 | + else: |
| 155 | + normalized_choices.append(choice) |
| 156 | + |
| 157 | + # If we modified any choices, create a new chunk with updated choices |
| 158 | + if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))): |
| 159 | + chunk = chunk.model_copy(update={"choices": normalized_choices}) |
33 | 160 |
|
34 | | - async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: |
35 | | - # Get base parameters from parent |
36 | | - params = await super()._get_params(request) |
| 161 | + return chunk |
37 | 162 |
|
38 | | - # Add watsonx.ai specific parameters |
39 | | - params["project_id"] = self.config.project_id |
40 | | - params["time_limit"] = self.config.timeout |
41 | | - return params |
| 163 | + async def _normalize_stream( |
| 164 | + self, stream: AsyncIterator[OpenAIChatCompletionChunk] |
| 165 | + ) -> AsyncIterator[OpenAIChatCompletionChunk]: |
| 166 | + """ |
| 167 | + Normalize all chunks in the stream to ensure they have expected attributes. |
| 168 | + This works around LiteLLM sometimes not including expected attributes. |
| 169 | + """ |
| 170 | + try: |
| 171 | + async for chunk in stream: |
| 172 | + # Normalize and yield each chunk immediately |
| 173 | + yield self._normalize_chunk(chunk) |
| 174 | + except Exception as e: |
| 175 | + logger.error(f"Error normalizing stream: {e}", exc_info=True) |
| 176 | + raise |
| 177 | + |
| 178 | + async def openai_completion( |
| 179 | + self, |
| 180 | + params: OpenAICompletionRequestWithExtraBody, |
| 181 | + ) -> OpenAICompletion: |
| 182 | + """ |
| 183 | + Override parent method to add watsonx-specific parameters. |
| 184 | + """ |
| 185 | + from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params |
| 186 | + |
| 187 | + model_obj = await self.model_store.get_model(params.model) |
| 188 | + |
| 189 | + request_params = await prepare_openai_completion_params( |
| 190 | + model=self.get_litellm_model_name(model_obj.provider_resource_id), |
| 191 | + prompt=params.prompt, |
| 192 | + best_of=params.best_of, |
| 193 | + echo=params.echo, |
| 194 | + frequency_penalty=params.frequency_penalty, |
| 195 | + logit_bias=params.logit_bias, |
| 196 | + logprobs=params.logprobs, |
| 197 | + max_tokens=params.max_tokens, |
| 198 | + n=params.n, |
| 199 | + presence_penalty=params.presence_penalty, |
| 200 | + seed=params.seed, |
| 201 | + stop=params.stop, |
| 202 | + stream=params.stream, |
| 203 | + stream_options=params.stream_options, |
| 204 | + temperature=params.temperature, |
| 205 | + top_p=params.top_p, |
| 206 | + user=params.user, |
| 207 | + suffix=params.suffix, |
| 208 | + api_key=self.get_api_key(), |
| 209 | + api_base=self.api_base, |
| 210 | + # These are watsonx-specific parameters |
| 211 | + timeout=self.config.timeout, |
| 212 | + project_id=self.config.project_id, |
| 213 | + ) |
| 214 | + return await litellm.atext_completion(**request_params) |
| 215 | + |
| 216 | + async def openai_embeddings( |
| 217 | + self, |
| 218 | + params: OpenAIEmbeddingsRequestWithExtraBody, |
| 219 | + ) -> OpenAIEmbeddingsResponse: |
| 220 | + """ |
| 221 | + Override parent method to add watsonx-specific parameters. |
| 222 | + """ |
| 223 | + model_obj = await self.model_store.get_model(params.model) |
| 224 | + |
| 225 | + # Convert input to list if it's a string |
| 226 | + input_list = [params.input] if isinstance(params.input, str) else params.input |
| 227 | + |
| 228 | + # Call litellm embedding function with watsonx-specific parameters |
| 229 | + response = litellm.embedding( |
| 230 | + model=self.get_litellm_model_name(model_obj.provider_resource_id), |
| 231 | + input=input_list, |
| 232 | + api_key=self.get_api_key(), |
| 233 | + api_base=self.api_base, |
| 234 | + dimensions=params.dimensions, |
| 235 | + # These are watsonx-specific parameters |
| 236 | + timeout=self.config.timeout, |
| 237 | + project_id=self.config.project_id, |
| 238 | + ) |
| 239 | + |
| 240 | + # Convert response to OpenAI format |
| 241 | + from llama_stack.apis.inference import OpenAIEmbeddingUsage |
| 242 | + from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response |
| 243 | + |
| 244 | + data = b64_encode_openai_embeddings_response(response.data, params.encoding_format) |
| 245 | + |
| 246 | + usage = OpenAIEmbeddingUsage( |
| 247 | + prompt_tokens=response["usage"]["prompt_tokens"], |
| 248 | + total_tokens=response["usage"]["total_tokens"], |
| 249 | + ) |
| 250 | + |
| 251 | + return OpenAIEmbeddingsResponse( |
| 252 | + data=data, |
| 253 | + model=model_obj.provider_resource_id, |
| 254 | + usage=usage, |
| 255 | + ) |
| 256 | + |
| 257 | + def get_base_url(self) -> str: |
| 258 | + return self.config.url |
42 | 259 |
|
43 | 260 | # Copied from OpenAIMixin |
44 | 261 | async def check_model_availability(self, model: str) -> bool: |
|
0 commit comments