-
Notifications
You must be signed in to change notification settings - Fork 89
feat(mindone/diffusers): enable multi-card Inference for flux2 Pipeline (zero-3 sharding) #1446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @Cui-yshoho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the MindOne Diffusers library by integrating the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Summary of ChangesHello @Cui-yshoho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the MindSpore Diffusers library by integrating the FLUX.2 model with robust support for distributed inference. It addresses the memory demands of large models by leveraging ZeRO-3 sharding, allowing the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for multi-card inference for the Flux2Pipeline using ZeRO-3 sharding, which is a significant feature given the large size of the FLUX.2 model. The changes are comprehensive, adding the necessary pipeline, transformer model, VAE, and LoRA loading components for the Flux2 architecture. The implementation appears well-structured and follows existing patterns in the codebase. I've identified a minor area for code simplification in one of the utility functions to improve readability and remove redundancy. Overall, this is a solid contribution.
| if "img" in modality_block_name: | ||
| # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v} | ||
| to_q_weight, to_k_weight, to_v_weight = mint.chunk(fused_qkv_weight, 3, dim=0) | ||
| new_q_name = "attn.to_q" | ||
| new_k_name = "attn.to_k" | ||
| new_v_name = "attn.to_v" | ||
| elif "txt" in modality_block_name: | ||
| # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj} | ||
| to_q_weight, to_k_weight, to_v_weight = mint.chunk(fused_qkv_weight, 3, dim=0) | ||
| new_q_name = "attn.add_q_proj" | ||
| new_k_name = "attn.add_k_proj" | ||
| new_v_name = "attn.add_v_proj" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mint.chunk operation to split fused_qkv_weight is performed at the beginning of the if "qkv" in within_block_name: block. The subsequent calls to mint.chunk within the if "img" in modality_block_name: and elif "txt" in modality_block_name: blocks are redundant as they re-assign the same chunked tensors. You can remove these redundant calls to simplify the code and improve clarity.
| if "img" in modality_block_name: | |
| # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v} | |
| to_q_weight, to_k_weight, to_v_weight = mint.chunk(fused_qkv_weight, 3, dim=0) | |
| new_q_name = "attn.to_q" | |
| new_k_name = "attn.to_k" | |
| new_v_name = "attn.to_v" | |
| elif "txt" in modality_block_name: | |
| # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj} | |
| to_q_weight, to_k_weight, to_v_weight = mint.chunk(fused_qkv_weight, 3, dim=0) | |
| new_q_name = "attn.add_q_proj" | |
| new_k_name = "attn.add_k_proj" | |
| new_v_name = "attn.add_v_proj" | |
| if "img" in modality_block_name: | |
| # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v} | |
| new_q_name = "attn.to_q" | |
| new_k_name = "attn.to_k" | |
| new_v_name = "attn.to_v" | |
| elif "txt" in modality_block_name: | |
| # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj} | |
| new_q_name = "attn.add_q_proj" | |
| new_k_name = "attn.add_k_proj" | |
| new_v_name = "attn.add_v_proj" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for multi-card inference for the Flux2Pipeline using ZeRO-3 sharding, which is a significant and well-implemented feature. The changes are comprehensive, including the addition of new models, pipelines, documentation, and utilities for LoRA and checkpoint conversion. The code is generally of high quality. I've provided a few minor suggestions for code cleanup, typo fixes, and clarification on certain assumptions to further improve the codebase. Great work on this complex feature!
| class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): | ||
| def __init__(self, *args, **kwargs): | ||
| deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." # noqa: E501 | ||
| deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) | ||
| super().__init__(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if img_ids.ndim == 3: | ||
| img_ids = img_ids[0] | ||
| if txt_ids.ndim == 3: | ||
| txt_ids = txt_ids[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic assumes that if img_ids or txt_ids are batched (i.e., have 3 dimensions), all items in the batch share the same IDs, and it proceeds by just taking the first item (img_ids[0]). This could lead to unexpected behavior if the IDs actually differ across the batch. It would be safer to add a comment to clarify this behavior for future developers.
990e8a4 to
d6af027
Compare
What does this PR do?
Description
This PR introduces support for running the
Flux2Pipelineinference under a distributed setup in MindSpore.Core Implementation & Rationale
Due to the large size of the FLUX.2 model weights, multi-card execution is mandatory, as the model cannot fit onto a single card.
The implementation achieves memory efficiency and parallelism by:
1. Initializing the process group and setting
DATA_PARALLELmode.2. Using the
prepare_train_networkutility withzero_stage=3to apply ZeRO-3 sharding to the memory-heavytransformermodule andtext_encodermodule.3. The provided script is a minimal working example of distributed inference.
Sample Code Included
Usage
To execute the provided script (e.g., saved as
net.py), use themsrunlaunch utility. This example starts the script on two workers/cards:Known Issues & Tips
MindSpore 2.7.1 Warning: Users running on MindSpore version 2.7.1 might encounter an
AttributeError: 'NoneType' object has no attribute 'total_instance_count'. This is a harmless warning that does not affect the final image output and will be resolved in a subsequent MindSpore release.Weight Loading Optimization: If you experience slow weight loading times, we recommend merging the optimization introduced in PR: #1422.
Before submitting
What's New. Here are thedocumentation guidelines
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@xxx