-
Notifications
You must be signed in to change notification settings - Fork 540
[JAX] Support for checkpointing quantizations #2356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Support for checkpointing quantizations #2356
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds support for checkpointing quantization operations in JAX by introducing an optional checkpoint_name parameter throughout the quantizer hierarchy and applying JAX's checkpoint_name function to primitive outputs.
Key Changes:
- Added
checkpoint_name: Optional[str]field to all quantizer classes (Quantizer,DelayedScaleQuantizer,NVFP4Quantizer,GroupedQuantizer) - Updated
QuantizerFactory.create()andQuantizerFactory.create_set()to accept and propagatecheckpoint_name - Modified quantization primitives in
activation.py,normalization.py, andquantization.pyto applycheckpoint_name()to outputs when set - Added
quantization_checkpoint_nameparameter to Flax modules (DenseGeneral,LayerNormDenseGeneral,LayerNormMLP)
Critical Issue Found:
- In
NVFP4Quantizer.tree_flatten()(line 631-638),checkpoint_nameis inserted beforeuse_rhtin theaux_datatuple, but the dataclass field order hascheckpoint_nameas the 5th field (from parent) anduse_rhtas the 6th field. This will cause incorrect deserialization whentree_unflattenunpacks the tuple
Confidence Score: 1/5
- This PR contains a critical serialization bug that will cause runtime failures
- The bug in NVFP4Quantizer.tree_flatten will cause incorrect field assignment during deserialization, leading to runtime errors when quantizers are serialized/deserialized. The checkpoint_name field order in aux_data doesn't match the dataclass field order
- transformer_engine/jax/quantize/quantizer.py requires immediate attention to fix the tree_flatten field order
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/quantize/quantizer.py | 1/5 | Added checkpoint_name parameter throughout quantizer hierarchy; critical bug in NVFP4Quantizer.tree_flatten where checkpoint_name is inserted before use_rht, breaking serialization/deserialization |
| transformer_engine/jax/cpp_extensions/activation.py | 4/5 | Added checkpoint support for activation quantizations by applying checkpoint_name to primitive outputs; implementation looks correct |
| transformer_engine/jax/cpp_extensions/normalization.py | 4/5 | Added checkpoint support for normalization quantizations by applying checkpoint_name to primitive outputs; implementation looks correct |
| transformer_engine/jax/cpp_extensions/quantization.py | 4/5 | Added checkpoint support for quantization operations by applying checkpoint_name to primitive outputs; implementation looks correct |
| transformer_engine/jax/flax/module.py | 5/5 | Added quantization_checkpoint_name parameter to DenseGeneral, LayerNormDenseGeneral, and LayerNormMLP modules, properly threaded through to quantizer creation |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Module as Flax Module<br/>(DenseGeneral/LayerNorm)
participant Factory as QuantizerFactory
participant Quantizer as Quantizer<br/>(DelayedScale/NVFP4)
participant Primitive as JAX Primitive<br/>(ActLu/Norm/Quantize)
participant Checkpoint as checkpoint_name
User->>Module: Call with quantization_checkpoint_name
Module->>Factory: generate_quantizer_set(checkpoint_name)
Factory->>Quantizer: create quantizer with checkpoint_name
Quantizer-->>Module: return quantizer_set
Module->>Primitive: bind() with inputs
Primitive-->>Module: return prim_outputs (tuple)
alt checkpoint_name is not None
Module->>Checkpoint: checkpoint_name(prim_outputs, name)
Checkpoint-->>Module: return checkpointed outputs
end
Module->>Module: unpack outputs tuple
Module->>Quantizer: update(amax)
Module-->>User: return quantized tensor
5 files reviewed, 1 comment
| aux_data = ( | ||
| self.q_dtype, | ||
| self.scaling_mode, | ||
| self.q_layout, | ||
| self.data_layout, | ||
| self.checkpoint_name, | ||
| self.use_rht, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: checkpoint_name inserted before use_rht in aux_data breaks deserialization. When tree_unflatten unpacks *aux_data, the 6th element (use_rht) will be assigned to the 5th field (checkpoint_name), causing a mismatch.
| aux_data = ( | |
| self.q_dtype, | |
| self.scaling_mode, | |
| self.q_layout, | |
| self.data_layout, | |
| self.checkpoint_name, | |
| self.use_rht, | |
| ) | |
| aux_data = ( | |
| self.q_dtype, | |
| self.scaling_mode, | |
| self.q_layout, | |
| self.data_layout, | |
| self.use_rht, | |
| self.checkpoint_name, | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is incorrect. checkpoint_name is the 5th field here since it is inherited from its parent and is 5th in the flattened list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right! Looking at the tree_unflatten method, it expects the parameters in a specific order. The NVFP4Quantizer inherits from Quantizer, so checkpoint_name should be in the same position as defined in the parent class.
In the base Quantizer class, the tree_flatten method puts checkpoint_name as the 5th element (index 4), but in NVFP4Quantizer, you've placed use_rht in that position and moved checkpoint_name to the 6th position. This will cause deserialization issues.
| aux_data = ( | |
| self.q_dtype, | |
| self.scaling_mode, | |
| self.q_layout, | |
| self.data_layout, | |
| self.checkpoint_name, | |
| self.use_rht, | |
| ) | |
| aux_data = ( | |
| self.q_dtype, | |
| self.scaling_mode, | |
| self.q_layout, | |
| self.data_layout, | |
| self.checkpoint_name, | |
| self.use_rht, | |
| ) |
|
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds optional checkpointing support for quantization operations to avoid rematerialization in frameworks like MaxText. The implementation adds a checkpoint_name field to all quantizer classes and wraps primitive outputs with jax.ad_checkpoint.checkpoint_name when specified.
Key changes:
- Added
checkpoint_namefield to baseQuantizerclass and all subclasses (DelayedScaleQuantizer,NVFP4Quantizer,GroupedQuantizer) - Updated
tree_flattenmethods to includecheckpoint_namein serialization for all quantizer types - Modified primitive wrappers in
activation.py,normalization.py, andquantization.pyto conditionally wrap outputs withcheckpoint_name - Extended Flax modules (
DenseGeneral,LayerNormDenseGeneral,LayerNormMLP) to accept and pass throughquantization_checkpoint_nameparameter - Updated
QuantizerFactory.createandQuantizerFactory.create_setto accept and propagatecheckpoint_name
Issues found:
- Critical serialization bug in
NVFP4Quantizer.tree_flattenwherecheckpoint_nameis inserted beforeuse_rhtin the aux_data tuple, breaking deserialization order (already flagged in previous comment)
Confidence Score: 2/5
- This PR has a critical serialization bug that will break NVFP4Quantizer deserialization
- The implementation pattern is solid and consistent across all files, but the
NVFP4Quantizer.tree_flattenmethod has incorrect field ordering in aux_data that will cause deserialization to fail or produce incorrect quantizer state. Thecheckpoint_namefield is inserted beforeuse_rht, buttree_unflattenexpects them in the order they appear in the class definition - transformer_engine/jax/quantize/quantizer.py - Fix NVFP4Quantizer serialization order and add tests for quantizer serialization/deserialization
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/quantize/quantizer.py | 2/5 | Adds checkpoint_name field to quantizer classes and updates serialization, but has critical deserialization bug in NVFP4Quantizer where field order mismatch will cause incorrect unpacking |
| transformer_engine/jax/cpp_extensions/activation.py | 5/5 | Adds checkpoint_name wrapping around primitive outputs for act_lu and quantize_dact_dbias functions, consistent implementation with proper tuple unpacking |
| transformer_engine/jax/cpp_extensions/normalization.py | 5/5 | Adds checkpoint_name wrapping for layernorm_fwd and rmsnorm_fwd primitives, consistent with activation.py pattern |
| transformer_engine/jax/cpp_extensions/quantization.py | 5/5 | Adds checkpoint_name wrapping to _quantize_dbias_impl function, following the same pattern as other primitive wrappers |
| transformer_engine/jax/flax/module.py | 5/5 | Adds quantization_checkpoint_name parameter to Dense, LayerNormDenseGeneral, and LayerNormMLP modules, properly threading it through to quantizer creation |
Sequence Diagram
sequenceDiagram
participant User as User/Framework
participant Module as Flax Module<br/>(Dense/LayerNorm)
participant Factory as QuantizerFactory
participant Quantizer as Quantizer
participant Primitive as Primitive Wrapper<br/>(activation/norm/quant)
participant JAX as JAX checkpoint_name
User->>Module: Call with quantization_checkpoint_name
Module->>Factory: generate_quantizer_set(checkpoint_name)
Factory->>Quantizer: create quantizer with checkpoint_name field
Quantizer-->>Module: return quantizer_set
Module->>Primitive: pass quantizer to primitive wrapper
Primitive->>Primitive: execute primitive.bind()
Primitive->>Primitive: check if quantizer.checkpoint_name is set
alt checkpoint_name is set
Primitive->>JAX: checkpoint_name(prim_outputs, name)
JAX-->>Primitive: wrapped outputs (checkpointed)
else checkpoint_name is None
Primitive->>Primitive: use unwrapped outputs
end
Primitive->>Primitive: unpack outputs tuple
Primitive->>Quantizer: update(amax)
Primitive-->>Module: return quantized tensor
5 files reviewed, no comments
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds optional checkpointing support for quantizations to avoid recomputation in frameworks like MaxText.
Key changes:
- Added
checkpoint_namefield toQuantizerbase class and all quantizer implementations - Updated
QuantizerFactory.create()andQuantizerFactory.create_set()to accept and propagatecheckpoint_name - Modified Flax modules (
DenseGeneral,LayerNormDenseGeneral,LayerNormMLP) to acceptquantization_checkpoint_nameparameter - Wrapped primitive outputs in
jax.ad_checkpoint.checkpoint_name()when checkpoint name is provided - Added comprehensive test validating checkpoint_name appears in jaxpr for all supported recipes
Issues found:
- Critical: Backward compatibility broken in
NVFP4QuantizerandGroupedQuantizerserialization due tocheckpoint_namefield placement inaux_datatuple
Confidence Score: 2/5
- This PR has a critical backward compatibility issue that will break deserialization of existing checkpoints
- The implementation logic is sound and well-tested, but the backward compatibility break in quantizer serialization is a critical issue that could cause production failures when loading old checkpoints
- Pay close attention to
transformer_engine/jax/quantize/quantizer.py- the serialization format changes must be fixed to maintain backward compatibility
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/quantize/quantizer.py | 2/5 | Added checkpoint_name field to quantizers, but placement in aux_data breaks backward compatibility with serialized quantizers |
| transformer_engine/jax/flax/module.py | 5/5 | Added quantization_checkpoint_name parameter to Dense, LayerNormDenseGeneral, and LayerNormMLP modules - correctly threaded through to quantizer factory |
| tests/jax/test_recipe_characteristics.py | 5/5 | Added comprehensive test for quantization checkpointing across all recipes - test validates jaxpr contains expected checkpoint_name primitives |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Module as Flax Module<br/>(Dense/LayerNormMLP)
participant Factory as QuantizerFactory
participant Quantizer as Quantizer<br/>(with checkpoint_name)
participant Primitive as JAX Primitive<br/>(quantize/act_lu/norm)
participant Checkpoint as jax.ad_checkpoint.checkpoint_name
User->>Module: Create module with quantization_checkpoint_name="quant"
Module->>Factory: generate_quantizer_set(checkpoint_name="quant")
Factory->>Quantizer: Create quantizers with checkpoint_name field
Quantizer-->>Module: Return QuantizerSet
User->>Module: Forward pass (apply)
Module->>Primitive: Call quantize/act_lu/norm with quantizer
Primitive->>Primitive: Execute quantization primitive
Primitive-->>Primitive: Get prim_outputs (tuple of tensors)
alt checkpoint_name is not None
Primitive->>Checkpoint: checkpoint_name(prim_outputs, "quant")
Checkpoint-->>Primitive: Wrapped outputs (for JAX rematerialization)
end
Primitive->>Primitive: Unpack tuple to individual tensors
Primitive-->>Module: Return quantized tensors
Module-->>User: Forward pass result
1 file reviewed, 2 comments
| aux_data = ( | ||
| self.q_dtype, | ||
| self.scaling_mode, | ||
| self.q_layout, | ||
| self.data_layout, | ||
| self.checkpoint_name, | ||
| self.use_rht, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: breaks backward compatibility for serialized quantizers. Old serialized NVFP4Quantizer objects have 5 elements in aux_data, new version has 6. When deserializing old checkpoints, use_rht will be assigned to checkpoint_name position.
| aux_data = ( | |
| self.q_dtype, | |
| self.scaling_mode, | |
| self.q_layout, | |
| self.data_layout, | |
| self.checkpoint_name, | |
| self.use_rht, | |
| ) | |
| aux_data = ( | |
| self.q_dtype, | |
| self.scaling_mode, | |
| self.q_layout, | |
| self.data_layout, | |
| self.use_rht, | |
| self.checkpoint_name, | |
| ) |
| aux_data = ( | ||
| self.q_dtype, | ||
| self.scaling_mode, | ||
| self.q_layout, | ||
| self.data_layout, | ||
| self.checkpoint_name, | ||
| self.n_groups, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same backward compatibility issue - inserting checkpoint_name before n_groups breaks deserialization of old GroupedQuantizer objects.
| aux_data = ( | |
| self.q_dtype, | |
| self.scaling_mode, | |
| self.q_layout, | |
| self.data_layout, | |
| self.checkpoint_name, | |
| self.n_groups, | |
| ) | |
| aux_data = ( | |
| self.q_dtype, | |
| self.scaling_mode, | |
| self.q_layout, | |
| self.data_layout, | |
| self.n_groups, | |
| self.checkpoint_name, | |
| ) |
| rowwise_scale_inv, | ||
| colwise_scale_inv, | ||
| updated_amax, | ||
| dbias, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to checkpoint dbias and updated_amax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds optional checkpointing support for quantizations in JAX to prevent rematerialization in frameworks like MaxText. The implementation adds a checkpoint_name field to quantizers and wraps primitive outputs with jax.ad_checkpoint.checkpoint_name when specified.
Key changes:
- Added
checkpoint_namefield to all quantizer classes (Quantizer,DelayedScaleQuantizer,NVFP4Quantizer,GroupedQuantizer) - Updated
QuantizerFactory.createandQuantizerFactory.create_setto accept and pass throughcheckpoint_name - Modified primitive wrappers in
activation.py,normalization.py, andquantization.pyto wrap outputs with checkpoint_name when set - Added
quantization_checkpoint_nameparameter to Flax modules (DenseGeneral,LayerNormDenseGeneral,LayerNormMLP) - Added comprehensive test coverage for checkpoint functionality
Critical issue:
- The serialization format changes break backward compatibility. Old serialized checkpoints cannot be loaded with the new code due to mismatched
aux_datatuple lengths intree_flatten/tree_unflatten
Confidence Score: 2/5
- Has backward compatibility issues with serialized checkpoints that will cause runtime errors
- The feature implementation is solid and well-tested, but the serialization format changes in quantizer.py break backward compatibility for existing checkpoints. This is a critical issue that needs resolution before merging.
- transformer_engine/jax/quantize/quantizer.py requires attention for backward compatibility fixes in tree_flatten/tree_unflatten methods
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/quantize/quantizer.py | 1/5 | Adds checkpoint_name field to quantizers but breaks backward compatibility for serialized checkpoints in NVFP4Quantizer, GroupedQuantizer, and DelayedScaleQuantizer |
| transformer_engine/jax/cpp_extensions/activation.py | 5/5 | Correctly wraps primitive outputs with checkpoint_name when set, following consistent pattern |
| transformer_engine/jax/cpp_extensions/normalization.py | 5/5 | Correctly wraps primitive outputs with checkpoint_name when set, following consistent pattern |
| transformer_engine/jax/cpp_extensions/quantization.py | 5/5 | Correctly wraps primitive outputs with checkpoint_name when set, following consistent pattern |
| transformer_engine/jax/flax/module.py | 5/5 | Adds quantization_checkpoint_name parameter to Flax modules and passes it through to QuantizerFactory correctly |
| tests/jax/test_recipe_characteristics.py | 5/5 | Comprehensive test coverage verifying checkpoint_name functionality across all supported quantization recipes |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Module as Flax Module<br/>(Dense/LayerNormMLP)
participant Factory as QuantizerFactory
participant Quantizer as Quantizer
participant Primitive as Primitive Wrapper<br/>(quantize/act_lu/norm)
participant JAX as JAX Checkpoint
User->>Module: Create module with<br/>quantization_checkpoint_name
Module->>Factory: generate_quantizer_set(checkpoint_name)
Factory->>Quantizer: create quantizers with<br/>checkpoint_name field
Quantizer-->>Module: quantizer_set
User->>Module: Forward/Backward pass
Module->>Primitive: Call with quantizer
Primitive->>Primitive: bind primitive operation
Primitive->>Primitive: Get prim_outputs
alt checkpoint_name is set
Primitive->>JAX: checkpoint_name(prim_outputs, name)
JAX-->>Primitive: wrapped outputs
end
Primitive->>Primitive: Unpack outputs
Primitive->>Quantizer: update(amax)
Primitive-->>Module: ScaledTensor results
6 files reviewed, no comments
This reverts commit f7b7849. Signed-off-by: JAX Toolbox <jax@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds support for checkpointing quantizations in JAX by introducing an optional Key Changes:
Critical Issue: Confidence Score: 2/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as User Code
participant Module as Flax Module<br/>(Dense/LayerNormMLP)
participant QFactory as QuantizerFactory
participant Quantizer as Quantizer<br/>(with checkpoint_name)
participant Primitive as Primitive Wrappers<br/>(quantize/norm/act)
participant STFactory as ScaledTensorFactory
participant JAX as JAX checkpoint_name
User->>Module: Initialize with quantization_checkpoint_name
Module->>QFactory: generate_quantizer_set(checkpoint_name)
QFactory->>Quantizer: create quantizer(checkpoint_name=name)
Quantizer-->>Module: quantizer_set
User->>Module: Forward pass
Module->>Primitive: quantize/norm+q/act+q(quantizer)
Primitive->>STFactory: create(checkpoint_name=quantizer.checkpoint_name)
alt checkpoint_name is not None
STFactory->>JAX: jax_checkpoint_name(tensor, name)
JAX-->>STFactory: checkpointed tensor
end
STFactory-->>Primitive: ScaledTensor
Primitive-->>Module: quantized output
Module-->>User: result
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, no comments
|
|
||
| quantizer.update(updated_amax) | ||
|
|
||
| # pylint: disable=unexpected-keyword-arg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The linter wasn't able to understand the wrapper annotation that I created to wrap a function's outputs in a checkpoint name. It kept giving this error
E1123: Unexpected keyword argument 'checkpoint_name' in staticmethod call (unexpected-keyword-arg)
Despite the code being functional and correctly checkpointing when testing in unit tests and MaxText
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, no comments
Description
Some frameworks, such as MaxText, require checkpointing of values to be explicitly listed via JAX's
checkpoint_name. Otherwise, computations will be rematerialized. This PR adds optional support for checkpointing quantizations to avoid recomputing if desired.Corresponding MaxText PR to utilize this functionality: nvjax-svc-0/maxtext#29
Performance numbers are available in that PR's description
Type of change
Changes
checkpoint_namethat will be used in downstream checkpointing. I decided to store this on the quantizer directly instead of passing it as an additional argument in APIs as this functionality is highly coupled to a quantizer's functionality. Additionally, there are many codepaths in quantization and it would be easy to make a mistake in one codepath and forget to pass through this checkpoint name and silently skip checkpointing.jax.ad_checkpoint.checkpoint_nameChecklist: