diff --git a/dpctl/tensor/libtensor/source/sorting/topk.cpp b/dpctl/tensor/libtensor/source/sorting/topk.cpp index dea20fd494..7a7e5b4782 100644 --- a/dpctl/tensor/libtensor/source/sorting/topk.cpp +++ b/dpctl/tensor/libtensor/source/sorting/topk.cpp @@ -110,9 +110,9 @@ sycl::event topk_caller(sycl::queue &exec_q, py::ssize_t axis_inds_offset, const std::vector &depends) { - if constexpr (use_radix_sort::value) { + if (use_radix_sort::value && (axis_nelems >= 16384)) { using dpctl::tensor::kernels::topk_radix_impl; - auto ascending = !largest; + const auto ascending = !largest; return topk_radix_impl( exec_q, iter_nelems, axis_nelems, k, ascending, arg_cp, vals_cp, inds_cp, iter_arg_offset, iter_vals_offset, iter_inds_offset,