Skip to content

Commit

Permalink
Revert "[FIRRTL] Simplify muxes when a particular bit value selects t…
Browse files Browse the repository at this point in the history
…he same value. (#6382)"

This reverts commit e49b521.
  • Loading branch information
nandor committed Nov 9, 2023
1 parent c7d5e31 commit aae068d
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 71 deletions.
62 changes: 0 additions & 62 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1692,68 +1692,6 @@ LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
}
}

// Eliminate unneeded bits in the index. These arise from duplicate values in
// the mux. This is done by slicing the mux into a mux tree.
// multibit_mux(index, {a, b, a, c}) -> multibit_mux(index[0],
// multibit_mux(index[1], a,a), multibit_mux(index[1]}, {b,c})) Search for
// identities in specific bit slices. This is robust to unknown width
// indexes.
for (uint64_t bit = 0,
lastbit = op.getIndex().getType().getBitWidthOrSentinel();
bit < lastbit; ++bit) {
for (int curval = 0; curval <= 1; ++curval) {
// We don't collect values here as the normal case is we don't find a
// match, so we don't want to move data around and do allocations.
Value v;
uint64_t count = 0;
for (uint64_t i = 0, e = op.getInputs().size(); i < e; ++i) {
if (((i >> bit) & 1) != curval)
continue;
++count;
if (!v)
v = op.getInputs()[i];
if (v != op.getInputs()[i]) {
v = {};
break;
}
}
if (!v || count == 1)
continue;
// Found match, collect varying side of the future mux
SmallVector<Value> nonSimple;
for (uint64_t i = 0, e = op.getInputs().size(); i < e; ++i) {
if (((i >> bit) & 1) != curval)
nonSimple.push_back(op.getInputs()[i]);
}
Value indBit = rewriter.createOrFold<BitsPrimOp>(op.getLoc(),
op.getIndex(), bit, bit);
Value indBitRemLow;
if (bit)
indBitRemLow = rewriter.createOrFold<BitsPrimOp>(
op.getLoc(), op.getIndex(), bit - 1, 0);
else
indBitRemLow = rewriter.create<ConstantOp>(
op.getLoc(), IntType::get(op.getContext(), false, 0),
APInt(0U, 0UL));
Value indBitRemHigh;
if (bit == lastbit - 1)
indBitRemHigh = rewriter.create<ConstantOp>(
op.getLoc(), IntType::get(op.getContext(), false, 0),
APInt(0U, 0UL));
else
indBitRemHigh = rewriter.createOrFold<BitsPrimOp>(
op.getLoc(), op.getIndex(), lastbit - 1, bit + 1);
Value indBitRem = rewriter.createOrFold<CatPrimOp>(
op.getLoc(), indBitRemHigh, indBitRemLow);
Value otherSide =
rewriter.create<MultibitMuxOp>(op.getLoc(), indBitRem, nonSimple);
Value high = curval ? otherSide : v;
Value low = curval ? v : otherSide;
replaceOpWithNewOpAndCopyName<MuxPrimOp>(rewriter, op, indBit, high, low);
return success();
}
}

// If the size is 2, canonicalize into a normal mux to introduce more folds.
if (op.getInputs().size() != 2)
return failure();
Expand Down
9 changes: 0 additions & 9 deletions test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3309,13 +3309,4 @@ firrtl.module @Whens(in %clock: !firrtl.clock, in %a: !firrtl.uint<1>, in %reset
}
}

// CHECK-LABEL: firrtl.module @UselessIndexBit
firrtl.module @UselessIndexBit(in %a: !firrtl.uint<3>, out %b: !firrtl.uint<4>, in %c: !firrtl.uint<4>, in %d: !firrtl.uint<4>, in %e: !firrtl.uint<4>) attributes {convention = #firrtl<convention scalarized>} {
%c0_ui4 = firrtl.constant 0 : !firrtl.uint<4> {name = "ttable_2"}
%0 = firrtl.multibit_mux %a, %c0_ui4, %c0_ui4, %c0_ui4, %c0_ui4, %c0_ui4, %c0_ui4, %d, %c : !firrtl.uint<3>, !firrtl.uint<4>
firrtl.strictconnect %b, %0 : !firrtl.uint<4>
// CHECK: firrtl.mux({{.*}}, %d, %c)
// CHECK: firrtl.mux({{.*}}, %c0_ui4, {{.*}})
}

}

0 comments on commit aae068d

Please sign in to comment.