Skip to content

Commit

Permalink
Use dyn_cast_if_present and isa_and_present with getDefiningOp (#744)
Browse files Browse the repository at this point in the history
- Update `dyn_cast` to `dyn_cast_if_present` everywhere getDefiningOp is
used
- Update `isa` to `isa_and_present ` everywhere getDefiningOp is used
- Update `dyn_cast_or_null` to `dyn_cast_if_present` because the former
is deprecated
- Update `isa_and_nonnull` to `isa_and_present` because although not
deprecated according to comments, the former calls the latter under the
hood and to be consistent with `dyn_cast_if_present`
  • Loading branch information
jtuyls authored Sep 4, 2024
1 parent 57a3636 commit cba6bdf
Show file tree
Hide file tree
Showing 16 changed files with 103 additions and 83 deletions.
23 changes: 14 additions & 9 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ LogicalResult CoreOp::verify() {
}

TileOp CoreOp::getTileOp() {
return dyn_cast<TileOp>(getTile().getDefiningOp());
return dyn_cast_if_present<TileOp>(getTile().getDefiningOp());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -268,11 +268,13 @@ DoublyStridedOpInterface DmaCpyNdOp::createDoublyStridedOp(
}

LogicalObjectFifoFromMemrefOp DmaCpyNdOp::getSourceObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getSource().getDefiningOp());
return dyn_cast_if_present<LogicalObjectFifoFromMemrefOp>(
getSource().getDefiningOp());
};

LogicalObjectFifoFromMemrefOp DmaCpyNdOp::getTargetObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget().getDefiningOp());
return dyn_cast_if_present<LogicalObjectFifoFromMemrefOp>(
getTarget().getDefiningOp());
};

void DmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
Expand Down Expand Up @@ -395,11 +397,13 @@ DoublyStridedOpInterface CircularDmaCpyNdOp::createDoublyStridedOp(
}

LogicalObjectFifoFromMemrefOp CircularDmaCpyNdOp::getSourceObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getSource().getDefiningOp());
return dyn_cast_if_present<LogicalObjectFifoFromMemrefOp>(
getSource().getDefiningOp());
};

LogicalObjectFifoFromMemrefOp CircularDmaCpyNdOp::getTargetObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget().getDefiningOp());
return dyn_cast_if_present<LogicalObjectFifoFromMemrefOp>(
getTarget().getDefiningOp());
};

void CircularDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
Expand All @@ -422,7 +426,8 @@ void LogicalObjectFifoAccessOp::build(OpBuilder &b,

LogicalObjectFifoFromMemrefOp
LogicalObjectFifoAccessOp::getLogicalObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getInput().getDefiningOp());
return dyn_cast_if_present<LogicalObjectFifoFromMemrefOp>(
getInput().getDefiningOp());
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -490,7 +495,7 @@ LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize(
LogicalResult LogicalObjectFifoFromMemrefOp::verify() {
// Check whether the tile arguments are all of type AMDAIE::TileOp
if (llvm::all_of(getTiles(), [](Value result) {
return isa<TileOp>(result.getDefiningOp());
return isa_and_present<TileOp>(result.getDefiningOp());
})) {
return success();
}
Expand Down Expand Up @@ -878,8 +883,8 @@ bool TileOp::tileColumnComparator(AMDAIE::TileOp &a, AMDAIE::TileOp &b) {
}

bool TileOp::tileValueColumnAndRowComparator(Value a, Value b) {
TileOp tileA = dyn_cast<AMDAIE::TileOp>(a.getDefiningOp());
TileOp tileB = dyn_cast<AMDAIE::TileOp>(b.getDefiningOp());
TileOp tileA = cast<AMDAIE::TileOp>(a.getDefiningOp());
TileOp tileB = cast<AMDAIE::TileOp>(b.getDefiningOp());
int64_t colA = getConstantIntValue(tileA.getCol()).value();
int64_t rowA = getConstantIntValue(tileA.getRow()).value();
int64_t colB = getConstantIntValue(tileB.getCol()).value();
Expand Down
12 changes: 6 additions & 6 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def AMDAIE_NpuDmaCpyNdOp: AMDAIE_Op<"npu.dma_cpy_nd", [
}
// Return the input circular dma copy operation.
CircularDmaCpyNdOp getDmaCpyNdOp() {
return dyn_cast<CircularDmaCpyNdOp>(getDma().getDefiningOp());
return dyn_cast_if_present<CircularDmaCpyNdOp>(getDma().getDefiningOp());
}

// Return the source memref type. This is retrieved using information from
Expand Down Expand Up @@ -350,14 +350,14 @@ def AMDAIE_NpuDmaCpyNdOp: AMDAIE_Op<"npu.dma_cpy_nd", [

BdIdOp getSourceBdIdOp() {
Value bdIdValue = getSourceBdId();
if (!bdIdValue || !bdIdValue.getDefiningOp()) return nullptr;
return dyn_cast<BdIdOp>(bdIdValue.getDefiningOp());
if (!bdIdValue) return nullptr;
return dyn_cast_if_present<BdIdOp>(bdIdValue.getDefiningOp());
}

BdIdOp getTargetBdIdOp() {
Value bdIdValue = getTargetBdId();
if (!bdIdValue || !bdIdValue.getDefiningOp()) return nullptr;
return dyn_cast<BdIdOp>(bdIdValue.getDefiningOp());
if (!bdIdValue) return nullptr;
return dyn_cast_if_present<BdIdOp>(bdIdValue.getDefiningOp());
}

// A utility to create a new doubly strided operation from this one with a
Expand Down Expand Up @@ -417,7 +417,7 @@ def AMDAIE_NpuDmaWaitOp: AMDAIE_Op<"npu.dma_wait", []> {
let extraClassDeclaration = [{
// Return the Npu DMA operation argument.
NpuDmaCpyNdOp getDmaOp() {
return dyn_cast<NpuDmaCpyNdOp>(getDma().getDefiningOp());
return dyn_cast_if_present<NpuDmaCpyNdOp>(getDma().getDefiningOp());
}
}];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
return npuDmaOp.emitOpError()
<< "no channel BD ID generator found for tile: " << tile;
}
tileOp = dyn_cast<AMDAIE::TileOp>(tile.getDefiningOp());
tileOp = dyn_cast_if_present<AMDAIE::TileOp>(tile.getDefiningOp());
if (!tileOp) return npuDmaOp.emitOpError() << "no tile op found";
return success();
};
Expand All @@ -65,8 +65,9 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
WalkResult res = controlCodeOp->walk([&](Operation *op) {
if (auto npuDmaOp = dyn_cast<AMDAIE::NpuDmaCpyNdOp>(op)) {
if (npuDmaOp.getSource()) {
auto logicalObjFifo = dyn_cast<AMDAIE::LogicalObjFifoOpInterface>(
npuDmaOp.getSource().getDefiningOp());
auto logicalObjFifo =
dyn_cast_if_present<AMDAIE::LogicalObjFifoOpInterface>(
npuDmaOp.getSource().getDefiningOp());
if (!logicalObjFifo) {
npuDmaOp.emitOpError() << "expected a source logical objectFifo";
return WalkResult::interrupt();
Expand Down Expand Up @@ -96,8 +97,9 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
bdIdOp);
}
if (npuDmaOp.getTarget()) {
auto logicalObjFifo = dyn_cast<AMDAIE::LogicalObjectFifoFromMemrefOp>(
npuDmaOp.getTarget().getDefiningOp());
auto logicalObjFifo =
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
npuDmaOp.getTarget().getDefiningOp());
if (!logicalObjFifo) {
npuDmaOp.emitOpError()
<< "expected a target `amdaie.logicalobjectfifo.from_memref`";
Expand Down Expand Up @@ -132,14 +134,17 @@ LogicalResult assignNpuDmaBdIds(AMDAIE::WorkgroupOp workgroupOp) {
AMDAIE::NpuDmaCpyNdOp npuDmaOp = npuWaitOp.getDmaOp();
AMDAIE::BdIdOp bdIdOp;
if (npuDmaOp.getSourceBdId()) {
bdIdOp = cast<AMDAIE::BdIdOp>(npuDmaOp.getSourceBdId().getDefiningOp());
bdIdOp = dyn_cast_if_present<AMDAIE::BdIdOp>(
npuDmaOp.getSourceBdId().getDefiningOp());
} else if (npuDmaOp.getTargetBdId()) {
bdIdOp = cast<AMDAIE::BdIdOp>(npuDmaOp.getTargetBdId().getDefiningOp());
bdIdOp = dyn_cast_if_present<AMDAIE::BdIdOp>(
npuDmaOp.getTargetBdId().getDefiningOp());
} else {
return WalkResult::advance();
}
if (!bdIdOp) return WalkResult::advance();
auto tileOp = dyn_cast<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
auto tileOp =
dyn_cast_if_present<AMDAIE::TileOp>(bdIdOp.getTile().getDefiningOp());
if (!tileOp) {
bdIdOp.emitOpError() << "doesn't operate on a `amdaie.tile` operation";
return WalkResult::interrupt();
Expand Down Expand Up @@ -169,7 +174,7 @@ class AMDAIEAssignNpuDmaBdIdsPass
}

AMDAIEAssignNpuDmaBdIdsPass() = default;
AMDAIEAssignNpuDmaBdIdsPass(const AMDAIEAssignNpuDmaBdIdsPass &pass) {};
AMDAIEAssignNpuDmaBdIdsPass(const AMDAIEAssignNpuDmaBdIdsPass &pass){};
void runOnOperation() override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ LogicalResult coreLoopUnroll(RewriterBase &rewriter, AMDAIE::CoreOp coreOp) {
llvm::SmallDenseSet<unsigned> depths;
for (auto acqOp :
forOp.getBody()->getOps<AMDAIE::LogicalObjectFifoAcquire>()) {
auto stridedOp = dyn_cast<DoublyStridedCopyOpInterface>(
auto stridedOp = dyn_cast_if_present<DoublyStridedCopyOpInterface>(
acqOp.getDma().getDefiningOp());
if (!stridedOp) {
acqOp.emitOpError()
Expand All @@ -35,9 +35,9 @@ LogicalResult coreLoopUnroll(RewriterBase &rewriter, AMDAIE::CoreOp coreOp) {
}
auto logicalObjFifo =
acqOp.getPort() == LogicalObjectFifoPort::Consume
? dyn_cast<AMDAIE::LogicalObjectFifoFromMemrefOp>(
? dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
stridedOp.getTarget().getDefiningOp())
: dyn_cast<AMDAIE::LogicalObjectFifoFromMemrefOp>(
: dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
stridedOp.getSource().getDefiningOp());
depths.insert(
cast<LogicalObjectFifoType>(logicalObjFifo.getType()).getDepth());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ LogicalResult WorkgroupBuilder::buildForDmaCpyNdOp(
// Check if the source of DmaCpyNd op is from L3 - then source addressing
// will be controlled by the uController and target addressing will stay in
// the circular DMA to be part of the AIE configuration.
auto logicalObjFifo = dyn_cast<AMDAIE::LogicalObjectFifoFromMemrefOp>(
dmaOp.getSource().getDefiningOp());
auto logicalObjFifo =
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
dmaOp.getSource().getDefiningOp());
if (!logicalObjFifo) {
return dmaOp.emitOpError()
<< "`amdaie.logicalobjectfifo.from_memref` expected as source";
Expand All @@ -143,8 +144,9 @@ LogicalResult WorkgroupBuilder::buildForDmaCpyNdOp(
// Check if the target of DmaCpyNd op is from L3 - then target addressing
// will be controlled by the uController and source addressing will stay in
// the circular DMA to be part of the AIE configuration.
auto logicalObjFifo = dyn_cast<AMDAIE::LogicalObjectFifoFromMemrefOp>(
dmaOp.getTarget().getDefiningOp());
auto logicalObjFifo =
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
dmaOp.getTarget().getDefiningOp());
if (!logicalObjFifo) {
return dmaOp.emitOpError()
<< "`amdaie.logicalobjectfifo.from_memref` expected as source";
Expand Down Expand Up @@ -425,7 +427,7 @@ class AMDAIECreateAIEWorkgroupPass
}

AMDAIECreateAIEWorkgroupPass() = default;
AMDAIECreateAIEWorkgroupPass(const AMDAIECreateAIEWorkgroupPass &pass) {};
AMDAIECreateAIEWorkgroupPass(const AMDAIECreateAIEWorkgroupPass &pass){};
void runOnOperation() override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ LogicalResult createLogicalObjectFifoLink(
"has copy-like users not residing in the same block");
}
auto sourceLogicalObjectFifo =
dyn_cast<AMDAIE::LogicalObjectFifoFromMemrefOp>(
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
stridedOp.getSource().getDefiningOp());
if (!lastUserOp || lastUserOp->isBeforeInBlock(stridedOp)) {
lastUserOp = stridedOp;
Expand Down Expand Up @@ -169,16 +169,18 @@ LogicalResult createLogicalObjectFifoLink(
LogicalResult discardLinkNonZeroOffsets(RewriterBase &rewriter,
AMDAIE::LogicalObjectFifoLink linkOp) {
for (Value input : linkOp.getIns()) {
if (auto stridedOp = dyn_cast<AMDAIE::DoublyStridedCopyOpInterface>(
input.getDefiningOp())) {
if (auto stridedOp =
dyn_cast_if_present<AMDAIE::DoublyStridedCopyOpInterface>(
input.getDefiningOp())) {
SmallVector<int64_t> shape;
(void)discardAllNonZeroOffsets<CopyOpOperateOn::Target>(rewriter,
stridedOp, shape);
}
}
for (Value output : linkOp.getOuts()) {
if (auto stridedOp = dyn_cast<AMDAIE::DoublyStridedCopyOpInterface>(
output.getDefiningOp())) {
if (auto stridedOp =
dyn_cast_if_present<AMDAIE::DoublyStridedCopyOpInterface>(
output.getDefiningOp())) {
SmallVector<int64_t> shape;
(void)discardAllNonZeroOffsets<CopyOpOperateOn::Source>(rewriter,
stridedOp, shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,15 +505,17 @@ class AMDAIEUnrollLocalLoops : public OpRewritePattern<scf::ForOp> {
AMDAIE::LogicalObjectFifoFromMemrefOp target =
dmaOp.getTargetObjectFifo();
rewriter.setInsertionPoint(target);
auto cloneOp = dyn_cast<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.clone(*dmaOp.getTarget().getDefiningOp()));
auto cloneOp =
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.clone(*dmaOp.getTarget().getDefiningOp()));
operandMap.map(target.getOutput(), cloneOp.getOutput());
} else if (sourceMemSpaceInt > targetMemSpaceInt) {
AMDAIE::LogicalObjectFifoFromMemrefOp source =
dmaOp.getSourceObjectFifo();
rewriter.setInsertionPoint(source);
auto cloneOp = dyn_cast<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.clone(*dmaOp.getSource().getDefiningOp()));
auto cloneOp =
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.clone(*dmaOp.getSource().getDefiningOp()));
operandMap.map(source.getOutput(), cloneOp.getOutput());
}
}
Expand Down Expand Up @@ -585,8 +587,10 @@ LogicalResult getUserTiles(

// Only fill in tiles when all sources have tiles.
if (tileIndices.empty()) return failure();
for (Value index : tileIndices)
tileSet.insert(dyn_cast<AMDAIE::TileOp>(index.getDefiningOp()));
for (Value index : tileIndices) {
tileSet.insert(
dyn_cast_if_present<AMDAIE::TileOp>(index.getDefiningOp()));
}
}
}
tiles = tileSet.takeVector();
Expand Down Expand Up @@ -635,10 +639,8 @@ LogicalResult insertLogicalObjectFifoAccess(ModuleOp moduleOp) {
WalkResult res = coreOp->walk([&](Operation *op) {
bool hasAllocOperand = [op]() {
for (Value operand : op->getOperands()) {
Operation *definingOp = operand.getDefiningOp();
if (definingOp && isa<memref::AllocOp>(definingOp)) {
if (isa_and_present<memref::AllocOp>(operand.getDefiningOp()))
return true;
}
}
return false;
}();
Expand Down Expand Up @@ -904,9 +906,10 @@ LogicalResult assignAieTilesAndDistributeLogicalObjectFifos(ModuleOp moduleOp) {
if (memSpace && dyn_cast<IntegerAttr>(memSpace).getInt() != 1)
return WalkResult::advance();

SmallVector<AMDAIE::TileOp> tiles = llvm::map_to_vector(
logicalObjectFifo.getTiles(),
[](Value tile) { return dyn_cast<TileOp>(tile.getDefiningOp()); });
SmallVector<AMDAIE::TileOp> tiles =
llvm::map_to_vector(logicalObjectFifo.getTiles(), [](Value tile) {
return dyn_cast_if_present<TileOp>(tile.getDefiningOp());
});
llvm::sort(tiles.begin(), tiles.end(),
AMDAIE::TileOp::tileColumnComparator);

Expand Down Expand Up @@ -934,7 +937,7 @@ class AMDAIEDistributeCoresAndObjectFifosPass

AMDAIEDistributeCoresAndObjectFifosPass() = default;
AMDAIEDistributeCoresAndObjectFifosPass(
const AMDAIEDistributeCoresAndObjectFifosPass &pass) {};
const AMDAIEDistributeCoresAndObjectFifosPass &pass){};
void runOnOperation() override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ struct SubsumeLoopIntoDMA
// If the offset value is determined by an affine expression, retrieve
// the affine expression's stride scale and calculate the actual
// offset stride.
if (offsetValue.getDefiningOp() &&
isa<affine::AffineApplyOp>(offsetValue.getDefiningOp())) {
if (isa_and_present<affine::AffineApplyOp>(
offsetValue.getDefiningOp())) {
auto applyOp =
cast<affine::AffineApplyOp>(offsetValue.getDefiningOp());
// Retrieve the scale and optional bias from the affine map using an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ namespace {
static FailureOr<tensor::ExtractSliceOp> getTensorExtractSliceDefiningOp(
Value operand) {
while (Operation *defOp = operand.getDefiningOp()) {
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(defOp);
auto sliceOp = dyn_cast_if_present<tensor::ExtractSliceOp>(defOp);
if (sliceOp) {
// The producer of sliceOp should be a pack op.
if (isa_and_nonnull<tensor::PackOp>(
if (isa_and_present<tensor::PackOp>(
sliceOp.getSource().getDefiningOp())) {
return sliceOp;
}
if (isa<BlockArgument>(sliceOp.getSource())) {
auto blkArg = dyn_cast<BlockArgument>(sliceOp.getSource());
for (Value blkOperand :
blkArg.getOwner()->getParentOp()->getOperands()) {
if (isa_and_nonnull<tensor::PackOp>(blkOperand.getDefiningOp())) {
if (isa_and_present<tensor::PackOp>(blkOperand.getDefiningOp())) {
return sliceOp;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ LogicalResult insertCoreOps(mlir::ModuleOp moduleOp) {
// Fetch name of the ukernel function to look up its declaration in the
// Symbol table.
StringRef fnName = callOp.getCallee();
auto fnDecl = dyn_cast_or_null<func::FuncOp>(
auto fnDecl = dyn_cast_if_present<func::FuncOp>(
SymbolTable::lookupSymbolIn(moduleOp, fnName));
assert(fnDecl && "expected function declaration");
assert(fnDecl->hasAttr("link_with") &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ static FailureOr<OpFoldResult> updateL3SourceOffset(IRRewriter &rewriter,
Operation *defOpOfL3SourceOffset = l3SourceOffsetVal.getDefiningOp();
Location loc = defOpOfL3SourceOffset->getLoc();
rewriter.setInsertionPoint(defOpOfL3SourceOffset);
if (auto applyOp =
dyn_cast<affine::AffineApplyOp>(defOpOfL3SourceOffset)) {
if (auto applyOp = dyn_cast_if_present<affine::AffineApplyOp>(
defOpOfL3SourceOffset)) {
AffineExpr affineExpr = applyOp.getAffineMap().getResult(0);
AffineMap newAffineMap = createAffineMap(affineExpr, offsetToAdd);
newL3AsSourceOffset =
Expand Down Expand Up @@ -423,7 +423,7 @@ LogicalResult splitLogicalObjectFifos(
op->dropAllUses();
rewriter.eraseOp(op);
}

return success();
}

Expand Down
Loading

0 comments on commit cba6bdf

Please sign in to comment.