Skip to content

Commit

Permalink
add check for _model_is_within_list_of_allowed_models
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff committed Nov 25, 2024
1 parent 0e36333 commit e9cdbff
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,15 @@ def common_checks( # noqa: PLR0915
):
# this means the team has access to all models on the proxy
if (
"all-proxy-models" in team_object.models
or "*" in team_object.models
or "openai/*" in team_object.models
_model_is_within_list_of_allowed_models(
model=_model, allowed_models=team_object.models
)
is True
):
# this means the team has access to all models on the proxy
pass
# check if the team model is an access_group
elif model_in_access_group(_model, team_object.models) is True:
pass
elif _model and "*" in _model:
pass
else:
raise Exception(
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
Expand Down Expand Up @@ -868,25 +866,39 @@ async def can_key_call_model(
filtered_models += models_in_current_access_groups
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")

# Check for universal access patterns
if len(filtered_models) == 0:
return True
if "*" in filtered_models:
return True
if model_matches_patterns(model=model, allowed_models=filtered_models) is True:
return True

if model is not None and model not in filtered_models:
if (
_model_is_within_list_of_allowed_models(
model=model, allowed_models=filtered_models
)
is False
):
raise ValueError(
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
f"API Key not allowed to access model. List of allowed models={filtered_models}. Tried to access {model}"
)

valid_token.models = filtered_models
verbose_proxy_logger.debug(
f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
)
return True


def _model_is_within_list_of_allowed_models(
model: str, allowed_models: List[str]
) -> bool:
# Check for universal access patterns
if len(allowed_models) == 0:
return True
if "*" in allowed_models:
return True
if "all-proxy-models" in allowed_models:
return True
if model_matches_patterns(model=model, allowed_models=allowed_models) is True:
return True

return False


def model_matches_patterns(model: str, allowed_models: List[str]) -> bool:
"""
Helper function to check if a model matches any of the allowed model patterns.
Expand Down

0 comments on commit e9cdbff

Please sign in to comment.