Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Oct 28, 2025

Description

This change adds the new quick start notebook for Jax, mirroring the same Transformer architecture as the PyTorch guide, for users familiar with the PyTorch guide can also easily follow. Contains an 4 iterations of the layer with different training step time durations:

  1. pure JAX/Flax implemtation
  2. Basic TE implementation without any fused layer
  3. TE implementation with mixed fused and unfused layers
  4. Full TransformerLayer from TE

This might also include changes to how sphinx display this content on the HTML docs page in later commits

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added 1 quickstart guide written in Jax

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • [ DOES NOT APPLY] I have added tests that prove my fix is effective or that my feature works
  • [ DOES NOT APPLY] New and existing unit tests pass locally with my changes

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The current commit introduces only whitespace cleanup to docs/examples/quickstart_jax_utils.py, removing trailing spaces from multiple lines (48, 60, 80, and others). While these formatting changes align with the repository's pre-commit hooks (black, clang-format), they have no functional impact on the code. However, the file contains a critical pre-existing issue: three functions (is_bf16_supported(), is_fp8_supported(), is_mxfp8_supported()) call get_device_compute_capability(0) without importing this function, which will cause NameError exceptions at runtime. This is a blocking issue that must be resolved before the documentation can be safely used.

Important Files Changed

Filename Score Overview
docs/examples/quickstart_jax_utils.py 1/5 Whitespace cleanup only, but file contains critical missing import for get_device_compute_capability used in three functions

Confidence score: 1/5

  • This PR is not safe to merge due to a critical missing import that will cause runtime failures
  • Score reflects the presence of a blocking bug: get_device_compute_capability is called in lines 18, 25, and 32 but never imported, causing immediate NameError exceptions when any of the three capability-checking functions are invoked
  • Pay close attention to docs/examples/quickstart_jax_utils.py - add the missing import statement (likely from transformer_engine.jax import get_device_compute_capability or similar) and verify all three functions work correctly before merging

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The recent changes address the missing import issue that was previously identified. The quickstart_jax_utils.py file now includes the necessary import for get_device_compute_capability, which resolves the undefined reference that would have caused runtime errors when executing the capability check logic.

PR Description Notes:

  • "implemtation" should be "implementation" (in iteration #1)

Important Files Changed

Filename Score Overview
docs/examples/quickstart_jax_utils.py 2/5 Added import for get_device_compute_capability to fix the previously identified missing import issue

Confidence score: 2/5

  • This PR contains critical bugs that will cause runtime failures and should not be merged as-is
  • Score reflects severe functional issues in gradient computation logic: the speedometer function has argument count mismatches (calls train_step with 5 args but function expects 4), and the VJP-based train step has broken forward/backward pass implementations that will fail at runtime
  • Pay close attention to quickstart_jax_utils.py lines 38-48 (argument mismatch in speedometer), lines 67-81 (gradient function parameter issues), and lines 96-107 (broken VJP implementation); also review quickstart_jax.ipynb for hardcoded absolute paths at lines60-61 that will break for other users

2 files reviewed, 14 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The recent commits address several issues identified in previous reviews, including fixing typos, correcting import statements, and improving variable naming. The most critical remaining issues are: (1) the signature mismatch in train_step_fn where 5 arguments are passed but only 4 are expected (lines 37-39, 45-47 in utils), (2) the vjp_fn call receiving only 1 cotangent when it expects 3 (line 102 in utils), and (3) the closure issues in forward_fn within the VJP implementation. These changes integrate into the documentation suite by providing a JAX/Flax parallel to the existing PyTorch quickstart guide, allowing users to see equivalent implementations across both frameworks.

Important Files Changed

Filename Score Overview
docs/examples/quickstart_jax_utils.py 2/5 Utility functions for JAX quickstart with critical signature mismatches in train step functions and VJP implementation issues
docs/examples/quickstart_jax.ipynb 4/5 Comprehensive JAX/Flax quickstart notebook demonstrating four progressive Transformer implementations with timing benchmarks

Confidence score: 2/5

  • This PR contains critical logic errors that will prevent the code from running correctly, particularly around function signatures and VJP usage
  • Score reflects multiple runtime-breaking issues: the train_step_fn argument count mismatch (5 passed vs 4 expected), vjp_fn expecting 3 cotangents but receiving only 1, and forward_fn closure issues that will cause failures when executed
  • Pay close attention to docs/examples/quickstart_jax_utils.py lines 37-47 (speedometer train_step_fn calls), lines 53-82 (create_train_step_fn signature), and lines98-117 (VJP implementation)

2 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the most recent changes made since the last review, not the entire PR. The PR adds a comprehensive JAX/Flax quickstart notebook (docs/examples/quickstart_jax.ipynb) and supporting utilities (docs/examples/quickstart_jax_utils.py) that mirror the existing PyTorch guide. The new documentation demonstrates progressive optimization of Transformer layers through four iterations: pure JAX/Flax baseline, basic Transformer Engine modules, mixed fused/unfused TE modules, and full TransformerLayer with FP8 support. Performance benchmarks are included showing training step time improvements from 27.9ms (baseline) to 8.0ms (FP8-optimized). The utilities file provides performance measurement tools, JIT-compiled training step creators using both jax.grad and explicit VJP approaches, baseline attention/MLP implementations, and parameter-sharing functions to ensure fair comparisons between Flax and TE models.

Important Files Changed

Filename Score Overview
docs/examples/quickstart_jax.ipynb 4.5/5 Comprehensive quickstart notebook demonstrating four progressive optimization stages with embedded performance benchmarks; minor type annotation issues (layernorm_eps declared as int instead of float on lines 360, 528)
docs/examples/quickstart_jax_utils.py 2/5 Utility functions for benchmarking, training step creation, and parameter sharing; contains critical logic errors including function signature mismatch in speedometer (calls with 5 args, function expects 4), incorrect parameter usage in create_train_step_fn_vjp where forward_fn and vjp_fn don't use provided parameters, unused parameters (model_init_fn, key), and shallow copy issues in parameter sharing functions

Confidence score: 2/5

  • This PR contains critical logic errors in the utilities file that will cause runtime failures when the notebook is executed
  • Score reflects multiple function signature mismatches, incorrect parameter handling in VJP-based training steps, and shallow copying issues that could lead to unintended mutation of nested parameter dictionaries; the notebook itself is well-structured but depends entirely on the broken utilities
  • Pay close attention to docs/examples/quickstart_jax_utils.py lines 31-47 (signature mismatch in speedometer), lines 98-116 (forward_fn/vjp_fn parameter usage), and lines 214-438 (all three parameter-sharing functions use shallow copy() on nested dicts)

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

tdophung and others added 2 commits October 28, 2025 11:34
Signed-off-by: tdophung <tdophung@nvidia.com>
…NVIDIA#2308)

Change the pyTorch installation to CUDA 13 in Build All GitHub action to
match the version in the JAX container

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the tdophung/jax_quickstart_documentation branch from 2796e91 to 733d61b Compare October 28, 2025 18:39
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds a comprehensive JAX-specific quickstart tutorial notebook that demonstrates four progressive implementations of a transformer layer: pure JAX/Flax, basic Transformer Engine (TE) modules, fused TE modules, and the full TE TransformerLayer with optional FP8 support. The notebook includes performance benchmarking utilities and parameter-sharing helpers to enable fair comparisons between implementations. A minor build infrastructure change updates the PyTorch dependency installation in the GitHub Actions workflow to explicitly use CUDA 13.0 wheels, ensuring compatibility with the JAX container environment. This documentation effort mirrors the structure of the existing PyTorch quickstart guide, providing JAX users with a dedicated migration and optimization path.

Important Files Changed

Filename Score Overview
.github/workflows/build.yml 5/5 Split torch installation into separate command with explicit CUDA 13.0 index URL
docs/examples/quickstart_jax.ipynb 3/5 Added comprehensive JAX quickstart notebook with 4 progressive transformer implementations
docs/examples/quickstart_jax_utils.py 2/5 Added utility functions for benchmarking and parameter sharing between JAX and TE models

Confidence score: 2/5

  • This PR contains critical runtime errors that will prevent the tutorial from executing successfully
  • The main issue is function signature mismatch in quickstart_jax_utils.py: speedometer() calls train_step_fn with 5 arguments but create_train_step_fn() returns a function expecting only 4 parameters, causing a guaranteed TypeError
  • Additional concerns include variable scope issues in the notebook (line 795 references te_transformer_params_template which may not be in scope after FP8 initialization) and unused parameters indicating incomplete implementation
  • Pay close attention to docs/examples/quickstart_jax_utils.py lines 37-46 (function call mismatch) and the notebook cells around lines 787-801 (variable scoping)

Sequence Diagram

sequenceDiagram
    participant User
    participant Notebook as quickstart_jax.ipynb
    participant Utils as quickstart_jax_utils.py
    participant JAX as JAX/Flax
    participant TE as Transformer Engine

    User->>Notebook: Execute notebook cells
    Notebook->>JAX: Import jax, jax.numpy, flax.linen
    Notebook->>TE: Import transformer_engine.jax
    Notebook->>Utils: Import quickstart_jax_utils
    
    Note over Notebook,JAX: 1. Build BasicTransformerLayer (pure JAX/Flax)
    Notebook->>JAX: Initialize BasicTransformerLayer
    JAX-->>Notebook: Return initialized model
    Notebook->>JAX: init(key, x) to create params
    JAX-->>Notebook: Return params
    Notebook->>Utils: speedometer(model_apply_fn, variables, input, output_grad)
    Utils->>Utils: create_train_step_fn() - JIT compile fwd/bwd
    Utils->>JAX: Run warmup iterations
    Utils->>JAX: Run timing iterations (forward + backward)
    JAX-->>Utils: Return loss and gradients
    Utils-->>Notebook: Print mean time
    
    Note over Notebook,TE: 2. Build BasicTETransformerLayer (TE modules)
    Notebook->>TE: Initialize BasicTETransformerLayer
    TE-->>Notebook: Return TE model
    Notebook->>TE: init(key, x) to create TE params
    TE-->>Notebook: Return TE params template
    Notebook->>Utils: share_parameters_with_basic_te_model()
    Utils-->>Notebook: Return shared params
    Notebook->>Utils: speedometer() with TE model
    Utils->>TE: Run timing iterations with TE layers
    TE-->>Utils: Return loss and gradients
    Utils-->>Notebook: Print improved mean time
    
    Note over Notebook,TE: 3. Build FusedTETransformerLayer (fused ops)
    Notebook->>TE: Initialize FusedTETransformerLayer
    TE-->>Notebook: Return fused TE model
    Notebook->>TE: init() with LayerNormDenseGeneral, LayerNormMLP
    TE-->>Notebook: Return fused params
    Notebook->>Utils: share_fused_parameters_with_basic_te_model()
    Utils-->>Notebook: Return shared fused params
    Notebook->>Utils: speedometer() with fused model
    Utils->>TE: Run timing with fused kernels
    TE-->>Utils: Return loss and gradients
    Utils-->>Notebook: Print further improved time
    
    Note over Notebook,TE: 4. Use TE TransformerLayer with FP8
    Notebook->>TE: Create fp8_recipe (DelayedScaling)
    Notebook->>TE: fp8_autocast(enabled=True, fp8_recipe)
    Notebook->>TE: Initialize TransformerLayer within fp8_autocast
    TE-->>Notebook: Return params with fp8_metas collection
    Notebook->>Utils: share_parameters_with_transformerlayer_te_model()
    Utils-->>Notebook: Return shared params with fp8_metas
    Notebook->>Utils: speedometer() with fp8_autocast_kwargs
    Utils->>TE: Run timing with FP8 precision
    TE-->>Utils: Return loss and gradients (FP8)
    Utils-->>Notebook: Print fastest mean time
    
    Notebook-->>User: Display performance comparison
Loading

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review. The developer has cleaned up the JAX quickstart notebook by removing execution artifacts (output cells, execution counts), fixing typos and type annotations, and correcting markdown link syntax. The quickstart_jax_utils.py file was updated to remove an unused key parameter from train_step_fn calls and add a 2025 copyright header. However, a critical logic error remains unfixed in the notebook at line 772: the code references te_transformer_params_template which is undefined in the FP8 context (it was only defined outside the FP8 block in line 671). This variable should be te_transformer_params instead.

Important Files Changed

Filename Score Overview
docs/examples/quickstart_jax.ipynb 2/5 Cleaned notebook execution artifacts and fixed typos/types, but introduced critical undefined variable bug in FP8 initialization block (line 772)
docs/examples/quickstart_jax_utils.py 3/5 Removed unused key parameter from train step functions and added2025copyright; shallow copy operations may allow unintended parameter mutation

Confidence score: 2/5

  • This PR cannot be merged safely due to a critical logic error that will cause runtime failure.
  • Score reflects one critical bug (undefined variable te_transformer_params_template at line772 in the notebook) that will break FP8 initialization when users try to execute the notebook. The developer needs to replace te_transformer_params_template with te_transformer_params in the FP8context. Additionally, the shallow copy.deepcopy operations in the utils file may cause subtle parameter-sharing bugs if nested dictionaries are mutated.
  • The notebook file (quickstart_jax.ipynb) requires immediate attention - specifically line 772 in the FP8 initialization block where an undefined variable is referenced.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@tdophung tdophung force-pushed the tdophung/jax_quickstart_documentation branch from d9859d1 to b742244 Compare October 28, 2025 21:19
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the latest changes to two documentation files: docs/examples/quickstart_jax.ipynb and docs/examples/quickstart_jax_utils.py. Recent commits have addressed previous formatting and documentation issues (markdown links, type annotations, typos), but critical runtime bugs remain unresolved. The quickstart notebook demonstrates progressive optimization of a JAX/Flax transformer using Transformer Engine, but the utility file contains a function signature mismatch that will cause immediate failures when users attempt to run the notebook.

Important Files Changed

Filename Score Overview
docs/examples/quickstart_jax.ipynb 2/5 Cleared execution artifacts and fixed minor issues, but contains a critical NameError on line 772 where undefined variable te_transformer_params_template is referenced in FP8 initialization block
docs/examples/quickstart_jax_utils.py 1/5 Added benchmarking utilities for JAX/Flax models with critical signature mismatch—train_step_fn called with 4 arguments (lines 41, 47) but function signatures expect 4 parameters (dropout_key required); will fail immediately at runtime

Confidence score: 1/5

  • This PR will cause immediate runtime failures when users attempt to follow the quickstart guide due to function signature mismatches and undefined variable references
  • Score reflects two critical blocking issues: (1) the train_step_fn signature mismatch in quickstart_jax_utils.py where calls on lines 41 and 47 provide 4 arguments but the underlying functions require dropout_key as a parameter, and (2) the undefined te_transformer_params_template reference on line 772 of the notebook that will raise NameError
  • Both utility file functions (create_train_step_fn and create_train_step_fn_vjp) and the FP8 initialization block in the notebook require immediate correction before this documentation can be published

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

" )\n",
" x = attention(q, k, v, attention_mask, deterministic=deterministic)\n",
" \n",
" # Attention built-in. Comment out if not used\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a bit more detail to this comment, something like: If you would like to use Flax's built-in attention implementation, you can use this instead

Want to clarify to the users' that this is an alternative option and not some debug code we left by mistake

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment really got resolved. Let me make a suggestion how I think it could look like instead.

@phu0ngng phu0ngng requested review from phu0ngng and ptrendx October 30, 2025 00:17
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR introduces a comprehensive JAX/Flax quickstart tutorial that mirrors the PyTorch quickstart guide. The tutorial demonstrates four progressive implementations of a Transformer layer, showing performance improvements from pure JAX/Flax (~28ms) to fully optimized TE with FP8 (~8ms).

Key Changes:

  • New quickstart_jax.ipynb notebook with 4 transformer implementations (pure JAX, basic TE, fused TE, full TE with FP8)
  • Supporting utility file quickstart_jax_utils.py with benchmarking functions and custom attention/MLP modules
  • Workflow update to use specific PyTorch CUDA index URL for build consistency

Issues Found:

  • Several typos in documentation (already flagged in previous comments)
  • Malformed markdown link in notebook cell
  • Unused datasets import
  • Minor documentation link format issues

Positive Aspects:

  • Tutorial provides clear progressive optimization path
  • Benchmarking utilities are well-structured and correct
  • Good pedagogical approach with 4 iterations showing incremental improvements
  • Code is functionally correct despite some previous incorrect review comments

Confidence Score: 4/5

  • This PR is safe to merge after addressing minor documentation issues
  • Score reflects that this is a documentation-only PR with correct functional code. The issues found are primarily cosmetic (typos, formatting) rather than logical errors. Several previous comments incorrectly flagged non-issues (e.g., VJP signature, missing imports that don't exist). The actual issues are: unused import, malformed markdown link, and typos - all easily fixable and non-blocking for a draft PR.
  • docs/examples/quickstart_jax.ipynb needs attention for typos and formatting issues. The utility file and workflow changes are solid.

Important Files Changed

File Analysis

Filename Score Overview
.github/workflows/build.yml 5/5 Workflow updated to install torch from specific CUDA index URL - improves build consistency
docs/examples/quickstart_jax.ipynb 3/5 New JAX quickstart tutorial with 4 transformer implementations; contains typos and malformed markdown link already flagged
docs/examples/quickstart_jax_utils.py 4/5 Utility functions for speedometer benchmarking and custom attention/MLP modules; code is correct despite some previous incorrect comments

Sequence Diagram

sequenceDiagram
    participant User as User/Tutorial
    participant Basic as BasicTransformerLayer
    participant BasicTE as BasicTETransformerLayer
    participant Fused as FusedTETransformerLayer
    participant Full as TransformerLayer (TE)
    participant Utils as quickstart_jax_utils
    
    User->>Basic: Initialize with pure JAX/Flax
    Basic->>Basic: nn.LayerNorm + nn.Dense (QKV)
    Basic->>Utils: DotProductAttention
    Utils-->>Basic: attention output
    Basic->>Basic: nn.Dense (projection) + nn.Dropout
    Basic->>Utils: BasicMLP
    Utils-->>Basic: MLP output
    Basic-->>User: ~28ms training step
    
    User->>BasicTE: Replace with TE modules
    BasicTE->>BasicTE: te_flax.LayerNorm + te_flax.DenseGeneral
    BasicTE->>Utils: DotProductAttention
    Utils-->>BasicTE: attention output
    BasicTE->>BasicTE: te_flax.DenseGeneral + nn.Dropout
    BasicTE->>BasicTE: BasicTEMLP (te_flax.DenseGeneral)
    BasicTE-->>User: ~17ms training step
    
    User->>Fused: Use fused TE operations
    Fused->>Fused: te_flax.LayerNormDenseGeneral (fused)
    Fused->>Utils: DotProductAttention
    Utils-->>Fused: attention output
    Fused->>Fused: te_flax.DenseGeneral + nn.Dropout
    Fused->>Fused: te_flax.LayerNormMLP (fused)
    Fused-->>User: ~18ms training step
    
    User->>Full: Use full TE TransformerLayer
    Full->>Full: Complete optimized implementation
    Full-->>User: ~12ms training step
    
    User->>Full: Enable FP8 with fp8_autocast
    Full->>Full: FP8 precision compute
    Full-->>User: ~8ms training step
Loading

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@phu0ngng phu0ngng changed the title [DRAFT] Jax (separate from PyTorch) Quickstart documentation [JAX] Quickstart documentation Oct 31, 2025
@ptrendx
Copy link
Member

ptrendx commented Oct 31, 2025

There seems to be some issue with the commits mixed up. Could you rebase on top of the current main @tdophung ?

"- `DotProductAttention`: `DotProductAttention` from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n",
"- `Projection`: `nn.Dense` (JAX/Flax)\n",
"- `Dropout`: `nn.Dropout` (JAX/Flax)\n",
"- `MLP`: `BasicMLP` from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ksivaman putting on your radar that this will require more substitutions when we create the docs package.

"- `nn.Dropout`: JAX/Flax Dropout\n",
"- `BasicMLP`: Custom MLP from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n",
"\n",
"<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this sentence - it sounds too informal. A quick pass with copilot gave this instead:

The code snippet below also includes commented examples demonstrating how to utilize the built-in attention sub-layer from either pure Flax or TE Flax. These alternatives are provided should you prefer them over the custom attention implementation found in quickstart_jax_utils.py. This reference is intended to offer insight into the general structure and approach used for implementing attention mechanisms in our source code.

I don't really like that one either (although it sounds better), but maybe we could use it to get something inbetween.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of having both the flax.linen.MultiHeadDotProductAttention and the self-built attention? If the self-built attention is implemented correctly, they should have exactly the same ops as the linen one, right?

I think it will be less confusing to simply start with the linen attention module directly, then swap it with the TE one. In that way, we can expose those two modules in this tutorial directly rather than hiding their details in the util file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it is helpful to have the self-built attention since it shows how attention works and also stay close to the Pytorch tutorial

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds comprehensive JAX quickstart documentation mirroring the PyTorch tutorial structure, demonstrating progressive performance optimization through four iterations of a Transformer layer implementation.

Key Changes:

  • New Jupyter notebook quickstart_jax.ipynb (699 lines) with hands-on tutorial
  • Supporting utility file quickstart_jax_utils.py (267 lines) containing reusable components
  • CI workflow update to separate PyTorch installation with explicit CUDA 13.0 index URL

Tutorial Progression:

  1. Pure JAX/Flax baseline (~26ms) - Custom attention and MLP implementations
  2. Basic TE modules (~16ms) - Using te_flax.LayerNorm and DenseGeneral
  3. Fused TE modules (~16ms) - Using LayerNormDenseGeneral and LayerNormMLP
  4. Full TE TransformerLayer (~11ms) - Complete optimized implementation
  5. FP8 enabled (~7ms) - 73% speedup over baseline with FP8 precision

Issues Found:

  • Multiple typos in documentation ('similicity', 'attetntion', 'inistead', 'linnen', 'buiil-in')
  • Hardcoded absolute path os.chdir('/workspace/docs/examples') will break for users
  • Malformed markdown link syntax [quickstart_jax_utils.py] (**)(quickstart_jax_utils.py)
  • Extra quote in string 'no_mask'' should be 'no_mask'
  • Unused datasets import in notebook
  • layernorm_eps defined as int instead of float (style issue, doesn't affect functionality)

Confidence Score: 3/5

  • Safe to merge after fixing documentation issues - no functional bugs found, only typos and hardcoded paths that will impact user experience
  • Score of 3 reflects that while the code is functionally correct and the tutorial successfully demonstrates TE capabilities, there are several documentation quality issues that should be fixed: hardcoded paths will break for users running the notebook outside the expected environment, multiple typos reduce professionalism, and malformed markdown links affect navigation. The CI change is safe and improves dependency management.
  • docs/examples/quickstart_jax.ipynb requires fixing hardcoded paths (line 59) and typos throughout; docs/examples/quickstart_jax_utils.py needs typo corrections

Important Files Changed

File Analysis

Filename Score Overview
.github/workflows/build.yml 5/5 Splits PyTorch installation into separate command with explicit CUDA index URL for better dependency management
docs/examples/quickstart_jax_utils.py 3/5 New utility file with attention implementations and speedometer function; contains several typos and documentation formatting issues flagged in previous comments
docs/examples/quickstart_jax.ipynb 3/5 New JAX quickstart tutorial notebook demonstrating progressive optimization with TE; contains hardcoded paths, typos, and malformed markdown links flagged in previous comments

Sequence Diagram

sequenceDiagram
    participant User
    participant Notebook as quickstart_jax.ipynb
    participant Utils as quickstart_jax_utils.py
    participant JAX as JAX/Flax
    participant TE as TransformerEngine

    User->>Notebook: Run tutorial cells
    Notebook->>JAX: Import jax, jax.numpy, flax.linen
    Notebook->>Utils: Import quickstart_jax_utils
    Notebook->>TE: Import transformer_engine.jax
    
    Note over Notebook: Iteration 1: Pure JAX/Flax
    Notebook->>JAX: Define BasicTransformerLayer
    Notebook->>Utils: Use AttentionWrapper(CUSTOM_DOT_PRODUCT)
    Notebook->>Utils: Use BasicMLP
    Notebook->>Utils: speedometer() - measure baseline
    
    Note over Notebook: Iteration 2: Basic TE
    Notebook->>TE: Use te_flax.LayerNorm, DenseGeneral
    Notebook->>Utils: Use AttentionWrapper(TE_FLAX_MULTIHEAD)
    Notebook->>Utils: speedometer() - faster performance
    
    Note over Notebook: Iteration 3: Fused TE
    Notebook->>TE: Use LayerNormDenseGeneral, LayerNormMLP
    Notebook->>Utils: speedometer() - similar to iteration 2
    
    Note over Notebook: Iteration 4: Full TE TransformerLayer
    Notebook->>TE: Use te_flax.TransformerLayer
    Notebook->>Utils: speedometer() - best performance
    
    Note over Notebook: Enable FP8
    Notebook->>TE: fp8_autocast context manager
    Notebook->>TE: Initialize with FP8 recipe
    Notebook->>Utils: speedometer(fp8_autocast_kwargs) - fastest
    
    Utils->>JAX: create_train_step_fn()
    Utils->>JAX: jax.value_and_grad() for gradients
    Utils->>JAX: jax.jit() for compilation
    Utils-->>Notebook: Return timing results
Loading

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator

pggPL commented Nov 4, 2025

  1. It's not being rendered in the docs now, so I cannot see it - you need to put it into some proper .rst file.
  2. I had a discussion with @ptrendx yesterday and he said that you considered using tabs like in d2l.ai. I did something like this here using sphinx-tabs package. You can download the docs from here and go to section features -> low precision training -> custom recipe to see it. Tbh it is 100% cursor generated, so the code may not make sense, but it looks nice. Idk if you want to use it, just letting you know.
Zrzut ekranu 2025-11-4 o 17 57 00

@tdophung
Copy link
Collaborator Author

tdophung commented Nov 5, 2025

Wonder why I cannot reply to Pawel's suggestion to do tabbing in html. But I intend for this MR to just be the notebook first and work on a separate PR for displaying on html since that will also touch the Python quickstart guide. This PR is just for a functional ipynb

…some footnotes in the quick start notebook

Signed-off-by: tdophung <tdophung@nvidia.com>
…comaptibility with speedometer

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…Layer

Signed-off-by: tdophung <tdophung@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: tdophung <tdophung@nvidia.com>
…e and unfused TE impls to achieve same performance (removing extra dropout layer in fused layers. Also some minor wording changes

Signed-off-by: tdophung <tdophung@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the tdophung/jax_quickstart_documentation branch from b70bddf to c6b6ca3 Compare November 5, 2025 00:47
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds comprehensive JAX quickstart documentation demonstrating progressive optimization of Transformer layers using TransformerEngine. The tutorial walks through 4 implementations showing increasing performance gains (26ms → 7ms).

Key Changes:

  • New Jupyter notebook (quickstart_jax.ipynb) with 4 progressive transformer implementations
  • Supporting utility module (quickstart_jax_utils.py) with attention wrappers and performance benchmarking
  • CI workflow update for PyTorch installation with explicit CUDA index URL

Implementation Progression:

  1. BasicTransformerLayer: Pure JAX/Flax baseline (~26ms per step)
  2. BasicTETransformerLayer: Using TE's DenseGeneral and LayerNorm modules (~16ms)
  3. FusedTETransformerLayer: Leveraging fused kernels like LayerNormDenseGeneral (~16ms)
  4. Full TransformerLayer with FP8: Complete TE layer with FP8 precision (~7ms)

Strengths:

  • Well-structured tutorial mirroring PyTorch quickstart for consistency
  • Clear performance comparisons with timing measurements
  • Good documentation of FP8 initialization requirements
  • Comprehensive utility code with multiple attention implementations

Previous Review Feedback Addressed:
Most syntax issues (typos, formatting) and several logic concerns from previous reviews appear to have been addressed in recent commits, including fixing attention implementation inconsistencies and adding proper enum for attention types.

Confidence Score: 4/5

  • This documentation PR is safe to merge with minor issues already flagged in previous reviews
  • Score reflects that this is a documentation-only PR with comprehensive tutorial content. The code examples are well-structured and functional. Previous review comments have identified syntax issues (typos, formatting) that should be addressed, but these are non-blocking for a documentation PR. The utility code is correct and the notebook provides educational value. The CI workflow change is sensible and low-risk.
  • All files are in good shape. The notebook may have minor typos/formatting issues noted in previous comments that should be cleaned up before merge.

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/quickstart_jax.ipynb 4/5 Comprehensive JAX quickstart notebook with 4 progressive transformer implementations. Most previous syntax/typo issues appear addressed. Minor formatting improvements needed.
docs/examples/quickstart_jax_utils.py 5/5 Well-structured utility module with attention implementations and performance testing. Code is correct and follows best practices.

Sequence Diagram

sequenceDiagram
    participant User as User/Notebook
    participant Utils as quickstart_jax_utils
    participant JAX as JAX/Flax
    participant TE as TransformerEngine
    
    Note over User: 1. BasicTransformerLayer (Pure JAX/Flax)
    User->>JAX: Initialize BasicTransformerLayer
    JAX-->>User: params
    User->>Utils: speedometer(model, params, data)
    Utils->>Utils: create_train_step_fn()
    Utils->>JAX: Forward/Backward Pass
    JAX-->>Utils: loss, gradients
    Utils-->>User: Mean time: ~26ms
    
    Note over User: 2. BasicTETransformerLayer (TE modules)
    User->>TE: Initialize with DenseGeneral, LayerNorm
    TE-->>User: te_params
    User->>Utils: speedometer(model, params, data)
    Utils->>TE: Forward/Backward with TE modules
    TE-->>Utils: loss, gradients
    Utils-->>User: Mean time: ~16ms
    
    Note over User: 3. FusedTETransformerLayer (Fused TE)
    User->>TE: Initialize with LayerNormDenseGeneral, LayerNormMLP
    TE-->>User: fused_params
    User->>Utils: speedometer(model, params, data)
    Utils->>TE: Forward/Backward with fused kernels
    TE-->>Utils: loss, gradients
    Utils-->>User: Mean time: ~16ms
    
    Note over User: 4. TransformerLayer with FP8
    User->>TE: fp8_autocast context
    User->>TE: Initialize TransformerLayer
    TE-->>User: params + fp8_metas
    User->>Utils: speedometer(model, params, data, fp8_enabled=True)
    Utils->>TE: Forward/Backward with FP8 precision
    TE-->>Utils: loss, gradients
    Utils-->>User: Mean time: ~7ms
Loading

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

"- `nn.Dropout`: JAX/Flax Dropout\n",
"- `BasicMLP`: Custom MLP from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n",
"\n",
"<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of having both the flax.linen.MultiHeadDotProductAttention and the self-built attention? If the self-built attention is implemented correctly, they should have exactly the same ops as the linen one, right?

I think it will be less confusing to simply start with the linen attention module directly, then swap it with the TE one. In that way, we can expose those two modules in this tutorial directly rather than hiding their details in the util file.

"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 16.173095703125 ms\n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we add an option to switch between the two Attention modules in the call function, then we have 2 blocks of time measurement, one for the FusedTETransformerLayer + linen.DPA, one for FusedTETransformerLayer + te_flax.DPA.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest version, I changed the structure of the tutorial to bring FP8 up and Fused down as suggested before, and added the comparison between BasicTETransformerLayer + linen.DPA and BasicTETransformerLayer + te_flax.DPA already, so I do not think it is necessary to do the same for Fused.

Also, since Fused is now only under the FP8 section, there is no XLA ffi for linen.DPA with fp8 autocast so I could not have done this anyways. Please me know what you think

"id": "a45c12c8",
"metadata": {},
"source": [
"Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:"
Copy link
Collaborator

@phu0ngng phu0ngng Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite misleading, as it is right now.

It shows that the TE Transformer layer is 11ms while the FusedTETransformerLayer is 16ms (which was measured with linen.DPA). So with my above suggestions, we should be able to show that TE Transformer and FusedTETransformerLayer have the same runtime.

I don't think the TransformerLayer could be any more performant than the FusedTETransformerLayer, as the performant part comes from the fact that we try to fuse kernels together in the fused modules. In other words, it should not be different if users use our TransformerLayer or make their own FusedTETransformerLayer by stacking the fused modules.

So the TransformerLayer is there just for convenient usage.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the FusedTETransformerLayer is 16ms, and the BasicTETransformerLayer are both at 16ms and they both already use te_flax.DPA from transformer_engine.jax.flax.transformer import DotProductAttention, NOT linen.DPA.
I was not aware that TE Transformer layer (11ms) is not supposed to be much faster than the other 2 before this. I can try to take a look at why that is the case (that the other 2 are 16ms even with te_flax.DPA). But please comment if you see anything more apparent

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have figured out why this was the case. Simply put I passed the wrong expected shape of input into the TransformerLayer while passing the correct expected shape into the Fused version.
I initiated the input in the same shape as what pytorch would expect (seq_len, batch, hidden] while the TransformerLayer takes in [batch, seq_len, hidden] by default. in te_flax.DPA I correctly set transpose_batch_sequence=True in the argument to account for this, but in TransformerLayer I did not. Therefore the max_seqlen was set wrong and TransformerLayer did a lot less work than it supposed to (batch=4, seqlen=2048)

After I fixed this in the most recent commit, this is fixed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung now that the performance matches, let's remove "it provides the highest degree of performance optimization:" since TransformerLayer and individual usage of fused LayerNormMLP, etc. are equivalent

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated and will go away after rebasing, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right

"\n",
"## Overview\n",
"\n",
"Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It provides support for 8-bit floating point (FP8) precision on Hopper GPUs

We should update this to a newer message. I cc'd you on the same discussion about our TE GitHub description now that we support nvfp4. You can copy whichever wording we decide there into here

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the wording that the GitHub repo description got updated to

8-bit and 4-bit floating point (FP8 and FP4) precision on Hopper, Ada and Blackwell GPUs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have already updated it to
"It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs"

which is more precise because we do not have 4-bit precision on Hopper or Ada?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice! Missed that. I agree, that sounds good, thanks!

"\n",
"<b>Important: FP8 Metadata Initialization</b>\n",
"\n",
"When using FP8, the model **must be initialized within the `fp8_autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `fp8_autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this description of fp8_metas copied from the PyTorch tutorial? I think PyTorch needs this metadata to store all their scaling factors for every recipe since it's not easily supported in PyTorch. But in TE/JAX, we don't need this for scaling factors. We only need this for recipes that store state across iterations, which is only DelayedScaling's amax history and scale

We also do have some partial support for creating quantizers outside of autocast contexts, so I don't think the ScopeCollectionNotFound error applies to JAX. But if you have observed, lmk and I can take a look. Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not copied from PyTorch tutorial.
I actually observed this while creating the JAX tutorial, hence the warning. Please take a look!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know, I'll take a look

…ch, ...] instead of [batch, sequence,...]

Signed-off-by: tdophung <tdophung@nvidia.com>
…or fuse to take effect because quantization exist as suggested. Also make TransformerLayer perf get closer to Fused by setting hidden_dropout=0

Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds a comprehensive JAX quickstart tutorial that mirrors the existing PyTorch guide, demonstrating progressive optimization of a Transformer layer through 5 iterations:

  • Iteration 1: Pure JAX/Flax baseline (28.61ms)
  • Iteration 2: Basic TE modules with DenseGeneral and LayerNorm (17.11ms, ~40% speedup)
  • Iteration 3: TE with FP8 precision (13.34ms, ~53% speedup)
  • Iteration 4: Fused TE modules with LayerNormDenseGeneral and LayerNormMLP (13.05ms, ~54% speedup)
  • Iteration 5: Full TransformerLayer module (12.95ms, ~55% speedup)

The tutorial includes a helper utilities file (quickstart_jax_utils.py) that provides modular attention implementations and benchmarking tools. The code is well-structured with clear explanations of each optimization step.

Previous review comments have already identified and addressed the main issues: typos in documentation, formatting inconsistencies, and minor style improvements. The workflow change updates PyTorch installation to use CUDA 13.0 index.

Confidence Score: 4/5

  • Safe to merge after addressing the previously flagged syntax/style issues
  • This is a documentation-only PR adding a tutorial notebook. The code is functional (as evidenced by the execution outputs) and well-structured. Previous comments have already identified the main issues (typos, formatting, style suggestions). No critical logical errors or security concerns exist. Score of 4 reflects that while the content is solid, the previously flagged corrections should be applied before merge.
  • Both files need minor cleanup for typos and style improvements already flagged in previous comments

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/quickstart_jax.ipynb 4/5 Comprehensive JAX quickstart tutorial with 4 implementation iterations, minor typos already flagged
docs/examples/quickstart_jax_utils.py 4/5 Helper utilities for JAX quickstart notebook, provides attention and training wrappers with proper modularity

Sequence Diagram

sequenceDiagram
    participant User
    participant Notebook as quickstart_jax.ipynb
    participant Utils as quickstart_jax_utils.py
    participant JAX as JAX/Flax
    participant TE as TransformerEngine

    User->>Notebook: Start tutorial
    
    Note over Notebook,JAX: Iteration 1: Pure JAX/Flax
    Notebook->>JAX: Create BasicTransformerLayer
    JAX->>Utils: Use AttentionWrapper (CUSTOM_DOT_PRODUCT)
    Utils->>JAX: Return attention output
    JAX->>Notebook: Return layer output
    Notebook->>Utils: Call speedometer() for benchmarking
    Utils->>Notebook: Report 28.61ms per step
    
    Note over Notebook,TE: Iteration 2: Basic TE modules
    Notebook->>TE: Create BasicTETransformerLayer
    TE->>Utils: Use AttentionWrapper (TE_FLAX_MULTIHEAD)
    Utils->>TE: Return TE attention output
    TE->>Notebook: Return layer output
    Notebook->>Utils: Call speedometer()
    Utils->>Notebook: Report 17.11ms per step
    
    Note over Notebook,TE: Iteration 3: TE with FP8
    Notebook->>TE: Wrap in fp8_autocast context
    TE->>TE: Initialize FP8 metadata
    Notebook->>TE: Apply BasicTETransformerLayer with FP8
    TE->>Notebook: Return FP8-optimized output
    Notebook->>Utils: Call speedometer() with FP8
    Utils->>Notebook: Report 13.34ms per step
    
    Note over Notebook,TE: Iteration 4: Fused TE modules
    Notebook->>TE: Create FusedTETransformerLayer
    TE->>TE: Use LayerNormDenseGeneral & LayerNormMLP
    TE->>Notebook: Return fused layer output
    Notebook->>Utils: Call speedometer() with FP8
    Utils->>Notebook: Report 13.05ms per step
    
    Note over Notebook,TE: Iteration 5: Full TransformerLayer
    Notebook->>TE: Create te_flax.TransformerLayer
    TE->>TE: Apply complete optimized layer
    TE->>Notebook: Return optimized output
    Notebook->>Utils: Call speedometer() with FP8
    Utils->>Notebook: Report 12.95ms per step
    
    User->>Notebook: Complete tutorial
Loading

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

…ll of BasicTETransformerLayer and demonstrated difference in runtime between using flax and using te's attetion implementation

Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds a comprehensive JAX/Flax quickstart guide for Transformer Engine, mirroring the PyTorch tutorial. It demonstrates 4 progressive implementations of a Transformer encoder layer with increasing optimization: pure Flax, basic TE modules, partially fused TE modules, and fully fused TransformerLayer.

Key Changes:

  • New Jupyter notebook quickstart_jax.ipynb with performance benchmarks showing ~2x speedup progression from pure Flax (25.5ms) to full TE with FP8 (13ms)
  • Utility file quickstart_jax_utils.py with attention wrappers, MLP implementations, and training speedometer
  • Modified CI workflow to install PyTorch from specific CUDA index

Issues Found:

  • Critical: FusedTETransformerLayer references undefined self.attention_implementation attribute (line 334) - will cause runtime failure
  • Multiple issues already flagged in previous comments include typos, hardcoded paths, import issues, and parameter mismatches in utility functions

Confidence Score: 2/5

  • This PR has critical runtime errors that will prevent the notebook from executing successfully
  • While the documentation quality is good and demonstrates useful features, there's a critical bug where FusedTETransformerLayer uses an undefined attribute. Combined with numerous other issues flagged in previous comments (import errors, parameter mismatches, typos), this PR needs significant fixes before merge
  • Both docs/examples/quickstart_jax.ipynb and docs/examples/quickstart_jax_utils.py require attention - the notebook has the critical undefined attribute bug and the utils file has multiple parameter mismatch issues

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/quickstart_jax.ipynb 2/5 Tutorial notebook with critical bug: FusedTETransformerLayer uses undefined attention_implementation attribute. Also contains multiple typos and hardcoded paths flagged in previous reviews
docs/examples/quickstart_jax_utils.py 3/5 Utility functions with several parameter mismatches and unused parameters already flagged. Core logic appears sound but needs cleanup for robustness

Sequence Diagram

sequenceDiagram
    participant User
    participant Notebook
    participant TransformerLayer
    participant Speedometer
    participant TrainStepFn
    participant FP8Autocast
    participant JAXBackend

    User->>Notebook: Initialize model config
    Notebook->>TransformerLayer: init(key, x, attention_mask)
    TransformerLayer-->>Notebook: Return initialized params
    
    User->>Speedometer: speedometer(model_apply_fn, variables, input, output_grad)
    Speedometer->>TrainStepFn: create_train_step_fn(model_apply_fn, fp8_autocast_kwargs)
    TrainStepFn-->>Speedometer: Return JIT-compiled train_step_fn
    
    loop Warmup Iterations
        Speedometer->>TrainStepFn: train_step_fn(variables, input, output_grad, key)
        TrainStepFn->>FP8Autocast: Enter fp8_autocast context
        FP8Autocast->>TransformerLayer: apply(variables, input)
        TransformerLayer->>JAXBackend: Forward pass (LayerNorm, QKV, Attention, MLP)
        JAXBackend-->>TransformerLayer: Output tensor
        TransformerLayer-->>FP8Autocast: Return output
        FP8Autocast->>JAXBackend: Compute gradients (VJP)
        JAXBackend-->>FP8Autocast: param_grads, other_grads
        FP8Autocast-->>TrainStepFn: loss, gradients
        TrainStepFn-->>Speedometer: loss, gradients
    end
    
    loop Timing Iterations
        Speedometer->>TrainStepFn: train_step_fn(variables, input, output_grad, key)
        Note over TrainStepFn,JAXBackend: Same forward/backward flow as warmup
        TrainStepFn-->>Speedometer: loss, gradients
    end
    
    Speedometer-->>User: Print mean time per iteration
Loading

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

…ayer

Signed-off-by: tdophung <tdophung@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 6, 2025

Greptile Overview

Updated On: 2025-11-06 20:06:38 UTC

Greptile Summary

Adds comprehensive JAX quickstart documentation with a Jupyter notebook tutorial and utility functions. The tutorial demonstrates four progressive implementations of a Transformer layer, from pure JAX/Flax to fully-fused TE modules with FP8 support.

Major issues found:

  • Multiple syntax errors and typos throughout the notebook that will cause runtime failures ('no_mask'', similicity, attetntion, inistead, linnen, buiil-in)
  • Hardcoded absolute paths that will break for users outside the development environment
  • Missing FusedTETransformerLayer.attention_implementation attribute causing runtime errors
  • Malformed markdown links
  • Logic issues with undefined variables and function parameter mismatches in utility file

Minor issues:

  • Unused imports and parameters in utility functions
  • Type inconsistencies (layernorm_eps as int instead of float)
  • Incorrect API documentation link format

Confidence Score: 2/5

  • This PR has multiple critical syntax and logic errors that will prevent the tutorial from running successfully for users
  • Score reflects numerous syntax errors, typos, hardcoded paths, and missing attributes that were flagged in previous comments but remain unresolved. These issues will cause immediate runtime failures when users try to follow the tutorial
  • Pay close attention to docs/examples/quickstart_jax.ipynb which contains most of the critical syntax and logic errors

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/quickstart_jax.ipynb 2/5 New JAX quickstart tutorial with multiple syntax errors, typos, hardcoded paths, and logical issues that will break for users
docs/examples/quickstart_jax_utils.py 3/5 New utility file with minor code style issues and some parameter inconsistencies but mostly functional
.github/workflows/build.yml 5/5 Updated PyTorch installation to use specific CUDA index URL - straightforward and safe change

Sequence Diagram

sequenceDiagram
    participant User
    participant Notebook as quickstart_jax.ipynb
    participant Utils as quickstart_jax_utils.py
    participant JAX/Flax
    participant TE as TransformerEngine

    User->>Notebook: Run tutorial
    Notebook->>Utils: Import helper functions
    
    Note over Notebook: Step 1: Basic Transformer
    Notebook->>JAX/Flax: Create BasicTransformerLayer
    Notebook->>Utils: Use AttentionWrapper(FLAX_LINEN_MULTIHEAD)
    Utils->>JAX/Flax: Use nn.MultiHeadDotProductAttention
    Notebook->>Utils: speedometer() to measure performance
    
    Note over Notebook: Step 2: Basic TE
    Notebook->>TE: Create BasicTETransformerLayer
    Notebook->>Utils: Use AttentionWrapper(TE_FLAX_MULTIHEAD)
    Utils->>TE: Use TEDotProductAttention
    Notebook->>Utils: speedometer() with FP8 disabled
    
    Note over Notebook: Step 3: TE with FP8
    Notebook->>TE: fp8_autocast(enabled=True)
    Notebook->>TE: Initialize model within autocast
    Notebook->>Utils: speedometer() with FP8 enabled
    
    Note over Notebook: Step 4: Fused TE
    Notebook->>TE: Create FusedTETransformerLayer
    Note over TE: Uses LayerNormDenseGeneral,<br/>LayerNormMLP
    Notebook->>Utils: speedometer() with FP8 enabled
    
    Note over Notebook: Step 5: Full TransformerLayer
    Notebook->>TE: Use te_flax.TransformerLayer
    Note over TE: Highest optimization level
    Notebook->>Utils: speedometer() with FP8 enabled
    
    Utils->>User: Performance metrics for each step
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

…mpl only, removing last mention of Pytorch

Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds comprehensive JAX quickstart tutorial demonstrating progressive optimization with Transformer Engine. The tutorial walks through 4 implementation iterations of a Transformer encoder layer, showing performance improvements from pure Flax (17.7ms) to fully optimized TE with FP8 (9.7ms).

Major changes:

  • New Jupyter notebook tutorial (quickstart_jax.ipynb) with 16 code cells demonstrating TE integration
  • Utility module (quickstart_jax_utils.py) providing benchmarking functions for forward/backward pass timing
  • Unrelated CI workflow modification to PyTorch installation method

Issues found:

  • Multiple typos in documentation text and code comments (similicity, attetntion, inistead, linnen, buiil-in, QKVoutput)
  • Hardcoded absolute path in notebook will break for users running from different directories
  • Malformed markdown link syntax in notebook
  • Logic error referencing undefined variable in FP8 section
  • Function signature mismatches in utility file between train_step_fn definition and calls
  • Unused parameters (model_init_fn, key) in utility functions
  • CI workflow change appears unrelated to JAX documentation PR

Structure: The tutorial effectively mirrors the PyTorch quickstart guide's architecture, making it easy for users familiar with the PyTorch version to follow along with the JAX implementation.

Confidence Score: 2/5

  • This PR has multiple logic errors and typos that will affect user experience and potentially cause runtime failures
  • Score reflects several critical issues: hardcoded paths that will break for users, logic errors with undefined variables that could cause runtime failures, function signature mismatches in utility code, and numerous typos throughout the documentation. While the tutorial structure and approach are sound, these implementation issues need to be resolved before merge.
  • Pay close attention to docs/examples/quickstart_jax.ipynb (hardcoded paths, undefined variable references, typos) and docs/examples/quickstart_jax_utils.py (function signature mismatches). Also verify whether the .github/workflows/build.yml change is intentional.

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/quickstart_jax.ipynb 2/5 JAX quickstart tutorial with multiple typos, hardcoded paths, malformed markdown links, and logic errors referencing undefined variables
docs/examples/quickstart_jax_utils.py 2/5 Utility functions for speedometer benchmarking with function signature mismatches and unused parameters
.github/workflows/build.yml 3/5 CI workflow modification to PyTorch installation - appears unrelated to JAX documentation changes

Sequence Diagram

sequenceDiagram
    participant User
    participant Notebook as quickstart_jax.ipynb
    participant Utils as quickstart_jax_utils.py
    participant Flax as Flax Modules
    participant TE as Transformer Engine

    User->>Notebook: Run tutorial
    Notebook->>Flax: Create BasicTransformerLayer
    Flax-->>Notebook: Pure JAX/Flax implementation
    Notebook->>Utils: speedometer() benchmark
    Utils->>Utils: create_train_step_fn()
    Utils->>Flax: Forward/backward pass
    Flax-->>Utils: Timing results
    Utils-->>Notebook: Baseline performance
    
    Notebook->>TE: Create BasicTETransformerLayer
    TE-->>Notebook: TE modules (DenseGeneral, LayerNorm)
    Notebook->>Utils: speedometer() benchmark
    Utils->>TE: Forward/backward pass
    TE-->>Utils: Improved timing
    Utils-->>Notebook: TE performance
    
    Notebook->>TE: Enable FP8 with fp8_autocast
    TE-->>Notebook: FP8-enabled model
    Notebook->>Utils: speedometer() with FP8
    Utils->>TE: FP8 forward/backward pass
    TE-->>Utils: FP8 timing
    Utils-->>Notebook: FP8 performance gains
    
    Notebook->>TE: Create FusedTETransformerLayer
    TE-->>Notebook: Fused operations
    Notebook->>Utils: speedometer() benchmark
    Utils->>TE: Fused forward/backward pass
    TE-->>Utils: Fused timing
    Utils-->>Notebook: Best performance
    
    Notebook->>TE: Use TransformerLayer (full TE)
    TE-->>Notebook: Fully optimized layer
    Notebook->>Utils: Final benchmark
    Utils->>TE: Optimized forward/backward
    TE-->>Utils: Optimal timing
    Utils-->>Notebook: Maximum performance
Loading

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +86 to +87
run: |
pip install pybind11[global] einops onnxscript
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: unrelated change to CI workflow in documentation PR

This PyTorch installation change seems unrelated to the JAX quickstart documentation. Should this be in a separate PR?

"- `nn.Dropout`: JAX/Flax Dropout\n",
"- `BasicMLP`: Custom MLP implemented from `nn.Dense\n",
"\n",
"<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n",

This can be removed now that we've removed the custom attention impl

"- `DotProductAttention`: Custom attention from [quickstart_jax_utils.py] (**)(quickstart_jax_utils.py)\n",
"- `nn.Dense`: JAX/Flax Dense layer for projection\n",
"- `nn.Dropout`: JAX/Flax Dropout\n",
"- `BasicMLP`: Custom MLP implemented from `nn.Dense\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: missing closing backtick ` on nn.Dense here

"source": [
"import transformer_engine.jax as te\n",
"import transformer_engine.jax.flax as te_flax\n",
"from transformer_engine.jax.quantize import is_fp8_available, ScalingMode"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from transformer_engine.jax.quantize import is_fp8_available, ScalingMode

I don't see is_fp8_available or ScalingMode used later in the notebook. Let's remove them

"id": "c9a101d3",
"metadata": {},
"source": [
"## Enabling FP8\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Here and below can we replace FP8 -> Quantization unless specifically talking about an fp8-specific recipe. For the first mention of quantization, this title, we can mention FP8 and FP4 explicitly "Enabling Quantization (FP8 or FP4)"

"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" basic_te_params = basic_te_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
" y = basic_te_transformer.apply(basic_te_params, x, attention_mask=None, deterministic=True)\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this line? I don't see y used later

y = basic_te_transformer.apply(basic_te_params, x, attention_mask=None, deterministic=True)

"\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" fused_te_params = fused_te_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
" y = fused_te_transformer.apply(fused_te_params, x, attention_mask=None, deterministic=True)\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again here, I don't see y used later. Can this be removed or is there other reasoning I'm missing, thanks!

"id": "a45c12c8",
"metadata": {},
"source": [
"Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung now that the performance matches, let's remove "it provides the highest degree of performance optimization:" since TransformerLayer and individual usage of fused LayerNormMLP, etc. are equivalent

"\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" te_transformer_params = te_transformer.init(key, x, deterministic=False)\n",
" y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above for y = ...

"- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html\n",
"\n",
"## Let's build a Transformer encoder layer (*)!\n",
"<small>(*) _This was based upon the GPT decoder layer, but for the sake of simplicity and mirroring the PyTorch tutorial whose defaults are without any attention mask, we are setting attention mask here also to 0, making the attention basically an encoder, which does not exist in the GPT architecture. However, since the code support any attention mask here in the TransformerLayer (later in this guide), we will leave it to the audience to try experimenting with different attention masks._</small>\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are no longer tied to the PyTorch tutorial and no longer have the custom attention, how hard would it be to use a causal mask so we can remove this caveat? Looks like Flax linen has a built-in function for creating the mask here: https://flax.readthedocs.io/en/v0.6.11/api_reference/_autosummary/flax.linen.make_causal_mask.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants