Skip to content

Commit 5d8cf69

Browse files
authored
[CombFolds] Don't canonicalize extract(shl(1, x)) if shift is multiply used (#7527)
There is a canonicalization for `exract(c, shl(1, x))` to `x == c` but this canonicalization introduces a bunch of comparision to constants. This harms PPA when bitwidth is large (e.g. 16 bit shift introduce 2^16 icmp op). To prevent such regressions this commit imposes restriction regarding the number of uses for shift.
1 parent 29b1c1c commit 5d8cf69

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

lib/Dialect/Comb/CombFolds.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -693,17 +693,20 @@ LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
693693
// `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
694694
// extracted.
695695
if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
696-
if (auto shlOp = dyn_cast<ShlOp>(inputOp))
697-
if (auto lhsCst = shlOp.getOperand(0).getDefiningOp<hw::ConstantOp>())
698-
if (lhsCst.getValue().isOne()) {
699-
auto newCst = rewriter.create<hw::ConstantOp>(
700-
shlOp.getLoc(),
701-
APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
702-
replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
703-
shlOp->getOperand(1), newCst,
704-
false);
705-
return success();
706-
}
696+
if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
697+
// Don't canonicalize if the shift is multiply used.
698+
if (shlOp->hasOneUse())
699+
if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
700+
if (lhsCst.getValue().isOne()) {
701+
auto newCst = rewriter.create<hw::ConstantOp>(
702+
shlOp.getLoc(),
703+
APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
704+
replaceOpWithNewOpAndCopyName<ICmpOp>(
705+
rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
706+
false);
707+
return success();
708+
}
709+
}
707710

708711
return failure();
709712
}

test/Dialect/Comb/canonicalization.mlir

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,17 +1221,26 @@ hw.module @test1560(in %value: i38, out a: i1) {
12211221
}
12221222

12231223
// CHECK-LABEL: hw.module @extractShift
1224-
hw.module @extractShift(in %arg0 : i4, out o1 : i1, out o2: i1) {
1224+
hw.module @extractShift(in %arg0 : i4, out o1 : i1, out o2: i1, out o3: i1, out o4: i1) {
12251225
%c1 = hw.constant 1: i4
12261226
%0 = comb.shl %c1, %arg0 : i4
1227+
%1 = comb.shl %c1, %arg0 : i4
1228+
%2 = comb.shl %c1, %arg0 : i4
12271229

1228-
// CHECK: %0 = comb.icmp eq %arg0, %c0_i4 : i4
1229-
%1 = comb.extract %0 from 0 : (i4) -> i1
1230+
// CHECK: %[[O1:.+]] = comb.icmp eq %arg0, %c0_i4 : i4
1231+
%3 = comb.extract %0 from 0 : (i4) -> i1
12301232

1231-
// CHECK: %1 = comb.icmp eq %arg0, %c2_i4 : i4
1232-
%2 = comb.extract %0 from 2 : (i4) -> i1
1233-
// CHECK: hw.output %0, %1
1234-
hw.output %1, %2: i1, i1
1233+
// CHECK: %[[O2:.+]] = comb.icmp eq %arg0, %c2_i4 : i4
1234+
%4 = comb.extract %1 from 2 : (i4) -> i1
1235+
1236+
// CHECK: %[[O3:.+]] = comb.extract
1237+
%5 = comb.extract %2 from 2 : (i4) -> i1
1238+
1239+
// CHECK: %[[O4:.+]] = comb.extract
1240+
%6 = comb.extract %2 from 2 : (i4) -> i1
1241+
1242+
// CHECK: hw.output %[[O1]], %[[O2]], %[[O3]], %[[O4]]
1243+
hw.output %3, %4, %5, %6: i1, i1, i1, i1
12351244
}
12361245

12371246
// CHECK-LABEL: hw.module @moduloZeroDividend

0 commit comments

Comments
 (0)