From 709a4a7b9aa930243bf6d908e2841f607c07f22f Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 7 May 2024 22:24:07 -0400 Subject: [PATCH] Fix typing --- src/control_flow/core/task.py | 20 ++++++++++++++------ src/control_flow/dx.py | 15 +++++++++------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/control_flow/core/task.py b/src/control_flow/core/task.py index 8a1cc89a..d6073b3d 100644 --- a/src/control_flow/core/task.py +++ b/src/control_flow/core/task.py @@ -1,12 +1,12 @@ import datetime import itertools from enum import Enum -from typing import TYPE_CHECKING, Callable, TypeVar +from typing import TYPE_CHECKING, Callable, GenericAlias, TypeVar import marvin import marvin.utilities.tools from marvin.utilities.tools import FunctionTool -from pydantic import Field, TypeAdapter +from pydantic import Field, TypeAdapter, field_validator from control_flow.utilities.logging import get_logger from control_flow.utilities.prefect import wrap_prefect_tool @@ -26,20 +26,24 @@ class TaskStatus(Enum): class Task(ControlFlowModel): - model_config = dict(extra="forbid", allow_arbitrary_types=True) + model_config = dict(extra="forbid", arbitrary_types_allowed=True) objective: str instructions: str | None = None agents: list["Agent"] = [] context: dict = {} status: TaskStatus = TaskStatus.INCOMPLETE result: T = None - result_type: type[T] | None = None + result_type: type[T] | GenericAlias | None = None error: str | None = None tools: list[AssistantTool | Callable] = [] created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) completed_at: datetime.datetime | None = None user_access: bool = False + @field_validator("agents", mode="before") + def _turn_none_into_empty_list(cls, v): + return v or [] + def __init__(self, objective, **kwargs): # allow objective as a positional arg super().__init__(objective=objective, **kwargs) @@ -48,8 +52,10 @@ def run(self, agents: list["Agent"] = None): """ Runs the task with provided agents for up to one cycle through the agents. """ + from control_flow.core.agent import Agent + if not agents and not self.agents: - raise ValueError("No agents provided to run task.") + agents = [Agent()] for agent in agents or self.agents: if self.is_complete(): @@ -60,8 +66,10 @@ def run_until_complete(self, agents: list["Agent"] = None): """ Runs the task with provided agents until it is complete. """ + from control_flow.core.agent import Agent + if not agents and not self.agents: - raise ValueError("No agents provided to run task.") + agents = [Agent()] agents = itertools.cycle(agents or self.agents) while self.is_incomplete(): agent = next(agents) diff --git a/src/control_flow/dx.py b/src/control_flow/dx.py index 499d6cdb..e061c4a7 100644 --- a/src/control_flow/dx.py +++ b/src/control_flow/dx.py @@ -100,15 +100,18 @@ def wrapper(*args, _agents: list[Agent] = None, **kwargs): bound = sig.bind(*args, **kwargs) bound.apply_defaults() - return run_ai.with_options(name=f"Task: {fn.__name__}")( - tasks=objective, + task = Task( + objective=objective, agents=_agents or agents, - cast=fn.__annotations__.get("return"), context=bound.arguments, - tools=tools, - user_access=user_access, + result_type=fn.__annotations__.get("return"), + user_access=user_access or False, + tools=tools or [], ) + task.run_until_complete() + return task.result + return wrapper @@ -163,7 +166,7 @@ def run_ai( # create tasks if tasks: ai_tasks = [ - Task[cast]( + Task( objective=t, context=context or {}, user_access=user_access or False,