From be2eb495860c5b4755b80e129dc40b5322fb17ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Thu, 14 Mar 2024 17:50:38 -0700 Subject: [PATCH] [Mosaic] Restore Python pipeline and add a CLI flag to run it. We decided to expose a Python alternative again to make it easier for OSS users to see and customize the pipeline. The default is still to run the pipeline from XLA. The original one was removed in cl/596464480 and cl/597332393. PiperOrigin-RevId: 615959867 --- jax/BUILD | 2 + jax/_src/tpu_custom_call.py | 122 +++++++++++++++++++++++++++++++++++- 2 files changed, 121 insertions(+), 3 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 0ebfca2ba35f..57b253307953 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -826,6 +826,8 @@ pytype_strict_library( "//jax/_src/lib", ] + if_building_jaxlib([ "//jaxlib/mlir:ir", + "//jaxlib/mlir:mhlo_dialect", + "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:stablehlo_dialect", ]) + py_deps("numpy") + py_deps("absl/flags"), ) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index ce879ca041a2..4ea1694464db 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -30,18 +30,30 @@ import jax from jax import core from jax._src import config +from jax._src import sharding_impls +from jax._src.interpreters import mlir from jax._src.lib import tpu from jax._src.lib import xla_client from jax._src.lib.mlir.dialects import hlo -from jax._src.interpreters import mlir -from jax._src import sharding_impls from jax.interpreters import xla from jaxlib.mlir import ir +from jaxlib.mlir.dialects import mhlo from jaxlib.mlir.dialects import stablehlo +from jaxlib.mlir.passmanager import PassManager import numpy as np FLAGS = flags.FLAGS +_MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state( + name="mosaic_use_python_pipeline", + default=False, + help=( + "Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel" + " is called (for Pallas, this happens at JAX lowering time), instead of" + " later within XLA." + ), +) + _MOSAIC_ALLOW_HLO = config.define_bool_state( name="jax_mosaic_allow_hlo", default=False, @@ -250,6 +262,105 @@ def _tpu_custom_call_lowering( platform="tpu") +def _lower_tpu_kernel( + module: ir.Module, + hardware_generation: int, +) -> ir.Module: + """Runs MLIR passes lowering the given module to an MLIR module. + + Uses Python versions of infer-memref-layout and apply-vector-layout. + + Args: + module: The MLIR module to lower. + hardware_generation: The TPU hardware generation to target. + + Returns: + An MLIR module implementing the kernel. + """ + try: + module.operation.verify() + except ir.MLIRError as e: + raise ValueError("The compiled module fails MLIR verification") from e + + with module.context as ctx, module.operation.location as _: + + ctx.append_dialect_registry(mlir.upstream_dialects) + ctx.load_all_available_dialects() + tpu.register_dialect(ctx) + mhlo.register_mhlo_dialect(ctx) + mhlo.register_mhlo_passes() + + # We'll mutate the module, so clone it + module = ir.Module.parse( + module.operation.get_asm(binary=True, enable_debug_info=True) + ) + dump_mlir(module, "original") + + if _MOSAIC_ALLOW_HLO.value: + # Run hlo dialect conversion: hlo -> linalg -> vector. + pipeline = [ + "hlo-legalize-to-arithmetic", + "func.func(hlo-legalize-to-linalg)", + "func.func(linalg-vectorization)", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-hlo-conversion") + + pipeline = [ + f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})" + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-infer-memref-layout") + + pipeline = [ + "canonicalize", + "cse", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-simplify") + + if checks := FLAGS["xla_mosaic_on_device_checks"].value: + checks = set(checks.split(",")) + if checks == {"bounds"}: # We only support one kind of checks now. + pipeline = PassManager.parse( + "builtin.module(func.func(debug-assert-insertion))" + ) + pipeline.run(module.operation) + dump_mlir(module, "post-assert-insertion") + elif checks: + checks.discard("bounds") + raise ValueError( + f"Unrecognized on-device check categories: {', '.join(checks)}" + ) + + pipeline = [ + "func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})", + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-infer-vector-layout") + + mxu_size = 128 if hardware_generation < 6 else 256 + pipeline = [ + "func.func(tpu-apply-vector-layout{sublane-count=8 lane-count=128" + f" hardware-generation={hardware_generation}" + f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}" + "})" + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + dump_mlir(module, "post-apply-vector-layout") + + pipeline = PassManager.parse("builtin.module(canonicalize)") + pipeline.run(module.operation) + dump_mlir(module, "pre-lower-to-llo") + + return module + + def as_tpu_kernel( module: ir.Module, out_type: Any, @@ -279,6 +390,11 @@ def as_tpu_kernel( has_communication, has_custom_barrier = tpu.private_has_communication( module.operation ) + needs_layout_passes = not device_type + if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: + module = _lower_tpu_kernel(module, hardware_generation) + needs_layout_passes = False + bytecode_buffer = io.BytesIO() module.operation.write_bytecode(bytecode_buffer, desired_version=0) asm = bytecode_buffer.getvalue() @@ -290,7 +406,7 @@ def as_tpu_kernel( asm, out_type, needs_hlo_passes=_MOSAIC_ALLOW_HLO.value, - needs_layout_passes=not device_type, + needs_layout_passes=needs_layout_passes, device_type=device_type, has_communication=has_communication, has_custom_barrier=has_custom_barrier,