Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Nov 6, 2025

Description

Some frameworks, such as MaxText, require checkpointing of values to be explicitly listed via JAX's checkpoint_name. Otherwise, computations will be rematerialized. This PR adds optional support for checkpointing quantizations to avoid recomputing if desired.

Corresponding MaxText PR to utilize this functionality: nvjax-svc-0/maxtext#29
Performance numbers are available in that PR's description

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Update quantizers and QuantizerFactory to support an optional checkpoint_name that will be used in downstream checkpointing. I decided to store this on the quantizer directly instead of passing it as an additional argument in APIs as this functionality is highly coupled to a quantizer's functionality. Additionally, there are many codepaths in quantization and it would be easy to make a mistake in one codepath and forget to pass through this checkpoint name and silently skip checkpointing.
  • Flax modules that create quantizer sets, Dense, LayerNormDenseGeneral, and LayerNormMLP, have been updated to take an optional quantization checkpoint name
  • Primitive wrapper functions for quantization, activation+quantization, and norm+quantization have been updated to utilize this checkpoint name if available and wrap the outputs in jax.ad_checkpoint.checkpoint_name

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds support for checkpointing quantization operations in JAX by introducing an optional checkpoint_name parameter throughout the quantizer hierarchy and applying JAX's checkpoint_name function to primitive outputs.

Key Changes:

  • Added checkpoint_name: Optional[str] field to all quantizer classes (Quantizer, DelayedScaleQuantizer, NVFP4Quantizer, GroupedQuantizer)
  • Updated QuantizerFactory.create() and QuantizerFactory.create_set() to accept and propagate checkpoint_name
  • Modified quantization primitives in activation.py, normalization.py, and quantization.py to apply checkpoint_name() to outputs when set
  • Added quantization_checkpoint_name parameter to Flax modules (DenseGeneral, LayerNormDenseGeneral, LayerNormMLP)

Critical Issue Found:

  • In NVFP4Quantizer.tree_flatten() (line 631-638), checkpoint_name is inserted before use_rht in the aux_data tuple, but the dataclass field order has checkpoint_name as the 5th field (from parent) and use_rht as the 6th field. This will cause incorrect deserialization when tree_unflatten unpacks the tuple

Confidence Score: 1/5

  • This PR contains a critical serialization bug that will cause runtime failures
  • The bug in NVFP4Quantizer.tree_flatten will cause incorrect field assignment during deserialization, leading to runtime errors when quantizers are serialized/deserialized. The checkpoint_name field order in aux_data doesn't match the dataclass field order
  • transformer_engine/jax/quantize/quantizer.py requires immediate attention to fix the tree_flatten field order

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/quantize/quantizer.py 1/5 Added checkpoint_name parameter throughout quantizer hierarchy; critical bug in NVFP4Quantizer.tree_flatten where checkpoint_name is inserted before use_rht, breaking serialization/deserialization
transformer_engine/jax/cpp_extensions/activation.py 4/5 Added checkpoint support for activation quantizations by applying checkpoint_name to primitive outputs; implementation looks correct
transformer_engine/jax/cpp_extensions/normalization.py 4/5 Added checkpoint support for normalization quantizations by applying checkpoint_name to primitive outputs; implementation looks correct
transformer_engine/jax/cpp_extensions/quantization.py 4/5 Added checkpoint support for quantization operations by applying checkpoint_name to primitive outputs; implementation looks correct
transformer_engine/jax/flax/module.py 5/5 Added quantization_checkpoint_name parameter to DenseGeneral, LayerNormDenseGeneral, and LayerNormMLP modules, properly threaded through to quantizer creation

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Module as Flax Module<br/>(DenseGeneral/LayerNorm)
    participant Factory as QuantizerFactory
    participant Quantizer as Quantizer<br/>(DelayedScale/NVFP4)
    participant Primitive as JAX Primitive<br/>(ActLu/Norm/Quantize)
    participant Checkpoint as checkpoint_name

    User->>Module: Call with quantization_checkpoint_name
    Module->>Factory: generate_quantizer_set(checkpoint_name)
    Factory->>Quantizer: create quantizer with checkpoint_name
    Quantizer-->>Module: return quantizer_set
    
    Module->>Primitive: bind() with inputs
    Primitive-->>Module: return prim_outputs (tuple)
    
    alt checkpoint_name is not None
        Module->>Checkpoint: checkpoint_name(prim_outputs, name)
        Checkpoint-->>Module: return checkpointed outputs
    end
    
    Module->>Module: unpack outputs tuple
    Module->>Quantizer: update(amax)
    Module-->>User: return quantized tensor
Loading

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +631 to +638
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.use_rht,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: checkpoint_name inserted before use_rht in aux_data breaks deserialization. When tree_unflatten unpacks *aux_data, the 6th element (use_rht) will be assigned to the 5th field (checkpoint_name), causing a mismatch.

Suggested change
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.use_rht,
)
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.use_rht,
self.checkpoint_name,
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is incorrect. checkpoint_name is the 5th field here since it is inherited from its parent and is 5th in the flattened list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! Looking at the tree_unflatten method, it expects the parameters in a specific order. The NVFP4Quantizer inherits from Quantizer, so checkpoint_name should be in the same position as defined in the parent class.

In the base Quantizer class, the tree_flatten method puts checkpoint_name as the 5th element (index 4), but in NVFP4Quantizer, you've placed use_rht in that position and moved checkpoint_name to the 6th position. This will cause deserialization issues.

Suggested change
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.use_rht,
)
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.use_rht,
)

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds optional checkpointing support for quantization operations to avoid rematerialization in frameworks like MaxText. The implementation adds a checkpoint_name field to all quantizer classes and wraps primitive outputs with jax.ad_checkpoint.checkpoint_name when specified.

Key changes:

  • Added checkpoint_name field to base Quantizer class and all subclasses (DelayedScaleQuantizer, NVFP4Quantizer, GroupedQuantizer)
  • Updated tree_flatten methods to include checkpoint_name in serialization for all quantizer types
  • Modified primitive wrappers in activation.py, normalization.py, and quantization.py to conditionally wrap outputs with checkpoint_name
  • Extended Flax modules (DenseGeneral, LayerNormDenseGeneral, LayerNormMLP) to accept and pass through quantization_checkpoint_name parameter
  • Updated QuantizerFactory.create and QuantizerFactory.create_set to accept and propagate checkpoint_name

Issues found:

  • Critical serialization bug in NVFP4Quantizer.tree_flatten where checkpoint_name is inserted before use_rht in the aux_data tuple, breaking deserialization order (already flagged in previous comment)

Confidence Score: 2/5

  • This PR has a critical serialization bug that will break NVFP4Quantizer deserialization
  • The implementation pattern is solid and consistent across all files, but the NVFP4Quantizer.tree_flatten method has incorrect field ordering in aux_data that will cause deserialization to fail or produce incorrect quantizer state. The checkpoint_name field is inserted before use_rht, but tree_unflatten expects them in the order they appear in the class definition
  • transformer_engine/jax/quantize/quantizer.py - Fix NVFP4Quantizer serialization order and add tests for quantizer serialization/deserialization

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/quantize/quantizer.py 2/5 Adds checkpoint_name field to quantizer classes and updates serialization, but has critical deserialization bug in NVFP4Quantizer where field order mismatch will cause incorrect unpacking
transformer_engine/jax/cpp_extensions/activation.py 5/5 Adds checkpoint_name wrapping around primitive outputs for act_lu and quantize_dact_dbias functions, consistent implementation with proper tuple unpacking
transformer_engine/jax/cpp_extensions/normalization.py 5/5 Adds checkpoint_name wrapping for layernorm_fwd and rmsnorm_fwd primitives, consistent with activation.py pattern
transformer_engine/jax/cpp_extensions/quantization.py 5/5 Adds checkpoint_name wrapping to _quantize_dbias_impl function, following the same pattern as other primitive wrappers
transformer_engine/jax/flax/module.py 5/5 Adds quantization_checkpoint_name parameter to Dense, LayerNormDenseGeneral, and LayerNormMLP modules, properly threading it through to quantizer creation

Sequence Diagram

