From 45f07d58fc1b5fbfa05e3f6124361b462d477111 Mon Sep 17 00:00:00 2001 From: Tong Chen Date: Wed, 4 Dec 2024 13:33:52 -0500 Subject: [PATCH] Transform SequenceAt to split for special cases (#3018) * implement Signed-off-by: chentong319 * test case Signed-off-by: chentong319 * format Signed-off-by: chentong319 * fix Signed-off-by: chentong319 --------- Signed-off-by: chentong319 Co-authored-by: Alexandre Eichenberger --- src/Dialect/ONNX/DialectBuilder.cpp | 1 - src/Dialect/ONNX/ONNXOps/OpHelper.cpp | 3 +- src/Dialect/ONNX/Transforms/Decompose.cpp | 118 ++++++++++++++++++ src/Dialect/ONNX/Transforms/Decompose.td | 18 +++ .../onnx/onnx_decompose_canonicalize.mlir | 43 +++++++ 5 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 test/mlir/onnx/onnx_decompose_canonicalize.mlir diff --git a/src/Dialect/ONNX/DialectBuilder.cpp b/src/Dialect/ONNX/DialectBuilder.cpp index b9382a06b0..1210f89f23 100644 --- a/src/Dialect/ONNX/DialectBuilder.cpp +++ b/src/Dialect/ONNX/DialectBuilder.cpp @@ -444,7 +444,6 @@ TensorType OnnxBuilder::toTensor(Type input) const { } TypeRange OnnxBuilder::toTensors(TypeRange inputs) const { - assert(inputs.size() >= 2 && "Expect at least two inputs"); if (llvm::all_of(inputs, [](Type t) { return (mlir::isa(t)); })) return inputs; assert(llvm::all_of(inputs, [](Type t) { diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 520de56339..36cefe7675 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -307,7 +307,8 @@ void ArrayAttrIntVals(ArrayAttr a, mlir::SmallVectorImpl &i) { ElementsAttr getElementAttributeFromONNXValue(Value value) { ONNXConstantOp constantOp = getONNXConstantOp(value); - if (constantOp) + // In case the ConstantOp has not been normalized yet + if (constantOp && constantOp.getValueAttr()) return mlir::dyn_cast(constantOp.getValueAttr()); return nullptr; } diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index e9e455e26b..514943457e 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -333,6 +333,120 @@ bool hasStaticSpatialDims(Value v) { return llvm::none_of(Ds, ShapedType::isDynamic); } +// In the following pattern, a SequenceAt can be replaced with Split +// %seq = onnx.SplitToSequence(%input, %split) {%axis : } +// %res = onnx.SequenceAt(%seq, %position) +// We just try to avoid using the sequence related ops, which are less +// optimized, or even not implemented in onnx-mlir. +// In the targeted use case, %split and %position are constant scalar and the +// tensor of %input and %res have static shape. +// This condition greatly reduces the complexity of code generation to replace +// SequenceAt with split op +// %res = onnx.Split(%input, onnx.expand(%split, %input.shape()[%axis])) +// {%axis : } : %position +// onnx.expand(%split, %input.shape()[%axis]) can be a constant under the +// assumed condition. +// Here %position has to be compiler time constant. +// For multiple SequenceAt from the same SplitToSequence result, the onnx.split +// for different SequenceAt are expected to be merged by optimization. +// Alternatively, Slice can be used +// %res = onnx.Slice(%input, %start, %end, %step) +// The start, and end for slice will be onnx.constant: +// start: %position*%split for %axis, 0 for other dimensionis +// end: (%positiion+1)*%split for %axis, upper bound for other dimension +// step: 1 for all dimensions +// The split approach may have better performance than the alternative slice +// approach, because the slicing is done separately. + +bool canSequenceAtBeReplaced(Value sequenceAtResult) { + if (!hasStaticShape(sequenceAtResult.getType())) + return false; + + ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp(); + + Value inputSequence = op.getInputSequence(); + Value position = op.getPosition(); + + if (!isDenseONNXConstant(position)) + return false; + + // Input sequence should be defined with SplitToSequence + ONNXSplitToSequenceOp splitToSequence = + inputSequence.getDefiningOp(); + if (!splitToSequence) + return false; + + // Check the pattern of the SplitToSequence op + Value input = splitToSequence.getInput(); + if (!hasStaticShape(input.getType())) + return false; + Value split = splitToSequence.getSplit(); + if (!isScalarConstantTensor(split)) + return false; + + return true; +} + +Value replaceSequenceAt( + PatternRewriter &rewriter, Location loc, Value sequenceAtResult) { + ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp(); + + Value inputSequence = op.getInputSequence(); + Value position = op.getPosition(); + + ONNXConstantOp positionConstant = + mlir::cast(position.getDefiningOp()); + int64_t positionInt = getScalarValue(positionConstant); + + ONNXSplitToSequenceOp splitToSequence = + mlir::cast(inputSequence.getDefiningOp()); + + Value input = splitToSequence.getInput(); + Value split = splitToSequence.getSplit(); + + ONNXConstantOp splitConstant = + mlir::cast(split.getDefiningOp()); + int64_t splitInt = getScalarValue(splitConstant); + int64_t axisInt = splitToSequence.getAxis(); + + auto shape = getShape(input.getType()); + + OnnxBuilder create(rewriter, loc); + + Type sequenceElementType = + mlir::cast(inputSequence.getType()).getElementType(); + mlir::SmallVector outputTypes( + shape[axisInt] / splitInt, sequenceElementType); + auto numSplit = create.constantInt64( + mlir::SmallVector(shape[axisInt] / splitInt, splitInt)); + auto resultRange = create.split(outputTypes, input, numSplit, axisInt); + auto rawResult = resultRange[positionInt]; + + if (rawResult.getType() == sequenceAtResult.getType()) + return rawResult; + + // Temporary code for the error in the model generated by torch.onnx.export + // The the dim of the reuslt of SequenceAt op is different from the element + // type of the sequence.. + // My assumption is that the exporter is confused with squeeze and unsqueeze + // followed by the SequenceAt. There are two cases in the model: + // clang-format off + // Case #1: + // %16 = "onnx.SequenceAt"(%14, %15) {onnx_node_name = "n0"} : + // (!onnx.Seq>, tensor) -> tensor<1x100xf32> + // %23 = "onnx.Unsqueeze"(%16, %22) {onnx_node_name = "n2"} : + // (tensor<1x100xf32>, tensor) -> tensor<1x1x100xf32> + // Case#2: + // %67 = "onnx.SequenceAt"(%66, %15) {onnx_node_name = "n0"} : + // (!onnx.Seq>, tensor) -> tensor<1x1x100xf32> + // %71 = "onnx.Sigmoid"(%67) {onnx_node_name = "node_Sigmoid_60"} : + // (tensor<1x1x100xf32>) -> tensor<1x1x100xf32> + // clang-format on + // Thus, the compiler squeeze the tensor if needed. + return create.squeeze( + sequenceAtResult.getType(), rawResult, create.constantInt64(axisInt)); +} + bool shouldDecomposeConvTransposeOp(Value convTransposeResult) { ONNXConvTransposeOp op = mlir::cast(convTransposeResult.getDefiningOp()); @@ -1246,6 +1360,10 @@ void DecomposeONNXToONNXPass::runOnOperation() { return !isConcatFuseMatched(op, shapeOp, transposeOp); }); + target.addDynamicallyLegalOp([](ONNXSequenceAtOp op) { + return !onnx_mlir::canSequenceAtBeReplaced(op.getResult()); + }); + // Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs. target.addDynamicallyLegalOp([](ONNXConstantOp op) { return !(op.getValueFloatAttr() || op.getValueFloatsAttr() || diff --git a/src/Dialect/ONNX/Transforms/Decompose.td b/src/Dialect/ONNX/Transforms/Decompose.td index 3cea294521..00ae9f6ff3 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.td +++ b/src/Dialect/ONNX/Transforms/Decompose.td @@ -71,6 +71,12 @@ def createScalarDenseAttrRank0 def ReshapeElementsAttrToRank0 : NativeCodeCall< "onnx_mlir::OnnxElementsAttrBuilder($0.getContext()).reshape(cast($0), {})">; +def ReplaceSequenceAt : NativeCodeCall< + "onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">; + +def CanSequenceAtBeReplaced : + Constraint, "check whether the SequenceAt can be replaced with split">; + // Create a DenseElementsAttr from a single attribute. def createDenseArrayAttrFromSingleAttr : NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">; @@ -620,4 +626,16 @@ def ConstantOpNormalizationPattern6: Pat< [(AttributeIsNotNull:$stringsAttr)] >; +// Optimize for the pattern coming from torch.nn.LSTM exported from pytorch +// %32 = "onnx.SplitToSequence"(%30, %27) {axis = 0 : si64, keepdims = 0 : si64, onnx_node_name = "n1"} : (tensor<1x1x100xf32>, tensor) -> !onnx.Seq> +// %33 = "onnx.SequenceAt"(%32, %26) {onnx_node_name = "n0"} : (!onnx.Seq>, tensor) -> tensor<1x100xf32> +// When shape and size/axis related value are constant, this sequence of code +// can be translated into onnx.slice + +def ReplaceSequenceAtPattern: Pat< + (ONNXSequenceAtOp:$res $seq, $position), + (ReplaceSequenceAt $res), + [(CanSequenceAtBeReplaced:$res)] +>; + #endif // ONNX_DECOMPOSE diff --git a/test/mlir/onnx/onnx_decompose_canonicalize.mlir b/test/mlir/onnx/onnx_decompose_canonicalize.mlir new file mode 100644 index 0000000000..a132445562 --- /dev/null +++ b/test/mlir/onnx/onnx_decompose_canonicalize.mlir @@ -0,0 +1,43 @@ + +// RUN: onnx-mlir-opt --decompose-onnx --canonicalize %s -split-input-file | FileCheck %s + +// ----- + +// Test one pattern in lstm_no_data.onnx. +// The type of output of SequenceAt is not the same as the element type +// of the input sequence +func.func @sequence_at_squeezed(%arg0 : tensor<1x1x100xf32>) -> tensor<1x100xf32> { + %26 = onnx.Constant dense<0> : tensor + %27 = onnx.Constant dense<1> : tensor + %32 = "onnx.SplitToSequence"(%arg0, %27) {axis = 0 : si64, keepdims = 0 : si64} : (tensor<1x1x100xf32>, tensor) -> !onnx.Seq> + %33 = "onnx.SequenceAt"(%32, %26) : (!onnx.Seq>, tensor) -> tensor<1x100xf32> + return %33: tensor<1x100xf32> +// CHECK-LABEL: func.func @sequence_at_squeezed +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x100xf32>) -> tensor<1x100xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK: [[VAR_2_:%.+]] = "onnx.Split"([[PARAM_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x1x100xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Squeeze"([[VAR_2_]], [[VAR_0_]]) : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x100xf32> +// CHECK: return [[VAR_3_]] : tensor<1x100xf32> +// CHECK: } +} + +func.func @sequence_at_multi(%arg0 : tensor<1x1x400xf32>) -> tensor<1x1x100xf32> { + %15 = onnx.Constant dense<0> : tensor + %38 = onnx.Constant dense<1> : tensor + %65 = onnx.Constant dense<100> : tensor + %66 = "onnx.SplitToSequence"(%arg0, %65) {axis = 2 : si64, keepdims = 1 : si64} : (tensor<1x1x400xf32>, tensor) -> !onnx.Seq> + %67 = "onnx.SequenceAt"(%66, %15) : (!onnx.Seq>, tensor) -> tensor<1x1x100xf32> + %68 = "onnx.SequenceAt"(%66, %38) : (!onnx.Seq>, tensor) -> tensor<1x1x100xf32> + %40 = "onnx.Add"(%67, %68) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32> + return %40: tensor<1x1x100xf32> +// CHECK-LABEL: func.func @sequence_at_multi +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x400xf32>) -> tensor<1x1x100xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<100> : tensor<4xi64> +// CHECK-DAG: [[VAR_1_:%.+]]:4 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>) +// CHECK-DAG: [[VAR_2_:%.+]]:4 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>) +// CHECK: [[VAR_3_:%.+]] = "onnx.Add"([[VAR_1_]]#0, [[VAR_2_]]#1) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32> +// CHECK: return [[VAR_3_]] : tensor<1x1x100xf32> +// CHECK: } +} +