From e515e0f321a259016c5e5f6b388ecf02ae343ba7 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 15 Dec 2025 10:05:07 -0800 Subject: [PATCH 1/4] feat: Introduce a post-hoc, per-turn evaluator for user simulations PiperOrigin-RevId: 844818512 --- src/google/adk/evaluation/eval_metrics.py | 15 + src/google/adk/evaluation/evaluator.py | 4 + .../adk/evaluation/final_response_match_v1.py | 2 + .../adk/evaluation/hallucinations_v1.py | 2 + src/google/adk/evaluation/llm_as_judge.py | 2 + .../adk/evaluation/local_eval_service.py | 4 + .../evaluation/metric_evaluator_registry.py | 5 + .../adk/evaluation/response_evaluator.py | 2 + .../per_turn_user_simulator_quality_v1.py | 506 +++++++++++++++ .../adk/evaluation/vertex_ai_eval_facade.py | 2 + ...est_per_turn_user_simulation_quality_v1.py | 613 ++++++++++++++++++ 11 files changed, 1157 insertions(+) create mode 100644 src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py create mode 100644 tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py diff --git a/src/google/adk/evaluation/eval_metrics.py b/src/google/adk/evaluation/eval_metrics.py index 79a587eb04..b7c544ccad 100644 --- a/src/google/adk/evaluation/eval_metrics.py +++ b/src/google/adk/evaluation/eval_metrics.py @@ -57,6 +57,8 @@ class PrebuiltMetrics(Enum): RUBRIC_BASED_TOOL_USE_QUALITY_V1 = "rubric_based_tool_use_quality_v1" + PER_TURN_USER_SIMULATOR_QUALITY_V1 = "per_turn_user_simulator_quality_v1" + MetricName: TypeAlias = Union[str, PrebuiltMetrics] Threshold: TypeAlias = float @@ -223,6 +225,19 @@ class MatchType(Enum): ) +class LlmBackedUserSimulatorCriterion(LlmAsAJudgeCriterion): + """Criterion for LLM-backed User Simulator Evaluators.""" + + stop_signal: str = Field( + default="", + description=( + "Stop signal to validate the successful completion of a conversation." + " For optimal performance, this should match the one in the User" + " Simulator." + ), + ) + + class EvalMetric(EvalBaseModel): """A metric used to evaluate a particular aspect of an eval case.""" diff --git a/src/google/adk/evaluation/evaluator.py b/src/google/adk/evaluation/evaluator.py index c235bb1e71..c41fed74d9 100644 --- a/src/google/adk/evaluation/evaluator.py +++ b/src/google/adk/evaluation/evaluator.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from typing_extensions import TypeAlias +from .eval_case import ConversationScenario from .eval_case import Invocation from .eval_metrics import BaseCriterion from .eval_metrics import EvalStatus @@ -62,6 +63,7 @@ def evaluate_invocations( self, actual_invocations: list[Invocation], expected_invocations: Optional[list[Invocation]], + conversation_scenario: Optional[ConversationScenario], ) -> EvaluationResult: """Returns EvaluationResult after performing evaluations using actual and expected invocations. @@ -72,5 +74,7 @@ def evaluate_invocations( usually act as a benchmark/golden response. If these are specified usually the expectation is that the length of this list and actual invocation is the same. + conversation_scenario: An optional conversation scenario for multi-turn + conversations. """ raise NotImplementedError() diff --git a/src/google/adk/evaluation/final_response_match_v1.py b/src/google/adk/evaluation/final_response_match_v1.py index 83d0d4fc01..06a6440882 100644 --- a/src/google/adk/evaluation/final_response_match_v1.py +++ b/src/google/adk/evaluation/final_response_match_v1.py @@ -20,6 +20,7 @@ from typing_extensions import override from ..dependencies.rouge_scorer import rouge_scorer +from .eval_case import ConversationScenario from .eval_case import Invocation from .eval_metrics import EvalMetric from .eval_metrics import Interval @@ -60,6 +61,7 @@ def evaluate_invocations( self, actual_invocations: list[Invocation], expected_invocations: Optional[list[Invocation]], + _: Optional[ConversationScenario] = None, ) -> EvaluationResult: if expected_invocations is None: raise ValueError("expected_invocations is required for this metric.") diff --git a/src/google/adk/evaluation/hallucinations_v1.py b/src/google/adk/evaluation/hallucinations_v1.py index 587774a74c..84a4a115b2 100644 --- a/src/google/adk/evaluation/hallucinations_v1.py +++ b/src/google/adk/evaluation/hallucinations_v1.py @@ -34,6 +34,7 @@ from ..utils.feature_decorator import experimental from ._retry_options_utils import add_default_retry_options_if_not_present from .app_details import AppDetails +from .eval_case import ConversationScenario from .eval_case import Invocation from .eval_case import InvocationEvent from .eval_case import InvocationEvents @@ -720,6 +721,7 @@ async def evaluate_invocations( self, actual_invocations: list[Invocation], expected_invocations: Optional[list[Invocation]], + _: Optional[ConversationScenario] = None, ) -> EvaluationResult: # expected_invocations are not required by the metric and if they are not # supplied, we provide a list of None to rest of the code. diff --git a/src/google/adk/evaluation/llm_as_judge.py b/src/google/adk/evaluation/llm_as_judge.py index 226884d3a2..0f2d890139 100644 --- a/src/google/adk/evaluation/llm_as_judge.py +++ b/src/google/adk/evaluation/llm_as_judge.py @@ -29,6 +29,7 @@ from ..utils.feature_decorator import experimental from ._retry_options_utils import add_default_retry_options_if_not_present from .common import EvalBaseModel +from .eval_case import ConversationScenario from .eval_case import Invocation from .eval_metrics import BaseCriterion from .eval_metrics import EvalMetric @@ -118,6 +119,7 @@ async def evaluate_invocations( self, actual_invocations: list[Invocation], expected_invocations: Optional[list[Invocation]], + _: Optional[ConversationScenario] = None, ) -> EvaluationResult: if self._expected_invocations_required and expected_invocations is None: raise ValueError("expected_invocations is needed by this metric.") diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index f454266e00..5acbff0680 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -40,6 +40,7 @@ from .base_eval_service import InferenceRequest from .base_eval_service import InferenceResult from .base_eval_service import InferenceStatus +from .eval_case import ConversationScenario from .eval_case import Invocation from .eval_metrics import EvalMetric from .eval_metrics import EvalMetricResult @@ -256,6 +257,7 @@ async def _evaluate_single_inference_result( eval_metric=eval_metric, actual_invocations=inference_result.inferences, expected_invocations=eval_case.conversation, + conversation_scenario=eval_case.conversation_scenario, ) except Exception as e: # We intentionally catch the Exception as we don't want failures to @@ -345,6 +347,7 @@ async def _evaluate_metric( eval_metric: EvalMetric, actual_invocations: list[Invocation], expected_invocations: Optional[list[Invocation]], + conversation_scenario: Optional[ConversationScenario], ) -> EvaluationResult: """Returns EvaluationResult obtained from evaluating a metric using an Evaluator.""" @@ -359,6 +362,7 @@ async def _evaluate_metric( return await metric_evaluator.evaluate_invocations( actual_invocations=actual_invocations, expected_invocations=expected_invocations, + conversation_scenario=conversation_scenario, ) else: # Metrics that perform computation synchronously, mostly these don't diff --git a/src/google/adk/evaluation/metric_evaluator_registry.py b/src/google/adk/evaluation/metric_evaluator_registry.py index 0e8c54d8fb..0d0fb773ca 100644 --- a/src/google/adk/evaluation/metric_evaluator_registry.py +++ b/src/google/adk/evaluation/metric_evaluator_registry.py @@ -28,6 +28,7 @@ from .rubric_based_final_response_quality_v1 import RubricBasedFinalResponseQualityV1Evaluator from .rubric_based_tool_use_quality_v1 import RubricBasedToolUseV1Evaluator from .safety_evaluator import SafetyEvaluatorV1 +from .simulation.per_turn_user_simulator_quality_v1 import PerTurnUserSimulatorQualityV1 from .trajectory_evaluator import TrajectoryEvaluator logger = logging.getLogger("google_adk." + __name__) @@ -126,6 +127,10 @@ def _get_default_metric_evaluator_registry() -> MetricEvaluatorRegistry: metric_info=RubricBasedToolUseV1Evaluator.get_metric_info(), evaluator=RubricBasedToolUseV1Evaluator, ) + metric_evaluator_registry.register_evaluator( + metric_info=PerTurnUserSimulatorQualityV1.get_metric_info(), + evaluator=PerTurnUserSimulatorQualityV1, + ) return metric_evaluator_registry diff --git a/src/google/adk/evaluation/response_evaluator.py b/src/google/adk/evaluation/response_evaluator.py index 685222f2f7..5052aca2ac 100644 --- a/src/google/adk/evaluation/response_evaluator.py +++ b/src/google/adk/evaluation/response_evaluator.py @@ -18,6 +18,7 @@ from typing_extensions import override +from .eval_case import ConversationScenario from .eval_case import Invocation from .eval_metrics import EvalMetric from .eval_metrics import Interval @@ -100,6 +101,7 @@ def evaluate_invocations( self, actual_invocations: list[Invocation], expected_invocations: Optional[list[Invocation]], + _: Optional[ConversationScenario] = None, ) -> EvaluationResult: # If the metric is response_match_score, just use the RougeEvaluator. if self._metric_name == PrebuiltMetrics.RESPONSE_MATCH_SCORE.value: diff --git a/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py new file mode 100644 index 0000000000..5624bc0ec9 --- /dev/null +++ b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py @@ -0,0 +1,506 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import re +from typing import ClassVar +from typing import Optional + +from google.genai import types as genai_types +from pydantic import ValidationError +from typing_extensions import override + +from ...models.base_llm import BaseLlm +from ...models.llm_request import LlmRequest +from ...models.llm_response import LlmResponse +from ...models.registry import LLMRegistry +from ...utils.context_utils import Aclosing +from ...utils.feature_decorator import experimental +from .._retry_options_utils import add_default_retry_options_if_not_present +from ..eval_case import ConversationScenario +from ..eval_case import Invocation +from ..eval_metrics import BaseCriterion +from ..eval_metrics import EvalMetric +from ..eval_metrics import EvalStatus +from ..eval_metrics import Interval +from ..eval_metrics import LlmBackedUserSimulatorCriterion +from ..eval_metrics import MetricInfo +from ..eval_metrics import MetricValueInfo +from ..eval_metrics import PrebuiltMetrics +from ..evaluator import EvaluationResult +from ..evaluator import Evaluator +from ..evaluator import PerInvocationResult +from ..llm_as_judge import AutoRaterScore +from ..llm_as_judge_utils import get_eval_status +from ..llm_as_judge_utils import get_text_from_content +from ..llm_as_judge_utils import Label + +_LATEST_TURN_USER_SIMULATOR_EVALUATOR_PROMPT = """ +You are a data scientist tasked with evaluating the quality of a User Simulator that is interacting with an Agent. +Your task is to determine if the Generated User Response is consistent with: + - The Conversation Plan: A list of high-level goals that the User Simulator is expected to achieve in the conversation. + - The Conversation History: The exchange between the User Simulator and the Agent so far. +To determine this, we provide specific Evaluation Criteria that must be satisfied by the Generated User Response. + +# Definition of Conversation Plan +The Conversation Plan specifies the goals that the User Simulator must execute. +The Conversation Plan also specifies the information and details that are needed to complete the goals. +The Conversation Plan is sequential in nature and the User Simulator must ensure the sequence is followed. + +# Definition of Conversation History +The Conversation History is the actual dialogue between the User Simulator and the Agent. +The Conversation History may not be complete, but the exsisting dialogue should adhere to the Conversation Plan. +The Conversation History may contain instances where the User Simulator troubleshoots an incorrect/inappropriate response from the Agent in order to enforce the Conversation Plan. +The Conversation History is finished only when the User Simulator outputs `{stop_signal}` in its response. If this token is missing, the conversation between the User Simulator and the Agent has not finished, and more turns can be generated. + +# Definition of Generated User Response +The Generated User Response is a the next user response in the conversation between a User Simulator and an Agent. +The Generated User Response was generated by the User Simulator based on a Conversation Plan and Conversation History. + +# Evaluation Criteria +Your task is to evaluate the Generated User Response on a PASS/FAIL basis looking for specific errors. +The Generated User Response is marked as PASS unless it contains any of the Violations listed below, in which case it is marked as FAIL. + +** CONVERSATION_PLAN_FOLLOWED ** +Does the Generated User Response stick to the Conversation Plan? + +Mark as FAIL if any of the following Violations occur: +- The Generated User Response repeats a high-level goal that was already completed in previous turns. +- The Generated User Response provides details for a high-level goal that was already completed. +- The Generated User Response response agrees to change the topic or perform a task not listed in the Conversation Plan. +- The Generated User Response invents a new goal not present in the Conversation Plan. +- The Generated User Response invents details (e.g., a made-up phone number or address) not provided in the Conversation Plan. + +** STOP_CONDITION_FOLLOWED ** +Did the conversation end exactly when it was supposed to? + +Mark as FAIL if any of the following Violations occur: +- The conversation should have ended, but the Generated User Response did not use `{stop_signal}`. +- The Generated User Response used `{stop_signal}`, but tasks in the Conversation Plan are still incomplete AND the Agent has not failed. +- The Agent successfully transferred the User Simulator to a human/live agent, but the Generated User Response continued instead of using `{stop_signal}`. + +** USER_GOAL_ORIENTED ** +Is the User Simulator acting naturally, or is it "data dumping"? + +Mark as FAIL if any of the following Violations occur: +- The Generated User Response provides specific details for a high-level goal (email content, recipient address, phone numbers) BEFORE the Agent has explicitly asked for them. +- The Generated User Response tries to accomplish more than one high-level task in a single turn. + +** LIMITED_TROUBLESHOOTING ** +Does the User Simulator have the correct amount of patience? (Note: Please check the conversation history and count the number of Agent errors). + +Mark as FAIL if any of the following Violations occur: +- The Generated User Response ends the conversation immediately after the first Agent error. +- On the second Agent error, the Generated User Response response continues the conversation without using `{stop_signal}`. +- After the second Agent error, the Generated User Response tries to continue the conversation or continues addressing errors without using `{stop_signal}`. + +** RESPONSIVENESS ** +Does the User Simulator answer what is asked? + +Mark as FAIL if any of the following Violations occur: +- The Agent asked a question (or multiple questions), and the Generated User Response failed to address one or all of them. +- The Agent asked for information NOT in the Conversation Plan, and the Generated User Response made up an answer instead of stating, e.g., "I don't know" or "I don't have that info." + +** CORRECTS_AGENT ** +Does the User Simulator catch the Agent's mistakes? + +Mark as FAIL if any of the following Violations occur: +- The Agent provided incorrect information, but the Generated User Response continued as if it was correct. +- The Agent made a dangerous assumption (e.g., sending an email without asking for the content first), and the Generated User Response continues without correcting the Agent. + +** CONVERSATIONAL_TONE ** +Does the User Simulator sound like a human? + +Mark as FAIL if any of the following Violations occur: +- The Generated User Response uses overly complex sentence structures, or uses technical jargon inappropriately. +- The Generated User Response is sterile and purely functional (direct commands) with no natural conversational framing. +- The Generated User Response is too formal in nature, employing overly polite phrases and expressions. +- The Generated User Response is a "wall of text" where a simple sentence would suffice. + +# Output Format +Format your response in the following JSON format: +{{ + "criteria": [ + {{ + "name": "CRITERIA_NAME_1", + "reasoning": "reasoning", + "passes": True or False, + }}, + {{ + "name": "CRITERIA_NAME_2", + "reasoning": "reasoning", + "passes": True or False, + }}, + ... + ], + "is_valid": True or False, +}} + +# Conversation Plan +{conversation_plan} + +# Conversation History +{conversation_history} + +# Generated User Response +{generated_user_response} +""".strip() + + +def _parse_llm_response(response: str) -> Label: + """Parses the LLM response and extracts the final label. + + Args: + response: LLM response. + + Returns: + The extracted label, either VALID, INVALID, or NOT_FOUND. + """ + # Regex matching the label field in the response. + is_valid_match = re.search( + r'"is_valid":\s*\[*[\n\s]*"*([^"^\]^\s]*)"*[\n\s]*\]*\s*[,\n\}]', + response, + ) + + # If there was not match for "is_valid", return NOT_FOUND + if is_valid_match is None: + return Label.NOT_FOUND + + # Remove any trailing whitespace, commas, or end-brackets from the label. + label = is_valid_match.group(1).strip(r"\s,\}").lower() + if label in [ + Label.INVALID.value, + Label.ALMOST.value, + Label.FALSE.value, + *Label.PARTIALLY_VALID.value, + ]: + return Label.INVALID + elif label in [Label.VALID.value, Label.TRUE.value]: + return Label.VALID + else: + return Label.NOT_FOUND + + +def _format_conversation_history(invocations: list[Invocation]) -> str: + conversation_history = [] + for invocation in invocations: + if invocation.user_content is not None: + conversation_history.append( + f"user: {get_text_from_content(invocation.user_content)}" + ) + + final_response = invocation.final_response + if final_response is not None: + conversation_history.append( + f"{final_response.role}: {get_text_from_content(final_response)}" + ) + return "\n\n".join(conversation_history) + + +def _get_stop_signal_invocation(stop_signal: str) -> Invocation: + return Invocation( + invocation_id="stop_signal_proxy_invocation", + user_content=genai_types.Content( + parts=[genai_types.Part(text=stop_signal)] + ), + ) + + +@experimental +class PerTurnUserSimulatorQualityV1(Evaluator): + """Per turn user simulator evaluator. + + This evaluator verifies that the conversation from a user simulator sticks + to the given conversation scenario: + - In the first turn, it verifies that the user simulator output the + specified starting prompt. + - For all the other turns, it verifies that the user simulator stuck to the + conversation plan. + - It also verifies that the user simulator finished the conversation + appropriately. + This evaluator uses an LLM to verify all turns except the first one. It + aggregates repeated invocation samples by taking majority vote. The overall + score is the fraction of turns of the conversation before the verifier + detects an issue with the user simulator. + """ + + criterion_type: ClassVar[type[LlmBackedUserSimulatorCriterion]] = ( + LlmBackedUserSimulatorCriterion + ) + + def __init__( + self, + eval_metric: EvalMetric, + ): + self._eval_metric = eval_metric + self._criterion = self._deserialize_criterion(eval_metric) + + self._prompt_template = _LATEST_TURN_USER_SIMULATOR_EVALUATOR_PROMPT + + self._llm_options = self._criterion.judge_model_options + self._stop_signal = self._criterion.stop_signal + self._llm = self._setup_llm() + + def _deserialize_criterion(self, eval_metric: EvalMetric) -> BaseCriterion: + expected_criterion_type_error = ValueError( + f"`{eval_metric.metric_name}` metric expects a criterion of type" + f" `{self.criterion_type}`." + ) + try: + if self._eval_metric.criterion is None: + raise expected_criterion_type_error + + return self.criterion_type.model_validate( + self._eval_metric.criterion.model_dump() + ) + except ValidationError as e: + raise expected_criterion_type_error from e + + @staticmethod + def get_metric_info() -> MetricInfo: + return MetricInfo( + metric_name=PrebuiltMetrics.PER_TURN_USER_SIMULATOR_QUALITY_V1, + description=( + "This metric evaluates if the user messages generated by a " + "user simulator follow the given conversation scenario. It " + "validates each message separately. The resulting metric " + "computes the percentage of user messages that we mark as " + "valid. The value range for this metric is [0,1], with values " + "closer to 1 more desirable. " + ), + metric_value_info=MetricValueInfo( + interval=Interval(min_value=0.0, max_value=1.0) + ), + ) + + @override + async def evaluate_invocations( + self, + actual_invocations: list[Invocation], + expected_invocations: Optional[list[Invocation]], + conversation_scenario: Optional[ConversationScenario], + ) -> EvaluationResult: + del expected_invocations + + # Evaluate the first invocation contains the given starting prompt. + results = [ + self._evaluate_first_turn(actual_invocations[0], conversation_scenario) + ] + + # Evaluate the rest of the invocations. + for i, invocation in enumerate(actual_invocations): + # skip the first invocation. + if i == 0: + continue + + result = await self._evaluate_intermediate_turn( + invocation_at_step=invocation, + invocation_history=actual_invocations[:i], + conversation_scenario=conversation_scenario, + ) + results.append(result) + + if not results: + return EvaluationResult() + + # Evaluate whether the conversation ended correctly. + stop_signal_evaluation = await self._evaluate_stop_signal_turn( + invocation_history=actual_invocations, + conversation_scenario=conversation_scenario, + ) + + # If the conversation did not end correctly, indicate so by marking the + # last user turn as failed. + if stop_signal_evaluation.eval_status == EvalStatus.FAILED: + results[-1] = stop_signal_evaluation + + return self._aggregate_conversation_results(results) + + def _setup_llm(self) -> BaseLlm: + model_id = self._llm_options.judge_model + llm_registry = LLMRegistry() + llm_class = llm_registry.resolve(model_id) + return llm_class(model=model_id) + + def _format_llm_prompt( + self, + invocation: Invocation, + conversation_scenario: ConversationScenario, + previous_invocations: Optional[list[Invocation]], + ) -> str: + if previous_invocations is None: + raise ValueError( + "Previous invocations should have a set value when " + "formatting the LLM prompt. " + f"Encountered: {previous_invocations}" + ) + + if conversation_scenario is None: + raise ValueError( + "Conversation scenario should have a set value when " + "formatting the LLM prompt. " + f"Encountered: {conversation_scenario}" + ) + + return self._prompt_template.format( + conversation_plan=conversation_scenario.conversation_plan, + conversation_history=_format_conversation_history(previous_invocations), + generated_user_response=get_text_from_content(invocation.user_content), + stop_signal=self._stop_signal, + ) + + def _convert_llm_response_to_score( + self, auto_rater_response: LlmResponse + ) -> AutoRaterScore: + response_text = get_text_from_content(auto_rater_response.content) + if response_text is None or not response_text: + return AutoRaterScore() + label = _parse_llm_response(response_text) + + if label == Label.VALID: + return AutoRaterScore(score=1.0) + elif label == Label.INVALID: + return AutoRaterScore(score=0.0) + else: + return AutoRaterScore() + + def _aggregate_samples( + self, + per_invocation_samples: list[PerInvocationResult], + ) -> PerInvocationResult: + """Aggregates samples by taking majority vote.""" + if not per_invocation_samples: + raise ValueError("No samples to aggregate into a result.") + + positive_results = [s for s in per_invocation_samples if s.score == 1.0] + negative_results = [s for s in per_invocation_samples if s.score == 0.0] + + if not positive_results and not negative_results: + return per_invocation_samples[0] + elif len(positive_results) > len(negative_results): + return positive_results[0] + else: # len(negative_results) >= len(positive_results) + return negative_results[0] + + def _aggregate_conversation_results( + self, per_invocation_results: list[PerInvocationResult] + ) -> EvaluationResult: + """Computes the fraction of results that resulted in a pass status.""" + num_valid = 0 + num_evaluated = 0 + for result in per_invocation_results: + if result.eval_status == EvalStatus.PASSED: + num_valid += result.score + + num_evaluated += 1 + + # If no invocation was evaluated, we mark the score as None. + if num_evaluated == 0: + return EvaluationResult( + per_invocation_results=per_invocation_results, + ) + + overall_score = num_valid / num_evaluated + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=get_eval_status( + overall_score, self._criterion.threshold + ), + per_invocation_results=per_invocation_results, + ) + + def _evaluate_first_turn( + self, + first_invocation: Invocation, + conversation_scenario: ConversationScenario, + ) -> PerInvocationResult: + if first_invocation.user_content is None: + return PerInvocationResult( + actual_invocation=first_invocation, + eval_status=EvalStatus.NOT_EVALUATED, + ) + + score = int( + get_text_from_content(first_invocation.user_content).strip() + == conversation_scenario.starting_prompt.strip() + ) + return PerInvocationResult( + actual_invocation=first_invocation, + score=score, + eval_status=get_eval_status(score, self._eval_metric.threshold), + ) + + async def _evaluate_intermediate_turn( + self, + invocation_at_step: Invocation, + invocation_history: list[Invocation], + conversation_scenario: Optional[ConversationScenario], + ) -> PerInvocationResult: + + auto_rater_prompt = self._format_llm_prompt( + invocation=invocation_at_step, + conversation_scenario=conversation_scenario, + previous_invocations=invocation_history, + ) + + llm_request = LlmRequest( + model=self._llm_options.judge_model, + contents=[ + genai_types.Content( + parts=[genai_types.Part(text=auto_rater_prompt)], + role="user", + ) + ], + config=self._llm_options.judge_model_config, + ) + add_default_retry_options_if_not_present(llm_request) + num_samples = self._llm_options.num_samples + samples = [] + for _ in range(num_samples): + llm_score = await self._sample_llm(llm_request) + samples.append( + PerInvocationResult( + eval_status=get_eval_status( + llm_score.score, self._eval_metric.threshold + ), + score=llm_score.score, + actual_invocation=invocation_at_step, + ) + ) + if not samples: + return PerInvocationResult( + eval_status=EvalStatus.NOT_EVALUATED, + actual_invocation=invocation_at_step, + ) + + return self._aggregate_samples(samples) + + async def _evaluate_stop_signal_turn( + self, + invocation_history: list[Invocation], + conversation_scenario: ConversationScenario, + ) -> PerInvocationResult: + return await self._evaluate_intermediate_turn( + invocation_at_step=_get_stop_signal_invocation(self._stop_signal), + invocation_history=invocation_history, + conversation_scenario=conversation_scenario, + ) + + async def _sample_llm(self, llm_request: LlmRequest) -> AutoRaterScore: + async with Aclosing(self._llm.generate_content_async(llm_request)) as agen: + async for llm_response in agen: + # Non-streaming call, so there is only one response content. + return self._convert_llm_response_to_score(llm_response) diff --git a/src/google/adk/evaluation/vertex_ai_eval_facade.py b/src/google/adk/evaluation/vertex_ai_eval_facade.py index bddcbe53f3..92ac8fad27 100644 --- a/src/google/adk/evaluation/vertex_ai_eval_facade.py +++ b/src/google/adk/evaluation/vertex_ai_eval_facade.py @@ -23,6 +23,7 @@ import pandas as pd from typing_extensions import override +from .eval_case import ConversationScenario from .eval_case import Invocation from .evaluator import EvalStatus from .evaluator import EvaluationResult @@ -69,6 +70,7 @@ def evaluate_invocations( self, actual_invocations: list[Invocation], expected_invocations: Optional[list[Invocation]], + _: Optional[ConversationScenario] = None, ) -> EvaluationResult: if self._expected_invocations_required and expected_invocations is None: raise ValueError("expected_invocations is needed by this metric.") diff --git a/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py b/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py new file mode 100644 index 0000000000..25d8776978 --- /dev/null +++ b/tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_v1.py @@ -0,0 +1,613 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.evaluation.eval_case import ConversationScenario +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_metrics import EvalMetric +from google.adk.evaluation.eval_metrics import EvalStatus +from google.adk.evaluation.eval_metrics import JudgeModelOptions +from google.adk.evaluation.eval_metrics import LlmBackedUserSimulatorCriterion +from google.adk.evaluation.evaluator import PerInvocationResult +from google.adk.evaluation.llm_as_judge import AutoRaterScore +from google.adk.evaluation.llm_as_judge_utils import Label +from google.adk.evaluation.simulation.per_turn_user_simulator_quality_v1 import _format_conversation_history +from google.adk.evaluation.simulation.per_turn_user_simulator_quality_v1 import _parse_llm_response +from google.adk.evaluation.simulation.per_turn_user_simulator_quality_v1 import PerTurnUserSimulatorQualityV1 +from google.adk.models.llm_response import LlmResponse +from google.genai import types as genai_types +import pytest + + +@pytest.mark.parametrize( + "response_text", + [ + """```json + { + "criteria": [ + { + "name": "TEST_NAME", + "reasoning": "test_resonining", + "passes": True + } + ], + "is_valid_undefined_key": True + } + ```""", + """```json + { + "criteria": [ + { + "name": "TEST_NAME", + "reasoning": "test_resonining", + "passes": True + } + ], + "is_valid": "undefined label", + } + ```""", + ], +) +def test_parse_llm_response_label_not_found(response_text): + label = _parse_llm_response(response_text) + assert label == Label.NOT_FOUND + + +@pytest.mark.parametrize( + "response_text", + [ + """```json + { + "criteria": [ + { + "name": "TEST_NAME", + "reasoning": "test_resonining", + "passes": True + } + ], + "is_valid": True + } + ```""", + """```json + { + "criteria": [ + { + "name": "TEST_NAME", + "reasoning": "test_resonining", + "passes": True + } + ], + "is_valid": "true" + } + ```""", + """```json + { + "criteria": [ + { + "name": "TEST_NAME", + "reasoning": "test_resonining", + "passes": True + } + ], + "is_valid": "valid" + } + ```""", + ], +) +def test_parse_llm_response_label_valid(response_text): + label = _parse_llm_response(response_text) + assert label == Label.VALID + + +@pytest.mark.parametrize( + "response_text", + [ + """```json + { + "criteria": [ + { + "name": "TEST_NAME", + "reasoning": "test_resonining", + "passes": False + } + ], + "is_valid": False + } + ```""", + """```json + { + "criteria": [ + { + "name": "TEST_NAME", + "reasoning": "test_resonining", + "passes": False + } + ], + "is_valid": "false", + } + ```""", + """```json + { + "criteria": [ + { + "name": "TEST_NAME", + "reasoning": "test_resonining", + "passes": False + } + ], + "is_valid": "invalid", + } + ```""", + ], +) +def test_parse_llm_response_label_invalid(response_text): + label = _parse_llm_response(response_text) + assert label == Label.INVALID + + +def create_test_template() -> str: + return """This is a test template with stop signal: `{stop_signal}`. + +# Conversation Plan +{conversation_plan} + +# Conversation History +{conversation_history} + +# Generated User Response +{generated_user_response} +""".strip() + + +def _create_test_evaluator( + threshold: float = 1.0, stop_signal: str = "test stop signal" +) -> PerTurnUserSimulatorQualityV1: + evaluator = PerTurnUserSimulatorQualityV1( + EvalMetric( + metric_name="test_per_turn_user_simulator_quality_v1", + threshold=threshold, + criterion=LlmBackedUserSimulatorCriterion( + threshold=threshold, + stop_signal=stop_signal, + judge_model_options=JudgeModelOptions( + judge_model="gemini-2.5-flash", + judge_model_config=genai_types.GenerateContentConfig(), + num_samples=3, + ), + ), + ), + ) + evaluator._prompt_template = create_test_template() + return evaluator + + +def _create_test_conversation_scenario( + conversation_plan: str = "test conversation plan", + starting_prompt: str = "test starting prompt", +) -> ConversationScenario: + """Returns a ConversationScenario.""" + return ConversationScenario( + starting_prompt=starting_prompt, + conversation_plan=conversation_plan, + ) + + +def _create_test_invocation( + invocation_id: str, + user_content: str = "user content", + model_content: str = "model content", +) -> Invocation: + return Invocation( + invocation_id=invocation_id, + user_content=genai_types.Content( + parts=[genai_types.Part(text=user_content)], + role="user", + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text=model_content)], + role="model", + ), + ) + + +def _create_test_invocations( + conversation_history: list[str], +) -> list[Invocation]: + conversation_length = len(conversation_history) + + assert conversation_length % 2 == 0 + + invocations = [] + for i in range(conversation_length // 2): + user_message = conversation_history[2 * i] + model_message = conversation_history[2 * i + 1] + + invocations.append( + _create_test_invocation( + "turn {i}", user_content=user_message, model_content=model_message + ) + ) + + return invocations + + +def test_format_llm_prompt(): + evaluator = _create_test_evaluator(stop_signal="test stop signal") + + starting_prompt = "first user prompt." + conversation_scenario = _create_test_conversation_scenario( + conversation_plan="test conversation plan.", + starting_prompt=starting_prompt, + ) + invocation_history = _create_test_invocations([ + starting_prompt, + "first agent response.", + "second user prompt.", + "second agent response.", + "third user prompt.", + "third agent response.", + ]) + + prompt = evaluator._format_llm_prompt( + invocation=invocation_history[-1], + conversation_scenario=conversation_scenario, + previous_invocations=invocation_history[:-1], + ) + + assert ( + prompt == """This is a test template with stop signal: `test stop signal`. + +# Conversation Plan +test conversation plan. + +# Conversation History +user: first user prompt. + +model: first agent response. + +user: second user prompt. + +model: second agent response. + +# Generated User Response +third user prompt. +""".strip() + ) + + +def test_convert_llm_response_to_score_pass(): + evaluator = _create_test_evaluator() + auto_rater_response = """```json +{ + "is_valid": True, +} +```""" + llm_response = LlmResponse( + content=genai_types.Content( + parts=[genai_types.Part(text=auto_rater_response)], + role="model", + ) + ) + auto_rater_score = evaluator._convert_llm_response_to_score(llm_response) + assert auto_rater_score == AutoRaterScore(score=1.0) + + +def test_convert_llm_response_to_score_failure(): + evaluator = _create_test_evaluator() + auto_rater_response = """```json +{ + "is_valid": False, +} +```""" + llm_response = LlmResponse( + content=genai_types.Content( + parts=[genai_types.Part(text=auto_rater_response)], + role="model", + ) + ) + auto_rater_score = evaluator._convert_llm_response_to_score(llm_response) + assert auto_rater_score == AutoRaterScore(score=0.0) + + +def test_convert_llm_response_to_score_invalid_json(): + evaluator = _create_test_evaluator() + llm_response = LlmResponse( + content=genai_types.Content( + parts=[genai_types.Part(text="invalid json")], + role="model", + ) + ) + auto_rater_score = evaluator._convert_llm_response_to_score(llm_response) + assert auto_rater_score == AutoRaterScore() + + +def test_convert_llm_response_to_score_missing_key(): + evaluator = _create_test_evaluator() + llm_response = LlmResponse( + content=genai_types.Content( + parts=[genai_types.Part(text="{}")], + role="model", + ) + ) + auto_rater_score = evaluator._convert_llm_response_to_score(llm_response) + assert auto_rater_score == AutoRaterScore() + + +def test_aggregate_samples_not_evaluated(): + evaluator = _create_test_evaluator() + samples = [ + PerInvocationResult( + actual_invocation=_create_test_invocation("1"), + score=None, + eval_status=EvalStatus.NOT_EVALUATED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("2"), + score=None, + eval_status=EvalStatus.NOT_EVALUATED, + ), + ] + + aggregation = evaluator._aggregate_samples(samples) + assert aggregation == samples[0] + + +def test_aggregate_samples_pass(): + evaluator = _create_test_evaluator() + # The majority of results should be positive. + samples = [ + PerInvocationResult( + actual_invocation=_create_test_invocation("1"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("2"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("3"), + score=0.0, + eval_status=EvalStatus.FAILED, + ), + ] + + aggregation_result = evaluator._aggregate_samples(samples) + + assert aggregation_result.score == 1.0 + assert aggregation_result.eval_status == EvalStatus.PASSED + + +def test_aggregate_samples_failure(): + evaluator = _create_test_evaluator() + + # The majority of results should be negative. + samples = [ + PerInvocationResult( + actual_invocation=_create_test_invocation("1"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("2"), + score=0.0, + eval_status=EvalStatus.FAILED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("3"), + score=0.0, + eval_status=EvalStatus.FAILED, + ), + ] + + aggregation_result = evaluator._aggregate_samples(samples) + + assert aggregation_result.score == 0.0 + assert aggregation_result.eval_status == EvalStatus.FAILED + + +def test_format_conversation_history(): + conversation_history = [ + "first user prompt.", + "first agent response.", + "second user prompt.", + "second agent response.", + ] + invocation_history = _create_test_invocations(conversation_history) + formatted_history = _format_conversation_history(invocation_history) + assert formatted_history == """user: first user prompt. + +model: first agent response. + +user: second user prompt. + +model: second agent response.""" + + +def test_evaluate_first_turn_pass(): + evaluator = _create_test_evaluator( + threshold=0.8, stop_signal="test stop signal" + ) + conversation_scenario = _create_test_conversation_scenario( + conversation_plan="plan", + starting_prompt="test starting prompt", + ) + invocation = _create_test_invocation("1", user_content="test starting prompt") + + result = evaluator._evaluate_first_turn(invocation, conversation_scenario) + + assert result.score == 1.0 + assert result.eval_status == EvalStatus.PASSED + + +def test_evaluate_first_turn_failure(): + evaluator = _create_test_evaluator( + threshold=1.0, stop_signal="test stop signal" + ) + conversation_scenario = _create_test_conversation_scenario( + conversation_plan="plan", + starting_prompt="test starting prompt", + ) + invocation = _create_test_invocation("1", "wrong starting prompt") + + result = evaluator._evaluate_first_turn(invocation, conversation_scenario) + + assert result.score == 0.0 + assert result.eval_status == EvalStatus.FAILED + + +def test_aggregate_conversation_results_all_pass_produces_pass(): + evaluator = _create_test_evaluator() + results = [ + PerInvocationResult( + actual_invocation=_create_test_invocation("1"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("2"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("3"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("4"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + ] + aggregation = evaluator._aggregate_conversation_results(results) + assert aggregation.overall_score == 1.0 + assert aggregation.overall_eval_status == EvalStatus.PASSED + + +def test_aggregate_conversation_results_percentage_above_threshold_produces_pass(): + evaluator = _create_test_evaluator(threshold=0.7) + results = [ + PerInvocationResult( + actual_invocation=_create_test_invocation("1"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("2"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("3"), + score=0.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("4"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + ] + aggregation = evaluator._aggregate_conversation_results(results) + assert aggregation.overall_score == 0.75 + assert aggregation.overall_eval_status == EvalStatus.PASSED + + +def test_aggregate_conversation_results_all_failures_produces_failure(): + evaluator = _create_test_evaluator() + results = [ + PerInvocationResult( + actual_invocation=_create_test_invocation("1"), + score=0.0, + eval_status=EvalStatus.FAILED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("2"), + score=0.0, + eval_status=EvalStatus.FAILED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("3"), + score=0.0, + eval_status=EvalStatus.FAILED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("4"), + score=0.0, + eval_status=EvalStatus.FAILED, + ), + ] + aggregation = evaluator._aggregate_conversation_results(results) + assert aggregation.overall_score == 0.0 + assert aggregation.overall_eval_status == EvalStatus.FAILED + + +def test_aggregate_conversation_percentage_below_threshold_produces_failure(): + evaluator = _create_test_evaluator(threshold=1.0) + results = [ + PerInvocationResult( + actual_invocation=_create_test_invocation("1"), + score=0.0, + eval_status=EvalStatus.FAILED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("2"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("3"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + PerInvocationResult( + actual_invocation=_create_test_invocation("4"), + score=1.0, + eval_status=EvalStatus.PASSED, + ), + ] + aggregation = evaluator._aggregate_conversation_results(results) + assert aggregation.overall_score == 0.75 + assert aggregation.overall_eval_status == EvalStatus.FAILED + + +@pytest.mark.asyncio +async def test_evaluate_invocations_all_pass(): + evaluator = _create_test_evaluator() + + async def sample_llm_valid(*args, **kwargs): + return AutoRaterScore(score=1.0) + + evaluator._sample_llm = sample_llm_valid + starting_prompt = "first user prompt." + conversation_scenario = _create_test_conversation_scenario( + starting_prompt=starting_prompt + ) + invocations = _create_test_invocations( + [starting_prompt, "model 1.", "user 2.", "model 2."] + ) + result = await evaluator.evaluate_invocations( + actual_invocations=invocations, + expected_invocations=None, + conversation_scenario=conversation_scenario, + ) + + assert result.overall_score == 1.0 + assert result.overall_eval_status == EvalStatus.PASSED + assert len(result.per_invocation_results) == 2 + assert result.per_invocation_results[0].score == 1.0 + assert result.per_invocation_results[1].score == 1.0 From 8335f35015c7d4349bc4ac47dedbe99663b78e62 Mon Sep 17 00:00:00 2001 From: Hiroaki Sano Date: Mon, 15 Dec 2025 11:08:54 -0800 Subject: [PATCH 2/4] fix: Change error_message column type to TEXT in DatabaseSessionService Merge https://github.com/google/adk-python/pull/3917 To migrate from existing DB : ALTER TABLE events ALTER COLUMN error_message TYPE TEXT; -- PostgreSQL ALTER TABLE events MODIFY error_message TEXT; -- MySQL SQLite: Doesn't enforce VARCHAR length limits anyway. No impact. ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** n/a **2. Or, if no issue exists, describe the change:** **Problem:** When storing events with error messages longer than 1024 characters using `DatabaseSessionService`, PostgreSQL raises: ``` ERROR: value too long for type character varying(1024) ``` The `error_message` column in `StorageEvent` is defined as `String(1024)`, which maps to `VARCHAR(1024)`. Error messages can exceed 1024 characters. **Solution:** Change the column type from `String(1024)` to `Text` to allow unlimited length error messages. ### Testing Plan **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. $ pytest ./tests/unittests/sessions/ -v ======================= 75 passed, 3 warnings in 26.92s ======================== **Manual End-to-End (E2E) Tests:** - Verified that events with long error messages (>1024 chars) can be stored in PostgreSQL - Verified backward compatibility with existing databases ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [x] Any dependent changes have been merged and published in downstream modules. ### Additional context This is a minimal change (1 line) that only affects the `error_message` column type definition. Co-authored-by: Xiang (Sean) Zhou COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3917 from hiroakis:main 1474fd552cdbd7206de383e5507fd8a733aecda1 PiperOrigin-RevId: 844845692 --- src/google/adk/sessions/database_session_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index a352918211..047340c55e 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -269,7 +269,7 @@ class StorageEvent(Base): error_code: Mapped[str] = mapped_column( String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True ) - error_message: Mapped[str] = mapped_column(String(1024), nullable=True) + error_message: Mapped[str] = mapped_column(Text, nullable=True) interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) input_transcription: Mapped[dict[str, Any]] = mapped_column( DynamicJSON, nullable=True From e8ab7dafa96d5890a4fff919b9fa180993ef5830 Mon Sep 17 00:00:00 2001 From: Liang Wu Date: Mon, 15 Dec 2025 13:38:13 -0800 Subject: [PATCH 3/4] chore: Move SQLite migration script to migration/ folder Co-authored-by: Liang Wu PiperOrigin-RevId: 844902789 --- .../sessions/{ => migration}/migrate_from_sqlalchemy_sqlite.py | 0 src/google/adk/sessions/sqlite_session_service.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/google/adk/sessions/{ => migration}/migrate_from_sqlalchemy_sqlite.py (100%) diff --git a/src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py similarity index 100% rename from src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py rename to src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 8ba6531f52..e0d44b3872 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -107,7 +107,7 @@ def __init__(self, db_path: str): f"Database {db_path} seems to use an old schema." " Please run the migration command to" " migrate it to the new schema. Example: `python -m" - " google.adk.sessions.migrate_from_sqlalchemy_sqlite" + " google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite" f" --source_db_path {db_path} --dest_db_path" f" {db_path}.new` then backup {db_path} and rename" f" {db_path}.new to {db_path}." From a0885064b0cbef3b25484025da0748dc64320d4a Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Mon, 15 Dec 2025 13:45:41 -0800 Subject: [PATCH 4/4] chore: Add `override_feature_enabled` to override the default feature enable states Co-authored-by: Xuan Yang PiperOrigin-RevId: 844905911 --- src/google/adk/features/__init__.py | 2 + src/google/adk/features/_feature_registry.py | 45 +++++++++- .../features/test_feature_registry.py | 83 ++++++++++++++++++- 3 files changed, 128 insertions(+), 2 deletions(-) diff --git a/src/google/adk/features/__init__.py b/src/google/adk/features/__init__.py index f948c0779d..578a44966e 100644 --- a/src/google/adk/features/__init__.py +++ b/src/google/adk/features/__init__.py @@ -17,6 +17,7 @@ from ._feature_decorator import working_in_progress from ._feature_registry import FeatureName from ._feature_registry import is_feature_enabled +from ._feature_registry import override_feature_enabled __all__ = [ "experimental", @@ -24,4 +25,5 @@ "working_in_progress", "FeatureName", "is_feature_enabled", + "override_feature_enabled", ] diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 46b56eb6d9..934d9f8768 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -101,6 +101,9 @@ class FeatureConfig: # Track which experimental features have already warned (warn only once) _WARNED_FEATURES: set[FeatureName] = set() +# Programmatic overrides (highest priority, checked before env vars) +_FEATURE_OVERRIDES: dict[FeatureName, bool] = {} + def _get_feature_config( feature_name: FeatureName, @@ -129,12 +132,45 @@ def _register_feature( _FEATURE_REGISTRY[feature_name] = config +def override_feature_enabled( + feature_name: FeatureName, + enabled: bool, +) -> None: + """Programmatically override a feature's enabled state. + + This override takes highest priority, superseding environment variables + and registry defaults. Use this when environment variables are not + available or practical in your deployment environment. + + Args: + feature_name: The feature name to override. + enabled: Whether the feature should be enabled. + + Example: + ```python + from google.adk.features import FeatureName, override_feature_enabled + + # Enable a feature programmatically + override_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True) + ``` + """ + config = _get_feature_config(feature_name) + if config is None: + raise ValueError(f"Feature {feature_name} is not registered.") + _FEATURE_OVERRIDES[feature_name] = enabled + + def is_feature_enabled(feature_name: FeatureName) -> bool: """Check if a feature is enabled at runtime. This function is used for runtime behavior gating within stable features. It allows you to conditionally enable new behavior based on feature flags. + Priority order (highest to lowest): + 1. Programmatic overrides (via override_feature_enabled) + 2. Environment variables (ADK_ENABLE_* / ADK_DISABLE_*) + 3. Registry defaults + Args: feature_name: The feature name (e.g., FeatureName.RESUMABILITY). @@ -156,7 +192,14 @@ def _execute_agent_loop(): if config is None: raise ValueError(f"Feature {feature_name} is not registered.") - # Check environment variables first (highest priority) + # Check programmatic overrides first (highest priority) + if feature_name in _FEATURE_OVERRIDES: + enabled = _FEATURE_OVERRIDES[feature_name] + if enabled and config.stage != FeatureStage.STABLE: + _emit_non_stable_warning_once(feature_name, config.stage) + return enabled + + # Check environment variables second feature_name_str = ( feature_name.value if isinstance(feature_name, FeatureName) diff --git a/tests/unittests/features/test_feature_registry.py b/tests/unittests/features/test_feature_registry.py index ab84d986ea..1d6b0f2d6d 100644 --- a/tests/unittests/features/test_feature_registry.py +++ b/tests/unittests/features/test_feature_registry.py @@ -17,6 +17,7 @@ import os import warnings +from google.adk.features._feature_registry import _FEATURE_OVERRIDES from google.adk.features._feature_registry import _FEATURE_REGISTRY from google.adk.features._feature_registry import _get_feature_config from google.adk.features._feature_registry import _register_feature @@ -24,6 +25,7 @@ from google.adk.features._feature_registry import FeatureConfig from google.adk.features._feature_registry import FeatureStage from google.adk.features._feature_registry import is_feature_enabled +from google.adk.features._feature_registry import override_feature_enabled import pytest FEATURE_CONFIG_WIP = FeatureConfig(FeatureStage.WIP, default_on=False) @@ -38,7 +40,7 @@ @pytest.fixture(autouse=True) def reset_env_and_registry(monkeypatch): - """Reset environment variables and registry before each test.""" + """Reset environment variables, registry and overrides before each test.""" # Clean up environment variables for key in list(os.environ.keys()): if key.startswith("ADK_ENABLE_") or key.startswith("ADK_DISABLE_"): @@ -47,11 +49,17 @@ def reset_env_and_registry(monkeypatch): # Reset warned features set _WARNED_FEATURES.clear() + # Reset feature overrides + _FEATURE_OVERRIDES.clear() + yield # Reset warned features set _WARNED_FEATURES.clear() + # Reset feature overrides + _FEATURE_OVERRIDES.clear() + class TestGetFeatureConfig: """Tests for get_feature_config() function.""" @@ -159,3 +167,76 @@ def test_warn_once_per_feature(self, monkeypatch): assert "[EXPERIMENTAL] feature DISABLED_FEATURE is enabled." in str( w[0].message ) + + +class TestOverrideFeatureEnabled: + """Tests for override_feature_enabled() function.""" + + def test_override_not_in_registry_raises_value_error(self): + """Overriding features not in registry raises ValueError.""" + with pytest.raises(ValueError): + override_feature_enabled("UNKNOWN_FEATURE", True) + + def test_override_enables_disabled_feature(self): + """Programmatic override can enable a disabled feature.""" + _register_feature("OVERRIDE_TEST", FEATURE_CONFIG_EXPERIMENTAL_DISABLED) + assert not is_feature_enabled("OVERRIDE_TEST") + + override_feature_enabled("OVERRIDE_TEST", True) + with warnings.catch_warnings(record=True) as w: + assert is_feature_enabled("OVERRIDE_TEST") + assert len(w) == 1 + assert "[EXPERIMENTAL] feature OVERRIDE_TEST is enabled." in str( + w[0].message + ) + + def test_override_disables_enabled_feature(self): + """Programmatic override can disable an enabled feature.""" + _register_feature("OVERRIDE_TEST", FEATURE_CONFIG_EXPERIMENTAL_ENABLED) + + override_feature_enabled("OVERRIDE_TEST", False) + with warnings.catch_warnings(record=True) as w: + assert not is_feature_enabled("OVERRIDE_TEST") + assert not w + + def test_override_takes_precedence_over_env_enable(self, monkeypatch): + """Programmatic override takes precedence over ADK_ENABLE_* env var.""" + _register_feature("PRIORITY_TEST", FEATURE_CONFIG_EXPERIMENTAL_DISABLED) + + # Set env var to enable + monkeypatch.setenv("ADK_ENABLE_PRIORITY_TEST", "true") + assert is_feature_enabled("PRIORITY_TEST") + + # But override to disable + override_feature_enabled("PRIORITY_TEST", False) + + with warnings.catch_warnings(record=True) as w: + assert not is_feature_enabled("PRIORITY_TEST") + assert not w + + def test_override_takes_precedence_over_env_disable(self, monkeypatch): + """Programmatic override takes precedence over ADK_DISABLE_* env var.""" + _register_feature("PRIORITY_TEST", FEATURE_CONFIG_EXPERIMENTAL_ENABLED) + + # Set env var to disable + monkeypatch.setenv("ADK_DISABLE_PRIORITY_TEST", "true") + assert not is_feature_enabled("PRIORITY_TEST") + + # But override to enable + override_feature_enabled("PRIORITY_TEST", True) + + with warnings.catch_warnings(record=True) as w: + assert is_feature_enabled("PRIORITY_TEST") + assert len(w) == 1 + assert "[EXPERIMENTAL] feature PRIORITY_TEST is enabled." in str( + w[0].message + ) + + def test_override_stable_feature_no_warning(self): + """Overriding stable features does not emit warnings.""" + _register_feature("STABLE_OVERRIDE", FEATURE_CONFIG_STABLE) + + override_feature_enabled("STABLE_OVERRIDE", True) + with warnings.catch_warnings(record=True) as w: + assert is_feature_enabled("STABLE_OVERRIDE") + assert not w