-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
boundary_loss.py
62 lines (48 loc) · 1.84 KB
/
boundary_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmseg.registry import MODELS
@MODELS.register_module()
class BoundaryLoss(nn.Module):
"""Boundary loss.
This function is modified from
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L122>`_. # noqa
Licensed under the MIT License.
Args:
loss_weight (float): Weight of the loss. Defaults to 1.0.
loss_name (str): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_boundary'.
"""
def __init__(self,
loss_weight: float = 1.0,
loss_name: str = 'loss_boundary'):
super().__init__()
self.loss_weight = loss_weight
self.loss_name_ = loss_name
def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor:
"""Forward function.
Args:
bd_pre (Tensor): Predictions of the boundary head.
bd_gt (Tensor): Ground truth of the boundary.
Returns:
Tensor: Loss tensor.
"""
log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1)
target_t = bd_gt.view(1, -1).float()
pos_index = (target_t == 1)
neg_index = (target_t == 0)
weight = torch.zeros_like(log_p)
pos_num = pos_index.sum()
neg_num = neg_index.sum()
sum_num = pos_num + neg_num
weight[pos_index] = neg_num * 1.0 / sum_num
weight[neg_index] = pos_num * 1.0 / sum_num
loss = F.binary_cross_entropy_with_logits(
log_p, target_t, weight, reduction='mean')
return self.loss_weight * loss
@property
def loss_name(self):
return self.loss_name_