Skip to content

Commit

Permalink
[Mosaic] Restore Python pipeline and add a CLI flag to run it.
Browse files Browse the repository at this point in the history
We decided to expose a Python alternative again to make it easier for OSS users to see and customize the pipeline. The default is still to run the pipeline from XLA.

The original one was removed in cl/596464480 and cl/597332393.

PiperOrigin-RevId: 615959867
  • Loading branch information
tlongeri authored and jax authors committed Mar 19, 2024
1 parent df9cefa commit be2eb49
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 3 deletions.
2 changes: 2 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,8 @@ pytype_strict_library(
"//jax/_src/lib",
] + if_building_jaxlib([
"//jaxlib/mlir:ir",
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:stablehlo_dialect",
]) + py_deps("numpy") + py_deps("absl/flags"),
)
Expand Down
122 changes: 119 additions & 3 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,30 @@
import jax
from jax import core
from jax._src import config
from jax._src import sharding_impls
from jax._src.interpreters import mlir
from jax._src.lib import tpu
from jax._src.lib import xla_client
from jax._src.lib.mlir.dialects import hlo
from jax._src.interpreters import mlir
from jax._src import sharding_impls
from jax.interpreters import xla
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import mhlo
from jaxlib.mlir.dialects import stablehlo
from jaxlib.mlir.passmanager import PassManager
import numpy as np

FLAGS = flags.FLAGS

_MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state(
name="mosaic_use_python_pipeline",
default=False,
help=(
"Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel"
" is called (for Pallas, this happens at JAX lowering time), instead of"
" later within XLA."
),
)

_MOSAIC_ALLOW_HLO = config.define_bool_state(
name="jax_mosaic_allow_hlo",
default=False,
Expand Down Expand Up @@ -250,6 +262,105 @@ def _tpu_custom_call_lowering(
platform="tpu")


def _lower_tpu_kernel(
module: ir.Module,
hardware_generation: int,
) -> ir.Module:
"""Runs MLIR passes lowering the given module to an MLIR module.
Uses Python versions of infer-memref-layout and apply-vector-layout.
Args:
module: The MLIR module to lower.
hardware_generation: The TPU hardware generation to target.
Returns:
An MLIR module implementing the kernel.
"""
try:
module.operation.verify()
except ir.MLIRError as e:
raise ValueError("The compiled module fails MLIR verification") from e

with module.context as ctx, module.operation.location as _:

ctx.append_dialect_registry(mlir.upstream_dialects)
ctx.load_all_available_dialects()
tpu.register_dialect(ctx)
mhlo.register_mhlo_dialect(ctx)
mhlo.register_mhlo_passes()

# We'll mutate the module, so clone it
module = ir.Module.parse(
module.operation.get_asm(binary=True, enable_debug_info=True)
)
dump_mlir(module, "original")

if _MOSAIC_ALLOW_HLO.value:
# Run hlo dialect conversion: hlo -> linalg -> vector.
pipeline = [
"hlo-legalize-to-arithmetic",
"func.func(hlo-legalize-to-linalg)",
"func.func(linalg-vectorization)",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-hlo-conversion")

pipeline = [
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-memref-layout")

pipeline = [
"canonicalize",
"cse",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-simplify")

if checks := FLAGS["xla_mosaic_on_device_checks"].value:
checks = set(checks.split(","))
if checks == {"bounds"}: # We only support one kind of checks now.
pipeline = PassManager.parse(
"builtin.module(func.func(debug-assert-insertion))"
)
pipeline.run(module.operation)
dump_mlir(module, "post-assert-insertion")
elif checks:
checks.discard("bounds")
raise ValueError(
f"Unrecognized on-device check categories: {', '.join(checks)}"
)

pipeline = [
"func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-vector-layout")

mxu_size = 128 if hardware_generation < 6 else 256
pipeline = [
"func.func(tpu-apply-vector-layout{sublane-count=8 lane-count=128"
f" hardware-generation={hardware_generation}"
f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}"
"})"
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-apply-vector-layout")

pipeline = PassManager.parse("builtin.module(canonicalize)")
pipeline.run(module.operation)
dump_mlir(module, "pre-lower-to-llo")

return module


def as_tpu_kernel(
module: ir.Module,
out_type: Any,
Expand Down Expand Up @@ -279,6 +390,11 @@ def as_tpu_kernel(
has_communication, has_custom_barrier = tpu.private_has_communication(
module.operation
)
needs_layout_passes = not device_type
if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value:
module = _lower_tpu_kernel(module, hardware_generation)
needs_layout_passes = False

bytecode_buffer = io.BytesIO()
module.operation.write_bytecode(bytecode_buffer, desired_version=0)
asm = bytecode_buffer.getvalue()
Expand All @@ -290,7 +406,7 @@ def as_tpu_kernel(
asm,
out_type,
needs_hlo_passes=_MOSAIC_ALLOW_HLO.value,
needs_layout_passes=not device_type,
needs_layout_passes=needs_layout_passes,
device_type=device_type,
has_communication=has_communication,
has_custom_barrier=has_custom_barrier,
Expand Down

0 comments on commit be2eb49

Please sign in to comment.