Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoweringStrategy] Use a more general method to fetch input dims and sizes #1090

Merged
merged 3 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,57 @@ FailureOr<std::array<uint32_t, 3>> getPackedSize(linalg::LinalgOp linalgOp,
return instructionSize;
}

struct InputDimsAndSizes {
SmallVector<unsigned, 2> batchDims;
SmallVector<unsigned, 2> mDims;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal preference: Use SmallVector<T> instead of SmallVector<T, 2>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to use the same type as in upstream. In addition it seems not work to use SmallVector when calling

SmallVector<unsigned> mDims = contractionDims.m;

It has error: no viable conversion from 'SmallVector<[...], 2>' to 'SmallVector<[...], (default) 12>'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I didn't realize this is the way upstream does it, that's fine then.

FWIW you could do

SmallVector mDims {contractionDims.m.begin(), contractionDims.m.end()};

SmallVector<unsigned, 2> nDims;
SmallVector<unsigned, 2> kDims;
SmallVector<int64_t, 2> batchSizes;
SmallVector<int64_t, 2> mSizes;
SmallVector<int64_t, 2> nSizes;
SmallVector<int64_t, 2> kSizes;
};

FailureOr<InputDimsAndSizes> getInputDimsAndSizes(linalg::LinalgOp linalgOp) {
FailureOr<linalg::ContractionDimensions> maybeContractionDims =
linalg::inferContractionDims(linalgOp);
if (failed(maybeContractionDims)) {
return linalgOp.emitOpError("failed to infer the contraction dimensions.");
}

linalg::ContractionDimensions contractionDims = *maybeContractionDims;
SmallVector<unsigned, 2> batchDims = contractionDims.batch;
SmallVector<unsigned, 2> mDims = contractionDims.m;
SmallVector<unsigned, 2> nDims = contractionDims.n;
SmallVector<unsigned, 2> kDims = contractionDims.k;

SmallVector<int64_t> shapes = linalgOp.getStaticLoopRanges();
[[maybe_unused]] size_t totalNumDims =
batchDims.size() + mDims.size() + nDims.size() + kDims.size();
assert(totalNumDims == shapes.size() &&
("the total number of dims " + std::to_string(totalNumDims) +
" is not the same as the number of loops " +
std::to_string(shapes.size()) + ".")
.c_str());

auto getSizesAt = [&shapes](ArrayRef<unsigned> idx) {
SmallVector<int64_t, 2> sizes;
for (unsigned i : idx) sizes.push_back(shapes[i]);
return sizes;
};

InputDimsAndSizes inputDimsAndSizes;
inputDimsAndSizes.batchDims = batchDims;
inputDimsAndSizes.mDims = mDims;
inputDimsAndSizes.nDims = nDims;
inputDimsAndSizes.kDims = kDims;
inputDimsAndSizes.batchSizes = getSizesAt(batchDims);
inputDimsAndSizes.mSizes = getSizesAt(mDims);
inputDimsAndSizes.nSizes = getSizesAt(nDims);
inputDimsAndSizes.kSizes = getSizesAt(kDims);
return inputDimsAndSizes;
}

// Container class for the tiling at level 0 (the AIE shared memory) and level 1
// (the AIE core) in the M-, N-, and K-dimensions of a matmul operation, using
// the pad-pack approach to tiling a matmul. Also contains the packing sizes for
Expand Down Expand Up @@ -156,25 +207,24 @@ FailureOr<ParameterSetting> ParameterSetting::create(
auto initType =
llvm::cast<ShapedType>(linalgOp.getDpsInitOperand(0)->get().getType());
unsigned nBitsInit = initType.getElementTypeBitWidth();
ArrayRef<int64_t> initShape = initType.getShape();

auto lhsType =
llvm::cast<ShapedType>(linalgOp.getDpsInputOperand(0)->get().getType());
unsigned nBitsLhs = lhsType.getElementTypeBitWidth();
ArrayRef<int64_t> lhsShape = lhsType.getShape();

auto rhsType =
llvm::cast<ShapedType>(linalgOp.getDpsInputOperand(1)->get().getType());
unsigned nBitsRhs = rhsType.getElementTypeBitWidth();

// Shape of the full matmul operation.
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
initShape = initShape.drop_front();
lhsShape = lhsShape.drop_front();
}
const uint64_t M = initShape[0];
const uint64_t N = initShape[1];
const uint64_t K = lhsShape[1];
auto getTotalSize = [](ArrayRef<int64_t> sizes) {
return std::accumulate(sizes.begin(), sizes.end(), 1,
std::multiplies<int64_t>());
};

