diff --git a/jax/BUILD b/jax/BUILD index 0ebfca2ba35f..57b253307953 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -826,6 +826,8 @@ pytype_strict_library( "//jax/_src/lib", ] + if_building_jaxlib([ "//jaxlib/mlir:ir", + "//jaxlib/mlir:mhlo_dialect", + "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:stablehlo_dialect", ]) + py_deps("numpy") + py_deps("absl/flags"), ) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index ce879ca041a2..4ea1694464db 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -30,18 +30,30 @@ import jax from jax import core from jax._src import config +from jax._src import sharding_impls +from jax._src.interpreters import mlir from jax._src.lib import tpu from jax._src.lib import xla_client from jax._src.lib.mlir.dialects import hlo -from jax._src.interpreters import mlir -from jax._src import sharding_impls from jax.interpreters import xla from jaxlib.mlir import ir +from jaxlib.mlir.dialects import mhlo from jaxlib.mlir.dialects import stablehlo +from jaxlib.mlir.passmanager import PassManager import numpy as np FLAGS = flags.FLAGS +_MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state( + name="mosaic_use_python_pipeline", + default=False, + help=( + "Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel" + " is called (for Pallas, this happens at JAX lowering time), instead of" + " later within XLA." + ), +) + _MOSAIC_ALLOW_HLO = config.define_bool_state( name="jax_mosaic_allow_hlo", default=False, @@ -250,6 +262,105 @@ def _tpu_custom_call_lowering( platform="tpu") +def _lower_tpu_kernel( + module: ir.Module, + hardware_generation: int, +) -> ir.Module: + """Runs MLIR passes lowering the given module to an MLIR module. + + Uses Python versions of infer-memref-layout and apply-vector-layout. + + Args: + module: The MLIR module to lower. + hardware_generation: The TPU hardware generation to target. + + Returns: + An MLIR module implementing the kernel. + """ + try: + module.operation.verify() + except ir.MLIRError as e: + raise ValueError("The compiled module fails MLIR verification") from e + + with module.context as ctx, module.operation.location as _: + + ctx.append_dialect_registry(mlir.upstream_dialects) + ctx.load_all_available_dialects() + tpu.register_dialect(ctx) + mhlo.register_mhlo_dialect(ctx) + mhlo.register_mhlo_passes() + + # We'll mutate the module, so clone it + module = ir.Module.parse( + module.operation.get_asm(binary=True, enable_debug_info=True) + ) + dump_mlir(module, "original") + + if _MOSAIC_ALLOW_HLO.value: + # Run hlo dialect conversion: hlo -> linalg -> vector. + pipeline = [ + "hlo-legalize-to-arithmetic", + "func.func(hlo-legalize-to-linalg)", + "func.func(linalg-vectorization)", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-hlo-conversion") + + pipeline = [ + f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})" + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-infer-memref-layout") + + pipeline = [ + "canonicalize", + "cse", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-simplify") + + if checks := FLAGS["xla_mosaic_on_device_checks"].value: + checks = set(checks.split(",")) + if checks == {"bounds"}: # We only support one kind of checks now. + pipeline = PassManager.parse( + "builtin.module(func.func(debug-assert-insertion))" + ) + pipeline.run(module.operation) + dump_mlir(module, "post-assert-insertion") + elif checks: + checks.discard("bounds") + raise ValueError( + f"Unrecognized on-device check categories: {', '.join(checks)}" + ) + + pipeline = [ + "func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-infer-vector-layout") + + mxu_size = 128 if hardware_generation < 6 else 256 + pipeline = [ + "func.func(tpu-apply-vector-layout{sublane-count=8 lane-count=128" + f" hardware-generation={hardware_generation}" + f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}" + "})" + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-apply-vector-layout") + + pipeline = PassManager.parse("builtin.module(canonicalize)") + pipeline.run(module.operation) + dump_mlir(module, "pre-lower-to-llo") + + return module + + def as_tpu_kernel( module: ir.Module, out_type: Any, @@ -279,6 +390,11 @@ def as_tpu_kernel( has_communication, has_custom_barrier = tpu.private_has_communication( module.operation ) + needs_layout_passes = not device_type + if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: + module = _lower_tpu_kernel(module, hardware_generation) + needs_layout_passes = False + bytecode_buffer = io.BytesIO() module.operation.write_bytecode(bytecode_buffer, desired_version=0) asm = bytecode_buffer.getvalue() @@ -290,7 +406,7 @@ def as_tpu_kernel( asm, out_type, needs_hlo_passes=_MOSAIC_ALLOW_HLO.value, - needs_layout_passes=not device_type, + needs_layout_passes=needs_layout_passes, device_type=device_type, has_communication=has_communication, has_custom_barrier=has_custom_barrier,