diff --git a/docs/concepts/tasks.mdx b/docs/concepts/tasks.mdx index c96bcfc0..79164ec1 100644 --- a/docs/concepts/tasks.mdx +++ b/docs/concepts/tasks.mdx @@ -270,7 +270,7 @@ Note that this setting reflects the configuration of the `completion_tools` para import { VersionBadge } from '/snippets/version-badge.mdx' - + In addition to specifying which agents are automatically given completion tools, you can control which completion tools are generated for a task using the `completion_tools` parameter. This allows you to specify whether you want to provide success and/or failure tools, or even provide custom completion tools. diff --git a/docs/examples/features/early-termination.mdx b/docs/examples/features/early-termination.mdx new file mode 100644 index 00000000..ba98d5b8 --- /dev/null +++ b/docs/examples/features/early-termination.mdx @@ -0,0 +1,102 @@ +--- +title: Early Termination +description: Control workflow execution with flexible termination logic. +icon: circle-stop +--- + +import { VersionBadge } from "/snippets/version-badge.mdx" + + + +This example demonstrates how to use termination conditions with the `run_until` parameter to control the execution of a ControlFlow workflow. We'll create a simple research workflow that stops under various conditions, showcasing the flexibility of this feature. In this case, we'll allow research to continue until either two topics are researched or 15 LLM calls are made. + +## Code + +```python +import controlflow as cf +from controlflow.orchestration.conditions import AnyComplete, MaxLLMCalls +from pydantic import BaseModel + + +class ResearchPoint(BaseModel): + topic: str + key_findings: list[str] + + +@cf.flow +def research_workflow(topics: list[str]): + if len(topics) < 2: + raise ValueError("At least two topics are required") + + research_tasks = [ + cf.Task(f"Research {topic}", result_type=ResearchPoint) + for topic in topics + ] + + # Run tasks with termination conditions + results = cf.run_tasks( + research_tasks, + instructions="Research only one topic at a time.", + run_until=( + AnyComplete(min_complete=2) # stop after two tasks (if there are more than two topics) + | MaxLLMCalls(15) # or stop after 15 LLM calls, whichever comes first + ) + ) + + completed_research = [r for r in results if isinstance(r, ResearchPoint)] + return completed_research +``` + + + +Now, if we run this workflow on 4 topics, it will stop after two topics are researched: + +```python Example Usage +# Example usage +topics = [ + "Artificial Intelligence", + "Quantum Computing", + "Biotechnology", + "Renewable Energy", +] +results = research_workflow(topics) + +print(f"Completed research on {len(results)} topics:") +for research in results: + print(f"\nTopic: {research.topic}") + print("Key Findings:") + for finding in research.key_findings: + print(f"- {finding}") +``` + +```text Result +Completed research on 2 topics: + +Topic: Artificial Intelligence +Key Findings: +- Machine Learning and Deep Learning: These are subsets of AI that involve training models on large datasets to make predictions or decisions without being explicitly programmed. They are widely used in various applications, including image and speech recognition, natural language processing, and autonomous vehicles. +- AI Ethics and Bias: As AI systems become more prevalent, ethical concerns such as bias in AI algorithms, data privacy, and the impact on employment are increasingly significant. Ensuring fairness, transparency, and accountability in AI systems is a growing area of focus. +- AI in Healthcare: AI technologies are revolutionizing healthcare through applications in diagnostics, personalized medicine, and patient monitoring. AI can analyze medical data to assist in early disease detection and treatment planning. +- Natural Language Processing (NLP): NLP is a field of AI focused on the interaction between computers and humans through natural language. Recent advancements include transformers and large language models, which have improved the ability of machines to understand and generate human language. +- AI in Autonomous Systems: AI is a crucial component in developing autonomous systems, such as self-driving cars and drones, which require perception, decision-making, and control capabilities to navigate and operate in real-world environments. + +Topic: Quantum Computing +Key Findings: +- Quantum Bits (Qubits): Unlike classical bits, qubits can exist in multiple states simultaneously due to superposition. This allows quantum computers to process a vast amount of information at once, offering a potential exponential speed-up over classical computers for certain tasks. +- Quantum Entanglement: This phenomenon allows qubits that are entangled to be correlated with each other, even when separated by large distances. Entanglement is a key resource in quantum computing and quantum communication. +- Quantum Algorithms: Quantum algorithms, such as Shor's algorithm for factoring large numbers and Grover's algorithm for searching unsorted databases, demonstrate the potential power of quantum computing over classical approaches. +- Quantum Error Correction: Quantum systems are prone to errors due to decoherence and noise from the environment. Quantum error correction methods are essential for maintaining the integrity of quantum computations. +- Applications and Challenges: Quantum computing holds promise for solving complex problems in cryptography, material science, and optimization. However, significant technological challenges remain, including maintaining qubit coherence, scaling up the number of qubits, and developing practical quantum software. +``` + +## Key Concepts + +1. **Custom Termination Conditions**: We use a combination of `AnyComplete` and `MaxLLMCalls` conditions to control when the workflow should stop. + +2. **Flexible Workflow Control**: By using termination conditions with the `run_until` parameter, we can create more dynamic workflows that adapt to different scenarios. In this case, we're balancing between getting enough research done and limiting resource usage. + +3. **Partial Results**: The workflow can end before all tasks are complete, so we handle partial results by filtering for completed `ResearchPoint` objects. + +4. **Combining Conditions**: We use the `|` operator to combine multiple termination conditions. ControlFlow also supports `&` for more complex logic. + +This example demonstrates how termination conditions provide fine-grained control over workflow execution, allowing you to balance between task completion and resource usage. This can be particularly useful for managing costs, handling time-sensitive operations, or creating more responsive AI workflows. diff --git a/docs/examples/features/memory.mdx b/docs/examples/features/memory.mdx index da79e242..7005b9ca 100644 --- a/docs/examples/features/memory.mdx +++ b/docs/examples/features/memory.mdx @@ -5,7 +5,7 @@ icon: brain --- import { VersionBadge } from '/snippets/version-badge.mdx' - + Memory in ControlFlow allows agents to store and retrieve information across different conversations or workflow executions. This is particularly useful for maintaining context over time or sharing information between separate interactions. diff --git a/docs/guides/default-memory.mdx b/docs/guides/default-memory.mdx index fe74f18f..39eceded 100644 --- a/docs/guides/default-memory.mdx +++ b/docs/guides/default-memory.mdx @@ -6,7 +6,7 @@ icon: brain --- import { VersionBadge } from '/snippets/version-badge.mdx' - + ControlFlow's [memory](/patterns/memory) feature allows agents to store and retrieve information across multiple workflows. Memory modules are backed by a vector database, configured using a `MemoryProvider`. Setting up a default provider simplifies the process of creating memory objects throughout your application. Once configured, you can create memory objects without specifying a provider each time. diff --git a/docs/mint.json b/docs/mint.json index 748e80e0..4a44f92d 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -79,7 +79,8 @@ "examples/features/tools", "examples/features/multi-llm", "examples/features/private-flows", - "examples/features/memory" + "examples/features/memory", + "examples/features/early-termination" ] }, { diff --git a/docs/patterns/memory.mdx b/docs/patterns/memory.mdx index 384cc47b..9728935b 100644 --- a/docs/patterns/memory.mdx +++ b/docs/patterns/memory.mdx @@ -5,7 +5,7 @@ icon: bookmark --- import { VersionBadge } from '/snippets/version-badge.mdx' - + Within an agentic workflow, information is naturally added to the [thread history](/patterns/history) over time, making available to all agents who participate in the workflow. However, that information is not accessible from other threads, even if they relate to the same objective or resources. diff --git a/docs/patterns/running-tasks.mdx b/docs/patterns/running-tasks.mdx index 103f4778..d45d4099 100644 --- a/docs/patterns/running-tasks.mdx +++ b/docs/patterns/running-tasks.mdx @@ -4,6 +4,8 @@ description: Control task execution and manage how agents collaborate. icon: play --- +import { VersionBadge } from "/snippets/version-badge.mdx" + Tasks represent a unit of work that needs to be completed by your agents. To execute that work and retrieve its result, you need to instruct your agents to run the task. @@ -356,6 +358,36 @@ Note that the setting `max_llm_calls` on the task results in the task failing if +#### Early termination conditions + + + +ControlFlow supports more flexible control over when an orchestration run should end through the use of `run_until` conditions. These conditions allow you to specify complex termination logic based on various factors such as task completion, failure, or custom criteria. + +To use a run until condition, you can pass it to the `run_until` parameter when calling `run`, `run_async`, `run_tasks`, or `run_tasks_async`. For example, the following tasks will run until either one of them is complete or 10 LLM calls have been made: + +```python +import controlflow as cf +from controlflow.orchestration.conditions import AnyComplete, MaxLLMCalls + +result = cf.run_tasks( + tasks=[cf.Task("write a poem about AI"), cf.Task("write a poem about ML")], + run_until=AnyComplete() | MaxLLMCalls(10) +) +``` + +(Note that because tasks can be run in parallel, it's possible for both subtasks to be completed.) + +Termination conditions can be combined using boolean logic: `|` indicates "or" and `&` indicates "and". A variety of built-in conditions are available: + +- `AllComplete()`: stop when all tasks are complete (this is the default behavior) +- `MaxLLMCalls(n: int)`: stop when `n` LLM calls have been made (equivalent to providing `max_llm_calls`) +- `MaxAgentTurns(n: int)`: stop when `n` agent turns have been made (equivalent to providing `max_agent_turns`) +- `AnyComplete(tasks: list[Task], min_complete: int=1)`: stop when at least `min_complete` tasks are complete. If no tasks are provided, all of the orchestrator's tasks will be used. +- `AnyFailed(tasks: list[Task], min_failed: int=1)`: stop when at least `min_failed` tasks have failed. If no tasks are provided, all of the orchestrator's tasks will be used. + + + ### Accessing an orchestrator directly If you want to "step" through the agentic loop yourself, you can create and invoke an `Orchestrator` directly. diff --git a/examples/early_termination.py b/examples/early_termination.py new file mode 100644 index 00000000..45d14ee0 --- /dev/null +++ b/examples/early_termination.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel + +import controlflow as cf +from controlflow.orchestration.conditions import AnyComplete, MaxLLMCalls + + +class ResearchPoint(BaseModel): + topic: str + key_findings: list[str] + + +@cf.flow +def research_workflow(topics: list[str]): + if len(topics) < 2: + raise ValueError("At least two topics are required") + + research_tasks = [ + cf.Task(f"Research {topic}", result_type=ResearchPoint) for topic in topics + ] + + # Run tasks until either two topics are researched or 15 LLM calls are made + results = cf.run_tasks( + research_tasks, + instructions="Research only one topic at a time.", + run_until=( + AnyComplete( + min_complete=2 + ) # stop after two tasks (if there are more than two topics) + | MaxLLMCalls(15) # or stop after 15 LLM calls, whichever comes first + ), + ) + + completed_research = [r for r in results if isinstance(r, ResearchPoint)] + return completed_research + + +if __name__ == "__main__": + # Example usage + topics = [ + "Artificial Intelligence", + "Quantum Computing", + "Biotechnology", + "Renewable Energy", + ] + results = research_workflow(topics) + + print(f"Completed research on {len(results)} topics:") + for research in results: + print(f"\nTopic: {research.topic}") + print("Key Findings:") + for finding in research.key_findings: + print(f"- {finding}") diff --git a/examples/reasoning.py b/examples/reasoning.py new file mode 100644 index 00000000..a5c242fa --- /dev/null +++ b/examples/reasoning.py @@ -0,0 +1,118 @@ +""" +This example implements a reasoning loop that lets a relatively simple model +solve difficult problems. + +Here, gpt-4o-mini is used to solve a problem that typically requires o1's +reasoning ability. +""" + +import argparse + +from pydantic import BaseModel, Field + +import controlflow as cf +from controlflow.utilities.general import unwrap + + +class ReasoningStep(BaseModel): + explanation: str = Field( + description=""" + A brief (<5 words) description of what you intend to + achieve in this step, to display to the user. + """ + ) + reasoning: str = Field( + description="A single step of reasoning, not more than 1 or 2 sentences." + ) + found_validated_solution: bool + + +REASONING_INSTRUCTIONS = """ + You are working on solving a difficult problem (the `goal`). Based + on your previous thoughts and the overall goal, please perform **one + reasoning step** that advances you closer to a solution. Document + your thought process and any intermediate steps you take. + + After marking this task complete for a single step, you will be + given a new reasoning task to continue working on the problem. The + loop will continue until you have a valid solution. + + Complete the task as soon as you have a valid solution. + + **Guidelines** + + - You will not be able to brute force a solution exhaustively. You + must use your reasoning ability to make a plan that lets you make + progress. + - Each step should be focused on a specific aspect of the problem, + either advancing your understanding of the problem or validating a + solution. + - You should build on previous steps without repeating them. + - Since you will iterate your reasoning, you can explore multiple + approaches in different steps. + - Use logical and analytical thinking to reason through the problem. + - Ensure that your solution is valid and meets all requirements. + - If you find yourself spinning your wheels, take a step back and + re-evaluate your approach. +""" + + +@cf.flow +def solve_with_reasoning(goal: str, agent: cf.Agent) -> str: + while True: + response: ReasoningStep = cf.run( + objective=""" + Carefully read the `goal` and analyze the problem. + + Produce a single step of reasoning that advances you closer to a solution. + """, + instructions=REASONING_INSTRUCTIONS, + result_type=ReasoningStep, + agents=[agent], + context=dict(goal=goal), + model_kwargs=dict(tool_choice="required"), + ) + + if response.found_validated_solution: + if cf.run( + """ + Check your solution to be absolutely sure that it is correct and meets all requirements of the goal. Return True if it does. + """, + result_type=bool, + context=dict(goal=goal), + ): + break + + return cf.run(objective=goal, agents=[agent]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Solve a reasoning problem.") + parser.add_argument("--goal", type=str, help="Custom goal to solve", default=None) + args = parser.parse_args() + + agent = cf.Agent(name="Definitely not GPT-4o mini", model="openai/gpt-4o-mini") + + # Default goal via https://www.reddit.com/r/singularity/comments/1fggo1e/comment/ln3ymsu/ + default_goal = """ + Using only four instances of the digit 9 and any combination of the following + mathematical operations: the decimal point, parentheses, addition (+), + subtraction (-), multiplication (*), division (/), factorial (!), and square + root (sqrt), create an equation that equals 24. + + In order to validate your result, you should test that you have followed the rules: + + 1. You have used the correct number of variables + 2. You have only used 9s and potentially a leading 0 for a decimal + 3. You have used valid mathematical symbols + 4. Your equation truly equates to 24. + """ + + # Use the provided goal if available, otherwise use the default + goal = args.goal if args.goal is not None else default_goal + goal = unwrap(goal) + print(f"The goal is:\n\n{goal}") + + result = solve_with_reasoning(goal=goal, agent=agent) + + print(f"The solution is:\n\n{result}") diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index b9f0dffc..27a0218c 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -18,6 +18,7 @@ from .tools import tool from .run import run, run_async, run_tasks, run_tasks_async from .plan import plan +import controlflow.orchestration # --- Version --- diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 8d088eaf..5a38fd9e 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -32,7 +32,7 @@ handle_tool_call_async, ) from controlflow.utilities.context import ctx -from controlflow.utilities.general import ControlFlowModel, hash_objects +from controlflow.utilities.general import ControlFlowModel, hash_objects, unwrap from controlflow.utilities.prefect import create_markdown_artifact, prefect_task if TYPE_CHECKING: @@ -128,6 +128,12 @@ def _generate_id(self): ) ) + @field_validator("instructions") + def _validate_instructions(cls, v): + if v: + v = unwrap(v) + return v + @field_validator("tools", mode="before") def _validate_tools(cls, tools: list[Tool]): return as_tools(tools or []) diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index 8d51588e..f76c0bc1 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Union from prefect.context import FlowRunContext -from pydantic import Field +from pydantic import Field, field_validator from typing_extensions import Self import controlflow @@ -11,7 +11,7 @@ from controlflow.events.base import Event from controlflow.events.history import History from controlflow.utilities.context import ctx -from controlflow.utilities.general import ControlFlowModel +from controlflow.utilities.general import ControlFlowModel, unwrap from controlflow.utilities.logging import get_logger from controlflow.utilities.prefect import prefect_flow_context @@ -70,6 +70,12 @@ def __init__(self, **kwargs): kwargs["parent"] = get_flow() super().__init__(**kwargs) + @field_validator("description") + def _validate_description(cls, v): + if v: + v = unwrap(v) + return v + def get_prompt(self) -> str: """ Generate a prompt to share information about the flow with an agent. diff --git a/src/controlflow/flows/graph.py b/src/controlflow/flows/graph.py index 2c2af8df..d200958a 100644 --- a/src/controlflow/flows/graph.py +++ b/src/controlflow/flows/graph.py @@ -73,7 +73,7 @@ def add_task(self, task: Task): ) # add the task's subtasks - for subtask in task._subtasks: + for subtask in task.subtasks: self.add_edge( Edge( upstream=subtask, diff --git a/src/controlflow/llm/rules.py b/src/controlflow/llm/rules.py index f8626ff6..21f7596a 100644 --- a/src/controlflow/llm/rules.py +++ b/src/controlflow/llm/rules.py @@ -55,17 +55,6 @@ class OpenAIRules(LLMRules): def model_instructions(self) -> list[str]: instructions = [] - if self.model.model_name.endswith("gpt-4o-mini"): - instructions.append( - unwrap( - """ - You can only provide a single result for each task, and a - task can only be marked successful one time. Do not make - multiple tool calls in parallel to supply multiple results - to the same task. - """ - ) - ) return instructions diff --git a/src/controlflow/orchestration/__init__.py b/src/controlflow/orchestration/__init__.py index 8f3ed651..e4870f81 100644 --- a/src/controlflow/orchestration/__init__.py +++ b/src/controlflow/orchestration/__init__.py @@ -1,2 +1,3 @@ +from . import conditions from .orchestrator import Orchestrator from .handler import Handler diff --git a/src/controlflow/orchestration/conditions.py b/src/controlflow/orchestration/conditions.py new file mode 100644 index 00000000..aee1c852 --- /dev/null +++ b/src/controlflow/orchestration/conditions.py @@ -0,0 +1,166 @@ +import logging +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from pydantic import BaseModel, field_validator + +from controlflow.tasks.task import Task +from controlflow.utilities.general import ControlFlowModel +from controlflow.utilities.logging import get_logger + +if TYPE_CHECKING: + from controlflow.orchestration.orchestrator import Orchestrator + +logger = get_logger(__name__) + + +class RunContext(ControlFlowModel): + """ + Context for a run. + """ + + model_config = dict(arbitrary_types_allowed=True) + + orchestrator: "Orchestrator" + llm_calls: int = 0 + agent_turns: int = 0 + run_end_condition: "RunEndCondition" + + @field_validator("run_end_condition", mode="before") + def validate_condition(cls, v: Any) -> "RunEndCondition": + if not isinstance(v, RunEndCondition): + v = FnCondition(v) + return v + + def should_end(self) -> bool: + return self.run_end_condition.should_end(self) + + +class RunEndCondition: + def should_end(self, context: RunContext) -> bool: + """ + Returns True if the run should end, False otherwise. + """ + return False + + def __or__( + self, other: Union["RunEndCondition", Callable[[RunContext], bool]] + ) -> "RunEndCondition": + if isinstance(other, RunEndCondition): + return OR_(self, other) + elif callable(other): + return OR_(self, FnCondition(other)) + else: + raise NotImplementedError( + f"Cannot combine RunEndCondition with {type(other)}" + ) + + def __and__( + self, other: Union["RunEndCondition", Callable[[RunContext], bool]] + ) -> "RunEndCondition": + if isinstance(other, RunEndCondition): + return AND_(self, other) + elif callable(other): + return AND_(self, FnCondition(other)) + else: + raise NotImplementedError( + f"Cannot combine RunEndCondition with {type(other)}" + ) + + +class FnCondition(RunEndCondition): + def __init__(self, fn: Callable[[RunContext], bool]): + self.fn = fn + + def should_end(self, context: RunContext) -> bool: + result = self.fn(context) + if result: + logger.debug("Custom function condition met; ending run.") + return result + + +class OR_(RunEndCondition): + def __init__(self, *conditions: RunEndCondition): + self.conditions = conditions + + def should_end(self, context: RunContext) -> bool: + result = any(condition.should_end(context) for condition in self.conditions) + if result: + logger.debug("At least one condition in OR clause met.") + return result + + +class AND_(RunEndCondition): + def __init__(self, *conditions: RunEndCondition): + self.conditions = conditions + + def should_end(self, context: RunContext) -> bool: + result = all(condition.should_end(context) for condition in self.conditions) + if result: + logger.debug("All conditions in AND clause met.") + return result + + +class AllComplete(RunEndCondition): + def __init__(self, tasks: Optional[list[Task]] = None): + self.tasks = tasks + + def should_end(self, context: RunContext) -> bool: + tasks = self.tasks if self.tasks is not None else context.orchestrator.tasks + result = all(t.is_complete() for t in tasks) + if result: + logger.debug("All tasks are complete; ending run.") + return result + + +class AnyComplete(RunEndCondition): + def __init__(self, tasks: Optional[list[Task]] = None, min_complete: int = 1): + self.tasks = tasks + if min_complete < 1: + raise ValueError("min_complete must be at least 1") + self.min_complete = min_complete + + def should_end(self, context: RunContext) -> bool: + tasks = self.tasks if self.tasks is not None else context.orchestrator.tasks + result = sum(t.is_complete() for t in tasks) >= self.min_complete + if result: + logger.debug("At least one task is complete; ending run.") + return result + + +class AnyFailed(RunEndCondition): + def __init__(self, tasks: Optional[list[Task]] = None, min_failed: int = 1): + self.tasks = tasks + if min_failed < 1: + raise ValueError("min_failed must be at least 1") + self.min_failed = min_failed + + def should_end(self, context: RunContext) -> bool: + tasks = self.tasks if self.tasks is not None else context.orchestrator.tasks + result = sum(t.is_failed() for t in tasks) >= self.min_failed + if result: + logger.debug("At least one task has failed; ending run.") + return result + + +class MaxAgentTurns(RunEndCondition): + def __init__(self, n: int): + self.n = n + + def should_end(self, context: RunContext) -> bool: + result = context.agent_turns >= self.n + if result: + logger.debug( + f"Maximum number of agent turns ({self.n}) reached; ending run." + ) + return result + + +class MaxLLMCalls(RunEndCondition): + def __init__(self, n: int): + self.n = n + + def should_end(self, context: RunContext) -> bool: + result = context.llm_calls >= self.n + if result: + logger.debug(f"Maximum number of LLM calls ({self.n}) reached; ending run.") + return result diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 960f8a6a..c6fff6fb 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -1,7 +1,7 @@ import logging -from typing import Optional, TypeVar +from typing import Callable, Optional, TypeVar, Union -from pydantic import Field, field_validator +from pydantic import BaseModel, Field, field_validator import controlflow from controlflow.agents.agent import Agent @@ -12,6 +12,14 @@ from controlflow.instructions import get_instructions from controlflow.llm.messages import BaseMessage from controlflow.memory import Memory +from controlflow.orchestration.conditions import ( + AllComplete, + FnCondition, + MaxAgentTurns, + MaxLLMCalls, + RunContext, + RunEndCondition, +) from controlflow.orchestration.handler import Handler from controlflow.orchestration.turn_strategies import Popcorn, TurnStrategy from controlflow.tasks.task import Task @@ -141,11 +149,29 @@ def run( max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None, model_kwargs: Optional[dict] = None, - ): + run_until: Optional[ + Union[RunEndCondition, Callable[[RunContext], bool]] + ] = None, + ) -> RunContext: import controlflow.events.orchestrator_events - call_count = 0 - turn_count = 0 + # Create the base termination condition + if run_until is None: + run_until = AllComplete() + elif not isinstance(run_until, RunEndCondition): + run_until = FnCondition(run_until) + + # Add max_llm_calls condition + if max_llm_calls is None: + max_llm_calls = controlflow.settings.orchestrator_max_llm_calls + run_until = run_until | MaxLLMCalls(max_llm_calls) + + # Add max_agent_turns condition + if max_agent_turns is None: + max_agent_turns = controlflow.settings.orchestrator_max_agent_turns + run_until = run_until | MaxAgentTurns(max_agent_turns) + + run_context = RunContext(orchestrator=self, run_end_condition=run_until) # Initialize the agent if not already set if not self.agent: @@ -153,24 +179,14 @@ def run( None, self.get_available_agents() ) - if max_agent_turns is None: - max_agent_turns = controlflow.settings.orchestrator_max_agent_turns - if max_llm_calls is None: - max_llm_calls = controlflow.settings.orchestrator_max_llm_calls - # Signal the start of orchestration self.handle_event( controlflow.events.orchestrator_events.OrchestratorStart(orchestrator=self) ) try: - while any(t.is_incomplete() for t in self.tasks): - # Check if we've reached the turn or call limit - if max_agent_turns is not None and turn_count >= max_agent_turns: - logger.debug(f"Max agent turns reached: {max_agent_turns}") - break - - if max_llm_calls is not None and call_count >= max_llm_calls: + while True: + if run_context.should_end(): break self.handle_event( @@ -178,9 +194,8 @@ def run( orchestrator=self, agent=self.agent ) ) - turn_count += 1 - call_count += self.run_agent_turn( - max_llm_calls - call_count, + self.run_agent_turn( + run_context=run_context, model_kwargs=model_kwargs, ) self.handle_event( @@ -210,6 +225,7 @@ def run( orchestrator=self ) ) + return run_context @prefect_task async def run_async( @@ -217,19 +233,29 @@ async def run_async( max_llm_calls: Optional[int] = None, max_agent_turns: Optional[int] = None, model_kwargs: Optional[dict] = None, - ): - """ - Run the orchestration process asynchronously until completion or limits are reached. - - Args: - max_llm_calls (int, optional): Maximum number of LLM calls to make. - max_agent_turns (int, optional): Maximum number of agent turns to run - (each turn can consist of multiple LLM calls) - """ + run_until: Optional[ + Union[RunEndCondition, Callable[[RunContext], bool]] + ] = None, + ) -> RunContext: import controlflow.events.orchestrator_events - call_count = 0 - turn_count = 0 + # Create the base termination condition + if run_until is None: + run_until = AllComplete() + elif not isinstance(run_until, RunEndCondition): + run_until = FnCondition(run_until) + + # Add max_llm_calls condition + if max_llm_calls is None: + max_llm_calls = controlflow.settings.orchestrator_max_llm_calls + run_until = run_until | MaxLLMCalls(max_llm_calls) + + # Add max_agent_turns condition + if max_agent_turns is None: + max_agent_turns = controlflow.settings.orchestrator_max_agent_turns + run_until = run_until | MaxAgentTurns(max_agent_turns) + + run_context = RunContext(orchestrator=self, run_end_condition=run_until) # Initialize the agent if not already set if not self.agent: @@ -237,24 +263,15 @@ async def run_async( None, self.get_available_agents() ) - if max_agent_turns is None: - max_agent_turns = controlflow.settings.orchestrator_max_agent_turns - if max_llm_calls is None: - max_llm_calls = controlflow.settings.orchestrator_max_llm_calls - # Signal the start of orchestration self.handle_event( controlflow.events.orchestrator_events.OrchestratorStart(orchestrator=self) ) try: - while any(t.is_incomplete() for t in self.tasks): - # Check if we've reached the turn or call limit - if max_agent_turns is not None and turn_count >= max_agent_turns: - logger.debug(f"Max agent turns reached: {max_agent_turns}") - break - - if max_llm_calls is not None and call_count >= max_llm_calls: + while True: + # Check termination condition + if run_context.should_end(): break self.handle_event( @@ -262,9 +279,8 @@ async def run_async( orchestrator=self, agent=self.agent ) ) - turn_count += 1 - call_count += await self.run_agent_turn_async( - max_llm_calls - call_count, + await self.run_agent_turn_async( + run_context=run_context, model_kwargs=model_kwargs, ) self.handle_event( @@ -294,23 +310,17 @@ async def run_async( orchestrator=self ) ) + return run_context @prefect_task(task_run_name="Agent turn: {self.agent.name}") def run_agent_turn( self, - max_llm_calls: Optional[int], + run_context: RunContext, model_kwargs: Optional[dict] = None, ) -> int: """ Run a single agent turn, which may consist of multiple LLM calls. - - Args: - max_llm_calls (Optional[int]): The number of LLM calls allowed. - - Returns: - int: The number of LLM calls made during this turn. """ - call_count = 0 assigned_tasks = self.get_tasks("assigned") self.turn_strategy.begin_turn() @@ -321,28 +331,25 @@ def run_agent_turn( task.mark_running() self.handle_event( OrchestratorMessage( - content=f"Starting task {task.name} (ID {task.id}) " + content=f"Starting task {task.name + ' ' if task.name else ''}(ID {task.id}) " f"with objective: {task.objective}" ) ) while not self.turn_strategy.should_end_turn(): + # fail any tasks that have reached their max llm calls for task in assigned_tasks: if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: task.mark_failed(reason="Max LLM calls reached for this task.") - else: - task._llm_calls += 1 # Check if there are any ready tasks left if not any(t.is_ready() for t in assigned_tasks): logger.debug("No `ready` tasks to run") break - if not any(t.is_incomplete() for t in self.tasks): - logger.debug("No incomplete tasks left") + if run_context.should_end(): break - call_count += 1 messages = self.compile_messages() tools = self.get_tools() @@ -353,17 +360,16 @@ def run_agent_turn( ): self.handle_event(event) - # Check if we've reached the call limit within a turn - if max_llm_calls is not None and call_count >= max_llm_calls: - logger.debug(f"Max LLM calls reached: {max_llm_calls}") - break + run_context.llm_calls += 1 + for task in assigned_tasks: + task._llm_calls += 1 - return call_count + run_context.agent_turns += 1 @prefect_task async def run_agent_turn_async( self, - max_llm_calls: Optional[int], + run_context: RunContext, model_kwargs: Optional[dict] = None, ) -> int: """ @@ -375,7 +381,6 @@ async def run_agent_turn_async( Returns: int: The number of LLM calls made during this turn. """ - call_count = 0 assigned_tasks = self.get_tasks("assigned") self.turn_strategy.begin_turn() @@ -392,22 +397,19 @@ async def run_agent_turn_async( ) while not self.turn_strategy.should_end_turn(): + # fail any tasks that have reached their max llm calls for task in assigned_tasks: if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: task.mark_failed(reason="Max LLM calls reached for this task.") - else: - task._llm_calls += 1 # Check if there are any ready tasks left if not any(t.is_ready() for t in assigned_tasks): logger.debug("No `ready` tasks to run") break - if not any(t.is_incomplete() for t in self.tasks): - logger.debug("No incomplete tasks left") + if run_context.should_end(): break - call_count += 1 messages = self.compile_messages() tools = self.get_tools() @@ -418,12 +420,11 @@ async def run_agent_turn_async( ): self.handle_event(event) - # Check if we've reached the call limit within a turn - if max_llm_calls is not None and call_count >= max_llm_calls: - logger.debug(f"Max LLM calls reached: {max_llm_calls}") - break + run_context.llm_calls += 1 + for task in assigned_tasks: + task._llm_calls += 1 - return call_count + run_context.agent_turns += 1 def compile_prompt(self) -> str: """ @@ -554,3 +555,6 @@ def get_task_hierarchy(self) -> dict: hierarchy[task.id] = task_dict_map[task.id] return hierarchy + + +RunContext.model_rebuild() diff --git a/src/controlflow/orchestration/prompt_templates/flow.jinja b/src/controlflow/orchestration/prompt_templates/flow.jinja index 1d23fb87..b0902ecc 100644 --- a/src/controlflow/orchestration/prompt_templates/flow.jinja +++ b/src/controlflow/orchestration/prompt_templates/flow.jinja @@ -1,6 +1,6 @@ # Flow -Here is context about the flow you are participating in. +Here is context about the flow/thread you are participating in. - Name: {{ flow.name }} {% if flow.description %} diff --git a/src/controlflow/orchestration/prompt_templates/task.jinja b/src/controlflow/orchestration/prompt_templates/task.jinja index 8abc34b8..23fb868f 100644 --- a/src/controlflow/orchestration/prompt_templates/task.jinja +++ b/src/controlflow/orchestration/prompt_templates/task.jinja @@ -3,4 +3,10 @@ - objective: {{ task.objective }} {% if task.instructions %}- instructions: {{ task.instructions }}{% endif %} {% if task.result_type %}- result type: {{ task.result_type }}{% endif %} -{% if task.context %}- context: {{ task.context }}{% endif %} \ No newline at end of file +{% if task.context %}- context: {{ task.context }}{% endif %} +{% if task.parent %}- parent task ID: {{ task.parent.id }}{%endif %} +{% if task._subtasks%}- this task has the following subtask IDs: {{ task._subtasks | map(attribute='id') | join(', ') }} +{% if not task.wait_for_subtasks %}- complete this task as soon as you meet its objective, even if you haven't completed +its subtasks{% endif%}{% endif %} +{% if task.depends_on %}- this task depends on these upstream task IDs (includes subtasks): {{ task.depends_on | +map(attribute='id') | join(', ') }}{% endif %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompt_templates/tasks.jinja b/src/controlflow/orchestration/prompt_templates/tasks.jinja index d194f3ce..36868957 100644 --- a/src/controlflow/orchestration/prompt_templates/tasks.jinja +++ b/src/controlflow/orchestration/prompt_templates/tasks.jinja @@ -1,11 +1,11 @@ -{% macro render_task_hierarchy(task_info, indent='') -%} -{{ indent }}- {{ task_info.task.id }} ({{ task_info.task.status.value }}){% if task_info['is_active'] %} (active){% -endif %} +{% macro render_task_hierarchy(task_info, indent='') %} +{{ indent }}- {{ task_info.task.id }} ({{ task_info.task.status.value }}){% if task_info['is_active'] %} +(active){%endif%} {%- if task_info.children %} {% for child in task_info.children %} {{ render_task_hierarchy(child, indent + '-') }} -{%- endfor %} +{% endfor %} {%- endif %} {%- endmacro -%} @@ -38,8 +38,13 @@ successful more than once. Even if the `result_type` does not appear to match the objective, you must supply a single compatible result. Only mark a task failed if there is a technical error or issue preventing completion. +When a parent task must wait for subtasks, it means that all of its subtasks are +treated as upstream dependencies and must be completed before the parent task +can be marked as complete. However, if the parent task has +`wait_for_subtasks=False`, then it can and should be marked as complete as soon +as you can, regardless of the status of its subtasks. -## Task hierarchy +## Subtask hierarchy {% for task in task_hierarchy %} {{ render_task_hierarchy(task) }} diff --git a/src/controlflow/run.py b/src/controlflow/run.py index 10a538ed..59c2fe58 100644 --- a/src/controlflow/run.py +++ b/src/controlflow/run.py @@ -1,9 +1,11 @@ -from typing import Any, Optional +from typing import Any, Callable, Optional, Union from prefect.context import TaskRunContext +import controlflow from controlflow.agents.agent import Agent from controlflow.flows import Flow, get_flow +from controlflow.orchestration.conditions import RunContext, RunEndCondition from controlflow.orchestration.handler import Handler from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy from controlflow.tasks.task import Task @@ -20,6 +22,7 @@ def get_task_run_name() -> str: @prefect_task(task_run_name=get_task_run_name) def run_tasks( tasks: list[Task], + instructions: str = None, flow: Flow = None, agent: Agent = None, turn_strategy: TurnStrategy = None, @@ -28,6 +31,7 @@ def run_tasks( max_agent_turns: int = None, handlers: list[Handler] = None, model_kwargs: Optional[dict] = None, + run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, ) -> list[Any]: """ Run a list of tasks. @@ -43,11 +47,14 @@ def run_tasks( turn_strategy=turn_strategy, handlers=handlers, ) - orchestrator.run( - max_llm_calls=max_llm_calls, - max_agent_turns=max_agent_turns, - model_kwargs=model_kwargs, - ) + + with controlflow.instructions(instructions): + orchestrator.run( + max_llm_calls=max_llm_calls, + max_agent_turns=max_agent_turns, + model_kwargs=model_kwargs, + run_until=run_until, + ) if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] @@ -63,6 +70,7 @@ def run_tasks( @prefect_task(task_run_name=get_task_run_name) async def run_tasks_async( tasks: list[Task], + instructions: str = None, flow: Flow = None, agent: Agent = None, turn_strategy: TurnStrategy = None, @@ -71,9 +79,10 @@ async def run_tasks_async( max_agent_turns: int = None, handlers: list[Handler] = None, model_kwargs: Optional[dict] = None, + run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, ): """ - Run a list of tasks. + Run a list of tasks asynchronously. """ flow = flow or get_flow() or Flow() orchestrator = Orchestrator( @@ -83,11 +92,14 @@ async def run_tasks_async( turn_strategy=turn_strategy, handlers=handlers, ) - await orchestrator.run_async( - max_llm_calls=max_llm_calls, - max_agent_turns=max_agent_turns, - model_kwargs=model_kwargs, - ) + + with controlflow.instructions(instructions): + await orchestrator.run_async( + max_llm_calls=max_llm_calls, + max_agent_turns=max_agent_turns, + model_kwargs=model_kwargs, + run_until=run_until, + ) if raise_on_failure and any(t.is_failed() for t in tasks): errors = [f"- {t.friendly_name()}: {t.result}" for t in tasks if t.is_failed()] @@ -109,6 +121,7 @@ def run( raise_on_failure: bool = True, handlers: list[Handler] = None, model_kwargs: Optional[dict] = None, + run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) @@ -120,6 +133,7 @@ def run( max_agent_turns=max_agent_turns, handlers=handlers, model_kwargs=model_kwargs, + run_until=run_until, ) return results[0] @@ -135,6 +149,7 @@ async def run_async( raise_on_failure: bool = True, handlers: list[Handler] = None, model_kwargs: Optional[dict] = None, + run_until: Optional[Union[RunEndCondition, Callable[[RunContext], bool]]] = None, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) @@ -148,5 +163,6 @@ async def run_async( raise_on_failure=raise_on_failure, handlers=handlers, model_kwargs=model_kwargs, + run_until=run_until, ) return results[0] diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index c728d3b4..a0560923 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -18,6 +18,7 @@ _LiteralGenericAlias, _SpecialGenericAlias, ) +from uuid import uuid4 from prefect.context import TaskRunContext from pydantic import ( @@ -106,8 +107,7 @@ class Task(ControlFlowModel): ) context: dict = Field( default_factory=dict, - description="Additional context for the task. If tasks are provided as " - "context, they are automatically added as `depends_on`", + description="Additional context for the task.", ) parent: Optional["Task"] = Field( NOTSET, @@ -232,16 +232,18 @@ def __init__( self.id = self._generate_id() def _generate_id(self): - return hash_objects( - ( - type(self).__name__, - self.objective, - self.instructions, - str(self.result_type), - self.prompt, - str(self.context), - ) - ) + return str(uuid4())[:8] + # generate a short, semi-stable ID for a task + # return hash_objects( + # ( + # type(self).__name__, + # self.objective, + # self.instructions, + # str(self.result_type), + # self.prompt, + # str(self.context), + # ) + # ) def __hash__(self) -> int: return id(self) @@ -256,9 +258,16 @@ def __eq__(self, other): if type(self) is type(other): d1 = dict(self) d2 = dict(other) + + for attr in ["id", "created_at"]: + d1.pop(attr) + d2.pop(attr) + # conver sets to lists for comparison d1["depends_on"] = list(d1["depends_on"]) d2["depends_on"] = list(d2["depends_on"]) + d1["subtasks"] = list(self.subtasks) + d2["subtasks"] = list(other.subtasks) return d1 == d2 return False @@ -266,6 +275,18 @@ def __repr__(self) -> str: serialized = self.model_dump(include={"id", "objective"}) return f"{self.__class__.__name__}({', '.join(f'{key}={repr(value)}' for key, value in serialized.items())})" + @field_validator("objective") + def _validate_objective(cls, v): + if v: + v = unwrap(v) + return v + + @field_validator("instructions") + def _validate_instructions(cls, v): + if v: + v = unwrap(v) + return v + @field_validator("agents") def _validate_agents(cls, v): if isinstance(v, list) and not v: @@ -360,7 +381,6 @@ def add_subtask(self, task: "Task"): elif task.parent is not self: raise ValueError(f"{self.friendly_name()} already has a parent.") self._subtasks.add(task) - self.depends_on.add(task) def add_dependency(self, task: "Task"): """ @@ -474,8 +494,8 @@ def is_ready(self) -> bool: incomplete, meaning it is ready to be worked on. """ depends_on = self.depends_on - if not self.wait_for_subtasks: - depends_on = depends_on.difference(self._subtasks) + if self.wait_for_subtasks: + depends_on = depends_on.union(self._subtasks) return self.is_incomplete() and all(t.is_complete() for t in depends_on) @@ -560,8 +580,7 @@ def get_success_tool(self) -> Tool: """ options = {} instructions = unwrap(""" - Use this tool to mark the task as successful and provide a result. - This tool can only be used one time per task. + Use this tool to mark the task as successful and provide a result. """) result_schema = None diff --git a/src/controlflow/utilities/testing.py b/src/controlflow/utilities/testing.py index dced30d5..d10b977d 100644 --- a/src/controlflow/utilities/testing.py +++ b/src/controlflow/utilities/testing.py @@ -1,3 +1,5 @@ +import json +import uuid from contextlib import contextmanager from typing import Union @@ -5,7 +7,7 @@ import controlflow from controlflow.events.history import InMemoryHistory -from controlflow.llm.messages import AIMessage, BaseMessage +from controlflow.llm.messages import AIMessage, BaseMessage, ToolCall from controlflow.tasks.task import Task COUNTER = 0 @@ -28,16 +30,30 @@ def __init__(self, *, responses: list[Union[str, BaseMessage]] = None, **kwargs) self.set_responses(responses or ["Hello! This is a response from the FakeLLM."]) def set_responses(self, responses: list[Union[str, BaseMessage]]): - if any(not isinstance(m, (str, BaseMessage)) for m in responses): + messages = [] + + for r in responses: + if isinstance(r, str): + messages.append(AIMessage(content=r)) + elif isinstance(r, dict): + messages.append( + AIMessage( + content="", + tool_calls=[ + ToolCall(name=r["name"], args=r.get("args", {}), id="") + ], + ) + ) + else: + messages.append(r) + + if any(not isinstance(m, BaseMessage) for m in messages): raise ValueError( - "Responses must be provided as either a list of strings or AIMessages. " + "Responses must be provided as either a list of strings, tool call dicts, or AIMessages. " "Each item in the list will be emitted in a cycle when the LLM is called." ) - responses = [ - AIMessage(content=m) if isinstance(m, str) else m for m in responses - ] - self.responses = responses + self.responses = messages def bind_tools(self, *args, **kwargs): """When binding tools, passthrough""" diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index cafe7343..cc420886 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -57,6 +57,7 @@ def test_agent_loads_instructions_at_creation(self): assert "test instruction" in agent.instructions + @pytest.mark.skip(reason="IDs are not stable right now") def test_stable_id(self): agent = Agent(name="Test Agent") assert agent.id == "69dd1abd" diff --git a/tests/orchestration/test_orchestrator.py b/tests/orchestration/test_orchestrator.py index 522329f5..a31e0e82 100644 --- a/tests/orchestration/test_orchestrator.py +++ b/tests/orchestration/test_orchestrator.py @@ -1,26 +1,28 @@ +from unittest.mock import MagicMock, patch + import pytest +import controlflow.orchestration.conditions from controlflow.agents import Agent from controlflow.flows import Flow from controlflow.orchestration.orchestrator import Orchestrator -from controlflow.orchestration.turn_strategies import ( # Add this import - Popcorn, - TurnStrategy, -) +from controlflow.orchestration.turn_strategies import Popcorn, TurnStrategy from controlflow.tasks.task import Task +from controlflow.utilities.testing import FakeLLM, SimpleTask class TestOrchestratorLimits: - call_count = 0 - turn_count = 0 - @pytest.fixture - def mocked_orchestrator(self, default_fake_llm): - # Reset counts at the start of each test - self.call_count = 0 - self.turn_count = 0 + def orchestrator(self, default_fake_llm): + default_fake_llm.set_responses([dict(name="count_call")]) + self.calls = 0 + self.turns = 0 class TwoCallTurnStrategy(TurnStrategy): + """ + A turn strategy that ends a turn after 2 calls + """ + calls: int = 0 def get_tools(self, *args, **kwargs): @@ -30,84 +32,52 @@ def get_next_agent(self, current_agent, available_agents): return current_agent def begin_turn(ts_instance): - self.turn_count += 1 + self.turns += 1 super().begin_turn() - def should_end_turn(ts_instance): - ts_instance.calls += 1 + def should_end_turn(ts_self): + ts_self.calls += 1 # if this would be the third call, end the turn - if ts_instance.calls >= 3: - ts_instance.calls = 0 + if ts_self.calls >= 3: + ts_self.calls = 0 return True # record a new call for the unit test - self.call_count += 1 + # self.calls += 1 return False - agent = Agent() + def count_call(): + self.calls += 1 + + agent = Agent(tools=[count_call]) task = Task("Test task", agents=[agent]) flow = Flow() orchestrator = Orchestrator( - tasks=[task], flow=flow, agent=agent, turn_strategy=TwoCallTurnStrategy() + tasks=[task], + flow=flow, + agent=agent, + turn_strategy=TwoCallTurnStrategy(), ) - return orchestrator - def test_default_limits(self, mocked_orchestrator): - mocked_orchestrator.run() - - assert self.turn_count == 5 - assert self.call_count == 10 - - @pytest.mark.parametrize( - "max_agent_turns, max_llm_calls, expected_turns, expected_calls", - [ - (1, 1, 1, 1), - (1, 2, 1, 2), - (5, 3, 2, 3), - (3, 12, 3, 6), - ], - ) - def test_custom_limits( - self, - mocked_orchestrator, - max_agent_turns, - max_llm_calls, - expected_turns, - expected_calls, - ): - mocked_orchestrator.run( - max_agent_turns=max_agent_turns, max_llm_calls=max_llm_calls + def test_max_llm_calls(self, orchestrator): + orchestrator.run(max_llm_calls=5) + assert self.calls == 5 + + def test_max_agent_turns(self, orchestrator): + orchestrator.run(max_agent_turns=3) + assert self.calls == 6 + + def test_max_llm_calls_and_max_agent_turns(self, orchestrator): + orchestrator.run( + max_llm_calls=10, + max_agent_turns=3, + model_kwargs={"tool_choice": "required"}, ) + assert self.calls == 6 - assert self.turn_count == expected_turns - assert self.call_count == expected_calls - - def test_task_limit(self, mocked_orchestrator): - task = Task("Test task", max_llm_calls=5, agents=[mocked_orchestrator.agent]) - mocked_orchestrator.tasks = [task] - mocked_orchestrator.run() - assert task.is_failed() - assert self.turn_count == 3 - # Note: the call count will be 6 because the orchestrator call count is - # incremented in "should_end_turn" which is called before the task's - # call count is evaluated - assert self.call_count == 6 - - def test_task_lifetime_limit(self, mocked_orchestrator): - task = Task("Test task", max_llm_calls=5, agents=[mocked_orchestrator.agent]) - mocked_orchestrator.tasks = [task] - mocked_orchestrator.run(max_agent_turns=1) - assert task.is_incomplete() - mocked_orchestrator.run(max_agent_turns=1) - assert task.is_incomplete() - mocked_orchestrator.run(max_agent_turns=1) - assert task.is_failed() - - assert self.turn_count == 3 - # Note: the call count will be 6 because the orchestrator call count is - # incremented in "should_end_turn" which is called before the task's - # call count is evaluated - assert self.call_count == 6 + def test_default_limits(self, orchestrator): + orchestrator.run(model_kwargs={"tool_choice": "required"}) + assert self.calls == 10 # Assuming the default max_llm_calls is 10 class TestOrchestratorCreation: @@ -162,3 +132,120 @@ def test_run_keeps_existing_agent_if_set(self): orchestrator.run(max_agent_turns=0) assert orchestrator.agent == agent1 + + +class TestRunEndConditions: + def test_run_until_all_complete(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + task1.mark_successful() + task2.mark_successful() + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run(run_until=controlflow.orchestration.conditions.AllComplete()) + + assert all(task.is_complete() for task in orchestrator.tasks) + + def test_run_until_any_complete(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + task1.mark_successful() + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run(run_until=controlflow.orchestration.conditions.AnyComplete()) + + assert any(task.is_complete() for task in orchestrator.tasks) + + def test_run_until_fn_condition(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + task2.mark_successful() + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run( + run_until=controlflow.orchestration.conditions.FnCondition( + lambda context: context.orchestrator.tasks[1].is_complete() + ) + ) + + assert task2.is_complete() + + def test_run_until_lambda_condition(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + task2.mark_successful() + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run( + run_until=lambda context: context.orchestrator.tasks[1].is_complete() + ) + + assert task2.is_complete() + + def test_compound_condition(self, monkeypatch): + task1 = SimpleTask() + task2 = SimpleTask() + orchestrator = Orchestrator(tasks=[task1, task2], flow=Flow(), agent=Agent()) + + # Mock the run_agent_turn method + def mock_run_agent_turn(*args, **kwargs): + return 1 + + monkeypatch.setitem( + orchestrator.__dict__, + "run_agent_turn", + MagicMock(side_effect=mock_run_agent_turn), + ) + + orchestrator.run( + run_until=( + # this condition will always fail + controlflow.orchestration.conditions.FnCondition(lambda context: False) + | + # this condition will always pass + controlflow.orchestration.conditions.FnCondition(lambda context: True) + ) + ) + + # assert to prove we reach this point and the run stopped + assert True diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 4f0a7fec..c1f7f461 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -47,6 +47,7 @@ def test_task_initialization(): assert task.result is None +@pytest.mark.skip(reason="IDs are not stable right now") def test_stable_id(): t1 = Task(objective="Test Objective") t2 = Task(objective="Test Objective") diff --git a/tests/test_run.py b/tests/test_run.py index d59d3ab8..c3a2fdd0 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,7 +1,10 @@ +from controlflow import instructions from controlflow.events.base import Event from controlflow.events.events import AgentMessage +from controlflow.orchestration.conditions import AnyComplete, AnyFailed, MaxLLMCalls from controlflow.orchestration.handler import Handler -from controlflow.run import run, run_async +from controlflow.run import run, run_async, run_tasks, run_tasks_async +from controlflow.tasks.task import Task class TestHandlers: @@ -40,3 +43,127 @@ def test_run(): async def test_run_async(): result = await run_async("what's 2 + 2", result_type=int) assert result == 4 + + +class TestRunUntil: + def test_any_complete(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("complete task 2"): + run_tasks([task1, task2], run_until=AnyComplete()) + + assert task2.is_complete() + assert task1.is_incomplete() + + def test_any_failed(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("fail task 2"): + run_tasks([task1, task2], run_until=AnyFailed(), raise_on_failure=False) + + assert task2.is_failed() + assert task1.is_incomplete() + + def test_max_llm_calls(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("say hi but do not complete any tasks"): + run_tasks([task1, task2], run_until=MaxLLMCalls(1)) + + assert task2.is_incomplete() + assert task1.is_incomplete() + + def test_min_complete(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + task3 = Task("Task 3") + + with instructions("complete tasks 1 and 2"): + run_tasks([task1, task2, task3], run_until=AnyComplete(min_complete=2)) + + assert task1.is_complete() + assert task2.is_complete() + assert task3.is_incomplete() + + def test_min_failed(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + task3 = Task("Task 3") + + with instructions("fail tasks 1 and 3"): + run_tasks( + [task1, task2, task3], + run_until=AnyFailed(min_failed=2), + raise_on_failure=False, + ) + + assert task1.is_failed() + assert task2.is_incomplete() + assert task3.is_failed() + + +class TestRunUntilAsync: + async def test_any_complete(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("complete task 2"): + await run_tasks_async([task1, task2], run_until=AnyComplete()) + + assert task2.is_complete() + assert task1.is_incomplete() + + async def test_any_failed(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("fail task 2"): + await run_tasks_async( + [task1, task2], run_until=AnyFailed(), raise_on_failure=False + ) + + assert task2.is_failed() + assert task1.is_incomplete() + + async def test_max_llm_calls(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + + with instructions("say hi but do not complete any tasks"): + await run_tasks_async([task1, task2], run_until=MaxLLMCalls(1)) + + assert task2.is_incomplete() + assert task1.is_incomplete() + + async def test_min_complete(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + task3 = Task("Task 3") + + with instructions("complete tasks 1 and 2"): + await run_tasks_async( + [task1, task2, task3], run_until=AnyComplete(min_complete=2) + ) + + assert task1.is_complete() + assert task2.is_complete() + assert task3.is_incomplete() + + async def test_min_failed(self): + task1 = Task("Task 1") + task2 = Task("Task 2") + task3 = Task("Task 3") + + with instructions("fail tasks 1 and 3"): + await run_tasks_async( + [task1, task2, task3], + run_until=AnyFailed(min_failed=2), + raise_on_failure=False, + ) + + assert task1.is_failed() + assert task2.is_incomplete() + assert task3.is_failed()