-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] add Mochi-1 trainer #90
Conversation
torch.cuda.empty_cache() | ||
torch.cuda.synchronize(accelerator.device) | ||
|
||
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably needs to be revisited because we follow an inverse sigma scheme in Mochi-1. And the sigma linear quadratic schedule, perhaps, needs to be incorporated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll have to take a deeper look into this soon, because the current training runs seem to not have worked as output videos are random noise.
I'll try and take a better look soon as well!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. I am looking into it too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's jam on this and make the script work asap! Tysm for this
@@ -0,0 +1,474 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay to have a separate file for now for faster iterations. We will be refactoring the repo with a more modular API in the future anyway
training/mochi-1/dataset.py
Outdated
|
||
logger = get_logger(__name__) | ||
|
||
# TODO (sayakpaul): probably not all buckets are needed for Mochi-1? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are just default buckets if you don't specify via CLI args, so I don't think we need to worry here
r=args.rank, | ||
lora_alpha=args.lora_alpha, | ||
init_lora_weights=True, | ||
target_modules=["to_k", "to_q", "to_v", "to_out.0"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not required at the moment, and this is more from the diffusion training community that regularly finetunes image models, but finetuning certain layers can make a model worse. Typically, you would want to understand which layers do what to the video by removing that layer and trying to run inference. We should try to throughly understand this for CogVideoX and Mochi, and try to find which layers does it make sense to finetune (and provide users this configurability instead of finetuning all layers by default) for say aesthetics, new concept, temporality improvements, stylized effects, etc. Would you be interested in doing this analysis for Mochi, and I can take it up for Cog?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Later yes. But we should try to first find something that is the simplest yet reasonable. But I will make it configurable right now.
torch.cuda.empty_cache() | ||
torch.cuda.synchronize(accelerator.device) | ||
|
||
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll have to take a deeper look into this soon, because the current training runs seem to not have worked as output videos are random noise.
I'll try and take a better look soon as well!
@a-r-r-o-w I have pushed a couple of updates that better reuses the existing I have also addressed some of your other comments. Will now look into |
I decided to pad videos having lower number of frames than the nearest frame bucket. |
@sayakpaul Did you manage to find a fix for the random noise in output videos ? |
Yeah. Will push my updates soon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Sayak! Just some minor comments and we should be good to merge
We can work on the refactors and generic-fication when we find the time later on
- --validation_epochs 1 \ | ||
``` | ||
|
||
We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for how many frames btw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
37 frames (similar to the original example).
Also could you point me to the logs from your most promising run so far? |
https://wandb.ai/sayakpaul/mochi-1-lora/runs/hu344t6o Here's one run with the original trainer but on the same dataset: We can see that the loss dynamics are the same. Additionally, here's a sample derived intermediate with the original trainer: 0_1600.mp4The one we see here is not too bad: https://wandb.ai/sayakpaul/mochi-1-lora/runs/hu344t6o. Of course, I suspect the other quality issues will go away once huggingface/diffusers#10033 is merged. But there is a blocker with that: https://huggingface.slack.com/archives/C065E480NN9/p1732808679447279?thread_ts=1732688413.727359&cid=C065E480NN9. LMK if you have any other questions. |
Co-authored-by: Aryan <aryan@huggingface.co>
Wow, these look very promising! The quality will definitely improve once Dhruv's PR is in, both for training and inference. I think we're good to merge then |
@a-r-r-o-w thanks! Do you think it'd be prudent to get huggingface/diffusers#10031 in, as it was also critical? |
A minimal and a simple reimplementation of the Mochi-1 fine-tuner but with
diffusers
andpeft
.Follow the
README.md
file added in this PR.Successful runs will be at https://wandb.ai/sayakpaul/mochi-1-lora.