Skip to content

Commit

Permalink
Add support for shape polymorphism in ffi_lowering and move lu_pivots…
Browse files Browse the repository at this point in the history
…_to_permutation lowering out of jaxlib.

The lowering logic for all jaxlib custom calls are currently split between JAX and jaxlib for reasons that are harder to justify now that the compiled calls are split between jaxlib and the relevant plugins. As part of my project to update these calls and simplify the lowering logic, it makes sense to consolidate the lowering rules in JAX instead of jaxlib since the logic is now the same for both GPU and CPU. This update tackles a simple kernel as a test case for what this would look like.

Since the full lowering rule is now implemented in JAX, we can take advantage of the MLIR helpers that are included there, including `jex.ffi.ffi_lowering`, which I needed to update to support shape polymorphism.

Of note: I think it is safe (in a compatibility sense) to delete the lowering code from jaxlib, but it does mean that it won't be possible to lower this operation when `jax.__version__ < jaxlib.__version__`. I think this is okay given our compatibility guarantees, but I'd love a sanity check on that!

Another note, this doesn't actually change the lowered HLO for this op, so we don't need to worry about export compatibility.

PiperOrigin-RevId: 664680250
  • Loading branch information
dfm authored and jax authors committed Aug 19, 2024
1 parent 05792c9 commit dad2f57
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 46 deletions.
5 changes: 5 additions & 0 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def _lowering(
kwargs["operand_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_in) # pytype: disable=attribute-error
if result_layouts is None:
kwargs["result_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_out)
if "result_shapes" not in kwargs and not all(
core.is_constant_shape(aval.shape) for aval in ctx.avals_out):
kwargs["result_shapes"] = [
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, aval.shape))
for aval in ctx.avals_out]

return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore

Expand Down
22 changes: 9 additions & 13 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from jax._src import dtypes
from jax._src.core import (
Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape)
from jax._src.extend import ffi
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -1190,16 +1191,13 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
return lu_pivots_to_permutation_p.bind(
x, permutation_size=permutation_size), 0

def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
def _lu_pivots_to_permutation_gpu_lowering(platform, ctx, pivots, *,
permutation_size):
# TODO(danfm): Remove once jaxlib 0.4.32 is the minimum version.
if jaxlib_version >= (0, 4, 32):
pivots_aval, = ctx.avals_in
pivots_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, pivots_aval.shape)
kwargs = dict(pivots_shape_vals=pivots_shape_vals)
else:
kwargs = {}
return lowering(pivots, permutation_size=permutation_size, **kwargs)
rule = ffi.ffi_lowering(f"{platform}_lu_pivots_to_permutation")
return rule(ctx, pivots,
# TODO(b/358275922): remove unused parameter 12 weeks after
# the release of jaxlib v0.4.32.
permutation_size=np.int32(permutation_size))


lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
Expand All @@ -1215,13 +1213,11 @@ def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
mlir.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False))
mlir.register_lowering(
lu_pivots_to_permutation_p,
partial(_lu_pivots_to_permutation_gpu_lowering,
gpu_linalg.cuda_lu_pivots_to_permutation),
partial(_lu_pivots_to_permutation_gpu_lowering, "cu"),
platform='cuda')
mlir.register_lowering(
lu_pivots_to_permutation_p,
partial(_lu_pivots_to_permutation_gpu_lowering,
gpu_linalg.hip_lu_pivots_to_permutation),
partial(_lu_pivots_to_permutation_gpu_lowering, "hip"),
platform='rocm')

# LU decomposition
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def test_primitive_coverage(self):
continue
if p.name == "pallas_call":
continue
if p.name == "ffi_call":
continue
if p.name == "tpu_custom_call":
continue
if p.name == "custom_partitioning":
Expand Down
34 changes: 1 addition & 33 deletions jaxlib/gpu_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import jaxlib.mlir.ir as ir

from .hlo_helpers import custom_call, mk_result_types_and_shapes
from .hlo_helpers import custom_call
from .gpu_common_utils import GpuLibNotLinkedError

from jaxlib import xla_client
Expand Down Expand Up @@ -61,38 +61,6 @@
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)


def _lu_pivots_to_permutation_hlo(platform, pivots, *, permutation_size,
pivots_shape_vals):
"""Kernel for the transformation of pivots to permutations on GPU."""
typ = ir.RankedTensorType(pivots.type)
i32_type = ir.IntegerType.get_signless(32)
assert typ.element_type == i32_type, typ
assert len(pivots_shape_vals) >= 1

pivots_layout = tuple(range(len(pivots_shape_vals) - 1, -1, -1))
permutations_layout = pivots_layout
permutations_dims = (*pivots_shape_vals[:-1], permutation_size)
result_types, result_shapes = mk_result_types_and_shapes(
[(permutations_dims, i32_type)])
return custom_call(
f"{platform}_lu_pivots_to_permutation",
api_version=4,
operands=[pivots],
operand_layouts=[pivots_layout],
result_types=result_types,
result_shapes=result_shapes,
result_layouts=[permutations_layout],
# TODO(b/358275922): remove backend_config 12 weeks after release of
# jaxlib v0.4.32.
backend_config=dict(
permutation_size=ir.IntegerAttr.get(i32_type, permutation_size),
),
).results

cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu")
hip_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "hip")


def _cholesky_update_hlo(platform, gpu_linalg, r_matrix, w_vector, dtype):
"""Cholesky update."""
del platform
Expand Down

0 comments on commit dad2f57

Please sign in to comment.