From e0eb5d4f4dd9844ad1bd1bc7caa3251f7c9eec60 Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Tue, 26 Mar 2024 14:45:15 +0800 Subject: [PATCH] [Fix] fix loss computation in MSPNHead (#2993) --- mmpose/models/heads/heatmap_heads/mspn_head.py | 4 +++- .../test_heads/test_heatmap_heads/test_mspn_head.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mmpose/models/heads/heatmap_heads/mspn_head.py b/mmpose/models/heads/heatmap_heads/mspn_head.py index 8b7cddf798..4d2c0bfcef 100644 --- a/mmpose/models/heads/heatmap_heads/mspn_head.py +++ b/mmpose/models/heads/heatmap_heads/mspn_head.py @@ -394,7 +394,9 @@ def loss(self, keypoint_weights = torch.cat([ d.gt_instance_labels.keypoint_weights for d in batch_data_samples - ]) # shape: [B*N, L, K] + ], + dim=1) + keypoint_weights = keypoint_weights.transpose(0, 1) # [B*N, L, K] # calculate losses over multiple stages and multiple units losses = dict() diff --git a/tests/test_models/test_heads/test_heatmap_heads/test_mspn_head.py b/tests/test_models/test_heads/test_heatmap_heads/test_mspn_head.py index ce3d19b688..5643ff00ba 100644 --- a/tests/test_models/test_heads/test_heatmap_heads/test_mspn_head.py +++ b/tests/test_models/test_heads/test_heatmap_heads/test_mspn_head.py @@ -44,6 +44,7 @@ def _get_data_samples(self, with_heatmap=True, with_reg_label=False, num_levels=num_levels)['data_samples'] + return batch_data_samples def test_init(self): @@ -153,6 +154,10 @@ def test_loss(self): (unit_channels, 32, 24), (unit_channels, 64, 48)]) batch_data_samples = self._get_data_samples( batch_size=2, heatmap_size=(48, 64), num_levels=4) + for ds in batch_data_samples: + ds.gt_instance_labels = InstanceData( + keypoint_weights=ds.gt_instance_labels.keypoint_weights. + transpose(0, 1)) losses = head.loss(feats, batch_data_samples) self.assertIsInstance(losses['loss_kpt'], torch.Tensor) @@ -189,6 +194,10 @@ def test_loss(self): (unit_channels, 32, 24), (unit_channels, 64, 48)]) batch_data_samples = self._get_data_samples( batch_size=2, heatmap_size=(48, 64), num_levels=16) + for ds in batch_data_samples: + ds.gt_instance_labels = InstanceData( + keypoint_weights=ds.gt_instance_labels.keypoint_weights. + transpose(0, 1)) losses = head.loss(feats, batch_data_samples) self.assertIsInstance(losses['loss_kpt'], torch.Tensor)