diff --git a/lilac/tasks.py b/lilac/tasks.py index 43a453dc..955f5dab 100644 --- a/lilac/tasks.py +++ b/lilac/tasks.py @@ -28,11 +28,18 @@ TaskFn = Union[Callable[..., Any], Callable[..., Awaitable[Any]]] +class CancellationError(Exception): + """An error raised when a task is cancelled.""" + + pass + + class TaskStatus(str, Enum): """Enum holding a tasks status.""" PENDING = 'pending' COMPLETED = 'completed' + CANCELLED = 'cancelled' ERROR = 'error' @@ -127,6 +134,8 @@ def launch_task(self, task_id: TaskId, run_fn: Callable[..., Any]) -> None: def _wrapper() -> None: try: run_fn() + except CancellationError: + self.set_cancelled(task_id) except Exception as e: log(e) self.set_error(task_id, str(e)) @@ -157,6 +166,12 @@ def report_progress(self, task_id: TaskId, progress: int) -> None: elapsed = pretty_timedelta(timedelta(seconds=elapsed_sec)) task.details = f'{progress:,}/{task.total_len:,} [{elapsed} {ex_per_sec:,.2f} ex/s]' + def set_cancelled(self, task_id: TaskId) -> None: + """Mark a task as cancelled.""" + task = self._task_info[task_id] + task.status = TaskStatus.CANCELLED + task.end_timestamp = datetime.now().isoformat() + def set_error(self, task_id: TaskId, error: str) -> None: """Mark a task as errored.""" task = self._task_info[task_id] @@ -212,15 +227,19 @@ def progress_reporter(it: Iterator[TProgress]) -> Iterator[TProgress]: try: for item in tqdm(it, initial=progress, total=task_info.total_len, desc=task_info.description): if task_manager._task_stopped[task_id]: - raise AssertionError('Task cancelled successfully!') + raise CancellationError('Task cancelled successfully!') progress += 1 if progress % 100 == 0: task_manager.report_progress(task_id, progress) yield item + except CancellationError: + task_manager.set_cancelled(task_id) + raise except Exception as e: task_manager.set_error(task_id, str(e)) raise e - task_manager.set_completed(task_id) + else: + task_manager.set_completed(task_id) return progress_reporter