Skip to content

Commit

Permalink
[NNPA] Memory reduction of stickified constant by stickifying at file…
Browse files Browse the repository at this point in the history
… writing (onnx#2917)

Reduce memory usage for NNPA compilation. Change to set original data, not stickified data, in ZHighConstPropagationPass. Then, in the KrnlToLLVMPass, stickfied data is created and stored in the file, and deleted after writing into the file.
  • Loading branch information
imaihal authored Nov 5, 2024
1 parent f935e3d commit 33b466e
Show file tree
Hide file tree
Showing 37 changed files with 707 additions and 336 deletions.
133 changes: 88 additions & 45 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,27 +190,47 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));

// Create a ZHighStickifiedConstantOp.
ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
/*value=*/nullptr,
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// Use an dense resource attribute to store stickified data.
// Attribute type: tensor<sizeInBytes x i8>
int64_t sizeInBytes =
affine::getIntOrFloatMemRefSizeInBytes(resType).value();
char *rawData = static_cast<char *>(malloc(sizeInBytes));
assert(rawData && "failed to allocate memory for stickified data");
memset(rawData, 0, sizeInBytes);
DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get(
RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
stickifiedConstant.getOperation()
->getDialect()
->getNamespace(), // use the dialect as the blob "hint"
HeapAsmResourceBlob::allocateAndCopyWithAlign(
llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
stickifiedConstant.setValueAttr(valueAttr);
free(rawData);

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Set zero in value attribute as DenseResourceElementsAttribute.
// ZHighStickifiedConstantOp stickifiedConstant =
// rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
// /*stickified=*/rewriter.getBoolAttr(true),
// /*value=*/nullptr,
// /*alignment=*/rewriter.getI64IntegerAttr(4096));
//
// // Use an dense resource attribute to store stickified data.
// // Attribute type: tensor<sizeInBytes x i8>
// int64_t sizeInBytes =
// affine::getIntOrFloatMemRefSizeInBytes(resType).value();
// char *rawData = static_cast<char *>(malloc(sizeInBytes));
// assert(rawData && "failed to allocate memory for stickified data");
// memset(rawData, 0, sizeInBytes);
// DenseResourceElementsAttr valueAttr =
// DenseUI8ResourceElementsAttr::get(
// RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
// stickifiedConstant.getOperation()
// ->getDialect()
// ->getNamespace(), // use the dialect as the blob "hint"
// HeapAsmResourceBlob::allocateAndCopyWithAlign(
// llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
// stickifiedConstant.setValueAttr(valueAttr);
// free(rawData);
// #else

// Set zero in value attribute as SplatElementsAttr.
FloatAttr floatZero = rewriter.getFloatAttr(resType.getElementType(), 0.0);
ZHighStickifiedConstantOp stickifiedConstant = rewriter.create<
ZHighStickifiedConstantOp>(loc, resType,
/*stickified=*/rewriter.getBoolAttr(true),
/*value=*/SplatElementsAttr::get(cast<ShapedType>(resType), floatZero),
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN

res = stickifiedConstant.getResult();
} else {
Expand Down Expand Up @@ -686,7 +706,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
};

//===----------------------------------------------------------------------===//
// Lower ZHigh Stickified Constant to KrnlGlobal
// Lower ZHigh Stickified Constant to ZLow Stickified Constant
//===----------------------------------------------------------------------===//

struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
Expand All @@ -699,7 +719,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
ZHighStickifiedConstantOp stickifiedConstOp =
ZHighStickifiedConstantOp zhighStickifiedConstOp =
llvm::dyn_cast<ZHighStickifiedConstantOp>(op);

// Convert ZTensor type to MemRefType.
Expand All @@ -713,36 +733,59 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();

// Get dense resource attribute.
auto blob = mlir::cast<DenseResourceElementsAttr>(
stickifiedConstOp.getValue().value())
.getRawHandle()
.getBlob();
assert(blob && "Expecting dense resource with a valid blob");
ArrayRef<char> data = blob->getData();

// Validate the stickified tensor.
int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
memRefSizeInBytes *= normalizedType.getNumElements();
assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");

// Create a KrnlGlobalOp.
KrnlGlobalOp constantGlobal =
rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
// Create ZLowStickifiedConstantOp.
StringAttr layout =
getZTensorLayoutAttr(rewriter, *op->result_type_begin());

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Lower to KrnlGlobalOp
// // Get dense resource attribute.
// auto blob = mlir::cast<DenseResourceElementsAttr>(
// zhighStickifiedConstOp.getValue().value())
// .getRawHandle()
// .getBlob();
// assert(blob && "Expecting dense resource with a valid blob");
// ArrayRef<char> data = blob->getData();
// // Validate the stickified tensor.
// int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
// memRefSizeInBytes *= normalizedType.getNumElements();
// assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
// "The stickified tensor's buffer size and MemRef's size
// mismatched");
// // Create a KrnlGlobalOp.
// KrnlGlobalOp constantOp =
// rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
// /*shape=*/
// rewriter.getI64ArrayAttr(normalizedShape),
// /*name=*/
// rewriter.getStringAttr(
// "constant_stickify_" + std::to_string(constantID)),
// /*value=*/zhighStickifiedConstOp.getValueAttr(),
// /*offset=*/nullptr,
// /*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #else
ZLowStickifiedConstantOp constantOp =
rewriter.create<ZLowStickifiedConstantOp>(loc,
mlir::cast<MemRefType>(zMemRefType.value),
/*shape=*/
rewriter.getI64ArrayAttr(normalizedShape),
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*value=*/stickifiedConstOp.getValueAttr(),
/*offset=*/nullptr,
/*alignment=*/stickifiedConstOp.getAlignmentAttr());

/*stickified=*/zhighStickifiedConstOp.getStickifiedAttr(),
/*value=*/zhighStickifiedConstOp.getValueAttr(),
/*layout=*/layout,
/*offset=*/rewriter.getI64IntegerAttr(0),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
// Increment constant ID:
constantID++;

rewriter.replaceOp(op, constantGlobal.getResult());
rewriter.replaceOp(op, constantOp.getResult());
return success();
}
};
Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ add_onnx_mlir_library(OMZHighOps
OMONNXOps # Use ONNXShapeHelper
OMLayoutHelper
OMShapeHelperOpInterface
OMStickify
OMNNPACompilerOptions
MLIRIR

Expand Down
5 changes: 4 additions & 1 deletion src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,14 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
let summary = "ZHigh Stickified Constant operation";
let description = [{
This operator produces a constant tensor to store stickified data.
`value` attribute has original constant or stickified constant.
`stickified` attribute indicates the `value` is already stickified or not.
Stickified data is opaque and must be 4K-aligned. One who produces
the stickified data must make sure its size in bytes consistent with
the output tensor's size.
}];
let arguments = (ins OptionalAttr<AnyAttr>:$value,
let arguments = (ins BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs AnyZTensor:$output);
}
Expand Down
51 changes: 50 additions & 1 deletion src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp"
#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"

#include "src/Dialect/ONNX/DialectBuilder.hpp"
Expand Down Expand Up @@ -482,5 +481,55 @@ IntegerAttr getDefaultSaturation(PatternRewriter &rewriter) {
return IntegerAttr();
}

/// MLIR type to zDNN type.
zdnn_data_types mlirTypeToZDNNType(Type elementType) {
if (mlir::isa<FloatType>(elementType)) {
FloatType floatTy = mlir::cast<FloatType>(elementType);
if (floatTy.getWidth() == 16) {
return FP16;
} else if (floatTy.getWidth() == 32) {
return FP32;
} else
llvm_unreachable("Unsupported data type.");
} else
llvm_unreachable("Unsupported data type.");
}

/// Get stickified data from denseElementAttribute
ArrayRef<char> getStickifiedDataOfDenseElemAttr(
DenseElementsAttr denseAttr, StringAttr layout) {
ArrayRef<int64_t> shape = denseAttr.getType().getShape();
Type elementType = denseAttr.getType().getElementType();
int rank = shape.size();
// Read attributes's raw data.
std::vector<char> attrData;
getRawData(denseAttr, attrData);
// Call stickify.
zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc;
// pre-transformed desc.
zdnn_data_layouts zDNNLayout =
convertLayoutAttrToZDNNDataLayout(rank, layout);
// If zDNNLayout is NHWC, we stickify directly from NCHW.
if (zDNNLayout == ZDNN_NHWC)
zDNNLayout = ZDNN_NCHW;
zdnn_data_types zDNNType = onnx_mlir::zhigh::mlirTypeToZDNNType(elementType);
set_info_pre_transformed_desc(&pre_tfrmd_desc, zDNNLayout, zDNNType, shape);
// transformed desc.
zdnn_status status = generate_transformed_desc(&pre_tfrmd_desc, &tfrmd_desc);
assert(status == ZDNN_OK);
// Stick data using the software stickify.
zdnn_ztensor ztensor;
init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor);
status = allochelper_ztensor_alloc(&ztensor);
assert(status == ZDNN_OK);
status = stickify(&ztensor, attrData.data());
assert(status == ZDNN_OK);
int64_t sizeInBytes = ztensor.buffer_size;
char *rawData = (char *)malloc(sizeInBytes);
memcpy(rawData, ztensor.buffer, sizeInBytes);
allochelper_ztensor_free(&ztensor);
return llvm::ArrayRef(rawData, sizeInBytes);
}

} // namespace zhigh
} // namespace onnx_mlir
8 changes: 8 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Support/Stickify/Stickify.hpp"

namespace onnx_mlir {
namespace zhigh {
Expand Down Expand Up @@ -88,6 +89,13 @@ bool hasNNPAUse(mlir::Value v);
/// Get saturation settings.
mlir::IntegerAttr getDefaultSaturation(mlir::PatternRewriter &rewriter);

/// MLIR type to zDNN type.
zdnn_data_types mlirTypeToZDNNType(mlir::Type elementType);

/// Get stickified data from denseElementAttribute
mlir::ArrayRef<char> getStickifiedDataOfDenseElemAttr(
mlir::DenseElementsAttr denseAttr, mlir::StringAttr layout);

} // namespace zhigh
} // namespace onnx_mlir
#endif
5 changes: 5 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZLow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ add_onnx_mlir_library(OMZLowOps
DEPENDS
OMZLowIncGen
OMONNXZLowCombineIncGen
OMKrnlGlobalOpInterface

LINK_LIBS PUBLIC
MLIRIR
OMMlirDialects
OMZHighOps

ACCEL_INCLUDE_DIRS PRIVATE
${NNPA_INCLUDE_PATH}
)
17 changes: 17 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def ZMemRef : MemRefOf<[DLF16]>;
//===----------------------------------------------------------------------===//

include "mlir/Interfaces/SideEffectInterfaces.td"
include "src/Interface/KrnlGlobalOpInterface.td"

def ZLowAddOp:ZLow_Op<"add", [MemRefsNormalizable,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Expand Down Expand Up @@ -547,4 +548,20 @@ def ZLowConvertF32ToDLF16VectorOp:ZLow_Op<"vec_f32_to_dlf16", [Pure]> {
];
}

def ZLowStickifiedConstantOp:ZLow_Op<"stickifiedConstant", [MemRefsNormalizable,
DeclareOpInterfaceMethods<KrnlGlobalOpInterface>]> {
let summary = "ZLow Stickified Constant operation.";
let description = [{

}];
let arguments = (ins AnyAttr:$shape,
StrAttr:$name,
BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
OptionalAttr<StrAttr>:$layout,
OptionalAttr<I64Attr>:$offset,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs ZMemRef:$output);
}

#endif // ZLOW_OPS
Loading

0 comments on commit 33b466e

Please sign in to comment.