From 9cf6bb6468faad447833dde7145b046bcafd0bd0 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 26 Nov 2023 11:44:39 -0800 Subject: [PATCH] separate args for overlap_all_reduce and overlap_reduce_scatter --- axonn/intra_layer/__init__.py | 85 +++++++++++++++++++++++++--- axonn/intra_layer/communication.py | 32 ++++++++--- axonn/intra_layer/fully_connected.py | 10 +++- 3 files changed, 110 insertions(+), 17 deletions(-) diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index df775f8..ee76cf8 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -7,6 +7,7 @@ from axonn import axonn as ax import torch +import torch.distributed as dist def drop(x, transpose=False, dim=-1, batch_dim=0): @@ -31,12 +32,15 @@ def gather(x, transpose=False, dim=-1, batch_dim=0): return x -OVERLAP_COMM = False +OVERLAP_REDUCE_SCATTER = False +OVERLAP_ALL_REDUCE = False CACHE_WEIGHTS = False +ALL_GATHER_ITERATOR = None handles = [] pending_grad_accumulations = [] weights_cache = {} + def register_handle(handle): # ToDo: This might be unnecesary since # we are calling synchronize in clear_handles @@ -66,21 +70,88 @@ def accumulate(): pending_grad_accumulations = [] + def clear_weights_cache(): global weights_cache weights_cache = {} + +def trigger_async_all_gathers(model): + global weights_cache + for module in model.modules(): + if isinstance(module, Linear): + weight = module.weight + if weight not in weights_cache: + # only trigger all gathers if not in cache + process_group = module.depth_group + world_size = dist.get_world_size(process_group) + if world_size == 1: + all_gathered_weight = weight + handle = None + else: + assert weight.ndim == 1 + output_shape = weight.shape[0] * world_size + all_gathered_weight = torch.empty( + output_shape, dtype=weight.dtype, device=weight.device + ) + handle = dist.all_gather_into_tensor( + all_gathered_weight, weight, group=process_group, async_op=True + ) + weights_cache[weight] = [all_gathered_weight, handle] + yield + + +def enqueue_next_all_gather(): + global ALL_GATHER_ITERATOR + assert ALL_GATHER_ITERATOR is not None + try: + next(ALL_GATHER_ITERATOR) + except StopIteration: + pass + + +def retrieve_all_gathered_weight(weight): + global CACHE_WEIGHTS, ALL_GATHER_ITERATOR + assert weight in weights_cache + all_gathered_weight, handle = weights_cache[weight] + if ALL_GATHER_ITERATOR is not None: + enqueue_next_all_gather() + return all_gathered_weight, handle + + @contextmanager -def optimize_communication(cache_weights=False, *args, **kwargs): - global OVERLAP_COMM, CACHE_WEIGHTS - OVERLAP_COMM = True +def optimize_communication( + overlap_all_reduce=True, overlap_reduce_scatter=False, cache_weights=False, overlap_all_gather=False, model=None, *args, **kwargs +): + global OVERLAP_ALL_REDUCE, OVERLAP_REDUCE_SCATTER, CACHE_WEIGHTS, ALL_GATHER_ITERATOR + OVERLAP_ALL_REDUCE = overlap_all_reduce + OVERLAP_REDUCE_SCATTER = overlap_reduce_scatter + if (not cache_weights) and (CACHE_WEIGHTS): - raise ValueError("Attempting to set cache_weights to False, when it was earlier set to True. This can lead to erroneous behaviours. Either always use cache_weights=False or cache_weights=True") - CACHE_WEIGHTS=cache_weights + raise ValueError( + "Attempting to set cache_weights to False, when it was earlier set to True." + "This can lead to erroneous behaviour. Either always use cache_weights=False or cache_weights=True" + ) + CACHE_WEIGHTS = cache_weights + + if overlap_all_gather: + if model is None: + raise ValueError( + "You need to pass your model as an argument - " + "optimize_communication(...,model=model, ...)" + "if overlap_all_gather is True" + ) + assert ( + cache_weights + ), "all gathers can only be overlapped if cache_weights is True" + ALL_GATHER_ITERATOR = trigger_async_all_gathers(model) + enqueue_next_all_gather() try: yield None finally: clear_handles() accumulate() - OVERLAP_COMM = False + OVERLAP_ALL_REDUCE = False + OVERLAP_REDUCE_SCATTER = False + ALL_GATHER_ITERATOR = None diff --git a/axonn/intra_layer/communication.py b/axonn/intra_layer/communication.py index 82527c2..313a4de 100644 --- a/axonn/intra_layer/communication.py +++ b/axonn/intra_layer/communication.py @@ -28,8 +28,10 @@ def _gather(input_, dim, process_group=None, cache=False): return input_ if input_ in axonn.intra_layer.weights_cache: - output = axonn.intra_layer.weights_cache[input_] - + output, handle = axonn.intra_layer.retrieve_all_gathered_weight(input_) + if handle is not None: + handle.wait() + axonn.intra_layer.weights_cache[input_][1] = None else: input_ = input_.contiguous() # Size and dimension. @@ -44,8 +46,8 @@ def _gather(input_, dim, process_group=None, cache=False): # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=dim).contiguous() - if cache: - axonn.intra_layer.weights_cache[input_] = output + if cache: + axonn.intra_layer.weights_cache[input_] = output, None return output @@ -142,17 +144,33 @@ def backward(ctx, grad_output): class ForwardGather_BackwardReduceScatter(torch.autograd.Function): @staticmethod - def symbolic(graph, input_, process_group=None, dim=0, overlap_comm=False, cache_all_gather=False): + def symbolic( + graph, + input_, + process_group=None, + dim=0, + overlap_comm=False, + cache_all_gather=False, + ): return _gather(input_, dim=dim, process_group=process_group) @staticmethod - def forward(ctx, input_, process_group=None, dim=0, overlap_comm=False, cache_all_gather=False): + def forward( + ctx, + input_, + process_group=None, + dim=0, + overlap_comm=False, + cache_all_gather=False, + ): assert dim == 0 ctx.process_group = process_group ctx.dim = dim ctx.overlap_comm = overlap_comm ctx.input = input_ - return _gather(input_, dim=dim, process_group=process_group, cache=cache_all_gather) + return _gather( + input_, dim=dim, process_group=process_group, cache=cache_all_gather + ) @staticmethod def backward(ctx, grad_output): diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index 3dcfe4c..b24c350 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -203,7 +203,11 @@ def forward(self, x, scatter_input=True, gather_output=True): # gather weights from depth parallel group # reduce scatter in the backward pass weight = ForwardGather_BackwardReduceScatter.apply( - self.weight, self.depth_group, 0, axonn.intra_layer.OVERLAP_COMM, axonn.intra_layer.CACHE_WEIGHTS + self.weight, + self.depth_group, + 0, + axonn.intra_layer.OVERLAP_REDUCE_SCATTER, + axonn.intra_layer.CACHE_WEIGHTS, ).reshape(self.local_out_features, self.local_in_features) if not self.transpose: @@ -215,7 +219,7 @@ def forward(self, x, scatter_input=True, gather_output=True): weight, self.inner_group, self.outer_group, - axonn.intra_layer.OVERLAP_COMM, + axonn.intra_layer.OVERLAP_ALL_REDUCE, False, ) if gather_output: @@ -231,7 +235,7 @@ def forward(self, x, scatter_input=True, gather_output=True): weight, self.outer_group, self.inner_group, - axonn.intra_layer.OVERLAP_COMM, + axonn.intra_layer.OVERLAP_ALL_REDUCE, False, ) if gather_output: