diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 08c401d..6c0a4f7 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -43,11 +43,19 @@ def gather( OVERLAP_REDUCE_SCATTER = False OVERLAP_ALL_REDUCE = False ALL_GATHER_ITERATOR = None -ALL_GATHER_DTYPE = torch.bfloat16 +ALL_GATHER_DTYPE = torch.float32 +REDUCE_SCATTER_DTYPE = torch.float32 handles = [] pending_grad_accumulations = [] weights_cache = {} +def set_all_gather_dtype(dtype): + global ALL_GATHER_DTYPE + ALL_GATHER_DTYPE = dtype + +def set_reduce_scatter_dtype(dtype): + global REDUCE_SCATTER_DTYPE + REDUCE_SCATTER_DTYPE = dtype def register_handle(handle): # ToDo: This might be unnecesary since diff --git a/axonn/intra_layer/communication.py b/axonn/intra_layer/communication.py index a6c3265..1021924 100644 --- a/axonn/intra_layer/communication.py +++ b/axonn/intra_layer/communication.py @@ -44,12 +44,14 @@ def _gather(input_, dim, process_group=None, cache=False): input_ = input_.contiguous() # Size and dimension. rank = dist.get_rank(process_group) + + from axonn.intra_layer import ALL_GATHER_DTYPE tensor_list = [ - torch.empty_like(input_) for _ in range(dist.get_world_size(process_group)) + torch.empty_like(input_, dtype=ALL_GATHER_DTYPE) for _ in range(dist.get_world_size(process_group)) ] tensor_list[rank] = input_ - dist.all_gather(tensor_list, input_, group=process_group) + dist.all_gather(tensor_list, input_.to(ALL_GATHER_DTYPE), group=process_group) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=dim).contiguous() @@ -70,17 +72,20 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False): assert input_.shape[dim] % total_chunks == 0 tensor_shape = list(input_.shape) tensor_shape[dim] //= total_chunks + + from axonn.intra_layer import REDUCE_SCATTER_DTYPE + output = torch.empty( - tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device() + tensor_shape, dtype=REDUCE_SCATTER_DTYPE, device=torch.cuda.current_device() ) if hasattr(torch.distributed, "reduce_scatter_tensor"): handle = torch.distributed.reduce_scatter_tensor( - output, input_, group=process_group, async_op=overlap_comm + output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm ) else: handle = torch.distributed._reduce_scatter_base( - output, input_, group=process_group, async_op=overlap_comm + output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm ) if overlap_comm: