11
11
from sklearn .model_selection import KFold , StratifiedKFold
12
12
from model .densenet import *
13
13
from model .resnet import *
14
+ from model .senet import *
14
15
from core .mixup import Mixup , OneHotCrossEntropy
15
16
from core .snap_scheduler import SnapScheduler
16
17
from tqdm import tqdm
32
33
'densenet169' : densenet169 ,
33
34
'densenet201' : densenet201 ,
34
35
'densenet161' : densenet161 ,
36
+ 'senet18' : se_resnet18 ,
37
+ 'senet34' : se_resnet34 ,
38
+ 'senet50' : se_resnet50 ,
39
+ 'senet101' : se_resnet101 ,
40
+ 'senet152' : se_resnet152 ,
35
41
}
36
42
37
43
class Experiment (object ):
38
44
def __init__ (self , model : str , batch_size : int , epochs : int , lr : float , eval_interval : int = 1 ,
39
45
optimizer : str = 'sgd' , schedule : str = None , step_size : int = 10 , gamma : float = 0.5 , use_mixup : bool = True ,
40
- mixup_alpha : float = 0.5 , conv_fixed : bool = False , weighted : bool = False , cross_validate : bool = False ,
46
+ mixup_alpha : float = 0.5 , weighted : bool = False , cross_validate : bool = False ,
41
47
n_splits : int = 5 , seed : int = 42 , metric : str = 'accuracy' , no_snaps : bool = False , debug_limit : int = None ,
42
48
device : str = ('cuda' if torch .cuda .is_available () else 'cpu' ), num_processes : int = 8 , multi_gpu : bool = False , ** kwargs ):
43
49
self .set_seed (seed )
@@ -52,7 +58,6 @@ def __init__(self, model: str, batch_size: int, epochs: int, lr: float, eval_int
52
58
self .gamma = gamma
53
59
self .optimizer_str = optimizer
54
60
self .use_mixup = use_mixup
55
- self .conv_fixed = conv_fixed
56
61
self .weighted = weighted
57
62
self .cross_validate = cross_validate
58
63
self .n_splits = n_splits
@@ -99,15 +104,9 @@ def __init__(self, model: str, batch_size: int, epochs: int, lr: float, eval_int
99
104
self .model = self .load_model ()
100
105
101
106
if optimizer == 'sgd' :
102
- if self .conv_fixed :
103
- self .optimizer = optim .SGD (self .model .fc .parameters (), lr = self .lr , momentum = 0.9 )
104
- else :
105
- self .optimizer = optim .SGD (self .model .parameters (), lr = self .lr , momentum = 0.9 )
107
+ self .optimizer = optim .SGD (self .model .parameters (), lr = self .lr , momentum = 0.9 )
106
108
elif optimizer == 'adam' :
107
- if self .conv_fixed :
108
- self .optimizer = optim .Adam (self .model .fc .parameters (), lr = self .lr , amsgrad = False )
109
- else :
110
- self .optimizer = optim .Adam (self .model .parameters (), lr = self .lr , amsgrad = False )
109
+ self .optimizer = optim .Adam (self .model .parameters (), lr = self .lr , amsgrad = False )
111
110
112
111
if self .schedule is not None :
113
112
if self .schedule .lower () == 'step' :
@@ -153,21 +152,19 @@ def get_loaders(self, num_workers=8):
153
152
'test' : thd .DataLoader (self .testset , batch_size = self .batch_size , shuffle = False , num_workers = self .num_processes )}
154
153
155
154
def load_model (self ):
156
- model = pretrained_models [self .model_str ](pretrained = True )
157
- if self .conv_fixed :
158
- logger .warning ("Fixing weights" )
159
- for param in model .parameters ():
160
- param .requires_grad = False
161
-
162
155
classifier = lambda num_features : nn .Linear (num_features , self .num_classes )
163
156
164
157
if self .model_str .startswith ('densenet' ):
158
+ model = pretrained_models [self .model_str ](pretrained = True )
165
159
num_ftrs = model .classifier .in_features
166
160
model .classifier = classifier (num_ftrs )
167
161
elif self .model_str .startswith ('resnet' ):
162
+ model = pretrained_models [self .model_str ](pretrained = True )
168
163
num_ftrs = model .fc .in_features
169
164
model .avgpool = torch .nn .AdaptiveAvgPool2d (1 )
170
165
model .fc = classifier (num_ftrs )
166
+ elif self .model_str .startswith ('senet' ):
167
+ model = pretrained_models [self .model_str ](num_classes = self .num_classes )
171
168
else :
172
169
raise ValueError (f'Invalid model string. Received { self .model_str } .' )
173
170
@@ -303,15 +300,9 @@ def split_run(self):
303
300
self .model = self .load_model ()
304
301
305
302
if self .optimizer_str == 'sgd' :
306
- if self .conv_fixed :
307
- self .optimizer = optim .SGD (self .model .fc .parameters (), lr = self .lr , momentum = 0.9 )
308
- else :
309
- self .optimizer = optim .SGD (self .model .parameters (), lr = self .lr , momentum = 0.9 )
303
+ self .optimizer = optim .SGD (self .model .parameters (), lr = self .lr , momentum = 0.9 )
310
304
elif self .optimizer_str == 'adam' :
311
- if self .conv_fixed :
312
- self .optimizer = optim .Adam (self .model .fc .parameters (), lr = self .lr , amsgrad = False )
313
- else :
314
- self .optimizer = optim .Adam (self .model .parameters (), lr = self .lr , amsgrad = False )
305
+ self .optimizer = optim .Adam (self .model .parameters (), lr = self .lr , amsgrad = False )
315
306
316
307
self .single_run (run_fname = f'run-{ split_num } ' )
317
308
@@ -346,16 +337,14 @@ def run(self):
346
337
parser .add_argument ('--gamma' , type = float , default = 0.5 , help = 'Gamma argument for scheduler (only applies to step and exponential).' )
347
338
# Prevent from using mixup
348
339
parser .add_argument ('--no_mixup' , action = 'store_true' , help = 'Flag whether to use mixup.' )
349
- # Fix weights of convolutional layers
350
- parser .add_argument ('--conv_fixed' , action = 'store_true' , help = 'Flag whether to fix weights of convolutional layers.' )
351
- # Weight classes to tackle inbalance
352
- parser .add_argument ('-w' , '--weighted' , action = 'store_true' , help = 'Flag whether to weight classes.' )
353
340
# Use cross validation
354
341
parser .add_argument ('-cv' , '--cross_validate' , action = 'store_true' , help = 'Flag whether to use cross validation.' )
355
342
# Alpha parameter for Mixup's Beta distribution
356
343
parser .add_argument ('-alpha' , '--mixup_alpha' , type = float , default = 0.8 , help = "Alpha parameter for Mixup's Beta distribution." )
357
344
# Prevent from storing snapshots
358
345
parser .add_argument ('--no_snaps' , action = 'store_true' , help = 'Flag whether to prevent from storing snapshots.' )
346
+ # Evaulation interval
347
+ parser .add_argument ('--eval_interval' , type = int , default = 1 , help = 'How often to run evaluation.' )
359
348
# Debug limit to decrease size of dataset
360
349
parser .add_argument ('--debug_limit' , type = int , default = None , help = 'Debug limit to decrease size of dataset.' )
361
350
# Seed
@@ -373,7 +362,7 @@ def run(self):
373
362
if args .gpu_device is not None :
374
363
torch .cuda .set_device (args .gpu_device )
375
364
376
- exp = Experiment (args .model , args .batch_size , args .epochs , args .learning_rate , use_mixup = (not args .no_mixup ),
377
- mixup_alpha = args .mixup_alpha , conv_fixed = args . conv_fixed , weighted = args . weighted , cross_validate = args .cross_validate , schedule = args .scheduler ,
365
+ exp = Experiment (args .model , args .batch_size , args .epochs , args .learning_rate , eval_interval = args . eval_interval , use_mixup = (not args .no_mixup ),
366
+ mixup_alpha = args .mixup_alpha , cross_validate = args .cross_validate , schedule = args .scheduler ,
378
367
seed = args .seed , no_snaps = args .no_snaps , debug_limit = args .debug_limit , num_processes = args .num_workers , multi_gpu = args .multi_gpu )
379
368
exp .run ()
0 commit comments