-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodel_loader.py
103 lines (85 loc) · 2.75 KB
/
model_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import torch.nn as nn
import torchvision.models as models
import os
def load_model(arch, code_length):
"""
Load CNN model.
Args
arch(str): Model name.
code_length(int): Hash code length.
Returns
model(torch.nn.Module): CNN model.
"""
if arch == 'alexnet':
model = models.alexnet(pretrained=True)
model.classifier = model.classifier[:-2]
model = ModelWrapper(model, 4096, code_length)
elif arch == 'vgg16':
model = models.vgg16(pretrained=True)
model.classifier = model.classifier[:-3]
model = ModelWrapper(model, 4096, code_length)
else:
raise ValueError("Invalid model name!")
return model
class ModelWrapper(nn.Module):
"""
Add tanh activate function into model.
Args
model(torch.nn.Module): CNN model.
last_node(int): Last layer outputs size.
code_length(int): Hash code length.
"""
def __init__(self, model, last_node, code_length):
super(ModelWrapper, self).__init__()
self.model = model
self.code_length = code_length
self.hash_layer = nn.Sequential(
nn.ReLU(inplace=True),
nn.Linear(last_node, code_length),
nn.Tanh(),
)
# Extract features
self.extract_features = False
def forward(self, x):
if self.extract_features:
return self.model(x)
else:
return self.hash_layer(self.model(x))
def set_extract_features(self, flag):
"""
Extract features.
Args
flag(bool): true, if one needs extract features.
"""
self.extract_features = flag
def snapshot(self, it, optimizer):
"""
Save model snapshot.
Args
it(int): Iteration.
optimizer(torch.optim): Optimizer.
Returns
None
"""
torch.save({
'iteration': it,
'model_state_dict': self.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, os.path.join('checkpoints', 'resume_{}.t'.format(it)))
def load_snapshot(self, root, optimizer=None):
"""
Load model snapshot.
Args
root(str): Path of model snapshot.
optimizer(torch.optim): Optimizer.
Returns
optimizer(torch.optim, optional): Optimizer, if parameter 'optimizer' given.
it(int): Iteration, if parameter 'optimizer' given.
"""
checkpoint = torch.load(root)
self.load_state_dict(checkpoint['model_state_dict'])
if optimizer:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
it = checkpoint['iteration']
return optimizer, it