Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class MyWidget:

# Create runner once - reusable for all batches
self.runner = BatchRunner(
on_start=self._on_batch_start,
on_item_complete=self._on_item_complete,
on_complete=self._on_batch_complete,
on_error=self._on_item_error,
Expand All @@ -213,6 +214,11 @@ class MyWidget:
self._run_button.clicked.connect(self.run_batch)
self._cancel_button.clicked.connect(self.runner.cancel)

def _on_batch_start(self, total):
"""Called when batch starts with total item count."""
self._progress_bar.setValue(0)
self._progress_bar.setMaximum(total)

def _on_item_complete(self, result, ctx):
"""Called after each item completes."""
self._progress_bar.setValue(ctx.index + 1)
Expand All @@ -221,7 +227,11 @@ class MyWidget:
self._viewer.add_image(result, name=f"Result {ctx.index}")

def _on_batch_complete(self):
self._progress_bar.label = "Complete!"
errors = self.runner.error_count
if errors > 0:
self._progress_bar.label = f"Done with {errors} errors"
else:
self._progress_bar.label = "Complete!"

def _on_item_error(self, ctx, exception):
self._progress_bar.label = f"Error on {ctx.item.name}"
Expand All @@ -231,7 +241,6 @@ class MyWidget:

def run_batch(self):
"""Triggered by 'Run' button - just one line!"""
self._progress_bar.max = len(self.files)
self.runner.run(
process_image,
self.files,
Expand Down Expand Up @@ -335,6 +344,7 @@ def batch_logger(
class BatchRunner:
def __init__(
self,
on_start: Callable[[int], None] | None = None,
on_item_complete: Callable[[Any, BatchContext], None] | None = None,
on_complete: Callable[[], None] | None = None,
on_error: Callable[[BatchContext, Exception], None] | None = None,
Expand All @@ -351,7 +361,7 @@ class BatchRunner:
log_header: Mapping[str, object] | None = None,
patterns: str | Sequence[str] = '*',
recursive: bool = False,
**kwargs,
**kwargs, # Passed to func!
) -> None: ...

def cancel(self) -> None: ...
Expand All @@ -361,6 +371,9 @@ class BatchRunner:

@property
def was_cancelled(self) -> bool: ...

@property
def error_count(self) -> int: ... # Errors in current/last batch
```

## Contributing
Expand Down
31 changes: 31 additions & 0 deletions src/nbatch/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class BatchRunner:

Parameters
----------
on_start : Callable[[int], None] | None, optional
Called when batch starts, receives total item count. Use to initialize
progress bars (e.g., ``on_start=lambda total: progress_bar.setMaximum(total)``).
on_item_complete : Callable[[Any, BatchContext], None] | None, optional
Called after each item completes successfully. Receives the result
and BatchContext. Use for progress bars and adding results to viewer.
Expand All @@ -60,12 +63,15 @@ class BatchRunner:
True if a batch is currently being processed.
was_cancelled : bool
True if the last batch was cancelled before completion.
error_count : int
Number of errors encountered in the current/last batch.

Examples
--------
Basic usage in a napari widget:

>>> runner = BatchRunner(
... on_start=lambda total: progress_bar.setMaximum(total),
... on_item_complete=lambda r, ctx: progress_bar.setValue(ctx.index + 1),
... on_complete=lambda: print("Done!"),
... )
Expand All @@ -91,15 +97,26 @@ class BatchRunner:
... log_file="output/batch.log",
... log_header={"Input": str(input_dir), "Files": len(files)},
... )

Passing additional arguments to the function (no partial needed!):

>>> runner.run(
... process_image,
... files,
... output_dir=output_path, # passed to process_image
... sigma=2.0, # passed to process_image
... )
"""

def __init__(
self,
on_start: Callable[[int], None] | None = None,
on_item_complete: Callable[[Any, BatchContext], None] | None = None,
on_complete: Callable[[], None] | None = None,
on_error: Callable[[BatchContext, Exception], None] | None = None,
on_cancel: Callable[[], None] | None = None,
):
self._on_start = on_start
self._on_item_complete = on_item_complete
self._on_complete = on_complete
self._on_error = on_error
Expand All @@ -110,6 +127,7 @@ def __init__(
self._cancel_requested = False
self._is_running = False
self._was_cancelled = False
self._error_count = 0
self._lock = threading.Lock()

# For logging within run
Expand All @@ -127,6 +145,12 @@ def was_cancelled(self) -> bool:
with self._lock:
return self._was_cancelled

@property
def error_count(self) -> int:
"""Number of errors encountered in the current/last batch."""
with self._lock:
return self._error_count

def cancel(self) -> None:
"""Request cancellation of the running batch.

Expand Down Expand Up @@ -202,10 +226,15 @@ def run(
self._is_running = True
self._cancel_requested = False
self._was_cancelled = False
self._error_count = 0

# Normalize items to a list
items_list = self._normalize_items(items, patterns, recursive)

# Call on_start callback with total count
if self._on_start is not None:
self._on_start(len(items_list))

if threaded and HAS_NAPARI:
self._run_napari_threaded(
func, items_list, args, kwargs, log_file, log_header
Expand Down Expand Up @@ -388,6 +417,8 @@ def _handle_yielded(
result, ctx, error = value

if error is not None:
with self._lock:
self._error_count += 1
if self._on_error is not None:
self._on_error(ctx, error)
else:
Expand Down
84 changes: 84 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,77 @@ def test_init_default(self):

def test_init_with_callbacks(self):
"""Test BatchRunner initializes with callbacks."""
on_start = MagicMock()
on_item = MagicMock()
on_complete = MagicMock()
on_error = MagicMock()
on_cancel = MagicMock()

runner = BatchRunner(
on_start=on_start,
on_item_complete=on_item,
on_complete=on_complete,
on_error=on_error,
on_cancel=on_cancel,
)

assert runner._on_start is on_start
assert runner._on_item_complete is on_item
assert runner._on_complete is on_complete
assert runner._on_error is on_error
assert runner._on_cancel is on_cancel

def test_error_count_initial(self):
"""Test error_count is 0 initially."""
runner = BatchRunner()
assert runner.error_count == 0

def test_on_start_callback(self):
"""Test on_start callback is called with total count."""
on_start = MagicMock()

def process(item):
return item

runner = BatchRunner(on_start=on_start)
runner.run(process, [1, 2, 3, 4, 5], threaded=False)

on_start.assert_called_once_with(5)
Comment on lines +57 to +67
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The on_start callback is only tested in synchronous mode (threaded=False). For consistency with other callback tests (e.g., on_item_complete, on_complete, on_error in TestBatchRunnerNapariThreading), consider adding a test for on_start in threaded mode to ensure it works correctly in both execution paths.

Copilot uses AI. Check for mistakes.

def test_error_count_tracking(self):
"""Test error_count tracks errors during batch."""

def process(item):
if item in [2, 4]:
raise ValueError(f'Error on {item}')
return item

runner = BatchRunner()
runner.run(process, [1, 2, 3, 4, 5], threaded=False)

assert runner.error_count == 2
Comment on lines +69 to +80
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error_count property is only tested in synchronous mode (threaded=False). For consistency with other error handling tests (e.g., test_run_napari_threaded_with_errors in TestBatchRunnerNapariThreading), consider adding a test for error_count in threaded mode to ensure the thread-safe counter works correctly across concurrent execution paths.

Copilot uses AI. Check for mistakes.

def test_error_count_reset_on_new_run(self):
"""Test error_count is reset at start of each run."""

def fail_some(item):
if item == 2:
raise ValueError('Error')
return item

def succeed_all(item):
return item

runner = BatchRunner()

# First run with error
runner.run(fail_some, [1, 2, 3], threaded=False)
assert runner.error_count == 1

# Second run without errors - error_count should reset
runner.run(succeed_all, [1, 2, 3], threaded=False)
assert runner.error_count == 0


class TestBatchRunnerSync:
"""Test BatchRunner synchronous execution."""
Expand Down Expand Up @@ -643,3 +697,33 @@ def on_complete():

assert sorted(results) == [2, 4, 11, 12]
assert batch_count[0] == 2

def test_on_start_callback_threaded(self, qtbot):
"""Test on_start callback is called with total count in threaded mode."""
start_total = []

def process(item):
return item

runner = BatchRunner(on_start=lambda total: start_total.append(total))

runner.run(process, [1, 2, 3, 4, 5], threaded=True)

qtbot.waitUntil(lambda: not runner.is_running, timeout=5000)

assert start_total == [5]

def test_error_count_threaded(self, qtbot):
"""Test error_count tracks errors correctly in threaded mode."""

def process(item):
if item in [2, 4]:
raise ValueError(f'Error on {item}')
return item

runner = BatchRunner()
runner.run(process, [1, 2, 3, 4, 5], threaded=True)

qtbot.waitUntil(lambda: not runner.is_running, timeout=5000)

assert runner.error_count == 2