From aa0de2cd5bca9304f8868864d02ca97f3bab01d0 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:45:02 -0400 Subject: [PATCH] Pass basemodel attributes directly as kwargs --- .github/ai-labeler.yml | 1 - src/controlflow/tasks/task.py | 59 ++++++++++++++++++++++++----------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/.github/ai-labeler.yml b/.github/ai-labeler.yml index 5a54bf77..b1e0d37f 100644 --- a/.github/ai-labeler.yml +++ b/.github/ai-labeler.yml @@ -1,5 +1,4 @@ labels: - # Simple form: just the name - bug - breaking change - documentation diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 15bd3fcd..856e1283 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -22,6 +22,7 @@ from prefect.context import TaskRunContext from pydantic import ( + BaseModel, Field, PydanticSchemaGenerationError, RootModel, @@ -624,25 +625,45 @@ def get_success_tool(self) -> Tool: "Please use a custom type or add compatibility." ) - @tool( - name=f"mark_task_{self.id}_successful", - description=f"Mark task {self.id} as successful.", - instructions=instructions, - include_return_description=False, - ) - def succeed(result: result_schema) -> str: # type: ignore - if self.is_successful(): - raise ValueError( - f"{self.friendly_name()} is already marked successful." - ) - if options: - if result not in options: - raise ValueError(f"Invalid option. Please choose one of {options}") - result = options[result] - self.mark_successful(result=result) - return f"{self.friendly_name()} marked successful." - - return succeed + # for basemodel subclasses, we accept the model properties directly as kwargs + if isinstance(result_schema, type) and issubclass(result_schema, BaseModel): + + def succeed(**kwargs) -> str: + self.mark_successful(result=result_schema(**kwargs)) + return f"{self.friendly_name()} marked successful." + + return Tool( + fn=succeed, + name=f"mark_task_{self.id}_successful", + description=f"Mark task {self.id} as successful.", + instructions=instructions, + parameters=result_schema.model_json_schema(), + ) + + # for all other results, we create a single `result` kwarg to capture the result + else: + + @tool( + name=f"mark_task_{self.id}_successful", + description=f"Mark task {self.id} as successful.", + instructions=instructions, + include_return_description=False, + ) + def succeed(result: result_schema) -> str: # type: ignore + if self.is_successful(): + raise ValueError( + f"{self.friendly_name()} is already marked successful." + ) + if options: + if result not in options: + raise ValueError( + f"Invalid option. Please choose one of {options}" + ) + result = options[result] + self.mark_successful(result=result) + return f"{self.friendly_name()} marked successful." + + return succeed def get_fail_tool(self) -> Tool: """