-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dcf5272
commit 81ab602
Showing
2 changed files
with
131 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import torch | ||
import pytest | ||
from axonn import axonn as ax | ||
from axonn.intra_layer.communication import _drop, _gather | ||
from axonn.intra_layer import Tensor_Parallel_Conv2d, drop, gather | ||
|
||
torch.use_deterministic_algorithms(True) | ||
torch.backends.cudnn.benchmark = False | ||
|
||
|
||
@pytest.mark.mpi | ||
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)]) | ||
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) | ||
def test_fw_pass(G_intra_r, G_intra_c, H, W, C): | ||
# These tests are in fp-32 | ||
torch.manual_seed(42) | ||
ax.init( | ||
G_data=1, | ||
G_inter=1, | ||
G_intra_r=G_intra_r, | ||
G_intra_c=G_intra_c, | ||
) | ||
|
||
X = torch.randn(1, C, H, W).cuda() * 0.01 | ||
|
||
inner_group = ax.comm_handle.inner_intra_layer_parallel_group | ||
outer_group = ax.comm_handle.outer_intra_layer_parallel_group | ||
|
||
X_local = _drop( | ||
X, 1, inner_group | ||
) # divide channels of X along the inner tensor group | ||
layer = Tensor_Parallel_Conv2d( | ||
in_channels=C, | ||
out_channels=2 * C, | ||
kernel_size=5, | ||
skip_bias_add=True | ||
).cuda() | ||
|
||
with torch.no_grad(): | ||
# parallel FW pass | ||
Y_local, _ = layer(X_local) | ||
Y_parallel = _gather(Y_local.clone(), 1, outer_group) | ||
|
||
# sequential FW pass | ||
layer_sequential = torch.nn.Conv2d( | ||
in_channels=C, | ||
out_channels=C * 2, | ||
kernel_size=5, | ||
bias=False, | ||
).cuda() | ||
weight_sequential = _gather( | ||
_gather(layer.conv.weight, 1, inner_group), 0, outer_group | ||
) | ||
layer_sequential.weight.copy_(weight_sequential) | ||
Y_sequential = layer_sequential(X) | ||
|
||
assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match" | ||
|
||
|
||
@pytest.mark.mpi | ||
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)]) | ||
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) | ||
def test_bw_pass(G_intra_r, G_intra_c, H, W, C): | ||
# These tests are in fp-32 | ||
torch.manual_seed(42) | ||
ax.init( | ||
G_data=1, | ||
G_inter=1, | ||
G_intra_r=G_intra_r, | ||
G_intra_c=G_intra_c, | ||
) | ||
X = torch.randn(1, C, H, W).cuda() * 0.01 | ||
Y_grad = torch.randn(1, 2 * C, H - 4, W - 4).cuda() * 0.01 | ||
|
||
inner_group = ax.comm_handle.inner_intra_layer_parallel_group | ||
outer_group = ax.comm_handle.outer_intra_layer_parallel_group | ||
|
||
# parallel backward pass | ||
layer = Tensor_Parallel_Conv2d( | ||
in_channels= C, | ||
out_channels=2 * C, | ||
kernel_size=5, | ||
skip_bias_add=True | ||
).cuda() | ||
X_local = ( | ||
_drop(X, 1, inner_group).detach().clone() | ||
) # divide input channels of X along the inner tensor group | ||
X_local.requires_grad = True | ||
Y_local, _ = layer(X_local) | ||
Y_local_grad = _drop(Y_grad, 1, outer_group) | ||
Y_local.backward(Y_local_grad) | ||
|
||
# sequential backward pass | ||
layer_sequential = torch.nn.Conv2d( | ||
in_channels=C, | ||
out_channels=C * 2, | ||
kernel_size=5, | ||
bias=False, | ||
).cuda() | ||
with torch.no_grad(): | ||
weight_sequential = _gather( | ||
_gather(layer.conv.weight, 1, inner_group), 0, outer_group | ||
) | ||
layer_sequential.weight.copy_(weight_sequential) | ||
X.requires_grad = True | ||
Y_sequential = layer_sequential(X) | ||
Y_sequential.backward(Y_grad) | ||
|
||
X_grad_parallel = _gather(X_local.grad, 1, inner_group) | ||
torch.set_printoptions(threshold=10000) | ||
#print (X_grad_parallel) | ||
#print (X.grad) | ||
|
||
#print (torch.allclose( | ||
# X_grad_parallel, X.grad)) | ||
|
||
assert torch.allclose( | ||
X_grad_parallel, X.grad | ||
), "BW Pass - gradients of input do not match" | ||
|
||
weight_grad_parallel = _gather( | ||
_gather(layer.conv.weight.grad, 1, inner_group), 0, outer_group | ||
) | ||
assert torch.allclose( | ||
weight_grad_parallel, layer_sequential.weight.grad | ||
), "BW Pass - gradients of weight do not match" | ||
|
||
test_bw_pass(1,2,64,64,4) |