-
Notifications
You must be signed in to change notification settings - Fork 549
[WIP] Experimental implementation of gpt-oss (grouped GEMM MoE + FlexAttention sink/sliding) #1559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi @KhoomeiK! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
…ks but reduces mfu for 20b
SequenceParallel, | ||
) | ||
|
||
if torch.__version__ >= "2.9": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are supposing user use pytorch nightly so this won't be an issue
return model | ||
|
||
|
||
def apply_non_moe_tp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's duplicated definition of apply_non_moe_tp
from torchtitan.protocols.train_spec import BaseModelArgs | ||
from torchtitan.tools.logging import logger | ||
|
||
# from transformers.models.gpt_oss.modeling_gpt_oss import GPT_OSS_PRETRAINED_INIT_CONFIGURATION |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this
from .moe import MoE | ||
|
||
# TODO: may be able to remove this once parallelized properly | ||
def convert_submodules_to_bf16( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this function for? Are you trying to apply mixed precision training by applying this function? Currently titan is handling it here:
torchtitan/torchtitan/distributed/utils.py
Line 217 in d240be0
def maybe_enable_amp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file can further be turned into state_dict_adapter instead of a script, eg https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/models/deepseek_v3/model/state_dict_adapter.py,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the nice implementation! I learned quite a lot from this PR.
I left some comments on how to make it more consistent with existing torchtitan (sorry if some comments from @wwwjn and me are duplicate as we review independently). Let's discuss a way to move forward.
self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, dim * 2))) | ||
self.mlp1_bias = nn.Parameter(torch.empty((num_experts, dim * 2))) | ||
self.mlp2_weight = nn.Parameter(torch.empty((num_experts, dim, dim))) | ||
self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The in_dim and out_dim are not clear from the naming here. Could you update so that
- it's clearer to tell which is in in dim, which is out dim, either by rename the args, or by adding comments
- follow the style of [llama4] store expert weights such that we can transpose before grouped mm to have col-major memory layout #1517 to make output dim first and transpose before actuall grouped_mm, for gemm efficiency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm in the middle of a refactor #1569. We could do a rebase after it lands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we reuse the one in torchtitan/components/optimizer.py
?
if torch.__version__ >= "2.9": | ||
from torch.distributed.tensor.parallel import PrepareModuleInputOutput | ||
else: | ||
print(f"Since torch version {torch.__version__} < 2.9, PrepareModuleInputOutput is not available and MoE EP TP will fail.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this, as we always assume latest pytorch.
) | ||
|
||
|
||
def apply_non_moe_tp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems there are two functions with this name in this file.
_, H_kv, S_kv, _ = k.shape | ||
|
||
# regular (no-sink) mask + no extra KV col | ||
mask_key = (sliding_window, S_q, S_kv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For inference this might change every time a new token is generated, but for training it will always cache hit after the two layers? One is with sliding window, the other is without.
mask_mod = FlexAttention._get_sliding_window_mask_mod(sliding_window) | ||
else: | ||
mask_mod = FlexAttention._get_causal_mask_mod() | ||
block_mask = create_block_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One can follow the style of existing torchtitan code to create block mask once for each iteration, outside the transformer blocks (which are the components torch.compile applys to).
v: torch.Tensor, | ||
scale: float | None = None, | ||
sink_weights: torch.Tensor | None = None, | ||
sliding_window: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be None
.
In general, I feel we can refactor the code to not pass in sliding_window
during forward
; instead, we can build attention layer with sliding window info in it and create mask over there.
Part of the motivation is to make it work with the CP implementation #1707
w = sink_weights # [H] | ||
scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1] | ||
out = out * scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh very nice implementation using results from FlexAttention!
self.attn_mask_type = attn_mask_type | ||
self.fixed_block_size = fixed_block_size | ||
|
||
self.mask_cache = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can try reuse FlexAttention.used_attn_mask_types
|
||
# regular (no-sink) mask + no extra KV col | ||
mask_key = (sliding_window, S_q, S_kv) | ||
if mask_key not in self.mask_cache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sliding window can be used alone without sink_weights, right? We should make sliding window mask creation the same way as other masks.
scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1] | ||
out = out * scale | ||
|
||
out = out.to(q.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? Do we upcasting during the scaling?
self, | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are here, maybe we should consider to force all args after v
, keyword args.
if self.use_flex_attn: | ||
self.attn = build_attention(True, model_args.attn_mask_type) | ||
else: | ||
# NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add more context? cc., @drisspg
When will this PR be merged? |
This PR currently implements gpt-oss in torchtitan such that forward pass & sampling from 20B matches the HuggingFace reference implementation. You can run
torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py
to verify this.A few notes:
use_grouped_mm=False
to use the for-loop implementation.use_flex_attn=False
, though I intend to resolve this soon.Thanks @Chillee for pointing me to FlexAttention and @vwxyzjn for tips on the DeepSeek reference implementation!