From b3e9465c30c92977f76ad956334f7d1e7b9352f2 Mon Sep 17 00:00:00 2001
From: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
Date: Thu, 26 Oct 2023 12:34:34 -0500
Subject: [PATCH 1/6] Implementations of reductions for contigous case must
 take offsets into account

---
 .../libtensor/include/kernels/reductions.hpp  | 22 +++++++++++--------
 1 file changed, 13 insertions(+), 9 deletions(-)

diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp
index b9e2918c8c..d5fddce6ed 100644
--- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp
+++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp
@@ -1339,7 +1339,7 @@ sycl::event reduction_over_group_temps_strided_impl(
                 static_cast<py::ssize_t>(remaining_reduction_nelems)};
             ResIndexerT res_iter_indexer{iter_nd, iter_res_offset,
                                          /* shape */ iter_shape_and_strides,
-                                         /*s trides */ iter_shape_and_strides +
+                                         /* strides */ iter_shape_and_strides +
                                              2 * iter_nd};
 
             InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
@@ -1424,8 +1424,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
     py::ssize_t reduction_arg_offset,
     const std::vector<sycl::event> &depends)
 {
-    const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
-    resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
+    const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
+                          iter_arg_offset + reduction_arg_offset;
+    resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
 
     constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
 
@@ -1767,8 +1768,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
     py::ssize_t reduction_arg_offset,
     const std::vector<sycl::event> &depends)
 {
-    const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
-    resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
+    const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
+                          iter_arg_offset + reduction_arg_offset;
+    resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
 
     constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
 
@@ -4258,8 +4260,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
     py::ssize_t reduction_arg_offset,
     const std::vector<sycl::event> &depends)
 {
-    const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
-    resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
+    const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
+                          iter_arg_offset + reduction_arg_offset;
+    resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
 
     constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
     constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
@@ -4635,8 +4638,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
     py::ssize_t reduction_arg_offset,
     const std::vector<sycl::event> &depends)
 {
-    const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
-    resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
+    const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
+                          iter_arg_offset + reduction_arg_offset;
+    resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
 
     constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
     constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;

From c63c5451ec92620956d106df8f68ec5d8bc58680 Mon Sep 17 00:00:00 2001
From: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
Date: Thu, 26 Oct 2023 12:35:11 -0500
Subject: [PATCH 2/6] Expand test to cover non-contig. input that can be
 simplified into one

---
 dpctl/tests/test_tensor_sum.py | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py
index a4e202f073..fbfd9547e1 100644
--- a/dpctl/tests/test_tensor_sum.py
+++ b/dpctl/tests/test_tensor_sum.py
@@ -43,6 +43,7 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
     q = get_queue_or_skip()
     skip_if_dtype_not_supported(arg_dtype, q)
 
+    # test reduction for C-contiguous input
     m = dpt.ones(100, dtype=arg_dtype)
     r = dpt.sum(m)
 
@@ -55,12 +56,20 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
         assert r.dtype.kind == "f"
     elif m.dtype.kind == "c":
         assert r.dtype.kind == "c"
+
     assert dpt.all(r == 100)
 
+    # test reduction for strided input
     m = dpt.ones(200, dtype=arg_dtype)[:1:-2]
     r = dpt.sum(m)
     assert dpt.all(r == 99)
 
+    # test reduction for strided input which can be simplified
+    # to contiguous computation
+    m = dpt.ones(100, dtype=arg_dtype)
+    r = dpt.sum(dpt.flip(m))
+    assert dpt.all(r == 100)
+
 
 @pytest.mark.parametrize("arg_dtype", _all_dtypes)
 @pytest.mark.parametrize("out_dtype", _all_dtypes[1:])

From e92d1f9ad3a138c9c85e181c0e31e7293a4b8eb2 Mon Sep 17 00:00:00 2001
From: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
Date: Thu, 26 Oct 2023 12:43:05 -0500
Subject: [PATCH 3/6] Add tests for strided input where contig implementation
 is applicable

---
 dpctl/tests/test_usm_ndarray_reductions.py | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py
index 73cf9459a7..d500ce26b6 100644
--- a/dpctl/tests/test_usm_ndarray_reductions.py
+++ b/dpctl/tests/test_usm_ndarray_reductions.py
@@ -169,6 +169,14 @@ def test_search_reduction_kernels(arg_dtype):
     m = dpt.argmax(x)
     assert m == idx
 
+    m = dpt.argmax(dpt.flip(x))
+    assert m == x.size - 1 - idx
+
+    y = dpt.ones(2 * x.size, dtype=arg_dtype, sycl_queue=q)
+    y[::2] = x
+    m = dpt.argmax(y)
+    assert m == 2 * idx
+
     x = dpt.reshape(x, (24, 1025))
 
     x[idx_tup[0], :] = 3

From 702b707250d9ab226704d6ce3fee7e2307d0fdbb Mon Sep 17 00:00:00 2001
From: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
Date: Thu, 26 Oct 2023 10:59:54 -0700
Subject: [PATCH 4/6] Added comments to the test file

---
 dpctl/tests/test_usm_ndarray_reductions.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py
index d500ce26b6..56059e54b8 100644
--- a/dpctl/tests/test_usm_ndarray_reductions.py
+++ b/dpctl/tests/test_usm_ndarray_reductions.py
@@ -169,9 +169,12 @@ def test_search_reduction_kernels(arg_dtype):
     m = dpt.argmax(x)
     assert m == idx
 
+    # test case of strided input mapping to contig
+    # implementation
     m = dpt.argmax(dpt.flip(x))
     assert m == x.size - 1 - idx
 
+    # test case of strided implementation
     y = dpt.ones(2 * x.size, dtype=arg_dtype, sycl_queue=q)
     y[::2] = x
     m = dpt.argmax(y)

From dcb566a424a902af8ccb7e96021a3899e98c925b Mon Sep 17 00:00:00 2001
From: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
Date: Thu, 26 Oct 2023 15:24:17 -0500
Subject: [PATCH 5/6] Corrected logical error in can_use_reduce_over_group
 trait implementation

---
 dpctl/tensor/libtensor/include/kernels/reductions.hpp | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp
index d5fddce6ed..884a7c5461 100644
--- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp
+++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp
@@ -50,12 +50,18 @@ namespace tensor
 namespace kernels
 {
 
+template <typename ReductionOpT, typename T> struct needs_workaround
+{
+    static constexpr bool value =
+        std::is_same_v<ReductionOpT, sycl::multiplies<T>> &&
+        (std::is_same_v<T, std::int64_t> || std::is_same_v<T, std::uint64_t>);
+};
+
 template <typename ReductionOpT, typename T> struct can_use_reduce_over_group
 {
     static constexpr bool value =
         sycl::has_known_identity<ReductionOpT, T>::value &&
-        !std::is_same_v<T, std::int64_t> && !std::is_same_v<T, std::uint64_t> &&
-        !std::is_same_v<ReductionOpT, sycl::multiplies<T>>;
+        !needs_workaround<ReductionOpT, T>::value;
 };
 
 template <typename argT,

From bfba152db0ee15e2caaeac9c2dda6c74da628645 Mon Sep 17 00:00:00 2001
From: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
Date: Thu, 26 Oct 2023 17:42:26 -0500
Subject: [PATCH 6/6] The taper optimization in tree-reduction which causes
 problem with CUDA

The optimization should not use max-work-group-size, to allow RT some
of the SLM memory.
---
 .../libtensor/include/kernels/reductions.hpp  | 21 +++++++++++--------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp
index 884a7c5461..40e5bd282d 100644
--- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp
+++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp
@@ -1094,7 +1094,7 @@ sycl::event reduction_over_group_temps_strided_impl(
     // max_max_wg prevents running out of resources on CPU
     constexpr size_t max_max_wg = 2048;
     size_t max_wg = std::min(
-        max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
+        max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);
 
     size_t reductions_per_wi(preferrered_reductions_per_wi);
     if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1444,7 +1444,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
     // max_max_wg prevents running out of resources on CPU
     constexpr size_t max_max_wg = 2048;
     size_t max_wg = std::min(
-        max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
+        max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);
 
     size_t reductions_per_wi(preferrered_reductions_per_wi);
     if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1788,7 +1788,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
     // max_max_wg prevents running out of resources on CPU
     constexpr size_t max_max_wg = 2048;
     size_t max_wg = std::min(
-        max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
+        max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);
 
     size_t reductions_per_wi(preferrered_reductions_per_wi);
     if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -3883,8 +3883,9 @@ sycl::event search_over_group_temps_strided_impl(
 
     constexpr size_t preferrered_reductions_per_wi = 4;
     // max_max_wg prevents running out of resources on CPU
-    size_t max_wg = std::min(
-        size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
+    size_t max_wg =
+        std::min(size_t(2048),
+                 d.get_info<sycl::info::device::max_work_group_size>() / 2);
 
     size_t reductions_per_wi(preferrered_reductions_per_wi);
     if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4279,8 +4280,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
 
     constexpr size_t preferrered_reductions_per_wi = 8;
     // max_max_wg prevents running out of resources on CPU
-    size_t max_wg = std::min(
-        size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
+    size_t max_wg =
+        std::min(size_t(2048),
+                 d.get_info<sycl::info::device::max_work_group_size>() / 2);
 
     size_t reductions_per_wi(preferrered_reductions_per_wi);
     if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4657,8 +4659,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
 
     constexpr size_t preferrered_reductions_per_wi = 8;
     // max_max_wg prevents running out of resources on CPU
-    size_t max_wg = std::min(
-        size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
+    size_t max_wg =
+        std::min(size_t(2048),
+                 d.get_info<sycl::info::device::max_work_group_size>() / 2);
 
     size_t reductions_per_wi(preferrered_reductions_per_wi);
     if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {