From 20df39d58745e4a2d4a4dca1350c0684bcdb24b1 Mon Sep 17 00:00:00 2001 From: ghostplant Date: Thu, 22 Feb 2024 10:57:51 +0000 Subject: [PATCH] narrow down NCCL transactions for all2all-v & allgather-v --- tutel/custom/custom_kernel.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tutel/custom/custom_kernel.cpp b/tutel/custom/custom_kernel.cpp index 4f81796..d776cfb 100644 --- a/tutel/custom/custom_kernel.cpp +++ b/tutel/custom/custom_kernel.cpp @@ -469,8 +469,8 @@ static void batch_all_to_all_v(const std::vector &ins, const std: auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream(); - ncclGroupStart(); for (int k = 0; k < ins.size(); ++k) { + ncclGroupStart(); auto* in_buff = ins[k].data_ptr(); auto* out_buff = outs[k].data_ptr(); auto dtype = ins[k].dtype(); @@ -485,8 +485,8 @@ static void batch_all_to_all_v(const std::vector &ins, const std: in_offset += in_sizes[i] * size; out_offset += out_sizes[i] * size; } + ncclGroupEnd(); } - ncclGroupEnd(); } static void batch_all_gather_v(const std::vector &ins, const std::vector &outs, const torch::Tensor &out_sizes_) { @@ -496,8 +496,8 @@ static void batch_all_gather_v(const std::vector &ins, const std: auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream(); - ncclGroupStart(); for (int k = 0; k < ins.size(); ++k) { + ncclGroupStart(); auto* in_buff = ins[k].data_ptr(); auto* out_buff = outs[k].data_ptr(); auto dtype = ins[k].dtype(); @@ -513,8 +513,8 @@ static void batch_all_gather_v(const std::vector &ins, const std: ncclRecv((char*)out_buff + out_offset, out_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream); out_offset += out_sizes[i] * size; } + ncclGroupEnd(); } - ncclGroupEnd(); } static std::vector nccl_all_to_all_scatter_async(