Skip to content

Commit 11019a7

Browse files
committed
Auto commit: 2024-12-16 05:00:02
1 parent 517e2f6 commit 11019a7

File tree

8 files changed

+1387
-29
lines changed

8 files changed

+1387
-29
lines changed

models/arch_util.py

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
import math
2+
import torch
3+
from torch import nn as nn
4+
from torch.nn import functional as F
5+
from torch.nn import init as init
6+
from torch.nn.modules.batchnorm import _BatchNorm
7+
8+
from basicsr.utils import get_root_logger
9+
10+
try:
11+
from basicsr.models.ops.dcn import (ModulatedDeformConvPack, modulated_deform_conv)
12+
13+
except ImportError:
14+
print('Cannot import dcn. Ignore this warning if dcn is not used. '
15+
'Otherwise install BasicSR with compiling dcn.')
16+
ModulatedDeformConvPack = object
17+
modulated_deform_conv = None
18+
19+
20+
@torch.no_grad()
21+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
22+
"""Initialize network weights.
23+
24+
Args:
25+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
26+
scale (float): Scale initialized weights, especially for residual
27+
blocks. Default: 1.
28+
bias_fill (float): The value to fill bias. Default: 0
29+
kwargs (dict): Other arguments for initialization function.
30+
"""
31+
if not isinstance(module_list, list):
32+
module_list = [module_list]
33+
for module in module_list:
34+
for m in module.modules():
35+
if isinstance(m, nn.Conv2d):
36+
init.kaiming_normal_(m.weight, **kwargs)
37+
m.weight.data *= scale
38+
if m.bias is not None:
39+
m.bias.data.fill_(bias_fill)
40+
elif isinstance(m, nn.Linear):
41+
init.kaiming_normal_(m.weight, **kwargs)
42+
m.weight.data *= scale
43+
if m.bias is not None:
44+
m.bias.data.fill_(bias_fill)
45+
elif isinstance(m, _BatchNorm):
46+
init.constant_(m.weight, 1)
47+
if m.bias is not None:
48+
m.bias.data.fill_(bias_fill)
49+
50+
51+
def make_layer(basic_block, num_basic_block, **kwarg):
52+
"""Make layers by stacking the same blocks.
53+
54+
Args:
55+
basic_block (nn.module): nn.module class for basic block.
56+
num_basic_block (int): number of blocks.
57+
58+
Returns:
59+
nn.Sequential: Stacked blocks in nn.Sequential.
60+
"""
61+
layers = []
62+
for _ in range(num_basic_block):
63+
layers.append(basic_block(**kwarg))
64+
return nn.Sequential(*layers)
65+
66+
67+
class ResidualBlockNoBN(nn.Module):
68+
"""Residual block without BN.
69+
70+
It has a style of:
71+
---Conv-ReLU-Conv-+-
72+
|________________|
73+
74+
Args:
75+
num_feat (int): Channel number of intermediate features.
76+
Default: 64.
77+
res_scale (float): Residual scale. Default: 1.
78+
pytorch_init (bool): If set to True, use pytorch default init,
79+
otherwise, use default_init_weights. Default: False.
80+
"""
81+
82+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
83+
super(ResidualBlockNoBN, self).__init__()
84+
self.res_scale = res_scale
85+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
86+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
87+
self.relu = nn.ReLU(inplace=True)
88+
89+
if not pytorch_init:
90+
default_init_weights([self.conv1, self.conv2], 0.1)
91+
92+
def forward(self, x):
93+
identity = x
94+
out = self.conv2(self.relu(self.conv1(x)))
95+
return identity + out * self.res_scale
96+
97+
98+
class Upsample(nn.Sequential):
99+
"""Upsample module.
100+
101+
Args:
102+
scale (int): Scale factor. Supported scales: 2^n and 3.
103+
num_feat (int): Channel number of intermediate features.
104+
"""
105+
106+
def __init__(self, scale, num_feat):
107+
m = []
108+
if (scale & (scale - 1)) == 0: # scale = 2^n
109+
for _ in range(int(math.log(scale, 2))):
110+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
111+
m.append(nn.PixelShuffle(2))
112+
elif scale == 3:
113+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
114+
m.append(nn.PixelShuffle(3))
115+
else:
116+
raise ValueError(f'scale {scale} is not supported. '
117+
'Supported scales: 2^n and 3.')
118+
super(Upsample, self).__init__(*m)
119+
120+
121+
def flow_warp(x,
122+
flow,
123+
interp_mode='bilinear',
124+
padding_mode='zeros',
125+
align_corners=True):
126+
"""Warp an image or feature map with optical flow.
127+
128+
Args:
129+
x (Tensor): Tensor with size (n, c, h, w).
130+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
131+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
132+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
133+
Default: 'zeros'.
134+
align_corners (bool): Before pytorch 1.3, the default value is
135+
align_corners=True. After pytorch 1.3, the default value is
136+
align_corners=False. Here, we use the True as default.
137+
138+
Returns:
139+
Tensor: Warped image or feature map.
140+
"""
141+
assert x.size()[-2:] == flow.size()[1:3]
142+
_, _, h, w = x.size()
143+
# create mesh grid
144+
grid_y, grid_x = torch.meshgrid(
145+
torch.arange(0, h).type_as(x),
146+
torch.arange(0, w).type_as(x))
147+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
148+
grid.requires_grad = False
149+
150+
vgrid = grid + flow
151+
# scale grid to [-1,1]
152+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
153+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
154+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
155+
output = F.grid_sample(
156+
x,
157+
vgrid_scaled,
158+
mode=interp_mode,
159+
padding_mode=padding_mode,
160+
align_corners=align_corners)
161+
162+
# TODO, what if align_corners=False
163+
return output
164+
165+
166+
def resize_flow(flow,
167+
size_type,
168+
sizes,
169+
interp_mode='bilinear',
170+
align_corners=False):
171+
"""Resize a flow according to ratio or shape.
172+
173+
Args:
174+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
175+
size_type (str): 'ratio' or 'shape'.
176+
sizes (list[int | float]): the ratio for resizing or the final output
177+
shape.
178+
1) The order of ratio should be [ratio_h, ratio_w]. For
179+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
180+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
181+
ratio > 1.0).
182+
2) The order of output_size should be [out_h, out_w].
183+
interp_mode (str): The mode of interpolation for resizing.
184+
Default: 'bilinear'.
185+
align_corners (bool): Whether align corners. Default: False.
186+
187+
Returns:
188+
Tensor: Resized flow.
189+
"""
190+
_, _, flow_h, flow_w = flow.size()
191+
if size_type == 'ratio':
192+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
193+
elif size_type == 'shape':
194+
output_h, output_w = sizes[0], sizes[1]
195+
else:
196+
raise ValueError(
197+
f'Size type should be ratio or shape, but got type {size_type}.')
198+
199+
input_flow = flow.clone()
200+
ratio_h = output_h / flow_h
201+
ratio_w = output_w / flow_w
202+
input_flow[:, 0, :, :] *= ratio_w
203+
input_flow[:, 1, :, :] *= ratio_h
204+
resized_flow = F.interpolate(
205+
input=input_flow,
206+
size=(output_h, output_w),
207+
mode=interp_mode,
208+
align_corners=align_corners)
209+
return resized_flow
210+
211+
212+
# TODO: may write a cpp file
213+
def pixel_unshuffle(x, scale):
214+
""" Pixel unshuffle.
215+
216+
Args:
217+
x (Tensor): Input feature with shape (b, c, hh, hw).
218+
scale (int): Downsample ratio.
219+
220+
Returns:
221+
Tensor: the pixel unshuffled feature.
222+
"""
223+
b, c, hh, hw = x.size()
224+
out_channel = c * (scale**2)
225+
assert hh % scale == 0 and hw % scale == 0
226+
h = hh // scale
227+
w = hw // scale
228+
x_view = x.view(b, c, h, scale, w, scale)
229+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
230+
231+
232+
class DCNv2Pack(ModulatedDeformConvPack):
233+
"""Modulated deformable conv for deformable alignment.
234+
235+
Different from the official DCNv2Pack, which generates offsets and masks
236+
from the preceding features, this DCNv2Pack takes another different
237+
features to generate offsets and masks.
238+
239+
Ref:
240+
Delving Deep into Deformable Alignment in Video Super-Resolution.
241+
"""
242+
243+
def forward(self, x, feat):
244+
out = self.conv_offset(feat)
245+
o1, o2, mask = torch.chunk(out, 3, dim=1)
246+
offset = torch.cat((o1, o2), dim=1)
247+
mask = torch.sigmoid(mask)
248+
249+
offset_absmean = torch.mean(torch.abs(offset))
250+
if offset_absmean > 50:
251+
logger = get_root_logger()
252+
logger.warning(
253+
f'Offset abs mean is {offset_absmean}, larger than 50.')
254+
255+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
256+
self.stride, self.padding, self.dilation,
257+
self.groups, self.deformable_groups)
258+
259+
260+
## Channel Attention (CA) Layer
261+
class CALayer(nn.Module):
262+
def __init__(self, channel, reduction=16):
263+
super(CALayer, self).__init__()
264+
# global average pooling: feature --> point
265+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
266+
# feature channel downscale and upscale --> channel weight
267+
self.conv_du = nn.Sequential(
268+
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
269+
nn.ReLU(inplace=True),
270+
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
271+
nn.Sigmoid()
272+
)
273+
274+
def forward(self, x):
275+
y = self.avg_pool(x)
276+
y = self.conv_du(y)
277+
return x * y
278+
279+
280+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
281+
return nn.Conv2d(
282+
in_channels, out_channels, kernel_size,
283+
padding=(kernel_size//2), bias=bias)
284+
285+
286+
## Residual Channel Attention Block (RCAB)
287+
class RCAB(nn.Module):
288+
def __init__(self, conv=default_conv, n_feat=64, kernel_size=3, reduction=1, bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
289+
super(RCAB, self).__init__()
290+
modules_body = []
291+
for i in range(2):
292+
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
293+
if bn: modules_body.append(nn.BatchNorm2d(n_feat))
294+
if i == 0: modules_body.append(act)
295+
modules_body.append(CALayer(n_feat, reduction))
296+
self.body = nn.Sequential(*modules_body)
297+
self.res_scale = res_scale
298+
299+
def forward(self, x):
300+
res = self.body(x)
301+
#res = self.body(x).mul(self.res_scale)
302+
res += x
303+
return res
304+
305+
306+
## Residual Group (RG)
307+
class ResidualGroup(nn.Module):
308+
def __init__(self, conv=default_conv, n_feat=64, kernel_size=3, reduction=1, act=nn.ReLU(True), res_scale=1, n_resblocks=30):
309+
super(ResidualGroup, self).__init__()
310+
modules_body = []
311+
modules_body = [
312+
RCAB(
313+
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
314+
for _ in range(n_resblocks)]
315+
modules_body.append(conv(n_feat, n_feat, kernel_size))
316+
self.body = nn.Sequential(*modules_body)
317+
318+
def forward(self, x):
319+
res = self.body(x)
320+
res += x
321+
return res
322+
323+
324+
class RCABWithInputConv(nn.Module):
325+
"""RCAB blocks with a convolution in front.
326+
Args:
327+
in_channels (int): Number of input channels of the first conv.
328+
out_channels (int): Number of channels of the residual blocks.
329+
Default: 64.
330+
num_blocks (int): Number of residual blocks. Default: 30.
331+
"""
332+
333+
def __init__(self, in_channels, out_channels=64, num_blocks=30):
334+
super().__init__()
335+
336+
main = [RCAB(default_conv, out_channels, 3, 1, act=nn.ReLU(True), res_scale=1) for _ in range(num_blocks)]
337+
338+
# a convolution used to match the channels of the residual blocks
339+
main.insert(0, nn.LeakyReLU(negative_slope=0.1, inplace=True))
340+
main.insert(0, nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True))
341+
342+
self.main = nn.Sequential(*main)
343+
344+
def forward(self, feat):
345+
"""
346+
Forward function for RCABWithInputConv.
347+
Args:
348+
feat (Tensor): Input feature with shape (n, in_channels, h, w)
349+
Returns:
350+
Tensor: Output feature with shape (n, out_channels, h, w)
351+
"""
352+
return self.main(feat)

models/pixel_shuffle3d

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 8eb64b7ecce011ef2cc8adafb9c687c8593bf6ed

0 commit comments

Comments
 (0)