diff --git a/contributing/samples/data_agent/README.md b/contributing/samples/data_agent/README.md new file mode 100644 index 0000000000..33405c9116 --- /dev/null +++ b/contributing/samples/data_agent/README.md @@ -0,0 +1,61 @@ +# Data Agent Sample + +This sample agent demonstrates ADK's first-party tools for interacting with +Data Agents powered by [Conversational Analytics API](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/overview). +These tools are distributed via +the `google.adk.tools.data_agent` module and allow you to list, +inspect, and +chat with Data Agents using natural language. + +These tools leverage stateful conversations, meaning you can ask follow-up +questions in the same session, and the agent will maintain context. + +## Prerequisites + +1. An active Google Cloud project with BigQuery and Gemini APIs enabled. +2. Google Cloud authentication configured for Application Default Credentials: + ```bash + gcloud auth application-default login + ``` +3. At least one Data Agent created. You could create data agents via + [Conversational API](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/overview), + its + [Python SDK](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/build-agent-sdk), + or for BigQuery data + [BigQuery Studio](https://docs.cloud.google.com/bigquery/docs/create-data-agents#create_a_data_agent). + These agents are created and configured in the Google Cloud console and + point to your BigQuery tables or other data sources. +4. Follow the official + [Setup and prerequisites](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/overview#setup) + guide to enable the API and configure IAM permissions and authentication for + your data sources. + +## Tools Used + +* `list_accessible_data_agents`: Lists Data Agents you have permission to + access in the configured GCP project. +* `get_data_agent_info`: Retrieves details about a specific Data Agent given + its full resource name. +* `ask_data_agent`: Chats with a specific Data Agent using natural language. + This tool maintains conversation state: if you ask multiple + questions to the same agent in one session, it will use the same + conversation, allowing for follow-ups. If you switch agents, a new + conversation will be started for the new agent. + +## How to Run + +1. Navigate to the root of the ADK repository. +2. Run the agent using the ADK CLI: + ```bash + adk run --agent-path contributing/samples/data_agent + ``` +3. The CLI will prompt you for input. You can ask questions like the examples + below. + +## Sample prompts + +* "List accessible data agents." +* "Using agent + `projects/my-project/locations/global/dataAgents/sales-agent-123`, who were + my top 3 customers last quarter?" +* "How does that compare to the quarter before?" diff --git a/contributing/samples/data_agent/__init__.py b/contributing/samples/data_agent/__init__.py new file mode 100644 index 0000000000..4015e47d6e --- /dev/null +++ b/contributing/samples/data_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/data_agent/agent.py b/contributing/samples/data_agent/agent.py new file mode 100644 index 0000000000..634a476286 --- /dev/null +++ b/contributing/samples/data_agent/agent.py @@ -0,0 +1,84 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.adk.agents import Agent +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.tools.data_agent.config import DataAgentToolConfig +from google.adk.tools.data_agent.credentials import DataAgentCredentialsConfig +from google.adk.tools.data_agent.data_agent_toolset import DataAgentToolset +import google.auth +import google.auth.transport.requests + +# Define the desired credential type. +# By default use Application Default Credentials (ADC) from the local +# environment, which can be set up by following +# https://cloud.google.com/docs/authentication/provide-credentials-adc. +CREDENTIALS_TYPE = None + +if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: + # Initiaze the tools to do interactive OAuth + # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET + # must be set + credentials_config = DataAgentCredentialsConfig( + client_id=os.getenv("OAUTH_CLIENT_ID"), + client_secret=os.getenv("OAUTH_CLIENT_SECRET"), + ) +elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT: + # Initialize the tools to use the credentials in the service account key. + # If this flow is enabled, make sure to replace the file path with your own + # service account key file + # https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys + creds, _ = google.auth.load_credentials_from_file( + "service_account_key.json", + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + creds.refresh(google.auth.transport.requests.Request()) + credentials_config = DataAgentCredentialsConfig(credentials=creds) +else: + # Initialize the tools to use the application default credentials. + # https://cloud.google.com/docs/authentication/provide-credentials-adc + application_default_credentials, _ = google.auth.default() + credentials_config = DataAgentCredentialsConfig( + credentials=application_default_credentials + ) + +tool_config = DataAgentToolConfig( + max_query_result_rows=100, +) +da_toolset = DataAgentToolset( + credentials_config=credentials_config, + data_agent_tool_config=tool_config, + tool_filter=[ + "list_accessible_data_agents", + "get_data_agent_info", + "ask_data_agent", + ], +) + +root_agent = Agent( + name="data_agent", + model="gemini-2.0-flash", + description="Agent to answer user questions using Data Agents.", + instruction=( + "## Persona\nYou are a helpful assistant that uses Data Agents" + " to answer user questions about their data.\n\n## Tools\n- You can" + " list available data agents using `list_accessible_data_agents`.\n-" + " You can get information about a specific data agent using" + " `get_data_agent_info`.\n- You can chat with a specific data" + " agent using `ask_data_agent`.\n" + ), + tools=[da_toolset], +) diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index ee10108be8..154c77caf6 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -33,6 +33,8 @@ class FeatureName(str, Enum): BIGTABLE_TOOL_SETTINGS = "BIGTABLE_TOOL_SETTINGS" BIGTABLE_TOOLSET = "BIGTABLE_TOOLSET" COMPUTER_USE = "COMPUTER_USE" + DATA_AGENT_TOOL_CONFIG = "DATA_AGENT_TOOL_CONFIG" + DATA_AGENT_TOOLSET = "DATA_AGENT_TOOLSET" GOOGLE_CREDENTIALS_CONFIG = "GOOGLE_CREDENTIALS_CONFIG" GOOGLE_TOOL = "GOOGLE_TOOL" JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL" @@ -97,6 +99,12 @@ class FeatureConfig: FeatureName.COMPUTER_USE: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), + FeatureName.DATA_AGENT_TOOL_CONFIG: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), + FeatureName.DATA_AGENT_TOOLSET: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.GOOGLE_CREDENTIALS_CONFIG: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 5c742a9367..91135dce5f 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -15,9 +15,11 @@ from __future__ import annotations from typing import Any +from typing import Optional from typing import TYPE_CHECKING from google.genai import types +from pydantic import BaseModel from pydantic import model_validator from typing_extensions import override @@ -37,6 +39,56 @@ from ..agents.base_agent import BaseAgent +def _get_input_schema(agent: BaseAgent) -> Optional[type[BaseModel]]: + """Extracts the input_schema from an agent. + + For LlmAgent, returns its input_schema directly. + For agents with sub_agents, recursively searches the first sub-agent for an + input_schema. + + Args: + agent: The agent to extract input_schema from. + + Returns: + The input_schema if found, None otherwise. + """ + from ..agents.llm_agent import LlmAgent + + if isinstance(agent, LlmAgent): + return agent.input_schema + + # For composite agents, check the first sub-agent + if agent.sub_agents: + return _get_input_schema(agent.sub_agents[0]) + + return None + + +def _get_output_schema(agent: BaseAgent) -> Optional[type[BaseModel]]: + """Extracts the output_schema from an agent. + + For LlmAgent, returns its output_schema directly. + For agents with sub_agents, recursively searches the last sub-agent for an + output_schema. + + Args: + agent: The agent to extract output_schema from. + + Returns: + The output_schema if found, None otherwise. + """ + from ..agents.llm_agent import LlmAgent + + if isinstance(agent, LlmAgent): + return agent.output_schema + + # For composite agents, check the last sub-agent + if agent.sub_agents: + return _get_output_schema(agent.sub_agents[-1]) + + return None + + class AgentTool(BaseTool): """A tool that wraps an agent. @@ -74,12 +126,14 @@ def populate_name(cls, data: Any) -> Any: @override def _get_declaration(self) -> types.FunctionDeclaration: - from ..agents.llm_agent import LlmAgent from ..utils.variant_utils import GoogleLLMVariant - if isinstance(self.agent, LlmAgent) and self.agent.input_schema: + input_schema = _get_input_schema(self.agent) + output_schema = _get_output_schema(self.agent) + + if input_schema: result = _automatic_function_calling_util.build_function_declaration( - func=self.agent.input_schema, variant=self._api_variant + func=input_schema, variant=self._api_variant ) # Override the description with the agent's description result.description = self.agent.description @@ -114,7 +168,7 @@ def _get_declaration(self) -> types.FunctionDeclaration: # Set response schema for non-GEMINI_API variants if self._api_variant != GoogleLLMVariant.GEMINI_API: # Determine response type based on agent's output schema - if isinstance(self.agent, LlmAgent) and self.agent.output_schema: + if output_schema: # Agent has structured output schema - response is an object if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL): result.response_json_schema = {'type': 'object'} @@ -137,15 +191,15 @@ async def run_async( args: dict[str, Any], tool_context: ToolContext, ) -> Any: - from ..agents.llm_agent import LlmAgent from ..runners import Runner from ..sessions.in_memory_session_service import InMemorySessionService if self.skip_summarization: tool_context.actions.skip_summarization = True - if isinstance(self.agent, LlmAgent) and self.agent.input_schema: - input_value = self.agent.input_schema.model_validate(args) + input_schema = _get_input_schema(self.agent) + if input_schema: + input_value = input_schema.model_validate(args) content = types.Content( role='user', parts=[ @@ -212,10 +266,11 @@ async def run_async( merged_text = '\n'.join( p.text for p in last_content.parts if p.text and not p.thought ) - if isinstance(self.agent, LlmAgent) and self.agent.output_schema: - tool_result = self.agent.output_schema.model_validate_json( - merged_text - ).model_dump(exclude_none=True) + output_schema = _get_output_schema(self.agent) + if output_schema: + tool_result = output_schema.model_validate_json(merged_text).model_dump( + exclude_none=True + ) else: tool_result = merged_text return tool_result diff --git a/src/google/adk/tools/data_agent/__init__.py b/src/google/adk/tools/data_agent/__init__.py new file mode 100644 index 0000000000..e203faa07d --- /dev/null +++ b/src/google/adk/tools/data_agent/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data Agent Tools.""" + +from __future__ import annotations + +from .credentials import DataAgentCredentialsConfig +from .data_agent_toolset import DataAgentToolset + +__all__ = [ + "DataAgentCredentialsConfig", + "DataAgentToolset", +] diff --git a/src/google/adk/tools/data_agent/config.py b/src/google/adk/tools/data_agent/config.py new file mode 100644 index 0000000000..3b86047764 --- /dev/null +++ b/src/google/adk/tools/data_agent/config.py @@ -0,0 +1,35 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pydantic import BaseModel +from pydantic import ConfigDict + +from ...features import experimental +from ...features import FeatureName + + +@experimental(FeatureName.DATA_AGENT_TOOL_CONFIG) +class DataAgentToolConfig(BaseModel): + """Configuration for Data Agent tools.""" + + # Forbid any fields not defined in the model + model_config = ConfigDict(extra='forbid') + + max_query_result_rows: int = 50 + """Maximum number of rows to return from a query. + + By default, the query result will be limited to 50 rows. + """ diff --git a/src/google/adk/tools/data_agent/credentials.py b/src/google/adk/tools/data_agent/credentials.py new file mode 100644 index 0000000000..3503cfa050 --- /dev/null +++ b/src/google/adk/tools/data_agent/credentials.py @@ -0,0 +1,36 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .._google_credentials import BaseGoogleCredentialsConfig + +DATA_AGENT_TOKEN_CACHE_KEY = "data_agent_token_cache" +DATA_AGENT_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] + + +class DataAgentCredentialsConfig(BaseGoogleCredentialsConfig): + """Data Agent Credentials Configuration for Google API tools.""" + + def __post_init__(self) -> DataAgentCredentialsConfig: + """Populate default scope if scopes is None.""" + super().__post_init__() + + if not self.scopes: + self.scopes = DATA_AGENT_DEFAULT_SCOPE + + # Set the token cache key + self._token_cache_key = DATA_AGENT_TOKEN_CACHE_KEY + + return self diff --git a/src/google/adk/tools/data_agent/data_agent_tool.py b/src/google/adk/tools/data_agent/data_agent_tool.py new file mode 100644 index 0000000000..8b5a88822d --- /dev/null +++ b/src/google/adk/tools/data_agent/data_agent_tool.py @@ -0,0 +1,491 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +from typing import Any + +from google.auth.credentials import Credentials +import requests + +from ..tool_context import ToolContext +from .config import DataAgentToolConfig + +BASE_URL = "https://geminidataanalytics.googleapis.com/v1beta" + + +def _get_http_headers( + credentials: Credentials, +) -> dict[str, str]: + """Prepares headers for HTTP requests.""" + if not credentials.token: + error_details = ( + "The provided credentials object does not have a valid access" + " token.\n\nThis is often because the credentials need to be" + " refreshed or require specific API scopes. Please ensure the" + " credentials are prepared correctly before calling this" + " function.\n\nThere may be other underlying causes as well." + ) + raise ValueError(error_details) + return { + "Authorization": f"Bearer {credentials.token}", + "Content-Type": "application/json", + } + + +def list_accessible_data_agents( + project_id: str, + credentials: Credentials, +) -> dict[str, Any]: + """Lists accessible data agents in a project. + + Args: + project_id: The project to list agents in. + credentials: The credentials to use for the request. + + Returns: + A dictionary containing the status and a list of data agents with their + detailed information, including name, display_name, description (if + available), create_time, update_time, and data_analytics_agent context, + or error details if the request fails. + + Examples: + >>> list_accessible_data_agents( + ... project_id="my-gcp-project", + ... credentials=credentials, + ... ) + { + "status": "SUCCESS", + "response": [ + { + "name": "projects/my-project/locations/global/dataAgents/agent1", + "displayName": "My Test Agent", + "createTime": "2025-10-01T22:44:22.473927629Z", + "updateTime": "2025-10-01T22:44:23.094541325Z", + "dataAnalyticsAgent": { + "publishedContext": { + "datasourceReferences": [{ + "bq": { + "tableReferences": [{ + "projectId": "my-project", + "datasetId": "dataset1", + "tableId": "table1" + }] + } + }] + } + } + }, + { + "name": "projects/my-project/locations/global/dataAgents/agent2", + "displayName": "", + "description": "Description for Agent 2.", + "createTime": "2025-06-23T20:23:48.650597312Z", + "updateTime": "2025-06-23T20:23:49.437095391Z", + "dataAnalyticsAgent": { + "publishedContext": { + "datasourceReferences": [{ + "bq": { + "tableReferences": [{ + "projectId": "another-project", + "datasetId": "dataset2", + "tableId": "table2" + }] + } + }], + "systemInstruction": "You are a helpful assistant.", + "options": {"analysis": {"python": {"enabled": True}}} + } + } + } + ] + } + """ + try: + headers = _get_http_headers(credentials) + list_url = f"{BASE_URL}/projects/{project_id}/locations/global/dataAgents:listAccessible" + resp = requests.get( + list_url, + headers=headers, + ) + resp.raise_for_status() + return { + "status": "SUCCESS", + "response": resp.json().get("dataAgents", []), + } + except Exception as ex: # pylint: disable=broad-except + return { + "status": "ERROR", + "error_details": repr(ex), + } + + +def get_data_agent_info( + data_agent_name: str, + credentials: Credentials, +) -> dict[str, Any]: + """Gets a data agent by name. + + Args: + data_agent_name: The name of the agent to get, in format + projects/{project}/locations/{location}/dataAgents/{agent}. + credentials: The credentials to use for the request. + + Returns: + A dictionary containing the status and details of a data agent, + including name, display_name, description (if available), + create_time, update_time, and data_analytics_agent context, + or error details if the request fails. + + Examples: + >>> get_data_agent_info( + ... + data_agent_name="projects/my-project/locations/global/dataAgents/agent-1", + ... credentials=credentials, + ... ) + { + "status": "SUCCESS", + "response": { + "name": "projects/my-project/locations/global/dataAgents/agent-1", + "description": "Description for Agent 1.", + "createTime": "2025-06-23T20:23:48.650597312Z", + "updateTime": "2025-06-23T20:23:49.437095391Z", + "dataAnalyticsAgent": { + "publishedContext": { + "systemInstruction": "You are a helpful assistant.", + "options": {"analysis": {"python": {"enabled": True}}}, + "datasourceReferences": { + "bq": { + "tableReferences": [{ + "projectId": "my-gcp-project", + "datasetId": "dataset1", + "tableId": "table1" + }] + } + }, + } + } + } + } + """ + try: + headers = _get_http_headers(credentials) + get_url = f"{BASE_URL}/{data_agent_name}" + resp = requests.get( + get_url, + headers=headers, + ) + resp.raise_for_status() + return { + "status": "SUCCESS", + "response": resp.json(), + } + except Exception as ex: # pylint: disable=broad-except + return { + "status": "ERROR", + "error_details": repr(ex), + } + + +def ask_data_agent( + data_agent_name: str, + query: str, + *, + credentials: Credentials, + settings: DataAgentToolConfig, + tool_context: ToolContext, +) -> dict[str, Any]: + """Asks a question to a data agent. + + Args: + data_agent_name: The resource name of an existing data agent to ask, in + format projects/{project}/locations/{location}/dataAgents/{agent}. + query: The question to ask the agent. + credentials: The credentials to use for the request. + tool_context: The context for the tool. + + Returns: + A dictionary with two keys: + - 'status': A string indicating the final status (e.g., "SUCCESS"). + - 'response': A list of dictionaries, where each dictionary + represents a step in the agent's execution process (e.g., SQL + generation, data retrieval, final answer). Note that the 'Answer' + step contains a text response which may summarize findings or refer + to previous steps of agent execution, such as 'Data Retrieved', in + which cases, the 'Answer' step does not include the result data. + + Examples: + A query to a data agent, showing the full return structure. + The original question: "Which customer from New York spent the most last + month?" + + >>> ask_data_agent( + ... + data_agent_name="projects/my-project/locations/global/dataAgents/sales-agent", + ... query="Which customer from New York spent the most last month?", + ... credentials=credentials, + ... tool_context=tool_context, + ... ) + { + "status": "SUCCESS", + "response": [ + { + "Question": "Which customer from New York spent the most last + month?" + }, + { + "Schema Resolved": [ + { + "source_name": "my-gcp-project.sales_data.customers", + "schema": { + "headers": ["Column", "Type", "Description", "Mode"], + "rows": [ + ["customer_id", "INT64", "Customer ID", "REQUIRED"], + ["customer_name", "STRING", "Customer Name", "NULLABLE"], + ] + } + } + ] + }, + { + "Retrieval Query": { + "Query Name": "top_spender", + "Question": "Find top spending customer from New York in the last + month." + } + }, + { + "SQL Generated": "SELECT t1.customer_name, SUM(t2.order_total) ... " + }, + { + "Data Retrieved": { + "headers": ["customer_name", "total_spent"], + "rows": [["Jane Doe", 1234.56]], + "summary": "Showing all 1 rows." + } + }, + { + "Answer": "The customer who spent the most last month was Jane Doe." + } + ] + } + """ + try: + headers = _get_http_headers(credentials) + + agent_info = get_data_agent_info(data_agent_name, credentials) + if agent_info.get("status") == "ERROR": + return agent_info + parent = data_agent_name.rsplit("/", 2)[0] + chat_url = f"{BASE_URL}/{parent}:chat" + chat_payload = { + "messages": [{"userMessage": {"text": query}}], + "dataAgentContext": { + "dataAgent": data_agent_name, + }, + "clientIdEnum": "GOOGLE_ADK", + } + resp = _get_stream( + chat_url, + chat_payload, + headers=headers, + max_query_result_rows=settings.max_query_result_rows, + ) + return {"status": "SUCCESS", "response": resp} + except Exception as ex: # pylint: disable=broad-except + return { + "status": "ERROR", + "error_details": repr(ex), + } + + +def _get_stream( + url: str, + ca_payload: dict[str, Any], + *, + headers: dict[str, str], + max_query_result_rows: int, +) -> list[dict[str, Any]]: + """Sends a JSON request to a streaming API and returns a list of messages.""" + s = requests.Session() + + accumulator = "" + messages = [] + + with s.post(url, json=ca_payload, headers=headers, stream=True) as resp: + for line in resp.iter_lines(): + if not line: + continue + + decoded_line = str(line, encoding="utf-8") + + if decoded_line == "[{": + accumulator = "{" + elif decoded_line == "}]": + accumulator += "}" + elif decoded_line == ",": + continue + else: + accumulator += decoded_line + + try: + data_json = json.loads(accumulator) + except ValueError: + continue + if "systemMessage" not in data_json: + if "error" in data_json: + _append_message( + messages, + _handle_error(data_json["error"]), + ) + continue + + system_message = data_json["systemMessage"] + if "text" in system_message: + _append_message( + messages, + _handle_text_response(system_message["text"]), + ) + elif "schema" in system_message: + _append_message( + messages, + _handle_schema_response(system_message["schema"]), + ) + elif "data" in system_message: + _append_message( + messages, + _handle_data_response( + system_message["data"], max_query_result_rows + ), + ) + accumulator = "" + return messages + + +def _format_bq_table_ref(table_ref: dict[str, str]) -> str: + """Formats a BigQuery table reference dictionary into a string.""" + return f"{table_ref.get('projectId')}.{table_ref.get('datasetId')}.{table_ref.get('tableId')}" + + +def _format_schema_as_dict( + data: dict[str, Any], +) -> dict[str, list[Any]]: + """Extracts schema fields into a dictionary.""" + fields = data.get("fields", []) + if not fields: + return {"columns": []} + + column_details = [] + headers = ["Column", "Type", "Description", "Mode"] + rows: list[list[str, str, str, str]] = [] + for field in fields: + row_list = [ + field.get("name", ""), + field.get("type", ""), + field.get("description", ""), + field.get("mode", ""), + ] + rows.append(row_list) + + return {"headers": headers, "rows": rows} + + +def _format_datasource_as_dict(datasource: dict[str, Any]) -> dict[str, Any]: + """Formats a full datasource object into a dictionary with its name and schema.""" + source_name = _format_bq_table_ref(datasource["bigqueryTableReference"]) + + schema = _format_schema_as_dict(datasource["schema"]) + return {"source_name": source_name, "schema": schema} + + +def _handle_text_response(resp: dict[str, Any]) -> dict[str, str]: + """Formats a text response into a dictionary.""" + parts = resp.get("parts", []) + return {"Answer": "".join(parts)} + + +def _handle_schema_response(resp: dict[str, Any]) -> dict[str, Any]: + """Formats a schema response into a dictionary.""" + if "query" in resp: + return {"Question": resp["query"].get("question", "")} + elif "result" in resp: + datasources = resp["result"].get("datasources", []) + # Format each datasource and join them with newlines + formatted_sources = [_format_datasource_as_dict(ds) for ds in datasources] + return {"Schema Resolved": formatted_sources} + return {} + + +def _handle_data_response( + resp: dict[str, Any], max_query_result_rows: int +) -> dict[str, Any]: + """Formats a data response into a dictionary.""" + if "query" in resp: + query = resp["query"] + return { + "Retrieval Query": { + "Query Name": query.get("name", "N/A"), + "Question": query.get("question", "N/A"), + } + } + elif "generatedSql" in resp: + return {"SQL Generated": resp["generatedSql"]} + elif "result" in resp: + schema = resp["result"]["schema"] + headers = [field.get("name") for field in schema.get("fields", [])] + + all_rows = resp["result"].get("data", []) + total_rows = len(all_rows) + + compact_rows = [] + for row_dict in all_rows[:max_query_result_rows]: + row_values = [row_dict.get(header) for header in headers] + compact_rows.append(row_values) + + summary_string = f"Showing all {total_rows} rows." + if total_rows > max_query_result_rows: + summary_string = ( + f"Showing the first {len(compact_rows)} of {total_rows} total rows." + ) + + return { + "Data Retrieved": { + "headers": headers, + "rows": compact_rows, + "summary": summary_string, + } + } + + return {} + + +def _handle_error(resp: dict[str, Any]) -> dict[str, dict[str, Any]]: + """Formats an error response into a dictionary.""" + return { + "Error": { + "Code": resp.get("code", "N/A"), + "Message": resp.get("message", "No message provided."), + } + } + + +def _append_message( + messages: list[dict[str, Any]], + new_message: dict[str, Any], +): + """Appends a message to the list.""" + if not new_message: + return + + messages.append(new_message) diff --git a/src/google/adk/tools/data_agent/data_agent_toolset.py b/src/google/adk/tools/data_agent/data_agent_toolset.py new file mode 100644 index 0000000000..3579770fb5 --- /dev/null +++ b/src/google/adk/tools/data_agent/data_agent_toolset.py @@ -0,0 +1,93 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from google.adk.agents.readonly_context import ReadonlyContext +from typing_extensions import override + +from . import data_agent_tool +from ...features import experimental +from ...features import FeatureName +from ...tools.base_tool import BaseTool +from ...tools.base_toolset import BaseToolset +from ...tools.base_toolset import ToolPredicate +from ...tools.google_tool import GoogleTool +from .config import DataAgentToolConfig +from .credentials import DataAgentCredentialsConfig + + +@experimental(FeatureName.DATA_AGENT_TOOLSET) +class DataAgentToolset(BaseToolset): + """Data Agent Toolset contains tools for interacting with data agents.""" + + def __init__( + self, + *, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + credentials_config: Optional[DataAgentCredentialsConfig] = None, + data_agent_tool_config: Optional[DataAgentToolConfig] = None, + ): + super().__init__(tool_filter=tool_filter) + self._credentials_config = credentials_config + self._tool_settings = ( + data_agent_tool_config + if data_agent_tool_config + else DataAgentToolConfig() + ) + + def _is_tool_selected( + self, tool: BaseTool, readonly_context: ReadonlyContext + ) -> bool: + if self.tool_filter is None: + return True + + if isinstance(self.tool_filter, ToolPredicate): + return self.tool_filter(tool, readonly_context) + + if isinstance(self.tool_filter, list): + return tool.name in self.tool_filter + + return False + + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[BaseTool]: + all_tools = [ + GoogleTool( + func=func, + credentials_config=self._credentials_config, + tool_settings=self._tool_settings, + ) + for func in [ + data_agent_tool.list_accessible_data_agents, + data_agent_tool.get_data_agent_info, + data_agent_tool.ask_data_agent, + ] + ] + + return [ + tool + for tool in all_tools + if self._is_tool_selected(tool, readonly_context) + ] + + @override + async def close(self): + pass diff --git a/tests/unittests/tools/data_agent/test_data_agent_tool.py b/tests/unittests/tools/data_agent/test_data_agent_tool.py new file mode 100644 index 0000000000..6aa57e650f --- /dev/null +++ b/tests/unittests/tools/data_agent/test_data_agent_tool.py @@ -0,0 +1,198 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +from unittest import mock + +from google.adk.tools.data_agent import data_agent_tool +from google.adk.tools.tool_context import ToolContext +import pytest +import requests +import yaml + + +@mock.patch.object(data_agent_tool, "requests", autospec=True) +def test_list_accessible_data_agents_success(mock_requests): + """Tests list_accessible_data_agents success path.""" + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + mock_response = mock.Mock() + mock_response.json.return_value = {"dataAgents": ["agent1", "agent2"]} + mock_response.raise_for_status.return_value = None + mock_requests.get.return_value = mock_response + result = data_agent_tool.list_accessible_data_agents( + "test-project", mock_creds + ) + assert result["status"] == "SUCCESS" + assert result["response"] == ["agent1", "agent2"] + mock_requests.get.assert_called_once() + + +@mock.patch.object(data_agent_tool, "requests", autospec=True) +def test_list_accessible_data_agents_exception(mock_requests): + """Tests list_accessible_data_agents exception path.""" + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + mock_requests.get.side_effect = Exception("List failed!") + result = data_agent_tool.list_accessible_data_agents( + "test-project", mock_creds + ) + assert result["status"] == "ERROR" + assert "List failed!" in result["error_details"] + mock_requests.get.assert_called_once() + + +@mock.patch.object(data_agent_tool, "requests", autospec=True) +def test_get_data_agent_info_success(mock_requests): + """Tests get_data_agent_info success path.""" + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + mock_response = mock.Mock() + mock_response.json.return_value = "agent_info" + mock_response.raise_for_status.return_value = None + mock_requests.get.return_value = mock_response + result = data_agent_tool.get_data_agent_info("agent_name", mock_creds) + assert result["status"] == "SUCCESS" + assert result["response"] == "agent_info" + mock_requests.get.assert_called_once() + + +@mock.patch.object(data_agent_tool, "requests", autospec=True) +def test_get_data_agent_info_exception(mock_requests): + """Tests get_data_agent_info exception path.""" + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + mock_requests.get.side_effect = Exception("Get failed!") + result = data_agent_tool.get_data_agent_info("agent_name", mock_creds) + assert result["status"] == "ERROR" + assert "Get failed!" in result["error_details"] + mock_requests.get.assert_called_once() + + +@mock.patch.object(data_agent_tool, "_get_stream", autospec=True) +@mock.patch.object(data_agent_tool, "requests", autospec=True) +@mock.patch.object(data_agent_tool, "get_data_agent_info", autospec=True) +def test_ask_data_agent_success( + mock_get_agent_info, mock_requests, mock_get_stream +): + """Tests ask_data_agent success path.""" + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + mock_get_agent_info.return_value = {"status": "SUCCESS", "response": {}} + mock_get_stream.return_value = [ + {"Answer": "response1"}, + {"Answer": "response2"}, + ] + mock_invocation_context = mock.Mock() + mock_invocation_context.session.state = {} + mock_context = ToolContext(mock_invocation_context) + mock_settings = mock.Mock() + + result = data_agent_tool.ask_data_agent( + "projects/p/locations/l/dataAgents/a", + "query", + credentials=mock_creds, + tool_context=mock_context, + settings=mock_settings, + ) + assert result["status"] == "SUCCESS" + assert result["response"] == [ + {"Answer": "response1"}, + {"Answer": "response2"}, + ] + mock_get_agent_info.assert_called_once() + mock_get_stream.assert_called_once() + + +@mock.patch.object(data_agent_tool, "_get_stream", autospec=True) +@mock.patch.object(data_agent_tool, "requests", autospec=True) +@mock.patch.object(data_agent_tool, "get_data_agent_info", autospec=True) +def test_ask_data_agent_exception( + mock_get_agent_info, mock_requests, mock_get_stream +): + """Tests ask_data_agent exception path.""" + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + mock_get_agent_info.return_value = {"status": "SUCCESS", "response": {}} + mock_get_stream.side_effect = Exception("Chat failed!") + mock_invocation_context = mock.Mock() + mock_invocation_context.session.state = {} + mock_context = ToolContext(mock_invocation_context) + mock_settings = mock.Mock() + + result = data_agent_tool.ask_data_agent( + "projects/p/locations/l/dataAgents/a", + "query", + credentials=mock_creds, + tool_context=mock_context, + settings=mock_settings, + ) + assert result["status"] == "ERROR" + assert "Chat failed!" in result["error_details"] + mock_get_stream.assert_called_once() + + +@pytest.mark.parametrize( + "case_file_path", + [ + pytest.param("test_data/ask_data_insights_penguins_highest_mass.yaml"), + ], +) +@mock.patch.object(requests.Session, "post") +def test_get_stream_from_file(mock_post, case_file_path): + """Runs a full integration test for the _get_stream function using data from a specific file.""" + # 1. Construct the full, absolute path to the data file + full_path = pathlib.Path(__file__).parent.parent / "bigquery" / case_file_path + + # 2. Load the test case data from the specified YAML file + with open(full_path, "r", encoding="utf-8") as f: + case_data = yaml.safe_load(f) + + # 3. Prepare the mock stream and expected output from the loaded data + mock_stream_str = case_data["mock_api_stream"] + fake_stream_lines = [ + line.encode("utf-8") for line in mock_stream_str.splitlines() + ] + # Load the expected output as a list of dictionaries, not a single string + expected_final_list = case_data["expected_output"] + data_retrieved = { + "Data Retrieved": { + "headers": ["island", "average_body_mass"], + "rows": [ + ["Biscoe", "4716.017964071853"], + ["Dream", "3712.9032258064512"], + ["Torgersen", "3706.3725490196075"], + ], + "summary": "Showing all 3 rows.", + } + } + expected_final_list.insert(-1, data_retrieved) + + # 4. Configure the mock for requests.post + mock_response = mock.Mock() + mock_response.iter_lines.return_value = fake_stream_lines + # Add raise_for_status mock which is called in the updated code + mock_response.raise_for_status.return_value = None + mock_post.return_value.__enter__.return_value = mock_response + + # 5. Call the function under test + result = data_agent_tool._get_stream( # pylint: disable=protected-access + url="fake_url", + ca_payload={}, + headers={}, + max_query_result_rows=50, + ) + + # 6. Assert that the final list of dicts matches the expected output + assert result == expected_final_list diff --git a/tests/unittests/tools/data_agent/test_data_agent_toolset.py b/tests/unittests/tools/data_agent/test_data_agent_toolset.py new file mode 100644 index 0000000000..ccc478db7e --- /dev/null +++ b/tests/unittests/tools/data_agent/test_data_agent_toolset.py @@ -0,0 +1,128 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.tools.data_agent import DataAgentCredentialsConfig +from google.adk.tools.data_agent import DataAgentToolset +from google.adk.tools.data_agent.config import DataAgentToolConfig +from google.adk.tools.google_tool import GoogleTool +import pytest + + +@pytest.mark.asyncio +async def test_data_agent_toolset_tools_default(): + """Test default DataAgentToolset. + + This test verifies the behavior of the DataAgentToolset when no filter is + specified. + """ + credentials_config = DataAgentCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = DataAgentToolset( + credentials_config=credentials_config, data_agent_tool_config=None + ) + # Verify that the tool config is initialized to default values. + assert isinstance(toolset._tool_settings, DataAgentToolConfig) # pylint: disable=protected-access + assert toolset._tool_settings.__dict__ == DataAgentToolConfig().__dict__ # pylint: disable=protected-access + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == 3 + assert all(isinstance(tool, GoogleTool) for tool in tools) + + expected_tool_names = set([ + "list_accessible_data_agents", + "get_data_agent_info", + "ask_data_agent", + ]) + actual_tool_names = {tool.name for tool in tools} + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + "selected_tools", + [ + pytest.param([], id="None"), + pytest.param( + ["list_accessible_data_agents", "get_data_agent_info"], + id="list_and_get", + ), + pytest.param(["ask_data_agent"], id="ask"), + ], +) +@pytest.mark.asyncio +async def test_data_agent_toolset_tools_selective(selected_tools): + """Test DataAgentToolset with filter. + + This test verifies the behavior of the DataAgentToolset when filter is + specified. A use case for this would be when the agent builder wants to + use only a subset of the tools provided by the toolset. + """ + credentials_config = DataAgentCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = DataAgentToolset( + credentials_config=credentials_config, tool_filter=selected_tools + ) + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(selected_tools) + assert all(isinstance(tool, GoogleTool) for tool in tools) + + expected_tool_names = set(selected_tools) + actual_tool_names = {tool.name for tool in tools} + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + ("selected_tools", "returned_tools"), + [ + pytest.param(["unknown"], [], id="all-unknown"), + pytest.param( + ["unknown", "ask_data_agent"], + ["ask_data_agent"], + id="mixed-known-unknown", + ), + ], +) +@pytest.mark.asyncio +async def test_data_agent_toolset_unknown_tool(selected_tools, returned_tools): + """Test DataAgentToolset with filter. + + This test verifies the behavior of the DataAgentToolset when filter is + specified with an unknown tool. + """ + credentials_config = DataAgentCredentialsConfig( + client_id="abc", client_secret="def" + ) + + toolset = DataAgentToolset( + credentials_config=credentials_config, tool_filter=selected_tools + ) + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(returned_tools) + assert all(isinstance(tool, GoogleTool) for tool in tools) + + expected_tool_names = set(returned_tools) + actual_tool_names = {tool.name for tool in tools} + assert actual_tool_names == expected_tool_names diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index f63301c1cb..b5f59be0fc 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -942,3 +942,225 @@ async def test_run_async_handles_none_parts_in_response(): ) assert tool_result == '' + + +class TestAgentToolWithCompositeAgents: + """Tests for AgentTool wrapping composite agents (SequentialAgent, etc.).""" + + def test_sequential_agent_with_first_sub_agent_input_schema(self): + """Test that AgentTool exposes input_schema from first sub-agent of SequentialAgent.""" + + class CustomInput(BaseModel): + query: str + language: str + + first_agent = Agent( + name='first_agent', + model=testing_utils.MockModel.create(responses=['response1']), + input_schema=CustomInput, + ) + + second_agent = Agent( + name='second_agent', + model=testing_utils.MockModel.create(responses=['response2']), + ) + + sequence = SequentialAgent( + name='sequence', + description='Process the query through multiple steps', + sub_agents=[first_agent, second_agent], + ) + + agent_tool = AgentTool(agent=sequence) + declaration = agent_tool._get_declaration() + + # Should expose CustomInput schema, not fallback to 'request' + assert declaration.name == 'sequence' + assert declaration.description == 'Process the query through multiple steps' + assert declaration.parameters.properties['query'].type == 'STRING' + assert declaration.parameters.properties['language'].type == 'STRING' + assert 'request' not in declaration.parameters.properties + + def test_sequential_agent_without_input_schema_falls_back_to_request(self): + """Test that AgentTool falls back to 'request' when no sub-agent has input_schema.""" + + first_agent = Agent( + name='first_agent', + model=testing_utils.MockModel.create(responses=['response1']), + ) + + second_agent = Agent( + name='second_agent', + model=testing_utils.MockModel.create(responses=['response2']), + ) + + sequence = SequentialAgent( + name='sequence', + description='Process the query through multiple steps', + sub_agents=[first_agent, second_agent], + ) + + agent_tool = AgentTool(agent=sequence) + declaration = agent_tool._get_declaration() + + # Should fall back to 'request' parameter + assert declaration.name == 'sequence' + assert declaration.parameters.properties['request'].type == 'STRING' + assert 'query' not in declaration.parameters.properties + + @mark.parametrize( + 'env_variables', + [ + 'VERTEX', + ], + indirect=True, + ) + def test_sequential_agent_with_last_sub_agent_output_schema( + self, env_variables + ): + """Test that AgentTool uses output_schema from last sub-agent of SequentialAgent.""" + + class CustomOutput(BaseModel): + result: str + + first_agent = Agent( + name='first_agent', + model=testing_utils.MockModel.create(responses=['response1']), + ) + + second_agent = Agent( + name='second_agent', + model=testing_utils.MockModel.create(responses=['response2']), + output_schema=CustomOutput, + ) + + sequence = SequentialAgent( + name='sequence', + description='Process the query', + sub_agents=[first_agent, second_agent], + ) + + agent_tool = AgentTool(agent=sequence) + declaration = agent_tool._get_declaration() + + # Should have object response schema from last sub-agent + assert declaration.response is not None + assert declaration.response.type == types.Type.OBJECT + + def test_nested_sequential_agent_input_schema(self): + """Test that AgentTool recursively finds input_schema in nested composite agents.""" + + class CustomInput(BaseModel): + deep_query: str + + inner_agent = Agent( + name='inner_agent', + model=testing_utils.MockModel.create(responses=['response1']), + input_schema=CustomInput, + ) + + inner_sequence = SequentialAgent( + name='inner_sequence', + sub_agents=[inner_agent], + ) + + outer_sequence = SequentialAgent( + name='outer_sequence', + description='Nested sequence', + sub_agents=[inner_sequence], + ) + + agent_tool = AgentTool(agent=outer_sequence) + declaration = agent_tool._get_declaration() + + # Should recursively find CustomInput from inner_agent + assert declaration.name == 'outer_sequence' + assert 'deep_query' in declaration.parameters.properties + assert declaration.parameters.properties['deep_query'].type == 'STRING' + assert 'request' not in declaration.parameters.properties + + @mark.parametrize( + 'env_variables', + [ + 'GOOGLE_AI', + 'VERTEX', + ], + indirect=True, + ) + def test_sequential_agent_custom_schema_end_to_end(self, env_variables): + """Test end-to-end flow with SequentialAgent using custom input/output schema.""" + + class CustomInput(BaseModel): + custom_input: str + + class CustomOutput(BaseModel): + custom_output: str + + function_call_seq = Part.from_function_call( + name='sequence', args={'custom_input': 'test_input'} + ) + + mock_model = testing_utils.MockModel.create( + responses=[ + function_call_seq, + '{"custom_output": "step1_response"}', + '{"custom_output": "final_response"}', + 'root_response', + ] + ) + + first_agent = Agent( + name='first_agent', + model=mock_model, + input_schema=CustomInput, + ) + + second_agent = Agent( + name='second_agent', + model=mock_model, + output_schema=CustomOutput, + output_key='seq_output', + ) + + sequence = SequentialAgent( + name='sequence', + description='A sequential pipeline', + sub_agents=[first_agent, second_agent], + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=sequence)], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + runner.run('test1') + + # Verify the tool declaration sent to LLM has the correct schema + # The first request is from root_agent, which should have the tool declaration + first_request = mock_model.requests[0] + tool_declarations = first_request.config.tools + assert len(tool_declarations) == 1 + + sequence_tool = tool_declarations[0].function_declarations[0] + assert sequence_tool.name == 'sequence' + # Should have 'custom_input' parameter from first sub-agent's input_schema + assert 'custom_input' in sequence_tool.parameters.properties + # Should NOT have the fallback 'request' parameter + assert 'request' not in sequence_tool.parameters.properties + + def test_empty_sequential_agent_falls_back_to_request(self): + """Test that AgentTool with empty SequentialAgent falls back to 'request'.""" + + sequence = SequentialAgent( + name='empty_sequence', + description='An empty sequence', + sub_agents=[], + ) + + agent_tool = AgentTool(agent=sequence) + declaration = agent_tool._get_declaration() + + # Should fall back to 'request' parameter + assert declaration.parameters.properties['request'].type == 'STRING'