-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
79 lines (52 loc) · 2.58 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import numpy as np
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
class Fast_MMD(torch.nn.Module):
def __init__(self, gamma, features_out) -> None:
super().__init__()
self.gamma = gamma
self.features_out = features_out
def forward(self, z1, z2):
features_in = z1.shape[-1]
w = torch.randn((features_in, self.features_out))
b = torch.zeros((self.features_out,)).uniform_(0,2 * np.pi)
psi_z1 = self.psi(z1, w, b).mean(dim=0)
psi_z2 = self.psi(z2, w, b).mean(dim=0)
return torch.norm(psi_z1 - psi_z2, 2)
def psi(self, x, w, b):
return np.sqrt(2 / self.features_out) * (np.sqrt(2 / self.gamma) * (x @ w + b)).cos()
class VFAE_loss(torch.nn.Module):
def __init__(self, alpha, beta, gamma, dims_out):
super().__init__()
self.alpha = alpha
self.beta = beta
self.ce_loss = CrossEntropyLoss()
self.bce_loss = BCEWithLogitsLoss()
self.mmd = Fast_MMD(gamma, dims_out)
def forward(self, y_true, y_pred):
x, s, y = y_true['x'], y_true['s'], y_true['y']
x_s = torch.cat([x, s], dim=-1)
y = y.type(torch.LongTensor).reshape(y.size(0))
supervised_loss = self.ce_loss(y_pred['y_pred'], y)
reconstruction_loss = self.bce_loss(y_pred['x_pred'], x_s)
zeros = torch.zeros_like(y_pred['z1_enc_logvar'])
kl_loss_z1 = self.kl_divergence(y_pred['z1_enc_logvar'], y_pred['z1_dec_logvar'], y_pred['z1_enc_mean'], y_pred['z1_dec_mean'])
kl_loss_z2 = self.kl_divergence(y_pred['z2_enc_logvar'], y_pred['z2_enc_mean'], zeros, zeros)
vfae_loss = self.alpha * supervised_loss + kl_loss_z1 + kl_loss_z2 + reconstruction_loss
z1_encoded = y_pred['z1_enc']
z1_sensitive, z1_nonsensitive = self.separate_sensitive(z1_encoded, s)
if len(z1_sensitive) > 0:
vfae_loss += self.beta * self.mmd(z1_sensitive, z1_nonsensitive)
return vfae_loss
@staticmethod
def kl_divergence(logvar_z1, logvar_z2, mu_z1, mu_z2):
per_example_kl = logvar_z2 - logvar_z1 - 1 + (logvar_z1.exp() + (mu_z1 - mu_z2).square()) / logvar_z2.exp()
kl = 0.5 * torch.sum(per_example_kl, dim=1)
return kl.mean()
@staticmethod
def separate_sensitive(variables, s):
sensitive_ix = (s == 1).nonzero()[:, 0]
nonsensitive_ix = (s == 0).nonzero()[:, 0]
sensitive = variables[sensitive_ix]
nonsensitive = variables[nonsensitive_ix]
return sensitive, nonsensitive