diff --git a/training/__init__.py b/training/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/training/dataset.py b/training/dataset.py
index dca375df..ec47b0b3 100644
--- a/training/dataset.py
+++ b/training/dataset.py
@@ -425,4 +425,4 @@ def __iter__(self):
random.shuffle(bucket)
yield bucket
del self.buckets[fhw]
- self.buckets[fhw] = []
+ self.buckets[fhw] = []
\ No newline at end of file
diff --git a/training/mochi-1/README.md b/training/mochi-1/README.md
new file mode 100644
index 00000000..2cbc185d
--- /dev/null
+++ b/training/mochi-1/README.md
@@ -0,0 +1,108 @@
+# Simple Mochi-1 finetuner
+
+
+
+ Dataset Sample |
+ Test Sample |
+
+
+ |
+ |
+
+
+
+Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨
+
+We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation.
+
+## Getting started
+
+Install the dependencies: `pip install -r requirements.txt`. Also make sure your `diffusers` installation is from the current `main`.
+
+Download a demo dataset:
+
+```bash
+huggingface-cli download \
+ --repo-type dataset sayakpaul/video-dataset-disney-organized \
+ --local-dir video-dataset-disney-organized
+```
+
+The dataset follows the directory structure expected by the subsequent scripts. In particular, it follows what's prescribed [here](https://github.com/genmoai/mochi/tree/main/demos/fine_tuner#1-collect-your-videos-and-captions):
+
+```bash
+video_1.mp4
+video_1.txt -- One-paragraph description of video_1
+video_2.mp4
+video_2.txt -- One-paragraph description of video_2
+...
+```
+
+Then run (be sure to check the paths accordingly):
+
+```bash
+bash prepare_dataset.sh
+```
+
+We can adjust `num_frames` and `resolution`. By default, in `prepare_dataset.sh`, we use `--force_upsample`. This means if the original video resolution is smaller than the requested resolution, we will upsample the video.
+
+> [!IMPORTANT]
+> It's important to have a resolution of at least 480x848 to satisy Mochi-1's requirements.
+
+Now, we're ready to fine-tune. To launch, run:
+
+```bash
+bash train.sh
+```
+
+You can disable intermediate validation by:
+
+```diff
+- --validation_prompt "..." \
+- --validation_prompt_separator ::: \
+- --num_validation_videos 1 \
+- --validation_epochs 1 \
+```
+
+We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM.
+
+To use the LoRA checkpoint:
+
+```py
+from diffusers import MochiPipeline
+from diffusers.utils import export_to_video
+import torch
+
+pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
+pipe.load_lora_weights("path-to-lora")
+pipe.enable_model_cpu_offload()
+
+pipeline_args = {
+ "prompt": "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions",
+ "guidance_scale": 6.0,
+ "num_inference_steps": 64,
+ "height": 480,
+ "width": 848,
+ "max_sequence_length": 256,
+ "output_type": "np",
+}
+
+with torch.autocast("cuda", torch.bfloat16)
+ video = pipe(**pipeline_args).frames[0]
+export_to_video(video)
+```
+
+## Known limitations
+
+(Contributions are welcome 🤗)
+
+Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below:
+
+* No support for distributed training.
+* No intermediate checkpoint saving and loading support.
+* `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support.
+* No support for 8bit optimizers (but should be relatively easy to add).
+
+**Misc**:
+
+* We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033).
+* `embed.py` script is non-batched.
diff --git a/training/mochi-1/args.py b/training/mochi-1/args.py
new file mode 100644
index 00000000..51671c0c
--- /dev/null
+++ b/training/mochi-1/args.py
@@ -0,0 +1,263 @@
+"""
+Default values taken from
+https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml
+when applicable.
+"""
+
+import argparse
+
+
+def _get_model_args(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument(
+ "--cast_dit",
+ action="store_true",
+ help="If we should cast DiT params to a lower precision.",
+ )
+ parser.add_argument(
+ "--compile_dit",
+ action="store_true",
+ help="If we should compile the DiT.",
+ )
+
+
+def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument(
+ "--data_root",
+ type=str,
+ default=None,
+ help=("A folder containing the training data."),
+ )
+ parser.add_argument(
+ "--caption_dropout",
+ type=float,
+ default=None,
+ help=("Probability to drop out captions randomly."),
+ )
+
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
+ )
+ parser.add_argument(
+ "--pin_memory",
+ action="store_true",
+ help="Whether or not to use the pinned memory setting in pytorch dataloader.",
+ )
+
+
+def _get_validation_args(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
+ )
+ parser.add_argument(
+ "--validation_images",
+ type=str,
+ default=None,
+ help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
+ )
+ parser.add_argument(
+ "--validation_prompt_separator",
+ type=str,
+ default=":::",
+ help="String that separates multiple validation prompts",
+ )
+ parser.add_argument(
+ "--num_validation_videos",
+ type=int,
+ default=1,
+ help="Number of videos that should be generated during validation per `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
+ )
+ parser.add_argument(
+ "--enable_slicing",
+ action="store_true",
+ default=False,
+ help="Whether or not to use VAE slicing for saving memory.",
+ )
+ parser.add_argument(
+ "--enable_tiling",
+ action="store_true",
+ default=False,
+ help="Whether or not to use VAE tiling for saving memory.",
+ )
+ parser.add_argument(
+ "--enable_model_cpu_offload",
+ action="store_true",
+ default=False,
+ help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
+ )
+ parser.add_argument(
+ "--fps",
+ type=int,
+ default=30,
+ help="FPS to use when serializing the output videos.",
+ )
+ parser.add_argument(
+ "--height",
+ type=int,
+ default=480,
+ )
+ parser.add_argument(
+ "--width",
+ type=int,
+ default=848,
+ )
+
+
+def _get_training_args(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.")
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=16,
+ help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
+ )
+ parser.add_argument(
+ "--target_modules",
+ nargs="+",
+ type=str,
+ default=["to_k", "to_q", "to_v", "to_out.0"],
+ help="Target modules to train LoRA for.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="mochi-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--train_batch_size",
+ type=int,
+ default=4,
+ help="Batch size (per device) for the training dataloader.",
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=2e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps",
+ type=int,
+ default=200,
+ help="Number of steps for the warmup in the lr scheduler.",
+ )
+
+
+def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument(
+ "--optimizer",
+ type=lambda s: s.lower(),
+ default="adam",
+ choices=["adam", "adamw"],
+ help=("The optimizer type to use."),
+ )
+ parser.add_argument(
+ "--weight_decay",
+ type=float,
+ default=0.01,
+ help="Weight decay to use for optimizer.",
+ )
+
+
+def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether or not to push the model to the Hub.",
+ )
+ parser.add_argument(
+ "--hub_token",
+ type=str,
+ default=None,
+ help="The token to use to push to the Model Hub.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default=None,
+ help="If logging to wandb."
+ )
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.")
+
+ _get_model_args(parser)
+ _get_dataset_args(parser)
+ _get_training_args(parser)
+ _get_validation_args(parser)
+ _get_optimizer_args(parser)
+ _get_configuration_args(parser)
+
+ return parser.parse_args()
diff --git a/training/mochi-1/dataset_simple.py b/training/mochi-1/dataset_simple.py
new file mode 100644
index 00000000..8cc6153b
--- /dev/null
+++ b/training/mochi-1/dataset_simple.py
@@ -0,0 +1,50 @@
+"""
+Taken from
+https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py
+"""
+
+from pathlib import Path
+
+import click
+import torch
+from torch.utils.data import DataLoader, Dataset
+
+
+def load_to_cpu(x):
+ return torch.load(x, map_location=torch.device("cpu"), weights_only=True)
+
+
+class LatentEmbedDataset(Dataset):
+ def __init__(self, file_paths, repeat=1):
+ self.items = [
+ (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt"))
+ for p in file_paths
+ if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file()
+ ]
+ self.items = self.items * repeat
+ print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.")
+
+ def __len__(self):
+ return len(self.items)
+
+ def __getitem__(self, idx):
+ latent_path, embed_path = self.items[idx]
+ return load_to_cpu(latent_path), load_to_cpu(embed_path)
+
+
+@click.command()
+@click.argument("directory", type=click.Path(exists=True, file_okay=False))
+def process_videos(directory):
+ dir_path = Path(directory)
+ mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
+ assert mp4_files, f"No mp4 files found"
+
+ dataset = LatentEmbedDataset(mp4_files)
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
+
+ for latents, embeds in dataloader:
+ print([(k, v.shape) for k, v in latents.items()])
+
+
+if __name__ == "__main__":
+ process_videos()
diff --git a/training/mochi-1/embed.py b/training/mochi-1/embed.py
new file mode 100644
index 00000000..ec35ebb0
--- /dev/null
+++ b/training/mochi-1/embed.py
@@ -0,0 +1,111 @@
+"""
+Adapted from:
+https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py
+https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py
+"""
+
+import click
+import torch
+import torchvision
+from pathlib import Path
+from diffusers import AutoencoderKLMochi, MochiPipeline
+from transformers import T5EncoderModel, T5Tokenizer
+from tqdm.auto import tqdm
+
+
+def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str):
+ T, H, W = [int(s) for s in shape.split("x")]
+ assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6"
+ video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs")
+ fps = metadata["video_fps"]
+ video = video.permute(3, 0, 1, 2)
+ og_shape = video.shape
+ assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}"
+ assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}"
+ assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}"
+ if video.shape[1] > T:
+ video = video[:, :T]
+ print(f"Trimmed video from {og_shape[1]} to first {T} frames")
+ video = video.unsqueeze(0)
+ video = video.float() / 127.5 - 1.0
+ video = video.to(model.device)
+
+ assert video.ndim == 5
+
+ with torch.inference_mode():
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ ldist = model._encode(video)
+
+ torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt"))
+
+
+@click.command()
+@click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
+@click.option(
+ "--model_id",
+ type=str,
+ help="Repo id. Should be genmo/mochi-1-preview",
+ default="genmo/mochi-1-preview",
+)
+@click.option("--shape", default="163x480x848", help="Shape of the video to encode")
+@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.")
+def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None:
+ """Process all videos and captions in a directory using a single GPU."""
+ # comment out when running on unsupported hardware
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ # Get all video paths
+ video_paths = list(output_dir.glob("**/*.mp4"))
+ if not video_paths:
+ print(f"No MP4 files found in {output_dir}")
+ return
+
+ text_paths = list(output_dir.glob("**/*.txt"))
+ if not text_paths:
+ print(f"No text files found in {output_dir}")
+ return
+
+ # load the models
+ vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda")
+ text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder")
+ tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
+ pipeline = MochiPipeline.from_pretrained(
+ model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None
+ ).to("cuda")
+
+ for idx, video_path in tqdm(enumerate(sorted(video_paths))):
+ print(f"Processing {video_path}")
+ try:
+ if video_path.with_suffix(".latent.pt").exists() and not overwrite:
+ print(f"Skipping {video_path}")
+ continue
+
+ # encode videos.
+ encode_videos(vae, vid_path=video_path, shape=shape)
+
+ # embed captions.
+ prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt")
+ embed_path = prompt_path.with_suffix(".embed.pt")
+
+ if embed_path.exists() and not overwrite:
+ print(f"Skipping {prompt_path} - embeddings already exist")
+ continue
+
+ with open(prompt_path) as f:
+ text = f.read().strip()
+ with torch.inference_mode():
+ conditioning = pipeline.encode_prompt(prompt=[text])
+
+ conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]}
+ torch.save(conditioning, embed_path)
+
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ print(f"Error processing {video_path}: {str(e)}")
+
+
+if __name__ == "__main__":
+ batch_process()
diff --git a/training/mochi-1/prepare_dataset.sh b/training/mochi-1/prepare_dataset.sh
new file mode 100644
index 00000000..c424b4e5
--- /dev/null
+++ b/training/mochi-1/prepare_dataset.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+GPU_ID=0
+VIDEO_DIR=video-dataset-disney-organized
+OUTPUT_DIR=videos_prepared
+NUM_FRAMES=37
+RESOLUTION=480x848
+
+# Extract width and height from RESOLUTION
+WIDTH=$(echo $RESOLUTION | cut -dx -f1)
+HEIGHT=$(echo $RESOLUTION | cut -dx -f2)
+
+python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample
+
+CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT}
diff --git a/training/mochi-1/requirements.txt b/training/mochi-1/requirements.txt
new file mode 100644
index 00000000..a03ceeb0
--- /dev/null
+++ b/training/mochi-1/requirements.txt
@@ -0,0 +1,8 @@
+peft
+transformers
+wandb
+torch
+torchvision
+av==11.0.0
+moviepy==1.0.3
+click
\ No newline at end of file
diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py
new file mode 100644
index 00000000..faa6255c
--- /dev/null
+++ b/training/mochi-1/text_to_video_lora.py
@@ -0,0 +1,558 @@
+# Copyright 2024 The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import random
+from glob import glob
+import math
+import os
+import torch.nn.functional as F
+import numpy as np
+from pathlib import Path
+from typing import Any, Dict, Tuple, List
+
+import torch
+import wandb
+from diffusers import FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from diffusers.training_utils import cast_training_params
+from diffusers.utils import export_to_video
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, get_peft_model_state_dict
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+
+
+from args import get_args # isort:skip
+from dataset_simple import LatentEmbedDataset
+
+import sys
+
+
+sys.path.append("..")
+
+from utils import print_memory, reset_memory # isort:skip
+
+
+# Taken from
+# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139
+def get_cosine_annealing_lr_scheduler(
+ optimizer: torch.optim.Optimizer,
+ warmup_steps: int,
+ total_steps: int,
+):
+ def lr_lambda(step):
+ if step < warmup_steps:
+ return float(step) / float(max(1, warmup_steps))
+ else:
+ return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
+
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
+
+
+def save_model_card(
+ repo_id: str,
+ videos=None,
+ base_model: str = None,
+ validation_prompt=None,
+ repo_folder=None,
+ fps=30,
+):
+ widget_dict = []
+ if videos is not None and len(videos) > 0:
+ for i, video in enumerate(videos):
+ export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps)
+ widget_dict.append(
+ {
+ "text": validation_prompt if validation_prompt else " ",
+ "output": {"url": f"final_video_{i}.mp4"},
+ }
+ )
+
+ model_description = f"""
+# Mochi-1 Preview LoRA Finetune
+
+
+
+## Model description
+
+This is a lora finetune of the Mochi-1 preview model `{base_model}`.
+
+The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
+
+## Download model
+
+[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
+
+## Usage
+
+Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
+
+```py
+from diffusers import MochiPipeline
+from diffusers.utils import export_to_video
+import torch
+
+pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
+pipe.load_lora_weights("CHANGE_ME")
+pipe.enable_model_cpu_offload()
+
+pipeline_args = {
+ "prompt": "CHANGE_ME",
+ "guidance_scale": 6.0,
+ "num_inference_steps": 64,
+ "height": 480,
+ "width": 848,
+ "max_sequence_length": 256,
+ "output_type": "np",
+}
+
+with torch.autocast("cuda", torch.bfloat16)
+ video = pipe(**pipeline_args).frames[0]
+export_to_video(video)
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
+
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="apache-2.0",
+ base_model=base_model,
+ prompt=validation_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-video",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "mochi-1-preview",
+ "mochi-1-preview-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipe: MochiPipeline,
+ args: Dict[str, Any],
+ pipeline_args: Dict[str, Any],
+ epoch,
+ wandb_run: str = None,
+ is_final_validation: bool = False,
+):
+ print(
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
+ )
+ phase_name = "test" if is_final_validation else "validation"
+
+ if not args.enable_model_cpu_offload:
+ pipe = pipe.to("cuda")
+
+ # run inference
+ generator = torch.manual_seed(args.seed) if args.seed else None
+
+ videos = []
+ with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
+ for _ in range(args.num_validation_videos):
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
+ videos.append(video)
+
+ video_filenames = []
+ for i, video in enumerate(videos):
+ prompt = (
+ pipeline_args["prompt"][:25]
+ .replace(" ", "_")
+ .replace(" ", "_")
+ .replace("'", "_")
+ .replace('"', "_")
+ .replace("/", "_")
+ )
+ filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
+ export_to_video(video, filename, fps=30)
+ video_filenames.append(filename)
+
+ if wandb_run:
+ wandb.log(
+ {
+ phase_name: [
+ wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30)
+ for i, filename in enumerate(video_filenames)
+ ]
+ }
+ )
+
+ return videos
+
+
+# Adapted from the original code:
+# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578
+def cast_dit(model, dtype):
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ assert any(
+ n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"]
+ ), f"Unexpected linear layer: {name}"
+ module.to(dtype=dtype)
+ elif isinstance(module, torch.nn.Conv2d):
+ module.to(dtype=dtype)
+ return model
+
+
+class CollateFunction:
+ def __init__(self, caption_dropout: float = None) -> None:
+ self.caption_dropout = caption_dropout
+
+ def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]:
+ ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0)
+ z = DiagonalGaussianDistribution(ldists).sample()
+ assert torch.isfinite(z).all()
+
+ # Sample noise which we will add to the samples.
+ eps = torch.randn_like(z)
+ sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32)
+
+ prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0)
+ prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0)
+ if self.caption_dropout and random.random() < self.caption_dropout:
+ prompt_embeds.zero_()
+ prompt_attention_mask = prompt_attention_mask.long()
+ prompt_attention_mask.zero_()
+ prompt_attention_mask = prompt_attention_mask.bool()
+
+ return dict(
+ z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask
+ )
+
+
+def main(args):
+ if not torch.cuda.is_available():
+ raise ValueError("Not supported without CUDA.")
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ # Handle the repository creation
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Prepare models and scheduler
+ transformer = MochiTransformer3DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+
+ transformer.requires_grad_(False)
+ transformer.to("cuda")
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ if args.cast_dit:
+ transformer = cast_dit(transformer, torch.bfloat16)
+ if args.compile_dit:
+ transformer.compile()
+
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.lora_alpha,
+ init_lora_weights="gaussian",
+ target_modules=args.target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * args.train_batch_size
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params([transformer], dtype=torch.float32)
+
+ # Prepare optimizer
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+ num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters)
+ optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
+
+ # Dataset and DataLoader
+ train_vids = list(sorted(glob(f"{args.data_root}/*.mp4")))
+ train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
+ print(f"Found {len(train_vids)} training videos in {args.data_root}")
+ assert len(train_vids) > 0, f"No training data found in {args.data_root}"
+
+ collate_fn = CollateFunction(caption_dropout=args.caption_dropout)
+ train_dataset = LatentEmbedDataset(train_vids, repeat=1)
+ train_dataloader = DataLoader(
+ train_dataset,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ pin_memory=args.pin_memory,
+ )
+
+ # LR scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = len(train_dataloader)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_cosine_annealing_lr_scheduler(
+ optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = len(train_dataloader)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ wandb_run = None
+ if args.report_to == "wandb":
+ tracker_name = args.tracker_name or "mochi-1-lora"
+ wandb_run = wandb.init(project=tracker_name, config=vars(args))
+
+ print("===== Memory before training =====")
+ reset_memory("cuda")
+ print_memory("cuda")
+
+ # Train!
+ total_batch_size = args.train_batch_size
+ print("***** Running training *****")
+ print(f" Num trainable parameters = {num_trainable_parameters}")
+ print(f" Num examples = {len(train_dataset)}")
+ print(f" Num batches each epoch = {len(train_dataloader)}")
+ print(f" Num epochs = {args.num_train_epochs}")
+ print(f" Instantaneous batch size per device = {args.train_batch_size}")
+ print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ print(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+ first_epoch = 0
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=global_step,
+ desc="Steps",
+ )
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ with torch.no_grad():
+ z = batch["z"].to("cuda")
+ eps = batch["eps"].to("cuda")
+ sigma = batch["sigma"].to("cuda")
+ prompt_embeds = batch["prompt_embeds"].to("cuda")
+ prompt_attention_mask = batch["prompt_attention_mask"].to("cuda")
+
+ sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1]
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps
+ ut = z - eps
+
+ # (1 - sigma) because of
+ # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656
+ # Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation.
+ timesteps = (1 - sigma) * scheduler.config.num_train_timesteps
+
+ with torch.autocast("cuda", torch.bfloat16):
+ model_pred = transformer(
+ hidden_states=z_sigma,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+ assert model_pred.shape == z.shape
+ loss = F.mse_loss(model_pred.float(), ut.float())
+ loss.backward()
+
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+ progress_bar.update(1)
+ global_step += 1
+
+ last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
+ logs = {"loss": loss.detach().item(), "lr": last_lr}
+ progress_bar.set_postfix(**logs)
+ if wandb_run:
+ wandb_run.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
+ print("===== Memory before validation =====")
+ print_memory("cuda")
+
+ transformer.eval()
+ pipe = MochiPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=transformer,
+ scheduler=scheduler,
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ if args.enable_slicing:
+ pipe.vae.enable_slicing()
+ if args.enable_tiling:
+ pipe.vae.enable_tiling()
+ if args.enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
+ for validation_prompt in validation_prompts:
+ pipeline_args = {
+ "prompt": validation_prompt,
+ "guidance_scale": 6.0,
+ "num_inference_steps": 64,
+ "height": args.height,
+ "width": args.width,
+ "max_sequence_length": 256,
+ }
+ log_validation(
+ pipe=pipe,
+ args=args,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ wandb_run=wandb_run,
+ )
+
+ print("===== Memory after validation =====")
+ print_memory("cuda")
+ reset_memory("cuda")
+
+ del pipe.text_encoder
+ del pipe.vae
+ del pipe
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ transformer.train()
+
+ transformer.eval()
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+ MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers)
+
+ # Cleanup trained models to save memory
+ del transformer
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Final test inference
+ validation_outputs = []
+ if args.validation_prompt and args.num_validation_videos > 0:
+ print("===== Memory before testing =====")
+ print_memory("cuda")
+ reset_memory("cuda")
+
+ pipe = MochiPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ if args.enable_slicing:
+ pipe.vae.enable_slicing()
+ if args.enable_tiling:
+ pipe.vae.enable_tiling()
+ if args.enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+
+ # Load LoRA weights
+ lora_scaling = args.lora_alpha / args.rank
+ pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora")
+ pipe.set_adapters(["mochi-lora"], [lora_scaling])
+
+ # Run inference
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
+ for validation_prompt in validation_prompts:
+ pipeline_args = {
+ "prompt": validation_prompt,
+ "guidance_scale": 6.0,
+ "num_inference_steps": 64,
+ "height": args.height,
+ "width": args.width,
+ "max_sequence_length": 256,
+ }
+
+ video = log_validation(
+ pipe=pipe,
+ args=args,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ wandb_run=wandb_run,
+ is_final_validation=True,
+ )
+ validation_outputs.extend(video)
+
+ print("===== Memory after testing =====")
+ print_memory("cuda")
+ reset_memory("cuda")
+ torch.cuda.synchronize("cuda")
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ videos=validation_outputs,
+ base_model=args.pretrained_model_name_or_path,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ fps=args.fps,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*", "*.bin", "*.pt"],
+ )
+ print(f"Params pushed to {repo_id}.")
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/training/mochi-1/train.sh b/training/mochi-1/train.sh
new file mode 100644
index 00000000..2c378e2e
--- /dev/null
+++ b/training/mochi-1/train.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+export NCCL_P2P_DISABLE=1
+export TORCH_NCCL_ENABLE_MONITORING=0
+
+GPU_IDS="0"
+
+DATA_ROOT="videos_prepared"
+MODEL="genmo/mochi-1-preview"
+OUTPUT_PATH="mochi-lora"
+
+cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python text_to_video_lora.py \
+ --pretrained_model_name_or_path $MODEL \
+ --cast_dit \
+ --data_root $DATA_ROOT \
+ --seed 42 \
+ --output_dir $OUTPUT_PATH \
+ --train_batch_size 1 \
+ --dataloader_num_workers 4 \
+ --pin_memory \
+ --caption_dropout 0.1 \
+ --max_train_steps 2000 \
+ --gradient_checkpointing \
+ --enable_slicing \
+ --enable_tiling \
+ --enable_model_cpu_offload \
+ --optimizer adamw \
+ --validation_prompt \"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \
+ --validation_prompt_separator ::: \
+ --num_validation_videos 1 \
+ --validation_epochs 1 \
+ --allow_tf32 \
+ --report_to wandb \
+ --push_to_hub"
+
+echo "Running command: $cmd"
+eval $cmd
+echo -ne "-------------------- Finished executing script --------------------\n\n"
\ No newline at end of file
diff --git a/training/mochi-1/trim_and_crop_videos.py b/training/mochi-1/trim_and_crop_videos.py
new file mode 100644
index 00000000..0c6f411d
--- /dev/null
+++ b/training/mochi-1/trim_and_crop_videos.py
@@ -0,0 +1,126 @@
+"""
+Adapted from:
+https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py
+"""
+
+from pathlib import Path
+import shutil
+
+import click
+from moviepy.editor import VideoFileClip
+from tqdm import tqdm
+
+
+@click.command()
+@click.argument("folder", type=click.Path(exists=True, dir_okay=True))
+@click.argument("output_folder", type=click.Path(dir_okay=True))
+@click.option("--num_frames", "-f", type=float, default=30, help="Number of frames")
+@click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution")
+@click.option("--force_upsample", is_flag=True, help="Force upsample.")
+def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample):
+ """Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution"""
+ input_path = Path(folder)
+ output_path = Path(output_folder)
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ # Parse target resolution
+ target_height, target_width = map(int, resolution.split("x"))
+
+ # Calculate duration
+ duration = (num_frames / 30) + 0.09
+
+ # Find all MP4 and MOV files
+ video_files = (
+ list(input_path.rglob("*.mp4"))
+ + list(input_path.rglob("*.MOV"))
+ + list(input_path.rglob("*.mov"))
+ + list(input_path.rglob("*.MP4"))
+ )
+
+ for file_path in tqdm(video_files):
+ try:
+ relative_path = file_path.relative_to(input_path)
+ output_file = output_path / relative_path.with_suffix(".mp4")
+ output_file.parent.mkdir(parents=True, exist_ok=True)
+
+ click.echo(f"Processing: {file_path}")
+ video = VideoFileClip(str(file_path))
+
+ # Skip if video is too short
+ if video.duration < duration:
+ click.echo(f"Skipping {file_path} as it is too short")
+ continue
+
+ # Skip if target resolution is larger than input
+ if target_width > video.w or target_height > video.h:
+ if force_upsample:
+ click.echo(
+ f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video."
+ )
+ video = video.resize(width=target_width, height=target_height)
+ else:
+ click.echo(
+ f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}"
+ )
+ continue
+
+ # First truncate duration
+ truncated = video.subclip(0, duration)
+
+ # Calculate crop dimensions to maintain aspect ratio
+ target_ratio = target_width / target_height
+ current_ratio = truncated.w / truncated.h
+
+ if current_ratio > target_ratio:
+ # Video is wider than target ratio - crop width
+ new_width = int(truncated.h * target_ratio)
+ x1 = (truncated.w - new_width) // 2
+ final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height))
+ else:
+ # Video is taller than target ratio - crop height
+ new_height = int(truncated.w / target_ratio)
+ y1 = (truncated.h - new_height) // 2
+ final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height))
+
+ # Set output parameters for consistent MP4 encoding
+ output_params = {
+ "codec": "libx264",
+ "audio": False, # Disable audio
+ "preset": "medium", # Balance between speed and quality
+ "bitrate": "5000k", # Adjust as needed
+ }
+
+ # Set FPS to 30
+ final = final.set_fps(30)
+
+ # Check for a corresponding .txt file
+ txt_file_path = file_path.with_suffix(".txt")
+ if txt_file_path.exists():
+ output_txt_file = output_path / relative_path.with_suffix(".txt")
+ output_txt_file.parent.mkdir(parents=True, exist_ok=True)
+ shutil.copy(txt_file_path, output_txt_file)
+ click.echo(f"Copied {txt_file_path} to {output_txt_file}")
+ else:
+ # Print warning in bold yellow with a warning emoji
+ click.echo(
+ f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m"
+ )
+ output_txt_file = output_path / relative_path.with_suffix(".txt")
+ output_txt_file.parent.mkdir(parents=True, exist_ok=True)
+ output_txt_file.touch()
+
+ # Write the output file
+ final.write_videofile(str(output_file), **output_params)
+
+ # Clean up
+ video.close()
+ truncated.close()
+ final.close()
+
+ except Exception as e:
+ click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True)
+ raise
+
+
+if __name__ == "__main__":
+ truncate_videos()