diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index ea3cc56e..6743d12a 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -26,10 +26,10 @@ def _nd_sphere( """ if ndim == 2: - return morph.disk(radius) + return morph.disk(radius).astype(bool) if ndim == 3: - return morph.ball(radius) + return morph.ball(radius).astype(bool) raise ValueError(f"Spherical is only implemented for 2D and 3D, got ndim={ndim}") diff --git a/src/tracksdata/nodes/_test/test_mask.py b/src/tracksdata/nodes/_test/test_mask.py index 6723dc44..d804f21d 100644 --- a/src/tracksdata/nodes/_test/test_mask.py +++ b/src/tracksdata/nodes/_test/test_mask.py @@ -376,6 +376,7 @@ def test_mask_from_coordinates_2d_basic() -> None: # Should be a disk of radius 2, shape (5,5), centered at (5,5) assert mask.mask.shape == (5, 5) assert mask.mask[2, 2] # center pixel is True + assert mask.mask.dtype == bool np.testing.assert_array_equal(mask.bbox, [3, 3, 8, 8])