diff --git a/litellm/llms/gemini/chat/transformation.py b/litellm/llms/gemini/chat/transformation.py index fb891ae0ef04..313bb99af74e 100644 --- a/litellm/llms/gemini/chat/transformation.py +++ b/litellm/llms/gemini/chat/transformation.py @@ -12,9 +12,7 @@ from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig -class GoogleAIStudioGeminiConfig( - VertexGeminiConfig -): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported +class GoogleAIStudioGeminiConfig(VertexGeminiConfig): """ Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig @@ -82,6 +80,7 @@ def get_supported_openai_params(self, model: str) -> List[str]: "n", "stop", "logprobs", + "frequency_penalty", ] def map_openai_params( @@ -92,11 +91,6 @@ def map_openai_params( drop_params: bool, ) -> Dict: - # drop frequency_penalty and presence_penalty - if "frequency_penalty" in non_default_params: - del non_default_params["frequency_penalty"] - if "presence_penalty" in non_default_params: - del non_default_params["presence_penalty"] if litellm.vertex_ai_safety_settings is not None: optional_params["safety_settings"] = litellm.vertex_ai_safety_settings return super().map_openai_params( diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index dcb38f62415d..ca9ba960db6d 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -1063,6 +1063,7 @@ def test_is_vertex_anthropic_model(): is False ) + def test_groq_response_format_json_schema(): optional_params = get_optional_params( model="llama-3.1-70b-versatile", @@ -1072,3 +1073,10 @@ def test_groq_response_format_json_schema(): assert optional_params is not None assert "response_format" in optional_params assert optional_params["response_format"]["type"] == "json_object" + + +def test_gemini_frequency_penalty(): + optional_params = get_optional_params( + model="gemini-1.5-flash", custom_llm_provider="gemini", frequency_penalty=0.5 + ) + assert optional_params["frequency_penalty"] == 0.5