Skip to content

fix(gemma3): use GeGLU activation instead of SwiGLU#1825

Open
leofan-lab wants to merge 2 commits intoTHUDM:mainfrom
leofan-lab:fix/gemma3-activation-geglu
Open

fix(gemma3): use GeGLU activation instead of SwiGLU#1825
leofan-lab wants to merge 2 commits intoTHUDM:mainfrom
leofan-lab:fix/gemma3-activation-geglu

Conversation

@leofan-lab
Copy link
Copy Markdown

Problem

The upstream mbridge base config (_build_base_config) hardcodes activation_func=F.silu with gated_linear_unit=True, which gives SwiGLU. However, the entire Gemma family (1/2/3/4) uses GeGLU (gelu_pytorch_tanh + GLU).

This causes incorrect MLP outputs when running Gemma3 through Megatron's native MLP path.

Evidence

Discovered during Gemma4 parity testing against HuggingFace:

  • Before fix: 19.4% loss gap, per-layer MLP cosine similarity = 0.81
  • After fix: 0.29% loss gap, per-layer cosine = 0.9999, all 60 layers match

Sources confirming GeGLU for Gemma:

  • Google blog: "approximated GeGLU non-linearity"
  • HF config default: hidden_activation='gelu_pytorch_tanh' for Gemma 1/2/3/4
  • megatron-bridge v0.3.0 (the successor to mbridge) already uses fast_gelu for Gemma3

Fix

Adds slime_plugins/mbridge/gemma3.py that overrides _build_config to use functools.partial(F.gelu, approximate="tanh") instead of F.silu.

leofan-lab and others added 2 commits April 10, 2026 08:33
The upstream mbridge base config hardcodes F.silu (SwiGLU) for all models,
but the entire Gemma family (1/2/3/4) uses gelu_pytorch_tanh (GeGLU).

This causes incorrect MLP outputs when running Gemma3 through Megatron's
native MLP path, as the activation function doesn't match what the model
was pretrained with.

Evidence from Gemma4 parity testing:
- Before fix: 19.4% loss gap vs HuggingFace, cos=0.81 per-layer MLP output
- After fix: 0.29% loss gap, cos=0.9999 per-layer, all 60 layers match

Sources confirming GeGLU for Gemma:
- Google blog: developers.googleblog.com/en/gemma-explained-new-in-gemma-2/
- HF config: hidden_activation='gelu_pytorch_tanh' for Gemma 1/2/3/4
- New megatron-bridge (v0.3.0) already uses fast_gelu for Gemma3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant