diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 109a14f..c74071e 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: platform: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.10", "3.11", "3.12", "3.13"] + python-version: ["3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v6 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 53a2c87..7c1c33f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,10 @@ repos: - id: trailing-whitespace exclude: ^\.napari-hub/.* - id: check-yaml # checks for correct yaml syntax for github actions ex. + exclude: | + (?x)( + |^tests/resources/Workflow/workflows/.*\.yaml$ + ) - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.10 hooks: diff --git a/README.md b/README.md index 4f97452..db6da8a 100644 --- a/README.md +++ b/README.md @@ -9,23 +9,26 @@ [![npe2](https://img.shields.io/badge/plugin-npe2-blue?link=https://napari.org/stable/plugins/index.html)](https://napari.org/stable/plugins/index.html) [![Copier](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/copier-org/copier/master/img/badge/badge-grayscale-inverted-border-purple.json)](https://github.com/copier-org/copier) -reproducible processing workflows with napari +**Reproducible processing workflows with napari** ----------------------------------- +A re-implementation of [napari-workflows](https://github.com/haesleinhuepf/napari-workflows) with backwards compatibility. -This [napari] plugin was generated with [copier] using the [napari-plugin-template] (main). +--- - +## What is ndev-workflows? -## Installation +`ndev-workflows` is the workflow backend for napari image processing pipelines. It's a **drop-in replacement** for [napari-workflows](https://github.com/haesleinhuepf/napari-workflows) by Robert Haase, with these key improvements: + +- **Safe YAML loading** — Uses `yaml.safe_load()` (no arbitrary code execution) +- **Backwards compatible** — Automatically loads and migrates legacy napari-workflows files, and detects missing dependencies +- **Same API** — Most code works without changes +- **Future-ready** — Designed for upcoming npe2 workflow contributions (WIP), without relying on npe1, napari-time-slicer, and napari-tools-menu for interactivity + +--- -You can install `ndev-workflows` via [pip]: +## Installation ```bash pip install ndev-workflows @@ -37,41 +40,132 @@ If napari is not already installed, you can install `ndev-workflows` with napari pip install "ndev-workflows[all]" ``` +--- + +## Quick Start + +```python +from ndev_workflows import Workflow, save_workflow, load_workflow +from skimage.filters import gaussian + +# Create workflow +workflow = Workflow() +workflow.set("blurred", gaussian, "input_image", sigma=2.0) +workflow.set("input_image", my_image) + +# Execute +result = workflow.get("blurred") + +# Save +save_workflow("pipeline.yaml", workflow, name="My Pipeline") + +# Load and reuse +loaded = load_workflow("pipeline.yaml") +loaded.set("input_image", new_image) +result = loaded.get("blurred") +``` + +--- + +## YAML Format + +Saved workflows use a safe, human-readable format: + +```yaml +name: Nucleus Segmentation +description: Gaussian blur and thresholding +modified: '2025-12-22' + +inputs: + - raw_image + +outputs: + - labels + +tasks: + blurred: + function: skimage.filters.gaussian + params: + arg0: raw_image + sigma: 2.0 + + labels: + function: skimage.measure.label + params: + arg0: blurred +``` + +**Key features:** + +- No `!python/object` tags (safe to share) +- Functions imported by module path +- Params use `arg0`, `arg1`, etc. for positional args and keyword names for kwargs -To install latest development version: +**Legacy format**: Old napari-workflows YAML files are automatically detected and migrated when loaded. + +--- + +## Important Notes + +### Function Dependencies + +⚠️ Workflows **don't bundle functions** — they only store module paths. Recipients need the same packages installed. + +If loading fails with `WorkflowNotRunnableError`, install the missing package: ```bash -pip install git+https://github.com/ndev-kit/ndev-workflows.git +pip install scikit-image # for skimage functions +pip install napari-segment-blobs-and-things-with-membranes # for that plugin ``` +### Lazy Loading + +Inspect workflows without importing functions: + +```python +workflow = load_workflow("untrusted.yaml", lazy=True) +print(workflow.tasks) # Safe - doesn't execute +``` + +--- + +## Integration + +### Front-end plugins for interactive workflow building: +- [napari-assistant](https://github.com/haesleinhuepf/napari-assistant) +- [napari-workflow-optimizer](https://github.com/haesleinhuepf/napari-workflow-optimizer) +- [napari-workflow-inspector](https://github.com/haesleinhuepf/napari-workflow-inspector) + +### Works with processing plugins: + +- [napari-segment-blobs-and-things-with-membranes](https://www.napari-hub.org/plugins/napari-segment-blobs-and-things-with-membranes) +- [pyclesperanto](https://github.com/clesperanto/napari_pyclesperanto_assistant) +- And more! + +--- ## Contributing -Contributions are very welcome. Tests can be run with [tox], please ensure -the coverage at least stays the same before you submit a pull request. +```bash +git clone https://github.com/ndev-kit/ndev-workflows.git +cd ndev-workflows +uv venv +.venv\Scripts\activate +uv pip install -e . --group dev +pytest +``` + +--- ## License Distributed under the terms of the [BSD-3] license, "ndev-workflows" is free and open source software +Fork of [napari-workflows](https://github.com/haesleinhuepf/napari-workflows) by Robert Haase. -## Issues - -If you encounter any problems, please [file an issue] along with a detailed description. +--- -[napari]: https://github.com/napari/napari -[copier]: https://copier.readthedocs.io/en/stable/ -[MIT]: http://opensource.org/licenses/MIT -[BSD-3]: http://opensource.org/licenses/BSD-3-Clause -[GNU GPL v3.0]: http://www.gnu.org/licenses/gpl-3.0.txt -[GNU LGPL v3.0]: http://www.gnu.org/licenses/lgpl-3.0.txt -[Apache Software License 2.0]: http://www.apache.org/licenses/LICENSE-2.0 -[Mozilla Public License 2.0]: https://www.mozilla.org/media/MPL/2.0/index.txt -[napari-plugin-template]: https://github.com/napari/napari-plugin-template - -[file an issue]: https://github.com/ndev-kit/ndev-workflows/issues +## Issues -[tox]: https://tox.readthedocs.io/en/latest/ -[pip]: https://pypi.org/project/pip/ -[PyPI]: https://pypi.org/ +[File an issue](https://github.com/ndev-kit/ndev-workflows/issues) with your environment details, YAML file (if applicable), and error messages. diff --git a/pyproject.toml b/pyproject.toml index 34f585f..019a388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,19 +17,25 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Image Processing", ] -requires-python = ">=3.10" +requires-python = ">=3.11" # ndevio requires 3.11+ # napari can be included in dependencies if napari imports are required. # However, you should not include napari[all], napari[qt], # or any other Qt bindings directly (e.g. PyQt5, PySide2). # See best practices: https://napari.org/stable/plugins/building_a_plugin/best_practices.html dependencies = [ + "napari", + "nbatch>=0.0.4", + "ndevio>=0.6.0", + "magicgui", + "magic-class", "numpy", + "dask", + "pyyaml", ] [project.optional-dependencies] @@ -42,6 +48,9 @@ dev = [ "tox-uv", "pytest", # https://docs.pytest.org/en/latest/contents.html "pytest-cov", # https://pytest-cov.readthedocs.io/en/latest/ + "pytest-qt", + "napari[pyqt6]", + "napari-segment-blobs-and-things-with-membranes", # TODO: adds 76 transient dependencies, yuck. currently for legacy/current sample test workflows. will try to remove in future ] [project.entry-points."napari.manifest"] diff --git a/src/ndev_workflows/__init__.py b/src/ndev_workflows/__init__.py index d4f2bcf..c5a0cb9 100644 --- a/src/ndev_workflows/__init__.py +++ b/src/ndev_workflows/__init__.py @@ -1,7 +1,35 @@ +"""ndev-workflows: Reproducible processing workflows with napari. + +This package provides workflow management and batch processing for napari. +It is a fork of napari-workflows by Robert Haase (BSD-3-Clause license), +enhanced with: +- Safe YAML loading (no arbitrary code execution) +- Human-readable workflow format +- Integration with ndev-settings and nbatch +- npe2-native plugin architecture + +Example +------- +>>> from ndev_workflows import Workflow, save_workflow, load_workflow +>>> w = Workflow() +>>> w.set("blurred", gaussian, "input", sigma=2.0) +>>> save_workflow("my_workflow.yaml", w, name="My Pipeline") +>>> +>>> loaded = load_workflow("my_workflow.yaml") +>>> loaded.set("input", image_data) +>>> result = loaded.get("blurred") +""" + try: from ._version import version as __version__ except ImportError: __version__ = 'unknown' +from ._io import load_workflow, save_workflow +from ._workflow import Workflow -__all__ = () +__all__ = [ + 'Workflow', + 'load_workflow', + 'save_workflow', +] diff --git a/src/ndev_workflows/_batch.py b/src/ndev_workflows/_batch.py new file mode 100644 index 0000000..6e6a9b3 --- /dev/null +++ b/src/ndev_workflows/_batch.py @@ -0,0 +1,105 @@ +"""Batch-processing helpers for ndev-workflows.""" + +from __future__ import annotations + +from pathlib import Path + +from nbatch import batch + + +@batch(on_error='continue') +def process_workflow_file( + image_file: Path, + result_dir: Path, + workflow_file: Path, + root_index_list: list[int], + task_names: list[str], + keep_original_images: bool, + root_list: list[str], + squeezed_img_dims: str, +) -> Path: + """Process a single image file through a workflow. + + Loads a fresh workflow instance per file for thread safety. + + Parameters + ---------- + image_file : Path + Path to the image file to process. + result_dir : Path + Directory to save results. + workflow_file : Path + Path to the workflow YAML file. + root_index_list : list[int] + Indices of channels to use as workflow roots. + task_names : list[str] + Names of workflow tasks to execute. + keep_original_images : bool + Whether to concatenate original images with results. + root_list : list[str] + Names of root channels (for output naming). + squeezed_img_dims : str + Squeezed dimension order of the image. + + Returns + ------- + Path + Path to the saved output file. + """ + import dask.array as da + import numpy as np + from bioio.writers import OmeTiffWriter + from bioio_base import transforms + from ndevio import nImage + + from ._io import load_workflow + from ._spec import ensure_runnable + + workflow = load_workflow(workflow_file, lazy=True) + workflow = ensure_runnable(workflow) + + img = nImage(image_file) + + # Capture roots before modifying workflow (stable list of graph inputs) + root_names = workflow.roots() + + root_stack = [] + for idx, root_index in enumerate(root_index_list): + if 'S' in img.dims.order: + root_img = img.get_image_data('TSZYX', S=root_index) + else: + root_img = img.get_image_data('TCZYX', C=root_index) + + root_stack.append(root_img) + workflow.set(name=root_names[idx], func_or_data=np.squeeze(root_img)) + + result = workflow.get(name=task_names) + + result_stack = np.asarray(result) + result_stack = transforms.reshape_data( + data=result_stack, + given_dims='C' + squeezed_img_dims, + return_dims='TCZYX', + ) + + if result_stack.dtype == np.int64: + result_stack = result_stack.astype(np.int32) + + if keep_original_images: + dask_images = da.concatenate(root_stack, axis=1) # along "C" + result_stack = da.concatenate([dask_images, result_stack], axis=1) + result_names = root_list + task_names + else: + result_names = task_names + + output_path = result_dir / (image_file.stem + '.tiff') + OmeTiffWriter.save( + data=result_stack, + uri=output_path, + dim_order='TCZYX', + channel_names=result_names, + image_name=image_file.stem, + physical_pixel_sizes=img.physical_pixel_sizes, + ) + + return output_path diff --git a/src/ndev_workflows/_io.py b/src/ndev_workflows/_io.py new file mode 100644 index 0000000..a551357 --- /dev/null +++ b/src/ndev_workflows/_io.py @@ -0,0 +1,174 @@ +"""YAML-based workflow persistence. + +This module provides functions for saving and loading workflows in a +human-readable YAML format that is safe to load (no arbitrary code execution). + +For loading legacy napari-workflows files, see `_io_legacy.py`. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import yaml + +from ._io_legacy import is_legacy_format +from ._spec import ( + ensure_runnable, + legacy_yaml_to_spec_dict, + spec_dict_to_workflow, + workflow_to_spec_dict, +) +from ._workflow import Workflow, WorkflowNotRunnableError + +if TYPE_CHECKING: + pass + + +class WorkflowYAMLError(Exception): + """Error during workflow YAML serialization/deserialization.""" + + +def save_workflow( + filename: str | Path, + workflow: Workflow, + *, + name: str | None = None, + description: str | None = None, +) -> None: + """Save a workflow to a YAML file. + + Parameters + ---------- + filename : str or Path + Path to save the workflow. + workflow : Workflow + The workflow to save. + name : str, optional + Human-readable name for the workflow. + description : str, optional + Description of what the workflow does. + + Example + ------- + >>> from ndev_workflows import Workflow, save_workflow + >>> workflow = Workflow() + >>> workflow.set("blurred", gaussian, "input", sigma=2.0) + >>> save_workflow("my_workflow.yaml", workflow, name="Blur Pipeline") + """ + data = workflow_to_spec_dict(workflow, name=name, description=description) + + with open(filename, 'w') as f: + yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False) + + +def load_workflow(filename: str | Path, *, lazy: bool = False) -> Workflow: + """Load a workflow from a YAML file. + + Automatically detects legacy napari-workflows format and loads appropriately. + + Parameters + ---------- + filename : str or Path + Path to the YAML file. + lazy : bool, optional + If True, don't import functions (use CallableRef placeholders). + Default is False (import functions). + + Returns + ------- + Workflow + The loaded workflow. + + Raises + ------ + WorkflowYAMLError + If loading fails or functions cannot be imported (when lazy=False). + + Example + ------- + >>> from ndev_workflows import load_workflow + >>> workflow = load_workflow("my_workflow.yaml") + >>> workflow.set("input", image_data) + >>> result = workflow.get("output") + """ + # Always normalize to a spec dict first, then apply the same + # lazy/eager workflow construction logic. + if is_legacy_format(filename): + spec = legacy_yaml_to_spec_dict(filename, include_modified=False) + is_legacy = True + else: + try: + with open(filename) as f: + spec = yaml.safe_load(f) + except yaml.YAMLError as e: + raise WorkflowYAMLError(f'Failed to parse YAML: {e}') from e + is_legacy = False + + workflow = spec_dict_to_workflow(spec, lazy=True) + workflow.metadata['legacy'] = is_legacy + if lazy: + return workflow + + try: + return ensure_runnable(workflow) + except WorkflowNotRunnableError as e: + raise WorkflowYAMLError(str(e)) from e + + +def migrate_legacy( + input_file: str | Path, + output_file: str | Path | None = None, + *, + name: str | None = None, +) -> Workflow: + """Migrate a legacy napari-workflows file to the new format. + + Parameters + ---------- + input_file : str or Path + Path to the legacy YAML file. + output_file : str or Path, optional + Path for the output file. If None, appends '_migrated' to the name. + name : str, optional + Name for the migrated workflow. + + Returns + ------- + Workflow + The migrated workflow. + + Example + ------- + >>> workflow = migrate_legacy("old_workflow.yaml", "new_workflow.yaml") + """ + input_path = Path(input_file) + + if output_file is None: + output_file = input_path.with_stem(input_path.stem + '_migrated') + + migrated_name = name or f'Migrated: {input_path.stem}' + + # Convert legacy YAML to new spec dict lazily, then save the spec. + spec = legacy_yaml_to_spec_dict( + input_file, + name=migrated_name, + include_modified=True, + ) + + with open(output_file, 'w') as f: + yaml.safe_dump(spec, f, default_flow_style=False, sort_keys=False) + + # Return a normalized lazy workflow (new-format in-memory representation). + return spec_dict_to_workflow(spec, lazy=True) + + +# Re-export commonly used items +__all__ = [ + 'WorkflowYAMLError', + 'save_workflow', + 'load_workflow', + 'migrate_legacy', + 'is_legacy_format', +] diff --git a/src/ndev_workflows/_io_legacy.py b/src/ndev_workflows/_io_legacy.py new file mode 100644 index 0000000..a11d144 --- /dev/null +++ b/src/ndev_workflows/_io_legacy.py @@ -0,0 +1,133 @@ +"""Legacy format (napari-workflows) persistence. + +This module handles loading workflows saved with the original napari-workflows +package. These files use Python pickle-style YAML tags that require unsafe loading. + +For new workflows, use the functions in `_io.py` which use a plain YAML format. +""" + +from __future__ import annotations + +from functools import partial +from pathlib import Path + +import yaml + +from ._workflow import CallableRef, Workflow + + +def is_legacy_format(filename: str | Path) -> bool: + """Check if a workflow file is in legacy napari-workflows format. + + Parameters + ---------- + filename : str or Path + Path to the YAML file. + + Returns + ------- + bool + True if the file uses legacy !!python/object format. + """ + with open(filename, encoding='utf-8', errors='replace') as f: + first_line = f.readline() + return '!!python/object:napari_workflows' in first_line + + +def load_legacy_lazy(filename: str | Path) -> Workflow: + """Load a legacy workflow without importing function modules. + + This is useful for inspecting or migrating workflows when the + original function modules are not installed. + + Parameters + ---------- + filename : str or Path + Path to the YAML file. + + Returns + ------- + Workflow + The loaded workflow with CallableRef placeholders. + + Notes + ----- + The returned workflow cannot be executed (functions are placeholders), + but it can be inspected, migrated, or its structure can be examined. + """ + + class LazyLoader(yaml.SafeLoader): + pass + + def construct_python_tuple(loader, node): + return tuple(loader.construct_sequence(node)) + + def construct_python_name(loader, suffix, node): + """Return a CallableRef instead of importing.""" + # suffix is like 'skimage.filters.gaussian' + parts = suffix.rsplit('.', 1) + if len(parts) == 2: + module, name = parts + else: + module = '' + name = suffix + return CallableRef(module, name) + + def construct_legacy_workflow(loader, node): + mapping = loader.construct_mapping(node, deep=True) + workflow = Workflow() + workflow._tasks = mapping.get('_tasks', {}) + return workflow + + def construct_functools_partial(loader, node): + """Construct a CallableRef from legacy functools.partial tags. + + Legacy napari-workflows YAML may encode keyword arguments as + ``!!python/object/apply:functools.partial``. + + We keep this loader *lazy* by returning a CallableRef and storing + the keyword arguments on it. + """ + seq = loader.construct_sequence(node, deep=True) + if not seq: + return None + + func = seq[0] + + # Common encodings: + # [func, args_tuple_or_list, kwargs_dict] + # [func, kwargs_dict] + kwargs = {} + if len(seq) >= 3 and isinstance(seq[2], dict): + kwargs = seq[2] + elif len(seq) == 2 and isinstance(seq[1], dict): + kwargs = seq[1] + + if isinstance(func, CallableRef): + func.kwargs = dict(kwargs) + return func + + # Fallback: if we somehow got a real callable, keep it callable. + # This is still safe because SafeLoader will not construct arbitrary + # callables unless we registered constructors for them. + if kwargs and callable(func): + return partial(func, **kwargs) + return func + + LazyLoader.add_constructor( + 'tag:yaml.org,2002:python/tuple', construct_python_tuple + ) + LazyLoader.add_multi_constructor( + 'tag:yaml.org,2002:python/name:', construct_python_name + ) + LazyLoader.add_constructor( + 'tag:yaml.org,2002:python/object:napari_workflows._workflow.Workflow', + construct_legacy_workflow, + ) + LazyLoader.add_constructor( + 'tag:yaml.org,2002:python/object/apply:functools.partial', + construct_functools_partial, + ) + + with open(filename, 'rb') as stream: + return yaml.load(stream, Loader=LazyLoader) diff --git a/src/ndev_workflows/_manager.py b/src/ndev_workflows/_manager.py new file mode 100644 index 0000000..1f2c4a3 --- /dev/null +++ b/src/ndev_workflows/_manager.py @@ -0,0 +1,409 @@ +"""Workflow manager for napari viewer integration. + +This module is derived from napari-workflows by Robert Haase (BSD-3-Clause). +See NOTICE file for attribution. + +The WorkflowManager provides a singleton pattern for managing workflows +per napari viewer, with automatic layer updates and undo/redo support. +""" + +from __future__ import annotations + +import threading +import time +import warnings +from collections.abc import Callable +from typing import TYPE_CHECKING, Any +from weakref import WeakValueDictionary + +from ._undo_redo import UndoRedoController +from ._workflow import Workflow + +if TYPE_CHECKING: + from napari import Viewer + from napari.layers import Layer + + +# Global registry of WorkflowManager instances per viewer +_managers: WeakValueDictionary[int, WorkflowManager] = WeakValueDictionary() + + +class WorkflowManager: + """Manages a workflow attached to a napari viewer. + + The WorkflowManager provides: + - Singleton pattern (one manager per viewer) + - Automatic layer updates when sources change + - Undo/redo functionality + - Background worker for non-blocking updates + - Code generation from workflow + + Parameters + ---------- + viewer : napari.Viewer + The napari viewer to manage. + + Notes + ----- + Use ``WorkflowManager.install(viewer)`` to get or create a manager + for a viewer. Do not instantiate directly. + """ + + def __init__(self, viewer: Viewer) -> None: + """Initialize the WorkflowManager. + + Parameters + ---------- + viewer : napari.Viewer + The napari viewer to manage. + """ + self._viewer = viewer + self._workflow = Workflow() + self._undo_redo = UndoRedoController(self._workflow) + + # Background worker for auto-updates + self._update_requested = threading.Event() + self._stop_worker = threading.Event() + self._pending_updates: list[str] = [] + self._worker_thread: threading.Thread | None = None + + # Auto-update settings + self._auto_update_enabled = True + self._update_delay = 0.1 # seconds + + # Start background worker + self._start_worker() + + @classmethod + def install(cls, viewer: Viewer) -> WorkflowManager: + """Get or create a WorkflowManager for a viewer. + + Parameters + ---------- + viewer : napari.Viewer + The napari viewer. + + Returns + ------- + WorkflowManager + The workflow manager for this viewer. + """ + viewer_id = id(viewer) + if viewer_id not in _managers: + manager = cls(viewer) + _managers[viewer_id] = manager + return _managers[viewer_id] + + @property + def workflow(self) -> Workflow: + """The workflow being managed.""" + return self._workflow + + @property + def viewer(self) -> Viewer: + """The napari viewer being managed.""" + return self._viewer + + @property + def undo_redo(self) -> UndoRedoController: + """The undo/redo controller.""" + return self._undo_redo + + def update( + self, + target_layer: str | Layer, + function: Callable, + *args: Any, + **kwargs: Any, + ) -> None: + """Update or add a workflow step. + + Parameters + ---------- + target_layer : str or Layer + The target layer name or layer object. + function : Callable + The function to apply. + *args : Any + Arguments for the function (can include layer names). + **kwargs : Any + Keyword arguments for the function. + + Notes + ----- + This saves the current state for undo, updates the workflow, + and schedules dependent layers for re-execution. + """ + # Save state for undo + self._undo_redo.save_state() + + # Get target name + target_name = ( + target_layer + if isinstance(target_layer, str) + else target_layer.name + ) + + # Convert layer objects to names in args + processed_args = [] + for arg in args: + if hasattr(arg, 'name'): + processed_args.append(arg.name) + else: + processed_args.append(arg) + + # Update workflow + self._workflow.set(target_name, function, *processed_args, **kwargs) + + # Schedule update of this and dependent layers + self._schedule_update(target_name) + + def _schedule_update(self, name: str) -> None: + """Schedule a layer update in the background. + + Parameters + ---------- + name : str + The task name to update. + """ + if not self._auto_update_enabled: + return + + # Add to pending updates + if name not in self._pending_updates: + self._pending_updates.append(name) + + # Also schedule followers + for follower in self._workflow.followers_of(name): + if follower not in self._pending_updates: + self._pending_updates.append(follower) + + # Signal worker + self._update_requested.set() + + def _start_worker(self) -> None: + """Start the background update worker.""" + if self._worker_thread is not None: + return + + self._stop_worker.clear() + self._worker_thread = threading.Thread( + target=self._worker_loop, + daemon=True, + name='WorkflowManager-worker', + ) + self._worker_thread.start() + + def _worker_loop(self) -> None: + """Background worker loop for processing updates.""" + while not self._stop_worker.is_set(): + # Wait for update request + self._update_requested.wait(timeout=0.5) + if self._stop_worker.is_set(): + break + + # Small delay to batch updates + time.sleep(self._update_delay) + + # Process pending updates + while self._pending_updates: + name = self._pending_updates.pop(0) + try: + self._execute_update(name) + except (ValueError, TypeError, RuntimeError, KeyError) as e: + warnings.warn( + f"Workflow update failed for '{name}': {e}", + stacklevel=2, + ) + + self._update_requested.clear() + + def _execute_update(self, name: str) -> None: + """Execute a workflow update and refresh the layer. + + Parameters + ---------- + name : str + The task name to execute. + """ + # Check if this is a processing step (not raw data) + if self._workflow.is_data_task(name): + return + + # Execute the workflow step + try: + result = self._workflow.get(name) + except (ValueError, TypeError, RuntimeError, KeyError) as e: + warnings.warn( + f"Failed to compute '{name}': {e}", + stacklevel=2, + ) + return + + # Update the layer if it exists + try: + layer = self._viewer.layers[name] + layer.data = result + except KeyError: + # Layer doesn't exist, could add it + pass + except (ValueError, TypeError, RuntimeError) as e: + warnings.warn( + f"Failed to update layer '{name}': {e}", + stacklevel=2, + ) + + def stop(self) -> None: + """Stop the background worker.""" + self._stop_worker.set() + self._update_requested.set() # Wake up worker + if self._worker_thread is not None: + self._worker_thread.join(timeout=1.0) + self._worker_thread = None + + def invalidate(self, name: str) -> None: + """Invalidate a task and its followers. + + Parameters + ---------- + name : str + The task name to invalidate. + + Notes + ----- + This schedules the task and all dependent tasks for re-execution. + """ + self._schedule_update(name) + + def undo(self) -> None: + """Undo the last workflow change.""" + self._undo_redo.undo() + # Refresh all layers + for name in self._workflow: + if not self._workflow.is_data_task(name): + self._schedule_update(name) + + def redo(self) -> None: + """Redo the last undone change.""" + self._undo_redo.redo() + # Refresh all layers + for name in self._workflow: + if not self._workflow.is_data_task(name): + self._schedule_update(name) + + def clear(self) -> None: + """Clear the workflow.""" + self._undo_redo.save_state() + self._workflow.clear() + + def to_python_code( + self, + notebook: bool = False, + use_napari: bool = True, + ) -> str: + """Generate Python code from the workflow. + + Parameters + ---------- + notebook : bool, optional + If True, format as Jupyter notebook cells. Default False. + use_napari : bool, optional + If True, include napari viewer code. Default True. + + Returns + ------- + str + Python code that reproduces the workflow. + """ + lines = [] + + # Collect imports + imports = set() + for name in self._workflow: + func = self._workflow.get_function(name) + if func is not None: + # Handle partial functions + if hasattr(func, 'func'): + func = func.func + if hasattr(func, '__module__') and hasattr(func, '__name__'): + imports.add( + f'from {func.__module__} import {func.__name__}' + ) + + # Add imports + if imports: + lines.extend(sorted(imports)) + lines.append('') + + if use_napari: + lines.append('import napari') + lines.append('viewer = napari.Viewer()') + lines.append('') + + # Generate code for each task in dependency order + executed = set() + + def generate_task(name: str) -> None: + if name in executed: + return + + # First generate dependencies + for source in self._workflow.sources_of(name): + generate_task(source) + + task = self._workflow.get_task(name) + if task is None: + return + + if self._workflow.is_data_task(name): + # Data task - placeholder + lines.append(f'# {name} = ') + else: + # Processing task + func = task[0] + args = task[1:] + + # Get function name + if hasattr(func, 'func'): + func_name = func.func.__name__ + # Include kwargs from partial + if hasattr(func, 'keywords') and func.keywords: + kwargs_str = ', '.join( + f'{k}={repr(v)}' for k, v in func.keywords.items() + ) + args_str = ', '.join(str(a) for a in args) + if args_str: + call = f'{func_name}({args_str}, {kwargs_str})' + else: + call = f'{func_name}({kwargs_str})' + else: + args_str = ', '.join(str(a) for a in args) + call = f'{func_name}({args_str})' + else: + func_name = getattr(func, '__name__', 'unknown_function') + args_str = ', '.join(str(a) for a in args) + call = f'{func_name}({args_str})' + + lines.append(f'{name} = {call}') + + executed.add(name) + + for name in self._workflow: + generate_task(name) + + if use_napari: + lines.append('') + lines.append('napari.run()') + + code = '\n'.join(lines) + + if notebook: + # Split into cells at blank lines + # This is a simple implementation; could be enhanced + code = code.replace('\n\n', '\n# %%\n') + + return code + + def __del__(self) -> None: + """Cleanup when manager is deleted.""" + self.stop() diff --git a/src/ndev_workflows/_spec.py b/src/ndev_workflows/_spec.py new file mode 100644 index 0000000..6043dfb --- /dev/null +++ b/src/ndev_workflows/_spec.py @@ -0,0 +1,176 @@ +"""Workflow <-> YAML spec conversion. + +This module is *purely* about translating between: +- in-memory :class:`ndev_workflows.Workflow` graphs, and +- the new, safe, human-readable YAML "spec" dict. + +It intentionally does not read/write files. Disk I/O lives in `_io.py`. +Legacy YAML parsing lives in `_io_legacy.py`. +""" + +from __future__ import annotations + +import importlib +from datetime import datetime +from functools import partial +from pathlib import Path + +from ._io_legacy import load_legacy_lazy +from ._workflow import CallableRef, Workflow + + +def workflow_to_spec_dict( + workflow: Workflow, + *, + name: str | None = None, + description: str | None = None, + include_modified: bool = True, +) -> dict: + """Convert a workflow to the new YAML spec dict.""" + spec: dict = {} + if name: + spec['name'] = name + if description: + spec['description'] = description + if include_modified: + spec['modified'] = datetime.now().date().isoformat() + + tasks: dict[str, dict] = {} + saved_task_names: set[str] = set() + + for task_name, task in workflow.tasks.items(): + # Skip data tasks (not tuples) and empty tuples. + if not isinstance(task, tuple) or len(task) == 0: + continue + + func = task[0] + args = task[1:] + + if isinstance(func, CallableRef): + func_path = f'{func.module}.{func.name}' + kwargs = getattr(func, 'kwargs', {}) + elif isinstance(func, partial): + func_path = f'{func.func.__module__}.{func.func.__name__}' + kwargs = dict(func.keywords) if func.keywords else {} + elif callable(func): + func_path = f'{func.__module__}.{func.__name__}' + kwargs = {} + else: + # Unknown task encoding + continue + + saved_task_names.add(task_name) + + params: dict[str, object] = { + f'arg{i}': arg for i, arg in enumerate(args) + } + params.update(kwargs) + + tasks[task_name] = { + 'function': func_path, + 'params': params, + } + + # Inputs: referenced names that aren't saved as tasks. + all_referenced: set[str] = set() + for task_data in tasks.values(): + for param_name, param_value in task_data['params'].items(): + if isinstance(param_value, str) and param_name.startswith('arg'): + all_referenced.add(param_value) + + inputs = [n for n in all_referenced if n not in saved_task_names] + outputs = [n for n in saved_task_names if n not in all_referenced] + + spec['inputs'] = inputs + spec['outputs'] = outputs + spec['tasks'] = tasks + + return spec + + +def spec_dict_to_workflow(spec: dict, *, lazy: bool = False) -> Workflow: + """Convert a new-format YAML spec dict to a Workflow object.""" + workflow = Workflow() + workflow.metadata = { + 'name': spec.get('name'), + 'description': spec.get('description'), + 'modified': spec.get('modified'), + 'inputs': spec.get('inputs', []), + 'outputs': spec.get('outputs', []), + } + tasks = spec.get('tasks', {}) + + for task_name, task_data in tasks.items(): + func_path = task_data['function'] + params = task_data.get('params', {}) + + module_path, _, func_name = func_path.rpartition('.') + + if lazy: + func = CallableRef(module_path, func_name) + else: + try: + module = importlib.import_module(module_path) + func = getattr(module, func_name) + except (ImportError, AttributeError) as e: + raise ImportError( + f"Cannot import function '{func_name}' from '{module_path}': {e}" + ) from e + + # Extract args and kwargs + args: list[object] = [] + kwargs: dict[str, object] = {} + for param_name, param_value in params.items(): + if param_name.startswith('arg') and param_name[3:].isdigit(): + idx = int(param_name[3:]) + while len(args) <= idx: + args.append(None) + args[idx] = param_value + else: + kwargs[param_name] = param_value + + # Apply kwargs + if kwargs and not lazy: + func = partial(func, **kwargs) + elif kwargs and lazy: + func.kwargs = kwargs + + workflow._tasks[task_name] = (func, *args) + + return workflow + + +def ensure_runnable( + workflow_or_spec: Workflow | dict, +) -> Workflow: + """Ensure a workflow is runnable. + + Accepts either a Workflow (possibly loaded with ``lazy=True``) or a + new-format spec dict. + """ + if isinstance(workflow_or_spec, dict): + workflow = spec_dict_to_workflow(workflow_or_spec, lazy=True) + else: + workflow = workflow_or_spec + return workflow.ensure_runnable() + + +def legacy_yaml_to_spec_dict( + filename: str | Path, + *, + name: str | None = None, + description: str | None = None, + include_modified: bool = False, +) -> dict: + """Load a legacy napari-workflows YAML and convert it to the new spec dict. + + This function is intentionally *lazy*: it never imports referenced + functions. + """ + legacy_workflow = load_legacy_lazy(filename) + return workflow_to_spec_dict( + legacy_workflow, + name=name, + description=description, + include_modified=include_modified, + ) diff --git a/src/ndev_workflows/_undo_redo.py b/src/ndev_workflows/_undo_redo.py new file mode 100644 index 0000000..a4ba3c2 --- /dev/null +++ b/src/ndev_workflows/_undo_redo.py @@ -0,0 +1,179 @@ +"""Undo/redo functionality for workflows. + +This module is derived from napari-workflows by Robert Haase (BSD-3-Clause). +See NOTICE file for attribution. + +Provides an UndoRedoController that tracks workflow states and allows +undo/redo operations. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._workflow import Workflow + + +class UndoRedoController: + """Controller for undo/redo operations on workflows. + + The controller maintains two stacks: + - undo_stack: Previous states that can be restored + - redo_stack: States undone that can be re-applied + + Parameters + ---------- + workflow : Workflow + The workflow to track. + max_history : int, optional + Maximum number of states to keep in history. Default 50. + + Examples + -------- + >>> controller = UndoRedoController(workflow) + >>> controller.save_state() # Before making changes + >>> # ... make changes to workflow ... + >>> controller.undo() # Restore previous state + >>> controller.redo() # Re-apply the undone changes + """ + + def __init__(self, workflow: Workflow, max_history: int = 50) -> None: + """Initialize the undo/redo controller. + + Parameters + ---------- + workflow : Workflow + The workflow to track. + max_history : int, optional + Maximum number of states to keep in history. Default 50. + """ + self._workflow = workflow + self._max_history = max_history + self._undo_stack: list[dict] = [] + self._redo_stack: list[dict] = [] + + @property + def can_undo(self) -> bool: + """Whether undo is available.""" + return len(self._undo_stack) > 0 + + @property + def can_redo(self) -> bool: + """Whether redo is available.""" + return len(self._redo_stack) > 0 + + @property + def undo_stack_size(self) -> int: + """Number of states in undo stack.""" + return len(self._undo_stack) + + @property + def redo_stack_size(self) -> int: + """Number of states in redo stack.""" + return len(self._redo_stack) + + def save_state(self) -> None: + """Save the current workflow state to the undo stack. + + Call this before making changes to the workflow that you want + to be undoable. + + Notes + ----- + Saving a new state clears the redo stack. + """ + from copy import deepcopy + + # Deep copy the tasks dict + state = deepcopy(self._workflow._tasks) + self._undo_stack.append(state) + + # Trim history if needed + while len(self._undo_stack) > self._max_history: + self._undo_stack.pop(0) + + # Clear redo stack when new state is saved + self._redo_stack.clear() + + def undo(self) -> bool: + """Restore the previous workflow state. + + Returns + ------- + bool + True if undo was successful, False if stack was empty. + """ + if not self.can_undo: + return False + + from copy import deepcopy + + # Save current state to redo stack + current_state = deepcopy(self._workflow._tasks) + self._redo_stack.append(current_state) + + # Restore previous state + previous_state = self._undo_stack.pop() + self._workflow._tasks = previous_state + + return True + + def redo(self) -> bool: + """Re-apply a previously undone change. + + Returns + ------- + bool + True if redo was successful, False if stack was empty. + """ + if not self.can_redo: + return False + + from copy import deepcopy + + # Save current state to undo stack + current_state = deepcopy(self._workflow._tasks) + self._undo_stack.append(current_state) + + # Restore redo state + redo_state = self._redo_stack.pop() + self._workflow._tasks = redo_state + + return True + + def clear_history(self) -> None: + """Clear all undo/redo history.""" + self._undo_stack.clear() + self._redo_stack.clear() + + def get_workflow_copy(self) -> Workflow: + """Get a copy of the current workflow. + + Returns + ------- + Workflow + A deep copy of the tracked workflow. + """ + return self._workflow.copy() + + +def copy_workflow_state(workflow: Workflow) -> Workflow: + """Create a copy of a workflow. + + Parameters + ---------- + workflow : Workflow + The workflow to copy. + + Returns + ------- + Workflow + A deep copy of the workflow. + + Notes + ----- + This is a convenience function for external use. The UndoRedoController + uses internal state copying for efficiency. + """ + return workflow.copy() diff --git a/src/ndev_workflows/_workflow.py b/src/ndev_workflows/_workflow.py new file mode 100644 index 0000000..bd6e956 --- /dev/null +++ b/src/ndev_workflows/_workflow.py @@ -0,0 +1,549 @@ +"""Core Workflow class for ndev-workflows. + +The Workflow class represents a dask-compatible task graph that tracks +dependencies between processing steps in napari. +""" + +from __future__ import annotations + +import importlib +from collections.abc import Callable, Iterable +from copy import deepcopy +from dataclasses import dataclass +from functools import partial +from types import MappingProxyType +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + pass + + +class CallableRef: + """Placeholder for a callable that hasn't been imported yet.""" + + def __init__(self, module: str, name: str): + self.module = module + self.name = name + self.kwargs: dict = {} + + def __repr__(self) -> str: + if self.kwargs: + return ( + f'CallableRef({self.module}.{self.name}, kwargs={self.kwargs})' + ) + return f'CallableRef({self.module}.{self.name})' + + def __call__(self, *args, **kwargs): + raise WorkflowNotRunnableError( + [ + MissingCallable( + module=self.module, + name=self.name, + error='CallableRef is unresolved (lazy workflow)', + ) + ] + ) + + +@dataclass(frozen=True, slots=True) +class MissingCallable: + module: str + name: str + error: str + + +class WorkflowNotRunnableError(RuntimeError): + def __init__(self, missing: Iterable[MissingCallable]): + self.missing = tuple(missing) + lines: list[str] = ['Workflow is not runnable; missing callables:'] + for item in self.missing: + top = item.module.split('.', 1)[0] if item.module else '' + suggestion = None + if top: + alt = top.replace('_', '-') + if alt != top: + suggestion = f'pip install {alt}' + else: + suggestion = f'pip install {top}' + + msg = f'- Cannot import {item.module}.{item.name}: {item.error}' + if suggestion: + msg += f' (try: {suggestion})' + lines.append(msg) + super().__init__('\n'.join(lines)) + + +class Workflow: + """A dask-compatible task graph for image processing workflows. + + The Workflow class stores processing steps as a dictionary of tasks, + where each task is a tuple of (function, *args). Arguments can reference + other task names, creating a dependency graph that is lazily evaluated. + + This is compatible with dask's task graph format, allowing workflows + to be executed with dask's threaded scheduler. + + Parameters + ---------- + None + + Attributes + ---------- + _tasks : dict + Dictionary mapping task names to (function, *args) tuples. + + Examples + -------- + >>> from ndev_workflows import Workflow + >>> from skimage.filters import gaussian + >>> w = Workflow() + >>> w.set("input", image_data) # Raw data, not a processing step + >>> w.set("blurred", gaussian, "input", sigma=2.0) + >>> result = w.get("blurred") # Executes the graph + """ + + def __init__(self) -> None: + """Initialize an empty workflow.""" + self._tasks: dict[str, tuple] = {} + self.metadata: dict[str, object] = {} + + def set( + self, + name: str, + func_or_data: Callable | Any, + *args: Any, + **kwargs: Any, + ) -> None: + """Add or update a task in the workflow. + + Parameters + ---------- + name : str + The name/key for this task. Can be used as a reference in other tasks. + func_or_data : Callable or Any + Either a callable (function) to execute, or raw data to store. + If a callable, it will be executed with the provided args/kwargs. + *args : Any + Positional arguments for the function. String args that match + task names will be resolved to those task outputs. + **kwargs : Any + Keyword arguments for the function. + + Notes + ----- + When storing raw data (not a callable), the data is stored directly + (not as a tuple). This is compatible with dask's task graph format. + + When storing a callable, the task format is: + ``(func_or_partial, *args)`` + """ + if not callable(func_or_data): + # Raw data - store directly (dask-compatible) + self._tasks[name] = func_or_data + return + + func: Callable + # Store only explicitly provided kwargs; do not bake in defaults. + # This keeps YAML exports minimal/stable and matches typical + # napari-workflows behavior. + func = partial(func_or_data, **kwargs) if kwargs else func_or_data + + # Store as dask-compatible task tuple + self._tasks[name] = (func, *args) + + @property + def tasks(self): + """Read-only view of the underlying dask task graph. + + This is the public accessor for the workflow's task graph. The + underlying storage is ``self._tasks`` for dask-graph compatibility. + Prefer this property over accessing ``_tasks`` directly. + """ + return MappingProxyType(self._tasks) + + def get(self, name: str | list[str]) -> Any: + """Execute the workflow graph and return the result for task(s). + + Parameters + ---------- + name : str or list[str] + The name of the task to compute, or a list of task names. + + Returns + ------- + Any + The computed result of the task. If a list of names is given, + returns a list of results. + + Raises + ------ + KeyError + If a task name is not found in the workflow. + + Notes + ----- + This uses dask's threaded scheduler to execute the task graph, + automatically resolving dependencies. + """ + from dask.threaded import get as dask_get + + if isinstance(name, list): + for n in name: + if n not in self._tasks: + raise KeyError(f"Task '{n}' not found in workflow") + return [dask_get(self._tasks, n) for n in name] + + if name not in self._tasks: + raise KeyError(f"Task '{name}' not found in workflow") + return dask_get(self._tasks, name) + + def roots(self) -> list[str]: + """Return workflow input names (graph roots). + + Roots are names that are used as inputs to processing tasks and are + not produced by any other processing task. + + Importantly, roots remain roots even after you provide data via + ``workflow.set(root_name, data)``. + + Returns + ------- + list[str] + List of names that are referenced but not defined as tasks. + + Notes + ----- + This is *not* the same as :meth:`external_inputs`, which returns only + undefined inputs (i.e. roots that have not been provided as data tasks). + """ + # Build a dependency edge list: source -> task_name + sources_in_order: list[str] = [] + sources_seen: set[str] = set() + targets: set[str] = set() + + for task_name, task in self._tasks.items(): + if not isinstance(task, tuple) or len(task) <= 1: + continue + + targets.add(task_name) + for arg in task[1:]: + if not isinstance(arg, str): + continue + if arg not in sources_seen: + sources_seen.add(arg) + sources_in_order.append(arg) + + # Roots are sources that are never targets. + return [name for name in sources_in_order if name not in targets] + + def leaves(self) -> list[str]: + """Return the leaf nodes (outputs) of the workflow. + + Leaves are tasks that do not have any followers - nothing + depends on them. These are typically the final outputs. + + Returns + ------- + list[str] + List of task names that are leaves. + """ + # Collect tasks that ARE referenced by other tasks + has_followers = set() + for task in self._tasks.values(): + # Only tuples can have arguments that reference other tasks + if isinstance(task, tuple) and len(task) > 1: + for arg in task[1:]: + if isinstance(arg, str) and arg in self._tasks: + has_followers.add(arg) + + # Leaves are tasks with no followers + return [name for name in self._tasks if name not in has_followers] + + def leafs(self) -> list[str]: + """Alias for :meth:`leaves` (napari-workflows spelling).""" + return self.leaves() + + def get_undefined_inputs(self) -> list[str]: + """Return undefined input names. + + These are roots that are referenced by processing tasks but have not + been provided as tasks (typically via ``workflow.set(name, data)``). + + Returns + ------- + list[str] + List of names referenced but not defined. + + """ + # Preserve the stable ordering of roots(). + return [name for name in self.roots() if name not in self._tasks] + + def processing_task_names(self) -> list[str]: + """Return names of processing tasks (excluding raw data tasks).""" + return [ + name + for name, task in self._tasks.items() + if isinstance(task, tuple) and len(task) > 0 + ] + + def ensure_runnable(self) -> Workflow: + """Resolve any CallableRef placeholders into real imported callables. + + Parameters + ---------- + Returns + ------- + Workflow + Self (mutated in place). + + Raises + ------ + NotRunnableWorkflowError + If one or more callables cannot be imported. + """ + missing: list[MissingCallable] = [] + + for task_name, task in list(self._tasks.items()): + if not isinstance(task, tuple) or len(task) == 0: + continue + + func = task[0] + if not isinstance(func, CallableRef): + continue + + try: + module = importlib.import_module(func.module) + real_func = getattr(module, func.name) + except (ImportError, AttributeError) as e: + missing.append( + MissingCallable( + module=func.module, + name=func.name, + error=str(e), + ) + ) + continue + + if getattr(func, 'kwargs', None): + real_func = partial(real_func, **dict(func.kwargs)) + + self._tasks[task_name] = (real_func, *task[1:]) + + if missing: + raise WorkflowNotRunnableError(missing) + + return self + + def root_functions(self) -> dict[str, tuple]: + """Return the functions that operate directly on root inputs. + + These are the first processing steps in the workflow - functions + that take root tasks (data) as their primary input. + + Returns + ------- + dict[str, tuple] + Dictionary mapping task names to their task tuples for all + tasks that depend directly on root tasks. + + Notes + ----- + This is useful for initializing workflows when loading, as you + typically need to connect input data to these root functions. + """ + root_names = set(self.roots()) + root_funcs = {} + + for name, task in self._tasks.items(): + # Skip data tasks (stored directly, not as tuples) + if not isinstance(task, tuple): + continue + if len(task) == 0 or not callable(task[0]): + continue + + # Check if any source is a root + sources = self.sources_of(name) + if any(src in root_names for src in sources): + root_funcs[name] = task + + return root_funcs + + def followers_of(self, name: str) -> list[str]: + """Return tasks that depend on the given task. + + Parameters + ---------- + name : str + The name of the task to find followers for. + + Returns + ------- + list[str] + List of task names that depend on this task. + """ + followers = [] + for task_name, task in self._tasks.items(): + if task_name == name: + continue + # Only tuples can have arguments that reference other tasks + if isinstance(task, tuple) and len(task) > 1: + for arg in task[1:]: + if arg == name: + followers.append(task_name) + break + return followers + + def sources_of(self, name: str) -> list[str]: + """Return names that the given task depends on. + + Parameters + ---------- + name : str + The name of the task to find sources for. + + Returns + ------- + list[str] + List of names that this task references as inputs. + Includes both defined tasks and external references. + """ + if name not in self._tasks: + return [] + + task = self._tasks[name] + sources = [] + + # Only tuples can have arguments that reference other tasks + if isinstance(task, tuple) and len(task) > 1: + for arg in task[1:]: + if isinstance(arg, str): + sources.append(arg) + + return sources + + def keys(self) -> list[str]: + """Return all task names in the workflow. + + Returns + ------- + list[str] + List of all task names. + """ + return list(self._tasks.keys()) + + def __contains__(self, name: str) -> bool: + """Check if a task name exists in the workflow.""" + return name in self._tasks + + def __len__(self) -> int: + """Return the number of tasks in the workflow.""" + return len(self._tasks) + + def __iter__(self): + """Iterate over task names.""" + return iter(self._tasks) + + def __repr__(self) -> str: + """Return a string representation of the workflow.""" + n_tasks = len(self._tasks) + roots = self.roots() + leafs = self.leaves() + return f'Workflow({n_tasks} tasks, roots={roots}, leafs={leafs})' + + def copy(self) -> Workflow: + """Create a deep copy of this workflow. + + Returns + ------- + Workflow + A new Workflow with copied tasks. + """ + new_workflow = Workflow() + new_workflow._tasks = deepcopy(self._tasks) + return new_workflow + + def remove(self, name: str) -> None: + """Remove a task from the workflow. + + Parameters + ---------- + name : str + The name of the task to remove. + + Notes + ----- + This does not check for or update dependencies. Tasks that + depended on the removed task will fail when executed. + """ + if name in self._tasks: + del self._tasks[name] + + def clear(self) -> None: + """Remove all tasks from the workflow.""" + self._tasks.clear() + + def get_task(self, name: str) -> tuple | None: + """Get the raw task tuple for a given name. + + Parameters + ---------- + name : str + The name of the task. + + Returns + ------- + tuple or None + The task tuple (function, *args), or None if not found. + """ + return self._tasks.get(name) + + def get_function(self, name: str) -> Callable | None: + """Get the function for a given task. + + Parameters + ---------- + name : str + The name of the task. + + Returns + ------- + Callable or None + The function (may be a partial), or None if not found + or if the task is raw data. + """ + task = self._tasks.get(name) + if task is None: + return None + # Data tasks are stored directly (not as tuples) + if not isinstance(task, tuple): + return None + if len(task) == 0: + return None + func = task[0] + return func if callable(func) else None + + def is_data_task(self, name: str) -> bool: + """Check if a task represents raw data (not a processing step). + + Parameters + ---------- + name : str + The name of the task. + + Returns + ------- + bool + True if the task is raw data, False if it's a processing step. + """ + task = self._tasks.get(name) + if task is None: + return False + # Data tasks are stored directly (not as tuples) + # Processing tasks are stored as tuples (func, *args) + if not isinstance(task, tuple): + return True + # Edge case: empty tuple would be data, but shouldn't happen + if len(task) == 0: + return True + # If it's a tuple with a callable first element, it's a processing task + return not callable(task[0]) diff --git a/src/ndev_workflows/napari.yaml b/src/ndev_workflows/napari.yaml index 9da002d..495ca3a 100644 --- a/src/ndev_workflows/napari.yaml +++ b/src/ndev_workflows/napari.yaml @@ -1,8 +1,12 @@ name: ndev-workflows display_name: ndev Workflows -# use 'hidden' to remove plugin from napari hub search results visibility: public -# see https://napari.org/stable/plugins/technical_references/manifest.html#fields for valid categories -# categories: [] +categories: [Utilities] contributions: commands: + - id: ndev-workflows.workflow_container + title: Workflow Container + python_name: ndev_workflows.widgets._workflow_container:WorkflowContainer + widgets: + - command: ndev-workflows.workflow_container + display_name: Workflow Container diff --git a/src/ndev_workflows/widgets/__init__.py b/src/ndev_workflows/widgets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ndev_workflows/widgets/_workflow_container.py b/src/ndev_workflows/widgets/_workflow_container.py new file mode 100644 index 0000000..e1b3b75 --- /dev/null +++ b/src/ndev_workflows/widgets/_workflow_container.py @@ -0,0 +1,579 @@ +"""Workflow container widget for batch processing with napari-workflows. + +This module provides a Container widget for managing napari-workflows in both +interactive (viewer) and batch processing modes. It integrates with nbatch +for parallel execution of workflows on multiple files. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from magicgui.widgets import ( + CheckBox, + ComboBox, + Container, + FileEdit, + LineEdit, + ProgressBar, + PushButton, + Select, +) + +if TYPE_CHECKING: + import napari + + +class WorkflowContainer(Container): + """Container widget for managing napari-workflows. + + Provides both interactive (viewer) and batch processing modes for + executing napari-workflows. Integrates with nbatch for parallel + batch processing with progress tracking and error handling. + + Parameters + ---------- + viewer : napari.viewer.Viewer, optional + The napari viewer instance. If None, viewer-based workflow + execution will be disabled. + + Attributes + ---------- + workflow : napari_workflows.Workflow or None + The currently loaded workflow. + image_files : list[Path] + List of image files for batch processing. + + Example + ------- + >>> container = WorkflowContainer(viewer) + >>> viewer.window.add_dock_widget(container) + >>> # Select workflow file, image directory, and run batch + + """ + + def __init__(self, viewer: napari.viewer.Viewer = None): + """Initialize the WorkflowContainer widget. + + Parameters + ---------- + viewer : napari.viewer.Viewer, optional + The napari viewer instance. + + """ + super().__init__() + self._viewer = viewer if viewer is not None else None + self._channel_names = [] + self._img_dims = '' + self._squeezed_img_dims = '' + self.image_files = [] + self.workflow = None + self._workflow_inputs = [] # Declared inputs from YAML (stable) + self._root_scale = None + + self._init_widgets() + self._init_batch_runner() + self._init_viewer_container() + self._init_batch_container() + self._init_tasks_container() + self._init_layout() + self._connect_events() + + def _init_batch_runner(self): + """Initialize the BatchRunner for batch processing.""" + from nbatch import BatchRunner + + self._batch_runner = BatchRunner( + on_start=self._on_batch_start, + on_item_complete=self._on_batch_item_complete, + on_complete=self._on_batch_complete, + on_error=self._on_batch_error, + on_cancel=self._on_batch_cancel, + ) + + def _on_batch_start(self, total: int): + """Callback when batch starts - initialize progress bar.""" + self._progress_bar.label = f'Workflow on {total} images' + self._progress_bar.value = 0 + self._progress_bar.max = total + self.batch_button.enabled = False + self._cancel_button.enabled = True + + def _get_viewer_layers(self): + """Get layers from the viewer.""" + if self._viewer is None: + return [] + return list(self._viewer.layers) + + def _init_widgets(self): + """Initialize non-Container widgets.""" + self.workflow_file = FileEdit( + label='Workflow File', + filter='*.yaml', + tooltip='Select a workflow file to load', + ) + self._workflow_roots = LineEdit(label='Workflow Roots:') + self._progress_bar = ProgressBar(label='Progress:') + + def _init_viewer_container(self): + """Initialize the viewer container tab widgets.""" + self.viewer_button = PushButton(text='Viewer Workflow') + self._viewer_roots_container = Container(layout='vertical', label=None) + self._viewer_roots_container.native.layout().addStretch() + self._viewer_container = Container( + layout='vertical', + widgets=[ + self.viewer_button, + self._viewer_roots_container, + ], + label='Viewer', + labels=None, + ) + + def _init_batch_container(self): + """Initialize the batch container tab widgets.""" + self.image_directory = FileEdit(label='Image Directory', mode='d') + self.result_directory = FileEdit(label='Result Directory', mode='d') + self._keep_original_images = CheckBox( + label='Keep Original Images', + value=False, + tooltip='If checked, the original images will be ' + 'concatenated with the results', + ) + self.batch_button = PushButton(label='Batch Workflow') + self._cancel_button = PushButton(label='Cancel') + self._cancel_button.enabled = False + self._batch_info_container = Container( + layout='vertical', + widgets=[ + self.image_directory, + self.result_directory, + self._keep_original_images, + self.batch_button, + self._cancel_button, + ], + ) + + self._batch_roots_container = Container(layout='vertical', label=None) + self._batch_roots_container.native.layout().addStretch() + + self._batch_container = Container( + layout='vertical', + widgets=[ + self._batch_info_container, + self._batch_roots_container, + ], + label='Batch', + labels=None, + ) + + def _init_tasks_container(self): + """Initialize the tasks container.""" + self._tasks_select = Select( + choices=[], + nullable=False, + allow_multiple=True, + ) + self._tasks_container = Container( + layout='vertical', + widgets=[self._tasks_select], + label='Tasks', + ) + + def _init_layout(self): + """Initialize the layout of the widgets.""" + from magicclass.widgets import TabbedContainer + + self.extend( + [ + self.workflow_file, + self._workflow_roots, + self._progress_bar, + ] + ) + self._tabs = TabbedContainer( + widgets=[ + self._viewer_container, + self._batch_container, + self._tasks_container, + ], + label=None, + labels=None, + ) + self.native.layout().addWidget(self._tabs.native) + self.native.layout().addStretch() + + def _connect_events(self): + """Connect the events of the widgets to respective methods.""" + self.image_directory.changed.connect(self._get_image_info) + self.workflow_file.changed.connect(self._get_workflow_info) + self.batch_button.clicked.connect(self.batch_workflow) + self._cancel_button.clicked.connect(self._batch_runner.cancel) + self.viewer_button.clicked.connect(self.viewer_workflow_threaded) + + if self._viewer is not None: + self._viewer.layers.events.removed.connect( + self._update_layer_choices + ) + self._viewer.layers.events.inserted.connect( + self._update_layer_choices + ) + + def _get_image_info(self): + """Get channels and dims from first image in the directory.""" + from ndevio import helpers, nImage + + self.image_dir, self.image_files = helpers.get_directory_and_files( + self.image_directory.value, + ) + img = nImage(self.image_files[0]) + + self._channel_names = helpers.get_channel_names(img) + + for widget in self._batch_roots_container: + widget.choices = self._channel_names + + self._squeezed_img_dims = helpers.get_squeezed_dim_order(img) + return self._squeezed_img_dims + + def _update_layer_choices(self): + """Update the choices of the layers for the viewer workflow.""" + for widget in self._viewer_roots_container: + widget.choices = self._get_viewer_layers() + return + + def _update_roots(self): + """Get the roots from the workflow and update the ComboBox widgets.""" + from ndevio import helpers + + self._batch_roots_container.clear() + self._viewer_roots_container.clear() + + for idx, root in enumerate(self.workflow.roots()): + short_root = helpers.elide_string(root, max_length=12) + + batch_root_combo = ComboBox( + label=f'Root {idx}: {short_root}', + choices=self._channel_names, + nullable=True, + value=None, + ) + self._batch_roots_container.append(batch_root_combo) + + viewer_root_combo = ComboBox( + label=f'Root {idx}: {short_root}', + choices=self._get_viewer_layers(), + nullable=True, + value=None, + ) + self._viewer_roots_container.append(viewer_root_combo) + + return + + def _update_task_choices(self, workflow=None, tasks=None, leafs=None): + """Update the choices of the tasks with the workflow tasks. + + Parameters + ---------- + workflow : Workflow, optional + Workflow object to extract tasks from. Used when full workflow is loaded. + tasks : list[str], optional + List of task names. Used with v3 metadata preview. + leafs : list[str], optional + Default selected tasks (outputs). Used with v3 metadata. + """ + if tasks is not None: + self._tasks_select.choices = tasks + self._tasks_select.value = leafs if leafs else tasks[-1:] + elif workflow is not None: + self._tasks_select.choices = workflow.processing_task_names() + self._tasks_select.value = workflow.leafs() + + def _get_workflow_info(self): + """Load the workflow file and update the roots and leafs. + + Uses the loaded Workflow's metadata for fast preview. + """ + from .._io import load_workflow + + workflow_path = self.workflow_file.value + + # Load workflow lazily so missing optional deps don't break the UI. + # load_workflow() does not import task functions when lazy=True. + try: + self.workflow = load_workflow(workflow_path, lazy=True) + except Exception: # noqa + self.workflow = None + return + + metadata = getattr(self.workflow, 'metadata', {}) or {} + self._workflow_inputs = list(metadata.get('inputs', [])) + self._workflow_roots.value = str(self._workflow_inputs) + self._update_roots_from_list(self._workflow_inputs) + + self._update_task_choices( + tasks=self.workflow.processing_task_names(), + leafs=list(metadata.get('outputs', [])), + ) + return + + def _update_roots_from_list(self, roots: list[str]): + """Update root ComboBox widgets from a list of root names. + + Used for v3 format metadata preview. + """ + from ndevio import helpers + + self._batch_roots_container.clear() + self._viewer_roots_container.clear() + + for idx, root in enumerate(roots): + short_root = helpers.elide_string(root, max_length=12) + + batch_root_combo = ComboBox( + label=f'Root {idx}: {short_root}', + choices=self._channel_names, + nullable=True, + value=None, + ) + self._batch_roots_container.append(batch_root_combo) + + viewer_root_combo = ComboBox( + label=f'Root {idx}: {short_root}', + choices=self._get_viewer_layers(), + nullable=True, + value=None, + ) + self._viewer_roots_container.append(viewer_root_combo) + + return + + def _update_progress_bar(self, value): + self._progress_bar.value = value + return + + def _on_batch_item_complete(self, result, ctx): + """Callback when a batch item completes successfully.""" + self._progress_bar.value = ctx.index + 1 + + def _on_batch_complete(self): + """Callback when the entire batch completes.""" + total = self._progress_bar.max + errors = self._batch_runner.error_count + if errors > 0: + self._progress_bar.label = ( + f'Completed {total - errors} Images ({errors} Errors)' + ) + else: + self._progress_bar.label = f'Completed {total} Images' + self.batch_button.enabled = True + self._cancel_button.enabled = False + + def _on_batch_error(self, ctx, exception): + """Callback when a batch item fails. + + Note: Error logging is handled by BatchRunner's internal logger. + This callback only updates the UI. + """ + self._progress_bar.label = f'Error on {ctx.item.name}: {exception}' + + def _on_batch_cancel(self): + """Callback when the batch is cancelled.""" + self._progress_bar.label = 'Cancelled' + self.batch_button.enabled = True + self._cancel_button.enabled = False + + def batch_workflow(self): + """Run the workflow on all images in the image directory.""" + from .._batch import process_workflow_file + + result_dir = self.result_directory.value + image_files = self.image_files + + root_list = [widget.value for widget in self._batch_roots_container] + root_index_list = [self._channel_names.index(r) for r in root_list] + task_names = self._tasks_select.value + + self._batch_runner.run( + process_workflow_file, + image_files, + result_dir=result_dir, + workflow_file=self.workflow_file.value, + root_index_list=root_index_list, + task_names=task_names, + keep_original_images=self._keep_original_images.value, + root_list=root_list, + squeezed_img_dims=self._squeezed_img_dims, + log_file=result_dir / 'workflow.log.txt', + log_header={ + 'Image Directory': str(self.image_directory.value), + 'Result Directory': str(result_dir), + 'Workflow File': str(self.workflow_file.value), + 'Roots': str(root_list), + 'Tasks': str(task_names), + }, + threaded=True, + ) + + def viewer_workflow(self): + """Run the workflow on the viewer layers.""" + from .._io import load_workflow + from .._spec import ensure_runnable + from .._workflow import WorkflowNotRunnableError + + # Reload workflow for fresh state (previous run may have set data) + workflow = load_workflow(self.workflow_file.value, lazy=True) + + try: + workflow = ensure_runnable(workflow) + except WorkflowNotRunnableError as e: + from napari.utils.notifications import show_error + + show_error(str(e)) + return + + root_layer_list = [ + widget.value for widget in self._viewer_roots_container + ] + self._root_scale = root_layer_list[0].scale + + # Use stored input names (stable, from YAML metadata) + for root_idx, root_layer in enumerate(root_layer_list): + workflow.set( + name=self._workflow_inputs[root_idx], + func_or_data=root_layer.data, + ) + + for task_idx, task in enumerate(self._tasks_select.value): + func = workflow.get_function(task) + result = workflow.get(name=task) + yield task_idx, task, result, func + + return + + def _viewer_workflow_yielded(self, value): + task_idx, task, result, func = value + self._add_result_to_viewer(task=task, result=result, func=func) + self._progress_bar.value = task_idx + 1 + return + + def _add_result_to_viewer(self, *, task: str, result, func=None) -> None: + """Add a workflow result to the viewer using a best-effort layer choice. + + Rules: + - If the task returns a napari LayerDataTuple ``(data, kwargs, layer_type)``, + use that (this is the recommended way for non-image outputs like shapes). + - Otherwise, if the result looks array-like, choose between labels vs image + conservatively and fall back to add_image. + + Notes + ----- + We intentionally do NOT try to guess points/shapes from an ``(N, D)`` array + because that is ambiguous with images. For those cases, return a + LayerDataTuple from the workflow task. + """ + if self._viewer is None: + return + + scale = self._root_scale if self._root_scale is not None else None + + # Preferred: explicit LayerDataTuple + if ( + isinstance(result, tuple) + and len(result) == 3 + and isinstance(result[2], str) + ): + data, kwargs, layer_type = result + if kwargs is None: + kwargs = {} + if not isinstance(kwargs, dict): + kwargs = dict(kwargs) + + kwargs.setdefault('name', task) + if scale is not None and 'scale' not in kwargs: + kwargs['scale'] = scale + + add_name = f'add_{layer_type}' + add_fn = getattr(self._viewer, add_name, None) + if callable(add_fn): + add_fn(data, **kwargs) + return + + # Fallback: array-like results -> labels vs image + looks_array_like = all( + hasattr(result, attr) for attr in ('shape', 'ndim', 'dtype') + ) + if looks_array_like: + import numpy as np + + def _is_probably_labels(arr) -> bool: + try: + if arr.ndim < 2: + return False + if not ( + np.issubdtype(arr.dtype, np.integer) + or np.issubdtype(arr.dtype, np.bool_) + ): + return False + + flat = np.asarray(arr).ravel() + if flat.size == 0: + return False + + # Sample to avoid expensive unique() on large arrays. + if flat.size > 4096: + step = max(1, flat.size // 4096) + flat = flat[::step] + + uniq = np.unique(flat) + if uniq.size > 256: + return False + return not uniq.min(initial=0) < 0 + except Exception: # noqa + return False + + if _is_probably_labels(result): + self._viewer.add_labels( + result, + name=task, + scale=scale, + ) + return + + self._viewer.add_image( + result, + name=task, + blending='additive', + scale=scale, + ) + return + + # Last resort: try add_image, otherwise show a helpful error. + try: + self._viewer.add_image( + result, + name=task, + blending='additive', + scale=scale, + ) + except Exception as e: # noqa + from napari.utils.notifications import show_error + + show_error( + f"Cannot add result for task '{task}' to the viewer: {e}. " + 'For non-image outputs, return a LayerDataTuple ' + '(data, kwargs, layer_type) from the workflow task.' + ) + + def viewer_workflow_threaded(self): + """Run the viewer workflow with threading and progress bar updates.""" + from napari.qt import create_worker + + self._progress_bar.label = 'Workflow on Viewer Layers' + self._progress_bar.value = 0 + self._progress_bar.max = len(self._tasks_select.value) + + self._viewer_worker = create_worker(self.viewer_workflow) + self._viewer_worker.yielded.connect(self._viewer_workflow_yielded) + self._viewer_worker.start() + return diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..88af0c1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,143 @@ +"""Pytest configuration and shared fixtures for ndev-workflows tests.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +from ndev_workflows import Workflow + +# ============================================================================= +# Helper functions (module-level for import resolution in saved workflows) +# ============================================================================= + + +def add_value(x, value: int = 10): + """Helper function for addition tests.""" + return x + value + + +def multiply_value(x, factor: float = 2.0): + """Helper function for multiplication tests.""" + return x * factor + + +def threshold_value(x, thresh: float = 128.0): + """Helper function for thresholding tests.""" + return (x > thresh).astype(np.uint8) + + +def blur_value(x, sigma: float = 1.0): + """Helper function for blurring tests.""" + from scipy.ndimage import gaussian_filter + + return gaussian_filter(x.astype(float), sigma=sigma) + + +# ============================================================================= +# Path fixtures +# ============================================================================= + + +@pytest.fixture +def resources_path() -> Path: + """Path to the test resources directory.""" + return Path(__file__).parent / 'resources' + + +@pytest.fixture +def workflow_resources_path(resources_path: Path) -> Path: + """Path to the Workflow test resources.""" + return resources_path / 'Workflow' + + +@pytest.fixture +def sample_workflow_path(workflow_resources_path: Path) -> Path: + """Path to the sample 2-roots-2-leafs workflow.""" + return ( + workflow_resources_path + / 'workflows' + / 'cpu_workflow-2roots-2leafs.yaml' + ) + + +@pytest.fixture +def legacy_workflow_path(workflow_resources_path: Path) -> Path: + """Path to a legacy format workflow.""" + return workflow_resources_path / 'workflows' / 'legacy_simple.yaml' + + +@pytest.fixture +def images_path(workflow_resources_path: Path) -> Path: + """Path to the test images directory.""" + return workflow_resources_path / 'Images' + + +# ============================================================================= +# Image fixtures +# ============================================================================= + + +@pytest.fixture +def sample_image() -> np.ndarray: + """Create a sample 2D image for testing.""" + return np.random.randint(0, 255, (64, 64), dtype=np.uint8) + + +@pytest.fixture +def sample_3d_image() -> np.ndarray: + """Create a sample 3D image for testing.""" + return np.random.randint(0, 255, (16, 64, 64), dtype=np.uint8) + + +# ============================================================================= +# Workflow fixtures +# ============================================================================= + + +@pytest.fixture +def empty_workflow() -> Workflow: + """Create an empty workflow.""" + return Workflow() + + +@pytest.fixture +def simple_workflow(sample_image: np.ndarray) -> Workflow: + """Create a simple workflow with one processing step.""" + w = Workflow() + w.set('input', sample_image) + w.set('output', add_value, 'input', value=20) + return w + + +@pytest.fixture +def chain_workflow(sample_image: np.ndarray) -> Workflow: + """Create a workflow with a chain of processing steps.""" + w = Workflow() + w.set('input', sample_image) + w.set('multiplied', multiply_value, 'input', factor=2.0) + w.set('added', add_value, 'multiplied', value=5) + return w + + +@pytest.fixture +def branching_workflow(sample_image: np.ndarray) -> Workflow: + """Create a workflow with branching (one input, multiple outputs).""" + w = Workflow() + w.set('input', sample_image) + w.set('blurred_1', blur_value, 'input', sigma=1.0) + w.set('blurred_2', blur_value, 'input', sigma=2.0) + w.set('binary', threshold_value, 'blurred_1', thresh=100.0) + return w + + +@pytest.fixture +def saveable_workflow(sample_image: np.ndarray) -> Workflow: + """Create a workflow using module-level functions (can be saved/loaded).""" + w = Workflow() + w.set('input', sample_image) + w.set('step1', add_value, 'input', value=10) + w.set('step2', multiply_value, 'step1', factor=2.0) + return w diff --git a/tests/resources/Workflow/Images/cells3d2ch.tiff b/tests/resources/Workflow/Images/cells3d2ch.tiff new file mode 100644 index 0000000..b9526fb Binary files /dev/null and b/tests/resources/Workflow/Images/cells3d2ch.tiff differ diff --git a/tests/resources/Workflow/Labels/cells3d2ch.tiff b/tests/resources/Workflow/Labels/cells3d2ch.tiff new file mode 100644 index 0000000..8e3cb5a Binary files /dev/null and b/tests/resources/Workflow/Labels/cells3d2ch.tiff differ diff --git a/tests/resources/Workflow/ShapesAsLabels/cells3d2ch.tiff b/tests/resources/Workflow/ShapesAsLabels/cells3d2ch.tiff new file mode 100644 index 0000000..11e4a71 Binary files /dev/null and b/tests/resources/Workflow/ShapesAsLabels/cells3d2ch.tiff differ diff --git a/tests/resources/Workflow/workflows/cpu_workflow-2roots-2leafs.yaml b/tests/resources/Workflow/workflows/cpu_workflow-2roots-2leafs.yaml new file mode 100644 index 0000000..8c80c97 --- /dev/null +++ b/tests/resources/Workflow/workflows/cpu_workflow-2roots-2leafs.yaml @@ -0,0 +1,37 @@ +name: CPU Workflow 2 Roots 2 Leafs +description: Test workflow using napari-segment-blobs-and-things-with-membranes functions +modified: '2025-12-19' +inputs: +- membrane +- nuclei +outputs: +- membrane-label +- nucleus-label +tasks: + membrane-gb: + function: napari_segment_blobs_and_things_with_membranes.gaussian_blur + params: + arg0: membrane + sigma: 1.0 + membrane-threshold: + function: napari_segment_blobs_and_things_with_membranes.threshold_otsu + params: + arg0: membrane-gb + membrane-label: + function: napari_segment_blobs_and_things_with_membranes.connected_component_labeling + params: + arg0: membrane-threshold + + nucleus-gb: + function: napari_segment_blobs_and_things_with_membranes.gaussian_blur + params: + arg0: nuclei + sigma: 1.0 + nucleus-threshold: + function: napari_segment_blobs_and_things_with_membranes.threshold_otsu + params: + arg0: nucleus-gb + nucleus-label: + function: napari_segment_blobs_and_things_with_membranes.connected_component_labeling + params: + arg0: nucleus-threshold diff --git a/tests/resources/Workflow/workflows/cpu_workflow-2roots-2leafs_legacy.yaml b/tests/resources/Workflow/workflows/cpu_workflow-2roots-2leafs_legacy.yaml new file mode 100644 index 0000000..72afbed --- /dev/null +++ b/tests/resources/Workflow/workflows/cpu_workflow-2roots-2leafs_legacy.yaml @@ -0,0 +1,26 @@ +!!python/object:napari_workflows._workflow.Workflow +_tasks: + membrane-gb: !!python/tuple + - &id001 !!python/name:napari_segment_blobs_and_things_with_membranes.gaussian_blur '' + - membrane + - 1 + membrane-label: !!python/tuple + - &id002 !!python/name:skimage.measure._label.label '' + - membrane-threshold + - null + - false + membrane-threshold: !!python/tuple + - &id003 !!python/name:napari_segment_blobs_and_things_with_membranes.threshold_otsu '' + - membrane-gb + nucleus-gb: !!python/tuple + - *id001 + - nucleus + - 1 + nucleus-label: !!python/tuple + - *id002 + - nucleus-threshold + - null + - false + nucleus-threshold: !!python/tuple + - *id003 + - nucleus-gb diff --git a/tests/resources/Workflow/workflows/legacy_simple.yaml b/tests/resources/Workflow/workflows/legacy_simple.yaml new file mode 100644 index 0000000..355a5ec --- /dev/null +++ b/tests/resources/Workflow/workflows/legacy_simple.yaml @@ -0,0 +1,9 @@ +!!python/object:napari_workflows._workflow.Workflow +_tasks: + blurred: !!python/tuple + - !!python/name:skimage.filters.gaussian '' + - image + - 2.0 + labels: !!python/tuple + - !!python/name:skimage.measure.label '' + - blurred diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 0000000..3af612e --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,63 @@ +"""Tests for batch processing (_batch.py). + +Note: The process_workflow_file function is decorated with @batch from nbatch, +which modifies its signature. Direct testing is done via integration tests +in test_workflow_container.py through the WorkflowContainer.batch_workflow(). +""" + +from __future__ import annotations + +from pathlib import Path + + +class TestBatchDecorator: + """Test the @batch decorator behavior.""" + + def test_batch_continues_on_error(self, tmp_path: Path): + """Test that @batch decorator allows continuing on error.""" + from nbatch import batch + + errors = [] + + @batch(on_error='continue') + def process_item(item: Path, output: Path) -> Path: + if 'bad' in item.name: + raise ValueError(f'Bad file: {item}') + output_file = output / item.name + output_file.write_text('processed') + return output_file + + # Create test files + (tmp_path / 'good1.txt').write_text('test') + (tmp_path / 'bad_file.txt').write_text('test') + (tmp_path / 'good2.txt').write_text('test') + + output_dir = tmp_path / 'output' + output_dir.mkdir() + + files = list(tmp_path.glob('*.txt')) + + # Process all files - should not raise + results = [] + for f in files: + try: + result = process_item(f, output_dir) + results.append(result) + except ValueError as e: + errors.append(str(e)) + + # Good files should be processed + assert (output_dir / 'good1.txt').exists() + assert (output_dir / 'good2.txt').exists() + # Bad file should have raised + assert len(errors) == 1 + + +class TestProcessWorkflowFileImport: + """Test that process_workflow_file can be imported.""" + + def test_import_process_workflow_file(self): + """Test that the function can be imported.""" + from ndev_workflows._batch import process_workflow_file + + assert callable(process_workflow_file) diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 0000000..37d2874 --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,366 @@ +"""Tests for workflow I/O (save/load).""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +import pytest + +from ndev_workflows import Workflow, load_workflow, save_workflow +from ndev_workflows._io import WorkflowYAMLError +from ndev_workflows._spec import ensure_runnable +from ndev_workflows._workflow import WorkflowNotRunnableError + + +# Define helper functions at module level for import resolution +def add_helper(x, value=10): + """Helper function for addition.""" + return x + value + + +def multiply_helper(x, factor=2.0): + """Helper function for multiplication.""" + return x * factor + + +class TestWorkflowSaveLoad: + """Test YAML save/load functionality.""" + + def test_save_and_load_simple( + self, tmp_path: Path, sample_image: np.ndarray + ): + """Test saving and loading a simple workflow.""" + w = Workflow() + w.set('input', sample_image) + w.set('output', add_helper, 'input', value=20) + + filepath = tmp_path / 'workflow.yaml' + save_workflow(filepath, w) + + assert filepath.exists() + + # Load the workflow + loaded = load_workflow(filepath) + assert isinstance(loaded, Workflow) + assert 'output' in loaded + + # Data tasks are not saved + assert 'input' not in loaded + + def test_save_excludes_data_tasks( + self, tmp_path: Path, sample_image: np.ndarray + ): + """Test that raw data tasks are excluded from saved file.""" + w = Workflow() + w.set('data1', sample_image) + w.set('data2', np.zeros((10, 10))) + w.set('processed', add_helper, 'data1', value=5) + + filepath = tmp_path / 'workflow.yaml' + save_workflow(filepath, w) + + loaded = load_workflow(filepath) + + # Only processing step should be saved + assert 'processed' in loaded + assert 'data1' not in loaded + assert 'data2' not in loaded + + def test_loaded_workflow_executes( + self, tmp_path: Path, sample_image: np.ndarray + ): + """Test that loaded workflow can be executed.""" + w = Workflow() + w.set('input', sample_image) + w.set('result', add_helper, 'input', value=100) + + filepath = tmp_path / 'workflow.yaml' + save_workflow(filepath, w) + + loaded = load_workflow(filepath) + loaded.set('input', sample_image) # Provide data + + result = loaded.get('result') + expected = sample_image + 100 + np.testing.assert_array_equal(result, expected) + + def test_save_load_chain(self, tmp_path: Path, sample_image: np.ndarray): + """Test saving/loading a chain of operations.""" + w = Workflow() + w.set('input', sample_image) + w.set('step1', add_helper, 'input', value=10) + w.set('step2', multiply_helper, 'step1', factor=2.0) + + filepath = tmp_path / 'chain.yaml' + save_workflow(filepath, w) + + loaded = load_workflow(filepath) + loaded.set('input', sample_image) + + result = loaded.get('step2') + expected = (sample_image + 10) * 2.0 + np.testing.assert_allclose(result, expected) + + +class TestWorkflowMetadata: + """Test metadata extraction from workflow files.""" + + def test_get_metadata_from_saved_workflow( + self, tmp_path: Path, sample_image: np.ndarray + ): + """Test extracting metadata from a saved workflow.""" + w = Workflow() + w.set('input', sample_image) + w.set('processed', multiply_helper, 'input', factor=3.0) + + filepath = tmp_path / 'workflow.yaml' + save_workflow(filepath, w, name='Test Workflow') + + metadata = load_workflow(filepath, lazy=True).metadata + + assert metadata['name'] == 'Test Workflow' + assert metadata['legacy'] is False + + def test_get_metadata_inputs_outputs( + self, tmp_path: Path, sample_image: np.ndarray + ): + """Test that inputs and outputs are correctly extracted.""" + w = Workflow() + # Note: Data tasks are not saved - only function tasks are + # So 'input' becomes an external reference when loaded + w.set('input', sample_image) + w.set('step1', add_helper, 'input', value=10) + w.set('output', multiply_helper, 'step1', factor=2.0) + + filepath = tmp_path / 'workflow.yaml' + save_workflow(filepath, w) + + metadata = load_workflow(filepath, lazy=True).metadata + + # When loaded, 'input' becomes an external input (since data tasks aren't saved) + # The saved workflow's external_inputs() finds 'input' as referenced but undefined + assert 'input' in metadata['inputs'] + # Outputs are leafs + assert 'output' in metadata['outputs'] + + def test_get_metadata_tasks( + self, tmp_path: Path, sample_image: np.ndarray + ): + """Test that task names are correctly extracted.""" + w = Workflow() + w.set('input', sample_image) + w.set('blur', add_helper, 'input') + w.set('threshold', multiply_helper, 'blur') + + filepath = tmp_path / 'workflow.yaml' + save_workflow(filepath, w) + + workflow = load_workflow(filepath, lazy=True) + + assert 'blur' in workflow.processing_task_names() + assert 'threshold' in workflow.processing_task_names() + + +class TestWorkflowYAMLError: + """Test error handling in workflow I/O.""" + + def test_load_nonexistent_file(self, tmp_path: Path): + """Test loading a file that doesn't exist.""" + with pytest.raises(FileNotFoundError): + load_workflow(tmp_path / 'nonexistent.yaml') + + def test_load_invalid_yaml(self, tmp_path: Path): + """Test loading invalid YAML.""" + filepath = tmp_path / 'invalid.yaml' + filepath.write_text('not: valid: yaml: content: [[[') + + with pytest.raises(WorkflowYAMLError): + load_workflow(filepath) + + def test_load_with_unimportable_function(self, tmp_path: Path): + """Test that unimportable functions raise error.""" + filepath = tmp_path / 'bad_workflow.yaml' + filepath.write_text(""" +name: Bad Workflow +inputs: [input] +outputs: [bad] +tasks: + bad: + function: nonexistent.module.fake_function + params: + arg0: input +""") + + with pytest.raises(WorkflowYAMLError, match='Cannot import'): + load_workflow(filepath) + + +class TestPathTypes: + """Test that both str and Path work for file operations.""" + + def test_save_with_path(self, tmp_path: Path, sample_image: np.ndarray): + """Test save_workflow with Path object.""" + w = Workflow() + w.set('x', sample_image) + w.set('y', add_helper, 'x') + + save_workflow(tmp_path / 'test.yaml', w) + assert (tmp_path / 'test.yaml').exists() + + def test_save_with_str(self, tmp_path: Path, sample_image: np.ndarray): + """Test save_workflow with string path.""" + w = Workflow() + w.set('x', sample_image) + w.set('y', add_helper, 'x') + + save_workflow(str(tmp_path / 'test.yaml'), w) + assert (tmp_path / 'test.yaml').exists() + + def test_load_with_path(self, tmp_path: Path, sample_image: np.ndarray): + """Test load_workflow with Path object.""" + w = Workflow() + w.set('x', sample_image) + w.set('y', add_helper, 'x') + + filepath = tmp_path / 'test.yaml' + save_workflow(filepath, w) + + loaded = load_workflow(filepath) + assert 'y' in loaded + + def test_load_with_str(self, tmp_path: Path, sample_image: np.ndarray): + """Test load_workflow with string path.""" + w = Workflow() + w.set('x', sample_image) + w.set('y', add_helper, 'x') + + filepath = tmp_path / 'test.yaml' + save_workflow(filepath, w) + + loaded = load_workflow(str(filepath)) + assert 'y' in loaded + + +@pytest.fixture +def legacy_workflow_path() -> Path: + """Path to legacy format test file.""" + return Path('tests/resources/Workflow/workflows/legacy_simple.yaml') + + +class TestLegacyFormatLoading: + """Test loading legacy napari-workflows format.""" + + def test_load_legacy_format(self, legacy_workflow_path: Path): + """Test that legacy format is detected and loaded.""" + from ndev_workflows._io import is_legacy_format + + assert is_legacy_format(legacy_workflow_path) + + workflow = load_workflow(legacy_workflow_path) + assert isinstance(workflow, Workflow) + + def test_legacy_format_has_correct_tasks(self, legacy_workflow_path: Path): + """Test that legacy format loads all tasks.""" + workflow = load_workflow(legacy_workflow_path) + + assert 'blurred' in workflow + assert 'labels' in workflow + + def test_legacy_format_has_correct_roots(self, legacy_workflow_path: Path): + """Test that legacy format correctly identifies roots.""" + workflow = load_workflow(legacy_workflow_path) + + roots = workflow.roots() + assert 'image' in roots + + def test_legacy_format_has_correct_leafs(self, legacy_workflow_path: Path): + """Test that legacy format correctly identifies leafs.""" + workflow = load_workflow(legacy_workflow_path) + + leafs = workflow.leaves() + assert 'labels' in leafs + + def test_legacy_format_executes(self, legacy_workflow_path: Path): + """Test that loaded legacy workflow can execute.""" + workflow = load_workflow(legacy_workflow_path) + + # Set the input + test_image = np.random.randint(0, 255, (32, 32), dtype=np.uint8) + workflow.set('image', test_image) + + # Execute + result = workflow.get('labels') + assert result is not None + assert result.shape == test_image.shape + + def test_legacy_format_lazy_loading(self, legacy_workflow_path: Path): + """Test lazy loading doesn't import functions.""" + workflow = load_workflow(legacy_workflow_path, lazy=True) + + # Should have CallableRef placeholders + from ndev_workflows._workflow import CallableRef + + task = workflow._tasks['blurred'] + assert isinstance(task[0], CallableRef) + + def test_legacy_metadata(self, legacy_workflow_path: Path): + """Test getting metadata from legacy format.""" + metadata_obj = load_workflow(legacy_workflow_path, lazy=True).metadata + assert isinstance(metadata_obj, dict) + metadata: dict[str, Any] = metadata_obj + + assert metadata['legacy'] is True + + inputs = metadata.get('inputs') + outputs = metadata.get('outputs') + assert isinstance(inputs, list) + assert isinstance(outputs, list) + + assert 'image' in inputs + assert 'labels' in outputs + + +def test_ensure_runnable_from_spec_executes(): + spec = { + 'name': 'sqrt test', + 'inputs': ['x'], + 'outputs': ['y'], + 'tasks': { + 'y': { + 'function': 'math.sqrt', + 'params': {'arg0': 'x'}, + } + }, + } + + w = ensure_runnable(spec) + w.set('x', 9.0) + assert w.get('y') == 3.0 + + +def test_ensure_runnable_reports_missing_callable(): + spec = { + 'inputs': ['x'], + 'outputs': ['y'], + 'tasks': { + 'y': { + 'function': 'nonexistent.module.fake_function', + 'params': {'arg0': 'x'}, + } + }, + } + + with pytest.raises(WorkflowNotRunnableError, match='Cannot import'): + ensure_runnable(spec) + + +def test_workflow_method_ensure_runnable_resolves_callable_ref(): + from ndev_workflows._workflow import CallableRef, Workflow + + w = Workflow() + w._tasks['y'] = (CallableRef('math', 'sqrt'), 'x') + w.ensure_runnable() + w.set('x', 16.0) + assert w.get('y') == 4.0 diff --git a/tests/test_manager.py b/tests/test_manager.py new file mode 100644 index 0000000..ab5bc91 --- /dev/null +++ b/tests/test_manager.py @@ -0,0 +1,327 @@ +"""Tests for WorkflowManager (_manager.py). + +WorkflowManager requires napari, so tests use minimal viewer mocking +where possible and make_napari_viewer for integration tests. +""" + +from __future__ import annotations + +import time + +import numpy as np +import pytest + +from ndev_workflows._manager import WorkflowManager, _managers + + +# Use module-level functions for tests +def add_value(x, value: int = 10): + """Helper function for addition tests.""" + return x + value + + +def multiply_value(x, factor: float = 2.0): + """Helper function for multiplication tests.""" + return x * factor + + +class MockLayer: + """Minimal mock for napari layer.""" + + def __init__(self, name: str, data=None): + self.name = name + self.data = data + + +class MockLayers(dict): + """Mock for viewer.layers that supports dict-like access.""" + + +class MockViewer: + """Minimal mock viewer for testing without Qt.""" + + def __init__(self): + self.layers = MockLayers() + + +@pytest.fixture +def mock_viewer() -> MockViewer: + """Create a mock viewer.""" + return MockViewer() + + +@pytest.fixture(autouse=True) +def clear_managers(): + """Clear global managers registry before each test.""" + _managers.clear() + yield + _managers.clear() + + +class TestWorkflowManagerCreation: + """Test WorkflowManager instantiation and singleton pattern.""" + + def test_install_creates_manager(self, mock_viewer: MockViewer): + """Test that install creates a new manager.""" + manager = WorkflowManager.install(mock_viewer) + + assert manager is not None + assert isinstance(manager, WorkflowManager) + assert manager.viewer is mock_viewer + + def test_install_returns_existing_manager(self, mock_viewer: MockViewer): + """Test that install returns existing manager for same viewer.""" + manager1 = WorkflowManager.install(mock_viewer) + manager2 = WorkflowManager.install(mock_viewer) + + assert manager1 is manager2 + + def test_different_viewers_get_different_managers(self): + """Test that different viewers get different managers.""" + viewer1 = MockViewer() + viewer2 = MockViewer() + + manager1 = WorkflowManager.install(viewer1) + manager2 = WorkflowManager.install(viewer2) + + assert manager1 is not manager2 + + def test_manager_has_empty_workflow(self, mock_viewer: MockViewer): + """Test that new manager has empty workflow.""" + manager = WorkflowManager.install(mock_viewer) + + assert len(manager.workflow) == 0 + + def test_manager_has_undo_redo_controller(self, mock_viewer: MockViewer): + """Test that manager has undo/redo controller.""" + manager = WorkflowManager.install(mock_viewer) + + assert manager.undo_redo is not None + assert manager.undo_redo.can_undo is False + + +class TestWorkflowManagerUpdate: + """Test workflow update functionality.""" + + def test_update_adds_task(self, mock_viewer: MockViewer): + """Test that update adds a task to the workflow.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + + manager.workflow.set('input', data) + manager.update('output', add_value, 'input', value=5) + + assert 'output' in manager.workflow + + def test_update_saves_undo_state(self, mock_viewer: MockViewer): + """Test that update saves state for undo.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + manager.workflow.set('input', data) + + assert manager.undo_redo.can_undo is False + + manager.update('output', add_value, 'input', value=5) + + assert manager.undo_redo.can_undo is True + + def test_update_with_layer_object(self, mock_viewer: MockViewer): + """Test that update converts layer objects to names.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + manager.workflow.set('input_layer', data) + + # Create mock layer + layer = MockLayer('input_layer', data) + + manager.update('output', add_value, layer, value=5) + + # Should have stored the layer name, not the object + sources = manager.workflow.sources_of('output') + assert 'input_layer' in sources + + +class TestWorkflowManagerUndoRedo: + """Test undo/redo integration.""" + + def test_undo_reverts_workflow(self, mock_viewer: MockViewer): + """Test that undo reverts the workflow.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + manager.workflow.set('input', data) + + # Add a task + manager.update('output', add_value, 'input', value=5) + assert 'output' in manager.workflow + + # Undo + manager.undo() + assert 'output' not in manager.workflow + + def test_redo_reapplies_change(self, mock_viewer: MockViewer): + """Test that redo reapplies the undone change.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + manager.workflow.set('input', data) + + manager.update('output', add_value, 'input', value=5) + manager.undo() + assert 'output' not in manager.workflow + + manager.redo() + assert 'output' in manager.workflow + + +class TestWorkflowManagerClear: + """Test workflow clearing.""" + + def test_clear_removes_all_tasks(self, mock_viewer: MockViewer): + """Test that clear removes all tasks.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + + manager.workflow.set('input', data) + manager.workflow.set('output', add_value, 'input') + assert len(manager.workflow) == 2 + + manager.clear() + assert len(manager.workflow) == 0 + + def test_clear_saves_undo_state(self, mock_viewer: MockViewer): + """Test that clear is undoable.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + + manager.workflow.set('input', data) + manager.workflow.set('output', add_value, 'input') + + manager.clear() + assert len(manager.workflow) == 0 + + manager.undo() + assert len(manager.workflow) == 2 + + +class TestWorkflowManagerInvalidate: + """Test task invalidation.""" + + def test_invalidate_schedules_update(self, mock_viewer: MockViewer): + """Test that invalidate schedules an update.""" + manager = WorkflowManager.install(mock_viewer) + manager._auto_update_enabled = False # Disable for test + + manager.invalidate('test') + + # Since auto-update is disabled, pending should be empty + assert 'test' not in manager._pending_updates + + +class TestWorkflowManagerStop: + """Test worker thread management.""" + + def test_stop_terminates_worker(self, mock_viewer: MockViewer): + """Test that stop terminates the background worker.""" + manager = WorkflowManager.install(mock_viewer) + + assert manager._worker_thread is not None + assert manager._worker_thread.is_alive() + + manager.stop() + + # Give thread time to stop + time.sleep(0.2) + + # After stop, thread should be None or not alive + if manager._worker_thread is not None: + assert not manager._worker_thread.is_alive() + + +class TestWorkflowManagerCodeGeneration: + """Test Python code generation.""" + + def test_to_python_code_includes_imports(self, mock_viewer: MockViewer): + """Test that generated code includes imports.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + + manager.workflow.set('input', data) + manager.workflow.set('output', add_value, 'input', value=5) + + code = manager.to_python_code(use_napari=False) + + # Function defined in this module, so import will reference test_manager + assert 'import add_value' in code + assert 'add_value' in code + + def test_to_python_code_includes_napari(self, mock_viewer: MockViewer): + """Test that generated code includes napari when requested.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + + manager.workflow.set('input', data) + manager.workflow.set('output', add_value, 'input', value=5) + + code = manager.to_python_code(use_napari=True) + + assert 'import napari' in code + assert 'viewer = napari.Viewer()' in code + assert 'napari.run()' in code + + def test_to_python_code_data_placeholder(self, mock_viewer: MockViewer): + """Test that data tasks get placeholder comments.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + + manager.workflow.set('input', data) + manager.workflow.set('output', add_value, 'input') + + code = manager.to_python_code(use_napari=False) + + assert '# input = ' in code + + def test_to_python_code_notebook_format(self, mock_viewer: MockViewer): + """Test notebook format adds cell markers.""" + manager = WorkflowManager.install(mock_viewer) + data = np.zeros((10, 10)) + + manager.workflow.set('input', data) + manager.workflow.set('output', add_value, 'input') + + code = manager.to_python_code(use_napari=True, notebook=True) + + assert '# %%' in code + + +class TestWorkflowManagerWithNapari: + """Integration tests requiring actual napari viewer.""" + + @pytest.fixture + def napari_manager(self, make_napari_viewer): + """Create a manager with real napari viewer.""" + viewer = make_napari_viewer() + manager = WorkflowManager.install(viewer) + yield manager + manager.stop() + + def test_execute_update_updates_layer(self, napari_manager): + """Test that executing update refreshes layer data.""" + data = np.random.randint(0, 255, (64, 64), dtype=np.uint8) + + # Add a layer directly to viewer + napari_manager.viewer.add_image(data, name='input') + + # Add the data to workflow + napari_manager.workflow.set('input', data) + napari_manager.workflow.set('output', add_value, 'input', value=10) + + # Add output layer + output_layer = napari_manager.viewer.add_image( + np.zeros_like(data), name='output' + ) + + # Execute the update + napari_manager._execute_update('output') + + # Check layer data was updated + expected = data + 10 + np.testing.assert_array_equal(output_layer.data, expected) diff --git a/tests/test_spec.py b/tests/test_spec.py new file mode 100644 index 0000000..9c62115 --- /dev/null +++ b/tests/test_spec.py @@ -0,0 +1,282 @@ +"""Tests for workflow spec conversion (_spec.py).""" + +from __future__ import annotations + +from functools import partial + +import numpy as np +import pytest + +from ndev_workflows import Workflow +from ndev_workflows._spec import ( + ensure_runnable, + spec_dict_to_workflow, + workflow_to_spec_dict, +) +from ndev_workflows._workflow import CallableRef, WorkflowNotRunnableError + + +# Use module-level functions from tests.conftest for serialization tests +def add_value(x, value: int = 10): + """Helper function for addition tests.""" + return x + value + + +def multiply_value(x, factor: float = 2.0): + """Helper function for multiplication tests.""" + return x * factor + + +class TestWorkflowToSpecDict: + """Tests for workflow_to_spec_dict conversion.""" + + def test_converts_simple_workflow(self, sample_image: np.ndarray): + """Test converting a simple workflow to spec dict.""" + w = Workflow() + w.set('input', sample_image) + w.set('output', add_value, 'input', value=20) + + spec = workflow_to_spec_dict(w) + + assert 'tasks' in spec + assert 'output' in spec['tasks'] + # Function path should contain 'add_value' + assert 'add_value' in spec['tasks']['output']['function'] + assert spec['tasks']['output']['params']['arg0'] == 'input' + assert spec['tasks']['output']['params']['value'] == 20 + + def test_excludes_data_tasks(self, sample_image: np.ndarray): + """Test that data (non-callable) tasks are not included.""" + w = Workflow() + w.set('data1', sample_image) + w.set('data2', np.zeros((10, 10))) + w.set('processed', add_value, 'data1', value=5) + + spec = workflow_to_spec_dict(w) + + # Data tasks should not be in tasks dict + assert 'data1' not in spec['tasks'] + assert 'data2' not in spec['tasks'] + assert 'processed' in spec['tasks'] + + def test_identifies_inputs_and_outputs(self, sample_image: np.ndarray): + """Test that inputs and outputs are correctly identified.""" + w = Workflow() + w.set('input', sample_image) + w.set('step1', add_value, 'input', value=10) + w.set('output', multiply_value, 'step1', factor=2.0) + + spec = workflow_to_spec_dict(w) + + # 'input' is referenced but not saved as a task + assert 'input' in spec['inputs'] + # 'output' is a leaf (nothing depends on it) + assert 'output' in spec['outputs'] + + def test_includes_metadata(self, sample_image: np.ndarray): + """Test that name and description are included.""" + w = Workflow() + w.set('x', sample_image) + w.set('y', add_value, 'x') + + spec = workflow_to_spec_dict(w, name='Test', description='A test') + + assert spec['name'] == 'Test' + assert spec['description'] == 'A test' + assert 'modified' in spec + + def test_handles_callable_ref(self): + """Test converting a workflow with CallableRef placeholders.""" + w = Workflow() + ref = CallableRef('math', 'sqrt') + ref.kwargs = {'x': 9} + w._tasks['result'] = (ref, 'input') + + spec = workflow_to_spec_dict(w) + + assert 'result' in spec['tasks'] + assert spec['tasks']['result']['function'] == 'math.sqrt' + + +class TestSpecDictToWorkflow: + """Tests for spec_dict_to_workflow conversion.""" + + def test_creates_workflow_from_spec(self): + """Test creating a workflow from a spec dict.""" + spec = { + 'inputs': ['x'], + 'outputs': ['y'], + 'tasks': { + 'y': { + 'function': 'math.sqrt', + 'params': {'arg0': 'x'}, + } + }, + } + + w = spec_dict_to_workflow(spec, lazy=False) + + assert 'y' in w + w.set('x', 16.0) + assert w.get('y') == 4.0 + + def test_lazy_creates_callable_refs(self): + """Test lazy loading creates CallableRef placeholders.""" + spec = { + 'inputs': ['x'], + 'outputs': ['y'], + 'tasks': { + 'y': { + 'function': 'math.sqrt', + 'params': {'arg0': 'x'}, + } + }, + } + + w = spec_dict_to_workflow(spec, lazy=True) + + assert 'y' in w + task = w._tasks['y'] + assert isinstance(task[0], CallableRef) + + def test_kwargs_attached_to_callable_ref(self): + """Test that kwargs are attached to CallableRef when lazy.""" + spec = { + 'inputs': ['x'], + 'outputs': ['y'], + 'tasks': { + 'y': { + 'function': 'builtins.round', + 'params': {'arg0': 'x', 'ndigits': 2}, + } + }, + } + + w = spec_dict_to_workflow(spec, lazy=True) + + task = w._tasks['y'] + ref = task[0] + assert ref.kwargs == {'ndigits': 2} + + def test_eager_creates_partial_with_kwargs(self): + """Test that eager loading creates partial functions with kwargs.""" + spec = { + 'inputs': ['x'], + 'outputs': ['y'], + 'tasks': { + 'y': { + 'function': 'builtins.round', + 'params': {'arg0': 'x', 'ndigits': 3}, + } + }, + } + + w = spec_dict_to_workflow(spec, lazy=False) + + task = w._tasks['y'] + func = task[0] + assert isinstance(func, partial) + assert func.keywords == {'ndigits': 3} + + def test_preserves_metadata(self): + """Test that metadata is preserved.""" + spec = { + 'name': 'Test Workflow', + 'description': 'A test', + 'modified': '2025-01-01', + 'inputs': ['x'], + 'outputs': ['y'], + 'tasks': { + 'y': { + 'function': 'math.sqrt', + 'params': {'arg0': 'x'}, + } + }, + } + + w = spec_dict_to_workflow(spec, lazy=True) + + assert w.metadata['name'] == 'Test Workflow' + assert w.metadata['description'] == 'A test' + + +class TestEnsureRunnable: + """Tests for ensure_runnable function.""" + + def test_from_spec_dict(self): + """Test ensure_runnable from a spec dict.""" + spec = { + 'inputs': ['x'], + 'outputs': ['y'], + 'tasks': { + 'y': { + 'function': 'math.sqrt', + 'params': {'arg0': 'x'}, + } + }, + } + + w = ensure_runnable(spec) + + w.set('x', 25.0) + assert w.get('y') == 5.0 + + def test_from_lazy_workflow(self): + """Test ensure_runnable from a lazy workflow.""" + w = Workflow() + ref = CallableRef('math', 'sqrt') + w._tasks['y'] = (ref, 'x') + + w = ensure_runnable(w) + + w.set('x', 36.0) + assert w.get('y') == 6.0 + + def test_raises_for_missing_callable(self): + """Test that missing callables raise WorkflowNotRunnableError.""" + spec = { + 'inputs': ['x'], + 'outputs': ['y'], + 'tasks': { + 'y': { + 'function': 'nonexistent.module.fake_function', + 'params': {'arg0': 'x'}, + } + }, + } + + with pytest.raises(WorkflowNotRunnableError, match='Cannot import'): + ensure_runnable(spec) + + def test_already_runnable_workflow_unchanged( + self, simple_workflow: Workflow + ): + """Test that already runnable workflow passes through.""" + result = ensure_runnable(simple_workflow) + + assert 'output' in result + # The workflow should still work + assert result.get('output') is not None + + +class TestRoundTrip: + """Test round-trip conversion workflow -> spec -> workflow.""" + + def test_roundtrip_preserves_functionality(self, sample_image: np.ndarray): + """Test that converting to spec and back preserves behavior.""" + # Create original workflow + original = Workflow() + original.set('input', sample_image) + original.set('step1', add_value, 'input', value=10) + original.set('step2', multiply_value, 'step1', factor=3.0) + + # Convert to spec and back + spec = workflow_to_spec_dict(original) + restored = spec_dict_to_workflow(spec, lazy=False) + + # Provide input and execute + restored.set('input', sample_image) + result = restored.get('step2') + + expected = (sample_image + 10) * 3.0 + np.testing.assert_allclose(result, expected) diff --git a/tests/test_undo_redo.py b/tests/test_undo_redo.py new file mode 100644 index 0000000..eb92186 --- /dev/null +++ b/tests/test_undo_redo.py @@ -0,0 +1,228 @@ +"""Tests for UndoRedoController.""" + +from __future__ import annotations + +import numpy as np + +from ndev_workflows import Workflow +from ndev_workflows._undo_redo import UndoRedoController + + +def add_value(x, value=10): + """Test function for addition.""" + return x + value + + +class TestUndoRedoBasics: + """Test basic undo/redo functionality.""" + + def test_controller_creation(self, empty_workflow: Workflow): + """Test creating an UndoRedoController.""" + controller = UndoRedoController(empty_workflow) + + assert controller.can_undo is False + assert controller.can_redo is False + assert controller.undo_stack_size == 0 + assert controller.redo_stack_size == 0 + + def test_save_state(self, empty_workflow: Workflow): + """Test saving state.""" + controller = UndoRedoController(empty_workflow) + + controller.save_state() + + assert controller.can_undo is True + assert controller.undo_stack_size == 1 + + def test_undo_restores_state(self, sample_image: np.ndarray): + """Test that undo restores previous state.""" + w = Workflow() + controller = UndoRedoController(w) + + # Initial empty state + controller.save_state() + + # Add a task + w.set('input', sample_image) + + # Check task was added + assert 'input' in w + + # Undo + result = controller.undo() + + assert result is True + assert 'input' not in w + assert controller.can_redo is True + + def test_redo_reapplies_change(self, sample_image: np.ndarray): + """Test that redo reapplies undone change.""" + w = Workflow() + controller = UndoRedoController(w) + + # Save empty state + controller.save_state() + + # Add task + w.set('input', sample_image) + + # Undo + controller.undo() + assert 'input' not in w + + # Redo + result = controller.redo() + + assert result is True + assert 'input' in w + assert controller.can_undo is True + + def test_undo_empty_stack_returns_false(self, empty_workflow: Workflow): + """Test undo with empty stack returns False.""" + controller = UndoRedoController(empty_workflow) + + result = controller.undo() + + assert result is False + + def test_redo_empty_stack_returns_false(self, empty_workflow: Workflow): + """Test redo with empty stack returns False.""" + controller = UndoRedoController(empty_workflow) + + result = controller.redo() + + assert result is False + + +class TestUndoRedoSequence: + """Test undo/redo with multiple operations.""" + + def test_multiple_undos(self, sample_image: np.ndarray): + """Test multiple undo operations.""" + w = Workflow() + controller = UndoRedoController(w) + + # State 0: empty + controller.save_state() + + # State 1: add input + w.set('input', sample_image) + controller.save_state() + + # State 2: add step1 + w.set('step1', add_value, 'input') + controller.save_state() + + # State 3: add step2 + w.set('step2', add_value, 'step1', value=20) + + assert len(w) == 3 + + # Undo to state 2 + controller.undo() + assert 'step2' not in w + assert 'step1' in w + + # Undo to state 1 + controller.undo() + assert 'step1' not in w + assert 'input' in w + + # Undo to state 0 + controller.undo() + assert 'input' not in w + assert len(w) == 0 + + def test_undo_redo_sequence(self, sample_image: np.ndarray): + """Test alternating undo/redo.""" + w = Workflow() + controller = UndoRedoController(w) + + controller.save_state() + w.set('a', sample_image) + controller.save_state() + w.set('b', add_value, 'a') + + # Undo + controller.undo() + assert 'b' not in w + assert 'a' in w + + # Redo + controller.redo() + assert 'b' in w + + # Undo again + controller.undo() + assert 'b' not in w + + def test_new_change_clears_redo_stack(self, sample_image: np.ndarray): + """Test that making a new change clears redo stack.""" + w = Workflow() + controller = UndoRedoController(w) + + controller.save_state() + w.set('a', sample_image) + + # Undo + controller.undo() + assert controller.can_redo is True + + # Make new change + controller.save_state() + w.set('b', sample_image) + + # Redo should no longer be available + assert controller.can_redo is False + + +class TestUndoRedoMaxHistory: + """Test history size limits.""" + + def test_max_history_enforced(self, sample_image: np.ndarray): + """Test that max_history limits undo stack size.""" + w = Workflow() + controller = UndoRedoController(w, max_history=3) + + # Add 5 states + for i in range(5): + controller.save_state() + w.set(f'step{i}', sample_image) + + # Should only have 3 states + assert controller.undo_stack_size == 3 + + def test_clear_history(self, sample_image: np.ndarray): + """Test clearing history.""" + w = Workflow() + controller = UndoRedoController(w) + + controller.save_state() + w.set('a', sample_image) + controller.save_state() + + # Undo to create redo stack + controller.undo() + + assert controller.can_undo is True + assert controller.can_redo is True + + # Clear history + controller.clear_history() + + assert controller.can_undo is False + assert controller.can_redo is False + + +class TestUndoRedoWorkflowCopy: + """Test workflow copying via undo/redo controller.""" + + def test_get_workflow_copy(self, simple_workflow: Workflow): + """Test getting a workflow copy.""" + controller = UndoRedoController(simple_workflow) + + copy = controller.get_workflow_copy() + + assert copy is not simple_workflow + assert len(copy) == len(simple_workflow) + assert copy.keys() == simple_workflow.keys() diff --git a/tests/test_workflow.py b/tests/test_workflow.py new file mode 100644 index 0000000..e528038 --- /dev/null +++ b/tests/test_workflow.py @@ -0,0 +1,274 @@ +"""Tests for the core Workflow class.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from ndev_workflows import Workflow + + +class TestWorkflowBasics: + """Test basic Workflow functionality.""" + + def test_create_empty_workflow(self): + """Test creating an empty workflow.""" + w = Workflow() + assert len(w) == 0 + assert w.keys() == [] + + def test_set_raw_data(self, sample_image: np.ndarray): + """Test setting raw data in workflow.""" + w = Workflow() + w.set('input', sample_image) + + assert 'input' in w + assert len(w) == 1 + + def test_set_processing_step(self, sample_image: np.ndarray): + """Test setting a processing step.""" + + def double(x): + return x * 2 + + w = Workflow() + w.set('input', sample_image) + w.set('doubled', double, 'input') + + assert 'input' in w + assert 'doubled' in w + assert len(w) == 2 + + def test_get_executes_task(self, sample_image: np.ndarray): + """Test that get() executes the task graph.""" + + def add_one(x): + return x + 1 + + w = Workflow() + w.set('input', sample_image) + w.set('output', add_one, 'input') + + result = w.get('output') + expected = sample_image + 1 + np.testing.assert_array_equal(result, expected) + + def test_get_missing_task_raises(self, empty_workflow: Workflow): + """Test that get() raises KeyError for missing task.""" + with pytest.raises(KeyError, match='not found'): + empty_workflow.get('nonexistent') + + def test_kwargs_passed_correctly(self, sample_image: np.ndarray): + """Test that kwargs are passed to the function.""" + + def scale(x, factor=1.0): + return x * factor + + w = Workflow() + w.set('input', sample_image) + w.set('scaled', scale, 'input', factor=3.0) + + result = w.get('scaled') + expected = sample_image * 3.0 + np.testing.assert_array_equal(result, expected) + + +class TestWorkflowGraphOperations: + """Test workflow graph analysis operations.""" + + def test_roots_returns_external_references(self, sample_image: np.ndarray): + """Test roots() returns names referenced but not defined.""" + + def process(x): + return x * 2 + + w = Workflow() + # Only set a function that references 'input', don't define 'input' + w.set('output', process, 'input') + + roots = w.roots() + # 'input' is referenced but not defined, so it's a root + assert 'input' in roots + assert len(roots) == 1 + + def test_roots_empty_when_all_defined(self, simple_workflow: Workflow): + """Test roots() includes data inputs (graph roots).""" + # simple_workflow has 'input' defined as data and used by a task, + # so it is a graph root (even though it's not an external input). + roots = simple_workflow.roots() + assert roots == ['input'] + + def test_leafs_single_output(self, simple_workflow: Workflow): + """Test leafs() with single output.""" + leafs = simple_workflow.leaves() + assert 'output' in leafs + assert len(leafs) == 1 + + def test_leafs_multiple_outputs(self, branching_workflow: Workflow): + """Test leafs() with multiple outputs.""" + leafs = branching_workflow.leaves() + # blurred_2 and binary are both leafs (nothing depends on them) + assert 'blurred_2' in leafs + assert 'binary' in leafs + assert 'input' not in leafs + + def test_followers_of(self, branching_workflow: Workflow): + """Test followers_of() returns dependent tasks.""" + followers = branching_workflow.followers_of('input') + assert 'blurred_1' in followers + assert 'blurred_2' in followers + + def test_sources_of(self, chain_workflow: Workflow): + """Test sources_of() returns dependencies.""" + sources = chain_workflow.sources_of('added') + assert 'multiplied' in sources + + sources = chain_workflow.sources_of('multiplied') + assert 'input' in sources + + def test_sources_of_root(self, chain_workflow: Workflow): + """Test sources_of() returns empty for root.""" + sources = chain_workflow.sources_of('input') + assert sources == [] + + def test_external_inputs_none_when_complete( + self, simple_workflow: Workflow + ): + """Test external_inputs() returns empty for complete workflow.""" + external = simple_workflow.get_undefined_inputs() + assert external == [] + + def test_external_inputs_finds_missing(self, sample_image: np.ndarray): + """Test external_inputs() finds undefined references.""" + + def process(x): + return x * 2 + + w = Workflow() + # Create a task that references 'missing_input' which doesn't exist + w.set('result', process, 'missing_input') + + external = w.get_undefined_inputs() + assert 'missing_input' in external + assert len(external) == 1 + + def test_root_functions(self, sample_image: np.ndarray): + """Test root_functions() returns functions that operate on roots.""" + + def step1(x): + return x + 1 + + def step2(x): + return x * 2 + + w = Workflow() + # Don't define 'input' - it will be a root (external reference) + w.set('processed', step1, 'input') # Operates on root 'input' + w.set('final', step2, 'processed') # Operates on non-root + + root_funcs = w.root_functions() + # 'processed' depends on 'input' which is a root (undefined) + assert 'processed' in root_funcs + # 'final' depends on 'processed' which is defined, not a root + assert 'final' not in root_funcs + + +class TestWorkflowCopyAndModify: + """Test workflow copy and modification operations.""" + + def test_copy_creates_independent_workflow( + self, simple_workflow: Workflow + ): + """Test that copy() creates an independent copy.""" + copied = simple_workflow.copy() + + assert len(copied) == len(simple_workflow) + assert copied.keys() == simple_workflow.keys() + + # Modify original + simple_workflow.set('new_task', lambda x: x, 'input') + + # Copy should be unaffected + assert 'new_task' in simple_workflow + assert 'new_task' not in copied + + def test_remove_task(self, simple_workflow: Workflow): + """Test removing a task.""" + simple_workflow.remove('output') + assert 'output' not in simple_workflow + assert 'input' in simple_workflow + + def test_clear_removes_all(self, simple_workflow: Workflow): + """Test clear() removes all tasks.""" + simple_workflow.clear() + assert len(simple_workflow) == 0 + + def test_is_data_task(self, simple_workflow: Workflow): + """Test is_data_task() correctly identifies raw data.""" + assert simple_workflow.is_data_task('input') is True + assert simple_workflow.is_data_task('output') is False + + def test_get_function(self, simple_workflow: Workflow): + """Test get_function() returns the function.""" + func = simple_workflow.get_function('output') + assert callable(func) + + # Data task should return None + func = simple_workflow.get_function('input') + assert func is None + + +class TestWorkflowChainExecution: + """Test execution of chained workflows.""" + + def test_chain_execution(self, sample_image: np.ndarray): + """Test executing a chain of operations.""" + + def add(x, value=0): + return x + value + + def multiply(x, factor=1): + return x * factor + + w = Workflow() + w.set('input', sample_image) + w.set('step1', add, 'input', value=10) + w.set('step2', multiply, 'step1', factor=2) + w.set('step3', add, 'step2', value=5) + + result = w.get('step3') + expected = (sample_image + 10) * 2 + 5 + np.testing.assert_array_equal(result, expected) + + def test_branching_execution(self, sample_image: np.ndarray): + """Test executing a branching workflow.""" + + def add(x, value=0): + return x + value + + w = Workflow() + w.set('input', sample_image) + w.set('branch1', add, 'input', value=10) + w.set('branch2', add, 'input', value=20) + + result1 = w.get('branch1') + result2 = w.get('branch2') + + np.testing.assert_array_equal(result1, sample_image + 10) + np.testing.assert_array_equal(result2, sample_image + 20) + + +class TestWorkflowRepr: + """Test workflow string representation.""" + + def test_repr_empty(self, empty_workflow: Workflow): + """Test repr of empty workflow.""" + r = repr(empty_workflow) + assert 'Workflow' in r + assert '0 tasks' in r + + def test_repr_with_tasks(self, simple_workflow: Workflow): + """Test repr with tasks.""" + r = repr(simple_workflow) + assert 'Workflow' in r + assert '2 tasks' in r diff --git a/tests/widgets/__init__.py b/tests/widgets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/widgets/test_workflow_container.py b/tests/widgets/test_workflow_container.py new file mode 100644 index 0000000..4711f36 --- /dev/null +++ b/tests/widgets/test_workflow_container.py @@ -0,0 +1,410 @@ +"""Tests for WorkflowContainer widget. + +Organized into: +- Unit tests (no napari/Qt dependencies) +- Widget tests (qtbot for async, no viewer) +- Integration tests (full viewer, only when needed) +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import numpy as np +from ndevio import nImage + +from ndev_workflows.widgets._workflow_container import WorkflowContainer + + +class MockWorkflow: + """Mock workflow for testing without loading actual workflow.""" + + def roots(self): + return ['root1', 'root2'] + + def leaves(self): + return ['leaf1', 'leaf2'] + + def leafs(self): + return self.leaves() + + def set(self, name, func_or_data): + pass + + def get(self, name): + pass + + +# ============================================================================= +# Unit tests - No Qt/napari viewer dependencies +# ============================================================================= + + +class TestWorkflowContainerBasics: + """Basic initialization and property tests without viewer.""" + + def test_init_no_viewer(self): + """Test initialization without viewer.""" + container = WorkflowContainer() + + assert container._viewer is None + assert container._channel_names == [] + assert container._img_dims == '' + + def test_get_workflow_info_loads_workflow(self, sample_workflow_path): + """Test loading a workflow file populates workflow info.""" + container = WorkflowContainer() + container.workflow_file.value = sample_workflow_path + + assert container.workflow is not None + assert container._workflow_roots.value == str( + container.workflow.roots() + ) + + def test_workflow_roots_updated(self, sample_workflow_path): + """Test that root containers are updated from workflow.""" + container = WorkflowContainer() + container.workflow_file.value = sample_workflow_path + + roots = container.workflow.roots() + assert len(container._batch_roots_container) == len(roots) + assert len(container._viewer_roots_container) == len(roots) + + def test_tasks_select_populated(self, sample_workflow_path): + """Test that tasks select is populated from workflow.""" + container = WorkflowContainer() + container.workflow_file.value = sample_workflow_path + + assert len(container._tasks_select.choices) > 0 + assert container._tasks_select.value == list( + container.workflow.leaves() + ) + + def test_update_progress_bar(self): + """Test progress bar update.""" + container = WorkflowContainer() + container._progress_bar.value = 0 + container._progress_bar.max = 10 + container._update_progress_bar(9) + + assert container._progress_bar.value == 9 + + +class TestWorkflowContainerCallbacks: + """Test callback methods without running actual batch.""" + + def test_on_batch_error_updates_label(self): + """Test that error callback updates progress bar label.""" + from nbatch import BatchContext + + container = WorkflowContainer() + + mock_item = MagicMock() + mock_item.name = 'bad_file.tiff' + ctx = MagicMock(spec=BatchContext) + ctx.item = mock_item + + test_exception = ValueError('Test error message') + container._on_batch_error(ctx, test_exception) + + assert 'Error on bad_file.tiff' in container._progress_bar.label + assert 'Test error message' in container._progress_bar.label + + def test_on_batch_complete_updates_label(self): + """Test that complete callback updates progress bar label.""" + container = WorkflowContainer() + container._progress_bar.max = 5 + + container._on_batch_complete() + + assert 'Completed 5 Images' in container._progress_bar.label + assert container.batch_button.enabled is True + assert container._cancel_button.enabled is False + + def test_on_batch_cancel_updates_label(self): + """Test that cancel callback updates progress bar label.""" + container = WorkflowContainer() + container._on_batch_cancel() + + assert container._progress_bar.label == 'Cancelled' + assert container.batch_button.enabled is True + + def test_on_batch_start_sets_progress(self): + """Test that start callback initializes progress bar.""" + container = WorkflowContainer() + container._on_batch_start(total=10) + + assert container._progress_bar.value == 0 + assert container._progress_bar.max == 10 + assert container.batch_button.enabled is False + assert container._cancel_button.enabled is True + + +class TestMockWorkflowTests: + """Tests with mock workflow for faster widget tests (no viewer).""" + + def test_update_roots_with_mock_workflow(self): + """Test _update_roots with a mock workflow.""" + container = WorkflowContainer() + container.workflow = MockWorkflow() + container._channel_names = ['red', 'green', 'blue'] + + container._update_roots() + + assert len(container._batch_roots_container) == 2 + assert len(container._viewer_roots_container) == 2 + + for idx, root in enumerate(container._batch_roots_container): + assert root.label == f'Root {idx}: {MockWorkflow().roots()[idx]}' + assert root.choices == (None, 'red', 'green', 'blue') + assert root._nullable is True + assert root.value is None + + +# ============================================================================= +# Batch processing tests (with qtbot for async, no viewer) +# ============================================================================= + + +class TestBatchWorkflow: + """Test batch workflow execution.""" + + def test_batch_workflow_basic( + self, tmp_path, qtbot, sample_workflow_path, images_path + ): + """Test basic batch workflow execution.""" + container = WorkflowContainer() + container.workflow_file.value = sample_workflow_path + container.image_directory.value = images_path + + output_folder = tmp_path / 'Output' + output_folder.mkdir() + container.result_directory.value = output_folder + + container._batch_roots_container[0].value = 'membrane' + container._batch_roots_container[1].value = 'nuclei' + + container.batch_workflow() + + qtbot.waitUntil( + lambda: not container._batch_runner.is_running, timeout=15000 + ) + + assert output_folder.exists() + assert (output_folder / 'cells3d2ch.tiff').exists() + assert (output_folder / 'workflow.log.txt').exists() + + def test_batch_workflow_leaf_tasks_only( + self, tmp_path, qtbot, sample_workflow_path, images_path + ): + """Test batch workflow outputs only leaf tasks by default.""" + container = WorkflowContainer() + container.workflow_file.value = sample_workflow_path + container.image_directory.value = images_path + + output_folder = tmp_path / 'Output' + output_folder.mkdir(exist_ok=True) + container.result_directory.value = output_folder + + container._batch_roots_container[0].value = 'membrane' + container._batch_roots_container[1].value = 'nuclei' + + container.batch_workflow() + + qtbot.waitUntil( + lambda: not container._batch_runner.is_running, timeout=15000 + ) + + assert container._progress_bar.value == 1 + assert (output_folder / 'cells3d2ch.tiff').exists() + + img = nImage(output_folder / 'cells3d2ch.tiff') + assert len(img.channel_names) == 2 + + def test_batch_workflow_keep_original_images( + self, tmp_path, qtbot, sample_workflow_path, images_path + ): + """Test batch workflow with keep_original_images.""" + container = WorkflowContainer() + container.workflow_file.value = sample_workflow_path + container.image_directory.value = images_path + + output_folder = tmp_path / 'Output' + output_folder.mkdir() + container.result_directory.value = output_folder + + container._batch_roots_container[0].value = 'membrane' + container._batch_roots_container[1].value = 'nuclei' + container._keep_original_images.value = True + + container.batch_button.clicked() + + qtbot.waitUntil( + lambda: not container._batch_runner.is_running, timeout=15000 + ) + + img = nImage(output_folder / 'cells3d2ch.tiff') + assert len(img.channel_names) == 4 + assert img.channel_names == [ + 'membrane', + 'nuclei', + 'membrane-label', + 'nucleus-label', + ] + + def test_batch_workflow_all_tasks( + self, tmp_path, qtbot, sample_workflow_path, images_path + ): + """Test batch workflow with all tasks selected.""" + container = WorkflowContainer() + container.workflow_file.value = sample_workflow_path + container.image_directory.value = images_path + + output_folder = tmp_path / 'Output' + output_folder.mkdir() + container.result_directory.value = output_folder + + container._batch_roots_container[0].value = 'membrane' + container._batch_roots_container[1].value = 'nuclei' + container._tasks_select.value = list(container.workflow._tasks.keys()) + + container.batch_workflow() + + qtbot.waitUntil( + lambda: not container._batch_runner.is_running, timeout=15000 + ) + + img = nImage(output_folder / 'cells3d2ch.tiff') + assert len(img.channel_names) == 6 + + def test_cancel_button_stops_batch( + self, tmp_path, qtbot, sample_workflow_path, images_path + ): + """Test that cancel button stops the batch runner.""" + container = WorkflowContainer() + container.workflow_file.value = sample_workflow_path + container.image_directory.value = images_path + + output_folder = tmp_path / 'Output' + output_folder.mkdir() + container.result_directory.value = output_folder + + container._batch_roots_container[0].value = 'membrane' + container._batch_roots_container[1].value = 'nuclei' + + container.batch_workflow() + assert container._batch_runner.is_running + + container._cancel_button.clicked() + + qtbot.waitUntil( + lambda: not container._batch_runner.is_running, timeout=15000 + ) + + assert not container._batch_runner.is_running + + +# ============================================================================= +# Viewer workflow tests (require napari viewer) +# ============================================================================= + + +class TestViewerWorkflow: + """Tests that require an actual napari viewer.""" + + def test_init_with_viewer(self, make_napari_viewer): + """Test initialization with viewer.""" + viewer = make_napari_viewer() + container = WorkflowContainer(viewer) + + assert container._viewer == viewer + assert container._channel_names == [] + assert container._img_dims == '' + + def test_update_roots_with_viewer(self, make_napari_viewer): + """Test _update_roots with a viewer updates layer choices.""" + viewer = make_napari_viewer() + container = WorkflowContainer(viewer) + + container.workflow = MockWorkflow() + container._channel_names = ['red', 'green', 'blue'] + + container._update_roots() + + assert len(container._batch_roots_container) == 2 + assert len(container._viewer_roots_container) == 2 + + for idx, root in enumerate(container._viewer_roots_container): + assert ( + root.label == f'Root {idx}: {container.workflow.roots()[idx]}' + ) + assert root.choices == (None,) # No layers yet + assert root._nullable is True + + # Add layers to viewer + viewer.open_sample('napari', 'cells3d') + viewer.add_labels(np.random.randint(0, 2, (10, 10, 10))) + + # Layer choices should update + for root in container._viewer_roots_container: + assert len(root.choices) == 4 # None + 3 layers + + def test_viewer_workflow_generator( + self, make_napari_viewer, sample_workflow_path + ): + """Test viewer_workflow yields results.""" + viewer = make_napari_viewer() + container = WorkflowContainer(viewer) + container.workflow_file.value = sample_workflow_path + + viewer.open_sample('napari', 'cells3d') + container._viewer_roots_container[0].value = viewer.layers['membrane'] + container._viewer_roots_container[1].value = viewer.layers['nuclei'] + + generator = container.viewer_workflow() + + expected_results = [ + (0, 'membrane-label'), + (1, 'nucleus-label'), + ] + for idx, (task_idx, task, result, _func) in enumerate(generator): + assert task_idx == expected_results[idx][0] + assert task == expected_results[idx][1] + assert isinstance(result, np.ndarray) + + def test_viewer_workflow_yielded_adds_layer(self, make_napari_viewer): + """Test _viewer_workflow_yielded adds layer to viewer.""" + viewer = make_napari_viewer() + container = WorkflowContainer(viewer) + data = np.random.randint(0, 2, (10, 10, 10)) + + value = (1, 'test-name', data, None) + container._viewer_workflow_yielded(value) + + assert container._progress_bar.value == 2 # idx + 1 + assert viewer.layers[0].name == 'test-name' + assert viewer.layers[0].data.shape == data.shape + assert np.array_equal(viewer.layers[0].scale, (1, 1, 1)) + + def test_viewer_workflow_threaded( + self, make_napari_viewer, sample_workflow_path, qtbot + ): + """Test threaded viewer workflow execution.""" + viewer = make_napari_viewer() + container = WorkflowContainer(viewer) + container.workflow_file.value = sample_workflow_path + + viewer.open_sample('napari', 'cells3d') + container._viewer_roots_container[0].value = viewer.layers['membrane'] + container._viewer_roots_container[1].value = viewer.layers['nuclei'] + + container.viewer_workflow_threaded() + + with qtbot.waitSignal( + container._viewer_worker.finished, timeout=15000 + ): + pass + + assert container._progress_bar.value == 2 + assert viewer.layers[2].name == 'membrane-label' + assert viewer.layers[3].name == 'nucleus-label'