diff --git a/README.md b/README.md index 36416aa..9fede58 100644 --- a/README.md +++ b/README.md @@ -204,6 +204,7 @@ class MyWidget: # Create runner once - reusable for all batches self.runner = BatchRunner( + on_start=self._on_batch_start, on_item_complete=self._on_item_complete, on_complete=self._on_batch_complete, on_error=self._on_item_error, @@ -213,6 +214,11 @@ class MyWidget: self._run_button.clicked.connect(self.run_batch) self._cancel_button.clicked.connect(self.runner.cancel) + def _on_batch_start(self, total): + """Called when batch starts with total item count.""" + self._progress_bar.setValue(0) + self._progress_bar.setMaximum(total) + def _on_item_complete(self, result, ctx): """Called after each item completes.""" self._progress_bar.setValue(ctx.index + 1) @@ -221,7 +227,11 @@ class MyWidget: self._viewer.add_image(result, name=f"Result {ctx.index}") def _on_batch_complete(self): - self._progress_bar.label = "Complete!" + errors = self.runner.error_count + if errors > 0: + self._progress_bar.label = f"Done with {errors} errors" + else: + self._progress_bar.label = "Complete!" def _on_item_error(self, ctx, exception): self._progress_bar.label = f"Error on {ctx.item.name}" @@ -231,7 +241,6 @@ class MyWidget: def run_batch(self): """Triggered by 'Run' button - just one line!""" - self._progress_bar.max = len(self.files) self.runner.run( process_image, self.files, @@ -335,6 +344,7 @@ def batch_logger( class BatchRunner: def __init__( self, + on_start: Callable[[int], None] | None = None, on_item_complete: Callable[[Any, BatchContext], None] | None = None, on_complete: Callable[[], None] | None = None, on_error: Callable[[BatchContext, Exception], None] | None = None, @@ -351,7 +361,7 @@ class BatchRunner: log_header: Mapping[str, object] | None = None, patterns: str | Sequence[str] = '*', recursive: bool = False, - **kwargs, + **kwargs, # Passed to func! ) -> None: ... def cancel(self) -> None: ... @@ -361,6 +371,9 @@ class BatchRunner: @property def was_cancelled(self) -> bool: ... + + @property + def error_count(self) -> int: ... # Errors in current/last batch ``` ## Contributing diff --git a/src/nbatch/_runner.py b/src/nbatch/_runner.py index fca9307..24df94a 100644 --- a/src/nbatch/_runner.py +++ b/src/nbatch/_runner.py @@ -42,6 +42,9 @@ class BatchRunner: Parameters ---------- + on_start : Callable[[int], None] | None, optional + Called when batch starts, receives total item count. Use to initialize + progress bars (e.g., ``on_start=lambda total: progress_bar.setMaximum(total)``). on_item_complete : Callable[[Any, BatchContext], None] | None, optional Called after each item completes successfully. Receives the result and BatchContext. Use for progress bars and adding results to viewer. @@ -60,12 +63,15 @@ class BatchRunner: True if a batch is currently being processed. was_cancelled : bool True if the last batch was cancelled before completion. + error_count : int + Number of errors encountered in the current/last batch. Examples -------- Basic usage in a napari widget: >>> runner = BatchRunner( + ... on_start=lambda total: progress_bar.setMaximum(total), ... on_item_complete=lambda r, ctx: progress_bar.setValue(ctx.index + 1), ... on_complete=lambda: print("Done!"), ... ) @@ -91,15 +97,26 @@ class BatchRunner: ... log_file="output/batch.log", ... log_header={"Input": str(input_dir), "Files": len(files)}, ... ) + + Passing additional arguments to the function (no partial needed!): + + >>> runner.run( + ... process_image, + ... files, + ... output_dir=output_path, # passed to process_image + ... sigma=2.0, # passed to process_image + ... ) """ def __init__( self, + on_start: Callable[[int], None] | None = None, on_item_complete: Callable[[Any, BatchContext], None] | None = None, on_complete: Callable[[], None] | None = None, on_error: Callable[[BatchContext, Exception], None] | None = None, on_cancel: Callable[[], None] | None = None, ): + self._on_start = on_start self._on_item_complete = on_item_complete self._on_complete = on_complete self._on_error = on_error @@ -110,6 +127,7 @@ def __init__( self._cancel_requested = False self._is_running = False self._was_cancelled = False + self._error_count = 0 self._lock = threading.Lock() # For logging within run @@ -127,6 +145,12 @@ def was_cancelled(self) -> bool: with self._lock: return self._was_cancelled + @property + def error_count(self) -> int: + """Number of errors encountered in the current/last batch.""" + with self._lock: + return self._error_count + def cancel(self) -> None: """Request cancellation of the running batch. @@ -202,10 +226,15 @@ def run( self._is_running = True self._cancel_requested = False self._was_cancelled = False + self._error_count = 0 # Normalize items to a list items_list = self._normalize_items(items, patterns, recursive) + # Call on_start callback with total count + if self._on_start is not None: + self._on_start(len(items_list)) + if threaded and HAS_NAPARI: self._run_napari_threaded( func, items_list, args, kwargs, log_file, log_header @@ -388,6 +417,8 @@ def _handle_yielded( result, ctx, error = value if error is not None: + with self._lock: + self._error_count += 1 if self._on_error is not None: self._on_error(ctx, error) else: diff --git a/tests/test_runner.py b/tests/test_runner.py index 7f7708e..bec41d0 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -29,23 +29,77 @@ def test_init_default(self): def test_init_with_callbacks(self): """Test BatchRunner initializes with callbacks.""" + on_start = MagicMock() on_item = MagicMock() on_complete = MagicMock() on_error = MagicMock() on_cancel = MagicMock() runner = BatchRunner( + on_start=on_start, on_item_complete=on_item, on_complete=on_complete, on_error=on_error, on_cancel=on_cancel, ) + assert runner._on_start is on_start assert runner._on_item_complete is on_item assert runner._on_complete is on_complete assert runner._on_error is on_error assert runner._on_cancel is on_cancel + def test_error_count_initial(self): + """Test error_count is 0 initially.""" + runner = BatchRunner() + assert runner.error_count == 0 + + def test_on_start_callback(self): + """Test on_start callback is called with total count.""" + on_start = MagicMock() + + def process(item): + return item + + runner = BatchRunner(on_start=on_start) + runner.run(process, [1, 2, 3, 4, 5], threaded=False) + + on_start.assert_called_once_with(5) + + def test_error_count_tracking(self): + """Test error_count tracks errors during batch.""" + + def process(item): + if item in [2, 4]: + raise ValueError(f'Error on {item}') + return item + + runner = BatchRunner() + runner.run(process, [1, 2, 3, 4, 5], threaded=False) + + assert runner.error_count == 2 + + def test_error_count_reset_on_new_run(self): + """Test error_count is reset at start of each run.""" + + def fail_some(item): + if item == 2: + raise ValueError('Error') + return item + + def succeed_all(item): + return item + + runner = BatchRunner() + + # First run with error + runner.run(fail_some, [1, 2, 3], threaded=False) + assert runner.error_count == 1 + + # Second run without errors - error_count should reset + runner.run(succeed_all, [1, 2, 3], threaded=False) + assert runner.error_count == 0 + class TestBatchRunnerSync: """Test BatchRunner synchronous execution.""" @@ -643,3 +697,33 @@ def on_complete(): assert sorted(results) == [2, 4, 11, 12] assert batch_count[0] == 2 + + def test_on_start_callback_threaded(self, qtbot): + """Test on_start callback is called with total count in threaded mode.""" + start_total = [] + + def process(item): + return item + + runner = BatchRunner(on_start=lambda total: start_total.append(total)) + + runner.run(process, [1, 2, 3, 4, 5], threaded=True) + + qtbot.waitUntil(lambda: not runner.is_running, timeout=5000) + + assert start_total == [5] + + def test_error_count_threaded(self, qtbot): + """Test error_count tracks errors correctly in threaded mode.""" + + def process(item): + if item in [2, 4]: + raise ValueError(f'Error on {item}') + return item + + runner = BatchRunner() + runner.run(process, [1, 2, 3, 4, 5], threaded=True) + + qtbot.waitUntil(lambda: not runner.is_running, timeout=5000) + + assert runner.error_count == 2