-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
78 lines (62 loc) · 2.47 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import Dict, Optional, Tuple
from tqdm import tqdm
import os
import pdb
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from mindiffusion.unet import NaiveUnet, ContextUnet
from mindiffusion.ddpm import DDPM, DDPM_Context
def train_cifar10(
n_epoch: int = 100, device: str = "cuda:0", load_pth: Optional[str] = None
) -> None:
# uncomment to select one of below model
#ddpm = DDPM(eps_model=NaiveUnet(3, 3, n_feat=128), betas=(1e-4, 0.02), n_T=1000)
#ddpm = DDPM_Context(eps_model=ContextUnet(3, 3, n_feat=128, encoding='onehot', nc_feat=10), betas=(1e-4, 0.02), n_T=1000)
ddpm = DDPM_Context(eps_model=ContextUnet(3, 3, n_feat=128, encoding='clip', nc_feat=512), betas=(1e-4, 0.02), n_T=1000)
if load_pth is not None:
ddpm.load_state_dict(torch.load(load_pth))
ddpm.to(device)
tf = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((32, 32)), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
dataset = CIFAR10(
"./data",
train=True,
download=True,
transform=tf,
)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
optim = torch.optim.Adam(ddpm.parameters(), lr=1e-5)
for i in range(n_epoch):
print(f"Epoch {i} : ")
ddpm.train()
pbar = tqdm(dataloader)
loss_ema = None
for x, y in pbar:
optim.zero_grad()
x = x.to(device)
y = y.to(device)
loss = ddpm(x, y)
loss.backward()
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
pbar.set_description(f"loss: {loss_ema:.4f}")
optim.step()
ddpm.eval()
with torch.no_grad():
xh = ddpm.sample(8, (3, 32, 32), device)
xset = torch.cat([xh, x[:8]], dim=0)
grid = make_grid(xset, normalize=True, value_range=(-1, 1), nrow=4)
save_image(grid, f"./contents/ddpm_sample_cifar_{i}.png")
# save model
torch.save(ddpm.state_dict(), f"./models/ddpm_context_clip_cifar.pth")
if __name__ == "__main__":
os.makedirs('contents',exist_ok=True)
os.makedirs('models', exist_ok=True)
train_cifar10(n_epoch=100, device="cuda", load_pth=None)