-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathlosses_and_metrics.py
41 lines (35 loc) · 1.42 KB
/
losses_and_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import constants as C
def reshape_targs(targs, mask_val=C.BATCH_PAD_VAL):
targs = targs.view(-1, targs.size(-1))
return targs[targs[:,0]!=mask_val]
def group_mean_log_mae(y_true, y_pred, types, sc_mean=0, sc_std=1):
def proc(x):
if isinstance(x, torch.Tensor): return x.cpu().numpy().ravel()
y_true, y_pred, types = proc(y_true), proc(y_pred), proc(types)
y_true = sc_mean + y_true * sc_std
y_pred = sc_mean + y_pred * sc_std
maes = pd.Series(y_true - y_pred).abs().groupby(types).mean()
gmlmae = np.log(maes).mean()
return gmlmae
def contribs_rmse_loss(preds, targs):
"""
Returns the sum of RMSEs for each scalar coupling (sc) contribution and
the sc constant in a batch.
Args:
- preds: tensor of shape (n_sc_batch, 5) containing predictions. Last
column is the scalar coupling constant.
- targs: tensor of shape (batch_size, max_n_sc_per_molecule, 5)
containing true values. Last column is the scalar coupling constant.
"""
targs = reshape_targs(targs)
return torch.mean((preds - targs) ** 2, dim=0).sqrt().sum()
def rmse(preds, targs):
targs = reshape_targs(targs)
return torch.sqrt(F.mse_loss(preds[:,-1], targs[:,-1]))
def mae(preds, targs):
targs = reshape_targs(targs)
return torch.abs(preds[:,-1] - targs[:,-1]).mean()