-
Notifications
You must be signed in to change notification settings - Fork 0
/
decoder.py
executable file
·37 lines (32 loc) · 942 Bytes
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter, \
OpenClipAdapter, DecoderTrainer, CLIP
# do above for many steps ...
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 768,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
attn_heads=8,
attn_dim_head=64,
text_embed_dim = 768,
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()
decoder_params = Decoder(
unet = unet1,
image_sizes = [128],
clip = OpenAIClipAdapter("ViT-L/14"),
timesteps = 1000,
sample_timesteps = 64,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.1
).cuda()
decoder_params1 = DecoderTrainer(
decoder_params,
lr = 1e-4,
wd = 1e-2,
ema_beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
)