Skip to content

Commit ecddb8b

Browse files
committed
Add SENet.
1 parent c19222c commit ecddb8b

File tree

2 files changed

+174
-30
lines changed

2 files changed

+174
-30
lines changed

experiment.py

+19-30
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sklearn.model_selection import KFold, StratifiedKFold
1212
from model.densenet import *
1313
from model.resnet import *
14+
from model.senet import *
1415
from core.mixup import Mixup, OneHotCrossEntropy
1516
from core.snap_scheduler import SnapScheduler
1617
from tqdm import tqdm
@@ -32,12 +33,17 @@
3233
'densenet169': densenet169,
3334
'densenet201': densenet201,
3435
'densenet161': densenet161,
36+
'senet18': se_resnet18,
37+
'senet34': se_resnet34,
38+
'senet50': se_resnet50,
39+
'senet101': se_resnet101,
40+
'senet152': se_resnet152,
3541
}
3642

3743
class Experiment(object):
3844
def __init__(self, model: str, batch_size: int, epochs: int, lr: float, eval_interval: int=1,
3945
optimizer: str='sgd', schedule: str=None, step_size: int=10, gamma: float=0.5, use_mixup: bool=True,
40-
mixup_alpha: float=0.5, conv_fixed: bool=False, weighted: bool=False, cross_validate: bool=False,
46+
mixup_alpha: float=0.5, weighted: bool=False, cross_validate: bool=False,
4147
n_splits: int=5, seed: int=42, metric: str='accuracy', no_snaps: bool=False, debug_limit: int=None,
4248
device: str=('cuda' if torch.cuda.is_available() else 'cpu'), num_processes: int=8, multi_gpu: bool=False, **kwargs):
4349
self.set_seed(seed)
@@ -52,7 +58,6 @@ def __init__(self, model: str, batch_size: int, epochs: int, lr: float, eval_int
5258
self.gamma = gamma
5359
self.optimizer_str = optimizer
5460
self.use_mixup = use_mixup
55-
self.conv_fixed = conv_fixed
5661
self.weighted = weighted
5762
self.cross_validate = cross_validate
5863
self.n_splits = n_splits
@@ -99,15 +104,9 @@ def __init__(self, model: str, batch_size: int, epochs: int, lr: float, eval_int
99104
self.model = self.load_model()
100105

101106
if optimizer == 'sgd':
102-
if self.conv_fixed:
103-
self.optimizer = optim.SGD(self.model.fc.parameters(), lr=self.lr, momentum=0.9)
104-
else:
105-
self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9)
107+
self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9)
106108
elif optimizer == 'adam':
107-
if self.conv_fixed:
108-
self.optimizer = optim.Adam(self.model.fc.parameters(), lr=self.lr, amsgrad=False)
109-
else:
110-
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, amsgrad=False)
109+
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, amsgrad=False)
111110

112111
if self.schedule is not None:
113112
if self.schedule.lower() == 'step':
@@ -153,21 +152,19 @@ def get_loaders(self, num_workers=8):
153152
'test': thd.DataLoader(self.testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_processes)}
154153

155154
def load_model(self):
156-
model = pretrained_models[self.model_str](pretrained=True)
157-
if self.conv_fixed:
158-
logger.warning("Fixing weights")
159-
for param in model.parameters():
160-
param.requires_grad = False
161-
162155
classifier = lambda num_features: nn.Linear(num_features, self.num_classes)
163156

164157
if self.model_str.startswith('densenet'):
158+
model = pretrained_models[self.model_str](pretrained=True)
165159
num_ftrs = model.classifier.in_features
166160
model.classifier = classifier(num_ftrs)
167161
elif self.model_str.startswith('resnet'):
162+
model = pretrained_models[self.model_str](pretrained=True)
168163
num_ftrs = model.fc.in_features
169164
model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
170165
model.fc = classifier(num_ftrs)
166+
elif self.model_str.startswith('senet'):
167+
model = pretrained_models[self.model_str](num_classes=self.num_classes)
171168
else:
172169
raise ValueError(f'Invalid model string. Received {self.model_str}.')
173170

