Skip to content

Commit 43d5a50

Browse files
committed
[DT] Add generic op materialization pattern for GPU
Signed-off-by: Jorn Tuyls <jorn.tuyls@gmail.com>
1 parent 82ee6ad commit 43d5a50

File tree

11 files changed

+546
-31
lines changed

11 files changed

+546
-31
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp

+163-15
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,10 @@ lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp,
217217
return newEmptyOp;
218218
}
219219

220-
/// Converts a linalg::GenericOp with encoded inputs into the packed domain.
221-
/// The `genericOp` must have all parallel iterator types and a single output
222-
/// with an identity indexing map.
220+
/// Converts a linalg::GenericOp with encoded inputs into the packed domain,
221+
/// with an optional swizzle expansion and permutation if applicable. The
222+
/// `genericOp` must have all parallel iterator types and a single output with
223+
/// an identity indexing map.
223224
static FailureOr<Operation *> lowerGenericOpWithEncoding(
224225
RewriterBase &rewriter, linalg::GenericOp genericOp,
225226
ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
@@ -230,30 +231,119 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
230231
return rewriter.notifyMatchFailure(genericOp,
231232
"Output indexing map is not identity");
232233
}
234+
// Step 1: Retrieve the output encoding materialization information and
235+
// compute the new indexing maps for the packed and potentially swizzled
236+
// layout. This consists of an outer dimension and inner dimension permutation
237+
// vectors for the packing and an expanded result dimension permutation vector
238+
// for the optional swizzling. This assumes that the output map is identity,
239+
// and that all iterator types are parallel.
240+
//
241+
// Running example:
242+
//
243+
// Given following output layout:
244+
//
245+
// outputType: tensor<2x128x64xf32>
246+
// outputPackInfo: innerDimsPos = [1, 2],
247+
// innerTileSizes = [128, 16]
248+
// outerDimsPerm = [0, 1, 2]
249+
// outputSwizzle: expandShape = [[4, 8, 4], [4, 4]]
250+
// permutation = [1, 4, 0, 2, 3]}
251+
//
252+
// Retrieve and compute the permutation vectors for the packing outer and
253+
// inner dimension permutation and for the expanded swizzle permutation. Then,
254+
// calculate the permutation that would transform the swizzled output
255+
// dimension map into the identity dimension map. This is the inverse swizzle
256+
// permutation.
257+
//
258+
// outInverseOuterDimsPerm: [0, 1, 2]
259+
// outInnerDimsPos: [1, 2]
260+
// outSwizzlePerm: [0, 1, 2, 4, 7, 3, 5, 6]
261+
// invOutSwizzlePerm: [0, 1, 2, 5, 3, 6, 7, 4]
233262
MaterializeEncodingInfo outMaterializeEncodingInfo =
234263
typeConverter.getEncodingInfo(
235264
cast<RankedTensorType>(outputOperand->get().getType()));
236265
if (IREE::Codegen::isIdentityLayout(outMaterializeEncodingInfo)) {
237-
return rewriter.notifyMatchFailure(
238-
genericOp, "MaterializeEncodingInfo failed for output");
239-
}
240-
if (outMaterializeEncodingInfo.swizzle) {
241-
return rewriter.notifyMatchFailure(
242-
genericOp, "generic op lowering does not support swizzle yet");
266+
return dropEncodingAndCloneOp(rewriter, genericOp.getOperation(),
267+
convertedInputOperands,
268+
convertedOutputOperands);
243269
}
244270

245271
auto convertedResultType =
246272
cast<RankedTensorType>(convertedOutputOperands[0].getType());
247273
SmallVector<utils::IteratorType> iteratorTypes(convertedResultType.getRank(),
248274
utils::IteratorType::parallel);
249-
// Compute the new indexing maps for the packed layout. This assumes that
250-
// the output map is identity, and that all iterator types are parallel.
251-
SmallVector<int64_t> outInnerDimsPos =
252-
outMaterializeEncodingInfo.innerDimsPos;
275+
253276
SmallVector<int64_t> outInverseOuterDimsPerm =
254277
invertPermutationVector(outMaterializeEncodingInfo.outerDimsPerm);
278+
ArrayRef<int64_t> outInnerDimsPos = outMaterializeEncodingInfo.innerDimsPos;
279+
SmallVector<int64_t> outSwizzlePerm =
280+
llvm::to_vector(llvm::seq<int64_t>(0, convertedResultType.getRank()));
281+
if (outMaterializeEncodingInfo.swizzle.has_value()) {
282+
const int outRank =
283+
cast<RankedTensorType>(outputOperand->get().getType()).getRank();
284+
SmallVector<int64_t> transposePerm =
285+
llvm::to_vector(llvm::seq<int64_t>(0, outRank));
286+
for (auto perm : outMaterializeEncodingInfo.swizzle->permutation) {
287+
transposePerm.push_back(outRank + perm);
288+
}
289+
applyPermutationToVector(outSwizzlePerm, transposePerm);
290+
}
291+
SmallVector<int64_t> invOutSwizzlePerm =
292+
invertPermutationVector(outSwizzlePerm);
293+
294+
// Calculate the running offset for every dimension position for easy lookup
295+
// when calculating the packed result dimensions for every operand.
296+
// Example:
297+
// expandShape == [[4, 8, 4], [4, 4]]
298+
// In this case:
299+
// outOffsetForDimsPos == [0, 3]
300+
// So that whenever we need the real dimension for an entry (`outerIndex`,
301+
// `innerIndex`) in the 2D expanded shape vector, we can calculate it as:
302+
// dim(outerIndex, innerIndex) = outOffsetForDimsPos[outerIndex] +
303+
// innerIndex
304+
SmallVector<int64_t> outOffsetForDimsPos(outInnerDimsPos.size(), 0);
305+
if (outMaterializeEncodingInfo.swizzle.has_value()) {
306+
int64_t runningSize = 0;
307+
for (size_t i = 0; i < outInnerDimsPos.size(); i++) {
308+
outOffsetForDimsPos[i] = runningSize;
309+
runningSize += outMaterializeEncodingInfo.swizzle->expandShape[i].size();
310+
}
311+
}
312+
255313
SmallVector<AffineMap> packedIndexingMaps;
256314
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
315+
// Step 2: Retrieve the encoding for every input operand and perform the
316+
// outer dimension permutation, inner dimension expansion and permutation,
317+
// swizzle expansion and swizzle permutation.
318+
//
319+
// Running example:
320+
//
321+
// Given the input layout and indexing maps:
322+
//
323+
// inputType: tensor<2x64xf32>
324+
// innerPackInfo: innerDimsPos = [1]
325+
// innerTileSizes = [16]
326+
// outerDimsPerm = [0, 1]
327+
// innerSwizzle: expandShape = [[4, 4]]
328+
// permutation = [1, 0]
329+
// inputMap: [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
330+
// affine_map<(d0, d1, d2) -> (d0, d2)>]
331+
//
332+
// 1. Calculate the result dimensions from the indexing maps and perform the
333+
// outer dimension permutation:
334+
//
335+
// packedResultDims: [0, 2]
336+
//
337+
// 2. Perform inner dimension expansion, permutation and optional swizzle
338+
// expansion in one go. In this example, the inner dimension (64) would be
339+
// expanded into 4x16 based on `innerDimsPos` and `innerTileSizes` above,
340+
// and then expanded to 4x4x4 based on the swizzle.
341+
//
342+
// packedResultDims: [0, 2, 6, 7]
343+
//
344+
// 3. Perform the swizzle permutation:
345+
//
346+
// packedResultDims: [0, 2, 7, 6]
257347
MaterializeEncodingInfo materializeEncodingInfo =
258348
typeConverter.getEncodingInfo(
259349
cast<RankedTensorType>(inputOperand->get().getType()));
@@ -277,14 +367,72 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
277367
for (auto [idx, pos] : llvm::enumerate(innerDimsPos)) {
278368
auto dimPos = cast<AffineDimExpr>(inputMap.getResult(pos)).getPosition();
279369
for (auto [tileIdx, outDim] : llvm::enumerate(outInnerDimsPos)) {
280-
if (dimPos == outDim) {
370+
if (dimPos != outDim) {
371+
continue;
372+
}
373+
if (!materializeEncodingInfo.swizzle.has_value()) {
281374
packedResultDims.push_back(outputMap.getNumDims() + tileIdx);
375+
continue;
282376
}
377+
// In case of a layout with swizzle, an expanded set of dimensions
378+
// needs to be appended as specified by the swizzle's `expandedShape`
379+
// field. Note that the dimension index should be offset by the
380+
// calculated output starting offset as every dimension is now
381+
// transformed into an expanded sequence of indices and the correct
382+
// dimension index is:
383+
// outOffsetForDimsPos[tileIdx] + innerIndex
384+
assert(idx < materializeEncodingInfo.swizzle->expandShape.size() &&
385+
"`innerDimsPos` index should not exceed the swizzle's "
386+
"`expandShape` size");
387+
const size_t dimSize =
388+
materializeEncodingInfo.swizzle->expandShape[idx].size();
389+
const int64_t outIdxOffset =
390+
outputMap.getNumDims() + outOffsetForDimsPos[tileIdx];
391+
for (size_t i = 0; i < dimSize; i++) {
392+
packedResultDims.push_back(outIdxOffset + i);
393+
}
394+
}
395+
}
396+
// In case of a layout with swizzle, the packed result dimensions need
397+
// to be transposed according to the swizzle's permutation vector.
398+
if (materializeEncodingInfo.swizzle.has_value()) {
399+
int inRank =
400+
cast<RankedTensorType>(inputOperand->get().getType()).getRank();
401+
SmallVector<int64_t> transposePerm =
402+
llvm::to_vector(llvm::seq<int64_t>(0, inRank));
403+
for (auto perm : materializeEncodingInfo.swizzle->permutation) {
404+
transposePerm.push_back(inRank + perm);
283405
}
406+
applyPermutationToVector(packedResultDims, transposePerm);
284407
}
408+
409+
// Step 3: Calculate the final packed result dimensions through the inverse
410+
// result dimensions permutation map. This effectively linearizes the packed
411+
// result dimensions with respect to the output dimensions. For example, if
412+
// the permuted output dimensions are [D0, D2, D1], this will transform all
413+
// packed operand result dimensions with the permutation map that would make
414+
// the output dimensions the identity map [D0, D1, D2], i.e. {D0 -> D0, D1
415+
// -> D2, D2 -> D1}. Suppose that the operand dimensions are [D0, D2], this
416+
// operation would transform it into [D0, D1] to align with the output
417+
// identity map.
418+
//
419+
// Running example:
420+
//
421+
// The packed and swizzled result dimensions for the input operand:
422+
//
423+
// packedResultDims: [0, 2, 7, 6]
424+
//
425+
// Now we need to account for swizzled output result dimensions being
426+
// linearized to the identity map. This can be achieved by applying
427+
// `invOutSwizzlePerm` ([0, 1, 2, 5, 3, 6, 7, 4]):
428+
//
429+
// finalPackedResultDims: [0, 2, 4, 7]
430+
SmallVector<int64_t> finalPackedResultDims = llvm::map_to_vector(
431+
packedResultDims, [&](int64_t r) { return invOutSwizzlePerm[r]; });
432+
285433
// Create the packed indexing map.
286434
SmallVector<AffineExpr> packedResultExprs =
287-
llvm::map_to_vector(packedResultDims, [&](int64_t dim) {
435+
llvm::map_to_vector(finalPackedResultDims, [&](int64_t dim) {
288436
return rewriter.getAffineDimExpr(dim);
289437
});
290438
auto packedInputMap = AffineMap::get(

compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir

+54
Original file line numberDiff line numberDiff line change
@@ -1260,3 +1260,57 @@ func.func @missing_user_indexing_maps() {
12601260
// CHECK-DAG: %[[STORE_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
12611261
// CHECK-DAG: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[LOAD_BINDING]]{{.+}} -> tensor<255x513xf32>
12621262
// CHECK-DAG: flow.dispatch.tensor.store %[[LOAD]], %[[STORE_BINDING]]
1263+
1264+
// -----
1265+
1266+
#pipeline_layout = #hal.pipeline.layout<bindings = [
1267+
#hal.pipeline.binding<storage_buffer>,
1268+
#hal.pipeline.binding<storage_buffer>,
1269+
#hal.pipeline.binding<storage_buffer>,
1270+
#hal.pipeline.binding<storage_buffer>
1271+
]>
1272+
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]>
1273+
#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [[affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2) -> (d0, d2)>], affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]>
1274+
func.func @dequantization() {
1275+
%c0 = arith.constant 0 : index
1276+
%cst = arith.constant 0.000000e+00 : f32
1277+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x128x64xi8, #encoding>>
1278+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
1279+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
1280+
%6 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
1281+
%7 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x128x64xi8, #encoding>> -> tensor<2x128x64xi8, #encoding>
1282+
%8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
1283+
%9 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
1284+
%13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
1285+
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %8, %9 : tensor<2x128x64xi8, #encoding>, tensor<2x64xf32, #encoding_bcast>, tensor<2x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
1286+
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
1287+
%21 = arith.extui %in : i8 to i32
1288+
%22 = arith.uitofp %21 : i32 to f32
1289+
%23 = arith.subf %22, %in_1 : f32
1290+
%24 = arith.mulf %23, %in_0 : f32
1291+
linalg.yield %24 : f32
1292+
} -> tensor<2x128x64xf32, #encoding>
1293+
flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
1294+
return
1295+
}
1296+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
1297+
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d2, d4, d7)>
1298+
// CHECK-LABEL: func.func @dequantization()
1299+
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x1x4x8x4x4x4x4xi8>>
1300+
// CHECK-DAG: %[[LHS_SCALES_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x4x4xf32>>
1301+
// CHECK-DAG: %[[LHS_ZPS_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(2) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x4x4xf32>>
1302+
// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(3) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x1x4x8x4x4x4x4xf32>>
1303+
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]], offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [2, 1, 4, 8, 4, 4, 4, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x1x4x8x4x4x4x4xi8>> -> tensor<2x1x4x8x4x4x4x4xi8>
1304+
// CHECK-DAG: %[[LHS_SCALES:.+]] = flow.dispatch.tensor.load %[[LHS_SCALES_BINDING]], offsets = [0, 0, 0, 0], sizes = [2, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x4x4xf32>> -> tensor<2x4x4x4xf32>
1305+
// CHECK-DAG: %[[LHS_ZPS:.+]] = flow.dispatch.tensor.load %[[LHS_ZPS_BINDING]], offsets = [0, 0, 0, 0], sizes = [2, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x4x4xf32>> -> tensor<2x4x4x4xf32>
1306+
// CHECK-DAG: %[[EMPTY_LHS:.+]] = tensor.empty() : tensor<2x1x4x8x4x4x4x4xf32>
1307+
// CHECK-DAG: %[[LHS_DEQUANT:.+]] = linalg.generic
1308+
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP1]], #[[$MAP]]]
1309+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
1310+
// CHECK-SAME: ins(%[[LHS]], %[[LHS_SCALES]], %[[LHS_ZPS]] : tensor<2x1x4x8x4x4x4x4xi8>, tensor<2x4x4x4xf32>, tensor<2x4x4x4xf32>)
1311+
// CHECK-SAME: outs(%[[EMPTY_LHS]] : tensor<2x1x4x8x4x4x4x4xf32>)
1312+
// CHECK: arith.extui
1313+
// CHECK: arith.uitofp
1314+
// CHECK: arith.subf
1315+
// CHECK: arith.mulf
1316+
// CHECK: flow.dispatch.tensor.store %[[LHS_DEQUANT]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [2, 1, 4, 8, 4, 4, 4, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : tensor<2x1x4x8x4x4x4x4xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x1x4x8x4x4x4x4xf32>>

0 commit comments

Comments
 (0)