Skip to content

Commit

Permalink
Pass basemodel attributes directly as kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Oct 29, 2024
1 parent 970c9ed commit aa0de2c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
1 change: 0 additions & 1 deletion .github/ai-labeler.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
labels:
# Simple form: just the name
- bug
- breaking change
- documentation
Expand Down
59 changes: 40 additions & 19 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from prefect.context import TaskRunContext
from pydantic import (
BaseModel,
Field,
PydanticSchemaGenerationError,
RootModel,
Expand Down Expand Up @@ -624,25 +625,45 @@ def get_success_tool(self) -> Tool:
"Please use a custom type or add compatibility."
)

@tool(
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
include_return_description=False,
)
def succeed(result: result_schema) -> str: # type: ignore
if self.is_successful():
raise ValueError(
f"{self.friendly_name()} is already marked successful."
)
if options:
if result not in options:
raise ValueError(f"Invalid option. Please choose one of {options}")
result = options[result]
self.mark_successful(result=result)
return f"{self.friendly_name()} marked successful."

return succeed
# for basemodel subclasses, we accept the model properties directly as kwargs
if isinstance(result_schema, type) and issubclass(result_schema, BaseModel):

def succeed(**kwargs) -> str:
self.mark_successful(result=result_schema(**kwargs))
return f"{self.friendly_name()} marked successful."

return Tool(
fn=succeed,
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
parameters=result_schema.model_json_schema(),
)

# for all other results, we create a single `result` kwarg to capture the result
else:

@tool(
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
include_return_description=False,
)
def succeed(result: result_schema) -> str: # type: ignore
if self.is_successful():
raise ValueError(
f"{self.friendly_name()} is already marked successful."
)
if options:
if result not in options:
raise ValueError(
f"Invalid option. Please choose one of {options}"
)
result = options[result]
self.mark_successful(result=result)
return f"{self.friendly_name()} marked successful."

return succeed

def get_fail_tool(self) -> Tool:
"""
Expand Down

0 comments on commit aa0de2c

Please sign in to comment.