-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsampling.py
70 lines (56 loc) · 2.05 KB
/
sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import numpy as np
import torch
from .data.dataset import tensor_to_pil_image
from model import DiffusionModule
from scheduler import DDPMScheduler
from pathlib import Path
def main(args):
save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
device = f"cuda:{args.gpu}"
ddpm = DiffusionModule(None, None)
ddpm.load(args.ckpt_path)
ddpm.eval()
ddpm = ddpm.to(device)
num_train_timesteps = ddpm.var_scheduler.num_train_timesteps
ddpm.var_scheduler = DDPMScheduler(
num_train_timesteps,
beta_1=1e-4,
beta_T=0.02,
mode="linear",
).to(device)
total_num_samples = 500
num_batches = int(np.ceil(total_num_samples / args.batch_size))
for i in range(num_batches):
sidx = i * args.batch_size
eidx = min(sidx + args.batch_size, total_num_samples)
B = eidx - sidx
if args.use_cfg: # Enable CFG sampling
assert ddpm.network.use_cfg, f"The model was not trained to support CFG."
samples = ddpm.sample(
B,
class_label=torch.randint(1, 4, (B,)),
guidance_scale=args.cfg_scale,
)
else:
samples = ddpm.sample(
B,
class_label=torch.randint(1, 4, (B,)),
guidance_scale=0.0,
)
pil_images = tensor_to_pil_image(samples)
for j, img in zip(range(sidx, eidx), pil_images):
img.save(save_dir / f"{j}.png")
print(f"Saved the {j}-th image.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--ckpt_path", type=str)
parser.add_argument("--save_dir", type=str)
parser.add_argument("--use_cfg", action="store_true")
parser.add_argument("--sample_method", type=str, default="ddpm")
parser.add_argument("--cfg_scale", type=float, default=7.5)
args = parser.parse_args()
main(args)