From 29f76d8de3c1a31323b65530ac3c587df2840cc8 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:03:45 -0400 Subject: [PATCH] Improve instructions generation for success tool --- src/controlflow/tasks/task.py | 54 ++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 86914d95..52fa6dcf 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -582,9 +582,7 @@ def get_success_tool(self) -> Tool: Create an agent-compatible tool for marking this task as successful. """ options = {} - instructions = unwrap(""" - Use this tool to mark the task as successful and provide a result. - """) + instructions = None result_schema = None # if the result_type is a tuple of options, then we want the LLM to provide @@ -605,12 +603,14 @@ def get_success_tool(self) -> Tool: options_str = "\n\n".join( f"Option {i}: {option}" for i, option in serialized_options.items() ) - instructions += "\n\n" + unwrap(""" - Provide a single integer as the result, corresponding to the index + instructions = unwrap( + """ + Provide a single integer as the task result, corresponding to the index of your chosen option. Your options are: {options_str} - """).format(options_str=options_str) + """ + ).format(options_str=options_str) # otherwise try to load the schema for the result type elif self.result_type is not None: @@ -628,6 +628,11 @@ def get_success_tool(self) -> Tool: # for basemodel subclasses, we accept the model properties directly as kwargs if safe_issubclass(result_schema, BaseModel): + instructions = unwrap( + f""" + Use this tool to mark the task as successful and provide a result. The result schema is: {result_schema} + """ + ) def succeed(**kwargs) -> str: self.mark_successful(result=result_schema(**kwargs)) @@ -642,7 +647,13 @@ def succeed(**kwargs) -> str: ) # for all other results, we create a single `result` kwarg to capture the result - else: + elif result_schema is not None: + instructions = unwrap( + f""" + Use this tool to mark the task as successful and provide a result. + The result schema is: {{"task_result": {result_schema}}} + """ + ) @tool( name=f"mark_task_{self.id}_successful", @@ -650,21 +661,34 @@ def succeed(**kwargs) -> str: instructions=instructions, include_return_description=False, ) - def succeed(result: result_schema) -> str: # type: ignore + def succeed(task_result: result_schema) -> str: # type: ignore if self.is_successful(): raise ValueError( f"{self.friendly_name()} is already marked successful." ) if options: - if result not in options: + if task_result not in options: raise ValueError( f"Invalid option. Please choose one of {options}" ) - result = options[result] - self.mark_successful(result=result) + task_result = options[task_result] + self.mark_successful(result=task_result) + return f"{self.friendly_name()} marked successful." + + # for no result schema, we provide a tool that takes no arguments + else: + + @tool( + name=f"mark_task_{self.id}_successful", + description=f"Mark task {self.id} as successful.", + instructions=instructions, + include_return_description=False, + ) + def succeed() -> str: + self.mark_successful() return f"{self.friendly_name()} marked successful." - return succeed + return succeed def get_fail_tool(self) -> Tool: """ @@ -673,8 +697,10 @@ def get_fail_tool(self) -> Tool: @tool( name=f"mark_task_{self.id}_failed", - description=( - f"Mark task {self.id} as failed. Only use when technical errors prevent success. Provide a detailed reason for the failure." + description=unwrap( + f"""Mark task {self.id} as failed. Only use when technical + errors prevent success. Provide a detailed reason for the + failure.""" ), include_return_description=False, )