diff --git a/axonn/axonn.py b/axonn/axonn.py index bafc46f..e6c2f97 100644 --- a/axonn/axonn.py +++ b/axonn/axonn.py @@ -97,6 +97,7 @@ def init( G_data: int, G_intra_r: int = 1, G_intra_c: int = 1, + G_intra_d: int = 1, gpus_per_node: Optional[int] = None, mixed_precision=False, float16_allreduce=True, @@ -128,13 +129,14 @@ def init( global comm_handle, is_initialized, computation_dtype, _float16_all_reduce global _cpu_offload, _use_bf16, _mixed_precision, loss_scale comm_handle = communication_handle( - G_inter, G_data, G_intra_r, G_intra_c, gpus_per_node + G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node ) config.G_inter = G_inter config.G_data = G_data - config.G_intra = G_intra_r * G_intra_c + config.G_intra = G_intra_r * G_intra_c * G_intra_d config.G_intra_r = G_intra_r config.G_intra_c = G_intra_c + config.G_intra_d = G_intra_d config.inter_layer_parallel_rank = comm_handle.inter_layer_parallel_rank config.data_parallel_rank = comm_handle.data_parallel_rank config.intra_layer_parallel_rank = comm_handle.intra_layer_parallel_rank diff --git a/axonn/communication.py b/axonn/communication.py index 060f1ef..4a29554 100644 --- a/axonn/communication.py +++ b/axonn/communication.py @@ -4,8 +4,15 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import os -from mpi4py import MPI + +try: + from mpi4py import MPI + + MPI4PY = True +except ImportError: + MPI4PY = False import torch +import numpy as np class communication_handle: @@ -15,7 +22,13 @@ class communication_handle: """ def __init__( - self, G_inter: int, G_data: int, G_intra_r=1, G_intra_c=1, gpus_per_node=None + self, + G_inter: int, + G_data: int, + G_intra_r=1, + G_intra_c=1, + G_intra_d=1, + gpus_per_node=None, ): """Constructor for the communication handle @@ -24,22 +37,30 @@ def __init__( G_data (int): number of GPUs used for data parallelism gpus_per_node (int, optional): number of GPUs per node, if not provided this is inferred using pytorch - G_intra (int): degree of intra-layer parallelism. Note that - the user is supposed to implement their intra-layer parallel - kernels. AxoNN will just create communicationgroups for - intra-layer parallelism + G_intra_r (int): number of GPUs in the row intra-layer parallel dimension + G_intra_c (int): number of GPUs in the column intra-layer parallel dimension + G_intra_d (int): number of GPUs in the depth intra-layer parallel dimension """ - self.world_rank = MPI.COMM_WORLD.Get_rank() - self.world_size = MPI.COMM_WORLD.Get_size() - G_intra = G_intra_r * G_intra_c + if not torch.distributed.is_initialized(): + assert MPI4PY, "either install mpi4py and launch via mpirun/srun" + "or initialize torch.distributed outside axonn" + self.world_rank = MPI.COMM_WORLD.Get_rank() + self.world_size = MPI.COMM_WORLD.Get_size() + else: + self.world_rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + G_intra = G_intra_r * G_intra_c * G_intra_d assert ( G_inter * G_data * G_intra == self.world_size - ), "The product of G_inter and G_data should be equal to the number of GPUs" + ), "The product of G_inter, G_intra_r, G_intra_c, G_intra_d," + f"G_data should be equal to the number of GPUs - {self.world_size}" self.G_intra = G_intra self.G_inter = G_inter self.G_data = G_data self.G_intra_r = G_intra_r self.G_intra_c = G_intra_c + self.G_intra_d = G_intra_d # infer gpus per node if not provided self.gpus_per_node = ( @@ -51,15 +72,34 @@ def __init__( self.intra_layer_column_parallel_rank = ( self.intra_layer_parallel_rank % G_intra_c ) - self.intra_layer_row_parallel_rank = self.intra_layer_parallel_rank // G_intra_c + self.intra_layer_row_parallel_rank = ( + self.intra_layer_parallel_rank // G_intra_c + ) % G_intra_r + self.intra_layer_depth_parallel_rank = self.intra_layer_parallel_rank // ( + G_intra_c * G_intra_r + ) + self.inter_layer_parallel_rank = (self.world_rank // G_intra) % G_inter self.data_parallel_rank = self.world_rank // (G_inter * G_intra) # create communicator for point-to-point(MPI) communication colour = self.intra_layer_parallel_rank + G_intra * self.data_parallel_rank - # this needs to be checked - self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour) - assert self.p2p_mpi_comm.Get_size() == G_inter + + if G_inter > 1: + # this needs to be checked + if MPI4PY: + self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour) + assert self.p2p_mpi_comm.Get_size() == G_inter + else: + self.p2p_mpi_comm = None + print( + "Warning: AxoNN's implementation of inter-layer" + "parallelism (pipelining) requires mpi4py, which wasn't found." + "You will have to use an external implementation" + "of pipeline parallelism." + ) + else: + self.p2p_mpi_comm = None # create communicator for collective (NCCL) communication if not torch.distributed.is_initialized(): @@ -89,37 +129,63 @@ def __init__( self.coll_nccl_comm = ith_jth_data_parallel_group # create communicators for intra-layer parallelism - for i in range(G_data): - for j in range(G_inter): + for i_ in range(G_data): + for j_ in range(G_inter): ranks_in_ith_jth_intra_layer_group = [ - i * G_inter * G_intra + j * G_intra + k for k in range(G_intra) + i_ * G_inter * G_intra + j_ * G_intra + k for k in range(G_intra) ] - ith_jth_intra_layer_group = torch.distributed.new_group( ranks=ranks_in_ith_jth_intra_layer_group, backend="nccl" ) if self.world_rank in ranks_in_ith_jth_intra_layer_group: self.intra_layer_group = ith_jth_intra_layer_group + + assert ( + len(ranks_in_ith_jth_intra_layer_group) + == G_intra_r * G_intra_c * G_intra_d + ) + + ranks_in_ith_jth_intra_layer_group = np.array( + ranks_in_ith_jth_intra_layer_group + ).reshape(G_intra_d, G_intra_r, G_intra_c) # form row and column tensor parallel groups - # G_intra_r x G_intra_c - assert len(ranks_in_ith_jth_intra_layer_group) == G_intra_r * G_intra_c - intra_layer_ranks = ranks_in_ith_jth_intra_layer_group + # G_intra_d x G_intra_r x G_intra_c + + # inner + for i in range(G_intra_d): + for j in range(G_intra_r): + group_members = list( + ranks_in_ith_jth_intra_layer_group[i, j, :] + ) + group = torch.distributed.new_group( + ranks=group_members, backend="nccl" + ) + if self.world_rank in group_members: + self.inner_intra_layer_parallel_group = group + + # outer + for i in range(G_intra_d): + for j in range(G_intra_c): + group_members = list( + ranks_in_ith_jth_intra_layer_group[i, :, j] + ) + group = torch.distributed.new_group( + ranks=group_members, backend="nccl" + ) + if self.world_rank in group_members: + self.outer_intra_layer_parallel_group = group + + # depth for i in range(G_intra_r): - offset = i * G_intra_c - group_members = intra_layer_ranks[offset : offset + G_intra_c] - group = torch.distributed.new_group( - ranks=group_members, backend="nccl" - ) - if self.world_rank in group_members: - self.inner_intra_layer_parallel_group = group - - for i in range(G_intra_c): - group_members = intra_layer_ranks[i::G_intra_c] - group = torch.distributed.new_group( - ranks=group_members, backend="nccl" - ) - if self.world_rank in group_members: - self.outer_intra_layer_parallel_group = group + for j in range(G_intra_c): + group_members = list( + ranks_in_ith_jth_intra_layer_group[:, i, j] + ) + group = torch.distributed.new_group( + ranks=group_members, backend="nccl" + ) + if self.world_rank in group_members: + self.depth_intra_layer_parallel_group = group def _torch_to_mpi(self, tensor: torch.Tensor): """Converts a PyTorch tensor into an mpi4py compatible array using its diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index cbe63a7..f4c9edf 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from .fully_connected import Linear # noqa: F401 from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401 @@ -5,21 +6,154 @@ from .gradient_normalization import clip_grad_norm_ # noqa: F401 from axonn import axonn as ax +import torch +import torch.distributed as dist -def drop(x, transpose=False, dim=-1): +def drop(x, transpose=False, dim=-1, batch_dim=0): if not transpose: group = ax.comm_handle.inner_intra_layer_parallel_group else: group = ax.comm_handle.outer_intra_layer_parallel_group - return Drop.apply(x, group, dim) + x = Drop.apply(x, group, dim) + x = Drop.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim) + return x -def gather(x, transpose=False, dim=-1): +def gather(x, transpose=False, dim=-1, batch_dim=0): if not transpose: group = ax.comm_handle.inner_intra_layer_parallel_group else: group = ax.comm_handle.outer_intra_layer_parallel_group - return Gather.apply(x, group, dim) + x = Gather.apply(x, group, dim) + x = Gather.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim) + return x + + +OVERLAP_REDUCE_SCATTER = False +OVERLAP_ALL_REDUCE = False +CACHE_WEIGHTS = False +ALL_GATHER_ITERATOR = None +handles = [] +pending_grad_accumulations = [] +weights_cache = {} + + +def register_handle(handle): + # ToDo: This might be unnecesary since + # we are calling synchronize in clear_handles + global handles + handles.append(handle) + + +def clear_handles(): + global handles + torch.cuda.synchronize() + handles = [] + + +def accumulate_later(param, grad): + global pending_grad_accumulations + pending_grad_accumulations.append([param, grad]) + + +@torch.no_grad() +def accumulate(): + global pending_grad_accumulations + for param, grad in pending_grad_accumulations: + if param.grad is None: + param.grad = grad + else: + param.grad.add_(grad) + + pending_grad_accumulations = [] + + +def clear_weights_cache(): + global weights_cache + weights_cache = {} + + +def trigger_async_all_gathers(model): + global weights_cache + for module in model.modules(): + if isinstance(module, Linear): + weight = module.weight + if weight not in weights_cache: + # only trigger all gathers if not in cache + process_group = module.depth_group + world_size = dist.get_world_size(process_group) + if world_size == 1: + all_gathered_weight = weight + handle = None + else: + assert weight.ndim == 1 + output_shape = weight.shape[0] * world_size + all_gathered_weight = torch.empty( + output_shape, dtype=weight.dtype, device=weight.device + ) + handle = dist.all_gather_into_tensor( + all_gathered_weight, weight, group=process_group, async_op=True + ) + weights_cache[weight] = [all_gathered_weight, handle] + yield + + +def enqueue_next_all_gather(): + global ALL_GATHER_ITERATOR + assert ALL_GATHER_ITERATOR is not None + try: + next(ALL_GATHER_ITERATOR) + except StopIteration: + pass + + +def retrieve_all_gathered_weight(weight): + global CACHE_WEIGHTS, ALL_GATHER_ITERATOR + assert weight in weights_cache + all_gathered_weight, handle = weights_cache[weight] + if ALL_GATHER_ITERATOR is not None: + enqueue_next_all_gather() + return all_gathered_weight, handle + + +@contextmanager +def optimize_communication( + overlap_all_reduce=True, + overlap_reduce_scatter=False, + cache_weights=False, + overlap_all_gather=False, + model=None, + *args, + **kwargs +): + global OVERLAP_ALL_REDUCE, OVERLAP_REDUCE_SCATTER, CACHE_WEIGHTS + global ALL_GATHER_ITERATOR + OVERLAP_ALL_REDUCE = overlap_all_reduce + OVERLAP_REDUCE_SCATTER = overlap_reduce_scatter + + CACHE_WEIGHTS = cache_weights + + if overlap_all_gather: + if model is None: + raise ValueError( + "You need to pass your model as an argument - " + "optimize_communication(...,model=model, ...)" + "if overlap_all_gather is True" + ) + assert ( + cache_weights + ), "all gathers can only be overlapped if cache_weights is True" + ALL_GATHER_ITERATOR = trigger_async_all_gathers(model) + enqueue_next_all_gather() + + try: + yield None + finally: + clear_handles() + accumulate() + OVERLAP_ALL_REDUCE = False + OVERLAP_REDUCE_SCATTER = False + ALL_GATHER_ITERATOR = None diff --git a/axonn/intra_layer/communication.py b/axonn/intra_layer/communication.py index 62e765c..0056c4d 100644 --- a/axonn/intra_layer/communication.py +++ b/axonn/intra_layer/communication.py @@ -1,10 +1,16 @@ import torch.distributed as dist import torch +import axonn -def _all_reduce(input_, process_group=None): +def _all_reduce(input_, process_group=None, overlap_comm=False): + input_ = input_.contiguous() if dist.get_world_size(process_group) > 1: - dist.all_reduce(input_.contiguous(), group=process_group) + handle = dist.all_reduce( + input_.contiguous(), group=process_group, async_op=overlap_comm + ) + if overlap_comm: + axonn.intra_layer.register_handle(handle) return input_ @@ -21,24 +27,61 @@ def _drop(input_, dim, process_group=None): return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) -def _gather(input_, dim, process_group=None): +def _gather(input_, dim, process_group=None, cache=False): """Gather tensors and concatenate them along a dimension""" if dist.get_world_size(process_group) == 1: return input_ - input_ = input_.contiguous() - # Size and dimension. - rank = dist.get_rank(process_group) + if input_ in axonn.intra_layer.weights_cache: + output, handle = axonn.intra_layer.retrieve_all_gathered_weight(input_) + if handle is not None: + handle.wait() + axonn.intra_layer.weights_cache[input_][1] = None + else: + input_ = input_.contiguous() + # Size and dimension. + rank = dist.get_rank(process_group) + + tensor_list = [ + torch.empty_like(input_) for _ in range(dist.get_world_size(process_group)) + ] + tensor_list[rank] = input_ + dist.all_gather(tensor_list, input_, group=process_group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + + if cache: + axonn.intra_layer.weights_cache[input_] = output, None + + return output - tensor_list = [ - torch.empty_like(input_) for _ in range(dist.get_world_size(process_group)) - ] - tensor_list[rank] = input_ - dist.all_gather(tensor_list, input_, group=process_group) - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=dim).contiguous() +def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False): + assert dim == 0, "reduce scatter only implemented for dim=0" + + if dist.get_world_size(process_group) == 1: + return input_ + + total_chunks = dist.get_world_size(process_group) + assert input_.shape[dim] % total_chunks == 0 + tensor_shape = list(input_.shape) + tensor_shape[dim] //= total_chunks + output = torch.empty( + tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device() + ) + + if hasattr(torch.distributed, "reduce_scatter_tensor"): + handle = torch.distributed.reduce_scatter_tensor( + output, input_, group=process_group, async_op=overlap_comm + ) + else: + handle = torch.distributed._reduce_scatter_base( + output, input_, group=process_group, async_op=overlap_comm + ) + if overlap_comm: + axonn.intra_layer.register_handle(handle) return output @@ -58,17 +101,24 @@ def backward(ctx, grad_output): class BackwardAllReduce(torch.autograd.Function): @staticmethod - def symbolic(graph, input_, process_group=None): + def symbolic(graph, input_, process_group=None, overlap_comm=False): return input_ @staticmethod - def forward(ctx, input_, process_group=None): + def forward(ctx, input_, process_group=None, overlap_comm=False): ctx.process_group = process_group + ctx.overlap_comm = overlap_comm + ctx.input = input_ return input_ @staticmethod def backward(ctx, grad_output): - return _all_reduce(grad_output, ctx.process_group), None + grad_input = _all_reduce(grad_output, ctx.process_group, ctx.overlap_comm) + if not ctx.overlap_comm: + return grad_input, None, None + else: + axonn.intra_layer.accumulate_later(ctx.input, grad_input) + return None, None, None class Drop(torch.autograd.Function): @@ -109,3 +159,49 @@ def backward(ctx, grad_output): None, None, ) + + +class ForwardGather_BackwardReduceScatter(torch.autograd.Function): + @staticmethod + def symbolic( + graph, + input_, + process_group=None, + dim=0, + overlap_comm=False, + cache_all_gather=False, + ): + return _gather(input_, dim=dim, process_group=process_group) + + @staticmethod + def forward( + ctx, + input_, + process_group=None, + dim=0, + overlap_comm=False, + cache_all_gather=False, + ): + assert dim == 0 + ctx.process_group = process_group + ctx.dim = dim + ctx.overlap_comm = overlap_comm + ctx.input = input_ + return _gather( + input_, dim=dim, process_group=process_group, cache=cache_all_gather + ) + + @staticmethod + def backward(ctx, grad_output): + assert ctx.dim == 0 + grad_input = _reduce_scatter( + grad_output, + dim=ctx.dim, + process_group=ctx.process_group, + overlap_comm=ctx.overlap_comm, + ) + if not ctx.overlap_comm: + return (grad_input, None, None, None, None) + else: + axonn.intra_layer.accumulate_later(ctx.input, grad_input) + return None, None, None, None, None diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index faa7715..d677648 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -1,38 +1,54 @@ -from axonn import axonn as ax import torch.distributed as dist import torch -from .communication import Drop, Gather from torch.autograd import Function from torch.cuda.amp import custom_fwd, custom_bwd import math +from axonn import axonn as ax +import axonn +from .communication import ( + Drop, + Gather, + ForwardGather_BackwardReduceScatter, + BackwardAllReduce, +) + def divide(a, b): assert a % b == 0 return a // b +@torch.no_grad() def extract_local_params_from_full_params( - full_params, out_features_group, in_features_group + params, out_features_group, in_features_group, depth_group ): - params = Drop.apply(torch.t(full_params).contiguous(), out_features_group) - params = torch.t(params).contiguous() params = Drop.apply(params, in_features_group) + params = Drop.apply(torch.t(params).contiguous(), out_features_group) + params = torch.t(params).contiguous() + params = Drop.apply(params.reshape(-1), depth_group) # create 1D view return params @torch.no_grad() def initialize_params( - out_features, in_features, out_features_group, in_features_group, init_method + out_features, + in_features, + out_features_group, + in_features_group, + depth_group, + init_method, + init_device="cuda", ): - params = torch.empty((out_features, in_features)) + params = torch.empty((out_features, in_features), device=init_device) init_method(params) params = extract_local_params_from_full_params( - params, out_features_group, in_features_group - ) + params, out_features_group, in_features_group, depth_group + ).cpu() return params +@torch.no_grad() def default_init_method(weight): return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) @@ -47,12 +63,29 @@ def forward( forward_all_reduce_group, backward_all_reduce_group, backward_comm_async, + forward_comm_async, ): ctx.save_for_backward(input_, weight) ctx.backward_all_reduce_group = backward_all_reduce_group ctx.backward_comm_async = backward_comm_async - output = input_.matmul(weight.t()) - dist.all_reduce(output, group=forward_all_reduce_group, async_op=False) + if not forward_comm_async: + output = input_.matmul(weight.t()) + dist.all_reduce(output, group=forward_all_reduce_group, async_op=False) + else: + assert input_.shape[0] % 2 == 0 + input_chunks = torch.chunk(input_, 2) # each chunk is a view of the tensor + output_shape = list(input_.shape) + output_shape[-1] = weight.shape[0] + outputs = [] + outputs.append(input_chunks[0].matmul(weight.t())) + handle = dist.all_reduce( + outputs[-1], group=forward_all_reduce_group, async_op=True + ) + outputs.append(input_chunks[1].matmul(weight.t())) + dist.all_reduce(outputs[-1], group=forward_all_reduce_group, async_op=False) + handle.wait() # this call might be unnecessary + output = torch.cat(outputs) + return output @staticmethod @@ -75,7 +108,7 @@ def backward(ctx, grad_output): ) if handle and ctx.backward_comm_async: handle.wait() - return grad_input, grad_weight, None, None, None + return grad_input, grad_weight, None, None, None, None class Linear(torch.nn.Module): @@ -88,21 +121,20 @@ def __init__( bias=True, skip_bias_add=False, init_method=None, - async_comm_in_backward_pass=True, **kwargs ): super(Linear, self).__init__() self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group + self.depth_group = ax.comm_handle.depth_intra_layer_parallel_group self.inner_group_size = dist.get_world_size(self.inner_group) self.outer_group_size = dist.get_world_size(self.outer_group) + self.depth_group_size = dist.get_world_size(self.depth_group) self.in_features = in_features self.out_features = out_features - self.async_comm_in_backward_pass = async_comm_in_backward_pass - if init_method is None: init_method = default_init_method @@ -116,6 +148,7 @@ def __init__( in_features, self.outer_group, self.inner_group, + self.depth_group, init_method, ) else: @@ -128,6 +161,7 @@ def __init__( in_features, self.inner_group, self.outer_group, + self.depth_group, init_method, ) @@ -171,30 +205,47 @@ def get_output_feature_size(self): return self.local_out_features def forward(self, x, scatter_input=True, gather_output=True): + # gather weights from depth parallel group + # reduce scatter in the backward pass + weight = ForwardGather_BackwardReduceScatter.apply( + self.weight, + self.depth_group, + 0, + axonn.intra_layer.OVERLAP_REDUCE_SCATTER, + axonn.intra_layer.CACHE_WEIGHTS, + ).reshape(self.local_out_features, self.local_in_features) + if not self.transpose: if scatter_input: x = Drop.apply(x, self.inner_group) + x = Drop.apply(x, self.depth_group, 0) x = AsyncLinear.apply( x, - self.weight, + weight, self.inner_group, self.outer_group, - self.async_comm_in_backward_pass, + axonn.intra_layer.OVERLAP_ALL_REDUCE, + False, ) if gather_output: x = Gather.apply(x, self.outer_group) + x = Gather.apply(x, self.depth_group, 0) else: if scatter_input: x = Drop.apply(x, self.outer_group) + x = Drop.apply(x, self.depth_group, 0) + x = AsyncLinear.apply( x, - self.weight, + weight, self.outer_group, self.inner_group, - self.async_comm_in_backward_pass, + axonn.intra_layer.OVERLAP_ALL_REDUCE, + False, ) if gather_output: x = Gather.apply(x, self.inner_group) + x = Gather.apply(x, self.depth_group, 0) if self.bias is None: return x @@ -202,22 +253,28 @@ def forward(self, x, scatter_input=True, gather_output=True): bias = self.bias if gather_output: bias = Gather.apply( - self.bias, + bias, self.outer_group if not self.transpose else self.inner_group, ) + else: + bias = BackwardAllReduce.apply( + bias, self.depth_group, axonn.intra_layer.OVERLAP_REDUCE_SCATTER + ) if self.skip_bias_add: return x, bias else: return x + bias def _is_full_weight_matrix(self, weight): - return (weight.size(0) == self.out_features) and ( - weight.size(1) == self.in_features + return ( + weight.ndim == 2 + and weight.size(0) == self.out_features + and weight.size(1) == self.in_features ) def _is_sharded_weight_matrix(self, weight): - return (weight.size(0) == self.local_out_features) and ( - weight.size(1) == self.local_in_features + return weight.ndim == 1 and weight.size(0) == divide( + self.local_out_features * self.local_in_features, self.depth_group_size ) @torch.no_grad() @@ -245,9 +302,10 @@ def _modified_load_from_state_dict(self, state_dict, prefix, *args, **kwargs): self.outer_group, ) weight = extract_local_params_from_full_params( - weight, out_features_group, in_features_group + weight, out_features_group, in_features_group, self.depth_group ) - state_dict[prefix + "weight"] = weight + + state_dict[prefix + "weight"] = weight if self.bias is not None: bias = ( diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index 5fed505..29409d2 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -2,15 +2,22 @@ import pytest from axonn import axonn as ax from axonn.intra_layer.communication import _drop, _gather -from axonn.intra_layer import Linear, clip_grad_norm_ +from axonn.intra_layer import ( + Linear, + clip_grad_norm_, + optimize_communication, + clear_weights_cache, +) @pytest.mark.mpi @pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)]) -@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) +@pytest.mark.parametrize( + "G_intra_r, G_intra_c, G_intra_d", [(2, 1, 1), (1, 2, 1), (1, 1, 2)] +) @pytest.mark.parametrize("easy_tp", [False, True]) @pytest.mark.parametrize("bias", [False, True]) -def test_fw_pass(G_intra_r, G_intra_c, B, H, easy_tp, bias): +def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias): # These tests are in fp-32 torch.manual_seed(42) ax.init( @@ -18,18 +25,24 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H, easy_tp, bias): G_inter=1, G_intra_r=G_intra_r, G_intra_c=G_intra_c, + G_intra_d=G_intra_d, ) X = torch.randn(B, H).cuda() * 0.01 inner_group = ax.comm_handle.inner_intra_layer_parallel_group outer_group = ax.comm_handle.outer_intra_layer_parallel_group + depth_group = ax.comm_handle.depth_intra_layer_parallel_group if not easy_tp: # manually divide input X_local = _drop( X, 1, inner_group ) # divide colunns of X along the inner tensor group + # manually divide input + X_local = _drop( + X_local, 0, depth_group + ) # divide colunns of X along the inner tensor group else: X_local = X @@ -46,13 +59,9 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H, easy_tp, bias): Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) if not easy_tp: # gather output manually Y_parallel = _gather(Y_local.clone(), 1, outer_group) + Y_parallel = _gather(Y_parallel.clone(), 0, depth_group) else: Y_parallel = Y_local - # sequential FW pass - weight_sequential = _gather( - _gather(layer.weight, 1, inner_group), 0, outer_group - ) - layer_sequential.weight.copy_(weight_sequential) Y_sequential = layer_sequential(X) assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match" @@ -60,17 +69,20 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H, easy_tp, bias): @pytest.mark.mpi @pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)]) -@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) -@pytest.mark.parametrize("async_comm_in_backward_pass", [True, False]) +@pytest.mark.parametrize( + "G_intra_r, G_intra_c, G_intra_d", [(2, 1, 1), (1, 2, 1), (1, 1, 2)] +) +@pytest.mark.parametrize("comm_opt_level", [0, 4]) @pytest.mark.parametrize("easy_tp", [False, True]) @pytest.mark.parametrize("clip_grad_norm", [-1, 1e-3]) -@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("bias", [False, True]) def test_bw_pass( G_intra_r, G_intra_c, + G_intra_d, B, H, - async_comm_in_backward_pass, + comm_opt_level, easy_tp, clip_grad_norm, bias, @@ -82,19 +94,20 @@ def test_bw_pass( G_inter=1, G_intra_r=G_intra_r, G_intra_c=G_intra_c, + G_intra_d=G_intra_d, ) X = torch.randn(B, H).cuda() * 0.01 Y_grad = torch.randn(B, H).cuda() * 0.01 inner_group = ax.comm_handle.inner_intra_layer_parallel_group outer_group = ax.comm_handle.outer_intra_layer_parallel_group + depth_group = ax.comm_handle.depth_intra_layer_parallel_group # parallel backward pass layer = Linear( in_features=H, out_features=H, bias=bias, - async_comm_in_backward_pass=async_comm_in_backward_pass, ).cuda() layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda() @@ -107,19 +120,31 @@ def test_bw_pass( X_local = ( _drop(X, 1, inner_group).detach().clone() ) # divide colunns of X along the inner tensor group + X_local = ( + _drop(X_local, 0, depth_group).detach().clone() + ) # divide colunns of X along the inner tensor group else: X_local = X X_local.requires_grad = True - Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) - if not easy_tp: - Y_local_grad = _drop(Y_grad, 1, outer_group) + Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone() + Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone() else: Y_local_grad = Y_grad - Y_local.backward(Y_local_grad) + with optimize_communication( + overlap_all_reduce=comm_opt_level >= 1, + overlap_reduce_scatter=comm_opt_level >= 2, + cache_weights=comm_opt_level >= 3, + overlap_all_gather=comm_opt_level == 4, + model=layer, + ): + Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + Y_local.backward(Y_local_grad) + if comm_opt_level >= 3: + clear_weights_cache() # sequential backward pass X.requires_grad = True Y_sequential = layer_sequential(X) @@ -132,7 +157,8 @@ def test_bw_pass( ) if not easy_tp: - X_grad_parallel = _gather(X_local.grad, 1, inner_group) + X_grad_parallel = _gather(X_local.grad, 0, depth_group) + X_grad_parallel = _gather(X_grad_parallel, 1, inner_group) else: X_grad_parallel = X_local.grad @@ -140,9 +166,14 @@ def test_bw_pass( X_grad_parallel, X.grad ), "BW Pass - gradients of input do not match" + weight_grad_parallel = _gather(layer.weight.grad, 0, depth_group).reshape( + layer.local_out_features, layer.local_in_features + ) + weight_grad_parallel = _gather( - _gather(layer.weight.grad, 1, inner_group), 0, outer_group + _gather(weight_grad_parallel, 1, inner_group), 0, outer_group ) + assert torch.allclose( weight_grad_parallel, layer_sequential.weight.grad ), "BW Pass - gradients of weight do not match" @@ -155,15 +186,14 @@ def test_bw_pass( if __name__ == "__main__": - test_fw_pass(G_intra_r=2, G_intra_c=1, B=4, H=256, easy_tp=True, bias=True) test_bw_pass( - G_intra_r=2, + G_intra_r=1, G_intra_c=1, - B=4, + G_intra_d=2, + B=2, H=256, - async_comm_in_backward_pass=True, - easy_tp=True, - clip_grad_norm=0.01, + comm_opt_level=0, + easy_tp=False, + clip_grad_norm=-1, bias=True, ) - print("finished")