Skip to content

Commit 0dbf79c

Browse files
are-cesleseb
andauthored
fix: Fixed WatsonX remote inference provider (#3801)
# What does this PR do? This PR fixes issues with the WatsonX provider so it works correctly with LiteLLM. The main problem was that WatsonX requests failed because the provider data validator didn’t properly handle the API key and project ID. This was fixed by updating the WatsonXProviderDataValidator and ensuring the provider data is loaded correctly. The openai_chat_completion method was also updated to match the behavior of other providers while adding WatsonX-specific fields like project_id. It still calls await super().openai_chat_completion.__func__(self, params) to keep the existing setup and tracing logic. After these changes, WatsonX requests now run correctly. ## Test Plan The changes were tested by running chat completion requests and confirming that credentials and project parameters are passed correctly. I have tested with my WatsonX credentials, by using the cli with `uv run llama-stack-client inference chat-completion --session` --------- Signed-off-by: Sébastien Han <seb@redhat.com> Co-authored-by: Sébastien Han <seb@redhat.com>
1 parent 1136daf commit 0dbf79c

File tree

5 files changed

+254
-26
lines changed

5 files changed

+254
-26
lines changed

llama_stack/providers/registry/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def available_providers() -> list[ProviderSpec]:
271271
pip_packages=["litellm"],
272272
module="llama_stack.providers.remote.inference.watsonx",
273273
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
274-
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
274+
provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
275275
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
276276
),
277277
RemoteProviderSpec(

llama_stack/providers/remote/inference/watsonx/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77
import os
88
from typing import Any
99

10-
from pydantic import BaseModel, ConfigDict, Field
10+
from pydantic import BaseModel, Field
1111

1212
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
1313
from llama_stack.schema_utils import json_schema_type
1414

1515

1616
class WatsonXProviderDataValidator(BaseModel):
17-
model_config = ConfigDict(
18-
from_attributes=True,
19-
extra="forbid",
17+
watsonx_project_id: str | None = Field(
18+
default=None,
19+
description="IBM WatsonX project ID",
2020
)
21-
watsonx_api_key: str | None
21+
watsonx_api_key: str | None = None
2222

2323

2424
@json_schema_type

llama_stack/providers/remote/inference/watsonx/watsonx.py

Lines changed: 230 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,258 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
from collections.abc import AsyncIterator
78
from typing import Any
89

10+
import litellm
911
import requests
1012

11-
from llama_stack.apis.inference import ChatCompletionRequest
13+
from llama_stack.apis.inference.inference import (
14+
OpenAIChatCompletion,
15+
OpenAIChatCompletionChunk,
16+
OpenAIChatCompletionRequestWithExtraBody,
17+
OpenAIChatCompletionUsage,
18+
OpenAICompletion,
19+
OpenAICompletionRequestWithExtraBody,
20+
OpenAIEmbeddingsRequestWithExtraBody,
21+
OpenAIEmbeddingsResponse,
22+
)
1223
from llama_stack.apis.models import Model
1324
from llama_stack.apis.models.models import ModelType
25+
from llama_stack.log import get_logger
1426
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
1527
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
28+
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
29+
from llama_stack.providers.utils.telemetry.tracing import get_current_span
30+
31+
logger = get_logger(name=__name__, category="providers::remote::watsonx")
1632

1733

1834
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
1935
_model_cache: dict[str, Model] = {}
2036

37+
provider_data_api_key_field: str = "watsonx_api_key"
38+
2139
def __init__(self, config: WatsonXConfig):
40+
self.available_models = None
41+
self.config = config
42+
api_key = config.auth_credential.get_secret_value() if config.auth_credential else None
2243
LiteLLMOpenAIMixin.__init__(
2344
self,
2445
litellm_provider_name="watsonx",
25-
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
46+
api_key_from_config=api_key,
2647
provider_data_api_key_field="watsonx_api_key",
48+
openai_compat_api_base=self.get_base_url(),
2749
)
28-
self.available_models = None
29-
self.config = config
3050

31-
def get_base_url(self) -> str:
32-
return self.config.url
51+
async def openai_chat_completion(
52+
self,
53+
params: OpenAIChatCompletionRequestWithExtraBody,
54+
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
55+
"""
56+
Override parent method to add timeout and inject usage object when missing.
57+
This works around a LiteLLM defect where usage block is sometimes dropped.
58+
"""
59+
60+
# Add usage tracking for streaming when telemetry is active
61+
stream_options = params.stream_options
62+
if params.stream and get_current_span() is not None:
63+
if stream_options is None:
64+
stream_options = {"include_usage": True}
65+
elif "include_usage" not in stream_options:
66+
stream_options = {**stream_options, "include_usage": True}
67+
68+
model_obj = await self.model_store.get_model(params.model)
69+
70+
request_params = await prepare_openai_completion_params(
71+
model=self.get_litellm_model_name(model_obj.provider_resource_id),
72+
messages=params.messages,
73+
frequency_penalty=params.frequency_penalty,
74+
function_call=params.function_call,
75+
functions=params.functions,
76+
logit_bias=params.logit_bias,
77+
logprobs=params.logprobs,
78+
max_completion_tokens=params.max_completion_tokens,
79+
max_tokens=params.max_tokens,
80+
n=params.n,
81+
parallel_tool_calls=params.parallel_tool_calls,
82+
presence_penalty=params.presence_penalty,
83+
response_format=params.response_format,
84+
seed=params.seed,
85+
stop=params.stop,
86+
stream=params.stream,
87+
stream_options=stream_options,
88+
temperature=params.temperature,
89+
tool_choice=params.tool_choice,
90+
tools=params.tools,
91+
top_logprobs=params.top_logprobs,
92+
top_p=params.top_p,
93+
user=params.user,
94+
api_key=self.get_api_key(),
95+
api_base=self.api_base,
96+
# These are watsonx-specific parameters
97+
timeout=self.config.timeout,
98+
project_id=self.config.project_id,
99+
)
100+
101+
result = await litellm.acompletion(**request_params)
102+
103+
# If not streaming, check and inject usage if missing
104+
if not params.stream:
105+
# Use getattr to safely handle cases where usage attribute might not exist
106+
if getattr(result, "usage", None) is None:
107+
# Create usage object with zeros
108+
usage_obj = OpenAIChatCompletionUsage(
109+
prompt_tokens=0,
110+
completion_tokens=0,
111+
total_tokens=0,
112+
)
113+
# Use model_copy to create a new response with the usage injected
114+
result = result.model_copy(update={"usage": usage_obj})
115+
return result
116+
117+
# For streaming, wrap the iterator to normalize chunks
118+
return self._normalize_stream(result)
119+
120+
def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk:
121+
"""
122+
Normalize a chunk to ensure it has all expected attributes.
123+
This works around LiteLLM not always including all expected attributes.
124+
"""
125+
# Ensure chunk has usage attribute with zeros if missing
126+
if not hasattr(chunk, "usage") or chunk.usage is None:
127+
usage_obj = OpenAIChatCompletionUsage(
128+
prompt_tokens=0,
129+
completion_tokens=0,
130+
total_tokens=0,
131+
)
132+
chunk = chunk.model_copy(update={"usage": usage_obj})
133+
134+
# Ensure all delta objects in choices have expected attributes
135+
if hasattr(chunk, "choices") and chunk.choices:
136+
normalized_choices = []
137+
for choice in chunk.choices:
138+
if hasattr(choice, "delta") and choice.delta:
139+
delta = choice.delta
140+
# Build update dict for missing attributes
141+
delta_updates = {}
142+
if not hasattr(delta, "refusal"):
143+
delta_updates["refusal"] = None
144+
if not hasattr(delta, "reasoning_content"):
145+
delta_updates["reasoning_content"] = None
146+
147+
# If we need to update delta, create a new choice with updated delta
148+
if delta_updates:
149+
new_delta = delta.model_copy(update=delta_updates)
150+
new_choice = choice.model_copy(update={"delta": new_delta})
151+
normalized_choices.append(new_choice)
152+
else:
153+
normalized_choices.append(choice)
154+
else:
155+
normalized_choices.append(choice)
156+
157+
# If we modified any choices, create a new chunk with updated choices
158+
if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))):
159+
chunk = chunk.model_copy(update={"choices": normalized_choices})
33160

