Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "core/optimizer/conv_add_act_fusion.h"
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/mul_add_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/div_mul_fusion.h"
#include "core/optimizer/double_qdq_pairs_remover.h"
Expand Down Expand Up @@ -256,6 +257,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq,
session_options.config_options));
transformers.emplace_back(std::make_unique<MatMulAddFusion>());
transformers.emplace_back(std::make_unique<MulAddFusion>());
transformers.emplace_back(std::make_unique<ReshapeFusion>());
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
session_options.free_dimension_overrides));
Expand Down
198 changes: 198 additions & 0 deletions onnxruntime/core/optimizer/mul_add_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/status.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/mul_add_fusion.h"
#include "core/optimizer/utils.h"
#include "core/common/logging/logging.h"
#include "core/framework/data_types.h"
#include "core/framework/tensorprotoutils.h" // For utilities like TensorProtoToMLFloat16 etc.

using namespace ONNX_NAMESPACE;

Check warning on line 13 in onnxruntime/core/optimizer/mul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/optimizer/mul_add_fusion.cc:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using namespace onnxruntime::common;

Check warning on line 14 in onnxruntime/core/optimizer/mul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/optimizer/mul_add_fusion.cc:14: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {

bool IsPatternBatchnorm(const NodeArg* inp,
const ONNX_NAMESPACE::TensorProto* scale,
const ONNX_NAMESPACE::TensorProto* bias) {
int inp_rank = inp->Shape()->dim_size();
int scale_rank = scale->dims_size();
int bias_rank = bias->dims_size();
int max_rank = std::max({inp_rank, scale_rank, bias_rank});

Check warning on line 24 in onnxruntime/core/optimizer/mul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for max [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/mul_add_fusion.cc:24: Add #include <algorithm> for max [build/include_what_you_use] [4]
// Currently we do not support mul + add with rank-1 inputs
if (max_rank <= 1) {
return false;
}
std::vector<int64_t> broadcast_inp(max_rank);
std::vector<int64_t> broadcast_scale(max_rank);
std::vector<int64_t> broadcast_bias(max_rank);

Check warning on line 31 in onnxruntime/core/optimizer/mul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/mul_add_fusion.cc:31: Add #include <vector> for vector<> [build/include_what_you_use] [4]
for (int idx = 0; idx < max_rank; ++idx) {
auto broad_idx = max_rank - 1 - idx;
broadcast_inp[broad_idx] = (idx < inp_rank) ? inp->Shape()->dim(inp_rank - 1 - idx).dim_value() : 1;
broadcast_scale[broad_idx] = (idx < scale_rank) ? scale->dims(scale_rank - 1 - idx) : 1;
broadcast_bias[broad_idx] = (idx < bias_rank) ? bias->dims(bias_rank - 1 - idx) : 1;
}
// broadcast_scale and broadcast_bias should be in the form of [1, num_channel, 1, ..., 1].
// Note: The num_channel can be 1
int64_t num_channel = broadcast_inp[1];
if ((broadcast_scale[0] != 1) || (broadcast_scale[1] != 1 && broadcast_scale[1] != num_channel)) {
return false;
}
if ((broadcast_bias[0] != 1) || (broadcast_bias[1] != 1 && broadcast_bias[1] != num_channel)) {
return false;
}
for (int idx = 2; idx < max_rank; ++idx) {
if (broadcast_scale[idx] != 1 || broadcast_bias[idx] != 1) {
return false;
}
}
return true;
}

bool MulAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
auto& mul_node = node;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
mul_node.GetOutputEdgesCount() != 1) {
return false;
}
const auto& add_node = *mul_node.OutputNodesBegin();
// Make sure the two nodes do not span execution providers.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
(add_node.GetExecutionProviderType() != mul_node.GetExecutionProviderType())) {
return false;
}
// Pattern: Input -> Mul -> Add
if (mul_node.InputDefs().size() != 2 || add_node.InputDefs().size() != 2) {
return false;
}

// Get the second input of Mul (scale) and Add (bias)
// Focus on the case mul and add have exactly one constant and one non-constant input
bool is_const_mul_in0 = graph_utils::NodeArgIsConstant(graph, *mul_node.InputDefs()[0]);
bool is_const_mul_in1 = graph_utils::NodeArgIsConstant(graph, *mul_node.InputDefs()[1]);
bool is_const_add_in0 = graph_utils::NodeArgIsConstant(graph, *add_node.InputDefs()[0]);
bool is_const_add_in1 = graph_utils::NodeArgIsConstant(graph, *add_node.InputDefs()[1]);
if ((is_const_mul_in0 && is_const_mul_in1) || (!is_const_mul_in0 && !is_const_mul_in1)) {
return false;
}
if ((is_const_add_in0 && is_const_add_in1) || (!is_const_add_in0 && !is_const_add_in1)) {
return false;
}

auto mul_const_idx = is_const_mul_in0 ? 0 : 1;
auto add_const_idx = is_const_add_in0 ? 0 : 1;
return IsPatternBatchnorm(
mul_node.InputDefs()[1 - mul_const_idx],
graph_utils::GetConstantInitializer(graph, mul_node.InputDefs()[mul_const_idx]->Name()),
graph_utils::GetConstantInitializer(graph, add_node.InputDefs()[add_const_idx]->Name()));
}

Status MulAddFusion::FuseMulAdd(Node& node, Graph& graph, bool& modified, const logging::Logger&) const {
auto& mul_node = node;
Node& add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
bool is_const_mul_in0 = graph_utils::NodeArgIsConstant(graph, *mul_node.InputDefs()[0]);
bool is_const_add_in0 = graph_utils::NodeArgIsConstant(graph, *add_node.InputDefs()[0]);
auto mul_const_idx = is_const_mul_in0 ? 0 : 1;
auto mul_non_const_idx = 1 - mul_const_idx;
auto add_const_idx = is_const_add_in0 ? 0 : 1;
// Before layout transform, channel is the 1st dimension
int64_t num_channel = mul_node.InputDefs()[mul_non_const_idx]->Shape()->dim(1).dim_value();

// Process scale and bias. Should be {num_channel}
const auto* scale_tensor_proto = graph_utils::GetConstantInitializer(graph, mul_node.InputDefs()[mul_const_idx]->Name());
const auto* bias_tensor_proto = graph_utils::GetConstantInitializer(graph, add_node.InputDefs()[add_const_idx]->Name());
ORT_ENFORCE(scale_tensor_proto);
ORT_ENFORCE(bias_tensor_proto);
ONNX_NAMESPACE::TensorProto reshaped_scale_proto = *scale_tensor_proto;
ONNX_NAMESPACE::TensorProto reshaped_bias_tensor_proto = *bias_tensor_proto;
reshaped_scale_proto.clear_dims();
reshaped_scale_proto.set_name(scale_tensor_proto->name() + "_reshaped");
reshaped_scale_proto.add_dims(num_channel);
reshaped_bias_tensor_proto.clear_dims();
reshaped_bias_tensor_proto.set_name(bias_tensor_proto->name() + "_reshaped");
reshaped_bias_tensor_proto.add_dims(num_channel);
NodeArg& reshaped_scale_node_arg = graph_utils::AddInitializer(graph, reshaped_scale_proto);
NodeArg& reshaped_bias_node_arg = graph_utils::AddInitializer(graph, reshaped_bias_tensor_proto);

// add initializer of mean as zeros of shape [channel]
Initializer mean_init(
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(mul_node.InputDefs()[mul_non_const_idx]->TypeAsProto()->tensor_type().elem_type()),
graph.GenerateNodeArgName(mul_node.Name() + "_mul_add_fusion_mean"),
gsl::span<const int64_t>({num_channel}));
ONNX_NAMESPACE::TensorProto mean_tensor_proto;
mean_init.ToProto(mean_tensor_proto);
NodeArg& mean_init_node_arg = graph_utils::AddInitializer(graph, mean_tensor_proto);

// add initializer of var as ones of shape [channel]
Initializer var_init(
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(mul_node.InputDefs()[mul_non_const_idx]->TypeAsProto()->tensor_type().elem_type()),
graph.GenerateNodeArgName(add_node.Name() + "_mul_add_fusion_var"),
gsl::span<const int64_t>({num_channel}));
var_init.add(1);
ONNX_NAMESPACE::TensorProto var_tensor_proto;
var_init.ToProto(var_tensor_proto);
NodeArg& var_init_node_arg = graph_utils::AddInitializer(graph, var_tensor_proto);

// add BatchNormalization
Node& bn_node = graph.AddNode(
graph.GenerateNodeName(mul_node.Name() + "/MulAddFusion"),
"BatchNormalization",
"fused Mul and Add",
gsl::span<NodeArg* const>({mul_node.MutableInputDefs()[mul_non_const_idx],
&reshaped_scale_node_arg,
&reshaped_bias_node_arg,
&mean_init_node_arg,
&var_init_node_arg}),
gsl::span<NodeArg* const>({add_node.MutableOutputDefs()[0]}),
nullptr,
kOnnxDomainAlias);
bn_node.SetExecutionProviderType(mul_node.GetExecutionProviderType());
constexpr float eps = 0.0f;
bn_node.SetSinceVersion(9);
bn_node.AddAttribute("epsilon", eps);

auto mul_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(mul_node);
auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(add_node);
if (!graph_utils::IsGraphInput(graph, mul_node.InputDefs()[mul_non_const_idx])) {
graph.AddEdge(
mul_input_edges[mul_non_const_idx].src_node,
bn_node.Index(),
mul_input_edges[mul_non_const_idx].src_arg_index,
0);
}

graph_utils::GraphEdge::RemoveGraphEdges(graph, mul_input_edges);
graph_utils::GraphEdge::RemoveGraphEdges(graph, add_input_edges);
graph_utils::RemoveNodeOutputEdges(graph, add_node);
graph_utils::ReplaceDownstreamNodeInput(graph, add_node, 0, bn_node, 0);
graph.RemoveNode(mul_node.Index());
graph.RemoveNode(add_node.Index());

modified = true;
return Status::OK();
}

Status MulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const GraphViewer graph_viewer{graph};
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : node_indices) {
auto* node_ptr = graph.GetNode(node_idx);
if (!node_ptr) {
continue;
}

Node& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));

if (this->SatisfyCondition(graph, node, logger)) {
ORT_RETURN_IF_ERROR(this->FuseMulAdd(node, graph, modified, logger));
}
}

return Status::OK();
}

} // namespace onnxruntime
45 changes: 45 additions & 0 deletions onnxruntime/core/optimizer/mul_add_fusion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/graph_transformer.h"

namespace onnxruntime {

/**
@Class MulAddFusion

Rewrite rule that fuses two Mul+Add nodes to a single Batchnorm node.

Determines whether a Mul followed by an Add can be safely fused into a
BatchNormalization node. The fusion is based on the observation that:

Y = (X * scale) + bias

is mathematically equivalent to a BatchNormalization operation when the
BatchNorm parameters are set to:

mean = 0
var = 1
epsilon = 0

with

BatchNorm(X) = (X - mean) / sqrt(var + epsilon) * scale + bias
= (X - 0) / sqrt(1 + 0) * scale + bias
= X * scale + bias

*/
class MulAddFusion : public GraphTransformer {
public:
MulAddFusion() noexcept : GraphTransformer("MulAddFusion") {}

private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const;
Status FuseMulAdd(Node& node, Graph& graph, bool& modified, const logging::Logger&) const;
};

} // namespace onnxruntime
64 changes: 64 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "core/optimizer/isinf_reducesum_fusion.h"
#include "core/optimizer/label_encoder_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/mul_add_fusion.h"
#include "core/optimizer/matmul_bn_fusion.h"
#include "core/optimizer/matmul_nbits_fusion.h"
#include "core/optimizer/matmul_integer_to_float.h"
Expand Down Expand Up @@ -2753,6 +2754,69 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_PreserveAttentionPattern) {
ASSERT_EQ(op_count_qnn_ep["Gemm"], op_count_before["Gemm"] + expected_fusions2);
}

// Test case for MulAddFusion: Mul followed by Add can fuse into BatchNormalization
TEST_F(GraphTransformationTests, MulAddFusion_ToBatchNormalization) {
// Input shape (N, C, H, W) for the test
const std::vector<int64_t> input_shape = {{4, 8, 32, 32}}; // N, C, H, W for example

auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>(input_shape);
// Scale for Mul: should have same dimensions as input or be broadcastable
// For BatchNormalization, scale/bias typically match the channel dimension
auto* mul_scale_arg = builder.MakeInitializer<float>({1, 8, 1, 1}, 0.0f, 1.0f); // Scale for each channel
auto* add_bias_arg = builder.MakeInitializer<float>({1, 8, 1, 1}, 0.0f, 1.0f); // Bias for each channel
auto* mul_out = builder.MakeIntermediate();
auto* output_arg = builder.MakeOutput();

builder.AddNode("Mul", {input_arg, mul_scale_arg}, {mul_out});
builder.AddNode("Add", {mul_out, add_bias_arg}, {output_arg});
};

// Pre-graph checker: Verify initial graph has Mul and Add
auto pre_graph_checker = [](Graph& graph) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count["Mul"] == 1);
TEST_RETURN_IF_NOT(op_to_count["Add"] == 1);
TEST_RETURN_IF_NOT(op_to_count["BatchNormalization"] == 0); // Should not have BN initially
return Status::OK();
};

// Post-graph checker: Verify graph after fusion has BatchNormalization
auto post_graph_checker = [](Graph& graph) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count["Mul"] == 0);
TEST_RETURN_IF_NOT(op_to_count["Add"] == 0);
TEST_RETURN_IF_NOT(op_to_count["BatchNormalization"] == 1); // Should have BN after fusion

// Additional checks for the fused BatchNormalization node
for (const Node& node : graph.Nodes()) {
if (node.OpType() == "BatchNormalization") {
// BatchNormalization has 5 inputs: X, scale, B, mean, var
TEST_RETURN_IF_NOT(node.InputDefs().size() == 5);

// Verify the output shape matches the original input shape
auto output_shape_proto = node.OutputDefs()[0]->Shape();
TEST_RETURN_IF_NOT(output_shape_proto != nullptr);
auto shape = utils::GetTensorShapeFromTensorShapeProto(*output_shape_proto);
TEST_RETURN_IF_NOT(shape.NumDimensions() == 4);
TEST_RETURN_IF_NOT(shape[0] == 4 && shape[1] == 8 && shape[2] == 32 && shape[3] == 32);
}
}
return Status::OK();
};

// Instantiate the MulAddFusion transformer
std::unique_ptr<GraphTransformer> transformer = std::make_unique<MulAddFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(
build_test_case, 7,
*logger_,
std::move(transformer),
TransformerLevel::Level1,
1,
pre_graph_checker,
post_graph_checker));
}

#ifndef DISABLE_CONTRIB_OPS
TEST_F(GraphTransformationTests, Gemm_Relu_three_input) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx";
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4517,6 +4517,9 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) {
// changes during the layout transformation process.
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1"));

// MulAddFusion fuse Mul and Add into Batchnormalization and prevent the Transpose pushing
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableSpecifiedOptimizers, "MulAddFusion"));

using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider;

// set the test EP to support all ops in the model so that the layout transform applies to all nodes
Expand Down
Loading