Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Calyx] Lower SCF parallel op to Calyx #7830

Merged
merged 7 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 153 additions & 11 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,15 @@ struct CallScheduleable {
func::CallOp callOp;
};

struct ParScheduleable {
/// Parallel operation to schedule.
scf::ParallelOp parOp;
};

/// A variant of types representing scheduleable operations.
using Scheduleable =
std::variant<calyx::GroupOp, WhileScheduleable, ForScheduleable,
IfScheduleable, CallScheduleable>;
IfScheduleable, CallScheduleable, ParScheduleable>;

class IfLoweringStateInterface {
public:
Expand Down Expand Up @@ -275,6 +280,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
.template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
/// SCF
scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
scf::ParallelOp, scf::ReduceOp,
/// memref
memref::AllocOp, memref::AllocaOp, memref::LoadOp,
memref::StoreOp,
Expand Down Expand Up @@ -338,6 +344,10 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::ForOp forOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::IfOp ifOp) const;
LogicalResult buildOp(PatternRewriter &rewriter,
scf::ReduceOp reduceOp) const;
LogicalResult buildOp(PatternRewriter &rewriter,
scf::ParallelOp parallelOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, CallOp callOp) const;

/// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
Expand Down Expand Up @@ -1093,6 +1103,21 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::ReduceOp reduceOp) const {
// we don't handle reduce operation and simply return success for now since
// BuildParGroups would have already emitted an error and exited early
// if a reduce operation was encountered.
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::ParallelOp parOp) const {
getState<ComponentLoweringState>().addBlockScheduleable(
parOp.getOperation()->getBlock(), ParScheduleable{parOp});
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CallOp callOp) const {
std::string instanceName = calyx::getInstanceName(callOp);
Expand Down Expand Up @@ -1481,6 +1506,106 @@ class BuildIfGroups : public calyx::FuncOpPartialLoweringPattern {
}
};

class BuildParGroups : public calyx::FuncOpPartialLoweringPattern {
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
partiallyLowerFuncToComp(FuncOp funcOp,
PatternRewriter &rewriter) const override {
WalkResult walkResult = funcOp.walk([&](scf::ParallelOp scfParOp) {
if (!scfParOp.getResults().empty()) {
scfParOp.emitError(
"Reduce operations in scf.parallel is not supported yet");
return WalkResult::interrupt();
}

if (failed(partialEval(rewriter, scfParOp)))
return WalkResult::interrupt();

return WalkResult::advance();
});

return walkResult.wasInterrupted() ? failure() : success();
}

private:
// Partially evaluate/pre-compute all blocks being executed in parall by
// statically generate loop indices combinations
LogicalResult partialEval(PatternRewriter &rewriter,
jiahanxie353 marked this conversation as resolved.
Show resolved Hide resolved
scf::ParallelOp scfParOp) const {
assert(scfParOp.getLoopSteps() && "Parallel loop must have steps");
auto *body = scfParOp.getBody();
auto parOpIVs = scfParOp.getInductionVars();
auto steps = scfParOp.getStep();
auto lowerBounds = scfParOp.getLowerBound();
auto upperBounds = scfParOp.getUpperBound();
rewriter.setInsertionPointAfter(scfParOp);
scf::ParallelOp newParOp = scfParOp.cloneWithoutRegions();
auto loc = newParOp.getLoc();
rewriter.insert(newParOp);
OpBuilder insideBuilder(newParOp);
Block *currBlock = nullptr;
auto &region = newParOp.getRegion();
IRMapping operandMap;

// extract lower bounds, upper bounds, and steps as integer index values
SmallVector<int64_t> lbVals, ubVals, stepVals;
for (auto lb : lowerBounds) {
auto lbOp = lb.getDefiningOp<arith::ConstantIndexOp>();
assert(lbOp &&
"Lower bound must be a statically computable constant index");
lbVals.push_back(lbOp.value());
}
for (auto ub : upperBounds) {
auto ubOp = ub.getDefiningOp<arith::ConstantIndexOp>();
assert(ubOp &&
"Upper bound must be a statically computable constant index");
ubVals.push_back(ubOp.value());
}
for (auto step : steps) {
auto stepOp = step.getDefiningOp<arith::ConstantIndexOp>();
assert(stepOp && "Step must be a statically computable constant index");
stepVals.push_back(stepOp.value());
}

// Initialize indices with lower bounds
SmallVector<int64_t> indices = lbVals;

while (true) {
// Create a new block in the region for the current combination of indices
currBlock = &region.emplaceBlock();
insideBuilder.setInsertionPointToEnd(currBlock);

// Map induction variables to constant indices
for (unsigned i = 0; i < indices.size(); ++i) {
Value ivConstant =
insideBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
operandMap.map(parOpIVs[i], ivConstant);
}

for (auto it = body->begin(); it != std::prev(body->end()); ++it)
insideBuilder.clone(*it, operandMap);

// Increment indices using `step`
bool done = false;
for (int dim = indices.size() - 1; dim >= 0; --dim) {
indices[dim] += stepVals[dim];
if (indices[dim] < ubVals[dim])
break;
indices[dim] = lbVals[dim];
if (dim == 0)
// All combinations have been generated
done = true;
}
if (done)
break;
}

rewriter.replaceOp(scfParOp, newParOp);
return success();
}
};

/// Builds a control schedule by traversing the CFG of the function and
/// associating this with the previously created groups.
/// For simplicity, the generated control flow is expanded for all possible
Expand Down Expand Up @@ -1512,7 +1637,8 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
getState<ComponentLoweringState>().getBlockScheduleables(block);
auto loc = block->front().getLoc();

if (compBlockScheduleables.size() > 1) {
if (compBlockScheduleables.size() > 1 &&
!isa<scf::ParallelOp>(block->getParentOp())) {
auto seqOp = rewriter.create<calyx::SeqOp>(loc);
parentCtrlBlock = seqOp.getBodyBlock();
}
Expand All @@ -1537,18 +1663,30 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {

/// Only schedule the 'after' block. The 'before' block is
/// implicitly scheduled when evaluating the while condition.
LogicalResult res = buildCFGControl(path, rewriter, whileBodyOpBlock,
block, whileOp.getBodyBlock());
if (LogicalResult result =
buildCFGControl(path, rewriter, whileBodyOpBlock, block,
whileOp.getBodyBlock());
result.failed())
return result;

// Insert loop-latch at the end of the while group
rewriter.setInsertionPointToEnd(whileBodyOpBlock);
calyx::GroupOp whileLatchGroup =
getState<ComponentLoweringState>().getWhileLoopLatchGroup(whileOp);
rewriter.create<calyx::EnableOp>(whileLatchGroup.getLoc(),
whileLatchGroup.getName());

if (res.failed())
return res;
} else if (auto *parSchedPtr = std::get_if<ParScheduleable>(&group)) {
auto parOp = parSchedPtr->parOp;
auto calyxParOp = rewriter.create<calyx::ParOp>(parOp.getLoc());
for (auto &innerBlock : parOp.getRegion().getBlocks()) {
rewriter.setInsertionPointToEnd(calyxParOp.getBodyBlock());
auto seqOp = rewriter.create<calyx::SeqOp>(parOp.getLoc());
rewriter.setInsertionPointToEnd(seqOp.getBodyBlock());
if (LogicalResult res = scheduleBasicBlock(
rewriter, path, seqOp.getBodyBlock(), &innerBlock);
res.failed())
return res;
}
} else if (auto *forSchedPtr = std::get_if<ForScheduleable>(&group);
forSchedPtr) {
auto forOp = forSchedPtr->forOp;
Expand All @@ -1563,17 +1701,17 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
auto *forBodyOpBlock = forBodyOp.getBodyBlock();

// Schedule the body of the for loop.
LogicalResult res = buildCFGControl(path, rewriter, forBodyOpBlock,
block, forOp.getBodyBlock());
if (LogicalResult res = buildCFGControl(path, rewriter, forBodyOpBlock,
block, forOp.getBodyBlock());
res.failed())
return res;

// Insert loop-latch at the end of the while group.
rewriter.setInsertionPointToEnd(forBodyOpBlock);
calyx::GroupOp forLatchGroup =
getState<ComponentLoweringState>().getForLoopLatchGroup(forOp);
rewriter.create<calyx::EnableOp>(forLatchGroup.getLoc(),
forLatchGroup.getName());
if (res.failed())
return res;
} else if (auto *ifSchedPtr = std::get_if<IfScheduleable>(&group);
ifSchedPtr) {
auto ifOp = ifSchedPtr->ifOp;
Expand Down Expand Up @@ -2241,6 +2379,9 @@ void SCFToCalyxPass::runOnOperation() {
/// This pass inlines scf.ExecuteRegionOp's by adding control-flow.
addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);

addOncePattern<BuildParGroups>(loweringPatterns, patternState, funcMap,
*loweringState);

/// This pattern converts all index typed values to an i32 integer.
addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, patternState,
funcMap, *loweringState);
Expand Down Expand Up @@ -2270,6 +2411,7 @@ void SCFToCalyxPass::runOnOperation() {

addOncePattern<BuildIfGroups>(loweringPatterns, patternState, funcMap,
*loweringState);

/// This pattern converts operations within basic blocks to Calyx library
/// operators. Combinational operations are assigned inside a
/// calyx::CombGroupOp, and sequential inside calyx::GroupOps.
Expand Down
119 changes: 119 additions & 0 deletions test/Conversion/SCFToCalyx/convert_simple.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,122 @@ module {
return %1 : f32
}
}

// -----

// Test parallel op lowering

// CHECK: calyx.wires {
// CHECK-DAG: calyx.group @bb0_0 {
// CHECK-DAG: calyx.assign %std_slice_7.in = %c0_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_7.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
// CHECK-DAG: calyx.assign %mem_1.write_en = %false : i1
// CHECK-DAG: calyx.assign %load_0_reg.in = %mem_1.read_data : i32
// CHECK-DAG: calyx.assign %load_0_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_0_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb0_1 {
// CHECK-DAG: calyx.assign %std_slice_6.in = %c0_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_6.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_0_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb1_0 {
// CHECK-DAG: calyx.assign %std_slice_5.in = %c4_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_5.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
// CHECK-DAG: calyx.assign %mem_1.write_en = %false : i1
// CHECK-DAG: calyx.assign %load_1_reg.in = %mem_1.read_data : i32
// CHECK-DAG: calyx.assign %load_1_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_1_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb1_1 {
// CHECK-DAG: calyx.assign %std_slice_4.in = %c1_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_4.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_1_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb2_0 {
// CHECK-DAG: calyx.assign %std_slice_3.in = %c2_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_3.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
// CHECK-DAG: calyx.assign %mem_1.write_en = %false : i1
// CHECK-DAG: calyx.assign %load_2_reg.in = %mem_1.read_data : i32
// CHECK-DAG: calyx.assign %load_2_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_2_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb2_1 {
// CHECK-DAG: calyx.assign %std_slice_2.in = %c4_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_2.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_2_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb3_0 {
// CHECK-DAG: calyx.assign %std_slice_1.in = %c6_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_1.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
// CHECK-DAG: calyx.assign %mem_1.write_en = %false : i1
// CHECK-DAG: calyx.assign %load_3_reg.in = %mem_1.read_data : i32
// CHECK-DAG: calyx.assign %load_3_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_3_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb3_1 {
// CHECK-DAG: calyx.assign %std_slice_0.in = %c5_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_0.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_3_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: }
// CHECK-DAG: calyx.control {
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.par {
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb0_0
// CHECK-DAG: calyx.enable @bb0_1
// CHECK-DAG: }
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb1_0
// CHECK-DAG: calyx.enable @bb1_1
// CHECK-DAG: }
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb2_0
// CHECK-DAG: calyx.enable @bb2_1
// CHECK-DAG: }
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb3_0
// CHECK-DAG: calyx.enable @bb3_1
// CHECK-DAG: }
// CHECK-DAG: }
// CHECK-DAG: }
// CHECK-DAG: }

module {
func.func @main() {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<6xi32>
%alloc_1 = memref.alloc() : memref<6xi32>
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c3, %c2) step (%c2, %c1) {
%4 = arith.shli %arg3, %c2 : index
%5 = arith.addi %4, %arg2 : index
%6 = memref.load %alloc_1[%5] : memref<6xi32>
%7 = arith.shli %arg2, %c1 : index
%8 = arith.addi %7, %arg3 : index
memref.store %6, %alloc[%8] : memref<6xi32>
scf.reduce
}
return
}
}

22 changes: 22 additions & 0 deletions test/Conversion/SCFToCalyx/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,25 @@ module {
}
}

// -----

module {
func.func @main() -> i32 {
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c0 = arith.constant 0 : index
%cinit = arith.constant 0 : i32
%alloc = memref.alloc() : memref<6xi32>
// expected-error @+1 {{Reduce operations in scf.parallel is not supported yet}}
%r:1 = scf.parallel (%arg2) = (%c0) to (%c3) step (%c1) init (%cinit) -> i32 {
%6 = memref.load %alloc[%arg2] : memref<6xi32>
scf.reduce(%6 : i32) {
^bb0(%lhs : i32, %rhs: i32):
%res = arith.addi %lhs, %rhs : i32
scf.reduce.return %res : i32
}
}
return %r : i32
}
}

Loading