From e678bf8bc150a58defef8733e8c86d6267af8d46 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sun, 16 Jun 2024 13:55:58 -0400 Subject: [PATCH] Fix task run name --- src/controlflow/core/controller/controller.py | 6 +++++- src/controlflow/core/task.py | 10 +++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index 82310742..71326c4d 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -158,7 +158,11 @@ def _setup_run(self): # start tracking tasks for task in ready_tasks: if not task._prefect_task.is_started: - task._prefect_task.start() + task._prefect_task.start( + depends_on=[ + t.result for t in task.depends_on if t.result is not None + ] + ) # if there are no ready tasks, return. This will usually happen because # all the tasks are complete. diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 45eae433..208a3c45 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -16,6 +16,7 @@ ) import prefect +from prefect.context import TaskRunContext from pydantic import ( Field, PydanticSchemaGenerationError, @@ -54,6 +55,11 @@ logger = get_logger(__name__) +def get_task_run_name() -> str: + context = TaskRunContext.get() + return f'Run {context.parameters['self'].friendly_name()}' + + class TaskStatus(Enum): INCOMPLETE = "INCOMPLETE" SUCCESSFUL = "SUCCESSFUL" @@ -307,7 +313,7 @@ async def run_once_async(self, agent: "Agent" = None, flow: "Flow" = None): controller = controlflow.Controller(tasks=[self], agents=agent, flow=flow) await controller.run_once_async() - @prefect.task(task_run_name=lambda _, args: f"Run {args['self'].friendly_name()}") + @prefect.task(task_run_name=get_task_run_name) def _run( self, raise_on_error: bool = True, @@ -320,8 +326,6 @@ def _run( """ from controlflow.core.flow import Flow, get_flow - self._prefect_task.is_started = True - if max_iterations == NOTSET: max_iterations = controlflow.settings.max_task_iterations if max_iterations is None: