From 5b7c8c04d6e4a688c76fa517922488e3d96353a3 Mon Sep 17 00:00:00 2001 From: Joseph Pagadora Date: Fri, 2 Jan 2026 11:48:59 -0800 Subject: [PATCH] chore: Introduce MetricInfoProvider interface, and refactor metric evaluators to use this interface to provide MetricInfo Co-authored-by: Joseph Pagadora PiperOrigin-RevId: 851406110 --- src/google/adk/evaluation/eval_metrics.py | 10 + .../adk/evaluation/final_response_match_v1.py | 18 -- .../adk/evaluation/final_response_match_v2.py | 18 -- .../adk/evaluation/hallucinations_v1.py | 19 -- .../evaluation/metric_evaluator_registry.py | 30 ++- .../adk/evaluation/metric_info_providers.py | 185 +++++++++++++++++ .../adk/evaluation/response_evaluator.py | 23 -- .../rubric_based_final_response_quality_v1.py | 18 -- .../rubric_based_tool_use_quality_v1.py | 18 -- src/google/adk/evaluation/safety_evaluator.py | 18 -- .../per_turn_user_simulator_quality_v1.py | 21 -- .../adk/evaluation/trajectory_evaluator.py | 20 -- .../test_final_response_match_v1.py | 8 - .../test_final_response_match_v2.py | 10 - .../test_metric_evaluator_registry.py | 196 +++++++++++++----- .../evaluation/test_response_evaluator.py | 26 --- .../test_rubric_based_tool_use_quality_v1.py | 12 -- .../evaluation/test_safety_evaluator.py | 7 - .../evaluation/test_trajectory_evaluator.py | 10 - 19 files changed, 357 insertions(+), 310 deletions(-) create mode 100644 src/google/adk/evaluation/metric_info_providers.py diff --git a/src/google/adk/evaluation/eval_metrics.py b/src/google/adk/evaluation/eval_metrics.py index b7c544ccad..d937d7fe6d 100644 --- a/src/google/adk/evaluation/eval_metrics.py +++ b/src/google/adk/evaluation/eval_metrics.py @@ -14,6 +14,7 @@ from __future__ import annotations +import abc from enum import Enum from typing import Optional from typing import Union @@ -362,3 +363,12 @@ class MetricInfo(EvalBaseModel): metric_value_info: MetricValueInfo = Field( description="Information on the nature of values supported by the metric." ) + + +class MetricInfoProvider(abc.ABC): + """Interface for providing MetricInfo.""" + + @abc.abstractmethod + def get_metric_info(self) -> MetricInfo: + """Returns MetricInfo for a given metric.""" + raise NotImplementedError diff --git a/src/google/adk/evaluation/final_response_match_v1.py b/src/google/adk/evaluation/final_response_match_v1.py index 07365211a4..fb17fe80eb 100644 --- a/src/google/adk/evaluation/final_response_match_v1.py +++ b/src/google/adk/evaluation/final_response_match_v1.py @@ -23,10 +23,6 @@ from .eval_case import ConversationScenario from .eval_case import Invocation from .eval_metrics import EvalMetric -from .eval_metrics import Interval -from .eval_metrics import MetricInfo -from .eval_metrics import MetricValueInfo -from .eval_metrics import PrebuiltMetrics from .evaluator import EvalStatus from .evaluator import EvaluationResult from .evaluator import Evaluator @@ -42,20 +38,6 @@ class RougeEvaluator(Evaluator): def __init__(self, eval_metric: EvalMetric): self._eval_metric = eval_metric - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=PrebuiltMetrics.RESPONSE_MATCH_SCORE.value, - description=( - "This metric evaluates if the agent's final response matches a" - " golden/expected final response using Rouge_1 metric. Value range" - " for this metric is [0,1], with values closer to 1 more desirable." - ), - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) - @override def evaluate_invocations( self, diff --git a/src/google/adk/evaluation/final_response_match_v2.py b/src/google/adk/evaluation/final_response_match_v2.py index ea90c37487..e132f14d8d 100644 --- a/src/google/adk/evaluation/final_response_match_v2.py +++ b/src/google/adk/evaluation/final_response_match_v2.py @@ -26,11 +26,7 @@ from .eval_case import Invocation from .eval_metrics import EvalMetric from .eval_metrics import EvalStatus -from .eval_metrics import Interval from .eval_metrics import LlmAsAJudgeCriterion -from .eval_metrics import MetricInfo -from .eval_metrics import MetricValueInfo -from .eval_metrics import PrebuiltMetrics from .evaluator import EvaluationResult from .evaluator import PerInvocationResult from .llm_as_judge import AutoRaterScore @@ -154,20 +150,6 @@ def __init__( ) self._auto_rater_prompt_template = _FINAL_RESPONSE_MATCH_V2_PROMPT - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=PrebuiltMetrics.FINAL_RESPONSE_MATCH_V2.value, - description=( - "This metric evaluates if the agent's final response matches a" - " golden/expected final response using LLM as a judge. Value range" - " for this metric is [0,1], with values closer to 1 more desirable." - ), - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) - @override def format_auto_rater_prompt( self, diff --git a/src/google/adk/evaluation/hallucinations_v1.py b/src/google/adk/evaluation/hallucinations_v1.py index 8ed87d2c6c..21e03e3a3c 100644 --- a/src/google/adk/evaluation/hallucinations_v1.py +++ b/src/google/adk/evaluation/hallucinations_v1.py @@ -40,10 +40,6 @@ from .eval_case import InvocationEvents from .eval_metrics import EvalMetric from .eval_metrics import HallucinationsCriterion -from .eval_metrics import Interval -from .eval_metrics import MetricInfo -from .eval_metrics import MetricValueInfo -from .eval_metrics import PrebuiltMetrics from .evaluator import EvalStatus from .evaluator import EvaluationResult from .evaluator import Evaluator @@ -310,21 +306,6 @@ def _setup_auto_rater(self) -> BaseLlm: llm_class = llm_registry.resolve(model_id) return llm_class(model=model_id) - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=PrebuiltMetrics.HALLUCINATIONS_V1.value, - description=( - "This metric assesses whether a model response contains any false," - " contradictory, or unsupported claims using a LLM as judge. Value" - " range for this metric is [0,1], with values closer to 1 more" - " desirable." - ), - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) - def _create_context_for_step( self, app_details: Optional[AppDetails], diff --git a/src/google/adk/evaluation/metric_evaluator_registry.py b/src/google/adk/evaluation/metric_evaluator_registry.py index 0d0fb773ca..9e1fc6c23b 100644 --- a/src/google/adk/evaluation/metric_evaluator_registry.py +++ b/src/google/adk/evaluation/metric_evaluator_registry.py @@ -24,6 +24,14 @@ from .evaluator import Evaluator from .final_response_match_v2 import FinalResponseMatchV2Evaluator from .hallucinations_v1 import HallucinationsV1Evaluator +from .metric_info_providers import FinalResponseMatchV2EvaluatorMetricInfoProvider +from .metric_info_providers import HallucinationsV1EvaluatorMetricInfoProvider +from .metric_info_providers import PerTurnUserSimulatorQualityV1MetricInfoProvider +from .metric_info_providers import ResponseEvaluatorMetricInfoProvider +from .metric_info_providers import RubricBasedFinalResponseQualityV1EvaluatorMetricInfoProvider +from .metric_info_providers import RubricBasedToolUseV1EvaluatorMetricInfoProvider +from .metric_info_providers import SafetyEvaluatorV1MetricInfoProvider +from .metric_info_providers import TrajectoryEvaluatorMetricInfoProvider from .response_evaluator import ResponseEvaluator from .rubric_based_final_response_quality_v1 import RubricBasedFinalResponseQualityV1Evaluator from .rubric_based_tool_use_quality_v1 import RubricBasedToolUseV1Evaluator @@ -91,44 +99,44 @@ def _get_default_metric_evaluator_registry() -> MetricEvaluatorRegistry: metric_evaluator_registry = MetricEvaluatorRegistry() metric_evaluator_registry.register_evaluator( - metric_info=TrajectoryEvaluator.get_metric_info(), + metric_info=TrajectoryEvaluatorMetricInfoProvider().get_metric_info(), evaluator=TrajectoryEvaluator, ) metric_evaluator_registry.register_evaluator( - metric_info=ResponseEvaluator.get_metric_info( + metric_info=ResponseEvaluatorMetricInfoProvider( PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value - ), + ).get_metric_info(), evaluator=ResponseEvaluator, ) metric_evaluator_registry.register_evaluator( - metric_info=ResponseEvaluator.get_metric_info( + metric_info=ResponseEvaluatorMetricInfoProvider( PrebuiltMetrics.RESPONSE_MATCH_SCORE.value - ), + ).get_metric_info(), evaluator=ResponseEvaluator, ) metric_evaluator_registry.register_evaluator( - metric_info=SafetyEvaluatorV1.get_metric_info(), + metric_info=SafetyEvaluatorV1MetricInfoProvider().get_metric_info(), evaluator=SafetyEvaluatorV1, ) metric_evaluator_registry.register_evaluator( - metric_info=FinalResponseMatchV2Evaluator.get_metric_info(), + metric_info=FinalResponseMatchV2EvaluatorMetricInfoProvider().get_metric_info(), evaluator=FinalResponseMatchV2Evaluator, ) metric_evaluator_registry.register_evaluator( - metric_info=RubricBasedFinalResponseQualityV1Evaluator.get_metric_info(), + metric_info=RubricBasedFinalResponseQualityV1EvaluatorMetricInfoProvider().get_metric_info(), evaluator=RubricBasedFinalResponseQualityV1Evaluator, ) metric_evaluator_registry.register_evaluator( - metric_info=HallucinationsV1Evaluator.get_metric_info(), + metric_info=HallucinationsV1EvaluatorMetricInfoProvider().get_metric_info(), evaluator=HallucinationsV1Evaluator, ) metric_evaluator_registry.register_evaluator( - metric_info=RubricBasedToolUseV1Evaluator.get_metric_info(), + metric_info=RubricBasedToolUseV1EvaluatorMetricInfoProvider().get_metric_info(), evaluator=RubricBasedToolUseV1Evaluator, ) metric_evaluator_registry.register_evaluator( - metric_info=PerTurnUserSimulatorQualityV1.get_metric_info(), + metric_info=PerTurnUserSimulatorQualityV1MetricInfoProvider().get_metric_info(), evaluator=PerTurnUserSimulatorQualityV1, ) diff --git a/src/google/adk/evaluation/metric_info_providers.py b/src/google/adk/evaluation/metric_info_providers.py new file mode 100644 index 0000000000..4c625b72ef --- /dev/null +++ b/src/google/adk/evaluation/metric_info_providers.py @@ -0,0 +1,185 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .eval_metrics import Interval +from .eval_metrics import MetricInfo +from .eval_metrics import MetricInfoProvider +from .eval_metrics import MetricValueInfo +from .eval_metrics import PrebuiltMetrics + + +class TrajectoryEvaluatorMetricInfoProvider(MetricInfoProvider): + """Metric info provider for TrajectoryEvaluator.""" + + def get_metric_info(self) -> MetricInfo: + return MetricInfo( + metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value, + description=( + "This metric compares two tool call trajectories (expected vs." + " actual) for the same user interaction. It performs an exact match" + " on the tool name and arguments for each step in the trajectory." + " A score of 1.0 indicates a perfect match, while 0.0 indicates a" + " mismatch. Higher values are better." + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), + ) + + +class ResponseEvaluatorMetricInfoProvider(MetricInfoProvider): + """Metric info provider for ResponseEvaluator.""" + + def __init__(self, metric_name: str): + self._metric_name = metric_name + + def get_metric_info(self) -> MetricInfo: + """Returns MetricInfo for the given metric name.""" + if PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value == self._metric_name: + return MetricInfo( + metric_name=PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value, + description=( + "This metric evaluates how coherent agent's response was. Value" + " range of this metric is [1,5], with values closer to 5 more" + " desirable." + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=1.0, max_value=5.0) + ), + ) + elif PrebuiltMetrics.RESPONSE_MATCH_SCORE.value == self._metric_name: + return MetricInfo( + metric_name=PrebuiltMetrics.RESPONSE_MATCH_SCORE.value, + description=( + "This metric evaluates if the agent's final response matches a" + " golden/expected final response using Rouge_1 metric. Value" + " range for this metric is [0,1], with values closer to 1 more" + " desirable." + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), + ) + else: + raise ValueError(f"`{self._metric_name}` is not supported.") + + +class SafetyEvaluatorV1MetricInfoProvider(MetricInfoProvider): + """Metric info provider for SafetyEvaluatorV1.""" + + def get_metric_info(self) -> MetricInfo: + return MetricInfo( + metric_name=PrebuiltMetrics.SAFETY_V1.value, + description=( + "This metric evaluates the safety (harmlessness) of an Agent's" + " Response. Value range of the metric is [0, 1], with values closer" + " to 1 to be more desirable (safe)." + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), + ) + + +class FinalResponseMatchV2EvaluatorMetricInfoProvider(MetricInfoProvider): + """Metric info provider for FinalResponseMatchV2Evaluator.""" + + def get_metric_info(self) -> MetricInfo: + return MetricInfo( + metric_name=PrebuiltMetrics.FINAL_RESPONSE_MATCH_V2.value, + description=( + "This metric evaluates if the agent's final response matches a" + " golden/expected final response using LLM as a judge. Value range" + " for this metric is [0,1], with values closer to 1 more desirable." + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), + ) + + +class RubricBasedFinalResponseQualityV1EvaluatorMetricInfoProvider( + MetricInfoProvider +): + """Metric info provider for RubricBasedFinalResponseQualityV1Evaluator.""" + + def get_metric_info(self) -> MetricInfo: + return MetricInfo( + metric_name=PrebuiltMetrics.RUBRIC_BASED_FINAL_RESPONSE_QUALITY_V1.value, + description=( + "This metric assess if the agent's final response against a set of" + " rubrics using LLM as a judge. Value range for this metric is" + " [0,1], with values closer to 1 more desirable." + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), + ) + + +class HallucinationsV1EvaluatorMetricInfoProvider(MetricInfoProvider): + """Metric info provider for HallucinationsV1Evaluator.""" + + def get_metric_info(self) -> MetricInfo: + return MetricInfo( + metric_name=PrebuiltMetrics.HALLUCINATIONS_V1.value, + description=( + "This metric assesses whether a model response contains any false," + " contradictory, or unsupported claims using a LLM as judge. Value" + " range for this metric is [0,1], with values closer to 1 more" + " desirable." + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), + ) + + +class RubricBasedToolUseV1EvaluatorMetricInfoProvider(MetricInfoProvider): + """Metric info provider for RubricBasedToolUseV1Evaluator.""" + + def get_metric_info(self) -> MetricInfo: + return MetricInfo( + metric_name=PrebuiltMetrics.RUBRIC_BASED_TOOL_USE_QUALITY_V1.value, + description=( + "This metric assess if the agent's usage of tools against a set of" + " rubrics using LLM as a judge. Value range for this metric is" + " [0,1], with values closer to 1 more desirable." + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), + ) + + +class PerTurnUserSimulatorQualityV1MetricInfoProvider(MetricInfoProvider): + """Metric info provider for PerTurnUserSimulatorQualityV1.""" + + def get_metric_info(self) -> MetricInfo: + return MetricInfo( + metric_name=PrebuiltMetrics.PER_TURN_USER_SIMULATOR_QUALITY_V1, + description=( + "This metric evaluates if the user messages generated by a " + "user simulator follow the given conversation scenario. It " + "validates each message separately. The resulting metric " + "computes the percentage of user messages that we mark as " + "valid. The value range for this metric is [0,1], with values " + "closer to 1 more desirable. " + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), + ) diff --git a/src/google/adk/evaluation/response_evaluator.py b/src/google/adk/evaluation/response_evaluator.py index 3f7f309a65..3fa3754913 100644 --- a/src/google/adk/evaluation/response_evaluator.py +++ b/src/google/adk/evaluation/response_evaluator.py @@ -21,9 +21,6 @@ from .eval_case import ConversationScenario from .eval_case import Invocation from .eval_metrics import EvalMetric -from .eval_metrics import Interval -from .eval_metrics import MetricInfo -from .eval_metrics import MetricValueInfo from .eval_metrics import PrebuiltMetrics from .evaluator import EvaluationResult from .evaluator import Evaluator @@ -76,26 +73,6 @@ def __init__( self._threshold = threshold - @staticmethod - def get_metric_info(metric_name: str) -> MetricInfo: - """Returns MetricInfo for the given metric name.""" - if PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value == metric_name: - return MetricInfo( - metric_name=PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value, - description=( - "This metric evaluates how coherent agent's response was. Value" - " range of this metric is [1,5], with values closer to 5 more" - " desirable." - ), - metric_value_info=MetricValueInfo( - interval=Interval(min_value=1.0, max_value=5.0) - ), - ) - elif PrebuiltMetrics.RESPONSE_MATCH_SCORE.value == metric_name: - return RougeEvaluator.get_metric_info() - else: - raise ValueError(f"`{metric_name}` is not supported.") - @override def evaluate_invocations( self, diff --git a/src/google/adk/evaluation/rubric_based_final_response_quality_v1.py b/src/google/adk/evaluation/rubric_based_final_response_quality_v1.py index 1b4cb68197..90f02d3b35 100644 --- a/src/google/adk/evaluation/rubric_based_final_response_quality_v1.py +++ b/src/google/adk/evaluation/rubric_based_final_response_quality_v1.py @@ -24,10 +24,6 @@ from .eval_case import Invocation from .eval_case import InvocationEvents from .eval_metrics import EvalMetric -from .eval_metrics import Interval -from .eval_metrics import MetricInfo -from .eval_metrics import MetricValueInfo -from .eval_metrics import PrebuiltMetrics from .eval_metrics import RubricsBasedCriterion from .llm_as_judge_utils import get_text_from_content from .llm_as_judge_utils import get_tool_calls_and_responses_as_json_str @@ -266,20 +262,6 @@ def __init__(self, eval_metric: EvalMetric): _RUBRIC_BASED_FINAL_RESPONSE_QUALITY_V1_PROMPT ) - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=PrebuiltMetrics.RUBRIC_BASED_FINAL_RESPONSE_QUALITY_V1.value, - description=( - "This metric assess if the agent's final response against a set of" - " rubrics using LLM as a judge. Value range for this metric is" - " [0,1], with values closer to 1 more desirable." - ), - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) - @override def format_auto_rater_prompt( self, actual_invocation: Invocation, _: Optional[Invocation] diff --git a/src/google/adk/evaluation/rubric_based_tool_use_quality_v1.py b/src/google/adk/evaluation/rubric_based_tool_use_quality_v1.py index 40d48a7cf6..bb64124e45 100644 --- a/src/google/adk/evaluation/rubric_based_tool_use_quality_v1.py +++ b/src/google/adk/evaluation/rubric_based_tool_use_quality_v1.py @@ -23,10 +23,6 @@ from ..utils.feature_decorator import experimental from .eval_case import Invocation from .eval_metrics import EvalMetric -from .eval_metrics import Interval -from .eval_metrics import MetricInfo -from .eval_metrics import MetricValueInfo -from .eval_metrics import PrebuiltMetrics from .eval_metrics import RubricsBasedCriterion from .llm_as_judge_utils import get_text_from_content from .llm_as_judge_utils import get_tool_calls_and_responses_as_json_str @@ -166,20 +162,6 @@ def __init__(self, eval_metric: EvalMetric): ) self._auto_rater_prompt_template = _RUBRIC_BASED_TOOL_USE_QUALITY_V1_PROMPT - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=PrebuiltMetrics.RUBRIC_BASED_TOOL_USE_QUALITY_V1.value, - description=( - "This metric assess if the agent's usage of tools against a set of" - " rubrics using LLM as a judge. Value range for this metric is" - " [0,1], with values closer to 1 more desirable." - ), - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) - @override def format_auto_rater_prompt( self, actual_invocation: Invocation, _: Optional[Invocation] diff --git a/src/google/adk/evaluation/safety_evaluator.py b/src/google/adk/evaluation/safety_evaluator.py index 8c049803a4..f394849e2e 100644 --- a/src/google/adk/evaluation/safety_evaluator.py +++ b/src/google/adk/evaluation/safety_evaluator.py @@ -21,10 +21,6 @@ from .eval_case import ConversationScenario from .eval_case import Invocation from .eval_metrics import EvalMetric -from .eval_metrics import Interval -from .eval_metrics import MetricInfo -from .eval_metrics import MetricValueInfo -from .eval_metrics import PrebuiltMetrics from .evaluator import EvaluationResult from .evaluator import Evaluator from .vertex_ai_eval_facade import _VertexAiEvalFacade @@ -48,20 +44,6 @@ class SafetyEvaluatorV1(Evaluator): def __init__(self, eval_metric: EvalMetric): self._eval_metric = eval_metric - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=PrebuiltMetrics.SAFETY_V1.value, - description=( - "This metric evaluates the safety (harmlessness) of an Agent's" - " Response. Value range of the metric is [0, 1], with values closer" - " to 1 to be more desirable (safe)." - ), - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) - @override def evaluate_invocations( self, diff --git a/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py index 1e97cf75b6..cbd2c87e43 100644 --- a/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py +++ b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py @@ -35,11 +35,7 @@ from ..eval_metrics import BaseCriterion from ..eval_metrics import EvalMetric from ..eval_metrics import EvalStatus -from ..eval_metrics import Interval from ..eval_metrics import LlmBackedUserSimulatorCriterion -from ..eval_metrics import MetricInfo -from ..eval_metrics import MetricValueInfo -from ..eval_metrics import PrebuiltMetrics from ..evaluator import EvaluationResult from ..evaluator import Evaluator from ..evaluator import PerInvocationResult @@ -269,23 +265,6 @@ def _deserialize_criterion(self, eval_metric: EvalMetric) -> BaseCriterion: except ValidationError as e: raise expected_criterion_type_error from e - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=PrebuiltMetrics.PER_TURN_USER_SIMULATOR_QUALITY_V1, - description=( - "This metric evaluates if the user messages generated by a " - "user simulator follow the given conversation scenario. It " - "validates each message separately. The resulting metric " - "computes the percentage of user messages that we mark as " - "valid. The value range for this metric is [0,1], with values " - "closer to 1 more desirable. " - ), - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) - @override async def evaluate_invocations( self, diff --git a/src/google/adk/evaluation/trajectory_evaluator.py b/src/google/adk/evaluation/trajectory_evaluator.py index f795459ce7..ef55c7cced 100644 --- a/src/google/adk/evaluation/trajectory_evaluator.py +++ b/src/google/adk/evaluation/trajectory_evaluator.py @@ -26,10 +26,6 @@ from .eval_case import get_all_tool_calls from .eval_case import Invocation from .eval_metrics import EvalMetric -from .eval_metrics import Interval -from .eval_metrics import MetricInfo -from .eval_metrics import MetricValueInfo -from .eval_metrics import PrebuiltMetrics from .eval_metrics import ToolTrajectoryCriterion from .evaluator import EvalStatus from .evaluator import EvaluationResult @@ -99,22 +95,6 @@ def __init__( self._threshold = threshold self._match_type = ToolTrajectoryCriterion.MatchType.EXACT - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value, - description=( - "This metric compares two tool call trajectories (expected vs." - " actual) for the same user interaction. It performs an exact match" - " on the tool name and arguments for each step in the trajectory." - " A score of 1.0 indicates a perfect match, while 0.0 indicates a" - " mismatch. Higher values are better." - ), - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) - @override def evaluate_invocations( self, diff --git a/tests/unittests/evaluation/test_final_response_match_v1.py b/tests/unittests/evaluation/test_final_response_match_v1.py index d5fe0464f8..eef35d86d6 100644 --- a/tests/unittests/evaluation/test_final_response_match_v1.py +++ b/tests/unittests/evaluation/test_final_response_match_v1.py @@ -139,11 +139,3 @@ def test_rouge_evaluator_multiple_invocations( expected_score, rel=1e-3 ) assert evaluation_result.overall_eval_status == expected_status - - -def test_get_metric_info(): - """Test get_metric_info function for response match metric.""" - metric_info = RougeEvaluator.get_metric_info() - assert metric_info.metric_name == PrebuiltMetrics.RESPONSE_MATCH_SCORE.value - assert metric_info.metric_value_info.interval.min_value == 0.0 - assert metric_info.metric_value_info.interval.max_value == 1.0 diff --git a/tests/unittests/evaluation/test_final_response_match_v2.py b/tests/unittests/evaluation/test_final_response_match_v2.py index a40dbe091d..d82eea20d3 100644 --- a/tests/unittests/evaluation/test_final_response_match_v2.py +++ b/tests/unittests/evaluation/test_final_response_match_v2.py @@ -486,13 +486,3 @@ def test_aggregate_invocation_results(): # Only 4 / 8 invocations are evaluated, and 2 / 4 are valid. assert aggregated_result.overall_score == 0.5 assert aggregated_result.overall_eval_status == EvalStatus.PASSED - - -def test_get_metric_info(): - """Test get_metric_info function for Final Response Match V2 metric.""" - metric_info = FinalResponseMatchV2Evaluator.get_metric_info() - assert ( - metric_info.metric_name == PrebuiltMetrics.FINAL_RESPONSE_MATCH_V2.value - ) - assert metric_info.metric_value_info.interval.min_value == 0.0 - assert metric_info.metric_value_info.interval.max_value == 1.0 diff --git a/tests/unittests/evaluation/test_metric_evaluator_registry.py b/tests/unittests/evaluation/test_metric_evaluator_registry.py index 60b39d5431..ca5c70267c 100644 --- a/tests/unittests/evaluation/test_metric_evaluator_registry.py +++ b/tests/unittests/evaluation/test_metric_evaluator_registry.py @@ -19,102 +19,192 @@ from google.adk.evaluation.eval_metrics import Interval from google.adk.evaluation.eval_metrics import MetricInfo from google.adk.evaluation.eval_metrics import MetricValueInfo +from google.adk.evaluation.eval_metrics import PrebuiltMetrics from google.adk.evaluation.evaluator import Evaluator +from google.adk.evaluation.metric_evaluator_registry import FinalResponseMatchV2EvaluatorMetricInfoProvider +from google.adk.evaluation.metric_evaluator_registry import HallucinationsV1EvaluatorMetricInfoProvider from google.adk.evaluation.metric_evaluator_registry import MetricEvaluatorRegistry +from google.adk.evaluation.metric_evaluator_registry import PerTurnUserSimulatorQualityV1MetricInfoProvider +from google.adk.evaluation.metric_evaluator_registry import ResponseEvaluatorMetricInfoProvider +from google.adk.evaluation.metric_evaluator_registry import RubricBasedFinalResponseQualityV1EvaluatorMetricInfoProvider +from google.adk.evaluation.metric_evaluator_registry import RubricBasedToolUseV1EvaluatorMetricInfoProvider +from google.adk.evaluation.metric_evaluator_registry import SafetyEvaluatorV1MetricInfoProvider +from google.adk.evaluation.metric_evaluator_registry import TrajectoryEvaluatorMetricInfoProvider import pytest _DUMMY_METRIC_NAME = "dummy_metric_name" +_DUMMY_METRIC_INFO = MetricInfo( + metric_name=_DUMMY_METRIC_NAME, + description="Dummy metric description", + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), +) +_ANOTHER_DUMMY_METRIC_INFO = MetricInfo( + metric_name=_DUMMY_METRIC_NAME, + description="Another dummy metric description", + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), +) -class TestMetricEvaluatorRegistry: - """Test cases for MetricEvaluatorRegistry.""" +class DummyEvaluator(Evaluator): - @pytest.fixture - def registry(self): - return MetricEvaluatorRegistry() + def __init__(self, eval_metric: EvalMetric): + self._eval_metric = eval_metric - class DummyEvaluator(Evaluator): + def evaluate_invocations(self, actual_invocations, expected_invocations): + return "dummy_result" - def __init__(self, eval_metric: EvalMetric): - self._eval_metric = eval_metric - def evaluate_invocations(self, actual_invocations, expected_invocations): - return "dummy_result" +class AnotherDummyEvaluator(Evaluator): - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=_DUMMY_METRIC_NAME, - description="Dummy metric description", - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) + def __init__(self, eval_metric: EvalMetric): + self._eval_metric = eval_metric - class AnotherDummyEvaluator(Evaluator): + def evaluate_invocations(self, actual_invocations, expected_invocations): + return "another_dummy_result" - def __init__(self, eval_metric: EvalMetric): - self._eval_metric = eval_metric - def evaluate_invocations(self, actual_invocations, expected_invocations): - return "another_dummy_result" +class TestMetricEvaluatorRegistry: + """Test cases for MetricEvaluatorRegistry.""" - @staticmethod - def get_metric_info() -> MetricInfo: - return MetricInfo( - metric_name=_DUMMY_METRIC_NAME, - description="Another dummy metric description", - metric_value_info=MetricValueInfo( - interval=Interval(min_value=0.0, max_value=1.0) - ), - ) + @pytest.fixture + def registry(self): + return MetricEvaluatorRegistry() def test_register_evaluator(self, registry): - metric_info = TestMetricEvaluatorRegistry.DummyEvaluator.get_metric_info() registry.register_evaluator( - metric_info, - TestMetricEvaluatorRegistry.DummyEvaluator, + _DUMMY_METRIC_INFO, + DummyEvaluator, ) assert _DUMMY_METRIC_NAME in registry._registry assert registry._registry[_DUMMY_METRIC_NAME] == ( - TestMetricEvaluatorRegistry.DummyEvaluator, - metric_info, + DummyEvaluator, + _DUMMY_METRIC_INFO, ) def test_register_evaluator_updates_existing(self, registry): - metric_info = TestMetricEvaluatorRegistry.DummyEvaluator.get_metric_info() registry.register_evaluator( - metric_info, - TestMetricEvaluatorRegistry.DummyEvaluator, + _DUMMY_METRIC_INFO, + DummyEvaluator, ) assert registry._registry[_DUMMY_METRIC_NAME] == ( - TestMetricEvaluatorRegistry.DummyEvaluator, - metric_info, + DummyEvaluator, + _DUMMY_METRIC_INFO, ) - metric_info = ( - TestMetricEvaluatorRegistry.AnotherDummyEvaluator.get_metric_info() - ) registry.register_evaluator( - metric_info, TestMetricEvaluatorRegistry.AnotherDummyEvaluator + _ANOTHER_DUMMY_METRIC_INFO, AnotherDummyEvaluator ) assert registry._registry[_DUMMY_METRIC_NAME] == ( - TestMetricEvaluatorRegistry.AnotherDummyEvaluator, - metric_info, + AnotherDummyEvaluator, + _ANOTHER_DUMMY_METRIC_INFO, ) def test_get_evaluator(self, registry): - metric_info = TestMetricEvaluatorRegistry.DummyEvaluator.get_metric_info() registry.register_evaluator( - metric_info, - TestMetricEvaluatorRegistry.DummyEvaluator, + _DUMMY_METRIC_INFO, + DummyEvaluator, ) eval_metric = EvalMetric(metric_name=_DUMMY_METRIC_NAME, threshold=0.5) evaluator = registry.get_evaluator(eval_metric) - assert isinstance(evaluator, TestMetricEvaluatorRegistry.DummyEvaluator) + assert isinstance(evaluator, DummyEvaluator) def test_get_evaluator_not_found(self, registry): eval_metric = EvalMetric(metric_name="non_existent_metric", threshold=0.5) with pytest.raises(NotFoundError): registry.get_evaluator(eval_metric) + + +class TestMetricInfoProviders: + """Test cases for MetricInfoProviders.""" + + def test_trajectory_evaluator_metric_info_provider(self): + metric_info = TrajectoryEvaluatorMetricInfoProvider().get_metric_info() + assert ( + metric_info.metric_name + == PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value + ) + assert metric_info.metric_value_info.interval.min_value == 0.0 + assert metric_info.metric_value_info.interval.max_value == 1.0 + + def test_response_evaluator_metric_info_provider_eval_score(self): + metric_info = ResponseEvaluatorMetricInfoProvider( + PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value + ).get_metric_info() + assert ( + metric_info.metric_name + == PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value + ) + assert metric_info.metric_value_info.interval.min_value == 1.0 + assert metric_info.metric_value_info.interval.max_value == 5.0 + + def test_response_evaluator_metric_info_provider_match_score(self): + metric_info = ResponseEvaluatorMetricInfoProvider( + PrebuiltMetrics.RESPONSE_MATCH_SCORE.value + ).get_metric_info() + assert metric_info.metric_name == PrebuiltMetrics.RESPONSE_MATCH_SCORE.value + assert metric_info.metric_value_info.interval.min_value == 0.0 + assert metric_info.metric_value_info.interval.max_value == 1.0 + + def test_safety_evaluator_v1_metric_info_provider(self): + metric_info = SafetyEvaluatorV1MetricInfoProvider().get_metric_info() + assert metric_info.metric_name == PrebuiltMetrics.SAFETY_V1.value + assert metric_info.metric_value_info.interval.min_value == 0.0 + assert metric_info.metric_value_info.interval.max_value == 1.0 + + def test_final_response_match_v2_evaluator_metric_info_provider(self): + metric_info = ( + FinalResponseMatchV2EvaluatorMetricInfoProvider().get_metric_info() + ) + assert ( + metric_info.metric_name == PrebuiltMetrics.FINAL_RESPONSE_MATCH_V2.value + ) + assert metric_info.metric_value_info.interval.min_value == 0.0 + assert metric_info.metric_value_info.interval.max_value == 1.0 + + def test_rubric_based_final_response_quality_v1_evaluator_metric_info_provider( + self, + ): + metric_info = ( + RubricBasedFinalResponseQualityV1EvaluatorMetricInfoProvider().get_metric_info() + ) + assert ( + metric_info.metric_name + == PrebuiltMetrics.RUBRIC_BASED_FINAL_RESPONSE_QUALITY_V1.value + ) + assert metric_info.metric_value_info.interval.min_value == 0.0 + assert metric_info.metric_value_info.interval.max_value == 1.0 + + def test_hallucinations_v1_evaluator_metric_info_provider(self): + metric_info = ( + HallucinationsV1EvaluatorMetricInfoProvider().get_metric_info() + ) + assert metric_info.metric_name == PrebuiltMetrics.HALLUCINATIONS_V1.value + assert metric_info.metric_value_info.interval.min_value == 0.0 + assert metric_info.metric_value_info.interval.max_value == 1.0 + + def test_rubric_based_tool_use_v1_evaluator_metric_info_provider(self): + metric_info = ( + RubricBasedToolUseV1EvaluatorMetricInfoProvider().get_metric_info() + ) + assert ( + metric_info.metric_name + == PrebuiltMetrics.RUBRIC_BASED_TOOL_USE_QUALITY_V1.value + ) + assert metric_info.metric_value_info.interval.min_value == 0.0 + assert metric_info.metric_value_info.interval.max_value == 1.0 + + def test_per_turn_user_simulator_quality_v1_metric_info_provider(self): + metric_info = ( + PerTurnUserSimulatorQualityV1MetricInfoProvider().get_metric_info() + ) + assert ( + metric_info.metric_name + == PrebuiltMetrics.PER_TURN_USER_SIMULATOR_QUALITY_V1.value + ) + assert metric_info.metric_value_info.interval.min_value == 0.0 + assert metric_info.metric_value_info.interval.max_value == 1.0 diff --git a/tests/unittests/evaluation/test_response_evaluator.py b/tests/unittests/evaluation/test_response_evaluator.py index 548ae2209a..bd82b51f25 100644 --- a/tests/unittests/evaluation/test_response_evaluator.py +++ b/tests/unittests/evaluation/test_response_evaluator.py @@ -118,29 +118,3 @@ def test_evaluate_invocations_coherence_metric_passed(self, mocker): assert [m.name for m in mock_kwargs["metrics"]] == [ vertexai_types.PrebuiltMetric.COHERENCE.name ] - - def test_get_metric_info_response_evaluation_score(self): - """Test get_metric_info function for response evaluation metric.""" - metric_info = ResponseEvaluator.get_metric_info( - PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value - ) - assert ( - metric_info.metric_name - == PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value - ) - assert metric_info.metric_value_info.interval.min_value == 1.0 - assert metric_info.metric_value_info.interval.max_value == 5.0 - - def test_get_metric_info_response_match_score(self): - """Test get_metric_info function for response match metric.""" - metric_info = ResponseEvaluator.get_metric_info( - PrebuiltMetrics.RESPONSE_MATCH_SCORE.value - ) - assert metric_info.metric_name == PrebuiltMetrics.RESPONSE_MATCH_SCORE.value - assert metric_info.metric_value_info.interval.min_value == 0.0 - assert metric_info.metric_value_info.interval.max_value == 1.0 - - def test_get_metric_info_invalid(self): - """Test get_metric_info function for invalid metric.""" - with pytest.raises(ValueError): - ResponseEvaluator.get_metric_info("invalid_metric") diff --git a/tests/unittests/evaluation/test_rubric_based_tool_use_quality_v1.py b/tests/unittests/evaluation/test_rubric_based_tool_use_quality_v1.py index aed20a3a7a..9448249a1b 100644 --- a/tests/unittests/evaluation/test_rubric_based_tool_use_quality_v1.py +++ b/tests/unittests/evaluation/test_rubric_based_tool_use_quality_v1.py @@ -136,15 +136,3 @@ def test_format_auto_rater_prompt_with_intermediate_data( assert '"name": "test_func"' in prompt assert '"tool_response":' in prompt assert '"result": "ok"' in prompt - - -def test_get_metric_info(evaluator: RubricBasedToolUseV1Evaluator): - """Tests the get_metric_info method.""" - metric_info = evaluator.get_metric_info() - assert ( - metric_info.metric_name - == PrebuiltMetrics.RUBRIC_BASED_TOOL_USE_QUALITY_V1.value - ) - assert "agent's usage of tools" in metric_info.description - assert metric_info.metric_value_info.interval.min_value == 0.0 - assert metric_info.metric_value_info.interval.max_value == 1.0 diff --git a/tests/unittests/evaluation/test_safety_evaluator.py b/tests/unittests/evaluation/test_safety_evaluator.py index 69a1594474..091798afa7 100644 --- a/tests/unittests/evaluation/test_safety_evaluator.py +++ b/tests/unittests/evaluation/test_safety_evaluator.py @@ -76,10 +76,3 @@ def test_evaluate_invocations_coherence_metric_passed(self, mocker): assert [m.name for m in mock_kwargs["metrics"]] == [ vertexai_types.PrebuiltMetric.SAFETY.name ] - - def test_get_metric_info(self): - """Test get_metric_info function for Safety metric.""" - metric_info = SafetyEvaluatorV1.get_metric_info() - assert metric_info.metric_name == PrebuiltMetrics.SAFETY_V1.value - assert metric_info.metric_value_info.interval.min_value == 0.0 - assert metric_info.metric_value_info.interval.max_value == 1.0 diff --git a/tests/unittests/evaluation/test_trajectory_evaluator.py b/tests/unittests/evaluation/test_trajectory_evaluator.py index 0795739768..5edbe06807 100644 --- a/tests/unittests/evaluation/test_trajectory_evaluator.py +++ b/tests/unittests/evaluation/test_trajectory_evaluator.py @@ -30,16 +30,6 @@ ) -def test_get_metric_info(): - """Test get_metric_info function for tool trajectory avg metric.""" - metric_info = TrajectoryEvaluator.get_metric_info() - assert ( - metric_info.metric_name == PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value - ) - assert metric_info.metric_value_info.interval.min_value == 0.0 - assert metric_info.metric_value_info.interval.max_value == 1.0 - - @pytest.fixture def evaluator() -> TrajectoryEvaluator: """Returns a TrajectoryEvaluator."""