Skip to content
This repository has been archived by the owner on Nov 17, 2020. It is now read-only.

Commit

Permalink
Guided Grad-CAM bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Karol authored and Karol committed Jun 21, 2020
1 parent a4665e1 commit 9454782
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 20 deletions.
52 changes: 41 additions & 11 deletions gcam/backends/guided_grad_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from gcam.backends.grad_cam import GradCAM
from gcam.backends.guided_backpropagation import GuidedBackPropagation
from gcam import gcam_utils
import torch


class GuidedGradCam():
Expand All @@ -17,23 +18,52 @@ def __init__(self, model, target_layers=None, postprocessor=None, retain_graph=F

def generate_attention_map(self, batch, label):
"""Handles the generation of the attention map from start to finish."""
output, self.output_GCAM, output_batch_size, output_channels, output_shape = self.model_GCAM.generate_attention_map(batch.clone(), label)
_, self.output_GBP, _, _, _ = self.model_GBP.generate_attention_map(batch.clone(), label)
output, self.attention_map_GCAM, output_batch_size, output_channels, output_shape = self.model_GCAM.generate_attention_map(batch.clone(), label)
#_, self.attention_map_GBP, _, _, _ = self.model_GBP.generate_attention_map(batch.clone(), label)
self.attention_map_GBP = self._generate_gbp(batch, label)[""]
#self.attention_map_GBP = self.attention_map_GBP[""]
attention_map = self.generate()
return output, attention_map, output_batch_size, output_channels, output_shape

def get_registered_hooks(self):
"""Returns every hook that was able to register to a layer."""
return self.model_GCAM.get_registered_hooks()

def generate(self):
def generate(self): # TODO: Redo ggcam, find a solution for normalize_per_channel
"""Generates an attention map."""
attention_map_GCAM = self.model_GCAM.generate()
attention_map_GBP = self.model_GBP.generate()[""]
for layer_name in attention_map_GCAM.keys():
if attention_map_GBP.shape == attention_map_GCAM[layer_name].shape:
attention_map_GCAM[layer_name] = np.multiply(attention_map_GCAM[layer_name], attention_map_GBP)
for layer_name in self.attention_map_GCAM.keys():
if self.attention_map_GBP.shape == self.attention_map_GCAM[layer_name].shape:
self.attention_map_GCAM[layer_name] = np.multiply(self.attention_map_GCAM[layer_name], self.attention_map_GBP)
else:
attention_map_GCAM_tmp = gcam_utils.interpolate(attention_map_GCAM[layer_name], attention_map_GBP.shape[2:])
attention_map_GCAM[layer_name] = np.multiply(attention_map_GCAM_tmp, attention_map_GBP)
return attention_map_GCAM
attention_map_GCAM_tmp = gcam_utils.interpolate(self.attention_map_GCAM[layer_name], self.attention_map_GBP.shape[2:])
self.attention_map_GCAM[layer_name] = np.multiply(attention_map_GCAM_tmp, self.attention_map_GBP)
self.attention_map_GCAM[layer_name] = self._normalize_per_channel(self.attention_map_GCAM[layer_name])
return self.attention_map_GCAM

def _generate_gbp(self, batch, label):
output = self.model_GBP.forward(batch)
self.model_GBP.backward(label=label)

attention_map = self.model_GBP.data.grad.clone()
self.model_GBP.data.grad.zero_()
B, _, *data_shape = attention_map.shape
attention_map = attention_map.view(B, 1, -1, *data_shape)
attention_map = torch.mean(attention_map, dim=2) # TODO: mean or sum?
attention_map = attention_map.repeat(1, self.model_GBP.output_channels, *[1 for _ in range(self.model_GBP.input_dim)])
attention_map = attention_map.cpu().numpy()
attention_maps = {}
attention_maps[""] = attention_map
return attention_maps

def _normalize_per_channel(self, attention_map):
if np.min(attention_map) == np.max(attention_map):
return np.zeros(attention_map.shape)
# Normalization per channel
B, C, *data_shape = attention_map.shape
attention_map = np.reshape(attention_map, (B, C, -1))
attention_map_min = np.min(attention_map, axis=2, keepdims=True)[0]
attention_map_max = np.max(attention_map, axis=2, keepdims=True)[0]
attention_map -= attention_map_min
attention_map /= (attention_map_max - attention_map_min)
attention_map = np.reshape(attention_map, (B, C, *data_shape))
return attention_map
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="gcam",
version="0.0.19",
version="0.0.20",
author="Karol Gotkowski",
author_email="KarolGotkowski@gmx.de",
description="An easy to use framework that makes model predictions more interpretable for humans.",
Expand Down
30 changes: 22 additions & 8 deletions tests/test_classification/resnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def load_image(self, image_path):
]
)(raw_image[..., ::-1].copy())
image = image.to(self.DEVICE)
return image
return image, raw_image

