From 910f65473f7604c2b6cb3052d71ff0c3c11a0aaa Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 21 Jan 2026 10:06:08 -0800 Subject: [PATCH 1/7] docs: Update to gemini-2.5-flash for api registry and mcp sample agents Co-authored-by: Kathy Wu PiperOrigin-RevId: 859151854 --- contributing/samples/api_registry_agent/agent.py | 2 +- contributing/samples/mcp_dynamic_header_agent/agent.py | 2 +- contributing/samples/mcp_postgres_agent/agent.py | 2 +- contributing/samples/mcp_service_account_agent/agent.py | 2 +- contributing/samples/mcp_sse_agent/agent.py | 2 +- contributing/samples/mcp_stdio_notion_agent/agent.py | 2 +- contributing/samples/mcp_stdio_server_agent/agent.py | 2 +- contributing/samples/mcp_streamablehttp_agent/agent.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/contributing/samples/api_registry_agent/agent.py b/contributing/samples/api_registry_agent/agent.py index 5247c9fac5..9f55ef8069 100644 --- a/contributing/samples/api_registry_agent/agent.py +++ b/contributing/samples/api_registry_agent/agent.py @@ -26,7 +26,7 @@ mcp_server_name=MCP_SERVER_NAME, ) root_agent = LlmAgent( - model="gemini-2.0-flash", + model="gemini-2.5-flash", name="bigquery_assistant", instruction=f""" You are a helpful data analyst assistant with access to BigQuery. The project ID is: {PROJECT_ID} diff --git a/contributing/samples/mcp_dynamic_header_agent/agent.py b/contributing/samples/mcp_dynamic_header_agent/agent.py index 17768c16e4..0d3ce7f6a6 100644 --- a/contributing/samples/mcp_dynamic_header_agent/agent.py +++ b/contributing/samples/mcp_dynamic_header_agent/agent.py @@ -18,7 +18,7 @@ from google.adk.tools.mcp_tool.mcp_toolset import McpToolset root_agent = LlmAgent( - model='gemini-2.0-flash', + model='gemini-2.5-flash', name='tenant_agent', instruction="""You are a helpful assistant that helps users get tenant information. Call the get_tenant_data tool when the user asks for tenant data.""", diff --git a/contributing/samples/mcp_postgres_agent/agent.py b/contributing/samples/mcp_postgres_agent/agent.py index 7224c34ab5..09e75b44d0 100644 --- a/contributing/samples/mcp_postgres_agent/agent.py +++ b/contributing/samples/mcp_postgres_agent/agent.py @@ -31,7 +31,7 @@ ) root_agent = LlmAgent( - model="gemini-2.0-flash", + model="gemini-2.5-flash", name="postgres_agent", instruction=( "You are a PostgreSQL database assistant. " diff --git a/contributing/samples/mcp_service_account_agent/agent.py b/contributing/samples/mcp_service_account_agent/agent.py index a62e30cea2..975ff0ccb7 100644 --- a/contributing/samples/mcp_service_account_agent/agent.py +++ b/contributing/samples/mcp_service_account_agent/agent.py @@ -29,7 +29,7 @@ SCOPES = {"https://www.googleapis.com/auth/cloud-platform": ""} root_agent = LlmAgent( - model="gemini-2.0-flash", + model="gemini-2.5-flash", name="enterprise_assistant", instruction=""" Help the user with the tools available to you. diff --git a/contributing/samples/mcp_sse_agent/agent.py b/contributing/samples/mcp_sse_agent/agent.py index 2afbb930b1..8bddbca79f 100755 --- a/contributing/samples/mcp_sse_agent/agent.py +++ b/contributing/samples/mcp_sse_agent/agent.py @@ -28,7 +28,7 @@ ) root_agent = LlmAgent( - model='gemini-2.0-flash', + model='gemini-2.5-flash', name='enterprise_assistant', instruction=McpInstructionProvider( connection_params=connection_params, diff --git a/contributing/samples/mcp_stdio_notion_agent/agent.py b/contributing/samples/mcp_stdio_notion_agent/agent.py index 55ea56ec49..1a624f7ee1 100644 --- a/contributing/samples/mcp_stdio_notion_agent/agent.py +++ b/contributing/samples/mcp_stdio_notion_agent/agent.py @@ -29,7 +29,7 @@ }) root_agent = LlmAgent( - model="gemini-2.0-flash", + model="gemini-2.5-flash", name="notion_agent", instruction=( "You are my workspace assistant. " diff --git a/contributing/samples/mcp_stdio_server_agent/agent.py b/contributing/samples/mcp_stdio_server_agent/agent.py index 1799bd56d0..8b336ba281 100755 --- a/contributing/samples/mcp_stdio_server_agent/agent.py +++ b/contributing/samples/mcp_stdio_server_agent/agent.py @@ -23,7 +23,7 @@ _allowed_path = os.path.dirname(os.path.abspath(__file__)) root_agent = LlmAgent( - model='gemini-2.0-flash', + model='gemini-2.5-flash', name='enterprise_assistant', instruction=f"""\ Help user accessing their file systems. diff --git a/contributing/samples/mcp_streamablehttp_agent/agent.py b/contributing/samples/mcp_streamablehttp_agent/agent.py index e2223b0f13..1da782d04a 100644 --- a/contributing/samples/mcp_streamablehttp_agent/agent.py +++ b/contributing/samples/mcp_streamablehttp_agent/agent.py @@ -22,7 +22,7 @@ _allowed_path = os.path.dirname(os.path.abspath(__file__)) root_agent = LlmAgent( - model='gemini-2.0-flash', + model='gemini-2.5-flash', name='enterprise_assistant', instruction=f"""\ Help user accessing their file systems. From 7b25b8fb1daf54d7694bf405d545d46d2c012d2b Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 21 Jan 2026 10:21:14 -0800 Subject: [PATCH 2/7] fix(runner): Yield buffered function_call/function_response events during live streaming Bug: In live streaming mode, when function_call and function_response events arrive during active transcription, they are correctly buffered but never yielded to the caller. This causes callers to miss these events even though they are saved to the session. Fix: Add yield buffered_event after appending buffered events to the session when transcription ends. Testing: - Added unit test: test_live_streaming_buffered_function_call_yielded_during_transcription - Test verifies buffered events are yielded by: 1. Simulating partial transcription (triggers buffering) 2. Sending function_call during transcription (gets buffered) 3. Ending transcription (should yield buffered events) 4. Asserting both function_call and function_response are in yielded events Test results: - With fix: PASSED - Without fix (yield commented out): FAILED with "Buffered function_call event was not yielded" - Example event flow after fix: EVENT: partial=True, input_transcription="Show me the weather" EVENT: function_call=get_weather, args={'location': 'NYC'} <- Now yielded EVENT: function_response=get_weather, response={...} <- Now yielded EVENT: partial=False, input_transcription="Show me the weather for today" PiperOrigin-RevId: 859158546 --- src/google/adk/runners.py | 1 + tests/unittests/streaming/test_streaming.py | 111 ++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bfccbdc947..b931561c4d 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -815,6 +815,7 @@ async def _exec_with_plugin( await self.session_service.append_event( session=session, event=buffered_event ) + yield buffered_event # yield buffered events to caller buffered_events = [] else: # non-transcription event or empty transcription event, for diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index 57743d6076..5ee4721c05 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -1009,3 +1009,114 @@ async def consume_responses(session: testing_utils.Session): assert stock_call_found, 'Expected monitor_stock_price function call event.' assert video_call_found, 'Expected monitor_video_stream function call event.' + + +def test_live_streaming_buffered_function_call_yielded_during_transcription(): + """Test that function calls buffered during transcription are yielded. + + This tests the fix for the bug where function_call and function_response + events were buffered during active transcription but never yielded to the + caller. The fix ensures buffered events are yielded after transcription ends. + """ + function_call = types.Part.from_function_call( + name='get_weather', args={'location': 'San Francisco'} + ) + + response1 = LlmResponse( + input_transcription=types.Transcription(text='Show'), + partial=True, # ← Triggers is_transcribing = True + ) + response2 = LlmResponse( + content=types.Content( + role='model', parts=[function_call] + ), # ← Gets buffered + turn_complete=False, + ) + response3 = LlmResponse( + input_transcription=types.Transcription(text='Show me the weather'), + partial=False, # ← Transcription ends, buffered events yielded + ) + response4 = LlmResponse( + turn_complete=True, + ) + + mock_model = testing_utils.MockModel.create( + [response1, response2, response3, response4] + ) + + def get_weather(location: str) -> dict: + return {'temperature': 22, 'location': location} + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[get_weather], + ) + + class CustomTestRunner(testing_utils.InMemoryRunner): + + def run_live( + self, + live_request_queue: LiveRequestQueue, + run_config: testing_utils.RunConfig = None, + ) -> list[testing_utils.Event]: + collected_responses = [] + + async def consume_responses(session: testing_utils.Session): + run_res = self.runner.run_live( + session=session, + live_request_queue=live_request_queue, + run_config=run_config or testing_utils.RunConfig(), + ) + + async for response in run_res: + collected_responses.append(response) + if len(collected_responses) >= 5: + return + + try: + session = self.session + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + return collected_responses + + runner = CustomTestRunner(root_agent=root_agent) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=types.Blob(data=b'Show me the weather', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue) + + assert res_events is not None, 'Expected a list of events, got None.' + assert len(res_events) >= 1, 'Expected at least one event.' + + function_call_found = False + function_response_found = False + + for event in res_events: + if event.content and event.content.parts: + for part in event.content.parts: + if part.function_call and part.function_call.name == 'get_weather': + function_call_found = True + assert part.function_call.args['location'] == 'San Francisco' + if ( + part.function_response + and part.function_response.name == 'get_weather' + ): + function_response_found = True + assert part.function_response.response['temperature'] == 22 + + assert function_call_found, 'Buffered function_call event was not yielded.' + assert ( + function_response_found + ), 'Buffered function_response event was not yielded.' From d62f9c896c301aba3a781e868735e16f946a8862 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Wed, 21 Jan 2026 11:19:46 -0800 Subject: [PATCH 3/7] chore: Always skip executing partial function calls Related: https://github.com/google/adk-python/issues/4159 Co-authored-by: Xuan Yang PiperOrigin-RevId: 859184844 --- .../adk/flows/llm_flows/base_llm_flow.py | 15 +- .../test_progressive_sse_streaming.py | 141 ++++++++++++++++++ 2 files changed, 146 insertions(+), 10 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 514e31b272..d57e31ffd9 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -38,8 +38,6 @@ from ...agents.run_config import StreamingMode from ...agents.transcription_entry import TranscriptionEntry from ...events.event import Event -from ...features import FeatureName -from ...features import is_feature_enabled from ...models.base_llm_connection import BaseLlmConnection from ...models.llm_request import LlmRequest from ...models.llm_response import LlmResponse @@ -551,14 +549,11 @@ async def _postprocess_async( # Handles function calls. if model_response_event.get_function_calls(): - if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): - # In progressive SSE streaming mode stage 1, we skip partial FC events - # Only execute FCs in the final aggregated event (partial=False) - if ( - invocation_context.run_config.streaming_mode == StreamingMode.SSE - and model_response_event.partial - ): - return + # Skip partial function call events - they should not trigger execution + # since partial events are not saved to session (see runners.py). + # Only execute function calls in the non-partial events. + if model_response_event.partial: + return async with Aclosing( self._postprocess_handle_function_calls_async( diff --git a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py index 9b58d5bcbd..37dfbb95eb 100644 --- a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py +++ b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py @@ -631,3 +631,144 @@ async def process(): args = fc_part.function_call.args assert args["num"] == 100 assert args["s"] == "ADK" + + +class PartialFunctionCallMockModel(BaseLlm): + """A mock model that yields partial function call events followed by final.""" + + model: str = "partial-fc-mock" + tool_call_count: int = 0 + + @classmethod + def supported_models(cls) -> list[str]: + return ["partial-fc-mock"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + """Yield partial FC events then final, simulating streaming behavior.""" + + # Check if this is a follow-up call (after function response) + has_function_response = False + for content in llm_request.contents: + for part in content.parts or []: + if part.function_response: + has_function_response = True + break + + if has_function_response: + # Final response after function execution + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="Function executed once.")], + ), + partial=False, + ) + return + + # First call: yield partial FC events then final + # Partial event 1 + yield LlmResponse( + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name="track_execution", args={"call_id": "partial_1"} + ) + ], + ), + partial=True, + ) + + # Partial event 2 + yield LlmResponse( + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name="track_execution", args={"call_id": "partial_2"} + ) + ], + ), + partial=True, + ) + + # Final aggregated event (only this should trigger execution) + yield LlmResponse( + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name="track_execution", args={"call_id": "final"} + ) + ], + ), + partial=False, + finish_reason=types.FinishReason.STOP, + ) + + +def test_partial_function_calls_not_executed_in_none_streaming_mode(): + """Test that partial function call events are skipped regardless of mode.""" + execution_log = [] + + def track_execution(call_id: str) -> str: + """A tool that logs each execution to verify call count.""" + execution_log.append(call_id) + return f"Executed: {call_id}" + + mock_model = PartialFunctionCallMockModel() + + agent = Agent( + name="partial_fc_test_agent", + model=mock_model, + tools=[track_execution], + ) + + # Use StreamingMode.NONE to verify partial FCs are still skipped + run_config = RunConfig(streaming_mode=StreamingMode.NONE) + + runner = InMemoryRunner(agent=agent) + + session = runner.session_service.create_session_sync( + app_name=runner.app_name, user_id="test_user" + ) + + events = [] + for event in runner.run( + user_id="test_user", + session_id=session.id, + new_message=types.Content( + role="user", + parts=[types.Part.from_text(text="Test partial FC handling")], + ), + run_config=run_config, + ): + events.append(event) + + # Verify the tool was only executed once (from the final event) + assert ( + len(execution_log) == 1 + ), f"Expected 1 execution, got {len(execution_log)}: {execution_log}" + assert ( + execution_log[0] == "final" + ), f"Expected 'final' execution, got: {execution_log[0]}" + + # Verify partial events were yielded but not executed + partial_events = [e for e in events if e.partial] + assert ( + len(partial_events) == 2 + ), f"Expected 2 partial events, got {len(partial_events)}" + + # Verify there's a function response event (from the final FC execution) + function_response_events = [ + e + for e in events + if e.content + and e.content.parts + and any(p.function_response for p in e.content.parts) + ] + assert ( + len(function_response_events) == 1 + ), f"Expected 1 function response event, got {len(function_response_events)}" From e3d542a5ba3d357407f8cd29cfdd722f583c8564 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 21 Jan 2026 11:57:05 -0800 Subject: [PATCH 4/7] feat: Support authentication for MCP tool listing Currently only tool calling supports MCP auth. This refactors the auth logic into a auth_utils file and uses it for tool listing as well. Fixes https://github.com/google/adk-python/issues/2168. Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 859201722 --- .../adk/tools/mcp_tool/mcp_auth_utils.py | 110 ------- src/google/adk/tools/mcp_tool/mcp_tool.py | 94 +++++- src/google/adk/tools/mcp_tool/mcp_toolset.py | 46 +-- .../tools/mcp_tool/test_mcp_auth_utils.py | 240 -------------- .../unittests/tools/mcp_tool/test_mcp_tool.py | 296 ++++++++++++++++++ .../tools/mcp_tool/test_mcp_toolset.py | 90 ------ 6 files changed, 386 insertions(+), 490 deletions(-) delete mode 100644 src/google/adk/tools/mcp_tool/mcp_auth_utils.py delete mode 100644 tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py diff --git a/src/google/adk/tools/mcp_tool/mcp_auth_utils.py b/src/google/adk/tools/mcp_tool/mcp_auth_utils.py deleted file mode 100644 index b074e67f15..0000000000 --- a/src/google/adk/tools/mcp_tool/mcp_auth_utils.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for MCP tool authentication.""" - -from __future__ import annotations - -import base64 -import logging -from typing import Dict -from typing import Optional - -from fastapi.openapi import models as openapi_models -from fastapi.openapi.models import APIKey -from fastapi.openapi.models import HTTPBase - -from ...auth.auth_credential import AuthCredential -from ...auth.auth_schemes import AuthScheme - -logger = logging.getLogger("google_adk." + __name__) - - -def get_mcp_auth_headers( - auth_scheme: Optional[AuthScheme], credential: Optional[AuthCredential] -) -> Optional[Dict[str, str]]: - """Generates HTTP authentication headers for MCP calls. - - Args: - auth_scheme: The authentication scheme. - credential: The resolved authentication credential. - - Returns: - A dictionary of headers, or None if no auth is applicable. - - Raises: - ValueError: If the auth scheme is unsupported or misconfigured. - """ - if not credential: - return None - - headers: Optional[Dict[str, str]] = None - - if credential.oauth2: - headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} - elif credential.http: - if not auth_scheme or not isinstance(auth_scheme, HTTPBase): - logger.warning( - "HTTP credential provided, but auth_scheme is missing or not" - " HTTPBase." - ) - return None - - scheme = auth_scheme.scheme.lower() - if scheme == "bearer" and credential.http.credentials.token: - headers = {"Authorization": f"Bearer {credential.http.credentials.token}"} - elif scheme == "basic": - if ( - credential.http.credentials.username - and credential.http.credentials.password - ): - creds = f"{credential.http.credentials.username}:{credential.http.credentials.password}" - encoded_creds = base64.b64encode(creds.encode()).decode() - headers = {"Authorization": f"Basic {encoded_creds}"} - else: - logger.warning("Basic auth scheme missing username or password.") - elif credential.http.credentials.token: - # Handle other HTTP schemes like Digest, etc. if token is present - headers = { - "Authorization": ( - f"{auth_scheme.scheme} {credential.http.credentials.token}" - ) - } - else: - logger.warning(f"Unsupported or incomplete HTTP auth scheme '{scheme}'.") - elif credential.api_key: - if not auth_scheme or not isinstance(auth_scheme, APIKey): - logger.warning( - "API key credential provided, but auth_scheme is missing or not" - " APIKey." - ) - return None - - if auth_scheme.in_ != openapi_models.APIKeyIn.header: - error_msg = ( - "MCP tools only support header-based API key authentication. " - f"Configured location: {auth_scheme.in_}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - headers = {auth_scheme.name: credential.api_key} - elif credential.service_account: - logger.warning( - "Service account credentials should be exchanged for an access token " - "before calling get_mcp_auth_headers." - ) - else: - logger.warning(f"Unsupported credential type: {type(credential)}") - - return headers diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 0b0c95ee35..a5b598fd81 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import base64 import inspect import logging from typing import Any @@ -23,6 +24,7 @@ from typing import Union import warnings +from fastapi.openapi.models import APIKeyIn from google.genai.types import FunctionDeclaration from mcp.types import Tool as McpBaseTool from typing_extensions import override @@ -37,7 +39,6 @@ from ..base_authenticated_tool import BaseAuthenticatedTool # import from ..tool_context import ToolContext -from .mcp_auth_utils import get_mcp_auth_headers from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_errors @@ -194,12 +195,7 @@ async def _run_async_impl( Any: The response from the tool. """ # Extract headers from credential for session pooling - auth_scheme = ( - self._auth_config.auth_scheme - if hasattr(self, "_auth_config") and self._auth_config - else None - ) - auth_headers = get_mcp_auth_headers(auth_scheme, credential) + auth_headers = await self._get_headers(tool_context, credential) dynamic_headers = None if self._header_provider: dynamic_headers = self._header_provider( @@ -221,6 +217,90 @@ async def _run_async_impl( response = await session.call_tool(self._mcp_tool.name, arguments=args) return response.model_dump(exclude_none=True, mode="json") + async def _get_headers( + self, tool_context: ToolContext, credential: AuthCredential + ) -> Optional[dict[str, str]]: + """Extracts authentication headers from credentials. + + Args: + tool_context: The tool context of the current invocation. + credential: The authentication credential to process. + + Returns: + Dictionary of headers to add to the request, or None if no auth. + + Raises: + ValueError: If API key authentication is configured for non-header location. + """ + headers: Optional[dict[str, str]] = None + if credential: + if credential.oauth2: + headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} + elif credential.http: + # Handle HTTP authentication schemes + if ( + credential.http.scheme.lower() == "bearer" + and credential.http.credentials.token + ): + headers = { + "Authorization": f"Bearer {credential.http.credentials.token}" + } + elif credential.http.scheme.lower() == "basic": + # Handle basic auth + if ( + credential.http.credentials.username + and credential.http.credentials.password + ): + + credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}" + encoded_credentials = base64.b64encode( + credentials.encode() + ).decode() + headers = {"Authorization": f"Basic {encoded_credentials}"} + elif credential.http.credentials.token: + # Handle other HTTP schemes with token + headers = { + "Authorization": ( + f"{credential.http.scheme} {credential.http.credentials.token}" + ) + } + elif credential.api_key: + if ( + not self._credentials_manager + or not self._credentials_manager._auth_config + ): + error_msg = ( + "Cannot find corresponding auth scheme for API key credential" + f" {credential}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + elif ( + self._credentials_manager._auth_config.auth_scheme.in_ + != APIKeyIn.header + ): + error_msg = ( + "McpTool only supports header-based API key authentication." + " Configured location:" + f" {self._credentials_manager._auth_config.auth_scheme.in_}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + else: + headers = { + self._credentials_manager._auth_config.auth_scheme.name: ( + credential.api_key + ) + } + elif credential.service_account: + # Service accounts should be exchanged for access tokens before reaching this point + logger.warning( + "Service account credentials should be exchanged before MCP" + " session creation" + ) + + return headers + class MCPTool(McpTool): """Deprecated name, use `McpTool` instead.""" diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 77abf20332..fed0684ed2 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -33,14 +33,11 @@ from ...agents.readonly_context import ReadonlyContext from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme -from ...auth.auth_tool import AuthConfig -from ...auth.credential_manager import CredentialManager from ..base_tool import BaseTool from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate from ..tool_configs import BaseToolConfig from ..tool_configs import ToolArgsConfig -from .mcp_auth_utils import get_mcp_auth_headers from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_errors from .mcp_session_manager import SseConnectionParams @@ -157,50 +154,13 @@ async def get_tools( Returns: List[BaseTool]: A list of tools available under the specified context. """ - provided_headers = ( + headers = ( self._header_provider(readonly_context) if self._header_provider and readonly_context - else {} + else None ) - - auth_headers = {} - if self._auth_scheme: - try: - # Instantiate CredentialsManager to resolve credentials - auth_config = AuthConfig( - auth_scheme=self._auth_scheme, - raw_auth_credential=self._auth_credential, - ) - credentials_manager = CredentialManager(auth_config) - - # Resolve the credential - resolved_credential = await credentials_manager.get_auth_credential( - readonly_context - ) - - if resolved_credential: - auth_headers = get_mcp_auth_headers( - self._auth_scheme, resolved_credential - ) - else: - logger.warning( - "Failed to resolve credential for tool listing, proceeding" - " without auth headers." - ) - except Exception as e: - logger.warning( - "Error generating auth headers for tool listing: %s, proceeding" - " without auth headers.", - e, - exc_info=True, - ) - - merged_headers = {**(provided_headers or {}), **(auth_headers or {})} - # Get session from session manager - session = await self._mcp_session_manager.create_session( - headers=merged_headers - ) + session = await self._mcp_session_manager.create_session(headers=headers) # Fetch available tools from the MCP server timeout_in_seconds = ( diff --git a/tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py b/tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py deleted file mode 100644 index 9e0988467e..0000000000 --- a/tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -from unittest.mock import patch - -from fastapi.openapi import models as openapi_models -from google.adk.auth.auth_credential import AuthCredential -from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.auth.auth_credential import HttpAuth -from google.adk.auth.auth_credential import HttpCredentials -from google.adk.auth.auth_credential import OAuth2Auth -from google.adk.auth.auth_credential import ServiceAccount -from google.adk.auth.auth_schemes import AuthSchemeType -from google.adk.tools.mcp_tool import mcp_auth_utils -import pytest - - -def test_get_mcp_auth_headers_no_credential(): - """Test header generation with no credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=None - ) - assert headers is None - - -def test_get_mcp_auth_headers_no_auth_scheme(): - """Test header generation with no auth_scheme.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(access_token="test_token"), - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=None, credential=credential - ) - assert headers == {"Authorization": "Bearer test_token"} - - -def test_get_mcp_auth_headers_oauth2(): - """Test header generation for OAuth2 credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(access_token="test_token"), - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers == {"Authorization": "Bearer test_token"} - - -def test_get_mcp_auth_headers_http_bearer(): - """Test header generation for HTTP Bearer credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", credentials=HttpCredentials(token="bearer_token") - ), - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers == {"Authorization": "Bearer bearer_token"} - - -def test_get_mcp_auth_headers_http_basic(): - """Test header generation for HTTP Basic credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="basic") - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="basic", - credentials=HttpCredentials(username="user", password="pass"), - ), - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - expected_encoded = base64.b64encode(b"user:pass").decode() - assert headers == {"Authorization": f"Basic {expected_encoded}"} - - -def test_get_mcp_auth_headers_http_basic_missing_credentials(): - """Test header generation for HTTP Basic with missing credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="basic") - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="basic", - credentials=HttpCredentials(username="user", password=None), - ), - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers is None - mock_logger.warning.assert_called_once_with( - "Basic auth scheme missing username or password." - ) - - -def test_get_mcp_auth_headers_http_custom_scheme(): - """Test header generation for custom HTTP scheme.""" - auth_scheme = openapi_models.HTTPBase(scheme="custom") - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="custom", credentials=HttpCredentials(token="custom_token") - ), - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers == {"Authorization": "custom custom_token"} - - -def test_get_mcp_auth_headers_http_cred_wrong_scheme(): - """Test HTTP credential with non-HTTPBase auth scheme.""" - auth_scheme = openapi_models.APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": openapi_models.APIKeyIn.header, - "name": "X-API-Key", - }) - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", credentials=HttpCredentials(token="bearer_token") - ), - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers is None - mock_logger.warning.assert_called_once_with( - "HTTP credential provided, but auth_scheme is missing or not HTTPBase." - ) - - -def test_get_mcp_auth_headers_api_key_header(): - """Test header generation for API Key in header.""" - auth_scheme = openapi_models.APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": openapi_models.APIKeyIn.header, - "name": "X-Custom-API-Key", - }) - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers == {"X-Custom-API-Key": "my_api_key"} - - -def test_get_mcp_auth_headers_api_key_query_raises_error(): - """Test API Key in query raises ValueError.""" - auth_scheme = openapi_models.APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": openapi_models.APIKeyIn.query, - "name": "api_key", - }) - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - with pytest.raises( - ValueError, - match="MCP tools only support header-based API key authentication.", - ): - mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - - -def test_get_mcp_auth_headers_api_key_cookie_raises_error(): - """Test API Key in cookie raises ValueError.""" - auth_scheme = openapi_models.APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": openapi_models.APIKeyIn.cookie, - "name": "session_id", - }) - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - with pytest.raises( - ValueError, - match="MCP tools only support header-based API key authentication.", - ): - mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - - -def test_get_mcp_auth_headers_api_key_cred_wrong_scheme(): - """Test API key credential with non-APIKey auth scheme.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers is None - mock_logger.warning.assert_called_once_with( - "API key credential provided, but auth_scheme is missing or not APIKey." - ) - - -def test_get_mcp_auth_headers_service_account(): - """Test header generation for service account credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount(scopes=["test"]), - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers is None - mock_logger.warning.assert_called_once_with( - "Service account credentials should be exchanged for an access " - "token before calling get_mcp_auth_headers." - ) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index cb13a4e541..0bf28cb3c2 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -18,7 +18,10 @@ from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount from google.adk.features import FeatureName from google.adk.features._feature_registry import temporary_feature_override from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager @@ -258,6 +261,240 @@ async def test_run_async_impl_with_oauth2(self): headers = call_args[1]["headers"] assert headers == {"Authorization": "Bearer test_access_token"} + @pytest.mark.asyncio + async def test_get_headers_oauth2(self): + """Test header generation for OAuth2 credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + oauth2_auth = OAuth2Auth(access_token="test_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer test_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_bearer(self): + """Test header generation for HTTP Bearer credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="bearer", credentials=HttpCredentials(token="bearer_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer bearer_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_basic(self): + """Test header generation for HTTP Basic credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="basic", + credentials=HttpCredentials(username="user", password="pass"), + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should create Basic auth header with base64 encoded credentials + import base64 + + expected_encoded = base64.b64encode(b"user:pass").decode() + assert headers == {"Authorization": f"Basic {expected_encoded}"} + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_valid_header_scheme(self): + """Test header generation for API Key credentials with header-based auth scheme.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for header-based API key + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.header, + "name": "X-Custom-API-Key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, auth_credential) + + assert headers == {"X-Custom-API-Key": "my_api_key"} + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_query_scheme_raises_error(self): + """Test that API Key with query-based auth scheme raises ValueError.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for query-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.query, + "name": "api_key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="McpTool only supports header-based API key authentication", + ): + await tool._get_headers(tool_context, auth_credential) + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_cookie_scheme_raises_error(self): + """Test that API Key with cookie-based auth scheme raises ValueError.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for cookie-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.cookie, + "name": "session_id", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="McpTool only supports header-based API key authentication", + ): + await tool._get_headers(tool_context, auth_credential) + + @pytest.mark.asyncio + async def test_get_headers_api_key_without_auth_config_raises_error(self): + """Test that API Key without auth config raises ValueError.""" + # Create tool without auth scheme/config + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="Cannot find corresponding auth scheme for API key credential", + ): + await tool._get_headers(tool_context, credential) + + @pytest.mark.asyncio + async def test_get_headers_api_key_without_credentials_manager_raises_error( + self, + ): + """Test that API Key without credentials manager raises ValueError.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Manually set credentials manager to None to simulate error condition + tool._credentials_manager = None + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="Cannot find corresponding auth scheme for API key credential", + ): + await tool._get_headers(tool_context, credential) + + @pytest.mark.asyncio + async def test_get_headers_no_credential(self): + """Test header generation with no credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, None) + + assert headers is None + + @pytest.mark.asyncio + async def test_get_headers_service_account(self): + """Test header generation for service account credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create service account credential + service_account = ServiceAccount(scopes=["test"]) + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=service_account, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should return None as service account credentials are not supported for direct header generation + assert headers is None + @pytest.mark.asyncio async def test_run_async_impl_with_api_key_header_auth(self): """Test running tool with API key header authentication end-to-end.""" @@ -314,6 +551,65 @@ async def test_run_async_impl_retry_decorator(self): # Check that the method has the retry decorator assert hasattr(tool._run_async_impl, "__wrapped__") + @pytest.mark.asyncio + async def test_get_headers_http_custom_scheme(self): + """Test header generation for custom HTTP scheme.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="custom", credentials=HttpCredentials(token="custom_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "custom custom_token"} + + @pytest.mark.asyncio + async def test_get_headers_api_key_error_logging(self): + """Test that API key errors are logged correctly.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for query-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.query, + "name": "api_key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + + # Test with logging + with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger: + with pytest.raises(ValueError): + await tool._get_headers(tool_context, auth_credential) + + # Verify error was logged + mock_logger.error.assert_called_once() + logged_message = mock_logger.error.call_args[0][0] + assert ( + "McpTool only supports header-based API key authentication" + in logged_message + ) + @pytest.mark.asyncio async def test_run_async_require_confirmation_true_no_confirmation(self): """Test require_confirmation=True with no confirmation in context.""" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 34d8c5b7fb..83b112d8b9 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -30,8 +30,6 @@ from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from google.adk.tools.mcp_tool.mcp_toolset import McpToolset -from google.adk.tools.mcp_tool.mcp_toolset import McpToolsetConfig -from google.adk.tools.tool_configs import ToolArgsConfig from mcp import StdioServerParameters import pytest @@ -247,94 +245,6 @@ async def test_get_tools_with_header_provider(self): headers=expected_headers ) - @pytest.mark.asyncio - async def test_get_tools_with_auth_headers(self): - """Test get_tools with auth headers.""" - from fastapi.openapi import models as openapi_models - from google.adk.auth.auth_credential import AuthCredentialTypes - from google.adk.auth.auth_credential import OAuth2Auth - - mock_tools = [MockMCPTool("tool1")] - self.mock_session.list_tools = AsyncMock( - return_value=MockListToolsResult(mock_tools) - ) - mock_readonly_context = Mock(spec=ReadonlyContext) - - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(access_token="test_token"), - ) - - with patch( - "google.adk.tools.mcp_tool.mcp_toolset.CredentialManager" - ) as MockCredentialManager: - mock_manager_instance = MockCredentialManager.return_value - mock_manager_instance.get_auth_credential = AsyncMock( - return_value=auth_credential - ) - - toolset = MCPToolset( - connection_params=self.mock_stdio_params, - auth_scheme=auth_scheme, - auth_credential=auth_credential, - ) - toolset._mcp_session_manager = self.mock_session_manager - - await toolset.get_tools(readonly_context=mock_readonly_context) - - self.mock_session_manager.create_session.assert_called_once() - call_args = self.mock_session_manager.create_session.call_args - headers = call_args[1]["headers"] - assert headers == {"Authorization": "Bearer test_token"} - - @pytest.mark.asyncio - async def test_get_tools_with_auth_and_header_provider(self): - """Test get_tools with auth and header_provider.""" - from fastapi.openapi import models as openapi_models - from google.adk.auth.auth_credential import AuthCredentialTypes - from google.adk.auth.auth_credential import OAuth2Auth - - mock_tools = [MockMCPTool("tool1")] - self.mock_session.list_tools = AsyncMock( - return_value=MockListToolsResult(mock_tools) - ) - mock_readonly_context = Mock(spec=ReadonlyContext) - provided_headers = {"X-Tenant-ID": "test-tenant"} - header_provider = Mock(return_value=provided_headers) - - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(access_token="test_token"), - ) - - with patch( - "google.adk.tools.mcp_tool.mcp_toolset.CredentialManager" - ) as MockCredentialManager: - mock_manager_instance = MockCredentialManager.return_value - mock_manager_instance.get_auth_credential = AsyncMock( - return_value=auth_credential - ) - - toolset = MCPToolset( - connection_params=self.mock_stdio_params, - auth_scheme=auth_scheme, - auth_credential=auth_credential, - header_provider=header_provider, - ) - toolset._mcp_session_manager = self.mock_session_manager - - await toolset.get_tools(readonly_context=mock_readonly_context) - - self.mock_session_manager.create_session.assert_called_once() - call_args = self.mock_session_manager.create_session.call_args - headers = call_args[1]["headers"] - assert headers == { - "X-Tenant-ID": "test-tenant", - "Authorization": "Bearer test_token", - } - @pytest.mark.asyncio async def test_close_success(self): """Test successful cleanup.""" From 5d941460b607cf80af12a8226af7bfcd2ca3a3b1 Mon Sep 17 00:00:00 2001 From: Didier Durand <2927957+didier-durand@users.noreply.github.com> Date: Wed, 21 Jan 2026 12:01:27 -0800 Subject: [PATCH 5/7] docs: Fix various typos Merge https://github.com/google/adk-python/pull/4186 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4186 from didier-durand:main f8a52b58e133551bec1d98660fd552afd881a8b9 PiperOrigin-RevId: 859203551 --- contributing/samples/application_integration_agent/README.md | 2 +- contributing/samples/multi_agent_seq_config/README.md | 2 +- contributing/samples/spanner_rag_agent/README.md | 2 +- src/google/adk/tools/pubsub/__init__.py | 2 +- src/google/adk/tools/spanner/settings.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/contributing/samples/application_integration_agent/README.md b/contributing/samples/application_integration_agent/README.md index 0e0a70c17c..961a65eb53 100644 --- a/contributing/samples/application_integration_agent/README.md +++ b/contributing/samples/application_integration_agent/README.md @@ -7,7 +7,7 @@ This sample demonstrates how to use the `ApplicationIntegrationToolset` within a ## Prerequisites 1. **Set up Integration Connection:** - * You need an existing [Integration connection](https://cloud.google.com/integration-connectors/docs/overview) configured to interact with your Jira instance. Follow the [documentation](https://google.github.io/adk-docs/tools/google-cloud-tools/#use-integration-connectors) to provision the Integration Connector in Google Cloud and then use this [documentation](https://cloud.google.com/integration-connectors/docs/connectors/jiracloud/configure) to create an Jira connection. Note the `Connection Name`, `Project ID`, and `Location` of your connection. + * You need an existing [Integration connection](https://cloud.google.com/integration-connectors/docs/overview) configured to interact with your Jira instance. Follow the [documentation](https://google.github.io/adk-docs/tools/google-cloud-tools/#use-integration-connectors) to provision the Integration Connector in Google Cloud and then use this [documentation](https://cloud.google.com/integration-connectors/docs/connectors/jiracloud/configure) to create a Jira connection. Note the `Connection Name`, `Project ID`, and `Location` of your connection. * 2. **Configure Environment Variables:** diff --git a/contributing/samples/multi_agent_seq_config/README.md b/contributing/samples/multi_agent_seq_config/README.md index a2cd462465..af0dcee2fc 100644 --- a/contributing/samples/multi_agent_seq_config/README.md +++ b/contributing/samples/multi_agent_seq_config/README.md @@ -6,7 +6,7 @@ The whole process is: 1. An agent backed by a cheap and fast model to write initial version. 2. An agent backed by a smarter and a little more expensive to review the code. -3. An final agent backed by the smartest and slowest model to write the final revision. +3. A final agent backed by the smartest and slowest model to write the final revision. Sample queries: diff --git a/contributing/samples/spanner_rag_agent/README.md b/contributing/samples/spanner_rag_agent/README.md index 99b60794fe..08d134b990 100644 --- a/contributing/samples/spanner_rag_agent/README.md +++ b/contributing/samples/spanner_rag_agent/README.md @@ -181,7 +181,7 @@ type. ## 💬 Sample prompts -* I'd like to buy a starter bike for my 3 year old child, can you show me the recommendation? +* I'd like to buy a starter bike for my 3-year-old child, can you show me the recommendation? ![Spanner RAG Sample Agent](Spanner_RAG_Sample_Agent.png) diff --git a/src/google/adk/tools/pubsub/__init__.py b/src/google/adk/tools/pubsub/__init__.py index aae36ae528..d488c317d9 100644 --- a/src/google/adk/tools/pubsub/__init__.py +++ b/src/google/adk/tools/pubsub/__init__.py @@ -14,7 +14,7 @@ """Pub/Sub Tools (Experimental). -Pub/Sub Tools under this module are hand crafted and customized while the tools +Pub/Sub Tools under this module are handcrafted and customized while the tools under google.adk.tools.google_api_tool are auto generated based on API definition. The rationales to have customized tool are: diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py index dca8ef098b..ca7e05f1b1 100644 --- a/src/google/adk/tools/spanner/settings.py +++ b/src/google/adk/tools/spanner/settings.py @@ -115,7 +115,7 @@ class VectorSearchIndexSettings(BaseModel): """ num_branches: Optional[int] = None - """Optional. The number of branches to further parititon the vector data. + """Optional. The number of branches to further partition the vector data. You can only designate num_branches for trees with 3 levels. The number of branches must be fewer than the number of leaves @@ -165,7 +165,7 @@ class SpannerVectorStoreSettings(BaseModel): """Required. The vector store table columns to return in the vector similarity search result. By default, only the `content_column` value and the distance value are returned. - If sepecified, the list of selected columns and the distance value are returned. + If specified, the list of selected columns and the distance value are returned. For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value. """ From a04828dd8a848482acbd48acc7da432d0d2cb0aa Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 21 Jan 2026 12:10:37 -0800 Subject: [PATCH 6/7] feat: Persist user input content to session in live mode Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 859207592 --- .../adk/flows/llm_flows/base_llm_flow.py | 23 ++++- tests/unittests/streaming/test_streaming.py | 86 +++++++++++++++++++ tests/unittests/testing_utils.py | 4 + 3 files changed, 111 insertions(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index d57e31ffd9..759ac532fd 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -272,6 +272,25 @@ async def _send_to_model( await llm_connection.send_realtime(live_request.blob) if live_request.content: + content = live_request.content + # Persist user text content to session (similar to non-live mode) + # Skip function responses - they are already handled separately + is_function_response = content.parts and any( + part.function_response for part in content.parts + ) + if not is_function_response: + if not content.role: + content.role = 'user' + user_content_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author='user', + content=content, + ) + await invocation_context.session_service.append_event( + session=invocation_context.session, + event=user_content_event, + ) await llm_connection.send_content(live_request.content) async def _receive_from_model( @@ -391,8 +410,8 @@ async def _run_one_step_async( current_invocation=True, current_branch=True ) - # Long-running tool calls should have been handled before this point. - # If there are still long-running tool calls, it means the agent is paused + # Long running tool calls should have been handled before this point. + # If there are still long running tool calls, it means the agent is paused # before, and its branch hasn't been resumed yet. if ( invocation_context.is_resumable diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index 5ee4721c05..e697eacdc0 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -1120,3 +1120,89 @@ async def consume_responses(session: testing_utils.Session): assert ( function_response_found ), 'Buffered function_response event was not yielded.' + + +def test_live_streaming_text_content_persisted_in_session(): + """Test that user text content sent via send_content is persisted in session.""" + response1 = LlmResponse( + content=types.Content( + role='model', parts=[types.Part(text='Hello! How can I help you?')] + ), + turn_complete=True, + ) + + mock_model = testing_utils.MockModel.create([response1]) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[], + ) + + class CustomTestRunner(testing_utils.InMemoryRunner): + + def run_live_and_get_session( + self, + live_request_queue: LiveRequestQueue, + run_config: testing_utils.RunConfig = None, + ) -> tuple[list[testing_utils.Event], testing_utils.Session]: + collected_responses = [] + + async def consume_responses(session: testing_utils.Session): + run_res = self.runner.run_live( + session=session, + live_request_queue=live_request_queue, + run_config=run_config or testing_utils.RunConfig(), + ) + async for response in run_res: + collected_responses.append(response) + if len(collected_responses) >= 1: + return + + try: + session = self.session + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + # Get the updated session + updated_session = self.runner.session_service.get_session_sync( + app_name=self.app_name, + user_id=session.user_id, + session_id=session.id, + ) + return collected_responses, updated_session + + runner = CustomTestRunner(root_agent=root_agent) + live_request_queue = LiveRequestQueue() + + # Send text content (not audio blob) + user_text = 'Hello, this is a test message' + live_request_queue.send_content( + types.Content(role='user', parts=[types.Part(text=user_text)]) + ) + + res_events, session = runner.run_live_and_get_session(live_request_queue) + + assert res_events is not None, 'Expected a list of events, got None.' + + # Check that user text content was persisted in the session + user_content_found = False + for event in session.events: + if event.author == 'user' and event.content: + for part in event.content.parts: + if part.text and user_text in part.text: + user_content_found = True + break + + assert user_content_found, ( + f'Expected user text content "{user_text}" to be persisted in session. ' + f'Session events: {[e.content for e in session.events]}' + ) diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index f76668b165..4f9b8636fe 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -409,6 +409,10 @@ async def send_realtime(self, blob: types.Blob): async def receive(self) -> AsyncGenerator[LlmResponse, None]: """Yield each of the pre-defined LlmResponses.""" for response in self.llm_responses: + # Yield control to allow other tasks (like send_task) to run first. + # This ensures user content gets persisted before the mock response + # is yielded. + await asyncio.sleep(0) yield response async def close(self): From 3e3566bd0e54447b25faf33df0df1291db60bf1b Mon Sep 17 00:00:00 2001 From: Liang Wu Date: Wed, 21 Jan 2026 13:30:36 -0800 Subject: [PATCH 7/7] chore: Pin litellm dependency to versions below 1.81.0 Update the litellm version constraint in both project dependencies and dev dependencies to exclude versions 1.81.0 and higher because unit test breakages in GitHub actions introduced by it. This is a stopgap before the actual fix is added. Close #4225 Co-authored-by: Liang Wu PiperOrigin-RevId: 859239880 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ad1da6fff1..fc50b24199 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,7 @@ test = [ "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5, <2.0.0", # For LiteLLM tests + "litellm>=1.75.5, <1.81.0", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "openai>=1.100.2", # For LiteLLM "pytest-asyncio>=0.25.0", @@ -153,7 +153,7 @@ extensions = [ "docker>=7.0.0", # For ContainerCodeExecutor "kubernetes>=29.0.0", # For GkeCodeExecutor "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it + "litellm>=1.75.5, <1.81.0", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. "llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex. "lxml>=5.3.0", # For load_web_page tool.