-
Notifications
You must be signed in to change notification settings - Fork 699
[DRAFT] Optimize MoE Routing via torch.sort Indices DType Injection
#2343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Integrates out-variant index dtype injection in MoE routing logic to reduce memory bandwidth and peak footprint. References pytorch/pytorch#170978 Signed-off-by: Taeksang Kim <ts.kim@hyperaccel.ai>
|
Hi @voidbag! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
| nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) | ||
|
|
||
|
|
||
| def _indices_dtype_by_sort_size(data: Tensor, sort_dim=-1) -> torch.dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this type of util also put into pytorch core, or embedded into argsort?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comments! I actually explored integrating this into the standard argsort utility during the initial design of the aten sort optimizations. However, maintainer feedback at the time emphasized that modifying the signatures of fundamental operators like sort/argsort posed backward-compatibility risks.
To align with that upstream preference for API stability, I decided to use the out variant instead. It allows us to support dynamic indices effectively without the friction of a core signature change. I'm open to moving this into a shared utility if we find a way to do so!
[Updated]
There is the code snippet doing the function in case of down up cast
https://github.com/voidbag/pytorch/blob/76357cc5d9e02e5ac830712bbe348444224df937/aten/src/ATen/native/Sorting.cpp#L966-L975
[DRAFT] Optimize MoE Routing via
torch.sortIndices DType InjectionImportant
This PR is a Draft and is blocked by the merging of upstream pytorch/pytorch#170978.
Description
This PR optimizes the Mixture of Experts (MoE) routing hot-paths within
deepep.pyandmoe.py. By utilizing the newout-variantindexdtypeinjection introduced in the upstream PyTorch PR, we can significantly reduce the memory overhead and latency of token-to-expert assignment.HBM and DDR memory are premium resources in modern training clusters. Traditionally,
torch.argsortdefaults toint64indices, which is overkill for most expert routing scales. This patch allows us to useint32(and eventually smaller types), directly addressing the "memory wall" bottleneck in high-performance compute.If PyTorch indexing supports every integer type indexing, MoE can enjoy full features of indices_dtype.
📊 Detailed Performance Benchmark Results & Latency Analysis
Key Improvements
int64toint32.int32to ensure the resulting indices are immediately usable in advanced indexing operations (e.g., token shuffling) without requiring a costly.to(torch.long)cast.✅ Status & Todo
_indices_dtype_by_sort_sizelogic intotorchtitan.distributed.deepep.TokenReordererintorchtitan.models.moe.💡 Implementation Note
By using the
out=variant with a pre-allocatedint32tensor, we avoid the implicit allocation of a 64-bit index buffer. This is particularly beneficial for the large-scale token shuffles used intorchtitanto ensure peak memory is kept to a minimum.