-
Notifications
You must be signed in to change notification settings - Fork 552
[WIP] Experimental implementation of gpt-oss (grouped GEMM MoE + FlexAttention sink/sliding) #1559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
KhoomeiK
wants to merge
8
commits into
pytorch:main
Choose a base branch
from
KhoomeiK:khoomeik/upstream-tt-gpt-oss
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
bc749ba
gptoss experimental support
6656665
Merge branch 'main' into khoomeik/upstream-tt-gpt-oss
c2c9ed1
clean up tentative licensing
b9d3196
training fixes: expert load balancing, TP for sinks + experts, EP wor…
4186ad8
only assert sdpa backends if using sdpa; improve conversion script
a489a13
fixed conversion script with param by param
6e2d96a
Merge branch 'main' into khoomeik/upstream-tt-gpt-oss
d0a54fb
new lse-based flexattn implementation for sinks
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtitan.components.loss import build_cross_entropy_loss | ||
from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
from torchtitan.components.tokenizer import build_hf_tokenizer | ||
from torchtitan.datasets.hf_datasets import build_hf_dataloader | ||
from .infra.optimizer import build_gptoss_optimizers | ||
|
||
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec | ||
|
||
from .infra.parallelize import parallelize_gptoss | ||
from .model.args import GptOssModelArgs | ||
from .model.model import GptOssModel | ||
|
||
__all__ = [ | ||
"parallelize_gptoss", | ||
"GptOssModelArgs", | ||
"GptOssModel", | ||
"gptoss_configs", | ||
] | ||
|
||
|
||
gptoss_configs = { | ||
"debugmodel": GptOssModelArgs( | ||
hidden_size=256, | ||
num_hidden_layers=4, | ||
use_flex_attn=False, | ||
use_grouped_mm=False, | ||
), | ||
"20b": GptOssModelArgs( | ||
num_hidden_layers=24, | ||
num_local_experts=32, | ||
), | ||
"120b": GptOssModelArgs( | ||
num_hidden_layers=36, | ||
num_local_experts=128, | ||
), | ||
} | ||
|
||
|
||
register_train_spec( | ||
TrainSpec( | ||
name="gpt_oss", | ||
cls=GptOssModel, | ||
config=gptoss_configs, | ||
parallelize_fn=parallelize_gptoss, | ||
pipelining_fn=None, | ||
build_optimizers_fn=build_gptoss_optimizers, # use optimizer hooks to update expert weights | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
build_dataloader_fn=build_hf_dataloader, | ||
build_tokenizer_fn=build_hf_tokenizer, | ||
build_loss_fn=build_cross_entropy_loss, | ||
) | ||
) |
297 changes: 297 additions & 0 deletions
297
torchtitan/experiments/gpt_oss/infra/expert_parallel.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,297 @@ | ||
from functools import partial | ||
from typing import Callable | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from torch.distributed._functional_collectives import all_to_all_single_autograd | ||
from torch.distributed.tensor import ( | ||
DeviceMesh, | ||
distribute_module, | ||
distribute_tensor, | ||
DTensor, | ||
Replicate, | ||
Shard, | ||
) | ||
from torch.distributed.tensor.parallel import ParallelStyle | ||
from torch.distributed.tensor.placement_types import Placement | ||
|
||
|
||
# implementation of Tensor Parallel for the GroupedExperts in MoE | ||
class TensorParallel(ParallelStyle): | ||
def _partition_fn(self, name, module, device_mesh): | ||
module.register_parameter( | ||
"mlp1_weight", nn.Parameter(distribute_tensor(module.mlp1_weight, device_mesh, [Shard(2)])) | ||
) # Column-wise sharding | ||
module.register_parameter( | ||
"mlp1_bias", | ||
nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])), | ||
) # Column-wise sharding | ||
module.register_parameter( | ||
"mlp2_weight", | ||
nn.Parameter(distribute_tensor(module.mlp2_weight, device_mesh, [Shard(1)])), | ||
) # Row-wise sharding | ||
module.register_parameter( | ||
"mlp2_bias", | ||
nn.Parameter(distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()])), | ||
) # Replicate | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
self._partition_fn, | ||
) | ||
|
||
|
||
# NOTE: This is to achieve replicate computation on the gate module in the MoE router. | ||
# It does nothing other than (1) setting the module parameters as DTensors on the given mesh | ||
# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. | ||
# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, | ||
# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. | ||
class NoParallel(ParallelStyle): | ||
def __init__( | ||
self, | ||
*, | ||
input_layout: Placement | None = None, | ||
output_layout: Placement | None = None, | ||
use_local_output: bool = True, | ||
): | ||
super().__init__() | ||
self.input_layout = input_layout or Replicate() | ||
self.output_layout = output_layout or Replicate() | ||
self.desired_input_layout = Replicate() | ||
self.use_local_output = use_local_output | ||
|
||
@staticmethod | ||
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): | ||
# annotate module input placements/sharding with input_layouts | ||
input_tensor = inputs[0] | ||
if not isinstance(input_tensor, DTensor): | ||
input_tensor = DTensor.from_local( | ||
input_tensor, device_mesh, (input_layout,), run_check=False | ||
) | ||
|
||
if input_layout != desired_input_layout: | ||
input_tensor = input_tensor.redistribute( | ||
placements=(desired_input_layout,), async_op=True | ||
) | ||
return (input_tensor, *inputs[1:]) | ||
|
||
@staticmethod | ||
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): | ||
if outputs.placements != (output_layout,): | ||
outputs = outputs.redistribute(placements=(output_layout,), async_op=True) | ||
# back to local tensor | ||
return outputs.to_local() if use_local_output else outputs | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
None, | ||
partial( | ||
self._prepare_input_fn, self.input_layout, self.desired_input_layout | ||
), | ||
partial(self._prepare_output_fn, self.output_layout, self.use_local_output), | ||
) | ||
|
||
|
||
class ExpertParallel(ParallelStyle): | ||
def __init__(self): | ||
super().__init__() | ||
self.input_splits = None | ||
self.output_splits = None | ||
|
||
# performing all-to-all dispatch on the input | ||
def _token_dispatch(self, mod, inputs, device_mesh): | ||
# annotate module input placements/sharding with input_layouts | ||
routed_input, num_tokens_per_expert = inputs | ||
|
||
# generate the input splits and output splits for all-to-all | ||
with torch.no_grad(): | ||
num_tokens_per_expert_group = num_tokens_per_expert.new_empty( | ||
num_tokens_per_expert.shape[0] | ||
) | ||
dist.all_to_all_single( | ||
num_tokens_per_expert_group, | ||
num_tokens_per_expert, | ||
group=device_mesh.get_group(), | ||
) | ||
# NOTE: this would incur a device-to-host sync | ||
self.input_splits = ( | ||
num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist() | ||
) | ||
self.output_splits = ( | ||
num_tokens_per_expert_group.view(device_mesh.shape[0], -1) | ||
.sum(dim=1) | ||
.tolist() | ||
) | ||
|
||
# perform all-to-all | ||
routed_input = all_to_all_single_autograd( | ||
routed_input, | ||
self.output_splits, | ||
self.input_splits, | ||
device_mesh.get_group(), | ||
) | ||
|
||
# NOTE: After this all-to-all, the routed input is put on proper EP rank. | ||
# However, the num_tokens_per_expert_group is not of the final target format | ||
# [#tokens for local expert 0, #tokens for local expert 1, ...] | ||
# Rather, it is of the format | ||
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., | ||
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] | ||
# We need to perform another shuffle to get the correct format -- this is done via the function | ||
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens | ||
# each expert gets locally is a multiple of ALIGN_SIZE_M. | ||
|
||
return routed_input, num_tokens_per_expert_group | ||
|
||
@staticmethod | ||
def _partition_fn(name, mod, device_mesh): | ||
# shard on the expert dimension | ||
for name, param in mod.named_parameters(recurse=False): | ||
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) | ||
mod.register_parameter(name, dist_param) | ||
|
||
# performing all-to-all combine on the output | ||
def _token_combine(self, mod, routed_output, device_mesh): | ||
routed_output = all_to_all_single_autograd( | ||
routed_output, | ||
self.input_splits, | ||
self.output_splits, | ||
device_mesh.get_group(), | ||
) | ||
return routed_output | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
partition_fn=ExpertParallel._partition_fn, | ||
input_fn=self._token_dispatch, | ||
output_fn=self._token_combine, | ||
) | ||
|
||
|
||
# This class is for dp2ep with TP (without TP we can just use ExpertParallel) | ||
class ExpertTensorParallel(ExpertParallel): | ||
def __init__( | ||
self, | ||
tp_mesh: DeviceMesh, | ||
ep_mesh: DeviceMesh, | ||
): | ||
super().__init__() | ||
# TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, | ||
# as DeviceMesh doesn't support slicing from a submesh. | ||
self.tp_mesh = tp_mesh | ||
self.ep_mesh = ep_mesh | ||
|
||
def _token_dispatch(self, mod, inputs, device_mesh): | ||
# token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh | ||
return super()._token_dispatch(mod, inputs, self.ep_mesh) | ||
|
||
def _partition_fn_2d(self, name, mod, ep_tp_mesh): | ||
mod.register_parameter( | ||
"mlp1_weight", | ||
nn.Parameter(distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(2)])), | ||
) # Column-wise sharding | ||
mod.register_parameter( | ||
"mlp1_bias", | ||
nn.Parameter(distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)])), | ||
) # Row-wise sharding | ||
mod.register_parameter( | ||
"mlp2_weight", | ||
nn.Parameter(distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)])), | ||
) # Column-wise sharding | ||
mod.register_parameter( | ||
"mlp2_bias", | ||
nn.Parameter(distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Shard(1)])), | ||
) # Row-wise sharding | ||
Comment on lines
+204
to
+211
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these two should be |
||
|
||
def _token_combine(self, mod, routed_output, device_mesh): | ||
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh | ||
return super()._token_combine(mod, routed_output, self.ep_mesh) | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
partition_fn=self._partition_fn_2d, | ||
input_fn=self._token_dispatch, | ||
output_fn=self._token_combine, | ||
) | ||
|
||
|
||
def expert_parallel(func: Callable) -> Callable: | ||
""" | ||
This is a wrapper applied to the GroupedExperts computation, serving | ||
the following three purposes: | ||
1. Convert parameters from DTensors to plain Tensors, to work with | ||
dynamic-shape inputs which cannot be easily expressed as DTensors. | ||
2. In Expert Parallel, apply the generate_permute_indices kernel to | ||
permute the inputs to be ordered by local experts (see the _token_dispatch | ||
function in ExpertParallel) and permute the outputs back. | ||
3. In order to use torch._grouped_mm, we need to make sure the number of | ||
tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices | ||
kernel also helps achieve this via padding, without incurring synchronization | ||
between device and host. Note that this will create side effects when wrapping | ||
the for-loop implementation of GroupedExperts, as it does not need padding. | ||
|
||
Among the above: | ||
1 and 2 are needed only when expert_parallel_degree > 1. | ||
3 is needed even for single-device computation. | ||
2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. | ||
""" | ||
|
||
def wrapper( | ||
mlp1_weight: torch.Tensor, | ||
mlp1_bias: torch.Tensor, | ||
mlp2_weight: torch.Tensor, | ||
mlp2_bias: torch.Tensor, | ||
x: torch.Tensor, | ||
num_tokens_per_expert: torch.Tensor | None = None, | ||
) -> torch.Tensor: | ||
if isinstance(mlp1_weight, DTensor): | ||
mlp1_weight = mlp1_weight.to_local() | ||
mlp1_bias = mlp1_bias.to_local() | ||
mlp2_weight = mlp2_weight.to_local() | ||
mlp2_bias = mlp2_bias.to_local() | ||
|
||
if num_tokens_per_expert is not None: | ||
from torchtitan.experiments.kernels.moe.indices import ( | ||
generate_permute_indices, | ||
) | ||
|
||
experts_per_ep_rank = mlp1_weight.shape[0] | ||
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank | ||
|
||
ALIGN_SIZE_M = 16 | ||
with torch.no_grad(): | ||
( | ||
permuted_indices, | ||
num_tokens_per_expert, | ||
_, # offsets, | ||
) = generate_permute_indices( | ||
num_tokens_per_expert, | ||
experts_per_ep_rank, | ||
num_ep_ranks, | ||
x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, | ||
ALIGN_SIZE_M, | ||
) | ||
|
||
x = torch.vstack((x, x.new_zeros((x.shape[-1])))) | ||
input_shape = x.shape | ||
x = x[permuted_indices, :] | ||
|
||
out = func(mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias, x, num_tokens_per_expert) | ||
|
||
if num_tokens_per_expert is not None: | ||
out_unpermuted = out.new_empty(input_shape) | ||
out_unpermuted[permuted_indices, :] = out | ||
out = out_unpermuted[:-1] | ||
|
||
return out | ||
|
||
return wrapper |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm in the middle of a refactor #1569. We could do a rebase after it lands.