Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,4 @@ tests/plots/generated

# Local temp script for testing user bugs (Luca)
temp/
uv.lock
22 changes: 22 additions & 0 deletions docs/contributing.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
# Contributing guide

Please refer to the [contribution guide from the `spatialdata` repository](https://github.com/scverse/spatialdata/blob/main/docs/contributing.md).

## Debugging napari GUI tests

To visually inspect what a test is rendering in napari:

1. Change `make_napari_viewer()` to `make_napari_viewer(show=True)`
2. Add `napari.run()` before the end of the test (before the assertions)

Example:

```python
import napari


def test_my_visualization(make_napari_viewer):
viewer = make_napari_viewer(show=True)
# ... setup code ...
napari.run()
# assertions...
```

Remember to revert these changes before committing.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ install_requires =
scipy
shapely
scikit-learn
spatialdata>=0.7.0dev0
spatialdata>=0.7.0dev1
superqt
typing_extensions>=4.8.0
vispy
Expand Down
31 changes: 18 additions & 13 deletions src/napari_spatialdata/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from spatialdata import get_element_annotators, get_element_instances
from spatialdata._core.query.relational_query import _left_join_spatialelement_table
from spatialdata._types import ArrayLike
from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_channel_names
from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_axes_names, get_channel_names
from spatialdata.transformations import Affine, Identity

from napari_spatialdata._model import DataModel
from napari_spatialdata.constants import config
from napari_spatialdata.constants.config import CIRCLES_AS_POINTS
from napari_spatialdata.utils._utils import (
_adjust_channels_order,
_get_ellipses_from_circles,
Expand Down Expand Up @@ -470,7 +469,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi:
if multi:
original_name = original_name[: original_name.rfind("_")]

affine = _get_transform(sdata.images[original_name], selected_cs)
affine = _get_transform(sdata.images[original_name], selected_cs, include_z=True)
rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name])

channels = ("RGB(A)",) if rgb else get_channel_names(sdata.images[original_name])
Expand Down Expand Up @@ -517,6 +516,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
df = sdata.shapes[original_name]
affine = _get_transform(sdata.shapes[original_name], selected_cs)

# 2.5D circles not supported yet
xy = np.array([df.geometry.x, df.geometry.y]).T
yx = np.fliplr(xy)
radii = df.radius.to_numpy()
Expand All @@ -541,10 +541,10 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
version = get_napari_version()
kwargs: dict[str, Any] = (
{"edge_width": 0.0}
if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS
if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS
else {"border_width": 0.0}
)
if CIRCLES_AS_POINTS:
if config.CIRCLES_AS_POINTS:
layer = Points(
yx,
name=key,
Expand All @@ -556,7 +556,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
assert affine is not None
self._adjust_radii_of_points_layer(layer=layer, affine=affine)
else:
if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS:
if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS:
kwargs |= {"edge_color": "white"}
else:
kwargs |= {"border_color": "white"}
Expand Down Expand Up @@ -597,7 +597,8 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi
original_name = original_name[: original_name.rfind("_")]

df = sdata.shapes[original_name]
affine = _get_transform(sdata.shapes[original_name], selected_cs)
include_z = not config.PROJECT_2_5D_SHAPES_TO_2D
affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z)

# when mulitpolygons are present, we select the largest ones
if "MultiPolygon" in np.unique(df.geometry.type):
Expand All @@ -609,7 +610,7 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi
df = df.sort_index() # reset the index to the first order

simplify = len(df) > config.POLYGON_THRESHOLD
polygons, indices = _get_polygons_properties(df, simplify)
polygons, indices = _get_polygons_properties(df, simplify, include_z=include_z)

