|
29 | 29 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
30 | 30 | #include "llvm/ADT/TypeSwitch.h"
|
31 | 31 |
|
| 32 | +#include <type_traits> |
32 | 33 | #include <variant>
|
33 | 34 |
|
34 | 35 | namespace circt {
|
@@ -126,6 +127,19 @@ class PipelineScheduler : public calyx::SchedulerInterface<Scheduleable> {
|
126 | 127 | return pipelineRegs[stage];
|
127 | 128 | }
|
128 | 129 |
|
| 130 | + /// Returns the pipeline register for this value if its defining operation is |
| 131 | + /// a stage, and std::nullopt otherwise. |
| 132 | + std::optional<calyx::RegisterOp> getPipelineRegister(Value value) { |
| 133 | + auto opStage = dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp()); |
| 134 | + if (opStage == nullptr) |
| 135 | + return std::nullopt; |
| 136 | + // The pipeline register for this input value needs to be discovered. |
| 137 | + auto opResult = cast<OpResult>(value); |
| 138 | + unsigned int opNumber = opResult.getResultNumber(); |
| 139 | + auto &stageRegisters = getPipelineRegs(opStage); |
| 140 | + return stageRegisters.find(opNumber)->second; |
| 141 | + } |
| 142 | + |
129 | 143 | /// Add a stage's groups to the pipeline prologue.
|
130 | 144 | void addPipelinePrologue(Operation *op, SmallVector<StringAttr> groupNames) {
|
131 | 145 | pipelinePrologue[op].push_back(groupNames);
|
@@ -306,9 +320,14 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
|
306 | 320 | /// Create assignments to the inputs of the library op.
|
307 | 321 | auto group = createGroupForOp<TGroupOp>(rewriter, op);
|
308 | 322 | rewriter.setInsertionPointToEnd(group.getBodyBlock());
|
309 |
| - for (auto dstOp : enumerate(opInputPorts)) |
310 |
| - rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(), |
311 |
| - op->getOperand(dstOp.index())); |
| 323 | + for (auto dstOp : enumerate(opInputPorts)) { |
| 324 | + Value srcOp = op->getOperand(dstOp.index()); |
| 325 | + std::optional<calyx::RegisterOp> pipelineRegister = |
| 326 | + getState<ComponentLoweringState>().getPipelineRegister(srcOp); |
| 327 | + if (pipelineRegister.has_value()) |
| 328 | + srcOp = pipelineRegister->getOut(); |
| 329 | + rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(), srcOp); |
| 330 | + } |
312 | 331 |
|
313 | 332 | /// Replace the result values of the source operator with the new operator.
|
314 | 333 | for (auto res : enumerate(opOutputPorts)) {
|
@@ -1055,22 +1074,17 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
|
1055 | 1074 | Value value = operand.get();
|
1056 | 1075 |
|
1057 | 1076 | // Get the pipeline register for that result.
|
1058 |
| - auto pipelineRegister = pipelineRegisters[i]; |
| 1077 | + calyx::RegisterOp pipelineRegister = pipelineRegisters[i]; |
| 1078 | + if (std::optional<calyx::RegisterOp> pr = |
| 1079 | + state.getPipelineRegister(value)) { |
| 1080 | + value = pr->getOut(); |
| 1081 | + } |
1059 | 1082 |
|
1060 | 1083 | calyx::GroupOp group;
|
1061 | 1084 | // Get the evaluating group for that value.
|
1062 | 1085 | std::optional<calyx::GroupInterface> evaluatingGroup =
|
1063 | 1086 | state.findEvaluatingGroup(value);
|
1064 | 1087 | if (!evaluatingGroup.has_value()) {
|
1065 |
| - if (auto opStage = |
1066 |
| - dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp())) { |
1067 |
| - // The pipeline register for this input value needs to be discovered. |
1068 |
| - auto opResult = cast<OpResult>(value); |
1069 |
| - unsigned int opNumber = opResult.getResultNumber(); |
1070 |
| - auto &stageRegisters = state.getPipelineRegs(opStage); |
1071 |
| - calyx::RegisterOp opRegister = stageRegisters.find(opNumber)->second; |
1072 |
| - value = opRegister.getOut(); // Pass the `out` wire of this register. |
1073 |
| - } |
1074 | 1088 | if (value.getDefiningOp<calyx::RegisterOp>() == nullptr) {
|
1075 | 1089 | // We add this for any unhandled cases.
|
1076 | 1090 | llvm::errs() << "unexpected: input value: " << value << ", in stage "
|
@@ -1166,8 +1180,9 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
|
1166 | 1180 | }
|
1167 | 1181 | doneOp.getSrcMutable().assign(pipelineRegister.getDone());
|
1168 | 1182 |
|
1169 |
| - // Remove the old register completely. |
1170 |
| - rewriter.eraseOp(tempReg); |
| 1183 | + // Remove the old register if it has no more uses. |
| 1184 | + if (tempReg->use_empty()) |
| 1185 | + rewriter.eraseOp(tempReg); |
1171 | 1186 |
|
1172 | 1187 | return group;
|
1173 | 1188 | }
|
@@ -1534,10 +1549,11 @@ class LoopScheduleToCalyxPass
|
1534 | 1549 | if (runOnce)
|
1535 | 1550 | config.maxIterations = 1;
|
1536 | 1551 |
|
1537 |
| - /// Can't return applyPatternsGreedily. Root isn't |
| 1552 | + /// Can't return applyPatternsAndFoldGreedily. Root isn't |
1538 | 1553 | /// necessarily erased so it will always return failed(). Instead,
|
1539 | 1554 | /// forward the 'succeeded' value from PartialLoweringPatternBase.
|
1540 |
| - (void)applyPatternsGreedily(getOperation(), std::move(pattern), config); |
| 1555 | + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(pattern), |
| 1556 | + config); |
1541 | 1557 | return partialPatternRes;
|
1542 | 1558 | }
|
1543 | 1559 |
|
@@ -1628,6 +1644,9 @@ void LoopScheduleToCalyxPass::runOnOperation() {
|
1628 | 1644 | addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
|
1629 | 1645 | *loweringState);
|
1630 | 1646 |
|
| 1647 | + addGreedyPattern<calyx::DeduplicateParallelOp>(loweringPatterns); |
| 1648 | + addGreedyPattern<calyx::DeduplicateStaticParallelOp>(loweringPatterns); |
| 1649 | + |
1631 | 1650 | /// This pattern performs various SSA replacements that must be done
|
1632 | 1651 | /// after control generation.
|
1633 | 1652 | addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
|
@@ -1665,8 +1684,8 @@ void LoopScheduleToCalyxPass::runOnOperation() {
|
1665 | 1684 | RewritePatternSet cleanupPatterns(&getContext());
|
1666 | 1685 | cleanupPatterns.add<calyx::MultipleGroupDonePattern,
|
1667 | 1686 | calyx::NonTerminatingGroupDonePattern>(&getContext());
|
1668 |
| - if (failed( |
1669 |
| - applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) { |
| 1687 | + if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| 1688 | + std::move(cleanupPatterns)))) { |
1670 | 1689 | signalPassFailure();
|
1671 | 1690 | return;
|
1672 | 1691 | }
|
|
0 commit comments