Skip to content

Commit b2aebf0

Browse files
authored
[Calyx] Binary Floating Point MulF Operator and FloatingPointInterface (#7769)
* support mulf op * introduce floating point interface and use sfinae for get the name of op instance * change Add/MulFN to Add/MulFOpIEEE754 for better naming standard
1 parent ea99ca4 commit b2aebf0

File tree

10 files changed

+239
-31
lines changed

10 files changed

+239
-31
lines changed

include/circt/Dialect/Calyx/CalyxInterfaces.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,22 @@ def IfOpInterface : OpInterface<"IfInterface"> {
337337

338338
let verify = [{ return verifyIf(op); }];
339339
}
340+
341+
def FloatingPointOpInterface: OpInterface<"FloatingPointOpInterface"> {
342+
let cppNamespace = "::circt::calyx";
343+
344+
let description = [{
345+
This is an op interface for Calyx floating point ops.
346+
}];
347+
348+
let methods = [
349+
StaticInterfaceMethod<
350+
"This returns the floating point standard.",
351+
"FloatingPointStandard",
352+
"getFloatingPointStandard">,
353+
StaticInterfaceMethod<
354+
"This returns the Calyx native library name.",
355+
"std::string",
356+
"getCalyxLibraryName">
357+
];
358+
}

include/circt/Dialect/Calyx/CalyxLoweringUtils.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,13 +417,39 @@ class ComponentLoweringStateInterface {
417417
}
418418
}
419419

420+
template <typename T, typename = void>
421+
struct IsFloatingPoint : std::false_type {};
422+
423+
template <typename T>
424+
struct IsFloatingPoint<
425+
T, std::void_t<decltype(std::declval<T>().getFloatingPointStandard())>>
426+
: std::is_same<decltype(std::declval<T>().getFloatingPointStandard()),
427+
FloatingPointStandard> {};
428+
420429
template <typename TLibraryOp>
421430
TLibraryOp getNewLibraryOpInstance(OpBuilder &builder, Location loc,
422431
TypeRange resTypes) {
423432
mlir::IRRewriter::InsertionGuard guard(builder);
424433
Block *body = component.getBodyBlock();
425434
builder.setInsertionPoint(body, body->begin());
426-
auto name = TLibraryOp::getOperationName().split(".").second;
435+
std::string name = TLibraryOp::getOperationName().split(".").second.str();
436+
if constexpr (IsFloatingPoint<TLibraryOp>::value) {
437+
switch (TLibraryOp::getFloatingPointStandard()) {
438+
case FloatingPointStandard::IEEE754: {
439+
constexpr char prefix[] = "ieee754.";
440+
assert(name.find(prefix) == 0 &&
441+
("IEEE754 type operation's name must begin with '" +
442+
std::string(prefix) + "'")
443+
.c_str());
444+
name.erase(0, sizeof(prefix) - 1);
445+
name = llvm::join_items(/*separator=*/"", "std_", name, "FN");
446+
break;
447+
}
448+
449+
default:
450+
llvm_unreachable("Unhandled floating point standard.");
451+
}
452+
}
427453
return builder.create<TLibraryOp>(loc, getUniqueName(name), resTypes);
428454
}
429455

include/circt/Dialect/Calyx/CalyxOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class Combinational
6767
}
6868
};
6969

70+
enum class FloatingPointStandard {
71+
IEEE754,
72+
};
73+
7074
/// The direction of a Component or Cell port. this is similar to the
7175
/// implementation found in the FIRRTL dialect.
7276
enum Direction { Input = 0, Output = 1 };

include/circt/Dialect/Calyx/CalyxPrimitives.td

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ def SeqMemoryOp : CalyxPrimitive<"seq_mem", []> {
257257
}];
258258
}
259259

260-
class CalyxLibraryOp<string mnemonic, list<Trait> traits = []> :
261-
CalyxPrimitive<"std_" # mnemonic, traits> {
260+
class CalyxLibraryOp<string mnemonic, string prefix = "std_", list<Trait> traits = []> :
261+
CalyxPrimitive<prefix # mnemonic, traits> {
262262

263263
let summary = "Defines an operation which maps to a Calyx library primitive";
264264
let description = [{
@@ -289,7 +289,7 @@ class CalyxLibraryOp<string mnemonic, list<Trait> traits = []> :
289289
];
290290
}
291291

292-
class BoolBinaryLibraryOp<string mnemonic> : CalyxLibraryOp<mnemonic, [
292+
class BoolBinaryLibraryOp<string mnemonic> : CalyxLibraryOp<mnemonic, "std_", [
293293
Combinational,
294294
SameTypeConstraint<"left", "right">
295295
]> {
@@ -309,13 +309,13 @@ def SneqLibOp : BoolBinaryLibraryOp<"sneq"> {}
309309
def SgeLibOp : BoolBinaryLibraryOp<"sge"> {}
310310
def SleLibOp : BoolBinaryLibraryOp<"sle"> {}
311311

312-
class ArithBinaryLibraryOp<string mnemonic, list<Trait> traits = []> :
313-
CalyxLibraryOp<mnemonic, !listconcat(traits, [
312+
class ArithBinaryLibraryOp<string mnemonic, string prefix, list<Trait> traits = []> :
313+
CalyxLibraryOp<mnemonic, prefix, !listconcat(traits, [
314314
SameTypeConstraint<"left", "right">
315315
])> {}
316316

317317
class CombinationalArithBinaryLibraryOp<string mnemonic> :
318-
ArithBinaryLibraryOp<mnemonic, [
318+
ArithBinaryLibraryOp<mnemonic, "std_", [
319319
Combinational,
320320
SameTypeConstraint<"left", "out">
321321
]> {
@@ -332,10 +332,13 @@ def AndLibOp : CombinationalArithBinaryLibraryOp<"and"> {}
332332
def OrLibOp : CombinationalArithBinaryLibraryOp<"or"> {}
333333
def XorLibOp : CombinationalArithBinaryLibraryOp<"xor"> {}
334334

335-
class ArithBinaryFloatingPointLibraryOp<string mnemonic> : ArithBinaryLibraryOp<mnemonic, [
336-
SameTypeConstraint<"left", "out">]> {}
335+
class ArithBinaryFloatingPointLibraryOp<string mnemonic> :
336+
ArithBinaryLibraryOp<mnemonic, "", [
337+
DeclareOpInterfaceMethods<FloatingPointOpInterface>,
338+
SameTypeConstraint<"left", "out">
339+
]> {}
337340

338-
def AddFNOp : ArithBinaryFloatingPointLibraryOp<"addFN"> {
341+
def AddFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.add"> {
339342
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, I1:$subOp,
340343
AnySignlessInteger:$left, AnySignlessInteger:$right, AnySignlessInteger:$roundingMode, AnySignlessInteger:$out,
341344
AnySignlessInteger:$exceptionalFlags, I1:$done);
@@ -375,7 +378,42 @@ def AddFNOp : ArithBinaryFloatingPointLibraryOp<"addFN"> {
375378
}];
376379
}
377380

378-
def MuxLibOp : CalyxLibraryOp<"mux", [
381+
def MulFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.mul"> {
382+
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control,
383+
AnySignlessInteger:$left, AnySignlessInteger:$right, AnySignlessInteger:$roundingMode, AnySignlessInteger:$out,
384+
AnySignlessInteger:$exceptionalFlags, I1:$done);
385+
let assemblyFormat = "$sym_name attr-dict `:` qualified(type(results))";
386+
let extraClassDefinition = [{
387+
SmallVector<StringRef> $cppClass::portNames() {
388+
return {clkPort, resetPort, goPort, "control", "left", "right",
389+
"roundingMode", "out", "exceptionalFlags", donePort
390+
};
391+
}
392+
SmallVector<Direction> $cppClass::portDirections() {
393+
return {Input, Input, Input, Input, Input, Input, Input, Output, Output, Output};
394+
}
395+
void $cppClass::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
396+
getCellAsmResultNames(setNameFn, *this, this->portNames());
397+
}
398+
bool $cppClass::isCombinational() { return false; }
399+
SmallVector<DictionaryAttr> $cppClass::portAttributes() {
400+
IntegerAttr isSet = IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
401+
NamedAttrList go, clk, reset, done;
402+
go.append(goPort, isSet);
403+
clk.append(clkPort, isSet);
404+
reset.append(resetPort, isSet);
405+
done.append(donePort, isSet);
406+
return {clk.getDictionary(getContext()), reset.getDictionary(getContext()),
407+
go.getDictionary(getContext()), DictionaryAttr::get(getContext()),
408+
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
409+
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
410+
done.getDictionary(getContext()), DictionaryAttr::get(getContext())
411+
};
412+
}
413+
}];
414+
}
415+
416+
def MuxLibOp : CalyxLibraryOp<"mux", "std_", [
379417
Combinational, SameTypeConstraint<"tru", "fal">, SameTypeConstraint<"tru", "out">
380418
]> {
381419
let results = (outs I1:$cond, AnyType:$tru, AnyType:$fal, AnyType:$out);
@@ -397,7 +435,7 @@ def MuxLibOp : CalyxLibraryOp<"mux", [
397435
}];
398436
}
399437

400-
class ArithBinaryPipeLibraryOp<string mnemonic> : ArithBinaryLibraryOp<mnemonic # "_pipe", [
438+
class ArithBinaryPipeLibraryOp<string mnemonic> : ArithBinaryLibraryOp<mnemonic # "_pipe", "std_", [
401439
SameTypeConstraint<"left", "out">
402440
]> {
403441
let results = (outs I1:$clk, I1:$reset, I1:$go, AnyType:$left, AnyType:$right, AnyType:$out, I1:$done);
@@ -410,7 +448,7 @@ def RemUPipeLibOp : ArithBinaryPipeLibraryOp<"remu"> {}
410448
def RemSPipeLibOp : ArithBinaryPipeLibraryOp<"rems"> {}
411449

412450
class UnaryLibraryOp<string mnemonic, list<Trait> traits = []> :
413-
CalyxLibraryOp<mnemonic, !listconcat(traits, [Combinational])> {
451+
CalyxLibraryOp<mnemonic, "std_", !listconcat(traits, [Combinational])> {
414452
let results = (outs AnyInteger:$in, AnyInteger:$out);
415453
}
416454

lib/Conversion/SCFToCalyx/SCFToCalyx.cpp

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
289289
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
290290
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
291291
/// floating point
292-
AddFOp,
292+
AddFOp, MulFOp,
293293
/// others
294294
SelectOp, IndexCastOp, CallOp>(
295295
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
@@ -325,6 +325,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
325325
LogicalResult buildOp(PatternRewriter &rewriter, RemUIOp op) const;
326326
LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op) const;
327327
LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op) const;
328+
LogicalResult buildOp(PatternRewriter &rewriter, MulFOp op) const;
328329
LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const;
329330
LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const;
330331
LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const;
@@ -449,8 +450,8 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
449450
// The group is done when the register write is complete.
450451
rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());
451452

452-
if (isa<calyx::AddFNOp>(opPipe)) {
453-
auto opFN = cast<calyx::AddFNOp>(opPipe);
453+
if (isa<calyx::AddFOpIEEE754>(opPipe)) {
454+
auto opFOp = cast<calyx::AddFOpIEEE754>(opPipe);
454455
hw::ConstantOp subOp;
455456
if (isa<arith::AddFOp>(op)) {
456457
subOp = createConstant(loc, rewriter, getComponent(), /*width=*/1,
@@ -459,7 +460,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
459460
subOp = createConstant(loc, rewriter, getComponent(), /*width=*/1,
460461
/*subtract=*/1);
461462
}
462-
rewriter.create<calyx::AssignOp>(loc, opFN.getSubOp(), subOp);
463+
rewriter.create<calyx::AssignOp>(loc, opFOp.getSubOp(), subOp);
463464
}
464465

465466
// Register the values for the pipeline.
@@ -701,13 +702,29 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
701702
five = rewriter.getIntegerType(5),
702703
width = rewriter.getIntegerType(
703704
addf.getType().getIntOrFloatBitWidth());
704-
auto addFN =
705+
auto addFOp =
705706
getState<ComponentLoweringState>()
706-
.getNewLibraryOpInstance<calyx::AddFNOp>(
707+
.getNewLibraryOpInstance<calyx::AddFOpIEEE754>(
707708
rewriter, loc,
708709
{one, one, one, one, one, width, width, three, width, five, one});
709-
return buildLibraryBinaryPipeOp<calyx::AddFNOp>(rewriter, addf, addFN,
710-
addFN.getOut());
710+
return buildLibraryBinaryPipeOp<calyx::AddFOpIEEE754>(rewriter, addf, addFOp,
711+
addFOp.getOut());
712+
}
713+
714+
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
715+
MulFOp mulf) const {
716+
Location loc = mulf.getLoc();
717+
IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
718+
five = rewriter.getIntegerType(5),
719+
width = rewriter.getIntegerType(
720+
mulf.getType().getIntOrFloatBitWidth());
721+
auto mulFOp =
722+
getState<ComponentLoweringState>()
723+
.getNewLibraryOpInstance<calyx::MulFOpIEEE754>(
724+
rewriter, loc,
725+
{one, one, one, one, width, width, three, width, five, one});
726+
return buildLibraryBinaryPipeOp<calyx::MulFOpIEEE754>(rewriter, mulf, mulFOp,
727+
mulFOp.getOut());
711728
}
712729

713730
template <typename TAllocOp>
@@ -2094,7 +2111,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
20942111
ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
20952112
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
20962113
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
2097-
ExtSIOp, CallOp, AddFOp>();
2114+
ExtSIOp, CallOp, AddFOp, MulFOp>();
20982115

20992116
RewritePatternSet legalizePatterns(&getContext());
21002117
legalizePatterns.add<DummyPattern>(&getContext());

lib/Dialect/Calyx/CalyxOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,21 @@ uint32_t CycleOp::getGroupLatency() {
11981198
return group.getLatency();
11991199
}
12001200

1201+
//===----------------------------------------------------------------------===//
1202+
// Floating Point Op
1203+
//===----------------------------------------------------------------------===//
1204+
FloatingPointStandard AddFOpIEEE754::getFloatingPointStandard() {
1205+
return FloatingPointStandard::IEEE754;
1206+
}
1207+
1208+
FloatingPointStandard MulFOpIEEE754::getFloatingPointStandard() {
1209+
return FloatingPointStandard::IEEE754;
1210+
}
1211+
1212+
std::string AddFOpIEEE754::getCalyxLibraryName() { return "std_addFN"; }
1213+
1214+
std::string MulFOpIEEE754::getCalyxLibraryName() { return "std_mulFN"; }
1215+
12011216
//===----------------------------------------------------------------------===//
12021217
// GroupInterface
12031218
//===----------------------------------------------------------------------===//

lib/Dialect/Calyx/Export/CalyxEmitter.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,14 @@ struct ImportTracker {
149149
static constexpr std::string_view sFloat = "float";
150150
return {sFloat};
151151
})
152-
.Case<AddFNOp>([&](auto op) -> FailureOr<StringRef> {
152+
.Case<AddFOpIEEE754>([&](auto op) -> FailureOr<StringRef> {
153153
static constexpr std::string_view sFloatingPoint = "float/addFN";
154154
return {sFloatingPoint};
155155
})
156+
.Case<MulFOpIEEE754>([&](auto op) -> FailureOr<StringRef> {
157+
static constexpr std::string_view sFloatingPoint = "float/mulFN";
158+
return {sFloatingPoint};
159+
})
156160
.Default([&](auto op) {
157161
auto diag = op->emitOpError() << "not supported for emission";
158162
return diag;
@@ -675,7 +679,8 @@ void Emitter::emitComponent(ComponentInterface op) {
675679
emitLibraryPrimTypedByFirstOutputPort(
676680
op, /*calyxLibName=*/{"std_sdiv_pipe"});
677681
})
678-
.Case<AddFNOp>([&](auto op) { emitLibraryFloatingPoint(op); })
682+
.Case<AddFOpIEEE754, MulFOpIEEE754>(
683+
[&](auto op) { emitLibraryFloatingPoint(op); })
679684
.Default([&](auto op) {
680685
emitOpError(op, "not supported for emission inside component");
681686
});
@@ -1019,11 +1024,14 @@ void Emitter::emitLibraryFloatingPoint(Operation *op) {
10191024
return;
10201025
}
10211026

1022-
StringRef opName = op->getName().getStringRef();
1027+
std::string opName;
1028+
if (auto fpOp = dyn_cast<calyx::FloatingPointOpInterface>(op)) {
1029+
opName = fpOp.getCalyxLibraryName();
1030+
}
10231031
indent() << getAttributes(op, /*atFormat=*/true) << cell.instanceName()
1024-
<< space() << equals() << space() << removeCalyxPrefix(opName)
1025-
<< LParen() << expWidth << comma() << sigWidth << comma() << bitWidth
1026-
<< RParen() << semicolonEndL();
1032+
<< space() << equals() << space() << opName << LParen() << expWidth
1033+
<< comma() << sigWidth << comma() << bitWidth << RParen()
1034+
<< semicolonEndL();
10271035
}
10281036

10291037
void Emitter::emitAssignment(AssignOp op) {

lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,8 @@ void InlineCombGroups::recurseInlineCombGroups(
673673
hw::ConstantOp, mlir::arith::ConstantOp, calyx::MultPipeLibOp,
674674
calyx::DivUPipeLibOp, calyx::DivSPipeLibOp, calyx::RemSPipeLibOp,
675675
calyx::RemUPipeLibOp, mlir::scf::WhileOp, calyx::InstanceOp,
676-
calyx::ConstantOp, calyx::AddFNOp>(src.getDefiningOp()))
676+
calyx::ConstantOp, calyx::AddFOpIEEE754, calyx::MulFOpIEEE754>(
677+
src.getDefiningOp()))
677678
continue;
678679

679680
auto srcCombGroup = dyn_cast<calyx::CombGroupOp>(

test/Conversion/SCFToCalyx/convert_simple.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,33 @@ module {
260260

261261
// -----
262262

263+
// Test floating point mul
264+
265+
// CHECK: %cst = calyx.constant @cst_0 <4.200000e+00 : f32> : i32
266+
// CHECK-DAG: %true = hw.constant true
267+
// CHECK-DAG: %mulf_0_reg.in, %mulf_0_reg.write_en, %mulf_0_reg.clk, %mulf_0_reg.reset, %mulf_0_reg.out, %mulf_0_reg.done = calyx.register @mulf_0_reg : i32, i1, i1, i1, i32, i1
268+
// CHECK-DAG: %std_mulFN_0.clk, %std_mulFN_0.reset, %std_mulFN_0.go, %std_mulFN_0.control, %std_mulFN_0.left, %std_mulFN_0.right, %std_mulFN_0.roundingMode, %std_mulFN_0.out, %std_mulFN_0.exceptionalFlags, %std_mulFN_0.done = calyx.ieee754.mul @std_mulFN_0 : i1, i1, i1, i1, i32, i32, i3, i32, i5, i1
269+
// CHECK-DAG: %ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
270+
// CHECK: calyx.group @bb0_0 {
271+
// CHECK-DAG: calyx.assign %std_mulFN_0.left = %in0 : i32
272+
// CHECK-DAG: calyx.assign %std_mulFN_0.right = %cst : i32
273+
// CHECK-DAG: calyx.assign %mulf_0_reg.in = %std_mulFN_0.out : i32
274+
// CHECK-DAG: calyx.assign %mulf_0_reg.write_en = %std_mulFN_0.done : i1
275+
// CHECK-DAG: %0 = comb.xor %std_mulFN_0.done, %true : i1
276+
// CHECK-DAG: calyx.assign %std_mulFN_0.go = %0 ? %true : i1
277+
// CHECK-DAG: calyx.group_done %mulf_0_reg.done : i1
278+
// CHECK-DAG: }
279+
module {
280+
func.func @main(%arg0 : f32) -> f32 {
281+
%0 = arith.constant 4.2 : f32
282+
%1 = arith.mulf %arg0, %0 : f32
283+
284+
return %1 : f32
285+
}
286+
}
287+
288+
// -----
289+
263290
// Test parallel op lowering
264291

265292
// CHECK: calyx.wires {

0 commit comments

Comments
 (0)