@@ -800,11 +800,24 @@ def __init__(
800
800
self .torch_dtype = get_equivalent_dtype (dtype , torch .Tensor )
801
801
self .numpy_dtype = get_equivalent_dtype (dtype , np .ndarray )
802
802
# Validate that dtype is floating-point for meaningful Gaussian values
803
- if self .torch_dtype not in ( torch . float16 , torch . float32 , torch . float64 , torch . bfloat16 ) :
803
+ if not self .torch_dtype . is_floating_point :
804
804
raise ValueError (f"dtype must be a floating-point type, got { self .torch_dtype } " )
805
805
self .spatial_shape = None if spatial_shape is None else tuple (int (s ) for s in spatial_shape )
806
806
807
807
def __call__ (self , points : NdarrayOrTensor , spatial_shape : Sequence [int ] | None = None ) -> NdarrayOrTensor :
808
+ """
809
+ Args:
810
+ points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D),
811
+ ordered as (Y, X) for 2D or (Z, Y, X) for 3D.
812
+ spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses
813
+ the value provided at construction.
814
+
815
+ Returns:
816
+ Heatmaps with shape (N, *spatial) or (B, N, *spatial), one channel per landmark.
817
+
818
+ Raises:
819
+ ValueError: if points shape/dimension or spatial_shape is invalid.
820
+ """
808
821
original_points = points
809
822
points_t = convert_to_tensor (points , dtype = torch .float32 , track_meta = False )
810
823
@@ -828,13 +841,15 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
828
841
829
842
heatmap = torch .zeros ((batch_size , num_points , * target_shape ), dtype = self .torch_dtype , device = device )
830
843
image_bounds = tuple (int (s ) for s in target_shape )
844
+ bounds_t = torch .as_tensor (image_bounds , device = device , dtype = points_t .dtype )
831
845
for b_idx in range (batch_size ):
832
846
for idx , center in enumerate (points_t [b_idx ]):
833
- center_vals = center .tolist ()
834
- if not np .all (np .isfinite (center_vals )):
847
+ if not torch .isfinite (center ).all ():
835
848
continue
836
- if not self . _is_inside ( center_vals , image_bounds ):
849
+ if not (( center >= 0 ). all () and ( center < bounds_t ). all () ):
837
850
continue
851
+ # _make_window expects Python floats; convert only when needed
852
+ center_vals = center .tolist ()
838
853
window_slices , coord_shifts = self ._make_window (center_vals , radius , image_bounds , device )
839
854
if window_slices is None :
840
855
continue
0 commit comments