-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmodels.py
208 lines (173 loc) · 8.3 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import torch
from torch.autograd import Variable
import torch.nn as nn
import prunable_nn as pnn
import torch.utils.model_zoo as model_zoo
from torchvision import models
from operator import itemgetter
class VGG(models.VGG):
def __init__(self, features, num_classes=1000):
super().__init__(features, num_classes)
def pruning(self, flag):
prunable = [module for module in self.features
if getattr(module, "prune_feature_map", False) and module.out_channels > 1]
for p in prunable:
p.pruning(flag)
def prune(self):
# gather all modules & their indices. (excluding classifier)
# gather all talyor_estimate_lists & pair with the indices
# gather all talyor_estimates & pair with their list index & module index
# reduce to the minimum in the list
# grab the module with the minimum & prune
# grab the PBatchNorm & adjust
# adjust the next layer too
feature_list = list(enumerate(self.features))
# grab the taylor estimates of PConv2ds & pair with the module's index in self.features
taylor_estimates_by_module = [(module.taylor_estimates, module_idx) for module_idx, module in feature_list
if issubclass(type(module), pnn.PConv2d) and module.out_channels > 1]
taylor_estimates_by_feature_map = \
[(estimate, map_idx, module_idx)
for estimates_by_map, module_idx in taylor_estimates_by_module
for map_idx, estimate in enumerate(estimates_by_map)]
_, min_map_idx, min_module_idx = min(taylor_estimates_by_feature_map, key=itemgetter(0))
p_conv2d = self.features[min_module_idx]
p_conv2d.prune_feature_map(min_map_idx)
p_batchnorm = self.features[min_module_idx+1]
p_batchnorm.drop_input_channel(min_map_idx)
offset = 3 # batchnorm, relu, maxpool
is_last_conv2d = (len(feature_list)-1)-offset == min_module_idx
is_double_conv2d_layer = min_module_idx == 8 or min_module_idx == 15 or min_module_idx == 22
if is_last_conv2d:
first_p_linear = self.classifier[0]
shape = (first_p_linear.in_features//49, 7, 7) # the input is always ?x7x7
first_p_linear.drop_inputs(shape, min_map_idx)
elif is_double_conv2d_layer:
# no max pool,
next_p_conv2d = self.features[min_module_idx+offset]
next_p_conv2d.drop_input_channel(min_map_idx)
else:
next_p_conv2d = self.features[min_module_idx+offset+1]
next_p_conv2d.drop_input_channel(min_map_idx)
def vgg_model(num_classes):
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfg = {'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],}
model_url = 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth'
model = VGG(make_layers(cfg['A'], batch_norm=True))
model.load_state_dict(model_zoo.load_url(model_url), strict=False)
model.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
return model, 'vgg11_bn'
def chinese_model(num_classes):
return ChineseNet(num_classes), 'chinese_net'
def chinese_pruned_80(num_classes):
config = [26, 'M', 39, 'M', 52, 'M', 75, 93, 'M', 88, 95, 'M']
return ChineseNet(num_classes, config), 'chinese_net_80'
def chinese_pruned_90(num_classes):
config = [15, 'M', 14, 'M', 20, 'M', 27, 31, 'M', 28, 30, 'M']
return ChineseNet(num_classes, config), 'chinese_net_90'
class ChineseNet(nn.Module):
# inspired by https://arxiv.org/abs/1702.07975, used for chinese ocr
def __init__(self, num_classes, config=None):
super(ChineseNet, self).__init__()
self.config = [96, 'M', 128, 'M', 160, 'M', 256, 256, 'M', 384, 384, 'M'] if config is None else config
self.features = self.make_layers()
self.classifier = nn.Sequential(
# input is 96x96, output from features section should always be 2x2
pnn.PLinear(self.config[-2]*2*2, 1024),
nn.BatchNorm1d(1024),
nn.PReLU(),
nn.Dropout(),
nn.Linear(1024, num_classes)
)
self.convert_to_onnx = False
self.__pruning = False
def pruning(self, flag):
self.__pruning = flag
prunable = [module for module in self.features
if getattr(module, "prune_feature_map", False) and module.out_channels > 1]
for p in prunable:
p.pruning(flag)
def make_layers(self):
layers = []
in_channels = 1
for v in self.config:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=3, stride=2)]
else:
conv2d = pnn.PConv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, pnn.PBatchNorm2d(v), nn.PReLU()]
in_channels = v
return nn.Sequential(*layers)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
if self.convert_to_onnx:
x = self.classifier[0](x)
# manually perform 1d batchnorm, caffe2 currently requires a resize,
# which is hard to squeeze into the exported network
bn_1d = self.classifier[1]
numerator = (x - Variable(bn_1d.running_mean))
denominator = Variable(torch.sqrt(bn_1d.running_var + bn_1d.eps))
x = numerator/denominator*Variable(bn_1d.weight.data) + Variable(bn_1d.bias.data)
x = self.classifier[2](x)
x = self.classifier[3](x)
x = self.classifier[4](x)
return x
else:
x = self.classifier(x)
return x
def prune(self):
# gather all modules & their indices. (excluding classifier)
# gather all talyor_estimate_lists & pair with the indices
# gather all talyor_estimates & pair with their list index & module index
# reduce to the minimum in the list
# grab the module with the minimum & prune
# grab the PBatchNorm & adjust
# adjust the next layer too
feature_list = list(enumerate(self.features))
# grab the taylor estimates of PConv2ds & pair with the module's index in self.features
taylor_estimates_by_module = [(module.taylor_estimates, module_idx) for module_idx, module in feature_list
if getattr(module, "prune_feature_map", False) and module.out_channels > 1]
taylor_estimates_by_feature_map = \
[(estimate, map_idx, module_idx)
for estimates_by_map, module_idx in taylor_estimates_by_module
for map_idx, estimate in enumerate(estimates_by_map)]
_, min_map_idx, min_module_idx = min(taylor_estimates_by_feature_map, key=itemgetter(0))
p_conv2d = self.features[min_module_idx]
p_conv2d.prune_feature_map(min_map_idx)
p_batchnorm = self.features[min_module_idx+1]
p_batchnorm.drop_input_channel(min_map_idx)
offset = 3 # batchnorm & prelu & maxpool
is_last_conv2d = (len(feature_list)-1)-offset == min_module_idx
is_double_conv2d_layer = min_module_idx == 12 or min_module_idx == 19
if is_last_conv2d:
first_p_linear = self.classifier[0]
shape = (first_p_linear.in_features//4, 2, 2) # the input is always ?x2x2
first_p_linear.drop_inputs(shape, min_map_idx)
elif is_double_conv2d_layer:
# no max pool, -1
next_p_conv2d = self.features[min_module_idx+offset]
next_p_conv2d.drop_input_channel(min_map_idx)
else:
next_p_conv2d = self.features[min_module_idx+offset+1]
next_p_conv2d.drop_input_channel(min_map_idx)