Skip to content

Commit

Permalink
[mlir] Add direct vectorization lowering for tensor.pack ops (#78660)
Browse files Browse the repository at this point in the history
This PR adds a direct vectorization lowering of `tensor.pack` into
`mask(vector.transfer_read)`->`vector.shape_cast`->`vector.transpose`->`vector.transfer_write`.
  • Loading branch information
Max191 authored Feb 7, 2024
1 parent 347ab99 commit 7880b2c
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 70 deletions.
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);

/// Given a tensor::PackOp, compute the permutation vector to shuffle the
/// packed shape into the shape before any outer or inner permutations have
/// been applied.
/// i.e. for a pack from an ABCD layout to an ABCDba:
/// The packed shape would be ABCDba.
/// The pre-permutation shape would be AaBbCD.
SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);

/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
/// the same shape.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3152,7 +3152,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(

// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
Expand Down
36 changes: 7 additions & 29 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,31 +233,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
rewriter.setInsertionPoint(packOp);

// 2. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied. The permutation
// can be obtained from two permutations:
// a) Compute the permutation vector to move the last `numPackedDims` into
// the `innerPosDims` of a shape of rank `packedRank`.
// b) Compute the permutation vector to move outer dims if the pack op
// has outer_dims_perm.
// Apply (b) permutation on (a) permutation to get the final permutation.
int64_t numPackedDims = packOp.getInnerDimsPos().size();
int64_t packedRank = packedTensorType.getRank();
auto lastDims = llvm::to_vector(
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
// before any outer or inner permutations have been applied.
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
packedRank, lastDims, packingMetadata.insertPositions);

SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
if (!outerPerm.empty())
applyPermutationToVector(outerPos, outerPerm);
SmallVector<int64_t> outerPositionPerm = computePermutationVector(
packedRank, packingMetadata.outerPositions, outerPos);

SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
SmallVector<int64_t> packedToStripMinedShapePerm =
tensor::getPackInverseDestPermutation(packOp);

// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
Expand Down Expand Up @@ -304,10 +284,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
DBGS() << "packedShape: ");
DBGSNL();
llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
DBGS() << "innerPositionsPerm: ");
DBGSNL();
llvm::interleaveComma(packedToStripMinedShapePerm,
DBGS() << "packedToStripMinedShapePerm: ");
DBGSNL(); llvm::interleaveComma(
Expand All @@ -332,9 +308,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
auto emptyOp =
rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
// Offsets.
SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
rewriter.getIndexAttr(0));
// Strides.
SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> ones(packOp.getDestRank(),
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());

Expand Down
236 changes: 203 additions & 33 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,26 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#include <type_traits>
Expand Down Expand Up @@ -1393,6 +1401,164 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}

/// Given a tensor::PackOp, return the `dest` shape before any packing
/// permutations.
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
ArrayRef<int64_t> destShape) {
return applyPermutation(destShape,
tensor::getPackInverseDestPermutation(packOp));
}

/// Create a TransferReadOp from `source` with static shape `readShape`. If the
/// vector type for the read is not the same as the type of `source`, then a
/// mask is created on the read.
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source, ArrayRef<int64_t> readShape,
Value padValue) {
assert(llvm::none_of(readShape,
[](int64_t s) { return s == ShapedType::kDynamic; }));
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
assert(sourceShape.size() == readShape.size());
auto maskType = VectorType::get(readShape, builder.getI1Type());
auto vectorType = VectorType::get(readShape, padValue.getType());
int64_t readRank = readShape.size();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
auto transferReadOp = builder.create<vector::TransferReadOp>(
loc,
/*vectorType=*/vectorType,
/*source=*/source,
/*indices=*/SmallVector<Value>(readRank, zero),
/*padding=*/padValue,
/*inBounds=*/SmallVector<bool>(readRank, true));
if (llvm::equal(readShape, sourceShape)) {
return transferReadOp;
}
SmallVector<OpFoldResult> mixedSourceDims =
tensor::getMixedSizes(builder, loc, source);
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
->getResult(0);
}

