-
Notifications
You must be signed in to change notification settings - Fork 540
[Common] Added an optimized gated rowwise MXFP8 SwiGLU kernel #2328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Common] Added an optimized gated rowwise MXFP8 SwiGLU kernel #2328
Conversation
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR introduces a specialized persistent CUDA kernel optimized for MXFP8 rowwise quantization of gated SwiGLU activations on Blackwell architecture (sm100f family). The implementation achieves significant performance improvements by leveraging architecture-specific features.
Key Changes
- New specialized kernel (
gated_mxfp8_rowwise_swiglu.cuh): 785-line implementation with hand-tuned PTX assembly for BF16/FP16 inputs - Persistent scheduling: Uses Blackwell's Cluster Launch Control for dynamic work distribution across CTAs
- Optimized compute path: Fused SwiGLU backward/forward computation with inline PTX for
mul.f32x2,max.xorsign.abs, and vectorized conversions - Memory optimization: Double-buffered TMA operations with prefetching (1 stage ahead) to overlap compute and memory transfers
- Infrastructure additions: New PTX helpers for cluster control, 2D CTA cancellation queries, and shared memory alignment
Dispatch Logic
The new kernel activates when ALL conditions are met:
- Activation function is SwiGLU (silu, not clamped variant)
- Input dtype is BF16 or FP16
- Scaling mode is rowwise only
Otherwise, falls back to the existing generic kernel that supports bidimensional and columnwise scaling.
Confidence Score: 4/5
- Safe to merge with minor considerations for architecture-specific validation
- The implementation is well-structured with proper architecture guards, but lacks unit tests. The kernel is correctly gated behind multiple condition checks and falls back gracefully. PTX assembly appears sound with proper register management and no obvious memory safety issues. Main risk is architecture-specific behavior on Blackwell hardware requiring validation.
- Pay attention to
gated_mxfp8_rowwise_swiglu.cuhfor runtime validation on actual Blackwell hardware, especially the persistent scheduling and cluster launch control paths
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/cast/mxfp8/specialized/gated_mxfp8_rowwise_swiglu.cuh | 4/5 | New 785-line specialized kernel for MXFP8 rowwise SwiGLU with Blackwell persistent scheduling. Implements optimized PTX assembly for BF16/FP16 forward/backward passes using cluster launch control and TMA operations. |
| transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh | 5/5 | Added dispatch logic (lines 700-710) to route SwiGLU rowwise BF16/FP16 operations to the optimized kernel when conditions match (non-clamped SwiGLU, rowwise scaling). |
| transformer_engine/common/util/ptx.cuh | 5/5 | Added PTX infrastructure for Blackwell persistent kernels: cluster launch control functions, 2D CTA cancellation queries, shared memory alignment helper, and optimized BF16/FP16 load-convert functions. |
| transformer_engine/common/cast/core/common.cuh | 5/5 | Refactored TMA_bytes constant to use TMA_GMEM_ALIGNMENT (line 33) and added align_smem_ptr_per_TMA_requirements device function (lines 37-42) for dynamic shared memory alignment. |
Sequence Diagram
sequenceDiagram
participant User
participant Dispatcher
participant OptKernel
participant GenKernel
participant TMA
participant ClusterCtrl
User->>Dispatcher: Call quantize gated
Dispatcher->>Dispatcher: Check conditions
alt Optimized path
Dispatcher->>OptKernel: Launch rowwise kernel
OptKernel->>TMA: Prefetch initial data
loop Process chunks
OptKernel->>ClusterCtrl: Request next work
ClusterCtrl-->>OptKernel: Assign CTA coords
loop Each stage
OptKernel->>TMA: Async prefetch
OptKernel->>OptKernel: Wait on barrier
OptKernel->>OptKernel: Compute activation
OptKernel->>OptKernel: Quantize to MXFP8
OptKernel->>TMA: Write results
end
end
OptKernel-->>User: Return output
else Generic path
Dispatcher->>GenKernel: Launch generic kernel
GenKernel->>TMA: Standard processing
GenKernel-->>User: Return output
end
4 files reviewed, no comments
transformer_engine/common/cast/mxfp8/specialized/gated_mxfp8_rowwise_swiglu.cuh
Outdated
Show resolved
Hide resolved
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds a specialized persistent CUDA kernel optimized for gated MXFP8 quantization with rowwise scaling and SwiGLU activation on Blackwell architecture (sm_100+). The implementation leverages Cluster Launch Control for dynamic scheduling across CTAs.
Key changes:
- New persistent kernel in
gated_mxfp8_rowwise_swiglu.cuhusing TMA bulk transfers and mbarrier synchronization - PTX wrappers for Blackwell-specific cluster launch control primitives and relaxed/acquire memory ordering variants
- Dispatch logic routes BF16/FP16 inputs with rowwise scaling and SwiGLU to the optimized kernel
- Helper function for TMA-aligned shared memory pointer alignment
Implementation notes:
- Kernel uses persistent thread blocks with dynamic work assignment via
clusterlaunchcontrol_try_cancel - Single-iteration prefetching pipeline with double-buffering for input/output data
- Optimized PTX assembly for BWD SwiGLU computation using SIMD instructions
- Variable shadowing issue at line 257 (previously flagged) has been resolved in commit 25ed43e
Confidence Score: 4/5
- This PR is safe to merge with minor risk - the optimized kernel is architecture-specific and well-isolated
- Score reflects that the implementation is sound with proper fallback paths, but uses advanced Blackwell-specific features (persistent kernels, cluster launch control) that are less battle-tested than standard approaches. The variable shadowing issue was already fixed. The kernel is properly gated behind architecture checks and only executes for specific input configurations.
- No files require special attention - the implementation follows existing patterns in the codebase
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/cast/mxfp8/specialized/gated_mxfp8_rowwise_swiglu.cuh | 4/5 | New optimized persistent kernel for MXFP8 SwiGLU with Blackwell-specific cluster launch control. Variable shadowing issue already fixed. |
| transformer_engine/common/util/ptx.cuh | 5/5 | Added PTX wrappers for mbarrier operations and cluster launch control primitives for Blackwell architecture. |
| transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh | 5/5 | Added dispatch logic to route eligible workloads to optimized kernel based on data type, scaling type, and activation function. |
| transformer_engine/common/cast/core/common.cuh | 5/5 | Added helper function for TMA-aligned shared memory pointer alignment and updated to use TMA alignment constants. |
Sequence Diagram
sequenceDiagram
participant Host
participant DispatchLogic as Dispatch Logic<br/>(gated_mxfp8.cuh)
participant Kernel as Persistent Kernel<br/>(gated_mxfp8_rowwise_swiglu.cuh)
participant TMA as TMA Engine
participant CLC as Cluster Launch Control
Host->>DispatchLogic: quantize_gated(input, grad, output)
DispatchLogic->>DispatchLogic: Check conditions<br/>(BF16/FP16, rowwise, SwiGLU)
alt Optimized path eligible
DispatchLogic->>Kernel: Launch persistent kernel grid
loop Persistent execution
Kernel->>Kernel: Initialize mbarriers
Kernel->>CLC: Request next work unit
CLC-->>Kernel: Assign CTA coordinates (X, Y)
alt Work available
loop For each stage in chunk
par Async prefetch
Kernel->>TMA: Bulk copy (act, gate, grad)
TMA-->>Kernel: Data to shared memory
and Compute on current buffer
Kernel->>Kernel: Load from shared memory<br/>(ld_shared_cvt_f32x2)
Kernel->>Kernel: Compute SwiGLU + quantize<br/>(via PTX assembly)
Kernel->>Kernel: Compute rowwise scales
Kernel->>TMA: Store quantized output
end
end
Kernel->>Kernel: Store scaling factors to global
else No more work
CLC-->>Kernel: Return (-1, -1)
Kernel->>Kernel: Exit loop
end
end
Kernel-->>DispatchLogic: Complete
else Fallback to general kernel
DispatchLogic->>DispatchLogic: Use standard quantize_gated
end
DispatchLogic-->>Host: Return
4 files reviewed, no comments
|
/te-ci |
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe changes introduce TMA alignment improvements with dynamic calculation, add a specialized gated MXFP8 rowwise SwiGLU quantization kernel with advanced memory management and PTX operations, implement an early-exit optimization in the main gated path to delegate SwiGLU cases to the rowwise kernel, and expand PTX utilities with CTA-shared barrier primitives and cluster control functions. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host Code
participant dispatch as quantize_gated<br/>(Dispatcher)
participant rowwise as quantize_gated_rowwise<br/>(Launcher)
participant kernel as quantize_gated_mxfp8<br/>_rowwise_kernel
participant gmem as Global Memory
participant smem as Shared Memory
Host->>dispatch: Call quantize_gated()
dispatch->>dispatch: Check: SwiGLU + BF16/FP16 + ROWWISE?
alt Early-Exit Path (Optimized)
dispatch->>rowwise: Delegate to quantize_gated_rowwise()
rowwise->>rowwise: Configure TMA, shared memory, grid/block
rowwise->>kernel: Launch kernel
kernel->>gmem: Load activations, gates
kernel->>smem: TMA transfer to shared memory
kernel->>kernel: Compute fwd/bwd gated activation
kernel->>kernel: Calculate per-element amax
kernel->>kernel: Derive scaling factors (SFs)
kernel->>gmem: Write quantized output (act, gate)
kernel->>gmem: Write SFs per chunk
else Fallback Path
dispatch->>dispatch: Execute non-rowwise logic
dispatch->>gmem: Process via existing path
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Areas requiring extra attention:
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
transformer_engine/common/cast/core/common.cuh(1 hunks)transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh(2 hunks)transformer_engine/common/cast/mxfp8/specialized/gated_mxfp8_rowwise_swiglu.cuh(1 hunks)transformer_engine/common/util/ptx.cuh(4 hunks)
| // Optimized BWD/FWD SwiGLU MXFP8 Rowwise kernels for BF16/FP16 inputs | ||
| if constexpr (!std::is_same<ParamOP, ClampedSwiGLUParam>::value) { | ||
| const bool is_fwd_swiglu = !IS_BWD && (ActOP == &silu<fp32, fp32>); | ||
| const bool is_bwd_swiglu = | ||
| IS_BWD && (ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>); | ||
| const bool is_supported_data_type = | ||
| (gated_input.dtype() == DType::kFloat16) || (gated_input.dtype() == DType::kBFloat16); | ||
| const bool is_supported_scaling_type = scaling_type == ScalingType::ROWWISE; | ||
| if (is_supported_data_type && is_supported_scaling_type && (is_fwd_swiglu || is_bwd_swiglu)) { | ||
| quantize_gated_rowwise<IS_BWD, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream); | ||
| return; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard rowwise path for non-Blackwell GPUs
This early-exit now forwards BF16/FP16 rowwise SwiGLU traffic to the new persistent kernel, but nothing here verifies that we are running on a Blackwell-class device. On Hopper (sm90) and earlier, the new kernel executes clusterlaunchcontrol_* instructions; those helpers immediately call NVTE_DEVICE_ERROR when ARCH_BLACKWELL_FAMILY is false, so the kernel traps and the previously working path is no longer usable. Cluster Launch Control is a Blackwell-only feature, so we need to detect the architecture before choosing this fast path and fall back to the legacy implementation otherwise. (docs.jax.dev)
Please gate this call (e.g., by checking the runtime compute capability or ARCH_BLACKWELL_FAMILY) and keep the existing fallback on non-Blackwell GPUs so we don’t crash on Hopper deployments.
Description
This PR adds a persistent gated MXFP8 kernel optimized for rowwise scaling, SwiGLU activation (FWD and BWD) and BF16/FP16 input tensors. The kernel uses the "Cluster Launch Control" feature of Blackwell arch (
sm100ffamily) to enable persistency with dynamic scheduling.Performance measured on PreNYX cluster
Type of change
Changes
Checklist:
Summary by CodeRabbit
Release Notes
Performance Improvements
Refactor