diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index a644cb8b90..fccde1cb6f 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -563,6 +563,46 @@ def validate_name(cls, value: str): ) return value + @field_validator('sub_agents', mode='after') + @classmethod + def validate_sub_agents_unique_names( + cls, value: list[BaseAgent] + ) -> list[BaseAgent]: + """Validates that all sub-agents have unique names. + + Args: + value: The list of sub-agents to validate. + + Returns: + The validated list of sub-agents. + + Raises: + ValueError: If duplicate sub-agent names are found. + """ + if not value: + return value + + seen_names: set[str] = set() + duplicates: set[str] = set() + + for sub_agent in value: + name = sub_agent.name + if name in seen_names: + duplicates.add(name) + else: + seen_names.add(name) + + if duplicates: + duplicate_names_str = ', '.join( + f'`{name}`' for name in sorted(duplicates) + ) + raise ValueError( + f'Found duplicate sub-agent names: {duplicate_names_str}. ' + 'All sub-agents must have unique names.' + ) + + return value + def __set_parent_agent_for_sub_agents(self) -> BaseAgent: for sub_agent in self.sub_agents: if sub_agent.parent_agent is not None: diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index c9712974ea..9e3698b190 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -39,8 +39,8 @@ from litellm import acompletion from litellm import ChatCompletionAssistantMessage from litellm import ChatCompletionAssistantToolCall -from litellm import ChatCompletionDeveloperMessage from litellm import ChatCompletionMessageToolCall +from litellm import ChatCompletionSystemMessage from litellm import ChatCompletionToolMessage from litellm import ChatCompletionUserMessage from litellm import completion @@ -983,8 +983,8 @@ def _get_completion_inputs( if llm_request.config.system_instruction: messages.insert( 0, - ChatCompletionDeveloperMessage( - role="developer", + ChatCompletionSystemMessage( + role="system", content=llm_request.config.system_instruction, ), ) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 702f6e43aa..46d8616619 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -45,11 +45,22 @@ class AgentTool(BaseTool): Attributes: agent: The agent to wrap. skip_summarization: Whether to skip summarization of the agent output. + include_plugins: Whether to propagate plugins from the parent runner context + to the agent's runner. When True (default), the agent will inherit all + plugins from its parent. Set to False to run the agent with an isolated + plugin environment. """ - def __init__(self, agent: BaseAgent, skip_summarization: bool = False): + def __init__( + self, + agent: BaseAgent, + skip_summarization: bool = False, + *, + include_plugins: bool = True, + ): self.agent = agent self.skip_summarization: bool = skip_summarization + self.include_plugins = include_plugins super().__init__(name=agent.name, description=agent.description) @@ -68,6 +79,8 @@ def _get_declaration(self) -> types.FunctionDeclaration: result = _automatic_function_calling_util.build_function_declaration( func=self.agent.input_schema, variant=self._api_variant ) + # Override the description with the agent's description + result.description = self.agent.description else: result = types.FunctionDeclaration( parameters=types.Schema( @@ -130,6 +143,11 @@ async def run_async( invocation_context.app_name if invocation_context else None ) child_app_name = parent_app_name or self.agent.name + plugins = ( + tool_context._invocation_context.plugin_manager.plugins + if self.include_plugins + else None + ) runner = Runner( app_name=child_app_name, agent=self.agent, @@ -137,7 +155,7 @@ async def run_async( session_service=InMemorySessionService(), memory_service=InMemoryMemoryService(), credential_service=tool_context._invocation_context.credential_service, - plugins=list(tool_context._invocation_context.plugin_manager.plugins), + plugins=plugins, ) state_dict = { @@ -192,7 +210,9 @@ def from_config( agent_tool_config.agent, config_abs_path ) return cls( - agent=agent, skip_summarization=agent_tool_config.skip_summarization + agent=agent, + skip_summarization=agent_tool_config.skip_summarization, + include_plugins=agent_tool_config.include_plugins, ) @@ -204,3 +224,6 @@ class AgentToolConfig(BaseToolConfig): skip_summarization: bool = False """Whether to skip summarization of the agent output.""" + + include_plugins: bool = True + """Whether to include plugins from parent runner context.""" diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 663179f670..860cc8d4f0 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -854,6 +854,110 @@ def test_set_parent_agent_for_sub_agent_twice( ) +def test_validate_sub_agents_unique_names_single_duplicate( + request: pytest.FixtureRequest, +): + """Test that duplicate sub-agent names raise ValueError.""" + duplicate_name = f'{request.function.__name__}_duplicate_agent' + sub_agent_1 = _TestingAgent(name=duplicate_name) + sub_agent_2 = _TestingAgent(name=duplicate_name) + + with pytest.raises(ValueError, match='Found duplicate sub-agent names'): + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=[sub_agent_1, sub_agent_2], + ) + + +def test_validate_sub_agents_unique_names_multiple_duplicates( + request: pytest.FixtureRequest, +): + """Test that multiple duplicate sub-agent names are all reported.""" + duplicate_name_1 = f'{request.function.__name__}_duplicate_1' + duplicate_name_2 = f'{request.function.__name__}_duplicate_2' + + sub_agents = [ + _TestingAgent(name=duplicate_name_1), + _TestingAgent(name=f'{request.function.__name__}_unique'), + _TestingAgent(name=duplicate_name_1), # First duplicate + _TestingAgent(name=duplicate_name_2), + _TestingAgent(name=duplicate_name_2), # Second duplicate + ] + + with pytest.raises(ValueError) as exc_info: + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + error_message = str(exc_info.value) + # Verify each duplicate name appears exactly once in the error message + assert error_message.count(duplicate_name_1) == 1 + assert error_message.count(duplicate_name_2) == 1 + # Verify both duplicate names are present + assert duplicate_name_1 in error_message + assert duplicate_name_2 in error_message + + +def test_validate_sub_agents_unique_names_triple_duplicate( + request: pytest.FixtureRequest, +): + """Test that a name appearing three times is reported only once.""" + duplicate_name = f'{request.function.__name__}_triple_duplicate' + + sub_agents = [ + _TestingAgent(name=duplicate_name), + _TestingAgent(name=f'{request.function.__name__}_unique'), + _TestingAgent(name=duplicate_name), # Second occurrence + _TestingAgent(name=duplicate_name), # Third occurrence + ] + + with pytest.raises(ValueError) as exc_info: + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + error_message = str(exc_info.value) + # Verify the duplicate name appears exactly once in the error message + # (not three times even though it appears three times in the list) + assert error_message.count(duplicate_name) == 1 + assert duplicate_name in error_message + + +def test_validate_sub_agents_unique_names_no_duplicates( + request: pytest.FixtureRequest, +): + """Test that unique sub-agent names pass validation.""" + sub_agents = [ + _TestingAgent(name=f'{request.function.__name__}_sub_agent_1'), + _TestingAgent(name=f'{request.function.__name__}_sub_agent_2'), + _TestingAgent(name=f'{request.function.__name__}_sub_agent_3'), + ] + + parent = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + assert len(parent.sub_agents) == 3 + assert parent.sub_agents[0].name == f'{request.function.__name__}_sub_agent_1' + assert parent.sub_agents[1].name == f'{request.function.__name__}_sub_agent_2' + assert parent.sub_agents[2].name == f'{request.function.__name__}_sub_agent_3' + + +def test_validate_sub_agents_unique_names_empty_list( + request: pytest.FixtureRequest, +): + """Test that empty sub-agents list passes validation.""" + parent = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=[], + ) + + assert len(parent.sub_agents) == 0 + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index c486806d37..f65fc77a61 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1195,7 +1195,7 @@ async def test_generate_content_async_with_system_instruction( _, kwargs = mock_acompletion.call_args assert kwargs["model"] == "test_model" - assert kwargs["messages"][0]["role"] == "developer" + assert kwargs["messages"][0]["role"] == "system" assert kwargs["messages"][0]["content"] == "Test system instruction" assert kwargs["messages"][1]["role"] == "user" assert kwargs["messages"][1]["content"] == "Test prompt" diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index 85e8b9caa1..a9723b4347 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -570,3 +570,135 @@ class CustomInput(BaseModel): # Should have string response schema for VERTEX_AI when no output_schema assert declaration.response is not None assert declaration.response.type == types.Type.STRING + + +def test_include_plugins_default_true(): + """Test that plugins are propagated by default (include_plugins=True).""" + + # Create a test plugin that tracks callbacks + class TrackingPlugin(BasePlugin): + + def __init__(self, name: str): + super().__init__(name) + self.before_agent_calls = 0 + + async def before_agent_callback(self, **kwargs): + self.before_agent_calls += 1 + + tracking_plugin = TrackingPlugin(name='tracking') + + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent)], # Default include_plugins=True + ) + + runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin]) + runner.run('test1') + + # Plugin should be called for both root_agent and tool_agent + assert tracking_plugin.before_agent_calls == 2 + + +def test_include_plugins_explicit_true(): + """Test that plugins are propagated when include_plugins=True.""" + + class TrackingPlugin(BasePlugin): + + def __init__(self, name: str): + super().__init__(name) + self.before_agent_calls = 0 + + async def before_agent_callback(self, **kwargs): + self.before_agent_calls += 1 + + tracking_plugin = TrackingPlugin(name='tracking') + + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent, include_plugins=True)], + ) + + runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin]) + runner.run('test1') + + # Plugin should be called for both root_agent and tool_agent + assert tracking_plugin.before_agent_calls == 2 + + +def test_include_plugins_false(): + """Test that plugins are NOT propagated when include_plugins=False.""" + + class TrackingPlugin(BasePlugin): + + def __init__(self, name: str): + super().__init__(name) + self.before_agent_calls = 0 + + async def before_agent_callback(self, **kwargs): + self.before_agent_calls += 1 + + tracking_plugin = TrackingPlugin(name='tracking') + + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent, include_plugins=False)], + ) + + runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin]) + runner.run('test1') + + # Plugin should only be called for root_agent, not tool_agent + assert tracking_plugin.before_agent_calls == 1 + + +def test_agent_tool_description_with_input_schema(): + """Test that agent description is propagated when using input_schema.""" + + class CustomInput(BaseModel): + """This is the Pydantic model docstring.""" + + custom_input: str + + agent_description = 'This is the agent description that should be used' + tool_agent = Agent( + name='tool_agent', + model=testing_utils.MockModel.create(responses=['test response']), + description=agent_description, + input_schema=CustomInput, + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + # The description should come from the agent, not the Pydantic model + assert declaration.description == agent_description