From 3d758ab214c16456bdf7fd155818e43be9328511 Mon Sep 17 00:00:00 2001 From: Jiahan Xie Date: Sun, 3 Nov 2024 10:38:17 -0500 Subject: [PATCH] support mulf op --- .../circt/Dialect/Calyx/CalyxPrimitives.td | 34 +++++++++++++ lib/Conversion/SCFToCalyx/SCFToCalyx.cpp | 20 +++++++- lib/Dialect/Calyx/Export/CalyxEmitter.cpp | 7 ++- .../Calyx/Transforms/CalyxLoweringUtils.cpp | 3 +- .../Conversion/SCFToCalyx/convert_simple.mlir | 22 +++++++++ test/Dialect/Calyx/emit.mlir | 48 +++++++++++++++++++ 6 files changed, 130 insertions(+), 4 deletions(-) diff --git a/include/circt/Dialect/Calyx/CalyxPrimitives.td b/include/circt/Dialect/Calyx/CalyxPrimitives.td index 1a9012b5796d..ae19d15d7c63 100644 --- a/include/circt/Dialect/Calyx/CalyxPrimitives.td +++ b/include/circt/Dialect/Calyx/CalyxPrimitives.td @@ -378,6 +378,40 @@ def AddFNOp : ArithBinaryFloatingPointLibraryOp<"addFN"> { }]; } +def MulFNOp : ArithBinaryFloatingPointLibraryOp<"mulFN"> { + let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, + AnyFloat:$left, AnyFloat:$right, AnySignlessInteger:$roundingMode, AnyFloat:$out, + AnySignlessInteger:$exceptionalFlags, I1:$done); + let extraClassDefinition = [{ + SmallVector $cppClass::portNames() { + return {clkPort, resetPort, goPort, "control", "left", "right", + "roundingMode", "out", "exceptionalFlags", donePort + }; + } + SmallVector $cppClass::portDirections() { + return {Input, Input, Input, Input, Input, Input, Input, Output, Output, Output}; + } + void $cppClass::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { + getCellAsmResultNames(setNameFn, *this, this->portNames()); + } + bool $cppClass::isCombinational() { return false; } + SmallVector $cppClass::portAttributes() { + IntegerAttr isSet = IntegerAttr::get(IntegerType::get(getContext(), 1), 1); + NamedAttrList go, clk, reset, done; + go.append(goPort, isSet); + clk.append(clkPort, isSet); + reset.append(resetPort, isSet); + done.append(donePort, isSet); + return {clk.getDictionary(getContext()), reset.getDictionary(getContext()), + go.getDictionary(getContext()), DictionaryAttr::get(getContext()), + DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()), + DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()), + done.getDictionary(getContext()), DictionaryAttr::get(getContext()) + }; + } + }]; +} + def MuxLibOp : CalyxLibraryOp<"mux", [ Combinational, SameTypeConstraint<"tru", "fal">, SameTypeConstraint<"tru", "out"> ]> { diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index e0b041509431..df32a4884f43 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -283,7 +283,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp, MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp, /// floating point - AddFOp, + AddFOp, MulFOp, /// others SelectOp, IndexCastOp, CallOp>( [&](auto op) { return buildOp(rewriter, op).succeeded(); }) @@ -319,6 +319,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { LogicalResult buildOp(PatternRewriter &rewriter, RemUIOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op) const; + LogicalResult buildOp(PatternRewriter &rewriter, MulFOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const; @@ -699,6 +700,21 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, addFN.getOut()); } +LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, + MulFOp mulf) const { + Location loc = mulf.getLoc(); + Type width = mulf.getResult().getType(); + IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3), + five = rewriter.getIntegerType(5); + auto mulFN = + getState() + .getNewLibraryOpInstance( + rewriter, loc, + {one, one, one, one, width, width, three, width, five, one}); + return buildLibraryBinaryPipeOp(rewriter, mulf, mulFN, + mulFN.getOut()); +} + template static LogicalResult buildAllocOp(ComponentLoweringState &componentState, PatternRewriter &rewriter, TAllocOp allocOp) { @@ -1911,7 +1927,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase { ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp, CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp, - ExtSIOp, CallOp, AddFOp>(); + ExtSIOp, CallOp, AddFOp, MulFOp>(); RewritePatternSet legalizePatterns(&getContext()); legalizePatterns.add(&getContext()); diff --git a/lib/Dialect/Calyx/Export/CalyxEmitter.cpp b/lib/Dialect/Calyx/Export/CalyxEmitter.cpp index b115e1250716..28805ddd8882 100644 --- a/lib/Dialect/Calyx/Export/CalyxEmitter.cpp +++ b/lib/Dialect/Calyx/Export/CalyxEmitter.cpp @@ -153,6 +153,10 @@ struct ImportTracker { static constexpr std::string_view sFloatingPoint = "float/addFN"; return {sFloatingPoint}; }) + .Case([&](auto op) -> FailureOr { + static constexpr std::string_view sFloatingPoint = "float/mulFN"; + return {sFloatingPoint}; + }) .Default([&](auto op) { auto diag = op->emitOpError() << "not supported for emission"; return diag; @@ -675,7 +679,8 @@ void Emitter::emitComponent(ComponentInterface op) { emitLibraryPrimTypedByFirstOutputPort( op, /*calyxLibName=*/{"std_sdiv_pipe"}); }) - .Case([&](auto op) { emitLibraryFloatingPoint(op); }) + .Case( + [&](auto op) { emitLibraryFloatingPoint(op); }) .Default([&](auto op) { emitOpError(op, "not supported for emission inside component"); }); diff --git a/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp b/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp index 08d4aa7d72f1..8503c8f5db7f 100644 --- a/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp +++ b/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp @@ -660,7 +660,8 @@ void InlineCombGroups::recurseInlineCombGroups( hw::ConstantOp, mlir::arith::ConstantOp, calyx::MultPipeLibOp, calyx::DivUPipeLibOp, calyx::DivSPipeLibOp, calyx::RemSPipeLibOp, calyx::RemUPipeLibOp, mlir::scf::WhileOp, calyx::InstanceOp, - calyx::ConstantOp, calyx::AddFNOp>(src.getDefiningOp())) + calyx::ConstantOp, calyx::AddFNOp, calyx::MulFNOp>( + src.getDefiningOp())) continue; auto srcCombGroup = dyn_cast( diff --git a/test/Conversion/SCFToCalyx/convert_simple.mlir b/test/Conversion/SCFToCalyx/convert_simple.mlir index 6c603fc95f3c..8579df308c97 100644 --- a/test/Conversion/SCFToCalyx/convert_simple.mlir +++ b/test/Conversion/SCFToCalyx/convert_simple.mlir @@ -257,3 +257,25 @@ module { return %1 : f32 } } + +// ----- + +// Test floating point mul + +// CHECK: calyx.group @bb0_0 { +// CHECK-DAG: calyx.assign %std_mulFN_0.left = %in0 : f32 +// CHECK-DAG: calyx.assign %std_mulFN_0.right = %cst : f32 +// CHECK-DAG: calyx.assign %mulf_0_reg.in = %std_mulFN_0.out : f32 +// CHECK-DAG: calyx.assign %mulf_0_reg.write_en = %std_mulFN_0.done : i1 +// CHECK-DAG: %0 = comb.xor %std_mulFN_0.done, %true : i1 +// CHECK-DAG: calyx.assign %std_mulFN_0.go = %0 ? %true : i1 +// CHECK-DAG: calyx.group_done %mulf_0_reg.done : i1 +// CHECK-DAG: } +module { + func.func @main(%arg0 : f32) -> f32 { + %0 = arith.constant 4.2 : f32 + %1 = arith.mulf %arg0, %0 : f32 + + return %1 : f32 + } +} diff --git a/test/Dialect/Calyx/emit.mlir b/test/Dialect/Calyx/emit.mlir index 13925b0b9033..ea0c0c2c3c45 100644 --- a/test/Dialect/Calyx/emit.mlir +++ b/test/Dialect/Calyx/emit.mlir @@ -331,3 +331,51 @@ module attributes {calyx.entrypoint = "main"} { } } {toplevel} } + + +// ----- + +module attributes {calyx.entrypoint = "main"} { + // CHECK: import "primitives/float/mulFN.futil"; + calyx.component @main(%in0: f32, %clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: f32, %done: i1 {done}) { + // CHECK: std_mulFN_0 = std_mulFN(8, 24, 32); + %cst = calyx.constant {sym_name = "cst_0"} 4.200000e+00 : f32 + %true = hw.constant true + %mulf_0_reg.in, %mulf_0_reg.write_en, %mulf_0_reg.clk, %mulf_0_reg.reset, %mulf_0_reg.out, %mulf_0_reg.done = calyx.register @mulf_0_reg : f32, i1, i1, i1, f32, i1 + %std_mulFN_0.clk, %std_mulFN_0.reset, %std_mulFN_0.go, %std_mulFN_0.control, %std_mulFN_0.left, %std_mulFN_0.right, %std_mulFN_0.roundingMode, %std_mulFN_0.out, %std_mulFN_0.exceptionalFlags, %std_mulFN_0.done = calyx.std_mulFN @std_mulFN_0 : i1, i1, i1, i1, f32, f32, i3, f32, i5, i1 + %ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : f32, i1, i1, i1, f32, i1 + calyx.wires { + calyx.assign %out0 = %ret_arg0_reg.out : f32 + // CHECK-LABEL: group bb0_0 { + // CHECK-NEXT: std_mulFN_0.left = in0; + // CHECK-NEXT: std_mulFN_0.right = cst_0.out; + // CHECK-NEXT: mulf_0_reg.in = std_mulFN_0.out; + // CHECK-NEXT: mulf_0_reg.write_en = std_mulFN_0.done; + // CHECK-NEXT: std_mulFN_0.go = !std_mulFN_0.done ? 1'd1; + // CHECK-NEXT: bb0_0[done] = mulf_0_reg.done; + // CHECK-NEXT: } + calyx.group @bb0_0 { + calyx.assign %std_mulFN_0.left = %in0 : f32 + calyx.assign %std_mulFN_0.right = %cst : f32 + calyx.assign %mulf_0_reg.in = %std_mulFN_0.out : f32 + calyx.assign %mulf_0_reg.write_en = %std_mulFN_0.done : i1 + %0 = comb.xor %std_mulFN_0.done, %true : i1 + calyx.assign %std_mulFN_0.go = %0 ? %true : i1 + calyx.group_done %mulf_0_reg.done : i1 + } + calyx.group @ret_assign_0 { + calyx.assign %ret_arg0_reg.in = %std_mulFN_0.out : f32 + calyx.assign %ret_arg0_reg.write_en = %true : i1 + calyx.group_done %ret_arg0_reg.done : i1 + } + } + calyx.control { + calyx.seq { + calyx.seq { + calyx.enable @bb0_0 + calyx.enable @ret_assign_0 + } + } + } + } {toplevel} +}