diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index 35a78952..8bcf9369 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -110,7 +110,8 @@ def add_events(self, events: list[Event]): @contextmanager def create_context(self): - with ctx(flow=self): + # creating a new flow will reset any parent task tracking + with ctx(flow=self, tasks=None): yield self diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 52564976..7a79e214 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -394,7 +394,7 @@ async def run_async( @contextmanager def create_context(self): - stack = ctx.get("tasks", []) + stack = ctx.get("tasks") or [] with ctx(tasks=stack + [self]): yield self diff --git a/src/controlflow/utilities/context.py b/src/controlflow/utilities/context.py index 4d03dda5..871221bb 100644 --- a/src/controlflow/utilities/context.py +++ b/src/controlflow/utilities/context.py @@ -71,7 +71,7 @@ def __call__(self, **kwargs: Any) -> Generator[None, None, Any]: ctx = ScopedContext( dict( flow=None, - tasks=[], + tasks=None, agent=None, orchestrator=None, tui=None, diff --git a/tests/flows/test_flows.py b/tests/flows/test_flows.py index 9da9851f..518827f7 100644 --- a/tests/flows/test_flows.py +++ b/tests/flows/test_flows.py @@ -35,9 +35,9 @@ class TestFlowContext: def test_flow_context_manager(self): with Flow() as flow: assert ctx.get("flow") == flow - assert ctx.get("tasks") == [] + assert ctx.get("tasks") is None assert ctx.get("flow") is None - assert ctx.get("tasks") == [] + assert ctx.get("tasks") is None def test_get_flow_within_context(self): with Flow() as flow: @@ -66,6 +66,17 @@ def test_get_flow_nested_contexts(self): assert get_flow() == flow1 assert get_flow() is None + def test_flow_context_resets_task_tracking(self): + parent_task = Task("Parent task") + with parent_task: + assert ctx.get("tasks") == [parent_task] + with Flow(): + assert ctx.get("tasks") is None + nested_task = Task("Nested task") + assert nested_task.parent is None + assert ctx.get("tasks") == [parent_task] + assert ctx.get("tasks") is None + class TestFlowHistory: def test_get_events_empty(self): diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 83007162..d1596807 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -22,13 +22,13 @@ def test_status_coverage(): def test_context_open_and_close(): - assert ctx.get("tasks") == [] + assert ctx.get("tasks") is None with SimpleTask() as ta: assert ctx.get("tasks") == [ta] with SimpleTask() as tb: assert ctx.get("tasks") == [ta, tb] assert ctx.get("tasks") == [ta] - assert ctx.get("tasks") == [] + assert ctx.get("tasks") is None def test_task_requires_objective():