Skip to content

Commit

Permalink
[Mosaic] Parameterize the number of lanes and sublanes in TPU dialects.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684392184
  • Loading branch information
Google-ML-Automation committed Oct 10, 2024
1 parent 351187d commit 81a95f7
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 96 deletions.
22 changes: 17 additions & 5 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def _tpu_custom_call_lowering(
def _lower_tpu_kernel(
module: ir.Module,
hardware_generation: int,
target_shape: tuple[int, int],
) -> ir.Module:
"""Runs MLIR passes lowering the given module to an MLIR module.
Expand All @@ -283,6 +284,7 @@ def _lower_tpu_kernel(
Args:
module: The MLIR module to lower.
hardware_generation: The TPU hardware generation to target.
target_shape: The target shape of (sublane_count, lane_count).
Returns:
An MLIR module implementing the kernel.
Expand Down Expand Up @@ -312,11 +314,16 @@ def _lower_tpu_kernel(
pipeline.run(module.operation)
dump_mlir(module, "post-hlo-conversion")

sl_cnt, l_cnt = target_shape
# Note: we don't pass the TpuTilingFlags here, since we don't know the
# tiling decisions made by the compiler / what flags are enabled at this
# point, so we assume everything can be tiled up to default tiling.
pipeline = [
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
"func.func(tpu-infer-memref-layout{"
f" hardware-generation={hardware_generation}"
f" sublane-count={sl_cnt}"
f" lane-count={l_cnt}"
"})"
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
Expand Down Expand Up @@ -357,14 +364,16 @@ def _lower_tpu_kernel(
dump_mlir(module, "post-canonicalize-mosaic")

pipeline = [
"func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
(
"func.func(tpu-infer-vector-layout{"
f" sublane-count={sl_cnt} lane-count={l_cnt}"
"})"
),
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-vector-layout")

sl_cnt = 8
l_cnt = 128
mxu_size = 128 if hardware_generation < 6 else 256
pipeline = [
"func.func(tpu-apply-vector-layout{"
Expand Down Expand Up @@ -414,7 +423,10 @@ def _lower_mosaic_module_to_asm(
"tpu_custom_call cannot be lowered on a machine without TPUs "
"when mosaic_use_python_pipeline=True.")
hardware_generation = int(device_kind[len("TPU v")])
module = _lower_tpu_kernel(module, hardware_generation)
# TODO(b/369418606): Infer the target shape from the hardware generation.
module = _lower_tpu_kernel(
module, hardware_generation, target_shape=(8, 128)
)
needs_hlo_passes = False
needs_layout_passes = False
prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects
Expand Down
16 changes: 4 additions & 12 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,8 @@ class TPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
let mnemonic = mnemonic_;
}

def TPU_Vreg : Type<
And<[IsVectorTypePred,
Or<[
And<[
CPred<"llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128}">,
CPred<"llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth() == 32">
]>,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{"
"8, 128, 32 / ::llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth()}">,
]>
]>,
"native-sized vreg", "::mlir::VectorType">;
// TODO(b/369418606): Find out the way to verify vreg size.
def TPU_Vreg : Type<IsVectorTypePred, "native-sized vreg", "::mlir::VectorType">;

class TPU_Type<string name, string mnemonic_, list<Trait> traits = []>
: TypeDef<TPU_Dialect, name, traits> {
Expand Down Expand Up @@ -738,6 +728,8 @@ def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncO
// If hardware_generation is not set, the default value of -1 will crash on
// runOnOperation.
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
Option<"tpu_tiling_flags", "tpu-tiling-flags", "::mlir::tpu::TpuTilingFlags", /*default=*/"::mlir::tpu::TpuTilingFlags{}", "">,
];
}
Expand Down
6 changes: 4 additions & 2 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ struct ApplyVectorLayoutContext {
std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);

std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
int hardware_generation = -1, const TpuTilingFlags &tpu_tiling_flags = {});
int hardware_generation = -1,
std::array<int64_t, 2> target_shape = {8, 128},
const TpuTilingFlags &tpu_tiling_flags = {});

std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
int hardware_generation = -1);

std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int lane_count = 128, int sublane_count = 8);
std::array<int64_t, 2> target_shape = {8, 128});

std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{});
Expand Down
18 changes: 9 additions & 9 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ FailureOr<TypedValue<MemRefType>> getInternalScratch(
FAILUREOR_ASSIGN_OR_RETURN(
MemRefType scratch_ref_ty,
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation,
/*tpu_tiling_flags=*/{}, sublane_tiling));
ctx.target_shape, /*tpu_tiling_flags=*/{}, sublane_tiling));
return builder.create<tpu::GetInternalScratchOp>(loc, scratch_ref_ty)
.getResult();
}
Expand Down Expand Up @@ -490,7 +490,7 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
MemRefType arg_type,
inferMemref(
MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
ctx.hardware_generation, /*tpu_tiling_flags=*/{}));
ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{}));
const BlockArgument argument = entry_block.insertArgument(
entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx));
const FunctionType func_ty = func.getFunctionType();
Expand Down Expand Up @@ -5821,8 +5821,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
if (try_replicate_rows && packing == 1 &&
*(vregs.dimensions().end() - 2) == 1 &&
src.offsets() == LayoutOffsets{0, 0} &&
src.tiling() == std::array<int64_t, 2>{1, 128} &&
dst_tiling == std::array<int64_t, 2>{8, 128}) {
src.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
dst_tiling == ctx.target_shape) {
xla::Array<Value> retiled(dst_tiles_shape);
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
Expand All @@ -5839,9 +5839,9 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
return std::pair(dst, std::move(retiled));
}
// (8,128) -> (8 * packing,128) tiling change for packed type.
if (bitwidth < 32 && 32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{8, 128} &&
dst_tiling == std::array<int64_t, 2>{8 * dst.packing(), 128}) {
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * dst.packing(),
ctx.target_shape[1]}) {
// Note: for int4, retiling with scratch is always faster.
if (bitwidth != 4 || !has_enough_scratch) {
xla::Array<Value> retiled(dst_tiles_shape);
Expand Down Expand Up @@ -5883,8 +5883,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// match corresponding elements without shifting. It's just that
// the tiles are not adjacent (no contiguous vreg slice).
if (bitwidth < 32 && 32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{1, 128 * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, 128}) {
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
Expand Down
Loading

0 comments on commit 81a95f7

Please sign in to comment.