diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index d15042181b..a4bb187300 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -293,6 +293,7 @@ AsDiscrete, DistanceTransformEDT, FillHoles, + GenerateHeatmap, Invert, KeepLargestConnectedComponent, LabelFilter, @@ -319,6 +320,9 @@ FillHolesD, FillHolesd, FillHolesDict, + GenerateHeatmapd, + GenerateHeatmapD, + GenerateHeatmapDict, InvertD, Invertd, InvertDict, diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 2e733c4f6c..1e795a75de 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -38,7 +38,14 @@ remove_small_objects, ) from monai.transforms.utils_pytorch_numpy_unification import unravel_index -from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option +from monai.utils import ( + TransformBackends, + convert_data_type, + convert_to_tensor, + ensure_tuple, + get_equivalent_dtype, + look_up_option, +) from monai.utils.type_conversion import convert_to_dst_type __all__ = [ @@ -54,6 +61,7 @@ "SobelGradients", "VoteEnsemble", "Invert", + "GenerateHeatmap", "DistanceTransformEDT", ] @@ -742,6 +750,194 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO return self.post_convert(out_pt, img) +class GenerateHeatmap(Transform): + """ + Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates. + + Notes: + - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D. + - Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D. + - Output layout uses channel-first convention with one channel per landmark: + - Non-batched points (N, D): (N, Y, X) for 2D or (N, Z, Y, X) for 3D + - Batched points (B, N, D): (B, N, Y, X) for 2D or (B, N, Z, Y, X) for 3D + - Each channel corresponds to one landmark. + + Args: + sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions. + spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform. + A single int value will be broadcast to all spatial dimensions. + truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window. + normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``. + dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes). + + Raises: + ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved. + + """ + + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + + def __init__( + self, + sigma: Sequence[float] | float = 5.0, + spatial_shape: Sequence[int] | None = None, + truncated: float = 4.0, + normalize: bool = True, + dtype: np.dtype | torch.dtype | type = np.float32, + ) -> None: + if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)): + if any(s <= 0 for s in sigma): + raise ValueError("sigma values must be positive.") + self._sigma = tuple(float(s) for s in sigma) + else: + if float(sigma) <= 0: + raise ValueError("sigma must be positive.") + self._sigma = (float(sigma),) + if truncated <= 0: + raise ValueError("truncated must be positive.") + self.truncated = float(truncated) + self.normalize = normalize + self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor) + self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray) + # Validate that dtype is floating-point for meaningful Gaussian values + if not self.torch_dtype.is_floating_point: + raise ValueError(f"dtype must be a floating-point type, got {self.torch_dtype}") + self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape) + + def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor: + """ + Args: + points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D), + ordered as (Y, X) for 2D or (Z, Y, X) for 3D. + spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses + the value provided at construction. + + Returns: + Heatmaps with shape (N, *spatial) or (B, N, *spatial), one channel per landmark. + + Raises: + ValueError: if points shape/dimension or spatial_shape is invalid. + """ + original_points = points + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) + + is_batched = points_t.ndim == 3 + if not is_batched: + if points_t.ndim != 2: + raise ValueError( + "points must be a 2D or 3D array with shape (num_points, spatial_dims) or (B, num_points, spatial_dims)." + ) + points_t = points_t.unsqueeze(0) # Add a batch dimension + + if points_t.shape[-1] not in (2, 3): + raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.") + + device = points_t.device + batch_size, num_points, spatial_dims = points_t.shape + + target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims) + sigma = self._resolve_sigma(spatial_dims) + radius = tuple(int(np.ceil(self.truncated * s)) for s in sigma) + + heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device) + image_bounds = tuple(int(s) for s in target_shape) + bounds_t = torch.as_tensor(image_bounds, device=device, dtype=points_t.dtype) + for b_idx in range(batch_size): + for idx, center in enumerate(points_t[b_idx]): + if not torch.isfinite(center).all(): + continue + if not ((center >= 0).all() and (center < bounds_t).all()): + continue + # _make_window expects Python floats; convert only when needed + center_vals = center.tolist() + window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device) + if window_slices is None: + continue + region = heatmap[b_idx, idx][window_slices] + gaussian = self._evaluate_gaussian(coord_shifts, sigma) + updated = torch.maximum(region, gaussian) + # write back + region.copy_(updated) + if self.normalize: + peak = heatmap[b_idx, idx].amax() + denom = torch.where(peak > 0, peak, torch.ones_like(peak)) + heatmap[b_idx, idx].div_(denom) + + if not is_batched: + heatmap = heatmap.squeeze(0) + + target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype + converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype) + return converted + + def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]: + shape = call_shape if call_shape is not None else self.spatial_shape + if shape is None: + raise ValueError("spatial_shape must be provided either at construction time or call time.") + shape_tuple = ensure_tuple(shape) + if len(shape_tuple) != spatial_dims: + if len(shape_tuple) == 1: + shape_tuple = shape_tuple * spatial_dims # type: ignore + else: + raise ValueError( + "spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)." + ) + return tuple(int(s) for s in shape_tuple) + + def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]: + if len(self._sigma) == spatial_dims: + return self._sigma + if len(self._sigma) == 1: + return self._sigma * spatial_dims + raise ValueError("sigma sequence length must equal the number of spatial dimensions.") + + @staticmethod + def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool: + for c, size in zip(center, bounds): + if not (0 <= c < size): + return False + return True + + def _make_window( + self, center: Sequence[float], radius: tuple[int, ...], bounds: tuple[int, ...], device: torch.device + ) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]: + slices: list[slice] = [] + coord_shifts: list[torch.Tensor] = [] + for _dim, (c, r, size) in enumerate(zip(center, radius, bounds)): + start = max(int(np.floor(c - r)), 0) + stop = min(int(np.ceil(c + r)) + 1, size) + if start >= stop: + return None, () + slices.append(slice(start, stop)) + coord_shifts.append(torch.arange(start, stop, device=device, dtype=torch.float32) - float(c)) + return tuple(slices), tuple(coord_shifts) + + def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor: + """ + Evaluate Gaussian at given coordinate shifts with specified sigmas. + + Args: + coord_shifts: Per-dimension coordinate offsets from center. + sigma: Per-dimension standard deviations. + + Returns: + Gaussian values at the specified coordinates. + """ + device = coord_shifts[0].device + shape = tuple(len(axis) for axis in coord_shifts) + if 0 in shape: + return torch.zeros(shape, dtype=self.torch_dtype, device=device) + exponent = torch.zeros(shape, dtype=torch.float32, device=device) + for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)): + shift32 = shift.to(torch.float32) + scaled = (shift32 / float(sig)) ** 2 + reshape_shape = [1] * len(coord_shifts) + reshape_shape[dim] = shift.numel() + exponent += scaled.reshape(reshape_shape) + gauss = torch.exp(-0.5 * exponent) + return gauss.to(dtype=self.torch_dtype) + + class ProbNMS(Transform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 7e1e074f71..8002179256 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -35,6 +35,7 @@ AsDiscrete, DistanceTransformEDT, FillHoles, + GenerateHeatmap, KeepLargestConnectedComponent, LabelFilter, LabelToContour, @@ -48,6 +49,7 @@ from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep +from monai.utils.type_conversion import convert_to_dst_type __all__ = [ "ActivationsD", @@ -95,6 +97,9 @@ "DistanceTransformEDTd", "DistanceTransformEDTD", "DistanceTransformEDTDict", + "GenerateHeatmapd", + "GenerateHeatmapD", + "GenerateHeatmapDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -508,6 +513,175 @@ def __init__(self, keys: KeysCollection, output_key: str | None = None, num_clas super().__init__(keys, ensemble, output_key) +class GenerateHeatmapd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`. + Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image. + + Args: + keys: keys of the corresponding items in the dictionary. + sigma: standard deviation for the Gaussian kernel. Can be a single value or a sequence matching the number + of spatial dimensions. + heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key. + ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will + have the same shape, affine, and spatial metadata as the reference images. + spatial_shape: spatial dimensions of output heatmaps. Can be: + - Single shape (tuple): applied to all keys + - List of shapes: one per key (must match keys length) + truncated: truncation distance for Gaussian kernel computation (in sigmas). + normalize: if True, normalize each heatmap's peak value to 1.0. + dtype: output data type for heatmaps. Defaults to np.float32. + allow_missing_keys: if True, don't raise error if some keys are missing in data. + + Returns: + Dictionary with original data plus generated heatmaps at specified keys. + + Raises: + ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length. + ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys). + ValueError: If input points have invalid shape (must be 2D or 3D). + + Notes: + - Default heatmap_keys are generated as "{key}_heatmap" for each input key + - Shape inference precedence: static spatial_shape > ref_image + - Output shapes: + - Non-batched points (N, D): (N, H, W[, D]) + - Batched points (B, N, D): (B, N, H, W[, D]) + - When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference + """ + + backend = GenerateHeatmap.backend + + # Error messages + _ERR_HEATMAP_KEYS_LEN = "heatmap_keys length must match keys length." + _ERR_REF_KEYS_LEN = "ref_image_keys length must match keys length when provided." + _ERR_SHAPE_LEN = "spatial_shape length must match keys length when providing per-key shapes." + _ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys." + _ERR_INVALID_POINTS = "landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)." + _ERR_REF_NO_SHAPE = "Reference data must define a shape attribute." + + def __init__( + self, + keys: KeysCollection, + sigma: Sequence[float] | float = 5.0, + heatmap_keys: KeysCollection | None = None, + ref_image_keys: KeysCollection | None = None, + spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None, + truncated: float = 4.0, + normalize: bool = True, + dtype: np.dtype | torch.dtype | type = np.float32, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys) + self.ref_image_keys = self._prepare_optional_keys(ref_image_keys) + self.static_shapes = self._prepare_shapes(spatial_shape) + self.generator = GenerateHeatmap( + sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype + ) + + def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: + d = dict(data) + for key, out_key, ref_key, static_shape in self.key_iterator( + d, self.heatmap_keys, self.ref_image_keys, self.static_shapes + ): + points = d[key] + shape = self._determine_shape(points, static_shape, d, ref_key) + # The GenerateHeatmap transform will handle type conversion based on input points + heatmap = self.generator(points, spatial_shape=shape) + # If there's a reference image and we need to match its type/device + reference = d.get(ref_key) if ref_key is not None and ref_key in d else None + if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)): + # Convert to match reference type and device while preserving heatmap's dtype + heatmap, _, _ = convert_to_dst_type( + heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, "device", None) + ) + # Copy metadata if reference is MetaTensor + if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor): + heatmap.affine = reference.affine + self._update_spatial_metadata(heatmap, shape) + d[out_key] = heatmap + return d + + def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]: + if heatmap_keys is None: + return tuple(f"{key}_heatmap" for key in self.keys) + keys_tuple = ensure_tuple(heatmap_keys) + if len(keys_tuple) == 1 and len(self.keys) > 1: + keys_tuple = keys_tuple * len(self.keys) + if len(keys_tuple) != len(self.keys): + raise ValueError(self._ERR_HEATMAP_KEYS_LEN) + return keys_tuple + + def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]: + if maybe_keys is None: + return (None,) * len(self.keys) + keys_tuple = ensure_tuple(maybe_keys) + if len(keys_tuple) == 1 and len(self.keys) > 1: + keys_tuple = keys_tuple * len(self.keys) + if len(keys_tuple) != len(self.keys): + raise ValueError(self._ERR_REF_KEYS_LEN) + return tuple(keys_tuple) + + def _prepare_shapes( + self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None + ) -> tuple[tuple[int, ...] | None, ...]: + if spatial_shape is None: + return (None,) * len(self.keys) + shape_tuple = ensure_tuple(spatial_shape) + if shape_tuple and all(isinstance(v, (int, np.integer)) for v in shape_tuple): + shape = tuple(int(v) for v in shape_tuple) + return (shape,) * len(self.keys) + if len(shape_tuple) == 1 and len(self.keys) > 1: + shape_tuple = shape_tuple * len(self.keys) + if len(shape_tuple) != len(self.keys): + raise ValueError(self._ERR_SHAPE_LEN) + prepared: list[tuple[int, ...] | None] = [] + for item in shape_tuple: + if item is None: + prepared.append(None) + else: + dims = ensure_tuple(item) + prepared.append(tuple(int(v) for v in dims)) + return tuple(prepared) + + def _determine_shape( + self, points: Any, static_shape: tuple[int, ...] | None, data: Mapping[Hashable, Any], ref_key: Hashable | None + ) -> tuple[int, ...]: + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) + if points_t.ndim not in (2, 3): + raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.") + spatial_dims = int(points_t.shape[-1]) + if static_shape is not None: + if len(static_shape) != spatial_dims: + raise ValueError( + f"Provided static spatial_shape has {len(static_shape)} dims; expected {spatial_dims}." + ) + return static_shape + if ref_key is not None and ref_key in data: + return self._shape_from_reference(data[ref_key], spatial_dims) + raise ValueError(self._ERR_NO_SHAPE) + + def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, ...]: + if isinstance(reference, MetaTensor): + meta_shape = reference.meta.get("spatial_shape") + if meta_shape is not None: + dims = ensure_tuple(meta_shape) + if len(dims) == spatial_dims: + return tuple(int(v) for v in dims) + return tuple(int(v) for v in reference.shape[-spatial_dims:]) + if hasattr(reference, "shape"): + return tuple(int(v) for v in reference.shape[-spatial_dims:]) + raise ValueError(self._ERR_REF_NO_SHAPE) + + def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int, ...]) -> None: + """Set spatial_shape explicitly from resolved shape.""" + heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape) + + +GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd + + class ProbNMSd(MapTransform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py new file mode 100644 index 0000000000..73fedb796c --- /dev/null +++ b/tests/transforms/test_generate_heatmap.py @@ -0,0 +1,312 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.post.array import GenerateHeatmap +from tests.test_utils import TEST_NDARRAYS + + +def _argmax_nd(x) -> np.ndarray: + """argmax for N-D array → returns coordinate vector (z,y,x) or (y,x).""" + if isinstance(x, torch.Tensor): + x = x.cpu().numpy() + return np.asarray(np.unravel_index(np.argmax(x), x.shape)) + + +# Test cases for 2D array inputs with different data types +TEST_CASES_2D = [] +for idx, p in enumerate(TEST_NDARRAYS): + TEST_CASES_2D.append( + [ + f"2d_basic_type{idx}", + p(np.array([[4.2, 7.8], [12.3, 3.6]], dtype=np.float32)), + {"sigma": 1.5, "spatial_shape": (16, 16)}, + (2, 16, 16), + ] + ) + +# Test cases for 3D torch outputs with explicit dtype +TEST_CASES_3D_TORCH = [] +for dtype in [torch.float32, torch.float64]: + TEST_CASES_3D_TORCH.append( + [ + f"3d_torch_{str(dtype).replace('torch.', '')}", + torch.tensor([[1.5, 2.5, 3.5]], dtype=torch.float32), + {"sigma": 1.0, "spatial_shape": (8, 8, 8), "dtype": dtype}, + (1, 8, 8, 8), + dtype, + ] + ) + +# Test cases for 3D numpy outputs with explicit dtype +TEST_CASES_3D_NUMPY = [] +for dtype_obj in [np.float32, np.float64]: + TEST_CASES_3D_NUMPY.append( + [ + f"3d_numpy_{dtype_obj.__name__}", + np.array([[1.5, 2.5, 3.5]], dtype=np.float32), + {"sigma": 1.0, "spatial_shape": (8, 8, 8), "dtype": dtype_obj}, + (1, 8, 8, 8), + dtype_obj, + ] + ) + +# Test cases for different sigma values +TEST_CASES_SIGMA = [] +for sigma in [0.5, 1.0, 2.0, 3.0]: + TEST_CASES_SIGMA.append( + [ + f"sigma_{sigma}", + np.array([[8.0, 8.0]], dtype=np.float32), + {"sigma": sigma, "spatial_shape": (16, 16)}, + (1, 16, 16), + ] + ) + +# Test cases for truncated parameter +TEST_CASES_TRUNCATED = [] +for truncated in [2.0, 4.0, 6.0]: + TEST_CASES_TRUNCATED.append( + [ + f"truncated_{truncated}", + np.array([[8.0, 8.0]], dtype=np.float32), + {"sigma": 2.0, "spatial_shape": (32, 32), "truncated": truncated}, + (1, 32, 32), + ] + ) + +# Test cases for batched 3D with different array types +TEST_CASES_BATCHED = [] +for idx, p in enumerate(TEST_NDARRAYS): + TEST_CASES_BATCHED.append( + [ + f"batched_3d_type{idx}", + p(np.array([[[4.2, 7.8, 1.0]], [[12.3, 3.6, 2.0]]], dtype=np.float32)), + {"sigma": 1.5, "spatial_shape": (16, 16, 16)}, + (2, 1, 16, 16, 16), + ] + ) + +# Test cases for device and dtype propagation (torch only) +TEST_CASES_DEVICE_DTYPE = [] +if torch.cuda.is_available(): + for dtype in [torch.float16, torch.float32, torch.float64]: + TEST_CASES_DEVICE_DTYPE.append( + [ + f"cuda_{str(dtype).replace('torch.', '')}", + torch.tensor([[3.0, 4.0, 5.0]], dtype=torch.float32, device="cuda:0"), + {"sigma": 1.2, "spatial_shape": (10, 10, 10), "dtype": dtype}, + (1, 10, 10, 10), + dtype, + "cuda:0", + ] + ) +else: + for dtype in [torch.float32, torch.float64]: + TEST_CASES_DEVICE_DTYPE.append( + [ + f"cpu_{str(dtype).replace('torch.', '')}", + torch.tensor([[3.0, 4.0, 5.0]], dtype=torch.float32), + {"sigma": 1.2, "spatial_shape": (10, 10, 10), "dtype": dtype}, + (1, 10, 10, 10), + dtype, + "cpu", + ] + ) + + +class TestGenerateHeatmap(unittest.TestCase): + @parameterized.expand(TEST_CASES_2D) + def test_array_2d(self, _, points, params, expected_shape): + transform = GenerateHeatmap(**params) + heatmap = transform(points) + + # Check output type matches input type + if isinstance(points, torch.Tensor): + self.assertIsInstance(heatmap, torch.Tensor) + self.assertEqual(heatmap.dtype, torch.float32) # Default dtype for torch + heatmap_np = heatmap.cpu().numpy() + points_np = points.cpu().numpy() + else: + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.dtype, np.float32) # Default dtype for numpy + heatmap_np = heatmap + points_np = points + + self.assertEqual(heatmap.shape, expected_shape) + np.testing.assert_allclose(heatmap_np.max(axis=(1, 2)), np.ones(expected_shape[0]), rtol=1e-5, atol=1e-5) + + # peak should be close to original point location (<= 1px tolerance due to discretization) + for idx in range(expected_shape[0]): + peak = _argmax_nd(heatmap_np[idx]) + self.assertTrue(np.all(np.abs(peak - points_np[idx]) <= 1.0), msg=f"peak={peak}, point={points_np[idx]}") + self.assertLess(heatmap_np[idx, 0, 0], 1e-3) + + @parameterized.expand(TEST_CASES_3D_TORCH) + def test_array_3d_torch_output(self, _, points, params, expected_shape, expected_dtype): + transform = GenerateHeatmap(**params) + heatmap = transform(points) + + self.assertIsInstance(heatmap, torch.Tensor) + self.assertEqual(heatmap.device, points.device) + self.assertEqual(tuple(heatmap.shape), expected_shape) + self.assertEqual(heatmap.dtype, expected_dtype) + self.assertTrue(torch.isclose(heatmap.max(), torch.tensor(1.0, dtype=heatmap.dtype, device=heatmap.device))) + + @parameterized.expand(TEST_CASES_3D_NUMPY) + def test_array_3d_numpy_output(self, _, points, params, expected_shape, expected_dtype): + transform = GenerateHeatmap(**params) + heatmap = transform(points) + + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.shape, expected_shape) + self.assertEqual(heatmap.dtype, expected_dtype) + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + @parameterized.expand(TEST_CASES_DEVICE_DTYPE) + def test_array_torch_device_and_dtype_propagation( + self, _, pts, params, expected_shape, expected_dtype, expected_device + ): + tr = GenerateHeatmap(**params) + hm = tr(pts) + + self.assertIsInstance(hm, torch.Tensor) + self.assertEqual(str(hm.device).split(":")[0], expected_device.split(":")[0]) + self.assertEqual(hm.dtype, expected_dtype) + self.assertEqual(tuple(hm.shape), expected_shape) + self.assertTrue(torch.all(hm >= 0)) + + def test_array_channel_order_identity(self): + # ensure the order of channels follows the order of input points + pts = np.array([[2.0, 2.0], [12.0, 2.0], [2.0, 12.0]], dtype=np.float32) # point A # point B # point C + hm = GenerateHeatmap(sigma=1.2, spatial_shape=(16, 16))(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.shape, (3, 16, 16)) + + peaks = np.vstack([_argmax_nd(hm[i]) for i in range(3)]) + # y,x close to points + np.testing.assert_allclose(peaks, pts, atol=1.0) + + def test_array_points_out_of_bounds(self): + # points outside spatial domain: heatmap should still be valid (no NaN/Inf) and not all-zeros + pts = np.array( + [[-5.0, -5.0], [100.0, 100.0], [8.0, 8.0]], # outside top-left # outside bottom-right # inside + dtype=np.float32, + ) + hm = GenerateHeatmap(sigma=2.0, spatial_shape=(16, 16))(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.shape, (3, 16, 16)) + self.assertFalse(np.isnan(hm).any() or np.isinf(hm).any()) + + # inside point channel should have max≈1; others may clip at border (≤1) + self.assertGreater(hm[2].max(), 0.9) + + @parameterized.expand(TEST_CASES_SIGMA) + def test_array_sigma_scaling_effect(self, _, pt, params, expected_shape): + heatmap = GenerateHeatmap(**params)(pt)[0] + self.assertEqual(heatmap.shape, expected_shape[1:]) + + # All should have peak normalized to 1.0 + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + # Verify heatmap is valid + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + def test_invalid_points_shape_raises(self): + # points must be (N, D) with D in {2,3} + tr = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8)) + with self.assertRaises((ValueError, AssertionError, IndexError, RuntimeError)): + tr(np.zeros((2,), dtype=np.float32)) # wrong rank + + with self.assertRaises((ValueError, AssertionError, IndexError, RuntimeError)): + tr(np.zeros((2, 4), dtype=np.float32)) # D=4 unsupported + + @parameterized.expand(TEST_CASES_BATCHED) + def test_array_batched_3d(self, _, points, params, expected_shape): + transform = GenerateHeatmap(**params) + heatmap = transform(points) + + # Check output type matches input type + if isinstance(points, torch.Tensor): + self.assertIsInstance(heatmap, torch.Tensor) + heatmap_np = heatmap.cpu().numpy() + points_np = points.cpu().numpy() + else: + self.assertIsInstance(heatmap, np.ndarray) + heatmap_np = heatmap + points_np = points + + self.assertEqual(heatmap.shape, expected_shape) + np.testing.assert_allclose( + heatmap_np.max(axis=(2, 3, 4)), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5 + ) + + # Check peaks for each batch item + for i in range(expected_shape[0]): + peak = _argmax_nd(heatmap_np[i, 0]) + self.assertTrue(np.all(np.abs(peak - points_np[i, 0]) <= 1.0), msg=f"peak={peak}, point={points_np[i, 0]}") + + @parameterized.expand(TEST_CASES_TRUNCATED) + def test_truncated_parameter(self, _, pt, params, expected_shape): + heatmap = GenerateHeatmap(**params)(pt)[0] + + # All should have same peak value (normalized to 1.0) + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + # Verify shape and no NaN/Inf + self.assertEqual(heatmap.shape, expected_shape[1:]) + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + def test_torch_to_torch_type_preservation(self): + """Test that torch input produces torch output""" + pts = torch.tensor([[4.0, 4.0]], dtype=torch.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8))(pts) + + self.assertIsInstance(hm, torch.Tensor) + self.assertEqual(hm.dtype, torch.float32) + self.assertEqual(hm.device, pts.device) + + def test_numpy_to_numpy_type_preservation(self): + """Test that numpy input produces numpy output""" + pts = np.array([[4.0, 4.0]], dtype=np.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8))(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.dtype, np.float32) + + def test_dtype_override_torch(self): + """Test dtype parameter works with torch tensors""" + pts = torch.tensor([[4.0, 4.0, 4.0]], dtype=torch.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=torch.float64)(pts) + + self.assertIsInstance(hm, torch.Tensor) + self.assertEqual(hm.dtype, torch.float64) + + def test_dtype_override_numpy(self): + """Test dtype parameter works with numpy arrays""" + pts = np.array([[4.0, 4.0, 4.0]], dtype=np.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=np.float64)(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.dtype, np.float64) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transforms/test_generate_heatmapd.py b/tests/transforms/test_generate_heatmapd.py new file mode 100644 index 0000000000..85cdbdc75c --- /dev/null +++ b/tests/transforms/test_generate_heatmapd.py @@ -0,0 +1,267 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.post.dictionary import GenerateHeatmapd +from tests.test_utils import assert_allclose + +# Test cases for dictionary transforms with reference image +# Only test with non-MetaTensor types to avoid affine conflicts +TEST_CASES_WITH_REF = [] +TEST_CASES_WITH_REF.append( + [ + "dict_with_ref_3d_numpy", + np.array([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=np.float32), + {"sigma": 2.0}, + (2, 8, 8, 8), + torch.float32, + True, # uses reference image + ] +) +TEST_CASES_WITH_REF.append( + [ + "dict_with_ref_3d_torch", + torch.tensor([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=torch.float32), + {"sigma": 2.0}, + (2, 8, 8, 8), + torch.float32, + True, # uses reference image + ] +) + +# Test cases for dictionary transforms with static spatial shape +TEST_CASES_STATIC_SHAPE = [] +for shape in [(6, 6), (8, 8, 8), (10, 10, 10)]: + TEST_CASES_STATIC_SHAPE.append( + [ + f"dict_static_shape_{len(shape)}d", + np.array([[1.0] * len(shape)], dtype=np.float32), + {"spatial_shape": shape}, + (1, *shape), + np.float32, + ] + ) + +# Test cases for dtype control +TEST_CASES_DTYPE = [] +for dtype in [torch.float16, torch.float32, torch.float64]: + TEST_CASES_DTYPE.append( + [ + f"dict_dtype_{str(dtype).replace('torch.', '')}", + np.array([[2.0, 3.0, 4.0]], dtype=np.float32), + {"sigma": 1.4, "dtype": dtype}, + (1, 10, 10, 10), + dtype, + ] + ) + +# Test cases for batched dictionary transforms +TEST_CASES_BATCHED = [] +TEST_CASES_BATCHED.append( + [ + "dict_batched_3d", + torch.tensor([[[1.5, 2.5, 3.5]], [[4.5, 5.5, 6.5]]], dtype=torch.float32), + {"sigma": 1.0}, + (2, 1, 8, 8, 8), + torch.float32, + ] +) + +# Test cases for various sigma values +TEST_CASES_SIGMA_VALUES = [] +for sigma in [0.5, 1.0, 2.0, 3.0]: + TEST_CASES_SIGMA_VALUES.append( + [ + f"dict_sigma_{sigma}", + np.array([[4.0, 4.0, 4.0]], dtype=np.float32), + {"sigma": sigma, "spatial_shape": (8, 8, 8)}, + (1, 8, 8, 8), + ] + ) + + +class TestGenerateHeatmapd(unittest.TestCase): + @parameterized.expand(TEST_CASES_WITH_REF) + def test_dict_with_reference_meta(self, _, points, params, expected_shape, *_unused): + affine = torch.eye(4) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + data = {"points": points, "image": image} + + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", **params) + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), expected_shape) + self.assertEqual(heatmap.meta["spatial_shape"], (8, 8, 8)) + # The heatmap should inherit the reference image's affine + assert_allclose(heatmap.affine, image.affine, type_test=False) + + # Check max values are normalized to 1.0 + max_vals = heatmap.cpu().numpy().max(axis=tuple(range(1, len(expected_shape)))) + np.testing.assert_allclose(max_vals, np.ones(expected_shape[0]), rtol=1e-5, atol=1e-5) + + @parameterized.expand(TEST_CASES_STATIC_SHAPE) + def test_dict_static_shape(self, _, points, params, expected_shape, expected_dtype): + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", **params) + result = transform({"points": points}) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.shape, expected_shape) + self.assertEqual(heatmap.dtype, expected_dtype) + + # Verify no NaN or Inf values + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + # Verify max value is 1.0 for normalized heatmaps + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + def test_dict_missing_shape_raises(self): + # Without ref image or explicit spatial_shape, must raise + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap") + with self.assertRaisesRegex(ValueError, "spatial_shape|ref_image_keys"): + transform({"points": np.zeros((1, 2), dtype=np.float32)}) + + @parameterized.expand(TEST_CASES_DTYPE) + def test_dict_dtype_control(self, _, points, params, expected_shape, expected_dtype): + ref = MetaTensor(torch.zeros((1, 10, 10, 10), dtype=torch.float32), affine=torch.eye(4)) + d = {"pts": points, "img": ref} + + tr = GenerateHeatmapd(keys="pts", heatmap_keys="hm", ref_image_keys="img", **params) + out = tr(d) + hm = out["hm"] + + self.assertIsInstance(hm, MetaTensor) + self.assertEqual(tuple(hm.shape), expected_shape) + self.assertEqual(hm.dtype, expected_dtype) + + @parameterized.expand(TEST_CASES_BATCHED) + def test_dict_batched_with_ref(self, _, points, params, expected_shape, _expected_dtype): + affine = torch.eye(4) + # A single reference image is used for the whole batch + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + data = {"points": points, "image": image} + + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", **params) + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), expected_shape) + self.assertEqual(heatmap.meta["spatial_shape"], (8, 8, 8)) + assert_allclose(heatmap.affine, image.affine, type_test=False) + + # Check max values + hm2 = heatmap.reshape(heatmap.shape[0], heatmap.shape[1], -1) + max_vals = hm2.max(dim=2)[0] + np.testing.assert_allclose( + max_vals.cpu().numpy(), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5 + ) + + @parameterized.expand(TEST_CASES_SIGMA_VALUES) + def test_dict_various_sigma(self, _, points, params, expected_shape): + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", **params) + result = transform({"points": points}) + heatmap = result["heatmap"] + + self.assertEqual(heatmap.shape, expected_shape) + # Verify heatmap is normalized + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + # Verify no NaN or Inf + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + def test_dict_multiple_keys(self): + """Test dictionary transform with multiple input/output keys""" + points1 = np.array([[2.0, 2.0]], dtype=np.float32) + points2 = np.array([[4.0, 4.0]], dtype=np.float32) + + data = {"pts1": points1, "pts2": points2} + transform = GenerateHeatmapd( + keys=["pts1", "pts2"], heatmap_keys=["hm1", "hm2"], spatial_shape=(8, 8), sigma=1.0 + ) + + result = transform(data) + + self.assertIn("hm1", result) + self.assertIn("hm2", result) + self.assertEqual(result["hm1"].shape, (1, 8, 8)) + self.assertEqual(result["hm2"].shape, (1, 8, 8)) + + # Verify peaks are at different locations + self.assertNotEqual(np.argmax(result["hm1"]), np.argmax(result["hm2"])) + + def test_dict_mismatched_heatmap_keys_length(self): + """Test ValueError when heatmap_keys length doesn't match keys""" + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2", "hm3"], # Mismatch: 3 heatmap keys for 2 input keys + spatial_shape=(8, 8), + ) + + def test_dict_mismatched_ref_image_keys_length(self): + """Test ValueError when ref_image_keys length doesn't match keys""" + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2"], + ref_image_keys=["img1", "img2", "img3"], # Mismatch: 3 ref keys for 2 input keys + spatial_shape=(8, 8), + ) + + def test_dict_per_key_spatial_shape_mismatch(self): + """Test ValueError when per-key spatial_shape length doesn't match keys""" + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2"], + spatial_shape=[(8, 8), (8, 8), (8, 8)], # Mismatch: 3 shapes for 2 keys + sigma=1.0, + ) + + def test_metatensor_points_with_ref(self): + """Test MetaTensor points with reference image - documents current behavior""" + from monai.data import MetaTensor + + # Create MetaTensor points with non-identity affine + points_affine = torch.tensor([[2.0, 0, 0, 0], [0, 2.0, 0, 0], [0, 0, 2.0, 0], [0, 0, 0, 1.0]]) + points = MetaTensor(torch.tensor([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=torch.float32), affine=points_affine) + + # Reference image with identity affine + ref_affine = torch.eye(4) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=ref_affine) + image.meta["spatial_shape"] = (8, 8, 8) + + data = {"points": points, "image": image} + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", sigma=2.0) + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), (2, 8, 8, 8)) + + # Heatmap should inherit affine from the reference image + assert_allclose(heatmap.affine, image.affine, type_test=False) + + +if __name__ == "__main__": + unittest.main()