Skip to content

Commit

Permalink
[Pallas TPU] Fix lowering for jnp.remainder
Browse files Browse the repository at this point in the history
Fixes #24027

PiperOrigin-RevId: 686535642
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Oct 19, 2024
1 parent 884f1dc commit 0266909
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 16 deletions.
4 changes: 2 additions & 2 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,11 +1945,11 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
def _rem_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
(aval_out,) = ctx.avals_out
if jnp.issubdtype(aval_out.dtype, jnp.integer):
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
return arith.remsi(x, y)
if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
return arith.remui(x, y)
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
if jnp.issubdtype(aval_out.dtype, jnp.floating):
return arith.remf(x, y)
raise NotImplementedError(aval_out.dtype)

Expand Down
21 changes: 7 additions & 14 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,16 +1034,12 @@ def test_binary(self, f, dtype):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

# TODO: skipped due to https://github.com/jax-ml/jax/issues/24027
# TODO(ayx): Fix this on TPU
if (
jtu.test_device_matches(["tpu"])
and f == jnp.remainder
and not self.INTERPRET
and f in (jnp.floor_divide, jnp.subtract)
and dtype == "uint32"
):
self.skipTest("jnp.remainder on TPU is only supported in interpret mode")

# TODO(ayx): fix this on TPU
if jtu.test_device_matches(["tpu"]) and dtype == "uint32":
self.skipTest("Not supported on TPU")

@functools.partial(
Expand All @@ -1070,16 +1066,13 @@ def test_binary_scalar(self, f, dtype):
self.skipTest("Test only supported on TPU.")
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")
# TODO: skipped due to https://github.com/jax-ml/jax/issues/24027

# TODO(ayx): Fix this on TPU
if (
jtu.test_device_matches(["tpu"])
and f == jnp.remainder
and not self.INTERPRET
and f in (jnp.floor_divide, jnp.subtract)
and dtype == "uint32"
):
self.skipTest("jnp.remainder on TPU is only supported in interpret mode")

# TODO: skipped due to https://github.com/jax-ml/jax/issues/23972
if jtu.test_device_matches(["tpu"]) and dtype == "uint32":
self.skipTest("Not supported on TPU")

@functools.partial(
Expand Down

0 comments on commit 0266909

Please sign in to comment.