Skip to content

Commit eb8c1ed

Browse files
committed
[DT] Add generic op materialization pattern for GPU
Signed-off-by: Jorn Tuyls <jorn.tuyls@gmail.com>
1 parent 2dd6e83 commit eb8c1ed

File tree

11 files changed

+468
-26
lines changed

11 files changed

+468
-26
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp

+85-10
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,9 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
234234
typeConverter.getEncodingInfo(
235235
cast<RankedTensorType>(outputOperand->get().getType()));
236236
if (IREE::Codegen::isIdentityLayout(outMaterializeEncodingInfo)) {
237-
return rewriter.notifyMatchFailure(
238-
genericOp, "MaterializeEncodingInfo failed for output");
239-
}
240-
if (outMaterializeEncodingInfo.swizzle) {
241-
return rewriter.notifyMatchFailure(
242-
genericOp, "generic op lowering does not support swizzle yet");
237+
return dropEncodingAndCloneOp(rewriter, genericOp.getOperation(),
238+
convertedInputOperands,
239+
convertedOutputOperands);
243240
}
244241

245242
auto convertedResultType =
@@ -248,8 +245,41 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
248245
utils::IteratorType::parallel);
249246
// Compute the new indexing maps for the packed layout. This assumes that
250247
// the output map is identity, and that all iterator types are parallel.
251-
SmallVector<int64_t> outInnerDimsPos =
252-
outMaterializeEncodingInfo.innerDimsPos;
248+
ArrayRef<int64_t> outInnerDimsPos = outMaterializeEncodingInfo.innerDimsPos;
249+
SmallVector<int64_t> outResultDimsPerm =
250+
llvm::to_vector(llvm::seq<int64_t>(0, convertedResultType.getRank()));
251+
if (outMaterializeEncodingInfo.swizzle.has_value()) {
252+
int outRank =
253+
cast<RankedTensorType>(outputOperand->get().getType()).getRank();
254+
SmallVector<int64_t> transposePerm =
255+
llvm::to_vector(llvm::seq<int64_t>(0, outRank));
256+
for (auto perm : outMaterializeEncodingInfo.swizzle->permutation) {
257+
transposePerm.push_back(outRank + perm);
258+
}
259+
applyPermutationToVector(outResultDimsPerm, transposePerm);
260+
}
261+
SmallVector<int64_t> invOutResultDimsPerm =
262+
invertPermutationVector(outResultDimsPerm);
263+
264+
// Calculate the running offset for every dimension position for easy lookup
265+
// when calculating the packed result dimensions for every operand.
266+
// Example:
267+
// expandShape == [[4, 8, 4], [4, 4]]
268+
// In this case:
269+
// outOffsetForDimsPos == [0, 3]
270+
// So that whenever we need the real dimension for an entry (`outerIndex`,
271+
// `innerIndex`) in the 2D expanded shape vector, we can calculate it as:
272+
// dim(outerIndex, innerIndex) = outOffsetForDimsPos[outerIndex] +
273+
// innerIndex
274+
SmallVector<int64_t> outOffsetForDimsPos(outInnerDimsPos.size(), 0);
275+
if (outMaterializeEncodingInfo.swizzle.has_value()) {
276+
int64_t runningSize = 0;
277+
for (size_t i = 0; i < outInnerDimsPos.size(); i++) {
278+
outOffsetForDimsPos[i] = runningSize;
279+
runningSize += outMaterializeEncodingInfo.swizzle->expandShape[i].size();
280+
}
281+
}
282+
253283
SmallVector<int64_t> outInverseOuterDimsPerm =
254284
invertPermutationVector(outMaterializeEncodingInfo.outerDimsPerm);
255285
SmallVector<AffineMap> packedIndexingMaps;
@@ -277,14 +307,59 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
277307
for (auto [idx, pos] : llvm::enumerate(innerDimsPos)) {
278308
auto dimPos = cast<AffineDimExpr>(inputMap.getResult(pos)).getPosition();
279309
for (auto [tileIdx, outDim] : llvm::enumerate(outInnerDimsPos)) {
280-
if (dimPos == outDim) {
310+
if (dimPos != outDim) {
311+
continue;
312+
}
313+
if (!materializeEncodingInfo.swizzle.has_value()) {
281314
packedResultDims.push_back(outputMap.getNumDims() + tileIdx);
315+
continue;
316+
}
317+
// In case of a layout with swizzle, an expanded set of dimensions
318+
// needs to be appended as specified by the swizzle's `expandedShape`
319+
// field. Note that the dimension index should be offset by the
320+
// calculated output starting offset as every dimension is now
321+
// transformed into an expanded sequence of indices and the correct
322+
// dimension index is:
323+
// outOffsetForDimsPos[tileIdx] + innerIndex
324+
assert(idx < materializeEncodingInfo.swizzle->expandShape.size() &&
325+
"`innerDimsPos` index should not exceed the swizzle's "
326+
"`expandShape` size");
327+
const size_t dimSize =
328+
materializeEncodingInfo.swizzle->expandShape[idx].size();
329+
const int64_t outIdxOffset =
330+
outputMap.getNumDims() + outOffsetForDimsPos[tileIdx];
331+
for (size_t i = 0; i < dimSize; i++) {
332+
packedResultDims.push_back(outIdxOffset + i);
282333
}
283334
}
284335
}
336+
// In case of a layout with swizzle, the packed result dimensions need
337+
// to be transposed according to the swizzle's permutation vector.
338+
if (materializeEncodingInfo.swizzle.has_value()) {
339+
int inRank =
340+
cast<RankedTensorType>(inputOperand->get().getType()).getRank();
341+
SmallVector<int64_t> transposePerm =
342+
llvm::to_vector(llvm::seq<int64_t>(0, inRank));
343+
for (auto perm : materializeEncodingInfo.swizzle->permutation) {
344+
transposePerm.push_back(inRank + perm);
345+
}
346+
applyPermutationToVector(packedResultDims, transposePerm);
347+
}
348+
// Calculate the final packed result dimensions through the inverse result
349+
// dimensions permutation map. This effectively linearizes the packed result
350+
// dimensions with respect to the output dimensions. For example, if the
351+
// permuted output dimensions are [D0, D2, D1], this will transform all
352+
// packed operand result dimensions with the permutation map that would make
353+
// the output dimensions the identity map [D0, D1, D2], i.e. {D0 -> D0, D1
354+
// -> D2, D2 -> D1}. Suppose that the operand dimensions are [D0, D2], this
355+
// operation would transform it into [D0, D1] to align with the output
356+
// identity map.
357+
SmallVector<int64_t> finalPackedResultDims = llvm::map_to_vector(
358+
packedResultDims, [&](int64_t r) { return invOutResultDimsPerm[r]; });
359+
285360
// Create the packed indexing map.
286361
SmallVector<AffineExpr> packedResultExprs =
287-
llvm::map_to_vector(packedResultDims, [&](int64_t dim) {
362+
llvm::map_to_vector(finalPackedResultDims, [&](int64_t dim) {
288363
return rewriter.getAffineDimExpr(dim);
289364
});
290365
auto packedInputMap = AffineMap::get(

compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir

+54
Original file line numberDiff line numberDiff line change
@@ -1254,3 +1254,57 @@ func.func @missing_user_indexing_maps() {
12541254
// CHECK-DAG: %[[STORE_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
12551255
// CHECK-DAG: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[LOAD_BINDING]]{{.+}} -> tensor<255x513xf32>
12561256
// CHECK-DAG: flow.dispatch.tensor.store %[[LOAD]], %[[STORE_BINDING]]
1257+
1258+
// -----
1259+
1260+
#pipeline_layout = #hal.pipeline.layout<bindings = [
1261+
#hal.pipeline.binding<storage_buffer>,
1262+
#hal.pipeline.binding<storage_buffer>,
1263+
#hal.pipeline.binding<storage_buffer>,
1264+
#hal.pipeline.binding<storage_buffer>
1265+
]>
1266+
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]>
1267+
#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [[affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2) -> (d0, d2)>], affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]>
1268+
func.func @dequantization() {
1269+
%c0 = arith.constant 0 : index
1270+
%cst = arith.constant 0.000000e+00 : f32
1271+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x128x64xi8, #encoding>>
1272+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
1273+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
1274+
%6 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
1275+
%7 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x128x64xi8, #encoding>> -> tensor<2x128x64xi8, #encoding>
1276+
%8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
1277+
%9 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
1278+
%13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
1279+
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %8, %9 : tensor<2x128x64xi8, #encoding>, tensor<2x64xf32, #encoding_bcast>, tensor<2x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
1280+
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
1281+
%21 = arith.extui %in : i8 to i32
1282+
%22 = arith.uitofp %21 : i32 to f32
1283+
%23 = arith.subf %22, %in_1 : f32
1284+
%24 = arith.mulf %23, %in_0 : f32
1285+
linalg.yield %24 : f32
1286+
} -> tensor<2x128x64xf32, #encoding>
1287+
flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
1288+
return
1289+
}
1290+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
1291+
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d2, d4, d7)>
1292+
// CHECK-LABEL: func.func @dequantization()
1293+
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x1x4x8x4x4x4x4xi8>>
1294+
// CHECK-DAG: %[[LHS_SCALES_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x4x4xf32>>
1295+
// CHECK-DAG: %[[LHS_ZPS_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(2) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x4x4xf32>>
1296+
// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(3) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x1x4x8x4x4x4x4xf32>>
1297+
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]], offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [2, 1, 4, 8, 4, 4, 4, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x1x4x8x4x4x4x4xi8>> -> tensor<2x1x4x8x4x4x4x4xi8>
1298+
// CHECK-DAG: %[[LHS_SCALES:.+]] = flow.dispatch.tensor.load %[[LHS_SCALES_BINDING]], offsets = [0, 0, 0, 0], sizes = [2, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x4x4xf32>> -> tensor<2x4x4x4xf32>
1299+
// CHECK-DAG: %[[LHS_ZPS:.+]] = flow.dispatch.tensor.load %[[LHS_ZPS_BINDING]], offsets = [0, 0, 0, 0], sizes = [2, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x4x4xf32>> -> tensor<2x4x4x4xf32>
1300+
// CHECK-DAG: %[[EMPTY_LHS:.+]] = tensor.empty() : tensor<2x1x4x8x4x4x4x4xf32>
1301+
// CHECK-DAG: %[[LHS_DEQUANT:.+]] = linalg.generic
1302+
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP1]], #[[$MAP]]]
1303+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
1304+
// CHECK-SAME: ins(%[[LHS]], %[[LHS_SCALES]], %[[LHS_ZPS]] : tensor<2x1x4x8x4x4x4x4xi8>, tensor<2x4x4x4xf32>, tensor<2x4x4x4xf32>)
1305+
// CHECK-SAME: outs(%[[EMPTY_LHS]] : tensor<2x1x4x8x4x4x4x4xf32>)
1306+
// CHECK: arith.extui
1307+
// CHECK: arith.uitofp
1308+
// CHECK: arith.subf
1309+
// CHECK: arith.mulf
1310+
// CHECK: flow.dispatch.tensor.store %[[LHS_DEQUANT]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [2, 1, 4, 8, 4, 4, 4, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : tensor<2x1x4x8x4x4x4x4xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x1x4x8x4x4x4x4xf32>>

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp

+83-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
7+
#include <numeric>
8+
89
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
10+
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
911
#include "llvm/Support/Debug.h"
1012
#include "mlir/Dialect/Utils/IndexingUtils.h"
1113
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
@@ -398,6 +400,86 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
398400
return crossThreadOuterSwizzle;
399401
}
400402

403+
/// Remove the expanded dimensions for this index and update the permutation by
404+
/// erasing the removed dimensions' indices and adjusting existing larger
405+
/// indices accordingly.
406+
static void remove(TileSwizzle &swizzle, size_t idx) {
407+
assert(idx < swizzle.expandShape.size() && "idx out of bounds");
408+
const size_t startIdx = std::accumulate(
409+
std::begin(swizzle.expandShape), std::begin(swizzle.expandShape) + idx, 0,
410+
[](size_t idx, const TileSwizzle::ExpandShapeDimVectorType &dims)
411+
-> size_t { return idx + dims.size(); });
412+
const size_t endIdx = startIdx + swizzle.expandShape[idx].size();
413+
swizzle.expandShape.erase(swizzle.expandShape.begin() + idx);
414+
SmallVector<int64_t> newPermutation;
415+
for (const int64_t &p : swizzle.permutation) {
416+
if (p < startIdx) {
417+
newPermutation.push_back(p);
418+
} else if (p >= endIdx) {
419+
newPermutation.push_back(p - (endIdx - startIdx));
420+
}
421+
}
422+
swizzle.permutation = newPermutation;
423+
}
424+
425+
FailureOr<TileSwizzle> getEncodingSwizzle(IREE::Encoding::EncodingAttr encoding,
426+
IREE::GPU::DataTiledMMAAttr mma,
427+
IREE::GPU::MMAFragment fragment) {
428+
TileSwizzle swizzle = getSwizzle(mma, fragment);
429+
FailureOr<linalg::ContractionDimensions> cDims =
430+
getEncodingContractionDims(encoding);
431+
if (failed(cDims)) {
432+
return failure();
433+
}
434+
// The following expects M, N, K, and Batch sizes of at most 1 for now.
435+
// TODO: Extend this to multiple M/N/K/Batch dims.
436+
assert(cDims->m.size() <= 1 && cDims->n.size() <= 1 && cDims->k.size() == 1 &&
437+
cDims->batch.size() <= 1 &&
438+
"Expected at most one M, N, K, and Batch dimension");
439+
std::optional<unsigned> mDim =
440+
cDims->m.empty() ? std::nullopt
441+
: encoding.mapDimToOperandIndex(cDims->m[0]);
442+
std::optional<unsigned> nDim =
443+
cDims->n.empty() ? std::nullopt
444+
: encoding.mapDimToOperandIndex(cDims->n[0]);
445+
std::optional<unsigned> kDim = encoding.mapDimToOperandIndex(cDims->k[0]);
446+
switch (fragment) {
447+
case IREE::GPU::MMAFragment::Lhs:
448+
// A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
449+
// Dimensions are removed from last to first to ensure correctness.
450+
if (!kDim.has_value()) {
451+
remove(swizzle, 1);
452+
}
453+
if (!cDims->m.empty() && !mDim.has_value()) {
454+
remove(swizzle, 0);
455+
}
456+
break;
457+
case IREE::GPU::MMAFragment::Rhs:
458+
// B-matrix (RHS). Since the pack ops already took care of transposing B,
459+
// source dimensions are N (index 0) and K (index 1).
460+
// Dimensions are removed from last to first to ensure correctness.
461+
if (!kDim.has_value()) {
462+
remove(swizzle, 1);
463+
}
464+
if (!cDims->n.empty() && !nDim.has_value()) {
465+
remove(swizzle, 0);
466+
}
467+
break;
468+
case IREE::GPU::MMAFragment::Acc:
469+
// C-matrix (accumulator). Source dimensions are M (index 0) and N (index
470+
// 1).
471+
// Dimensions are removed from last to first to ensure correctness.
472+
if (!cDims->n.empty() && !nDim.has_value()) {
473+
remove(swizzle, 1);
474+
}
475+
if (!cDims->m.empty() && !mDim.has_value()) {
476+
remove(swizzle, 0);
477+
}
478+
break;
479+
}
480+
return swizzle;
481+
}
482+
401483
TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
402484
IREE::GPU::MMAFragment fragment) {
403485
auto swizzle =

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h

+7
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ Codegen::TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
3333
Codegen::TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
3434
IREE::GPU::MMAFragment fragment);
3535

36+
/// Returns the swizzle for the data-tiled-mma tile, based on the `fragment`
37+
/// and contraction dimensions required from the `encoding`.
38+
FailureOr<Codegen::TileSwizzle>
39+
getEncodingSwizzle(IREE::Encoding::EncodingAttr encoding,
40+
IREE::GPU::DataTiledMMAAttr mma,
41+
IREE::GPU::MMAFragment fragment);
42+
3643
} // namespace mlir::iree_compiler::IREE::GPU
3744

3845
#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPUTILESWIZZLEUTILS_H_

compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp

-14
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,6 @@ static void transposeInPlace(MaterializeEncodingInfo &info) {
7070
transpose(info.outerDimsPerm);
7171
}
7272

73-
static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
74-
ValueRange convertedInputOperands,
75-
ValueRange convertedOutputOperands) {
76-
SmallVector<Value> operands;
77-
operands.append(convertedInputOperands.begin(), convertedInputOperands.end());
78-
operands.append(convertedOutputOperands.begin(),
79-
convertedOutputOperands.end());
80-
return mlir::clone(
81-
builder, op,
82-
{cast<RankedTensorType>(convertedOutputOperands[0].getType())
83-
.dropEncoding()},
84-
operands);
85-
}
86-
8773
static RankedTensorType
8874
getExpandedType(RankedTensorType type, bool isBatched, bool isTransposed,
8975
SmallVectorImpl<ReassociationIndices> &ri) {

compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,12 @@ struct GPUDeviceEncodingLayoutResolverAttrInterface
342342
info = std::move(maybeEncodingInfo.value());
343343
auto fragment = static_cast<IREE::GPU::MMAFragment>(
344344
encoding.getOperandIndex().getInt());
345-
info.swizzle = getSwizzle(mma, fragment);
345+
FailureOr<Codegen::TileSwizzle> maybeSwizzle =
346+
getEncodingSwizzle(encoding, mma, fragment);
347+
if (failed(maybeSwizzle)) {
348+
return info;
349+
}
350+
info.swizzle = std::move(maybeSwizzle.value());
346351
return info;
347352
}
348353

compiler/src/iree/compiler/Codegen/Utils/Utils.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,20 @@ OpFoldResult convertByteOffsetToElementOffset(RewriterBase &rewriter,
11581158
}
11591159
}
11601160

1161+
Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
1162+
ValueRange convertedInputOperands,
1163+
ValueRange convertedOutputOperands) {
1164+
SmallVector<Value> operands;
1165+
operands.append(convertedInputOperands.begin(), convertedInputOperands.end());
1166+
operands.append(convertedOutputOperands.begin(),
1167+
convertedOutputOperands.end());
1168+
return mlir::clone(
1169+
builder, op,
1170+
{cast<RankedTensorType>(convertedOutputOperands[0].getType())
1171+
.dropEncoding()},
1172+
operands);
1173+
}
1174+
11611175
LogicalResult isArgmaxOp(linalg::GenericOp genericOp) {
11621176
// Check for 2 results(value, index), and 1 input
11631177
if (genericOp.getNumDpsInits() != 2) {

compiler/src/iree/compiler/Codegen/Utils/Utils.h

+5
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ OpFoldResult convertByteOffsetToElementOffset(RewriterBase &rewriter,
211211
OpFoldResult byteOffset,
212212
Type elementType);
213213

214+
/// Clone an operation and drop all encodings.
215+
Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
216+
ValueRange convertedInputOperands,
217+
ValueRange convertedOutputOperands);
218+
214219
/// Check if a linalg.generic is representing an argmax operation.
215220
LogicalResult isArgmaxOp(linalg::GenericOp genericOp);
216221

0 commit comments

Comments
 (0)