Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhang93 committed Feb 10, 2025
1 parent 9797d8c commit a33465c
Showing 1 changed file with 32 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ FailureOr<std::array<uint32_t, 3>> getPackedSize(linalg::LinalgOp linalgOp,
}

struct InputDimsAndSizes {
SmallVector<unsigned, 2> batchDims;
SmallVector<unsigned, 2> mDims;
SmallVector<unsigned, 2> nDims;
SmallVector<unsigned, 2> kDims;
SmallVector<int64_t, 2> batchSizes;
SmallVector<int64_t, 2> mSizes;
SmallVector<int64_t, 2> nSizes;
SmallVector<int64_t, 2> kSizes;
Expand All @@ -101,6 +103,7 @@ FailureOr<InputDimsAndSizes> getInputDimsAndSizes(linalg::LinalgOp linalgOp) {
}

linalg::ContractionDimensions contractionDims = *maybeContractionDims;
SmallVector<unsigned, 2> batchDims = contractionDims.batch;
SmallVector<unsigned, 2> mDims = contractionDims.m;
SmallVector<unsigned, 2> nDims = contractionDims.n;
SmallVector<unsigned, 2> kDims = contractionDims.k;
Expand All @@ -109,21 +112,26 @@ FailureOr<InputDimsAndSizes> getInputDimsAndSizes(linalg::LinalgOp linalgOp) {
}

SmallVector<int64_t> shapes = linalgOp.getStaticLoopRanges();
if (mDims.size() + nDims.size() + kDims.size() > shapes.size()) {
return linalgOp.emitOpError(
"the total of m/n/k dims is larger than the number of loops.");
}

auto getSizesAt = [&shapes](const SmallVector<unsigned, 2> &idx) {
size_t totalNumDims =
batchDims.size() + mDims.size() + nDims.size() + kDims.size();
assert(totalNumDims == shapes.size() &&
("the total number of dims " + std::to_string(totalNumDims) +
" is not the same as the number of loops " +
std::to_string(shapes.size()) + ".")
.c_str());

auto getSizesAt = [&shapes](ArrayRef<unsigned> idx) {
SmallVector<int64_t, 2> sizes;
for (auto i : idx) sizes.push_back(shapes[i]);
for (unsigned i : idx) sizes.push_back(shapes[i]);
return sizes;
};

InputDimsAndSizes inputDimsAndSizes;
inputDimsAndSizes.batchDims = batchDims;
inputDimsAndSizes.mDims = mDims;
inputDimsAndSizes.nDims = nDims;
inputDimsAndSizes.kDims = kDims;
inputDimsAndSizes.batchSizes = getSizesAt(batchDims);
inputDimsAndSizes.mSizes = getSizesAt(mDims);
inputDimsAndSizes.nSizes = getSizesAt(nDims);
inputDimsAndSizes.kSizes = getSizesAt(kDims);
Expand Down Expand Up @@ -209,7 +217,7 @@ FailureOr<ParameterSetting> ParameterSetting::create(
llvm::cast<ShapedType>(linalgOp.getDpsInputOperand(1)->get().getType());
unsigned nBitsRhs = rhsType.getElementTypeBitWidth();

auto getTotalSize = [](const SmallVector<int64_t, 2> &sizes) {
auto getTotalSize = [](ArrayRef<int64_t> sizes) {
return std::accumulate(sizes.begin(), sizes.end(), 1,
std::multiplies<int64_t>());
};
Expand Down Expand Up @@ -434,8 +442,7 @@ static SmallVector<int64_t> setOuterPermB(bool isMatmulTransposeB,

static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp,
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols,
uint32_t numLoops) {
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols) {
// Scale the L1 K with a factor of 2 compared with the outer dimensions M and
// N to increase the L1 memory usage.
auto maybePackPeelTiling = ParameterSetting::create(
Expand All @@ -445,8 +452,10 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
auto packPeelTiling = maybePackPeelTiling.value();

// Get M, N, K dimension indices from the input indexing map.
auto maybeInputDimsAndSizes = getInputDimsAndSizes(linalgOp);
FailureOr<InputDimsAndSizes> maybeInputDimsAndSizes =
getInputDimsAndSizes(linalgOp);
if (failed(maybeInputDimsAndSizes)) return failure();
SmallVector<unsigned, 2> batchDims = maybeInputDimsAndSizes.value().batchDims;
SmallVector<unsigned, 2> mDims = maybeInputDimsAndSizes.value().mDims;
SmallVector<unsigned, 2> nDims = maybeInputDimsAndSizes.value().nDims;
SmallVector<unsigned, 2> kDims = maybeInputDimsAndSizes.value().kDims;
Expand All @@ -457,6 +466,7 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
// --------------- Set packing config -------------------
// ------------------------------------------------------
MLIRContext *context = entryPointFn.getContext();
unsigned numLoops = linalgOp.getNumLoops();

// Pack level => 1.
SmallVector<int64_t> packedSizesL0(numLoops, 0);
Expand Down Expand Up @@ -543,7 +553,7 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(

SmallVector<int64_t> tileSizeLevel0(numLoops, 0);
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
tileSizeLevel0[0] = 1;
tileSizeLevel0[batchDims[0]] = 1;
}
tileSizeLevel0[mDims[0]] = packPeelTiling.M0 * scaleL0;
tileSizeLevel0[nDims[0]] = packPeelTiling.N0 * scaleL0;
Expand Down Expand Up @@ -572,7 +582,7 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
static LogicalResult setRootConfigForPackPeelPipeline(
mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp,
LowerToAIEPassPipeline useLowerToAIEPipeline, AMDAIEDevice targetDevice,
uint32_t numRows, uint32_t numCols, uint32_t numLoops) {
uint32_t numRows, uint32_t numCols) {
bool isObjectFifo =
useLowerToAIEPipeline == LowerToAIEPassPipeline::ObjectFifo;
auto maybePackPeelTiling =
Expand All @@ -582,8 +592,10 @@ static LogicalResult setRootConfigForPackPeelPipeline(
auto packPeelTiling = maybePackPeelTiling.value();

// Get M, N, K dimension indices from the input indexing map.
auto maybeInputDimsAndSizes = getInputDimsAndSizes(linalgOp);
FailureOr<InputDimsAndSizes> maybeInputDimsAndSizes =
getInputDimsAndSizes(linalgOp);
if (failed(maybeInputDimsAndSizes)) return failure();
SmallVector<unsigned, 2> batchDims = maybeInputDimsAndSizes.value().batchDims;
SmallVector<unsigned, 2> mDims = maybeInputDimsAndSizes.value().mDims;
SmallVector<unsigned, 2> nDims = maybeInputDimsAndSizes.value().nDims;
SmallVector<unsigned, 2> kDims = maybeInputDimsAndSizes.value().kDims;
Expand All @@ -592,6 +604,7 @@ static LogicalResult setRootConfigForPackPeelPipeline(
// --------------- Set packing config -------------------
// ------------------------------------------------------
MLIRContext *context = entryPointFn.getContext();
unsigned numLoops = linalgOp.getNumLoops();

// Pack level => 1.
SmallVector<int64_t> packedSizesL0(numLoops, 0);
Expand Down Expand Up @@ -669,7 +682,7 @@ static LogicalResult setRootConfigForPackPeelPipeline(
// ------------------------------------------------------
SmallVector<int64_t> tileSizeLevel0(numLoops, 0);
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
tileSizeLevel0[0] = 1;
tileSizeLevel0[batchDims[0]] = 1;
}
tileSizeLevel0[mDims[0]] = packPeelTiling.M0;
tileSizeLevel0[nDims[0]] = packPeelTiling.N0;
Expand Down Expand Up @@ -902,8 +915,6 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
uint32_t numCols) {
assert(!getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(genericOp) &&
"expected lowering_config is not set");
unsigned numLoops = genericOp.getNumLoops();
assert(numLoops <= 7 && "expected input number of loops no more than 7");
if (!isMatmul(genericOp) && !isMatmulTransposeA(genericOp) &&
!isMatmulTransposeB(genericOp))
return genericOp.emitOpError(
Expand All @@ -912,11 +923,11 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
if (passPipeline == TilePassPipeline::PackPeelPipeline) {
return setRootConfigForPackPeelPipeline(entryPointFn, genericOp,
useLowerToAIEPipeline, targetDevice,
numRows, numCols, numLoops);
numRows, numCols);
}
if (passPipeline == TilePassPipeline::PackPeel4LevelTilingPipeline) {
return setRootConfigForPackPeel4LevelTilingPipeline(
entryPointFn, genericOp, targetDevice, numRows, numCols, numLoops);
entryPointFn, genericOp, targetDevice, numRows, numCols);
}
if (passPipeline == TilePassPipeline::PadPackPipeline) {
return setRootConfigForPadPackPipeline(entryPointFn, genericOp,
Expand All @@ -936,20 +947,18 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
assert(!getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(contractionOp) &&
"expected lowering_config is not set");
auto linalgOp = cast<linalg::LinalgOp>(contractionOp.getOperation());
unsigned numLoops = linalgOp.getNumLoops();
assert(numLoops <= 7 && "expected input number of loops no more than 7");

// TODO (nmeshram) : This needs to be moved in a separate more generalized
// logic. Also, need a flag to experiment between pad based and pack based
// approach which will have different tile sizes and pass pipelines
if (passPipeline == TilePassPipeline::PackPeelPipeline) {
return setRootConfigForPackPeelPipeline(entryPointFn, linalgOp,
useLowerToAIEPipeline, targetDevice,
numRows, numCols, numLoops);
numRows, numCols);
}
if (passPipeline == TilePassPipeline::PackPeel4LevelTilingPipeline) {
return setRootConfigForPackPeel4LevelTilingPipeline(
entryPointFn, linalgOp, targetDevice, numRows, numCols, numLoops);
entryPointFn, linalgOp, targetDevice, numRows, numCols);
}
if (passPipeline == TilePassPipeline::PadPackPipeline) {
return setRootConfigForPadPackPipeline(entryPointFn, linalgOp, targetDevice,
Expand Down

0 comments on commit a33465c

Please sign in to comment.