@@ -303,15 +300,9 @@ def split_run(self):
303300
self.model = self.load_model()
304301

305302
if self.optimizer_str == 'sgd':
306-
if self.conv_fixed:
307-
self.optimizer = optim.SGD(self.model.fc.parameters(), lr=self.lr, momentum=0.9)
308-
else:
309-
self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9)
303+
self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9)
310304
elif self.optimizer_str == 'adam':
311-
if self.conv_fixed:
312-
self.optimizer = optim.Adam(self.model.fc.parameters(), lr=self.lr, amsgrad=False)
313-
else:
314-
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, amsgrad=False)
305+
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, amsgrad=False)
315306

316307
self.single_run(run_fname=f'run-{split_num}')
317308

@@ -346,16 +337,14 @@ def run(self):
346337
parser.add_argument('--gamma', type=float, default=0.5, help='Gamma argument for scheduler (only applies to step and exponential).')
347338
# Prevent from using mixup
348339
parser.add_argument('--no_mixup', action='store_true', help='Flag whether to use mixup.')
349-
# Fix weights of convolutional layers
350-
parser.add_argument('--conv_fixed', action='store_true', help='Flag whether to fix weights of convolutional layers.')
351-
# Weight classes to tackle inbalance
352-
parser.add_argument('-w', '--weighted', action='store_true', help='Flag whether to weight classes.')
353340
# Use cross validation
354341
parser.add_argument('-cv', '--cross_validate', action='store_true', help='Flag whether to use cross validation.')
355342
# Alpha parameter for Mixup's Beta distribution
356343
parser.add_argument('-alpha', '--mixup_alpha', type=float, default=0.8, help="Alpha parameter for Mixup's Beta distribution.")
357344
# Prevent from storing snapshots
358345
parser.add_argument('--no_snaps', action='store_true', help='Flag whether to prevent from storing snapshots.')
346+
# Evaulation interval
347+
parser.add_argument('--eval_interval', type=int, default=1, help='How often to run evaluation.')
359348
# Debug limit to decrease size of dataset
360349
parser.add_argument('--debug_limit', type=int, default=None, help='Debug limit to decrease size of dataset.')
361350
# Seed
@@ -373,7 +362,7 @@ def run(self):
373362
if args.gpu_device is not None:
374363
torch.cuda.set_device(args.gpu_device)
375364

