Skip to content

Commit

Permalink
Revert "Revert "JWT Auth - enforce_rbac support + UI team view, spe…
Browse files Browse the repository at this point in the history
…nd calc fix (#7863)""

This reverts commit 8f7b9ae.
  • Loading branch information
ishaan-jaff committed Jan 24, 2025
1 parent f2d9ce9 commit 085c4ad
Show file tree
Hide file tree
Showing 11 changed files with 419 additions and 196 deletions.
4 changes: 4 additions & 0 deletions docs/my-website/docs/proxy/admin_ui_sso.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';

# ✨ SSO for Admin UI

:::info
Expand Down
29 changes: 28 additions & 1 deletion docs/my-website/docs/proxy/token_auth.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ general_settings:
admin_jwt_scope: "litellm-proxy-admin"
```

## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
## Tracking End-Users / Internal Users / Team / Org

Set the field in the jwt token, which corresponds to a litellm user / team / org.

Expand Down Expand Up @@ -156,6 +156,33 @@ scope: ["litellm-proxy-admin",...]
scope: "litellm-proxy-admin ..."
```
## Enforce Role-Based Access Control (RBAC)
Reject a JWT token if it's valid but doesn't have the required scopes / fields.
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_jwtauth:
admin_jwt_scope: "litellm_proxy_endpoints_access"
admin_allowed_routes:
- openai_routes
- info_routes
public_key_ttl: 600
enforce_rbac: true # 👈 Enforce RBAC
```

Expected Scope in JWT:

```
{
"scope": "litellm_proxy_endpoints_access"
}
```

## Advanced - Allowed Routes

Configure which routes a JWT can access via the config.
Expand Down
17 changes: 17 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
- public_allowed_routes: list of allowed routes for authenticated but unknown litellm role jwt tokens.
- enforce_rbac: If true, enforce RBAC for all routes.
See `auth_checks.py` for the specific routes
"""
Expand Down Expand Up @@ -446,6 +448,8 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
)
end_user_id_jwt_field: Optional[str] = None
public_key_ttl: float = 600
public_allowed_routes: List[str] = ["public_routes"]
enforce_rbac: bool = False

def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model
Expand Down Expand Up @@ -2283,6 +2287,19 @@ class ProxyStateVariables(TypedDict):
UI_TEAM_ID = "litellm-dashboard"


class JWTAuthBuilderResult(TypedDict):
is_proxy_admin: bool
team_object: Optional[LiteLLM_TeamTable]
user_object: Optional[LiteLLM_UserTable]
end_user_object: Optional[LiteLLM_EndUserTable]
org_object: Optional[LiteLLM_OrganizationTable]
token: str
team_id: Optional[str]
user_id: Optional[str]
end_user_id: Optional[str]
org_id: Optional[str]


class ClientSideFallbackModel(TypedDict, total=False):
"""
Dictionary passed when client configuring input
Expand Down
34 changes: 2 additions & 32 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

import asyncio
import inspect
import time
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
Expand All @@ -24,7 +23,6 @@
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
CallInfo,
CommonProxyErrors,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
LiteLLM_OrganizationTable,
Expand Down Expand Up @@ -57,33 +55,6 @@
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value


def _allowed_import_check() -> bool:
from litellm.proxy.auth.user_api_key_auth import _user_api_key_auth_builder

# Get the calling frame
caller_frame = inspect.stack()[2]
caller_function = caller_frame.function
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)

allowed_function = "_user_api_key_auth_builder"
allowed_signature = inspect.signature(_user_api_key_auth_builder)
if caller_function_callable is None or not callable(caller_function_callable):
raise Exception(f"Caller function {caller_function} is not callable")
caller_signature = inspect.signature(caller_function_callable)

if caller_signature != allowed_signature:
raise TypeError(
f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}"
)
# Check if the caller module is allowed
if caller_function != allowed_function:
raise ImportError(
f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}"
)

return True


def common_checks( # noqa: PLR0915
request_body: dict,
team_object: Optional[LiteLLM_TeamTable],
Expand All @@ -108,7 +79,6 @@ def common_checks( # noqa: PLR0915
9. Check if request body is safe
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
"""
_allowed_import_check()
_model = request_body.get("model", None)
if team_object is not None and team_object.blocked is True:
raise Exception(
Expand Down Expand Up @@ -846,7 +816,7 @@ async def get_org_object(
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
) -> Optional[LiteLLM_OrganizationTable]:
"""
- Check if org id in proxy Org Table
- if valid, return LiteLLM_OrganizationTable object
Expand All @@ -861,7 +831,7 @@ async def get_org_object(
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
if cached_org_obj is not None:
if isinstance(cached_org_obj, dict):
return cached_org_obj
return LiteLLM_OrganizationTable(**cached_org_obj)
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
return cached_org_obj
# else, check db
Expand Down
38 changes: 37 additions & 1 deletion litellm/proxy/auth/handle_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import JWKKeyValue, JWTKeyItem, LiteLLM_JWTAuth
from litellm.proxy._types import (
JWKKeyValue,
JWTKeyItem,
LiteLLM_JWTAuth,
LitellmUserRoles,
)
from litellm.proxy.utils import PrismaClient


Expand Down Expand Up @@ -54,6 +59,34 @@ def is_jwt(self, token: str):
parts = token.split(".")
return len(parts) == 3

def get_rbac_role(self, token: dict) -> Optional[LitellmUserRoles]:
"""
Returns the RBAC role the token 'belongs' to.
RBAC roles allowed to make requests:
- PROXY_ADMIN: can make requests to all routes
- TEAM: can make requests to routes associated with a team
- INTERNAL_USER: can make requests to routes associated with a user
Resolves: https://github.com/BerriAI/litellm/issues/6793
Returns:
- PROXY_ADMIN: if token is admin
- TEAM: if token is associated with a team
- INTERNAL_USER: if token is associated with a user
- None: if token is not associated with a team or user
"""
scopes = self.get_scopes(token=token)
is_admin = self.is_admin(scopes=scopes)
if is_admin:
return LitellmUserRoles.PROXY_ADMIN
elif self.get_team_id(token=token, default_value=None) is not None:
return LitellmUserRoles.TEAM
elif self.get_user_id(token=token, default_value=None) is not None:
return LitellmUserRoles.INTERNAL_USER

return None

def is_admin(self, scopes: list) -> bool:
if self.litellm_jwtauth.admin_jwt_scope in scopes:
return True
Expand All @@ -68,12 +101,14 @@ def get_end_user_id(
self, token: dict, default_value: Optional[str]
) -> Optional[str]:
try:

if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
else:
user_id = None
except KeyError:
user_id = default_value

return user_id

def is_required_team_id(self) -> bool:
Expand Down Expand Up @@ -169,6 +204,7 @@ def get_scopes(self, token: dict) -> list:
return scopes

async def get_public_key(self, kid: Optional[str]) -> dict:

keys_url = os.getenv("JWT_PUBLIC_KEY_URL")

if keys_url is None:
Expand Down
Loading

0 comments on commit 085c4ad

Please sign in to comment.