diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 39f26d0fbd..9a44e60f40 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -124,19 +124,31 @@ def get_layer(self, layer_id: str | Callable[[nn.Module], nn.Module]) -> nn.Modu return cast(nn.Module, mod) raise NotImplementedError(f"Could not find {layer_id}.") - def class_score(self, logits: torch.Tensor, class_idx: int) -> torch.Tensor: - return logits[:, class_idx].squeeze() + def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor: + if isinstance(class_idx, int): + return logits[:, class_idx].squeeze() + elif class_idx.numel() == 1: + return logits[:, class_idx.item()] + elif len(class_idx.view(-1)) == logits.shape[0]: + return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1) + else: + raise ValueError("expect length of class_idx equal to batch size") def __call__(self, x, class_idx=None, retain_graph=False, **kwargs): train = self.model.training self.model.eval() logits = self.model(x, **kwargs) - self.class_idx = logits.max(1)[-1] if class_idx is None else class_idx + if class_idx is None: + self.class_idx = logits.max(1)[-1] + elif isinstance(class_idx, torch.Tensor): + self.class_idx = class_idx.to(logits.device) + else: + self.class_idx = class_idx acti, grad = None, None if self.register_forward: acti = tuple(self.activations[layer] for layer in self.target_layers) if self.register_backward: - self.score = self.class_score(logits, cast(int, self.class_idx)) + self.score = self.class_score(logits, self.class_idx) self.model.zero_grad() self.score.sum().backward(retain_graph=retain_graph) for layer in self.target_layers: diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index c54c9cd4ca..a9355f04f3 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -88,8 +88,6 @@ def model(self, m): def get_grad( self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph: bool = True, **kwargs: Any ) -> torch.Tensor: - if x.shape[0] != 1: - raise ValueError("expect batch size of 1") x.requires_grad = True self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs) diff --git a/tests/integration/test_vis_gradbased.py b/tests/integration/test_vis_gradbased.py index e9db0af240..39b51d809b 100644 --- a/tests/integration/test_vis_gradbased.py +++ b/tests/integration/test_vis_gradbased.py @@ -38,14 +38,19 @@ def __call__(self, x, adjoint_info): for type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad): # 2D densenet TESTS.append([type, DENSENET2D, (1, 1, 48, 64)]) + TESTS.append([type, DENSENET2D, (4, 1, 48, 64)]) # 3D densenet TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6)]) + TESTS.append([type, DENSENET3D, (2, 1, 6, 6, 6)]) # 2D senet TESTS.append([type, SENET2D, (1, 3, 64, 64)]) + TESTS.append([type, SENET2D, (3, 3, 64, 64)]) # 3D senet TESTS.append([type, SENET3D, (1, 3, 8, 8, 48)]) + TESTS.append([type, SENET3D, (2, 3, 8, 8, 48)]) # 2D densenet - adjoint TESTS.append([type, DENSENET2DADJOINT, (1, 1, 48, 64)]) + TESTS.append([type, DENSENET2DADJOINT, (3, 1, 48, 64)]) class TestGradientClassActivationMap(unittest.TestCase):