Skip to content

Commit da8adbc

Browse files
committed
sys_prompt support in Agent
# What does this PR do? The current default system prompt for llama3.2 tends to overindex on tool calling and doesn't work well when the prompt does not require tool calling. This PR adds an option to override the default system prompt, and organizes tool-related configs into a new config object. - [ ] Addresses issue (#issue) ## Test Plan LLAMA_STACK_CONFIG=together pytest \-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
1 parent 9fe5e40 commit da8adbc

File tree

4 files changed

+116
-4
lines changed

4 files changed

+116
-4
lines changed

llama_stack/apis/agents/agents.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ToolResponse,
3434
ToolResponseMessage,
3535
UserMessage,
36+
ToolConfig,
3637
)
3738
from llama_stack.apis.safety import SafetyViolation
3839
from llama_stack.apis.tools import ToolDef
@@ -153,11 +154,24 @@ class AgentConfigCommon(BaseModel):
153154
output_shields: Optional[List[str]] = Field(default_factory=list)
154155
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
155156
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
156-
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
157-
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
157+
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
158+
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
159+
tool_config: Optional[ToolConfig] = Field(default=None)
158160

159161
max_infer_iters: int = 10
160162

163+
def model_post_init(self, __context):
164+
if self.tool_config:
165+
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
166+
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
167+
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
168+
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
169+
if self.tool_config is None:
170+
self.tool_config = ToolConfig(
171+
tool_choice=self.tool_choice,
172+
tool_prompt_format=self.tool_prompt_format,
173+
)
174+
161175

162176
@json_schema_type
163177
class AgentConfig(AgentConfigCommon):
@@ -268,6 +282,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
268282
toolgroups: Optional[List[AgentToolGroup]] = None
269283

270284
stream: Optional[bool] = False
285+
tool_config: Optional[ToolConfig] = None
271286

272287

273288
@json_schema_type
@@ -315,6 +330,7 @@ async def create_agent_turn(
315330
stream: Optional[bool] = False,
316331
documents: Optional[List[Document]] = None,
317332
toolgroups: Optional[List[AgentToolGroup]] = None,
333+
tool_config: Optional[ToolConfig] = None,
318334
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
319335

320336
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")

llama_stack/providers/inline/agents/meta_reference/agent_instance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,10 +496,11 @@ async def _run(
496496
tools=[
497497
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
498498
],
499-
tool_prompt_format=self.agent_config.tool_prompt_format,
499+
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
500500
response_format=self.agent_config.response_format,
501501
stream=True,
502502
sampling_params=sampling_params,
503+
tool_config=self.agent_config.tool_config,
503504
):
504505
event = chunk.event
505506
if event.event_type == ChatCompletionResponseEventType.start:

llama_stack/providers/inline/agents/meta_reference/agents.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
Session,
2626
Turn,
2727
)
28-
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
28+
from llama_stack.apis.inference import (
29+
Inference,
30+
ToolConfig,
31+
ToolResponseMessage,
32+
UserMessage,
33+
)
2934
from llama_stack.apis.safety import Safety
3035
from llama_stack.apis.tools import ToolGroups, ToolRuntime
3136
from llama_stack.apis.vector_io import VectorIO
@@ -76,6 +81,12 @@ async def create_agent(
7681
) -> AgentCreateResponse:
7782
agent_id = str(uuid.uuid4())
7883

84+
if agent_config.tool_config is None:
85+
agent_config.tool_config = ToolConfig(
86+
tool_choice=agent_config.tool_choice,
87+
tool_prompt_format=agent_config.tool_prompt_format,
88+
)
89+
7990
await self.persistence_store.set(
8091
key=f"agent:{agent_id}",
8192
value=agent_config.model_dump_json(),
@@ -140,6 +151,7 @@ async def create_agent_turn(
140151
toolgroups: Optional[List[AgentToolGroup]] = None,
141152
documents: Optional[List[Document]] = None,
142153
stream: Optional[bool] = False,
154+
tool_config: Optional[ToolConfig] = None,
143155
) -> AsyncGenerator:
144156
request = AgentTurnCreateRequest(
145157
agent_id=agent_id,
@@ -148,6 +160,7 @@ async def create_agent_turn(
148160
stream=True,
149161
toolgroups=toolgroups,
150162
documents=documents,
163+
tool_config=tool_config,
151164
)
152165
if stream:
153166
return self._create_agent_turn_streaming(request)

tests/client-sdk/agents/test_agents.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,88 @@ def test_custom_tool(llama_stack_client, agent_config):
263263
assert "CustomTool" in logs_str
264264

265265

266+
def test_override_system_message_behavior(llama_stack_client, agent_config):
267+
client_tool = TestClientTool()
268+
agent_config = {
269+
**agent_config,
270+
"instructions": "You are a pirate",
271+
"client_tools": [client_tool.get_tool_definition()],
272+
}
273+
274+
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
275+
session_id = agent.create_session(f"test-session-{uuid4()}")
276+
277+
response = agent.create_turn(
278+
messages=[
279+
{
280+
"role": "user",
281+
"content": "tell me a joke about bicycles",
282+
},
283+
],
284+
session_id=session_id,
285+
)
286+
287+
logs = [str(log) for log in EventLogger().log(response) if log is not None]
288+
logs_str = "".join(logs)
289+
print(logs_str)
290+
# can't tell a joke: "I don't have a function"
291+
assert "function" in logs_str
292+
293+
# with system message behavior replace
294+
instructions = """
295+
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
296+
297+
You are an expert in composing functions. You are given a question and a set of possible functions.
298+
Based on the question, you may or may not need to make one or more function/tool calls to achieve the purpose.
299+
If none of the function can be used, don't return [], instead answer the question directly without using functions. If the given question lacks the parameters required by the function,
300+
also point it out.
301+
302+
{{ function_description }}
303+
"""
304+
agent_config = {
305+
**agent_config,
306+
"instructions": instructions,
307+
"client_tools": [client_tool.get_tool_definition()],
308+
"tool_config": {
309+
"system_message_behavior": "replace",
310+
},
311+
}
312+
313+
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
314+
session_id = agent.create_session(f"test-session-{uuid4()}")
315+
316+
response = agent.create_turn(
317+
messages=[
318+
{
319+
"role": "user",
320+
"content": "tell me a joke about bicycles",
321+
},
322+
],
323+
session_id=session_id,
324+
)
325+
326+
logs = [str(log) for log in EventLogger().log(response) if log is not None]
327+
logs_str = "".join(logs)
328+
print(logs_str)
329+
assert "bicycle" in logs_str
330+
331+
response = agent.create_turn(
332+
messages=[
333+
{
334+
"role": "user",
335+
"content": "What is the boiling point of polyjuice?",
336+
},
337+
],
338+
session_id=session_id,
339+
)
340+
341+
logs = [str(log) for log in EventLogger().log(response) if log is not None]
342+
logs_str = "".join(logs)
343+
print(logs_str)
344+
assert "-100" in logs_str
345+
assert "CustomTool" in logs_str
346+
347+
266348
def test_rag_agent(llama_stack_client, agent_config):
267349
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
268350
documents = [

0 commit comments

Comments
 (0)