From 02f45b04f0896deec609b3a48ea4de51ab4de0b8 Mon Sep 17 00:00:00 2001 From: Tong Chen Date: Wed, 11 Sep 2024 09:18:22 -0400 Subject: [PATCH] try to use new buffer deallocation (#2919) * implementation Signed-off-by: Chen Tong * comments Signed-off-by: Chen Tong * format Signed-off-by: Chen Tong --------- Signed-off-by: Chen Tong Co-authored-by: Tung D. Le Co-authored-by: Alexandre Eichenberger --- src/Compiler/CompilerDialects.cpp | 5 +++-- src/Compiler/CompilerOptions.cpp | 8 ++++++++ src/Compiler/CompilerOptions.hpp | 1 + src/Dialect/Krnl/CMakeLists.txt | 1 + src/Dialect/Krnl/DialectBuilder.cpp | 5 +++++ src/Dialect/Mlir/CMakeLists.txt | 1 + src/Dialect/Mlir/DialectBuilder.cpp | 5 +++++ 7 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/Compiler/CompilerDialects.cpp b/src/Compiler/CompilerDialects.cpp index a54d014977..87fab41a9c 100644 --- a/src/Compiler/CompilerDialects.cpp +++ b/src/Compiler/CompilerDialects.cpp @@ -46,8 +46,9 @@ DialectRegistry registerDialects(ArrayRef accels) { for (auto *accel : accel::Accelerator::getAccelerators()) accel->registerDialects(registry); - if (useOldBufferization) - memref::registerAllocationOpInterfaceExternalModels(registry); + // Register interface needed by both old and new buffer deallocation pass. + memref::registerAllocationOpInterfaceExternalModels(registry); + arith::registerBufferDeallocationOpInterfaceExternalModels(registry); return registry; } diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index 6d010bd219..03975bb1d0 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both std::vector functionsToDecompose; // common for both std::string opsForCall; // common for both bool disableKrnlOpFusion; // common for both +bool disableMemRefPrefetch; // common for both EmissionTargetType emissionTarget; // onnx-mlir only bool invokeOnnxVersionConverter; // onnx-mlir only bool preserveLocations; // onnx-mlir only @@ -211,6 +212,13 @@ static llvm::cl::opt disableKrnlOpFusionOpt( llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); +static llvm::cl::opt disableMemRefPrefetchOpt( + "disable-memref-prefetch", + llvm::cl::desc("disable generation of memref.prefetch (default=false)\n" + "Set to 'true' if you want to disable prefetch."), + llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + static llvm::cl::opt disableRecomposeOptionOpt("disable-recompose", llvm::cl::desc("Disable recomposition of ONNX operations."), llvm::cl::location(disableRecomposeOption), llvm::cl::init(false), diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 2ed9f251e1..fe12e4511c 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both extern std::vector functionsToDecompose; // common for both extern std::string opsForCall; // common for both extern bool disableKrnlOpFusion; // common for both +extern bool disableMemRefPrefetch; // common for both extern EmissionTargetType emissionTarget; // onnx-mlir only extern bool invokeOnnxVersionConverter; // onnx-mlir only extern bool preserveLocations; // onnx-mlir only diff --git a/src/Dialect/Krnl/CMakeLists.txt b/src/Dialect/Krnl/CMakeLists.txt index 541437da3a..683e4500dc 100644 --- a/src/Dialect/Krnl/CMakeLists.txt +++ b/src/Dialect/Krnl/CMakeLists.txt @@ -20,6 +20,7 @@ add_onnx_mlir_library(OMKrnlOps OMSpecializedKernelOpInterface LINK_LIBS PUBLIC + OMCompilerOptions OMONNXOps MLIRLLVMCommonConversion MLIRAffineDialect diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index 4b3a8704c8..b07d7b09d5 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/TypeSwitch.h" +#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -94,12 +95,16 @@ Value KrnlBuilder::getLinearOffsetIndexIE( void KrnlBuilder::prefetch(Value memref, ValueRange indices, bool isWrite, unsigned localityHint, bool isDataCache) { + if (disableMemRefPrefetch) + return; b().create( loc(), memref, indices, isWrite, localityHint, isDataCache); } void KrnlBuilder::prefetchIE(Value memref, ArrayRef indices, bool isWrite, unsigned localityHint, bool isDataCache) { + if (disableMemRefPrefetch) + return; SmallVector indexValues; IndexExpr::getValues(indices, indexValues); b().create( diff --git a/src/Dialect/Mlir/CMakeLists.txt b/src/Dialect/Mlir/CMakeLists.txt index 80c45ea5a2..01b4062aa4 100644 --- a/src/Dialect/Mlir/CMakeLists.txt +++ b/src/Dialect/Mlir/CMakeLists.txt @@ -12,6 +12,7 @@ add_onnx_mlir_library(OMMlirDialects OMSpecializedKernelOpInterface LINK_LIBS PUBLIC + OMCompilerOptions MLIRMathDialect MLIRAffineDialect MLIRSCFDialect diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index 546c18aabf..c77dfb5368 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/Debug.h" // Please do not add dependences on ONNX or KRNL dialects. +#include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" #include "src/Dialect/Mlir/VectorMachineSupport.hpp" @@ -1657,12 +1658,16 @@ Value MemRefBuilder::dim(Value val, Value index) const { void MemRefBuilder::prefetch(Value memref, ValueRange indices, bool isWrite, unsigned locality, bool isData) { + if (disableMemRefPrefetch) + return; b().create( loc(), memref, indices, isWrite, locality, isData); } void MemRefBuilder::prefetchIE(Value memref, ArrayRef indices, bool isWrite, unsigned locality, bool isData) { + if (disableMemRefPrefetch) + return; SmallVector indexVals; IndexExpr::getValues(indices, indexVals); prefetch(memref, indexVals, isWrite, locality, isData);