-
Notifications
You must be signed in to change notification settings - Fork 24
/
run_resnet_demo.py
executable file
·182 lines (138 loc) · 5.79 KB
/
run_resnet_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import torchvision
from torchvision import models
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import os
import os.path as osp
import yaml
import numpy as np
import PIL.Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import utils
here = osp.dirname(osp.abspath(__file__)) # output folder is located here
# -----------------------------------------------------------------------------
# 0. User-defined settings
# -----------------------------------------------------------------------------
gpu = 0 # use gpu:0 by default
# specify model path of trained ResNet-50 network:
model_path = './umd-face/logs/MODEL-resnet_umdfaces_CFG-006_TIME-20180114-141943/model_best.pth.tar'
num_class = 8277 # UMD-Faces had this many classes
# -----------------------------------------------------------------------------
# 1. GPU setup
# -----------------------------------------------------------------------------
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
cuda = torch.cuda.is_available()
torch.manual_seed(1337)
if cuda:
torch.cuda.manual_seed(1337)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# -----------------------------------------------------------------------------
# 2. Data preparation
# -----------------------------------------------------------------------------
# Samples images taken for demo purpose from LFW:
# http://vis-www.cs.umass.edu/lfw/
data_root = './samples/verif'
file_path = [osp.join(data_root, 'Recep_Tayyip_Erdogan_0012.jpg'),
osp.join(data_root, 'Recep_Tayyip_Erdogan_0015.jpg'),
osp.join(data_root, 'Quincy_Jones_0001.jpg')]
image = [PIL.Image.open(f).convert('RGB') for f in file_path]
# Data transforms
# http://pytorch.org/docs/master/torchvision/transforms.html
# NOTE: these should be consistent with the training script val_loader
# Since LFW images (250x250) are not close-crops, we modify the cropping a bit.
RGB_MEAN = [ 0.485, 0.456, 0.406 ]
RGB_STD = [ 0.229, 0.224, 0.225 ]
test_transform = transforms.Compose([
transforms.CenterCrop(150), # 150x150 center crop
transforms.Scale((224,224)), # resized to the network's required input size
transforms.ToTensor(),
transforms.Normalize(mean = RGB_MEAN,
std = RGB_STD),
])
# apply the transform
inputs = [test_transform(im) for im in image]
# -----------------------------------------------------------------------------
# 3. Model
# -----------------------------------------------------------------------------
# PyTorch ResNet model definition:
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
# ResNet docs:
# http://pytorch.org/docs/master/torchvision/models.html#id3
model = torchvision.models.resnet50(pretrained=True)
# Replace last layer (by default, resnet has 1000 output categories)
model.fc = torch.nn.Linear(2048, num_class) # change to current dataset's classes
# Pre-trained PyTorch model loaded from a file
checkpoint = torch.load(model_path)
if checkpoint['arch'] == 'DataParallel':
# if we trained and saved our model using DataParallel
model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3, 4])
model.load_state_dict(checkpoint['model_state_dict'])
model = model.module # get network module from inside its DataParallel wrapper
else:
model.load_state_dict(checkpoint['model_state_dict'])
if cuda:
model = model.cuda()
# Convert the trained network into a "feature extractor"
# From https://github.com/meliketoy/fine-tuning.pytorch/blob/master/extract_features.py#L85
feature_map = list(model.children())
feature_map.pop() # remove the final "class prediction" layer
extractor = nn.Sequential(*feature_map) # create feature extractor
# Inspect the structure - it is a nested list of various modules
print extractor[-1] # last layer of the model - avg-pool
print extractor[-2][-1] # second-last layer's last module - output is 2048-dim
# -----------------------------------------------------------------------------
# 4. Feature extraction
# -----------------------------------------------------------------------------
# - simple, one input sample at a time
features = []
for x in inputs:
x = Variable(x, volatile=True)
if cuda:
x = x.cuda()
x = x.view(1, x.size(0), x.size(1), x.size(2)) # add batch_dim=1 in the front
feat = extractor(x).view(-1) # extract features of input `x`, reshape to 1-D vector
features.append(feat)
features = torch.stack(features) # N x 2048 for N inputs
# get Tensors on CPU from autograd.Variables on GPU
if cuda:
features = features.data.cpu()
else:
features = features.data
features = F.normalize(features, p=2, dim=1) # L2-normalize
# -----------------------------------------------------------------------------
# 5. Face verification
# -----------------------------------------------------------------------------
# L2-distance between features (Tensors) of same and different pairs
d1 = (features[0] - features[1]).norm(p=2) # same pair
d2 = (features[0] - features[2]).norm(p=2) # different pair
print 'matched pair: %.2f' % d1
print 'mismatched pair: %.2f' % d2
assert d1 < d2
# visualizations
fig, ax = plt.subplots(nrows=2, ncols=2)
plt.subplot(2, 2, 1)
plt.title('matched pair')
plt.imshow(image[0])
plt.tight_layout()
plt.subplot(2, 2, 2)
plt.imshow(image[1])
plt.title('d = %.3f' % d1)
plt.tight_layout()
plt.subplot(2, 2, 3)
plt.imshow(image[0])
plt.title('mismatched pair')
plt.tight_layout()
plt.subplot(2, 2, 4)
plt.imshow(image[2])
plt.title('d = %.3f' % d2)
plt.tight_layout()
plt.savefig(osp.join(here, 'demo_verif.png'), bbox_inches='tight')
print 'Visualization saved in ' + osp.join(here, 'demo_verif.png')