From 33b466ecff0fcee7ff2b86a0ef257ff49463b2c3 Mon Sep 17 00:00:00 2001 From: Haruki Imai Date: Tue, 5 Nov 2024 23:48:20 +0900 Subject: [PATCH] [NNPA] Memory reduction of stickified constant by stickifying at file writing (#2917) Reduce memory usage for NNPA compilation. Change to set original data, not stickified data, in ZHighConstPropagationPass. Then, in the KrnlToLLVMPass, stickfied data is created and stored in the file, and deleted after writing into the file. --- .../Conversion/ZHighToZLow/ZHighToZLow.cpp | 133 +++++++++----- .../NNPA/Dialect/ZHigh/CMakeLists.txt | 1 + src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td | 5 +- .../NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp | 51 +++++- .../NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp | 8 + .../NNPA/Dialect/ZLow/CMakeLists.txt | 5 + src/Accelerators/NNPA/Dialect/ZLow/ZLow.td | 17 ++ .../NNPA/Dialect/ZLow/ZLowOps.cpp | 77 ++++++++ .../NNPA/Dialect/ZLow/ZLowOps.hpp | 2 + .../NNPA/Transform/ZHigh/CMakeLists.txt | 4 +- .../Transform/ZHigh/ZHighConstPropagation.cpp | 98 ++++------ .../ZLow/ZLowDummyOpForMultiDerefPass.cpp | 2 +- src/Conversion/KrnlToLLVM/CMakeLists.txt | 3 +- .../KrnlToLLVM/ConvertKrnlToLLVM.cpp | 63 ++++--- .../KrnlToLLVM/ConvertKrnlToLLVM.hpp | 5 +- ...nlGlobal.cpp => KrnlGlobalOpInterface.cpp} | 172 ++++++++++-------- src/Dialect/Krnl/CMakeLists.txt | 1 + src/Dialect/Krnl/Krnl.td | 4 +- src/Dialect/Krnl/KrnlOps.cpp | 30 +++ src/Dialect/Krnl/KrnlOps.hpp | 1 + src/Dialect/ONNX/ONNXOps/OpHelper.cpp | 24 +++ src/Dialect/ONNX/ONNXOps/OpHelper.hpp | 3 + src/Interface/CMakeLists.txt | 12 ++ src/Interface/KrnlGlobalOpInterface.cpp | 24 +++ src/Interface/KrnlGlobalOpInterface.hpp | 27 +++ src/Interface/KrnlGlobalOpInterface.td | 88 +++++++++ .../device-placement/emit-zhighir.mlir | 4 +- .../nnpa/conversion/lower-all-to-llvm.mlir | 1 + .../conversion/lower-to-llvm-be/lit.local.cfg | 6 + .../lower-all-to-llvm_be.mlir | 21 +++ .../nnpa/conversion/zhigh-to-zlow/conv.mlir | 2 +- .../nnpa/conversion/zhigh-to-zlow/gru.mlir | 4 +- .../nnpa/conversion/zhigh-to-zlow/lstm.mlir | 4 +- .../zhigh-to-zlow/stickified-constant.mlir | 18 +- test/mlir/accelerators/nnpa/driver/ccfd.mlir | 27 ++- .../driver/dense-out-attention-layer.mlir | 2 +- .../constprop.mlir | 94 ++-------- 37 files changed, 707 insertions(+), 336 deletions(-) rename src/Conversion/KrnlToLLVM/{KrnlGlobal.cpp => KrnlGlobalOpInterface.cpp} (66%) create mode 100644 src/Interface/KrnlGlobalOpInterface.cpp create mode 100644 src/Interface/KrnlGlobalOpInterface.hpp create mode 100644 src/Interface/KrnlGlobalOpInterface.td create mode 100644 test/mlir/accelerators/nnpa/conversion/lower-to-llvm-be/lit.local.cfg create mode 100644 test/mlir/accelerators/nnpa/conversion/lower-to-llvm-be/lower-all-to-llvm_be.mlir diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp index adff6073e9..c3fca41393 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp @@ -190,27 +190,47 @@ Value insertAllocOrEmitZeroConstant(ArrayRef dims, affine::normalizeMemRefType(mlir::cast(zMemRefType.value)); // Create a ZHighStickifiedConstantOp. - ZHighStickifiedConstantOp stickifiedConstant = - rewriter.create(loc, resType, - /*value=*/nullptr, - /*alignment=*/rewriter.getI64IntegerAttr(4096)); - - // Use an dense resource attribute to store stickified data. - // Attribute type: tensor - int64_t sizeInBytes = - affine::getIntOrFloatMemRefSizeInBytes(resType).value(); - char *rawData = static_cast(malloc(sizeInBytes)); - assert(rawData && "failed to allocate memory for stickified data"); - memset(rawData, 0, sizeInBytes); - DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get( - RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()), - stickifiedConstant.getOperation() - ->getDialect() - ->getNamespace(), // use the dialect as the blob "hint" - HeapAsmResourceBlob::allocateAndCopyWithAlign( - llvm::ArrayRef(rawData, sizeInBytes), alignof(char))); - stickifiedConstant.setValueAttr(valueAttr); - free(rawData); + + // Keep previous implementation about generating stickified data at + // ZHighConstPropagationPass. To use this, comment in and set directive " + // NNPA_ZHIGH_STICKIFIEDCONST_GEN" + // + // #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN + // // Set zero in value attribute as DenseResourceElementsAttribute. + // ZHighStickifiedConstantOp stickifiedConstant = + // rewriter.create(loc, resType, + // /*stickified=*/rewriter.getBoolAttr(true), + // /*value=*/nullptr, + // /*alignment=*/rewriter.getI64IntegerAttr(4096)); + // + // // Use an dense resource attribute to store stickified data. + // // Attribute type: tensor + // int64_t sizeInBytes = + // affine::getIntOrFloatMemRefSizeInBytes(resType).value(); + // char *rawData = static_cast(malloc(sizeInBytes)); + // assert(rawData && "failed to allocate memory for stickified data"); + // memset(rawData, 0, sizeInBytes); + // DenseResourceElementsAttr valueAttr = + // DenseUI8ResourceElementsAttr::get( + // RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()), + // stickifiedConstant.getOperation() + // ->getDialect() + // ->getNamespace(), // use the dialect as the blob "hint" + // HeapAsmResourceBlob::allocateAndCopyWithAlign( + // llvm::ArrayRef(rawData, sizeInBytes), alignof(char))); + // stickifiedConstant.setValueAttr(valueAttr); + // free(rawData); + // #else + + // Set zero in value attribute as SplatElementsAttr. + FloatAttr floatZero = rewriter.getFloatAttr(resType.getElementType(), 0.0); + ZHighStickifiedConstantOp stickifiedConstant = rewriter.create< + ZHighStickifiedConstantOp>(loc, resType, + /*stickified=*/rewriter.getBoolAttr(true), + /*value=*/SplatElementsAttr::get(cast(resType), floatZero), + /*alignment=*/rewriter.getI64IntegerAttr(4096)); + + // #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN res = stickifiedConstant.getResult(); } else { @@ -686,7 +706,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern { }; //===----------------------------------------------------------------------===// -// Lower ZHigh Stickified Constant to KrnlGlobal +// Lower ZHigh Stickified Constant to ZLow Stickified Constant //===----------------------------------------------------------------------===// struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern { @@ -699,7 +719,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); - ZHighStickifiedConstantOp stickifiedConstOp = + ZHighStickifiedConstantOp zhighStickifiedConstOp = llvm::dyn_cast(op); // Convert ZTensor type to MemRefType. @@ -713,36 +733,59 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern { affine::normalizeMemRefType(mlir::cast(zMemRefType.value)); ArrayRef normalizedShape = normalizedType.getShape(); - // Get dense resource attribute. - auto blob = mlir::cast( - stickifiedConstOp.getValue().value()) - .getRawHandle() - .getBlob(); - assert(blob && "Expecting dense resource with a valid blob"); - ArrayRef data = blob->getData(); - - // Validate the stickified tensor. - int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType); - memRefSizeInBytes *= normalizedType.getNumElements(); - assert((data.size() == static_cast(memRefSizeInBytes)) && - "The stickified tensor's buffer size and MemRef's size mismatched"); - - // Create a KrnlGlobalOp. - KrnlGlobalOp constantGlobal = - rewriter.create(loc, zMemRefType.value, + // Create ZLowStickifiedConstantOp. + StringAttr layout = + getZTensorLayoutAttr(rewriter, *op->result_type_begin()); + + // Keep previous implementation about generating stickified data at + // ZHighConstPropagationPass. To use this, comment in and set directive " + // NNPA_ZHIGH_STICKIFIEDCONST_GEN" + // + // #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN + // // Lower to KrnlGlobalOp + // // Get dense resource attribute. + // auto blob = mlir::cast( + // zhighStickifiedConstOp.getValue().value()) + // .getRawHandle() + // .getBlob(); + // assert(blob && "Expecting dense resource with a valid blob"); + // ArrayRef data = blob->getData(); + // // Validate the stickified tensor. + // int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType); + // memRefSizeInBytes *= normalizedType.getNumElements(); + // assert((data.size() == static_cast(memRefSizeInBytes)) && + // "The stickified tensor's buffer size and MemRef's size + // mismatched"); + // // Create a KrnlGlobalOp. + // KrnlGlobalOp constantOp = + // rewriter.create(loc, zMemRefType.value, + // /*shape=*/ + // rewriter.getI64ArrayAttr(normalizedShape), + // /*name=*/ + // rewriter.getStringAttr( + // "constant_stickify_" + std::to_string(constantID)), + // /*value=*/zhighStickifiedConstOp.getValueAttr(), + // /*offset=*/nullptr, + // /*alignment=*/zhighStickifiedConstOp.getAlignmentAttr()); + // #else + ZLowStickifiedConstantOp constantOp = + rewriter.create(loc, + mlir::cast(zMemRefType.value), /*shape=*/ rewriter.getI64ArrayAttr(normalizedShape), /*name=*/ rewriter.getStringAttr( "constant_stickify_" + std::to_string(constantID)), - /*value=*/stickifiedConstOp.getValueAttr(), - /*offset=*/nullptr, - /*alignment=*/stickifiedConstOp.getAlignmentAttr()); - + /*stickified=*/zhighStickifiedConstOp.getStickifiedAttr(), + /*value=*/zhighStickifiedConstOp.getValueAttr(), + /*layout=*/layout, + /*offset=*/rewriter.getI64IntegerAttr(0), + /*alignment=*/zhighStickifiedConstOp.getAlignmentAttr()); + // #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN // Increment constant ID: constantID++; - rewriter.replaceOp(op, constantGlobal.getResult()); + rewriter.replaceOp(op, constantOp.getResult()); return success(); } }; diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt index 915ed61717..af5bab1779 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt @@ -47,6 +47,7 @@ add_onnx_mlir_library(OMZHighOps OMONNXOps # Use ONNXShapeHelper OMLayoutHelper OMShapeHelperOpInterface + OMStickify OMNNPACompilerOptions MLIRIR diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td index d2624138c0..8b17786bcd 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td @@ -862,11 +862,14 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> { let summary = "ZHigh Stickified Constant operation"; let description = [{ This operator produces a constant tensor to store stickified data. + `value` attribute has original constant or stickified constant. + `stickified` attribute indicates the `value` is already stickified or not. Stickified data is opaque and must be 4K-aligned. One who produces the stickified data must make sure its size in bytes consistent with the output tensor's size. }]; - let arguments = (ins OptionalAttr:$value, + let arguments = (ins BoolAttr:$stickified, + OptionalAttr:$value, DefaultValuedAttr:$alignment); let results = (outs AnyZTensor:$output); } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp index f5b9ff910f..028b5ac528 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp @@ -12,7 +12,6 @@ #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp" #include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp" -#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" #include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" @@ -482,5 +481,55 @@ IntegerAttr getDefaultSaturation(PatternRewriter &rewriter) { return IntegerAttr(); } +/// MLIR type to zDNN type. +zdnn_data_types mlirTypeToZDNNType(Type elementType) { + if (mlir::isa(elementType)) { + FloatType floatTy = mlir::cast(elementType); + if (floatTy.getWidth() == 16) { + return FP16; + } else if (floatTy.getWidth() == 32) { + return FP32; + } else + llvm_unreachable("Unsupported data type."); + } else + llvm_unreachable("Unsupported data type."); +} + +/// Get stickified data from denseElementAttribute +ArrayRef getStickifiedDataOfDenseElemAttr( + DenseElementsAttr denseAttr, StringAttr layout) { + ArrayRef shape = denseAttr.getType().getShape(); + Type elementType = denseAttr.getType().getElementType(); + int rank = shape.size(); + // Read attributes's raw data. + std::vector attrData; + getRawData(denseAttr, attrData); + // Call stickify. + zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc; + // pre-transformed desc. + zdnn_data_layouts zDNNLayout = + convertLayoutAttrToZDNNDataLayout(rank, layout); + // If zDNNLayout is NHWC, we stickify directly from NCHW. + if (zDNNLayout == ZDNN_NHWC) + zDNNLayout = ZDNN_NCHW; + zdnn_data_types zDNNType = onnx_mlir::zhigh::mlirTypeToZDNNType(elementType); + set_info_pre_transformed_desc(&pre_tfrmd_desc, zDNNLayout, zDNNType, shape); + // transformed desc. + zdnn_status status = generate_transformed_desc(&pre_tfrmd_desc, &tfrmd_desc); + assert(status == ZDNN_OK); + // Stick data using the software stickify. + zdnn_ztensor ztensor; + init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor); + status = allochelper_ztensor_alloc(&ztensor); + assert(status == ZDNN_OK); + status = stickify(&ztensor, attrData.data()); + assert(status == ZDNN_OK); + int64_t sizeInBytes = ztensor.buffer_size; + char *rawData = (char *)malloc(sizeInBytes); + memcpy(rawData, ztensor.buffer, sizeInBytes); + allochelper_ztensor_free(&ztensor); + return llvm::ArrayRef(rawData, sizeInBytes); +} + } // namespace zhigh } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp index def0813d7b..4d353950a6 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp @@ -16,6 +16,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" +#include "src/Accelerators/NNPA/Support/Stickify/Stickify.hpp" namespace onnx_mlir { namespace zhigh { @@ -88,6 +89,13 @@ bool hasNNPAUse(mlir::Value v); /// Get saturation settings. mlir::IntegerAttr getDefaultSaturation(mlir::PatternRewriter &rewriter); +/// MLIR type to zDNN type. +zdnn_data_types mlirTypeToZDNNType(mlir::Type elementType); + +/// Get stickified data from denseElementAttribute +mlir::ArrayRef getStickifiedDataOfDenseElemAttr( + mlir::DenseElementsAttr denseAttr, mlir::StringAttr layout); + } // namespace zhigh } // namespace onnx_mlir #endif diff --git a/src/Accelerators/NNPA/Dialect/ZLow/CMakeLists.txt b/src/Accelerators/NNPA/Dialect/ZLow/CMakeLists.txt index c259f721c3..99dd227c39 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/CMakeLists.txt +++ b/src/Accelerators/NNPA/Dialect/ZLow/CMakeLists.txt @@ -11,8 +11,13 @@ add_onnx_mlir_library(OMZLowOps DEPENDS OMZLowIncGen OMONNXZLowCombineIncGen + OMKrnlGlobalOpInterface LINK_LIBS PUBLIC MLIRIR OMMlirDialects + OMZHighOps + + ACCEL_INCLUDE_DIRS PRIVATE + ${NNPA_INCLUDE_PATH} ) diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td index 63fcb0704d..4376a3d90b 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td @@ -44,6 +44,7 @@ def ZMemRef : MemRefOf<[DLF16]>; //===----------------------------------------------------------------------===// include "mlir/Interfaces/SideEffectInterfaces.td" +include "src/Interface/KrnlGlobalOpInterface.td" def ZLowAddOp:ZLow_Op<"add", [MemRefsNormalizable, DeclareOpInterfaceMethods]> { @@ -547,4 +548,20 @@ def ZLowConvertF32ToDLF16VectorOp:ZLow_Op<"vec_f32_to_dlf16", [Pure]> { ]; } +def ZLowStickifiedConstantOp:ZLow_Op<"stickifiedConstant", [MemRefsNormalizable, + DeclareOpInterfaceMethods]> { + let summary = "ZLow Stickified Constant operation."; + let description = [{ + + }]; + let arguments = (ins AnyAttr:$shape, + StrAttr:$name, + BoolAttr:$stickified, + OptionalAttr:$value, + OptionalAttr:$layout, + OptionalAttr:$offset, + DefaultValuedAttr:$alignment); + let results = (outs ZMemRef:$output); +} + #endif // ZLOW_OPS diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp b/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp index 7526933777..4cf9d79b2b 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp @@ -12,19 +12,27 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp" #include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" +#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" +#include "src/Accelerators/NNPA/Support/Stickify/Stickify.hpp" +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" using namespace mlir; @@ -358,6 +366,75 @@ void ZLowBatchNormOp::getEffects( SideEffects::DefaultResource::get()); } +/// Create a buffer for constant. Stickified data is +/// created and set if `stickified` attribute is false. +ArrayRef ZLowStickifiedConstantOp::getBuffer() { + MLIRContext *context = getOperation()->getContext(); + PatternRewriter rewriter(context); + ArrayRef ret; + if (getValueAttr()) { + StringAttr layout = getLayoutAttr(); + auto dataAttr = getValue().value(); + if (!getStickified()) { + // The case which the data in value attribute is still not stickified. + // Get the buffer after stickification. + DenseElementsAttr denseAttr = mlir::cast(dataAttr); + ret = + onnx_mlir::zhigh::getStickifiedDataOfDenseElemAttr(denseAttr, layout); + } else { + // Get the buffer from `value` attribute. + int64_t sizeInBytes = getBufferSize(); + char *rawData = (char *)malloc(sizeInBytes); + std::vector attrData; + getRawData(dataAttr, attrData); + memcpy(rawData, attrData.data(), sizeInBytes); + ret = llvm::ArrayRef(rawData, sizeInBytes); + } + } + return ret; +} + +/// Get buffer size from result. +uint64_t ZLowStickifiedConstantOp::getBufferSize() { + const Type type = getOperation()->getResults()[0].getType(); + const MemRefType memRefTy = mlir::cast(type); + auto sizeInBytes = affine::getIntOrFloatMemRefSizeInBytes(memRefTy); + return sizeInBytes.has_value() ? sizeInBytes.value() : 0; +} + +/// Free buffer created by getBuffer(). +void ZLowStickifiedConstantOp::freeBuffer(ArrayRef rawData) { + free(const_cast(rawData.data())); + return; +} + +/// Get a buffer, set/copy it to value attribute, and free the buffer. +void ZLowStickifiedConstantOp::updateValueAttr() { + MLIRContext *context = getOperation()->getContext(); + PatternRewriter rewriter(context); + // Set buffer when the value attribute is still not stickified or is splat + // with dense element attribute. + if (getValueAttr()) { + bool isStickified = getStickified(); + bool isSplat = false; + if (auto denseAttr = mlir::dyn_cast(getValue().value())) + isSplat = denseAttr.isSplat(); + if (!isStickified || isSplat) { + ArrayRef rawData = getBuffer(); + int64_t sizeInBytes = getBufferSize(); + DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get( + RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()), + getOperation() + ->getDialect() + ->getNamespace(), // use the dialect as the blob "hint" + HeapAsmResourceBlob::allocateAndCopyWithAlign( + rawData, alignof(char))); + setValueAttr(valueAttr); + freeBuffer(rawData); + } + } +} + } // namespace zlow } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp b/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp index 2050779dcb..9ebeb64447 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp @@ -24,6 +24,8 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "src/Interface/KrnlGlobalOpInterface.hpp" + /// Include the auto-generated header files containing the declarations of the /// ZLow dialect and operations. #include "src/Accelerators/NNPA/Dialect/ZLow/ZLowDialect.hpp.inc" diff --git a/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt index 7f9bfe05ec..dfc7e7f5b0 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt @@ -12,7 +12,6 @@ add_onnx_mlir_library(OMZHighConstPropagation MLIRRewrite MLIRTransformUtils OMLayoutHelper - OMStickify OMZHighOps OMONNXOps @@ -47,6 +46,9 @@ add_onnx_mlir_library(OMZHighClipToDLFloat MLIRTransformUtils OMZHighOps OMONNXOps + + ACCEL_INCLUDE_DIRS PRIVATE + ${NNPA_INCLUDE_PATH} ) add_onnx_mlir_rewriter(ZHighDecomposeStickUnstick) diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp index a32bacb4c4..62724c1c73 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp +++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp @@ -21,9 +21,11 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp" #include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" #include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" #include "src/Accelerators/NNPA/Support/Stickify/Stickify.hpp" +#include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" @@ -34,33 +36,6 @@ using namespace onnx_mlir::zhigh; namespace onnx_mlir { namespace zhigh { -/// Get raw data from a dense attribute. -static void getRawData(DenseElementsAttr denseAttr, std::vector &data) { - if (!denseAttr.isSplat()) { - data = denseAttr.getRawData(); - } else { - ShapedType denseShapeType = mlir::cast(denseAttr.getType()); - std::vector rawData = denseAttr.getRawData(); - int64_t numElements = denseShapeType.getNumElements(); - for (int i = 0; i < numElements; i++) - data.insert(data.end(), rawData.begin(), rawData.end()); - } -} - -/// MLIR type to zDNN type. -zdnn_data_types mlirTypeToZDNNType(Type elementType) { - if (mlir::isa(elementType)) { - FloatType floatTy = mlir::cast(elementType); - if (floatTy.getWidth() == 16) { - return FP16; - } else if (floatTy.getWidth() == 32) { - return FP32; - } else - llvm_unreachable("Unsupported data type."); - } else - llvm_unreachable("Unsupported data type."); -} - /// Emit a ZHighStikifiedConstant using information from a stickified ztensor. ZHighStickifiedConstantOp emitZHighStickifiedConstant(PatternRewriter &rewriter, Location loc, zdnn_ztensor *ztensor, Type outputType) { @@ -68,6 +43,7 @@ ZHighStickifiedConstantOp emitZHighStickifiedConstant(PatternRewriter &rewriter, // Create a ZHighStickifiedConstantOp. ZHighStickifiedConstantOp stickifiedConstant = rewriter.create(loc, outputType, + /*stickified=*/rewriter.getBoolAttr(true), /*value=*/nullptr, /*alignment=*/rewriter.getI64IntegerAttr(4096)); @@ -91,44 +67,44 @@ ZHighStickifiedConstantOp createConstantForStick(PatternRewriter &rewriter, Value replacingValue, Value input, StringAttr layout) { Location loc = replacingValue.getLoc(); Operation *op = input.getDefiningOp(); - ArrayRef shape = mlir::cast(input.getType()).getShape(); - Type elementType = mlir::cast(input.getType()).getElementType(); - int rank = shape.size(); - // Read dense attributes. DenseElementsAttr dataAttr = mlir::dyn_cast_or_null( op->getAttrOfType<::mlir::Attribute>("value")); assert(dataAttr && "Attribute is null"); - // Read attributes's raw data. - std::vector rawData; - getRawData(dataAttr, rawData); - // assert((rawData.size() == (uint64_t)getMemRefSizeInBytes(input)) && - // "Data size mismatched"); - - // Call stickify. - zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc; - // pre-transformed desc. - zdnn_data_layouts zDNNLayout = - convertLayoutAttrToZDNNDataLayout(rank, layout); - // If zDNNLayout is NHWC, we stickify directly from NCHW. - if (zDNNLayout == ZDNN_NHWC) - zDNNLayout = ZDNN_NCHW; - zdnn_data_types zDNNType = mlirTypeToZDNNType(elementType); - set_info_pre_transformed_desc(&pre_tfrmd_desc, zDNNLayout, zDNNType, shape); - // transformed desc. - zdnn_status status = generate_transformed_desc(&pre_tfrmd_desc, &tfrmd_desc); - assert(status == ZDNN_OK); - // Stick data using the software stickify. - zdnn_ztensor ztensor; - init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor); - status = allochelper_ztensor_alloc(&ztensor); - assert(status == ZDNN_OK); - status = stickify(&ztensor, rawData.data()); - assert(status == ZDNN_OK); - // Emit a constant global in ZHigh dialect. - ZHighStickifiedConstantOp constantOp = emitZHighStickifiedConstant( - rewriter, loc, &ztensor, replacingValue.getType()); - + // Keep previous implementation about generating stickified data at + // ZHighConstPropagationPass. To use this, comment in and set directive " + // NNPA_ZHIGH_STICKIFIEDCONST_GEN" + // + // #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN + // // Set stickified data. + // ArrayRef stickifiedData = + // getStickifiedDataOfDenseElemAttr(dataAttr, layout); + // // Create a ZHighStickifiedConstantOp. + // ZHighStickifiedConstantOp constantOp = + // rewriter.create(loc, + // replacingValue.getType(), + // /*stickified=*/rewriter.getBoolAttr(true), + // /*value=*/nullptr, + // /*alignment=*/rewriter.getI64IntegerAttr(4096)); + // + // // Use an dense resource attribute to store stickified data. + // // Attribute type: tensor + // DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get( + // RankedTensorType::get({stickifiedData.size()}, rewriter.getI8Type()), + // constantOp.getOperation() + // ->getDialect() + // ->getNamespace(), // use the dialect as the blob "hint" + // HeapAsmResourceBlob::allocateAndCopyWithAlign( + // stickifiedData, alignof(char))); + // + // constantOp.setValueAttr(valueAttr); + // #else + ZHighStickifiedConstantOp constantOp = + rewriter.create(loc, replacingValue.getType(), + /*stickified=*/rewriter.getBoolAttr(false), + /*value=*/dataAttr, + /*alignment=*/rewriter.getI64IntegerAttr(4096)); + // #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN return constantOp; } diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowDummyOpForMultiDerefPass.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowDummyOpForMultiDerefPass.cpp index b93a4f7688..c7e8ba9a65 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowDummyOpForMultiDerefPass.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowDummyOpForMultiDerefPass.cpp @@ -58,7 +58,7 @@ class ZLowDummyOpForMultiDerefPass ZLowDialect::getDialectNamespace()) { ValueRange operands = op->getOperands(); llvm::SmallSet processed; - for (uint64_t i = 0; i < operands.size() - 1; ++i) { + for (int64_t i = 0; i < (int64_t)operands.size() - 1; ++i) { if (processed.contains(i)) continue; for (uint64_t j = i + 1; j < operands.size(); ++j) { diff --git a/src/Conversion/KrnlToLLVM/CMakeLists.txt b/src/Conversion/KrnlToLLVM/CMakeLists.txt index 92948137be..7a42037e08 100644 --- a/src/Conversion/KrnlToLLVM/CMakeLists.txt +++ b/src/Conversion/KrnlToLLVM/CMakeLists.txt @@ -2,10 +2,10 @@ add_onnx_mlir_library(OMKrnlToLLVM ConvertKrnlToLLVM.cpp + KrnlGlobalOpInterface.cpp KrnlFindIndex.cpp KrnlCall.cpp KrnlEntryPoint.cpp - KrnlGlobal.cpp KrnlInstrument.cpp KrnlMemcpy.cpp KrnlNone.cpp @@ -22,6 +22,7 @@ add_onnx_mlir_library(OMKrnlToLLVM LINK_LIBS PUBLIC OMAccelerator + OMKrnlGlobalOpInterface OMSupport MLIRAffineToStandard MLIRArithTransforms diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp index a8d631b2d1..709ede3c34 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp @@ -211,6 +211,17 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns, // Use polynomial approximation for math.{tanh, sin, cos and exp} for better // performance. populateMathPolynomialApproximationPatterns(patterns); + // `arith.maxnumf/arith.minnumf` can be replaced with + // `llvm.intr.maxnum/llvm.intr.minnum` by + // populateArithToLLVMConversionPatterns, or with `arith.cmpf` and + // `arith.select` by populateArithExpandOpsPatterns. Which is applied for + // depends on the order in which the pattterns are applied. Currently, it + // should be replaced with `llvm.intr.maxnum/llvm.intr.minnum` because + // `arith.cmpf` and `arith.select do not work in float16 on ppc64le and cannot + // use SIMD, but currently there is no way to specify the order. From testing, + // following two line generates expected replacement We need to consider to + // specify the order, but we use this workaround for now. + arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); arith::populateArithExpandOpsPatterns(patterns); populateMathToLLVMConversionPatterns(typeConverter, patterns); populateFuncToLLVMConversionPatterns(typeConverter, patterns); @@ -219,7 +230,6 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns, if (enableParallel) { populateOpenMPToLLVMConversionPatterns(typeConverter, patterns); } - arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); krnl::populateKrnlToLLVMConversion(typeConverter, patterns, ctx, @@ -465,8 +475,8 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath, // Check constants with thresholds. // Do not count constants whose size is <= singleThreshold. uint64_t totalSize = 0; - SmallVector globalOfInterest; - module.walk([&](KrnlGlobalOp op) { + SmallVector globalOfInterest; + module.walk([&](KrnlGlobalOpInterface op) { // Ignore constants that are return values. bool isReturnedValue = false; for (Operation *user : op.getResult().getUsers()) { @@ -482,22 +492,23 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath, // For an unknown reason, enabling constants of bool caused segfault in the // IBM granite.20B model (The model with KV cache) at 1265 input tokens. // See issue https://github.com/onnx/onnx-mlir/issues/2713. - if (llvm::cast(op->getResult(0).getType()) + if (llvm::cast(op.getResult().getType()) .getElementType() .isInteger(1)) return WalkResult::advance(); // Get raw data from DenseElementsAttr or DenseResourceElementsAttr. - ArrayRef rawData = getRawData(op); - if (rawData.empty()) - return WalkResult::advance(); - - auto valueAttr = mlir::cast(op.getValue().value()); - if (valueAttr.isSplat() || rawData.size() <= singleThreshold) + uint64_t bufferSize = op.getBufferSize(); + if (bufferSize <= singleThreshold) return WalkResult::advance(); + if (op.getValueAttr()) { + auto valueAttr = mlir::cast(op.getValue().value()); + if (valueAttr.isSplat()) + return WalkResult::advance(); + } globalOfInterest.emplace_back(op); - totalSize += rawData.size(); + totalSize += bufferSize; return WalkResult::advance(); }); // Do not use file if the total size of satisfied constants is <= @@ -507,15 +518,16 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath, // Sort constants in the non-descending order of alignment values. // Non-alignment is the smallest value (-1), the others are positive. - llvm::sort(globalOfInterest, [&](KrnlGlobalOp left, KrnlGlobalOp right) { - int64_t leftAlign = -1; - int64_t rightAlign = -1; - if (left.getAlignment().has_value()) - leftAlign = left.getAlignment().value(); - if (right.getAlignment().has_value()) - rightAlign = right.getAlignment().value(); - return (leftAlign < rightAlign); - }); + llvm::sort(globalOfInterest, + [&](KrnlGlobalOpInterface left, KrnlGlobalOpInterface right) { + int64_t leftAlign = -1; + int64_t rightAlign = -1; + if (left.getAlignment().has_value()) + leftAlign = left.getAlignment().value(); + if (right.getAlignment().has_value()) + rightAlign = right.getAlignment().value(); + return (leftAlign < rightAlign); + }); // Store each constant into single file. // Constants with the highest alignment will be packed first in the file. @@ -525,8 +537,8 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath, std::ofstream outfile(filepath, std::ios::app | std::ios::binary); uint64_t totalConstSize = 0; for (int64_t i = globalOfInterest.size() - 1; i >= 0; --i) { - KrnlGlobalOp op = globalOfInterest[i]; - ArrayRef rawData = getRawData(op); + KrnlGlobalOpInterface op = globalOfInterest[i]; + ArrayRef rawData = op.getBuffer(); // Get alignment. int64_t alignment = -1; @@ -544,11 +556,11 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath, } op.setOffsetAttr(b.getI64IntegerAttr(totalConstSize)); - op.removeValueAttr(); outfile.write(rawData.data(), rawData.size()); totalConstSize += rawData.size(); + op.removeValueAttr(); + op.freeBuffer(rawData); } - // No constant statisfying thresholds, do not store constants to file. if (totalConstSize == 0) return false; @@ -960,7 +972,8 @@ void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter, verifyInputTensors); krnl::populateLoweringKrnlCallOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlFindIndexOpPattern(typeConverter, patterns, ctx); - krnl::populateLoweringKrnlGlobalOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlGlobalOpInterfacePattern( + typeConverter, patterns, ctx); krnl::populateLoweringKrnlInstrumentOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlMemcpyOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlPrintOpPattern(typeConverter, patterns, ctx); diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp index 2309871db4..ed5258516e 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp @@ -68,8 +68,9 @@ void populateLoweringKrnlFindIndexOpPattern( mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); -void populateLoweringKrnlGlobalOpPattern(mlir::LLVMTypeConverter &typeConverter, - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateLoweringKrnlGlobalOpInterfacePattern( + mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, + mlir::MLIRContext *ctx); void populateLoweringKrnlInstrumentOpPattern( mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, diff --git a/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp b/src/Conversion/KrnlToLLVM/KrnlGlobalOpInterface.cpp similarity index 66% rename from src/Conversion/KrnlToLLVM/KrnlGlobal.cpp rename to src/Conversion/KrnlToLLVM/KrnlGlobalOpInterface.cpp index 1c13787ac0..1e579f5595 100644 --- a/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlGlobalOpInterface.cpp @@ -2,13 +2,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===------ KrnlGlobal.cpp - Lower KrnlGlobalOp ---------------------------===// +//===------ KrnlGlobalOpInterface.cpp - Lower KrnlGlobalOpInterface -------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // -// This file lowers the KrnlGlobalOp operator. +// This file lowers the KrnlGlobalOpInterface. // //===----------------------------------------------------------------------===// @@ -35,33 +35,39 @@ namespace krnl { /// This variable is initizalied inside ConvertKrnlToLLVMPass. extern std::string EXTERNAL_CONSTANT_PREFIX; -class KrnlGlobalOpLowering : public ConvertToLLVMPattern { +class KrnlGlobalOpInterfaceLowering + : public OpInterfaceConversionPattern { + public: - explicit KrnlGlobalOpLowering( + using OpInterfaceConversionPattern< + KrnlGlobalOpInterface>::OpInterfaceConversionPattern; + + explicit KrnlGlobalOpInterfaceLowering( LLVMTypeConverter &typeConverter, MLIRContext *context) - : ConvertToLLVMPattern( - KrnlGlobalOp::getOperationName(), context, typeConverter) {} + : OpInterfaceConversionPattern(typeConverter, context) {} - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + LogicalResult matchAndRewrite(KrnlGlobalOpInterface op, + ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto krnlGlobalOp = llvm::dyn_cast(op); - Location loc = krnlGlobalOp.getLoc(); - MLIRContext *context = krnlGlobalOp.getContext(); + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); MultiDialectBuilder create(rewriter, loc); + const LLVMTypeConverter *llvmTypeConverter = + static_cast(getTypeConverter()); // Basic type. Type llvmI8Ty = IntegerType::get(context, 8); Type llvmI8PtrTy = getPointerType(context, llvmI8Ty); // The element type of the array. - const Type type = op->getResult(0).getType(); + const Type type = op.getResult().getType(); const MemRefType memRefTy = mlir::cast(type); const Type constantElementType = - typeConverter->convertType(memRefTy.getElementType()); + llvmTypeConverter->convertType(memRefTy.getElementType()); Type globalType = constantElementType; // The llvm type of the global (example: [2 x [8 x float]]). - const auto shape = mlir::dyn_cast(krnlGlobalOp.getShape()); + const auto shape = mlir::dyn_cast(op.getShape()); if (shape.empty()) globalType = LLVM::LLVMArrayType::get(mlir::cast(globalType), 1); else { @@ -74,16 +80,17 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { LLVM::GlobalOp global; // Pointer to the raw data of the global. Value dataPtr; + // Update value attribute if needed. + op.updateValueAttr(); - if (krnlGlobalOp.getValue().has_value()) { - auto value = krnlGlobalOp.getValue().value(); + if (op.getValue().has_value()) { + auto value = op.getValue().value(); TypeSwitch(value) .Case([&](DenseResourceElementsAttr attr) { - global = - lowerDenseResourceConstant(krnlGlobalOp, globalType, rewriter); + global = lowerDenseResourceConstant(op, globalType, rewriter); }) .Case([&](DenseElementsAttr attr) { - global = lowerDenseConstant(krnlGlobalOp, globalType, rewriter); + global = lowerDenseConstant(op, globalType, rewriter); }) .Default([&](Attribute attr) { llvm_unreachable("Unsupported attribute type"); @@ -91,15 +98,14 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { dataPtr = create.llvm.addressOf(global); } else { // Data are stored on files. - global = lowerGlobalOpWithExternalFiles(krnlGlobalOp, rewriter); + global = lowerGlobalOpWithExternalFiles(op, rewriter); dataPtr = create.llvm.load(llvmI8PtrTy, create.llvm.addressOf(global)); } // Set the global alignment based on the alignment attribute if it exists, // otherwise use the module datalayout info. - krnl::setAlignment(global, krnlGlobalOp.getAlignmentAttr(), - krnlGlobalOp->getParentOfType(), rewriter, - *getTypeConverter()); + krnl::setAlignment(global, op.getAlignmentAttr(), + op->getParentOfType(), rewriter, *llvmTypeConverter); // Prepare data to be inserted into a MemRefDescriptor (a struct). MemRefDescriptor memRefDescr = @@ -115,31 +121,32 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { return mlir::cast(a.getValue()[i]).getInt(); } - LLVM::GlobalOp lowerDenseResourceConstant(KrnlGlobalOp &krnlGlobalOp, - Type globalType, ConversionPatternRewriter &rewriter) const { - assert(krnlGlobalOp.getValue().has_value() && - "Expecting KrnlGlobalOp with a valid value"); - assert( - mlir::isa(krnlGlobalOp.getValue().value()) && - "Expecting a global with an dense resource elements attribute"); - - MLIRContext *context = krnlGlobalOp.getContext(); - Location loc = krnlGlobalOp.getLoc(); - ModuleOp module = krnlGlobalOp->getParentOfType(); + LLVM::GlobalOp lowerDenseResourceConstant( + KrnlGlobalOpInterface &globalOpInterface, Type globalType, + ConversionPatternRewriter &rewriter) const { + assert(globalOpInterface.getValue().has_value() && + "Expecting KrnlGlobalOpInterface with a valid value"); + assert(mlir::isa( + globalOpInterface.getValue().value()) && + "Expecting a global with an dense resource elements attribute"); + + MLIRContext *context = globalOpInterface.getContext(); + Location loc = globalOpInterface.getLoc(); + ModuleOp module = globalOpInterface->getParentOfType(); MultiDialectBuilder create(rewriter, loc); OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - auto blob = - mlir::cast(krnlGlobalOp.getValue().value()) - .getRawHandle() - .getBlob(); + auto blob = mlir::cast( + globalOpInterface.getValue().value()) + .getRawHandle() + .getBlob(); assert(blob && "Expecting dense resource with a valid blob"); ArrayRef rawData = blob->getData(); // Check data size. - uint64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp); + uint64_t sizeInBytes = computeSizeInBytes(globalOpInterface); assert(((uint64_t)rawData.size() == sizeInBytes) && "Data size mismatch."); StringRef data(rawData.data(), rawData.size()); @@ -147,23 +154,23 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { auto llvmArrayI8Ty = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), sizeInBytes); LLVM::GlobalOp global = create.llvm.globalOp(llvmArrayI8Ty, - /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(), - llvmStringAttr); + /*isConstant=*/true, LLVM::Linkage::Internal, + globalOpInterface.getName(), llvmStringAttr); LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";); return global; } - LLVM::GlobalOp lowerDenseConstant(KrnlGlobalOp &krnlGlobalOp, Type globalType, - ConversionPatternRewriter &rewriter) const { - assert(krnlGlobalOp.getValue().has_value() && - "Expecting KrnlGlobalOp with a valid value"); - assert(mlir::isa(krnlGlobalOp.getValue().value()) && + LLVM::GlobalOp lowerDenseConstant(KrnlGlobalOpInterface &globalOpInterface, + Type globalType, ConversionPatternRewriter &rewriter) const { + assert(globalOpInterface.getValue().has_value() && + "Expecting KrnlGlobalOpInterface with a valid value"); + assert(mlir::isa(globalOpInterface.getValue().value()) && "Expecting a global with an dense elements attribute"); - Location loc = krnlGlobalOp.getLoc(); - ModuleOp module = krnlGlobalOp->getParentOfType(); - MLIRContext *context = krnlGlobalOp.getContext(); + Location loc = globalOpInterface.getLoc(); + ModuleOp module = globalOpInterface->getParentOfType(); + MLIRContext *context = globalOpInterface.getContext(); MultiDialectBuilder create(rewriter, loc); Type llvmI8Ty = IntegerType::get(context, 8); @@ -172,9 +179,9 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { rewriter.setInsertionPointToStart(module.getBody()); DenseElementsAttr denseAttr = - mlir::cast(krnlGlobalOp.getValue().value()); + mlir::cast(globalOpInterface.getValue().value()); - uint64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp); + uint64_t sizeInBytes = computeSizeInBytes(globalOpInterface); LLVM::GlobalOp global; if (!(mlir::isa(denseAttr.getElementType())) && !(denseAttr.getElementType().isInteger(1)) && (!denseAttr.isSplat()) && @@ -188,15 +195,15 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { StringRef data(rawData.data(), rawData.size()); StringAttr llvmStringAttr = StringAttr::get(context, data); global = create.llvm.globalOp(llvmArrayI8Ty, - /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(), - llvmStringAttr); + /*isConstant=*/true, LLVM::Linkage::Internal, + globalOpInterface.getName(), llvmStringAttr); } else { if (mlir::isa(denseAttr.getElementType())) - global = lowerStringLiteral(krnlGlobalOp, globalType, rewriter); + global = lowerStringLiteral(globalOpInterface, globalType, rewriter); else global = create.llvm.globalOp(globalType, /*isConstant=*/true, LLVM::Linkage::Internal, - krnlGlobalOp.getName(), krnlGlobalOp.getValue().value()); + globalOpInterface.getName(), globalOpInterface.getValue().value()); } LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";); @@ -204,21 +211,24 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { } LLVM::GlobalOp lowerGlobalOpWithExternalFiles( - KrnlGlobalOp &krnlGlobalOp, ConversionPatternRewriter &rewriter) const { - Location loc = krnlGlobalOp.getLoc(); - MLIRContext *context = krnlGlobalOp.getContext(); - ModuleOp module = krnlGlobalOp.getOperation()->getParentOfType(); + KrnlGlobalOpInterface &globalOpInterface, + ConversionPatternRewriter &rewriter) const { + Location loc = globalOpInterface.getLoc(); + MLIRContext *context = globalOpInterface.getContext(); + ModuleOp module = + globalOpInterface.getOperation()->getParentOfType(); MultiDialectBuilder create(rewriter, loc); Type llvmI8Ty = IntegerType::get(context, 8); Type llvmI8PtrTy = getPointerType(context, llvmI8Ty); Type llvmI64Ty = IntegerType::get(context, 64); - auto offset = krnlGlobalOp.getOffset(); - assert(offset.has_value() && "Missing offset value in KrnlGlobalOp"); + auto offset = globalOpInterface.getOffset(); + assert( + offset.has_value() && "Missing offset value in KrnlGlobalOpInterface"); // Data is store in `constants.bin` at offset. - std::string constantName = krnlGlobalOp.getName().str(); + std::string constantName = globalOpInterface.getName().str(); // Emit globals at the begining of the module. OpBuilder::InsertionGuard insertGuard(rewriter); @@ -246,14 +256,14 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { return global; } - uint64_t computeSizeInBytes(KrnlGlobalOp &krnlGlobalOp) const { + uint64_t computeSizeInBytes(KrnlGlobalOpInterface &globalOpInterface) const { // Compute total number of elements. - const auto shape = mlir::dyn_cast(krnlGlobalOp.getShape()); + const auto shape = mlir::dyn_cast(globalOpInterface.getShape()); uint64_t numElements = 1; for (unsigned int i = 0; i < shape.size(); ++i) numElements *= ArrayAttrIntVal(shape, i); - const auto type = krnlGlobalOp.getResult().getType(); + const auto type = globalOpInterface.getResult().getType(); const auto memRefTy = mlir::cast(type); // Special handling for bool. @@ -267,8 +277,9 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { MemRefDescriptor createMemRefDescriptor(Value address, MemRefType memRefType, Location loc, OpBuilder &builder) const { Type elementType = memRefType.getElementType(); - const LLVMTypeConverter &typeConverter = *getTypeConverter(); - Type llvmElemType = typeConverter.convertType(elementType); + const LLVMTypeConverter *llvmTypeConverter = + static_cast(getTypeConverter()); + Type llvmElemType = llvmTypeConverter->convertType(elementType); MLIRContext *context = builder.getContext(); MultiDialectBuilder create(builder, loc); @@ -278,21 +289,21 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { Value bitCastOp = create.llvm.bitcast(ptrType, address); // Create llvm MemRef from original MemRef and fill the data pointers. return MemRefDescriptor::fromStaticShape( - builder, loc, typeConverter, memRefType, bitCastOp); + builder, loc, *llvmTypeConverter, memRefType, bitCastOp); } - // Generate a global string for each krnlGlobalOp string value, and store + // Generate a global string for each globalOpInterface string value, and store // the address of the global strings into an array. Return the array address. - LLVM::GlobalOp lowerStringLiteral( - KrnlGlobalOp &krnlGlobalOp, Type globalType, OpBuilder &builder) const { - assert(mlir::isa(krnlGlobalOp.getValue().value()) && + LLVM::GlobalOp lowerStringLiteral(KrnlGlobalOpInterface &globalOpInterface, + Type globalType, OpBuilder &builder) const { + assert(mlir::isa(globalOpInterface.getValue().value()) && "Expecting a dense value"); - Location loc = krnlGlobalOp.getLoc(); + Location loc = globalOpInterface.getLoc(); MultiDialectBuilder create(builder, loc); DenseElementsAttr denseAttr = - mlir::cast(krnlGlobalOp.getValue().value()); + mlir::cast(globalOpInterface.getValue().value()); Type i8PtrType = getI8PointerType(builder.getContext()); @@ -322,14 +333,14 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { auto llvmArrayI8Ty = LLVM::LLVMArrayType::get(i8Type, totalSize); LLVM::GlobalOp globalStr = create.llvm.globalOp(llvmArrayI8Ty, /*isConstant=*/true, LLVM::Linkage::Internal, - "om.strArray." + krnlGlobalOp.getName().str(), llvmStringAttr); + "om.strArray." + globalOpInterface.getName().str(), llvmStringAttr); // Generate an LLVM GlobalOps with an initializer region containing one // block. auto arrayType = LLVM::LLVMArrayType::get(i8PtrType, offsets.size()); auto global = create.llvm.globalOp(arrayType, - /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(), - Attribute()); + /*isConstant=*/true, LLVM::Linkage::Internal, + globalOpInterface.getName(), Attribute()); Region ®ion = global.getInitializerRegion(); Block *block = builder.createBlock(®ion); @@ -355,9 +366,10 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { } }; -void populateLoweringKrnlGlobalOpPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, MLIRContext *ctx) { - patterns.insert(typeConverter, ctx); +void populateLoweringKrnlGlobalOpInterfacePattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); } } // namespace krnl diff --git a/src/Dialect/Krnl/CMakeLists.txt b/src/Dialect/Krnl/CMakeLists.txt index 683e4500dc..c3c7bf8991 100644 --- a/src/Dialect/Krnl/CMakeLists.txt +++ b/src/Dialect/Krnl/CMakeLists.txt @@ -18,6 +18,7 @@ add_onnx_mlir_library(OMKrnlOps DEPENDS OMKrnlIncGen OMSpecializedKernelOpInterface + OMKrnlGlobalOpInterface LINK_LIBS PUBLIC OMCompilerOptions diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index c8220dfc53..1d5a015d29 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -27,6 +27,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td" include "src/Interface/SpecializedKernelOpInterface.td" +include "src/Interface/KrnlGlobalOpInterface.td" def Krnl_Dialect : Dialect { let name = "krnl"; @@ -406,7 +407,8 @@ def KrnlMemcpyOp : Op { +def KrnlGlobalOp : Op, MemRefsNormalizable]> { let summary = "Krnl global operation"; let description = [{ Operation for holding global data values. A global constant can have a diff --git a/src/Dialect/Krnl/KrnlOps.cpp b/src/Dialect/Krnl/KrnlOps.cpp index cec7b2d94d..bdcf2bffba 100644 --- a/src/Dialect/Krnl/KrnlOps.cpp +++ b/src/Dialect/Krnl/KrnlOps.cpp @@ -12,12 +12,14 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/Value.h" @@ -812,6 +814,34 @@ MutableOperandRange KrnlSpecializedKernel::getLoopRefs() { return getLoopsMutable(); } +ArrayRef KrnlGlobalOp::getBuffer() { + ArrayRef ret; + std::vector attrData; + if (getValueAttr()) { + int64_t sizeInBytes = getBufferSize(); + char *rawData = (char *)malloc(sizeInBytes); + auto valueAttr = getValue().value(); + getRawData(valueAttr, attrData); + memcpy(rawData, attrData.data(), sizeInBytes); + ret = llvm::ArrayRef(rawData, sizeInBytes); + } + return ret; +} + +uint64_t KrnlGlobalOp::getBufferSize() { + const Type type = getOperation()->getResults()[0].getType(); + const MemRefType memRefTy = mlir::cast(type); + auto sizeInBytes = affine::getIntOrFloatMemRefSizeInBytes(memRefTy); + return sizeInBytes.has_value() ? sizeInBytes.value() : 0; +} + +void KrnlGlobalOp::freeBuffer(ArrayRef rawData) { + free(const_cast(rawData.data())); + return; +} + +void KrnlGlobalOp::updateValueAttr() {} + //===----------------------------------------------------------------------===// // KrnlMatMulOp //===----------------------------------------------------------------------===// diff --git a/src/Dialect/Krnl/KrnlOps.hpp b/src/Dialect/Krnl/KrnlOps.hpp index 661fa7576d..fcf48a395d 100644 --- a/src/Dialect/Krnl/KrnlOps.hpp +++ b/src/Dialect/Krnl/KrnlOps.hpp @@ -21,6 +21,7 @@ #include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/Krnl/KrnlTypes.hpp" +#include "src/Interface/KrnlGlobalOpInterface.hpp" #include "src/Interface/SpecializedKernelOpInterface.hpp" #include "src/Dialect/Krnl/KrnlDialect.hpp.inc" diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 520de56339..72479a53ab 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Path.h" @@ -752,6 +753,29 @@ bool hasIntegerPowerExponent(ONNXPowOp *op, int64_t &exponentValue) { return false; } +/// Get raw data from a dense attribute. +void getRawData(Attribute dataAttr, std::vector &data) { + TypeSwitch(dataAttr) + .Case([&](DenseElementsAttr denseAttr) { + if (!denseAttr.isSplat()) { + data = denseAttr.getRawData(); + } else { + ShapedType denseShapeType = + mlir::cast(denseAttr.getType()); + std::vector rawData = denseAttr.getRawData(); + int64_t numElements = denseShapeType.getNumElements(); + for (int i = 0; i < numElements; i++) + data.insert(data.end(), rawData.begin(), rawData.end()); + } + }) + .Case( + [&](DenseResourceElementsAttr denseResourceAttr) { + data = denseResourceAttr.getRawHandle().getBlob()->getData(); + }) + .Default( + [&](Attribute attr) { llvm_unreachable("Unsupported data type."); }); +} + //===----------------------------------------------------------------------===// // Support for ReshapeOp. //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index b084ad5cd6..278a454313 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -261,6 +261,9 @@ bool isScalarTensor(mlir::Value v); bool hasIntegerPowerExponent(mlir::ONNXPowOp *op, int64_t &exponentValue); +/// Get raw data from a dense attribute. +void getRawData(mlir::Attribute dataAttr, std::vector &data); + //===----------------------------------------------------------------------===// // Support for dim operations. //===----------------------------------------------------------------------===// diff --git a/src/Interface/CMakeLists.txt b/src/Interface/CMakeLists.txt index 07a1eb6873..21b76d0f31 100644 --- a/src/Interface/CMakeLists.txt +++ b/src/Interface/CMakeLists.txt @@ -5,6 +5,7 @@ add_onnx_mlir_interface(ShapeHelperOpInterface) add_onnx_mlir_interface(ResultTypeInferenceOpInterface) add_onnx_mlir_interface(HasOnnxSubgraphOpInterface) add_onnx_mlir_interface(SpecializedKernelOpInterface) +add_onnx_mlir_interface(KrnlGlobalOpInterface) add_onnx_mlir_library(OMShapeInferenceOpInterface ShapeInferenceOpInterface.cpp @@ -61,3 +62,14 @@ add_onnx_mlir_library(OMSpecializedKernelOpInterface MLIRIR LLVMSupport ) + +add_onnx_mlir_library(OMKrnlGlobalOpInterface + KrnlGlobalOpInterface.cpp + + DEPENDS + OMKrnlGlobalOpInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRIR + LLVMSupport + ) diff --git a/src/Interface/KrnlGlobalOpInterface.cpp b/src/Interface/KrnlGlobalOpInterface.cpp new file mode 100644 index 0000000000..f54c6222c6 --- /dev/null +++ b/src/Interface/KrnlGlobalOpInterface.cpp @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-------------------- KrnlGlobalOpInterface.cpp -----------------------===// +//===---------------- KrnlGlobalOp Interface Definition -------------------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the definition of the Constant Op Interface. +// +//===----------------------------------------------------------------------===// + +#include "KrnlGlobalOpInterface.hpp" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// KrnlGlobal Op Interface +//===----------------------------------------------------------------------===// + +#include "src/Interface/KrnlGlobalOpInterface.cpp.inc" diff --git a/src/Interface/KrnlGlobalOpInterface.hpp b/src/Interface/KrnlGlobalOpInterface.hpp new file mode 100644 index 0000000000..c9adafca14 --- /dev/null +++ b/src/Interface/KrnlGlobalOpInterface.hpp @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-------------------- KrnlGlobalOpInterfaceo.hpp ----------------------===// +//===---------------- KrnlGlobal Op Interface Definition ------------------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the definition of the KrnlGlobal Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifndef ONNX_MLIR_KRNLGLOBALOP_INTERFACE_H +#define ONNX_MLIR_KRNLGLOBALOP_INTERFACE_H + +#include +#include + +#include "mlir/IR/OpDefinition.h" + +/// Include the auto-generated declarations. +#include "src/Interface/KrnlGlobalOpInterface.hpp.inc" + +#endif diff --git a/src/Interface/KrnlGlobalOpInterface.td b/src/Interface/KrnlGlobalOpInterface.td new file mode 100644 index 0000000000..71e2c1a918 --- /dev/null +++ b/src/Interface/KrnlGlobalOpInterface.td @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: Apache-2.0 + +//===-------------------- KrnlGlobalOpInterface.hpp -----------------------===// +//===---------------- KrnlGlobal Op Interface Definition ------------------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the TableGen definition of the Constant Op +// Interface Definition. +// +//===----------------------------------------------------------------------===// + +#ifdef KRNLGLOBAL_OP_INTERFACE +#else +#define KRNLGLOBAL_OP_INTERFACE + +include "mlir/IR/OpBase.td" + +def KrnlGlobalOpInterface : OpInterface<"KrnlGlobalOpInterface"> { + let description = [{ + A KrnlGlobalOp-like operation is one that holds global constant value. It has + `name` attribute, `shape` attribute, `offset` attribute, and `alighnment` + attribute. Its content is stored in the `value` attribute, which can be + converted when retrieving. + }]; + + let methods = [ + InterfaceMethod<"Get the buffer for the constant value from value attribute. " + "If conversions are required to get the buffer. It should be " + "done in this method. The constant value is stored in newly " + "allocated buffer. The buffer needs to be freed afte use by " + "using `freeBuffer()`.", + "::mlir::ArrayRef", "getBuffer", (ins ) + >, + InterfaceMethod<"Get the size of the buffer. ", + "uint64_t", "getBufferSize", (ins ) + >, + InterfaceMethod<"Free the buffer for the constant value retrieved from value " + "attribute.", + "void", "freeBuffer", (ins "::mlir::ArrayRef": $buffer) + >, + InterfaceMethod<"Update the `value` attribute by converting existing `value` " + "attribute. Assume to use getBuffer(), setValueAttr(), and " + "freeBuffer() in this function.", + "void", "updateValueAttr", (ins ) + >, + InterfaceMethod<"Get the value from the attribute.", + "std::optional", "getValue", (ins ) + >, + InterfaceMethod<"Get the `value` attribute.", + "Attribute", "getValueAttr", (ins ) + >, + InterfaceMethod<"Remove value attribute.", + "Attribute", "removeValueAttr", (ins ) + >, + InterfaceMethod<"Get the `alignment` attribute.", + "std::optional", "getAlignment", (ins ) + >, + InterfaceMethod<"Get the attribute for the alignment.", + "IntegerAttr", "getAlignmentAttr", (ins ) + >, + InterfaceMethod<"Get the `shape` attribute.", + "::mlir::Attribute", "getShape", (ins ) + >, + InterfaceMethod<"Get the `name` attribute.", + "::mlir::StringRef", "getName", (ins ) + >, + InterfaceMethod<"Set the offset to the attribute.", + "void", "setOffsetAttr", (ins "::mlir::IntegerAttr": $attr) + >, + InterfaceMethod<"Get the `offset` attribute.", + "std::optional", "getOffset", (ins ) + > + ]; + + let extraClassDeclaration = [{ + /// Return the single result of this op. + ::mlir::Value getResult() { + return getOperation()->getResult(0); + } + }]; + + let cppNamespace = "::mlir"; +} + +#endif // KRNLGLOBAL_OP_INTERFACE diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir index 9b1bd2935d..a03c59b2f4 100644 --- a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir @@ -39,13 +39,13 @@ module { // CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [2, 3, 1, 0]} : (tensor<8x1x5x5xf32>) -> tensor<5x5x1x8xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_9_:%.+]] = "zhigh.Stick"([[VAR_8_]]) {layout = "HWCK"} : (tensor<5x5x1x8xf32>) -> tensor<5x5x1x8xf16, #zhigh.layout<{dataLayout = "HWCK"}>> -// CHECK-DAG: [[VAR_10_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<8xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK-DAG: [[VAR_10_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<[-0.161539719, -0.433835655, 0.091641359, -0.0168522168, -0.0650264397, -0.131737873, 0.0204175506, -0.121110231]> : tensor<8xf32>} : () -> tensor<8xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK: [[VAR_11_:%.+]] = "zhigh.Conv2D"([[VAR_7_]], [[VAR_9_]], [[VAR_10_]]) {act_func = "ACT_RELU", kernel_shape = [5, 5], padding_type = "SAME_PADDING", strides = [1, 1]} : (tensor<1x28x28x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<5x5x1x8xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<8xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<1x28x28x8xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK-DAG: [[VAR_12_:%.+]] = "zhigh.MaxPool2D"([[VAR_11_]]) {kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [2, 2]} : (tensor<1x28x28x8xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x14x14x8xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Transpose"([[VAR_1_]]) {perm = [2, 3, 1, 0]} : (tensor<16x8x5x5xf32>) -> tensor<5x5x8x16xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_14_:%.+]] = "zhigh.Stick"([[VAR_13_]]) {layout = "HWCK"} : (tensor<5x5x8x16xf32>) -> tensor<5x5x8x16xf16, #zhigh.layout<{dataLayout = "HWCK"}>> -// CHECK-DAG: [[VAR_15_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK-DAG: [[VAR_15_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<[-0.0822488219, -0.108868778, -0.141039595, -0.204869166, -0.17913565, -0.215438381, -0.133805066, -0.195724562, -0.268250644, -0.258212209, -0.0761560649, 0.0132841459, -0.00444464432, -0.414740831, -0.17879115, -0.0386558883]> : tensor<16xf32>} : () -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK: [[VAR_16_:%.+]] = "zhigh.Conv2D"([[VAR_12_]], [[VAR_14_]], [[VAR_15_]]) {act_func = "ACT_RELU", kernel_shape = [5, 5], padding_type = "SAME_PADDING", strides = [1, 1]} : (tensor<1x14x14x8xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<5x5x8x16xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<1x14x14x16xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK: [[VAR_17_:%.+]] = "zhigh.MaxPool2D"([[VAR_16_]]) {kernel_shape = [3, 3], padding_type = "VALID_PADDING", strides = [3, 3]} : (tensor<1x14x14x16xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x4x4x16xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK: [[VAR_18_:%.+]] = "zhigh.Unstick"([[VAR_17_]]) : (tensor<1x4x4x16xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x16x4x4xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir index 2307680415..4100b41d9e 100644 --- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir @@ -478,3 +478,4 @@ func.func @test_call_zdnn_batchnorm() -> () { // CHECK-LABEL: test_call_zdnn_batchnorm // CHECK: {{.*}} = llvm.call @zdnn_batchnorm(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 } + diff --git a/test/mlir/accelerators/nnpa/conversion/lower-to-llvm-be/lit.local.cfg b/test/mlir/accelerators/nnpa/conversion/lower-to-llvm-be/lit.local.cfg new file mode 100644 index 0000000000..ac7f7ec3e6 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/lower-to-llvm-be/lit.local.cfg @@ -0,0 +1,6 @@ +if sys.byteorder == "little": + config.unsupported = True +else: + config.unsupported = False + +root = config.root diff --git a/test/mlir/accelerators/nnpa/conversion/lower-to-llvm-be/lower-all-to-llvm_be.mlir b/test/mlir/accelerators/nnpa/conversion/lower-to-llvm-be/lower-all-to-llvm_be.mlir new file mode 100644 index 0000000000..0a5c78240a --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/lower-to-llvm-be/lower-all-to-llvm_be.mlir @@ -0,0 +1,21 @@ +// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s + +// ----- + +func.func @test_stickifiedconstant() -> memref<1x1x1x1x32x64xf16> { + %0 = "zlow.stickifiedConstant"() {alignment = 4096 : i64, layout = "2D", name = "constant_stickify_0", offset = 0 : i64, shape = [1, 1, 1, 1, 32, 64], stickified = false, value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]]> : tensor<2x3xf32>} : () -> memref<1x1x1x1x32x64xf16> + return %0 : memref<1x1x1x1x32x64xf16> + + // CHECK: llvm.mlir.global internal constant @constant_stickify{addr_space = 0 : i32, alignment = 4096 : i64} + +} + +// ----- + +func.func @test_stickifiedconstant_allzero() -> memref<1x1x1x1x32x64xf16> { + %0 = "zlow.stickifiedConstant"() {alignment = 4096 : i64, layout = "2D", name = "constant_stickify_0", offset = 0 : i64, shape = [1, 1, 1, 1, 32, 64], stickified = true, value = dense<0.000000e+00> : tensor<1x1x1x32x64xf16>} : () -> memref<1x1x1x1x32x64xf16> + return %0 : memref<1x1x1x1x32x64xf16> + + // CHECK: llvm.mlir.global internal constant @constant_stickify_0("\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00") {addr_space = 0 : i32, alignment = 4096 : i64} + +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir index 2d2983ba07..c409d1f9fd 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir @@ -189,7 +189,7 @@ func.func @conv_same_padding_no_bias_unknown_dims(%arg0: tensor<1x32x32x3xf16, # // CHECK: krnl.store [[VAR_c1_i64_]], [[RES_1_]]{{.}}[[VAR_c4_]]{{.}} : memref<7xi64> // CHECK: krnl.store [[VAR_c32_i64_]], [[RES_1_]]{{.}}[[VAR_c5_]]{{.}} : memref<7xi64> // CHECK: krnl.store [[VAR_c32_i64_]], [[RES_1_]]{{.}}[[VAR_c6_]]{{.}} : memref<7xi64> -// CHECK: [[VAR_2_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_0", shape = [1, 1, 1, 1, 32, 64], value = dense_resource : tensor<4096xi8>} : () -> memref<1x1x1x1x32x64xf16> +// CHECK: [[VAR_2_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_0", offset = 0 : i64, shape = [1, 1, 1, 1, 32, 64], stickified = true, value = dense<0.000000e+00> : memref<1x1x1x1x32x64xf16>} : () -> memref<1x1x1x1x32x64xf16> // CHECK: "zlow.conv2d"([[PARAM_0_]], [[PARAM_1_]], [[VAR_2_]], [[RES_1_]], [[RES_]]) {act_func = "ACT_NONE", kernel_shape = [2, 2], padding_type = "SAME_PADDING", strides = [1, 1]} : (memref<1x32x32x3xf16, #map>, memref<2x2x3x1xf16, #map1>, memref<1x1x1x1x32x64xf16>, memref<7xi64>, memref<1x32x32x1xf16, #map>) -> () // CHECK: return [[RES_]] : memref<1x32x32x1xf16, #map> // CHECK: } diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gru.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gru.mlir index b828da2b80..c705168d4d 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gru.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gru.mlir @@ -220,8 +220,8 @@ func.func @gru_no_input_and_hidden_biases(%input : tensor // CHECK: krnl.store [[VAR_c7_i64_]], [[RES_1_]]{{.}}[[VAR_c3_]]{{.}} : memref<5xi64> // CHECK: krnl.store [[VAR_c9_i64_]], [[RES_1_]]{{.}}[[VAR_c4_]]{{.}} : memref<5xi64> -// CHECK-DAG: [[VAR_2_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_0", shape = [1, 3, 1, 1, 32, 64], value = dense_resource : tensor<12288xi8>} : () -> memref<1x3x1x1x32x64xf16> -// CHECK-DAG: [[VAR_3_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_1", shape = [1, 3, 1, 1, 32, 64], value = dense_resource : tensor<12288xi8>} : () -> memref<1x3x1x1x32x64xf16> +// CHECK-DAG: [[VAR_2_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_0", offset = 0 : i64, shape = [1, 3, 1, 1, 32, 64], stickified = true, value = dense<0.000000e+00> : memref<1x3x1x1x32x64xf16>} : () -> memref<1x3x1x1x32x64xf16> +// CHECK-DAG: [[VAR_3_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_1", offset = 0 : i64, shape = [1, 3, 1, 1, 32, 64], stickified = true, value = dense<0.000000e+00> : memref<1x3x1x1x32x64xf16>} : () -> memref<1x3x1x1x32x64xf16> // CHECK-DAG: [[VAR_dim_2_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref // CHECK-DAG: [[VAR_dim_3_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref // CHECK-NOT: separator of consecutive DAGs diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/lstm.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/lstm.mlir index e63d5cee97..a31d70231e 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/lstm.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/lstm.mlir @@ -357,8 +357,8 @@ func.func @lstm_no_input_and_hidden_biases(%input : tensor // CHECK: krnl.store [[VAR_c7_i64_]], [[RES_2_]]{{.}}[[VAR_c3_]]{{.}} : memref<5xi64> // CHECK: krnl.store [[VAR_c9_i64_]], [[RES_2_]]{{.}}[[VAR_c4_]]{{.}} : memref<5xi64> -// CHECK-DAG: [[VAR_2_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_0", shape = [1, 4, 1, 1, 32, 64], value = dense_resource : tensor<16384xi8>} : () -> memref<1x4x1x1x32x64xf16> -// CHECK-DAG: [[VAR_3_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_1", shape = [1, 4, 1, 1, 32, 64], value = dense_resource : tensor<16384xi8>} : () -> memref<1x4x1x1x32x64xf16> +// CHECK-DAG: [[VAR_2_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_0", offset = 0 : i64, shape = [1, 4, 1, 1, 32, 64], stickified = true, value = dense<0.000000e+00> : memref<1x4x1x1x32x64xf16>} : () -> memref<1x4x1x1x32x64xf16> +// CHECK-DAG: [[VAR_3_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_1", offset = 0 : i64, shape = [1, 4, 1, 1, 32, 64], stickified = true, value = dense<0.000000e+00> : memref<1x4x1x1x32x64xf16>} : () -> memref<1x4x1x1x32x64xf16> // CHECK-DAG: [[VAR_dim_3_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref // CHECK-DAG: [[VAR_dim_4_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref // CHECK-NOT: separator of consecutive DAGs diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir index 7bf9766d88..b0188969d0 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir @@ -2,24 +2,17 @@ module { func.func @remove_stick_2d() -> tensor<2x3xf32> { - %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>> + %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<[[0., 1., 2.], [3., 4., 5.]]> : tensor<2x3xf32>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>> %1 = "zhigh.Unstick"(%0) : (tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<2x3xf32> return %1 : tensor<2x3xf32> } } -{-# - dialect_resources: { - builtin: { - zhigh: "} - } -#-} // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 32, d0 mod 32, d1 mod 64)> -// CHECK-LABEL: func @remove_stick_2d +// CHECK-LABEL: func.func @remove_stick_2d // CHECK-SAME: () -> memref<2x3xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_0", shape = [1, 1, 1, 1, 32, 64], value = dense_resource : tensor<4096xi8>} : () -> memref<2x3xf16, [[MAP_0_]]> +// CHECK-DAG: [[VAR_0_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, layout = "2D", name = "constant_stickify_0", offset = 0 : i64, shape = [1, 1, 1, 1, 32, 64], stickified = false, value = dense<{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}> : tensor<2x3xf32>} : () -> memref<2x3xf16, #map> // CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x3xf32> @@ -27,8 +20,3 @@ module { // CHECK: return [[RES_]] : memref<2x3xf32> // CHECK: } -// CHECK: dialect_resources: { -// CHECK-NEXT: builtin: { -// CHECK-NEXT: zhigh: "0x0100000000003E00400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004100420042800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/test/mlir/accelerators/nnpa/driver/ccfd.mlir b/test/mlir/accelerators/nnpa/driver/ccfd.mlir index 3c66f67da7..690becf58b 100644 --- a/test/mlir/accelerators/nnpa/driver/ccfd.mlir +++ b/test/mlir/accelerators/nnpa/driver/ccfd.mlir @@ -5,15 +5,14 @@ // COM: It is the necessary condition to get the best performance. CHECK-LABEL: func.func @main_graph -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global +CHECK-DAG: zlow.stickifiedConstant +CHECK-DAG: zlow.stickifiedConstant CHECK-DAG: memref.alloc CHECK-NEXT: zlow.stick +CHECK-DAG: zlow.stickifiedConstant -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global +CHECK-DAG: zlow.stickifiedConstant +CHECK-DAG: zlow.stickifiedConstant CHECK-DAG: memref.alloc CHECK-DAG: memref.alloc CHECK-DAG: krnl.global @@ -24,12 +23,10 @@ CHECK-NEXT: zlow.lstm CHECK-NOT: zlow.stick CHECK-NOT: zlow.unstick -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global +CHECK-DAG: zlow.stickifiedConstant +CHECK-DAG: zlow.stickifiedConstant +CHECK-DAG: zlow.stickifiedConstant +CHECK-DAG: zlow.stickifiedConstant CHECK-DAG: memref.alloc CHECK-DAG: memref.alloc CHECK-DAG: krnl.global @@ -40,17 +37,17 @@ CHECK-NEXT: zlow.lstm CHECK-NOT: zlow.stick CHECK-NOT: zlow.unstick -CHECK-DAG: krnl.global +CHECK-DAG: zlow.stickifiedConstant CHECK-DAG: memref.alloc CHECK-DAG: krnl.global -CHECK-DAG: krnl.global +CHECK-DAG: zlow.stickifiedConstant CHECK-NEXT: zlow.matmul // No stick and unstick in between. CHECK-NOT: zlow.stick CHECK-NOT: zlow.unstick -CHECK-DAG: krnl.global +CHECK-DAG: zlow.stickifiedConstant CHECK-DAG: memref.alloc CHECK-DAG: krnl.global CHECK-NEXT: zlow.add diff --git a/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir b/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir index 863efd1ee4..11ccc619a0 100644 --- a/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir +++ b/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir @@ -14,6 +14,6 @@ func.func @test_matmul_add_add(%arg0: tensor, %arg1: tensor<768x768 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<768x768xf32>) -> tensor { // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor) -> tensor> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<768x768xf32>) -> tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>> -// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<49152xi8>} : () -> tensor<768xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<5.000000e+00> : tensor<768xf32>} : () -> tensor<768xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<768xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor> } diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir index 609cab1aec..27c53c501b 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir @@ -9,16 +9,11 @@ func.func @remove_stick_1d() -> tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}> %res = "zhigh.Stick"(%inp) {layout = "1D"} : (tensor<6xf32>) -> tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}>> return %res : tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}>> - // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<6xf32>} : () -> tensor<6xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "} - // CHECK-NEXT: } } // ----- @@ -31,16 +26,10 @@ func.func @remove_stick_2d() -> tensor<2x3xf32> { %res = "zhigh.Unstick"(%st) {layout = "2D"} : (tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<2x3xf32> return %res : tensor<2x3xf32> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}> : tensor<2x3xf32>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -52,16 +41,10 @@ func.func @remove_stick_2ds() -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D %res = "zhigh.Stick"(%inp) {layout = "2DS"} : (tensor<2x3xf32>) -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2DS"}>> return %res : tensor<2x3xf16, #zhigh.layout<{dataLayout = "2DS"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<8192xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2DS"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}> : tensor<2x3xf32>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2DS"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x} - // CHECK-NEXT: } } // ----- @@ -73,16 +56,10 @@ func.func @remove_stick_3d() -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3 %res = "zhigh.Stick"(%inp) {layout = "3D"} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3D"}>> return %res : tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3D"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3D"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<{{.}}{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}> : tensor<1x2x3xf32>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -94,16 +71,10 @@ func.func @remove_stick_3ds() -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = " %res = "zhigh.Stick"(%inp) {layout = "3DS"} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3DS"}>> return %res : tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3DS"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3DS"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<{{.}}{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}> : tensor<1x2x3xf32>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -115,16 +86,10 @@ func.func @remove_stick_4d() -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = %res = "zhigh.Stick"(%inp) {layout = "4D"} : (tensor<1x1x2x3xf32>) -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "4D"}>> return %res : tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "4D"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "4D"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<{{.}}[{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}]> : tensor<1x1x2x3xf32>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "4D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -136,16 +101,10 @@ func.func @remove_stick_nhwc() -> tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout %res = "zhigh.Stick"(%inp) {layout = "NHWC"} : (tensor<1x1x2x3xf32>) -> tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>> return %res : tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<8192xi8>} : () -> tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<{{.}}[{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}]> : tensor<1x1x2x3xf32>} : () -> tensor<1x2x3x1xf16, #zhigh.layout<{dataLayout = "NHWC"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x} - // CHECK-NEXT: } } // ----- @@ -157,16 +116,10 @@ func.func @remove_stick_nchw() -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout %res = "zhigh.Stick"(%inp) {layout = "NCHW"} : (tensor<1x1x2x3xf32>) -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "NCHW"}>> return %res : tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "NCHW"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<8192xi8>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "NCHW"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<{{.}}[{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}]> : tensor<1x1x2x3xf32>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "NCHW"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x} - // CHECK-NEXT: } } // ----- @@ -178,16 +131,10 @@ func.func @remove_stick_cnnk_hwck() -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLa %res = "zhigh.Stick"(%inp) {layout = "HWCK"} : (tensor<1x1x2x3xf32>) -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "HWCK"}>> return %res : tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "HWCK"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "HWCK"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<{{.}}[{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e+00]{{.}}{{.}}]> : tensor<1x1x2x3xf32>} : () -> tensor<1x1x2x3xf16, #zhigh.layout<{dataLayout = "HWCK"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "" - // CHECK-NEXT: } - // CHECK-NEXT: } } // ----- @@ -202,7 +149,7 @@ func.func @remove_stick_zrh_2d() -> tensor<2x3xf16, #zhigh.layout<{dataLayout = %res = "zhigh.StickForGRU"(%z, %r, %h) : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "ZRH"}>> return %res : tensor<2x3xf16, #zhigh.layout<{dataLayout = "ZRH"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<24576xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "ZRH"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = true, value = dense_resource : tensor<24576xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "ZRH"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.StickForGRU" @@ -226,7 +173,7 @@ func.func @remove_stick_zrh_3d() -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout %res = "zhigh.StickForGRU"(%z, %r, %h) : (tensor<1x2x3xf32>, tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "ZRH"}>> return %res : tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "ZRH"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<12288xi8>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "ZRH"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = true, value = dense_resource : tensor<12288xi8>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "ZRH"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.StickForGRU" @@ -251,7 +198,7 @@ func.func @remove_stick_fico_2d() -> tensor<2x3xf16, #zhigh.layout<{dataLayout = %res = "zhigh.StickForLSTM"(%f, %i, %c, %o) : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "FICO"}>> return %res : tensor<2x3xf16, #zhigh.layout<{dataLayout = "FICO"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<32768xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "FICO"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = true, value = dense_resource : tensor<32768xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "FICO"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.StickForLSTM" @@ -276,7 +223,7 @@ func.func @remove_stick_fico_3d() -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout %res = "zhigh.StickForLSTM"(%f, %i, %c, %o) : (tensor<1x2x3xf32>, tensor<1x2x3xf32>, tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "FICO"}>> return %res : tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "FICO"}>> - // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<16384xi8>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "FICO"}>> + // CHECK-NEXT: "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = true, value = dense_resource : tensor<16384xi8>} : () -> tensor<1x2x3xf16, #zhigh.layout<{dataLayout = "FICO"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.StickForLSTM" @@ -297,16 +244,10 @@ func.func @out_of_range_minimum() -> tensor<1xf16, #zhigh.layout<{dataLayout = " %res = "zhigh.Stick"(%inp) {layout = "1D"} : (tensor<1xf32>) -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> return %res : tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> - // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> + // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<-3.402820e+38> : tensor<1xf32>} : () -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "} - // CHECK-NEXT: } } // ----- @@ -317,14 +258,9 @@ func.func @out_of_range_maximum() -> tensor<1xf16, #zhigh.layout<{dataLayout = " %res = "zhigh.Stick"(%inp) {layout = "1D"} : (tensor<1xf32>) -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> return %res : tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> - // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<4096xi8>} : () -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> + // CHECK-NEXT: %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, stickified = false, value = dense<3.402820e+38> : tensor<1xf32>} : () -> tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>> // CHECK-NOT: "onnx.Constant" // CHECK-NOT: "zhigh.Stick" - // CHECK: dialect_resources: { - // CHECK-NEXT: builtin: { - // CHECK-NEXT: zhigh: "0x} - // CHECK-NEXT: } }