From bacab3620c2378350ea313ff1fbba8400f411725 Mon Sep 17 00:00:00 2001 From: Jasmine Collins Date: Mon, 21 Aug 2023 10:37:43 -0700 Subject: [PATCH] Add partial SDXL model (#61) * add sdxl unet * fix stochastic failures in streaming datasets * add some debug logging * unpin some reqs * add yamls * remove debug prints * allow passing vae model path * add base * remove trailing whitespace * split sdxl into separate model * remove local yamls * clean up sd2 doc * one more doc fix * add NotImplementedError, fix docs --- diffusion/models/models.py | 129 +++++++++++++++++++++++++++++++++++-- 1 file changed, 123 insertions(+), 6 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 6990710f..e2eed1a8 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -44,8 +44,8 @@ def stable_diffusion_2( prompts. Args: - model_name (str, optional): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'. - pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + model_name (str): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'. + pretrained (bool): Whether to load pretrained weights. Defaults to True. prediction_type (str): The type of prediction to use. Must be one of 'sample', 'epsilon', or 'v_prediction'. Default: `epsilon`. train_metrics (list, optional): List of metrics to compute during training. If None, defaults to @@ -54,12 +54,12 @@ def stable_diffusion_2( [MeanSquaredError(), FrechetInceptionDistance(normalize=True)]. val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to [1.0, 3.0, 7.0]. - val_seed (int, optional): Seed to use for generating evaluation images. Defaults to 1138. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to [(0, 1)]. - precomputed_latents (bool, optional): Whether to use precomputed latents. Defaults to False. - encode_latents_in_fp16 (bool, optional): Whether to encode latents in fp16. Defaults to True. - fsdp (bool, optional): Whether to use FSDP. Defaults to True. + precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + fsdp (bool): Whether to use FSDP. Defaults to True. """ if train_metrics is None: train_metrics = [MeanSquaredError()] @@ -123,6 +123,123 @@ def stable_diffusion_2( return model +def stable_diffusion_xl( + model_name: str = 'stabilityai/stable-diffusion-2-base', + unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', + pretrained: bool = True, + prediction_type: str = 'epsilon', + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + val_guidance_scales: Optional[List] = None, + val_seed: int = 1138, + loss_bins: Optional[List] = None, + precomputed_latents: bool = False, + encode_latents_in_fp16: bool = True, + fsdp: bool = True, +): + """Stable diffusion 2 training setup + SDXL UNet and VAE. + + Requires batches of matched images and text prompts to train. Generates images from text + prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. + + Args: + model_name (str): Name of the model to load. Determines the text encoder, tokenizer, + and noise scheduler. Defaults to 'stabilityai/stable-diffusion-2-base'. + unet_model_name (str): Name of the UNet model to load. Defaults to + 'stabilityai/stable-diffusion-xl-base-1.0'. + vae_model_name (str): Name of the VAE model to load. Defaults to + 'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from + 'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16. + pretrained (bool): Whether to load pretrained weights. Defaults to True. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + train_metrics (list, optional): List of metrics to compute during training. If None, defaults to + [MeanSquaredError()]. + val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to + [MeanSquaredError(), FrechetInceptionDistance(normalize=True)]. + val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to + [1.0, 3.0, 7.0]. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. + loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to + [(0, 1)]. + precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + fsdp (bool): Whether to use FSDP. Defaults to True. + """ + if train_metrics is None: + train_metrics = [MeanSquaredError()] + if val_metrics is None: + val_metrics = [MeanSquaredError(), FrechetInceptionDistance(normalize=True)] + if val_guidance_scales is None: + val_guidance_scales = [1.0, 3.0, 7.0] + if loss_bins is None: + loss_bins = [(0, 1)] + # Fix a bug where CLIPScore requires grad + for metric in val_metrics: + if isinstance(metric, CLIPScore): + metric.requires_grad_(False) + + if pretrained: + raise NotImplementedError('Full SDXL pipeline not implemented yet.') + else: + config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet') + # Currently not doing micro-conditioning, so set config appropriately + config[0]['addition_embed_type'] = None + config[0]['cross_attention_dim'] = 1024 + unet = UNet2DConditionModel(**config[0]) + + # Prevent fsdp from wrapping up_blocks and down_blocks because the forward pass calls length on these + unet.up_blocks._fsdp_wrap = False + unet.down_blocks._fsdp_wrap = False + for block in unet.up_blocks: + block._fsdp_wrap = True + for block in unet.down_blocks: + block._fsdp_wrap = True + + if encode_latents_in_fp16: + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) + text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch.float16) + else: + vae = AutoencoderKL.from_pretrained(vae_model_name) + text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder') + + tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') + noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler') + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=noise_scheduler.config.num_train_timesteps, + beta_start=noise_scheduler.config.beta_start, + beta_end=noise_scheduler.config.beta_end, + beta_schedule=noise_scheduler.config.beta_schedule, + trained_betas=noise_scheduler.config.trained_betas, + clip_sample=noise_scheduler.config.clip_sample, + set_alpha_to_one=noise_scheduler.config.set_alpha_to_one, + prediction_type=prediction_type) + + model = StableDiffusion( + unet=unet, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + train_metrics=train_metrics, + val_metrics=val_metrics, + val_guidance_scales=val_guidance_scales, + val_seed=val_seed, + loss_bins=loss_bins, + precomputed_latents=precomputed_latents, + encode_latents_in_fp16=encode_latents_in_fp16, + fsdp=fsdp, + ) + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + if is_xformers_installed: + model.unet.enable_xformers_memory_efficient_attention() + model.vae.enable_xformers_memory_efficient_attention() + return model + + def discrete_pixel_diffusion(clip_model_name: str = 'openai/clip-vit-large-patch14', prediction_type='epsilon'): """Discrete pixel diffusion training setup.