From fed98ee459ebe59f54eff37fef7dbdf96f9bbf46 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Dec 2025 12:33:54 -0600 Subject: [PATCH 1/2] feat: add outside values to while loop argument list --- src/enzyme_ad/jax/Passes/Passes.td | 7 + ...LoopOutsideValuesAddToArgumentListPass.cpp | 141 ++++++++++++++++++ test/lit_tests/loop_all_values_defined.mlir | 67 +++++++++ 3 files changed, 215 insertions(+) create mode 100644 src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp create mode 100644 test/lit_tests/loop_all_values_defined.mlir diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 5bf009201..859c4c30c 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1077,4 +1077,11 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> { ]; } +def WhileLoopOutsideValuesAddToArgumentListPass : Pass< + "while-loop-outside-values-add-to-argument-list"> { + let dependentDialects = [ + "stablehlo::StablehloDialect" + ]; +} + #endif diff --git a/src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp b/src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp new file mode 100644 index 000000000..ea6214cfe --- /dev/null +++ b/src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp @@ -0,0 +1,141 @@ +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_WHILELOOPOUTSIDEVALUESADDTOARGUMENTLISTPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::stablehlo; + +namespace { + +static bool definedOutside(Value v, Operation *op) { + return !op->isAncestor(v.getParentBlock()->getParentOp()); +} + +struct SHLOWhileOpUpdateArgumentListPattern final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::WhileOp whileOp, + PatternRewriter &rewriter) const override { + // Collect values used inside cond/body that are defined outside the WhileOp + SmallVector extraValues; + SmallPtrSet seen; + + auto collectExternal = [&](Region ®ion) { + region.walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + Value v = operand.get(); + if (!v) + continue; + if (definedOutside(v, whileOp) && !seen.contains(v)) { + seen.insert(v); + extraValues.push_back(v); + } + } + }); + }; + + collectExternal(whileOp.getCond()); + collectExternal(whileOp.getBody()); + + if (extraValues.empty()) + return failure(); + + // Build new operand list = existing operands + external values + SmallVector newOperands(whileOp.getOperands().begin(), + whileOp.getOperands().end()); + for (Value v : extraValues) + newOperands.push_back(v); + + SmallVector newResultTypes; + newResultTypes.reserve(newOperands.size()); + for (Value v : newOperands) + newResultTypes.push_back(v.getType()); + + auto newWhile = stablehlo::WhileOp::create(rewriter, whileOp.getLoc(), + newResultTypes, newOperands); + + rewriter.inlineRegionBefore(whileOp.getCond(), newWhile.getCond(), + newWhile.getCond().end()); + rewriter.inlineRegionBefore(whileOp.getBody(), newWhile.getBody(), + newWhile.getBody().end()); + + // Append block arguments for the extra values + Block &condBlock = newWhile.getCond().front(); + Block &bodyBlock = newWhile.getBody().front(); + + unsigned origArgCount = whileOp.getNumOperands(); + SmallVector addedCondArgs, addedBodyArgs; + addedCondArgs.reserve(extraValues.size()); + addedBodyArgs.reserve(extraValues.size()); + for (Value v : extraValues) { + addedCondArgs.push_back(condBlock.addArgument(v.getType(), v.getLoc())); + addedBodyArgs.push_back(bodyBlock.addArgument(v.getType(), v.getLoc())); + } + + // Remap uses of external values inside the regions to the new block args + auto remapRegionUses = [&](Region ®ion, ArrayRef externals, + ArrayRef args) { + region.walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + Value v = operand.get(); + for (auto [ext, arg] : llvm::zip(externals, args)) { + if (v == ext) { + operand.set(arg); + break; + } + } + } + }); + }; + + remapRegionUses(newWhile.getCond(), extraValues, addedCondArgs); + remapRegionUses(newWhile.getBody(), extraValues, addedBodyArgs); + + Operation *terminator = bodyBlock.getTerminator(); + if (!terminator) { + return rewriter.notifyMatchFailure(whileOp, "missing body terminator"); + } + + auto retOp = dyn_cast(terminator); + assert(retOp && "expected stablehlo::ReturnOp"); + + SmallVector newRetVals(retOp.getOperands().begin(), + retOp.getOperands().end()); + for (BlockArgument arg : addedBodyArgs) + newRetVals.push_back(arg); + + rewriter.setInsertionPoint(terminator); + rewriter.replaceOpWithNewOp(terminator, newRetVals); + + for (unsigned i = 0; i < origArgCount; ++i) + rewriter.replaceAllUsesWith(whileOp.getResult(i), newWhile.getResult(i)); + rewriter.eraseOp(whileOp); + return success(); + } +}; + +struct WhileLoopOutsideValuesAddToArgumentListPass + : public enzyme::impl::WhileLoopOutsideValuesAddToArgumentListPassBase< + WhileLoopOutsideValuesAddToArgumentListPass> { + using WhileLoopOutsideValuesAddToArgumentListPassBase:: + WhileLoopOutsideValuesAddToArgumentListPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext()); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + } +}; + +} // namespace diff --git a/test/lit_tests/loop_all_values_defined.mlir b/test/lit_tests/loop_all_values_defined.mlir new file mode 100644 index 000000000..41df8bf13 --- /dev/null +++ b/test/lit_tests/loop_all_values_defined.mlir @@ -0,0 +1,67 @@ +// RUN: enzymexlamlir-opt %s --while-loop-outside-values-add-to-argument-list | FileCheck %s + +func.func @main(%arg0: tensor<25xf32>) -> tensor<13xf32> { + %cst = stablehlo.constant dense<3.000000e+00> : tensor<1xf32> + %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32> + %c = stablehlo.constant dense<1> : tensor + %c_1 = stablehlo.constant dense<5> : tensor + %c_2 = stablehlo.constant dense<2> : tensor + %c_3 = stablehlo.constant dense<0> : tensor + %c_4 = stablehlo.constant dense<10> : tensor + %c_5 = stablehlo.constant dense<1> : tensor + %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<13xf32> + %0:2 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %cst_6) : tensor, tensor<13xf32> + cond { + %1 = stablehlo.compare LT, %iterArg, %c_4 : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + } do { + %1 = stablehlo.add %c_5, %iterArg : tensor + %2 = stablehlo.multiply %c_2, %1 : tensor + %3 = stablehlo.add %2, %c_1 : tensor + %4 = stablehlo.convert %3 : (tensor) -> tensor + %5 = stablehlo.subtract %4, %c : tensor + %6 = stablehlo.dynamic_slice %arg0, %5, sizes = [1] : (tensor<25xf32>, tensor) -> tensor<1xf32> + %7 = stablehlo.multiply %6, %cst : tensor<1xf32> + %8 = stablehlo.subtract %7, %cst_0 : tensor<1xf32> + %9 = stablehlo.sine %8 : tensor<1xf32> + %10 = stablehlo.add %1, %c_2 : tensor + %11 = stablehlo.convert %10 : (tensor) -> tensor + %12 = stablehlo.subtract %11, %c : tensor + %13 = stablehlo.dynamic_update_slice %iterArg_7, %9, %12 : (tensor<13xf32>, tensor<1xf32>, tensor) -> tensor<13xf32> + stablehlo.return %1, %13 : tensor, tensor<13xf32> + } + return %0#1 : tensor<13xf32> +} + +// CHECK: func.func @main(%arg0: tensor<25xf32>) -> tensor<13xf32> { +// CHECK-NEXT: %cst = stablehlo.constant dense<3.000000e+00> : tensor<1xf32> +// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32> +// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor +// CHECK-NEXT: %c_1 = stablehlo.constant dense<5> : tensor +// CHECK-NEXT: %c_2 = stablehlo.constant dense<2> : tensor +// CHECK-NEXT: %c_3 = stablehlo.constant dense<0> : tensor +// CHECK-NEXT: %c_4 = stablehlo.constant dense<10> : tensor +// CHECK-NEXT: %c_5 = stablehlo.constant dense<1> : tensor +// CHECK-NEXT: %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<13xf32> +// CHECK-NEXT: %0:10 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %cst_6, %iterArg_8 = %c_4, %iterArg_9 = %c_5, %iterArg_10 = %c_2, %iterArg_11 = %c_1, %iterArg_12 = %c, %iterArg_13 = %arg0, %iterArg_14 = %cst, %iterArg_15 = %cst_0) : tensor, tensor<13xf32>, tensor, tensor, tensor, tensor, tensor, tensor<25xf32>, tensor<1xf32>, tensor<1xf32> +// CHECK-NEXT: cond { +// CHECK-NEXT: %1 = stablehlo.compare LT, %iterArg, %iterArg_8 : (tensor, tensor) -> tensor +// CHECK-NEXT: stablehlo.return %1 : tensor +// CHECK-NEXT: } do { +// CHECK-NEXT: %1 = stablehlo.add %iterArg_9, %iterArg : tensor +// CHECK-NEXT: %2 = stablehlo.multiply %iterArg_10, %1 : tensor +// CHECK-NEXT: %3 = stablehlo.add %2, %iterArg_11 : tensor +// CHECK-NEXT: %4 = stablehlo.convert %3 : (tensor) -> tensor +// CHECK-NEXT: %5 = stablehlo.subtract %4, %iterArg_12 : tensor +// CHECK-NEXT: %6 = stablehlo.dynamic_slice %iterArg_13, %5, sizes = [1] : (tensor<25xf32>, tensor) -> tensor<1xf32> +// CHECK-NEXT: %7 = stablehlo.multiply %6, %iterArg_14 : tensor<1xf32> +// CHECK-NEXT: %8 = stablehlo.subtract %7, %iterArg_15 : tensor<1xf32> +// CHECK-NEXT: %9 = stablehlo.sine %8 : tensor<1xf32> +// CHECK-NEXT: %10 = stablehlo.add %1, %iterArg_10 : tensor +// CHECK-NEXT: %11 = stablehlo.convert %10 : (tensor) -> tensor +// CHECK-NEXT: %12 = stablehlo.subtract %11, %iterArg_12 : tensor +// CHECK-NEXT: %13 = stablehlo.dynamic_update_slice %iterArg_7, %9, %12 : (tensor<13xf32>, tensor<1xf32>, tensor) -> tensor<13xf32> +// CHECK-NEXT: stablehlo.return %1, %13, %iterArg_8, %iterArg_9, %iterArg_10, %iterArg_11, %iterArg_12, %iterArg_13, %iterArg_14, %iterArg_15 : tensor, tensor<13xf32>, tensor, tensor, tensor, tensor, tensor, tensor<25xf32>, tensor<1xf32>, tensor<1xf32> +// CHECK-NEXT: } +// CHECK-NEXT: return %0#1 : tensor<13xf32> +// CHECK-NEXT: } From 90beba7b916ec5116b717886621f967c0e2a4512 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 3 Dec 2025 09:20:03 -0500 Subject: [PATCH 2/2] refactor: use `getUsedValuesDefinedAbove` Co-authored-by: Paul Berg --- ...LoopOutsideValuesAddToArgumentListPass.cpp | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp b/src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp index ea6214cfe..e9d1bf1ba 100644 --- a/src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp +++ b/src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp @@ -1,4 +1,5 @@ #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/RegionUtils.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "src/enzyme_ad/jax/Passes/Passes.h" @@ -27,26 +28,9 @@ struct SHLOWhileOpUpdateArgumentListPattern final LogicalResult matchAndRewrite(stablehlo::WhileOp whileOp, PatternRewriter &rewriter) const override { - // Collect values used inside cond/body that are defined outside the WhileOp - SmallVector extraValues; - SmallPtrSet seen; - - auto collectExternal = [&](Region ®ion) { - region.walk([&](Operation *op) { - for (OpOperand &operand : op->getOpOperands()) { - Value v = operand.get(); - if (!v) - continue; - if (definedOutside(v, whileOp) && !seen.contains(v)) { - seen.insert(v); - extraValues.push_back(v); - } - } - }); - }; - - collectExternal(whileOp.getCond()); - collectExternal(whileOp.getBody()); + SetVector extraValuesSet; + getUsedValuesDefinedAbove(whileOp->getRegions(), extraValuesSet); + SmallVector extraValues = extraValuesSet.takeVector(); if (extraValues.empty()) return failure();