diff --git a/tests/layout_test.py b/tests/layout_test.py index 7972d44d304a..acbba120bfba 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -40,11 +40,13 @@ def tearDownModule(): class LayoutTest(jtu.JaxTestCase): def setUp(self): - if not jtu.test_device_matches(['tpu']): - self.skipTest("Layouts do not work on CPU and GPU backends yet.") + if not jtu.test_device_matches(['tpu', 'gpu']): + self.skipTest("Layouts do not work on CPU backend yet.") super().setUp() def test_auto_layout(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) shape2 = (128, 128) @@ -110,6 +112,8 @@ def init(x, y): self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T) def test_default_layout(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape = (4, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -149,6 +153,8 @@ def f(x): out_shardings=DLL.AUTO).lower(sds).compile() def test_in_layouts_out_layouts(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape = (8, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -173,6 +179,8 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape = (4, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -235,6 +243,8 @@ def f(x, y): compiled(*arrs) def test_aot_layout_mismatch(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape = (256, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -404,6 +414,9 @@ def f(x): self.assertArraysEqual(out, inp.T) def test_device_put_user_concrete_layout(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test does not work on GPU backend.") + shape = (8, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) dll = DLL(major_to_minor=(1, 0)) diff --git a/tests/memories_test.py b/tests/memories_test.py index 7559f9f724d4..cb6f931865f3 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -190,8 +190,8 @@ def test_default_memory_kind(self): class DevicePutTest(jtu.JaxTestCase): def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Memories do not work on CPU and GPU backends yet.") + if not jtu.test_device_matches(["tpu", "gpu"]): + self.skipTest("Memories do not work on CPU backend yet.") super().setUp() def _check_device_put_addressable_shards( @@ -215,6 +215,8 @@ def test_error_transfer_to_memory_kind_outside_jit(self): @parameterized.parameters("unpinned_host", "pinned_host") def test_device_put_host_to_hbm(self, host_memory_kind: str): + if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": + self.skipTest("unpinned_host does not work on GPU backend.") mesh = jtu.create_global_mesh((2, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind) np_inp = np.arange(16).reshape(8, 2) @@ -229,6 +231,8 @@ def test_device_put_host_to_hbm(self, host_memory_kind: str): @parameterized.parameters("unpinned_host", "pinned_host") def test_device_put_hbm_to_host(self, host_memory_kind: str): + if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": + self.skipTest("unpinned_host does not work on GPU backend.") mesh = jtu.create_global_mesh((2, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind) inp = jnp.arange(16).reshape(8, 2) @@ -246,6 +250,8 @@ def test_device_put_hbm_to_host(self, host_memory_kind: str): def test_device_put_different_device_and_memory_host_to_hbm( self, host_memory_kind: str ): + if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": + self.skipTest("unpinned_host does not work on GPU backend.") if jax.device_count() < 3: raise unittest.SkipTest("Test requires >=3 devices") @@ -266,6 +272,8 @@ def test_device_put_different_device_and_memory_host_to_hbm( def test_device_put_different_device_and_memory_hbm_to_host( self, host_memory_kind: str ): + if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": + self.skipTest("unpinned_host does not work on GPU backend.") if jax.device_count() < 3: raise unittest.SkipTest("Test requires >=3 devices") @@ -285,6 +293,8 @@ def test_device_put_different_device_and_memory_hbm_to_host( @parameterized.parameters("unpinned_host", "pinned_host") def test_device_put_on_different_device_with_the_same_memory_kind( self, host_memory_kind: str): + if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": + self.skipTest("unpinned_host does not work on GPU backend.") if len(jax.devices()) < 2: raise unittest.SkipTest("Test requires >=2 devices.") @@ -331,6 +341,8 @@ def test_device_put_on_different_device_with_the_same_memory_kind( @parameterized.parameters("unpinned_host", "pinned_host") def test_device_put_numpy_array(self, host_memory_kind: str): + if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": + self.skipTest("unpinned_host does not work on GPU backend.") mesh = jtu.create_global_mesh((2, 2), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device") @@ -345,6 +357,8 @@ def test_device_put_numpy_array(self, host_memory_kind: str): @parameterized.parameters("unpinned_host", "pinned_host") def test_device_put_numpy_scalar(self, host_memory_kind: str): + if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": + self.skipTest("unpinned_host does not work on GPU backend.") np_inp = np.float32(8) s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device") s_host = s_hbm.with_memory_kind(host_memory_kind) @@ -358,6 +372,8 @@ def test_device_put_numpy_scalar(self, host_memory_kind: str): @parameterized.parameters("unpinned_host", "pinned_host") def test_device_put_python_scalar(self, host_memory_kind: str): + if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": + self.skipTest("unpinned_host does not work on GPU backend.") py_scalar = float(8) s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device") s_host = s_hbm.with_memory_kind(host_memory_kind) @@ -372,6 +388,8 @@ def test_device_put_python_scalar(self, host_memory_kind: str): @parameterized.parameters("unpinned_host", "pinned_host") def test_device_put_python_int(self, host_memory_kind: str): + if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": + self.skipTest("unpinned_host does not work on GPU backend.") py_inp = 8 s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device") s_host = s_hbm.with_memory_kind(host_memory_kind) @@ -399,6 +417,8 @@ def f(a, b): out, np_inp * np_inp, s_dev, "device") def test_parameter_streaming(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test does not work on GPU backend.") _, s_host, np_inp, inp_host = _create_inputs( (8, 2), P("x", "y"), mem_kind="pinned_host") s_dev = s_host.with_memory_kind('device') @@ -422,6 +442,8 @@ def f(a, b): out2, np_inp * np_inp * 2, s_host, 'pinned_host') def test_parameter_streaming_with_scalar_and_constant(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_global_mesh((2, 2), ("x", "y")) scalar_inp = 1 s_host = NamedSharding(mesh, P(), memory_kind="pinned_host") @@ -569,6 +591,8 @@ def f(x): out_host, np_inp * 2, s_host, 'pinned_host') def test_output_streaming_inside_scan(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test does not work on GPU backend.") if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z"))