diff --git a/tutel/custom/custom_kernel.cpp b/tutel/custom/custom_kernel.cpp index 4f81796..d776cfb 100644 --- a/tutel/custom/custom_kernel.cpp +++ b/tutel/custom/custom_kernel.cpp @@ -469,8 +469,8 @@ static void batch_all_to_all_v(const std::vector &ins, const std: auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream(); - ncclGroupStart(); for (int k = 0; k < ins.size(); ++k) { + ncclGroupStart(); auto* in_buff = ins[k].data_ptr(); auto* out_buff = outs[k].data_ptr(); auto dtype = ins[k].dtype(); @@ -485,8 +485,8 @@ static void batch_all_to_all_v(const std::vector &ins, const std: in_offset += in_sizes[i] * size; out_offset += out_sizes[i] * size; } + ncclGroupEnd(); } - ncclGroupEnd(); } static void batch_all_gather_v(const std::vector &ins, const std::vector &outs, const torch::Tensor &out_sizes_) { @@ -496,8 +496,8 @@ static void batch_all_gather_v(const std::vector &ins, const std: auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream(); - ncclGroupStart(); for (int k = 0; k < ins.size(); ++k) { + ncclGroupStart(); auto* in_buff = ins[k].data_ptr(); auto* out_buff = outs[k].data_ptr(); auto dtype = ins[k].dtype(); @@ -513,8 +513,8 @@ static void batch_all_gather_v(const std::vector &ins, const std: ncclRecv((char*)out_buff + out_offset, out_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream); out_offset += out_sizes[i] * size; } + ncclGroupEnd(); } - ncclGroupEnd(); } static std::vector nccl_all_to_all_scatter_async(