From aa0de2cd5bca9304f8868864d02ca97f3bab01d0 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:45:02 -0400 Subject: [PATCH 1/3] Pass basemodel attributes directly as kwargs --- .github/ai-labeler.yml | 1 - src/controlflow/tasks/task.py | 59 ++++++++++++++++++++++++----------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/.github/ai-labeler.yml b/.github/ai-labeler.yml index 5a54bf77..b1e0d37f 100644 --- a/.github/ai-labeler.yml +++ b/.github/ai-labeler.yml @@ -1,5 +1,4 @@ labels: - # Simple form: just the name - bug - breaking change - documentation diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 15bd3fcd..856e1283 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -22,6 +22,7 @@ from prefect.context import TaskRunContext from pydantic import ( + BaseModel, Field, PydanticSchemaGenerationError, RootModel, @@ -624,25 +625,45 @@ def get_success_tool(self) -> Tool: "Please use a custom type or add compatibility." ) - @tool( - name=f"mark_task_{self.id}_successful", - description=f"Mark task {self.id} as successful.", - instructions=instructions, - include_return_description=False, - ) - def succeed(result: result_schema) -> str: # type: ignore - if self.is_successful(): - raise ValueError( - f"{self.friendly_name()} is already marked successful." - ) - if options: - if result not in options: - raise ValueError(f"Invalid option. Please choose one of {options}") - result = options[result] - self.mark_successful(result=result) - return f"{self.friendly_name()} marked successful." - - return succeed + # for basemodel subclasses, we accept the model properties directly as kwargs + if isinstance(result_schema, type) and issubclass(result_schema, BaseModel): + + def succeed(**kwargs) -> str: + self.mark_successful(result=result_schema(**kwargs)) + return f"{self.friendly_name()} marked successful." + + return Tool( + fn=succeed, + name=f"mark_task_{self.id}_successful", + description=f"Mark task {self.id} as successful.", + instructions=instructions, + parameters=result_schema.model_json_schema(), + ) + + # for all other results, we create a single `result` kwarg to capture the result + else: + + @tool( + name=f"mark_task_{self.id}_successful", + description=f"Mark task {self.id} as successful.", + instructions=instructions, + include_return_description=False, + ) + def succeed(result: result_schema) -> str: # type: ignore + if self.is_successful(): + raise ValueError( + f"{self.friendly_name()} is already marked successful." + ) + if options: + if result not in options: + raise ValueError( + f"Invalid option. Please choose one of {options}" + ) + result = options[result] + self.mark_successful(result=result) + return f"{self.friendly_name()} marked successful." + + return succeed def get_fail_tool(self) -> Tool: """ From a18cb505aaf602c2fbfb2345b618d4b3998a8d9c Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:58:19 -0400 Subject: [PATCH 2/3] use safe_issubclass --- src/controlflow/tasks/task.py | 3 ++- src/controlflow/utilities/general.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 856e1283..86914d95 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -45,6 +45,7 @@ NOTSET, ControlFlowModel, hash_objects, + safe_issubclass, unwrap, ) from controlflow.utilities.logging import get_logger @@ -626,7 +627,7 @@ def get_success_tool(self) -> Tool: ) # for basemodel subclasses, we accept the model properties directly as kwargs - if isinstance(result_schema, type) and issubclass(result_schema, BaseModel): + if safe_issubclass(result_schema, BaseModel): def succeed(**kwargs) -> str: self.mark_successful(result=result_schema(**kwargs)) diff --git a/src/controlflow/utilities/general.py b/src/controlflow/utilities/general.py index ce8e5e07..08722c7d 100644 --- a/src/controlflow/utilities/general.py +++ b/src/controlflow/utilities/general.py @@ -77,3 +77,10 @@ class PandasSeries(ControlFlowModel): index: Optional[list[str]] = None name: Optional[str] = None dtype: Optional[str] = None + + +def safe_issubclass(cls: type, subclass: type) -> bool: + try: + return isinstance(cls, type) and issubclass(cls, subclass) + except TypeError: + return False From dbcf7a28f1dc49bbfcbe54150a4a3ead58716e22 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 29 Oct 2024 18:00:09 -0400 Subject: [PATCH 3/3] Update general.py --- src/controlflow/utilities/general.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/controlflow/utilities/general.py b/src/controlflow/utilities/general.py index 08722c7d..4edfe363 100644 --- a/src/controlflow/utilities/general.py +++ b/src/controlflow/utilities/general.py @@ -80,6 +80,11 @@ class PandasSeries(ControlFlowModel): def safe_issubclass(cls: type, subclass: type) -> bool: + """ + `issubclass` raises a TypeError if cls is not a type. This helper function + safely checks if cls is a type and then checks if it is a subclass of + subclass. + """ try: return isinstance(cls, type) and issubclass(cls, subclass) except TypeError: