Skip to content

Commit

Permalink
narrow down NCCL transactions for all2all-v & allgather-v
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostplant committed Feb 22, 2024
1 parent 758750d commit 20df39d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tutel/custom/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ static void batch_all_to_all_v(const std::vector<torch::Tensor> &ins, const std:
auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream();

ncclGroupStart();
for (int k = 0; k < ins.size(); ++k) {
ncclGroupStart();
auto* in_buff = ins[k].data_ptr();
auto* out_buff = outs[k].data_ptr();
auto dtype = ins[k].dtype();
Expand All @@ -485,8 +485,8 @@ static void batch_all_to_all_v(const std::vector<torch::Tensor> &ins, const std:
in_offset += in_sizes[i] * size;
out_offset += out_sizes[i] * size;
}
ncclGroupEnd();
}
ncclGroupEnd();
}

static void batch_all_gather_v(const std::vector<torch::Tensor> &ins, const std::vector<torch::Tensor> &outs, const torch::Tensor &out_sizes_) {
Expand All @@ -496,8 +496,8 @@ static void batch_all_gather_v(const std::vector<torch::Tensor> &ins, const std:
auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream();

ncclGroupStart();
for (int k = 0; k < ins.size(); ++k) {
ncclGroupStart();
auto* in_buff = ins[k].data_ptr();
auto* out_buff = outs[k].data_ptr();
auto dtype = ins[k].dtype();
Expand All @@ -513,8 +513,8 @@ static void batch_all_gather_v(const std::vector<torch::Tensor> &ins, const std:
ncclRecv((char*)out_buff + out_offset, out_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream);
out_offset += out_sizes[i] * size;
}
ncclGroupEnd();
}
ncclGroupEnd();
}

static std::vector<torch::Tensor> nccl_all_to_all_scatter_async(
Expand Down

0 comments on commit 20df39d

Please sign in to comment.