Skip to content

Commit fc28c71

Browse files
committed
fix nitpick comments
Signed-off-by: sewon.jeon <sewon.jeon@connecteve.com>
1 parent fd4be38 commit fc28c71

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

monai/transforms/post/array.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -800,11 +800,24 @@ def __init__(
800800
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
801801
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
802802
# Validate that dtype is floating-point for meaningful Gaussian values
803-
if self.torch_dtype not in (torch.float16, torch.float32, torch.float64, torch.bfloat16):
803+
if not self.torch_dtype.is_floating_point:
804804
raise ValueError(f"dtype must be a floating-point type, got {self.torch_dtype}")
805805
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)
806806

807807
def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
808+
"""
809+
Args:
810+
points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D),
811+
ordered as (Y, X) for 2D or (Z, Y, X) for 3D.
812+
spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses
813+
the value provided at construction.
814+
815+
Returns:
816+
Heatmaps with shape (N, *spatial) or (B, N, *spatial), one channel per landmark.
817+
818+
Raises:
819+
ValueError: if points shape/dimension or spatial_shape is invalid.
820+
"""
808821
original_points = points
809822
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
810823

@@ -828,13 +841,15 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
828841

829842
heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device)
830843
image_bounds = tuple(int(s) for s in target_shape)
844+
bounds_t = torch.as_tensor(image_bounds, device=device, dtype=points_t.dtype)
831845
for b_idx in range(batch_size):
832846
for idx, center in enumerate(points_t[b_idx]):
833-
center_vals = center.tolist()
834-
if not np.all(np.isfinite(center_vals)):
847+
if not torch.isfinite(center).all():
835848
continue
836-
if not self._is_inside(center_vals, image_bounds):
849+
if not ((center >= 0).all() and (center < bounds_t).all()):
837850
continue
851+
# _make_window expects Python floats; convert only when needed
852+
center_vals = center.tolist()
838853
window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device)
839854
if window_slices is None:
840855
continue

0 commit comments

Comments
 (0)