Skip to content

Commit

Permalink
enable message size larger than 4GB for all_to_all_v/all_gather_v (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostplant authored Mar 23, 2024
1 parent 20df39d commit 338fa08
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
16 changes: 8 additions & 8 deletions tutel/custom/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,10 @@ static torch::Tensor& nccl_stream_acquire(torch::Tensor &tensor, int idx) {
static void batch_all_to_all_v(const std::vector<torch::Tensor> &ins, const std::vector<torch::Tensor> &outs, const torch::Tensor &in_sizes_, const torch::Tensor &out_sizes_) {
AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL");

auto in_sizes_cpu = in_sizes_.to(torch::kCPU).to(torch::kInt32);
auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt32);
auto* in_sizes = (unsigned int*)in_sizes_cpu.data_ptr();
auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr();
auto in_sizes_cpu = in_sizes_.to(torch::kCPU).to(torch::kInt64);
auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt64);
auto* in_sizes = (unsigned long long*)in_sizes_cpu.data_ptr();
auto* out_sizes = (unsigned long long*)out_sizes_cpu.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream();

for (int k = 0; k < ins.size(); ++k) {
Expand All @@ -478,7 +478,7 @@ static void batch_all_to_all_v(const std::vector<torch::Tensor> &ins, const std:
AT_ASSERTM(size > 0, "Data type of input tensors for batch_all_to_all_v are not recognized.");
AT_ASSERTM(k == 0 || ins[0].numel() == ins[k].numel(), "Tensor instances within batch_all_to_all_v are supposed to share same length.");

int in_offset = 0, out_offset = 0;
unsigned long long in_offset = 0, out_offset = 0;
for (int i = 0; i < shared_world_size; ++i) {
ncclSend((char*)in_buff + in_offset, in_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream);
ncclRecv((char*)out_buff + out_offset, out_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream);
Expand All @@ -492,8 +492,8 @@ static void batch_all_to_all_v(const std::vector<torch::Tensor> &ins, const std:
static void batch_all_gather_v(const std::vector<torch::Tensor> &ins, const std::vector<torch::Tensor> &outs, const torch::Tensor &out_sizes_) {
AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL");

auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt32);
auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr();
auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt64);
auto* out_sizes = (unsigned long long*)out_sizes_cpu.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream();

for (int k = 0; k < ins.size(); ++k) {
Expand All @@ -505,7 +505,7 @@ static void batch_all_gather_v(const std::vector<torch::Tensor> &ins, const std:
AT_ASSERTM(size > 0, "Data type of input tensors for batch_all_gather_v are not recognized.");
AT_ASSERTM(k == 0 || ins[0].numel() == ins[k].numel(), "Tensor instances within batch_all_gather_v are supposed to share same length.");

int out_offset = 0;
unsigned long long out_offset = 0;
for (int i = 0; i < shared_world_size; ++i) {
if (out_sizes[shared_world_rank])
ncclSend((char*)in_buff, out_sizes[shared_world_rank] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream);
Expand Down
4 changes: 2 additions & 2 deletions tutel/examples/nccl_all_to_all_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

if parallel_env.global_rank == 0:
input = torch.tensor([10, 10, 10, 10, 10], device=local_device)
send_counts = torch.tensor([1, 4], dtype=torch.int32, device=local_device)
send_counts = torch.tensor([1, 4], dtype=torch.int64, device=local_device)
else:
input = torch.tensor([20, 20, 20], device=local_device)
send_counts = torch.tensor([2, 1], dtype=torch.int32, device=local_device)
send_counts = torch.tensor([2, 1], dtype=torch.int64, device=local_device)

print(f'Device-{parallel_env.global_rank} sends: {[input,]}')

Expand Down
4 changes: 3 additions & 1 deletion tutel/impls/communicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def batch_all_to_all_v(datas, partition_sizes, group=None):
assert type(datas) in (tuple, list), "data type for batch_all_to_all_v() is not a list of tensors"
in_sizes = partition_sizes
if type(in_sizes) != torch.Tensor:
in_sizes = torch.tensor(in_sizes, dtype=torch.int32, device=datas[0].device)
in_sizes = torch.tensor(in_sizes, dtype=torch.int64, device=datas[0].device)
else:
in_sizes = in_sizes.to(torch.int64)
world_size = get_world_size(group)
assert in_sizes.numel() == world_size
if world_size == 1:
Expand Down

0 comments on commit 338fa08

Please sign in to comment.