Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def format(self, record):
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY")
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")

# Get preferred provider (default to openai)
PREFERRED_PROVIDER = os.environ.get("PREFERRED_PROVIDER", "openai").lower()
Expand Down Expand Up @@ -112,6 +113,12 @@ def format(self, record):
"gemini-2.0-flash"
]

# List of DeepSeek models
DEEPSEEK_MODELS = [
"deepseek-reasoner",
"deepseek-chat"
]

# Helper function to clean schema for Gemini
def clean_gemini_schema(schema: Any) -> Any:
"""Recursively removes unsupported fields from a JSON schema for Gemini."""
Expand Down Expand Up @@ -202,26 +209,32 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict
clean_v = clean_v[7:]
elif clean_v.startswith('gemini/'):
clean_v = clean_v[7:]
elif clean_v.startswith('deepseek/'):
clean_v = clean_v[9:]

# --- Mapping Logic --- START ---
mapped = False
# Map Haiku to SMALL_MODEL based on provider preference
if 'haiku' in clean_v.lower():
if PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS:
# Check if SMALL_MODEL already has a provider prefix
if "/" in SMALL_MODEL:
new_model = SMALL_MODEL
elif PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS:
new_model = f"gemini/{SMALL_MODEL}"
mapped = True
else:
new_model = f"openai/{SMALL_MODEL}"
mapped = True
mapped = True

# Map Sonnet to BIG_MODEL based on provider preference
elif 'sonnet' in clean_v.lower():
if PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS:
# Check if BIG_MODEL already has a provider prefix
if "/" in BIG_MODEL:
new_model = BIG_MODEL
elif PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS:
new_model = f"gemini/{BIG_MODEL}"
mapped = True
else:
new_model = f"openai/{BIG_MODEL}"
mapped = True
mapped = True

# Add prefixes to non-mapped models if they match known lists
elif not mapped:
Expand All @@ -231,13 +244,16 @@ def validate_model_field(cls, v, info): # Renamed to avoid conflict
elif clean_v in OPENAI_MODELS and not v.startswith('openai/'):
new_model = f"openai/{clean_v}"
mapped = True # Technically mapped to add prefix
elif clean_v in DEEPSEEK_MODELS and not v.startswith('deepseek/'):
new_model = f"deepseek/{clean_v}"
mapped = True # Technically mapped to add prefix
# --- Mapping Logic --- END ---

if mapped:
logger.debug(f"📌 MODEL MAPPING: '{original_model}' ➡️ '{new_model}'")
else:
# If no mapping occurred and no prefix exists, log warning or decide default
if not v.startswith(('openai/', 'gemini/', 'anthropic/')):
if not v.startswith(('openai/', 'gemini/', 'anthropic/', 'deepseek/')):
logger.warning(f"⚠️ No prefix or mapping rule for model: '{original_model}'. Using as is.")
new_model = v # Ensure we return the original if no rule applied

Expand Down Expand Up @@ -275,26 +291,32 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict
clean_v = clean_v[7:]
elif clean_v.startswith('gemini/'):
clean_v = clean_v[7:]
elif clean_v.startswith('deepseek/'):
clean_v = clean_v[9:]

# --- Mapping Logic --- START ---
mapped = False
# Map Haiku to SMALL_MODEL based on provider preference
if 'haiku' in clean_v.lower():
if PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS:
# Check if SMALL_MODEL already has a provider prefix
if "/" in SMALL_MODEL:
new_model = SMALL_MODEL
elif PREFERRED_PROVIDER == "google" and SMALL_MODEL in GEMINI_MODELS:
new_model = f"gemini/{SMALL_MODEL}"
mapped = True
else:
new_model = f"openai/{SMALL_MODEL}"
mapped = True
mapped = True

# Map Sonnet to BIG_MODEL based on provider preference
elif 'sonnet' in clean_v.lower():
if PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS:
# Check if BIG_MODEL already has a provider prefix
if "/" in BIG_MODEL:
new_model = BIG_MODEL
elif PREFERRED_PROVIDER == "google" and BIG_MODEL in GEMINI_MODELS:
new_model = f"gemini/{BIG_MODEL}"
mapped = True
else:
new_model = f"openai/{BIG_MODEL}"
mapped = True
mapped = True

# Add prefixes to non-mapped models if they match known lists
elif not mapped:
Expand All @@ -304,12 +326,15 @@ def validate_model_token_count(cls, v, info): # Renamed to avoid conflict
elif clean_v in OPENAI_MODELS and not v.startswith('openai/'):
new_model = f"openai/{clean_v}"
mapped = True # Technically mapped to add prefix
elif clean_v in DEEPSEEK_MODELS and not v.startswith('deepseek/'):
new_model = f"deepseek/{clean_v}"
mapped = True # Technically mapped to add prefix
# --- Mapping Logic --- END ---

if mapped:
logger.debug(f"📌 TOKEN COUNT MAPPING: '{original_model}' ➡️ '{new_model}'")
else:
if not v.startswith(('openai/', 'gemini/', 'anthropic/')):
if not v.startswith(('openai/', 'gemini/', 'anthropic/', 'deepseek/')):
logger.warning(f"⚠️ No prefix or mapping rule for token count model: '{original_model}'. Using as is.")
new_model = v # Ensure we return the original if no rule applied

Expand Down Expand Up @@ -1097,6 +1122,8 @@ async def create_message(
clean_model = clean_model[len("anthropic/"):]
elif clean_model.startswith("openai/"):
clean_model = clean_model[len("openai/"):]
elif clean_model.startswith("deepseek/"):
clean_model = clean_model[len("deepseek/"):]

logger.debug(f"📊 PROCESSING REQUEST: Model={request.model}, Stream={request.stream}")

Expand All @@ -1110,6 +1137,9 @@ async def create_message(
elif request.model.startswith("gemini/"):
litellm_request["api_key"] = GEMINI_API_KEY
logger.debug(f"Using Gemini API key for model: {request.model}")
elif request.model.startswith("deepseek/"):
litellm_request["api_key"] = DEEPSEEK_API_KEY
logger.debug(f"Using DeepSeek API key for model: {request.model}")
else:
litellm_request["api_key"] = ANTHROPIC_API_KEY
logger.debug(f"Using Anthropic API key for model: {request.model}")
Expand Down Expand Up @@ -1354,6 +1384,8 @@ async def count_tokens(
clean_model = clean_model[len("anthropic/"):]
elif clean_model.startswith("openai/"):
clean_model = clean_model[len("openai/"):]
elif clean_model.startswith("deepseek/"):
clean_model = clean_model[len("deepseek/"):]

# Convert the messages to a format LiteLLM can understand
converted_request = convert_anthropic_to_litellm(
Expand Down Expand Up @@ -1462,4 +1494,4 @@ def log_request_beautifully(method, path, claude_model, openai_model, num_messag
sys.exit(0)

# Configure uvicorn to run with minimal logs
uvicorn.run(app, host="0.0.0.0", port=8082, log_level="error")
uvicorn.run(app, host="0.0.0.0", port=8082, log_level="error")