-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_sd_unet.py
105 lines (82 loc) · 3.43 KB
/
train_sd_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import functools
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.nn.parallel import DistributedDataParallel as DDP
import optix
def enable_tf32():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def train(model, vae, batchsize=1, use_amp=True, h=512, w=512, is_xl=False,
use_optix=False):
dt=torch.float32
if not is_xl:
timesteps = torch.arange(batchsize, dtype=torch.int64).cuda()+100
encoder_hidden_states = torch.rand([batchsize,77,768], dtype=dt).cuda()
# encoder_hidden_states = torch.rand([batchsize,77,1024], dtype=dt).cuda()
else:
timesteps = torch.arange(batchsize, dtype=torch.int64).cuda()+100
prompt_embeds = torch.rand([batchsize,77,2048], dtype=dt).cuda()
time_ids = torch.rand([batchsize,6], dtype=dt).cuda()
text_embeds = torch.rand([batchsize,1280], dtype=dt).cuda()
unet_added_conditions = {
"time_ids": time_ids,
"text_embeds": text_embeds
}
model.cuda()
if use_optix:
model, vae, opt, _ = optix.compile(model, vae, compile_vae=True)
else:
model.enable_gradient_checkpointing()
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
model = DDP(model)
perf_times = []
for ind in range(8):
model_input = torch.rand([batchsize, 3, h, w], dtype=torch.float32).cuda()
torch.cuda.synchronize()
beg = time.time()
if not use_optix:
with torch.no_grad():
noisy_model_input = vae.encode(model_input).latent_dist.sample().mul_(0.18215)
else:
noisy_model_input = optix.sliced_vae(vae, model_input, use_autocast=True, nhwc=True)
with torch.autocast(dtype=torch.float16, device_type='cuda', enabled=use_amp):
if is_xl:
model_pred = model(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample
loss = F.mse_loss(model_pred.float(), torch.rand_like(model_pred).float(), reduction="mean")
else:
model_pred = model(noisy_model_input, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), torch.rand_like(model_pred).float(), reduction="mean")
loss.backward()
opt.step()
opt.zero_grad()
torch.cuda.synchronize()
if ind>4:
perf_times.append(time.time()-beg)
beg=time.time()
print("max mem", torch.cuda.max_memory_allocated()/1e9)
print(perf_times)
enable_tf32()
rank, world_size, port, addr=optix.utils.setup_distributed()
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"
# pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4"
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet"
).cuda()
unet.train()
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae").cuda()
train(unet, vae, batchsize=4,
use_amp=False, h=576, w=1024, is_xl ='xl' in pretrained_model_name_or_path,
use_optix=True)