diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index af7de7c..cb58973 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -1,10 +1,10 @@ from axonn import axonn as ax import torch.distributed as dist import torch +from .communication import Drop, Gather +from torch.autograd import Function import math -from .communication import ForwardAllReduce, BackwardAllReduce, Drop, Gather - def divide(a, b): assert a % b == 0 return a // b @@ -21,6 +21,49 @@ def initialize_params( params = Drop.apply(params, in_features_group) return params +class AsyncLinear(Function): + @staticmethod + def forward( + ctx, + input_, + weight, + forward_all_reduce_group, + backward_all_reduce_group, + backward_comm_async, + ): + ctx.save_for_backward(input_, weight) + ctx.backward_all_reduce_group = backward_all_reduce_group + ctx.backward_comm_async = backward_comm_async + output = input_.matmul(weight.t()) + dist.all_reduce(output, group=forward_all_reduce_group, async_op=False) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + handle = None + if ctx.needs_input_grad[0]: + grad_input = grad_output.matmul(weight) + handle = dist.all_reduce( + grad_input, + group=ctx.backward_all_reduce_group, + async_op=ctx.backward_comm_async, + ) + if ctx.needs_input_grad[1]: + grad_weight = ( + grad_output.view(-1, grad_output.shape[-1]) + .t() + .mm(input_.view(-1, input_.shape[-1])) + ) + if handle and ctx.backward_comm_async: + handle.wait() + return grad_input, grad_weight, None, None, None + + +def default_init_method(weight): + return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + class Linear(torch.nn.Module): def __init__( self, @@ -30,6 +73,7 @@ def __init__( transpose=False, skip_bias_add=False, init_method=None, + async_comm_in_backward_pass=True, **kwargs ): super(Linear, self).__init__() @@ -38,11 +82,12 @@ def __init__( self.inner_group_size = dist.get_world_size(self.inner_group) self.outer_group_size = dist.get_world_size(self.outer_group) - + + self.async_comm_in_backward_pass = async_comm_in_backward_pass + if init_method is None: - ## this is the same as pytorch 2.1 - init_method = lambda weight : torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - + init_method = default_init_method + if not transpose: assert in_features % self.inner_group_size == 0 assert out_features % self.outer_group_size == 0 @@ -68,16 +113,8 @@ def __init__( init_method, ) - self.linear = torch.nn.Linear( - in_features=self.local_in_features, - out_features=self.local_out_features, - *args, - **kwargs, - bias=False - ) - if init_method: - self.linear.weight.data.copy_(initial_params) + self.weight = torch.nn.Parameter(initial_params, requires_grad=True) setattr(self.linear.weight, "is_tensor_parallel", True) @@ -96,26 +133,33 @@ def forward(self, x, scatter_input=True, gather_output=True): if not self.transpose: if scatter_input: x = Drop.apply(x, self.inner_group) - x = BackwardAllReduce.apply(x, self.outer_group) - x = self.linear(x) - x = ForwardAllReduce.apply(x, self.inner_group) + x = AsyncLinear.apply( + x, + self.weight, + self.inner_group, + self.outer_group, + self.async_comm_in_backward_pass, + ) if gather_output: x = Gather.apply(x, self.outer_group) else: if scatter_input: x = Drop.apply(x, self.outer_group) - x = BackwardAllReduce.apply(x, self.inner_group) - x = self.linear(x) - x = ForwardAllReduce.apply(x, self.outer_group) + x = AsyncLinear.apply( + x, + self.weight, + self.outer_group, + self.inner_group, + self.async_comm_in_backward_pass, + ) if gather_output: x = Gather.apply(x, self.inner_group) - + bias = self.bias if gather_output: bias = Gather.apply( self.bias, self.outer_group if not self.transpose else self.inner_group ) - if self.skip_bias_add: return x, bias else: diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index 4f600f8..a48498c 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -40,7 +40,7 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H): in_features=H, out_features=H, bias=False ).cuda() weight_sequential = _gather( - _gather(layer.linear.weight, 1, inner_group), 0, outer_group + _gather(layer.weight, 1, inner_group), 0, outer_group ) layer_sequential.weight.copy_(weight_sequential) Y_sequential = layer_sequential(X) @@ -51,7 +51,8 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H): @pytest.mark.mpi @pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)]) @pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) -def test_bw_pass(G_intra_r, G_intra_c, B, H): +@pytest.mark.parametrize("async_comm_in_backward_pass", [True, False]) +def test_bw_pass(G_intra_r, G_intra_c, B, H, async_comm_in_backward_pass): # These tests are in fp-32 torch.manual_seed(42) ax.init( @@ -68,7 +69,10 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H): # parallel backward pass layer = Tensor_Parallel_Linear( - in_features=H, out_features=H, skip_bias_add=True + in_features=H, + out_features=H, + skip_bias_add=True, + async_comm_in_backward_pass=async_comm_in_backward_pass, ).cuda() X_local = ( _drop(X, 1, inner_group).detach().clone() @@ -82,7 +86,7 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H): layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=False).cuda() with torch.no_grad(): weight_sequential = _gather( - _gather(layer.linear.weight, 1, inner_group), 0, outer_group + _gather(layer.weight, 1, inner_group), 0, outer_group ) layer_sequential.weight.copy_(weight_sequential) X.requires_grad = True @@ -95,7 +99,7 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H): ), "BW Pass - gradients of input do not match" weight_grad_parallel = _gather( - _gather(layer.linear.weight.grad, 1, inner_group), 0, outer_group + _gather(layer.weight.grad, 1, inner_group), 0, outer_group ) assert torch.allclose( weight_grad_parallel, layer_sequential.weight.grad