Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ extensions = [
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
"llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex.
"lxml>=5.3.0", # For load_web_page tool.
"toolbox-core>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset
"toolbox-adk>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset
]

otel-gcp = ["opentelemetry-instrumentation-google-genai>=0.3b0, <1.0.0"]
Expand Down
8 changes: 7 additions & 1 deletion src/google/adk/flows/llm_flows/audio_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,16 @@ async def _flush_cache_to_services(
artifact_ref = f'artifact://{invocation_context.app_name}/{invocation_context.user_id}/{invocation_context.session.id}/_adk_live/{filename}#{revision_id}'

# Create event with file data reference to add to session
# For model events, author should be the agent name, not the role
author = (
invocation_context.agent.name
if audio_cache[0].role == 'model'
else audio_cache[0].role
)
audio_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=audio_cache[0].role,
author=author,
content=types.Content(
role=audio_cache[0].role,
parts=[
Expand Down
15 changes: 8 additions & 7 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,12 @@ async def _postprocess_live(
for event in flushed_events:
yield event
if flushed_events:
# NOTE below return is O.K. for now, because currently we only flush
# events on interrupted or turn_complete. turn_complete is a pure
# control event and interrupted is not with content but those content
# is ignorable because model is already interrupted. If we have other
# case to flush events in the future that are not pure control events,
# we should not return here.
return

# Builds the event.
Expand Down Expand Up @@ -968,13 +974,8 @@ async def _handle_control_event_flush(
flush_user_audio=True,
flush_model_audio=True,
)
elif getattr(llm_response, 'generation_complete', False):
# model generation complete so we can flush model audio
return await self.audio_cache_manager.flush_caches(
invocation_context,
flush_user_audio=False,
flush_model_audio=True,
)
# TODO: Once generation_complete is surfaced on LlmResponse, we can flush
# model audio here (flush_user_audio=False, flush_model_audio=True).
return []

async def _run_and_handle_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from ....agents.readonly_context import ReadonlyContext
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ....features import FeatureName
from ....features import is_feature_enabled
from ..._gemini_schema_util import _to_gemini_schema
from ..._gemini_schema_util import _to_snake_case
from ...base_tool import BaseTool
Expand Down Expand Up @@ -221,10 +223,17 @@ def from_parsed_operation_str(
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self._operation_parser.get_json_schema()
parameters = _to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
function_decl = FunctionDeclaration(
name=self.name,
description=self.description,
parameters_json_schema=schema_dict,
)
else:
parameters = _to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl

def configure_auth_scheme(
Expand Down
104 changes: 55 additions & 49 deletions src/google/adk/tools/toolbox_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any
from typing import Callable
from typing import List
from typing import Mapping
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

import toolbox_core as toolbox
from typing_extensions import override

from ..agents.readonly_context import ReadonlyContext
from .base_tool import BaseTool
from .base_toolset import BaseToolset
from .function_tool import FunctionTool

if TYPE_CHECKING:
from toolbox_adk import CredentialConfig


class ToolboxToolset(BaseToolset):
"""A class that provides access to toolbox toolsets.

This class acts as a bridge to the `toolbox-adk` package.
You must install `toolbox-adk` to use this class.

Example:
```python
toolbox_toolset = ToolboxToolset("http://127.0.0.1:5000",
toolset_name="my-toolset")
from toolbox_adk import CredentialStrategy

toolbox_toolset = ToolboxToolset(
server_url="http://127.0.0.1:5000",
# toolset_name and tool_names are optional. If omitted, all tools are
loaded.
credentials=CredentialStrategy.toolbox_identity()
)
```
"""
Expand All @@ -44,64 +56,58 @@ def __init__(
server_url: str,
toolset_name: Optional[str] = None,
tool_names: Optional[List[str]] = None,
auth_token_getters: Optional[dict[str, Callable[[], str]]] = None,
auth_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
bound_params: Optional[
Mapping[str, Union[Callable[[], Any], Any]]
] = None,
credentials: Optional[CredentialConfig] = None,
additional_headers: Optional[Mapping[str, str]] = None,
**kwargs,
):
"""Args:

server_url: The URL of the toolbox server.
toolset_name: The name of the toolbox toolset to load.
tool_names: The names of the tools to load.
auth_token_getters: A mapping of authentication service names to
callables that return the corresponding authentication token. see:
https://github.com/googleapis/mcp-toolbox-sdk-python/tree/main/packages/toolbox-core#authenticating-tools
for details.
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed. see:
https://github.com/googleapis/mcp-toolbox-sdk-python/tree/main/packages/toolbox-core#binding-parameter-values
for details.
The resulting ToolboxToolset will contain both tools loaded by tool_names
and toolset_name.
server_url: The URL of the toolbox server.
toolset_name: The name of the toolbox toolset to load.
tool_names: The names of the tools to load.
auth_token_getters: (Deprecated) Map of auth token getters.
bound_params: Parameters to bind to the tools.
credentials: (Optional) toolbox_adk.CredentialConfig object.
additional_headers: (Optional) Static headers dictionary.
**kwargs: Additional arguments passed to the underlying
toolbox_adk.ToolboxToolset.
"""
if not tool_names and not toolset_name:
raise ValueError("tool_names and toolset_name cannot both be None")
if not toolset_name and not tool_names:
raise ValueError(
"Either 'toolset_name' or 'tool_names' must be provided."
)

try:
from toolbox_adk import ToolboxToolset as RealToolboxToolset # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError(
"ToolboxToolset requires the 'toolbox-adk' package. "
"Please install it using `pip install toolbox-adk`."
) from exc

super().__init__()
self._server_url = server_url
self._toolbox_client = toolbox.ToolboxClient(server_url)
self._toolset_name = toolset_name
self._tool_names = tool_names
self._auth_token_getters = auth_token_getters or {}
self._bound_params = bound_params or {}

self._delegate = RealToolboxToolset(
server_url=server_url,
toolset_name=toolset_name,
tool_names=tool_names,
credentials=credentials,
additional_headers=additional_headers,
bound_params=bound_params,
auth_token_getters=auth_token_getters,
**kwargs,
)

@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> list[BaseTool]:
tools = []
if self._toolset_name:
tools.extend([
FunctionTool(tool)
for tool in await self._toolbox_client.load_toolset(
self._toolset_name,
auth_token_getters=self._auth_token_getters,
bound_params=self._bound_params,
)
])
if self._tool_names:
tools.extend([
FunctionTool(
await self._toolbox_client.load_tool(
tool_name,
auth_token_getters=self._auth_token_getters,
bound_params=self._bound_params,
)
)
for tool_name in self._tool_names
])
return tools
return await self._delegate.get_tools(readonly_context)

@override
async def close(self):
self._toolbox_client.close()
await self._delegate.close()
51 changes: 51 additions & 0 deletions tests/unittests/flows/llm_flows/test_audio_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,54 @@ async def test_filename_uses_first_chunk_timestamp(self):
assert filename.startswith(
f'adk_live_audio_storage_input_audio_{expected_timestamp_ms}'
)

@pytest.mark.asyncio
async def test_flush_event_author_for_user_audio(self):
"""Test that flushed user audio events have 'user' as author."""
invocation_context = await testing_utils.create_invocation_context(
testing_utils.create_test_agent()
)

# Set up mock artifact service
mock_artifact_service = AsyncMock()
mock_artifact_service.save_artifact.return_value = 123
invocation_context.artifact_service = mock_artifact_service

# Cache user input audio
input_blob = types.Blob(data=b'user_audio_data', mime_type='audio/pcm')
self.manager.cache_audio(invocation_context, input_blob, 'input')

# Flush cache and get events
events = await self.manager.flush_caches(
invocation_context, flush_user_audio=True, flush_model_audio=False
)

# Verify event author is 'user' for user audio
assert len(events) == 1
assert events[0].author == 'user'
assert events[0].content.role == 'user'

@pytest.mark.asyncio
async def test_flush_event_author_for_model_audio(self):
"""Test that flushed model audio events have agent name as author, not 'model'."""
agent = testing_utils.create_test_agent(name='my_test_agent')
invocation_context = await testing_utils.create_invocation_context(agent)

# Set up mock artifact service
mock_artifact_service = AsyncMock()
mock_artifact_service.save_artifact.return_value = 123
invocation_context.artifact_service = mock_artifact_service

# Cache model output audio
output_blob = types.Blob(data=b'model_audio_data', mime_type='audio/wav')
self.manager.cache_audio(invocation_context, output_blob, 'output')

# Flush cache and get events
events = await self.manager.flush_caches(
invocation_context, flush_user_audio=False, flush_model_audio=True
)

# Verify event author is agent name (not 'model') for model audio
assert len(events) == 1
assert events[0].author == 'my_test_agent' # Agent name, not 'model'
assert events[0].content.role == 'model' # Role is still 'model'
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.sessions.state import State
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
from google.adk.tools.openapi_tool.common.common import ApiParameter
Expand Down Expand Up @@ -204,6 +206,45 @@ def test_get_declaration(
assert declaration.description == "Test description"
assert isinstance(declaration.parameters, Schema)

def test_get_declaration_with_json_schema_feature_enabled(
self, sample_endpoint, sample_operation
):
"""Test that _get_declaration uses parameters_json_schema when feature is enabled."""
mock_parser = MagicMock(spec=OperationParser)
mock_parser.get_json_schema.return_value = {
"type": "object",
"properties": {
"test_param": {"type": "string"},
},
"required": ["test_param"],
}

tool = RestApiTool(
name="test_tool",
description="Test description",
endpoint=sample_endpoint,
operation=sample_operation,
should_parse_operation=False,
)
tool._operation_parser = mock_parser

with temporary_feature_override(
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, True
):
declaration = tool._get_declaration()

assert isinstance(declaration, FunctionDeclaration)
assert declaration.name == "test_tool"
assert declaration.description == "Test description"
assert declaration.parameters is None
assert declaration.parameters_json_schema == {
"type": "object",
"properties": {
"test_param": {"type": "string"},
},
"required": ["test_param"],
}

@patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
)
Expand Down
Loading