-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_ldm.py
137 lines (114 loc) · 4.96 KB
/
train_ldm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import builtins
import os
import time
import json
import hydra
import jittor as jt
from omegaconf import OmegaConf
import wandb
from model_jittor.dataset import get_ldm_dataloader
from model_jittor.ldm.ddpm import LatentDiffusion
from utils import make_grid, save_checkpoint, toggle_to_train, toggle_to_eval
with open('./assets/class_labels.json', 'r') as f:
# NOTE: the class id and labels are not matched, just for wandb log
_class_labels = json.load(f)
class_labels = {}
for key, val in _class_labels.items():
class_labels[int(key)] = val
@hydra.main(version_base=None, config_path='configs', config_name='ldm.yaml')
def init_and_run(cfg):
# only print on the master node
if jt.world_size > 1 and jt.rank != 0:
def print_pass(*args): pass
builtins.print = print_pass
# init jittor
jt.flags.use_cuda = cfg.jittor.use_cuda
jt.flags.auto_mixed_precision_level = cfg.jittor.amp_level
jt.set_global_seed(cfg.jittor.seed)
# adjust batch size and lr according to num of gpus (i.e. batch size)
if jt.world_size > 1:
batch_size_old = cfg.data.batch_size
cfg.data.batch_size *= jt.world_size
cfg.lr *= jt.world_size
print(f"Adjust batch size: {batch_size_old} -> {cfg.data.batch_size}")
print(f"Adjust lr: {cfg.lr} -> {cfg.lr * jt.world_size}")
# configure wandb and ckpt's save dir
if jt.rank == 0:
config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
wandb.init(project=cfg.project, name=cfg.name, config=config)
# save ckpt in './save/run-id/checkpoints', run-id is generated by wandb
cfg.save_dir = wandb.run.dir.replace('wandb', 'save').replace(
'files', 'checkpoints')
os.makedirs(os.path.join(cfg.save_dir), exist_ok=True)
print(f'Saving checkpoints in {cfg.save_dir}')
main(cfg)
def main(cfg):
# data # TODO: use all data for training ddpm
train_loader, val_loader = get_ldm_dataloader(**cfg.data)
# init model and ema model
model = LatentDiffusion(**cfg.model)
# configure optimizer TODO: try lr_scheduler
optimizer = jt.optim.Adam(
list(model.model.parameters()) + list(model.cond_stage_model.parameters()),
lr=cfg.lr,
)
# resume
if cfg.resume is not None:
assert os.path.isfile(cfg.resume)
checkpoint = jt.load(cfg.resume)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
cfg.start_epoch = checkpoint['epoch'] + 1 # start from the next epoch
print('Start training, good luck!')
for epoch in range(cfg.start_epoch, cfg.epochs):
start_time = time.time()
toggle_to_train(model)
for i, (img, seg, _) in enumerate(train_loader):
global_train_steps = epoch * len(train_loader) + i
loss = model(img, seg)
optimizer.step(loss)
model.step_ema() # ema unet
if global_train_steps % cfg.print_freq == 0:
print(
f"epoch: {epoch:3d}\t",
f"iter: [{i:4d}/{len(train_loader)}]\t",
f"loss {loss.item():7.3f}\t",
)
if jt.rank == 0: # TODO: warp wandb.log to master only
wandb.log({
"train/epoch": epoch,
"train/iter": global_train_steps,
"train/loss": loss.item(),
})
train_time = time.time() - start_time
print(f'Epoch {epoch:3d} training time: {train_time/60:.2f} min.')
# sample val set
toggle_to_eval(model)
if epoch % cfg.sample_freq == 0:
img, seg, _ = next(iter(val_loader))
img, seg = img[:4], seg[:4]
img_sample = model.sample_and_decode(seg)
img_rec = model.first_stage_model(img)
img = jt.clamp((img.detach() + 1) / 2, 0, 1)
img_sample = jt.clamp((img_sample.detach() + 1) / 2, 0, 1)
img_rec = jt.clamp((img_rec.detach() + 1) / 2, 0, 1)
seg, _ = jt.argmax(seg.detach(), dim=1, keepdims=True)
if jt.rank == 0:
wandb.log({
'generated': wandb.Image(make_grid(img_sample.data, n_cols=4)),
'reconstructed': wandb.Image(make_grid(img_rec.data, n_cols=4)),
'original': wandb.Image(
make_grid(img.data, n_cols=4), masks={
"ground_truth":
{"mask_data": make_grid(seg.data, n_cols=4),
"class_labels": class_labels}
},
),
})
save_checkpoint({
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, save_dir=cfg.save_dir, filename=f"epoch_{epoch}.ckpt")
if __name__ == "__main__":
init_and_run()