sequenceDiagram
    participant User as User/Framework
    participant Module as Flax Module<br/>(Dense/LayerNorm)
    participant Factory as QuantizerFactory
    participant Quantizer as Quantizer
    participant Primitive as Primitive Wrapper<br/>(activation/norm/quant)
    participant JAX as JAX checkpoint_name

    User->>Module: Call with quantization_checkpoint_name
    Module->>Factory: generate_quantizer_set(checkpoint_name)
    Factory->>Quantizer: create quantizer with checkpoint_name field
    Quantizer-->>Module: return quantizer_set
    Module->>Primitive: pass quantizer to primitive wrapper
    Primitive->>Primitive: execute primitive.bind()
    Primitive->>Primitive: check if quantizer.checkpoint_name is set
    alt checkpoint_name is set
        Primitive->>JAX: checkpoint_name(prim_outputs, name)
        JAX-->>Primitive: wrapped outputs (checkpointed)
    else checkpoint_name is None
        Primitive->>Primitive: use unwrapped outputs
    end
    Primitive->>Primitive: unpack outputs tuple
    Primitive->>Quantizer: update(amax)
    Primitive-->>Module: return quantized tensor
Loading

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds optional checkpointing support for quantizations to avoid recomputation in frameworks like MaxText.

Key changes:

  • Added checkpoint_name field to Quantizer base class and all quantizer implementations
  • Updated QuantizerFactory.create() and QuantizerFactory.create_set() to accept and propagate checkpoint_name
  • Modified Flax modules (DenseGeneral, LayerNormDenseGeneral, LayerNormMLP) to accept quantization_checkpoint_name parameter
  • Wrapped primitive outputs in jax.ad_checkpoint.checkpoint_name() when checkpoint name is provided
  • Added comprehensive test validating checkpoint_name appears in jaxpr for all supported recipes

Issues found:

  • Critical: Backward compatibility broken in NVFP4Quantizer and GroupedQuantizer serialization due to checkpoint_name field placement in aux_data tuple

Confidence Score: 2/5

  • This PR has a critical backward compatibility issue that will break deserialization of existing checkpoints
  • The implementation logic is sound and well-tested, but the backward compatibility break in quantizer serialization is a critical issue that could cause production failures when loading old checkpoints
  • Pay close attention to transformer_engine/jax/quantize/quantizer.py - the serialization format changes must be fixed to maintain backward compatibility

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/quantize/quantizer.py 2/5 Added checkpoint_name field to quantizers, but placement in aux_data breaks backward compatibility with serialized quantizers
transformer_engine/jax/flax/module.py 5/5 Added quantization_checkpoint_name parameter to Dense, LayerNormDenseGeneral, and LayerNormMLP modules - correctly threaded through to quantizer factory
tests/jax/test_recipe_characteristics.py 5/5 Added comprehensive test for quantization checkpointing across all recipes - test validates jaxpr contains expected checkpoint_name primitives

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Module as Flax Module<br/>(Dense/LayerNormMLP)
    participant Factory as QuantizerFactory
    participant Quantizer as Quantizer<br/>(with checkpoint_name)
    participant Primitive as JAX Primitive<br/>(quantize/act_lu/norm)
    participant Checkpoint as jax.ad_checkpoint.checkpoint_name

    User->>Module: Create module with quantization_checkpoint_name="quant"
    Module->>Factory: generate_quantizer_set(checkpoint_name="quant")
    Factory->>Quantizer: Create quantizers with checkpoint_name field
    Quantizer-->>Module: Return QuantizerSet
    
    User->>Module: Forward pass (apply)
    Module->>Primitive: Call quantize/act_lu/norm with quantizer
    Primitive->>Primitive: Execute quantization primitive
    Primitive-->>Primitive: Get prim_outputs (tuple of tensors)
    
    alt checkpoint_name is not None
        Primitive->>Checkpoint: checkpoint_name(prim_outputs, "quant")
        Checkpoint-->>Primitive: Wrapped outputs (for JAX rematerialization)
    end
    
    Primitive->>Primitive: Unpack tuple to individual tensors
    Primitive-->>Module: Return quantized tensors
    Module-->>User: Forward pass result
Loading

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +631 to +638
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.use_rht,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: breaks backward compatibility for serialized quantizers. Old serialized NVFP4Quantizer objects have 5 elements in aux_data, new version has 6. When deserializing old checkpoints, use_rht will be assigned to checkpoint_name position.

