-
Notifications
You must be signed in to change notification settings - Fork 41
/
train_ppt2_bn.py
1046 lines (921 loc) · 43.7 KB
/
train_ppt2_bn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
# modified from https://github.com/TencentARC/BrushNet/blob/main/examples/brushnet/train_brushnet.py
import argparse
import gc
import logging
import math
import os
import shutil
from pathlib import Path
import accelerate
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import PretrainedConfig
import diffusers
import powerpaint.datasets
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
from powerpaint.datasets import ProbPickingDataset
from powerpaint.models import BrushNetModel, UNet2DConditionModel
from powerpaint.pipelines import StableDiffusionPowerPaintBrushNetPipeline
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
logger = get_logger(__name__)
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = ""
if image_logs is not None:
img_str = "You can find some example images below.\n\n"
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
validation_image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
# PowerPaint - {repo_id}
These are PowerPaint weights trained on {base_model} with new type of conditioning.
{img_str}
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = [
"stable-diffusion",
"stable-diffusion-diffusers",
"text-to-image",
"diffusers",
"PowerPaint",
"diffusers-training",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(tokenizer, text_encoder, brushnet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
# use fixed model from pretrained models, and text_encoder and tokenizer from trainer
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
),
tokenizer=tokenizer,
text_encoder=accelerator.unwrap_model(text_encoder),
brushnet=accelerator.unwrap_model(brushnet),
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
local_files_only=True, # load files from local cache
)
pipe = pipe.to(accelerator.device)
pipe.set_progress_bar_config(disable=True)
if args.enable_xformers_memory_efficient_attention:
pipe.enable_xformers_memory_efficient_attention()
# load validation images
image_logs = []
for case in args.validation_data.cases:
validation_prompts = case.prompt
validation_image = Image.open(os.path.join(args.validation_data.data_root, case.image)).convert("RGB")
validation_mask = Image.open(os.path.join(args.validation_data.data_root, case.mask))
validation_mask = validation_mask.resize((validation_image.size[0], validation_image.size[1]), Image.NEAREST)
validation_mask = validation_mask.convert("L")
hole_value = (0, 0, 0)
validation_image = Image.composite(
Image.new("RGB", (validation_image.size[0], validation_image.size[1]), hole_value),
validation_image,
validation_mask.convert("L"),
)
image_grid = Image.new(
"RGB",
(validation_image.size[0] * (1 + len(validation_prompts)), validation_image.size[1]),
(255, 255, 255),
)
image_grid.paste(validation_image, (0, 0))
t2i_mask = Image.new("RGB", (validation_image.size[0], validation_image.size[1]), (255, 255, 255)).convert("L")
t2i_image = Image.new("RGB", (validation_image.size[0], validation_image.size[1]), (0, 0, 0))
for i, p in enumerate(validation_prompts):
with torch.autocast(accelerator.device.type):
image = pipe(
promptA=p.promptA,
promptB=p.promptB,
prompt=p.prompt,
negative_promptA=p.negative_promptA,
negative_promptB=p.negative_promptB,
negative_prompt=p.negative_prompt,
tradeoff=p.tradeoff,
image=validation_image if p.task != "t2i" else t2i_image,
mask=validation_mask if p.task != "t2i" else t2i_mask,
num_inference_steps=20,
).images[0]
image_logs.append(image)
image_grid.paste(image, (validation_image.size[0] * (i + 1), 0))
image_grid.save(os.path.join(args.output_dir, f"{str(step).zfill(3)}_{os.path.basename(case.image)}"))
gc.collect()
torch.cuda.empty_cache()
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in image_logs])
tracker.writer.add_images("validation", np_images, step, dataformats="NHWC")
elif tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{p.task}")
for image, p in zip(image_logs, args.validation_data.cases[0].prompt)
]
}
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
del pipe
gc.collect()
torch.cuda.empty_cache()
return image_logs
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else:
raise ValueError(f"{model_class} is not supported.")
def parse_args(input_args=None):
parser = argparse.ArgumentParser(
description="Simple example of a PowerPaint based on brushnet architecture training script."
)
parser.add_argument(
"--config",
type=str,
default=None,
help="yaml for configuration",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--powerpaint_model_name_or_path",
type=str,
default=None,
help="Path to pretrained powerpaint model or model identifier from huggingface.co/models."
" If not specified powerpaint weights are initialized from unet.",
)
parser.add_argument(
"--output_dir",
type=str,
default="runs/ppt2_bn",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=10000)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
"instructions."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warm up period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warm up in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
help=(
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
" behaviors, so disable this argument if it causes any problems. More info:"
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
),
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing the target image."
)
parser.add_argument(
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the powerpaint conditioning image.",
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--proportion_empty_prompts",
type=float,
default=0,
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--validation_steps",
type=int,
default=100,
help=(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="train_powerpaint_brushnet",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
# use omegaconf to manage configurations
if args.config is not None:
config = OmegaConf.load(args.config)
for k, v in config.items():
args.__dict__[k] = v
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the brushnet encoder."
)
return args
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
torch.manual_seed(args.seed)
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# saving training configuration to output_dir
to_save_config = OmegaConf.create(vars(args))
OmegaConf.save(config=to_save_config, f=os.path.join(args.output_dir, "training_config.yaml"))
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# initialize from pre-trained pipeline
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
),
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
local_files_only=True, # load files from local cache
)
if args.powerpaint_model_name_or_path:
logger.info("Loading existing powerpaint weights")
pipe.brushnet = BrushNetModel.from_pretrained(args.powerpaint_model_name_or_path)
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# IMPORTANT: add learnable tokens for task prompts into tokenizer
placeholder_tokens = [v.placeholder_tokens for k, v in args.task_prompt.items()]
initializer_token = [v.initializer_token for k, v in args.task_prompt.items()]
num_vectors_per_token = [v.num_vectors_per_token for k, v in args.task_prompt.items()]
placeholder_token_ids = pipe.add_tokens(
placeholder_tokens, initializer_token, num_vectors_per_token, initialize_parameters=True
)
vae, tokenizer, unet, noise_scheduler = pipe.vae, pipe.tokenizer, pipe.unet, pipe.scheduler
text_encoder, brushnet = pipe.text_encoder.to(torch.float32), pipe.brushnet.to(torch.float32)
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
sub_dir = "brushnet" if isinstance(model, type(unwrap_model(brushnet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(text_encoder))):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
model.config = load_model.config
else:
# load diffusers style into model
load_model = BrushNetModel.from_pretrained(input_dir, subfolder="brushnet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.gradient_checkpointing:
brushnet.enable_gradient_checkpointing()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
brushnet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Check that all trainable models are in full precision
low_precision_error_string = (
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training, copy of the weights should still be float32."
)
if unwrap_model(brushnet).dtype != torch.float32:
raise ValueError(f"BrushNet loaded as datatype {unwrap_model(brushnet).dtype}. {low_precision_error_string}")
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# 1. trainable embedding + 2. trainable model (brushnet)
vae.requires_grad_(False)
unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
optimizer = optimizer_class(
list(brushnet.parameters()) + list(text_encoder.get_input_embeddings().parameters()),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# transforms used for preprocessing dataset
train_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
# preparing datasets and dataloader for training.
# support loading multiple datasets in a single dataloader.
datasets_list = []
for d in args.train_data.datasets:
dataset_class = getattr(powerpaint.datasets, d.dataset_class)
dataset_ = dataset_class(train_transforms, pipe, args.task_prompt, **d)
datasets_list.append({"dataset": dataset_, "prob": d.prob})
train_dataset = ProbPickingDataset(datasets_list)
with accelerator.main_process_first():
if args.max_train_samples is not None:
train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
brushnet.train()
text_encoder.train()
# Prepare everything with our `accelerator`.
brushnet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
brushnet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
# Move vae, unet and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = dict(vars(args))
# tensorboard cannot handle list types for config
pop_list = []
for k, v in tracker_config.items():
if not isinstance(v, (int, float, str, bool, torch.Tensor)):
pop_list.append(k)
logger.info(f"Removed {k} (type:{type(v)}) from tracker_config")
for k in pop_list:
tracker_config.pop(k)
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info(f"***** Running training for {args.tracker_project_name} *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {int(args.max_train_steps)}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
args.resume_from_checkpoint = None
initial_global_step = 0
else:
logger.info(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path), map_location="cpu")
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, int(args.max_train_steps)),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
image_logs = None
# keep original embeddings as reference
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
for _ in range(first_epoch, args.num_train_epochs):
train_loss = 0.0
for batch in train_dataloader:
with accelerator.accumulate(brushnet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * vae.config.scaling_factor
# we follow the same annotation for mask as
# https://github.com/huggingface/diffusers/blob/v0.30.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
# mask: 1 for masked regions and 0 for known regions
mask = torch.nn.functional.interpolate(batch["mask"], size=(64, 64))
mask_image = batch["pixel_values"] * (batch["mask"] < 0.5)
# convert the hole value from 0 to -1 due to [-1, 1] range
mask_image = mask_image - batch["mask"]
mask_image_latents = vae.encode(mask_image.to(weight_dtype)).latent_dist.sample()
mask_image_latents = (mask_image_latents * vae.config.scaling_factor).to(weight_dtype)
conditioning_latents = torch.concat([mask, mask_image_latents], 1)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning unet
encoder_hidden_states_unet = text_encoder(batch["input_ids"], return_dict=False)[0]
# text embedding for brushnet, (bs, 77, 768)
encoder_hidden_statesA = text_encoder(batch["input_idsA"], return_dict=False)[0]
encoder_hidden_statesB = text_encoder(batch["input_idsB"], return_dict=False)[0]
# tradeoff between two text embeddings (bs, 2, 1)
tradeoff = batch["tradeoff"].unsqueeze(-1)
encoder_hidden_states_brushnet = (
tradeoff[:, 0:1, :] * encoder_hidden_statesA + tradeoff[:, 1:, :] * encoder_hidden_statesB.detach()
)
# Run the brushnet forward pass
down_block_res_samples, mid_block_res_sample, up_block_res_samples = brushnet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states_brushnet.to(weight_dtype),
brushnet_cond=conditioning_latents,
return_dict=False,
)
# Predict the noise residual
model_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states_unet.detach().to(weight_dtype),
down_block_add_samples=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
mid_block_add_sample=mid_block_res_sample.to(dtype=weight_dtype),
up_block_add_samples=[sample.to(dtype=weight_dtype) for sample in up_block_res_samples],
return_dict=False,
)[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
dim=1
)[0]
if noise_scheduler.config.prediction_type == "epsilon":
mse_loss_weights = mse_loss_weights / snr
elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = list(brushnet.parameters()) + list(
accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if hasattr(args, "validation_data") and global_step % args.validation_steps == 0:
image_logs = log_validation(
tokenizer,
text_encoder,
brushnet,
args,
accelerator,
weight_dtype,
global_step,
)