-
Notifications
You must be signed in to change notification settings - Fork 513
[JAX][Draft] Async issuing D2H memcpy for grouped_gemm group_sizes array #2213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX][Draft] Async issuing D2H memcpy for grouped_gemm group_sizes array #2213
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is a good improvement for now.
We should probably provide a GroupedLayerNormMLP
VJP op, which encloses the grouped_gemm_copy_group_sizes
function and the use_async_d2h_group_sizes
option so that we don't expose these two to users as they can be pretty bug-prone.
"supported number ", max_num_gemms, " to be downloaded in advance."); | ||
host_num_gemms = num_gemms; | ||
// Wait for current compute stream to finish | ||
cudaStream_t compute_stream_0 = nvte_get_compute_stream(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mingxu1067 could you check if this causes the same stream sync issue as last time when we used the compute_stream(0)
instead of the stream given by XLA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note: this part follows the logic in https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/gemm/cublaslt_gemm.cu#L915
auto init = [&]() { | ||
NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event)); | ||
NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms)); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this causes any issues, we could consider moving this allocation into the FFI prepare phase.
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Hua Huang <huah@nvidia.com>
a7c4ab7
to
25a15cd
Compare
/te-ci JAX L0 |
Description
This is a draft PR for saving some work and discussion.
Recently we used TE/JAX's
grouped_gemm()
interface for a MoE model's inference. Nsys shows a GPU bubble whengrouped_gemm()
is copying thegroup_sizes
array from device to host. This is a known issue when we were designing thegrouped_gemm()
interface. It's performance impact for training / inference prefill stage is relatively small but cannot be ignored in inference decode stage. This draft aims to partially address the bubble issue.Our target model uses MLP-MoE, i.e., each expert is a MLP layer. After fusing GEMMs, each MLP-MoE layer needs two
grouped_gemm()
with the samegroup_sizes
array. This PR allows issuing an async D2H copy of thegroup_size
array before enteringgrouped_gemm()
, thengrouped_gemm()
can reuse the downloadedgroup_sizes
. We have validated the correctness of the implementation in this PR in our target model.This PR does not solve the issue of breaking CUDA graph in
grouped_gemm()
since in the async copy mode it still needs to callcudaEventSynchronize()
. Furthermore, the D2H memcpy does not overlap with other operations for copying and dispatching tokens to experts in our implementation for the target model, since those JAX-native operations are captured and executed in CUDA graph, while the async D2H copy does not support CUDA graph.@phu0ngng @mingxu1067 Please let me know your comments and suggestions. Much appreciated!
Type of change
Changes
GroupedGemmCopySizesPrimitive
for async copying ofgroup_sizes
from GPU to hostuse_async_d2h_group_sizes
forgrouped_gemm()
, the default value isFalse
so the original code path will be used by defaultChecklist: