Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,72 @@
class EnterpriseCallbackControls:
@staticmethod
def is_callback_disabled_dynamically(
callback: litellm.CALLBACK_TYPES,
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams
) -> bool:
"""
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.

Args:
callback: The callback to check (can be string, CustomLogger instance, or callable)
litellm_params: Parameters containing proxy server request info

Returns:
bool: True if the callback should be disabled, False otherwise
"""
from litellm.litellm_core_utils.custom_logger_registry import (
CustomLoggerRegistry,
callback: litellm.CALLBACK_TYPES,
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams,
) -> bool:
"""
Check if a callback is disabled via the x-litellm-disable-callbacks header or via `litellm_disabled_callbacks` in standard_callback_dynamic_params.

Args:
callback: The callback to check (can be string, CustomLogger instance, or callable)
litellm_params: Parameters containing proxy server request info

Returns:
bool: True if the callback should be disabled, False otherwise
"""
from litellm.litellm_core_utils.custom_logger_registry import (
CustomLoggerRegistry,
)

try:
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(
litellm_params, standard_callback_dynamic_params
)
verbose_logger.debug(
f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}"
)
verbose_logger.debug(
f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}"
)
if disabled_callbacks is not None:
#########################################################
# premium user check
#########################################################
if not EnterpriseCallbackControls._premium_user_check():
return False
#########################################################
if isinstance(callback, str):
if callback.lower() in disabled_callbacks:
verbose_logger.debug(
f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
)
return True
elif isinstance(callback, CustomLogger):
# get the string name of the callback
callback_str = (
CustomLoggerRegistry.get_callback_str_from_class_type(
callback.__class__
)
)
if (
callback_str is not None
and callback_str.lower() in disabled_callbacks
):
verbose_logger.debug(
f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}"
)
return True
return False
except Exception as e:
verbose_logger.debug(f"Error checking disabled callbacks header: {str(e)}")
return False

try:
disabled_callbacks = EnterpriseCallbackControls.get_disabled_callbacks(litellm_params, standard_callback_dynamic_params)
verbose_logger.debug(f"Dynamically disabled callbacks from {X_LITELLM_DISABLE_CALLBACKS}: {disabled_callbacks}")
verbose_logger.debug(f"Checking if {callback} is disabled via headers. Disable callbacks from headers: {disabled_callbacks}")
if disabled_callbacks is not None:
#########################################################
# premium user check
#########################################################
if not EnterpriseCallbackControls._premium_user_check():
return False
#########################################################
if isinstance(callback, str):
if callback.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
elif isinstance(callback, CustomLogger):
# get the string name of the callback
callback_str = CustomLoggerRegistry.get_callback_str_from_class_type(callback.__class__)
if callback_str is not None and callback_str.lower() in disabled_callbacks:
verbose_logger.debug(f"Not logging to {callback_str} because it is disabled via {X_LITELLM_DISABLE_CALLBACKS}")
return True
return False
except Exception as e:
verbose_logger.debug(
f"Error checking disabled callbacks header: {str(e)}"
)
return False
@staticmethod
def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_params: StandardCallbackDynamicParams) -> Optional[List[str]]:
def get_disabled_callbacks(
litellm_params: dict,
standard_callback_dynamic_params: StandardCallbackDynamicParams,
) -> Optional[List[str]]:
"""
Get the disabled callbacks from the standard callback dynamic params.
"""
Expand All @@ -71,22 +90,31 @@ def get_disabled_callbacks(litellm_params: dict, standard_callback_dynamic_param
request_headers = get_proxy_server_request_headers(litellm_params)
disabled_callbacks = request_headers.get(X_LITELLM_DISABLE_CALLBACKS, None)
if disabled_callbacks is not None:
disabled_callbacks = set([cb.strip().lower() for cb in disabled_callbacks.split(",")])
disabled_callbacks = set(
[cb.strip().lower() for cb in disabled_callbacks.split(",")]
)
return list(disabled_callbacks)


#########################################################
# check if disabled via request body
#########################################################
if standard_callback_dynamic_params.get("litellm_disabled_callbacks", None) is not None:
return standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)

