Skip to content

Commit 9ec74d2

Browse files
author
sergey.vilov
committed
not load optimizer for inference
1 parent da2239a commit 9ec74d2

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

NNC/nn.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -266,24 +266,25 @@ def collate_fn(data):
266266

267267
optimizer = torch.optim.AdamW(model_params, lr=input_params.learning_rate, weight_decay=input_params.weight_decay) #define optimizer
268268

269+
last_epoch = -1
270+
269271
if input_params.load_weights:
270272

271273
if torch.cuda.is_available():
272274
#load on gpu
273275
model.load_state_dict(torch.load(input_params.config_start_base + '_model'))
274-
optimizer.load_state_dict(torch.load(input_params.config_start_base + '_optimizer'))
276+
if not input_params.inference_mode:
277+
optimizer.load_state_dict(torch.load(input_params.config_start_base + '_optimizer'))
275278
else:
276279
#load on cpu
277280
model.load_state_dict(torch.load(input_params.config_start_base + '_model', map_location=torch.device('cpu')))
278-
optimizer.load_state_dict(torch.load(input_params.config_start_base + '_optimizer', map_location=torch.device('cpu')))
281+
if not input_params.inference_mode:
282+
optimizer.load_state_dict(torch.load(input_params.config_start_base + '_optimizer', map_location=torch.device('cpu')))
279283

280284
last_epoch = int(input_params.config_start_base.split('_')[-2]) #infer previous epoch from input_params.config_start_base
281285

282-
else:
283-
284-
last_epoch = -1
285-
286-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
286+
if not input_params.inference_mode:
287+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
287288
milestones=[input_params.lr_sch_milestones],
288289
gamma=input_params.lr_sch_gamma,
289290
last_epoch=last_epoch, verbose=False) #define learning rate scheduler

0 commit comments

Comments
 (0)