From 5f1490c4e35d067024d833a97e214de782fbb531 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Oct 2023 14:31:44 -0400 Subject: [PATCH] Intra-layer - Overlap communication in backward pass (#44) * overlap weight grad compute with activation grad communication --- axonn/intra_layer/fully_connected.py | 118 ++++++++++++++++++--------- axonn/tests/test_intra_layer_fc.py | 14 ++-- 2 files changed, 89 insertions(+), 43 deletions(-) diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index d61e2b4..dd758fb 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -1,7 +1,9 @@ from axonn import axonn as ax import torch.distributed as dist import torch -from .communication import ForwardAllReduce, BackwardAllReduce, Drop +from .communication import Drop +from torch.autograd import Function +import math def divide(a, b): @@ -21,6 +23,49 @@ def initialize_params( return params +class AsyncLinear(Function): + @staticmethod + def forward( + ctx, + input_, + weight, + forward_all_reduce_group, + backward_all_reduce_group, + backward_comm_async, + ): + ctx.save_for_backward(input_, weight) + ctx.backward_all_reduce_group = backward_all_reduce_group + ctx.backward_comm_async = backward_comm_async + output = input_.matmul(weight.t()) + dist.all_reduce(output, group=forward_all_reduce_group, async_op=False) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + handle = None + if ctx.needs_input_grad[0]: + grad_input = grad_output.matmul(weight) + handle = dist.all_reduce( + grad_input, + group=ctx.backward_all_reduce_group, + async_op=ctx.backward_comm_async, + ) + if ctx.needs_input_grad[1]: + grad_weight = ( + grad_output.view(-1, grad_output.shape[-1]) + .t() + .mm(input_.view(-1, input_.shape[-1])) + ) + if handle and ctx.backward_comm_async: + handle.wait() + return grad_input, grad_weight, None, None, None + + +def default_init_method(weight): + return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + class Linear(torch.nn.Module): def __init__( self, @@ -30,6 +75,7 @@ def __init__( transpose=False, skip_bias_add=False, init_method=None, + async_comm_in_backward_pass=True, **kwargs ): super(Linear, self).__init__() @@ -38,44 +84,37 @@ def __init__( self.inner_group_size = dist.get_world_size(self.inner_group) self.outer_group_size = dist.get_world_size(self.outer_group) + self.async_comm_in_backward_pass = async_comm_in_backward_pass + + if init_method is None: + init_method = default_init_method if not transpose: assert in_features % self.inner_group_size == 0 assert out_features % self.outer_group_size == 0 self.local_in_features = divide(in_features, self.inner_group_size) self.local_out_features = divide(out_features, self.outer_group_size) - if init_method: - initial_params = initialize_params( - out_features, - in_features, - self.outer_group, - self.inner_group, - init_method, - ) + initial_params = initialize_params( + out_features, + in_features, + self.outer_group, + self.inner_group, + init_method, + ) else: assert out_features % self.inner_group_size == 0 assert in_features % self.outer_group_size == 0 self.local_in_features = divide(in_features, self.outer_group_size) self.local_out_features = divide(out_features, self.inner_group_size) - if init_method: - initial_params = initialize_params( - out_features, - in_features, - self.inner_group, - self.outer_group, - init_method, - ) - - self.linear = torch.nn.Linear( - in_features=self.local_in_features, - out_features=self.local_out_features, - *args, - **kwargs, - bias=False - ) + initial_params = initialize_params( + out_features, + in_features, + self.inner_group, + self.outer_group, + init_method, + ) - if init_method: - self.linear.weight.data.copy_(initial_params) + self.weight = torch.nn.Parameter(initial_params, requires_grad=True) self.bias = torch.nn.Parameter( torch.zeros( @@ -90,18 +129,21 @@ def get_output_feature_size(self): def forward(self, x): if not self.transpose: - if x.size(-1) == self.local_in_features * self.inner_group_size: - x = Drop.apply(x, self.inner_group) - x = BackwardAllReduce.apply(x, self.outer_group) - x = self.linear(x) - x = ForwardAllReduce.apply(x, self.inner_group) + x = AsyncLinear.apply( + x, + self.weight, + self.inner_group, + self.outer_group, + self.async_comm_in_backward_pass, + ) else: - if x.size(-1) == self.local_in_features * self.outer_group_size: - x = Drop.apply(x, self.outer_group) - x = BackwardAllReduce.apply(x, self.inner_group) - x = self.linear(x) - x = ForwardAllReduce.apply(x, self.outer_group) - + x = AsyncLinear.apply( + x, + self.weight, + self.outer_group, + self.inner_group, + self.async_comm_in_backward_pass, + ) if self.skip_bias_add: return x, self.bias else: diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index 7216103..f3f9d34 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -40,7 +40,7 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H): in_features=H, out_features=H, bias=False ).cuda() weight_sequential = _gather( - _gather(layer.linear.weight, 1, inner_group), 0, outer_group + _gather(layer.weight, 1, inner_group), 0, outer_group ) layer_sequential.weight.copy_(weight_sequential) Y_sequential = layer_sequential(X) @@ -51,7 +51,8 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H): @pytest.mark.mpi @pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)]) @pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) -def test_bw_pass(G_intra_r, G_intra_c, B, H): +@pytest.mark.parametrize("async_comm_in_backward_pass", [True, False]) +def test_bw_pass(G_intra_r, G_intra_c, B, H, async_comm_in_backward_pass): # These tests are in fp-32 torch.manual_seed(42) ax.init( @@ -68,7 +69,10 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H): # parallel backward pass layer = Tensor_Parallel_Linear( - in_features=H, out_features=H, skip_bias_add=True + in_features=H, + out_features=H, + skip_bias_add=True, + async_comm_in_backward_pass=async_comm_in_backward_pass, ).cuda() X_local = ( _drop(X, 1, inner_group).detach().clone() @@ -82,7 +86,7 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H): layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=False).cuda() with torch.no_grad(): weight_sequential = _gather( - _gather(layer.linear.weight, 1, inner_group), 0, outer_group + _gather(layer.weight, 1, inner_group), 0, outer_group ) layer_sequential.weight.copy_(weight_sequential) X.requires_grad = True @@ -95,7 +99,7 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H): ), "BW Pass - gradients of input do not match" weight_grad_parallel = _gather( - _gather(layer.linear.weight.grad, 1, inner_group), 0, outer_group + _gather(layer.weight.grad, 1, inner_group), 0, outer_group ) assert torch.allclose( weight_grad_parallel, layer_sequential.weight.grad