Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NaN propagation for float16 min and max operators #22161

Merged
merged 7 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions onnxruntime/core/providers/cpu/math/element_wise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
EigenVectorArrayMap<Eigen::half> output_vec_map(output, num_elements);

if (is_min) {
output_vec_map = input_1_vec_map.min(static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
output_vec_map = input_1_vec_map.template min<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
} else {
output_vec_map = input_1_vec_map.max(static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
output_vec_map = input_1_vec_map.template max<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
}
},
[](BroadcastHelper& per_iter_bh) {
Expand All @@ -772,9 +774,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
EigenVectorArrayMap<Eigen::half> output_vec_map(output, num_elements);

if (is_min) {
output_vec_map = input_0_vec_map.min(static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
output_vec_map = input_0_vec_map.template min<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
} else {
output_vec_map = input_0_vec_map.max(static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
output_vec_map = input_0_vec_map.template max<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
}
},
[](BroadcastHelper& per_iter_bh) {
Expand All @@ -790,9 +794,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
EigenVectorArrayMap<Eigen::half> output_vec_map(output, num_elements);

if (is_min) {
output_vec_map = input_0_vec_map.min(input_1_vec_map);
output_vec_map = input_0_vec_map.template min<Eigen::PropagateNaN>(input_1_vec_map);
} else {
output_vec_map = input_0_vec_map.max(input_1_vec_map);
output_vec_map = input_0_vec_map.template max<Eigen::PropagateNaN>(input_1_vec_map);
}
}};

Expand Down
61 changes: 57 additions & 4 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
#include <math.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>

Check warning on line 13 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:13: Found C system header after C++ system header. Should be: common.h, c system, c++ system, other. [build/include_order] [4]
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"

#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif

namespace onnxruntime {
namespace cuda {

Expand Down Expand Up @@ -347,6 +344,21 @@
template <>
__device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); }

#define ISNAN_HALF(v__) static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&v__) & ~MLFloat16::kSignMask) \
> MLFloat16::kPositiveInfinityBits

#define ISNAN_BFLOAT16(v__) static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&v__) & ~BFloat16::kSignMask) \
> BFloat16::kPositiveInfinityBits

// CUDART_NAN_BF16 and CUDART_NAN_FP16 constants were only added in CUDA 12.2,
// so define our own equivalent constants to support older versions.
// Note that there is no consistent canonical NaN for FP16 and BF16;
// CUDA uses 0x7FFF for both, but ONNX Runtime uses 0x7E00 and 0x7FC1
// for FP16 and BF16 respectively
// (see Float16Impl::kPositiveQNaNBits and BFloat16Impl::kPositiveQNaNBits).
#define NAN_HALF __ushort_as_half((unsigned short)0x7FFFU)

Check warning on line 359 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Use int16/int64/etc, rather than the C type short [runtime/int] [4] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:359: Use int16/int64/etc, rather than the C type short [runtime/int] [4]
#define NAN_BFLOAT16 BFloat16::FromBits((uint16_t)0x7FFFU)

template <typename T>
__device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; }

Expand All @@ -360,6 +372,24 @@
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a < b ? a : b );
}

template <>
__device__ __inline__ half _Min(half a, half b) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))

Check warning on line 377 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:377: Lines should be <= 120 characters long [whitespace/line_length] [2]
return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a < b ? a : b);
#else
return __hmin_nan(a, b);
#endif
}

template <>
__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))

Check warning on line 386 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:386: Lines should be <= 120 characters long [whitespace/line_length] [2]
return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a < b ? a : b);
#else
return BFloat16(__hmin_nan((__nv_bfloat16)a, (__nv_bfloat16)b));
#endif
}

template <typename T>
__device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }

Expand All @@ -373,6 +403,29 @@
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a > b ? a : b );
}

template <>
__device__ __inline__ half _Max(half a, half b) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))

Check warning on line 408 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:408: Lines should be <= 120 characters long [whitespace/line_length] [2]
return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a > b ? a : b);
#else
return __hmax_nan(a, b);
#endif
}

template <>
__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))

Check warning on line 417 in onnxruntime/core/providers/cuda/cu_inc/common.cuh

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cu_inc/common.cuh:417: Lines should be <= 120 characters long [whitespace/line_length] [2]
return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a > b ? a : b);
#else
return BFloat16(__hmax_nan((__nv_bfloat16)a, (__nv_bfloat16)b));
#endif
}

#undef ISNAN_HALF
#undef ISNAN_BFLOAT16
#undef NAN_HALF
#undef NAN_BFLOAT16

template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }

Expand Down
209 changes: 124 additions & 85 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1787,54 +1787,90 @@ TEST(MathOpTest, Min_12_MLFloat16_Scalar1) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_MLFloat16_MatrixVector) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {3, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f}));
test.AddInput<MLFloat16>("data_1", {3, 1},
MakeMLFloat16({0.0f, -1.0f, 1.0f}));
test.AddOutput<MLFloat16>("min", {3, 3},
MakeMLFloat16({0.0f, 0.0f, 0.0f,
-1.0f, -1.0f, -2.0f,
0.5f, 0.0f, 1.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
void TestFloat16MinMax(
const char* op_name,
const std::vector<int64_t>& lhs_dim,
const std::initializer_list<float>& lhs_values,
const std::vector<int64_t>& rhs_dim,
const std::initializer_list<float>& rhs_values,
const std::vector<int64_t>& out_dim,
const std::initializer_list<float>& out_values) {
{
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
if (nullptr != DefaultCpuExecutionProvider()) {
execution_providers.push_back(DefaultCpuExecutionProvider());
}
if (nullptr != DefaultCudaExecutionProvider()) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
OpTester test(op_name, 13);
test.AddInput<MLFloat16>("data_0", lhs_dim, MakeMLFloat16(lhs_values));
test.AddInput<MLFloat16>("data_1", rhs_dim, MakeMLFloat16(rhs_values));
test.AddOutput<MLFloat16>("output", out_dim, MakeMLFloat16(out_values));
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_MLFloat16_VectorMatrix) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {3, 1},
MakeMLFloat16({0.0f, -1.0f, 1.0f}));
test.AddInput<MLFloat16>("data_1", {3, 4},
MakeMLFloat16({1.0f, 1.0f, 1.0f, -1.0f,
-0.5f, 0.0f, -2.0f, -1.25f,
0.5f, 0.0f, 2.0f, 1.5f}));
test.AddOutput<MLFloat16>("min", {3, 4},
MakeMLFloat16({0.0f, 0.0f, 0.0f, -1.0f,
-1.0f, -1.0f, -2.0f, -1.25f,
0.5f, 0.0f, 1.0f, 1.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
OpTester test(op_name, 13);
test.AddInput<BFloat16>("data_0", lhs_dim, MakeBFloat16(lhs_values));
test.AddInput<BFloat16>("data_1", rhs_dim, MakeBFloat16(rhs_values));
test.AddOutput<BFloat16>("output", out_dim, MakeBFloat16(out_values));
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_13_Float16_MatrixVector) {
TestFloat16MinMax("Min",
{3, 3},
{1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f},
{3, 1}, {0.0f, -1.0f, 1.0f},
{3, 3},
{0.0f, 0.0f, 0.0f,
-1.0f, -1.0f, -2.0f,
0.5f, 0.0f, 1.0f});
}

TEST(MathOpTest, Min_13_Float16_VectorMatrix) {
TestFloat16MinMax("Min",
{3, 1}, {0.0f, -1.0f, 1.0f},
{3, 4},
{1.0f, 1.0f, 1.0f, -1.0f,
-0.5f, 0.0f, -2.0f, -1.25f,
0.5f, 0.0f, 2.0f, 1.5f},
{3, 4},
{0.0f, 0.0f, 0.0f, -1.0f,
-1.0f, -1.0f, -2.0f, -1.25f,
0.5f, 0.0f, 1.0f, 1.0f});
}

TEST(MathOpTest, Min_13_Float16_Nan) {
TestFloat16MinMax("Min",
{4, 1}, {-1.0f, std::numeric_limits<float>::quiet_NaN(), 1.0f, 0.5f},
{4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits<float>::quiet_NaN()},
{4, 1},
{-1.0f, std::numeric_limits<float>::quiet_NaN(), 0.25f, std::numeric_limits<float>::quiet_NaN()});
}

TEST(MathOpTest, Min_13_Float16_Nan_with_scalar) {
TestFloat16MinMax("Min",
{3, 1}, {-1.0f, std::numeric_limits<float>::quiet_NaN(), 1.0f},
{1}, {0.25f},
{3, 1}, {-1.0f, std::numeric_limits<float>::quiet_NaN(), 0.25f});
}

TEST(MathOpTest, Min_13_Float16_with_scalar_Nan) {
TestFloat16MinMax("Min",
{3, 1}, {-0.5f, 1.0f, 1.5f},
{1}, {std::numeric_limits<float>::quiet_NaN()},
{3, 1},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN()});
}
TEST(MathOpTest, Max_6) {
OpTester test("Max", 6);
std::vector<int64_t> dims{3, 3};
Expand Down Expand Up @@ -2185,54 +2221,57 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_MLFloat16_MatrixVector) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {4, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -2.0f,
0.0f, 0.5f, 0.75f,
0.5f, 0.0f, 2.0f}));
test.AddInput<MLFloat16>("data_1", {4, 1},
MakeMLFloat16({0.0f, -1.0f, 0.5f, 1.0f}));
test.AddOutput<MLFloat16>("max", {4, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -1.0f,
0.5f, 0.5f, 0.75f,
1.0f, 1.0f, 2.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_MLFloat16_VectorMatrix) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {3, 1},
MakeMLFloat16({0.0f, -1.0f, 1.0f}));
test.AddInput<MLFloat16>("data_1", {3, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f}));
test.AddOutput<MLFloat16>("max", {3, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -1.0f,
1.0f, 1.0f, 2.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
TEST(MathOpTest, Max_13_Float16_MatrixVector) {
TestFloat16MinMax("Max",
{4, 3},
{1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -2.0f,
0.0f, 0.5f, 0.75f,
0.5f, 0.0f, 2.0f},
{4, 1}, {0.0f, -1.0f, 0.5f, 1.0f},
{4, 3},
{1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -1.0f,
0.5f, 0.5f, 0.75f,
1.0f, 1.0f, 2.0f});
}

TEST(MathOpTest, Max_13_Float16_VectorMatrix) {
TestFloat16MinMax("Max",
{3, 1}, {0.0f, -1.0f, 1.0f},
{3, 3},
{1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f},
{3, 3},
{1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -1.0f,
1.0f, 1.0f, 2.0f});
}

TEST(MathOpTest, Max_13_Float16_Nan) {
TestFloat16MinMax("Max",
{4, 1}, {-1.0f, std::numeric_limits<float>::quiet_NaN(), 1.0f, 0.5f},
{4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits<float>::quiet_NaN()},
{4, 1},
{0.5f, std::numeric_limits<float>::quiet_NaN(), 1.0f, std::numeric_limits<float>::quiet_NaN()});
}

TEST(MathOpTest, Max_13_Float16_Nan_with_scalar) {
TestFloat16MinMax("Max",
{3, 1}, {-1.0f, std::numeric_limits<float>::quiet_NaN(), 1.0f},
{1}, {0.25f},
{3, 1}, {0.25f, std::numeric_limits<float>::quiet_NaN(), 1.0f});
}

TEST(MathOpTest, Max_13_Float16_with_scalar_Nan) {
TestFloat16MinMax("Max",
{3, 1}, {-0.5f, 1.0f, 1.5f},
{1}, {std::numeric_limits<float>::quiet_NaN()},
{3, 1},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN()});
}

TEST(MathOpTest, Not) {
Expand Down
Loading