From 44620db87fcb0946ee3af3c995cf8646538eafce Mon Sep 17 00:00:00 2001 From: Anish Bhupalam Date: Sat, 2 Nov 2024 07:54:19 -0700 Subject: [PATCH 1/3] adding more timers to axonn --- axonn/intra_layer/asym_communication.py | 2 +- axonn/intra_layer/communication.py | 22 +++++++++++++--------- axonn/intra_layer/fully_connected.py | 21 ++++++++++++++++++++- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/axonn/intra_layer/asym_communication.py b/axonn/intra_layer/asym_communication.py index 454f320..865da17 100644 --- a/axonn/intra_layer/asym_communication.py +++ b/axonn/intra_layer/asym_communication.py @@ -168,8 +168,8 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group ) torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group) + ax.get_timers().stop("alltoallv") ax.get_timers().stop("gather-channels-scatter-batch") - ax.get_timers().stop("alltoallv") return torch.cat(recv_tensors, dim=-1) diff --git a/axonn/intra_layer/communication.py b/axonn/intra_layer/communication.py index a436c46..094864f 100644 --- a/axonn/intra_layer/communication.py +++ b/axonn/intra_layer/communication.py @@ -6,9 +6,10 @@ import torch.distributed as dist import torch import axonn.intra_layer.overlap_communication as overlap_communication - +from axonn import axonn as ax def _all_reduce(input_, process_group=None, overlap_comm=False): + ax.get_timers().start("all-reduce") input_ = input_.contiguous() if dist.get_world_size(process_group) > 1: handle = dist.all_reduce( @@ -16,6 +17,7 @@ def _all_reduce(input_, process_group=None, overlap_comm=False): ) if overlap_comm: overlap_communication.register_handle(handle) + ax.get_timers().stop("all-reduce") return input_ @@ -23,20 +25,19 @@ def _drop(input_, dim, process_group=None): """Divide a tensor among the tensor parallel ranks""" if dist.get_world_size(process_group) == 1: return input_ - + ax.get_timers().start("drop") total_chunks = dist.get_world_size(process_group) this_chunk = dist.get_rank(process_group) assert input_.shape[dim] % total_chunks == 0 chunk_size = input_.shape[dim] // total_chunks - + ax.get_timers().stop("drop") return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) - def _gather(input_, dim, process_group=None, cache=False): """Gather tensors and concatenate them along a dimension""" if dist.get_world_size(process_group) == 1: return input_ - + ax.get_timers().start("gather") if input_ in overlap_communication.weights_cache: output, handle = overlap_communication.retrieve_all_gathered_weight( input_, delete=not cache @@ -61,16 +62,16 @@ def _gather(input_, dim, process_group=None, cache=False): if cache: overlap_communication.weights_cache[input_] = output, None - + ax.get_timers().stop("gather") return output def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False): assert dim == 0, "reduce scatter only implemented for dim=0" - + if dist.get_world_size(process_group) == 1: return input_ - + ax.get_timers().start("reduce-scatter") total_chunks = dist.get_world_size(process_group) assert input_.shape[dim] % total_chunks == 0 tensor_shape = list(input_.shape) @@ -79,6 +80,7 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False): tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device() ) + ax.get_timers().start("reduce-scatter-dist") if hasattr(torch.distributed, "reduce_scatter_tensor"): handle = torch.distributed.reduce_scatter_tensor( output, input_, group=process_group, async_op=overlap_comm @@ -87,9 +89,11 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False): handle = torch.distributed._reduce_scatter_base( output, input_, group=process_group, async_op=overlap_comm ) - + + ax.get_timers().stop("reduce-scatter-dist") if overlap_comm: overlap_communication.register_handle(handle) + ax.get_timers().stop("reduce-scatter") return output diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index af75590..d31fb05 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -106,6 +106,7 @@ def forward( local_weight_shape, cache_weights, ): + ax.get_timers().start("forward-async") original_weight = weight weight = _gather( weight, dim=0, process_group=depth_parallel_group, cache=cache_weights @@ -115,14 +116,17 @@ def forward( ctx.backward_all_reduce_group = backward_all_reduce_group ctx.depth_parallel_group = depth_parallel_group ctx.shape = local_weight_shape + ax.get_timers().start("compute") output = input_.matmul(weight.t()) + ax.get_timers().stop("compute") dist.all_reduce(output, group=forward_all_reduce_group, async_op=False) - + ax.get_timers().stop("forward-async") return output @staticmethod @version_aware_custom_bwd def backward(ctx, grad_output): + ax.get_timers().start("backward-async") input_, original_weight = ctx.saved_tensors weight = _gather( original_weight, dim=0, process_group=ctx.depth_parallel_group, cache=False @@ -137,18 +141,22 @@ def backward(ctx, grad_output): grad_input, grad_weight = None, None if ctx.needs_input_grad[0]: + ax.get_timers().start("compute") grad_input = grad_output.matmul(weight) + ax.get_timers().stop("compute") handle = dist.all_reduce( grad_input, group=ctx.backward_all_reduce_group, async_op=overlap_all_reduce, ) if ctx.needs_input_grad[1]: + ax.get_timers().start("compute") grad_weight = ( grad_output.reshape(-1, grad_output.shape[-1]) .t() .mm(input_.view(-1, input_.shape[-1])) ) + ax.get_timers().stop("compute") grad_weight = grad_weight.reshape(-1) grad_weight = _reduce_scatter( @@ -163,27 +171,35 @@ def backward(ctx, grad_output): if overlap_reduce_scatter and ctx.needs_input_grad[1]: overlap_communication.accumulate_later(original_weight, grad_weight) grad_weight = None # weight gradients are not ready yet + ax.get_timers().stop("backward-async") return grad_input, grad_weight, None, None, None, None, None, None, None else: grad_input, grad_weight = None, None if ctx.needs_input_grad[1]: + ax.get_timers().start("compute") grad_weight = ( grad_output.reshape(-1, grad_output.shape[-1]) .t() .mm(input_.view(-1, input_.shape[-1])) ).reshape(-1) + ax.get_timers().stop("compute") + #ax.get_timers().start("reduce-scatter") grad_weight = _reduce_scatter( grad_weight, dim=0, process_group=ctx.depth_parallel_group, overlap_comm=True, ) + #ax.get_timers().stop("reduce-scatter") overlap_communication.accumulate_later(original_weight, grad_weight) grad_weight = None # weight gradients are not ready yet if ctx.needs_input_grad[0]: + ax.get_timers().start("compute") grad_input = grad_output.matmul(weight) + ax.get_timers().stop("compute") + ax.get_timers().stop("backward-async") return grad_input, grad_weight, None, None, None, None, None, None, None @@ -305,6 +321,7 @@ def forward( x, cache_weights_in_all_gather=False, ): + ax.get_timers().start("forward-linear") original_shape_x = x.shape x = x.reshape(-1, x.shape[-1]) weight = self.weight @@ -345,11 +362,13 @@ def forward( x = x.reshape(*original_shape_x[:-1], x.shape[-1]) if self.bias is None: + ax.get_timers().stop("forward-linear") return x else: bias = self.bias if not self.expert_mode: bias = Gather.apply(bias, self.outer_group) + ax.get_timers().stop("forward-linear") if self.skip_bias_add: return x, bias else: From e080fff1977048f56c0b83730f3ab10806ab89f0 Mon Sep 17 00:00:00 2001 From: Anish Bhupalam Date: Sat, 2 Nov 2024 09:42:15 -0700 Subject: [PATCH 2/3] removing comments --- axonn/intra_layer/fully_connected.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index d31fb05..a95a37e 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -184,14 +184,12 @@ def backward(ctx, grad_output): .mm(input_.view(-1, input_.shape[-1])) ).reshape(-1) ax.get_timers().stop("compute") - #ax.get_timers().start("reduce-scatter") grad_weight = _reduce_scatter( grad_weight, dim=0, process_group=ctx.depth_parallel_group, overlap_comm=True, ) - #ax.get_timers().stop("reduce-scatter") overlap_communication.accumulate_later(original_weight, grad_weight) grad_weight = None # weight gradients are not ready yet From fd24d0162011c1b12c1e3df63a800d92032b71ed Mon Sep 17 00:00:00 2001 From: Anish Bhupalam Date: Sun, 3 Nov 2024 15:40:04 -0800 Subject: [PATCH 3/3] fix formatting --- axonn/intra_layer/asym_communication.py | 2 +- axonn/intra_layer/communication.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/axonn/intra_layer/asym_communication.py b/axonn/intra_layer/asym_communication.py index 865da17..821a291 100644 --- a/axonn/intra_layer/asym_communication.py +++ b/axonn/intra_layer/asym_communication.py @@ -168,7 +168,7 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group ) torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group) - ax.get_timers().stop("alltoallv") + ax.get_timers().stop("alltoallv") ax.get_timers().stop("gather-channels-scatter-batch") return torch.cat(recv_tensors, dim=-1) diff --git a/axonn/intra_layer/communication.py b/axonn/intra_layer/communication.py index 094864f..e3839c3 100644 --- a/axonn/intra_layer/communication.py +++ b/axonn/intra_layer/communication.py @@ -8,6 +8,7 @@ import axonn.intra_layer.overlap_communication as overlap_communication from axonn import axonn as ax + def _all_reduce(input_, process_group=None, overlap_comm=False): ax.get_timers().start("all-reduce") input_ = input_.contiguous() @@ -33,6 +34,7 @@ def _drop(input_, dim, process_group=None): ax.get_timers().stop("drop") return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) + def _gather(input_, dim, process_group=None, cache=False): """Gather tensors and concatenate them along a dimension""" if dist.get_world_size(process_group) == 1: @@ -68,7 +70,7 @@ def _gather(input_, dim, process_group=None, cache=False): def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False): assert dim == 0, "reduce scatter only implemented for dim=0" - + if dist.get_world_size(process_group) == 1: return input_ ax.get_timers().start("reduce-scatter") @@ -89,7 +91,7 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False): handle = torch.distributed._reduce_scatter_base( output, input_, group=process_group, async_op=overlap_comm ) - + ax.get_timers().stop("reduce-scatter-dist") if overlap_comm: overlap_communication.register_handle(handle)