Skip to content

Commit

Permalink
sys_prompt support in Agent
Browse files Browse the repository at this point in the history
# What does this PR do?

The current default system prompt for llama3.2 tends to overindex on tool calling and doesn't work well when the prompt does not require tool calling.

This PR adds an option to override the default system prompt, and organizes tool-related configs into a new config object.

- [ ] Addresses issue (#issue)


## Test Plan


LLAMA_STACK_CONFIG=together pytest \-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
  • Loading branch information
ehhuang committed Feb 3, 2025
1 parent 77a5273 commit f275609
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 4 deletions.
20 changes: 18 additions & 2 deletions llama_stack/apis/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ToolResponse,
ToolResponseMessage,
UserMessage,
ToolConfig,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
Expand Down Expand Up @@ -153,11 +154,24 @@ class AgentConfigCommon(BaseModel):
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
tool_config: Optional[ToolConfig] = Field(default=None)

max_infer_iters: int = 10

def model_post_init(self, __context):
if self.tool_config:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
if self.tool_config is None:
self.tool_config = ToolConfig(
tool_choice=self.tool_choice,
tool_prompt_format=self.tool_prompt_format,
)


@json_schema_type
class AgentConfig(AgentConfigCommon):
Expand Down Expand Up @@ -268,6 +282,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
toolgroups: Optional[List[AgentToolGroup]] = None

stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None


@json_schema_type
Expand Down Expand Up @@ -315,6 +330,7 @@ async def create_agent_turn(
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...

@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,11 @@ async def _run(
tools=[
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tool_prompt_format=self.agent_config.tool_prompt_format,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
sampling_params=sampling_params,
tool_config=self.agent_config.tool_config,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
Expand Down
15 changes: 14 additions & 1 deletion llama_stack/providers/inline/agents/meta_reference/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
Expand Down Expand Up @@ -76,6 +81,12 @@ async def create_agent(
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())

if agent_config.tool_config is None:
agent_config.tool_config = ToolConfig(
tool_choice=agent_config.tool_choice,
tool_prompt_format=agent_config.tool_prompt_format,
)

await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),
Expand Down Expand Up @@ -140,6 +151,7 @@ async def create_agent_turn(
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
Expand All @@ -148,6 +160,7 @@ async def create_agent_turn(
stream=True,
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
)
if stream:
return self._create_agent_turn_streaming(request)
Expand Down
82 changes: 82 additions & 0 deletions tests/client-sdk/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,88 @@ def test_custom_tool(llama_stack_client, agent_config):
assert "CustomTool" in logs_str


def test_override_system_message_behavior(llama_stack_client, agent_config):
client_tool = TestClientTool()
agent_config = {
**agent_config,
"instructions": "You are a pirate",
"client_tools": [client_tool.get_tool_definition()],
}

agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")

response = agent.create_turn(
messages=[
{
"role": "user",
"content": "tell me a joke about bicycles",
},
],
session_id=session_id,
)

logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
print(logs_str)
# can't tell a joke: "I don't have a function"
assert "function" in logs_str

# with system message behavior replace
instructions = """
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you may or may not need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, don't return [], instead answer the question directly without using functions. If the given question lacks the parameters required by the function,
also point it out.
{{ function_description }}
"""
agent_config = {
**agent_config,
"instructions": instructions,
"client_tools": [client_tool.get_tool_definition()],
"tool_config": {
"system_message_behavior": "replace",
},
}

agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")

response = agent.create_turn(
messages=[
{
"role": "user",
"content": "tell me a joke about bicycles",
},
],
session_id=session_id,
)

logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
print(logs_str)
assert "bicycle" in logs_str

response = agent.create_turn(
messages=[
{
"role": "user",
"content": "What is the boiling point of polyjuice?",
},
],
session_id=session_id,
)

logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
print(logs_str)
assert "-100" in logs_str
assert "CustomTool" in logs_str


def test_rag_agent(llama_stack_client, agent_config):
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
documents = [
Expand Down

0 comments on commit f275609

Please sign in to comment.