Skip to content

Commit fe60a38

Browse files
Litellm dev 01 2025 p4 (#7776)
* fix(gemini/): support gemini 'frequency_penalty' and 'presence_penalty' Closes #7748 * feat(proxy_server.py): new env var to disable prisma health check on startup * test: fix test
1 parent 8353caa commit fe60a38

File tree

4 files changed

+47
-22
lines changed

4 files changed

+47
-22
lines changed

litellm/llms/gemini/chat/transformation.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
1313

1414

15-
class GoogleAIStudioGeminiConfig(
16-
VertexGeminiConfig
17-
): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
15+
class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
1816
"""
1917
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
2018
@@ -82,6 +80,7 @@ def get_supported_openai_params(self, model: str) -> List[str]:
8280
"n",
8381
"stop",
8482
"logprobs",
83+
"frequency_penalty",
8584
]
8685

8786
def map_openai_params(
@@ -92,11 +91,6 @@ def map_openai_params(
9291
drop_params: bool,
9392
) -> Dict:
9493

95-
# drop frequency_penalty and presence_penalty
96-
if "frequency_penalty" in non_default_params:
97-
del non_default_params["frequency_penalty"]
98-
if "presence_penalty" in non_default_params:
99-
del non_default_params["presence_penalty"]
10094
if litellm.vertex_ai_safety_settings is not None:
10195
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
10296
return super().map_openai_params(

litellm/proxy/proxy_server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3233,7 +3233,11 @@ async def _setup_prisma_client(
32333233
) # set the spend logs row count in proxy state. Don't block execution
32343234

32353235
# run a health check to ensure the DB is ready
3236-
await prisma_client.health_check()
3236+
if (
3237+
get_secret_bool("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", False)
3238+
is not True
3239+
):
3240+
await prisma_client.health_check()
32373241
return prisma_client
32383242

32393243
@classmethod

tests/llm_translation/test_optional_params.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -217,19 +217,6 @@ def test_databricks_optional_params():
217217
assert "user" not in optional_params
218218

219219

220-
def test_gemini_optional_params():
221-
litellm.drop_params = True
222-
optional_params = get_optional_params(
223-
model="",
224-
custom_llm_provider="gemini",
225-
max_tokens=10,
226-
frequency_penalty=10,
227-
)
228-
print(f"optional_params: {optional_params}")
229-
assert len(optional_params) == 1
230-
assert "frequency_penalty" not in optional_params
231-
232-
233220
def test_azure_ai_mistral_optional_params():
234221
litellm.drop_params = True
235222
optional_params = get_optional_params(
@@ -1063,6 +1050,7 @@ def test_is_vertex_anthropic_model():
10631050
is False
10641051
)
10651052

1053+
10661054
def test_groq_response_format_json_schema():
10671055
optional_params = get_optional_params(
10681056
model="llama-3.1-70b-versatile",
@@ -1072,3 +1060,10 @@ def test_groq_response_format_json_schema():
10721060
assert optional_params is not None
10731061
assert "response_format" in optional_params
10741062
assert optional_params["response_format"]["type"] == "json_object"
1063+
1064+
1065+
def test_gemini_frequency_penalty():
1066+
optional_params = get_optional_params(
1067+
model="gemini-1.5-flash", custom_llm_provider="gemini", frequency_penalty=0.5
1068+
)
1069+
assert optional_params["frequency_penalty"] == 0.5

tests/proxy_unit_tests/test_proxy_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,3 +1447,35 @@ def test_update_key_budget_with_temp_budget_increase():
14471447
},
14481448
)
14491449
assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200
1450+
1451+
1452+
from unittest.mock import MagicMock, AsyncMock
1453+
1454+
1455+
@pytest.mark.asyncio
1456+
async def test_health_check_not_called_when_disabled(monkeypatch):
1457+
from litellm.proxy.proxy_server import ProxyStartupEvent
1458+
1459+
# Mock environment variable
1460+
monkeypatch.setenv("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", "true")
1461+
1462+
# Create mock prisma client
1463+
mock_prisma = MagicMock()
1464+
mock_prisma.connect = AsyncMock()
1465+
mock_prisma.health_check = AsyncMock()
1466+
mock_prisma.check_view_exists = AsyncMock()
1467+
mock_prisma._set_spend_logs_row_count_in_proxy_state = AsyncMock()
1468+
# Mock PrismaClient constructor
1469+
monkeypatch.setattr(
1470+
"litellm.proxy.proxy_server.PrismaClient", lambda **kwargs: mock_prisma
1471+
)
1472+
1473+
# Call the setup function
1474+
await ProxyStartupEvent._setup_prisma_client(
1475+
database_url="mock_url",
1476+
proxy_logging_obj=MagicMock(),
1477+
user_api_key_cache=MagicMock(),
1478+
)
1479+
1480+
# Verify health check wasn't called
1481+
mock_prisma.health_check.assert_not_called()

0 commit comments

Comments
 (0)