Skip to content

Commit

Permalink
Merge branch 'develop' into easy-tensor-parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Oct 24, 2023
2 parents a003277 + 5f1490c commit 270b7dc
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 28 deletions.
90 changes: 67 additions & 23 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from axonn import axonn as ax
import torch.distributed as dist
import torch
from .communication import Drop, Gather
from torch.autograd import Function
import math

from .communication import ForwardAllReduce, BackwardAllReduce, Drop, Gather

def divide(a, b):
assert a % b == 0
return a // b
Expand All @@ -21,6 +21,49 @@ def initialize_params(
params = Drop.apply(params, in_features_group)
return params

class AsyncLinear(Function):
@staticmethod
def forward(
ctx,
input_,
weight,
forward_all_reduce_group,
backward_all_reduce_group,
backward_comm_async,
):
ctx.save_for_backward(input_, weight)
ctx.backward_all_reduce_group = backward_all_reduce_group
ctx.backward_comm_async = backward_comm_async
output = input_.matmul(weight.t())
dist.all_reduce(output, group=forward_all_reduce_group, async_op=False)
return output

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
handle = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(
grad_input,
group=ctx.backward_all_reduce_group,
async_op=ctx.backward_comm_async,
)
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.view(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
)
if handle and ctx.backward_comm_async:
handle.wait()
return grad_input, grad_weight, None, None, None


def default_init_method(weight):
return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))


class Linear(torch.nn.Module):
def __init__(
self,
Expand All @@ -30,6 +73,7 @@ def __init__(
transpose=False,
skip_bias_add=False,
init_method=None,
async_comm_in_backward_pass=True,
**kwargs
):
super(Linear, self).__init__()
Expand All @@ -38,11 +82,12 @@ def __init__(

self.inner_group_size = dist.get_world_size(self.inner_group)
self.outer_group_size = dist.get_world_size(self.outer_group)


self.async_comm_in_backward_pass = async_comm_in_backward_pass

if init_method is None:
## this is the same as pytorch 2.1
init_method = lambda weight : torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))

init_method = default_init_method

if not transpose:
assert in_features % self.inner_group_size == 0
assert out_features % self.outer_group_size == 0
Expand All @@ -68,16 +113,8 @@ def __init__(
init_method,
)

self.linear = torch.nn.Linear(
in_features=self.local_in_features,
out_features=self.local_out_features,
*args,
**kwargs,
bias=False
)

if init_method:
self.linear.weight.data.copy_(initial_params)
self.weight = torch.nn.Parameter(initial_params, requires_grad=True)

setattr(self.linear.weight, "is_tensor_parallel", True)

Expand All @@ -96,26 +133,33 @@ def forward(self, x, scatter_input=True, gather_output=True):
if not self.transpose:
if scatter_input:
x = Drop.apply(x, self.inner_group)
x = BackwardAllReduce.apply(x, self.outer_group)
x = self.linear(x)
x = ForwardAllReduce.apply(x, self.inner_group)
x = AsyncLinear.apply(
x,
self.weight,
self.inner_group,
self.outer_group,
self.async_comm_in_backward_pass,
)
if gather_output:
x = Gather.apply(x, self.outer_group)
else:
if scatter_input:
x = Drop.apply(x, self.outer_group)
x = BackwardAllReduce.apply(x, self.inner_group)
x = self.linear(x)
x = ForwardAllReduce.apply(x, self.outer_group)
x = AsyncLinear.apply(
x,
self.weight,
self.outer_group,
self.inner_group,
self.async_comm_in_backward_pass,
)
if gather_output:
x = Gather.apply(x, self.inner_group)

bias = self.bias
if gather_output:
bias = Gather.apply(
self.bias, self.outer_group if not self.transpose else self.inner_group
)

if self.skip_bias_add:
return x, bias
else:
Expand Down
14 changes: 9 additions & 5 deletions axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H):
in_features=H, out_features=H, bias=False
).cuda()
weight_sequential = _gather(
_gather(layer.linear.weight, 1, inner_group), 0, outer_group
_gather(layer.weight, 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
Y_sequential = layer_sequential(X)
Expand All @@ -51,7 +51,8 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H):
@pytest.mark.mpi
@pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)])
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)])
def test_bw_pass(G_intra_r, G_intra_c, B, H):
@pytest.mark.parametrize("async_comm_in_backward_pass", [True, False])
def test_bw_pass(G_intra_r, G_intra_c, B, H, async_comm_in_backward_pass):
# These tests are in fp-32
torch.manual_seed(42)
ax.init(
Expand All @@ -68,7 +69,10 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H):

# parallel backward pass
layer = Tensor_Parallel_Linear(
in_features=H, out_features=H, skip_bias_add=True
in_features=H,
out_features=H,
skip_bias_add=True,
async_comm_in_backward_pass=async_comm_in_backward_pass,
).cuda()
X_local = (
_drop(X, 1, inner_group).detach().clone()
Expand All @@ -82,7 +86,7 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H):
layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=False).cuda()
with torch.no_grad():
weight_sequential = _gather(
_gather(layer.linear.weight, 1, inner_group), 0, outer_group
_gather(layer.weight, 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
X.requires_grad = True
Expand All @@ -95,7 +99,7 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H):
), "BW Pass - gradients of input do not match"

weight_grad_parallel = _gather(
_gather(layer.linear.weight.grad, 1, inner_group), 0, outer_group
_gather(layer.weight.grad, 1, inner_group), 0, outer_group
)
assert torch.allclose(
weight_grad_parallel, layer_sequential.weight.grad
Expand Down

0 comments on commit 270b7dc

Please sign in to comment.