diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp index f709a1737d..11f3e851fb 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -35,6 +35,7 @@ #include "kernels/dpctl_tensor_types.hpp" #include "merge_sort.hpp" +#include "radix_sort.hpp" #include "utils/sycl_alloc_utils.hpp" #include @@ -70,31 +71,25 @@ void scale_topk_params(const std::uint64_t nelems_per_slm, } // namespace topk_detail template -class populate_index_data_full_sort_krn; +class topk_populate_index_data_krn; template -class topk_map_to_rows_full_sort_krn; - -template class populate_index_data_krn; - -template class topk_map_to_rows_krn; +class topk_full_merge_map_back_krn; template -sycl::event topk_full_sort_impl( - sycl::queue &exec_q, - std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows in a - // matrix when sorting over rows) - std::size_t sort_nelems, // size of each array to sort (length of rows, - // i.e. number of columns) - std::size_t k, - const argTy *arg_tp, - argTy *vals_tp, - IndexTy *inds_tp, - const CompT &comp, - const std::vector &depends) +sycl::event +topk_full_merge_sort_impl(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + const argTy *arg_tp, + argTy *vals_tp, + IndexTy *inds_tp, + const CompT &comp, + const std::vector &depends) { IndexTy *index_data = - sycl::malloc_device(iter_nelems * sort_nelems, exec_q); + sycl::malloc_device(iter_nelems * axis_nelems, exec_q); if (index_data == nullptr) { throw std::runtime_error("Unable to allocate device_memory"); } @@ -103,10 +98,10 @@ sycl::event topk_full_sort_impl( exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - auto const &range = sycl::range<1>(iter_nelems * sort_nelems); + auto const &range = sycl::range<1>(iter_nelems * axis_nelems); using KernelName = - populate_index_data_full_sort_krn; + topk_populate_index_data_krn; cgh.parallel_for(range, [=](sycl::id<1> id) { std::size_t i = id[0]; @@ -118,21 +113,20 @@ sycl::event topk_full_sort_impl( // Sort segments of the array sycl::event base_sort_ev = merge_sort_detail::sort_over_work_group_contig_impl( - exec_q, iter_nelems, sort_nelems, index_data, index_data, comp, + exec_q, iter_nelems, axis_nelems, index_data, index_data, comp, sorted_block_size, // modified in place with size of sorted block // size {populate_indexed_data_ev}); // Merge segments in parallel until all elements are sorted sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl( - exec_q, iter_nelems, sort_nelems, index_data, comp, sorted_block_size, + exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size, {base_sort_ev}); sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(merges_ev); - using KernelName = - topk_map_to_rows_full_sort_krn; + using KernelName = topk_full_merge_map_back_krn; cgh.parallel_for(iter_nelems * k, [=](sycl::id<1> id) { std::size_t gid = id[0]; @@ -140,12 +134,12 @@ sycl::event topk_full_sort_impl( std::size_t iter_gid = gid / k; std::size_t axis_gid = gid - (iter_gid * k); - std::size_t src_idx = iter_gid * sort_nelems + axis_gid; + std::size_t src_idx = iter_gid * axis_nelems + axis_gid; std::size_t dst_idx = iter_gid * k + axis_gid; auto res_ind = index_data[src_idx]; vals_tp[dst_idx] = arg_tp[res_ind]; - inds_tp[dst_idx] = res_ind % sort_nelems; + inds_tp[dst_idx] = res_ind % axis_nelems; }); }); @@ -162,29 +156,32 @@ sycl::event topk_full_sort_impl( return cleanup_host_task_event; }; +template +class topk_partial_merge_map_back_krn; + template class topk_over_work_group_krn; template > -sycl::event -topk_impl(sycl::queue &exec_q, - std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows - // in a matrix when sorting over rows) - std::size_t axis_nelems, // size of each array to sort (length of - // rows, i.e. number of columns) - std::size_t k, - const char *arg_cp, - char *vals_cp, - char *inds_cp, - dpctl::tensor::ssize_t iter_arg_offset, - dpctl::tensor::ssize_t iter_vals_offset, - dpctl::tensor::ssize_t iter_inds_offset, - dpctl::tensor::ssize_t axis_arg_offset, - dpctl::tensor::ssize_t axis_vals_offset, - dpctl::tensor::ssize_t axis_inds_offset, - const std::vector &depends) +sycl::event topk_merge_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows + // in a matrix when sorting over rows) + std::size_t axis_nelems, // size of each array to sort (length of + // rows, i.e. number of columns) + std::size_t k, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + dpctl::tensor::ssize_t iter_arg_offset, + dpctl::tensor::ssize_t iter_vals_offset, + dpctl::tensor::ssize_t iter_inds_offset, + dpctl::tensor::ssize_t axis_arg_offset, + dpctl::tensor::ssize_t axis_vals_offset, + dpctl::tensor::ssize_t axis_inds_offset, + const std::vector &depends) { if (axis_nelems < k) { throw std::runtime_error("Invalid sort axis size for value of k"); @@ -201,8 +198,9 @@ topk_impl(sycl::queue &exec_q, const IndexComp index_comp{arg_tp, ValueComp{}}; if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2) { - return topk_full_sort_impl(exec_q, iter_nelems, axis_nelems, k, arg_tp, - vals_tp, inds_tp, index_comp, depends); + return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, k, + arg_tp, vals_tp, inds_tp, index_comp, + depends); } else { using PartialKernelName = @@ -269,9 +267,9 @@ topk_impl(sycl::queue &exec_q, if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size || alloc_len >= axis_nelems / 2) { - return topk_full_sort_impl(exec_q, iter_nelems, axis_nelems, k, - arg_tp, vals_tp, inds_tp, index_comp, - depends); + return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, + k, arg_tp, vals_tp, inds_tp, + index_comp, depends); } IndexTy *index_data = @@ -399,7 +397,8 @@ topk_impl(sycl::queue &exec_q, sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(merges_ev); - using KernelName = topk_map_to_rows_krn; + using KernelName = + topk_partial_merge_map_back_krn; cgh.parallel_for(iter_nelems * k, [=](sycl::id<1> id) { std::size_t gid = id[0]; @@ -430,6 +429,109 @@ topk_impl(sycl::queue &exec_q, } } +template class topk_iota_krn; + +template class topk_radix_map_back_krn; + +template +sycl::event topk_radix_impl(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + bool ascending, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + dpctl::tensor::ssize_t iter_arg_offset, + dpctl::tensor::ssize_t iter_vals_offset, + dpctl::tensor::ssize_t iter_inds_offset, + dpctl::tensor::ssize_t axis_arg_offset, + dpctl::tensor::ssize_t axis_vals_offset, + dpctl::tensor::ssize_t axis_inds_offset, + const std::vector &depends) +{ + if (axis_nelems < k) { + throw std::runtime_error("Invalid sort axis size for value of k"); + } + + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + axis_arg_offset; + argTy *vals_tp = reinterpret_cast(vals_cp) + iter_vals_offset + + axis_vals_offset; + IndexTy *inds_tp = reinterpret_cast(inds_cp) + iter_inds_offset + + axis_inds_offset; + + const std::size_t total_nelems = iter_nelems * axis_nelems; + const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; + IndexTy *workspace = sycl::malloc_device( + padded_total_nelems + total_nelems, exec_q); + + IndexTy *tmp_tp = sycl::malloc_device(total_nelems, exec_q); + + if (nullptr == workspace || nullptr == tmp_tp) { + throw std::runtime_error( + "Not enough device memory for radix sort topk"); + } + + using IdentityProjT = radix_sort_details::IdentityProj; + using IndexedProjT = + radix_sort_details::IndexedProj; + const IndexedProjT proj_op{arg_tp, IdentityProjT{}}; + + sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = topk_iota_krn; + + cgh.parallel_for( + sycl::range<1>(total_nelems), [=](sycl::id<1> id) { + size_t i = id[0]; + IndexTy sort_id = static_cast(i); + workspace[i] = sort_id; + }); + }); + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op, + ascending, {iota_ev}); + + // Write out top k of the temporary + sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(radix_sort_ev); + + using KernelName = topk_radix_map_back_krn; + + cgh.parallel_for(iter_nelems * k, [=](sycl::id<1> id) { + std::size_t gid = id[0]; + + std::size_t iter_gid = gid / k; + std::size_t axis_gid = gid - (iter_gid * k); + + std::size_t src_idx = iter_gid * axis_nelems + axis_gid; + std::size_t dst_idx = iter_gid * k + axis_gid; + + IndexTy res_ind = tmp_tp[src_idx]; + vals_tp[dst_idx] = arg_tp[res_ind]; + inds_tp[dst_idx] = res_ind % axis_nelems; + }); + }); + + sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(write_topk_ev); + + const sycl::context &ctx = exec_q.get_context(); + + using dpctl::tensor::alloc_utils::sycl_free_noexcept; + cgh.host_task([ctx, workspace, tmp_tp] { + sycl_free_noexcept(workspace, ctx); + sycl_free_noexcept(tmp_tp, ctx); + }); + }); + + return cleanup_ev; +} + } // end of namespace kernels } // end of namespace tensor } // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/topk.cpp b/dpctl/tensor/libtensor/source/sorting/topk.cpp index fe5eac9acb..d3a9e951ac 100644 --- a/dpctl/tensor/libtensor/source/sorting/topk.cpp +++ b/dpctl/tensor/libtensor/source/sorting/topk.cpp @@ -76,6 +76,23 @@ static topk_impl_fn_ptr_t topk_dispatch_vector[td_ns::num_types]; namespace { +template +struct use_radix_sort : public std::false_type +{ +}; + +template +struct use_radix_sort< + T, + std::enable_if_t, + std::is_same, + std::is_same, + std::is_same, + std::is_same>::value>> + : public std::true_type +{ +}; + template sycl::event topk_caller(sycl::queue &exec_q, @@ -96,22 +113,33 @@ topk_caller(sycl::queue &exec_q, py::ssize_t axis_inds_offset, const std::vector &depends) { - using dpctl::tensor::kernels::topk_impl; - if (largest) { - using CompTy = - typename dpctl::tensor::py_internal::DescendingSorter::type; - return topk_impl( - exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, - iter_arg_offset, iter_vals_offset, iter_inds_offset, + if constexpr (use_radix_sort::value) { + using dpctl::tensor::kernels::topk_radix_impl; + auto ascending = !largest; + return topk_radix_impl( + exec_q, iter_nelems, axis_nelems, k, ascending, arg_cp, vals_cp, + inds_cp, iter_arg_offset, iter_vals_offset, iter_inds_offset, axis_arg_offset, axis_vals_offset, axis_inds_offset, depends); } else { - using CompTy = - typename dpctl::tensor::py_internal::AscendingSorter::type; - return topk_impl( - exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, - iter_arg_offset, iter_vals_offset, iter_inds_offset, - axis_arg_offset, axis_vals_offset, axis_inds_offset, depends); + using dpctl::tensor::kernels::topk_merge_impl; + if (largest) { + using CompTy = + typename dpctl::tensor::py_internal::DescendingSorter< + argTy>::type; + return topk_merge_impl( + exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, + iter_arg_offset, iter_vals_offset, iter_inds_offset, + axis_arg_offset, axis_vals_offset, axis_inds_offset, depends); + } + else { + using CompTy = typename dpctl::tensor::py_internal::AscendingSorter< + argTy>::type; + return topk_merge_impl( + exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, + iter_arg_offset, iter_vals_offset, iter_inds_offset, + axis_arg_offset, axis_vals_offset, axis_inds_offset, depends); + } } }