-
Notifications
You must be signed in to change notification settings - Fork 9
/
grbm.py
362 lines (311 loc) · 14.3 KB
/
grbm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import cosine_schedule
class GRBM(nn.Module):
""" Gaussian-Bernoulli Restricted Boltzmann Machines (GRBM) """
def __init__(self,
visible_size,
hidden_size,
CD_step=1,
CD_burnin=0,
init_var=1e-0,
inference_method='Gibbs',
Langevin_step=10,
Langevin_eta=1.0,
is_anneal_Langevin=True,
Langevin_adjust_step=0) -> None:
super().__init__()
# we use samples in [CD_burnin, CD_step) steps
assert CD_burnin >= 0 and CD_burnin <= CD_step
assert inference_method in ['Gibbs', 'Langevin', 'Gibbs-Langevin']
self.visible_size = visible_size
self.hidden_size = hidden_size
self.CD_step = CD_step
self.CD_burnin = CD_burnin
self.init_var = init_var
self.inference_method = inference_method
self.Langevin_step = Langevin_step
self.Langevin_eta = Langevin_eta
self.is_anneal_Langevin = is_anneal_Langevin
self.Langevin_adjust_step = Langevin_adjust_step
self.W = nn.Parameter(torch.Tensor(visible_size, hidden_size))
self.b = nn.Parameter(torch.Tensor(hidden_size))
self.mu = nn.Parameter(torch.Tensor(visible_size))
self.log_var = nn.Parameter(torch.Tensor(visible_size))
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.W,
std=1.0 * self.init_var /
np.sqrt(self.visible_size + self.hidden_size))
nn.init.constant_(self.b, 0.0)
nn.init.constant_(self.mu, 0.0)
nn.init.constant_(self.log_var,
np.log(self.init_var)) # init variance = 1.0
def get_var(self):
return self.log_var.exp().clip(min=1e-8)
def set_Langevin_eta(self, eta):
self.Langevin_eta = eta
def set_Langevin_adjust_step(self, step):
self.Langevin_adjust_step = step
@torch.no_grad()
def energy(self, v, h):
# compute per-sample energy averaged over batch size
B = v.shape[0]
var = self.get_var()
eng = 0.5 * ((v - self.mu)**2 / var).sum(dim=1)
eng -= ((v / var).mm(self.W) * h).sum(dim=1) + h.mv(self.b)
return eng / B
@torch.no_grad()
def marginal_energy(self, v):
# compute per-sample energy averaged over batch size
B = v.shape[0]
var = self.get_var()
eng = 0.5 * ((v - self.mu)**2 / var).sum(dim=1)
eng -= F.softplus((v / var).mm(self.W) + self.b).sum(dim=1)
return eng / B
@torch.no_grad()
def energy_grad_v(self, v, h):
# compute the gradient (sample) of energy averaged over batch size
B = v.shape[0]
var = self.get_var()
return ((v - self.mu) / var - h.mm(self.W.T) / var) / B
@torch.no_grad()
def marginal_energy_grad_v(self, v):
# compute the gradient (sample) of energy averaged over batch size
B = v.shape[0]
var = self.get_var()
return ((v - self.mu) / var - torch.sigmoid((v / var).mm(self.W) + self.b).mm(self.W.T) / var) / B
@torch.no_grad()
def energy_grad_param(self, v, h):
# compute the gradient (parameter) of energy averaged over batch size
var = self.get_var()
grad = {}
grad['W'] = -torch.einsum("bi,bj->ij", v / var, h) / v.shape[0]
grad['b'] = -h.mean(dim=0)
grad['mu'] = ((self.mu - v) / var).mean(dim=0)
grad['log_var'] = (-0.5 * (v - self.mu)**2 / var +
((v / var) * h.mm(self.W.T))).mean(dim=0)
return grad
@torch.no_grad()
def marginal_energy_grad_param(self, v):
# compute the gradient (parameter) of energy averaged over batch size
var = self.get_var()
vv = v / var
tmp = torch.sigmoid(vv.mm(self.W) + self.b)
grad = {}
grad['W'] = -torch.einsum("bi,bj->ij", vv, tmp) / v.shape[0]
grad['b'] = -tmp.mean(dim=0)
grad['mu'] = ((self.mu - v) / var).mean(dim=0)
grad['log_var'] = (-0.5 * (v - self.mu)**2 / var +
(vv * tmp.mm(self.W.T))).mean(dim=0)
return grad
@torch.no_grad()
def prob_h_given_v(self, v, var):
return torch.sigmoid((v / var).mm(self.W) + self.b)
@torch.no_grad()
def prob_v_given_h(self, h):
return h.mm(self.W.T) + self.mu
@torch.no_grad()
def log_metropolis_ratio_Gibbs_Langevin(self, v_old, h_old, v_new, h_new, eta_list):
""" Metropolis-Hasting ratio of accepting the move from old to new state """
B = v_old.shape[0]
var = self.get_var()
eng_diff = -self.energy(v_new, h_new) + self.energy(v_old, h_old)
state_h_new = (v_new / var).mm(self.W) + self.b
state_h_old = (v_old / var).mm(self.W) + self.b
log_prob_h_given_v_new = - \
F.binary_cross_entropy_with_logits(
state_h_old, h_old, reduction='none').sum(dim=1)
log_prob_h_given_v_old = - \
F.binary_cross_entropy_with_logits(
state_h_new, h_new, reduction='none').sum(dim=1)
eta = torch.tensor(eta_list).to(var.device) # shape K X 1
beta_in = 1.0 - eta.unsqueeze(1) / (B * var.unsqueeze(0)) # shape K X D
beta = torch.flip(torch.cumprod(
torch.flip(beta_in, [0]), 0), [0]) # shape K X D
beta = F.pad(beta, [0, 0, 0, 1], "constant", 1.0) # shape (K+1) X D
va = (beta[1:] * eta.view(-1, 1)).sum(dim=0) / (B * var) # shape 1 X D
tilde_sigma_sqrt = (
(beta[1:]**2 * eta.view(-1, 1)).sum(dim=0)).sqrt() # shape 1 X D
proposal_eng_new = - torch.pow((v_old - beta[0] * v_new - va * (
self.mu + h_new.mm(self.W.T))) / (2 * tilde_sigma_sqrt), 2.0).sum(dim=1)
proposal_eng_old = - torch.pow((v_new - beta[0] * v_old - va * (
self.mu + h_old.mm(self.W.T))) / (2 * tilde_sigma_sqrt), 2.0).sum(dim=1)
return eng_diff + proposal_eng_new - proposal_eng_old + log_prob_h_given_v_new - log_prob_h_given_v_old
@torch.no_grad()
def log_metropolis_ratio_Langevin_one_step(self, v_old, v_new, grad_old, eta):
""" Metropolis-Hasting ratio of accepting the move from old to new state """
eng_diff = -self.marginal_energy(v_new) + self.marginal_energy(v_old)
proposal_eng_new = - \
torch.pow(v_old - v_new + eta *
self.marginal_energy_grad_v(v_new), 2.0).sum(dim=1) / (4 * eta)
proposal_eng_old = - \
torch.pow(v_new - v_old + eta * grad_old,
2.0).sum(dim=1) / (4 * eta)
return eng_diff + proposal_eng_new - proposal_eng_old
@torch.no_grad()
def Gibbs_sampling_vh(self, v, num_steps=10, burn_in=0):
samples, var = [], self.get_var()
std = var.sqrt()
h = torch.bernoulli(self.prob_h_given_v(v, var))
for ii in range(num_steps):
# backward sampling
mu = self.prob_v_given_h(h)
v = mu + torch.randn_like(mu) * std
# forward sampling
h = torch.bernoulli(self.prob_h_given_v(v, var))
if ii >= burn_in:
samples += [(v, h)]
return samples
@torch.no_grad()
def Langevin_sampling_v(self,
v,
num_steps=10,
eta=1.0e+0,
burn_in=0,
is_anneal=True,
adjust_step=0):
eta_list = cosine_schedule(eta_max=eta, T=num_steps)
samples = []
for ii in range(num_steps):
eta_ii = eta_list[ii] if is_anneal else eta
grad_v = self.marginal_energy_grad_v(v)
v_new = v - eta_ii * grad_v + \
torch.randn_like(v) * np.sqrt(eta_ii * 2)
if ii >= adjust_step:
tmp_u = torch.rand(v.shape[0]).to(v.device)
log_ratio = self.log_metropolis_ratio_Langevin_one_step(
v, v_new, grad_v, eta_ii)
ratio = torch.minimum(
torch.ones_like(log_ratio), log_ratio.exp())
v = v_new * (tmp_u < ratio).float().view(
-1, 1) + v * (tmp_u >= ratio).float().view(-1, 1)
else:
v = v_new
if ii >= burn_in:
samples += [v]
return samples
@torch.no_grad()
def Gibbs_Langevin_sampling_vh(self,
v,
num_steps=10,
num_steps_Langevin=10,
eta=1.0e+0,
burn_in=0,
is_anneal=True,
adjust_step=0):
samples, var = [], self.get_var()
eta_list = cosine_schedule(eta_max=eta, T=num_steps_Langevin)
h = torch.bernoulli(self.prob_h_given_v(v, var))
for ii in range(num_steps):
v_old, h_old = v, h
# backward sampling
for jj in range(num_steps_Langevin):
eta_jj = eta_list[jj] if is_anneal else eta
grad_v = self.energy_grad_v(v, h)
v = v - eta_jj * grad_v + \
torch.randn_like(v) * np.sqrt(eta_jj * 2)
# forward sampling
h = torch.bernoulli(self.prob_h_given_v(v, var))
if ii >= adjust_step:
tmp_u = torch.rand(v.shape[0]).to(v.device)
log_ratio = self.log_metropolis_ratio_Gibbs_Langevin(
v_old, h_old, v, h, eta_list)
ratio = torch.minimum(
torch.ones_like(log_ratio), log_ratio.exp())
v = v * (tmp_u < ratio).float().view(
-1, 1) + v_old * (tmp_u >= ratio).float().view(-1, 1)
h = h * (tmp_u < ratio).float().view(
-1, 1) + h_old * (tmp_u >= ratio).float().view(-1, 1)
if ii >= burn_in:
samples += [(v, h)]
return samples
@torch.no_grad()
def reconstruction(self, v):
v, var = v.view(v.shape[0], -1), self.get_var()
prob_h = self.prob_h_given_v(v, var)
v_bar = self.prob_v_given_h(prob_h)
return F.mse_loss(v, v_bar)
@torch.no_grad()
def sampling(self, v_init, num_steps=1, save_gap=1):
v_shape = v_init.shape
v = v_init.view(v_shape[0], -1)
var = self.get_var()
var_mean = var.mean().item()
if self.inference_method == 'Gibbs':
samples = self.Gibbs_sampling_vh(v, num_steps=num_steps - 1)
samples = [xx[0] for xx in samples] # extract v
elif self.inference_method == 'Langevin':
samples = self.Langevin_sampling_v(v,
num_steps=num_steps - 1,
eta=self.Langevin_eta * var_mean,
is_anneal=self.is_anneal_Langevin,
adjust_step=self.Langevin_adjust_step)
elif self.inference_method == 'Gibbs-Langevin':
samples = self.Gibbs_Langevin_sampling_vh(
v,
num_steps=num_steps - 1,
num_steps_Langevin=self.Langevin_step,
eta=self.Langevin_eta * var_mean,
is_anneal=self.is_anneal_Langevin,
adjust_step=self.Langevin_adjust_step)
samples = [xx[0] for xx in samples] # extract v
# use conditional mean as the last sample
h = torch.bernoulli(self.prob_h_given_v(samples[-1], var))
mu = self.prob_v_given_h(h)
v_list = [(0, v_init)] + [(ii + 1, samples[ii].view(v_shape).detach())
for ii in range(num_steps - 1)
if (ii + 1) % save_gap == 0
] + [(num_steps, mu.view(v_shape).detach())]
return v_list
@torch.no_grad()
def positive_grad(self, v):
h = torch.bernoulli(self.prob_h_given_v(v, self.get_var()))
grad = self.energy_grad_param(v, h)
return grad
@torch.no_grad()
def negative_grad(self, v):
var = self.get_var()
var_mean = var.mean().item()
if self.inference_method == 'Gibbs':
samples = self.Gibbs_sampling_vh(v,
num_steps=self.CD_step,
burn_in=self.CD_burnin)
v_neg = torch.cat([xx[0] for xx in samples], dim=0)
h_neg = torch.cat([xx[1] for xx in samples], dim=0)
grad = self.energy_grad_param(v_neg, h_neg)
elif self.inference_method == 'Langevin':
samples = self.Langevin_sampling_v(v,
num_steps=self.CD_step,
burn_in=self.CD_burnin,
eta=self.Langevin_eta * var_mean,
is_anneal=self.is_anneal_Langevin,
adjust_step=self.Langevin_adjust_step)
v_neg = torch.cat(samples, dim=0)
grad = self.marginal_energy_grad_param(v_neg)
elif self.inference_method == 'Gibbs-Langevin':
samples = self.Gibbs_Langevin_sampling_vh(
v,
num_steps=self.CD_step,
burn_in=self.CD_burnin,
num_steps_Langevin=self.Langevin_step,
eta=self.Langevin_eta * var_mean,
is_anneal=self.is_anneal_Langevin,
adjust_step=self.Langevin_adjust_step)
v_neg = torch.cat([xx[0] for xx in samples], dim=0)
h_neg = torch.cat([xx[1] for xx in samples], dim=0)
grad = self.energy_grad_param(v_neg, h_neg)
return grad
@torch.no_grad()
def CD_grad(self, v):
v = v.view(v.shape[0], -1)
# postive gradient
grad_pos = self.positive_grad(v)
# negative gradient
v_neg = torch.randn_like(v)
grad_neg = self.negative_grad(v_neg)
# compute update
for name, param in self.named_parameters():
param.grad = grad_pos[name] - grad_neg[name]