From 71c3014cfb766a2128458e0766bea899ecff1255 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Sun, 17 Sep 2023 20:39:45 +0000 Subject: [PATCH 01/29] Add base trainer for any accelerate model --- src/drlx/trainer/base_accelerate.py | 122 ++++++++++++++++++++++++++++ src/drlx/trainer/ddpo_trainer.py | 94 +-------------------- 2 files changed, 124 insertions(+), 92 deletions(-) create mode 100644 src/drlx/trainer/base_accelerate.py diff --git a/src/drlx/trainer/base_accelerate.py b/src/drlx/trainer/base_accelerate.py new file mode 100644 index 0000000..4381e2f --- /dev/null +++ b/src/drlx/trainer/base_accelerate.py @@ -0,0 +1,122 @@ +from drlx.trainer import BaseTrainer +from drlx.configs import DRLXConfig +from drlx.sampling import Sampler +from drlx.utils import suppress_warnings + +from accelerate import Accelerator +import wandb +import logging +import torch +from diffusers import StableDiffusionPipeline + + +class AcceleratedTrainer(BaseTrainer): + """ + Base class for any trainer using accelerate. Assumes model comes from a pretrained + pipeline + + :param config: DRLX config. Method config can be anything. + :type config: DRLXConfig + """ + def __init__(self, config : DRLXConfig): + super().__init__(config) + # Figure out batch size and accumulation steps + if self.config.train.target_batch is not None: # Just use normal batch_size + self.accum_steps = (self.config.train.target_batch // self.config.train.batch_size) + else: + self.accum_steps = 1 + + self.accelerator = Accelerator( + log_with = config.logging.log_with, + gradient_accumulation_steps = self.accum_steps + ) + + # Disable tokenizer warnings since they clutter the CLI + kw_str = self.config.train.suppress_log_keywords + if kw_str is not None: + for prefix in kw_str.split(","): + suppress_warnings(prefix.strip()) + + self.pipe = None # Store reference to pipeline so that we can use save_pretrained later + self.model = self.setup_model() + self.optimizer = self.setup_optimizer() + self.scheduler = self.setup_scheduler() + + self.sampler = self.model.sampler + self.model, self.optimizer, self.scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.scheduler + ) + + # Setup tracking + + tracker_kwargs = {} + self.use_wandb = not (config.logging.wandb_project is None) + if self.use_wandb: + log = config.logging + tracker_kwargs["wandb"] = { + "name" : log.run_name, + "entity" : log.wandb_entity, + "mode" : "online" + } + + self.accelerator.init_trackers( + project_name = log.wandb_project, + config = config.to_dict(), + init_kwargs = tracker_kwargs + ) + + self.world_size = self.accelerator.state.num_processes + + def setup_model(self): + """ + Set up model from config. + """ + model = self.get_arch(self.config)(self.config.model, sampler = Sampler(self.config.sampler)) + if self.config.model.model_path is not None: + model, pipe = model.from_pretrained_pipeline(StableDiffusionPipeline, self.config.model.model_path) + + self.pipe = pipe + return model + + def save_checkpoint(self, fp : str, components = None): + """ + Save checkpoint in main process + + :param fp: File path to save checkpoint to + """ + if self.accelerator.is_main_process: + os.makedirs(fp, exist_ok = True) + self.accelerator.save_state(output_dir=fp) + self.accelerator.wait_for_everyone() # need to use this twice or a corrupted state is saved + + def save_pretrained(self, fp : str): + """ + Save model into pretrained pipeline so it can be loaded in pipeline later + + :param fp: File path to save to + """ + if self.accelerator.is_main_process: + os.makedirs(fp, exist_ok = True) + unwrapped_model = self.accelerator.unwrap_model(self.model) + self.pipe.unet = unwrapped_model.unet + self.pipe.save_pretrained(fp, safe_serialization = unwrapped_model.config.use_safetensors) + self.accelerator.wait_for_everyone() + + def extract_pipeline(self): + """ + Return original pipeline with finetuned denoiser plugged in + + :return: Diffusers pipeline + """ + + self.pipe.unet = self.accelerator.unwrap_model(self.model).unet + return self.pipe + + def load_checkpoint(self, fp : str): + """ + Load checkpoint + + :param fp: File path to checkpoint to load from + """ + self.accelerator.load_state(fp) + self.accelerator.print("Succesfully loaded checkpoint") \ No newline at end of file diff --git a/src/drlx/trainer/ddpo_trainer.py b/src/drlx/trainer/ddpo_trainer.py index ce9fe29..153bf76 100644 --- a/src/drlx/trainer/ddpo_trainer.py +++ b/src/drlx/trainer/ddpo_trainer.py @@ -3,7 +3,7 @@ from accelerate import Accelerator from drlx.configs import DRLXConfig, DDPOConfig -from drlx.trainer import BaseTrainer +from drlx.trainer.base_accelerate import AcceleratedTrainer from drlx.sampling import DDPOSampler from drlx.utils import suppress_warnings, Timer, PerPromptStatTracker, scoped_seed, save_images @@ -79,7 +79,7 @@ def collate(batch): return DataLoader(self, collate_fn=collate, **kwargs) -class DDPOTrainer(BaseTrainer): +class DDPOTrainer(AcceleratedTrainer): """ DDPO Accelerated Trainer initilization from config. During init, sets up model, optimizer, sampler and logging @@ -92,53 +92,6 @@ def __init__(self, config : DRLXConfig): assert isinstance(self.config.method, DDPOConfig), "ERROR: Method config must be DDPO config" - # Figure out batch size and accumulation steps - if self.config.train.target_batch is not None: # Just use normal batch_size - self.accum_steps = (self.config.train.target_batch // self.config.train.batch_size) - else: - self.accum_steps = 1 - - self.accelerator = Accelerator( - log_with = config.logging.log_with, - gradient_accumulation_steps = self.accum_steps - ) - - # Disable tokenizer warnings since they clutter the CLI - kw_str = self.config.train.suppress_log_keywords - if kw_str is not None: - for prefix in kw_str.split(","): - suppress_warnings(prefix.strip()) - - self.pipe = None # Store reference to pipeline so that we can use save_pretrained later - self.model = self.setup_model() - self.optimizer = self.setup_optimizer() - self.scheduler = self.setup_scheduler() - - self.sampler = self.model.sampler - self.model, self.optimizer, self.scheduler = self.accelerator.prepare( - self.model, self.optimizer, self.scheduler - ) - - # Setup tracking - - tracker_kwargs = {} - self.use_wandb = not (config.logging.wandb_project is None) - if self.use_wandb: - log = config.logging - tracker_kwargs["wandb"] = { - "name" : log.run_name, - "entity" : log.wandb_entity, - "mode" : "online" - } - - self.accelerator.init_trackers( - project_name = log.wandb_project, - config = config.to_dict(), - init_kwargs = tracker_kwargs - ) - - self.world_size = self.accelerator.state.num_processes - def setup_model(self): """ Set up model from config. @@ -381,46 +334,3 @@ def time_per_1k(n_samples : int): last_epoch_time = time_per_1k(self.config.train.num_samples_per_epoch) del metrics, dataloader, experience_loader - - def save_checkpoint(self, fp : str, components = None): - """ - Save checkpoint in main process - - :param fp: File path to save checkpoint to - """ - if self.accelerator.is_main_process: - os.makedirs(fp, exist_ok = True) - self.accelerator.save_state(output_dir=fp) - self.accelerator.wait_for_everyone() # need to use this twice or a corrupted state is saved - - def save_pretrained(self, fp : str): - """ - Save model into pretrained pipeline so it can be loaded in pipeline later - - :param fp: File path to save to - """ - if self.accelerator.is_main_process: - os.makedirs(fp, exist_ok = True) - unwrapped_model = self.accelerator.unwrap_model(self.model) - self.pipe.unet = unwrapped_model.unet - self.pipe.save_pretrained(fp, safe_serialization = unwrapped_model.config.use_safetensors) - self.accelerator.wait_for_everyone() - - def extract_pipeline(self): - """ - Return original pipeline with finetuned denoiser plugged in - - :return: Diffusers pipeline - """ - - self.pipe.unet = self.accelerator.unwrap_model(self.model).unet - return self.pipe - - def load_checkpoint(self, fp : str): - """ - Load checkpoint - - :param fp: File path to checkpoint to load from - """ - self.accelerator.load_state(fp) - self.accelerator.print("Succesfully loaded checkpoint") From af358b7529e4985daf42f84d31a6a0072a21ec28 Mon Sep 17 00:00:00 2001 From: shahbuland Matiana Date: Wed, 27 Sep 2023 04:28:06 +0000 Subject: [PATCH 02/29] Add PickaPic pipeline for DPO --- configs/dpo_pickapic.yml | 41 ++++++++++++++++ examples/DPO/download_pickapic_wds.py | 54 ++++++++++++++++++++ examples/DPO/train_pickapic.py | 14 ++++++ src/drlx/pipeline/pickapic_wds.py | 71 +++++++++++++++++++++++++++ 4 files changed, 180 insertions(+) create mode 100644 configs/dpo_pickapic.yml create mode 100644 examples/DPO/download_pickapic_wds.py create mode 100644 examples/DPO/train_pickapic.py create mode 100644 src/drlx/pipeline/pickapic_wds.py diff --git a/configs/dpo_pickapic.yml b/configs/dpo_pickapic.yml new file mode 100644 index 0000000..7276c9d --- /dev/null +++ b/configs/dpo_pickapic.yml @@ -0,0 +1,41 @@ +method: + name : "DPO" + +model: + model_path: "stabilityai/stable-diffusion-2-1-base" + model_arch_type: "LDMUnet" + attention_slicing: True + xformers_memory_efficient: True + gradient_checkpointing: True + +sampler: + guidance_scale: 7.5 + num_inference_steps: 50 + +optimizer: + name: "adamw" + kwargs: + lr: 1.0e-5 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + +scheduler: + name: "linear" # Name of learning rate scheduler + kwargs: + start_factor: 1.0 + end_factor: 1.0 + +logging: + run_name: 'dpo_pickapic' + #wandb_entity: None + #wandb_project: None + +train: + num_epochs: 500 + num_samples_per_epoch: 256 + batch_size: 4 + sample_batch_size: 32 + grad_clip: 1.0 + checkpoint_interval: 50 + tf32: True + suppress_log_keywords: "diffusers.pipelines,transformers" \ No newline at end of file diff --git a/examples/DPO/download_pickapic_wds.py b/examples/DPO/download_pickapic_wds.py new file mode 100644 index 0000000..820c06b --- /dev/null +++ b/examples/DPO/download_pickapic_wds.py @@ -0,0 +1,54 @@ +from datasets import load_dataset +import requests +import os +from tqdm import tqdm +import tarfile +from multiprocessing import Pool, cpu_count + +""" +This script takes the filtered version of the PickAPic prompt dataset +and downloads the associated images, then tars them. This tar file can then +be moved to S3 or loaded directly if needed. Number of samples can be specified +""" + +n_samples = 1000 +data_root = "./pickapic_sample" +url = "CarperAI/pickapic_v1_no_images_training_sfw" +n_cpus = cpu_count() # Detect the number of CPUs + +base_name = os.path.basename(data_root).replace('.', '').replace('/', '') + +def make_tarfile(output_filename, source_dir): + with tarfile.open(output_filename, "w") as tar: + tar.add(source_dir, arcname=os.path.basename(source_dir)) + +def download_image(args): + url, filename = args + response = requests.get(url) + with open(filename, 'wb') as f: + f.write(response.content) + +if __name__ == "__main__": + ds = load_dataset("CarperAI/pickapic_v1_no_images_training_sfw")['train'] + os.makedirs(data_root, exist_ok = True) + + id_counter = 0 + with Pool(n_cpus) as p: + for row in tqdm(ds, total = n_samples): + if id_counter >= n_samples: + break + if row['has_label']: + id_str = str(id_counter).zfill(8) + with open(os.path.join(data_root, f'{id_str}.prompt.txt'), 'w', encoding='utf-8') as f: + # Ensure the caption is in UTF-8 format + caption = row['caption'].encode('utf-8').decode('utf-8') + f.write(caption) + if row['label_0']: + p.map(download_image, [(row['image_0_url'], os.path.join(data_root, f'{id_str}.chosen.png')), + (row['image_1_url'], os.path.join(data_root, f'{id_str}.rejected.png'))]) + else: + p.map(download_image, [(row['image_1_url'], os.path.join(data_root, f'{id_str}.chosen.png')), + (row['image_0_url'], os.path.join(data_root, f'{id_str}.rejected.png'))]) + id_counter += 1 + + make_tarfile(f"{base_name}.tar", data_root) \ No newline at end of file diff --git a/examples/DPO/train_pickapic.py b/examples/DPO/train_pickapic.py new file mode 100644 index 0000000..be60b94 --- /dev/null +++ b/examples/DPO/train_pickapic.py @@ -0,0 +1,14 @@ +import sys +sys.path.append("./src") + +from drlx.pipeline.pickapic_wds import PickAPicPipeline +from drlx.trainer.dpo_trainer import DPOTrainer +from drlx.configs import DRLXConfig + +pipe = PickAPicPipeline() +resume = False + +config = DRLXConfig.load_yaml("configs/dpo_pickapic.yml") +trainer = DPOTrainer(config) + +trainer.train(pipe) \ No newline at end of file diff --git a/src/drlx/pipeline/pickapic_wds.py b/src/drlx/pipeline/pickapic_wds.py new file mode 100644 index 0000000..919e86f --- /dev/null +++ b/src/drlx/pipeline/pickapic_wds.py @@ -0,0 +1,71 @@ +import webdataset as wds +import io +from PIL import Image + +from drlx.pipeline import Pipeline + +# Utility function for processing bytes being streamed from WDS +# Returns as dictionary with: +# list of chosen images (PIL format) +# +def wds_initial_collate(sample): + """ + Initial function to call in a collate function to map + list of dictionaries into list of data elements for + further processing + + :return: A dictionary that contains: + - A list of chosen images (PIL images) + - A list of rejected images (PIL images) + - A list of prompts + """ + + result = { + "chosen" : [], + "rejected" : [], + "prompt" : [] + } + for d in sample: + result['chosen'].append( + Image.open(io.BytesIO(d['chosen.png'])) + ) + result['rejected'].append( + Image.open(io.BytesIO(d['rejected.png'])) + ) + result['prompt'].append( + d['prompt.txt'].decode('utf-8') + ) + + return result + +class PickAPicPipeline(Pipeline): + """ + Pipeline that uses webdataset to load pickapic with images and prompts + + :param url: URL/path to tar file for WDS + """ + def __init__(self, url : str, *args): + super().__init__(args) + + self.dataset = wds.WebDataset(url) + + def collate(self, data): + data = wds_initial_collate(data) + return data + + def create_loader(self, **kwargs): + if self.prep is None: + raise ValueError("Preprocessing function must be set before creating a dataloader.") + + if 'shuffle' in kwargs: + if kwargs['shuffle'] and 'generator' not in kwargs: + generator = torch.Generator() + generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) + kwargs['generator'] = generator + + return DataLoader(self.dataset, collate_fn = self.collate) + + + + + From b1406232e9dd0e4c8223be90d243fddc8609fb23 Mon Sep 17 00:00:00 2001 From: shahbuland Matiana Date: Wed, 27 Sep 2023 04:38:35 +0000 Subject: [PATCH 03/29] add skeleton for DPO trainer --- src/drlx/trainer/dpo_trainer.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 src/drlx/trainer/dpo_trainer.py diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py new file mode 100644 index 0000000..eca17c7 --- /dev/null +++ b/src/drlx/trainer/dpo_trainer.py @@ -0,0 +1,26 @@ +from drlx.trainer.base_accelerate import AcceleratedTrainer +from drlx.configs import DRLXConfig, DPOConfig + +class DPOTrainer(AcceleratedTrainer): + """ + DPO Accelerated Trainer initilization from config. During init, sets up model, optimizer, sampler and logging + + :param config: DRLX config + :type config: DRLXConfig + """ + + def __init__(self, config : DRLXConfig): + super().__init__(config) + + assert isinstance(self.config.method, DPOConfig), "ERROR: Method config must be DDPO config" + + def train(self, pipeline): + """ + Trains model based on config parameters. Needs to be passed a pipeline that + supplies chosen images, rejeceted images and prompts + + :param pipeline: Pipeline to draw images and prompts from + :type Pipline: Pipeline + """ + pass + From 704a85ced4ec6c262a6b64a444f67599a6b3794a Mon Sep 17 00:00:00 2001 From: Shahbuland Matiana Date: Tue, 23 Jan 2024 23:19:45 +0000 Subject: [PATCH 04/29] Pipeline for DPO --- src/drlx/pipeline/dpo_pipeline.py | 30 +++++++++++++++ src/drlx/pipeline/pickapic_dpo.py | 64 +++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 src/drlx/pipeline/dpo_pipeline.py create mode 100644 src/drlx/pipeline/pickapic_dpo.py diff --git a/src/drlx/pipeline/dpo_pipeline.py b/src/drlx/pipeline/dpo_pipeline.py new file mode 100644 index 0000000..95dc880 --- /dev/null +++ b/src/drlx/pipeline/dpo_pipeline.py @@ -0,0 +1,30 @@ +from abc import abstractmethod +from typing import Tuple, Callable + +from PIL import Image + +from drlx.pipeline import Pipeline + +class DPOPipeline(Pipeline): + """ + Pipeline for training with DPO. Returns prompts, chosen images, and rejected images + """ + def __init__(self, *args): + super().__init__(*args) + + @abstractmethod + def __getitem__(self, index : int) -> Tuple[str, Image.Image, Image.Image]: + pass + + def make_default_collate(self, prep : Callable): + def collate(batch : Iterable[Tuple[str, Image.Image, Image.Image]]): + prompts = [d[0] for d in batch] + chosen = [d[1] for d in batch] + rejected = [d[2] for d in batch] + + return prep(prompts, chosen, rejected) + + return collate + + + diff --git a/src/drlx/pipeline/pickapic_dpo.py b/src/drlx/pipeline/pickapic_dpo.py new file mode 100644 index 0000000..2891b24 --- /dev/null +++ b/src/drlx/pipeline/pickapic_dpo.py @@ -0,0 +1,64 @@ +from datasets import load_dataset +import io + +from drlx.pipeline.dpo_pipeline import DPOPipeline + +import torch +from torchvision import transforms +from torch.utils.data import Dataset, DataLoader + +def convert_bytes_to_image(image_bytes, id): + try: + image = Image.open(io.BytesIO(image_bytes)) + image = image.resize((512, 512)) + return image + except Exception as e: + print(f"An error occurred: {e}") + +def create_train_dataset(): + ds = load_dataset("yuvalkirstain/pickapic_v2",split='train', streaming=True + ds = ds.filter(lambda example: example['has_label'] == True and example['label_0'] != 0.5) + return ds + +class Collator: + def __call__(self, batch): + # Batch is list of rows which are dicts + image_0_bytes = [b['jpg_0'] for b in batch] + image_1_bytes = [b['jpg_1'] for b in batch] + uid_0 = [b['image_0_uid'] for b in batch] + uid_1 = [b['image_1_uid'] for b in batch] + + label_0s = [b['label_0'] for b in batch] + + for i in range(len(batch)): + if not label_0s[i]: # label_1 is 1 => jpg_1 is the chosen one + image_0_bytes[i], image_1_bytes[i] = image_1_bytes[i], image_0_bytes[i] + # Swap so image_0 is always the chosen one + + prompts = [b['caption'] for b in batch] + + images_0 = [convert_bytes_to_image(i, id) for (i, id) in zip(image_0_bytes, uid_0)] + images_1 = [convert_bytes_to_image(i, id) for (i, id) in zip(image_1_bytes, uid_1)] + + images_0 = torch.stack([transforms.ToTensor()(image) for image in images_0]) + images_0 = images_0 * 2 - 1 + + images_1 = torch.stack([transforms.ToTensor()(image) for image in images_1]) + images_1 = images_1 * 2 - 1 + + return { + "chosen_pixel_values" : image_0, + "rejected_pixel_values" : image_1, + "prompts" : prompts + } + +class PickAPicDPOPipeline(DPOPipeline): + """ + Pipeline for training LDM with DPO + """ + def __init__(self): + self.train_ds = create_train_dataset() + self.dc = Collator() + + def create_loader(**kwargs): + return DataLoader(self.train_ds, collate_fn = self.dc, **kwargs) \ No newline at end of file From 68ec7891593d2fb2a10b65d884b7b0bccc31480e Mon Sep 17 00:00:00 2001 From: Shahbuland Matiana Date: Tue, 23 Jan 2024 23:20:08 +0000 Subject: [PATCH 05/29] Allow for list of images instead of just list of np arrays for sample drawing --- src/drlx/utils/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/drlx/utils/__init__.py b/src/drlx/utils/__init__.py index 49e776d..1dd0244 100644 --- a/src/drlx/utils/__init__.py +++ b/src/drlx/utils/__init__.py @@ -223,7 +223,12 @@ def save_images(images : np.array, fp : str): os.makedirs(fp, exist_ok = True) - images = [Image.fromarray(image) for image in images] + if isinstance(images, np.ndarray): + images = [Image.fromarray(image) for image in images] + elif isinstance(images, list) and all(isinstance(i, Image.Image) for i in images): + pass + else: + raise ValueError("Images should be either a numpy array or a list of PIL Images") for i, image in enumerate(images): image.save(os.path.join(fp,f"{i}.png")) From e70e537b35106e0e2b05ffa44ea6cebe6f2f5ff5 Mon Sep 17 00:00:00 2001 From: Shahbuland Matiana Date: Tue, 23 Jan 2024 23:20:29 +0000 Subject: [PATCH 06/29] Add sampler for DPO --- src/drlx/sampling/__init__.py | 103 +++++++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 1 deletion(-) diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index e7c5593..6a13059 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -5,6 +5,7 @@ from tqdm import tqdm import math import einops as eo +import torch.nn.functional as F from drlx.utils import rescale_noise_cfg @@ -287,4 +288,104 @@ def compute_loss( if accelerator is not None: metrics = accelerator.reduce(metrics, 'mean') - return metrics \ No newline at end of file + return metrics + +class DPOSampler(Sampler): + def compute_loss( + self, + prompts, + chosen_img, + rejected_img, + denoiser, + ref_denoiser, + vae, + device, + method_config, + accelerator = None + ): + if accelerator is None: + dn_unwrapped = accelerator.unwrap_model(denoiser) + else: + dn_unwrapped = denoiser + + scheduler = dn_unwrapped.scheduler + preprocess = dn_unwrapped.preprocess + beta = method_config.beta + + # Text and image preprocessing + with torch.no_grad(): + text_embeds = preprocess( + prompts, mode = "embeds", device = device, + num_images_per_prompt = 1, + do_classifier_free_guidance = self.config.guidance_scale > 1.0 + ).detach() + + chosen_latent = vae.encode(chosen_img).latent_dist.sample() + rejected_latent = vae.encode(rejected_img).latent_dist.sample() + + # sample random ts + timesteps = torch.randint( + 0, self.config.num_inference_steps, (len(chosen_pred),), device = device, dtype = torch.long + ) + + # One step of noising to samples + noise = torch.randn_like(chosen_latent) # [B, C, H, W] + noisy_chosen = scheduler.add_noise(chosen_latent, noise, timesteps) + noisy_rejected = scheduler.add_noise(rejected_latent, noise, timesteps) + + def predict(model, pixel_values): + return model( + pixel_values = pixel_values, + time_step = timesteps, + text_embeds = text_embeds + ) + + # Stack predictions for chosen and rejected samples + chosen_out, rejected_out = predict(denoiser, noisy_chosen), predict(denoiser, noisy_rejected) + with torch.no_grad(): + chosen_ref, rejected_ref = predict(ref_denoiser, noisy_chosen), predict(ref_denoiser, noisy_rejected) + + if scheduler.config.prediction_type == "epsilon": + chosen_target, rejected_target = noise, noise + elif scheduler.config.prediction_type == "v_prediction": + chosen_target = scheduler.get_velocity( + chosen_latent, + nosie, + timesteps + ) + rejected_target = scheduler.get_velocity( + rejected_latent, + noise, + timesteps + ) + + # Basic Diffusion Loss + mse_chosen = F.mse_loss(chosen_out, chosen_target, reduction = "mean") + mse_rejected = F.mse_loss(rejected_out, rejected_target, reduction = "mean") + base_loss = 0.5 * (mse_chosen.mean() + mse_rejected.mean()) # logging + model_diff = mse_chosen - mse_rejected + + # Reference model + with torch.no_grad(): + ref_mse_chosen = F.mse_loss(chosen_ref, chosen_target, reduction = "mean") + ref_mse_rejected = F.mse_loss(rejeceted_ref, rejeceted_target, reduction = "mean") + ref_diff = ref_mse_chosen - ref_mse_rejected + + # DPO Objective + surr_loss = (-beta/2) * (model_diff - ref_diff) + loss = -1 * F.logsigmoid(surr_loss.mean()) + + # Get approx accuracy as models probability of giving chosen over rejected + acc = (surr_loss > 0).sum().float() / len(surr_loss) + acc += 0.5 * (surr_loss == 0).sum().float() / len(surr_loss) # 50% for when both match + + if accelerator is None: + loss.backward() + else: + accelerator.backward(loss) + + return { + "loss" : loss.item(), + "diffusion_loss" : base_loss.item(), + "accuracy" : acc.item() + } \ No newline at end of file From 2202d9f224c54f29db12c2353295df5524776fe6 Mon Sep 17 00:00:00 2001 From: Shahbuland Matiana Date: Tue, 23 Jan 2024 23:20:53 +0000 Subject: [PATCH 07/29] Add method config for DPO --- src/drlx/configs.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/drlx/configs.py b/src/drlx/configs.py index 3aa6d29..4393144 100644 --- a/src/drlx/configs.py +++ b/src/drlx/configs.py @@ -92,6 +92,18 @@ class DDPOConfig(MethodConfig): buffer_size: int = 32 # Set to None to avoid using per prompt stat tracker min_count: int = 16 +@register_method("DPO") +@dataclass +class DPOConfig(MethodConfig): + """ + Config for DPO-related hyperparams + + :param beta: Deviation from initial model + :type beta: float + """ + name : str = "DPO" + beta : float = 0.9 + @dataclass class TrainConfig(ConfigClass): """ From 9304bac74d8e3f67ed6b383bb4f5d2db0bbdc92c Mon Sep 17 00:00:00 2001 From: Shahbuland Matiana Date: Tue, 23 Jan 2024 23:21:06 +0000 Subject: [PATCH 08/29] Add DPO trainer initial version --- src/drlx/trainer/dpo_trainer.py | 171 ++++++++++++++++++++++++++++++-- 1 file changed, 163 insertions(+), 8 deletions(-) diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index eca17c7..d2d42a1 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -1,9 +1,29 @@ -from drlx.trainer.base_accelerate import AcceleratedTrainer +from torchtyping import TensorType +from typing import Iterable, Tuple, Callable + +from accelerate import Accelerator from drlx.configs import DRLXConfig, DPOConfig +from drlx.trainer.base_accelerate import AcceleratedTrainer +from drlx.sampling import DPOSampler +from drlx.utils import suppress_warnings, Timer, scoped_seed, save_images + +import torch +import einops as eo +import os +import gc +import logging +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +import numpy as np +import wandb +import accelerate.utils +from PIL import Image + +from diffusers import StableDiffusionPipeline class DPOTrainer(AcceleratedTrainer): """ - DPO Accelerated Trainer initilization from config. During init, sets up model, optimizer, sampler and logging + DDPO Accelerated Trainer initilization from config. During init, sets up model, optimizer, sampler and logging :param config: DRLX config :type config: DRLXConfig @@ -12,15 +32,150 @@ class DPOTrainer(AcceleratedTrainer): def __init__(self, config : DRLXConfig): super().__init__(config) - assert isinstance(self.config.method, DPOConfig), "ERROR: Method config must be DDPO config" + assert isinstance(self.config.method, DPOConfig), "ERROR: Method config must be DPO config" + + def setup_model(self): + """ + Set up model from config. + """ + model = self.get_arch(self.config)(self.config.model, sampler = DPOSampler(self.config.sampler)) + if self.config.model.model_path is not None: + model, pipe = model.from_pretrained_pipeline(StableDiffusionPipeline, self.config.model.model_path) + + self.pipe = pipe + return model + + def loss( + self, + prompts, chosen_img, rejected_img, ref_denoiser + ): + """ + Get loss for training + + :param chosen_batch_preds: Predictions for the ba + """ + return self.sampler.compute_loss( + prompts=prompts, chosen_img=chosen_img, rejected_img=rejected_img, + denoiser=self.model, ref_denoiser=ref_denoiser, vae=self.model.vae + device=self.accelerator.device, + method_config=self.config.method, + accelerator=self.accelerator + ) + @torch.no_grad() + def deterministic_sample(self, prompts): + """ + Sample images deterministically. Utility for visualizing changes for fixed prompts through training. + """ + gen = torch.Generator(device=self.pipe.device).manual_seed(self.config.train.seed) + self.pipe.unet = self.model.unet + return self.pipe(prompts, generator = gen).images + def train(self, pipeline): """ - Trains model based on config parameters. Needs to be passed a pipeline that - supplies chosen images, rejeceted images and prompts + Trains the model based on config parameters. Needs to be passed a prompt pipeline and reward function. - :param pipeline: Pipeline to draw images and prompts from - :type Pipline: Pipeline + :param pipeline: Pipeline to draw tuples from with prompts + :type prompt_pipeline: DPOPipeline """ - pass + + # === SETUP === + + # Singular dataloader made to get a sample of prompts + # This sample batch is dependent on config seed so it can be same across runs + with scoped_seed(self.config.train.seed): + dataloader = self.accelerator.prepare( + pipeline.create_loader(batch_size = self.config.train.batch_size, shuffle = False) + ) + sample_prompts = self.config.train.sample_prompts + if sample_prompts is None: + sample_prompts = [] + if len(sample_prompts) < self.config.train.batch_size: + new_sample_prompts = next(iter(dataloader))["prompts"] + sample_prompts += new_sample_prompts + sample_prompts = sample_prompts[:self.config.train.batch_size] + + # Now make main dataloader + + assert isinstance(self.sampler, DPOSampler), "Error: Model Sampler for DPO training must be DPO sampler" + + if isinstance(reward_fn, torch.nn.Module): + reward_fn = self.accelerator.prepare(reward_fn) + + # Set the epoch count + epochs = self.config.train.num_epochs + if self.config.train.total_samples is not None: + epochs = int(self.config.train.total_samples // self.config.train.num_samples_per_epoch) + + # Timer to measure time per 1k images (as metric) + timer = Timer() + def time_per_1k(n_samples : int): + total_time = timer.hit() + return total_time * 1000 / n_samples + last_batch_time = timer.hit() + + # Ref model + ref_model = self.setup_model() + + # === MAIN TRAINING LOOP === + + mean_rewards = [] + accum = 0 + last_epoch_time = timer.hit() + for epoch in range(epochs): + dataloader = pipeline.create_loader(batch_size = self.config.train.batch_size, shuffle = True) + dataloader = self.accelerator.prepare(dataloader) + + # Clean up unused resources + self.accelerator._dataloaders = [] # Clear dataloaders + gc.collect() + torch.cuda.empty_cache() + + self.accelerator.print(f"Epoch {epoch}/{epochs}.") + + for batch in dataloader: + metrics = self.compute_loss( + prompts = batch['prompts'], + chosen_img = batch['chosen_pixel_values'], + rejected_img = batch['rejected_pixel_values'], + ref_model + ) + + self.accelerator.wait_for_everyone() + # Generate the sample prompts + self.pipe.unet = self.model.unet + with torch.no_grad(): + with scoped_seed(self.config.train.seed): + sample_imgs = self.deterministic_sample(sample_prompts) + sample_imgs = [wandb.Image(img, caption = prompt) for (img, prompt) in zip(sample_imgs, sample_prompts)] + + # Logging + if self.use_wandb: + self.accelerator.log({ + "base_loss" : metrics["diffusion_loss"], + "accuracy" : metrics["accuracy"], + "dpo_loss" : metrics["loss"], + "time_per_1k" : last_batch_time, + "img_sample" : sample_imgs + }) + # save images + if self.accelerator.is_main_process and self.config.train.save_samples: + save_images(sample_imgs, f"./samples/{self.config.logging.run_name}/{epoch}") + + + + # Save model every [interval] epochs + accum += 1 + if accum % self.config.train.checkpoint_interval == 0 and self.config.train.checkpoint_interval > 0: + self.accelerator.print("Saving...") + base_path = f"./checkpoints/{self.config.logging.run_name}" + output_path = f"./output/{self.config.logging.run_name}" + self.accelerator.wait_for_everyone() + self.save_checkpoint(f"{base_path}/{accum}") + self.save_pretrained(output_path) + + last_epoch_time = time_per_1k(self.config.train.num_samples_per_epoch) + + del metrics + del dataloader From 1752621028a0eeccee33a8b19a026485d86c3608 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Wed, 24 Jan 2024 03:20:40 +0000 Subject: [PATCH 09/29] basic debugs --- src/drlx/trainer/dpo_trainer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index d2d42a1..ca96d22 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -56,7 +56,7 @@ def loss( """ return self.sampler.compute_loss( prompts=prompts, chosen_img=chosen_img, rejected_img=rejected_img, - denoiser=self.model, ref_denoiser=ref_denoiser, vae=self.model.vae + denoiser=self.model, ref_denoiser=ref_denoiser, vae=self.model.vae, device=self.accelerator.device, method_config=self.config.method, accelerator=self.accelerator @@ -134,14 +134,21 @@ def time_per_1k(n_samples : int): self.accelerator.print(f"Epoch {epoch}/{epochs}.") for batch in dataloader: - metrics = self.compute_loss( + metrics = self.loss( prompts = batch['prompts'], chosen_img = batch['chosen_pixel_values'], rejected_img = batch['rejected_pixel_values'], - ref_model + ref_denoiser = ref_model ) self.accelerator.wait_for_everyone() + + # Optimizer step + self.accelerator.clip_grad_norm_(self.model.parameters(), self.config.train.grad_clip) + self.optimizer.step() + self.scheduler.step() + self.optimizer.zero_grad() + # Generate the sample prompts self.pipe.unet = self.model.unet with torch.no_grad(): From b11e873a0a22af4c6371c896ad9b7cbb3ad32d46 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Wed, 24 Jan 2024 03:21:12 +0000 Subject: [PATCH 10/29] Remove streaming --- src/drlx/pipeline/pickapic_dpo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/drlx/pipeline/pickapic_dpo.py b/src/drlx/pipeline/pickapic_dpo.py index 2891b24..f032c23 100644 --- a/src/drlx/pipeline/pickapic_dpo.py +++ b/src/drlx/pipeline/pickapic_dpo.py @@ -6,6 +6,7 @@ import torch from torchvision import transforms from torch.utils.data import Dataset, DataLoader +from PIL import Image def convert_bytes_to_image(image_bytes, id): try: @@ -16,7 +17,7 @@ def convert_bytes_to_image(image_bytes, id): print(f"An error occurred: {e}") def create_train_dataset(): - ds = load_dataset("yuvalkirstain/pickapic_v2",split='train', streaming=True + ds = load_dataset("yuvalkirstain/pickapic_v2",split='train') ds = ds.filter(lambda example: example['has_label'] == True and example['label_0'] != 0.5) return ds @@ -47,8 +48,8 @@ def __call__(self, batch): images_1 = images_1 * 2 - 1 return { - "chosen_pixel_values" : image_0, - "rejected_pixel_values" : image_1, + "chosen_pixel_values" : images_0, + "rejected_pixel_values" : images_1, "prompts" : prompts } @@ -60,5 +61,5 @@ def __init__(self): self.train_ds = create_train_dataset() self.dc = Collator() - def create_loader(**kwargs): + def create_loader(self, **kwargs): return DataLoader(self.train_ds, collate_fn = self.dc, **kwargs) \ No newline at end of file From 9201d6aca04e41fc2bca3e70ad17d5560483d4fb Mon Sep 17 00:00:00 2001 From: shahbuland Date: Wed, 24 Jan 2024 21:06:38 +0000 Subject: [PATCH 11/29] minor bug fixes --- configs/dpo_pickapic.yml | 2 +- src/drlx/pipeline/pickapic_wds.py | 71 ------------------------------- src/drlx/sampling/__init__.py | 18 +++++--- src/drlx/trainer/dpo_trainer.py | 4 +- 4 files changed, 14 insertions(+), 81 deletions(-) delete mode 100644 src/drlx/pipeline/pickapic_wds.py diff --git a/configs/dpo_pickapic.yml b/configs/dpo_pickapic.yml index 7276c9d..c11b02e 100644 --- a/configs/dpo_pickapic.yml +++ b/configs/dpo_pickapic.yml @@ -5,7 +5,7 @@ model: model_path: "stabilityai/stable-diffusion-2-1-base" model_arch_type: "LDMUnet" attention_slicing: True - xformers_memory_efficient: True + xformers_memory_efficient: False gradient_checkpointing: True sampler: diff --git a/src/drlx/pipeline/pickapic_wds.py b/src/drlx/pipeline/pickapic_wds.py deleted file mode 100644 index 919e86f..0000000 --- a/src/drlx/pipeline/pickapic_wds.py +++ /dev/null @@ -1,71 +0,0 @@ -import webdataset as wds -import io -from PIL import Image - -from drlx.pipeline import Pipeline - -# Utility function for processing bytes being streamed from WDS -# Returns as dictionary with: -# list of chosen images (PIL format) -# -def wds_initial_collate(sample): - """ - Initial function to call in a collate function to map - list of dictionaries into list of data elements for - further processing - - :return: A dictionary that contains: - - A list of chosen images (PIL images) - - A list of rejected images (PIL images) - - A list of prompts - """ - - result = { - "chosen" : [], - "rejected" : [], - "prompt" : [] - } - for d in sample: - result['chosen'].append( - Image.open(io.BytesIO(d['chosen.png'])) - ) - result['rejected'].append( - Image.open(io.BytesIO(d['rejected.png'])) - ) - result['prompt'].append( - d['prompt.txt'].decode('utf-8') - ) - - return result - -class PickAPicPipeline(Pipeline): - """ - Pipeline that uses webdataset to load pickapic with images and prompts - - :param url: URL/path to tar file for WDS - """ - def __init__(self, url : str, *args): - super().__init__(args) - - self.dataset = wds.WebDataset(url) - - def collate(self, data): - data = wds_initial_collate(data) - return data - - def create_loader(self, **kwargs): - if self.prep is None: - raise ValueError("Preprocessing function must be set before creating a dataloader.") - - if 'shuffle' in kwargs: - if kwargs['shuffle'] and 'generator' not in kwargs: - generator = torch.Generator() - generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) - kwargs['generator'] = generator - - return DataLoader(self.dataset, collate_fn = self.collate) - - - - - diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index 6a13059..6e1bea3 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -297,12 +297,16 @@ def compute_loss( chosen_img, rejected_img, denoiser, - ref_denoiser, vae, device, method_config, - accelerator = None + accelerator = None, + ref_denoiser = none ): + """ + Compute metrics and do backwards pass on loss. Assumes LoRA if reference is not given. + """ + do_lora = ref_denoiser is None if accelerator is None: dn_unwrapped = accelerator.unwrap_model(denoiser) else: @@ -325,7 +329,7 @@ def compute_loss( # sample random ts timesteps = torch.randint( - 0, self.config.num_inference_steps, (len(chosen_pred),), device = device, dtype = torch.long + 0, self.config.num_inference_steps, (len(chosen_img),), device = device, dtype = torch.long ) # One step of noising to samples @@ -335,15 +339,17 @@ def compute_loss( def predict(model, pixel_values): return model( - pixel_values = pixel_values, - time_step = timesteps, - text_embeds = text_embeds + pixel_values = pixel_values.to(model.device), + time_step = timesteps.to(model.device), + text_embeds = text_embeds.to(model.device) ) # Stack predictions for chosen and rejected samples chosen_out, rejected_out = predict(denoiser, noisy_chosen), predict(denoiser, noisy_rejected) with torch.no_grad(): chosen_ref, rejected_ref = predict(ref_denoiser, noisy_chosen), predict(ref_denoiser, noisy_rejected) + chosen_ref = chosen_ref.to(denoiser.device) + rejected_ref = rejected_ref.to(denoiser.device) if scheduler.config.prediction_type == "epsilon": chosen_target, rejected_target = noise, noise diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index ca96d22..76a1da0 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -99,9 +99,6 @@ def train(self, pipeline): assert isinstance(self.sampler, DPOSampler), "Error: Model Sampler for DPO training must be DPO sampler" - if isinstance(reward_fn, torch.nn.Module): - reward_fn = self.accelerator.prepare(reward_fn) - # Set the epoch count epochs = self.config.train.num_epochs if self.config.train.total_samples is not None: @@ -116,6 +113,7 @@ def time_per_1k(n_samples : int): # Ref model ref_model = self.setup_model() + ref_model = ref_model.to("cuda:1") # === MAIN TRAINING LOOP === From e16526f6dcbd73d4bd1226690fc7d1e9a3ee5245 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Thu, 25 Jan 2024 03:45:35 +0000 Subject: [PATCH 12/29] LoRA, refactorings, quick bug fixes --- src/drlx/sampling/__init__.py | 108 ++++++++++++++++++-------------- src/drlx/trainer/dpo_trainer.py | 13 +++- 2 files changed, 70 insertions(+), 51 deletions(-) diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index 6e1bea3..616c9ea 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -221,13 +221,8 @@ def compute_loss( "clip_frac" : [], # Proportion of policy updates where magnitude of update was clipped } - if accelerator is None: - denoiser_unwrapped = denoiser - else: - denoiser_unwrapped = accelerator.unwrap_model(denoiser) - - scheduler = denoiser_unwrapped.scheduler - preprocess = denoiser_unwrapped.preprocess + scheduler = accelerator.unwrap_model(denoiser).scheduler + preprocess = accelerator.unwrap_model(denoiser).preprocess adv_clip = method_config.clip_advantages # clip value for advantages pi_clip = method_config.clip_ratio # clip value for policy ratio @@ -301,19 +296,16 @@ def compute_loss( device, method_config, accelerator = None, - ref_denoiser = none + ref_denoiser = None ): - """ - Compute metrics and do backwards pass on loss. Assumes LoRA if reference is not given. - """ + """ + Compute metrics and do backwards pass on loss. Assumes LoRA if reference is not given. + """ do_lora = ref_denoiser is None - if accelerator is None: - dn_unwrapped = accelerator.unwrap_model(denoiser) - else: - dn_unwrapped = denoiser - scheduler = dn_unwrapped.scheduler - preprocess = dn_unwrapped.preprocess + scheduler = accelerator.unwrap_model(denoiser).scheduler + preprocess = accelerator.unwrap_model(denoiser).preprocess + beta = method_config.beta # Text and image preprocessing @@ -337,45 +329,65 @@ def compute_loss( noisy_chosen = scheduler.add_noise(chosen_latent, noise, timesteps) noisy_rejected = scheduler.add_noise(rejected_latent, noise, timesteps) - def predict(model, pixel_values): - return model( - pixel_values = pixel_values.to(model.device), - time_step = timesteps.to(model.device), - text_embeds = text_embeds.to(model.device) - ) - - # Stack predictions for chosen and rejected samples - chosen_out, rejected_out = predict(denoiser, noisy_chosen), predict(denoiser, noisy_rejected) - with torch.no_grad(): - chosen_ref, rejected_ref = predict(ref_denoiser, noisy_chosen), predict(ref_denoiser, noisy_rejected) - chosen_ref = chosen_ref.to(denoiser.device) - rejected_ref = rejected_ref.to(denoiser.device) + # Doubling across chosen and rejeceted + def double_up(x): + return torch.cat([x,x], dim = 0) + def double_down(x): + n = len(x) + return x[:n//2], x[n//2:] + + # Double everything up so we can input both chosen and rejected at the same time + timesteps = double_up(timesteps) + noise = double_up(noise) + text_embeds = double_up(text_embeds) + latent = torch.cat([chosen_latent, rejected_latent]) + + noisy_inputs = scheduler.add_noise( + latent, + noise, + timesteps + ) + + # Get targets if scheduler.config.prediction_type == "epsilon": - chosen_target, rejected_target = noise, noise + target = noise elif scheduler.config.prediction_type == "v_prediction": - chosen_target = scheduler.get_velocity( - chosen_latent, - nosie, - timesteps - ) - rejected_target = scheduler.get_velocity( - rejected_latent, + target = scheduler.get_velocity( + latent, noise, timesteps ) + + # utility function to get loss simpler + def split_mse(pred, target): + mse = eo.reduce(F.mse_loss(pred, target), 'b ... -> b', reduction = "mean") + chosen, rejected = double_down(mse) + return mse.mean(), chose.mean() - rejected.mean() + + # Forward pass and loss for DPO denoiser + pred = denoiser( + pixel_values = noisy_inputs, + time_step = timesteps, + text_embeds = text_embeds + ) + model_diff, base_loss = split_mse(pred, targets) - # Basic Diffusion Loss - mse_chosen = F.mse_loss(chosen_out, chosen_target, reduction = "mean") - mse_rejected = F.mse_loss(rejected_out, rejected_target, reduction = "mean") - base_loss = 0.5 * (mse_chosen.mean() + mse_rejected.mean()) # logging - model_diff = mse_chosen - mse_rejected - - # Reference model + # Forward pass and loss for refrence with torch.no_grad(): - ref_mse_chosen = F.mse_loss(chosen_ref, chosen_target, reduction = "mean") - ref_mse_rejected = F.mse_loss(rejeceted_ref, rejeceted_target, reduction = "mean") - ref_diff = ref_mse_chosen - ref_mse_rejected + if do_lora: + accelerator.unwrap_model(denoiser).disable_adapters() + + ref_pred = denoiser( + pixel_values = noisy_inputs, + time_step = timesteps, + text_embeds = text_embeds + ) + ref_diff, _ = split_mse(ref_pred, targets) + + accelerator.unwrap_model(denoiser).enable_adapters() + else: + pass # TODO: Maybe not needed? Do we want non-LoRA DPO? # DPO Objective surr_loss = (-beta/2) * (model_diff - ref_diff) diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index 76a1da0..1319e32 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -80,6 +80,7 @@ def train(self, pipeline): """ # === SETUP === + do_lora = self.config.model.lora_rank is not None # Singular dataloader made to get a sample of prompts # This sample batch is dependent on config seed so it can be same across runs @@ -112,8 +113,11 @@ def time_per_1k(n_samples : int): last_batch_time = timer.hit() # Ref model - ref_model = self.setup_model() - ref_model = ref_model.to("cuda:1") + if not do_lora: + ref_model = self.setup_model() + ref_model = ref_model.to("cuda:1") + else: + ref_model = None # === MAIN TRAINING LOOP === @@ -142,7 +146,10 @@ def time_per_1k(n_samples : int): self.accelerator.wait_for_everyone() # Optimizer step - self.accelerator.clip_grad_norm_(self.model.parameters(), self.config.train.grad_clip) + self.accelerator.clip_grad_norm_( + filter(lambda p: p.requires_grad, self.model.parameters()), + self.config.train.grad_clip + ) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() From 14fe25462298775d96fc1edfc714d042c390b6ef Mon Sep 17 00:00:00 2001 From: shahbuland Date: Thu, 25 Jan 2024 04:35:40 +0000 Subject: [PATCH 13/29] small bug fixes --- src/drlx/denoisers/ldm_unet.py | 11 +++++++++++ src/drlx/sampling/__init__.py | 8 ++++---- src/drlx/trainer/base_accelerate.py | 1 + src/drlx/trainer/dpo_trainer.py | 9 +++++---- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/drlx/denoisers/ldm_unet.py b/src/drlx/denoisers/ldm_unet.py index 3ef62d0..71bec0f 100644 --- a/src/drlx/denoisers/ldm_unet.py +++ b/src/drlx/denoisers/ldm_unet.py @@ -165,5 +165,16 @@ def forward( encoder_hidden_states = text_embeds ).sample + @property + def device(self): + return self.unet.device + + def enable_adapters(self): + if self.config.lora_rank: + self.unet.enable_adapters() + + def disable_adapters(self): + if self.config.lora_rank: + self.unet.disable_adapters() diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index 616c9ea..1908739 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -361,9 +361,9 @@ def double_down(x): # utility function to get loss simpler def split_mse(pred, target): - mse = eo.reduce(F.mse_loss(pred, target), 'b ... -> b', reduction = "mean") + mse = eo.reduce(F.mse_loss(pred, target, reduction = 'none'), 'b ... -> b', reduction = "mean") chosen, rejected = double_down(mse) - return mse.mean(), chose.mean() - rejected.mean() + return chosen - rejected, mse.mean() # Forward pass and loss for DPO denoiser pred = denoiser( @@ -371,7 +371,7 @@ def split_mse(pred, target): time_step = timesteps, text_embeds = text_embeds ) - model_diff, base_loss = split_mse(pred, targets) + model_diff, base_loss = split_mse(pred, target) # Forward pass and loss for refrence with torch.no_grad(): @@ -383,7 +383,7 @@ def split_mse(pred, target): time_step = timesteps, text_embeds = text_embeds ) - ref_diff, _ = split_mse(ref_pred, targets) + ref_diff, _ = split_mse(ref_pred, target) accelerator.unwrap_model(denoiser).enable_adapters() else: diff --git a/src/drlx/trainer/base_accelerate.py b/src/drlx/trainer/base_accelerate.py index 327d0c2..7895012 100644 --- a/src/drlx/trainer/base_accelerate.py +++ b/src/drlx/trainer/base_accelerate.py @@ -8,6 +8,7 @@ import logging import torch from diffusers import StableDiffusionPipeline +import os class AcceleratedTrainer(BaseTrainer): diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index 1319e32..7063725 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -42,7 +42,8 @@ def setup_model(self): if self.config.model.model_path is not None: model, pipe = model.from_pretrained_pipeline(StableDiffusionPipeline, self.config.model.model_path) - self.pipe = pipe + self.pipe = pipe + self.pipe.set_progress_bar_config(disable=True) return model def loss( @@ -135,7 +136,7 @@ def time_per_1k(n_samples : int): self.accelerator.print(f"Epoch {epoch}/{epochs}.") - for batch in dataloader: + for batch in tqdm(dataloader): metrics = self.loss( prompts = batch['prompts'], chosen_img = batch['chosen_pixel_values'], @@ -159,7 +160,7 @@ def time_per_1k(n_samples : int): with torch.no_grad(): with scoped_seed(self.config.train.seed): sample_imgs = self.deterministic_sample(sample_prompts) - sample_imgs = [wandb.Image(img, caption = prompt) for (img, prompt) in zip(sample_imgs, sample_prompts)] + sample_imgs_wandb = [wandb.Image(img, caption = prompt) for (img, prompt) in zip(sample_imgs, sample_prompts)] # Logging if self.use_wandb: @@ -168,7 +169,7 @@ def time_per_1k(n_samples : int): "accuracy" : metrics["accuracy"], "dpo_loss" : metrics["loss"], "time_per_1k" : last_batch_time, - "img_sample" : sample_imgs + "img_sample" : sample_imgs_wandb }) # save images if self.accelerator.is_main_process and self.config.train.save_samples: From e121257c6ed7ee9ba57310c35ba1d69341e59518 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Thu, 25 Jan 2024 05:35:49 +0000 Subject: [PATCH 14/29] bug fixes --- src/drlx/sampling/__init__.py | 5 +- src/drlx/trainer/dpo_trainer.py | 104 +++++++++++++++++--------------- 2 files changed, 57 insertions(+), 52 deletions(-) diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index 1908739..4042352 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -305,6 +305,7 @@ def compute_loss( scheduler = accelerator.unwrap_model(denoiser).scheduler preprocess = accelerator.unwrap_model(denoiser).preprocess + encode = accelerator.unwrap_model(vae).encode beta = method_config.beta @@ -316,8 +317,8 @@ def compute_loss( do_classifier_free_guidance = self.config.guidance_scale > 1.0 ).detach() - chosen_latent = vae.encode(chosen_img).latent_dist.sample() - rejected_latent = vae.encode(rejected_img).latent_dist.sample() + chosen_latent = encode(chosen_img).latent_dist.sample() + rejected_latent = encode(rejected_img).latent_dist.sample() # sample random ts timesteps = torch.randint( diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index 7063725..3569360 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -32,6 +32,10 @@ class DPOTrainer(AcceleratedTrainer): def __init__(self, config : DRLXConfig): super().__init__(config) + # DPO requires we use vae encode, so let's put it on all GPUs + self.vae = self.accelerator.unwrap_model(self.model).vae + self.vae = self.accelerator.prepare(self.vae) + assert isinstance(self.config.method, DPOConfig), "ERROR: Method config must be DPO config" def setup_model(self): @@ -57,7 +61,7 @@ def loss( """ return self.sampler.compute_loss( prompts=prompts, chosen_img=chosen_img, rejected_img=rejected_img, - denoiser=self.model, ref_denoiser=ref_denoiser, vae=self.model.vae, + denoiser=self.model, ref_denoiser=ref_denoiser, vae=self.vae, device=self.accelerator.device, method_config=self.config.method, accelerator=self.accelerator @@ -69,7 +73,7 @@ def deterministic_sample(self, prompts): Sample images deterministically. Utility for visualizing changes for fixed prompts through training. """ gen = torch.Generator(device=self.pipe.device).manual_seed(self.config.train.seed) - self.pipe.unet = self.model.unet + self.pipe.unet = self.accelerator.unwrap_model(self.model).unet return self.pipe(prompts, generator = gen).images def train(self, pipeline): @@ -137,57 +141,57 @@ def time_per_1k(n_samples : int): self.accelerator.print(f"Epoch {epoch}/{epochs}.") for batch in tqdm(dataloader): - metrics = self.loss( - prompts = batch['prompts'], - chosen_img = batch['chosen_pixel_values'], - rejected_img = batch['rejected_pixel_values'], - ref_denoiser = ref_model - ) - - self.accelerator.wait_for_everyone() - - # Optimizer step - self.accelerator.clip_grad_norm_( - filter(lambda p: p.requires_grad, self.model.parameters()), - self.config.train.grad_clip - ) - self.optimizer.step() - self.scheduler.step() - self.optimizer.zero_grad() - - # Generate the sample prompts - self.pipe.unet = self.model.unet - with torch.no_grad(): - with scoped_seed(self.config.train.seed): - sample_imgs = self.deterministic_sample(sample_prompts) - sample_imgs_wandb = [wandb.Image(img, caption = prompt) for (img, prompt) in zip(sample_imgs, sample_prompts)] - - # Logging - if self.use_wandb: - self.accelerator.log({ - "base_loss" : metrics["diffusion_loss"], - "accuracy" : metrics["accuracy"], - "dpo_loss" : metrics["loss"], - "time_per_1k" : last_batch_time, - "img_sample" : sample_imgs_wandb - }) - # save images - if self.accelerator.is_main_process and self.config.train.save_samples: - save_images(sample_imgs, f"./samples/{self.config.logging.run_name}/{epoch}") - - + with self.accelerator.accumulate(self.model): + metrics = self.loss( + prompts = batch['prompts'], + chosen_img = batch['chosen_pixel_values'], + rejected_img = batch['rejected_pixel_values'], + ref_denoiser = ref_model + ) - # Save model every [interval] epochs - accum += 1 - if accum % self.config.train.checkpoint_interval == 0 and self.config.train.checkpoint_interval > 0: - self.accelerator.print("Saving...") - base_path = f"./checkpoints/{self.config.logging.run_name}" - output_path = f"./output/{self.config.logging.run_name}" self.accelerator.wait_for_everyone() - self.save_checkpoint(f"{base_path}/{accum}") - self.save_pretrained(output_path) - last_epoch_time = time_per_1k(self.config.train.num_samples_per_epoch) + # Optimizer step + self.accelerator.clip_grad_norm_( + filter(lambda p: p.requires_grad, self.model.parameters()), + self.config.train.grad_clip + ) + self.optimizer.step() + self.scheduler.step() + self.optimizer.zero_grad() + + # Generate the sample prompts + with torch.no_grad(): + with scoped_seed(self.config.train.seed): + sample_imgs = self.deterministic_sample(sample_prompts) + sample_imgs_wandb = [wandb.Image(img, caption = prompt) for (img, prompt) in zip(sample_imgs, sample_prompts)] + + # Logging + if self.use_wandb: + self.accelerator.log({ + "base_loss" : metrics["diffusion_loss"], + "accuracy" : metrics["accuracy"], + "dpo_loss" : metrics["loss"], + "time_per_1k" : last_batch_time, + "img_sample" : sample_imgs_wandb + }) + # save images + if self.accelerator.is_main_process and self.config.train.save_samples: + save_images(sample_imgs, f"./samples/{self.config.logging.run_name}/{epoch}") + + + + # Save model every [interval] epochs + accum += 1 + if accum % self.config.train.checkpoint_interval == 0 and self.config.train.checkpoint_interval > 0: + self.accelerator.print("Saving...") + base_path = f"./checkpoints/{self.config.logging.run_name}" + output_path = f"./output/{self.config.logging.run_name}" + self.accelerator.wait_for_everyone() + self.save_checkpoint(f"{base_path}/{accum}") + self.save_pretrained(output_path) + + last_epoch_time = time_per_1k(self.config.train.num_samples_per_epoch) del metrics del dataloader From 6d9e03df4eadaef675afb24d94b69798e6f82f17 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Thu, 25 Jan 2024 21:39:04 +0000 Subject: [PATCH 15/29] Fix import errors and checkpointing --- src/drlx/trainer/base_accelerate.py | 35 ++--------------------------- src/drlx/trainer/dpo_trainer.py | 3 ++- 2 files changed, 4 insertions(+), 34 deletions(-) diff --git a/src/drlx/trainer/base_accelerate.py b/src/drlx/trainer/base_accelerate.py index 7895012..1d85881 100644 --- a/src/drlx/trainer/base_accelerate.py +++ b/src/drlx/trainer/base_accelerate.py @@ -10,6 +10,8 @@ from diffusers import StableDiffusionPipeline import os +from diffusers.utils import convert_state_dict_to_diffusers +from peft.utils import get_peft_model_state_dict class AcceleratedTrainer(BaseTrainer): """ @@ -79,30 +81,6 @@ def setup_model(self): self.pipe = pipe return model - def save_checkpoint(self, fp : str, components = None): - """ - Save checkpoint in main process - - :param fp: File path to save checkpoint to - """ - if self.accelerator.is_main_process: - os.makedirs(fp, exist_ok = True) - self.accelerator.save_state(output_dir=fp) - self.accelerator.wait_for_everyone() # need to use this twice or a corrupted state is saved - - def save_pretrained(self, fp : str): - """ - Save model into pretrained pipeline so it can be loaded in pipeline later - - :param fp: File path to save to - """ - if self.accelerator.is_main_process: - os.makedirs(fp, exist_ok = True) - unwrapped_model = self.accelerator.unwrap_model(self.model) - self.pipe.unet = unwrapped_model.unet - self.pipe.save_pretrained(fp, safe_serialization = unwrapped_model.config.use_safetensors) - self.accelerator.wait_for_everyone() - def extract_pipeline(self): """ Return original pipeline with finetuned denoiser plugged in @@ -159,12 +137,3 @@ def extract_pipeline(self): self.pipe.unet = self.accelerator.unwrap_model(self.model).unet return self.pipe - - def load_checkpoint(self, fp : str): - """ - Load checkpoint - - :param fp: File path to checkpoint to load from - """ - self.accelerator.load_state(fp) - self.accelerator.print("Succesfully loaded checkpoint") diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index 3569360..9d91973 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -188,10 +188,11 @@ def time_per_1k(n_samples : int): base_path = f"./checkpoints/{self.config.logging.run_name}" output_path = f"./output/{self.config.logging.run_name}" self.accelerator.wait_for_everyone() + # Commenting this out for now so I can test rest of the code even though this is broken self.save_checkpoint(f"{base_path}/{accum}") self.save_pretrained(output_path) - last_epoch_time = time_per_1k(self.config.train.num_samples_per_epoch) + last_epoch_time = time_per_1k(self.config.train.num_samples_per_epoch) del metrics del dataloader From c2350cb8d8012c8e16914fa0b930052ef29c6227 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Fri, 26 Jan 2024 00:21:09 +0000 Subject: [PATCH 16/29] Add base model loss deviation to sampling as metric --- src/drlx/sampling/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index 4042352..905a887 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -384,7 +384,7 @@ def split_mse(pred, target): time_step = timesteps, text_embeds = text_embeds ) - ref_diff, _ = split_mse(ref_pred, target) + ref_diff, ref_loss = split_mse(ref_pred, target) accelerator.unwrap_model(denoiser).enable_adapters() else: @@ -406,5 +406,6 @@ def split_mse(pred, target): return { "loss" : loss.item(), "diffusion_loss" : base_loss.item(), - "accuracy" : acc.item() + "accuracy" : acc.item(), + "ref_deviation" : (ref_loss - base_loss) ** 2 } \ No newline at end of file From 765b9f66b317015a0c42d80e74c5157b0cd38965 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Fri, 26 Jan 2024 00:21:33 +0000 Subject: [PATCH 17/29] Add base model loss deviation to trainer logging as metric --- src/drlx/trainer/dpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index 9d91973..e604c4f 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -172,6 +172,7 @@ def time_per_1k(n_samples : int): "base_loss" : metrics["diffusion_loss"], "accuracy" : metrics["accuracy"], "dpo_loss" : metrics["loss"], + "ref_deviation" : metrics["ref_deviation"], "time_per_1k" : last_batch_time, "img_sample" : sample_imgs_wandb }) From 74012cc745f57e568bd24a883de9a966f372624e Mon Sep 17 00:00:00 2001 From: shahbuland Date: Fri, 26 Jan 2024 03:51:27 +0000 Subject: [PATCH 18/29] Add non-lora training with memory saving options in config --- src/drlx/configs.py | 4 ++++ src/drlx/sampling/__init__.py | 9 ++++++++- src/drlx/trainer/dpo_trainer.py | 7 +++++-- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/drlx/configs.py b/src/drlx/configs.py index edd3a28..d4fa200 100644 --- a/src/drlx/configs.py +++ b/src/drlx/configs.py @@ -99,9 +99,13 @@ class DPOConfig(MethodConfig): :param beta: Deviation from initial model :type beta: float + + :param ref_mem_strategy: Strategy for managing reference model on memory. By default, puts it in 16 bit. + :type ref_mem_strategy: str """ name : str = "DPO" beta : float = 0.9 + ref_mem_strategy : str = None # None or "half" @dataclass class TrainConfig(ConfigClass): diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index 905a887..208ea56 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -308,6 +308,7 @@ def compute_loss( encode = accelerator.unwrap_model(vae).encode beta = method_config.beta + ref_strategy = method_config.ref_mem_strategy # Text and image preprocessing with torch.no_grad(): @@ -388,7 +389,13 @@ def split_mse(pred, target): accelerator.unwrap_model(denoiser).enable_adapters() else: - pass # TODO: Maybe not needed? Do we want non-LoRA DPO? + ref_inputs = { + "sample" : noisy_inputs.half() if ref_strategy == "half" else noisy_inputs, + "timestep" : timesteps, + "encoder_hidden_states" : text_embeds.half() if ref_strategy == "half" else text_embeds + } + ref_pred = ref_denoiser(**ref_inputs).sample + ref_diff, ref_loss = split_mse(ref_pred, target) # DPO Objective surr_loss = (-beta/2) * (model_diff - ref_diff) diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index e604c4f..bed859a 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -18,6 +18,7 @@ import wandb import accelerate.utils from PIL import Image +from copy import deepcopy from diffusers import StableDiffusionPipeline @@ -119,8 +120,10 @@ def time_per_1k(n_samples : int): # Ref model if not do_lora: - ref_model = self.setup_model() - ref_model = ref_model.to("cuda:1") + ref_model = deepcopy(self.accelerator.unwrap_model(self.model).unet) + ref_model.requires_grad = False + if self.config.method.ref_mem_strategy == "half": ref_model = ref_model.half() + ref_model = self.accelerator.prepare(ref_model) else: ref_model = None From ef91f92b8d88bf75da856a2ad64666e2c2070b67 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Sun, 28 Jan 2024 17:33:05 +0000 Subject: [PATCH 19/29] some refactorings to sampling, add rmsprop --- src/drlx/sampling/__init__.py | 4 +--- src/drlx/utils/__init__.py | 3 +++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index 208ea56..f9e158c 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -328,8 +328,6 @@ def compute_loss( # One step of noising to samples noise = torch.randn_like(chosen_latent) # [B, C, H, W] - noisy_chosen = scheduler.add_noise(chosen_latent, noise, timesteps) - noisy_rejected = scheduler.add_noise(rejected_latent, noise, timesteps) # Doubling across chosen and rejeceted def double_up(x): @@ -398,7 +396,7 @@ def split_mse(pred, target): ref_diff, ref_loss = split_mse(ref_pred, target) # DPO Objective - surr_loss = (-beta/2) * (model_diff - ref_diff) + surr_loss = -beta * (model_diff - ref_diff) loss = -1 * F.logsigmoid(surr_loss.mean()) # Get approx accuracy as models probability of giving chosen over rejected diff --git a/src/drlx/utils/__init__.py b/src/drlx/utils/__init__.py index 1dd0244..1cab956 100644 --- a/src/drlx/utils/__init__.py +++ b/src/drlx/utils/__init__.py @@ -22,6 +22,7 @@ class OptimizerName(str, Enum): ADAM_8BIT_BNB: str = "adam_8bit_bnb" ADAMW_8BIT_BNB: str = "adamw_8bit_bnb" SGD: str = "sgd" + RMSPROP: str = "rmsprop" def get_optimizer_class(name: OptimizerName): @@ -57,6 +58,8 @@ def get_optimizer_class(name: OptimizerName): ) if name == OptimizerName.SGD.value: return torch.optim.SGD + if name == OptimizerName.RMSPROP.value: + return torch.optim.RMSprop supported_optimizers = [o.value for o in OptimizerName] raise ValueError(f"`{name}` is not a supported optimizer. " f"Supported optimizers are: {supported_optimizers}") From 38847c58dee999ac204663084f5c39f1310be0a1 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Sun, 28 Jan 2024 17:33:59 +0000 Subject: [PATCH 20/29] Delete old DPO example, push new one --- examples/DPO/download_pickapic_wds.py | 54 --------------------------- examples/DPO/train_pickapic.py | 14 ------- examples/DPO2/train_dpo_pickapic.py | 24 ++++++++++++ 3 files changed, 24 insertions(+), 68 deletions(-) delete mode 100644 examples/DPO/download_pickapic_wds.py delete mode 100644 examples/DPO/train_pickapic.py create mode 100644 examples/DPO2/train_dpo_pickapic.py diff --git a/examples/DPO/download_pickapic_wds.py b/examples/DPO/download_pickapic_wds.py deleted file mode 100644 index 820c06b..0000000 --- a/examples/DPO/download_pickapic_wds.py +++ /dev/null @@ -1,54 +0,0 @@ -from datasets import load_dataset -import requests -import os -from tqdm import tqdm -import tarfile -from multiprocessing import Pool, cpu_count - -""" -This script takes the filtered version of the PickAPic prompt dataset -and downloads the associated images, then tars them. This tar file can then -be moved to S3 or loaded directly if needed. Number of samples can be specified -""" - -n_samples = 1000 -data_root = "./pickapic_sample" -url = "CarperAI/pickapic_v1_no_images_training_sfw" -n_cpus = cpu_count() # Detect the number of CPUs - -base_name = os.path.basename(data_root).replace('.', '').replace('/', '') - -def make_tarfile(output_filename, source_dir): - with tarfile.open(output_filename, "w") as tar: - tar.add(source_dir, arcname=os.path.basename(source_dir)) - -def download_image(args): - url, filename = args - response = requests.get(url) - with open(filename, 'wb') as f: - f.write(response.content) - -if __name__ == "__main__": - ds = load_dataset("CarperAI/pickapic_v1_no_images_training_sfw")['train'] - os.makedirs(data_root, exist_ok = True) - - id_counter = 0 - with Pool(n_cpus) as p: - for row in tqdm(ds, total = n_samples): - if id_counter >= n_samples: - break - if row['has_label']: - id_str = str(id_counter).zfill(8) - with open(os.path.join(data_root, f'{id_str}.prompt.txt'), 'w', encoding='utf-8') as f: - # Ensure the caption is in UTF-8 format - caption = row['caption'].encode('utf-8').decode('utf-8') - f.write(caption) - if row['label_0']: - p.map(download_image, [(row['image_0_url'], os.path.join(data_root, f'{id_str}.chosen.png')), - (row['image_1_url'], os.path.join(data_root, f'{id_str}.rejected.png'))]) - else: - p.map(download_image, [(row['image_1_url'], os.path.join(data_root, f'{id_str}.chosen.png')), - (row['image_0_url'], os.path.join(data_root, f'{id_str}.rejected.png'))]) - id_counter += 1 - - make_tarfile(f"{base_name}.tar", data_root) \ No newline at end of file diff --git a/examples/DPO/train_pickapic.py b/examples/DPO/train_pickapic.py deleted file mode 100644 index be60b94..0000000 --- a/examples/DPO/train_pickapic.py +++ /dev/null @@ -1,14 +0,0 @@ -import sys -sys.path.append("./src") - -from drlx.pipeline.pickapic_wds import PickAPicPipeline -from drlx.trainer.dpo_trainer import DPOTrainer -from drlx.configs import DRLXConfig - -pipe = PickAPicPipeline() -resume = False - -config = DRLXConfig.load_yaml("configs/dpo_pickapic.yml") -trainer = DPOTrainer(config) - -trainer.train(pipe) \ No newline at end of file diff --git a/examples/DPO2/train_dpo_pickapic.py b/examples/DPO2/train_dpo_pickapic.py new file mode 100644 index 0000000..c581c52 --- /dev/null +++ b/examples/DPO2/train_dpo_pickapic.py @@ -0,0 +1,24 @@ +import sys + +sys.path.append("./src") + +from drlx.trainer.dpo_trainer import DPOTrainer +from drlx.configs import DRLXConfig +from drlx.utils import get_latest_checkpoint + +# Pipeline first +from drlx.pipeline.pickapic_dpo import PickAPicDPOPipeline + +import torch + +pipe = PickAPicDPOPipeline() +resume = False + +config = DRLXConfig.load_yaml("configs/dpo_pickapic.yml") +trainer = DPOTrainer(config) + +if resume: + cp_dir = get_latest_checkpoint(f"checkpoints/{config.logging.run_name}") + trainer.load_checkpoint(cp_dir) + +trainer.train(pipe) \ No newline at end of file From be05515a90ca652a854136856c0d75fcf95b3603 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Tue, 13 Feb 2024 00:35:24 +0000 Subject: [PATCH 21/29] Rename DPO2 to DPO --- examples/{DPO2 => DPO}/train_dpo_pickapic.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{DPO2 => DPO}/train_dpo_pickapic.py (100%) diff --git a/examples/DPO2/train_dpo_pickapic.py b/examples/DPO/train_dpo_pickapic.py similarity index 100% rename from examples/DPO2/train_dpo_pickapic.py rename to examples/DPO/train_dpo_pickapic.py From e6023a364d61fcb5280248f12d70546bac3b4519 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Tue, 13 Feb 2024 00:50:39 +0000 Subject: [PATCH 22/29] Move DPO and DDPO sampler to their own files for better organiation --- src/drlx/sampling/__init__.py | 328 +----------------------------- src/drlx/sampling/ddpo_sampler.py | 206 +++++++++++++++++++ src/drlx/sampling/dpo_sampler.py | 136 +++++++++++++ 3 files changed, 343 insertions(+), 327 deletions(-) create mode 100644 src/drlx/sampling/ddpo_sampler.py create mode 100644 src/drlx/sampling/dpo_sampler.py diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index f9e158c..21f8904 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -87,330 +87,4 @@ def sample(self, prompts : Iterable[str], denoiser, device = None, show_progress if self.config.postprocess: return denoiser_unwrapped.postprocess(latents) else: - return latents - -class DDPOSampler(Sampler): - def step_and_logprobs(self, - scheduler, - pred : TensorType["b", "c", "h", "w"], - t : float, - latents : TensorType["b", "c", "h", "w"], - old_pred : Optional[TensorType["b", "c", "h", "w"]] = None - ): - """ - Steps backwards using scheduler. Considers the prediction as an action sampled - from a normal distribution and returns average log probability for that prediction. - Can also be used to find probability of current model giving some other prediction (old_pred) - - :param scheduler: Scheduler being used for diffusion process - :param pred: Denoiser prediction with CFG and scaling accounted for - :param t: Timestep in diffusion process - :param latents: Latent vector given as input to denoiser - :param old_pred: Alternate prediction. If given, computes log probability of current model predicting alternative output. - """ - scheduler_out = scheduler.step(pred, t, latents, self.config.eta, variance_noise=0) - - # computing log_probs - t_1 = t - scheduler.config.num_train_timesteps // self.config.num_inference_steps - variance = scheduler._get_variance(t, t_1) - std_dev_t = self.config.eta * variance ** 0.5 - prev_sample_mean = scheduler_out.prev_sample - prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t - - std_dev_t = torch.clip(std_dev_t, 1e-6) # force sigma > 1e-6 - - # If old_pred provided, we are finding probability of new model outputting same action as before - # Otherwise finding probability of current action - action = old_pred if old_pred is not None else prev_sample # Log prob of new model giving old output - log_probs = -((action.detach() - prev_sample_mean) ** 2) / (2 * std_dev_t ** 2) - torch.log(std_dev_t) - math.log(math.sqrt(2 * math.pi)) - log_probs = eo.reduce(log_probs, 'b c h w -> b', 'mean') - - return prev_sample, log_probs - - @torch.no_grad() - def sample( - self, prompts, denoiser, device, - show_progress : bool = False, - accelerator = None - ) -> Iterable[torch.Tensor]: - """ - DDPO sampling is analagous to playing a game in an RL environment. This function samples - given denoiser and prompts but in addition to giving latents also gives log probabilities - for predictions as well as ALL predictions (i.e. at each timestep) - - :param prompts: Text prompts to condition denoiser - :param denoiser: Denoising model - :param device: Device to do inference on - :param show_progress: Display progress bar? - :param accelerator: Accelerator object for accelerated training (optional) - - :return: triple of final denoised latents, all model predictions, all log probabilities for each prediction - """ - - if accelerator is None: - denoiser_unwrapped = denoiser - else: - denoiser_unwrapped = accelerator.unwrap_model(denoiser) - - scheduler = denoiser_unwrapped.scheduler - preprocess = denoiser_unwrapped.preprocess - noise_shape = denoiser_unwrapped.get_input_shape() - - text_embeds = preprocess( - prompts, mode = "embeds", device = device, - num_images_per_prompt = 1, - do_classifier_free_guidance = self.config.guidance_scale > 1.0 - ).detach() - - scheduler.set_timesteps(self.config.num_inference_steps, device = device) - latents = torch.randn(len(prompts), *noise_shape, device = device) - - all_step_preds, all_log_probs = [latents], [] - - for t in tqdm(scheduler.timesteps, disable = not show_progress): - latent_input = torch.cat([latents] * 2) - latent_input = scheduler.scale_model_input(latent_input, t) - - pred = denoiser( - pixel_values = latent_input, - time_step = t, - text_embeds = text_embeds - ) - - # cfg - pred = self.cfg_rescale(pred) - - # step - prev_sample, log_probs = self.step_and_logprobs(scheduler, pred, t, latents) - - all_step_preds.append(prev_sample) - all_log_probs.append(log_probs) - latents = prev_sample - - return latents, torch.stack(all_step_preds), torch.stack(all_log_probs) - - def compute_loss( - self, prompts, denoiser, device, - show_progress : bool = False, - advantages = None, old_preds = None, old_log_probs = None, - method_config : DDPOConfig = None, - accelerator = None - ): - - - """ - Computes the loss for the DDPO sampling process. This function is used to train the denoiser model. - - :param prompts: Text prompts to condition the denoiser - :param denoiser: Denoising model - :param device: Device to perform model inference on - :param show_progress: Whether to display a progress bar for the sampling steps - :param advantages: Normalized advantages obtained from reward computation - :param old_preds: Previous predictions from past model - :param old_log_probs: Log probabilities of predictions from past model - :param method_config: Configuration for the DDPO method - :param accelerator: Accelerator object for accelerated training (optional) - - :return: Total loss computed over the sampling process - """ - - # All metrics are reduced and gathered before result is returned - metrics = { - "loss" : [], - "kl_div" : [], # ~ KL div between new policy and old one (average) - "clip_frac" : [], # Proportion of policy updates where magnitude of update was clipped - } - - scheduler = accelerator.unwrap_model(denoiser).scheduler - preprocess = accelerator.unwrap_model(denoiser).preprocess - - adv_clip = method_config.clip_advantages # clip value for advantages - pi_clip = method_config.clip_ratio # clip value for policy ratio - - text_embeds = preprocess( - prompts, mode = "embeds", device = device, - num_images_per_prompt = 1, - do_classifier_free_guidance = self.config.guidance_scale > 1.0 - ).detach() - - scheduler.set_timesteps(self.config.num_inference_steps, device = device) - total_loss = 0. - - for i, t in enumerate(tqdm(scheduler.timesteps, disable = not show_progress)): - latent_input = torch.cat([old_preds[i].detach()] * 2) - latent_input = scheduler.scale_model_input(latent_input, t) - - pred = denoiser( - pixel_values = latent_input, - time_step = t, - text_embeds = text_embeds - ) - - # cfg - pred = self.cfg_rescale(pred) - - # step - prev_sample, log_probs = self.step_and_logprobs( - scheduler, pred, t, old_preds[i], - old_preds[i+1] - ) - - # Need to be computed and detached again because of autograd weirdness - clipped_advs = torch.clip(advantages,-adv_clip,adv_clip).detach() - - # ppo actor loss - - ratio = torch.exp(log_probs - old_log_probs[i].detach()) - surr1 = -clipped_advs * ratio - surr2 = -clipped_advs * torch.clip(ratio, 1. - pi_clip, 1. + pi_clip) - loss = torch.max(surr1, surr2).mean() - if accelerator is not None: - accelerator.backward(loss) - else: - loss.backward() - - # Metric computations - kl_div = 0.5 * (log_probs - old_log_probs[i]).mean() ** 2 - clip_frac = ((ratio < 1 - pi_clip) | (ratio > 1 + pi_clip)).float().mean() - - metrics["loss"].append(loss.item()) - metrics["kl_div"].append(kl_div.item()) - metrics["clip_frac"].append(clip_frac.item()) - - # Reduce across timesteps then across devices - for k in metrics: - metrics[k] = torch.tensor(metrics[k]).mean().cuda() # Needed for reduction to work - if accelerator is not None: - metrics = accelerator.reduce(metrics, 'mean') - - return metrics - -class DPOSampler(Sampler): - def compute_loss( - self, - prompts, - chosen_img, - rejected_img, - denoiser, - vae, - device, - method_config, - accelerator = None, - ref_denoiser = None - ): - """ - Compute metrics and do backwards pass on loss. Assumes LoRA if reference is not given. - """ - do_lora = ref_denoiser is None - - scheduler = accelerator.unwrap_model(denoiser).scheduler - preprocess = accelerator.unwrap_model(denoiser).preprocess - encode = accelerator.unwrap_model(vae).encode - - beta = method_config.beta - ref_strategy = method_config.ref_mem_strategy - - # Text and image preprocessing - with torch.no_grad(): - text_embeds = preprocess( - prompts, mode = "embeds", device = device, - num_images_per_prompt = 1, - do_classifier_free_guidance = self.config.guidance_scale > 1.0 - ).detach() - - chosen_latent = encode(chosen_img).latent_dist.sample() - rejected_latent = encode(rejected_img).latent_dist.sample() - - # sample random ts - timesteps = torch.randint( - 0, self.config.num_inference_steps, (len(chosen_img),), device = device, dtype = torch.long - ) - - # One step of noising to samples - noise = torch.randn_like(chosen_latent) # [B, C, H, W] - - # Doubling across chosen and rejeceted - def double_up(x): - return torch.cat([x,x], dim = 0) - - def double_down(x): - n = len(x) - return x[:n//2], x[n//2:] - - # Double everything up so we can input both chosen and rejected at the same time - timesteps = double_up(timesteps) - noise = double_up(noise) - text_embeds = double_up(text_embeds) - latent = torch.cat([chosen_latent, rejected_latent]) - - noisy_inputs = scheduler.add_noise( - latent, - noise, - timesteps - ) - - # Get targets - if scheduler.config.prediction_type == "epsilon": - target = noise - elif scheduler.config.prediction_type == "v_prediction": - target = scheduler.get_velocity( - latent, - noise, - timesteps - ) - - # utility function to get loss simpler - def split_mse(pred, target): - mse = eo.reduce(F.mse_loss(pred, target, reduction = 'none'), 'b ... -> b', reduction = "mean") - chosen, rejected = double_down(mse) - return chosen - rejected, mse.mean() - - # Forward pass and loss for DPO denoiser - pred = denoiser( - pixel_values = noisy_inputs, - time_step = timesteps, - text_embeds = text_embeds - ) - model_diff, base_loss = split_mse(pred, target) - - # Forward pass and loss for refrence - with torch.no_grad(): - if do_lora: - accelerator.unwrap_model(denoiser).disable_adapters() - - ref_pred = denoiser( - pixel_values = noisy_inputs, - time_step = timesteps, - text_embeds = text_embeds - ) - ref_diff, ref_loss = split_mse(ref_pred, target) - - accelerator.unwrap_model(denoiser).enable_adapters() - else: - ref_inputs = { - "sample" : noisy_inputs.half() if ref_strategy == "half" else noisy_inputs, - "timestep" : timesteps, - "encoder_hidden_states" : text_embeds.half() if ref_strategy == "half" else text_embeds - } - ref_pred = ref_denoiser(**ref_inputs).sample - ref_diff, ref_loss = split_mse(ref_pred, target) - - # DPO Objective - surr_loss = -beta * (model_diff - ref_diff) - loss = -1 * F.logsigmoid(surr_loss.mean()) - - # Get approx accuracy as models probability of giving chosen over rejected - acc = (surr_loss > 0).sum().float() / len(surr_loss) - acc += 0.5 * (surr_loss == 0).sum().float() / len(surr_loss) # 50% for when both match - - if accelerator is None: - loss.backward() - else: - accelerator.backward(loss) - - return { - "loss" : loss.item(), - "diffusion_loss" : base_loss.item(), - "accuracy" : acc.item(), - "ref_deviation" : (ref_loss - base_loss) ** 2 - } \ No newline at end of file + return latents \ No newline at end of file diff --git a/src/drlx/sampling/ddpo_sampler.py b/src/drlx/sampling/ddpo_sampler.py new file mode 100644 index 0000000..affee9a --- /dev/null +++ b/src/drlx/sampling/ddpo_sampler.py @@ -0,0 +1,206 @@ +from torchtyping import TensorType +from typing import Iterable, Optional + +import einops as eo +import torch +from tqdm import tqdm +import math + +from drlx.sampling import Sampler +from drlx.configs import DDPOConfig + +class DDPOSampler(Sampler): + def step_and_logprobs(self, + scheduler, + pred : TensorType["b", "c", "h", "w"], + t : float, + latents : TensorType["b", "c", "h", "w"], + old_pred : Optional[TensorType["b", "c", "h", "w"]] = None + ): + """ + Steps backwards using scheduler. Considers the prediction as an action sampled + from a normal distribution and returns average log probability for that prediction. + Can also be used to find probability of current model giving some other prediction (old_pred) + + :param scheduler: Scheduler being used for diffusion process + :param pred: Denoiser prediction with CFG and scaling accounted for + :param t: Timestep in diffusion process + :param latents: Latent vector given as input to denoiser + :param old_pred: Alternate prediction. If given, computes log probability of current model predicting alternative output. + """ + scheduler_out = scheduler.step(pred, t, latents, self.config.eta, variance_noise=0) + + # computing log_probs + t_1 = t - scheduler.config.num_train_timesteps // self.config.num_inference_steps + variance = scheduler._get_variance(t, t_1) + std_dev_t = self.config.eta * variance ** 0.5 + prev_sample_mean = scheduler_out.prev_sample + prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t + + std_dev_t = torch.clip(std_dev_t, 1e-6) # force sigma > 1e-6 + + # If old_pred provided, we are finding probability of new model outputting same action as before + # Otherwise finding probability of current action + action = old_pred if old_pred is not None else prev_sample # Log prob of new model giving old output + log_probs = -((action.detach() - prev_sample_mean) ** 2) / (2 * std_dev_t ** 2) - torch.log(std_dev_t) - math.log(math.sqrt(2 * math.pi)) + log_probs = eo.reduce(log_probs, 'b c h w -> b', 'mean') + + return prev_sample, log_probs + + @torch.no_grad() + def sample( + self, prompts, denoiser, device, + show_progress : bool = False, + accelerator = None + ) -> Iterable[torch.Tensor]: + """ + DDPO sampling is analagous to playing a game in an RL environment. This function samples + given denoiser and prompts but in addition to giving latents also gives log probabilities + for predictions as well as ALL predictions (i.e. at each timestep) + + :param prompts: Text prompts to condition denoiser + :param denoiser: Denoising model + :param device: Device to do inference on + :param show_progress: Display progress bar? + :param accelerator: Accelerator object for accelerated training (optional) + + :return: triple of final denoised latents, all model predictions, all log probabilities for each prediction + """ + + if accelerator is None: + denoiser_unwrapped = denoiser + else: + denoiser_unwrapped = accelerator.unwrap_model(denoiser) + + scheduler = denoiser_unwrapped.scheduler + preprocess = denoiser_unwrapped.preprocess + noise_shape = denoiser_unwrapped.get_input_shape() + + text_embeds = preprocess( + prompts, mode = "embeds", device = device, + num_images_per_prompt = 1, + do_classifier_free_guidance = self.config.guidance_scale > 1.0 + ).detach() + + scheduler.set_timesteps(self.config.num_inference_steps, device = device) + latents = torch.randn(len(prompts), *noise_shape, device = device) + + all_step_preds, all_log_probs = [latents], [] + + for t in tqdm(scheduler.timesteps, disable = not show_progress): + latent_input = torch.cat([latents] * 2) + latent_input = scheduler.scale_model_input(latent_input, t) + + pred = denoiser( + pixel_values = latent_input, + time_step = t, + text_embeds = text_embeds + ) + + # cfg + pred = self.cfg_rescale(pred) + + # step + prev_sample, log_probs = self.step_and_logprobs(scheduler, pred, t, latents) + + all_step_preds.append(prev_sample) + all_log_probs.append(log_probs) + latents = prev_sample + + return latents, torch.stack(all_step_preds), torch.stack(all_log_probs) + + def compute_loss( + self, prompts, denoiser, device, + show_progress : bool = False, + advantages = None, old_preds = None, old_log_probs = None, + method_config : DDPOConfig = None, + accelerator = None + ): + + + """ + Computes the loss for the DDPO sampling process. This function is used to train the denoiser model. + + :param prompts: Text prompts to condition the denoiser + :param denoiser: Denoising model + :param device: Device to perform model inference on + :param show_progress: Whether to display a progress bar for the sampling steps + :param advantages: Normalized advantages obtained from reward computation + :param old_preds: Previous predictions from past model + :param old_log_probs: Log probabilities of predictions from past model + :param method_config: Configuration for the DDPO method + :param accelerator: Accelerator object for accelerated training (optional) + + :return: Total loss computed over the sampling process + """ + + # All metrics are reduced and gathered before result is returned + metrics = { + "loss" : [], + "kl_div" : [], # ~ KL div between new policy and old one (average) + "clip_frac" : [], # Proportion of policy updates where magnitude of update was clipped + } + + scheduler = accelerator.unwrap_model(denoiser).scheduler + preprocess = accelerator.unwrap_model(denoiser).preprocess + + adv_clip = method_config.clip_advantages # clip value for advantages + pi_clip = method_config.clip_ratio # clip value for policy ratio + + text_embeds = preprocess( + prompts, mode = "embeds", device = device, + num_images_per_prompt = 1, + do_classifier_free_guidance = self.config.guidance_scale > 1.0 + ).detach() + + scheduler.set_timesteps(self.config.num_inference_steps, device = device) + total_loss = 0. + + for i, t in enumerate(tqdm(scheduler.timesteps, disable = not show_progress)): + latent_input = torch.cat([old_preds[i].detach()] * 2) + latent_input = scheduler.scale_model_input(latent_input, t) + + pred = denoiser( + pixel_values = latent_input, + time_step = t, + text_embeds = text_embeds + ) + + # cfg + pred = self.cfg_rescale(pred) + + # step + prev_sample, log_probs = self.step_and_logprobs( + scheduler, pred, t, old_preds[i], + old_preds[i+1] + ) + + # Need to be computed and detached again because of autograd weirdness + clipped_advs = torch.clip(advantages,-adv_clip,adv_clip).detach() + + # ppo actor loss + + ratio = torch.exp(log_probs - old_log_probs[i].detach()) + surr1 = -clipped_advs * ratio + surr2 = -clipped_advs * torch.clip(ratio, 1. - pi_clip, 1. + pi_clip) + loss = torch.max(surr1, surr2).mean() + if accelerator is not None: + accelerator.backward(loss) + else: + loss.backward() + + # Metric computations + kl_div = 0.5 * (log_probs - old_log_probs[i]).mean() ** 2 + clip_frac = ((ratio < 1 - pi_clip) | (ratio > 1 + pi_clip)).float().mean() + + metrics["loss"].append(loss.item()) + metrics["kl_div"].append(kl_div.item()) + metrics["clip_frac"].append(clip_frac.item()) + + # Reduce across timesteps then across devices + for k in metrics: + metrics[k] = torch.tensor(metrics[k]).mean().cuda() # Needed for reduction to work + if accelerator is not None: + metrics = accelerator.reduce(metrics, 'mean') + + return metrics diff --git a/src/drlx/sampling/dpo_sampler.py b/src/drlx/sampling/dpo_sampler.py new file mode 100644 index 0000000..35e1040 --- /dev/null +++ b/src/drlx/sampling/dpo_sampler.py @@ -0,0 +1,136 @@ +import torch +import torch.nn.functional as F +import einops as eo + +from drlx.sampling import Sampler +from drlx.configs import DPOConfig + +class DPOSampler(Sampler): + def compute_loss( + self, + prompts, + chosen_img, + rejected_img, + denoiser, + vae, + device, + method_config : DPOConfig, + accelerator = None, + ref_denoiser = None + ): + """ + Compute metrics and do backwards pass on loss. Assumes LoRA if reference is not given. + """ + do_lora = ref_denoiser is None + + scheduler = accelerator.unwrap_model(denoiser).scheduler + preprocess = accelerator.unwrap_model(denoiser).preprocess + encode = accelerator.unwrap_model(vae).encode + + beta = method_config.beta + ref_strategy = method_config.ref_mem_strategy + + # Text and image preprocessing + with torch.no_grad(): + text_embeds = preprocess( + prompts, mode = "embeds", device = device, + num_images_per_prompt = 1, + do_classifier_free_guidance = self.config.guidance_scale > 1.0 + ).detach() + + chosen_latent = encode(chosen_img).latent_dist.sample() + rejected_latent = encode(rejected_img).latent_dist.sample() + + # sample random ts + timesteps = torch.randint( + 0, self.config.num_inference_steps, (len(chosen_img),), device = device, dtype = torch.long + ) + + # One step of noising to samples + noise = torch.randn_like(chosen_latent) # [B, C, H, W] + + # Doubling across chosen and rejeceted + def double_up(x): + return torch.cat([x,x], dim = 0) + + def double_down(x): + n = len(x) + return x[:n//2], x[n//2:] + + # Double everything up so we can input both chosen and rejected at the same time + timesteps = double_up(timesteps) + noise = double_up(noise) + text_embeds = double_up(text_embeds) + latent = torch.cat([chosen_latent, rejected_latent]) + + noisy_inputs = scheduler.add_noise( + latent, + noise, + timesteps + ) + + # Get targets + if scheduler.config.prediction_type == "epsilon": + target = noise + elif scheduler.config.prediction_type == "v_prediction": + target = scheduler.get_velocity( + latent, + noise, + timesteps + ) + + # utility function to get loss simpler + def split_mse(pred, target): + mse = eo.reduce(F.mse_loss(pred, target, reduction = 'none'), 'b ... -> b', reduction = "mean") + chosen, rejected = double_down(mse) + return chosen - rejected, mse.mean() + + # Forward pass and loss for DPO denoiser + pred = denoiser( + pixel_values = noisy_inputs, + time_step = timesteps, + text_embeds = text_embeds + ) + model_diff, base_loss = split_mse(pred, target) + + # Forward pass and loss for refrence + with torch.no_grad(): + if do_lora: + accelerator.unwrap_model(denoiser).disable_adapters() + + ref_pred = denoiser( + pixel_values = noisy_inputs, + time_step = timesteps, + text_embeds = text_embeds + ) + ref_diff, ref_loss = split_mse(ref_pred, target) + + accelerator.unwrap_model(denoiser).enable_adapters() + else: + ref_inputs = { + "sample" : noisy_inputs.half() if ref_strategy == "half" else noisy_inputs, + "timestep" : timesteps, + "encoder_hidden_states" : text_embeds.half() if ref_strategy == "half" else text_embeds + } + ref_pred = ref_denoiser(**ref_inputs).sample + ref_diff, ref_loss = split_mse(ref_pred, target) + + # DPO Objective + surr_loss = -beta * (model_diff - ref_diff) + loss = -1 * F.logsigmoid(surr_loss.mean()) + + # Get approx accuracy as models probability of giving chosen over rejected + acc = (surr_loss > 0).sum().float() / len(surr_loss) + acc += 0.5 * (surr_loss == 0).sum().float() / len(surr_loss) # 50% for when both match + + if accelerator is None: + loss.backward() + else: + accelerator.backward(loss) + + return { + "loss" : loss.item(), + "diffusion_loss" : base_loss.item(), + "accuracy" : acc.item(), + "ref_deviation" : (ref_loss - base_loss) ** 2 + } \ No newline at end of file From 5253473b89cec364cff232fca566ca9327c2d9a2 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Tue, 13 Feb 2024 02:14:46 +0000 Subject: [PATCH 23/29] prepare for adding SDXL --- src/drlx/configs.py | 16 ++++++++-------- src/drlx/denoisers/ldm_unet.py | 21 +++++++++++++++++++-- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/drlx/configs.py b/src/drlx/configs.py index d4fa200..2b4f9ee 100644 --- a/src/drlx/configs.py +++ b/src/drlx/configs.py @@ -235,14 +235,14 @@ class ModelConfig(ConfigClass): :param model_path: Path or name of the model (local or on huggingface hub) :type model_path: str - :param model_arch_type: Type of model architecture. - :type model_arch_type: str + :param pipeline_kwargs: Keyword arguments for pipeline if model is being loaded from one + :type pipeline_kwargs: dict - :param use_safetensors: Use safe tensors when loading pipeline? - :type use_safetensors: bool + :param sdxl: Using SDXL model? + :type sdxl: bool - :param local_model: Force model to load checkpoint locally only - :type local_model: bool + :param model_arch_type: Type of model architecture. Defaults to LDM UNet + :type model_arch_type: str :param attention_slicing: Whether to use attention slicing :type attention_slicing: bool @@ -258,9 +258,9 @@ class ModelConfig(ConfigClass): """ model_path: str = None + pipeline_kwargs : dict = None + sdxl : bool = False model_arch_type: str = None - use_safetensors : bool = False - local_model : bool = False attention_slicing: bool = False xformers_memory_efficient: bool = False gradient_checkpointing: bool = False diff --git a/src/drlx/denoisers/ldm_unet.py b/src/drlx/denoisers/ldm_unet.py index 71bec0f..222024c 100644 --- a/src/drlx/denoisers/ldm_unet.py +++ b/src/drlx/denoisers/ldm_unet.py @@ -28,7 +28,10 @@ def __init__(self, config : ModelConfig, sampler_config : SamplerConfig = None, super().__init__(config, sampler_config, sampler) self.unet : UNet2DConditionModel = None + self.text_encoder = None + self.text_encoder_2 = None # SDXL Support, just needs to be here for device mapping + self.vae = None self.encode_prompt : Callable = None @@ -37,6 +40,8 @@ def __init__(self, config : ModelConfig, sampler_config : SamplerConfig = None, self.scale_factor = None + self.sdxl_flag = self.config.sdxl + def get_input_shape(self) -> Tuple[int]: """ Figure out latent noise input shape for the UNet. Requires that unet and vae are defined @@ -65,16 +70,28 @@ def from_pretrained_pipeline(self, cls : Type, path : str): :rtype: LDMUNet """ - pipe = cls.from_pretrained(path, use_safetensors = self.config.use_safetensors, local_files_only = self.config.local_model) + kwargs = self.config.pipeline_kwargs + if kwargs['variant'] == "fp16": + kwargs['torch_dtype'] = torch.float16 + else: + kwargs["torch_dtype"] = torch.float32 + + pipe = cls.from_pretrained(path, **kwargs) if self.config.attention_slicing: pipe.enable_attention_slicing() if self.config.xformers_memory_efficient: pipe.enable_xformers_memory_efficient_attention() self.unet = pipe.unet self.text_encoder = pipe.text_encoder + + # SDXL compat + if self.sdxl_flag: + self.text_encoder_2 = pipe.text_encoder_2 + self.vae = pipe.vae self.scale_factor = pipe.vae_scale_factor - self.encode_prompt = pipe._encode_prompt + + self.encode_prompt = pipe.encode_prompt self.text_encoder.requires_grad_(False) self.vae.requires_grad_(False) From 54f6ec163bc89b0c12aa288a24f41edf434b758e Mon Sep 17 00:00:00 2001 From: shahbuland Date: Tue, 13 Feb 2024 02:15:24 +0000 Subject: [PATCH 24/29] Fix issue with modularizing samplers --- src/drlx/sampling/__init__.py | 93 +------------------------------ src/drlx/sampling/base.py | 89 +++++++++++++++++++++++++++++ src/drlx/sampling/ddpo_sampler.py | 7 ++- src/drlx/sampling/dpo_sampler.py | 7 ++- 4 files changed, 103 insertions(+), 93 deletions(-) create mode 100644 src/drlx/sampling/base.py diff --git a/src/drlx/sampling/__init__.py b/src/drlx/sampling/__init__.py index 21f8904..fe8a7f5 100644 --- a/src/drlx/sampling/__init__.py +++ b/src/drlx/sampling/__init__.py @@ -1,90 +1,3 @@ -from typing import Union, Iterable, Tuple, Any, Optional -from torchtyping import TensorType - -import torch -from tqdm import tqdm -import math -import einops as eo -import torch.nn.functional as F - -from drlx.utils import rescale_noise_cfg - -from drlx.configs import SamplerConfig, DDPOConfig - -class Sampler: - """ - Generic class for sampling generations using a denoiser. Assumes LDMUnet - """ - def __init__(self, config : SamplerConfig = SamplerConfig()): - self.config = config - - def cfg_rescale(self, pred : TensorType["2 * b", "c", "h", "w"]): - """ - Applies classifier free guidance to prediction and rescales if cfg_rescaling is enabled - - :param pred: - Assumed to be batched repeated prediction with first half consisting of - unconditioned (empty token) predictions and second half being conditioned - predictions - """ - - pred_uncond, pred_cond = pred.chunk(2) - pred = pred_uncond + self.config.guidance_scale * (pred_cond - pred_uncond) - - if self.config.guidance_rescale is not None: - pred = rescale_noise_cfg(pred, pred_cond, self.config.guidance_rescale) - - return pred - - @torch.no_grad() - def sample(self, prompts : Iterable[str], denoiser, device = None, show_progress : bool = False, accelerator = None): - """ - Samples latents given some prompts and a denoiser - - :param prompts: Text prompts for image generation (to condition denoiser) - :param denoiser: Model to use for denoising - :param device: Device on which to perform model inference - :param show_progress: Whether to display a progress bar for the sampling steps - :param accelerator: Accelerator object for accelerated training (optional) - - :return: Latents unless postprocess flag is set to true in config, in which case VAE decoded latents are returned (i.e. images) - """ - if accelerator is None: - denoiser_unwrapped = denoiser - else: - denoiser_unwrapped = accelerator.unwrap_model(denoiser) - - scheduler = denoiser_unwrapped.scheduler - preprocess = denoiser_unwrapped.preprocess - noise_shape = denoiser_unwrapped.get_input_shape() - - text_embeds = preprocess( - prompts, mode = "embeds", device = device, - num_images_per_prompt = 1, - do_classifier_free_guidance = self.config.guidance_scale > 1.0 - ).detach() - - scheduler.set_timesteps(self.config.num_inference_steps, device = device) - latents = torch.randn(len(prompts), *noise_shape, device = device) - - for i, t in enumerate(tqdm(scheduler.timesteps), disable = not show_progress): - input = torch.cat([latents] * 2) - input = scheduler.scale_model_input(input, t) - - pred = denoiser( - pixel_values=input, - time_step = t, - text_embeds = text_embeds - ) - - # guidance - pred = self.cfg_rescale(pred) - - # step backward - scheduler_out = scheduler.step(pred, t, latents, self.config.eta) - latents = scheduler_out.prev_sample - - if self.config.postprocess: - return denoiser_unwrapped.postprocess(latents) - else: - return latents \ No newline at end of file +from .base import Sampler +from .ddpo_sampler import DDPOSampler +from .dpo_sampler import DPOSampler \ No newline at end of file diff --git a/src/drlx/sampling/base.py b/src/drlx/sampling/base.py new file mode 100644 index 0000000..88cf118 --- /dev/null +++ b/src/drlx/sampling/base.py @@ -0,0 +1,89 @@ +from typing import Union, Iterable, Tuple, Any, Optional +from torchtyping import TensorType + +import torch +from tqdm import tqdm +import math +import einops as eo +import torch.nn.functional as F + +from drlx.utils import rescale_noise_cfg +from drlx.configs import SamplerConfig + +class Sampler: + """ + Generic class for sampling generations using a denoiser. Assumes LDMUnet + """ + def __init__(self, config : SamplerConfig = SamplerConfig()): + self.config = config + + def cfg_rescale(self, pred : TensorType["2 * b", "c", "h", "w"]): + """ + Applies classifier free guidance to prediction and rescales if cfg_rescaling is enabled + + :param pred: + Assumed to be batched repeated prediction with first half consisting of + unconditioned (empty token) predictions and second half being conditioned + predictions + """ + + pred_uncond, pred_cond = pred.chunk(2) + pred = pred_uncond + self.config.guidance_scale * (pred_cond - pred_uncond) + + if self.config.guidance_rescale is not None: + pred = rescale_noise_cfg(pred, pred_cond, self.config.guidance_rescale) + + return pred + + @torch.no_grad() + def sample(self, prompts : Iterable[str], denoiser, device = None, show_progress : bool = False, accelerator = None): + """ + Samples latents given some prompts and a denoiser + + :param prompts: Text prompts for image generation (to condition denoiser) + :param denoiser: Model to use for denoising + :param device: Device on which to perform model inference + :param show_progress: Whether to display a progress bar for the sampling steps + :param accelerator: Accelerator object for accelerated training (optional) + + :return: Latents unless postprocess flag is set to true in config, in which case VAE decoded latents are returned (i.e. images) + """ + if accelerator is None: + denoiser_unwrapped = denoiser + else: + denoiser_unwrapped = accelerator.unwrap_model(denoiser) + + scheduler = denoiser_unwrapped.scheduler + preprocess = denoiser_unwrapped.preprocess + noise_shape = denoiser_unwrapped.get_input_shape() + + text_embeds = preprocess( + prompts, mode = "embeds", device = device, + num_images_per_prompt = 1, + do_classifier_free_guidance = self.config.guidance_scale > 1.0 + ).detach() + + scheduler.set_timesteps(self.config.num_inference_steps, device = device) + latents = torch.randn(len(prompts), *noise_shape, device = device) + + for i, t in enumerate(tqdm(scheduler.timesteps), disable = not show_progress): + input = torch.cat([latents] * 2) + input = scheduler.scale_model_input(input, t) + + pred = denoiser( + pixel_values=input, + time_step = t, + text_embeds = text_embeds + ) + + # guidance + pred = self.cfg_rescale(pred) + + # step backward + scheduler_out = scheduler.step(pred, t, latents, self.config.eta) + latents = scheduler_out.prev_sample + + if self.config.postprocess: + return denoiser_unwrapped.postprocess(latents) + else: + return latents \ No newline at end of file diff --git a/src/drlx/sampling/ddpo_sampler.py b/src/drlx/sampling/ddpo_sampler.py index affee9a..d0a9049 100644 --- a/src/drlx/sampling/ddpo_sampler.py +++ b/src/drlx/sampling/ddpo_sampler.py @@ -6,7 +6,7 @@ from tqdm import tqdm import math -from drlx.sampling import Sampler +from drlx.sampling.base import Sampler from drlx.configs import DDPOConfig class DDPOSampler(Sampler): @@ -82,13 +82,16 @@ def sample( do_classifier_free_guidance = self.config.guidance_scale > 1.0 ).detach() + # If not SDXL, we assume encode prompts gave normal and negative embeds, which we concat + text_embeds = torch.cat([text_embeds[1], text_embeds[0]]) + scheduler.set_timesteps(self.config.num_inference_steps, device = device) latents = torch.randn(len(prompts), *noise_shape, device = device) all_step_preds, all_log_probs = [latents], [] for t in tqdm(scheduler.timesteps, disable = not show_progress): - latent_input = torch.cat([latents] * 2) + latent_input = torch.cat([latents] * 2) # Double for CFG latent_input = scheduler.scale_model_input(latent_input, t) pred = denoiser( diff --git a/src/drlx/sampling/dpo_sampler.py b/src/drlx/sampling/dpo_sampler.py index 35e1040..d0036b5 100644 --- a/src/drlx/sampling/dpo_sampler.py +++ b/src/drlx/sampling/dpo_sampler.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import einops as eo -from drlx.sampling import Sampler +from drlx.sampling.base import Sampler from drlx.configs import DPOConfig class DPOSampler(Sampler): @@ -25,6 +25,7 @@ def compute_loss( scheduler = accelerator.unwrap_model(denoiser).scheduler preprocess = accelerator.unwrap_model(denoiser).preprocess + sdxl_flag = accelerator.unwrap_model(denoiser).sdxl_flag encode = accelerator.unwrap_model(vae).encode beta = method_config.beta @@ -38,6 +39,10 @@ def compute_loss( do_classifier_free_guidance = self.config.guidance_scale > 1.0 ).detach() + # The value returned above varies depending on model + # With most models its two values, positive and negative prompts + # With DPO we don't care about CFG, so just only get the positive prompts + chosen_latent = encode(chosen_img).latent_dist.sample() rejected_latent = encode(rejected_img).latent_dist.sample() From 44f163d52883f634e6275842ca0611634623e056 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Tue, 13 Feb 2024 03:58:10 +0000 Subject: [PATCH 25/29] Add SDXL support and reorganize config for model --- src/drlx/configs.py | 2 +- src/drlx/denoisers/ldm_unet.py | 11 +++++------ src/drlx/sampling/ddpo_sampler.py | 8 ++++++-- src/drlx/sampling/dpo_sampler.py | 23 +++++++++++++++++++---- src/drlx/utils/sdxl.py | 14 ++++++++++++++ 5 files changed, 45 insertions(+), 13 deletions(-) create mode 100644 src/drlx/utils/sdxl.py diff --git a/src/drlx/configs.py b/src/drlx/configs.py index 2b4f9ee..4912bc1 100644 --- a/src/drlx/configs.py +++ b/src/drlx/configs.py @@ -160,7 +160,7 @@ class TrainConfig(ConfigClass): num_epochs: int = 50 total_samples: int = None num_samples_per_epoch: int = 256 - grad_clip: float = 1.0 + grad_clip: float = -1 checkpoint_interval: int = 10 checkpoint_path: str = "checkpoints" seed: int = 0 diff --git a/src/drlx/denoisers/ldm_unet.py b/src/drlx/denoisers/ldm_unet.py index 222024c..f92253b 100644 --- a/src/drlx/denoisers/ldm_unet.py +++ b/src/drlx/denoisers/ldm_unet.py @@ -71,10 +71,7 @@ def from_pretrained_pipeline(self, cls : Type, path : str): """ kwargs = self.config.pipeline_kwargs - if kwargs['variant'] == "fp16": - kwargs['torch_dtype'] = torch.float16 - else: - kwargs["torch_dtype"] = torch.float32 + kwargs["torch_dtype"] = torch.float32 pipe = cls.from_pretrained(path, **kwargs) @@ -166,7 +163,8 @@ def forward( time_step : Union[TensorType["batch"], int], # Note diffusers tyically does 999->0 as steps input_ids : TensorType["batch", "seq_len"] = None, attention_mask : TensorType["batch", "seq_len"] = None, - text_embeds : TensorType["batch", "d"] = None + text_embeds : TensorType["batch", "d"] = None, + added_cond_kwargs = {} ) -> TensorType["batch", "channels", "height", "width"]: """ For text conditioned UNET, inputs are assumed to be: @@ -179,7 +177,8 @@ def forward( return self.unet( pixel_values, time_step, - encoder_hidden_states = text_embeds + encoder_hidden_states = text_embeds, + added_cond_kwargs = added_cond_kwargs ).sample @property diff --git a/src/drlx/sampling/ddpo_sampler.py b/src/drlx/sampling/ddpo_sampler.py index d0a9049..8439e2b 100644 --- a/src/drlx/sampling/ddpo_sampler.py +++ b/src/drlx/sampling/ddpo_sampler.py @@ -74,16 +74,20 @@ def sample( scheduler = denoiser_unwrapped.scheduler preprocess = denoiser_unwrapped.preprocess + sdxl_flag = denoiser_unwrapped.sdxl_flag noise_shape = denoiser_unwrapped.get_input_shape() text_embeds = preprocess( prompts, mode = "embeds", device = device, num_images_per_prompt = 1, do_classifier_free_guidance = self.config.guidance_scale > 1.0 - ).detach() + ) # If not SDXL, we assume encode prompts gave normal and negative embeds, which we concat - text_embeds = torch.cat([text_embeds[1], text_embeds[0]]) + if sdxl_flag: + pass # TODO: SDXL Support for DDPO + else: + text_embeds = torch.cat([text_embeds[1], text_embeds[0]]).detach() scheduler.set_timesteps(self.config.num_inference_steps, device = device) latents = torch.randn(len(prompts), *noise_shape, device = device) diff --git a/src/drlx/sampling/dpo_sampler.py b/src/drlx/sampling/dpo_sampler.py index d0036b5..d3c8b03 100644 --- a/src/drlx/sampling/dpo_sampler.py +++ b/src/drlx/sampling/dpo_sampler.py @@ -4,6 +4,7 @@ from drlx.sampling.base import Sampler from drlx.configs import DPOConfig +from drlx.utils.sdxl import get_time_ids class DPOSampler(Sampler): def compute_loss( @@ -37,11 +38,17 @@ def compute_loss( prompts, mode = "embeds", device = device, num_images_per_prompt = 1, do_classifier_free_guidance = self.config.guidance_scale > 1.0 - ).detach() + ) # The value returned above varies depending on model # With most models its two values, positive and negative prompts # With DPO we don't care about CFG, so just only get the positive prompts + added_cond_kwargs = {} + if sdxl_flag: + added_cond_kwargs['text_embeds'] = text_embeds[2].detach() # Pooled prompt embeds + added_cond_kwargs['time_ids'] = get_time_ids(chosen_img) + + text_embeds = text_embeds[0].detach() chosen_latent = encode(chosen_img).latent_dist.sample() rejected_latent = encode(rejected_img).latent_dist.sample() @@ -66,6 +73,11 @@ def double_down(x): timesteps = double_up(timesteps) noise = double_up(noise) text_embeds = double_up(text_embeds) + + if sdxl_flag: + added_cond_kwargs['text_embeds'] = double_up(added_cond_kwargs['text_embeds']) + added_cond_kwargs['time_ids'] = double_up(added_cond_kwargs['time_ids']) + latent = torch.cat([chosen_latent, rejected_latent]) noisy_inputs = scheduler.add_noise( @@ -94,7 +106,8 @@ def split_mse(pred, target): pred = denoiser( pixel_values = noisy_inputs, time_step = timesteps, - text_embeds = text_embeds + text_embeds = text_embeds, + added_cond_kwargs = added_cond_kwargs ) model_diff, base_loss = split_mse(pred, target) @@ -106,7 +119,8 @@ def split_mse(pred, target): ref_pred = denoiser( pixel_values = noisy_inputs, time_step = timesteps, - text_embeds = text_embeds + text_embeds = text_embeds, + added_cond_kwargs = added_cond_kwargs ) ref_diff, ref_loss = split_mse(ref_pred, target) @@ -115,7 +129,8 @@ def split_mse(pred, target): ref_inputs = { "sample" : noisy_inputs.half() if ref_strategy == "half" else noisy_inputs, "timestep" : timesteps, - "encoder_hidden_states" : text_embeds.half() if ref_strategy == "half" else text_embeds + "encoder_hidden_states" : text_embeds.half() if ref_strategy == "half" else text_embeds, + "added_cond_kwargs" : added_cond_kwargs } ref_pred = ref_denoiser(**ref_inputs).sample ref_diff, ref_loss = split_mse(ref_pred, target) diff --git a/src/drlx/utils/sdxl.py b/src/drlx/utils/sdxl.py new file mode 100644 index 0000000..205d24d --- /dev/null +++ b/src/drlx/utils/sdxl.py @@ -0,0 +1,14 @@ +import einops as eo +import torch + +def get_time_ids(batch): + """ + Computes time ids needed for SDXL in a heavily simplified manner that only requires image size + (assumes square images). Assumes crop top left is (0,0) for all images. Infers all needed info from batch of images. + """ + + b, c, h, w = batch.shape + + # input_size, crop, input_size + add_time_ids = torch.tensor([h, w, 0, 0, h, w], device = batch.device, dtype = batch.dtype) + return eo.repeat(add_time_ids, 'd -> (b d)', b = b) From 565efe605f03581692567c77580a1ec796683796 Mon Sep 17 00:00:00 2001 From: shahbuland Date: Tue, 13 Feb 2024 04:01:41 +0000 Subject: [PATCH 26/29] Remove mandatory gradient clipping and fix model saving with new config --- src/drlx/trainer/base_accelerate.py | 2 +- src/drlx/trainer/dpo_trainer.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/drlx/trainer/base_accelerate.py b/src/drlx/trainer/base_accelerate.py index 1d85881..daa49c9 100644 --- a/src/drlx/trainer/base_accelerate.py +++ b/src/drlx/trainer/base_accelerate.py @@ -125,7 +125,7 @@ def save_pretrained(self, fp : str): StableDiffusionPipeline.save_lora_weights(fp, unet_lora_layers=unet_lora_state_dict, safe_serialization = unwrapped_model.config.use_safetensors) else: self.pipe.unet = unwrapped_model.unet - self.pipe.save_pretrained(fp, safe_serialization = unwrapped_model.config.use_safetensors) + self.pipe.save_pretrained(fp, safe_serialization = unwrapped_model.config.pipeline_kwargs['use_safetensors']) self.accelerator.wait_for_everyone() def extract_pipeline(self): diff --git a/src/drlx/trainer/dpo_trainer.py b/src/drlx/trainer/dpo_trainer.py index bed859a..6033ebd 100644 --- a/src/drlx/trainer/dpo_trainer.py +++ b/src/drlx/trainer/dpo_trainer.py @@ -20,7 +20,7 @@ from PIL import Image from copy import deepcopy -from diffusers import StableDiffusionPipeline +from diffusers import DiffusionPipeline class DPOTrainer(AcceleratedTrainer): """ @@ -45,7 +45,7 @@ def setup_model(self): """ model = self.get_arch(self.config)(self.config.model, sampler = DPOSampler(self.config.sampler)) if self.config.model.model_path is not None: - model, pipe = model.from_pretrained_pipeline(StableDiffusionPipeline, self.config.model.model_path) + model, pipe = model.from_pretrained_pipeline(DiffusionPipeline, self.config.model.model_path) self.pipe = pipe self.pipe.set_progress_bar_config(disable=True) @@ -155,10 +155,11 @@ def time_per_1k(n_samples : int): self.accelerator.wait_for_everyone() # Optimizer step - self.accelerator.clip_grad_norm_( - filter(lambda p: p.requires_grad, self.model.parameters()), - self.config.train.grad_clip - ) + if self.config.train.grad_clip > 0: + self.accelerator.clip_grad_norm_( + filter(lambda p: p.requires_grad, self.model.parameters()), + self.config.train.grad_clip + ) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() From 4324932e40eb4be202ec52b63a5cb561fb8b9cbd Mon Sep 17 00:00:00 2001 From: Shahbuland Matiana <44281577+shahbuland@users.noreply.github.com> Date: Mon, 12 Feb 2024 23:07:34 -0500 Subject: [PATCH 27/29] Update dpo_pickapic.yml --- configs/dpo_pickapic.yml | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/configs/dpo_pickapic.yml b/configs/dpo_pickapic.yml index c11b02e..768d1ac 100644 --- a/configs/dpo_pickapic.yml +++ b/configs/dpo_pickapic.yml @@ -2,11 +2,16 @@ method: name : "DPO" model: - model_path: "stabilityai/stable-diffusion-2-1-base" + model_path: "runwayml/stable-diffusion-v1-5" + pipeline_kwargs: + use_safetensors: True + variant: "fp16" + sdxl: False model_arch_type: "LDMUnet" attention_slicing: True xformers_memory_efficient: False gradient_checkpointing: True + sampler: guidance_scale: 7.5 @@ -15,7 +20,7 @@ sampler: optimizer: name: "adamw" kwargs: - lr: 1.0e-5 + lr: 2.048e-8 weight_decay: 1.0e-4 betas: [0.9, 0.999] @@ -27,15 +32,14 @@ scheduler: logging: run_name: 'dpo_pickapic' - #wandb_entity: None + #wandb_entity: None #wandb_project: None train: num_epochs: 500 num_samples_per_epoch: 256 batch_size: 4 - sample_batch_size: 32 - grad_clip: 1.0 - checkpoint_interval: 50 + target_batch: 256 + checkpoint_interval: 640 tf32: True - suppress_log_keywords: "diffusers.pipelines,transformers" \ No newline at end of file + suppress_log_keywords: "diffusers.pipelines,transformers" From dde1265a084725ea6b0d3a9950d0cc48e66c2ff1 Mon Sep 17 00:00:00 2001 From: Shahbuland Matiana <44281577+shahbuland@users.noreply.github.com> Date: Mon, 12 Feb 2024 23:13:56 -0500 Subject: [PATCH 28/29] Update dpo_pickapic.yml --- configs/dpo_pickapic.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/dpo_pickapic.yml b/configs/dpo_pickapic.yml index 768d1ac..9059946 100644 --- a/configs/dpo_pickapic.yml +++ b/configs/dpo_pickapic.yml @@ -2,7 +2,7 @@ method: name : "DPO" model: - model_path: "runwayml/stable-diffusion-v1-5" + model_path: "stabilityai/stable-diffusion-2-1" pipeline_kwargs: use_safetensors: True variant: "fp16" From 70f182749a629288e6e6b40f944623f77c46c6d0 Mon Sep 17 00:00:00 2001 From: Shahbuland Matiana <44281577+shahbuland@users.noreply.github.com> Date: Mon, 12 Feb 2024 23:14:04 -0500 Subject: [PATCH 29/29] Update dpo_pickapic.yml --- configs/dpo_pickapic.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/dpo_pickapic.yml b/configs/dpo_pickapic.yml index 9059946..b65e562 100644 --- a/configs/dpo_pickapic.yml +++ b/configs/dpo_pickapic.yml @@ -38,7 +38,7 @@ logging: train: num_epochs: 500 num_samples_per_epoch: 256 - batch_size: 4 + batch_size: 1 target_batch: 256 checkpoint_interval: 640 tf32: True