forked from songdejia/EAST
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
51 lines (42 loc) · 1.98 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from torch.autograd import Variable
### 此处默认真实值和预测值的格式均为 bs * W * H * channels
import torch
import torch.nn as nn
def dice_coefficient(y_true_cls, y_pred_cls,
training_mask):
'''
dice loss
:param y_true_cls:
:param y_pred_cls:
:param training_mask:
:return:
'''
eps = 1e-5
intersection =torch.sum(y_true_cls * y_pred_cls * training_mask)
union = torch.sum(y_true_cls * training_mask) + torch.sum(y_pred_cls * training_mask) + eps
loss = 1. - (2 * intersection / union)
return loss
class LossFunc(nn.Module):
def __init__(self):
super(LossFunc, self).__init__()
return
def forward(self, y_true_cls, y_pred_cls, y_true_geo, y_pred_geo, training_mask):
classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask)
# scale classification loss to match the iou loss part
classification_loss *= 0.01
# d1 -> top, d2->right, d3->bottom, d4->left
# d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3)
d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = torch.split(y_true_geo, 1, 1)
# d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3)
d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = torch.split(y_pred_geo, 1, 1)
area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
w_union = torch.min(d2_gt, d2_pred) + torch.min(d4_gt, d4_pred)
h_union = torch.min(d1_gt, d1_pred) + torch.min(d3_gt, d3_pred)
area_intersect = w_union * h_union
area_union = area_gt + area_pred - area_intersect
L_AABB = -torch.log((area_intersect + 1.0)/(area_union + 1.0))
L_theta = 1 - torch.cos(theta_pred - theta_gt)
L_g = L_AABB + 20 * L_theta
return torch.mean(L_g * y_true_cls * training_mask) + classification_loss