diff --git a/src/nbatch/_runner.py b/src/nbatch/_runner.py index 24df94a..995d8b6 100644 --- a/src/nbatch/_runner.py +++ b/src/nbatch/_runner.py @@ -23,15 +23,6 @@ # Module-level logger _logger = logging.getLogger('nbatch.runner') -# Check for napari availability -try: - from napari.qt.threading import create_worker - - HAS_NAPARI = True -except ImportError: - HAS_NAPARI = False - create_worker = None - class BatchRunner: """Orchestrates batch operations with threading, progress, and cancellation. @@ -163,10 +154,11 @@ def cancel(self) -> None: self._cancel_requested = True self._was_cancelled = True # Set immediately for threaded cases # Store local reference to avoid race condition with _handle_finished - worker = self._worker if HAS_NAPARI else None + # Check for quit method to determine if it's a napari worker + worker = self._worker - # If using napari worker, request quit - if worker is not None: + # If using napari worker (has quit method), request quit + if worker is not None and hasattr(worker, 'quit'): with contextlib.suppress(RuntimeError): worker.quit() @@ -235,14 +227,18 @@ def run( if self._on_start is not None: self._on_start(len(items_list)) - if threaded and HAS_NAPARI: - self._run_napari_threaded( - func, items_list, args, kwargs, log_file, log_header - ) - elif threaded: - self._run_thread_fallback( - func, items_list, args, kwargs, log_file, log_header - ) + # Try napari threading first, fall back to standard threading + if threaded: + try: + import napari.qt.threading # noqa: F401 + + self._run_napari_threaded( + func, items_list, args, kwargs, log_file, log_header + ) + except ImportError: + self._run_thread_fallback( + func, items_list, args, kwargs, log_file, log_header + ) else: self._run_sync( func, items_list, args, kwargs, log_file, log_header @@ -285,6 +281,7 @@ def _run_napari_threaded( log_header: Mapping[str, object] | None, ) -> None: """Run batch using napari's create_worker for Qt-safe threading.""" + from napari.qt.threading import create_worker def _worker_func(): """Generator function for napari worker.""" diff --git a/tests/test_runner.py b/tests/test_runner.py index 3b01fc0..017109b 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -489,10 +489,17 @@ class TestBatchRunnerThreaded: def test_run_threaded_fallback_no_napari(self, monkeypatch): """Test threaded execution using fallback (concurrent.futures).""" - # Force the fallback path by pretending napari isn't available - import nbatch._runner as runner_module + # Force the fallback path by making napari import fail + import builtins - monkeypatch.setattr(runner_module, 'HAS_NAPARI', False) + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == 'napari.qt.threading' or name.startswith('napari'): + raise ImportError('napari not available') + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, '__import__', mock_import) results = [] completed = [] @@ -506,7 +513,7 @@ def process(item): on_complete=lambda: completed.append(True), ) - # Run threaded (will use fallback since we monkeypatched HAS_NAPARI) + # Run threaded (will use fallback since napari import fails) runner.run(process, [1, 2, 3], threaded=True) # Wait for completion