if (
standard_callback_dynamic_params.get("litellm_disabled_callbacks", None)
is not None
):
return standard_callback_dynamic_params.get(
"litellm_disabled_callbacks", None
)

return None

@staticmethod
def _premium_user_check():
from litellm.proxy.proxy_server import premium_user

if premium_user:
return True
verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}")
return False
verbose_logger.warning(
f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}"
)
return False
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,16 @@ async def send_key_created_email(
)

# Check if API key should be included in email
include_api_key = get_secret_bool(secret_name="EMAIL_INCLUDE_API_KEY", default_value=True)
include_api_key = get_secret_bool(
secret_name="EMAIL_INCLUDE_API_KEY", default_value=True
)
if include_api_key is None:
include_api_key = True # Default to True if not set
key_token_display = send_key_created_email_event.virtual_key if include_api_key else "[Key hidden for security - retrieve from dashboard]"
key_token_display = (
send_key_created_email_event.virtual_key
if include_api_key
else "[Key hidden for security - retrieve from dashboard]"
)

email_html_content = KEY_CREATED_EMAIL_TEMPLATE.format(
email_logo_url=email_params.logo_url,
Expand Down Expand Up @@ -131,10 +137,16 @@ async def send_key_rotated_email(
)

# Check if API key should be included in email
include_api_key = get_secret_bool(secret_name="EMAIL_INCLUDE_API_KEY", default_value=True)
include_api_key = get_secret_bool(
secret_name="EMAIL_INCLUDE_API_KEY", default_value=True
)
if include_api_key is None:
include_api_key = True # Default to True if not set
key_token_display = send_key_rotated_email_event.virtual_key if include_api_key else "[Key hidden for security - retrieve from dashboard]"
key_token_display = (
send_key_rotated_email_event.virtual_key
if include_api_key
else "[Key hidden for security - retrieve from dashboard]"
)

email_html_content = KEY_ROTATED_EMAIL_TEMPLATE.format(
email_logo_url=email_params.logo_url,
Expand Down Expand Up @@ -184,9 +196,14 @@ async def _get_email_params(
unused_custom_fields = []

# Function to safely get custom value or default
def get_custom_or_default(custom_value: Optional[str], default_value: str, field_name: str) -> str:
if custom_value is not None: # Only check premium if trying to use custom value
def get_custom_or_default(
custom_value: Optional[str], default_value: str, field_name: str
) -> str:
if (
custom_value is not None
): # Only check premium if trying to use custom value
from litellm.proxy.proxy_server import premium_user

if premium_user is not True:
unused_custom_fields.append(field_name)
return default_value
Expand All @@ -195,34 +212,44 @@ def get_custom_or_default(custom_value: Optional[str], default_value: str, field

# Get parameters, falling back to defaults if custom values aren't allowed
logo_url = get_custom_or_default(custom_logo, LITELLM_LOGO_URL, "logo URL")
support_contact = get_custom_or_default(custom_support, self.DEFAULT_SUPPORT_EMAIL, "support contact")
base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000") # Not a premium feature
signature = get_custom_or_default(custom_signature, EMAIL_FOOTER, "email signature")
support_contact = get_custom_or_default(
custom_support, self.DEFAULT_SUPPORT_EMAIL, "support contact"
)
base_url = os.getenv(
"PROXY_BASE_URL", "http://0.0.0.0:4000"
) # Not a premium feature
signature = get_custom_or_default(
custom_signature, EMAIL_FOOTER, "email signature"
)

# Get custom subject template based on email event type
if email_event == EmailEvent.new_user_invitation:
subject_template = get_custom_or_default(
custom_subject_invitation,
self.DEFAULT_SUBJECT_TEMPLATES[EmailEvent.new_user_invitation],
"invitation subject template"
"invitation subject template",
)
elif email_event == EmailEvent.virtual_key_created:
subject_template = get_custom_or_default(
custom_subject_key_created,
self.DEFAULT_SUBJECT_TEMPLATES[EmailEvent.virtual_key_created],
"key created subject template"
"key created subject template",
)
elif email_event == EmailEvent.virtual_key_rotated:
custom_subject_key_rotated = os.getenv("EMAIL_SUBJECT_KEY_ROTATED", None)
subject_template = get_custom_or_default(
custom_subject_key_rotated,
self.DEFAULT_SUBJECT_TEMPLATES[EmailEvent.virtual_key_rotated],
"key rotated subject template"
"key rotated subject template",
)
else:
subject_template = "LiteLLM: {event_message}"

subject = subject_template.format(event_message=event_message) if event_message else "LiteLLM Notification"
subject = (
subject_template.format(event_message=event_message)
if event_message
else "LiteLLM Notification"
)

recipient_email: Optional[
str
Expand All @@ -246,9 +273,7 @@ def get_custom_or_default(custom_value: Optional[str], default_value: str, field
"This is an Enterprise feature. To use custom email fields, please upgrade to LiteLLM Enterprise. "
"Schedule a meeting here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
)
verbose_proxy_logger.warning(
f"{warning_msg}"
)
verbose_proxy_logger.warning(f"{warning_msg}")

return EmailParams(
logo_url=logo_url,
Expand Down Expand Up @@ -295,76 +320,91 @@ async def _get_invitation_link(self, user_id: Optional[str], base_url: str) -> s
if not user_id:
verbose_proxy_logger.debug("No user_id provided for invitation link")
return base_url

if not await self._is_prisma_client_available():
return base_url

# Wait for any concurrent invitation creation to complete
await self._wait_for_invitation_creation()

# Get or create invitation
invitation = await self._get_or_create_invitation(user_id)
if not invitation:
verbose_proxy_logger.warning(f"Failed to get/create invitation for user_id: {user_id}")
verbose_proxy_logger.warning(
f"Failed to get/create invitation for user_id: {user_id}"
)
return base_url

return self._construct_invitation_link(invitation.id, base_url)

async def _is_prisma_client_available(self) -> bool:
"""Check if Prisma client is available"""
from litellm.proxy.proxy_server import prisma_client

if prisma_client is None:
verbose_proxy_logger.debug("Prisma client not found. Unable to lookup invitation")
verbose_proxy_logger.debug(
"Prisma client not found. Unable to lookup invitation"
)
return False
return True

async def _wait_for_invitation_creation(self) -> None:
"""
Wait for any concurrent invitation creation to complete.

The UI calls /invitation/new to generate the invitation link.
We wait to ensure any pending invitation creation is completed.
"""
import asyncio

await asyncio.sleep(10)

async def _get_or_create_invitation(self, user_id: str):
"""
Get existing invitation or create a new one for the user

Returns:
Invitation object with id attribute, or None if failed
"""
from litellm.proxy.management_helpers.user_invitation import (
create_invitation_for_user,
)
from litellm.proxy.proxy_server import prisma_client

if prisma_client is None:
verbose_proxy_logger.error("Prisma client is None in _get_or_create_invitation")
verbose_proxy_logger.error(
"Prisma client is None in _get_or_create_invitation"
)
return None

try:
# Try to get existing invitation
existing_invitations = await prisma_client.db.litellm_invitationlink.find_many(
where={"user_id": user_id},
order={"created_at": "desc"},
existing_invitations = (
await prisma_client.db.litellm_invitationlink.find_many(
where={"user_id": user_id},
order={"created_at": "desc"},
)
)

if existing_invitations and len(existing_invitations) > 0:
verbose_proxy_logger.debug(f"Found existing invitation for user_id: {user_id}")
verbose_proxy_logger.debug(
f"Found existing invitation for user_id: {user_id}"
)
return existing_invitations[0]

# Create new invitation if none exists
verbose_proxy_logger.debug(f"Creating new invitation for user_id: {user_id}")
verbose_proxy_logger.debug(
f"Creating new invitation for user_id: {user_id}"
)
return await create_invitation_for_user(
data=InvitationNew(user_id=user_id),
user_api_key_dict=UserAPIKeyAuth(user_id=user_id),
)

except Exception as e:
verbose_proxy_logger.error(f"Error getting/creating invitation for user_id {user_id}: {e}")
verbose_proxy_logger.error(
f"Error getting/creating invitation for user_id {user_id}: {e}"
)
return None

def _construct_invitation_link(self, invitation_id: str, base_url: str) -> str:
Expand Down
Loading
Loading