Skip to content

Commit 4a062a2

Browse files
authored
[LoopScheduleToCalyx] deduplicate groups within a ParOp. (#8055)
1 parent 76c562b commit 4a062a2

File tree

4 files changed

+155
-19
lines changed

4 files changed

+155
-19
lines changed

include/circt/Dialect/Calyx/CalyxLoweringUtils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,23 @@ struct EliminateUnusedCombGroups : mlir::OpRewritePattern<calyx::CombGroupOp> {
756756
PatternRewriter &rewriter) const override;
757757
};
758758

759+
/// Removes duplicate EnableOps in parallel operations.
760+
struct DeduplicateParallelOp : mlir::OpRewritePattern<calyx::ParOp> {
761+
using mlir::OpRewritePattern<calyx::ParOp>::OpRewritePattern;
762+
763+
LogicalResult matchAndRewrite(calyx::ParOp parOp,
764+
PatternRewriter &rewriter) const override;
765+
};
766+
767+
/// Removes duplicate EnableOps in static parallel operations.
768+
struct DeduplicateStaticParallelOp
769+
: mlir::OpRewritePattern<calyx::StaticParOp> {
770+
using mlir::OpRewritePattern<calyx::StaticParOp>::OpRewritePattern;
771+
772+
LogicalResult matchAndRewrite(calyx::StaticParOp parOp,
773+
PatternRewriter &rewriter) const override;
774+
};
775+
759776
/// This pass recursively inlines use-def chains of combinational logic (from
760777
/// non-stateful groups) into groups referenced in the control schedule.
761778
class InlineCombGroups

lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3030
#include "llvm/ADT/TypeSwitch.h"
3131

32+
#include <type_traits>
3233
#include <variant>
3334

3435
namespace circt {
@@ -126,6 +127,19 @@ class PipelineScheduler : public calyx::SchedulerInterface<Scheduleable> {
126127
return pipelineRegs[stage];
127128
}
128129

130+
/// Returns the pipeline register for this value if its defining operation is
131+
/// a stage, and std::nullopt otherwise.
132+
std::optional<calyx::RegisterOp> getPipelineRegister(Value value) {
133+
auto opStage = dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp());
134+
if (opStage == nullptr)
135+
return std::nullopt;
136+
// The pipeline register for this input value needs to be discovered.
137+
auto opResult = cast<OpResult>(value);
138+
unsigned int opNumber = opResult.getResultNumber();
139+
auto &stageRegisters = getPipelineRegs(opStage);
140+
return stageRegisters.find(opNumber)->second;
141+
}
142+
129143
/// Add a stage's groups to the pipeline prologue.
130144
void addPipelinePrologue(Operation *op, SmallVector<StringAttr> groupNames) {
131145
pipelinePrologue[op].push_back(groupNames);
@@ -306,9 +320,14 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
306320
/// Create assignments to the inputs of the library op.
307321
auto group = createGroupForOp<TGroupOp>(rewriter, op);
308322
rewriter.setInsertionPointToEnd(group.getBodyBlock());
309-
for (auto dstOp : enumerate(opInputPorts))
310-
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
311-
op->getOperand(dstOp.index()));
323+
for (auto dstOp : enumerate(opInputPorts)) {
324+
Value srcOp = op->getOperand(dstOp.index());
325+
std::optional<calyx::RegisterOp> pipelineRegister =
326+
getState<ComponentLoweringState>().getPipelineRegister(srcOp);
327+
if (pipelineRegister.has_value())
328+
srcOp = pipelineRegister->getOut();
329+
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(), srcOp);
330+
}
312331

