From 89107ab02b79e4ffd463e91b760ca49b6fef0b36 Mon Sep 17 00:00:00 2001 From: Prajwal Singhania Date: Fri, 6 Oct 2023 12:55:23 -0700 Subject: [PATCH 1/5] ILP Conv Layer support --- axonn/intra_layer/__init__.py | 9 ++++---- axonn/intra_layer/communication.py | 22 +++++++++--------- axonn/intra_layer/conv.py | 36 ++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 14 deletions(-) create mode 100644 axonn/intra_layer/conv.py diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 57efc32..5a71fcf 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -1,20 +1,21 @@ from .fully_connected import Linear as Tensor_Parallel_Linear # noqa: F401 +from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401 from .communication import Drop, Gather from axonn import axonn as ax -def drop(x, transpose=False): +def drop(x, transpose=False, dim=-1): if not transpose: group = ax.comm_handle.inner_intra_layer_parallel_group else: group = ax.comm_handle.outer_intra_layer_parallel_group - return Drop.apply(x, group) + return Drop.apply(x, group, dim) -def gather(x, transpose=False): +def gather(x, transpose=False, dim=-1): if not transpose: group = ax.comm_handle.inner_intra_layer_parallel_group else: group = ax.comm_handle.outer_intra_layer_parallel_group - return Gather.apply(x, group) + return Gather.apply(x, group, dim) diff --git a/axonn/intra_layer/communication.py b/axonn/intra_layer/communication.py index 0d3977c..21fb103 100644 --- a/axonn/intra_layer/communication.py +++ b/axonn/intra_layer/communication.py @@ -73,29 +73,31 @@ def backward(ctx, grad_output): class Drop(torch.autograd.Function): @staticmethod - def symbolic(graph, input_, process_group=None): - return _drop(input_, dim=-1, process_group=process_group) + def symbolic(graph, input_, process_group=None, dim=-1) : + return _drop(input_, dim=dim, process_group=process_group) @staticmethod - def forward(ctx, input_, process_group=None): + def forward(ctx, input_, process_group=None, dim=-1): ctx.process_group = process_group - return _drop(input_, dim=-1, process_group=process_group) + ctx.dim = dim + return _drop(input_, dim=dim, process_group=process_group) @staticmethod def backward(ctx, grad_output): - return _gather(grad_output, dim=-1, process_group=ctx.process_group), None + return _gather(grad_output, dim=ctx.dim, process_group=ctx.process_group), None, None class Gather(torch.autograd.Function): @staticmethod - def symbolic(graph, input_, process_group=None): - return _gather(input_, dim=-1, process_group=process_group) + def symbolic(graph, input_, process_group=None, dim=-1): + return _gather(input_, dim=dim, process_group=process_group) @staticmethod - def forward(ctx, input_, process_group=None): + def forward(ctx, input_, process_group=None, dim=-1,): ctx.process_group = process_group - return _gather(input_, dim=-1, process_group=process_group) + ctx.dim = dim + return _gather(input_, dim=dim, process_group=process_group) @staticmethod def backward(ctx, grad_output): - return _drop(grad_output, dim=-1, process_group=ctx.process_group), None + return _drop(grad_output, dim=ctx.dim, process_group=ctx.process_group), None, None diff --git a/axonn/intra_layer/conv.py b/axonn/intra_layer/conv.py new file mode 100644 index 0000000..6bdd12e --- /dev/null +++ b/axonn/intra_layer/conv.py @@ -0,0 +1,36 @@ +from axonn import axonn as ax +import torch.distributed as dist +import torch +from .communication import ForwardAllReduce, BackwardAllReduce, Drop + +class Conv2d(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, *args, transpose=False, **kwargs): + super(Conv2d, self).__init__() + self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group + self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group + + if transpose: + ordered_groups = [self.outer_group, self.inner_group] + else: + ordered_groups = [self.inner_group, self.outer_group] + + self.group_sizes = [dist.get_world_size(group=group) for group in ordered_groups] + self.ordered_groups = ordered_groups + self.in_channels, self.out_channels = in_channels, out_channels + + + assert in_channels % self.group_sizes[0] == 0 + assert out_channels % self.group_sizes[1] == 0 + + self.conv = torch.nn.Conv2d( + in_channels=in_channels // self.group_sizes[0], + out_channels=out_channels // self.group_sizes[1], + kernel_size=kernel_size, + **kwargs) + + + def forward(self, x): + x = BackwardAllReduce.apply(x, self.ordered_groups[1]) + h = self.conv(x) + h = ForwardAllReduce.apply(h, self.ordered_groups[0]) + return h From 950d4cfd1aa85140970d799ded5b2cf372a5e791 Mon Sep 17 00:00:00 2001 From: Prajwal Singhania Date: Mon, 9 Oct 2023 19:17:14 -0700 Subject: [PATCH 2/5] ILP Conv Layer: Rebased to develop; Added skip_bias_add and init_method params; Black linting --- axonn/intra_layer/__init__.py | 2 +- axonn/intra_layer/communication.py | 16 +++++-- axonn/intra_layer/conv.py | 76 +++++++++++++++++++++++------- axonn/intra_layer/utils.py | 3 ++ 4 files changed, 75 insertions(+), 22 deletions(-) create mode 100644 axonn/intra_layer/utils.py diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 5a71fcf..7180e9e 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -1,5 +1,5 @@ from .fully_connected import Linear as Tensor_Parallel_Linear # noqa: F401 -from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401 +from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401 from .communication import Drop, Gather from axonn import axonn as ax diff --git a/axonn/intra_layer/communication.py b/axonn/intra_layer/communication.py index 21fb103..62e765c 100644 --- a/axonn/intra_layer/communication.py +++ b/axonn/intra_layer/communication.py @@ -73,7 +73,7 @@ def backward(ctx, grad_output): class Drop(torch.autograd.Function): @staticmethod - def symbolic(graph, input_, process_group=None, dim=-1) : + def symbolic(graph, input_, process_group=None, dim=-1): return _drop(input_, dim=dim, process_group=process_group) @staticmethod @@ -84,7 +84,11 @@ def forward(ctx, input_, process_group=None, dim=-1): @staticmethod def backward(ctx, grad_output): - return _gather(grad_output, dim=ctx.dim, process_group=ctx.process_group), None, None + return ( + _gather(grad_output, dim=ctx.dim, process_group=ctx.process_group), + None, + None, + ) class Gather(torch.autograd.Function): @@ -93,11 +97,15 @@ def symbolic(graph, input_, process_group=None, dim=-1): return _gather(input_, dim=dim, process_group=process_group) @staticmethod - def forward(ctx, input_, process_group=None, dim=-1,): + def forward(ctx, input_, process_group=None, dim=-1): ctx.process_group = process_group ctx.dim = dim return _gather(input_, dim=dim, process_group=process_group) @staticmethod def backward(ctx, grad_output): - return _drop(grad_output, dim=ctx.dim, process_group=ctx.process_group), None, None + return ( + _drop(grad_output, dim=ctx.dim, process_group=ctx.process_group), + None, + None, + ) diff --git a/axonn/intra_layer/conv.py b/axonn/intra_layer/conv.py index 6bdd12e..1b01e8c 100644 --- a/axonn/intra_layer/conv.py +++ b/axonn/intra_layer/conv.py @@ -2,35 +2,77 @@ import torch.distributed as dist import torch from .communication import ForwardAllReduce, BackwardAllReduce, Drop +from .utils import divide + + +@torch.no_grad() +def initialize_params( + out_channels, in_channels, kernel_size, outer_group, inner_group, init_method +): + params = torch.empty((out_channels, in_channels, kernel_size, kernel_size)) + init_method(params) + params = Drop.apply(params, outer_group, 0) + params = Drop.apply(params, inner_group, 1) + return params + class Conv2d(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, *args, transpose=False, **kwargs): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + *args, + transpose=False, + skip_bias_add=False, + init_method=None, + **kwargs + ): super(Conv2d, self).__init__() - self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group - self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group - if transpose: - ordered_groups = [self.outer_group, self.inner_group] + if not transpose: + self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group + self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group else: - ordered_groups = [self.inner_group, self.outer_group] + self.outer_group = ax.comm_handle.inner_intra_layer_parallel_group + self.inner_group = ax.comm_handle.outer_intra_layer_parallel_group - self.group_sizes = [dist.get_world_size(group=group) for group in ordered_groups] - self.ordered_groups = ordered_groups - self.in_channels, self.out_channels = in_channels, out_channels - + self.inner_group_size = dist.get_world_size(self.inner_group) + self.outer_group_size = dist.get_world_size(self.outer_group) - assert in_channels % self.group_sizes[0] == 0 - assert out_channels % self.group_sizes[1] == 0 + self.in_channels = divide(in_channels, self.inner_group_size) + self.out_channels = divide(out_channels, self.outer_group_size) self.conv = torch.nn.Conv2d( - in_channels=in_channels // self.group_sizes[0], - out_channels=out_channels // self.group_sizes[1], + in_channels=self.in_channels, + out_channels=self.out_channels, kernel_size=kernel_size, - **kwargs) + bias=False, + **kwargs + ) + if init_method: + initial_params = initialize_params( + out_channels, + in_channels, + kernel_size, + self.outer_group, + self.inner_group, + init_method, + ) + self.conv.weight.data.copy_(initial_params) + + self.skip_bias_add = skip_bias_add + + if not self.skip_bias_add: + self.bias = torch.nn.Parameter(torch.zeros(self.out_channels)) def forward(self, x): - x = BackwardAllReduce.apply(x, self.ordered_groups[1]) + x = BackwardAllReduce.apply(x, self.outer_group) h = self.conv(x) - h = ForwardAllReduce.apply(h, self.ordered_groups[0]) + h = ForwardAllReduce.apply(h, self.inner_group) + if self.skip_bias_add: + return h + else: + return h + self.bias.view(1, -1, 1, 1) return h diff --git a/axonn/intra_layer/utils.py b/axonn/intra_layer/utils.py new file mode 100644 index 0000000..b890b40 --- /dev/null +++ b/axonn/intra_layer/utils.py @@ -0,0 +1,3 @@ +def divide(a, b): + assert a % b == 0 + return a // b From 81ab602d66418ecae55c45f31b737eebc95fb7ef Mon Sep 17 00:00:00 2001 From: Prajwal Singhania Date: Mon, 23 Oct 2023 09:30:10 -0700 Subject: [PATCH 3/5] ILP Conv Layer : CI added --- axonn/intra_layer/conv.py | 7 +- axonn/tests/test_conv_layer.py | 128 +++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 axonn/tests/test_conv_layer.py diff --git a/axonn/intra_layer/conv.py b/axonn/intra_layer/conv.py index 1b01e8c..682d666 100644 --- a/axonn/intra_layer/conv.py +++ b/axonn/intra_layer/conv.py @@ -62,17 +62,16 @@ def __init__( ) self.conv.weight.data.copy_(initial_params) - self.skip_bias_add = skip_bias_add + self.bias = torch.nn.Parameter(torch.zeros(self.out_channels)) - if not self.skip_bias_add: - self.bias = torch.nn.Parameter(torch.zeros(self.out_channels)) + self.skip_bias_add = skip_bias_add def forward(self, x): x = BackwardAllReduce.apply(x, self.outer_group) h = self.conv(x) h = ForwardAllReduce.apply(h, self.inner_group) if self.skip_bias_add: - return h + return h, self.bias else: return h + self.bias.view(1, -1, 1, 1) return h diff --git a/axonn/tests/test_conv_layer.py b/axonn/tests/test_conv_layer.py new file mode 100644 index 0000000..93755e0 --- /dev/null +++ b/axonn/tests/test_conv_layer.py @@ -0,0 +1,128 @@ +import torch +import pytest +from axonn import axonn as ax +from axonn.intra_layer.communication import _drop, _gather +from axonn.intra_layer import Tensor_Parallel_Conv2d, drop, gather + +torch.use_deterministic_algorithms(True) +torch.backends.cudnn.benchmark = False + + +@pytest.mark.mpi +@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)]) +@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) +def test_fw_pass(G_intra_r, G_intra_c, H, W, C): + # These tests are in fp-32 + torch.manual_seed(42) + ax.init( + G_data=1, + G_inter=1, + G_intra_r=G_intra_r, + G_intra_c=G_intra_c, + ) + + X = torch.randn(1, C, H, W).cuda() * 0.01 + + inner_group = ax.comm_handle.inner_intra_layer_parallel_group + outer_group = ax.comm_handle.outer_intra_layer_parallel_group + + X_local = _drop( + X, 1, inner_group + ) # divide channels of X along the inner tensor group + layer = Tensor_Parallel_Conv2d( + in_channels=C, + out_channels=2 * C, + kernel_size=5, + skip_bias_add=True + ).cuda() + + with torch.no_grad(): + # parallel FW pass + Y_local, _ = layer(X_local) + Y_parallel = _gather(Y_local.clone(), 1, outer_group) + + # sequential FW pass + layer_sequential = torch.nn.Conv2d( + in_channels=C, + out_channels=C * 2, + kernel_size=5, + bias=False, + ).cuda() + weight_sequential = _gather( + _gather(layer.conv.weight, 1, inner_group), 0, outer_group + ) + layer_sequential.weight.copy_(weight_sequential) + Y_sequential = layer_sequential(X) + + assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match" + + +@pytest.mark.mpi +@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)]) +@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) +def test_bw_pass(G_intra_r, G_intra_c, H, W, C): + # These tests are in fp-32 + torch.manual_seed(42) + ax.init( + G_data=1, + G_inter=1, + G_intra_r=G_intra_r, + G_intra_c=G_intra_c, + ) + X = torch.randn(1, C, H, W).cuda() * 0.01 + Y_grad = torch.randn(1, 2 * C, H - 4, W - 4).cuda() * 0.01 + + inner_group = ax.comm_handle.inner_intra_layer_parallel_group + outer_group = ax.comm_handle.outer_intra_layer_parallel_group + + # parallel backward pass + layer = Tensor_Parallel_Conv2d( + in_channels= C, + out_channels=2 * C, + kernel_size=5, + skip_bias_add=True + ).cuda() + X_local = ( + _drop(X, 1, inner_group).detach().clone() + ) # divide input channels of X along the inner tensor group + X_local.requires_grad = True + Y_local, _ = layer(X_local) + Y_local_grad = _drop(Y_grad, 1, outer_group) + Y_local.backward(Y_local_grad) + + # sequential backward pass + layer_sequential = torch.nn.Conv2d( + in_channels=C, + out_channels=C * 2, + kernel_size=5, + bias=False, + ).cuda() + with torch.no_grad(): + weight_sequential = _gather( + _gather(layer.conv.weight, 1, inner_group), 0, outer_group + ) + layer_sequential.weight.copy_(weight_sequential) + X.requires_grad = True + Y_sequential = layer_sequential(X) + Y_sequential.backward(Y_grad) + + X_grad_parallel = _gather(X_local.grad, 1, inner_group) + torch.set_printoptions(threshold=10000) + #print (X_grad_parallel) + #print (X.grad) + + #print (torch.allclose( + # X_grad_parallel, X.grad)) + + assert torch.allclose( + X_grad_parallel, X.grad + ), "BW Pass - gradients of input do not match" + + weight_grad_parallel = _gather( + _gather(layer.conv.weight.grad, 1, inner_group), 0, outer_group + ) + assert torch.allclose( + weight_grad_parallel, layer_sequential.weight.grad + ), "BW Pass - gradients of weight do not match" + +test_bw_pass(1,2,64,64,4) From 5944083cbe7552bc6fc4ee51b2c2b8cf7920ec6c Mon Sep 17 00:00:00 2001 From: Prajwal Singhania Date: Mon, 23 Oct 2023 16:45:16 -0400 Subject: [PATCH 4/5] ILP Conv Layer: Renamed test_conv_layer.py to test_intra_layer_conv.py; Fixed test --- ...conv_layer.py => test_intra_layer_conv.py} | 39 +++++++++---------- 1 file changed, 18 insertions(+), 21 deletions(-) rename axonn/tests/{test_conv_layer.py => test_intra_layer_conv.py} (78%) diff --git a/axonn/tests/test_conv_layer.py b/axonn/tests/test_intra_layer_conv.py similarity index 78% rename from axonn/tests/test_conv_layer.py rename to axonn/tests/test_intra_layer_conv.py index 93755e0..99423f8 100644 --- a/axonn/tests/test_conv_layer.py +++ b/axonn/tests/test_intra_layer_conv.py @@ -2,10 +2,7 @@ import pytest from axonn import axonn as ax from axonn.intra_layer.communication import _drop, _gather -from axonn.intra_layer import Tensor_Parallel_Conv2d, drop, gather - -torch.use_deterministic_algorithms(True) -torch.backends.cudnn.benchmark = False +from axonn.intra_layer import Tensor_Parallel_Conv2d @pytest.mark.mpi @@ -14,6 +11,13 @@ def test_fw_pass(G_intra_r, G_intra_c, H, W, C): # These tests are in fp-32 torch.manual_seed(42) + # Need to remove all non-determinism from convolutions + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + # This is required because TF32 cores only look at the first 10 bits of mantissa + torch.backends.cudnn.allow_tf32 = False + ax.init( G_data=1, G_inter=1, @@ -30,10 +34,7 @@ def test_fw_pass(G_intra_r, G_intra_c, H, W, C): X, 1, inner_group ) # divide channels of X along the inner tensor group layer = Tensor_Parallel_Conv2d( - in_channels=C, - out_channels=2 * C, - kernel_size=5, - skip_bias_add=True + in_channels=C, out_channels=2 * C, kernel_size=5, skip_bias_add=True ).cuda() with torch.no_grad(): @@ -62,7 +63,14 @@ def test_fw_pass(G_intra_r, G_intra_c, H, W, C): @pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) def test_bw_pass(G_intra_r, G_intra_c, H, W, C): # These tests are in fp-32 + # Need to remove all non-determinism from convolutions torch.manual_seed(42) + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + # This is required because TF32 cores only look at the first 10 bits of mantissa + torch.backends.cudnn.allow_tf32 = False + ax.init( G_data=1, G_inter=1, @@ -77,13 +85,10 @@ def test_bw_pass(G_intra_r, G_intra_c, H, W, C): # parallel backward pass layer = Tensor_Parallel_Conv2d( - in_channels= C, - out_channels=2 * C, - kernel_size=5, - skip_bias_add=True + in_channels=C, out_channels=2 * C, kernel_size=5, skip_bias_add=True ).cuda() X_local = ( - _drop(X, 1, inner_group).detach().clone() + _drop(X, 1, inner_group).detach().clone() ) # divide input channels of X along the inner tensor group X_local.requires_grad = True Y_local, _ = layer(X_local) @@ -107,12 +112,6 @@ def test_bw_pass(G_intra_r, G_intra_c, H, W, C): Y_sequential.backward(Y_grad) X_grad_parallel = _gather(X_local.grad, 1, inner_group) - torch.set_printoptions(threshold=10000) - #print (X_grad_parallel) - #print (X.grad) - - #print (torch.allclose( - # X_grad_parallel, X.grad)) assert torch.allclose( X_grad_parallel, X.grad @@ -124,5 +123,3 @@ def test_bw_pass(G_intra_r, G_intra_c, H, W, C): assert torch.allclose( weight_grad_parallel, layer_sequential.weight.grad ), "BW Pass - gradients of weight do not match" - -test_bw_pass(1,2,64,64,4) From ce607eb36158cb4bd70e047f1fc308d40b6d25c4 Mon Sep 17 00:00:00 2001 From: Prajwal Singhania Date: Mon, 23 Oct 2023 16:59:41 -0400 Subject: [PATCH 5/5] ILP Conv Layer: Added Unit tests to CI pipeline --- .github/workflows/nvidia-rtx-3090-tests.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/nvidia-rtx-3090-tests.yaml b/.github/workflows/nvidia-rtx-3090-tests.yaml index 266841c..a9c43a7 100644 --- a/.github/workflows/nvidia-rtx-3090-tests.yaml +++ b/.github/workflows/nvidia-rtx-3090-tests.yaml @@ -45,9 +45,12 @@ jobs: - name: Install AxoNN run: | pip install -r requirements.txt - - name: Run unit intra-layer unit tests + - name: Run intra-layer FC unit tests run: | mpirun -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_fc.py + - name: Run intra-layer Conv unit tests + run: | + mpirun -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_conv.py - name: Uninstall AxoNN run: | pip uninstall --yes axonn