-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
67 lines (57 loc) · 1.88 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch.nn as nn
import torch
import numpy as np
from torch.autograd import Variable
from torch import nn
from collections import namedtuple
from torchvision import models
import matplotlib.pyplot as plt
import util
import torch.nn.functional as F
import projection
class GANLoss(nn.Module):
def __init__(self, target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.eps = 1e-9
def __call__(self, input, target_is_real):
if target_is_real:
return -1.*torch.mean(torch.log(input + self.eps))
else:
return -1.*torch.mean(torch.log(1 - input + self.eps))
class PatchDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=28, n_layers=3):
super().__init__()
sequence = [
nn.ReflectionPad2d(1),
nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**(n+1), 8)
stride = 1 if n == n_layers - 1 else 2
sequence += [
nn.ReflectionPad2d(1),
nn.Conv2d(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=4,
stride=stride,
padding=0),
nn.BatchNorm2d(ndf*nf_mult),
nn.Dropout2d(0.5),
nn.LeakyReLU(0.2, True)
]
sequence += [
nn.ReflectionPad2d(1),
nn.Conv2d(ndf * nf_mult,
1,
kernel_size=4,
stride=1,
padding=0),
nn.Sigmoid()
]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)