Skip to content

Commit

Permalink
result type → cast
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Apr 5, 2024
1 parent c89f3fa commit 1620a28
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def demo():
name = get_user_name()

# define an AI task inline
interests = run_ai("ask user for three interests", result_type=list[str], user_access=True)
interests = run_ai("ask user for three interests", cast=list[str], user_access=True)

# set instructions for just the next task
with instructions("no more than 8 lines"):
Expand Down
2 changes: 1 addition & 1 deletion examples/readme_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def demo():
# define an AI task inline
interests = run_ai(
"ask user for three interests",
result_type=list[str],
cast=list[str],
user_access=True,
)

Expand Down
15 changes: 9 additions & 6 deletions src/control_flow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@
### Task {{ task.id }}
- Status: {{ task.status.value }}
- Objective: {{ task.objective }}
{% if task.instructions %}
- Additional instructions: {{ task.instructions }}
{% endif %}
{% if task.status.value == "completed" %}
- Result: {{ task.result }}
{% elif task.status.value == "failed" %}
Expand Down Expand Up @@ -503,7 +506,7 @@ def wrapper(*args, **kwargs):

return run_ai.with_options(name=f"Task: {fn.__name__}")(
task=objective,
result_type=fn.__annotations__.get("return"),
cast=fn.__annotations__.get("return"),
context=bound.arguments,
user_access=user_access,
**agent_kwargs,
Expand All @@ -517,23 +520,23 @@ def _name_from_objective():
from prefect.runtime import task_run

objective = task_run.parameters["task"]
if len(objective) > 32:
return f"Task: {objective[:32]}..."
if len(objective) > 50:
return f"Task: {objective[:50]}..."
return f"Task: {objective}"


@prefect_task(task_run_name=_name_from_objective)
def run_ai(
task: str,
result_type: T = str,
cast: T = str,
context: dict = None,
user_access: bool = None,
model: str = None,
**agent_kwargs: dict,
) -> T:
"""
Run an agent to complete a task with the given objective and context. The
response will be of the given result type.
response will be cast to the given result type.
"""

# load flow
Expand All @@ -542,7 +545,7 @@ def run_ai(
flow = AIFlow()

# create task
ai_task = AITask[result_type](objective=task, context=context)
ai_task = AITask[cast](objective=task, context=context)
flow.add_task(ai_task)

# run agent
Expand Down
4 changes: 4 additions & 0 deletions src/control_flow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class AITask(BaseModel, Generic[T]):

id: int = Field(None, validate_default=True)
objective: str
instructions: Optional[str] = None
context: dict = Field(None, validate_default=True)
status: TaskStatus = TaskStatus.PENDING
result: T = None
Expand Down Expand Up @@ -108,4 +109,7 @@ def fail(self, message: Optional[str] = None):
self.status = TaskStatus.FAILED

def get_result_type(self) -> T:
"""
Returns the `type` of the task's result field.
"""
return self.model_fields["result"].annotation

0 comments on commit 1620a28

Please sign in to comment.