Skip to content

Commit

Permalink
Improve instructions generation for success tool
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Oct 31, 2024
1 parent 973d663 commit 29f76d8
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,7 @@ def get_success_tool(self) -> Tool:
Create an agent-compatible tool for marking this task as successful.
"""
options = {}
instructions = unwrap("""
Use this tool to mark the task as successful and provide a result.
""")
instructions = None
result_schema = None

# if the result_type is a tuple of options, then we want the LLM to provide
Expand All @@ -605,12 +603,14 @@ def get_success_tool(self) -> Tool:
options_str = "\n\n".join(
f"Option {i}: {option}" for i, option in serialized_options.items()
)
instructions += "\n\n" + unwrap("""
Provide a single integer as the result, corresponding to the index
instructions = unwrap(
"""
Provide a single integer as the task result, corresponding to the index
of your chosen option. Your options are:
{options_str}
""").format(options_str=options_str)
"""
).format(options_str=options_str)

# otherwise try to load the schema for the result type
elif self.result_type is not None:
Expand All @@ -628,6 +628,11 @@ def get_success_tool(self) -> Tool:

# for basemodel subclasses, we accept the model properties directly as kwargs
if safe_issubclass(result_schema, BaseModel):
instructions = unwrap(
f"""
Use this tool to mark the task as successful and provide a result. The result schema is: {result_schema}
"""
)

def succeed(**kwargs) -> str:
self.mark_successful(result=result_schema(**kwargs))
Expand All @@ -642,29 +647,48 @@ def succeed(**kwargs) -> str:
)

# for all other results, we create a single `result` kwarg to capture the result
else:
elif result_schema is not None:
instructions = unwrap(
f"""
Use this tool to mark the task as successful and provide a result.
The result schema is: {{"task_result": {result_schema}}}
"""
)

@tool(
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
include_return_description=False,
)
def succeed(result: result_schema) -> str: # type: ignore
def succeed(task_result: result_schema) -> str: # type: ignore
if self.is_successful():
raise ValueError(
f"{self.friendly_name()} is already marked successful."
)
if options:
if result not in options:
if task_result not in options:
raise ValueError(
f"Invalid option. Please choose one of {options}"
)
result = options[result]
self.mark_successful(result=result)
task_result = options[task_result]
self.mark_successful(result=task_result)
return f"{self.friendly_name()} marked successful."

# for no result schema, we provide a tool that takes no arguments
else:

@tool(
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
include_return_description=False,
)
def succeed() -> str:
self.mark_successful()
return f"{self.friendly_name()} marked successful."

return succeed
return succeed

def get_fail_tool(self) -> Tool:
"""
Expand All @@ -673,8 +697,10 @@ def get_fail_tool(self) -> Tool:

@tool(
name=f"mark_task_{self.id}_failed",
description=(
f"Mark task {self.id} as failed. Only use when technical errors prevent success. Provide a detailed reason for the failure."
description=unwrap(
f"""Mark task {self.id} as failed. Only use when technical
errors prevent success. Provide a detailed reason for the
failure."""
),
include_return_description=False,
)
Expand Down

0 comments on commit 29f76d8

Please sign in to comment.