diff --git a/tutel/custom/custom_kernel.cpp b/tutel/custom/custom_kernel.cpp index d776cfb..0648539 100644 --- a/tutel/custom/custom_kernel.cpp +++ b/tutel/custom/custom_kernel.cpp @@ -463,10 +463,10 @@ static torch::Tensor& nccl_stream_acquire(torch::Tensor &tensor, int idx) { static void batch_all_to_all_v(const std::vector &ins, const std::vector &outs, const torch::Tensor &in_sizes_, const torch::Tensor &out_sizes_) { AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL"); - auto in_sizes_cpu = in_sizes_.to(torch::kCPU).to(torch::kInt32); - auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt32); - auto* in_sizes = (unsigned int*)in_sizes_cpu.data_ptr(); - auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr(); + auto in_sizes_cpu = in_sizes_.to(torch::kCPU).to(torch::kInt64); + auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt64); + auto* in_sizes = (unsigned long long*)in_sizes_cpu.data_ptr(); + auto* out_sizes = (unsigned long long*)out_sizes_cpu.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream(); for (int k = 0; k < ins.size(); ++k) { @@ -478,7 +478,7 @@ static void batch_all_to_all_v(const std::vector &ins, const std: AT_ASSERTM(size > 0, "Data type of input tensors for batch_all_to_all_v are not recognized."); AT_ASSERTM(k == 0 || ins[0].numel() == ins[k].numel(), "Tensor instances within batch_all_to_all_v are supposed to share same length."); - int in_offset = 0, out_offset = 0; + unsigned long long in_offset = 0, out_offset = 0; for (int i = 0; i < shared_world_size; ++i) { ncclSend((char*)in_buff + in_offset, in_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream); ncclRecv((char*)out_buff + out_offset, out_sizes[i] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream); @@ -492,8 +492,8 @@ static void batch_all_to_all_v(const std::vector &ins, const std: static void batch_all_gather_v(const std::vector &ins, const std::vector &outs, const torch::Tensor &out_sizes_) { AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL"); - auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt32); - auto* out_sizes = (unsigned int*)out_sizes_cpu.data_ptr(); + auto out_sizes_cpu = out_sizes_.to(torch::kCPU).to(torch::kInt64); + auto* out_sizes = (unsigned long long*)out_sizes_cpu.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream(); for (int k = 0; k < ins.size(); ++k) { @@ -505,7 +505,7 @@ static void batch_all_gather_v(const std::vector &ins, const std: AT_ASSERTM(size > 0, "Data type of input tensors for batch_all_gather_v are not recognized."); AT_ASSERTM(k == 0 || ins[0].numel() == ins[k].numel(), "Tensor instances within batch_all_gather_v are supposed to share same length."); - int out_offset = 0; + unsigned long long out_offset = 0; for (int i = 0; i < shared_world_size; ++i) { if (out_sizes[shared_world_rank]) ncclSend((char*)in_buff, out_sizes[shared_world_rank] * size, ncclInt8, i, (ncclComm_t)shared_nccl_comm, stream); diff --git a/tutel/examples/nccl_all_to_all_v.py b/tutel/examples/nccl_all_to_all_v.py index fcfcc51..92311ef 100755 --- a/tutel/examples/nccl_all_to_all_v.py +++ b/tutel/examples/nccl_all_to_all_v.py @@ -12,10 +12,10 @@ if parallel_env.global_rank == 0: input = torch.tensor([10, 10, 10, 10, 10], device=local_device) - send_counts = torch.tensor([1, 4], dtype=torch.int32, device=local_device) + send_counts = torch.tensor([1, 4], dtype=torch.int64, device=local_device) else: input = torch.tensor([20, 20, 20], device=local_device) - send_counts = torch.tensor([2, 1], dtype=torch.int32, device=local_device) + send_counts = torch.tensor([2, 1], dtype=torch.int64, device=local_device) print(f'Device-{parallel_env.global_rank} sends: {[input,]}') diff --git a/tutel/impls/communicate.py b/tutel/impls/communicate.py index 4c75d0d..9b9543d 100644 --- a/tutel/impls/communicate.py +++ b/tutel/impls/communicate.py @@ -211,7 +211,9 @@ def batch_all_to_all_v(datas, partition_sizes, group=None): assert type(datas) in (tuple, list), "data type for batch_all_to_all_v() is not a list of tensors" in_sizes = partition_sizes if type(in_sizes) != torch.Tensor: - in_sizes = torch.tensor(in_sizes, dtype=torch.int32, device=datas[0].device) + in_sizes = torch.tensor(in_sizes, dtype=torch.int64, device=datas[0].device) + else: + in_sizes = in_sizes.to(torch.int64) world_size = get_world_size(group) assert in_sizes.numel() == world_size if world_size == 1: