diff --git a/tests/test_transpose.py b/tests/test_transpose.py
index 8aedd7d..561e376 100644
--- a/tests/test_transpose.py
+++ b/tests/test_transpose.py
@@ -52,26 +52,26 @@ def create_spmd_array(global_shape, pdims):
 
 pencil_1 = (size // 2, size // (size // 2))  # 2x2 for V100 and 4x2 for A100
 pencil_2 = (size // (size // 2), size // 2)  # 2x2 for V100 and 2x4 for A100
-params = [(size, 1), (1, size), pencil_1, pencil_2]
-#params = [(size, 1), (1, size)]
 
+decomp = [(size, 1), (1, size), pencil_1, pencil_2]
+global_shapes = [(4, 8, 16), (4, 4, 4), (29 * size, 19 * size, 17 * size)
+                ]  # Cubes, non-cubes and primes
 
+
+# Cartesian product tests
 @pytest.mark.parametrize("pdims",
-                         params)  # Test with Slab and Pencil decompositions
-def test_tranpose(pdims):
+                         decomp)  # Test with Slab and Pencil decompositions
+@pytest.mark.parametrize("global_shape",
+                         global_shapes)  # Test cubes, non-cubes and primes
+def test_tranpose(pdims, global_shape):
   """ Goes from an array of shape [z,y,x] # What we call an x pencil
     to [x,z,y] # what we call a y pencil
     """
   print("*" * 80)
-  print(f"Testing with pdims {pdims}")
-
-  global_shape = (4, 4, 4)  # These sizes are prime numbers x size of the pmesh
-  #global_shape = (29 * size, 19 * size, 17 * size)
-  print(f"Global shape is {global_shape}")
+  print(f"Testing with pdims {pdims} and global shape {global_shape}")
 
   global_array, mesh = create_spmd_array(global_shape, pdims)
 
-  gathered_array = multihost_utils.process_allgather(global_array, tiled=True)
   with mesh:
     jd_tranposed_xy = transposeXtoY(global_array)
     jd_tranposed_yz = transposeYtoZ(jd_tranposed_xy)
@@ -98,6 +98,16 @@ def test_tranpose(pdims):
 
   print(f"JD tranposed yz sharding {jd_tranposed_yz.sharding.spec}")
   print(f"JD tranposed xy sharding {jd_tranposed_xy.sharding.spec}")
+  print(f"JD tranposed zy sharding {jd_tranposed_zy.sharding.spec}")
+  print(f"JD tranposed yx sharding {jd_tranposed_yx.sharding.spec}")
+
+  assert global_array.sharding.spec == P('z', 'y')
+  assert jd_tranposed_xy.sharding.spec == transposed_sharding
+  assert jd_tranposed_yz.sharding.spec == original_sharding
+  assert jd_tranposed_zy.sharding.spec == transposed_sharding
+  assert jd_tranposed_yx.sharding.spec == original_sharding
+
+  gathered_array = multihost_utils.process_allgather(global_array, tiled=True)
 
   gathered_jd_xy = multihost_utils.process_allgather(
       jd_tranposed_xy, tiled=True)
@@ -108,42 +118,25 @@ def test_tranpose(pdims):
   gathered_jd_yx = multihost_utils.process_allgather(
       jd_tranposed_yx, tiled=True)
 
-  assert jd_tranposed_xy.sharding.spec == transposed_sharding
-  assert jd_tranposed_yz.sharding.spec == original_sharding
-  assert jd_tranposed_zy.sharding.spec == transposed_sharding
-  assert jd_tranposed_yx.sharding.spec == original_sharding
-
   # Explanation :
-
-  # For pencils
-
   # Tranposing forward is a shift axis to the right so ZYX to XZY to YXZ (2 0 1)
   # Tranposing backward is a shift axis to the left so YXZ to XZY to ZYX (1 2 0)
   # Double Tranposing from ZYX to YXZ is double (2 0 1) so  (1 2 0)
 
-  # For slabs it is a bit more complicated
-
-  # Tranposing from X to Y is a X Y tranpose so ZYX to ZXY (0 2 1) and tranposing back is the same (0 2 1)
-  # Tranposing from Y to Z is a Y Z tranpose so ZXY to YXZ (2 0 1) and tranposing back is the same (2 0 1)
-  # a double tranpose is a Z X tranpose so YXZ to ZYX (0 1 2)
-
-  # Every tranpose, also tranposes the pdim grid from P('Z', 'Y') to P('Y', 'Z') or vise versa
-
-  forward_tranpose = [1, 2, 0] if 1 in pdims else [2, 0, 1]
-  forward_pencils = [2, 0, 1]
-  backward_tranpose = [2, 0, 1] if 1 in pdims else [1, 2, 0]
-  backward_pencils = [1, 2, 0]
-  double_back = [0, 1, 2] if 1 in pdims else [1, 2, 0]
+  forward_tranpose = [2, 0, 1]
+  backward_tranpose = [1, 2, 0]
+  double_forward = [1, 2, 0]
 
+  #
   # Test X to Y transpose
   # It tranposes ZYX to XZY so from 0 1 2 to 2 0 1
   assert_array_equal(gathered_array.transpose(forward_tranpose), gathered_jd_xy)
   # *********************************************
   # Test Y to Z transpose
   # It tranposes XZY to YXZ so from 0 1 2 to 2 0 1 again
-  assert_array_equal(gathered_jd_xy.transpose(forward_pencils), gathered_jd_yz)
+  assert_array_equal(gathered_jd_xy.transpose(forward_tranpose), gathered_jd_yz)
   # and from the global array ZYX to YXZ so from 0 1 2 to 1 2 0
-  assert_array_equal(gathered_array.transpose(double_back), gathered_jd_yz)
+  assert_array_equal(gathered_array.transpose(double_forward), gathered_jd_yz)
   # *********************************************
   # Test Z to Y transpose
   # It tranposes YXZ to XZY so from 0 1 2 to 1 2 0
@@ -154,18 +147,20 @@ def test_tranpose(pdims):
   # *********************************************
   # Test Y to X transpose
   # It tranposes XZY to ZYX so from 0 1 2 to 1 2 0
-  assert_array_equal(gathered_jd_zy.transpose(backward_pencils), gathered_jd_yx)
+  assert_array_equal(
+      gathered_jd_zy.transpose(backward_tranpose), gathered_jd_yx)
   # The X pencils should match in forward and backward transposes (original array)
   assert_array_equal(gathered_jd_yx, gathered_array)
 
   print(f"Pdims {pdims} are ok!")
 
 
+# Cartesian product tests
 @pytest.mark.parametrize("pdims",
-                         params)  # Test with Slab and Pencil decompositions
-def test_tranpose_grad(pdims):
-
-  global_shape = (4, 4, 4)  # These sizes are prime numbers x size of the pmesh
+                         decomp)  # Test with Slab and Pencil decompositions
+@pytest.mark.parametrize("global_shape",
+                         global_shapes)  # Test cubes, non-cubes and primes
+def test_tranpose_grad(pdims, global_shape):
 
   global_array, mesh = create_spmd_array(global_shape, pdims)