# this will only work for polygons and not for multipolygons
polygons = _transform_coordinates(polygons, f=lambda x: x[::-1])
Expand Down Expand Up @@ -662,7 +663,7 @@ def get_sdata_labels(self, sdata: SpatialData, key: str, selected_cs: str, multi
original_name = original_name[: original_name.rfind("_")]

indices = get_element_instances(sdata.labels[original_name])
affine = _get_transform(sdata.labels[original_name], selected_cs)
affine = _get_transform(sdata.labels[original_name], selected_cs, include_z=True)
rgb_labels, _ = _adjust_channels_order(element=sdata.labels[original_name])

adata, table_name, table_names = self._get_table_data(sdata, original_name)
Expand Down Expand Up @@ -706,8 +707,10 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
if multi:
original_name = original_name[: original_name.rfind("_")]

axes = get_axes_names(sdata.points[original_name])
points = sdata.points[original_name].compute()
affine = _get_transform(sdata.points[original_name], selected_cs)
include_z = "z" in axes and not config.PROJECT_3D_POINTS_TO_2D
affine = _get_transform(sdata.points[original_name], selected_cs, include_z=include_z)
adata, table_name, table_names = self._get_table_data(sdata, original_name)

if len(points) < config.POINT_THRESHOLD:
Expand All @@ -727,14 +730,16 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
_, adata = _left_join_spatialelement_table(
{"points": {original_name: subsample_points}}, sdata[table_name], match_rows="left"
)
xy = subsample_points[["y", "x"]].values
np.fliplr(xy)
axes = sorted(axes, reverse=True)
if not include_z and "z" in axes:
axes.remove("z")
coords = subsample_points[axes].values
# radii_size = _calc_default_radii(self.viewer, sdata, selected_cs)
radii_size = 3
version = get_napari_version()
kwargs = {"edge_width": 0.0} if version <= packaging.version.parse("0.4.20") else {"border_width": 0.0}
layer = Points(
xy,
coords,
name=key,
size=radii_size * 2,
affine=affine,
Expand Down
2 changes: 2 additions & 0 deletions src/napari_spatialdata/constants/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
N_SHAPES_WARNING_THRESHOLD = 10000
POINT_SIZE_SCATTERPLOT_WIDGET = 6
CIRCLES_AS_POINTS = True
PROJECT_3D_POINTS_TO_2D = True
PROJECT_2_5D_SHAPES_TO_2D = True
9 changes: 7 additions & 2 deletions src/napari_spatialdata/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,20 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]:
return [[f(xy) for xy in sublist] for sublist in data]


def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | ArrayLike:
def _get_transform(
element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None
) -> None | ArrayLike:
if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame):
raise RuntimeError("Cannot get transform for {type(element)}")

transformations = get_transformation(element, get_all=True)
cs = transformations.keys().__iter__().__next__() if coordinate_system_name is None else coordinate_system_name
ct = transformations.get(cs)
if ct:
return ct.to_affine_matrix(input_axes=("y", "x"), output_axes=("y", "x"))
axes_element = get_axes_names(element)
include_z = include_z and "z" in axes_element
axes_transformation = ("z", "y", "x") if include_z else ("y", "x")
return ct.to_affine_matrix(input_axes=axes_transformation, output_axes=axes_transformation)
return None


Expand Down
43 changes: 31 additions & 12 deletions src/napari_spatialdata/utils/_viewer_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
from geopandas import GeoDataFrame
from spatialdata.models import get_axes_names

# type aliases, only used in this module
Coord2D = tuple[float, float]
Coord3D = tuple[float, float, float]
Polygon2D = list[Coord2D]
Polygon3D = list[Coord3D]
Polygon = Polygon2D | Polygon3D

def _get_polygons_properties(df: GeoDataFrame, simplify: bool) -> tuple[list[list[tuple[float, float]]], list[int]]:
indices = []
polygons = []

if simplify:
for i in range(0, len(df)):
indices.append(df.iloc[i].name)
# This can be removed once napari is sped up in the plotting. It changes the shapes only very slightly
polygons.append(list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords))
else:
for i in range(0, len(df)):
indices.append(df.iloc[i].name)
polygons.append(list(df.geometry.iloc[i].exterior.coords))
def _get_polygons_properties(df: GeoDataFrame, simplify: bool, include_z: bool) -> tuple[list[Polygon], list[int]]:
# assumes no "Polygon Z": z is in separate column if present
indices: list[int] = []
polygons: list[Polygon] = []

