diff --git a/tests/litellm/llms/azure/test_azure_common_utils.py b/tests/litellm/llms/azure/test_azure_common_utils.py index a6419c6245f5..7d8c0650f3e4 100644 --- a/tests/litellm/llms/azure/test_azure_common_utils.py +++ b/tests/litellm/llms/azure/test_azure_common_utils.py @@ -370,3 +370,87 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type): for call in azure_calls: assert "api_key" in call.kwargs, "api_key not found in parameters" assert "api_base" in call.kwargs, "api_base not found in parameters" + + +@pytest.mark.parametrize( + "call_type", + [ + CallTypes.atext_completion, + CallTypes.acompletion, + ], +) +@pytest.mark.asyncio +async def test_ensure_initialize_azure_sdk_client_always_used_azure_text(call_type): + from litellm.router import Router + + # Create a router with an Azure model + azure_model_name = "azure_text/chatgpt-v-2" + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": azure_model_name, + "api_key": "test-api-key", + "api_version": os.getenv("AZURE_API_VERSION", "2023-05-15"), + "api_base": os.getenv( + "AZURE_API_BASE", "https://test.openai.azure.com" + ), + }, + } + ], + ) + + # Prepare test input based on call type + test_inputs = { + "acompletion": { + "messages": [{"role": "user", "content": "Hello, how are you?"}] + }, + "atext_completion": {"prompt": "Hello, how are you?"}, + } + + # Get appropriate input for this call type + input_kwarg = test_inputs.get(call_type.value, {}) + + patch_target = "litellm.main.azure_text_completions.initialize_azure_sdk_client" + + # Mock the initialize_azure_sdk_client function + with patch(patch_target) as mock_init_azure: + # Also mock async_function_with_fallbacks to prevent actual API calls + # Call the appropriate router method + try: + get_attr = getattr(router, call_type.value, None) + if get_attr is None: + pytest.skip( + f"Skipping {call_type.value} because it is not supported on Router" + ) + await getattr(router, call_type.value)( + model="gpt-3.5-turbo", + **input_kwarg, + num_retries=0, + azure_ad_token="oidc/test-token", + ) + except Exception as e: + traceback.print_exc() + + # Verify initialize_azure_sdk_client was called + mock_init_azure.assert_called_once() + + # Verify it was called with the right model name + calls = mock_init_azure.call_args_list + azure_calls = [call for call in calls] + + litellm_params = azure_calls[0].kwargs["litellm_params"] + print("litellm_params", litellm_params) + + assert ( + "azure_ad_token" in litellm_params + ), "azure_ad_token not found in parameters" + assert ( + litellm_params["azure_ad_token"] == "oidc/test-token" + ), "azure_ad_token is not correct" + + # More detailed verification (optional) + for call in azure_calls: + assert "api_key" in call.kwargs, "api_key not found in parameters" + assert "api_base" in call.kwargs, "api_base not found in parameters"