Skip to content

Commit 837ab75

Browse files
committed
Improve prefect task logic
1 parent f47552e commit 837ab75

File tree

4 files changed

+80
-36
lines changed

4 files changed

+80
-36
lines changed

src/controlflow/core/flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def add_task(self, task: "Task"):
4545
self._tasks[task.id] = task
4646

4747
@contextmanager
48-
def _context(self, create_prefect_flow_context: bool = True):
48+
def create_context(self, create_prefect_flow_context: bool = True):
4949
if create_prefect_flow_context:
5050
prefect_ctx = prefect_flow_context(name=self.name)
5151
else:
@@ -55,7 +55,7 @@ def _context(self, create_prefect_flow_context: bool = True):
5555

5656
def __enter__(self):
5757
# use stack so we can enter the context multiple times
58-
cm = self._context()
58+
cm = self.create_context()
5959
self._cm_stack.append(cm)
6060
return cm.__enter__()
6161

src/controlflow/core/task.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(
140140
super().__init__(**kwargs)
141141

142142
self._prefect_task = PrefectTrackingTask(
143-
name=f"Working on {self.friendly_name()}",
143+
name=f"Working on {self.friendly_name()}...",
144144
description=self.instructions,
145145
tags=[self.__class__.__name__],
146146
)
@@ -307,6 +307,7 @@ async def run_once_async(self, agent: "Agent" = None, flow: "Flow" = None):
307307
controller = controlflow.Controller(tasks=[self], agents=agent, flow=flow)
308308
await controller.run_once_async()
309309

310+
@prefect.task(task_run_name=lambda _, args: f"Run {args['self'].friendly_name()}")
310311
def _run(
311312
self,
312313
raise_on_error: bool = True,
@@ -319,6 +320,8 @@ def _run(
319320
"""
320321
from controlflow.core.flow import Flow, get_flow
321322

323+
self._prefect_task.is_started = True
324+
322325
if max_iterations == NOTSET:
323326
max_iterations = controlflow.settings.max_task_iterations
324327
if max_iterations is None:
@@ -362,21 +365,17 @@ def run(
362365
If max_iterations is provided, the task will run at most that many times before raising an error.
363366
"""
364367

365-
@prefect.task(task_run_name=f"Run Task {self.id}")
366-
def _run():
367-
gen = self._run(
368-
raise_on_error=raise_on_error,
369-
max_iterations=max_iterations,
370-
flow=flow,
371-
run_async=False,
372-
)
373-
while True:
374-
try:
375-
next(gen)
376-
except StopIteration as e:
377-
return e.value
378-
379-
return _run()
368+
gen = self._run(
369+
raise_on_error=raise_on_error,
370+
max_iterations=max_iterations,
371+
flow=flow,
372+
run_async=False,
373+
)
374+
while True:
375+
try:
376+
next(gen)
377+
except StopIteration as e:
378+
return e.value
380379

381380
async def run_async(
382381
self,
@@ -390,31 +389,27 @@ async def run_async(
390389
If max_iterations is provided, the task will run at most that many times before raising an error.
391390
"""
392391

393-
@prefect.task(task_run_name=f"Run Task {self.id}")
394-
async def _run():
395-
gen = self._run(
396-
raise_on_error=raise_on_error,
397-
max_iterations=max_iterations,
398-
flow=flow,
399-
run_async=True,
400-
)
401-
while True:
402-
try:
403-
await next(gen)
404-
except StopIteration as e:
405-
return e.value
406-
407-
return await _run()
392+
gen = self._run(
393+
raise_on_error=raise_on_error,
394+
max_iterations=max_iterations,
395+
flow=flow,
396+
run_async=True,
397+
)
398+
while True:
399+
try:
400+
await next(gen)
401+
except StopIteration as e:
402+
return e.value
408403

409404
@contextmanager
410-
def _context(self):
405+
def create_context(self):
411406
stack = ctx.get("tasks", [])
412407
with ctx(tasks=stack + [self]):
413408
yield self
414409

415410
def __enter__(self):
416411
# use stack so we can enter the context multiple times
417-
self._cm_stack.append(self._context())
412+
self._cm_stack.append(self.create_context())
418413
return self._cm_stack[-1].__enter__()
419414

420415
def __exit__(self, *exc_info):

src/controlflow/decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def wrapper(
101101
**flow_kwargs,
102102
)
103103

104-
with flow_obj._context(create_prefect_flow_context=False):
104+
with flow_obj.create_context(create_prefect_flow_context=False):
105105
with controlflow.instructions(instructions):
106106
result = fn(*args, **kwargs)
107107

src/controlflow/utilities/prefect.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,30 @@ def skip(self):
328328

329329

330330
def prefect_task_context(**kwargs):
331+
"""
332+
Creates a Prefect task that starts when the context is entered and ends when
333+
it closes. This is useful for creating a Prefect task that is not tied to a
334+
specific function but governs a block of code. Note that some features, like
335+
retries and caching, will not work.
336+
"""
337+
supported_kwargs = {
338+
"name",
339+
"description",
340+
"task_run_name",
341+
"tags",
342+
"version",
343+
"timeout_seconds",
344+
"log_prints",
345+
"on_completion",
346+
"on_failure",
347+
}
348+
unsupported_kwargs = set(kwargs.keys()) - set(supported_kwargs)
349+
if unsupported_kwargs:
350+
raise ValueError(
351+
f"Unsupported keyword arguments for a task context provided: "
352+
f"{unsupported_kwargs}. Consider using a @task-decorated function instead."
353+
)
354+
331355
@contextmanager
332356
@prefect.task(**kwargs)
333357
def task_context():
@@ -337,6 +361,31 @@ def task_context():
337361

338362

339363
def prefect_flow_context(**kwargs):
364+
"""
365+
Creates a Prefect flow that starts when the context is entered and ends when
366+
it closes. This is useful for creating a Prefect flow that is not tied to a
367+
specific function but governs a block of code. Note that some features, like
368+
retries and caching, will not work.
369+
"""
370+
371+
supported_kwargs = {
372+
"name",
373+
"description",
374+
"flow_run_name",
375+
"tags",
376+
"version",
377+
"timeout_seconds",
378+
"log_prints",
379+
"on_completion",
380+
"on_failure",
381+
}
382+
unsupported_kwargs = set(kwargs.keys()) - set(supported_kwargs)
383+
if unsupported_kwargs:
384+
raise ValueError(
385+
f"Unsupported keyword arguments for a flow context provided: "
386+
f"{unsupported_kwargs}. Consider using a @flow-decorated function instead."
387+
)
388+
340389
@contextmanager
341390
@prefect.flow(**kwargs)
342391
def flow_context():

0 commit comments

Comments
 (0)