Skip to content

Commit

Permalink
fix(azure/completions): migrate completions endpoint to support base …
Browse files Browse the repository at this point in the history
…azure llm class

enables consistent auth logic across all azure calls
  • Loading branch information
krrishdholakia committed Mar 12, 2025
1 parent 42af49c commit 23bf7b5
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions tests/litellm/llms/azure/test_azure_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,87 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
for call in azure_calls:
assert "api_key" in call.kwargs, "api_key not found in parameters"
assert "api_base" in call.kwargs, "api_base not found in parameters"


@pytest.mark.parametrize(
"call_type",
[
CallTypes.atext_completion,
CallTypes.acompletion,
],
)
@pytest.mark.asyncio
async def test_ensure_initialize_azure_sdk_client_always_used_azure_text(call_type):
from litellm.router import Router

# Create a router with an Azure model
azure_model_name = "azure_text/chatgpt-v-2"
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": azure_model_name,
"api_key": "test-api-key",
"api_version": os.getenv("AZURE_API_VERSION", "2023-05-15"),
"api_base": os.getenv(
"AZURE_API_BASE", "https://test.openai.azure.com"
),
},
}
],
)

# Prepare test input based on call type
test_inputs = {
"acompletion": {
"messages": [{"role": "user", "content": "Hello, how are you?"}]
},
"atext_completion": {"prompt": "Hello, how are you?"},
}

# Get appropriate input for this call type
input_kwarg = test_inputs.get(call_type.value, {})

patch_target = "litellm.main.azure_text_completions.initialize_azure_sdk_client"

# Mock the initialize_azure_sdk_client function
with patch(patch_target) as mock_init_azure:
# Also mock async_function_with_fallbacks to prevent actual API calls
# Call the appropriate router method
try:
get_attr = getattr(router, call_type.value, None)
if get_attr is None:
pytest.skip(
f"Skipping {call_type.value} because it is not supported on Router"
)
await getattr(router, call_type.value)(
model="gpt-3.5-turbo",
**input_kwarg,
num_retries=0,
azure_ad_token="oidc/test-token",
)
except Exception as e:
traceback.print_exc()

# Verify initialize_azure_sdk_client was called
mock_init_azure.assert_called_once()

# Verify it was called with the right model name
calls = mock_init_azure.call_args_list
azure_calls = [call for call in calls]

litellm_params = azure_calls[0].kwargs["litellm_params"]
print("litellm_params", litellm_params)

assert (
"azure_ad_token" in litellm_params
), "azure_ad_token not found in parameters"
assert (
litellm_params["azure_ad_token"] == "oidc/test-token"
), "azure_ad_token is not correct"

# More detailed verification (optional)
for call in azure_calls:
assert "api_key" in call.kwargs, "api_key not found in parameters"
assert "api_base" in call.kwargs, "api_base not found in parameters"

0 comments on commit 23bf7b5

Please sign in to comment.