Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions src/nbatch/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@
# Module-level logger
_logger = logging.getLogger('nbatch.runner')

# Check for napari availability
try:
from napari.qt.threading import create_worker

HAS_NAPARI = True
except ImportError:
HAS_NAPARI = False
create_worker = None


class BatchRunner:
"""Orchestrates batch operations with threading, progress, and cancellation.
Expand Down Expand Up @@ -163,10 +154,11 @@ def cancel(self) -> None:
self._cancel_requested = True
self._was_cancelled = True # Set immediately for threaded cases
# Store local reference to avoid race condition with _handle_finished
worker = self._worker if HAS_NAPARI else None
# Check for quit method to determine if it's a napari worker
worker = self._worker

# If using napari worker, request quit
if worker is not None:
# If using napari worker (has quit method), request quit
if worker is not None and hasattr(worker, 'quit'):
with contextlib.suppress(RuntimeError):
worker.quit()

Expand Down Expand Up @@ -235,14 +227,18 @@ def run(
if self._on_start is not None:
self._on_start(len(items_list))

if threaded and HAS_NAPARI:
self._run_napari_threaded(
func, items_list, args, kwargs, log_file, log_header
)
elif threaded:
self._run_thread_fallback(
func, items_list, args, kwargs, log_file, log_header
)
# Try napari threading first, fall back to standard threading
if threaded:
try:
import napari.qt.threading # noqa: F401

self._run_napari_threaded(
func, items_list, args, kwargs, log_file, log_header
)
except ImportError:
self._run_thread_fallback(
func, items_list, args, kwargs, log_file, log_header
)
else:
self._run_sync(
func, items_list, args, kwargs, log_file, log_header
Expand Down Expand Up @@ -285,6 +281,7 @@ def _run_napari_threaded(
log_header: Mapping[str, object] | None,
) -> None:
"""Run batch using napari's create_worker for Qt-safe threading."""
from napari.qt.threading import create_worker

def _worker_func():
"""Generator function for napari worker."""
Expand Down
15 changes: 11 additions & 4 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,17 @@ class TestBatchRunnerThreaded:

def test_run_threaded_fallback_no_napari(self, monkeypatch):
"""Test threaded execution using fallback (concurrent.futures)."""
# Force the fallback path by pretending napari isn't available
import nbatch._runner as runner_module
# Force the fallback path by making napari import fail
import builtins

monkeypatch.setattr(runner_module, 'HAS_NAPARI', False)
real_import = builtins.__import__

def mock_import(name, *args, **kwargs):
if name == 'napari.qt.threading' or name.startswith('napari'):
raise ImportError('napari not available')
return real_import(name, *args, **kwargs)

monkeypatch.setattr(builtins, '__import__', mock_import)

results = []
completed = []
Expand All @@ -506,7 +513,7 @@ def process(item):
on_complete=lambda: completed.append(True),
)

# Run threaded (will use fallback since we monkeypatched HAS_NAPARI)
# Run threaded (will use fallback since napari import fails)
runner.run(process, [1, 2, 3], threaded=True)

# Wait for completion
Expand Down