Suggested change
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.use_rht,
)
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.use_rht,
self.checkpoint_name,
)

Comment on lines +917 to +924
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.n_groups,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same backward compatibility issue - inserting checkpoint_name before n_groups breaks deserialization of old GroupedQuantizer objects.

Suggested change
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.n_groups,
)
aux_data = (
self.q_dtype,
self.scaling_mode,
self.q_layout,
self.data_layout,
self.n_groups,
self.checkpoint_name,
)

@phu0ngng phu0ngng self-requested a review November 10, 2025 14:50
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
Copy link
Collaborator

@phu0ngng phu0ngng Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to checkpoint dbias and updated_amax.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds optional checkpointing support for quantizations in JAX to prevent rematerialization in frameworks like MaxText. The implementation adds a checkpoint_name field to quantizers and wraps primitive outputs with jax.ad_checkpoint.checkpoint_name when specified.

Key changes:

  • Added checkpoint_name field to all quantizer classes (Quantizer, DelayedScaleQuantizer, NVFP4Quantizer, GroupedQuantizer)
  • Updated QuantizerFactory.create and QuantizerFactory.create_set to accept and pass through checkpoint_name
  • Modified primitive wrappers in activation.py, normalization.py, and quantization.py to wrap outputs with checkpoint_name when set
  • Added quantization_checkpoint_name parameter to Flax modules (DenseGeneral, LayerNormDenseGeneral, LayerNormMLP)
  • Added comprehensive test coverage for checkpoint functionality

Critical issue:

  • The serialization format changes break backward compatibility. Old serialized checkpoints cannot be loaded with the new code due to mismatched aux_data tuple lengths in tree_flatten/tree_unflatten

Confidence Score: 2/5

  • Has backward compatibility issues with serialized checkpoints that will cause runtime errors
  • The feature implementation is solid and well-tested, but the serialization format changes in quantizer.py break backward compatibility for existing checkpoints. This is a critical issue that needs resolution before merging.
  • transformer_engine/jax/quantize/quantizer.py requires attention for backward compatibility fixes in tree_flatten/tree_unflatten methods

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/quantize/quantizer.py 1/5 Adds checkpoint_name field to quantizers but breaks backward compatibility for serialized checkpoints in NVFP4Quantizer, GroupedQuantizer, and DelayedScaleQuantizer
transformer_engine/jax/cpp_extensions/activation.py 5/5 Correctly wraps primitive outputs with checkpoint_name when set, following consistent pattern
transformer_engine/jax/cpp_extensions/normalization.py 5/5 Correctly wraps primitive outputs with checkpoint_name when set, following consistent pattern
transformer_engine/jax/cpp_extensions/quantization.py 5/5 Correctly wraps primitive outputs with checkpoint_name when set, following consistent pattern
transformer_engine/jax/flax/module.py 5/5 Adds quantization_checkpoint_name parameter to Flax modules and passes it through to QuantizerFactory correctly
tests/jax/test_recipe_characteristics.py 5/5 Comprehensive test coverage verifying checkpoint_name functionality across all supported quantization recipes

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Module as Flax Module<br/>(Dense/LayerNormMLP)
    participant Factory as QuantizerFactory
    participant Quantizer as Quantizer
    participant Primitive as Primitive Wrapper<br/>(quantize/act_lu/norm)
    participant JAX as JAX Checkpoint

    User->>Module: Create module with<br/>quantization_checkpoint_name
    Module->>Factory: generate_quantizer_set(checkpoint_name)
    Factory->>Quantizer: create quantizers with<br/>checkpoint_name field
    Quantizer-->>Module: quantizer_set
    
    User->>Module: Forward/Backward pass
    Module->>Primitive: Call with quantizer
    Primitive->>Primitive: bind primitive operation
    Primitive->>Primitive: Get prim_outputs
    
    alt checkpoint_name is set
        Primitive->>JAX: checkpoint_name(prim_outputs, name)
        JAX-->>Primitive: wrapped outputs
    end
    
    Primitive->>Primitive: Unpack outputs
    Primitive->>Quantizer: update(amax)
    Primitive-->>Module: ScaledTensor results
