diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 5d08cbc..08c401d 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -43,6 +43,7 @@ def gather( OVERLAP_REDUCE_SCATTER = False OVERLAP_ALL_REDUCE = False ALL_GATHER_ITERATOR = None +ALL_GATHER_DTYPE = torch.bfloat16 handles = [] pending_grad_accumulations = [] weights_cache = {} @@ -99,10 +100,10 @@ def trigger_async_all_gathers(model): assert weight.ndim == 1 output_shape = weight.shape[0] * world_size all_gathered_weight = torch.empty( - output_shape, dtype=weight.dtype, device=weight.device + output_shape, dtype=ALL_GATHER_DTYPE, device=weight.device ) handle = dist.all_gather_into_tensor( - all_gathered_weight, weight, group=process_group, async_op=True + all_gathered_weight, weight.to(ALL_GATHER_DTYPE), group=process_group, async_op=True ) weights_cache[weight] = [all_gathered_weight, handle] yield