Skip to content

Commit

Permalink
ILP Conv Layer: Formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
prajwal1210 committed Jan 9, 2024
1 parent 8211724 commit 6d0838a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 41 deletions.
78 changes: 56 additions & 22 deletions axonn/intra_layer/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,30 @@
import torch.distributed as dist
import torch
import math
from .communication import ForwardAllReduce, BackwardAllReduce, Drop, Gather, ForwardGather_BackwardReduceScatter
from .utils import divide, default_init_method
from .communication import (
ForwardAllReduce,
BackwardAllReduce,
Drop,
Gather,
ForwardGather_BackwardReduceScatter,
)
from .utils import divide


@torch.no_grad()
def initialize_params(
out_channels, in_channels, kernel_size, outer_group, inner_group, depth_group, init_method, init_device="cuda"
out_channels,
in_channels,
kernel_size,
outer_group,
inner_group,
depth_group,
init_method,
init_device="cuda",
):
params = torch.empty((out_channels, in_channels, kernel_size, kernel_size), device=init_device)
params = torch.empty(
(out_channels, in_channels, kernel_size, kernel_size), device=init_device
)
init_method(params)
params = Drop.apply(params, outer_group, 0)
params = Drop.apply(params, inner_group, 1)
Expand All @@ -18,6 +34,11 @@ def initialize_params(
return params


@torch.no_grad()
def default_init_method(weight):
return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))


