diff --git a/contributing/samples/api_registry_agent/agent.py b/contributing/samples/api_registry_agent/agent.py index 5247c9fac5..9f55ef8069 100644 --- a/contributing/samples/api_registry_agent/agent.py +++ b/contributing/samples/api_registry_agent/agent.py @@ -26,7 +26,7 @@ mcp_server_name=MCP_SERVER_NAME, ) root_agent = LlmAgent( - model="gemini-2.0-flash", + model="gemini-2.5-flash", name="bigquery_assistant", instruction=f""" You are a helpful data analyst assistant with access to BigQuery. The project ID is: {PROJECT_ID} diff --git a/contributing/samples/application_integration_agent/README.md b/contributing/samples/application_integration_agent/README.md index 0e0a70c17c..961a65eb53 100644 --- a/contributing/samples/application_integration_agent/README.md +++ b/contributing/samples/application_integration_agent/README.md @@ -7,7 +7,7 @@ This sample demonstrates how to use the `ApplicationIntegrationToolset` within a ## Prerequisites 1. **Set up Integration Connection:** - * You need an existing [Integration connection](https://cloud.google.com/integration-connectors/docs/overview) configured to interact with your Jira instance. Follow the [documentation](https://google.github.io/adk-docs/tools/google-cloud-tools/#use-integration-connectors) to provision the Integration Connector in Google Cloud and then use this [documentation](https://cloud.google.com/integration-connectors/docs/connectors/jiracloud/configure) to create an Jira connection. Note the `Connection Name`, `Project ID`, and `Location` of your connection. + * You need an existing [Integration connection](https://cloud.google.com/integration-connectors/docs/overview) configured to interact with your Jira instance. Follow the [documentation](https://google.github.io/adk-docs/tools/google-cloud-tools/#use-integration-connectors) to provision the Integration Connector in Google Cloud and then use this [documentation](https://cloud.google.com/integration-connectors/docs/connectors/jiracloud/configure) to create a Jira connection. Note the `Connection Name`, `Project ID`, and `Location` of your connection. * 2. **Configure Environment Variables:** diff --git a/contributing/samples/mcp_dynamic_header_agent/agent.py b/contributing/samples/mcp_dynamic_header_agent/agent.py index 17768c16e4..0d3ce7f6a6 100644 --- a/contributing/samples/mcp_dynamic_header_agent/agent.py +++ b/contributing/samples/mcp_dynamic_header_agent/agent.py @@ -18,7 +18,7 @@ from google.adk.tools.mcp_tool.mcp_toolset import McpToolset root_agent = LlmAgent( - model='gemini-2.0-flash', + model='gemini-2.5-flash', name='tenant_agent', instruction="""You are a helpful assistant that helps users get tenant information. Call the get_tenant_data tool when the user asks for tenant data.""", diff --git a/contributing/samples/mcp_postgres_agent/agent.py b/contributing/samples/mcp_postgres_agent/agent.py index 7224c34ab5..09e75b44d0 100644 --- a/contributing/samples/mcp_postgres_agent/agent.py +++ b/contributing/samples/mcp_postgres_agent/agent.py @@ -31,7 +31,7 @@ ) root_agent = LlmAgent( - model="gemini-2.0-flash", + model="gemini-2.5-flash", name="postgres_agent", instruction=( "You are a PostgreSQL database assistant. " diff --git a/contributing/samples/mcp_service_account_agent/agent.py b/contributing/samples/mcp_service_account_agent/agent.py index a62e30cea2..975ff0ccb7 100644 --- a/contributing/samples/mcp_service_account_agent/agent.py +++ b/contributing/samples/mcp_service_account_agent/agent.py @@ -29,7 +29,7 @@ SCOPES = {"https://www.googleapis.com/auth/cloud-platform": ""} root_agent = LlmAgent( - model="gemini-2.0-flash", + model="gemini-2.5-flash", name="enterprise_assistant", instruction=""" Help the user with the tools available to you. diff --git a/contributing/samples/mcp_sse_agent/agent.py b/contributing/samples/mcp_sse_agent/agent.py index 2afbb930b1..8bddbca79f 100755 --- a/contributing/samples/mcp_sse_agent/agent.py +++ b/contributing/samples/mcp_sse_agent/agent.py @@ -28,7 +28,7 @@ ) root_agent = LlmAgent( - model='gemini-2.0-flash', + model='gemini-2.5-flash', name='enterprise_assistant', instruction=McpInstructionProvider( connection_params=connection_params, diff --git a/contributing/samples/mcp_stdio_notion_agent/agent.py b/contributing/samples/mcp_stdio_notion_agent/agent.py index 55ea56ec49..1a624f7ee1 100644 --- a/contributing/samples/mcp_stdio_notion_agent/agent.py +++ b/contributing/samples/mcp_stdio_notion_agent/agent.py @@ -29,7 +29,7 @@ }) root_agent = LlmAgent( - model="gemini-2.0-flash", + model="gemini-2.5-flash", name="notion_agent", instruction=( "You are my workspace assistant. " diff --git a/contributing/samples/mcp_stdio_server_agent/agent.py b/contributing/samples/mcp_stdio_server_agent/agent.py index 1799bd56d0..8b336ba281 100755 --- a/contributing/samples/mcp_stdio_server_agent/agent.py +++ b/contributing/samples/mcp_stdio_server_agent/agent.py @@ -23,7 +23,7 @@ _allowed_path = os.path.dirname(os.path.abspath(__file__)) root_agent = LlmAgent( - model='gemini-2.0-flash', + model='gemini-2.5-flash', name='enterprise_assistant', instruction=f"""\ Help user accessing their file systems. diff --git a/contributing/samples/mcp_streamablehttp_agent/agent.py b/contributing/samples/mcp_streamablehttp_agent/agent.py index e2223b0f13..1da782d04a 100644 --- a/contributing/samples/mcp_streamablehttp_agent/agent.py +++ b/contributing/samples/mcp_streamablehttp_agent/agent.py @@ -22,7 +22,7 @@ _allowed_path = os.path.dirname(os.path.abspath(__file__)) root_agent = LlmAgent( - model='gemini-2.0-flash', + model='gemini-2.5-flash', name='enterprise_assistant', instruction=f"""\ Help user accessing their file systems. diff --git a/contributing/samples/multi_agent_seq_config/README.md b/contributing/samples/multi_agent_seq_config/README.md index a2cd462465..af0dcee2fc 100644 --- a/contributing/samples/multi_agent_seq_config/README.md +++ b/contributing/samples/multi_agent_seq_config/README.md @@ -6,7 +6,7 @@ The whole process is: 1. An agent backed by a cheap and fast model to write initial version. 2. An agent backed by a smarter and a little more expensive to review the code. -3. An final agent backed by the smartest and slowest model to write the final revision. +3. A final agent backed by the smartest and slowest model to write the final revision. Sample queries: diff --git a/contributing/samples/spanner_rag_agent/README.md b/contributing/samples/spanner_rag_agent/README.md index 99b60794fe..08d134b990 100644 --- a/contributing/samples/spanner_rag_agent/README.md +++ b/contributing/samples/spanner_rag_agent/README.md @@ -181,7 +181,7 @@ type. ## 💬 Sample prompts -* I'd like to buy a starter bike for my 3 year old child, can you show me the recommendation? +* I'd like to buy a starter bike for my 3-year-old child, can you show me the recommendation? ![Spanner RAG Sample Agent](Spanner_RAG_Sample_Agent.png) diff --git a/pyproject.toml b/pyproject.toml index ad1da6fff1..fc50b24199 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,7 @@ test = [ "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5, <2.0.0", # For LiteLLM tests + "litellm>=1.75.5, <1.81.0", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "openai>=1.100.2", # For LiteLLM "pytest-asyncio>=0.25.0", @@ -153,7 +153,7 @@ extensions = [ "docker>=7.0.0", # For ContainerCodeExecutor "kubernetes>=29.0.0", # For GkeCodeExecutor "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it + "litellm>=1.75.5, <1.81.0", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. "llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex. "lxml>=5.3.0", # For load_web_page tool. diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 514e31b272..759ac532fd 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -38,8 +38,6 @@ from ...agents.run_config import StreamingMode from ...agents.transcription_entry import TranscriptionEntry from ...events.event import Event -from ...features import FeatureName -from ...features import is_feature_enabled from ...models.base_llm_connection import BaseLlmConnection from ...models.llm_request import LlmRequest from ...models.llm_response import LlmResponse @@ -274,6 +272,25 @@ async def _send_to_model( await llm_connection.send_realtime(live_request.blob) if live_request.content: + content = live_request.content + # Persist user text content to session (similar to non-live mode) + # Skip function responses - they are already handled separately + is_function_response = content.parts and any( + part.function_response for part in content.parts + ) + if not is_function_response: + if not content.role: + content.role = 'user' + user_content_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author='user', + content=content, + ) + await invocation_context.session_service.append_event( + session=invocation_context.session, + event=user_content_event, + ) await llm_connection.send_content(live_request.content) async def _receive_from_model( @@ -393,8 +410,8 @@ async def _run_one_step_async( current_invocation=True, current_branch=True ) - # Long-running tool calls should have been handled before this point. - # If there are still long-running tool calls, it means the agent is paused + # Long running tool calls should have been handled before this point. + # If there are still long running tool calls, it means the agent is paused # before, and its branch hasn't been resumed yet. if ( invocation_context.is_resumable @@ -551,14 +568,11 @@ async def _postprocess_async( # Handles function calls. if model_response_event.get_function_calls(): - if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): - # In progressive SSE streaming mode stage 1, we skip partial FC events - # Only execute FCs in the final aggregated event (partial=False) - if ( - invocation_context.run_config.streaming_mode == StreamingMode.SSE - and model_response_event.partial - ): - return + # Skip partial function call events - they should not trigger execution + # since partial events are not saved to session (see runners.py). + # Only execute function calls in the non-partial events. + if model_response_event.partial: + return async with Aclosing( self._postprocess_handle_function_calls_async( diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bfccbdc947..b931561c4d 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -815,6 +815,7 @@ async def _exec_with_plugin( await self.session_service.append_event( session=session, event=buffered_event ) + yield buffered_event # yield buffered events to caller buffered_events = [] else: # non-transcription event or empty transcription event, for diff --git a/src/google/adk/tools/mcp_tool/mcp_auth_utils.py b/src/google/adk/tools/mcp_tool/mcp_auth_utils.py deleted file mode 100644 index b074e67f15..0000000000 --- a/src/google/adk/tools/mcp_tool/mcp_auth_utils.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for MCP tool authentication.""" - -from __future__ import annotations - -import base64 -import logging -from typing import Dict -from typing import Optional - -from fastapi.openapi import models as openapi_models -from fastapi.openapi.models import APIKey -from fastapi.openapi.models import HTTPBase - -from ...auth.auth_credential import AuthCredential -from ...auth.auth_schemes import AuthScheme - -logger = logging.getLogger("google_adk." + __name__) - - -def get_mcp_auth_headers( - auth_scheme: Optional[AuthScheme], credential: Optional[AuthCredential] -) -> Optional[Dict[str, str]]: - """Generates HTTP authentication headers for MCP calls. - - Args: - auth_scheme: The authentication scheme. - credential: The resolved authentication credential. - - Returns: - A dictionary of headers, or None if no auth is applicable. - - Raises: - ValueError: If the auth scheme is unsupported or misconfigured. - """ - if not credential: - return None - - headers: Optional[Dict[str, str]] = None - - if credential.oauth2: - headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} - elif credential.http: - if not auth_scheme or not isinstance(auth_scheme, HTTPBase): - logger.warning( - "HTTP credential provided, but auth_scheme is missing or not" - " HTTPBase." - ) - return None - - scheme = auth_scheme.scheme.lower() - if scheme == "bearer" and credential.http.credentials.token: - headers = {"Authorization": f"Bearer {credential.http.credentials.token}"} - elif scheme == "basic": - if ( - credential.http.credentials.username - and credential.http.credentials.password - ): - creds = f"{credential.http.credentials.username}:{credential.http.credentials.password}" - encoded_creds = base64.b64encode(creds.encode()).decode() - headers = {"Authorization": f"Basic {encoded_creds}"} - else: - logger.warning("Basic auth scheme missing username or password.") - elif credential.http.credentials.token: - # Handle other HTTP schemes like Digest, etc. if token is present - headers = { - "Authorization": ( - f"{auth_scheme.scheme} {credential.http.credentials.token}" - ) - } - else: - logger.warning(f"Unsupported or incomplete HTTP auth scheme '{scheme}'.") - elif credential.api_key: - if not auth_scheme or not isinstance(auth_scheme, APIKey): - logger.warning( - "API key credential provided, but auth_scheme is missing or not" - " APIKey." - ) - return None - - if auth_scheme.in_ != openapi_models.APIKeyIn.header: - error_msg = ( - "MCP tools only support header-based API key authentication. " - f"Configured location: {auth_scheme.in_}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - headers = {auth_scheme.name: credential.api_key} - elif credential.service_account: - logger.warning( - "Service account credentials should be exchanged for an access token " - "before calling get_mcp_auth_headers." - ) - else: - logger.warning(f"Unsupported credential type: {type(credential)}") - - return headers diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 0b0c95ee35..a5b598fd81 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import base64 import inspect import logging from typing import Any @@ -23,6 +24,7 @@ from typing import Union import warnings +from fastapi.openapi.models import APIKeyIn from google.genai.types import FunctionDeclaration from mcp.types import Tool as McpBaseTool from typing_extensions import override @@ -37,7 +39,6 @@ from ..base_authenticated_tool import BaseAuthenticatedTool # import from ..tool_context import ToolContext -from .mcp_auth_utils import get_mcp_auth_headers from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_errors @@ -194,12 +195,7 @@ async def _run_async_impl( Any: The response from the tool. """ # Extract headers from credential for session pooling - auth_scheme = ( - self._auth_config.auth_scheme - if hasattr(self, "_auth_config") and self._auth_config - else None - ) - auth_headers = get_mcp_auth_headers(auth_scheme, credential) + auth_headers = await self._get_headers(tool_context, credential) dynamic_headers = None if self._header_provider: dynamic_headers = self._header_provider( @@ -221,6 +217,90 @@ async def _run_async_impl( response = await session.call_tool(self._mcp_tool.name, arguments=args) return response.model_dump(exclude_none=True, mode="json") + async def _get_headers( + self, tool_context: ToolContext, credential: AuthCredential + ) -> Optional[dict[str, str]]: + """Extracts authentication headers from credentials. + + Args: + tool_context: The tool context of the current invocation. + credential: The authentication credential to process. + + Returns: + Dictionary of headers to add to the request, or None if no auth. + + Raises: + ValueError: If API key authentication is configured for non-header location. + """ + headers: Optional[dict[str, str]] = None + if credential: + if credential.oauth2: + headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} + elif credential.http: + # Handle HTTP authentication schemes + if ( + credential.http.scheme.lower() == "bearer" + and credential.http.credentials.token + ): + headers = { + "Authorization": f"Bearer {credential.http.credentials.token}" + } + elif credential.http.scheme.lower() == "basic": + # Handle basic auth + if ( + credential.http.credentials.username + and credential.http.credentials.password + ): + + credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}" + encoded_credentials = base64.b64encode( + credentials.encode() + ).decode() + headers = {"Authorization": f"Basic {encoded_credentials}"} + elif credential.http.credentials.token: + # Handle other HTTP schemes with token + headers = { + "Authorization": ( + f"{credential.http.scheme} {credential.http.credentials.token}" + ) + } + elif credential.api_key: + if ( + not self._credentials_manager + or not self._credentials_manager._auth_config + ): + error_msg = ( + "Cannot find corresponding auth scheme for API key credential" + f" {credential}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + elif ( + self._credentials_manager._auth_config.auth_scheme.in_ + != APIKeyIn.header + ): + error_msg = ( + "McpTool only supports header-based API key authentication." + " Configured location:" + f" {self._credentials_manager._auth_config.auth_scheme.in_}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + else: + headers = { + self._credentials_manager._auth_config.auth_scheme.name: ( + credential.api_key + ) + } + elif credential.service_account: + # Service accounts should be exchanged for access tokens before reaching this point + logger.warning( + "Service account credentials should be exchanged before MCP" + " session creation" + ) + + return headers + class MCPTool(McpTool): """Deprecated name, use `McpTool` instead.""" diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 77abf20332..fed0684ed2 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -33,14 +33,11 @@ from ...agents.readonly_context import ReadonlyContext from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme -from ...auth.auth_tool import AuthConfig -from ...auth.credential_manager import CredentialManager from ..base_tool import BaseTool from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate from ..tool_configs import BaseToolConfig from ..tool_configs import ToolArgsConfig -from .mcp_auth_utils import get_mcp_auth_headers from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_errors from .mcp_session_manager import SseConnectionParams @@ -157,50 +154,13 @@ async def get_tools( Returns: List[BaseTool]: A list of tools available under the specified context. """ - provided_headers = ( + headers = ( self._header_provider(readonly_context) if self._header_provider and readonly_context - else {} + else None ) - - auth_headers = {} - if self._auth_scheme: - try: - # Instantiate CredentialsManager to resolve credentials - auth_config = AuthConfig( - auth_scheme=self._auth_scheme, - raw_auth_credential=self._auth_credential, - ) - credentials_manager = CredentialManager(auth_config) - - # Resolve the credential - resolved_credential = await credentials_manager.get_auth_credential( - readonly_context - ) - - if resolved_credential: - auth_headers = get_mcp_auth_headers( - self._auth_scheme, resolved_credential - ) - else: - logger.warning( - "Failed to resolve credential for tool listing, proceeding" - " without auth headers." - ) - except Exception as e: - logger.warning( - "Error generating auth headers for tool listing: %s, proceeding" - " without auth headers.", - e, - exc_info=True, - ) - - merged_headers = {**(provided_headers or {}), **(auth_headers or {})} - # Get session from session manager - session = await self._mcp_session_manager.create_session( - headers=merged_headers - ) + session = await self._mcp_session_manager.create_session(headers=headers) # Fetch available tools from the MCP server timeout_in_seconds = ( diff --git a/src/google/adk/tools/pubsub/__init__.py b/src/google/adk/tools/pubsub/__init__.py index aae36ae528..d488c317d9 100644 --- a/src/google/adk/tools/pubsub/__init__.py +++ b/src/google/adk/tools/pubsub/__init__.py @@ -14,7 +14,7 @@ """Pub/Sub Tools (Experimental). -Pub/Sub Tools under this module are hand crafted and customized while the tools +Pub/Sub Tools under this module are handcrafted and customized while the tools under google.adk.tools.google_api_tool are auto generated based on API definition. The rationales to have customized tool are: diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py index dca8ef098b..ca7e05f1b1 100644 --- a/src/google/adk/tools/spanner/settings.py +++ b/src/google/adk/tools/spanner/settings.py @@ -115,7 +115,7 @@ class VectorSearchIndexSettings(BaseModel): """ num_branches: Optional[int] = None - """Optional. The number of branches to further parititon the vector data. + """Optional. The number of branches to further partition the vector data. You can only designate num_branches for trees with 3 levels. The number of branches must be fewer than the number of leaves @@ -165,7 +165,7 @@ class SpannerVectorStoreSettings(BaseModel): """Required. The vector store table columns to return in the vector similarity search result. By default, only the `content_column` value and the distance value are returned. - If sepecified, the list of selected columns and the distance value are returned. + If specified, the list of selected columns and the distance value are returned. For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value. """ diff --git a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py index 9b58d5bcbd..37dfbb95eb 100644 --- a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py +++ b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py @@ -631,3 +631,144 @@ async def process(): args = fc_part.function_call.args assert args["num"] == 100 assert args["s"] == "ADK" + + +class PartialFunctionCallMockModel(BaseLlm): + """A mock model that yields partial function call events followed by final.""" + + model: str = "partial-fc-mock" + tool_call_count: int = 0 + + @classmethod + def supported_models(cls) -> list[str]: + return ["partial-fc-mock"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + """Yield partial FC events then final, simulating streaming behavior.""" + + # Check if this is a follow-up call (after function response) + has_function_response = False + for content in llm_request.contents: + for part in content.parts or []: + if part.function_response: + has_function_response = True + break + + if has_function_response: + # Final response after function execution + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="Function executed once.")], + ), + partial=False, + ) + return + + # First call: yield partial FC events then final + # Partial event 1 + yield LlmResponse( + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name="track_execution", args={"call_id": "partial_1"} + ) + ], + ), + partial=True, + ) + + # Partial event 2 + yield LlmResponse( + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name="track_execution", args={"call_id": "partial_2"} + ) + ], + ), + partial=True, + ) + + # Final aggregated event (only this should trigger execution) + yield LlmResponse( + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name="track_execution", args={"call_id": "final"} + ) + ], + ), + partial=False, + finish_reason=types.FinishReason.STOP, + ) + + +def test_partial_function_calls_not_executed_in_none_streaming_mode(): + """Test that partial function call events are skipped regardless of mode.""" + execution_log = [] + + def track_execution(call_id: str) -> str: + """A tool that logs each execution to verify call count.""" + execution_log.append(call_id) + return f"Executed: {call_id}" + + mock_model = PartialFunctionCallMockModel() + + agent = Agent( + name="partial_fc_test_agent", + model=mock_model, + tools=[track_execution], + ) + + # Use StreamingMode.NONE to verify partial FCs are still skipped + run_config = RunConfig(streaming_mode=StreamingMode.NONE) + + runner = InMemoryRunner(agent=agent) + + session = runner.session_service.create_session_sync( + app_name=runner.app_name, user_id="test_user" + ) + + events = [] + for event in runner.run( + user_id="test_user", + session_id=session.id, + new_message=types.Content( + role="user", + parts=[types.Part.from_text(text="Test partial FC handling")], + ), + run_config=run_config, + ): + events.append(event) + + # Verify the tool was only executed once (from the final event) + assert ( + len(execution_log) == 1 + ), f"Expected 1 execution, got {len(execution_log)}: {execution_log}" + assert ( + execution_log[0] == "final" + ), f"Expected 'final' execution, got: {execution_log[0]}" + + # Verify partial events were yielded but not executed + partial_events = [e for e in events if e.partial] + assert ( + len(partial_events) == 2 + ), f"Expected 2 partial events, got {len(partial_events)}" + + # Verify there's a function response event (from the final FC execution) + function_response_events = [ + e + for e in events + if e.content + and e.content.parts + and any(p.function_response for p in e.content.parts) + ] + assert ( + len(function_response_events) == 1 + ), f"Expected 1 function response event, got {len(function_response_events)}" diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index 57743d6076..e697eacdc0 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -1009,3 +1009,200 @@ async def consume_responses(session: testing_utils.Session): assert stock_call_found, 'Expected monitor_stock_price function call event.' assert video_call_found, 'Expected monitor_video_stream function call event.' + + +def test_live_streaming_buffered_function_call_yielded_during_transcription(): + """Test that function calls buffered during transcription are yielded. + + This tests the fix for the bug where function_call and function_response + events were buffered during active transcription but never yielded to the + caller. The fix ensures buffered events are yielded after transcription ends. + """ + function_call = types.Part.from_function_call( + name='get_weather', args={'location': 'San Francisco'} + ) + + response1 = LlmResponse( + input_transcription=types.Transcription(text='Show'), + partial=True, # ← Triggers is_transcribing = True + ) + response2 = LlmResponse( + content=types.Content( + role='model', parts=[function_call] + ), # ← Gets buffered + turn_complete=False, + ) + response3 = LlmResponse( + input_transcription=types.Transcription(text='Show me the weather'), + partial=False, # ← Transcription ends, buffered events yielded + ) + response4 = LlmResponse( + turn_complete=True, + ) + + mock_model = testing_utils.MockModel.create( + [response1, response2, response3, response4] + ) + + def get_weather(location: str) -> dict: + return {'temperature': 22, 'location': location} + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[get_weather], + ) + + class CustomTestRunner(testing_utils.InMemoryRunner): + + def run_live( + self, + live_request_queue: LiveRequestQueue, + run_config: testing_utils.RunConfig = None, + ) -> list[testing_utils.Event]: + collected_responses = [] + + async def consume_responses(session: testing_utils.Session): + run_res = self.runner.run_live( + session=session, + live_request_queue=live_request_queue, + run_config=run_config or testing_utils.RunConfig(), + ) + + async for response in run_res: + collected_responses.append(response) + if len(collected_responses) >= 5: + return + + try: + session = self.session + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + return collected_responses + + runner = CustomTestRunner(root_agent=root_agent) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=types.Blob(data=b'Show me the weather', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue) + + assert res_events is not None, 'Expected a list of events, got None.' + assert len(res_events) >= 1, 'Expected at least one event.' + + function_call_found = False + function_response_found = False + + for event in res_events: + if event.content and event.content.parts: + for part in event.content.parts: + if part.function_call and part.function_call.name == 'get_weather': + function_call_found = True + assert part.function_call.args['location'] == 'San Francisco' + if ( + part.function_response + and part.function_response.name == 'get_weather' + ): + function_response_found = True + assert part.function_response.response['temperature'] == 22 + + assert function_call_found, 'Buffered function_call event was not yielded.' + assert ( + function_response_found + ), 'Buffered function_response event was not yielded.' + + +def test_live_streaming_text_content_persisted_in_session(): + """Test that user text content sent via send_content is persisted in session.""" + response1 = LlmResponse( + content=types.Content( + role='model', parts=[types.Part(text='Hello! How can I help you?')] + ), + turn_complete=True, + ) + + mock_model = testing_utils.MockModel.create([response1]) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[], + ) + + class CustomTestRunner(testing_utils.InMemoryRunner): + + def run_live_and_get_session( + self, + live_request_queue: LiveRequestQueue, + run_config: testing_utils.RunConfig = None, + ) -> tuple[list[testing_utils.Event], testing_utils.Session]: + collected_responses = [] + + async def consume_responses(session: testing_utils.Session): + run_res = self.runner.run_live( + session=session, + live_request_queue=live_request_queue, + run_config=run_config or testing_utils.RunConfig(), + ) + async for response in run_res: + collected_responses.append(response) + if len(collected_responses) >= 1: + return + + try: + session = self.session + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + # Get the updated session + updated_session = self.runner.session_service.get_session_sync( + app_name=self.app_name, + user_id=session.user_id, + session_id=session.id, + ) + return collected_responses, updated_session + + runner = CustomTestRunner(root_agent=root_agent) + live_request_queue = LiveRequestQueue() + + # Send text content (not audio blob) + user_text = 'Hello, this is a test message' + live_request_queue.send_content( + types.Content(role='user', parts=[types.Part(text=user_text)]) + ) + + res_events, session = runner.run_live_and_get_session(live_request_queue) + + assert res_events is not None, 'Expected a list of events, got None.' + + # Check that user text content was persisted in the session + user_content_found = False + for event in session.events: + if event.author == 'user' and event.content: + for part in event.content.parts: + if part.text and user_text in part.text: + user_content_found = True + break + + assert user_content_found, ( + f'Expected user text content "{user_text}" to be persisted in session. ' + f'Session events: {[e.content for e in session.events]}' + ) diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index f76668b165..4f9b8636fe 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -409,6 +409,10 @@ async def send_realtime(self, blob: types.Blob): async def receive(self) -> AsyncGenerator[LlmResponse, None]: """Yield each of the pre-defined LlmResponses.""" for response in self.llm_responses: + # Yield control to allow other tasks (like send_task) to run first. + # This ensures user content gets persisted before the mock response + # is yielded. + await asyncio.sleep(0) yield response async def close(self): diff --git a/tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py b/tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py deleted file mode 100644 index 9e0988467e..0000000000 --- a/tests/unittests/tools/mcp_tool/test_mcp_auth_utils.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -from unittest.mock import patch - -from fastapi.openapi import models as openapi_models -from google.adk.auth.auth_credential import AuthCredential -from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.auth.auth_credential import HttpAuth -from google.adk.auth.auth_credential import HttpCredentials -from google.adk.auth.auth_credential import OAuth2Auth -from google.adk.auth.auth_credential import ServiceAccount -from google.adk.auth.auth_schemes import AuthSchemeType -from google.adk.tools.mcp_tool import mcp_auth_utils -import pytest - - -def test_get_mcp_auth_headers_no_credential(): - """Test header generation with no credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=None - ) - assert headers is None - - -def test_get_mcp_auth_headers_no_auth_scheme(): - """Test header generation with no auth_scheme.""" - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(access_token="test_token"), - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=None, credential=credential - ) - assert headers == {"Authorization": "Bearer test_token"} - - -def test_get_mcp_auth_headers_oauth2(): - """Test header generation for OAuth2 credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(access_token="test_token"), - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers == {"Authorization": "Bearer test_token"} - - -def test_get_mcp_auth_headers_http_bearer(): - """Test header generation for HTTP Bearer credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", credentials=HttpCredentials(token="bearer_token") - ), - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers == {"Authorization": "Bearer bearer_token"} - - -def test_get_mcp_auth_headers_http_basic(): - """Test header generation for HTTP Basic credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="basic") - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="basic", - credentials=HttpCredentials(username="user", password="pass"), - ), - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - expected_encoded = base64.b64encode(b"user:pass").decode() - assert headers == {"Authorization": f"Basic {expected_encoded}"} - - -def test_get_mcp_auth_headers_http_basic_missing_credentials(): - """Test header generation for HTTP Basic with missing credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="basic") - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="basic", - credentials=HttpCredentials(username="user", password=None), - ), - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers is None - mock_logger.warning.assert_called_once_with( - "Basic auth scheme missing username or password." - ) - - -def test_get_mcp_auth_headers_http_custom_scheme(): - """Test header generation for custom HTTP scheme.""" - auth_scheme = openapi_models.HTTPBase(scheme="custom") - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="custom", credentials=HttpCredentials(token="custom_token") - ), - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers == {"Authorization": "custom custom_token"} - - -def test_get_mcp_auth_headers_http_cred_wrong_scheme(): - """Test HTTP credential with non-HTTPBase auth scheme.""" - auth_scheme = openapi_models.APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": openapi_models.APIKeyIn.header, - "name": "X-API-Key", - }) - credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, - http=HttpAuth( - scheme="bearer", credentials=HttpCredentials(token="bearer_token") - ), - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers is None - mock_logger.warning.assert_called_once_with( - "HTTP credential provided, but auth_scheme is missing or not HTTPBase." - ) - - -def test_get_mcp_auth_headers_api_key_header(): - """Test header generation for API Key in header.""" - auth_scheme = openapi_models.APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": openapi_models.APIKeyIn.header, - "name": "X-Custom-API-Key", - }) - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers == {"X-Custom-API-Key": "my_api_key"} - - -def test_get_mcp_auth_headers_api_key_query_raises_error(): - """Test API Key in query raises ValueError.""" - auth_scheme = openapi_models.APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": openapi_models.APIKeyIn.query, - "name": "api_key", - }) - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - with pytest.raises( - ValueError, - match="MCP tools only support header-based API key authentication.", - ): - mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - - -def test_get_mcp_auth_headers_api_key_cookie_raises_error(): - """Test API Key in cookie raises ValueError.""" - auth_scheme = openapi_models.APIKey(**{ - "type": AuthSchemeType.apiKey, - "in": openapi_models.APIKeyIn.cookie, - "name": "session_id", - }) - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - with pytest.raises( - ValueError, - match="MCP tools only support header-based API key authentication.", - ): - mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - - -def test_get_mcp_auth_headers_api_key_cred_wrong_scheme(): - """Test API key credential with non-APIKey auth scheme.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - credential = AuthCredential( - auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers is None - mock_logger.warning.assert_called_once_with( - "API key credential provided, but auth_scheme is missing or not APIKey." - ) - - -def test_get_mcp_auth_headers_service_account(): - """Test header generation for service account credentials.""" - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - credential = AuthCredential( - auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, - service_account=ServiceAccount(scopes=["test"]), - ) - with patch.object(mcp_auth_utils, "logger") as mock_logger: - headers = mcp_auth_utils.get_mcp_auth_headers( - auth_scheme=auth_scheme, credential=credential - ) - assert headers is None - mock_logger.warning.assert_called_once_with( - "Service account credentials should be exchanged for an access " - "token before calling get_mcp_auth_headers." - ) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index cb13a4e541..0bf28cb3c2 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -18,7 +18,10 @@ from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_credential import ServiceAccount from google.adk.features import FeatureName from google.adk.features._feature_registry import temporary_feature_override from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager @@ -258,6 +261,240 @@ async def test_run_async_impl_with_oauth2(self): headers = call_args[1]["headers"] assert headers == {"Authorization": "Bearer test_access_token"} + @pytest.mark.asyncio + async def test_get_headers_oauth2(self): + """Test header generation for OAuth2 credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + oauth2_auth = OAuth2Auth(access_token="test_token") + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer test_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_bearer(self): + """Test header generation for HTTP Bearer credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="bearer", credentials=HttpCredentials(token="bearer_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "Bearer bearer_token"} + + @pytest.mark.asyncio + async def test_get_headers_http_basic(self): + """Test header generation for HTTP Basic credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="basic", + credentials=HttpCredentials(username="user", password="pass"), + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should create Basic auth header with base64 encoded credentials + import base64 + + expected_encoded = base64.b64encode(b"user:pass").decode() + assert headers == {"Authorization": f"Basic {expected_encoded}"} + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_valid_header_scheme(self): + """Test header generation for API Key credentials with header-based auth scheme.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for header-based API key + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.header, + "name": "X-Custom-API-Key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, auth_credential) + + assert headers == {"X-Custom-API-Key": "my_api_key"} + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_query_scheme_raises_error(self): + """Test that API Key with query-based auth scheme raises ValueError.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for query-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.query, + "name": "api_key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="McpTool only supports header-based API key authentication", + ): + await tool._get_headers(tool_context, auth_credential) + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_cookie_scheme_raises_error(self): + """Test that API Key with cookie-based auth scheme raises ValueError.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for cookie-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.cookie, + "name": "session_id", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="McpTool only supports header-based API key authentication", + ): + await tool._get_headers(tool_context, auth_credential) + + @pytest.mark.asyncio + async def test_get_headers_api_key_without_auth_config_raises_error(self): + """Test that API Key without auth config raises ValueError.""" + # Create tool without auth scheme/config + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="Cannot find corresponding auth scheme for API key credential", + ): + await tool._get_headers(tool_context, credential) + + @pytest.mark.asyncio + async def test_get_headers_api_key_without_credentials_manager_raises_error( + self, + ): + """Test that API Key without credentials manager raises ValueError.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Manually set credentials manager to None to simulate error condition + tool._credentials_manager = None + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="Cannot find corresponding auth scheme for API key credential", + ): + await tool._get_headers(tool_context, credential) + + @pytest.mark.asyncio + async def test_get_headers_no_credential(self): + """Test header generation with no credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, None) + + assert headers is None + + @pytest.mark.asyncio + async def test_get_headers_service_account(self): + """Test header generation for service account credentials.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + # Create service account credential + service_account = ServiceAccount(scopes=["test"]) + credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=service_account, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + # Should return None as service account credentials are not supported for direct header generation + assert headers is None + @pytest.mark.asyncio async def test_run_async_impl_with_api_key_header_auth(self): """Test running tool with API key header authentication end-to-end.""" @@ -314,6 +551,65 @@ async def test_run_async_impl_retry_decorator(self): # Check that the method has the retry decorator assert hasattr(tool._run_async_impl, "__wrapped__") + @pytest.mark.asyncio + async def test_get_headers_http_custom_scheme(self): + """Test header generation for custom HTTP scheme.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + http_auth = HttpAuth( + scheme="custom", credentials=HttpCredentials(token="custom_token") + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=http_auth + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, credential) + + assert headers == {"Authorization": "custom custom_token"} + + @pytest.mark.asyncio + async def test_get_headers_api_key_error_logging(self): + """Test that API key errors are logged correctly.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for query-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.query, + "name": "api_key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + + # Test with logging + with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger: + with pytest.raises(ValueError): + await tool._get_headers(tool_context, auth_credential) + + # Verify error was logged + mock_logger.error.assert_called_once() + logged_message = mock_logger.error.call_args[0][0] + assert ( + "McpTool only supports header-based API key authentication" + in logged_message + ) + @pytest.mark.asyncio async def test_run_async_require_confirmation_true_no_confirmation(self): """Test require_confirmation=True with no confirmation in context.""" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 34d8c5b7fb..83b112d8b9 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -30,8 +30,6 @@ from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from google.adk.tools.mcp_tool.mcp_toolset import McpToolset -from google.adk.tools.mcp_tool.mcp_toolset import McpToolsetConfig -from google.adk.tools.tool_configs import ToolArgsConfig from mcp import StdioServerParameters import pytest @@ -247,94 +245,6 @@ async def test_get_tools_with_header_provider(self): headers=expected_headers ) - @pytest.mark.asyncio - async def test_get_tools_with_auth_headers(self): - """Test get_tools with auth headers.""" - from fastapi.openapi import models as openapi_models - from google.adk.auth.auth_credential import AuthCredentialTypes - from google.adk.auth.auth_credential import OAuth2Auth - - mock_tools = [MockMCPTool("tool1")] - self.mock_session.list_tools = AsyncMock( - return_value=MockListToolsResult(mock_tools) - ) - mock_readonly_context = Mock(spec=ReadonlyContext) - - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(access_token="test_token"), - ) - - with patch( - "google.adk.tools.mcp_tool.mcp_toolset.CredentialManager" - ) as MockCredentialManager: - mock_manager_instance = MockCredentialManager.return_value - mock_manager_instance.get_auth_credential = AsyncMock( - return_value=auth_credential - ) - - toolset = MCPToolset( - connection_params=self.mock_stdio_params, - auth_scheme=auth_scheme, - auth_credential=auth_credential, - ) - toolset._mcp_session_manager = self.mock_session_manager - - await toolset.get_tools(readonly_context=mock_readonly_context) - - self.mock_session_manager.create_session.assert_called_once() - call_args = self.mock_session_manager.create_session.call_args - headers = call_args[1]["headers"] - assert headers == {"Authorization": "Bearer test_token"} - - @pytest.mark.asyncio - async def test_get_tools_with_auth_and_header_provider(self): - """Test get_tools with auth and header_provider.""" - from fastapi.openapi import models as openapi_models - from google.adk.auth.auth_credential import AuthCredentialTypes - from google.adk.auth.auth_credential import OAuth2Auth - - mock_tools = [MockMCPTool("tool1")] - self.mock_session.list_tools = AsyncMock( - return_value=MockListToolsResult(mock_tools) - ) - mock_readonly_context = Mock(spec=ReadonlyContext) - provided_headers = {"X-Tenant-ID": "test-tenant"} - header_provider = Mock(return_value=provided_headers) - - auth_scheme = openapi_models.HTTPBase(scheme="bearer") - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(access_token="test_token"), - ) - - with patch( - "google.adk.tools.mcp_tool.mcp_toolset.CredentialManager" - ) as MockCredentialManager: - mock_manager_instance = MockCredentialManager.return_value - mock_manager_instance.get_auth_credential = AsyncMock( - return_value=auth_credential - ) - - toolset = MCPToolset( - connection_params=self.mock_stdio_params, - auth_scheme=auth_scheme, - auth_credential=auth_credential, - header_provider=header_provider, - ) - toolset._mcp_session_manager = self.mock_session_manager - - await toolset.get_tools(readonly_context=mock_readonly_context) - - self.mock_session_manager.create_session.assert_called_once() - call_args = self.mock_session_manager.create_session.call_args - headers = call_args[1]["headers"] - assert headers == { - "X-Tenant-ID": "test-tenant", - "Authorization": "Bearer test_token", - } - @pytest.mark.asyncio async def test_close_success(self): """Test successful cleanup."""