376-
exp = Experiment(args.model, args.batch_size, args.epochs, args.learning_rate, use_mixup=(not args.no_mixup),
377-
mixup_alpha=args.mixup_alpha, conv_fixed=args.conv_fixed, weighted=args.weighted, cross_validate=args.cross_validate, schedule=args.scheduler,
365+
exp = Experiment(args.model, args.batch_size, args.epochs, args.learning_rate, eval_interval= args.eval_interval, use_mixup=(not args.no_mixup),
366+
mixup_alpha=args.mixup_alpha, cross_validate=args.cross_validate, schedule=args.scheduler,
378367
seed=args.seed, no_snaps=args.no_snaps, debug_limit=args.debug_limit, num_processes=args.num_workers, multi_gpu=args.multi_gpu)
379368
exp.run()

model/senet.py

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import math
2+
import torch.nn as nn
3+
from model.resnet import ResNet
4+
5+
6+
class SELayer(nn.Module):
7+
def __init__(self, channel, reduction=16):
8+
super(SELayer, self).__init__()
9+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
10+
self.fc = nn.Sequential(
11+
nn.Linear(channel, channel // reduction),
12+
nn.ReLU(inplace=True),
13+
nn.Linear(channel // reduction, channel),
14+
nn.Sigmoid()
15+
)
16+
17+
def forward(self, x):
18+
b, c, _, _ = x.size()
19+
y = self.avg_pool(x).view(b, c)
20+
y = self.fc(y).view(b, c, 1, 1)
21+
return x * y
22+
23+
24+
def conv3x3(in_planes, out_planes, stride=1):
25+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
26+
27+
28+
class SEBasicBlock(nn.Module):
29+
expansion = 1
30+
31+
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
32+
super(SEBasicBlock, self).__init__()
33+
self.conv1 = conv3x3(inplanes, planes, stride)
34+
self.bn1 = nn.BatchNorm2d(planes)
35+
self.relu = nn.ReLU(inplace=True)
36+
self.conv2 = conv3x3(planes, planes, 1)
37+
self.bn2 = nn.BatchNorm2d(planes)
38+
self.se = SELayer(planes, reduction)
39+
self.downsample = downsample
40+
self.stride = stride
41+
42+
def forward(self, x):
43+
residual = x
44+
out = self.conv1(x)
45+
out = self.bn1(out)
46+
out = self.relu(out)
47+
48+
out = self.conv2(out)
49+
out = self.bn2(out)
50+
out = self.se(out)
51+
52+
if self.downsample is not None:
53+
residual = self.downsample(x)
54+
55+
out += residual
56+
out = self.relu(out)
57+
58+
return out
59+
60+
61+
class SEBottleneck(nn.Module):
62+
expansion = 4
63+
64+
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
65+
super(SEBottleneck, self).__init__()
66+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67+
self.bn1 = nn.BatchNorm2d(planes)
68+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
69+
padding=1, bias=False)
70+
self.bn2 = nn.BatchNorm2d(planes)
71+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
72+
self.bn3 = nn.BatchNorm2d(planes * 4)
73+
self.relu = nn.ReLU(inplace=True)
74+
self.se = SELayer(planes * 4, reduction)
75+
self.downsample = downsample
76+
self.stride = stride
77+
78+
def forward(self, x):
79+
residual = x
80+
81+
out = self.conv1(x)
82+
out = self.bn1(out)
83+
out = self.relu(out)
84+
85+
out = self.conv2(out)
86+
out = self.bn2(out)
87+
out = self.relu(out)
88+
89+
out = self.conv3(out)
90+
out = self.bn3(out)
91+
out = self.se(out)
92+
93+
if self.downsample is not None:
94+
residual = self.downsample(x)
95+
96+
out += residual
97+
out = self.relu(out)
98+
99+
return out
100+
101+
102+
def se_resnet18(num_classes):
103+
"""Constructs a ResNet-18 model.
104+
105+
Args:
106+
pretrained (bool): If True, returns a model pre-trained on ImageNet
107+
"""
108+
model = ResNet(SEBasicBlock, [2, 2, 2, 2], num_classes=num_classes)
109+
model.avgpool = nn.AdaptiveAvgPool2d(1)
110+
return model
111+
112+
113+
def se_resnet34(num_classes):
114+
"""Constructs a ResNet-34 model.
115+
116+
Args:
117+
pretrained (bool): If True, returns a model pre-trained on ImageNet
118+
"""
119+
model = ResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=num_classes)
120+
model.avgpool = nn.AdaptiveAvgPool2d(1)
121+
return model
122+
123+
124+
def se_resnet50(num_classes):
125+
"""Constructs a ResNet-50 model.
126+
127+
Args:
128+
pretrained (bool): If True, returns a model pre-trained on ImageNet
129+
"""
130+
model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes)
131+
model.avgpool = nn.AdaptiveAvgPool2d(1)
132+
return model
133+
134+
135+
def se_resnet101(num_classes):
136+
"""Constructs a ResNet-101 model.
137+
138+
Args:
139+
pretrained (bool): If True, returns a model pre-trained on ImageNet
140+
"""
141+
model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes)
142+
model.avgpool = nn.AdaptiveAvgPool2d(1)
143+
return model
144+
145+
146+
def se_resnet152(num_classes):
147+
"""Constructs a ResNet-152 model.
148+
149+
Args:
150+
pretrained (bool): If True, returns a model pre-trained on ImageNet
151+
"""
152+
model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes)
153+
model.avgpool = nn.AdaptiveAvgPool2d(1)
154+
return model
155+

0 commit comments

Comments
 (0)