Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically resolve tasks within flows #36

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/task_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from controlflow import Task, flow


@flow
def book_ideas():
genre = Task("pick a genre", str)
ideas = Task(
"generate three short ideas for a book",
list[str],
context=dict(genre=genre),
)
abstract = Task(
"pick one idea and write an abstract",
str,
context=dict(ideas=ideas, genre=genre),
)
title = Task(
"pick a title",
str,
context=dict(abstract=abstract),
)

return dict(genre=genre, ideas=ideas, abstract=abstract, title=title)


if __name__ == "__main__":
result = book_ideas()
print(result)
5 changes: 3 additions & 2 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .settings import settings

from .core.flow import Flow, reset_global_flow as _reset_global_flow, flow
from .core.task import Task, task
from .core.flow import Flow, reset_global_flow as _reset_global_flow
from .core.task import Task
from .core.agent import Agent
from .core.controller.controller import Controller
from .instructions import instructions
from .decorators import flow, task

Flow.model_rebuild()
Task.model_rebuild()
Expand Down
60 changes: 0 additions & 60 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import functools
import inspect
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Union

import prefect
from marvin.beta.assistants import Thread
from openai.types.beta.threads import Message
from pydantic import Field, field_validator

import controlflow
from controlflow.utilities.context import ctx
from controlflow.utilities.logging import get_logger
from controlflow.utilities.marvin import patch_marvin
from controlflow.utilities.types import ControlFlowModel, ToolType

if TYPE_CHECKING:
Expand Down Expand Up @@ -95,59 +91,3 @@ def get_flow_messages(limit: int = None) -> list[Message]:
"""
flow = get_flow()
return flow.thread.get_messages(limit=limit)


def flow(
fn=None,
*,
thread: Thread = None,
instructions: str = None,
tools: list[ToolType] = None,
agents: list["Agent"] = None,
):
"""
A decorator that runs a function as a Flow
"""

if fn is None:
return functools.partial(
flow,
thread=thread,
tools=tools,
agents=agents,
)

sig = inspect.signature(fn)

@functools.wraps(fn)
def wrapper(
*args,
flow_kwargs: dict = None,
**kwargs,
):
# first process callargs
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

flow_kwargs = flow_kwargs or {}

if thread is not None:
flow_kwargs.setdefault("thread", thread)
if tools is not None:
flow_kwargs.setdefault("tools", tools)
if agents is not None:
flow_kwargs.setdefault("agents", agents)

p_fn = prefect.flow(fn)

flow_obj = Flow(**flow_kwargs, context=bound.arguments)

logger.info(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with ctx(flow=flow_obj), patch_marvin():
with controlflow.instructions(instructions):
return p_fn(*args, **kwargs)

return wrapper
79 changes: 41 additions & 38 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
GenericAlias,
Literal,
TypeVar,
Expand All @@ -31,6 +29,7 @@
from controlflow.utilities.context import ctx
from controlflow.utilities.logging import get_logger
from controlflow.utilities.prefect import wrap_prefect_tool
from controlflow.utilities.tasks import collect_tasks, visit_task_collection
from controlflow.utilities.types import (
NOTSET,
AssistantTool,
Expand All @@ -53,32 +52,6 @@ class TaskStatus(Enum):
SKIPPED = "skipped"


def visit_task_collection(
val: Any, fn: Callable, recursion_limit: int = 3, _counter: int = 0
) -> list["Task"]:
if _counter >= recursion_limit:
return val

if isinstance(val, dict):
result = {}
for key, value in list(val.items()):
result[key] = visit_task_collection(
value, fn=fn, recursion_limit=recursion_limit, _counter=_counter + 1
)
elif isinstance(val, (list, set, tuple)):
result = []
for item in val:
result.append(
visit_task_collection(
item, fn=fn, recursion_limit=recursion_limit, _counter=_counter + 1
)
)
elif isinstance(val, Task):
return fn(val)

return val


class Task(ControlFlowModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:5]))
objective: str = Field(
Expand Down Expand Up @@ -109,6 +82,7 @@ class Task(ControlFlowModel):
error: Union[str, None] = None
tools: list[ToolType] = []
user_access: bool = False
is_auto_completed_by_subtasks: bool = False
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
_parent: "Union[Task, None]" = None
_downstreams: list["Task"] = []
Expand Down Expand Up @@ -191,16 +165,16 @@ def _turn_list_into_literal_result_type(cls, v):

@model_validator(mode="after")
def _finalize(self):
# create dependencies to tasks passed in as context
tasks = []

def visitor(task):
tasks.append(task)
return task
# validate correlated settings
if self.result_type is not None and self.is_auto_completed_by_subtasks:
raise ValueError(
"Tasks with a result type cannot be auto-completed by their subtasks."
)

visit_task_collection(self.context, visitor)
# create dependencies to tasks passed in as context
context_tasks = collect_tasks(self.context)

for task in tasks:
for task in context_tasks:
if task not in self.depends_on:
self.depends_on.append(task)
return self
Expand Down Expand Up @@ -283,7 +257,7 @@ def run_once(self, agent: "Agent" = None):

controller.run_once()

def run(self, max_iterations: int = NOTSET) -> T:
def run(self, raise_on_error: bool = True, max_iterations: int = NOTSET) -> T:
"""
Runs the task with provided agents until it is complete.

Expand All @@ -304,7 +278,7 @@ def run(self, max_iterations: int = NOTSET) -> T:
counter += 1
if self.is_successful():
return self.result
elif self.is_failed():
elif self.is_failed() and raise_on_error:
raise ValueError(f"{self.friendly_name()} failed: {self.error}")

@contextmanager
Expand Down Expand Up @@ -394,6 +368,9 @@ def get_tools(self) -> list[ToolType]:
tools.append(marvin.utilities.tools.tool_from_function(talk_to_human))
return [wrap_prefect_tool(t) for t in tools]

def dependencies(self):
return self.depends_on + self.subtasks

def mark_successful(self, result: T = None, validate: bool = True):
if validate:
if any(t.is_incomplete() for t in self.depends_on):
Expand All @@ -418,15 +395,41 @@ def mark_successful(self, result: T = None, validate: bool = True):

self.result = result
self.status = TaskStatus.SUCCESSFUL

# attempt to complete the parent, if appropriate
if (
self._parent
and self._parent.is_auto_completed_by_subtasks
and all_complete(self._parent.dependencies())
):
self._parent.mark_successful(validate=True)

return f"{self.friendly_name()} marked successful. Updated task definition: {self.model_dump()}"

def mark_failed(self, message: Union[str, None] = None):
self.error = message
self.status = TaskStatus.FAILED

# attempt to fail the parent, if appropriate
if (
self._parent
and self._parent.is_auto_completed_by_subtasks
and all_complete(self._parent.dependencies())
):
self._parent.mark_failed()

return f"{self.friendly_name()} marked failed. Updated task definition: {self.model_dump()}"

def mark_skipped(self):
self.status = TaskStatus.SKIPPED
# attempt to complete the parent, if appropriate
if (
self._parent
and self._parent.is_auto_completed_by_subtasks
and all_complete(self._parent.dependencies())
):
self._parent.mark_successful(validate=False)

return f"{self.friendly_name()} marked skipped. Updated task definition: {self.model_dump()}"


Expand Down
Loading
Loading