Skip to content

Commit

Permalink
Fix bug in top_k partial merge sort implementation
Browse files Browse the repository at this point in the history
rounded value of k must be divisible by the merge sort chunk size
  • Loading branch information
ndgrigorian committed Dec 13, 2024
1 parent 5b0b80f commit 8bcb100
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ namespace merge_sort_detail

using namespace dpctl::tensor::kernels::search_sorted_detail;

size_t get_merge_segment_size(const sycl::device &dev)
{
return dev.has(sycl::aspect::cpu) ? 32 : 4;
}

/*! @brief Merge two contiguous sorted segments */
template <typename InAcc, typename OutAcc, typename Compare>
void merge_impl(const std::size_t offset,
Expand Down Expand Up @@ -580,7 +585,7 @@ merge_sorted_block_contig_impl(sycl::queue &q,
// experimentally determined value
// size of segments worked upon by each work-item during merging
const sycl::device &dev = q.get_device();
const size_t segment_size = (dev.has(sycl::aspect::cpu)) ? 32 : 4;
const size_t segment_size = get_merge_segment_size(dev);

const size_t chunk_size =
(sorted_block_size < segment_size) ? sorted_block_size : segment_size;
Expand Down
17 changes: 10 additions & 7 deletions dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "kernels/dpctl_tensor_types.hpp"
#include "merge_sort.hpp"
#include "radix_sort.hpp"
#include "search_sorted_detail.hpp"
#include "utils/sycl_alloc_utils.hpp"
#include <sycl/ext/oneapi/sub_group_mask.hpp>

Expand Down Expand Up @@ -247,14 +248,17 @@ sycl::event topk_merge_impl(
// This assumption permits doing away with using a loop
assert(sorted_block_size % lws == 0);

using search_sorted_detail::quotient_ceil;
const std::size_t n_segments =
merge_sort_detail::quotient_ceil<std::size_t>(axis_nelems,
sorted_block_size);
quotient_ceil<std::size_t>(axis_nelems, sorted_block_size);

// round k up for the later merge kernel
// round k up for the later merge kernel if necessary
const std::size_t round_k_to =
merge_sort_detail::get_merge_segment_size(dev);
std::size_t k_rounded =
merge_sort_detail::quotient_ceil<std::size_t>(k, elems_per_wi) *
elems_per_wi;
(k < round_k_to)
? k
: quotient_ceil<std::size_t>(k, round_k_to) * round_k_to;

// get length of tail for alloc size
auto rem = axis_nelems % sorted_block_size;
Expand Down Expand Up @@ -322,8 +326,7 @@ sycl::event topk_merge_impl(
sycl::group_barrier(it.get_group());

const std::size_t chunk =
merge_sort_detail::quotient_ceil<std::size_t>(
sorted_block_size, lws);
quotient_ceil<std::size_t>(sorted_block_size, lws);

const std::size_t chunk_start_idx = lid * chunk;
const std::size_t chunk_end_idx =
Expand Down

0 comments on commit 8bcb100

Please sign in to comment.