Skip to content

Commit

Permalink
Update the lowering for div_p to require f32/f64 for floating point i…
Browse files Browse the repository at this point in the history
…nputs

PTX has no div instruction for other floating point types.
See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div.

PiperOrigin-RevId: 616100981
  • Loading branch information
superbobry authored and jax authors committed Mar 15, 2024
1 parent c94ea14 commit 56fa2ad
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
4 changes: 2 additions & 2 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def _mul(x: ir.Value, y: ir.Value) -> ir.Value:
def _floordiv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
assert x.type == y.type, (str(x.type), str(y.type))
x_element_type = _element_type(x.type)
if isinstance(x_element_type, ir.FloatType):
if isinstance(x_element_type, (ir.F32Type, ir.F64Type)):
return arith_dialect.divf(x, y)
if not isinstance(x_element_type, ir.IntegerType):
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
Expand All @@ -852,7 +852,7 @@ def _truediv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
x_element_type = ir.F32Type.get()
x = _int_float_cast(x, x_element_type, signed=signed)
y = _int_float_cast(y, x_element_type, signed=signed)
if isinstance(x_element_type, ir.FloatType):
if isinstance(x_element_type, (ir.F32Type, ir.F64Type)):
return arith_dialect.divf(x, y)
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")

Expand Down
32 changes: 28 additions & 4 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@
from jax.experimental.pallas.ops import rms_norm
from jax.experimental.pallas.ops import softmax
try:
from jax._src.pallas.triton.lowering import LoweringError
from jax._src.pallas.triton.pallas_call_registration import (
compile_jaxpr,
_TRITON_COMPILE_VIA_XLA,
)
from jax.experimental.pallas import gpu as plgpu
except ModuleNotFoundError:
LoweringError = Exception
compile_jaxpr = None
_TRITON_COMPILE_VIA_XLA = None
import numpy as np
Expand Down Expand Up @@ -1634,19 +1636,41 @@ def isnan(x_ref, o_ref):
x = x.at[3].set(jnp.nan)
np.testing.assert_allclose(isnan(x), jnp.isnan(x))

def test_true_divide(self):
@parameterized.parameters(
("int32", "float32"),
("float32", "float32"),
)
def test_true_divide(self, dtype, out_dtype):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8,), jnp.float32),
out_shape=jax.ShapeDtypeStruct((8,), out_dtype),
grid=1,
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])

x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7], dtype=jnp.int32)
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4], dtype=jnp.int32)
x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype)
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype)
np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y))

@parameterized.parameters("float16", "bfloat16")
def test_true_divide_unsupported(self, dtype):
if self.INTERPRET:
self.skipTest("No lowering in interpreter mode")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), dtype),
grid=1,
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])

x = jnp.array([2.4, 4.2]).astype(dtype)
y = jnp.array([4.2, 2.4]).astype(dtype)
with self.assertRaises(LoweringError):
kernel(x, y)

BINARY_OPS = [
([jnp.floor_divide], ["int32", "uint32"]),
(
Expand Down

0 comments on commit 56fa2ad

Please sign in to comment.