From 11ecba8e282a3d34f1883819d2ed08010bba6036 Mon Sep 17 00:00:00 2001 From: ndgrigorian <46709016+ndgrigorian@users.noreply.github.com> Date: Wed, 1 Nov 2023 10:28:39 -0700 Subject: [PATCH] Fix search reductions giving incorrect results for F-contiguous inputs (#1462) * Fixes correctness regression in search functions ``py_search_over_axis`` no longer calls the ``axis1`` contiguous variant ``py_search_over_axis`` now only calls ``axis0`` variant wh * Adds tests for fixed search reduction behavior --- .../source/reductions/reduction_over_axis.hpp | 13 +++---------- dpctl/tests/test_usm_ndarray_reductions.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp index aa46f1c02a..f1b924dd47 100644 --- a/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp +++ b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp @@ -874,14 +874,11 @@ std::pair py_search_over_axis( int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); // handle special case when both reduction and iteration are 1D contiguous - // and can be done with atomics bool is_src_c_contig = src.is_c_contiguous(); bool is_dst_c_contig = dst.is_c_contiguous(); bool is_src_f_contig = src.is_f_contiguous(); - if ((is_src_c_contig && is_dst_c_contig) || - (is_src_f_contig && dst_nelems == 1)) - { + if (is_src_c_contig && is_dst_c_contig) { auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid]; if (fn != nullptr) { size_t iter_nelems = dst_nelems; @@ -903,9 +900,7 @@ std::pair py_search_over_axis( reduction_over_axis_contig_ev); } } - else if (is_src_f_contig && - ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) - { + else if (is_src_f_contig && dst_nd == 1) { auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid]; if (fn != nullptr) { size_t iter_nelems = dst_nelems; @@ -983,11 +978,9 @@ std::pair py_search_over_axis( if ((reduction_nd == 1) && (iteration_nd == 1)) { bool mat_reduce_over_axis1 = false; bool mat_reduce_over_axis0 = false; - bool array_reduce_all_elems = false; size_t iter_nelems = dst_nelems; if (compact_reduction_src_strides[0] == 1) { - array_reduce_all_elems = (simplified_iteration_shape[0] == 1); mat_reduce_over_axis1 = (simplified_iteration_dst_strides[0] == 1) && (static_cast(simplified_iteration_src_strides[0]) == @@ -1000,7 +993,7 @@ std::pair py_search_over_axis( (simplified_iteration_src_strides[0] == 1); } - if (mat_reduce_over_axis1 || array_reduce_all_elems) { + if (mat_reduce_over_axis1) { auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid]; if (fn != nullptr) { sycl::event reduction_over_axis1_contig_ev = diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index 45afb26aac..cbfd6baec6 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -265,6 +265,22 @@ def test_argmax_argmin_identities(): assert dpt.argmin(x) == 0 +@pytest.mark.parametrize("order", ["C", "F"]) +def test_argmax_axis0_axis1(order): + get_queue_or_skip() + + x = dpt.asarray([[1, 2, 3], [6, 5, 4]], dtype="i4", order=order) + assert dpt.argmax(x) == 3 + + res = dpt.argmax(x, axis=0) + expected = dpt.asarray([1, 1, 1], dtype=res.dtype) + assert dpt.all(res == expected) + + res = dpt.argmax(x, axis=1) + expected = dpt.asarray([2, 0], dtype=res.dtype) + assert dpt.all(res == expected) + + def test_reduction_arg_validation(): get_queue_or_skip()