-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers.py
89 lines (76 loc) · 2.84 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch import nn
from torch.nn import functional as F
# Flatten layer
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
# Reshape layer
class Reshape(nn.Module):
def __init__(self, outer_shape):
super(Reshape, self).__init__()
self.outer_shape = outer_shape
def forward(self, x):
return x.view(x.size(0), *self.outer_shape)
# Sample from the Gumbel-Softmax distribution and optionally discretize.
class GumbelSoftmax(nn.Module):
def __init__(self, in_dim, num_class, class_dim):
super(GumbelSoftmax, self).__init__()
self.logits = nn.Linear(in_dim, num_class*class_dim)
self.in_dim = in_dim
self.num_class = num_class
self.class_dim = class_dim
def sample_gumbel(self, shape, is_cuda=False, eps=1e-20):
U = torch.rand(shape)
if is_cuda:
U = U.cuda()
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(self, logits, temperature):
y = logits + self.sample_gumbel(logits.size(), logits.is_cuda)
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(self, logits, temperature, hard=False):
"""
ST-gumple-softmax
input: [*, n_class]
return: flatten --> [*, n_class] an one-hot vector
"""
#categorical_dim = 10
y = self.gumbel_softmax_sample(logits, temperature)
if not hard:
return y
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
# Set gradients w.r.t. y_hard gradients w.r.t. y
y_hard = (y_hard - y).detach() + y
return y_hard
def forward(self, x, temperature=1.0, hard=False):
logits_y = self.logits(x) #[b,N*K]
logits_y = logits_y.view(-1, self.num_class) #[b*N,K]
q_y = F.softmax(logits_y, dim=-1)
#print("q_y1:",q_y.reshape([-1,self.class_dim, self.num_class])[0])
#print("q_y1 sum0:",q_y.reshape([-1,self.class_dim, self.num_class])[0].sum(0))
#print("q_y1 sum1:",q_y.reshape([-1,self.class_dim, self.num_class])[0].sum(1))
#print("q_y1 argmax:",torch.argmax(q_y.reshape([-1,self.class_dim, self.num_class])[0],dim=1))
log_q_y = torch.log(q_y + 1e-20)
y = self.gumbel_softmax(logits_y, temperature, hard) #[b*N,K]
#print("y1:",y[0].reshape([self.class_dim, self.num_class]))
return y, log_q_y, q_y
# Sample from a Gaussian distribution
class Gaussian(nn.Module):
def __init__(self, in_dim, out_dim):
super(Gaussian, self).__init__()
self.mu = nn.Linear(in_dim, out_dim)
self.var = nn.Linear(in_dim, out_dim)
def reparameterize(self, mu, var):
std = torch.sqrt(var + 1e-10)
noise = torch.randn_like(std)
z = mu + noise * std
return z
def forward(self, x):
mu = self.mu(x)
var = F.softplus(self.var(x))
z = self.reparameterize(mu, var)
return mu, var, z