-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmetric.py
73 lines (61 loc) · 2.33 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import torch
from sklearn.metrics import roc_auc_score
def roc_auc(output, target):
with torch.no_grad():
try:
auc = roc_auc_score(target.cpu().numpy(), output.cpu().numpy())
# deal with batch containing only one class
except ValueError:
auc = np.nan
return auc
def accuracy(output, target):
with torch.no_grad():
pred = output > 0.5
assert pred.shape[0] == len(target)
correct = 0
correct += torch.sum(pred == target).item()
return correct / len(target)
def balanced_accuracy(output, target):
with torch.no_grad():
pred = (output > 0.5).squeeze()
assert pred.shape[0] == len(target)
n_items_0 = torch.sum(target == 0).item()
n_items_1 = torch.sum(target == 1).item()
# deal with batch containing only one class
if n_items_1 > 0 and n_items_0 > 0:
recall_0 = torch.sum((pred + target) == 0).item() / n_items_0
recall_1 = torch.sum((pred + target) == 2).item() / n_items_1
baccuracy = 0.5 * (recall_0 + recall_1)
else:
baccuracy = np.nan
return baccuracy
def accuracy_with_logit(output, target):
with torch.no_grad():
pred = torch.sigmoid(output) > 0.5
assert pred.shape[0] == len(target)
correct = 0
correct += torch.sum(pred == target).item()
return correct / len(target)
def balanced_accuracy_with_logit(output, target):
with torch.no_grad():
pred = (torch.sigmoid(output) > 0.5).squeeze()
assert pred.shape[0] == len(target)
n_items_0 = torch.sum(target == 0).item()
n_items_1 = torch.sum(target == 1).item()
# deal with batch containing only one class
if n_items_1 > 0 and n_items_0 > 0:
recall_0 = torch.sum((pred + target) == 0).item() / n_items_0
recall_1 = torch.sum((pred + target) == 2).item() / n_items_1
baccuracy = 0.5 * (recall_0 + recall_1)
else:
baccuracy = np.nan
return baccuracy
def top_k_acc(output, target, k=3):
with torch.no_grad():
pred = torch.topk(output, k, dim=1)[1]
assert pred.shape[0] == len(target)
correct = 0
for i in range(k):
correct += torch.sum(pred[:, i] == target).item()
return correct / len(target)