Skip to content

Commit

Permalink
[CreateLogicalObjFifoLink] Refactor unsafe walk for non-zero offset r…
Browse files Browse the repository at this point in the history
…emoval (#739)
  • Loading branch information
jtuyls authored Sep 3, 2024
1 parent 07ea41d commit 57a3636
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ namespace mlir::iree_compiler::AMDAIE {
template <CopyOpOperateOn OperateOn>
LogicalResult checkForContiguousAccessPatterns(
ArrayRef<std::pair<DoublyStridedCopyOpInterface, int64_t>> stridedOps) {

for (auto &&[i, stridedOpAndOffset] : llvm::enumerate(stridedOps)) {
DoublyStridedCopyOpInterface stridedOp = stridedOpAndOffset.first;
std::optional<int64_t> extent;
Expand Down Expand Up @@ -77,7 +76,8 @@ LogicalResult checkForContiguousAccessPatterns(
/// block and an error will be emitted if that's not the case.
LogicalResult createLogicalObjectFifoLink(
RewriterBase &rewriter,
AMDAIE::LogicalObjectFifoFromMemrefOp logicalObjectFifo) {
AMDAIE::LogicalObjectFifoFromMemrefOp logicalObjectFifo,
SmallVector<AMDAIE::LogicalObjectFifoLink> &newLinkOps) {
Attribute memSpace = logicalObjectFifo.getMemorySpace();
if (!memSpace || dyn_cast<IntegerAttr>(memSpace).getInt() != 1) {
return success();
Expand Down Expand Up @@ -140,12 +140,10 @@ LogicalResult createLogicalObjectFifoLink(

// Check that access patterns are not overlapping between consumers
// respectively producers.
if (failed(
checkForContiguousAccessPatterns<CopyOpOperateOn::Target>(ins))) {
if (failed(checkForContiguousAccessPatterns<CopyOpOperateOn::Target>(ins))) {
return failure();
}
if (failed(
checkForContiguousAccessPatterns<CopyOpOperateOn::Source>(outs))) {
if (failed(checkForContiguousAccessPatterns<CopyOpOperateOn::Source>(outs))) {
return failure();
}

Expand All @@ -161,8 +159,30 @@ LogicalResult createLogicalObjectFifoLink(
// Insert the `LogicalObjectFifoLink` after the last user operation.
if (lastUserOp) {
rewriter.setInsertionPointAfter(lastUserOp);
rewriter.create<AMDAIE::LogicalObjectFifoLink>(rewriter.getUnknownLoc(),
inResults, outResults);
auto linkOp = rewriter.create<AMDAIE::LogicalObjectFifoLink>(
rewriter.getUnknownLoc(), inResults, outResults);
newLinkOps.push_back(linkOp);
}
return success();
}

LogicalResult discardLinkNonZeroOffsets(RewriterBase &rewriter,
AMDAIE::LogicalObjectFifoLink linkOp) {
for (Value input : linkOp.getIns()) {
if (auto stridedOp = dyn_cast<AMDAIE::DoublyStridedCopyOpInterface>(
input.getDefiningOp())) {
SmallVector<int64_t> shape;
(void)discardAllNonZeroOffsets<CopyOpOperateOn::Target>(rewriter,
stridedOp, shape);
}
}
for (Value output : linkOp.getOuts()) {
if (auto stridedOp = dyn_cast<AMDAIE::DoublyStridedCopyOpInterface>(
output.getDefiningOp())) {
SmallVector<int64_t> shape;
(void)discardAllNonZeroOffsets<CopyOpOperateOn::Source>(rewriter,
stridedOp, shape);
}
}
return success();
}
Expand All @@ -180,10 +200,11 @@ struct AMDAIECreateLogicalObjectFifoLinkPass
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());

SmallVector<AMDAIE::LogicalObjectFifoLink> newLinkOps;
WalkResult res = parentOp->walk(
[&](AMDAIE::LogicalObjectFifoFromMemrefOp logicalObjectFifo) {
if (failed(
createLogicalObjectFifoLink(rewriter, logicalObjectFifo))) {
if (failed(createLogicalObjectFifoLink(rewriter, logicalObjectFifo,
newLinkOps))) {
logicalObjectFifo.emitError() << "couldn't create a link operation";
return WalkResult::interrupt();
}
Expand All @@ -192,24 +213,10 @@ struct AMDAIECreateLogicalObjectFifoLinkPass
if (res.wasInterrupted()) return signalPassFailure();

// Remove all non-zero offsets.
parentOp->walk([&](AMDAIE::LogicalObjectFifoLink linkOp) {
for (Value input : linkOp.getIns()) {
if (auto stridedOp = dyn_cast<AMDAIE::DoublyStridedCopyOpInterface>(
input.getDefiningOp())) {
SmallVector<int64_t> shape;
(void)discardAllNonZeroOffsets<CopyOpOperateOn::Target>(
rewriter, stridedOp, shape);
}
}
for (Value output : linkOp.getOuts()) {
if (auto stridedOp = dyn_cast<AMDAIE::DoublyStridedCopyOpInterface>(
output.getDefiningOp())) {
SmallVector<int64_t> shape;
(void)discardAllNonZeroOffsets<CopyOpOperateOn::Source>(
rewriter, stridedOp, shape);
}
}
});
for (AMDAIE::LogicalObjectFifoLink linkOp : newLinkOps) {
if (failed(discardLinkNonZeroOffsets(rewriter, linkOp)))
return signalPassFailure();
}
}
};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-amdaie-create-logical-objectfifo-link, cse))" --verify-diagnostics %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-amdaie-create-logical-objectfifo-link,cse,canonicalize))" --verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: func.func @link
// CHECK: %[[DMA0:.+]] = amdaie.circular_dma_cpy_nd
Expand Down

0 comments on commit 57a3636

Please sign in to comment.