From 846e23684f7ba071798336cc71df873928db2145 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 11 Jun 2024 16:16:42 +0800 Subject: [PATCH] compatible with torch 2.3 above --- unicore/distributed/utils.py | 6 +++++- unicore/trainer.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/unicore/distributed/utils.py b/unicore/distributed/utils.py index 278819a..43980ed 100644 --- a/unicore/distributed/utils.py +++ b/unicore/distributed/utils.py @@ -327,7 +327,11 @@ def all_gather_list(data, group=None, max_size=16384): not hasattr(all_gather_list, "_buffer") or all_gather_list._buffer.numel() < buffer_size ): - all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) + all_gather_list._buffer = torch.tensor( + data=[0] * buffer_size, # Initialize with zeros + dtype=torch.uint8, # Byte tensor corresponds to uint8 + device='cuda' # Specify the device as CUDA + ) all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() buffer = all_gather_list._buffer buffer.zero_() diff --git a/unicore/trainer.py b/unicore/trainer.py index cf85c4f..07ef447 100644 --- a/unicore/trainer.py +++ b/unicore/trainer.py @@ -87,7 +87,11 @@ def __init__(self, args, task, model, loss): self._wrapped_model = None if self.cuda and self.data_parallel_world_size > 1: - self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size) + self._grad_norm_buf = torch.tensor( + data=[0.0] * self.data_parallel_world_size, # Initialize with zeros or appropriate values + dtype=torch.double, # Set the desired data type + device='cuda' + ) else: self._grad_norm_buf = None