Skip to content

Commit 9a63a40

Browse files
committed
feat: support intermediate insertions/deletions while copying
test: missing results feat: try more generalization (partial progress) fix: uncomment fix: add comment on how to handle remaining cases fix: remove unwanted code revert: unwanted changes feat: handle transpose without explicit transpose op chore: run fmt chore: remove old comment
1 parent bea6deb commit 9a63a40

File tree

2 files changed

+292
-36
lines changed

2 files changed

+292
-36
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 165 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24299,40 +24299,18 @@ struct WhileIsCopySimplify
2429924299
blockArg.getArgNumber() != idx)
2430024300
continue;
2430124301

24302-
// check if update is a DS
24303-
stablehlo::DynamicSliceOp sliceOp;
24304-
bool iotaMapping = false;
24305-
Value sliceOperand;
24306-
SmallVector<int64_t> mapDStoDUSDim;
24307-
24308-
if (sliceOp =
24309-
dusOp.getUpdate().getDefiningOp<stablehlo::DynamicSliceOp>()) {
24310-
// simple case where we slice and update
24311-
iotaMapping = true;
24312-
mapDStoDUSDim = SmallVector<int64_t>(sliceOp.getStartIndices().size());
24313-
std::iota(mapDStoDUSDim.begin(), mapDStoDUSDim.end(), 0);
24314-
sliceOperand = sliceOp.getOperand();
24315-
if (!isValueAccessibleFromBlock(domInfo, sliceOperand, parentBlock))
24316-
continue;
24317-
} else if (auto transposeOp =
24318-
dusOp.getUpdate()
24319-
.getDefiningOp<stablehlo::TransposeOp>()) {
24320-
// slice => transpose => update
24321-
auto tperm = transposeOp.getPermutation();
24322-
mapDStoDUSDim = SmallVector<int64_t>(tperm.size());
24323-
for (int i = 0; i < tperm.size(); i++)
24324-
mapDStoDUSDim[tperm[i]] = i;
24325-
24326-
sliceOp =
24327-
transposeOp.getOperand().getDefiningOp<stablehlo::DynamicSliceOp>();
24328-
if (!sliceOp)
24329-
continue;
24330-
sliceOperand = sliceOp.getOperand();
24331-
if (!isValueAccessibleFromBlock(domInfo, sliceOperand, parentBlock))
24332-
continue;
24333-
} else {
24302+
int32_t dusInductionVarDim =
24303+
getInductionVariableDimension(dusOp, inductionVarOffsets, whileOp);
24304+
if (dusInductionVarDim == -1)
24305+
continue;
24306+
24307+
auto [success, iotaMapping, sliceOp, sliceInductionVarDim, sliceOperand,
24308+
mapDStoDUSDim] =
24309+
getIndexMappingInfo(dusOp.getUpdate().getDefiningOp(), domInfo,
24310+
parentBlock, rewriter, whileOp,
24311+
inductionVarOffsets, dusOp, dusInductionVarDim);
24312+
if (!success)
2433424313
continue;
24335-
}
2433624314

2433724315
bool indicesMatch = true, foundInductionVar = false;
2433824316
auto dsShape = cast<ShapedType>(sliceOp.getType()).getShape();
@@ -24345,7 +24323,7 @@ struct WhileIsCopySimplify
2434524323
SmallVector<IndexInfo> dusStartIndices(dusOp.getStartIndices().size());
2434624324

2434724325
for (size_t i = 0; i < sliceOp.getStartIndices().size(); i++) {
24348-
int j = mapDStoDUSDim[i];
24326+
int32_t j = mapDStoDUSDim[i];
2434924327

2435024328
auto dsStartIndex = sliceOp.getStartIndices()[i];
2435124329
auto dusStartIndex = dusOp.getStartIndices()[j];
@@ -24409,6 +24387,7 @@ struct WhileIsCopySimplify
2440924387
rewriter.getDenseI64ArrayAttr(copyInfo.sliceSizes));
2441024388

2441124389
auto dusUpdate = sliceOp.getResult();
24390+
2441224391
if (!copyInfo.iotaMapping) {
2441324392
SmallVector<int64_t> permutation(copyInfo.mapDStoDUSDim.size());
2441424393
for (int i = 0; i < copyInfo.mapDStoDUSDim.size(); i++)
@@ -24444,10 +24423,161 @@ struct WhileIsCopySimplify
2444424423
SmallVector<int64_t> sliceSizes;
2444524424
SmallVector<IndexInfo> dusStartIndices;
2444624425
bool iotaMapping;
24447-
SmallVector<int64_t> mapDStoDUSDim;
24426+
SmallVector<int32_t> mapDStoDUSDim;
2444824427
unsigned blockArgIdx;
2444924428
};
2445024429

24430+
struct IndexMappingInfo {
24431+
bool success;
24432+
bool iotaMapping;
24433+
stablehlo::DynamicSliceOp sliceOp;
24434+
int32_t sliceInductionVarDim;
24435+
Value sliceOperand;
24436+
SmallVector<int32_t> mapDStoDUSDim;
24437+
};
24438+
24439+
IndexMappingInfo unsupportedIndexMappingInfo() const {
24440+
return IndexMappingInfo{false, false, nullptr, -1, nullptr, {}};
24441+
}
24442+
24443+
IndexMappingInfo
24444+
getIndexMappingInfo(Operation *op, DominanceInfo &domInfo, Block *parentBlock,
24445+
PatternRewriter &rewriter, stablehlo::WhileOp whileOp,
24446+
DenseMap<Value, APInt> &inductionVarOffsets,
24447+
stablehlo::DynamicUpdateSliceOp dusOp,
24448+
int32_t dusInductionVarDim) const {
24449+
if (auto sliceOp = dyn_cast<stablehlo::DynamicSliceOp>(op)) {
24450+
// base case, we have reached the dynamic slice
24451+
Value sliceOperand = sliceOp.getOperand();
24452+
24453+
if (!isValueAccessibleFromBlock(domInfo, sliceOperand, parentBlock))
24454+
return unsupportedIndexMappingInfo();
24455+
24456+
auto inductionVarDim =
24457+
getInductionVariableDimension(sliceOp, inductionVarOffsets, whileOp);
24458+
if (inductionVarDim == -1)
24459+
return unsupportedIndexMappingInfo();
24460+
24461+
auto sliceSizes = sliceOp.getSliceSizes();
24462+
24463+
SmallVector<int32_t> mapDStoDUSDim(sliceOp.getStartIndices().size(), -1);
24464+
bool isIotaMapping = false;
24465+
if (inductionVarDim == dusInductionVarDim) {
24466+
isIotaMapping = true;
24467+
std::iota(mapDStoDUSDim.begin(), mapDStoDUSDim.end(), 0);
24468+
} else {
24469+
auto minVal = std::min(dusInductionVarDim, inductionVarDim);
24470+
auto maxVal = std::max(dusInductionVarDim, inductionVarDim);
24471+
24472+
for (int32_t i = 0; i < minVal; i++)
24473+
mapDStoDUSDim[i] = i;
24474+
24475+
bool allOnes = true;
24476+
for (int32_t i = minVal; i <= maxVal; i++) {
24477+
if (sliceSizes[i] != 1) {
24478+
allOnes = false;
24479+
break;
24480+
}
24481+
mapDStoDUSDim[i] = i;
24482+
}
24483+
24484+
mapDStoDUSDim[dusInductionVarDim] = inductionVarDim;
24485+
mapDStoDUSDim[inductionVarDim] = dusInductionVarDim;
24486+
24487+
if (!allOnes)
24488+
return unsupportedIndexMappingInfo();
24489+
24490+
for (int32_t i = maxVal + 1; i < sliceOp.getStartIndices().size(); i++)
24491+
mapDStoDUSDim[i] = i;
24492+
}
24493+
24494+
return IndexMappingInfo{true, isIotaMapping,
24495+
sliceOp, inductionVarDim,
24496+
sliceOperand, mapDStoDUSDim};
24497+
}
24498+
24499+
if (auto transposeOp = dyn_cast<stablehlo::TransposeOp>(op)) {
24500+
// recursive case: apply transpose on the mapped indices
24501+
auto tperm = transposeOp.getPermutation();
24502+
24503+
int32_t mappedDusInductionVarDim;
24504+
for (int32_t i = 0; i < tperm.size(); i++) {
24505+
if (tperm[i] == dusInductionVarDim) {
24506+
mappedDusInductionVarDim = i;
24507+
break;
24508+
}
24509+
}
24510+
24511+
auto prevInfo = getIndexMappingInfo(
24512+
transposeOp.getOperand().getDefiningOp(), domInfo, parentBlock,
24513+
rewriter, whileOp, inductionVarOffsets, dusOp,
24514+
mappedDusInductionVarDim);
24515+
if (!prevInfo.success)
24516+
return prevInfo;
24517+
24518+
// apply transpose on the mapped indices
24519+
SmallVector<int32_t> newMapping(tperm.size());
24520+
int32_t sliceInductionVarDim;
24521+
for (int32_t i = 0; i < tperm.size(); i++) {
24522+
if (tperm[i] == prevInfo.sliceInductionVarDim) {
24523+
sliceInductionVarDim = i;
24524+
}
24525+
newMapping[tperm[i]] = prevInfo.mapDStoDUSDim[i];
24526+
}
24527+
24528+
return IndexMappingInfo{true,
24529+
false,
24530+
prevInfo.sliceOp,
24531+
sliceInductionVarDim,
24532+
prevInfo.sliceOperand,
24533+
newMapping};
24534+
}
24535+
24536+
return unsupportedIndexMappingInfo();
24537+
}
24538+
24539+
int32_t
24540+
getInductionVariableDimension(stablehlo::DynamicSliceOp sliceOp,
24541+
DenseMap<Value, APInt> &inductionVarOffsets,
24542+
stablehlo::WhileOp whileOp) const {
24543+
int32_t inductionVarDimension = -1;
24544+
24545+
for (size_t i = 0; i < sliceOp.getStartIndices().size(); i++) {
24546+
auto dsStartIndex = sliceOp.getStartIndices()[i];
24547+
24548+
if (!isConstantAcrossLoopIterations(dsStartIndex, whileOp)) {
24549+
if (inductionVarDimension > 0 || // multiple indices with induction var
24550+
!inductionVarOffsets.contains(dsStartIndex))
24551+
return -1;
24552+
24553+
inductionVarDimension = i;
24554+
}
24555+
}
24556+
24557+
return inductionVarDimension;
24558+
}
24559+
24560+
int32_t
24561+
getInductionVariableDimension(stablehlo::DynamicUpdateSliceOp dusOp,
24562+
DenseMap<Value, APInt> &inductionVarOffsets,
24563+
stablehlo::WhileOp whileOp) const {
24564+
int32_t inductionVarDimension = -1;
24565+
24566+
for (size_t i = 0; i < dusOp.getStartIndices().size(); i++) {
24567+
auto dusStartIndex = dusOp.getStartIndices()[i];
24568+
24569+
if (!isConstantAcrossLoopIterations(dusStartIndex, whileOp)) {
24570+
if (inductionVarDimension > 0 || // multiple indices with induction var
24571+
!inductionVarOffsets.contains(dusStartIndex))
24572+
return -1;
24573+
24574+
inductionVarDimension = i;
24575+
}
24576+
}
24577+
24578+
return inductionVarDimension;
24579+
}
24580+
2445124581
SmallVector<Value> indexInfoToValues(Location loc,
2445224582
ArrayRef<IndexInfo> indices,
2445324583
PatternRewriter &rewriter) const {

test/lit_tests/while_is_copy.mlir

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reshape_dynamic_slice(1);reshape_licm(1);transpose_dynamic_slice;transpose_licm(1)" --transform-interpreter --enzyme-hlo-remove-transform --enzyme-hlo-opt %s | FileCheck %s
22

33
module {
44
func.func @main(%arg0: tensor<10xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<10xf32>) -> tensor<10xf32> {
@@ -165,3 +165,129 @@ module {
165165
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [2, 0, 1] : (tensor<5x10x4xf32>) -> tensor<4x5x10xf32>
166166
// CHECK-NEXT: return %2 : tensor<4x5x10xf32>
167167
// CHECK-NEXT: }
168+
169+
module {
170+
func.func @main(%arg0: tensor<10xf64>) -> tensor<10xf64> {
171+
%c = stablehlo.constant dense<1> : tensor<i32>
172+
%c_0 = stablehlo.constant dense<0> : tensor<i64>
173+
%c_1 = stablehlo.constant dense<10> : tensor<i64>
174+
%c_2 = stablehlo.constant dense<1> : tensor<i64>
175+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf64>
176+
%0 = stablehlo.reshape %arg0 : (tensor<10xf64>) -> tensor<10x1xf64>
177+
%1:2 = stablehlo.while(%iterArg = %c_0, %iterArg_3 = %cst) : tensor<i64>, tensor<10xf64>
178+
cond {
179+
%2 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
180+
stablehlo.return %2 : tensor<i1>
181+
} do {
182+
%2 = stablehlo.add %c_2, %iterArg : tensor<i64>
183+
%3 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32>
184+
%4 = stablehlo.subtract %3, %c : tensor<i32>
185+
%5 = stablehlo.dynamic_slice %0, %iterArg, %c_0, sizes = [1, 1] : (tensor<10x1xf64>, tensor<i64>, tensor<i64>) -> tensor<1x1xf64>
186+
%6 = stablehlo.reshape %5 : (tensor<1x1xf64>) -> tensor<1xf64>
187+
%7 = stablehlo.dynamic_update_slice %iterArg_3, %6, %4 : (tensor<10xf64>, tensor<1xf64>, tensor<i32>) -> tensor<10xf64>
188+
stablehlo.return %2, %7 : tensor<i64>, tensor<10xf64>
189+
}
190+
return %1#1 : tensor<10xf64>
191+
}
192+
}
193+
194+
// CHECK: func.func @main(%arg0: tensor<10xf64>) -> tensor<10xf64> {
195+
// CHECK-NEXT: return %arg0 : tensor<10xf64>
196+
// CHECK-NEXT: }
197+
198+
module {
199+
func.func @main(%arg0: tensor<5x4x3xf32>) -> tensor<4x5x3xf32> {
200+
%c = stablehlo.constant dense<0> : tensor<i32>
201+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<4x5x3xf32>
202+
%c_0 = stablehlo.constant dense<1> : tensor<i32>
203+
%c_1 = stablehlo.constant dense<0> : tensor<i64>
204+
%c_2 = stablehlo.constant dense<4> : tensor<i64>
205+
%c_3 = stablehlo.constant dense<1> : tensor<i64>
206+
%0 = stablehlo.broadcast_in_dim %arg0, dims = [2, 0, 3] : (tensor<5x4x3xf32>) -> tensor<4x1x5x3xf32>
207+
%1:2 = stablehlo.while(%iterArg = %c_1, %iterArg_4 = %cst) : tensor<i64>, tensor<4x5x3xf32>
208+
cond {
209+
%2 = stablehlo.compare LT, %iterArg, %c_2 : (tensor<i64>, tensor<i64>) -> tensor<i1>
210+
stablehlo.return %2 : tensor<i1>
211+
} do {
212+
%2 = stablehlo.add %c_3, %iterArg : tensor<i64>
213+
%3 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32>
214+
%4 = stablehlo.subtract %3, %c_0 : tensor<i32>
215+
%5 = stablehlo.dynamic_slice %0, %iterArg, %c_1, %c_1, %c_1, sizes = [1, 1, 5, 3] : (tensor<4x1x5x3xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x5x3xf32>
216+
%6 = stablehlo.reshape %5 : (tensor<1x1x5x3xf32>) -> tensor<1x5x3xf32>
217+
%7 = stablehlo.dynamic_update_slice %iterArg_4, %6, %4, %c, %c : (tensor<4x5x3xf32>, tensor<1x5x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4x5x3xf32>
218+
stablehlo.return %2, %7 : tensor<i64>, tensor<4x5x3xf32>
219+
}
220+
return %1#1 : tensor<4x5x3xf32>
221+
}
222+
}
223+
224+
// CHECK: func.func @main(%arg0: tensor<5x4x3xf32>) -> tensor<4x5x3xf32> {
225+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0, 2] : (tensor<5x4x3xf32>) -> tensor<4x5x3xf32>
226+
// CHECK-NEXT: return %0 : tensor<4x5x3xf32>
227+
// CHECK-NEXT: }
228+
229+
module {
230+
func.func @main(%arg0: tensor<5x4x3xf32>, %arg1: tensor<3x1x4x1x5xf32>) -> tensor<5x4x3xf32> {
231+
%c = stablehlo.constant dense<0> : tensor<i32>
232+
%c_0 = stablehlo.constant dense<1> : tensor<i32>
233+
%c_1 = stablehlo.constant dense<0> : tensor<i64>
234+
%c_2 = stablehlo.constant dense<1> : tensor<i64>
235+
%c_3 = stablehlo.constant dense<4> : tensor<i64>
236+
%0 = stablehlo.transpose %arg1, dims = [4, 1, 2, 3, 0] : (tensor<3x1x4x1x5xf32>) -> tensor<5x1x4x1x3xf32>
237+
%1:2 = stablehlo.while(%iterArg = %c_1, %iterArg_3 = %arg0) : tensor<i64>, tensor<5x4x3xf32>
238+
cond {
239+
%2 = stablehlo.compare LT, %iterArg, %c_3 : (tensor<i64>, tensor<i64>) -> tensor<i1>
240+
stablehlo.return %2 : tensor<i1>
241+
} do {
242+
%2 = stablehlo.add %c_2, %iterArg : tensor<i64>
243+
%3 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32>
244+
%4 = stablehlo.subtract %3, %c_0 : tensor<i32>
245+
%5 = stablehlo.dynamic_slice %0, %c, %c, %4, %c, %c, sizes = [5, 1, 1, 1, 3] : (tensor<5x1x4x1x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x1x1x1x3xf32>
246+
%6 = stablehlo.reshape %5 : (tensor<5x1x1x1x3xf32>) -> tensor<5x1x3xf32>
247+
%7 = stablehlo.dynamic_update_slice %iterArg_3, %6, %c, %4, %c : (tensor<5x4x3xf32>, tensor<5x1x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x4x3xf32>
248+
stablehlo.return %2, %7 : tensor<i64>, tensor<5x4x3xf32>
249+
}
250+
return %1#1 : tensor<5x4x3xf32>
251+
}
252+
}
253+
254+
// CHECK: func.func @main(%arg0: tensor<5x4x3xf32>, %arg1: tensor<3x1x4x1x5xf32>) -> tensor<5x4x3xf32> {
255+
// CHECK-NEXT: %0 = stablehlo.transpose %arg1, dims = [4, 1, 2, 3, 0] : (tensor<3x1x4x1x5xf32>) -> tensor<5x1x4x1x3xf32>
256+
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<5x1x4x1x3xf32>) -> tensor<5x4x3xf32>
257+
// CHECK-NEXT: return %1 : tensor<5x4x3xf32>
258+
// CHECK-NEXT: }
259+
260+
module {
261+
func.func @main(%arg0: tensor<1x3x4x1x5xf32>, %arg1: tensor<5x4x3xf32>) -> tensor<1x3x4x1x5xf32> {
262+
%c = stablehlo.constant dense<0> : tensor<i32>
263+
%c_0 = stablehlo.constant dense<1> : tensor<i32>
264+
%c_1 = stablehlo.constant dense<0> : tensor<i64>
265+
%c_2 = stablehlo.constant dense<1> : tensor<i64>
266+
%c_3 = stablehlo.constant dense<3> : tensor<i64>
267+
%0:2 = stablehlo.while(%iterArg = %c_1, %iterArg_4 = %arg0) : tensor<i64>, tensor<1x3x4x1x5xf32>
268+
cond {
269+
%1 = stablehlo.compare LT, %iterArg, %c_3 : (tensor<i64>, tensor<i64>) -> tensor<i1>
270+
stablehlo.return %1 : tensor<i1>
271+
} do {
272+
%1 = stablehlo.add %c_2, %iterArg : tensor<i64>
273+
%2 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32>
274+
%3 = stablehlo.subtract %2, %c_0 : tensor<i32>
275+
%4 = stablehlo.dynamic_slice %arg1, %c, %3, %c, sizes = [5, 1, 3] : (tensor<5x4x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x1x3xf32>
276+
%5 = stablehlo.reshape %4 : (tensor<5x1x3xf32>) -> tensor<5x1x1x3x1xf32>
277+
%6 = stablehlo.transpose %5, dims = [4, 3, 2, 1, 0] : (tensor<5x1x1x3x1xf32>) -> tensor<1x3x1x1x5xf32>
278+
%7 = stablehlo.dynamic_update_slice %iterArg_4, %6, %c, %c, %3, %c, %c : (tensor<1x3x4x1x5xf32>, tensor<1x3x1x1x5xf32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x3x4x1x5xf32>
279+
stablehlo.return %1, %7 : tensor<i64>, tensor<1x3x4x1x5xf32>
280+
}
281+
return %0#1 : tensor<1x3x4x1x5xf32>
282+
}
283+
}
284+
285+
// CHECK: func.func @main(%arg0: tensor<1x3x4x1x5xf32>, %arg1: tensor<5x4x3xf32>) -> tensor<1x3x4x1x5xf32> {
286+
// CHECK-NEXT: %0 = stablehlo.slice %arg1 [0:5, 0:3, 0:3] : (tensor<5x4x3xf32>) -> tensor<5x3x3xf32>
287+
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<5x3x3xf32>) -> tensor<5x3x1x3x1xf32>
288+
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [4, 3, 2, 1, 0] : (tensor<5x3x1x3x1xf32>) -> tensor<1x3x1x3x5xf32>
289+
// CHECK-NEXT: %3 = stablehlo.reshape %2 : (tensor<1x3x1x3x5xf32>) -> tensor<1x3x3x1x5xf32>
290+
// CHECK-NEXT: %4 = stablehlo.slice %arg0 [0:1, 0:3, 3:4, 0:1, 0:5] : (tensor<1x3x4x1x5xf32>) -> tensor<1x3x1x1x5xf32>
291+
// CHECK-NEXT: %5 = stablehlo.concatenate %3, %4, dim = 2 : (tensor<1x3x3x1x5xf32>, tensor<1x3x1x1x5xf32>) -> tensor<1x3x4x1x5xf32>
292+
// CHECK-NEXT: return %5 : tensor<1x3x4x1x5xf32>
293+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)