@@ -190,27 +190,47 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
190
190
affine::normalizeMemRefType (mlir::cast<MemRefType>(zMemRefType.value ));
191
191
192
192
// Create a ZHighStickifiedConstantOp.
193
- ZHighStickifiedConstantOp stickifiedConstant =
194
- rewriter.create <ZHighStickifiedConstantOp>(loc, resType,
195
- /* value=*/ nullptr ,
196
- /* alignment=*/ rewriter.getI64IntegerAttr (4096 ));
197
-
198
- // Use an dense resource attribute to store stickified data.
199
- // Attribute type: tensor<sizeInBytes x i8>
200
- int64_t sizeInBytes =
201
- affine::getIntOrFloatMemRefSizeInBytes (resType).value ();
202
- char *rawData = static_cast <char *>(malloc (sizeInBytes));
203
- assert (rawData && " failed to allocate memory for stickified data" );
204
- memset (rawData, 0 , sizeInBytes);
205
- DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get (
206
- RankedTensorType::get ({sizeInBytes}, rewriter.getI8Type ()),
207
- stickifiedConstant.getOperation ()
208
- ->getDialect ()
209
- ->getNamespace (), // use the dialect as the blob "hint"
210
- HeapAsmResourceBlob::allocateAndCopyWithAlign (
211
- llvm::ArrayRef (rawData, sizeInBytes), alignof (char )));
212
- stickifiedConstant.setValueAttr (valueAttr);
213
- free (rawData);
193
+
194
+ // Keep previous implementation about generating stickified data at
195
+ // ZHighConstPropagationPass. To use this, comment in and set directive "
196
+ // NNPA_ZHIGH_STICKIFIEDCONST_GEN"
197
+ //
198
+ // #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
199
+ // // Set zero in value attribute as DenseResourceElementsAttribute.
200
+ // ZHighStickifiedConstantOp stickifiedConstant =
201
+ // rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
202
+ // /*stickified=*/rewriter.getBoolAttr(true),
203
+ // /*value=*/nullptr,
204
+ // /*alignment=*/rewriter.getI64IntegerAttr(4096));
205
+ //
206
+ // // Use an dense resource attribute to store stickified data.
207
+ // // Attribute type: tensor<sizeInBytes x i8>
208
+ // int64_t sizeInBytes =
209
+ // affine::getIntOrFloatMemRefSizeInBytes(resType).value();
210
+ // char *rawData = static_cast<char *>(malloc(sizeInBytes));
211
+ // assert(rawData && "failed to allocate memory for stickified data");
212
+ // memset(rawData, 0, sizeInBytes);
213
+ // DenseResourceElementsAttr valueAttr =
214
+ // DenseUI8ResourceElementsAttr::get(
215
+ // RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
216
+ // stickifiedConstant.getOperation()
217
+ // ->getDialect()
218
+ // ->getNamespace(), // use the dialect as the blob "hint"
219
+ // HeapAsmResourceBlob::allocateAndCopyWithAlign(
220
+ // llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
221
+ // stickifiedConstant.setValueAttr(valueAttr);
222
+ // free(rawData);
223
+ // #else
224
+
225
+ // Set zero in value attribute as SplatElementsAttr.
226
+ FloatAttr floatZero = rewriter.getFloatAttr (resType.getElementType (), 0.0 );
227
+ ZHighStickifiedConstantOp stickifiedConstant = rewriter.create <
228
+ ZHighStickifiedConstantOp>(loc, resType,
229
+ /* stickified=*/ rewriter.getBoolAttr (true ),
230
+ /* value=*/ SplatElementsAttr::get (cast<ShapedType>(resType), floatZero),
231
+ /* alignment=*/ rewriter.getI64IntegerAttr (4096 ));
232
+
233
+ // #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
214
234
215
235
res = stickifiedConstant.getResult ();
216
236
} else {
@@ -686,7 +706,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
686
706
};
687
707
688
708
// ===----------------------------------------------------------------------===//
689
- // Lower ZHigh Stickified Constant to KrnlGlobal
709
+ // Lower ZHigh Stickified Constant to ZLow Stickified Constant
690
710
// ===----------------------------------------------------------------------===//
691
711
692
712
struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
@@ -699,7 +719,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
699
719
LogicalResult matchAndRewrite (Operation *op, ArrayRef<Value> operands,
700
720
ConversionPatternRewriter &rewriter) const final {
701
721
Location loc = op->getLoc ();
702
- ZHighStickifiedConstantOp stickifiedConstOp =
722
+ ZHighStickifiedConstantOp zhighStickifiedConstOp =
703
723
llvm::dyn_cast<ZHighStickifiedConstantOp>(op);
704
724
705
725
// Convert ZTensor type to MemRefType.
@@ -713,36 +733,59 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
713
733
affine::normalizeMemRefType (mlir::cast<MemRefType>(zMemRefType.value ));
714
734
ArrayRef<int64_t > normalizedShape = normalizedType.getShape ();
715
735
716
- // Get dense resource attribute.
717
- auto blob = mlir::cast<DenseResourceElementsAttr>(
718
- stickifiedConstOp.getValue ().value ())
719
- .getRawHandle ()
720
- .getBlob ();
721
- assert (blob && " Expecting dense resource with a valid blob" );
722
- ArrayRef<char > data = blob->getData ();
723
-
724
- // Validate the stickified tensor.
725
- int64_t memRefSizeInBytes = getMemRefEltSizeInBytes (normalizedType);
726
- memRefSizeInBytes *= normalizedType.getNumElements ();
727
- assert ((data.size () == static_cast <uint64_t >(memRefSizeInBytes)) &&
728
- " The stickified tensor's buffer size and MemRef's size mismatched" );
729
-
730
- // Create a KrnlGlobalOp.
731
- KrnlGlobalOp constantGlobal =
732
- rewriter.create <KrnlGlobalOp>(loc, zMemRefType.value ,
736
+ // Create ZLowStickifiedConstantOp.
737
+ StringAttr layout =
738
+ getZTensorLayoutAttr (rewriter, *op->result_type_begin ());
739
+
740
+ // Keep previous implementation about generating stickified data at
741
+ // ZHighConstPropagationPass. To use this, comment in and set directive "
742
+ // NNPA_ZHIGH_STICKIFIEDCONST_GEN"
743
+ //
744
+ // #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
745
+ // // Lower to KrnlGlobalOp
746
+ // // Get dense resource attribute.
747
+ // auto blob = mlir::cast<DenseResourceElementsAttr>(
748
+ // zhighStickifiedConstOp.getValue().value())
749
+ // .getRawHandle()
750
+ // .getBlob();
751
+ // assert(blob && "Expecting dense resource with a valid blob");
752
+ // ArrayRef<char> data = blob->getData();
753
+ // // Validate the stickified tensor.
754
+ // int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
755
+ // memRefSizeInBytes *= normalizedType.getNumElements();
756
+ // assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
757
+ // "The stickified tensor's buffer size and MemRef's size
758
+ // mismatched");
759
+ // // Create a KrnlGlobalOp.
760
+ // KrnlGlobalOp constantOp =
761
+ // rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
762
+ // /*shape=*/
763
+ // rewriter.getI64ArrayAttr(normalizedShape),
764
+ // /*name=*/
765
+ // rewriter.getStringAttr(
766
+ // "constant_stickify_" + std::to_string(constantID)),
767
+ // /*value=*/zhighStickifiedConstOp.getValueAttr(),
768
+ // /*offset=*/nullptr,
769
+ // /*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
770
+ // #else
771
+ ZLowStickifiedConstantOp constantOp =
772
+ rewriter.create <ZLowStickifiedConstantOp>(loc,
773
+ mlir::cast<MemRefType>(zMemRefType.value ),
733
774
/* shape=*/
734
775
rewriter.getI64ArrayAttr (normalizedShape),
735
776
/* name=*/
736
777
rewriter.getStringAttr (
737
778
" constant_stickify_" + std::to_string (constantID)),
738
- /* value=*/ stickifiedConstOp.getValueAttr (),
739
- /* offset=*/ nullptr ,
740
- /* alignment=*/ stickifiedConstOp.getAlignmentAttr ());
741
-
779
+ /* stickified=*/ zhighStickifiedConstOp.getStickifiedAttr (),
780
+ /* value=*/ zhighStickifiedConstOp.getValueAttr (),
781
+ /* layout=*/ layout,
782
+ /* offset=*/ rewriter.getI64IntegerAttr (0 ),
783
+ /* alignment=*/ zhighStickifiedConstOp.getAlignmentAttr ());
784
+ // #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
742
785
// Increment constant ID:
743
786
constantID++;
744
787
745
- rewriter.replaceOp (op, constantGlobal .getResult ());
788
+ rewriter.replaceOp (op, constantOp .getResult ());
746
789
return success ();
747
790
}
748
791
};
0 commit comments