Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions src/enzyme_ad/jax/Passes/AffineCFG.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "Enzyme/MLIR/Dialect/Ops.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
Expand Down Expand Up @@ -1643,6 +1644,55 @@ static void replaceLoad(memref::LoadOp load,
load.erase();
}
*/

struct MoveRMWToAffine : public OpRewritePattern<memref::AtomicRMWOp> {
using OpRewritePattern<memref::AtomicRMWOp>::OpRewritePattern;

LogicalResult matchAndRewrite(memref::AtomicRMWOp rmw,
PatternRewriter &rewriter) const override {
auto scope = getLocalAffineScope(rmw);
for (auto idx : rmw.getIndices()) {
if (!isValidIndex(idx, scope)) {
return failure();
}
}

auto memrefType = cast<MemRefType>(rmw.getMemref().getType());
int64_t rank = memrefType.getRank();

// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
dimExprs.push_back(rewriter.getAffineSymbolExpr(i));
auto map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs,
rewriter.getContext());

SmallVector<Value, 4> operands = rmw.getIndices();

if (map.getNumInputs() != operands.size()) {
// load->getParentOfType<FuncOp>().dump();
llvm::errs() << " load: " << rmw << "\n";
}
auto *parentScope = scope->getParentOp();
DominanceInfo DI(parentScope);
assert(map.getNumInputs() == operands.size());
fully2ComposeAffineMapAndOperands(rewriter, &map, &operands, DI, scope);
assert(map.getNumInputs() == operands.size());
affine::canonicalizeMapAndOperands(&map, &operands);
map = recreateExpr(map);
assert(map.getNumInputs() == operands.size());

auto affineLoad = enzyme::AffineAtomicRMWOp::create(
rewriter, rmw.getLoc(), rmw.getValue().getType(), rmw.getKind(),
rmw.getValue(), rmw.getMemref(), operands, map);
rmw.getResult().replaceAllUsesWith(affineLoad.getResult());
rewriter.eraseOp(rmw);
return success();
}
};

struct MoveLoadToAffine : public OpRewritePattern<memref::LoadOp> {
using OpRewritePattern<memref::LoadOp>::OpRewritePattern;

Expand Down Expand Up @@ -5906,8 +5956,8 @@ void mlir::enzyme::populateAffineCFGPatterns(RewritePatternSet &rpl) {
CanonicalizeIndexCast<IndexCastUIOp>, AffineIfYieldMovementPattern,
/* IndexCastMovement,*/ AffineFixup<affine::AffineLoadOp>,
AffineFixup<affine::AffineStoreOp>, CanonicalizIfBounds,
MoveStoreToAffine, MoveIfToAffine, MoveLoadToAffine, MoveExtToAffine,
MoveSIToFPToAffine, CmpExt, MoveSelectToAffine,
MoveStoreToAffine, MoveIfToAffine, MoveRMWToAffine, MoveLoadToAffine,
MoveExtToAffine, MoveSIToFPToAffine, CmpExt, MoveSelectToAffine,
AffineIfSimplification, AffineIfSimplificationIsl, CombineAffineIfs,
MergeNestedAffineParallelLoops, PrepMergeNestedAffineParallelLoops,
MergeNestedAffineParallelIf, MergeParallelInductions, OptimizeRem,
Expand Down
Loading