forked from CherBass/CapsPix2Pix
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_losses.py
162 lines (129 loc) · 5.67 KB
/
custom_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
def jaccard_loss(inputs, targets):
num = targets.size()[0] # batch size
inputSize = targets.size()[2]
m1 = inputs.contiguous().view(num, -1)
m2 = targets.contiguous().view(num, -1)
numerator = torch.sum(m1 * m2, 1)
denominator = torch.sum(m2, 1)
numerator2 = torch.sum((1 - m1) * (1 - m2), 1)
denominator2 = torch.sum((1 - m2), 1)
ratioOfbackgroundPixels = torch.mean(
denominator2 / inputSize ** 2) # give larger weights for loss produced by image patches with fewer number of vessel pixels
denominator3 = torch.sum(m1, 1)
denominator4 = torch.sum((1 - m1), 1)
return ((1 - torch.mean(numerator / (denominator + 0.000001))) * ratioOfbackgroundPixels + 1 - torch.mean(
numerator2 / (denominator2 + 0.000001)) + (1 - ratioOfbackgroundPixels) * (
1 - torch.mean(numerator / (denominator3 + 0.000001))) + (
1 - torch.mean(numerator2 / (denominator4 + 0.000001)))) / 4
def jaccard_loss_2(inputs, targets):
num = targets.size()[0] # batch size
inputSize = targets.size()[2]
m1 = inputs.contiguous().view(num, -1)
m2 = targets.contiguous().view(num, -1)
numerator = torch.sum(m1 * m2, 1)
denominator = torch.sum(m2, 1)
numerator2 = torch.sum((1 - m1) * (1 - m2), 1)
denominator2 = torch.sum((1 - m2), 1)
ratioOfbackgroundPixels = torch.mean(
denominator2 / inputSize ** 2) # give larger weights for loss produced by image patches with fewer number of vessel pixels
denominator3 = torch.sum(m1, 1)
denominator4 = torch.sum((1 - m1), 1)
return ((1 - torch.mean(numerator / (denominator + 0.000001))) * ratioOfbackgroundPixels + 1 - torch.mean(
numerator2 / (denominator2 + 0.000001)) + (
1 - torch.mean(numerator / (denominator3 + 0.000001))) + (
1 - torch.mean(numerator2 / (denominator4 + 0.000001)))) / 4
def dice_loss(pred, target):
"""This definition generalize to real valued pred and target vector.
This should be differentiable.
pred: tensor with first dimension as batch
target: tensor with first dimension as batch
"""
smooth = 1.
epsilon = 10e-8
# have to use contiguous since they may from a torch.view op
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
A_sum = torch.sum(iflat * iflat)
B_sum = torch.sum(tflat * tflat)
dice = (2. * intersection + smooth) / (A_sum + B_sum + smooth)
dice = dice.mean(dim=0)
dice = torch.clamp(dice, 0, 1.0-epsilon)
return 1- dice
def dice_hard(pred, target):
"""This definition generalize to real valued pred and target vector.
This should be differentiable.
pred: tensor with first dimension as batch
target: tensor with first dimension as batch
"""
smooth = 1.
epsilon = 10e-8
# have to use contiguous since they may from a torch.view op
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
A_sum = torch.sum(iflat * iflat)
B_sum = torch.sum(tflat * tflat)
dice = (2. * intersection + smooth) / (A_sum + B_sum + smooth)
dice = dice.mean(dim=0)
dice = torch.clamp(dice, 0, 1.0-epsilon)
return 1 - (dice)
def dice_coeff(pred, target):
"""This definition generalize to real valued pred and target vector.
This should be differentiable.
pred: tensor with first dimension as batch
target: tensor with first dimension as batch
"""
smooth = 1.
epsilon = 10e-8
# have to use contiguous since they may from a torch.view op
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
A_sum = torch.sum(iflat * iflat)
B_sum = torch.sum(tflat * tflat)
dice = (2. * intersection + smooth) / (A_sum + B_sum + smooth)
dice = dice.mean(dim=0)
dice = torch.clamp(dice, 0, 1.0-epsilon)
return dice
def dice_soft(pred, target, loss_type='sorensen', smooth=1e-5, from_logits=False):
"""Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity
of two batch of data, usually be used for binary image segmentation
i.e. labels are binary. The coefficient between 0 to 1, 1 means totally match.
Parameters
-----------
pred : tensor
A distribution with shape: [batch_size, ....], (any dimensions).
target : tensor
A distribution with shape: [batch_size, ....], (any dimensions).
loss_type : string
``jaccard`` or ``sorensen``, default is ``jaccard``.
smooth : float
This small value will be added to the numerator and denominator.
If both y_pred and y_true are empty, it makes sure dice is 1.
If either y_pred or y_true are empty (all pixels are background), dice = ```smooth/(small_value + smooth)``,
then if smooth is very small, dice close to 0 (even the image values lower than the threshold),
so in this case, higher smooth can have a higher dice.
References
-----------
- `Wiki-Dice <https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient>`_
"""
if not from_logits:
# transform back to logits
_epsilon = 1e-7
pred = torch.clamp(pred, _epsilon, 1 - _epsilon)
pred = torch.log(pred / (1 - pred))
inse = torch.sum(pred * target)
if loss_type == 'jaccard':
l = torch.sum(pred * pred)
r = torch.sum(target * target)
elif loss_type == 'sorensen':
l = torch.sum(pred)
r = torch.sum(target)
else:
raise Exception("Unknow loss_type")
dice = (2. * inse + smooth) / (l + r + smooth)
##
dice = dice.mean(dim=0)
return dice