-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathparamSave.py
56 lines (38 loc) · 1.7 KB
/
paramSave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from utils.datasets import *
from models import *
def saveParams( path, model, fName="weights.dat" ):
if not os.path.exists(path):
os.makedirs(path)
params = np.empty(0)
Dict = model.state_dict()
for name in Dict:
param = Dict[name].numpy()
if "num_batches" in name:
continue
param = param.reshape(param.size)
params = np.concatenate((params, param))
params.tofile(path+"/"+fName)
if __name__ == "__main__":
path = "checkpoints/bestFinetuneHR93_34.weights"
model = ROBO(bn=False,inch=3,halfRes=True)
model.load_state_dict(torch.load(path, map_location={'cuda:0': 'cpu'}))
saveParams("checkpoints/",model,fName="weightsHR.dat")
path = "checkpoints/bestFinetune2C93_43.weights"
model = ROBO(bn=False,inch=2,halfRes=False)
model.load_state_dict(torch.load(path, map_location={'cuda:0': 'cpu'}))
saveParams("checkpoints/",model,fName="weights2C.dat")
path = "checkpoints/bestFinetune2CHR93_32.weights"
model = ROBO(bn=False,inch=2,halfRes=True)
model.load_state_dict(torch.load(path, map_location={'cuda:0': 'cpu'}))
saveParams("checkpoints/",model,fName="weights2CHR.dat")
path = "checkpoints/bestFinetuneBN97_79.weights"
model = ROBO(bn=True,inch=3,halfRes=False)
model.load_state_dict(torch.load(path, map_location={'cuda:0': 'cpu'}))
saveParams("checkpoints/",model,fName="weightsBN.dat")
path = "checkpoints/bestFinetune93_41.weights"
model = ROBO(bn=False,inch=3,halfRes=False)
model.load_state_dict(torch.load(path, map_location={'cuda:0': 'cpu'}))
saveParams("checkpoints/",model,fName="weights.dat")