diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index 819c7daec141..5bbd2af9d212 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -10,7 +10,6 @@ import os import time import traceback -import random from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast, overload @@ -20,7 +19,6 @@ from litellm.constants import DB_SPEND_UPDATE_JOB_NAME from litellm.litellm_core_utils.safe_json_loads import safe_json_loads from litellm.proxy._types import ( - DB_CONNECTION_ERROR_TYPES, BaseDailySpendTransaction, DailyTagSpendTransaction, DailyTeamSpendTransaction, @@ -530,7 +528,14 @@ async def _commit_spend_updates_to_db_without_redis_buffer( This is the regular flow of committing to db without using a redis buffer - Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS. + Multi-rows writes to the database should ideally always be consistently sorted to minimize the likelihood of deadlocks: + ideally the sorting order should be chosen so that writes to ALL indexes on a table happen in the same order across all + concurrent transactions, as any out-of-order concurrent write to the table or ANY index increases the chances of deadlocks. + Finding a single consistent order across multiple indexes is generally impossible, so we pick one to minimize the chance + of transient deadlocks, and retry later if we are unlucky. + + Note: This flow can cause deadlocks under high load. Use self._commit_spend_updates_to_db_with_redis() instead + if you experience a high rate of deadlocks that the retry logic fails to handle. """ # Aggregate all in memory spend updates (key, user, end_user, team, team_member, org) and commit to db @@ -600,7 +605,7 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915 """ from litellm.proxy.utils import ( ProxyUpdateSpend, - _raise_failed_update_spend_exception, + _handle_db_exception_retriable, ) ### UPDATE USER TABLE ### @@ -622,26 +627,15 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915 for ( user_id, response_cost, - ) in user_list_transactions.items(): + ) in sorted(user_list_transactions.items()): batcher.litellm_usertable.update_many( where={"user_id": user_id}, data={"spend": {"increment": response_cost}}, ) break - except DB_CONNECTION_ERROR_TYPES as e: - if ( - i >= n_retry_times - ): # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, - start_time=start_time, - proxy_logging_obj=proxy_logging_obj, - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + await _handle_db_exception_retriable( + e=e, i=i, n_retry_times=n_retry_times, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) ### UPDATE END-USER TABLE ### @@ -677,26 +671,15 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915 for ( token, response_cost, - ) in key_list_transactions.items(): + ) in sorted(key_list_transactions.items()): batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists where={"token": token}, data={"spend": {"increment": response_cost}}, ) break - except DB_CONNECTION_ERROR_TYPES as e: - if ( - i >= n_retry_times - ): # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, - start_time=start_time, - proxy_logging_obj=proxy_logging_obj, - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + await _handle_db_exception_retriable( + e=e, i=i, n_retry_times=n_retry_times, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) ### UPDATE TEAM TABLE ### @@ -718,7 +701,7 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915 for ( team_id, response_cost, - ) in team_list_transactions.items(): + ) in sorted(team_list_transactions.items()): verbose_proxy_logger.debug( "Updating spend for team id={} by {}".format( team_id, response_cost @@ -729,20 +712,9 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915 data={"spend": {"increment": response_cost}}, ) break - except DB_CONNECTION_ERROR_TYPES as e: - if ( - i >= n_retry_times - ): # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, - start_time=start_time, - proxy_logging_obj=proxy_logging_obj, - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + await _handle_db_exception_retriable( + e=e, i=i, n_retry_times=n_retry_times, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) ### UPDATE TEAM Membership TABLE with spend ### @@ -768,7 +740,7 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915 for ( key, response_cost, - ) in team_member_list_transactions.items(): + ) in sorted(team_member_list_transactions.items()): # key is "team_id::::user_id::" team_id = key.split("::")[1] user_id = key.split("::")[3] @@ -778,20 +750,9 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915 data={"spend": {"increment": response_cost}}, ) break - except DB_CONNECTION_ERROR_TYPES as e: - if ( - i >= n_retry_times - ): # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, - start_time=start_time, - proxy_logging_obj=proxy_logging_obj, - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + await _handle_db_exception_retriable( + e=e, i=i, n_retry_times=n_retry_times, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) ### UPDATE ORG TABLE ### @@ -810,33 +771,15 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915 for ( org_id, response_cost, - ) in org_list_transactions.items(): + ) in sorted(org_list_transactions.items()): batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists where={"organization_id": org_id}, data={"spend": {"increment": response_cost}}, ) break - except DB_CONNECTION_ERROR_TYPES as e: - if ( - i >= n_retry_times - ): # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, - start_time=start_time, - proxy_logging_obj=proxy_logging_obj, - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep( - # Sleep a random amount to avoid retrying and deadlocking again: when two transactions deadlock they are - # cancelled basically at the same time, so if they wait the same time they will also retry at the same time - # and thus they are more likely to deadlock again. - # Instead, we sleep a random amount so that they retry at slightly different times, lowering the chance of - # repeated deadlocks, and therefore of exceeding the retry limit. - random.uniform(2**i, 2 ** (i + 1)) - ) except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + await _handle_db_exception_retriable( + e=e, i=i, n_retry_times=n_retry_times, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) ### UPDATE TAG TABLE ### @@ -873,7 +816,7 @@ async def _update_entity_spend_in_db( prisma_client: Prisma client instance proxy_logging_obj: Proxy logging object """ - from litellm.proxy.utils import _raise_failed_update_spend_exception + from litellm.proxy.utils import _handle_db_exception_retriable verbose_proxy_logger.debug( f"{entity_name} Spend transactions: {transactions}" @@ -886,7 +829,7 @@ async def _update_entity_spend_in_db( timeout=timedelta(seconds=60) ) as transaction: async with transaction.batch_() as batcher: - for entity_id, response_cost in transactions.items(): + for entity_id, response_cost in sorted(transactions.items()): verbose_proxy_logger.debug( f"Updating spend for {entity_name} {where_field}={entity_id} by {response_cost}" ) @@ -895,17 +838,9 @@ async def _update_entity_spend_in_db( data={"spend": {"increment": response_cost}}, ) break - except DB_CONNECTION_ERROR_TYPES as e: - if i >= n_retry_times: - _raise_failed_update_spend_exception( - e=e, - start_time=start_time, - proxy_logging_obj=proxy_logging_obj, - ) - await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + await _handle_db_exception_retriable( + e=e, i=i, n_retry_times=n_retry_times, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) # fmt: off @@ -971,7 +906,7 @@ async def _update_daily_spend( """ Generic function to update daily spend for any entity type (user, team, tag) """ - from litellm.proxy.utils import _raise_failed_update_spend_exception + from litellm.proxy.utils import _raise_failed_update_spend_exception, _handle_db_exception_retriable verbose_proxy_logger.debug( f"Daily {entity_type.capitalize()} Spend transactions: {len(daily_spend_transactions)}" @@ -984,26 +919,12 @@ async def _update_daily_spend( try: # Sort the transactions to minimize the probability of deadlocks by reducing the chance of concurrent # trasactions locking the same rows/ranges in different orders. - transactions_to_process = dict( - sorted( - daily_spend_transactions.items(), - # Normally to avoid deadlocks we would sort by the index, but since we have sprinkled indexes - # on our schema like we're discount Salt Bae, we just sort by all fields that have an index, - # in an ad-hoc (but hopefully sensible) order of indexes. The actual ordering matters less than - # ensuring that all concurrent transactions sort in the same order. - # We could in theory use the dict key, as it contains basically the same fields, but this is more - # robust to future changes in the key format. - # If _update_daily_spend ever gets the ability to write to multiple tables at once, the sorting - # should sort by the table first. - key=lambda x: ( - x[1]["date"], - x[1].get(entity_id_field) or "", - x[1]["api_key"], - x[1]["model"], - x[1]["custom_llm_provider"], - ), - )[:BATCH_SIZE] - ) + # Normally to avoid deadlocks we would sort by the index, but since we have sprinkled indexes + # on our schema like we're discount Salt Bae, we just sort by the dict key. The actual ordering + # matters less than ensuring that all concurrent transactions sort in the same order. + # If _update_daily_spend ever gets the ability to write to multiple tables at once, the sorting + # should sort by the table first. + transactions_to_process = dict(sorted(daily_spend_transactions.items())[:BATCH_SIZE]) if len(transactions_to_process) == 0: verbose_proxy_logger.debug( @@ -1122,20 +1043,9 @@ async def _update_daily_spend( break - except DB_CONNECTION_ERROR_TYPES as e: - if i >= n_retry_times: - _raise_failed_update_spend_exception( - e=e, - start_time=start_time, - proxy_logging_obj=proxy_logging_obj, - ) - await asyncio.sleep( - # Sleep a random amount to avoid retrying and deadlocking again: when two transactions deadlock they are - # cancelled basically at the same time, so if they wait the same time they will also retry at the same time - # and thus they are more likely to deadlock again. - # Instead, we sleep a random amount so that they retry at slightly different times, lowering the chance of - # repeated deadlocks, and therefore of exceeding the retry limit. - random.uniform(2**i, 2 ** (i + 1)) + except Exception as e: + await _handle_db_exception_retriable( + e=e, i=i, n_retry_times=n_retry_times, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) except Exception as e: diff --git a/litellm/proxy/db/exception_handler.py b/litellm/proxy/db/exception_handler.py index db73f9e9c939..d3078139a9eb 100644 --- a/litellm/proxy/db/exception_handler.py +++ b/litellm/proxy/db/exception_handler.py @@ -43,6 +43,35 @@ def is_database_connection_error(e: Exception) -> bool: if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection: return True return False + + @staticmethod + def is_database_retriable_exception(e: Exception) -> bool: + """ + Returns True if the execption is from a condition (e.g. deadlock, broken connection, etc.) that should be retried. + """ + import re + + if isinstance(e, DB_CONNECTION_ERROR_TYPES): # TODO: is this actually needed? + return True + + # Deadlocks should normally be retried. + # Postgres right now, on deadlock, triggers an exception similar to: + # Error occurred during query execution: ConnectorError(ConnectorError { user_facing_error: None, + # kind: QueryError(PostgresError { code: "40P01", message: "deadlock detected", severity: "ERROR", + # detail: Some("Process 3753505 waits for ShareLock on transaction 5729447; blocked by process 3755128.\n + # Process 3755128 waits for ShareLock on transaction 5729448; blocked by process 3753505."), column: None, + # hint: Some("See server log for query details.") }), transient: false }) + # Unfortunately there does not seem to be a easy way to properly parse that or otherwise detect the specific + # issue, so just match using a regular expression. This is definitely not ideal, but not much we can do about + # it. + if re.search(r'\bConnectorError\b.*?\bQueryError\b.*?\bPostgresError\b.*?"40P01"', str(e), re.DOTALL): + return True + + # TODO: add additional specific cases (be careful to not add exceptions that should not be retried!) + # If many more additional regular expressions are added, it may make sense to combine them into a single one, + # or use something like hyperscan. + + return False @staticmethod def handle_db_exception(e: Exception): diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 62ab69c32ea6..41dae9ad9309 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -3,6 +3,7 @@ import hashlib import json import os +import random import smtplib import threading import time @@ -25,7 +26,6 @@ from litellm import _custom_logger_compatible_callbacks_literal from litellm.constants import DEFAULT_MODEL_CREATED_AT_TIME, MAX_TEAM_LIST_LIMIT from litellm.proxy._types import ( - DB_CONNECTION_ERROR_TYPES, CommonProxyErrors, ProxyErrorTypes, ProxyException, @@ -3266,7 +3266,7 @@ async def update_end_user_spend( for ( end_user_id, response_cost, - ) in end_user_list_transactions.items(): + ) in sorted(end_user_list_transactions.items()): if litellm.max_end_user_budget is not None: pass batcher.litellm_endusertable.upsert( @@ -3282,16 +3282,9 @@ async def update_end_user_spend( ) break - except DB_CONNECTION_ERROR_TYPES as e: - if i >= n_retry_times: # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + await _handle_db_exception_retriable( + e=e, i=i, n_retry_times=n_retry_times, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) @staticmethod @@ -3352,12 +3345,10 @@ async def update_spend_logs( f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}" ) break - except DB_CONNECTION_ERROR_TYPES: - if i is None: - i = 0 - if i >= n_retry_times: - raise - await asyncio.sleep(2**i) + except Exception as e: + await _handle_db_exception_retriable( + e=e, i=i, n_retry_times=n_retry_times, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) except Exception as e: prisma_client.spend_log_transactions = prisma_client.spend_log_transactions[ len(logs_to_process) : @@ -3444,6 +3435,21 @@ def _raise_failed_update_spend_exception( raise e +async def _handle_db_exception_retriable(e: Exception, i: int, n_retry_times: int, start_time: float, proxy_logging_obj: ProxyLogging): + from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler + + if PrismaDBExceptionHandler.is_database_retriable_exception(e): + if i < n_retry_times: + await asyncio.sleep(random.uniform(2**i, 2 ** (i + 1))) # Exponential backoff with jitter + return None # continue with the next retry + + _raise_failed_update_spend_exception( + e=e, + start_time=start_time, + proxy_logging_obj=proxy_logging_obj, + ) + + def _is_projected_spend_over_limit( current_spend: float, soft_budget_limit: Optional[float] ):