Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@
AsDiscrete,
DistanceTransformEDT,
FillHoles,
GenerateHeatmap,
Invert,
KeepLargestConnectedComponent,
LabelFilter,
Expand All @@ -319,6 +320,9 @@
FillHolesD,
FillHolesd,
FillHolesDict,
GenerateHeatmapd,
GenerateHeatmapD,
GenerateHeatmapDict,
InvertD,
Invertd,
InvertDict,
Expand Down
198 changes: 197 additions & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@
remove_small_objects,
)
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option
from monai.utils import (
TransformBackends,
convert_data_type,
convert_to_tensor,
ensure_tuple,
get_equivalent_dtype,
look_up_option,
)
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
Expand All @@ -54,6 +61,7 @@
"SobelGradients",
"VoteEnsemble",
"Invert",
"GenerateHeatmap",
"DistanceTransformEDT",
]

Expand Down Expand Up @@ -742,6 +750,194 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
return self.post_convert(out_pt, img)


class GenerateHeatmap(Transform):
"""
Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.

Notes:
- Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
- Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D.
- Output layout uses channel-first convention with one channel per landmark:
- Non-batched points (N, D): (N, Y, X) for 2D or (N, Z, Y, X) for 3D
- Batched points (B, N, D): (B, N, Y, X) for 2D or (B, N, Z, Y, X) for 3D
- Each channel corresponds to one landmark.

Args:
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
A single int value will be broadcast to all spatial dimensions.
truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).

Raises:
ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved.

"""

backend = [TransformBackends.NUMPY, TransformBackends.TORCH]

def __init__(
self,
sigma: Sequence[float] | float = 5.0,
spatial_shape: Sequence[int] | None = None,
truncated: float = 4.0,
normalize: bool = True,
dtype: np.dtype | torch.dtype | type = np.float32,
) -> None:
if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)):
if any(s <= 0 for s in sigma):
raise ValueError("sigma values must be positive.")
self._sigma = tuple(float(s) for s in sigma)
else:
if float(sigma) <= 0:
raise ValueError("sigma must be positive.")
self._sigma = (float(sigma),)
if truncated <= 0:
raise ValueError("truncated must be positive.")
self.truncated = float(truncated)
self.normalize = normalize
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
# Validate that dtype is floating-point for meaningful Gaussian values
if not self.torch_dtype.is_floating_point:
raise ValueError(f"dtype must be a floating-point type, got {self.torch_dtype}")
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)

def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
"""
Args:
points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D),
ordered as (Y, X) for 2D or (Z, Y, X) for 3D.
spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses
the value provided at construction.

Returns:
Heatmaps with shape (N, *spatial) or (B, N, *spatial), one channel per landmark.

Raises:
ValueError: if points shape/dimension or spatial_shape is invalid.
"""
original_points = points
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)

is_batched = points_t.ndim == 3
if not is_batched:
if points_t.ndim != 2:
raise ValueError(
"points must be a 2D or 3D array with shape (num_points, spatial_dims) or (B, num_points, spatial_dims)."
)
points_t = points_t.unsqueeze(0) # Add a batch dimension

if points_t.shape[-1] not in (2, 3):
raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.")

device = points_t.device
batch_size, num_points, spatial_dims = points_t.shape

target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)
sigma = self._resolve_sigma(spatial_dims)
radius = tuple(int(np.ceil(self.truncated * s)) for s in sigma)

heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device)
image_bounds = tuple(int(s) for s in target_shape)
bounds_t = torch.as_tensor(image_bounds, device=device, dtype=points_t.dtype)
for b_idx in range(batch_size):
for idx, center in enumerate(points_t[b_idx]):
if not torch.isfinite(center).all():
continue
if not ((center >= 0).all() and (center < bounds_t).all()):
continue
# _make_window expects Python floats; convert only when needed
center_vals = center.tolist()
window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device)
if window_slices is None:
continue
region = heatmap[b_idx, idx][window_slices]
gaussian = self._evaluate_gaussian(coord_shifts, sigma)
updated = torch.maximum(region, gaussian)
# write back
region.copy_(updated)
if self.normalize:
peak = heatmap[b_idx, idx].amax()
denom = torch.where(peak > 0, peak, torch.ones_like(peak))
heatmap[b_idx, idx].div_(denom)

if not is_batched:
heatmap = heatmap.squeeze(0)

target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype
converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype)
return converted

def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]:
shape = call_shape if call_shape is not None else self.spatial_shape
if shape is None:
raise ValueError("spatial_shape must be provided either at construction time or call time.")
shape_tuple = ensure_tuple(shape)
if len(shape_tuple) != spatial_dims:
if len(shape_tuple) == 1:
shape_tuple = shape_tuple * spatial_dims # type: ignore
else:
raise ValueError(
"spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)."
)
return tuple(int(s) for s in shape_tuple)

def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
if len(self._sigma) == spatial_dims:
return self._sigma
if len(self._sigma) == 1:
return self._sigma * spatial_dims
raise ValueError("sigma sequence length must equal the number of spatial dimensions.")

@staticmethod
def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:
for c, size in zip(center, bounds):
if not (0 <= c < size):
return False
return True

def _make_window(
self, center: Sequence[float], radius: tuple[int, ...], bounds: tuple[int, ...], device: torch.device
) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]:
slices: list[slice] = []
coord_shifts: list[torch.Tensor] = []
for _dim, (c, r, size) in enumerate(zip(center, radius, bounds)):
start = max(int(np.floor(c - r)), 0)
stop = min(int(np.ceil(c + r)) + 1, size)
if start >= stop:
return None, ()
slices.append(slice(start, stop))
coord_shifts.append(torch.arange(start, stop, device=device, dtype=torch.float32) - float(c))
return tuple(slices), tuple(coord_shifts)

def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
"""
Evaluate Gaussian at given coordinate shifts with specified sigmas.

Args:
coord_shifts: Per-dimension coordinate offsets from center.
sigma: Per-dimension standard deviations.

Returns:
Gaussian values at the specified coordinates.
"""
device = coord_shifts[0].device
shape = tuple(len(axis) for axis in coord_shifts)
if 0 in shape:
return torch.zeros(shape, dtype=self.torch_dtype, device=device)
exponent = torch.zeros(shape, dtype=torch.float32, device=device)
for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)):
shift32 = shift.to(torch.float32)
scaled = (shift32 / float(sig)) ** 2
reshape_shape = [1] * len(coord_shifts)
reshape_shape[dim] = shift.numel()
exponent += scaled.reshape(reshape_shape)
gauss = torch.exp(-0.5 * exponent)
return gauss.to(dtype=self.torch_dtype)


class ProbNMS(Transform):
"""
Performs probability based non-maximum suppression (NMS) on the probabilities map via
Expand Down
Loading
Loading