Skip to content

Commit

Permalink
Intra-layer - Overlap communication in backward pass (#44)
Browse files Browse the repository at this point in the history
* overlap weight grad compute with activation grad communication
  • Loading branch information
siddharth9820 authored Oct 24, 2023
1 parent d69dbb9 commit 5f1490c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 43 deletions.
118 changes: 80 additions & 38 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from axonn import axonn as ax
import torch.distributed as dist
import torch
from .communication import ForwardAllReduce, BackwardAllReduce, Drop
from .communication import Drop
from torch.autograd import Function
import math


def divide(a, b):
Expand All @@ -21,6 +23,49 @@ def initialize_params(
return params


class AsyncLinear(Function):
@staticmethod
def forward(
ctx,
input_,
weight,
forward_all_reduce_group,
backward_all_reduce_group,
backward_comm_async,
):
ctx.save_for_backward(input_, weight)
ctx.backward_all_reduce_group = backward_all_reduce_group
ctx.backward_comm_async = backward_comm_async
output = input_.matmul(weight.t())
dist.all_reduce(output, group=forward_all_reduce_group, async_op=False)
return output

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
handle = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(
grad_input,
group=ctx.backward_all_reduce_group,
async_op=ctx.backward_comm_async,
)
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.view(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
)
if handle and ctx.backward_comm_async:
handle.wait()
return grad_input, grad_weight, None, None, None


def default_init_method(weight):
return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))


class Linear(torch.nn.Module):
def __init__(
self,
Expand All @@ -30,6 +75,7 @@ def __init__(
transpose=False,
skip_bias_add=False,
init_method=None,
async_comm_in_backward_pass=True,
**kwargs
):
super(Linear, self).__init__()
Expand All @@ -38,44 +84,37 @@ def __init__(

self.inner_group_size = dist.get_world_size(self.inner_group)
self.outer_group_size = dist.get_world_size(self.outer_group)
self.async_comm_in_backward_pass = async_comm_in_backward_pass

if init_method is None:
init_method = default_init_method

if not transpose:
assert in_features % self.inner_group_size == 0
assert out_features % self.outer_group_size == 0
self.local_in_features = divide(in_features, self.inner_group_size)
self.local_out_features = divide(out_features, self.outer_group_size)
if init_method:
initial_params = initialize_params(
out_features,
in_features,
self.outer_group,
self.inner_group,
init_method,
)
initial_params = initialize_params(
out_features,
in_features,
self.outer_group,
self.inner_group,
init_method,
)
else:
assert out_features % self.inner_group_size == 0
assert in_features % self.outer_group_size == 0
self.local_in_features = divide(in_features, self.outer_group_size)
self.local_out_features = divide(out_features, self.inner_group_size)
if init_method:
initial_params = initialize_params(
out_features,
in_features,
self.inner_group,
self.outer_group,
init_method,
)

self.linear = torch.nn.Linear(
in_features=self.local_in_features,
out_features=self.local_out_features,
*args,
**kwargs,
bias=False
)
initial_params = initialize_params(
out_features,
in_features,
self.inner_group,
self.outer_group,
init_method,
)

if init_method:
self.linear.weight.data.copy_(initial_params)
self.weight = torch.nn.Parameter(initial_params, requires_grad=True)

self.bias = torch.nn.Parameter(
torch.zeros(
Expand All @@ -90,18 +129,21 @@ def get_output_feature_size(self):

def forward(self, x):
if not self.transpose:
if x.size(-1) == self.local_in_features * self.inner_group_size:
x = Drop.apply(x, self.inner_group)
x = BackwardAllReduce.apply(x, self.outer_group)
x = self.linear(x)
x = ForwardAllReduce.apply(x, self.inner_group)
x = AsyncLinear.apply(
x,
self.weight,
self.inner_group,
self.outer_group,
self.async_comm_in_backward_pass,
)
else:
if x.size(-1) == self.local_in_features * self.outer_group_size:
x = Drop.apply(x, self.outer_group)
x = BackwardAllReduce.apply(x, self.inner_group)
x = self.linear(x)
x = ForwardAllReduce.apply(x, self.outer_group)

x = AsyncLinear.apply(
x,
self.weight,
self.outer_group,
self.inner_group,
self.async_comm_in_backward_pass,
)
if self.skip_bias_add:
return x, self.bias
else:
Expand Down
14 changes: 9 additions & 5 deletions axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H):
in_features=H, out_features=H, bias=False
).cuda()
weight_sequential = _gather(
_gather(layer.linear.weight, 1, inner_group), 0, outer_group
_gather(layer.weight, 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
Y_sequential = layer_sequential(X)
Expand All @@ -51,7 +51,8 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H):
@pytest.mark.mpi
@pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)])
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)])
def test_bw_pass(G_intra_r, G_intra_c, B, H):
@pytest.mark.parametrize("async_comm_in_backward_pass", [True, False])
def test_bw_pass(G_intra_r, G_intra_c, B, H, async_comm_in_backward_pass):
# These tests are in fp-32
torch.manual_seed(42)
ax.init(
Expand All @@ -68,7 +69,10 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H):

# parallel backward pass
layer = Tensor_Parallel_Linear(
in_features=H, out_features=H, skip_bias_add=True
in_features=H,
out_features=H,
skip_bias_add=True,
async_comm_in_backward_pass=async_comm_in_backward_pass,
).cuda()
X_local = (
_drop(X, 1, inner_group).detach().clone()
Expand All @@ -82,7 +86,7 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H):
layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=False).cuda()
with torch.no_grad():
weight_sequential = _gather(
_gather(layer.linear.weight, 1, inner_group), 0, outer_group
_gather(layer.weight, 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
X.requires_grad = True
Expand All @@ -95,7 +99,7 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H):
), "BW Pass - gradients of input do not match"

weight_grad_parallel = _gather(
_gather(layer.linear.weight.grad, 1, inner_group), 0, outer_group
_gather(layer.weight.grad, 1, inner_group), 0, outer_group
)
assert torch.allclose(
weight_grad_parallel, layer_sequential.weight.grad
Expand Down

0 comments on commit 5f1490c

Please sign in to comment.