Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Nov 5, 2025

Description

Changing default arguments in some JAX layers

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

Changes some default argument values to match with the Pytorch implementation + enable fused codepath to be activated by default.:

  • Default activation to be gelu instead of relu
  • Return_layernorm_output to False to enable fused layernorm and MLP/DenseGeneral by default
  • Intermediate_dropout_rate to 0.0: this dropout layer does not exist on the PyTorch side, so we should have it turned off by default also on Jax side. Will be removed completely when rubin comes

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR aligns JAX layer default arguments with the PyTorch implementation by changing three key defaults:

Default Argument Changes:

  • Activation function: Changed from relu to gelu in LayerNormMLP and TransformerLayer
  • Intermediate dropout: Changed from 0.1 to 0.0 in LayerNormMLP and TransformerLayer
  • Return layernorm output: Changed from True to False in LayerNormDenseGeneral and LayerNormMLP

Key Benefits:

  • Enables fused codepath by default (via return_layernorm_output=False)
  • Removes unnecessary dropout layer that doesn't exist in PyTorch implementation
  • Aligns activation function with PyTorch's default gelu
  • Improves cross-framework consistency

Additional Changes:

  • Added new JAX quickstart tutorial (notebook and utilities)
  • Updated CI workflow to use specific PyTorch CUDA 13.0 index

The changes are well-documented in the PR description and align with the stated goal of matching PyTorch defaults while enabling better performance through fused operations.

Confidence Score: 5/5

  • This PR is safe to merge with no identified issues
  • All changes are straightforward default argument updates that improve JAX/PyTorch consistency. The PyTorch implementation was verified to use gelu activation and the changes align perfectly with the PR description. The new tutorial files are documentation additions that pose no risk. No logical errors or breaking changes detected.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/flax/module.py 5/5 Changed default args in LayerNormDenseGeneral and LayerNormMLP: return_layernorm_output from True to False, activations from relu to gelu, and intermediate_dropout_rate from 0.1 to 0.0 to match PyTorch defaults
transformer_engine/jax/flax/transformer.py 5/5 Changed TransformerLayer defaults: intermediate_dropout from 0.1 to 0.0 and mlp_activations from relu to gelu to align with PyTorch implementation

Sequence Diagram

sequenceDiagram
    participant User
    participant JAX Layer
    participant LayerNormMLP
    participant LayerNormDenseGeneral
    participant TransformerLayer
    
    User->>JAX Layer: Initialize with default args
    Note over JAX Layer: Before PR: activation='relu'<br/>intermediate_dropout=0.1<br/>return_layernorm_output=True
    
    User->>JAX Layer: Initialize with default args (After PR)
    Note over JAX Layer: After PR: activation='gelu'<br/>intermediate_dropout=0.0<br/>return_layernorm_output=False
    
    JAX Layer->>LayerNormMLP: Apply with new defaults
    Note over LayerNormMLP: activations=('gelu',)<br/>intermediate_dropout_rate=0.0<br/>return_layernorm_output=False
    
    JAX Layer->>LayerNormDenseGeneral: Apply with new defaults
    Note over LayerNormDenseGeneral: return_layernorm_output=False<br/>Enables fused codepath
    
    JAX Layer->>TransformerLayer: Apply with new defaults
    Note over TransformerLayer: mlp_activations=('gelu',)<br/>intermediate_dropout=0.0<br/>Matches PyTorch behavior
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@tdophung tdophung changed the title Teddy/fused layers default args [JAX] Fused layers argument default values changed Nov 5, 2025
@jberchtold-nvidia
Copy link
Collaborator

LGTM pending CI, thanks for making this change! Since this PR is smaller it is likely to ready to merge before the getting started tutorial, so you may want to rebase this on main instead

mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",)
mlp_activations: Sequence[str] = ("gelu",)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TE/PyTorch seems to not have a default activation type and afaict requires the user to explicitly set it
https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L204

For now, we can keep this PR as is to make the activation default consistent across TE/JAX, but in future we could try making it an required argument with no default.

…fter FC1 to 0, and return_layernorm_output to False

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/fused_layers_default_args branch from af4a836 to 76e4262 Compare November 5, 2025 18:46
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR aligns JAX layer default parameter values with the PyTorch implementation by changing three key defaults:

  • Activation function: Changed from 'relu' to 'gelu' in LayerNormMLP and TransformerLayer
  • Intermediate dropout: Changed from 0.1 to 0.0 in both LayerNormMLP (as intermediate_dropout_rate) and TransformerLayer (as intermediate_dropout)
  • Return layernorm output: Changed from True to False in LayerNormDenseGeneral and LayerNormMLP

The most impactful change is return_layernorm_output=False, which enables the fused FP8 codepath by default when FP8 is enabled (see lines 698-702 and 1021-1025 in module.py). This optimization fuses LayerNorm with subsequent Dense/MLP operations for better performance.

All changes are consistent with PyTorch defaults verified in:

  • transformer_engine/pytorch/module/layernorm_mlp.py:1539 (activation="gelu")
  • transformer_engine/pytorch/module/layernorm_mlp.py:1543 (return_layernorm_output=False)
  • transformer_engine/pytorch/transformer.py:312 (activation="gelu")

The intermediate dropout change is justified per the PR description: "this dropout layer does not exist on the PyTorch side, so we should have it turned off by default".

