From 82420eefa5d142ed7d70b4d96db5fad7346f143e Mon Sep 17 00:00:00 2001 From: qti-hungjuiw Date: Tue, 21 Oct 2025 23:10:52 +0800 Subject: [PATCH 1/2] Add MulAddFusion core optimizer in L1 Transformer - Add MulAddFusion to fuse mul and add into batchnormalization - Add corresponding unit test - Fix TransposeOptimizerTests.QnnTransposeReshape --- .../core/optimizer/graph_transformer_utils.cc | 2 + onnxruntime/core/optimizer/mul_add_fusion.cc | 198 ++++++++++++++++++ onnxruntime/core/optimizer/mul_add_fusion.h | 45 ++++ .../test/optimizer/graph_transform_test.cc | 64 ++++++ .../optimizer/transpose_optimizer_test.cc | 3 + 5 files changed, 312 insertions(+) create mode 100644 onnxruntime/core/optimizer/mul_add_fusion.cc create mode 100644 onnxruntime/core/optimizer/mul_add_fusion.h diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 3680127ed4793..7364bbca23034 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -31,6 +31,7 @@ #include "core/optimizer/conv_add_act_fusion.h" #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/conv_bn_fusion.h" +#include "core/optimizer/mul_add_fusion.h" #include "core/optimizer/conv_mul_fusion.h" #include "core/optimizer/div_mul_fusion.h" #include "core/optimizer/double_qdq_pairs_remover.h" @@ -256,6 +257,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq, session_options.config_options)); transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique( session_options.free_dimension_overrides)); diff --git a/onnxruntime/core/optimizer/mul_add_fusion.cc b/onnxruntime/core/optimizer/mul_add_fusion.cc new file mode 100644 index 0000000000000..0b28dd6ee7fed --- /dev/null +++ b/onnxruntime/core/optimizer/mul_add_fusion.cc @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/status.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/mul_add_fusion.h" +#include "core/optimizer/utils.h" +#include "core/common/logging/logging.h" +#include "core/framework/data_types.h" +#include "core/framework/tensorprotoutils.h" // For utilities like TensorProtoToMLFloat16 etc. + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { + +bool IsPatternBatchnorm(const NodeArg* inp, + const ONNX_NAMESPACE::TensorProto* scale, + const ONNX_NAMESPACE::TensorProto* bias) { + int inp_rank = inp->Shape()->dim_size(); + int scale_rank = scale->dims_size(); + int bias_rank = bias->dims_size(); + int max_rank = std::max({inp_rank, scale_rank, bias_rank}); + // Currently we do not support mul + add with rank-1 inputs + if (max_rank <= 1) { + return false; + } + std::vector broadcast_inp(max_rank); + std::vector broadcast_scale(max_rank); + std::vector broadcast_bias(max_rank); + for (int idx = 0; idx < max_rank; ++idx) { + auto broad_idx = max_rank - 1 - idx; + broadcast_inp[broad_idx] = (idx < inp_rank) ? inp->Shape()->dim(inp_rank - 1 - idx).dim_value() : 1; + broadcast_scale[broad_idx] = (idx < scale_rank) ? scale->dims(scale_rank - 1 - idx) : 1; + broadcast_bias[broad_idx] = (idx < bias_rank) ? bias->dims(bias_rank - 1 - idx) : 1; + } + // broadcast_scale and broadcast_bias should be in the form of [1, num_channel, 1, ..., 1]. + // Note: The num_channel can be 1 + int64_t num_channel = broadcast_inp[1]; + if ((broadcast_scale[0] != 1) || (broadcast_scale[1] != 1 && broadcast_scale[1] != num_channel)) { + return false; + } + if ((broadcast_bias[0] != 1) || (broadcast_bias[1] != 1 && broadcast_bias[1] != num_channel)) { + return false; + } + for (int idx = 2; idx < max_rank; ++idx) { + if (broadcast_scale[idx] != 1 || broadcast_bias[idx] != 1) { + return false; + } + } + return true; +} + +bool MulAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + auto& mul_node = node; + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) || + mul_node.GetOutputEdgesCount() != 1) { + return false; + } + const auto& add_node = *mul_node.OutputNodesBegin(); + // Make sure the two nodes do not span execution providers. + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) || + (add_node.GetExecutionProviderType() != mul_node.GetExecutionProviderType())) { + return false; + } + // Pattern: Input -> Mul -> Add + if (mul_node.InputDefs().size() != 2 || add_node.InputDefs().size() != 2) { + return false; + } + + // Get the second input of Mul (scale) and Add (bias) + // Focus on the case mul and add have exactly one constant and one non-constant input + bool is_const_mul_in0 = graph_utils::NodeArgIsConstant(graph, *mul_node.InputDefs()[0]); + bool is_const_mul_in1 = graph_utils::NodeArgIsConstant(graph, *mul_node.InputDefs()[1]); + bool is_const_add_in0 = graph_utils::NodeArgIsConstant(graph, *add_node.InputDefs()[0]); + bool is_const_add_in1 = graph_utils::NodeArgIsConstant(graph, *add_node.InputDefs()[1]); + if ((is_const_mul_in0 && is_const_mul_in1) || (!is_const_mul_in0 && !is_const_mul_in1)) { + return false; + } + if ((is_const_add_in0 && is_const_add_in1) || (!is_const_add_in0 && !is_const_add_in1)) { + return false; + } + + auto mul_const_idx = is_const_mul_in0 ? 0 : 1; + auto add_const_idx = is_const_add_in0 ? 0 : 1; + return IsPatternBatchnorm( + mul_node.InputDefs()[1 - mul_const_idx], + graph_utils::GetConstantInitializer(graph, mul_node.InputDefs()[mul_const_idx]->Name()), + graph_utils::GetConstantInitializer(graph, add_node.InputDefs()[add_const_idx]->Name())); +} + +Status MulAddFusion::FuseMulAdd(Node& node, Graph& graph, bool& modified, const logging::Logger&) const { + auto& mul_node = node; + Node& add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index()); + bool is_const_mul_in0 = graph_utils::NodeArgIsConstant(graph, *mul_node.InputDefs()[0]); + bool is_const_add_in0 = graph_utils::NodeArgIsConstant(graph, *add_node.InputDefs()[0]); + auto mul_const_idx = is_const_mul_in0 ? 0 : 1; + auto mul_non_const_idx = 1 - mul_const_idx; + auto add_const_idx = is_const_add_in0 ? 0 : 1; + // Before layout transform, channel is the 1st dimension + int64_t num_channel = mul_node.InputDefs()[mul_non_const_idx]->Shape()->dim(1).dim_value(); + + // Process scale and bias. Should be {num_channel} + const auto* scale_tensor_proto = graph_utils::GetConstantInitializer(graph, mul_node.InputDefs()[mul_const_idx]->Name()); + const auto* bias_tensor_proto = graph_utils::GetConstantInitializer(graph, add_node.InputDefs()[add_const_idx]->Name()); + ORT_ENFORCE(scale_tensor_proto); + ORT_ENFORCE(bias_tensor_proto); + ONNX_NAMESPACE::TensorProto reshaped_scale_proto = *scale_tensor_proto; + ONNX_NAMESPACE::TensorProto reshaped_bias_tensor_proto = *bias_tensor_proto; + reshaped_scale_proto.clear_dims(); + reshaped_scale_proto.set_name(scale_tensor_proto->name() + "_reshaped"); + reshaped_scale_proto.add_dims(num_channel); + reshaped_bias_tensor_proto.clear_dims(); + reshaped_bias_tensor_proto.set_name(bias_tensor_proto->name() + "_reshaped"); + reshaped_bias_tensor_proto.add_dims(num_channel); + NodeArg& reshaped_scale_node_arg = graph_utils::AddInitializer(graph, reshaped_scale_proto); + NodeArg& reshaped_bias_node_arg = graph_utils::AddInitializer(graph, reshaped_bias_tensor_proto); + + // add initializer of mean as zeros of shape [channel] + Initializer mean_init( + static_cast(mul_node.InputDefs()[mul_non_const_idx]->TypeAsProto()->tensor_type().elem_type()), + graph.GenerateNodeArgName(mul_node.Name() + "_mul_add_fusion_mean"), + gsl::span({num_channel})); + ONNX_NAMESPACE::TensorProto mean_tensor_proto; + mean_init.ToProto(mean_tensor_proto); + NodeArg& mean_init_node_arg = graph_utils::AddInitializer(graph, mean_tensor_proto); + + // add initializer of var as ones of shape [channel] + Initializer var_init( + static_cast(mul_node.InputDefs()[mul_non_const_idx]->TypeAsProto()->tensor_type().elem_type()), + graph.GenerateNodeArgName(add_node.Name() + "_mul_add_fusion_var"), + gsl::span({num_channel})); + var_init.add(1); + ONNX_NAMESPACE::TensorProto var_tensor_proto; + var_init.ToProto(var_tensor_proto); + NodeArg& var_init_node_arg = graph_utils::AddInitializer(graph, var_tensor_proto); + + // add BatchNormalization + Node& bn_node = graph.AddNode( + graph.GenerateNodeName(mul_node.Name() + "/MulAddFusion"), + "BatchNormalization", + "fused Mul and Add", + gsl::span({mul_node.MutableInputDefs()[mul_non_const_idx], + &reshaped_scale_node_arg, + &reshaped_bias_node_arg, + &mean_init_node_arg, + &var_init_node_arg}), + gsl::span({add_node.MutableOutputDefs()[0]}), + nullptr, + kOnnxDomainAlias); + bn_node.SetExecutionProviderType(mul_node.GetExecutionProviderType()); + constexpr float eps = 0.0f; + bn_node.SetSinceVersion(9); + bn_node.AddAttribute("epsilon", eps); + + auto mul_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(mul_node); + auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(add_node); + if (!graph_utils::IsGraphInput(graph, mul_node.InputDefs()[mul_non_const_idx])) { + graph.AddEdge( + mul_input_edges[mul_non_const_idx].src_node, + bn_node.Index(), + mul_input_edges[mul_non_const_idx].src_arg_index, + 0); + } + + graph_utils::GraphEdge::RemoveGraphEdges(graph, mul_input_edges); + graph_utils::GraphEdge::RemoveGraphEdges(graph, add_input_edges); + graph_utils::RemoveNodeOutputEdges(graph, add_node); + graph_utils::ReplaceDownstreamNodeInput(graph, add_node, 0, bn_node, 0); + graph.RemoveNode(mul_node.Index()); + graph.RemoveNode(add_node.Index()); + + modified = true; + return Status::OK(); +} + +Status MulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + const GraphViewer graph_viewer{graph}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + auto* node_ptr = graph.GetNode(node_idx); + if (!node_ptr) { + continue; + } + + Node& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (this->SatisfyCondition(graph, node, logger)) { + ORT_RETURN_IF_ERROR(this->FuseMulAdd(node, graph, modified, logger)); + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/mul_add_fusion.h b/onnxruntime/core/optimizer/mul_add_fusion.h new file mode 100644 index 0000000000000..f5fbfc6591e3b --- /dev/null +++ b/onnxruntime/core/optimizer/mul_add_fusion.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class MulAddFusion + +Rewrite rule that fuses two Mul+Add nodes to a single Batchnorm node. + +Determines whether a Mul followed by an Add can be safely fused into a +BatchNormalization node. The fusion is based on the observation that: + + Y = (X * scale) + bias + +is mathematically equivalent to a BatchNormalization operation when the +BatchNorm parameters are set to: + + mean = 0 + var = 1 + epsilon = 0 + +with + + BatchNorm(X) = (X - mean) / sqrt(var + epsilon) * scale + bias + = (X - 0) / sqrt(1 + 0) * scale + bias + = X * scale + bias + +*/ +class MulAddFusion : public GraphTransformer { + public: + MulAddFusion() noexcept : GraphTransformer("MulAddFusion") {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const; + Status FuseMulAdd(Node& node, Graph& graph, bool& modified, const logging::Logger&) const; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 5ad5811cfd7bc..8382b86fa1875 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -58,6 +58,7 @@ #include "core/optimizer/isinf_reducesum_fusion.h" #include "core/optimizer/label_encoder_fusion.h" #include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/mul_add_fusion.h" #include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/matmul_nbits_fusion.h" #include "core/optimizer/matmul_integer_to_float.h" @@ -2753,6 +2754,69 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_PreserveAttentionPattern) { ASSERT_EQ(op_count_qnn_ep["Gemm"], op_count_before["Gemm"] + expected_fusions2); } +// Test case for MulAddFusion: Mul followed by Add can fuse into BatchNormalization +TEST_F(GraphTransformationTests, MulAddFusion_ToBatchNormalization) { + // Input shape (N, C, H, W) for the test + const std::vector input_shape = {{4, 8, 32, 32}}; // N, C, H, W for example + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape); + // Scale for Mul: should have same dimensions as input or be broadcastable + // For BatchNormalization, scale/bias typically match the channel dimension + auto* mul_scale_arg = builder.MakeInitializer({1, 8, 1, 1}, 0.0f, 1.0f); // Scale for each channel + auto* add_bias_arg = builder.MakeInitializer({1, 8, 1, 1}, 0.0f, 1.0f); // Bias for each channel + auto* mul_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + builder.AddNode("Mul", {input_arg, mul_scale_arg}, {mul_out}); + builder.AddNode("Add", {mul_out, add_bias_arg}, {output_arg}); + }; + + // Pre-graph checker: Verify initial graph has Mul and Add + auto pre_graph_checker = [](Graph& graph) { + std::map op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 1); + TEST_RETURN_IF_NOT(op_to_count["BatchNormalization"] == 0); // Should not have BN initially + return Status::OK(); + }; + + // Post-graph checker: Verify graph after fusion has BatchNormalization + auto post_graph_checker = [](Graph& graph) { + std::map op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_to_count["BatchNormalization"] == 1); // Should have BN after fusion + + // Additional checks for the fused BatchNormalization node + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "BatchNormalization") { + // BatchNormalization has 5 inputs: X, scale, B, mean, var + TEST_RETURN_IF_NOT(node.InputDefs().size() == 5); + + // Verify the output shape matches the original input shape + auto output_shape_proto = node.OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(output_shape_proto != nullptr); + auto shape = utils::GetTensorShapeFromTensorShapeProto(*output_shape_proto); + TEST_RETURN_IF_NOT(shape.NumDimensions() == 4); + TEST_RETURN_IF_NOT(shape[0] == 4 && shape[1] == 8 && shape[2] == 32 && shape[3] == 32); + } + } + return Status::OK(); + }; + + // Instantiate the MulAddFusion transformer + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 7, + *logger_, + std::move(transformer), + TransformerLevel::Level1, + 1, + pre_graph_checker, + post_graph_checker)); +} + #ifndef DISABLE_CONTRIB_OPS TEST_F(GraphTransformationTests, Gemm_Relu_three_input) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx"; diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 3cc7eb8f3675f..06c523145ebdd 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4517,6 +4517,9 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { // changes during the layout transformation process. ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + // MulAddFusion fuse Mul and Add into Batchnormalization and prevent the Transpose pushing + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableSpecifiedOptimizers, "MulAddFusion")); + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; // set the test EP to support all ops in the model so that the layout transform applies to all nodes From e182ee16f7f03d065ef41d3fcb409c7f7d1614c5 Mon Sep 17 00:00:00 2001 From: qti-hungjuiw Date: Sun, 26 Oct 2025 18:30:51 +0800 Subject: [PATCH 2/2] Fix crash issue on Mul with nullptr shape and add logging messages --- onnxruntime/core/optimizer/mul_add_fusion.cc | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/mul_add_fusion.cc b/onnxruntime/core/optimizer/mul_add_fusion.cc index 0b28dd6ee7fed..d0c816e0cde4e 100644 --- a/onnxruntime/core/optimizer/mul_add_fusion.cc +++ b/onnxruntime/core/optimizer/mul_add_fusion.cc @@ -17,13 +17,18 @@ namespace onnxruntime { bool IsPatternBatchnorm(const NodeArg* inp, const ONNX_NAMESPACE::TensorProto* scale, - const ONNX_NAMESPACE::TensorProto* bias) { + const ONNX_NAMESPACE::TensorProto* bias, + const logging::Logger& logger) { + if (!inp->Shape()) { + LOGS(logger, VERBOSE) << "Skip MulAddFusion since " << inp->Name() << " has nullptr on shape"; + return false; + } int inp_rank = inp->Shape()->dim_size(); int scale_rank = scale->dims_size(); int bias_rank = bias->dims_size(); int max_rank = std::max({inp_rank, scale_rank, bias_rank}); - // Currently we do not support mul + add with rank-1 inputs if (max_rank <= 1) { + LOGS(logger, VERBOSE) << "Skip MulAddFusion since the max rank among " << inp->Name() << ", " << scale->name() << " and " << bias->name() << "is <= 1"; return false; } std::vector broadcast_inp(max_rank); @@ -39,20 +44,23 @@ bool IsPatternBatchnorm(const NodeArg* inp, // Note: The num_channel can be 1 int64_t num_channel = broadcast_inp[1]; if ((broadcast_scale[0] != 1) || (broadcast_scale[1] != 1 && broadcast_scale[1] != num_channel)) { + LOGS(logger, VERBOSE) << "Skip MulAddFusion since " << scale->name() << " has unsupported shape."; return false; } if ((broadcast_bias[0] != 1) || (broadcast_bias[1] != 1 && broadcast_bias[1] != num_channel)) { + LOGS(logger, VERBOSE) << "Skip MulAddFusion since " << bias->name() << " has unsupported shape."; return false; } for (int idx = 2; idx < max_rank; ++idx) { if (broadcast_scale[idx] != 1 || broadcast_bias[idx] != 1) { + LOGS(logger, VERBOSE) << "Skip MulAddFusion since " << scale->name() << " or " << bias->name() << " has unsupported shape."; return false; } } return true; } -bool MulAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { +bool MulAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { auto& mul_node = node; if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) || mul_node.GetOutputEdgesCount() != 1) { @@ -66,6 +74,7 @@ bool MulAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const } // Pattern: Input -> Mul -> Add if (mul_node.InputDefs().size() != 2 || add_node.InputDefs().size() != 2) { + LOGS(logger, VERBOSE) << "Skip MulAddFusion since " << mul_node.Name() << " or " << add_node.Name() << " has more than 2 inputs."; return false; } @@ -76,9 +85,11 @@ bool MulAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const bool is_const_add_in0 = graph_utils::NodeArgIsConstant(graph, *add_node.InputDefs()[0]); bool is_const_add_in1 = graph_utils::NodeArgIsConstant(graph, *add_node.InputDefs()[1]); if ((is_const_mul_in0 && is_const_mul_in1) || (!is_const_mul_in0 && !is_const_mul_in1)) { + LOGS(logger, VERBOSE) << "Skip MulAddFusion since " << mul_node.Name() << " should have exactly 1 constant and 1 non-contant input."; return false; } if ((is_const_add_in0 && is_const_add_in1) || (!is_const_add_in0 && !is_const_add_in1)) { + LOGS(logger, VERBOSE) << "Skip MulAddFusion since " << add_node.Name() << " should have exactly 1 constant and 1 non-contant input."; return false; } @@ -87,7 +98,8 @@ bool MulAddFusion::SatisfyCondition(const Graph& graph, const Node& node, const return IsPatternBatchnorm( mul_node.InputDefs()[1 - mul_const_idx], graph_utils::GetConstantInitializer(graph, mul_node.InputDefs()[mul_const_idx]->Name()), - graph_utils::GetConstantInitializer(graph, add_node.InputDefs()[add_const_idx]->Name())); + graph_utils::GetConstantInitializer(graph, add_node.InputDefs()[add_const_idx]->Name()), + logger); } Status MulAddFusion::FuseMulAdd(Node& node, Graph& graph, bool& modified, const logging::Logger&) const {