From 57e34e1a2ce4610a93d87931b78b9c6898718549 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Mon, 26 Feb 2024 13:30:11 -0800 Subject: [PATCH] [Mosaic][NFC] Use `TypedValue` instead of `Value` for applicable arguments/return values in `disassemble` and `relayout` Ideally we would prefer `TypedValue` everywhere possible for static type checking. However, I tried the type for arrays of vregs, `xla::Array` to `xla::Array>` and ran into issues because MLIR support for arrays/ranges of `TypedValue`s seems lacking. For example, I can't find a good way to get a `ValueRange` (which many op constructors take) from an array of `TypedValue`s without creating an intermediate vector of `Value`s. Perhaps an unsafe cast if we make the (probably not guaranteed) assumption that `sizeof(TypedValue)` equals `sizeof(Value)`. Also note that MLIR itself uses untyped `Value`s for ranges of op results and operands even when the op definition declares them to be of a specific type. PiperOrigin-RevId: 610509743 --- .../dialect/tpu/integrations/c/tpu_dialect.cc | 11 +- jaxlib/mosaic/dialect/tpu/tpu.td | 8 +- .../tpu/transforms/apply_vector_layout.cc | 100 +++++++++--------- .../tpu/transforms/apply_vector_layout.h | 11 +- 4 files changed, 71 insertions(+), 59 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index de3ed6c57f88..4ae2d738d93b 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -349,8 +349,10 @@ MlirTpuValueArray mlirTpuDisassemble(MlirTpuInsertionPoint insertion_point, MlirTpuVectorLayout layout, MlirValue val, MlirTpuI64TargetTuple target_shape) { mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point); + // This cast will fail and assert if the caller passed a non-vector + auto vector_val = mlir::cast>(unwrap(val)); mlir::FailureOr> failure_or_vals = - mlir::tpu::disassemble(builder, *unwrap(layout), unwrap(val), + mlir::tpu::disassemble(builder, *unwrap(layout), vector_val, unwrap(target_shape)); if (failed(failure_or_vals)) { return {{nullptr, 0}, nullptr}; @@ -371,8 +373,11 @@ MlirValue mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val, MlirTpuVectorLayout src, MlirTpuVectorLayout dst, MlirTpuI64TargetTuple target_shape) { mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point); - mlir::FailureOr failure_or_new_val = mlir::tpu::relayout( - builder, unwrap(val), *unwrap(src), *unwrap(dst), unwrap(target_shape)); + // This cast will fail and assert if the caller passed a non-vector + auto vector_val = mlir::cast>(unwrap(val)); + mlir::FailureOr> failure_or_new_val = + mlir::tpu::relayout(builder, vector_val, *unwrap(src), *unwrap(dst), + unwrap(target_shape)); if (failed(failure_or_new_val)) { return {nullptr}; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 0371f5c56c9e..a50b3e0b3b8a 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -325,16 +325,16 @@ def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> { } def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> { - let arguments = (ins Variadic:$input); - let results = (outs AnyType:$output); + let arguments = (ins Variadic:$input); + let results = (outs AnyVector:$output); let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; } def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> { - let arguments = (ins AnyType:$input); - let results = (outs Variadic:$output); + let arguments = (ins AnyVector:$input); + let results = (outs Variadic:$output); let hasCanonicalizeMethod = 1; let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 7ba68c2a6d21..221f64483324 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -590,9 +590,11 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, SmallVector> in_vreg_arrays; in_vreg_arrays.reserve(num_operands); for (unsigned i = 0; i < num_operands; ++i) { - FAILUREOR_ASSIGN_OR_RETURN(xla::Array tile_array, - disassemble(builder, *layouts_in[i], - op.getOperand(i), ctx.target_shape)); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array tile_array, + disassemble(builder, *layouts_in[i], + cast>(op.getOperand(i)), + ctx.target_shape)); in_vreg_arrays.emplace_back(std::move(tile_array)); } @@ -653,15 +655,16 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, const VectorLayout &layout_in, const VectorLayout &layout_out) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); - auto result_ty = cast(op.getResult().getType()); - auto source_ty = cast(op.getIn().getType()); + const auto result_ty = cast(op.getResult().getType()); + auto source = cast>(op.getIn()); + const auto source_ty = source.getType(); if (layout_out.bitwidth() != 32) { return op.emitOpError( "Not implemented: Only extensions to 32-bit supported"); } FAILUREOR_ASSIGN_OR_RETURN( const xla::Array input_vregs, - disassemble(builder, layout_in, op.getIn(), ctx.target_shape)); + disassemble(builder, layout_in, source, ctx.target_shape)); xla::Array output_vregs( layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape)); FAILUREOR_ASSIGN_OR_RETURN( @@ -762,7 +765,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, auto result_ty = cast(op.getResult().getType()); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array input_vregs, - disassemble(builder, layout_in, op.getIn(), ctx.target_shape)); + disassemble(builder, layout_in, cast>(op.getIn()), + ctx.target_shape)); xla::Array output_vregs( layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape)); if (layout_in.bitwidth() != 32) { @@ -905,13 +909,13 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op, } continue; } - if (auto vty = dyn_cast(operand.getType())) { + if (auto vector_operand = dyn_cast>(operand)) { if (!layout.has_value()) { return op.emitOpError("Expected layout for vector operand"); } FAILUREOR_ASSIGN_OR_RETURN( const xla::Array tiles, - disassemble(builder, *layout, operand, ctx.target_shape)); + disassemble(builder, *layout, vector_operand, ctx.target_shape)); unrolled_args.append(tiles.begin(), tiles.end()); } else { if (layout.has_value()) { @@ -1098,12 +1102,12 @@ LogicalResult scf_yield_rule(RewriteContext &ctx, Operation &op, SmallVector unrolled; for (auto [operand, layout] : llvm::zip_equal(yield_op.getOperands(), layouts_in)) { - if (auto vty = dyn_cast(operand.getType())) { + if (auto vector_operand = dyn_cast>(operand)) { // When the operand has vector type, disassemble the operand. TPU_ASSERT_OP(layout.has_value()); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array tiles, - disassemble(builder, *layout, operand, ctx.target_shape)); + disassemble(builder, *layout, vector_operand, ctx.target_shape)); unrolled.append(tiles.begin(), tiles.end()); } else { TPU_ASSERT_OP(!layout.has_value()); @@ -1745,7 +1749,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, for (Value operand : concatenate_op.getOperands()) { FAILUREOR_ASSIGN_OR_RETURN( xla::Array t, - disassemble(builder, layout, operand, ctx.target_shape)); + disassemble(builder, layout, cast>(operand), + ctx.target_shape)); tiles.emplace_back(std::move(t)); } const xla::Array res_tiles = concatenate(tiles, dimension); @@ -2227,7 +2232,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, const VectorType dst_ty = broadcast_op.getResult().getType(); const SmallVector dst_tiles_shape = layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape); - if (auto src_ty = dyn_cast(broadcast_op.getSourceType())) { + if (auto src = dyn_cast>(broadcast_op.getSource())) { + VectorType src_ty = src.getType(); TPU_ASSERT_OP(maybe_layout_in.has_value()); const VectorLayout &layout_in = *maybe_layout_in; if (layout_in.implicit_dim() != layout_out.implicit_dim()) { @@ -2301,8 +2307,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_tiles, - disassemble(builder, layout_in, broadcast_op.getSource(), - ctx.target_shape)); + disassemble(builder, layout_in, src, ctx.target_shape)); xla::Array dst_tiles(dst_tiles_shape); if (no_op) { SmallVector reshape_dims(expand_rank, 1); @@ -2666,10 +2671,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, return multi_reduction_op.emitOpError( "Not implemented: Can only reduce into vectors"); } - if (!layouts_out.front().has_value()) { - // Shouldn't be empty since result is a vector - return op.emitOpError("Expected non-null output layout"); - } + // Op definition enforces that accumulator type must match result type + auto acc = cast>(multi_reduction_op.getAcc()); + TPU_ASSERT_OP(layouts_out.front().has_value()); const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims(); SmallVector dims; @@ -2686,11 +2690,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, } FAILUREOR_ASSIGN_OR_RETURN( const xla::Array acc_vregs, - disassemble(builder, acc_layout, multi_reduction_op.getAcc(), - ctx.target_shape)); - const Value acc_vreg = *acc_vregs.begin(); - auto acc_def = - dyn_cast_if_present(acc_vreg.getDefiningOp()); + disassemble(builder, acc_layout, acc, ctx.target_shape)); + auto acc_def = dyn_cast_if_present( + acc_vregs.begin()->getDefiningOp()); if (acc_def == nullptr) { return multi_reduction_op.emitOpError( "Not implemented: Only constant accumulator supported"); @@ -2838,7 +2840,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, } xla::Array reduced_vregs = src_vregs.Slice(src_slice_start, src_slice_end); - std::optional acc; + std::optional acc_vreg; auto reduction_status = reduced_vregs.EachStatus( [&](const absl::Span red_idx, Value *const src_vreg) { @@ -2860,17 +2862,17 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, return absl::UnknownError(""); } Value vreg = failure_or_vreg.value(); - if (!acc.has_value()) { - acc = vreg; + if (!acc_vreg.has_value()) { + acc_vreg = vreg; } else { switch (tpu_kind) { case tpu::ReductionKind::SUM: - acc = builder.create(vreg.getLoc(), *acc, - vreg); + acc_vreg = builder.create(vreg.getLoc(), + *acc_vreg, vreg); break; case tpu::ReductionKind::MAX: - acc = builder.create(vreg.getLoc(), *acc, - vreg); + acc_vreg = builder.create( + vreg.getLoc(), *acc_vreg, vreg); break; } } @@ -2879,16 +2881,16 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, if (!reduction_status.ok()) { return reduction_status; } - TPU_ASSERT_OP(acc.has_value()); + TPU_ASSERT_OP(acc_vreg.has_value()); if (reduces[1]) { - acc = builder.create(multi_reduction_op->getLoc(), - *acc, 1, tpu_kind); + acc_vreg = builder.create( + multi_reduction_op->getLoc(), *acc_vreg, 1, tpu_kind); } if (reduces[0]) { - acc = builder.create(multi_reduction_op->getLoc(), - *acc, 0, tpu_kind); + acc_vreg = builder.create( + multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind); } - *dst_vreg = *acc; + *dst_vreg = *acc_vreg; return absl::OkStatus(); }); if (!all_results_ok.ok()) { @@ -3478,9 +3480,10 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty, // Returns: // An ndarray of MLIR values representing the tiling of val given by layout. FailureOr> disassemble( - OpBuilder &builder, const VectorLayout &layout, const Value val, + OpBuilder &builder, const VectorLayout &layout, + const TypedValue val, const std::array target_shape) { - const auto vty = cast(val.getType()); + const auto vty = val.getType(); const auto op_result = dyn_cast(val); if (op_result == nullptr) { return failure(); @@ -3869,15 +3872,15 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx, } // TODO(apaszke): Test this function properly -FailureOr relayout(OpBuilder &builder, Value v, VectorLayout src, - const VectorLayout &dst, - const std::array target_shape) { +FailureOr> relayout( + OpBuilder &builder, TypedValue v, VectorLayout src, + const VectorLayout &dst, const std::array target_shape) { const int8_t bitwidth = src.bitwidth(); if (bitwidth != dst.bitwidth()) { return emitError(v.getLoc(), "Can't change bitwidth during a relayout"); } const int packing = src.packing(); - VectorType vty = cast(v.getType()); + VectorType vty = v.getType(); FAILUREOR_ASSIGN_OR_RETURN(xla::Array src_tiles, disassemble(builder, src, v, target_shape)); SmallVector dst_tiles_shape = @@ -4202,16 +4205,17 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { for (auto [idx, tup] : llvm::enumerate(llvm::zip(op.getOperands(), layouts_in))) { auto [operand, li] = tup; - auto vty = dyn_cast(operand.getType()); - TPU_ASSERT_EQ_OP(vty != nullptr, li.has_value()); - if (vty == nullptr) { + auto vector_operand = dyn_cast>(operand); + TPU_ASSERT_EQ_OP(vector_operand != nullptr, li.has_value()); + if (vector_operand == nullptr) { continue; } + auto vty = vector_operand.getType(); // The operand should always be an Operation (and not a BlockArgument) // since we expect the FuncOp to have only memrefs and semaphores as // arguments. - auto op_result = dyn_cast(operand); + auto op_result = dyn_cast(vector_operand); if (op_result == nullptr) { return op.emitError("Expected operand to be an operation result"); } @@ -4227,7 +4231,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { } OpBuilder builder(&op); FAILUREOR_ASSIGN_OR_RETURN(Value new_v, - relayout(builder, operand, /*src=*/*lo, + relayout(builder, vector_operand, /*src=*/*lo, /*dst=*/*li, ctx.target_shape)); op.setOperand(idx, new_v); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h index bb8f547b40dd..7dcff2c1c916 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h @@ -29,7 +29,8 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty, const xla::Array &vals, std::array target_shape); FailureOr> disassemble(OpBuilder &builder, - const VectorLayout &layout, Value val, + const VectorLayout &layout, + TypedValue val, std::array target_shape); // Rewrites the operation according to its layout annotations. @@ -55,9 +56,11 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op); // // Returns: // A new MLIR vector value, laid out as requested by dst. -FailureOr relayout(OpBuilder &builder, Value v, VectorLayout src, - const VectorLayout &dst, - std::array target_shape); +FailureOr> relayout(OpBuilder &builder, + TypedValue v, + VectorLayout src, + const VectorLayout &dst, + std::array target_shape); } // namespace mlir::tpu