From f940ef7b33c9eb86e1c257c3cc49b8d1fa80a85f Mon Sep 17 00:00:00 2001 From: Prajwal Singhania Date: Mon, 23 Oct 2023 16:45:16 -0400 Subject: [PATCH] ILP Conv Layer: Renamed test_conv_layer.py to test_intra_layer_conv.py; Fixed test --- ...conv_layer.py => test_intra_layer_conv.py} | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) rename axonn/tests/{test_conv_layer.py => test_intra_layer_conv.py} (83%) diff --git a/axonn/tests/test_conv_layer.py b/axonn/tests/test_intra_layer_conv.py similarity index 83% rename from axonn/tests/test_conv_layer.py rename to axonn/tests/test_intra_layer_conv.py index 93755e0..dac3e79 100644 --- a/axonn/tests/test_conv_layer.py +++ b/axonn/tests/test_intra_layer_conv.py @@ -4,16 +4,19 @@ from axonn.intra_layer.communication import _drop, _gather from axonn.intra_layer import Tensor_Parallel_Conv2d, drop, gather -torch.use_deterministic_algorithms(True) -torch.backends.cudnn.benchmark = False - - @pytest.mark.mpi @pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)]) @pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) def test_fw_pass(G_intra_r, G_intra_c, H, W, C): # These tests are in fp-32 torch.manual_seed(42) + # Need to remove all non-determinism from convolutions + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + # This is required because TF32 cores only look at the first 10 bits of mantissa thus introducing precision errors + torch.backends.cudnn.allow_tf32 = False + ax.init( G_data=1, G_inter=1, @@ -62,7 +65,14 @@ def test_fw_pass(G_intra_r, G_intra_c, H, W, C): @pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) def test_bw_pass(G_intra_r, G_intra_c, H, W, C): # These tests are in fp-32 + # Need to remove all non-determinism from convolutions torch.manual_seed(42) + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + # This is required because TF32 cores only look at the first 10 bits of mantissa thus introducing precision errors + torch.backends.cudnn.allow_tf32 = False + ax.init( G_data=1, G_inter=1, @@ -107,12 +117,6 @@ def test_bw_pass(G_intra_r, G_intra_c, H, W, C): Y_sequential.backward(Y_grad) X_grad_parallel = _gather(X_local.grad, 1, inner_group) - torch.set_printoptions(threshold=10000) - #print (X_grad_parallel) - #print (X.grad) - - #print (torch.allclose( - # X_grad_parallel, X.grad)) assert torch.allclose( X_grad_parallel, X.grad @@ -125,4 +129,3 @@ def test_bw_pass(G_intra_r, G_intra_c, H, W, C): weight_grad_parallel, layer_sequential.weight.grad ), "BW Pass - gradients of weight do not match" -test_bw_pass(1,2,64,64,4)