From b3e9465c30c92977f76ad956334f7d1e7b9352f2 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 26 Oct 2023 12:34:34 -0500 Subject: [PATCH 1/6] Implementations of reductions for contigous case must take offsets into account --- .../libtensor/include/kernels/reductions.hpp | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index b9e2918c8c..d5fddce6ed 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -1339,7 +1339,7 @@ sycl::event reduction_over_group_temps_strided_impl( static_cast(remaining_reduction_nelems)}; ResIndexerT res_iter_indexer{iter_nd, iter_res_offset, /* shape */ iter_shape_and_strides, - /*s trides */ iter_shape_and_strides + + /* strides */ iter_shape_and_strides + 2 * iter_nd}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -1424,8 +1424,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( py::ssize_t reduction_arg_offset, const std::vector &depends) { - const argTy *arg_tp = reinterpret_cast(arg_cp); - resTy *res_tp = reinterpret_cast(res_cp); + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; constexpr resTy identity_val = su_ns::Identity::value; @@ -1767,8 +1768,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( py::ssize_t reduction_arg_offset, const std::vector &depends) { - const argTy *arg_tp = reinterpret_cast(arg_cp); - resTy *res_tp = reinterpret_cast(res_cp); + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; constexpr resTy identity_val = su_ns::Identity::value; @@ -4258,8 +4260,9 @@ sycl::event search_axis1_over_group_temps_contig_impl( py::ssize_t reduction_arg_offset, const std::vector &depends) { - const argTy *arg_tp = reinterpret_cast(arg_cp); - resTy *res_tp = reinterpret_cast(res_cp); + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; constexpr argTy identity_val = su_ns::Identity::value; constexpr resTy idx_identity_val = su_ns::Identity::value; @@ -4635,8 +4638,9 @@ sycl::event search_axis0_over_group_temps_contig_impl( py::ssize_t reduction_arg_offset, const std::vector &depends) { - const argTy *arg_tp = reinterpret_cast(arg_cp); - resTy *res_tp = reinterpret_cast(res_cp); + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; constexpr argTy identity_val = su_ns::Identity::value; constexpr resTy idx_identity_val = su_ns::Identity::value; From c63c5451ec92620956d106df8f68ec5d8bc58680 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 26 Oct 2023 12:35:11 -0500 Subject: [PATCH 2/6] Expand test to cover non-contig. input that can be simplified into one --- dpctl/tests/test_tensor_sum.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index a4e202f073..fbfd9547e1 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -43,6 +43,7 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(arg_dtype, q) + # test reduction for C-contiguous input m = dpt.ones(100, dtype=arg_dtype) r = dpt.sum(m) @@ -55,12 +56,20 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype): assert r.dtype.kind == "f" elif m.dtype.kind == "c": assert r.dtype.kind == "c" + assert dpt.all(r == 100) + # test reduction for strided input m = dpt.ones(200, dtype=arg_dtype)[:1:-2] r = dpt.sum(m) assert dpt.all(r == 99) + # test reduction for strided input which can be simplified + # to contiguous computation + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.sum(dpt.flip(m)) + assert dpt.all(r == 100) + @pytest.mark.parametrize("arg_dtype", _all_dtypes) @pytest.mark.parametrize("out_dtype", _all_dtypes[1:]) From e92d1f9ad3a138c9c85e181c0e31e7293a4b8eb2 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 26 Oct 2023 12:43:05 -0500 Subject: [PATCH 3/6] Add tests for strided input where contig implementation is applicable --- dpctl/tests/test_usm_ndarray_reductions.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index 73cf9459a7..d500ce26b6 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -169,6 +169,14 @@ def test_search_reduction_kernels(arg_dtype): m = dpt.argmax(x) assert m == idx + m = dpt.argmax(dpt.flip(x)) + assert m == x.size - 1 - idx + + y = dpt.ones(2 * x.size, dtype=arg_dtype, sycl_queue=q) + y[::2] = x + m = dpt.argmax(y) + assert m == 2 * idx + x = dpt.reshape(x, (24, 1025)) x[idx_tup[0], :] = 3 From 702b707250d9ab226704d6ce3fee7e2307d0fdbb Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 26 Oct 2023 10:59:54 -0700 Subject: [PATCH 4/6] Added comments to the test file --- dpctl/tests/test_usm_ndarray_reductions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index d500ce26b6..56059e54b8 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -169,9 +169,12 @@ def test_search_reduction_kernels(arg_dtype): m = dpt.argmax(x) assert m == idx + # test case of strided input mapping to contig + # implementation m = dpt.argmax(dpt.flip(x)) assert m == x.size - 1 - idx + # test case of strided implementation y = dpt.ones(2 * x.size, dtype=arg_dtype, sycl_queue=q) y[::2] = x m = dpt.argmax(y) From dcb566a424a902af8ccb7e96021a3899e98c925b Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 26 Oct 2023 15:24:17 -0500 Subject: [PATCH 5/6] Corrected logical error in can_use_reduce_over_group trait implementation --- dpctl/tensor/libtensor/include/kernels/reductions.hpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index d5fddce6ed..884a7c5461 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -50,12 +50,18 @@ namespace tensor namespace kernels { +template struct needs_workaround +{ + static constexpr bool value = + std::is_same_v> && + (std::is_same_v || std::is_same_v); +}; + template struct can_use_reduce_over_group { static constexpr bool value = sycl::has_known_identity::value && - !std::is_same_v && !std::is_same_v && - !std::is_same_v>; + !needs_workaround::value; }; template Date: Thu, 26 Oct 2023 17:42:26 -0500 Subject: [PATCH 6/6] The taper optimization in tree-reduction which causes problem with CUDA The optimization should not use max-work-group-size, to allow RT some of the SLM memory. --- .../libtensor/include/kernels/reductions.hpp | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index 884a7c5461..40e5bd282d 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -1094,7 +1094,7 @@ sycl::event reduction_over_group_temps_strided_impl( // max_max_wg prevents running out of resources on CPU constexpr size_t max_max_wg = 2048; size_t max_wg = std::min( - max_max_wg, d.get_info()); + max_max_wg, d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -1444,7 +1444,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( // max_max_wg prevents running out of resources on CPU constexpr size_t max_max_wg = 2048; size_t max_wg = std::min( - max_max_wg, d.get_info()); + max_max_wg, d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -1788,7 +1788,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( // max_max_wg prevents running out of resources on CPU constexpr size_t max_max_wg = 2048; size_t max_wg = std::min( - max_max_wg, d.get_info()); + max_max_wg, d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -3883,8 +3883,9 @@ sycl::event search_over_group_temps_strided_impl( constexpr size_t preferrered_reductions_per_wi = 4; // max_max_wg prevents running out of resources on CPU - size_t max_wg = std::min( - size_t(2048), d.get_info()); + size_t max_wg = + std::min(size_t(2048), + d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -4279,8 +4280,9 @@ sycl::event search_axis1_over_group_temps_contig_impl( constexpr size_t preferrered_reductions_per_wi = 8; // max_max_wg prevents running out of resources on CPU - size_t max_wg = std::min( - size_t(2048), d.get_info()); + size_t max_wg = + std::min(size_t(2048), + d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { @@ -4657,8 +4659,9 @@ sycl::event search_axis0_over_group_temps_contig_impl( constexpr size_t preferrered_reductions_per_wi = 8; // max_max_wg prevents running out of resources on CPU - size_t max_wg = std::min( - size_t(2048), d.get_info()); + size_t max_wg = + std::min(size_t(2048), + d.get_info() / 2); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {