forked from naoto0804/pytorch-inpainting-with-partial-conv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
59 lines (46 loc) · 2.08 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
def gram_matrix(feat):
# https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py
(b, ch, h, w) = feat.size()
feat = feat.view(b, ch, h * w)
feat_t = feat.transpose(1, 2)
gram = torch.bmm(feat, feat_t) / (ch * h * w)
return gram
def total_variation_loss(image):
# shift one pixel and get difference (for both x and y direction)
loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
return loss
class InpaintingLoss(nn.Module):
def __init__(self, extractor):
super().__init__()
self.l1 = nn.L1Loss()
self.extractor = extractor
def forward(self, input, mask, output, gt):
loss_dict = {}
output_comp = mask * input + (1 - mask) * output
loss_dict['hole'] = self.l1((1 - mask) * output, (1 - mask) * gt)
loss_dict['valid'] = self.l1(mask * output, mask * gt)
if output.shape[1] == 3:
feat_output_comp = self.extractor(output_comp)
feat_output = self.extractor(output)
feat_gt = self.extractor(gt)
elif output.shape[1] == 1:
feat_output_comp = self.extractor(torch.cat([output_comp]*3, 1))
feat_output = self.extractor(torch.cat([output]*3, 1))
feat_gt = self.extractor(torch.cat([gt]*3, 1))
else:
raise ValueError('only gray an')
loss_dict['prc'] = 0.0
for i in range(3):
loss_dict['prc'] += self.l1(feat_output[i], feat_gt[i])
loss_dict['prc'] += self.l1(feat_output_comp[i], feat_gt[i])
loss_dict['style'] = 0.0
for i in range(3):
loss_dict['style'] += self.l1(gram_matrix(feat_output[i]),
gram_matrix(feat_gt[i]))
loss_dict['style'] += self.l1(gram_matrix(feat_output_comp[i]),
gram_matrix(feat_gt[i]))
loss_dict['tv'] = total_variation_loss(output_comp)
return loss_dict