Skip to content

Commit

Permalink
Merge pull request #322 from PrefectHQ/llm-instructions
Browse files Browse the repository at this point in the history
Add llm-specific prompt instructions
  • Loading branch information
jlowin authored Sep 19, 2024
2 parents 91c8bc7 + 5065e09 commit bbf9c20
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 12 deletions.
29 changes: 26 additions & 3 deletions src/controlflow/llm/rules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import textwrap
from typing import Optional

from langchain_anthropic import ChatAnthropic
Expand All @@ -16,6 +17,8 @@ class LLMRules(ControlFlowModel):
necessary.
"""

model: Optional[BaseChatModel]

# require at least one non-system message
require_at_least_one_message: bool = False

Expand All @@ -41,10 +44,30 @@ class LLMRules(ControlFlowModel):
# the name associated with a message must conform to a specific format
require_message_name_format: Optional[str] = None

def model_instructions(self) -> Optional[list[str]]:
pass


class OpenAIRules(LLMRules):
require_message_name_format: str = r"[^a-zA-Z0-9_-]"

model: ChatOpenAI

def model_instructions(self) -> list[str]:
instructions = []
if self.model.model_name.endswith("gpt-4o-mini"):
instructions.append(
textwrap.dedent(
"""
You can only provide a single result for each task, and a
task can only be marked successful one time. Do not make
multiple tool calls in parallel to supply multiple results
to the same task.
"""
)
)
return instructions


class AnthropicRules(LLMRules):
require_at_least_one_message: bool = True
Expand All @@ -56,8 +79,8 @@ class AnthropicRules(LLMRules):

def rules_for_model(model: BaseChatModel) -> LLMRules:
if isinstance(model, (ChatOpenAI, AzureChatOpenAI)):
return OpenAIRules()
return OpenAIRules(model=model)
elif isinstance(model, ChatAnthropic):
return AnthropicRules()
return AnthropicRules(model=model)
else:
return LLMRules()
return LLMRules(model=model)
6 changes: 6 additions & 0 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,19 +392,25 @@ def compile_prompt(self) -> str:
"""
from controlflow.orchestration.prompt_templates import (
InstructionsTemplate,
LLMInstructionsTemplate,
TasksTemplate,
ToolTemplate,
)

tools = self.get_tools()
llm_rules = self.agent.get_llm_rules()

prompts = [
self.agent.get_prompt(),
self.flow.get_prompt(),
TasksTemplate(tasks=self.get_tasks("ready")).render(),
ToolTemplate(tools=tools).render(),
InstructionsTemplate(instructions=get_instructions()).render(),
LLMInstructionsTemplate(
instructions=llm_rules.model_instructions()
).render(),
]

prompt = "\n\n".join([p for p in prompts if p])
return prompt

Expand Down
8 changes: 8 additions & 0 deletions src/controlflow/orchestration/prompt_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ def should_render(self) -> bool:
return bool(self.instructions)


class LLMInstructionsTemplate(Template):
template_path: str = "llm_instructions.jinja"
instructions: Optional[list[str]] = None

def should_render(self) -> bool:
return bool(self.instructions)


class ToolTemplate(Template):
template_path: str = "tools.jinja"
tools: list[Tool]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Instructions

You must follow these instructions. Note that instructions can be changed at any time.
You must follow these instructions at all times. Note that instructions can be changed at any time.

{% for instruction in instructions %}
- {{ instruction }}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# LLM Instructions

These instructions are specific to your LLM model. They must be followed to ensure compliance with the orchestrator and
other agents.

{% for instruction in instructions %}
- {{ instruction }}

{% endfor %}
15 changes: 11 additions & 4 deletions src/controlflow/orchestration/prompt_templates/tasks.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,18 @@ The following tasks are active:
</Task>
{% endfor %}

Only agents assigned to a task are able to mark the task as complete. You must use a tool to end your turn to let other
agents participate. If you are asked to talk to other agents, post messages. Do not impersonate another agent! Do not
impersonate the orchestrator!
Only agents assigned to a task are able to mark the task as complete. You must
use a tool to end your turn to let other agents participate. If you are asked to
talk to other agents, post messages. Do not impersonate another agent! Do not
impersonate the orchestrator! If you have been assigned a task, then you (and
other agents) must have the resources, knowledge, or tools required to complete
it.

A task can only be marked complete one time. Do not attempt to mark a task
successful more than once. Even if the `result_type` does not appear to match
the objective, you must supply a single compatible result. Only mark a task
failed if there is a technical error or issue preventing completion.

Only mark a task failed if there is a technical error or issue preventing completion.

## Task hierarchy

Expand Down
14 changes: 10 additions & 4 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import textwrap
import warnings
from contextlib import ExitStack, contextmanager
from enum import Enum
Expand Down Expand Up @@ -511,7 +512,10 @@ def create_success_tool(self) -> Tool:
Create an agent-compatible tool for marking this task as successful.
"""
options = {}
instructions = None
instructions = textwrap.dedent("""
Use this tool to mark the task as successful and provide a result.
This tool can only be used one time per task.
""")
result_schema = None

# if the result_type is a tuple of options, then we want the LLM to provide
Expand All @@ -532,10 +536,12 @@ def create_success_tool(self) -> Tool:
options_str = "\n\n".join(
f"Option {i}: {option}" for i, option in serialized_options.items()
)
instructions = f"""
instructions += "\n\n" + textwrap.dedent("""
Provide a single integer as the result, corresponding to the index
of your chosen option. Your options are: {options_str}
"""
of your chosen option. Your options are:
{options_str}
""").format(options_str=options_str)

# otherwise try to load the schema for the result type
elif self.result_type is not None:
Expand Down

0 comments on commit bbf9c20

Please sign in to comment.