Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion torchtitan/distributed/deepep/deepep.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,24 @@ def get_buffer(group: ProcessGroup, hidden_bytes: int) -> Buffer:
return _buffer


def _indices_dtype_by_sort_size(data: Tensor, sort_dim=-1) -> torch.dtype:
sort_size = data.size(sort_dim)
indices_dtype = torch.long
if sort_size - 1 <= torch.iinfo(torch.int8).max:
indices_dtype = torch.int8
elif sort_size - 1 <= torch.iinfo(torch.uint8).max:
indices_dtype = torch.uint8
elif sort_size - 1 <= torch.iinfo(torch.int16).max:
indices_dtype = torch.int16
elif sort_size - 1 <= torch.iinfo(torch.uint16).max:
indices_dtype = torch.uint16
elif sort_size - 1 <= torch.iinfo(torch.int32).max:
indices_dtype = torch.int32
else:
indices_dtype = torch.long
return indices_dtype


def _permute_tokens(
hidden_states: torch.Tensor,
dispatched_indices: torch.Tensor,
Expand All @@ -343,7 +361,15 @@ def _permute_tokens(
valid_scores = dispatched_scores[mask]

# Repeat each token by its valid count and select tokens in expert order
sort_order = torch.argsort(valid_expert_ids, stable=True)

# Current torch indexing mechanism supports only int64 and int32
# If other integer value is supported in indexing, the torch native ops, below optimization can be enjoyed
# indices_dtype = _indices_dtype_by_sort_size(valid_expert_ids) if valid_expert_ids.is_cuda or valid_expert_ids.is_cpu else torch.long
# indices_dtype = indices_dtype.to(torch.int32) # addinitional copy is required in the constraint indexing dtype (int32, long)
indices_dtype = torch.int32 # Free performance improvement
sort_order = torch.empty((), device=valid_expert_ids.device(), dtype=indices_dtype)
torch.argsort(valid_expert_ids, stable=True, out=sort_order)

permuted_indices = torch.arange(
len(hidden_states), device=hidden_states.device
).repeat_interleave(mask.sum(dim=1))[sort_order]
Expand Down
30 changes: 28 additions & 2 deletions torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,24 @@ def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)


def _indices_dtype_by_sort_size(data: Tensor, sort_dim=-1) -> torch.dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this type of util also put into pytorch core, or embedded into argsort?

Copy link
Contributor Author

@voidbag voidbag Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments! I actually explored integrating this into the standard argsort utility during the initial design of the aten sort optimizations. However, maintainer feedback at the time emphasized that modifying the signatures of fundamental operators like sort/argsort posed backward-compatibility risks.

To align with that upstream preference for API stability, I decided to use the out variant instead. It allows us to support dynamic indices effectively without the friction of a core signature change. I'm open to moving this into a shared utility if we find a way to do so!

[Updated]
There is the code snippet doing the function in case of down up cast
https://github.com/voidbag/pytorch/blob/76357cc5d9e02e5ac830712bbe348444224df937/aten/src/ATen/native/Sorting.cpp#L966-L975

sort_size = data.size(sort_dim)
indices_dtype = torch.long
if sort_size - 1 <= torch.iinfo(torch.int8).max:
indices_dtype = torch.int8
elif sort_size - 1 <= torch.iinfo(torch.uint8).max:
indices_dtype = torch.uint8
elif sort_size - 1 <= torch.iinfo(torch.int16).max:
indices_dtype = torch.int16
elif sort_size - 1 <= torch.iinfo(torch.uint16).max:
indices_dtype = torch.uint16
elif sort_size - 1 <= torch.iinfo(torch.int32).max:
indices_dtype = torch.int32
else:
indices_dtype = torch.long
return indices_dtype


# NOTE: the reason we make this a stateless module is to support
# expert_tensor_parallel_degree=1 with consistent TP/EP APIs.
class TokenReorderer(nn.Module):
Expand Down Expand Up @@ -403,9 +421,17 @@ def forward(

# Reorder the token indices to match the order of the experts
# token_indices_experts_sorted shape (bs*slen*top_k,)
token_indices_experts_sorted = torch.argsort(
selected_experts_indices.view(-1), stable=True
to_sort = selected_experts_indices.view(-1)
# Current torch indexing mechanism supports only int64 and int32
# If other integer value is supported in indexing, the torch native ops, below optimization can be enjoyed
# indices_dtype = _indices_dtype_by_sort_size(valid_expert_ids) if valid_expert_ids.is_cuda or valid_expert_ids.is_cpu else torch.long
# indices_dtype = indices_dtype.to(torch.int32) # addinitional copy is required in the constraint indexing dtype (int32, long)
indices_dtype = torch.int32 # Free performance improvement

token_indices_experts_sorted = torch.empty(
(), device=to_sort.device(), dtype=indices_dtype
)
torch.argsort(to_sort, stable=True, out=token_indices_experts_sorted)

top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted]

Expand Down