-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy pathmodel.py
executable file
·37 lines (31 loc) · 1.07 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch import nn
from models import vinet
def generate_model(opt):
try:
assert(opt.model == 'vinet_final')
model = vinet.VINet_final(opt=opt)
except:
print('Model name should be: vinet_final')
assert(opt.no_cuda is False)
model = model.cuda()
model = nn.DataParallel(model)
loaded, empty = 0,0
if opt.pretrain_path:
print('Loading pretrained model {}'.format(opt.pretrain_path))
pretrain = torch.load(opt.pretrain_path)
child_dict = model.state_dict()
parent_list = pretrain['state_dict'].keys()
parent_dict = {}
for chi,_ in child_dict.items():
if chi in parent_list:
parent_dict[chi] = pretrain['state_dict'][chi]
#print('Loaded: ',chi)
loaded += 1
else:
#print('Empty:',chi)
empty += 1
print('Loaded: %d/%d params'%(loaded, loaded+empty))
child_dict.update(parent_dict)
model.load_state_dict(child_dict)
return model, model.parameters()