diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index 49e5b1ec9234..5371fd967ce8 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -182,7 +182,9 @@ def SCFToCalyx : Pass<"lower-scf-to-calyx", "mlir::ModuleOp"> { "Identifier of top-level function to be the entry-point component" " of the Calyx program.">, Option<"ciderSourceLocationMetadata", "cider-source-location-metadata", "bool", "", - "Whether to track source location for the Cider debugger."> + "Whether to track source location for the Cider debugger.">, + Option<"writeJsonOpt", "write-json", "std::string", "", + "Whether to write memory contents to the json file."> ]; } diff --git a/include/circt/Dialect/Calyx/CalyxLoweringUtils.h b/include/circt/Dialect/Calyx/CalyxLoweringUtils.h index de1aaf086c6f..e2694b11def6 100644 --- a/include/circt/Dialect/Calyx/CalyxLoweringUtils.h +++ b/include/circt/Dialect/Calyx/CalyxLoweringUtils.h @@ -27,6 +27,7 @@ #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/JSON.h" #include @@ -450,6 +451,38 @@ class ComponentLoweringStateInterface { return builder.create(loc, getUniqueName(name), resTypes); } + llvm::json::Value &getExtMemData() { return extMemData; } + + const llvm::json::Value &getExtMemData() const { return extMemData; } + + void setDataField(StringRef name, llvm::json::Array data) { + auto *extMemDataObj = extMemData.getAsObject(); + assert(extMemDataObj && "extMemData should be an object"); + + auto &value = (*extMemDataObj)[name.str()]; + llvm::json::Object *obj = value.getAsObject(); + if (!obj) { + value = llvm::json::Object{}; + obj = value.getAsObject(); + } + (*obj)["data"] = llvm::json::Value(std::move(data)); + } + + void setFormat(StringRef name, std::string numType, bool isSigned, + unsigned width) { + auto *extMemDataObj = extMemData.getAsObject(); + assert(extMemDataObj && "extMemData should be an object"); + + auto &value = (*extMemDataObj)[name.str()]; + llvm::json::Object *obj = value.getAsObject(); + if (!obj) { + value = llvm::json::Object{}; + obj = value.getAsObject(); + } + (*obj)["format"] = llvm::json::Object{ + {"numeric_type", numType}, {"is_signed", isSigned}, {"width", width}}; + } + private: /// The component which this lowering state is associated to. calyx::ComponentOp component; @@ -486,6 +519,10 @@ class ComponentLoweringStateInterface { /// A mapping between the callee and the instance. llvm::StringMap instanceMap; + + /// A json file to store external global memory data. See + /// https://docs.calyxir.org/lang/data-format.html?highlight=json#the-data-format + llvm::json::Value extMemData; }; /// An interface for conversion passes that lower Calyx programs. This handles diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index eba68cfc092b..a19029255843 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -30,7 +30,14 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/LogicalResult.h" - +#include "llvm/Support/raw_os_ostream.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include + +#include +#include #include namespace circt { @@ -266,6 +273,14 @@ class ComponentLoweringState : public calyx::ComponentLoweringStateInterface, /// Iterate through the operations of a source function and instantiate /// components or primitives based on the type of the operations. class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { +public: + BuildOpGroups(MLIRContext *context, LogicalResult &resRef, + calyx::PatternApplicationState &patternState, + DenseMap &map, + calyx::CalyxLoweringState &state, + mlir::Pass::Option &writeJsonOpt) + : FuncOpPartialLoweringPattern(context, resRef, patternState, map, state), + writeJson(writeJsonOpt) {} using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern; LogicalResult @@ -283,7 +298,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { scf::ParallelOp, scf::ReduceOp, /// memref memref::AllocOp, memref::AllocaOp, memref::LoadOp, - memref::StoreOp, + memref::StoreOp, memref::GetGlobalOp, /// standard arithmetic AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp, @@ -306,10 +321,32 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { : WalkResult::interrupt(); }); + if (!writeJson.empty()) { + if (auto fileLoc = dyn_cast(funcOp->getLoc())) { + std::string filename = fileLoc.getFilename().str(); + std::filesystem::path path(filename); + std::string jsonFileName = writeJson.append(".json"); + auto outFileName = path.parent_path().append(jsonFileName); + std::ofstream outFile(outFileName); + + if (!outFile.is_open()) { + llvm::errs() << "Unable to open file: " << outFileName + << " for writing\n"; + return failure(); + } + llvm::raw_os_ostream llvmOut(outFile); + llvm::json::OStream jsonOS(llvmOut, 2); + jsonOS.value(getState().getExtMemData()); + jsonOS.flush(); + outFile.close(); + } + } + return success(opBuiltSuccessfully); } private: + mlir::Pass::Option &writeJson; /// Op builder specializations. LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp) const; LogicalResult buildOp(PatternRewriter &rewriter, @@ -341,6 +378,8 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocaOp op) const; + LogicalResult buildOp(PatternRewriter &rewriter, + memref::GetGlobalOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const; LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp) const; @@ -962,6 +1001,82 @@ static LogicalResult buildAllocOp(ComponentLoweringState &componentState, IntegerAttr::get(rewriter.getI1Type(), llvm::APInt(1, 1))); componentState.registerMemoryInterface(allocOp.getResult(), calyx::MemoryInterface(memoryOp)); + + unsigned elmTyBitWidth = memtype.getElementTypeBitWidth(); + assert(elmTyBitWidth <= 64 && "element bitwidth should not exceed 64"); + bool isFloat = !memtype.getElementType().isInteger(); + + auto shape = allocOp.getType().getShape(); + int totalSize = + std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); + // The `totalSize <= 1` check is a hack to: + // https://github.com/llvm/circt/pull/2661, where a multi-dimensional memory + // whose size in some dimension equals 1, e.g. memref<1x1x1x1xi32>, will be + // collapsed to `memref<1xi32>` with `totalSize == 1`. While the above case is + // a trivial fix, Calyx expects 1-dimensional memories in general: + // https://github.com/calyxir/calyx/issues/907 + if (!(shape.size() <= 1 || totalSize <= 1)) { + allocOp.emitError("input memory dimension must be empty or one."); + return failure(); + } + + std::vector flattenedVals(totalSize, 0); + if (isa(allocOp)) { + auto getGlobalOp = cast(allocOp); + auto *symbolTableOp = + getGlobalOp->template getParentWithTrait(); + auto globalOp = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr())); + // Flatten the values in the attribute + auto cstAttr = llvm::dyn_cast_or_null( + globalOp.getConstantInitValue()); + int sizeCount = 0; + for (auto attr : cstAttr.template getValues()) { + assert((isa(attr)) && + "memory attributes must be float or int"); + if (auto fltAttr = dyn_cast(attr)) { + flattenedVals[sizeCount++] = + bit_cast(fltAttr.getValueAsDouble()); + } else { + auto intAttr = dyn_cast(attr); + APInt value = intAttr.getValue(); + flattenedVals[sizeCount++] = *value.getRawData(); + } + } + + rewriter.eraseOp(globalOp); + } + + llvm::json::Array result; + result.reserve(std::max(static_cast(shape.size()), 1)); + + Type elemType = memtype.getElementType(); + bool isSigned = + !elemType.isSignlessInteger() && !elemType.isUnsignedInteger(); + for (uint64_t bitValue : flattenedVals) { + llvm::json::Value value = 0; + if (isFloat) { + // We cast to `double` and let downstream calyx to deal with the actual + // value's precision handling. + value = bit_cast(bitValue); + } else { + APInt apInt(/*numBits=*/elmTyBitWidth, bitValue, isSigned); + // The conditional ternary operation will cause the `value` to interpret + // the underlying data as unsigned regardless `isSigned` or not. + if (isSigned) + value = static_cast(apInt.getSExtValue()); + else + value = apInt.getZExtValue(); + } + result.push_back(std::move(value)); + } + + componentState.setDataField(memoryOp.getName(), result); + std::string numType = + memtype.getElementType().isInteger() ? "bitnum" : "ieee754_float"; + componentState.setFormat(memoryOp.getName(), numType, isSigned, + elmTyBitWidth); + return success(); } @@ -975,6 +1090,12 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, return buildAllocOp(getState(), rewriter, allocOp); } +LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, + memref::GetGlobalOp getGlobalOp) const { + return buildAllocOp(getState(), rewriter, + getGlobalOp); +} + LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp) const { if (yieldOp.getOperands().empty()) { @@ -2644,7 +2765,7 @@ void SCFToCalyxPass::runOnOperation() { /// having a distinct group for each operation, groups are analogous to SSA /// values in the source program. addOncePattern(loweringPatterns, patternState, funcMap, - *loweringState); + *loweringState, writeJsonOpt); /// This pattern traverses the CFG of the program and generates a control /// schedule based on the calyx::GroupOp's which were registered for each diff --git a/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp b/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp index 4423feb91b28..7c8055f215d4 100644 --- a/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp +++ b/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp @@ -318,7 +318,7 @@ BasicLoopInterface::~BasicLoopInterface() = default; ComponentLoweringStateInterface::ComponentLoweringStateInterface( calyx::ComponentOp component) - : component(component) {} + : component(component), extMemData(llvm::json::Object{}) {} ComponentLoweringStateInterface::~ComponentLoweringStateInterface() = default; diff --git a/test/Conversion/SCFToCalyx/write_memory.mlir b/test/Conversion/SCFToCalyx/write_memory.mlir new file mode 100644 index 000000000000..54104f52687d --- /dev/null +++ b/test/Conversion/SCFToCalyx/write_memory.mlir @@ -0,0 +1,93 @@ +// RUN: circt-opt %s --lower-scf-to-calyx="write-json=data" -canonicalize>/dev/null && cat $(dirname %s)/data.json | FileCheck %s + +// CHECK-LABEL: "mem_0": { +// CHECK: "data": [ +// CHECK: 0, +// CHECK: 0, +// CHECK: 0, +// CHECK: 0 +// CHECK: ], +// CHECK: "format": { +// CHECK: "is_signed": true, +// CHECK: "numeric_type": "ieee754_float", +// CHECK: "width": 32 +// CHECK: } +// CHECK: }, + +// CHECK-LABEL: "mem_1": { +// CHECK: "data": [ +// CHECK: 0 +// CHECK: ], +// CHECK: "format": { +// CHECK: "is_signed": true, +// CHECK: "numeric_type": "bitnum", +// CHECK: "width": 8 +// CHECK: } +// CHECK: }, + +// CHECK-LABEL: "mem_2": { +// CHECK: "data": [ +// CHECK: 43, +// CHECK: 8, +// CHECK: 4294967257, +// CHECK: 4294967277, +// CHECK: 70, +// CHECK: 4294967232, +// CHECK: 4294967289, +// CHECK: 4294967269, +// CHECK: 4294967239, +// CHECK: 5 +// CHECK: ], +// CHECK: "format": { +// CHECK: "is_signed": false, +// CHECK: "numeric_type": "bitnum", +// CHECK: "width": 32 +// CHECK: } +// CHECK: }, + +// CHECK-LABEL: "mem_3": { +// CHECK: "data": [ +// CHECK: 0.69999998807907104, +// CHECK: -4.1999998092651367, +// CHECK: 0 +// CHECK: ], +// CHECK: "format": { +// CHECK: "is_signed": true, +// CHECK: "numeric_type": "ieee754_float", +// CHECK: "width": 32 +// CHECK: } +// CHECK: }, + +// CHECK-LABEL: "mem_4": { +// CHECK: "data": [ +// CHECK: -42, +// CHECK: 35 +// CHECK: ], +// CHECK: "format": { +// CHECK: "is_signed": true, +// CHECK: "numeric_type": "bitnum", +// CHECK: "width": 8 +// CHECK: } +// CHECK: } + +module { + memref.global "private" constant @constant_10xi32_0 : memref<10xi32> = dense<[43, 8, -39, -19, 70, -64, -7, -27, -57, 5]> + memref.global "private" constant @constant_2xsi8_0 : memref<2xsi8> = dense<[-42, 35]> + memref.global "private" constant @constant_3xf32 : memref<3xf32> = dense<[0.7, -4.2, 0.0]> + func.func @main(%arg_idx : index) -> i32 { + %alloc = memref.alloc() : memref<4xf32> + %zero_dim_mem = memref.alloca() : memref + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %0 = memref.get_global @constant_10xi32_0 : memref<10xi32> + %ret = memref.load %0[%arg_idx] : memref<10xi32> + %1 = memref.get_global @constant_3xf32 : memref<3xf32> + %2 = memref.load %1[%c1] : memref<3xf32> + memref.store %2, %alloc[%c2] : memref<4xf32> + %3 = memref.get_global @constant_2xsi8_0 : memref<2xsi8> + %4 = memref.load %3[%c1] : memref<2xsi8> + memref.store %4, %zero_dim_mem[] : memref + return %ret : i32 + } +} +