Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
# LICENSE file in the root directory of this source tree.

# Import to register quantization modules.
import torchtitan.components.quantization # noqa: F401
#import torchtitan.components.quantization # noqa: F401

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.experiments # noqa: F401
import torchtitan.models # noqa: F401
#import torchtitan.experiments # noqa: F401
#import torchtitan.models # noqa: F401
8 changes: 4 additions & 4 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torchtitan.experiments.llama4 # noqa: F401
import torchtitan.experiments.qwen3
import torchtitan.experiments.simple_fsdp # noqa: F401
import torchtitan.experiments.vlm # noqa: F401
#import torchtitan.experiments.llama4 # noqa: F401
#import torchtitan.experiments.qwen3
#import torchtitan.experiments.simple_fsdp # noqa: F401
#import torchtitan.experiments.vlm # noqa: F401
189 changes: 71 additions & 118 deletions torchtitan/experiments/deepseek_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,12 @@

from attn_mask_utils import _prepare_4d_causal_attention_mask

from group_gemms import (
DSGroupGEMM,
ManualLoopGroupGEMM,
TorchAOBF16GroupGEMM,
TorchBF16GroupGEMM,
TorchFP8GroupGEMM,
TritonCGBF16GroupGEMM,
)

from model_config import ModelArgs
from symm_mem_recipes import OnDeviceAllToAllV
from torch import nn
from torch.distributed._functional_collectives import all_to_all_single_autograd

from torchtitan.experiments.kernels.moe.indices import generate_permute_indices
from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import ALIGN_SIZE_M
ALIGN_SIZE_M = 8


# Get model parallel subgroup by name:
Expand Down Expand Up @@ -472,11 +462,6 @@ class MoE(nn.Module):
token_send_buf: Optional[torch.Tensor] = None
token_gather_buf: Optional[torch.Tensor] = None

# Group GEMM strategies
group_gemm_strategies = None
# which group gemm to use?
group_mm = "manual" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch", , "torchao", "tritoncg", "manual"]

def __init__(self, config):
super().__init__()
self.config = config
Expand Down Expand Up @@ -512,50 +497,8 @@ def __init__(self, config):
config=config, intermediate_size=intermediate_size
)

# Group Gemm
# Initialize group GEMM strategies if not already loaded
if MoE.group_gemm_strategies is None:
MoE._initialize_group_gemm_strategies()

assert (
MoE.group_mm in MoE.group_gemm_strategies
), f"selected group gemm {self.group_mm} is not available!"
# keep active gg ready
self.group_gemm_instance = MoE.group_gemm_strategies[MoE.group_mm]
self._buffer_initialized = False

@classmethod
def _initialize_group_gemm_strategies(cls):
"""Initialize available group GEMM strategies"""
cls.group_gemm_strategies = {
# torch._group_MM
"torch": TorchBF16GroupGEMM(MLP.act_fn),
# torch.mm with looping
"manual": ManualLoopGroupGEMM(MLP.act_fn),
"torchao": (
TorchAOBF16GroupGEMM(MLP.act_fn)
if TorchAOBF16GroupGEMM.is_available()
else None
),
"torchfp8": (
TorchFP8GroupGEMM(MLP.act_fn)
if TorchFP8GroupGEMM.is_available()
else None
),
"dsgemm": (
DSGroupGEMM(MLP.act_fn, use_triton_quant=True)
if DSGroupGEMM.is_available()
else None
),
"tritoncg": (
TritonCGBF16GroupGEMM(
MLP.act_fn,
)
if TritonCGBF16GroupGEMM.is_available()
else None
),
}

def combine_experts(self, submod_name: str):
all_weights = []
for expert in self.experts.values():
Expand All @@ -565,12 +508,7 @@ def combine_experts(self, submod_name: str):
lin.weight = None

# let the group gemm strategy prep the final weight layout
combined_weight = self.group_gemm_instance.arrange_expert_weights(
all_weights, submod_name, self
)

if combined_weight is None:
raise NotImplementedError("expert weights not handled by group gemmm")
combined_weight = torch.stack(all_weights)

self.register_parameter(f"{submod_name}_weight", nn.Parameter(combined_weight))

Expand Down Expand Up @@ -599,10 +537,17 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
if MoE.token_send_buf is not None:
return

self.group_name = self.ep_group.group_name

symm_mem.set_backend("NVSHMEM")
symm_mem.enable_symm_mem_for_group("0")
symm_mem.enable_symm_mem_for_group(self.group_name)

# Input buffer for DP-to-EP shuffle
MoE.token_send_buf = symm_mem.empty(
self.config.max_seq_len
* self.num_experts_per_tok, # seq len * top k (flattened)
* self.num_experts_per_tok # seq len * top k (flattened)
* overflow,
self.config.hidden_size, # hidden dim
dtype=dtype,
device=device,
Expand All @@ -617,6 +562,15 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
device=device,
)

nsplits = self.config.n_routed_experts
MoE.in_splits = symm_mem.empty(nsplits, dtype=torch.int64, device=device)
MoE.out_splits_offsets = symm_mem.empty(
(2, nsplits), dtype=torch.int64, device=device
)
MoE.combine_out_splits_offsets = symm_mem.empty(
(2, nsplits), dtype=torch.int64, device=device
)