class Conv2d(torch.nn.Module):
def __init__(
self,
Expand All @@ -29,10 +50,10 @@ def __init__(
bias=True,
skip_bias_add=False,
init_method=None,
stride=1,
padding=0,
dilation=1,
groups=1,
stride=1,
padding=0,
dilation=1,
groups=1,
):
super(Conv2d, self).__init__()

Expand Down Expand Up @@ -72,11 +93,13 @@ def __init__(
setattr(
self.weight,
"process_group_for_norm_reduction",
ax.comm_handle.intra_layer_group, # What is intra_layer_group?
ax.comm_handle.intra_layer_group, # What is intra_layer_group?
)

if bias:
self.bias = torch.nn.Parameter(torch.zeros(self.local_out_channels), requires_grad=True)
self.bias = torch.nn.Parameter(
torch.zeros(self.local_out_channels), requires_grad=True
)
setattr(self.bias, "is_tensor_parallel", True)
setattr(self.bias, "needs_gradient_sync", True)
setattr(
Expand All @@ -92,44 +115,55 @@ def __init__(
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups

self.groups = groups

def forward(
self,
self,
x,
scatter_input = True,
gather_output = True,
cache_weights_in_all_gather=False
scatter_input=True,
gather_output=True,
cache_weights_in_all_gather=False,
):
# Gather weights from depth parallel group
weight = ForwardGather_BackwardReduceScatter.apply(
self.weight,
self.depth_group,
0,
).reshape(self.local_out_channels, self.local_in_channels, self.kernel_size, self.kernel_size)

).reshape(
self.local_out_channels,
self.local_in_channels,
self.kernel_size,
self.kernel_size,
)

if scatter_input:
# Drop input across the in_channels dimension on the inner_group
x = Drop.apply(x, self.inner_group, 1)
x = Drop.apply(x, self.inner_group, 1)
# Drop input across the batch dimension on the depth_group
x = Drop.apply(x, self.depth_group, 0)

x = BackwardAllReduce.apply(x, self.outer_group)
h = torch.nn.functional.conv2d(x, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
h = torch.nn.functional.conv2d(
x,
weight,
bias=None,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
h = ForwardAllReduce.apply(h, self.inner_group)

if gather_output:
# Gather input across the in_channels dimension on the inner_group
h = Gather.apply(h, self.outer_group, 1)
# Gather input across the batch dimension on the depth_group
h = Gather.apply(h, self.depth_group, 0)
h = Gather.apply(h, self.depth_group, 0)

if self.bias is None:
return h
else:
bias = self.bias # Why do we need this extra copy?
bias = self.bias # Why do we need this extra copy?
if gather_output:
bias = Gather.apply(bias, self.outer_group)

Expand Down
7 changes: 0 additions & 7 deletions axonn/intra_layer/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
import torch
import math

def divide(a, b):
assert a % b == 0
return a // b

@torch.no_grad()
def default_init_method(weight):
return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
61 changes: 49 additions & 12 deletions axonn/tests/test_intra_layer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
import math
import torch.distributed as dist


def log_dist(msg, ranks=[]):
assert dist.is_initialized()
if dist.get_rank() in ranks:
print(f"Rank {dist.get_rank()} : {msg}")


def norm_allclose(X, Y):
epsilon = 1e-6
squared_diff = torch.square(X - Y)
mse = torch.mean(squared_diff).item()
mse = torch.mean(squared_diff).item()
rmse = math.sqrt(mse)

log_dist(f"RMSE:{rmse}", [0])
Expand All @@ -25,10 +27,13 @@ def norm_allclose(X, Y):
else:
return False


@pytest.mark.mpi
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("B", [2, 4, 16])
@pytest.mark.parametrize("G_intra_r, G_intra_c, G_intra_d", [(1, 2, 1), (2, 1, 1), (1, 1, 2)])
@pytest.mark.parametrize(
"G_intra_r, G_intra_c, G_intra_d", [(1, 2, 1), (2, 1, 1), (1, 1, 2)]
)
@pytest.mark.parametrize("easy_tp", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
Expand Down Expand Up @@ -87,7 +92,18 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
bias=bias,
).cuda()
weight_sequential = _gather(
_gather(_gather(layer.weight, 0, depth_group).reshape(layer.local_out_channels, layer.local_in_channels, layer.kernel_size, layer.kernel_size), 1, inner_group), 0, outer_group
_gather(
_gather(layer.weight, 0, depth_group).reshape(
layer.local_out_channels,
layer.local_in_channels,
layer.kernel_size,
layer.kernel_size,
),
1,
inner_group,
),
0,
outer_group,
)
layer_sequential.weight.copy_(weight_sequential)
if bias:
Expand All @@ -100,7 +116,9 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
@pytest.mark.mpi
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("B", [2, 4, 16])
@pytest.mark.parametrize("G_intra_r, G_intra_c, G_intra_d", [(1, 2, 1), (2, 1, 1), (1, 1, 2)])
@pytest.mark.parametrize(
"G_intra_r, G_intra_c, G_intra_d", [(1, 2, 1), (2, 1, 1), (1, 1, 2)]
)
@pytest.mark.parametrize("easy_tp", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
Expand Down Expand Up @@ -142,15 +160,14 @@ def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
) # divide input channels of X along the depth tensor group
else:
X_local = X

X_local.requires_grad = True
if not easy_tp:
Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone()
Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone()
else:
Y_local_grad = Y_grad


Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Y_local.backward(Y_local_grad)

Expand All @@ -166,7 +183,18 @@ def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
).cuda()
with torch.no_grad():
weight_sequential = _gather(
_gather(_gather(layer.weight, 0, depth_group).reshape(layer.local_out_channels, layer.local_in_channels, layer.kernel_size, layer.kernel_size), 1, inner_group), 0, outer_group
_gather(
_gather(layer.weight, 0, depth_group).reshape(
layer.local_out_channels,
layer.local_in_channels,
layer.kernel_size,
layer.kernel_size,
),
1,
inner_group,
),
0,
outer_group,
)
layer_sequential.weight.copy_(weight_sequential)
if bias:
Expand All @@ -183,10 +211,21 @@ def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):

assert norm_allclose(
X_grad_parallel, X.grad
), f"BW Pass - gradients of input do not match"
), "BW Pass - gradients of input do not match"

weight_grad_parallel = _gather(
_gather(_gather(layer.weight.grad, 0, depth_group).reshape(layer.local_out_channels, layer.local_in_channels, layer.kernel_size, layer.kernel_size), 1, inner_group), 0, outer_group
_gather(
_gather(layer.weight.grad, 0, depth_group).reshape(
layer.local_out_channels,
layer.local_in_channels,
layer.kernel_size,
layer.kernel_size,
),
1,
inner_group,
),
0,
outer_group,
)

assert norm_allclose(
Expand All @@ -197,6 +236,4 @@ def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
bias_grad_parallel = _gather(layer.bias.grad, 0, outer_group)
assert norm_allclose(
bias_grad_parallel, layer_sequential.bias.grad
), f"BW Pass - gradients of bias do not match"


), "BW Pass - gradients of bias do not match"

0 comments on commit 6d0838a

Please sign in to comment.