From 80365858f92a7ebf183f858dbf2a3d26253f23c2 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 15 Oct 2024 07:17:41 -0700 Subject: [PATCH] [Pallas TPU] Add lowerings for bf16 `jnp.ceil` and `jnp.floor` in TPU v6+ This PR is similar to https://github.com/jax-ml/jax/pull/24284 Note that `np.testing.assert_allclose()` is changed to `self.assertAllClose()` because the latter is a wrapper with bfloat16 support. PiperOrigin-RevId: 686094131 --- tests/pallas/ops_test.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 34b6fac469b7..2e7a9b53c416 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -742,7 +742,7 @@ def kernel(x_ref, o_ref): [jnp.abs, jnp.negative], ["int16", "int32", "int64", "float16", "float32", "float64"], ), - ([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]), + ([jnp.ceil, jnp.floor], ["bfloat16", "float32", "float64", "int32"]), ( [jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt], ["float16", "float32", "float64"], @@ -767,8 +767,23 @@ def test_elementwise(self, fn, dtype): if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") - if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: - self.skipTest("16-bit types are not supported on TPU") + if jtu.test_device_matches(["tpu"]) and dtype in ("int16", "float16"): + self.skipTest("int16 and float16 are not supported on TPU") + + if ( + jtu.test_device_matches(["gpu"]) + and fn in (jnp.ceil, jnp.floor) + and dtype == "bfloat16" + ): + self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU") + + if ( + jtu.test_device_matches(["tpu"]) + and not jtu.is_device_tpu_at_least(6) + and fn in (jnp.ceil, jnp.floor) + and dtype == "bfloat16" + ): + self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+") # TODO(b/370578663): implement these lowerings on TPU if jtu.test_device_matches(["tpu"]) and fn in ( @@ -784,7 +799,7 @@ def kernel(x_ref, o_ref): o_ref[:] = fn(x_ref[...]) x = jnp.array([0.42, 2.4]).astype(dtype) - np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + self.assertAllClose(kernel(x), fn(x), rtol=1e-6) @parameterized.named_parameters( (f"{fn.__name__}_{dtype}", fn, dtype) @@ -798,6 +813,13 @@ def test_elementwise_scalar(self, fn, dtype): if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") + if ( + jtu.test_device_matches(["gpu"]) + and fn in (jnp.ceil, jnp.floor) + and dtype == "bfloat16" + ): + self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU") + if ( jtu.test_device_matches(["tpu"]) and fn == lax.population_count @@ -826,7 +848,7 @@ def kernel(x_ref, o_ref): o_ref[1] = fn(x_ref[1]) x = jnp.array([0.42, 2.4]).astype(dtype) - np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + self.assertAllClose(kernel(x), fn(x), rtol=1e-6) def test_abs_weak_type(self): # see https://github.com/jax-ml/jax/issues/23191