Skip to content

Commit

Permalink
[Pallas TPU] Add lowerings for bf16 jnp.ceil and jnp.floor in TPU…
Browse files Browse the repository at this point in the history
… v6+

This PR is similar to #24284

Note that `np.testing.assert_allclose()` is changed to `self.assertAllClose()` because the latter is a wrapper with bfloat16 support.

PiperOrigin-RevId: 686094131
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Oct 18, 2024
1 parent 8615556 commit 8036585
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def kernel(x_ref, o_ref):
[jnp.abs, jnp.negative],
["int16", "int32", "int64", "float16", "float32", "float64"],
),
([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]),
([jnp.ceil, jnp.floor], ["bfloat16", "float32", "float64", "int32"]),
(
[jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt],
["float16", "float32", "float64"],
Expand All @@ -767,8 +767,23 @@ def test_elementwise(self, fn, dtype):
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")
if jtu.test_device_matches(["tpu"]) and dtype in ("int16", "float16"):
self.skipTest("int16 and float16 are not supported on TPU")

if (
jtu.test_device_matches(["gpu"])
and fn in (jnp.ceil, jnp.floor)
and dtype == "bfloat16"
):
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")

if (
jtu.test_device_matches(["tpu"])
and not jtu.is_device_tpu_at_least(6)
and fn in (jnp.ceil, jnp.floor)
and dtype == "bfloat16"
):
self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+")

# TODO(b/370578663): implement these lowerings on TPU
if jtu.test_device_matches(["tpu"]) and fn in (
Expand All @@ -784,7 +799,7 @@ def kernel(x_ref, o_ref):
o_ref[:] = fn(x_ref[...])

x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
self.assertAllClose(kernel(x), fn(x), rtol=1e-6)

@parameterized.named_parameters(
(f"{fn.__name__}_{dtype}", fn, dtype)
Expand All @@ -798,6 +813,13 @@ def test_elementwise_scalar(self, fn, dtype):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

if (
jtu.test_device_matches(["gpu"])
and fn in (jnp.ceil, jnp.floor)
and dtype == "bfloat16"
):
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")

if (
jtu.test_device_matches(["tpu"])
and fn == lax.population_count
Expand Down Expand Up @@ -826,7 +848,7 @@ def kernel(x_ref, o_ref):
o_ref[1] = fn(x_ref[1])

x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
self.assertAllClose(kernel(x), fn(x), rtol=1e-6)

def test_abs_weak_type(self):
# see https://github.com/jax-ml/jax/issues/23191
Expand Down

0 comments on commit 8036585

Please sign in to comment.