diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index dac831af5477..46021a556717 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -954,145 +954,49 @@ struct AffineForOpRaising : public OpRewritePattern { } }; -// struct RemoveIterArgs : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; -// LogicalResult matchAndRewrite(scf::ForOp forOp, -// PatternRewriter &rewriter) const override { -// if (!forOp.getRegion().hasOneBlock()) -// return failure(); -// unsigned numIterArgs = forOp.getNumRegionIterArgs(); -// auto loc = forOp->getLoc(); -// bool changed = false; -// llvm::SetVector removed; -// llvm::MapVector steps; -// auto yield = cast(forOp.getBody()->getTerminator()); -// for (unsigned i = 0; i < numIterArgs; i++) { -// auto ba = forOp.getRegionIterArgs()[i]; -// auto init = forOp.getInits()[i]; -// auto next = yield->getOperand(i); - -// auto increment = next.getDefiningOp(); -// if (!increment) -// continue; - -// Value step = nullptr; -// if (increment.getLhs() == ba) { -// step = increment.getRhs(); -// } else { -// step = increment.getLhs(); -// } -// if (!step) -// continue; - -// // If it dominates the loop entry -// if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion())) -// continue; - -// rewriter.setInsertionPointToStart(forOp.getBody()); -// Value iterNum = rewriter.create( -// loc, forOp.getInductionVar(), forOp.getLowerBound()); -// iterNum = rewriter.create(loc, iterNum, forOp.getStep()); - -// Value replacementIV = rewriter.create(loc, iterNum, step); -// replacementIV = rewriter.create(loc, replacementIV, init); - -// rewriter.replaceAllUsesWith(ba, replacementIV); - -// removed.insert(i); -// steps.insert({i, step}); -// changed = true; -// } - -// if (!changed) -// return failure(); - -// SmallVector newInits; -// for (unsigned i = 0; i < numIterArgs; i++) -// if (!removed.contains(i)) -// newInits.push_back(forOp.getInits()[i]); - -// rewriter.setInsertionPoint(forOp); -// auto newForOp = rewriter.create(loc, forOp.getLowerBound(), -// forOp.getUpperBound(), -// forOp.getStep(), newInits); -// if (!newForOp.getRegion().empty()) -// newForOp.getRegion().front().erase(); -// assert(newForOp.getRegion().empty()); -// rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), -// newForOp.getRegion().begin()); - -// SmallVector newYields; -// for (unsigned i = 0; i < numIterArgs; i++) -// if (!removed.contains(i)) -// newYields.push_back(yield->getOperand(i)); - -// rewriter.setInsertionPoint(yield); -// rewriter.replaceOpWithNewOp(yield, newYields); - -// llvm::BitVector toDelete(numIterArgs + 1); -// for (unsigned i = 0; i < numIterArgs; i++) -// if (removed.contains(i)) -// toDelete[i + 1] = true; -// newForOp.getBody()->eraseArguments(toDelete); - -// rewriter.setInsertionPoint(newForOp); -// unsigned curNewRes = 0; -// for (unsigned i = 0; i < numIterArgs; i++) { -// auto result = forOp->getResult(i); -// if (removed.contains(i)) { -// if (result.use_empty()) -// continue; - -// rewriter.setInsertionPointToStart(forOp.getBody()); -// Value iterNum = rewriter.create( -// loc, forOp.getUpperBound(), forOp.getLowerBound()); -// iterNum = -// rewriter.create(loc, iterNum, forOp.getStep()); - -// Value afterLoop = -// rewriter.create(loc, iterNum, steps[i]); -// afterLoop = -// rewriter.create(loc, afterLoop, forOp.getInits()[i]); - -// rewriter.replaceAllUsesWith(result, afterLoop); -// } else { -// rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++)); -// } -// } - -// rewriter.eraseOp(forOp); +// namespace { +// struct RaiseAffineToLinalg +// : public AffineRaiseToLinalgBase { +// std::shared_ptr patterns; + +// LogicalResult initialize(MLIRContext *context) override { +// RewritePatternSet owningPatterns(context); +// for (auto *dialect : context->getLoadedDialects()) +// dialect->getCanonicalizationPatterns(owningPatterns); +// for (RegisteredOperationName op : context->getRegisteredOperations()) +// op.getCanonicalizationPatterns(owningPatterns, context); + +// owningPatterns.insert(&getContext()); + +// patterns = std::make_shared( +// std::move(owningPatterns)); // return success(); // } +// void runOnOperation() override { +// GreedyRewriteConfig config; +// (void)applyPatternsAndFoldGreedily(getOperation(), *patterns, config); +// } // }; +// } // namespace namespace { struct RaiseAffineToLinalg : public AffineRaiseToLinalgBase { - - std::shared_ptr patterns; - - LogicalResult initialize(MLIRContext *context) override { - RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) - dialect->getCanonicalizationPatterns(owningPatterns); - for (RegisteredOperationName op : context->getRegisteredOperations()) - op.getCanonicalizationPatterns(owningPatterns, context); - - //owningPatterns.insert(&getContext()); - owningPatterns.insert(&getContext()); - - patterns = std::make_shared( - std::move(owningPatterns)); - return success(); - } - void runOnOperation() override { - GreedyRewriteConfig config; - (void)applyPatternsAndFoldGreedily(getOperation(), *patterns, config); - } + void runOnOperation() override; }; } // namespace +void RaiseAffineToLinalg::runOnOperation() { + RewritePatternSet patterns(&getContext()); + // TODO add the existing canonicalization patterns + // + subview of an affine apply -> subview + patterns.insert(&getContext()); + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); +} + namespace mlir { namespace polygeist { std::unique_ptr createRaiseAffineToLinalgPass() {