Skip to content

Commit 35a61d3

Browse files
cjvolzkatungld
andauthored
Add a flag to turn on/off the lowering of scalar broadcasting binary ops to NNPA (#2778) (#2782)
* Add a flag to turn on/off scalar broadcasting binary op in NNPA Signed-off-by: Tung D. Le <tung@jp.ibm.com> --------- Signed-off-by: Tung D. Le <tung@jp.ibm.com> Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com> (cherry picked from commit 08d4fed) Co-authored-by: Tung D. Le <tung@jp.ibm.com>
1 parent 41e755a commit 35a61d3

File tree

13 files changed

+77
-60
lines changed

13 files changed

+77
-60
lines changed

src/Accelerators/NNPA/Compiler/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS)
2-
31
add_onnx_mlir_library(OMNNPACompilerOptions
42
NNPACompilerOptions.cpp
53

@@ -12,7 +10,6 @@ add_onnx_mlir_library(OMNNPACompilerOptions
1210
${NNPA_ONNX_MLIR_BIN_ROOT}
1311

1412
LINK_LIBS PUBLIC
15-
${OMLibs}
1613
OMCompilerOptions
1714

1815
ACCEL_INCLUDE_DIRS PRIVATE
@@ -32,7 +29,6 @@ add_onnx_mlir_library(OMNNPACompilerUtils
3229
${NNPA_ONNX_MLIR_BIN_ROOT}
3330

3431
LINK_LIBS PUBLIC
35-
${OMLibs}
3632
OMNNPACompilerOptions
3733
OMCompilerPasses
3834

src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ llvm::cl::opt<bool> nnpaEnableCompilerStickUnstick(
5555
"stick/unstick code. Default is false."),
5656
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));
5757

58+
llvm::cl::opt<bool> nnpaEnableScalarBcastBinary(
59+
"nnpa-enable-scalar-bcast-binary",
60+
llvm::cl::desc("Enable the lowering to NNPA the broadcasting binary ops "
61+
"whose one of the operands is scalar. Currently support "
62+
"ONNXDiv only. Default is false."),
63+
llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));
64+
5865
llvm::cl::opt<std::string> nnpaLoadDevicePlacementFile{
5966
"nnpa-load-device-placement-file",
6067
llvm::cl::desc(

src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@ typedef enum {
4949
} NNPAPlacementHeuristic;
5050

5151
extern llvm::cl::OptionCategory OnnxMlirOptions;
52+
extern llvm::cl::OptionCategory OnnxMlirCommonOptions;
5253
extern llvm::cl::opt<onnx_mlir::NNPAEmissionTargetType> nnpaEmissionTarget;
5354
extern llvm::cl::opt<bool> nnpaClipToDLFloatRange;
5455
extern llvm::cl::opt<bool> nnpaEnableZHighToOnnx;
5556
extern llvm::cl::opt<bool> nnpaEnableZHighDecomposeStickUnstick;
5657
extern llvm::cl::opt<bool> nnpaEnableCompilerStickUnstick;
58+
extern llvm::cl::opt<bool> nnpaEnableScalarBcastBinary;
5759
extern llvm::cl::opt<NNPAPlacementHeuristic> nnpaPlacementHeuristic;
5860
extern llvm::cl::opt<bool> profileZHighIR;
5961
extern llvm::cl::opt<std::string> nnpaLoadDevicePlacementFile;

src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ add_onnx_mlir_library(OMONNXToZHigh
1111
libzdnn
1212

1313
LINK_LIBS PUBLIC
14-
OMCompilerOptions
14+
OMNNPACompilerOptions
1515
OMONNXOps
1616
OMONNXToKrnl
1717
OMZHighOps
@@ -32,7 +32,7 @@ add_onnx_mlir_library(OMRewriteONNXForZHigh
3232
libzdnn
3333

3434
LINK_LIBS PUBLIC
35-
OMCompilerOptions
35+
OMNNPACompilerOptions
3636
OMONNXOps
3737
OMONNXToKrnl
3838
OMZHighOps

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,19 @@ bool isSuitableForZDNN<ONNXDivOp>(
324324
// Check NNPA level.
325325
if (!isCompatibleWithNNPALevel(NNPA_Z16))
326326
return false;
327-
if (!isF32ScalarConstantTensor(A) && !isValidElementTypeAndRank(A))
327+
// Broadcast with a scalar operand.
328+
if (isEnableScalarBcastBinary()) {
329+
if (isF32ScalarConstantTensor(A) && isValidElementTypeAndRank(B))
330+
return true;
331+
if (isF32ScalarConstantTensor(B) && isValidElementTypeAndRank(A))
332+
return true;
333+
}
334+
// Non-broadcast cases.
335+
if (!isValidElementTypeAndRank(A))
328336
return false;
329-
if (!isF32ScalarConstantTensor(B) && !isValidElementTypeAndRank(B))
337+
if (!isValidElementTypeAndRank(B))
330338
return false;
331-
return isF32ScalarConstantTensor(A) || isF32ScalarConstantTensor(B) ||
332-
dimAnalysis->sameShape(A, B);
339+
return dimAnalysis->sameShape(A, B);
333340
}
334341

335342
/// Check legality for ONNXSum.

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td"
2929
/// dag benefitsAdded = (addBenefit 0)
3030
/// >;
3131

32+
def IsEnableScalarBcastBinary: Constraint<CPred<"isEnableScalarBcastBinary()">>;
33+
3234
def IsNoneType : Constraint<CPred<"(($_self).getType().isa<NoneType>())">>;
3335

3436
def IsNotNoneType : Constraint<CPred<"(!($_self).getType().isa<NoneType>())">>;
@@ -227,7 +229,7 @@ def replaceONNXDivBroadcastPattern1 : Pat<
227229
(GetScalarF32AttrFromConstant $y),
228230
(NoneLayoutAttr)),
229231
(returnType $s_x))),
230-
[(IsF32ScalarConstantTensor $y)], [],
232+
[(IsEnableScalarBcastBinary), (IsF32ScalarConstantTensor $y)], [],
231233
(addBenefit 1)
232234
>;
233235

@@ -241,7 +243,7 @@ def replaceONNXDivBroadcastPattern2 : Pat<
241243
(NoneLayoutAttr)),
242244
(ZHighStickOp:$s_y $y, (NoneLayoutAttr)),
243245
(returnType $s_y))),
244-
[(IsF32ScalarConstantTensor $x)], [],
246+
[(IsEnableScalarBcastBinary), (IsF32ScalarConstantTensor $x)], [],
245247
(addBenefit 1)
246248
>;
247249

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
//===----------------------------------------------------------------------===//
1515

1616
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
17+
#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
1718
#include "src/Dialect/ONNX/DialectBuilder.hpp"
1819

1920
using namespace mlir;
2021
namespace onnx_mlir {
2122

23+
bool isEnableScalarBcastBinary() { return nnpaEnableScalarBcastBinary; }
24+
2225
/// Get transposed tensor by using a permutation array.
2326
Value emitONNXTranspose(
2427
Location loc, PatternRewriter &rewriter, Value x, ArrayRef<int64_t> perms) {

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ const std::string DEVICE_ATTRIBUTE = "device";
2727
const std::string CPU_DEVICE = "cpu";
2828
const std::string NNPA_DEVICE = "nnpa";
2929

30+
bool isEnableScalarBcastBinary();
31+
3032
template <typename OP_TYPE>
3133
void addDynamicallyLegalOpFor(mlir::ConversionTarget *target,
3234
const onnx_mlir::DimAnalysis *dimAnalysis,

src/Accelerators/NNPA/Pass/NNPAPasses.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ std::unique_ptr<mlir::Pass> createDevicePlacementPass(
2929

3030
/// Add pass for lowering ONNX ops to ZHigh ops.
3131
std::unique_ptr<mlir::Pass> createONNXToZHighPass();
32-
std::unique_ptr<mlir::Pass> createONNXToZHighPass();
3332

3433
/// Add pass for rewriting ONNX ops for ZHigh.
3534
std::unique_ptr<mlir::Pass> createRewriteONNXForZHighPass();
36-
std::unique_ptr<mlir::Pass> createRewriteONNXForZHighPass();
3735

3836
/// Add pass for re-construct ONNX ops from ZHigh ops.
3937
std::unique_ptr<mlir::Pass> createZHighToONNXPass();

src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class KrnlGetLinearOffsetIndexLowering : public ConversionPattern {
5353

5454
auto memrefTy = llvm::dyn_cast<MemRefType>(memref.getType());
5555
int64_t rank = memrefTy.getRank();
56-
assert(mapResults.value().size() == rank && "Invalid indices");
56+
assert((int64_t)mapResults.value().size() == rank && "Invalid indices");
5757

5858
// Only lower this op after the memref is normalized.
5959
if (!memrefTy.getLayout().isIdentity())
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --nnpa-enable-scalar-bcast-binary %s -split-input-file | FileCheck %s
2+
3+
// COM: Division by a scalar in case of dynamic dimensions.
4+
func.func @test_div_unknown_scalar1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
5+
%0 = onnx.Constant dense<8.000000e+00> : tensor<f32>
6+
%1 = "onnx.Div"(%arg0, %0) : (tensor<?x10xf32>, tensor<f32>) -> tensor<*xf32>
7+
"func.return"(%1) : (tensor<*xf32>) -> ()
8+
9+
// CHECK-LABEL: func.func @test_div_unknown_scalar1
10+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
11+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor<f32>
12+
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<?x10xf32>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
13+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
14+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
15+
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
16+
// CHECK: [[VAR_5_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_4_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>
17+
// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_1_]], [[VAR_5_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
18+
// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf32>
19+
// CHECK: return [[VAR_7_]] : tensor<?x10xf32>
20+
// CHECK: }
21+
}
22+
23+
// -----
24+
25+
// COM: Division by a scalar in case of dynamic dimensions.
26+
func.func @test_div_unknown_scalar2(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
27+
%0 = onnx.Constant dense<8.000000e+00> : tensor<f32>
28+
%1 = "onnx.Div"(%0, %arg0) : (tensor<f32>, tensor<?x10xf32>) -> tensor<*xf32>
29+
"func.return"(%1) : (tensor<*xf32>) -> ()
30+
31+
// CHECK-LABEL: func.func @test_div_unknown_scalar2
32+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
33+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor<f32>
34+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
35+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
36+
// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
37+
// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_3_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>
38+
// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<?x10xf32>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
39+
// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_4_]], [[VAR_5_]]) : (tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
40+
// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf32>
41+
// CHECK: return [[VAR_7_]] : tensor<?x10xf32>
42+
// CHECK: }
43+
}
44+

test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,50 +32,6 @@ func.func @test_div_3ds(%arg0 : tensor<10x10x10xf32>, %arg1 : tensor<10x10x10xf3
3232

3333
// -----
3434

35-
// COM: Division by a scalar in case of dynamic dimensions.
36-
func.func @test_div_unknown_scalar1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
37-
%0 = onnx.Constant dense<8.000000e+00> : tensor<f32>
38-
%1 = "onnx.Div"(%arg0, %0) : (tensor<?x10xf32>, tensor<f32>) -> tensor<*xf32>
39-
"func.return"(%1) : (tensor<*xf32>) -> ()
40-
41-
// CHECK-LABEL: func.func @test_div_unknown_scalar1
42-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
43-
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor<f32>
44-
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<?x10xf32>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
45-
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
46-
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
47-
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
48-
// CHECK: [[VAR_5_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_4_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>
49-
// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_1_]], [[VAR_5_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
50-
// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf32>
51-
// CHECK: return [[VAR_7_]] : tensor<?x10xf32>
52-
// CHECK: }
53-
}
54-
55-
// -----
56-
57-
// COM: Division by a scalar in case of dynamic dimensions.
58-
func.func @test_div_unknown_scalar2(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
59-
%0 = onnx.Constant dense<8.000000e+00> : tensor<f32>
60-
%1 = "onnx.Div"(%0, %arg0) : (tensor<f32>, tensor<?x10xf32>) -> tensor<*xf32>
61-
"func.return"(%1) : (tensor<*xf32>) -> ()
62-
63-
// CHECK-LABEL: func.func @test_div_unknown_scalar2
64-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
65-
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor<f32>
66-
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
67-
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
68-
// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
69-
// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_3_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>
70-
// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<?x10xf32>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
71-
// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_4_]], [[VAR_5_]]) : (tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
72-
// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf32>
73-
// CHECK: return [[VAR_7_]] : tensor<?x10xf32>
74-
// CHECK: }
75-
}
76-
77-
// -----
78-
7935
// COM: Do not lower broadcasting onnx.Div to zHigh.
8036
func.func @test_div_not_lowered_diff_shape(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10xf32>) -> tensor<*xf32> {
8137
%0 = "onnx.Div"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10xf32>) -> tensor<*xf32>

test/mlir/accelerators/nnpa/driver/matmul-div-in-attention-layer.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --printIR %s | FileCheck %s
1+
// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --nnpa-enable-scalar-bcast-binary --printIR %s | FileCheck %s
22

33
// Check whether the compiler can remove unstick/stick so that the output of zdnn matmul is passed directly to zdnn div.
44
func.func @matmul_div(%arg0: tensor<?x12x?x64xf32>) -> tensor<?x?x?x?xf32> {

0 commit comments

Comments
 (0)