Skip to content

Commit fa71ca2

Browse files
authored
[Comb] Disallow canonicalization across MLIR blocks (#6235)
This is probably the most conservative implementation of this, but should suffice for the current requirements to this change. For any comb operation that has a canonicalizer, guard canonicalization on whether any operands are define outside of the current block. This still allows for (constant) folding, which (to me) still should be a safe thing to do across blocks - if that was not the case, then that would be a strange abstraction (i'd expect an op to be `IsolatedFromAbove` if it wants to prevent constant folding).
1 parent 6a4849d commit fa71ca2

File tree

4 files changed

+142
-5
lines changed

4 files changed

+142
-5
lines changed

lib/Dialect/Comb/CombFolds.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,21 @@ using namespace circt;
2121
using namespace comb;
2222
using namespace matchers;
2323

24+
/// In comb, we assume no knowledge of the semantics of cross-block dataflow. As
25+
/// such, cross-block dataflow is interpreted as a canonicalization barrier.
26+
/// This is a conservative approach which:
27+
/// 1. still allows for efficient canonicalization for the common CIRCT usecase
28+
/// of comb (comb logic nested inside single-block hw.module's)
29+
/// 2. allows comb operations to be used in non-HW container ops - that may use
30+
/// MLIR blocks and regions to represent various forms of hierarchical
31+
/// abstractions, thus allowing comb to compose with other dialects.
32+
static bool hasOperandsOutsideOfBlock(Operation *op) {
33+
Block *thisBlock = op->getBlock();
34+
return llvm::any_of(op->getOperands(), [&](Value operand) {
35+
return operand.getParentBlock() != thisBlock;
36+
});
37+
}
38+
2439
/// Create a new instance of a generic operation that only has value operands,
2540
/// and has a single result value whose type matches the first operand.
2641
///
@@ -242,6 +257,9 @@ static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
242257
//===----------------------------------------------------------------------===//
243258

244259
OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
260+
if (hasOperandsOutsideOfBlock(getOperation()))
261+
return {};
262+
245263
// Replicate one time -> noop.
246264
if (getType().cast<IntegerType>().getWidth() ==
247265
getInput().getType().getIntOrFloatBitWidth())
@@ -269,6 +287,9 @@ OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
269287
}
270288

271289
OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
290+
if (hasOperandsOutsideOfBlock(getOperation()))
291+
return {};
292+
272293
// Constant fold.
273294
if (auto input = adaptor.getInput().dyn_cast_or_null<IntegerAttr>())
274295
return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
@@ -295,6 +316,9 @@ static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
295316
}
296317

297318
OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
319+
if (hasOperandsOutsideOfBlock(getOperation()))
320+
return {};
321+
298322
if (auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
299323
unsigned shift = rhs.getValue().getZExtValue();
300324
unsigned width = getType().getIntOrFloatBitWidth();
@@ -308,6 +332,9 @@ OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
308332
}
309333

310334
LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
335+
if (hasOperandsOutsideOfBlock(&*op))
336+
return failure();
337+
311338
// ShlOp(x, cst) -> Concat(Extract(x), zeros)
312339
APInt value;
313340
if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
@@ -332,6 +359,9 @@ LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
332359
}
333360

334361
OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
362+
if (hasOperandsOutsideOfBlock(getOperation()))
363+
return {};
364+
335365
if (auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
336366
unsigned shift = rhs.getValue().getZExtValue();
337367
if (shift == 0)
@@ -345,6 +375,9 @@ OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
345375
}
346376

347377
LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
378+
if (hasOperandsOutsideOfBlock(&*op))
379+
return failure();
380+
348381
// ShrUOp(x, cst) -> Concat(zeros, Extract(x))
349382
APInt value;
350383
if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
@@ -369,6 +402,9 @@ LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
369402
}
370403

371404
OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
405+
if (hasOperandsOutsideOfBlock(getOperation()))
406+
return {};
407+
372408
if (auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
373409
if (rhs.getValue().getZExtValue() == 0)
374410
return getOperand(0);
@@ -377,6 +413,9 @@ OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
377413
}
378414

379415
LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
416+
if (hasOperandsOutsideOfBlock(&*op))
417+
return failure();
418+
380419
// ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
381420
APInt value;
382421
if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
@@ -406,6 +445,9 @@ LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
406445
//===----------------------------------------------------------------------===//
407446

408447
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
448+
if (hasOperandsOutsideOfBlock(getOperation()))
449+
return {};
450+
409451
// If we are extracting the entire input, then return it.
410452
if (getInput().getType() == getType())
411453
return getInput();
@@ -534,6 +576,9 @@ static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
534576
}
535577

536578
LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
579+
if (hasOperandsOutsideOfBlock(&*op))
580+
return failure();
581+
537582
auto *inputOp = op.getInput().getDefiningOp();
538583

539584
// This turns out to be incredibly expensive. Disable until performance is
@@ -744,6 +789,9 @@ static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
744789
}
745790

746791
OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
792+
if (hasOperandsOutsideOfBlock(getOperation()))
793+
return {};
794+
747795
APInt value = APInt::getAllOnes(getType().cast<IntegerType>().getWidth());
748796

749797
auto inputs = adaptor.getInputs();
@@ -841,6 +889,9 @@ static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
841889
}
842890

843891
LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
892+
if (hasOperandsOutsideOfBlock(&*op))
893+
return failure();
894+
844895
auto inputs = op.getInputs();
845896
auto size = inputs.size();
846897
assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
@@ -974,6 +1025,9 @@ LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
9741025
}
9751026

9761027
OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1028+
if (hasOperandsOutsideOfBlock(getOperation()))
1029+
return {};
1030+
9771031
auto value = APInt::getZero(getType().cast<IntegerType>().getWidth());
9781032
auto inputs = adaptor.getInputs();
9791033
// or(x, 10, 01) -> 11
@@ -1113,6 +1167,9 @@ static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
11131167
}
11141168

11151169
LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1170+
if (hasOperandsOutsideOfBlock(&*op))
1171+
return failure();
1172+
11161173
auto inputs = op.getInputs();
11171174
auto size = inputs.size();
11181175
assert(size > 1 && "expected 2 or more operands");
@@ -1212,6 +1269,9 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
12121269
}
12131270

12141271
OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1272+
if (hasOperandsOutsideOfBlock(getOperation()))
1273+
return {};
1274+
12151275
auto size = getInputs().size();
12161276
auto inputs = adaptor.getInputs();
12171277

@@ -1264,6 +1324,9 @@ static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
12641324
}
12651325

12661326
LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1327+
if (hasOperandsOutsideOfBlock(&*op))
1328+
return failure();
1329+
12671330
auto inputs = op.getInputs();
12681331
auto size = inputs.size();
12691332
assert(size > 1 && "expected 2 or more operands");
@@ -1339,6 +1402,9 @@ LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
13391402
}
13401403

13411404
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1405+
if (hasOperandsOutsideOfBlock(getOperation()))
1406+
return {};
1407+
13421408
// sub(x - x) -> 0
13431409
if (getRhs() == getLhs())
13441410
return getIntAttr(
@@ -1369,6 +1435,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
13691435
}
13701436

13711437
LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1438+
if (hasOperandsOutsideOfBlock(&*op))
1439+
return failure();
1440+
13721441
// sub(x, cst) -> add(x, -cst)
13731442
APInt value;
13741443
if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
@@ -1386,6 +1455,9 @@ LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
13861455
}
13871456

13881457
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1458+
if (hasOperandsOutsideOfBlock(getOperation()))
1459+
return {};
1460+
13891461
auto size = getInputs().size();
13901462

13911463
// add(x) -> x -- noop
@@ -1397,6 +1469,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
13971469
}
13981470

13991471
LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1472+
if (hasOperandsOutsideOfBlock(&*op))
1473+
return failure();
1474+
14001475
auto inputs = op.getInputs();
14011476
auto size = inputs.size();
14021477
assert(size > 1 && "expected 2 or more operands");
@@ -1497,6 +1572,9 @@ LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
14971572
}
14981573

14991574
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1575+
if (hasOperandsOutsideOfBlock(getOperation()))
1576+
return {};
1577+
15001578
auto size = getInputs().size();
15011579
auto inputs = adaptor.getInputs();
15021580

@@ -1521,6 +1599,9 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
15211599
}
15221600

15231601
LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1602+
if (hasOperandsOutsideOfBlock(&*op))
1603+
return failure();
1604+
15241605
auto inputs = op.getInputs();
15251606
auto size = inputs.size();
15261607
assert(size > 1 && "expected 2 or more operands");
@@ -1585,10 +1666,16 @@ static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
15851666
}
15861667

15871668
OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1669+
if (hasOperandsOutsideOfBlock(getOperation()))
1670+
return {};
1671+
15881672
return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
15891673
}
15901674

15911675
OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1676+
if (hasOperandsOutsideOfBlock(getOperation()))
1677+
return {};
1678+
15921679
return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
15931680
}
15941681

@@ -1616,10 +1703,16 @@ static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
16161703
}
16171704

16181705
OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1706+
if (hasOperandsOutsideOfBlock(getOperation()))
1707+
return {};
1708+
16191709
return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
16201710
}
16211711

16221712
OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1713+
if (hasOperandsOutsideOfBlock(getOperation()))
1714+
return {};
1715+
16231716
return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
16241717
}
16251718
//===----------------------------------------------------------------------===//
@@ -1628,6 +1721,9 @@ OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
16281721

16291722
// Constant folding
16301723
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1724+
if (hasOperandsOutsideOfBlock(getOperation()))
1725+
return {};
1726+
16311727
if (getNumOperands() == 1)
16321728
return getOperand(0);
16331729

@@ -1652,6 +1748,9 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
16521748
}
16531749

16541750
LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1751+
if (hasOperandsOutsideOfBlock(&*op))
1752+
return failure();
1753+
16551754
auto inputs = op.getInputs();
16561755
auto size = inputs.size();
16571756
assert(size > 1 && "expected 2 or more operands");
@@ -1815,6 +1914,9 @@ LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
18151914
//===----------------------------------------------------------------------===//
18161915

18171916
OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1917+
if (hasOperandsOutsideOfBlock(getOperation()))
1918+
return {};
1919+
18181920
// mux (c, b, b) -> b
18191921
if (getTrueValue() == getFalseValue())
18201922
return getTrueValue();
@@ -2264,6 +2366,9 @@ struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
22642366

22652367
LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
22662368
PatternRewriter &rewriter) const {
2369+
if (hasOperandsOutsideOfBlock(&*op))
2370+
return failure();
2371+
22672372
// If the op has a SV attribute, don't optimize it.
22682373
if (hasSVAttributes(op))
22692374
return failure();
@@ -2554,6 +2659,9 @@ struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
25542659

25552660
LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
25562661
PatternRewriter &rewriter) const override {
2662+
if (hasOperandsOutsideOfBlock(&*op))
2663+
return failure();
2664+
25572665
if (foldArrayOfMuxes(op, rewriter))
25582666
return success();
25592667
return failure();
@@ -2633,6 +2741,9 @@ static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
26332741
}
26342742

26352743
OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2744+
if (hasOperandsOutsideOfBlock(getOperation()))
2745+
return {};
2746+
26362747
// gt a, a -> false
26372748
// gte a, a -> true
26382749
if (getLhs() == getRhs()) {
@@ -2908,6 +3019,9 @@ static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
29083019
}
29093020

29103021
LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3022+
if (hasOperandsOutsideOfBlock(&*op))
3023+
return failure();
3024+
29113025
APInt lhs, rhs;
29123026

29133027
// icmp 1, x -> icmp x, 1

test/Dialect/Calyx/remove-comb-groups.mlir

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%don
1212
// CHECK: calyx.assign %eq_reg.write_en = %true : i1
1313
// CHECK: calyx.assign %eq.left = %true : i1
1414
// CHECK: calyx.assign %eq.right = %true : i1
15-
// CHECK: calyx.group_done %eq_reg.done ? %true : i1
15+
// CHECK: %0 = comb.and %eq_reg.done : i1
16+
// CHECK: calyx.group_done %0 ? %true : i1
1617
calyx.comb_group @Cond {
1718
calyx.assign %eq.left = %c1_1 : i1
1819
calyx.assign %eq.right = %c1_1 : i1
@@ -59,7 +60,8 @@ calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%don
5960
// CHECK: calyx.assign %eq_reg.write_en = %true : i1
6061
// CHECK: calyx.assign %eq.left = %true : i1
6162
// CHECK: calyx.assign %eq.right = %true : i1
62-
// CHECK: calyx.group_done %eq_reg.done ? %true : i1
63+
// CHECK: %0 = comb.and %eq_reg.done : i1
64+
// CHECK: calyx.group_done %0 ? %true : i1
6365
// CHECK: }
6466
calyx.comb_group @Cond1 {
6567
calyx.assign %eq.left = %c1_1 : i1
@@ -72,7 +74,8 @@ calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%don
7274
// CHECK: calyx.assign %eq_reg.write_en = %true : i1
7375
// CHECK: calyx.assign %eq.left = %true : i1
7476
// CHECK: calyx.assign %eq.right = %true : i1
75-
// CHECK: calyx.group_done %eq_reg.done ? %true : i1
77+
// CHECK: %0 = comb.and %eq_reg.done : i1
78+
// CHECK: calyx.group_done %0 ? %true : i1
7679
// CHECK: }
7780
calyx.comb_group @Cond2 {
7881
calyx.assign %eq.left = %c1_1 : i1

0 commit comments

Comments
 (0)