Loading

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

nvjax and others added 2 commits November 10, 2025 11:37
This reverts commit f7b7849.

Signed-off-by: JAX Toolbox <jax@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 10, 2025

Greptile Overview

Greptile Summary

This PR adds support for checkpointing quantizations in JAX by introducing an optional checkpoint_name parameter throughout the quantization pipeline. When provided, quantized tensors are wrapped with jax.ad_checkpoint.checkpoint_name to prevent rematerialization during gradient computation, which is critical for frameworks like MaxText.

Key Changes:

  • Added checkpoint_name field to all quantizer classes (Quantizer, NVFP4Quantizer, GroupedQuantizer, etc.)
  • Updated QuantizerFactory to accept and propagate checkpoint_name when creating quantizer sets
  • Modified Flax modules (Dense, LayerNormDenseGeneral, LayerNormMLP) to accept quantization_checkpoint_name parameter
  • Implemented wrap_in_checkpoint_name decorator to conditionally apply JAX checkpointing to ScaledTensorFactory outputs
  • Updated all primitive wrappers (quantization, normalization, activation) to pass checkpoint_name through to tensor creation
  • Added comprehensive test coverage verifying checkpoint propagation across all recipes

Critical Issue:
The serialization format change in NVFP4Quantizer and GroupedQuantizer breaks backward compatibility. The new checkpoint_name field is inserted before existing fields (use_rht, n_groups) in aux_data, which will cause deserialization failures when loading checkpoints saved with older versions.

Confidence Score: 2/5

  • This PR has critical backward compatibility issues that will break existing checkpoints
  • The implementation is thorough and well-tested, but the serialization breaking changes in NVFP4Quantizer and GroupedQuantizer (inserting checkpoint_name before use_rht/n_groups in aux_data) will cause deserialization failures for existing saved models. This is a blocking issue that needs to be fixed before merge
  • Pay close attention to transformer_engine/jax/quantize/quantizer.py - the serialization format must be fixed to maintain backward compatibility

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/quantize/quantizer.py 1/5 Added checkpoint_name field to quantizers but breaks backward compatibility by inserting it before existing fields in aux_data serialization, causing deserialization errors for NVFP4Quantizer and GroupedQuantizer
transformer_engine/jax/quantize/tensor.py 5/5 Added decorator wrap_in_checkpoint_name to wrap ScaledTensorFactory outputs with JAX checkpointing. Clean implementation with proper handling of optional checkpoint names
transformer_engine/jax/flax/module.py 5/5 Added quantization_checkpoint_name parameter to Dense, LayerNormDenseGeneral, and LayerNormMLP modules, passing it through to quantizer sets

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Module as Flax Module<br/>(Dense/LayerNormMLP)
    participant QFactory as QuantizerFactory
    participant Quantizer as Quantizer<br/>(with checkpoint_name)
    participant Primitive as Primitive Wrappers<br/>(quantize/norm/act)
    participant STFactory as ScaledTensorFactory
    participant JAX as JAX checkpoint_name

    User->>Module: Initialize with quantization_checkpoint_name
    Module->>QFactory: generate_quantizer_set(checkpoint_name)
    QFactory->>Quantizer: create quantizer(checkpoint_name=name)
    Quantizer-->>Module: quantizer_set

    User->>Module: Forward pass
    Module->>Primitive: quantize/norm+q/act+q(quantizer)
    Primitive->>STFactory: create(checkpoint_name=quantizer.checkpoint_name)
    
    alt checkpoint_name is not None
        STFactory->>JAX: jax_checkpoint_name(tensor, name)
        JAX-->>STFactory: checkpointed tensor
    end
    
    STFactory-->>Primitive: ScaledTensor
    Primitive-->>Module: quantized output
    Module-->>User: result
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

quantizer.update(updated_amax)

# pylint: disable=unexpected-keyword-arg
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The linter wasn't able to understand the wrapper annotation that I created to wrap a function's outputs in a checkpoint name. It kept giving this error

E1123: Unexpected keyword argument 'checkpoint_name' in staticmethod call (unexpected-keyword-arg)

Despite the code being functional and correctly checkpointing when testing in unit tests and MaxText

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants