Skip to content

Commit

Permalink
Add implementation of top_k using radix sort
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Dec 5, 2024
1 parent a338383 commit a56e21c
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 63 deletions.
202 changes: 152 additions & 50 deletions dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include "kernels/dpctl_tensor_types.hpp"
#include "merge_sort.hpp"
#include "radix_sort.hpp"
#include "utils/sycl_alloc_utils.hpp"
#include <sycl/ext/oneapi/sub_group_mask.hpp>

Expand Down Expand Up @@ -70,31 +71,25 @@ void scale_topk_params(const std::uint64_t nelems_per_slm,
} // namespace topk_detail

template <typename T1, typename T2, typename T3>
class populate_index_data_full_sort_krn;
class topk_populate_index_data_krn;

template <typename T1, typename T2, typename T3>
class topk_map_to_rows_full_sort_krn;

template <typename T1, typename T2, typename T3> class populate_index_data_krn;

template <typename T1, typename T2, typename T3> class topk_map_to_rows_krn;
class topk_full_merge_map_back_krn;

template <typename argTy, typename IndexTy, typename CompT>
sycl::event topk_full_sort_impl(
sycl::queue &exec_q,
std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows in a
// matrix when sorting over rows)
std::size_t sort_nelems, // size of each array to sort (length of rows,
// i.e. number of columns)
std::size_t k,
const argTy *arg_tp,
argTy *vals_tp,
IndexTy *inds_tp,
const CompT &comp,
const std::vector<sycl::event> &depends)
sycl::event
topk_full_merge_sort_impl(sycl::queue &exec_q,
std::size_t iter_nelems, // number of sub-arrays
std::size_t axis_nelems, // size of each sub-array
std::size_t k,
const argTy *arg_tp,
argTy *vals_tp,
IndexTy *inds_tp,
const CompT &comp,
const std::vector<sycl::event> &depends)
{
IndexTy *index_data =
sycl::malloc_device<IndexTy>(iter_nelems * sort_nelems, exec_q);
sycl::malloc_device<IndexTy>(iter_nelems * axis_nelems, exec_q);
if (index_data == nullptr) {
throw std::runtime_error("Unable to allocate device_memory");
}
Expand All @@ -103,10 +98,10 @@ sycl::event topk_full_sort_impl(
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

auto const &range = sycl::range<1>(iter_nelems * sort_nelems);
auto const &range = sycl::range<1>(iter_nelems * axis_nelems);

using KernelName =
populate_index_data_full_sort_krn<argTy, IndexTy, CompT>;
topk_populate_index_data_krn<argTy, IndexTy, CompT>;

cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
std::size_t i = id[0];
Expand All @@ -118,34 +113,33 @@ sycl::event topk_full_sort_impl(
// Sort segments of the array
sycl::event base_sort_ev =
merge_sort_detail::sort_over_work_group_contig_impl(
exec_q, iter_nelems, sort_nelems, index_data, index_data, comp,
exec_q, iter_nelems, axis_nelems, index_data, index_data, comp,
sorted_block_size, // modified in place with size of sorted block
// size
{populate_indexed_data_ev});

// Merge segments in parallel until all elements are sorted
sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl(
exec_q, iter_nelems, sort_nelems, index_data, comp, sorted_block_size,
exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size,
{base_sort_ev});

sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(merges_ev);

using KernelName =
topk_map_to_rows_full_sort_krn<argTy, IndexTy, CompT>;
using KernelName = topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;

cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
std::size_t gid = id[0];

std::size_t iter_gid = gid / k;
std::size_t axis_gid = gid - (iter_gid * k);

std::size_t src_idx = iter_gid * sort_nelems + axis_gid;
std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
std::size_t dst_idx = iter_gid * k + axis_gid;

auto res_ind = index_data[src_idx];
vals_tp[dst_idx] = arg_tp[res_ind];
inds_tp[dst_idx] = res_ind % sort_nelems;
inds_tp[dst_idx] = res_ind % axis_nelems;
});
});

Expand All @@ -162,29 +156,32 @@ sycl::event topk_full_sort_impl(
return cleanup_host_task_event;
};

template <typename T1, typename T2, typename T3>
class topk_partial_merge_map_back_krn;

template <typename T1, typename T2, typename Comp>
class topk_over_work_group_krn;

template <typename argTy,
typename IndexTy,
typename ValueComp = std::less<argTy>>
sycl::event
topk_impl(sycl::queue &exec_q,
std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows
// in a matrix when sorting over rows)
std::size_t axis_nelems, // size of each array to sort (length of
// rows, i.e. number of columns)
std::size_t k,
const char *arg_cp,
char *vals_cp,
char *inds_cp,
dpctl::tensor::ssize_t iter_arg_offset,
dpctl::tensor::ssize_t iter_vals_offset,
dpctl::tensor::ssize_t iter_inds_offset,
dpctl::tensor::ssize_t axis_arg_offset,
dpctl::tensor::ssize_t axis_vals_offset,
dpctl::tensor::ssize_t axis_inds_offset,
const std::vector<sycl::event> &depends)
sycl::event topk_merge_impl(
sycl::queue &exec_q,
std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows
// in a matrix when sorting over rows)
std::size_t axis_nelems, // size of each array to sort (length of
// rows, i.e. number of columns)
std::size_t k,
const char *arg_cp,
char *vals_cp,
char *inds_cp,
dpctl::tensor::ssize_t iter_arg_offset,
dpctl::tensor::ssize_t iter_vals_offset,
dpctl::tensor::ssize_t iter_inds_offset,
dpctl::tensor::ssize_t axis_arg_offset,
dpctl::tensor::ssize_t axis_vals_offset,
dpctl::tensor::ssize_t axis_inds_offset,
const std::vector<sycl::event> &depends)
{
if (axis_nelems < k) {
throw std::runtime_error("Invalid sort axis size for value of k");
Expand All @@ -201,8 +198,9 @@ topk_impl(sycl::queue &exec_q,
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};

if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2) {
return topk_full_sort_impl(exec_q, iter_nelems, axis_nelems, k, arg_tp,
vals_tp, inds_tp, index_comp, depends);
return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, k,
arg_tp, vals_tp, inds_tp, index_comp,
depends);
}
else {
using PartialKernelName =
Expand Down Expand Up @@ -269,9 +267,9 @@ topk_impl(sycl::queue &exec_q,
if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size ||
alloc_len >= axis_nelems / 2)
{
return topk_full_sort_impl(exec_q, iter_nelems, axis_nelems, k,
arg_tp, vals_tp, inds_tp, index_comp,
depends);
return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems,
k, arg_tp, vals_tp, inds_tp,
index_comp, depends);
}

IndexTy *index_data =
Expand Down Expand Up @@ -399,7 +397,8 @@ topk_impl(sycl::queue &exec_q,
sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(merges_ev);

using KernelName = topk_map_to_rows_krn<argTy, IndexTy, ValueComp>;
using KernelName =
topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;

cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
std::size_t gid = id[0];
Expand Down Expand Up @@ -430,6 +429,109 @@ topk_impl(sycl::queue &exec_q,
}
}

template <typename T1, typename T2> class topk_iota_krn;

template <typename T1, typename T2> class topk_radix_map_back_krn;

template <typename argTy, typename IndexTy>
sycl::event topk_radix_impl(sycl::queue &exec_q,
std::size_t iter_nelems, // number of sub-arrays
std::size_t axis_nelems, // size of each sub-array
std::size_t k,
bool ascending,
const char *arg_cp,
char *vals_cp,
char *inds_cp,
dpctl::tensor::ssize_t iter_arg_offset,
dpctl::tensor::ssize_t iter_vals_offset,
dpctl::tensor::ssize_t iter_inds_offset,
dpctl::tensor::ssize_t axis_arg_offset,
dpctl::tensor::ssize_t axis_vals_offset,
dpctl::tensor::ssize_t axis_inds_offset,
const std::vector<sycl::event> &depends)
{
if (axis_nelems < k) {
throw std::runtime_error("Invalid sort axis size for value of k");
}

const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + axis_arg_offset;
argTy *vals_tp = reinterpret_cast<argTy *>(vals_cp) + iter_vals_offset +
axis_vals_offset;
IndexTy *inds_tp = reinterpret_cast<IndexTy *>(inds_cp) + iter_inds_offset +
axis_inds_offset;

const std::size_t total_nelems = iter_nelems * axis_nelems;
const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64;
IndexTy *workspace = sycl::malloc_device<IndexTy>(
padded_total_nelems + total_nelems, exec_q);

IndexTy *tmp_tp = sycl::malloc_device<IndexTy>(total_nelems, exec_q);

if (nullptr == workspace || nullptr == tmp_tp) {
throw std::runtime_error(
"Not enough device memory for radix sort topk");
}

using IdentityProjT = radix_sort_details::IdentityProj;
using IndexedProjT =
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
const IndexedProjT proj_op{arg_tp, IdentityProjT{}};

sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using KernelName = topk_iota_krn<argTy, IndexTy>;

cgh.parallel_for<KernelName>(
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
size_t i = id[0];
IndexTy sort_id = static_cast<IndexTy>(i);
workspace[i] = sort_id;
});
});

sycl::event radix_sort_ev =
radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op,
ascending, {iota_ev});

// Write out top k of the temporary
sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(radix_sort_ev);

using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;

cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
std::size_t gid = id[0];

std::size_t iter_gid = gid / k;
std::size_t axis_gid = gid - (iter_gid * k);

std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
std::size_t dst_idx = iter_gid * k + axis_gid;

IndexTy res_ind = tmp_tp[src_idx];
vals_tp[dst_idx] = arg_tp[res_ind];
inds_tp[dst_idx] = res_ind % axis_nelems;
});
});

sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(write_topk_ev);

const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task([ctx, workspace, tmp_tp] {
sycl_free_noexcept(workspace, ctx);
sycl_free_noexcept(tmp_tp, ctx);
});
});

return cleanup_ev;
}

} // end of namespace kernels
} // end of namespace tensor
} // end of namespace dpctl
54 changes: 41 additions & 13 deletions dpctl/tensor/libtensor/source/sorting/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ static topk_impl_fn_ptr_t topk_dispatch_vector[td_ns::num_types];
namespace
{

template <typename T, typename = void>
struct use_radix_sort : public std::false_type
{
};

template <typename T>
struct use_radix_sort<
T,
std::enable_if_t<std::disjunction<std::is_same<T, bool>,
std::is_same<T, std::uint8_t>,
std::is_same<T, std::int8_t>,
std::is_same<T, std::uint16_t>,
std::is_same<T, std::int16_t>>::value>>
: public std::true_type
{
};

template <typename argTy, typename IndexTy>
sycl::event
topk_caller(sycl::queue &exec_q,
Expand All @@ -96,22 +113,33 @@ topk_caller(sycl::queue &exec_q,
py::ssize_t axis_inds_offset,
const std::vector<sycl::event> &depends)
{
using dpctl::tensor::kernels::topk_impl;
if (largest) {
using CompTy =
typename dpctl::tensor::py_internal::DescendingSorter<argTy>::type;
return topk_impl<argTy, IndexTy, CompTy>(
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
iter_arg_offset, iter_vals_offset, iter_inds_offset,
if constexpr (use_radix_sort<argTy>::value) {
using dpctl::tensor::kernels::topk_radix_impl;
auto ascending = !largest;
return topk_radix_impl<argTy, IndexTy>(
exec_q, iter_nelems, axis_nelems, k, ascending, arg_cp, vals_cp,
inds_cp, iter_arg_offset, iter_vals_offset, iter_inds_offset,
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
}
else {
using CompTy =
typename dpctl::tensor::py_internal::AscendingSorter<argTy>::type;
return topk_impl<argTy, IndexTy, CompTy>(
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
iter_arg_offset, iter_vals_offset, iter_inds_offset,
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
using dpctl::tensor::kernels::topk_merge_impl;
if (largest) {
using CompTy =
typename dpctl::tensor::py_internal::DescendingSorter<
argTy>::type;
return topk_merge_impl<argTy, IndexTy, CompTy>(
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
iter_arg_offset, iter_vals_offset, iter_inds_offset,
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
}
else {
using CompTy = typename dpctl::tensor::py_internal::AscendingSorter<
argTy>::type;
return topk_merge_impl<argTy, IndexTy, CompTy>(
exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp,
iter_arg_offset, iter_vals_offset, iter_inds_offset,
axis_arg_offset, axis_vals_offset, axis_inds_offset, depends);
}
}
}

Expand Down

0 comments on commit a56e21c

Please sign in to comment.