diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp index f3b5030c48..316b8e26a6 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp @@ -46,6 +46,11 @@ namespace merge_sort_detail using namespace dpctl::tensor::kernels::search_sorted_detail; +size_t get_merge_segment_size(const sycl::device &dev) +{ + return dev.has(sycl::aspect::cpu) ? 32 : 4; +} + /*! @brief Merge two contiguous sorted segments */ template void merge_impl(const std::size_t offset, @@ -580,7 +585,7 @@ merge_sorted_block_contig_impl(sycl::queue &q, // experimentally determined value // size of segments worked upon by each work-item during merging const sycl::device &dev = q.get_device(); - const size_t segment_size = (dev.has(sycl::aspect::cpu)) ? 32 : 4; + const size_t segment_size = get_merge_segment_size(dev); const size_t chunk_size = (sorted_block_size < segment_size) ? sorted_block_size : segment_size; diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp index 11f3e851fb..f6fd239c4d 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -36,6 +36,7 @@ #include "kernels/dpctl_tensor_types.hpp" #include "merge_sort.hpp" #include "radix_sort.hpp" +#include "search_sorted_detail.hpp" #include "utils/sycl_alloc_utils.hpp" #include @@ -247,14 +248,17 @@ sycl::event topk_merge_impl( // This assumption permits doing away with using a loop assert(sorted_block_size % lws == 0); + using search_sorted_detail::quotient_ceil; const std::size_t n_segments = - merge_sort_detail::quotient_ceil(axis_nelems, - sorted_block_size); + quotient_ceil(axis_nelems, sorted_block_size); - // round k up for the later merge kernel + // round k up for the later merge kernel if necessary + const std::size_t round_k_to = + merge_sort_detail::get_merge_segment_size(dev); std::size_t k_rounded = - merge_sort_detail::quotient_ceil(k, elems_per_wi) * - elems_per_wi; + (k < round_k_to) + ? k + : quotient_ceil(k, round_k_to) * round_k_to; // get length of tail for alloc size auto rem = axis_nelems % sorted_block_size; @@ -322,8 +326,7 @@ sycl::event topk_merge_impl( sycl::group_barrier(it.get_group()); const std::size_t chunk = - merge_sort_detail::quotient_ceil( - sorted_block_size, lws); + quotient_ceil(sorted_block_size, lws); const std::size_t chunk_start_idx = lid * chunk; const std::size_t chunk_end_idx =