Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1077,4 +1077,11 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> {
];
}

def WhileLoopOutsideValuesAddToArgumentListPass : Pass<
"while-loop-outside-values-add-to-argument-list"> {
let dependentDialects = [
"stablehlo::StablehloDialect"
];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "stablehlo/dialect/StablehloOps.h"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_WHILELOOPOUTSIDEVALUESADDTOARGUMENTLISTPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::stablehlo;

namespace {

static bool definedOutside(Value v, Operation *op) {
return !op->isAncestor(v.getParentBlock()->getParentOp());
}

struct SHLOWhileOpUpdateArgumentListPattern final
: public OpRewritePattern<stablehlo::WhileOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::WhileOp whileOp,
PatternRewriter &rewriter) const override {
SetVector<Value> extraValuesSet;
getUsedValuesDefinedAbove(whileOp->getRegions(), extraValuesSet);
SmallVector<Value> extraValues = extraValuesSet.takeVector();

if (extraValues.empty())
return failure();

// Build new operand list = existing operands + external values
SmallVector<Value, 8> newOperands(whileOp.getOperands().begin(),
whileOp.getOperands().end());
for (Value v : extraValues)
newOperands.push_back(v);

SmallVector<Type, 8> newResultTypes;
newResultTypes.reserve(newOperands.size());
for (Value v : newOperands)
newResultTypes.push_back(v.getType());

auto newWhile = stablehlo::WhileOp::create(rewriter, whileOp.getLoc(),
newResultTypes, newOperands);

rewriter.inlineRegionBefore(whileOp.getCond(), newWhile.getCond(),
newWhile.getCond().end());
rewriter.inlineRegionBefore(whileOp.getBody(), newWhile.getBody(),
newWhile.getBody().end());

// Append block arguments for the extra values
Block &condBlock = newWhile.getCond().front();
Block &bodyBlock = newWhile.getBody().front();

unsigned origArgCount = whileOp.getNumOperands();
SmallVector<BlockArgument, 8> addedCondArgs, addedBodyArgs;
addedCondArgs.reserve(extraValues.size());
addedBodyArgs.reserve(extraValues.size());
for (Value v : extraValues) {
addedCondArgs.push_back(condBlock.addArgument(v.getType(), v.getLoc()));
addedBodyArgs.push_back(bodyBlock.addArgument(v.getType(), v.getLoc()));
}

// Remap uses of external values inside the regions to the new block args
auto remapRegionUses = [&](Region &region, ArrayRef<Value> externals,
ArrayRef<BlockArgument> args) {
region.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
Value v = operand.get();
for (auto [ext, arg] : llvm::zip(externals, args)) {
if (v == ext) {
operand.set(arg);
break;
}
}
}
});
};

remapRegionUses(newWhile.getCond(), extraValues, addedCondArgs);
remapRegionUses(newWhile.getBody(), extraValues, addedBodyArgs);

Operation *terminator = bodyBlock.getTerminator();
if (!terminator) {
return rewriter.notifyMatchFailure(whileOp, "missing body terminator");
}

auto retOp = dyn_cast<stablehlo::ReturnOp>(terminator);
assert(retOp && "expected stablehlo::ReturnOp");

SmallVector<Value, 8> newRetVals(retOp.getOperands().begin(),
retOp.getOperands().end());
for (BlockArgument arg : addedBodyArgs)
newRetVals.push_back(arg);

rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<stablehlo::ReturnOp>(terminator, newRetVals);

for (unsigned i = 0; i < origArgCount; ++i)
rewriter.replaceAllUsesWith(whileOp.getResult(i), newWhile.getResult(i));
rewriter.eraseOp(whileOp);
return success();
}
};

struct WhileLoopOutsideValuesAddToArgumentListPass
: public enzyme::impl::WhileLoopOutsideValuesAddToArgumentListPassBase<
WhileLoopOutsideValuesAddToArgumentListPass> {
using WhileLoopOutsideValuesAddToArgumentListPassBase::
WhileLoopOutsideValuesAddToArgumentListPassBase;

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<SHLOWhileOpUpdateArgumentListPattern>(patterns.getContext());
walkAndApplyPatterns(getOperation(), std::move(patterns));
}
};

} // namespace
67 changes: 67 additions & 0 deletions test/lit_tests/loop_all_values_defined.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: enzymexlamlir-opt %s --while-loop-outside-values-add-to-argument-list | FileCheck %s

func.func @main(%arg0: tensor<25xf32>) -> tensor<13xf32> {
%cst = stablehlo.constant dense<3.000000e+00> : tensor<1xf32>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
%c = stablehlo.constant dense<1> : tensor<i32>
%c_1 = stablehlo.constant dense<5> : tensor<i64>
%c_2 = stablehlo.constant dense<2> : tensor<i64>
%c_3 = stablehlo.constant dense<0> : tensor<i64>
%c_4 = stablehlo.constant dense<10> : tensor<i64>
%c_5 = stablehlo.constant dense<1> : tensor<i64>
%cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<13xf32>
%0:2 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %cst_6) : tensor<i64>, tensor<13xf32>
cond {
%1 = stablehlo.compare LT, %iterArg, %c_4 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
} do {
%1 = stablehlo.add %c_5, %iterArg : tensor<i64>
%2 = stablehlo.multiply %c_2, %1 : tensor<i64>
%3 = stablehlo.add %2, %c_1 : tensor<i64>
%4 = stablehlo.convert %3 : (tensor<i64>) -> tensor<i32>
%5 = stablehlo.subtract %4, %c : tensor<i32>
%6 = stablehlo.dynamic_slice %arg0, %5, sizes = [1] : (tensor<25xf32>, tensor<i32>) -> tensor<1xf32>
%7 = stablehlo.multiply %6, %cst : tensor<1xf32>
%8 = stablehlo.subtract %7, %cst_0 : tensor<1xf32>
%9 = stablehlo.sine %8 : tensor<1xf32>
%10 = stablehlo.add %1, %c_2 : tensor<i64>
%11 = stablehlo.convert %10 : (tensor<i64>) -> tensor<i32>
%12 = stablehlo.subtract %11, %c : tensor<i32>
%13 = stablehlo.dynamic_update_slice %iterArg_7, %9, %12 : (tensor<13xf32>, tensor<1xf32>, tensor<i32>) -> tensor<13xf32>
stablehlo.return %1, %13 : tensor<i64>, tensor<13xf32>
}
return %0#1 : tensor<13xf32>
}

// CHECK: func.func @main(%arg0: tensor<25xf32>) -> tensor<13xf32> {
// CHECK-NEXT: %cst = stablehlo.constant dense<3.000000e+00> : tensor<1xf32>
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<i32>
// CHECK-NEXT: %c_1 = stablehlo.constant dense<5> : tensor<i64>
// CHECK-NEXT: %c_2 = stablehlo.constant dense<2> : tensor<i64>
// CHECK-NEXT: %c_3 = stablehlo.constant dense<0> : tensor<i64>
// CHECK-NEXT: %c_4 = stablehlo.constant dense<10> : tensor<i64>
// CHECK-NEXT: %c_5 = stablehlo.constant dense<1> : tensor<i64>
// CHECK-NEXT: %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<13xf32>
// CHECK-NEXT: %0:10 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %cst_6, %iterArg_8 = %c_4, %iterArg_9 = %c_5, %iterArg_10 = %c_2, %iterArg_11 = %c_1, %iterArg_12 = %c, %iterArg_13 = %arg0, %iterArg_14 = %cst, %iterArg_15 = %cst_0) : tensor<i64>, tensor<13xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i32>, tensor<25xf32>, tensor<1xf32>, tensor<1xf32>
// CHECK-NEXT: cond {
// CHECK-NEXT: %1 = stablehlo.compare LT, %iterArg, %iterArg_8 : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: stablehlo.return %1 : tensor<i1>
// CHECK-NEXT: } do {
// CHECK-NEXT: %1 = stablehlo.add %iterArg_9, %iterArg : tensor<i64>
// CHECK-NEXT: %2 = stablehlo.multiply %iterArg_10, %1 : tensor<i64>
// CHECK-NEXT: %3 = stablehlo.add %2, %iterArg_11 : tensor<i64>
// CHECK-NEXT: %4 = stablehlo.convert %3 : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: %5 = stablehlo.subtract %4, %iterArg_12 : tensor<i32>
// CHECK-NEXT: %6 = stablehlo.dynamic_slice %iterArg_13, %5, sizes = [1] : (tensor<25xf32>, tensor<i32>) -> tensor<1xf32>
// CHECK-NEXT: %7 = stablehlo.multiply %6, %iterArg_14 : tensor<1xf32>
// CHECK-NEXT: %8 = stablehlo.subtract %7, %iterArg_15 : tensor<1xf32>
// CHECK-NEXT: %9 = stablehlo.sine %8 : tensor<1xf32>
// CHECK-NEXT: %10 = stablehlo.add %1, %iterArg_10 : tensor<i64>
// CHECK-NEXT: %11 = stablehlo.convert %10 : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: %12 = stablehlo.subtract %11, %iterArg_12 : tensor<i32>
// CHECK-NEXT: %13 = stablehlo.dynamic_update_slice %iterArg_7, %9, %12 : (tensor<13xf32>, tensor<1xf32>, tensor<i32>) -> tensor<13xf32>
// CHECK-NEXT: stablehlo.return %1, %13, %iterArg_8, %iterArg_9, %iterArg_10, %iterArg_11, %iterArg_12, %iterArg_13, %iterArg_14, %iterArg_15 : tensor<i64>, tensor<13xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i32>, tensor<25xf32>, tensor<1xf32>, tensor<1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %0#1 : tensor<13xf32>
// CHECK-NEXT: }
Loading