diff --git a/gcam/backends/guided_grad_cam.py b/gcam/backends/guided_grad_cam.py index 2f85296..a33c5b4 100644 --- a/gcam/backends/guided_grad_cam.py +++ b/gcam/backends/guided_grad_cam.py @@ -2,6 +2,7 @@ from gcam.backends.grad_cam import GradCAM from gcam.backends.guided_backpropagation import GuidedBackPropagation from gcam import gcam_utils +import torch class GuidedGradCam(): @@ -17,8 +18,10 @@ def __init__(self, model, target_layers=None, postprocessor=None, retain_graph=F def generate_attention_map(self, batch, label): """Handles the generation of the attention map from start to finish.""" - output, self.output_GCAM, output_batch_size, output_channels, output_shape = self.model_GCAM.generate_attention_map(batch.clone(), label) - _, self.output_GBP, _, _, _ = self.model_GBP.generate_attention_map(batch.clone(), label) + output, self.attention_map_GCAM, output_batch_size, output_channels, output_shape = self.model_GCAM.generate_attention_map(batch.clone(), label) + #_, self.attention_map_GBP, _, _, _ = self.model_GBP.generate_attention_map(batch.clone(), label) + self.attention_map_GBP = self._generate_gbp(batch, label)[""] + #self.attention_map_GBP = self.attention_map_GBP[""] attention_map = self.generate() return output, attention_map, output_batch_size, output_channels, output_shape @@ -26,14 +29,41 @@ def get_registered_hooks(self): """Returns every hook that was able to register to a layer.""" return self.model_GCAM.get_registered_hooks() - def generate(self): + def generate(self): # TODO: Redo ggcam, find a solution for normalize_per_channel """Generates an attention map.""" - attention_map_GCAM = self.model_GCAM.generate() - attention_map_GBP = self.model_GBP.generate()[""] - for layer_name in attention_map_GCAM.keys(): - if attention_map_GBP.shape == attention_map_GCAM[layer_name].shape: - attention_map_GCAM[layer_name] = np.multiply(attention_map_GCAM[layer_name], attention_map_GBP) + for layer_name in self.attention_map_GCAM.keys(): + if self.attention_map_GBP.shape == self.attention_map_GCAM[layer_name].shape: + self.attention_map_GCAM[layer_name] = np.multiply(self.attention_map_GCAM[layer_name], self.attention_map_GBP) else: - attention_map_GCAM_tmp = gcam_utils.interpolate(attention_map_GCAM[layer_name], attention_map_GBP.shape[2:]) - attention_map_GCAM[layer_name] = np.multiply(attention_map_GCAM_tmp, attention_map_GBP) - return attention_map_GCAM + attention_map_GCAM_tmp = gcam_utils.interpolate(self.attention_map_GCAM[layer_name], self.attention_map_GBP.shape[2:]) + self.attention_map_GCAM[layer_name] = np.multiply(attention_map_GCAM_tmp, self.attention_map_GBP) + self.attention_map_GCAM[layer_name] = self._normalize_per_channel(self.attention_map_GCAM[layer_name]) + return self.attention_map_GCAM + + def _generate_gbp(self, batch, label): + output = self.model_GBP.forward(batch) + self.model_GBP.backward(label=label) + + attention_map = self.model_GBP.data.grad.clone() + self.model_GBP.data.grad.zero_() + B, _, *data_shape = attention_map.shape + attention_map = attention_map.view(B, 1, -1, *data_shape) + attention_map = torch.mean(attention_map, dim=2) # TODO: mean or sum? + attention_map = attention_map.repeat(1, self.model_GBP.output_channels, *[1 for _ in range(self.model_GBP.input_dim)]) + attention_map = attention_map.cpu().numpy() + attention_maps = {} + attention_maps[""] = attention_map + return attention_maps + + def _normalize_per_channel(self, attention_map): + if np.min(attention_map) == np.max(attention_map): + return np.zeros(attention_map.shape) + # Normalization per channel + B, C, *data_shape = attention_map.shape + attention_map = np.reshape(attention_map, (B, C, -1)) + attention_map_min = np.min(attention_map, axis=2, keepdims=True)[0] + attention_map_max = np.max(attention_map, axis=2, keepdims=True)[0] + attention_map -= attention_map_min + attention_map /= (attention_map_max - attention_map_min) + attention_map = np.reshape(attention_map, (B, C, *data_shape)) + return attention_map diff --git a/setup.py b/setup.py index d9672e4..2cb14eb 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="gcam", - version="0.0.19", + version="0.0.20", author="Karol Gotkowski", author_email="KarolGotkowski@gmx.de", description="An easy to use framework that makes model predictions more interpretable for humans.", diff --git a/tests/test_classification/resnet_test.py b/tests/test_classification/resnet_test.py index 7d580cd..5cf5f9a 100644 --- a/tests/test_classification/resnet_test.py +++ b/tests/test_classification/resnet_test.py @@ -33,7 +33,7 @@ def load_image(self, image_path): ] )(raw_image[..., ::-1].copy()) image = image.to(self.DEVICE) - return image + return image, raw_image def test_gbp(self): model = gcam.inject(self.model, output_dir=os.path.join(self.current_path, 'results/resnet152/test_gbp'), backend='gbp', @@ -41,7 +41,7 @@ def test_gbp(self): model.eval() data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False) for i, batch in enumerate(data_loader): - _ = model(batch[0]) + _ = model(batch[0][0]) del model gc.collect() @@ -56,10 +56,26 @@ def test_gcam(self): evaluate=False, save_scores=False, save_maps=True, save_pickle=False, channels=1) model.eval() data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False) - model.test_run(next(iter(data_loader))[0]) for i, batch in enumerate(data_loader): - _ = model(batch[0]) + _ = model(batch[0][0]) + + del model + gc.collect() + torch.cuda.empty_cache() + + if CLEAR and os.path.isdir(os.path.join(self.current_path, 'results/resnet152')): + shutil.rmtree(os.path.join(self.current_path, 'results/resnet152')) + + def test_gcam_overlay(self): + layer = 'layer4' + model = gcam.inject(self.model, output_dir=os.path.join(self.current_path, 'results/resnet152/test_gcam_overlay'), backend='gcam', layer=layer, + evaluate=False, save_scores=False, save_maps=True, save_pickle=False, channels=1) + model.eval() + data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False) + + for i, batch in enumerate(data_loader): + _ = model(batch[0][0], raw_input=batch[0][1]) del model gc.collect() @@ -74,10 +90,9 @@ def test_ggcam(self): evaluate=False, save_scores=False, save_maps=True, save_pickle=False, channels=1) model.eval() data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False) - model.test_run(next(iter(data_loader))[0]) for i, batch in enumerate(data_loader): - _ = model(batch[0]) + _ = model(batch[0][0]) del model gc.collect() @@ -92,10 +107,9 @@ def test_gcampp(self): evaluate=False, save_scores=False, save_maps=True, save_pickle=False, channels=1) model.eval() data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False) - model.test_run(next(iter(data_loader))[0]) for i, batch in enumerate(data_loader): - _ = model(batch[0]) + _ = model(batch[0][0]) del model gc.collect() diff --git a/tests/test_segmentation/unet_seg_test.py b/tests/test_segmentation/unet_seg_test.py index 9b3e628..3b7b0d9 100644 --- a/tests/test_segmentation/unet_seg_test.py +++ b/tests/test_segmentation/unet_seg_test.py @@ -59,6 +59,25 @@ def test_gcam(self): if CLEAR and os.path.isdir(os.path.join(self.current_path, 'results/unet_seg')): shutil.rmtree(os.path.join(self.current_path, 'results/unet_seg')) + def test_gcam_overlay(self): + layer = 'full' + metric = 'wioa' + model = gcam.inject(self.model, output_dir=os.path.join(self.current_path, 'results/unet_seg/gcam_overlay'), backend='gcam', layer=layer, + evaluate=True, save_scores=False, save_maps=True, save_pickle=False, metric=metric, label=lambda x: 0.5 < x, channels=1) + model.eval() + data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False) + model.test_run(next(iter(data_loader))["img"]) + + for i, batch in enumerate(data_loader): + _ = model(batch["img"], mask=batch["gt"], raw_input=batch["img"]) + + del model + gc.collect() + torch.cuda.empty_cache() + + if CLEAR and os.path.isdir(os.path.join(self.current_path, 'results/unet_seg')): + shutil.rmtree(os.path.join(self.current_path, 'results/unet_seg')) + def test_ggcam(self): layer = 'full' metric = 'wioa'