diff --git a/torchtitan/distributed/deepep/deepep.py b/torchtitan/distributed/deepep/deepep.py index bf15fc2b99..702709a032 100644 --- a/torchtitan/distributed/deepep/deepep.py +++ b/torchtitan/distributed/deepep/deepep.py @@ -317,6 +317,24 @@ def get_buffer(group: ProcessGroup, hidden_bytes: int) -> Buffer: return _buffer +def _indices_dtype_by_sort_size(data: Tensor, sort_dim=-1) -> torch.dtype: + sort_size = data.size(sort_dim) + indices_dtype = torch.long + if sort_size - 1 <= torch.iinfo(torch.int8).max: + indices_dtype = torch.int8 + elif sort_size - 1 <= torch.iinfo(torch.uint8).max: + indices_dtype = torch.uint8 + elif sort_size - 1 <= torch.iinfo(torch.int16).max: + indices_dtype = torch.int16 + elif sort_size - 1 <= torch.iinfo(torch.uint16).max: + indices_dtype = torch.uint16 + elif sort_size - 1 <= torch.iinfo(torch.int32).max: + indices_dtype = torch.int32 + else: + indices_dtype = torch.long + return indices_dtype + + def _permute_tokens( hidden_states: torch.Tensor, dispatched_indices: torch.Tensor, @@ -343,7 +361,15 @@ def _permute_tokens( valid_scores = dispatched_scores[mask] # Repeat each token by its valid count and select tokens in expert order - sort_order = torch.argsort(valid_expert_ids, stable=True) + + # Current torch indexing mechanism supports only int64 and int32 + # If other integer value is supported in indexing, the torch native ops, below optimization can be enjoyed + # indices_dtype = _indices_dtype_by_sort_size(valid_expert_ids) if valid_expert_ids.is_cuda or valid_expert_ids.is_cpu else torch.long + # indices_dtype = indices_dtype.to(torch.int32) # addinitional copy is required in the constraint indexing dtype (int32, long) + indices_dtype = torch.int32 # Free performance improvement + sort_order = torch.empty((), device=valid_expert_ids.device(), dtype=indices_dtype) + torch.argsort(valid_expert_ids, stable=True, out=sort_order) + permuted_indices = torch.arange( len(hidden_states), device=hidden_states.device ).repeat_interleave(mask.sum(dim=1))[sort_order] diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 39b8d4bc92..46e2e74481 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -356,6 +356,24 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) +def _indices_dtype_by_sort_size(data: Tensor, sort_dim=-1) -> torch.dtype: + sort_size = data.size(sort_dim) + indices_dtype = torch.long + if sort_size - 1 <= torch.iinfo(torch.int8).max: + indices_dtype = torch.int8 + elif sort_size - 1 <= torch.iinfo(torch.uint8).max: + indices_dtype = torch.uint8 + elif sort_size - 1 <= torch.iinfo(torch.int16).max: + indices_dtype = torch.int16 + elif sort_size - 1 <= torch.iinfo(torch.uint16).max: + indices_dtype = torch.uint16 + elif sort_size - 1 <= torch.iinfo(torch.int32).max: + indices_dtype = torch.int32 + else: + indices_dtype = torch.long + return indices_dtype + + # NOTE: the reason we make this a stateless module is to support # expert_tensor_parallel_degree=1 with consistent TP/EP APIs. class TokenReorderer(nn.Module): @@ -403,9 +421,17 @@ def forward( # Reorder the token indices to match the order of the experts # token_indices_experts_sorted shape (bs*slen*top_k,) - token_indices_experts_sorted = torch.argsort( - selected_experts_indices.view(-1), stable=True + to_sort = selected_experts_indices.view(-1) + # Current torch indexing mechanism supports only int64 and int32 + # If other integer value is supported in indexing, the torch native ops, below optimization can be enjoyed + # indices_dtype = _indices_dtype_by_sort_size(valid_expert_ids) if valid_expert_ids.is_cuda or valid_expert_ids.is_cpu else torch.long + # indices_dtype = indices_dtype.to(torch.int32) # addinitional copy is required in the constraint indexing dtype (int32, long) + indices_dtype = torch.int32 # Free performance improvement + + token_indices_experts_sorted = torch.empty( + (), device=to_sort.device(), dtype=indices_dtype ) + torch.argsort(to_sort, stable=True, out=token_indices_experts_sorted) top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted]