Skip to content

Commit

Permalink
ILP Conv Layer: Renamed test_conv_layer.py to test_intra_layer_conv.p…
Browse files Browse the repository at this point in the history
…y; Fixed test
  • Loading branch information
prajwal1210 committed Oct 23, 2023
1 parent 81ab602 commit f940ef7
Showing 1 changed file with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from axonn.intra_layer.communication import _drop, _gather
from axonn.intra_layer import Tensor_Parallel_Conv2d, drop, gather

torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False


@pytest.mark.mpi
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)])
def test_fw_pass(G_intra_r, G_intra_c, H, W, C):
# These tests are in fp-32
torch.manual_seed(42)
# Need to remove all non-determinism from convolutions
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# This is required because TF32 cores only look at the first 10 bits of mantissa thus introducing precision errors
torch.backends.cudnn.allow_tf32 = False

ax.init(
G_data=1,
G_inter=1,
Expand Down Expand Up @@ -62,7 +65,14 @@ def test_fw_pass(G_intra_r, G_intra_c, H, W, C):
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)])
def test_bw_pass(G_intra_r, G_intra_c, H, W, C):
# These tests are in fp-32
# Need to remove all non-determinism from convolutions
torch.manual_seed(42)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# This is required because TF32 cores only look at the first 10 bits of mantissa thus introducing precision errors
torch.backends.cudnn.allow_tf32 = False

ax.init(
G_data=1,
G_inter=1,
Expand Down Expand Up @@ -107,12 +117,6 @@ def test_bw_pass(G_intra_r, G_intra_c, H, W, C):
Y_sequential.backward(Y_grad)

X_grad_parallel = _gather(X_local.grad, 1, inner_group)
torch.set_printoptions(threshold=10000)
#print (X_grad_parallel)
#print (X.grad)

#print (torch.allclose(
# X_grad_parallel, X.grad))

assert torch.allclose(
X_grad_parallel, X.grad
Expand All @@ -125,4 +129,3 @@ def test_bw_pass(G_intra_r, G_intra_c, H, W, C):
weight_grad_parallel, layer_sequential.weight.grad
), "BW Pass - gradients of weight do not match"

test_bw_pass(1,2,64,64,4)

0 comments on commit f940ef7

Please sign in to comment.