Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Nov 26, 2025

Description

Adds a new notebook that is a tutorial for using TE/JAX quantization in an existing model framework.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add new tutorial notebook for quantization integration into an existing model framework

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 26, 2025

Greptile Overview

Greptile Summary

This PR adds a comprehensive tutorial (te_jax_integration.ipynb) demonstrating how to integrate TransformerEngine quantization into existing JAX/Flax frameworks. The key changes include:

  • New tutorial notebook: Shows step-by-step integration of TE quantization using make_dot_general_cls() to create drop-in replacements for standard Flax Dense layers
  • New public API functions: wrap_function_in_te_state_module() and make_dot_general_cls() exposed in transformer_engine.jax.flax for easier framework integration
  • Refactored RNG handling: Updated quickstart_jax_utils.py to support multiple RNG keys via dictionary (rngs parameter) instead of single dropout_key, improving flexibility for recipes like NVFP4 that require additional RNG streams
  • Updated examples: All quickstart examples consistently use the new RNG dictionary format

The implementation is clean and non-invasive - existing model parameters and initialization remain unchanged. Only the GEMM operation is replaced with TE's quantized version. The tutorial demonstrates ~1.8x speedup on the test configuration.

Confidence Score: 5/5

  • This PR is safe to merge - it adds documentation and public API wrappers without breaking changes
  • The changes are well-contained to documentation and utility improvements. The new API functions are well-documented with clear docstrings and examples. The RNG refactoring is backwards-compatible in practice (empty dict defaults). All changes follow existing patterns and conventions in the codebase.
  • No files require special attention - the typo in quickstart_jax_utils.py:79 should be fixed but is minor

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/te_jax_integration.ipynb 5/5 New tutorial demonstrating TE/JAX integration with quantization recipes. Well-structured with clear examples and performance comparisons.
docs/examples/quickstart_jax_utils.py 4/5 Refactored to support multiple RNG keys via dictionary. Added helper function for RNG splitting. Contains typo in comment.
transformer_engine/jax/flax/module.py 5/5 Added two new functions for wrapping custom functions with TE quantization support. Well-documented with clear examples.

Sequence Diagram

sequenceDiagram
    participant User
    participant FlaxModule as Flax Dense Layer
    participant TEWrapper as TE Wrapper Module
    participant Quantizer as TE Quantizer Set
    participant GEMM as TE Dense (Quantized GEMM)

    User->>FlaxModule: Initialize with dot_general_cls=te_dot_general_cls
    FlaxModule->>TEWrapper: Create wrapper instance
    TEWrapper->>Quantizer: Initialize quantizer state (amax_history, scales)
    
    User->>FlaxModule: Forward pass with input x
    FlaxModule->>TEWrapper: Call dot_general(x, kernel, dims)
    TEWrapper->>Quantizer: Get quantizer set
    TEWrapper->>GEMM: dense(x, kernel, quantizer_set)
    GEMM->>Quantizer: Quantize x, kernel, grad
    Quantizer-->>GEMM: Return quantized tensors
    GEMM-->>TEWrapper: Return output
    TEWrapper-->>FlaxModule: Return output
    FlaxModule-->>User: Return final output
    
    Note over Quantizer: State updates (amax_history, scales)<br/>stored in Flax variables
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +234 to +236
"utils.speedometer(\n",
" model_apply_fn=flax_transformer.apply,\n",
" variables=params,\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Unresolved TODO comment about fixing sr_rng setup for NVFP4. Should this be resolved before publishing the tutorial?

"cell_type": "code",
"execution_count": 3,
"id": "8b44649d",
"metadata": {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Typo in comment: signma should be sigma

Suggested change
"metadata": {},
# grad_target = derivative of L (loss fn) over y (output) = sigma(L)/sigma(y)

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

…oat16 cast

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

This method does a couple things:
1. Wraps the given function in a context that specifies MaxText's physical mesh axes to TransformerEngine. This ensures our collective operations in TransformerEngine are using the correct axes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Docstring incorrectly mentions MaxText-specific behavior. This is generic TE functionality, not MaxText-specific.

Suggested change
1. Wraps the given function in a context that specifies MaxText's physical mesh axes to TransformerEngine. This ensures our collective operations in TransformerEngine are using the correct axes.
1. Wraps the given function in a Flax linen module context to support TransformerEngine quantization operations.



def make_dot_general_cls(quantization_recipe):
"""Placeholder for dot_general implementation in subclasses."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Misleading docstring - says "Placeholder" but this is a complete implementation, not a placeholder.

Suggested change
"""Placeholder for dot_general implementation in subclasses."""
"""Creates a dot_general_cls that wraps JAX dense operations with TransformerEngine quantization."""

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. docs/examples/quickstart_jax_utils.py, line 79 (link)

    syntax: Typo: signma should be sigma

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
" \n",
" # Attention using Flax's MultiHeadDotProductAttention\n",
" attention = nn.MultiHeadDotProductAttention(\n",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: pull in Teddy's DPA changes from #2429

@jberchtold-nvidia jberchtold-nvidia changed the title [Draft][JAX] Tutorial for integration TE/JAX quantization into an existing framework [JAX] Tutorial for integration TE/JAX quantization into an existing framework Nov 26, 2025
@jberchtold-nvidia jberchtold-nvidia changed the title [JAX] Tutorial for integration TE/JAX quantization into an existing framework [JAX] Add tutorial for integrating TE/JAX quantization into an existing framework Nov 26, 2025
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant