Skip to content

Conversation

KhoomeiK
Copy link

This PR currently implements gpt-oss in torchtitan such that forward pass & sampling from 20B matches the HuggingFace reference implementation. You can run torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py to verify this.

A few notes:

  • It default uses torch 2.9 grouped GEMMs for the MoE. If you can't force upgrade to torch 2.9, please set use_grouped_mm=False to use the for-loop implementation.
  • It default uses my FlexAttention implementation of sinks + sliding window, which currently runs into CUDA errors when sampling more than one token. If you care about sampling, please set use_flex_attn=False, though I intend to resolve this soon.
  • Parallelism is unimplemented so training is untested as of now. This is obviously my highest priority right now, but I wanted to get a baseline implementation out for the community.
  • The 120B checkpoint remains untested, but architecture differences are minimal so I don't foresee any issues.
  • I haven't yet spent any time adding support for the tiktokenizer, please use the HuggingFace tokenizer for now.
  • I still have some open questions regarding mixed precision training and what the best numerics recipe is here.

Thanks @Chillee for pointing me to FlexAttention and @vwxyzjn for tips on the DeepSeek reference implementation!

Copy link

meta-cla bot commented Aug 13, 2025

Hi @KhoomeiK!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

SequenceParallel,
)

if torch.__version__ >= "2.9":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are supposing user use pytorch nightly so this won't be an issue

return model


def apply_non_moe_tp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's duplicated definition of apply_non_moe_tp

from torchtitan.protocols.train_spec import BaseModelArgs
from torchtitan.tools.logging import logger

# from transformers.models.gpt_oss.modeling_gpt_oss import GPT_OSS_PRETRAINED_INIT_CONFIGURATION
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this

from .moe import MoE

# TODO: may be able to remove this once parallelized properly
def convert_submodules_to_bf16(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this function for? Are you trying to apply mixed precision training by applying this function? Currently titan is handling it here:

def maybe_enable_amp(
, and in train.py https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L460.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file can further be turned into state_dict_adapter instead of a script, eg https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/models/deepseek_v3/model/state_dict_adapter.py,

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the nice implementation! I learned quite a lot from this PR.

I left some comments on how to make it more consistent with existing torchtitan (sorry if some comments from @wwwjn and me are duplicate as we review independently). Let's discuss a way to move forward.

Comment on lines +32 to +35
self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, dim * 2)))
self.mlp1_bias = nn.Parameter(torch.empty((num_experts, dim * 2)))
self.mlp2_weight = nn.Parameter(torch.empty((num_experts, dim, dim)))
self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The in_dim and out_dim are not clear from the naming here. Could you update so that

  1. it's clearer to tell which is in in dim, which is out dim, either by rename the args, or by adding comments
  2. follow the style of [llama4] store expert weights such that we can transpose before grouped mm to have col-major memory layout #1517 to make output dim first and transpose before actuall grouped_mm, for gemm efficiency

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in the middle of a refactor #1569. We could do a rebase after it lands.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse the one in torchtitan/components/optimizer.py?

Comment on lines +13 to +16
if torch.__version__ >= "2.9":
from torch.distributed.tensor.parallel import PrepareModuleInputOutput
else:
print(f"Since torch version {torch.__version__} < 2.9, PrepareModuleInputOutput is not available and MoE EP TP will fail.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this, as we always assume latest pytorch.

)


def apply_non_moe_tp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems there are two functions with this name in this file.

_, H_kv, S_kv, _ = k.shape

# regular (no-sink) mask + no extra KV col
mask_key = (sliding_window, S_q, S_kv)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For inference this might change every time a new token is generated, but for training it will always cache hit after the two layers? One is with sliding window, the other is without.

mask_mod = FlexAttention._get_sliding_window_mask_mod(sliding_window)
else:
mask_mod = FlexAttention._get_causal_mask_mod()
block_mask = create_block_mask(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One can follow the style of existing torchtitan code to create block mask once for each iteration, outside the transformer blocks (which are the components torch.compile applys to).

v: torch.Tensor,
scale: float | None = None,
sink_weights: torch.Tensor | None = None,
sliding_window: int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be None.
In general, I feel we can refactor the code to not pass in sliding_window during forward; instead, we can build attention layer with sliding window info in it and create mask over there.

Part of the motivation is to make it work with the CP implementation #1707

Comment on lines +124 to +126
w = sink_weights # [H]
scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1]
out = out * scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh very nice implementation using results from FlexAttention!

self.attn_mask_type = attn_mask_type
self.fixed_block_size = fixed_block_size

self.mask_cache = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can try reuse FlexAttention.used_attn_mask_types


# regular (no-sink) mask + no extra KV col
mask_key = (sliding_window, S_q, S_kv)
if mask_key not in self.mask_cache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sliding window can be used alone without sink_weights, right? We should make sliding window mask creation the same way as other masks.

scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1]
out = out * scale

out = out.to(q.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Do we upcasting during the scaling?

self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are here, maybe we should consider to force all args after v, keyword args.

if self.use_flex_attn:
self.attn = build_attention(True, model_args.attn_mask_type)
else:
# NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more context? cc., @drisspg

@Opdoop
Copy link

Opdoop commented Sep 23, 2025

When will this PR be merged?

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants