Skip to content

Conversation

@zhongbozhu
Copy link
Collaborator

@zhongbozhu zhongbozhu commented Nov 6, 2025

Description

This PR is one of the many on-going grouped kernels for NVFP4 to reduce CPU overhead and reduce quantization cost.

This PR is ready for code review

Action items:

  • Clean up code, more testing
  • Enable optionally disabling RHT, only get amax (@timmoon10 FP8 CS can potentially use this one too).
  • Fix stochastic rounding in the grouped kernel
  • Let MegatronLM also bump up the NVFP4 M dim padding from 32 to 64, check out PR to MCore main & PR to Mcore dev

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zhongbozhu zhongbozhu self-assigned this Nov 6, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements a grouped/multi-tensor NVFP4 quantization kernel with Hadamard transform (RHT) to reduce CPU overhead for MoE workloads. The key changes include:

  • New fused kernel path in cast.cpp: Implements multi_tensor_quantize_nvfp4_impl() that processes multiple tensor chunks in a grouped fashion using nvte_multi_hadamard_transform_amax() for amax computation and fused RHT operations
  • Alignment requirement bump: NVFP4 padding increased from 32 to 64 elements per the grouped kernel's constraints (enforced via 64-multiple check on split_sections)
  • Refactored alignment logic: Introduced get_align_size_for_quantization() helper to centralize alignment logic (16 for FP8, 32 for MXFP8, 64 for NVFP4)
  • Fixed amax dtype bug: Changed amax tensor dtype from torch::kUInt8 to torch::kFloat32 in bulk allocation (lines 809, 864)
  • Comprehensive test coverage: New test file validates grouped quantization against reference implementation across various edge cases (zero tokens, uneven splits, transposed outputs)

The implementation correctly handles empty tensor chunks, validates input constraints (128-column multiple, bfloat16 only, no 2D quantization), and falls back to non-fused path when requirements aren't met.

Confidence Score: 4/5

  • This PR is generally safe to merge with minor caveats around incomplete features
  • The implementation is well-structured with proper validation checks and comprehensive test coverage. Fixed a critical dtype bug (amax as float32). However, stochastic rounding remains unimplemented (TODO at line 152), and the PR is marked as DRAFT with several TODO items in the description. The grouped kernel requires strict alignment (64-multiple) which is enforced. Core functionality appears sound based on the test suite.
  • transformer_engine/pytorch/csrc/extensions/cast.cpp - verify stochastic rounding is disabled in production use until implemented

Important Files Changed

File Analysis

Filename Score Overview
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py 5/5 New comprehensive test file for NVFP4 grouped quantization with RHT, covering edge cases like zero tokens and uneven splits.
tests/pytorch/test_numerics.py 5/5 Refactored to use get_align_size_for_quantization() helper function, replacing hardcoded alignment values for better maintainability.
transformer_engine/pytorch/csrc/extensions/cast.cpp 4/5 Major implementation of NVFP4 grouped quantization with fused Hadamard transform, fixed amax dtype bug, added dimension checks for 64-multiple alignment.

Sequence Diagram

sequenceDiagram
    participant User
    participant split_quantize
    participant multi_tensor_quantize_impl
    participant multi_tensor_quantize_nvfp4_impl
    participant nvte_multi_hadamard_transform_amax
    participant Quantize Loop

    User->>split_quantize: Input tensor + split_sections + quantizers
    split_quantize->>split_quantize: Split input into chunks
    split_quantize->>split_quantize: Check if NVFP4 fused kernel eligible
    Note over split_quantize: Requires: 64-multiple splits, NVFP4_1D_SCALING
    split_quantize->>multi_tensor_quantize_impl: Pass chunks + quantizers
    multi_tensor_quantize_impl->>multi_tensor_quantize_impl: Validate scaling mode consistency
    multi_tensor_quantize_impl->>multi_tensor_quantize_nvfp4_impl: NVFP4 fused path
    
    multi_tensor_quantize_nvfp4_impl->>nvte_multi_hadamard_transform_amax: Compute amax with RHT
    Note over nvte_multi_hadamard_transform_amax: Computes rowwise & columnwise amax
    nvte_multi_hadamard_transform_amax-->>multi_tensor_quantize_nvfp4_impl: amax values populated
    
    multi_tensor_quantize_nvfp4_impl->>Quantize Loop: For each chunk
    loop Each tensor chunk
        Quantize Loop->>Quantize Loop: Skip if numel == 0
        Quantize Loop->>Quantize Loop: Rowwise quantization (nvte_quantize_v2)
        Quantize Loop->>Quantize Loop: Columnwise with RHT fusion (nvte_hadamard_transform_cast_fusion_columnwise)
    end
    
    Quantize Loop-->>multi_tensor_quantize_nvfp4_impl: Quantized chunks
    multi_tensor_quantize_nvfp4_impl-->>split_quantize: Output list
    split_quantize-->>User: List of NVFP4Tensor objects
Loading

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 151 to 153
if (quantizer->stochastic_rounding) {
// TODO: implement stochastic rounding support for multi tensor
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Stochastic rounding not implemented yet for multi-tensor - verify that stochastic_rounding is disabled in tests, or this will silently skip the rounding.

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu force-pushed the zhongbo/mult_RHT_amax branch from b8a1722 to c6c11ef Compare November 7, 2025 19:58
@zhongbozhu
Copy link
Collaborator Author

/te_ci L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements a grouped NVFP4 Hadamard Amax kernel to reduce CPU overhead and quantization cost for mixture-of-experts (MOE) workloads. The implementation adds a high-performance CUDA kernel using TMA (Tensor Memory Accelerator) and tensor cores.

Key Changes

  • New grouped Hadamard kernel (multi_hadamard_transform.cu): Implements MultiHadamardAmaxTmaKernel with ping-pong buffering, tensor core operations, and multi-tensor amax reduction
  • Extended Python API (cast.cpp): Added multi_tensor_quantize_nvfp4_impl function supporting grouped quantization with stochastic rounding
  • Padding alignment change: Bumped NVFP4 M-dimension padding requirement from 32 to 64 bytes for optimal kernel performance
  • Utility function: Added get_align_size_for_quantization() to centralize alignment logic across the codebase
  • Stochastic rounding fix: Recent commit (c6c11ef) properly implements stochastic rounding support for grouped operations
  • Comprehensive tests: Added test coverage for grouped quantization with various edge cases (zero tokens, uneven splits, stochastic rounding)

Important Constraints

  • Hardware requirement: Kernel requires SM 10.0+ (Blackwell architecture)
  • Alignment requirement: Split sections must be 64 multiples (breaking change from 32)
  • Limitations: Pre-RHT amax and non-RHT modes not supported
  • MegatronLM coordination: PR description notes MegatronLM needs to update padding from 32 to 64

Architecture

The implementation uses a sophisticated multi-stage pipeline:

  1. Zero-initialize amax buffers for all tensors
  2. Launch grouped Hadamard transform kernel with TMA loads and tensor core MMA operations
  3. Compute both pre-RHT amax (rowwise) and post-RHT amax (columnwise transpose)
  4. Perform rowwise and/or columnwise quantization using populated amax values

Confidence Score: 3/5

  • This PR requires coordination with downstream projects and has pending TODO items before it should be merged
  • The implementation is technically sound with good test coverage and stochastic rounding properly implemented. However, the PR is explicitly marked as DRAFT with several unfinished TODO items: (1) code cleanup and more testing needed, (2) stochastic rounding was just fixed in latest commit, (3) MegatronLM needs to update padding from 32 to 64 bytes. The breaking change in alignment requirements (32→64) needs coordination with downstream users. Additionally, the SM 10.0+ hardware requirement is a significant constraint that should be clearly communicated.
  • transformer_engine/pytorch/csrc/extensions/cast.cpp needs verification that stochastic rounding fix is complete. Downstream projects (especially MegatronLM) need to update padding requirements before this can merge.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu 4/5 New grouped Hadamard transform kernel implementation with TMA. Uses tensor cores and advanced PTX instructions. Complex ping-pong buffering and amax reduction logic appears correct. Main concern: requires SM 10.0+ (Blackwell) which limits portability.
transformer_engine/pytorch/csrc/extensions/cast.cpp 3/5 Adds multi_tensor_quantize_nvfp4_impl function with stochastic rounding support. Stochastic rounding was recently fixed in commit c6c11ef. RNG state generation uses per-tensor loop (TODO notes bulk API would be better). Validates split_sections must be 64 multiples.
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py 4/5 Comprehensive test coverage for grouped quantization including edge cases (zero tokens, uneven splits). Tests validate against reference implementation with exact equality checks. Good test structure.
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py 4/5 New test file for stochastic rounding validation in grouped quantization. Includes reference implementations for RHT and FP4 dequantization. Tests both single and grouped quantization paths.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant PyAPI as tex.split_quantize
    participant Cast as cast.cpp
    participant MultiImpl as multi_tensor_quantize_nvfp4_impl
    participant Hadamard as nvte_multi_hadamard_transform_amax
    participant Kernel as MultiHadamardAmaxTmaKernel
    participant Quantize as nvte_quantize_v2 / RHT fusion

    User->>PyAPI: tex.split_quantize(input, split_sections, quantizers)
    PyAPI->>Cast: split_quantize (C++ extension)
    Cast->>Cast: Split input tensor along dim 0
    Cast->>Cast: Allocate bulk NVFP4 output tensors
    Cast->>MultiImpl: multi_tensor_quantize_nvfp4_impl()
    
    alt Stochastic Rounding Enabled
        MultiImpl->>MultiImpl: Generate RNG states for each tensor
        MultiImpl->>MultiImpl: Configure quant_config with RNG
    end
    
    MultiImpl->>Hadamard: nvte_multi_hadamard_transform_amax(input, outputs, split_sections)
    Note over Hadamard: Compute pre-RHT amax & columnwise RHT amax
    Hadamard->>Kernel: MultiZeroAmaxKernel (initialize amax buffers)
    Hadamard->>Kernel: MultiHadamardAmaxTmaKernel (SM 10.0+)
    
    Note over Kernel: TMA load with ping-pong buffering
    Note over Kernel: Tensor core MMA (16x16x16)
    Note over Kernel: Compute amax for pre-RHT and RHT(transpose)
    Kernel-->>Hadamard: Amax values written to output tensors
    
    Hadamard-->>MultiImpl: Return with populated amax buffers
    
    loop For each tensor split
        alt Rowwise Quantization
            MultiImpl->>Quantize: nvte_quantize_v2 (rowwise)
        end
        
        alt Columnwise Quantization with RHT
            MultiImpl->>Quantize: nvte_hadamard_transform_cast_fusion_columnwise
            Note over Quantize: RHT transform + quantize in single kernel
        end
    end
    
    MultiImpl-->>Cast: Quantized outputs
    Cast-->>PyAPI: NVFP4Tensor list
    PyAPI-->>User: Return quantized tensors
Loading

10 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 161 to 179
if (quantizer->stochastic_rounding) {
// TODO(zhongbo): remove the for loop of generating rng states with a single call
// with rng_elts_per_thread = 1024 * num_tensors
// Change to the bulk generate rng states api when grouped quantize is available
const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
rng_states_tensor = torch::empty({2 * num_tensors}, opts);

for (size_t i = 0; i < num_tensors; ++i) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
int64_t *rng_state_ptr = static_cast<int64_t *>(rng_states_tensor.data_ptr()) + i * 2;
philox_unpack(philox_args, rng_state_ptr);
te_rng_state_list.push_back(makeTransformerEngineTensor(
static_cast<void *>(rng_state_ptr), std::vector<size_t>{2}, DType::kInt64));
quant_config_list[i].set_rng_state(te_rng_state_list[i].data());
quant_config_list[i].set_stochastic_rounding(true);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Consider optimizing RNG state generation by using bulk API instead of per-tensor loop (as noted in TODO). This would improve performance for multi-tensor quantization.

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds a utility function get_align_size_for_quantization() to centralize the logic for determining alignment sizes based on quantization recipes, replacing duplicated inline logic across the codebase.

Key changes:

  • Added get_align_size_for_quantization() function to return alignment size: 32 for MXFP8, 64 for NVFP4, 16 for other recipes
  • Important: NVFP4 alignment padding increased from 32 to 64 (as mentioned in PR description and related MegatronLM PRs)
  • Updated Fp8Padding and Fp8Unpadding modules to use the new function
  • Updated test files to use the centralized function instead of hardcoded values

This refactoring improves maintainability by centralizing the alignment logic in one place.

Confidence Score: 5/5

  • This PR is safe to merge - it's a straightforward refactoring that centralizes alignment size logic
  • The changes are minimal and low-risk: adds a simple utility function that replaces duplicated inline logic with centralized code. The function has clear control flow with no complex edge cases. The NVFP4 padding change from 32 to 64 is intentional and documented in the PR description with corresponding MegatronLM changes.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/quantization.py 5/5 Added utility function get_align_size_for_quantization to centralize alignment size logic (32 for MXFP8, 64 for NVFP4, 16 default). NVFP4 padding increased from 32 to 64.

Sequence Diagram

sequenceDiagram
    participant User
    participant Fp8Padding/Fp8Unpadding
    participant FP8GlobalStateManager
    participant get_align_size_for_quantization
    participant Recipe
    
    User->>Fp8Padding/Fp8Unpadding: forward(inp, m_splits)
    
    alt align_size is None
        Fp8Padding/Fp8Unpadding->>FP8GlobalStateManager: get_fp8_recipe()
        FP8GlobalStateManager-->>Fp8Padding/Fp8Unpadding: recipe
        Fp8Padding/Fp8Unpadding->>get_align_size_for_quantization: get_align_size_for_quantization(recipe)
        get_align_size_for_quantization->>Recipe: recipe.mxfp8()
        Recipe-->>get_align_size_for_quantization: bool
        
        alt recipe.mxfp8() == True
            get_align_size_for_quantization-->>Fp8Padding/Fp8Unpadding: 32
        else
            get_align_size_for_quantization->>Recipe: recipe.nvfp4()
            Recipe-->>get_align_size_for_quantization: bool
            
            alt recipe.nvfp4() == True
                get_align_size_for_quantization-->>Fp8Padding/Fp8Unpadding: 64
            else
                get_align_size_for_quantization-->>Fp8Padding/Fp8Unpadding: 16
            end
        end
    end
    
    Fp8Padding/Fp8Unpadding->>Fp8Padding/Fp8Unpadding: Calculate padded_m_splits
    Fp8Padding/Fp8Unpadding-->>User: output
Loading

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

return get_default_fp8_recipe()


def get_align_size_for_quantization(recipe: Recipe):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Missing return type annotation - should specify -> int

Suggested change
def get_align_size_for_quantization(recipe: Recipe):
def get_align_size_for_quantization(recipe: Recipe) -> int:

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Implements grouped NVFP4 quantization with Hadamard transform and amax computation to reduce CPU overhead and quantization costs for MOE workloads.

Key Changes

  • Added multi_tensor_quantize_nvfp4_impl function (~190 lines) that handles batched NVFP4 quantization with RHT
  • Extended multi_tensor_quantize_impl to route NVFP4 quantizers to the new fused kernel path
  • Implemented stochastic rounding support for multi-tensor quantization with per-tensor RNG state generation
  • Fixed critical bug: changed amax tensor dtype from kUInt8 to kFloat32 in bulk allocation functions (lines 830, 885)
  • Added validation checks: columns must be multiple of 128, split_sections must be multiple of 64
  • Modified function signatures to pass full input tensor alongside split tensors for grouped amax computation

Implementation Details

  • Uses nvte_multi_hadamard_transform_amax for efficient batched amax computation across row/column-wise modes
  • Supports both rowwise and columnwise quantization with RHT fusion via nvte_hadamard_transform_cast_fusion_columnwise
  • Requires bfloat16 input for RHT path (validation at line 188)
  • Currently only supports post-RHT amax (pre-RHT amax and no-RHT paths intentionally unimplemented)

Issues Found

  • Missing null check after dynamic_cast at line 372 - could cause segfault if cast fails
  • Assumes all quantizers have identical stochastic rounding settings without verification

Confidence Score: 3/5

  • This PR has a critical null pointer dereference risk that must be addressed before merging
  • Score reflects one critical bug (missing null check on dynamic_cast) that could cause runtime crashes, plus minor assumption issues. The core implementation appears sound with proper validation checks, stochastic rounding support is correctly implemented, and the amax dtype bug was already fixed. However, the missing null check is a blocking issue.
  • Pay close attention to line 372 in cast.cpp where the dynamic_cast result must be null-checked before use

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/csrc/extensions/cast.cpp 4/5 Adds NVFP4 grouped Hadamard transform with amax kernel. Includes stochastic rounding support, bulk RNG state generation, multi-tensor amax computation, and fused RHT+quantization. Fixed amax dtype from kUInt8 to kFloat32.

Sequence Diagram

sequenceDiagram
    participant Client
    participant split_quantize
    participant multi_tensor_quantize_impl
    participant multi_tensor_quantize_nvfp4_impl
    participant CUDA_Kernels

    Client->>split_quantize: Input tensor + split_sections + quantizers
    split_quantize->>split_quantize: Split input tensor along dim 0
    split_quantize->>split_quantize: Create TensorWrapper for full input
    split_quantize->>split_quantize: Bulk allocate NVFP4 output tensors
    split_quantize->>multi_tensor_quantize_impl: input_wrapper + input_list + split_sections
    
    multi_tensor_quantize_impl->>multi_tensor_quantize_impl: Check scaling mode consistency
    multi_tensor_quantize_impl->>multi_tensor_quantize_impl: Verify split_sections % 64 == 0
    multi_tensor_quantize_impl->>multi_tensor_quantize_nvfp4_impl: Route to NVFP4 fused kernel
    
    multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Validate inputs (cols % 128, no 2D, no amax reduction)
    multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Create QuantizationConfigWrapper list
    
    alt Stochastic Rounding Enabled
        multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Allocate RNG states tensor [2*num_tensors]
        loop For each tensor
            multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Generate Philox RNG state
            multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Set RNG state in quant_config_list[i]
        end
    end
    
    alt RHT Enabled
        multi_tensor_quantize_nvfp4_impl->>CUDA_Kernels: nvte_multi_hadamard_transform_amax
        CUDA_Kernels-->>multi_tensor_quantize_nvfp4_impl: Compute rowwise & columnwise amax
        
        loop For each tensor
            alt Rowwise quantization
                multi_tensor_quantize_nvfp4_impl->>CUDA_Kernels: nvte_quantize_v2 (identity)
            end
            alt Columnwise quantization
                multi_tensor_quantize_nvfp4_impl->>CUDA_Kernels: nvte_hadamard_transform_cast_fusion_columnwise
            end
        end
    else RHT Disabled
        multi_tensor_quantize_nvfp4_impl->>multi_tensor_quantize_nvfp4_impl: Error: multi-amax without RHT not supported
    end
    
    multi_tensor_quantize_nvfp4_impl-->>Client: Return quantized tensors
Loading

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

break;
}
case NVTE_NVFP4_1D_SCALING: {
auto nvfp4_quantizer = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp_list[0].get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: dynamic_cast can return nullptr if cast fails - needs null check before dereferencing

Suggested change
auto nvfp4_quantizer = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp_list[0].get());
auto nvfp4_quantizer = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp_list[0].get());
NVTE_CHECK(nvfp4_quantizer != nullptr, "Failed to cast to NVFP4Quantizer");

at::Tensor rng_states_tensor;

// assumes one quantizer doing RS means all quantizers doing RS
if (quantizer->stochastic_rounding) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Assumes all quantizers have same stochastic_rounding setting - verify quantizers are consistent or check each individually

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements grouped NVFP4 Hadamard transform kernels to reduce CPU overhead for MOE quantization. The implementation refactors shared utility functions into hadamard_transform_utils.cuh and adds a new multi_hadamard_transform.cu file with a TMA-based kernel that processes multiple tensor splits in a single launch.

Key changes:

  • Refactored Hadamard utility functions into shared header for code reuse
  • Added MultiHadamardAmaxTmaKernel that computes amax values across multiple tensor splits using prefix-sum based indexing
  • Extended multi_tensor_quantize_impl to support NVFP4 quantizers with 64-multiple split_sections requirement
  • Fixed amax buffer dtypes from torch::kUInt8 to torch::kFloat32 in bulk allocation

Critical issues found:

  • Overflow check happens after addition, allowing signed integer overflow (UB)
  • Missing bounds check for tensor_id lookup in kernel - can access out-of-bounds memory if global_offset_y exceeds total range

Note: PR is marked as DRAFT and author lists cleanup/testing as pending action items.

Confidence Score: 2/5

  • Not safe to merge - contains critical memory safety and integer overflow bugs
  • Score of 2 due to two critical logic errors: (1) overflow check after addition allows undefined behavior, (2) missing bounds check in kernel can cause out-of-bounds memory access. These must be fixed before merging. Refactoring and new kernel implementation are sound, but safety issues prevent higher score.
  • Pay close attention to transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu - contains both critical bugs that need immediate fixing

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu 2/5 New grouped NVFP4 Hadamard kernel with TMA - found critical overflow/bounds check issues that need fixing before merge
transformer_engine/pytorch/csrc/extensions/cast.cpp 3/5 Added NVFP4 multi-tensor quantization support - complex logic with proper validation but amax dtype fix needs verification

Sequence Diagram

sequenceDiagram
    participant User
    participant PyTorch as split_quantize()
    participant Cast as multi_tensor_quantize_impl()
    participant NVFP4 as multi_tensor_quantize_nvfp4_impl()
    participant Hadamard as multi_hadamard_transform_amax()
    participant Kernel as MultiHadamardAmaxTmaKernel
    participant Quantize as nvte_quantize_v2()

    User->>PyTorch: split_quantize(tensor, split_sections, quantizers)
    PyTorch->>PyTorch: Create input_list by splitting tensor
    PyTorch->>Cast: multi_tensor_quantize_impl(single_input, input_list, quantizers, split_sections)
    
    Cast->>Cast: Check if all quantizers are NVFP4
    Cast->>Cast: Validate split_sections (must be 64 multiple)
    
    alt Fused NVFP4 Kernel Path
        Cast->>NVFP4: multi_tensor_quantize_nvfp4_impl()
        NVFP4->>NVFP4: Generate RNG states for stochastic rounding
        NVFP4->>Hadamard: nvte_multi_hadamard_transform_amax()
        Hadamard->>Hadamard: Build MultiAmaxArgs with split_sections_range
        Hadamard->>Kernel: Launch MultiZeroAmaxKernel (zero out amax buffers)
        Hadamard->>Kernel: Launch MultiHadamardAmaxTmaKernel
        Kernel->>Kernel: For each block, find tensor_id from global_offset_y
        Kernel->>Kernel: Compute RHT and calculate amax values
        Kernel->>Kernel: Atomic updates to output amax buffers
        Kernel-->>NVFP4: Return (amax populated)
        
        loop For each tensor
            NVFP4->>Quantize: nvte_quantize_v2() for rowwise (if enabled)
            NVFP4->>Quantize: nvte_hadamard_transform_cast_fusion_columnwise() for columnwise
        end
        NVFP4-->>Cast: Return quantized tensors
    else Individual Quantization Path
        loop For each tensor
            Cast->>Quantize: quantizer.quantize(input, output)
        end
    end
    
    Cast-->>PyTorch: Return output_list
    PyTorch-->>User: Return quantized tensors
Loading

4 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +416 to +420
kernel_args.split_sections_range[kernel_args.num_tensors + 1] =
kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i];
// check overflow
NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0,
"split_sections_range overflow the int32_t");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: overflow check only detects sign flip, not actual overflow before it happens - if sum is near INT_MAX, addition can overflow before the check

Suggested change
kernel_args.split_sections_range[kernel_args.num_tensors + 1] =
kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i];
// check overflow
NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0,
"split_sections_range overflow the int32_t");
// check overflow before addition
NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors] <= INT_MAX - split_sections[i],
"split_sections_range would overflow int32_t");
kernel_args.split_sections_range[kernel_args.num_tensors + 1] =
kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i];

Comment on lines +214 to +219
// calculate the global offset in Y direction to access the correct amax buffer
int global_offset_y = blockIdx.y * CHUNK_DIM_Y;
int tensor_id = 0;
while (args.split_sections_range[tensor_id + 1] <= global_offset_y) {
++tensor_id;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: tensor_id could go out of bounds if global_offset_y >= split_sections_range[num_tensors] - while loop doesn't check upper bound

Suggested change
// calculate the global offset in Y direction to access the correct amax buffer
int global_offset_y = blockIdx.y * CHUNK_DIM_Y;
int tensor_id = 0;
while (args.split_sections_range[tensor_id + 1] <= global_offset_y) {
++tensor_id;
}
// calculate the global offset in Y direction to access the correct amax buffer
int global_offset_y = blockIdx.y * CHUNK_DIM_Y;
int tensor_id = 0;
while (tensor_id < args.num_tensors && args.split_sections_range[tensor_id + 1] <= global_offset_y) {
++tensor_id;
}
// Bounds check after loop
if (tensor_id >= args.num_tensors) {
return; // or handle error
}

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements a grouped Hadamard amax kernel for NVFP4 quantization to reduce CPU overhead in MOE (Mixture of Experts) workloads. The implementation enables multi-tensor processing where a single input tensor is split into multiple sections, and amax (absolute maximum) values are computed for each section with optional Hadamard transform.

Key Changes:

  • Added multi_hadamard_transform.cu implementing MultiHadamardAmaxTmaKernel for computing amax across multiple tensor splits
  • Extended cast.cpp with multi_tensor_quantize_nvfp4_impl to orchestrate grouped quantization with stochastic rounding support
  • Added comprehensive test suite test_nvfp4_group_quantize.py covering edge cases (zero tokens, uneven splits)
  • New API functions nvte_multi_hadamard_transform_amax and nvte_multi_tensor_amax for grouped operations
  • Enabled optional RHT (Randomized Hadamard Transform) disabling to compute only amax values

Implementation Approach:
The kernel treats multiple logical tensors as a single contiguous input, using split_sections_range (prefix sum) to map block coordinates to the correct tensor ID and amax buffer. This allows efficient GPU utilization while maintaining per-tensor quantization semantics.

Confidence Score: 4/5

  • Safe to merge with minor bounds check improvement recommended
  • The implementation is well-structured with comprehensive tests covering edge cases. The tensor_id calculation has a potential out-of-bounds issue flagged in previous comments that should be verified. Stochastic rounding works correctly but could benefit from the bulk API optimization noted in TODOs. All action items from the PR description are marked complete.
  • Pay attention to multi_hadamard_transform.cu line 236-241 for the tensor_id bounds check issue

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/hadamard_transform/multi_hadamard_transform.cu 4/5 New CUDA kernel for multi-tensor Hadamard transform with amax reduction. Implements efficient grouped quantization but has a minor bounds check issue in tensor_id calculation.
transformer_engine/pytorch/csrc/extensions/cast.cpp 4/5 Implements multi-tensor NVFP4 quantization with stochastic rounding support. Stochastic rounding per-tensor loop could be optimized to bulk API as noted in TODO.

Sequence Diagram

sequenceDiagram
    participant PyTorch as PyTorch Layer
    participant CPP as cast.cpp
    participant MultiAmax as multi_hadamard_transform_amax
    participant CUDA as MultiHadamardAmaxTmaKernel
    participant Quantize as nvte_quantize_v2/cast_fusion

    PyTorch->>CPP: split_quantize(input, split_sections, quantizers)
    CPP->>CPP: Validate inputs & create split_sections
    CPP->>CPP: Setup RNG states for stochastic rounding
    
    alt with_rht enabled
        CPP->>MultiAmax: nvte_multi_hadamard_transform_amax(input, outputs, split_sections)
        MultiAmax->>MultiAmax: Build MultiAmaxArgs with split_sections_range
        MultiAmax->>CUDA: MultiZeroAmaxKernel (initialize amax buffers)
        MultiAmax->>CUDA: MultiHadamardAmaxTmaKernel (compute amax per split)
        CUDA->>CUDA: Calculate tensor_id from blockIdx.y offset
        CUDA->>CUDA: Compute RHT and amax for each split section
        CUDA-->>MultiAmax: Populated amax buffers
    else without_rht
        CPP->>MultiAmax: nvte_multi_tensor_amax(input, outputs, split_sections)
        MultiAmax->>CUDA: Calculate amax without RHT
        MultiAmax->>CUDA: Copy amax to both rowwise and columnwise
    end
    
    CPP->>CPP: Loop through each split
    
    alt rowwise_usage enabled
        CPP->>Quantize: nvte_quantize_v2(input_chunk, output, config)
        Quantize-->>CPP: Rowwise quantized data
    end
    
    alt columnwise_usage enabled
        CPP->>Quantize: nvte_hadamard_transform_cast_fusion_columnwise
        Quantize-->>CPP: Columnwise quantized data
    end
    
    CPP-->>PyTorch: List of quantized outputs
Loading

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements grouped NVFP4 Hadamard Amax kernels to reduce CPU overhead and quantization cost for MOE workloads.

Key Changes:

  • Added multi_tensor_quantize_nvfp4_impl in cast.cpp to handle grouped quantization of multiple tensors with a single kernel call
  • Implemented nvte_multi_hadamard_transform_amax and nvte_multi_tensor_amax kernels in new multi_hadamard_transform.cu file
  • Refactored Hadamard transform utilities into hadamard_transform_utils.cuh for code reuse
  • Added comprehensive test suite in test_nvfp4_group_quantize.py covering edge cases (zero tokens, uneven splits)
  • Updated alignment requirements from 32 to 64 for NVFP4 quantization via get_align_size_for_quantization helper
  • Added stochastic rounding support for grouped quantization (with per-tensor RNG state generation)

Technical Implementation:
The grouped kernel batches multiple tensor quantizations by:

  1. Computing amax values for all tensors in a single kernel launch using split_sections_range to map blocks to tensors
  2. Performing rowwise/columnwise quantization with optional RHT (Random Hadamard Transform) fusion
  3. Supporting both with-RHT (post-RHT amax) and without-RHT (pre-RHT amax copy) modes

Testing:
Extensive parametric tests validate correctness against per-tensor reference implementations across multiple dimensions, quantization modes, and edge cases.

Confidence Score: 4/5

  • This PR is safe to merge with minor review of previously flagged items
  • The implementation is well-tested with comprehensive test coverage including edge cases. Previous reviews have already identified optimization opportunities (stochastic rounding bulk API, dynamic_cast null checks) and potential bounds checking improvements. The core logic is sound - grouped kernel architecture properly handles tensor mapping via split_sections_range, and tests validate correctness. Confidence is 4/5 rather than 5/5 due to previously flagged issues that should be acknowledged before merge
  • transformer_engine/pytorch/csrc/extensions/cast.cpp - review previously flagged dynamic_cast and stochastic rounding comments

Important Files Changed

File Analysis

Filename Score Overview
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py 5/5 New comprehensive test file for grouped NVFP4 quantization, validates correctness against reference implementation for various edge cases including zero tokens
transformer_engine/pytorch/csrc/extensions/cast.cpp 4/5 Adds multi_tensor_quantize_nvfp4_impl for grouped NVFP4 quantization with RHT support. Previous comments address stochastic rounding TODOs and dynamic_cast safety checks

Sequence Diagram

sequenceDiagram
    participant User as Python/PyTorch
    participant SQ as tex.split_quantize
    participant MTQ as multi_tensor_quantize_impl
    participant NVFP4 as multi_tensor_quantize_nvfp4_impl
    participant Amax as Amax Computation Kernel
    participant Quant as Quantization Kernel
    
    User->>SQ: split_quantize(input, split_sections, quantizers)
    SQ->>MTQ: Call with input_list & quantizers
    MTQ->>NVFP4: Route to NVFP4 implementation
    
    Note over NVFP4: Setup RNG states for stochastic rounding (per-tensor)
    
    alt With RHT enabled
        NVFP4->>Amax: nvte_multi_hadamard_transform_amax
        Note over Amax: Compute rowwise amax & columnwise amax(RHT(input.T))
        Amax-->>NVFP4: Return amax values
        
        loop For each tensor split
            NVFP4->>Quant: nvte_quantize_v2 (rowwise)
            NVFP4->>Quant: nvte_hadamard_transform_cast_fusion_columnwise
            Quant-->>NVFP4: Quantized data + scales
        end
    else Without RHT
        NVFP4->>Amax: nvte_multi_tensor_amax
        Note over Amax: Compute amax for all splits, copy to columnwise
        Amax-->>NVFP4: Return amax values
        
        loop For each tensor split
            NVFP4->>Quant: nvte_quantize_v2 (rowwise & columnwise)
            Quant-->>NVFP4: Quantized data + scales
        end
    end
    
    NVFP4-->>User: List of quantized outputs (data, scales, amax)
Loading

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant