-
Notifications
You must be signed in to change notification settings - Fork 540
[JAX] Quickstart documentation #2310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Quickstart documentation #2310
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The current commit introduces only whitespace cleanup to docs/examples/quickstart_jax_utils.py, removing trailing spaces from multiple lines (48, 60, 80, and others). While these formatting changes align with the repository's pre-commit hooks (black, clang-format), they have no functional impact on the code. However, the file contains a critical pre-existing issue: three functions (is_bf16_supported(), is_fp8_supported(), is_mxfp8_supported()) call get_device_compute_capability(0) without importing this function, which will cause NameError exceptions at runtime. This is a blocking issue that must be resolved before the documentation can be safely used.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax_utils.py | 1/5 | Whitespace cleanup only, but file contains critical missing import for get_device_compute_capability used in three functions |
Confidence score: 1/5
- This PR is not safe to merge due to a critical missing import that will cause runtime failures
- Score reflects the presence of a blocking bug:
get_device_compute_capabilityis called in lines 18, 25, and 32 but never imported, causing immediateNameErrorexceptions when any of the three capability-checking functions are invoked - Pay close attention to
docs/examples/quickstart_jax_utils.py- add the missing import statement (likelyfrom transformer_engine.jax import get_device_compute_capabilityor similar) and verify all three functions work correctly before merging
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The recent changes address the missing import issue that was previously identified. The quickstart_jax_utils.py file now includes the necessary import for get_device_compute_capability, which resolves the undefined reference that would have caused runtime errors when executing the capability check logic.
PR Description Notes:
- "implemtation" should be "implementation" (in iteration #1)
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax_utils.py | 2/5 | Added import for get_device_compute_capability to fix the previously identified missing import issue |
Confidence score: 2/5
- This PR contains critical bugs that will cause runtime failures and should not be merged as-is
- Score reflects severe functional issues in gradient computation logic: the
speedometerfunction has argument count mismatches (calls train_step with 5 args but function expects 4), and the VJP-based train step has broken forward/backward pass implementations that will fail at runtime - Pay close attention to
quickstart_jax_utils.pylines 38-48 (argument mismatch in speedometer), lines 67-81 (gradient function parameter issues), and lines 96-107 (broken VJP implementation); also reviewquickstart_jax.ipynbfor hardcoded absolute paths at lines60-61 that will break for other users
2 files reviewed, 14 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The recent commits address several issues identified in previous reviews, including fixing typos, correcting import statements, and improving variable naming. The most critical remaining issues are: (1) the signature mismatch in train_step_fn where 5 arguments are passed but only 4 are expected (lines 37-39, 45-47 in utils), (2) the vjp_fn call receiving only 1 cotangent when it expects 3 (line 102 in utils), and (3) the closure issues in forward_fn within the VJP implementation. These changes integrate into the documentation suite by providing a JAX/Flax parallel to the existing PyTorch quickstart guide, allowing users to see equivalent implementations across both frameworks.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax_utils.py | 2/5 | Utility functions for JAX quickstart with critical signature mismatches in train step functions and VJP implementation issues |
| docs/examples/quickstart_jax.ipynb | 4/5 | Comprehensive JAX/Flax quickstart notebook demonstrating four progressive Transformer implementations with timing benchmarks |
Confidence score: 2/5
- This PR contains critical logic errors that will prevent the code from running correctly, particularly around function signatures and VJP usage
- Score reflects multiple runtime-breaking issues: the train_step_fn argument count mismatch (5 passed vs 4 expected), vjp_fn expecting 3 cotangents but receiving only 1, and forward_fn closure issues that will cause failures when executed
- Pay close attention to docs/examples/quickstart_jax_utils.py lines 37-47 (speedometer train_step_fn calls), lines 53-82 (create_train_step_fn signature), and lines98-117 (VJP implementation)
2 files reviewed, 4 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the most recent changes made since the last review, not the entire PR. The PR adds a comprehensive JAX/Flax quickstart notebook (docs/examples/quickstart_jax.ipynb) and supporting utilities (docs/examples/quickstart_jax_utils.py) that mirror the existing PyTorch guide. The new documentation demonstrates progressive optimization of Transformer layers through four iterations: pure JAX/Flax baseline, basic Transformer Engine modules, mixed fused/unfused TE modules, and full TransformerLayer with FP8 support. Performance benchmarks are included showing training step time improvements from 27.9ms (baseline) to 8.0ms (FP8-optimized). The utilities file provides performance measurement tools, JIT-compiled training step creators using both jax.grad and explicit VJP approaches, baseline attention/MLP implementations, and parameter-sharing functions to ensure fair comparisons between Flax and TE models.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 4.5/5 | Comprehensive quickstart notebook demonstrating four progressive optimization stages with embedded performance benchmarks; minor type annotation issues (layernorm_eps declared as int instead of float on lines 360, 528) |
| docs/examples/quickstart_jax_utils.py | 2/5 | Utility functions for benchmarking, training step creation, and parameter sharing; contains critical logic errors including function signature mismatch in speedometer (calls with 5 args, function expects 4), incorrect parameter usage in create_train_step_fn_vjp where forward_fn and vjp_fn don't use provided parameters, unused parameters (model_init_fn, key), and shallow copy issues in parameter sharing functions |
Confidence score: 2/5
- This PR contains critical logic errors in the utilities file that will cause runtime failures when the notebook is executed
- Score reflects multiple function signature mismatches, incorrect parameter handling in VJP-based training steps, and shallow copying issues that could lead to unintended mutation of nested parameter dictionaries; the notebook itself is well-structured but depends entirely on the broken utilities
- Pay close attention to
docs/examples/quickstart_jax_utils.pylines 31-47 (signature mismatch inspeedometer), lines 98-116 (forward_fn/vjp_fnparameter usage), and lines 214-438 (all three parameter-sharing functions use shallowcopy()on nested dicts)
2 files reviewed, no comments
Signed-off-by: tdophung <tdophung@nvidia.com>
…NVIDIA#2308) Change the pyTorch installation to CUDA 13 in Build All GitHub action to match the version in the JAX container Signed-off-by: Przemek Tredak <ptredak@nvidia.com> Signed-off-by: tdophung <tdophung@nvidia.com>
2796e91 to
733d61b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds a comprehensive JAX-specific quickstart tutorial notebook that demonstrates four progressive implementations of a transformer layer: pure JAX/Flax, basic Transformer Engine (TE) modules, fused TE modules, and the full TE TransformerLayer with optional FP8 support. The notebook includes performance benchmarking utilities and parameter-sharing helpers to enable fair comparisons between implementations. A minor build infrastructure change updates the PyTorch dependency installation in the GitHub Actions workflow to explicitly use CUDA 13.0 wheels, ensuring compatibility with the JAX container environment. This documentation effort mirrors the structure of the existing PyTorch quickstart guide, providing JAX users with a dedicated migration and optimization path.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
.github/workflows/build.yml |
5/5 | Split torch installation into separate command with explicit CUDA 13.0 index URL |
docs/examples/quickstart_jax.ipynb |
3/5 | Added comprehensive JAX quickstart notebook with 4 progressive transformer implementations |
docs/examples/quickstart_jax_utils.py |
2/5 | Added utility functions for benchmarking and parameter sharing between JAX and TE models |
Confidence score: 2/5
- This PR contains critical runtime errors that will prevent the tutorial from executing successfully
- The main issue is function signature mismatch in
quickstart_jax_utils.py:speedometer()callstrain_step_fnwith 5 arguments butcreate_train_step_fn()returns a function expecting only 4 parameters, causing a guaranteed TypeError - Additional concerns include variable scope issues in the notebook (line 795 references
te_transformer_params_templatewhich may not be in scope after FP8 initialization) and unused parameters indicating incomplete implementation - Pay close attention to
docs/examples/quickstart_jax_utils.pylines 37-46 (function call mismatch) and the notebook cells around lines 787-801 (variable scoping)
Sequence Diagram
sequenceDiagram
participant User
participant Notebook as quickstart_jax.ipynb
participant Utils as quickstart_jax_utils.py
participant JAX as JAX/Flax
participant TE as Transformer Engine
User->>Notebook: Execute notebook cells
Notebook->>JAX: Import jax, jax.numpy, flax.linen
Notebook->>TE: Import transformer_engine.jax
Notebook->>Utils: Import quickstart_jax_utils
Note over Notebook,JAX: 1. Build BasicTransformerLayer (pure JAX/Flax)
Notebook->>JAX: Initialize BasicTransformerLayer
JAX-->>Notebook: Return initialized model
Notebook->>JAX: init(key, x) to create params
JAX-->>Notebook: Return params
Notebook->>Utils: speedometer(model_apply_fn, variables, input, output_grad)
Utils->>Utils: create_train_step_fn() - JIT compile fwd/bwd
Utils->>JAX: Run warmup iterations
Utils->>JAX: Run timing iterations (forward + backward)
JAX-->>Utils: Return loss and gradients
Utils-->>Notebook: Print mean time
Note over Notebook,TE: 2. Build BasicTETransformerLayer (TE modules)
Notebook->>TE: Initialize BasicTETransformerLayer
TE-->>Notebook: Return TE model
Notebook->>TE: init(key, x) to create TE params
TE-->>Notebook: Return TE params template
Notebook->>Utils: share_parameters_with_basic_te_model()
Utils-->>Notebook: Return shared params
Notebook->>Utils: speedometer() with TE model
Utils->>TE: Run timing iterations with TE layers
TE-->>Utils: Return loss and gradients
Utils-->>Notebook: Print improved mean time
Note over Notebook,TE: 3. Build FusedTETransformerLayer (fused ops)
Notebook->>TE: Initialize FusedTETransformerLayer
TE-->>Notebook: Return fused TE model
Notebook->>TE: init() with LayerNormDenseGeneral, LayerNormMLP
TE-->>Notebook: Return fused params
Notebook->>Utils: share_fused_parameters_with_basic_te_model()
Utils-->>Notebook: Return shared fused params
Notebook->>Utils: speedometer() with fused model
Utils->>TE: Run timing with fused kernels
TE-->>Utils: Return loss and gradients
Utils-->>Notebook: Print further improved time
Note over Notebook,TE: 4. Use TE TransformerLayer with FP8
Notebook->>TE: Create fp8_recipe (DelayedScaling)
Notebook->>TE: fp8_autocast(enabled=True, fp8_recipe)
Notebook->>TE: Initialize TransformerLayer within fp8_autocast
TE-->>Notebook: Return params with fp8_metas collection
Notebook->>Utils: share_parameters_with_transformerlayer_te_model()
Utils-->>Notebook: Return shared params with fp8_metas
Notebook->>Utils: speedometer() with fp8_autocast_kwargs
Utils->>TE: Run timing with FP8 precision
TE-->>Utils: Return loss and gradients (FP8)
Utils-->>Notebook: Print fastest mean time
Notebook-->>User: Display performance comparison
3 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review. The developer has cleaned up the JAX quickstart notebook by removing execution artifacts (output cells, execution counts), fixing typos and type annotations, and correcting markdown link syntax. The quickstart_jax_utils.py file was updated to remove an unused key parameter from train_step_fn calls and add a 2025 copyright header. However, a critical logic error remains unfixed in the notebook at line 772: the code references te_transformer_params_template which is undefined in the FP8 context (it was only defined outside the FP8 block in line 671). This variable should be te_transformer_params instead.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 2/5 | Cleaned notebook execution artifacts and fixed typos/types, but introduced critical undefined variable bug in FP8 initialization block (line 772) |
| docs/examples/quickstart_jax_utils.py | 3/5 | Removed unused key parameter from train step functions and added2025copyright; shallow copy operations may allow unintended parameter mutation |
Confidence score: 2/5
- This PR cannot be merged safely due to a critical logic error that will cause runtime failure.
- Score reflects one critical bug (undefined variable
te_transformer_params_templateat line772 in the notebook) that will break FP8 initialization when users try to execute the notebook. The developer needs to replacete_transformer_params_templatewithte_transformer_paramsin the FP8context. Additionally, the shallowcopy.deepcopyoperations in the utils file may cause subtle parameter-sharing bugs if nested dictionaries are mutated. - The notebook file (quickstart_jax.ipynb) requires immediate attention - specifically line 772 in the FP8 initialization block where an undefined variable is referenced.
2 files reviewed, 2 comments
d9859d1 to
b742244
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the latest changes to two documentation files: docs/examples/quickstart_jax.ipynb and docs/examples/quickstart_jax_utils.py. Recent commits have addressed previous formatting and documentation issues (markdown links, type annotations, typos), but critical runtime bugs remain unresolved. The quickstart notebook demonstrates progressive optimization of a JAX/Flax transformer using Transformer Engine, but the utility file contains a function signature mismatch that will cause immediate failures when users attempt to run the notebook.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 2/5 | Cleared execution artifacts and fixed minor issues, but contains a critical NameError on line 772 where undefined variable te_transformer_params_template is referenced in FP8 initialization block |
| docs/examples/quickstart_jax_utils.py | 1/5 | Added benchmarking utilities for JAX/Flax models with critical signature mismatch—train_step_fn called with 4 arguments (lines 41, 47) but function signatures expect 4 parameters (dropout_key required); will fail immediately at runtime |
Confidence score: 1/5
- This PR will cause immediate runtime failures when users attempt to follow the quickstart guide due to function signature mismatches and undefined variable references
- Score reflects two critical blocking issues: (1) the
train_step_fnsignature mismatch inquickstart_jax_utils.pywhere calls on lines 41 and 47 provide 4 arguments but the underlying functions requiredropout_keyas a parameter, and (2) the undefinedte_transformer_params_templatereference on line 772 of the notebook that will raise NameError - Both utility file functions (
create_train_step_fnandcreate_train_step_fn_vjp) and the FP8 initialization block in the notebook require immediate correction before this documentation can be published
2 files reviewed, no comments
docs/examples/quickstart_jax.ipynb
Outdated
| " )\n", | ||
| " x = attention(q, k, v, attention_mask, deterministic=deterministic)\n", | ||
| " \n", | ||
| " # Attention built-in. Comment out if not used\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a bit more detail to this comment, something like: If you would like to use Flax's built-in attention implementation, you can use this instead
Want to clarify to the users' that this is an alternative option and not some debug code we left by mistake
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this comment really got resolved. Let me make a suggestion how I think it could look like instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR introduces a comprehensive JAX/Flax quickstart tutorial that mirrors the PyTorch quickstart guide. The tutorial demonstrates four progressive implementations of a Transformer layer, showing performance improvements from pure JAX/Flax (~28ms) to fully optimized TE with FP8 (~8ms).
Key Changes:
- New
quickstart_jax.ipynbnotebook with 4 transformer implementations (pure JAX, basic TE, fused TE, full TE with FP8) - Supporting utility file
quickstart_jax_utils.pywith benchmarking functions and custom attention/MLP modules - Workflow update to use specific PyTorch CUDA index URL for build consistency
Issues Found:
- Several typos in documentation (already flagged in previous comments)
- Malformed markdown link in notebook cell
- Unused
datasetsimport - Minor documentation link format issues
Positive Aspects:
- Tutorial provides clear progressive optimization path
- Benchmarking utilities are well-structured and correct
- Good pedagogical approach with 4 iterations showing incremental improvements
- Code is functionally correct despite some previous incorrect review comments
Confidence Score: 4/5
- This PR is safe to merge after addressing minor documentation issues
- Score reflects that this is a documentation-only PR with correct functional code. The issues found are primarily cosmetic (typos, formatting) rather than logical errors. Several previous comments incorrectly flagged non-issues (e.g., VJP signature, missing imports that don't exist). The actual issues are: unused import, malformed markdown link, and typos - all easily fixable and non-blocking for a draft PR.
- docs/examples/quickstart_jax.ipynb needs attention for typos and formatting issues. The utility file and workflow changes are solid.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| .github/workflows/build.yml | 5/5 | Workflow updated to install torch from specific CUDA index URL - improves build consistency |
| docs/examples/quickstart_jax.ipynb | 3/5 | New JAX quickstart tutorial with 4 transformer implementations; contains typos and malformed markdown link already flagged |
| docs/examples/quickstart_jax_utils.py | 4/5 | Utility functions for speedometer benchmarking and custom attention/MLP modules; code is correct despite some previous incorrect comments |
Sequence Diagram
sequenceDiagram
participant User as User/Tutorial
participant Basic as BasicTransformerLayer
participant BasicTE as BasicTETransformerLayer
participant Fused as FusedTETransformerLayer
participant Full as TransformerLayer (TE)
participant Utils as quickstart_jax_utils
User->>Basic: Initialize with pure JAX/Flax
Basic->>Basic: nn.LayerNorm + nn.Dense (QKV)
Basic->>Utils: DotProductAttention
Utils-->>Basic: attention output
Basic->>Basic: nn.Dense (projection) + nn.Dropout
Basic->>Utils: BasicMLP
Utils-->>Basic: MLP output
Basic-->>User: ~28ms training step
User->>BasicTE: Replace with TE modules
BasicTE->>BasicTE: te_flax.LayerNorm + te_flax.DenseGeneral
BasicTE->>Utils: DotProductAttention
Utils-->>BasicTE: attention output
BasicTE->>BasicTE: te_flax.DenseGeneral + nn.Dropout
BasicTE->>BasicTE: BasicTEMLP (te_flax.DenseGeneral)
BasicTE-->>User: ~17ms training step
User->>Fused: Use fused TE operations
Fused->>Fused: te_flax.LayerNormDenseGeneral (fused)
Fused->>Utils: DotProductAttention
Utils-->>Fused: attention output
Fused->>Fused: te_flax.DenseGeneral + nn.Dropout
Fused->>Fused: te_flax.LayerNormMLP (fused)
Fused-->>User: ~18ms training step
User->>Full: Use full TE TransformerLayer
Full->>Full: Complete optimized implementation
Full-->>User: ~12ms training step
User->>Full: Enable FP8 with fp8_autocast
Full->>Full: FP8 precision compute
Full-->>User: ~8ms training step
3 files reviewed, 1 comment
|
There seems to be some issue with the commits mixed up. Could you rebase on top of the current main @tdophung ? |
docs/examples/quickstart_jax.ipynb
Outdated
| "- `DotProductAttention`: `DotProductAttention` from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n", | ||
| "- `Projection`: `nn.Dense` (JAX/Flax)\n", | ||
| "- `Dropout`: `nn.Dropout` (JAX/Flax)\n", | ||
| "- `MLP`: `BasicMLP` from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ksivaman putting on your radar that this will require more substitutions when we create the docs package.
| "- `nn.Dropout`: JAX/Flax Dropout\n", | ||
| "- `BasicMLP`: Custom MLP from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n", | ||
| "\n", | ||
| "<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this sentence - it sounds too informal. A quick pass with copilot gave this instead:
The code snippet below also includes commented examples demonstrating how to utilize the built-in attention sub-layer from either pure Flax or TE Flax. These alternatives are provided should you prefer them over the custom attention implementation found in quickstart_jax_utils.py. This reference is intended to offer insight into the general structure and approach used for implementing attention mechanisms in our source code.
I don't really like that one either (although it sounds better), but maybe we could use it to get something inbetween.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of having both the flax.linen.MultiHeadDotProductAttention and the self-built attention? If the self-built attention is implemented correctly, they should have exactly the same ops as the linen one, right?
I think it will be less confusing to simply start with the linen attention module directly, then swap it with the TE one. In that way, we can expose those two modules in this tutorial directly rather than hiding their details in the util file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it is helpful to have the self-built attention since it shows how attention works and also stay close to the Pytorch tutorial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds comprehensive JAX quickstart documentation mirroring the PyTorch tutorial structure, demonstrating progressive performance optimization through four iterations of a Transformer layer implementation.
Key Changes:
- New Jupyter notebook
quickstart_jax.ipynb(699 lines) with hands-on tutorial - Supporting utility file
quickstart_jax_utils.py(267 lines) containing reusable components - CI workflow update to separate PyTorch installation with explicit CUDA 13.0 index URL
Tutorial Progression:
- Pure JAX/Flax baseline (~26ms) - Custom attention and MLP implementations
- Basic TE modules (~16ms) - Using
te_flax.LayerNormandDenseGeneral - Fused TE modules (~16ms) - Using
LayerNormDenseGeneralandLayerNormMLP - Full TE TransformerLayer (~11ms) - Complete optimized implementation
- FP8 enabled (~7ms) - 73% speedup over baseline with FP8 precision
Issues Found:
- Multiple typos in documentation ('similicity', 'attetntion', 'inistead', 'linnen', 'buiil-in')
- Hardcoded absolute path
os.chdir('/workspace/docs/examples')will break for users - Malformed markdown link syntax
[quickstart_jax_utils.py] (**)(quickstart_jax_utils.py) - Extra quote in string
'no_mask''should be'no_mask' - Unused
datasetsimport in notebook layernorm_epsdefined asintinstead offloat(style issue, doesn't affect functionality)
Confidence Score: 3/5
- Safe to merge after fixing documentation issues - no functional bugs found, only typos and hardcoded paths that will impact user experience
- Score of 3 reflects that while the code is functionally correct and the tutorial successfully demonstrates TE capabilities, there are several documentation quality issues that should be fixed: hardcoded paths will break for users running the notebook outside the expected environment, multiple typos reduce professionalism, and malformed markdown links affect navigation. The CI change is safe and improves dependency management.
- docs/examples/quickstart_jax.ipynb requires fixing hardcoded paths (line 59) and typos throughout; docs/examples/quickstart_jax_utils.py needs typo corrections
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| .github/workflows/build.yml | 5/5 | Splits PyTorch installation into separate command with explicit CUDA index URL for better dependency management |
| docs/examples/quickstart_jax_utils.py | 3/5 | New utility file with attention implementations and speedometer function; contains several typos and documentation formatting issues flagged in previous comments |
| docs/examples/quickstart_jax.ipynb | 3/5 | New JAX quickstart tutorial notebook demonstrating progressive optimization with TE; contains hardcoded paths, typos, and malformed markdown links flagged in previous comments |
Sequence Diagram
sequenceDiagram
participant User
participant Notebook as quickstart_jax.ipynb
participant Utils as quickstart_jax_utils.py
participant JAX as JAX/Flax
participant TE as TransformerEngine
User->>Notebook: Run tutorial cells
Notebook->>JAX: Import jax, jax.numpy, flax.linen
Notebook->>Utils: Import quickstart_jax_utils
Notebook->>TE: Import transformer_engine.jax
Note over Notebook: Iteration 1: Pure JAX/Flax
Notebook->>JAX: Define BasicTransformerLayer
Notebook->>Utils: Use AttentionWrapper(CUSTOM_DOT_PRODUCT)
Notebook->>Utils: Use BasicMLP
Notebook->>Utils: speedometer() - measure baseline
Note over Notebook: Iteration 2: Basic TE
Notebook->>TE: Use te_flax.LayerNorm, DenseGeneral
Notebook->>Utils: Use AttentionWrapper(TE_FLAX_MULTIHEAD)
Notebook->>Utils: speedometer() - faster performance
Note over Notebook: Iteration 3: Fused TE
Notebook->>TE: Use LayerNormDenseGeneral, LayerNormMLP
Notebook->>Utils: speedometer() - similar to iteration 2
Note over Notebook: Iteration 4: Full TE TransformerLayer
Notebook->>TE: Use te_flax.TransformerLayer
Notebook->>Utils: speedometer() - best performance
Note over Notebook: Enable FP8
Notebook->>TE: fp8_autocast context manager
Notebook->>TE: Initialize with FP8 recipe
Notebook->>Utils: speedometer(fp8_autocast_kwargs) - fastest
Utils->>JAX: create_train_step_fn()
Utils->>JAX: jax.value_and_grad() for gradients
Utils->>JAX: jax.jit() for compilation
Utils-->>Notebook: Return timing results
3 files reviewed, no comments
|
|
Wonder why I cannot reply to Pawel's suggestion to do tabbing in html. But I intend for this MR to just be the notebook first and work on a separate PR for displaying on html since that will also touch the Python quickstart guide. This PR is just for a functional ipynb |
…some footnotes in the quick start notebook Signed-off-by: tdophung <tdophung@nvidia.com>
…comaptibility with speedometer Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…Layer Signed-off-by: tdophung <tdophung@nvidia.com> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
…e and unfused TE impls to achieve same performance (removing extra dropout layer in fused layers. Also some minor wording changes Signed-off-by: tdophung <tdophung@nvidia.com> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
b70bddf to
c6b6ca3
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds comprehensive JAX quickstart documentation demonstrating progressive optimization of Transformer layers using TransformerEngine. The tutorial walks through 4 implementations showing increasing performance gains (26ms → 7ms).
Key Changes:
- New Jupyter notebook (
quickstart_jax.ipynb) with 4 progressive transformer implementations - Supporting utility module (
quickstart_jax_utils.py) with attention wrappers and performance benchmarking - CI workflow update for PyTorch installation with explicit CUDA index URL
Implementation Progression:
- BasicTransformerLayer: Pure JAX/Flax baseline (~26ms per step)
- BasicTETransformerLayer: Using TE's
DenseGeneralandLayerNormmodules (~16ms) - FusedTETransformerLayer: Leveraging fused kernels like
LayerNormDenseGeneral(~16ms) - Full TransformerLayer with FP8: Complete TE layer with FP8 precision (~7ms)
Strengths:
- Well-structured tutorial mirroring PyTorch quickstart for consistency
- Clear performance comparisons with timing measurements
- Good documentation of FP8 initialization requirements
- Comprehensive utility code with multiple attention implementations
Previous Review Feedback Addressed:
Most syntax issues (typos, formatting) and several logic concerns from previous reviews appear to have been addressed in recent commits, including fixing attention implementation inconsistencies and adding proper enum for attention types.
Confidence Score: 4/5
- This documentation PR is safe to merge with minor issues already flagged in previous reviews
- Score reflects that this is a documentation-only PR with comprehensive tutorial content. The code examples are well-structured and functional. Previous review comments have identified syntax issues (typos, formatting) that should be addressed, but these are non-blocking for a documentation PR. The utility code is correct and the notebook provides educational value. The CI workflow change is sensible and low-risk.
- All files are in good shape. The notebook may have minor typos/formatting issues noted in previous comments that should be cleaned up before merge.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 4/5 | Comprehensive JAX quickstart notebook with 4 progressive transformer implementations. Most previous syntax/typo issues appear addressed. Minor formatting improvements needed. |
| docs/examples/quickstart_jax_utils.py | 5/5 | Well-structured utility module with attention implementations and performance testing. Code is correct and follows best practices. |
Sequence Diagram
sequenceDiagram
participant User as User/Notebook
participant Utils as quickstart_jax_utils
participant JAX as JAX/Flax
participant TE as TransformerEngine
Note over User: 1. BasicTransformerLayer (Pure JAX/Flax)
User->>JAX: Initialize BasicTransformerLayer
JAX-->>User: params
User->>Utils: speedometer(model, params, data)
Utils->>Utils: create_train_step_fn()
Utils->>JAX: Forward/Backward Pass
JAX-->>Utils: loss, gradients
Utils-->>User: Mean time: ~26ms
Note over User: 2. BasicTETransformerLayer (TE modules)
User->>TE: Initialize with DenseGeneral, LayerNorm
TE-->>User: te_params
User->>Utils: speedometer(model, params, data)
Utils->>TE: Forward/Backward with TE modules
TE-->>Utils: loss, gradients
Utils-->>User: Mean time: ~16ms
Note over User: 3. FusedTETransformerLayer (Fused TE)
User->>TE: Initialize with LayerNormDenseGeneral, LayerNormMLP
TE-->>User: fused_params
User->>Utils: speedometer(model, params, data)
Utils->>TE: Forward/Backward with fused kernels
TE-->>Utils: loss, gradients
Utils-->>User: Mean time: ~16ms
Note over User: 4. TransformerLayer with FP8
User->>TE: fp8_autocast context
User->>TE: Initialize TransformerLayer
TE-->>User: params + fp8_metas
User->>Utils: speedometer(model, params, data, fp8_enabled=True)
Utils->>TE: Forward/Backward with FP8 precision
TE-->>Utils: loss, gradients
Utils-->>User: Mean time: ~7ms
2 files reviewed, no comments
| "- `nn.Dropout`: JAX/Flax Dropout\n", | ||
| "- `BasicMLP`: Custom MLP from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n", | ||
| "\n", | ||
| "<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of having both the flax.linen.MultiHeadDotProductAttention and the self-built attention? If the self-built attention is implemented correctly, they should have exactly the same ops as the linen one, right?
I think it will be less confusing to simply start with the linen attention module directly, then swap it with the TE one. In that way, we can expose those two modules in this tutorial directly rather than hiding their details in the util file.
docs/examples/quickstart_jax.ipynb
Outdated
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "Mean time: 16.173095703125 ms\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we add an option to switch between the two Attention modules in the call function, then we have 2 blocks of time measurement, one for the FusedTETransformerLayer + linen.DPA, one for FusedTETransformerLayer + te_flax.DPA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the latest version, I changed the structure of the tutorial to bring FP8 up and Fused down as suggested before, and added the comparison between BasicTETransformerLayer + linen.DPA and BasicTETransformerLayer + te_flax.DPA already, so I do not think it is necessary to do the same for Fused.
Also, since Fused is now only under the FP8 section, there is no XLA ffi for linen.DPA with fp8 autocast so I could not have done this anyways. Please me know what you think
| "id": "a45c12c8", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite misleading, as it is right now.
It shows that the TE Transformer layer is 11ms while the FusedTETransformerLayer is 16ms (which was measured with linen.DPA). So with my above suggestions, we should be able to show that TE Transformer and FusedTETransformerLayer have the same runtime.
I don't think the TransformerLayer could be any more performant than the FusedTETransformerLayer, as the performant part comes from the fact that we try to fuse kernels together in the fused modules. In other words, it should not be different if users use our TransformerLayer or make their own FusedTETransformerLayer by stacking the fused modules.
So the TransformerLayer is there just for convenient usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the FusedTETransformerLayer is 16ms, and the BasicTETransformerLayer are both at 16ms and they both already use te_flax.DPA from transformer_engine.jax.flax.transformer import DotProductAttention, NOT linen.DPA.
I was not aware that TE Transformer layer (11ms) is not supposed to be much faster than the other 2 before this. I can try to take a look at why that is the case (that the other 2 are 16ms even with te_flax.DPA). But please comment if you see anything more apparent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have figured out why this was the case. Simply put I passed the wrong expected shape of input into the TransformerLayer while passing the correct expected shape into the Fused version.
I initiated the input in the same shape as what pytorch would expect (seq_len, batch, hidden] while the TransformerLayer takes in [batch, seq_len, hidden] by default. in te_flax.DPA I correctly set transpose_batch_sequence=True in the argument to account for this, but in TransformerLayer I did not. Therefore the max_seqlen was set wrong and TransformerLayer did a lot less work than it supposed to (batch=4, seqlen=2048)
After I fixed this in the most recent commit, this is fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tdophung now that the performance matches, let's remove "it provides the highest degree of performance optimization:" since TransformerLayer and individual usage of fused LayerNormMLP, etc. are equivalent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unrelated and will go away after rebasing, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right
docs/examples/quickstart_jax.ipynb
Outdated
| "\n", | ||
| "## Overview\n", | ||
| "\n", | ||
| "Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It provides support for 8-bit floating point (FP8) precision on Hopper GPUs
We should update this to a newer message. I cc'd you on the same discussion about our TE GitHub description now that we support nvfp4. You can copy whichever wording we decide there into here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the wording that the GitHub repo description got updated to
8-bit and 4-bit floating point (FP8 and FP4) precision on Hopper, Ada and Blackwell GPUs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have already updated it to
"It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs"
which is more precise because we do not have 4-bit precision on Hopper or Ada?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice! Missed that. I agree, that sounds good, thanks!
docs/examples/quickstart_jax.ipynb
Outdated
| "\n", | ||
| "<b>Important: FP8 Metadata Initialization</b>\n", | ||
| "\n", | ||
| "When using FP8, the model **must be initialized within the `fp8_autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `fp8_autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this description of fp8_metas copied from the PyTorch tutorial? I think PyTorch needs this metadata to store all their scaling factors for every recipe since it's not easily supported in PyTorch. But in TE/JAX, we don't need this for scaling factors. We only need this for recipes that store state across iterations, which is only DelayedScaling's amax history and scale
We also do have some partial support for creating quantizers outside of autocast contexts, so I don't think the ScopeCollectionNotFound error applies to JAX. But if you have observed, lmk and I can take a look. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was not copied from PyTorch tutorial.
I actually observed this while creating the JAX tutorial, hence the warning. Please take a look!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to know, I'll take a look
…ch, ...] instead of [batch, sequence,...] Signed-off-by: tdophung <tdophung@nvidia.com>
…or fuse to take effect because quantization exist as suggested. Also make TransformerLayer perf get closer to Fused by setting hidden_dropout=0 Signed-off-by: tdophung <tdophung@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds a comprehensive JAX quickstart tutorial that mirrors the existing PyTorch guide, demonstrating progressive optimization of a Transformer layer through 5 iterations:
- Iteration 1: Pure JAX/Flax baseline (28.61ms)
- Iteration 2: Basic TE modules with
DenseGeneralandLayerNorm(17.11ms, ~40% speedup) - Iteration 3: TE with FP8 precision (13.34ms, ~53% speedup)
- Iteration 4: Fused TE modules with
LayerNormDenseGeneralandLayerNormMLP(13.05ms, ~54% speedup) - Iteration 5: Full
TransformerLayermodule (12.95ms, ~55% speedup)
The tutorial includes a helper utilities file (quickstart_jax_utils.py) that provides modular attention implementations and benchmarking tools. The code is well-structured with clear explanations of each optimization step.
Previous review comments have already identified and addressed the main issues: typos in documentation, formatting inconsistencies, and minor style improvements. The workflow change updates PyTorch installation to use CUDA 13.0 index.
Confidence Score: 4/5
- Safe to merge after addressing the previously flagged syntax/style issues
- This is a documentation-only PR adding a tutorial notebook. The code is functional (as evidenced by the execution outputs) and well-structured. Previous comments have already identified the main issues (typos, formatting, style suggestions). No critical logical errors or security concerns exist. Score of 4 reflects that while the content is solid, the previously flagged corrections should be applied before merge.
- Both files need minor cleanup for typos and style improvements already flagged in previous comments
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 4/5 | Comprehensive JAX quickstart tutorial with 4 implementation iterations, minor typos already flagged |
| docs/examples/quickstart_jax_utils.py | 4/5 | Helper utilities for JAX quickstart notebook, provides attention and training wrappers with proper modularity |
Sequence Diagram
sequenceDiagram
participant User
participant Notebook as quickstart_jax.ipynb
participant Utils as quickstart_jax_utils.py
participant JAX as JAX/Flax
participant TE as TransformerEngine
User->>Notebook: Start tutorial
Note over Notebook,JAX: Iteration 1: Pure JAX/Flax
Notebook->>JAX: Create BasicTransformerLayer
JAX->>Utils: Use AttentionWrapper (CUSTOM_DOT_PRODUCT)
Utils->>JAX: Return attention output
JAX->>Notebook: Return layer output
Notebook->>Utils: Call speedometer() for benchmarking
Utils->>Notebook: Report 28.61ms per step
Note over Notebook,TE: Iteration 2: Basic TE modules
Notebook->>TE: Create BasicTETransformerLayer
TE->>Utils: Use AttentionWrapper (TE_FLAX_MULTIHEAD)
Utils->>TE: Return TE attention output
TE->>Notebook: Return layer output
Notebook->>Utils: Call speedometer()
Utils->>Notebook: Report 17.11ms per step
Note over Notebook,TE: Iteration 3: TE with FP8
Notebook->>TE: Wrap in fp8_autocast context
TE->>TE: Initialize FP8 metadata
Notebook->>TE: Apply BasicTETransformerLayer with FP8
TE->>Notebook: Return FP8-optimized output
Notebook->>Utils: Call speedometer() with FP8
Utils->>Notebook: Report 13.34ms per step
Note over Notebook,TE: Iteration 4: Fused TE modules
Notebook->>TE: Create FusedTETransformerLayer
TE->>TE: Use LayerNormDenseGeneral & LayerNormMLP
TE->>Notebook: Return fused layer output
Notebook->>Utils: Call speedometer() with FP8
Utils->>Notebook: Report 13.05ms per step
Note over Notebook,TE: Iteration 5: Full TransformerLayer
Notebook->>TE: Create te_flax.TransformerLayer
TE->>TE: Apply complete optimized layer
TE->>Notebook: Return optimized output
Notebook->>Utils: Call speedometer() with FP8
Utils->>Notebook: Report 12.95ms per step
User->>Notebook: Complete tutorial
2 files reviewed, no comments
…ll of BasicTETransformerLayer and demonstrated difference in runtime between using flax and using te's attetion implementation Signed-off-by: tdophung <tdophung@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds a comprehensive JAX/Flax quickstart guide for Transformer Engine, mirroring the PyTorch tutorial. It demonstrates 4 progressive implementations of a Transformer encoder layer with increasing optimization: pure Flax, basic TE modules, partially fused TE modules, and fully fused TransformerLayer.
Key Changes:
- New Jupyter notebook
quickstart_jax.ipynbwith performance benchmarks showing ~2x speedup progression from pure Flax (25.5ms) to full TE with FP8 (13ms) - Utility file
quickstart_jax_utils.pywith attention wrappers, MLP implementations, and training speedometer - Modified CI workflow to install PyTorch from specific CUDA index
Issues Found:
- Critical:
FusedTETransformerLayerreferences undefinedself.attention_implementationattribute (line 334) - will cause runtime failure - Multiple issues already flagged in previous comments include typos, hardcoded paths, import issues, and parameter mismatches in utility functions
Confidence Score: 2/5
- This PR has critical runtime errors that will prevent the notebook from executing successfully
- While the documentation quality is good and demonstrates useful features, there's a critical bug where
FusedTETransformerLayeruses an undefined attribute. Combined with numerous other issues flagged in previous comments (import errors, parameter mismatches, typos), this PR needs significant fixes before merge - Both
docs/examples/quickstart_jax.ipynbanddocs/examples/quickstart_jax_utils.pyrequire attention - the notebook has the critical undefined attribute bug and the utils file has multiple parameter mismatch issues
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 2/5 | Tutorial notebook with critical bug: FusedTETransformerLayer uses undefined attention_implementation attribute. Also contains multiple typos and hardcoded paths flagged in previous reviews |
| docs/examples/quickstart_jax_utils.py | 3/5 | Utility functions with several parameter mismatches and unused parameters already flagged. Core logic appears sound but needs cleanup for robustness |
Sequence Diagram
sequenceDiagram
participant User
participant Notebook
participant TransformerLayer
participant Speedometer
participant TrainStepFn
participant FP8Autocast
participant JAXBackend
User->>Notebook: Initialize model config
Notebook->>TransformerLayer: init(key, x, attention_mask)
TransformerLayer-->>Notebook: Return initialized params
User->>Speedometer: speedometer(model_apply_fn, variables, input, output_grad)
Speedometer->>TrainStepFn: create_train_step_fn(model_apply_fn, fp8_autocast_kwargs)
TrainStepFn-->>Speedometer: Return JIT-compiled train_step_fn
loop Warmup Iterations
Speedometer->>TrainStepFn: train_step_fn(variables, input, output_grad, key)
TrainStepFn->>FP8Autocast: Enter fp8_autocast context
FP8Autocast->>TransformerLayer: apply(variables, input)
TransformerLayer->>JAXBackend: Forward pass (LayerNorm, QKV, Attention, MLP)
JAXBackend-->>TransformerLayer: Output tensor
TransformerLayer-->>FP8Autocast: Return output
FP8Autocast->>JAXBackend: Compute gradients (VJP)
JAXBackend-->>FP8Autocast: param_grads, other_grads
FP8Autocast-->>TrainStepFn: loss, gradients
TrainStepFn-->>Speedometer: loss, gradients
end
loop Timing Iterations
Speedometer->>TrainStepFn: train_step_fn(variables, input, output_grad, key)
Note over TrainStepFn,JAXBackend: Same forward/backward flow as warmup
TrainStepFn-->>Speedometer: loss, gradients
end
Speedometer-->>User: Print mean time per iteration
1 file reviewed, 1 comment
…ayer Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile OverviewUpdated On: 2025-11-06 20:06:38 UTC Greptile SummaryAdds comprehensive JAX quickstart documentation with a Jupyter notebook tutorial and utility functions. The tutorial demonstrates four progressive implementations of a Transformer layer, from pure JAX/Flax to fully-fused TE modules with FP8 support. Major issues found:
Minor issues:
Confidence Score: 2/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant Notebook as quickstart_jax.ipynb
participant Utils as quickstart_jax_utils.py
participant JAX/Flax
participant TE as TransformerEngine
User->>Notebook: Run tutorial
Notebook->>Utils: Import helper functions
Note over Notebook: Step 1: Basic Transformer
Notebook->>JAX/Flax: Create BasicTransformerLayer
Notebook->>Utils: Use AttentionWrapper(FLAX_LINEN_MULTIHEAD)
Utils->>JAX/Flax: Use nn.MultiHeadDotProductAttention
Notebook->>Utils: speedometer() to measure performance
Note over Notebook: Step 2: Basic TE
Notebook->>TE: Create BasicTETransformerLayer
Notebook->>Utils: Use AttentionWrapper(TE_FLAX_MULTIHEAD)
Utils->>TE: Use TEDotProductAttention
Notebook->>Utils: speedometer() with FP8 disabled
Note over Notebook: Step 3: TE with FP8
Notebook->>TE: fp8_autocast(enabled=True)
Notebook->>TE: Initialize model within autocast
Notebook->>Utils: speedometer() with FP8 enabled
Note over Notebook: Step 4: Fused TE
Notebook->>TE: Create FusedTETransformerLayer
Note over TE: Uses LayerNormDenseGeneral,<br/>LayerNormMLP
Notebook->>Utils: speedometer() with FP8 enabled
Note over Notebook: Step 5: Full TransformerLayer
Notebook->>TE: Use te_flax.TransformerLayer
Note over TE: Highest optimization level
Notebook->>Utils: speedometer() with FP8 enabled
Utils->>User: Performance metrics for each step
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
…mpl only, removing last mention of Pytorch Signed-off-by: tdophung <tdophung@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds comprehensive JAX quickstart tutorial demonstrating progressive optimization with Transformer Engine. The tutorial walks through 4 implementation iterations of a Transformer encoder layer, showing performance improvements from pure Flax (17.7ms) to fully optimized TE with FP8 (9.7ms).
Major changes:
- New Jupyter notebook tutorial (
quickstart_jax.ipynb) with 16 code cells demonstrating TE integration - Utility module (
quickstart_jax_utils.py) providing benchmarking functions for forward/backward pass timing - Unrelated CI workflow modification to PyTorch installation method
Issues found:
- Multiple typos in documentation text and code comments (
similicity,attetntion,inistead,linnen,buiil-in,QKVoutput) - Hardcoded absolute path in notebook will break for users running from different directories
- Malformed markdown link syntax in notebook
- Logic error referencing undefined variable in FP8 section
- Function signature mismatches in utility file between
train_step_fndefinition and calls - Unused parameters (
model_init_fn,key) in utility functions - CI workflow change appears unrelated to JAX documentation PR
Structure: The tutorial effectively mirrors the PyTorch quickstart guide's architecture, making it easy for users familiar with the PyTorch version to follow along with the JAX implementation.
Confidence Score: 2/5
- This PR has multiple logic errors and typos that will affect user experience and potentially cause runtime failures
- Score reflects several critical issues: hardcoded paths that will break for users, logic errors with undefined variables that could cause runtime failures, function signature mismatches in utility code, and numerous typos throughout the documentation. While the tutorial structure and approach are sound, these implementation issues need to be resolved before merge.
- Pay close attention to
docs/examples/quickstart_jax.ipynb(hardcoded paths, undefined variable references, typos) anddocs/examples/quickstart_jax_utils.py(function signature mismatches). Also verify whether the.github/workflows/build.ymlchange is intentional.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 2/5 | JAX quickstart tutorial with multiple typos, hardcoded paths, malformed markdown links, and logic errors referencing undefined variables |
| docs/examples/quickstart_jax_utils.py | 2/5 | Utility functions for speedometer benchmarking with function signature mismatches and unused parameters |
| .github/workflows/build.yml | 3/5 | CI workflow modification to PyTorch installation - appears unrelated to JAX documentation changes |
Sequence Diagram
sequenceDiagram
participant User
participant Notebook as quickstart_jax.ipynb
participant Utils as quickstart_jax_utils.py
participant Flax as Flax Modules
participant TE as Transformer Engine
User->>Notebook: Run tutorial
Notebook->>Flax: Create BasicTransformerLayer
Flax-->>Notebook: Pure JAX/Flax implementation
Notebook->>Utils: speedometer() benchmark
Utils->>Utils: create_train_step_fn()
Utils->>Flax: Forward/backward pass
Flax-->>Utils: Timing results
Utils-->>Notebook: Baseline performance
Notebook->>TE: Create BasicTETransformerLayer
TE-->>Notebook: TE modules (DenseGeneral, LayerNorm)
Notebook->>Utils: speedometer() benchmark
Utils->>TE: Forward/backward pass
TE-->>Utils: Improved timing
Utils-->>Notebook: TE performance
Notebook->>TE: Enable FP8 with fp8_autocast
TE-->>Notebook: FP8-enabled model
Notebook->>Utils: speedometer() with FP8
Utils->>TE: FP8 forward/backward pass
TE-->>Utils: FP8 timing
Utils-->>Notebook: FP8 performance gains
Notebook->>TE: Create FusedTETransformerLayer
TE-->>Notebook: Fused operations
Notebook->>Utils: speedometer() benchmark
Utils->>TE: Fused forward/backward pass
TE-->>Utils: Fused timing
Utils-->>Notebook: Best performance
Notebook->>TE: Use TransformerLayer (full TE)
TE-->>Notebook: Fully optimized layer
Notebook->>Utils: Final benchmark
Utils->>TE: Optimized forward/backward
TE-->>Utils: Optimal timing
Utils-->>Notebook: Maximum performance
2 files reviewed, 1 comment
| run: | | ||
| pip install pybind11[global] einops onnxscript |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: unrelated change to CI workflow in documentation PR
This PyTorch installation change seems unrelated to the JAX quickstart documentation. Should this be in a separate PR?
| "- `nn.Dropout`: JAX/Flax Dropout\n", | ||
| "- `BasicMLP`: Custom MLP implemented from `nn.Dense\n", | ||
| "\n", | ||
| "<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n",
This can be removed now that we've removed the custom attention impl
| "- `DotProductAttention`: Custom attention from [quickstart_jax_utils.py] (**)(quickstart_jax_utils.py)\n", | ||
| "- `nn.Dense`: JAX/Flax Dense layer for projection\n", | ||
| "- `nn.Dropout`: JAX/Flax Dropout\n", | ||
| "- `BasicMLP`: Custom MLP implemented from `nn.Dense\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: missing closing backtick ` on nn.Dense here
| "source": [ | ||
| "import transformer_engine.jax as te\n", | ||
| "import transformer_engine.jax.flax as te_flax\n", | ||
| "from transformer_engine.jax.quantize import is_fp8_available, ScalingMode" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
I don't see is_fp8_available or ScalingMode used later in the notebook. Let's remove them
| "id": "c9a101d3", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Enabling FP8\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Here and below can we replace FP8 -> Quantization unless specifically talking about an fp8-specific recipe. For the first mention of quantization, this title, we can mention FP8 and FP4 explicitly "Enabling Quantization (FP8 or FP4)"
| "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", | ||
| " basic_te_params = basic_te_transformer.init(key, x, attention_mask=None, deterministic=False)\n", | ||
| "\n", | ||
| " y = basic_te_transformer.apply(basic_te_params, x, attention_mask=None, deterministic=True)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this line? I don't see y used later
y = basic_te_transformer.apply(basic_te_params, x, attention_mask=None, deterministic=True)
| "\n", | ||
| "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", | ||
| " fused_te_params = fused_te_transformer.init(key, x, attention_mask=None, deterministic=False)\n", | ||
| " y = fused_te_transformer.apply(fused_te_params, x, attention_mask=None, deterministic=True)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again here, I don't see y used later. Can this be removed or is there other reasoning I'm missing, thanks!
| "id": "a45c12c8", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tdophung now that the performance matches, let's remove "it provides the highest degree of performance optimization:" since TransformerLayer and individual usage of fused LayerNormMLP, etc. are equivalent
| "\n", | ||
| "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", | ||
| " te_transformer_params = te_transformer.init(key, x, deterministic=False)\n", | ||
| " y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question as above for y = ...
| "- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html\n", | ||
| "\n", | ||
| "## Let's build a Transformer encoder layer (*)!\n", | ||
| "<small>(*) _This was based upon the GPT decoder layer, but for the sake of simplicity and mirroring the PyTorch tutorial whose defaults are without any attention mask, we are setting attention mask here also to 0, making the attention basically an encoder, which does not exist in the GPT architecture. However, since the code support any attention mask here in the TransformerLayer (later in this guide), we will leave it to the audience to try experimenting with different attention masks._</small>\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are no longer tied to the PyTorch tutorial and no longer have the custom attention, how hard would it be to use a causal mask so we can remove this caveat? Looks like Flax linen has a built-in function for creating the mask here: https://flax.readthedocs.io/en/v0.6.11/api_reference/_autosummary/flax.linen.make_causal_mask.html

Description
This change adds the new quick start notebook for Jax, mirroring the same Transformer architecture as the PyTorch guide, for users familiar with the PyTorch guide can also easily follow. Contains an 4 iterations of the layer with different training step time durations:
This might also include changes to how sphinx display this content on the HTML docs page in later commits
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: