From c0bdba778e227b87ede16bab3b143ad09159e65c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Tue, 31 Dec 2024 10:16:11 -0600 Subject: [PATCH] Replace map_back_impl in sort_utils Change kernel to process few data elements in the work-item. --- .../include/kernels/sorting/sort_utils.hpp | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp index f62a6c3fa0..a81f528852 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp @@ -105,14 +105,35 @@ sycl::event map_back_impl(sycl::queue &exec_q, std::size_t row_size, const std::vector &dependent_events) { + constexpr std::uint32_t lws = 64; + constexpr std::uint32_t n_wi = 4; + const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws); + + sycl::range<1> lRange{lws}; + sycl::range<1> gRange{n_groups * lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(dependent_events); - cgh.parallel_for( - sycl::range<1>(nelems), [=](sycl::id<1> id) { - const IndexTy linear_index = flat_index_data[id]; - reduced_index_data[id] = (linear_index % row_size); - }); + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_max_local_range()[0]; + + const std::size_t start_id = (gid - lane_id) * n_wi + lane_id; + +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + const std::size_t data_id = start_id + i * sg_size; + + if (data_id < nelems) { + const IndexTy linear_index = flat_index_data[data_id]; + reduced_index_data[data_id] = (linear_index % row_size); + } + } + }); }); return map_back_ev;