Confidence Score: 5/5

  • This PR is safe to merge with no risk - it only changes default parameter values to match the PyTorch implementation
  • Score reflects that these are straightforward default value changes with clear justification: (1) Aligns JAX and PyTorch implementations for consistency, (2) Enables optimized fused codepath by default, (3) All changes are backward-compatible since users can still explicitly set the old values, (4) Changes are well-documented in both code and PR description
  • No files require special attention - all changes are intentional default value updates

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/flax/module.py 5/5 Changes default values for LayerNormDenseGeneral and LayerNormMLP: return_layernorm_output changed from True to False (enables fused FP8 codepath), activations changed from ('relu',) to ('gelu',), and intermediate_dropout_rate changed from 0.1 to 0.0 to match PyTorch implementation
transformer_engine/jax/flax/transformer.py 5/5 Changes default values for TransformerLayer: intermediate_dropout changed from 0.1 to 0.0 and mlp_activations changed from ('relu',) to ('gelu',) to match PyTorch implementation

Sequence Diagram

sequenceDiagram
    participant User
    participant JAX Layer
    participant FP8 Config
    participant Fused Path
    participant Unfused Path
    
    User->>JAX Layer: Create LayerNormMLP/TransformerLayer
    Note over JAX Layer: New defaults:<br/>activations=('gelu',)<br/>intermediate_dropout=0.0<br/>return_layernorm_output=False
    
    User->>JAX Layer: Forward pass with input
    JAX Layer->>FP8 Config: Check is_fp8_enabled()
    
    alt FP8 Enabled & return_layernorm_output=False
        FP8 Config-->>JAX Layer: True
        JAX Layer->>Fused Path: Use fused layernorm_mlp/layernorm_dense
        Note over Fused Path: Optimized fusion:<br/>LayerNorm + Dense<br/>or LayerNorm + MLP
        Fused Path-->>JAX Layer: Output (no layernorm_output)
    else FP8 Disabled or return_layernorm_output=True
        FP8 Config-->>JAX Layer: False
        JAX Layer->>Unfused Path: Separate layernorm + dense ops
        Note over Unfused Path: Standard path:<br/>LayerNorm first<br/>Then Dense/MLP
        Unfused Path-->>JAX Layer: Output + optional layernorm_output
    end
    
    JAX Layer->>JAX Layer: Apply GELU activation (was ReLU)
    JAX Layer->>JAX Layer: Skip intermediate dropout (rate=0.0)
    JAX Layer-->>User: Final output
Loading

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@tdophung
Copy link
Collaborator Author

tdophung commented Nov 6, 2025

/te-ci L2 jax

tdophung and others added 2 commits November 7, 2025 17:05
…ues instead of relying on newer default values

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator Author

tdophung commented Nov 8, 2025

/te_ci L2 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Changed JAX fused layer default argument values to align with PyTorch implementation and enable fused codepaths by default.

Key Changes:

  • Default activation function changed from relu to gelu in LayerNormMLP, LayerNormDenseGeneral, and TransformerLayer
  • Default intermediate_dropout_rate changed from 0.1 to 0.0 (this dropout doesn't exist in PyTorch, will be removed when Rubin comes)
  • Default return_layernorm_output changed from True to False to enable fused layernorm and MLP/DenseGeneral by default
  • Test files updated to explicitly set return_layernorm_output=True where needed to preserve existing test behavior
  • Test utility classes updated to match new defaults for consistency

Impact:
This is a breaking change for users who rely on default values. Existing code that doesn't explicitly specify these parameters will now use gelu activation, no intermediate dropout, and fused layernorm paths.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are well-structured with proper test updates. The first commit changed defaults in main modules, and the second commit properly fixed failing tests by explicitly specifying the old behavior. The changes align with PyTorch implementation for consistency. All default value changes are documented in the PR description and are intentional improvements.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
tests/jax/test_distributed_layernorm_mlp.py 5/5 Added explicit return_layernorm_output=True to test cases to maintain expected test behavior after default changed to False
tests/jax/utils.py 5/5 Updated test utility classes with new defaults: activation changed from relu to gelu, intermediate_dropout changed from 0.1 to 0.0

Sequence Diagram

sequenceDiagram
    participant Dev as Developer
    participant Module as JAX Modules<br/>(module.py, transformer.py)
    participant Tests as Test Suite<br/>(test_distributed_layernorm_mlp.py)
    participant Utils as Test Utils<br/>(utils.py)
    
    Note over Dev,Utils: First Commit: Default Value Changes
    Dev->>Module: Change default activation: relu → gelu
    Dev->>Module: Change default intermediate_dropout: 0.1 → 0.0
    Dev->>Module: Change default return_layernorm_output: True → False
    
    Note over Tests,Utils: Tests fail due to new defaults
    
    Note over Dev,Utils: Second Commit: Test Fixes
    Dev->>Tests: Add explicit return_layernorm_output=True<br/>to preserve test expectations
    Dev->>Utils: Update MlpBlock defaults:<br/>activation: relu → gelu<br/>intermediate_dropout: 0.1 → 0.0
    Dev->>Utils: Update EncoderLayer defaults:<br/>activation: relu → gelu<br/>intermediate_dropout: 0.1 → 0.0
    Dev->>Utils: Update DecoderLayer defaults:<br/>activation: relu → gelu<br/>intermediate_dropout: 0.1 → 0.0
    
    Note over Tests: Tests pass with explicit parameters
Loading

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants