Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger committed Sep 11, 2024
2 parents 01a0ada + 02f45b0 commit 630a97e
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 7 deletions.
5 changes: 3 additions & 2 deletions src/Compiler/CompilerDialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ DialectRegistry registerDialects(ArrayRef<accel::Accelerator::Kind> accels) {
for (auto *accel : accel::Accelerator::getAccelerators())
accel->registerDialects(registry);

if (useOldBufferization)
memref::registerAllocationOpInterfaceExternalModels(registry);
// Register interface needed by both old and new buffer deallocation pass.
memref::registerAllocationOpInterfaceExternalModels(registry);
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);

return registry;
}
Expand Down
17 changes: 12 additions & 5 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ std::vector<std::string> functionsToDecompose; // common for both
std::string opsForCall; // common for both
bool disableKrnlOpFusion; // common for both
bool disableQuantZeroPoint; // common for both
bool disableMemRefPrefetch; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
bool invokeOnnxVersionConverter; // onnx-mlir only
bool preserveLocations; // onnx-mlir only
Expand Down Expand Up @@ -194,7 +195,7 @@ static llvm::cl::list<std::string, std::vector<std::string>>
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> enableONNXHybridPassOpt("onnx-hybrid-pass",
llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n"
llvm::cl::desc("Enable ONNX hybrid pass (default=true).\n"
"Set to 'false' if you want to disable ONNX hybrid pass."),
llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirCommonOptions));
Expand All @@ -207,20 +208,27 @@ static llvm::cl::list<std::string, std::vector<std::string>>

static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
"disable-krnl-op-fusion",
llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n"
llvm::cl::desc("Disable op fusion in onnx-to-krnl pass (default=false).\n"
"Set to 'true' if you want to disable fusion."),
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disable_quantization_zero_point(
"disable-quantization-zero-point",
llvm::cl::desc(
"Disable the use of zero-point in quantization.\n"
"Disable the use of zero-point in quantization (default=false).\n"
"Set to 'true' if you want to disable the use of zero-point\n"
"in dyn/static quantization/dequantization. Default is false."),
"in dyn/static quantization/dequantization."),
llvm::cl::location(disableQuantZeroPoint), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
"disable-memref-prefetch",
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"
"Set to 'true' if you want to disable prefetch."),
llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disableRecomposeOptionOpt("disable-recompose",
llvm::cl::desc("Disable recomposition of ONNX operations."),
llvm::cl::location(disableRecomposeOption), llvm::cl::init(false),
Expand Down Expand Up @@ -1138,7 +1146,6 @@ std::string getLibraryPath() {
// as lrodataScript.
std::string getToolPath(
const std::string &tool, bool flag /*false by default*/) {

if (!flag) {
std::string execDir = llvm::sys::path::parent_path(getExecPath()).str();
llvm::SmallString<8> toolPath(execDir);
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ extern std::vector<std::string> functionsToDecompose; // common for both
extern std::string opsForCall; // common for both
extern bool disableKrnlOpFusion; // common for both
extern bool disableQuantZeroPoint; // common for both
extern bool disableMemRefPrefetch; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
extern bool invokeOnnxVersionConverter; // onnx-mlir only
extern bool preserveLocations; // onnx-mlir only
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/Krnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_onnx_mlir_library(OMKrnlOps
OMSpecializedKernelOpInterface

LINK_LIBS PUBLIC
OMCompilerOptions
OMONNXOps
MLIRLLVMCommonConversion
MLIRAffineDialect
Expand Down
5 changes: 5 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "llvm/ADT/TypeSwitch.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
Expand Down Expand Up @@ -94,12 +95,16 @@ Value KrnlBuilder::getLinearOffsetIndexIE(

void KrnlBuilder::prefetch(Value memref, ValueRange indices, bool isWrite,
unsigned localityHint, bool isDataCache) {
if (disableMemRefPrefetch)
return;
b().create<KrnlPrefetchOp>(
loc(), memref, indices, isWrite, localityHint, isDataCache);
}

void KrnlBuilder::prefetchIE(Value memref, ArrayRef<IndexExpr> indices,
bool isWrite, unsigned localityHint, bool isDataCache) {
if (disableMemRefPrefetch)
return;
SmallVector<Value, 4> indexValues;
IndexExpr::getValues(indices, indexValues);
b().create<KrnlPrefetchOp>(
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/Mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_onnx_mlir_library(OMMlirDialects
OMSpecializedKernelOpInterface

LINK_LIBS PUBLIC
OMCompilerOptions
MLIRMathDialect
MLIRAffineDialect
MLIRSCFDialect
Expand Down
5 changes: 5 additions & 0 deletions src/Dialect/Mlir/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llvm/Support/Debug.h"

// Please do not add dependences on ONNX or KRNL dialects.
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Dialect/Mlir/VectorMachineSupport.hpp"

Expand Down Expand Up @@ -1657,12 +1658,16 @@ Value MemRefBuilder::dim(Value val, Value index) const {

void MemRefBuilder::prefetch(Value memref, ValueRange indices, bool isWrite,
unsigned locality, bool isData) {
if (disableMemRefPrefetch)
return;
b().create<memref::PrefetchOp>(
loc(), memref, indices, isWrite, locality, isData);
}

void MemRefBuilder::prefetchIE(Value memref, ArrayRef<IndexExpr> indices,
bool isWrite, unsigned locality, bool isData) {
if (disableMemRefPrefetch)
return;
SmallVector<Value, 4> indexVals;
IndexExpr::getValues(indices, indexVals);
prefetch(memref, indexVals, isWrite, locality, isData);
Expand Down

0 comments on commit 630a97e

Please sign in to comment.