diff --git a/dpctl/tensor/libtensor/source/sorting/topk.cpp b/dpctl/tensor/libtensor/source/sorting/topk.cpp index dea20fd494..544f0d6f5e 100644 --- a/dpctl/tensor/libtensor/source/sorting/topk.cpp +++ b/dpctl/tensor/libtensor/source/sorting/topk.cpp @@ -112,13 +112,17 @@ sycl::event topk_caller(sycl::queue &exec_q, { if constexpr (use_radix_sort::value) { using dpctl::tensor::kernels::topk_radix_impl; - auto ascending = !largest; - return topk_radix_impl( - exec_q, iter_nelems, axis_nelems, k, ascending, arg_cp, vals_cp, - inds_cp, iter_arg_offset, iter_vals_offset, iter_inds_offset, - axis_arg_offset, axis_vals_offset, axis_inds_offset, depends); + const auto ascending = !largest; + + if (axis_nelems >= 16384) { + return topk_radix_impl( + exec_q, iter_nelems, axis_nelems, k, ascending, arg_cp, vals_cp, + inds_cp, iter_arg_offset, iter_vals_offset, iter_inds_offset, + axis_arg_offset, axis_vals_offset, axis_inds_offset, depends); + } } - else { + + { using dpctl::tensor::kernels::topk_merge_impl; if (largest) { using CompTy =