From 9406020fc2474de1d5400d18dc090d4b2c2051a2 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Fri, 5 Dec 2025 09:12:08 +0900 Subject: [PATCH] Add raising from memref atomic rmw to affine --- src/enzyme_ad/jax/Passes/AffineCFG.cpp | 54 +++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/AffineCFG.cpp b/src/enzyme_ad/jax/Passes/AffineCFG.cpp index 4cafae9809..ad18bc5943 100644 --- a/src/enzyme_ad/jax/Passes/AffineCFG.cpp +++ b/src/enzyme_ad/jax/Passes/AffineCFG.cpp @@ -1,3 +1,4 @@ +#include "Enzyme/MLIR/Dialect/Ops.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" @@ -1643,6 +1644,55 @@ static void replaceLoad(memref::LoadOp load, load.erase(); } */ + +struct MoveRMWToAffine : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::AtomicRMWOp rmw, + PatternRewriter &rewriter) const override { + auto scope = getLocalAffineScope(rmw); + for (auto idx : rmw.getIndices()) { + if (!isValidIndex(idx, scope)) { + return failure(); + } + } + + auto memrefType = cast(rmw.getMemref().getType()); + int64_t rank = memrefType.getRank(); + + // Create identity map for memrefs with at least one dimension or () -> () + // for zero-dimensional memrefs. + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimExprs.push_back(rewriter.getAffineSymbolExpr(i)); + auto map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs, + rewriter.getContext()); + + SmallVector operands = rmw.getIndices(); + + if (map.getNumInputs() != operands.size()) { + // load->getParentOfType().dump(); + llvm::errs() << " load: " << rmw << "\n"; + } + auto *parentScope = scope->getParentOp(); + DominanceInfo DI(parentScope); + assert(map.getNumInputs() == operands.size()); + fully2ComposeAffineMapAndOperands(rewriter, &map, &operands, DI, scope); + assert(map.getNumInputs() == operands.size()); + affine::canonicalizeMapAndOperands(&map, &operands); + map = recreateExpr(map); + assert(map.getNumInputs() == operands.size()); + + auto affineLoad = enzyme::AffineAtomicRMWOp::create( + rewriter, rmw.getLoc(), rmw.getValue().getType(), rmw.getKind(), + rmw.getValue(), rmw.getMemref(), operands, map); + rmw.getResult().replaceAllUsesWith(affineLoad.getResult()); + rewriter.eraseOp(rmw); + return success(); + } +}; + struct MoveLoadToAffine : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -5906,8 +5956,8 @@ void mlir::enzyme::populateAffineCFGPatterns(RewritePatternSet &rpl) { CanonicalizeIndexCast, AffineIfYieldMovementPattern, /* IndexCastMovement,*/ AffineFixup, AffineFixup, CanonicalizIfBounds, - MoveStoreToAffine, MoveIfToAffine, MoveLoadToAffine, MoveExtToAffine, - MoveSIToFPToAffine, CmpExt, MoveSelectToAffine, + MoveStoreToAffine, MoveIfToAffine, MoveRMWToAffine, MoveLoadToAffine, + MoveExtToAffine, MoveSIToFPToAffine, CmpExt, MoveSelectToAffine, AffineIfSimplification, AffineIfSimplificationIsl, CombineAffineIfs, MergeNestedAffineParallelLoops, PrepMergeNestedAffineParallelLoops, MergeNestedAffineParallelIf, MergeParallelInductions, OptimizeRem,