-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvae.py
118 lines (98 loc) · 3.76 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import numpy as np
import torch.nn as nn
from scipy.io import loadmat
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torch.distributions import Normal, Independent
from trainer import Trainer
class Encoder(nn.Module):
def __init__(self, x_dim, z_dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(x_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2 * z_dim),
)
self.x_dim = x_dim
self.z_dim = z_dim
self.hidden_dim = hidden_dim
def forward(self, x):
z = self.net(x)
means, vars = z[:, : self.z_dim], F.softplus(z[:, self.z_dim :])
return means, vars
class Decoder(nn.Module):
def __init__(self, x_dim, z_dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(z_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2 * x_dim),
)
self.x_dim = x_dim
self.z_dim = z_dim
self.hidden_dim = hidden_dim
def forward(self, z):
x = self.net(z)
means, vars = F.sigmoid(x[:, : self.x_dim]), F.softplus(x[:, self.x_dim :])
return means, vars
class GaussianVAE(nn.Module):
def __init__(self, x_dim, z_dim, hidden_dim, device="cuda"):
super().__init__()
self.encoder = Encoder(x_dim, z_dim, hidden_dim).to(device)
self.decoder = Decoder(x_dim, z_dim, hidden_dim).to(device)
self.p_z = Normal(torch.zeros(z_dim).to(device), torch.ones(z_dim).to(device))
self.z_dim = z_dim
self.optimizer = Adam(self.parameters(), lr=0.001)
self.device=device
def train_step(self, real_samples):
self.optimizer.zero_grad()
# Encoder
mean_enc, var_enc = self.encoder(real_samples)
q_z_given_x = Independent(Normal(mean_enc, var_enc), -1)
# Decoder
z_samples = q_z_given_x.rsample()
mean_dec, var_dec = self.decoder(z_samples)
p_x_given_z = Independent(Normal(mean_dec, var_dec), -1)
# Loss
loss = self.loss_fn(real_samples, q_z_given_x, p_x_given_z)
loss.backward()
self.optimizer.step()
return loss.item()
def loss_fn(self, real_samples, q_z_given_x, p_x_given_z):
kl_loss = -(
0.5 * (1 + torch.log(q_z_given_x.variance) - q_z_given_x.mean**2 - q_z_given_x.variance).sum(-1)
).mean()
cross_entropy_loss = - p_x_given_z.log_prob(real_samples).mean()
return cross_entropy_loss + kl_loss
def forward(self, num_samples):
z_samples = self.p_z.rsample([num_samples]).to("cuda")
mean, _ = self.decoder(z_samples)
return mean
torch.no_grad()
def sample_images(self, num_samples=16):
assert num_samples == int(np.sqrt(num_samples)) ** 2
display_images = self(num_samples).view(num_samples, 1, 28, 20)
display_images = make_grid(display_images, nrow=int(np.sqrt(num_samples)))
display_images = (
(255 * torch.clip(display_images, 0, 1)).detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8)
)
return display_images
if __name__ == "__main__":
data = "./data/frey_rawface.mat"
dest = "./training-runs"
run_name = "vae-frey_faces"
device = "cuda"
batch_size = 100
dataset = torch.Tensor(loadmat(data)["ff"].T / 255.0).to(device)
dataloader = DataLoader(dataset, batch_size=batch_size)
x_dim = 28 * 20
z_dim = 20
hidden_dim = 200
vae = GaussianVAE(x_dim, z_dim, hidden_dim, device)
total_steps = int(1e7)
log_every = int(5e5)
trainer = Trainer(vae, dataloader, total_steps, log_every, dest, run_name)
trainer.fit()