@@ -217,9 +217,10 @@ lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp,
217
217
return newEmptyOp;
218
218
}
219
219
220
- // / Converts a linalg::GenericOp with encoded inputs into the packed domain.
221
- // / The `genericOp` must have all parallel iterator types and a single output
222
- // / with an identity indexing map.
220
+ // / Converts a linalg::GenericOp with encoded inputs into the packed domain,
221
+ // / with an optional swizzle expansion and permutation if applicable. The
222
+ // / `genericOp` must have all parallel iterator types and a single output with
223
+ // / an identity indexing map.
223
224
static FailureOr<Operation *> lowerGenericOpWithEncoding (
224
225
RewriterBase &rewriter, linalg::GenericOp genericOp,
225
226
ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
@@ -230,30 +231,119 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
230
231
return rewriter.notifyMatchFailure (genericOp,
231
232
" Output indexing map is not identity" );
232
233
}
234
+ // Step 1: Retrieve the output encoding materialization information and
235
+ // compute the new indexing maps for the packed and potentially swizzled
236
+ // layout. This consists of an outer dimension and inner dimension permutation
237
+ // vectors for the packing and an expanded result dimension permutation vector
238
+ // for the optional swizzling. This assumes that the output map is identity,
239
+ // and that all iterator types are parallel.
240
+ //
241
+ // Running example:
242
+ //
243
+ // Given following output layout:
244
+ //
245
+ // outputType: tensor<2x128x64xf32>
246
+ // outputPackInfo: innerDimsPos = [1, 2],
247
+ // innerTileSizes = [128, 16]
248
+ // outerDimsPerm = [0, 1, 2]
249
+ // outputSwizzle: expandShape = [[4, 8, 4], [4, 4]]
250
+ // permutation = [1, 4, 0, 2, 3]}
251
+ //
252
+ // Retrieve and compute the permutation vectors for the packing outer and
253
+ // inner dimension permutation and for the expanded swizzle permutation. Then,
254
+ // calculate the permutation that would transform the swizzled output
255
+ // dimension map into the identity dimension map. This is the inverse swizzle
256
+ // permutation.
257
+ //
258
+ // outInverseOuterDimsPerm: [0, 1, 2]
259
+ // outInnerDimsPos: [1, 2]
260
+ // outSwizzlePerm: [0, 1, 2, 4, 7, 3, 5, 6]
261
+ // invOutSwizzlePerm: [0, 1, 2, 5, 3, 6, 7, 4]
233
262
MaterializeEncodingInfo outMaterializeEncodingInfo =
234
263
typeConverter.getEncodingInfo (
235
264
cast<RankedTensorType>(outputOperand->get ().getType ()));
236
265
if (IREE::Codegen::isIdentityLayout (outMaterializeEncodingInfo)) {
237
- return rewriter.notifyMatchFailure (
238
- genericOp, " MaterializeEncodingInfo failed for output" );
239
- }
240
- if (outMaterializeEncodingInfo.swizzle ) {
241
- return rewriter.notifyMatchFailure (
242
- genericOp, " generic op lowering does not support swizzle yet" );
266
+ return dropEncodingAndCloneOp (rewriter, genericOp.getOperation (),
267
+ convertedInputOperands,
268
+ convertedOutputOperands);
243
269
}
244
270
245
271
auto convertedResultType =
246
272
cast<RankedTensorType>(convertedOutputOperands[0 ].getType ());
247
273
SmallVector<utils::IteratorType> iteratorTypes (convertedResultType.getRank (),
248
274
utils::IteratorType::parallel);
249
- // Compute the new indexing maps for the packed layout. This assumes that
250
- // the output map is identity, and that all iterator types are parallel.
251
- SmallVector<int64_t > outInnerDimsPos =
252
- outMaterializeEncodingInfo.innerDimsPos ;
275
+
253
276
SmallVector<int64_t > outInverseOuterDimsPerm =
254
277
invertPermutationVector (outMaterializeEncodingInfo.outerDimsPerm );
278
+ ArrayRef<int64_t > outInnerDimsPos = outMaterializeEncodingInfo.innerDimsPos ;
279
+ SmallVector<int64_t > outSwizzlePerm =
280
+ llvm::to_vector (llvm::seq<int64_t >(0 , convertedResultType.getRank ()));
281
+ if (outMaterializeEncodingInfo.swizzle .has_value ()) {
282
+ const int outRank =
283
+ cast<RankedTensorType>(outputOperand->get ().getType ()).getRank ();
284
+ SmallVector<int64_t > transposePerm =
285
+ llvm::to_vector (llvm::seq<int64_t >(0 , outRank));
286
+ for (auto perm : outMaterializeEncodingInfo.swizzle ->permutation ) {
287
+ transposePerm.push_back (outRank + perm);
288
+ }
289
+ applyPermutationToVector (outSwizzlePerm, transposePerm);
290
+ }
291
+ SmallVector<int64_t > invOutSwizzlePerm =
292
+ invertPermutationVector (outSwizzlePerm);
293
+
294
+ // Calculate the running offset for every dimension position for easy lookup
295
+ // when calculating the packed result dimensions for every operand.
296
+ // Example:
297
+ // expandShape == [[4, 8, 4], [4, 4]]
298
+ // In this case:
299
+ // outOffsetForDimsPos == [0, 3]
300
+ // So that whenever we need the real dimension for an entry (`outerIndex`,
301
+ // `innerIndex`) in the 2D expanded shape vector, we can calculate it as:
302
+ // dim(outerIndex, innerIndex) = outOffsetForDimsPos[outerIndex] +
303
+ // innerIndex
304
+ SmallVector<int64_t > outOffsetForDimsPos (outInnerDimsPos.size (), 0 );
305
+ if (outMaterializeEncodingInfo.swizzle .has_value ()) {
306
+ int64_t runningSize = 0 ;
307
+ for (size_t i = 0 ; i < outInnerDimsPos.size (); i++) {
308
+ outOffsetForDimsPos[i] = runningSize;
309
+ runningSize += outMaterializeEncodingInfo.swizzle ->expandShape [i].size ();
310
+ }
311
+ }
312
+
255
313
SmallVector<AffineMap> packedIndexingMaps;
256
314
for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
315
+ // Step 2: Retrieve the encoding for every input operand and perform the
316
+ // outer dimension permutation, inner dimension expansion and permutation,
317
+ // swizzle expansion and swizzle permutation.
318
+ //
319
+ // Running example:
320
+ //
321
+ // Given the input layout and indexing maps:
322
+ //
323
+ // inputType: tensor<2x64xf32>
324
+ // innerPackInfo: innerDimsPos = [1]
325
+ // innerTileSizes = [16]
326
+ // outerDimsPerm = [0, 1]
327
+ // innerSwizzle: expandShape = [[4, 4]]
328
+ // permutation = [1, 0]
329
+ // inputMap: [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
330
+ // affine_map<(d0, d1, d2) -> (d0, d2)>]
331
+ //
332
+ // 1. Calculate the result dimensions from the indexing maps and perform the
333
+ // outer dimension permutation:
334
+ //
335
+ // packedResultDims: [0, 2]
336
+ //
337
+ // 2. Perform inner dimension expansion, permutation and optional swizzle
338
+ // expansion in one go. In this example, the inner dimension (64) would be
339
+ // expanded into 4x16 based on `innerDimsPos` and `innerTileSizes` above,
340
+ // and then expanded to 4x4x4 based on the swizzle.
341
+ //
342
+ // packedResultDims: [0, 2, 6, 7]
343
+ //
344
+ // 3. Perform the swizzle permutation:
345
+ //
346
+ // packedResultDims: [0, 2, 7, 6]
257
347
MaterializeEncodingInfo materializeEncodingInfo =
258
348
typeConverter.getEncodingInfo (
259
349
cast<RankedTensorType>(inputOperand->get ().getType ()));
@@ -277,14 +367,72 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
277
367
for (auto [idx, pos] : llvm::enumerate (innerDimsPos)) {
278
368
auto dimPos = cast<AffineDimExpr>(inputMap.getResult (pos)).getPosition ();
279
369
for (auto [tileIdx, outDim] : llvm::enumerate (outInnerDimsPos)) {
280
- if (dimPos == outDim) {
370
+ if (dimPos != outDim) {
371
+ continue ;
372
+ }
373
+ if (!materializeEncodingInfo.swizzle .has_value ()) {
281
374
packedResultDims.push_back (outputMap.getNumDims () + tileIdx);
375
+ continue ;
282
376
}
377
+ // In case of a layout with swizzle, an expanded set of dimensions
378
+ // needs to be appended as specified by the swizzle's `expandedShape`
379
+ // field. Note that the dimension index should be offset by the
380
+ // calculated output starting offset as every dimension is now
381
+ // transformed into an expanded sequence of indices and the correct
382
+ // dimension index is:
383
+ // outOffsetForDimsPos[tileIdx] + innerIndex
384
+ assert (idx < materializeEncodingInfo.swizzle ->expandShape .size () &&
385
+ " `innerDimsPos` index should not exceed the swizzle's "
386
+ " `expandShape` size" );
387
+ const size_t dimSize =
388
+ materializeEncodingInfo.swizzle ->expandShape [idx].size ();
389
+ const int64_t outIdxOffset =
390
+ outputMap.getNumDims () + outOffsetForDimsPos[tileIdx];
391
+ for (size_t i = 0 ; i < dimSize; i++) {
392
+ packedResultDims.push_back (outIdxOffset + i);
393
+ }
394
+ }
395
+ }
396
+ // In case of a layout with swizzle, the packed result dimensions need
397
+ // to be transposed according to the swizzle's permutation vector.
398
+ if (materializeEncodingInfo.swizzle .has_value ()) {
399
+ int inRank =
400
+ cast<RankedTensorType>(inputOperand->get ().getType ()).getRank ();
401
+ SmallVector<int64_t > transposePerm =
402
+ llvm::to_vector (llvm::seq<int64_t >(0 , inRank));
403
+ for (auto perm : materializeEncodingInfo.swizzle ->permutation ) {
404
+ transposePerm.push_back (inRank + perm);
283
405
}
406
+ applyPermutationToVector (packedResultDims, transposePerm);
284
407
}
408
+
409
+ // Step 3: Calculate the final packed result dimensions through the inverse
410
+ // result dimensions permutation map. This effectively linearizes the packed
411
+ // result dimensions with respect to the output dimensions. For example, if
412
+ // the permuted output dimensions are [D0, D2, D1], this will transform all
413
+ // packed operand result dimensions with the permutation map that would make
414
+ // the output dimensions the identity map [D0, D1, D2], i.e. {D0 -> D0, D1
415
+ // -> D2, D2 -> D1}. Suppose that the operand dimensions are [D0, D2], this
416
+ // operation would transform it into [D0, D1] to align with the output
417
+ // identity map.
418
+ //
419
+ // Running example:
420
+ //
421
+ // The packed and swizzled result dimensions for the input operand:
422
+ //
423
+ // packedResultDims: [0, 2, 7, 6]
424
+ //
425
+ // Now we need to account for swizzled output result dimensions being
426
+ // linearized to the identity map. This can be achieved by applying
427
+ // `invOutSwizzlePerm` ([0, 1, 2, 5, 3, 6, 7, 4]):
428
+ //
429
+ // finalPackedResultDims: [0, 2, 4, 7]
430
+ SmallVector<int64_t > finalPackedResultDims = llvm::map_to_vector (
431
+ packedResultDims, [&](int64_t r) { return invOutSwizzlePerm[r]; });
432
+
285
433
// Create the packed indexing map.
286
434
SmallVector<AffineExpr> packedResultExprs =
287
- llvm::map_to_vector (packedResultDims , [&](int64_t dim) {
435
+ llvm::map_to_vector (finalPackedResultDims , [&](int64_t dim) {
288
436
return rewriter.getAffineDimExpr (dim);
289
437
});
290
438
auto packedInputMap = AffineMap::get (
0 commit comments