Skip to content

Commit 49f1f69

Browse files
avik-palPangoraw
andauthored
Update src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp
Co-authored-by: Paul Berg <naydex.mc+github@gmail.com>
1 parent fed98ee commit 49f1f69

File tree

1 file changed

+3
-19
lines changed

1 file changed

+3
-19
lines changed

src/enzyme_ad/jax/Passes/WhileLoopOutsideValuesAddToArgumentListPass.cpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,10 @@ struct SHLOWhileOpUpdateArgumentListPattern final
2727

2828
LogicalResult matchAndRewrite(stablehlo::WhileOp whileOp,
2929
PatternRewriter &rewriter) const override {
30-
// Collect values used inside cond/body that are defined outside the WhileOp
31-
SmallVector<Value, 4> extraValues;
32-
SmallPtrSet<Value, 8> seen;
30+
SetVector<Value> extraValuesSet;
31+
getUsedValuesDefinedAbove(whileOp->getRegions(), extraValuesSet);
32+
SmallVector<Value> extraValues = extraValuesSet.takeVector();
3333

34-
auto collectExternal = [&](Region &region) {
35-
region.walk([&](Operation *op) {
36-
for (OpOperand &operand : op->getOpOperands()) {
37-
Value v = operand.get();
38-
if (!v)
39-
continue;
40-
if (definedOutside(v, whileOp) && !seen.contains(v)) {
41-
seen.insert(v);
42-
extraValues.push_back(v);
43-
}
44-
}
45-
});
46-
};
47-
48-
collectExternal(whileOp.getCond());
49-
collectExternal(whileOp.getBody());
5034

5135
if (extraValues.empty())
5236
return failure();

0 commit comments

Comments
 (0)