-
Notifications
You must be signed in to change notification settings - Fork 540
[JAX] Fused layers argument default values changed #2347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR aligns JAX layer default arguments with the PyTorch implementation by changing three key defaults:
Default Argument Changes:
- Activation function: Changed from
relutogeluinLayerNormMLPandTransformerLayer - Intermediate dropout: Changed from
0.1to0.0inLayerNormMLPandTransformerLayer - Return layernorm output: Changed from
TruetoFalseinLayerNormDenseGeneralandLayerNormMLP
Key Benefits:
- Enables fused codepath by default (via
return_layernorm_output=False) - Removes unnecessary dropout layer that doesn't exist in PyTorch implementation
- Aligns activation function with PyTorch's default
gelu - Improves cross-framework consistency
Additional Changes:
- Added new JAX quickstart tutorial (notebook and utilities)
- Updated CI workflow to use specific PyTorch CUDA 13.0 index
The changes are well-documented in the PR description and align with the stated goal of matching PyTorch defaults while enabling better performance through fused operations.
Confidence Score: 5/5
- This PR is safe to merge with no identified issues
- All changes are straightforward default argument updates that improve JAX/PyTorch consistency. The PyTorch implementation was verified to use
geluactivation and the changes align perfectly with the PR description. The new tutorial files are documentation additions that pose no risk. No logical errors or breaking changes detected. - No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/flax/module.py | 5/5 | Changed default args in LayerNormDenseGeneral and LayerNormMLP: return_layernorm_output from True to False, activations from relu to gelu, and intermediate_dropout_rate from 0.1 to 0.0 to match PyTorch defaults |
| transformer_engine/jax/flax/transformer.py | 5/5 | Changed TransformerLayer defaults: intermediate_dropout from 0.1 to 0.0 and mlp_activations from relu to gelu to align with PyTorch implementation |
Sequence Diagram
sequenceDiagram
participant User
participant JAX Layer
participant LayerNormMLP
participant LayerNormDenseGeneral
participant TransformerLayer
User->>JAX Layer: Initialize with default args
Note over JAX Layer: Before PR: activation='relu'<br/>intermediate_dropout=0.1<br/>return_layernorm_output=True
User->>JAX Layer: Initialize with default args (After PR)
Note over JAX Layer: After PR: activation='gelu'<br/>intermediate_dropout=0.0<br/>return_layernorm_output=False
JAX Layer->>LayerNormMLP: Apply with new defaults
Note over LayerNormMLP: activations=('gelu',)<br/>intermediate_dropout_rate=0.0<br/>return_layernorm_output=False
JAX Layer->>LayerNormDenseGeneral: Apply with new defaults
Note over LayerNormDenseGeneral: return_layernorm_output=False<br/>Enables fused codepath
JAX Layer->>TransformerLayer: Apply with new defaults
Note over TransformerLayer: mlp_activations=('gelu',)<br/>intermediate_dropout=0.0<br/>Matches PyTorch behavior
1 file reviewed, no comments
|
LGTM pending CI, thanks for making this change! Since this PR is smaller it is likely to ready to merge before the getting started tutorial, so you may want to rebase this on main instead |
| mha_kernel_init: Initializer = None | ||
| mlp_kernel_init: Initializer = None | ||
| mlp_activations: Sequence[str] = ("relu",) | ||
| mlp_activations: Sequence[str] = ("gelu",) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TE/PyTorch seems to not have a default activation type and afaict requires the user to explicitly set it
https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L204
For now, we can keep this PR as is to make the activation default consistent across TE/JAX, but in future we could try making it an required argument with no default.
…fter FC1 to 0, and return_layernorm_output to False Signed-off-by: tdophung <tdophung@nvidia.com>
af4a836 to
76e4262
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR aligns JAX layer default parameter values with the PyTorch implementation by changing three key defaults:
- Activation function: Changed from
'relu'to'gelu'inLayerNormMLPandTransformerLayer - Intermediate dropout: Changed from
0.1to0.0in bothLayerNormMLP(asintermediate_dropout_rate) andTransformerLayer(asintermediate_dropout) - Return layernorm output: Changed from
TruetoFalseinLayerNormDenseGeneralandLayerNormMLP
The most impactful change is return_layernorm_output=False, which enables the fused FP8 codepath by default when FP8 is enabled (see lines 698-702 and 1021-1025 in module.py). This optimization fuses LayerNorm with subsequent Dense/MLP operations for better performance.
All changes are consistent with PyTorch defaults verified in:
transformer_engine/pytorch/module/layernorm_mlp.py:1539(activation="gelu")transformer_engine/pytorch/module/layernorm_mlp.py:1543(return_layernorm_output=False)transformer_engine/pytorch/transformer.py:312(activation="gelu")
The intermediate dropout change is justified per the PR description: "this dropout layer does not exist on the PyTorch side, so we should have it turned off by default".
Confidence Score: 5/5
- This PR is safe to merge with no risk - it only changes default parameter values to match the PyTorch implementation
- Score reflects that these are straightforward default value changes with clear justification: (1) Aligns JAX and PyTorch implementations for consistency, (2) Enables optimized fused codepath by default, (3) All changes are backward-compatible since users can still explicitly set the old values, (4) Changes are well-documented in both code and PR description
- No files require special attention - all changes are intentional default value updates
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/flax/module.py | 5/5 | Changes default values for LayerNormDenseGeneral and LayerNormMLP: return_layernorm_output changed from True to False (enables fused FP8 codepath), activations changed from ('relu',) to ('gelu',), and intermediate_dropout_rate changed from 0.1 to 0.0 to match PyTorch implementation |
| transformer_engine/jax/flax/transformer.py | 5/5 | Changes default values for TransformerLayer: intermediate_dropout changed from 0.1 to 0.0 and mlp_activations changed from ('relu',) to ('gelu',) to match PyTorch implementation |
Sequence Diagram
sequenceDiagram
participant User
participant JAX Layer
participant FP8 Config
participant Fused Path
participant Unfused Path
User->>JAX Layer: Create LayerNormMLP/TransformerLayer
Note over JAX Layer: New defaults:<br/>activations=('gelu',)<br/>intermediate_dropout=0.0<br/>return_layernorm_output=False
User->>JAX Layer: Forward pass with input
JAX Layer->>FP8 Config: Check is_fp8_enabled()
alt FP8 Enabled & return_layernorm_output=False
FP8 Config-->>JAX Layer: True
JAX Layer->>Fused Path: Use fused layernorm_mlp/layernorm_dense
Note over Fused Path: Optimized fusion:<br/>LayerNorm + Dense<br/>or LayerNorm + MLP
Fused Path-->>JAX Layer: Output (no layernorm_output)
else FP8 Disabled or return_layernorm_output=True
FP8 Config-->>JAX Layer: False
JAX Layer->>Unfused Path: Separate layernorm + dense ops
Note over Unfused Path: Standard path:<br/>LayerNorm first<br/>Then Dense/MLP
Unfused Path-->>JAX Layer: Output + optional layernorm_output
end
JAX Layer->>JAX Layer: Apply GELU activation (was ReLU)
JAX Layer->>JAX Layer: Skip intermediate dropout (rate=0.0)
JAX Layer-->>User: Final output
2 files reviewed, no comments
|
/te-ci L2 jax |
…ues instead of relying on newer default values Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
|
/te_ci L2 jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Changed JAX fused layer default argument values to align with PyTorch implementation and enable fused codepaths by default.
Key Changes:
- Default activation function changed from
relutogeluinLayerNormMLP,LayerNormDenseGeneral, andTransformerLayer - Default
intermediate_dropout_ratechanged from0.1to0.0(this dropout doesn't exist in PyTorch, will be removed when Rubin comes) - Default
return_layernorm_outputchanged fromTruetoFalseto enable fused layernorm and MLP/DenseGeneral by default - Test files updated to explicitly set
return_layernorm_output=Truewhere needed to preserve existing test behavior - Test utility classes updated to match new defaults for consistency
Impact:
This is a breaking change for users who rely on default values. Existing code that doesn't explicitly specify these parameters will now use gelu activation, no intermediate dropout, and fused layernorm paths.
Confidence Score: 5/5
- This PR is safe to merge with minimal risk
- The changes are well-structured with proper test updates. The first commit changed defaults in main modules, and the second commit properly fixed failing tests by explicitly specifying the old behavior. The changes align with PyTorch implementation for consistency. All default value changes are documented in the PR description and are intentional improvements.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| tests/jax/test_distributed_layernorm_mlp.py | 5/5 | Added explicit return_layernorm_output=True to test cases to maintain expected test behavior after default changed to False |
| tests/jax/utils.py | 5/5 | Updated test utility classes with new defaults: activation changed from relu to gelu, intermediate_dropout changed from 0.1 to 0.0 |
Sequence Diagram
sequenceDiagram
participant Dev as Developer
participant Module as JAX Modules<br/>(module.py, transformer.py)
participant Tests as Test Suite<br/>(test_distributed_layernorm_mlp.py)
participant Utils as Test Utils<br/>(utils.py)
Note over Dev,Utils: First Commit: Default Value Changes
Dev->>Module: Change default activation: relu → gelu
Dev->>Module: Change default intermediate_dropout: 0.1 → 0.0
Dev->>Module: Change default return_layernorm_output: True → False
Note over Tests,Utils: Tests fail due to new defaults
Note over Dev,Utils: Second Commit: Test Fixes
Dev->>Tests: Add explicit return_layernorm_output=True<br/>to preserve test expectations
Dev->>Utils: Update MlpBlock defaults:<br/>activation: relu → gelu<br/>intermediate_dropout: 0.1 → 0.0
Dev->>Utils: Update EncoderLayer defaults:<br/>activation: relu → gelu<br/>intermediate_dropout: 0.1 → 0.0
Dev->>Utils: Update DecoderLayer defaults:<br/>activation: relu → gelu<br/>intermediate_dropout: 0.1 → 0.0
Note over Tests: Tests pass with explicit parameters
2 files reviewed, no comments
Description
Changing default arguments in some JAX layers
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Changes some default argument values to match with the Pytorch implementation + enable fused codepath to be activated by default.:
Checklist: