Skip to content

Commit

Permalink
reformat and correct minor errors
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Oct 24, 2023
1 parent a2b10aa commit 2e74e9f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
23 changes: 18 additions & 5 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from .communication import Drop
from torch.autograd import Function
import math


def divide(a, b):
Expand Down Expand Up @@ -61,6 +62,10 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, None, None, None


def default_init_method(weight):
return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))


class Linear(torch.nn.Module):
def __init__(
self,
Expand All @@ -70,7 +75,7 @@ def __init__(
transpose=False,
skip_bias_add=False,
init_method=None,
async_comm_in_backward_pass=True
async_comm_in_backward_pass=True,
**kwargs
):
super(Linear, self).__init__()
Expand All @@ -79,10 +84,10 @@ def __init__(

self.inner_group_size = dist.get_world_size(self.inner_group)
self.outer_group_size = dist.get_world_size(self.outer_group)
self.async_comm_in_backward_pass=async_comm_in_backward_pass
self.async_comm_in_backward_pass = async_comm_in_backward_pass

if init_method is None:
init_method = lambda weight : torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
init_method = default_init_method

if not transpose:
assert in_features % self.inner_group_size == 0
Expand Down Expand Up @@ -125,11 +130,19 @@ def get_output_feature_size(self):
def forward(self, x):
if not self.transpose:
x = AsyncLinear.apply(
x, self.weight, self.inner_group, self.outer_group, self.async_comm_in_backward_pass
x,
self.weight,
self.inner_group,
self.outer_group,
self.async_comm_in_backward_pass,
)
else:
x = AsyncLinear.apply(
x, self.weight, self.outer_group, self.inner_group, self.async_comm_in_backward_pass
x,
self.weight,
self.outer_group,
self.inner_group,
self.async_comm_in_backward_pass,
)
if self.skip_bias_add:
return x, self.bias
Expand Down
7 changes: 5 additions & 2 deletions axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H):
@pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)])
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)])
@pytest.mark.parametrize("async_comm_in_backward_pass", [True, False])
def test_bw_pass(G_intra_r, G_intra_c, B, H):
def test_bw_pass(G_intra_r, G_intra_c, B, H, async_comm_in_backward_pass):
# These tests are in fp-32
torch.manual_seed(42)
ax.init(
Expand All @@ -69,7 +69,10 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H):

# parallel backward pass
layer = Tensor_Parallel_Linear(
in_features=H, out_features=H, skip_bias_add=True, async_comm_in_backward_pass=async_comm_in_backward_pass
in_features=H,
out_features=H,
skip_bias_add=True,
async_comm_in_backward_pass=async_comm_in_backward_pass,
).cuda()
X_local = (
_drop(X, 1, inner_group).detach().clone()
Expand Down

0 comments on commit 2e74e9f

Please sign in to comment.