Skip to content

Commit

Permalink
Improve default handling for agents and parents
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed May 16, 2024
1 parent 7b2e415 commit cbd614c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 21 deletions.
55 changes: 36 additions & 19 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ class Task(ControlFlowModel):
instructions: Union[str, None] = Field(
None, description="Detailed instructions for completing the task."
)
agents: Union[list["Agent"], None] = Field(
agents: list["Agent"] = Field(
None,
description="The agents assigned to the task. If None, the task will use its flow's default agents.",
validate_default=True,
description="The agents assigned to the task. If not provided, agents will be inferred from the parent task, flow, or global default.",
)
context: dict = Field(
default_factory=dict,
Expand Down Expand Up @@ -143,6 +142,7 @@ def __init__(
self,
objective=None,
result_type=None,
*,
parent: "Task" = None,
**kwargs,
):
Expand All @@ -159,16 +159,37 @@ def __init__(
+ "\n".join(additional_instructions)
).strip()

super().__init__(**kwargs)

# setup up relationships
if parent is None:
parent_tasks = ctx.get("tasks", [])
parent = parent_tasks[-1] if parent_tasks else None

# set up default agents
# - if provided, use the provided agents
# - if not provided, use the parent's agents
# - if no parent, use the flow's agents
# - if no flow, use the default agent
if not kwargs.get("agents"):
from controlflow.core.agent import default_agent
from controlflow.core.flow import get_flow

if parent and parent.agents:
kwargs["agents"] = parent.agents
else:
try:
flow = get_flow()
except ValueError:
flow = None
if flow and flow.agents:
kwargs["agents"] = flow.agents
else:
kwargs["agents"] = [default_agent()]

super().__init__(**kwargs)

# register task with parent
if parent is not None:
parent.add_subtask(self)
for task in self.depends_on:
self.add_dependency(task)

def __repr__(self):
include_fields = [
Expand All @@ -192,18 +213,6 @@ def __repr__(self):

@field_validator("agents", mode="before")
def _default_agents(cls, v):
from controlflow.core.agent import default_agent
from controlflow.core.flow import get_flow

if v is None:
try:
flow = get_flow()
except ValueError:
flow = None
if flow and flow.agents:
v = flow.agents
else:
v = [default_agent()]
if not v:
raise ValueError("At least one agent is required.")
return v
Expand All @@ -222,14 +231,22 @@ def _finalize(self):
flow = get_flow()
flow.add_task(self)

# create dependencies to tasks passed in as depends_on
for task in self.depends_on:
self.add_dependency(task)

# create dependencies to tasks passed in as context
context_tasks = collect_tasks(self.context)

for task in context_tasks:
if task not in self.depends_on:
self.depends_on.append(task)

return self

def parent(self) -> Optional["Task"]:
return self._parent

@field_serializer("subtasks")
def _serialize_subtasks(subtasks: list["Task"]):
return [t.id for t in subtasks]
Expand Down
51 changes: 49 additions & 2 deletions tests/core/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import AsyncMock

import pytest
from controlflow.core.agent import Agent
from controlflow.core.agent import Agent, default_agent
from controlflow.core.flow import Flow
from controlflow.core.graph import EdgeType
from controlflow.core.task import Task, TaskStatus
Expand Down Expand Up @@ -37,7 +37,21 @@ def test_task_subtasks():
task1 = Task(objective="Task 1")
task2 = Task(objective="Task 2", parent=task1)
assert task2 in task1.subtasks
assert task2._parent == task1
assert task2.parent() is task1


def test_task_parent_context():
with Task("grandparent") as task1:
with Task("parent") as task2:
task3 = Task("child")

assert task3.parent() is task2
assert task2.parent() is task1
assert task1.parent() is None

assert task1.subtasks == [task2]
assert task2.subtasks == [task3]
assert task3.subtasks == []


def test_task_agent_assignment():
Expand All @@ -46,6 +60,39 @@ def test_task_agent_assignment():
assert agent in task.agents


def test_task_loads_agent_from_parent():
agent = Agent(name="Test Agent")
with Task("parent", agents=[agent]):
child = Task("child")

assert child.agents == [agent]


def test_task_loads_agent_from_flow():
agent = Agent(name="Test Agent")
with Flow(agents=[agent]):
task = Task("task")

assert task.agents == [agent]


def test_task_loads_agent_from_default_if_none_otherwise():
agent = default_agent()
task = Task("task")

assert task.agents == [agent]


def test_task_loads_agent_from_parent_before_flow():
agent1 = Agent(name="Test Agent 1")
agent2 = Agent(name="Test Agent 2")
with Flow(agents=[agent1]):
with Task("parent", agents=[agent2]):
child = Task("child")

assert child.agents == [agent2]


def test_task_tracking(mock_run):
with Flow() as flow:
task = Task(objective="Test objective")
Expand Down

0 comments on commit cbd614c

Please sign in to comment.