Skip to content

Commit

Permalink
Quantizelinear nearbyint fix (#3819)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Feb 18, 2025
1 parent 770bc4a commit c8b73b1
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 22 deletions.
15 changes: 12 additions & 3 deletions src/include/migraphx/op/quantizelinear.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -87,8 +87,17 @@ struct quantizelinear
auto min_value = std::numeric_limits<quant_type>::lowest();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
double quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<double>(zero_pts[i]);
double quantized;
if constexpr(std::is_integral<quant_type>{})
{
quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<double>(zero_pts[i]);
}
else
{
quantized = static_cast<double>(input[i]) / static_cast<double>(scales[i]) +
static_cast<double>(zero_pts[i]);
}
output[i] = std::max(static_cast<double>(min_value),
std::min(static_cast<double>(max_value), quantized));
});
Expand Down
11 changes: 8 additions & 3 deletions src/rewrite_quantization.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
/*rby
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -47,7 +47,12 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);

instruction_ref add_zero_point = div;
if(shape::is_integral(ins->get_shape().type()))
{
add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
}

if(ins->inputs().size() == 3)
{
Expand Down
55 changes: 40 additions & 15 deletions test/ref/quantizelinear.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <algorithm>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
Expand Down Expand Up @@ -85,59 +86,83 @@ TEST_CASE(quantizelinear_2)
}

template <class DType>
void quantizelinear_3()
void quantizelinear_fp8e4m3()
{
migraphx::shape xs{migraphx::shape::float_type, {2, 2, 2}};
migraphx::shape zs{migraphx::shape::get_type<DType>{}, {2, 2, 2}};
std::vector<float> xv = {0.5, 0.75, -0.4375, 0.6875, -0.9375, -0.9375, 0.625, -0.5625};
std::vector<float> sv = {0.25, 0.75, 0.5625, 0.4375, 0.8125, -0.6875, 0.875, -0.0625};
std::vector<float> zv{0.6875, 0.75, -0.75, 0.5, -0.0625, 0.0625, -0.375, 0.25};
std::vector<float> tmp = {0.6875, 0.75, -0.75, 0.5, -0.0625, 0.0625, -0.375, 0.25};
std::vector<DType> zero_pts;
std::transform(
tmp.begin(), tmp.end(), std::back_inserter(zero_pts), [](auto x) { return DType(x); });
auto create_program = [&]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_literal(xs, xv);
auto s = mm->add_literal(xs, sv);
auto z = mm->add_literal(zs, zv);
auto z = mm->add_literal(zs, zero_pts);
mm->add_instruction(migraphx::make_op("quantizelinear"), x, s, z);
return p;
};

migraphx::program p1 = create_program();
p1.compile(migraphx::make_target("ref"));
auto result = p1.eval({}).back();
std::vector<float> results_vector(8);
std::vector<DType> results_vector(8);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.75, 1.75, -1.75, 2.5, -1, 1, 0.625, 9};
std::vector<DType> gold;
auto min_value = std::numeric_limits<DType>::lowest();
auto max_value = std::numeric_limits<DType>::max();
for(int i = 0; i < xv.size(); ++i)
{
double quantized = xv.at(i) / sv.at(i);
quantized = std::max(static_cast<double>(min_value),
std::min(static_cast<double>(max_value), quantized));
gold.push_back(DType(quantized + zero_pts.at(i)));
}
EXPECT(results_vector == gold);
}
TEST_CASE_REGISTER(quantizelinear_3<migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(quantizelinear_3<migraphx::fp8::fp8e4m3fn>);
TEST_CASE_REGISTER(quantizelinear_fp8e4m3<migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(quantizelinear_fp8e4m3<migraphx::fp8::fp8e4m3fn>);

template <class DType>
void quantizelinear_4()
void quantizelinear_fp8e5m2()
{
migraphx::shape xs{migraphx::shape::float_type, {2, 2, 2}};
migraphx::shape zs{migraphx::shape::get_type<DType>{}, {2, 2, 2}};
std::vector<float> xv = {0.5, 0.75, -0.4375, 0.625, -0.875, -0.875, 0.625, -0.5};
std::vector<float> sv = {0.25, 0.75, 0.5, 0.4375, 0.875, -0.625, 0.875, -0.0625};
std::vector<float> zv{0.625, 0.75, -0.75, 0.5, -0.0625, 0.0625, -0.375, 0.25};
std::vector<float> tmp = {0.6875, 0.75, -0.75, 0.5, -0.0625, 0.0625, -0.375, 0.25};
std::vector<DType> zero_pts;
std::transform(
tmp.begin(), tmp.end(), std::back_inserter(zero_pts), [](auto x) { return DType(x); });
auto create_program = [&]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_literal(xs, xv);
auto s = mm->add_literal(xs, sv);
auto z = mm->add_literal(zs, zv);
auto z = mm->add_literal(zs, zero_pts);
mm->add_instruction(migraphx::make_op("quantizelinear"), x, s, z);
return p;
};

migraphx::program p1 = create_program();
p1.compile(migraphx::make_target("ref"));
auto result = p1.eval({}).back();
std::vector<float> results_vector(8);
std::vector<DType> results_vector(8);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5, 1.75, -1.75, 1.5, -1, 1, 0.625, 8};
std::vector<DType> gold;
auto min_value = std::numeric_limits<DType>::lowest();
auto max_value = std::numeric_limits<DType>::max();
for(int i = 0; i < xv.size(); ++i)
{
double quantized = xv.at(i) / sv.at(i);
quantized = std::max(static_cast<double>(min_value),
std::min(static_cast<double>(max_value), quantized));
gold.push_back(DType(quantized + zero_pts.at(i)));
}
EXPECT(results_vector == gold);
}
TEST_CASE_REGISTER(quantizelinear_4<migraphx::fp8::fp8e5m2fnuz>);
TEST_CASE_REGISTER(quantizelinear_4<migraphx::fp8::fp8e5m2>);
TEST_CASE_REGISTER(quantizelinear_fp8e5m2<migraphx::fp8::fp8e5m2fnuz>);
TEST_CASE_REGISTER(quantizelinear_fp8e5m2<migraphx::fp8::fp8e5m2>);
116 changes: 115 additions & 1 deletion test/rewrite_quantization_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -120,4 +120,118 @@ TEST_CASE(dequantizelinear)
EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear));
}

// has a nearbyint operation
TEST_CASE(quantize_to_integral_type)
{
migraphx::shape xs{migraphx::shape::float_type, {2, 3, 3}};
std::vector<float> xv = {
-300, 600, 129, -1000, 4, 3, -6, 600, 550, -300, 600, 129, -1000, 4, 3, -6, 600, 550};
migraphx::shape ss{migraphx::shape::float_type, {2, 3, 3}};
std::vector<float> sv = {2, 2, 2, 4, 4, 4, 6, 6, 6, 2, 2, 2, 4, 4, 4, 6, 6, 6};
migraphx::shape zs{migraphx::shape::int8_type, {2, 3, 3}};
std::vector<uint8_t> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};

migraphx::program p_run;
{
auto* mm = p_run.get_main_module();
auto x = mm->add_literal(xs, xv);
auto s = mm->add_literal(ss, sv);
auto z = mm->add_literal(zs, zv);
mm->add_instruction(migraphx::make_op("quantizelinear"), x, s, z);
};

migraphx::program p_expected;
{
auto* mm = p_expected.get_main_module();
auto x = mm->add_literal(xs, xv);
auto s = mm->add_literal(ss, sv);
auto z = mm->add_literal(zs, zv);
auto div = mm->add_instruction(migraphx::make_op("div"), x, s);
auto nearby_int = mm->add_instruction(migraphx::make_op("nearbyint"), div);
auto zero_point = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), z);
auto add_zero_point = mm->add_instruction(migraphx::make_op("add"), nearby_int, zero_point);
double max_quant = 0;
double min_quant = 0;
auto zp_shape = add_zero_point->get_shape();
zs.visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
});
auto min_arg =
mm->add_literal(migraphx::literal{migraphx::shape{zp_shape.type()}, {min_quant}});
auto max_arg =
mm->add_literal(migraphx::literal{migraphx::shape{zp_shape.type()}, {max_quant}});
min_arg = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", zp_shape.lens()}}), min_arg);
max_arg = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", zp_shape.lens()}}), max_arg);
auto saturate =
mm->add_instruction(migraphx::make_op("clip"), {add_zero_point, min_arg, max_arg});
mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), saturate);
};

run_pass(*p_run.get_main_module());
EXPECT(p_run == p_expected);
}

// should not have a nearbyint operation
TEST_CASE(quantize_to_floating_point_type)
{
migraphx::shape xs{migraphx::shape::float_type, {2, 2, 2}};
migraphx::shape zs{migraphx::shape::get_type<migraphx::fp8::fp8e4m3fn>{}, {2, 2, 2}};
std::vector<float> xv = {0.5, 0.75, -0.4375, 0.6875, -0.9375, -0.9375, 0.625, -0.5625};
std::vector<float> sv = {0.25, 0.75, 0.5625, 0.4375, 0.8125, -0.6875, 0.875, -0.0625};
std::vector<float> tmp = {0.6875, 0.75, -0.75, 0.5, -0.0625, 0.0625, -0.375, 0.25};
std::vector<migraphx::fp8::fp8e4m3fn> zero_pts;
std::transform(tmp.begin(), tmp.end(), std::back_inserter(zero_pts), [](auto x) {
return migraphx::fp8::fp8e4m3fn(x);
});

migraphx::program p_run;
{
auto* mm = p_run.get_main_module();
auto x = mm->add_literal(xs, xv);
auto s = mm->add_literal(xs, sv);
auto z = mm->add_literal(zs, zero_pts);
mm->add_instruction(migraphx::make_op("quantizelinear"), x, s, z);
};

migraphx::program p_expected;
{
auto* mm = p_expected.get_main_module();
auto x = mm->add_literal(xs, xv);
auto s = mm->add_literal(xs, sv);
auto z = mm->add_literal(zs, zero_pts);
auto div = mm->add_instruction(migraphx::make_op("div"), x, s);
auto zero_point = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), z);
auto add_zero_point = mm->add_instruction(migraphx::make_op("add"), div, zero_point);
double max_quant = 0;
double min_quant = 0;
auto zp_shape = add_zero_point->get_shape();
zs.visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
});
auto min_arg =
mm->add_literal(migraphx::literal{migraphx::shape{zp_shape.type()}, {min_quant}});
auto max_arg =
mm->add_literal(migraphx::literal{migraphx::shape{zp_shape.type()}, {max_quant}});
min_arg = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", zp_shape.lens()}}), min_arg);
max_arg = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", zp_shape.lens()}}), max_arg);
auto saturate =
mm->add_instruction(migraphx::make_op("clip"), {add_zero_point, min_arg, max_arg});
mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fn_type}}),
saturate);
};

run_pass(*p_run.get_main_module());
EXPECT(p_run == p_expected);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit c8b73b1

Please sign in to comment.