forked from kwea123/ngp_pl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
40 lines (31 loc) · 1.06 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch import nn
def shiftscale_inv_depthloss(disp_pred, disp_gt):
"""
Computes the shift- scale- invariant depth loss as proposed in
https://arxiv.org/pdf/1907.01341.pdf.
Inputs:
disp_pred: (N) disp predicted by the network
disp_gt: (N) disparity produced by image-based method.
Outputs:
loss: (N)
"""
t_pred = torch.median(disp_pred)
s_pred = torch.mean(torch.abs(disp_pred-t_pred))
t_gt = torch.median(disp_gt)
s_gt = torch.mean(torch.abs(disp_gt-t_gt))
disp_pred_n = (disp_pred-t_pred)/s_pred
disp_gt_n = (disp_gt-t_gt)/s_gt
loss = (disp_pred_n-disp_gt_n)**2
return loss
class NeRFLoss(nn.Module):
def __init__(self):
super().__init__()
self.lambda_opa = 1e-3
def forward(self, results, target, **kwargs):
d = {}
d['rgb'] = (results['rgb']-target['rgb'])**2
o = results['opacity']+1e-10
# encourage opacity to be either 0 or 1 to avoid floater
d['opacity'] = self.lambda_opa*(-o*torch.log(o))
return d