-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
69 lines (37 loc) · 1.07 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
def f_score(pr, gt):
tp = torch.sum(gt * pr)
fp = torch.sum(pr) - tp
fn = torch.sum(gt) - tp
score = ((2 * tp) / (2 * tp + fn + fp))
return score
def accuracy(pr, gt):
tp = torch.sum(gt == pr)
score = tp / gt.reshape(-1).shape[0]
return score
def precision(pr, gt):
tp = torch.sum(gt * pr)
fp = torch.sum(pr) - tp
score = (tp) / (tp + fp)
return score
def recall(pr, gt):
tp = torch.sum(gt * pr)
fn = torch.sum(gt) - tp
score = (tp / (tp + fn))
return score
def specificity(pr, gt):
neg_gt = torch.abs(torch.sub(gt, 1))
neg_pr = torch.abs(torch.sub(pr, 1))
tn = torch.sum(neg_gt * neg_pr)
tp = torch.sum(gt * pr)
fp = torch.sum(pr) - tp
score = (tn) / (tn + fp)
return score
def batch_metric(pr, gt, metric):
assert pr.shape == gt.shape
pr_data, gt_data = pr.cpu().detach(), gt.cpu().detach()
b, c, _, _ = pr_data.shape
total = 0
for i in range(b):
total += metric(pr_data[i, 1:, ...], gt_data[i, 1:, ...])
return total/b