From e7062dd517e8f5153d27fa54dd901bd606b04021 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Sun, 22 Jun 2025 19:07:35 -0400 Subject: [PATCH 1/7] better 3D and 2.5D support for raster and vector data --- src/napari_spatialdata/_viewer.py | 31 +++++++------ src/napari_spatialdata/constants/config.py | 2 + src/napari_spatialdata/utils/_utils.py | 9 +++- src/napari_spatialdata/utils/_viewer_utils.py | 46 +++++++++++++++---- 4 files changed, 64 insertions(+), 24 deletions(-) diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index de90a4de..1dfa45a3 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -18,12 +18,11 @@ from spatialdata import get_element_annotators, get_element_instances from spatialdata._core.query.relational_query import _left_join_spatialelement_table from spatialdata._types import ArrayLike -from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_channels +from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_axes_names, get_channels from spatialdata.transformations import Affine, Identity from napari_spatialdata._model import DataModel from napari_spatialdata.constants import config -from napari_spatialdata.constants.config import CIRCLES_AS_POINTS from napari_spatialdata.utils._utils import ( _adjust_channels_order, _get_ellipses_from_circles, @@ -464,7 +463,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: if multi: original_name = original_name[: original_name.rfind("_")] - affine = _get_transform(sdata.images[original_name], selected_cs) + affine = _get_transform(sdata.images[original_name], selected_cs, include_z=True) rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name]) channels = ("RGB(A)",) if rgb else get_channels(sdata.images[original_name]) @@ -511,6 +510,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult df = sdata.shapes[original_name] affine = _get_transform(sdata.shapes[original_name], selected_cs) + # 2.5D circles not supported yet xy = np.array([df.geometry.x, df.geometry.y]).T yx = np.fliplr(xy) radii = df.radius.to_numpy() @@ -535,10 +535,10 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult version = get_napari_version() kwargs: dict[str, Any] = ( {"edge_width": 0.0} - if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS + if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS else {"border_width": 0.0} ) - if CIRCLES_AS_POINTS: + if config.CIRCLES_AS_POINTS: layer = Points( yx, name=key, @@ -550,7 +550,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult assert affine is not None self._adjust_radii_of_points_layer(layer=layer, affine=affine) else: - if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS: + if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS: kwargs |= {"edge_color": "white"} else: kwargs |= {"border_color": "white"} @@ -591,7 +591,8 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi original_name = original_name[: original_name.rfind("_")] df = sdata.shapes[original_name] - affine = _get_transform(sdata.shapes[original_name], selected_cs) + include_z = not config.PROJECT_2_5D_SHAPES_TO_2D + affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z) # when mulitpolygons are present, we select the largest ones if "MultiPolygon" in np.unique(df.geometry.type): @@ -603,7 +604,7 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi df = df.sort_index() # reset the index to the first order simplify = len(df) > config.POLYGON_THRESHOLD - polygons, indices = _get_polygons_properties(df, simplify) + polygons, indices = _get_polygons_properties(df, simplify, include_z=include_z) # this will only work for polygons and not for multipolygons polygons = _transform_coordinates(polygons, f=lambda x: x[::-1]) @@ -656,7 +657,7 @@ def get_sdata_labels(self, sdata: SpatialData, key: str, selected_cs: str, multi original_name = original_name[: original_name.rfind("_")] indices = get_element_instances(sdata.labels[original_name]) - affine = _get_transform(sdata.labels[original_name], selected_cs) + affine = _get_transform(sdata.labels[original_name], selected_cs, include_z=True) rgb_labels, _ = _adjust_channels_order(element=sdata.labels[original_name]) adata, table_name, table_names = self._get_table_data(sdata, original_name) @@ -700,8 +701,10 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi if multi: original_name = original_name[: original_name.rfind("_")] + axes = get_axes_names(sdata.points[original_name]) points = sdata.points[original_name].compute() - affine = _get_transform(sdata.points[original_name], selected_cs) + include_z = "z" in axes and not config.PROJECT_3D_POINTS_TO_2D + affine = _get_transform(sdata.points[original_name], selected_cs, include_z=include_z) adata, table_name, table_names = self._get_table_data(sdata, original_name) if len(points) < config.POINT_THRESHOLD: @@ -721,14 +724,16 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi _, adata = _left_join_spatialelement_table( {"points": {original_name: subsample_points}}, sdata[table_name], match_rows="left" ) - xy = subsample_points[["y", "x"]].values - np.fliplr(xy) + axes = list(sorted(axes, reverse=True)) + if not include_z and "z" in axes: + axes.remove("z") + coords = subsample_points[axes].values # radii_size = _calc_default_radii(self.viewer, sdata, selected_cs) radii_size = 3 version = get_napari_version() kwargs = {"edge_width": 0.0} if version <= packaging.version.parse("0.4.20") else {"border_width": 0.0} layer = Points( - xy, + coords, name=key, size=radii_size * 2, affine=affine, diff --git a/src/napari_spatialdata/constants/config.py b/src/napari_spatialdata/constants/config.py index 14054767..ddcebc23 100644 --- a/src/napari_spatialdata/constants/config.py +++ b/src/napari_spatialdata/constants/config.py @@ -4,3 +4,5 @@ N_SHAPES_WARNING_THRESHOLD = 10000 POINT_SIZE_SCATTERPLOT_WIDGET = 6 CIRCLES_AS_POINTS = True +PROJECT_3D_POINTS_TO_2D = True +PROJECT_2_5D_SHAPES_TO_2D = True diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 237b84b7..7a1fe13c 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -181,7 +181,9 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]: return [[f(xy) for xy in sublist] for sublist in data] -def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | ArrayLike: +def _get_transform( + element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None +) -> None | ArrayLike: if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame): raise RuntimeError("Cannot get transform for {type(element)}") @@ -189,7 +191,10 @@ def _get_transform(element: SpatialElement, coordinate_system_name: str | None = cs = transformations.keys().__iter__().__next__() if coordinate_system_name is None else coordinate_system_name ct = transformations.get(cs) if ct: - return ct.to_affine_matrix(input_axes=("y", "x"), output_axes=("y", "x")) + axes_element = get_axes_names(element) + include_z = include_z and "z" in axes_element + axes_transformation = ("z", "y", "x") if include_z else ("y", "x") + return ct.to_affine_matrix(input_axes=axes_transformation, output_axes=axes_transformation) return None diff --git a/src/napari_spatialdata/utils/_viewer_utils.py b/src/napari_spatialdata/utils/_viewer_utils.py index 17a724f8..e32a5df0 100644 --- a/src/napari_spatialdata/utils/_viewer_utils.py +++ b/src/napari_spatialdata/utils/_viewer_utils.py @@ -1,18 +1,46 @@ from geopandas import GeoDataFrame +from spatialdata.models import get_axes_names -def _get_polygons_properties(df: GeoDataFrame, simplify: bool) -> tuple[list[list[tuple[float, float]]], list[int]]: +def add_z_to_list_of_xy_tuples(xy: list[tuple[float, float]], z: float) -> list[tuple[float, float, float]]: + """ + Add z coordinates to a list of (x, y) tuples. + + Parameters + ---------- + xy + List of (x, y) tuples. + z + z coordinate to add to each tuple. + + Returns + ------- + list[tuple[float, float, float]] + List of (x, y, z) tuples. + """ + return [(x, y, z) for x, y in xy] + + +def _get_polygons_properties( + df: GeoDataFrame, simplify: bool, include_z: bool +) -> (tuple)[list[list[tuple[float, float]]], list[int]]: + # for the moment this function assumes that there are no "Polygon Z", but that the z + # coordinates, if present, is in a separate column indices = [] polygons = [] + axes = get_axes_names(df) + include_z = include_z and "z" in axes - if simplify: - for i in range(0, len(df)): - indices.append(df.iloc[i].name) + for i in range(0, len(df)): + indices.append(df.iloc[i].name) + if include_z: + z = df.iloc[i].z.item() + if simplify: # This can be removed once napari is sped up in the plotting. It changes the shapes only very slightly - polygons.append(list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords)) - else: - for i in range(0, len(df)): - indices.append(df.iloc[i].name) - polygons.append(list(df.geometry.iloc[i].exterior.coords)) + xy = list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords) + else: + xy = list(df.geometry.iloc[i].exterior.coords) + coords = xy if not include_z else add_z_to_list_of_xy_tuples(xy=xy, z=z) + polygons.append(coords) return polygons, indices From 5c7a549d2008738ace99fd648010b64bfe6064a4 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 14 Aug 2025 16:21:44 +0200 Subject: [PATCH 2/7] fix mypy; workaround qt test (command double registration) --- .mypy.ini | 1 - src/napari_spatialdata/_viewer.py | 2 +- src/napari_spatialdata/utils/_viewer_utils.py | 45 ++++++++++++------- tests/test_cli.py | 2 + tests/test_interactive.py | 3 ++ tests/test_scatterwidgets.py | 3 ++ tests/test_widgets.py | 4 ++ 7 files changed, 42 insertions(+), 18 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index d41a7f36..1c42e37e 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,7 +1,6 @@ [mypy] mypy_path = napari-spatialdata python_version = 3.10 -plugins = numpy.typing.mypy_plugin ignore_errors = False warn_redundant_casts = True diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index 1dfa45a3..1c044c41 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -724,7 +724,7 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi _, adata = _left_join_spatialelement_table( {"points": {original_name: subsample_points}}, sdata[table_name], match_rows="left" ) - axes = list(sorted(axes, reverse=True)) + axes = sorted(axes, reverse=True) if not include_z and "z" in axes: axes.remove("z") coords = subsample_points[axes].values diff --git a/src/napari_spatialdata/utils/_viewer_utils.py b/src/napari_spatialdata/utils/_viewer_utils.py index e32a5df0..d49ee1e8 100644 --- a/src/napari_spatialdata/utils/_viewer_utils.py +++ b/src/napari_spatialdata/utils/_viewer_utils.py @@ -1,3 +1,5 @@ +from typing import cast + from geopandas import GeoDataFrame from spatialdata.models import get_axes_names @@ -21,26 +23,37 @@ def add_z_to_list_of_xy_tuples(xy: list[tuple[float, float]], z: float) -> list[ return [(x, y, z) for x, y in xy] -def _get_polygons_properties( - df: GeoDataFrame, simplify: bool, include_z: bool -) -> (tuple)[list[list[tuple[float, float]]], list[int]]: - # for the moment this function assumes that there are no "Polygon Z", but that the z - # coordinates, if present, is in a separate column - indices = [] - polygons = [] +# type aliases, only used in this module +Coord2D = tuple[float, float] +Coord3D = tuple[float, float, float] +Polygon2D = list[Coord2D] +Polygon3D = list[Coord3D] +Polygon = Polygon2D | Polygon3D + + +def _get_polygons_properties(df: GeoDataFrame, simplify: bool, include_z: bool) -> tuple[list[Polygon], list[int]]: + # assumes no "Polygon Z": z is in separate column if present + indices: list[int] = [] + polygons: list[Polygon] = [] + axes = get_axes_names(df) - include_z = include_z and "z" in axes + add_z = include_z and "z" in axes + + for i in range(len(df)): + indices.append(int(df.index[i])) - for i in range(0, len(df)): - indices.append(df.iloc[i].name) - if include_z: - z = df.iloc[i].z.item() if simplify: - # This can be removed once napari is sped up in the plotting. It changes the shapes only very slightly - xy = list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords) + xy = cast(list[Coord2D], list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords)) + else: + xy = cast(list[Coord2D], list(df.geometry.iloc[i].exterior.coords)) + + coords: Polygon2D | Polygon3D + if add_z: + z_val = float(df.iloc[i].z.item() if hasattr(df.iloc[i].z, "item") else df.iloc[i].z) + coords = cast(Polygon3D, add_z_to_list_of_xy_tuples(xy=xy, z=z_val)) else: - xy = list(df.geometry.iloc[i].exterior.coords) - coords = xy if not include_z else add_z_to_list_of_xy_tuples(xy=xy, z=z) + coords = cast(Polygon2D, xy) + polygons.append(coords) return polygons, indices diff --git a/tests/test_cli.py b/tests/test_cli.py index 3e9501a5..55b3e371 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,6 @@ from pathlib import Path +import pytest from click.testing import CliRunner from napari.viewer import Viewer from spatialdata.datasets import blobs @@ -26,6 +27,7 @@ def test_view_path_not_exists(): Viewer.close_all() +@pytest.mark.usefixtures("mock_app_model") def test_view_path_is_dir(): runner = CliRunner() with runner.isolated_filesystem(): diff --git a/tests/test_interactive.py b/tests/test_interactive.py index ef7805a9..8ddd3894 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -18,6 +18,7 @@ ) +@pytest.mark.usefixtures("mock_app_model") class TestImages(PlotTester, metaclass=PlotTesterMeta): def test_plot_can_add_element_image(self, sdata_blobs: SpatialData): blobs_image = Image2DModel.parse(sdata_blobs["blobs_image"], c_coords=("r", "g", "b")) @@ -51,6 +52,7 @@ def test_switch_coordinate_system(self, sdata_blobs: SpatialData): Viewer.close_all() +@pytest.mark.usefixtures("mock_app_model") def test_plot_can_add_element_switch_cs(sdata_blobs: SpatialData): i = Interactive(sdata=sdata_blobs, headless=True) i.add_element(element="blobs_image", element_coordinate_system="global", view_element_system=True) @@ -59,6 +61,7 @@ def test_plot_can_add_element_switch_cs(sdata_blobs: SpatialData): Viewer.close_all() +@pytest.mark.usefixtures("mock_app_model") class TestInteractive(PlotTester, metaclass=PlotTesterMeta): def test_get_layer_existing(self, sdata_blobs: SpatialData): i = Interactive(sdata=sdata_blobs, headless=True) diff --git a/tests/test_scatterwidgets.py b/tests/test_scatterwidgets.py index 4744faec..f23491bf 100644 --- a/tests/test_scatterwidgets.py +++ b/tests/test_scatterwidgets.py @@ -7,6 +7,9 @@ from napari_spatialdata._model import DataModel from napari_spatialdata._scatterwidgets import PlotWidget +pytestmark = pytest.mark.usefixtures("mock_app_model") + + DATA_LEN = 100 diff --git a/tests/test_widgets.py b/tests/test_widgets.py index 98f32850..b569df52 100644 --- a/tests/test_widgets.py +++ b/tests/test_widgets.py @@ -267,6 +267,7 @@ def test_layer_selection(make_napari_viewer: Any, image: ArrayLike, widget: Any, assert widget.model.adata.n_obs == 0 +@pytest.mark.usefixtures("mock_app_model") def test_export_no_rois(adata_labels): """Test export for no rois situation.""" @@ -277,6 +278,7 @@ def test_export_no_rois(adata_labels): assert scatter_widget.status_label.text() == "Status: No rois selected." +@pytest.mark.usefixtures("mock_app_model") def test_export_no_name(adata_labels, mocker): """Test export - no column name provided.""" @@ -290,6 +292,7 @@ def test_export_no_name(adata_labels, mocker): assert scatter_widget._model.adata.obs.equals(adata_labels.obs) +@pytest.mark.usefixtures("mock_app_model") def test_new_annotation(adata_labels, annotation_values, mocker): """Test export - adding a new annotation.""" @@ -304,6 +307,7 @@ def test_new_annotation(adata_labels, annotation_values, mocker): assert np.array_equal(scatter_widget._model.adata.obs.test, annotation_values) +@pytest.mark.usefixtures("mock_app_model") def test_old_annotation(adata_labels, annotation_values, mocker): """Test updating existing annotation.""" From 2281f0bbcfd5ca92e74320fd6aa7ffc3744d5a69 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 13 Oct 2025 13:15:30 +0200 Subject: [PATCH 3/7] fix deprecated API --- src/napari_spatialdata/_viewer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index 1c044c41..74d806ea 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -18,7 +18,7 @@ from spatialdata import get_element_annotators, get_element_instances from spatialdata._core.query.relational_query import _left_join_spatialelement_table from spatialdata._types import ArrayLike -from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_axes_names, get_channels +from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_axes_names, get_channel_names from spatialdata.transformations import Affine, Identity from napari_spatialdata._model import DataModel @@ -466,7 +466,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: affine = _get_transform(sdata.images[original_name], selected_cs, include_z=True) rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name]) - channels = ("RGB(A)",) if rgb else get_channels(sdata.images[original_name]) + channels = ("RGB(A)",) if rgb else get_channel_names(sdata.images[original_name]) adata = AnnData(shape=(0, len(channels)), var=pd.DataFrame(index=channels)) From d08ce7a9795268fbfe861e06ddba98fd9b6f60ee Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 11 Nov 2025 22:34:58 +0100 Subject: [PATCH 4/7] remove dask pin --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 9097cf86..fea284df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = anndata click cycler - dask>=2024.4.1,<=2024.11.2 + dask>=2025.2.0 geopandas loguru matplotlib From 7f633ae0db07baa4cf678bfc8dfe7d120e16f4c3 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 6 Jan 2026 15:20:06 +0100 Subject: [PATCH 5/7] cleanup --- .gitignore | 1 + src/napari_spatialdata/utils/_viewer_utils.py | 28 ++----------------- 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index cb34b882..7e2989e9 100644 --- a/.gitignore +++ b/.gitignore @@ -117,3 +117,4 @@ tests/plots/generated # Local temp script for testing user bugs (Luca) temp/ +uv.lock diff --git a/src/napari_spatialdata/utils/_viewer_utils.py b/src/napari_spatialdata/utils/_viewer_utils.py index 3b186228..16639ed8 100644 --- a/src/napari_spatialdata/utils/_viewer_utils.py +++ b/src/napari_spatialdata/utils/_viewer_utils.py @@ -1,28 +1,6 @@ -from typing import cast - from geopandas import GeoDataFrame from spatialdata.models import get_axes_names - -def add_z_to_list_of_xy_tuples(xy: list[tuple[float, float]], z: float) -> list[tuple[float, float, float]]: - """ - Add z coordinates to a list of (x, y) tuples. - - Parameters - ---------- - xy - List of (x, y) tuples. - z - z coordinate to add to each tuple. - - Returns - ------- - list[tuple[float, float, float]] - List of (x, y, z) tuples. - """ - return [(x, y, z) for x, y in xy] - - # type aliases, only used in this module Coord2D = tuple[float, float] Coord3D = tuple[float, float, float] @@ -43,14 +21,14 @@ def _get_polygons_properties(df: GeoDataFrame, simplify: bool, include_z: bool) indices.append(int(df.index[i])) if simplify: - xy = cast(list[Coord2D], list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords)) + xy = list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords) else: - xy = cast(list[Coord2D], list(df.geometry.iloc[i].exterior.coords)) + xy = list(df.geometry.iloc[i].exterior.coords) coords: Polygon2D | Polygon3D if add_z: z_val = float(df.iloc[i].z.item() if hasattr(df.iloc[i].z, "item") else df.iloc[i].z) - coords = add_z_to_list_of_xy_tuples(xy=xy, z=z_val) + coords = [(x, y, z_val) for x, y in xy] else: coords = xy From 82464d886593442d348a7e16348c31766160a992 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 6 Jan 2026 16:29:02 +0100 Subject: [PATCH 6/7] add tests for 3D points, 2.5D shapes; improve contribution guide --- docs/contributing.md | 22 +++++ tests/conftest.py | 64 ++++++++++++- tests/test_3d_visualization.py | 161 +++++++++++++++++++++++++++++++++ 3 files changed, 246 insertions(+), 1 deletion(-) create mode 100644 tests/test_3d_visualization.py diff --git a/docs/contributing.md b/docs/contributing.md index 1ee0da05..dd0085cf 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,3 +1,25 @@ # Contributing guide Please refer to the [contribution guide from the `spatialdata` repository](https://github.com/scverse/spatialdata/blob/main/docs/contributing.md). + +## Debugging napari GUI tests + +To visually inspect what a test is rendering in napari: + +1. Change `make_napari_viewer()` to `make_napari_viewer(show=True)` +2. Add `napari.run()` before the end of the test (before the assertions) + +Example: + +```python +import napari + + +def test_my_visualization(make_napari_viewer): + viewer = make_napari_viewer(show=True) + # ... setup code ... + napari.run() + # assertions... +``` + +Remember to revert these changes before committing. diff --git a/tests/conftest.py b/tests/conftest.py index e4f46808..2f78635f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,19 +9,23 @@ from pathlib import Path from typing import Any +import geopandas as gpd import napari import numpy as np import pandas as pd import pytest from anndata import AnnData +from dask.dataframe import from_pandas from loguru import logger from matplotlib.testing.compare import compare_images from scipy import ndimage as ndi +from shapely import MultiPolygon, Polygon from skimage import data from spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.datasets import blobs -from spatialdata.models import TableModel +from spatialdata.models import PointsModel, ShapesModel, TableModel +from spatialdata.transformations import Identity, set_transformation from napari_spatialdata.utils._test_utils import export_figure, save_image @@ -259,3 +263,61 @@ def caplog(caplog): def always_sync(monkeypatch, request): if request.node.get_closest_marker("use_thread_loader") is None: monkeypatch.setattr("napari_spatialdata._sdata_widgets.PROBLEMATIC_NUMPY_MACOS", True) + + +@pytest.fixture +def sdata_3d_points() -> SpatialData: + """Create a SpatialData object with 3D points (x, y, z coordinates).""" + n_points = 10 + rng = np.random.default_rng(SEED) + df = pd.DataFrame( + { + "x": rng.uniform(0, 100, n_points), + "y": rng.uniform(0, 100, n_points), + "z": rng.uniform(0, 50, n_points), + } + ) + dask_df = from_pandas(df, npartitions=1) + points = PointsModel.parse(dask_df) + set_transformation(points, {"global": Identity()}, set_all=True) + + return SpatialData(points={"points_3d": points}) + + +@pytest.fixture +def sdata_2_5d_shapes() -> SpatialData: + """Create a SpatialData object with 2.5D shapes (3 layers at different z, polygons + multipolygons).""" + shapes = {} + + geometries = [] + z_values = [] + indices = [] + for i, z_val in enumerate([0.0, 10.0, 20.0]): + # Add simple polygons (triangles and quadrilaterals) + poly1 = Polygon([(10 + i * 5, 10), (20 + i * 5, 10), (15 + i * 5, 20)]) + poly2 = Polygon([(30 + i * 5, 30), (40 + i * 5, 30), (40 + i * 5, 40), (30 + i * 5, 40)]) + geometries.extend([poly1, poly2]) + indices.extend([0, 1]) + z_values.extend([z_val] * 2) + + # Add a multipolygon (two separate polygon parts) + multi_poly = MultiPolygon( + [ + Polygon([(50 + i * 5, 10), (60 + i * 5, 10), (55 + i * 5, 20)]), + Polygon([(50 + i * 5, 30), (60 + i * 5, 30), (60 + i * 5, 40), (50 + i * 5, 40)]), + ] + ) + geometries.append(multi_poly) + indices.append(2) + z_values.append(z_val) + + gdf = gpd.GeoDataFrame( + {"z": z_values, "geometry": geometries}, + index=indices, + ) + + shape_element = ShapesModel.parse(gdf) + set_transformation(shape_element, {"global": Identity()}, set_all=True) + shapes["shapes_2.5d"] = shape_element + + return SpatialData(shapes=shapes) diff --git a/tests/test_3d_visualization.py b/tests/test_3d_visualization.py new file mode 100644 index 00000000..41ee43f3 --- /dev/null +++ b/tests/test_3d_visualization.py @@ -0,0 +1,161 @@ +"""Tests for 3D points and 2.5D shapes visualization. + +For debugging tips on how to visually inspect tests, see docs/contributing.md. +""" + +from typing import Any + +import pytest +from napari.layers import Points, Shapes +from napari.utils.events import EventedList +from spatialdata import SpatialData + +from napari_spatialdata._sdata_widgets import SdataWidget +from napari_spatialdata.constants import config + + +class Test3DPointsVisualization: + """Test 3D points visualization in napari.""" + + def test_3d_points_projected_to_2d(self, make_napari_viewer: Any, sdata_3d_points: SpatialData): + """Test that 3D points are projected to 2D when config flag is True.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + viewer.dims.ndisplay = 3 + + assert len(viewer.layers) == 1 + assert isinstance(viewer.layers[0], Points) + # 2D projection: points should have 2 coordinates + assert viewer.layers[0].data.shape[1] == 2 + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + def test_3d_points_full_3d(self, make_napari_viewer: Any, sdata_3d_points: SpatialData): + """Test that 3D points are visualized in 3D when config flag is False.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + viewer.dims.ndisplay = 3 + + assert len(viewer.layers) == 1 + assert isinstance(viewer.layers[0], Points) + # Full 3D: points should have 3 coordinates (z, y, x) + assert viewer.layers[0].data.shape[1] == 3 + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + +class Test2_5DShapesVisualization: + """Test 2.5D shapes visualization in napari.""" + + def test_2_5d_shapes_projected_to_2d(self, make_napari_viewer: Any, sdata_2_5d_shapes: SpatialData): + """Test that 2.5D shapes are projected to 2D when config flag is True.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + + # Add 2.5D shapes + widget._onClick("shapes_2.5d") + + assert len(viewer.layers) == 1 + assert isinstance(viewer.layers[0], Shapes) + # 2D projection: shape coordinates should have 2 values per vertex (y, x) + for shape_data in viewer.layers[0].data: + assert shape_data.shape[1] == 2 + + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + def test_2_5d_shapes_full_3d(self, make_napari_viewer: Any, sdata_2_5d_shapes: SpatialData): + """Test that 2.5D shapes are visualized in 3D when config flag is False.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + + # Add 2.5D shapes + widget._onClick("shapes_2.5d") + + assert len(viewer.layers) == 1 + assert isinstance(viewer.layers[0], Shapes) + # Full 3D: shape coordinates should have 3 values per vertex (z, y, x) + for shape_data in viewer.layers[0].data: + assert shape_data.shape[1] == 3 + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + +class TestMixed2D3DVisualization: + """Test mixed 2D and 3D visualization scenarios.""" + + @pytest.mark.parametrize( + "points_dim,shapes_dim", + [ + (3, 2), # Points 3D, Shapes 2D + (2, 3), # Points 2D, Shapes 3D + ], + ) + def test_mixed_dimension_visualization( + self, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + sdata_2_5d_shapes: SpatialData, + points_dim: int, + shapes_dim: int, + ): + """Test that points and shapes can be visualized with different dimension settings.""" + original_points_config = config.PROJECT_3D_POINTS_TO_2D + original_shapes_config = config.PROJECT_2_5D_SHAPES_TO_2D + + try: + config.PROJECT_3D_POINTS_TO_2D = points_dim == 2 + config.PROJECT_2_5D_SHAPES_TO_2D = shapes_dim == 2 + + # Create a combined SpatialData + combined_sdata = SpatialData( + points={"points_3d": sdata_3d_points["points_3d"]}, + shapes={"shapes_2.5d": sdata_2_5d_shapes["shapes_2.5d"]}, + ) + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([combined_sdata])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + + widget._onClick("points_3d") + assert viewer.layers[0].data.shape[1] == points_dim + + widget._onClick("shapes_2.5d") + for shape_data in viewer.layers[1].data: + assert shape_data.shape[1] == shapes_dim + + finally: + config.PROJECT_3D_POINTS_TO_2D = original_points_config + config.PROJECT_2_5D_SHAPES_TO_2D = original_shapes_config From 8af175758357b436223abf435ccc9e17a2ebfca8 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 6 Jan 2026 17:22:02 +0100 Subject: [PATCH 7/7] bump min spatialdata --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 5a0fc5dd..8bcd3ed5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,7 @@ install_requires = scipy shapely scikit-learn - spatialdata>=0.7.0dev0 + spatialdata>=0.7.0dev1 superqt typing_extensions>=4.8.0 vispy