Skip to content

Commit

Permalink
Merge pull request #286 from PrefectHQ/flow-parents
Browse files Browse the repository at this point in the history
Reset task parent tracking when nested flows are created
  • Loading branch information
jlowin authored Sep 6, 2024
2 parents a811d6d + c4c3091 commit fe7e638
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/controlflow/flows/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def add_events(self, events: list[Event]):

@contextmanager
def create_context(self):
with ctx(flow=self):
# creating a new flow will reset any parent task tracking
with ctx(flow=self, tasks=None):
yield self


Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ async def run_async(

@contextmanager
def create_context(self):
stack = ctx.get("tasks", [])
stack = ctx.get("tasks") or []
with ctx(tasks=stack + [self]):
yield self

Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/utilities/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __call__(self, **kwargs: Any) -> Generator[None, None, Any]:
ctx = ScopedContext(
dict(
flow=None,
tasks=[],
tasks=None,
agent=None,
orchestrator=None,
tui=None,
Expand Down
15 changes: 13 additions & 2 deletions tests/flows/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class TestFlowContext:
def test_flow_context_manager(self):
with Flow() as flow:
assert ctx.get("flow") == flow
assert ctx.get("tasks") == []
assert ctx.get("tasks") is None
assert ctx.get("flow") is None
assert ctx.get("tasks") == []
assert ctx.get("tasks") is None

def test_get_flow_within_context(self):
with Flow() as flow:
Expand Down Expand Up @@ -66,6 +66,17 @@ def test_get_flow_nested_contexts(self):
assert get_flow() == flow1
assert get_flow() is None

def test_flow_context_resets_task_tracking(self):
parent_task = Task("Parent task")
with parent_task:
assert ctx.get("tasks") == [parent_task]
with Flow():
assert ctx.get("tasks") is None
nested_task = Task("Nested task")
assert nested_task.parent is None
assert ctx.get("tasks") == [parent_task]
assert ctx.get("tasks") is None


class TestFlowHistory:
def test_get_events_empty(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def test_status_coverage():


def test_context_open_and_close():
assert ctx.get("tasks") == []
assert ctx.get("tasks") is None
with SimpleTask() as ta:
assert ctx.get("tasks") == [ta]
with SimpleTask() as tb:
assert ctx.get("tasks") == [ta, tb]
assert ctx.get("tasks") == [ta]
assert ctx.get("tasks") == []
assert ctx.get("tasks") is None


def test_task_requires_objective():
Expand Down

0 comments on commit fe7e638

Please sign in to comment.