Skip to content

Commit 33b466e

Browse files
authored
[NNPA] Memory reduction of stickified constant by stickifying at file writing (onnx#2917)
Reduce memory usage for NNPA compilation. Change to set original data, not stickified data, in ZHighConstPropagationPass. Then, in the KrnlToLLVMPass, stickfied data is created and stored in the file, and deleted after writing into the file.
1 parent f935e3d commit 33b466e

37 files changed

+707
-336
lines changed

src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp

+88-45
Original file line numberDiff line numberDiff line change
@@ -190,27 +190,47 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
190190
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
191191

192192
// Create a ZHighStickifiedConstantOp.
193-
ZHighStickifiedConstantOp stickifiedConstant =
194-
rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
195-
/*value=*/nullptr,
196-
/*alignment=*/rewriter.getI64IntegerAttr(4096));
197-
198-
// Use an dense resource attribute to store stickified data.
199-
// Attribute type: tensor<sizeInBytes x i8>
200-
int64_t sizeInBytes =
201-
affine::getIntOrFloatMemRefSizeInBytes(resType).value();
202-
char *rawData = static_cast<char *>(malloc(sizeInBytes));
203-
assert(rawData && "failed to allocate memory for stickified data");
204-
memset(rawData, 0, sizeInBytes);
205-
DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get(
206-
RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
207-
stickifiedConstant.getOperation()
208-
->getDialect()
209-
->getNamespace(), // use the dialect as the blob "hint"
210-
HeapAsmResourceBlob::allocateAndCopyWithAlign(
211-
llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
212-
stickifiedConstant.setValueAttr(valueAttr);
213-
free(rawData);
193+
194+
// Keep previous implementation about generating stickified data at
195+
// ZHighConstPropagationPass. To use this, comment in and set directive "
196+
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
197+
//
198+
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
199+
// // Set zero in value attribute as DenseResourceElementsAttribute.
200+
// ZHighStickifiedConstantOp stickifiedConstant =
201+
// rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
202+
// /*stickified=*/rewriter.getBoolAttr(true),
203+
// /*value=*/nullptr,
204+
// /*alignment=*/rewriter.getI64IntegerAttr(4096));
205+
//
206+
// // Use an dense resource attribute to store stickified data.
207+
// // Attribute type: tensor<sizeInBytes x i8>
208+
// int64_t sizeInBytes =
209+
// affine::getIntOrFloatMemRefSizeInBytes(resType).value();
210+
// char *rawData = static_cast<char *>(malloc(sizeInBytes));
211+
// assert(rawData && "failed to allocate memory for stickified data");
212+
// memset(rawData, 0, sizeInBytes);
213+
// DenseResourceElementsAttr valueAttr =
214+
// DenseUI8ResourceElementsAttr::get(
215+
// RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
216+
// stickifiedConstant.getOperation()
217+
// ->getDialect()
218+
// ->getNamespace(), // use the dialect as the blob "hint"
219+
// HeapAsmResourceBlob::allocateAndCopyWithAlign(
220+
// llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
221+
// stickifiedConstant.setValueAttr(valueAttr);
222+
// free(rawData);
223+
// #else
224+
225+
// Set zero in value attribute as SplatElementsAttr.
226+
FloatAttr floatZero = rewriter.getFloatAttr(resType.getElementType(), 0.0);
227+
ZHighStickifiedConstantOp stickifiedConstant = rewriter.create<
228+
ZHighStickifiedConstantOp>(loc, resType,
229+
/*stickified=*/rewriter.getBoolAttr(true),
230+
/*value=*/SplatElementsAttr::get(cast<ShapedType>(resType), floatZero),
231+
/*alignment=*/rewriter.getI64IntegerAttr(4096));
232+
233+
// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
214234

215235
res = stickifiedConstant.getResult();
216236
} else {
@@ -686,7 +706,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
686706
};
687707

688708
//===----------------------------------------------------------------------===//
689-
// Lower ZHigh Stickified Constant to KrnlGlobal
709+
// Lower ZHigh Stickified Constant to ZLow Stickified Constant
690710
//===----------------------------------------------------------------------===//
691711

692712
struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
@@ -699,7 +719,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
699719
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
700720
ConversionPatternRewriter &rewriter) const final {
701721
Location loc = op->getLoc();
702-
ZHighStickifiedConstantOp stickifiedConstOp =
722+
ZHighStickifiedConstantOp zhighStickifiedConstOp =
703723
llvm::dyn_cast<ZHighStickifiedConstantOp>(op);
704724

705725
// Convert ZTensor type to MemRefType.
@@ -713,36 +733,59 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
713733
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
714734
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();
715735

716-
// Get dense resource attribute.
717-
auto blob = mlir::cast<DenseResourceElementsAttr>(
718-
stickifiedConstOp.getValue().value())
719-
.getRawHandle()
720-
.getBlob();
721-
assert(blob && "Expecting dense resource with a valid blob");
722-
ArrayRef<char> data = blob->getData();
723-
724-
// Validate the stickified tensor.
725-
int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
726-
memRefSizeInBytes *= normalizedType.getNumElements();
727-
assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
728-
"The stickified tensor's buffer size and MemRef's size mismatched");
729-
730-
// Create a KrnlGlobalOp.
731-
KrnlGlobalOp constantGlobal =
732-
rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
736+
// Create ZLowStickifiedConstantOp.
737+
StringAttr layout =
738+
getZTensorLayoutAttr(rewriter, *op->result_type_begin());
739+
740+
// Keep previous implementation about generating stickified data at
741+
// ZHighConstPropagationPass. To use this, comment in and set directive "
742+
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
743+
//
744+
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
745+
// // Lower to KrnlGlobalOp
746+
// // Get dense resource attribute.
747+
// auto blob = mlir::cast<DenseResourceElementsAttr>(
748+
// zhighStickifiedConstOp.getValue().value())
749+
// .getRawHandle()
750+
// .getBlob();
751+
// assert(blob && "Expecting dense resource with a valid blob");
752+
// ArrayRef<char> data = blob->getData();
753+
// // Validate the stickified tensor.
754+
// int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
755+
// memRefSizeInBytes *= normalizedType.getNumElements();
756+
// assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
757+
// "The stickified tensor's buffer size and MemRef's size
758+
// mismatched");
759+
// // Create a KrnlGlobalOp.
760+
// KrnlGlobalOp constantOp =
761+
// rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
762+
// /*shape=*/
763+
// rewriter.getI64ArrayAttr(normalizedShape),
764+
// /*name=*/
765+
// rewriter.getStringAttr(
766+
// "constant_stickify_" + std::to_string(constantID)),
767+
// /*value=*/zhighStickifiedConstOp.getValueAttr(),
768+
// /*offset=*/nullptr,
769+
// /*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
770+
// #else
771+
ZLowStickifiedConstantOp constantOp =
772+
rewriter.create<ZLowStickifiedConstantOp>(loc,
773+
mlir::cast<MemRefType>(zMemRefType.value),
733774
/*shape=*/
734775
rewriter.getI64ArrayAttr(normalizedShape),
735776
/*name=*/
736777
rewriter.getStringAttr(
737778
"constant_stickify_" + std::to_string(constantID)),
738-
/*value=*/stickifiedConstOp.getValueAttr(),
739-
/*offset=*/nullptr,
740-
/*alignment=*/stickifiedConstOp.getAlignmentAttr());
741-
779+
/*stickified=*/zhighStickifiedConstOp.getStickifiedAttr(),
780+
/*value=*/zhighStickifiedConstOp.getValueAttr(),
781+
/*layout=*/layout,
782+
/*offset=*/rewriter.getI64IntegerAttr(0),
783+
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
784+
// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
742785
// Increment constant ID:
743786
constantID++;
744787

745-
rewriter.replaceOp(op, constantGlobal.getResult());
788+
rewriter.replaceOp(op, constantOp.getResult());
746789
return success();
747790
}
748791
};

src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ add_onnx_mlir_library(OMZHighOps
4747
OMONNXOps # Use ONNXShapeHelper
4848
OMLayoutHelper
4949
OMShapeHelperOpInterface
50+
OMStickify
5051
OMNNPACompilerOptions
5152
MLIRIR
5253

src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td

+4-1
Original file line numberDiff line numberDiff line change
@@ -862,11 +862,14 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
862862
let summary = "ZHigh Stickified Constant operation";
863863
let description = [{
864864
This operator produces a constant tensor to store stickified data.
865+
`value` attribute has original constant or stickified constant.
866+
`stickified` attribute indicates the `value` is already stickified or not.
865867
Stickified data is opaque and must be 4K-aligned. One who produces
866868
the stickified data must make sure its size in bytes consistent with
867869
the output tensor's size.
868870
}];
869-
let arguments = (ins OptionalAttr<AnyAttr>:$value,
871+
let arguments = (ins BoolAttr:$stickified,
872+
OptionalAttr<AnyAttr>:$value,
870873
DefaultValuedAttr<I64Attr, "4096">:$alignment);
871874
let results = (outs AnyZTensor:$output);
872875
}

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp

+50-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp"
1414
#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
15-
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
1615
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"
1716

1817
#include "src/Dialect/ONNX/DialectBuilder.hpp"
@@ -482,5 +481,55 @@ IntegerAttr getDefaultSaturation(PatternRewriter &rewriter) {
482481
return IntegerAttr();
483482
}
484483

484+
/// MLIR type to zDNN type.
485+
zdnn_data_types mlirTypeToZDNNType(Type elementType) {
486+
if (mlir::isa<FloatType>(elementType)) {
487+
FloatType floatTy = mlir::cast<FloatType>(elementType);
488+
if (floatTy.getWidth() == 16) {
489+
return FP16;
490+
} else if (floatTy.getWidth() == 32) {
491+
return FP32;
492+
} else
493+
llvm_unreachable("Unsupported data type.");
494+
} else
495+
llvm_unreachable("Unsupported data type.");
496+
}
497+
498+
/// Get stickified data from denseElementAttribute
499+
ArrayRef<char> getStickifiedDataOfDenseElemAttr(
500+
DenseElementsAttr denseAttr, StringAttr layout) {
501+
ArrayRef<int64_t> shape = denseAttr.getType().getShape();
502+
Type elementType = denseAttr.getType().getElementType();
503+
int rank = shape.size();
504+
// Read attributes's raw data.
505+
std::vector<char> attrData;
506+
getRawData(denseAttr, attrData);
507+
// Call stickify.
508+
zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc;
509+
// pre-transformed desc.
510+
zdnn_data_layouts zDNNLayout =
511+
convertLayoutAttrToZDNNDataLayout(rank, layout);
512+
// If zDNNLayout is NHWC, we stickify directly from NCHW.
513+
if (zDNNLayout == ZDNN_NHWC)
514+
zDNNLayout = ZDNN_NCHW;
515+
zdnn_data_types zDNNType = onnx_mlir::zhigh::mlirTypeToZDNNType(elementType);
516+
set_info_pre_transformed_desc(&pre_tfrmd_desc, zDNNLayout, zDNNType, shape);
517+
// transformed desc.
518+
zdnn_status status = generate_transformed_desc(&pre_tfrmd_desc, &tfrmd_desc);
519+
assert(status == ZDNN_OK);
520+
// Stick data using the software stickify.
521+
zdnn_ztensor ztensor;
522+
init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor);
523+
status = allochelper_ztensor_alloc(&ztensor);
524+
assert(status == ZDNN_OK);
525+
status = stickify(&ztensor, attrData.data());
526+
assert(status == ZDNN_OK);
527+
int64_t sizeInBytes = ztensor.buffer_size;
528+
char *rawData = (char *)malloc(sizeInBytes);
529+
memcpy(rawData, ztensor.buffer, sizeInBytes);
530+
allochelper_ztensor_free(&ztensor);
531+
return llvm::ArrayRef(rawData, sizeInBytes);
532+
}
533+
485534
} // namespace zhigh
486535
} // namespace onnx_mlir

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/Builders.h"
1717
#include "mlir/IR/BuiltinAttributes.h"
1818
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
19+
#include "src/Accelerators/NNPA/Support/Stickify/Stickify.hpp"
1920

2021
namespace onnx_mlir {
2122
namespace zhigh {
@@ -88,6 +89,13 @@ bool hasNNPAUse(mlir::Value v);
8889
/// Get saturation settings.
8990
mlir::IntegerAttr getDefaultSaturation(mlir::PatternRewriter &rewriter);
9091

92+
/// MLIR type to zDNN type.
93+
zdnn_data_types mlirTypeToZDNNType(mlir::Type elementType);
94+
95+
/// Get stickified data from denseElementAttribute
96+
mlir::ArrayRef<char> getStickifiedDataOfDenseElemAttr(
97+
mlir::DenseElementsAttr denseAttr, mlir::StringAttr layout);
98+
9199
} // namespace zhigh
92100
} // namespace onnx_mlir
93101
#endif

src/Accelerators/NNPA/Dialect/ZLow/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@ add_onnx_mlir_library(OMZLowOps
1111
DEPENDS
1212
OMZLowIncGen
1313
OMONNXZLowCombineIncGen
14+
OMKrnlGlobalOpInterface
1415

1516
LINK_LIBS PUBLIC
1617
MLIRIR
1718
OMMlirDialects
19+
OMZHighOps
20+
21+
ACCEL_INCLUDE_DIRS PRIVATE
22+
${NNPA_INCLUDE_PATH}
1823
)

src/Accelerators/NNPA/Dialect/ZLow/ZLow.td

+17
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def ZMemRef : MemRefOf<[DLF16]>;
4444
//===----------------------------------------------------------------------===//
4545

4646
include "mlir/Interfaces/SideEffectInterfaces.td"
47+
include "src/Interface/KrnlGlobalOpInterface.td"
4748

4849
def ZLowAddOp:ZLow_Op<"add", [MemRefsNormalizable,
4950
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
@@ -547,4 +548,20 @@ def ZLowConvertF32ToDLF16VectorOp:ZLow_Op<"vec_f32_to_dlf16", [Pure]> {
547548
];
548549
}
549550

551+
def ZLowStickifiedConstantOp:ZLow_Op<"stickifiedConstant", [MemRefsNormalizable,
552+
DeclareOpInterfaceMethods<KrnlGlobalOpInterface>]> {
553+
let summary = "ZLow Stickified Constant operation.";
554+
let description = [{
555+
556+
}];
557+
let arguments = (ins AnyAttr:$shape,
558+
StrAttr:$name,
559+
BoolAttr:$stickified,
560+
OptionalAttr<AnyAttr>:$value,
561+
OptionalAttr<StrAttr>:$layout,
562+
OptionalAttr<I64Attr>:$offset,
563+
DefaultValuedAttr<I64Attr, "4096">:$alignment);
564+
let results = (outs ZMemRef:$output);
565+
}
566+
550567
#endif // ZLOW_OPS

0 commit comments

Comments
 (0)