Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,46 @@ def validate_name(cls, value: str):
)
return value

@field_validator('sub_agents', mode='after')
@classmethod
def validate_sub_agents_unique_names(
cls, value: list[BaseAgent]
) -> list[BaseAgent]:
"""Validates that all sub-agents have unique names.

Args:
value: The list of sub-agents to validate.

Returns:
The validated list of sub-agents.

Raises:
ValueError: If duplicate sub-agent names are found.
"""
if not value:
return value

seen_names: set[str] = set()
duplicates: set[str] = set()

for sub_agent in value:
name = sub_agent.name
if name in seen_names:
duplicates.add(name)
else:
seen_names.add(name)

if duplicates:
duplicate_names_str = ', '.join(
f'`{name}`' for name in sorted(duplicates)
)
raise ValueError(
f'Found duplicate sub-agent names: {duplicate_names_str}. '
'All sub-agents must have unique names.'
)

return value

def __set_parent_agent_for_sub_agents(self) -> BaseAgent:
for sub_agent in self.sub_agents:
if sub_agent.parent_agent is not None:
Expand Down
6 changes: 3 additions & 3 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from litellm import acompletion
from litellm import ChatCompletionAssistantMessage
from litellm import ChatCompletionAssistantToolCall
from litellm import ChatCompletionDeveloperMessage
from litellm import ChatCompletionMessageToolCall
from litellm import ChatCompletionSystemMessage
from litellm import ChatCompletionToolMessage
from litellm import ChatCompletionUserMessage
from litellm import completion
Expand Down Expand Up @@ -983,8 +983,8 @@ def _get_completion_inputs(
if llm_request.config.system_instruction:
messages.insert(
0,
ChatCompletionDeveloperMessage(
role="developer",
ChatCompletionSystemMessage(
role="system",
content=llm_request.config.system_instruction,
),
)
Expand Down
29 changes: 26 additions & 3 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,22 @@ class AgentTool(BaseTool):
Attributes:
agent: The agent to wrap.
skip_summarization: Whether to skip summarization of the agent output.
include_plugins: Whether to propagate plugins from the parent runner context
to the agent's runner. When True (default), the agent will inherit all
plugins from its parent. Set to False to run the agent with an isolated
plugin environment.
"""

def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
def __init__(
self,
agent: BaseAgent,
skip_summarization: bool = False,
*,
include_plugins: bool = True,
):
self.agent = agent
self.skip_summarization: bool = skip_summarization
self.include_plugins = include_plugins

super().__init__(name=agent.name, description=agent.description)

Expand All @@ -68,6 +79,8 @@ def _get_declaration(self) -> types.FunctionDeclaration:
result = _automatic_function_calling_util.build_function_declaration(
func=self.agent.input_schema, variant=self._api_variant
)
# Override the description with the agent's description
result.description = self.agent.description
else:
result = types.FunctionDeclaration(
parameters=types.Schema(
Expand Down Expand Up @@ -130,14 +143,19 @@ async def run_async(
invocation_context.app_name if invocation_context else None
)
child_app_name = parent_app_name or self.agent.name
plugins = (
tool_context._invocation_context.plugin_manager.plugins
if self.include_plugins
else None
)
runner = Runner(
app_name=child_app_name,
agent=self.agent,
artifact_service=ForwardingArtifactService(tool_context),
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
credential_service=tool_context._invocation_context.credential_service,
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
plugins=plugins,
)

state_dict = {
Expand Down Expand Up @@ -192,7 +210,9 @@ def from_config(
agent_tool_config.agent, config_abs_path
)
return cls(
agent=agent, skip_summarization=agent_tool_config.skip_summarization
agent=agent,
skip_summarization=agent_tool_config.skip_summarization,
include_plugins=agent_tool_config.include_plugins,
)


Expand All @@ -204,3 +224,6 @@ class AgentToolConfig(BaseToolConfig):

skip_summarization: bool = False
"""Whether to skip summarization of the agent output."""

include_plugins: bool = True
"""Whether to include plugins from parent runner context."""
104 changes: 104 additions & 0 deletions tests/unittests/agents/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,110 @@ def test_set_parent_agent_for_sub_agent_twice(
)


def test_validate_sub_agents_unique_names_single_duplicate(
request: pytest.FixtureRequest,
):
"""Test that duplicate sub-agent names raise ValueError."""
duplicate_name = f'{request.function.__name__}_duplicate_agent'
sub_agent_1 = _TestingAgent(name=duplicate_name)
sub_agent_2 = _TestingAgent(name=duplicate_name)

with pytest.raises(ValueError, match='Found duplicate sub-agent names'):
_ = _TestingAgent(
name=f'{request.function.__name__}_parent',
sub_agents=[sub_agent_1, sub_agent_2],
)


def test_validate_sub_agents_unique_names_multiple_duplicates(
request: pytest.FixtureRequest,
):
"""Test that multiple duplicate sub-agent names are all reported."""
duplicate_name_1 = f'{request.function.__name__}_duplicate_1'
duplicate_name_2 = f'{request.function.__name__}_duplicate_2'

sub_agents = [
_TestingAgent(name=duplicate_name_1),
_TestingAgent(name=f'{request.function.__name__}_unique'),
_TestingAgent(name=duplicate_name_1), # First duplicate
_TestingAgent(name=duplicate_name_2),
_TestingAgent(name=duplicate_name_2), # Second duplicate
]

with pytest.raises(ValueError) as exc_info:
_ = _TestingAgent(
name=f'{request.function.__name__}_parent',
sub_agents=sub_agents,
)

error_message = str(exc_info.value)
# Verify each duplicate name appears exactly once in the error message
assert error_message.count(duplicate_name_1) == 1
assert error_message.count(duplicate_name_2) == 1
# Verify both duplicate names are present
assert duplicate_name_1 in error_message
assert duplicate_name_2 in error_message


def test_validate_sub_agents_unique_names_triple_duplicate(
request: pytest.FixtureRequest,
):
"""Test that a name appearing three times is reported only once."""
duplicate_name = f'{request.function.__name__}_triple_duplicate'

sub_agents = [
_TestingAgent(name=duplicate_name),
_TestingAgent(name=f'{request.function.__name__}_unique'),
_TestingAgent(name=duplicate_name), # Second occurrence
_TestingAgent(name=duplicate_name), # Third occurrence
]

with pytest.raises(ValueError) as exc_info:
_ = _TestingAgent(
name=f'{request.function.__name__}_parent',
sub_agents=sub_agents,
)

error_message = str(exc_info.value)
# Verify the duplicate name appears exactly once in the error message
# (not three times even though it appears three times in the list)
assert error_message.count(duplicate_name) == 1
assert duplicate_name in error_message


def test_validate_sub_agents_unique_names_no_duplicates(
request: pytest.FixtureRequest,
):
"""Test that unique sub-agent names pass validation."""
sub_agents = [
_TestingAgent(name=f'{request.function.__name__}_sub_agent_1'),
_TestingAgent(name=f'{request.function.__name__}_sub_agent_2'),
_TestingAgent(name=f'{request.function.__name__}_sub_agent_3'),
]

parent = _TestingAgent(
name=f'{request.function.__name__}_parent',
sub_agents=sub_agents,
)

assert len(parent.sub_agents) == 3
assert parent.sub_agents[0].name == f'{request.function.__name__}_sub_agent_1'
assert parent.sub_agents[1].name == f'{request.function.__name__}_sub_agent_2'
assert parent.sub_agents[2].name == f'{request.function.__name__}_sub_agent_3'


def test_validate_sub_agents_unique_names_empty_list(
request: pytest.FixtureRequest,
):
"""Test that empty sub-agents list passes validation."""
parent = _TestingAgent(
name=f'{request.function.__name__}_parent',
sub_agents=[],
)

assert len(parent.sub_agents) == 0


if __name__ == '__main__':
pytest.main([__file__])

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ async def test_generate_content_async_with_system_instruction(

_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "developer"
assert kwargs["messages"][0]["role"] == "system"
assert kwargs["messages"][0]["content"] == "Test system instruction"
assert kwargs["messages"][1]["role"] == "user"
assert kwargs["messages"][1]["content"] == "Test prompt"
Expand Down
132 changes: 132 additions & 0 deletions tests/unittests/tools/test_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,135 @@ class CustomInput(BaseModel):
# Should have string response schema for VERTEX_AI when no output_schema
assert declaration.response is not None
assert declaration.response.type == types.Type.STRING


def test_include_plugins_default_true():
"""Test that plugins are propagated by default (include_plugins=True)."""

# Create a test plugin that tracks callbacks
class TrackingPlugin(BasePlugin):

def __init__(self, name: str):
super().__init__(name)
self.before_agent_calls = 0

async def before_agent_callback(self, **kwargs):
self.before_agent_calls += 1

tracking_plugin = TrackingPlugin(name='tracking')

mock_model = testing_utils.MockModel.create(
responses=[function_call_no_schema, 'response1', 'response2']
)

tool_agent = Agent(
name='tool_agent',
model=mock_model,
)

root_agent = Agent(
name='root_agent',
model=mock_model,
tools=[AgentTool(agent=tool_agent)], # Default include_plugins=True
)

runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
runner.run('test1')

# Plugin should be called for both root_agent and tool_agent
assert tracking_plugin.before_agent_calls == 2


def test_include_plugins_explicit_true():
"""Test that plugins are propagated when include_plugins=True."""

class TrackingPlugin(BasePlugin):

def __init__(self, name: str):
super().__init__(name)
self.before_agent_calls = 0

async def before_agent_callback(self, **kwargs):
self.before_agent_calls += 1

tracking_plugin = TrackingPlugin(name='tracking')

mock_model = testing_utils.MockModel.create(
responses=[function_call_no_schema, 'response1', 'response2']
)

tool_agent = Agent(
name='tool_agent',
model=mock_model,
)

root_agent = Agent(
name='root_agent',
model=mock_model,
tools=[AgentTool(agent=tool_agent, include_plugins=True)],
)

runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
runner.run('test1')

# Plugin should be called for both root_agent and tool_agent
assert tracking_plugin.before_agent_calls == 2


def test_include_plugins_false():
"""Test that plugins are NOT propagated when include_plugins=False."""

class TrackingPlugin(BasePlugin):

def __init__(self, name: str):
super().__init__(name)
self.before_agent_calls = 0

async def before_agent_callback(self, **kwargs):
self.before_agent_calls += 1

tracking_plugin = TrackingPlugin(name='tracking')

mock_model = testing_utils.MockModel.create(
responses=[function_call_no_schema, 'response1', 'response2']
)

tool_agent = Agent(
name='tool_agent',
model=mock_model,
)

root_agent = Agent(
name='root_agent',
model=mock_model,
tools=[AgentTool(agent=tool_agent, include_plugins=False)],
)

runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
runner.run('test1')

# Plugin should only be called for root_agent, not tool_agent
assert tracking_plugin.before_agent_calls == 1


def test_agent_tool_description_with_input_schema():
"""Test that agent description is propagated when using input_schema."""

class CustomInput(BaseModel):
"""This is the Pydantic model docstring."""

custom_input: str

agent_description = 'This is the agent description that should be used'
tool_agent = Agent(
name='tool_agent',
model=testing_utils.MockModel.create(responses=['test response']),
description=agent_description,
input_schema=CustomInput,
)

agent_tool = AgentTool(agent=tool_agent)
declaration = agent_tool._get_declaration()

# The description should come from the agent, not the Pydantic model
assert declaration.description == agent_description
Loading