-
Notifications
You must be signed in to change notification settings - Fork 565
[JAX] Add tutorial for integrating TE/JAX quantization into an existing framework #2423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Add tutorial for integrating TE/JAX quantization into an existing framework #2423
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds a comprehensive tutorial (
The implementation is clean and non-invasive - existing model parameters and initialization remain unchanged. Only the GEMM operation is replaced with TE's quantized version. The tutorial demonstrates ~1.8x speedup on the test configuration. Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant FlaxModule as Flax Dense Layer
participant TEWrapper as TE Wrapper Module
participant Quantizer as TE Quantizer Set
participant GEMM as TE Dense (Quantized GEMM)
User->>FlaxModule: Initialize with dot_general_cls=te_dot_general_cls
FlaxModule->>TEWrapper: Create wrapper instance
TEWrapper->>Quantizer: Initialize quantizer state (amax_history, scales)
User->>FlaxModule: Forward pass with input x
FlaxModule->>TEWrapper: Call dot_general(x, kernel, dims)
TEWrapper->>Quantizer: Get quantizer set
TEWrapper->>GEMM: dense(x, kernel, quantizer_set)
GEMM->>Quantizer: Quantize x, kernel, grad
Quantizer-->>GEMM: Return quantized tensors
GEMM-->>TEWrapper: Return output
TEWrapper-->>FlaxModule: Return output
FlaxModule-->>User: Return final output
Note over Quantizer: State updates (amax_history, scales)<br/>stored in Flax variables
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 2 comments
| "utils.speedometer(\n", | ||
| " model_apply_fn=flax_transformer.apply,\n", | ||
| " variables=params,\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Unresolved TODO comment about fixing sr_rng setup for NVFP4. Should this be resolved before publishing the tutorial?
| "cell_type": "code", | ||
| "execution_count": 3, | ||
| "id": "8b44649d", | ||
| "metadata": {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Typo in comment: signma should be sigma
| "metadata": {}, | |
| # grad_target = derivative of L (loss fn) over y (output) = sigma(L)/sigma(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
…oat16 cast Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 2 comments
| This method does a couple things: | ||
| 1. Wraps the given function in a context that specifies MaxText's physical mesh axes to TransformerEngine. This ensures our collective operations in TransformerEngine are using the correct axes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Docstring incorrectly mentions MaxText-specific behavior. This is generic TE functionality, not MaxText-specific.
| 1. Wraps the given function in a context that specifies MaxText's physical mesh axes to TransformerEngine. This ensures our collective operations in TransformerEngine are using the correct axes. | |
| 1. Wraps the given function in a Flax linen module context to support TransformerEngine quantization operations. |
|
|
||
|
|
||
| def make_dot_general_cls(quantization_recipe): | ||
| """Placeholder for dot_general implementation in subclasses.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Misleading docstring - says "Placeholder" but this is a complete implementation, not a placeholder.
| """Placeholder for dot_general implementation in subclasses.""" | |
| """Creates a dot_general_cls that wraps JAX dense operations with TransformerEngine quantization.""" |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
docs/examples/quickstart_jax_utils.py, line 79 (link)syntax: Typo:
signmashould besigma
5 files reviewed, 1 comment
| " v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n", | ||
| " \n", | ||
| " # Attention using Flax's MultiHeadDotProductAttention\n", | ||
| " attention = nn.MultiHeadDotProductAttention(\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: pull in Teddy's DPA changes from #2429
|
/te-ci jax |
Description
Adds a new notebook that is a tutorial for using TE/JAX quantization in an existing model framework.
Type of change
Changes
Checklist: