Skip to content

Commit

Permalink
[FIRRTL] FoldUnusedBits: minor cleanup (#7914)
Browse files Browse the repository at this point in the history
* [FIRRTL] FoldUnusedBits: Early break when all bits are read

* [FIRRTL] FoldUnusedBits: Opportunistically fold generated ops

Once the memory has been replaced with the compressed memory, a lot of the
bitselect ops reading from the old memory will be selecting the entire memory.
Use createOrFold to eagerly clean up these "whole range bitselect" ops.
  • Loading branch information
rwy7 authored Nov 28, 2024
1 parent e2ebea5 commit b16f1ad
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2679,7 +2679,7 @@ struct FoldUnusedBits : public mlir::RewritePattern {
// ports whose data/rdata field is used only through bit select ops. The
// bit selects are then used to build a bit-mask. The ops are collected.
SmallVector<BitsPrimOp> readOps;
auto findReadUsers = [&](Value port, StringRef field) {
auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
auto portTy = type_cast<BundleType>(port.getType());
auto fieldIndex = portTy.getElementIndex(field);
assert(fieldIndex && "missing data port");
Expand All @@ -2691,16 +2691,19 @@ struct FoldUnusedBits : public mlir::RewritePattern {

for (auto *user : op->getUsers()) {
auto bits = dyn_cast<BitsPrimOp>(user);
if (!bits) {
usedBits.set();
continue;
}
if (!bits)
return failure();

usedBits.set(bits.getLo(), bits.getHi() + 1);
if (usedBits.all())
return failure();

mapping[bits.getLo()] = 0;
readOps.push_back(bits);
}
}

return success();
};

// Finds the users of write ports. This expects all the data/wdata fields
Expand Down Expand Up @@ -2741,20 +2744,21 @@ struct FoldUnusedBits : public mlir::RewritePattern {
return failure();
continue;
case MemOp::PortKind::Read:
findReadUsers(port, "data");
if (failed(findReadUsers(port, "data")))
return failure();
continue;
case MemOp::PortKind::ReadWrite:
if (failed(findWriteUsers(port, "wdata")))
return failure();
findReadUsers(port, "rdata");
if (failed(findReadUsers(port, "rdata")))
return failure();
continue;
}
llvm_unreachable("unknown port kind");
}

// Perform the transformation is there are some bits missing. Unused
// memories are handled in a different canonicalizer.
if (usedBits.all() || usedBits.none())
// Unused memories are handled in a different canonicalizer.
if (usedBits.none())
return failure();

// Build a mapping of existing indices to compacted ones.
Expand Down Expand Up @@ -2828,9 +2832,13 @@ struct FoldUnusedBits : public mlir::RewritePattern {
rewriter.setInsertionPointAfter(readOp);
auto it = mapping.find(readOp.getLo());
assert(it != mapping.end() && "bit op mapping not found");
rewriter.replaceOpWithNewOp<BitsPrimOp>(
readOp, readOp.getInput(),
// Create a new bit selection from the compressed memory. The new op may
// be folded if we are selecting the entire compressed memory.
auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
readOp.getLoc(), readOp.getInput(),
readOp.getHi() - readOp.getLo() + it->second, it->second);
rewriter.replaceAllUsesWith(readOp, newReadValue);
rewriter.eraseOp(readOp);
}

// Rewrite the writes into a concatenation of slices.
Expand All @@ -2840,11 +2848,11 @@ struct FoldUnusedBits : public mlir::RewritePattern {

Value catOfSlices;
for (auto &[start, end] : ranges) {
Value slice =
rewriter.create<BitsPrimOp>(writeOp.getLoc(), source, end, start);
Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
source, end, start);
if (catOfSlices) {
catOfSlices =
rewriter.create<CatPrimOp>(writeOp.getLoc(), slice, catOfSlices);
catOfSlices = rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(),
slice, catOfSlices);
} else {
catOfSlices = slice;
}
Expand Down

0 comments on commit b16f1ad

Please sign in to comment.