Skip to content

Commit

Permalink
support mulf op
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 committed Nov 3, 2024
1 parent 1dc931d commit 3d758ab
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 4 deletions.
34 changes: 34 additions & 0 deletions include/circt/Dialect/Calyx/CalyxPrimitives.td
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,40 @@ def AddFNOp : ArithBinaryFloatingPointLibraryOp<"addFN"> {
}];
}

def MulFNOp : ArithBinaryFloatingPointLibraryOp<"mulFN"> {
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control,
AnyFloat:$left, AnyFloat:$right, AnySignlessInteger:$roundingMode, AnyFloat:$out,
AnySignlessInteger:$exceptionalFlags, I1:$done);
let extraClassDefinition = [{
SmallVector<StringRef> $cppClass::portNames() {
return {clkPort, resetPort, goPort, "control", "left", "right",
"roundingMode", "out", "exceptionalFlags", donePort
};
}
SmallVector<Direction> $cppClass::portDirections() {
return {Input, Input, Input, Input, Input, Input, Input, Output, Output, Output};
}
void $cppClass::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
getCellAsmResultNames(setNameFn, *this, this->portNames());
}
bool $cppClass::isCombinational() { return false; }
SmallVector<DictionaryAttr> $cppClass::portAttributes() {
IntegerAttr isSet = IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
NamedAttrList go, clk, reset, done;
go.append(goPort, isSet);
clk.append(clkPort, isSet);
reset.append(resetPort, isSet);
done.append(donePort, isSet);
return {clk.getDictionary(getContext()), reset.getDictionary(getContext()),
go.getDictionary(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
done.getDictionary(getContext()), DictionaryAttr::get(getContext())
};
}
}];
}

def MuxLibOp : CalyxLibraryOp<"mux", [
Combinational, SameTypeConstraint<"tru", "fal">, SameTypeConstraint<"tru", "out">
]> {
Expand Down
20 changes: 18 additions & 2 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
/// floating point
AddFOp,
AddFOp, MulFOp,
/// others
SelectOp, IndexCastOp, CallOp>(
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
Expand Down Expand Up @@ -319,6 +319,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, RemUIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, MulFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const;
Expand Down Expand Up @@ -699,6 +700,21 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
addFN.getOut());
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
MulFOp mulf) const {
Location loc = mulf.getLoc();
Type width = mulf.getResult().getType();
IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
five = rewriter.getIntegerType(5);
auto mulFN =
getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::MulFNOp>(
rewriter, loc,
{one, one, one, one, width, width, three, width, five, one});
return buildLibraryBinaryPipeOp<calyx::MulFNOp>(rewriter, mulf, mulFN,
mulFN.getOut());
}

template <typename TAllocOp>
static LogicalResult buildAllocOp(ComponentLoweringState &componentState,
PatternRewriter &rewriter, TAllocOp allocOp) {
Expand Down Expand Up @@ -1911,7 +1927,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
ExtSIOp, CallOp, AddFOp>();
ExtSIOp, CallOp, AddFOp, MulFOp>();

RewritePatternSet legalizePatterns(&getContext());
legalizePatterns.add<DummyPattern>(&getContext());
Expand Down
7 changes: 6 additions & 1 deletion lib/Dialect/Calyx/Export/CalyxEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ struct ImportTracker {
static constexpr std::string_view sFloatingPoint = "float/addFN";
return {sFloatingPoint};
})
.Case<MulFNOp>([&](auto op) -> FailureOr<StringRef> {
static constexpr std::string_view sFloatingPoint = "float/mulFN";
return {sFloatingPoint};
})
.Default([&](auto op) {
auto diag = op->emitOpError() << "not supported for emission";
return diag;
Expand Down Expand Up @@ -675,7 +679,8 @@ void Emitter::emitComponent(ComponentInterface op) {
emitLibraryPrimTypedByFirstOutputPort(
op, /*calyxLibName=*/{"std_sdiv_pipe"});
})
.Case<AddFNOp>([&](auto op) { emitLibraryFloatingPoint(op); })
.Case<AddFNOp, MulFNOp>(
[&](auto op) { emitLibraryFloatingPoint(op); })
.Default([&](auto op) {
emitOpError(op, "not supported for emission inside component");
});
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,8 @@ void InlineCombGroups::recurseInlineCombGroups(
hw::ConstantOp, mlir::arith::ConstantOp, calyx::MultPipeLibOp,
calyx::DivUPipeLibOp, calyx::DivSPipeLibOp, calyx::RemSPipeLibOp,
calyx::RemUPipeLibOp, mlir::scf::WhileOp, calyx::InstanceOp,
calyx::ConstantOp, calyx::AddFNOp>(src.getDefiningOp()))
calyx::ConstantOp, calyx::AddFNOp, calyx::MulFNOp>(
src.getDefiningOp()))
continue;

auto srcCombGroup = dyn_cast<calyx::CombGroupOp>(
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/SCFToCalyx/convert_simple.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,25 @@ module {
return %1 : f32
}
}

// -----

// Test floating point mul

// CHECK: calyx.group @bb0_0 {
// CHECK-DAG: calyx.assign %std_mulFN_0.left = %in0 : f32
// CHECK-DAG: calyx.assign %std_mulFN_0.right = %cst : f32
// CHECK-DAG: calyx.assign %mulf_0_reg.in = %std_mulFN_0.out : f32
// CHECK-DAG: calyx.assign %mulf_0_reg.write_en = %std_mulFN_0.done : i1
// CHECK-DAG: %0 = comb.xor %std_mulFN_0.done, %true : i1
// CHECK-DAG: calyx.assign %std_mulFN_0.go = %0 ? %true : i1
// CHECK-DAG: calyx.group_done %mulf_0_reg.done : i1
// CHECK-DAG: }
module {
func.func @main(%arg0 : f32) -> f32 {
%0 = arith.constant 4.2 : f32
%1 = arith.mulf %arg0, %0 : f32

return %1 : f32
}
}
48 changes: 48 additions & 0 deletions test/Dialect/Calyx/emit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,51 @@ module attributes {calyx.entrypoint = "main"} {
}
} {toplevel}
}


// -----

module attributes {calyx.entrypoint = "main"} {
// CHECK: import "primitives/float/mulFN.futil";
calyx.component @main(%in0: f32, %clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: f32, %done: i1 {done}) {
// CHECK: std_mulFN_0 = std_mulFN(8, 24, 32);
%cst = calyx.constant {sym_name = "cst_0"} 4.200000e+00 : f32
%true = hw.constant true
%mulf_0_reg.in, %mulf_0_reg.write_en, %mulf_0_reg.clk, %mulf_0_reg.reset, %mulf_0_reg.out, %mulf_0_reg.done = calyx.register @mulf_0_reg : f32, i1, i1, i1, f32, i1
%std_mulFN_0.clk, %std_mulFN_0.reset, %std_mulFN_0.go, %std_mulFN_0.control, %std_mulFN_0.left, %std_mulFN_0.right, %std_mulFN_0.roundingMode, %std_mulFN_0.out, %std_mulFN_0.exceptionalFlags, %std_mulFN_0.done = calyx.std_mulFN @std_mulFN_0 : i1, i1, i1, i1, f32, f32, i3, f32, i5, i1
%ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : f32, i1, i1, i1, f32, i1
calyx.wires {
calyx.assign %out0 = %ret_arg0_reg.out : f32
// CHECK-LABEL: group bb0_0 {
// CHECK-NEXT: std_mulFN_0.left = in0;
// CHECK-NEXT: std_mulFN_0.right = cst_0.out;
// CHECK-NEXT: mulf_0_reg.in = std_mulFN_0.out;
// CHECK-NEXT: mulf_0_reg.write_en = std_mulFN_0.done;
// CHECK-NEXT: std_mulFN_0.go = !std_mulFN_0.done ? 1'd1;
// CHECK-NEXT: bb0_0[done] = mulf_0_reg.done;
// CHECK-NEXT: }
calyx.group @bb0_0 {
calyx.assign %std_mulFN_0.left = %in0 : f32
calyx.assign %std_mulFN_0.right = %cst : f32
calyx.assign %mulf_0_reg.in = %std_mulFN_0.out : f32
calyx.assign %mulf_0_reg.write_en = %std_mulFN_0.done : i1
%0 = comb.xor %std_mulFN_0.done, %true : i1
calyx.assign %std_mulFN_0.go = %0 ? %true : i1
calyx.group_done %mulf_0_reg.done : i1
}
calyx.group @ret_assign_0 {
calyx.assign %ret_arg0_reg.in = %std_mulFN_0.out : f32
calyx.assign %ret_arg0_reg.write_en = %true : i1
calyx.group_done %ret_arg0_reg.done : i1
}
}
calyx.control {
calyx.seq {
calyx.seq {
calyx.enable @bb0_0
calyx.enable @ret_assign_0
}
}
}
} {toplevel}
}

0 comments on commit 3d758ab

Please sign in to comment.