-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_diffusion.py
53 lines (46 loc) · 1.92 KB
/
train_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import yaml
import os
from torch.utils.data import DataLoader
import logging
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
from timm.utils import AverageMeter
from utils import check_dir
from dataset.toy_2d import Toy_Dataset
from model.diffusion.diffusion import DDPM
with open("./config_diffusion.yaml", 'r') as f:
config = yaml.safe_load(f)
base_dir = "./log/diffusion/25gaussians"
device = "cuda:0" if torch.cuda.is_available() else torch.device("cpu")
check_dir(base_dir)
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
logging.basicConfig(level=logging.INFO, filename=os.path.join(base_dir, '{}.log'.format(timestamp)), filemode='a', format='%(asctime)s - %(message)s')
print(config)
logging.info(config)
fixed_noise = torch.randn([10000, 2], device=device)
train_dataset = Toy_Dataset(data_name="25gaussians")
train_data_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
model = DDPM(config)
train_process = tqdm(range(config['epochs']))
for epoch in train_process:
mean_loss = AverageMeter()
for train_step, batch in enumerate(train_data_loader):
batch = batch.to(device)
loss = model.train(batch)
mean_loss.update(loss.detach().cpu())
log = "Epoch: {}\t Loss of DDPM: {}\t ".format(epoch, mean_loss.avg)
logging.info(log)
print(log)
if (epoch+1)%config['plot_freq']==0:
fake_samples, _ = model.sample(fixed_noise)
fake_samples = fake_samples.cpu().detach().numpy()
plt.scatter(train_dataset.data[:, 0], train_dataset.data[:, 1], color='blue', label='True', s=2, alpha=0.5)
plt.scatter(fake_samples[:, 0], fake_samples[:, 1], color='red', label='Fake', s=2, alpha=0.5)
plt.xlim((-4,4))
plt.ylim((-4,4))
plt.grid()
plt.savefig(os.path.join(base_dir, "{}.png".format(str(epoch+1).zfill(4))))
plt.close()
torch.save(model.model, base_dir+"final.pt")