Skip to content

Commit

Permalink
ILP Conv Layer: Added rmse matching in unit tests; fixed gather outpu…
Browse files Browse the repository at this point in the history
…t; added bias and easy_tp tests in conv
  • Loading branch information
prajwal1210 committed Jan 9, 2024
1 parent 620127f commit 8211724
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 39 deletions.
5 changes: 2 additions & 3 deletions axonn/intra_layer/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .communication import ForwardAllReduce, BackwardAllReduce, Drop, Gather, ForwardGather_BackwardReduceScatter
from .utils import divide, default_init_method


@torch.no_grad()
def initialize_params(
out_channels, in_channels, kernel_size, outer_group, inner_group, depth_group, init_method, init_device="cuda"
Expand Down Expand Up @@ -123,9 +122,9 @@ def forward(

if gather_output:
# Gather input across the in_channels dimension on the inner_group
x = Gather.apply(x, self.inner_group, 1)
h = Gather.apply(h, self.outer_group, 1)
# Gather input across the batch dimension on the depth_group
x = Gather.apply(x, self.depth_group, 0)
h = Gather.apply(h, self.depth_group, 0)

if self.bias is None:
return h
Expand Down
114 changes: 78 additions & 36 deletions axonn/tests/test_intra_layer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,38 @@
from axonn import axonn as ax
from axonn.intra_layer.communication import _drop, _gather
from axonn.intra_layer import Tensor_Parallel_Conv2d, sync_gradients
import math
import torch.distributed as dist

def log_dist(msg, ranks=[]):
assert dist.is_initialized()
if dist.get_rank() in ranks:
print(f"Rank {dist.get_rank()} : {msg}")

def norm_allclose(X, Y):
epsilon = 1e-6
squared_diff = torch.square(X - Y)
mse = torch.mean(squared_diff).item()
rmse = math.sqrt(mse)

log_dist(f"RMSE:{rmse}", [0])
log_dist(f"L2Norm:{torch.norm(X - Y, 2)}", [0])

if rmse < epsilon:
return True
else:
return False

@pytest.mark.mpi
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("B", [2, 4, 16])
@pytest.mark.parametrize("G_intra_r, G_intra_c, G_intra_d", [(1, 2, 1), (2, 1, 1), (1, 1, 2)])
def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C):
@pytest.mark.parametrize("easy_tp", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
# These tests are in fp-32
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Need to remove all non-determinism from convolutions
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
Expand All @@ -38,34 +56,42 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C):
outer_group = ax.comm_handle.outer_intra_layer_parallel_group
depth_group = ax.comm_handle.depth_intra_layer_parallel_group

X_local = _drop(
X, 1, inner_group
) # divide channels of X along the inner tensor group
X_local = _drop(
X_local, 0, depth_group
) # divide input channels of X along the depth tensor group
if not easy_tp:
X_local = _drop(
X, 1, inner_group
) # divide channels of X along the inner tensor group
X_local = _drop(
X_local, 0, depth_group
) # divide input channels of X along the depth tensor group
else:
X_local = X

layer = Tensor_Parallel_Conv2d(
in_channels=C, out_channels=2 * C, kernel_size=5, skip_bias_add=True
in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias
).cuda()

with torch.no_grad():
# parallel FW pass
Y_local, _ = layer(X_local, scatter_input=False, gather_output=False)
Y_parallel = _gather(Y_local.clone(), 1, outer_group)
Y_parallel = _gather(Y_parallel.clone(), 0, depth_group)
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
if not easy_tp:
Y_parallel = _gather(Y_local.clone(), 1, outer_group)
Y_parallel = _gather(Y_parallel.clone(), 0, depth_group)
else:
Y_parallel = Y_local

# sequential FW pass
layer_sequential = torch.nn.Conv2d(
in_channels=C,
out_channels=C * 2,
kernel_size=5,
bias=False,
bias=bias,
).cuda()
weight_sequential = _gather(
_gather(_gather(layer.weight, 0, depth_group).reshape(layer.local_out_channels, layer.local_in_channels, layer.kernel_size, layer.kernel_size), 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
if bias:
layer_sequential.bias.zero_()
Y_sequential = layer_sequential(X)

assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match"
Expand All @@ -75,10 +101,13 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C):
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("B", [2, 4, 16])
@pytest.mark.parametrize("G_intra_r, G_intra_c, G_intra_d", [(1, 2, 1), (2, 1, 1), (1, 1, 2)])
def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C):
@pytest.mark.parametrize("easy_tp", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
# These tests are in fp-32
# Need to remove all non-determinism from convolutions
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
Expand All @@ -101,60 +130,73 @@ def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C):

# parallel backward pass
layer = Tensor_Parallel_Conv2d(
in_channels=C, out_channels=2 * C, kernel_size=5, bias=False
in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias
).cuda()

X_local = (
_drop(X, 1, inner_group).detach().clone()
) # divide input channels of X along the inner tensor group
X_local = (
_drop(X_local, 0, depth_group).detach().clone()
) # divide input channels of X along the depth tensor group
if not easy_tp:
X_local = (
_drop(X, 1, inner_group).detach().clone()
) # divide input channels of X along the inner tensor group
X_local = (
_drop(X_local, 0, depth_group).detach().clone()
) # divide input channels of X along the depth tensor group
else:
X_local = X

X_local.requires_grad = True
Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone()
Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone()
if not easy_tp:
Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone()
Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone()
else:
Y_local_grad = Y_grad


Y_local = layer(X_local, scatter_input=False, gather_output=False)
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Y_local.backward(Y_local_grad)

sync_gradients(layer)
if not easy_tp:
sync_gradients(layer)

# sequential backward pass
layer_sequential = torch.nn.Conv2d(
in_channels=C,
out_channels=C * 2,
kernel_size=5,
bias=False,
bias=bias,
).cuda()
with torch.no_grad():
weight_sequential = _gather(
_gather(_gather(layer.weight, 0, depth_group).reshape(layer.local_out_channels, layer.local_in_channels, layer.kernel_size, layer.kernel_size), 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
if bias:
layer_sequential.bias.zero_()
X.requires_grad = True
Y_sequential = layer_sequential(X)
Y_sequential.backward(Y_grad)

X_grad_parallel = _gather(X_local.grad, 0, depth_group)
X_grad_parallel = _gather(X_grad_parallel, 1, inner_group)
if not easy_tp:
X_grad_parallel = _gather(X_local.grad, 0, depth_group)
X_grad_parallel = _gather(X_grad_parallel, 1, inner_group)
else:
X_grad_parallel = X_local.grad

assert torch.allclose(
assert norm_allclose(
X_grad_parallel, X.grad
), "BW Pass - gradients of input do not match"
), f"BW Pass - gradients of input do not match"

weight_grad_parallel = _gather(
_gather(_gather(layer.weight.grad, 0, depth_group).reshape(layer.local_out_channels, layer.local_in_channels, layer.kernel_size, layer.kernel_size), 1, inner_group), 0, outer_group
)

#log_dist(weight_grad_parallel.size(), [0])
#log_dist(layer_sequential.weight.grad.size(), [0])

#log_dist(weight_grad_parallel - layer_sequential.weight.grad, [0])

assert torch.allclose(
assert norm_allclose(
weight_grad_parallel, layer_sequential.weight.grad
), "BW Pass - gradients of weight do not match"

if bias:
bias_grad_parallel = _gather(layer.bias.grad, 0, outer_group)
assert norm_allclose(
bias_grad_parallel, layer_sequential.bias.grad
), f"BW Pass - gradients of bias do not match"


#test_bw_pass(1, 1, 2, 4, 64, 64, 8)

0 comments on commit 8211724

Please sign in to comment.