From 4bb1f1a550062aafcc0e519d184181cc19ddc579 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 15 Mar 2024 05:49:08 -0700 Subject: [PATCH] Update the lowering for div_p to require f32/f64 for floating point inputs PTX has no div instruction for other floating point types. See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div. PiperOrigin-RevId: 616100981 --- jax/_src/pallas/triton/lowering.py | 4 ++-- tests/pallas/pallas_test.py | 32 ++++++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 4b92442afa91..84c6b0a4ca47 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -835,7 +835,7 @@ def _mul(x: ir.Value, y: ir.Value) -> ir.Value: def _floordiv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value: assert x.type == y.type, (str(x.type), str(y.type)) x_element_type = _element_type(x.type) - if isinstance(x_element_type, ir.FloatType): + if isinstance(x_element_type, (ir.F32Type, ir.F64Type)): return arith_dialect.divf(x, y) if not isinstance(x_element_type, ir.IntegerType): raise NotImplementedError(f"unsupported types: {x.type} and {y.type}") @@ -852,7 +852,7 @@ def _truediv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value: x_element_type = ir.F32Type.get() x = _int_float_cast(x, x_element_type, signed=signed) y = _int_float_cast(y, x_element_type, signed=signed) - if isinstance(x_element_type, ir.FloatType): + if isinstance(x_element_type, (ir.F32Type, ir.F64Type)): return arith_dialect.divf(x, y) raise NotImplementedError(f"unsupported types: {x.type} and {y.type}") diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 94d87d487d9a..71e2d01ed222 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -40,12 +40,14 @@ from jax.experimental.pallas.ops import rms_norm from jax.experimental.pallas.ops import softmax try: + from jax._src.pallas.triton.lowering import LoweringError from jax._src.pallas.triton.pallas_call_registration import ( compile_jaxpr, _TRITON_COMPILE_VIA_XLA, ) from jax.experimental.pallas import gpu as plgpu except ModuleNotFoundError: + LoweringError = Exception compile_jaxpr = None _TRITON_COMPILE_VIA_XLA = None import numpy as np @@ -1634,19 +1636,41 @@ def isnan(x_ref, o_ref): x = x.at[3].set(jnp.nan) np.testing.assert_allclose(isnan(x), jnp.isnan(x)) - def test_true_divide(self): + @parameterized.parameters( + ("int32", "float32"), + ("float32", "float32"), + ) + def test_true_divide(self, dtype, out_dtype): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + out_shape=jax.ShapeDtypeStruct((8,), out_dtype), grid=1, ) def kernel(x_ref, y_ref, o_ref): o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) - x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7], dtype=jnp.int32) - y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4], dtype=jnp.int32) + x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) + y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y)) + @parameterized.parameters("float16", "bfloat16") + def test_true_divide_unsupported(self, dtype): + if self.INTERPRET: + self.skipTest("No lowering in interpreter mode") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), dtype), + grid=1, + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) + + x = jnp.array([2.4, 4.2]).astype(dtype) + y = jnp.array([4.2, 2.4]).astype(dtype) + with self.assertRaises(LoweringError): + kernel(x, y) + BINARY_OPS = [ ([jnp.floor_divide], ["int32", "uint32"]), (