-
Notifications
You must be signed in to change notification settings - Fork 86
/
metrics.py
61 lines (52 loc) · 1.8 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import numpy as np
import cv2
def _threshold(x, threshold=None):
if threshold is not None:
return (x > threshold).type(x.dtype)
else:
return x
def _list_tensor(x, y):
m = torch.nn.Sigmoid()
if type(x) is list:
x = torch.tensor(np.array(x))
y = torch.tensor(np.array(y))
if x.min() < 0:
x = m(x)
else:
x, y = x, y
if x.min() < 0:
x = m(x)
return x, y
def iou(pr, gt, eps=1e-7, threshold = 0.5):
pr_, gt_ = _list_tensor(pr, gt)
pr_ = _threshold(pr_, threshold=threshold)
gt_ = _threshold(gt_, threshold=threshold)
intersection = torch.sum(gt_ * pr_,dim=[1,2,3])
union = torch.sum(gt_,dim=[1,2,3]) + torch.sum(pr_,dim=[1,2,3]) - intersection
return ((intersection + eps) / (union + eps)).cpu().numpy()
def dice(pr, gt, eps=1e-7, threshold = 0.5):
pr_, gt_ = _list_tensor(pr, gt)
pr_ = _threshold(pr_, threshold=threshold)
gt_ = _threshold(gt_, threshold=threshold)
intersection = torch.sum(gt_ * pr_,dim=[1,2,3])
union = torch.sum(gt_,dim=[1,2,3]) + torch.sum(pr_,dim=[1,2,3])
return ((2. * intersection +eps) / (union + eps)).cpu().numpy()
def SegMetrics(pred, label, metrics):
metric_list = []
if isinstance(metrics, str):
metrics = [metrics, ]
for i, metric in enumerate(metrics):
if not isinstance(metric, str):
continue
elif metric == 'iou':
metric_list.append(np.mean(iou(pred, label)))
elif metric == 'dice':
metric_list.append(np.mean(dice(pred, label)))
else:
raise ValueError('metric %s not recognized' % metric)
if pred is not None:
metric = np.array(metric_list)
else:
raise ValueError('metric mistakes in calculations')
return metric