-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgradcam_pp.py
61 lines (50 loc) · 2.44 KB
/
gradcam_pp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn.functional as F
from deepnet.gradcam.gradcam import GradCAM
class GradCAMpp(GradCAM):
"""Calculate GradCAM++ salinecy map using heatmap and image
Arguments:
model: Model network
layer_name: Layer of the model on which GradCAM will produce saliency map
input: Input image
"""
def __init__(self, model, layer_name, input):
super(GradCAMpp, self).__init__(model, layer_name, input)
self.model = model
self.layer_name = layer_name
self.saliency_map(input)
def saliency_map(self, input, class_idx=None, retain_graph=False):
"""Creates saliency map of the same spatial dimension with input
Arguments:
input (tuple): input image with shape of (1, 3, H, W)
class_idx (int): class index for calculating GradCAM.
If not specified, the class index that makes the highest model prediction score will be used.
"""
b, c, h, w = input.size()
logit = self.model(input)
if class_idx is None:
score = logit[:, logit.max(1)[-1]].squeeze()
else:
score = logit[:, class_idx].squeeze()
self.model.zero_grad()
score.backward(retain_graph=retain_graph)
gradients = self.gradients['value'] # dS/dA
activations = self.activations['value'] # A
b, k, u, v = gradients.size()
alpha_num = gradients.pow(2)
alpha_denom = gradients.pow(2).mul(2) + \
activations.mul(gradients.pow(3)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1)
alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom))
alpha = alpha_num.div(alpha_denom+1e-7)
positive_gradients = F.relu(score.exp()*gradients) # ReLU(dY/dA) == ReLU(exp(S)*dS/dA))
weights = (alpha*positive_gradients).view(b, k, u*v).sum(-1).view(b, k, 1, 1)
saliency_map = (weights*activations).sum(1, keepdim=True)
saliency_map = F.relu(saliency_map)
saliency_map = F.upsample(saliency_map, size=(h,w), mode='bilinear', align_corners=False)
saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
saliency_map = (saliency_map-saliency_map_min).div(saliency_map_max-saliency_map_min).data
self.saliency_map = saliency_map
@property
def result(self):
"""Returns saliency map"""
return self.saliency_map