34-
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
35-
# Get base parameters from parent
36-
params = await super()._get_params(request)
161+
return chunk
37162

38-
# Add watsonx.ai specific parameters
39-
params["project_id"] = self.config.project_id
40-
params["time_limit"] = self.config.timeout
41-
return params
163+
async def _normalize_stream(
164+
self, stream: AsyncIterator[OpenAIChatCompletionChunk]
165+
) -> AsyncIterator[OpenAIChatCompletionChunk]:
166+
"""
167+
Normalize all chunks in the stream to ensure they have expected attributes.
168+
This works around LiteLLM sometimes not including expected attributes.
169+
"""
170+
try:
171+
async for chunk in stream:
172+
# Normalize and yield each chunk immediately
173+
yield self._normalize_chunk(chunk)
174+
except Exception as e:
175+
logger.error(f"Error normalizing stream: {e}", exc_info=True)
176+
raise
177+
178+
async def openai_completion(
179+
self,
180+
params: OpenAICompletionRequestWithExtraBody,
181+
) -> OpenAICompletion:
182+
"""
183+
Override parent method to add watsonx-specific parameters.
184+
"""
185+
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
186+
187+
model_obj = await self.model_store.get_model(params.model)
188+
189+
request_params = await prepare_openai_completion_params(
190+
model=self.get_litellm_model_name(model_obj.provider_resource_id),
191+
prompt=params.prompt,
192+
best_of=params.best_of,
193+
echo=params.echo,
194+
frequency_penalty=params.frequency_penalty,
195+
logit_bias=params.logit_bias,
196+
logprobs=params.logprobs,
197+
max_tokens=params.max_tokens,
198+
n=params.n,
199+
presence_penalty=params.presence_penalty,
200+
seed=params.seed,
201+
stop=params.stop,
202+
stream=params.stream,
203+
stream_options=params.stream_options,
204+
temperature=params.temperature,
205+
top_p=params.top_p,
206+
user=params.user,
207+
suffix=params.suffix,
208+
api_key=self.get_api_key(),
209+
api_base=self.api_base,
210+
# These are watsonx-specific parameters
211+
timeout=self.config.timeout,
212+
project_id=self.config.project_id,
213+
)
214+
return await litellm.atext_completion(**request_params)
215+
216+
async def openai_embeddings(
217+
self,
218+
params: OpenAIEmbeddingsRequestWithExtraBody,
219+
) -> OpenAIEmbeddingsResponse:
220+
"""
221+
Override parent method to add watsonx-specific parameters.
222+
"""
223+
model_obj = await self.model_store.get_model(params.model)
224+
225+
# Convert input to list if it's a string
226+
input_list = [params.input] if isinstance(params.input, str) else params.input
227+
228+
# Call litellm embedding function with watsonx-specific parameters
229+
response = litellm.embedding(
230+
model=self.get_litellm_model_name(model_obj.provider_resource_id),
231+
input=input_list,
232+
api_key=self.get_api_key(),
233+
api_base=self.api_base,
234+
dimensions=params.dimensions,
235+
# These are watsonx-specific parameters
236+
timeout=self.config.timeout,
237+
project_id=self.config.project_id,
238+
)
239+
240+
# Convert response to OpenAI format
241+
from llama_stack.apis.inference import OpenAIEmbeddingUsage
242+
from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response
243+
244+
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
245+
246+
usage = OpenAIEmbeddingUsage(
247+
prompt_tokens=response["usage"]["prompt_tokens"],
248+
total_tokens=response["usage"]["total_tokens"],
249+
)
250+
251+
return OpenAIEmbeddingsResponse(
252+
data=data,
253+
model=model_obj.provider_resource_id,
254+
usage=usage,
255+
)
256+
257+
def get_base_url(self) -> str:
258+
return self.config.url
42259

