Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas TPU] Add lowerings for bf16 jnp.ceil and jnp.floor in TPU v6+ #24310

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def kernel(x_ref, o_ref):
[jnp.abs, jnp.negative],
["int16", "int32", "int64", "float16", "float32", "float64"],
),
([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]),
([jnp.ceil, jnp.floor], ["bfloat16", "float32", "float64", "int32"]),
(
[jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt],
["float16", "float32", "float64"],
Expand All @@ -767,8 +767,23 @@ def test_elementwise(self, fn, dtype):
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")
if jtu.test_device_matches(["tpu"]) and dtype in ("int16", "float16"):
self.skipTest("int16 and float16 are not supported on TPU")

if (
jtu.test_device_matches(["gpu"])
and fn in (jnp.ceil, jnp.floor)
and dtype == "bfloat16"
):
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")

if (
jtu.test_device_matches(["tpu"])
and not jtu.is_device_tpu_at_least(6)
and fn in (jnp.ceil, jnp.floor)
and dtype == "bfloat16"
):
self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+")

# TODO(b/370578663): implement these lowerings on TPU
if jtu.test_device_matches(["tpu"]) and fn in (
Expand All @@ -784,7 +799,7 @@ def kernel(x_ref, o_ref):
o_ref[:] = fn(x_ref[...])

x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
self.assertAllClose(kernel(x), fn(x), rtol=1e-6)

@parameterized.named_parameters(
(f"{fn.__name__}_{dtype}", fn, dtype)
Expand All @@ -798,6 +813,13 @@ def test_elementwise_scalar(self, fn, dtype):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

if (
jtu.test_device_matches(["gpu"])
and fn in (jnp.ceil, jnp.floor)
and dtype == "bfloat16"
):
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")

if (
jtu.test_device_matches(["tpu"])
and fn == lax.population_count
Expand Down Expand Up @@ -826,7 +848,7 @@ def kernel(x_ref, o_ref):
o_ref[1] = fn(x_ref[1])

x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
self.assertAllClose(kernel(x), fn(x), rtol=1e-6)

def test_abs_weak_type(self):
# see https://github.com/jax-ml/jax/issues/23191
Expand Down
Loading