-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
132 lines (107 loc) · 3.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import shutil
import torch
import numpy as np
from sklearn import metrics
def get_class(x):
res = []
for row in x:
row_res = []
for col in row:
if col >= 0.5:
row_res.append(1)
else:
row_res.append(0)
res.append(row_res)
return np.array(res)
def multilabel_score(y_true, y_pred):
score = 0.0
for i in range(y_true.shape[0]):
for j in range(y_true.shape[1]):
if y_true[i][j] == y_pred[i][j]:
if y_true[i][j] == 0:
score += 1
if y_true[i][j] == 1:
score += 100
score /= (y_true.shape[0] * y_true.shape[1])
return score
def class_eval(prediction, target, pred_type):
prediction = prediction.cpu().numpy()
target = target.cpu().numpy()
if prediction.shape[1] == 2:
pred_label = np.argmax(prediction, axis=1)
target_label = np.squeeze(target)
precision, recall, fscore, _ = metrics.precision_recall_fscore_support(
target_label, pred_label, average='binary')
try:
auc_score = metrics.roc_auc_score(target_label, prediction[:, 1])
except: # all true label are 0
auc_score = 0.0
accuracy = metrics.accuracy_score(target_label, pred_label)
else:
if pred_type == 'multi_label':
pred_label = get_class(prediction)
accuracy = multilabel_score(target, pred_label)
precision = 0.0
recall = 0.0
fscore = 0.0
auc_score = 0.0
elif pred_type == 'multi_class':
pred_label = np.argmax(prediction, axis=1)
target_label = np.argmax(prediction, axis=1)
precision, recall, fscore, _ = metrics.precision_recall_fscore_support(
target_label, pred_label, average='binary')
auc_score = 0.0
accuracy = metrics.accuracy_score(target_label, pred_label)
return accuracy, precision, recall, fscore, auc_score
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def save_checkpoint(state, is_best, path, filename='checkpoint.pth.tar'):
filename = os.path.join(path, filename)
torch.save(state, filename)
if is_best:
print('Saving checkpoint ...\n')
best_model = os.path.join(path, 'model_best.pth.tar')
shutil.copyfile(filename, best_model)
def load_checkpoint(model, checkpoint_path):
if os.path.isfile(checkpoint_path):
print("=> loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(checkpoint_path, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(checkpoint_path))
def check_fields(required_fields, session):
for field in required_fields:
if field not in session:
raise Exception('%s should be configured in IO session' % field)
def print_progress(text):
print("=====%s=====" % text)
def str2list(str, sep=','):
return str.split(sep)