From 8ef905b36b26f2d84c386ab5359d76a61ca643ac Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Thu, 18 Sep 2025 13:23:18 +0900 Subject: [PATCH 01/19] Adds a transform to generate heatmap from landmarks Adds a `GenerateHeatmap` transform to create gaussian response maps from landmark coordinates. This transform is implemented for both array and dictionary-based workflows. It enables the generation of heatmaps from landmark data, facilitating tasks like landmark localization and visualization. The transform supports 2D and 3D coordinates and offers options for controlling the gaussian standard deviation, spatial shape, truncation, normalization, and data type. --- monai/transforms/post/array.py | 150 +++++++++++++++++++++++++++- monai/transforms/post/dictionary.py | 136 +++++++++++++++++++++++++ tests/test_generate_heatmap.py | 90 +++++++++++++++++ 3 files changed, 375 insertions(+), 1 deletion(-) create mode 100644 tests/test_generate_heatmap.py diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 2e733c4f6c..4d419819d6 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -38,7 +38,14 @@ remove_small_objects, ) from monai.transforms.utils_pytorch_numpy_unification import unravel_index -from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option +from monai.utils import ( + TransformBackends, + convert_data_type, + convert_to_tensor, + ensure_tuple, + get_equivalent_dtype, + look_up_option, +) from monai.utils.type_conversion import convert_to_dst_type __all__ = [ @@ -54,6 +61,7 @@ "SobelGradients", "VoteEnsemble", "Invert", + "GenerateHeatmap", "DistanceTransformEDT", ] @@ -742,6 +750,146 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO return self.post_convert(out_pt, img) +class GenerateHeatmap(Transform): + """ + Generate per-landmark gaussian response maps for 2D or 3D coordinates. + + Args: + sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions. + spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform. + truncate: extent, in multiples of ``sigma``, used to crop the gaussian support window. + normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``. + dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes). + + Raises: + ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved. + + """ + + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + + def __init__( + self, + sigma: Sequence[float] | float = 5.0, + spatial_shape: Sequence[int] | None = None, + truncate: float = 3.0, + normalize: bool = True, + dtype: np.dtype | torch.dtype | type = np.float32, + ) -> None: + if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)): + if any(s <= 0 for s in sigma): + raise ValueError("sigma values must be positive.") + self._sigma = tuple(float(s) for s in sigma) + else: + if float(sigma) <= 0: + raise ValueError("sigma must be positive.") + self._sigma = float(sigma) + if truncate <= 0: + raise ValueError("truncate must be positive.") + self.truncate = float(truncate) + self.normalize = normalize + self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor) + self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray) + self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape) + + def __call__( + self, + points: NdarrayOrTensor, + spatial_shape: Sequence[int] | None = None, + ) -> NdarrayOrTensor: + original_points = points + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) + if points_t.ndim != 2: + raise ValueError("points must be a 2D array with shape (num_points, spatial_dims).") + device = points_t.device + num_points, spatial_dims = points_t.shape + if spatial_dims not in (2, 3): + raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.") + + target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims) + sigma = self._resolve_sigma(spatial_dims) + radius = tuple(int(np.ceil(self.truncate * s)) for s in sigma) + + heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device) + image_bounds = tuple(int(s) for s in target_shape) + for idx, center in enumerate(points_t): + center_vals = center.tolist() + if not np.all(np.isfinite(center_vals)): + continue + if not self._is_inside(center_vals, image_bounds): + continue + window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device) + if window_slices is None: + continue + region = heatmap[(idx, *window_slices)] + gaussian = self._evaluate_gaussian(coord_shifts, sigma) + torch.maximum(region, gaussian, out=region) + if self.normalize: + max_val = heatmap[idx].max() + if max_val.item() > 0: + heatmap[idx] /= max_val + + target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype + converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype) + return converted + + def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]: + shape = call_shape if call_shape is not None else self.spatial_shape + if shape is None: + raise ValueError("spatial_shape must be provided either at construction time or call time.") + shape_tuple = ensure_tuple(shape) + if len(shape_tuple) != spatial_dims: + if len(shape_tuple) == 1: + shape_tuple = shape_tuple * spatial_dims # type: ignore + else: + raise ValueError("spatial_shape length must match spatial dimension of the landmarks.") + return tuple(int(s) for s in shape_tuple) + + def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]: + if isinstance(self._sigma, tuple): + if len(self._sigma) == spatial_dims: + return self._sigma + if len(self._sigma) == 1: + return self._sigma * spatial_dims + raise ValueError("sigma sequence length must equal the number of spatial dimensions.") + return (self._sigma,) * spatial_dims + + @staticmethod + def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool: + return all(0 <= c < size for c, size in zip(center, bounds)) + + def _make_window( + self, + center: Sequence[float], + radius: tuple[int, ...], + bounds: tuple[int, ...], + device: torch.device, + ) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]: + slices: list[slice] = [] + coord_shifts: list[torch.Tensor] = [] + for dim, (c, r, size) in enumerate(zip(center, radius, bounds)): + start = max(int(np.floor(c - r)), 0) + stop = min(int(np.ceil(c + r)) + 1, size) + if start >= stop: + return None, () + slices.append(slice(start, stop)) + coord_shifts.append(torch.arange(start, stop, device=device, dtype=self.torch_dtype) - float(c)) + return tuple(slices), tuple(coord_shifts) + + def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor: + device = coord_shifts[0].device + shape = tuple(len(axis) for axis in coord_shifts) + if 0 in shape: + return torch.zeros(shape, dtype=self.torch_dtype, device=device) + exponent = torch.zeros(shape, dtype=self.torch_dtype, device=device) + for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)): + scaled = (shift / float(sig)) ** 2 + reshape_shape = [1] * len(coord_shifts) + reshape_shape[dim] = shift.numel() + exponent += scaled.reshape(reshape_shape) + return torch.exp(-0.5 * exponent) + + class ProbNMS(Transform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 7e1e074f71..02b939a9bb 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -35,6 +35,7 @@ AsDiscrete, DistanceTransformEDT, FillHoles, + GenerateHeatmap, KeepLargestConnectedComponent, LabelFilter, LabelToContour, @@ -48,6 +49,7 @@ from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep +from monai.utils.type_conversion import convert_to_dst_type __all__ = [ "ActivationsD", @@ -95,6 +97,9 @@ "DistanceTransformEDTd", "DistanceTransformEDTD", "DistanceTransformEDTDict", + "GenerateHeatmapd", + "GenerateHeatmapD", + "GenerateHeatmapDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -508,6 +513,137 @@ def __init__(self, keys: KeysCollection, output_key: str | None = None, num_clas super().__init__(keys, ensemble, output_key) +class GenerateHeatmapd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`. + Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image. + """ + + backend = GenerateHeatmap.backend + + def __init__( + self, + keys: KeysCollection, + sigma: Sequence[float] | float = 5.0, + heatmap_keys: KeysCollection | None = None, + ref_image_keys: KeysCollection | None = None, + spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None, + truncate: float = 3.0, + normalize: bool = True, + dtype: np.dtype | type = np.float32, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys) + self.ref_image_keys = self._prepare_optional_keys(ref_image_keys) + self.static_shapes = self._prepare_shapes(spatial_shape) + self.generator = GenerateHeatmap( + sigma=sigma, + spatial_shape=None, + truncate=truncate, + normalize=normalize, + dtype=dtype, + ) + + def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: + d = dict(data) + for key, out_key, ref_key, static_shape in self.key_iterator( + d, self.heatmap_keys, self.ref_image_keys, self.static_shapes + ): + points = d[key] + shape = self._determine_shape(points, static_shape, d, ref_key) + heatmap = self.generator(points, spatial_shape=shape) + reference = d.get(ref_key) if ref_key is not None and ref_key in d else None + d[out_key] = self._prepare_output(heatmap, reference) + return d + + def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]: + if heatmap_keys is None: + return tuple(f"{key}_heatmap" for key in self.keys) + keys_tuple = ensure_tuple(heatmap_keys) + if len(keys_tuple) == 1 and len(self.keys) > 1: + keys_tuple = keys_tuple * len(self.keys) + if len(keys_tuple) != len(self.keys): + raise ValueError("heatmap_keys length must match keys length.") + return keys_tuple + + def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]: + if maybe_keys is None: + return (None,) * len(self.keys) + keys_tuple = ensure_tuple(maybe_keys) + if len(keys_tuple) == 1 and len(self.keys) > 1: + keys_tuple = keys_tuple * len(self.keys) + if len(keys_tuple) != len(self.keys): + raise ValueError("ref_image_keys length must match keys length when provided.") + return tuple(keys_tuple) + + def _prepare_shapes( + self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None + ) -> tuple[tuple[int, ...] | None, ...]: + if spatial_shape is None: + return (None,) * len(self.keys) + shape_tuple = ensure_tuple(spatial_shape) + if shape_tuple and all(isinstance(v, (int, np.integer)) for v in shape_tuple): + shape = tuple(int(v) for v in shape_tuple) + return (shape,) * len(self.keys) + if len(shape_tuple) == 1 and len(self.keys) > 1: + shape_tuple = shape_tuple * len(self.keys) + if len(shape_tuple) != len(self.keys): + raise ValueError("spatial_shape length must match keys length when providing per-key shapes.") + prepared: list[tuple[int, ...] | None] = [] + for item in shape_tuple: + if item is None: + prepared.append(None) + else: + dims = ensure_tuple(item) + prepared.append(tuple(int(v) for v in dims)) + return tuple(prepared) + + def _determine_shape( + self, + points: Any, + static_shape: tuple[int, ...] | None, + data: Mapping[Hashable, Any], + ref_key: Hashable | None, + ) -> tuple[int, ...]: + if static_shape is not None: + return static_shape + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) + if points_t.ndim != 2: + raise ValueError("landmark arrays must be 2D with shape (num_points, spatial_dims).") + spatial_dims = int(points_t.shape[1]) + if ref_key is not None and ref_key in data: + return self._shape_from_reference(data[ref_key], spatial_dims) + raise ValueError( + "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys." + ) + + def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, ...]: + if isinstance(reference, MetaTensor): + meta_shape = reference.meta.get("spatial_shape") + if meta_shape is not None: + dims = ensure_tuple(meta_shape) + if len(dims) == spatial_dims: + return tuple(int(v) for v in dims) + return tuple(int(v) for v in reference.shape[-spatial_dims:]) + if hasattr(reference, "shape"): + return tuple(int(v) for v in reference.shape[-spatial_dims:]) + raise ValueError("Reference data must define a shape attribute.") + + def _prepare_output(self, heatmap: NdarrayOrTensor, reference: Any) -> Any: + if isinstance(reference, MetaTensor): + converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=reference.dtype, device=reference.device) + converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) + return converted + if isinstance(reference, torch.Tensor): + converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=reference.dtype, device=reference.device) + return converted + return heatmap + + +GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd + + class ProbNMSd(MapTransform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via diff --git a/tests/test_generate_heatmap.py b/tests/test_generate_heatmap.py new file mode 100644 index 0000000000..ff594719d3 --- /dev/null +++ b/tests/test_generate_heatmap.py @@ -0,0 +1,90 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from monai.data import MetaTensor +from monai.transforms.post.array import GenerateHeatmap +from monai.transforms.post.dictionary import GenerateHeatmapd + + +def test_generate_heatmap_array_2d() -> None: + points = np.array([[4.2, 7.8], [12.3, 3.6]], dtype=np.float32) + transform = GenerateHeatmap(sigma=1.5, spatial_shape=(16, 16)) + + heatmap = transform(points) + + assert heatmap.shape == (2, 16, 16) + assert heatmap.dtype == np.float32 + np.testing.assert_allclose(heatmap.max(axis=(1, 2)), np.ones(2), rtol=1e-5, atol=1e-5) + + for idx, channel in enumerate(heatmap): + max_idx = np.array(np.unravel_index(np.argmax(channel), channel.shape)) + assert np.all(np.abs(max_idx - points[idx]) <= 1) + assert channel[0, 0] < 1e-3 + + +def test_generate_heatmap_array_torch_output() -> None: + points = torch.tensor([[1.5, 2.5, 3.5]], dtype=torch.float32) + transform = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=torch.float32) + + heatmap = transform(points.to(device=points.device)) + + assert isinstance(heatmap, torch.Tensor) + assert heatmap.device == points.device + assert heatmap.shape == (1, 8, 8, 8) + assert torch.isclose(heatmap.max(), torch.tensor(1.0, dtype=heatmap.dtype, device=heatmap.device)) + + +def test_generate_heatmapd_with_reference_meta() -> None: + points = np.array([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=np.float32) + affine = torch.eye(4) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + data = {"points": points, "image": image} + + transform = GenerateHeatmapd( + keys="points", + heatmap_keys="heatmap", + ref_image_keys="image", + sigma=2.0, + ) + + result = transform(data) + heatmap = result["heatmap"] + + assert isinstance(heatmap, MetaTensor) + assert tuple(heatmap.shape) == (2, 8, 8, 8) + assert heatmap.meta["spatial_shape"] == (8, 8, 8) + assert torch.allclose(heatmap.affine, image.affine) + np.testing.assert_allclose(heatmap.cpu().numpy().max(axis=(1, 2, 3)), np.ones(2), rtol=1e-5, atol=1e-5) + + +def test_generate_heatmapd_static_shape() -> None: + points = np.array([[1.0, 1.0]], dtype=np.float32) + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", spatial_shape=(6, 6)) + + result = transform({"points": points}) + + heatmap = result["heatmap"] + assert isinstance(heatmap, np.ndarray) + assert heatmap.shape == (1, 6, 6) + + +def test_generate_heatmapd_missing_shape_raises() -> None: + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap") + + with pytest.raises(ValueError): + transform({"points": np.zeros((1, 2), dtype=np.float32)}) From 226bf906c9b62e3ca196ebd29dd81b0eede48b4a Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Fri, 19 Sep 2025 19:45:25 +0900 Subject: [PATCH 02/19] Adds heatmap generation demo and tests Introduces a new interactive notebook demonstrating landmark to heatmap conversion using MONAI transforms. This includes: - A notebook with array and dictionary transform modes. - A test suite for the `GenerateHeatmap` transform. This enhancement enables users to visualize and interact with heatmap generation, facilitating a better understanding and application of the MONAI transforms. --- 2d_mdtest.ipynb | 258 ++++++++++++++++++++++ tests/test_generate_heatmap.py | 90 -------- tests/transforms/test_generate_heatmap.py | 176 +++++++++++++++ 3 files changed, 434 insertions(+), 90 deletions(-) create mode 100644 2d_mdtest.ipynb delete mode 100644 tests/test_generate_heatmap.py create mode 100644 tests/transforms/test_generate_heatmap.py diff --git a/2d_mdtest.ipynb b/2d_mdtest.ipynb new file mode 100644 index 0000000000..54571c37cf --- /dev/null +++ b/2d_mdtest.ipynb @@ -0,0 +1,258 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2D Landmark → Heatmap (MNIST / MedMNIST) — with MONAI\n", + "Interactive demo converting clicked landmarks to Gaussian heatmaps **using MONAI**.\n", + "\n", + "**Modes**\n", + "1) Array transform: `GenerateHeatmap`\n", + "2) Dict transform: `GenerateHeatmapd` with optional `MetaTensor` reference" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Installation requirements for interactive notebook\n", + "# %pip install --upgrade pip\n", + "# %pip install torch torchvision monai medmnist matplotlib ipywidgets\n", + "#\n", + "# For JupyterLab users, also run:\n", + "# %pip install jupyterlab-widgets\n", + "#\n", + "# For interactive matplotlib (optional, used in first implementation):\n", + "# %pip install ipympl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from torchvision import datasets, transforms\n", + "from monai.transforms.post.array import GenerateHeatmap\n", + "from monai.transforms.post.dictionary import GenerateHeatmapd\n", + "from monai.data import MetaTensor\n", + "\n", + "try:\n", + " import medmnist\n", + " from medmnist import PathMNIST\n", + "\n", + " HAS_MEDMNIST = True\n", + "except Exception:\n", + " HAS_MEDMNIST = False\n", + " print(\"medmnist not available. Run `pip install medmnist` to enable PathMNIST.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image shape: 28x28\n" + ] + } + ], + "source": [ + "# Load a small 2D image\n", + "use_medmnist = False\n", + "if use_medmnist and HAS_MEDMNIST:\n", + " ds = PathMNIST(split=\"test\", download=True, as_rgb=True)\n", + " img = np.asarray(ds[0][0]).mean(axis=2).astype(np.float32)\n", + "else:\n", + " mnist = datasets.MNIST(root=\"./data\", train=False, download=True, transform=transforms.ToTensor())\n", + " img = mnist[0][0][0].numpy().astype(np.float32)\n", + "\n", + "if img.max() > 0:\n", + " img = img / float(img.max())\n", + "H, W = img.shape\n", + "print(f\"Image shape: {H}x{W}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Heatmap helper using GenerateHeatmap\n", + "# NOTE: GenerateHeatmap expects points in (y, x) == (row, col) order matching array indexing.\n", + "# This wrapper accepts user-friendly (x, y) and internally reorders to (y, x).\n", + "\n", + "sigma = 3.0\n", + "\n", + "\n", + "def heatmap_with_array_transform(x, y, sigma_override=None):\n", + " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", + " tr = GenerateHeatmap(sigma=s, spatial_shape=(H, W))\n", + " # Reorder (x,y) -> (y,x) for the transform\n", + " pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n", + " return tr(pts_yx) # (N,H,W) where pts interpreted as (row, col)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "affine = torch.eye(4)\n", + "ref_img = MetaTensor(torch.from_numpy(img).unsqueeze(0), affine=affine)\n", + "ref_img.meta[\"spatial_shape\"] = (H, W)\n", + "\n", + "# Dictionary version wrapper also accepts (x,y) and converts to (y,x)\n", + "\n", + "\n", + "def heatmap_with_dict_transform(x, y, sigma_override=None, use_ref=True):\n", + " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", + " tr = GenerateHeatmapd(\n", + " keys=\"points\",\n", + " heatmap_keys=\"heatmap\",\n", + " ref_image_keys=\"ref\" if use_ref else None,\n", + " spatial_shape=None if use_ref else (H, W),\n", + " sigma=s,\n", + " )\n", + " pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n", + " data = {\"points\": pts_yx, \"ref\": ref_img}\n", + " out = tr(data)\n", + " return out[\"heatmap\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Random points (x, y):\n", + "[[24.567745 18.05713 ]\n", + " [15.933253 2.4389822]\n", + " [21.020788 17.367664 ]]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Simple random landmark → heatmap example (no interactivity)\n", + "# Re-run this cell to sample new random points and regenerate heatmaps.\n", + "# INTERNAL NOTE: GenerateHeatmap consumes (row=y, col=x). We sample (x,y) for user readability and convert.\n", + "\n", + "import random\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Parameters\n", + "num_points = 3 # number of random landmarks\n", + "sigma_demo = 3.0 # Gaussian sigma\n", + "combine_mode = \"max\" # or 'sum'\n", + "\n", + "# Sample random (x,y) points within image bounds (user-friendly)\n", + "points_xy = np.array(\n", + " [[random.uniform(0, W - 1), random.uniform(0, H - 1)] for _ in range(num_points)], dtype=np.float32\n", + ") # (N,2)\n", + "print(\"Random points (x, y):\")\n", + "print(points_xy)\n", + "\n", + "# Convert to (y,x) for the transform\n", + "yx_points = points_xy[:, [1, 0]].copy()\n", + "\n", + "array_tr = GenerateHeatmap(sigma=sigma_demo, spatial_shape=(H, W))\n", + "heatmaps = array_tr(yx_points) # now correct orientation\n", + "\n", + "if combine_mode == \"max\":\n", + " combined = heatmaps.max(axis=0)\n", + "elif combine_mode == \"sum\":\n", + " combined = heatmaps.sum(axis=0)\n", + " if combined.max() > 0:\n", + " combined = combined / combined.max()\n", + "else:\n", + " raise ValueError(\"combine_mode must be 'max' or 'sum'\")\n", + "\n", + "# Plot\n", + "fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", + "axes[0].imshow(img, cmap=\"gray\", vmin=0.0, vmax=1.0, origin=\"upper\")\n", + "axes[0].set_title(\"Base Image\")\n", + "axes[0].set_axis_off()\n", + "for x, y in points_xy:\n", + " axes[0].plot(x, y, \"r+\", markersize=12, markeredgewidth=2)\n", + "\n", + "axes[1].imshow(img, cmap=\"gray\", vmin=0.0, vmax=1.0, origin=\"upper\")\n", + "axes[1].imshow(combined, alpha=0.6, cmap=\"hot\", origin=\"upper\")\n", + "axes[1].set_title(f\"Combined Heatmap (mode={combine_mode}, sigma={sigma_demo})\")\n", + "axes[1].set_axis_off()\n", + "for x, y in points_xy:\n", + " axes[1].plot(x, y, \"c+\", markersize=12, markeredgewidth=2)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Individual channels\n", + "fig2, axes2 = plt.subplots(1, num_points, figsize=(4 * num_points, 4))\n", + "if num_points == 1:\n", + " axes2 = [axes2]\n", + "for i, ax in enumerate(axes2):\n", + " ax.imshow(heatmaps[i], cmap=\"hot\", origin=\"upper\")\n", + " ax.plot(points_xy[i, 0], points_xy[i, 1], \"w+\", markersize=12, markeredgewidth=2)\n", + " ax.set_title(f\"Point {i}: (x={points_xy[i,0]:.1f}, y={points_xy[i,1]:.1f})\")\n", + " ax.set_axis_off()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".conda", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_generate_heatmap.py b/tests/test_generate_heatmap.py deleted file mode 100644 index ff594719d3..0000000000 --- a/tests/test_generate_heatmap.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import numpy as np -import pytest -import torch - -from monai.data import MetaTensor -from monai.transforms.post.array import GenerateHeatmap -from monai.transforms.post.dictionary import GenerateHeatmapd - - -def test_generate_heatmap_array_2d() -> None: - points = np.array([[4.2, 7.8], [12.3, 3.6]], dtype=np.float32) - transform = GenerateHeatmap(sigma=1.5, spatial_shape=(16, 16)) - - heatmap = transform(points) - - assert heatmap.shape == (2, 16, 16) - assert heatmap.dtype == np.float32 - np.testing.assert_allclose(heatmap.max(axis=(1, 2)), np.ones(2), rtol=1e-5, atol=1e-5) - - for idx, channel in enumerate(heatmap): - max_idx = np.array(np.unravel_index(np.argmax(channel), channel.shape)) - assert np.all(np.abs(max_idx - points[idx]) <= 1) - assert channel[0, 0] < 1e-3 - - -def test_generate_heatmap_array_torch_output() -> None: - points = torch.tensor([[1.5, 2.5, 3.5]], dtype=torch.float32) - transform = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=torch.float32) - - heatmap = transform(points.to(device=points.device)) - - assert isinstance(heatmap, torch.Tensor) - assert heatmap.device == points.device - assert heatmap.shape == (1, 8, 8, 8) - assert torch.isclose(heatmap.max(), torch.tensor(1.0, dtype=heatmap.dtype, device=heatmap.device)) - - -def test_generate_heatmapd_with_reference_meta() -> None: - points = np.array([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=np.float32) - affine = torch.eye(4) - image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) - image.meta["spatial_shape"] = (8, 8, 8) - data = {"points": points, "image": image} - - transform = GenerateHeatmapd( - keys="points", - heatmap_keys="heatmap", - ref_image_keys="image", - sigma=2.0, - ) - - result = transform(data) - heatmap = result["heatmap"] - - assert isinstance(heatmap, MetaTensor) - assert tuple(heatmap.shape) == (2, 8, 8, 8) - assert heatmap.meta["spatial_shape"] == (8, 8, 8) - assert torch.allclose(heatmap.affine, image.affine) - np.testing.assert_allclose(heatmap.cpu().numpy().max(axis=(1, 2, 3)), np.ones(2), rtol=1e-5, atol=1e-5) - - -def test_generate_heatmapd_static_shape() -> None: - points = np.array([[1.0, 1.0]], dtype=np.float32) - transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", spatial_shape=(6, 6)) - - result = transform({"points": points}) - - heatmap = result["heatmap"] - assert isinstance(heatmap, np.ndarray) - assert heatmap.shape == (1, 6, 6) - - -def test_generate_heatmapd_missing_shape_raises() -> None: - transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap") - - with pytest.raises(ValueError): - transform({"points": np.zeros((1, 2), dtype=np.float32)}) diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py new file mode 100644 index 0000000000..123ac087c5 --- /dev/null +++ b/tests/transforms/test_generate_heatmap.py @@ -0,0 +1,176 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +import math +import numpy as np +import torch + +from monai.data import MetaTensor +from monai.transforms.post.array import GenerateHeatmap +from monai.transforms.post.dictionary import GenerateHeatmapd +from tests.test_utils import assert_allclose + + +def _argmax_nd(x: np.ndarray) -> np.ndarray: + """argmax for N-D array → returns coordinate vector (z,y,x) or (y,x).""" + return np.asarray(np.unravel_index(np.argmax(x), x.shape)) + + +class TestGenerateHeatmap(unittest.TestCase): + def test_array_2d(self): + points = np.array([[4.2, 7.8], [12.3, 3.6]], dtype=np.float32) + transform = GenerateHeatmap(sigma=1.5, spatial_shape=(16, 16)) + + heatmap = transform(points) + + self.assertEqual(heatmap.shape, (2, 16, 16)) + self.assertEqual(heatmap.dtype, np.float32) + np.testing.assert_allclose(heatmap.max(axis=(1, 2)), np.ones(2), rtol=1e-5, atol=1e-5) + + # peak should be close to original point location (<= 1px tolerance due to discretization) + for idx, channel in enumerate(heatmap): + peak = _argmax_nd(channel) + self.assertTrue(np.all(np.abs(peak - points[idx]) <= 1.0), msg=f"peak={peak}, point={points[idx]}") + self.assertLess(channel[0, 0], 1e-3) + + def test_array_3d_torch_output(self): + points = torch.tensor([[1.5, 2.5, 3.5]], dtype=torch.float32) + transform = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=torch.float32) + + heatmap = transform(points.to(device=points.device)) + + self.assertIsInstance(heatmap, torch.Tensor) + self.assertEqual(heatmap.device, points.device) + self.assertEqual(tuple(heatmap.shape), (1, 8, 8, 8)) + self.assertTrue(torch.isclose(heatmap.max(), torch.tensor(1.0, dtype=heatmap.dtype, device=heatmap.device))) + + def test_array_torch_device_and_dtype_propagation(self): + # verify dtype parameter honored and CUDA (if available) + dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + + pts = torch.tensor([[3.0, 4.0, 5.0]], dtype=torch.float32, device=device) + tr = GenerateHeatmap(sigma=1.2, spatial_shape=(10, 10, 10), dtype=dtype) + + hm = tr(pts) + self.assertIsInstance(hm, torch.Tensor) + self.assertEqual(hm.device, device) + self.assertEqual(hm.dtype, dtype) + self.assertEqual(tuple(hm.shape), (1, 10, 10, 10)) + self.assertTrue(torch.all(hm >= 0)) + + def test_array_channel_order_identity(self): + # ensure the order of channels follows the order of input points + pts = np.array( + [ + [2.0, 2.0], # point A + [12.0, 2.0], # point B + [2.0, 12.0], # point C + ], + dtype=np.float32, + ) + hm = GenerateHeatmap(sigma=1.2, spatial_shape=(16, 16))(pts) + self.assertEqual(hm.shape, (3, 16, 16)) + + peaks = np.vstack([_argmax_nd(hm[i]) for i in range(3)]) + # y,x close to points + np.testing.assert_allclose(peaks, pts, atol=1.0) + + def test_array_points_out_of_bounds(self): + # points outside spatial domain: heatmap should still be valid (no NaN/Inf) and not all-zeros + pts = np.array( + [ + [-5.0, -5.0], # outside top-left + [100.0, 100.0], # outside bottom-right + [8.0, 8.0], # inside + ], + dtype=np.float32, + ) + hm = GenerateHeatmap(sigma=2.0, spatial_shape=(16, 16))(pts) + self.assertEqual(hm.shape, (3, 16, 16)) + self.assertFalse(np.isnan(hm).any() or np.isinf(hm).any()) + + # inside point channel should have max≈1; others may clip at border (≤1) + self.assertGreater(hm[2].max(), 0.9) + + def test_array_sigma_scaling_effect(self): + # Larger sigma should spread mass (lower peak), smaller sigma higher peak + pt = np.array([[8.0, 8.0]], dtype=np.float32) + small = GenerateHeatmap(sigma=0.8, spatial_shape=(16, 16))(pt)[0] + large = GenerateHeatmap(sigma=2.5, spatial_shape=(16, 16))(pt)[0] + self.assertGreater(small.max(), large.max() - 1e-6) # small sigma peak >= large sigma peak + + def test_dict_with_reference_meta(self): + points = np.array([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=np.float32) + affine = torch.eye(4) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + data = {"points": points, "image": image} + + transform = GenerateHeatmapd( + keys="points", + heatmap_keys="heatmap", + ref_image_keys="image", + sigma=2.0, + ) + + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), (2, 8, 8, 8)) + self.assertEqual(heatmap.meta["spatial_shape"], (8, 8, 8)) + assert_allclose(heatmap.affine, image.affine, type_test=False) + np.testing.assert_allclose(heatmap.cpu().numpy().max(axis=(1, 2, 3)), np.ones(2), rtol=1e-5, atol=1e-5) + + def test_dict_static_shape(self): + points = np.array([[1.0, 1.0]], dtype=np.float32) + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", spatial_shape=(6, 6)) + + result = transform({"points": points}) + heatmap = result["heatmap"] + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.shape, (1, 6, 6)) + + def test_dict_missing_shape_raises(self): + # Without ref image or explicit spatial_shape, must raise + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap") + with self.assertRaises(ValueError): + transform({"points": np.zeros((1, 2), dtype=np.float32)}) + + def test_invalid_points_shape_raises(self): + # points must be (N, D) with D in {2,3} + tr = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8)) + with self.assertRaises((ValueError, AssertionError, IndexError, RuntimeError)): + tr(np.zeros((2,), dtype=np.float32)) # wrong rank + + with self.assertRaises((ValueError, AssertionError, IndexError, RuntimeError)): + tr(np.zeros((2, 4), dtype=np.float32)) # D=4 unsupported + + def test_dict_dtype_control(self): + # Ensure dtype argument controls output dtype for dictionary transform too + points = np.array([[2.0, 3.0, 4.0]], dtype=np.float32) + ref = MetaTensor(torch.zeros((1, 10, 10, 10), dtype=torch.float32), affine=torch.eye(4)) + d = {"pts": points, "img": ref} + + tr = GenerateHeatmapd(keys="pts", heatmap_keys="hm", ref_image_keys="img", sigma=1.4, dtype=torch.float16) + out = tr(d) + hm = out["hm"] + self.assertIsInstance(hm, MetaTensor) + self.assertEqual(tuple(hm.shape), (1, 10, 10, 10)) + self.assertEqual(hm.dtype, torch.float16) + + +if __name__ == "__main__": + unittest.main() From 3097baf5204b03872a87f8baa2e6d34411f9f42f Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sun, 21 Sep 2025 16:51:09 +0900 Subject: [PATCH 03/19] Enables batched input for heatmap generation Extends the `GenerateHeatmap` transform to support batched inputs, allowing for more efficient processing of multiple landmark sets. This change modifies the transform to handle inputs with a batch dimension (B, N, spatial_dims) in addition to single-point inputs (N, spatial_dims). It also includes a demonstration of 3D heatmap generation using PyVista for visualization. --- 2d_mdtest.ipynb | 25 ++- 3d_heatmap_pyvista.ipynb | 238 ++++++++++++++++++++++ monai/transforms/post/array.py | 56 +++-- monai/transforms/post/dictionary.py | 6 +- tests/transforms/test_generate_heatmap.py | 52 +++++ 5 files changed, 347 insertions(+), 30 deletions(-) create mode 100644 3d_heatmap_pyvista.ipynb diff --git a/2d_mdtest.ipynb b/2d_mdtest.ipynb index 54571c37cf..73cd466375 100644 --- a/2d_mdtest.ipynb +++ b/2d_mdtest.ipynb @@ -91,6 +91,7 @@ "# Heatmap helper using GenerateHeatmap\n", "# NOTE: GenerateHeatmap expects points in (y, x) == (row, col) order matching array indexing.\n", "# This wrapper accepts user-friendly (x, y) and internally reorders to (y, x).\n", + "# It now supports batched inputs.\n", "\n", "sigma = 3.0\n", "\n", @@ -99,8 +100,12 @@ " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", " tr = GenerateHeatmap(sigma=s, spatial_shape=(H, W))\n", " # Reorder (x,y) -> (y,x) for the transform\n", - " pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n", - " return tr(pts_yx) # (N,H,W) where pts interpreted as (row, col)" + " # Support batched and non-batched inputs\n", + " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", + " if pts.ndim == 2:\n", + " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", + " pts_yx = pts[..., [1, 0]]\n", + " return tr(pts_yx) # (B, N, H, W)\n" ] }, { @@ -125,10 +130,14 @@ " spatial_shape=None if use_ref else (H, W),\n", " sigma=s,\n", " )\n", - " pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n", + " # Support batched and non-batched inputs\n", + " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", + " if pts.ndim == 2:\n", + " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", + " pts_yx = pts[..., [1, 0]]\n", " data = {\"points\": pts_yx, \"ref\": ref_img}\n", " out = tr(data)\n", - " return out[\"heatmap\"]" + " return out[\"heatmap\"]\n" ] }, { @@ -179,6 +188,7 @@ "num_points = 3 # number of random landmarks\n", "sigma_demo = 3.0 # Gaussian sigma\n", "combine_mode = \"max\" # or 'sum'\n", + "batched_input = True # Set to True to test batched input\n", "\n", "# Sample random (x,y) points within image bounds (user-friendly)\n", "points_xy = np.array(\n", @@ -189,10 +199,15 @@ "\n", "# Convert to (y,x) for the transform\n", "yx_points = points_xy[:, [1, 0]].copy()\n", + "if batched_input:\n", + " yx_points = yx_points[np.newaxis, ...] # Add a batch dimension\n", "\n", "array_tr = GenerateHeatmap(sigma=sigma_demo, spatial_shape=(H, W))\n", "heatmaps = array_tr(yx_points) # now correct orientation\n", "\n", + "if batched_input:\n", + " heatmaps = heatmaps.squeeze(0) # Remove batch dim for plotting\n", + "\n", "if combine_mode == \"max\":\n", " combined = heatmaps.max(axis=0)\n", "elif combine_mode == \"sum\":\n", @@ -230,7 +245,7 @@ " ax.set_title(f\"Point {i}: (x={points_xy[i,0]:.1f}, y={points_xy[i,1]:.1f})\")\n", " ax.set_axis_off()\n", "plt.tight_layout()\n", - "plt.show()" + "plt.show()\n" ] } ], diff --git a/3d_heatmap_pyvista.ipynb b/3d_heatmap_pyvista.ipynb new file mode 100644 index 0000000000..fcd3abdf4c --- /dev/null +++ b/3d_heatmap_pyvista.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ec5556c6", + "metadata": {}, + "source": [ + "\n", + "# 3D Volume ??Heatmap (MONAI) ??**PyVista edition**\n", + "Interactive 3D volume + heatmap rendering using **PyVista (trame backend)**. \n", + "- Points are **(x, y, z)** in *voxel index* order (MONAI/ITK-style physical coordinate order). \n", + "- The base volume is shown in grayscale; the Gaussian heatmap is overlaid in a hot colormap.\n", + "\n", + "> If you don't have the deps, install first:\n", + "```bash\n", + "pip install monai torch numpy matplotlib pyvista[trame] medmnist\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4d634d14", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from monai.transforms.post.array import GenerateHeatmap\n", + "from monai.transforms.post.dictionary import GenerateHeatmapd\n", + "from monai.data import MetaTensor\n", + "\n", + "# Optional data source (MedMNIST)\n", + "HAS_MEDMNIST = False\n", + "try:\n", + " from medmnist import OrganMNIST3D\n", + " HAS_MEDMNIST = True\n", + "except Exception:\n", + " print(\"medmnist not available. Fallback to synthetic volume.\")\n", + "\n", + "# PyVista (3D rendering)\n", + "HAS_PYVISTA = True\n", + "try:\n", + " import pyvista as pv\n", + " pv.set_jupyter_backend(\"trame\") # interactive in-notebook UI\n", + "except Exception as e:\n", + " HAS_PYVISTA = False\n", + " print(\"PyVista or trame backend not available:\", e)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0dbf0348", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Volume shape (D, H, W) = (28, 28, 28)\n" + ] + } + ], + "source": [ + "\n", + "# --- Load a 3D test volume ---\n", + "if HAS_MEDMNIST:\n", + " ds = OrganMNIST3D(split=\"test\", download=True, as_rgb=False)\n", + " vol = np.asarray(ds[0][0])\n", + " if vol.ndim == 4: # (C, D, H, W) -> (D, H, W)\n", + " vol = vol[0]\n", + "else:\n", + " # Synthetic 3D blob (D=H=W=28)\n", + " D = H = W = 28\n", + " z, y, x = np.mgrid[0:D, 0:H, 0:W]\n", + " vol = np.exp(-((x-14)**2 + (y-14)**2 + (z-14)**2) / (2*5.0**2)).astype(np.float32)\n", + "\n", + "# Normalize for nicer visualization\n", + "vol = vol.astype(np.float32)\n", + "if vol.max() > 0:\n", + " vol = vol / float(vol.max())\n", + "\n", + "D, H, W = vol.shape\n", + "print(\"Volume shape (D, H, W) =\", (D, H, W))\n", + "\n", + "# Reference MetaTensor for dictionary-based transform (identity affine)\n", + "ref = MetaTensor(torch.from_numpy(vol).unsqueeze(0), affine=torch.eye(4))\n", + "ref.meta[\"spatial_shape\"] = (D, H, W)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc664bbf", + "metadata": {}, + "outputs": [], + "source": [ + "# --- Heatmap generators ---\n", + "def heatmap3d_array(points_xyz: np.ndarray, sigma: float = 2.0):\n", + " \"\"\"Return (B, N, D, H, W) heatmap from landmarks specified as (x, y, z).\"\"\"\n", + " if points_xyz.ndim == 2: # (N, 3)\n", + " points_xyz = points_xyz[np.newaxis, ...] # -> (1, N, 3)\n", + " points_zyx = points_xyz[..., [2, 1, 0]].copy() # reorder to (z,y,x) for MONAI\n", + " tr = GenerateHeatmap(sigma=float(sigma), spatial_shape=(D, H, W))\n", + " hm = tr(points_zyx) # (B, N, D, H, W)\n", + " return hm\n", + "\n", + "def heatmap3d_dict(points_xyz: np.ndarray, sigma: float = 2.0, use_ref: bool = True):\n", + " \"\"\"Return (B, N, D, H, W) heatmap from landmarks specified as (x, y, z).\"\"\"\n", + " if points_xyz.ndim == 2: # (N, 3)\n", + " points_xyz = points_xyz[np.newaxis, ...] # -> (1, N, 3)\n", + " points_zyx = points_xyz[..., [2, 1, 0]].copy() # reorder to (z,y,x) for MONAI\n", + " tr = GenerateHeatmapd(\n", + " keys=\"points\",\n", + " heatmap_keys=\"heatmap\",\n", + " ref_image_keys=\"ref\" if use_ref else None,\n", + " spatial_shape=None if use_ref else (D, H, W),\n", + " sigma=float(sigma),\n", + " )\n", + " data = {\"points\": points_zyx, \"ref\": ref}\n", + " out = tr(data)\n", + " return out[\"heatmap\"] # Tensor or np.ndarray, shape (B, N, D, H, W)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdcfbb37", + "metadata": {}, + "outputs": [ + { + "ename": "IndexError", + "evalue": "index 14 is out of bounds for axis 0 with size 1", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 11\u001b[39m\n\u001b[32m 9\u001b[39m fig, ax = plt.subplots()\n\u001b[32m 10\u001b[39m ax.imshow(vol[z_idx], cmap=\u001b[33m\"\u001b[39m\u001b[33mgray\u001b[39m\u001b[33m\"\u001b[39m, vmin=\u001b[32m0.0\u001b[39m, vmax=\u001b[32m1.0\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m ax.imshow(\u001b[43mhm\u001b[49m\u001b[43m[\u001b[49m\u001b[43mz_idx\u001b[49m\u001b[43m]\u001b[49m, alpha=\u001b[32m0.6\u001b[39m)\n\u001b[32m 12\u001b[39m ax.set_title(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m2D check @ z=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mz_idx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | input(x,y,z)=(\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcx\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcy\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcz\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m) | sigma=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msigma\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 13\u001b[39m ax.set_axis_off()\n", + "\u001b[31mIndexError\u001b[39m: index 14 is out of bounds for axis 0 with size 1" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# --- Quick 2D sanity check (central slice) ---\n", + "# This confirms that an (x,y,z) input maps to the expected (z,y,x) voxel index.\n", + "points_xyz = np.array([[W // 2, H // 2, D // 2]], dtype=np.float32)\n", + "sigma = 2.0\n", + "\n", + "hm_batch = heatmap3d_array(points_xyz, sigma=sigma)\n", + "hm = hm_batch.squeeze(0).squeeze(0) # (D,H,W)\n", + "z_idx = int(points_xyz[0, 2])\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.imshow(vol[z_idx], cmap=\"gray\", vmin=0.0, vmax=1.0)\n", + "ax.imshow(hm[z_idx], alpha=0.6, cmap=\"hot\")\n", + "ax.set_title(f\"2D check @ z={z_idx} | input(x,y,z)=({points_xyz[0,0]:.1f},{points_xyz[0,1]:.1f},{points_xyz[0,2]:.1f}) | sigma={sigma:.1f}\")\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f033c71", + "metadata": {}, + "outputs": [], + "source": [ + "# --- PyVista 3D volume rendering ---\n", + "def render_3d_pyvista(points_xyz: np.ndarray, sigma: float = 2.0, use_dict: bool = False):\n", + " if not HAS_PYVISTA:\n", + " raise RuntimeError(\"PyVista backend is not available in this environment.\")\n", + "\n", + " # Build heatmap (B, N, D, H, W)\n", + " if use_dict:\n", + " hm_any = heatmap3d_dict(points_xyz, sigma=sigma, use_ref=True)\n", + " else:\n", + " hm_any = heatmap3d_array(points_xyz, sigma=sigma)\n", + "\n", + " # Combine heatmaps (max projection over channel dimension)\n", + " if hasattr(hm_any, \"cpu\"):\n", + " hm_combined = hm_any.cpu().numpy().max(axis=1).squeeze(0) # (D,H,W)\n", + " else:\n", + " hm_combined = hm_any.max(axis=1).squeeze(0)\n", + "\n", + " print(f\"Rendering 3D volume with {points_xyz.shape[0]} landmark(s) | sigma={sigma:.1f}\")\n", + "\n", + " # Wrap numpy arrays directly; PyVista understands (Z,Y,X) ordering for UniformGrid\n", + " grid_vol = pv.wrap(vol) # (D,H,W)\n", + " grid_hm = pv.wrap(hm_combined) # (D,H,W)\n", + "\n", + " p = pv.Plotter()\n", + " # Base volume (grayscale-ish)\n", + " p.add_volume(grid_vol, cmap=\"bone\", opacity=\"sigmoid_6\", shade=True)\n", + " # Heatmap overlay (hot colormap, more transparent)\n", + " p.add_volume(grid_hm, cmap=\"hot\", opacity=\"sigmoid_3\", shade=True)\n", + " p.show()\n", + "\n", + "# Example: try moving the landmark; call this cell repeatedly with new values.\n", + "render_3d_pyvista(np.array([[14.0, 10.0, 18.0], [20.0, 20.0, 10.0]]), sigma=3.0, use_dict=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".conda", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 4d419819d6..f710120293 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -799,35 +799,47 @@ def __call__( ) -> NdarrayOrTensor: original_points = points points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) - if points_t.ndim != 2: - raise ValueError("points must be a 2D array with shape (num_points, spatial_dims).") - device = points_t.device - num_points, spatial_dims = points_t.shape - if spatial_dims not in (2, 3): + + is_batched = points_t.ndim == 3 + if not is_batched: + if points_t.ndim != 2: + raise ValueError( + "points must be a 2D or 3D array with shape (num_points, spatial_dims) or (B, num_points, spatial_dims)." + ) + points_t = points_t.unsqueeze(0) # Add a batch dimension + + if points_t.shape[-1] not in (2, 3): raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.") + device = points_t.device + batch_size, num_points, spatial_dims = points_t.shape + target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims) sigma = self._resolve_sigma(spatial_dims) radius = tuple(int(np.ceil(self.truncate * s)) for s in sigma) - heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device) + heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device) image_bounds = tuple(int(s) for s in target_shape) - for idx, center in enumerate(points_t): - center_vals = center.tolist() - if not np.all(np.isfinite(center_vals)): - continue - if not self._is_inside(center_vals, image_bounds): - continue - window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device) - if window_slices is None: - continue - region = heatmap[(idx, *window_slices)] - gaussian = self._evaluate_gaussian(coord_shifts, sigma) - torch.maximum(region, gaussian, out=region) - if self.normalize: - max_val = heatmap[idx].max() - if max_val.item() > 0: - heatmap[idx] /= max_val + for b_idx in range(batch_size): + for idx, center in enumerate(points_t[b_idx]): + center_vals = center.tolist() + if not np.all(np.isfinite(center_vals)): + continue + if not self._is_inside(center_vals, image_bounds): + continue + window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device) + if window_slices is None: + continue + region = heatmap[(b_idx, idx, *window_slices)] + gaussian = self._evaluate_gaussian(coord_shifts, sigma) + torch.maximum(region, gaussian, out=region) + if self.normalize: + max_val = heatmap[b_idx, idx].max() + if max_val.item() > 0: + heatmap[b_idx, idx] /= max_val + + if not is_batched: + heatmap = heatmap.squeeze(0) target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 02b939a9bb..813804dcc6 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -609,9 +609,9 @@ def _determine_shape( if static_shape is not None: return static_shape points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) - if points_t.ndim != 2: - raise ValueError("landmark arrays must be 2D with shape (num_points, spatial_dims).") - spatial_dims = int(points_t.shape[1]) + if points_t.ndim not in (2, 3): + raise ValueError("landmark arrays must be 2D or 3D with shape (num_points, spatial_dims) or (B, num_points, spatial_dims).") + spatial_dims = int(points_t.shape[-1]) if ref_key is not None and ref_key in data: return self._shape_from_reference(data[ref_key], spatial_dims) raise ValueError( diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py index 123ac087c5..ccd3e28bdb 100644 --- a/tests/transforms/test_generate_heatmap.py +++ b/tests/transforms/test_generate_heatmap.py @@ -171,6 +171,58 @@ def test_dict_dtype_control(self): self.assertEqual(tuple(hm.shape), (1, 10, 10, 10)) self.assertEqual(hm.dtype, torch.float16) + def test_array_batched_3d(self): + points = np.array( + [ + [[4.2, 7.8, 1.0]], # Batch 1 + [[12.3, 3.6, 2.0]], # Batch 2 + ], + dtype=np.float32, + ) + transform = GenerateHeatmap(sigma=1.5, spatial_shape=(16, 16, 16)) + + heatmap = transform(points) + + self.assertEqual(heatmap.shape, (2, 1, 16, 16, 16)) + self.assertEqual(heatmap.dtype, np.float32) + np.testing.assert_allclose(heatmap.max(axis=(2, 3, 4)), np.ones((2, 1)), rtol=1e-5, atol=1e-5) + + # Check peaks for each batch item + for i in range(2): + peak = _argmax_nd(heatmap[i, 0]) + self.assertTrue(np.all(np.abs(peak - points[i, 0]) <= 1.0), msg=f"peak={peak}, point={points[i, 0]}") + + def test_dict_batched_with_ref(self): + points = torch.tensor( + [ + [[1.5, 2.5, 3.5]], # Batch 1 + [[4.5, 5.5, 6.5]], # Batch 2 + ], + dtype=torch.float32, + ) + affine = torch.eye(4) + # A single reference image is used for the whole batch + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + data = {"points": points, "image": image} + + transform = GenerateHeatmapd( + keys="points", + heatmap_keys="heatmap", + ref_image_keys="image", + sigma=1.0, + ) + + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), (2, 1, 8, 8, 8)) + self.assertEqual(heatmap.meta["spatial_shape"], (8, 8, 8)) + assert_allclose(heatmap.affine, image.affine, type_test=False) + max_vals = heatmap.max(dim=2)[0].max(dim=2)[0].max(dim=2)[0] + np.testing.assert_allclose(max_vals.cpu().numpy(), np.ones((2, 1)), rtol=1e-5, atol=1e-5) + if __name__ == "__main__": unittest.main() From 08a715ad36051f9fc04f41358b4f670adeeb7c1d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Sep 2025 07:54:11 +0000 Subject: [PATCH 04/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- 2d_mdtest.ipynb | 539 +++++++++++----------- 3d_heatmap_pyvista.ipynb | 470 +++++++++---------- tests/transforms/test_generate_heatmap.py | 1 - 3 files changed, 504 insertions(+), 506 deletions(-) diff --git a/2d_mdtest.ipynb b/2d_mdtest.ipynb index 73cd466375..5235740c81 100644 --- a/2d_mdtest.ipynb +++ b/2d_mdtest.ipynb @@ -1,273 +1,272 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 2D Landmark → Heatmap (MNIST / MedMNIST) — with MONAI\n", - "Interactive demo converting clicked landmarks to Gaussian heatmaps **using MONAI**.\n", - "\n", - "**Modes**\n", - "1) Array transform: `GenerateHeatmap`\n", - "2) Dict transform: `GenerateHeatmapd` with optional `MetaTensor` reference" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# Installation requirements for interactive notebook\n", - "# %pip install --upgrade pip\n", - "# %pip install torch torchvision monai medmnist matplotlib ipywidgets\n", - "#\n", - "# For JupyterLab users, also run:\n", - "# %pip install jupyterlab-widgets\n", - "#\n", - "# For interactive matplotlib (optional, used in first implementation):\n", - "# %pip install ipympl" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n", - "import matplotlib.pyplot as plt\n", - "from torchvision import datasets, transforms\n", - "from monai.transforms.post.array import GenerateHeatmap\n", - "from monai.transforms.post.dictionary import GenerateHeatmapd\n", - "from monai.data import MetaTensor\n", - "\n", - "try:\n", - " import medmnist\n", - " from medmnist import PathMNIST\n", - "\n", - " HAS_MEDMNIST = True\n", - "except Exception:\n", - " HAS_MEDMNIST = False\n", - " print(\"medmnist not available. Run `pip install medmnist` to enable PathMNIST.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Image shape: 28x28\n" - ] - } - ], - "source": [ - "# Load a small 2D image\n", - "use_medmnist = False\n", - "if use_medmnist and HAS_MEDMNIST:\n", - " ds = PathMNIST(split=\"test\", download=True, as_rgb=True)\n", - " img = np.asarray(ds[0][0]).mean(axis=2).astype(np.float32)\n", - "else:\n", - " mnist = datasets.MNIST(root=\"./data\", train=False, download=True, transform=transforms.ToTensor())\n", - " img = mnist[0][0][0].numpy().astype(np.float32)\n", - "\n", - "if img.max() > 0:\n", - " img = img / float(img.max())\n", - "H, W = img.shape\n", - "print(f\"Image shape: {H}x{W}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Heatmap helper using GenerateHeatmap\n", - "# NOTE: GenerateHeatmap expects points in (y, x) == (row, col) order matching array indexing.\n", - "# This wrapper accepts user-friendly (x, y) and internally reorders to (y, x).\n", - "# It now supports batched inputs.\n", - "\n", - "sigma = 3.0\n", - "\n", - "\n", - "def heatmap_with_array_transform(x, y, sigma_override=None):\n", - " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", - " tr = GenerateHeatmap(sigma=s, spatial_shape=(H, W))\n", - " # Reorder (x,y) -> (y,x) for the transform\n", - " # Support batched and non-batched inputs\n", - " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", - " if pts.ndim == 2:\n", - " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", - " pts_yx = pts[..., [1, 0]]\n", - " return tr(pts_yx) # (B, N, H, W)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "affine = torch.eye(4)\n", - "ref_img = MetaTensor(torch.from_numpy(img).unsqueeze(0), affine=affine)\n", - "ref_img.meta[\"spatial_shape\"] = (H, W)\n", - "\n", - "# Dictionary version wrapper also accepts (x,y) and converts to (y,x)\n", - "\n", - "\n", - "def heatmap_with_dict_transform(x, y, sigma_override=None, use_ref=True):\n", - " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", - " tr = GenerateHeatmapd(\n", - " keys=\"points\",\n", - " heatmap_keys=\"heatmap\",\n", - " ref_image_keys=\"ref\" if use_ref else None,\n", - " spatial_shape=None if use_ref else (H, W),\n", - " sigma=s,\n", - " )\n", - " # Support batched and non-batched inputs\n", - " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", - " if pts.ndim == 2:\n", - " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", - " pts_yx = pts[..., [1, 0]]\n", - " data = {\"points\": pts_yx, \"ref\": ref_img}\n", - " out = tr(data)\n", - " return out[\"heatmap\"]\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Random points (x, y):\n", - "[[24.567745 18.05713 ]\n", - " [15.933253 2.4389822]\n", - " [21.020788 17.367664 ]]\n" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Simple random landmark → heatmap example (no interactivity)\n", - "# Re-run this cell to sample new random points and regenerate heatmaps.\n", - "# INTERNAL NOTE: GenerateHeatmap consumes (row=y, col=x). We sample (x,y) for user readability and convert.\n", - "\n", - "import random\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Parameters\n", - "num_points = 3 # number of random landmarks\n", - "sigma_demo = 3.0 # Gaussian sigma\n", - "combine_mode = \"max\" # or 'sum'\n", - "batched_input = True # Set to True to test batched input\n", - "\n", - "# Sample random (x,y) points within image bounds (user-friendly)\n", - "points_xy = np.array(\n", - " [[random.uniform(0, W - 1), random.uniform(0, H - 1)] for _ in range(num_points)], dtype=np.float32\n", - ") # (N,2)\n", - "print(\"Random points (x, y):\")\n", - "print(points_xy)\n", - "\n", - "# Convert to (y,x) for the transform\n", - "yx_points = points_xy[:, [1, 0]].copy()\n", - "if batched_input:\n", - " yx_points = yx_points[np.newaxis, ...] # Add a batch dimension\n", - "\n", - "array_tr = GenerateHeatmap(sigma=sigma_demo, spatial_shape=(H, W))\n", - "heatmaps = array_tr(yx_points) # now correct orientation\n", - "\n", - "if batched_input:\n", - " heatmaps = heatmaps.squeeze(0) # Remove batch dim for plotting\n", - "\n", - "if combine_mode == \"max\":\n", - " combined = heatmaps.max(axis=0)\n", - "elif combine_mode == \"sum\":\n", - " combined = heatmaps.sum(axis=0)\n", - " if combined.max() > 0:\n", - " combined = combined / combined.max()\n", - "else:\n", - " raise ValueError(\"combine_mode must be 'max' or 'sum'\")\n", - "\n", - "# Plot\n", - "fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", - "axes[0].imshow(img, cmap=\"gray\", vmin=0.0, vmax=1.0, origin=\"upper\")\n", - "axes[0].set_title(\"Base Image\")\n", - "axes[0].set_axis_off()\n", - "for x, y in points_xy:\n", - " axes[0].plot(x, y, \"r+\", markersize=12, markeredgewidth=2)\n", - "\n", - "axes[1].imshow(img, cmap=\"gray\", vmin=0.0, vmax=1.0, origin=\"upper\")\n", - "axes[1].imshow(combined, alpha=0.6, cmap=\"hot\", origin=\"upper\")\n", - "axes[1].set_title(f\"Combined Heatmap (mode={combine_mode}, sigma={sigma_demo})\")\n", - "axes[1].set_axis_off()\n", - "for x, y in points_xy:\n", - " axes[1].plot(x, y, \"c+\", markersize=12, markeredgewidth=2)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Individual channels\n", - "fig2, axes2 = plt.subplots(1, num_points, figsize=(4 * num_points, 4))\n", - "if num_points == 1:\n", - " axes2 = [axes2]\n", - "for i, ax in enumerate(axes2):\n", - " ax.imshow(heatmaps[i], cmap=\"hot\", origin=\"upper\")\n", - " ax.plot(points_xy[i, 0], points_xy[i, 1], \"w+\", markersize=12, markeredgewidth=2)\n", - " ax.set_title(f\"Point {i}: (x={points_xy[i,0]:.1f}, y={points_xy[i,1]:.1f})\")\n", - " ax.set_axis_off()\n", - "plt.tight_layout()\n", - "plt.show()\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".conda", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.13" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2D Landmark → Heatmap (MNIST / MedMNIST) — with MONAI\n", + "Interactive demo converting clicked landmarks to Gaussian heatmaps **using MONAI**.\n", + "\n", + "**Modes**\n", + "1) Array transform: `GenerateHeatmap`\n", + "2) Dict transform: `GenerateHeatmapd` with optional `MetaTensor` reference" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Installation requirements for interactive notebook\n", + "# %pip install --upgrade pip\n", + "# %pip install torch torchvision monai medmnist matplotlib ipywidgets\n", + "#\n", + "# For JupyterLab users, also run:\n", + "# %pip install jupyterlab-widgets\n", + "#\n", + "# For interactive matplotlib (optional, used in first implementation):\n", + "# %pip install ipympl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from torchvision import datasets, transforms\n", + "from monai.transforms.post.array import GenerateHeatmap\n", + "from monai.transforms.post.dictionary import GenerateHeatmapd\n", + "from monai.data import MetaTensor\n", + "\n", + "try:\n", + " import medmnist\n", + " from medmnist import PathMNIST\n", + "\n", + " HAS_MEDMNIST = True\n", + "except Exception:\n", + " HAS_MEDMNIST = False\n", + " print(\"medmnist not available. Run `pip install medmnist` to enable PathMNIST.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image shape: 28x28\n" + ] + } + ], + "source": [ + "# Load a small 2D image\n", + "use_medmnist = False\n", + "if use_medmnist and HAS_MEDMNIST:\n", + " ds = PathMNIST(split=\"test\", download=True, as_rgb=True)\n", + " img = np.asarray(ds[0][0]).mean(axis=2).astype(np.float32)\n", + "else:\n", + " mnist = datasets.MNIST(root=\"./data\", train=False, download=True, transform=transforms.ToTensor())\n", + " img = mnist[0][0][0].numpy().astype(np.float32)\n", + "\n", + "if img.max() > 0:\n", + " img = img / float(img.max())\n", + "H, W = img.shape\n", + "print(f\"Image shape: {H}x{W}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Heatmap helper using GenerateHeatmap\n", + "# NOTE: GenerateHeatmap expects points in (y, x) == (row, col) order matching array indexing.\n", + "# This wrapper accepts user-friendly (x, y) and internally reorders to (y, x).\n", + "# It now supports batched inputs.\n", + "\n", + "sigma = 3.0\n", + "\n", + "\n", + "def heatmap_with_array_transform(x, y, sigma_override=None):\n", + " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", + " tr = GenerateHeatmap(sigma=s, spatial_shape=(H, W))\n", + " # Reorder (x,y) -> (y,x) for the transform\n", + " # Support batched and non-batched inputs\n", + " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", + " if pts.ndim == 2:\n", + " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", + " pts_yx = pts[..., [1, 0]]\n", + " return tr(pts_yx) # (B, N, H, W)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "affine = torch.eye(4)\n", + "ref_img = MetaTensor(torch.from_numpy(img).unsqueeze(0), affine=affine)\n", + "ref_img.meta[\"spatial_shape\"] = (H, W)\n", + "\n", + "# Dictionary version wrapper also accepts (x,y) and converts to (y,x)\n", + "\n", + "\n", + "def heatmap_with_dict_transform(x, y, sigma_override=None, use_ref=True):\n", + " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", + " tr = GenerateHeatmapd(\n", + " keys=\"points\",\n", + " heatmap_keys=\"heatmap\",\n", + " ref_image_keys=\"ref\" if use_ref else None,\n", + " spatial_shape=None if use_ref else (H, W),\n", + " sigma=s,\n", + " )\n", + " # Support batched and non-batched inputs\n", + " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", + " if pts.ndim == 2:\n", + " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", + " pts_yx = pts[..., [1, 0]]\n", + " data = {\"points\": pts_yx, \"ref\": ref_img}\n", + " out = tr(data)\n", + " return out[\"heatmap\"]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Random points (x, y):\n", + "[[24.567745 18.05713 ]\n", + " [15.933253 2.4389822]\n", + " [21.020788 17.367664 ]]\n" + ] }, - "nbformat": 4, - "nbformat_minor": 2 + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Simple random landmark → heatmap example (no interactivity)\n", + "# Re-run this cell to sample new random points and regenerate heatmaps.\n", + "# INTERNAL NOTE: GenerateHeatmap consumes (row=y, col=x). We sample (x,y) for user readability and convert.\n", + "\n", + "import random\n", + "\n", + "# Parameters\n", + "num_points = 3 # number of random landmarks\n", + "sigma_demo = 3.0 # Gaussian sigma\n", + "combine_mode = \"max\" # or 'sum'\n", + "batched_input = True # Set to True to test batched input\n", + "\n", + "# Sample random (x,y) points within image bounds (user-friendly)\n", + "points_xy = np.array(\n", + " [[random.uniform(0, W - 1), random.uniform(0, H - 1)] for _ in range(num_points)], dtype=np.float32\n", + ") # (N,2)\n", + "print(\"Random points (x, y):\")\n", + "print(points_xy)\n", + "\n", + "# Convert to (y,x) for the transform\n", + "yx_points = points_xy[:, [1, 0]].copy()\n", + "if batched_input:\n", + " yx_points = yx_points[np.newaxis, ...] # Add a batch dimension\n", + "\n", + "array_tr = GenerateHeatmap(sigma=sigma_demo, spatial_shape=(H, W))\n", + "heatmaps = array_tr(yx_points) # now correct orientation\n", + "\n", + "if batched_input:\n", + " heatmaps = heatmaps.squeeze(0) # Remove batch dim for plotting\n", + "\n", + "if combine_mode == \"max\":\n", + " combined = heatmaps.max(axis=0)\n", + "elif combine_mode == \"sum\":\n", + " combined = heatmaps.sum(axis=0)\n", + " if combined.max() > 0:\n", + " combined = combined / combined.max()\n", + "else:\n", + " raise ValueError(\"combine_mode must be 'max' or 'sum'\")\n", + "\n", + "# Plot\n", + "fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", + "axes[0].imshow(img, cmap=\"gray\", vmin=0.0, vmax=1.0, origin=\"upper\")\n", + "axes[0].set_title(\"Base Image\")\n", + "axes[0].set_axis_off()\n", + "for x, y in points_xy:\n", + " axes[0].plot(x, y, \"r+\", markersize=12, markeredgewidth=2)\n", + "\n", + "axes[1].imshow(img, cmap=\"gray\", vmin=0.0, vmax=1.0, origin=\"upper\")\n", + "axes[1].imshow(combined, alpha=0.6, cmap=\"hot\", origin=\"upper\")\n", + "axes[1].set_title(f\"Combined Heatmap (mode={combine_mode}, sigma={sigma_demo})\")\n", + "axes[1].set_axis_off()\n", + "for x, y in points_xy:\n", + " axes[1].plot(x, y, \"c+\", markersize=12, markeredgewidth=2)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Individual channels\n", + "fig2, axes2 = plt.subplots(1, num_points, figsize=(4 * num_points, 4))\n", + "if num_points == 1:\n", + " axes2 = [axes2]\n", + "for i, ax in enumerate(axes2):\n", + " ax.imshow(heatmaps[i], cmap=\"hot\", origin=\"upper\")\n", + " ax.plot(points_xy[i, 0], points_xy[i, 1], \"w+\", markersize=12, markeredgewidth=2)\n", + " ax.set_title(f\"Point {i}: (x={points_xy[i,0]:.1f}, y={points_xy[i,1]:.1f})\")\n", + " ax.set_axis_off()\n", + "plt.tight_layout()\n", + "plt.show()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".conda", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/3d_heatmap_pyvista.ipynb b/3d_heatmap_pyvista.ipynb index fcd3abdf4c..b7be6075f3 100644 --- a/3d_heatmap_pyvista.ipynb +++ b/3d_heatmap_pyvista.ipynb @@ -1,238 +1,238 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "ec5556c6", - "metadata": {}, - "source": [ - "\n", - "# 3D Volume ??Heatmap (MONAI) ??**PyVista edition**\n", - "Interactive 3D volume + heatmap rendering using **PyVista (trame backend)**. \n", - "- Points are **(x, y, z)** in *voxel index* order (MONAI/ITK-style physical coordinate order). \n", - "- The base volume is shown in grayscale; the Gaussian heatmap is overlaid in a hot colormap.\n", - "\n", - "> If you don't have the deps, install first:\n", - "```bash\n", - "pip install monai torch numpy matplotlib pyvista[trame] medmnist\n", - "```\n" - ] + "cells": [ + { + "cell_type": "markdown", + "id": "ec5556c6", + "metadata": {}, + "source": [ + "\n", + "# 3D Volume ??Heatmap (MONAI) ??**PyVista edition**\n", + "Interactive 3D volume + heatmap rendering using **PyVista (trame backend)**. \n", + "- Points are **(x, y, z)** in *voxel index* order (MONAI/ITK-style physical coordinate order). \n", + "- The base volume is shown in grayscale; the Gaussian heatmap is overlaid in a hot colormap.\n", + "\n", + "> If you don't have the deps, install first:\n", + "```bash\n", + "pip install monai torch numpy matplotlib pyvista[trame] medmnist\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4d634d14", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from monai.transforms.post.array import GenerateHeatmap\n", + "from monai.transforms.post.dictionary import GenerateHeatmapd\n", + "from monai.data import MetaTensor\n", + "\n", + "# Optional data source (MedMNIST)\n", + "HAS_MEDMNIST = False\n", + "try:\n", + " from medmnist import OrganMNIST3D\n", + " HAS_MEDMNIST = True\n", + "except Exception:\n", + " print(\"medmnist not available. Fallback to synthetic volume.\")\n", + "\n", + "# PyVista (3D rendering)\n", + "HAS_PYVISTA = True\n", + "try:\n", + " import pyvista as pv\n", + " pv.set_jupyter_backend(\"trame\") # interactive in-notebook UI\n", + "except Exception as e:\n", + " HAS_PYVISTA = False\n", + " print(\"PyVista or trame backend not available:\", e)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0dbf0348", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Volume shape (D, H, W) = (28, 28, 28)\n" + ] + } + ], + "source": [ + "\n", + "# --- Load a 3D test volume ---\n", + "if HAS_MEDMNIST:\n", + " ds = OrganMNIST3D(split=\"test\", download=True, as_rgb=False)\n", + " vol = np.asarray(ds[0][0])\n", + " if vol.ndim == 4: # (C, D, H, W) -> (D, H, W)\n", + " vol = vol[0]\n", + "else:\n", + " # Synthetic 3D blob (D=H=W=28)\n", + " D = H = W = 28\n", + " z, y, x = np.mgrid[0:D, 0:H, 0:W]\n", + " vol = np.exp(-((x-14)**2 + (y-14)**2 + (z-14)**2) / (2*5.0**2)).astype(np.float32)\n", + "\n", + "# Normalize for nicer visualization\n", + "vol = vol.astype(np.float32)\n", + "if vol.max() > 0:\n", + " vol = vol / float(vol.max())\n", + "\n", + "D, H, W = vol.shape\n", + "print(\"Volume shape (D, H, W) =\", (D, H, W))\n", + "\n", + "# Reference MetaTensor for dictionary-based transform (identity affine)\n", + "ref = MetaTensor(torch.from_numpy(vol).unsqueeze(0), affine=torch.eye(4))\n", + "ref.meta[\"spatial_shape\"] = (D, H, W)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc664bbf", + "metadata": {}, + "outputs": [], + "source": [ + "# --- Heatmap generators ---\n", + "def heatmap3d_array(points_xyz: np.ndarray, sigma: float = 2.0):\n", + " \"\"\"Return (B, N, D, H, W) heatmap from landmarks specified as (x, y, z).\"\"\"\n", + " if points_xyz.ndim == 2: # (N, 3)\n", + " points_xyz = points_xyz[np.newaxis, ...] # -> (1, N, 3)\n", + " points_zyx = points_xyz[..., [2, 1, 0]].copy() # reorder to (z,y,x) for MONAI\n", + " tr = GenerateHeatmap(sigma=float(sigma), spatial_shape=(D, H, W))\n", + " hm = tr(points_zyx) # (B, N, D, H, W)\n", + " return hm\n", + "\n", + "def heatmap3d_dict(points_xyz: np.ndarray, sigma: float = 2.0, use_ref: bool = True):\n", + " \"\"\"Return (B, N, D, H, W) heatmap from landmarks specified as (x, y, z).\"\"\"\n", + " if points_xyz.ndim == 2: # (N, 3)\n", + " points_xyz = points_xyz[np.newaxis, ...] # -> (1, N, 3)\n", + " points_zyx = points_xyz[..., [2, 1, 0]].copy() # reorder to (z,y,x) for MONAI\n", + " tr = GenerateHeatmapd(\n", + " keys=\"points\",\n", + " heatmap_keys=\"heatmap\",\n", + " ref_image_keys=\"ref\" if use_ref else None,\n", + " spatial_shape=None if use_ref else (D, H, W),\n", + " sigma=float(sigma),\n", + " )\n", + " data = {\"points\": points_zyx, \"ref\": ref}\n", + " out = tr(data)\n", + " return out[\"heatmap\"] # Tensor or np.ndarray, shape (B, N, D, H, W)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdcfbb37", + "metadata": {}, + "outputs": [ + { + "ename": "IndexError", + "evalue": "index 14 is out of bounds for axis 0 with size 1", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 11\u001b[39m\n\u001b[32m 9\u001b[39m fig, ax = plt.subplots()\n\u001b[32m 10\u001b[39m ax.imshow(vol[z_idx], cmap=\u001b[33m\"\u001b[39m\u001b[33mgray\u001b[39m\u001b[33m\"\u001b[39m, vmin=\u001b[32m0.0\u001b[39m, vmax=\u001b[32m1.0\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m ax.imshow(\u001b[43mhm\u001b[49m\u001b[43m[\u001b[49m\u001b[43mz_idx\u001b[49m\u001b[43m]\u001b[49m, alpha=\u001b[32m0.6\u001b[39m)\n\u001b[32m 12\u001b[39m ax.set_title(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m2D check @ z=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mz_idx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | input(x,y,z)=(\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcx\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcy\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcz\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m) | sigma=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msigma\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 13\u001b[39m ax.set_axis_off()\n", + "\u001b[31mIndexError\u001b[39m: index 14 is out of bounds for axis 0 with size 1" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] }, - { - "cell_type": "code", - "execution_count": 1, - "id": "4d634d14", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "import numpy as np\n", - "import torch\n", - "import matplotlib.pyplot as plt\n", - "\n", - "from monai.transforms.post.array import GenerateHeatmap\n", - "from monai.transforms.post.dictionary import GenerateHeatmapd\n", - "from monai.data import MetaTensor\n", - "\n", - "# Optional data source (MedMNIST)\n", - "HAS_MEDMNIST = False\n", - "try:\n", - " from medmnist import OrganMNIST3D\n", - " HAS_MEDMNIST = True\n", - "except Exception:\n", - " print(\"medmnist not available. Fallback to synthetic volume.\")\n", - "\n", - "# PyVista (3D rendering)\n", - "HAS_PYVISTA = True\n", - "try:\n", - " import pyvista as pv\n", - " pv.set_jupyter_backend(\"trame\") # interactive in-notebook UI\n", - "except Exception as e:\n", - " HAS_PYVISTA = False\n", - " print(\"PyVista or trame backend not available:\", e)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "0dbf0348", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Volume shape (D, H, W) = (28, 28, 28)\n" - ] - } - ], - "source": [ - "\n", - "# --- Load a 3D test volume ---\n", - "if HAS_MEDMNIST:\n", - " ds = OrganMNIST3D(split=\"test\", download=True, as_rgb=False)\n", - " vol = np.asarray(ds[0][0])\n", - " if vol.ndim == 4: # (C, D, H, W) -> (D, H, W)\n", - " vol = vol[0]\n", - "else:\n", - " # Synthetic 3D blob (D=H=W=28)\n", - " D = H = W = 28\n", - " z, y, x = np.mgrid[0:D, 0:H, 0:W]\n", - " vol = np.exp(-((x-14)**2 + (y-14)**2 + (z-14)**2) / (2*5.0**2)).astype(np.float32)\n", - "\n", - "# Normalize for nicer visualization\n", - "vol = vol.astype(np.float32)\n", - "if vol.max() > 0:\n", - " vol = vol / float(vol.max())\n", - "\n", - "D, H, W = vol.shape\n", - "print(\"Volume shape (D, H, W) =\", (D, H, W))\n", - "\n", - "# Reference MetaTensor for dictionary-based transform (identity affine)\n", - "ref = MetaTensor(torch.from_numpy(vol).unsqueeze(0), affine=torch.eye(4))\n", - "ref.meta[\"spatial_shape\"] = (D, H, W)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cc664bbf", - "metadata": {}, - "outputs": [], - "source": [ - "# --- Heatmap generators ---\n", - "def heatmap3d_array(points_xyz: np.ndarray, sigma: float = 2.0):\n", - " \"\"\"Return (B, N, D, H, W) heatmap from landmarks specified as (x, y, z).\"\"\"\n", - " if points_xyz.ndim == 2: # (N, 3)\n", - " points_xyz = points_xyz[np.newaxis, ...] # -> (1, N, 3)\n", - " points_zyx = points_xyz[..., [2, 1, 0]].copy() # reorder to (z,y,x) for MONAI\n", - " tr = GenerateHeatmap(sigma=float(sigma), spatial_shape=(D, H, W))\n", - " hm = tr(points_zyx) # (B, N, D, H, W)\n", - " return hm\n", - "\n", - "def heatmap3d_dict(points_xyz: np.ndarray, sigma: float = 2.0, use_ref: bool = True):\n", - " \"\"\"Return (B, N, D, H, W) heatmap from landmarks specified as (x, y, z).\"\"\"\n", - " if points_xyz.ndim == 2: # (N, 3)\n", - " points_xyz = points_xyz[np.newaxis, ...] # -> (1, N, 3)\n", - " points_zyx = points_xyz[..., [2, 1, 0]].copy() # reorder to (z,y,x) for MONAI\n", - " tr = GenerateHeatmapd(\n", - " keys=\"points\",\n", - " heatmap_keys=\"heatmap\",\n", - " ref_image_keys=\"ref\" if use_ref else None,\n", - " spatial_shape=None if use_ref else (D, H, W),\n", - " sigma=float(sigma),\n", - " )\n", - " data = {\"points\": points_zyx, \"ref\": ref}\n", - " out = tr(data)\n", - " return out[\"heatmap\"] # Tensor or np.ndarray, shape (B, N, D, H, W)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fdcfbb37", - "metadata": {}, - "outputs": [ - { - "ename": "IndexError", - "evalue": "index 14 is out of bounds for axis 0 with size 1", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 11\u001b[39m\n\u001b[32m 9\u001b[39m fig, ax = plt.subplots()\n\u001b[32m 10\u001b[39m ax.imshow(vol[z_idx], cmap=\u001b[33m\"\u001b[39m\u001b[33mgray\u001b[39m\u001b[33m\"\u001b[39m, vmin=\u001b[32m0.0\u001b[39m, vmax=\u001b[32m1.0\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m ax.imshow(\u001b[43mhm\u001b[49m\u001b[43m[\u001b[49m\u001b[43mz_idx\u001b[49m\u001b[43m]\u001b[49m, alpha=\u001b[32m0.6\u001b[39m)\n\u001b[32m 12\u001b[39m ax.set_title(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m2D check @ z=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mz_idx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | input(x,y,z)=(\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcx\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcy\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcz\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m) | sigma=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msigma\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 13\u001b[39m ax.set_axis_off()\n", - "\u001b[31mIndexError\u001b[39m: index 14 is out of bounds for axis 0 with size 1" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# --- Quick 2D sanity check (central slice) ---\n", - "# This confirms that an (x,y,z) input maps to the expected (z,y,x) voxel index.\n", - "points_xyz = np.array([[W // 2, H // 2, D // 2]], dtype=np.float32)\n", - "sigma = 2.0\n", - "\n", - "hm_batch = heatmap3d_array(points_xyz, sigma=sigma)\n", - "hm = hm_batch.squeeze(0).squeeze(0) # (D,H,W)\n", - "z_idx = int(points_xyz[0, 2])\n", - "\n", - "fig, ax = plt.subplots()\n", - "ax.imshow(vol[z_idx], cmap=\"gray\", vmin=0.0, vmax=1.0)\n", - "ax.imshow(hm[z_idx], alpha=0.6, cmap=\"hot\")\n", - "ax.set_title(f\"2D check @ z={z_idx} | input(x,y,z)=({points_xyz[0,0]:.1f},{points_xyz[0,1]:.1f},{points_xyz[0,2]:.1f}) | sigma={sigma:.1f}\")\n", - "ax.set_axis_off()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4f033c71", - "metadata": {}, - "outputs": [], - "source": [ - "# --- PyVista 3D volume rendering ---\n", - "def render_3d_pyvista(points_xyz: np.ndarray, sigma: float = 2.0, use_dict: bool = False):\n", - " if not HAS_PYVISTA:\n", - " raise RuntimeError(\"PyVista backend is not available in this environment.\")\n", - "\n", - " # Build heatmap (B, N, D, H, W)\n", - " if use_dict:\n", - " hm_any = heatmap3d_dict(points_xyz, sigma=sigma, use_ref=True)\n", - " else:\n", - " hm_any = heatmap3d_array(points_xyz, sigma=sigma)\n", - "\n", - " # Combine heatmaps (max projection over channel dimension)\n", - " if hasattr(hm_any, \"cpu\"):\n", - " hm_combined = hm_any.cpu().numpy().max(axis=1).squeeze(0) # (D,H,W)\n", - " else:\n", - " hm_combined = hm_any.max(axis=1).squeeze(0)\n", - "\n", - " print(f\"Rendering 3D volume with {points_xyz.shape[0]} landmark(s) | sigma={sigma:.1f}\")\n", - "\n", - " # Wrap numpy arrays directly; PyVista understands (Z,Y,X) ordering for UniformGrid\n", - " grid_vol = pv.wrap(vol) # (D,H,W)\n", - " grid_hm = pv.wrap(hm_combined) # (D,H,W)\n", - "\n", - " p = pv.Plotter()\n", - " # Base volume (grayscale-ish)\n", - " p.add_volume(grid_vol, cmap=\"bone\", opacity=\"sigmoid_6\", shade=True)\n", - " # Heatmap overlay (hot colormap, more transparent)\n", - " p.add_volume(grid_hm, cmap=\"hot\", opacity=\"sigmoid_3\", shade=True)\n", - " p.show()\n", - "\n", - "# Example: try moving the landmark; call this cell repeatedly with new values.\n", - "render_3d_pyvista(np.array([[14.0, 10.0, 18.0], [20.0, 20.0, 10.0]]), sigma=3.0, use_dict=False)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".conda", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# --- Quick 2D sanity check (central slice) ---\n", + "# This confirms that an (x,y,z) input maps to the expected (z,y,x) voxel index.\n", + "points_xyz = np.array([[W // 2, H // 2, D // 2]], dtype=np.float32)\n", + "sigma = 2.0\n", + "\n", + "hm_batch = heatmap3d_array(points_xyz, sigma=sigma)\n", + "hm = hm_batch.squeeze(0).squeeze(0) # (D,H,W)\n", + "z_idx = int(points_xyz[0, 2])\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.imshow(vol[z_idx], cmap=\"gray\", vmin=0.0, vmax=1.0)\n", + "ax.imshow(hm[z_idx], alpha=0.6, cmap=\"hot\")\n", + "ax.set_title(f\"2D check @ z={z_idx} | input(x,y,z)=({points_xyz[0,0]:.1f},{points_xyz[0,1]:.1f},{points_xyz[0,2]:.1f}) | sigma={sigma:.1f}\")\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f033c71", + "metadata": {}, + "outputs": [], + "source": [ + "# --- PyVista 3D volume rendering ---\n", + "def render_3d_pyvista(points_xyz: np.ndarray, sigma: float = 2.0, use_dict: bool = False):\n", + " if not HAS_PYVISTA:\n", + " raise RuntimeError(\"PyVista backend is not available in this environment.\")\n", + "\n", + " # Build heatmap (B, N, D, H, W)\n", + " if use_dict:\n", + " hm_any = heatmap3d_dict(points_xyz, sigma=sigma, use_ref=True)\n", + " else:\n", + " hm_any = heatmap3d_array(points_xyz, sigma=sigma)\n", + "\n", + " # Combine heatmaps (max projection over channel dimension)\n", + " if hasattr(hm_any, \"cpu\"):\n", + " hm_combined = hm_any.cpu().numpy().max(axis=1).squeeze(0) # (D,H,W)\n", + " else:\n", + " hm_combined = hm_any.max(axis=1).squeeze(0)\n", + "\n", + " print(f\"Rendering 3D volume with {points_xyz.shape[0]} landmark(s) | sigma={sigma:.1f}\")\n", + "\n", + " # Wrap numpy arrays directly; PyVista understands (Z,Y,X) ordering for UniformGrid\n", + " grid_vol = pv.wrap(vol) # (D,H,W)\n", + " grid_hm = pv.wrap(hm_combined) # (D,H,W)\n", + "\n", + " p = pv.Plotter()\n", + " # Base volume (grayscale-ish)\n", + " p.add_volume(grid_vol, cmap=\"bone\", opacity=\"sigmoid_6\", shade=True)\n", + " # Heatmap overlay (hot colormap, more transparent)\n", + " p.add_volume(grid_hm, cmap=\"hot\", opacity=\"sigmoid_3\", shade=True)\n", + " p.show()\n", + "\n", + "# Example: try moving the landmark; call this cell repeatedly with new values.\n", + "render_3d_pyvista(np.array([[14.0, 10.0, 18.0], [20.0, 20.0, 10.0]]), sigma=3.0, use_dict=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".conda", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py index ccd3e28bdb..85935ba2c1 100644 --- a/tests/transforms/test_generate_heatmap.py +++ b/tests/transforms/test_generate_heatmap.py @@ -12,7 +12,6 @@ from __future__ import annotations import unittest -import math import numpy as np import torch From 25ceb7fde7789d5eb92e6d1948163ecaffeb1f54 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sun, 21 Sep 2025 20:32:14 +0900 Subject: [PATCH 05/19] Refactors GenerateHeatmap transforms for clarity Streamlines the GenerateHeatmap and GenerateHeatmapd transforms for better usability and code clarity. Specifically: - Improves the input landmark array validation to provide a more descriptive error message. - Removes example notebooks. DCO Remediation Commit for sewon.jeon I, sewon.jeon , hereby add my Signed-off-by to this commit: 8ef905b36b26f2d84c386ab5359d76a61ca643ac I, sewon.jeon , hereby add my Signed-off-by to this commit: 226bf906c9b62e3ca196ebd29dd81b0eede48b4a I, sewon.jeon , hereby add my Signed-off-by to this commit: 3097baf5204b03872a87f8baa2e6d34411f9f42f I, sewon.jeon , hereby add my Signed-off-by to this commit: 0072cb081dbe805cd4ed569c722f76f9721ab239 Signed-off-by: sewon.jeon --- 2d_mdtest.ipynb | 272 ---------------------- 3d_heatmap_pyvista.ipynb | 238 ------------------- monai/transforms/post/dictionary.py | 2 +- tests/transforms/test_generate_heatmap.py | 1 + 4 files changed, 2 insertions(+), 511 deletions(-) delete mode 100644 2d_mdtest.ipynb delete mode 100644 3d_heatmap_pyvista.ipynb diff --git a/2d_mdtest.ipynb b/2d_mdtest.ipynb deleted file mode 100644 index 5235740c81..0000000000 --- a/2d_mdtest.ipynb +++ /dev/null @@ -1,272 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 2D Landmark → Heatmap (MNIST / MedMNIST) — with MONAI\n", - "Interactive demo converting clicked landmarks to Gaussian heatmaps **using MONAI**.\n", - "\n", - "**Modes**\n", - "1) Array transform: `GenerateHeatmap`\n", - "2) Dict transform: `GenerateHeatmapd` with optional `MetaTensor` reference" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# Installation requirements for interactive notebook\n", - "# %pip install --upgrade pip\n", - "# %pip install torch torchvision monai medmnist matplotlib ipywidgets\n", - "#\n", - "# For JupyterLab users, also run:\n", - "# %pip install jupyterlab-widgets\n", - "#\n", - "# For interactive matplotlib (optional, used in first implementation):\n", - "# %pip install ipympl" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n", - "import matplotlib.pyplot as plt\n", - "from torchvision import datasets, transforms\n", - "from monai.transforms.post.array import GenerateHeatmap\n", - "from monai.transforms.post.dictionary import GenerateHeatmapd\n", - "from monai.data import MetaTensor\n", - "\n", - "try:\n", - " import medmnist\n", - " from medmnist import PathMNIST\n", - "\n", - " HAS_MEDMNIST = True\n", - "except Exception:\n", - " HAS_MEDMNIST = False\n", - " print(\"medmnist not available. Run `pip install medmnist` to enable PathMNIST.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Image shape: 28x28\n" - ] - } - ], - "source": [ - "# Load a small 2D image\n", - "use_medmnist = False\n", - "if use_medmnist and HAS_MEDMNIST:\n", - " ds = PathMNIST(split=\"test\", download=True, as_rgb=True)\n", - " img = np.asarray(ds[0][0]).mean(axis=2).astype(np.float32)\n", - "else:\n", - " mnist = datasets.MNIST(root=\"./data\", train=False, download=True, transform=transforms.ToTensor())\n", - " img = mnist[0][0][0].numpy().astype(np.float32)\n", - "\n", - "if img.max() > 0:\n", - " img = img / float(img.max())\n", - "H, W = img.shape\n", - "print(f\"Image shape: {H}x{W}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Heatmap helper using GenerateHeatmap\n", - "# NOTE: GenerateHeatmap expects points in (y, x) == (row, col) order matching array indexing.\n", - "# This wrapper accepts user-friendly (x, y) and internally reorders to (y, x).\n", - "# It now supports batched inputs.\n", - "\n", - "sigma = 3.0\n", - "\n", - "\n", - "def heatmap_with_array_transform(x, y, sigma_override=None):\n", - " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", - " tr = GenerateHeatmap(sigma=s, spatial_shape=(H, W))\n", - " # Reorder (x,y) -> (y,x) for the transform\n", - " # Support batched and non-batched inputs\n", - " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", - " if pts.ndim == 2:\n", - " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", - " pts_yx = pts[..., [1, 0]]\n", - " return tr(pts_yx) # (B, N, H, W)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "affine = torch.eye(4)\n", - "ref_img = MetaTensor(torch.from_numpy(img).unsqueeze(0), affine=affine)\n", - "ref_img.meta[\"spatial_shape\"] = (H, W)\n", - "\n", - "# Dictionary version wrapper also accepts (x,y) and converts to (y,x)\n", - "\n", - "\n", - "def heatmap_with_dict_transform(x, y, sigma_override=None, use_ref=True):\n", - " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", - " tr = GenerateHeatmapd(\n", - " keys=\"points\",\n", - " heatmap_keys=\"heatmap\",\n", - " ref_image_keys=\"ref\" if use_ref else None,\n", - " spatial_shape=None if use_ref else (H, W),\n", - " sigma=s,\n", - " )\n", - " # Support batched and non-batched inputs\n", - " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", - " if pts.ndim == 2:\n", - " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", - " pts_yx = pts[..., [1, 0]]\n", - " data = {\"points\": pts_yx, \"ref\": ref_img}\n", - " out = tr(data)\n", - " return out[\"heatmap\"]\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Random points (x, y):\n", - "[[24.567745 18.05713 ]\n", - " [15.933253 2.4389822]\n", - " [21.020788 17.367664 ]]\n" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Simple random landmark → heatmap example (no interactivity)\n", - "# Re-run this cell to sample new random points and regenerate heatmaps.\n", - "# INTERNAL NOTE: GenerateHeatmap consumes (row=y, col=x). We sample (x,y) for user readability and convert.\n", - "\n", - "import random\n", - "\n", - "# Parameters\n", - "num_points = 3 # number of random landmarks\n", - "sigma_demo = 3.0 # Gaussian sigma\n", - "combine_mode = \"max\" # or 'sum'\n", - "batched_input = True # Set to True to test batched input\n", - "\n", - "# Sample random (x,y) points within image bounds (user-friendly)\n", - "points_xy = np.array(\n", - " [[random.uniform(0, W - 1), random.uniform(0, H - 1)] for _ in range(num_points)], dtype=np.float32\n", - ") # (N,2)\n", - "print(\"Random points (x, y):\")\n", - "print(points_xy)\n", - "\n", - "# Convert to (y,x) for the transform\n", - "yx_points = points_xy[:, [1, 0]].copy()\n", - "if batched_input:\n", - " yx_points = yx_points[np.newaxis, ...] # Add a batch dimension\n", - "\n", - "array_tr = GenerateHeatmap(sigma=sigma_demo, spatial_shape=(H, W))\n", - "heatmaps = array_tr(yx_points) # now correct orientation\n", - "\n", - "if batched_input:\n", - " heatmaps = heatmaps.squeeze(0) # Remove batch dim for plotting\n", - "\n", - "if combine_mode == \"max\":\n", - " combined = heatmaps.max(axis=0)\n", - "elif combine_mode == \"sum\":\n", - " combined = heatmaps.sum(axis=0)\n", - " if combined.max() > 0:\n", - " combined = combined / combined.max()\n", - "else:\n", - " raise ValueError(\"combine_mode must be 'max' or 'sum'\")\n", - "\n", - "# Plot\n", - "fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", - "axes[0].imshow(img, cmap=\"gray\", vmin=0.0, vmax=1.0, origin=\"upper\")\n", - "axes[0].set_title(\"Base Image\")\n", - "axes[0].set_axis_off()\n", - "for x, y in points_xy:\n", - " axes[0].plot(x, y, \"r+\", markersize=12, markeredgewidth=2)\n", - "\n", - "axes[1].imshow(img, cmap=\"gray\", vmin=0.0, vmax=1.0, origin=\"upper\")\n", - "axes[1].imshow(combined, alpha=0.6, cmap=\"hot\", origin=\"upper\")\n", - "axes[1].set_title(f\"Combined Heatmap (mode={combine_mode}, sigma={sigma_demo})\")\n", - "axes[1].set_axis_off()\n", - "for x, y in points_xy:\n", - " axes[1].plot(x, y, \"c+\", markersize=12, markeredgewidth=2)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Individual channels\n", - "fig2, axes2 = plt.subplots(1, num_points, figsize=(4 * num_points, 4))\n", - "if num_points == 1:\n", - " axes2 = [axes2]\n", - "for i, ax in enumerate(axes2):\n", - " ax.imshow(heatmaps[i], cmap=\"hot\", origin=\"upper\")\n", - " ax.plot(points_xy[i, 0], points_xy[i, 1], \"w+\", markersize=12, markeredgewidth=2)\n", - " ax.set_title(f\"Point {i}: (x={points_xy[i,0]:.1f}, y={points_xy[i,1]:.1f})\")\n", - " ax.set_axis_off()\n", - "plt.tight_layout()\n", - "plt.show()\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".conda", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/3d_heatmap_pyvista.ipynb b/3d_heatmap_pyvista.ipynb deleted file mode 100644 index b7be6075f3..0000000000 --- a/3d_heatmap_pyvista.ipynb +++ /dev/null @@ -1,238 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "ec5556c6", - "metadata": {}, - "source": [ - "\n", - "# 3D Volume ??Heatmap (MONAI) ??**PyVista edition**\n", - "Interactive 3D volume + heatmap rendering using **PyVista (trame backend)**. \n", - "- Points are **(x, y, z)** in *voxel index* order (MONAI/ITK-style physical coordinate order). \n", - "- The base volume is shown in grayscale; the Gaussian heatmap is overlaid in a hot colormap.\n", - "\n", - "> If you don't have the deps, install first:\n", - "```bash\n", - "pip install monai torch numpy matplotlib pyvista[trame] medmnist\n", - "```\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "4d634d14", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "import numpy as np\n", - "import torch\n", - "import matplotlib.pyplot as plt\n", - "\n", - "from monai.transforms.post.array import GenerateHeatmap\n", - "from monai.transforms.post.dictionary import GenerateHeatmapd\n", - "from monai.data import MetaTensor\n", - "\n", - "# Optional data source (MedMNIST)\n", - "HAS_MEDMNIST = False\n", - "try:\n", - " from medmnist import OrganMNIST3D\n", - " HAS_MEDMNIST = True\n", - "except Exception:\n", - " print(\"medmnist not available. Fallback to synthetic volume.\")\n", - "\n", - "# PyVista (3D rendering)\n", - "HAS_PYVISTA = True\n", - "try:\n", - " import pyvista as pv\n", - " pv.set_jupyter_backend(\"trame\") # interactive in-notebook UI\n", - "except Exception as e:\n", - " HAS_PYVISTA = False\n", - " print(\"PyVista or trame backend not available:\", e)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "0dbf0348", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Volume shape (D, H, W) = (28, 28, 28)\n" - ] - } - ], - "source": [ - "\n", - "# --- Load a 3D test volume ---\n", - "if HAS_MEDMNIST:\n", - " ds = OrganMNIST3D(split=\"test\", download=True, as_rgb=False)\n", - " vol = np.asarray(ds[0][0])\n", - " if vol.ndim == 4: # (C, D, H, W) -> (D, H, W)\n", - " vol = vol[0]\n", - "else:\n", - " # Synthetic 3D blob (D=H=W=28)\n", - " D = H = W = 28\n", - " z, y, x = np.mgrid[0:D, 0:H, 0:W]\n", - " vol = np.exp(-((x-14)**2 + (y-14)**2 + (z-14)**2) / (2*5.0**2)).astype(np.float32)\n", - "\n", - "# Normalize for nicer visualization\n", - "vol = vol.astype(np.float32)\n", - "if vol.max() > 0:\n", - " vol = vol / float(vol.max())\n", - "\n", - "D, H, W = vol.shape\n", - "print(\"Volume shape (D, H, W) =\", (D, H, W))\n", - "\n", - "# Reference MetaTensor for dictionary-based transform (identity affine)\n", - "ref = MetaTensor(torch.from_numpy(vol).unsqueeze(0), affine=torch.eye(4))\n", - "ref.meta[\"spatial_shape\"] = (D, H, W)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cc664bbf", - "metadata": {}, - "outputs": [], - "source": [ - "# --- Heatmap generators ---\n", - "def heatmap3d_array(points_xyz: np.ndarray, sigma: float = 2.0):\n", - " \"\"\"Return (B, N, D, H, W) heatmap from landmarks specified as (x, y, z).\"\"\"\n", - " if points_xyz.ndim == 2: # (N, 3)\n", - " points_xyz = points_xyz[np.newaxis, ...] # -> (1, N, 3)\n", - " points_zyx = points_xyz[..., [2, 1, 0]].copy() # reorder to (z,y,x) for MONAI\n", - " tr = GenerateHeatmap(sigma=float(sigma), spatial_shape=(D, H, W))\n", - " hm = tr(points_zyx) # (B, N, D, H, W)\n", - " return hm\n", - "\n", - "def heatmap3d_dict(points_xyz: np.ndarray, sigma: float = 2.0, use_ref: bool = True):\n", - " \"\"\"Return (B, N, D, H, W) heatmap from landmarks specified as (x, y, z).\"\"\"\n", - " if points_xyz.ndim == 2: # (N, 3)\n", - " points_xyz = points_xyz[np.newaxis, ...] # -> (1, N, 3)\n", - " points_zyx = points_xyz[..., [2, 1, 0]].copy() # reorder to (z,y,x) for MONAI\n", - " tr = GenerateHeatmapd(\n", - " keys=\"points\",\n", - " heatmap_keys=\"heatmap\",\n", - " ref_image_keys=\"ref\" if use_ref else None,\n", - " spatial_shape=None if use_ref else (D, H, W),\n", - " sigma=float(sigma),\n", - " )\n", - " data = {\"points\": points_zyx, \"ref\": ref}\n", - " out = tr(data)\n", - " return out[\"heatmap\"] # Tensor or np.ndarray, shape (B, N, D, H, W)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fdcfbb37", - "metadata": {}, - "outputs": [ - { - "ename": "IndexError", - "evalue": "index 14 is out of bounds for axis 0 with size 1", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 11\u001b[39m\n\u001b[32m 9\u001b[39m fig, ax = plt.subplots()\n\u001b[32m 10\u001b[39m ax.imshow(vol[z_idx], cmap=\u001b[33m\"\u001b[39m\u001b[33mgray\u001b[39m\u001b[33m\"\u001b[39m, vmin=\u001b[32m0.0\u001b[39m, vmax=\u001b[32m1.0\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m ax.imshow(\u001b[43mhm\u001b[49m\u001b[43m[\u001b[49m\u001b[43mz_idx\u001b[49m\u001b[43m]\u001b[49m, alpha=\u001b[32m0.6\u001b[39m)\n\u001b[32m 12\u001b[39m ax.set_title(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m2D check @ z=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mz_idx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m | input(x,y,z)=(\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcx\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcy\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcz\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m) | sigma=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msigma\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.1f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 13\u001b[39m ax.set_axis_off()\n", - "\u001b[31mIndexError\u001b[39m: index 14 is out of bounds for axis 0 with size 1" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# --- Quick 2D sanity check (central slice) ---\n", - "# This confirms that an (x,y,z) input maps to the expected (z,y,x) voxel index.\n", - "points_xyz = np.array([[W // 2, H // 2, D // 2]], dtype=np.float32)\n", - "sigma = 2.0\n", - "\n", - "hm_batch = heatmap3d_array(points_xyz, sigma=sigma)\n", - "hm = hm_batch.squeeze(0).squeeze(0) # (D,H,W)\n", - "z_idx = int(points_xyz[0, 2])\n", - "\n", - "fig, ax = plt.subplots()\n", - "ax.imshow(vol[z_idx], cmap=\"gray\", vmin=0.0, vmax=1.0)\n", - "ax.imshow(hm[z_idx], alpha=0.6, cmap=\"hot\")\n", - "ax.set_title(f\"2D check @ z={z_idx} | input(x,y,z)=({points_xyz[0,0]:.1f},{points_xyz[0,1]:.1f},{points_xyz[0,2]:.1f}) | sigma={sigma:.1f}\")\n", - "ax.set_axis_off()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4f033c71", - "metadata": {}, - "outputs": [], - "source": [ - "# --- PyVista 3D volume rendering ---\n", - "def render_3d_pyvista(points_xyz: np.ndarray, sigma: float = 2.0, use_dict: bool = False):\n", - " if not HAS_PYVISTA:\n", - " raise RuntimeError(\"PyVista backend is not available in this environment.\")\n", - "\n", - " # Build heatmap (B, N, D, H, W)\n", - " if use_dict:\n", - " hm_any = heatmap3d_dict(points_xyz, sigma=sigma, use_ref=True)\n", - " else:\n", - " hm_any = heatmap3d_array(points_xyz, sigma=sigma)\n", - "\n", - " # Combine heatmaps (max projection over channel dimension)\n", - " if hasattr(hm_any, \"cpu\"):\n", - " hm_combined = hm_any.cpu().numpy().max(axis=1).squeeze(0) # (D,H,W)\n", - " else:\n", - " hm_combined = hm_any.max(axis=1).squeeze(0)\n", - "\n", - " print(f\"Rendering 3D volume with {points_xyz.shape[0]} landmark(s) | sigma={sigma:.1f}\")\n", - "\n", - " # Wrap numpy arrays directly; PyVista understands (Z,Y,X) ordering for UniformGrid\n", - " grid_vol = pv.wrap(vol) # (D,H,W)\n", - " grid_hm = pv.wrap(hm_combined) # (D,H,W)\n", - "\n", - " p = pv.Plotter()\n", - " # Base volume (grayscale-ish)\n", - " p.add_volume(grid_vol, cmap=\"bone\", opacity=\"sigmoid_6\", shade=True)\n", - " # Heatmap overlay (hot colormap, more transparent)\n", - " p.add_volume(grid_hm, cmap=\"hot\", opacity=\"sigmoid_3\", shade=True)\n", - " p.show()\n", - "\n", - "# Example: try moving the landmark; call this cell repeatedly with new values.\n", - "render_3d_pyvista(np.array([[14.0, 10.0, 18.0], [20.0, 20.0, 10.0]]), sigma=3.0, use_dict=False)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".conda", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 813804dcc6..bd650040e9 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -610,7 +610,7 @@ def _determine_shape( return static_shape points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) if points_t.ndim not in (2, 3): - raise ValueError("landmark arrays must be 2D or 3D with shape (num_points, spatial_dims) or (B, num_points, spatial_dims).") + raise ValueError("landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D).") spatial_dims = int(points_t.shape[-1]) if ref_key is not None and ref_key in data: return self._shape_from_reference(data[ref_key], spatial_dims) diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py index 85935ba2c1..8c6f158186 100644 --- a/tests/transforms/test_generate_heatmap.py +++ b/tests/transforms/test_generate_heatmap.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest + import numpy as np import torch From 9e33e7c3d18cad1cda24f57e818d87225641cac7 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sun, 21 Sep 2025 21:28:35 +0900 Subject: [PATCH 06/19] rename parameter Signed-off-by: sewon.jeon --- monai/transforms/post/array.py | 12 ++++++------ monai/transforms/post/dictionary.py | 24 ++++++++++++++++++----- tests/transforms/test_generate_heatmap.py | 19 ++++++++++++++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index f710120293..5e90254148 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -757,7 +757,7 @@ class GenerateHeatmap(Transform): Args: sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions. spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform. - truncate: extent, in multiples of ``sigma``, used to crop the gaussian support window. + truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window. normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``. dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes). @@ -772,7 +772,7 @@ def __init__( self, sigma: Sequence[float] | float = 5.0, spatial_shape: Sequence[int] | None = None, - truncate: float = 3.0, + truncated: float = 4.0, normalize: bool = True, dtype: np.dtype | torch.dtype | type = np.float32, ) -> None: @@ -784,9 +784,9 @@ def __init__( if float(sigma) <= 0: raise ValueError("sigma must be positive.") self._sigma = float(sigma) - if truncate <= 0: - raise ValueError("truncate must be positive.") - self.truncate = float(truncate) + if truncated <= 0: + raise ValueError("truncated must be positive.") + self.truncated = float(truncated) self.normalize = normalize self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor) self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray) @@ -816,7 +816,7 @@ def __call__( target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims) sigma = self._resolve_sigma(spatial_dims) - radius = tuple(int(np.ceil(self.truncate * s)) for s in sigma) + radius = tuple(int(np.ceil(self.truncated * s)) for s in sigma) heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device) image_bounds = tuple(int(s) for s in target_shape) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index bd650040e9..bd006ef648 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -528,7 +528,7 @@ def __init__( heatmap_keys: KeysCollection | None = None, ref_image_keys: KeysCollection | None = None, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None, - truncate: float = 3.0, + truncated: float = 4.0, normalize: bool = True, dtype: np.dtype | type = np.float32, allow_missing_keys: bool = False, @@ -540,7 +540,7 @@ def __init__( self.generator = GenerateHeatmap( sigma=sigma, spatial_shape=None, - truncate=truncate, + truncated=truncated, normalize=normalize, dtype=dtype, ) @@ -632,11 +632,25 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, def _prepare_output(self, heatmap: NdarrayOrTensor, reference: Any) -> Any: if isinstance(reference, MetaTensor): - converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=reference.dtype, device=reference.device) - converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) + # Use heatmap's dtype (from generator), not reference's dtype + converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=heatmap.dtype, device=reference.device) + # For batched data shape is (B, C, *spatial), for non-batched it's (C, *spatial) + if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D) + converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) + elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D) + # Need to check if this is batched 2D or non-batched 3D + if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])): + # Non-batched 3D + converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) + else: + # Batched 2D + converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) + else: # 2D non-batched: (C, H, W) + converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) return converted if isinstance(reference, torch.Tensor): - converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=reference.dtype, device=reference.device) + # Use heatmap's dtype (from generator), not reference's dtype + converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=heatmap.dtype, device=reference.device) return converted return heatmap diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py index 8c6f158186..6caf9a89df 100644 --- a/tests/transforms/test_generate_heatmap.py +++ b/tests/transforms/test_generate_heatmap.py @@ -223,6 +223,25 @@ def test_dict_batched_with_ref(self): max_vals = heatmap.max(dim=2)[0].max(dim=2)[0].max(dim=2)[0] np.testing.assert_allclose(max_vals.cpu().numpy(), np.ones((2, 1)), rtol=1e-5, atol=1e-5) + def test_truncated_parameter(self): + # Test that truncated parameter correctly controls window size + pt = np.array([[8.0, 8.0]], dtype=np.float32) + sigma = 2.0 + + # Test with different truncated values + small_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=2.0)(pt)[0] + default_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=4.0)(pt)[0] # default + large_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=6.0)(pt)[0] + + # Larger truncated should capture more of the gaussian, resulting in slightly higher total sum + self.assertLess(small_truncated.sum(), default_truncated.sum()) + self.assertLess(default_truncated.sum(), large_truncated.sum()) + + # All should have same peak value (normalized to 1.0) + np.testing.assert_allclose(small_truncated.max(), 1.0, rtol=1e-5) + np.testing.assert_allclose(default_truncated.max(), 1.0, rtol=1e-5) + np.testing.assert_allclose(large_truncated.max(), 1.0, rtol=1e-5) + if __name__ == "__main__": unittest.main() From 62831e60c2f4d9cbb4cf287fb5a272ab9025f7cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Sep 2025 12:30:07 +0000 Subject: [PATCH 07/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/transforms/test_generate_heatmap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py index 6caf9a89df..31aff5a89a 100644 --- a/tests/transforms/test_generate_heatmap.py +++ b/tests/transforms/test_generate_heatmap.py @@ -227,16 +227,16 @@ def test_truncated_parameter(self): # Test that truncated parameter correctly controls window size pt = np.array([[8.0, 8.0]], dtype=np.float32) sigma = 2.0 - + # Test with different truncated values small_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=2.0)(pt)[0] default_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=4.0)(pt)[0] # default large_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=6.0)(pt)[0] - + # Larger truncated should capture more of the gaussian, resulting in slightly higher total sum self.assertLess(small_truncated.sum(), default_truncated.sum()) self.assertLess(default_truncated.sum(), large_truncated.sum()) - + # All should have same peak value (normalized to 1.0) np.testing.assert_allclose(small_truncated.max(), 1.0, rtol=1e-5) np.testing.assert_allclose(default_truncated.max(), 1.0, rtol=1e-5) From 15ec97a1c6e3fb36e3e1e6ae62d77bd11fd7a5af Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sun, 21 Sep 2025 22:16:14 +0900 Subject: [PATCH 08/19] fix formatting Signed-off-by: sewon.jeon --- monai/transforms/post/array.py | 12 +----- monai/transforms/post/dictionary.py | 12 +----- tests/transforms/test_generate_heatmap.py | 45 +++-------------------- 3 files changed, 10 insertions(+), 59 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 5e90254148..aacd9abd1e 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -792,11 +792,7 @@ def __init__( self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray) self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape) - def __call__( - self, - points: NdarrayOrTensor, - spatial_shape: Sequence[int] | None = None, - ) -> NdarrayOrTensor: + def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor: original_points = points points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) @@ -871,11 +867,7 @@ def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool: return all(0 <= c < size for c, size in zip(center, bounds)) def _make_window( - self, - center: Sequence[float], - radius: tuple[int, ...], - bounds: tuple[int, ...], - device: torch.device, + self, center: Sequence[float], radius: tuple[int, ...], bounds: tuple[int, ...], device: torch.device ) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]: slices: list[slice] = [] coord_shifts: list[torch.Tensor] = [] diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index bd006ef648..7665646451 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -538,11 +538,7 @@ def __init__( self.ref_image_keys = self._prepare_optional_keys(ref_image_keys) self.static_shapes = self._prepare_shapes(spatial_shape) self.generator = GenerateHeatmap( - sigma=sigma, - spatial_shape=None, - truncated=truncated, - normalize=normalize, - dtype=dtype, + sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype ) def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: @@ -600,11 +596,7 @@ def _prepare_shapes( return tuple(prepared) def _determine_shape( - self, - points: Any, - static_shape: tuple[int, ...] | None, - data: Mapping[Hashable, Any], - ref_key: Hashable | None, + self, points: Any, static_shape: tuple[int, ...] | None, data: Mapping[Hashable, Any], ref_key: Hashable | None ) -> tuple[int, ...]: if static_shape is not None: return static_shape diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py index 31aff5a89a..994c7fcfc2 100644 --- a/tests/transforms/test_generate_heatmap.py +++ b/tests/transforms/test_generate_heatmap.py @@ -72,14 +72,7 @@ def test_array_torch_device_and_dtype_propagation(self): def test_array_channel_order_identity(self): # ensure the order of channels follows the order of input points - pts = np.array( - [ - [2.0, 2.0], # point A - [12.0, 2.0], # point B - [2.0, 12.0], # point C - ], - dtype=np.float32, - ) + pts = np.array([[2.0, 2.0], [12.0, 2.0], [2.0, 12.0]], dtype=np.float32) # point A # point B # point C hm = GenerateHeatmap(sigma=1.2, spatial_shape=(16, 16))(pts) self.assertEqual(hm.shape, (3, 16, 16)) @@ -90,11 +83,7 @@ def test_array_channel_order_identity(self): def test_array_points_out_of_bounds(self): # points outside spatial domain: heatmap should still be valid (no NaN/Inf) and not all-zeros pts = np.array( - [ - [-5.0, -5.0], # outside top-left - [100.0, 100.0], # outside bottom-right - [8.0, 8.0], # inside - ], + [[-5.0, -5.0], [100.0, 100.0], [8.0, 8.0]], # outside top-left # outside bottom-right # inside dtype=np.float32, ) hm = GenerateHeatmap(sigma=2.0, spatial_shape=(16, 16))(pts) @@ -118,12 +107,7 @@ def test_dict_with_reference_meta(self): image.meta["spatial_shape"] = (8, 8, 8) data = {"points": points, "image": image} - transform = GenerateHeatmapd( - keys="points", - heatmap_keys="heatmap", - ref_image_keys="image", - sigma=2.0, - ) + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", sigma=2.0) result = transform(data) heatmap = result["heatmap"] @@ -172,13 +156,7 @@ def test_dict_dtype_control(self): self.assertEqual(hm.dtype, torch.float16) def test_array_batched_3d(self): - points = np.array( - [ - [[4.2, 7.8, 1.0]], # Batch 1 - [[12.3, 3.6, 2.0]], # Batch 2 - ], - dtype=np.float32, - ) + points = np.array([[[4.2, 7.8, 1.0]], [[12.3, 3.6, 2.0]]], dtype=np.float32) # Batch 1 # Batch 2 transform = GenerateHeatmap(sigma=1.5, spatial_shape=(16, 16, 16)) heatmap = transform(points) @@ -193,25 +171,14 @@ def test_array_batched_3d(self): self.assertTrue(np.all(np.abs(peak - points[i, 0]) <= 1.0), msg=f"peak={peak}, point={points[i, 0]}") def test_dict_batched_with_ref(self): - points = torch.tensor( - [ - [[1.5, 2.5, 3.5]], # Batch 1 - [[4.5, 5.5, 6.5]], # Batch 2 - ], - dtype=torch.float32, - ) + points = torch.tensor([[[1.5, 2.5, 3.5]], [[4.5, 5.5, 6.5]]], dtype=torch.float32) # Batch 1 # Batch 2 affine = torch.eye(4) # A single reference image is used for the whole batch image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) image.meta["spatial_shape"] = (8, 8, 8) data = {"points": points, "image": image} - transform = GenerateHeatmapd( - keys="points", - heatmap_keys="heatmap", - ref_image_keys="image", - sigma=1.0, - ) + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", sigma=1.0) result = transform(data) heatmap = result["heatmap"] From 5bc7993fe2894f1c61b69d4f95293f9e7c762b32 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sun, 21 Sep 2025 22:27:35 +0900 Subject: [PATCH 09/19] fix flake8 Signed-off-by: sewon.jeon --- monai/transforms/post/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index aacd9abd1e..9b5b942a36 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -871,7 +871,7 @@ def _make_window( ) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]: slices: list[slice] = [] coord_shifts: list[torch.Tensor] = [] - for dim, (c, r, size) in enumerate(zip(center, radius, bounds)): + for _dim, (c, r, size) in enumerate(zip(center, radius, bounds)): start = max(int(np.floor(c - r)), 0) stop = min(int(np.ceil(c + r)) + 1, size) if start >= stop: From 0e907bb29355257bcb0f7fcea9dfa5d54eae658a Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sun, 21 Sep 2025 23:05:45 +0900 Subject: [PATCH 10/19] fix meta tensor problem Signed-off-by: sewon.jeon --- monai/transforms/post/array.py | 16 ++++----- monai/transforms/post/dictionary.py | 50 +++++++++++++++-------------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 9b5b942a36..8d227f5359 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -783,7 +783,7 @@ def __init__( else: if float(sigma) <= 0: raise ValueError("sigma must be positive.") - self._sigma = float(sigma) + self._sigma = (float(sigma),) if truncated <= 0: raise ValueError("truncated must be positive.") self.truncated = float(truncated) @@ -826,7 +826,7 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device) if window_slices is None: continue - region = heatmap[(b_idx, idx, *window_slices)] + region = heatmap[b_idx, idx][window_slices] gaussian = self._evaluate_gaussian(coord_shifts, sigma) torch.maximum(region, gaussian, out=region) if self.normalize: @@ -854,13 +854,11 @@ def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: return tuple(int(s) for s in shape_tuple) def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]: - if isinstance(self._sigma, tuple): - if len(self._sigma) == spatial_dims: - return self._sigma - if len(self._sigma) == 1: - return self._sigma * spatial_dims - raise ValueError("sigma sequence length must equal the number of spatial dimensions.") - return (self._sigma,) * spatial_dims + if len(self._sigma) == spatial_dims: + return self._sigma + if len(self._sigma) == 1: + return self._sigma * spatial_dims + raise ValueError("sigma sequence length must equal the number of spatial dimensions.") @staticmethod def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool: diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 7665646451..635afea088 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -548,9 +548,19 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: ): points = d[key] shape = self._determine_shape(points, static_shape, d, ref_key) + # The GenerateHeatmap transform will handle type conversion based on input points heatmap = self.generator(points, spatial_shape=shape) + # If there's a reference image and we need to match its type/device reference = d.get(ref_key) if ref_key is not None and ref_key in d else None - d[out_key] = self._prepare_output(heatmap, reference) + if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)): + # Convert to match reference type and device while preserving heatmap's dtype + heatmap, _, _ = convert_to_dst_type( + heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, "device", None) + ) + # Copy metadata if reference is MetaTensor + if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor): + self._update_spatial_metadata(heatmap, reference) + d[out_key] = heatmap return d def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]: @@ -622,29 +632,21 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, return tuple(int(v) for v in reference.shape[-spatial_dims:]) raise ValueError("Reference data must define a shape attribute.") - def _prepare_output(self, heatmap: NdarrayOrTensor, reference: Any) -> Any: - if isinstance(reference, MetaTensor): - # Use heatmap's dtype (from generator), not reference's dtype - converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=heatmap.dtype, device=reference.device) - # For batched data shape is (B, C, *spatial), for non-batched it's (C, *spatial) - if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D) - converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) - elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D) - # Need to check if this is batched 2D or non-batched 3D - if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])): - # Non-batched 3D - converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) - else: - # Batched 2D - converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) - else: # 2D non-batched: (C, H, W) - converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) - return converted - if isinstance(reference, torch.Tensor): - # Use heatmap's dtype (from generator), not reference's dtype - converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=heatmap.dtype, device=reference.device) - return converted - return heatmap + def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None: + """Update spatial metadata of heatmap based on its dimensions.""" + # Update spatial_shape metadata based on heatmap dimensions + if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D) + heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) + elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D) + # Need to check if this is batched 2D or non-batched 3D + if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])): + # Non-batched 3D + heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) + else: + # Batched 2D + heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) + else: # 2D non-batched: (C, H, W) + heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd From 54a81a5d74c70a9e68737bd293b51ed2635b272e Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sun, 21 Sep 2025 23:26:32 +0900 Subject: [PATCH 11/19] better unit tests Signed-off-by: sewon.jeon --- tests/transforms/test_generate_heatmap.py | 360 +++++++++++++-------- tests/transforms/test_generate_heatmapd.py | 232 +++++++++++++ 2 files changed, 461 insertions(+), 131 deletions(-) create mode 100644 tests/transforms/test_generate_heatmapd.py diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py index 994c7fcfc2..4875aa31c1 100644 --- a/tests/transforms/test_generate_heatmap.py +++ b/tests/transforms/test_generate_heatmap.py @@ -1,6 +1,6 @@ # Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); -# You may not use this file except in compliance with the License. +# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software @@ -15,65 +15,188 @@ import numpy as np import torch +from parameterized import parameterized -from monai.data import MetaTensor from monai.transforms.post.array import GenerateHeatmap -from monai.transforms.post.dictionary import GenerateHeatmapd -from tests.test_utils import assert_allclose +from tests.test_utils import TEST_NDARRAYS -def _argmax_nd(x: np.ndarray) -> np.ndarray: +def _argmax_nd(x) -> np.ndarray: """argmax for N-D array → returns coordinate vector (z,y,x) or (y,x).""" + if isinstance(x, torch.Tensor): + x = x.cpu().numpy() return np.asarray(np.unravel_index(np.argmax(x), x.shape)) -class TestGenerateHeatmap(unittest.TestCase): - def test_array_2d(self): - points = np.array([[4.2, 7.8], [12.3, 3.6]], dtype=np.float32) - transform = GenerateHeatmap(sigma=1.5, spatial_shape=(16, 16)) +# Test cases for 2D array inputs with different data types +TEST_CASES_2D = [] +for idx, p in enumerate(TEST_NDARRAYS): + TEST_CASES_2D.append( + [ + f"2d_basic_type{idx}", + p(np.array([[4.2, 7.8], [12.3, 3.6]], dtype=np.float32)), + {"sigma": 1.5, "spatial_shape": (16, 16)}, + (2, 16, 16), + ] + ) + +# Test cases for 3D torch outputs with explicit dtype +TEST_CASES_3D_TORCH = [] +for dtype in [torch.float32, torch.float64]: + TEST_CASES_3D_TORCH.append( + [ + f"3d_torch_{str(dtype).replace('torch.', '')}", + torch.tensor([[1.5, 2.5, 3.5]], dtype=torch.float32), + {"sigma": 1.0, "spatial_shape": (8, 8, 8), "dtype": dtype}, + (1, 8, 8, 8), + dtype, + ] + ) + +# Test cases for 3D numpy outputs with explicit dtype +TEST_CASES_3D_NUMPY = [] +for dtype in [np.float32, np.float64]: + TEST_CASES_3D_NUMPY.append( + [ + f"3d_numpy_{dtype.__name__}", + np.array([[1.5, 2.5, 3.5]], dtype=np.float32), + {"sigma": 1.0, "spatial_shape": (8, 8, 8), "dtype": dtype}, + (1, 8, 8, 8), + dtype, + ] + ) + +# Test cases for different sigma values +TEST_CASES_SIGMA = [] +for sigma in [0.5, 1.0, 2.0, 3.0]: + TEST_CASES_SIGMA.append( + [ + f"sigma_{sigma}", + np.array([[8.0, 8.0]], dtype=np.float32), + {"sigma": sigma, "spatial_shape": (16, 16)}, + (1, 16, 16), + ] + ) + +# Test cases for truncated parameter +TEST_CASES_TRUNCATED = [] +for truncated in [2.0, 4.0, 6.0]: + TEST_CASES_TRUNCATED.append( + [ + f"truncated_{truncated}", + np.array([[8.0, 8.0]], dtype=np.float32), + {"sigma": 2.0, "spatial_shape": (32, 32), "truncated": truncated}, + (1, 32, 32), + ] + ) + +# Test cases for batched 3D with different array types +TEST_CASES_BATCHED = [] +for idx, p in enumerate(TEST_NDARRAYS): + TEST_CASES_BATCHED.append( + [ + f"batched_3d_type{idx}", + p(np.array([[[4.2, 7.8, 1.0]], [[12.3, 3.6, 2.0]]], dtype=np.float32)), + {"sigma": 1.5, "spatial_shape": (16, 16, 16)}, + (2, 1, 16, 16, 16), + ] + ) + +# Test cases for device and dtype propagation (torch only) +TEST_CASES_DEVICE_DTYPE = [] +if torch.cuda.is_available(): + for dtype in [torch.float16, torch.float32, torch.float64]: + TEST_CASES_DEVICE_DTYPE.append( + [ + f"cuda_{str(dtype).replace('torch.', '')}", + torch.tensor([[3.0, 4.0, 5.0]], dtype=torch.float32, device="cuda:0"), + {"sigma": 1.2, "spatial_shape": (10, 10, 10), "dtype": dtype}, + (1, 10, 10, 10), + dtype, + "cuda:0", + ] + ) +else: + for dtype in [torch.float32, torch.float64]: + TEST_CASES_DEVICE_DTYPE.append( + [ + f"cpu_{str(dtype).replace('torch.', '')}", + torch.tensor([[3.0, 4.0, 5.0]], dtype=torch.float32), + {"sigma": 1.2, "spatial_shape": (10, 10, 10), "dtype": dtype}, + (1, 10, 10, 10), + dtype, + "cpu", + ] + ) + +class TestGenerateHeatmap(unittest.TestCase): + @parameterized.expand(TEST_CASES_2D) + def test_array_2d(self, _, points, params, expected_shape): + transform = GenerateHeatmap(**params) heatmap = transform(points) - self.assertEqual(heatmap.shape, (2, 16, 16)) - self.assertEqual(heatmap.dtype, np.float32) - np.testing.assert_allclose(heatmap.max(axis=(1, 2)), np.ones(2), rtol=1e-5, atol=1e-5) + # Check output type matches input type + if isinstance(points, torch.Tensor): + self.assertIsInstance(heatmap, torch.Tensor) + self.assertEqual(heatmap.dtype, torch.float32) # Default dtype for torch + heatmap_np = heatmap.cpu().numpy() + points_np = points.cpu().numpy() + else: + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.dtype, np.float32) # Default dtype for numpy + heatmap_np = heatmap + points_np = points + + self.assertEqual(heatmap.shape, expected_shape) + np.testing.assert_allclose(heatmap_np.max(axis=(1, 2)), np.ones(expected_shape[0]), rtol=1e-5, atol=1e-5) # peak should be close to original point location (<= 1px tolerance due to discretization) - for idx, channel in enumerate(heatmap): - peak = _argmax_nd(channel) - self.assertTrue(np.all(np.abs(peak - points[idx]) <= 1.0), msg=f"peak={peak}, point={points[idx]}") - self.assertLess(channel[0, 0], 1e-3) - - def test_array_3d_torch_output(self): - points = torch.tensor([[1.5, 2.5, 3.5]], dtype=torch.float32) - transform = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=torch.float32) - - heatmap = transform(points.to(device=points.device)) + for idx in range(expected_shape[0]): + peak = _argmax_nd(heatmap_np[idx]) + self.assertTrue(np.all(np.abs(peak - points_np[idx]) <= 1.0), msg=f"peak={peak}, point={points_np[idx]}") + self.assertLess(heatmap_np[idx, 0, 0], 1e-3) + + @parameterized.expand(TEST_CASES_3D_TORCH) + def test_array_3d_torch_output(self, _, points, params, expected_shape, expected_dtype): + transform = GenerateHeatmap(**params) + heatmap = transform(points) self.assertIsInstance(heatmap, torch.Tensor) self.assertEqual(heatmap.device, points.device) - self.assertEqual(tuple(heatmap.shape), (1, 8, 8, 8)) + self.assertEqual(tuple(heatmap.shape), expected_shape) + self.assertEqual(heatmap.dtype, expected_dtype) self.assertTrue(torch.isclose(heatmap.max(), torch.tensor(1.0, dtype=heatmap.dtype, device=heatmap.device))) - def test_array_torch_device_and_dtype_propagation(self): - # verify dtype parameter honored and CUDA (if available) - dtype = torch.float16 if torch.cuda.is_available() else torch.float32 - device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") - - pts = torch.tensor([[3.0, 4.0, 5.0]], dtype=torch.float32, device=device) - tr = GenerateHeatmap(sigma=1.2, spatial_shape=(10, 10, 10), dtype=dtype) + @parameterized.expand(TEST_CASES_3D_NUMPY) + def test_array_3d_numpy_output(self, _, points, params, expected_shape, expected_dtype): + transform = GenerateHeatmap(**params) + heatmap = transform(points) + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.shape, expected_shape) + self.assertEqual(heatmap.dtype, expected_dtype) + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + @parameterized.expand(TEST_CASES_DEVICE_DTYPE) + def test_array_torch_device_and_dtype_propagation( + self, _, pts, params, expected_shape, expected_dtype, expected_device + ): + tr = GenerateHeatmap(**params) hm = tr(pts) + self.assertIsInstance(hm, torch.Tensor) - self.assertEqual(hm.device, device) - self.assertEqual(hm.dtype, dtype) - self.assertEqual(tuple(hm.shape), (1, 10, 10, 10)) + self.assertEqual(str(hm.device).split(":")[0], expected_device.split(":")[0]) + self.assertEqual(hm.dtype, expected_dtype) + self.assertEqual(tuple(hm.shape), expected_shape) self.assertTrue(torch.all(hm >= 0)) def test_array_channel_order_identity(self): # ensure the order of channels follows the order of input points pts = np.array([[2.0, 2.0], [12.0, 2.0], [2.0, 12.0]], dtype=np.float32) # point A # point B # point C hm = GenerateHeatmap(sigma=1.2, spatial_shape=(16, 16))(pts) + + self.assertIsInstance(hm, np.ndarray) self.assertEqual(hm.shape, (3, 16, 16)) peaks = np.vstack([_argmax_nd(hm[i]) for i in range(3)]) @@ -87,51 +210,24 @@ def test_array_points_out_of_bounds(self): dtype=np.float32, ) hm = GenerateHeatmap(sigma=2.0, spatial_shape=(16, 16))(pts) + + self.assertIsInstance(hm, np.ndarray) self.assertEqual(hm.shape, (3, 16, 16)) self.assertFalse(np.isnan(hm).any() or np.isinf(hm).any()) # inside point channel should have max≈1; others may clip at border (≤1) self.assertGreater(hm[2].max(), 0.9) - def test_array_sigma_scaling_effect(self): - # Larger sigma should spread mass (lower peak), smaller sigma higher peak - pt = np.array([[8.0, 8.0]], dtype=np.float32) - small = GenerateHeatmap(sigma=0.8, spatial_shape=(16, 16))(pt)[0] - large = GenerateHeatmap(sigma=2.5, spatial_shape=(16, 16))(pt)[0] - self.assertGreater(small.max(), large.max() - 1e-6) # small sigma peak >= large sigma peak - - def test_dict_with_reference_meta(self): - points = np.array([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=np.float32) - affine = torch.eye(4) - image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) - image.meta["spatial_shape"] = (8, 8, 8) - data = {"points": points, "image": image} - - transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", sigma=2.0) - - result = transform(data) - heatmap = result["heatmap"] - - self.assertIsInstance(heatmap, MetaTensor) - self.assertEqual(tuple(heatmap.shape), (2, 8, 8, 8)) - self.assertEqual(heatmap.meta["spatial_shape"], (8, 8, 8)) - assert_allclose(heatmap.affine, image.affine, type_test=False) - np.testing.assert_allclose(heatmap.cpu().numpy().max(axis=(1, 2, 3)), np.ones(2), rtol=1e-5, atol=1e-5) - - def test_dict_static_shape(self): - points = np.array([[1.0, 1.0]], dtype=np.float32) - transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", spatial_shape=(6, 6)) - - result = transform({"points": points}) - heatmap = result["heatmap"] - self.assertIsInstance(heatmap, np.ndarray) - self.assertEqual(heatmap.shape, (1, 6, 6)) + @parameterized.expand(TEST_CASES_SIGMA) + def test_array_sigma_scaling_effect(self, _, pt, params, expected_shape): + heatmap = GenerateHeatmap(**params)(pt)[0] + self.assertEqual(heatmap.shape, expected_shape[1:]) + + # All should have peak normalized to 1.0 + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) - def test_dict_missing_shape_raises(self): - # Without ref image or explicit spatial_shape, must raise - transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap") - with self.assertRaises(ValueError): - transform({"points": np.zeros((1, 2), dtype=np.float32)}) + # Verify heatmap is valid + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) def test_invalid_points_shape_raises(self): # points must be (N, D) with D in {2,3} @@ -142,72 +238,74 @@ def test_invalid_points_shape_raises(self): with self.assertRaises((ValueError, AssertionError, IndexError, RuntimeError)): tr(np.zeros((2, 4), dtype=np.float32)) # D=4 unsupported - def test_dict_dtype_control(self): - # Ensure dtype argument controls output dtype for dictionary transform too - points = np.array([[2.0, 3.0, 4.0]], dtype=np.float32) - ref = MetaTensor(torch.zeros((1, 10, 10, 10), dtype=torch.float32), affine=torch.eye(4)) - d = {"pts": points, "img": ref} - - tr = GenerateHeatmapd(keys="pts", heatmap_keys="hm", ref_image_keys="img", sigma=1.4, dtype=torch.float16) - out = tr(d) - hm = out["hm"] - self.assertIsInstance(hm, MetaTensor) - self.assertEqual(tuple(hm.shape), (1, 10, 10, 10)) - self.assertEqual(hm.dtype, torch.float16) - - def test_array_batched_3d(self): - points = np.array([[[4.2, 7.8, 1.0]], [[12.3, 3.6, 2.0]]], dtype=np.float32) # Batch 1 # Batch 2 - transform = GenerateHeatmap(sigma=1.5, spatial_shape=(16, 16, 16)) - + @parameterized.expand(TEST_CASES_BATCHED) + def test_array_batched_3d(self, _, points, params, expected_shape): + transform = GenerateHeatmap(**params) heatmap = transform(points) - self.assertEqual(heatmap.shape, (2, 1, 16, 16, 16)) - self.assertEqual(heatmap.dtype, np.float32) - np.testing.assert_allclose(heatmap.max(axis=(2, 3, 4)), np.ones((2, 1)), rtol=1e-5, atol=1e-5) + # Check output type matches input type + if isinstance(points, torch.Tensor): + self.assertIsInstance(heatmap, torch.Tensor) + heatmap_np = heatmap.cpu().numpy() + points_np = points.cpu().numpy() + else: + self.assertIsInstance(heatmap, np.ndarray) + heatmap_np = heatmap + points_np = points + + self.assertEqual(heatmap.shape, expected_shape) + np.testing.assert_allclose( + heatmap_np.max(axis=(2, 3, 4)), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5 + ) # Check peaks for each batch item - for i in range(2): - peak = _argmax_nd(heatmap[i, 0]) - self.assertTrue(np.all(np.abs(peak - points[i, 0]) <= 1.0), msg=f"peak={peak}, point={points[i, 0]}") - - def test_dict_batched_with_ref(self): - points = torch.tensor([[[1.5, 2.5, 3.5]], [[4.5, 5.5, 6.5]]], dtype=torch.float32) # Batch 1 # Batch 2 - affine = torch.eye(4) - # A single reference image is used for the whole batch - image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) - image.meta["spatial_shape"] = (8, 8, 8) - data = {"points": points, "image": image} - - transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", sigma=1.0) - - result = transform(data) - heatmap = result["heatmap"] - - self.assertIsInstance(heatmap, MetaTensor) - self.assertEqual(tuple(heatmap.shape), (2, 1, 8, 8, 8)) - self.assertEqual(heatmap.meta["spatial_shape"], (8, 8, 8)) - assert_allclose(heatmap.affine, image.affine, type_test=False) - max_vals = heatmap.max(dim=2)[0].max(dim=2)[0].max(dim=2)[0] - np.testing.assert_allclose(max_vals.cpu().numpy(), np.ones((2, 1)), rtol=1e-5, atol=1e-5) - - def test_truncated_parameter(self): - # Test that truncated parameter correctly controls window size - pt = np.array([[8.0, 8.0]], dtype=np.float32) - sigma = 2.0 - - # Test with different truncated values - small_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=2.0)(pt)[0] - default_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=4.0)(pt)[0] # default - large_truncated = GenerateHeatmap(sigma=sigma, spatial_shape=(32, 32), truncated=6.0)(pt)[0] - - # Larger truncated should capture more of the gaussian, resulting in slightly higher total sum - self.assertLess(small_truncated.sum(), default_truncated.sum()) - self.assertLess(default_truncated.sum(), large_truncated.sum()) + for i in range(expected_shape[0]): + peak = _argmax_nd(heatmap_np[i, 0]) + self.assertTrue(np.all(np.abs(peak - points_np[i, 0]) <= 1.0), msg=f"peak={peak}, point={points_np[i, 0]}") + + @parameterized.expand(TEST_CASES_TRUNCATED) + def test_truncated_parameter(self, _, pt, params, expected_shape): + heatmap = GenerateHeatmap(**params)(pt)[0] # All should have same peak value (normalized to 1.0) - np.testing.assert_allclose(small_truncated.max(), 1.0, rtol=1e-5) - np.testing.assert_allclose(default_truncated.max(), 1.0, rtol=1e-5) - np.testing.assert_allclose(large_truncated.max(), 1.0, rtol=1e-5) + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + # Verify shape and no NaN/Inf + self.assertEqual(heatmap.shape, expected_shape[1:]) + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + def test_torch_to_torch_type_preservation(self): + """Test that torch input produces torch output""" + pts = torch.tensor([[4.0, 4.0]], dtype=torch.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8))(pts) + + self.assertIsInstance(hm, torch.Tensor) + self.assertEqual(hm.dtype, torch.float32) + self.assertEqual(hm.device, pts.device) + + def test_numpy_to_numpy_type_preservation(self): + """Test that numpy input produces numpy output""" + pts = np.array([[4.0, 4.0]], dtype=np.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8))(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.dtype, np.float32) + + def test_dtype_override_torch(self): + """Test dtype parameter works with torch tensors""" + pts = torch.tensor([[4.0, 4.0, 4.0]], dtype=torch.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=torch.float64)(pts) + + self.assertIsInstance(hm, torch.Tensor) + self.assertEqual(hm.dtype, torch.float64) + + def test_dtype_override_numpy(self): + """Test dtype parameter works with numpy arrays""" + pts = np.array([[4.0, 4.0, 4.0]], dtype=np.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=np.float64)(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.dtype, np.float64) if __name__ == "__main__": diff --git a/tests/transforms/test_generate_heatmapd.py b/tests/transforms/test_generate_heatmapd.py new file mode 100644 index 0000000000..c32a1d1f2d --- /dev/null +++ b/tests/transforms/test_generate_heatmapd.py @@ -0,0 +1,232 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.post.dictionary import GenerateHeatmapd +from tests.test_utils import TEST_NDARRAYS, assert_allclose + +# Test cases for dictionary transforms with reference image +# Only test with non-MetaTensor types to avoid affine conflicts +TEST_CASES_WITH_REF = [] +TEST_CASES_WITH_REF.append( + [ + "dict_with_ref_3d_numpy", + np.array([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=np.float32), + {"sigma": 2.0}, + (2, 8, 8, 8), + torch.float32, + True, # uses reference image + ] +) +TEST_CASES_WITH_REF.append( + [ + "dict_with_ref_3d_torch", + torch.tensor([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=torch.float32), + {"sigma": 2.0}, + (2, 8, 8, 8), + torch.float32, + True, # uses reference image + ] +) + +# Test cases for dictionary transforms with static spatial shape +TEST_CASES_STATIC_SHAPE = [] +for shape in [(6, 6), (8, 8, 8), (10, 10, 10)]: + TEST_CASES_STATIC_SHAPE.append( + [ + f"dict_static_shape_{len(shape)}d", + np.array([[1.0] * len(shape)], dtype=np.float32), + {"spatial_shape": shape}, + (1,) + shape, + np.float32, + ] + ) + +# Test cases for dtype control +TEST_CASES_DTYPE = [] +for dtype in [torch.float16, torch.float32, torch.float64]: + TEST_CASES_DTYPE.append( + [ + f"dict_dtype_{str(dtype).replace('torch.', '')}", + np.array([[2.0, 3.0, 4.0]], dtype=np.float32), + {"sigma": 1.4, "dtype": dtype}, + (1, 10, 10, 10), + dtype, + ] + ) + +# Test cases for batched dictionary transforms +TEST_CASES_BATCHED = [] +TEST_CASES_BATCHED.append( + [ + "dict_batched_3d", + torch.tensor([[[1.5, 2.5, 3.5]], [[4.5, 5.5, 6.5]]], dtype=torch.float32), + {"sigma": 1.0}, + (2, 1, 8, 8, 8), + torch.float32, + ] +) + +# Test cases for various sigma values +TEST_CASES_SIGMA_VALUES = [] +for sigma in [0.5, 1.0, 2.0, 3.0]: + TEST_CASES_SIGMA_VALUES.append( + [ + f"dict_sigma_{sigma}", + np.array([[4.0, 4.0, 4.0]], dtype=np.float32), + {"sigma": sigma, "spatial_shape": (8, 8, 8)}, + (1, 8, 8, 8), + ] + ) + + +class TestGenerateHeatmapd(unittest.TestCase): + @parameterized.expand(TEST_CASES_WITH_REF) + def test_dict_with_reference_meta(self, _, points, params, expected_shape, expected_dtype, uses_ref): + affine = torch.eye(4) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + data = {"points": points, "image": image} + + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", **params) + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), expected_shape) + self.assertEqual(heatmap.meta["spatial_shape"], (8, 8, 8)) + # The heatmap should inherit the reference image's affine + assert_allclose(heatmap.affine, image.affine, type_test=False) + + # Check max values are normalized to 1.0 + max_vals = heatmap.cpu().numpy().max(axis=tuple(range(1, len(expected_shape)))) + np.testing.assert_allclose(max_vals, np.ones(expected_shape[0]), rtol=1e-5, atol=1e-5) + + @parameterized.expand(TEST_CASES_STATIC_SHAPE) + def test_dict_static_shape(self, _, points, params, expected_shape, expected_dtype): + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", **params) + result = transform({"points": points}) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.shape, expected_shape) + self.assertEqual(heatmap.dtype, expected_dtype) + + def test_dict_missing_shape_raises(self): + # Without ref image or explicit spatial_shape, must raise + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap") + with self.assertRaises(ValueError): + transform({"points": np.zeros((1, 2), dtype=np.float32)}) + + @parameterized.expand(TEST_CASES_DTYPE) + def test_dict_dtype_control(self, _, points, params, expected_shape, expected_dtype): + ref = MetaTensor(torch.zeros((1, 10, 10, 10), dtype=torch.float32), affine=torch.eye(4)) + d = {"pts": points, "img": ref} + + tr = GenerateHeatmapd(keys="pts", heatmap_keys="hm", ref_image_keys="img", **params) + out = tr(d) + hm = out["hm"] + + self.assertIsInstance(hm, MetaTensor) + self.assertEqual(tuple(hm.shape), expected_shape) + self.assertEqual(hm.dtype, expected_dtype) + + @parameterized.expand(TEST_CASES_BATCHED) + def test_dict_batched_with_ref(self, _, points, params, expected_shape, expected_dtype): + affine = torch.eye(4) + # A single reference image is used for the whole batch + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + data = {"points": points, "image": image} + + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", **params) + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), expected_shape) + self.assertEqual(heatmap.meta["spatial_shape"], (8, 8, 8)) + assert_allclose(heatmap.affine, image.affine, type_test=False) + + # Check max values + max_vals = heatmap.max(dim=2)[0].max(dim=2)[0].max(dim=2)[0] + np.testing.assert_allclose( + max_vals.cpu().numpy(), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5 + ) + + @parameterized.expand(TEST_CASES_SIGMA_VALUES) + def test_dict_various_sigma(self, _, points, params, expected_shape): + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", **params) + result = transform({"points": points}) + heatmap = result["heatmap"] + + self.assertEqual(heatmap.shape, expected_shape) + # Verify heatmap is normalized + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + # Verify no NaN or Inf + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + def test_dict_multiple_keys(self): + """Test dictionary transform with multiple input/output keys""" + points1 = np.array([[2.0, 2.0]], dtype=np.float32) + points2 = np.array([[4.0, 4.0]], dtype=np.float32) + + data = {"pts1": points1, "pts2": points2} + transform = GenerateHeatmapd( + keys=["pts1", "pts2"], heatmap_keys=["hm1", "hm2"], spatial_shape=(8, 8), sigma=1.0 + ) + + result = transform(data) + + self.assertIn("hm1", result) + self.assertIn("hm2", result) + self.assertEqual(result["hm1"].shape, (1, 8, 8)) + self.assertEqual(result["hm2"].shape, (1, 8, 8)) + + # Verify peaks are at different locations + self.assertNotEqual(np.argmax(result["hm1"]), np.argmax(result["hm2"])) + + def test_metatensor_points_with_ref(self): + """Test MetaTensor points with reference image - documents current behavior""" + from monai.data import MetaTensor + + # Create MetaTensor points with non-identity affine + points_affine = torch.tensor([[2.0, 0, 0, 0], [0, 2.0, 0, 0], [0, 0, 2.0, 0], [0, 0, 0, 1.0]]) + points = MetaTensor(torch.tensor([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=torch.float32), affine=points_affine) + + # Reference image with identity affine + ref_affine = torch.eye(4) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=ref_affine) + image.meta["spatial_shape"] = (8, 8, 8) + + data = {"points": points, "image": image} + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", sigma=2.0) + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), (2, 8, 8, 8)) + + # Note: Currently the heatmap may inherit affine from points MetaTensor + # This test documents the current behavior + # Ideally, the heatmap should use the reference image's affine + + +if __name__ == "__main__": + unittest.main() From 4b367ab9e95e2b91bb7e651f9d808f0f0c8d2793 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Sep 2025 14:27:49 +0000 Subject: [PATCH 12/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/transforms/test_generate_heatmapd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transforms/test_generate_heatmapd.py b/tests/transforms/test_generate_heatmapd.py index c32a1d1f2d..035529983e 100644 --- a/tests/transforms/test_generate_heatmapd.py +++ b/tests/transforms/test_generate_heatmapd.py @@ -19,7 +19,7 @@ from monai.data import MetaTensor from monai.transforms.post.dictionary import GenerateHeatmapd -from tests.test_utils import TEST_NDARRAYS, assert_allclose +from tests.test_utils import assert_allclose # Test cases for dictionary transforms with reference image # Only test with non-MetaTensor types to avoid affine conflicts From eafe59a57eb9b1a242290bdf888318b027c6daee Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sun, 21 Sep 2025 23:40:02 +0900 Subject: [PATCH 13/19] fix test error Signed-off-by: sewon.jeon --- tests/transforms/test_generate_heatmap.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py index 4875aa31c1..73fedb796c 100644 --- a/tests/transforms/test_generate_heatmap.py +++ b/tests/transforms/test_generate_heatmap.py @@ -55,14 +55,14 @@ def _argmax_nd(x) -> np.ndarray: # Test cases for 3D numpy outputs with explicit dtype TEST_CASES_3D_NUMPY = [] -for dtype in [np.float32, np.float64]: +for dtype_obj in [np.float32, np.float64]: TEST_CASES_3D_NUMPY.append( [ - f"3d_numpy_{dtype.__name__}", + f"3d_numpy_{dtype_obj.__name__}", np.array([[1.5, 2.5, 3.5]], dtype=np.float32), - {"sigma": 1.0, "spatial_shape": (8, 8, 8), "dtype": dtype}, + {"sigma": 1.0, "spatial_shape": (8, 8, 8), "dtype": dtype_obj}, (1, 8, 8, 8), - dtype, + dtype_obj, ] ) From aaf283349127db9a3357624fddbe92344081e6a0 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Wed, 24 Sep 2025 19:29:59 +0900 Subject: [PATCH 14/19] address the code review. Signed-off-by: sewon.jeon --- monai/transforms/post/array.py | 18 ++++++++-- monai/transforms/post/dictionary.py | 42 ++++++++++++---------- tests/transforms/test_generate_heatmapd.py | 4 +-- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 8d227f5359..755f11caa0 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -16,6 +16,7 @@ import warnings from collections.abc import Callable, Iterable, Sequence +from typing import ClassVar import numpy as np import torch @@ -766,7 +767,7 @@ class GenerateHeatmap(Transform): """ - backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + backend: ClassVar[list] = [TransformBackends.NUMPY, TransformBackends.TORCH] def __init__( self, @@ -862,7 +863,10 @@ def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]: @staticmethod def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool: - return all(0 <= c < size for c, size in zip(center, bounds)) + for c, size in zip(center, bounds): + if not (0 <= c < size): + return False + return True def _make_window( self, center: Sequence[float], radius: tuple[int, ...], bounds: tuple[int, ...], device: torch.device @@ -879,6 +883,16 @@ def _make_window( return tuple(slices), tuple(coord_shifts) def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor: + """ + Evaluate Gaussian at given coordinate shifts with specified sigmas. + + Args: + coord_shifts: Per-dimension coordinate offsets from center. + sigma: Per-dimension standard deviations. + + Returns: + Gaussian values at the specified coordinates. + """ device = coord_shifts[0].device shape = tuple(len(axis) for axis in coord_shifts) if 0 in shape: diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 635afea088..5fed494028 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -521,6 +521,14 @@ class GenerateHeatmapd(MapTransform): backend = GenerateHeatmap.backend + # Error messages + _ERR_HEATMAP_KEYS_LEN = "heatmap_keys length must match keys length." + _ERR_REF_KEYS_LEN = "ref_image_keys length must match keys length when provided." + _ERR_SHAPE_LEN = "spatial_shape length must match keys length when providing per-key shapes." + _ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys." + _ERR_INVALID_POINTS = "landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)." + _ERR_REF_NO_SHAPE = "Reference data must define a shape attribute." + def __init__( self, keys: KeysCollection, @@ -570,7 +578,7 @@ def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Ha if len(keys_tuple) == 1 and len(self.keys) > 1: keys_tuple = keys_tuple * len(self.keys) if len(keys_tuple) != len(self.keys): - raise ValueError("heatmap_keys length must match keys length.") + raise ValueError(self._ERR_HEATMAP_KEYS_LEN) return keys_tuple def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]: @@ -580,7 +588,7 @@ def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Has if len(keys_tuple) == 1 and len(self.keys) > 1: keys_tuple = keys_tuple * len(self.keys) if len(keys_tuple) != len(self.keys): - raise ValueError("ref_image_keys length must match keys length when provided.") + raise ValueError(self._ERR_REF_KEYS_LEN) return tuple(keys_tuple) def _prepare_shapes( @@ -595,7 +603,7 @@ def _prepare_shapes( if len(shape_tuple) == 1 and len(self.keys) > 1: shape_tuple = shape_tuple * len(self.keys) if len(shape_tuple) != len(self.keys): - raise ValueError("spatial_shape length must match keys length when providing per-key shapes.") + raise ValueError(self._ERR_SHAPE_LEN) prepared: list[tuple[int, ...] | None] = [] for item in shape_tuple: if item is None: @@ -612,13 +620,11 @@ def _determine_shape( return static_shape points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) if points_t.ndim not in (2, 3): - raise ValueError("landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D).") + raise ValueError(self._ERR_INVALID_POINTS) spatial_dims = int(points_t.shape[-1]) if ref_key is not None and ref_key in data: return self._shape_from_reference(data[ref_key], spatial_dims) - raise ValueError( - "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys." - ) + raise ValueError(self._ERR_NO_SHAPE) def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, ...]: if isinstance(reference, MetaTensor): @@ -630,23 +636,23 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, return tuple(int(v) for v in reference.shape[-spatial_dims:]) if hasattr(reference, "shape"): return tuple(int(v) for v in reference.shape[-spatial_dims:]) - raise ValueError("Reference data must define a shape attribute.") + raise ValueError(self._ERR_REF_NO_SHAPE) def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None: """Update spatial metadata of heatmap based on its dimensions.""" - # Update spatial_shape metadata based on heatmap dimensions + # Determine if batched based on reference's batch dimension + ref_spatial_shape = reference.meta.get("spatial_shape", []) + ref_is_batched = len(reference.shape) > len(ref_spatial_shape) + 1 + if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D) - heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) + spatial_shape = heatmap.shape[2:] elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D) - # Need to check if this is batched 2D or non-batched 3D - if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])): - # Non-batched 3D - heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) - else: - # Batched 2D - heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:]) + # Disambiguate: 2D batched vs 3D non-batched + spatial_shape = heatmap.shape[2:] if ref_is_batched else heatmap.shape[1:] else: # 2D non-batched: (C, H, W) - heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:]) + spatial_shape = heatmap.shape[1:] + + heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape) GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd diff --git a/tests/transforms/test_generate_heatmapd.py b/tests/transforms/test_generate_heatmapd.py index 035529983e..fc8a2b8e55 100644 --- a/tests/transforms/test_generate_heatmapd.py +++ b/tests/transforms/test_generate_heatmapd.py @@ -98,7 +98,7 @@ class TestGenerateHeatmapd(unittest.TestCase): @parameterized.expand(TEST_CASES_WITH_REF) - def test_dict_with_reference_meta(self, _, points, params, expected_shape, expected_dtype, uses_ref): + def test_dict_with_reference_meta(self, _, points, params, expected_shape, *_unused): affine = torch.eye(4) image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) image.meta["spatial_shape"] = (8, 8, 8) @@ -148,7 +148,7 @@ def test_dict_dtype_control(self, _, points, params, expected_shape, expected_dt self.assertEqual(hm.dtype, expected_dtype) @parameterized.expand(TEST_CASES_BATCHED) - def test_dict_batched_with_ref(self, _, points, params, expected_shape, expected_dtype): + def test_dict_batched_with_ref(self, _, points, params, expected_shape, _expected_dtype): affine = torch.eye(4) # A single reference image is used for the whole batch image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) From 1bf0850d208922fa68ada015ff04a8a226f49344 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Thu, 25 Sep 2025 13:43:14 +0900 Subject: [PATCH 15/19] fixes Signed-off-by: sewon.jeon --- monai/transforms/post/array.py | 33 +++++++++++++++------- monai/transforms/post/dictionary.py | 24 ++++++++-------- tests/transforms/test_generate_heatmapd.py | 5 ++-- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 755f11caa0..833344ac6f 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -753,7 +753,14 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO class GenerateHeatmap(Transform): """ - Generate per-landmark gaussian response maps for 2D or 3D coordinates. + Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates. + + Notes: + - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D. + - Output shape: + - Non-batched points (N, D): (N, H, W[, D]) + - Batched points (B, N, D): (B, N, H, W[, D]) + - Each channel corresponds to one landmark. Args: sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions. @@ -829,11 +836,13 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None continue region = heatmap[b_idx, idx][window_slices] gaussian = self._evaluate_gaussian(coord_shifts, sigma) - torch.maximum(region, gaussian, out=region) + updated = torch.maximum(region, gaussian) + # write back + region.copy_(updated) if self.normalize: - max_val = heatmap[b_idx, idx].max() - if max_val.item() > 0: - heatmap[b_idx, idx] /= max_val + peak = updated.max() + if peak.item() > 0: + heatmap[b_idx, idx] /= peak if not is_batched: heatmap = heatmap.squeeze(0) @@ -851,7 +860,9 @@ def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: if len(shape_tuple) == 1: shape_tuple = shape_tuple * spatial_dims # type: ignore else: - raise ValueError("spatial_shape length must match spatial dimension of the landmarks.") + raise ValueError( + "spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)." + ) return tuple(int(s) for s in shape_tuple) def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]: @@ -879,7 +890,7 @@ def _make_window( if start >= stop: return None, () slices.append(slice(start, stop)) - coord_shifts.append(torch.arange(start, stop, device=device, dtype=self.torch_dtype) - float(c)) + coord_shifts.append(torch.arange(start, stop, device=device, dtype=torch.float32) - float(c)) return tuple(slices), tuple(coord_shifts) def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor: @@ -897,13 +908,15 @@ def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tupl shape = tuple(len(axis) for axis in coord_shifts) if 0 in shape: return torch.zeros(shape, dtype=self.torch_dtype, device=device) - exponent = torch.zeros(shape, dtype=self.torch_dtype, device=device) + exponent = torch.zeros(shape, dtype=torch.float32, device=device) for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)): - scaled = (shift / float(sig)) ** 2 + shift32 = shift.to(torch.float32) + scaled = (shift32 / float(sig)) ** 2 reshape_shape = [1] * len(coord_shifts) reshape_shape[dim] = shift.numel() exponent += scaled.reshape(reshape_shape) - return torch.exp(-0.5 * exponent) + gauss = torch.exp(-0.5 * exponent) + return gauss.to(dtype=self.torch_dtype) class ProbNMS(Transform): diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 5fed494028..7c8a2b15ac 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -517,6 +517,13 @@ class GenerateHeatmapd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`. Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image. + + Notes: + - Default heatmap_keys are generated as "{key}_heatmap" for each input key + - Shape inference precedence: static spatial_shape > ref_image + - Output shapes: + - Non-batched points (N, D): (N, H, W[, D]) + - Batched points (B, N, D): (B, N, H, W[, D]) """ backend = GenerateHeatmap.backend @@ -538,7 +545,7 @@ def __init__( spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None, truncated: float = 4.0, normalize: bool = True, - dtype: np.dtype | type = np.float32, + dtype: np.dtype | torch.dtype | type = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -567,6 +574,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: ) # Copy metadata if reference is MetaTensor if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor): + heatmap.affine = reference.affine self._update_spatial_metadata(heatmap, reference) d[out_key] = heatmap return d @@ -640,18 +648,8 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None: """Update spatial metadata of heatmap based on its dimensions.""" - # Determine if batched based on reference's batch dimension - ref_spatial_shape = reference.meta.get("spatial_shape", []) - ref_is_batched = len(reference.shape) > len(ref_spatial_shape) + 1 - - if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D) - spatial_shape = heatmap.shape[2:] - elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D) - # Disambiguate: 2D batched vs 3D non-batched - spatial_shape = heatmap.shape[2:] if ref_is_batched else heatmap.shape[1:] - else: # 2D non-batched: (C, H, W) - spatial_shape = heatmap.shape[1:] - + # trailing dims after channel are spatial regardless of batch presence + spatial_shape = heatmap.shape[-(reference.ndim - 1) :] heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape) diff --git a/tests/transforms/test_generate_heatmapd.py b/tests/transforms/test_generate_heatmapd.py index fc8a2b8e55..3d474177b6 100644 --- a/tests/transforms/test_generate_heatmapd.py +++ b/tests/transforms/test_generate_heatmapd.py @@ -53,7 +53,7 @@ f"dict_static_shape_{len(shape)}d", np.array([[1.0] * len(shape)], dtype=np.float32), {"spatial_shape": shape}, - (1,) + shape, + (1, *shape), np.float32, ] ) @@ -165,7 +165,8 @@ def test_dict_batched_with_ref(self, _, points, params, expected_shape, _expecte assert_allclose(heatmap.affine, image.affine, type_test=False) # Check max values - max_vals = heatmap.max(dim=2)[0].max(dim=2)[0].max(dim=2)[0] + hm2 = heatmap.reshape(heatmap.shape[0], heatmap.shape[1], -1) + max_vals = hm2.max(dim=2)[0] np.testing.assert_allclose( max_vals.cpu().numpy(), np.ones((expected_shape[0], expected_shape[1])), rtol=1e-5, atol=1e-5 ) From 2c7b4d014da73d4793ca918ffecc3fec4dcc30ca Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sat, 27 Sep 2025 00:01:58 +0900 Subject: [PATCH 16/19] Improves GenerateHeatmap transform and documentation Enhances the GenerateHeatmap transform with better normalization, spatial metadata handling, and comprehensive documentation. The changes ensure correct heatmap normalization, and improve handling of spatial metadata inheritance from reference images. Also improves input validation and fixes shape inconsistencies. Adds new test cases to cover edge cases and improve code reliability. Signed-off-by: sewon.jeon --- monai/transforms/post/array.py | 17 ++++----- monai/transforms/post/dictionary.py | 33 +++++++++++++---- tests/transforms/test_generate_heatmapd.py | 42 +++++++++++++++++++--- 3 files changed, 74 insertions(+), 18 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 833344ac6f..cc9637d0ef 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -16,7 +16,6 @@ import warnings from collections.abc import Callable, Iterable, Sequence -from typing import ClassVar import numpy as np import torch @@ -757,14 +756,16 @@ class GenerateHeatmap(Transform): Notes: - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D. - - Output shape: - - Non-batched points (N, D): (N, H, W[, D]) - - Batched points (B, N, D): (B, N, H, W[, D]) + - Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D. + - Output layout uses channel-first convention with one channel per landmark: + - Non-batched points (N, D): (N, Y, X) for 2D or (N, Z, Y, X) for 3D + - Batched points (B, N, D): (B, N, Y, X) for 2D or (B, N, Z, Y, X) for 3D - Each channel corresponds to one landmark. Args: sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions. spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform. + A single int value will be broadcast to all spatial dimensions. truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window. normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``. dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes). @@ -774,7 +775,7 @@ class GenerateHeatmap(Transform): """ - backend: ClassVar[list] = [TransformBackends.NUMPY, TransformBackends.TORCH] + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] def __init__( self, @@ -840,9 +841,9 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None # write back region.copy_(updated) if self.normalize: - peak = updated.max() - if peak.item() > 0: - heatmap[b_idx, idx] /= peak + peak = updated.amax() + denom = torch.where(peak > 0, peak, torch.ones_like(peak)) + heatmap[b_idx, idx] = heatmap[b_idx, idx] / denom if not is_batched: heatmap = heatmap.squeeze(0) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 7c8a2b15ac..f174cbbb3d 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -518,12 +518,35 @@ class GenerateHeatmapd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`. Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image. + Args: + keys: keys of the corresponding items in the dictionary. + sigma: standard deviation for the Gaussian kernel. Can be a single value or sequence matching number of points. + heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key. + ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will + have the same shape, affine, and spatial metadata as the reference images. + spatial_shape: spatial dimensions of output heatmaps. Can be: + - Single shape (tuple): applied to all keys + - List of shapes: one per key (must match keys length) + truncated: truncation distance for Gaussian kernel computation (in sigmas). + normalize: if True, normalize each heatmap's peak value to 1.0. + dtype: output data type for heatmaps. Defaults to np.float32. + allow_missing_keys: if True, don't raise error if some keys are missing in data. + + Returns: + Dictionary with original data plus generated heatmaps at specified keys. + + Raises: + ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length. + ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys). + ValueError: If input points have invalid shape (must be 2D or 3D). + Notes: - Default heatmap_keys are generated as "{key}_heatmap" for each input key - Shape inference precedence: static spatial_shape > ref_image - Output shapes: - Non-batched points (N, D): (N, H, W[, D]) - Batched points (B, N, D): (B, N, H, W[, D]) + - When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference """ backend = GenerateHeatmap.backend @@ -575,7 +598,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: # Copy metadata if reference is MetaTensor if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor): heatmap.affine = reference.affine - self._update_spatial_metadata(heatmap, reference) + self._update_spatial_metadata(heatmap, shape) d[out_key] = heatmap return d @@ -628,7 +651,7 @@ def _determine_shape( return static_shape points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) if points_t.ndim not in (2, 3): - raise ValueError(self._ERR_INVALID_POINTS) + raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.") spatial_dims = int(points_t.shape[-1]) if ref_key is not None and ref_key in data: return self._shape_from_reference(data[ref_key], spatial_dims) @@ -646,10 +669,8 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, return tuple(int(v) for v in reference.shape[-spatial_dims:]) raise ValueError(self._ERR_REF_NO_SHAPE) - def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None: - """Update spatial metadata of heatmap based on its dimensions.""" - # trailing dims after channel are spatial regardless of batch presence - spatial_shape = heatmap.shape[-(reference.ndim - 1) :] + def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int, ...]) -> None: + """Set spatial_shape explicitly from resolved shape.""" heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape) diff --git a/tests/transforms/test_generate_heatmapd.py b/tests/transforms/test_generate_heatmapd.py index 3d474177b6..85cdbdc75c 100644 --- a/tests/transforms/test_generate_heatmapd.py +++ b/tests/transforms/test_generate_heatmapd.py @@ -128,10 +128,16 @@ def test_dict_static_shape(self, _, points, params, expected_shape, expected_dty self.assertEqual(heatmap.shape, expected_shape) self.assertEqual(heatmap.dtype, expected_dtype) + # Verify no NaN or Inf values + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + # Verify max value is 1.0 for normalized heatmaps + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + def test_dict_missing_shape_raises(self): # Without ref image or explicit spatial_shape, must raise transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap") - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "spatial_shape|ref_image_keys"): transform({"points": np.zeros((1, 2), dtype=np.float32)}) @parameterized.expand(TEST_CASES_DTYPE) @@ -203,6 +209,35 @@ def test_dict_multiple_keys(self): # Verify peaks are at different locations self.assertNotEqual(np.argmax(result["hm1"]), np.argmax(result["hm2"])) + def test_dict_mismatched_heatmap_keys_length(self): + """Test ValueError when heatmap_keys length doesn't match keys""" + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2", "hm3"], # Mismatch: 3 heatmap keys for 2 input keys + spatial_shape=(8, 8), + ) + + def test_dict_mismatched_ref_image_keys_length(self): + """Test ValueError when ref_image_keys length doesn't match keys""" + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2"], + ref_image_keys=["img1", "img2", "img3"], # Mismatch: 3 ref keys for 2 input keys + spatial_shape=(8, 8), + ) + + def test_dict_per_key_spatial_shape_mismatch(self): + """Test ValueError when per-key spatial_shape length doesn't match keys""" + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2"], + spatial_shape=[(8, 8), (8, 8), (8, 8)], # Mismatch: 3 shapes for 2 keys + sigma=1.0, + ) + def test_metatensor_points_with_ref(self): """Test MetaTensor points with reference image - documents current behavior""" from monai.data import MetaTensor @@ -224,9 +259,8 @@ def test_metatensor_points_with_ref(self): self.assertIsInstance(heatmap, MetaTensor) self.assertEqual(tuple(heatmap.shape), (2, 8, 8, 8)) - # Note: Currently the heatmap may inherit affine from points MetaTensor - # This test documents the current behavior - # Ideally, the heatmap should use the reference image's affine + # Heatmap should inherit affine from the reference image + assert_allclose(heatmap.affine, image.affine, type_test=False) if __name__ == "__main__": From 1b5888bfd54d9d7be09e834ef46f45151625aa42 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sat, 27 Sep 2025 00:32:39 +0900 Subject: [PATCH 17/19] Fixes heatmap normalization and shape checking Fixes an issue where heatmap normalization was using the entire heatmap instead of the local region. Adds a check to ensure that the provided static shape matches the number of spatial dimensions. Signed-off-by: sewon.jeon --- monai/transforms/post/array.py | 4 ++-- monai/transforms/post/dictionary.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index cc9637d0ef..b9f3949900 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -841,9 +841,9 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None # write back region.copy_(updated) if self.normalize: - peak = updated.amax() + peak = heatmap[b_idx, idx].amax() denom = torch.where(peak > 0, peak, torch.ones_like(peak)) - heatmap[b_idx, idx] = heatmap[b_idx, idx] / denom + heatmap[b_idx, idx].div_(denom) if not is_batched: heatmap = heatmap.squeeze(0) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index f174cbbb3d..8002179256 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -520,7 +520,8 @@ class GenerateHeatmapd(MapTransform): Args: keys: keys of the corresponding items in the dictionary. - sigma: standard deviation for the Gaussian kernel. Can be a single value or sequence matching number of points. + sigma: standard deviation for the Gaussian kernel. Can be a single value or a sequence matching the number + of spatial dimensions. heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key. ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will have the same shape, affine, and spatial metadata as the reference images. @@ -647,12 +648,16 @@ def _prepare_shapes( def _determine_shape( self, points: Any, static_shape: tuple[int, ...] | None, data: Mapping[Hashable, Any], ref_key: Hashable | None ) -> tuple[int, ...]: - if static_shape is not None: - return static_shape points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) if points_t.ndim not in (2, 3): raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.") spatial_dims = int(points_t.shape[-1]) + if static_shape is not None: + if len(static_shape) != spatial_dims: + raise ValueError( + f"Provided static spatial_shape has {len(static_shape)} dims; expected {spatial_dims}." + ) + return static_shape if ref_key is not None and ref_key in data: return self._shape_from_reference(data[ref_key], spatial_dims) raise ValueError(self._ERR_NO_SHAPE) From fd4be38f093812d94d8aadf94019892c1cdc35d1 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sat, 27 Sep 2025 00:51:21 +0900 Subject: [PATCH 18/19] Adds `GenerateHeatmap` transform Adds a `GenerateHeatmap` transform to generate heatmaps from point data. This transform creates heatmaps from point data, validating that the dtype is a floating-point type. Signed-off-by: sewon.jeon --- monai/transforms/__init__.py | 4 ++++ monai/transforms/post/array.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index d15042181b..a4bb187300 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -293,6 +293,7 @@ AsDiscrete, DistanceTransformEDT, FillHoles, + GenerateHeatmap, Invert, KeepLargestConnectedComponent, LabelFilter, @@ -319,6 +320,9 @@ FillHolesD, FillHolesd, FillHolesDict, + GenerateHeatmapd, + GenerateHeatmapD, + GenerateHeatmapDict, InvertD, Invertd, InvertDict, diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index b9f3949900..eef714459d 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -799,6 +799,9 @@ def __init__( self.normalize = normalize self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor) self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray) + # Validate that dtype is floating-point for meaningful Gaussian values + if self.torch_dtype not in (torch.float16, torch.float32, torch.float64, torch.bfloat16): + raise ValueError(f"dtype must be a floating-point type, got {self.torch_dtype}") self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape) def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor: From fc28c712e9cf51082b90d45be94059bac43742e7 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Sat, 27 Sep 2025 01:11:20 +0900 Subject: [PATCH 19/19] fix nitpick comments Signed-off-by: sewon.jeon --- monai/transforms/post/array.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index eef714459d..1e795a75de 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -800,11 +800,24 @@ def __init__( self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor) self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray) # Validate that dtype is floating-point for meaningful Gaussian values - if self.torch_dtype not in (torch.float16, torch.float32, torch.float64, torch.bfloat16): + if not self.torch_dtype.is_floating_point: raise ValueError(f"dtype must be a floating-point type, got {self.torch_dtype}") self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape) def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor: + """ + Args: + points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D), + ordered as (Y, X) for 2D or (Z, Y, X) for 3D. + spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses + the value provided at construction. + + Returns: + Heatmaps with shape (N, *spatial) or (B, N, *spatial), one channel per landmark. + + Raises: + ValueError: if points shape/dimension or spatial_shape is invalid. + """ original_points = points points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) @@ -828,13 +841,15 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device) image_bounds = tuple(int(s) for s in target_shape) + bounds_t = torch.as_tensor(image_bounds, device=device, dtype=points_t.dtype) for b_idx in range(batch_size): for idx, center in enumerate(points_t[b_idx]): - center_vals = center.tolist() - if not np.all(np.isfinite(center_vals)): + if not torch.isfinite(center).all(): continue - if not self._is_inside(center_vals, image_bounds): + if not ((center >= 0).all() and (center < bounds_t).all()): continue + # _make_window expects Python floats; convert only when needed + center_vals = center.tolist() window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device) if window_slices is None: continue