Skip to content

Commit

Permalink
Temporary reverted pass registeration as the code was failing
Browse files Browse the repository at this point in the history
  • Loading branch information
arpitj1 committed Oct 12, 2024
1 parent 701f25a commit d285fb5
Showing 1 changed file with 32 additions and 128 deletions.
160 changes: 32 additions & 128 deletions lib/polygeist/Passes/RaiseToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,145 +954,49 @@ struct AffineForOpRaising : public OpRewritePattern<affine::AffineForOp> {
}
};

// struct RemoveIterArgs : public OpRewritePattern<scf::ForOp> {
// using OpRewritePattern<scf::ForOp>::OpRewritePattern;
// LogicalResult matchAndRewrite(scf::ForOp forOp,
// PatternRewriter &rewriter) const override {
// if (!forOp.getRegion().hasOneBlock())
// return failure();
// unsigned numIterArgs = forOp.getNumRegionIterArgs();
// auto loc = forOp->getLoc();
// bool changed = false;
// llvm::SetVector<unsigned> removed;
// llvm::MapVector<unsigned, Value> steps;
// auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
// for (unsigned i = 0; i < numIterArgs; i++) {
// auto ba = forOp.getRegionIterArgs()[i];
// auto init = forOp.getInits()[i];
// auto next = yield->getOperand(i);

// auto increment = next.getDefiningOp<arith::AddIOp>();
// if (!increment)
// continue;

// Value step = nullptr;
// if (increment.getLhs() == ba) {
// step = increment.getRhs();
// } else {
// step = increment.getLhs();
// }
// if (!step)
// continue;

// // If it dominates the loop entry
// if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion()))
// continue;

// rewriter.setInsertionPointToStart(forOp.getBody());
// Value iterNum = rewriter.create<arith::SubIOp>(
// loc, forOp.getInductionVar(), forOp.getLowerBound());
// iterNum = rewriter.create<arith::DivSIOp>(loc, iterNum, forOp.getStep());

// Value replacementIV = rewriter.create<arith::MulIOp>(loc, iterNum, step);
// replacementIV = rewriter.create<arith::AddIOp>(loc, replacementIV, init);

// rewriter.replaceAllUsesWith(ba, replacementIV);

// removed.insert(i);
// steps.insert({i, step});
// changed = true;
// }

// if (!changed)
// return failure();

// SmallVector<Value> newInits;
// for (unsigned i = 0; i < numIterArgs; i++)
// if (!removed.contains(i))
// newInits.push_back(forOp.getInits()[i]);

// rewriter.setInsertionPoint(forOp);
// auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
// forOp.getUpperBound(),
// forOp.getStep(), newInits);
// if (!newForOp.getRegion().empty())
// newForOp.getRegion().front().erase();
// assert(newForOp.getRegion().empty());
// rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(),
// newForOp.getRegion().begin());

// SmallVector<Value> newYields;
// for (unsigned i = 0; i < numIterArgs; i++)
// if (!removed.contains(i))
// newYields.push_back(yield->getOperand(i));

// rewriter.setInsertionPoint(yield);
// rewriter.replaceOpWithNewOp<scf::YieldOp>(yield, newYields);

// llvm::BitVector toDelete(numIterArgs + 1);
// for (unsigned i = 0; i < numIterArgs; i++)
// if (removed.contains(i))
// toDelete[i + 1] = true;
// newForOp.getBody()->eraseArguments(toDelete);

// rewriter.setInsertionPoint(newForOp);
// unsigned curNewRes = 0;
// for (unsigned i = 0; i < numIterArgs; i++) {
// auto result = forOp->getResult(i);
// if (removed.contains(i)) {
// if (result.use_empty())
// continue;

// rewriter.setInsertionPointToStart(forOp.getBody());
// Value iterNum = rewriter.create<arith::SubIOp>(
// loc, forOp.getUpperBound(), forOp.getLowerBound());
// iterNum =
// rewriter.create<arith::DivSIOp>(loc, iterNum, forOp.getStep());

// Value afterLoop =
// rewriter.create<arith::MulIOp>(loc, iterNum, steps[i]);
// afterLoop =
// rewriter.create<arith::AddIOp>(loc, afterLoop, forOp.getInits()[i]);

// rewriter.replaceAllUsesWith(result, afterLoop);
// } else {
// rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++));
// }
// }

// rewriter.eraseOp(forOp);
// namespace {
// struct RaiseAffineToLinalg
// : public AffineRaiseToLinalgBase<RaiseAffineToLinalg> {

// std::shared_ptr<const FrozenRewritePatternSet> patterns;

// LogicalResult initialize(MLIRContext *context) override {
// RewritePatternSet owningPatterns(context);
// for (auto *dialect : context->getLoadedDialects())
// dialect->getCanonicalizationPatterns(owningPatterns);
// for (RegisteredOperationName op : context->getRegisteredOperations())
// op.getCanonicalizationPatterns(owningPatterns, context);

// owningPatterns.insert<AffineForOpRaising>(&getContext());

// patterns = std::make_shared<FrozenRewritePatternSet>(
// std::move(owningPatterns));
// return success();
// }
// void runOnOperation() override {
// GreedyRewriteConfig config;
// (void)applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
// }
// };
// } // namespace

namespace {
struct RaiseAffineToLinalg
: public AffineRaiseToLinalgBase<RaiseAffineToLinalg> {

std::shared_ptr<const FrozenRewritePatternSet> patterns;

LogicalResult initialize(MLIRContext *context) override {
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);

//owningPatterns.insert<RemoveIterArgs>(&getContext());
owningPatterns.insert<AffineForOpRaising>(&getContext());

patterns = std::make_shared<FrozenRewritePatternSet>(
std::move(owningPatterns));
return success();
}
void runOnOperation() override {
GreedyRewriteConfig config;
(void)applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
}
void runOnOperation() override;
};
} // namespace

void RaiseAffineToLinalg::runOnOperation() {
RewritePatternSet patterns(&getContext());
// TODO add the existing canonicalization patterns
// + subview of an affine apply -> subview
patterns.insert<AffineForOpRaising>(&getContext());
GreedyRewriteConfig config;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}

namespace mlir {
namespace polygeist {
std::unique_ptr<Pass> createRaiseAffineToLinalgPass() {
Expand Down

0 comments on commit d285fb5

Please sign in to comment.