@@ -1092,6 +1092,10 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
1092
1092
rewriter.moveOpBefore (op, &forallBody->getOperations ().front ());
1093
1093
}
1094
1094
1095
+ bool isSingleTripLoop = forallOp.isNormalized () &&
1096
+ llvm::all_of (forallOp.getStaticUpperBound (),
1097
+ [](int64_t i) { return i == 1 ; });
1098
+
1095
1099
// Step 2. Collect the set of tensor.parallel_insert_slice ops in the
1096
1100
// terminator and their paired extract_slice ops from the for loop iter arg.
1097
1101
SmallVector<Operation *> sliceOperandProducers;
@@ -1106,7 +1110,8 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
1106
1110
scf::InParallelOp parallelTerminator = forallOp.getTerminator ();
1107
1111
SmallVector<tensor::ParallelInsertSliceOp> terminators (
1108
1112
forallOp.getNumResults ());
1109
- SmallVector<tensor::ExtractSliceOp> pairedSlices (forallOp.getNumResults ());
1113
+ SmallVector<std::optional<tensor::ExtractSliceOp>> pairedSlices (
1114
+ forallOp.getNumResults (), std::nullopt);
1110
1115
int64_t numInductionVars = forallOp.getInductionVars ().size ();
1111
1116
for (auto &yieldingOp : parallelTerminator.getYieldingOps ()) {
1112
1117
auto parallelInsert = cast<tensor::ParallelInsertSliceOp>(&yieldingOp);
@@ -1117,28 +1122,58 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
1117
1122
if (user == parallelInsert)
1118
1123
continue ;
1119
1124
auto maybeSlice = dyn_cast<tensor::ExtractSliceOp>(user);
1120
- // Fail if the destination has more users than a direct insert and
1121
- // extract slice.
1122
1125
if (!maybeSlice) {
1123
- return failure ();
1126
+ // Fail if the destination has more users than a direct insert and
1127
+ // extract slice unless it is a single trip loop.
1128
+ if (!isSingleTripLoop) {
1129
+ return failure ();
1130
+ }
1131
+ continue ;
1124
1132
}
1125
- // Require a single extract per destination.
1133
+ // Require at most one extract per destination.
1126
1134
if (destSlice) {
1127
1135
return failure ();
1128
1136
}
1129
1137
destSlice = maybeSlice;
1130
1138
}
1139
+
1131
1140
// Verify they operate on equivalent subsets, ensuring the slices are
1132
1141
// hoistable. It is still possible to hoist the loop if this is not true,
1133
1142
// however in such cases we likely formed the loops in the wrong order.
1134
- if (!cast<SubsetOpInterface>(*destSlice)
1135
- .operatesOnEquivalentSubset (
1136
- cast<SubsetOpInterface>(*parallelInsert),
1137
- [](Value v1, Value v2) { return v1 == v2; })) {
1143
+ if (destSlice && !cast<SubsetOpInterface>(*destSlice)
1144
+ .operatesOnEquivalentSubset (
1145
+ cast<SubsetOpInterface>(*parallelInsert),
1146
+ [](Value v1, Value v2) { return v1 == v2; })) {
1138
1147
return failure ();
1139
1148
}
1140
- terminators[destBbArg.getArgNumber () - numInductionVars] = parallelInsert;
1141
- pairedSlices[destBbArg.getArgNumber () - numInductionVars] = destSlice;
1149
+
1150
+ auto isOverwritingFullDestination =
1151
+ [](tensor::ParallelInsertSliceOp insert) {
1152
+ // TODO: Handle rank reducing case.
1153
+ if (insert.getSourceType ().getRank () !=
1154
+ insert.getDestType ().getRank ()) {
1155
+ return false ;
1156
+ }
1157
+ for (auto [dim, size] : llvm::enumerate (insert.getMixedSizes ())) {
1158
+ FailureOr<bool > equalDimSize = ValueBoundsConstraintSet::areEqual (
1159
+ {size}, {insert.getDest (), static_cast <int64_t >(dim)});
1160
+ if (failed (equalDimSize) || !*equalDimSize)
1161
+ return false ;
1162
+ }
1163
+ return true ;
1164
+ };
1165
+
1166
+ // For single trip loops, verify that the parallel_insert_slice is
1167
+ // overwriting the full destination.
1168
+ if (!destSlice && !isOverwritingFullDestination (parallelInsert)) {
1169
+ return failure ();
1170
+ }
1171
+
1172
+ int64_t argId = destBbArg.getArgNumber () - numInductionVars;
1173
+ terminators[argId] = parallelInsert;
1174
+ if (destSlice) {
1175
+ pairedSlices[argId] = destSlice;
1176
+ }
1142
1177
1143
1178
// Collect all of the offset/size/stride operands for both slices and
1144
1179
// compute a backwards slice of the program from them. Fail if any of
@@ -1148,10 +1183,12 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
1148
1183
parallelInsert.getOperands ().begin () +
1149
1184
parallelInsert.getOffsetSizeAndStrideStartOperandIndex (),
1150
1185
parallelInsert.getOperands ().end ());
1151
- sliceOperands.insert (
1152
- destSlice.getOperands ().begin () +
1153
- destSlice.getOffsetSizeAndStrideStartOperandIndex (),
1154
- destSlice.getOperands ().end ());
1186
+ if (destSlice) {
1187
+ sliceOperands.insert (
1188
+ destSlice.getOperands ().begin () +
1189
+ destSlice.getOffsetSizeAndStrideStartOperandIndex (),
1190
+ destSlice.getOperands ().end ());
1191
+ }
1155
1192
for (Value operand : sliceOperands) {
1156
1193
if (auto bbArg = dyn_cast<BlockArgument>(operand)) {
1157
1194
if (bbArg.getOwner ()->getParentOp () == loop) {
@@ -1200,8 +1237,15 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
1200
1237
OpBuilder::InsertionGuard g (rewriter);
1201
1238
rewriter.setInsertionPoint (newForallOp.getTerminator ());
1202
1239
SmallVector<Value> newInits;
1203
- for (auto slice : pairedSlices) {
1204
- newInits.push_back (slice.getResult ());
1240
+ for (auto [iterArgId, slice] : llvm::enumerate (pairedSlices)) {
1241
+ if (slice) {
1242
+ newInits.push_back (slice.value ().getResult ());
1243
+ continue ;
1244
+ }
1245
+
1246
+ // If there is no paired slice (for a single trip count loop) then
1247
+ // use the iter arg of the forall op directly.
1248
+ newInits.push_back (newForallOp.getRegionIterArgs ()[iterArgId]);
1205
1249
}
1206
1250
// Step 4. Create a new for loop with new inits for the result of the
1207
1251
// extracted slices.
@@ -1224,7 +1268,10 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
1224
1268
// args.
1225
1269
for (auto [hoistedSlice, iterArg] :
1226
1270
llvm::zip_equal (pairedSlices, newLoop.getRegionIterArgs ())) {
1227
- rewriter.replaceAllUsesExcept (hoistedSlice, iterArg, newLoop);
1271
+ if (hoistedSlice) {
1272
+ rewriter.replaceAllUsesExcept (hoistedSlice.value (), iterArg,
1273
+ newLoop);
1274
+ }
1228
1275
}
1229
1276
1230
1277
// Create the terminator for the new loop using the sources of the
@@ -1243,7 +1290,9 @@ struct HoistForallFromFor : public OpRewritePattern<scf::ForOp> {
1243
1290
rewriter.moveOpBefore (sliceOperandProducer, newLoop);
1244
1291
}
1245
1292
for (auto slice : pairedSlices) {
1246
- rewriter.moveOpBefore (slice, newLoop);
1293
+ if (slice) {
1294
+ rewriter.moveOpBefore (slice.value (), newLoop);
1295
+ }
1247
1296
}
1248
1297
1249
1298
// Create the new terminator for the hoisted forall loop using the results
0 commit comments