Skip to content

Commit

Permalink
Add contrib groupnorm (#3678)
Browse files Browse the repository at this point in the history
GroupNorm Contrib Operator from Microsoft's Contrib set: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.GroupNorm

Needed for additional support on optimized Stable diffusion models outlined here: #3538

This operator maps almost all inputs 1:1 with our existing groupNormalization operator and renames a few fields/attributes
  • Loading branch information
TedThemistokleous authored Feb 13, 2025
1 parent f18e695 commit 851bcd2
Show file tree
Hide file tree
Showing 25 changed files with 787 additions and 16 deletions.
84 changes: 75 additions & 9 deletions src/onnx/parse_groupnorm.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -25,43 +25,96 @@
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/permutation.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

static instruction_ref
apply_channels_last_perm(const onnx_parser::node_info& info, instruction_ref ins, bool invert)
{
std::vector<int64_t> perm(ins->get_shape().ndim());
std::iota(perm.begin() + 1, perm.end() - 1, 2);
perm.back() = 1;
return info.add_instruction(
make_op("transpose", {{"permutation", invert ? invert_permutation(perm) : perm}}), ins);
}

struct parse_groupnorm : op_parser<parse_groupnorm>
{
std::vector<op_desc> operators() const { return {{"GroupNormalization"}}; }
std::vector<op_desc> operators() const
{
return {{"GroupNormalization", "GroupNorm"}, {"GroupNorm", "Contrib_GroupNorm"}};
}

instruction_ref parse(const op_desc& /*opd*/,
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
bool is_contrib = (opd.op_name == ("Contrib_GroupNorm"));

float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
size_t num_groups;
if(contains(info.attributes, "num_groups"))
if(contains(info.attributes, "num_groups") or contains(info.attributes, "groups"))
{
num_groups = parser.parse_value(info.attributes.at("num_groups")).at<size_t>();
if(is_contrib)
{
num_groups =
std::abs(parser.parse_value(info.attributes.at("groups")).at<int64_t>());
}
else
{
num_groups =
std::abs(parser.parse_value(info.attributes.at("num_groups")).at<int64_t>());
}
}
else
{
MIGRAPHX_THROW("PARSE_GROUPNORM: num_groups must be available");
}

bool is_channels_last = false;
if(is_contrib)
{ // default state for GroupNorm Contrib op
is_channels_last = true;
if(contains(info.attributes, "channels_last"))
{
is_channels_last =
(1 == parser.parse_value(info.attributes.at("channels_last")).at<size_t>());
}
}

bool silu_activation = false;
if(contains(info.attributes, "activation") and is_contrib)
{
silu_activation =
(1 == parser.parse_value(info.attributes.at("activation")).at<size_t>());
}
else if(is_contrib)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: activation must be available");
}

if(args.size() != 3)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input count");
}

auto x = args.at(0);
auto scale = args.at(1);
auto bias = args.at(2);
// Adjust chanels from channels_last-> NCHW if last channel is set for contrib op
auto x = args.at(0);
if(is_channels_last and is_contrib)
{
x = apply_channels_last_perm(info, x, true);
}

auto scale = args.at(1); // gamma in the GroupNorm contrib case
auto bias = args.at(2); // beta in the GroupNorm contrib case

auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
Expand Down Expand Up @@ -120,7 +173,20 @@ struct parse_groupnorm : op_parser<parse_groupnorm>
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = info.add_instruction(make_op("add"), scaled, bias_bcast);
return info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);
auto output = info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);

// Convert to NCHW -> channels_last for contrib GroupNorm
if(is_channels_last and is_contrib)
{
output = apply_channels_last_perm(info, output, false);
}
if(silu_activation)
{
// SiLU activation is just out = x * sigmoid(x)
auto sigmoid = info.add_instruction(make_op("sigmoid"), output);
output = info.add_instruction(make_op("mul"), output, sigmoid);
}
return output;
}
};

Expand Down
106 changes: 106 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4862,6 +4862,112 @@ def group_norm_invalid_bias_shape_test():
return group_norm_test([1, 4, 3, 3], [2], [3], [1, 4, 3, 3], 2)


def group_norm_contrib_test(x_dims,
gamma_dims,
beta_dims,
y_dims,
num_groups,
activation,
channels_last,
eps_value=1e-5,
dtype=TensorProto.FLOAT):
x = helper.make_tensor_value_info('x', dtype, x_dims)
gamma = helper.make_tensor_value_info('gamma', dtype, gamma_dims)
beta = helper.make_tensor_value_info('beta', dtype, beta_dims)
y = helper.make_tensor_value_info('y', dtype, y_dims)

node = onnx.helper.make_node('GroupNorm',
inputs=['x', 'gamma', 'beta'],
outputs=['y'],
activation=activation,
channels_last=channels_last,
groups=num_groups,
epsilon=eps_value)

return ([node], [x, gamma, beta], [y])


@onnx_test()
def group_norm_contrib_3d_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 0, 0)


@onnx_test()
def group_norm_contrib_3d_channel_last_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 0, 1)


@onnx_test()
def group_norm_contrib_3d_channel_last_half_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2],
2,
0,
1,
dtype=TensorProto.FLOAT16)


@onnx_test()
def group_norm_contrib_3d_channel_last_bf16_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2],
2,
0,
1,
dtype=TensorProto.BFLOAT16)


@onnx_test()
def group_norm_contrib_silu_3d_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 1, 0)


@onnx_test()
def group_norm_contrib_channels_last_3d_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 0, 1)


@onnx_test()
def group_norm_contrib_channels_last_4d_test():
return group_norm_contrib_test([1, 3, 3, 4], [2], [2], [1, 3, 3, 4], 2, 0,
1)


@onnx_test()
def group_norm_contrib_channels_last_and_silu_3d_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 1, 1)


@onnx_test()
def group_norm_contrib_no_activation_attr_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4, 2])
gamma = helper.make_tensor_value_info('gamma', TensorProto.FLOAT, [2])
beta = helper.make_tensor_value_info('beta', TensorProto.FLOAT, [2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4, 2])

node = onnx.helper.make_node('GroupNorm',
inputs=['x', 'gamma', 'Beta'],
outputs=['y'],
channels_last=0,
groups=2)

return ([node], [x, gamma, beta], [y])


@onnx_test()
def group_norm_contrib_no_num_groups_attr_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4, 2])
gamma = helper.make_tensor_value_info('gamma', TensorProto.FLOAT, [2])
beta = helper.make_tensor_value_info('beta', TensorProto.FLOAT, [2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4, 2])

node = onnx.helper.make_node('GroupNorm',
inputs=['x', 'gamma', 'Beta'],
outputs=['y'],
activation=0,
channels_last=0)

return ([node], [x, gamma, beta], [y])


@onnx_test()
def group_query_attention_test():
qkv = helper.make_tensor_value_info('qkv', TensorProto.FLOAT16,
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added test/onnx/group_norm_contrib_3d_test.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
29 changes: 29 additions & 0 deletions test/onnx/group_norm_contrib_channels_last_and_silu_3d_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

1group_norm_contrib_channels_last_and_silu_3d_test:�
i
x
gamma
betay" GroupNorm*

activation�*
channels_last�*
epsilon��'7�*
groups�1group_norm_contrib_channels_last_and_silu_3d_testZ
x



Z
gamma


Z
beta


b
y



B
Expand Down
Binary file not shown.
Binary file not shown.
Binary file added test/onnx/group_norm_contrib_silu_3d_test.onnx
Binary file not shown.
8 changes: 5 additions & 3 deletions test/onnx/include/onnx_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,16 @@ make_group_norm(const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& reshape_dims,
const std::vector<int64_t>& reduce_axes,
const float eps_value = 1e-5f,
const migraphx::shape::type_t dtype = migraphx::shape::float_type)
const migraphx::shape::type_t dtype = migraphx::shape::float_type,
const std::string& param1_name = "scale",
const std::string& param2_name = "bias")
{
migraphx::program p;
auto* mm = p.get_main_module();

auto x = mm->add_parameter("x", {dtype, input_dims});
auto scale = mm->add_parameter("scale", {dtype, scale_dims});
auto bias = mm->add_parameter("bias", {dtype, bias_dims});
auto scale = mm->add_parameter(param1_name, {dtype, scale_dims});
auto bias = mm->add_parameter(param2_name, {dtype, bias_dims});

auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});

Expand Down
10 changes: 6 additions & 4 deletions test/onnx/include/onnx_verify_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -33,7 +33,9 @@ template <typename T = float>
std::vector<T> norm_test(const std::vector<size_t>& x_dims,
std::vector<T>& scale,
std::vector<T>& bias,
migraphx::program p)
migraphx::program p,
const std::string& scale_str = std::string{"scale"},
const std::string& bias_str = std::string{"bias"})
{
p.compile(migraphx::make_target("ref"));

Expand All @@ -46,8 +48,8 @@ std::vector<T> norm_test(const std::vector<size_t>& x_dims,

migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, x.data());
pp["scale"] = migraphx::argument(s_s, scale.data());
pp["bias"] = migraphx::argument(s_b, bias.data());
pp[scale_str] = migraphx::argument(s_s, scale.data());
pp[bias_str] = migraphx::argument(s_b, bias.data());

auto result = p.eval(pp).back();

Expand Down
31 changes: 31 additions & 0 deletions test/onnx/parse/group_norm_contrib_3d_no_activation_err_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(group_norm_contrib_no_activation_err_test)
{
EXPECT(test::throws([&] { read_onnx("group_norm_contrib_no_activation_attr_test.onnx"); }));
}
31 changes: 31 additions & 0 deletions test/onnx/parse/group_norm_contrib_3d_no_num_groups_err_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(group_norm_contrib_no_num_groups_err_test)
{
EXPECT(test::throws([&] { read_onnx("group_norm_contrib_no_num_groups_attr_test.onnx"); }));
}
Loading

0 comments on commit 851bcd2

Please sign in to comment.