Skip to content

Commit

Permalink
[SplitLogicalObjFifos] Generalize for loop dependency and stride
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuyls committed Jan 8, 2025
1 parent 0ee8453 commit 1b29d76
Show file tree
Hide file tree
Showing 9 changed files with 612 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIELogicalObjFifoSplittingUtils.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Pass/Pass.h"

Expand All @@ -18,20 +21,114 @@ namespace {
/// Utility struct to represent DMA split information.
struct DmaSplitInfo {
size_t sourceSplitDim{0};
int64_t newSourceStride{1};
size_t targetSplitDim{0};
int64_t newTargetStride{1};
int64_t splitSize{1};
};

/// Utility struct to represent objectFifo split information.
struct ObjFifoSplitInfo {
size_t splitDim{0};
int64_t splitSize{1};
int64_t splitStride{1};
};

using DmaObjFifoPairT =
std::pair<AMDAIE::DmaCpyNdOp, AMDAIE::LogicalObjectFifoFromMemrefOp>;

/// Utility to derive the split stride to be used from a vector of DMA ops by
/// analyzing the offset scales. Will fail if the provided DMA ops don't have a
/// consistent offset scale.
template <CopyOpOperateOn OperateOn>
FailureOr<int64_t> getSplitStride(ArrayRef<AMDAIE::DmaCpyNdOp> dmaOps,
int64_t sizeAfterSplit) {
int64_t splitStride{-1};
for (AMDAIE::DmaCpyNdOp dmaOp : dmaOps) {
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> strides;
if constexpr (OperateOn == CopyOpOperateOn::Source) {
offsets = dmaOp.getSourceMixedOffsets();
strides = dmaOp.getSourceMixedStrides();
} else if constexpr (OperateOn == CopyOpOperateOn::Target) {
offsets = dmaOp.getTargetMixedOffsets();
strides = dmaOp.getTargetMixedStrides();
} else {
assert(false && "Function can only operate on Source or Target");
}
SmallVector<size_t> splitIndices =
getStrideIndicesWithDynamicOrNonZeroOffset(offsets, strides,
sizeAfterSplit);
if (splitIndices.size() > 1)
return dmaOp.emitError() << "multiple split indices found";
int64_t step{-1};
if (splitIndices.empty()) {
step = 1;
} else {
// splitIndices.size() == 1
size_t splitIdx = splitIndices[0];
OpFoldResult offset = offsets[splitIdx];

if (std::optional<int64_t> staticOffset = getConstantIntValue(offset);
staticOffset.has_value()) {
if (staticOffset.value() == 0) continue;
step = 1;
} else if (auto offsetValue = dyn_cast_if_present<Value>(offset)) {
if (isa_and_present<affine::AffineApplyOp>(
offsetValue.getDefiningOp())) {
auto applyOp =
cast<affine::AffineApplyOp>(offsetValue.getDefiningOp());
if (applyOp.getNumOperands() != 1)
return applyOp.emitError() << "mulptiple operands is not supported";
AffineMap affineMap = applyOp.getAffineMap();
RetrieveScaleAndBias retriever;
if (failed(retriever.visit(affineMap.getResult(0)))) {
return applyOp.emitError()
<< "could not retrieve scale and bias from expression: "
<< *applyOp.getOperation();
}
if (!retriever.scale.has_value()) {
return applyOp.emitError()
<< "expected a scale for: " << *applyOp.getOperation();
}
step = retriever.scale.value();
} else if (auto blockArg = dyn_cast<BlockArgument>(offsetValue);
blockArg && isa<LoopLikeOpInterface>(
blockArg.getOwner()->getParentOp())) {
step = 1;
} else {
return dmaOp.emitOpError()
<< "has an offset value that is neither an "
"induction variable nor an affine expression";
}
} else {
return dmaOp.emitOpError()
<< "has an offset that is neither a constant nor an affine "
"expression, which is not supported";
}
}
if (splitStride == -1) {
splitStride = step;
} else if (step != splitStride) {
return dmaOp.emitOpError() << "has an offset step: " << step
<< ", which is different from "
"previous offset steps: "
<< splitStride;
}
}
// If all offsets are zero (or no split index found).
if (splitStride == -1) return 1;
return splitStride;
}

/// Find the logical objectFifo and DMA source/target splitting dimensions for
/// each DMA and objectFifo pair.
///
/// Each pair is handled in the following way:
/// First, compute the objectFifo splitting dimension as the last non-unit shape
/// dimension. Afterwards, depending on which logical objectFifo is being
/// split on, find the outermost dimension in either the source or
/// target access pattern that has:
/// First, compute the objectFifo splitting dimension based on the last non-unit
/// shape dimension and the number of available columns. Afterwards, depending
/// on which logical objectFifo is being split on, find the outermost dimension
/// in either the source or target access pattern that has:
/// - stride == sizeAfterSplit
/// - size != 1
/// This is the splitting dimension to be used on the respective side of the DMA
Expand All @@ -44,8 +141,9 @@ using DmaObjFifoPairT =
LogicalResult collectSplittingDims(
const SmallVector<DmaObjFifoPairT> &dmaObjFifoPairs,
DenseMap<AMDAIE::DmaCpyNdOp, DmaSplitInfo> &dmaSplitInfoMap,
DenseMap<AMDAIE::LogicalObjectFifoFromMemrefOp, size_t>
&objFifoSplitDimMap) {
DenseMap<AMDAIE::LogicalObjectFifoFromMemrefOp, ObjFifoSplitInfo>
&objFifoSplitInfoMap,
int64_t numCols) {
for (auto [dmaOp, objFifo] : dmaObjFifoPairs) {
LLVM_DEBUG(llvm::dbgs() << "dmaOp: " << dmaOp << "\n");
LLVM_DEBUG(llvm::dbgs() << "objFifo: " << objFifo << "\n");
Expand All @@ -62,10 +160,19 @@ LogicalResult collectSplittingDims(
// If all dimensions are unit (1), no splitting can be done, so continue to
// the next pair.
if (objFifoSplitDim >= memrefShape.size()) continue;
int64_t splitDimSize = memrefShape[objFifoSplitDim];
int64_t sizeAfterSplit =
std::accumulate(memrefShape.begin() + objFifoSplitDim + 1,
memrefShape.end(), 1, std::multiplies<>());

// Get the producers and consumers of the current objectFifoOp.
SmallVector<AMDAIE::DmaCpyNdOp> producers;
SmallVector<AMDAIE::DmaCpyNdOp> consumers;
if (failed(getDmaCpyNdOpProducersAndConsumers(objFifo, producers,
consumers))) {
return failure();
}

size_t sourceSplitDim{0};
size_t targetSplitDim{0};
if (dmaOp.getTargetObjectFifo() == objFifo) {
Expand Down Expand Up @@ -101,6 +208,27 @@ LogicalResult collectSplittingDims(
break;
}
}
FailureOr<int64_t> maybeSplitStride =
getSplitStride<CopyOpOperateOn::Source>(consumers, sizeAfterSplit);
if (failed(maybeSplitStride)) {
objFifo.emitOpError()
<< "could not retrieve a split stride from the consumer DMA ops";
}
int64_t splitStride = maybeSplitStride.value();
// Calculate the new source stride to be used for splitting the DMA.
int64_t newSourceStride =
splitStride != 1 ? splitDimSize / splitStride : 1;
LLVM_DEBUG(llvm::dbgs() << "sourceSplitDim: " << sourceSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "targetSplitDim: " << targetSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "newSourceStride: " << newSourceStride << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "objFifoSplitDim: " << objFifoSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitStride: " << splitStride << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitFactor: " << numCols << "\n");
dmaSplitInfoMap[dmaOp] = {sourceSplitDim, newSourceStride, targetSplitDim,
1, numCols};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, numCols, splitStride};
} else if (dmaOp.getSourceObjectFifo() == objFifo) {
// Find outermost dimension in the access pattern that has stride ==
// sizeAfterSplit and size != 1.
Expand Down Expand Up @@ -136,13 +264,28 @@ LogicalResult collectSplittingDims(
break;
}
}
FailureOr<int64_t> maybeSplitStride =
getSplitStride<CopyOpOperateOn::Target>(producers, sizeAfterSplit);
if (failed(maybeSplitStride)) {
objFifo.emitOpError()
<< "could not retrieve a split stride from the consumer DMA ops";
}
int64_t splitStride = maybeSplitStride.value();
// Calculate the new target stride to be used for splitting the DMA.
int64_t newTargetStride =
splitStride != 1 ? splitDimSize / splitStride : 1;
LLVM_DEBUG(llvm::dbgs() << "sourceSplitDim: " << sourceSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "targetSplitDim: " << targetSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "newTargetStride: " << newTargetStride << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "objFifoSplitDim: " << objFifoSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitStride: " << splitStride << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitFactor: " << numCols << "\n");
dmaSplitInfoMap[dmaOp] = {sourceSplitDim, 1, targetSplitDim,
newTargetStride, numCols};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, numCols, splitStride};
}
LLVM_DEBUG(llvm::dbgs() << "sourceSplitDim: " << sourceSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "targetSplitDim: " << targetSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "objFifoSplitDim: " << objFifoSplitDim << "\n");
DmaSplitInfo dmaSplitInfo = {sourceSplitDim, targetSplitDim};
dmaSplitInfoMap[dmaOp] = std::move(dmaSplitInfo);
objFifoSplitDimMap[objFifo] = objFifoSplitDim;
}
return success();
}
Expand All @@ -157,9 +300,6 @@ class AMDAIESplitLogicalObjFifosPass

AMDAIESplitLogicalObjFifosPass() = default;
AMDAIESplitLogicalObjFifosPass(const AMDAIESplitLogicalObjFifosPass &pass){};
AMDAIESplitLogicalObjFifosPass(
const AMDAIESplitLogicalObjFifosOptions &options)
: AMDAIESplitLogicalObjFifosBase(options) {}
void runOnOperation() override;
};

Expand All @@ -168,6 +308,18 @@ void AMDAIESplitLogicalObjFifosPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);

// Retrieve the device model.
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(moduleOp);
std::optional<int64_t> maybeNumColumns = getConfigNumColumns(targetAttr);
if (!maybeNumColumns) {
moduleOp.emitOpError() << "has no number of columns specified in the "
"target attribute configuration. This "
"device-specific information is required to "
"correctly split logical objectFifos.";
return signalPassFailure();
}
int64_t numColumns = maybeNumColumns.value();

// Walk and collect all dma ops between L3 and L2.
SmallVector<AMDAIE::DmaCpyNdOp> l3L2DmaOps;
SmallVector<DmaObjFifoPairT> dmaObjFifoPairs;
Expand All @@ -189,9 +341,10 @@ void AMDAIESplitLogicalObjFifosPass::runOnOperation() {

// Collect the split dimensions for all DMA and ojectFifo pairs.
DenseMap<AMDAIE::DmaCpyNdOp, DmaSplitInfo> dmaSplitInfoMap;
DenseMap<AMDAIE::LogicalObjectFifoFromMemrefOp, size_t> objFifoSplitDimMap;
DenseMap<AMDAIE::LogicalObjectFifoFromMemrefOp, ObjFifoSplitInfo>
objFifoSplitInfoMap;
if (failed(collectSplittingDims(dmaObjFifoPairs, dmaSplitInfoMap,
objFifoSplitDimMap))) {
objFifoSplitInfoMap, numColumns))) {
return signalPassFailure();
}

Expand All @@ -200,16 +353,19 @@ void AMDAIESplitLogicalObjFifosPass::runOnOperation() {
for (auto &&[dmaOp, dmaSplitInfo] : dmaSplitInfoMap) {
auto stridedOp =
cast<AMDAIE::DoublyStridedOpInterface>(dmaOp.getOperation());
if (failed(splitDoublyStridedOp(rewriter, stridedOp,
dmaSplitInfo.sourceSplitDim,
dmaSplitInfo.targetSplitDim, numCols))) {
if (failed(splitDoublyStridedOp(
rewriter, stridedOp, dmaSplitInfo.sourceSplitDim,
dmaSplitInfo.targetSplitDim, dmaSplitInfo.splitSize,
dmaSplitInfo.newSourceStride, dmaSplitInfo.newTargetStride))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to perform splitting of the DMA op: " << dmaOp);
return signalPassFailure();
}
}
for (auto &&[objFifo, splitDim] : objFifoSplitDimMap) {
if (failed(splitLogicalObjectFifo(rewriter, objFifo, splitDim, numCols))) {
for (auto &&[objFifo, splitInfo] : objFifoSplitInfoMap) {
if (failed(splitLogicalObjectFifo(rewriter, objFifo, splitInfo.splitDim,
splitInfo.splitSize,
splitInfo.splitStride))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to perform splitting of objectFifo op");
return signalPassFailure();
Expand All @@ -219,9 +375,8 @@ void AMDAIESplitLogicalObjFifosPass::runOnOperation() {

} // namespace

std::unique_ptr<Pass> createAMDAIESplitLogicalObjFifosPass(
AMDAIESplitLogicalObjFifosOptions options) {
return std::make_unique<AMDAIESplitLogicalObjFifosPass>(options);
std::unique_ptr<Pass> createAMDAIESplitLogicalObjFifosPass() {
return std::make_unique<AMDAIESplitLogicalObjFifosPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,8 @@ void addAMDAIEObjectFifoLoweringPasses(

passManager.addPass(createAMDAIESplitLogicalObjFifosForConnectionReusePass());
// Currently, SplitLogicalObjFifos pass only works for matmul-like ops.
{
if (useTilePipeline == TilePassPipeline::PackPeelPipeline) {
AMDAIESplitLogicalObjFifosOptions splitOptions;
splitOptions.numCols = numCols;
passManager.addPass(createAMDAIESplitLogicalObjFifosPass(splitOptions));
}
}
if (useTilePipeline == TilePassPipeline::PackPeelPipeline)
passManager.addPass(createAMDAIESplitLogicalObjFifosPass());

passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ std::unique_ptr<Pass> createAMDAIERemoveMemorySpacePass();
std::unique_ptr<Pass> createAMDAIESinkIntoCorePass();

/// Create a pass to split logicalobjectfifos for shimTile/memTile distribution.
std::unique_ptr<Pass> createAMDAIESplitLogicalObjFifosPass(
AMDAIESplitLogicalObjFifosOptions options = {});
std::unique_ptr<Pass> createAMDAIESplitLogicalObjFifosPass();

/// Create a pass to split logicalobjectfifos for connection reuse.
std::unique_ptr<Pass> createAMDAIESplitLogicalObjFifosForConnectionReusePass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,6 @@ def AMDAIESplitLogicalObjFifos :
`[1, 2, 32, 32]`, will be split to two `[1, 1, 32, 32]` buffers.
}];
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIESplitLogicalObjFifosPass()";
let options = [
Option<"numCols", "num-cols", "uint32_t", /*default=*/"4",
"Number of columns used in an AIE core array">
];
}

def AMDAIESplitLogicalObjFifosForConnectionReuse :
Expand Down
Loading

0 comments on commit 1b29d76

Please sign in to comment.