-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
executable file
·101 lines (87 loc) · 3.35 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import io
import torch
from torchvision import transforms as trans
from datetime import datetime
from PIL import Image
from data.data_pipe import de_preprocess
import matplotlib.pyplot as plt
import random
plt.switch_backend('agg')
def separate_bn_paras(modules):
if not isinstance(modules, list):
modules = [*modules.modules()]
paras_only_bn = []
paras_wo_bn = []
for layer in modules:
if 'model' in str(layer.__class__):
continue
if 'container' in str(layer.__class__):
continue
else:
if 'batchnorm' in str(layer.__class__):
paras_only_bn.extend([*layer.parameters()])
else:
paras_wo_bn.extend([*layer.parameters()])
return paras_only_bn, paras_wo_bn
def hflip_batch(imgs_tensor):
hflip = trans.Compose([
de_preprocess,
trans.ToPILImage(),
trans.functional.hflip,
trans.ToTensor(),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
hfliped_imgs = torch.empty_like(imgs_tensor)
for i, img_ten in enumerate(imgs_tensor):
hfliped_imgs[i] = hflip(img_ten)
return hfliped_imgs
def get_time():
return (str(datetime.now())[:-10]).replace(' ','-').replace(':','-')
def gen_plot(fpr, tpr):
"""Create a pyplot plot and save to buffer."""
plt.figure()
plt.xlabel("FPR", fontsize=14)
plt.ylabel("TPR", fontsize=14)
plt.title("ROC Curve", fontsize=14)
plot = plt.plot(fpr, tpr, linewidth=2)
buf = io.BytesIO()
plt.savefig(buf, format='jpeg')
buf.seek(0)
plt.close()
return buf
def cos_dist(x1, x2):
return torch.sum(x1 * x2) / (torch.norm(x1) * torch.norm(x2))
def fixed_img_list(lfw_pair_text):
f = open(lfw_pair_text, 'r')
lines = []
while True:
line = f.readline()
if not line:
break
lines.append(line)
f.close()
random.shuffle(lines)
return lines
def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
# Calculate evaluation metrics
thresholds = np.arange(0, 4, 0.01)
embeddings1 = embeddings[0::2]
embeddings2 = embeddings[1::2]
tpr, fpr, accuracy, best_thresholds = calculate_roc(thresholds, embeddings1, embeddings2, np.asarray(actual_issame), nrof_folds=nrof_folds, pca=pca)
# thresholds = np.arange(0, 4, 0.001)
# val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2,
# np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds)
# return tpr, fpr, accuracy, best_thresholds, val, val_std, far
return tpr, fpr, accuracy, best_thresholds
def evaluate_dist(embeddings, actual_issame, nrof_folds=10, pca=0):
# Calculate evaluation metrics
thresholds = np.arange(0, 4, 0.01)
embeddings1 = embeddings[0::2]
embeddings2 = embeddings[1::2]
tpr, fpr, accuracy, best_thresholds, dist = calculate_roc_dist(thresholds, embeddings1, embeddings2,
np.asarray(actual_issame), nrof_folds=nrof_folds, pca=pca)
# thresholds = np.arange(0, 4, 0.001)
# val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2,
# np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds)
# return tpr, fpr, accuracy, best_thresholds, val, val_std, far
return tpr, fpr, accuracy, best_thresholds, dist