diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 52564976..7a79e214 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -394,7 +394,7 @@ async def run_async( @contextmanager def create_context(self): - stack = ctx.get("tasks", []) + stack = ctx.get("tasks") or [] with ctx(tasks=stack + [self]): yield self diff --git a/src/controlflow/utilities/context.py b/src/controlflow/utilities/context.py index 4d03dda5..871221bb 100644 --- a/src/controlflow/utilities/context.py +++ b/src/controlflow/utilities/context.py @@ -71,7 +71,7 @@ def __call__(self, **kwargs: Any) -> Generator[None, None, Any]: ctx = ScopedContext( dict( flow=None, - tasks=[], + tasks=None, agent=None, orchestrator=None, tui=None, diff --git a/tests/flows/test_flows.py b/tests/flows/test_flows.py index f5a85ac3..518827f7 100644 --- a/tests/flows/test_flows.py +++ b/tests/flows/test_flows.py @@ -35,9 +35,9 @@ class TestFlowContext: def test_flow_context_manager(self): with Flow() as flow: assert ctx.get("flow") == flow - assert ctx.get("tasks") == [] + assert ctx.get("tasks") is None assert ctx.get("flow") is None - assert ctx.get("tasks") == [] + assert ctx.get("tasks") is None def test_get_flow_within_context(self): with Flow() as flow: @@ -75,7 +75,7 @@ def test_flow_context_resets_task_tracking(self): nested_task = Task("Nested task") assert nested_task.parent is None assert ctx.get("tasks") == [parent_task] - assert ctx.get("tasks") == [] + assert ctx.get("tasks") is None class TestFlowHistory: diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 83007162..d1596807 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -22,13 +22,13 @@ def test_status_coverage(): def test_context_open_and_close(): - assert ctx.get("tasks") == [] + assert ctx.get("tasks") is None with SimpleTask() as ta: assert ctx.get("tasks") == [ta] with SimpleTask() as tb: assert ctx.get("tasks") == [ta, tb] assert ctx.get("tasks") == [ta] - assert ctx.get("tasks") == [] + assert ctx.get("tasks") is None def test_task_requires_objective():