/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
/// create an empty destination tensor and create a TransferWriteOp from the
/// input to the empty tensor. If the destination shape is not the same as the
/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
/// mask for the write.
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
Value input,
SmallVector<OpFoldResult> destSizes,
ArrayRef<int64_t> inputVectorSizes) {
auto inputType = cast<VectorType>(input.getType());
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
inputType.getElementType());
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Operation *write = builder.create<vector::TransferWriteOp>(
loc,
/*vector=*/input,
/*source=*/dest,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/SmallVector<bool>(rank, true));
auto destShape = cast<ShapedType>(dest.getType()).getShape();
assert(llvm::none_of(
destShape.drop_front(inputVectorSizes.size()),
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
"Only dims aligned with inputVectorSizes may be dynamic");
bool needMaskForWrite = !llvm::equal(
inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
if (needMaskForWrite) {
SmallVector<int64_t> writeMaskShape;
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
destShape.end());
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
Value maskForWrite =
builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
write = mlir::vector::maskOperation(builder, write, maskForWrite);
}
return write;
}

/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
/// padding value into:
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
/// As in the following example:
///
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
///
/// This pack would be vectorized to:
///
/// %load = vector.mask %mask {
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
/// {in_bounds = [true, true, true]} :
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
/// to vector<32x4x2x1x16xf32>
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
/// %write = vector.transfer_write %transpose,
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
/// {in_bounds = [true, true, true, true, true]}
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);

Location loc = packOp.getLoc();
auto padValue = packOp.getPaddingValue();
if (!padValue) {
padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
}
ReifiedRankedShapedTypeDims reifiedReturnShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds.
assert(succeeded(status) && "failed to reify result shapes");

// Create masked TransferReadOp.
SmallVector<int64_t> inputShape(inputVectorSizes);
auto innerTiles = packOp.getStaticInnerTiles();
auto innerDimsPos = packOp.getInnerDimsPos();
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty())
applyPermutationToVector(inputShape,
invertPermutationVector(outerDimsPerm));
for (auto [idx, size] : enumerate(innerTiles))
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
inputShape, padValue);

// Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
destShape.append(innerTiles.begin(), innerTiles.end());
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
packOp.getDestType().getElementType());
auto shapeCastOp =
rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);

// Create TransposeOp.
auto destPermutation =
invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
auto transposeOp = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getResult(), destPermutation);

// Create TransferWriteOp.
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
reifiedReturnShapes[0], inputVectorSizes);
newResults.push_back(write->getResult(0));
return success();
}

/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
Expand All @@ -1402,9 +1568,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
SmallVectorImpl<Value> &newResults) {
auto padValue = padOp.getConstantPaddingValue();
Location loc = padOp.getLoc();
int64_t rank = inputVectorSizes.size();
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());

// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
OpBuilder::InsertionGuard g(rewriter);
Expand All @@ -1416,36 +1579,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
padValue.getType());
SmallVector<OpFoldResult> mixedSourceDims =
tensor::getMixedSizes(rewriter, loc, padOp.getSource());
Value mask =
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc,
/*vectorType=*/vectorType,
/*source=*/padOp.getSource(),
/*indices=*/SmallVector<Value>(rank, zero),
/*padding=*/padValue,
/*inBounds=*/SmallVector<bool>(rank, true));
auto maskedOp = cast<vector::MaskOp>(
mlir::vector::maskOperation(rewriter, transferReadOp, mask));
Operation *write = rewriter.create<vector::TransferWriteOp>(
loc,
/*vector=*/maskedOp->getResult(0),
/*source=*/emptyOp,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/SmallVector<bool>(rank, true));
bool needMaskForWrite = llvm::any_of(
llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
if (needMaskForWrite) {
Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
loc, maskType, reifiedReturnShapes[0]);
write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
}
auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(),
inputVectorSizes, padValue);
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
newResults.push_back(write->getResult(0));
return success();
}
Expand Down Expand Up @@ -1585,6 +1722,32 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success();
}

/// TODO: Use a matcher to check for a constant padding value.
static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}

ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
if (failed(isValidMaskedInputVector(
resultTensorShape.take_front(packOp.getSourceRank()),
inputVectorSizes)))
return failure();

if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
return !getConstantIntValue(v).has_value();
})) {
LDBG("inner_tiles must be constant: " << packOp << "\n");
return failure();
}

return success();
}

static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
Expand Down Expand Up @@ -1644,6 +1807,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
.Case<tensor::PackOp>([&](auto packOp) {
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
})
.Default([](auto) { return failure(); });
}

Expand Down Expand Up @@ -1732,6 +1898,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
results);
})
.Case<tensor::PackOp>([&](auto packOp) {
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
results);
})
.Default([](auto) { return failure(); });

if (failed(vectorizeResult)) {
Expand Down
Loading

0 comments on commit 7880b2c

Please sign in to comment.