-
Notifications
You must be signed in to change notification settings - Fork 27
/
sagan_models.py
255 lines (210 loc) · 10.3 KB
/
sagan_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch.nn.init import xavier_uniform_
def init_weights(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
xavier_uniform_(m.weight)
m.bias.data.fill_(0.)
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
def snlinear(in_features, out_features):
return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
def sn_embedding(num_embeddings, embedding_dim):
return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))
class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_channels):
super(Self_Attn, self).__init__()
self.in_channels = in_channels
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
self.softmax = nn.Softmax(dim=-1)
self.sigma = nn.Parameter(torch.zeros(1))
def forward(self, x):
"""
inputs :
x : input feature maps(B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
_, ch, h, w = x.size()
# Theta path
theta = self.snconv1x1_theta(x)
theta = theta.view(-1, ch//8, h*w)
# Phi path
phi = self.snconv1x1_phi(x)
phi = self.maxpool(phi)
phi = phi.view(-1, ch//8, h*w//4)
# Attn map
attn = torch.bmm(theta.permute(0, 2, 1), phi)
attn = self.softmax(attn)
# g path
g = self.snconv1x1_g(x)
g = self.maxpool(g)
g = g.view(-1, ch//2, h*w//4)
# Attn_g
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
attn_g = attn_g.view(-1, ch//2, h, w)
attn_g = self.snconv1x1_attn(attn_g)
# Out
out = x + self.sigma*attn_g
return out
class ConditionalBatchNorm2d(nn.Module):
# https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
# self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, :num_features].fill_(1.) # Initialize scale to 1
self.embed.weight.data[:, num_features:].zero_() # Initialize bias at 0
def forward(self, x, y):
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
return out
class GenBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_classes):
super(GenBlock, self).__init__()
self.cond_bn1 = ConditionalBatchNorm2d(in_channels, num_classes)
self.relu = nn.ReLU(inplace=True)
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.cond_bn2 = ConditionalBatchNorm2d(out_channels, num_classes)
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, labels):
x0 = x
x = self.cond_bn1(x, labels)
x = self.relu(x)
x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
x = self.snconv2d1(x)
x = self.cond_bn2(x, labels)
x = self.relu(x)
x = self.snconv2d2(x)
x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
x0 = self.snconv2d0(x0)
out = x + x0
return out
class Generator(nn.Module):
"""Generator."""
def __init__(self, z_dim, g_conv_dim, num_classes):
super(Generator, self).__init__()
self.z_dim = z_dim
self.g_conv_dim = g_conv_dim
self.snlinear0 = snlinear(in_features=z_dim, out_features=g_conv_dim*16*4*4)
self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16, num_classes)
self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8, num_classes)
self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4, num_classes)
self.self_attn = Self_Attn(g_conv_dim*4)
self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2, num_classes)
self.block5 = GenBlock(g_conv_dim*2, g_conv_dim, num_classes)
self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
self.relu = nn.ReLU(inplace=True)
self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
self.tanh = nn.Tanh()
# Weight init
self.apply(init_weights)
def forward(self, z, labels):
# n x z_dim
act0 = self.snlinear0(z) # n x g_conv_dim*16*4*4
act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
act1 = self.block1(act0, labels) # n x g_conv_dim*16 x 8 x 8
act2 = self.block2(act1, labels) # n x g_conv_dim*8 x 16 x 16
act3 = self.block3(act2, labels) # n x g_conv_dim*4 x 32 x 32
act3 = self.self_attn(act3) # n x g_conv_dim*4 x 32 x 32
act4 = self.block4(act3, labels) # n x g_conv_dim*2 x 64 x 64
act5 = self.block5(act4, labels) # n x g_conv_dim x 128 x 128
act5 = self.bn(act5) # n x g_conv_dim x 128 x 128
act5 = self.relu(act5) # n x g_conv_dim x 128 x 128
act6 = self.snconv2d1(act5) # n x 3 x 128 x 128
act6 = self.tanh(act6) # n x 3 x 128 x 128
return act6
class DiscOptBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DiscOptBlock, self).__init__()
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.downsample = nn.AvgPool2d(2)
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x0 = x
x = self.snconv2d1(x)
x = self.relu(x)
x = self.snconv2d2(x)
x = self.downsample(x)
x0 = self.downsample(x0)
x0 = self.snconv2d0(x0)
out = x + x0
return out
class DiscBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DiscBlock, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.downsample = nn.AvgPool2d(2)
self.ch_mismatch = False
if in_channels != out_channels:
self.ch_mismatch = True
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, downsample=True):
x0 = x
x = self.relu(x)
x = self.snconv2d1(x)
x = self.relu(x)
x = self.snconv2d2(x)
if downsample:
x = self.downsample(x)
if downsample or self.ch_mismatch:
x0 = self.snconv2d0(x0)
if downsample:
x0 = self.downsample(x0)
out = x + x0
return out
class Discriminator(nn.Module):
"""Discriminator."""
def __init__(self, d_conv_dim, num_classes):
super(Discriminator, self).__init__()
self.d_conv_dim = d_conv_dim
self.opt_block1 = DiscOptBlock(3, d_conv_dim)
self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
self.self_attn = Self_Attn(d_conv_dim*2)
self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*8)
self.block4 = DiscBlock(d_conv_dim*8, d_conv_dim*16)
self.block5 = DiscBlock(d_conv_dim*16, d_conv_dim*16)
self.relu = nn.ReLU(inplace=True)
self.snlinear1 = snlinear(in_features=d_conv_dim*16, out_features=1)
self.sn_embedding1 = sn_embedding(num_classes, d_conv_dim*16)
# Weight init
self.apply(init_weights)
xavier_uniform_(self.sn_embedding1.weight)
def forward(self, x, labels):
# n x 3 x 128 x 128
h0 = self.opt_block1(x) # n x d_conv_dim x 64 x 64
h1 = self.block1(h0) # n x d_conv_dim*2 x 32 x 32
h1 = self.self_attn(h1) # n x d_conv_dim*2 x 32 x 32
h2 = self.block2(h1) # n x d_conv_dim*4 x 16 x 16
h3 = self.block3(h2) # n x d_conv_dim*8 x 8 x 8
h4 = self.block4(h3) # n x d_conv_dim*16 x 4 x 4
h5 = self.block5(h4, downsample=False) # n x d_conv_dim*16 x 4 x 4
h5 = self.relu(h5) # n x d_conv_dim*16 x 4 x 4
h6 = torch.sum(h5, dim=[2,3]) # n x d_conv_dim*16
output1 = torch.squeeze(self.snlinear1(h6)) # n
# Projection
h_labels = self.sn_embedding1(labels) # n x d_conv_dim*16
proj = torch.mul(h6, h_labels) # n x d_conv_dim*16
output2 = torch.sum(proj, dim=[1]) # n
# Out
output = output1 + output2 # n
return output