Skip to content

Commit fed98ee

Browse files
committed
feat: add outside values to while loop argument list
1 parent f603104 commit fed98ee

File tree

3 files changed

+215
-0
lines changed

3 files changed

+215
-0
lines changed

src/enzyme_ad/jax/Passes/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,4 +1077,11 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> {
10771077
];
10781078
}
10791079

1080+
def WhileLoopOutsideValuesAddToArgumentListPass : Pass<
1081+
"while-loop-outside-values-add-to-argument-list"> {
1082+
let dependentDialects = [
1083+
"stablehlo::StablehloDialect"
1084+
];
1085+
}
1086+
10801087
#endif
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#include "mlir/IR/PatternMatch.h"
2+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
3+
4+
#include "src/enzyme_ad/jax/Passes/Passes.h"
5+
6+
#include "stablehlo/dialect/StablehloOps.h"
7+
8+
namespace mlir {
9+
namespace enzyme {
10+
#define GEN_PASS_DEF_WHILELOOPOUTSIDEVALUESADDTOARGUMENTLISTPASS
11+
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
12+
} // namespace enzyme
13+
} // namespace mlir
14+
15+
using namespace mlir;
16+
using namespace mlir::stablehlo;
17+
18+
namespace {
19+
20+
static bool definedOutside(Value v, Operation *op) {
21+
return !op->isAncestor(v.getParentBlock()->getParentOp());
22+
}
23+
24+
struct SHLOWhileOpUpdateArgumentListPattern final
25+
: public OpRewritePattern<stablehlo::WhileOp> {
26+
using OpRewritePattern::OpRewritePattern;
27+
28+
LogicalResult matchAndRewrite(stablehlo::WhileOp whileOp,
29+
PatternRewriter &rewriter) const override {
30+
// Collect values used inside cond/body that are defined outside the WhileOp
31+
SmallVector<Value, 4> extraValues;
32+
SmallPtrSet<Value, 8> seen;
33+
34+
auto collectExternal = [&](Region &region) {
35+
region.walk([&](Operation *op) {
36+
for (OpOperand &operand : op->getOpOperands()) {
37+
Value v = operand.get();
38+
if (!v)
39+
continue;
40+
if (definedOutside(v, whileOp) && !seen.contains(v)) {
41+
seen.insert(v);
42+
extraValues.push_back(v);
43+
}
44+
}
45+
});
46+
};
47+
48+
collectExternal(whileOp.getCond());
49+
collectExternal(whileOp.getBody());
50+
51+
if (extraValues.empty())
52+
return failure();
53+
54+
// Build new operand list = existing operands + external values
55+
SmallVector<Value, 8> newOperands(whileOp.getOperands().begin(),
56+
whileOp.getOperands().end());
57+
for (Value v : extraValues)
58+
newOperands.push_back(v);
59+
60+
SmallVector<Type, 8> newResultTypes;
61+
newResultTypes.reserve(newOperands.size());
62+
for (Value v : newOperands)
63+
newResultTypes.push_back(v.getType());
64+
65+
auto newWhile = stablehlo::WhileOp::create(rewriter, whileOp.getLoc(),
66+
newResultTypes, newOperands);
67+
68+
rewriter.inlineRegionBefore(whileOp.getCond(), newWhile.getCond(),
69+
newWhile.getCond().end());
70+
rewriter.inlineRegionBefore(whileOp.getBody(), newWhile.getBody(),
71+
newWhile.getBody().end());
72+
73+
// Append block arguments for the extra values
74+
Block &condBlock = newWhile.getCond().front();
75+
Block &bodyBlock = newWhile.getBody().front();
76+
77+
unsigned origArgCount = whileOp.getNumOperands();
78+
SmallVector<BlockArgument, 8> addedCondArgs, addedBodyArgs;
79+
addedCondArgs.reserve(extraValues.size());
80+
addedBodyArgs.reserve(extraValues.size());
81+
for (Value v : extraValues) {
82+
addedCondArgs.push_back(condBlock.addArgument(v.getType(), v.getLoc()));
83+
addedBodyArgs.push_back(bodyBlock.addArgument(v.getType(), v.getLoc()));
84+
}
85+
86+
// Remap uses of external values inside the regions to the new block args
87+
auto remapRegionUses = [&](Region &region, ArrayRef<Value> externals,
88+
ArrayRef<BlockArgument> args) {
89+
region.walk([&](Operation *op) {
90+
for (OpOperand &operand : op->getOpOperands()) {
91+
Value v = operand.get();
92+
for (auto [ext, arg] : llvm::zip(externals, args)) {
93+
if (v == ext) {
94+
operand.set(arg);
95+
break;
96+
}
97+
}
98+
}
99+
});
100+
};
101+
102+
remapRegionUses(newWhile.getCond(), extraValues, addedCondArgs);
103+
remapRegionUses(newWhile.getBody(), extraValues, addedBodyArgs);
104+
105+
Operation *terminator = bodyBlock.getTerminator();
106+
if (!terminator) {
107+
return rewriter.notifyMatchFailure(whileOp, "missing body terminator");
108+
}
109+
110+
auto retOp = dyn_cast<stablehlo::ReturnOp>(terminator);
111+
assert(retOp && "expected stablehlo::ReturnOp");
112+
113+
SmallVector<Value, 8> newRetVals(retOp.getOperands().begin(),
114+
retOp.getOperands().end());
115+
for (BlockArgument arg : addedBodyArgs)
116+
newRetVals.push_back(arg);
117+
118+
rewriter.setInsertionPoint(terminator);
119+
rewriter.replaceOpWithNewOp<stablehlo::ReturnOp>(terminator, newRetVals);
120+
121+
for (unsigned i = 0; i < origArgCount; ++i)
122+
rewriter.replaceAllUsesWith(whileOp.getResult(i), newWhile.getResult(i));
123+
rewriter.eraseOp(whileOp);
124+
return success();
125+
}
126+
};
127+
128+
struct WhileLoopOutsideValuesAddToArgumentListPass
129+
: public enzyme::impl::WhileLoopOutsideValuesAddToArgumentListPassBase<
130+
WhileLoopOutsideValuesAddToArgumentListPass> {
131+
using WhileLoopOutsideValuesAddToArgumentListPassBase::
132+
WhileLoopOutsideValuesAddToArgumentListPassBase;
133+
134+
void runOnOperation() override {
135+
RewritePatternSet patterns(&getContext());
136+
patterns.add<SHLOWhileOpUpdateArgumentListPattern>(patterns.getContext());
137+
walkAndApplyPatterns(getOperation(), std::move(patterns));
138+
}
139+
};
140+
141+
} // namespace
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: enzymexlamlir-opt %s --while-loop-outside-values-add-to-argument-list | FileCheck %s
2+
3+
func.func @main(%arg0: tensor<25xf32>) -> tensor<13xf32> {
4+
%cst = stablehlo.constant dense<3.000000e+00> : tensor<1xf32>
5+
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
6+
%c = stablehlo.constant dense<1> : tensor<i32>
7+
%c_1 = stablehlo.constant dense<5> : tensor<i64>
8+
%c_2 = stablehlo.constant dense<2> : tensor<i64>
9+
%c_3 = stablehlo.constant dense<0> : tensor<i64>
10+
%c_4 = stablehlo.constant dense<10> : tensor<i64>
11+
%c_5 = stablehlo.constant dense<1> : tensor<i64>
12+
%cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<13xf32>
13+
%0:2 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %cst_6) : tensor<i64>, tensor<13xf32>
14+
cond {
15+
%1 = stablehlo.compare LT, %iterArg, %c_4 : (tensor<i64>, tensor<i64>) -> tensor<i1>
16+
stablehlo.return %1 : tensor<i1>
17+
} do {
18+
%1 = stablehlo.add %c_5, %iterArg : tensor<i64>
19+
%2 = stablehlo.multiply %c_2, %1 : tensor<i64>
20+
%3 = stablehlo.add %2, %c_1 : tensor<i64>
21+
%4 = stablehlo.convert %3 : (tensor<i64>) -> tensor<i32>
22+
%5 = stablehlo.subtract %4, %c : tensor<i32>
23+
%6 = stablehlo.dynamic_slice %arg0, %5, sizes = [1] : (tensor<25xf32>, tensor<i32>) -> tensor<1xf32>
24+
%7 = stablehlo.multiply %6, %cst : tensor<1xf32>
25+
%8 = stablehlo.subtract %7, %cst_0 : tensor<1xf32>
26+
%9 = stablehlo.sine %8 : tensor<1xf32>
27+
%10 = stablehlo.add %1, %c_2 : tensor<i64>
28+
%11 = stablehlo.convert %10 : (tensor<i64>) -> tensor<i32>
29+
%12 = stablehlo.subtract %11, %c : tensor<i32>
30+
%13 = stablehlo.dynamic_update_slice %iterArg_7, %9, %12 : (tensor<13xf32>, tensor<1xf32>, tensor<i32>) -> tensor<13xf32>
31+
stablehlo.return %1, %13 : tensor<i64>, tensor<13xf32>
32+
}
33+
return %0#1 : tensor<13xf32>
34+
}
35+
36+
// CHECK: func.func @main(%arg0: tensor<25xf32>) -> tensor<13xf32> {
37+
// CHECK-NEXT: %cst = stablehlo.constant dense<3.000000e+00> : tensor<1xf32>
38+
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
39+
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<i32>
40+
// CHECK-NEXT: %c_1 = stablehlo.constant dense<5> : tensor<i64>
41+
// CHECK-NEXT: %c_2 = stablehlo.constant dense<2> : tensor<i64>
42+
// CHECK-NEXT: %c_3 = stablehlo.constant dense<0> : tensor<i64>
43+
// CHECK-NEXT: %c_4 = stablehlo.constant dense<10> : tensor<i64>
44+
// CHECK-NEXT: %c_5 = stablehlo.constant dense<1> : tensor<i64>
45+
// CHECK-NEXT: %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<13xf32>
46+
// CHECK-NEXT: %0:10 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %cst_6, %iterArg_8 = %c_4, %iterArg_9 = %c_5, %iterArg_10 = %c_2, %iterArg_11 = %c_1, %iterArg_12 = %c, %iterArg_13 = %arg0, %iterArg_14 = %cst, %iterArg_15 = %cst_0) : tensor<i64>, tensor<13xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i32>, tensor<25xf32>, tensor<1xf32>, tensor<1xf32>
47+
// CHECK-NEXT: cond {
48+
// CHECK-NEXT: %1 = stablehlo.compare LT, %iterArg, %iterArg_8 : (tensor<i64>, tensor<i64>) -> tensor<i1>
49+
// CHECK-NEXT: stablehlo.return %1 : tensor<i1>
50+
// CHECK-NEXT: } do {
51+
// CHECK-NEXT: %1 = stablehlo.add %iterArg_9, %iterArg : tensor<i64>
52+
// CHECK-NEXT: %2 = stablehlo.multiply %iterArg_10, %1 : tensor<i64>
53+
// CHECK-NEXT: %3 = stablehlo.add %2, %iterArg_11 : tensor<i64>
54+
// CHECK-NEXT: %4 = stablehlo.convert %3 : (tensor<i64>) -> tensor<i32>
55+
// CHECK-NEXT: %5 = stablehlo.subtract %4, %iterArg_12 : tensor<i32>
56+
// CHECK-NEXT: %6 = stablehlo.dynamic_slice %iterArg_13, %5, sizes = [1] : (tensor<25xf32>, tensor<i32>) -> tensor<1xf32>
57+
// CHECK-NEXT: %7 = stablehlo.multiply %6, %iterArg_14 : tensor<1xf32>
58+
// CHECK-NEXT: %8 = stablehlo.subtract %7, %iterArg_15 : tensor<1xf32>
59+
// CHECK-NEXT: %9 = stablehlo.sine %8 : tensor<1xf32>
60+
// CHECK-NEXT: %10 = stablehlo.add %1, %iterArg_10 : tensor<i64>
61+
// CHECK-NEXT: %11 = stablehlo.convert %10 : (tensor<i64>) -> tensor<i32>
62+
// CHECK-NEXT: %12 = stablehlo.subtract %11, %iterArg_12 : tensor<i32>
63+
// CHECK-NEXT: %13 = stablehlo.dynamic_update_slice %iterArg_7, %9, %12 : (tensor<13xf32>, tensor<1xf32>, tensor<i32>) -> tensor<13xf32>
64+
// CHECK-NEXT: stablehlo.return %1, %13, %iterArg_8, %iterArg_9, %iterArg_10, %iterArg_11, %iterArg_12, %iterArg_13, %iterArg_14, %iterArg_15 : tensor<i64>, tensor<13xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i32>, tensor<25xf32>, tensor<1xf32>, tensor<1xf32>
65+
// CHECK-NEXT: }
66+
// CHECK-NEXT: return %0#1 : tensor<13xf32>
67+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)