Skip to content

Commit

Permalink
Merge branch 'main' into simd-framwork-v1
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandreEichenberger authored Sep 17, 2024
2 parents 7add7d9 + a6ebca0 commit 4836f22
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout f142f8afe21bceb00fb495468aa0b5043e98c419 && cd ..
cd llvm-project && git checkout eaa95a1c2bd38332c1a4e634595f29d22b28ffea && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout f142f8afe21bceb00fb495468aa0b5043e98c419 && cd ..
cd llvm-project && git checkout eaa95a1c2bd38332c1a4e634595f29d22b28ffea && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
17 changes: 13 additions & 4 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both
std::vector<std::string> functionsToDecompose; // common for both
std::string opsForCall; // common for both
bool disableKrnlOpFusion; // common for both
bool disableQuantZeroPoint; // common for both
bool enableKrnlBufferReuse; // common for both
bool disableMemRefPrefetch; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
Expand Down Expand Up @@ -195,7 +196,7 @@ static llvm::cl::list<std::string, std::vector<std::string>>
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> enableONNXHybridPassOpt("onnx-hybrid-pass",
llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n"
llvm::cl::desc("Enable ONNX hybrid pass (default=true).\n"
"Set to 'false' if you want to disable ONNX hybrid pass."),
llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirCommonOptions));
Expand All @@ -208,11 +209,20 @@ static llvm::cl::list<std::string, std::vector<std::string>>

static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
"disable-krnl-op-fusion",
llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n"
llvm::cl::desc("Disable op fusion in onnx-to-krnl pass (default=false).\n"
"Set to 'true' if you want to disable fusion."),
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disable_quantization_zero_point(
"disable-quantization-zero-point",
llvm::cl::desc(
"Disable the use of zero-point in quantization (default=false).\n"
"Set to 'true' if you want to disable the use of zero-point\n"
"in dyn/static quantization/dequantization."),
llvm::cl::location(disableQuantZeroPoint), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
"enable-krnl-buffer-reuse",
llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass"
Expand All @@ -223,7 +233,7 @@ static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(

static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
"disable-memref-prefetch",
llvm::cl::desc("disable generation of memref.prefetch (default=false)\n"
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"
"Set to 'true' if you want to disable prefetch."),
llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));
Expand Down Expand Up @@ -1145,7 +1155,6 @@ std::string getLibraryPath() {
// as lrodataScript.
std::string getToolPath(
const std::string &tool, bool flag /*false by default*/) {

if (!flag) {
std::string execDir = llvm::sys::path::parent_path(getExecPath()).str();
llvm::SmallString<8> toolPath(execDir);
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both
extern std::vector<std::string> functionsToDecompose; // common for both
extern std::string opsForCall; // common for both
extern bool disableKrnlOpFusion; // common for both
extern bool disableQuantZeroPoint; // common for both
extern bool enableKrnlBufferReuse; // common for both
extern bool disableMemRefPrefetch; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
Expand Down
10 changes: 8 additions & 2 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1358,9 +1358,15 @@ Value emitScalarOpFor<ONNXDequantizeLinearOp>(
Value scaleFloat = scalarOperands[1];
Value zeroPointInt = scalarOperands[2];

Value zeroPointFloat = create.math.cast(elementType, zeroPointInt);
Value xFloat = create.math.cast(elementType, XInt);
Value sub = create.math.sub(xFloat, zeroPointFloat);

Value sub;
if (!disableQuantZeroPoint && !isNoneValue(zeroPointInt)) {
Value zeroPointFloat = create.math.cast(elementType, zeroPointInt);
sub = create.math.sub(xFloat, zeroPointFloat);
} else {
sub = xFloat;
}
Value res = create.math.mul(sub, scaleFloat);
return res;
}
Expand Down
23 changes: 15 additions & 8 deletions src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
Expand All @@ -29,7 +30,7 @@ void emitDynamicQuantizationLinearScalarParameters(
ConversionPatternRewriter &rewriter, Location loc, Operation *op,
MemRefType inputType, MemRefType quantizedType, Value input, Value qMin,
Value qMax, Value &scale, Value &zeroPoint, Value &quantizedZeroPoint,
bool enableSIMD, bool enableParallel) {
bool wantZeroPoint, bool enableSIMD, bool enableParallel) {
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);

// Types
Expand Down Expand Up @@ -62,11 +63,15 @@ void emitDynamicQuantizationLinearScalarParameters(
scale = create.math.div(xDiff, boundDiff);

// Compute y_zero_point.
Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale));
// Saturate zero point.
Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax);
// Round zero point.
zeroPoint = create.math.round(saturateZeroPoint);
if (wantZeroPoint) {
Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale));
// Saturate zero point.
Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax);
// Round zero point.
zeroPoint = create.math.round(saturateZeroPoint);
} else {
zeroPoint = zero;
}
quantizedZeroPoint = create.math.cast(quantizedElementType, zeroPoint);
}

Expand Down Expand Up @@ -122,15 +127,17 @@ struct ONNXDynamicQuantizeLinearOpLowering
Value qMin = create.math.constant(elementType, 0.0);
Value scale, zeroPoint, zeroPointInt;

bool wantZeroPoint = !disableQuantZeroPoint;
emitDynamicQuantizationLinearScalarParameters(rewriter, loc, op,
xMemRefType, yMemRefType, X, qMin, qMax, scale, zeroPoint, zeroPointInt,
enableSIMD, enableParallel);
wantZeroPoint, enableSIMD, enableParallel);
create.krnl.store(scale, YScale);
create.krnl.store(zeroPointInt, YZeroPoint);

emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType,
yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale,
zeroPoint, enableSIMD, enableParallel);
zeroPoint, wantZeroPoint /*wanted one, so we have a zero point*/,
enableSIMD, enableParallel);

rewriter.replaceOp(op, {Y, YScale, YZeroPoint});
onnxToKrnlSimdReport(op);
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ void emitQuantizationLinearScalarParameters(
mlir::Operation *op, mlir::MemRefType inputType,
mlir::MemRefType quantizedType, mlir::Value alloc, DimsExpr &allocDims,
mlir::Value input, mlir::Value qMin, mlir::Value qMax, mlir::Value scale,
mlir::Value zeroPoint, bool enableSIMD, bool enableParallel);
mlir::Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
bool enableParallel);

// Scan the input to compute scale, zeroPoint, and quantizedZeroPoint given qMin
// and qMax.
Expand All @@ -32,5 +33,6 @@ void emitDynamicQuantizationLinearScalarParameters(
mlir::Operation *op, mlir::MemRefType inputType,
mlir::MemRefType quantizedType, mlir::Value input, mlir::Value qMin,
mlir::Value qMax, mlir::Value &scale, mlir::Value &zeroPoint,
mlir::Value &quantizedZeroPoint, bool enableSIMD, bool enableParallel);
mlir::Value &quantizedZeroPoint, bool wantZeroPoint, bool enableSIMD,
bool enableParallel);
} // namespace onnx_mlir
24 changes: 18 additions & 6 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
Expand All @@ -26,7 +27,8 @@ namespace onnx_mlir {
void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, MemRefType inputType, MemRefType quantizedType,
Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax,
Value scale, Value zeroPoint, bool enableSIMD, bool enableParallel) {
Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
bool enableParallel) {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
rewriter, loc);

Expand Down Expand Up @@ -76,7 +78,11 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
// Round
Value roundX = create.math.round(scaleX);
// Adjust
Value adjustX = create.math.add(roundX, zeroPoint);
Value adjustX;
if (hasZeroPoint)
adjustX = create.math.add(roundX, zeroPoint);
else
adjustX = roundX;
// Saturate
Value saturateX = create.math.clip(adjustX, qMin, qMax);
Value res = create.math.cast(quantizedElementType, saturateX);
Expand Down Expand Up @@ -159,15 +165,21 @@ struct ONNXQuantizeLinearOpLowering

// Load y_zero_point.
Value zeroPoint;
bool hasZeroPoint = false;
if (!isNoneValue(YZeroPoint)) {
zeroPoint = create.krnl.load(adaptor.getYZeroPoint());
zeroPoint = create.math.cast(elementType, zeroPoint);
} else
zeroPoint = create.math.constant(elementType, 0.0);

hasZeroPoint = true;
}
if (disableQuantZeroPoint) {
// TODO: should we expect to disable hasZeroPoint forcefully, or generate
// an error if we had a zero point? Right now, just forcefully assert we
// have no zero point, i.e. ignore one even if we had a zero point.
hasZeroPoint = false;
}
emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType,
yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale,
zeroPoint, enableSIMD, enableParallel);
zeroPoint, hasZeroPoint, enableSIMD, enableParallel);

rewriter.replaceOp(op, {Y});
onnxToKrnlSimdReport(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Adding canonicalize is important here as this is the only way to check the values of the map,
// which are otherwise before the function, and thus are hard to test.

// -----

func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor<f32>, %arg2: tensor<i8>) -> tensor<4xf32> {
%0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xi8>, tensor<f32>, tensor<i8>) -> tensor<4xf32>
return %0 : tensor<4xf32>
Expand All @@ -29,10 +31,12 @@ func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor<f32>, %ar

// -----


func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor<f32>, %arg2: tensor<ui8>) -> tensor<4xf32> {
%0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xui8>, tensor<f32>, tensor<ui8>) -> tensor<4xf32>
return %0 : tensor<4xf32>

// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_dequantizelinear_ui8
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xui8>, [[PARAM_1_:%.+]]: memref<f32>, [[PARAM_2_:%.+]]: memref<ui8>) -> memref<4xf32> {
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32>
Expand All @@ -42,11 +46,11 @@ func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor<f32>, %
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xui8>
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref<f32>
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref<ui8>
// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8
// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8
// CHECK-DAG: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i8 to f32
// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8
// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8
// CHECK: [[VAR_8_:%.+]] = arith.uitofp [[VAR_7_]] : i8 to f32
// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_8_]], [[VAR_6_]] : f32
// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32
// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32
// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
// CHECK: }
Expand Down
Loading

0 comments on commit 4836f22

Please sign in to comment.