Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,13 @@
# LICENSE file in the root directory of this source tree.

_supported_experiments = frozenset(
["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
[
"flux",
"llama4",
"qwen3",
"simple_fsdp.llama3",
"simple_fsdp.deepseek_v3",
"vlm",
"gpt_oss",
]
)
19 changes: 19 additions & 0 deletions torchtitan/experiments/gpt_oss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# gpt-oss Model in torchtitan

## Quick Start
```bash
CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./run_train.sh
```

## Supported Features
- FSDP/HSDP, TP, EP, ETP
- Grouped matrix multiplication for efficient computation
- SwiGLU activation
- Multi-head attention with sliding window mask and attention sink


## TODO
1. More parallelism support: CP, PP
2. Conversion between HF weights (StateDictAdapter)
3. Forward parity verification
4. CI support
92 changes: 92 additions & 0 deletions torchtitan/experiments/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.moe import MoEArgs

from torchtitan.protocols.train_spec import TrainSpec

from .infra.parallelize import parallelize_gptoss
from .model.args import GptOssModelArgs
from .model.model import GptOssModel

__all__ = [
"parallelize_gptoss",
"GptOssModelArgs",
"GptOssModel",
"gptoss_configs",
]


gptoss_configs = {
"debugmodel": GptOssModelArgs(
hidden_size=256,
num_hidden_layers=4,
moe_args=MoEArgs(
num_experts=8,
num_shared_experts=0,
score_func="softmax",
route_norm=False,
route_scale=1.0,
score_before_experts=False,
top_k=4,
use_grouped_mm=True,
load_balance_coeff=1e-3,
),
use_flex_attn=True,
attn_mask_type="causal",
),
"20b": GptOssModelArgs(
num_hidden_layers=24,
moe_args=MoEArgs(
num_experts=32,
num_shared_experts=0,
score_func="softmax",
route_norm=False,
route_scale=1.0,
score_before_experts=False,
top_k=4,
use_grouped_mm=True,
load_balance_coeff=1e-3,
),
),
"120b": GptOssModelArgs(
num_hidden_layers=36,
moe_args=MoEArgs(
num_experts=128,
num_shared_experts=0,
score_func="softmax",
route_norm=False,
route_scale=1.0,
score_before_experts=False,
top_k=4,
use_grouped_mm=True,
load_balance_coeff=1e-3,
),
),
}


def get_train_spec() -> TrainSpec:
return TrainSpec(
name="gpt_oss",
model_cls=GptOssModel,
model_args=gptoss_configs,
parallelize_fn=parallelize_gptoss,
pipelining_fn=None,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
194 changes: 194 additions & 0 deletions torchtitan/experiments/gpt_oss/infra/expert_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable

import torch
import torch.nn as nn
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor.parallel import ParallelStyle
from torchtitan.distributed.expert_parallel import ExpertParallel


# implementation of Tensor Parallel for the GroupedExperts in MoE
class TensorParallel(ParallelStyle):
def _partition_fn(self, name, module, device_mesh):
module.register_parameter(
"mlp1_weight",
nn.Parameter(
distribute_tensor(module.mlp1_weight, device_mesh, [Shard(2)])
),
) # Column-wise sharding
module.register_parameter(
"mlp1_bias",
nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])),
) # Column-wise sharding
module.register_parameter(
"mlp2_weight",
nn.Parameter(
distribute_tensor(module.mlp2_weight, device_mesh, [Shard(1)])
),
) # Row-wise sharding
module.register_parameter(
"mlp2_bias",
nn.Parameter(
distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()])
),
) # Replicate

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
self._partition_fn,
)


# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
class ExpertTensorParallel(ExpertParallel):
def __init__(
self,
tp_mesh: DeviceMesh,
ep_mesh: DeviceMesh,
):
super().__init__()
# TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh,
# as DeviceMesh doesn't support slicing from a submesh.
self.tp_mesh = tp_mesh
self.ep_mesh = ep_mesh

def _token_dispatch(self, mod, inputs, device_mesh):
# token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
return super()._token_dispatch(mod, inputs, self.ep_mesh)

def _partition_fn_2d(self, name, mod, ep_tp_mesh):
mod.register_parameter(
"mlp1_weight",
nn.Parameter(
distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(2)])
),
) # Column-wise sharding
mod.register_parameter(
"mlp1_bias",
nn.Parameter(
distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)])
),
) # Column-wise sharding
mod.register_parameter(
"mlp2_weight",
nn.Parameter(
distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(1)])
),
) # Row-wise sharding
mod.register_parameter(
"mlp2_bias",
nn.Parameter(
distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()])
),
) # Replicate

def _token_combine(self, mod, routed_output, device_mesh):
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
return super()._token_combine(mod, routed_output, self.ep_mesh)

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
partition_fn=self._partition_fn_2d,
input_fn=self._token_dispatch,
output_fn=self._token_combine,
)


# TODO(jianiw): This need to be merged with expert_parallel
def expert_parallel(func: Callable) -> Callable:
"""
This is a wrapper applied to the GroupedExperts computation, serving
the following three purposes:
1. Convert parameters from DTensors to plain Tensors, to work with
dynamic-shape inputs which cannot be easily expressed as DTensors.
2. In Expert Parallel, apply the generate_permute_indices kernel to
permute the inputs to be ordered by local experts (see the _token_dispatch
function in ExpertParallel) and permute the outputs back.
3. In order to use torch._grouped_mm, we need to make sure the number of
tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices
kernel also helps achieve this via padding, without incurring synchronization
between device and host. Note that this will create side effects when wrapping
the for-loop implementation of GroupedExperts, as it does not need padding.

Among the above:
1 and 2 are needed only when expert_parallel_degree > 1.
3 is needed even for single-device computation.
2 can be moved to ExpertParallel _token_dispatch if not coupled with 3.
"""

def wrapper(
mlp1_weight: torch.Tensor,
mlp1_bias: torch.Tensor,
mlp2_weight: torch.Tensor,
mlp2_bias: torch.Tensor,
swiglu_limit: float,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
if isinstance(mlp1_weight, DTensor):
mlp1_weight = mlp1_weight.to_local()
mlp1_bias = mlp1_bias.to_local()
mlp2_weight = mlp2_weight.to_local()
mlp2_bias = mlp2_bias.to_local()

if num_tokens_per_expert is not None:
from torchtitan.experiments.kernels.moe.indices import (
generate_permute_indices,
)

experts_per_ep_rank = mlp1_weight.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

ALIGN_SIZE_M = 16
with torch.no_grad():
(
permuted_indices,
num_tokens_per_expert,
_, # offsets,
) = generate_permute_indices(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M,
ALIGN_SIZE_M,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]

out = func(
mlp1_weight,
mlp1_bias,
mlp2_weight,
mlp2_bias,
swiglu_limit,
x,
num_tokens_per_expert,
)

if num_tokens_per_expert is not None:
out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]

return out

return wrapper
Loading