-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] add Mochi-1 trainer #90
Merged
Merged
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
a9adf22
dataprep.
sayakpaul 5ba510e
updates
sayakpaul 9852c3d
updates.
sayakpaul a40ccd2
updates
sayakpaul 4c05ea6
updates
sayakpaul a8dd94e
updates
sayakpaul ac83c78
updates
sayakpaul 1409d47
updates.
sayakpaul c80bd17
nearest_frame_bucket.
sayakpaul 316a705
revert changes to training/dataset.py
sayakpaul 4216705
Merge branch 'main' into mochi-1-tuner
sayakpaul 440dc25
better reuse.
sayakpaul a01592b
betterments.
sayakpaul cb16cba
dataset_mochi
sayakpaul 5208c59
fix
sayakpaul 9eea656
fixes
sayakpaul 7e203db
updates
sayakpaul 9a15eae
updates
sayakpaul 2dbddd5
updates
sayakpaul 58a0632
updates
sayakpaul 2fde026
updates
sayakpaul ced8558
updates
sayakpaul 4e3bb7a
updates
sayakpaul e1866d8
better example code.
sayakpaul 9c86706
fix help message
sayakpaul 38f157c
Apply suggestions from code review
sayakpaul 8a32b13
pin moviepy.
sayakpaul 95775ba
pyav pining.
sayakpaul 0011fa1
better command
sayakpaul dceded0
add a preview table
sayakpaul 7090bcb
Update README.md
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Simple Mochi-1 finetuner | ||
|
||
Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨 | ||
|
||
We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation. | ||
|
||
## Getting started | ||
|
||
Install the dependencies: `pip install -r requirements.txt`. Also make sure your `diffusers` installation is from the current `main`. | ||
|
||
Download a demo dataset: | ||
|
||
```bash | ||
huggingface-cli download \ | ||
--repo-type dataset sayakpaul/video-dataset-disney-organized \ | ||
--local-dir video-dataset-disney-organized | ||
``` | ||
|
||
The dataset follows the directory structure expected by the subsequent scripts. In particular, it follows what's prescribed [here](https://github.com/genmoai/mochi/tree/main/demos/fine_tuner#1-collect-your-videos-and-captions): | ||
|
||
```bash | ||
video_1.mp4 | ||
video_1.txt -- One-paragraph description of video_1 | ||
video_2.mp4 | ||
video_2.txt -- One-paragraph description of video_2 | ||
... | ||
``` | ||
|
||
Then run (be sure to check the paths accordingly): | ||
|
||
```bash | ||
bash prepare_dataset.sh | ||
``` | ||
|
||
We can adjust `num_frames` and `resolution`. By default, in `prepare_dataset.sh`, we use `--force_upsample`. This means if the original video resolution is smaller than the requested resolution, we will upsample the video. | ||
|
||
> [!IMPORTANT] | ||
> It's important to have a resolution of at least 480x848 to satisy Mochi-1's requirements. | ||
|
||
Now, we're ready to fine-tune. To launch, run: | ||
|
||
```bash | ||
bash train.sh | ||
``` | ||
|
||
You can disable intermediate validation by: | ||
|
||
```diff | ||
- --validation_prompt "..." \ | ||
- --validation_prompt_separator ::: \ | ||
- --num_validation_videos 1 \ | ||
- --validation_epochs 1 \ | ||
``` | ||
|
||
We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM. | ||
|
||
To use the LoRA checkpoint: | ||
|
||
```py | ||
from diffusers import MochiPipeline | ||
from diffusers.utils import export_to_video | ||
import torch | ||
|
||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") | ||
pipe.load_lora_weights("path-to-lora") | ||
pipe.enable_model_cpu_offload() | ||
|
||
pipeline_args = { | ||
"prompt": "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions", | ||
"guidance_scale": 6.0, | ||
"num_inference_steps": 64, | ||
"height": 480, | ||
"width": 848, | ||
"max_sequence_length": 256, | ||
"output_type": "np", | ||
} | ||
|
||
with torch.autocast("cuda", torch.bfloat16) | ||
video = pipe(**pipeline_args).frames[0] | ||
export_to_video(video) | ||
``` | ||
|
||
## Known limitations | ||
|
||
(Contributions are welcome 🤗) | ||
|
||
Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below: | ||
|
||
* No support for distributed training. | ||
* No intermediate checkpoint saving and loading support. | ||
* `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support. | ||
* No support for 8bit optimizers (but should be relatively easy to add). | ||
|
||
**Misc**: | ||
|
||
* We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033). | ||
* `embed.py` script is non-batched. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
""" | ||
Default values taken from | ||
https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml | ||
when applicable. | ||
""" | ||
|
||
import argparse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay to have a separate file for now for faster iterations. We will be refactoring the repo with a more modular API in the future anyway |
||
|
||
|
||
def _get_model_args(parser: argparse.ArgumentParser) -> None: | ||
parser.add_argument( | ||
"--pretrained_model_name_or_path", | ||
type=str, | ||
default=None, | ||
required=True, | ||
help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
) | ||
parser.add_argument( | ||
"--revision", | ||
type=str, | ||
default=None, | ||
required=False, | ||
help="Revision of pretrained model identifier from huggingface.co/models.", | ||
) | ||
parser.add_argument( | ||
"--variant", | ||
type=str, | ||
default=None, | ||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", | ||
) | ||
parser.add_argument( | ||
"--cache_dir", | ||
type=str, | ||
default=None, | ||
help="The directory where the downloaded models and datasets will be stored.", | ||
) | ||
parser.add_argument( | ||
"--cast_dit", | ||
action="store_true", | ||
help="If we should cast DiT params to a lower precision.", | ||
) | ||
parser.add_argument( | ||
"--compile_dit", | ||
action="store_true", | ||
help="If we should cast DiT params to a lower precision.", | ||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
|
||
def _get_dataset_args(parser: argparse.ArgumentParser) -> None: | ||
parser.add_argument( | ||
"--data_root", | ||
type=str, | ||
default=None, | ||
help=("A folder containing the training data."), | ||
) | ||
parser.add_argument( | ||
"--caption_dropout", | ||
type=float, | ||
default=None, | ||
help=("Probability to drop out captions randomly."), | ||
) | ||
|
||
parser.add_argument( | ||
"--dataloader_num_workers", | ||
type=int, | ||
default=0, | ||
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", | ||
) | ||
parser.add_argument( | ||
"--pin_memory", | ||
action="store_true", | ||
help="Whether or not to use the pinned memory setting in pytorch dataloader.", | ||
) | ||
|
||
|
||
def _get_validation_args(parser: argparse.ArgumentParser) -> None: | ||
parser.add_argument( | ||
"--validation_prompt", | ||
type=str, | ||
default=None, | ||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", | ||
) | ||
parser.add_argument( | ||
"--validation_images", | ||
type=str, | ||
default=None, | ||
help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", | ||
) | ||
parser.add_argument( | ||
"--validation_prompt_separator", | ||
type=str, | ||
default=":::", | ||
help="String that separates multiple validation prompts", | ||
) | ||
parser.add_argument( | ||
"--num_validation_videos", | ||
type=int, | ||
default=1, | ||
help="Number of videos that should be generated during validation per `validation_prompt`.", | ||
) | ||
parser.add_argument( | ||
"--validation_epochs", | ||
type=int, | ||
default=50, | ||
help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", | ||
) | ||
parser.add_argument( | ||
"--enable_slicing", | ||
action="store_true", | ||
default=False, | ||
help="Whether or not to use VAE slicing for saving memory.", | ||
) | ||
parser.add_argument( | ||
"--enable_tiling", | ||
action="store_true", | ||
default=False, | ||
help="Whether or not to use VAE tiling for saving memory.", | ||
) | ||
parser.add_argument( | ||
"--enable_model_cpu_offload", | ||
action="store_true", | ||
default=False, | ||
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", | ||
) | ||
parser.add_argument( | ||
"--fps", | ||
type=int, | ||
default=30, | ||
help="FPS to use when serializing the output videos.", | ||
) | ||
parser.add_argument( | ||
"--height", | ||
type=int, | ||
default=480, | ||
) | ||
parser.add_argument( | ||
"--width", | ||
type=int, | ||
default=848, | ||
) | ||
|
||
|
||
def _get_training_args(parser: argparse.ArgumentParser) -> None: | ||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") | ||
parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.") | ||
parser.add_argument( | ||
"--lora_alpha", | ||
type=int, | ||
default=16, | ||
help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", | ||
) | ||
parser.add_argument( | ||
"--target_modules", | ||
nargs="+", | ||
type=str, | ||
default=["to_k", "to_q", "to_v", "to_out.0"], | ||
help="Target modules to train LoRA for.", | ||
) | ||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
default="mochi-lora", | ||
help="The output directory where the model predictions and checkpoints will be written.", | ||
) | ||
parser.add_argument( | ||
"--train_batch_size", | ||
type=int, | ||
default=4, | ||
help="Batch size (per device) for the training dataloader.", | ||
) | ||
parser.add_argument("--num_train_epochs", type=int, default=1) | ||
parser.add_argument( | ||
"--max_train_steps", | ||
type=int, | ||
default=None, | ||
help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", | ||
) | ||
parser.add_argument( | ||
"--gradient_checkpointing", | ||
action="store_true", | ||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
) | ||
parser.add_argument( | ||
"--learning_rate", | ||
type=float, | ||
default=2e-4, | ||
help="Initial learning rate (after the potential warmup period) to use.", | ||
) | ||
parser.add_argument( | ||
"--scale_lr", | ||
action="store_true", | ||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | ||
) | ||
parser.add_argument( | ||
"--lr_warmup_steps", | ||
type=int, | ||
default=200, | ||
help="Number of steps for the warmup in the lr scheduler.", | ||
) | ||
|
||
|
||
def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: | ||
parser.add_argument( | ||
"--optimizer", | ||
type=lambda s: s.lower(), | ||
default="adam", | ||
choices=["adam", "adamw"], | ||
help=("The optimizer type to use."), | ||
) | ||
parser.add_argument( | ||
"--weight_decay", | ||
type=float, | ||
default=0.01, | ||
help="Weight decay to use for optimizer.", | ||
) | ||
|
||
|
||
def _get_configuration_args(parser: argparse.ArgumentParser) -> None: | ||
parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") | ||
parser.add_argument( | ||
"--push_to_hub", | ||
action="store_true", | ||
help="Whether or not to push the model to the Hub.", | ||
) | ||
parser.add_argument( | ||
"--hub_token", | ||
type=str, | ||
default=None, | ||
help="The token to use to push to the Model Hub.", | ||
) | ||
parser.add_argument( | ||
"--hub_model_id", | ||
type=str, | ||
default=None, | ||
help="The name of the repository to keep in sync with the local `output_dir`.", | ||
) | ||
parser.add_argument( | ||
"--allow_tf32", | ||
action="store_true", | ||
help=( | ||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" | ||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" | ||
), | ||
) | ||
parser.add_argument( | ||
"--report_to", | ||
type=str, | ||
default=None, | ||
help="If logging to wandb." | ||
) | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.") | ||
|
||
_get_model_args(parser) | ||
_get_dataset_args(parser) | ||
_get_training_args(parser) | ||
_get_validation_args(parser) | ||
_get_optimizer_args(parser) | ||
_get_configuration_args(parser) | ||
|
||
return parser.parse_args() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for how many frames btw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
37 frames (similar to the original example).