axes = get_axes_names(df)
add_z = include_z and "z" in axes

for i in range(len(df)):
indices.append(int(df.index[i]))

if simplify:
xy = list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords)
else:
xy = list(df.geometry.iloc[i].exterior.coords)

coords: Polygon2D | Polygon3D
if add_z:
z_val = float(df.iloc[i].z.item() if hasattr(df.iloc[i].z, "item") else df.iloc[i].z)
coords = [(x, y, z_val) for x, y in xy]
else:
coords = xy

polygons.append(coords)

return polygons, indices
64 changes: 63 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,23 @@
from pathlib import Path
from typing import Any

import geopandas as gpd
import napari
import numpy as np
import pandas as pd
import pytest
from anndata import AnnData
from dask.dataframe import from_pandas
from loguru import logger
from matplotlib.testing.compare import compare_images
from scipy import ndimage as ndi
from shapely import MultiPolygon, Polygon
from skimage import data
from spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata.datasets import blobs
from spatialdata.models import TableModel
from spatialdata.models import PointsModel, ShapesModel, TableModel
from spatialdata.transformations import Identity, set_transformation

from napari_spatialdata.utils._test_utils import export_figure, save_image

Expand Down Expand Up @@ -259,3 +263,61 @@ def caplog(caplog):
def always_sync(monkeypatch, request):
if request.node.get_closest_marker("use_thread_loader") is None:
monkeypatch.setattr("napari_spatialdata._sdata_widgets.PROBLEMATIC_NUMPY_MACOS", True)


@pytest.fixture
def sdata_3d_points() -> SpatialData:
"""Create a SpatialData object with 3D points (x, y, z coordinates)."""
n_points = 10
rng = np.random.default_rng(SEED)
df = pd.DataFrame(
{
"x": rng.uniform(0, 100, n_points),
"y": rng.uniform(0, 100, n_points),
"z": rng.uniform(0, 50, n_points),
}
)
dask_df = from_pandas(df, npartitions=1)
points = PointsModel.parse(dask_df)
set_transformation(points, {"global": Identity()}, set_all=True)

return SpatialData(points={"points_3d": points})


@pytest.fixture
def sdata_2_5d_shapes() -> SpatialData:
"""Create a SpatialData object with 2.5D shapes (3 layers at different z, polygons + multipolygons)."""
shapes = {}

geometries = []
z_values = []
indices = []
for i, z_val in enumerate([0.0, 10.0, 20.0]):
# Add simple polygons (triangles and quadrilaterals)
poly1 = Polygon([(10 + i * 5, 10), (20 + i * 5, 10), (15 + i * 5, 20)])
poly2 = Polygon([(30 + i * 5, 30), (40 + i * 5, 30), (40 + i * 5, 40), (30 + i * 5, 40)])
geometries.extend([poly1, poly2])
indices.extend([0, 1])
z_values.extend([z_val] * 2)

# Add a multipolygon (two separate polygon parts)
multi_poly = MultiPolygon(
[
Polygon([(50 + i * 5, 10), (60 + i * 5, 10), (55 + i * 5, 20)]),
Polygon([(50 + i * 5, 30), (60 + i * 5, 30), (60 + i * 5, 40), (50 + i * 5, 40)]),
]
)
geometries.append(multi_poly)
indices.append(2)
z_values.append(z_val)

gdf = gpd.GeoDataFrame(
{"z": z_values, "geometry": geometries},
index=indices,
)

shape_element = ShapesModel.parse(gdf)
set_transformation(shape_element, {"global": Identity()}, set_all=True)
shapes["shapes_2.5d"] = shape_element

return SpatialData(shapes=shapes)
Loading
Loading