From 7880b2c8586eade00a4aa5ac11007317a61e376c Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:11:11 -0500 Subject: [PATCH] [mlir] Add direct vectorization lowering for `tensor.pack` ops (#78660) This PR adds a direct vectorization lowering of `tensor.pack` into `mask(vector.transfer_read)`->`vector.shape_cast`->`vector.transpose`->`vector.transfer_write`. --- .../include/mlir/Dialect/Tensor/Utils/Utils.h | 8 + .../TransformOps/LinalgTransformOps.cpp | 2 +- .../Dialect/Linalg/Transforms/Transforms.cpp | 36 +-- .../Linalg/Transforms/Vectorization.cpp | 236 +++++++++++++++--- mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 29 +++ mlir/test/Dialect/Linalg/vectorization.mlir | 116 ++++++++- 6 files changed, 357 insertions(+), 70 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h index 04b4de4a33a5..fe9b16cb44b3 100644 --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -32,6 +32,14 @@ FailureOr computeTransposedType(RankedTensorType rankedTensorType, ArrayRef transposeVector); +/// Given a tensor::PackOp, compute the permutation vector to shuffle the +/// packed shape into the shape before any outer or inner permutations have +/// been applied. +/// i.e. for a pack from an ABCD layout to an ABCDba: +/// The packed shape would be ABCDba. +/// The pre-permutation shape would be AaBbCD. +SmallVector getPackInverseDestPermutation(PackOp packOp); + /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the /// source tensor or inserts the source tensor into a destination tensor with /// the same shape. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 6431bbd25396..585fd14b40d7 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3152,7 +3152,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply( // TODO: Check that the correct number of vectorSizes was provided. for (Operation *target : targets) { - if (!isa(target)) { + if (!isa(target)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Unsupported Op, cannot vectorize"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 02bc3e672bf7..596b7c50c1e4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -233,31 +233,11 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, rewriter.setInsertionPoint(packOp); // 2. Compute the permutation vector to shuffle packed shape into the shape - // before any outer or inner permutations have been applied. The permutation - // can be obtained from two permutations: - // a) Compute the permutation vector to move the last `numPackedDims` into - // the `innerPosDims` of a shape of rank `packedRank`. - // b) Compute the permutation vector to move outer dims if the pack op - // has outer_dims_perm. - // Apply (b) permutation on (a) permutation to get the final permutation. - int64_t numPackedDims = packOp.getInnerDimsPos().size(); - int64_t packedRank = packedTensorType.getRank(); - auto lastDims = llvm::to_vector( - llvm::seq(packedRank - numPackedDims, packedRank)); + // before any outer or inner permutations have been applied. PackingMetadata packingMetadata = computePackingMetadata( packedTensorType.getRank(), packOp.getInnerDimsPos()); - SmallVector innerPositionsPerm = computePermutationVector( - packedRank, lastDims, packingMetadata.insertPositions); - - SmallVector outerPos = packingMetadata.outerPositions; - ArrayRef outerPerm = packOp.getOuterDimsPerm(); - if (!outerPerm.empty()) - applyPermutationToVector(outerPos, outerPerm); - SmallVector outerPositionPerm = computePermutationVector( - packedRank, packingMetadata.outerPositions, outerPos); - - SmallVector packedToStripMinedShapePerm = innerPositionsPerm; - applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm); + SmallVector packedToStripMinedShapePerm = + tensor::getPackInverseDestPermutation(packOp); // 3. Compute the stripMinedShape: this is the packed shape before any outer // or inner permutations have been applied. @@ -304,10 +284,6 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), DBGS() << "packedShape: "); DBGSNL(); - llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: "); - DBGSNL(); llvm::interleaveComma(innerPositionsPerm, - DBGS() << "innerPositionsPerm: "); - DBGSNL(); llvm::interleaveComma(packedToStripMinedShapePerm, DBGS() << "packedToStripMinedShapePerm: "); DBGSNL(); llvm::interleaveComma( @@ -332,9 +308,11 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, auto emptyOp = rewriter.create(loc, packedTensorType, ValueRange{}); // Offsets. - SmallVector zeros(packedRank, rewriter.getIndexAttr(0)); + SmallVector zeros(packOp.getDestRank(), + rewriter.getIndexAttr(0)); // Strides. - SmallVector ones(packedRank, rewriter.getIndexAttr(1)); + SmallVector ones(packOp.getDestRank(), + rewriter.getIndexAttr(1)); SmallVector sizes = tensor::getMixedSizes(rewriter, loc, packOp.getDest()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0707625819d1..2bd6929fea61 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -19,10 +19,16 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" @@ -30,7 +36,9 @@ #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -1393,6 +1401,164 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, return success(); } +/// Given a tensor::PackOp, return the `dest` shape before any packing +/// permutations. +static SmallVector getTiledPackShape(tensor::PackOp packOp, + ArrayRef destShape) { + return applyPermutation(destShape, + tensor::getPackInverseDestPermutation(packOp)); +} + +/// Create a TransferReadOp from `source` with static shape `readShape`. If the +/// vector type for the read is not the same as the type of `source`, then a +/// mask is created on the read. +static Value createReadOrMaskedRead(OpBuilder &builder, Location loc, + Value source, ArrayRef readShape, + Value padValue) { + assert(llvm::none_of(readShape, + [](int64_t s) { return s == ShapedType::kDynamic; })); + auto sourceShape = dyn_cast(source.getType()).getShape(); + assert(sourceShape.size() == readShape.size()); + auto maskType = VectorType::get(readShape, builder.getI1Type()); + auto vectorType = VectorType::get(readShape, padValue.getType()); + int64_t readRank = readShape.size(); + auto zero = builder.create(loc, 0); + auto transferReadOp = builder.create( + loc, + /*vectorType=*/vectorType, + /*source=*/source, + /*indices=*/SmallVector(readRank, zero), + /*padding=*/padValue, + /*inBounds=*/SmallVector(readRank, true)); + if (llvm::equal(readShape, sourceShape)) { + return transferReadOp; + } + SmallVector mixedSourceDims = + tensor::getMixedSizes(builder, loc, source); + Value mask = + builder.create(loc, maskType, mixedSourceDims); + return mlir::vector::maskOperation(builder, transferReadOp, mask) + ->getResult(0); +} + +/// Given an input, the mixed destSizes, and the vector sizes for vectorization, +/// create an empty destination tensor and create a TransferWriteOp from the +/// input to the empty tensor. If the destination shape is not the same as the +/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a +/// mask for the write. +static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc, + Value input, + SmallVector destSizes, + ArrayRef inputVectorSizes) { + auto inputType = cast(input.getType()); + Value dest = builder.create(loc, destSizes, + inputType.getElementType()); + int64_t rank = cast(dest.getType()).getRank(); + auto zero = builder.create(loc, 0); + Operation *write = builder.create( + loc, + /*vector=*/input, + /*source=*/dest, + /*indices=*/SmallVector(rank, zero), + /*inBounds=*/SmallVector(rank, true)); + auto destShape = cast(dest.getType()).getShape(); + assert(llvm::none_of( + destShape.drop_front(inputVectorSizes.size()), + [](int64_t size) { return size == ShapedType::kDynamic; }) && + "Only dims aligned with inputVectorSizes may be dynamic"); + bool needMaskForWrite = !llvm::equal( + inputVectorSizes, destShape.take_front(inputVectorSizes.size())); + if (needMaskForWrite) { + SmallVector writeMaskShape; + writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end()); + writeMaskShape.append(destShape.begin() + inputVectorSizes.size(), + destShape.end()); + auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type()); + Value maskForWrite = + builder.create(loc, writeMaskType, destSizes); + write = mlir::vector::maskOperation(builder, write, maskForWrite); + } + return write; +} + +/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant +/// padding value into: +/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds +/// As in the following example: +/// +/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] +/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> +/// +/// This pack would be vectorized to: +/// +/// %load = vector.mask %mask { +/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst +/// {in_bounds = [true, true, true]} : +/// tensor<32x7x16xf32>, vector<32x8x16xf32> +/// } : vector<32x8x16xi1> -> vector<32x8x16xf32> +/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32> +/// to vector<32x4x2x1x16xf32> +/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2] +/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> +/// %write = vector.transfer_write %transpose, +/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] +/// {in_bounds = [true, true, true, true, true]} +/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> +static LogicalResult +vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, + ArrayRef inputVectorSizes, + SmallVectorImpl &newResults) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(packOp); + + Location loc = packOp.getLoc(); + auto padValue = packOp.getPaddingValue(); + if (!padValue) { + padValue = rewriter.create( + loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType())); + } + ReifiedRankedShapedTypeDims reifiedReturnShapes; + LogicalResult status = + cast(packOp.getOperation()) + .reifyResultShapes(rewriter, reifiedReturnShapes); + (void)status; // prevent unused variable warning on non-assert builds. + assert(succeeded(status) && "failed to reify result shapes"); + + // Create masked TransferReadOp. + SmallVector inputShape(inputVectorSizes); + auto innerTiles = packOp.getStaticInnerTiles(); + auto innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + if (!outerDimsPerm.empty()) + applyPermutationToVector(inputShape, + invertPermutationVector(outerDimsPerm)); + for (auto [idx, size] : enumerate(innerTiles)) + inputShape[innerDimsPos[idx]] *= size; + auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(), + inputShape, padValue); + + // Create ShapeCastOp. + SmallVector destShape(inputVectorSizes); + destShape.append(innerTiles.begin(), innerTiles.end()); + auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), + packOp.getDestType().getElementType()); + auto shapeCastOp = + rewriter.create(loc, tiledPackType, maskedRead); + + // Create TransposeOp. + auto destPermutation = + invertPermutationVector(tensor::getPackInverseDestPermutation(packOp)); + auto transposeOp = rewriter.create( + loc, shapeCastOp.getResult(), destPermutation); + + // Create TransferWriteOp. + Operation *write = + createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), + reifiedReturnShapes[0], inputVectorSizes); + newResults.push_back(write->getResult(0)); + return success(); +} + /// Vectorize a `padOp` with (1) static result type, (2) constant padding value /// and (3) all-zero lowPad to /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. @@ -1402,9 +1568,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, SmallVectorImpl &newResults) { auto padValue = padOp.getConstantPaddingValue(); Location loc = padOp.getLoc(); - int64_t rank = inputVectorSizes.size(); - auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type()); - auto vectorType = VectorType::get(inputVectorSizes, padValue.getType()); // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value)) OpBuilder::InsertionGuard g(rewriter); @@ -1416,36 +1579,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, .reifyResultShapes(rewriter, reifiedReturnShapes); (void)status; // prevent unused variable warning on non-assert builds assert(succeeded(status) && "failed to reify result shapes"); - auto emptyOp = rewriter.create(loc, reifiedReturnShapes[0], - padValue.getType()); - SmallVector mixedSourceDims = - tensor::getMixedSizes(rewriter, loc, padOp.getSource()); - Value mask = - rewriter.create(loc, maskType, mixedSourceDims); - auto zero = rewriter.create(loc, 0); - auto transferReadOp = rewriter.create( - loc, - /*vectorType=*/vectorType, - /*source=*/padOp.getSource(), - /*indices=*/SmallVector(rank, zero), - /*padding=*/padValue, - /*inBounds=*/SmallVector(rank, true)); - auto maskedOp = cast( - mlir::vector::maskOperation(rewriter, transferReadOp, mask)); - Operation *write = rewriter.create( - loc, - /*vector=*/maskedOp->getResult(0), - /*source=*/emptyOp, - /*indices=*/SmallVector(rank, zero), - /*inBounds=*/SmallVector(rank, true)); - bool needMaskForWrite = llvm::any_of( - llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()), - [](auto it) { return std::get<0>(it) != std::get<1>(it); }); - if (needMaskForWrite) { - Value maskForWrite = rewriter.create( - loc, maskType, reifiedReturnShapes[0]); - write = mlir::vector::maskOperation(rewriter, write, maskForWrite); - } + auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(), + inputVectorSizes, padValue); + Operation *write = createWriteOrMaskedWrite( + rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes); newResults.push_back(write->getResult(0)); return success(); } @@ -1585,6 +1722,32 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp, return success(); } +/// TODO: Use a matcher to check for a constant padding value. +static LogicalResult +vectorizePackOpPrecondition(tensor::PackOp packOp, + ArrayRef inputVectorSizes) { + auto padValue = packOp.getPaddingValue(); + if (padValue && !padValue.getDefiningOp()) { + LDBG("pad value is not constant: " << packOp << "\n"); + return failure(); + } + + ArrayRef resultTensorShape = packOp.getDestType().getShape(); + if (failed(isValidMaskedInputVector( + resultTensorShape.take_front(packOp.getSourceRank()), + inputVectorSizes))) + return failure(); + + if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) { + return !getConstantIntValue(v).has_value(); + })) { + LDBG("inner_tiles must be constant: " << packOp << "\n"); + return failure(); + } + + return success(); +} + static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef inputVectorSizes) { @@ -1644,6 +1807,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition( .Case([&](auto padOp) { return vectorizePadOpPrecondition(padOp, inputVectorSizes); }) + .Case([&](auto packOp) { + return vectorizePackOpPrecondition(packOp, inputVectorSizes); + }) .Default([](auto) { return failure(); }); } @@ -1732,6 +1898,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, results); }) + .Case([&](auto packOp) { + return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, + results); + }) .Default([](auto) { return failure(); }); if (failed(vectorizeResult)) { diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index 24cbceb3d117..f20008a1ed2b 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -73,6 +73,35 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType, return transposedTensorType; } +SmallVector +mlir::tensor::getPackInverseDestPermutation(PackOp packOp) { + // The permutation can be obtained from two permutations: + // a) Compute the permutation vector to move the last `numPackedDims` into + // the `innerPosDims` of a shape of rank `packedRank`. + // b) Compute the permutation vector to move outer dims if the pack op + // has outer_dims_perm. + // Apply (b) permutation on (a) permutation to get the final permutation. + int64_t numPackedDims = packOp.getInnerDimsPos().size(); + int64_t packedRank = packOp.getDestType().getRank(); + auto lastDims = llvm::to_vector( + llvm::seq(packedRank - numPackedDims, packedRank)); + PackingMetadata packingMetadata = computePackingMetadata( + packOp.getDestType().getRank(), packOp.getInnerDimsPos()); + SmallVector innerPositionsPerm = computePermutationVector( + packedRank, lastDims, packingMetadata.insertPositions); + + SmallVector outerPos = packingMetadata.outerPositions; + ArrayRef outerPerm = packOp.getOuterDimsPerm(); + if (!outerPerm.empty()) + applyPermutationToVector(outerPos, outerPerm); + SmallVector outerPositionPerm = computePermutationVector( + packedRank, packingMetadata.outerPositions, outerPos); + + SmallVector packInverseDestPermutation = innerPositionsPerm; + applyPermutationToVector(packInverseDestPermutation, outerPositionPerm); + return packInverseDestPermutation; +} + bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { llvm::SmallBitVector droppedDims = op.getDroppedDims(); int64_t srcDim = 0; diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index d5fb0cbb9c72..5d1bef478ee9 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -426,16 +426,17 @@ func.func @test_masked_vectorize_pad( { // CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32 // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32> + // CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> - // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { - // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]] + // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_0]], %[[c0_0]]], %[[c42]] // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> - // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]] + // CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32> + // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_1]], %[[c0_1]]] // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32> %cst = arith.constant 42.43 : f32 %c0 = arith.constant 0 : index @@ -467,18 +468,19 @@ func.func @test_masked_vectorize_dynamic_pad( // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[res_d0:.+]] = affine.apply #[[MAP]]() // CHECK-DAG: %[[res_d1:.+]] = affine.apply #[[MAP]]() - // CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor + // CHECK: %[[c0_2:.*]] = arith.constant 0 : index // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> - // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]] // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> + // CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor + // CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index // CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1> // CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] { - // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]] + // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]] // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor // CHECK: return %[[masked_write]] : tensor %cst = arith.constant 42.43 : f32 @@ -501,6 +503,106 @@ module attributes {transform.with_named_sequence} { // ----- +func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { + %pack = tensor.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32> + return %pack : tensor<4x1x32x16x2xf32> +} +// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]] +// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32> +// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> +// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32> +// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x32x16x2xf32> +// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]] +// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32> +// CHECK: return %[[write]] : tensor<4x1x32x16x2xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [4, 1, 32] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { + %pad = arith.constant 0.000000e+00 : f32 + %pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> + return %pack : tensor<32x4x1x16x2xf32> +} +// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index +// CHECK-DAG: %[[c15:.*]] = arith.constant 15 : index +// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c7]], %[[c15]] : vector<32x8x16xi1> +// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { +// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]] +// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x7x15xf32>, vector<32x8x16xf32> +// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32> +// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> +// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> +// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32> +// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]] +// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> +// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @test_vectorize_dynamic_pack(%arg0: tensor, %arg1: tensor) -> tensor { + %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor -> tensor + return %pack : tensor +} +// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor +// CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor +// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1_0:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[d0_0:.*]] = tensor.dim {{.*}} %[[c0_0]] : tensor +// CHECK-DAG: %[[d1_0:.*]] = tensor.dim {{.*}} %[[c1_0]] : tensor +// CHECK: %[[mask:.*]] = vector.create_mask %[[d0_0]], %[[d1_0]] : vector<8x16xi1> +// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { +// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]] +// CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<8x16xf32> +// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32> +// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32> +// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32> +// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[d0]], %[[d1]]) : tensor +// CHECK: %[[mask_0:.*]] = vector.create_mask %[[d0]], %[[d1]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1> +// CHECK: %[[masked_write:.*]] = vector.mask %[[mask_0]] { +// CHECK-SAME: vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]] +// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor +// CHECK: return %[[masked_write]] : tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + func.func @matmul(%A: memref, %B: memref, %C: memref) { linalg.matmul ins(%A, %B: memref, memref) outs(%C: memref)