-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss.py
127 lines (98 loc) · 4.35 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.models import vgg19
from torchvision import transforms
from torch.autograd import Variable
import numpy as np
from torch.distributions.multivariate_normal import MultivariateNormal as Norm
import cv2
eps = np.finfo(np.float32).eps
def kldiv(s_map, gt):
assert s_map.size() == gt.size()
batch_size = s_map.size(0)
w = s_map.size(1)
h = s_map.size(2)
sum_s_map = torch.sum(s_map.view(batch_size, -1), 1)
expand_s_map = sum_s_map.view(batch_size, 1, 1).expand(batch_size, w, h)
assert expand_s_map.size() == s_map.size()
sum_gt = torch.sum(gt.view(batch_size, -1), 1)
expand_gt = sum_gt.view(batch_size, 1, 1).expand(batch_size, w, h)
assert expand_gt.size() == gt.size()
#eps = 2.2204e-16
# added eps to prevent division from zero
s_map_temp = s_map/((expand_s_map+eps)*1.0)
gt_temp = gt / ((expand_gt+eps)*1.0)
s_map_temp = s_map_temp.view(batch_size, -1)
gt_temp = gt_temp.view(batch_size, -1)
result = gt_temp * torch.log(eps + gt_temp/(s_map_temp + eps))
#print(torch.log(eps + gt/(s_map + eps)) )
#return torch.mean(torch.sum(result, 1))
return torch.sum(result, 1)
def normalize_map(s_map):
# normalize the salience map (as done in MIT code)
batch_size = s_map.size(0)
w = s_map.size(1)
h = s_map.size(2)
min_s_map = torch.min(s_map.view(batch_size, -1), 1)[0].view(batch_size, 1, 1).expand(batch_size, w, h)
max_s_map = torch.max(s_map.view(batch_size, -1), 1)[0].view(batch_size, 1, 1).expand(batch_size, w, h)
norm_s_map = (s_map - min_s_map)/(max_s_map-min_s_map*1.0)
return norm_s_map
def sim(s_map, gt):
''' For single image metric
Size of Image - WxH or 1xWxH
gt is ground truth saliency map
'''
batch_size = s_map.size(0)
w = s_map.size(1)
h = s_map.size(2)
s_map_norm = normalize_map(s_map)
gt_norm = normalize_map(gt)
sum_s_map = torch.sum(s_map_norm.view(batch_size, -1), 1)
expand_s_map = sum_s_map.view(batch_size, 1, 1).expand(batch_size, w, h)
assert expand_s_map.size() == s_map_norm.size()
sum_gt = torch.sum(gt.view(batch_size, -1), 1)
expand_gt = sum_gt.view(batch_size, 1, 1).expand(batch_size, w, h)
s_map_norm = s_map_norm/(expand_s_map*1.0)
gt_norm = gt / (expand_gt*1.0)
s_map_norm = s_map_norm.view(batch_size, -1)
gt_norm = gt_norm.view(batch_size, -1)
#return torch.mean(torch.sum(torch.min(s_map, gt), 1))
return torch.sum(torch.min(s_map_norm, gt_norm), 1)
def cc(s_map, gt):
assert s_map.size() == gt.size()
batch_size = s_map.size(0)
w = s_map.size(1)
h = s_map.size(2)
mean_s_map = torch.mean(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
std_s_map = torch.std(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
mean_gt = torch.mean(gt.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
std_gt = torch.std(gt.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
s_map_norm = (s_map - mean_s_map) / std_s_map
gt_norm = (gt - mean_gt) / std_gt
ab = torch.sum((s_map_norm * gt_norm).view(batch_size, -1), 1)
aa = torch.sum((s_map_norm * s_map_norm).view(batch_size, -1), 1)
bb = torch.sum((gt_norm * gt_norm).view(batch_size, -1), 1)
#return torch.mean(ab / (torch.sqrt(aa*bb)))
return ab / (torch.sqrt(aa*bb))
def nss(s_map, gt):
if s_map.size() != gt.size():
s_map = s_map.cpu().squeeze(0).numpy()
s_map = torch.FloatTensor(cv2.resize(s_map, (gt.size(2), gt.size(1)))).unsqueeze(0)
s_map = s_map.cuda()
gt = gt.cuda()
# print(s_map.size(), gt.size())
assert s_map.size()==gt.size()
batch_size = s_map.size(0)
w = s_map.size(1)
h = s_map.size(2)
mean_s_map = torch.mean(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
std_s_map = torch.std(s_map.view(batch_size, -1), 1).view(batch_size, 1, 1).expand(batch_size, w, h)
#eps = 2.2204e-16
s_map_norm = (s_map - mean_s_map) / (std_s_map + eps)
s_map_norm = torch.sum((s_map_norm * gt).view(batch_size, -1), 1)
count = torch.sum(gt.view(batch_size, -1), 1)
#return torch.mean(s_map / count)
return s_map_norm / count