diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index b9e2918c8c..40e5bd282d 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -50,12 +50,18 @@ namespace tensor namespace kernels { +template struct needs_workaround +{ + static constexpr bool value = + std::is_same_v> && + (std::is_same_v || std::is_same_v); +}; + template struct can_use_reduce_over_group { static constexpr bool value = sycl::has_known_identity::value && - !std::is_same_v && !std::is_same_v && - !std::is_same_v>; + !needs_workaround::value; }; template ()); + max_max_wg, d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -1339,7 +1345,7 @@ sycl::event reduction_over_group_temps_strided_impl( static_cast(remaining_reduction_nelems)}; ResIndexerT res_iter_indexer{iter_nd, iter_res_offset, /* shape */ iter_shape_and_strides, - /*s trides */ iter_shape_and_strides + + /* strides */ iter_shape_and_strides + 2 * iter_nd}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -1424,8 +1430,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( py::ssize_t reduction_arg_offset, const std::vector &depends) { - const argTy *arg_tp = reinterpret_cast(arg_cp); - resTy *res_tp = reinterpret_cast(res_cp); + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; constexpr resTy identity_val = su_ns::Identity::value; @@ -1437,7 +1444,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( // max_max_wg prevents running out of resources on CPU constexpr size_t max_max_wg = 2048; size_t max_wg = std::min( - max_max_wg, d.get_info()); + max_max_wg, d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -1767,8 +1774,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( py::ssize_t reduction_arg_offset, const std::vector &depends) { - const argTy *arg_tp = reinterpret_cast(arg_cp); - resTy *res_tp = reinterpret_cast(res_cp); + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; constexpr resTy identity_val = su_ns::Identity::value; @@ -1780,7 +1788,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( // max_max_wg prevents running out of resources on CPU constexpr size_t max_max_wg = 2048; size_t max_wg = std::min( - max_max_wg, d.get_info()); + max_max_wg, d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -3875,8 +3883,9 @@ sycl::event search_over_group_temps_strided_impl( constexpr size_t preferrered_reductions_per_wi = 4; // max_max_wg prevents running out of resources on CPU - size_t max_wg = std::min( - size_t(2048), d.get_info()); + size_t max_wg = + std::min(size_t(2048), + d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -4258,8 +4267,9 @@ sycl::event search_axis1_over_group_temps_contig_impl( py::ssize_t reduction_arg_offset, const std::vector &depends) { - const argTy *arg_tp = reinterpret_cast(arg_cp); - resTy *res_tp = reinterpret_cast(res_cp); + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; constexpr argTy identity_val = su_ns::Identity::value; constexpr resTy idx_identity_val = su_ns::Identity::value; @@ -4270,8 +4280,9 @@ sycl::event search_axis1_over_group_temps_contig_impl( constexpr size_t preferrered_reductions_per_wi = 8; // max_max_wg prevents running out of resources on CPU - size_t max_wg = std::min( - size_t(2048), d.get_info()); + size_t max_wg = + std::min(size_t(2048), + d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -4635,8 +4646,9 @@ sycl::event search_axis0_over_group_temps_contig_impl( py::ssize_t reduction_arg_offset, const std::vector &depends) { - const argTy *arg_tp = reinterpret_cast(arg_cp); - resTy *res_tp = reinterpret_cast(res_cp); + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; constexpr argTy identity_val = su_ns::Identity::value; constexpr resTy idx_identity_val = su_ns::Identity::value; @@ -4647,8 +4659,9 @@ sycl::event search_axis0_over_group_temps_contig_impl( constexpr size_t preferrered_reductions_per_wi = 8; // max_max_wg prevents running out of resources on CPU - size_t max_wg = std::min( - size_t(2048), d.get_info()); + size_t max_wg = + std::min(size_t(2048), + d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index a4e202f073..fbfd9547e1 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -43,6 +43,7 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(arg_dtype, q) + # test reduction for C-contiguous input m = dpt.ones(100, dtype=arg_dtype) r = dpt.sum(m) @@ -55,12 +56,20 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype): assert r.dtype.kind == "f" elif m.dtype.kind == "c": assert r.dtype.kind == "c" + assert dpt.all(r == 100) + # test reduction for strided input m = dpt.ones(200, dtype=arg_dtype)[:1:-2] r = dpt.sum(m) assert dpt.all(r == 99) + # test reduction for strided input which can be simplified + # to contiguous computation + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.sum(dpt.flip(m)) + assert dpt.all(r == 100) + @pytest.mark.parametrize("arg_dtype", _all_dtypes) @pytest.mark.parametrize("out_dtype", _all_dtypes[1:]) diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index 73cf9459a7..56059e54b8 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -169,6 +169,17 @@ def test_search_reduction_kernels(arg_dtype): m = dpt.argmax(x) assert m == idx + # test case of strided input mapping to contig + # implementation + m = dpt.argmax(dpt.flip(x)) + assert m == x.size - 1 - idx + + # test case of strided implementation + y = dpt.ones(2 * x.size, dtype=arg_dtype, sycl_queue=q) + y[::2] = x + m = dpt.argmax(y) + assert m == 2 * idx + x = dpt.reshape(x, (24, 1025)) x[idx_tup[0], :] = 3