@@ -266,24 +266,25 @@ def collate_fn(data):
266
266
267
267
optimizer = torch .optim .AdamW (model_params , lr = input_params .learning_rate , weight_decay = input_params .weight_decay ) #define optimizer
268
268
269
+ last_epoch = - 1
270
+
269
271
if input_params .load_weights :
270
272
271
273
if torch .cuda .is_available ():
272
274
#load on gpu
273
275
model .load_state_dict (torch .load (input_params .config_start_base + '_model' ))
274
- optimizer .load_state_dict (torch .load (input_params .config_start_base + '_optimizer' ))
276
+ if not input_params .inference_mode :
277
+ optimizer .load_state_dict (torch .load (input_params .config_start_base + '_optimizer' ))
275
278
else :
276
279
#load on cpu
277
280
model .load_state_dict (torch .load (input_params .config_start_base + '_model' , map_location = torch .device ('cpu' )))
278
- optimizer .load_state_dict (torch .load (input_params .config_start_base + '_optimizer' , map_location = torch .device ('cpu' )))
281
+ if not input_params .inference_mode :
282
+ optimizer .load_state_dict (torch .load (input_params .config_start_base + '_optimizer' , map_location = torch .device ('cpu' )))
279
283
280
284
last_epoch = int (input_params .config_start_base .split ('_' )[- 2 ]) #infer previous epoch from input_params.config_start_base
281
285
282
- else :
283
-
284
- last_epoch = - 1
285
-
286
- lr_scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer ,
286
+ if not input_params .inference_mode :
287
+ lr_scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer ,
287
288
milestones = [input_params .lr_sch_milestones ],
288
289
gamma = input_params .lr_sch_gamma ,
289
290
last_epoch = last_epoch , verbose = False ) #define learning rate scheduler
0 commit comments