def get_send_buf(self):
# [Why detach?] During a first forward-backward step, the buffer would
# be included in a computational graph. In a second step, autograd will
Expand Down Expand Up @@ -791,13 +745,37 @@ def sort_tokens(self, x, topk_ids, topk_weights):
def _run_group_gemm(self, contig_tokens, m_sizes, m_offsets):
"""Run the appropriate group GEMM implementation based on configuration"""

try:
return self.group_gemm_strategies[self.group_mm].execute(
contig_tokens, m_sizes, m_offsets, self
)
except Exception as e:
# Flag the error
print(f"Error using {self.group_mm} strategy: {e}")
# Get weights
w_gate = self.get_parameter("gate_proj_weight")
w_up = self.get_parameter("up_proj_weight")
w_down = self.get_parameter("down_proj_weight")

# Run first two GEMMs (gate and up projections)
gate_proj = torch._grouped_mm(
contig_tokens,
w_gate.transpose(-2, -1),
m_offsets,
out_dtype=torch.bfloat16,
)
up_proj = torch._grouped_mm(
contig_tokens,
w_up.transpose(-2, -1),
m_offsets,
out_dtype=torch.bfloat16,
)

# Apply activation
hidden_outputs = self.activation_function(gate_proj) * up_proj

# Run the third GEMM (down projection)
hidden_outputs = torch._grouped_mm(
hidden_outputs,
w_down.transpose(-2, -1),
m_offsets,
out_dtype=torch.bfloat16,
)

return hidden_outputs

def moe_on_device(self, x, topk_ids, topk_weight):
(
Expand All @@ -814,65 +792,40 @@ def moe_on_device(self, x, topk_ids, topk_weight):
# band", which is not part of the actual data. Thus no gradient is
# needed.

# Sum the tokens over local experts, then we get tokens per EP rank,
# which is the input splits
with torch.no_grad():
tokens_per_expert_group = tokens_per_expert.new_empty(
tokens_per_expert.shape[0]
)
dist.all_to_all_single(
tokens_per_expert_group, tokens_per_expert, group=self.ep_group
)
input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
MoE.in_splits.copy_(tokens_per_expert.view(-1))

# Move input to the `token_send_buf` symm mem
token_send_buf = self.get_send_buf()
token_send_buf[: token_indices.shape[0]].copy_(sorted_tokens)
# Note: `out=` avoids copy, but it is not differentiable
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=token_send_buf[: idxs.shape[0]])
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
token_send_buf,
input_splits,
self.ep_group,
)
token_gather_buf = self.get_gather_buf()

# We need to permute the received tokens so that tokens for the same expert are contiguous.
# This part prepares a 1D tensor `permuted_indices` for such permutation.
# This part doesn't need gradient.
with torch.no_grad():
permuted_indices, m_sizes, m_offsets = generate_permute_indices(
tokens_per_expert_group,
self.experts_per_rank,
self.ep_size,
token_gather_buf.shape[0],
ALIGN_SIZE_M,
)
# Dispatch the tokens
torch.ops.symm_mem.all_to_all_vdev_2d(
token_send_buf, token_gather_buf, MoE.in_splits, MoE.out_splits_offsets, self.group_name, major_align=ALIGN_SIZE_M
)

# Permute the received tokens so that tokens for the same expert are contiguous.
contig_tokens = token_gather_buf[permuted_indices]
m_offsets = torch.empty(self.experts_per_rank, dtype=MoE.in_splits.dtype, device=MoE.in_splits.device)
exclusive_offsets = MoE.out_splits_offsets[1].view(self.experts_per_rank, -1)
m_offsets[:-1].copy_(exclusive_offsets[1: , 0])
m_offsets[-1] = exclusive_offsets[-1, 0] + MoE.in_splits[-1]

# group gemm - handle all three group gemms (up, gate, down for all experts)
hidden_outputs = self._run_group_gemm(
contig_tokens,
m_sizes,
m_offsets,
processed_tokens = self._run_group_gemm(
token_gather_buf,
MoE.out_splits_offsets[0],
m_offsets.to(torch.int32),
)

# Prepare buffer for tokens processed by experts
processed_tokens = self.get_gather_buf()

# Move into Symmetric Memory for the return shuffle
processed_tokens[permuted_indices] = hidden_outputs
token_gather_buf.copy_(processed_tokens)

# Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
# The input/output splits are just a reverse of the previous shuffle.
token_return_buf, _ = OnDeviceAllToAllV.apply(
processed_tokens,
output_splits,
self.ep_group,
# Combine the tokens
# `out_splits_offsets` from shuffle is exactly the `input_splits_offsets` for combine
# `out` data from shuffle is exactly the `input` data for combine
torch.ops.symm_mem.all_to_all_vdev_2d_offset(
token_gather_buf, token_send_buf, MoE.out_splits_offsets, MoE.combine_out_splits_offsets, self.group_name
)

returned_tokens = token_return_buf[:seqlen_sorted_tokens]
returned_tokens = token_send_buf[:seqlen_sorted_tokens]
output_tokens = torch.empty_like(returned_tokens)
output_tokens[token_indices] = returned_tokens

Expand Down