313332
/// Replace the result values of the source operator with the new operator.
314333
for (auto res : enumerate(opOutputPorts)) {
@@ -1055,22 +1074,17 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
10551074
Value value = operand.get();
10561075

10571076
// Get the pipeline register for that result.
1058-
auto pipelineRegister = pipelineRegisters[i];
1077+
calyx::RegisterOp pipelineRegister = pipelineRegisters[i];
1078+
if (std::optional<calyx::RegisterOp> pr =
1079+
state.getPipelineRegister(value)) {
1080+
value = pr->getOut();
1081+
}
10591082

10601083
calyx::GroupOp group;
10611084
// Get the evaluating group for that value.
10621085
std::optional<calyx::GroupInterface> evaluatingGroup =
10631086
state.findEvaluatingGroup(value);
10641087
if (!evaluatingGroup.has_value()) {
1065-
if (auto opStage =
1066-
dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp())) {
1067-
// The pipeline register for this input value needs to be discovered.
1068-
auto opResult = cast<OpResult>(value);
1069-
unsigned int opNumber = opResult.getResultNumber();
1070-
auto &stageRegisters = state.getPipelineRegs(opStage);
1071-
calyx::RegisterOp opRegister = stageRegisters.find(opNumber)->second;
1072-
value = opRegister.getOut(); // Pass the `out` wire of this register.
1073-
}
10741088
if (value.getDefiningOp<calyx::RegisterOp>() == nullptr) {
10751089
// We add this for any unhandled cases.
10761090
llvm::errs() << "unexpected: input value: " << value << ", in stage "
@@ -1166,8 +1180,9 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
11661180
}
11671181
doneOp.getSrcMutable().assign(pipelineRegister.getDone());
11681182

1169-
// Remove the old register completely.
1170-
rewriter.eraseOp(tempReg);
1183+
// Remove the old register if it has no more uses.
1184+
if (tempReg->use_empty())
1185+
rewriter.eraseOp(tempReg);
11711186

11721187
return group;
11731188
}
@@ -1534,10 +1549,11 @@ class LoopScheduleToCalyxPass
15341549
if (runOnce)
15351550
config.maxIterations = 1;
15361551

1537-
/// Can't return applyPatternsGreedily. Root isn't
1552+
/// Can't return applyPatternsAndFoldGreedily. Root isn't
15381553
/// necessarily erased so it will always return failed(). Instead,
15391554
/// forward the 'succeeded' value from PartialLoweringPatternBase.
1540-
(void)applyPatternsGreedily(getOperation(), std::move(pattern), config);
1555+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(pattern),
1556+
config);
15411557
return partialPatternRes;
15421558
}
15431559

@@ -1628,6 +1644,9 @@ void LoopScheduleToCalyxPass::runOnOperation() {
16281644
addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
16291645
*loweringState);
16301646

1647+
addGreedyPattern<calyx::DeduplicateParallelOp>(loweringPatterns);
1648+
addGreedyPattern<calyx::DeduplicateStaticParallelOp>(loweringPatterns);
1649+
16311650
/// This pattern performs various SSA replacements that must be done
16321651
/// after control generation.
16331652
addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
@@ -1665,8 +1684,8 @@ void LoopScheduleToCalyxPass::runOnOperation() {
16651684
RewritePatternSet cleanupPatterns(&getContext());
16661685
cleanupPatterns.add<calyx::MultipleGroupDonePattern,
16671686
calyx::NonTerminatingGroupDonePattern>(&getContext());
1668-
if (failed(
1669-
applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) {
1687+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
1688+
std::move(cleanupPatterns)))) {
16701689
signalPassFailure();
16711690
return;
16721691
}

lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,28 @@ using namespace mlir::arith;
2929
namespace circt {
3030
namespace calyx {
3131

32+
template <typename OpTy>
33+
static LogicalResult deduplicateParallelOperation(OpTy parOp,
34+
PatternRewriter &rewriter) {
35+
auto *body = parOp.getBodyBlock();
36+
if (body->getOperations().size() < 2)
37+
return failure();
38+
39+
LogicalResult result = LogicalResult::failure();
40+
SetVector<StringRef> members;
41+
for (auto &op : make_early_inc_range(*body)) {
42+
auto enableOp = dyn_cast<EnableOp>(&op);
43+
if (enableOp == nullptr)
44+
continue;
45+
bool inserted = members.insert(enableOp.getGroupName());
46+
if (!inserted) {
47+
rewriter.eraseOp(enableOp);
48+
result = LogicalResult::success();
49+
}
50+
}
51+
return result;
52+
}
53+
3254
void appendPortsForExternalMemref(PatternRewriter &rewriter, StringRef memName,
3355
Value memref, unsigned memoryID,
3456
SmallVectorImpl<calyx::PortInfo> &inPorts,
@@ -609,6 +631,22 @@ EliminateUnusedCombGroups::matchAndRewrite(calyx::CombGroupOp combGroupOp,
609631
return success();
610632
}
611633

634+
//===----------------------------------------------------------------------===//
635+
// DeduplicateParallelOperations
636+
//===----------------------------------------------------------------------===//
637+
638+
LogicalResult
639+
DeduplicateParallelOp::matchAndRewrite(calyx::ParOp parOp,
640+
PatternRewriter &rewriter) const {
641+
return deduplicateParallelOperation<calyx::ParOp>(parOp, rewriter);
642+
}
643+
644+
LogicalResult
645+
DeduplicateStaticParallelOp::matchAndRewrite(calyx::StaticParOp parOp,
646+
PatternRewriter &rewriter) const {
647+
return deduplicateParallelOperation<calyx::StaticParOp>(parOp, rewriter);
648+
}
649+
612650
//===----------------------------------------------------------------------===//
613651
// InlineCombGroups
614652
//===----------------------------------------------------------------------===//
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// RUN: circt-opt %s -lower-loopschedule-to-calyx -canonicalize -split-input-file | FileCheck %s
2+
3+
// This will introduce duplicate groups; these should be subsequently removed during canonicalization.
4+
5+
// CHECK: calyx.while %std_lt_0.out with @bb0_0 {
6+
// CHECK-NEXT: calyx.par {
7+
// CHECK-NEXT: calyx.enable @bb0_1
8+
// CHECK-NEXT: }
9+
// CHECK-NEXT: }
10+
module {
11+
func.func @foo() attributes {} {
12+
%const = arith.constant 1 : index
13+
loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () {
14+
%latch = arith.cmpi ult, %counter, %const : index
15+
loopschedule.register %latch : i1
16+
} do {
17+
%S0 = loopschedule.pipeline.stage start = 0 {
18+
%op = arith.addi %counter, %const : index
19+
loopschedule.register %op : index
20+
} : index
21+
%S1 = loopschedule.pipeline.stage start = 1 {
22+
loopschedule.register %S0: index
23+
} : index
24+
loopschedule.terminator iter_args(%S0), results() : (index) -> ()
25+
}
26+
return
27+
}
28+
}
29+
30+
// -----
31+
32+
// Stage pipeline registers passed directly to the next stage
33+
// should also be updated when used in computations.
34+
35+
// CHECK: calyx.group @bb0_2 {
36+
// CHECK-NEXT: calyx.assign %std_add_1.left = %while_0_arg0_reg.out : i32
37+
// CHECK-NEXT: calyx.assign %std_add_1.right = %c1_i32 : i32
38+
// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.in = %std_add_1.out : i32
39+
// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.write_en = %true : i1
40+
// CHECK-NEXT: calyx.group_done %stage_1_register_0_reg.done : i1
41+
// CHECK-NEXT: }
42+
module {
43+
func.func @foo() attributes {} {
44+
%const = arith.constant 1 : index
45+
loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () {
46+
%latch = arith.cmpi ult, %counter, %const : index
47+
loopschedule.register %latch : i1
48+
} do {
49+
%S0 = loopschedule.pipeline.stage start = 0 {
50+
%op = arith.addi %counter, %const : index
51+
loopschedule.register %op : index
52+
} : index
53+
%S1 = loopschedule.pipeline.stage start = 1 {
54+
%math = arith.addi %S0, %const : index
55+
loopschedule.register %math : index
56+
} : index
57+
loopschedule.terminator iter_args(%S0), results() : (index) -> ()
58+
}
59+
return
60+
}
61+
}
62+

0 commit comments

Comments
 (0)