diff --git a/.github/ai-labeler.yml b/.github/ai-labeler.yml index 5a54bf77..b1e0d37f 100644 --- a/.github/ai-labeler.yml +++ b/.github/ai-labeler.yml @@ -1,5 +1,4 @@ labels: - # Simple form: just the name - bug - breaking change - documentation diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 15bd3fcd..86914d95 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -22,6 +22,7 @@ from prefect.context import TaskRunContext from pydantic import ( + BaseModel, Field, PydanticSchemaGenerationError, RootModel, @@ -44,6 +45,7 @@ NOTSET, ControlFlowModel, hash_objects, + safe_issubclass, unwrap, ) from controlflow.utilities.logging import get_logger @@ -624,25 +626,45 @@ def get_success_tool(self) -> Tool: "Please use a custom type or add compatibility." ) - @tool( - name=f"mark_task_{self.id}_successful", - description=f"Mark task {self.id} as successful.", - instructions=instructions, - include_return_description=False, - ) - def succeed(result: result_schema) -> str: # type: ignore - if self.is_successful(): - raise ValueError( - f"{self.friendly_name()} is already marked successful." - ) - if options: - if result not in options: - raise ValueError(f"Invalid option. Please choose one of {options}") - result = options[result] - self.mark_successful(result=result) - return f"{self.friendly_name()} marked successful." - - return succeed + # for basemodel subclasses, we accept the model properties directly as kwargs + if safe_issubclass(result_schema, BaseModel): + + def succeed(**kwargs) -> str: + self.mark_successful(result=result_schema(**kwargs)) + return f"{self.friendly_name()} marked successful." + + return Tool( + fn=succeed, + name=f"mark_task_{self.id}_successful", + description=f"Mark task {self.id} as successful.", + instructions=instructions, + parameters=result_schema.model_json_schema(), + ) + + # for all other results, we create a single `result` kwarg to capture the result + else: + + @tool( + name=f"mark_task_{self.id}_successful", + description=f"Mark task {self.id} as successful.", + instructions=instructions, + include_return_description=False, + ) + def succeed(result: result_schema) -> str: # type: ignore + if self.is_successful(): + raise ValueError( + f"{self.friendly_name()} is already marked successful." + ) + if options: + if result not in options: + raise ValueError( + f"Invalid option. Please choose one of {options}" + ) + result = options[result] + self.mark_successful(result=result) + return f"{self.friendly_name()} marked successful." + + return succeed def get_fail_tool(self) -> Tool: """ diff --git a/src/controlflow/utilities/general.py b/src/controlflow/utilities/general.py index ce8e5e07..4edfe363 100644 --- a/src/controlflow/utilities/general.py +++ b/src/controlflow/utilities/general.py @@ -77,3 +77,15 @@ class PandasSeries(ControlFlowModel): index: Optional[list[str]] = None name: Optional[str] = None dtype: Optional[str] = None + + +def safe_issubclass(cls: type, subclass: type) -> bool: + """ + `issubclass` raises a TypeError if cls is not a type. This helper function + safely checks if cls is a type and then checks if it is a subclass of + subclass. + """ + try: + return isinstance(cls, type) and issubclass(cls, subclass) + except TypeError: + return False