From f2756098213ed431401b01cc1cba3b4e3dfd7783 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 3 Feb 2025 12:22:09 -0800 Subject: [PATCH] sys_prompt support in Agent # What does this PR do? The current default system prompt for llama3.2 tends to overindex on tool calling and doesn't work well when the prompt does not require tool calling. This PR adds an option to override the default system prompt, and organizes tool-related configs into a new config object. - [ ] Addresses issue (#issue) ## Test Plan LLAMA_STACK_CONFIG=together pytest \-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- llama_stack/apis/agents/agents.py | 20 ++++- .../agents/meta_reference/agent_instance.py | 3 +- .../inline/agents/meta_reference/agents.py | 15 +++- tests/client-sdk/agents/test_agents.py | 82 +++++++++++++++++++ 4 files changed, 116 insertions(+), 4 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 50bea3d55..54913c798 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -33,6 +33,7 @@ ToolResponse, ToolResponseMessage, UserMessage, + ToolConfig, ) from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef @@ -153,11 +154,24 @@ class AgentConfigCommon(BaseModel): output_shields: Optional[List[str]] = Field(default_factory=list) toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead") + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead") + tool_config: Optional[ToolConfig] = Field(default=None) max_infer_iters: int = 10 + def model_post_init(self, __context): + if self.tool_config: + if self.tool_choice and self.tool_config.tool_choice != self.tool_choice: + raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.") + if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format: + raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.") + if self.tool_config is None: + self.tool_config = ToolConfig( + tool_choice=self.tool_choice, + tool_prompt_format=self.tool_prompt_format, + ) + @json_schema_type class AgentConfig(AgentConfigCommon): @@ -268,6 +282,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): toolgroups: Optional[List[AgentToolGroup]] = None stream: Optional[bool] = False + tool_config: Optional[ToolConfig] = None @json_schema_type @@ -315,6 +330,7 @@ async def create_agent_turn( stream: Optional[bool] = False, documents: Optional[List[Document]] = None, toolgroups: Optional[List[AgentToolGroup]] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET") diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f5ddbab40..51691c546 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -496,10 +496,11 @@ async def _run( tools=[ tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP ], - tool_prompt_format=self.agent_config.tool_prompt_format, + tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, sampling_params=sampling_params, + tool_config=self.agent_config.tool_config, ): event = chunk.event if event.event_type == ChatCompletionResponseEventType.start: diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index b9e3066c6..8f9fa2d82 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -25,7 +25,12 @@ Session, Turn, ) -from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage +from llama_stack.apis.inference import ( + Inference, + ToolConfig, + ToolResponseMessage, + UserMessage, +) from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO @@ -76,6 +81,12 @@ async def create_agent( ) -> AgentCreateResponse: agent_id = str(uuid.uuid4()) + if agent_config.tool_config is None: + agent_config.tool_config = ToolConfig( + tool_choice=agent_config.tool_choice, + tool_prompt_format=agent_config.tool_prompt_format, + ) + await self.persistence_store.set( key=f"agent:{agent_id}", value=agent_config.model_dump_json(), @@ -140,6 +151,7 @@ async def create_agent_turn( toolgroups: Optional[List[AgentToolGroup]] = None, documents: Optional[List[Document]] = None, stream: Optional[bool] = False, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, @@ -148,6 +160,7 @@ async def create_agent_turn( stream=True, toolgroups=toolgroups, documents=documents, + tool_config=tool_config, ) if stream: return self._create_agent_turn_streaming(request) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 7a62da35f..efdb0b4ec 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -263,6 +263,88 @@ def test_custom_tool(llama_stack_client, agent_config): assert "CustomTool" in logs_str +def test_override_system_message_behavior(llama_stack_client, agent_config): + client_tool = TestClientTool() + agent_config = { + **agent_config, + "instructions": "You are a pirate", + "client_tools": [client_tool.get_tool_definition()], + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "tell me a joke about bicycles", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + print(logs_str) + # can't tell a joke: "I don't have a function" + assert "function" in logs_str + + # with system message behavior replace + instructions = """ + You are a helpful assistant. You have access to functions, but you should only use them if they are required. + + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you may or may not need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, don't return [], instead answer the question directly without using functions. If the given question lacks the parameters required by the function, + also point it out. + + {{ function_description }} + """ + agent_config = { + **agent_config, + "instructions": instructions, + "client_tools": [client_tool.get_tool_definition()], + "tool_config": { + "system_message_behavior": "replace", + }, + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "tell me a joke about bicycles", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + print(logs_str) + assert "bicycle" in logs_str + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "What is the boiling point of polyjuice?", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + print(logs_str) + assert "-100" in logs_str + assert "CustomTool" in logs_str + + def test_rag_agent(llama_stack_client, agent_config): urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"] documents = [