From 672b57f1b76580023d1f348de76227291a9c1012 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 14 Jan 2026 15:48:52 -0800 Subject: [PATCH 01/11] chore: add a sample BigQuery agent using BigQuery MCP tools PiperOrigin-RevId: 856400285 --- contributing/samples/bigquery/README.md | 10 ++-- contributing/samples/bigquery_mcp/README.md | 55 +++++++++++++++++++ contributing/samples/bigquery_mcp/__init__.py | 15 +++++ contributing/samples/bigquery_mcp/agent.py | 51 +++++++++++++++++ src/google/adk/cli/cli_tools_click.py | 2 - 5 files changed, 126 insertions(+), 7 deletions(-) create mode 100644 contributing/samples/bigquery_mcp/README.md create mode 100644 contributing/samples/bigquery_mcp/__init__.py create mode 100644 contributing/samples/bigquery_mcp/agent.py diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index 960b6f40c2..f6e3bb66f9 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -24,11 +24,11 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: 5. `get_job_info` Fetches metadata about a BigQuery job. -5. `execute_sql` +6. `execute_sql` Runs or dry-runs a SQL query in BigQuery. -6. `ask_data_insights` +7. `ask_data_insights` Natural language-in, natural language-out tool that answers questions about structured data in BigQuery. Provides a one-stop solution for generating @@ -38,18 +38,18 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: the official [Conversational Analytics API documentation](https://cloud.google.com/gemini/docs/conversational-analytics-api/overview) for instructions. -7. `forecast` +8. `forecast` Perform time series forecasting using BigQuery's `AI.FORECAST` function, leveraging the TimesFM 2.0 model. -8. `analyze_contribution` +9. `analyze_contribution` Perform contribution analysis in BigQuery by creating a temporary `CONTRIBUTION_ANALYSIS` model and then querying it with `ML.GET_INSIGHTS` to find top contributors for a given metric. -9. `detect_anomalies` +10. `detect_anomalies` Perform time series anomaly detection in BigQuery by creating a temporary `ARIMA_PLUS` model and then querying it with diff --git a/contributing/samples/bigquery_mcp/README.md b/contributing/samples/bigquery_mcp/README.md new file mode 100644 index 0000000000..bce19976ca --- /dev/null +++ b/contributing/samples/bigquery_mcp/README.md @@ -0,0 +1,55 @@ +# BigQuery MCP Toolset Sample + +## Introduction + +This sample agent demonstrates using ADK's `McpToolset` to interact with +BigQuery's official MCP endpoint, allowing an agent to access and execute +toole by leveraging the Model Context Protocol (MCP). These tools include: + + +1. `list_dataset_ids` + + Fetches BigQuery dataset ids present in a GCP project. + +2. `get_dataset_info` + + Fetches metadata about a BigQuery dataset. + +3. `list_table_ids` + + Fetches table ids present in a BigQuery dataset. + +4. `get_table_info` + + Fetches metadata about a BigQuery table. + +5. `execute_sql` + + Runs or dry-runs a SQL query in BigQuery. + +## How to use + +Set up your project and local authentication by following the guide +[Use the BigQuery remote MCP server](https://docs.cloud.google.com/bigquery/docs/use-bigquery-mcp). +This agent uses Application Default Credentials (ADC) to authenticate with the +BigQuery MCP endpoint. + +Set up environment variables in your `.env` file for using +[Google AI Studio](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-ai-studio) +or +[Google Cloud Vertex AI](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai) +for the LLM service for your agent. For example, for using Google AI Studio you +would set: + +* GOOGLE_GENAI_USE_VERTEXAI=FALSE +* GOOGLE_API_KEY={your api key} + +Then run the agent using `adk run .` or `adk web .` in this directory. + +## Sample prompts + +* which weather datasets exist in bigquery public data? +* tell me more about noaa_lightning +* which tables exist in the ml_datasets dataset? +* show more details about the penguins table +* compute penguins population per island. diff --git a/contributing/samples/bigquery_mcp/__init__.py b/contributing/samples/bigquery_mcp/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/bigquery_mcp/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/bigquery_mcp/agent.py b/contributing/samples/bigquery_mcp/agent.py new file mode 100644 index 0000000000..4116bc6cf4 --- /dev/null +++ b/contributing/samples/bigquery_mcp/agent.py @@ -0,0 +1,51 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +import google.auth + +BIGQUERY_AGENT_NAME = "adk_sample_bigquery_mcp_agent" +BIGQUERY_MCP_ENDPOINT = "https://bigquery.googleapis.com/mcp" +BIGQUERY_SCOPE = "https://www.googleapis.com/auth/bigquery" + +# Initialize the tools to use the application default credentials. +# https://cloud.google.com/docs/authentication/provide-credentials-adc +credentials, project_id = google.auth.default(scopes=[BIGQUERY_SCOPE]) +credentials.refresh(google.auth.transport.requests.Request()) +oauth_token = credentials.token + +bigquery_mcp_toolset = McpToolset( + connection_params=StreamableHTTPConnectionParams( + url=BIGQUERY_MCP_ENDPOINT, + headers={"Authorization": f"Bearer {oauth_token}"}, + ) +) + +# The variable name `root_agent` determines what your root agent is for the +# debug CLI +root_agent = LlmAgent( + model="gemini-2.5-flash", + name=BIGQUERY_AGENT_NAME, + description=( + "Agent to answer questions about BigQuery data and models and execute" + " SQL queries using MCP." + ), + instruction="""\ + You are a data science agent with access to several BigQuery tools provided via MCP. + Make use of those tools to answer the user's questions. + """, + tools=[bigquery_mcp_toolset], +) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 91b4a07b5d..5d7611f217 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -1291,7 +1291,6 @@ async def _lifespan(app: FastAPI): host=host, port=port, reload=reload, - log_level=log_level.lower(), ) server = uvicorn.Server(config) @@ -1368,7 +1367,6 @@ def cli_api_server( host=host, port=port, reload=reload, - log_level=log_level.lower(), ) server = uvicorn.Server(config) server.run() From 7b035aa9fc43a43489aeffea8f877cd7eaa09f35 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 14 Jan 2026 15:59:46 -0800 Subject: [PATCH 02/11] chore: Always log api backend when connecting to live model Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 856404282 --- src/google/adk/models/google_llm.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index c38f854c93..c243f56a6a 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -378,6 +378,13 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: types.Part.from_text(text=llm_request.config.system_instruction) ], ) + + logger.info( + 'Trying to connect to live model: %s with api backend: %s', + llm_request.model, + self._api_backend, + ) + if ( llm_request.live_connect_config.session_resumption and llm_request.live_connect_config.session_resumption.transparent @@ -386,17 +393,13 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: 'session resumption config: %s', llm_request.live_connect_config.session_resumption, ) - logger.debug( - 'self._api_backend: %s', - self._api_backend, - ) + if self._api_backend == GoogleLLMVariant.GEMINI_API: raise ValueError( 'Transparent session resumption is only supported for Vertex AI' ' backend. Please use Vertex AI backend.' ) llm_request.live_connect_config.tools = llm_request.config.tools - logger.info('Connecting to live for model: %s', llm_request.model) logger.debug('Connecting to live with llm_request:%s', llm_request) logger.debug('Live connect config: %s', llm_request.live_connect_config) async with self._live_api_client.aio.live.connect( From fdc98d5c927bfef021e87cf72103892e4c2ac12a Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 14 Jan 2026 16:00:20 -0800 Subject: [PATCH 03/11] fix: Convert unsupported inline artifact MIME types to text in LoadArtifactsTool The LoadArtifactsTool now checks if an artifact's inline data MIME type is supported by Gemini. If not, it attempts to convert the artifact content into a text Part Close #4028 Co-authored-by: George Weale PiperOrigin-RevId: 856404510 --- src/google/adk/tools/load_artifacts_tool.py | 105 +++++++++++- .../tools/test_load_artifacts_tool.py | 162 ++++++++++++++++++ 2 files changed, 265 insertions(+), 2 deletions(-) create mode 100644 tests/unittests/tools/test_load_artifacts_tool.py diff --git a/src/google/adk/tools/load_artifacts_tool.py b/src/google/adk/tools/load_artifacts_tool.py index 0e91380517..dbdc1f26f2 100644 --- a/src/google/adk/tools/load_artifacts_tool.py +++ b/src/google/adk/tools/load_artifacts_tool.py @@ -14,6 +14,8 @@ from __future__ import annotations +import base64 +import binascii import json import logging from typing import Any @@ -24,6 +26,19 @@ from .base_tool import BaseTool +# MIME types Gemini accepts for inline data in requests. +_GEMINI_SUPPORTED_INLINE_MIME_PREFIXES = ( + 'image/', + 'audio/', + 'video/', +) +_GEMINI_SUPPORTED_INLINE_MIME_TYPES = frozenset({'application/pdf'}) +_TEXT_LIKE_MIME_TYPES = frozenset({ + 'application/csv', + 'application/json', + 'application/xml', +}) + if TYPE_CHECKING: from ..models.llm_request import LlmRequest from .tool_context import ToolContext @@ -31,6 +46,79 @@ logger = logging.getLogger('google_adk.' + __name__) +def _normalize_mime_type(mime_type: str | None) -> str | None: + """Returns the normalized MIME type, without parameters like charset.""" + if not mime_type: + return None + return mime_type.split(';', 1)[0].strip() + + +def _is_inline_mime_type_supported(mime_type: str | None) -> bool: + """Returns True if Gemini accepts this MIME type as inline data.""" + normalized = _normalize_mime_type(mime_type) + if not normalized: + return False + return normalized.startswith(_GEMINI_SUPPORTED_INLINE_MIME_PREFIXES) or ( + normalized in _GEMINI_SUPPORTED_INLINE_MIME_TYPES + ) + + +def _maybe_base64_to_bytes(data: str) -> bytes | None: + """Best-effort base64 decode for both std and urlsafe formats.""" + try: + return base64.b64decode(data, validate=True) + except (binascii.Error, ValueError): + try: + return base64.urlsafe_b64decode(data) + except (binascii.Error, ValueError): + return None + + +def _as_safe_part_for_llm( + artifact: types.Part, artifact_name: str +) -> types.Part: + """Returns a Part that is safe to send to Gemini.""" + inline_data = artifact.inline_data + if inline_data is None: + return artifact + + if _is_inline_mime_type_supported(inline_data.mime_type): + return artifact + + mime_type = _normalize_mime_type(inline_data.mime_type) or ( + 'application/octet-stream' + ) + data = inline_data.data + if data is None: + return types.Part.from_text( + text=( + f'[Artifact: {artifact_name}, type: {mime_type}. ' + 'No inline data was provided.]' + ) + ) + + if isinstance(data, str): + decoded = _maybe_base64_to_bytes(data) + if decoded is None: + return types.Part.from_text(text=data) + data = decoded + + if mime_type.startswith('text/') or mime_type in _TEXT_LIKE_MIME_TYPES: + try: + return types.Part.from_text(text=data.decode('utf-8')) + except UnicodeDecodeError: + return types.Part.from_text(text=data.decode('utf-8', errors='replace')) + + size_kb = len(data) / 1024 + return types.Part.from_text( + text=( + f'[Binary artifact: {artifact_name}, ' + f'type: {mime_type}, size: {size_kb:.1f} KB. ' + 'Content cannot be displayed inline.]' + ) + ) + + class LoadArtifactsTool(BaseTool): """A tool that loads the artifacts and adds them to the session.""" @@ -108,7 +196,8 @@ async def _append_artifacts_to_llm_request( if llm_request.contents and llm_request.contents[-1].parts: function_response = llm_request.contents[-1].parts[0].function_response if function_response and function_response.name == 'load_artifacts': - artifact_names = function_response.response['artifact_names'] + response = function_response.response or {} + artifact_names = response.get('artifact_names', []) for artifact_name in artifact_names: # Try session-scoped first (default behavior) artifact = await tool_context.load_artifact(artifact_name) @@ -122,6 +211,18 @@ async def _append_artifacts_to_llm_request( if artifact is None: logger.warning('Artifact "%s" not found, skipping', artifact_name) continue + + artifact_part = _as_safe_part_for_llm(artifact, artifact_name) + if artifact_part is not artifact: + mime_type = ( + artifact.inline_data.mime_type if artifact.inline_data else None + ) + logger.debug( + 'Converted artifact "%s" (mime_type=%s) to text Part', + artifact_name, + mime_type, + ) + llm_request.contents.append( types.Content( role='user', @@ -129,7 +230,7 @@ async def _append_artifacts_to_llm_request( types.Part.from_text( text=f'Artifact {artifact_name} is:' ), - artifact, + artifact_part, ], ) ) diff --git a/tests/unittests/tools/test_load_artifacts_tool.py b/tests/unittests/tools/test_load_artifacts_tool.py new file mode 100644 index 0000000000..1ea50bb33c --- /dev/null +++ b/tests/unittests/tools/test_load_artifacts_tool.py @@ -0,0 +1,162 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 + +from google.adk.models.llm_request import LlmRequest +from google.adk.tools.load_artifacts_tool import _maybe_base64_to_bytes +from google.adk.tools.load_artifacts_tool import load_artifacts_tool +from google.genai import types +from pytest import mark + + +class _StubToolContext: + """Minimal ToolContext stub for LoadArtifactsTool tests.""" + + def __init__(self, artifacts_by_name: dict[str, types.Part]): + self._artifacts_by_name = artifacts_by_name + + async def list_artifacts(self) -> list[str]: + return list(self._artifacts_by_name.keys()) + + async def load_artifact(self, name: str) -> types.Part | None: + return self._artifacts_by_name.get(name) + + +@mark.asyncio +async def test_load_artifacts_converts_unsupported_mime_to_text(): + """Unsupported inline MIME types are converted to text parts.""" + artifact_name = 'test.csv' + csv_bytes = b'col1,col2\n1,2\n' + artifact = types.Part( + inline_data=types.Blob(data=csv_bytes, mime_type='application/csv') + ) + + tool_context = _StubToolContext({artifact_name: artifact}) + llm_request = LlmRequest( + contents=[ + types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='load_artifacts', + response={'artifact_names': [artifact_name]}, + ) + ) + ], + ) + ] + ) + + await load_artifacts_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.contents[-1].parts[0].text == ( + f'Artifact {artifact_name} is:' + ) + artifact_part = llm_request.contents[-1].parts[1] + assert artifact_part.inline_data is None + assert artifact_part.text == csv_bytes.decode('utf-8') + + +@mark.asyncio +async def test_load_artifacts_converts_base64_unsupported_mime_to_text(): + """Unsupported base64 string data is converted to text parts.""" + artifact_name = 'test.csv' + csv_bytes = b'col1,col2\n1,2\n' + csv_base64 = base64.b64encode(csv_bytes).decode('ascii') + artifact = types.Part( + inline_data=types.Blob(data=csv_base64, mime_type='application/csv') + ) + + tool_context = _StubToolContext({artifact_name: artifact}) + llm_request = LlmRequest( + contents=[ + types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='load_artifacts', + response={'artifact_names': [artifact_name]}, + ) + ) + ], + ) + ] + ) + + await load_artifacts_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + artifact_part = llm_request.contents[-1].parts[1] + assert artifact_part.inline_data is None + assert artifact_part.text == csv_bytes.decode('utf-8') + + +@mark.asyncio +async def test_load_artifacts_keeps_supported_mime_types(): + """Supported inline MIME types are passed through unchanged.""" + artifact_name = 'test.pdf' + artifact = types.Part( + inline_data=types.Blob(data=b'%PDF-1.4', mime_type='application/pdf') + ) + + tool_context = _StubToolContext({artifact_name: artifact}) + llm_request = LlmRequest( + contents=[ + types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='load_artifacts', + response={'artifact_names': [artifact_name]}, + ) + ) + ], + ) + ] + ) + + await load_artifacts_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + artifact_part = llm_request.contents[-1].parts[1] + assert artifact_part.inline_data is not None + assert artifact_part.inline_data.mime_type == 'application/pdf' + + +def test_maybe_base64_to_bytes_decodes_standard_base64(): + """Standard base64 encoded strings are decoded correctly.""" + original = b'hello world' + encoded = base64.b64encode(original).decode('ascii') + assert _maybe_base64_to_bytes(encoded) == original + + +def test_maybe_base64_to_bytes_decodes_urlsafe_base64(): + """URL-safe base64 encoded strings are decoded correctly.""" + original = b'\xfb\xff\xfe' # bytes that produce +/ in std but -_ in urlsafe + encoded = base64.urlsafe_b64encode(original).decode('ascii') + assert _maybe_base64_to_bytes(encoded) == original + + +def test_maybe_base64_to_bytes_returns_none_for_invalid(): + """Invalid base64 strings return None.""" + # Single character is invalid (base64 requires length % 4 == 0 after padding) + assert _maybe_base64_to_bytes('x') is None From 8e7cc16f1248174e247780f20785401ecf525305 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Wed, 14 Jan 2026 16:18:34 -0800 Subject: [PATCH 04/11] docs: Refactor ADK release analyzer with workflow agents Co-authored-by: Xuan Yang PiperOrigin-RevId: 856412858 --- .../adk_release_analyzer/agent.py | 642 +++++++++++++++--- .../samples/adk_documentation/tools.py | 111 +++ 2 files changed, 673 insertions(+), 80 deletions(-) diff --git a/contributing/samples/adk_documentation/adk_release_analyzer/agent.py b/contributing/samples/adk_documentation/adk_release_analyzer/agent.py index 738217c3e2..ddad17d310 100644 --- a/contributing/samples/adk_documentation/adk_release_analyzer/agent.py +++ b/contributing/samples/adk_documentation/adk_release_analyzer/agent.py @@ -12,8 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""ADK Release Analyzer Agent - Multi-agent architecture for analyzing releases. + +This agent uses a SequentialAgent + LoopAgent pattern to handle large releases +without context overflow: + +1. PlannerAgent: Collects changed files and creates analysis groups +2. LoopAgent + FileGroupAnalyzer: Processes one group at a time +3. SummaryAgent: Compiles all findings and creates the GitHub issue + +State keys used: +- start_tag, end_tag: Release tags being compared +- compare_url: GitHub compare URL +- file_groups: List of file groups to analyze +- current_group_index: Index of current group being processed +- recommendations: Accumulated recommendations from all groups +""" + import os import sys +from typing import Any SAMPLES_DIR = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "..") @@ -29,12 +47,21 @@ from adk_documentation.settings import LOCAL_REPOS_DIR_PATH from adk_documentation.tools import clone_or_pull_repo from adk_documentation.tools import create_issue -from adk_documentation.tools import get_changed_files_between_releases +from adk_documentation.tools import get_changed_files_summary +from adk_documentation.tools import get_file_diff_for_release from adk_documentation.tools import list_directory_contents from adk_documentation.tools import list_releases from adk_documentation.tools import read_local_git_repo_file_content from adk_documentation.tools import search_local_git_repo from google.adk import Agent +from google.adk.agents.loop_agent import LoopAgent +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.tools.exit_loop_tool import exit_loop +from google.adk.tools.tool_context import ToolContext + +# Maximum number of files per analysis group to avoid context overflow +MAX_FILES_PER_GROUP = 5 if IS_INTERACTIVE: APPROVAL_INSTRUCTION = ( @@ -43,96 +70,551 @@ ) else: APPROVAL_INSTRUCTION = ( - "**Do not** wait or ask for user approval or confirmation for creating or" - " updating the issue." + "**Do not** wait or ask for user approval or confirmation for creating" + " or updating the issue." ) + +# ============================================================================= +# Tool functions for state management +# ============================================================================= + + +def get_next_file_group(tool_context: ToolContext) -> dict[str, Any]: + """Gets the next group of files to analyze from the state. + + This tool retrieves the next file group from state["file_groups"] + and increments the current_group_index. + + Args: + tool_context: The tool context providing access to state. + + Returns: + A dictionary with the next file group or indication that all groups + are processed. + """ + file_groups = tool_context.state.get("file_groups", []) + current_index = tool_context.state.get("current_group_index", 0) + + if current_index >= len(file_groups): + return { + "status": "complete", + "message": "All file groups have been processed.", + "total_groups": len(file_groups), + "processed": current_index, + } + + current_group = file_groups[current_index] + tool_context.state["current_group_index"] = current_index + 1 + + return { + "status": "success", + "group_index": current_index, + "total_groups": len(file_groups), + "remaining": len(file_groups) - current_index - 1, + "files": current_group, + } + + +def save_group_recommendations( + tool_context: ToolContext, + group_index: int, + recommendations: list[dict[str, str]], +) -> dict[str, Any]: + """Saves recommendations for a file group to state. + + Args: + tool_context: The tool context providing access to state. + group_index: The index of the group these recommendations belong to. + recommendations: List of recommendation dicts with keys: + - summary: Brief summary of the change + - doc_file: Path to the doc file to update + - current_state: Current content in the doc + - proposed_change: What should be changed + - reasoning: Why this change is needed + - reference: Reference to the code file + + Returns: + A dictionary confirming the save operation. + """ + all_recommendations = tool_context.state.get("recommendations", []) + all_recommendations.extend(recommendations) + tool_context.state["recommendations"] = all_recommendations + + return { + "status": "success", + "group_index": group_index, + "new_recommendations": len(recommendations), + "total_recommendations": len(all_recommendations), + } + + +def get_all_recommendations(tool_context: ToolContext) -> dict[str, Any]: + """Retrieves all accumulated recommendations from state. + + Args: + tool_context: The tool context providing access to state. + + Returns: + A dictionary with all recommendations and metadata. + """ + recommendations = tool_context.state.get("recommendations", []) + start_tag = tool_context.state.get("start_tag", "unknown") + end_tag = tool_context.state.get("end_tag", "unknown") + compare_url = tool_context.state.get("compare_url", "") + + return { + "status": "success", + "start_tag": start_tag, + "end_tag": end_tag, + "compare_url": compare_url, + "total_recommendations": len(recommendations), + "recommendations": recommendations, + } + + +def save_release_info( + tool_context: ToolContext, + start_tag: str, + end_tag: str, + compare_url: str, + file_groups: list[list[dict[str, Any]]], + release_summary: str, + all_changed_files: list[str], +) -> dict[str, Any]: + """Saves release info and file groups to state for processing. + + Args: + tool_context: The tool context providing access to state. + start_tag: The starting release tag. + end_tag: The ending release tag. + compare_url: The GitHub compare URL. + file_groups: List of file groups, where each group is a list of file + info dicts. + release_summary: A high-level summary of all changes in this release, + including the main themes (e.g., "new feature X", "refactoring Y", + "bug fixes in Z"). This helps individual analyzers understand the + bigger picture. + all_changed_files: List of all changed file paths (for cross-reference). + + Returns: + A dictionary confirming the save operation. + """ + tool_context.state["start_tag"] = start_tag + tool_context.state["end_tag"] = end_tag + tool_context.state["compare_url"] = compare_url + tool_context.state["file_groups"] = file_groups + tool_context.state["current_group_index"] = 0 + tool_context.state["recommendations"] = [] + tool_context.state["release_summary"] = release_summary + tool_context.state["all_changed_files"] = all_changed_files + + return { + "status": "success", + "start_tag": start_tag, + "end_tag": end_tag, + "total_groups": len(file_groups), + "total_files": sum(len(group) for group in file_groups), + } + + +def get_release_context(tool_context: ToolContext) -> dict[str, Any]: + """Gets the global release context for cross-group awareness. + + This allows individual file group analyzers to understand: + - The overall theme of the release + - What other files were changed (for identifying related changes) + - What recommendations have already been made (to avoid duplicates) + + Args: + tool_context: The tool context providing access to state. + + Returns: + A dictionary with global release context. + """ + return { + "status": "success", + "start_tag": tool_context.state.get("start_tag", "unknown"), + "end_tag": tool_context.state.get("end_tag", "unknown"), + "release_summary": tool_context.state.get("release_summary", ""), + "all_changed_files": tool_context.state.get("all_changed_files", []), + "existing_recommendations": tool_context.state.get("recommendations", []), + "current_group_index": tool_context.state.get("current_group_index", 0), + "total_groups": len(tool_context.state.get("file_groups", [])), + } + + +# ============================================================================= +# Agent 1: Planner Agent +# ============================================================================= + +planner_agent = Agent( + model="gemini-2.5-pro", + name="release_planner", + description=( + "Plans the analysis by fetching release info and organizing files into" + " groups for incremental processing." + ), + instruction=f""" +# 1. Identity +You are the Release Planner, responsible for setting up the analysis of ADK +Python releases. You gather information about changes and organize them for +efficient processing. + +# 2. Workflow +1. First, call `clone_or_pull_repo` for both repositories: + - ADK Python codebase: owner={CODE_OWNER}, repo={CODE_REPO}, path={LOCAL_REPOS_DIR_PATH}/{CODE_REPO} + - ADK Docs: owner={DOC_OWNER}, repo={DOC_REPO}, path={LOCAL_REPOS_DIR_PATH}/{DOC_REPO} + +2. Call `list_releases` to find the release tags for {CODE_OWNER}/{CODE_REPO}. + - By default, compare the two most recent releases. + - If the user specifies tags, use those instead. + +3. Call `get_changed_files_summary` to get the list of changed files WITHOUT + the full patches (to save context space). + +4. Filter and organize the files: + - **INCLUDE** only files in `src/google/adk/` directory + - **EXCLUDE** test files, `__init__.py`, and files outside src/ + - **IMPORTANT**: Do NOT exclude any file just because it has few changes. + Even single-line changes to public APIs need documentation updates. + - **PRIORITIZE** by importance: + a) New files (status: "added") - ALWAYS include these + b) CLI files (cli/) - often contain user-facing flags and options + c) Tool files (tools/) - may contain new tools or tool parameters + d) Core files (agents/, models/, sessions/, memory/, a2a/, flows/, + plugins/, evaluation/) + e) Files with many changes (high additions + deletions) + +5. **Create a high-level release summary** based on the changed files: + - Identify the main themes (e.g., "new tool X added", "refactoring of Y") + - Note any files that appear related (e.g., same feature area) + - This summary will be shared with individual file analyzers so they + understand the bigger picture. + +6. Group the filtered files into groups of at most {MAX_FILES_PER_GROUP} files each. + - **IMPORTANT**: Group RELATED files together (same directory or feature) + - Files that are part of the same feature should be in the same group + - Each group should be independently analyzable + +7. Call `save_release_info` to save: + - start_tag, end_tag + - compare_url + - file_groups (the organized groups) + - release_summary (the high-level summary you created) + - all_changed_files (list of all file paths for cross-reference) + +# 3. Output +Provide a summary of: +- Which releases are being compared +- The high-level themes of this release +- How many files changed in total +- How many files are relevant for doc analysis +- How many groups were created +""", + tools=[ + clone_or_pull_repo, + list_releases, + get_changed_files_summary, + save_release_info, + ], + output_key="planner_output", +) + + +# ============================================================================= +# Agent 2: File Group Analyzer (runs inside LoopAgent) +# ============================================================================= + + +def file_analyzer_instruction(readonly_context: ReadonlyContext) -> str: + """Dynamic instruction that includes current state info.""" + start_tag = readonly_context.state.get("start_tag", "unknown") + end_tag = readonly_context.state.get("end_tag", "unknown") + release_summary = readonly_context.state.get("release_summary", "") + + return f""" +# 1. Identity +You are the File Group Analyzer, responsible for analyzing a group of changed +files and finding related documentation that needs updating. + +# 2. Context +- Comparing releases: {start_tag} to {end_tag} +- Code repository: {CODE_OWNER}/{CODE_REPO} +- Docs repository: {DOC_OWNER}/{DOC_REPO} +- Docs local path: {LOCAL_REPOS_DIR_PATH}/{DOC_REPO} +- Code local path: {LOCAL_REPOS_DIR_PATH}/{CODE_REPO} + +## Release Summary (from Planner) +{release_summary} + +# 3. Workflow +1. Call `get_next_file_group` to get the next group of files to analyze. + - If status is "complete", call the `exit_loop` tool to exit the loop. + +2. **FIRST**, call `get_release_context` to understand: + - The overall release themes (to understand how your files fit in) + - What other files were changed (to identify related changes) + - What recommendations already exist (to AVOID DUPLICATES) + +3. For each file in the group: + a) Call `get_file_diff_for_release` to get the patch content for that file. + b) Analyze the changes THOROUGHLY. Look for: + **API Changes:** + - New functions, classes, methods (especially public ones) + - New parameters added to existing functions + - New CLI arguments or flags (look for argparse, click decorators) + - New environment variables (look for os.environ, getenv) + - New tools or features being added + - Renamed or deprecated functionality + **Behavior Changes (even without API changes):** + - Default values changed + - Error handling or exception types changed + - Return value format or content changed + - Side effects added or removed + - Performance characteristics changed + - Edge case handling changed + - Validation rules changed + c) Consider how this file relates to OTHER changed files in this release. + d) Generate MULTIPLE search patterns based on: + - Class/function names that changed + - Feature names mentioned in the file path + - Keywords from the patch content (e.g., "local_storage", "allow_origins") + - Tool names, parameter names, environment variable names + +4. For EACH significant change, call `search_local_git_repo` to find related docs + in {LOCAL_REPOS_DIR_PATH}/{DOC_REPO}/docs/ + - Search for the feature name, class name, or related keywords + - If no docs found, recommend creating new documentation + +5. Call `read_local_git_repo_file_content` to read the relevant doc files + and check if they need updating. + +6. For each documentation update needed, create a recommendation with: + - summary: Brief summary of what needs to change + - doc_file: Relative path in the docs repo (e.g., docs/tools/google-search.md) + - current_state: What the doc currently says + - proposed_change: What it should say instead + - reasoning: Why this update is needed + - reference: The source code file path + - related_files: Other changed files that are part of the same change (if any) + +7. Call `save_group_recommendations` with all recommendations for this group. + +8. After saving, output a brief summary of what you found for this group. + +# 4. Rules +- **BE THOROUGH**: Check EVERY change in the diff that could affect users. + This includes API changes AND behavior changes (default values, error handling, + return formats, side effects, etc.). +- Focus on changes that users need to know about +- Include behavior changes even if the API signature stays the same +- If a change only affects auto-generated API reference docs, note that + regeneration is needed instead of manual updates +- **AVOID DUPLICATES**: Check existing_recommendations before adding new ones +- **CROSS-REFERENCE**: If files in your group relate to files in other groups, + mention this in your recommendation so the Summary agent can consolidate +- **DON'T MISS ITEMS**: Better to have too many recommendations than too few. + If unsure whether something needs documentation, include it. +- For new features with no existing docs, recommend creating a new page +""" + + +file_group_analyzer = Agent( + model="gemini-2.5-pro", + name="file_group_analyzer", + description=( + "Analyzes a group of changed files and generates recommendations." + ), + instruction=file_analyzer_instruction, + tools=[ + get_next_file_group, + get_release_context, # Get global context to avoid duplicates + get_file_diff_for_release, + search_local_git_repo, + read_local_git_repo_file_content, + list_directory_contents, + save_group_recommendations, + exit_loop, # Call this when all groups are processed + ], + output_key="analyzer_output", +) + +# Loop agent that processes file groups one at a time +file_analysis_loop = LoopAgent( + name="file_analysis_loop", + sub_agents=[file_group_analyzer], + max_iterations=50, # Safety limit +) + + +# ============================================================================= +# Agent 3: Summary Agent +# ============================================================================= + + +def summary_instruction(readonly_context: ReadonlyContext) -> str: + """Dynamic instruction with release info.""" + start_tag = readonly_context.state.get("start_tag", "unknown") + end_tag = readonly_context.state.get("end_tag", "unknown") + + return f""" +# 1. Identity +You are the Summary Agent, responsible for compiling all recommendations into +a well-formatted GitHub issue. + +# 2. Workflow +1. Call `get_all_recommendations` to retrieve all accumulated recommendations. + +2. Organize the recommendations: + - Group by importance: Feature changes > Bug fixes > Other + - Within each group, sort by number of affected files + - Remove duplicates or merge similar recommendations + +3. Format the issue body using this template for each recommendation: + ``` + ### N. **Summary of the change** + + **Doc file**: path/to/doc.md + + **Current state**: + > Current content in the doc + + **Proposed Change**: + > What it should say instead + + **Reasoning**: + Explanation of why this change is necessary. + + **Reference**: src/google/adk/path/to/file.py + ``` + +4. Create the GitHub issue: + - Title: "Found docs updates needed from ADK python release {start_tag} to {end_tag}" + - Include the compare link at the top + - {APPROVAL_INSTRUCTION} + +5. Call `create_issue` for {DOC_OWNER}/{DOC_REPO} with the formatted content. + +# 3. Output +Present a summary of: +- Total recommendations created +- Issue URL if created +- Any notes about the analysis +""" + + +summary_agent = Agent( + model="gemini-2.5-pro", + name="summary_agent", + description="Compiles recommendations and creates the GitHub issue.", + instruction=summary_instruction, + tools=[ + get_all_recommendations, + create_issue, + ], + output_key="summary_output", +) + + +# ============================================================================= +# Pipeline Agent: Sequential orchestration of the analysis +# ============================================================================= + +analysis_pipeline = SequentialAgent( + name="analysis_pipeline", + description=( + "Executes the release analysis pipeline: planning, file analysis, and" + " summary generation." + ), + sub_agents=[ + planner_agent, + file_analysis_loop, + summary_agent, + ], +) + + +# ============================================================================= +# Root Agent: Entry point that understands user requests +# ============================================================================= + root_agent = Agent( model="gemini-2.5-pro", name="adk_release_analyzer", description=( - "Analyze the changes between two ADK releases and generate instructions" - " about how to update the ADK docs." + "Analyzes ADK Python releases and generates documentation update" + " recommendations." ), instruction=f""" - # 1. Identity - You are a helper bot that checks if ADK docs in GitHub Repository {DOC_REPO} owned by {DOC_OWNER} - should be updated based on the changes in the ADK Python codebase in GitHub Repository {CODE_REPO} owned by {CODE_OWNER}. - - You are very familiar with GitHub, especially how to search for files in a GitHub repository using git grep. - - # 2. Responsibilities - Your core responsibility includes: - - Find all the code changes between the two ADK releases. - - Find **all** the related docs files in ADK Docs repository under the "/docs/" directory. - - Compare the code changes with the docs files and analyze the differences. - - Write the instructions about how to update the ADK docs in markdown format and create a GitHub issue in the GitHub Repository {DOC_REPO} with the instructions. - - # 3. Workflow - 1. Always call the `clone_or_pull_repo` tool to make sure the ADK docs and codebase repos exist in the local folder {LOCAL_REPOS_DIR_PATH}/repo_name and are the latest version. - 2. Find the code changes between the two ADK releases. - - You should call the `get_changed_files_between_releases` tool to find all the code changes between the two ADK releases. - - You can call the `list_releases` tool to find the release tags. - 3. Understand the code changes between the two ADK releases. - - You should focus on the main ADK Python codebase, ignore the changes in tests or other auxiliary files. - 4. Come up with a list of regex search patterns to search for related docs files. - 5. Use the `search_local_git_repo` tool to search for related docs files using the regex patterns. - - You should look into all the related docs files, not only the most relevant one. - - Prefer searching from the root directory of the ADK Docs repository (i.e. /docs/), unless you are certain that the file is in a specific directory. - 6. Read the found docs files using the `read_local_git_repo_file_content` tool to find all the docs to update. - - You should read all the found docs files and check if they are up to date. - 7. Compare the code changes and docs files, and analyze the differences. - - You should not only check the code snippets in the docs, but also the text contents. - 8. Write the instructions about how to update the ADK docs in a markdown format. - - For **each** recommended change, reference the code changes. - - For **each** recommended change, follow the format of the following template: - ``` - 1. **Highlighted summary of the change**. - Details of the change. - - **Current state**: - Current content in the doc - - **Proposed Change**: - Proposed change to the doc. - - **Reasoning**: - Explanation of why this change is necessary. - - **Reference**: - Reference to the code file (e.g. src/google/adk/tools/spanner/metadata_tool.py). - ``` - - When referencing doc file, use the full relative path of the doc file in the ADK Docs repository (e.g. docs/sessions/memory.md). - 9. Create or recommend to create a GitHub issue in the GitHub Repository {DOC_REPO} with the instructions using the `create_issue` tool. - - The title of the issue should be "Found docs updates needed from ADK python release to ", where start_tag and end_tag are the release tags. - - The body of the issue should be the instructions about how to update the ADK docs. - - Include the compare link between the two ADK releases in the issue body, e.g. https://github.com/google/adk-python/compare/v1.14.0...v1.14.1. - - **{APPROVAL_INSTRUCTION}** - - # 4. Guidelines & Rules - - **File Paths:** Always use absolute paths when calling the tools to read files, list directories, or search the codebase. - - **Tool Call Parallelism:** Execute multiple independent tool calls in parallel when feasible (i.e. searching the codebase). - - **Explanation:** Provide concise explanations for your actions and reasoning for each step. - - **Reference:** For each recommended change, reference the code changes (i.e. links to the commits) **AND** the code files (i.e. relative paths to the code files in the codebase). - - **Sorting:** Sort the recommended changes by the importance of the changes, from the most important to the least important. - - Here are the importance groups: Feature changes > Bug fixes > Other changes. - - Within each importance group, sort the changes by the number of files they affect. - - Within each group of changes with the same number of files, sort by the number of lines changed in each file. - - **API Reference Updates:** ADK Docs repository has auto-generated API reference docs for the ADK Python codebase, which can be found in the "/docs/api-reference/python" directory. - - If a change in the codebase can be covered by the auto-generated API reference docs, you should just recommend to update the API reference docs (i.e. regenerate the API reference docs) instead of the other human-written ADK docs. - - # 5. Output - Present the following in an easy to read format as the final output to the user. - - The actions you took and the reasoning - - The summary of the differences found - """, +# 1. Identity +You are the ADK Release Analyzer, a helper bot that analyzes changes between +ADK Python releases and identifies documentation updates needed in the ADK +Docs repository. + +# 2. Capabilities +You can help users in several ways: + +## A. Full Release Analysis (delegate to analysis_pipeline) +When users want a complete analysis of releases, delegate to the +`analysis_pipeline` sub-agent. This will: +- Clone/update repositories +- Analyze all changed files +- Generate recommendations +- Create a GitHub issue + +Use this when users say things like: +- "Analyze the latest releases" +- "Check what docs need updating for v1.15.0" +- "Run a full analysis" + +## B. Quick Queries (use your tools directly) +For targeted questions, use your tools directly WITHOUT delegating: + +- **"How should I modify doc1.md?"** → Use `search_local_git_repo` to find + mentions of doc1.md in the codebase, then use `get_changed_files_summary` + to see what changed, and provide specific guidance. + +- **"What changed in the tools module?"** → Use `get_changed_files_summary` + and filter for tools/ directory. + +- **"Show me the recommendations from the last analysis"** → Use + `get_all_recommendations` to retrieve stored recommendations. + +- **"What releases are available?"** → Use `list_releases` directly. + +# 3. Workflow Decision +1. First, understand what the user is asking: + - Full analysis request → delegate to analysis_pipeline + - Specific question about a file/module → use tools directly + - Query about previous results → use get_all_recommendations + +2. For quick queries, ensure repos are cloned first using `clone_or_pull_repo` + if needed. + +3. Always explain what you're doing and provide clear, actionable answers. + +# 4. Available Tools +- `clone_or_pull_repo`: Ensure local repos are up to date +- `list_releases`: See available release tags +- `get_changed_files_summary`: Get list of changed files (lightweight) +- `get_file_diff_for_release`: Get patch for a specific file +- `search_local_git_repo`: Search for patterns in repos +- `read_local_git_repo_file_content`: Read file contents +- `get_all_recommendations`: Retrieve recommendations from previous analysis + +# 5. Repository Info +- Code repo: {CODE_OWNER}/{CODE_REPO} at {LOCAL_REPOS_DIR_PATH}/{CODE_REPO} +- Docs repo: {DOC_OWNER}/{DOC_REPO} at {LOCAL_REPOS_DIR_PATH}/{DOC_REPO} +""", tools=[ - list_releases, - get_changed_files_between_releases, clone_or_pull_repo, - list_directory_contents, + list_releases, + get_changed_files_summary, + get_file_diff_for_release, search_local_git_repo, read_local_git_repo_file_content, - create_issue, + get_all_recommendations, ], + sub_agents=[analysis_pipeline], ) diff --git a/contributing/samples/adk_documentation/tools.py b/contributing/samples/adk_documentation/tools.py index bc3b8d8c42..c6fd4c2f4d 100644 --- a/contributing/samples/adk_documentation/tools.py +++ b/contributing/samples/adk_documentation/tools.py @@ -548,3 +548,114 @@ def _git_grep( check=False, # Don't raise error on non-zero exit code (1 means no match) ) return grep_process + + +def get_file_diff_for_release( + repo_owner: str, + repo_name: str, + start_tag: str, + end_tag: str, + file_path: str, +) -> Dict[str, Any]: + """Gets the diff/patch for a specific file between two release tags. + + This is useful for incremental processing where you want to analyze + one file at a time instead of loading all changes at once. + + Args: + repo_owner: The name of the repository owner. + repo_name: The name of the repository. + start_tag: The older tag (base) for the comparison. + end_tag: The newer tag (head) for the comparison. + file_path: The relative path of the file to get the diff for. + + Returns: + A dictionary containing the status and the file diff details. + """ + url = f"{GITHUB_BASE_URL}/repos/{repo_owner}/{repo_name}/compare/{start_tag}...{end_tag}" + + try: + comparison_data = get_request(url) + changed_files = comparison_data.get("files", []) + + for file_data in changed_files: + if file_data.get("filename") == file_path: + return { + "status": "success", + "file": { + "relative_path": file_data.get("filename"), + "status": file_data.get("status"), + "additions": file_data.get("additions"), + "deletions": file_data.get("deletions"), + "changes": file_data.get("changes"), + "patch": file_data.get("patch", "No patch available."), + }, + } + + return error_response(f"File {file_path} not found in the comparison.") + except requests.exceptions.HTTPError as e: + return error_response(f"HTTP Error: {e}") + except requests.exceptions.RequestException as e: + return error_response(f"Request Error: {e}") + + +def get_changed_files_summary( + repo_owner: str, repo_name: str, start_tag: str, end_tag: str +) -> Dict[str, Any]: + """Gets a summary of changed files between two releases without patches. + + This is a lighter-weight version of get_changed_files_between_releases + that only returns file paths and metadata, without the actual diff content. + Use this for planning which files to analyze. + + Args: + repo_owner: The name of the repository owner. + repo_name: The name of the repository. + start_tag: The older tag (base) for the comparison. + end_tag: The newer tag (head) for the comparison. + + Returns: + A dictionary containing the status and a summary of changed files. + """ + url = f"{GITHUB_BASE_URL}/repos/{repo_owner}/{repo_name}/compare/{start_tag}...{end_tag}" + + try: + comparison_data = get_request(url) + changed_files = comparison_data.get("files", []) + + # Group files by directory for easier processing + files_by_dir: Dict[str, List[Dict[str, Any]]] = {} + formatted_files = [] + + for file_data in changed_files: + file_info = { + "relative_path": file_data.get("filename"), + "status": file_data.get("status"), + "additions": file_data.get("additions"), + "deletions": file_data.get("deletions"), + "changes": file_data.get("changes"), + } + formatted_files.append(file_info) + + # Group by top-level directory + path = file_data.get("filename", "") + parts = path.split("/") + top_dir = parts[0] if parts else "root" + if top_dir not in files_by_dir: + files_by_dir[top_dir] = [] + files_by_dir[top_dir].append(file_info) + + return { + "status": "success", + "total_files": len(formatted_files), + "files": formatted_files, + "files_by_directory": files_by_dir, + "compare_url": ( + f"https://github.com/{repo_owner}/{repo_name}" + f"/compare/{start_tag}...{end_tag}" + ), + } + except requests.exceptions.HTTPError as e: + return error_response(f"HTTP Error: {e}") + except requests.exceptions.RequestException as e: + return error_response(f"Request Error: {e}") From 8264211f9823946e01b23bf888c5741d8df9fe43 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 14 Jan 2026 16:36:39 -0800 Subject: [PATCH 05/11] chore: Consolidate test_mcp_toolset.py into one file There was an extra test_mcp_toolset in the tools/ directory with only one test; I moved it into the main file. Co-authored-by: Kathy Wu PiperOrigin-RevId: 856419611 --- .../tools/mcp_tool/test_mcp_toolset.py | 51 +++++++++++++ tests/unittests/tools/test_mcp_toolset.py | 71 ------------------- 2 files changed, 51 insertions(+), 71 deletions(-) delete mode 100644 tests/unittests/tools/test_mcp_toolset.py diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 5809efe56f..f6d002ed17 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -17,6 +17,7 @@ import sys import unittest from unittest.mock import AsyncMock +from unittest.mock import MagicMock from unittest.mock import Mock from unittest.mock import patch @@ -28,6 +29,7 @@ from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset from mcp import StdioServerParameters import pytest @@ -302,3 +304,52 @@ async def test_get_tools_retry_decorator(self): # Check that the method has the retry decorator assert hasattr(toolset.get_tools, "__wrapped__") + + @pytest.mark.asyncio + async def test_mcp_toolset_with_prefix(self): + """Test that McpToolset correctly applies the tool_name_prefix.""" + # Mock the connection parameters + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + + # Mock the MCPSessionManager and its create_session method + mock_session_manager = MagicMock() + mock_session = MagicMock() + + # Mock the list_tools response from the MCP server + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "tool 1 desc" + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "tool 2 desc" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool1, mock_tool2] + mock_session.list_tools = AsyncMock(return_value=list_tools_result) + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + + # Create an instance of McpToolset with a prefix + toolset = McpToolset( + connection_params=mock_connection_params, + tool_name_prefix="my_prefix", + ) + + # Replace the internal session manager with our mock + toolset._mcp_session_manager = mock_session_manager + + # Get the tools from the toolset + tools = await toolset.get_tools() + + # The get_tools method in McpToolset returns MCPTool objects, which are + # instances of BaseTool. The prefixing is handled by the BaseToolset, + # so we need to call get_tools_with_prefix to get the prefixed tools. + prefixed_tools = await toolset.get_tools_with_prefix() + + # Assert that the tools are prefixed correctly + assert len(prefixed_tools) == 2 + assert prefixed_tools[0].name == "my_prefix_tool1" + assert prefixed_tools[1].name == "my_prefix_tool2" + + # Assert that the original tools are not modified + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" diff --git a/tests/unittests/tools/test_mcp_toolset.py b/tests/unittests/tools/test_mcp_toolset.py deleted file mode 100644 index 7bfd912669..0000000000 --- a/tests/unittests/tools/test_mcp_toolset.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for McpToolset.""" - -from unittest.mock import AsyncMock -from unittest.mock import MagicMock - -from google.adk.tools.mcp_tool.mcp_toolset import McpToolset -import pytest - - -@pytest.mark.asyncio -async def test_mcp_toolset_with_prefix(): - """Test that McpToolset correctly applies the tool_name_prefix.""" - # Mock the connection parameters - mock_connection_params = MagicMock() - mock_connection_params.timeout = None - - # Mock the MCPSessionManager and its create_session method - mock_session_manager = MagicMock() - mock_session = MagicMock() - - # Mock the list_tools response from the MCP server - mock_tool1 = MagicMock() - mock_tool1.name = "tool1" - mock_tool1.description = "tool 1 desc" - mock_tool2 = MagicMock() - mock_tool2.name = "tool2" - mock_tool2.description = "tool 2 desc" - list_tools_result = MagicMock() - list_tools_result.tools = [mock_tool1, mock_tool2] - mock_session.list_tools = AsyncMock(return_value=list_tools_result) - mock_session_manager.create_session = AsyncMock(return_value=mock_session) - - # Create an instance of McpToolset with a prefix - toolset = McpToolset( - connection_params=mock_connection_params, - tool_name_prefix="my_prefix", - ) - - # Replace the internal session manager with our mock - toolset._mcp_session_manager = mock_session_manager - - # Get the tools from the toolset - tools = await toolset.get_tools() - - # The get_tools method in McpToolset returns MCPTool objects, which are - # instances of BaseTool. The prefixing is handled by the BaseToolset, - # so we need to call get_tools_with_prefix to get the prefixed tools. - prefixed_tools = await toolset.get_tools_with_prefix() - - # Assert that the tools are prefixed correctly - assert len(prefixed_tools) == 2 - assert prefixed_tools[0].name == "my_prefix_tool1" - assert prefixed_tools[1].name == "my_prefix_tool2" - - # Assert that the original tools are not modified - assert tools[0].name == "tool1" - assert tools[1].name == "tool2" From 89bed43f5e0c5ad12dd31c716d372145b7e33e78 Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 14 Jan 2026 16:44:16 -0800 Subject: [PATCH 06/11] fix: Add finish reason mapping and remove custom file URI handling in LiteLLM Introduces a function to map LiteLLM finish reason strings to the internal types.FinishReason enum and populates the finish_reason field in LlmResponse. Removes custom logic for handling file URIs, including special casing for different providers, and updates tests accordingly Close #4125 Co-authored-by: George Weale PiperOrigin-RevId: 856421317 --- src/google/adk/models/lite_llm.py | 18 +++++++++ tests/unittests/models/test_litellm.py | 51 ++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 384d76da88..f6705c1de9 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -110,6 +110,18 @@ ) +def _map_finish_reason( + finish_reason: Any, +) -> types.FinishReason | None: + """Maps a LiteLLM finish_reason value to a google-genai FinishReason enum.""" + if not finish_reason: + return None + if isinstance(finish_reason, types.FinishReason): + return finish_reason + finish_reason_str = str(finish_reason).lower() + return _FINISH_REASON_MAPPING.get(finish_reason_str, types.FinishReason.OTHER) + + def _get_provider_from_model(model: str) -> str: """Extracts the provider name from a LiteLLM model string. @@ -1840,6 +1852,9 @@ async def generate_content_async( else None, ) ) + aggregated_llm_response_with_tool_call.finish_reason = ( + _map_finish_reason(finish_reason) + ) text = "" reasoning_parts = [] function_calls.clear() @@ -1854,6 +1869,9 @@ async def generate_content_async( if reasoning_parts else None, ) + aggregated_llm_response.finish_reason = _map_finish_reason( + finish_reason + ) text = "" reasoning_parts = [] diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index c687ceb0cb..f6428087b0 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -2880,6 +2880,7 @@ async def test_generate_content_async_stream( "test_arg": "test_value" } assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" + assert responses[3].finish_reason == types.FinishReason.STOP assert responses[3].model_version == "test_model" mock_completion.assert_called_once() @@ -2900,6 +2901,55 @@ async def test_generate_content_async_stream( ) +@pytest.mark.asyncio +async def test_generate_content_async_stream_sets_finish_reason( + mock_completion, lite_llm_instance +): + mock_completion.return_value = iter([ + ModelResponse( + model="test_model", + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta(role="assistant", content="Hello "), + ) + ], + ), + ModelResponse( + model="test_model", + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta(role="assistant", content="world"), + ) + ], + ), + ModelResponse( + model="test_model", + choices=[StreamingChoices(finish_reason="stop", delta=Delta())], + ), + ]) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + ) + + responses = [ + response + async for response in lite_llm_instance.generate_content_async( + llm_request, stream=True + ) + ] + + assert responses[-1].partial is False + assert responses[-1].finish_reason == types.FinishReason.STOP + assert responses[-1].content.parts[0].text == "Hello world" + + @pytest.mark.asyncio async def test_generate_content_async_stream_with_usage_metadata( mock_completion, lite_llm_instance @@ -2944,6 +2994,7 @@ async def test_generate_content_async_stream_with_usage_metadata( "test_arg": "test_value" } assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" + assert responses[3].finish_reason == types.FinishReason.STOP assert responses[3].usage_metadata.prompt_token_count == 10 assert responses[3].usage_metadata.candidates_token_count == 5 From 712b5a393d44e7b5ce35fc459da98361bae4bb16 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 14 Jan 2026 16:50:45 -0800 Subject: [PATCH 07/11] fix: Only filter out audio content when sending history audio is transcribed thus no need to be sent, but other blob(e.g. image) should still be sent. Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 856422986 --- .../adk/models/gemini_llm_connection.py | 16 +- src/google/adk/utils/content_utils.py | 38 ++++ .../models/test_gemini_llm_connection.py | 174 ++++++++++++++++++ 3 files changed, 224 insertions(+), 4 deletions(-) create mode 100644 src/google/adk/utils/content_utils.py diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 327157e2a6..158a5cabc1 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -20,6 +20,7 @@ from google.genai import types +from ..utils.content_utils import filter_audio_parts from ..utils.context_utils import Aclosing from ..utils.variant_utils import GoogleLLMVariant from .base_llm_connection import BaseLlmConnection @@ -63,15 +64,22 @@ async def send_history(self, history: list[types.Content]): # TODO: Remove this filter and translate unary contents to streaming # contents properly. - # We ignore any audio from user during the agent transfer phase + # Filter out audio parts from history because: + # 1. audio has already been transcribed. + # 2. sending audio via connection.send or connection.send_live_content is + # not supported by LIVE API (session will be corrupted). + # This method is called when: + # 1. Agent transfer to a new agent + # 2. Establishing a new live connection with previous ADK session history + contents = [ - content + filtered for content in history - if content.parts and content.parts[0].text + if (filtered := filter_audio_parts(content)) is not None ] - logger.debug('Sending history to live connection: %s', contents) if contents: + logger.debug('Sending history to live connection: %s', contents) await self._gemini_session.send( input=types.LiveClientContent( turns=contents, diff --git a/src/google/adk/utils/content_utils.py b/src/google/adk/utils/content_utils.py new file mode 100644 index 0000000000..379c31ec96 --- /dev/null +++ b/src/google/adk/utils/content_utils.py @@ -0,0 +1,38 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.genai import types + + +def is_audio_part(part: types.Part) -> bool: + return ( + part.inline_data + and part.inline_data.mime_type + and part.inline_data.mime_type.startswith('audio/') + ) or ( + part.file_data + and part.file_data.mime_type + and part.file_data.mime_type.startswith('audio/') + ) + + +def filter_audio_parts(content: types.Content) -> types.Content | None: + if not content.parts: + return None + filtered_parts = [part for part in content.parts if not is_audio_part(part)] + if not filtered_parts: + return None + return types.Content(role=content.role, parts=filtered_parts) diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index de8f4f9dad..ac65b2ac2a 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -600,3 +600,177 @@ async def mock_receive_generator(): assert responses[2].output_transcription.text == 'How can I help?' assert responses[2].output_transcription.finished is True assert responses[2].partial is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'audio_part', + [ + types.Part( + inline_data=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm') + ), + types.Part( + file_data=types.FileData( + file_uri='artifact://app/user/session/_adk_live/audio.pcm#1', + mime_type='audio/pcm', + ) + ), + ], +) +async def test_send_history_filters_audio(mock_gemini_session, audio_part): + """Test that audio parts (inline or file_data) are filtered out.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[audio_part], + ), + types.Content( + role='model', parts=[types.Part.from_text(text='I heard you')] + ), + ] + + await connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + sent_contents = call_args['input'].turns + # Only the model response should be sent (user audio filtered out) + assert len(sent_contents) == 1 + assert sent_contents[0].role == 'model' + assert sent_contents[0].parts == [types.Part.from_text(text='I heard you')] + + +@pytest.mark.asyncio +async def test_send_history_keeps_image_data(mock_gemini_session): + """Test that image data is NOT filtered out.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + image_blob = types.Blob(data=b'\x89PNG\r\n', mime_type='image/png') + history = [ + types.Content( + role='user', + parts=[types.Part(inline_data=image_blob)], + ), + types.Content( + role='model', parts=[types.Part.from_text(text='Nice image!')] + ), + ] + + await connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + sent_contents = call_args['input'].turns + # Both contents should be sent (image is not filtered) + assert len(sent_contents) == 2 + assert sent_contents[0].parts[0].inline_data == image_blob + + +@pytest.mark.asyncio +async def test_send_history_mixed_content_filters_only_audio( + mock_gemini_session, +): + """Test that mixed content keeps non-audio parts.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[ + types.Part( + inline_data=types.Blob( + data=b'\x00\xFF', mime_type='audio/wav' + ) + ), + types.Part.from_text(text='transcribed text'), + ], + ), + ] + + await connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + sent_contents = call_args['input'].turns + # Content should be sent but only with the text part + assert len(sent_contents) == 1 + assert len(sent_contents[0].parts) == 1 + assert sent_contents[0].parts[0].text == 'transcribed text' + + +@pytest.mark.asyncio +async def test_send_history_all_audio_content_not_sent(mock_gemini_session): + """Test that content with only audio parts is completely removed.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[ + types.Part( + inline_data=types.Blob( + data=b'\x00\xFF', mime_type='audio/pcm' + ) + ), + types.Part( + file_data=types.FileData( + file_uri='artifact://audio.pcm#1', + mime_type='audio/wav', + ) + ), + ], + ), + ] + + await connection.send_history(history) + + # No content should be sent since all parts are audio + mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_send_history_empty_history_not_sent(mock_gemini_session): + """Test that empty history does not call send.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + + await connection.send_history([]) + + mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'audio_mime_type', + ['audio/pcm', 'audio/wav', 'audio/mp3', 'audio/ogg'], +) +async def test_send_history_filters_various_audio_mime_types( + mock_gemini_session, + audio_mime_type, +): + """Test that various audio mime types are all filtered.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[ + types.Part( + inline_data=types.Blob(data=b'', mime_type=audio_mime_type) + ) + ], + ), + ] + + await connection.send_history(history) + + # No content should be sent since the only part is audio + mock_gemini_session.send.assert_not_called() From 1133ce219c5a7a9a85222b03e348ba6b13830c8f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 14 Jan 2026 17:06:38 -0800 Subject: [PATCH 08/11] feat: convert A2UI messages between A2A DataPart metadata and ADK events 1. Convert A2A responses containing a DataPart to ADK events. By default, this is done by serializing the DataPart to JSON and embedding it within the inline_data field of a GenAI Part, wrapped with custom tags ( and ). 2. Convert ADK events back to A2A requests. Specifically, messages stored in inline_data with the text/plain mime type and content wrapped within the custom tags ( and ) are deserialized from JSON back into an A2A DataPart PiperOrigin-RevId: 856426615 --- .../adk/a2a/converters/part_converter.py | 28 ++- .../a2a/converters/test_part_converter.py | 199 +++++++++++++----- 2 files changed, 174 insertions(+), 53 deletions(-) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index dfe6f4a0a2..21428b6381 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -40,6 +40,9 @@ A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response' A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = 'code_execution_result' A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = 'executable_code' +A2A_DATA_PART_TEXT_MIME_TYPE = 'text/plain' +A2A_DATA_PART_START_TAG = b'' +A2A_DATA_PART_END_TAG = b'' A2APartToGenAIPartConverter = Callable[ @@ -130,7 +133,16 @@ def convert_a2a_part_to_genai_part( part.data, by_alias=True ) ) - return genai_types.Part(text=json.dumps(part.data)) + return genai_types.Part( + inline_data=genai_types.Blob( + data=A2A_DATA_PART_START_TAG + + part.model_dump_json(by_alias=True, exclude_none=True).encode( + 'utf-8' + ) + + A2A_DATA_PART_END_TAG, + mime_type=A2A_DATA_PART_TEXT_MIME_TYPE, + ) + ) logger.warning( 'Cannot convert unsupported part type: %s for A2A part: %s', @@ -163,6 +175,20 @@ def convert_genai_part_to_a2a_part( ) if part.inline_data: + if ( + part.inline_data.mime_type == A2A_DATA_PART_TEXT_MIME_TYPE + and part.inline_data.data is not None + and part.inline_data.data.startswith(A2A_DATA_PART_START_TAG) + and part.inline_data.data.endswith(A2A_DATA_PART_END_TAG) + ): + return a2a_types.Part( + root=a2a_types.DataPart.model_validate_json( + part.inline_data.data[ + len(A2A_DATA_PART_START_TAG) : -len(A2A_DATA_PART_END_TAG) + ] + ) + ) + # The default case for inline_data is to convert it to FileWithBytes. a2a_part = a2a_types.FilePart( file=a2a_types.FileWithBytes( bytes=base64.b64encode(part.inline_data.data).decode('utf-8'), diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 541ab7709d..00c9ddc5e0 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -17,11 +17,14 @@ from unittest.mock import patch from a2a import types as a2a_types +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_END_TAG from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_START_TAG +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_TEXT_MIME_TYPE from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part from google.adk.a2a.converters.utils import _get_adk_metadata_key @@ -154,12 +157,43 @@ def test_convert_data_part_function_response(self): "data": [1, 2, 3], } - def test_convert_data_part_without_special_metadata(self): - """Test conversion of A2A DataPart without special metadata to text.""" + @pytest.mark.parametrize( + "test_name, data, metadata", + [ + ( + "without_special_metadata", + {"key": "value", "number": 123}, + {"other": "metadata"}, + ), + ( + "no_metadata", + {"key": "value", "array": [1, 2, 3]}, + None, + ), + ( + "complex_data", + { + "nested": { + "array": [1, 2, {"inner": "value"}], + "boolean": True, + "null_value": None, + }, + "unicode": "Hello 世界 🌍", + }, + None, + ), + ( + "empty_metadata", + {"key": "value"}, + {}, + ), + ], + ) + def test_convert_data_part_to_inline_data(self, test_name, data, metadata): + """Test conversion of A2A DataPart to GenAI inline_data Part.""" # Arrange - data = {"key": "value", "number": 123} a2a_part = a2a_types.Part( - root=a2a_types.DataPart(data=data, metadata={"other": "metadata"}) + root=a2a_types.DataPart(data=data, metadata=metadata) ) # Act @@ -168,21 +202,17 @@ def test_convert_data_part_without_special_metadata(self): # Assert assert result is not None assert isinstance(result, genai_types.Part) - assert result.text == json.dumps(data) - - def test_convert_data_part_no_metadata(self): - """Test conversion of A2A DataPart with no metadata to text.""" - # Arrange - data = {"key": "value", "array": [1, 2, 3]} - a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data)) - - # Act - result = convert_a2a_part_to_genai_part(a2a_part) - - # Assert - assert result is not None - assert isinstance(result, genai_types.Part) - assert result.text == json.dumps(data) + assert result.inline_data is not None + assert result.inline_data.mime_type == A2A_DATA_PART_TEXT_MIME_TYPE + assert result.inline_data.data.startswith(A2A_DATA_PART_START_TAG) + assert result.inline_data.data.endswith(A2A_DATA_PART_END_TAG) + converted_data_part = a2a_types.DataPart.model_validate_json( + result.inline_data.data[ + len(A2A_DATA_PART_START_TAG) : -len(A2A_DATA_PART_END_TAG) + ] + ) + assert converted_data_part.data == data + assert converted_data_part.metadata == metadata def test_convert_unsupported_file_type(self): """Test handling of unsupported file types.""" @@ -325,6 +355,32 @@ def test_convert_inline_data_part_with_video_metadata(self): assert result.root.metadata is not None assert _get_adk_metadata_key("video_metadata") in result.root.metadata + def test_convert_inline_data_part_to_data_part(self): + """Test conversion of GenAI inline_data Part to A2A DataPart.""" + # Arrange + data = {"key": "value"} + metadata = {"meta": "data"} + a2a_part_to_convert = a2a_types.DataPart(data=data, metadata=metadata) + json_data = a2a_part_to_convert.model_dump_json( + by_alias=True, exclude_none=True + ).encode("utf-8") + genai_part = genai_types.Part( + inline_data=genai_types.Blob( + data=A2A_DATA_PART_START_TAG + json_data + A2A_DATA_PART_END_TAG, + mime_type=A2A_DATA_PART_TEXT_MIME_TYPE, + ) + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + assert result.root.data == data + assert result.root.metadata == metadata + def test_convert_function_call_part(self): """Test conversion of GenAI function_call Part to A2A Part.""" # Arrange @@ -596,6 +652,47 @@ def test_executable_code_round_trip(self): ) assert result_genai_part.executable_code.code == executable_code.code + def test_data_part_round_trip(self): + """Test round-trip conversion for data parts.""" + # Arrange + data = {"key": "value"} + metadata = {"meta": "data"} + a2a_part = a2a_types.Part( + root=a2a_types.DataPart(data=data, metadata=metadata) + ) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.DataPart) + assert result_a2a_part.root.data == data + assert result_a2a_part.root.metadata == metadata + + def test_data_part_with_mime_type_metadata_round_trip(self): + """Test round-trip conversion for data parts with 'mime_type' in metadata.""" + # Arrange + data = {"content": "some data"} + metadata = {"meta": "data", "mime_type": "application/json"} + a2a_part = a2a_types.Part( + root=a2a_types.DataPart(data=data, metadata=metadata) + ) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.DataPart) + assert result_a2a_part.root.data == data + # The 'mime_type' key in the metadata should be preserved as is + assert result_a2a_part.root.metadata == metadata + class TestEdgeCases: """Test cases for edge cases and error conditions.""" @@ -612,6 +709,37 @@ def test_empty_text_part(self): assert result is not None assert result.text == "" + def test_genai_inline_data_with_mimetype_to_a2a(self): + """Test conversion of GenAI inline_data with 'mimeType' in DataPart metadata to A2A. + + This tests if 'mimeType' in metadata of a DataPart wrapped in inline_data + is correctly handled, ensuring the key casing is preserved. + """ + # Arrange + data = {"key": "value"} + metadata = {"adk_type": "some_type", "mimeType": "image/png"} + a2a_part_inner = a2a_types.DataPart(data=data, metadata=metadata) + json_data = a2a_part_inner.model_dump_json( + by_alias=True, exclude_none=True + ).encode("utf-8") + genai_part = genai_types.Part( + inline_data=genai_types.Blob( + data=A2A_DATA_PART_START_TAG + json_data + A2A_DATA_PART_END_TAG, + mime_type=A2A_DATA_PART_TEXT_MIME_TYPE, + ) + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + assert result.root.data == data + # The key casing should be preserved from the JSON + assert result.root.metadata == metadata + def test_none_input_a2a_to_genai(self): """Test handling of None input for A2A to GenAI conversion.""" # This test depends on how the function handles None input @@ -626,39 +754,6 @@ def test_none_input_genai_to_a2a(self): with pytest.raises(AttributeError): convert_genai_part_to_a2a_part(None) - def test_data_part_with_complex_data(self): - """Test conversion of DataPart with complex nested data.""" - # Arrange - complex_data = { - "nested": { - "array": [1, 2, {"inner": "value"}], - "boolean": True, - "null_value": None, - }, - "unicode": "Hello 世界 🌍", - } - a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=complex_data)) - - # Act - result = convert_a2a_part_to_genai_part(a2a_part) - - # Assert - assert result is not None - assert result.text == json.dumps(complex_data) - - def test_data_part_with_empty_metadata(self): - """Test conversion of DataPart with empty metadata dict.""" - # Arrange - data = {"key": "value"} - a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data, metadata={})) - - # Act - result = convert_a2a_part_to_genai_part(a2a_part) - - # Assert - assert result is not None - assert result.text == json.dumps(data) - class TestNewConstants: """Test cases for new constants and functionality.""" From cce430da799766686e65f6cae02ba64e916d5c8a Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 14 Jan 2026 18:09:33 -0800 Subject: [PATCH 09/11] feat: start and close ClientSession in a single task in McpSessionManager Merge https://github.com/google/adk-python/pull/4025 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: - #3950 - #3731 - #3708 **2. Or, if no issue exists, describe the change:** **Problem:** - `ClientSession` of https://github.com/modelcontextprotocol/python-sdk uses AnyIO for async task management. - AnyIO TaskGroup requires its start and close must happen in a same task. - Since `McpSessionManager` does not create task per client, the client might be closed by different task, cause the error: `Attempted to exit cancel scope in a different task than it was entered in`. **Solution:** I Suggest 2 changes: Handling the `ClientSession` in a single task - To start and close `ClientSession` by the same task, we need to wrap the whole lifecycle of `ClientSession` to a single task. - `SessionContext` wraps the initialization and disposal of `ClientSession` to a single task, ensures that the `ClientSession` will be handled only in a dedicated task. Add timeout for `ClientSession` - Since now we are using task per `ClientSession`, task should never be leaked. - But `McpSessionManager` does not deliver timeout directly to `ClientSession` when the type is not STDIO. - There is only timeout for `httpx` client when MCP type is SSE or StreamableHTTP. - But the timeout applys only to `httpx` client, so if there is an issue in MCP client itself(e.g. https://github.com/modelcontextprotocol/python-sdk/issues/262), a tool call waits the result **FOREVER**! - To overcome this issue, I propagated the `sse_read_timeout` to `ClientSession`. - `timeout` is too short for timeout for tool call, since its default value is only 5s. - `sse_read_timeout` is originally made for read timeout of SSE(default value of 5m or 300s), but actually most of SSE implementations from server (e.g. FastAPI, etc.) sends ping periodically(about 15s I assume), so in a normal circumstances this timeout is quite useless. - If the server does not send ping, the timeout is equal to tool call timeout. Therefore, it would be appropriate to use `sse_read_timeout` as tool call timeout. - Most of tool calls should finish within 5 minutes, and sse timeout is adjustable if not. - If this change is not acceptable, we could make a dedicate parameter for tool call timeout(e.g. `tool_call_timeout`). ### Testing Plan - Although this does not change the interface itself, it changes its own session management logics, some existing tests are no longer valid. - I made changes to those tests, especially those of which validate session states(e.g. checking whether `initialize()` called). - Since now session is encapsulated with `SessionContext`, we cannot validate the initialized state of the session in `TestMcpSessionManager`, should validate it at `TestSessionContext`. - Added a simple test for reproducing the issue(`test_create_and_close_session_in_different_tasks`). - Also made a test for the new component: `SessionContext`. **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. ```plaintext =================================================================================== 3689 passed, 1 skipped, 2205 warnings in 63.39s (0:01:03) =================================================================================== ``` **Manual End-to-End (E2E) Tests:** _Please provide instructions on how to manually test your changes, including any necessary setup or configuration. Please provide logs or screenshots to help reviewers better understand the fix._ ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [ ] ~~Any dependent changes have been merged and published in downstream modules.~~ `no deps has been changed` ### Additional context This PR is related to https://github.com/modelcontextprotocol/python-sdk/pull/1817 since it also fixes endless tool call awaiting. Co-authored-by: Kathy Wu COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4025 from challenger71498:feat/task-based-mcp-session-manager f7f7cd0c9c96840361c30499d08c33a189f57d86 PiperOrigin-RevId: 856438147 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 36 +- .../adk/tools/mcp_tool/session_context.py | 194 ++++++ .../mcp_tool/test_mcp_session_manager.py | 114 +++- .../tools/mcp_tool/test_session_context.py | 550 ++++++++++++++++++ 4 files changed, 848 insertions(+), 46 deletions(-) create mode 100644 src/google/adk/tools/mcp_tool/session_context.py create mode 100644 tests/unittests/tools/mcp_tool/test_session_context.py diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 89f0145727..ebd91dc354 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -41,6 +41,8 @@ from pydantic import BaseModel from pydantic import ConfigDict +from .session_context import SessionContext + logger = logging.getLogger('google_adk.' + __name__) @@ -385,29 +387,27 @@ async def create_session( if hasattr(self._connection_params, 'timeout') else None ) + sse_read_timeout_in_seconds = ( + self._connection_params.sse_read_timeout + if hasattr(self._connection_params, 'sse_read_timeout') + else None + ) try: client = self._create_client(merged_headers) - - transports = await asyncio.wait_for( - exit_stack.enter_async_context(client), + is_stdio = isinstance(self._connection_params, StdioConnectionParams) + + session = await asyncio.wait_for( + exit_stack.enter_async_context( + SessionContext( + client=client, + timeout=timeout_in_seconds, + sse_read_timeout=sse_read_timeout_in_seconds, + is_stdio=is_stdio, + ) + ), timeout=timeout_in_seconds, ) - # The streamable http client returns a GetSessionCallback in addition to the - # read/write MemoryObjectStreams needed to build the ClientSession, we limit - # then to the two first values to be compatible with all clients. - if isinstance(self._connection_params, StdioConnectionParams): - session = await exit_stack.enter_async_context( - ClientSession( - *transports[:2], - read_timeout_seconds=timedelta(seconds=timeout_in_seconds), - ) - ) - else: - session = await exit_stack.enter_async_context( - ClientSession(*transports[:2]) - ) - await asyncio.wait_for(session.initialize(), timeout=timeout_in_seconds) # Store session and exit stack in the pool self._sessions[session_key] = (session, exit_stack) diff --git a/src/google/adk/tools/mcp_tool/session_context.py b/src/google/adk/tools/mcp_tool/session_context.py new file mode 100644 index 0000000000..ca637d0489 --- /dev/null +++ b/src/google/adk/tools/mcp_tool/session_context.py @@ -0,0 +1,194 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack +from datetime import timedelta +import logging +from typing import AsyncContextManager +from typing import Optional + +from mcp import ClientSession + +logger = logging.getLogger('google_adk.' + __name__) + + +class SessionContext: + """Represents the context of a single MCP session within a dedicated task. + + AnyIO's TaskGroup/CancelScope requires that the start and end of a scope + occur within the same task. Since MCP clients use AnyIO internally, we need + to ensure that the client's entire lifecycle (creation, usage, and cleanup) + happens within a single dedicated task. + + This class spawns a background task that: + 1. Enters the MCP client's async context and initializes the session + 2. Signals readiness via an asyncio.Event + 3. Waits for a close signal + 4. Cleans up the client within the same task + + This ensures CancelScope constraints are satisfied regardless of which + task calls start() or close(). + + Can be used in two ways: + 1. Direct method calls: start() and close() + 2. As an async context manager: async with lifecycle as session: ... + """ + + def __init__( + self, + client: AsyncContextManager, + timeout: Optional[float], + sse_read_timeout: Optional[float], + is_stdio: bool = False, + ): + """ + Args: + client: An MCP client context manager (e.g., from streamablehttp_client, + sse_client, or stdio_client). + timeout: Timeout in seconds for connection and initialization. + sse_read_timeout: Timeout in seconds for reading data from the MCP SSE + server. + is_stdio: Whether this is a stdio connection (affects read timeout). + """ + self._client = client + self._timeout = timeout + self._sse_read_timeout = sse_read_timeout + self._is_stdio = is_stdio + self._session: Optional[ClientSession] = None + self._ready_event = asyncio.Event() + self._close_event = asyncio.Event() + self._task: Optional[asyncio.Task] = None + self._task_lock = asyncio.Lock() + + @property + def session(self) -> Optional[ClientSession]: + """Get the managed ClientSession, if available.""" + return self._session + + async def start(self) -> ClientSession: + """Start the runner and wait for the session to be ready. + + Returns: + The initialized ClientSession. + + Raises: + ConnectionError: If session creation fails. + """ + async with self._task_lock: + if self._session: + logger.debug( + 'Session has already been created, returning existing session' + ) + return self._session + + if self._close_event.is_set(): + raise ConnectionError( + 'Failed to create MCP session: session already closed' + ) + + if not self._task: + self._task = asyncio.create_task(self._run()) + + await self._ready_event.wait() + + if self._task.cancelled(): + raise ConnectionError('Failed to create MCP session: task cancelled') + + if self._task.done() and self._task.exception(): + raise ConnectionError( + f'Failed to create MCP session: {self._task.exception()}' + ) from self._task.exception() + + return self._session + + async def close(self): + """Signal the context task to close and wait for cleanup.""" + # Set the close event to signal the task to close. + # Even if start has not been called, we need to set the close event + # to signal the task to close right away. + async with self._task_lock: + self._close_event.set() + + # If start has not been called, only set the close event and return + if not self._task: + return + + if not self._ready_event.is_set(): + self._task.cancel() + + try: + await asyncio.wait_for(self._task, timeout=self._timeout) + except asyncio.TimeoutError: + logger.warning('Failed to close MCP session: task timed out') + self._task.cancel() + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning(f'Failed to close MCP session: {e}') + + async def __aenter__(self) -> ClientSession: + return await self.start() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def _run(self): + """Run the complete session context within a single task.""" + try: + async with AsyncExitStack() as exit_stack: + transports = await asyncio.wait_for( + exit_stack.enter_async_context(self._client), + timeout=self._timeout, + ) + # The streamable http client returns a GetSessionCallback in addition + # to the read/write MemoryObjectStreams needed to build the + # ClientSession. We limit to the first two values to be compatible + # with all clients. + if self._is_stdio: + session = await exit_stack.enter_async_context( + ClientSession( + *transports[:2], + read_timeout_seconds=timedelta(seconds=self._timeout) + if self._timeout is not None + else None, + ) + ) + else: + # For SSE and Streamable HTTP clients, use the sse_read_timeout + # instead of the connection timeout as the read_timeout for the session. + session = await exit_stack.enter_async_context( + ClientSession( + *transports[:2], + read_timeout_seconds=timedelta(seconds=self._sse_read_timeout) + if self._sse_read_timeout is not None + else None, + ) + ) + await asyncio.wait_for(session.initialize(), timeout=self._timeout) + logger.debug('Session has been successfully initialized') + + self._session = session + self._ready_event.set() + + # Wait for close signal - the session remains valid while we wait + await self._close_event.wait() + except BaseException as e: + logger.warning(f'Error on session runner task: {e}') + raise + finally: + self._ready_event.set() + self._close_event.set() diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 7e18d9d457..ae91bed13d 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -56,6 +56,33 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): pass +class MockSessionContext: + """Mock SessionContext for testing.""" + + def __init__(self, session=None): + """Initialize MockSessionContext. + + Args: + session: The mock session to return from __aenter__ and session property. + """ + self._session = session + self._aenter_mock = AsyncMock(return_value=session) + self._aexit_mock = AsyncMock(return_value=False) + + @property + def session(self): + """Get the mock session.""" + return self._session + + async def __aenter__(self): + """Enter the async context manager.""" + return await self._aenter_mock() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the async context manager.""" + return await self._aexit_mock(exc_type, exc_val, exc_tb) + + class TestMCPSessionManager: """Test suite for MCPSessionManager class.""" @@ -241,7 +268,6 @@ async def test_create_session_stdio_new(self): """Test creating a new stdio session.""" manager = MCPSessionManager(self.mock_stdio_connection_params) - mock_session = MockClientSession() mock_exit_stack = MockAsyncExitStack() with patch( @@ -251,17 +277,19 @@ async def test_create_session_stdio_new(self): "google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack" ) as mock_exit_stack_class: with patch( - "google.adk.tools.mcp_tool.mcp_session_manager.ClientSession" - ) as mock_session_class: + "google.adk.tools.mcp_tool.mcp_session_manager.SessionContext" + ) as mock_session_context_class: # Setup mocks mock_exit_stack_class.return_value = mock_exit_stack mock_stdio.return_value = AsyncMock() - mock_exit_stack.enter_async_context.side_effect = [ - ("read", "write"), # First call returns transports - mock_session, # Second call returns session - ] - mock_session_class.return_value = mock_session + + # Mock SessionContext using MockSessionContext + # Create a mock session that will be returned by SessionContext + mock_session = AsyncMock() + mock_session_context = MockSessionContext(session=mock_session) + mock_session_context_class.return_value = mock_session_context + mock_exit_stack.enter_async_context.return_value = mock_session # Create session session = await manager.create_session() @@ -271,8 +299,10 @@ async def test_create_session_stdio_new(self): assert len(manager._sessions) == 1 assert "stdio_session" in manager._sessions - # Verify session was initialized - mock_session.initialize.assert_called_once() + # Verify SessionContext was created + mock_session_context_class.assert_called_once() + # Verify enter_async_context was called (which internally calls __aenter__) + mock_exit_stack.enter_async_context.assert_called_once() @pytest.mark.asyncio async def test_create_session_reuse_existing(self): @@ -300,39 +330,37 @@ async def test_create_session_reuse_existing(self): @pytest.mark.asyncio @patch("google.adk.tools.mcp_tool.mcp_session_manager.stdio_client") @patch("google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack") - @patch("google.adk.tools.mcp_tool.mcp_session_manager.ClientSession") + @patch("google.adk.tools.mcp_tool.mcp_session_manager.SessionContext") async def test_create_session_timeout( - self, mock_session_class, mock_exit_stack_class, mock_stdio + self, mock_session_context_class, mock_exit_stack_class, mock_stdio ): """Test session creation timeout.""" manager = MCPSessionManager(self.mock_stdio_connection_params) - mock_session = MockClientSession() mock_exit_stack = MockAsyncExitStack() mock_exit_stack_class.return_value = mock_exit_stack mock_stdio.return_value = AsyncMock() - mock_exit_stack.enter_async_context.side_effect = [ - ("read", "write"), # First call returns transports - mock_session, # Second call returns session - ] - mock_session_class.return_value = mock_session - # Simulate timeout during session initialization - mock_session.initialize.side_effect = asyncio.TimeoutError("Test timeout") + # Mock SessionContext + mock_session_context = AsyncMock() + mock_session_context.__aenter__ = AsyncMock( + return_value=MockClientSession() + ) + mock_session_context.__aexit__ = AsyncMock(return_value=False) + mock_session_context_class.return_value = mock_session_context + + # Mock enter_async_context to raise TimeoutError (simulating asyncio.wait_for timeout) + mock_exit_stack.enter_async_context = AsyncMock( + side_effect=asyncio.TimeoutError("Test timeout") + ) # Expect ConnectionError due to timeout with pytest.raises(ConnectionError, match="Failed to create MCP session"): await manager.create_session() - # Verify ClientSession called with timeout - mock_session_class.assert_called_with( - "read", - "write", - read_timeout_seconds=timedelta( - seconds=manager._connection_params.timeout - ), - ) + # Verify SessionContext was created + mock_session_context_class.assert_called_once() # Verify session was not added to pool assert not manager._sessions # Verify cleanup was called @@ -390,6 +418,36 @@ async def test_close_with_errors(self): assert "Warning: Error during MCP session cleanup" in error_output assert "Close error 1" in error_output + @pytest.mark.asyncio + @patch("google.adk.tools.mcp_tool.mcp_session_manager.stdio_client") + @patch("google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack") + @patch("google.adk.tools.mcp_tool.mcp_session_manager.SessionContext") + async def test_create_and_close_session_in_different_tasks( + self, mock_session_context_class, mock_exit_stack_class, mock_stdio + ): + """Test creating and closing a session in different tasks.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + mock_exit_stack_class.return_value = MockAsyncExitStack() + mock_stdio.return_value = AsyncMock() + + # Mock SessionContext + mock_session_context = AsyncMock() + mock_session_context.__aenter__ = AsyncMock( + return_value=MockClientSession() + ) + mock_session_context.__aexit__ = AsyncMock(return_value=False) + mock_session_context_class.return_value = mock_session_context + + # Create session in a new task + await asyncio.create_task(manager.create_session()) + + # Close session in another task + await asyncio.create_task(manager.close()) + + # Verify session was closed + assert not manager._sessions + @pytest.mark.asyncio async def test_retry_on_errors_decorator(): diff --git a/tests/unittests/tools/mcp_tool/test_session_context.py b/tests/unittests/tools/mcp_tool/test_session_context.py new file mode 100644 index 0000000000..161cd1aba3 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_session_context.py @@ -0,0 +1,550 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack +from datetime import timedelta +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.tools.mcp_tool.session_context import SessionContext +from mcp import ClientSession +import pytest + + +class MockClientSession: + """Mock ClientSession for testing.""" + + def __init__(self, *args, **kwargs): + self._initialized = False + self._args = args + self._kwargs = kwargs + + async def initialize(self): + """Mock initialize method.""" + self._initialized = True + return self + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + +class MockClient: + """Mock MCP client.""" + + def __init__( + self, + transports=None, + raise_on_enter=None, + delay_on_enter=0, + ): + self._transports = transports or ('read_stream', 'write_stream') + self._raise_on_enter = raise_on_enter + self._delay_on_enter = delay_on_enter + self._entered = False + self._exited = False + + async def __aenter__(self): + if self._delay_on_enter > 0: + await asyncio.sleep(self._delay_on_enter) + if self._raise_on_enter: + raise self._raise_on_enter + self._entered = True + return self._transports + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._exited = True + return False + + +class TestSessionContext: + """Test suite for SessionContext class.""" + + @pytest.mark.asyncio + async def test_start_success_ready_event_set_and_session_returned(self): + """Test that start() sets _ready_event and returns session.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Mock ClientSession + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + session = await session_context.start() + + # Verify ready_event was set + assert session_context._ready_event.is_set() + + # Verify session was returned + assert session == mock_session + assert session_context.session == mock_session + + # Verify initialize was called + assert mock_session._initialized + + # Verify task was created and is still running (waiting for close) + assert session_context._task is not None + assert not session_context._task.done() + + # Clean up + await session_context.close() + + @pytest.mark.asyncio + async def test_start_raises_connection_error_on_exception(self): + """Test that start() raises ConnectionError when exception occurs.""" + test_exception = ValueError('Connection failed') + mock_client = MockClient(raise_on_enter=test_exception) + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + # Verify ConnectionError message contains original exception + assert 'Failed to create MCP session' in str(exc_info.value) + assert 'Connection failed' in str(exc_info.value) + + # Verify ready_event was set (in finally block) + assert session_context._ready_event.is_set() + + @pytest.mark.asyncio + async def test_start_raises_connection_error_on_cancelled_error(self): + """Test that start() raises ConnectionError when CancelledError occurs.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Mock session that will cause cancellation + mock_session = MockClientSession() + + # Make initialize raise CancelledError + async def cancelled_initialize(): + raise asyncio.CancelledError('Task cancelled') + + mock_session.initialize = cancelled_initialize + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + # Should raise ConnectionError (not CancelledError directly) + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + # Verify it's a ConnectionError about cancellation + assert 'Failed to create MCP session' in str(exc_info.value) + assert 'task cancelled' in str(exc_info.value) + + # Verify ready_event was set + assert session_context._ready_event.is_set() + + @pytest.mark.asyncio + async def test_close_cleans_up_task(self): + """Test that close() properly cleans up the task.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Mock ClientSession + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + # Start the session context + await session_context.start() + + # Verify task is running + assert session_context._task is not None + assert not session_context._task.done() + + # Close the session context + await session_context.close() + + # Wait a bit for cleanup + await asyncio.sleep(0.1) + + # Verify close_event was set + assert session_context._close_event.is_set() + + # Verify task completed (may take a moment) + # The task should finish after close_event is set + assert session_context._task.done() + + @pytest.mark.asyncio + async def test_session_exception_does_not_break_event_loop(self): + """Test that session exceptions don't break the event loop.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Mock ClientSession that raises exception during use + mock_session = MockClientSession() + + async def failing_operation(): + raise RuntimeError('Session operation failed') + + mock_session.failing_operation = failing_operation + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + # Start the session context + session = await session_context.start() + + # Use session and trigger exception + with pytest.raises(RuntimeError, match='Session operation failed'): + await session.failing_operation() + + # Close the session context - should not break event loop + await session_context.close() + + # Verify event loop is still healthy by running another task + result = await asyncio.sleep(0.01) + assert result is None + + @pytest.mark.asyncio + async def test_async_context_manager(self): + """Test using SessionContext as async context manager.""" + mock_client = MockClient() + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + async with SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) as session: + assert session == mock_session + # Verify initialize was called by checking _initialized flag + assert session._initialized + + @pytest.mark.asyncio + async def test_timeout_during_connection(self): + """Test timeout during client connection.""" + # Client that takes longer than timeout + mock_client = MockClient(delay_on_enter=10.0) + session_context = SessionContext( + mock_client, timeout=0.1, sse_read_timeout=None + ) + + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + assert 'Failed to create MCP session' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_timeout_during_initialization(self): + """Test timeout during session initialization.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=0.1, sse_read_timeout=None + ) + + # Mock ClientSession with slow initialize + mock_session = MockClientSession() + + async def slow_initialize(): + await asyncio.sleep(1.0) + return mock_session + + mock_session.initialize = slow_initialize + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + assert 'Failed to create MCP session' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_stdio_client_with_read_timeout(self): + """Test stdio client includes read_timeout_seconds parameter.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None, is_stdio=True + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Verify ClientSession was called with read_timeout_seconds for stdio + call_args = mock_session_class.call_args + assert 'read_timeout_seconds' in call_args.kwargs + assert call_args.kwargs['read_timeout_seconds'] == timedelta(seconds=5.0) + + await session_context.close() + + @pytest.mark.asyncio + async def test_non_stdio_client_without_read_timeout(self): + """Test non-stdio client does not include read_timeout_seconds.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None, is_stdio=False + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Verify ClientSession was called with read_timeout_seconds=None for non-stdio + # when sse_read_timeout is None + call_args = mock_session_class.call_args + assert 'read_timeout_seconds' in call_args.kwargs + assert call_args.kwargs['read_timeout_seconds'] is None + + await session_context.close() + + @pytest.mark.asyncio + async def test_sse_read_timeout_passed_to_client_session(self): + """Test that sse_read_timeout is passed to ClientSession for non-stdio.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=300.0, is_stdio=False + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Verify ClientSession was called with sse_read_timeout + call_args = mock_session_class.call_args + assert 'read_timeout_seconds' in call_args.kwargs + assert call_args.kwargs['read_timeout_seconds'] == timedelta( + seconds=300.0 + ) + + await session_context.close() + + @pytest.mark.asyncio + async def test_close_multiple_times(self): + """Test that close() can be called multiple times safely.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Close multiple times + await session_context.close() + await session_context.close() + await session_context.close() + + # Should not raise exception + assert session_context._close_event.is_set() + + @pytest.mark.asyncio + async def test_close_before_start(self): + """Test that close() works even if start() was never called.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Close before starting should not raise + await session_context.close() + + assert session_context._close_event.is_set() + + @pytest.mark.asyncio + async def test_close_before_start_ends(self): + """Test that close() before start() ends the task.""" + # Client has enough time to delay the start task + mock_client = MockClient(delay_on_enter=10.0) + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + start_task = asyncio.create_task(session_context.start()) + await asyncio.sleep(0.1) + assert not start_task.done() + + # Call close before start() ends the task + await session_context.close() + await asyncio.sleep(0.1) + + assert start_task.done() + assert isinstance( + start_task.exception(), ConnectionError + ) and 'task cancelled' in str(start_task.exception()) + + @pytest.mark.asyncio + async def test_close_before_start_called(self): + """Test that close() before start() called sets the close event.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Call close() before start() called + await session_context.close() + await asyncio.sleep(0.1) + + assert session_context._task is None + assert session_context._close_event.is_set() + + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + assert 'session already closed' in str(exc_info.value) + assert session_context._task is None + + @pytest.mark.asyncio + async def test_session_property(self): + """Test that session property returns the managed session.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Initially None + assert session_context.session is None + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Should return the session + assert session_context.session == mock_session + + await session_context.close() + + @pytest.mark.asyncio + async def test_client_cleanup_on_exception(self): + """Test that client is properly cleaned up even when exception occurs.""" + test_exception = RuntimeError('Test error') + mock_client = MockClient(raise_on_enter=test_exception) + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + with pytest.raises(ConnectionError): + await session_context.start() + + # Wait a bit for cleanup + await asyncio.sleep(0.1) + + # Verify task completed + assert session_context._task.done() + + @pytest.mark.asyncio + async def test_close_handles_cancelled_error(self): + """Test that close() handles CancelledError gracefully.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Cancel the task + if session_context._task: + session_context._task.cancel() + + # Close should handle CancelledError gracefully + await session_context.close() + + # Should not raise exception + assert session_context._close_event.is_set() + + @pytest.mark.asyncio + async def test_close_handles_exception_during_cleanup(self): + """Test that close() handles exceptions during cleanup gracefully.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Create a mock session that raises during exit + class FailingMockSession(MockClientSession): + + async def __aexit__(self, exc_type, exc_val, exc_tb): + raise RuntimeError('Cleanup failed') + + failing_session = FailingMockSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = failing_session + + await session_context.start() + + # Close should handle the exception gracefully + await session_context.close() + + # Should not raise exception + assert session_context._close_event.is_set() From d4da1bb7330cdb87c1dcbe0b9023148357a6bd07 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 14 Jan 2026 19:18:44 -0800 Subject: [PATCH 10/11] fix: Initialize self._auth_config inside BaseAuthenticatedTool So that we can access self._auth_config in McpTool for getting auth headers Co-authored-by: Kathy Wu PiperOrigin-RevId: 856451693 --- src/google/adk/tools/base_authenticated_tool.py | 1 + tests/unittests/tools/test_base_authenticated_tool.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/google/adk/tools/base_authenticated_tool.py b/src/google/adk/tools/base_authenticated_tool.py index 862d1cef5a..92e395d4ac 100644 --- a/src/google/adk/tools/base_authenticated_tool.py +++ b/src/google/adk/tools/base_authenticated_tool.py @@ -66,6 +66,7 @@ def __init__( name=name, description=description, ) + self._auth_config = auth_config if auth_config and auth_config.auth_scheme: self._credentials_manager = CredentialManager(auth_config=auth_config) diff --git a/tests/unittests/tools/test_base_authenticated_tool.py b/tests/unittests/tools/test_base_authenticated_tool.py index 55454224d8..5f7bf53f7d 100644 --- a/tests/unittests/tools/test_base_authenticated_tool.py +++ b/tests/unittests/tools/test_base_authenticated_tool.py @@ -90,6 +90,7 @@ def test_init_with_auth_config(self): assert tool.description == "Test description" assert tool._credentials_manager is not None assert tool._response_for_auth_required == unauthenticated_response + assert tool._auth_config == auth_config def test_init_with_no_auth_config(self): """Test initialization without auth_config.""" @@ -99,6 +100,7 @@ def test_init_with_no_auth_config(self): assert tool.description == "Test authenticated tool" assert tool._credentials_manager is None assert tool._response_for_auth_required is None + assert tool._auth_config is None def test_init_with_empty_auth_scheme(self): """Test initialization with auth_config but no auth_scheme.""" From 7c282973ea193841fee79f90b8a91c5e02627ccc Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 14 Jan 2026 19:56:13 -0800 Subject: [PATCH 11/11] fix: Support Generator and AsyncGenerator tool declaration use yield type as return type Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 856459995 --- .../tools/_automatic_function_calling_util.py | 17 ++++ .../tools/test_from_function_with_options.py | 77 +++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py index 92df88718a..2b00c79917 100644 --- a/src/google/adk/tools/_automatic_function_calling_util.py +++ b/src/google/adk/tools/_automatic_function_calling_util.py @@ -14,12 +14,15 @@ from __future__ import annotations +import collections.abc import inspect from types import FunctionType import typing from typing import Any from typing import Callable from typing import Dict +from typing import get_args +from typing import get_origin from typing import Optional from typing import Union @@ -391,6 +394,20 @@ def from_function_with_options( return_annotation = inspect.signature(func).return_annotation + # Handle AsyncGenerator and Generator return types (streaming tools) + # AsyncGenerator[YieldType, SendType] -> use YieldType as response schema + # Generator[YieldType, SendType, ReturnType] -> use YieldType as response schema + origin = get_origin(return_annotation) + if origin is not None and ( + origin is collections.abc.AsyncGenerator + or origin is collections.abc.Generator + ): + type_args = get_args(return_annotation) + if type_args: + # First type argument is the yield type + yield_type = type_args[0] + return_annotation = yield_type + # Handle functions with no return annotation if return_annotation is inspect._empty: # Functions with no return annotation can return any type diff --git a/tests/unittests/tools/test_from_function_with_options.py b/tests/unittests/tools/test_from_function_with_options.py index 61670a2678..eae164538f 100644 --- a/tests/unittests/tools/test_from_function_with_options.py +++ b/tests/unittests/tools/test_from_function_with_options.py @@ -14,7 +14,9 @@ from collections.abc import Sequence from typing import Any +from typing import AsyncGenerator from typing import Dict +from typing import Generator from google.adk.tools import _automatic_function_calling_util from google.adk.utils.variant_utils import GoogleLLMVariant @@ -242,3 +244,78 @@ def test_function( assert declaration.name == 'test_function' assert declaration.response.type == types.Type.ARRAY assert declaration.response.items.type == types.Type.STRING + + +def test_from_function_with_async_generator_return_vertex(): + """Test from_function_with_options with AsyncGenerator return for VERTEX_AI.""" + + async def test_function(param: str) -> AsyncGenerator[str, None]: + """A streaming function that yields strings.""" + yield param + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # VERTEX_AI should extract yield type (str) from AsyncGenerator[str, None] + assert declaration.response is not None + assert declaration.response.type == types.Type.STRING + + +def test_from_function_with_async_generator_return_gemini(): + """Test from_function_with_options with AsyncGenerator return for GEMINI_API.""" + + async def test_function(param: str) -> AsyncGenerator[str, None]: + """A streaming function that yields strings.""" + yield param + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.GEMINI_API + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # GEMINI_API should not have response schema + assert declaration.response is None + + +def test_from_function_with_generator_return_vertex(): + """Test from_function_with_options with Generator return for VERTEX_AI.""" + + def test_function(param: str) -> Generator[int, None, None]: + """A streaming function that yields integers.""" + yield 42 + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # VERTEX_AI should extract yield type (int) from Generator[int, None, None] + assert declaration.response is not None + assert declaration.response.type == types.Type.INTEGER + + +def test_from_function_with_async_generator_complex_yield_type_vertex(): + """Test from_function_with_options with AsyncGenerator yielding dict.""" + + async def test_function(param: str) -> AsyncGenerator[Dict[str, str], None]: + """A streaming function that yields dicts.""" + yield {'result': param} + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # VERTEX_AI should extract yield type (Dict[str, str]) from AsyncGenerator + assert declaration.response is not None + assert declaration.response.type == types.Type.OBJECT