Skip to content

Commit

Permalink
compatible with torch 2.3 above
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jun 11, 2024
1 parent 8cc441a commit 846e236
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 5 additions & 1 deletion unicore/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,11 @@ def all_gather_list(data, group=None, max_size=16384):
not hasattr(all_gather_list, "_buffer")
or all_gather_list._buffer.numel() < buffer_size
):
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._buffer = torch.tensor(
data=[0] * buffer_size, # Initialize with zeros
dtype=torch.uint8, # Byte tensor corresponds to uint8
device='cuda' # Specify the device as CUDA
)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
buffer.zero_()
Expand Down
6 changes: 5 additions & 1 deletion unicore/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ def __init__(self, args, task, model, loss):
self._wrapped_model = None

if self.cuda and self.data_parallel_world_size > 1:
self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
self._grad_norm_buf = torch.tensor(
data=[0.0] * self.data_parallel_world_size, # Initialize with zeros or appropriate values
dtype=torch.double, # Set the desired data type
device='cuda'
)
else:
self._grad_norm_buf = None

Expand Down

0 comments on commit 846e236

Please sign in to comment.