-
Notifications
You must be signed in to change notification settings - Fork 540
Add device-Initiated Grouped GEMM supporting m_splits on device #2360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Code clean & Add Check & Fix Arch Chnage cutlass submodule to https Support BF16 Fix assertion Fix when all local experts have no tokens Fix Support save_original_input for cutlass backend Fix remove cudaMallocAsync & modify CUTLASS config Pass nullptr if C is not needed Tune kernel Performance Add dtype check for m_split Optimize setGroupedGemmWgradArguments when fuse_wgrad_accumulation=false Support partial wgrad accumulate when using cutlass backend use torch.empty() instead of torch.zeros for wgrad_list Fix IMA when enable cuda graph Use agr wgrad_accumulation_mask to handle partial wgrad accumulate Use bitmap for partial wgrad accumulate to avoid cudaMemcpyAsync Allow m_splits to be List, convert to torch tensor Use pinned memory instead of pageable memory Refactor and add dispatcher
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR introduces device-initiated grouped GEMM support that eliminates CPU-GPU synchronization overhead for MoE (Mixture of Experts) workloads by reading m_splits directly on the device.
Key Changes:
- Modified
m_splitsparameter fromList[int]totorch.Tensor, maintaining backward compatibility by auto-converting lists - Added new CUTLASS-based kernel path (
nvte_device_cutlass_grouped_gemmandnvte_device_cutlass_grouped_gemm_wgrad) for Blackwell GPUs (SM 10.0) - Implemented device-side argument preparation that reads
m_splitstensor on GPU, avoiding D2H transfer - Added support for partial weight gradient accumulation via
wgrad_accumulation_maskparameter - Uses pinned host memory buffer for CUDA Graph compatibility when transferring weight/scale factor addresses
Critical Issues Found:
- Race condition in global buffer index (
gemm.cpp:563): Thepinned_host_buffer_indexglobal variable lacks thread safety and never resets, causing buffer overflow after multiple calls - Multiple typos: Variable name
m_splits_on_devieshould bem_splits_on_devicein several locations - Complex lambda expression (
grouped_linear.py:265-267): Nested immediately-invoked lambda makes code maintenance difficult
Limitations:
- Device-initiated path only supports MXFP8 format on Blackwell GPUs
- No bias support when
m_splitsis on device - Requires
mdimension alignment to 128 for MXFP8 (increased from 32)
Confidence Score: 2/5
- This PR has critical concurrency bugs that will cause failures in production
- The global
pinned_host_buffer_indexvariable creates a race condition and memory corruption risk. Without proper synchronization or reset mechanism, the buffer index grows unbounded and will overflow workspace memory. Additionally, multiple typos in variable names indicate insufficient review/testing transformer_engine/pytorch/csrc/extensions/gemm.cpprequires immediate attention to fix the global buffer index race condition before merging
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/grouped_linear.py | 3/5 | Changed m_splits from List[int] to torch.Tensor, added device-initiated path with m_splits_on_device flag, complex lambda expression for conditional wgrad accumulation |
| transformer_engine/pytorch/csrc/extensions/gemm.cpp | 2/5 | Implements te_general_device_initiated_grouped_gemm with global pinned_host_buffer_index for CUDA Graph support, manages H2D copies for weight/SF addresses |
| transformer_engine/common/gemm/cutlass_device_grouped_gemm.cu | 4/5 | New CUTLASS kernel implementation for device-initiated grouped GEMM, supports fprop/dgrad/wgrad with MXFP8, includes WgradAccumulatePolicy for partial accumulation |
Sequence Diagram
sequenceDiagram
participant User
participant GroupedLinear
participant _GroupedLinear
participant gemm.py
participant gemm.cpp
participant CUTLASS_CUDA
User->>GroupedLinear: forward(inp, m_splits)
Note over GroupedLinear: Convert m_splits to tensor if list
GroupedLinear->>_GroupedLinear: forward()
alt m_splits on device
Note over _GroupedLinear: Single quantize (no split)
_GroupedLinear->>_GroupedLinear: tex.split_quantize(inp, [total_size], quantizers[:1])
else m_splits on CPU
Note over _GroupedLinear: Split quantize per expert
_GroupedLinear->>_GroupedLinear: tex.split_quantize(inp, m_splits.tolist(), quantizers)
end
_GroupedLinear->>gemm.py: general_grouped_gemm(A, B, out, m_splits, m_splits_on_device)
alt m_splits_on_device
gemm.py->>gemm.cpp: te_general_device_initiated_grouped_gemm()
gemm.cpp->>gemm.cpp: Prepare B/SF address arrays in pinned memory
gemm.cpp->>gemm.cpp: H2D copy (async, non-blocking)
gemm.cpp->>CUTLASS_CUDA: nvte_device_cutlass_grouped_gemm()
CUTLASS_CUDA->>CUTLASS_CUDA: setGroupedGemmArguments kernel (reads m_splits on device)
CUTLASS_CUDA->>CUTLASS_CUDA: CUTLASS grouped GEMM kernel
else m_splits on CPU
gemm.py->>gemm.cpp: te_general_grouped_gemm()
gemm.cpp->>gemm.cpp: cuBLAS multi-stream grouped GEMM
end
gemm.cpp-->>gemm.py: result
gemm.py-->>_GroupedLinear: out
_GroupedLinear-->>GroupedLinear: out
GroupedLinear-->>User: result
11 files reviewed, 8 comments
| // reuse and overwrite the same buffer. | ||
| // A global variable is used because the function doesn't know how many instances there are and which | ||
| // instance is calling. | ||
| int pinned_host_buffer_index = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: global variable without synchronization creates race condition
The global pinned_host_buffer_index is not thread-safe. If multiple operators call this function concurrently (even from different streams), they will increment the same global variable, leading to:
- Buffer corruption as operators overwrite each other's data
- Index overflow beyond workspace bounds
Even with CUDA Graph capture being single-threaded per graph, eager mode or multiple graphs could cause issues.
| int pinned_host_buffer_index = 0; | |
| // Use atomic operations or per-instance tracking | |
| std::atomic<int> pinned_host_buffer_index = 0; |
Consider using thread-local storage or passing buffer offset through the call chain instead.
| ctx.main_grad_funcs = [ | ||
| lambda j=i: weights[j].main_grad for i in range(num_gemms) | ||
| (lambda idx=i: (lambda: weights[idx].main_grad) if (wgrad_accumulation_mask is None or wgrad_accumulation_mask[idx]) else (lambda: None))() | ||
| for i in range(num_gemms) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: overly complex lambda creates hard-to-debug closure
This nested lambda with conditional execution (lambda idx=i: (lambda: ...) if condition else (lambda: None))() is executed immediately and difficult to understand. The inner lambdas capture weights[idx] but the outer lambda is called immediately with ().
| ctx.main_grad_funcs = [ | |
| lambda j=i: weights[j].main_grad for i in range(num_gemms) | |
| (lambda idx=i: (lambda: weights[idx].main_grad) if (wgrad_accumulation_mask is None or wgrad_accumulation_mask[idx]) else (lambda: None))() | |
| for i in range(num_gemms) | |
| def make_grad_func(idx): | |
| if wgrad_accumulation_mask is None or wgrad_accumulation_mask[idx]: | |
| return lambda: weights[idx].main_grad | |
| else: | |
| return lambda: None | |
| ctx.main_grad_funcs = [make_grad_func(i) for i in range(num_gemms)] |
| save_original_input, | ||
| *weights_and_biases, | ||
| ) -> torch.Tensor: | ||
| # pylint: disable=missing-function-docstring |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: m_splits_on_devie should be m_splits_on_device
| # pylint: disable=missing-function-docstring | |
| m_splits_on_device = m_splits.is_cuda |
| assert fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(), "Only MXFP8 is supported when m_splits is on devie" | ||
| # Cannot split because the m_splits is not available on host. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: assertion message contains typo: devie should be device
| assert fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(), "Only MXFP8 is supported when m_splits is on devie" | |
| # Cannot split because the m_splits is not available on host. | |
| assert fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(), "Only MXFP8 is supported when m_splits is on device" |
| if ctx.use_bias: | ||
| grad_output_mats = torch.split(grad_output_view, ctx.m_splits) | ||
| assert not ctx.m_splits_on_devie, "bias is not supported when m_splits is on devie" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: assertion message contains typo: devie should be device
| if ctx.use_bias: | |
| grad_output_mats = torch.split(grad_output_view, ctx.m_splits) | |
| assert not ctx.m_splits_on_devie, "bias is not supported when m_splits is on devie" | |
| assert not ctx.m_splits_on_device, "bias is not supported when m_splits is on device" |
| cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits | ||
| ) | ||
|
|
||
| assert ctx.fp8 and ctx.fp8_recipe.mxfp8(), "Only MXFP8 is supported when m_splits is on devie" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: assertion message contains typo: devie should be device
| assert ctx.fp8 and ctx.fp8_recipe.mxfp8(), "Only MXFP8 is supported when m_splits is on devie" | |
| assert ctx.fp8 and ctx.fp8_recipe.mxfp8(), "Only MXFP8 is supported when m_splits is on device" |
| # Device buffer for cutlass arguments and kernel | ||
| torch.empty(get_cutlass_grouped_gemm_workspace_size_bytes(), dtype=torch.uint8, device="cuda"), | ||
| # TODO: Only allocate pinned buffer when cuda graph is enabled | ||
| # Host pinned buffer for the source of H2D copy of cutlass arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: pinned memory buffer allocated unconditionally wastes memory
The comment says "TODO: Only allocate pinned buffer when cuda graph is enabled" but the buffer is always allocated. Pinned memory is a limited resource (typically capped at a fraction of system RAM).
Consider lazy allocation or checking torch.cuda.is_current_stream_capturing() state.
| NVTE_CHECK(pinned_host_buffer_index + num_gemms * 2 <= workspace[1].size(0), | ||
| "Pinned host buffer out of bounds, please increase the capacity by setting " | ||
| "NVTE_CUTLASS_HOST_PINNED_U64_CAPACITY. " | ||
| "Current buffer size: ", | ||
| workspace[1].size(0)); | ||
| inputB_and_SF_addrs = workspace[1].narrow(0, pinned_host_buffer_index, num_gemms * 2); | ||
| pinned_host_buffer_index += num_gemms * 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: buffer index never resets causing accumulation across calls
The pinned_host_buffer_index increments on every call but is never reset to 0. After multiple forward/backward passes, it will:
- Exceed workspace bounds
- Trigger the error check on line 628
For CUDA Graph, the index should reset at graph capture start or be managed per-graph instance.
Description
Introduces a CUTLASS-based grouped GEMM implementation that reads m_splits directly on the device.
This optimization removes the need for device-to-host data transfers and synchronization in MCore, while allowing the number of quantization kernels to be reduced to one.
The kernel is fully compatible with CUDA Graphs.
Key points:
• Does not break the existing API. The operator now accepts m_splits as either a torch.Tensor (on CPU or GPU) or a Python list.
• Reduces CPU overhead, especially for large expert counts, by using a single quantization kernel instead of one per GEMM.
• Currently supports only MXFP8.
Type of change
Changes
Please list the changes introduced in this PR:
m_splitsfromList[int]totorch.Tensor, but can still run correctly withList[int](will be internally converted to a tensor)te_general_device_initiated_grouped_gemmUnit Test
pytest -v -s tests/pytorch/test_numerics.py::test_grouped_linear_accuracy_cutlass_deviceChecklist: