Skip to content

Commit

Permalink
feat(router.py): support wildcard routes in get_router_model_info()
Browse files Browse the repository at this point in the history
Addresses #6914
  • Loading branch information
krrishdholakia committed Nov 26, 2024
1 parent 3f48aaa commit 7f346ba
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 7 deletions.
10 changes: 10 additions & 0 deletions litellm/model_prices_and_context_window_backup.json
Original file line number Diff line number Diff line change
Expand Up @@ -3383,6 +3383,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-001": {
Expand All @@ -3406,6 +3408,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash": {
Expand All @@ -3428,6 +3432,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-latest": {
Expand All @@ -3450,6 +3456,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-8b-exp-0924": {
Expand All @@ -3472,6 +3480,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-exp-1114": {
Expand Down
54 changes: 49 additions & 5 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4120,7 +4120,24 @@ def get_deployment_by_model_group_name(
raise Exception("Model Name invalid - {}".format(type(model)))
return None

def get_router_model_info(self, deployment: dict) -> ModelMapInfo:
@overload
def get_router_model_info(
self, deployment: dict, received_model_name: str, id: None = None
) -> ModelMapInfo:
pass

@overload
def get_router_model_info(
self, deployment: None, received_model_name: str, id: str
) -> ModelMapInfo:
pass

def get_router_model_info(
self,
deployment: Optional[dict],
received_model_name: str,
id: Optional[str] = None,
) -> ModelMapInfo:
"""
For a given model id, return the model info (max tokens, input cost, output cost, etc.).
Expand All @@ -4134,6 +4151,14 @@ def get_router_model_info(self, deployment: dict) -> ModelMapInfo:
Raises:
- ValueError -> If model is not mapped yet
"""
if id is not None:
_deployment = self.get_deployment(model_id=id)
if _deployment is not None:
deployment = _deployment.model_dump(exclude_none=True)

if deployment is None:
raise ValueError("Deployment not found")

## GET BASE MODEL
base_model = deployment.get("model_info", {}).get("base_model", None)
if base_model is None:
Expand All @@ -4155,10 +4180,27 @@ def get_router_model_info(self, deployment: dict) -> ModelMapInfo:
elif custom_llm_provider != "azure":
model = _model

potential_models = self.pattern_router.route(received_model_name)
if "*" in model and potential_models is not None: # if wildcard route
for potential_model in potential_models:
try:
if potential_model.get("model_info", {}).get(
"id"
) == deployment.get("model_info", {}).get("id"):
model = potential_model.get("litellm_params", {}).get(
"model"
)
break
except Exception:
pass

## GET LITELLM MODEL INFO - raises exception, if model is not mapped
model_info = litellm.get_model_info(
model="{}/{}".format(custom_llm_provider, model)
)
if not model.startswith(custom_llm_provider):
model_info_name = "{}/{}".format(custom_llm_provider, model)
else:
model_info_name = model

model_info = litellm.get_model_info(model=model_info_name)

## CHECK USER SET MODEL INFO
user_model_info = deployment.get("model_info", {})
Expand Down Expand Up @@ -4807,10 +4849,12 @@ def _pre_call_checks( # noqa: PLR0915
base_model = deployment.get("litellm_params", {}).get(
"base_model", None
)
model_info = self.get_router_model_info(
deployment=deployment, received_model_name=model
)
model = base_model or deployment.get("litellm_params", {}).get(
"model", None
)
model_info = self.get_router_model_info(deployment=deployment)

if (
isinstance(model_info, dict)
Expand Down
2 changes: 2 additions & 0 deletions litellm/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class ModelInfo(TypedDict, total=False):
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
tpm: Optional[int]
rpm: Optional[int]


class GenericStreamingChunk(TypedDict, total=False):
Expand Down
2 changes: 2 additions & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4656,6 +4656,8 @@ def _get_max_position_embeddings(model_name):
),
supports_audio_input=_model_info.get("supports_audio_input", False),
supports_audio_output=_model_info.get("supports_audio_output", False),
tpm=_model_info.get("tpm", None),
rpm=_model_info.get("rpm", None),
)
except Exception as e:
if "OllamaError" in str(e):
Expand Down
10 changes: 10 additions & 0 deletions model_prices_and_context_window.json
Original file line number Diff line number Diff line change
Expand Up @@ -3383,6 +3383,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-001": {
Expand All @@ -3406,6 +3408,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash": {
Expand All @@ -3428,6 +3432,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-latest": {
Expand All @@ -3450,6 +3456,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-8b-exp-0924": {
Expand All @@ -3472,6 +3480,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-exp-1114": {
Expand Down
8 changes: 6 additions & 2 deletions tests/local_testing/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,10 +2115,14 @@ def test_router_get_model_info(model, base_model, llm_provider):
assert deployment is not None

if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"):
router.get_router_model_info(deployment=deployment.to_json())
router.get_router_model_info(
deployment=deployment.to_json(), received_model_name=model
)
else:
try:
router.get_router_model_info(deployment=deployment.to_json())
router.get_router_model_info(
deployment=deployment.to_json(), received_model_name=model
)
pytest.fail("Expected this to raise model not mapped error")
except Exception as e:
if "This model isn't mapped yet" in str(e):
Expand Down
21 changes: 21 additions & 0 deletions tests/local_testing/test_router_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,24 @@ async def test_update_kwargs_before_fallbacks(call_type):

print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None


def test_router_get_model_info_wildcard_routes():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)
model_info = router.get_router_model_info(
deployment=None, received_model_name="gemini/gemini-1.5-flash", id="1"
)
print(model_info)
assert model_info is not None
assert model_info["tpm"] is not None
assert model_info["rpm"] is not None

0 comments on commit 7f346ba

Please sign in to comment.