Skip to content

Commit 6c9aad0

Browse files
authored
[Codegen][GPU] Improve forall hoisting pattern for single trip loops (#18418)
For single trip scf.forall loops the `tensor.extract_slice` on the output can be folded away, causing the forall loop hoisting pattern to fail. Single trip loops with processor ID mappings cannot be folded away because they can resolve to an `scf.if`. So this patch extends the loop hoisting pattern to support hoisting in the case of single trip loops where the `tensor.extract_slice` has been folded away.
1 parent edc5d5e commit 6c9aad0

File tree

3 files changed

+119
-25
lines changed

3 files changed

+119
-25
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include "mlir/Support/LogicalResult.h"
2121
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2222

23+
#define DEBUG_TYPE "iree-gpu-fuse-and-hoist-parallel-loops"
24+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
26+
2327
namespace mlir::iree_compiler::IREE::GPU {
2428

2529
#define GEN_PASS_DEF_FUSEANDHOISTPARALLELLOOPSPASS
@@ -192,6 +196,8 @@ struct FuseTilableForallConsumers final
192196
void FuseAndHoistParallelLoopsPass::runOnOperation() {
193197
MLIRContext *context = &getContext();
194198

199+
FunctionOpInterface funcOp = getOperation();
200+
195201
// First run the hoisting and fusion patterns.
196202
{
197203
RewritePatternSet patterns(context);
@@ -200,12 +206,13 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() {
200206
patterns.add<FuseForalls>(context);
201207
patterns.add<FuseTilableForallConsumers>(context);
202208
populateForallLoopHoistingPattern(patterns);
203-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
204-
std::move(patterns)))) {
209+
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
205210
return signalPassFailure();
206211
}
207212
}
208213

214+
LDBG("After fusing and hoisting loops\n" << funcOp);
215+
209216
// After hoisting parallel loops, try to fuse in any newly revealed consumers
210217
// and destinations.
211218
// TODO: Move the consumer fusion pattern to an explicit worklist rather than
@@ -216,24 +223,26 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() {
216223
patterns.add<FuseTilableForallConsumers>(context);
217224
tensor::populateFoldTensorEmptyPatterns(patterns);
218225
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
219-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
220-
std::move(patterns)))) {
226+
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
221227
return signalPassFailure();
222228
}
223229
}
224230

231+
LDBG("After fusing new consumers\n" << funcOp);
232+
225233
// Finally try to do any new producer fusions.
226234
{
227235
RewritePatternSet patterns(context);
228236
patterns.add<FuseTilableDestinationProducers>(context);
229237
patterns.add<FuseTilableSliceProducers>(context);
230238
tensor::populateFoldTensorEmptyPatterns(patterns);
231239
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
232-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
233-
std::move(patterns)))) {
240+
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
234241
return signalPassFailure();
235242
}
236243
}
244+
245+
LDBG("After fusing new producers\n" << funcOp);
237246
}
238247

239248
} // namespace mlir::iree_compiler::IREE::GPU

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,39 @@ func.func @multi_hoist_with_other_ops_in_loop(%2: tensor<128x128xf16>, %3: tenso
304304
// CHECK: scf.forall.in_parallel
305305
// CHECK: scf.forall.in_parallel
306306
// CHECK: return
307+
308+
// -----
309+
310+
func.func @hoist_with_single_trip_loops(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> {
311+
%c4 = arith.constant 4 : index
312+
%c128 = arith.constant 128 : index
313+
%c0 = arith.constant 0 : index
314+
%empty = tensor.empty() : tensor<128x128xf16>
315+
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %empty) -> (tensor<128x128xf16>) {
316+
%9 = scf.forall (%arg2, %arg3) in (1, 1) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf16>) {
317+
%extracted_slice = tensor.extract_slice %arg4[%arg2, %arg3] [128, 128] [1, 1] : tensor<128x128xf16> to tensor<128x128xf16>
318+
%10 = scf.forall (%arg5, %arg6) in (1, 1) shared_outs(%arg7 = %extracted_slice) -> (tensor<128x128xf16>) {
319+
%16 = linalg.copy ins(%arg7 : tensor<128x128xf16>) outs(%2 : tensor<128x128xf16>) -> tensor<128x128xf16>
320+
scf.forall.in_parallel {
321+
tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [128, 128] [1, 1] : tensor<128x128xf16> into tensor<128x128xf16>
322+
}
323+
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
324+
scf.forall.in_parallel {
325+
tensor.parallel_insert_slice %10 into %arg4[%arg2, %arg3] [128, 128] [1, 1] : tensor<128x128xf16> into tensor<128x128xf16>
326+
}
327+
} {mapping = [#gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_0>]}
328+
scf.yield %9 : tensor<128x128xf16>
329+
}
330+
return %8 : tensor<128x128xf16>
331+
}
332+
333+
// CHECK-LABEL: func @hoist_with_single_trip_loops
334+
// CHECK-SAME: %[[I0:[A-Za-z0-9]+]]: tensor<128x128xf16>
335+
// CHECK-SAME: %[[I1:[A-Za-z0-9]+]]: tensor<128x128xf16>
336+
// CHECK: scf.forall
337+
// CHECK: scf.forall
338+
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} -> (tensor<128x128xf16>)
339+
// CHECK: linalg.copy
340+
// CHECK: scf.forall.in_parallel
341+
// CHECK: scf.forall.in_parallel
342+
// CHECK: return

compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,10 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
10921092
rewriter.moveOpBefore(op, &forallBody->getOperations().front());
10931093
}
10941094

1095+
bool isSingleTripLoop = forallOp.isNormalized() &&
1096+
llvm::all_of(forallOp.getStaticUpperBound(),
1097+
[](int64_t i) { return i == 1; });
1098+
10951099
// Step 2. Collect the set of tensor.parallel_insert_slice ops in the
10961100
// terminator and their paired extract_slice ops from the for loop iter arg.
10971101
SmallVector<Operation *> sliceOperandProducers;
@@ -1106,7 +1110,8 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
11061110
scf::InParallelOp parallelTerminator = forallOp.getTerminator();
11071111
SmallVector<tensor::ParallelInsertSliceOp> terminators(
11081112
forallOp.getNumResults());
1109-
SmallVector<tensor::ExtractSliceOp> pairedSlices(forallOp.getNumResults());
1113+
SmallVector<std::optional<tensor::ExtractSliceOp>> pairedSlices(
1114+
forallOp.getNumResults(), std::nullopt);
11101115
int64_t numInductionVars = forallOp.getInductionVars().size();
11111116
for (auto &yieldingOp : parallelTerminator.getYieldingOps()) {
11121117
auto parallelInsert = cast<tensor::ParallelInsertSliceOp>(&yieldingOp);
@@ -1117,28 +1122,58 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
11171122
if (user == parallelInsert)
11181123
continue;
11191124
auto maybeSlice = dyn_cast<tensor::ExtractSliceOp>(user);
1120-
// Fail if the destination has more users than a direct insert and
1121-
// extract slice.
11221125
if (!maybeSlice) {
1123-
return failure();
1126+
// Fail if the destination has more users than a direct insert and
1127+
// extract slice unless it is a single trip loop.
1128+
if (!isSingleTripLoop) {
1129+
return failure();
1130+
}
1131+
continue;
11241132
}
1125-
// Require a single extract per destination.
1133+
// Require at most one extract per destination.
11261134
if (destSlice) {
11271135
return failure();
11281136
}
11291137
destSlice = maybeSlice;
11301138
}
1139+
11311140
// Verify they operate on equivalent subsets, ensuring the slices are
11321141
// hoistable. It is still possible to hoist the loop if this is not true,
11331142
// however in such cases we likely formed the loops in the wrong order.
1134-
if (!cast<SubsetOpInterface>(*destSlice)
1135-
.operatesOnEquivalentSubset(
1136-
cast<SubsetOpInterface>(*parallelInsert),
1137-
[](Value v1, Value v2) { return v1 == v2; })) {
1143+
if (destSlice && !cast<SubsetOpInterface>(*destSlice)
1144+
.operatesOnEquivalentSubset(
1145+
cast<SubsetOpInterface>(*parallelInsert),
1146+
[](Value v1, Value v2) { return v1 == v2; })) {
11381147
return failure();
11391148
}
1140-
terminators[destBbArg.getArgNumber() - numInductionVars] = parallelInsert;
1141-
pairedSlices[destBbArg.getArgNumber() - numInductionVars] = destSlice;
1149+
1150+
auto isOverwritingFullDestination =
1151+
[](tensor::ParallelInsertSliceOp insert) {
1152+
// TODO: Handle rank reducing case.
1153+
if (insert.getSourceType().getRank() !=
1154+
insert.getDestType().getRank()) {
1155+
return false;
1156+
}
1157+
for (auto [dim, size] : llvm::enumerate(insert.getMixedSizes())) {
1158+
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
1159+
{size}, {insert.getDest(), static_cast<int64_t>(dim)});
1160+
if (failed(equalDimSize) || !*equalDimSize)
1161+
return false;
1162+
}
1163+
return true;
1164+
};
1165+
1166+
// For single trip loops, verify that the parallel_insert_slice is
1167+
// overwriting the full destination.
1168+
if (!destSlice && !isOverwritingFullDestination(parallelInsert)) {
1169+
return failure();
1170+
}
1171+
1172+
int64_t argId = destBbArg.getArgNumber() - numInductionVars;
1173+
terminators[argId] = parallelInsert;
1174+
if (destSlice) {
1175+
pairedSlices[argId] = destSlice;
1176+
}
11421177

11431178
// Collect all of the offset/size/stride operands for both slices and
11441179
// compute a backwards slice of the program from them. Fail if any of
@@ -1148,10 +1183,12 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
11481183
parallelInsert.getOperands().begin() +
11491184
parallelInsert.getOffsetSizeAndStrideStartOperandIndex(),
11501185
parallelInsert.getOperands().end());
1151-
sliceOperands.insert(
1152-
destSlice.getOperands().begin() +
1153-
destSlice.getOffsetSizeAndStrideStartOperandIndex(),
1154-
destSlice.getOperands().end());
1186+
if (destSlice) {
1187+
sliceOperands.insert(
1188+
destSlice.getOperands().begin() +
1189+
destSlice.getOffsetSizeAndStrideStartOperandIndex(),
1190+
destSlice.getOperands().end());
1191+
}
11551192
for (Value operand : sliceOperands) {
11561193
if (auto bbArg = dyn_cast<BlockArgument>(operand)) {
11571194
if (bbArg.getOwner()->getParentOp() == loop) {
@@ -1200,8 +1237,15 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
12001237
OpBuilder::InsertionGuard g(rewriter);
12011238
rewriter.setInsertionPoint(newForallOp.getTerminator());
12021239
SmallVector<Value> newInits;
1203-
for (auto slice : pairedSlices) {
1204-
newInits.push_back(slice.getResult());
1240+
for (auto [iterArgId, slice] : llvm::enumerate(pairedSlices)) {
1241+
if (slice) {
1242+
newInits.push_back(slice.value().getResult());
1243+
continue;
1244+
}
1245+
1246+
// If there is no paired slice (for a single trip count loop) then
1247+
// use the iter arg of the forall op directly.
1248+
newInits.push_back(newForallOp.getRegionIterArgs()[iterArgId]);
12051249
}
12061250
// Step 4. Create a new for loop with new inits for the result of the
12071251
// extracted slices.
@@ -1224,7 +1268,10 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
12241268
// args.
12251269
for (auto [hoistedSlice, iterArg] :
12261270
llvm::zip_equal(pairedSlices, newLoop.getRegionIterArgs())) {
1227-
rewriter.replaceAllUsesExcept(hoistedSlice, iterArg, newLoop);
1271+
if (hoistedSlice) {
1272+
rewriter.replaceAllUsesExcept(hoistedSlice.value(), iterArg,
1273+
newLoop);
1274+
}
12281275
}
12291276

12301277
// Create the terminator for the new loop using the sources of the
@@ -1243,7 +1290,9 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
12431290
rewriter.moveOpBefore(sliceOperandProducer, newLoop);
12441291
}
12451292
for (auto slice : pairedSlices) {
1246-
rewriter.moveOpBefore(slice, newLoop);
1293+
if (slice) {
1294+
rewriter.moveOpBefore(slice.value(), newLoop);
1295+
}
12471296
}
12481297

12491298
// Create the new terminator for the hoisted forall loop using the results

0 commit comments

Comments
 (0)