Skip to content

Commit

Permalink
ILP Conv Layer : Fixed overlap all_gather optimisation; Renamed Tenso…
Browse files Browse the repository at this point in the history
…r_Parallel_Conv2d to Conv2d
  • Loading branch information
prajwal1210 committed Jan 17, 2024
1 parent b653760 commit 454f031
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from .fully_connected import Linear # noqa: F401
from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401
from .conv import Conv2d # noqa: F401

from .communication import Drop, Gather
from .gradient_normalization import clip_grad_norm_ # noqa: F401
Expand Down Expand Up @@ -86,7 +86,7 @@ def clear_weights_cache():
def trigger_async_all_gathers(model):
global weights_cache
for module in model.modules():
if isinstance(module, Linear):
if isinstance(module, Linear) or isinstance(module, Conv2d):
weight = module.weight
if weight not in weights_cache:
# only trigger all gathers if not in cache
Expand Down
2 changes: 1 addition & 1 deletion axonn/intra_layer/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def forward(
if self.bias is None:
return h
else:
bias = self.bias # Why do we need this extra copy?
bias = self.bias
if gather_output:
bias = Gather.apply(bias, self.outer_group)

Expand Down
6 changes: 3 additions & 3 deletions axonn/tests/test_intra_layer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from axonn import axonn as ax
from axonn.intra_layer.communication import _drop, _gather
from axonn.intra_layer import (
Tensor_Parallel_Conv2d,
Conv2d,
optimize_communication,
clear_weights_cache,
sync_gradients,
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
else:
X_local = X

layer = Tensor_Parallel_Conv2d(
layer = Conv2d(
in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias
).cuda()

Expand Down Expand Up @@ -155,7 +155,7 @@ def test_bw_pass(
depth_group = ax.comm_handle.depth_intra_layer_parallel_group

# parallel backward pass
layer = Tensor_Parallel_Conv2d(
layer = Conv2d(
in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias
).cuda()

Expand Down

0 comments on commit 454f031

Please sign in to comment.