Skip to content

Commit

Permalink
[mlir][tensor] Add canonicalization to fold consecutive tensor.pad ops (
Browse files Browse the repository at this point in the history
llvm#107302)

`tensor.pad(tensor.pad)` with the same constant padding value can be
combined into a single pad that pads to the sum of the high and low
padding amounts.
  • Loading branch information
qedawkins authored Sep 9, 2024
1 parent ea92045 commit 6cc3bf7
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 1 deletion.
80 changes: 79 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3402,12 +3402,90 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
}
};

/// Folds a chain of `tensor.pad` ops with the same constant padding value.
///
/// Example:
///
/// ```mlir
/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
/// tensor.yield %val
/// } : tensor<1x2xf32> to tensor<2x5xf32>
/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
/// tensor.yield %val
/// } : tensor<1x5xf32> to tensor<5x7xf32>
/// ```
///
/// folds into:
///
/// ```mlir
/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
/// tensor.yield %val
/// } : tensor<1x2xf32> to tensor<5x7xf32>
/// ```
struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
if (padOp.getNofold()) {
return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
}

auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
if (!producerPad || producerPad.getNofold()) {
return rewriter.notifyMatchFailure(
padOp, "producer is not a foldable tensor.pad op");
}

// Fail if the tensor::PadOps padding values do not match.
Value consumerPadValue = padOp.getConstantPaddingValue();
Value producerPadValue = producerPad.getConstantPaddingValue();
if (!consumerPadValue || !producerPadValue ||
consumerPadValue != producerPadValue) {
return rewriter.notifyMatchFailure(
padOp,
"cannot fold PadOps with different or non-constant padding values");
}

Location loc = padOp.getLoc();
AffineExpr d0, d1;
bindDims(rewriter.getContext(), d0, d1);

// Combine the low/high paddings of the two tensor::PadOps.
auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
ArrayRef<OpFoldResult> producerPaddings) {
SmallVector<OpFoldResult> sumPaddings;
for (auto [consumerIndex, producerIndex] :
llvm::zip_equal(consumerPaddings, producerPaddings)) {
sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
}
return sumPaddings;
};

SmallVector<OpFoldResult> newHighPad =
addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
SmallVector<OpFoldResult> newLowPad =
addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());

auto newPadOp = rewriter.create<tensor::PadOp>(
padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
newLowPad, newHighPad, padOp.getNofold(),
getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
newPadOp.getRegion().begin());
rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
}
};

} // namespace

void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
FoldOrthogonalPaddings, FoldStaticPadding>(context);
FoldOrthogonalPaddings, FoldStaticPadding,
FoldConsecutiveConstantPadding>(context);
}

/// Return the padding value of the PadOp if it constant. In this context,
Expand Down
82 changes: 82 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1964,6 +1964,88 @@ func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,

// -----

// CHECK-LABEL: func @merge_constant_padding
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
// CHECK: tensor.yield %[[PADVAL]]
// CHECK: return %[[PAD]]
func.func @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
%pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<2x3xf32> to tensor<4x4xf32>
%pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
^bb0(%b2: index, %b3 : index):
tensor.yield %pad_value : f32
} : tensor<4x4xf32> to tensor<7x8xf32>
return %pad1 : tensor<7x8xf32>
}

// -----

// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1)>
// CHECK-LABEL: func @merge_constant_padding_dynamic
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index
// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
// CHECK: %[[HIGH:.+]] = affine.apply #[[$MAP]]()[%[[IDX]]]
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
// CHECK: tensor.yield %[[PADVAL]]
// CHECK: return %[[PAD]]
func.func @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
%pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
%pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
^bb0(%b2: index, %b3 : index):
tensor.yield %pad_value : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
return %pad1 : tensor<?x?xf32>
}

// -----

// Verify that folding does not happen if it would drop a nofold attribute
// CHECK-LABEL: func @dont_merge_constant_padding_nofold
// CHECK: tensor.pad {{.*}} nofold
// CHECK: tensor.pad
func.func @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
%pad0 = tensor.pad %arg0 nofold low[1, 1] high[1, 0] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<2x3xf32> to tensor<4x4xf32>
%pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
^bb0(%b2: index, %b3 : index):
tensor.yield %pad_value : f32
} : tensor<4x4xf32> to tensor<7x8xf32>
return %pad1 : tensor<7x8xf32>
}

// -----

// Verify that folding does not happen if it would drop a nofold attribute
// CHECK-LABEL: func @dont_merge_constant_padding_different_vals
// CHECK: tensor.pad
// CHECK: tensor.pad
func.func @dont_merge_constant_padding_different_vals(
%arg0: tensor<2x3xf32>,
%pad_value0: f32,
%pad_value1: f32) -> tensor<7x8xf32> {
%pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value0 : f32
} : tensor<2x3xf32> to tensor<4x4xf32>
%pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
^bb0(%b2: index, %b3 : index):
tensor.yield %pad_value1 : f32
} : tensor<4x4xf32> to tensor<7x8xf32>
return %pad1 : tensor<7x8xf32>
}

// -----

// CHECK-LABEL: func @fold_collapse_shape_from_elements
func.func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
// CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>
Expand Down

0 comments on commit 6cc3bf7

Please sign in to comment.