Skip to content

Commit

Permalink
feat(router.py): add tpm/rpm tracking on success/failure to global_ro…
Browse files Browse the repository at this point in the history
…uter

Addresses #6914
  • Loading branch information
krrishdholakia committed Nov 26, 2024
1 parent 1b99371 commit 734d00b
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 6 deletions.
78 changes: 72 additions & 6 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
ModelInfo,
ProviderBudgetConfigType,
RetryPolicy,
RouterCacheEnum,
RouterErrors,
RouterGeneralSettings,
RouterModelGroupAliasItem,
Expand Down Expand Up @@ -503,6 +504,14 @@ def __init__( # noqa: PLR0915
litellm.success_callback.append(self.sync_deployment_callback_on_success)
else:
litellm.success_callback = [self.sync_deployment_callback_on_success]
if isinstance(litellm._async_failure_callback, list):
litellm._async_failure_callback.append(
self.async_deployment_callback_on_failure
)
else:
litellm._async_failure_callback = [
self.async_deployment_callback_on_failure
]
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
Expand Down Expand Up @@ -3288,13 +3297,14 @@ async def deployment_callback_on_success(
):
"""
Track remaining tpm/rpm quota for model in model_list
Currently, only updates TPM usage.
"""
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment_name = kwargs["litellm_params"]["metadata"].get(
"deployment", None
) # stable name - works for wildcard routes as well
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
Expand All @@ -3305,6 +3315,8 @@ async def deployment_callback_on_success(
elif isinstance(id, int):
id = str(id)

parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)

_usage_obj = completion_response.get("usage")
total_tokens = _usage_obj.get("total_tokens", 0) if _usage_obj else 0

Expand All @@ -3316,13 +3328,14 @@ async def deployment_callback_on_success(
"%H-%M"
) # use the same timezone regardless of system clock

tpm_key = f"global_router:{id}:tpm:{current_minute}"
tpm_key = RouterCacheEnum.TPM.value.format(
id=id, current_minute=current_minute, model=deployment_name
)
# ------------
# Update usage
# ------------
# update cache

parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
## TPM
await self.cache.async_increment_cache(
key=tpm_key,
Expand All @@ -3331,6 +3344,17 @@ async def deployment_callback_on_success(
ttl=RoutingArgs.ttl.value,
)

## RPM
rpm_key = RouterCacheEnum.RPM.value.format(
id=id, current_minute=current_minute, model=deployment_name
)
await self.cache.async_increment_cache(
key=rpm_key,
value=1,
parent_otel_span=parent_otel_span,
ttl=RoutingArgs.ttl.value,
)

increment_deployment_successes_for_current_minute(
litellm_router_instance=self,
deployment_id=id,
Expand Down Expand Up @@ -3443,6 +3467,40 @@ def deployment_callback_on_failure(
except Exception as e:
raise e

async def async_deployment_callback_on_failure(
self, kwargs, completion_response: Optional[Any], start_time, end_time
):
"""
Update RPM usage for a deployment
"""
deployment_name = kwargs["litellm_params"]["metadata"].get(
"deployment", None
) # handles wildcard routes - by giving the original name sent to `litellm.completion`
model_group = kwargs["litellm_params"]["metadata"].get("model_group", None)
model_info = kwargs["litellm_params"].get("model_info", {}) or {}
id = model_info.get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)

dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock

## RPM
rpm_key = RouterCacheEnum.RPM.value.format(
id=id, current_minute=current_minute, model=deployment_name
)
await self.cache.async_increment_cache(
key=rpm_key,
value=1,
parent_otel_span=parent_otel_span,
ttl=RoutingArgs.ttl.value,
)

def log_retry(self, kwargs: dict, e: Exception) -> dict:
"""
When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing
Expand Down Expand Up @@ -4472,10 +4530,18 @@ async def get_model_group_usage(
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
tpm_keys.append(
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
RouterCacheEnum.TPM.value.format(
id=model["model_info"]["id"],
model=model["model_name"],
current_minute=current_minute,
)
)
rpm_keys.append(
f"global_router:{model['model_info']['id']}:rpm:{current_minute}"
RouterCacheEnum.RPM.value.format(
id=model["model_info"]["id"],
model=model["model_name"],
current_minute=current_minute,
)
)
combined_tpm_rpm_keys = tpm_keys + rpm_keys

Expand Down
3 changes: 3 additions & 0 deletions litellm/router_utils/response_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Return remaining tpm/rpm limits for a given model
"""
5 changes: 5 additions & 0 deletions litellm/types/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,8 @@ class ProviderBudgetInfo(BaseModel):


ProviderBudgetConfigType = Dict[str, ProviderBudgetInfo]


class RouterCacheEnum(enum.Enum):
TPM = "global_router:{id}:{model}:tpm:{current_minute}"
RPM = "global_router:{id}:{model}:rpm:{current_minute}"
68 changes: 68 additions & 0 deletions tests/local_testing/test_router_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,71 @@ def test_router_get_model_info_wildcard_routes():
assert model_info is not None
assert model_info["tpm"] is not None
assert model_info["rpm"] is not None


@pytest.mark.asyncio
async def test_call_router_callbacks_on_success():
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)

with patch.object(
router.cache, "async_increment_cache", new=AsyncMock()
) as mock_callback:
await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
await asyncio.sleep(1)
assert mock_callback.call_count == 2

assert (
mock_callback.call_args_list[0]
.kwargs["key"]
.startswith("global_router:1:gemini/gemini-1.5-flash:tpm")
)
assert (
mock_callback.call_args_list[1]
.kwargs["key"]
.startswith("global_router:1:gemini/gemini-1.5-flash:rpm")
)


@pytest.mark.asyncio
async def test_call_router_callbacks_on_failure():
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)

with patch.object(
router.cache, "async_increment_cache", new=AsyncMock()
) as mock_callback:
with pytest.raises(litellm.RateLimitError):
await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="litellm.RateLimitError",
num_retries=0,
)
await asyncio.sleep(1)
print(mock_callback.call_args_list)
assert mock_callback.call_count == 1

assert (
mock_callback.call_args_list[0]
.kwargs["key"]
.startswith("global_router:1:gemini/gemini-1.5-flash:rpm")
)

0 comments on commit 734d00b

Please sign in to comment.