Skip to content

Commit

Permalink
Add tests for non-cubes and prime*size dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed May 31, 2024
1 parent af3953d commit 1d1a779
Showing 1 changed file with 33 additions and 38 deletions.
71 changes: 33 additions & 38 deletions tests/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,26 @@ def create_spmd_array(global_shape, pdims):

pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100
pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100
params = [(size, 1), (1, size), pencil_1, pencil_2]
#params = [(size, 1), (1, size)]

decomp = [(size, 1), (1, size), pencil_1, pencil_2]
global_shapes = [(4, 8, 16), (4, 4, 4), (29 * size, 19 * size, 17 * size)
] # Cubes, non-cubes and primes


# Cartesian product tests
@pytest.mark.parametrize("pdims",
params) # Test with Slab and Pencil decompositions
def test_tranpose(pdims):
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_tranpose(pdims, global_shape):
""" Goes from an array of shape [z,y,x] # What we call an x pencil
to [x,z,y] # what we call a y pencil
"""
print("*" * 80)
print(f"Testing with pdims {pdims}")

global_shape = (4, 4, 4) # These sizes are prime numbers x size of the pmesh
#global_shape = (29 * size, 19 * size, 17 * size)
print(f"Global shape is {global_shape}")
print(f"Testing with pdims {pdims} and global shape {global_shape}")

global_array, mesh = create_spmd_array(global_shape, pdims)

gathered_array = multihost_utils.process_allgather(global_array, tiled=True)
with mesh:
jd_tranposed_xy = transposeXtoY(global_array)
jd_tranposed_yz = transposeYtoZ(jd_tranposed_xy)
Expand All @@ -98,6 +98,16 @@ def test_tranpose(pdims):

print(f"JD tranposed yz sharding {jd_tranposed_yz.sharding.spec}")
print(f"JD tranposed xy sharding {jd_tranposed_xy.sharding.spec}")
print(f"JD tranposed zy sharding {jd_tranposed_zy.sharding.spec}")
print(f"JD tranposed yx sharding {jd_tranposed_yx.sharding.spec}")

assert global_array.sharding.spec == P('z', 'y')
assert jd_tranposed_xy.sharding.spec == transposed_sharding
assert jd_tranposed_yz.sharding.spec == original_sharding
assert jd_tranposed_zy.sharding.spec == transposed_sharding
assert jd_tranposed_yx.sharding.spec == original_sharding

gathered_array = multihost_utils.process_allgather(global_array, tiled=True)

gathered_jd_xy = multihost_utils.process_allgather(
jd_tranposed_xy, tiled=True)
Expand All @@ -108,42 +118,25 @@ def test_tranpose(pdims):
gathered_jd_yx = multihost_utils.process_allgather(
jd_tranposed_yx, tiled=True)

assert jd_tranposed_xy.sharding.spec == transposed_sharding
assert jd_tranposed_yz.sharding.spec == original_sharding
assert jd_tranposed_zy.sharding.spec == transposed_sharding
assert jd_tranposed_yx.sharding.spec == original_sharding

# Explanation :

# For pencils

# Tranposing forward is a shift axis to the right so ZYX to XZY to YXZ (2 0 1)
# Tranposing backward is a shift axis to the left so YXZ to XZY to ZYX (1 2 0)
# Double Tranposing from ZYX to YXZ is double (2 0 1) so (1 2 0)

# For slabs it is a bit more complicated

# Tranposing from X to Y is a X Y tranpose so ZYX to ZXY (0 2 1) and tranposing back is the same (0 2 1)
# Tranposing from Y to Z is a Y Z tranpose so ZXY to YXZ (2 0 1) and tranposing back is the same (2 0 1)
# a double tranpose is a Z X tranpose so YXZ to ZYX (0 1 2)

# Every tranpose, also tranposes the pdim grid from P('Z', 'Y') to P('Y', 'Z') or vise versa

forward_tranpose = [1, 2, 0] if 1 in pdims else [2, 0, 1]
forward_pencils = [2, 0, 1]
backward_tranpose = [2, 0, 1] if 1 in pdims else [1, 2, 0]
backward_pencils = [1, 2, 0]
double_back = [0, 1, 2] if 1 in pdims else [1, 2, 0]
forward_tranpose = [2, 0, 1]
backward_tranpose = [1, 2, 0]
double_forward = [1, 2, 0]

#
# Test X to Y transpose
# It tranposes ZYX to XZY so from 0 1 2 to 2 0 1
assert_array_equal(gathered_array.transpose(forward_tranpose), gathered_jd_xy)
# *********************************************
# Test Y to Z transpose
# It tranposes XZY to YXZ so from 0 1 2 to 2 0 1 again
assert_array_equal(gathered_jd_xy.transpose(forward_pencils), gathered_jd_yz)
assert_array_equal(gathered_jd_xy.transpose(forward_tranpose), gathered_jd_yz)
# and from the global array ZYX to YXZ so from 0 1 2 to 1 2 0
assert_array_equal(gathered_array.transpose(double_back), gathered_jd_yz)
assert_array_equal(gathered_array.transpose(double_forward), gathered_jd_yz)
# *********************************************
# Test Z to Y transpose
# It tranposes YXZ to XZY so from 0 1 2 to 1 2 0
Expand All @@ -154,18 +147,20 @@ def test_tranpose(pdims):
# *********************************************
# Test Y to X transpose
# It tranposes XZY to ZYX so from 0 1 2 to 1 2 0
assert_array_equal(gathered_jd_zy.transpose(backward_pencils), gathered_jd_yx)
assert_array_equal(
gathered_jd_zy.transpose(backward_tranpose), gathered_jd_yx)
# The X pencils should match in forward and backward transposes (original array)
assert_array_equal(gathered_jd_yx, gathered_array)

print(f"Pdims {pdims} are ok!")


# Cartesian product tests
@pytest.mark.parametrize("pdims",
params) # Test with Slab and Pencil decompositions
def test_tranpose_grad(pdims):

global_shape = (4, 4, 4) # These sizes are prime numbers x size of the pmesh
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_tranpose_grad(pdims, global_shape):

global_array, mesh = create_spmd_array(global_shape, pdims)

Expand Down

0 comments on commit 1d1a779

Please sign in to comment.