Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,12 +1015,15 @@ async def run_live(
# Pre-processing for live streaming tools
# Inspect the tool's parameters to find if it uses LiveRequestQueue
invocation_context.active_streaming_tools = {}
# TODO(hangfei): switch to use canonical_tools.
# for shell agents, there is no tools associated with it so we should skip.
if hasattr(invocation_context.agent, 'tools'):
# For shell agents, there is no canonical_tools method so we should skip.
if hasattr(invocation_context.agent, 'canonical_tools'):
import inspect

for tool in invocation_context.agent.tools:
# Use canonical_tools to get properly wrapped BaseTool instances
canonical_tools = await invocation_context.agent.canonical_tools(
invocation_context
)
for tool in canonical_tools:
# We use `inspect.signature()` to examine the tool's underlying function (`tool.func`).
# This approach is deliberately chosen over `typing.get_type_hints()` for robustness.
#
Expand All @@ -1044,10 +1047,14 @@ async def run_live(
if param.annotation is LiveRequestQueue:
if not invocation_context.active_streaming_tools:
invocation_context.active_streaming_tools = {}

logger.debug(
'Register streaming tool with input stream: %s', tool.name
)
active_streaming_tool = ActiveStreamingTool(
stream=LiveRequestQueue()
)
invocation_context.active_streaming_tools[tool.__name__] = (
invocation_context.active_streaming_tools[tool.name] = (
active_streaming_tool
)

Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def run_async(
# to avoid "Attempted to exit cancel scope in a different task" errors
await runner.close()

if not last_content:
if last_content is None or last_content.parts is None:
return ''
merged_text = '\n'.join(
p.text for p in last_content.parts if p.text and not p.thought
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from ...features import FeatureName
from ...features import is_feature_enabled
from .._gemini_schema_util import _to_gemini_schema
from ..base_tool import BaseTool
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
Expand Down Expand Up @@ -125,10 +127,17 @@ def _get_declaration(self) -> FunctionDeclaration:
if field in schema_dict['required']:
schema_dict['required'].remove(field)

parameters = _to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
function_decl = FunctionDeclaration(
name=self.name,
description=self.description,
parameters_json_schema=schema_dict,
)
else:
parameters = _to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl

def _prepare_dynamic_euc(self, auth_credential: AuthCredential) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/tools/google_search_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def run_async(
last_content = event.content
last_grounding_metadata = event.grounding_metadata

if not last_content:
if last_content is None or last_content.parts is None:
return ''
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
Expand Down
84 changes: 84 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.live_request_queue import LiveRequestQueue
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.apps.app import App
Expand All @@ -34,6 +35,7 @@
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.session import Session
from google.adk.tools.function_tool import FunctionTool
from google.genai import types
import pytest

Expand Down Expand Up @@ -358,6 +360,88 @@ async def test_run_live_auto_create_session():
assert session is not None


@pytest.mark.asyncio
async def test_run_live_detects_streaming_tools_with_canonical_tools():
"""run_live should detect streaming tools using canonical_tools and tool.name."""

# Define streaming tools - one as raw function, one wrapped in FunctionTool
async def raw_streaming_tool(
input_stream: LiveRequestQueue,
) -> AsyncGenerator[str, None]:
"""A raw streaming tool function."""
yield "test"

async def wrapped_streaming_tool(
input_stream: LiveRequestQueue,
) -> AsyncGenerator[str, None]:
"""A streaming tool wrapped in FunctionTool."""
yield "test"

def non_streaming_tool(param: str) -> str:
"""A regular non-streaming tool."""
return param

# Create a mock LlmAgent that yields an event and captures invocation context
captured_context = {}

class StreamingToolsAgent(LlmAgent):

async def _run_live_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
# Capture the active_streaming_tools for verification
captured_context["active_streaming_tools"] = (
invocation_context.active_streaming_tools
)
yield Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
content=types.Content(
role="model", parts=[types.Part(text="streaming test")]
),
)

agent = StreamingToolsAgent(
name="streaming_agent",
model="gemini-2.0-flash",
tools=[
raw_streaming_tool, # Raw function
FunctionTool(wrapped_streaming_tool), # Wrapped in FunctionTool
non_streaming_tool, # Non-streaming tool (should not be detected)
],
)

session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name="streaming_test_app",
agent=agent,
session_service=session_service,
artifact_service=artifact_service,
auto_create_session=True,
)

live_queue = LiveRequestQueue()

agen = runner.run_live(
user_id="user",
session_id="test_session",
live_request_queue=live_queue,
)

event = await agen.__anext__()
await agen.aclose()

assert event.author == "streaming_agent"

# Verify streaming tools were detected correctly
active_tools = captured_context.get("active_streaming_tools", {})
assert "raw_streaming_tool" in active_tools
assert "wrapped_streaming_tool" in active_tools
# Non-streaming tool should not be detected
assert "non_streaming_tool" not in active_tools


@pytest.mark.asyncio
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
project_root = tmp_path / "workspace"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import AuthPreparationResult
Expand Down Expand Up @@ -254,3 +256,23 @@ async def test_run_with_auth_async(
args=expected_call_args, tool_context={}
)
assert result == {"status": "success", "data": "mock_data"}


def test_get_declaration_with_json_schema_feature_enabled(integration_tool):
"""Tests the generation of the function declaration with JSON schema feature enabled."""
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True):
declaration = integration_tool._get_declaration()

assert isinstance(declaration, FunctionDeclaration)
assert declaration.name == "test_integration_tool"
assert declaration.description == "Test integration tool description."
assert declaration.parameters is None
assert declaration.parameters_json_schema == {
"type": "object",
"properties": {
"user_id": {"type": "string", "description": "User ID"},
"page_size": {"type": "integer"},
"filter": {"type": "string"},
},
"required": ["user_id"],
}
42 changes: 42 additions & 0 deletions tests/unittests/tools/test_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,3 +900,45 @@ class CustomOutput(BaseModel):
},
'response_json_schema': {'type': 'object'},
}


@mark.asyncio
async def test_run_async_handles_none_parts_in_response():
"""Verify run_async handles None parts in response without raising TypeError."""

# Mock model for the tool_agent that returns content with parts=None
# This simulates the condition causing the TypeError
tool_agent_model = testing_utils.MockModel.create(
responses=[
LlmResponse(
content=types.Content(parts=None),
)
]
)

tool_agent = Agent(
name='tool_agent',
model=tool_agent_model,
)

agent_tool = AgentTool(agent=tool_agent)

session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)

invocation_context = InvocationContext(
invocation_id='invocation_id',
agent=tool_agent,
session=session,
session_service=session_service,
)
tool_context = ToolContext(invocation_context=invocation_context)

# This should not raise `TypeError: 'NoneType' object is not iterable`.
tool_result = await agent_tool.run_async(
args={'request': 'test request'}, tool_context=tool_context
)

assert tool_result == ''
Loading