diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 97eb85dfdc..64343302fe 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -1015,12 +1015,15 @@ async def run_live( # Pre-processing for live streaming tools # Inspect the tool's parameters to find if it uses LiveRequestQueue invocation_context.active_streaming_tools = {} - # TODO(hangfei): switch to use canonical_tools. - # for shell agents, there is no tools associated with it so we should skip. - if hasattr(invocation_context.agent, 'tools'): + # For shell agents, there is no canonical_tools method so we should skip. + if hasattr(invocation_context.agent, 'canonical_tools'): import inspect - for tool in invocation_context.agent.tools: + # Use canonical_tools to get properly wrapped BaseTool instances + canonical_tools = await invocation_context.agent.canonical_tools( + invocation_context + ) + for tool in canonical_tools: # We use `inspect.signature()` to examine the tool's underlying function (`tool.func`). # This approach is deliberately chosen over `typing.get_type_hints()` for robustness. # @@ -1044,10 +1047,14 @@ async def run_live( if param.annotation is LiveRequestQueue: if not invocation_context.active_streaming_tools: invocation_context.active_streaming_tools = {} + + logger.debug( + 'Register streaming tool with input stream: %s', tool.name + ) active_streaming_tool = ActiveStreamingTool( stream=LiveRequestQueue() ) - invocation_context.active_streaming_tools[tool.__name__] = ( + invocation_context.active_streaming_tools[tool.name] = ( active_streaming_tool ) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index ea40bee0c3..2b82b663be 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -207,7 +207,7 @@ async def run_async( # to avoid "Attempted to exit cancel scope in a different task" errors await runner.close() - if not last_content: + if last_content is None or last_content.parts is None: return '' merged_text = '\n'.join( p.text for p in last_content.parts if p.text and not p.thought diff --git a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py index 0f1a6895d8..a32f43bab8 100644 --- a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py +++ b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py @@ -25,6 +25,8 @@ from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme +from ...features import FeatureName +from ...features import is_feature_enabled from .._gemini_schema_util import _to_gemini_schema from ..base_tool import BaseTool from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool @@ -125,10 +127,17 @@ def _get_declaration(self) -> FunctionDeclaration: if field in schema_dict['required']: schema_dict['required'].remove(field) - parameters = _to_gemini_schema(schema_dict) - function_decl = FunctionDeclaration( - name=self.name, description=self.description, parameters=parameters - ) + if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL): + function_decl = FunctionDeclaration( + name=self.name, + description=self.description, + parameters_json_schema=schema_dict, + ) + else: + parameters = _to_gemini_schema(schema_dict) + function_decl = FunctionDeclaration( + name=self.name, description=self.description, parameters=parameters + ) return function_decl def _prepare_dynamic_euc(self, auth_credential: AuthCredential) -> str: diff --git a/src/google/adk/tools/google_search_agent_tool.py b/src/google/adk/tools/google_search_agent_tool.py index 77cb6fedf9..c88a986b29 100644 --- a/src/google/adk/tools/google_search_agent_tool.py +++ b/src/google/adk/tools/google_search_agent_tool.py @@ -123,7 +123,7 @@ async def run_async( last_content = event.content last_grounding_metadata = event.grounding_metadata - if not last_content: + if last_content is None or last_content.parts is None: return '' merged_text = '\n'.join(p.text for p in last_content.parts if p.text) if isinstance(self.agent, LlmAgent) and self.agent.output_schema: diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index bb44ce73d6..c876bff53a 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -23,6 +23,7 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App @@ -34,6 +35,7 @@ from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.session import Session +from google.adk.tools.function_tool import FunctionTool from google.genai import types import pytest @@ -358,6 +360,88 @@ async def test_run_live_auto_create_session(): assert session is not None +@pytest.mark.asyncio +async def test_run_live_detects_streaming_tools_with_canonical_tools(): + """run_live should detect streaming tools using canonical_tools and tool.name.""" + + # Define streaming tools - one as raw function, one wrapped in FunctionTool + async def raw_streaming_tool( + input_stream: LiveRequestQueue, + ) -> AsyncGenerator[str, None]: + """A raw streaming tool function.""" + yield "test" + + async def wrapped_streaming_tool( + input_stream: LiveRequestQueue, + ) -> AsyncGenerator[str, None]: + """A streaming tool wrapped in FunctionTool.""" + yield "test" + + def non_streaming_tool(param: str) -> str: + """A regular non-streaming tool.""" + return param + + # Create a mock LlmAgent that yields an event and captures invocation context + captured_context = {} + + class StreamingToolsAgent(LlmAgent): + + async def _run_live_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: + # Capture the active_streaming_tools for verification + captured_context["active_streaming_tools"] = ( + invocation_context.active_streaming_tools + ) + yield Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + content=types.Content( + role="model", parts=[types.Part(text="streaming test")] + ), + ) + + agent = StreamingToolsAgent( + name="streaming_agent", + model="gemini-2.0-flash", + tools=[ + raw_streaming_tool, # Raw function + FunctionTool(wrapped_streaming_tool), # Wrapped in FunctionTool + non_streaming_tool, # Non-streaming tool (should not be detected) + ], + ) + + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name="streaming_test_app", + agent=agent, + session_service=session_service, + artifact_service=artifact_service, + auto_create_session=True, + ) + + live_queue = LiveRequestQueue() + + agen = runner.run_live( + user_id="user", + session_id="test_session", + live_request_queue=live_queue, + ) + + event = await agen.__anext__() + await agen.aclose() + + assert event.author == "streaming_agent" + + # Verify streaming tools were detected correctly + active_tools = captured_context.get("active_streaming_tools", {}) + assert "raw_streaming_tool" in active_tools + assert "wrapped_streaming_tool" in active_tools + # Non-streaming tool should not be detected + assert "non_streaming_tool" not in active_tools + + @pytest.mark.asyncio async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch): project_root = tmp_path / "workspace" diff --git a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py index f70af0601e..d5b8407d9f 100644 --- a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py +++ b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py @@ -18,6 +18,8 @@ from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import HttpAuth from google.adk.auth.auth_credential import HttpCredentials +from google.adk.features import FeatureName +from google.adk.features._feature_registry import temporary_feature_override from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import AuthPreparationResult @@ -254,3 +256,23 @@ async def test_run_with_auth_async( args=expected_call_args, tool_context={} ) assert result == {"status": "success", "data": "mock_data"} + + +def test_get_declaration_with_json_schema_feature_enabled(integration_tool): + """Tests the generation of the function declaration with JSON schema feature enabled.""" + with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True): + declaration = integration_tool._get_declaration() + + assert isinstance(declaration, FunctionDeclaration) + assert declaration.name == "test_integration_tool" + assert declaration.description == "Test integration tool description." + assert declaration.parameters is None + assert declaration.parameters_json_schema == { + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User ID"}, + "page_size": {"type": "integer"}, + "filter": {"type": "string"}, + }, + "required": ["user_id"], + } diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index 48a7a995bb..902318715b 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -900,3 +900,45 @@ class CustomOutput(BaseModel): }, 'response_json_schema': {'type': 'object'}, } + + +@mark.asyncio +async def test_run_async_handles_none_parts_in_response(): + """Verify run_async handles None parts in response without raising TypeError.""" + + # Mock model for the tool_agent that returns content with parts=None + # This simulates the condition causing the TypeError + tool_agent_model = testing_utils.MockModel.create( + responses=[ + LlmResponse( + content=types.Content(parts=None), + ) + ] + ) + + tool_agent = Agent( + name='tool_agent', + model=tool_agent_model, + ) + + agent_tool = AgentTool(agent=tool_agent) + + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=tool_agent, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context=invocation_context) + + # This should not raise `TypeError: 'NoneType' object is not iterable`. + tool_result = await agent_tool.run_async( + args={'request': 'test request'}, tool_context=tool_context + ) + + assert tool_result == ''