-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
78 lines (59 loc) · 2.48 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
def dice_score(predict, target, smooth=1e-5):
smooth = 1.
pflat = predict.view(-1)
tflat = target.view(-1)
pflat = pflat.cpu().detach()
tflat = tflat.cpu().detach()
intersection = (pflat * tflat).sum()
dice_score = (2. * intersection + smooth) / (pflat.sum() + tflat.sum() + smooth)
return dice_score
def metric(probability, truth, threshold=0.5, reduction='none'):
'''Calculates dice of positive and negative images seperately'''
'''probability and truth must be torch tensors'''
batch_size = len(truth)
with torch.no_grad():
probability = probability.view(batch_size, -1)
truth = truth.view(batch_size, -1)
assert(probability.shape == truth.shape)
p = (probability > threshold).float()
t = (truth > 0.5).float()
t_sum = t.sum(-1)
p_sum = p.sum(-1)
neg_index = torch.nonzero(t_sum == 0)
pos_index = torch.nonzero(t_sum >= 1)
dice_neg = (p_sum == 0).float()
dice_pos = 2 * (p*t).sum(-1)/((p+t).sum(-1))
dice_neg = dice_neg[neg_index]
dice_pos = dice_pos[pos_index]
dice = torch.cat([dice_pos, dice_neg])
dice_neg = np.nan_to_num(dice_neg.mean().item(), 0)
dice_pos = np.nan_to_num(dice_pos.mean().item(), 0)
dice = dice.mean().item()
num_neg = len(neg_index)
num_pos = len(pos_index)
return dice, dice_neg, dice_pos, num_neg, num_pos
class Meter:
'''A meter to keep track of iou and dice scores throughout an epoch'''
def __init__(self, epoch):
self.base_threshold = 0.5 # <<<<<<<<<<< here's the threshold
self.base_dice_scores = []
self.dice_neg_scores = []
self.dice_pos_scores = []
self.iou_scores = []
def update(self, targets, outputs):
probs = torch.sigmoid(outputs)
dice, dice_neg, dice_pos, _, _ = metric(probs, targets, self.base_threshold)
self.base_dice_scores.append(dice)
self.dice_pos_scores.append(dice_pos)
self.dice_neg_scores.append(dice_neg)
preds = predict(probs, self.base_threshold)
iou = compute_iou_batch(preds, targets, classes=[1])
self.iou_scores.append(iou)
def get_metrics(self):
dice = np.mean(self.base_dice_scores)
dice_neg = np.mean(self.dice_neg_scores)
dice_pos = np.mean(self.dice_pos_scores)
dices = [dice, dice_neg, dice_pos]
iou = np.nanmean(self.iou_scores)
return dices, iou