From a9adf2238692915c7afe781debc38e51e64de7a8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 17 Nov 2024 08:31:21 +0530 Subject: [PATCH 01/30] dataprep. --- training/__init__.py | 0 training/mochi-1/prepare_dataset.py | 673 ++++++++++++++++++++++++++++ 2 files changed, 673 insertions(+) create mode 100644 training/__init__.py create mode 100644 training/mochi-1/prepare_dataset.py diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/training/mochi-1/prepare_dataset.py b/training/mochi-1/prepare_dataset.py new file mode 100644 index 00000000..6e5fea9f --- /dev/null +++ b/training/mochi-1/prepare_dataset.py @@ -0,0 +1,673 @@ +#!/usr/bin/env python3 + +import argparse +import functools +import json +import os +import pathlib +import queue +import traceback +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from diffusers import AutoencoderKL +from diffusers.training_utils import set_seed +from diffusers.utils import export_to_video, get_logger +from torch.utils.data import DataLoader +from torchvision import transforms +from tqdm import tqdm +from transformers import T5EncoderModel, T5Tokenizer + + +import decord # isort:skip + +from ..dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip + + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def check_height(x: Any) -> int: + x = int(x) + if x % 16 != 0: + raise argparse.ArgumentTypeError( + f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria." + ) + return x + + +def check_width(x: Any) -> int: + x = int(x) + if x % 16 != 0: + raise argparse.ArgumentTypeError( + f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria." + ) + return x + + +def check_frames(x: Any) -> int: + x = int(x) + if x % 4 != 0 and x % 4 != 1: + raise argparse.ArgumentTypeError( + f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria." + ) + return x + + +def get_args() -> Dict[str, Any]: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_id", + type=str, + default="genmo/mochi-1-preview", + help="Hugging Face model ID to use for tokenizer, text encoder and VAE.", + ) + parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.") + parser.add_argument( + "--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data." + ) + parser.add_argument( + "--caption_column", + type=str, + default="caption", + help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).", + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).", + ) + parser.add_argument( + "--id_token", + type=str, + default=None, + help="Identifier token appended to the start of each prompt if provided.", + ) + parser.add_argument( + "--height_buckets", + nargs="+", + type=check_height, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--width_buckets", + nargs="+", + type=check_width, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--frame_buckets", + nargs="+", + type=check_frames, + default=[49], + ) + parser.add_argument( + "--random_flip", + type=float, + default=None, + help="If random horizontal flip augmentation is to be used, this should be the flip probability.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default=None, + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) + parser.add_argument( + "--save_image_latents", + action="store_true", + help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to output directory where preprocessed videos/latents/embeddings will be saved.", + ) + parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.") + parser.add_argument( + "--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings." + ) + parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.") + parser.add_argument( + "--save_latents_and_embeddings", + action="store_true", + help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.", + ) + parser.add_argument( + "--use_slicing", + action="store_true", + help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", + ) + parser.add_argument( + "--use_tiling", + action="store_true", + help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", + ) + parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.") + parser.add_argument( + "--num_decode_threads", + type=int, + default=0, + help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="fp32", + help="Data type to use when generating latents and prompt embeddings.", + ) + parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.") + parser.add_argument( + "--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts." + ) + return parser.parse_args() + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompts: List[str], + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype, + requires_grad: bool = False, +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompts, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompts, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds + + +to_pil_image = transforms.ToPILImage(mode="RGB") + + +def save_image(image: torch.Tensor, path: pathlib.Path) -> None: + image = to_pil_image(image) + image.save(path) + + +def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None: + video = [to_pil_image(frame) for frame in video] + export_to_video(video, path, fps=fps) + + +def save_prompt(prompt: str, path: pathlib.Path) -> None: + with open(path, "w", encoding="utf-8") as file: + file.write(prompt) + + +def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None: + with open(path, "w", encoding="utf-8") as file: + file.write(json.dumps(metadata)) + + +@torch.no_grad() +def serialize_artifacts( + batch_size: int, + fps: int, + images_dir: Optional[pathlib.Path] = None, + image_latents_dir: Optional[pathlib.Path] = None, + videos_dir: Optional[pathlib.Path] = None, + video_latents_dir: Optional[pathlib.Path] = None, + prompts_dir: Optional[pathlib.Path] = None, + prompt_embeds_dir: Optional[pathlib.Path] = None, + images: Optional[torch.Tensor] = None, + image_latents: Optional[torch.Tensor] = None, + videos: Optional[torch.Tensor] = None, + video_latents: Optional[torch.Tensor] = None, + prompts: Optional[List[str]] = None, + prompt_embeds: Optional[torch.Tensor] = None, +) -> None: + num_frames, height, width = videos.size(1), videos.size(3), videos.size(4) + metadata = [{"num_frames": num_frames, "height": height, "width": width}] + + data_folder_mapper_list = [ + (images, images_dir, lambda img, path: save_image(img[0], path), "png"), + (image_latents, image_latents_dir, torch.save, "pt"), + (videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"), + (video_latents, video_latents_dir, torch.save, "pt"), + (prompts, prompts_dir, save_prompt, "txt"), + (prompt_embeds, prompt_embeds_dir, torch.save, "pt"), + (metadata, videos_dir, save_metadata, "txt"), + ] + filenames = [uuid.uuid4() for _ in range(batch_size)] + + for data, folder, save_fn, extension in data_folder_mapper_list: + if data is None: + continue + for slice, filename in zip(data, filenames): + if isinstance(slice, torch.Tensor): + slice = slice.clone().to("cpu") + path = folder.joinpath(f"{filename}.{extension}") + save_fn(slice, path) + + +def save_intermediates(output_queue: queue.Queue) -> None: + while True: + try: + item = output_queue.get(timeout=30) + if item is None: + break + serialize_artifacts(**item) + + except queue.Empty: + continue + + +@torch.no_grad() +def main(): + args = get_args() + set_seed(args.seed) + + output_dir = pathlib.Path(args.output_dir) + tmp_dir = output_dir.joinpath("tmp") + + output_dir.mkdir(parents=True, exist_ok=True) + tmp_dir.mkdir(parents=True, exist_ok=True) + + # Create task queue for non-blocking serializing of artifacts + output_queue = queue.Queue() + save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers) + save_future = save_thread.submit(save_intermediates, output_queue) + + # Initialize distributed processing + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl") + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + # Single GPU + local_rank = 0 + world_size = 1 + rank = 0 + torch.cuda.set_device(rank) + + # Create folders where intermediate tensors from each rank will be saved + images_dir = tmp_dir.joinpath(f"images/{rank}") + image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}") + videos_dir = tmp_dir.joinpath(f"videos/{rank}") + video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}") + prompts_dir = tmp_dir.joinpath(f"prompts/{rank}") + prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}") + + images_dir.mkdir(parents=True, exist_ok=True) + image_latents_dir.mkdir(parents=True, exist_ok=True) + videos_dir.mkdir(parents=True, exist_ok=True) + video_latents_dir.mkdir(parents=True, exist_ok=True) + prompts_dir.mkdir(parents=True, exist_ok=True) + prompt_embeds_dir.mkdir(parents=True, exist_ok=True) + + weight_dtype = DTYPE_MAPPING[args.dtype] + target_fps = args.target_fps + + if weight_dtype is not None: + weight_dtype = torch.float32 + print("To get the best results, we set `weight_dtype` to `torch.float32`.") + + # 1. Dataset + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": False, + "random_flip": args.random_flip, + "image_to_video": args.save_image_latents, + } + if args.video_reshape_mode is None: + dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + original_dataset_size = len(dataset) + + # Split data among GPUs + if world_size > 1: + samples_per_gpu = original_dataset_size // world_size + start_index = rank * samples_per_gpu + end_index = start_index + samples_per_gpu + if rank == world_size - 1: + end_index = original_dataset_size # Make sure the last GPU gets the remaining data + + # Slice the data + dataset.prompts = dataset.prompts[start_index:end_index] + dataset.video_paths = dataset.video_paths[start_index:end_index] + else: + pass + + rank_dataset_size = len(dataset) + + # 2. Dataloader + def collate_fn(data): + prompts = [x["prompt"] for x in data[0]] + + images = None + if args.save_image_latents: + images = [x["image"] for x in data[0]] + images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True) + + return { + "images": images, + "videos": videos, + "prompts": prompts, + } + + dataloader = DataLoader( + dataset, + batch_size=1, + sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + ) + + # 3. Prepare models + device = f"cuda:{rank}" + + if args.save_latents_and_embeddings: + tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained( + args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype + ) + text_encoder = text_encoder.to(device) + + vae = AutoencoderKL.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype) + vae = vae.to(device) + + if args.use_slicing: + vae.enable_slicing() + if args.use_tiling: + vae.enable_tiling() + + # 4. Compute latents and embeddings and save + if rank == 0: + iterator = tqdm( + dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size + ) + else: + iterator = dataloader + + for step, batch in enumerate(iterator): + try: + images = None + image_latents = None + video_latents = None + prompt_embeds = None + + if args.save_image_latents: + images = batch["images"].to(device, non_blocking=True) + images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + + videos = batch["videos"].to(device, non_blocking=True) + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + + prompts = batch["prompts"] + + # Encode videos & images + # we run under autocast following the official recommendations of Mochi + with torch.autocast(device, torch.bfloat16, cache_enabled=False): + if args.save_latents_and_embeddings: + if args.use_slicing: + if args.save_image_latents: + encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)] + image_latents = torch.cat(encoded_slices) + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)] + video_latents = torch.cat(encoded_slices) + + else: + if args.save_image_latents: + image_latents = vae._encode(images) + image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + video_latents = vae._encode(videos) + + video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + # Encode prompts + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + args.max_sequence_length, + device, + weight_dtype, + requires_grad=False, + ) + + if images is not None: + images = (images.permute(0, 2, 1, 3, 4) + 1) / 2 + + videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2 + + output_queue.put( + { + "batch_size": len(prompts), + "fps": target_fps, + "images_dir": images_dir, + "image_latents_dir": image_latents_dir, + "videos_dir": videos_dir, + "video_latents_dir": video_latents_dir, + "prompts_dir": prompts_dir, + "prompt_embeds_dir": prompt_embeds_dir, + "images": images, + "image_latents": image_latents, + "videos": videos, + "video_latents": video_latents, + "prompts": prompts, + "prompt_embeds": prompt_embeds, + } + ) + + except Exception: + print("-------------------------") + print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}") + traceback.print_exc() + print("-------------------------") + + # 5. Complete distributed processing + if world_size > 1: + dist.barrier() + dist.destroy_process_group() + + output_queue.put(None) + save_thread.shutdown(wait=True) + save_future.result() + + # 6. Combine results from each rank + if rank == 0: + print( + f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`" + ) + + # Move files from each rank to common directory + for subfolder, extension in [ + ("images", "png"), + ("image_latents", "pt"), + ("videos", "mp4"), + ("video_latents", "pt"), + ("prompts", "txt"), + ("prompt_embeds", "pt"), + ("videos", "txt"), + ]: + tmp_subfolder = tmp_dir.joinpath(subfolder) + combined_subfolder = output_dir.joinpath(subfolder) + combined_subfolder.mkdir(parents=True, exist_ok=True) + pattern = f"*.{extension}" + + for file in tmp_subfolder.rglob(pattern): + file.replace(combined_subfolder / file.name) + + # Remove temporary directories + def rmdir_recursive(dir: pathlib.Path) -> None: + for child in dir.iterdir(): + if child.is_file(): + child.unlink() + else: + rmdir_recursive(child) + dir.rmdir() + + rmdir_recursive(tmp_dir) + + # Combine prompts and videos into individual text files and single jsonl + prompts_folder = output_dir.joinpath("prompts") + prompts = [] + stems = [] + + for filename in prompts_folder.rglob("*.txt"): + with open(filename, "r") as file: + prompts.append(file.read().strip()) + stems.append(filename.stem) + + prompts_txt = output_dir.joinpath("prompts.txt") + videos_txt = output_dir.joinpath("videos.txt") + data_jsonl = output_dir.joinpath("data.jsonl") + + with open(prompts_txt, "w") as file: + for prompt in prompts: + file.write(f"{prompt}\n") + + with open(videos_txt, "w") as file: + for stem in stems: + file.write(f"videos/{stem}.mp4\n") + + with open(data_jsonl, "w") as file: + for prompt, stem in zip(prompts, stems): + video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt") + with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file: + metadata = json.loads(metadata_file.read()) + + data = { + "prompt": prompt, + "prompt_embed": f"prompt_embeds/{stem}.pt", + "image": f"images/{stem}.png", + "image_latent": f"image_latents/{stem}.pt", + "video": f"videos/{stem}.mp4", + "video_latent": f"video_latents/{stem}.pt", + "metadata": metadata, + } + file.write(json.dumps(data) + "\n") + + print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`") + + +if __name__ == "__main__": + main() From 5ba510e0612e0baac7de5b9667bbc85609f2a6ed Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 17 Nov 2024 10:18:20 +0530 Subject: [PATCH 02/30] updates --- training/dataset.py | 56 ++++++++++++++++------------- training/mochi-1/prepare_dataset.py | 18 +++++----- training/mochi-1/prepare_dataset.sh | 48 +++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 33 deletions(-) create mode 100644 training/mochi-1/prepare_dataset.sh diff --git a/training/dataset.py b/training/dataset.py index 7ace5c06..1d250dcb 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -22,8 +22,8 @@ logger = get_logger(__name__) HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] -WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] -FRAME_BUCKETS = [16, 24, 32, 48, 64, 80] +WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536] +FRAME_BUCKETS = [16, 24, 32, 48, 64, 80, 84] class VideoDataset(Dataset): @@ -144,17 +144,17 @@ def __getitem__(self, index: int) -> Dict[str, Any]: } else: image, video, _ = self._preprocess_video(self.video_paths[index]) - - return { - "prompt": self.id_token + self.prompts[index], - "image": image, - "video": video, - "video_metadata": { - "num_frames": video.shape[0], - "height": video.shape[2], - "width": video.shape[3], - }, - } + if video is not None: + return { + "prompt": self.id_token + self.prompts[index], + "image": image, + "video": video, + "video_metadata": { + "num_frames": video.shape[0], + "height": video.shape[2], + "width": video.shape[3], + }, + } def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: if not self.data_root.exists(): @@ -276,12 +276,15 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: else: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) + nearest_frame_bucket = min( self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) ) - + if video_num_frames < nearest_frame_bucket: + # TODO: we could handle this by padding zero frames or duplicating the existing frames? + return None, None, None + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) - frames = video_reader.get_batch(frame_indices) frames = frames[:nearest_frame_bucket].float() frames = frames.permute(0, 3, 1, 2).contiguous() @@ -344,6 +347,8 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: nearest_frame_bucket = min( self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) ) + if video_num_frames < nearest_frame_bucket: + return None, None, None frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) @@ -404,16 +409,17 @@ def __len__(self): def __iter__(self): for index, data in enumerate(self.data_source): - video_metadata = data["video_metadata"] - f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] - - self.buckets[(f, h, w)].append(data) - if len(self.buckets[(f, h, w)]) == self.batch_size: - if self.shuffle: - random.shuffle(self.buckets[(f, h, w)]) - yield self.buckets[(f, h, w)] - del self.buckets[(f, h, w)] - self.buckets[(f, h, w)] = [] + if data is not None: + video_metadata = data["video_metadata"] + f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] + + self.buckets[(f, h, w)].append(data) + if len(self.buckets[(f, h, w)]) == self.batch_size: + if self.shuffle: + random.shuffle(self.buckets[(f, h, w)]) + yield self.buckets[(f, h, w)] + del self.buckets[(f, h, w)] + self.buckets[(f, h, w)] = [] if self.drop_last: return diff --git a/training/mochi-1/prepare_dataset.py b/training/mochi-1/prepare_dataset.py index 6e5fea9f..777f7aee 100644 --- a/training/mochi-1/prepare_dataset.py +++ b/training/mochi-1/prepare_dataset.py @@ -13,7 +13,7 @@ import torch import torch.distributed as dist -from diffusers import AutoencoderKL +from diffusers import AutoencoderKLMochi from diffusers.training_utils import set_seed from diffusers.utils import export_to_video, get_logger from torch.utils.data import DataLoader @@ -24,7 +24,9 @@ import decord # isort:skip -from ..dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +import sys +sys.path.append("..") +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip decord.bridge.set_bridge("torch") @@ -485,12 +487,12 @@ def collate_fn(data): tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer") text_encoder = T5EncoderModel.from_pretrained( args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype - ) - text_encoder = text_encoder.to(device) - - vae = AutoencoderKL.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype) - vae = vae.to(device) - + ).to(device) + + vae = AutoencoderKLMochi.from_pretrained( + args.model_id, subfolder="vae", torch_dtype=weight_dtype + ).to(device) + if args.use_slicing: vae.enable_slicing() if args.use_tiling: diff --git a/training/mochi-1/prepare_dataset.sh b/training/mochi-1/prepare_dataset.sh new file mode 100644 index 00000000..8081ac27 --- /dev/null +++ b/training/mochi-1/prepare_dataset.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +MODEL_ID="genmo/mochi-1-preview" + +NUM_GPUS=1 + +# For more details on the expected data format, please refer to the README. +DATA_ROOT="/home/sayak/cogvideox-factory/video-dataset-disney" # This needs to be the path to the base directory where your videos are located. +CAPTION_COLUMN="prompt.txt" +VIDEO_COLUMN="videos.txt" +OUTPUT_DIR="/home/sayak/cogvideox-factory/video-dataset-disney/mochi-1/preprocessed-dataset" +HEIGHT_BUCKETS="480" +WIDTH_BUCKETS="848" +FRAME_BUCKETS="84" +MAX_NUM_FRAMES="84" +MAX_SEQUENCE_LENGTH=256 +TARGET_FPS=30 +BATCH_SIZE=1 +DTYPE=fp32 + +# To create a folder-style dataset structure without pre-encoding videos and captions +# For Image-to-Video finetuning, make sure to pass `--save_image_latents` +CMD_WITHOUT_PRE_ENCODING="\ + torchrun --nproc_per_node=$NUM_GPUS \ + prepare_dataset.py \ + --model_id $MODEL_ID \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --output_dir $OUTPUT_DIR \ + --height_buckets $HEIGHT_BUCKETS \ + --width_buckets $WIDTH_BUCKETS \ + --frame_buckets $FRAME_BUCKETS \ + --max_num_frames $MAX_NUM_FRAMES \ + --max_sequence_length $MAX_SEQUENCE_LENGTH \ + --target_fps $TARGET_FPS \ + --batch_size $BATCH_SIZE \ + --dtype $DTYPE +" + +CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_latents_and_embeddings" + +# Select which you'd like to run +CMD=$CMD_WITH_PRE_ENCODING + +echo "===== Running \`$CMD\` =====" +eval $CMD +echo -ne "===== Finished running script =====\n" From 9852c3d994a6b83bcec28854e67effbfadf6f36d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Nov 2024 15:26:52 +0530 Subject: [PATCH 03/30] updates. --- training/mochi-1/args.py | 474 +++++++++++++ training/mochi-1/dataset.py | 444 ++++++++++++ training/mochi-1/deepspeed.yaml | 23 + training/mochi-1/prepare_dataset.py | 57 +- training/mochi-1/text_to_video_lora.py | 948 +++++++++++++++++++++++++ training/mochi-1/train.sh | 56 ++ 6 files changed, 1976 insertions(+), 26 deletions(-) create mode 100644 training/mochi-1/args.py create mode 100644 training/mochi-1/dataset.py create mode 100644 training/mochi-1/deepspeed.yaml create mode 100644 training/mochi-1/text_to_video_lora.py create mode 100644 training/mochi-1/train.sh diff --git a/training/mochi-1/args.py b/training/mochi-1/args.py new file mode 100644 index 00000000..25248e44 --- /dev/null +++ b/training/mochi-1/args.py @@ -0,0 +1,474 @@ +import argparse + + +def _get_model_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + +def _get_dataset_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--dataset_file", + type=str, + default=None, + help=("Path to a CSV file if loading prompts/video paths using this format."), + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.", + ) + parser.add_argument( + "--id_token", + type=str, + default=None, + help="Identifier token appended to the start of each prompt if provided.", + ) + parser.add_argument( + "--height_buckets", + nargs="+", + type=int, + default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--width_buckets", + nargs="+", + type=int, + default=[256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536], + ) + parser.add_argument( + "--frame_buckets", + nargs="+", + type=int, + default=[84], + ) + parser.add_argument( + "--load_tensors", + action="store_true", + help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.", + ) + parser.add_argument( + "--random_flip", + type=float, + default=None, + help="If random horizontal flip augmentation is to be used, this should be the flip probability.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Whether or not to use the pinned memory setting in pytorch dataloader.", + ) + + +def _get_validation_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", + ) + parser.add_argument( + "--enable_model_cpu_offload", + action="store_true", + default=False, + help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", + ) + + +def _get_training_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.") + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. " + "Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this " + "argument to override the accelerate config." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="mochi-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="All input videos are resized to this height.", + ) + parser.add_argument( + "--width", + type=int, + default=848, + help="All input videos are resized to this width.", + ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default=None, + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) + parser.add_argument("--fps", type=int, default=30, help="All input videos will be used at this FPS.") + parser.add_argument( + "--max_num_frames", + type=int, + default=84, + help="All input videos will be truncated to these many frames.", + ) + parser.add_argument( + "--skip_frames_start", + type=int, + default=0, + help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", + ) + parser.add_argument( + "--skip_frames_end", + type=int, + default=0, + help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + parser.add_argument( + "--noised_image_dropout", + type=float, + default=0.05, + help="Image condition dropout probability when finetuning image-to-video.", + ) + + +def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "prodigy", "came"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit", + action="store_true", + help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.", + ) + parser.add_argument( + "--use_4bit", + action="store_true", + help="Whether or not to use 4-bit optimizers from `torchao`.", + ) + parser.add_argument( + "--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers." + ) + parser.add_argument( + "--beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--beta2", + type=float, + default=0.95, + help="The beta2 parameter for the Adam and Prodigy optimizers.", + ) + parser.add_argument( + "--beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument( + "--prodigy_decouple", + action="store_true", + help="Use AdamW style decoupled weight decay.", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=1e-04, + help="Weight decay to use for optimizer.", + ) + parser.add_argument( + "--epsilon", + type=float, + default=1e-8, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--prodigy_use_bias_correction", + action="store_true", + help="Turn on Adam's bias correction.", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + action="store_true", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) + parser.add_argument( + "--use_cpu_offload_optimizer", + action="store_true", + help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.", + ) + parser.add_argument( + "--offload_gradients", + action="store_true", + help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.", + ) + + +def _get_configuration_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--nccl_timeout", + type=int, + default=600, + help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.", + ) + parser.add_argument( + "--report_to", + type=str, + default=None, + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.") + + _get_model_args(parser) + _get_dataset_args(parser) + _get_training_args(parser) + _get_validation_args(parser) + _get_optimizer_args(parser) + _get_configuration_args(parser) + + return parser.parse_args() diff --git a/training/mochi-1/dataset.py b/training/mochi-1/dataset.py new file mode 100644 index 00000000..09b9e2ce --- /dev/null +++ b/training/mochi-1/dataset.py @@ -0,0 +1,444 @@ +import random +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +import torchvision.transforms as TT +from accelerate.logging import get_logger +from torch.utils.data import Dataset, Sampler +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import resize + + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + +# TODO (sayakpaul): probably not all buckets are needed for Mochi-1? +HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536] +FRAME_BUCKETS = [16, 24, 32, 48, 64, 80, 84] + +VAE_SPATIAL_SCALE_FACTOR = 8 +VAE_TEMPORAL_SCALE_FACTOR = 6 + +class VideoDataset(Dataset): + def __init__( + self, + data_root: str, + dataset_file: Optional[str] = None, + caption_column: str = "text", + video_column: str = "video", + max_num_frames: int = 49, + id_token: Optional[str] = None, + height_buckets: List[int] = None, + width_buckets: List[int] = None, + frame_buckets: List[int] = None, + load_tensors: bool = False, + random_flip: Optional[float] = None, + image_to_video: bool = False, + ) -> None: + super().__init__() + + self.data_root = Path(data_root) + self.dataset_file = dataset_file + self.caption_column = caption_column + self.video_column = video_column + self.max_num_frames = max_num_frames + self.id_token = id_token or "" + self.height_buckets = height_buckets or HEIGHT_BUCKETS + self.width_buckets = width_buckets or WIDTH_BUCKETS + self.frame_buckets = frame_buckets or FRAME_BUCKETS + self.load_tensors = load_tensors + self.random_flip = random_flip + self.image_to_video = image_to_video + + self.resolutions = [ + (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets + ] + + # Two methods of loading data are supported. + # - Using a CSV: caption_column and video_column must be some column in the CSV. One could + # make use of other columns too, such as a motion score or aesthetic score, by modifying the + # logic in CSV processing. + # - Using two files containing line-separate captions and relative paths to videos. + # For a more detailed explanation about preparing dataset format, checkout the README. + if dataset_file is None: + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_local_path() + else: + ( + self.prompts, + self.video_paths, + ) = self._load_dataset_from_csv() + + if len(self.video_paths) != len(self.prompts): + raise ValueError( + f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + self.video_transforms = transforms.Compose( + [ + transforms.RandomHorizontalFlip(random_flip) + if random_flip + else transforms.Lambda(self.identity_transform), + transforms.Lambda(self.scale_transform), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + @staticmethod + def identity_transform(x): + return x + + @staticmethod + def scale_transform(x): + return x / 255.0 + + def __len__(self) -> int: + return len(self.video_paths) + + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + # Here, index is actually a list of data objects that we need to return. + # The BucketSampler should ideally return indices. But, in the sampler, we'd like + # to have information about num_frames, height and width. Since this is not stored + # as metadata, we need to read the video to get this information. You could read this + # information without loading the full video in memory, but we do it anyway. In order + # to not load the video twice (once to get the metadata, and once to return the loaded video + # based on sampled indices), we cache it in the BucketSampler. When the sampler is + # to yield, we yield the cache data instead of indices. So, this special check ensures + # that data is not loaded a second time. PRs are welcome for improvements. + return index + + if self.load_tensors: + image_latents, video_latents, prompt_embeds, prompt_attention_mask = self._preprocess_video(self.video_paths[index]) + + # This is hardcoded for now. + # Output of the VAE encoding is 2 * output_channels and then it's + # temporal compression factor is 6. Initially, the VAE encodings will have + # 24 latent number of frames. So, if we were to train with a + # max frame size of 84 and frame bucket of [84], we need to have the following logic. + # print(f"{video_latents.shape=}") + latent_num_frames = video_latents.size(0) + # print(f"{latent_num_frames=}") + num_frames = (latent_num_frames // 2) * (VAE_TEMPORAL_SCALE_FACTOR + 1) + + height = video_latents.size(2) * VAE_SPATIAL_SCALE_FACTOR + width = video_latents.size(3) * VAE_SPATIAL_SCALE_FACTOR + + return { + "prompt": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "image": image_latents, + "video": video_latents, + "video_metadata": { + "num_frames": num_frames, + "height": height, + "width": width, + }, + } + else: + image, video, _ = self._preprocess_video(self.video_paths[index]) + if video is not None: + return { + "prompt": self.id_token + self.prompts[index], + "image": image, + "video": video, + "video_metadata": { + "num_frames": video.shape[0], + "height": video.shape[2], + "width": video.shape[3], + }, + } + + def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: + if not self.data_root.exists(): + raise ValueError("Root folder for videos does not exist") + + prompt_path = self.data_root.joinpath(self.caption_column) + video_path = self.data_root.joinpath(self.video_column) + + if not prompt_path.exists() or not prompt_path.is_file(): + raise ValueError( + "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." + ) + if not video_path.exists() or not video_path.is_file(): + raise ValueError( + "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." + ) + + with open(prompt_path, "r", encoding="utf-8") as file: + prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] + with open(video_path, "r", encoding="utf-8") as file: + video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] + + if not self.load_tensors and any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]: + df = pd.read_csv(self.dataset_file) + prompts = df[self.caption_column].tolist() + video_paths = df[self.video_column].tolist() + video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] + + if any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return prompts, video_paths + + def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + r""" + Loads a single video, or latent and prompt embedding, based on initialization parameters. + + If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here, + F, C, H and W are the frames, channels, height and width of the input video. + + If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D]. + F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length + and embedding dimension of prompt embeddings. + """ + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + + indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames)) + frames = video_reader.get_batch(indices) + frames = frames[: self.max_num_frames].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + filename_without_ext = path.name.split(".")[0] + pt_filename = f"{filename_without_ext}.pt" + + # The current path is something like: /a/b/c/d/videos/00001.mp4 + # We need to reach: /a/b/c/d/video_latents/00001.pt + image_latents_path = path.parent.parent.joinpath("image_latents") + video_latents_path = path.parent.parent.joinpath("video_latents") + embeds_path = path.parent.parent.joinpath("prompt_embeds") + attention_mask_path = path.parent.parent.joinpath("prompt_attention_mask") + + if ( + not video_latents_path.exists() + or not embeds_path.exists() + or not attention_mask_path.exists() + or (self.image_to_video and not image_latents_path.exists()) + ): + raise ValueError( + f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains three folders named `video_latents`, `prompt_embeds`, and `prompt_attention_mask`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." + ) + + if self.image_to_video: + image_latent_filepath = image_latents_path.joinpath(pt_filename) + video_latent_filepath = video_latents_path.joinpath(pt_filename) + embeds_filepath = embeds_path.joinpath(pt_filename) + attention_mask_filepath = attention_mask_path.joinpath(pt_filename) + + if not video_latent_filepath.is_file() or not embeds_filepath.is_file() or not attention_mask_filepath.is_file(): + if self.image_to_video: + image_latent_filepath = image_latent_filepath.as_posix() + video_latent_filepath = video_latent_filepath.as_posix() + embeds_filepath = embeds_filepath.as_posix() + attention_mask_filepath = attention_mask_filepath.as_posix() + raise ValueError( + f"The file {video_latent_filepath=} or {embeds_filepath=} or {attention_mask_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." + ) + + images = ( + torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None + ) + latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) + embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) + attention_masks = torch.load(attention_mask_filepath, map_location="cpu", weights_only=True) + + return images, latents, embeds, attention_masks + + +class VideoDatasetWithResizing(VideoDataset): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def _preprocess_video(self, path: Path) -> torch.Tensor: + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + + nearest_frame_bucket = min( + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + ) + if video_num_frames < nearest_frame_bucket: + # TODO: we could handle this by padding zero frames or duplicating the existing frames? + return None, None, None + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + +class VideoDatasetWithResizeAndRectangleCrop(VideoDataset): + def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.video_reshape_mode = video_reshape_mode + + def _resize_for_rectangle_crop(self, arr, image_size): + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + def _preprocess_video(self, path: Path) -> torch.Tensor: + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + nearest_frame_bucket = min( + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + ) + if video_num_frames < nearest_frame_bucket: + return None, None, None + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + +class BucketSampler(Sampler): + r""" + PyTorch Sampler that groups 3D data by height, width and frames. + + Args: + data_source (`VideoDataset`): + A PyTorch dataset object that is an instance of `VideoDataset`. + batch_size (`int`, defaults to `8`): + The batch size to use for training. + shuffle (`bool`, defaults to `True`): + Whether or not to shuffle the data in each batch before dispatching to dataloader. + drop_last (`bool`, defaults to `False`): + Whether or not to drop incomplete buckets of data after completely iterating over all data + in the dataset. If set to True, only batches that have `batch_size` number of entries will + be yielded. If set to False, it is guaranteed that all data in the dataset will be processed + and batches that do not have `batch_size` number of entries will also be yielded. + """ + + def __init__( + self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False + ) -> None: + self.data_source = data_source + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + + self.buckets = {resolution: [] for resolution in data_source.resolutions} + + self._raised_warning_for_drop_last = False + + def __len__(self): + if self.drop_last and not self._raised_warning_for_drop_last: + self._raised_warning_for_drop_last = True + logger.warning( + "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." + ) + return (len(self.data_source) + self.batch_size - 1) // self.batch_size + + def __iter__(self): + for index, data in enumerate(self.data_source): + if data is not None: + video_metadata = data["video_metadata"] + f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] + + self.buckets[(f, h, w)].append(data) + if len(self.buckets[(f, h, w)]) == self.batch_size: + if self.shuffle: + random.shuffle(self.buckets[(f, h, w)]) + yield self.buckets[(f, h, w)] + del self.buckets[(f, h, w)] + self.buckets[(f, h, w)] = [] + + if self.drop_last: + return + + for fhw, bucket in list(self.buckets.items()): + if len(bucket) == 0: + continue + if self.shuffle: + random.shuffle(bucket) + yield bucket + del self.buckets[fhw] + self.buckets[fhw] = [] diff --git a/training/mochi-1/deepspeed.yaml b/training/mochi-1/deepspeed.yaml new file mode 100644 index 00000000..efbbf6fa --- /dev/null +++ b/training/mochi-1/deepspeed.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/training/mochi-1/prepare_dataset.py b/training/mochi-1/prepare_dataset.py index 777f7aee..2f08b55f 100644 --- a/training/mochi-1/prepare_dataset.py +++ b/training/mochi-1/prepare_dataset.py @@ -8,6 +8,7 @@ import queue import traceback import uuid +from contextlib import nullcontext from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Union @@ -25,7 +26,7 @@ import decord # isort:skip import sys -sys.path.append("..") +sys.path.append(".") from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip @@ -107,13 +108,13 @@ def get_args() -> Dict[str, Any]: "--width_buckets", nargs="+", type=check_width, - default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], + default=[256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536], ) parser.add_argument( "--frame_buckets", nargs="+", type=check_frames, - default=[49], + default=[84], ) parser.add_argument( "--random_flip", @@ -149,11 +150,11 @@ def get_args() -> Dict[str, Any]: required=True, help="Path to output directory where preprocessed videos/latents/embeddings will be saved.", ) - parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.") + parser.add_argument("--max_num_frames", type=int, default=84, help="Maximum number of frames in output video.") parser.add_argument( - "--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings." + "--max_sequence_length", type=int, default=256, help="Max sequence length of prompt embeddings." ) - parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.") + parser.add_argument("--target_fps", type=int, default=30, help="Frame rate of output videos.") parser.add_argument( "--save_latents_and_embeddings", action="store_true", @@ -213,6 +214,8 @@ def _get_t5_prompt_embeds( return_tensors="pt", ) text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool() else: if text_input_ids is None: raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") @@ -220,12 +223,13 @@ def _get_t5_prompt_embeds( prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) - return prompt_embeds + return prompt_embeds, prompt_attention_mask def encode_prompt( @@ -233,13 +237,13 @@ def encode_prompt( text_encoder: T5EncoderModel, prompt: Union[str, List[str]], num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, + max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - prompt_embeds = _get_t5_prompt_embeds( + prompt_embeds, prompt_attention_mask = _get_t5_prompt_embeds( tokenizer, text_encoder, prompt=prompt, @@ -249,7 +253,7 @@ def encode_prompt( dtype=dtype, text_input_ids=text_input_ids, ) - return prompt_embeds + return prompt_embeds, prompt_attention_mask def compute_prompt_embeddings( @@ -261,8 +265,9 @@ def compute_prompt_embeddings( dtype: torch.dtype, requires_grad: bool = False, ): - if requires_grad: - prompt_embeds = encode_prompt( + ctx = nullcontext() if requires_grad else torch.no_grad() + with ctx: + prompt_embeds, prompt_attention_mask = encode_prompt( tokenizer, text_encoder, prompts, @@ -271,18 +276,7 @@ def compute_prompt_embeddings( device=device, dtype=dtype, ) - else: - with torch.no_grad(): - prompt_embeds = encode_prompt( - tokenizer, - text_encoder, - prompts, - num_videos_per_prompt=1, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - return prompt_embeds + return prompt_embeds, prompt_attention_mask to_pil_image = transforms.ToPILImage(mode="RGB") @@ -318,12 +312,14 @@ def serialize_artifacts( video_latents_dir: Optional[pathlib.Path] = None, prompts_dir: Optional[pathlib.Path] = None, prompt_embeds_dir: Optional[pathlib.Path] = None, + prompt_attention_mask_dir: Optional[pathlib.Path] = None, images: Optional[torch.Tensor] = None, image_latents: Optional[torch.Tensor] = None, videos: Optional[torch.Tensor] = None, video_latents: Optional[torch.Tensor] = None, prompts: Optional[List[str]] = None, prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None ) -> None: num_frames, height, width = videos.size(1), videos.size(3), videos.size(4) metadata = [{"num_frames": num_frames, "height": height, "width": width}] @@ -335,6 +331,7 @@ def serialize_artifacts( (video_latents, video_latents_dir, torch.save, "pt"), (prompts, prompts_dir, save_prompt, "txt"), (prompt_embeds, prompt_embeds_dir, torch.save, "pt"), + (prompt_attention_mask, prompt_attention_mask_dir, torch.save, "pt"), (metadata, videos_dir, save_metadata, "txt"), ] filenames = [uuid.uuid4() for _ in range(batch_size)] @@ -398,6 +395,7 @@ def main(): video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}") prompts_dir = tmp_dir.joinpath(f"prompts/{rank}") prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}") + prompt_attention_mask_dir = tmp_dir.joinpath(f"prompt_attention_mask/{rank}") images_dir.mkdir(parents=True, exist_ok=True) image_latents_dir.mkdir(parents=True, exist_ok=True) @@ -405,6 +403,7 @@ def main(): video_latents_dir.mkdir(parents=True, exist_ok=True) prompts_dir.mkdir(parents=True, exist_ok=True) prompt_embeds_dir.mkdir(parents=True, exist_ok=True) + prompt_attention_mask_dir.mkdir(parents=True, exist_ok=True) weight_dtype = DTYPE_MAPPING[args.dtype] target_fps = args.target_fps @@ -543,9 +542,10 @@ def collate_fn(data): video_latents = vae._encode(videos) video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + print(f"{video_latents.shape=}") # Encode prompts - prompt_embeds = compute_prompt_embeddings( + prompt_embeds, prompt_attention_mask = compute_prompt_embeddings( tokenizer, text_encoder, prompts, @@ -554,6 +554,7 @@ def collate_fn(data): weight_dtype, requires_grad=False, ) + print(f"{prompt_attention_mask.shape=}") if images is not None: images = (images.permute(0, 2, 1, 3, 4) + 1) / 2 @@ -570,12 +571,14 @@ def collate_fn(data): "video_latents_dir": video_latents_dir, "prompts_dir": prompts_dir, "prompt_embeds_dir": prompt_embeds_dir, + "prompt_attention_mask_dir": prompt_attention_mask_dir, "images": images, "image_latents": image_latents, "videos": videos, "video_latents": video_latents, "prompts": prompts, "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, } ) @@ -608,6 +611,7 @@ def collate_fn(data): ("video_latents", "pt"), ("prompts", "txt"), ("prompt_embeds", "pt"), + ("prompt_attention_mask", "pt"), ("videos", "txt"), ]: tmp_subfolder = tmp_dir.joinpath(subfolder) @@ -660,6 +664,7 @@ def rmdir_recursive(dir: pathlib.Path) -> None: data = { "prompt": prompt, "prompt_embed": f"prompt_embeds/{stem}.pt", + "prompt_attention_mask": f"prompt_attention_mask/{stem}.pt", "image": f"images/{stem}.png", "image_latent": f"image_latents/{stem}.pt", "video": f"videos/{stem}.mp4", diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py new file mode 100644 index 00000000..fd48f195 --- /dev/null +++ b/training/mochi-1/text_to_video_lora.py @@ -0,0 +1,948 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import math +import os +import shutil +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict + +import diffusers +import torch +import transformers +import wandb +import copy +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, +) +from diffusers import ( + AutoencoderKLMochi, + FlowMatchEulerDiscreteScheduler, + MochiPipeline, + MochiTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 +from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel + +from args import get_args # isort:skip + +import sys +sys.path.append("..") + +from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from text_encoder import compute_prompt_embeddings # isort:skip +from utils import get_gradient_norm, get_optimizer, prepare_rotary_positional_embeddings, print_memory, reset_memory # isort:skip + + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": f"video_{i}.mp4"}, + } + ) + + model_description = f""" +# Mochi-1 Preview LoRA Finetune + + + +## Model description + +This is a lora finetune of the Moch-1 preview model `{base_model}`. + +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX, Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. + +```py +TODO +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. + +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="apache-2.0", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "lora", + "mochi-1-preview", + "mochi-1-preview-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + accelerator: Accelerator, + pipe: MochiPipeline, + args: Dict[str, Any], + pipeline_args: Dict[str, Any], + epoch, + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + + pipe = pipe.to(accelerator.device) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + with torch.autocast(accelerator.device.type, torch.bfloat16, cache_enabled=False): + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=30) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30) + for i, filename in enumerate(video_filenames) + ] + } + ) + + return videos + + +class CollateFunction: + def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: + self.weight_dtype = weight_dtype + self.load_tensors = load_tensors + + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompts = [x["prompt"] for x in data[0]] + prompt_attention_mask = None + + if self.load_tensors: + prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) + prompt_attention_mask = torch.stack([x["prompt_attention_mask"] for x in data[0]]) + + videos = [x["video"] for x in data[0]] + videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) + + out_dict = { + "videos": videos, + "prompts": prompts, + } + if prompt_attention_mask is not None: + out_dict.update({"prompt_attention_mask": prompt_attention_mask}) + return out_dict + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + if not args.load_tensors: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + vae = AutoencoderKLMochi.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.to(accelerator.device, dtype=weight_dtype) + + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = MochiTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + noise_scheduler_copy = copy.deepcopy(scheduler) + + vae_config = AutoencoderKLMochi.load_config(args.pretrained_model_name_or_path, subfolder="vae") + vae_in_channels = vae_config["latent_channels"] + has_latents_mean = "latents_mean" in vae_config and vae_config["latents_mean"] is not None + has_latents_std = "latents_std" in vae_config and vae_config["latents_std"] is not None + + VAE_SCALING_FACTOR = vae_config["scaling_factor"] + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.bfloat16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + transformer.requires_grad_(False) + transformer.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + MochiPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + # This is a bit of a hack but I don't know any other solution. + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"Unexpected save model: {unwrap_model(model).__class__}") + else: + transformer_ = MochiTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = MochiPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = { + "params": transformer_lora_parameters, + "lr": args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=args.optimizer, + learning_rate=args.learning_rate, + beta1=args.beta1, + beta2=args.beta2, + beta3=args.beta3, + epsilon=args.epsilon, + weight_decay=args.weight_decay, + prodigy_decouple=args.prodigy_decouple, + prodigy_use_bias_correction=args.prodigy_use_bias_correction, + prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, + use_8bit=args.use_8bit, + use_4bit=args.use_4bit, + use_torchao=args.use_torchao, + use_deepspeed=use_deepspeed_optimizer, + use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, + offload_gradients=args.offload_gradients, + ) + + # Dataset and DataLoader + dataset_init_kwargs = { + "data_root": args.data_root, + "dataset_file": args.dataset_file, + "caption_column": args.caption_column, + "video_column": args.video_column, + "max_num_frames": args.max_num_frames, + "id_token": args.id_token, + "height_buckets": args.height_buckets, + "width_buckets": args.width_buckets, + "frame_buckets": args.frame_buckets, + "load_tensors": args.load_tensors, + "random_flip": args.random_flip, + } + if args.video_reshape_mode is None: + train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) + else: + train_dataset = VideoDatasetWithResizeAndRectangleCrop( + video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs + ) + + collate_fn = CollateFunction(weight_dtype, args.load_tensors) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + pin_memory=args.pin_memory, + prefetch_factor=4, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.use_cpu_offload_optimizer: + lr_scheduler = None + accelerator.print( + "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " + "you are training with those settings, they will be ignored." + ) + else: + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + tracker_name = args.tracker_name or "mochi-1-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + accelerator.print("===== Memory before training =====") + reset_memory(accelerator.device) + print_memory(accelerator.device) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("***** Running training *****") + accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") + accelerator.print(f" Num examples = {len(train_dataset)}") + accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") + accelerator.print(f" Num epochs = {args.num_train_epochs}") + accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") + accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + if args.load_tensors: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + + with accelerator.accumulate(models_to_accumulate): + videos = batch["videos"].to(accelerator.device, non_blocking=True) + prompts = batch["prompts"] + if args.load_tensors: + prompt_attention_mask = batch["prompt_attention_mask"] + + # Encode videos + if not args.load_tensors: + videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(videos).latent_dist + else: + latent_dist = DiagonalGaussianDistribution(videos) + + videos = latent_dist.sample() + videos = videos[:, :vae_in_channels, ...] # to respect `in_channels` for the vae + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(vae_config["latents_mean"]).view(1, vae_in_channels, 1, 1, 1).to(videos.device, videos.dtype) + ) + latents_std = ( + torch.tensor(vae_config["latents_std"]).view(1, vae_in_channels, 1, 1, 1).to(videos.device, videos.dtype) + ) + videos = (videos - latents_mean) * VAE_SCALING_FACTOR / latents_std + else: + videos = videos * VAE_SCALING_FACTOR + + videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + model_input = videos + + # Encode prompts + if not args.load_tensors: + prompt_embeds, prompt_attention_mask = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + else: + prompt_embeds = prompts.to(dtype=weight_dtype) + prompt_attention_mask = prompt_attention_mask.to(accelerator.device) + + # Sample noise that will be added to the latents + noise = torch.randn_like(model_input) + batch_size, num_channels, num_frames, height, width = model_input.shape + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=batch_size, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + # print(f"{timesteps=}") + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # Predict the noise residual + model_pred = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timesteps, + return_dict=False, + )[0] + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + loss = torch.mean( + (weighting * (model_pred.float() - target.float()) ** 2).reshape(batch_size, -1), + dim=1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + gradient_norm_before_clip = get_gradient_norm(transformer_lora_parameters) + accelerator.clip_grad_norm_(transformer_lora_parameters, args.max_grad_norm) + gradient_norm_after_clip = get_gradient_norm(transformer_lora_parameters) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + if not args.use_cpu_offload_optimizer: + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate + logs = {"loss": loss.detach().item(), "lr": last_lr} + # gradnorm + deepspeed: https://github.com/microsoft/DeepSpeed/issues/4555 + if accelerator.distributed_type != DistributedType.DEEPSPEED: + logs.update( + { + "gradient_norm_before_clip": gradient_norm_before_clip, + "gradient_norm_after_clip": gradient_norm_after_clip, + } + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + accelerator.print("===== Memory before validation =====") + print_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + pipe = MochiPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + # torch_dtype=weight_dtype, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": 4.5, + "height": args.height, + "width": args.width, + "max_sequence_length": 256, + } + + log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + + accelerator.print("===== Memory after validation =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + del pipe + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + # transformer = transformer.to(dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + MochiPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + + # Cleanup trained models to save memory + if args.load_tensors: + del transformer + else: + del transformer, text_encoder, vae + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(accelerator.device) + + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + # Final test inference + pipe = MochiPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + # torch_dtype=weight_dtype, + ) + pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora") + pipe.set_adapters(["mochi-lora"], [lora_scaling]) + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": 4.5, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + accelerator=accelerator, + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + validation_outputs.extend(video) + + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/training/mochi-1/train.sh b/training/mochi-1/train.sh new file mode 100644 index 00000000..4c4871e2 --- /dev/null +++ b/training/mochi-1/train.sh @@ -0,0 +1,56 @@ +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 + +GPU_IDS="2" + +DATA_ROOT="/home/sayak/cogvideox-factory/video-dataset-disney/mochi-1/preprocessed-dataset" + +CAPTION_COLUMN="prompts.txt" +VIDEO_COLUMN="videos.txt" + +cmd="accelerate launch --config_file deepspeed.yaml --gpu_ids $GPU_IDS text_to_video_lora.py \ + --pretrained_model_name_or_path genmo/mochi-1-preview \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --id_token BW_STYLE \ + --height_buckets 480 \ + --width_buckets 848 \ + --frame_buckets 84 \ + --load_tensors \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 1 \ + --seed 42 \ + --rank 64 \ + --lora_alpha 64 \ + --mixed_precision bf16 \ + --output_dir mochi-lora \ + --max_num_frames 84 \ + --train_batch_size 1 \ + --dataloader_num_workers 4 \ + --max_train_steps 500 \ + --checkpointing_steps 50 \ + --gradient_accumulation_steps 4 \ + --gradient_checkpointing \ + --learning_rate 0.0001 \ + --lr_scheduler constant \ + --lr_warmup_steps 0 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --optimizer adamw \ + --beta1 0.9 \ + --beta2 0.95 \ + --beta3 0.99 \ + --weight_decay 0.001 \ + --max_grad_norm 1.0 \ + --allow_tf32 \ + --report_to wandb \ + --push_to_hub \ + --nccl_timeout 1800" + +echo "Running command: $cmd" +eval $cmd +echo -ne "-------------------- Finished executing script --------------------\n\n" \ No newline at end of file From a40ccd22379e899377b77dcd1632e0fbcd4cbeda Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Nov 2024 15:28:37 +0530 Subject: [PATCH 04/30] updates --- training/dataset.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/training/dataset.py b/training/dataset.py index 1d250dcb..411fdfc4 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -22,8 +22,8 @@ logger = get_logger(__name__) HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] -WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536] -FRAME_BUCKETS = [16, 24, 32, 48, 64, 80, 84] +WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +FRAME_BUCKETS = [16, 24, 32, 48, 64, 80] class VideoDataset(Dataset): @@ -144,17 +144,16 @@ def __getitem__(self, index: int) -> Dict[str, Any]: } else: image, video, _ = self._preprocess_video(self.video_paths[index]) - if video is not None: - return { - "prompt": self.id_token + self.prompts[index], - "image": image, - "video": video, - "video_metadata": { - "num_frames": video.shape[0], - "height": video.shape[2], - "width": video.shape[3], - }, - } + return { + "prompt": self.id_token + self.prompts[index], + "image": image, + "video": video, + "video_metadata": { + "num_frames": video.shape[0], + "height": video.shape[2], + "width": video.shape[3], + }, + } def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: if not self.data_root.exists(): @@ -276,13 +275,9 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: else: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) - nearest_frame_bucket = min( self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) ) - if video_num_frames < nearest_frame_bucket: - # TODO: we could handle this by padding zero frames or duplicating the existing frames? - return None, None, None frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) frames = video_reader.get_batch(frame_indices) From 4c05ea6669640cb304dc37949bcfc8a3080d500c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Nov 2024 15:30:23 +0530 Subject: [PATCH 05/30] updates --- training/dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/training/dataset.py b/training/dataset.py index 411fdfc4..3176d9d5 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -278,7 +278,7 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: nearest_frame_bucket = min( self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) ) - + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) frames = video_reader.get_batch(frame_indices) frames = frames[:nearest_frame_bucket].float() @@ -342,8 +342,6 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: nearest_frame_bucket = min( self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) ) - if video_num_frames < nearest_frame_bucket: - return None, None, None frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) From a8dd94ee1c7644f803eeb34cd7db58dc43910676 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Nov 2024 15:31:11 +0530 Subject: [PATCH 06/30] updates --- training/dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/training/dataset.py b/training/dataset.py index 3176d9d5..309f337d 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -144,6 +144,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: } else: image, video, _ = self._preprocess_video(self.video_paths[index]) + return { "prompt": self.id_token + self.prompts[index], "image": image, @@ -280,6 +281,7 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: ) frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + frames = video_reader.get_batch(frame_indices) frames = frames[:nearest_frame_bucket].float() frames = frames.permute(0, 3, 1, 2).contiguous() From ac83c78a5ac1980a9b8f72b83d210319bae87114 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Nov 2024 15:32:04 +0530 Subject: [PATCH 07/30] updates --- training/dataset.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/training/dataset.py b/training/dataset.py index 309f337d..7ace5c06 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -144,7 +144,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: } else: image, video, _ = self._preprocess_video(self.video_paths[index]) - + return { "prompt": self.id_token + self.prompts[index], "image": image, @@ -281,7 +281,7 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: ) frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) - + frames = video_reader.get_batch(frame_indices) frames = frames[:nearest_frame_bucket].float() frames = frames.permute(0, 3, 1, 2).contiguous() @@ -404,17 +404,16 @@ def __len__(self): def __iter__(self): for index, data in enumerate(self.data_source): - if data is not None: - video_metadata = data["video_metadata"] - f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] - - self.buckets[(f, h, w)].append(data) - if len(self.buckets[(f, h, w)]) == self.batch_size: - if self.shuffle: - random.shuffle(self.buckets[(f, h, w)]) - yield self.buckets[(f, h, w)] - del self.buckets[(f, h, w)] - self.buckets[(f, h, w)] = [] + video_metadata = data["video_metadata"] + f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] + + self.buckets[(f, h, w)].append(data) + if len(self.buckets[(f, h, w)]) == self.batch_size: + if self.shuffle: + random.shuffle(self.buckets[(f, h, w)]) + yield self.buckets[(f, h, w)] + del self.buckets[(f, h, w)] + self.buckets[(f, h, w)] = [] if self.drop_last: return From 1409d478d53fbc41d015ba8d6b86e6bbbb355e53 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Nov 2024 10:21:28 +0530 Subject: [PATCH 08/30] updates. --- training/mochi-1/args.py | 7 +++++++ training/mochi-1/dataset.py | 15 ++++++--------- training/mochi-1/prepare_dataset.py | 9 +++++---- training/mochi-1/prepare_dataset.sh | 5 +++-- training/mochi-1/text_to_video_lora.py | 2 +- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/training/mochi-1/args.py b/training/mochi-1/args.py index 25248e44..ee7c463d 100644 --- a/training/mochi-1/args.py +++ b/training/mochi-1/args.py @@ -151,6 +151,13 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: default=64, help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", ) + parser.add_argument( + "--target_modules", + nargs="+", + type=str, + default=["to_k", "to_q", "to_v", "to_out.0"], + help="Target modules to train LoRA for." + ) parser.add_argument( "--mixed_precision", type=str, diff --git a/training/mochi-1/dataset.py b/training/mochi-1/dataset.py index 09b9e2ce..0a95ab70 100644 --- a/training/mochi-1/dataset.py +++ b/training/mochi-1/dataset.py @@ -128,9 +128,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: # temporal compression factor is 6. Initially, the VAE encodings will have # 24 latent number of frames. So, if we were to train with a # max frame size of 84 and frame bucket of [84], we need to have the following logic. - # print(f"{video_latents.shape=}") latent_num_frames = video_latents.size(0) - # print(f"{latent_num_frames=}") num_frames = (latent_num_frames // 2) * (VAE_TEMPORAL_SCALE_FACTOR + 1) height = video_latents.size(2) * VAE_SPATIAL_SCALE_FACTOR @@ -288,11 +286,10 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: video_num_frames = len(video_reader) nearest_frame_bucket = min( - self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], + key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), + default=1, ) - if video_num_frames < nearest_frame_bucket: - # TODO: we could handle this by padding zero frames or duplicating the existing frames? - return None, None, None frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) frames = video_reader.get_batch(frame_indices) @@ -355,10 +352,10 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) nearest_frame_bucket = min( - self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], + key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), + default=1, ) - if video_num_frames < nearest_frame_bucket: - return None, None, None frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) diff --git a/training/mochi-1/prepare_dataset.py b/training/mochi-1/prepare_dataset.py index 2f08b55f..d9da99a2 100644 --- a/training/mochi-1/prepare_dataset.py +++ b/training/mochi-1/prepare_dataset.py @@ -321,8 +321,11 @@ def serialize_artifacts( prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None ) -> None: - num_frames, height, width = videos.size(1), videos.size(3), videos.size(4) - metadata = [{"num_frames": num_frames, "height": height, "width": width}] + metadata = [] + for i in range(videos.size(0)): + video = videos[i:i+1] + metadata_dict = {"num_frames": video.size(1), "height": video.size(3), "width": video.size(4)} + metadata.append(metadata_dict) data_folder_mapper_list = [ (images, images_dir, lambda img, path: save_image(img[0], path), "png"), @@ -542,7 +545,6 @@ def collate_fn(data): video_latents = vae._encode(videos) video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) - print(f"{video_latents.shape=}") # Encode prompts prompt_embeds, prompt_attention_mask = compute_prompt_embeddings( @@ -554,7 +556,6 @@ def collate_fn(data): weight_dtype, requires_grad=False, ) - print(f"{prompt_attention_mask.shape=}") if images is not None: images = (images.permute(0, 2, 1, 3, 4) + 1) / 2 diff --git a/training/mochi-1/prepare_dataset.sh b/training/mochi-1/prepare_dataset.sh index 8081ac27..02786e52 100644 --- a/training/mochi-1/prepare_dataset.sh +++ b/training/mochi-1/prepare_dataset.sh @@ -11,11 +11,11 @@ VIDEO_COLUMN="videos.txt" OUTPUT_DIR="/home/sayak/cogvideox-factory/video-dataset-disney/mochi-1/preprocessed-dataset" HEIGHT_BUCKETS="480" WIDTH_BUCKETS="848" -FRAME_BUCKETS="84" +FRAME_BUCKETS="1 84" MAX_NUM_FRAMES="84" MAX_SEQUENCE_LENGTH=256 TARGET_FPS=30 -BATCH_SIZE=1 +BATCH_SIZE=4 DTYPE=fp32 # To create a folder-style dataset structure without pre-encoding videos and captions @@ -35,6 +35,7 @@ CMD_WITHOUT_PRE_ENCODING="\ --max_sequence_length $MAX_SEQUENCE_LENGTH \ --target_fps $TARGET_FPS \ --batch_size $BATCH_SIZE \ + --use_slicing \ --dtype $DTYPE " diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index fd48f195..9f423666 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -354,7 +354,7 @@ def main(args): r=args.rank, lora_alpha=args.lora_alpha, init_lora_weights=True, - target_modules=["to_k", "to_q", "to_v", "to_out.0"], + target_modules=args.target_modules, ) transformer.add_adapter(transformer_lora_config) From c80bd177fdced754b2cb9e9336c6be1ce58604a5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Nov 2024 10:27:24 +0530 Subject: [PATCH 09/30] nearest_frame_bucket. --- training/dataset.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/training/dataset.py b/training/dataset.py index 7ace5c06..bb8f8b90 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -277,7 +277,9 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) nearest_frame_bucket = min( - self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], + key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), + default=1, ) frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) @@ -342,7 +344,9 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) nearest_frame_bucket = min( - self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], + key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), + default=1, ) frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) From 316a705f6a8081af2de388c64fc9b4833daf3a82 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Nov 2024 10:29:16 +0530 Subject: [PATCH 10/30] revert changes to training/dataset.py --- training/dataset.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/training/dataset.py b/training/dataset.py index bb8f8b90..ec47b0b3 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -49,7 +49,7 @@ def __init__( self.caption_column = caption_column self.video_column = video_column self.max_num_frames = max_num_frames - self.id_token = id_token or "" + self.id_token = f"{id_token.strip()} " if id_token else "" self.height_buckets = height_buckets or HEIGHT_BUCKETS self.width_buckets = width_buckets or WIDTH_BUCKETS self.frame_buckets = frame_buckets or FRAME_BUCKETS @@ -277,9 +277,7 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) nearest_frame_bucket = min( - [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], - key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), - default=1, + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) ) frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) @@ -344,9 +342,7 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) nearest_frame_bucket = min( - [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], - key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), - default=1, + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) ) frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) @@ -429,4 +425,4 @@ def __iter__(self): random.shuffle(bucket) yield bucket del self.buckets[fhw] - self.buckets[fhw] = [] + self.buckets[fhw] = [] \ No newline at end of file From 440dc257c85635af8ee93cdfbdd7f294ceeca317 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Nov 2024 10:51:00 +0530 Subject: [PATCH 11/30] better reuse. --- training/mochi-1/dataset.py | 232 ++----------------------- training/mochi-1/text_to_video_lora.py | 4 +- 2 files changed, 19 insertions(+), 217 deletions(-) diff --git a/training/mochi-1/dataset.py b/training/mochi-1/dataset.py index 0a95ab70..fa109c41 100644 --- a/training/mochi-1/dataset.py +++ b/training/mochi-1/dataset.py @@ -1,13 +1,10 @@ -import random from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Tuple import numpy as np -import pandas as pd import torch import torchvision.transforms as TT from accelerate.logging import get_logger -from torch.utils.data import Dataset, Sampler from torchvision import transforms from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import resize @@ -19,6 +16,12 @@ decord.bridge.set_bridge("torch") +import sys +sys.path.append("..") + +from dataset import VideoDataset as VDS +from dataset import BucketSampler + logger = get_logger(__name__) # TODO (sayakpaul): probably not all buckets are needed for Mochi-1? @@ -29,84 +32,11 @@ VAE_SPATIAL_SCALE_FACTOR = 8 VAE_TEMPORAL_SCALE_FACTOR = 6 -class VideoDataset(Dataset): - def __init__( - self, - data_root: str, - dataset_file: Optional[str] = None, - caption_column: str = "text", - video_column: str = "video", - max_num_frames: int = 49, - id_token: Optional[str] = None, - height_buckets: List[int] = None, - width_buckets: List[int] = None, - frame_buckets: List[int] = None, - load_tensors: bool = False, - random_flip: Optional[float] = None, - image_to_video: bool = False, - ) -> None: - super().__init__() - - self.data_root = Path(data_root) - self.dataset_file = dataset_file - self.caption_column = caption_column - self.video_column = video_column - self.max_num_frames = max_num_frames - self.id_token = id_token or "" - self.height_buckets = height_buckets or HEIGHT_BUCKETS - self.width_buckets = width_buckets or WIDTH_BUCKETS - self.frame_buckets = frame_buckets or FRAME_BUCKETS - self.load_tensors = load_tensors - self.random_flip = random_flip - self.image_to_video = image_to_video - - self.resolutions = [ - (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets - ] - - # Two methods of loading data are supported. - # - Using a CSV: caption_column and video_column must be some column in the CSV. One could - # make use of other columns too, such as a motion score or aesthetic score, by modifying the - # logic in CSV processing. - # - Using two files containing line-separate captions and relative paths to videos. - # For a more detailed explanation about preparing dataset format, checkout the README. - if dataset_file is None: - ( - self.prompts, - self.video_paths, - ) = self._load_dataset_from_local_path() - else: - ( - self.prompts, - self.video_paths, - ) = self._load_dataset_from_csv() - - if len(self.video_paths) != len(self.prompts): - raise ValueError( - f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." - ) - - self.video_transforms = transforms.Compose( - [ - transforms.RandomHorizontalFlip(random_flip) - if random_flip - else transforms.Lambda(self.identity_transform), - transforms.Lambda(self.scale_transform), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), - ] - ) - - @staticmethod - def identity_transform(x): - return x - - @staticmethod - def scale_transform(x): - return x / 255.0 - - def __len__(self) -> int: - return len(self.video_paths) +class VideoDataset(VDS): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Overriding this because we calculate `num_frames` differently. def __getitem__(self, index: int) -> Dict[str, Any]: if isinstance(index, list): # Here, index is actually a list of data objects that we need to return. @@ -159,74 +89,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: }, } - def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: - if not self.data_root.exists(): - raise ValueError("Root folder for videos does not exist") - - prompt_path = self.data_root.joinpath(self.caption_column) - video_path = self.data_root.joinpath(self.video_column) - - if not prompt_path.exists() or not prompt_path.is_file(): - raise ValueError( - "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." - ) - if not video_path.exists() or not video_path.is_file(): - raise ValueError( - "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." - ) - - with open(prompt_path, "r", encoding="utf-8") as file: - prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] - with open(video_path, "r", encoding="utf-8") as file: - video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] - - if not self.load_tensors and any(not path.is_file() for path in video_paths): - raise ValueError( - f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." - ) - - return prompts, video_paths - - def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]: - df = pd.read_csv(self.dataset_file) - prompts = df[self.caption_column].tolist() - video_paths = df[self.video_column].tolist() - video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] - - if any(not path.is_file() for path in video_paths): - raise ValueError( - f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." - ) - - return prompts, video_paths - - def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - r""" - Loads a single video, or latent and prompt embedding, based on initialization parameters. - - If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here, - F, C, H and W are the frames, channels, height and width of the input video. - - If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D]. - F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length - and embedding dimension of prompt embeddings. - """ - if self.load_tensors: - return self._load_preprocessed_latents_and_embeds(path) - else: - video_reader = decord.VideoReader(uri=path.as_posix()) - video_num_frames = len(video_reader) - - indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames)) - frames = video_reader.get_batch(indices) - frames = frames[: self.max_num_frames].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) - - image = frames[:1].clone() if self.image_to_video else None - - return image, frames, None - + # Overriding this because we need `prompt_attention_mask`. def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]: filename_without_ext = path.name.split(".")[0] pt_filename = f"{filename_without_ext}.pt" @@ -274,6 +137,10 @@ def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tenso return images, latents, embeds, attention_masks +# We need the `VideoDatasetWithResizing` and `VideoDatasetWithResizeAndRectangleCrop` classes to subclass from +# the new `VideoDataset` class defined in this file. And also because of the changes in +# `_preprocess_video()` (how we handle `nearest_frame_bucket`). + class VideoDatasetWithResizing(VideoDataset): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -373,69 +240,4 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: def _find_nearest_resolution(self, height, width): nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) - return nearest_res[1], nearest_res[2] - - -class BucketSampler(Sampler): - r""" - PyTorch Sampler that groups 3D data by height, width and frames. - - Args: - data_source (`VideoDataset`): - A PyTorch dataset object that is an instance of `VideoDataset`. - batch_size (`int`, defaults to `8`): - The batch size to use for training. - shuffle (`bool`, defaults to `True`): - Whether or not to shuffle the data in each batch before dispatching to dataloader. - drop_last (`bool`, defaults to `False`): - Whether or not to drop incomplete buckets of data after completely iterating over all data - in the dataset. If set to True, only batches that have `batch_size` number of entries will - be yielded. If set to False, it is guaranteed that all data in the dataset will be processed - and batches that do not have `batch_size` number of entries will also be yielded. - """ - - def __init__( - self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False - ) -> None: - self.data_source = data_source - self.batch_size = batch_size - self.shuffle = shuffle - self.drop_last = drop_last - - self.buckets = {resolution: [] for resolution in data_source.resolutions} - - self._raised_warning_for_drop_last = False - - def __len__(self): - if self.drop_last and not self._raised_warning_for_drop_last: - self._raised_warning_for_drop_last = True - logger.warning( - "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." - ) - return (len(self.data_source) + self.batch_size - 1) // self.batch_size - - def __iter__(self): - for index, data in enumerate(self.data_source): - if data is not None: - video_metadata = data["video_metadata"] - f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] - - self.buckets[(f, h, w)].append(data) - if len(self.buckets[(f, h, w)]) == self.batch_size: - if self.shuffle: - random.shuffle(self.buckets[(f, h, w)]) - yield self.buckets[(f, h, w)] - del self.buckets[(f, h, w)] - self.buckets[(f, h, w)] = [] - - if self.drop_last: - return - - for fhw, bucket in list(self.buckets.items()): - if len(bucket) == 0: - continue - if self.shuffle: - random.shuffle(bucket) - yield bucket - del self.buckets[fhw] - self.buckets[fhw] = [] + return nearest_res[1], nearest_res[2] \ No newline at end of file diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 9f423666..d0d17369 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -56,11 +56,11 @@ from args import get_args # isort:skip import sys -sys.path.append("..") +sys.path.append(".") from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip from text_encoder import compute_prompt_embeddings # isort:skip -from utils import get_gradient_norm, get_optimizer, prepare_rotary_positional_embeddings, print_memory, reset_memory # isort:skip +from utils import get_gradient_norm, get_optimizer, print_memory, reset_memory # isort:skip logger = get_logger(__name__) From a01592b25db4ec60e7af540d19e8836431411109 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Nov 2024 11:32:24 +0530 Subject: [PATCH 12/30] betterments. --- training/mochi-1/dataset.py | 243 ------------------------- training/mochi-1/prepare_dataset.py | 8 +- training/mochi-1/text_to_video_lora.py | 5 +- 3 files changed, 8 insertions(+), 248 deletions(-) delete mode 100644 training/mochi-1/dataset.py diff --git a/training/mochi-1/dataset.py b/training/mochi-1/dataset.py deleted file mode 100644 index fa109c41..00000000 --- a/training/mochi-1/dataset.py +++ /dev/null @@ -1,243 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Tuple - -import numpy as np -import torch -import torchvision.transforms as TT -from accelerate.logging import get_logger -from torchvision import transforms -from torchvision.transforms import InterpolationMode -from torchvision.transforms.functional import resize - - -# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error -# Very few bug reports but it happens. Look in decord Github issues for more relevant information. -import decord # isort:skip - -decord.bridge.set_bridge("torch") - -import sys -sys.path.append("..") - -from dataset import VideoDataset as VDS -from dataset import BucketSampler - -logger = get_logger(__name__) - -# TODO (sayakpaul): probably not all buckets are needed for Mochi-1? -HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] -WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536] -FRAME_BUCKETS = [16, 24, 32, 48, 64, 80, 84] - -VAE_SPATIAL_SCALE_FACTOR = 8 -VAE_TEMPORAL_SCALE_FACTOR = 6 - -class VideoDataset(VDS): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - # Overriding this because we calculate `num_frames` differently. - def __getitem__(self, index: int) -> Dict[str, Any]: - if isinstance(index, list): - # Here, index is actually a list of data objects that we need to return. - # The BucketSampler should ideally return indices. But, in the sampler, we'd like - # to have information about num_frames, height and width. Since this is not stored - # as metadata, we need to read the video to get this information. You could read this - # information without loading the full video in memory, but we do it anyway. In order - # to not load the video twice (once to get the metadata, and once to return the loaded video - # based on sampled indices), we cache it in the BucketSampler. When the sampler is - # to yield, we yield the cache data instead of indices. So, this special check ensures - # that data is not loaded a second time. PRs are welcome for improvements. - return index - - if self.load_tensors: - image_latents, video_latents, prompt_embeds, prompt_attention_mask = self._preprocess_video(self.video_paths[index]) - - # This is hardcoded for now. - # Output of the VAE encoding is 2 * output_channels and then it's - # temporal compression factor is 6. Initially, the VAE encodings will have - # 24 latent number of frames. So, if we were to train with a - # max frame size of 84 and frame bucket of [84], we need to have the following logic. - latent_num_frames = video_latents.size(0) - num_frames = (latent_num_frames // 2) * (VAE_TEMPORAL_SCALE_FACTOR + 1) - - height = video_latents.size(2) * VAE_SPATIAL_SCALE_FACTOR - width = video_latents.size(3) * VAE_SPATIAL_SCALE_FACTOR - - return { - "prompt": prompt_embeds, - "prompt_attention_mask": prompt_attention_mask, - "image": image_latents, - "video": video_latents, - "video_metadata": { - "num_frames": num_frames, - "height": height, - "width": width, - }, - } - else: - image, video, _ = self._preprocess_video(self.video_paths[index]) - if video is not None: - return { - "prompt": self.id_token + self.prompts[index], - "image": image, - "video": video, - "video_metadata": { - "num_frames": video.shape[0], - "height": video.shape[2], - "width": video.shape[3], - }, - } - - # Overriding this because we need `prompt_attention_mask`. - def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]: - filename_without_ext = path.name.split(".")[0] - pt_filename = f"{filename_without_ext}.pt" - - # The current path is something like: /a/b/c/d/videos/00001.mp4 - # We need to reach: /a/b/c/d/video_latents/00001.pt - image_latents_path = path.parent.parent.joinpath("image_latents") - video_latents_path = path.parent.parent.joinpath("video_latents") - embeds_path = path.parent.parent.joinpath("prompt_embeds") - attention_mask_path = path.parent.parent.joinpath("prompt_attention_mask") - - if ( - not video_latents_path.exists() - or not embeds_path.exists() - or not attention_mask_path.exists() - or (self.image_to_video and not image_latents_path.exists()) - ): - raise ValueError( - f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains three folders named `video_latents`, `prompt_embeds`, and `prompt_attention_mask`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." - ) - - if self.image_to_video: - image_latent_filepath = image_latents_path.joinpath(pt_filename) - video_latent_filepath = video_latents_path.joinpath(pt_filename) - embeds_filepath = embeds_path.joinpath(pt_filename) - attention_mask_filepath = attention_mask_path.joinpath(pt_filename) - - if not video_latent_filepath.is_file() or not embeds_filepath.is_file() or not attention_mask_filepath.is_file(): - if self.image_to_video: - image_latent_filepath = image_latent_filepath.as_posix() - video_latent_filepath = video_latent_filepath.as_posix() - embeds_filepath = embeds_filepath.as_posix() - attention_mask_filepath = attention_mask_filepath.as_posix() - raise ValueError( - f"The file {video_latent_filepath=} or {embeds_filepath=} or {attention_mask_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." - ) - - images = ( - torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None - ) - latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) - embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) - attention_masks = torch.load(attention_mask_filepath, map_location="cpu", weights_only=True) - - return images, latents, embeds, attention_masks - - -# We need the `VideoDatasetWithResizing` and `VideoDatasetWithResizeAndRectangleCrop` classes to subclass from -# the new `VideoDataset` class defined in this file. And also because of the changes in -# `_preprocess_video()` (how we handle `nearest_frame_bucket`). - -class VideoDatasetWithResizing(VideoDataset): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def _preprocess_video(self, path: Path) -> torch.Tensor: - if self.load_tensors: - return self._load_preprocessed_latents_and_embeds(path) - else: - video_reader = decord.VideoReader(uri=path.as_posix()) - video_num_frames = len(video_reader) - - nearest_frame_bucket = min( - [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], - key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), - default=1, - ) - - frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) - frames = video_reader.get_batch(frame_indices) - frames = frames[:nearest_frame_bucket].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - - nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) - frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) - frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) - - image = frames[:1].clone() if self.image_to_video else None - - return image, frames, None - - def _find_nearest_resolution(self, height, width): - nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) - return nearest_res[1], nearest_res[2] - - -class VideoDatasetWithResizeAndRectangleCrop(VideoDataset): - def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.video_reshape_mode = video_reshape_mode - - def _resize_for_rectangle_crop(self, arr, image_size): - reshape_mode = self.video_reshape_mode - if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: - arr = resize( - arr, - size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], - interpolation=InterpolationMode.BICUBIC, - ) - else: - arr = resize( - arr, - size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], - interpolation=InterpolationMode.BICUBIC, - ) - - h, w = arr.shape[2], arr.shape[3] - arr = arr.squeeze(0) - - delta_h = h - image_size[0] - delta_w = w - image_size[1] - - if reshape_mode == "random" or reshape_mode == "none": - top = np.random.randint(0, delta_h + 1) - left = np.random.randint(0, delta_w + 1) - elif reshape_mode == "center": - top, left = delta_h // 2, delta_w // 2 - else: - raise NotImplementedError - arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) - return arr - - def _preprocess_video(self, path: Path) -> torch.Tensor: - if self.load_tensors: - return self._load_preprocessed_latents_and_embeds(path) - else: - video_reader = decord.VideoReader(uri=path.as_posix()) - video_num_frames = len(video_reader) - nearest_frame_bucket = min( - [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], - key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), - default=1, - ) - - frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) - - frames = video_reader.get_batch(frame_indices) - frames = frames[:nearest_frame_bucket].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - - nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) - frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) - frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) - - image = frames[:1].clone() if self.image_to_video else None - - return image, frames, None - - def _find_nearest_resolution(self, height, width): - nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) - return nearest_res[1], nearest_res[2] \ No newline at end of file diff --git a/training/mochi-1/prepare_dataset.py b/training/mochi-1/prepare_dataset.py index d9da99a2..0826c279 100644 --- a/training/mochi-1/prepare_dataset.py +++ b/training/mochi-1/prepare_dataset.py @@ -21,16 +21,18 @@ from torchvision import transforms from tqdm import tqdm from transformers import T5EncoderModel, T5Tokenizer +from dataset_mochi import VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop import decord # isort:skip +decord.bridge.set_bridge("torch") + import sys -sys.path.append(".") -from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +sys.path.append("..") +from dataset import BucketSampler -decord.bridge.set_bridge("torch") logger = get_logger(__name__) diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index d0d17369..2b9ee427 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -54,11 +54,12 @@ from transformers import AutoTokenizer, T5EncoderModel from args import get_args # isort:skip +from dataset_mochi import VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip import sys -sys.path.append(".") +sys.path.append("..") -from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from dataset import BucketSampler # isort:skip from text_encoder import compute_prompt_embeddings # isort:skip from utils import get_gradient_norm, get_optimizer, print_memory, reset_memory # isort:skip From cb16cbae5fffe1f966255b0830742d448c7c07c8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Nov 2024 11:35:00 +0530 Subject: [PATCH 13/30] dataset_mochi --- dataset_mochi.py | 242 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 dataset_mochi.py diff --git a/dataset_mochi.py b/dataset_mochi.py new file mode 100644 index 00000000..c40a7473 --- /dev/null +++ b/dataset_mochi.py @@ -0,0 +1,242 @@ +from pathlib import Path +from typing import Any, Dict, Tuple + +import numpy as np +import torch +import torchvision.transforms as TT +from accelerate.logging import get_logger +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import resize + + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +import sys +sys.path.append("..") + +from dataset import VideoDataset as VDS + +logger = get_logger(__name__) + +# TODO (sayakpaul): probably not all buckets are needed for Mochi-1? +HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] +WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536] +FRAME_BUCKETS = [16, 24, 32, 48, 64, 80, 84] + +VAE_SPATIAL_SCALE_FACTOR = 8 +VAE_TEMPORAL_SCALE_FACTOR = 6 + +class VideoDataset(VDS): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # Overriding this because we calculate `num_frames` differently. + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + # Here, index is actually a list of data objects that we need to return. + # The BucketSampler should ideally return indices. But, in the sampler, we'd like + # to have information about num_frames, height and width. Since this is not stored + # as metadata, we need to read the video to get this information. You could read this + # information without loading the full video in memory, but we do it anyway. In order + # to not load the video twice (once to get the metadata, and once to return the loaded video + # based on sampled indices), we cache it in the BucketSampler. When the sampler is + # to yield, we yield the cache data instead of indices. So, this special check ensures + # that data is not loaded a second time. PRs are welcome for improvements. + return index + + if self.load_tensors: + image_latents, video_latents, prompt_embeds, prompt_attention_mask = self._preprocess_video(self.video_paths[index]) + + # This is hardcoded for now. + # Output of the VAE encoding is 2 * output_channels and then it's + # temporal compression factor is 6. Initially, the VAE encodings will have + # 24 latent number of frames. So, if we were to train with a + # max frame size of 84 and frame bucket of [84], we need to have the following logic. + latent_num_frames = video_latents.size(0) + num_frames = (latent_num_frames // 2) * (VAE_TEMPORAL_SCALE_FACTOR + 1) + + height = video_latents.size(2) * VAE_SPATIAL_SCALE_FACTOR + width = video_latents.size(3) * VAE_SPATIAL_SCALE_FACTOR + + return { + "prompt": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "image": image_latents, + "video": video_latents, + "video_metadata": { + "num_frames": num_frames, + "height": height, + "width": width, + }, + } + else: + image, video, _ = self._preprocess_video(self.video_paths[index]) + if video is not None: + return { + "prompt": self.id_token + self.prompts[index], + "image": image, + "video": video, + "video_metadata": { + "num_frames": video.shape[0], + "height": video.shape[2], + "width": video.shape[3], + }, + } + + # Overriding this because we need `prompt_attention_mask`. + def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + filename_without_ext = path.name.split(".")[0] + pt_filename = f"{filename_without_ext}.pt" + + # The current path is something like: /a/b/c/d/videos/00001.mp4 + # We need to reach: /a/b/c/d/video_latents/00001.pt + image_latents_path = path.parent.parent.joinpath("image_latents") + video_latents_path = path.parent.parent.joinpath("video_latents") + embeds_path = path.parent.parent.joinpath("prompt_embeds") + attention_mask_path = path.parent.parent.joinpath("prompt_attention_mask") + + if ( + not video_latents_path.exists() + or not embeds_path.exists() + or not attention_mask_path.exists() + or (self.image_to_video and not image_latents_path.exists()) + ): + raise ValueError( + f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains three folders named `video_latents`, `prompt_embeds`, and `prompt_attention_mask`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." + ) + + if self.image_to_video: + image_latent_filepath = image_latents_path.joinpath(pt_filename) + video_latent_filepath = video_latents_path.joinpath(pt_filename) + embeds_filepath = embeds_path.joinpath(pt_filename) + attention_mask_filepath = attention_mask_path.joinpath(pt_filename) + + if not video_latent_filepath.is_file() or not embeds_filepath.is_file() or not attention_mask_filepath.is_file(): + if self.image_to_video: + image_latent_filepath = image_latent_filepath.as_posix() + video_latent_filepath = video_latent_filepath.as_posix() + embeds_filepath = embeds_filepath.as_posix() + attention_mask_filepath = attention_mask_filepath.as_posix() + raise ValueError( + f"The file {video_latent_filepath=} or {embeds_filepath=} or {attention_mask_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." + ) + + images = ( + torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None + ) + latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) + embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) + attention_masks = torch.load(attention_mask_filepath, map_location="cpu", weights_only=True) + + return images, latents, embeds, attention_masks + + +# We need the `VideoDatasetWithResizing` and `VideoDatasetWithResizeAndRectangleCrop` classes to subclass from +# the new `VideoDataset` class defined in this file. And also because of the changes in +# `_preprocess_video()` (how we handle `nearest_frame_bucket`). + +class VideoDatasetWithResizing(VideoDataset): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def _preprocess_video(self, path: Path) -> torch.Tensor: + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + + nearest_frame_bucket = min( + [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], + key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), + default=1, + ) + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + +class VideoDatasetWithResizeAndRectangleCrop(VideoDataset): + def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.video_reshape_mode = video_reshape_mode + + def _resize_for_rectangle_crop(self, arr, image_size): + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + def _preprocess_video(self, path: Path) -> torch.Tensor: + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + nearest_frame_bucket = min( + [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], + key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), + default=1, + ) + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + image = frames[:1].clone() if self.image_to_video else None + + return image, frames, None + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] \ No newline at end of file From 5208c59394a654b7b04c49e300a60297a48ecada Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Nov 2024 11:36:56 +0530 Subject: [PATCH 14/30] fix --- dataset_mochi.py => training/mochi-1/dataset_mochi.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename dataset_mochi.py => training/mochi-1/dataset_mochi.py (100%) diff --git a/dataset_mochi.py b/training/mochi-1/dataset_mochi.py similarity index 100% rename from dataset_mochi.py rename to training/mochi-1/dataset_mochi.py From 9eea656fa07c2924b76cead59916806f4f2ebfb4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Nov 2024 18:38:05 +0530 Subject: [PATCH 15/30] fixes --- training/mochi-1/text_to_video_lora.py | 6 +++--- training/mochi-1/train.sh | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 2b9ee427..239d215c 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -309,6 +309,7 @@ def main(args): variant=args.variant, ) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + # noise_scheduler_copy = FlowMatchEulerDiscreteScheduler.from_config(scheduler.config, invert_sigmas=False) noise_scheduler_copy = copy.deepcopy(scheduler) vae_config = AutoencoderKLMochi.load_config(args.pretrained_model_name_or_path, subfolder="vae") @@ -633,7 +634,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + # notice the reverse. + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps][::-1] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: @@ -660,7 +662,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): latent_dist = DiagonalGaussianDistribution(videos) videos = latent_dist.sample() - videos = videos[:, :vae_in_channels, ...] # to respect `in_channels` for the vae if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(vae_config["latents_mean"]).view(1, vae_in_channels, 1, 1, 1).to(videos.device, videos.dtype) @@ -705,7 +706,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) - # print(f"{timesteps=}") # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 diff --git a/training/mochi-1/train.sh b/training/mochi-1/train.sh index 4c4871e2..ad71ed2a 100644 --- a/training/mochi-1/train.sh +++ b/training/mochi-1/train.sh @@ -26,7 +26,7 @@ cmd="accelerate launch --config_file deepspeed.yaml --gpu_ids $GPU_IDS text_to_v --rank 64 \ --lora_alpha 64 \ --mixed_precision bf16 \ - --output_dir mochi-lora \ + --output_dir /raid/.cache/huggingface/sayak/mochi-lora/ \ --max_num_frames 84 \ --train_batch_size 1 \ --dataloader_num_workers 4 \ From 7e203db9aad5b54422909c72dd262ce5f91448d3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 20 Nov 2024 10:39:52 +0530 Subject: [PATCH 16/30] updates --- training/mochi-1/args.py | 2 +- training/mochi-1/text_to_video_lora.py | 8 ++++++-- training/mochi-1/train.sh | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/training/mochi-1/args.py b/training/mochi-1/args.py index ee7c463d..1d41f5ae 100644 --- a/training/mochi-1/args.py +++ b/training/mochi-1/args.py @@ -369,7 +369,7 @@ def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--beta2", type=float, - default=0.95, + default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers.", ) parser.add_argument( diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 239d215c..9fc4ac61 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -483,6 +483,7 @@ def load_model_hook(models, input_dir): use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, offload_gradients=args.offload_gradients, ) + accelerator.print(f"Using {optimizer.__class__.__name__} optimizer.") # Dataset and DataLoader dataset_init_kwargs = { @@ -635,9 +636,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) timesteps = timesteps.to(accelerator.device) # notice the reverse. - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps][::-1] + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() + if "invert_sigmas" in noise_scheduler_copy.config and noise_scheduler_copy.config.invert_sigmas: + # https://github.com/huggingface/diffusers/blob/99c0483b67427de467f11aa35d54678fd36a7ea2/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L209 + sigma = 1.0 - sigma while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma @@ -938,7 +942,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): repo_id=repo_id, folder_path=args.output_dir, commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], + ignore_patterns=["step_*", "epoch_*", "*.bin"], ) accelerator.end_training() diff --git a/training/mochi-1/train.sh b/training/mochi-1/train.sh index ad71ed2a..e9851cbd 100644 --- a/training/mochi-1/train.sh +++ b/training/mochi-1/train.sh @@ -34,13 +34,13 @@ cmd="accelerate launch --config_file deepspeed.yaml --gpu_ids $GPU_IDS text_to_v --checkpointing_steps 50 \ --gradient_accumulation_steps 4 \ --gradient_checkpointing \ - --learning_rate 0.0001 \ + --learning_rate 1e-5 \ --lr_scheduler constant \ --lr_warmup_steps 0 \ --lr_num_cycles 1 \ --enable_slicing \ --enable_tiling \ - --optimizer adamw \ + --optimizer adamw --use_8bit \ --beta1 0.9 \ --beta2 0.95 \ --beta3 0.99 \ From 9a15eaee417e0b339c1663745fc2d165fe0ffe0d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 20 Nov 2024 19:31:13 +0530 Subject: [PATCH 17/30] updates --- training/mochi-1/dataset_mochi.py | 1 - training/mochi-1/text_to_video_lora.py | 60 +++++++++++++------------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/training/mochi-1/dataset_mochi.py b/training/mochi-1/dataset_mochi.py index c40a7473..2e58f3a0 100644 --- a/training/mochi-1/dataset_mochi.py +++ b/training/mochi-1/dataset_mochi.py @@ -5,7 +5,6 @@ import torch import torchvision.transforms as TT from accelerate.logging import get_logger -from torchvision import transforms from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import resize diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 9fc4ac61..4ffe908a 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -76,7 +76,7 @@ def save_model_card( fps=8, ): widget_dict = [] - if videos is not None: + if videos is not None and len(videos) > 0: for i, video in enumerate(videos): export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) widget_dict.append( @@ -876,34 +876,34 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch.cuda.empty_cache() torch.cuda.synchronize(accelerator.device) - accelerator.print("===== Memory before testing =====") - print_memory(accelerator.device) - reset_memory(accelerator.device) - # Final test inference - pipe = MochiPipeline.from_pretrained( - args.pretrained_model_name_or_path, - revision=args.revision, - variant=args.variant, - # torch_dtype=weight_dtype, - ) - pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + accelerator.print("===== Memory before testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + + pipe = MochiPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + # torch_dtype=weight_dtype, + ) + pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) - if args.enable_slicing: - pipe.vae.enable_slicing() - if args.enable_tiling: - pipe.vae.enable_tiling() - if args.enable_model_cpu_offload: - pipe.enable_model_cpu_offload() + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() - # Load LoRA weights - lora_scaling = args.lora_alpha / args.rank - pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora") - pipe.set_adapters(["mochi-lora"], [lora_scaling]) + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora") + pipe.set_adapters(["mochi-lora"], [lora_scaling]) - # Run inference - validation_outputs = [] - if args.validation_prompt and args.num_validation_videos > 0: + # Run inference validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) for validation_prompt in validation_prompts: pipeline_args = { @@ -924,10 +924,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) validation_outputs.extend(video) - accelerator.print("===== Memory after testing =====") - print_memory(accelerator.device) - reset_memory(accelerator.device) - torch.cuda.synchronize(accelerator.device) + accelerator.print("===== Memory after testing =====") + print_memory(accelerator.device) + reset_memory(accelerator.device) + torch.cuda.synchronize(accelerator.device) if args.push_to_hub: save_model_card( @@ -942,7 +942,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): repo_id=repo_id, folder_path=args.output_dir, commit_message="End of training", - ignore_patterns=["step_*", "epoch_*", "*.bin"], + ignore_patterns=["step_*", "epoch_*", "*.bin", "*.pt"], ) accelerator.end_training() From 2dbddd566e49c600d701dd405d30aad21af07147 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 22 Nov 2024 23:34:17 +0530 Subject: [PATCH 18/30] updates --- training/mochi-1/dataset_mochi.py | 133 +++++++++++++------------ training/mochi-1/prepare_dataset.py | 13 +-- training/mochi-1/prepare_dataset.sh | 4 +- training/mochi-1/text_to_video_lora.py | 30 ++---- training/mochi-1/train.sh | 10 +- 5 files changed, 87 insertions(+), 103 deletions(-) diff --git a/training/mochi-1/dataset_mochi.py b/training/mochi-1/dataset_mochi.py index 2e58f3a0..d2af84b9 100644 --- a/training/mochi-1/dataset_mochi.py +++ b/training/mochi-1/dataset_mochi.py @@ -3,10 +3,11 @@ import numpy as np import torch -import torchvision.transforms as TT +from torchvision import transforms from accelerate.logging import get_logger from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import resize +import torch.nn as nn # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error @@ -25,7 +26,7 @@ # TODO (sayakpaul): probably not all buckets are needed for Mochi-1? HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536] -FRAME_BUCKETS = [16, 24, 32, 48, 64, 80, 84] +FRAME_BUCKETS = [16, 24, 32, 48, 64, 80, 85] VAE_SPATIAL_SCALE_FACTOR = 8 VAE_TEMPORAL_SCALE_FACTOR = 6 @@ -34,6 +35,19 @@ class VideoDataset(VDS): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + random_flip = kwargs.get("random_flip", None) + self.video_transforms = transforms.Compose( + [ + transforms.RandomHorizontalFlip([(random_flip)]) + if random_flip + else transforms.Lambda(lambda x: x), + transforms.Lambda(self.scale_transform), + ] + ) + + def scale_transform(self, x): + return x / 127.5 - 1.0 + # Overriding this because we calculate `num_frames` differently. def __getitem__(self, index: int) -> Dict[str, Any]: if isinstance(index, list): @@ -55,9 +69,9 @@ def __getitem__(self, index: int) -> Dict[str, Any]: # Output of the VAE encoding is 2 * output_channels and then it's # temporal compression factor is 6. Initially, the VAE encodings will have # 24 latent number of frames. So, if we were to train with a - # max frame size of 84 and frame bucket of [84], we need to have the following logic. + # max frame size of 85 and frame bucket of [85], we need to have the following logic. latent_num_frames = video_latents.size(0) - num_frames = (latent_num_frames // 2) * (VAE_TEMPORAL_SCALE_FACTOR + 1) + num_frames = ((latent_num_frames // 2) * (VAE_TEMPORAL_SCALE_FACTOR + 1) + 1) height = video_latents.size(2) * VAE_SPATIAL_SCALE_FACTOR width = video_latents.size(3) * VAE_SPATIAL_SCALE_FACTOR @@ -135,13 +149,13 @@ def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tenso return images, latents, embeds, attention_masks -# We need the `VideoDatasetWithResizing` and `VideoDatasetWithResizeAndRectangleCrop` classes to subclass from -# the new `VideoDataset` class defined in this file. And also because of the changes in -# `_preprocess_video()` (how we handle `nearest_frame_bucket`). -class VideoDatasetWithResizing(VideoDataset): - def __init__(self, *args, **kwargs) -> None: +class VideoDatasetWithFlexibleResize(VideoDataset): + def __init__(self, video_reshape_mode: str = None, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + if video_reshape_mode: + assert video_reshape_mode in ["center", "random"] + self.video_reshape_mode = video_reshape_mode def _preprocess_video(self, path: Path) -> torch.Tensor: if self.load_tensors: @@ -151,36 +165,56 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: video_num_frames = len(video_reader) nearest_frame_bucket = min( - [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], - key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), - default=1, + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + ) + frame_indices = list( + range( + 0, + video_num_frames, + 1 if video_num_frames < nearest_frame_bucket else video_num_frames // nearest_frame_bucket + ) ) - - frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) frames = video_reader.get_batch(frame_indices) - frames = frames[:nearest_frame_bucket].float() + + # Pad or truncate frames to match the bucket size + if video_num_frames < nearest_frame_bucket: + pad_size = nearest_frame_bucket - video_num_frames + frames = nn.functional.pad(frames, (0, 0, 0, 0, 0, 0, 0, pad_size)) + frames = frames.float() + else: + frames = frames[:nearest_frame_bucket].float() frames = frames.permute(0, 3, 1, 2).contiguous() + # Find nearest resolution and apply resizing nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) - frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) - frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + if self.video_reshape_mode in {"center", "random"}: + frames = self._resize_for_rectangle_crop(frames, nearest_res) + else: + frames = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) + + # Apply transformations + frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) + # Optionally extract the first frame as an image image = frames[:1].clone() if self.image_to_video else None return image, frames, None - def _find_nearest_resolution(self, height, width): + def _find_nearest_resolution(self, height: int, width: int) -> Tuple[int, int]: + """ + Find the nearest resolution from the predefined list of resolutions. + """ nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) return nearest_res[1], nearest_res[2] - -class VideoDatasetWithResizeAndRectangleCrop(VideoDataset): - def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.video_reshape_mode = video_reshape_mode - - def _resize_for_rectangle_crop(self, arr, image_size): - reshape_mode = self.video_reshape_mode + def _resize_for_rectangle_crop(self, arr: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: + """ + Resize frames for rectangular cropping. + + Args: + arr (torch.Tensor): The video frames tensor [N, C, H, W]. + image_size (Tuple[int, int]): The target resolution (height, width). + """ if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: arr = resize( arr, @@ -194,48 +228,15 @@ def _resize_for_rectangle_crop(self, arr, image_size): interpolation=InterpolationMode.BICUBIC, ) + # Perform cropping h, w = arr.shape[2], arr.shape[3] - arr = arr.squeeze(0) + delta_h, delta_w = h - image_size[0], w - image_size[1] - delta_h = h - image_size[0] - delta_w = w - image_size[1] - - if reshape_mode == "random" or reshape_mode == "none": - top = np.random.randint(0, delta_h + 1) - left = np.random.randint(0, delta_w + 1) - elif reshape_mode == "center": + if self.video_reshape_mode == "random": + top, left = np.random.randint(0, delta_h + 1), np.random.randint(0, delta_w + 1) + elif self.video_reshape_mode == "center": top, left = delta_h // 2, delta_w // 2 else: - raise NotImplementedError - arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) - return arr - - def _preprocess_video(self, path: Path) -> torch.Tensor: - if self.load_tensors: - return self._load_preprocessed_latents_and_embeds(path) - else: - video_reader = decord.VideoReader(uri=path.as_posix()) - video_num_frames = len(video_reader) - nearest_frame_bucket = min( - [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], - key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), - default=1, - ) - - frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + raise NotImplementedError(f"Unsupported reshape mode: {self.video_reshape_mode}") - frames = video_reader.get_batch(frame_indices) - frames = frames[:nearest_frame_bucket].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - - nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) - frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) - frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) - - image = frames[:1].clone() if self.image_to_video else None - - return image, frames, None - - def _find_nearest_resolution(self, height, width): - nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) - return nearest_res[1], nearest_res[2] \ No newline at end of file + return transforms.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) \ No newline at end of file diff --git a/training/mochi-1/prepare_dataset.py b/training/mochi-1/prepare_dataset.py index 0826c279..864e3821 100644 --- a/training/mochi-1/prepare_dataset.py +++ b/training/mochi-1/prepare_dataset.py @@ -21,7 +21,7 @@ from torchvision import transforms from tqdm import tqdm from transformers import T5EncoderModel, T5Tokenizer -from dataset_mochi import VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop +from dataset_mochi import VideoDatasetWithFlexibleResize import decord # isort:skip @@ -326,6 +326,8 @@ def serialize_artifacts( metadata = [] for i in range(videos.size(0)): video = videos[i:i+1] + if video.size(1) == 1: + print(f"{video_latents[i:i+1].shape=}") metadata_dict = {"num_frames": video.size(1), "height": video.size(3), "width": video.size(4)} metadata.append(metadata_dict) @@ -421,6 +423,7 @@ def main(): dataset_init_kwargs = { "data_root": args.data_root, "dataset_file": args.dataset_file, + "video_reshape_mode": args.video_reshape_mode, "caption_column": args.caption_column, "video_column": args.video_column, "max_num_frames": args.max_num_frames, @@ -432,13 +435,7 @@ def main(): "random_flip": args.random_flip, "image_to_video": args.save_image_latents, } - if args.video_reshape_mode is None: - dataset = VideoDatasetWithResizing(**dataset_init_kwargs) - else: - dataset = VideoDatasetWithResizeAndRectangleCrop( - video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs - ) - + dataset = VideoDatasetWithFlexibleResize(**dataset_init_kwargs) original_dataset_size = len(dataset) # Split data among GPUs diff --git a/training/mochi-1/prepare_dataset.sh b/training/mochi-1/prepare_dataset.sh index 02786e52..c2aba5ce 100644 --- a/training/mochi-1/prepare_dataset.sh +++ b/training/mochi-1/prepare_dataset.sh @@ -11,8 +11,8 @@ VIDEO_COLUMN="videos.txt" OUTPUT_DIR="/home/sayak/cogvideox-factory/video-dataset-disney/mochi-1/preprocessed-dataset" HEIGHT_BUCKETS="480" WIDTH_BUCKETS="848" -FRAME_BUCKETS="1 84" -MAX_NUM_FRAMES="84" +FRAME_BUCKETS="85" +MAX_NUM_FRAMES="85" MAX_SEQUENCE_LENGTH=256 TARGET_FPS=30 BATCH_SIZE=4 diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 4ffe908a..26091e26 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -54,7 +54,7 @@ from transformers import AutoTokenizer, T5EncoderModel from args import get_args # isort:skip -from dataset_mochi import VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip +from dataset_mochi import VideoDatasetWithFlexibleResize # isort:skip import sys sys.path.append("..") @@ -309,7 +309,6 @@ def main(args): variant=args.variant, ) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - # noise_scheduler_copy = FlowMatchEulerDiscreteScheduler.from_config(scheduler.config, invert_sigmas=False) noise_scheduler_copy = copy.deepcopy(scheduler) vae_config = AutoencoderKLMochi.load_config(args.pretrained_model_name_or_path, subfolder="vae") @@ -347,7 +346,8 @@ def main(args): ) transformer.requires_grad_(False) - transformer.to(accelerator.device, dtype=weight_dtype) + # transformer.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() @@ -423,9 +423,8 @@ def load_model_hook(models, input_dir): # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 - if args.mixed_precision == "fp16": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([transformer_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -439,11 +438,8 @@ def load_model_hook(models, input_dir): args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([transformer], dtype=torch.float32) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) @@ -488,6 +484,7 @@ def load_model_hook(models, input_dir): # Dataset and DataLoader dataset_init_kwargs = { "data_root": args.data_root, + "video_reshape_mode": args.video_reshape_mode, "dataset_file": args.dataset_file, "caption_column": args.caption_column, "video_column": args.video_column, @@ -499,15 +496,8 @@ def load_model_hook(models, input_dir): "load_tensors": args.load_tensors, "random_flip": args.random_flip, } - if args.video_reshape_mode is None: - train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs) - else: - train_dataset = VideoDatasetWithResizeAndRectangleCrop( - video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs - ) - + train_dataset = VideoDatasetWithFlexibleResize(**dataset_init_kwargs) collate_fn = CollateFunction(weight_dtype, args.load_tensors) - train_dataloader = DataLoader( train_dataset, batch_size=1, @@ -635,7 +625,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) timesteps = timesteps.to(accelerator.device) - # notice the reverse. step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() @@ -944,6 +933,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): commit_message="End of training", ignore_patterns=["step_*", "epoch_*", "*.bin", "*.pt"], ) + accelerator.print(f"Params pushed to {repo_id}.") accelerator.end_training() diff --git a/training/mochi-1/train.sh b/training/mochi-1/train.sh index e9851cbd..69c636b5 100644 --- a/training/mochi-1/train.sh +++ b/training/mochi-1/train.sh @@ -16,21 +16,17 @@ cmd="accelerate launch --config_file deepspeed.yaml --gpu_ids $GPU_IDS text_to_v --id_token BW_STYLE \ --height_buckets 480 \ --width_buckets 848 \ - --frame_buckets 84 \ + --frame_buckets 85 \ --load_tensors \ - --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \ - --validation_prompt_separator ::: \ - --num_validation_videos 1 \ - --validation_epochs 1 \ --seed 42 \ --rank 64 \ --lora_alpha 64 \ --mixed_precision bf16 \ --output_dir /raid/.cache/huggingface/sayak/mochi-lora/ \ - --max_num_frames 84 \ + --max_num_frames 85 \ --train_batch_size 1 \ --dataloader_num_workers 4 \ - --max_train_steps 500 \ + --max_train_steps 10 \ --checkpointing_steps 50 \ --gradient_accumulation_steps 4 \ --gradient_checkpointing \ From 58a06320bafe4aeeb9d676db354738cf7b5dd44f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 26 Nov 2024 12:28:44 +0530 Subject: [PATCH 19/30] updates --- training/mochi-1/dataset_mochi.py | 23 ++--- training/mochi-1/prepare_dataset.py | 10 ++- training/mochi-1/text_to_video_lora.py | 114 +++++++++++++++---------- 3 files changed, 88 insertions(+), 59 deletions(-) diff --git a/training/mochi-1/dataset_mochi.py b/training/mochi-1/dataset_mochi.py index d2af84b9..7bab1ce8 100644 --- a/training/mochi-1/dataset_mochi.py +++ b/training/mochi-1/dataset_mochi.py @@ -3,11 +3,11 @@ import numpy as np import torch -from torchvision import transforms +import torch.nn as nn from accelerate.logging import get_logger +from torchvision import transforms from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import resize -import torch.nn as nn # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error @@ -16,14 +16,17 @@ decord.bridge.set_bridge("torch") -import sys +import sys + + sys.path.append("..") from dataset import VideoDataset as VDS + logger = get_logger(__name__) -# TODO (sayakpaul): probably not all buckets are needed for Mochi-1? +# TODO (sayakpaul): probably not all buckets are needed for Mochi-1? HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536] FRAME_BUCKETS = [16, 24, 32, 48, 64, 80, 85] @@ -68,7 +71,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: # This is hardcoded for now. # Output of the VAE encoding is 2 * output_channels and then it's # temporal compression factor is 6. Initially, the VAE encodings will have - # 24 latent number of frames. So, if we were to train with a + # 24 latent number of frames. So, if we were to train with a # max frame size of 85 and frame bucket of [85], we need to have the following logic. latent_num_frames = video_latents.size(0) num_frames = ((latent_num_frames // 2) * (VAE_TEMPORAL_SCALE_FACTOR + 1) + 1) @@ -169,13 +172,13 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: ) frame_indices = list( range( - 0, - video_num_frames, + 0, + video_num_frames, 1 if video_num_frames < nearest_frame_bucket else video_num_frames // nearest_frame_bucket ) ) frames = video_reader.get_batch(frame_indices) - + # Pad or truncate frames to match the bucket size if video_num_frames < nearest_frame_bucket: pad_size = nearest_frame_bucket - video_num_frames @@ -210,7 +213,7 @@ def _find_nearest_resolution(self, height: int, width: int) -> Tuple[int, int]: def _resize_for_rectangle_crop(self, arr: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: """ Resize frames for rectangular cropping. - + Args: arr (torch.Tensor): The video frames tensor [N, C, H, W]. image_size (Tuple[int, int]): The target resolution (height, width). @@ -239,4 +242,4 @@ def _resize_for_rectangle_crop(self, arr: torch.Tensor, image_size: Tuple[int, i else: raise NotImplementedError(f"Unsupported reshape mode: {self.video_reshape_mode}") - return transforms.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) \ No newline at end of file + return transforms.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) diff --git a/training/mochi-1/prepare_dataset.py b/training/mochi-1/prepare_dataset.py index 864e3821..0e710db6 100644 --- a/training/mochi-1/prepare_dataset.py +++ b/training/mochi-1/prepare_dataset.py @@ -8,12 +8,13 @@ import queue import traceback import uuid -from contextlib import nullcontext from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist +from dataset_mochi import VideoDatasetWithFlexibleResize from diffusers import AutoencoderKLMochi from diffusers.training_utils import set_seed from diffusers.utils import export_to_video, get_logger @@ -21,7 +22,6 @@ from torchvision import transforms from tqdm import tqdm from transformers import T5EncoderModel, T5Tokenizer -from dataset_mochi import VideoDatasetWithFlexibleResize import decord # isort:skip @@ -29,6 +29,8 @@ decord.bridge.set_bridge("torch") import sys + + sys.path.append("..") from dataset import BucketSampler @@ -489,11 +491,11 @@ def collate_fn(data): text_encoder = T5EncoderModel.from_pretrained( args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype ).to(device) - + vae = AutoencoderKLMochi.from_pretrained( args.model_id, subfolder="vae", torch_dtype=weight_dtype ).to(device) - + if args.use_slicing: vae.enable_slicing() if args.use_tiling: diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 26091e26..e73dcd95 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import gc +import json import logging import math import os @@ -26,7 +28,6 @@ import torch import transformers import wandb -import copy from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import ( @@ -43,7 +44,11 @@ ) from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from diffusers.optimization import get_scheduler -from diffusers.training_utils import cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 +from diffusers.training_utils import ( + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, +) from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -53,10 +58,13 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, T5EncoderModel + from args import get_args # isort:skip from dataset_mochi import VideoDatasetWithFlexibleResize # isort:skip import sys + + sys.path.append("..") from dataset import BucketSampler # isort:skip @@ -78,7 +86,7 @@ def save_model_card( widget_dict = [] if videos is not None and len(videos) > 0: for i, video in enumerate(videos): - export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps) widget_dict.append( { "text": validation_prompt if validation_prompt else " ", @@ -147,7 +155,8 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." ) - pipe = pipe.to(accelerator.device) + if not args.enable_model_cpu_offload: + pipe = pipe.to(accelerator.device) # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None @@ -261,7 +270,7 @@ def main(args): set_seed(args.seed) # Handle the repository creation - if accelerator.is_main_process: + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) @@ -283,7 +292,7 @@ def main(args): subfolder="text_encoder", revision=args.revision, ) - + vae = AutoencoderKLMochi.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", @@ -295,10 +304,11 @@ def main(args): if args.enable_tiling: vae.enable_tiling() + # keep things in FP32. text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=torch.float32) vae.requires_grad_(False) - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=torch.float32) load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 transformer = MochiTransformer3DModel.from_pretrained( @@ -315,7 +325,7 @@ def main(args): vae_in_channels = vae_config["latent_channels"] has_latents_mean = "latents_mean" in vae_config and vae_config["latents_mean"] is not None has_latents_std = "latents_std" in vae_config and vae_config["latents_std"] is not None - + VAE_SCALING_FACTOR = vae_config["scaling_factor"] # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision @@ -344,10 +354,10 @@ def main(args): raise ValueError( "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - + + # keep the transformer in FP32. transformer.requires_grad_(False) - # transformer.to(accelerator.device, dtype=weight_dtype) - transformer.to(accelerator.device) + transformer.to(accelerator.device, torch.float32) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() @@ -355,7 +365,7 @@ def main(args): transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.lora_alpha, - init_lora_weights=True, + init_lora_weights="gaussian", target_modules=args.target_modules, ) transformer.add_adapter(transformer_lora_config) @@ -482,6 +492,18 @@ def load_model_hook(models, input_dir): accelerator.print(f"Using {optimizer.__class__.__name__} optimizer.") # Dataset and DataLoader + if args.load_tensors and args.id_token: + with open(os.path.join(args.data_root, "data.jsonl")) as f: + contents = [json.loads(jline) for jline in f.read().splitlines()] + parsed_id_token = None + for content in contents: + if "id_token" in content: + parsed_id_token = content["id_token"] + if parsed_id_token is not None and parsed_id_token.strip() != args.id_token.strip(): + raise ValueError( + f"Parsed `id_token` from serialized metadata is {parsed_id_token} and provided `id_token` is {args.id_token}. They should match." + ) + dataset_init_kwargs = { "data_root": args.data_root, "video_reshape_mode": args.video_reshape_mode, @@ -497,7 +519,8 @@ def load_model_hook(models, input_dir): "random_flip": args.random_flip, } train_dataset = VideoDatasetWithFlexibleResize(**dataset_init_kwargs) - collate_fn = CollateFunction(weight_dtype, args.load_tensors) + # keeping things in FP32 for now. + collate_fn = CollateFunction(weight_dtype=torch.float32, load_tensors=args.load_tensors) train_dataloader = DataLoader( train_dataset, batch_size=1, @@ -619,7 +642,6 @@ def load_model_hook(models, input_dir): if args.load_tensors: gc.collect() torch.cuda.empty_cache() - torch.cuda.synchronize(accelerator.device) def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) @@ -650,11 +672,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Encode videos if not args.load_tensors: videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] - latent_dist = vae.encode(videos).latent_dist + latent_dist = vae.encode(videos.to(vae.dtype)).latent_dist else: latent_dist = DiagonalGaussianDistribution(videos) - videos = latent_dist.sample() + videos = latent_dist.sample() if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(vae_config["latents_mean"]).view(1, vae_in_channels, 1, 1, 1).to(videos.device, videos.dtype) @@ -665,8 +687,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): videos = (videos - latents_mean) * VAE_SCALING_FACTOR / latents_std else: videos = videos * VAE_SCALING_FACTOR - - videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype) + + # keep in FP32 for now. + videos = videos.to(memory_format=torch.contiguous_format, dtype=torch.float32) model_input = videos # Encode prompts @@ -677,11 +700,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompts, model_config.max_text_seq_length, accelerator.device, - weight_dtype, + weight_dtype=weight_dtype, requires_grad=False, ) else: - prompt_embeds = prompts.to(dtype=weight_dtype) + prompt_embeds = prompts.to(weight_dtype) prompt_attention_mask = prompt_attention_mask.to(accelerator.device) # Sample noise that will be added to the latents @@ -697,15 +720,25 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + # indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + # timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + # revisit. + timesteps = (u * noise_scheduler_copy.config.num_train_timesteps) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + sigmas = get_sigmas( + timesteps=noise_scheduler_copy.timesteps[timesteps.long()].to(device=model_input.device), + n_dim=model_input.ndim, + dtype=model_input.dtype + ) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # do we need to revisit this? + noisy_model_input = noisy_model_input.to(weight_dtype) # Predict the noise residual + actual_num_train_timesteps = float(noise_scheduler_copy.config.num_train_timesteps) + timesteps = (1 - (timesteps / actual_num_train_timesteps)) * actual_num_train_timesteps # revisit + timesteps = timesteps.to(device=model_input.device) model_pred = transformer( hidden_states=noisy_model_input, encoder_hidden_states=prompt_embeds, @@ -719,7 +752,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss - target = noise - model_input + # target = noise - model_input + target = model_input - noise # as discussed with Ajay loss = torch.mean( (weighting * (model_pred.float() - target.float()) ** 2).reshape(batch_size, -1), @@ -787,11 +821,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if global_step >= args.max_train_steps: break - if accelerator.is_main_process: + if global_step >= args.max_train_steps: + break + + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: accelerator.print("===== Memory before validation =====") print_memory(accelerator.device) - torch.cuda.synchronize(accelerator.device) pipe = MochiPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -799,7 +835,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): scheduler=scheduler, revision=args.revision, variant=args.variant, - # torch_dtype=weight_dtype, ) if args.enable_slicing: @@ -831,25 +866,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): print_memory(accelerator.device) reset_memory(accelerator.device) + + del pipe.text_encoder + del pipe.vae del pipe gc.collect() torch.cuda.empty_cache() - torch.cuda.synchronize(accelerator.device) accelerator.wait_for_everyone() - if accelerator.is_main_process: + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: transformer = unwrap_model(transformer) - dtype = ( - torch.float16 - if args.mixed_precision == "fp16" - else torch.bfloat16 - if args.mixed_precision == "bf16" - else torch.float32 - ) - # transformer = transformer.to(dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) - MochiPipeline.save_lora_weights( save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, @@ -863,7 +891,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): gc.collect() torch.cuda.empty_cache() - torch.cuda.synchronize(accelerator.device) # Final test inference validation_outputs = [] @@ -871,14 +898,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.print("===== Memory before testing =====") print_memory(accelerator.device) reset_memory(accelerator.device) - + pipe = MochiPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, - # torch_dtype=weight_dtype, ) - pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) if args.enable_slicing: pipe.vae.enable_slicing() @@ -898,7 +923,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args = { "prompt": validation_prompt, "guidance_scale": 4.5, - "use_dynamic_cfg": args.use_dynamic_cfg, "height": args.height, "width": args.width, } From 2fde026d30e34dc029b0e8a7a494ff51e52e8e93 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 28 Nov 2024 12:33:25 +0530 Subject: [PATCH 20/30] updates --- training/mochi-1/args.py | 150 ++--- training/mochi-1/dataset_mochi.py | 245 -------- training/mochi-1/dataset_simple.py | 50 ++ training/mochi-1/embed.py | 111 ++++ training/mochi-1/prepare_dataset.py | 682 ----------------------- training/mochi-1/prepare_dataset.sh | 50 +- training/mochi-1/text_to_video_lora.py | 381 +++++-------- training/mochi-1/train.sh | 46 +- training/mochi-1/trim_and_crop_videos.py | 126 +++++ 9 files changed, 466 insertions(+), 1375 deletions(-) delete mode 100644 training/mochi-1/dataset_mochi.py create mode 100644 training/mochi-1/dataset_simple.py create mode 100644 training/mochi-1/embed.py delete mode 100644 training/mochi-1/prepare_dataset.py create mode 100644 training/mochi-1/trim_and_crop_videos.py diff --git a/training/mochi-1/args.py b/training/mochi-1/args.py index 1d41f5ae..46a8a178 100644 --- a/training/mochi-1/args.py +++ b/training/mochi-1/args.py @@ -28,6 +28,11 @@ def _get_model_args(parser: argparse.ArgumentParser) -> None: default=None, help="The directory where the downloaded models and datasets will be stored.", ) + parser.add_argument( + "--cast_dit", + action="store_true", + help="If we should cast DiT params to a lower precision.", + ) def _get_dataset_args(parser: argparse.ArgumentParser) -> None: @@ -38,58 +43,12 @@ def _get_dataset_args(parser: argparse.ArgumentParser) -> None: help=("A folder containing the training data."), ) parser.add_argument( - "--dataset_file", - type=str, - default=None, - help=("Path to a CSV file if loading prompts/video paths using this format."), - ) - parser.add_argument( - "--video_column", - type=str, - default="video", - help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.", - ) - parser.add_argument( - "--caption_column", - type=str, - default="text", - help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.", - ) - parser.add_argument( - "--id_token", - type=str, - default=None, - help="Identifier token appended to the start of each prompt if provided.", - ) - parser.add_argument( - "--height_buckets", - nargs="+", - type=int, - default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], - ) - parser.add_argument( - "--width_buckets", - nargs="+", - type=int, - default=[256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536], - ) - parser.add_argument( - "--frame_buckets", - nargs="+", - type=int, - default=[84], - ) - parser.add_argument( - "--load_tensors", - action="store_true", - help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.", - ) - parser.add_argument( - "--random_flip", + "--caption_dropout", type=float, default=None, - help="If random horizontal flip augmentation is to be used, this should be the flip probability.", + help=("Probability to drop out captions randomly."), ) + parser.add_argument( "--dataloader_num_workers", type=int, @@ -140,15 +99,31 @@ def _get_validation_args(parser: argparse.ArgumentParser) -> None: default=False, help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="FPS to use when serializing the output videos.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + ) + parser.add_argument( + "--width", + type=int, + default=848, + ) def _get_training_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.") + parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.") parser.add_argument( "--lora_alpha", type=int, - default=64, + default=16, help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", ) parser.add_argument( @@ -156,7 +131,7 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: nargs="+", type=str, default=["to_k", "to_q", "to_v", "to_out.0"], - help="Target modules to train LoRA for." + help="Target modules to train LoRA for.", ) parser.add_argument( "--mixed_precision", @@ -175,43 +150,6 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: default="mochi-lora", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument( - "--height", - type=int, - default=480, - help="All input videos are resized to this height.", - ) - parser.add_argument( - "--width", - type=int, - default=848, - help="All input videos are resized to this width.", - ) - parser.add_argument( - "--video_reshape_mode", - type=str, - default=None, - help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", - ) - parser.add_argument("--fps", type=int, default=30, help="All input videos will be used at this FPS.") - parser.add_argument( - "--max_num_frames", - type=int, - default=84, - help="All input videos will be truncated to these many frames.", - ) - parser.add_argument( - "--skip_frames_start", - type=int, - default=0, - help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", - ) - parser.add_argument( - "--skip_frames_end", - type=int, - default=0, - help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", - ) parser.add_argument( "--train_batch_size", type=int, @@ -256,25 +194,6 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) - parser.add_argument( - "--weighting_scheme", - type=str, - default="none", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), - ) - parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument( - "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", - ) parser.add_argument( "--gradient_checkpointing", action="store_true", @@ -283,19 +202,18 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--learning_rate", type=float, - default=1e-4, + default=2e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", - default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, - default="constant", + default="cosine", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -304,7 +222,7 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--lr_warmup_steps", type=int, - default=500, + default=200, help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( @@ -331,12 +249,6 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: default=False, help="Whether or not to use VAE tiling for saving memory.", ) - parser.add_argument( - "--noised_image_dropout", - type=float, - default=0.05, - help="Image condition dropout probability when finetuning image-to-video.", - ) def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: @@ -386,7 +298,7 @@ def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--weight_decay", type=float, - default=1e-04, + default=0.01, help="Weight decay to use for optimizer.", ) parser.add_argument( diff --git a/training/mochi-1/dataset_mochi.py b/training/mochi-1/dataset_mochi.py deleted file mode 100644 index 7bab1ce8..00000000 --- a/training/mochi-1/dataset_mochi.py +++ /dev/null @@ -1,245 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Tuple - -import numpy as np -import torch -import torch.nn as nn -from accelerate.logging import get_logger -from torchvision import transforms -from torchvision.transforms import InterpolationMode -from torchvision.transforms.functional import resize - - -# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error -# Very few bug reports but it happens. Look in decord Github issues for more relevant information. -import decord # isort:skip - -decord.bridge.set_bridge("torch") - -import sys - - -sys.path.append("..") - -from dataset import VideoDataset as VDS - - -logger = get_logger(__name__) - -# TODO (sayakpaul): probably not all buckets are needed for Mochi-1? -HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] -WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536] -FRAME_BUCKETS = [16, 24, 32, 48, 64, 80, 85] - -VAE_SPATIAL_SCALE_FACTOR = 8 -VAE_TEMPORAL_SCALE_FACTOR = 6 - -class VideoDataset(VDS): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - random_flip = kwargs.get("random_flip", None) - self.video_transforms = transforms.Compose( - [ - transforms.RandomHorizontalFlip([(random_flip)]) - if random_flip - else transforms.Lambda(lambda x: x), - transforms.Lambda(self.scale_transform), - ] - ) - - def scale_transform(self, x): - return x / 127.5 - 1.0 - - # Overriding this because we calculate `num_frames` differently. - def __getitem__(self, index: int) -> Dict[str, Any]: - if isinstance(index, list): - # Here, index is actually a list of data objects that we need to return. - # The BucketSampler should ideally return indices. But, in the sampler, we'd like - # to have information about num_frames, height and width. Since this is not stored - # as metadata, we need to read the video to get this information. You could read this - # information without loading the full video in memory, but we do it anyway. In order - # to not load the video twice (once to get the metadata, and once to return the loaded video - # based on sampled indices), we cache it in the BucketSampler. When the sampler is - # to yield, we yield the cache data instead of indices. So, this special check ensures - # that data is not loaded a second time. PRs are welcome for improvements. - return index - - if self.load_tensors: - image_latents, video_latents, prompt_embeds, prompt_attention_mask = self._preprocess_video(self.video_paths[index]) - - # This is hardcoded for now. - # Output of the VAE encoding is 2 * output_channels and then it's - # temporal compression factor is 6. Initially, the VAE encodings will have - # 24 latent number of frames. So, if we were to train with a - # max frame size of 85 and frame bucket of [85], we need to have the following logic. - latent_num_frames = video_latents.size(0) - num_frames = ((latent_num_frames // 2) * (VAE_TEMPORAL_SCALE_FACTOR + 1) + 1) - - height = video_latents.size(2) * VAE_SPATIAL_SCALE_FACTOR - width = video_latents.size(3) * VAE_SPATIAL_SCALE_FACTOR - - return { - "prompt": prompt_embeds, - "prompt_attention_mask": prompt_attention_mask, - "image": image_latents, - "video": video_latents, - "video_metadata": { - "num_frames": num_frames, - "height": height, - "width": width, - }, - } - else: - image, video, _ = self._preprocess_video(self.video_paths[index]) - if video is not None: - return { - "prompt": self.id_token + self.prompts[index], - "image": image, - "video": video, - "video_metadata": { - "num_frames": video.shape[0], - "height": video.shape[2], - "width": video.shape[3], - }, - } - - # Overriding this because we need `prompt_attention_mask`. - def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]: - filename_without_ext = path.name.split(".")[0] - pt_filename = f"{filename_without_ext}.pt" - - # The current path is something like: /a/b/c/d/videos/00001.mp4 - # We need to reach: /a/b/c/d/video_latents/00001.pt - image_latents_path = path.parent.parent.joinpath("image_latents") - video_latents_path = path.parent.parent.joinpath("video_latents") - embeds_path = path.parent.parent.joinpath("prompt_embeds") - attention_mask_path = path.parent.parent.joinpath("prompt_attention_mask") - - if ( - not video_latents_path.exists() - or not embeds_path.exists() - or not attention_mask_path.exists() - or (self.image_to_video and not image_latents_path.exists()) - ): - raise ValueError( - f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains three folders named `video_latents`, `prompt_embeds`, and `prompt_attention_mask`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." - ) - - if self.image_to_video: - image_latent_filepath = image_latents_path.joinpath(pt_filename) - video_latent_filepath = video_latents_path.joinpath(pt_filename) - embeds_filepath = embeds_path.joinpath(pt_filename) - attention_mask_filepath = attention_mask_path.joinpath(pt_filename) - - if not video_latent_filepath.is_file() or not embeds_filepath.is_file() or not attention_mask_filepath.is_file(): - if self.image_to_video: - image_latent_filepath = image_latent_filepath.as_posix() - video_latent_filepath = video_latent_filepath.as_posix() - embeds_filepath = embeds_filepath.as_posix() - attention_mask_filepath = attention_mask_filepath.as_posix() - raise ValueError( - f"The file {video_latent_filepath=} or {embeds_filepath=} or {attention_mask_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." - ) - - images = ( - torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None - ) - latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) - embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) - attention_masks = torch.load(attention_mask_filepath, map_location="cpu", weights_only=True) - - return images, latents, embeds, attention_masks - - - -class VideoDatasetWithFlexibleResize(VideoDataset): - def __init__(self, video_reshape_mode: str = None, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - if video_reshape_mode: - assert video_reshape_mode in ["center", "random"] - self.video_reshape_mode = video_reshape_mode - - def _preprocess_video(self, path: Path) -> torch.Tensor: - if self.load_tensors: - return self._load_preprocessed_latents_and_embeds(path) - else: - video_reader = decord.VideoReader(uri=path.as_posix()) - video_num_frames = len(video_reader) - - nearest_frame_bucket = min( - self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) - ) - frame_indices = list( - range( - 0, - video_num_frames, - 1 if video_num_frames < nearest_frame_bucket else video_num_frames // nearest_frame_bucket - ) - ) - frames = video_reader.get_batch(frame_indices) - - # Pad or truncate frames to match the bucket size - if video_num_frames < nearest_frame_bucket: - pad_size = nearest_frame_bucket - video_num_frames - frames = nn.functional.pad(frames, (0, 0, 0, 0, 0, 0, 0, pad_size)) - frames = frames.float() - else: - frames = frames[:nearest_frame_bucket].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - - # Find nearest resolution and apply resizing - nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) - if self.video_reshape_mode in {"center", "random"}: - frames = self._resize_for_rectangle_crop(frames, nearest_res) - else: - frames = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) - - # Apply transformations - frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) - - # Optionally extract the first frame as an image - image = frames[:1].clone() if self.image_to_video else None - - return image, frames, None - - def _find_nearest_resolution(self, height: int, width: int) -> Tuple[int, int]: - """ - Find the nearest resolution from the predefined list of resolutions. - """ - nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) - return nearest_res[1], nearest_res[2] - - def _resize_for_rectangle_crop(self, arr: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: - """ - Resize frames for rectangular cropping. - - Args: - arr (torch.Tensor): The video frames tensor [N, C, H, W]. - image_size (Tuple[int, int]): The target resolution (height, width). - """ - if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: - arr = resize( - arr, - size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], - interpolation=InterpolationMode.BICUBIC, - ) - else: - arr = resize( - arr, - size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], - interpolation=InterpolationMode.BICUBIC, - ) - - # Perform cropping - h, w = arr.shape[2], arr.shape[3] - delta_h, delta_w = h - image_size[0], w - image_size[1] - - if self.video_reshape_mode == "random": - top, left = np.random.randint(0, delta_h + 1), np.random.randint(0, delta_w + 1) - elif self.video_reshape_mode == "center": - top, left = delta_h // 2, delta_w // 2 - else: - raise NotImplementedError(f"Unsupported reshape mode: {self.video_reshape_mode}") - - return transforms.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) diff --git a/training/mochi-1/dataset_simple.py b/training/mochi-1/dataset_simple.py new file mode 100644 index 00000000..8cc6153b --- /dev/null +++ b/training/mochi-1/dataset_simple.py @@ -0,0 +1,50 @@ +""" +Taken from +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py +""" + +from pathlib import Path + +import click +import torch +from torch.utils.data import DataLoader, Dataset + + +def load_to_cpu(x): + return torch.load(x, map_location=torch.device("cpu"), weights_only=True) + + +class LatentEmbedDataset(Dataset): + def __init__(self, file_paths, repeat=1): + self.items = [ + (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt")) + for p in file_paths + if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file() + ] + self.items = self.items * repeat + print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.") + + def __len__(self): + return len(self.items) + + def __getitem__(self, idx): + latent_path, embed_path = self.items[idx] + return load_to_cpu(latent_path), load_to_cpu(embed_path) + + +@click.command() +@click.argument("directory", type=click.Path(exists=True, file_okay=False)) +def process_videos(directory): + dir_path = Path(directory) + mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")] + assert mp4_files, f"No mp4 files found" + + dataset = LatentEmbedDataset(mp4_files) + dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + + for latents, embeds in dataloader: + print([(k, v.shape) for k, v in latents.items()]) + + +if __name__ == "__main__": + process_videos() diff --git a/training/mochi-1/embed.py b/training/mochi-1/embed.py new file mode 100644 index 00000000..ec35ebb0 --- /dev/null +++ b/training/mochi-1/embed.py @@ -0,0 +1,111 @@ +""" +Adapted from: +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py +""" + +import click +import torch +import torchvision +from pathlib import Path +from diffusers import AutoencoderKLMochi, MochiPipeline +from transformers import T5EncoderModel, T5Tokenizer +from tqdm.auto import tqdm + + +def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str): + T, H, W = [int(s) for s in shape.split("x")] + assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6" + video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs") + fps = metadata["video_fps"] + video = video.permute(3, 0, 1, 2) + og_shape = video.shape + assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}" + assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}" + assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}" + if video.shape[1] > T: + video = video[:, :T] + print(f"Trimmed video from {og_shape[1]} to first {T} frames") + video = video.unsqueeze(0) + video = video.float() / 127.5 - 1.0 + video = video.to(model.device) + + assert video.ndim == 5 + + with torch.inference_mode(): + with torch.autocast("cuda", dtype=torch.bfloat16): + ldist = model._encode(video) + + torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt")) + + +@click.command() +@click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path)) +@click.option( + "--model_id", + type=str, + help="Repo id. Should be genmo/mochi-1-preview", + default="genmo/mochi-1-preview", +) +@click.option("--shape", default="163x480x848", help="Shape of the video to encode") +@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.") +def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None: + """Process all videos and captions in a directory using a single GPU.""" + # comment out when running on unsupported hardware + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Get all video paths + video_paths = list(output_dir.glob("**/*.mp4")) + if not video_paths: + print(f"No MP4 files found in {output_dir}") + return + + text_paths = list(output_dir.glob("**/*.txt")) + if not text_paths: + print(f"No text files found in {output_dir}") + return + + # load the models + vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda") + text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder") + tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer") + pipeline = MochiPipeline.from_pretrained( + model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None + ).to("cuda") + + for idx, video_path in tqdm(enumerate(sorted(video_paths))): + print(f"Processing {video_path}") + try: + if video_path.with_suffix(".latent.pt").exists() and not overwrite: + print(f"Skipping {video_path}") + continue + + # encode videos. + encode_videos(vae, vid_path=video_path, shape=shape) + + # embed captions. + prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt") + embed_path = prompt_path.with_suffix(".embed.pt") + + if embed_path.exists() and not overwrite: + print(f"Skipping {prompt_path} - embeddings already exist") + continue + + with open(prompt_path) as f: + text = f.read().strip() + with torch.inference_mode(): + conditioning = pipeline.encode_prompt(prompt=[text]) + + conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]} + torch.save(conditioning, embed_path) + + except Exception as e: + import traceback + + traceback.print_exc() + print(f"Error processing {video_path}: {str(e)}") + + +if __name__ == "__main__": + batch_process() diff --git a/training/mochi-1/prepare_dataset.py b/training/mochi-1/prepare_dataset.py deleted file mode 100644 index 0e710db6..00000000 --- a/training/mochi-1/prepare_dataset.py +++ /dev/null @@ -1,682 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import functools -import json -import os -import pathlib -import queue -import traceback -import uuid -from concurrent.futures import ThreadPoolExecutor -from contextlib import nullcontext -from typing import Any, Dict, List, Optional, Union - -import torch -import torch.distributed as dist -from dataset_mochi import VideoDatasetWithFlexibleResize -from diffusers import AutoencoderKLMochi -from diffusers.training_utils import set_seed -from diffusers.utils import export_to_video, get_logger -from torch.utils.data import DataLoader -from torchvision import transforms -from tqdm import tqdm -from transformers import T5EncoderModel, T5Tokenizer - - -import decord # isort:skip - -decord.bridge.set_bridge("torch") - -import sys - - -sys.path.append("..") - -from dataset import BucketSampler - - -logger = get_logger(__name__) - -DTYPE_MAPPING = { - "fp32": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, -} - - -def check_height(x: Any) -> int: - x = int(x) - if x % 16 != 0: - raise argparse.ArgumentTypeError( - f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria." - ) - return x - - -def check_width(x: Any) -> int: - x = int(x) - if x % 16 != 0: - raise argparse.ArgumentTypeError( - f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria." - ) - return x - - -def check_frames(x: Any) -> int: - x = int(x) - if x % 4 != 0 and x % 4 != 1: - raise argparse.ArgumentTypeError( - f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria." - ) - return x - - -def get_args() -> Dict[str, Any]: - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_id", - type=str, - default="genmo/mochi-1-preview", - help="Hugging Face model ID to use for tokenizer, text encoder and VAE.", - ) - parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.") - parser.add_argument( - "--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data." - ) - parser.add_argument( - "--caption_column", - type=str, - default="caption", - help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).", - ) - parser.add_argument( - "--video_column", - type=str, - default="video", - help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).", - ) - parser.add_argument( - "--id_token", - type=str, - default=None, - help="Identifier token appended to the start of each prompt if provided.", - ) - parser.add_argument( - "--height_buckets", - nargs="+", - type=check_height, - default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], - ) - parser.add_argument( - "--width_buckets", - nargs="+", - type=check_width, - default=[256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536], - ) - parser.add_argument( - "--frame_buckets", - nargs="+", - type=check_frames, - default=[84], - ) - parser.add_argument( - "--random_flip", - type=float, - default=None, - help="If random horizontal flip augmentation is to be used, this should be the flip probability.", - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", - ) - parser.add_argument( - "--pin_memory", - action="store_true", - help="Whether or not to use the pinned memory setting in pytorch dataloader.", - ) - parser.add_argument( - "--video_reshape_mode", - type=str, - default=None, - help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", - ) - parser.add_argument( - "--save_image_latents", - action="store_true", - help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.", - ) - parser.add_argument( - "--output_dir", - type=str, - required=True, - help="Path to output directory where preprocessed videos/latents/embeddings will be saved.", - ) - parser.add_argument("--max_num_frames", type=int, default=84, help="Maximum number of frames in output video.") - parser.add_argument( - "--max_sequence_length", type=int, default=256, help="Max sequence length of prompt embeddings." - ) - parser.add_argument("--target_fps", type=int, default=30, help="Frame rate of output videos.") - parser.add_argument( - "--save_latents_and_embeddings", - action="store_true", - help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.", - ) - parser.add_argument( - "--use_slicing", - action="store_true", - help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", - ) - parser.add_argument( - "--use_tiling", - action="store_true", - help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", - ) - parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.") - parser.add_argument( - "--num_decode_threads", - type=int, - default=0, - help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.", - ) - parser.add_argument( - "--dtype", - type=str, - choices=["fp32", "fp16", "bf16"], - default="fp32", - help="Data type to use when generating latents and prompt embeddings.", - ) - parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.") - parser.add_argument( - "--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts." - ) - return parser.parse_args() - - -def _get_t5_prompt_embeds( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - text_input_ids=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if tokenizer is not None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.bool() - else: - if text_input_ids is None: - raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") - - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) - prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) - - return prompt_embeds, prompt_attention_mask - - -def encode_prompt( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - num_videos_per_prompt: int = 1, - max_sequence_length: int = 256, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - text_input_ids=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt_embeds, prompt_attention_mask = _get_t5_prompt_embeds( - tokenizer, - text_encoder, - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - text_input_ids=text_input_ids, - ) - return prompt_embeds, prompt_attention_mask - - -def compute_prompt_embeddings( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompts: List[str], - max_sequence_length: int, - device: torch.device, - dtype: torch.dtype, - requires_grad: bool = False, -): - ctx = nullcontext() if requires_grad else torch.no_grad() - with ctx: - prompt_embeds, prompt_attention_mask = encode_prompt( - tokenizer, - text_encoder, - prompts, - num_videos_per_prompt=1, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - return prompt_embeds, prompt_attention_mask - - -to_pil_image = transforms.ToPILImage(mode="RGB") - - -def save_image(image: torch.Tensor, path: pathlib.Path) -> None: - image = to_pil_image(image) - image.save(path) - - -def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None: - video = [to_pil_image(frame) for frame in video] - export_to_video(video, path, fps=fps) - - -def save_prompt(prompt: str, path: pathlib.Path) -> None: - with open(path, "w", encoding="utf-8") as file: - file.write(prompt) - - -def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None: - with open(path, "w", encoding="utf-8") as file: - file.write(json.dumps(metadata)) - - -@torch.no_grad() -def serialize_artifacts( - batch_size: int, - fps: int, - images_dir: Optional[pathlib.Path] = None, - image_latents_dir: Optional[pathlib.Path] = None, - videos_dir: Optional[pathlib.Path] = None, - video_latents_dir: Optional[pathlib.Path] = None, - prompts_dir: Optional[pathlib.Path] = None, - prompt_embeds_dir: Optional[pathlib.Path] = None, - prompt_attention_mask_dir: Optional[pathlib.Path] = None, - images: Optional[torch.Tensor] = None, - image_latents: Optional[torch.Tensor] = None, - videos: Optional[torch.Tensor] = None, - video_latents: Optional[torch.Tensor] = None, - prompts: Optional[List[str]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None -) -> None: - metadata = [] - for i in range(videos.size(0)): - video = videos[i:i+1] - if video.size(1) == 1: - print(f"{video_latents[i:i+1].shape=}") - metadata_dict = {"num_frames": video.size(1), "height": video.size(3), "width": video.size(4)} - metadata.append(metadata_dict) - - data_folder_mapper_list = [ - (images, images_dir, lambda img, path: save_image(img[0], path), "png"), - (image_latents, image_latents_dir, torch.save, "pt"), - (videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"), - (video_latents, video_latents_dir, torch.save, "pt"), - (prompts, prompts_dir, save_prompt, "txt"), - (prompt_embeds, prompt_embeds_dir, torch.save, "pt"), - (prompt_attention_mask, prompt_attention_mask_dir, torch.save, "pt"), - (metadata, videos_dir, save_metadata, "txt"), - ] - filenames = [uuid.uuid4() for _ in range(batch_size)] - - for data, folder, save_fn, extension in data_folder_mapper_list: - if data is None: - continue - for slice, filename in zip(data, filenames): - if isinstance(slice, torch.Tensor): - slice = slice.clone().to("cpu") - path = folder.joinpath(f"{filename}.{extension}") - save_fn(slice, path) - - -def save_intermediates(output_queue: queue.Queue) -> None: - while True: - try: - item = output_queue.get(timeout=30) - if item is None: - break - serialize_artifacts(**item) - - except queue.Empty: - continue - - -@torch.no_grad() -def main(): - args = get_args() - set_seed(args.seed) - - output_dir = pathlib.Path(args.output_dir) - tmp_dir = output_dir.joinpath("tmp") - - output_dir.mkdir(parents=True, exist_ok=True) - tmp_dir.mkdir(parents=True, exist_ok=True) - - # Create task queue for non-blocking serializing of artifacts - output_queue = queue.Queue() - save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers) - save_future = save_thread.submit(save_intermediates, output_queue) - - # Initialize distributed processing - if "LOCAL_RANK" in os.environ: - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - dist.init_process_group(backend="nccl") - world_size = dist.get_world_size() - rank = dist.get_rank() - else: - # Single GPU - local_rank = 0 - world_size = 1 - rank = 0 - torch.cuda.set_device(rank) - - # Create folders where intermediate tensors from each rank will be saved - images_dir = tmp_dir.joinpath(f"images/{rank}") - image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}") - videos_dir = tmp_dir.joinpath(f"videos/{rank}") - video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}") - prompts_dir = tmp_dir.joinpath(f"prompts/{rank}") - prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}") - prompt_attention_mask_dir = tmp_dir.joinpath(f"prompt_attention_mask/{rank}") - - images_dir.mkdir(parents=True, exist_ok=True) - image_latents_dir.mkdir(parents=True, exist_ok=True) - videos_dir.mkdir(parents=True, exist_ok=True) - video_latents_dir.mkdir(parents=True, exist_ok=True) - prompts_dir.mkdir(parents=True, exist_ok=True) - prompt_embeds_dir.mkdir(parents=True, exist_ok=True) - prompt_attention_mask_dir.mkdir(parents=True, exist_ok=True) - - weight_dtype = DTYPE_MAPPING[args.dtype] - target_fps = args.target_fps - - if weight_dtype is not None: - weight_dtype = torch.float32 - print("To get the best results, we set `weight_dtype` to `torch.float32`.") - - # 1. Dataset - dataset_init_kwargs = { - "data_root": args.data_root, - "dataset_file": args.dataset_file, - "video_reshape_mode": args.video_reshape_mode, - "caption_column": args.caption_column, - "video_column": args.video_column, - "max_num_frames": args.max_num_frames, - "id_token": args.id_token, - "height_buckets": args.height_buckets, - "width_buckets": args.width_buckets, - "frame_buckets": args.frame_buckets, - "load_tensors": False, - "random_flip": args.random_flip, - "image_to_video": args.save_image_latents, - } - dataset = VideoDatasetWithFlexibleResize(**dataset_init_kwargs) - original_dataset_size = len(dataset) - - # Split data among GPUs - if world_size > 1: - samples_per_gpu = original_dataset_size // world_size - start_index = rank * samples_per_gpu - end_index = start_index + samples_per_gpu - if rank == world_size - 1: - end_index = original_dataset_size # Make sure the last GPU gets the remaining data - - # Slice the data - dataset.prompts = dataset.prompts[start_index:end_index] - dataset.video_paths = dataset.video_paths[start_index:end_index] - else: - pass - - rank_dataset_size = len(dataset) - - # 2. Dataloader - def collate_fn(data): - prompts = [x["prompt"] for x in data[0]] - - images = None - if args.save_image_latents: - images = [x["image"] for x in data[0]] - images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True) - - videos = [x["video"] for x in data[0]] - videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True) - - return { - "images": images, - "videos": videos, - "prompts": prompts, - } - - dataloader = DataLoader( - dataset, - batch_size=1, - sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False), - collate_fn=collate_fn, - num_workers=args.dataloader_num_workers, - pin_memory=args.pin_memory, - ) - - # 3. Prepare models - device = f"cuda:{rank}" - - if args.save_latents_and_embeddings: - tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer") - text_encoder = T5EncoderModel.from_pretrained( - args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype - ).to(device) - - vae = AutoencoderKLMochi.from_pretrained( - args.model_id, subfolder="vae", torch_dtype=weight_dtype - ).to(device) - - if args.use_slicing: - vae.enable_slicing() - if args.use_tiling: - vae.enable_tiling() - - # 4. Compute latents and embeddings and save - if rank == 0: - iterator = tqdm( - dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size - ) - else: - iterator = dataloader - - for step, batch in enumerate(iterator): - try: - images = None - image_latents = None - video_latents = None - prompt_embeds = None - - if args.save_image_latents: - images = batch["images"].to(device, non_blocking=True) - images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] - - videos = batch["videos"].to(device, non_blocking=True) - videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] - - prompts = batch["prompts"] - - # Encode videos & images - # we run under autocast following the official recommendations of Mochi - with torch.autocast(device, torch.bfloat16, cache_enabled=False): - if args.save_latents_and_embeddings: - if args.use_slicing: - if args.save_image_latents: - encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)] - image_latents = torch.cat(encoded_slices) - image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) - - encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)] - video_latents = torch.cat(encoded_slices) - - else: - if args.save_image_latents: - image_latents = vae._encode(images) - image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) - - video_latents = vae._encode(videos) - - video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) - - # Encode prompts - prompt_embeds, prompt_attention_mask = compute_prompt_embeddings( - tokenizer, - text_encoder, - prompts, - args.max_sequence_length, - device, - weight_dtype, - requires_grad=False, - ) - - if images is not None: - images = (images.permute(0, 2, 1, 3, 4) + 1) / 2 - - videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2 - - output_queue.put( - { - "batch_size": len(prompts), - "fps": target_fps, - "images_dir": images_dir, - "image_latents_dir": image_latents_dir, - "videos_dir": videos_dir, - "video_latents_dir": video_latents_dir, - "prompts_dir": prompts_dir, - "prompt_embeds_dir": prompt_embeds_dir, - "prompt_attention_mask_dir": prompt_attention_mask_dir, - "images": images, - "image_latents": image_latents, - "videos": videos, - "video_latents": video_latents, - "prompts": prompts, - "prompt_embeds": prompt_embeds, - "prompt_attention_mask": prompt_attention_mask, - } - ) - - except Exception: - print("-------------------------") - print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}") - traceback.print_exc() - print("-------------------------") - - # 5. Complete distributed processing - if world_size > 1: - dist.barrier() - dist.destroy_process_group() - - output_queue.put(None) - save_thread.shutdown(wait=True) - save_future.result() - - # 6. Combine results from each rank - if rank == 0: - print( - f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`" - ) - - # Move files from each rank to common directory - for subfolder, extension in [ - ("images", "png"), - ("image_latents", "pt"), - ("videos", "mp4"), - ("video_latents", "pt"), - ("prompts", "txt"), - ("prompt_embeds", "pt"), - ("prompt_attention_mask", "pt"), - ("videos", "txt"), - ]: - tmp_subfolder = tmp_dir.joinpath(subfolder) - combined_subfolder = output_dir.joinpath(subfolder) - combined_subfolder.mkdir(parents=True, exist_ok=True) - pattern = f"*.{extension}" - - for file in tmp_subfolder.rglob(pattern): - file.replace(combined_subfolder / file.name) - - # Remove temporary directories - def rmdir_recursive(dir: pathlib.Path) -> None: - for child in dir.iterdir(): - if child.is_file(): - child.unlink() - else: - rmdir_recursive(child) - dir.rmdir() - - rmdir_recursive(tmp_dir) - - # Combine prompts and videos into individual text files and single jsonl - prompts_folder = output_dir.joinpath("prompts") - prompts = [] - stems = [] - - for filename in prompts_folder.rglob("*.txt"): - with open(filename, "r") as file: - prompts.append(file.read().strip()) - stems.append(filename.stem) - - prompts_txt = output_dir.joinpath("prompts.txt") - videos_txt = output_dir.joinpath("videos.txt") - data_jsonl = output_dir.joinpath("data.jsonl") - - with open(prompts_txt, "w") as file: - for prompt in prompts: - file.write(f"{prompt}\n") - - with open(videos_txt, "w") as file: - for stem in stems: - file.write(f"videos/{stem}.mp4\n") - - with open(data_jsonl, "w") as file: - for prompt, stem in zip(prompts, stems): - video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt") - with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file: - metadata = json.loads(metadata_file.read()) - - data = { - "prompt": prompt, - "prompt_embed": f"prompt_embeds/{stem}.pt", - "prompt_attention_mask": f"prompt_attention_mask/{stem}.pt", - "image": f"images/{stem}.png", - "image_latent": f"image_latents/{stem}.pt", - "video": f"videos/{stem}.mp4", - "video_latent": f"video_latents/{stem}.pt", - "metadata": metadata, - } - file.write(json.dumps(data) + "\n") - - print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`") - - -if __name__ == "__main__": - main() diff --git a/training/mochi-1/prepare_dataset.sh b/training/mochi-1/prepare_dataset.sh index c2aba5ce..7d6d064f 100644 --- a/training/mochi-1/prepare_dataset.sh +++ b/training/mochi-1/prepare_dataset.sh @@ -1,49 +1,9 @@ #!/bin/bash -MODEL_ID="genmo/mochi-1-preview" +GPU_ID=0 +VIDEO_DIR=/home/sayak/cogvideox-factory/video-dataset-disney-organized +OUTPUT_DIR=videos_prepared -NUM_GPUS=1 +python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=37 --resolution=480x848 --force_upsample -# For more details on the expected data format, please refer to the README. -DATA_ROOT="/home/sayak/cogvideox-factory/video-dataset-disney" # This needs to be the path to the base directory where your videos are located. -CAPTION_COLUMN="prompt.txt" -VIDEO_COLUMN="videos.txt" -OUTPUT_DIR="/home/sayak/cogvideox-factory/video-dataset-disney/mochi-1/preprocessed-dataset" -HEIGHT_BUCKETS="480" -WIDTH_BUCKETS="848" -FRAME_BUCKETS="85" -MAX_NUM_FRAMES="85" -MAX_SEQUENCE_LENGTH=256 -TARGET_FPS=30 -BATCH_SIZE=4 -DTYPE=fp32 - -# To create a folder-style dataset structure without pre-encoding videos and captions -# For Image-to-Video finetuning, make sure to pass `--save_image_latents` -CMD_WITHOUT_PRE_ENCODING="\ - torchrun --nproc_per_node=$NUM_GPUS \ - prepare_dataset.py \ - --model_id $MODEL_ID \ - --data_root $DATA_ROOT \ - --caption_column $CAPTION_COLUMN \ - --video_column $VIDEO_COLUMN \ - --output_dir $OUTPUT_DIR \ - --height_buckets $HEIGHT_BUCKETS \ - --width_buckets $WIDTH_BUCKETS \ - --frame_buckets $FRAME_BUCKETS \ - --max_num_frames $MAX_NUM_FRAMES \ - --max_sequence_length $MAX_SEQUENCE_LENGTH \ - --target_fps $TARGET_FPS \ - --batch_size $BATCH_SIZE \ - --use_slicing \ - --dtype $DTYPE -" - -CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_latents_and_embeddings" - -# Select which you'd like to run -CMD=$CMD_WITH_PRE_ENCODING - -echo "===== Running \`$CMD\` =====" -eval $CMD -echo -ne "===== Finished running script =====\n" +CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=37x480x848 \ No newline at end of file diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index e73dcd95..874af27c 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import gc -import json +import random +from glob import glob import logging import math import os import shutil +import torch.nn.functional as F from datetime import timedelta from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Tuple, List import diffusers import torch @@ -44,11 +45,7 @@ ) from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from diffusers.optimization import get_scheduler -from diffusers.training_utils import ( - cast_training_params, - compute_density_for_timestep_sampling, - compute_loss_weighting_for_sd3, -) +from diffusers.training_utils import cast_training_params from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -56,20 +53,17 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from torch.utils.data import DataLoader from tqdm.auto import tqdm -from transformers import AutoTokenizer, T5EncoderModel from args import get_args # isort:skip -from dataset_mochi import VideoDatasetWithFlexibleResize # isort:skip +from dataset_simple import LatentEmbedDataset import sys sys.path.append("..") -from dataset import BucketSampler # isort:skip -from text_encoder import compute_prompt_embeddings # isort:skip -from utils import get_gradient_norm, get_optimizer, print_memory, reset_memory # isort:skip +from utils import get_optimizer, print_memory, reset_memory # isort:skip logger = get_logger(__name__) @@ -81,7 +75,7 @@ def save_model_card( base_model: str = None, validation_prompt=None, repo_folder=None, - fps=8, + fps=30, ): widget_dict = [] if videos is not None and len(videos) > 0: @@ -196,29 +190,44 @@ def log_validation( return videos -class CollateFunction: - def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None: - self.weight_dtype = weight_dtype - self.load_tensors = load_tensors - - def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: - prompts = [x["prompt"] for x in data[0]] - prompt_attention_mask = None +# Adapted from the original code: +# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578 +def cast_dit(model, dtype): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + assert any( + n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"] + ), f"Unexpected linear layer: {name}" + module.to(dtype=dtype) + elif isinstance(module, torch.nn.Conv2d): + module.to(dtype=dtype) + return model - if self.load_tensors: - prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True) - prompt_attention_mask = torch.stack([x["prompt_attention_mask"] for x in data[0]]) - videos = [x["video"] for x in data[0]] - videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True) - - out_dict = { - "videos": videos, - "prompts": prompts, - } - if prompt_attention_mask is not None: - out_dict.update({"prompt_attention_mask": prompt_attention_mask}) - return out_dict +class CollateFunction: + def __init__(self, caption_dropout: float = None) -> None: + self.caption_dropout = caption_dropout + + def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]: + ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0) + z = DiagonalGaussianDistribution(ldists).sample() + assert torch.isfinite(z).all() + + # Sample noise which we will add to the samples. + eps = torch.randn_like(z) + sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32) + + prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0) + prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0) + if self.caption_dropout and random.random() < self.caption_dropout: + prompt_embeds.zero_() + prompt_attention_mask = prompt_attention_mask.long() + prompt_attention_mask.zero_() + prompt_attention_mask = prompt_attention_mask.bool() + + return dict( + z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask + ) def main(args): @@ -281,73 +290,41 @@ def main(args): ).repo_id # Prepare models and scheduler - if not args.load_tensors: - tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - ) - text_encoder = T5EncoderModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - ) - - vae = AutoencoderKLMochi.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision, - variant=args.variant, - ) - if args.enable_slicing: - vae.enable_slicing() - if args.enable_tiling: - vae.enable_tiling() - - # keep things in FP32. - text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device, dtype=torch.float32) - vae.requires_grad_(False) - vae.to(accelerator.device, dtype=torch.float32) - - load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 transformer = MochiTransformer3DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", - torch_dtype=load_dtype, revision=args.revision, variant=args.variant, ) - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - noise_scheduler_copy = copy.deepcopy(scheduler) + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) vae_config = AutoencoderKLMochi.load_config(args.pretrained_model_name_or_path, subfolder="vae") - vae_in_channels = vae_config["latent_channels"] has_latents_mean = "latents_mean" in vae_config and vae_config["latents_mean"] is not None has_latents_std = "latents_std" in vae_config and vae_config["latents_std"] is not None + if has_latents_mean and has_latents_std: + mean = torch.tensor(vae_config["latents_mean"])[:, None, None, None] + std = torch.tensor(vae_config["latents_mean"])[:, None, None, None] - VAE_SCALING_FACTOR = vae_config["scaling_factor"] - - # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 - if accelerator.state.deepspeed_plugin: - # DeepSpeed is handling precision, use what's in the DeepSpeed config - if ( - "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] - ): - weight_dtype = torch.float16 - if ( - "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] - ): - weight_dtype = torch.bfloat16 - else: - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 + # if accelerator.state.deepspeed_plugin: + # # DeepSpeed is handling precision, use what's in the DeepSpeed config + # if ( + # "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + # and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + # ): + # weight_dtype = torch.float16 + # if ( + # "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + # and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + # ): + # weight_dtype = torch.bfloat16 + # else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: # due to pytorch#99272, MPS does not yet support bfloat16. @@ -355,11 +332,12 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - # keep the transformer in FP32. transformer.requires_grad_(False) - transformer.to(accelerator.device, torch.float32) + transformer.to(accelerator.device) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() + if args.cast_dit: + transformer = cast_dit(transformer, weight_dtype) # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( @@ -492,43 +470,19 @@ def load_model_hook(models, input_dir): accelerator.print(f"Using {optimizer.__class__.__name__} optimizer.") # Dataset and DataLoader - if args.load_tensors and args.id_token: - with open(os.path.join(args.data_root, "data.jsonl")) as f: - contents = [json.loads(jline) for jline in f.read().splitlines()] - parsed_id_token = None - for content in contents: - if "id_token" in content: - parsed_id_token = content["id_token"] - if parsed_id_token is not None and parsed_id_token.strip() != args.id_token.strip(): - raise ValueError( - f"Parsed `id_token` from serialized metadata is {parsed_id_token} and provided `id_token` is {args.id_token}. They should match." - ) + train_vids = list(sorted(glob(f"{args.data_root}/*.mp4"))) + train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")] + accelerator.print(f"Found {len(train_vids)} training videos in {args.data_root}") + assert len(train_vids) > 0, f"No training data found in {args.data_root}" - dataset_init_kwargs = { - "data_root": args.data_root, - "video_reshape_mode": args.video_reshape_mode, - "dataset_file": args.dataset_file, - "caption_column": args.caption_column, - "video_column": args.video_column, - "max_num_frames": args.max_num_frames, - "id_token": args.id_token, - "height_buckets": args.height_buckets, - "width_buckets": args.width_buckets, - "frame_buckets": args.frame_buckets, - "load_tensors": args.load_tensors, - "random_flip": args.random_flip, - } - train_dataset = VideoDatasetWithFlexibleResize(**dataset_init_kwargs) - # keeping things in FP32 for now. - collate_fn = CollateFunction(weight_dtype=torch.float32, load_tensors=args.load_tensors) + collate_fn = CollateFunction(caption_dropout=args.caption_dropout) + train_dataset = LatentEmbedDataset(train_vids, repeat=1) train_dataloader = DataLoader( train_dataset, - batch_size=1, - sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True), collate_fn=collate_fn, + batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, pin_memory=args.pin_memory, - prefetch_factor=4, ) # Scheduler and math around the number of training steps. @@ -636,27 +590,6 @@ def load_model_hook(models, input_dir): disable=not accelerator.is_local_main_process, ) - # For DeepSpeed training - model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config - - if args.load_tensors: - gc.collect() - torch.cuda.empty_cache() - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - if "invert_sigmas" in noise_scheduler_copy.config and noise_scheduler_copy.config.invert_sigmas: - # https://github.com/huggingface/diffusers/blob/99c0483b67427de467f11aa35d54678fd36a7ea2/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L209 - sigma = 1.0 - sigma - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - for epoch in range(first_epoch, args.num_train_epochs): transformer.train() @@ -664,108 +597,45 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): models_to_accumulate = [transformer] with accelerator.accumulate(models_to_accumulate): - videos = batch["videos"].to(accelerator.device, non_blocking=True) - prompts = batch["prompts"] - if args.load_tensors: - prompt_attention_mask = batch["prompt_attention_mask"] - - # Encode videos - if not args.load_tensors: - videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] - latent_dist = vae.encode(videos.to(vae.dtype)).latent_dist - else: - latent_dist = DiagonalGaussianDistribution(videos) + z = batch["z"] + # revisit + # if has_latents_mean and has_latents_std: + # z = (z - mean.to(z)) / std.to(z) - videos = latent_dist.sample() - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(vae_config["latents_mean"]).view(1, vae_in_channels, 1, 1, 1).to(videos.device, videos.dtype) - ) - latents_std = ( - torch.tensor(vae_config["latents_std"]).view(1, vae_in_channels, 1, 1, 1).to(videos.device, videos.dtype) - ) - videos = (videos - latents_mean) * VAE_SCALING_FACTOR / latents_std - else: - videos = videos * VAE_SCALING_FACTOR - - # keep in FP32 for now. - videos = videos.to(memory_format=torch.contiguous_format, dtype=torch.float32) - model_input = videos - - # Encode prompts - if not args.load_tensors: - prompt_embeds, prompt_attention_mask = compute_prompt_embeddings( - tokenizer, - text_encoder, - prompts, - model_config.max_text_seq_length, - accelerator.device, - weight_dtype=weight_dtype, - requires_grad=False, - ) - else: - prompt_embeds = prompts.to(weight_dtype) - prompt_attention_mask = prompt_attention_mask.to(accelerator.device) - - # Sample noise that will be added to the latents - noise = torch.randn_like(model_input) - batch_size, num_channels, num_frames, height, width = model_input.shape - - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=batch_size, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - # indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - # timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) - # revisit. - timesteps = (u * noise_scheduler_copy.config.num_train_timesteps) + eps = batch["eps"] + sigma = batch["sigma"] + prompt_embeds = batch["prompt_embeds"] + prompt_attention_mask = batch["prompt_attention_mask"] + sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1] # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas( - timesteps=noise_scheduler_copy.timesteps[timesteps.long()].to(device=model_input.device), - n_dim=model_input.ndim, - dtype=model_input.dtype - ) - noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # do we need to revisit this? - noisy_model_input = noisy_model_input.to(weight_dtype) + z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps + ut = z - eps # Predict the noise residual - actual_num_train_timesteps = float(noise_scheduler_copy.config.num_train_timesteps) - timesteps = (1 - (timesteps / actual_num_train_timesteps)) * actual_num_train_timesteps # revisit - timesteps = timesteps.to(device=model_input.device) - model_pred = transformer( - hidden_states=noisy_model_input, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=timesteps, - return_dict=False, - )[0] - - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) - - # flow matching loss - # target = noise - model_input - target = model_input - noise # as discussed with Ajay - - loss = torch.mean( - (weighting * (model_pred.float() - target.float()) ** 2).reshape(batch_size, -1), - dim=1, - ) - loss = loss.mean() + # (1 - sigma) because of + # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656 + # Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation. + timesteps = (1 - sigma) * scheduler.config.num_train_timesteps + with torch.autocast(accelerator.device.type, weight_dtype): + model_pred = transformer( + hidden_states=z_sigma, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timesteps, + return_dict=False, + )[0] + assert model_pred.shape == z.shape + loss = F.mse_loss(model_pred.float(), ut.float()) accelerator.backward(loss) - if accelerator.sync_gradients: - gradient_norm_before_clip = get_gradient_norm(transformer_lora_parameters) - accelerator.clip_grad_norm_(transformer_lora_parameters, args.max_grad_norm) - gradient_norm_after_clip = get_gradient_norm(transformer_lora_parameters) + # if accelerator.sync_gradients: + # no grad norm for now, following the original code + # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L380 + # gradient_norm_before_clip = get_gradient_norm(transformer_lora_parameters) + # accelerator.clip_grad_norm_(transformer_lora_parameters, args.max_grad_norm) + # gradient_norm_after_clip = get_gradient_norm(transformer_lora_parameters) if accelerator.state.deepspeed_plugin is None: optimizer.step() @@ -807,14 +677,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate logs = {"loss": loss.detach().item(), "lr": last_lr} - # gradnorm + deepspeed: https://github.com/microsoft/DeepSpeed/issues/4555 - if accelerator.distributed_type != DistributedType.DEEPSPEED: - logs.update( - { - "gradient_norm_before_clip": gradient_norm_before_clip, - "gradient_norm_after_clip": gradient_norm_after_clip, - } - ) + # # gradnorm + deepspeed: https://github.com/microsoft/DeepSpeed/issues/4555 + # if accelerator.distributed_type != DistributedType.DEEPSPEED: + # logs.update( + # { + # "gradient_norm_before_clip": gradient_norm_before_clip, + # "gradient_norm_after_clip": gradient_norm_after_clip, + # } + # ) progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -822,13 +692,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): break if global_step >= args.max_train_steps: - break + break if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: accelerator.print("===== Memory before validation =====") print_memory(accelerator.device) + transformer.eval() pipe = MochiPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=unwrap_model(transformer), @@ -848,12 +719,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for validation_prompt in validation_prompts: pipeline_args = { "prompt": validation_prompt, - "guidance_scale": 4.5, + "guidance_scale": 6.0, + "num_inference_steps": 64, "height": args.height, "width": args.width, "max_sequence_length": 256, } - log_validation( pipe=pipe, args=args, @@ -866,13 +737,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): print_memory(accelerator.device) reset_memory(accelerator.device) - del pipe.text_encoder del pipe.vae del pipe gc.collect() torch.cuda.empty_cache() + transformer.train() + accelerator.wait_for_everyone() if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: @@ -884,10 +756,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # Cleanup trained models to save memory - if args.load_tensors: - del transformer - else: - del transformer, text_encoder, vae + del transformer gc.collect() torch.cuda.empty_cache() @@ -922,9 +791,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for validation_prompt in validation_prompts: pipeline_args = { "prompt": validation_prompt, - "guidance_scale": 4.5, + "guidance_scale": 6.0, + "num_inference_steps": 64, "height": args.height, "width": args.width, + "max_sequence_length": 256, } video = log_validation( diff --git a/training/mochi-1/train.sh b/training/mochi-1/train.sh index 69c636b5..d15ff695 100644 --- a/training/mochi-1/train.sh +++ b/training/mochi-1/train.sh @@ -1,50 +1,38 @@ +#!/bin/bash export NCCL_P2P_DISABLE=1 export TORCH_NCCL_ENABLE_MONITORING=0 GPU_IDS="2" -DATA_ROOT="/home/sayak/cogvideox-factory/video-dataset-disney/mochi-1/preprocessed-dataset" - -CAPTION_COLUMN="prompts.txt" -VIDEO_COLUMN="videos.txt" +DATA_ROOT="/home/sayak/cogvideox-factory/training/mochi-1/videos_prepared" +MODEL="genmo/mochi-1-preview" +OUTPUT_PATH=/raid/.cache/huggingface/sayak/mochi-lora/ cmd="accelerate launch --config_file deepspeed.yaml --gpu_ids $GPU_IDS text_to_video_lora.py \ - --pretrained_model_name_or_path genmo/mochi-1-preview \ + --pretrained_model_name_or_path $MODEL \ --data_root $DATA_ROOT \ - --caption_column $CAPTION_COLUMN \ - --video_column $VIDEO_COLUMN \ - --id_token BW_STYLE \ - --height_buckets 480 \ - --width_buckets 848 \ - --frame_buckets 85 \ - --load_tensors \ --seed 42 \ - --rank 64 \ - --lora_alpha 64 \ - --mixed_precision bf16 \ - --output_dir /raid/.cache/huggingface/sayak/mochi-lora/ \ - --max_num_frames 85 \ + --mixed_precision "bf16" \ + --output_dir $OUTPUT_PATH \ --train_batch_size 1 \ --dataloader_num_workers 4 \ - --max_train_steps 10 \ - --checkpointing_steps 50 \ + --pin_memory \ + --caption_dropout 0.1 \ + --max_train_steps 2000 \ + --checkpointing_steps 200 \ + --checkpoints_total_limit 1 \ --gradient_accumulation_steps 4 \ --gradient_checkpointing \ - --learning_rate 1e-5 \ - --lr_scheduler constant \ - --lr_warmup_steps 0 \ - --lr_num_cycles 1 \ --enable_slicing \ --enable_tiling \ + --enable_model_cpu_offload \ --optimizer adamw --use_8bit \ - --beta1 0.9 \ - --beta2 0.95 \ - --beta3 0.99 \ - --weight_decay 0.001 \ - --max_grad_norm 1.0 \ + --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 1 \ --allow_tf32 \ --report_to wandb \ - --push_to_hub \ --nccl_timeout 1800" echo "Running command: $cmd" diff --git a/training/mochi-1/trim_and_crop_videos.py b/training/mochi-1/trim_and_crop_videos.py new file mode 100644 index 00000000..0c6f411d --- /dev/null +++ b/training/mochi-1/trim_and_crop_videos.py @@ -0,0 +1,126 @@ +""" +Adapted from: +https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py +""" + +from pathlib import Path +import shutil + +import click +from moviepy.editor import VideoFileClip +from tqdm import tqdm + + +@click.command() +@click.argument("folder", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_folder", type=click.Path(dir_okay=True)) +@click.option("--num_frames", "-f", type=float, default=30, help="Number of frames") +@click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution") +@click.option("--force_upsample", is_flag=True, help="Force upsample.") +def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample): + """Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution""" + input_path = Path(folder) + output_path = Path(output_folder) + output_path.mkdir(parents=True, exist_ok=True) + + # Parse target resolution + target_height, target_width = map(int, resolution.split("x")) + + # Calculate duration + duration = (num_frames / 30) + 0.09 + + # Find all MP4 and MOV files + video_files = ( + list(input_path.rglob("*.mp4")) + + list(input_path.rglob("*.MOV")) + + list(input_path.rglob("*.mov")) + + list(input_path.rglob("*.MP4")) + ) + + for file_path in tqdm(video_files): + try: + relative_path = file_path.relative_to(input_path) + output_file = output_path / relative_path.with_suffix(".mp4") + output_file.parent.mkdir(parents=True, exist_ok=True) + + click.echo(f"Processing: {file_path}") + video = VideoFileClip(str(file_path)) + + # Skip if video is too short + if video.duration < duration: + click.echo(f"Skipping {file_path} as it is too short") + continue + + # Skip if target resolution is larger than input + if target_width > video.w or target_height > video.h: + if force_upsample: + click.echo( + f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video." + ) + video = video.resize(width=target_width, height=target_height) + else: + click.echo( + f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}" + ) + continue + + # First truncate duration + truncated = video.subclip(0, duration) + + # Calculate crop dimensions to maintain aspect ratio + target_ratio = target_width / target_height + current_ratio = truncated.w / truncated.h + + if current_ratio > target_ratio: + # Video is wider than target ratio - crop width + new_width = int(truncated.h * target_ratio) + x1 = (truncated.w - new_width) // 2 + final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height)) + else: + # Video is taller than target ratio - crop height + new_height = int(truncated.w / target_ratio) + y1 = (truncated.h - new_height) // 2 + final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height)) + + # Set output parameters for consistent MP4 encoding + output_params = { + "codec": "libx264", + "audio": False, # Disable audio + "preset": "medium", # Balance between speed and quality + "bitrate": "5000k", # Adjust as needed + } + + # Set FPS to 30 + final = final.set_fps(30) + + # Check for a corresponding .txt file + txt_file_path = file_path.with_suffix(".txt") + if txt_file_path.exists(): + output_txt_file = output_path / relative_path.with_suffix(".txt") + output_txt_file.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(txt_file_path, output_txt_file) + click.echo(f"Copied {txt_file_path} to {output_txt_file}") + else: + # Print warning in bold yellow with a warning emoji + click.echo( + f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m" + ) + output_txt_file = output_path / relative_path.with_suffix(".txt") + output_txt_file.parent.mkdir(parents=True, exist_ok=True) + output_txt_file.touch() + + # Write the output file + final.write_videofile(str(output_file), **output_params) + + # Clean up + video.close() + truncated.close() + final.close() + + except Exception as e: + click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True) + raise + + +if __name__ == "__main__": + truncate_videos() From ced8558eb06449dbe4dd8a819ed378803b01ddc1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 29 Nov 2024 10:48:54 +0530 Subject: [PATCH 21/30] updates --- training/mochi-1/README.md | 96 ++++ training/mochi-1/args.py | 180 +----- training/mochi-1/deepspeed.yaml | 23 - training/mochi-1/prepare_dataset.sh | 6 +- training/mochi-1/requirements.txt | 7 + training/mochi-1/text_to_video_lora.py | 730 ++++++++----------------- training/mochi-1/train.sh | 19 +- 7 files changed, 355 insertions(+), 706 deletions(-) create mode 100644 training/mochi-1/README.md delete mode 100644 training/mochi-1/deepspeed.yaml create mode 100644 training/mochi-1/requirements.txt diff --git a/training/mochi-1/README.md b/training/mochi-1/README.md new file mode 100644 index 00000000..5e278ca7 --- /dev/null +++ b/training/mochi-1/README.md @@ -0,0 +1,96 @@ +# Simple Mochi-1 finetuner + +Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨 + +We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation. + +## Getting started + +Install the dependencies: `pip install -r requirements.txt`. Also make sure your `diffusers` installation is from the current `main`. + +Download a demo dataset: + +```bash +huggingface-cli download \ + --repo-type dataset sayakpaul/video-dataset-disney-organized \ + --local-dir video-dataset-disney-organized +``` + +The dataset follows the directory structure expected by the subsequent scripts. In particular, it follows what's prescribed [here](https://github.com/genmoai/mochi/tree/main/demos/fine_tuner#1-collect-your-videos-and-captions): + +```bash +video_1.mp4 +video_1.txt -- One-paragraph description of video_1 +video_2.mp4 +video_2.txt -- One-paragraph description of video_2 +... +``` + +Then run (be sure to check the paths accordingly): + +```bash +bash prepare_dataset.sh +``` + +We can adjust `num_frames` and `resolution`. By default, in `prepare_dataset.sh`, we use `--force_upsample`. This means if the original video resolution is smaller than the requested resolution, we will upsample the video. + +> [!IMPORTANT] +> It's important to have a resolution of at least 480x848 to satisy Mochi-1's requirements. + +Now, we're ready to fine-tune. To launch, run: + +```bash +bash train.sh +``` + +You can disable intermediate validation by: + +```diff +- --validation_prompt "..." \ +- --validation_prompt_separator ::: \ +- --num_validation_videos 1 \ +- --validation_epochs 1 \ +``` + +We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM. + +To use the LoRA checkpoint: + +```py +from diffusers import MochiPipeline +from diffusers.utils import export_to_video +import torch + +pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") +pipe.load_lora_weights("path-to-lora") +pipe.enable_model_cpu_offload() + +pipeline_args = { + "prompt": "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions", + "guidance_scale": 6.0, + "num_inference_steps": 64, + "height": 480, + "width": 848, + "max_sequence_length": 256, + "output_type": "np", +} + +with torch.autocast("cuda", torch.bfloat16) + video = pipe(**pipeline_args).frames[0] +export_to_video(video) +``` + +## Known limitations + +(Contributions are welcome 🤗) + +Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below: + +* No support for distributed training. +* No intermediate checkpoint saving and loading support. +* `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support. + +**Misc**: + +* We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033). +* `embed.py` script is non-batched. \ No newline at end of file diff --git a/training/mochi-1/args.py b/training/mochi-1/args.py index 46a8a178..4f1cc7e8 100644 --- a/training/mochi-1/args.py +++ b/training/mochi-1/args.py @@ -1,3 +1,9 @@ +""" +Default values taken from +https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml +when applicable. +""" + import argparse @@ -33,6 +39,11 @@ def _get_model_args(parser: argparse.ArgumentParser) -> None: action="store_true", help="If we should cast DiT params to a lower precision.", ) + parser.add_argument( + "--compile_dit", + action="store_true", + help="If we should cast DiT params to a lower precision.", + ) def _get_dataset_args(parser: argparse.ArgumentParser) -> None: @@ -93,6 +104,18 @@ def _get_validation_args(parser: argparse.ArgumentParser) -> None: default=50, help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", ) + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) parser.add_argument( "--enable_model_cpu_offload", action="store_true", @@ -133,17 +156,6 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: default=["to_k", "to_q", "to_v", "to_out.0"], help="Target modules to train LoRA for.", ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. " - "Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this " - "argument to override the accelerate config." - ), - ) parser.add_argument( "--output_dir", type=str, @@ -163,37 +175,6 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: default=None, help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) parser.add_argument( "--gradient_checkpointing", action="store_true", @@ -210,45 +191,12 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None: action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="cosine", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) parser.add_argument( "--lr_warmup_steps", type=int, default=200, help="Number of steps for the warmup in the lr scheduler.", ) - parser.add_argument( - "--lr_num_cycles", - type=int, - default=1, - help="Number of hard resets of the lr in cosine_with_restarts scheduler.", - ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) - parser.add_argument( - "--enable_slicing", - action="store_true", - default=False, - help="Whether or not to use VAE slicing for saving memory.", - ) - parser.add_argument( - "--enable_tiling", - action="store_true", - default=False, - help="Whether or not to use VAE tiling for saving memory.", - ) def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: @@ -256,78 +204,15 @@ def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: "--optimizer", type=lambda s: s.lower(), default="adam", - choices=["adam", "adamw", "prodigy", "came"], + choices=["adam", "adamw"], help=("The optimizer type to use."), ) - parser.add_argument( - "--use_8bit", - action="store_true", - help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.", - ) - parser.add_argument( - "--use_4bit", - action="store_true", - help="Whether or not to use 4-bit optimizers from `torchao`.", - ) - parser.add_argument( - "--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers." - ) - parser.add_argument( - "--beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam and Prodigy optimizers.", - ) - parser.add_argument( - "--beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam and Prodigy optimizers.", - ) - parser.add_argument( - "--beta3", - type=float, - default=None, - help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", - ) - parser.add_argument( - "--prodigy_decouple", - action="store_true", - help="Use AdamW style decoupled weight decay.", - ) parser.add_argument( "--weight_decay", type=float, default=0.01, help="Weight decay to use for optimizer.", ) - parser.add_argument( - "--epsilon", - type=float, - default=1e-8, - help="Epsilon value for the Adam optimizer and Prodigy optimizers.", - ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument( - "--prodigy_use_bias_correction", - action="store_true", - help="Turn on Adam's bias correction.", - ) - parser.add_argument( - "--prodigy_safeguard_warmup", - action="store_true", - help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", - ) - parser.add_argument( - "--use_cpu_offload_optimizer", - action="store_true", - help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.", - ) - parser.add_argument( - "--offload_gradients", - action="store_true", - help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.", - ) def _get_configuration_args(parser: argparse.ArgumentParser) -> None: @@ -349,12 +234,6 @@ def _get_configuration_args(parser: argparse.ArgumentParser) -> None: default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help="Directory where logs are stored.", - ) parser.add_argument( "--allow_tf32", action="store_true", @@ -363,20 +242,11 @@ def _get_configuration_args(parser: argparse.ArgumentParser) -> None: " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) - parser.add_argument( - "--nccl_timeout", - type=int, - default=600, - help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.", - ) parser.add_argument( "--report_to", type=str, default=None, - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), + help="If logging to wandb." ) diff --git a/training/mochi-1/deepspeed.yaml b/training/mochi-1/deepspeed.yaml deleted file mode 100644 index efbbf6fa..00000000 --- a/training/mochi-1/deepspeed.yaml +++ /dev/null @@ -1,23 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - gradient_accumulation_steps: 1 - gradient_clipping: 1.0 - offload_optimizer_device: cpu - offload_param_device: cpu - zero3_init_flag: false - zero_stage: 2 -distributed_type: DEEPSPEED -downcast_bf16: 'no' -enable_cpu_affinity: false -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 1 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/training/mochi-1/prepare_dataset.sh b/training/mochi-1/prepare_dataset.sh index 7d6d064f..03a96559 100644 --- a/training/mochi-1/prepare_dataset.sh +++ b/training/mochi-1/prepare_dataset.sh @@ -1,9 +1,11 @@ #!/bin/bash GPU_ID=0 -VIDEO_DIR=/home/sayak/cogvideox-factory/video-dataset-disney-organized +VIDEO_DIR=video-dataset-disney-organized OUTPUT_DIR=videos_prepared +NUM_FRAMES=37 +RESOLUTION=480x848 -python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=37 --resolution=480x848 --force_upsample +python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=37x480x848 \ No newline at end of file diff --git a/training/mochi-1/requirements.txt b/training/mochi-1/requirements.txt new file mode 100644 index 00000000..8fb970ab --- /dev/null +++ b/training/mochi-1/requirements.txt @@ -0,0 +1,7 @@ +peft +transformers +wandb +torch +torchvision +moviepy +click \ No newline at end of file diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 874af27c..1a7122e5 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -16,41 +16,22 @@ import gc import random from glob import glob -import logging import math import os -import shutil import torch.nn.functional as F -from datetime import timedelta +import numpy as np from pathlib import Path from typing import Any, Dict, Tuple, List -import diffusers import torch -import transformers import wandb -from accelerate import Accelerator, DistributedType -from accelerate.logging import get_logger -from accelerate.utils import ( - DistributedDataParallelKwargs, - InitProcessGroupKwargs, - ProjectConfiguration, - set_seed, -) -from diffusers import ( - AutoencoderKLMochi, - FlowMatchEulerDiscreteScheduler, - MochiPipeline, - MochiTransformer3DModel, -) +from diffusers import FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params -from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video +from diffusers.utils import export_to_video from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card -from diffusers.utils.torch_utils import is_compiled_module from huggingface_hub import create_repo, upload_folder -from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from peft import LoraConfig, get_peft_model_state_dict from torch.utils.data import DataLoader from tqdm.auto import tqdm @@ -63,10 +44,23 @@ sys.path.append("..") -from utils import get_optimizer, print_memory, reset_memory # isort:skip +from utils import print_memory, reset_memory # isort:skip -logger = get_logger(__name__) +# Taken from +# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139 +def get_cosine_annealing_lr_scheduler( + optimizer: torch.optim.Optimizer, + warmup_steps: int, + total_steps: int, +): + def lr_lambda(step): + if step < warmup_steps: + return float(step) / float(max(1, warmup_steps)) + else: + return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps))) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) def save_model_card( @@ -84,7 +78,7 @@ def save_model_card( widget_dict.append( { "text": validation_prompt if validation_prompt else " ", - "output": {"url": f"video_{i}.mp4"}, + "output": {"url": f"final_video_{i}.mp4"}, } ) @@ -138,54 +132,53 @@ def save_model_card( def log_validation( - accelerator: Accelerator, pipe: MochiPipeline, args: Dict[str, Any], pipeline_args: Dict[str, Any], epoch, + wandb_run: str = None, is_final_validation: bool = False, ): - logger.info( + print( f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." ) + phase_name = "test" if is_final_validation else "validation" if not args.enable_model_cpu_offload: - pipe = pipe.to(accelerator.device) + pipe = pipe.to("cuda") # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + generator = torch.manual_seed(args.seed) if args.seed else None videos = [] - with torch.autocast(accelerator.device.type, torch.bfloat16, cache_enabled=False): + with torch.autocast("cuda", torch.bfloat16, cache_enabled=False): for _ in range(args.num_validation_videos): video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] videos.append(video) - for tracker in accelerator.trackers: - phase_name = "test" if is_final_validation else "validation" - if tracker.name == "wandb": - video_filenames = [] - for i, video in enumerate(videos): - prompt = ( - pipeline_args["prompt"][:25] - .replace(" ", "_") - .replace(" ", "_") - .replace("'", "_") - .replace('"', "_") - .replace("/", "_") - ) - filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") - export_to_video(video, filename, fps=30) - video_filenames.append(filename) - - tracker.log( - { - phase_name: [ - wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30) - for i, filename in enumerate(video_filenames) - ] - } - ) + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=30) + video_filenames.append(filename) + + if wandb_run: + wandb.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30) + for i, filename in enumerate(video_filenames) + ] + } + ) return videos @@ -231,63 +224,24 @@ def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch. def main(args): + if not torch.cuda.is_available(): + raise ValueError("Not supported without CUDA.") + if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." " Please use `huggingface-cli login` to authenticate with the Hub." ) - if torch.backends.mps.is_available() and args.mixed_precision == "bf16": - # due to pytorch#99272, MPS does not yet support bfloat16. - raise ValueError( - "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." - ) - - logging_dir = Path(args.output_dir, args.logging_dir) - - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], - ) - - # Disable AMP for MPS. - if torch.backends.mps.is_available(): - accelerator.native_amp = False - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - # Handle the repository creation - if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) - if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, - exist_ok=True, - ).repo_id + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id # Prepare models and scheduler transformer = MochiTransformer3DModel.from_pretrained( @@ -300,44 +254,14 @@ def main(args): args.pretrained_model_name_or_path, subfolder="scheduler" ) - vae_config = AutoencoderKLMochi.load_config(args.pretrained_model_name_or_path, subfolder="vae") - has_latents_mean = "latents_mean" in vae_config and vae_config["latents_mean"] is not None - has_latents_std = "latents_std" in vae_config and vae_config["latents_std"] is not None - if has_latents_mean and has_latents_std: - mean = torch.tensor(vae_config["latents_mean"])[:, None, None, None] - std = torch.tensor(vae_config["latents_mean"])[:, None, None, None] - - weight_dtype = torch.float32 - # if accelerator.state.deepspeed_plugin: - # # DeepSpeed is handling precision, use what's in the DeepSpeed config - # if ( - # "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config - # and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] - # ): - # weight_dtype = torch.float16 - # if ( - # "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config - # and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] - # ): - # weight_dtype = torch.bfloat16 - # else: - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: - # due to pytorch#99272, MPS does not yet support bfloat16. - raise ValueError( - "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." - ) - transformer.requires_grad_(False) - transformer.to(accelerator.device) + transformer.to("cuda") if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() if args.cast_dit: - transformer = cast_dit(transformer, weight_dtype) + transformer = cast_dit(transformer, torch.bfloat16) + if args.compile_dit: + transformer.compile() # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( @@ -348,131 +272,25 @@ def main(args): ) transformer.add_adapter(transformer_lora_config) - def unwrap_model(model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - - for model in models: - if isinstance(unwrap_model(model), type(unwrap_model(transformer))): - model = unwrap_model(model) - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() - - MochiPipeline.save_lora_weights( - output_dir, - transformer_lora_layers=transformer_lora_layers_to_save, - ) - - def load_model_hook(models, input_dir): - transformer_ = None - - # This is a bit of a hack but I don't know any other solution. - if not accelerator.distributed_type == DistributedType.DEEPSPEED: - while len(models) > 0: - model = models.pop() - - if isinstance(unwrap_model(model), type(unwrap_model(transformer))): - transformer_ = unwrap_model(model) - else: - raise ValueError(f"Unexpected save model: {unwrap_model(model).__class__}") - else: - transformer_ = MochiTransformer3DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer" - ) - transformer_.add_adapter(transformer_lora_config) - - lora_state_dict = MochiPipeline.lora_state_dict(input_dir) - - transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") - } - transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. More details: - # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([transformer_]) - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32 and torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) + args.learning_rate = args.learning_rate * args.train_batch_size # only upcast trainable parameters (LoRA) into fp32 cast_training_params([transformer], dtype=torch.float32) + # Prepare optimizer transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - - # Optimization parameters - transformer_parameters_with_lr = { - "params": transformer_lora_parameters, - "lr": args.learning_rate, - } - params_to_optimize = [transformer_parameters_with_lr] - num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) - - use_deepspeed_optimizer = ( - accelerator.state.deepspeed_plugin is not None - and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config - ) - use_deepspeed_scheduler = ( - accelerator.state.deepspeed_plugin is not None - and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config - ) - - optimizer = get_optimizer( - params_to_optimize=params_to_optimize, - optimizer_name=args.optimizer, - learning_rate=args.learning_rate, - beta1=args.beta1, - beta2=args.beta2, - beta3=args.beta3, - epsilon=args.epsilon, - weight_decay=args.weight_decay, - prodigy_decouple=args.prodigy_decouple, - prodigy_use_bias_correction=args.prodigy_use_bias_correction, - prodigy_safeguard_warmup=args.prodigy_safeguard_warmup, - use_8bit=args.use_8bit, - use_4bit=args.use_4bit, - use_torchao=args.use_torchao, - use_deepspeed=use_deepspeed_optimizer, - use_cpu_offload_optimizer=args.use_cpu_offload_optimizer, - offload_gradients=args.offload_gradients, - ) - accelerator.print(f"Using {optimizer.__class__.__name__} optimizer.") + num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters) + optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay) # Dataset and DataLoader train_vids = list(sorted(glob(f"{args.data_root}/*.mp4"))) train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")] - accelerator.print(f"Found {len(train_vids)} training videos in {args.data_root}") + print(f"Found {len(train_vids)} training videos in {args.data_root}") assert len(train_vids) > 0, f"No training data found in {args.data_root}" collate_fn = CollateFunction(caption_dropout=args.caption_dropout) @@ -485,46 +303,19 @@ def load_model_hook(models, input_dir): pin_memory=args.pin_memory, ) - # Scheduler and math around the number of training steps. + # LR scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = len(train_dataloader) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - if args.use_cpu_offload_optimizer: - lr_scheduler = None - accelerator.print( - "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If " - "you are training with those settings, they will be ignored." - ) - else: - if use_deepspeed_scheduler: - from accelerate.utils import DummyScheduler - - lr_scheduler = DummyScheduler( - name=args.lr_scheduler, - optimizer=optimizer, - total_num_steps=args.max_train_steps * accelerator.num_processes, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - ) - else: - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - ) - - # Prepare everything with our `accelerator`. - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler + lr_scheduler = get_cosine_annealing_lr_scheduler( + optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = len(train_dataloader) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -532,80 +323,43 @@ def load_model_hook(models, input_dir): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. - if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + wandb_run = None + if args.report_to == "wandb": tracker_name = args.tracker_name or "mochi-1-lora" - accelerator.init_trackers(tracker_name, config=vars(args)) + wandb_run = wandb.init(project=tracker_name, config=vars(args)) - accelerator.print("===== Memory before training =====") - reset_memory(accelerator.device) - print_memory(accelerator.device) + print("===== Memory before training =====") + reset_memory("cuda") + print_memory("cuda") # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - accelerator.print("***** Running training *****") - accelerator.print(f" Num trainable parameters = {num_trainable_parameters}") - accelerator.print(f" Num examples = {len(train_dataset)}") - accelerator.print(f" Num batches each epoch = {len(train_dataloader)}") - accelerator.print(f" Num epochs = {args.num_train_epochs}") - accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}") - accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") - accelerator.print(f" Total optimization steps = {args.max_train_steps}") + total_batch_size = args.train_batch_size + print("***** Running training *****") + print(f" Num trainable parameters = {num_trainable_parameters}") + print(f" Num examples = {len(train_dataset)}") + print(f" Num batches each epoch = {len(train_dataloader)}") + print(f" Num epochs = {args.num_train_epochs}") + print(f" Instantaneous batch size per device = {args.train_batch_size}") + print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if not args.resume_from_checkpoint: - initial_global_step = 0 - else: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - initial_global_step = 0 - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - progress_bar = tqdm( range(0, args.max_train_steps), - initial=initial_global_step, + initial=global_step, desc="Steps", - # Only show the progress bar once on each machine. - disable=not accelerator.is_local_main_process, ) - for epoch in range(first_epoch, args.num_train_epochs): transformer.train() for step, batch in enumerate(train_dataloader): - models_to_accumulate = [transformer] - - with accelerator.accumulate(models_to_accumulate): - z = batch["z"] - # revisit - # if has_latents_mean and has_latents_std: - # z = (z - mean.to(z)) / std.to(z) - - eps = batch["eps"] - sigma = batch["sigma"] - prompt_embeds = batch["prompt_embeds"] - prompt_attention_mask = batch["prompt_attention_mask"] + with torch.no_grad(): + z = batch["z"].to("cuda") + eps = batch["eps"].to("cuda") + sigma = batch["sigma"].to("cuda") + prompt_embeds = batch["prompt_embeds"].to("cuda") + prompt_attention_mask = batch["prompt_attention_mask"].to("cuda") sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1] # Add noise according to flow matching. @@ -613,80 +367,35 @@ def load_model_hook(models, input_dir): z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps ut = z - eps - # Predict the noise residual # (1 - sigma) because of # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656 # Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation. timesteps = (1 - sigma) * scheduler.config.num_train_timesteps - with torch.autocast(accelerator.device.type, weight_dtype): - model_pred = transformer( - hidden_states=z_sigma, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=timesteps, - return_dict=False, - )[0] - assert model_pred.shape == z.shape - loss = F.mse_loss(model_pred.float(), ut.float()) - accelerator.backward(loss) - - # if accelerator.sync_gradients: - # no grad norm for now, following the original code - # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L380 - # gradient_norm_before_clip = get_gradient_norm(transformer_lora_parameters) - # accelerator.clip_grad_norm_(transformer_lora_parameters, args.max_grad_norm) - # gradient_norm_after_clip = get_gradient_norm(transformer_lora_parameters) - - if accelerator.state.deepspeed_plugin is None: - optimizer.step() - optimizer.zero_grad() - - if not args.use_cpu_offload_optimizer: - lr_scheduler.step() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") + + with torch.autocast("cuda", torch.bfloat16): + model_pred = transformer( + hidden_states=z_sigma, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timesteps, + return_dict=False, + )[0] + assert model_pred.shape == z.shape + loss = F.mse_loss(model_pred.float(), ut.float()) + loss.backward() + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + progress_bar.update(1) + global_step += 1 last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate logs = {"loss": loss.detach().item(), "lr": last_lr} - # # gradnorm + deepspeed: https://github.com/microsoft/DeepSpeed/issues/4555 - # if accelerator.distributed_type != DistributedType.DEEPSPEED: - # logs.update( - # { - # "gradient_norm_before_clip": gradient_norm_before_clip, - # "gradient_norm_after_clip": gradient_norm_after_clip, - # } - # ) progress_bar.set_postfix(**logs) - accelerator.log(logs, step=global_step) + if wandb_run: + wandb_run.log(logs, step=global_step) if global_step >= args.max_train_steps: break @@ -694,82 +403,15 @@ def load_model_hook(models, input_dir): if global_step >= args.max_train_steps: break - if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: - if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: - accelerator.print("===== Memory before validation =====") - print_memory(accelerator.device) - - transformer.eval() - pipe = MochiPipeline.from_pretrained( - args.pretrained_model_name_or_path, - transformer=unwrap_model(transformer), - scheduler=scheduler, - revision=args.revision, - variant=args.variant, - ) - - if args.enable_slicing: - pipe.vae.enable_slicing() - if args.enable_tiling: - pipe.vae.enable_tiling() - if args.enable_model_cpu_offload: - pipe.enable_model_cpu_offload() - - validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) - for validation_prompt in validation_prompts: - pipeline_args = { - "prompt": validation_prompt, - "guidance_scale": 6.0, - "num_inference_steps": 64, - "height": args.height, - "width": args.width, - "max_sequence_length": 256, - } - log_validation( - pipe=pipe, - args=args, - accelerator=accelerator, - pipeline_args=pipeline_args, - epoch=epoch, - ) - - accelerator.print("===== Memory after validation =====") - print_memory(accelerator.device) - reset_memory(accelerator.device) - - del pipe.text_encoder - del pipe.vae - del pipe - gc.collect() - torch.cuda.empty_cache() - - transformer.train() - - accelerator.wait_for_everyone() - - if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: - transformer = unwrap_model(transformer) - transformer_lora_layers = get_peft_model_state_dict(transformer) - MochiPipeline.save_lora_weights( - save_directory=args.output_dir, - transformer_lora_layers=transformer_lora_layers, - ) - - # Cleanup trained models to save memory - del transformer - - gc.collect() - torch.cuda.empty_cache() - - # Final test inference - validation_outputs = [] - if args.validation_prompt and args.num_validation_videos > 0: - accelerator.print("===== Memory before testing =====") - print_memory(accelerator.device) - reset_memory(accelerator.device) + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + print("===== Memory before validation =====") + print_memory("cuda") + transformer.eval() pipe = MochiPipeline.from_pretrained( args.pretrained_model_name_or_path, + transformer=transformer, + scheduler=scheduler, revision=args.revision, variant=args.variant, ) @@ -781,12 +423,6 @@ def load_model_hook(models, input_dir): if args.enable_model_cpu_offload: pipe.enable_model_cpu_offload() - # Load LoRA weights - lora_scaling = args.lora_alpha / args.rank - pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora") - pipe.set_adapters(["mochi-lora"], [lora_scaling]) - - # Run inference validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) for validation_prompt in validation_prompts: pipeline_args = { @@ -797,40 +433,104 @@ def load_model_hook(models, input_dir): "width": args.width, "max_sequence_length": 256, } - - video = log_validation( - accelerator=accelerator, + log_validation( pipe=pipe, args=args, pipeline_args=pipeline_args, epoch=epoch, - is_final_validation=True, + wandb_run=wandb_run, ) - validation_outputs.extend(video) - - accelerator.print("===== Memory after testing =====") - print_memory(accelerator.device) - reset_memory(accelerator.device) - torch.cuda.synchronize(accelerator.device) - - if args.push_to_hub: - save_model_card( - repo_id, - videos=validation_outputs, - base_model=args.pretrained_model_name_or_path, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - fps=args.fps, - ) - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*", "*.bin", "*.pt"], - ) - accelerator.print(f"Params pushed to {repo_id}.") - accelerator.end_training() + print("===== Memory after validation =====") + print_memory("cuda") + reset_memory("cuda") + + del pipe.text_encoder + del pipe.vae + del pipe + gc.collect() + torch.cuda.empty_cache() + + transformer.train() + + transformer.eval() + transformer_lora_layers = get_peft_model_state_dict(transformer) + MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers) + + # Cleanup trained models to save memory + del transformer + + gc.collect() + torch.cuda.empty_cache() + + # Final test inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + print("===== Memory before testing =====") + print_memory("cuda") + reset_memory("cuda") + + pipe = MochiPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + if args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora") + pipe.set_adapters(["mochi-lora"], [lora_scaling]) + + # Run inference + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": 6.0, + "num_inference_steps": 64, + "height": args.height, + "width": args.width, + "max_sequence_length": 256, + } + + video = log_validation( + pipe=pipe, + args=args, + pipeline_args=pipeline_args, + epoch=epoch, + wandb_run=wandb_run, + is_final_validation=True, + ) + validation_outputs.extend(video) + + print("===== Memory after testing =====") + print_memory("cuda") + reset_memory("cuda") + torch.cuda.synchronize("cuda") + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*", "*.bin", "*.pt"], + ) + print(f"Params pushed to {repo_id}.") if __name__ == "__main__": diff --git a/training/mochi-1/train.sh b/training/mochi-1/train.sh index d15ff695..f789281c 100644 --- a/training/mochi-1/train.sh +++ b/training/mochi-1/train.sh @@ -2,38 +2,35 @@ export NCCL_P2P_DISABLE=1 export TORCH_NCCL_ENABLE_MONITORING=0 -GPU_IDS="2" +GPU_IDS="0" -DATA_ROOT="/home/sayak/cogvideox-factory/training/mochi-1/videos_prepared" +DATA_ROOT="videos_prepared" MODEL="genmo/mochi-1-preview" -OUTPUT_PATH=/raid/.cache/huggingface/sayak/mochi-lora/ +OUTPUT_PATH="mochi-lora" -cmd="accelerate launch --config_file deepspeed.yaml --gpu_ids $GPU_IDS text_to_video_lora.py \ +cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python text_to_video_lora_simple.py \ --pretrained_model_name_or_path $MODEL \ + --cast_dit \ --data_root $DATA_ROOT \ --seed 42 \ - --mixed_precision "bf16" \ --output_dir $OUTPUT_PATH \ --train_batch_size 1 \ --dataloader_num_workers 4 \ --pin_memory \ --caption_dropout 0.1 \ --max_train_steps 2000 \ - --checkpointing_steps 200 \ - --checkpoints_total_limit 1 \ - --gradient_accumulation_steps 4 \ --gradient_checkpointing \ --enable_slicing \ --enable_tiling \ --enable_model_cpu_offload \ - --optimizer adamw --use_8bit \ - --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \ + --optimizer adamw \ + --validation_prompt \"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \ --validation_prompt_separator ::: \ --num_validation_videos 1 \ --validation_epochs 1 \ --allow_tf32 \ --report_to wandb \ - --nccl_timeout 1800" + --push_to_hub" echo "Running command: $cmd" eval $cmd From 4e3bb7ab5faeb3249fcfb7d12f9367988ee0a409 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 29 Nov 2024 10:52:12 +0530 Subject: [PATCH 22/30] updates --- training/mochi-1/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/training/mochi-1/README.md b/training/mochi-1/README.md index 5e278ca7..90517ca4 100644 --- a/training/mochi-1/README.md +++ b/training/mochi-1/README.md @@ -89,6 +89,7 @@ Our script currently doesn't leverage `accelerate` and some of its consequences * No support for distributed training. * No intermediate checkpoint saving and loading support. * `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support. +* No support for 8bit optimizers (but should be relatively easy to add). **Misc**: From e1866d844d2da33725f8d1e375286f24514eb2bb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 29 Nov 2024 13:07:26 +0530 Subject: [PATCH 23/30] better example code. --- training/mochi-1/text_to_video_lora.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 1a7122e5..dc4f30a2 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -102,7 +102,27 @@ def save_model_card( Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. ```py -TODO +from diffusers import MochiPipeline +from diffusers.utils import export_to_video +import torch + +pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") +pipe.load_lora_weights("CHANGE_ME") +pipe.enable_model_cpu_offload() + +pipeline_args = { + "prompt": "CHANGE_ME", + "guidance_scale": 6.0, + "num_inference_steps": 64, + "height": 480, + "width": 848, + "max_sequence_length": 256, + "output_type": "np", +} + +with torch.autocast("cuda", torch.bfloat16) + video = pipe(**pipeline_args).frames[0] +export_to_video(video) ``` For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. From 9c86706cbb5d66293a397f8abc1bac1225cfe589 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 29 Nov 2024 13:08:55 +0530 Subject: [PATCH 24/30] fix help message --- training/mochi-1/args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/mochi-1/args.py b/training/mochi-1/args.py index 4f1cc7e8..51671c0c 100644 --- a/training/mochi-1/args.py +++ b/training/mochi-1/args.py @@ -42,7 +42,7 @@ def _get_model_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--compile_dit", action="store_true", - help="If we should cast DiT params to a lower precision.", + help="If we should compile the DiT.", ) From 38f157c6d089cd9c36fc25cd88927816771d68e9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 29 Nov 2024 13:11:08 +0530 Subject: [PATCH 25/30] Apply suggestions from code review Co-authored-by: Aryan --- training/mochi-1/text_to_video_lora.py | 4 ++-- training/mochi-1/train.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index dc4f30a2..faa6255c 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -89,9 +89,9 @@ def save_model_card( ## Model description -This is a lora finetune of the Moch-1 preview model `{base_model}`. +This is a lora finetune of the Mochi-1 preview model `{base_model}`. -The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX, Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). +The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). ## Download model diff --git a/training/mochi-1/train.sh b/training/mochi-1/train.sh index f789281c..2c378e2e 100644 --- a/training/mochi-1/train.sh +++ b/training/mochi-1/train.sh @@ -8,7 +8,7 @@ DATA_ROOT="videos_prepared" MODEL="genmo/mochi-1-preview" OUTPUT_PATH="mochi-lora" -cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python text_to_video_lora_simple.py \ +cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python text_to_video_lora.py \ --pretrained_model_name_or_path $MODEL \ --cast_dit \ --data_root $DATA_ROOT \ From 8a32b139af861fda2853a4ad0cf3581383a11b24 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 29 Nov 2024 13:22:23 +0530 Subject: [PATCH 26/30] pin moviepy. --- training/mochi-1/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/mochi-1/requirements.txt b/training/mochi-1/requirements.txt index 8fb970ab..e8baa997 100644 --- a/training/mochi-1/requirements.txt +++ b/training/mochi-1/requirements.txt @@ -3,5 +3,5 @@ transformers wandb torch torchvision -moviepy +moviepy==1.0.3 click \ No newline at end of file From 95775ba5d138e185c75901f950b22ff16ab00c19 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 29 Nov 2024 13:27:23 +0530 Subject: [PATCH 27/30] pyav pining. --- training/mochi-1/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/training/mochi-1/requirements.txt b/training/mochi-1/requirements.txt index e8baa997..a03ceeb0 100644 --- a/training/mochi-1/requirements.txt +++ b/training/mochi-1/requirements.txt @@ -3,5 +3,6 @@ transformers wandb torch torchvision +av==11.0.0 moviepy==1.0.3 click \ No newline at end of file From 0011fa140644dbd45d4931386b02c97822647f7b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 29 Nov 2024 13:34:05 +0530 Subject: [PATCH 28/30] better command --- training/mochi-1/prepare_dataset.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/training/mochi-1/prepare_dataset.sh b/training/mochi-1/prepare_dataset.sh index 03a96559..c424b4e5 100644 --- a/training/mochi-1/prepare_dataset.sh +++ b/training/mochi-1/prepare_dataset.sh @@ -6,6 +6,10 @@ OUTPUT_DIR=videos_prepared NUM_FRAMES=37 RESOLUTION=480x848 +# Extract width and height from RESOLUTION +WIDTH=$(echo $RESOLUTION | cut -dx -f1) +HEIGHT=$(echo $RESOLUTION | cut -dx -f2) + python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample -CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=37x480x848 \ No newline at end of file +CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT} From dceded0a9472bdf626beb557aebccb126d977587 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 29 Nov 2024 13:40:02 +0530 Subject: [PATCH 29/30] add a preview table --- training/mochi-1/README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/training/mochi-1/README.md b/training/mochi-1/README.md index 90517ca4..74a2f7c8 100644 --- a/training/mochi-1/README.md +++ b/training/mochi-1/README.md @@ -1,5 +1,16 @@ # Simple Mochi-1 finetuner + + + + + + + + + +
Dataset Sample Test Sample
+ Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨 We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation. From 7090bcb9dece012110d5c0ec360f68dca0f41739 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 29 Nov 2024 13:47:26 +0530 Subject: [PATCH 30/30] Update README.md --- training/mochi-1/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/training/mochi-1/README.md b/training/mochi-1/README.md index 74a2f7c8..2cbc185d 100644 --- a/training/mochi-1/README.md +++ b/training/mochi-1/README.md @@ -6,8 +6,8 @@ Test Sample - - + + @@ -105,4 +105,4 @@ Our script currently doesn't leverage `accelerate` and some of its consequences **Misc**: * We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033). -* `embed.py` script is non-batched. \ No newline at end of file +* `embed.py` script is non-batched.