1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch .nn import functional as F
5
+ from torch .nn import init as init
6
+ from torch .nn .modules .batchnorm import _BatchNorm
7
+
8
+ from basicsr .utils import get_root_logger
9
+
10
+ try :
11
+ from basicsr .models .ops .dcn import (ModulatedDeformConvPack , modulated_deform_conv )
12
+
13
+ except ImportError :
14
+ print ('Cannot import dcn. Ignore this warning if dcn is not used. '
15
+ 'Otherwise install BasicSR with compiling dcn.' )
16
+ ModulatedDeformConvPack = object
17
+ modulated_deform_conv = None
18
+
19
+
20
+ @torch .no_grad ()
21
+ def default_init_weights (module_list , scale = 1 , bias_fill = 0 , ** kwargs ):
22
+ """Initialize network weights.
23
+
24
+ Args:
25
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
26
+ scale (float): Scale initialized weights, especially for residual
27
+ blocks. Default: 1.
28
+ bias_fill (float): The value to fill bias. Default: 0
29
+ kwargs (dict): Other arguments for initialization function.
30
+ """
31
+ if not isinstance (module_list , list ):
32
+ module_list = [module_list ]
33
+ for module in module_list :
34
+ for m in module .modules ():
35
+ if isinstance (m , nn .Conv2d ):
36
+ init .kaiming_normal_ (m .weight , ** kwargs )
37
+ m .weight .data *= scale
38
+ if m .bias is not None :
39
+ m .bias .data .fill_ (bias_fill )
40
+ elif isinstance (m , nn .Linear ):
41
+ init .kaiming_normal_ (m .weight , ** kwargs )
42
+ m .weight .data *= scale
43
+ if m .bias is not None :
44
+ m .bias .data .fill_ (bias_fill )
45
+ elif isinstance (m , _BatchNorm ):
46
+ init .constant_ (m .weight , 1 )
47
+ if m .bias is not None :
48
+ m .bias .data .fill_ (bias_fill )
49
+
50
+
51
+ def make_layer (basic_block , num_basic_block , ** kwarg ):
52
+ """Make layers by stacking the same blocks.
53
+
54
+ Args:
55
+ basic_block (nn.module): nn.module class for basic block.
56
+ num_basic_block (int): number of blocks.
57
+
58
+ Returns:
59
+ nn.Sequential: Stacked blocks in nn.Sequential.
60
+ """
61
+ layers = []
62
+ for _ in range (num_basic_block ):
63
+ layers .append (basic_block (** kwarg ))
64
+ return nn .Sequential (* layers )
65
+
66
+
67
+ class ResidualBlockNoBN (nn .Module ):
68
+ """Residual block without BN.
69
+
70
+ It has a style of:
71
+ ---Conv-ReLU-Conv-+-
72
+ |________________|
73
+
74
+ Args:
75
+ num_feat (int): Channel number of intermediate features.
76
+ Default: 64.
77
+ res_scale (float): Residual scale. Default: 1.
78
+ pytorch_init (bool): If set to True, use pytorch default init,
79
+ otherwise, use default_init_weights. Default: False.
80
+ """
81
+
82
+ def __init__ (self , num_feat = 64 , res_scale = 1 , pytorch_init = False ):
83
+ super (ResidualBlockNoBN , self ).__init__ ()
84
+ self .res_scale = res_scale
85
+ self .conv1 = nn .Conv2d (num_feat , num_feat , 3 , 1 , 1 , bias = True )
86
+ self .conv2 = nn .Conv2d (num_feat , num_feat , 3 , 1 , 1 , bias = True )
87
+ self .relu = nn .ReLU (inplace = True )
88
+
89
+ if not pytorch_init :
90
+ default_init_weights ([self .conv1 , self .conv2 ], 0.1 )
91
+
92
+ def forward (self , x ):
93
+ identity = x
94
+ out = self .conv2 (self .relu (self .conv1 (x )))
95
+ return identity + out * self .res_scale
96
+
97
+
98
+ class Upsample (nn .Sequential ):
99
+ """Upsample module.
100
+
101
+ Args:
102
+ scale (int): Scale factor. Supported scales: 2^n and 3.
103
+ num_feat (int): Channel number of intermediate features.
104
+ """
105
+
106
+ def __init__ (self , scale , num_feat ):
107
+ m = []
108
+ if (scale & (scale - 1 )) == 0 : # scale = 2^n
109
+ for _ in range (int (math .log (scale , 2 ))):
110
+ m .append (nn .Conv2d (num_feat , 4 * num_feat , 3 , 1 , 1 ))
111
+ m .append (nn .PixelShuffle (2 ))
112
+ elif scale == 3 :
113
+ m .append (nn .Conv2d (num_feat , 9 * num_feat , 3 , 1 , 1 ))
114
+ m .append (nn .PixelShuffle (3 ))
115
+ else :
116
+ raise ValueError (f'scale { scale } is not supported. '
117
+ 'Supported scales: 2^n and 3.' )
118
+ super (Upsample , self ).__init__ (* m )
119
+
120
+
121
+ def flow_warp (x ,
122
+ flow ,
123
+ interp_mode = 'bilinear' ,
124
+ padding_mode = 'zeros' ,
125
+ align_corners = True ):
126
+ """Warp an image or feature map with optical flow.
127
+
128
+ Args:
129
+ x (Tensor): Tensor with size (n, c, h, w).
130
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
131
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
132
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
133
+ Default: 'zeros'.
134
+ align_corners (bool): Before pytorch 1.3, the default value is
135
+ align_corners=True. After pytorch 1.3, the default value is
136
+ align_corners=False. Here, we use the True as default.
137
+
138
+ Returns:
139
+ Tensor: Warped image or feature map.
140
+ """
141
+ assert x .size ()[- 2 :] == flow .size ()[1 :3 ]
142
+ _ , _ , h , w = x .size ()
143
+ # create mesh grid
144
+ grid_y , grid_x = torch .meshgrid (
145
+ torch .arange (0 , h ).type_as (x ),
146
+ torch .arange (0 , w ).type_as (x ))
147
+ grid = torch .stack ((grid_x , grid_y ), 2 ).float () # W(x), H(y), 2
148
+ grid .requires_grad = False
149
+
150
+ vgrid = grid + flow
151
+ # scale grid to [-1,1]
152
+ vgrid_x = 2.0 * vgrid [:, :, :, 0 ] / max (w - 1 , 1 ) - 1.0
153
+ vgrid_y = 2.0 * vgrid [:, :, :, 1 ] / max (h - 1 , 1 ) - 1.0
154
+ vgrid_scaled = torch .stack ((vgrid_x , vgrid_y ), dim = 3 )
155
+ output = F .grid_sample (
156
+ x ,
157
+ vgrid_scaled ,
158
+ mode = interp_mode ,
159
+ padding_mode = padding_mode ,
160
+ align_corners = align_corners )
161
+
162
+ # TODO, what if align_corners=False
163
+ return output
164
+
165
+
166
+ def resize_flow (flow ,
167
+ size_type ,
168
+ sizes ,
169
+ interp_mode = 'bilinear' ,
170
+ align_corners = False ):
171
+ """Resize a flow according to ratio or shape.
172
+
173
+ Args:
174
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
175
+ size_type (str): 'ratio' or 'shape'.
176
+ sizes (list[int | float]): the ratio for resizing or the final output
177
+ shape.
178
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
179
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
180
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
181
+ ratio > 1.0).
182
+ 2) The order of output_size should be [out_h, out_w].
183
+ interp_mode (str): The mode of interpolation for resizing.
184
+ Default: 'bilinear'.
185
+ align_corners (bool): Whether align corners. Default: False.
186
+
187
+ Returns:
188
+ Tensor: Resized flow.
189
+ """
190
+ _ , _ , flow_h , flow_w = flow .size ()
191
+ if size_type == 'ratio' :
192
+ output_h , output_w = int (flow_h * sizes [0 ]), int (flow_w * sizes [1 ])
193
+ elif size_type == 'shape' :
194
+ output_h , output_w = sizes [0 ], sizes [1 ]
195
+ else :
196
+ raise ValueError (
197
+ f'Size type should be ratio or shape, but got type { size_type } .' )
198
+
199
+ input_flow = flow .clone ()
200
+ ratio_h = output_h / flow_h
201
+ ratio_w = output_w / flow_w
202
+ input_flow [:, 0 , :, :] *= ratio_w
203
+ input_flow [:, 1 , :, :] *= ratio_h
204
+ resized_flow = F .interpolate (
205
+ input = input_flow ,
206
+ size = (output_h , output_w ),
207
+ mode = interp_mode ,
208
+ align_corners = align_corners )
209
+ return resized_flow
210
+
211
+
212
+ # TODO: may write a cpp file
213
+ def pixel_unshuffle (x , scale ):
214
+ """ Pixel unshuffle.
215
+
216
+ Args:
217
+ x (Tensor): Input feature with shape (b, c, hh, hw).
218
+ scale (int): Downsample ratio.
219
+
220
+ Returns:
221
+ Tensor: the pixel unshuffled feature.
222
+ """
223
+ b , c , hh , hw = x .size ()
224
+ out_channel = c * (scale ** 2 )
225
+ assert hh % scale == 0 and hw % scale == 0
226
+ h = hh // scale
227
+ w = hw // scale
228
+ x_view = x .view (b , c , h , scale , w , scale )
229
+ return x_view .permute (0 , 1 , 3 , 5 , 2 , 4 ).reshape (b , out_channel , h , w )
230
+
231
+
232
+ class DCNv2Pack (ModulatedDeformConvPack ):
233
+ """Modulated deformable conv for deformable alignment.
234
+
235
+ Different from the official DCNv2Pack, which generates offsets and masks
236
+ from the preceding features, this DCNv2Pack takes another different
237
+ features to generate offsets and masks.
238
+
239
+ Ref:
240
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
241
+ """
242
+
243
+ def forward (self , x , feat ):
244
+ out = self .conv_offset (feat )
245
+ o1 , o2 , mask = torch .chunk (out , 3 , dim = 1 )
246
+ offset = torch .cat ((o1 , o2 ), dim = 1 )
247
+ mask = torch .sigmoid (mask )
248
+
249
+ offset_absmean = torch .mean (torch .abs (offset ))
250
+ if offset_absmean > 50 :
251
+ logger = get_root_logger ()
252
+ logger .warning (
253
+ f'Offset abs mean is { offset_absmean } , larger than 50.' )
254
+
255
+ return modulated_deform_conv (x , offset , mask , self .weight , self .bias ,
256
+ self .stride , self .padding , self .dilation ,
257
+ self .groups , self .deformable_groups )
258
+
259
+
260
+ ## Channel Attention (CA) Layer
261
+ class CALayer (nn .Module ):
262
+ def __init__ (self , channel , reduction = 16 ):
263
+ super (CALayer , self ).__init__ ()
264
+ # global average pooling: feature --> point
265
+ self .avg_pool = nn .AdaptiveAvgPool2d (1 )
266
+ # feature channel downscale and upscale --> channel weight
267
+ self .conv_du = nn .Sequential (
268
+ nn .Conv2d (channel , channel // reduction , 1 , padding = 0 , bias = True ),
269
+ nn .ReLU (inplace = True ),
270
+ nn .Conv2d (channel // reduction , channel , 1 , padding = 0 , bias = True ),
271
+ nn .Sigmoid ()
272
+ )
273
+
274
+ def forward (self , x ):
275
+ y = self .avg_pool (x )
276
+ y = self .conv_du (y )
277
+ return x * y
278
+
279
+
280
+ def default_conv (in_channels , out_channels , kernel_size , bias = True ):
281
+ return nn .Conv2d (
282
+ in_channels , out_channels , kernel_size ,
283
+ padding = (kernel_size // 2 ), bias = bias )
284
+
285
+
286
+ ## Residual Channel Attention Block (RCAB)
287
+ class RCAB (nn .Module ):
288
+ def __init__ (self , conv = default_conv , n_feat = 64 , kernel_size = 3 , reduction = 1 , bias = True , bn = False , act = nn .ReLU (True ), res_scale = 1 ):
289
+ super (RCAB , self ).__init__ ()
290
+ modules_body = []
291
+ for i in range (2 ):
292
+ modules_body .append (conv (n_feat , n_feat , kernel_size , bias = bias ))
293
+ if bn : modules_body .append (nn .BatchNorm2d (n_feat ))
294
+ if i == 0 : modules_body .append (act )
295
+ modules_body .append (CALayer (n_feat , reduction ))
296
+ self .body = nn .Sequential (* modules_body )
297
+ self .res_scale = res_scale
298
+
299
+ def forward (self , x ):
300
+ res = self .body (x )
301
+ #res = self.body(x).mul(self.res_scale)
302
+ res += x
303
+ return res
304
+
305
+
306
+ ## Residual Group (RG)
307
+ class ResidualGroup (nn .Module ):
308
+ def __init__ (self , conv = default_conv , n_feat = 64 , kernel_size = 3 , reduction = 1 , act = nn .ReLU (True ), res_scale = 1 , n_resblocks = 30 ):
309
+ super (ResidualGroup , self ).__init__ ()
310
+ modules_body = []
311
+ modules_body = [
312
+ RCAB (
313
+ conv , n_feat , kernel_size , reduction , bias = True , bn = False , act = nn .ReLU (True ), res_scale = 1 ) \
314
+ for _ in range (n_resblocks )]
315
+ modules_body .append (conv (n_feat , n_feat , kernel_size ))
316
+ self .body = nn .Sequential (* modules_body )
317
+
318
+ def forward (self , x ):
319
+ res = self .body (x )
320
+ res += x
321
+ return res
322
+
323
+
324
+ class RCABWithInputConv (nn .Module ):
325
+ """RCAB blocks with a convolution in front.
326
+ Args:
327
+ in_channels (int): Number of input channels of the first conv.
328
+ out_channels (int): Number of channels of the residual blocks.
329
+ Default: 64.
330
+ num_blocks (int): Number of residual blocks. Default: 30.
331
+ """
332
+
333
+ def __init__ (self , in_channels , out_channels = 64 , num_blocks = 30 ):
334
+ super ().__init__ ()
335
+
336
+ main = [RCAB (default_conv , out_channels , 3 , 1 , act = nn .ReLU (True ), res_scale = 1 ) for _ in range (num_blocks )]
337
+
338
+ # a convolution used to match the channels of the residual blocks
339
+ main .insert (0 , nn .LeakyReLU (negative_slope = 0.1 , inplace = True ))
340
+ main .insert (0 , nn .Conv2d (in_channels , out_channels , 3 , 1 , 1 , bias = True ))
341
+
342
+ self .main = nn .Sequential (* main )
343
+
344
+ def forward (self , feat ):
345
+ """
346
+ Forward function for RCABWithInputConv.
347
+ Args:
348
+ feat (Tensor): Input feature with shape (n, in_channels, h, w)
349
+ Returns:
350
+ Tensor: Output feature with shape (n, out_channels, h, w)
351
+ """
352
+ return self .main (feat )
0 commit comments