Skip to content

Commit

Permalink
fix(auth_checks.py): add check for reducing db calls if user/team id …
Browse files Browse the repository at this point in the history
…does not exist in db

reduces latency/call by ~100ms
  • Loading branch information
krrishdholakia committed Nov 4, 2024
1 parent 2d5543c commit 051907b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 22 deletions.
2 changes: 1 addition & 1 deletion litellm/caching/dual_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.redis_batch_cache_expiry = (
default_redis_batch_cache_expiry
or litellm.default_redis_batch_cache_expiry
or 5
or 10
)
self.default_in_memory_ttl = (
default_in_memory_ttl or litellm.default_in_memory_ttl
Expand Down
15 changes: 0 additions & 15 deletions litellm/integrations/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,21 +281,6 @@ async def async_post_call_failure_hook(
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))

async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode

parent_otel_span = user_api_key_dict.parent_otel_span
if parent_otel_span is not None:
parent_otel_span.set_status(Status(StatusCode.OK))
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))

def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
Expand Down
58 changes: 52 additions & 6 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.proxy._types import (
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
Expand All @@ -42,6 +43,10 @@
else:
Span = Any


last_db_access_time = LimitedSizeOrderedDict(max_size=100)
db_cache_expiry = 5 # refresh every 5s

all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value


Expand Down Expand Up @@ -383,6 +388,18 @@ def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
return False


def _should_check_db(
key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
) -> bool:
current_time = time.time()
if (
key not in last_db_access_time
or current_time - last_db_access_time[key] >= db_cache_expiry
):
return True
return False


@log_to_opentelemetry
async def get_user_object(
user_id: str,
Expand Down Expand Up @@ -412,10 +429,18 @@ async def get_user_object(
if prisma_client is None:
raise Exception("No db connected")
try:

response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
db_access_time_key = "user_id:{}".format(user_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
)
else:
response = None

if response is None:
if user_id_upsert:
Expand Down Expand Up @@ -444,6 +469,9 @@ async def get_user_object(
# save the user object to cache
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)

# save to db access time
last_db_access_time[db_access_time_key] = time.time()

return _response
except Exception as e: # if user not in db
raise ValueError(
Expand Down Expand Up @@ -515,6 +543,12 @@ async def _delete_cache_key_object(


@log_to_opentelemetry
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)


async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
Expand Down Expand Up @@ -544,7 +578,7 @@ async def get_team_object(
):
cached_team_obj = (
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
key=key
key=key, parent_otel_span=parent_otel_span
)
)

Expand All @@ -564,9 +598,18 @@ async def get_team_object(

# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
db_access_time_key = "team_id:{}".format(team_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client
)
else:
response = None

if response is None:
raise Exception
Expand All @@ -580,6 +623,9 @@ async def get_team_object(
proxy_logging_obj=proxy_logging_obj,
)

# save to db access time
last_db_access_time[db_access_time_key] = time.time()

return _response
except Exception:
raise Exception(
Expand Down

0 comments on commit 051907b

Please sign in to comment.