diff --git a/.gitignore b/.gitignore index cb34b882..7e2989e9 100644 --- a/.gitignore +++ b/.gitignore @@ -117,3 +117,4 @@ tests/plots/generated # Local temp script for testing user bugs (Luca) temp/ +uv.lock diff --git a/docs/contributing.md b/docs/contributing.md index 1ee0da05..dd0085cf 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,3 +1,25 @@ # Contributing guide Please refer to the [contribution guide from the `spatialdata` repository](https://github.com/scverse/spatialdata/blob/main/docs/contributing.md). + +## Debugging napari GUI tests + +To visually inspect what a test is rendering in napari: + +1. Change `make_napari_viewer()` to `make_napari_viewer(show=True)` +2. Add `napari.run()` before the end of the test (before the assertions) + +Example: + +```python +import napari + + +def test_my_visualization(make_napari_viewer): + viewer = make_napari_viewer(show=True) + # ... setup code ... + napari.run() + # assertions... +``` + +Remember to revert these changes before committing. diff --git a/setup.cfg b/setup.cfg index 5a0fc5dd..8bcd3ed5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,7 @@ install_requires = scipy shapely scikit-learn - spatialdata>=0.7.0dev0 + spatialdata>=0.7.0dev1 superqt typing_extensions>=4.8.0 vispy diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index c6af7139..8d5e3d71 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -18,12 +18,11 @@ from spatialdata import get_element_annotators, get_element_instances from spatialdata._core.query.relational_query import _left_join_spatialelement_table from spatialdata._types import ArrayLike -from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_channel_names +from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_axes_names, get_channel_names from spatialdata.transformations import Affine, Identity from napari_spatialdata._model import DataModel from napari_spatialdata.constants import config -from napari_spatialdata.constants.config import CIRCLES_AS_POINTS from napari_spatialdata.utils._utils import ( _adjust_channels_order, _get_ellipses_from_circles, @@ -470,7 +469,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: if multi: original_name = original_name[: original_name.rfind("_")] - affine = _get_transform(sdata.images[original_name], selected_cs) + affine = _get_transform(sdata.images[original_name], selected_cs, include_z=True) rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name]) channels = ("RGB(A)",) if rgb else get_channel_names(sdata.images[original_name]) @@ -517,6 +516,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult df = sdata.shapes[original_name] affine = _get_transform(sdata.shapes[original_name], selected_cs) + # 2.5D circles not supported yet xy = np.array([df.geometry.x, df.geometry.y]).T yx = np.fliplr(xy) radii = df.radius.to_numpy() @@ -541,10 +541,10 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult version = get_napari_version() kwargs: dict[str, Any] = ( {"edge_width": 0.0} - if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS + if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS else {"border_width": 0.0} ) - if CIRCLES_AS_POINTS: + if config.CIRCLES_AS_POINTS: layer = Points( yx, name=key, @@ -556,7 +556,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult assert affine is not None self._adjust_radii_of_points_layer(layer=layer, affine=affine) else: - if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS: + if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS: kwargs |= {"edge_color": "white"} else: kwargs |= {"border_color": "white"} @@ -597,7 +597,8 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi original_name = original_name[: original_name.rfind("_")] df = sdata.shapes[original_name] - affine = _get_transform(sdata.shapes[original_name], selected_cs) + include_z = not config.PROJECT_2_5D_SHAPES_TO_2D + affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z) # when mulitpolygons are present, we select the largest ones if "MultiPolygon" in np.unique(df.geometry.type): @@ -609,7 +610,7 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi df = df.sort_index() # reset the index to the first order simplify = len(df) > config.POLYGON_THRESHOLD - polygons, indices = _get_polygons_properties(df, simplify) + polygons, indices = _get_polygons_properties(df, simplify, include_z=include_z) # this will only work for polygons and not for multipolygons polygons = _transform_coordinates(polygons, f=lambda x: x[::-1]) @@ -662,7 +663,7 @@ def get_sdata_labels(self, sdata: SpatialData, key: str, selected_cs: str, multi original_name = original_name[: original_name.rfind("_")] indices = get_element_instances(sdata.labels[original_name]) - affine = _get_transform(sdata.labels[original_name], selected_cs) + affine = _get_transform(sdata.labels[original_name], selected_cs, include_z=True) rgb_labels, _ = _adjust_channels_order(element=sdata.labels[original_name]) adata, table_name, table_names = self._get_table_data(sdata, original_name) @@ -706,8 +707,10 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi if multi: original_name = original_name[: original_name.rfind("_")] + axes = get_axes_names(sdata.points[original_name]) points = sdata.points[original_name].compute() - affine = _get_transform(sdata.points[original_name], selected_cs) + include_z = "z" in axes and not config.PROJECT_3D_POINTS_TO_2D + affine = _get_transform(sdata.points[original_name], selected_cs, include_z=include_z) adata, table_name, table_names = self._get_table_data(sdata, original_name) if len(points) < config.POINT_THRESHOLD: @@ -727,14 +730,16 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi _, adata = _left_join_spatialelement_table( {"points": {original_name: subsample_points}}, sdata[table_name], match_rows="left" ) - xy = subsample_points[["y", "x"]].values - np.fliplr(xy) + axes = sorted(axes, reverse=True) + if not include_z and "z" in axes: + axes.remove("z") + coords = subsample_points[axes].values # radii_size = _calc_default_radii(self.viewer, sdata, selected_cs) radii_size = 3 version = get_napari_version() kwargs = {"edge_width": 0.0} if version <= packaging.version.parse("0.4.20") else {"border_width": 0.0} layer = Points( - xy, + coords, name=key, size=radii_size * 2, affine=affine, diff --git a/src/napari_spatialdata/constants/config.py b/src/napari_spatialdata/constants/config.py index 14054767..ddcebc23 100644 --- a/src/napari_spatialdata/constants/config.py +++ b/src/napari_spatialdata/constants/config.py @@ -4,3 +4,5 @@ N_SHAPES_WARNING_THRESHOLD = 10000 POINT_SIZE_SCATTERPLOT_WIDGET = 6 CIRCLES_AS_POINTS = True +PROJECT_3D_POINTS_TO_2D = True +PROJECT_2_5D_SHAPES_TO_2D = True diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 03399977..cfe37e23 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -181,7 +181,9 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]: return [[f(xy) for xy in sublist] for sublist in data] -def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | ArrayLike: +def _get_transform( + element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None +) -> None | ArrayLike: if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame): raise RuntimeError("Cannot get transform for {type(element)}") @@ -189,7 +191,10 @@ def _get_transform(element: SpatialElement, coordinate_system_name: str | None = cs = transformations.keys().__iter__().__next__() if coordinate_system_name is None else coordinate_system_name ct = transformations.get(cs) if ct: - return ct.to_affine_matrix(input_axes=("y", "x"), output_axes=("y", "x")) + axes_element = get_axes_names(element) + include_z = include_z and "z" in axes_element + axes_transformation = ("z", "y", "x") if include_z else ("y", "x") + return ct.to_affine_matrix(input_axes=axes_transformation, output_axes=axes_transformation) return None diff --git a/src/napari_spatialdata/utils/_viewer_utils.py b/src/napari_spatialdata/utils/_viewer_utils.py index 17a724f8..16639ed8 100644 --- a/src/napari_spatialdata/utils/_viewer_utils.py +++ b/src/napari_spatialdata/utils/_viewer_utils.py @@ -1,18 +1,37 @@ from geopandas import GeoDataFrame +from spatialdata.models import get_axes_names +# type aliases, only used in this module +Coord2D = tuple[float, float] +Coord3D = tuple[float, float, float] +Polygon2D = list[Coord2D] +Polygon3D = list[Coord3D] +Polygon = Polygon2D | Polygon3D -def _get_polygons_properties(df: GeoDataFrame, simplify: bool) -> tuple[list[list[tuple[float, float]]], list[int]]: - indices = [] - polygons = [] - if simplify: - for i in range(0, len(df)): - indices.append(df.iloc[i].name) - # This can be removed once napari is sped up in the plotting. It changes the shapes only very slightly - polygons.append(list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords)) - else: - for i in range(0, len(df)): - indices.append(df.iloc[i].name) - polygons.append(list(df.geometry.iloc[i].exterior.coords)) +def _get_polygons_properties(df: GeoDataFrame, simplify: bool, include_z: bool) -> tuple[list[Polygon], list[int]]: + # assumes no "Polygon Z": z is in separate column if present + indices: list[int] = [] + polygons: list[Polygon] = [] + + axes = get_axes_names(df) + add_z = include_z and "z" in axes + + for i in range(len(df)): + indices.append(int(df.index[i])) + + if simplify: + xy = list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords) + else: + xy = list(df.geometry.iloc[i].exterior.coords) + + coords: Polygon2D | Polygon3D + if add_z: + z_val = float(df.iloc[i].z.item() if hasattr(df.iloc[i].z, "item") else df.iloc[i].z) + coords = [(x, y, z_val) for x, y in xy] + else: + coords = xy + + polygons.append(coords) return polygons, indices diff --git a/tests/conftest.py b/tests/conftest.py index e4f46808..2f78635f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,19 +9,23 @@ from pathlib import Path from typing import Any +import geopandas as gpd import napari import numpy as np import pandas as pd import pytest from anndata import AnnData +from dask.dataframe import from_pandas from loguru import logger from matplotlib.testing.compare import compare_images from scipy import ndimage as ndi +from shapely import MultiPolygon, Polygon from skimage import data from spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.datasets import blobs -from spatialdata.models import TableModel +from spatialdata.models import PointsModel, ShapesModel, TableModel +from spatialdata.transformations import Identity, set_transformation from napari_spatialdata.utils._test_utils import export_figure, save_image @@ -259,3 +263,61 @@ def caplog(caplog): def always_sync(monkeypatch, request): if request.node.get_closest_marker("use_thread_loader") is None: monkeypatch.setattr("napari_spatialdata._sdata_widgets.PROBLEMATIC_NUMPY_MACOS", True) + + +@pytest.fixture +def sdata_3d_points() -> SpatialData: + """Create a SpatialData object with 3D points (x, y, z coordinates).""" + n_points = 10 + rng = np.random.default_rng(SEED) + df = pd.DataFrame( + { + "x": rng.uniform(0, 100, n_points), + "y": rng.uniform(0, 100, n_points), + "z": rng.uniform(0, 50, n_points), + } + ) + dask_df = from_pandas(df, npartitions=1) + points = PointsModel.parse(dask_df) + set_transformation(points, {"global": Identity()}, set_all=True) + + return SpatialData(points={"points_3d": points}) + + +@pytest.fixture +def sdata_2_5d_shapes() -> SpatialData: + """Create a SpatialData object with 2.5D shapes (3 layers at different z, polygons + multipolygons).""" + shapes = {} + + geometries = [] + z_values = [] + indices = [] + for i, z_val in enumerate([0.0, 10.0, 20.0]): + # Add simple polygons (triangles and quadrilaterals) + poly1 = Polygon([(10 + i * 5, 10), (20 + i * 5, 10), (15 + i * 5, 20)]) + poly2 = Polygon([(30 + i * 5, 30), (40 + i * 5, 30), (40 + i * 5, 40), (30 + i * 5, 40)]) + geometries.extend([poly1, poly2]) + indices.extend([0, 1]) + z_values.extend([z_val] * 2) + + # Add a multipolygon (two separate polygon parts) + multi_poly = MultiPolygon( + [ + Polygon([(50 + i * 5, 10), (60 + i * 5, 10), (55 + i * 5, 20)]), + Polygon([(50 + i * 5, 30), (60 + i * 5, 30), (60 + i * 5, 40), (50 + i * 5, 40)]), + ] + ) + geometries.append(multi_poly) + indices.append(2) + z_values.append(z_val) + + gdf = gpd.GeoDataFrame( + {"z": z_values, "geometry": geometries}, + index=indices, + ) + + shape_element = ShapesModel.parse(gdf) + set_transformation(shape_element, {"global": Identity()}, set_all=True) + shapes["shapes_2.5d"] = shape_element + + return SpatialData(shapes=shapes) diff --git a/tests/test_3d_visualization.py b/tests/test_3d_visualization.py new file mode 100644 index 00000000..41ee43f3 --- /dev/null +++ b/tests/test_3d_visualization.py @@ -0,0 +1,161 @@ +"""Tests for 3D points and 2.5D shapes visualization. + +For debugging tips on how to visually inspect tests, see docs/contributing.md. +""" + +from typing import Any + +import pytest +from napari.layers import Points, Shapes +from napari.utils.events import EventedList +from spatialdata import SpatialData + +from napari_spatialdata._sdata_widgets import SdataWidget +from napari_spatialdata.constants import config + + +class Test3DPointsVisualization: + """Test 3D points visualization in napari.""" + + def test_3d_points_projected_to_2d(self, make_napari_viewer: Any, sdata_3d_points: SpatialData): + """Test that 3D points are projected to 2D when config flag is True.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + viewer.dims.ndisplay = 3 + + assert len(viewer.layers) == 1 + assert isinstance(viewer.layers[0], Points) + # 2D projection: points should have 2 coordinates + assert viewer.layers[0].data.shape[1] == 2 + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + def test_3d_points_full_3d(self, make_napari_viewer: Any, sdata_3d_points: SpatialData): + """Test that 3D points are visualized in 3D when config flag is False.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + viewer.dims.ndisplay = 3 + + assert len(viewer.layers) == 1 + assert isinstance(viewer.layers[0], Points) + # Full 3D: points should have 3 coordinates (z, y, x) + assert viewer.layers[0].data.shape[1] == 3 + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + +class Test2_5DShapesVisualization: + """Test 2.5D shapes visualization in napari.""" + + def test_2_5d_shapes_projected_to_2d(self, make_napari_viewer: Any, sdata_2_5d_shapes: SpatialData): + """Test that 2.5D shapes are projected to 2D when config flag is True.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + + # Add 2.5D shapes + widget._onClick("shapes_2.5d") + + assert len(viewer.layers) == 1 + assert isinstance(viewer.layers[0], Shapes) + # 2D projection: shape coordinates should have 2 values per vertex (y, x) + for shape_data in viewer.layers[0].data: + assert shape_data.shape[1] == 2 + + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + def test_2_5d_shapes_full_3d(self, make_napari_viewer: Any, sdata_2_5d_shapes: SpatialData): + """Test that 2.5D shapes are visualized in 3D when config flag is False.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + + # Add 2.5D shapes + widget._onClick("shapes_2.5d") + + assert len(viewer.layers) == 1 + assert isinstance(viewer.layers[0], Shapes) + # Full 3D: shape coordinates should have 3 values per vertex (z, y, x) + for shape_data in viewer.layers[0].data: + assert shape_data.shape[1] == 3 + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + +class TestMixed2D3DVisualization: + """Test mixed 2D and 3D visualization scenarios.""" + + @pytest.mark.parametrize( + "points_dim,shapes_dim", + [ + (3, 2), # Points 3D, Shapes 2D + (2, 3), # Points 2D, Shapes 3D + ], + ) + def test_mixed_dimension_visualization( + self, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + sdata_2_5d_shapes: SpatialData, + points_dim: int, + shapes_dim: int, + ): + """Test that points and shapes can be visualized with different dimension settings.""" + original_points_config = config.PROJECT_3D_POINTS_TO_2D + original_shapes_config = config.PROJECT_2_5D_SHAPES_TO_2D + + try: + config.PROJECT_3D_POINTS_TO_2D = points_dim == 2 + config.PROJECT_2_5D_SHAPES_TO_2D = shapes_dim == 2 + + # Create a combined SpatialData + combined_sdata = SpatialData( + points={"points_3d": sdata_3d_points["points_3d"]}, + shapes={"shapes_2.5d": sdata_2_5d_shapes["shapes_2.5d"]}, + ) + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([combined_sdata])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + + widget._onClick("points_3d") + assert viewer.layers[0].data.shape[1] == points_dim + + widget._onClick("shapes_2.5d") + for shape_data in viewer.layers[1].data: + assert shape_data.shape[1] == shapes_dim + + finally: + config.PROJECT_3D_POINTS_TO_2D = original_points_config + config.PROJECT_2_5D_SHAPES_TO_2D = original_shapes_config