-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathCommon.py
189 lines (156 loc) · 7.43 KB
/
Common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from contextlib import contextmanager
from math import sqrt, log
import torch
import torch.nn as nn
# import warnings
# warnings.simplefilter('ignore')
class BaseModule(nn.Module):
def __init__(self):
self.act_fn = None
super(BaseModule, self).__init__()
def selu_init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) and m.weight.requires_grad:
m.weight.data.normal_(0.0, 1.0 / sqrt(m.weight.numel()))
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear) and m.weight.requires_grad:
m.weight.data.normal_(0, 1.0 / sqrt(m.weight.numel()))
m.bias.data.zero_()
def initialize_weights_xavier_uniform(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) and m.weight.requires_grad:
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad:
m.weight.data.fill_(1)
m.bias.data.zero_()
def load_state_dict(self, state_dict, strict=True, self_state=False):
own_state = self_state if self_state else self.state_dict()
for name, param in state_dict.items():
if name in own_state:
try:
own_state[name].copy_(param.data)
except Exception as e:
print("Parameter {} fails to load.".format(name))
print("-----------------------------------------")
print(e)
else:
print("Parameter {} is not in the model. ".format(name))
@contextmanager
def set_activation_inplace(self):
if hasattr(self, 'act_fn') and hasattr(self.act_fn, 'inplace'):
# save memory
self.act_fn.inplace = True
yield
self.act_fn.inplace = False
else:
yield
def total_parameters(self):
total = sum([i.numel() for i in self.parameters()])
trainable = sum([i.numel() for i in self.parameters() if i.requires_grad])
print("Total parameters : {}. Trainable parameters : {}".format(total, trainable))
return total
def forward(self, *x):
raise NotImplementedError
class ResidualFixBlock(BaseModule):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1,
groups=1, activation=nn.SELU(), conv=nn.Conv2d):
super(ResidualFixBlock, self).__init__()
self.act_fn = activation
self.m = nn.Sequential(
conv(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation, groups=groups),
activation,
# conv(out_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2, dilation=1, groups=groups),
conv(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation, groups=groups),
)
def forward(self, x):
out = self.m(x)
return self.act_fn(out + x)
class ConvBlock(BaseModule):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, groups=1,
activation=nn.SELU(), conv=nn.Conv2d):
super(ConvBlock, self).__init__()
self.m = nn.Sequential(conv(in_channels, out_channels, kernel_size, padding=padding,
dilation=dilation, groups=groups),
activation)
def forward(self, x):
return self.m(x)
class UpSampleBlock(BaseModule):
def __init__(self, channels, scale, activation, atrous_rate=1, conv=nn.Conv2d):
assert scale in [2, 4, 8], "Currently UpSampleBlock supports 2, 4, 8 scaling"
super(UpSampleBlock, self).__init__()
m = nn.Sequential(
conv(channels, 4 * channels, kernel_size=3, padding=atrous_rate, dilation=atrous_rate),
activation,
nn.PixelShuffle(2)
)
self.m = nn.Sequential(*[m for _ in range(int(log(scale, 2)))])
def forward(self, x):
return self.m(x)
class SpatialChannelSqueezeExcitation(BaseModule):
# https://arxiv.org/abs/1709.01507
# https://arxiv.org/pdf/1803.02579v1.pdf
def __init__(self, in_channel, reduction=16, activation=nn.ReLU()):
super(SpatialChannelSqueezeExcitation, self).__init__()
linear_nodes = max(in_channel // reduction, 4) # avoid only 1 node case
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.channel_excite = nn.Sequential(
# check the paper for the number 16 in reduction. It is selected by experiment.
nn.Linear(in_channel, linear_nodes),
activation,
nn.Linear(linear_nodes, in_channel),
nn.Sigmoid()
)
self.spatial_excite = nn.Sequential(
nn.Conv2d(in_channel, 1, kernel_size=1, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, h, w = x.size()
#
channel = self.avg_pool(x).view(b, c)
# channel = F.avg_pool2d(x, kernel_size=(h,w)).view(b,c) # used for porting to other frameworks
cSE = self.channel_excite(channel).view(b, c, 1, 1)
x_cSE = torch.mul(x, cSE)
# spatial
sSE = self.spatial_excite(x)
x_sSE = torch.mul(x, sSE)
# return x_sSE
return torch.add(x_cSE, x_sSE)
class PartialConv(nn.Module):
# reference:
# Image Inpainting for Irregular Holes Using Partial Convolutions
# http://masc.cs.gmu.edu/wiki/partialconv/show?time=2018-05-24+21%3A41%3A10
# https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/net.py
# https://github.com/SeitaroShinagawa/chainer-partial_convolution_image_inpainting/blob/master/common/net.py
# partial based padding
# https: // github.com / NVIDIA / partialconv / blob / master / models / pd_resnet.py
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(PartialConv, self).__init__()
self.feature_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
self.mask_conv = nn.Conv2d(1, 1, kernel_size, stride,
padding, dilation, groups, bias=False)
self.window_size = self.mask_conv.kernel_size[0] * self.mask_conv.kernel_size[1]
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
for param in self.mask_conv.parameters():
param.requires_grad = False
def forward(self, x):
output = self.feature_conv(x)
if self.feature_conv.bias is not None:
output_bias = self.feature_conv.bias.view(1, -1, 1, 1).expand_as(output)
else:
output_bias = torch.zeros_like(output, device=x.device)
with torch.no_grad():
ones = torch.ones(1, 1, x.size(2), x.size(3), device=x.device)
output_mask = self.mask_conv(ones)
output_mask = self.window_size / output_mask
output = (output - output_bias) * output_mask + output_bias
return output