Skip to content

Remove max turns from task #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ def _run_turn(self, max_calls_per_turn: Optional[int] = None):

for task in self.get_tasks("assigned"):
task.mark_running()
if task.max_turns and task._turns >= task.max_turns:
task.mark_failed(reason="Max turns exceeded.")
else:
task._turns += 1

calls = 0
while not self.turn_strategy.should_end_turn():
Expand Down Expand Up @@ -164,10 +160,6 @@ async def _run_turn_async(self, max_calls_per_turn: Optional[int] = None):

for task in self.get_tasks("assigned"):
task.mark_running()
if task.max_turns and task._turns >= task.max_turns:
task.mark_failed(reason="Max turns exceeded.")
else:
task._turns += 1

calls = 0
while not self.turn_strategy.should_end_turn():
Expand Down
5 changes: 0 additions & 5 deletions src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ class Settings(ControlFlowSettings):
)

# ------------ orchestration settings ------------
task_max_turns: Optional[int] = Field(
default=None,
description="The maximum number of agent turns allowed when attempting to run any task. "
"Turns are counted across the life of the task. If None, tasks may run indefinitely.",
)
orchestrator_max_turns: Optional[int] = Field(
default=100,
description="The maximum number of agent turns allowed when orchestrating tasks. "
Expand Down
5 changes: 0 additions & 5 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,8 @@ class Task(ControlFlowModel):
)
interactive: bool = False
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
max_turns: Optional[int] = Field(
default_factory=lambda: controlflow.settings.task_max_turns,
description="The maximum number of turns to attempt to run a task.",
)
_subtasks: set["Task"] = set()
_downstreams: set["Task"] = set()
_turns: int = 0
_cm_stack: list[contextmanager] = []

model_config = dict(extra="forbid", arbitrary_types_allowed=True)
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
def temp_controlflow_settings():
with temporary_settings(
enable_print_handler=False,
task_max_turns=10,
orchestrator_max_turns=10,
orchestrator_max_calls_per_turn=10,
):
Expand Down
48 changes: 1 addition & 47 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
from controlflow.agents import Agent
from controlflow.flows import Flow
from controlflow.instructions import instructions
from controlflow.orchestration import turn_strategies
from controlflow.orchestration.orchestrator import Orchestrator
from controlflow.settings import temporary_settings
from controlflow.tasks.task import (
COMPLETE_STATUSES,
INCOMPLETE_STATUSES,
Task,
TaskStatus,
)
from controlflow.utilities.context import ctx
from controlflow.utilities.testing import FakeLLM, SimpleTask
from controlflow.utilities.testing import SimpleTask


def test_status_coverage():
Expand Down Expand Up @@ -448,46 +445,3 @@ def mock_run_model(*args, **kwargs):
)

assert call_count == expected_calls


class TestMaxTurns:
def test_default_max_turns(self):
with temporary_settings(task_max_turns=99):
task = Task("Test task")
assert task.max_turns == 99

def test_custom_max_turns(self):
task = Task("Test task", max_turns=10)
assert task.max_turns == 10

def test_max_turns_reached(self, default_fake_llm: FakeLLM):
task = Task("Test task", max_turns=3)

with pytest.raises(ValueError, match="Max turns exceeded"):
task.run(max_calls_per_turn=1)

assert task._turns == 3
assert task.is_failed()
assert task.result == "Max turns exceeded."

def test_max_turns_only_applies_if_task_is_ready_and_assigned_to_active_agent(
self, default_fake_llm
):
agent1 = Agent()
agent2 = Agent()
task1 = Task("Test task 1", max_turns=3, agents=[agent1])
task2 = Task("Test task 2", max_turns=3, agents=[agent2])

Orchestrator(
flow=Flow(),
tasks=[task1, task2],
agent=agent1,
turn_strategy=turn_strategies.Single(),
).run(max_calls_per_turn=1)

assert task1._turns == 3
assert task1.is_failed()
assert task1.result == "Max turns exceeded."

assert task2._turns == 0
assert task2.is_ready()