diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index fdc83b1557..f0b8fba022 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -1553,6 +1553,47 @@ def cli_deploy_cloud_run( click.secho(f"Deploy failed: {e}", fg="red", err=True) +@main.group() +def migrate(): + """ADK migration commands.""" + pass + + +@migrate.command("session", cls=HelpfulCommand) +@click.option( + "--source_db_url", + required=True, + help=( + "SQLAlchemy URL of source database in database session service, e.g." + " sqlite:///source.db." + ), +) +@click.option( + "--dest_db_url", + required=True, + help=( + "SQLAlchemy URL of destination database in database session service," + " e.g. sqlite:///dest.db." + ), +) +@click.option( + "--log_level", + type=LOG_LEVELS, + default="INFO", + help="Optional. Set the logging level", +) +def cli_migrate_session( + *, source_db_url: str, dest_db_url: str, log_level: str +): + """Migrates a session database to the latest schema version.""" + logs.setup_adk_logger(getattr(logging, log_level.upper())) + try: + migration_runner.upgrade(source_db_url, dest_db_url) + click.secho("Migration check and upgrade process finished.", fg="green") + except Exception as e: + click.secho(f"Migration failed: {e}", fg="red", err=True) + + @deploy.command("agent_engine") @click.option( "--api_key", diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 036b56ef23..2ab0130639 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -14,8 +14,10 @@ from __future__ import annotations +from contextlib import contextmanager from dataclasses import dataclass from enum import Enum +from typing import Generator import warnings from ..utils.env_utils import is_env_enabled @@ -24,17 +26,23 @@ class FeatureName(str, Enum): """Feature names.""" + AUTHENTICATED_FUNCTION_TOOL = "AUTHENTICATED_FUNCTION_TOOL" + BASE_AUTHENTICATED_TOOL = "BASE_AUTHENTICATED_TOOL" BIG_QUERY_TOOLSET = "BIG_QUERY_TOOLSET" BIG_QUERY_TOOL_CONFIG = "BIG_QUERY_TOOL_CONFIG" BIGTABLE_TOOL_SETTINGS = "BIGTABLE_TOOL_SETTINGS" + BIGTABLE_TOOLSET = "BIGTABLE_TOOLSET" COMPUTER_USE = "COMPUTER_USE" GOOGLE_CREDENTIALS_CONFIG = "GOOGLE_CREDENTIALS_CONFIG" GOOGLE_TOOL = "GOOGLE_TOOL" JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL" PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING" + PUBSUB_TOOL_CONFIG = "PUBSUB_TOOL_CONFIG" PUBSUB_TOOLSET = "PUBSUB_TOOLSET" SPANNER_TOOLSET = "SPANNER_TOOLSET" SPANNER_TOOL_SETTINGS = "SPANNER_TOOL_SETTINGS" + TOOL_CONFIG = "TOOL_CONFIG" + TOOL_CONFIRMATION = "TOOL_CONFIRMATION" class FeatureStage(Enum): @@ -67,6 +75,12 @@ class FeatureConfig: # Central registry: FeatureName -> FeatureConfig _FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = { + FeatureName.AUTHENTICATED_FUNCTION_TOOL: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), + FeatureName.BASE_AUTHENTICATED_TOOL: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.BIG_QUERY_TOOLSET: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), @@ -76,6 +90,9 @@ class FeatureConfig: FeatureName.BIGTABLE_TOOL_SETTINGS: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), + FeatureName.BIGTABLE_TOOLSET: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.COMPUTER_USE: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), @@ -91,6 +108,9 @@ class FeatureConfig: FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), + FeatureName.PUBSUB_TOOL_CONFIG: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.PUBSUB_TOOLSET: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), @@ -100,6 +120,12 @@ class FeatureConfig: FeatureName.SPANNER_TOOL_SETTINGS: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), + FeatureName.TOOL_CONFIG: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), + FeatureName.TOOL_CONFIRMATION: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), } # Track which experimental features have already warned (warn only once) @@ -240,3 +266,52 @@ def _emit_non_stable_warning_once( f"[{feature_stage.name.upper()}] feature {feature_name} is enabled." ) warnings.warn(full_message, category=UserWarning, stacklevel=4) + + +@contextmanager +def temporary_feature_override( + feature_name: FeatureName, + enabled: bool, +) -> Generator[None, None, None]: + """Temporarily override a feature's enabled state within a context. + + This context manager is useful for testing or temporarily enabling/disabling + a feature within a specific scope. The original state is restored when the + context exits. + + Args: + feature_name: The feature name to override. + enabled: Whether the feature should be enabled. + + Yields: + None + + Example: + ```python + from google.adk.features import FeatureName, temporary_feature_override + + # Temporarily enable a feature for testing + with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True): + # Feature is enabled here + result = some_function_that_checks_feature() + # Feature is restored to original state here + ``` + """ + config = _get_feature_config(feature_name) + if config is None: + raise ValueError(f"Feature {feature_name} is not registered.") + + # Save the original override state + had_override = feature_name in _FEATURE_OVERRIDES + original_value = _FEATURE_OVERRIDES.get(feature_name) + + # Apply the temporary override + _FEATURE_OVERRIDES[feature_name] = enabled + try: + yield + finally: + # Restore the original state + if had_override: + _FEATURE_OVERRIDES[feature_name] = original_value + else: + _FEATURE_OVERRIDES.pop(feature_name, None) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index dda3689d26..9fb02d865d 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -461,7 +461,8 @@ async def _content_to_message_param( A litellm Message, a list of litellm Messages. """ - tool_messages = [] + tool_messages: list[Message] = [] + non_tool_parts: list[types.Part] = [] for part in content.parts: if part.function_response: response = part.function_response.response @@ -477,9 +478,22 @@ async def _content_to_message_param( content=response_content, ) ) - if tool_messages: + else: + non_tool_parts.append(part) + + if tool_messages and not non_tool_parts: return tool_messages if len(tool_messages) > 1 else tool_messages[0] + if tool_messages and non_tool_parts: + follow_up = await _content_to_message_param( + types.Content(role=content.role, parts=non_tool_parts), + provider=provider, + ) + follow_up_messages = ( + follow_up if isinstance(follow_up, list) else [follow_up] + ) + return tool_messages + follow_up_messages + # Handle user or assistant messages role = _to_litellm_role(content.role) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 730216e0b0..cbf2c59548 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -153,15 +153,21 @@ def __init__( """Initializes the Runner. Developers should provide either an `app` instance or both `app_name` and - `agent`. Providing a mix of `app` and `app_name`/`agent` will result in a - `ValueError`. Providing `app` is the recommended way to create a runner. + `agent`. When `app` is provided, `app_name` can optionally override the + app's name (useful for deployment scenarios like Agent Engine where the + resource name differs from the app's identifier). However, `agent` should + not be provided when `app` is provided. Providing `app` is the recommended + way to create a runner. Args: - app: An optional `App` instance. If provided, `app_name` and `agent` - should not be specified. + app: An optional `App` instance. If provided, `agent` should not be + specified. `app_name` can optionally override `app.name`. app_name: The application name of the runner. Required if `app` is not - provided. - agent: The root agent to run. Required if `app` is not provided. + provided. If `app` is provided, this can optionally override `app.name` + (e.g., for deployment scenarios where a resource name differs from the + app identifier). + agent: The root agent to run. Required if `app` is not provided. Should + not be provided when `app` is provided. plugins: Deprecated. A list of plugins for the runner. Please use the `app` argument to provide plugins instead. artifact_service: The artifact service for the runner. @@ -171,8 +177,8 @@ def __init__( plugin_close_timeout: The timeout in seconds for plugin close methods. Raises: - ValueError: If `app` is provided along with `app_name` or `plugins`, or - if `app` is not provided but either `app_name` or `agent` is missing. + ValueError: If `app` is provided along with `agent` or `plugins`, or if + `app` is not provided but either `app_name` or `agent` is missing. """ self.app = app ( @@ -213,7 +219,8 @@ def _validate_runner_params( Args: app: An optional `App` instance. - app_name: The application name of the runner. + app_name: The application name of the runner. Can override app.name when + app is provided. agent: The root agent to run. plugins: A list of plugins for the runner. @@ -232,10 +239,6 @@ def _validate_runner_params( ) if app: - if app_name: - raise ValueError( - 'When app is provided, app_name should not be provided.' - ) if agent: raise ValueError('When app is provided, agent should not be provided.') if plugins: @@ -243,7 +246,9 @@ def _validate_runner_params( 'When app is provided, plugins should not be provided and should be' ' provided in the app instead.' ) - app_name = app.name + # Allow app_name to override app.name (useful for deployment scenarios + # like Agent Engine where resource names differ from app identifiers) + app_name = app_name or app.name agent = app.root_agent plugins = app.plugins context_cache_config = app.context_cache_config diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 799f0ea4dd..ea40bee0c3 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -23,6 +23,8 @@ from . import _automatic_function_calling_util from ..agents.common_configs import AgentRefConfig +from ..features import FeatureName +from ..features import is_feature_enabled from ..memory.in_memory_memory_service import InMemoryMemoryService from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService @@ -82,29 +84,48 @@ def _get_declaration(self) -> types.FunctionDeclaration: # Override the description with the agent's description result.description = self.agent.description else: - result = types.FunctionDeclaration( - parameters=types.Schema( - type=types.Type.OBJECT, - properties={ - 'request': types.Schema( - type=types.Type.STRING, - ), - }, - required=['request'], - ), - description=self.agent.description, - name=self.name, - ) + if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL): + result = types.FunctionDeclaration( + name=self.name, + description=self.agent.description, + parameters_json_schema={ + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + ) + else: + result = types.FunctionDeclaration( + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + 'request': types.Schema( + type=types.Type.STRING, + ), + }, + required=['request'], + ), + description=self.agent.description, + name=self.name, + ) # Set response schema for non-GEMINI_API variants if self._api_variant != GoogleLLMVariant.GEMINI_API: # Determine response type based on agent's output schema if isinstance(self.agent, LlmAgent) and self.agent.output_schema: # Agent has structured output schema - response is an object - result.response = types.Schema(type=types.Type.OBJECT) + if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL): + result.response_json_schema = {'type': 'object'} + else: + result.response = types.Schema(type=types.Type.OBJECT) else: # Agent returns text - response is a string - result.response = types.Schema(type=types.Type.STRING) + if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL): + result.response_json_schema = {'type': 'string'} + else: + result.response = types.Schema(type=types.Type.STRING) result.name = self.name return result diff --git a/src/google/adk/tools/authenticated_function_tool.py b/src/google/adk/tools/authenticated_function_tool.py index 01e44ed000..5a1cc932fb 100644 --- a/src/google/adk/tools/authenticated_function_tool.py +++ b/src/google/adk/tools/authenticated_function_tool.py @@ -18,7 +18,6 @@ import logging from typing import Any from typing import Callable -from typing import Dict from typing import Optional from typing import Union @@ -27,14 +26,15 @@ from ..auth.auth_credential import AuthCredential from ..auth.auth_tool import AuthConfig from ..auth.credential_manager import CredentialManager -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .function_tool import FunctionTool from .tool_context import ToolContext logger = logging.getLogger("google_adk." + __name__) -@experimental +@experimental(FeatureName.AUTHENTICATED_FUNCTION_TOOL) class AuthenticatedFunctionTool(FunctionTool): """A FunctionTool that handles authentication before the actual tool logic gets called. Functions can accept a special `credential` argument which is the diff --git a/src/google/adk/tools/base_authenticated_tool.py b/src/google/adk/tools/base_authenticated_tool.py index 6279d4f725..862d1cef5a 100644 --- a/src/google/adk/tools/base_authenticated_tool.py +++ b/src/google/adk/tools/base_authenticated_tool.py @@ -25,14 +25,15 @@ from ..auth.auth_credential import AuthCredential from ..auth.auth_tool import AuthConfig from ..auth.credential_manager import CredentialManager -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .base_tool import BaseTool from .tool_context import ToolContext logger = logging.getLogger("google_adk." + __name__) -@experimental +@experimental(FeatureName.BASE_AUTHENTICATED_TOOL) class BaseAuthenticatedTool(BaseTool): """A base tool class that handles authentication before the actual tool logic gets called. Functions can accept a special `credential` argument which is the diff --git a/src/google/adk/tools/bigtable/bigtable_toolset.py b/src/google/adk/tools/bigtable/bigtable_toolset.py index 3b39e908a9..424a6fee21 100644 --- a/src/google/adk/tools/bigtable/bigtable_toolset.py +++ b/src/google/adk/tools/bigtable/bigtable_toolset.py @@ -23,18 +23,19 @@ from . import metadata_tool from . import query_tool +from ...features import experimental +from ...features import FeatureName from ...tools.base_tool import BaseTool from ...tools.base_toolset import BaseToolset from ...tools.base_toolset import ToolPredicate from ...tools.google_tool import GoogleTool -from ...utils.feature_decorator import experimental from .bigtable_credentials import BigtableCredentialsConfig from .settings import BigtableToolSettings DEFAULT_BIGTABLE_TOOL_NAME_PREFIX = "bigtable" -@experimental +@experimental(FeatureName.BIGTABLE_TOOLSET) class BigtableToolset(BaseToolset): """Bigtable Toolset contains tools for interacting with Bigtable data and metadata. diff --git a/src/google/adk/tools/pubsub/config.py b/src/google/adk/tools/pubsub/config.py index eb48a1f7f4..60f21f1e9b 100644 --- a/src/google/adk/tools/pubsub/config.py +++ b/src/google/adk/tools/pubsub/config.py @@ -17,10 +17,11 @@ from pydantic import BaseModel from pydantic import ConfigDict -from ...utils.feature_decorator import experimental +from ...features import experimental +from ...features import FeatureName -@experimental('Config defaults may have breaking change in the future.') +@experimental(FeatureName.PUBSUB_TOOL_CONFIG) class PubSubToolConfig(BaseModel): """Configuration for Pub/Sub tools.""" diff --git a/src/google/adk/tools/tool_configs.py b/src/google/adk/tools/tool_configs.py index 6953afabd5..bfeba5697b 100644 --- a/src/google/adk/tools/tool_configs.py +++ b/src/google/adk/tools/tool_configs.py @@ -20,24 +20,25 @@ from pydantic import ConfigDict from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName -@experimental +@experimental(FeatureName.TOOL_CONFIG) class BaseToolConfig(BaseModel): """The base class for all tool configs.""" model_config = ConfigDict(extra="forbid") -@experimental +@experimental(FeatureName.TOOL_CONFIG) class ToolArgsConfig(BaseModel): """Config to host free key-value pairs for the args in ToolConfig.""" model_config = ConfigDict(extra="allow") -@experimental +@experimental(FeatureName.TOOL_CONFIG) class ToolConfig(BaseModel): """The configuration for a tool. diff --git a/src/google/adk/tools/tool_confirmation.py b/src/google/adk/tools/tool_confirmation.py index a561ac6a95..6f71699c48 100644 --- a/src/google/adk/tools/tool_confirmation.py +++ b/src/google/adk/tools/tool_confirmation.py @@ -20,12 +20,12 @@ from pydantic import alias_generators from pydantic import BaseModel from pydantic import ConfigDict -from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName -@experimental +@experimental(FeatureName.TOOL_CONFIRMATION) class ToolConfirmation(BaseModel): """Represents a tool confirmation configuration.""" diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 01ece3f183..b3f4bd9e25 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1813,6 +1813,46 @@ async def test_content_to_message_param_multi_part_function_response(): assert messages[1]["content"] == '{"value": 123}' +@pytest.mark.asyncio +async def test_content_to_message_param_function_response_with_extra_parts(): + tool_part = types.Part.from_function_response( + name="load_image", + response={"status": "success"}, + ) + tool_part.function_response.id = "tool_call_1" + + text_part = types.Part.from_text(text="[Image: img_123.png]") + image_bytes = b"test_image_data" + image_part = types.Part.from_bytes(data=image_bytes, mime_type="image/png") + + content = types.Content( + role="user", + parts=[tool_part, text_part, image_part], + ) + + messages = await _content_to_message_param(content) + assert isinstance(messages, list) + assert messages == [ + { + "role": "tool", + "tool_call_id": "tool_call_1", + "content": '{"status": "success"}', + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "[Image: img_123.png]"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + }, + ] + + @pytest.mark.asyncio async def test_content_to_message_param_function_response_preserves_string(): """Tests that string responses are used directly without double-serialization. diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index c347a78931..710ca90d31 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -648,20 +648,38 @@ async def test_runner_passes_plugin_close_timeout(self): ) assert runner.plugin_manager._close_timeout == 10.0 - def test_runner_init_raises_error_with_app_and_app_name_and_agent(self): - """Test that ValueError is raised when app, app_name and agent are provided.""" + @pytest.mark.filterwarnings( + "ignore:The `plugins` argument is deprecated:DeprecationWarning" + ) + def test_runner_init_raises_error_with_app_and_agent(self): + """Test that ValueError is raised when app and agent are provided.""" with pytest.raises( ValueError, - match="When app is provided, app_name should not be provided.", + match="When app is provided, agent should not be provided.", ): Runner( app=App(name="test_app", root_agent=self.root_agent), - app_name="test_app", agent=self.root_agent, session_service=self.session_service, artifact_service=self.artifact_service, ) + @pytest.mark.filterwarnings( + "ignore:The `plugins` argument is deprecated:DeprecationWarning" + ) + def test_runner_init_allows_app_name_override_with_app(self): + """Test that app_name can override app.name when both are provided.""" + app = App(name="test_app", root_agent=self.root_agent) + runner = Runner( + app=app, + app_name="override_name", + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + assert runner.app_name == "override_name" + assert runner.agent == self.root_agent + assert runner.app == app + def test_runner_init_raises_error_without_app_and_app_name(self): """Test ValueError is raised when app is not provided and app_name is missing.""" with pytest.raises( diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 1284e73bce..235830195f 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -22,6 +22,8 @@ from google.adk.auth.auth_credential import HttpCredentials from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_credential import ServiceAccount +from google.adk.features import FeatureName +from google.adk.features._feature_registry import temporary_feature_override from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.tool_context import ToolContext @@ -129,17 +131,16 @@ def test_get_declaration(self): assert declaration.description == "Test tool description" assert declaration.parameters is not None - def test_get_declaration_with_json_schema_for_func_decl_enabled( - self, monkeypatch - ): + def test_get_declaration_with_json_schema_for_func_decl_enabled(self): """Test function declaration generation with json schema for func decl enabled.""" tool = MCPTool( mcp_tool=self.mock_mcp_tool, mcp_session_manager=self.mock_session_manager, ) - with monkeypatch.context() as m: - m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true") + with temporary_feature_override( + FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True + ): declaration = tool._get_declaration() assert isinstance(declaration, FunctionDeclaration) @@ -151,7 +152,7 @@ def test_get_declaration_with_json_schema_for_func_decl_enabled( assert declaration.response_json_schema is None def test_get_declaration_with_output_schema_and_json_schema_for_func_decl_enabled( - self, monkeypatch + self, ): """Test function declaration generation with an output schema and json schema for func decl enabled.""" output_schema = { @@ -169,8 +170,9 @@ def test_get_declaration_with_output_schema_and_json_schema_for_func_decl_enable mcp_session_manager=self.mock_session_manager, ) - with monkeypatch.context() as m: - m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true") + with temporary_feature_override( + FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True + ): declaration = tool._get_declaration() assert isinstance(declaration, FunctionDeclaration) @@ -178,7 +180,7 @@ def test_get_declaration_with_output_schema_and_json_schema_for_func_decl_enable assert declaration.response_json_schema == output_schema def test_get_declaration_with_empty_output_schema_and_json_schema_for_func_decl_enabled( - self, monkeypatch + self, ): """Test function declaration with an empty output schema and json schema for func decl enabled.""" tool = MCPTool( @@ -186,8 +188,9 @@ def test_get_declaration_with_empty_output_schema_and_json_schema_for_func_decl_ mcp_session_manager=self.mock_session_manager, ) - with monkeypatch.context() as m: - m.setenv("ADK_ENABLE_JSON_SCHEMA_FOR_FUNC_DECL", "true") + with temporary_feature_override( + FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True + ): declaration = tool._get_declaration() assert declaration.response is None diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index a9723b4347..48a7a995bb 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -21,6 +21,8 @@ from google.adk.agents.run_config import RunConfig from google.adk.agents.sequential_agent import SequentialAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.features import FeatureName +from google.adk.features._feature_registry import temporary_feature_override from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse @@ -33,6 +35,7 @@ from google.genai import types from google.genai.types import Part from pydantic import BaseModel +import pytest from pytest import mark from .. import testing_utils @@ -702,3 +705,198 @@ class CustomInput(BaseModel): # The description should come from the agent, not the Pydantic model assert declaration.description == agent_description + + +@pytest.fixture +def enable_json_schema_feature(): + """Fixture to enable JSON_SCHEMA_FOR_FUNC_DECL feature for a test.""" + with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True): + yield + + +def test_agent_tool_no_schema_with_json_schema_feature( + enable_json_schema_feature, +): + """Test AgentTool without input_schema uses parameters_json_schema when feature enabled.""" + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + } + + +@mark.parametrize( + 'env_variables', + [ + 'VERTEX', # Test VERTEX_AI variant + ], + indirect=True, +) +def test_agent_tool_response_json_schema_no_output_schema_vertex_ai( + env_variables, + enable_json_schema_feature, +): + """Test AgentTool with no output schema uses response_json_schema for VERTEX_AI when feature enabled.""" + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + 'response_json_schema': {'type': 'string'}, + } + + +@mark.parametrize( + 'env_variables', + [ + 'VERTEX', # Test VERTEX_AI variant + ], + indirect=True, +) +def test_agent_tool_response_json_schema_with_output_schema_vertex_ai( + env_variables, + enable_json_schema_feature, +): + """Test AgentTool with output schema uses response_json_schema for VERTEX_AI when feature enabled.""" + + class CustomOutput(BaseModel): + custom_output: str + + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + output_schema=CustomOutput, + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + 'response_json_schema': {'type': 'object'}, + } + + +@mark.parametrize( + 'env_variables', + [ + 'GOOGLE_AI', # Test GEMINI_API variant + ], + indirect=True, +) +def test_agent_tool_no_response_json_schema_gemini_api( + env_variables, + enable_json_schema_feature, +): + """Test AgentTool with GEMINI_API variant has no response_json_schema when feature enabled.""" + + class CustomOutput(BaseModel): + custom_output: str + + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + output_schema=CustomOutput, + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + # GEMINI_API should not have response_json_schema + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'type': 'object', + 'properties': { + 'request': {'type': 'string'}, + }, + 'required': ['request'], + }, + } + + +@mark.parametrize( + 'env_variables', + [ + 'VERTEX', # Test VERTEX_AI variant + ], + indirect=True, +) +def test_agent_tool_with_input_schema_uses_json_schema_feature( + env_variables, + enable_json_schema_feature, +): + """Test AgentTool with input_schema uses parameters_json_schema when feature enabled.""" + + class CustomInput(BaseModel): + custom_input: str + + class CustomOutput(BaseModel): + custom_output: str + + tool_agent = Agent( + name='tool_agent', + description='A tool agent for testing.', + model=testing_utils.MockModel.create(responses=['test response']), + input_schema=CustomInput, + output_schema=CustomOutput, + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + # When input_schema is provided, build_function_declaration uses Pydantic's + # model_json_schema() which includes additional fields like 'title' + assert declaration.model_dump(exclude_none=True) == { + 'name': 'tool_agent', + 'description': 'A tool agent for testing.', + 'parameters_json_schema': { + 'properties': { + 'custom_input': {'title': 'Custom Input', 'type': 'string'}, + }, + 'required': ['custom_input'], + 'title': 'CustomInput', + 'type': 'object', + }, + 'response_json_schema': {'type': 'object'}, + } diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index dd85b20c86..3797c4ed56 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -13,9 +13,9 @@ # limitations under the License. from enum import Enum -from unittest import mock from google.adk.features import FeatureName +from google.adk.features._feature_registry import temporary_feature_override from google.adk.tools import _automatic_function_calling_util from google.adk.tools.tool_context import ToolContext from google.adk.utils.variant_utils import GoogleLLMVariant @@ -435,11 +435,8 @@ class TestJsonSchemaFeatureFlagEnabled: @pytest.fixture(autouse=True) def enable_feature_flag(self): """Enable the JSON_SCHEMA_FOR_FUNC_DECL feature flag for all tests.""" - with mock.patch.object( - _automatic_function_calling_util, - 'is_feature_enabled', - autospec=True, - side_effect=lambda f: f == FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, + with temporary_feature_override( + FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True ): yield