Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] add Mochi-1 trainer #90

Merged
merged 31 commits into from
Nov 29, 2024
Merged

[feat] add Mochi-1 trainer #90

merged 31 commits into from
Nov 29, 2024

Conversation

sayakpaul
Copy link
Collaborator

@sayakpaul sayakpaul commented Nov 17, 2024

A minimal and a simple reimplementation of the Mochi-1 fine-tuner but with diffusers and peft.

Follow the README.md file added in this PR.

Successful runs will be at https://wandb.ai/sayakpaul/mochi-1-lora.

@sayakpaul sayakpaul requested a review from a-r-r-o-w November 17, 2024 04:49
@sayakpaul sayakpaul changed the title some minor updates [feat] add Mochi-1 trainer Nov 18, 2024
torch.cuda.empty_cache()
torch.cuda.synchronize(accelerator.device)

def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably needs to be revisited because we follow an inverse sigma scheme in Mochi-1. And the sigma linear quadratic schedule, perhaps, needs to be incorporated.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll have to take a deeper look into this soon, because the current training runs seem to not have worked as output videos are random noise.

I'll try and take a better look soon as well!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. I am looking into it too.

Copy link
Owner

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's jam on this and make the script work asap! Tysm for this

@@ -0,0 +1,474 @@
import argparse
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay to have a separate file for now for faster iterations. We will be refactoring the repo with a more modular API in the future anyway


logger = get_logger(__name__)

# TODO (sayakpaul): probably not all buckets are needed for Mochi-1?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are just default buckets if you don't specify via CLI args, so I don't think we need to worry here

training/mochi-1/dataset.py Outdated Show resolved Hide resolved
training/mochi-1/dataset.py Outdated Show resolved Hide resolved
training/mochi-1/prepare_dataset.py Outdated Show resolved Hide resolved
r=args.rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required at the moment, and this is more from the diffusion training community that regularly finetunes image models, but finetuning certain layers can make a model worse. Typically, you would want to understand which layers do what to the video by removing that layer and trying to run inference. We should try to throughly understand this for CogVideoX and Mochi, and try to find which layers does it make sense to finetune (and provide users this configurability instead of finetuning all layers by default) for say aesthetics, new concept, temporality improvements, stylized effects, etc. Would you be interested in doing this analysis for Mochi, and I can take it up for Cog?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later yes. But we should try to first find something that is the simplest yet reasonable. But I will make it configurable right now.

torch.cuda.empty_cache()
torch.cuda.synchronize(accelerator.device)

def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll have to take a deeper look into this soon, because the current training runs seem to not have worked as output videos are random noise.

I'll try and take a better look soon as well!

training/mochi-1/text_to_video_lora.py Outdated Show resolved Hide resolved
training/mochi-1/args.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Collaborator Author

@a-r-r-o-w I have pushed a couple of updates that better reuses the existing dataset.py module and reduces a significant LoC. I had to do a fresh prepare_dataset.py, though, because we have some major differences, mainly regarding to prompt_attention_mask.

I have also addressed some of your other comments.

Will now look into sigmas.

@sayakpaul
Copy link
Collaborator Author

I decided to pad videos having lower number of frames than the nearest frame bucket.

@TrickyBarrel
Copy link

@sayakpaul Did you manage to find a fix for the random noise in output videos ?

@sayakpaul
Copy link
Collaborator Author

Yeah. Will push my updates soon

@sayakpaul sayakpaul marked this pull request as ready for review November 29, 2024 05:19
Copy link
Owner

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Sayak! Just some minor comments and we should be good to merge

We can work on the refactors and generic-fication when we find the time later on

training/mochi-1/args.py Outdated Show resolved Hide resolved
- --validation_epochs 1 \
```

We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for how many frames btw?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

37 frames (similar to the original example).

training/mochi-1/train.sh Outdated Show resolved Hide resolved
training/mochi-1/text_to_video_lora.py Outdated Show resolved Hide resolved
training/mochi-1/text_to_video_lora.py Outdated Show resolved Hide resolved
training/mochi-1/text_to_video_lora.py Outdated Show resolved Hide resolved
training/mochi-1/text_to_video_lora.py Show resolved Hide resolved
@a-r-r-o-w
Copy link
Owner

Also could you point me to the logs from your most promising run so far?

@sayakpaul
Copy link
Collaborator Author

Also could you point me to the logs from your most promising run so far?

https://wandb.ai/sayakpaul/mochi-1-lora/runs/hu344t6o

Here's one run with the original trainer but on the same dataset:
https://wandb.ai/sayakpaul/mochi-1-lora/runs/de4iwnvf

We can see that the loss dynamics are the same. Additionally, here's a sample derived intermediate with the original trainer:

0_1600.mp4

The one we see here is not too bad: https://wandb.ai/sayakpaul/mochi-1-lora/runs/hu344t6o. Of course, I suspect the other quality issues will go away once huggingface/diffusers#10033 is merged. But there is a blocker with that: https://huggingface.slack.com/archives/C065E480NN9/p1732808679447279?thread_ts=1732688413.727359&cid=C065E480NN9.

LMK if you have any other questions.

@a-r-r-o-w
Copy link
Owner

Wow, these look very promising! The quality will definitely improve once Dhruv's PR is in, both for training and inference. I think we're good to merge then

@sayakpaul
Copy link
Collaborator Author

@a-r-r-o-w thanks! Do you think it'd be prudent to get huggingface/diffusers#10031 in, as it was also critical?

@sayakpaul sayakpaul merged commit d10963f into main Nov 29, 2024
@sayakpaul sayakpaul deleted the mochi-1-tuner branch November 29, 2024 09:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants