Skip to content

Commit

Permalink
[SCFToCalyx] Lower MemRef GetGlobal and write memory data to json fil…
Browse files Browse the repository at this point in the history
…es (#7301)

* build getglobal as an alloc and write external global memory data to json; we write zero's for alloc and alloca operations
  • Loading branch information
jiahanxie353 authored Dec 5, 2024
1 parent e9150ac commit baba42c
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 5 deletions.
4 changes: 3 additions & 1 deletion include/circt/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def SCFToCalyx : Pass<"lower-scf-to-calyx", "mlir::ModuleOp"> {
"Identifier of top-level function to be the entry-point component"
" of the Calyx program.">,
Option<"ciderSourceLocationMetadata", "cider-source-location-metadata", "bool", "",
"Whether to track source location for the Cider debugger.">
"Whether to track source location for the Cider debugger.">,
Option<"writeJsonOpt", "write-json", "std::string", "",
"Whether to write memory contents to the json file.">
];
}

Expand Down
37 changes: 37 additions & 0 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/JSON.h"

#include <variant>

Expand Down Expand Up @@ -450,6 +451,38 @@ class ComponentLoweringStateInterface {
return builder.create<TLibraryOp>(loc, getUniqueName(name), resTypes);
}

llvm::json::Value &getExtMemData() { return extMemData; }

const llvm::json::Value &getExtMemData() const { return extMemData; }

void setDataField(StringRef name, llvm::json::Array data) {
auto *extMemDataObj = extMemData.getAsObject();
assert(extMemDataObj && "extMemData should be an object");

auto &value = (*extMemDataObj)[name.str()];
llvm::json::Object *obj = value.getAsObject();
if (!obj) {
value = llvm::json::Object{};
obj = value.getAsObject();
}
(*obj)["data"] = llvm::json::Value(std::move(data));
}

void setFormat(StringRef name, std::string numType, bool isSigned,
unsigned width) {
auto *extMemDataObj = extMemData.getAsObject();
assert(extMemDataObj && "extMemData should be an object");

auto &value = (*extMemDataObj)[name.str()];
llvm::json::Object *obj = value.getAsObject();
if (!obj) {
value = llvm::json::Object{};
obj = value.getAsObject();
}
(*obj)["format"] = llvm::json::Object{
{"numeric_type", numType}, {"is_signed", isSigned}, {"width", width}};
}

private:
/// The component which this lowering state is associated to.
calyx::ComponentOp component;
Expand Down Expand Up @@ -486,6 +519,10 @@ class ComponentLoweringStateInterface {

/// A mapping between the callee and the instance.
llvm::StringMap<calyx::InstanceOp> instanceMap;

/// A json file to store external global memory data. See
/// https://docs.calyxir.org/lang/data-format.html?highlight=json#the-data-format
llvm::json::Value extMemData;
};

/// An interface for conversion passes that lower Calyx programs. This handles
Expand Down
127 changes: 124 additions & 3 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"

#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <filesystem>
#include <fstream>

#include <locale>
#include <numeric>
#include <variant>

namespace circt {
Expand Down Expand Up @@ -266,6 +273,14 @@ class ComponentLoweringState : public calyx::ComponentLoweringStateInterface,
/// Iterate through the operations of a source function and instantiate
/// components or primitives based on the type of the operations.
class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
public:
BuildOpGroups(MLIRContext *context, LogicalResult &resRef,
calyx::PatternApplicationState &patternState,
DenseMap<mlir::func::FuncOp, calyx::ComponentOp> &map,
calyx::CalyxLoweringState &state,
mlir::Pass::Option<std::string> &writeJsonOpt)
: FuncOpPartialLoweringPattern(context, resRef, patternState, map, state),
writeJson(writeJsonOpt) {}
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
Expand All @@ -283,7 +298,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
scf::ParallelOp, scf::ReduceOp,
/// memref
memref::AllocOp, memref::AllocaOp, memref::LoadOp,
memref::StoreOp,
memref::StoreOp, memref::GetGlobalOp,
/// standard arithmetic
AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
Expand All @@ -306,10 +321,32 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
: WalkResult::interrupt();
});

if (!writeJson.empty()) {
if (auto fileLoc = dyn_cast<mlir::FileLineColLoc>(funcOp->getLoc())) {
std::string filename = fileLoc.getFilename().str();
std::filesystem::path path(filename);
std::string jsonFileName = writeJson.append(".json");
auto outFileName = path.parent_path().append(jsonFileName);
std::ofstream outFile(outFileName);

if (!outFile.is_open()) {
llvm::errs() << "Unable to open file: " << outFileName
<< " for writing\n";
return failure();
}
llvm::raw_os_ostream llvmOut(outFile);
llvm::json::OStream jsonOS(llvmOut, 2);
jsonOS.value(getState<ComponentLoweringState>().getExtMemData());
jsonOS.flush();
outFile.close();
}
}

return success(opBuiltSuccessfully);
}

private:
mlir::Pass::Option<std::string> &writeJson;
/// Op builder specializations.
LogicalResult buildOp(PatternRewriter &rewriter, scf::YieldOp yieldOp) const;
LogicalResult buildOp(PatternRewriter &rewriter,
Expand Down Expand Up @@ -341,6 +378,8 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocaOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter,
memref::GetGlobalOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp) const;
Expand Down Expand Up @@ -962,6 +1001,82 @@ static LogicalResult buildAllocOp(ComponentLoweringState &componentState,
IntegerAttr::get(rewriter.getI1Type(), llvm::APInt(1, 1)));
componentState.registerMemoryInterface(allocOp.getResult(),
calyx::MemoryInterface(memoryOp));

unsigned elmTyBitWidth = memtype.getElementTypeBitWidth();
assert(elmTyBitWidth <= 64 && "element bitwidth should not exceed 64");
bool isFloat = !memtype.getElementType().isInteger();

auto shape = allocOp.getType().getShape();
int totalSize =
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
// The `totalSize <= 1` check is a hack to:
// https://github.com/llvm/circt/pull/2661, where a multi-dimensional memory
// whose size in some dimension equals 1, e.g. memref<1x1x1x1xi32>, will be
// collapsed to `memref<1xi32>` with `totalSize == 1`. While the above case is
// a trivial fix, Calyx expects 1-dimensional memories in general:
// https://github.com/calyxir/calyx/issues/907
if (!(shape.size() <= 1 || totalSize <= 1)) {
allocOp.emitError("input memory dimension must be empty or one.");
return failure();
}

std::vector<uint64_t> flattenedVals(totalSize, 0);
if (isa<memref::GetGlobalOp>(allocOp)) {
auto getGlobalOp = cast<memref::GetGlobalOp>(allocOp);
auto *symbolTableOp =
getGlobalOp->template getParentWithTrait<mlir::OpTrait::SymbolTable>();
auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
// Flatten the values in the attribute
auto cstAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(
globalOp.getConstantInitValue());
int sizeCount = 0;
for (auto attr : cstAttr.template getValues<Attribute>()) {
assert((isa<mlir::FloatAttr, mlir::IntegerAttr>(attr)) &&
"memory attributes must be float or int");
if (auto fltAttr = dyn_cast<mlir::FloatAttr>(attr)) {
flattenedVals[sizeCount++] =
bit_cast<uint64_t>(fltAttr.getValueAsDouble());
} else {
auto intAttr = dyn_cast<mlir::IntegerAttr>(attr);
APInt value = intAttr.getValue();
flattenedVals[sizeCount++] = *value.getRawData();
}
}

rewriter.eraseOp(globalOp);
}

llvm::json::Array result;
result.reserve(std::max(static_cast<int>(shape.size()), 1));

Type elemType = memtype.getElementType();
bool isSigned =
!elemType.isSignlessInteger() && !elemType.isUnsignedInteger();
for (uint64_t bitValue : flattenedVals) {
llvm::json::Value value = 0;
if (isFloat) {
// We cast to `double` and let downstream calyx to deal with the actual
// value's precision handling.
value = bit_cast<double>(bitValue);
} else {
APInt apInt(/*numBits=*/elmTyBitWidth, bitValue, isSigned);
// The conditional ternary operation will cause the `value` to interpret
// the underlying data as unsigned regardless `isSigned` or not.
if (isSigned)
value = static_cast<int64_t>(apInt.getSExtValue());
else
value = apInt.getZExtValue();
}
result.push_back(std::move(value));
}

componentState.setDataField(memoryOp.getName(), result);
std::string numType =
memtype.getElementType().isInteger() ? "bitnum" : "ieee754_float";
componentState.setFormat(memoryOp.getName(), numType, isSigned,
elmTyBitWidth);

return success();
}

Expand All @@ -975,6 +1090,12 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
return buildAllocOp(getState<ComponentLoweringState>(), rewriter, allocOp);
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
memref::GetGlobalOp getGlobalOp) const {
return buildAllocOp(getState<ComponentLoweringState>(), rewriter,
getGlobalOp);
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::YieldOp yieldOp) const {
if (yieldOp.getOperands().empty()) {
Expand Down Expand Up @@ -2644,7 +2765,7 @@ void SCFToCalyxPass::runOnOperation() {
/// having a distinct group for each operation, groups are analogous to SSA
/// values in the source program.
addOncePattern<BuildOpGroups>(loweringPatterns, patternState, funcMap,
*loweringState);
*loweringState, writeJsonOpt);

/// This pattern traverses the CFG of the program and generates a control
/// schedule based on the calyx::GroupOp's which were registered for each
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ BasicLoopInterface::~BasicLoopInterface() = default;

ComponentLoweringStateInterface::ComponentLoweringStateInterface(
calyx::ComponentOp component)
: component(component) {}
: component(component), extMemData(llvm::json::Object{}) {}

ComponentLoweringStateInterface::~ComponentLoweringStateInterface() = default;

Expand Down
93 changes: 93 additions & 0 deletions test/Conversion/SCFToCalyx/write_memory.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// RUN: circt-opt %s --lower-scf-to-calyx="write-json=data" -canonicalize>/dev/null && cat $(dirname %s)/data.json | FileCheck %s

// CHECK-LABEL: "mem_0": {
// CHECK: "data": [
// CHECK: 0,
// CHECK: 0,
// CHECK: 0,
// CHECK: 0
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": true,
// CHECK: "numeric_type": "ieee754_float",
// CHECK: "width": 32
// CHECK: }
// CHECK: },

// CHECK-LABEL: "mem_1": {
// CHECK: "data": [
// CHECK: 0
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": true,
// CHECK: "numeric_type": "bitnum",
// CHECK: "width": 8
// CHECK: }
// CHECK: },

// CHECK-LABEL: "mem_2": {
// CHECK: "data": [
// CHECK: 43,
// CHECK: 8,
// CHECK: 4294967257,
// CHECK: 4294967277,
// CHECK: 70,
// CHECK: 4294967232,
// CHECK: 4294967289,
// CHECK: 4294967269,
// CHECK: 4294967239,
// CHECK: 5
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": false,
// CHECK: "numeric_type": "bitnum",
// CHECK: "width": 32
// CHECK: }
// CHECK: },

// CHECK-LABEL: "mem_3": {
// CHECK: "data": [
// CHECK: 0.69999998807907104,
// CHECK: -4.1999998092651367,
// CHECK: 0
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": true,
// CHECK: "numeric_type": "ieee754_float",
// CHECK: "width": 32
// CHECK: }
// CHECK: },

// CHECK-LABEL: "mem_4": {
// CHECK: "data": [
// CHECK: -42,
// CHECK: 35
// CHECK: ],
// CHECK: "format": {
// CHECK: "is_signed": true,
// CHECK: "numeric_type": "bitnum",
// CHECK: "width": 8
// CHECK: }
// CHECK: }

module {
memref.global "private" constant @constant_10xi32_0 : memref<10xi32> = dense<[43, 8, -39, -19, 70, -64, -7, -27, -57, 5]>
memref.global "private" constant @constant_2xsi8_0 : memref<2xsi8> = dense<[-42, 35]>
memref.global "private" constant @constant_3xf32 : memref<3xf32> = dense<[0.7, -4.2, 0.0]>
func.func @main(%arg_idx : index) -> i32 {
%alloc = memref.alloc() : memref<4xf32>
%zero_dim_mem = memref.alloca() : memref<si8>
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%0 = memref.get_global @constant_10xi32_0 : memref<10xi32>
%ret = memref.load %0[%arg_idx] : memref<10xi32>
%1 = memref.get_global @constant_3xf32 : memref<3xf32>
%2 = memref.load %1[%c1] : memref<3xf32>
memref.store %2, %alloc[%c2] : memref<4xf32>
%3 = memref.get_global @constant_2xsi8_0 : memref<2xsi8>
%4 = memref.load %3[%c1] : memref<2xsi8>
memref.store %4, %zero_dim_mem[] : memref<si8>
return %ret : i32
}
}

0 comments on commit baba42c

Please sign in to comment.