// Get the shape (M, N, K) of the full Matmul operation.
auto maybeInputDimsAndSizes = getInputDimsAndSizes(linalgOp);
yzhang93 marked this conversation as resolved.
Show resolved Hide resolved
if (failed(maybeInputDimsAndSizes)) return failure();
int64_t M = getTotalSize(maybeInputDimsAndSizes.value().mSizes);
int64_t N = getTotalSize(maybeInputDimsAndSizes.value().nSizes);
int64_t K = getTotalSize(maybeInputDimsAndSizes.value().kSizes);

// If we are conservative with ensuring that tiles A, B, and C fit at the
// different memory levels, we should choose the scale factor based
Expand Down Expand Up @@ -390,25 +440,39 @@ static SmallVector<int64_t> setOuterPermB(bool isMatmulTransposeB,
static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp,
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols) {
// Scale the L1 K with a factor of 2 compared with the outer dimenions M and N
// to increase the L1 memory usage.
// Scale the L1 K with a factor of 2 compared with the outer dimensions M and
// N to increase the L1 memory usage.
auto maybePackPeelTiling = ParameterSetting::create(
linalgOp, /*isPackPeel=*/true, /*isObjectFifo=*/true, targetDevice,
numRows, numCols, /*kPackScaleL1=*/2);
if (failed(maybePackPeelTiling)) return failure();
auto packPeelTiling = maybePackPeelTiling.value();

// Get M, N, K dimension indices from the input indexing map.
FailureOr<InputDimsAndSizes> maybeInputDimsAndSizes =
getInputDimsAndSizes(linalgOp);
if (failed(maybeInputDimsAndSizes)) return failure();
SmallVector<unsigned, 2> batchDims = maybeInputDimsAndSizes.value().batchDims;
SmallVector<unsigned, 2> mDims = maybeInputDimsAndSizes.value().mDims;
SmallVector<unsigned, 2> nDims = maybeInputDimsAndSizes.value().nDims;
SmallVector<unsigned, 2> kDims = maybeInputDimsAndSizes.value().kDims;
if (mDims.empty() || nDims.empty() || kDims.empty()) {
return linalgOp.emitOpError("failed to fetch m/n/k dims.");
}

AMDAIEDeviceModel deviceModel = getDeviceModel(targetDevice);

// ------------------------------------------------------
// --------------- Set packing config -------------------
// ------------------------------------------------------
MLIRContext *context = entryPointFn.getContext();
unsigned numLoops = linalgOp.getNumLoops();

SmallVector<int64_t> packedSizesL0 = packPeelTiling.getPackSizeL0();
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
packedSizesL0.insert(packedSizesL0.begin(), 0);
}
// Pack level => 1.
SmallVector<int64_t> packedSizesL0(numLoops, 0);
packedSizesL0[mDims.back()] = packPeelTiling.m0Pack;
packedSizesL0[nDims.back()] = packPeelTiling.n0Pack;
packedSizesL0[kDims.back()] = packPeelTiling.k0Pack;

// For matmul, transpose B matrix from [K N n k] to [N K k n]
// For matmul_transpose_b, we don't have to transpose the B matrix,
Expand Down Expand Up @@ -440,17 +504,11 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
outerPerm);

// Pack level => 2.
// packed size for [M, N, K, m, n, k]
SmallVector<int64_t> packedSizesL1 = {0,
0,
0,
packPeelTiling.m1Pack,
packPeelTiling.n1Pack,
packPeelTiling.k1Pack};

if (isa<linalg::BatchMatmulOp>(linalgOp)) {
packedSizesL1.insert(packedSizesL1.begin(), 0);
}
// The number of loops have increased by 3 due to the first level pack.
SmallVector<int64_t> packedSizesL1(numLoops + 3, 0);
packedSizesL1[mDims.back() + 3] = packPeelTiling.m1Pack;
packedSizesL1[nDims.back() + 3] = packPeelTiling.n1Pack;
packedSizesL1[kDims.back() + 3] = packPeelTiling.k1Pack;

// Transpose A matrix from [M K m k m0 k0] to [M K k m m0 k0]
// Transpose C matrix from [M N m n m0 n0] to [M N n m m0 n0]
Expand Down Expand Up @@ -492,18 +550,25 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
bool fitsInL2 = (l2SizeA + l2SizeB + l2SizeInit) <
(deviceModel.getMemTileSizeInBytes() * numCols);
int64_t scaleL0 = !isBatchMatmul && fitsInL2 ? 2 : 1;
SmallVector<int64_t> tileSizeLevel0 = {packPeelTiling.M0 * scaleL0,
packPeelTiling.N0 * scaleL0};
SmallVector<int64_t> tileSizeLevel1 = {numRows, numCols, 0};
SmallVector<int64_t> tileSizeLevel2 = {0, 0, 1};
SmallVector<int64_t> tileSizeLevel3 = {1, 1, 0, 0, 0, 0};

SmallVector<int64_t> tileSizeLevel0(numLoops, 0);
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
tileSizeLevel0.insert(tileSizeLevel0.begin(), 1);
tileSizeLevel1.insert(tileSizeLevel1.begin(), 0);
tileSizeLevel2.insert(tileSizeLevel2.begin(), 0);
tileSizeLevel3.insert(tileSizeLevel3.begin(), 0);
assert(!batchDims.empty() && "expected batch dims not empty");
tileSizeLevel0[batchDims[0]] = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could assert that batchDims is not empty, like the others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
tileSizeLevel0[mDims[0]] = packPeelTiling.M0 * scaleL0;
tileSizeLevel0[nDims[0]] = packPeelTiling.N0 * scaleL0;

SmallVector<int64_t> tileSizeLevel1(numLoops, 0);
tileSizeLevel1[mDims[0]] = numRows;
tileSizeLevel1[nDims[0]] = numCols;

SmallVector<int64_t> tileSizeLevel2(numLoops, 0);
tileSizeLevel2[kDims[0]] = 1;

SmallVector<int64_t> tileSizeLevel3(numLoops, 0);
tileSizeLevel3[mDims[0]] = 1;
tileSizeLevel3[nDims[0]] = 1;

TileSizesListType tileSizes = {tileSizeLevel0, tileSizeLevel1, tileSizeLevel2,
tileSizeLevel3};
Expand All @@ -527,15 +592,29 @@ static LogicalResult setRootConfigForPackPeelPipeline(
if (failed(maybePackPeelTiling)) return failure();
auto packPeelTiling = maybePackPeelTiling.value();

// Get M, N, K dimension indices from the input indexing map.
FailureOr<InputDimsAndSizes> maybeInputDimsAndSizes =
getInputDimsAndSizes(linalgOp);
if (failed(maybeInputDimsAndSizes)) return failure();
SmallVector<unsigned, 2> batchDims = maybeInputDimsAndSizes.value().batchDims;
SmallVector<unsigned, 2> mDims = maybeInputDimsAndSizes.value().mDims;
SmallVector<unsigned, 2> nDims = maybeInputDimsAndSizes.value().nDims;
SmallVector<unsigned, 2> kDims = maybeInputDimsAndSizes.value().kDims;
if (mDims.empty() || nDims.empty() || kDims.empty()) {
return linalgOp.emitOpError("failed to fetch m/n/k dims.");
}

// ------------------------------------------------------
// --------------- Set packing config -------------------
// ------------------------------------------------------
MLIRContext *context = entryPointFn.getContext();
unsigned numLoops = linalgOp.getNumLoops();

SmallVector<int64_t> packedSizesL0 = packPeelTiling.getPackSizeL0();
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
packedSizesL0.insert(packedSizesL0.begin(), 0);
}
// Pack level => 1.
SmallVector<int64_t> packedSizesL0(numLoops, 0);
packedSizesL0[mDims.back()] = packPeelTiling.m0Pack;
packedSizesL0[nDims.back()] = packPeelTiling.n0Pack;
packedSizesL0[kDims.back()] = packPeelTiling.k0Pack;

// For matmul, transpose B matrix from [K N n k] to [N K k n]
// For matmul_transpose_b, we don't have to transpose the B matrix,
Expand Down Expand Up @@ -571,17 +650,11 @@ static LogicalResult setRootConfigForPackPeelPipeline(
outerPerm);

// Pack level => 2.
// packed size for [M, N, K, m, n, k]
SmallVector<int64_t> packedSizesL1 = {0,
0,
0,
packPeelTiling.m1Pack,
packPeelTiling.n1Pack,
packPeelTiling.k1Pack};

if (isa<linalg::BatchMatmulOp>(linalgOp)) {
packedSizesL1.insert(packedSizesL1.begin(), 0);
}
// The number of loops have increased by 3 due to the first level pack.
SmallVector<int64_t> packedSizesL1(numLoops + 3, 0);
packedSizesL1[mDims.back() + 3] = packPeelTiling.m1Pack;
packedSizesL1[nDims.back() + 3] = packPeelTiling.n1Pack;
packedSizesL1[kDims.back() + 3] = packPeelTiling.k1Pack;

// Transpose A matrix from [M K m k m0 k0] to [M K k m m0 k0]
// Transpose C matrix from [M N m n m0 n0] to [M N n m m0 n0]
Expand Down Expand Up @@ -611,15 +684,20 @@ static LogicalResult setRootConfigForPackPeelPipeline(
// ------------------------------------------------------
// -------------- Set lowering config -------------------
// ------------------------------------------------------
SmallVector<int64_t> tileSizeLevel0 = {packPeelTiling.M0, packPeelTiling.N0};
SmallVector<int64_t> tileSizeLevel1 = {0, 0, packPeelTiling.K0};
SmallVector<int64_t> tileSizeLevel2 = {1, 1, 0, 0, 0, 0};

SmallVector<int64_t> tileSizeLevel0(numLoops, 0);
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
tileSizeLevel0.insert(tileSizeLevel0.begin(), 1);
tileSizeLevel1.insert(tileSizeLevel1.begin(), 0);
tileSizeLevel2.insert(tileSizeLevel2.begin(), 0);
assert(!batchDims.empty() && "expected batch dims not empty");
tileSizeLevel0[batchDims[0]] = 1;
}
tileSizeLevel0[mDims[0]] = packPeelTiling.M0;
tileSizeLevel0[nDims[0]] = packPeelTiling.N0;

SmallVector<int64_t> tileSizeLevel1(numLoops, 0);
tileSizeLevel1[kDims[0]] = 1;

SmallVector<int64_t> tileSizeLevel2(numLoops, 0);
tileSizeLevel2[mDims[0]] = 1;
tileSizeLevel2[nDims[0]] = 1;

TileSizesListType tileSizes = {tileSizeLevel0, tileSizeLevel1,
tileSizeLevel2};
Expand Down Expand Up @@ -874,16 +952,6 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
assert(!getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(contractionOp) &&
"expected lowering_config is not set");
auto linalgOp = cast<linalg::LinalgOp>(contractionOp.getOperation());
unsigned numLoops = linalgOp.getNumLoops();
{
SmallVector<unsigned> dims;
linalgOp.getReductionDims(dims);
if (dims.size() != 1 || dims[0] != numLoops - 1) {
return linalgOp.emitOpError(
"is expected to have exactly one reduction dim, ")
<< "and that it is the innermost dim (" << numLoops - 1 << ").";
}
}

// TODO (nmeshram) : This needs to be moved in a separate more generalized
// logic. Also, need a flag to experiment between pad based and pack based
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ builtin.module {

// -----

// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0], [0, 0, 1], [1, 1, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [128, 128, 128], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [1, 0]], outerPerm = [[0, 1], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 8, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
Expand All @@ -216,7 +216,7 @@ builtin.module {

// -----

// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[44, 128, 0], [0, 0, 1], [1, 1, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [44, 32, 64], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [1, 0]], outerPerm = [[0, 1], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
Expand Down Expand Up @@ -244,7 +244,7 @@ module {

// CHECK-PAD-PACK{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [0, 0, 256], [32, 32], [0, 0, 4]]>
// CHECK-PAD-PACK{LITERAL}: #packingConfig = #amdaie.packing_config<packing_config = [{packedSizes = [4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [0, 1], [0, 1]], outerPerm = [[1, 0], [1, 0], [1, 0]]}]>
// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128, 0], [0, 0, 1], [1, 1, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [0, 1]], outerPerm = [[0, 1], [0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [0, 1], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
Expand Down
Loading
Loading