Skip to content

Commit

Permalink
Search reductions use correct branch for float16
Browse files Browse the repository at this point in the history
constexpr branch logic accounted for floating point types but not sycl::half,
which meant NaNs were not propagating for float16 data
  • Loading branch information
ndgrigorian committed Nov 3, 2023
1 parent 5709f99 commit 119d43d
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions dpctl/tensor/libtensor/include/kernels/reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3476,7 +3476,9 @@ struct SequentialSearchReduction
idx_val = static_cast<outT>(m);
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val < red_val || std::isnan(val)) {
red_val = val;
idx_val = static_cast<outT>(m);
Expand All @@ -3501,7 +3503,9 @@ struct SequentialSearchReduction
idx_val = static_cast<outT>(m);
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val > red_val || std::isnan(val)) {
red_val = val;
idx_val = static_cast<outT>(m);
Expand Down Expand Up @@ -3789,7 +3793,9 @@ struct CustomSearchReduction
}
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val < local_red_val || std::isnan(val)) {
local_red_val = val;
if constexpr (!First) {
Expand Down Expand Up @@ -3833,7 +3839,9 @@ struct CustomSearchReduction
}
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val > local_red_val || std::isnan(val)) {
local_red_val = val;
if constexpr (!First) {
Expand Down Expand Up @@ -3876,7 +3884,9 @@ struct CustomSearchReduction
? local_idx
: idx_identity_;
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
// equality does not hold for NaNs, so check here
local_idx =
(red_val_over_wg == local_red_val || std::isnan(local_red_val))
Expand Down

0 comments on commit 119d43d

Please sign in to comment.