diff --git a/mmpose/codecs/spr.py b/mmpose/codecs/spr.py index 8e09b185c7..fba17f1598 100644 --- a/mmpose/codecs/spr.py +++ b/mmpose/codecs/spr.py @@ -138,7 +138,7 @@ def _get_heatmap_weights(self, Returns: np.ndarray: Heatmap weight array in the same shape with heatmaps """ - heatmap_weights = np.ones(heatmaps.shape) * bg_weight + heatmap_weights = np.ones(heatmaps.shape, dtype=np.float32) * bg_weight heatmap_weights[heatmaps > 0] = fg_weight return heatmap_weights