43260
# Copied from OpenAIMixin
44261
async def check_model_availability(self, model: str) -> bool:

tests/integration/inference/test_openai_completion.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
5858
# does not work with the specified model, gpt-5-mini. Please choose different model and try
5959
# again. You can learn more about which models can be used with each operation here:
6060
# https://go.microsoft.com/fwlink/?linkid=2197993.'}}"}
61-
"remote::watsonx", # return 404 when hitting the /openai/v1 endpoint
6261
"remote::llama-openai-compat",
6362
):
6463
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
@@ -68,6 +67,7 @@ def skip_if_doesnt_support_completions_logprobs(client_with_models, model_id):
6867
provider_type = provider_from_model(client_with_models, model_id).provider_type
6968
if provider_type in (
7069
"remote::ollama", # logprobs is ignored
70+
"remote::watsonx",
7171
):
7272
pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions logprobs.")
7373

@@ -110,6 +110,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
110110
# Error code 400 - {'message': '"n" > 1 is not currently supported', 'type': 'invalid_request_error', 'param': 'n', 'code': 'wrong_api_format'}
111111
"remote::cerebras",
112112
"remote::databricks", # Bad request: parameter "n" must be equal to 1 for streaming mode
113+
"remote::watsonx",
113114
):
114115
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
115116

@@ -124,7 +125,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
124125
"remote::databricks",
125126
"remote::cerebras",
126127
"remote::runpod",
127-
"remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint
128128
):
129129
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
130130

@@ -508,6 +508,12 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi
508508
assert "hello world" in normalized_content
509509

510510

511+
def skip_if_doesnt_support_completions_stop_sequence(client_with_models, model_id):
512+
provider_type = provider_from_model(client_with_models, model_id).provider_type
513+
if provider_type in ("remote::watsonx",): # openai.BadRequestError: Error code: 400
514+
pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions stop sequence.")
515+
516+
511517
@pytest.mark.parametrize(
512518
"test_case",
513519
[
@@ -516,6 +522,7 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi
516522
)
517523
def test_openai_completion_stop_sequence(client_with_models, openai_client, text_model_id, test_case):
518524
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
525+
skip_if_doesnt_support_completions_stop_sequence(client_with_models, text_model_id)
519526

520527
tc = TestCase(test_case)
521528

tests/integration/inference/test_openai_embeddings.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,15 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
5050

5151
def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id):
5252
provider = provider_from_model(client_with_models, model_id)
53-
if provider.provider_type in (
54-
"remote::together", # returns 400
55-
"inline::sentence-transformers",
56-
# Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'}
57-
"remote::databricks",
53+
if (
54+
provider.provider_type
55+
in (
56+
"remote::together", # returns 400
57+
"inline::sentence-transformers",
58+
# Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'}
59+
"remote::databricks",
60+
"remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384}
61+
)
5862
):
5963
pytest.skip(
6064
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."

0 commit comments

Comments
 (0)