Skip to content

Commit

Permalink
ILP Conv Layer : Added communication optimisations to CI tests
Browse files Browse the repository at this point in the history
  • Loading branch information
prajwal1210 committed Jan 17, 2024
1 parent eecf387 commit e338e88
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
3 changes: 2 additions & 1 deletion axonn/intra_layer/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,12 @@ def forward(
cache_weights_in_all_gather=False,
):
# Gather weights from depth parallel group
# TODO: We should make the OVERLAP_REDUCE_SCATTER flag part of axonn.axonn
weight = ForwardGather_BackwardReduceScatter.apply(
self.weight,
self.depth_group,
0,
axonn.intra_layer.OVERLAP_REDUCE_SCATTER, # TODO: We should ideally make the flag part of axonn.axonn instead of just axonn
axonn.intra_layer.OVERLAP_REDUCE_SCATTER,
cache_weights_in_all_gather,
).reshape(
self.local_out_channels,
Expand Down
24 changes: 20 additions & 4 deletions axonn/tests/test_intra_layer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import pytest
from axonn import axonn as ax
from axonn.intra_layer.communication import _drop, _gather
from axonn.intra_layer import Tensor_Parallel_Conv2d, sync_gradients
from axonn.intra_layer import (
Tensor_Parallel_Conv2d,
optimize_communication,
clear_weights_cache,
sync_gradients,
)
import math
import torch.distributed as dist

Expand Down Expand Up @@ -121,7 +126,10 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
)
@pytest.mark.parametrize("easy_tp", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
@pytest.mark.parametrize("comm_opt_level", [0, 3])
def test_bw_pass(
G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias, comm_opt_level
):
# These tests are in fp-32
# Need to remove all non-determinism from convolutions
torch.manual_seed(42)
Expand Down Expand Up @@ -168,11 +176,19 @@ def test_bw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
else:
Y_local_grad = Y_grad

Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Y_local.backward(Y_local_grad)
with optimize_communication(
overlap_reduce_scatter=comm_opt_level >= 1,
cache_weights=comm_opt_level >= 2,
overlap_all_gather=comm_opt_level == 3,
model_object_for_overlapping_allgathers=layer,
):
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Y_local.backward(Y_local_grad)

if not easy_tp:
sync_gradients(layer)
if comm_opt_level >= 3:
clear_weights_cache()

# sequential backward pass
layer_sequential = torch.nn.Conv2d(
Expand Down

0 comments on commit e338e88

Please sign in to comment.