Skip to content

Commit

Permalink
[Comb] Don't try to canonicalize muxes indefinitely (#8023)
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart authored Jan 1, 2025
1 parent afa566a commit 029dbd4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
22 changes: 14 additions & 8 deletions lib/Dialect/Comb/CombFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,7 @@ OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
return {};

// mux (c, b, b) -> b
if (getTrueValue() == getFalseValue())
if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
return getTrueValue();
if (auto tv = adaptor.getTrueValue())
if (tv == adaptor.getFalseValue())
Expand Down Expand Up @@ -2183,6 +2183,9 @@ static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
// `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
// `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
if (subMux == op)
return false;

Value otherValue;
Value subCond = subMux.getCond();

Expand Down Expand Up @@ -2514,8 +2517,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
}
}

if (auto falseMux =
dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
falseMux && falseMux != op) {
// mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
if (op.getCond() == falseMux.getCond()) {
replaceOpWithNewOpAndCopyName<MuxOp>(
Expand All @@ -2529,8 +2532,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
return success();
}

if (auto trueMux =
dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
trueMux && trueMux != op) {
// mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
if (op.getCond() == trueMux.getCond()) {
replaceOpWithNewOpAndCopyName<MuxOp>(
Expand All @@ -2548,7 +2551,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
trueMux.getTrueValue() == falseMux.getTrueValue()) {
trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
falseMux != op) {
auto subMux = rewriter.create<MuxOp>(
rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
Expand All @@ -2562,7 +2566,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
trueMux.getFalseValue() == falseMux.getFalseValue()) {
trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
falseMux != op) {
auto subMux = rewriter.create<MuxOp>(
rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
Expand All @@ -2577,7 +2582,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
trueMux && falseMux &&
trueMux.getTrueValue() == falseMux.getTrueValue() &&
trueMux.getFalseValue() == falseMux.getFalseValue()) {
trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
falseMux != op) {
auto subMux = rewriter.create<MuxOp>(
rewriter.getFusedLoc(
{op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
Expand Down
11 changes: 7 additions & 4 deletions test/Dialect/Comb/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ hw.module @muxConstantsFold(in %cond: i1, out o: i25) {
hw.module @muxCommon(in %cond: i1, in %cond2: i1,
in %arg0 : i32, in %arg1 : i32, in %arg2: i32, in %arg3: i32,
out o1: i32, out o2: i32, out o3: i32, out o4: i32,
out o5: i32, out orResult: i32, out o6: i32, out o7: i32) {
out o5: i32, out orResult: i32, out o6: i32, out o7: i32, out o8 : i1) {
%allones = hw.constant -1 : i32
%notArg0 = comb.xor %arg0, %allones : i32

Expand Down Expand Up @@ -1275,10 +1275,13 @@ hw.module @muxCommon(in %cond: i1, in %cond2: i1,
%1 = comb.mux %cond, %arg1, %arg0 : i32
%o7 = comb.mux %cond2, %1, %arg0 : i32

/// CHECK: [[O8:%.+]] = comb.mux [[O8]], [[O8]], [[O8]] : i1
%o8 = comb.mux %o8, %o8, %o8 : i1

// CHECK: hw.output [[O1]], [[O2]], [[O3]], [[O4]], [[O5]], [[ORRESULT]],
// CHECK: [[O6]], [[O7]]
hw.output %o1, %o2, %o3, %o4, %o5, %orResult, %o6, %o7
: i32, i32, i32, i32, i32, i32, i32, i32
// CHECK: [[O6]], [[O7]], [[O8]]
hw.output %o1, %o2, %o3, %o4, %o5, %orResult, %o6, %o7, %o8
: i32, i32, i32, i32, i32, i32, i32, i32, i1
}

// CHECK-LABEL: @flatten_multi_use_and
Expand Down

0 comments on commit 029dbd4

Please sign in to comment.