Skip to content

Conversation

@QiZhangNV
Copy link
Contributor

Description

Introduces a CUTLASS-based grouped GEMM implementation that reads m_splits directly on the device.

This optimization removes the need for device-to-host data transfers and synchronization in MCore, while allowing the number of quantization kernels to be reduced to one.

The kernel is fully compatible with CUDA Graphs.

Key points:
• Does not break the existing API. The operator now accepts m_splits as either a torch.Tensor (on CPU or GPU) or a Python list.
• Reduces CPU overhead, especially for large expert counts, by using a single quantization kernel instead of one per GEMM.
• Currently supports only MXFP8.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change m_splits from List[int] to torch.Tensor, but can still run correctly with List[int] (will be internally converted to a tensor)
  • Add te_general_device_initiated_grouped_gemm

Unit Test

pytest -v -s tests/pytorch/test_numerics.py::test_grouped_linear_accuracy_cutlass_device

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Code clean & Add Check & Fix Arch

Chnage cutlass submodule to https

Support BF16

Fix assertion

Fix when all local experts have no tokens

Fix

Support save_original_input for cutlass backend

Fix remove cudaMallocAsync & modify CUTLASS config

Pass nullptr if C is not needed

Tune kernel Performance

Add dtype check for m_split

Optimize setGroupedGemmWgradArguments when fuse_wgrad_accumulation=false

Support partial wgrad accumulate when using cutlass backend

use torch.empty() instead of torch.zeros for wgrad_list

Fix IMA when enable cuda graph

Use agr wgrad_accumulation_mask to handle partial wgrad accumulate

Use bitmap for partial wgrad accumulate to avoid cudaMemcpyAsync

Allow m_splits to be List, convert to torch tensor

Use pinned memory instead of pageable memory

Refactor and add dispatcher
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR introduces device-initiated grouped GEMM support that eliminates CPU-GPU synchronization overhead for MoE (Mixture of Experts) workloads by reading m_splits directly on the device.

Key Changes:

  • Modified m_splits parameter from List[int] to torch.Tensor, maintaining backward compatibility by auto-converting lists
  • Added new CUTLASS-based kernel path (nvte_device_cutlass_grouped_gemm and nvte_device_cutlass_grouped_gemm_wgrad) for Blackwell GPUs (SM 10.0)
  • Implemented device-side argument preparation that reads m_splits tensor on GPU, avoiding D2H transfer
  • Added support for partial weight gradient accumulation via wgrad_accumulation_mask parameter
  • Uses pinned host memory buffer for CUDA Graph compatibility when transferring weight/scale factor addresses

Critical Issues Found:

  1. Race condition in global buffer index (gemm.cpp:563): The pinned_host_buffer_index global variable lacks thread safety and never resets, causing buffer overflow after multiple calls
  2. Multiple typos: Variable name m_splits_on_devie should be m_splits_on_device in several locations
  3. Complex lambda expression (grouped_linear.py:265-267): Nested immediately-invoked lambda makes code maintenance difficult

Limitations:

  • Device-initiated path only supports MXFP8 format on Blackwell GPUs
  • No bias support when m_splits is on device
  • Requires m dimension alignment to 128 for MXFP8 (increased from 32)

Confidence Score: 2/5

  • This PR has critical concurrency bugs that will cause failures in production
  • The global pinned_host_buffer_index variable creates a race condition and memory corruption risk. Without proper synchronization or reset mechanism, the buffer index grows unbounded and will overflow workspace memory. Additionally, multiple typos in variable names indicate insufficient review/testing
  • transformer_engine/pytorch/csrc/extensions/gemm.cpp requires immediate attention to fix the global buffer index race condition before merging

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/grouped_linear.py 3/5 Changed m_splits from List[int] to torch.Tensor, added device-initiated path with m_splits_on_device flag, complex lambda expression for conditional wgrad accumulation
transformer_engine/pytorch/csrc/extensions/gemm.cpp 2/5 Implements te_general_device_initiated_grouped_gemm with global pinned_host_buffer_index for CUDA Graph support, manages H2D copies for weight/SF addresses
transformer_engine/common/gemm/cutlass_device_grouped_gemm.cu 4/5 New CUTLASS kernel implementation for device-initiated grouped GEMM, supports fprop/dgrad/wgrad with MXFP8, includes WgradAccumulatePolicy for partial accumulation

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedLinear
    participant _GroupedLinear
    participant gemm.py
    participant gemm.cpp
    participant CUTLASS_CUDA
    
    User->>GroupedLinear: forward(inp, m_splits)
    Note over GroupedLinear: Convert m_splits to tensor if list
    GroupedLinear->>_GroupedLinear: forward()
    
    alt m_splits on device
        Note over _GroupedLinear: Single quantize (no split)
        _GroupedLinear->>_GroupedLinear: tex.split_quantize(inp, [total_size], quantizers[:1])
    else m_splits on CPU
        Note over _GroupedLinear: Split quantize per expert
        _GroupedLinear->>_GroupedLinear: tex.split_quantize(inp, m_splits.tolist(), quantizers)
    end
    
    _GroupedLinear->>gemm.py: general_grouped_gemm(A, B, out, m_splits, m_splits_on_device)
    
    alt m_splits_on_device
        gemm.py->>gemm.cpp: te_general_device_initiated_grouped_gemm()
        gemm.cpp->>gemm.cpp: Prepare B/SF address arrays in pinned memory
        gemm.cpp->>gemm.cpp: H2D copy (async, non-blocking)
        gemm.cpp->>CUTLASS_CUDA: nvte_device_cutlass_grouped_gemm()
        CUTLASS_CUDA->>CUTLASS_CUDA: setGroupedGemmArguments kernel (reads m_splits on device)
        CUTLASS_CUDA->>CUTLASS_CUDA: CUTLASS grouped GEMM kernel
    else m_splits on CPU
        gemm.py->>gemm.cpp: te_general_grouped_gemm()
        gemm.cpp->>gemm.cpp: cuBLAS multi-stream grouped GEMM
    end
    
    gemm.cpp-->>gemm.py: result
    gemm.py-->>_GroupedLinear: out
    _GroupedLinear-->>GroupedLinear: out
    GroupedLinear-->>User: result
Loading

11 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

// reuse and overwrite the same buffer.
// A global variable is used because the function doesn't know how many instances there are and which
// instance is calling.
int pinned_host_buffer_index = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: global variable without synchronization creates race condition

The global pinned_host_buffer_index is not thread-safe. If multiple operators call this function concurrently (even from different streams), they will increment the same global variable, leading to:

  1. Buffer corruption as operators overwrite each other's data
  2. Index overflow beyond workspace bounds

Even with CUDA Graph capture being single-threaded per graph, eager mode or multiple graphs could cause issues.

Suggested change
int pinned_host_buffer_index = 0;
// Use atomic operations or per-instance tracking
std::atomic<int> pinned_host_buffer_index = 0;

Consider using thread-local storage or passing buffer offset through the call chain instead.

Comment on lines 265 to +267
ctx.main_grad_funcs = [
lambda j=i: weights[j].main_grad for i in range(num_gemms)
(lambda idx=i: (lambda: weights[idx].main_grad) if (wgrad_accumulation_mask is None or wgrad_accumulation_mask[idx]) else (lambda: None))()
for i in range(num_gemms)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: overly complex lambda creates hard-to-debug closure

This nested lambda with conditional execution (lambda idx=i: (lambda: ...) if condition else (lambda: None))() is executed immediately and difficult to understand. The inner lambdas capture weights[idx] but the outer lambda is called immediately with ().

Suggested change
ctx.main_grad_funcs = [
lambda j=i: weights[j].main_grad for i in range(num_gemms)
(lambda idx=i: (lambda: weights[idx].main_grad) if (wgrad_accumulation_mask is None or wgrad_accumulation_mask[idx]) else (lambda: None))()
for i in range(num_gemms)
def make_grad_func(idx):
if wgrad_accumulation_mask is None or wgrad_accumulation_mask[idx]:
return lambda: weights[idx].main_grad
else:
return lambda: None
ctx.main_grad_funcs = [make_grad_func(i) for i in range(num_gemms)]

save_original_input,
*weights_and_biases,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: typo: m_splits_on_devie should be m_splits_on_device

Suggested change
# pylint: disable=missing-function-docstring
m_splits_on_device = m_splits.is_cuda

Comment on lines +138 to +139
assert fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(), "Only MXFP8 is supported when m_splits is on devie"
# Cannot split because the m_splits is not available on host.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: assertion message contains typo: devie should be device

Suggested change
assert fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(), "Only MXFP8 is supported when m_splits is on devie"
# Cannot split because the m_splits is not available on host.
assert fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(), "Only MXFP8 is supported when m_splits is on device"

Comment on lines 334 to +335
if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
assert not ctx.m_splits_on_devie, "bias is not supported when m_splits is on devie"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: assertion message contains typo: devie should be device

Suggested change
if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
assert not ctx.m_splits_on_devie, "bias is not supported when m_splits is on devie"
assert not ctx.m_splits_on_device, "bias is not supported when m_splits is on device"

cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits
)

assert ctx.fp8 and ctx.fp8_recipe.mxfp8(), "Only MXFP8 is supported when m_splits is on devie"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: assertion message contains typo: devie should be device

Suggested change
assert ctx.fp8 and ctx.fp8_recipe.mxfp8(), "Only MXFP8 is supported when m_splits is on devie"
assert ctx.fp8 and ctx.fp8_recipe.mxfp8(), "Only MXFP8 is supported when m_splits is on device"

# Device buffer for cutlass arguments and kernel
torch.empty(get_cutlass_grouped_gemm_workspace_size_bytes(), dtype=torch.uint8, device="cuda"),
# TODO: Only allocate pinned buffer when cuda graph is enabled
# Host pinned buffer for the source of H2D copy of cutlass arguments
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: pinned memory buffer allocated unconditionally wastes memory

The comment says "TODO: Only allocate pinned buffer when cuda graph is enabled" but the buffer is always allocated. Pinned memory is a limited resource (typically capped at a fraction of system RAM).

Consider lazy allocation or checking torch.cuda.is_current_stream_capturing() state.

Comment on lines +628 to +634
NVTE_CHECK(pinned_host_buffer_index + num_gemms * 2 <= workspace[1].size(0),
"Pinned host buffer out of bounds, please increase the capacity by setting "
"NVTE_CUTLASS_HOST_PINNED_U64_CAPACITY. "
"Current buffer size: ",
workspace[1].size(0));
inputB_and_SF_addrs = workspace[1].narrow(0, pinned_host_buffer_index, num_gemms * 2);
pinned_host_buffer_index += num_gemms * 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: buffer index never resets causing accumulation across calls

The pinned_host_buffer_index increments on every call but is never reset to 0. After multiple forward/backward passes, it will:

  1. Exceed workspace bounds
  2. Trigger the error check on line 628

For CUDA Graph, the index should reset at graph capture start or be managed per-graph instance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant