Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/google/adk/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import TypeAlias

from .common import EvalBaseModel
Expand Down Expand Up @@ -71,8 +72,10 @@ class JudgeModelOptions(EvalBaseModel):
),
)

judge_model_config: Optional[genai_types.GenerateContentConfig] = Field(
default=genai_types.GenerateContentConfig,
judge_model_config: SkipJsonSchema[
Optional[genai_types.GenerateContentConfig]
] = Field(
default=None,
description="The configuration for the judge model.",
)

Expand Down
5 changes: 4 additions & 1 deletion src/google/adk/evaluation/hallucinations_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,10 @@ def __init__(self, eval_metric: EvalMetric):
self.segmenter_prompt = _HALLUCINATIONS_V1_SEGMENTER_PROMPT
self.sentence_validator_prompt = _HALLUCINATIONS_V1_VALIDATOR_PROMPT
self._model = self._judge_model_options.judge_model
self._model_config = self._judge_model_options.judge_model_config
self._model_config = (
self._judge_model_options.judge_model_config
or genai_types.GenerateContentConfig()
)

def _setup_auto_rater(self) -> BaseLlm:
model_id = self._judge_model_options.judge_model
Expand Down
3 changes: 2 additions & 1 deletion src/google/adk/evaluation/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ async def evaluate_invocations(
role="user",
)
],
config=self._judge_model_options.judge_model_config,
config=self._judge_model_options.judge_model_config
or genai_types.GenerateContentConfig(),
)
add_default_retry_options_if_not_present(llm_request)
num_samples = self._judge_model_options.num_samples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self,
project: str,
location: str,
connection_template_override: Optional[str] = None,
integration: Optional[str] = None,
triggers: Optional[List[str]] = None,
connection: Optional[str] = None,
Expand All @@ -104,6 +105,8 @@ def __init__(
Args:
project: The GCP project ID.
location: The GCP location.
connection_template_override: Overrides `ExecuteConnection` default
integration name.
integration: The integration name.
triggers: The list of trigger names in the integration.
connection: The connection name.
Expand All @@ -129,6 +132,7 @@ def __init__(
super().__init__(tool_filter=tool_filter)
self.project = project
self.location = location
self._connection_template_override = connection_template_override
self._integration = integration
self._triggers = triggers
self._connection = connection
Expand All @@ -142,6 +146,7 @@ def __init__(
integration_client = IntegrationClient(
project,
location,
connection_template_override,
integration,
triggers,
connection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
self,
project: str,
location: str,
connection_template_override: Optional[str] = None,
integration: Optional[str] = None,
triggers: Optional[List[str]] = None,
connection: Optional[str] = None,
Expand All @@ -50,6 +51,8 @@ def __init__(
Args:
project: The Google Cloud project ID.
location: The Google Cloud location (e.g., us-central1).
connection_template_override: Overrides `ExecuteConnection` default
integration name.
integration: The integration name.
triggers: The list of trigger IDs for the integration.
connection: The connection name.
Expand All @@ -62,6 +65,7 @@ def __init__(
"""
self.project = project
self.location = location
self.connection_template_override = connection_template_override
self.integration = integration
self.triggers = triggers
self.connection = connection
Expand Down Expand Up @@ -130,7 +134,7 @@ def get_openapi_spec_for_connection(self, tool_name="", tool_instructions=""):
Exception: For any other unexpected errors.
"""
# Application Integration needs to be provisioned in the same region as connection and an integration with name "ExecuteConnection" and trigger "api_trigger/ExecuteConnection" should be created as per the documentation.
integration_name = "ExecuteConnection"
integration_name = self.connection_template_override or "ExecuteConnection"
connections_client = ConnectionsClient(
self.project,
self.location,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,15 @@ async def test_initialization_with_integration_and_trigger(
project, location, integration=integration_name, triggers=triggers
)
mock_integration_client.assert_called_once_with(
project, location, integration_name, triggers, None, None, None, None
project,
location,
None,
integration_name,
triggers,
None,
None,
None,
None,
)
mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once()
mock_connections_client.assert_not_called()
Expand All @@ -218,6 +226,7 @@ async def test_initialization_with_integration_and_list_of_triggers(
mock_integration_client.assert_called_once_with(
project,
location,
None,
integration_name,
triggers,
None,
Expand Down Expand Up @@ -247,7 +256,7 @@ async def test_initialization_with_integration_and_empty_trigger_list(
project, location, integration=integration_name
)
mock_integration_client.assert_called_once_with(
project, location, integration_name, None, None, None, None, None
project, location, None, integration_name, None, None, None, None, None
)
mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once()
mock_connections_client.assert_not_called()
Expand Down Expand Up @@ -287,6 +296,7 @@ async def test_initialization_with_connection_and_entity_operations(
location,
None,
None,
None,
connection_name,
entity_operations_list,
None,
Expand Down Expand Up @@ -335,7 +345,15 @@ async def test_initialization_with_connection_and_actions(
tool_instructions=tool_instructions,
)
mock_integration_client.assert_called_once_with(
project, location, None, None, connection_name, None, actions_list, None
project,
location,
None,
None,
None,
connection_name,
None,
actions_list,
None,
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
Expand Down Expand Up @@ -414,6 +432,7 @@ def test_initialization_with_service_account_credentials(
mock_integration_client.assert_called_once_with(
project,
location,
None,
integration_name,
triggers,
None,
Expand Down Expand Up @@ -441,7 +460,15 @@ def test_initialization_without_explicit_service_account_credentials(
project, location, integration=integration_name, triggers=triggers
)
mock_integration_client.assert_called_once_with(
project, location, integration_name, triggers, None, None, None, None
project,
location,
None,
integration_name,
triggers,
None,
None,
None,
None,
)
mock_openapi_toolset.assert_called_once()
_, kwargs = mock_openapi_toolset.call_args
Expand Down Expand Up @@ -542,7 +569,15 @@ async def test_init_with_connection_and_custom_auth(
auth_credential=auth_credential,
)
mock_integration_client.assert_called_once_with(
project, location, None, None, connection_name, None, actions_list, None
project,
location,
None,
None,
None,
connection_name,
None,
actions_list,
None,
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
Expand Down Expand Up @@ -611,7 +646,15 @@ async def test_init_with_connection_with_auth_override_disabled_and_custom_auth(
auth_credential=auth_credential,
)
mock_integration_client.assert_called_once_with(
project, location, None, None, connection_name, None, actions_list, None
project,
location,
None,
None,
None,
connection_name,
None,
actions_list,
None,
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
Expand Down
Loading