-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathhmc.py
52 lines (37 loc) · 1.6 KB
/
hmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
def hamiltonian(x, v, model):
energy = 0.5 * torch.pow(v, 2).sum(dim=1).sum(dim=1).sum(dim=1) + model(x).squeeze()
return energy
def leapfrog_step(x, v, model, step_size, num_steps, label, sample=False):
x = torch.log(x / (1 - x + 1e-10))
x.requires_grad_(requires_grad=True)
energy = model(torch.sigmoid(x))
im_grad = torch.autograd.grad([energy.sum()], [x])[0]
v = v - 0.5 * step_size * im_grad
# x = x.detach()
im_negs = []
for i in range(num_steps):
x.requires_grad_(requires_grad=True)
energy = model(torch.sigmoid(x))
im_grad = torch.autograd.grad([energy.sum()], [x])[0]
v = v - step_size * im_grad
x = x + step_size * v
x = x.detach()
v = v.detach()
# if i % 10 == 0:
# print(i, hamiltonian(torch.sigmoid(x), v, model, label).mean(), torch.abs(im_grad).mean())
x.requires_grad_(requires_grad=True)
energy = model(torch.sigmoid(x))
im_grad = torch.autograd.grad([energy.sum()], [x])[0]
v = v - 0.5 * im_grad
x = torch.sigmoid(x.detach())
return x, v, im_grad
def gen_hmc_image(im_neg, step_size, temperature, model_fn, num_steps=10, sample=False):
# energy = model.forward(im_neg, label)
v = 0.1 * torch.randn_like(im_neg)
im_neg_new, v_new, im_grad = leapfrog_step(im_neg, v, model_fn, step_size, num_steps, None)
orig = hamiltonian(im_neg, v, model_fn)
new = hamiltonian(im_neg_new, v_new, model_fn)
mask = (torch.exp((orig - new)) > (torch.rand(new.size(0))).to(im_neg.device))
im_neg_new[mask]= im_neg[mask]
return im_neg_new