Skip to content

Commit

Permalink
Replace map_back_impl in sort_utils
Browse files Browse the repository at this point in the history
Change kernel to process few data elements in the work-item.
  • Loading branch information
oleksandr-pavlyk committed Dec 31, 2024
1 parent 7811078 commit c0bdba7
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,35 @@ sycl::event map_back_impl(sycl::queue &exec_q,
std::size_t row_size,
const std::vector<sycl::event> &dependent_events)
{
constexpr std::uint32_t lws = 64;
constexpr std::uint32_t n_wi = 4;
const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws);

sycl::range<1> lRange{lws};
sycl::range<1> gRange{n_groups * lws};
sycl::nd_range<1> ndRange{gRange, lRange};

sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(dependent_events);

cgh.parallel_for<KernelName>(
sycl::range<1>(nelems), [=](sycl::id<1> id) {
const IndexTy linear_index = flat_index_data[id];
reduced_index_data[id] = (linear_index % row_size);
});
cgh.parallel_for<KernelName>(ndRange, [=](sycl::nd_item<1> it) {
const std::size_t gid = it.get_global_linear_id();
const auto &sg = it.get_sub_group();
const std::uint32_t lane_id = sg.get_local_id()[0];
const std::uint32_t sg_size = sg.get_max_local_range()[0];

const std::size_t start_id = (gid - lane_id) * n_wi + lane_id;

#pragma unroll
for (std::uint32_t i = 0; i < n_wi; ++i) {
const std::size_t data_id = start_id + i * sg_size;

if (data_id < nelems) {
const IndexTy linear_index = flat_index_data[data_id];
reduced_index_data[data_id] = (linear_index % row_size);
}
}
});
});

return map_back_ev;
Expand Down

0 comments on commit c0bdba7

Please sign in to comment.