diff --git a/train_DDP.py b/train_DDP.py index 105e410..c3d20d3 100755 --- a/train_DDP.py +++ b/train_DDP.py @@ -451,10 +451,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25): # load best model weights model.load_state_dict(last_model_wts) - if len(opt.gpu_ids)>1: - save_network(model.module, 'last') - else: - save_network(model, 'last') + save_network(model.module, opt.name, 'last') return model