-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLoss.py
26 lines (20 loc) · 791 Bytes
/
Loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, preds, targets):
smooth = 1.0 # To prevent division by zero
preds = torch.sigmoid(preds) # Convert logits to probabilities
intersection = (preds * targets).sum()
dice = (2.0 * intersection + smooth) / (preds.sum() + targets.sum() + smooth)
return 1 - dice
class CombinedLoss(nn.Module):
def __init__(self):
super(CombinedLoss, self).__init__()
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
def forward(self, preds, targets):
bce_loss = self.bce(preds, targets)
dice_loss = self.dice(preds, targets)
return bce_loss + dice_loss