From 1d1a77983112e8d5fcd3bcaed770b89700bf3a99 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Sat, 4 May 2024 10:08:27 +0200 Subject: [PATCH] Add tests for non-cubes and prime*size dimensions --- tests/test_transpose.py | 71 +++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 8aedd7d..561e376 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -52,26 +52,26 @@ def create_spmd_array(global_shape, pdims): pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100 pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100 -params = [(size, 1), (1, size), pencil_1, pencil_2] -#params = [(size, 1), (1, size)] +decomp = [(size, 1), (1, size), pencil_1, pencil_2] +global_shapes = [(4, 8, 16), (4, 4, 4), (29 * size, 19 * size, 17 * size) + ] # Cubes, non-cubes and primes + +# Cartesian product tests @pytest.mark.parametrize("pdims", - params) # Test with Slab and Pencil decompositions -def test_tranpose(pdims): + decomp) # Test with Slab and Pencil decompositions +@pytest.mark.parametrize("global_shape", + global_shapes) # Test cubes, non-cubes and primes +def test_tranpose(pdims, global_shape): """ Goes from an array of shape [z,y,x] # What we call an x pencil to [x,z,y] # what we call a y pencil """ print("*" * 80) - print(f"Testing with pdims {pdims}") - - global_shape = (4, 4, 4) # These sizes are prime numbers x size of the pmesh - #global_shape = (29 * size, 19 * size, 17 * size) - print(f"Global shape is {global_shape}") + print(f"Testing with pdims {pdims} and global shape {global_shape}") global_array, mesh = create_spmd_array(global_shape, pdims) - gathered_array = multihost_utils.process_allgather(global_array, tiled=True) with mesh: jd_tranposed_xy = transposeXtoY(global_array) jd_tranposed_yz = transposeYtoZ(jd_tranposed_xy) @@ -98,6 +98,16 @@ def test_tranpose(pdims): print(f"JD tranposed yz sharding {jd_tranposed_yz.sharding.spec}") print(f"JD tranposed xy sharding {jd_tranposed_xy.sharding.spec}") + print(f"JD tranposed zy sharding {jd_tranposed_zy.sharding.spec}") + print(f"JD tranposed yx sharding {jd_tranposed_yx.sharding.spec}") + + assert global_array.sharding.spec == P('z', 'y') + assert jd_tranposed_xy.sharding.spec == transposed_sharding + assert jd_tranposed_yz.sharding.spec == original_sharding + assert jd_tranposed_zy.sharding.spec == transposed_sharding + assert jd_tranposed_yx.sharding.spec == original_sharding + + gathered_array = multihost_utils.process_allgather(global_array, tiled=True) gathered_jd_xy = multihost_utils.process_allgather( jd_tranposed_xy, tiled=True) @@ -108,42 +118,25 @@ def test_tranpose(pdims): gathered_jd_yx = multihost_utils.process_allgather( jd_tranposed_yx, tiled=True) - assert jd_tranposed_xy.sharding.spec == transposed_sharding - assert jd_tranposed_yz.sharding.spec == original_sharding - assert jd_tranposed_zy.sharding.spec == transposed_sharding - assert jd_tranposed_yx.sharding.spec == original_sharding - # Explanation : - - # For pencils - # Tranposing forward is a shift axis to the right so ZYX to XZY to YXZ (2 0 1) # Tranposing backward is a shift axis to the left so YXZ to XZY to ZYX (1 2 0) # Double Tranposing from ZYX to YXZ is double (2 0 1) so (1 2 0) - # For slabs it is a bit more complicated - - # Tranposing from X to Y is a X Y tranpose so ZYX to ZXY (0 2 1) and tranposing back is the same (0 2 1) - # Tranposing from Y to Z is a Y Z tranpose so ZXY to YXZ (2 0 1) and tranposing back is the same (2 0 1) - # a double tranpose is a Z X tranpose so YXZ to ZYX (0 1 2) - - # Every tranpose, also tranposes the pdim grid from P('Z', 'Y') to P('Y', 'Z') or vise versa - - forward_tranpose = [1, 2, 0] if 1 in pdims else [2, 0, 1] - forward_pencils = [2, 0, 1] - backward_tranpose = [2, 0, 1] if 1 in pdims else [1, 2, 0] - backward_pencils = [1, 2, 0] - double_back = [0, 1, 2] if 1 in pdims else [1, 2, 0] + forward_tranpose = [2, 0, 1] + backward_tranpose = [1, 2, 0] + double_forward = [1, 2, 0] + # # Test X to Y transpose # It tranposes ZYX to XZY so from 0 1 2 to 2 0 1 assert_array_equal(gathered_array.transpose(forward_tranpose), gathered_jd_xy) # ********************************************* # Test Y to Z transpose # It tranposes XZY to YXZ so from 0 1 2 to 2 0 1 again - assert_array_equal(gathered_jd_xy.transpose(forward_pencils), gathered_jd_yz) + assert_array_equal(gathered_jd_xy.transpose(forward_tranpose), gathered_jd_yz) # and from the global array ZYX to YXZ so from 0 1 2 to 1 2 0 - assert_array_equal(gathered_array.transpose(double_back), gathered_jd_yz) + assert_array_equal(gathered_array.transpose(double_forward), gathered_jd_yz) # ********************************************* # Test Z to Y transpose # It tranposes YXZ to XZY so from 0 1 2 to 1 2 0 @@ -154,18 +147,20 @@ def test_tranpose(pdims): # ********************************************* # Test Y to X transpose # It tranposes XZY to ZYX so from 0 1 2 to 1 2 0 - assert_array_equal(gathered_jd_zy.transpose(backward_pencils), gathered_jd_yx) + assert_array_equal( + gathered_jd_zy.transpose(backward_tranpose), gathered_jd_yx) # The X pencils should match in forward and backward transposes (original array) assert_array_equal(gathered_jd_yx, gathered_array) print(f"Pdims {pdims} are ok!") +# Cartesian product tests @pytest.mark.parametrize("pdims", - params) # Test with Slab and Pencil decompositions -def test_tranpose_grad(pdims): - - global_shape = (4, 4, 4) # These sizes are prime numbers x size of the pmesh + decomp) # Test with Slab and Pencil decompositions +@pytest.mark.parametrize("global_shape", + global_shapes) # Test cubes, non-cubes and primes +def test_tranpose_grad(pdims, global_shape): global_array, mesh = create_spmd_array(global_shape, pdims)