-
Notifications
You must be signed in to change notification settings - Fork 540
[PyTorch][NVFP4][MOE] NVFP4 Grouped Hadamard Amax Kernel #2351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements a grouped/multi-tensor NVFP4 quantization kernel with Hadamard transform (RHT) to reduce CPU overhead for MoE workloads. The key changes include:
- New fused kernel path in
cast.cpp: Implementsmulti_tensor_quantize_nvfp4_impl()that processes multiple tensor chunks in a grouped fashion usingnvte_multi_hadamard_transform_amax()for amax computation and fused RHT operations - Alignment requirement bump: NVFP4 padding increased from 32 to 64 elements per the grouped kernel's constraints (enforced via 64-multiple check on split_sections)
- Refactored alignment logic: Introduced
get_align_size_for_quantization()helper to centralize alignment logic (16 for FP8, 32 for MXFP8, 64 for NVFP4) - Fixed amax dtype bug: Changed amax tensor dtype from
torch::kUInt8totorch::kFloat32in bulk allocation (lines 809, 864) - Comprehensive test coverage: New test file validates grouped quantization against reference implementation across various edge cases (zero tokens, uneven splits, transposed outputs)
The implementation correctly handles empty tensor chunks, validates input constraints (128-column multiple, bfloat16 only, no 2D quantization), and falls back to non-fused path when requirements aren't met.
Confidence Score: 4/5
- This PR is generally safe to merge with minor caveats around incomplete features
- The implementation is well-structured with proper validation checks and comprehensive test coverage. Fixed a critical dtype bug (amax as float32). However, stochastic rounding remains unimplemented (TODO at line 152), and the PR is marked as DRAFT with several TODO items in the description. The grouped kernel requires strict alignment (64-multiple) which is enforced. Core functionality appears sound based on the test suite.
transformer_engine/pytorch/csrc/extensions/cast.cpp- verify stochastic rounding is disabled in production use until implemented
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/nvfp4/test_nvfp4_group_quantize.py | 5/5 | New comprehensive test file for NVFP4 grouped quantization with RHT, covering edge cases like zero tokens and uneven splits. |
| tests/pytorch/test_numerics.py | 5/5 | Refactored to use get_align_size_for_quantization() helper function, replacing hardcoded alignment values for better maintainability. |
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 4/5 | Major implementation of NVFP4 grouped quantization with fused Hadamard transform, fixed amax dtype bug, added dimension checks for 64-multiple alignment. |
Sequence Diagram
sequenceDiagram
participant User
participant split_quantize
participant multi_tensor_quantize_impl
participant multi_tensor_quantize_nvfp4_impl
participant nvte_multi_hadamard_transform_amax
participant Quantize Loop
User->>split_quantize: Input tensor + split_sections + quantizers
split_quantize->>split_quantize: Split input into chunks
split_quantize->>split_quantize: Check if NVFP4 fused kernel eligible
Note over split_quantize: Requires: 64-multiple splits, NVFP4_1D_SCALING
split_quantize->>multi_tensor_quantize_impl: Pass chunks + quantizers
multi_tensor_quantize_impl->>multi_tensor_quantize_impl: Validate scaling mode consistency
multi_tensor_quantize_impl->>multi_tensor_quantize_nvfp4_impl: NVFP4 fused path
multi_tensor_quantize_nvfp4_impl->>nvte_multi_hadamard_transform_amax: Compute amax with RHT
Note over nvte_multi_hadamard_transform_amax: Computes rowwise & columnwise amax
nvte_multi_hadamard_transform_amax-->>multi_tensor_quantize_nvfp4_impl: amax values populated
multi_tensor_quantize_nvfp4_impl->>Quantize Loop: For each chunk
loop Each tensor chunk
Quantize Loop->>Quantize Loop: Skip if numel == 0
Quantize Loop->>Quantize Loop: Rowwise quantization (nvte_quantize_v2)
Quantize Loop->>Quantize Loop: Columnwise with RHT fusion (nvte_hadamard_transform_cast_fusion_columnwise)
end
Quantize Loop-->>multi_tensor_quantize_nvfp4_impl: Quantized chunks
multi_tensor_quantize_nvfp4_impl-->>split_quantize: Output list
split_quantize-->>User: List of NVFP4Tensor objects
3 files reviewed, 1 comment
| if (quantizer->stochastic_rounding) { | ||
| // TODO: implement stochastic rounding support for multi tensor | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Stochastic rounding not implemented yet for multi-tensor - verify that stochastic_rounding is disabled in tests, or this will silently skip the rounding.
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
b8a1722 to
c6c11ef
Compare
|
/te_ci L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements a grouped NVFP4 Hadamard Amax kernel to reduce CPU overhead and quantization cost for mixture-of-experts (MOE) workloads. The implementation adds a high-performance CUDA kernel using TMA (Tensor Memory Accelerator) and tensor cores.
Key Changes
- New grouped Hadamard kernel (
multi_hadamard_transform.cu): ImplementsMultiHadamardAmaxTmaKernelwith ping-pong buffering, tensor core operations, and multi-tensor amax reduction - Extended Python API (
cast.cpp): Addedmulti_tensor_quantize_nvfp4_implfunction supporting grouped quantization with stochastic rounding - Padding alignment change: Bumped NVFP4 M-dimension padding requirement from 32 to 64 bytes for optimal kernel performance
- Utility function: Added
get_align_size_for_quantization()to centralize alignment logic across the codebase - Stochastic rounding fix: Recent commit (c6c11ef) properly implements stochastic rounding support for grouped operations
- Comprehensive tests: Added test coverage for grouped quantization with various edge cases (zero tokens, uneven splits, stochastic rounding)
Important Constraints
- Hardware requirement: Kernel requires SM 10.0+ (Blackwell architecture)
- Alignment requirement: Split sections must be 64 multiples (breaking change from 32)
- Limitations: Pre-RHT amax and non-RHT modes not supported
- MegatronLM coordination: PR description notes MegatronLM needs to update padding from 32 to 64
Architecture
The implementation uses a sophisticated multi-stage pipeline:
- Zero-initialize amax buffers for all tensors
- Launch grouped Hadamard transform kernel with TMA loads and tensor core MMA operations
- Compute both pre-RHT amax (rowwise) and post-RHT amax (columnwise transpose)
- Perform rowwise and/or columnwise quantization using populated amax values
Confidence Score: 3/5
- This PR requires coordination with downstream projects and has pending TODO items before it should be merged
- The implementation is technically sound with good test coverage and stochastic rounding properly implemented. However, the PR is explicitly marked as DRAFT with several unfinished TODO items: (1) code cleanup and more testing needed, (2) stochastic rounding was just fixed in latest commit, (3) MegatronLM needs to update padding from 32 to 64 bytes. The breaking change in alignment requirements (32→64) needs coordination with downstream users. Additionally, the SM 10.0+ hardware requirement is a significant constraint that should be clearly communicated.
transformer_engine/pytorch/csrc/extensions/cast.cppneeds verification that stochastic rounding fix is complete. Downstream projects (especially MegatronLM) need to update padding requirements before this can merge.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu | 4/5 | New grouped Hadamard transform kernel implementation with TMA. Uses tensor cores and advanced PTX instructions. Complex ping-pong buffering and amax reduction logic appears correct. Main concern: requires SM 10.0+ (Blackwell) which limits portability. |
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 3/5 | Adds multi_tensor_quantize_nvfp4_impl function with stochastic rounding support. Stochastic rounding was recently fixed in commit c6c11ef. RNG state generation uses per-tensor loop (TODO notes bulk API would be better). Validates split_sections must be 64 multiples. |
| tests/pytorch/nvfp4/test_nvfp4_group_quantize.py | 4/5 | Comprehensive test coverage for grouped quantization including edge cases (zero tokens, uneven splits). Tests validate against reference implementation with exact equality checks. Good test structure. |
| tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py | 4/5 | New test file for stochastic rounding validation in grouped quantization. Includes reference implementations for RHT and FP4 dequantization. Tests both single and grouped quantization paths. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant PyAPI as tex.split_quantize
participant Cast as cast.cpp
participant MultiImpl as multi_tensor_quantize_nvfp4_impl
participant Hadamard as nvte_multi_hadamard_transform_amax
participant Kernel as MultiHadamardAmaxTmaKernel
participant Quantize as nvte_quantize_v2 / RHT fusion
User->>PyAPI: tex.split_quantize(input, split_sections, quantizers)
PyAPI->>Cast: split_quantize (C++ extension)
Cast->>Cast: Split input tensor along dim 0
Cast->>Cast: Allocate bulk NVFP4 output tensors
Cast->>MultiImpl: multi_tensor_quantize_nvfp4_impl()
alt Stochastic Rounding Enabled
MultiImpl->>MultiImpl: Generate RNG states for each tensor
MultiImpl->>MultiImpl: Configure quant_config with RNG
end
MultiImpl->>Hadamard: nvte_multi_hadamard_transform_amax(input, outputs, split_sections)
Note over Hadamard: Compute pre-RHT amax & columnwise RHT amax
Hadamard->>Kernel: MultiZeroAmaxKernel (initialize amax buffers)
Hadamard->>Kernel: MultiHadamardAmaxTmaKernel (SM 10.0+)
Note over Kernel: TMA load with ping-pong buffering
Note over Kernel: Tensor core MMA (16x16x16)
Note over Kernel: Compute amax for pre-RHT and RHT(transpose)
Kernel-->>Hadamard: Amax values written to output tensors
Hadamard-->>MultiImpl: Return with populated amax buffers
loop For each tensor split
alt Rowwise Quantization
MultiImpl->>Quantize: nvte_quantize_v2 (rowwise)
end
alt Columnwise Quantization with RHT
MultiImpl->>Quantize: nvte_hadamard_transform_cast_fusion_columnwise
Note over Quantize: RHT transform + quantize in single kernel
end
end
MultiImpl-->>Cast: Quantized outputs
Cast-->>PyAPI: NVFP4Tensor list
PyAPI-->>User: Return quantized tensors
10 files reviewed, 1 comment
| if (quantizer->stochastic_rounding) { | ||
| // TODO(zhongbo): remove the for loop of generating rng states with a single call | ||
| // with rng_elts_per_thread = 1024 * num_tensors | ||
| // Change to the bulk generate rng states api when grouped quantize is available | ||
| const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened | ||
| auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); | ||
| rng_states_tensor = torch::empty({2 * num_tensors}, opts); | ||
|
|
||
| for (size_t i = 0; i < num_tensors; ++i) { | ||
| auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( | ||
| std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); | ||
| at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); | ||
| int64_t *rng_state_ptr = static_cast<int64_t *>(rng_states_tensor.data_ptr()) + i * 2; | ||
| philox_unpack(philox_args, rng_state_ptr); | ||
| te_rng_state_list.push_back(makeTransformerEngineTensor( | ||
| static_cast<void *>(rng_state_ptr), std::vector<size_t>{2}, DType::kInt64)); | ||
| quant_config_list[i].set_rng_state(te_rng_state_list[i].data()); | ||
| quant_config_list[i].set_stochastic_rounding(true); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Consider optimizing RNG state generation by using bulk API instead of per-tensor loop (as noted in TODO). This would improve performance for multi-tensor quantization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds a utility function get_align_size_for_quantization() to centralize the logic for determining alignment sizes based on quantization recipes, replacing duplicated inline logic across the codebase.
Key changes:
- Added
get_align_size_for_quantization()function to return alignment size: 32 for MXFP8, 64 for NVFP4, 16 for other recipes - Important: NVFP4 alignment padding increased from 32 to 64 (as mentioned in PR description and related MegatronLM PRs)
- Updated
Fp8PaddingandFp8Unpaddingmodules to use the new function - Updated test files to use the centralized function instead of hardcoded values
This refactoring improves maintainability by centralizing the alignment logic in one place.
Confidence Score: 5/5
- This PR is safe to merge - it's a straightforward refactoring that centralizes alignment size logic
- The changes are minimal and low-risk: adds a simple utility function that replaces duplicated inline logic with centralized code. The function has clear control flow with no complex edge cases. The NVFP4 padding change from 32 to 64 is intentional and documented in the PR description with corresponding MegatronLM changes.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/quantization.py | 5/5 | Added utility function get_align_size_for_quantization to centralize alignment size logic (32 for MXFP8, 64 for NVFP4, 16 default). NVFP4 padding increased from 32 to 64. |
Sequence Diagram
sequenceDiagram
participant User
participant Fp8Padding/Fp8Unpadding
participant FP8GlobalStateManager
participant get_align_size_for_quantization
participant Recipe
User->>Fp8Padding/Fp8Unpadding: forward(inp, m_splits)
alt align_size is None
Fp8Padding/Fp8Unpadding->>FP8GlobalStateManager: get_fp8_recipe()
FP8GlobalStateManager-->>Fp8Padding/Fp8Unpadding: recipe
Fp8Padding/Fp8Unpadding->>get_align_size_for_quantization: get_align_size_for_quantization(recipe)
get_align_size_for_quantization->>Recipe: recipe.mxfp8()
Recipe-->>get_align_size_for_quantization: bool
alt recipe.mxfp8() == True
get_align_size_for_quantization-->>Fp8Padding/Fp8Unpadding: 32
else
get_align_size_for_quantization->>Recipe: recipe.nvfp4()
Recipe-->>get_align_size_for_quantization: bool
alt recipe.nvfp4() == True
get_align_size_for_quantization-->>Fp8Padding/Fp8Unpadding: 64
else
get_align_size_for_quantization-->>Fp8Padding/Fp8Unpadding: 16
end
end
end
Fp8Padding/Fp8Unpadding->>Fp8Padding/Fp8Unpadding: Calculate padded_m_splits
Fp8Padding/Fp8Unpadding-->>User: output
1 file reviewed, 1 comment
| return get_default_fp8_recipe() | ||
|
|
||
|
|
||
| def get_align_size_for_quantization(recipe: Recipe): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Missing return type annotation - should specify -> int
| def get_align_size_for_quantization(recipe: Recipe): | |
| def get_align_size_for_quantization(recipe: Recipe) -> int: |
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Implements grouped NVFP4 quantization with Hadamard transform and amax computation to reduce CPU overhead and quantization costs for MOE workloads.
Key Changes
- Added
multi_tensor_quantize_nvfp4_implfunction (~190 lines) that handles batched NVFP4 quantization with RHT - Extended
multi_tensor_quantize_implto route NVFP4 quantizers to the new fused kernel path - Implemented stochastic rounding support for multi-tensor quantization with per-tensor RNG state generation
- Fixed critical bug: changed amax tensor dtype from
kUInt8tokFloat32in bulk allocation functions (lines 830, 885) - Added validation checks: columns must be multiple of 128, split_sections must be multiple of 64
- Modified function signatures to pass full input tensor alongside split tensors for grouped amax computation
Implementation Details
- Uses
nvte_multi_hadamard_transform_amaxfor efficient batched amax computation across row/column-wise modes - Supports both rowwise and columnwise quantization with RHT fusion via
nvte_hadamard_transform_cast_fusion_columnwise - Requires bfloat16 input for RHT path (validation at line 188)
- Currently only supports post-RHT amax (pre-RHT amax and no-RHT paths intentionally unimplemented)
Issues Found
- Missing null check after
dynamic_castat line 372 - could cause segfault if cast fails - Assumes all quantizers have identical stochastic rounding settings without verification
Confidence Score: 3/5
- This PR has a critical null pointer dereference risk that must be addressed before merging
- Score reflects one critical bug (missing null check on dynamic_cast) that could cause runtime crashes, plus minor assumption issues. The core implementation appears sound with proper validation checks, stochastic rounding support is correctly implemented, and the amax dtype bug was already fixed. However, the missing null check is a blocking issue.
- Pay close attention to line 372 in
cast.cppwhere thedynamic_castresult must be null-checked before use
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 4/5 | Adds NVFP4 grouped Hadamard transform with amax kernel. Includes stochastic rounding support, bulk RNG state generation, multi-tensor amax computation, and fused RHT+quantization. Fixed amax dtype from kUInt8 to kFloat32. |
Sequence Diagram
sequenceDiagram
participant Client
participant split_quantize
participant multi_tensor_quantize_impl
participant multi_tensor_quantize_nvfp4_impl
participant CUDA_Kernels
Client->>split_quantize: Input tensor + split_sections + quantizers
split_quantize->>split_quantize: Split input tensor along dim 0
split_quantize->>split_quantize: Create TensorWrapper for full input
split_quantize->>split_quantize: Bulk allocate NVFP4 output tensors
split_quantize->>multi_tensor_quantize_impl: input_wrapper + input_list + split_sections
multi_tensor_quantize_impl->>multi_tensor_quantize_impl: Check scaling mode consistency
multi_tensor_quantize_impl->>multi_tensor_quantize_impl: Verify split_sections % 64 == 0
multi_tensor_quantize_impl->>multi_tensor_quantize_nvfp4_impl: Route to NVFP4 fused kernel
multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Validate inputs (cols % 128, no 2D, no amax reduction)
multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Create QuantizationConfigWrapper list
alt Stochastic Rounding Enabled
multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Allocate RNG states tensor [2*num_tensors]
loop For each tensor
multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Generate Philox RNG state
multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Set RNG state in quant_config_list[i]
end
end
alt RHT Enabled
multi_tensor_quantize_nvfp4_impl->>CUDA_Kernels: nvte_multi_hadamard_transform_amax
CUDA_Kernels-->>multi_tensor_quantize_nvfp4_impl: Compute rowwise & columnwise amax
loop For each tensor
alt Rowwise quantization
multi_tensor_quantize_nvfp4_impl->>CUDA_Kernels: nvte_quantize_v2 (identity)
end
alt Columnwise quantization
multi_tensor_quantize_nvfp4_impl->>CUDA_Kernels: nvte_hadamard_transform_cast_fusion_columnwise
end
end
else RHT Disabled
multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Error: multi-amax without RHT not supported
end
multi_tensor_quantize_nvfp4_impl-->>Client: Return quantized tensors
1 file reviewed, 2 comments
| break; | ||
| } | ||
| case NVTE_NVFP4_1D_SCALING: { | ||
| auto nvfp4_quantizer = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp_list[0].get()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: dynamic_cast can return nullptr if cast fails - needs null check before dereferencing
| auto nvfp4_quantizer = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp_list[0].get()); | |
| auto nvfp4_quantizer = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp_list[0].get()); | |
| NVTE_CHECK(nvfp4_quantizer != nullptr, "Failed to cast to NVFP4Quantizer"); |
| at::Tensor rng_states_tensor; | ||
|
|
||
| // assumes one quantizer doing RS means all quantizers doing RS | ||
| if (quantizer->stochastic_rounding) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Assumes all quantizers have same stochastic_rounding setting - verify quantizers are consistent or check each individually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements grouped NVFP4 Hadamard transform kernels to reduce CPU overhead for MOE quantization. The implementation refactors shared utility functions into hadamard_transform_utils.cuh and adds a new multi_hadamard_transform.cu file with a TMA-based kernel that processes multiple tensor splits in a single launch.
Key changes:
- Refactored Hadamard utility functions into shared header for code reuse
- Added
MultiHadamardAmaxTmaKernelthat computes amax values across multiple tensor splits using prefix-sum based indexing - Extended
multi_tensor_quantize_implto support NVFP4 quantizers with 64-multiple split_sections requirement - Fixed amax buffer dtypes from
torch::kUInt8totorch::kFloat32in bulk allocation
Critical issues found:
- Overflow check happens after addition, allowing signed integer overflow (UB)
- Missing bounds check for
tensor_idlookup in kernel - can access out-of-bounds memory ifglobal_offset_yexceeds total range
Note: PR is marked as DRAFT and author lists cleanup/testing as pending action items.
Confidence Score: 2/5
- Not safe to merge - contains critical memory safety and integer overflow bugs
- Score of 2 due to two critical logic errors: (1) overflow check after addition allows undefined behavior, (2) missing bounds check in kernel can cause out-of-bounds memory access. These must be fixed before merging. Refactoring and new kernel implementation are sound, but safety issues prevent higher score.
- Pay close attention to
transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu- contains both critical bugs that need immediate fixing
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu | 2/5 | New grouped NVFP4 Hadamard kernel with TMA - found critical overflow/bounds check issues that need fixing before merge |
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 3/5 | Added NVFP4 multi-tensor quantization support - complex logic with proper validation but amax dtype fix needs verification |
Sequence Diagram
sequenceDiagram
participant User
participant PyTorch as split_quantize()
participant Cast as multi_tensor_quantize_impl()
participant NVFP4 as multi_tensor_quantize_nvfp4_impl()
participant Hadamard as multi_hadamard_transform_amax()
participant Kernel as MultiHadamardAmaxTmaKernel
participant Quantize as nvte_quantize_v2()
User->>PyTorch: split_quantize(tensor, split_sections, quantizers)
PyTorch->>PyTorch: Create input_list by splitting tensor
PyTorch->>Cast: multi_tensor_quantize_impl(single_input, input_list, quantizers, split_sections)
Cast->>Cast: Check if all quantizers are NVFP4
Cast->>Cast: Validate split_sections (must be 64 multiple)
alt Fused NVFP4 Kernel Path
Cast->>NVFP4: multi_tensor_quantize_nvfp4_impl()
NVFP4->>NVFP4: Generate RNG states for stochastic rounding
NVFP4->>Hadamard: nvte_multi_hadamard_transform_amax()
Hadamard->>Hadamard: Build MultiAmaxArgs with split_sections_range
Hadamard->>Kernel: Launch MultiZeroAmaxKernel (zero out amax buffers)
Hadamard->>Kernel: Launch MultiHadamardAmaxTmaKernel
Kernel->>Kernel: For each block, find tensor_id from global_offset_y
Kernel->>Kernel: Compute RHT and calculate amax values
Kernel->>Kernel: Atomic updates to output amax buffers
Kernel-->>NVFP4: Return (amax populated)
loop For each tensor
NVFP4->>Quantize: nvte_quantize_v2() for rowwise (if enabled)
NVFP4->>Quantize: nvte_hadamard_transform_cast_fusion_columnwise() for columnwise
end
NVFP4-->>Cast: Return quantized tensors
else Individual Quantization Path
loop For each tensor
Cast->>Quantize: quantizer.quantize(input, output)
end
end
Cast-->>PyTorch: Return output_list
PyTorch-->>User: Return quantized tensors
4 files reviewed, 2 comments
| kernel_args.split_sections_range[kernel_args.num_tensors + 1] = | ||
| kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; | ||
| // check overflow | ||
| NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, | ||
| "split_sections_range overflow the int32_t"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: overflow check only detects sign flip, not actual overflow before it happens - if sum is near INT_MAX, addition can overflow before the check
| kernel_args.split_sections_range[kernel_args.num_tensors + 1] = | |
| kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; | |
| // check overflow | |
| NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, | |
| "split_sections_range overflow the int32_t"); | |
| // check overflow before addition | |
| NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors] <= INT_MAX - split_sections[i], | |
| "split_sections_range would overflow int32_t"); | |
| kernel_args.split_sections_range[kernel_args.num_tensors + 1] = | |
| kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; |
| // calculate the global offset in Y direction to access the correct amax buffer | ||
| int global_offset_y = blockIdx.y * CHUNK_DIM_Y; | ||
| int tensor_id = 0; | ||
| while (args.split_sections_range[tensor_id + 1] <= global_offset_y) { | ||
| ++tensor_id; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: tensor_id could go out of bounds if global_offset_y >= split_sections_range[num_tensors] - while loop doesn't check upper bound
| // calculate the global offset in Y direction to access the correct amax buffer | |
| int global_offset_y = blockIdx.y * CHUNK_DIM_Y; | |
| int tensor_id = 0; | |
| while (args.split_sections_range[tensor_id + 1] <= global_offset_y) { | |
| ++tensor_id; | |
| } | |
| // calculate the global offset in Y direction to access the correct amax buffer | |
| int global_offset_y = blockIdx.y * CHUNK_DIM_Y; | |
| int tensor_id = 0; | |
| while (tensor_id < args.num_tensors && args.split_sections_range[tensor_id + 1] <= global_offset_y) { | |
| ++tensor_id; | |
| } | |
| // Bounds check after loop | |
| if (tensor_id >= args.num_tensors) { | |
| return; // or handle error | |
| } |
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements a grouped Hadamard amax kernel for NVFP4 quantization to reduce CPU overhead in MOE (Mixture of Experts) workloads. The implementation enables multi-tensor processing where a single input tensor is split into multiple sections, and amax (absolute maximum) values are computed for each section with optional Hadamard transform.
Key Changes:
- Added
multi_hadamard_transform.cuimplementingMultiHadamardAmaxTmaKernelfor computing amax across multiple tensor splits - Extended
cast.cppwithmulti_tensor_quantize_nvfp4_implto orchestrate grouped quantization with stochastic rounding support - Added comprehensive test suite
test_nvfp4_group_quantize.pycovering edge cases (zero tokens, uneven splits) - New API functions
nvte_multi_hadamard_transform_amaxandnvte_multi_tensor_amaxfor grouped operations - Enabled optional RHT (Randomized Hadamard Transform) disabling to compute only amax values
Implementation Approach:
The kernel treats multiple logical tensors as a single contiguous input, using split_sections_range (prefix sum) to map block coordinates to the correct tensor ID and amax buffer. This allows efficient GPU utilization while maintaining per-tensor quantization semantics.
Confidence Score: 4/5
- Safe to merge with minor bounds check improvement recommended
- The implementation is well-structured with comprehensive tests covering edge cases. The tensor_id calculation has a potential out-of-bounds issue flagged in previous comments that should be verified. Stochastic rounding works correctly but could benefit from the bulk API optimization noted in TODOs. All action items from the PR description are marked complete.
- Pay attention to
multi_hadamard_transform.culine 236-241 for the tensor_id bounds check issue
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu | 4/5 | New CUDA kernel for multi-tensor Hadamard transform with amax reduction. Implements efficient grouped quantization but has a minor bounds check issue in tensor_id calculation. |
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 4/5 | Implements multi-tensor NVFP4 quantization with stochastic rounding support. Stochastic rounding per-tensor loop could be optimized to bulk API as noted in TODO. |
Sequence Diagram
sequenceDiagram
participant PyTorch as PyTorch Layer
participant CPP as cast.cpp
participant MultiAmax as multi_hadamard_transform_amax
participant CUDA as MultiHadamardAmaxTmaKernel
participant Quantize as nvte_quantize_v2/cast_fusion
PyTorch->>CPP: split_quantize(input, split_sections, quantizers)
CPP->>CPP: Validate inputs & create split_sections
CPP->>CPP: Setup RNG states for stochastic rounding
alt with_rht enabled
CPP->>MultiAmax: nvte_multi_hadamard_transform_amax(input, outputs, split_sections)
MultiAmax->>MultiAmax: Build MultiAmaxArgs with split_sections_range
MultiAmax->>CUDA: MultiZeroAmaxKernel (initialize amax buffers)
MultiAmax->>CUDA: MultiHadamardAmaxTmaKernel (compute amax per split)
CUDA->>CUDA: Calculate tensor_id from blockIdx.y offset
CUDA->>CUDA: Compute RHT and amax for each split section
CUDA-->>MultiAmax: Populated amax buffers
else without_rht
CPP->>MultiAmax: nvte_multi_tensor_amax(input, outputs, split_sections)
MultiAmax->>CUDA: Calculate amax without RHT
MultiAmax->>CUDA: Copy amax to both rowwise and columnwise
end
CPP->>CPP: Loop through each split
alt rowwise_usage enabled
CPP->>Quantize: nvte_quantize_v2(input_chunk, output, config)
Quantize-->>CPP: Rowwise quantized data
end
alt columnwise_usage enabled
CPP->>Quantize: nvte_hadamard_transform_cast_fusion_columnwise
Quantize-->>CPP: Columnwise quantized data
end
CPP-->>PyTorch: List of quantized outputs
4 files reviewed, no comments
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements grouped NVFP4 Hadamard Amax kernels to reduce CPU overhead and quantization cost for MOE workloads.
Key Changes:
- Added
multi_tensor_quantize_nvfp4_implincast.cppto handle grouped quantization of multiple tensors with a single kernel call - Implemented
nvte_multi_hadamard_transform_amaxandnvte_multi_tensor_amaxkernels in newmulti_hadamard_transform.cufile - Refactored Hadamard transform utilities into
hadamard_transform_utils.cuhfor code reuse - Added comprehensive test suite in
test_nvfp4_group_quantize.pycovering edge cases (zero tokens, uneven splits) - Updated alignment requirements from 32 to 64 for NVFP4 quantization via
get_align_size_for_quantizationhelper - Added stochastic rounding support for grouped quantization (with per-tensor RNG state generation)
Technical Implementation:
The grouped kernel batches multiple tensor quantizations by:
- Computing amax values for all tensors in a single kernel launch using
split_sections_rangeto map blocks to tensors - Performing rowwise/columnwise quantization with optional RHT (Random Hadamard Transform) fusion
- Supporting both with-RHT (post-RHT amax) and without-RHT (pre-RHT amax copy) modes
Testing:
Extensive parametric tests validate correctness against per-tensor reference implementations across multiple dimensions, quantization modes, and edge cases.
Confidence Score: 4/5
- This PR is safe to merge with minor review of previously flagged items
- The implementation is well-tested with comprehensive test coverage including edge cases. Previous reviews have already identified optimization opportunities (stochastic rounding bulk API, dynamic_cast null checks) and potential bounds checking improvements. The core logic is sound - grouped kernel architecture properly handles tensor mapping via split_sections_range, and tests validate correctness. Confidence is 4/5 rather than 5/5 due to previously flagged issues that should be acknowledged before merge
transformer_engine/pytorch/csrc/extensions/cast.cpp- review previously flagged dynamic_cast and stochastic rounding comments
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/nvfp4/test_nvfp4_group_quantize.py | 5/5 | New comprehensive test file for grouped NVFP4 quantization, validates correctness against reference implementation for various edge cases including zero tokens |
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 4/5 | Adds multi_tensor_quantize_nvfp4_impl for grouped NVFP4 quantization with RHT support. Previous comments address stochastic rounding TODOs and dynamic_cast safety checks |
Sequence Diagram
sequenceDiagram
participant User as Python/PyTorch
participant SQ as tex.split_quantize
participant MTQ as multi_tensor_quantize_impl
participant NVFP4 as multi_tensor_quantize_nvfp4_impl
participant Amax as Amax Computation Kernel
participant Quant as Quantization Kernel
User->>SQ: split_quantize(input, split_sections, quantizers)
SQ->>MTQ: Call with input_list & quantizers
MTQ->>NVFP4: Route to NVFP4 implementation
Note over NVFP4: Setup RNG states for stochastic rounding (per-tensor)
alt With RHT enabled
NVFP4->>Amax: nvte_multi_hadamard_transform_amax
Note over Amax: Compute rowwise amax & columnwise amax(RHT(input.T))
Amax-->>NVFP4: Return amax values
loop For each tensor split
NVFP4->>Quant: nvte_quantize_v2 (rowwise)
NVFP4->>Quant: nvte_hadamard_transform_cast_fusion_columnwise
Quant-->>NVFP4: Quantized data + scales
end
else Without RHT
NVFP4->>Amax: nvte_multi_tensor_amax
Note over Amax: Compute amax for all splits, copy to columnwise
Amax-->>NVFP4: Return amax values
loop For each tensor split
NVFP4->>Quant: nvte_quantize_v2 (rowwise & columnwise)
Quant-->>NVFP4: Quantized data + scales
end
end
NVFP4-->>User: List of quantized outputs (data, scales, amax)
2 files reviewed, no comments
Description
This PR is one of the many on-going grouped kernels for NVFP4 to reduce CPU overhead and reduce quantization cost.
This PR is ready for code review
Action items:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: