Skip to content

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Oct 31, 2025

Description

This PR adds a persistent gated MXFP8 kernel optimized for rowwise scaling, SwiGLU activation (FWD and BWD) and BF16/FP16 input tensors. The kernel uses the "Cluster Launch Control" feature of Blackwell arch (sm100f family) to enable persistency with dynamic scheduling.

Performance measured on PreNYX cluster

Optimized Gated MXFP8 SwiGLU kernel Throughput

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added a specialized kernel
  • Added the logic to use it when the conditions are met

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Summary by CodeRabbit

Release Notes

  • Performance Improvements

    • Enhanced MXFP8 quantization with specialized processing for BF16/FP16 inputs
    • Optimized memory alignment calculations and transfer mechanisms for tensor operations
  • Refactor

    • Updated internal kernel synchronization and memory management utilities to support advanced synchronization patterns

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR introduces a specialized persistent CUDA kernel optimized for MXFP8 rowwise quantization of gated SwiGLU activations on Blackwell architecture (sm100f family). The implementation achieves significant performance improvements by leveraging architecture-specific features.

Key Changes

  • New specialized kernel (gated_mxfp8_rowwise_swiglu.cuh): 785-line implementation with hand-tuned PTX assembly for BF16/FP16 inputs
  • Persistent scheduling: Uses Blackwell's Cluster Launch Control for dynamic work distribution across CTAs
  • Optimized compute path: Fused SwiGLU backward/forward computation with inline PTX for mul.f32x2, max.xorsign.abs, and vectorized conversions
  • Memory optimization: Double-buffered TMA operations with prefetching (1 stage ahead) to overlap compute and memory transfers
  • Infrastructure additions: New PTX helpers for cluster control, 2D CTA cancellation queries, and shared memory alignment

Dispatch Logic

The new kernel activates when ALL conditions are met:

  1. Activation function is SwiGLU (silu, not clamped variant)
  2. Input dtype is BF16 or FP16
  3. Scaling mode is rowwise only

Otherwise, falls back to the existing generic kernel that supports bidimensional and columnwise scaling.

Confidence Score: 4/5

  • Safe to merge with minor considerations for architecture-specific validation
  • The implementation is well-structured with proper architecture guards, but lacks unit tests. The kernel is correctly gated behind multiple condition checks and falls back gracefully. PTX assembly appears sound with proper register management and no obvious memory safety issues. Main risk is architecture-specific behavior on Blackwell hardware requiring validation.
  • Pay attention to gated_mxfp8_rowwise_swiglu.cuh for runtime validation on actual Blackwell hardware, especially the persistent scheduling and cluster launch control paths

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/cast/mxfp8/specialized/gated_mxfp8_rowwise_swiglu.cuh 4/5 New 785-line specialized kernel for MXFP8 rowwise SwiGLU with Blackwell persistent scheduling. Implements optimized PTX assembly for BF16/FP16 forward/backward passes using cluster launch control and TMA operations.
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh 5/5 Added dispatch logic (lines 700-710) to route SwiGLU rowwise BF16/FP16 operations to the optimized kernel when conditions match (non-clamped SwiGLU, rowwise scaling).
transformer_engine/common/util/ptx.cuh 5/5 Added PTX infrastructure for Blackwell persistent kernels: cluster launch control functions, 2D CTA cancellation queries, shared memory alignment helper, and optimized BF16/FP16 load-convert functions.
transformer_engine/common/cast/core/common.cuh 5/5 Refactored TMA_bytes constant to use TMA_GMEM_ALIGNMENT (line 33) and added align_smem_ptr_per_TMA_requirements device function (lines 37-42) for dynamic shared memory alignment.

Sequence Diagram

sequenceDiagram
    participant User
    participant Dispatcher
    participant OptKernel
    participant GenKernel
    participant TMA
    participant ClusterCtrl

    User->>Dispatcher: Call quantize gated
    Dispatcher->>Dispatcher: Check conditions
    
    alt Optimized path
        Dispatcher->>OptKernel: Launch rowwise kernel
        OptKernel->>TMA: Prefetch initial data
        
        loop Process chunks
            OptKernel->>ClusterCtrl: Request next work
            ClusterCtrl-->>OptKernel: Assign CTA coords
            
            loop Each stage
                OptKernel->>TMA: Async prefetch
                OptKernel->>OptKernel: Wait on barrier
                OptKernel->>OptKernel: Compute activation
                OptKernel->>OptKernel: Quantize to MXFP8
                OptKernel->>TMA: Write results
            end
        end
        
        OptKernel-->>User: Return output
        
    else Generic path
        Dispatcher->>GenKernel: Launch generic kernel
        GenKernel->>TMA: Standard processing
        GenKernel-->>User: Return output
    end
Loading

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds a specialized persistent CUDA kernel optimized for gated MXFP8 quantization with rowwise scaling and SwiGLU activation on Blackwell architecture (sm_100+). The implementation leverages Cluster Launch Control for dynamic scheduling across CTAs.

Key changes:

  • New persistent kernel in gated_mxfp8_rowwise_swiglu.cuh using TMA bulk transfers and mbarrier synchronization
  • PTX wrappers for Blackwell-specific cluster launch control primitives and relaxed/acquire memory ordering variants
  • Dispatch logic routes BF16/FP16 inputs with rowwise scaling and SwiGLU to the optimized kernel
  • Helper function for TMA-aligned shared memory pointer alignment

Implementation notes:

  • Kernel uses persistent thread blocks with dynamic work assignment via clusterlaunchcontrol_try_cancel
  • Single-iteration prefetching pipeline with double-buffering for input/output data
  • Optimized PTX assembly for BWD SwiGLU computation using SIMD instructions
  • Variable shadowing issue at line 257 (previously flagged) has been resolved in commit 25ed43e

Confidence Score: 4/5

  • This PR is safe to merge with minor risk - the optimized kernel is architecture-specific and well-isolated
  • Score reflects that the implementation is sound with proper fallback paths, but uses advanced Blackwell-specific features (persistent kernels, cluster launch control) that are less battle-tested than standard approaches. The variable shadowing issue was already fixed. The kernel is properly gated behind architecture checks and only executes for specific input configurations.
  • No files require special attention - the implementation follows existing patterns in the codebase

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/cast/mxfp8/specialized/gated_mxfp8_rowwise_swiglu.cuh 4/5 New optimized persistent kernel for MXFP8 SwiGLU with Blackwell-specific cluster launch control. Variable shadowing issue already fixed.
transformer_engine/common/util/ptx.cuh 5/5 Added PTX wrappers for mbarrier operations and cluster launch control primitives for Blackwell architecture.
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh 5/5 Added dispatch logic to route eligible workloads to optimized kernel based on data type, scaling type, and activation function.
transformer_engine/common/cast/core/common.cuh 5/5 Added helper function for TMA-aligned shared memory pointer alignment and updated to use TMA alignment constants.

Sequence Diagram

sequenceDiagram
    participant Host
    participant DispatchLogic as Dispatch Logic<br/>(gated_mxfp8.cuh)
    participant Kernel as Persistent Kernel<br/>(gated_mxfp8_rowwise_swiglu.cuh)
    participant TMA as TMA Engine
    participant CLC as Cluster Launch Control

    Host->>DispatchLogic: quantize_gated(input, grad, output)
    DispatchLogic->>DispatchLogic: Check conditions<br/>(BF16/FP16, rowwise, SwiGLU)
    
    alt Optimized path eligible
        DispatchLogic->>Kernel: Launch persistent kernel grid
        
        loop Persistent execution
            Kernel->>Kernel: Initialize mbarriers
            Kernel->>CLC: Request next work unit
            CLC-->>Kernel: Assign CTA coordinates (X, Y)
            
            alt Work available
                loop For each stage in chunk
                    par Async prefetch
                        Kernel->>TMA: Bulk copy (act, gate, grad)
                        TMA-->>Kernel: Data to shared memory
                    and Compute on current buffer
                        Kernel->>Kernel: Load from shared memory<br/>(ld_shared_cvt_f32x2)
                        Kernel->>Kernel: Compute SwiGLU + quantize<br/>(via PTX assembly)
                        Kernel->>Kernel: Compute rowwise scales
                        Kernel->>TMA: Store quantized output
                    end
                end
                Kernel->>Kernel: Store scaling factors to global
            else No more work
                CLC-->>Kernel: Return (-1, -1)
                Kernel->>Kernel: Exit loop
            end
        end
        
        Kernel-->>DispatchLogic: Complete
    else Fallback to general kernel
        DispatchLogic->>DispatchLogic: Use standard quantize_gated
    end
    
    DispatchLogic-->>Host: Return
Loading

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

@ptrendx
Copy link
Member

ptrendx commented Oct 31, 2025

@coderabbitai full review

@coderabbitai
Copy link

coderabbitai bot commented Oct 31, 2025

✅ Actions performed

Full review triggered.

@coderabbitai
Copy link

coderabbitai bot commented Oct 31, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

The changes introduce TMA alignment improvements with dynamic calculation, add a specialized gated MXFP8 rowwise SwiGLU quantization kernel with advanced memory management and PTX operations, implement an early-exit optimization in the main gated path to delegate SwiGLU cases to the rowwise kernel, and expand PTX utilities with CTA-shared barrier primitives and cluster control functions.

Changes

Cohort / File(s) Change Summary
TMA Alignment and Helpers
transformer_engine/common/cast/core/common.cuh
Updated dimensions_supported_by_TMA to calculate alignment dynamically from TMA_GMEM_ALIGNMENT and data type bits instead of fixed constants; added align_smem_ptr_per_TMA_requirements() device helper to align pointers to TMA_SHMEM_ALIGNMENT.
Gated MXFP8 Optimizations
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
Added early-exit condition in quantize_gated to detect forward/backward SwiGLU with BF16/FP16 inputs and ROWWISE scaling, delegating to quantize_gated_rowwise() before executing non-optimized path; existing logic remains intact for other cases.
Specialized Gated SwiGLU Rowwise Kernel
transformer_engine/common/cast/mxfp8/specialized/gated_mxfp8_rowwise_swiglu.cuh
Introduces complete specialized gated rowwise MXFP8 quantization kernel with dual-buffered global-to-shared memory transfer, PTX-tuned forward/backward SwiGLU computations, per-element amax tracking, scaling factor derivation, and output write-back; includes host launcher quantize_gated_rowwise() and device helpers for activation/gate calculations.
PTX Barrier and Cluster Utilities
transformer_engine/common/util/ptx.cuh
Added CTA-shared mbarrier primitives (mbarrier_arrive_*_cta_shared_cta, mbarrier_wait_parity_*_cta_shared_cta); added cluster launch control functions (clusterlaunchcontrol_try_cancel_async_shared_cta_mbarrier_complete_tx_bytes, get_cancelled_cta_2D_id); added templated ld_shared_cvt_f32x2() for shared-memory load-and-convert with bf16x2/fp16x2 specializations.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host Code
    participant dispatch as quantize_gated<br/>(Dispatcher)
    participant rowwise as quantize_gated_rowwise<br/>(Launcher)
    participant kernel as quantize_gated_mxfp8<br/>_rowwise_kernel
    participant gmem as Global Memory
    participant smem as Shared Memory

    Host->>dispatch: Call quantize_gated()
    dispatch->>dispatch: Check: SwiGLU + BF16/FP16 + ROWWISE?
    alt Early-Exit Path (Optimized)
        dispatch->>rowwise: Delegate to quantize_gated_rowwise()
        rowwise->>rowwise: Configure TMA, shared memory, grid/block
        rowwise->>kernel: Launch kernel
        kernel->>gmem: Load activations, gates
        kernel->>smem: TMA transfer to shared memory
        kernel->>kernel: Compute fwd/bwd gated activation
        kernel->>kernel: Calculate per-element amax
        kernel->>kernel: Derive scaling factors (SFs)
        kernel->>gmem: Write quantized output (act, gate)
        kernel->>gmem: Write SFs per chunk
    else Fallback Path
        dispatch->>dispatch: Execute non-rowwise logic
        dispatch->>gmem: Process via existing path
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Areas requiring extra attention:

  • gated_mxfp8_rowwise_swiglu.cuh: Extensive new kernel logic with double-buffering, PTX-based SwiGLU computations, amax tracking, and scaling-factor derivation; verify correctness of memory synchronization, barrier placement, and output formatting.
  • ptx.cuh: New CTA-shared barrier variants and cluster control primitives; validate PTX assembly semantics, architecture-specific guards (SM_10.0+), and error handling for unsupported architectures.
  • gated_mxfp8.cuh: Early-exit path must correctly distinguish SwiGLU variants (including ClampedSwiGLU exclusion) and ensure rowwise kernel receives correct data types and scaling mode.
  • common.cuh: Verify TMA alignment calculation with dynamic formula matches hardware requirements and that pointer alignment helper handles edge cases.

Poem

🐰 Barriers sync and SwiGLU gates align,
Rowwise kernels dance through shared-memory time,
PTX tuned fast, with clusters in control,
Early exits glide—optimization's soul!
A hoppy hop through quantization's race,
Transformer speedups light up the space! 🚀

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[Common] Added an optimized gated rowwise MXFP8 SwiGLU kernel" accurately and specifically describes the primary change in this changeset. The main addition is indeed a specialized kernel for gated rowwise MXFP8 quantization with SwiGLU activation support (new file gated_mxfp8_rowwise_swiglu.cuh), which represents the bulk of the new functionality. Supporting changes in other files—such as updates to ptx.cuh for cluster-launch-control PTX integration, updates to common.cuh for alignment helpers, and modifications to gated_mxfp8.cuh to route to the new kernel—are infrastructure components that enable the main kernel feature. The title is concise, avoids vague terminology, and makes the primary objective clear to someone reviewing the git history.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e7227af and 25ed43e.

📒 Files selected for processing (4)
  • transformer_engine/common/cast/core/common.cuh (1 hunks)
  • transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh (2 hunks)
  • transformer_engine/common/cast/mxfp8/specialized/gated_mxfp8_rowwise_swiglu.cuh (1 hunks)
  • transformer_engine/common/util/ptx.cuh (4 hunks)

Comment on lines +700 to +712
// Optimized BWD/FWD SwiGLU MXFP8 Rowwise kernels for BF16/FP16 inputs
if constexpr (!std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const bool is_fwd_swiglu = !IS_BWD && (ActOP == &silu<fp32, fp32>);
const bool is_bwd_swiglu =
IS_BWD && (ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>);
const bool is_supported_data_type =
(gated_input.dtype() == DType::kFloat16) || (gated_input.dtype() == DType::kBFloat16);
const bool is_supported_scaling_type = scaling_type == ScalingType::ROWWISE;
if (is_supported_data_type && is_supported_scaling_type && (is_fwd_swiglu || is_bwd_swiglu)) {
quantize_gated_rowwise<IS_BWD, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
return;
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard rowwise path for non-Blackwell GPUs

This early-exit now forwards BF16/FP16 rowwise SwiGLU traffic to the new persistent kernel, but nothing here verifies that we are running on a Blackwell-class device. On Hopper (sm90) and earlier, the new kernel executes clusterlaunchcontrol_* instructions; those helpers immediately call NVTE_DEVICE_ERROR when ARCH_BLACKWELL_FAMILY is false, so the kernel traps and the previously working path is no longer usable. Cluster Launch Control is a Blackwell-only feature, so we need to detect the architecture before choosing this fast path and fall back to the legacy implementation otherwise. (docs.jax.dev)

Please gate this call (e.g., by checking the runtime compute capability or ARCH_BLACKWELL_FAMILY) and keep the existing fallback on non-Blackwell GPUs so we don’t crash on Hopper deployments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants