@@ -21,6 +21,21 @@ using namespace circt;
21
21
using namespace comb ;
22
22
using namespace matchers ;
23
23
24
+ // / In comb, we assume no knowledge of the semantics of cross-block dataflow. As
25
+ // / such, cross-block dataflow is interpreted as a canonicalization barrier.
26
+ // / This is a conservative approach which:
27
+ // / 1. still allows for efficient canonicalization for the common CIRCT usecase
28
+ // / of comb (comb logic nested inside single-block hw.module's)
29
+ // / 2. allows comb operations to be used in non-HW container ops - that may use
30
+ // / MLIR blocks and regions to represent various forms of hierarchical
31
+ // / abstractions, thus allowing comb to compose with other dialects.
32
+ static bool hasOperandsOutsideOfBlock (Operation *op) {
33
+ Block *thisBlock = op->getBlock ();
34
+ return llvm::any_of (op->getOperands (), [&](Value operand) {
35
+ return operand.getParentBlock () != thisBlock;
36
+ });
37
+ }
38
+
24
39
// / Create a new instance of a generic operation that only has value operands,
25
40
// / and has a single result value whose type matches the first operand.
26
41
// /
@@ -242,6 +257,9 @@ static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
242
257
// ===----------------------------------------------------------------------===//
243
258
244
259
OpFoldResult ReplicateOp::fold (FoldAdaptor adaptor) {
260
+ if (hasOperandsOutsideOfBlock (getOperation ()))
261
+ return {};
262
+
245
263
// Replicate one time -> noop.
246
264
if (getType ().cast <IntegerType>().getWidth () ==
247
265
getInput ().getType ().getIntOrFloatBitWidth ())
@@ -269,6 +287,9 @@ OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
269
287
}
270
288
271
289
OpFoldResult ParityOp::fold (FoldAdaptor adaptor) {
290
+ if (hasOperandsOutsideOfBlock (getOperation ()))
291
+ return {};
292
+
272
293
// Constant fold.
273
294
if (auto input = adaptor.getInput ().dyn_cast_or_null <IntegerAttr>())
274
295
return getIntAttr (APInt (1 , input.getValue ().popcount () & 1 ), getContext ());
@@ -295,6 +316,9 @@ static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
295
316
}
296
317
297
318
OpFoldResult ShlOp::fold (FoldAdaptor adaptor) {
319
+ if (hasOperandsOutsideOfBlock (getOperation ()))
320
+ return {};
321
+
298
322
if (auto rhs = adaptor.getRhs ().dyn_cast_or_null <IntegerAttr>()) {
299
323
unsigned shift = rhs.getValue ().getZExtValue ();
300
324
unsigned width = getType ().getIntOrFloatBitWidth ();
@@ -308,6 +332,9 @@ OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
308
332
}
309
333
310
334
LogicalResult ShlOp::canonicalize (ShlOp op, PatternRewriter &rewriter) {
335
+ if (hasOperandsOutsideOfBlock (&*op))
336
+ return failure ();
337
+
311
338
// ShlOp(x, cst) -> Concat(Extract(x), zeros)
312
339
APInt value;
313
340
if (!matchPattern (op.getRhs (), m_ConstantInt (&value)))
@@ -332,6 +359,9 @@ LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
332
359
}
333
360
334
361
OpFoldResult ShrUOp::fold (FoldAdaptor adaptor) {
362
+ if (hasOperandsOutsideOfBlock (getOperation ()))
363
+ return {};
364
+
335
365
if (auto rhs = adaptor.getRhs ().dyn_cast_or_null <IntegerAttr>()) {
336
366
unsigned shift = rhs.getValue ().getZExtValue ();
337
367
if (shift == 0 )
@@ -345,6 +375,9 @@ OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
345
375
}
346
376
347
377
LogicalResult ShrUOp::canonicalize (ShrUOp op, PatternRewriter &rewriter) {
378
+ if (hasOperandsOutsideOfBlock (&*op))
379
+ return failure ();
380
+
348
381
// ShrUOp(x, cst) -> Concat(zeros, Extract(x))
349
382
APInt value;
350
383
if (!matchPattern (op.getRhs (), m_ConstantInt (&value)))
@@ -369,6 +402,9 @@ LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
369
402
}
370
403
371
404
OpFoldResult ShrSOp::fold (FoldAdaptor adaptor) {
405
+ if (hasOperandsOutsideOfBlock (getOperation ()))
406
+ return {};
407
+
372
408
if (auto rhs = adaptor.getRhs ().dyn_cast_or_null <IntegerAttr>()) {
373
409
if (rhs.getValue ().getZExtValue () == 0 )
374
410
return getOperand (0 );
@@ -377,6 +413,9 @@ OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
377
413
}
378
414
379
415
LogicalResult ShrSOp::canonicalize (ShrSOp op, PatternRewriter &rewriter) {
416
+ if (hasOperandsOutsideOfBlock (&*op))
417
+ return failure ();
418
+
380
419
// ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
381
420
APInt value;
382
421
if (!matchPattern (op.getRhs (), m_ConstantInt (&value)))
@@ -406,6 +445,9 @@ LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
406
445
// ===----------------------------------------------------------------------===//
407
446
408
447
OpFoldResult ExtractOp::fold (FoldAdaptor adaptor) {
448
+ if (hasOperandsOutsideOfBlock (getOperation ()))
449
+ return {};
450
+
409
451
// If we are extracting the entire input, then return it.
410
452
if (getInput ().getType () == getType ())
411
453
return getInput ();
@@ -534,6 +576,9 @@ static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
534
576
}
535
577
536
578
LogicalResult ExtractOp::canonicalize (ExtractOp op, PatternRewriter &rewriter) {
579
+ if (hasOperandsOutsideOfBlock (&*op))
580
+ return failure ();
581
+
537
582
auto *inputOp = op.getInput ().getDefiningOp ();
538
583
539
584
// This turns out to be incredibly expensive. Disable until performance is
@@ -744,6 +789,9 @@ static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
744
789
}
745
790
746
791
OpFoldResult AndOp::fold (FoldAdaptor adaptor) {
792
+ if (hasOperandsOutsideOfBlock (getOperation ()))
793
+ return {};
794
+
747
795
APInt value = APInt::getAllOnes (getType ().cast <IntegerType>().getWidth ());
748
796
749
797
auto inputs = adaptor.getInputs ();
@@ -841,6 +889,9 @@ static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
841
889
}
842
890
843
891
LogicalResult AndOp::canonicalize (AndOp op, PatternRewriter &rewriter) {
892
+ if (hasOperandsOutsideOfBlock (&*op))
893
+ return failure ();
894
+
844
895
auto inputs = op.getInputs ();
845
896
auto size = inputs.size ();
846
897
assert (size > 1 && " expected 2 or more operands, `fold` should handle this" );
@@ -974,6 +1025,9 @@ LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
974
1025
}
975
1026
976
1027
OpFoldResult OrOp::fold (FoldAdaptor adaptor) {
1028
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1029
+ return {};
1030
+
977
1031
auto value = APInt::getZero (getType ().cast <IntegerType>().getWidth ());
978
1032
auto inputs = adaptor.getInputs ();
979
1033
// or(x, 10, 01) -> 11
@@ -1113,6 +1167,9 @@ static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
1113
1167
}
1114
1168
1115
1169
LogicalResult OrOp::canonicalize (OrOp op, PatternRewriter &rewriter) {
1170
+ if (hasOperandsOutsideOfBlock (&*op))
1171
+ return failure ();
1172
+
1116
1173
auto inputs = op.getInputs ();
1117
1174
auto size = inputs.size ();
1118
1175
assert (size > 1 && " expected 2 or more operands" );
@@ -1212,6 +1269,9 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1212
1269
}
1213
1270
1214
1271
OpFoldResult XorOp::fold (FoldAdaptor adaptor) {
1272
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1273
+ return {};
1274
+
1215
1275
auto size = getInputs ().size ();
1216
1276
auto inputs = adaptor.getInputs ();
1217
1277
@@ -1264,6 +1324,9 @@ static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1264
1324
}
1265
1325
1266
1326
LogicalResult XorOp::canonicalize (XorOp op, PatternRewriter &rewriter) {
1327
+ if (hasOperandsOutsideOfBlock (&*op))
1328
+ return failure ();
1329
+
1267
1330
auto inputs = op.getInputs ();
1268
1331
auto size = inputs.size ();
1269
1332
assert (size > 1 && " expected 2 or more operands" );
@@ -1339,6 +1402,9 @@ LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1339
1402
}
1340
1403
1341
1404
OpFoldResult SubOp::fold (FoldAdaptor adaptor) {
1405
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1406
+ return {};
1407
+
1342
1408
// sub(x - x) -> 0
1343
1409
if (getRhs () == getLhs ())
1344
1410
return getIntAttr (
@@ -1369,6 +1435,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1369
1435
}
1370
1436
1371
1437
LogicalResult SubOp::canonicalize (SubOp op, PatternRewriter &rewriter) {
1438
+ if (hasOperandsOutsideOfBlock (&*op))
1439
+ return failure ();
1440
+
1372
1441
// sub(x, cst) -> add(x, -cst)
1373
1442
APInt value;
1374
1443
if (matchPattern (op.getRhs (), m_ConstantInt (&value))) {
@@ -1386,6 +1455,9 @@ LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1386
1455
}
1387
1456
1388
1457
OpFoldResult AddOp::fold (FoldAdaptor adaptor) {
1458
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1459
+ return {};
1460
+
1389
1461
auto size = getInputs ().size ();
1390
1462
1391
1463
// add(x) -> x -- noop
@@ -1397,6 +1469,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1397
1469
}
1398
1470
1399
1471
LogicalResult AddOp::canonicalize (AddOp op, PatternRewriter &rewriter) {
1472
+ if (hasOperandsOutsideOfBlock (&*op))
1473
+ return failure ();
1474
+
1400
1475
auto inputs = op.getInputs ();
1401
1476
auto size = inputs.size ();
1402
1477
assert (size > 1 && " expected 2 or more operands" );
@@ -1497,6 +1572,9 @@ LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1497
1572
}
1498
1573
1499
1574
OpFoldResult MulOp::fold (FoldAdaptor adaptor) {
1575
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1576
+ return {};
1577
+
1500
1578
auto size = getInputs ().size ();
1501
1579
auto inputs = adaptor.getInputs ();
1502
1580
@@ -1521,6 +1599,9 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1521
1599
}
1522
1600
1523
1601
LogicalResult MulOp::canonicalize (MulOp op, PatternRewriter &rewriter) {
1602
+ if (hasOperandsOutsideOfBlock (&*op))
1603
+ return failure ();
1604
+
1524
1605
auto inputs = op.getInputs ();
1525
1606
auto size = inputs.size ();
1526
1607
assert (size > 1 && " expected 2 or more operands" );
@@ -1585,10 +1666,16 @@ static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1585
1666
}
1586
1667
1587
1668
OpFoldResult DivUOp::fold (FoldAdaptor adaptor) {
1669
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1670
+ return {};
1671
+
1588
1672
return foldDiv<DivUOp, /* isSigned=*/ false >(*this , adaptor.getOperands ());
1589
1673
}
1590
1674
1591
1675
OpFoldResult DivSOp::fold (FoldAdaptor adaptor) {
1676
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1677
+ return {};
1678
+
1592
1679
return foldDiv<DivSOp, /* isSigned=*/ true >(*this , adaptor.getOperands ());
1593
1680
}
1594
1681
@@ -1616,10 +1703,16 @@ static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1616
1703
}
1617
1704
1618
1705
OpFoldResult ModUOp::fold (FoldAdaptor adaptor) {
1706
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1707
+ return {};
1708
+
1619
1709
return foldMod<ModUOp, /* isSigned=*/ false >(*this , adaptor.getOperands ());
1620
1710
}
1621
1711
1622
1712
OpFoldResult ModSOp::fold (FoldAdaptor adaptor) {
1713
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1714
+ return {};
1715
+
1623
1716
return foldMod<ModSOp, /* isSigned=*/ true >(*this , adaptor.getOperands ());
1624
1717
}
1625
1718
// ===----------------------------------------------------------------------===//
@@ -1628,6 +1721,9 @@ OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1628
1721
1629
1722
// Constant folding
1630
1723
OpFoldResult ConcatOp::fold (FoldAdaptor adaptor) {
1724
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1725
+ return {};
1726
+
1631
1727
if (getNumOperands () == 1 )
1632
1728
return getOperand (0 );
1633
1729
@@ -1652,6 +1748,9 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1652
1748
}
1653
1749
1654
1750
LogicalResult ConcatOp::canonicalize (ConcatOp op, PatternRewriter &rewriter) {
1751
+ if (hasOperandsOutsideOfBlock (&*op))
1752
+ return failure ();
1753
+
1655
1754
auto inputs = op.getInputs ();
1656
1755
auto size = inputs.size ();
1657
1756
assert (size > 1 && " expected 2 or more operands" );
@@ -1815,6 +1914,9 @@ LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1815
1914
// ===----------------------------------------------------------------------===//
1816
1915
1817
1916
OpFoldResult MuxOp::fold (FoldAdaptor adaptor) {
1917
+ if (hasOperandsOutsideOfBlock (getOperation ()))
1918
+ return {};
1919
+
1818
1920
// mux (c, b, b) -> b
1819
1921
if (getTrueValue () == getFalseValue ())
1820
1922
return getTrueValue ();
@@ -2264,6 +2366,9 @@ struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2264
2366
2265
2367
LogicalResult MuxRewriter::matchAndRewrite (MuxOp op,
2266
2368
PatternRewriter &rewriter) const {
2369
+ if (hasOperandsOutsideOfBlock (&*op))
2370
+ return failure ();
2371
+
2267
2372
// If the op has a SV attribute, don't optimize it.
2268
2373
if (hasSVAttributes (op))
2269
2374
return failure ();
@@ -2554,6 +2659,9 @@ struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2554
2659
2555
2660
LogicalResult matchAndRewrite (hw::ArrayCreateOp op,
2556
2661
PatternRewriter &rewriter) const override {
2662
+ if (hasOperandsOutsideOfBlock (&*op))
2663
+ return failure ();
2664
+
2557
2665
if (foldArrayOfMuxes (op, rewriter))
2558
2666
return success ();
2559
2667
return failure ();
@@ -2633,6 +2741,9 @@ static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2633
2741
}
2634
2742
2635
2743
OpFoldResult ICmpOp::fold (FoldAdaptor adaptor) {
2744
+ if (hasOperandsOutsideOfBlock (getOperation ()))
2745
+ return {};
2746
+
2636
2747
// gt a, a -> false
2637
2748
// gte a, a -> true
2638
2749
if (getLhs () == getRhs ()) {
@@ -2908,6 +3019,9 @@ static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
2908
3019
}
2909
3020
2910
3021
LogicalResult ICmpOp::canonicalize (ICmpOp op, PatternRewriter &rewriter) {
3022
+ if (hasOperandsOutsideOfBlock (&*op))
3023
+ return failure ();
3024
+
2911
3025
APInt lhs, rhs;
2912
3026
2913
3027
// icmp 1, x -> icmp x, 1
0 commit comments