forked from rosinality/vq-vae-2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample.py
executable file
·100 lines (76 loc) · 2.71 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import os
import torch
from torchvision.utils import save_image
from tqdm import tqdm
from vqvae import VQVAE
from pixelsnail import PixelSNAIL
@torch.no_grad()
def sample_model(model, device, batch, size, temperature, condition=None):
row = torch.zeros(batch, *size, dtype=torch.int64).to(device)
cache = {}
for i in tqdm(range(size[0])):
for j in range(size[1]):
out, cache = model(row[:, : i + 1, :], condition=condition, cache=cache)
prob = torch.softmax(out[:, :, i, j] / temperature, 1)
sample = torch.multinomial(prob, 1).squeeze(-1)
row[:, i, j] = sample
return row
def load_model(model, checkpoint, device):
ckpt = torch.load(os.path.join('checkpoint', checkpoint))
if 'args' in ckpt:
args = ckpt['args']
if model == 'vqvae':
model = VQVAE()
elif model == 'pixelsnail_top':
model = PixelSNAIL(
[32, 32],
512,
args.channel,
5,
4,
args.n_res_block,
args.n_res_channel,
dropout=args.dropout,
n_out_res_block=args.n_out_res_block,
)
elif model == 'pixelsnail_bottom':
model = PixelSNAIL(
[64, 64],
512,
args.channel,
5,
4,
args.n_res_block,
args.n_res_channel,
attention=False,
dropout=args.dropout,
n_cond_res_block=args.n_cond_res_block,
cond_res_channel=args.n_res_channel,
)
if 'model' in ckpt:
ckpt = ckpt['model']
model.load_state_dict(ckpt)
model = model.to(device)
model.eval()
return model
if __name__ == '__main__':
device = 'cuda'
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8)
parser.add_argument('--vqvae', type=str)
parser.add_argument('--top', type=str)
parser.add_argument('--bottom', type=str)
parser.add_argument('--temp', type=float, default=1.0)
parser.add_argument('filename', type=str)
args = parser.parse_args()
model_vqvae = load_model('vqvae', args.vqvae, device)
model_top = load_model('pixelsnail_top', args.top, device)
model_bottom = load_model('pixelsnail_bottom', args.bottom, device)
top_sample = sample_model(model_top, device, args.batch, [32, 32], args.temp)
bottom_sample = sample_model(
model_bottom, device, args.batch, [64, 64], args.temp, condition=top_sample
)
decoded_sample = model_vqvae.decode_code(top_sample, bottom_sample)
decoded_sample = decoded_sample.clamp(-1, 1)
save_image(decoded_sample, args.filename, normalize=True, range=(-1, 1))