Skip to content

Commit

Permalink
fix(gemini/): support gemini 'frequency_penalty' and 'presence_penalty'
Browse files Browse the repository at this point in the history
Closes #7748
  • Loading branch information
krrishdholakia committed Jan 15, 2025
1 parent 35919d9 commit ad80b0d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
10 changes: 2 additions & 8 deletions litellm/llms/gemini/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig


class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
"""
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
Expand Down Expand Up @@ -82,6 +80,7 @@ def get_supported_openai_params(self, model: str) -> List[str]:
"n",
"stop",
"logprobs",
"frequency_penalty",
]

def map_openai_params(
Expand All @@ -92,11 +91,6 @@ def map_openai_params(
drop_params: bool,
) -> Dict:

# drop frequency_penalty and presence_penalty
if "frequency_penalty" in non_default_params:
del non_default_params["frequency_penalty"]
if "presence_penalty" in non_default_params:
del non_default_params["presence_penalty"]
if litellm.vertex_ai_safety_settings is not None:
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
return super().map_openai_params(
Expand Down
8 changes: 8 additions & 0 deletions tests/llm_translation/test_optional_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,7 @@ def test_is_vertex_anthropic_model():
is False
)


def test_groq_response_format_json_schema():
optional_params = get_optional_params(
model="llama-3.1-70b-versatile",
Expand All @@ -1072,3 +1073,10 @@ def test_groq_response_format_json_schema():
assert optional_params is not None
assert "response_format" in optional_params
assert optional_params["response_format"]["type"] == "json_object"


def test_gemini_frequency_penalty():
optional_params = get_optional_params(
model="gemini-1.5-flash", custom_llm_provider="gemini", frequency_penalty=0.5
)
assert optional_params["frequency_penalty"] == 0.5

0 comments on commit ad80b0d

Please sign in to comment.