From 50c4b8d33a32a4e07e2f1acc010dd13f8b93b317 Mon Sep 17 00:00:00 2001 From: nikkie Date: Thu, 15 Jan 2026 09:24:48 -0800 Subject: [PATCH 1/7] chore: Disable scheduled GitHub Actions workflows in forks Merge https://github.com/google/adk-python/pull/4059 ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: #3961 **Problem:** Excessive notifications from periodic workflow runs (every 6 hours for triage, daily for stale-bot and docs upload) in forks where they are not needed. **Solution:** Add repository checks to prevent scheduled workflows from running in forked repositories. The workflows will now only run in the main google/adk-python repository. ### Testing Plan I think that unit tests and E2E are not needed because this change is for GitHub Actions (not ADK source code). _Please describe the tests that you ran to verify your changes. This is required for all PRs that are not small documentation or typo fixes._ **Unit Tests:** - [ ] I have added or updated unit tests for my change. - [ ] All unit tests pass locally. _Please include a summary of passed `pytest` results._ **Manual End-to-End (E2E) Tests:** _Please provide instructions on how to manually test your changes, including any necessary setup or configuration. Please provide logs or screenshots to help reviewers better understand the fix._ ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [ ] I have added tests that prove my fix is effective or that my feature works. - [ ] New and existing unit tests pass locally with my changes. - [ ] I have manually tested my changes end-to-end. - [ ] Any dependent changes have been merged and published in downstream modules. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4059 from ftnext:disable-fork-actions 991cc94674cd5cef2e3c2aacf243ba7dae7b88ad PiperOrigin-RevId: 856691019 --- .github/workflows/stale-bot.yml | 1 + .github/workflows/triage.yml | 8 +++++--- .github/workflows/upload-adk-docs-to-vertex-ai-search.yml | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/stale-bot.yml b/.github/workflows/stale-bot.yml index b2a9cb82ad..1e539c7ff6 100644 --- a/.github/workflows/stale-bot.yml +++ b/.github/workflows/stale-bot.yml @@ -23,6 +23,7 @@ on: jobs: audit-stale-issues: + if: github.repository == 'google/adk-python' runs-on: ubuntu-latest timeout-minutes: 60 diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml index d19a0e9197..ff1afaac98 100644 --- a/.github/workflows/triage.yml +++ b/.github/workflows/triage.yml @@ -15,9 +15,11 @@ jobs: # - New issues (need component labeling) # - Issues labeled with "planned" (need owner assignment) if: >- - github.event_name == 'schedule' || - github.event.action == 'opened' || - github.event.label.name == 'planned' + github.repository == 'google/adk-python' && ( + github.event_name == 'schedule' || + github.event.action == 'opened' || + github.event.label.name == 'planned' + ) permissions: issues: write contents: read diff --git a/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml b/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml index e8d94eb9dc..f29adbe9e7 100644 --- a/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml +++ b/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml @@ -9,6 +9,7 @@ on: jobs: upload-adk-docs-to-vertex-ai-search: + if: github.repository == 'google/adk-python' runs-on: ubuntu-latest steps: From ed2c3ebde9cafbb5e2bf375f44db1e77cee9fb24 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 15 Jan 2026 10:03:05 -0800 Subject: [PATCH 2/7] fix: Prevent stopping event processing on events with None content PiperOrigin-RevId: 856706510 --- src/google/adk/agents/remote_a2a_agent.py | 11 ++++++++++- .../unittests/agents/test_remote_a2a_agent.py | 18 +++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 167328847c..23a9b47554 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -348,6 +348,13 @@ def _create_a2a_request_for_user_function_response( return a2a_message + def _is_remote_response(self, event: Event) -> bool: + return ( + event.author == self.name + and event.custom_metadata + and event.custom_metadata.get(A2A_METADATA_PREFIX + "response", False) + ) + def _construct_message_parts_from_session( self, ctx: InvocationContext ) -> tuple[list[A2APart], Optional[str]]: @@ -365,7 +372,7 @@ def _construct_message_parts_from_session( events_to_process = [] for event in reversed(ctx.session.events): - if event.author == self.name: + if self._is_remote_response(event): # stop on content generated by current a2a agent given it should already # be in remote session if event.custom_metadata: @@ -496,6 +503,8 @@ async def _handle_a2a_response( invocation_id=ctx.invocation_id, branch=ctx.branch, ) + event.custom_metadata = event.custom_metadata or {} + event.custom_metadata[A2A_METADATA_PREFIX + "response"] = True return event except A2AClientError as e: logger.error("Failed to handle A2A response: %s", e) diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index d395a5516f..6a098dff9a 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -683,7 +683,16 @@ def test_construct_message_parts_from_session_stops_on_agent_reply(self): agent1 = Mock() agent1.content = content2 agent1.author = self.agent.name - agent1.custom_metadata = None + agent1.custom_metadata = { + A2A_METADATA_PREFIX + "response": True, + } + + agent2 = Mock() + agent2.content = None + agent2.author = self.agent.name + # Just actions, no content. Not marked as a response. + agent2.actions = Mock() + agent2.custom_metadata = None part3 = Mock() part3.text = "User 2" @@ -694,7 +703,7 @@ def test_construct_message_parts_from_session_stops_on_agent_reply(self): user2.author = "user" user2.custom_metadata = None - self.mock_session.events = [user1, agent1, user2] + self.mock_session.events = [user1, agent1, user2, agent2] def mock_converter(part): mock_a2a_part = Mock() @@ -785,7 +794,10 @@ def test_construct_message_parts_from_session_stateful_partial_history(self): agent1 = Mock() agent1.content = content2 agent1.author = self.agent.name - agent1.custom_metadata = {A2A_METADATA_PREFIX + "context_id": "ctx-1"} + agent1.custom_metadata = { + A2A_METADATA_PREFIX + "response": True, + A2A_METADATA_PREFIX + "context_id": "ctx-1", + } part3 = Mock() part3.text = "User 2" From 19555e7dce6d60c3b960ca0bc2f928c138ac3cc0 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 15 Jan 2026 10:19:32 -0800 Subject: [PATCH 3/7] fix: Support Generator and Async Generator tool declaration in JSON schema Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 856713741 --- .../adk/tools/_function_tool_declarations.py | 16 ++++ .../tools/test_function_tool_declarations.py | 80 +++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/src/google/adk/tools/_function_tool_declarations.py b/src/google/adk/tools/_function_tool_declarations.py index 7b37390856..5dfc192770 100644 --- a/src/google/adk/tools/_function_tool_declarations.py +++ b/src/google/adk/tools/_function_tool_declarations.py @@ -24,10 +24,13 @@ from __future__ import annotations +import collections.abc import inspect import logging from typing import Any from typing import Callable +from typing import get_args +from typing import get_origin from typing import get_type_hints from typing import Optional from typing import Type @@ -145,6 +148,19 @@ def _build_response_json_schema( except TypeError: pass + # Handle AsyncGenerator and Generator return types (streaming tools) + # AsyncGenerator[YieldType, SendType] -> use YieldType as response schema + # Generator[YieldType, SendType, ReturnType] -> use YieldType as response schema + origin = get_origin(return_annotation) + if origin is not None and ( + origin is collections.abc.AsyncGenerator + or origin is collections.abc.Generator + ): + type_args = get_args(return_annotation) + if type_args: + # First type argument is the yield type + return_annotation = type_args[0] + try: adapter = pydantic.TypeAdapter( return_annotation, diff --git a/tests/unittests/tools/test_function_tool_declarations.py b/tests/unittests/tools/test_function_tool_declarations.py index 252bc9868c..a5443d8e91 100644 --- a/tests/unittests/tools/test_function_tool_declarations.py +++ b/tests/unittests/tools/test_function_tool_declarations.py @@ -23,6 +23,8 @@ from collections.abc import Sequence from enum import Enum from typing import Any +from typing import AsyncGenerator +from typing import Generator from typing import Literal from typing import Optional @@ -840,3 +842,81 @@ class CreateUserRequest(BaseModel): # When passing a BaseModel, there is no function return, so response schema # is None self.assertIsNone(decl.response_json_schema) + + +class TestStreamingReturnTypes(parameterized.TestCase): + """Tests for AsyncGenerator and Generator return types (streaming tools).""" + + def test_async_generator_string_yield(self): + """Test AsyncGenerator[str, None] return type extracts str as response.""" + + async def streaming_tool(param: str) -> AsyncGenerator[str, None]: + """A streaming tool that yields strings.""" + yield param + + decl = build_function_declaration_with_json_schema(streaming_tool) + + self.assertEqual(decl.name, "streaming_tool") + self.assertIsNotNone(decl.parameters_json_schema) + self.assertEqual( + decl.parameters_json_schema["properties"]["param"]["type"], "string" + ) + # Should extract str from AsyncGenerator[str, None] + self.assertEqual(decl.response_json_schema, {"type": "string"}) + + def test_async_generator_int_yield(self): + """Test AsyncGenerator[int, None] return type extracts int as response.""" + + async def counter(start: int) -> AsyncGenerator[int, None]: + """A streaming counter.""" + yield start + + decl = build_function_declaration_with_json_schema(counter) + + self.assertEqual(decl.name, "counter") + # Should extract int from AsyncGenerator[int, None] + self.assertEqual(decl.response_json_schema, {"type": "integer"}) + + def test_async_generator_dict_yield(self): + """Test AsyncGenerator[dict[str, str], None] return type.""" + + async def streaming_dict( + param: str, + ) -> AsyncGenerator[dict[str, str], None]: + """A streaming tool that yields dicts.""" + yield {"result": param} + + decl = build_function_declaration_with_json_schema(streaming_dict) + + self.assertEqual(decl.name, "streaming_dict") + # Should extract dict[str, str] from AsyncGenerator + self.assertEqual( + decl.response_json_schema, + {"additionalProperties": {"type": "string"}, "type": "object"}, + ) + + def test_generator_string_yield(self): + """Test Generator[str, None, None] return type extracts str as response.""" + + def sync_streaming_tool(param: str) -> Generator[str, None, None]: + """A sync streaming tool that yields strings.""" + yield param + + decl = build_function_declaration_with_json_schema(sync_streaming_tool) + + self.assertEqual(decl.name, "sync_streaming_tool") + # Should extract str from Generator[str, None, None] + self.assertEqual(decl.response_json_schema, {"type": "string"}) + + def test_generator_int_yield(self): + """Test Generator[int, None, None] return type extracts int as response.""" + + def sync_counter(start: int) -> Generator[int, None, None]: + """A sync streaming counter.""" + yield start + + decl = build_function_declaration_with_json_schema(sync_counter) + + self.assertEqual(decl.name, "sync_counter") + # Should extract int from Generator[int, None, None] + self.assertEqual(decl.response_json_schema, {"type": "integer"}) From 83d7bb6ef0d952ad04c5d9a61aaf202672c7e17d Mon Sep 17 00:00:00 2001 From: Yeesian Ng Date: Thu, 15 Jan 2026 10:45:40 -0800 Subject: [PATCH 4/7] fix: Use the correct path for config-based agents when deploying to AgentEngine Co-authored-by: Yeesian Ng PiperOrigin-RevId: 856724694 --- src/google/adk/cli/cli_deploy.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index d6ef019adc..86aaefb8a4 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -74,12 +74,7 @@ if {is_config_agent}: from google.adk.agents import config_agent_utils - try: - # This path is for local loading. - root_agent = config_agent_utils.from_config("{agent_folder}/root_agent.yaml") - except FileNotFoundError: - # This path is used to support the file structure in Agent Engine. - root_agent = config_agent_utils.from_config("./{temp_folder}/{app_name}/root_agent.yaml") + root_agent = config_agent_utils.from_config("{agent_folder}/root_agent.yaml") else: from .agent import {adk_app_object} @@ -912,8 +907,7 @@ def to_agent_engine( app_name=app_name, trace_to_cloud_option=trace_to_cloud, is_config_agent=is_config_agent, - temp_folder=temp_folder, - agent_folder=agent_folder, + agent_folder=f'./{temp_folder}', adk_app_object=adk_app_object, adk_app_type=adk_app_type, express_mode=api_key is not None, From 6ad18cc2fc3a3315a0fc240cb51b3283b53116b4 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 15 Jan 2026 11:13:37 -0800 Subject: [PATCH 5/7] fix: Use json.dumps for error messages in SSE events Co-authored-by: George Weale PiperOrigin-RevId: 856737497 --- src/google/adk/cli/adk_web_server.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 6a39ab1a84..0f0657ee0c 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -1558,8 +1558,17 @@ async def event_generator(): yield f"data: {sse_event}\n\n" except Exception as e: logger.exception("Error in event_generator: %s", e) - # You might want to yield an error event here - yield f'data: {{"error": "{str(e)}"}}\n\n' + # Yield a proper Event object for the error + error_event = Event( + author="system", + content=types.Content( + role="model", parts=[types.Part(text=f"Error: {e}")] + ), + ) + yield ( + "data:" + f" {error_event.model_dump_json(by_alias=True, exclude_none=True)}\n\n" + ) # Returns a streaming response with the proper media type for SSE return StreamingResponse( From 6dbe851fca34659dcbbc5048193a9fe46d86d124 Mon Sep 17 00:00:00 2001 From: Yeesian Ng Date: Thu, 15 Jan 2026 11:42:33 -0800 Subject: [PATCH 6/7] chore: Add back unit tests for CLI utility to deploy to AgentEngine Co-authored-by: Yeesian Ng PiperOrigin-RevId: 856749290 --- src/google/adk/cli/cli_deploy.py | 81 +++++++++++++++---- src/google/adk/cli/cli_tools_click.py | 23 ++++-- tests/unittests/cli/utils/test_cli_deploy.py | 67 ++++++++++++++- .../cli/utils/test_cli_tools_click.py | 3 - 4 files changed, 149 insertions(+), 25 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 86aaefb8a4..781274fbfd 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -20,6 +20,7 @@ import subprocess from typing import Final from typing import Optional +import warnings import click from packaging.version import parse @@ -27,6 +28,36 @@ _IS_WINDOWS = os.name == 'nt' _GCLOUD_CMD = 'gcloud.cmd' if _IS_WINDOWS else 'gcloud' _LOCAL_STORAGE_FLAG_MIN_VERSION: Final[str] = '1.21.0' +_AGENT_ENGINE_REQUIREMENT: Final[str] = ( + 'google-cloud-aiplatform[adk,agent_engines]' +) + + +def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: + """Ensures staged requirements include Agent Engine dependencies.""" + if not os.path.exists(requirements_txt_path): + raise FileNotFoundError( + f'requirements.txt not found at: {requirements_txt_path}' + ) + + requirements = '' + with open(requirements_txt_path, 'r', encoding='utf-8') as f: + requirements = f.read() + + for line in requirements.splitlines(): + stripped = line.strip() + if ( + stripped + and not stripped.startswith('#') + and stripped.startswith('google-cloud-aiplatform') + ): + return + + with open(requirements_txt_path, 'a', encoding='utf-8') as f: + if requirements and not requirements.endswith('\n'): + f.write('\n') + f.write(_AGENT_ENGINE_REQUIREMENT + '\n') + _DOCKERFILE_TEMPLATE: Final[str] = """ FROM python:3.11-slim @@ -656,7 +687,7 @@ def to_agent_engine( agent_folder: str, temp_folder: Optional[str] = None, adk_app: str, - staging_bucket: str, + staging_bucket: Optional[str] = None, trace_to_cloud: Optional[bool] = None, api_key: Optional[str] = None, adk_app_object: Optional[str] = None, @@ -699,7 +730,8 @@ def to_agent_engine( files. It will be replaced with the generated files if it already exists. adk_app (str): The name of the file (without .py) containing the AdkApp instance. - staging_bucket (str): The GCS bucket for staging the deployment artifacts. + staging_bucket (str): Deprecated. This argument is no longer required or + used. trace_to_cloud (bool): Whether to enable Cloud Trace. api_key (str): Optional. The API key to use for Express Mode. If not provided, the API key from the GOOGLE_API_KEY environment variable @@ -729,13 +761,6 @@ def to_agent_engine( app_name = os.path.basename(agent_folder) display_name = display_name or app_name parent_folder = os.path.dirname(agent_folder) - if parent_folder != os.getcwd(): - click.echo(f'Please deploy from the project dir: {parent_folder}') - return - tmp_app_name = app_name + '_tmp' + datetime.now().strftime('%Y%m%d_%H%M%S') - temp_folder = temp_folder or tmp_app_name - agent_src_path = os.path.join(parent_folder, temp_folder) - click.echo(f'Staging all files in: {agent_src_path}') adk_app_object = adk_app_object or 'root_agent' if adk_app_object not in ['root_agent', 'app']: click.echo( @@ -743,12 +768,34 @@ def to_agent_engine( ' or "app".' ) return + if staging_bucket: + warnings.warn( + 'WARNING: `staging_bucket` is deprecated and will be removed in a' + ' future release. Please drop it from the list of arguments.', + DeprecationWarning, + stacklevel=2, + ) + + original_cwd = os.getcwd() + did_change_cwd = False + if parent_folder != original_cwd: + click.echo( + 'Agent Engine deployment uses relative paths; temporarily switching ' + f'working directory to: {parent_folder}' + ) + os.chdir(parent_folder) + did_change_cwd = True + tmp_app_name = app_name + '_tmp' + datetime.now().strftime('%Y%m%d_%H%M%S') + temp_folder = temp_folder or tmp_app_name + agent_src_path = os.path.join(parent_folder, temp_folder) + click.echo(f'Staging all files in: {agent_src_path}') # remove agent_src_path if it exists if os.path.exists(agent_src_path): click.echo('Removing existing files') shutil.rmtree(agent_src_path) try: + click.echo(f'Staging all files in: {agent_src_path}') ignore_patterns = None ae_ignore_path = os.path.join(agent_folder, '.ae_ignore') if os.path.exists(ae_ignore_path): @@ -757,15 +804,18 @@ def to_agent_engine( patterns = [pattern.strip() for pattern in f.readlines()] ignore_patterns = shutil.ignore_patterns(*patterns) click.echo('Copying agent source code...') - shutil.copytree(agent_folder, agent_src_path, ignore=ignore_patterns) + shutil.copytree( + agent_folder, + agent_src_path, + ignore=ignore_patterns, + dirs_exist_ok=True, + ) click.echo('Copying agent source code complete.') project = _resolve_project(project) click.echo('Resolving files and dependencies...') agent_config = {} - if staging_bucket: - agent_config['staging_bucket'] = staging_bucket if not agent_engine_config_file: # Attempt to read the agent engine config from .agent_engine_config.json in the dir (if any). agent_engine_config_file = os.path.join( @@ -808,8 +858,9 @@ def to_agent_engine( if not os.path.exists(requirements_txt_path): click.echo(f'Creating {requirements_txt_path}...') with open(requirements_txt_path, 'w', encoding='utf-8') as f: - f.write('google-cloud-aiplatform[adk,agent_engines]') + f.write(_AGENT_ENGINE_REQUIREMENT + '\n') click.echo(f'Created {requirements_txt_path}') + _ensure_agent_engine_dependency(requirements_txt_path) agent_config['requirements_file'] = f'{temp_folder}/requirements.txt' env_vars = {} @@ -940,7 +991,9 @@ def to_agent_engine( click.secho(f'✅ Updated agent engine: {agent_engine_id}', fg='green') finally: click.echo(f'Cleaning up the temp folder: {temp_folder}') - shutil.rmtree(temp_folder) + shutil.rmtree(agent_src_path) + if did_change_cwd: + os.chdir(original_cwd) def to_gke( diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 5d7611f217..241c696351 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -1031,6 +1031,19 @@ def wrapper(*args, **kwargs): return decorator +def _deprecate_staging_bucket(ctx, param, value): + if value: + click.echo( + click.style( + f"WARNING: --{param} is deprecated and will be removed. Please" + " leave it unspecified.", + fg="yellow", + ), + err=True, + ) + return value + + def deprecated_adk_services_options(): """Deprecated ADK services options.""" @@ -1689,10 +1702,8 @@ def cli_migrate_session( "--staging_bucket", type=str, default=None, - help=( - "Optional. GCS bucket for staging the deployment artifacts. It will be" - " ignored if api_key is set." - ), + help="Deprecated. This argument is no longer required or used.", + callback=_deprecate_staging_bucket, ) @click.option( "--agent_engine_id", @@ -1827,8 +1838,7 @@ def cli_deploy_agent_engine( # With Google Cloud Project and Region adk deploy agent_engine --project=[project] --region=[region] - --staging_bucket=[staging_bucket] --display_name=[app_name] - my_agent + --display_name=[app_name] my_agent """ logging.getLogger("vertexai_genai.agentengines").setLevel(logging.INFO) try: @@ -1836,7 +1846,6 @@ def cli_deploy_agent_engine( agent_folder=agent, project=project, region=region, - staging_bucket=staging_bucket, agent_engine_id=agent_engine_id, trace_to_cloud=trace_to_cloud, api_key=api_key, diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index 9a2ebcfac2..dad93583df 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -26,7 +26,6 @@ from typing import Any from typing import Callable from typing import Dict -from typing import Generator from typing import List from typing import Tuple from unittest import mock @@ -227,6 +226,72 @@ def test_get_service_option_by_adk_version( assert actual.rstrip() == expected.rstrip() +@pytest.mark.parametrize("include_requirements", [True, False]) +def test_to_agent_engine_happy_path( + monkeypatch: pytest.MonkeyPatch, + agent_dir: Callable[[bool, bool], Path], + include_requirements: bool, +) -> None: + """Tests the happy path for the `to_agent_engine` function.""" + rmtree_recorder = _Recorder() + monkeypatch.setattr(shutil, "rmtree", rmtree_recorder) + create_recorder = _Recorder() + + fake_vertexai = types.ModuleType("vertexai") + + class _FakeAgentEngines: + + def create(self, *, config: Dict[str, Any]) -> Any: + create_recorder(config=config) + return types.SimpleNamespace( + api_resource=types.SimpleNamespace( + name="projects/p/locations/l/reasoningEngines/e" + ) + ) + + def update(self, *, name: str, config: Dict[str, Any]) -> None: + del name + del config + + class _FakeVertexClient: + + def __init__(self, *args: Any, **kwargs: Any) -> None: + del args + del kwargs + self.agent_engines = _FakeAgentEngines() + + fake_vertexai.Client = _FakeVertexClient + monkeypatch.setitem(sys.modules, "vertexai", fake_vertexai) + src_dir = agent_dir(include_requirements, False) + tmp_dir = src_dir.parent / "tmp" + cli_deploy.to_agent_engine( + agent_folder=str(src_dir), + temp_folder="tmp", + adk_app="my_adk_app", + trace_to_cloud=True, + project="my-gcp-project", + region="us-central1", + display_name="My Test Agent", + description="A test agent.", + ) + agent_file = tmp_dir / "agent.py" + assert agent_file.is_file() + init_file = tmp_dir / "__init__.py" + assert init_file.is_file() + adk_app_file = tmp_dir / "my_adk_app.py" + assert adk_app_file.is_file() + content = adk_app_file.read_text() + assert "from .agent import root_agent" in content + assert "adk_app = AdkApp(" in content + assert "agent=root_agent" in content + assert "enable_tracing=True" in content + reqs_path = tmp_dir / "requirements.txt" + assert reqs_path.is_file() + assert "google-cloud-aiplatform[adk,agent_engines]" in reqs_path.read_text() + assert len(create_recorder.calls) == 1 + assert str(rmtree_recorder.get_last_call_args()[0]) == str(tmp_dir) + + @pytest.mark.parametrize("include_requirements", [True, False]) def test_to_gke_happy_path( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index 95b561e57b..316ffbb6af 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -400,8 +400,6 @@ def test_cli_deploy_agent_engine_success( "test-proj", "--region", "us-central1", - "--staging_bucket", - "gs://mybucket", str(agent_dir), ], ) @@ -410,7 +408,6 @@ def test_cli_deploy_agent_engine_success( called_kwargs = rec.calls[0][1] assert called_kwargs.get("project") == "test-proj" assert called_kwargs.get("region") == "us-central1" - assert called_kwargs.get("staging_bucket") == "gs://mybucket" # cli deploy gke From 19315fe557039fa8bf446525a4830b1c9f40cba9 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Thu, 15 Jan 2026 13:37:43 -0800 Subject: [PATCH 7/7] feat: Support authentication for MCP tool listing Currently only tool calling supports MCP auth. This refactors the auth logic into a auth_utils file and uses it for tool listing as well. Fixes https://github.com/google/adk-python/issues/2168. Co-authored-by: Kathy Wu PiperOrigin-RevId: 856798589 --- .../adk/tools/mcp_tool/mcp_auth_utils.py | 110 +++++++ src/google/adk/tools/mcp_tool/mcp_tool.py | 94 +----- src/google/adk/tools/mcp_tool/mcp_toolset.py | 46 ++- .../tools/mcp_tool/test_mcp_auth_utils.py | 240 ++++++++++++++ .../unittests/tools/mcp_tool/test_mcp_tool.py | 296 ------------------ .../tools/mcp_tool/test_mcp_toolset.py | 90 ++++++ 6 files changed, 490 insertions(+), 386 deletions(-) create mode 100644 src/google/adk/tools/mcp_tool/mcp_auth_utils.py create mode 100644 tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py diff --git a/src/google/adk/tools/mcp_tool/mcp_auth_utils.py b/src/google/adk/tools/mcp_tool/mcp_auth_utils.py new file mode 100644 index 0000000000..b074e67f15 --- /dev/null +++ b/src/google/adk/tools/mcp_tool/mcp_auth_utils.py @@ -0,0 +1,110 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for MCP tool authentication.""" + +from __future__ import annotations + +import base64 +import logging +from typing import Dict +from typing import Optional + +from fastapi.openapi import models as openapi_models +from fastapi.openapi.models import APIKey +from fastapi.openapi.models import HTTPBase + +from ...auth.auth_credential import AuthCredential +from ...auth.auth_schemes import AuthScheme + +logger = logging.getLogger("google_adk." + __name__) + + +def get_mcp_auth_headers( + auth_scheme: Optional[AuthScheme], credential: Optional[AuthCredential] +) -> Optional[Dict[str, str]]: + """Generates HTTP authentication headers for MCP calls. + + Args: + auth_scheme: The authentication scheme. + credential: The resolved authentication credential. + + Returns: + A dictionary of headers, or None if no auth is applicable. + + Raises: + ValueError: If the auth scheme is unsupported or misconfigured. + """ + if not credential: + return None + + headers: Optional[Dict[str, str]] = None + + if credential.oauth2: + headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} + elif credential.http: + if not auth_scheme or not isinstance(auth_scheme, HTTPBase): + logger.warning( + "HTTP credential provided, but auth_scheme is missing or not" + " HTTPBase." + ) + return None + + scheme = auth_scheme.scheme.lower() + if scheme == "bearer" and credential.http.credentials.token: + headers = {"Authorization": f"Bearer {credential.http.credentials.token}"} + elif scheme == "basic": + if ( + credential.http.credentials.username + and credential.http.credentials.password + ): + creds = f"{credential.http.credentials.username}:{credential.http.credentials.password}" + encoded_creds = base64.b64encode(creds.encode()).decode() + headers = {"Authorization": f"Basic {encoded_creds}"} + else: + logger.warning("Basic auth scheme missing username or password.") + elif credential.http.credentials.token: + # Handle other HTTP schemes like Digest, etc. if token is present + headers = { + "Authorization": ( + f"{auth_scheme.scheme} {credential.http.credentials.token}" + ) + } + else: + logger.warning(f"Unsupported or incomplete HTTP auth scheme '{scheme}'.") + elif credential.api_key: + if not auth_scheme or not isinstance(auth_scheme, APIKey): + logger.warning( + "API key credential provided, but auth_scheme is missing or not" + " APIKey." + ) + return None + + if auth_scheme.in_ != openapi_models.APIKeyIn.header: + error_msg = ( + "MCP tools only support header-based API key authentication. " + f"Configured location: {auth_scheme.in_}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + headers = {auth_scheme.name: credential.api_key} + elif credential.service_account: + logger.warning( + "Service account credentials should be exchanged for an access token " + "before calling get_mcp_auth_headers." + ) + else: + logger.warning(f"Unsupported credential type: {type(credential)}") + + return headers diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index b15f2c73fe..3a5a257366 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -14,7 +14,6 @@ from __future__ import annotations -import base64 import inspect import logging from typing import Any @@ -24,7 +23,6 @@ from typing import Union import warnings -from fastapi.openapi.models import APIKeyIn from google.genai.types import FunctionDeclaration from mcp.types import Tool as McpBaseTool from typing_extensions import override @@ -39,6 +37,7 @@ from ..base_authenticated_tool import BaseAuthenticatedTool # import from ..tool_context import ToolContext +from .mcp_auth_utils import get_mcp_auth_headers from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_errors @@ -195,7 +194,12 @@ async def _run_async_impl( Any: The response from the tool. """ # Extract headers from credential for session pooling - auth_headers = await self._get_headers(tool_context, credential) + auth_scheme = ( + self._auth_config.auth_scheme + if hasattr(self, "_auth_config") and self._auth_config + else None + ) + auth_headers = get_mcp_auth_headers(auth_scheme, credential) dynamic_headers = None if self._header_provider: dynamic_headers = self._header_provider( @@ -217,90 +221,6 @@ async def _run_async_impl( response = await session.call_tool(self._mcp_tool.name, arguments=args) return response.model_dump(exclude_none=True, mode="json") - async def _get_headers( - self, tool_context: ToolContext, credential: AuthCredential - ) -> Optional[dict[str, str]]: - """Extracts authentication headers from credentials. - - Args: - tool_context: The tool context of the current invocation. - credential: The authentication credential to process. - - Returns: - Dictionary of headers to add to the request, or None if no auth. - - Raises: - ValueError: If API key authentication is configured for non-header location. - """ - headers: Optional[dict[str, str]] = None - if credential: - if credential.oauth2: - headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} - elif credential.http: - # Handle HTTP authentication schemes - if ( - credential.http.scheme.lower() == "bearer" - and credential.http.credentials.token - ): - headers = { - "Authorization": f"Bearer {credential.http.credentials.token}" - } - elif credential.http.scheme.lower() == "basic": - # Handle basic auth - if ( - credential.http.credentials.username - and credential.http.credentials.password - ): - - credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}" - encoded_credentials = base64.b64encode( - credentials.encode() - ).decode() - headers = {"Authorization": f"Basic {encoded_credentials}"} - elif credential.http.credentials.token: - # Handle other HTTP schemes with token - headers = { - "Authorization": ( - f"{credential.http.scheme} {credential.http.credentials.token}" - ) - } - elif credential.api_key: - if ( - not self._credentials_manager - or not self._credentials_manager._auth_config - ): - error_msg = ( - "Cannot find corresponding auth scheme for API key credential" - f" {credential}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - elif ( - self._credentials_manager._auth_config.auth_scheme.in_ - != APIKeyIn.header - ): - error_msg = ( - "McpTool only supports header-based API key authentication." - " Configured location:" - f" {self._credentials_manager._auth_config.auth_scheme.in_}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - else: - headers = { - self._credentials_manager._auth_config.auth_scheme.name: ( - credential.api_key - ) - } - elif credential.service_account: - # Service accounts should be exchanged for access tokens before reaching this point - logger.warning( - "Service account credentials should be exchanged before MCP" - " session creation" - ) - - return headers - class MCPTool(McpTool): """Deprecated name, use `McpTool` instead.""" diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 035b75878b..1a2dd33ded 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -33,11 +33,14 @@ from ...agents.readonly_context import ReadonlyContext from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme +from ...auth.auth_tool import AuthConfig +from ...auth.credential_manager import CredentialManager from ..base_tool import BaseTool from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate from ..tool_configs import BaseToolConfig from ..tool_configs import ToolArgsConfig +from .mcp_auth_utils import get_mcp_auth_headers from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_errors from .mcp_session_manager import SseConnectionParams @@ -154,13 +157,50 @@ async def get_tools( Returns: List[BaseTool]: A list of tools available under the specified context. """ - headers = ( + provided_headers = ( self._header_provider(readonly_context) if self._header_provider and readonly_context - else None + else {} ) + + auth_headers = {} + if self._auth_scheme: + try: + # Instantiate CredentialsManager to resolve credentials + auth_config = AuthConfig( + auth_scheme=self._auth_scheme, + raw_auth_credential=self._auth_credential, + ) + credentials_manager = CredentialManager(auth_config) + + # Resolve the credential + resolved_credential = await credentials_manager.get_auth_credential( + readonly_context + ) + + if resolved_credential: + auth_headers = get_mcp_auth_headers( + self._auth_scheme, resolved_credential + ) + else: + logger.warning( + "Failed to resolve credential for tool listing, proceeding" + " without auth headers." + ) + except Exception as e: + logger.warning( + "Error generating auth headers for tool listing: %s, proceeding" + " without auth headers.", + e, + exc_info=True, + ) + + merged_headers = {**(provided_headers or {}), **(auth_headers or {})} + # Get session from session manager - session = await self._mcp_session_manager.create_session(headers=headers) + session = await self._mcp_session_manager.create_session( + headers=merged_headers + ) # Fetch available tools from the MCP server timeout_in_seconds = ( diff --git a/tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py b/tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py new file mode 100644 index 0000000000..9e0988467e --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py @@ -0,0 +1,240 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +from unittest.mock import patch + +from fastapi.openapi import models as openapi_models +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.tools.mcp_tool import mcp_auth_utils +import pytest + + +def test_get_mcp_auth_headers_no_credential(): + """Test header generation with no credentials.""" + auth_scheme = openapi_models.HTTPBase(scheme="bearer") + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=None + ) + assert headers is None + + +def test_get_mcp_auth_headers_no_auth_scheme(): + """Test header generation with no auth_scheme.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(access_token="test_token"), + ) + with patch.object(mcp_auth_utils, "logger") as mock_logger: + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=None, credential=credential + ) + assert headers == {"Authorization": "Bearer test_token"} + + +def test_get_mcp_auth_headers_oauth2(): + """Test header generation for OAuth2 credentials.""" + auth_scheme = openapi_models.HTTPBase(scheme="bearer") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(access_token="test_token"), + ) + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + assert headers == {"Authorization": "Bearer test_token"} + + +def test_get_mcp_auth_headers_http_bearer(): + """Test header generation for HTTP Bearer credentials.""" + auth_scheme = openapi_models.HTTPBase(scheme="bearer") + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", credentials=HttpCredentials(token="bearer_token") + ), + ) + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + assert headers == {"Authorization": "Bearer bearer_token"} + + +def test_get_mcp_auth_headers_http_basic(): + """Test header generation for HTTP Basic credentials.""" + auth_scheme = openapi_models.HTTPBase(scheme="basic") + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="basic", + credentials=HttpCredentials(username="user", password="pass"), + ), + ) + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + expected_encoded = base64.b64encode(b"user:pass").decode() + assert headers == {"Authorization": f"Basic {expected_encoded}"} + + +def test_get_mcp_auth_headers_http_basic_missing_credentials(): + """Test header generation for HTTP Basic with missing credentials.""" + auth_scheme = openapi_models.HTTPBase(scheme="basic") + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="basic", + credentials=HttpCredentials(username="user", password=None), + ), + ) + with patch.object(mcp_auth_utils, "logger") as mock_logger: + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + assert headers is None + mock_logger.warning.assert_called_once_with( + "Basic auth scheme missing username or password." + ) + + +def test_get_mcp_auth_headers_http_custom_scheme(): + """Test header generation for custom HTTP scheme.""" + auth_scheme = openapi_models.HTTPBase(scheme="custom") + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="custom", credentials=HttpCredentials(token="custom_token") + ), + ) + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + assert headers == {"Authorization": "custom custom_token"} + + +def test_get_mcp_auth_headers_http_cred_wrong_scheme(): + """Test HTTP credential with non-HTTPBase auth scheme.""" + auth_scheme = openapi_models.APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": openapi_models.APIKeyIn.header, + "name": "X-API-Key", + }) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", credentials=HttpCredentials(token="bearer_token") + ), + ) + with patch.object(mcp_auth_utils, "logger") as mock_logger: + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + assert headers is None + mock_logger.warning.assert_called_once_with( + "HTTP credential provided, but auth_scheme is missing or not HTTPBase." + ) + + +def test_get_mcp_auth_headers_api_key_header(): + """Test header generation for API Key in header.""" + auth_scheme = openapi_models.APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": openapi_models.APIKeyIn.header, + "name": "X-Custom-API-Key", + }) + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + assert headers == {"X-Custom-API-Key": "my_api_key"} + + +def test_get_mcp_auth_headers_api_key_query_raises_error(): + """Test API Key in query raises ValueError.""" + auth_scheme = openapi_models.APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": openapi_models.APIKeyIn.query, + "name": "api_key", + }) + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + with pytest.raises( + ValueError, + match="MCP tools only support header-based API key authentication.", + ): + mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + + +def test_get_mcp_auth_headers_api_key_cookie_raises_error(): + """Test API Key in cookie raises ValueError.""" + auth_scheme = openapi_models.APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": openapi_models.APIKeyIn.cookie, + "name": "session_id", + }) + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + with pytest.raises( + ValueError, + match="MCP tools only support header-based API key authentication.", + ): + mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + + +def test_get_mcp_auth_headers_api_key_cred_wrong_scheme(): + """Test API key credential with non-APIKey auth scheme.""" + auth_scheme = openapi_models.HTTPBase(scheme="bearer") + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + with patch.object(mcp_auth_utils, "logger") as mock_logger: + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + assert headers is None + mock_logger.warning.assert_called_once_with( + "API key credential provided, but auth_scheme is missing or not APIKey." + ) + + +def test_get_mcp_auth_headers_service_account(): + """Test header generation for service account credentials.""" + auth_scheme = openapi_models.HTTPBase(scheme="bearer") + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount(scopes=["test"]), + ) + with patch.object(mcp_auth_utils, "logger") as mock_logger: + headers = mcp_auth_utils.get_mcp_auth_headers( + auth_scheme=auth_scheme, credential=credential + ) + assert headers is None + mock_logger.warning.assert_called_once_with( + "Service account credentials should be exchanged for an access " + "token before calling get_mcp_auth_headers." + ) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 235830195f..72becfa0e4 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -18,10 +18,7 @@ from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.auth.auth_credential import HttpAuth -from google.adk.auth.auth_credential import HttpCredentials from google.adk.auth.auth_credential import OAuth2Auth -from google.adk.auth.auth_credential import ServiceAccount from google.adk.features import FeatureName from google.adk.features._feature_registry import temporary_feature_override from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager @@ -261,240 +258,6 @@ async def test_run_async_impl_with_oauth2(self): headers = call_args[1]["headers"] assert headers == {"Authorization": "Bearer test_access_token"} - @pytest.mark.asyncio - async def test_get_headers_oauth2(self): - """Test header generation for OAuth2 credentials.""" - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - oauth2_auth = OAuth2Auth(access_token="test_token") - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth - ) - - tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, credential) - - assert headers == {"Authorization": "Bearer test_token"} - - @pytest.mark.asyncio - async def test_get_headers_http_bearer(self): - """Test header generation for HTTP Bearer credentials.""" - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - http_auth = HttpAuth( - scheme="bearer", credentials=HttpCredentials(token="bearer_token") - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, http=http_auth - ) - - tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, credential) - - assert headers == {"Authorization": "Bearer bearer_token"} - - @pytest.mark.asyncio - async def test_get_headers_http_basic(self): - """Test header generation for HTTP Basic credentials.""" - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - http_auth = HttpAuth( - scheme="basic", - credentials=HttpCredentials(username="user", password="pass"), - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, http=http_auth - ) - - tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, credential) - - # Should create Basic auth header with base64 encoded credentials - import base64 - - expected_encoded = base64.b64encode(b"user:pass").decode() - assert headers == {"Authorization": f"Basic {expected_encoded}"} - - @pytest.mark.asyncio - async def test_get_headers_api_key_with_valid_header_scheme(self): - """Test header generation for API Key credentials with header-based auth scheme.""" - from fastapi.openapi.models import APIKey - from fastapi.openapi.models import APIKeyIn - from google.adk.auth.auth_schemes import AuthSchemeType - - # Create auth scheme for header-based API key - auth_scheme = APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": APIKeyIn.header, - "name": "X-Custom-API-Key", - }) - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - auth_scheme=auth_scheme, - auth_credential=auth_credential, - ) - - tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, auth_credential) - - assert headers == {"X-Custom-API-Key": "my_api_key"} - - @pytest.mark.asyncio - async def test_get_headers_api_key_with_query_scheme_raises_error(self): - """Test that API Key with query-based auth scheme raises ValueError.""" - from fastapi.openapi.models import APIKey - from fastapi.openapi.models import APIKeyIn - from google.adk.auth.auth_schemes import AuthSchemeType - - # Create auth scheme for query-based API key (not supported) - auth_scheme = APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": APIKeyIn.query, - "name": "api_key", - }) - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - auth_scheme=auth_scheme, - auth_credential=auth_credential, - ) - - tool_context = Mock(spec=ToolContext) - - with pytest.raises( - ValueError, - match="McpTool only supports header-based API key authentication", - ): - await tool._get_headers(tool_context, auth_credential) - - @pytest.mark.asyncio - async def test_get_headers_api_key_with_cookie_scheme_raises_error(self): - """Test that API Key with cookie-based auth scheme raises ValueError.""" - from fastapi.openapi.models import APIKey - from fastapi.openapi.models import APIKeyIn - from google.adk.auth.auth_schemes import AuthSchemeType - - # Create auth scheme for cookie-based API key (not supported) - auth_scheme = APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": APIKeyIn.cookie, - "name": "session_id", - }) - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - auth_scheme=auth_scheme, - auth_credential=auth_credential, - ) - - tool_context = Mock(spec=ToolContext) - - with pytest.raises( - ValueError, - match="McpTool only supports header-based API key authentication", - ): - await tool._get_headers(tool_context, auth_credential) - - @pytest.mark.asyncio - async def test_get_headers_api_key_without_auth_config_raises_error(self): - """Test that API Key without auth config raises ValueError.""" - # Create tool without auth scheme/config - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - tool_context = Mock(spec=ToolContext) - - with pytest.raises( - ValueError, - match="Cannot find corresponding auth scheme for API key credential", - ): - await tool._get_headers(tool_context, credential) - - @pytest.mark.asyncio - async def test_get_headers_api_key_without_credentials_manager_raises_error( - self, - ): - """Test that API Key without credentials manager raises ValueError.""" - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - # Manually set credentials manager to None to simulate error condition - tool._credentials_manager = None - - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - tool_context = Mock(spec=ToolContext) - - with pytest.raises( - ValueError, - match="Cannot find corresponding auth scheme for API key credential", - ): - await tool._get_headers(tool_context, credential) - - @pytest.mark.asyncio - async def test_get_headers_no_credential(self): - """Test header generation with no credentials.""" - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, None) - - assert headers is None - - @pytest.mark.asyncio - async def test_get_headers_service_account(self): - """Test header generation for service account credentials.""" - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - # Create service account credential - service_account = ServiceAccount(scopes=["test"]) - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=service_account, - ) - - tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, credential) - - # Should return None as service account credentials are not supported for direct header generation - assert headers is None - @pytest.mark.asyncio async def test_run_async_impl_with_api_key_header_auth(self): """Test running tool with API key header authentication end-to-end.""" @@ -551,65 +314,6 @@ async def test_run_async_impl_retry_decorator(self): # Check that the method has the retry decorator assert hasattr(tool._run_async_impl, "__wrapped__") - @pytest.mark.asyncio - async def test_get_headers_http_custom_scheme(self): - """Test header generation for custom HTTP scheme.""" - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - ) - - http_auth = HttpAuth( - scheme="custom", credentials=HttpCredentials(token="custom_token") - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, http=http_auth - ) - - tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, credential) - - assert headers == {"Authorization": "custom custom_token"} - - @pytest.mark.asyncio - async def test_get_headers_api_key_error_logging(self): - """Test that API key errors are logged correctly.""" - from fastapi.openapi.models import APIKey - from fastapi.openapi.models import APIKeyIn - from google.adk.auth.auth_schemes import AuthSchemeType - - # Create auth scheme for query-based API key (not supported) - auth_scheme = APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": APIKeyIn.query, - "name": "api_key", - }) - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - - tool = MCPTool( - mcp_tool=self.mock_mcp_tool, - mcp_session_manager=self.mock_session_manager, - auth_scheme=auth_scheme, - auth_credential=auth_credential, - ) - - tool_context = Mock(spec=ToolContext) - - # Test with logging - with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger: - with pytest.raises(ValueError): - await tool._get_headers(tool_context, auth_credential) - - # Verify error was logged - mock_logger.error.assert_called_once() - logged_message = mock_logger.error.call_args[0][0] - assert ( - "McpTool only supports header-based API key authentication" - in logged_message - ) - @pytest.mark.asyncio async def test_run_async_require_confirmation_true_no_confirmation(self): """Test require_confirmation=True with no confirmation in context.""" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index f6d002ed17..0d0b761885 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -30,6 +30,8 @@ from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +from google.adk.tools.mcp_tool.mcp_toolset import McpToolsetConfig +from google.adk.tools.tool_configs import ToolArgsConfig from mcp import StdioServerParameters import pytest @@ -245,6 +247,94 @@ async def test_get_tools_with_header_provider(self): headers=expected_headers ) + @pytest.mark.asyncio + async def test_get_tools_with_auth_headers(self): + """Test get_tools with auth headers.""" + from fastapi.openapi import models as openapi_models + from google.adk.auth.auth_credential import AuthCredentialTypes + from google.adk.auth.auth_credential import OAuth2Auth + + mock_tools = [MockMCPTool("tool1")] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + mock_readonly_context = Mock(spec=ReadonlyContext) + + auth_scheme = openapi_models.HTTPBase(scheme="bearer") + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(access_token="test_token"), + ) + + with patch( + "google.adk.tools.mcp_tool.mcp_toolset.CredentialManager" + ) as MockCredentialManager: + mock_manager_instance = MockCredentialManager.return_value + mock_manager_instance.get_auth_credential = AsyncMock( + return_value=auth_credential + ) + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + toolset._mcp_session_manager = self.mock_session_manager + + await toolset.get_tools(readonly_context=mock_readonly_context) + + self.mock_session_manager.create_session.assert_called_once() + call_args = self.mock_session_manager.create_session.call_args + headers = call_args[1]["headers"] + assert headers == {"Authorization": "Bearer test_token"} + + @pytest.mark.asyncio + async def test_get_tools_with_auth_and_header_provider(self): + """Test get_tools with auth and header_provider.""" + from fastapi.openapi import models as openapi_models + from google.adk.auth.auth_credential import AuthCredentialTypes + from google.adk.auth.auth_credential import OAuth2Auth + + mock_tools = [MockMCPTool("tool1")] + self.mock_session.list_tools = AsyncMock( + return_value=MockListToolsResult(mock_tools) + ) + mock_readonly_context = Mock(spec=ReadonlyContext) + provided_headers = {"X-Tenant-ID": "test-tenant"} + header_provider = Mock(return_value=provided_headers) + + auth_scheme = openapi_models.HTTPBase(scheme="bearer") + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(access_token="test_token"), + ) + + with patch( + "google.adk.tools.mcp_tool.mcp_toolset.CredentialManager" + ) as MockCredentialManager: + mock_manager_instance = MockCredentialManager.return_value + mock_manager_instance.get_auth_credential = AsyncMock( + return_value=auth_credential + ) + + toolset = MCPToolset( + connection_params=self.mock_stdio_params, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + header_provider=header_provider, + ) + toolset._mcp_session_manager = self.mock_session_manager + + await toolset.get_tools(readonly_context=mock_readonly_context) + + self.mock_session_manager.create_session.assert_called_once() + call_args = self.mock_session_manager.create_session.call_args + headers = call_args[1]["headers"] + assert headers == { + "X-Tenant-ID": "test-tenant", + "Authorization": "Bearer test_token", + } + @pytest.mark.asyncio async def test_close_success(self): """Test successful cleanup."""