-
Notifications
You must be signed in to change notification settings - Fork 112
/
Copy pathloss.py
125 lines (82 loc) · 3.92 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
基于Dice的loss函数,计算时pred和target的shape必须相同,亦即target为onehot编码后的Tensor
"""
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, target):
# pred = pred.squeeze(dim=1)
smooth = 1
dice = 0.
# dice系数的定义
for i in range(pred.size(1)):
dice += 2 * (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (pred[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) +
target[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + smooth)
# 返回的是dice距离
dice = dice / pred.size(1)
return torch.clamp((1 - dice).mean(), 0, 1)
class ELDiceLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, target):
smooth = 1
dice = 0.
# dice系数的定义
for i in range(pred.size(1)):
dice += 2 * (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (pred[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) +
target[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + smooth)
dice = dice / pred.size(1)
# 返回的是dice距离
return torch.clamp((torch.pow(-torch.log(dice + 1e-5), 0.3)).mean(), 0, 2)
class HybridLoss(nn.Module):
def __init__(self):
super().__init__()
self.bce_loss = nn.BCELoss()
self.bce_weight = 1.0
def forward(self, pred, target):
smooth = 1
dice = 0.
# dice系数的定义
for i in range(pred.size(1)):
dice += 2 * (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (pred[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) +
target[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + smooth)
dice = dice / pred.size(1)
# 返回的是dice距离 + 二值化交叉熵损失
return torch.clamp((1 - dice).mean(), 0, 1) + self.bce_loss(pred, target) * self.bce_weight
class JaccardLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, target):
smooth = 1
# jaccard系数的定义
jaccard = 0.
for i in range(pred.size(1)):
jaccard += (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (pred[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) +
target[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) - (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) + smooth)
# 返回的是jaccard距离
jaccard = jaccard / pred.size(1)
return torch.clamp((1 - jaccard).mean(), 0, 1)
class SSLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, target):
smooth = 1
loss = 0.
for i in range(pred.size(1)):
s1 = ((pred[:,i] - target[:,i]).pow(2) * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (smooth + target[:,i].sum(dim=1).sum(dim=1).sum(dim=1))
s2 = ((pred[:,i] - target[:,i]).pow(2) * (1 - target[:,i])).sum(dim=1).sum(dim=1).sum(dim=1) / (smooth + (1 - target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1))
loss += (0.05 * s1 + 0.95 * s2)
return loss / pred.size(1)
class TverskyLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, target):
smooth = 1
dice = 0.
for i in range(pred.size(1)):
dice += (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / ((pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1)+
0.3 * (pred[:,i] * (1 - target[:,i])).sum(dim=1).sum(dim=1).sum(dim=1) + 0.7 * ((1 - pred[:,i]) * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) + smooth)
dice = dice / pred.size(1)
return torch.clamp((1 - dice).mean(), 0, 2)