def test_gbp(self):
model = gcam.inject(self.model, output_dir=os.path.join(self.current_path, 'results/resnet152/test_gbp'), backend='gbp',
evaluate=False, save_scores=False, save_maps=True, save_pickle=False, channels=1)
model.eval()
data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False)
for i, batch in enumerate(data_loader):
_ = model(batch[0])
_ = model(batch[0][0])

del model
gc.collect()
Expand All @@ -56,10 +56,26 @@ def test_gcam(self):
evaluate=False, save_scores=False, save_maps=True, save_pickle=False, channels=1)
model.eval()
data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False)
model.test_run(next(iter(data_loader))[0])

for i, batch in enumerate(data_loader):
_ = model(batch[0])
_ = model(batch[0][0])

del model
gc.collect()
torch.cuda.empty_cache()

if CLEAR and os.path.isdir(os.path.join(self.current_path, 'results/resnet152')):
shutil.rmtree(os.path.join(self.current_path, 'results/resnet152'))

def test_gcam_overlay(self):
layer = 'layer4'
model = gcam.inject(self.model, output_dir=os.path.join(self.current_path, 'results/resnet152/test_gcam_overlay'), backend='gcam', layer=layer,
evaluate=False, save_scores=False, save_maps=True, save_pickle=False, channels=1)
model.eval()
data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False)

for i, batch in enumerate(data_loader):
_ = model(batch[0][0], raw_input=batch[0][1])

del model
gc.collect()
Expand All @@ -74,10 +90,9 @@ def test_ggcam(self):
evaluate=False, save_scores=False, save_maps=True, save_pickle=False, channels=1)
model.eval()
data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False)
model.test_run(next(iter(data_loader))[0])

for i, batch in enumerate(data_loader):
_ = model(batch[0])
_ = model(batch[0][0])

del model
gc.collect()
Expand All @@ -92,10 +107,9 @@ def test_gcampp(self):
evaluate=False, save_scores=False, save_maps=True, save_pickle=False, channels=1)
model.eval()
data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False)
model.test_run(next(iter(data_loader))[0])

for i, batch in enumerate(data_loader):
_ = model(batch[0])
_ = model(batch[0][0])

del model
gc.collect()
Expand Down
19 changes: 19 additions & 0 deletions tests/test_segmentation/unet_seg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@ def test_gcam(self):
if CLEAR and os.path.isdir(os.path.join(self.current_path, 'results/unet_seg')):
shutil.rmtree(os.path.join(self.current_path, 'results/unet_seg'))

def test_gcam_overlay(self):
layer = 'full'
metric = 'wioa'
model = gcam.inject(self.model, output_dir=os.path.join(self.current_path, 'results/unet_seg/gcam_overlay'), backend='gcam', layer=layer,
evaluate=True, save_scores=False, save_maps=True, save_pickle=False, metric=metric, label=lambda x: 0.5 < x, channels=1)
model.eval()
data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False)
model.test_run(next(iter(data_loader))["img"])

for i, batch in enumerate(data_loader):
_ = model(batch["img"], mask=batch["gt"], raw_input=batch["img"])

del model
gc.collect()
torch.cuda.empty_cache()

if CLEAR and os.path.isdir(os.path.join(self.current_path, 'results/unet_seg')):
shutil.rmtree(os.path.join(self.current_path, 'results/unet_seg'))

def test_ggcam(self):
layer = 'full'
metric = 'wioa'
Expand Down

0 comments on commit 9454782

Please sign in to comment.