Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Calyx] constant op #7770

Merged
merged 3 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions include/circt/Dialect/Calyx/CalyxHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,
Twine prefix);

calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
ComponentOp component, Type type,
Twine prefix);

/// A helper function to create constants in the HW dialect.
hw::ConstantOp createConstant(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,
Expand Down
29 changes: 13 additions & 16 deletions include/circt/Dialect/Calyx/CalyxPrimitives.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,33 @@ class CalyxPrimitive<string mnemonic, list<Trait> traits = []> :
}

def ConstantOp: CalyxPrimitive<"constant",
[ConstantLike, FirstAttrDerivedResultType,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllTypesMatch<["value", "out"]>
[ConstantLike, DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "integer or floating point constant";
let summary = "constant capable of representing an integer or floating point value";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
floating-point constant specified by an attribute.
The `constant` operation is a wrapper around bit vectors with fixed-size number of bits.
Specific value and intended type should be specified via attribute only.

Example:

```
// Integer constant
%1 = calyx.constant 42 : i32
%1 = calyx.constant <42 : i32> : i32

// Floating point constant
%1 = calyx.constant 42.00+e00 : f32
%1 = calyx.constant <4.2 : f32> : i32
```
}];
let arguments = (ins TypedAttrInterface:$value);
let arguments = (ins SymbolNameAttr:$sym_name, TypedAttrInterface:$value);

let results = (outs SignlessIntegerOrFloatLike:$out);
let results = (outs AnySignlessInteger:$out);

let builders = [
/// Build a ConstantOp from a prebuilt attribute.
OpBuilder <(ins "StringRef":$sym_name, "TypedAttr":$attr)>,
OpBuilder <(ins "StringRef":$sym_name, "Attribute":$attr, "Type":$type)>,
];

let hasFolder = 1;
let assemblyFormat = "attr-dict $value";
let assemblyFormat = "$sym_name ` ` `<` $value `>` attr-dict `:` qualified(type($out))";
let hasVerifier = 1;
}

Expand Down Expand Up @@ -340,7 +337,7 @@ class ArithBinaryFloatingPointLibraryOp<string mnemonic> : ArithBinaryLibraryOp<

def AddFNOp : ArithBinaryFloatingPointLibraryOp<"addFN"> {
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, I1:$subOp,
AnyFloat:$left, AnyFloat:$right, AnySignlessInteger:$roundingMode, AnyFloat:$out,
AnySignlessInteger:$left, AnySignlessInteger:$right, AnySignlessInteger:$roundingMode, AnySignlessInteger:$out,
AnySignlessInteger:$exceptionalFlags, I1:$done);

let extraClassDefinition = [{
Expand All @@ -352,11 +349,11 @@ def AddFNOp : ArithBinaryFloatingPointLibraryOp<"addFN"> {

SmallVector<Direction> $cppClass::portDirections() {
return {Input, Input, Input, Input, Input, Input, Input, Input, Output, Output, Output};
}
}

void $cppClass::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
getCellAsmResultNames(setNameFn, *this, this->portNames());
}
}

bool $cppClass::isCombinational() { return false; }

Expand Down
13 changes: 9 additions & 4 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
// Pass the result from the Operation to the Calyx primitive.
op.getResult().replaceAllUsesWith(out);
auto reg = createRegister(
op.getLoc(), rewriter, getComponent(), width,
op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(),
getState<ComponentLoweringState>().getUniqueName(opName));
// Operation pipelines are not combinational, so a GroupOp is required.
auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
Expand Down Expand Up @@ -687,9 +687,10 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
AddFOp addf) const {
Location loc = addf.getLoc();
Type width = addf.getResult().getType();
IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
five = rewriter.getIntegerType(5);
five = rewriter.getIntegerType(5),
width = rewriter.getIntegerType(
addf.getType().getIntOrFloatBitWidth());
auto addFN =
getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::AddFNOp>(
Expand Down Expand Up @@ -935,8 +936,11 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
getComponent().getBodyBlock()->begin());
} else {
std::string name = getState<ComponentLoweringState>().getUniqueName("cst");
auto floatAttr = cast<FloatAttr>(constOp.getValueAttr());
auto intType =
rewriter.getIntegerType(floatAttr.getType().getIntOrFloatBitWidth());
auto calyxConstOp = rewriter.create<calyx::ConstantOp>(
constOp.getLoc(), name, constOp.getValueAttr());
constOp.getLoc(), name, floatAttr, intType);
calyxConstOp->moveAfter(getComponent().getBodyBlock(),
getComponent().getBodyBlock()->begin());
rewriter.replaceAllUsesWith(constOp, calyxConstOp.getOut());
Expand Down Expand Up @@ -1215,6 +1219,7 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
else
resName = "out" + std::to_string(res.index());
funcOpResultMapping[res.index()] = outPorts.size();

outPorts.push_back(calyx::PortInfo{
rewriter.getStringAttr(resName),
calyx::convIndexType(rewriter, res.value()), calyx::Direction::Output,
Expand Down
15 changes: 9 additions & 6 deletions lib/Dialect/Calyx/CalyxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1980,10 +1980,13 @@ void ConstantOp::getAsmResultNames(

LogicalResult ConstantOp::verify() {
auto type = getType();
// The value's type must match the return type.
if (auto valType = getValue().getType(); valType != type) {
return emitOpError() << "value type " << valType
<< " must match return type: " << type;
assert(isa<IntegerType>(type) && "must be an IntegerType");
// The value's bit width must match the return type bitwidth.
if (auto valTyBitWidth = getValue().getType().getIntOrFloatBitWidth();
valTyBitWidth != type.getIntOrFloatBitWidth()) {
return emitOpError() << "value type bit width" << valTyBitWidth
<< " must match return type: "
<< type.getIntOrFloatBitWidth();
}
// Integer values must be signless.
if (llvm::isa<IntegerType>(type) &&
Expand All @@ -2002,12 +2005,12 @@ OpFoldResult calyx::ConstantOp::fold(FoldAdaptor adaptor) {
}

void calyx::ConstantOp::build(OpBuilder &builder, OperationState &state,
StringRef symName, TypedAttr attr) {
StringRef symName, Attribute attr, Type type) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute("value", attr);
SmallVector<Type> types;
types.push_back(attr.getType()); // Out
types.push_back(type); // Out
state.addTypes(types);
}

Expand Down
8 changes: 0 additions & 8 deletions lib/Dialect/Calyx/Transforms/CalyxHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
return builder.create<RegisterOp>(loc, (prefix + "_reg").str(), width);
}

calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
ComponentOp component, Type type,
Twine prefix) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(component.getBodyBlock());
return builder.create<RegisterOp>(loc, (prefix + "_reg").str(), type);
}

hw::ConstantOp createConstant(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,
size_t value) {
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ Value getComponentOutput(calyx::ComponentOp compOp, unsigned outPortIdx) {
Type convIndexType(OpBuilder &builder, Type type) {
if (type.isIndex())
return builder.getI32Type();
if (type.isIntOrFloat() && !type.isInteger())
return builder.getIntegerType(type.getIntOrFloatBitWidth());
return type;
}

Expand Down Expand Up @@ -768,7 +770,7 @@ BuildReturnRegs::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
"unsupported return type");
std::string name = "ret_arg" + std::to_string(argType.index());
auto reg = createRegister(funcOp.getLoc(), rewriter, getComponent(),
convArgType, name);
convArgType.getIntOrFloatBitWidth(), name);
getState().addReturnReg(reg, argType.index());

rewriter.setInsertionPointToStart(
Expand Down
10 changes: 5 additions & 5 deletions test/Conversion/SCFToCalyx/convert_simple.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ module {
// Test integer and floating point constant

// CHECK: calyx.group @ret_assign_0 {
// CHECK-DAG: calyx.assign %ret_arg0_reg.in = %in0 : f32
// CHECK-DAG: calyx.assign %ret_arg0_reg.in = %in0 : i32
// CHECK-DAG: calyx.assign %ret_arg0_reg.write_en = %true : i1
// CHECK-DAG: calyx.assign %ret_arg1_reg.in = %c42_i32 : i32
// CHECK-DAG: calyx.assign %ret_arg1_reg.write_en = %true : i1
// CHECK-DAG: calyx.assign %ret_arg2_reg.in = %cst : f32
// CHECK-DAG: calyx.assign %ret_arg2_reg.in = %cst : i32
// CHECK-DAG: calyx.assign %ret_arg2_reg.write_en = %true : i1
// CHECK-DAG: %0 = comb.and %ret_arg2_reg.done, %ret_arg1_reg.done, %ret_arg0_reg.done : i1
// CHECK-DAG: calyx.group_done %0 ? %true : i1
Expand All @@ -239,9 +239,9 @@ module {
// Test floating point add

// CHECK: calyx.group @bb0_0 {
// CHECK-DAG: calyx.assign %std_addFN_0.left = %in0 : f32
// CHECK-DAG: calyx.assign %std_addFN_0.right = %cst : f32
// CHECK-DAG: calyx.assign %addf_0_reg.in = %std_addFN_0.out : f32
// CHECK-DAG: calyx.assign %std_addFN_0.left = %in0 : i32
// CHECK-DAG: calyx.assign %std_addFN_0.right = %cst : i32
// CHECK-DAG: calyx.assign %addf_0_reg.in = %std_addFN_0.out : i32
// CHECK-DAG: calyx.assign %addf_0_reg.write_en = %std_addFN_0.done : i1
// CHECK-DAG: %0 = comb.xor %std_addFN_0.done, %true : i1
// CHECK-DAG: calyx.assign %std_addFN_0.go = %0 ? %true : i1
Expand Down
30 changes: 15 additions & 15 deletions test/Dialect/Calyx/emit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,15 @@ module attributes {calyx.entrypoint = "main"} {
// -----

module attributes {calyx.entrypoint = "main"} {
calyx.component @main(%clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: i32, %out1: f32, %done: i1 {done}) {
calyx.component @main(%clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: i32, %out1: i32, %done: i1 {done}) {
// CHECK: cst_0 = std_float_const(0, 32, 4.200000);
%c42_i32 = hw.constant 42 : i32
%cst = calyx.constant {sym_name = "cst_0"} 4.200000e+00 : f32
%cst = calyx.constant @cst_0 <4.200000e+00 : f32> : i32
%true = hw.constant true
%ret_arg1_reg.in, %ret_arg1_reg.write_en, %ret_arg1_reg.clk, %ret_arg1_reg.reset, %ret_arg1_reg.out, %ret_arg1_reg.done = calyx.register @ret_arg1_reg : f32, i1, i1, i1, f32, i1
%ret_arg1_reg.in, %ret_arg1_reg.write_en, %ret_arg1_reg.clk, %ret_arg1_reg.reset, %ret_arg1_reg.out, %ret_arg1_reg.done = calyx.register @ret_arg1_reg : i32, i1, i1, i1, i32, i1
%ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
calyx.wires {
calyx.assign %out1 = %ret_arg1_reg.out : f32
calyx.assign %out1 = %ret_arg1_reg.out : i32
calyx.assign %out0 = %ret_arg0_reg.out : i32

// CHECK-LABEL: group ret_assign_0 {
Expand All @@ -266,7 +266,7 @@ module attributes {calyx.entrypoint = "main"} {
calyx.group @ret_assign_0 {
calyx.assign %ret_arg0_reg.in = %c42_i32 : i32
calyx.assign %ret_arg0_reg.write_en = %true : i1
calyx.assign %ret_arg1_reg.in = %cst : f32
calyx.assign %ret_arg1_reg.in = %cst : i32
calyx.assign %ret_arg1_reg.write_en = %true : i1
%0 = comb.and %ret_arg1_reg.done, %ret_arg0_reg.done : i1
calyx.group_done %0 ? %true : i1
Expand All @@ -285,16 +285,16 @@ module attributes {calyx.entrypoint = "main"} {

module attributes {calyx.entrypoint = "main"} {
// CHECK: import "primitives/float/addFN.futil";
calyx.component @main(%in0: f32, %clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: f32, %done: i1 {done}) {
calyx.component @main(%in0: i32, %clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: i32, %done: i1 {done}) {
// CHECK: std_addFN_0 = std_addFN(8, 24, 32);
%cst = calyx.constant {sym_name = "cst_0"} 4.200000e+00 : f32
%cst = calyx.constant @cst_0 <4.200000e+00 : f32> : i32
%true = hw.constant true
%false = hw.constant false
%addf_0_reg.in, %addf_0_reg.write_en, %addf_0_reg.clk, %addf_0_reg.reset, %addf_0_reg.out, %addf_0_reg.done = calyx.register @addf_0_reg : f32, i1, i1, i1, f32, i1
%std_addFN_0.clk, %std_addFN_0.reset, %std_addFN_0.go, %std_addFN_0.control, %std_addFN_0.subOp, %std_addFN_0.left, %std_addFN_0.right, %std_addFN_0.roundingMode, %std_addFN_0.out, %std_addFN_0.exceptionalFlags, %std_addFN_0.done = calyx.std_addFN @std_addFN_0 : i1, i1, i1, i1, i1, f32, f32, i3, f32, i5, i1
%ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : f32, i1, i1, i1, f32, i1
%addf_0_reg.in, %addf_0_reg.write_en, %addf_0_reg.clk, %addf_0_reg.reset, %addf_0_reg.out, %addf_0_reg.done = calyx.register @addf_0_reg : i32, i1, i1, i1, i32, i1
%std_addFN_0.clk, %std_addFN_0.reset, %std_addFN_0.go, %std_addFN_0.control, %std_addFN_0.subOp, %std_addFN_0.left, %std_addFN_0.right, %std_addFN_0.roundingMode, %std_addFN_0.out, %std_addFN_0.exceptionalFlags, %std_addFN_0.done = calyx.std_addFN @std_addFN_0 : i1, i1, i1, i1, i1, i32, i32, i3, i32, i5, i1
%ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
calyx.wires {
calyx.assign %out0 = %ret_arg0_reg.out : f32
calyx.assign %out0 = %ret_arg0_reg.out : i32

// CHECK-LABEL: group bb0_0 {
// CHECK-NEXT: std_addFN_0.left = in0;
Expand All @@ -306,17 +306,17 @@ module attributes {calyx.entrypoint = "main"} {
// CHECK-NEXT: bb0_0[done] = addf_0_reg.done;
// CHECK-NEXT: }
calyx.group @bb0_0 {
calyx.assign %std_addFN_0.left = %in0 : f32
calyx.assign %std_addFN_0.right = %cst : f32
calyx.assign %addf_0_reg.in = %std_addFN_0.out : f32
calyx.assign %std_addFN_0.left = %in0 : i32
calyx.assign %std_addFN_0.right = %cst : i32
calyx.assign %addf_0_reg.in = %std_addFN_0.out : i32
calyx.assign %addf_0_reg.write_en = %std_addFN_0.done : i1
%0 = comb.xor %std_addFN_0.done, %true : i1
calyx.assign %std_addFN_0.go = %0 ? %true : i1
calyx.assign %std_addFN_0.subOp = %false : i1
calyx.group_done %addf_0_reg.done : i1
}
calyx.group @ret_assign_0 {
calyx.assign %ret_arg0_reg.in = %std_addFN_0.out : f32
calyx.assign %ret_arg0_reg.in = %std_addFN_0.out : i32
calyx.assign %ret_arg0_reg.write_en = %true : i1
calyx.group_done %ret_arg0_reg.done : i1
}
Expand Down
Loading