Skip to content

Commit

Permalink
ILP Conv Layer : CI added
Browse files Browse the repository at this point in the history
  • Loading branch information
prajwal1210 committed Oct 23, 2023
1 parent dcf5272 commit 81ab602
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 4 deletions.
7 changes: 3 additions & 4 deletions axonn/intra_layer/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,16 @@ def __init__(
)
self.conv.weight.data.copy_(initial_params)

self.skip_bias_add = skip_bias_add
self.bias = torch.nn.Parameter(torch.zeros(self.out_channels))

if not self.skip_bias_add:
self.bias = torch.nn.Parameter(torch.zeros(self.out_channels))
self.skip_bias_add = skip_bias_add

def forward(self, x):
x = BackwardAllReduce.apply(x, self.outer_group)
h = self.conv(x)
h = ForwardAllReduce.apply(h, self.inner_group)
if self.skip_bias_add:
return h
return h, self.bias
else:
return h + self.bias.view(1, -1, 1, 1)
return h
128 changes: 128 additions & 0 deletions axonn/tests/test_conv_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import pytest
from axonn import axonn as ax
from axonn.intra_layer.communication import _drop, _gather
from axonn.intra_layer import Tensor_Parallel_Conv2d, drop, gather

torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False


@pytest.mark.mpi
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)])
def test_fw_pass(G_intra_r, G_intra_c, H, W, C):
# These tests are in fp-32
torch.manual_seed(42)
ax.init(
G_data=1,
G_inter=1,
G_intra_r=G_intra_r,
G_intra_c=G_intra_c,
)

X = torch.randn(1, C, H, W).cuda() * 0.01

inner_group = ax.comm_handle.inner_intra_layer_parallel_group
outer_group = ax.comm_handle.outer_intra_layer_parallel_group

X_local = _drop(
X, 1, inner_group
) # divide channels of X along the inner tensor group
layer = Tensor_Parallel_Conv2d(
in_channels=C,
out_channels=2 * C,
kernel_size=5,
skip_bias_add=True
).cuda()

with torch.no_grad():
# parallel FW pass
Y_local, _ = layer(X_local)
Y_parallel = _gather(Y_local.clone(), 1, outer_group)

# sequential FW pass
layer_sequential = torch.nn.Conv2d(
in_channels=C,
out_channels=C * 2,
kernel_size=5,
bias=False,
).cuda()
weight_sequential = _gather(
_gather(layer.conv.weight, 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
Y_sequential = layer_sequential(X)

assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match"


@pytest.mark.mpi
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)])
def test_bw_pass(G_intra_r, G_intra_c, H, W, C):
# These tests are in fp-32
torch.manual_seed(42)
ax.init(
G_data=1,
G_inter=1,
G_intra_r=G_intra_r,
G_intra_c=G_intra_c,
)
X = torch.randn(1, C, H, W).cuda() * 0.01
Y_grad = torch.randn(1, 2 * C, H - 4, W - 4).cuda() * 0.01

inner_group = ax.comm_handle.inner_intra_layer_parallel_group
outer_group = ax.comm_handle.outer_intra_layer_parallel_group

# parallel backward pass
layer = Tensor_Parallel_Conv2d(
in_channels= C,
out_channels=2 * C,
kernel_size=5,
skip_bias_add=True
).cuda()
X_local = (
_drop(X, 1, inner_group).detach().clone()
) # divide input channels of X along the inner tensor group
X_local.requires_grad = True
Y_local, _ = layer(X_local)
Y_local_grad = _drop(Y_grad, 1, outer_group)
Y_local.backward(Y_local_grad)

# sequential backward pass
layer_sequential = torch.nn.Conv2d(
in_channels=C,
out_channels=C * 2,
kernel_size=5,
bias=False,
).cuda()
with torch.no_grad():
weight_sequential = _gather(
_gather(layer.conv.weight, 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
X.requires_grad = True
Y_sequential = layer_sequential(X)
Y_sequential.backward(Y_grad)

X_grad_parallel = _gather(X_local.grad, 1, inner_group)
torch.set_printoptions(threshold=10000)
#print (X_grad_parallel)
#print (X.grad)

#print (torch.allclose(
# X_grad_parallel, X.grad))

assert torch.allclose(
X_grad_parallel, X.grad
), "BW Pass - gradients of input do not match"

weight_grad_parallel = _gather(
_gather(layer.conv.weight.grad, 1, inner_group), 0, outer_group
)
assert torch.allclose(
weight_grad_parallel, layer_sequential.weight.grad
), "BW Pass - gradients of weight do not match"

test_bw_pass(1,2,64,64,4)

0 comments on commit 81ab602

Please sign in to comment.