Skip to content

Conversation

@voidbag
Copy link
Contributor

@voidbag voidbag commented Feb 8, 2026

[DRAFT] Optimize MoE Routing via torch.sort Indices DType Injection

Important

This PR is a Draft and is blocked by the merging of upstream pytorch/pytorch#170978.

Description

This PR optimizes the Mixture of Experts (MoE) routing hot-paths within deepep.py and moe.py. By utilizing the new out-variant index dtype injection introduced in the upstream PyTorch PR, we can significantly reduce the memory overhead and latency of token-to-expert assignment.

HBM and DDR memory are premium resources in modern training clusters. Traditionally, torch.argsort defaults to int64 indices, which is overkill for most expert routing scales. This patch allows us to use int32 (and eventually smaller types), directly addressing the "memory wall" bottleneck in high-performance compute.

If PyTorch indexing supports every integer type indexing, MoE can enjoy full features of indices_dtype.

📊 Detailed Performance Benchmark Results & Latency Analysis


Key Improvements

  • Memory Efficiency: Reduces the memory footprint of routing sort indices by 50% (2x improvement) by switching from int64 to int32.
  • Index Compatibility: Uses int32 to ensure the resulting indices are immediately usable in advanced indexing operations (e.g., token shuffling) without requiring a costly .to(torch.long) cast.
  • Throughput: Up to 2.24x speedup and 3.33x memory footprint saving by minimizing data movement across the memory bus.
  • Zero-Overhead Scaling: Enables "free" performance gains by matching index precision to the actual requirements of the sort dimension.

✅ Status & Todo

  • Integrate _indices_dtype_by_sort_size logic into torchtitan.distributed.deepep.
  • Update TokenReorderer in torchtitan.models.moe.
  • Blocked by: pytorch/pytorch#170978
  • optional When sort is merged, topk can be modified to support indices_dtype in out-variant op.

💡 Implementation Note

By using the out= variant with a pre-allocated int32 tensor, we avoid the implicit allocation of a 64-bit index buffer. This is particularly beneficial for the large-scale token shuffles used in torchtitan to ensure peak memory is kept to a minimum.

Integrates out-variant index dtype injection in MoE routing logic
to reduce memory bandwidth and peak footprint.

References pytorch/pytorch#170978

Signed-off-by: Taeksang Kim <ts.kim@hyperaccel.ai>
@meta-cla
Copy link

meta-cla bot commented Feb 8, 2026

Hi @voidbag!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla
Copy link

meta-cla bot commented Feb 8, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 8, 2026
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)


def _indices_dtype_by_sort_size(data: Tensor, sort_dim=-1) -> torch.dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this type of util also put into pytorch core, or embedded into argsort?

Copy link
Contributor Author

@voidbag voidbag Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments! I actually explored integrating this into the standard argsort utility during the initial design of the aten sort optimizations. However, maintainer feedback at the time emphasized that modifying the signatures of fundamental operators like sort/argsort posed backward-compatibility risks.

To align with that upstream preference for API stability, I decided to use the out variant instead. It allows us to support dynamic indices effectively without the friction of a core signature change. I'm open to moving this into a shared utility if we find a way to do so!

[Updated]
There is the code snippet doing the function in case of down up cast
https://github.com/voidbag/pytorch/blob/76357cc5d9e02e5ac830712bbe348444224df937/aten/src/ATen/native/Sorting.cpp#L966-L975

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants