Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed May 8, 2024
1 parent 50fef85 commit 709a4a7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
20 changes: 14 additions & 6 deletions src/control_flow/core/task.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import datetime
import itertools
from enum import Enum
from typing import TYPE_CHECKING, Callable, TypeVar
from typing import TYPE_CHECKING, Callable, GenericAlias, TypeVar

import marvin
import marvin.utilities.tools
from marvin.utilities.tools import FunctionTool
from pydantic import Field, TypeAdapter
from pydantic import Field, TypeAdapter, field_validator

from control_flow.utilities.logging import get_logger
from control_flow.utilities.prefect import wrap_prefect_tool
Expand All @@ -26,20 +26,24 @@ class TaskStatus(Enum):


class Task(ControlFlowModel):
model_config = dict(extra="forbid", allow_arbitrary_types=True)
model_config = dict(extra="forbid", arbitrary_types_allowed=True)
objective: str
instructions: str | None = None
agents: list["Agent"] = []
context: dict = {}
status: TaskStatus = TaskStatus.INCOMPLETE
result: T = None
result_type: type[T] | None = None
result_type: type[T] | GenericAlias | None = None
error: str | None = None
tools: list[AssistantTool | Callable] = []
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
completed_at: datetime.datetime | None = None
user_access: bool = False

@field_validator("agents", mode="before")
def _turn_none_into_empty_list(cls, v):
return v or []

def __init__(self, objective, **kwargs):
# allow objective as a positional arg
super().__init__(objective=objective, **kwargs)
Expand All @@ -48,8 +52,10 @@ def run(self, agents: list["Agent"] = None):
"""
Runs the task with provided agents for up to one cycle through the agents.
"""
from control_flow.core.agent import Agent

if not agents and not self.agents:
raise ValueError("No agents provided to run task.")
agents = [Agent()]

for agent in agents or self.agents:
if self.is_complete():
Expand All @@ -60,8 +66,10 @@ def run_until_complete(self, agents: list["Agent"] = None):
"""
Runs the task with provided agents until it is complete.
"""
from control_flow.core.agent import Agent

if not agents and not self.agents:
raise ValueError("No agents provided to run task.")
agents = [Agent()]
agents = itertools.cycle(agents or self.agents)
while self.is_incomplete():
agent = next(agents)
Expand Down
15 changes: 9 additions & 6 deletions src/control_flow/dx.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,18 @@ def wrapper(*args, _agents: list[Agent] = None, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

return run_ai.with_options(name=f"Task: {fn.__name__}")(
tasks=objective,
task = Task(
objective=objective,
agents=_agents or agents,
cast=fn.__annotations__.get("return"),
context=bound.arguments,
tools=tools,
user_access=user_access,
result_type=fn.__annotations__.get("return"),
user_access=user_access or False,
tools=tools or [],
)

task.run_until_complete()
return task.result

return wrapper


Expand Down Expand Up @@ -163,7 +166,7 @@ def run_ai(
# create tasks
if tasks:
ai_tasks = [
Task[cast](
Task(
objective=t,
context=context or {},
user_access=user_access or False,
Expand Down

0 comments on commit 709a4a7

Please sign in to comment.