-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvovnet.py
384 lines (329 loc) · 13.4 KB
/
vovnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
""" Adapted from the original implementation. """
import collections
import dataclasses
from typing import List
import torch
@dataclasses.dataclass
class VoVNetParams:
stem_out: int
stage_conv_ch: List[int] # Channel depth of
stage_out_ch: List[int] # The channel depth of the concatenated output
layer_per_block: int
block_per_stage: List[int]
dw: bool
_STAGE_SPECS = {
"vovnet-19-slim-dw": VoVNetParams(
64, [64, 80, 96, 112], [112, 256, 384, 512], 3, [1, 1, 1, 1], True
),
"vovnet-19-dw": VoVNetParams(
64, [128, 160, 192, 224], [256, 512, 768, 1024], 3, [1, 1, 1, 1], True
),
"vovnet-19-slim": VoVNetParams(
128, [64, 80, 96, 112], [112, 256, 384, 512], 3, [1, 1, 1, 1], False
),
"vovnet-19": VoVNetParams(
128, [128, 160, 192, 224], [256, 512, 768, 1024], 3, [1, 1, 1, 1], False
),
"vovnet-39": VoVNetParams(
128, [128, 160, 192, 224], [256, 512, 768, 1024], 5, [1, 1, 2, 2], False
),
"vovnet-57": VoVNetParams(
128, [128, 160, 192, 224], [256, 512, 768, 1024], 5, [1, 1, 4, 3], False
),
"vovnet-99": VoVNetParams(
128, [128, 160, 192, 224], [256, 512, 768, 1024], 5, [1, 3, 9, 3], False
),
}
_BN_MOMENTUM = 1e-1
_BN_EPS = 1e-5
def dw_conv(
in_channels: int, out_channels: int, stride: int = 1
) -> List[torch.nn.Module]:
""" Depthwise separable pointwise linear convolution. """
return [
torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
padding=1,
stride=stride,
groups=in_channels,
bias=False,
),
torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
torch.nn.BatchNorm2d(out_channels, eps=_BN_EPS, momentum=_BN_MOMENTUM),
torch.nn.ReLU(inplace=True),
]
def conv(
in_channels: int,
out_channels: int,
stride: int = 1,
groups: int = 1,
kernel_size: int = 3,
padding: int = 1,
) -> List[torch.nn.Module]:
""" 3x3 convolution with padding."""
return [
torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
),
torch.nn.BatchNorm2d(out_channels, eps=_BN_EPS, momentum=_BN_MOMENTUM),
torch.nn.ReLU(inplace=True),
]
def pointwise(in_channels: int, out_channels: int) -> List[torch.nn.Module]:
""" Pointwise convolution."""
return [
torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
torch.nn.BatchNorm2d(out_channels, eps=_BN_EPS, momentum=_BN_MOMENTUM),
torch.nn.ReLU(inplace=True),
]
class ESE(torch.nn.Module):
"""This is adapted from the efficientnet Squeeze Excitation. The idea is to not
squeeze the number of channels to keep more information."""
def __init__(self, channel: int) -> None:
super().__init__()
self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
self.fc = torch.nn.Conv2d(channel, channel, kernel_size=1) # (Linear)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.avg_pool(x)
out = self.fc(out)
return torch.sigmoid(out) * x
class OSA(torch.nn.Module):
def __init__(
self,
in_channels: int,
stage_channels: int,
concat_channels: int,
layer_per_block: int,
use_depthwise: bool = False,
) -> None:
""" Implementation of an OSA layer which takes the output of its conv layers and
concatenates them into one large tensor which is passed to the next layer. The
goal with this concatenation is to preserve information flow through the model
layers. This also ends up helping with small object detection.
Args:
in_channels: Channel depth of the input to the OSA block.
stage_channels: Channel depth to reduce the input.
concat_channels: Channel depth to force on the concatenated output of the
comprising layers in a block.
layer_per_block: The number of layers in this OSA block.
use_depthwise: Wether to use depthwise separable pointwise linear convs.
"""
super().__init__()
# Keep track of the size of the final concatenation tensor.
aggregated = in_channels
self.isReduced = in_channels != stage_channels
# If this OSA block is not the first in the OSA stage, we can
# leverage the fact that subsequent OSA blocks have the same input and
# output channel depth, concat_channels. This lets us reuse the concept of
# a residual from ResNet models.
self.identity = in_channels == concat_channels
self.layers = torch.nn.ModuleList()
self.use_depthwise = use_depthwise
conv_op = dw_conv if use_depthwise else conv
# If this model uses depthwise and the input channel depth needs to be reduced
# to the stage_channels size, add a pointwise layer to adjust the depth. If the
# model is not depthwise, let the first OSA layer do the resizing.
if self.use_depthwise and self.isReduced:
self.conv_reduction = torch.nn.Sequential(
*pointwise(in_channels, stage_channels)
)
in_channels = stage_channels
for _ in range(layer_per_block):
self.layers.append(
torch.nn.Sequential(*conv_op(in_channels, stage_channels))
)
in_channels = stage_channels
# feature aggregation
aggregated += layer_per_block * stage_channels
self.concat = torch.nn.Sequential(*pointwise(aggregated, concat_channels))
self.ese = ESE(concat_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.identity:
identity_feat = x
output = [x]
if self.use_depthwise and self.isReduced:
x = self.conv_reduction(x)
# Loop through all the
for layer in self.layers:
x = layer(x)
output.append(x)
x = torch.cat(output, dim=1)
xt = self.concat(x)
xt = self.ese(xt)
if self.identity:
xt += identity_feat
return xt
class OSA_stage(torch.nn.Sequential):
def __init__(
self,
in_channels: int,
stage_channels: int,
concat_channels: int,
block_per_stage: int,
layer_per_block: int,
stage_num: int,
use_depthwise: bool = False,
) -> None:
"""An OSA stage which is comprised of OSA blocks.
Args:
in_channels: Channel depth of the input to the OSA stage.
stage_channels: Channel depth to reduce the input of the block to.
concat_channels: Channel depth to force on the concatenated output of the
comprising layers in a block.
block_per_stage: Number of OSA blocks in this stage.
layer_per_block: The number of layers per OSA block.
stage_num: The OSA stage index.
use_depthwise: Wether to use depthwise separable pointwise linear convs.
"""
super().__init__()
# Use maxpool to downsample the input to this OSA stage.
self.add_module(
"Pooling", torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
for idx in range(block_per_stage):
# Add the OSA modules. If this is the first block in the stage, use the
# proper in in channels, but the rest of the rest of the OSA layers will use
# the concatenation channel depth outputted from the previous layer.
self.add_module(
f"OSA{stage_num}_{idx + 1}",
OSA(
in_channels if idx == 0 else concat_channels,
stage_channels,
concat_channels,
layer_per_block,
use_depthwise=use_depthwise,
),
)
class VoVNet(torch.nn.Sequential):
def __init__(
self, model_name: str, num_classes: int = 10, input_channels: int = 3
) -> None:
"""
Args:
model_name: Which model to create.
num_classes: The number of classification classes.
input_channels: The number of input channels.
Usage:
>>> net = VoVNet("vovnet-19-slim-dw", num_classes=1000)
>>> with torch.no_grad():
... out = net(torch.randn(1, 3, 512, 512))
>>> print(out.shape)
torch.Size([1, 1000])
>>> net = VoVNet("vovnet-19-dw", num_classes=1000)
>>> with torch.no_grad():
... out = net(torch.randn(1, 3, 512, 512))
>>> print(out.shape)
torch.Size([1, 1000])
>>> net = VoVNet("vovnet-19-slim", num_classes=1000)
>>> with torch.no_grad():
... out = net(torch.randn(1, 3, 512, 512))
>>> print(out.shape)
torch.Size([1, 1000])
>>> net = VoVNet("vovnet-19", num_classes=1000)
>>> with torch.no_grad():
... out = net(torch.randn(1, 3, 512, 512))
>>> print(out.shape)
torch.Size([1, 1000])
>>> net = VoVNet("vovnet-39", num_classes=1000)
>>> with torch.no_grad():
... out = net(torch.randn(1, 3, 512, 512))
>>> print(out.shape)
torch.Size([1, 1000])
>>> net = VoVNet("vovnet-57", num_classes=1000)
>>> with torch.no_grad():
... out = net(torch.randn(1, 3, 512, 512))
>>> print(out.shape)
torch.Size([1, 1000])
>>> net = VoVNet("vovnet-99", num_classes=1000)
>>> with torch.no_grad():
... out = net(torch.randn(1, 3, 512, 512))
>>> print(out.shape)
torch.Size([1, 1000])
"""
super().__init__()
assert model_name in _STAGE_SPECS, f"{model_name} not supported."
stem_ch = _STAGE_SPECS[model_name].stem_out
config_stage_ch = _STAGE_SPECS[model_name].stage_conv_ch
config_concat_ch = _STAGE_SPECS[model_name].stage_out_ch
block_per_stage = _STAGE_SPECS[model_name].block_per_stage
layer_per_block = _STAGE_SPECS[model_name].layer_per_block
conv_type = dw_conv if _STAGE_SPECS[model_name].dw else conv
# Construct the stem.
stem = conv(input_channels, 64, stride=2)
stem += conv_type(64, 64)
# The original implementation uses a stride=2 on the conv below, but in this
# implementation we'll just pool at every OSA stage, unlike the original
# which doesn't pool at the first OSA stage.
stem += conv_type(64, stem_ch)
self.model = torch.nn.Sequential()
self.model.add_module("stem", torch.nn.Sequential(*stem))
self._out_feature_channels = [stem_ch]
# Organize the outputs of each OSA stage. This is the concatentated channel
# depth of each sub block's layer's outputs.
in_ch_list = [stem_ch] + config_concat_ch[:-1]
# Add the OSA modules. Typically 4 modules.
for idx in range(len(config_stage_ch)):
self.model.add_module(
f"OSA_{(idx + 2)}",
OSA_stage(
in_ch_list[idx],
config_stage_ch[idx],
config_concat_ch[idx],
block_per_stage[idx],
layer_per_block,
idx + 2,
_STAGE_SPECS[model_name].dw,
),
)
self._out_feature_channels.append(config_concat_ch[idx])
# Add the classification head.
self.model.add_module(
"classifier",
torch.nn.Sequential(
torch.nn.BatchNorm2d(
self._out_feature_channels[-1], _BN_MOMENTUM, _BN_EPS
),
torch.nn.AdaptiveAvgPool2d(1),
torch.nn.Flatten(),
torch.nn.Dropout(0.2),
torch.nn.Linear(self._out_feature_channels[-1], num_classes, bias=True),
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def forward_pyramids(self, x: torch.Tensor) -> collections.OrderedDict:
"""
Args:
model_name: Which model to create.
num_classes: The number of classification classes.
input_channels: The number of input channels.
Usage:
>>> net = VoVNet("vovnet-19-slim-dw", num_classes=1000)
>>> net.delete_classification_head()
>>> with torch.no_grad():
... out = net.forward_pyramids(torch.randn(1, 3, 512, 512))
>>> [level.shape[-1] for level in out.values()] # Check the height/widths of levels
[256, 128, 64, 32, 16]
>>> [level.shape[1] for level in out.values()] == net._out_feature_channels
True
"""
levels = collections.OrderedDict()
levels[1] = self.model.stem(x)
levels[2] = self.model.OSA_2(levels[1])
levels[3] = self.model.OSA_3(levels[2])
levels[4] = self.model.OSA_4(levels[3])
levels[5] = self.model.OSA_5(levels[4])
return levels
def delete_classification_head(self) -> None:
""" Call this before using model as an object detection backbone. """
del self.model.classifier
def get_pyramid_channels(self) -> None:
""" Return the number of channels for each pyramid level. """
return self._out_feature_channels