Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
brilee committed Feb 2, 2024
1 parent 72e123d commit ec0a38a
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions lilac/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,18 @@
TaskFn = Union[Callable[..., Any], Callable[..., Awaitable[Any]]]


class CancellationError(Exception):
"""An error raised when a task is cancelled."""

pass


class TaskStatus(str, Enum):
"""Enum holding a tasks status."""

PENDING = 'pending'
COMPLETED = 'completed'
CANCELLED = 'cancelled'
ERROR = 'error'


Expand Down Expand Up @@ -127,6 +134,8 @@ def launch_task(self, task_id: TaskId, run_fn: Callable[..., Any]) -> None:
def _wrapper() -> None:
try:
run_fn()
except CancellationError:
self.set_cancelled(task_id)
except Exception as e:
log(e)
self.set_error(task_id, str(e))
Expand Down Expand Up @@ -157,6 +166,12 @@ def report_progress(self, task_id: TaskId, progress: int) -> None:
elapsed = pretty_timedelta(timedelta(seconds=elapsed_sec))
task.details = f'{progress:,}/{task.total_len:,} [{elapsed} {ex_per_sec:,.2f} ex/s]'

def set_cancelled(self, task_id: TaskId) -> None:
"""Mark a task as cancelled."""
task = self._task_info[task_id]
task.status = TaskStatus.CANCELLED
task.end_timestamp = datetime.now().isoformat()

def set_error(self, task_id: TaskId, error: str) -> None:
"""Mark a task as errored."""
task = self._task_info[task_id]
Expand Down Expand Up @@ -212,15 +227,19 @@ def progress_reporter(it: Iterator[TProgress]) -> Iterator[TProgress]:
try:
for item in tqdm(it, initial=progress, total=task_info.total_len, desc=task_info.description):
if task_manager._task_stopped[task_id]:
raise AssertionError('Task cancelled successfully!')
raise CancellationError('Task cancelled successfully!')
progress += 1
if progress % 100 == 0:
task_manager.report_progress(task_id, progress)
yield item
except CancellationError:
task_manager.set_cancelled(task_id)
raise
except Exception as e:
task_manager.set_error(task_id, str(e))
raise e
task_manager.set_completed(task_id)
else:
task_manager.set_completed(task_id)

return progress_reporter

Expand Down

0 comments on commit ec0a38a

Please sign in to comment.