From 026690987f6a259eaaad1c8999f8e444d7cf8338 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 16 Oct 2024 09:29:12 -0700 Subject: [PATCH] [Pallas TPU] Fix lowering for `jnp.remainder` Fixes https://github.com/jax-ml/jax/issues/24027 PiperOrigin-RevId: 686535642 --- jax/_src/pallas/mosaic/lowering.py | 4 ++-- tests/pallas/ops_test.py | 21 +++++++-------------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index e775637d81e8..18b73a66ca23 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1945,11 +1945,11 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out - if jnp.issubdtype(aval_out.dtype, jnp.integer): + if jnp.issubdtype(aval_out.dtype, jnp.signedinteger): return arith.remsi(x, y) if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger): return arith.remui(x, y) - elif jnp.issubdtype(aval_out.dtype, jnp.floating): + if jnp.issubdtype(aval_out.dtype, jnp.floating): return arith.remf(x, y) raise NotImplementedError(aval_out.dtype) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 34b6fac469b7..6ba98e64b73c 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1034,16 +1034,12 @@ def test_binary(self, f, dtype): if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") - # TODO: skipped due to https://github.com/jax-ml/jax/issues/24027 + # TODO(ayx): Fix this on TPU if ( jtu.test_device_matches(["tpu"]) - and f == jnp.remainder - and not self.INTERPRET + and f in (jnp.floor_divide, jnp.subtract) + and dtype == "uint32" ): - self.skipTest("jnp.remainder on TPU is only supported in interpret mode") - - # TODO(ayx): fix this on TPU - if jtu.test_device_matches(["tpu"]) and dtype == "uint32": self.skipTest("Not supported on TPU") @functools.partial( @@ -1070,16 +1066,13 @@ def test_binary_scalar(self, f, dtype): self.skipTest("Test only supported on TPU.") if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") - # TODO: skipped due to https://github.com/jax-ml/jax/issues/24027 + + # TODO(ayx): Fix this on TPU if ( jtu.test_device_matches(["tpu"]) - and f == jnp.remainder - and not self.INTERPRET + and f in (jnp.floor_divide, jnp.subtract) + and dtype == "uint32" ): - self.skipTest("jnp.remainder on TPU is only supported in interpret mode") - - # TODO: skipped due to https://github.com/jax-ml/jax/issues/23972 - if jtu.test_device_matches(["tpu"]) and dtype == "uint32": self.skipTest("Not supported on TPU") @functools.partial(