From 7f346ba78152282a69f5708e2be40303825158c8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Nov 2024 16:20:42 +0530 Subject: [PATCH] feat(router.py): support wildcard routes in `get_router_model_info()` Addresses https://github.com/BerriAI/litellm/issues/6914 --- ...odel_prices_and_context_window_backup.json | 10 ++++ litellm/router.py | 54 +++++++++++++++++-- litellm/types/utils.py | 2 + litellm/utils.py | 2 + model_prices_and_context_window.json | 10 ++++ tests/local_testing/test_router.py | 8 ++- tests/local_testing/test_router_utils.py | 21 ++++++++ 7 files changed, 100 insertions(+), 7 deletions(-) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index a56472f7f0b2..23eb59e0bb93 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -3383,6 +3383,8 @@ "supports_vision": true, "supports_response_schema": true, "supports_prompt_caching": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash-001": { @@ -3406,6 +3408,8 @@ "supports_vision": true, "supports_response_schema": true, "supports_prompt_caching": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash": { @@ -3428,6 +3432,8 @@ "supports_function_calling": true, "supports_vision": true, "supports_response_schema": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash-latest": { @@ -3450,6 +3456,8 @@ "supports_function_calling": true, "supports_vision": true, "supports_response_schema": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash-8b-exp-0924": { @@ -3472,6 +3480,8 @@ "supports_function_calling": true, "supports_vision": true, "supports_response_schema": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-exp-1114": { diff --git a/litellm/router.py b/litellm/router.py index 26ca705265d7..1cf0aceff128 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4120,7 +4120,24 @@ def get_deployment_by_model_group_name( raise Exception("Model Name invalid - {}".format(type(model))) return None - def get_router_model_info(self, deployment: dict) -> ModelMapInfo: + @overload + def get_router_model_info( + self, deployment: dict, received_model_name: str, id: None = None + ) -> ModelMapInfo: + pass + + @overload + def get_router_model_info( + self, deployment: None, received_model_name: str, id: str + ) -> ModelMapInfo: + pass + + def get_router_model_info( + self, + deployment: Optional[dict], + received_model_name: str, + id: Optional[str] = None, + ) -> ModelMapInfo: """ For a given model id, return the model info (max tokens, input cost, output cost, etc.). @@ -4134,6 +4151,14 @@ def get_router_model_info(self, deployment: dict) -> ModelMapInfo: Raises: - ValueError -> If model is not mapped yet """ + if id is not None: + _deployment = self.get_deployment(model_id=id) + if _deployment is not None: + deployment = _deployment.model_dump(exclude_none=True) + + if deployment is None: + raise ValueError("Deployment not found") + ## GET BASE MODEL base_model = deployment.get("model_info", {}).get("base_model", None) if base_model is None: @@ -4155,10 +4180,27 @@ def get_router_model_info(self, deployment: dict) -> ModelMapInfo: elif custom_llm_provider != "azure": model = _model + potential_models = self.pattern_router.route(received_model_name) + if "*" in model and potential_models is not None: # if wildcard route + for potential_model in potential_models: + try: + if potential_model.get("model_info", {}).get( + "id" + ) == deployment.get("model_info", {}).get("id"): + model = potential_model.get("litellm_params", {}).get( + "model" + ) + break + except Exception: + pass + ## GET LITELLM MODEL INFO - raises exception, if model is not mapped - model_info = litellm.get_model_info( - model="{}/{}".format(custom_llm_provider, model) - ) + if not model.startswith(custom_llm_provider): + model_info_name = "{}/{}".format(custom_llm_provider, model) + else: + model_info_name = model + + model_info = litellm.get_model_info(model=model_info_name) ## CHECK USER SET MODEL INFO user_model_info = deployment.get("model_info", {}) @@ -4807,10 +4849,12 @@ def _pre_call_checks( # noqa: PLR0915 base_model = deployment.get("litellm_params", {}).get( "base_model", None ) + model_info = self.get_router_model_info( + deployment=deployment, received_model_name=model + ) model = base_model or deployment.get("litellm_params", {}).get( "model", None ) - model_info = self.get_router_model_info(deployment=deployment) if ( isinstance(model_info, dict) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 9fc58dff6942..93b4a39d3be1 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -106,6 +106,8 @@ class ModelInfo(TypedDict, total=False): supports_prompt_caching: Optional[bool] supports_audio_input: Optional[bool] supports_audio_output: Optional[bool] + tpm: Optional[int] + rpm: Optional[int] class GenericStreamingChunk(TypedDict, total=False): diff --git a/litellm/utils.py b/litellm/utils.py index 262af341817b..b925fbf5bba2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4656,6 +4656,8 @@ def _get_max_position_embeddings(model_name): ), supports_audio_input=_model_info.get("supports_audio_input", False), supports_audio_output=_model_info.get("supports_audio_output", False), + tpm=_model_info.get("tpm", None), + rpm=_model_info.get("rpm", None), ) except Exception as e: if "OllamaError" in str(e): diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index a56472f7f0b2..23eb59e0bb93 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -3383,6 +3383,8 @@ "supports_vision": true, "supports_response_schema": true, "supports_prompt_caching": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash-001": { @@ -3406,6 +3408,8 @@ "supports_vision": true, "supports_response_schema": true, "supports_prompt_caching": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash": { @@ -3428,6 +3432,8 @@ "supports_function_calling": true, "supports_vision": true, "supports_response_schema": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash-latest": { @@ -3450,6 +3456,8 @@ "supports_function_calling": true, "supports_vision": true, "supports_response_schema": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-1.5-flash-8b-exp-0924": { @@ -3472,6 +3480,8 @@ "supports_function_calling": true, "supports_vision": true, "supports_response_schema": true, + "tpm": 4000000, + "rpm": 2000, "source": "https://ai.google.dev/pricing" }, "gemini/gemini-exp-1114": { diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 20867e766dbe..7b53d42db0f8 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2115,10 +2115,14 @@ def test_router_get_model_info(model, base_model, llm_provider): assert deployment is not None if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"): - router.get_router_model_info(deployment=deployment.to_json()) + router.get_router_model_info( + deployment=deployment.to_json(), received_model_name=model + ) else: try: - router.get_router_model_info(deployment=deployment.to_json()) + router.get_router_model_info( + deployment=deployment.to_json(), received_model_name=model + ) pytest.fail("Expected this to raise model not mapped error") except Exception as e: if "This model isn't mapped yet" in str(e): diff --git a/tests/local_testing/test_router_utils.py b/tests/local_testing/test_router_utils.py index d266cfbd9602..c6a1623dbd12 100644 --- a/tests/local_testing/test_router_utils.py +++ b/tests/local_testing/test_router_utils.py @@ -174,3 +174,24 @@ async def test_update_kwargs_before_fallbacks(call_type): print(mock_client.call_args.kwargs) assert mock_client.call_args.kwargs["litellm_trace_id"] is not None + + +def test_router_get_model_info_wildcard_routes(): + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + router = Router( + model_list=[ + { + "model_name": "gemini/*", + "litellm_params": {"model": "gemini/*"}, + "model_info": {"id": 1}, + }, + ] + ) + model_info = router.get_router_model_info( + deployment=None, received_model_name="gemini/gemini-1.5-flash", id="1" + ) + print(model_info) + assert model_info is not None + assert model_info["tpm"] is not None + assert model_info["rpm"] is not None