diff --git a/CHANGELOG.md b/CHANGELOG.md index fec5e79f17..e7138a6a73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +* Added `dpctl.tensor.top_k` per Python Array API specification: [#1921](https://github.com/IntelPython/dpctl/pull/1921) + ### Changed * Improved performance of copy-and-cast operations from `numpy.ndarray` to `tensor.usm_ndarray` for contiguous inputs [gh-1829](https://github.com/IntelPython/dpctl/pull/1829) diff --git a/docs/doc_sources/api_reference/dpctl/tensor.sorting_functions.rst b/docs/doc_sources/api_reference/dpctl/tensor.sorting_functions.rst index ae1605d988..ef20f4654c 100644 --- a/docs/doc_sources/api_reference/dpctl/tensor.sorting_functions.rst +++ b/docs/doc_sources/api_reference/dpctl/tensor.sorting_functions.rst @@ -10,3 +10,4 @@ Sorting functions argsort sort + top_k diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 882233b30c..75976f1b1f 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -115,6 +115,7 @@ set(_sorting_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp ) set(_sorting_radix_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 581c3d54f6..490cdfd23c 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -199,7 +199,7 @@ unique_inverse, unique_values, ) -from ._sorting import argsort, sort +from ._sorting import argsort, sort, top_k from ._testing import allclose from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type @@ -387,4 +387,5 @@ "DLDeviceType", "take_along_axis", "put_along_axis", + "top_k", ] diff --git a/dpctl/tensor/_sorting.py b/dpctl/tensor/_sorting.py index d5026a6ee8..19cd2f2836 100644 --- a/dpctl/tensor/_sorting.py +++ b/dpctl/tensor/_sorting.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import operator +from typing import NamedTuple + import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti import dpctl.utils as du @@ -24,6 +27,7 @@ _argsort_descending, _sort_ascending, _sort_descending, + _topk, ) from ._tensor_sorting_radix_impl import ( _radix_argsort_ascending, @@ -267,3 +271,166 @@ def argsort(x, axis=-1, descending=False, stable=True, kind=None): inv_perm = sorted(range(nd), key=lambda d: perm[d]) res = dpt.permute_dims(res, inv_perm) return res + + +def _get_top_k_largest(mode): + modes = {"largest": True, "smallest": False} + try: + return modes[mode] + except KeyError: + raise ValueError( + f"`mode` must be `largest` or `smallest`. Got `{mode}`." + ) + + +class TopKResult(NamedTuple): + values: dpt.usm_ndarray + indices: dpt.usm_ndarray + + +def top_k(x, k, /, *, axis=None, mode="largest"): + """top_k(x, k, axis=None, mode="largest") + + Returns the `k` largest or smallest values and their indices in the input + array `x` along the specified axis `axis`. + + Args: + x (usm_ndarray): + input array. + k (int): + number of elements to find. Must be a positive integer value. + axis (Optional[int]): + axis along which to search. If `None`, the search will be performed + over the flattened array. Default: ``None``. + mode (Literal["largest", "smallest"]): + search mode. Must be one of the following modes: + + - `"largest"`: return the `k` largest elements. + - `"smallest"`: return the `k` smallest elements. + + Default: `"largest"`. + + Returns: + tuple[usm_ndarray, usm_ndarray] + a namedtuple `(values, indices)` whose + + * first element `values` will be an array containing the `k` + largest or smallest elements of `x`. The array has the same data + type as `x`. If `axis` was `None`, `values` will be a + one-dimensional array with shape `(k,)` and otherwise, `values` + will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]` + * second element `indices` will be an array containing indices of + `x` that result in `values`. The array will have the same shape + as `values` and will have the default array index data type. + """ + largest = _get_top_k_largest(mode) + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" + ) + + k = operator.index(k) + if k < 0: + raise ValueError("`k` must be a positive integer value") + + nd = x.ndim + if axis is None: + sz = x.size + if nd == 0: + if k > 1: + raise ValueError(f"`k`={k} is out of bounds 1") + return TopKResult( + dpt.copy(x, order="C"), + dpt.zeros_like( + x, dtype=ti.default_device_index_type(x.sycl_queue) + ), + ) + arr = x + n_search_dims = None + res_sh = k + else: + axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") + sz = x.shape[axis] + a1 = axis + 1 + if a1 == nd: + perm = list(range(nd)) + arr = x + else: + perm = [i for i in range(nd) if i != axis] + [ + axis, + ] + arr = dpt.permute_dims(x, perm) + n_search_dims = 1 + res_sh = arr.shape[: nd - 1] + (k,) + + if k > sz: + raise ValueError(f"`k`={k} is out of bounds {sz}") + + exec_q = x.sycl_queue + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + res_usm_type = arr.usm_type + if arr.flags.c_contiguous: + vals = dpt.empty( + res_sh, + dtype=arr.dtype, + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + inds = dpt.empty( + res_sh, + dtype=ti.default_device_index_type(exec_q), + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + ht_ev, impl_ev = _topk( + src=arr, + trailing_dims_to_search=n_search_dims, + k=k, + largest=largest, + vals=vals, + inds=inds, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, impl_ev) + else: + tmp = dpt.empty_like(arr, order="C") + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + vals = dpt.empty( + res_sh, + dtype=arr.dtype, + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + inds = dpt.empty( + res_sh, + dtype=ti.default_device_index_type(exec_q), + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + ht_ev, impl_ev = _topk( + src=tmp, + trailing_dims_to_search=n_search_dims, + k=k, + largest=largest, + vals=vals, + inds=inds, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, impl_ev) + if axis is not None and a1 != nd: + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + vals = dpt.permute_dims(vals, inv_perm) + inds = dpt.permute_dims(inds, inv_perm) + + return TopKResult(vals, inds) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp index f3b5030c48..dbf40c10fe 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp @@ -33,6 +33,7 @@ #include "kernels/dpctl_tensor_types.hpp" #include "kernels/sorting/search_sorted_detail.hpp" +#include "kernels/sorting/sort_utils.hpp" namespace dpctl { @@ -811,20 +812,12 @@ sycl::event stable_argsort_axis1_contig_impl( const size_t total_nelems = iter_nelems * sort_nelems; - sycl::event populate_indexed_data_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; - const sycl::range<1> range{total_nelems}; + using IotaKernelName = populate_index_data_krn; - using KernelName = - populate_index_data_krn; - - cgh.parallel_for(range, [=](sycl::id<1> id) { - size_t i = id[0]; - res_tp[i] = static_cast(i); - }); - }); + sycl::event populate_indexed_data_ev = iota_impl( + exec_q, res_tp, total_nelems, depends); // Sort segments of the array sycl::event base_sort_ev = @@ -839,21 +832,11 @@ sycl::event stable_argsort_axis1_contig_impl( exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size, {base_sort_ev}); - sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(merges_ev); - - auto temp_acc = - merge_sort_detail::GetReadOnlyAccess{}(res_tp, - cgh); - - using KernelName = index_map_to_rows_krn; + using MapBackKernelName = index_map_to_rows_krn; + using dpctl::tensor::kernels::sort_utils_detail::map_back_impl; - const sycl::range<1> range{total_nelems}; - - cgh.parallel_for(range, [=](sycl::id<1> id) { - res_tp[id] = (temp_acc[id] % sort_nelems); - }); - }); + sycl::event write_out_ev = map_back_impl( + exec_q, total_nelems, res_tp, res_tp, sort_nelems, {merges_ev}); return write_out_ev; } diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp index dc3da24315..15f22b334e 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -38,6 +38,7 @@ #include #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/sort_utils.hpp" #include "utils/sycl_alloc_utils.hpp" namespace dpctl @@ -62,6 +63,48 @@ class radix_sort_reorder_peer_kernel; template class radix_sort_reorder_kernel; +/*! @brief Computes smallest exponent such that `n <= (1 << exponent)` */ +template && + sizeof(SizeT) == sizeof(std::uint64_t), + int> = 0> +std::uint32_t ceil_log2(SizeT n) +{ + // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b + // floor_log2(q * 2^b + r) == floor_log2(q * 2^b) == q + floor_log2(n1) + // ceil_log2(n) == 1 + floor_log2(n-1) + if (n <= 1) + return std::uint32_t{1}; + + std::uint32_t exp{1}; + --n; + if (n >= (SizeT{1} << 32)) { + n >>= 32; + exp += 32; + } + if (n >= (SizeT{1} << 16)) { + n >>= 16; + exp += 16; + } + if (n >= (SizeT{1} << 8)) { + n >>= 8; + exp += 8; + } + if (n >= (SizeT{1} << 4)) { + n >>= 4; + exp += 4; + } + if (n >= (SizeT{1} << 2)) { + n >>= 2; + exp += 2; + } + if (n >= (SizeT{1} << 1)) { + n >>= 1; + ++exp; + } + return exp; +} + //---------------------------------------------------------- // bitwise order-preserving conversions to unsigned integers //---------------------------------------------------------- @@ -95,16 +138,20 @@ template order_preserving_cast(IntT val) { using UIntT = std::make_unsigned_t; - // ascending_mask: 100..0 - constexpr UIntT ascending_mask = - (UIntT(1) << std::numeric_limits::digits); - // descending_mask: 011..1 - constexpr UIntT descending_mask = (std::numeric_limits::max() >> 1); - - constexpr UIntT mask = (is_ascending) ? ascending_mask : descending_mask; const UIntT uint_val = sycl::bit_cast(val); - return (uint_val ^ mask); + if constexpr (is_ascending) { + // ascending_mask: 100..0 + constexpr UIntT ascending_mask = + (UIntT(1) << std::numeric_limits::digits); + return (uint_val ^ ascending_mask); + } + else { + // descending_mask: 011..1 + constexpr UIntT descending_mask = + (std::numeric_limits::max() >> 1); + return (uint_val ^ descending_mask); + } } template std::uint16_t order_preserving_cast(sycl::half val) @@ -1003,10 +1050,10 @@ template class radix_sort_one_wg_krn; template + std::uint16_t req_sub_group_size = (block_size < 4 ? 32 : 16)> struct subgroup_radix_sort { private: @@ -1020,8 +1067,8 @@ struct subgroup_radix_sort public: template sycl::event operator()(sycl::queue &exec_q, - size_t n_iters, - size_t n_to_sort, + std::size_t n_iters, + std::size_t n_to_sort, ValueT *input_ptr, OutputT *output_ptr, ProjT proj_op, @@ -1118,8 +1165,8 @@ struct subgroup_radix_sort }; static_assert(wg_size <= 1024); - static constexpr uint16_t bin_count = (1 << radix); - static constexpr uint16_t counter_buf_sz = wg_size * bin_count + 1; + static constexpr std::uint16_t bin_count = (1 << radix); + static constexpr std::uint16_t counter_buf_sz = wg_size * bin_count + 1; enum class temp_allocations { @@ -1135,7 +1182,7 @@ struct subgroup_radix_sort assert(n <= (SizeT(1) << 16)); constexpr auto req_slm_size_counters = - counter_buf_sz * sizeof(uint32_t); + counter_buf_sz * sizeof(std::uint16_t); const auto &dev = exec_q.get_device(); @@ -1144,7 +1191,7 @@ struct subgroup_radix_sort const std::size_t max_slm_size = dev.template get_info() / 2; - const auto n_uniform = 1 << (std::uint32_t(std::log2(n - 1)) + 1); + const auto n_uniform = 1 << ceil_log2(n); const auto req_slm_size_val = sizeof(T) * n_uniform; return ((req_slm_size_val + req_slm_size_counters) <= max_slm_size) @@ -1170,9 +1217,9 @@ struct subgroup_radix_sort typename SLM_value_tag, typename SLM_counter_tag> sycl::event operator()(sycl::queue &exec_q, - size_t n_iters, - size_t n_batch_size, - size_t n_values, + std::size_t n_iters, + std::size_t n_batch_size, + std::size_t n_values, InputT *input_arr, OutputT *output_arr, const ProjT &proj_op, @@ -1186,7 +1233,7 @@ struct subgroup_radix_sort assert(n_values <= static_cast(block_size) * static_cast(wg_size)); - uint16_t n = static_cast(n_values); + const std::uint16_t n = static_cast(n_values); static_assert(std::is_same_v, OutputT>); using ValueT = OutputT; @@ -1195,17 +1242,36 @@ struct subgroup_radix_sort TempBuf buf_val( n_batch_size, static_cast(block_size * wg_size)); - TempBuf buf_count( + TempBuf buf_count( n_batch_size, static_cast(counter_buf_sz)); sycl::range<1> lRange{wg_size}; sycl::event sort_ev; - std::vector deps = depends; + std::vector deps{depends}; + + const std::size_t n_batches = + (n_iters + n_batch_size - 1) / n_batch_size; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + const auto &krn = kb.get_kernel(kernel_id); + + const std::uint32_t krn_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); - std::size_t n_batches = (n_iters + n_batch_size - 1) / n_batch_size; + // due to a bug in CPU device implementation, an additional + // synchronization is necessary for short sub-group sizes + const bool work_around_needed = + exec_q.get_device().has(sycl::aspect::cpu) && + (krn_sg_size < 16); - for (size_t batch_id = 0; batch_id < n_batches; ++batch_id) { + for (std::size_t batch_id = 0; batch_id < n_batches; ++batch_id) { const std::size_t block_start = batch_id * n_batch_size; @@ -1221,6 +1287,7 @@ struct subgroup_radix_sort sort_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(deps); + cgh.use_kernel_bundle(kb); // allocation to use for value exchanges auto exchange_acc = buf_val.get_acc(cgh); @@ -1244,49 +1311,49 @@ struct subgroup_radix_sort const std::size_t iter_exchange_offset = iter_id * exchange_acc_iter_stride; - uint16_t wi = ndit.get_local_linear_id(); - uint16_t begin_bit = 0; + std::uint16_t wi = ndit.get_local_linear_id(); + std::uint16_t begin_bit = 0; - constexpr uint16_t end_bit = + constexpr std::uint16_t end_bit = number_of_bits_in_type(); -// copy from input array into values + // copy from input array into values #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; - if (id < n) - values[i] = std::move( - this_input_arr[iter_val_offset + - static_cast( - id)]); + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t id = wi * block_size + i; + values[i] = + (id < n) ? this_input_arr[iter_val_offset + id] + : ValueT{}; } while (true) { // indices for indirect access in the "re-order" // phase - uint16_t indices[block_size]; + std::uint16_t indices[block_size]; { // pointers to bucket's counters - uint32_t *counters[block_size]; + std::uint16_t *counters[block_size]; // counting phase auto pcounter = get_accessor_pointer(counter_acc) + - static_cast(wi) + - iter_counter_offset; + (wi + iter_counter_offset); -// initialize counters + // initialize counters #pragma unroll - for (uint16_t i = 0; i < bin_count; ++i) - pcounter[i * wg_size] = std::uint32_t{0}; + for (std::uint16_t i = 0; i < bin_count; ++i) + pcounter[i * wg_size] = std::uint16_t{0}; sycl::group_barrier(ndit.get_group()); if (is_ascending) { #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; - constexpr uint16_t bin_mask = + for (std::uint16_t i = 0; i < block_size; + ++i) + { + const std::uint16_t id = + wi * block_size + i; + constexpr std::uint16_t bin_mask = bin_count - 1; // points to the padded element, i.e. id @@ -1295,7 +1362,7 @@ struct subgroup_radix_sort default_out_of_range_bin_id = bin_mask; - const uint16_t bin = + const std::uint16_t bin = (id < n) ? get_bucket_id( order_preserving_cast< @@ -1309,13 +1376,21 @@ struct subgroup_radix_sort counters[i] = &pcounter[bin * wg_size]; indices[i] = *counters[i]; *counters[i] = indices[i] + 1; + + if (work_around_needed) { + sycl::group_barrier( + ndit.get_group()); + } } } else { #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; - constexpr uint16_t bin_mask = + for (std::uint16_t i = 0; i < block_size; + ++i) + { + const std::uint16_t id = + wi * block_size + i; + constexpr std::uint16_t bin_mask = bin_count - 1; // points to the padded element, i.e. id @@ -1324,7 +1399,7 @@ struct subgroup_radix_sort default_out_of_range_bin_id = bin_mask; - const uint16_t bin = + const std::uint16_t bin = (id < n) ? get_bucket_id( order_preserving_cast< @@ -1338,6 +1413,11 @@ struct subgroup_radix_sort counters[i] = &pcounter[bin * wg_size]; indices[i] = *counters[i]; *counters[i] = indices[i] + 1; + + if (work_around_needed) { + sycl::group_barrier( + ndit.get_group()); + } } } @@ -1347,37 +1427,32 @@ struct subgroup_radix_sort { // scan contiguous numbers - uint16_t bin_sum[bin_count]; - bin_sum[0] = - counter_acc[iter_counter_offset + - static_cast( - wi * bin_count)]; + std::uint16_t bin_sum[bin_count]; + const std::size_t counter_offset0 = + iter_counter_offset + wi * bin_count; + bin_sum[0] = counter_acc[counter_offset0]; #pragma unroll - for (uint16_t i = 1; i < bin_count; ++i) + for (std::uint16_t i = 1; i < bin_count; + ++i) bin_sum[i] = bin_sum[i - 1] + - counter_acc - [iter_counter_offset + - static_cast( - wi * bin_count + i)]; + counter_acc[counter_offset0 + i]; sycl::group_barrier(ndit.get_group()); // exclusive scan local sum - uint16_t sum_scan = + std::uint16_t sum_scan = sycl::exclusive_scan_over_group( ndit.get_group(), bin_sum[bin_count - 1], - sycl::plus()); + sycl::plus()); // add to local sum, generate exclusive scan result #pragma unroll - for (uint16_t i = 0; i < bin_count; ++i) - counter_acc[iter_counter_offset + - static_cast( - wi * bin_count + i + - 1)] = + for (std::uint16_t i = 0; i < bin_count; + ++i) + counter_acc[counter_offset0 + i + 1] = sum_scan + bin_sum[i]; if (wi == 0) @@ -1388,11 +1463,13 @@ struct subgroup_radix_sort } #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { + for (std::uint16_t i = 0; i < block_size; ++i) { // a global index is a local offset plus a // global base index indices[i] += *counters[i]; } + + sycl::group_barrier(ndit.get_group()); } begin_bit += radix; @@ -1400,43 +1477,36 @@ struct subgroup_radix_sort // "re-order" phase sycl::group_barrier(ndit.get_group()); if (begin_bit >= end_bit) { -// the last iteration - writing out the result + // the last iteration - writing out the result #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t r = indices[i]; + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t r = indices[i]; if (r < n) { - // move the values to source range and - // destroy the values - this_output_arr - [iter_val_offset + - static_cast(r)] = - std::move(values[i]); + this_output_arr[iter_val_offset + r] = + values[i]; } } return; } -// data exchange + // data exchange #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t r = indices[i]; + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t r = indices[i]; if (r < n) - exchange_acc[iter_exchange_offset + - static_cast(r)] = - std::move(values[i]); + exchange_acc[iter_exchange_offset + r] = + values[i]; } sycl::group_barrier(ndit.get_group()); #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t id = wi * block_size + i; if (id < n) - values[i] = std::move( - exchange_acc[iter_exchange_offset + - static_cast( - id)]); + values[i] = + exchange_acc[iter_exchange_offset + id]; } sycl::group_barrier(ndit.get_group()); @@ -1601,11 +1671,11 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, using CountT = std::uint32_t; // memory for storing count and offset values - CountT *count_ptr = - sycl::malloc_device(n_iters * n_counts, exec_q); - if (nullptr == count_ptr) { - throw std::runtime_error("Could not allocate USM-device memory"); - } + auto count_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_counts, exec_q); + + CountT *count_ptr = count_owner.get(); constexpr std::uint32_t zero_radix_iter{0}; @@ -1618,25 +1688,17 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, n_counts, count_ptr, proj_op, is_ascending, depends); - sort_ev = exec_q.submit([=](sycl::handler &cgh) { - cgh.depends_on(sort_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, count_ptr]() { sycl_free_noexcept(count_ptr, ctx); }); - }); + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, count_owner); return sort_ev; } - ValueT *tmp_arr = - sycl::malloc_device(n_iters * n_to_sort, exec_q); - if (nullptr == tmp_arr) { - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - sycl_free_noexcept(count_ptr, exec_q); - throw std::runtime_error("Could not allocate USM-device memory"); - } + auto tmp_arr_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_to_sort, exec_q); + + ValueT *tmp_arr = tmp_arr_owner.get(); // iterations per each bucket assert("Number of iterations must be even" && radix_iters % 2 == 0); @@ -1670,17 +1732,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, } } - sort_ev = exec_q.submit([=](sycl::handler &cgh) { - cgh.depends_on(sort_ev); - - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, count_ptr, tmp_arr]() { - sycl_free_noexcept(tmp_arr, ctx); - sycl_free_noexcept(count_ptr, ctx); - }); - }); + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, tmp_arr_owner, count_owner); } return sort_ev; @@ -1725,10 +1778,10 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q, const bool sort_ascending, // number of sub-arrays to sort (num. of rows in a // matrix when sorting over rows) - size_t iter_nelems, + std::size_t iter_nelems, // size of each array to sort (length of rows, // i.e. number of columns) - size_t sort_nelems, + std::size_t sort_nelems, const char *arg_cp, char *res_cp, ssize_t iter_arg_offset, @@ -1764,10 +1817,10 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, const bool sort_ascending, // number of sub-arrays to sort (num. of // rows in a matrix when sorting over rows) - size_t iter_nelems, + std::size_t iter_nelems, // size of each array to sort (length of // rows, i.e. number of columns) - size_t sort_nelems, + std::size_t sort_nelems, const char *arg_cp, char *res_cp, ssize_t iter_arg_offset, @@ -1782,57 +1835,38 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; const std::size_t total_nelems = iter_nelems * sort_nelems; - const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; - IndexTy *workspace = sycl::malloc_device( - padded_total_nelems + total_nelems, exec_q); + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device(total_nelems, + exec_q); - if (nullptr == workspace) { - throw std::runtime_error("Could not allocate workspace on device"); - } + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); using IdentityProjT = radix_sort_details::IdentityProj; using IndexedProjT = radix_sort_details::IndexedProj; const IndexedProjT proj_op{arg_tp, IdentityProjT{}}; - sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using IotaKernelName = radix_argsort_iota_krn; - using KernelName = radix_argsort_iota_krn; + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; - cgh.parallel_for( - sycl::range<1>(total_nelems), [=](sycl::id<1> id) { - size_t i = id[0]; - IndexTy sort_id = static_cast(i); - workspace[i] = sort_id; - }); - }); + sycl::event iota_ev = iota_impl( + exec_q, workspace, total_nelems, depends); sycl::event radix_sort_ev = radix_sort_details::parallel_radix_sort_impl( exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op, sort_ascending, {iota_ev}); - sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(radix_sort_ev); + using MapBackKernelName = radix_argsort_index_write_out_krn; + using dpctl::tensor::kernels::sort_utils_detail::map_back_impl; - using KernelName = radix_argsort_index_write_out_krn; + sycl::event map_back_ev = map_back_impl( + exec_q, total_nelems, res_tp, res_tp, sort_nelems, {radix_sort_ev}); - cgh.parallel_for( - sycl::range<1>(total_nelems), [=](sycl::id<1> id) { - IndexTy linear_index = res_tp[id]; - res_tp[id] = (linear_index % sort_nelems); - }); - }); - - sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(map_back_ev); - - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); }); - }); + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {map_back_ev}, workspace_owner); return cleanup_ev; } diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp new file mode 100644 index 0000000000..d1f166f945 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp @@ -0,0 +1,145 @@ +//=== sorting.hpp - Implementation of sorting kernels ---*-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace sort_utils_detail +{ + +namespace syclexp = sycl::ext::oneapi::experimental; + +template +sycl::event iota_impl(sycl::queue &exec_q, + T *data, + std::size_t nelems, + const std::vector &dependent_events) +{ + constexpr std::uint32_t lws = 256; + constexpr std::uint32_t n_wi = 4; + const std::size_t n_groups = (nelems + n_wi * lws - 1) / (n_wi * lws); + + sycl::range<1> gRange{n_groups * lws}; + sycl::range<1> lRange{lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_events); + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + + const std::size_t offset = (gid - lane_id) * n_wi; + const std::uint32_t max_sgSize = sg.get_max_local_range()[0]; + + std::array stripe{}; +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + stripe[i] = T(offset + lane_id + i * max_sgSize); + } + + if (offset + n_wi * max_sgSize < nelems) { + constexpr auto group_ls_props = syclexp::properties{ + syclexp::data_placement_striped + // , syclexp::full_group + }; + + auto out_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&data[offset]); + + syclexp::group_store(sg, sycl::span{&stripe[0], n_wi}, + out_multi_ptr, group_ls_props); + } + else { + for (std::size_t idx = offset + lane_id; idx < nelems; + idx += max_sgSize) + { + data[idx] = T(idx); + } + } + }); + }); + + return e; +} + +template +sycl::event map_back_impl(sycl::queue &exec_q, + std::size_t nelems, + const IndexTy *flat_index_data, + IndexTy *reduced_index_data, + std::size_t row_size, + const std::vector &dependent_events) +{ + constexpr std::uint32_t lws = 64; + constexpr std::uint32_t n_wi = 4; + const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws); + + sycl::range<1> lRange{lws}; + sycl::range<1> gRange{n_groups * lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_events); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_max_local_range()[0]; + + const std::size_t start_id = (gid - lane_id) * n_wi + lane_id; + +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + const std::size_t data_id = start_id + i * sg_size; + + if (data_id < nelems) { + const IndexTy linear_index = flat_index_data[data_id]; + reduced_index_data[data_id] = (linear_index % row_size); + } + } + }); + }); + + return map_back_ev; +} + +} // end of namespace sort_utils_detail +} // end of namespace kernels +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp new file mode 100644 index 0000000000..1828c88eb4 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -0,0 +1,527 @@ +//=== topk.hpp - Implementation of topk kernels ---*-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor topk operation. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "kernels/sorting/sort_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +namespace topk_detail +{ + +void scale_topk_params(const std::uint64_t nelems_per_slm, + const std::size_t sub_groups_per_work_group, + const std::uint32_t elems_per_wi, + const std::vector &sg_sizes, + std::size_t &lws, + std::size_t &nelems_wg_sorts) +{ + for (auto it = sg_sizes.rbegin(); it != sg_sizes.rend(); ++it) { + auto sg_size = *it; + lws = sub_groups_per_work_group * sg_size; + nelems_wg_sorts = elems_per_wi * lws; + if (nelems_wg_sorts < nelems_per_slm) { + return; + } + } + // should never reach + throw std::runtime_error("Could not construct top k kernel parameters"); +} + +template +sycl::event write_out_impl(sycl::queue &exec_q, + std::size_t iter_nelems, + std::size_t k, + const argTy *arg_tp, + const IndexTy *index_data, + std::size_t iter_index_stride, + std::size_t axis_nelems, + argTy *vals_tp, + IndexTy *inds_tp, + const std::vector &depends) +{ + constexpr std::uint32_t lws = 64; + constexpr std::uint32_t n_wi = 4; + const std::size_t nelems = iter_nelems * k; + const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws); + + sycl::range<1> lRange{lws}; + sycl::range<1> gRange{n_groups * lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_max_local_range()[0]; + + const std::size_t start_id = (gid - lane_id) * n_wi + lane_id; + +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + const std::size_t data_id = start_id + i * sg_size; + + if (data_id < nelems) { + const std::size_t iter_id = data_id / k; + + /* + const std::size_t axis_gid = data_id - (iter_gid * k); + const std::size_t src_idx = iter_gid * iter_index_stride + + axis_gid; + */ + const std::size_t src_idx = + data_id + iter_id * (iter_index_stride - k); + + const IndexTy res_ind = index_data[src_idx]; + const argTy v = arg_tp[res_ind]; + + const std::size_t dst_idx = data_id; + vals_tp[dst_idx] = v; + inds_tp[dst_idx] = (res_ind % axis_nelems); + } + } + }); + }); + + return write_out_ev; +} + +} // namespace topk_detail + +template +class topk_populate_index_data_krn; + +template +class topk_full_merge_map_back_krn; + +template +sycl::event +topk_full_merge_sort_impl(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + const argTy *arg_tp, + argTy *vals_tp, + IndexTy *inds_tp, + const CompT &comp, + const std::vector &depends) +{ + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * axis_nelems, exec_q); + // extract USM pointer + IndexTy *index_data = index_data_owner.get(); + + using IotaKernelName = topk_populate_index_data_krn; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + sycl::event populate_indexed_data_ev = iota_impl( + exec_q, index_data, iter_nelems * axis_nelems, depends); + + std::size_t sorted_block_size; + // Sort segments of the array + sycl::event base_sort_ev = + merge_sort_detail::sort_over_work_group_contig_impl( + exec_q, iter_nelems, axis_nelems, index_data, index_data, comp, + sorted_block_size, // modified in place with size of sorted block + // size + {populate_indexed_data_ev}); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size, + {base_sort_ev}); + + using WriteOutKernelName = + topk_full_merge_map_back_krn; + + sycl::event write_out_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, index_data, axis_nelems, + axis_nelems, vals_tp, inds_tp, {merges_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {write_out_ev}, + index_data_owner); + + return cleanup_host_task_event; +}; + +template +class topk_partial_merge_map_back_krn; + +template +class topk_over_work_group_krn; + +template > +sycl::event topk_merge_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows + // in a matrix when sorting over rows) + std::size_t axis_nelems, // size of each array to sort (length of + // rows, i.e. number of columns) + std::size_t k, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + dpctl::tensor::ssize_t iter_arg_offset, + dpctl::tensor::ssize_t iter_vals_offset, + dpctl::tensor::ssize_t iter_inds_offset, + dpctl::tensor::ssize_t axis_arg_offset, + dpctl::tensor::ssize_t axis_vals_offset, + dpctl::tensor::ssize_t axis_inds_offset, + const std::vector &depends) +{ + if (axis_nelems < k) { + throw std::runtime_error("Invalid sort axis size for value of k"); + } + + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + axis_arg_offset; + argTy *vals_tp = reinterpret_cast(vals_cp) + iter_vals_offset + + axis_vals_offset; + IndexTy *inds_tp = reinterpret_cast(inds_cp) + iter_inds_offset + + axis_inds_offset; + + using dpctl::tensor::kernels::IndexComp; + const IndexComp index_comp{arg_tp, ValueComp{}}; + + if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2) { + return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, k, + arg_tp, vals_tp, inds_tp, index_comp, + depends); + } + else { + using PartialKernelName = + topk_over_work_group_krn; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + auto krn = kb.get_kernel(kernel_id); + + const std::uint32_t max_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + const std::uint64_t device_local_memory_size = + dev.get_info(); + + // leave 512 bytes of local memory for RT + const std::uint64_t safety_margin = 512; + + const std::uint64_t nelems_per_slm = + (device_local_memory_size - safety_margin) / (2 * sizeof(IndexTy)); + + constexpr std::uint32_t sub_groups_per_work_group = 4; + const std::uint32_t elems_per_wi = dev.has(sycl::aspect::cpu) ? 8 : 2; + + std::size_t lws = sub_groups_per_work_group * max_sg_size; + + std::size_t sorted_block_size = elems_per_wi * lws; + if (sorted_block_size > nelems_per_slm) { + const std::vector sg_sizes = + dev.get_info(); + topk_detail::scale_topk_params( + nelems_per_slm, sub_groups_per_work_group, elems_per_wi, + sg_sizes, + lws, // modified by reference + sorted_block_size // modified by reference + ); + } + + // This assumption permits doing away with using a loop + assert(sorted_block_size % lws == 0); + + using search_sorted_detail::quotient_ceil; + const std::size_t n_segments = + quotient_ceil(axis_nelems, sorted_block_size); + + // round k up for the later merge kernel if necessary + const std::size_t round_k_to = dev.has(sycl::aspect::cpu) ? 32 : 4; + std::size_t k_rounded = + (k < round_k_to) + ? k + : quotient_ceil(k, round_k_to) * round_k_to; + + // get length of tail for alloc size + auto rem = axis_nelems % sorted_block_size; + auto alloc_len = (rem && rem < k_rounded) + ? rem + k_rounded * (n_segments - 1) + : k_rounded * n_segments; + + // if allocation would be sufficiently large or k is larger than + // elements processed, use full sort + if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size || + alloc_len >= axis_nelems / 2) + { + return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, + k, arg_tp, vals_tp, inds_tp, + index_comp, depends); + } + + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * alloc_len, exec_q); + // get raw USM pointer + IndexTy *index_data = index_data_owner.get(); + + // no need to populate index data: SLM will be populated with default + // values + + sycl::event base_sort_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.use_kernel_bundle(kb); + + sycl::range<1> global_range{iter_nelems * n_segments * lws}; + sycl::range<1> local_range{lws}; + + sycl::range<1> slm_range{sorted_block_size}; + sycl::local_accessor work_space(slm_range, cgh); + sycl::local_accessor scratch_space(slm_range, cgh); + + sycl::nd_range<1> ndRange(global_range, local_range); + + cgh.parallel_for( + ndRange, [=](sycl::nd_item<1> it) { + const std::size_t group_id = it.get_group_linear_id(); + const std::size_t iter_id = group_id / n_segments; + const std::size_t segment_id = + group_id - iter_id * n_segments; + const std::size_t lid = it.get_local_linear_id(); + + const std::size_t segment_start_idx = + segment_id * sorted_block_size; + const std::size_t segment_end_idx = std::min( + segment_start_idx + sorted_block_size, axis_nelems); + const std::size_t wg_chunk_size = + segment_end_idx - segment_start_idx; + + // load input into SLM + for (std::size_t array_id = segment_start_idx + lid; + array_id < segment_end_idx; array_id += lws) + { + IndexTy v = (array_id < axis_nelems) + ? iter_id * axis_nelems + array_id + : IndexTy{}; + work_space[array_id - segment_start_idx] = v; + } + sycl::group_barrier(it.get_group()); + + const std::size_t chunk = + quotient_ceil(sorted_block_size, lws); + + const std::size_t chunk_start_idx = lid * chunk; + const std::size_t chunk_end_idx = + sycl::min(chunk_start_idx + chunk, wg_chunk_size); + + merge_sort_detail::leaf_sort_impl( + work_space, chunk_start_idx, chunk_end_idx, index_comp); + + sycl::group_barrier(it.get_group()); + + bool data_in_temp = false; + std::size_t n_chunks_merged = 1; + + // merge chunk while n_chunks_merged * chunk < wg_chunk_size + const std::size_t max_chunks_merged = + 1 + ((wg_chunk_size - 1) / chunk); + for (; n_chunks_merged < max_chunks_merged; + data_in_temp = !data_in_temp, n_chunks_merged *= 2) + { + const std::size_t nelems_sorted_so_far = + n_chunks_merged * chunk; + const std::size_t q = (lid / n_chunks_merged); + const std::size_t start_1 = sycl::min( + 2 * nelems_sorted_so_far * q, wg_chunk_size); + const std::size_t end_1 = sycl::min( + start_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t end_2 = sycl::min( + end_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t offset = + chunk * (lid - q * n_chunks_merged); + + if (data_in_temp) { + merge_sort_detail::merge_impl( + offset, scratch_space, work_space, start_1, + end_1, end_2, start_1, index_comp, chunk); + } + else { + merge_sort_detail::merge_impl( + offset, work_space, scratch_space, start_1, + end_1, end_2, start_1, index_comp, chunk); + } + sycl::group_barrier(it.get_group()); + } + + // output assumed to be structured as (iter_nelems, + // alloc_len) + const std::size_t k_segment_start_idx = + segment_id * k_rounded; + const std::size_t k_segment_end_idx = std::min( + k_segment_start_idx + k_rounded, alloc_len); + const auto &out_src = + (data_in_temp) ? scratch_space : work_space; + for (std::size_t array_id = k_segment_start_idx + lid; + array_id < k_segment_end_idx; array_id += lws) + { + if (lid < k_rounded) { + index_data[iter_id * alloc_len + array_id] = + out_src[array_id - k_segment_start_idx]; + } + } + }); + }); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = + merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, alloc_len, index_data, index_comp, + k_rounded, {base_sort_ev}); + + // Write out top k of the merge-sorted memory + using WriteOutKernelName = + topk_partial_merge_map_back_krn; + + sycl::event write_topk_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, index_data, alloc_len, + axis_nelems, vals_tp, inds_tp, {merges_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, index_data_owner); + + return cleanup_host_task_event; + } +} + +template class topk_iota_krn; + +template class topk_radix_map_back_krn; + +template +sycl::event topk_radix_impl(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + bool ascending, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + dpctl::tensor::ssize_t iter_arg_offset, + dpctl::tensor::ssize_t iter_vals_offset, + dpctl::tensor::ssize_t iter_inds_offset, + dpctl::tensor::ssize_t axis_arg_offset, + dpctl::tensor::ssize_t axis_vals_offset, + dpctl::tensor::ssize_t axis_inds_offset, + const std::vector &depends) +{ + if (axis_nelems < k) { + throw std::runtime_error("Invalid sort axis size for value of k"); + } + + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + axis_arg_offset; + argTy *vals_tp = reinterpret_cast(vals_cp) + iter_vals_offset + + axis_vals_offset; + IndexTy *inds_tp = reinterpret_cast(inds_cp) + iter_inds_offset + + axis_inds_offset; + + const std::size_t total_nelems = iter_nelems * axis_nelems; + const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + padded_total_nelems + total_nelems, exec_q); + + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); + IndexTy *tmp_tp = workspace + padded_total_nelems; + + using IdentityProjT = radix_sort_details::IdentityProj; + using IndexedProjT = + radix_sort_details::IndexedProj; + const IndexedProjT proj_op{arg_tp, IdentityProjT{}}; + + using IotaKernelName = topk_iota_krn; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + sycl::event iota_ev = iota_impl( + exec_q, workspace, total_nelems, depends); + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op, + ascending, {iota_ev}); + + // Write out top k of the temporary + using WriteOutKernelName = topk_radix_map_back_krn; + + sycl::event write_topk_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, tmp_tp, axis_nelems, axis_nelems, + vals_tp, inds_tp, {radix_sort_ev}); + + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, workspace_owner); + + return cleanup_ev; +} + +} // end of namespace kernels +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp index 3ad5f6f36a..f67e1bba1f 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp @@ -28,6 +28,10 @@ #include #include +#include +#include +#include +#include #include "sycl/sycl.hpp" @@ -73,11 +77,137 @@ void sycl_free_noexcept(T *ptr, const sycl::context &ctx) noexcept } } -template void sycl_free_noexcept(T *ptr, sycl::queue &q) noexcept +template +void sycl_free_noexcept(T *ptr, const sycl::queue &q) noexcept { sycl_free_noexcept(ptr, q.get_context()); } +class USMDeleter +{ +private: + sycl::context ctx_; + +public: + USMDeleter(const sycl::queue &q) : ctx_(q.get_context()) {} + USMDeleter(const sycl::context &ctx) : ctx_(ctx) {} + + template void operator()(T *ptr) const + { + sycl_free_noexcept(ptr, ctx_); + } +}; + +template +std::unique_ptr +smart_malloc(std::size_t count, + const sycl::queue &q, + sycl::usm::alloc kind, + const sycl::property_list &propList = {}) +{ + T *ptr = sycl::malloc(count, q, kind, propList); + if (nullptr == ptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + + auto usm_deleter = USMDeleter(q); + return std::unique_ptr(ptr, usm_deleter); +} + +template +std::unique_ptr +smart_malloc_device(std::size_t count, + const sycl::queue &q, + const sycl::property_list &propList = {}) +{ + return smart_malloc(count, q, sycl::usm::alloc::device, propList); +} + +template +std::unique_ptr +smart_malloc_shared(std::size_t count, + const sycl::queue &q, + const sycl::property_list &propList = {}) +{ + return smart_malloc(count, q, sycl::usm::alloc::shared, propList); +} + +template +std::unique_ptr +smart_malloc_host(std::size_t count, + const sycl::queue &q, + const sycl::property_list &propList = {}) +{ + return smart_malloc(count, q, sycl::usm::alloc::host, propList); +} + +namespace +{ +template struct valid_smart_ptr : public std::false_type +{ +}; + +template +struct valid_smart_ptr &> + : public std::is_same +{ +}; + +template +struct valid_smart_ptr> + : public std::is_same +{ +}; + +// base case +template struct all_valid_smart_ptrs +{ + static constexpr bool value = true; +}; + +template +struct all_valid_smart_ptrs +{ + static constexpr bool value = valid_smart_ptr::value && + (all_valid_smart_ptrs::value); +}; +} // namespace + +template +sycl::event async_smart_free(sycl::queue &exec_q, + const std::vector &depends, + Args &&...args) +{ + constexpr std::size_t n = sizeof...(Args); + static_assert( + n > 0, "async_smart_free requires at least one smart pointer argument"); + + static_assert( + all_valid_smart_ptrs::value, + "async_smart_free requires unique_ptr created with smart_malloc"); + + std::vector ptrs; + ptrs.reserve(n); + (ptrs.push_back(reinterpret_cast(args.get())), ...); + + std::vector dels; + dels.reserve(n); + (dels.push_back(args.get_deleter()), ...); + + sycl::event ht_e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.host_task([ptrs, dels]() { + for (size_t i = 0; i < ptrs.size(); ++i) { + dels[i](ptrs[i]); + } + }); + }); + (args.release(), ...); + + return ht_e; +} + } // end of namespace alloc_utils } // end of namespace tensor } // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/topk.cpp b/dpctl/tensor/libtensor/source/sorting/topk.cpp new file mode 100644 index 0000000000..dea20fd494 --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/topk.cpp @@ -0,0 +1,317 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dpctl4pybind11.hpp" +#include +#include +#include + +#include "utils/math_utils.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +#include "kernels/sorting/topk.hpp" +#include "rich_comparisons.hpp" +#include "topk.hpp" + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +typedef sycl::event (*topk_impl_fn_ptr_t)(sycl::queue &, + std::size_t, + std::size_t, + std::size_t, + bool, + const char *, + char *, + char *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + py::ssize_t, + py::ssize_t, + py::ssize_t, + const std::vector &); + +static topk_impl_fn_ptr_t topk_dispatch_vector[td_ns::num_types]; + +namespace +{ + +template +struct use_radix_sort : public std::false_type +{ +}; + +template +struct use_radix_sort< + T, + std::enable_if_t, + std::is_same, + std::is_same, + std::is_same, + std::is_same>::value>> + : public std::true_type +{ +}; + +template +sycl::event topk_caller(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + bool largest, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + py::ssize_t iter_arg_offset, + py::ssize_t iter_vals_offset, + py::ssize_t iter_inds_offset, + py::ssize_t axis_arg_offset, + py::ssize_t axis_vals_offset, + py::ssize_t axis_inds_offset, + const std::vector &depends) +{ + if constexpr (use_radix_sort::value) { + using dpctl::tensor::kernels::topk_radix_impl; + auto ascending = !largest; + return topk_radix_impl( + exec_q, iter_nelems, axis_nelems, k, ascending, arg_cp, vals_cp, + inds_cp, iter_arg_offset, iter_vals_offset, iter_inds_offset, + axis_arg_offset, axis_vals_offset, axis_inds_offset, depends); + } + else { + using dpctl::tensor::kernels::topk_merge_impl; + if (largest) { + using CompTy = + typename dpctl::tensor::py_internal::DescendingSorter< + argTy>::type; + return topk_merge_impl( + exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, + iter_arg_offset, iter_vals_offset, iter_inds_offset, + axis_arg_offset, axis_vals_offset, axis_inds_offset, depends); + } + else { + using CompTy = typename dpctl::tensor::py_internal::AscendingSorter< + argTy>::type; + return topk_merge_impl( + exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, + iter_arg_offset, iter_vals_offset, iter_inds_offset, + axis_arg_offset, axis_vals_offset, axis_inds_offset, depends); + } + } +} + +} // namespace + +std::pair +py_topk(const dpctl::tensor::usm_ndarray &src, + std::optional trailing_dims_to_search, + const std::size_t k, + const bool largest, + const dpctl::tensor::usm_ndarray &vals, + const dpctl::tensor::usm_ndarray &inds, + sycl::queue &exec_q, + const std::vector &depends) +{ + int src_nd = src.get_ndim(); + int vals_nd = vals.get_ndim(); + int inds_nd = inds.get_ndim(); + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *vals_shape_ptr = vals.get_shape_raw(); + const py::ssize_t *inds_shape_ptr = inds.get_shape_raw(); + + std::size_t axis_nelems(1); + std::size_t iter_nelems(1); + if (trailing_dims_to_search.has_value()) { + if (src_nd != vals_nd || src_nd != inds_nd) { + throw py::value_error("The input and output arrays must have " + "the same array ranks"); + } + + auto trailing_dims = trailing_dims_to_search.value(); + int iter_nd = src_nd - trailing_dims; + if (trailing_dims <= 0 || iter_nd < 0) { + throw py::value_error( + "trailing_dims_to_search must be positive, but no " + "greater than rank of the array being searched"); + } + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < iter_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == vals_shape_ptr[i] && + src_shape_i == inds_shape_ptr[i]); + iter_nelems *= static_cast(src_shape_i); + } + + if (!same_shapes) { + throw py::value_error( + "Destination shape does not match the input shape"); + } + + std::size_t vals_k(1); + std::size_t inds_k(1); + for (int i = iter_nd; i < src_nd; ++i) { + axis_nelems *= static_cast(src_shape_ptr[i]); + vals_k *= static_cast(vals_shape_ptr[i]); + inds_k *= static_cast(inds_shape_ptr[i]); + } + + bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k); + if (!valid_k) { + throw py::value_error("The value of k is invalid for the input and " + "destination arrays"); + } + } + else { + if (vals_nd != 1 || inds_nd != 1) { + throw py::value_error("Output arrays must be one-dimensional"); + } + + for (int i = 0; i < src_nd; ++i) { + axis_nelems *= static_cast(src_shape_ptr[i]); + } + + bool valid_k = (axis_nelems >= k && + static_cast(vals_shape_ptr[0]) == k && + static_cast(inds_shape_ptr[0]) == k); + if (!valid_k) { + throw py::value_error("The value of k is invalid for the input and " + "destination arrays"); + } + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, vals, inds})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vals); + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(inds); + + if ((iter_nelems == 0) || (axis_nelems == 0)) { + // Nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, vals) || overlap(src, inds)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vals, + k * iter_nelems); + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(inds, + k * iter_nelems); + + int src_typenum = src.get_typenum(); + int vals_typenum = vals.get_typenum(); + int inds_typenum = inds.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int vals_typeid = array_types.typenum_to_lookup_id(vals_typenum); + int inds_typeid = array_types.typenum_to_lookup_id(inds_typenum); + + if (src_typeid != vals_typeid) { + throw py::value_error("Input array and vals array must have " + "the same data type"); + } + + if (inds_typeid != static_cast(td_ns::typenum_t::INT64)) { + throw py::value_error("Inds array must have data type int64"); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_vals_c_contig = vals.is_c_contiguous(); + bool is_inds_c_contig = inds.is_c_contiguous(); + + if (is_src_c_contig && is_vals_c_contig && is_inds_c_contig) { + static constexpr py::ssize_t zero_offset = py::ssize_t(0); + + auto fn = topk_dispatch_vector[src_typeid]; + + sycl::event comp_ev = + fn(exec_q, iter_nelems, axis_nelems, k, largest, src.get_data(), + vals.get_data(), inds.get_data(), zero_offset, zero_offset, + zero_offset, zero_offset, zero_offset, zero_offset, depends); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, vals, inds}, {comp_ev}); + + return std::make_pair(keep_args_alive_ev, comp_ev); + } + + return std::make_pair(sycl::event(), sycl::event()); +} + +template struct TopKFactory +{ + fnT get() + { + using IdxT = std::int64_t; + return topk_caller; + } +}; + +void init_topk_dispatch_vectors(void) +{ + td_ns::DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(topk_dispatch_vector); +} + +void init_topk_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_topk_dispatch_vectors(); + + m.def("_topk", &py_topk, py::arg("src"), py::arg("trailing_dims_to_search"), + py::arg("k"), py::arg("largest"), py::arg("vals"), py::arg("inds"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); +} + +} // end of namespace py_internal +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/topk.hpp b/dpctl/tensor/libtensor/source/sorting/topk.hpp new file mode 100644 index 0000000000..37042457da --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/topk.hpp @@ -0,0 +1,42 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_topk_functions(py::module_); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_sorting.cpp b/dpctl/tensor/libtensor/source/tensor_sorting.cpp index 52d3ab67b4..09deac7b30 100644 --- a/dpctl/tensor/libtensor/source/tensor_sorting.cpp +++ b/dpctl/tensor/libtensor/source/tensor_sorting.cpp @@ -28,6 +28,7 @@ #include "sorting/merge_argsort.hpp" #include "sorting/merge_sort.hpp" #include "sorting/searchsorted.hpp" +#include "sorting/topk.hpp" namespace py = pybind11; @@ -36,4 +37,5 @@ PYBIND11_MODULE(_tensor_sorting_impl, m) dpctl::tensor::py_internal::init_merge_sort_functions(m); dpctl::tensor::py_internal::init_merge_argsort_functions(m); dpctl::tensor::py_internal::init_searchsorted_functions(m); + dpctl::tensor::py_internal::init_topk_functions(m); } diff --git a/dpctl/tests/test_usm_ndarray_top_k.py b/dpctl/tests/test_usm_ndarray_top_k.py new file mode 100644 index 0000000000..a27853d8c8 --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_top_k.py @@ -0,0 +1,315 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + + +def _expected_largest_inds(inp, n, shift, k): + "Computed expected top_k indices for mode='largest'" + assert k < n + ones_start_id = shift % (2 * n) + + alloc_dev = inp.device + + if ones_start_id < n: + expected_inds = dpt.arange( + ones_start_id, ones_start_id + k, dtype="i8", device=alloc_dev + ) + else: + # wrap-around + ones_end_id = (ones_start_id + n) % (2 * n) + if ones_end_id >= k: + expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev) + else: + expected_inds = dpt.concat( + ( + dpt.arange(ones_end_id, dtype="i8", device=alloc_dev), + dpt.arange( + ones_start_id, + ones_start_id + k - ones_end_id, + dtype="i8", + device=alloc_dev, + ), + ) + ) + + return expected_inds + + +@pytest.mark.parametrize( + "dtype", + [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("n", [33, 43, 255, 511, 1021, 8193]) +def test_top_k_1d_largest(dtype, n): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + shift, k = 734, 5 + o = dpt.ones(n, dtype=dtype) + z = dpt.zeros(n, dtype=dtype) + oz = dpt.concat((o, z)) + inp = dpt.roll(oz, shift) + + expected_inds = _expected_largest_inds(oz, n, shift, k) + + s = dpt.top_k(inp, k, mode="largest") + assert s.values.shape == (k,) + assert s.values.dtype == inp.dtype + assert s.indices.shape == (k,) + assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values + assert dpt.all(s.values == inp[s.indices]), s.indices + assert dpt.all(s.indices == expected_inds), (s.indices, expected_inds) + + +def _expected_smallest_inds(inp, n, shift, k): + "Computed expected top_k indices for mode='smallest'" + assert k < n + zeros_start_id = (n + shift) % (2 * n) + zeros_end_id = (shift) % (2 * n) + + alloc_dev = inp.device + + if zeros_start_id < zeros_end_id: + expected_inds = dpt.arange( + zeros_start_id, zeros_start_id + k, dtype="i8", device=alloc_dev + ) + else: + if zeros_end_id >= k: + expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev) + else: + expected_inds = dpt.concat( + ( + dpt.arange(zeros_end_id, dtype="i8", device=alloc_dev), + dpt.arange( + zeros_start_id, + zeros_start_id + k - zeros_end_id, + dtype="i8", + device=alloc_dev, + ), + ) + ) + + return expected_inds + + +@pytest.mark.parametrize( + "dtype", + [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193]) +def test_top_k_1d_smallest(dtype, n): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + shift, k = 734, 5 + o = dpt.ones(n, dtype=dtype) + z = dpt.zeros(n, dtype=dtype) + oz = dpt.concat((o, z)) + inp = dpt.roll(oz, shift) + + expected_inds = _expected_smallest_inds(oz, n, shift, k) + + s = dpt.top_k(inp, k, mode="smallest") + assert s.values.shape == (k,) + assert s.values.dtype == inp.dtype + assert s.indices.shape == (k,) + assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values + assert dpt.all(s.values == inp[s.indices]), s.indices + assert dpt.all(s.indices == expected_inds), (s.indices, expected_inds) + + +@pytest.mark.parametrize( + "dtype", + [ + # skip short types to ensure that m*n can be represented + # in the type + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193]) +def test_top_k_2d_largest(dtype, n): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m, k = 8, 3 + if dtype == "f2" and m * n > 2000: + pytest.skip( + "f2 can not distinguish between large integers used in this test" + ) + + x = dpt.reshape(dpt.arange(m * n, dtype=dtype), (m, n)) + + r = dpt.top_k(x, k, axis=1) + + assert r.values.shape == (m, k) + assert r.indices.shape == (m, k) + expected_inds = dpt.reshape(dpt.arange(n, dtype=r.indices.dtype), (1, n))[ + :, -k: + ] + assert expected_inds.shape == (1, k) + assert dpt.all( + dpt.sort(r.indices, axis=1) == dpt.sort(expected_inds, axis=1) + ), (r.indices, expected_inds) + expected_vals = x[:, -k:] + assert dpt.all( + dpt.sort(r.values, axis=1) == dpt.sort(expected_vals, axis=1) + ) + + +@pytest.mark.parametrize( + "dtype", + [ + # skip short types to ensure that m*n can be represented + # in the type + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193]) +def test_top_k_2d_smallest(dtype, n): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m, k = 8, 3 + if dtype == "f2" and m * n > 2000: + pytest.skip( + "f2 can not distinguish between large integers used in this test" + ) + + x = dpt.reshape(dpt.arange(m * n, dtype=dtype), (m, n)) + + r = dpt.top_k(x, k, axis=1, mode="smallest") + + assert r.values.shape == (m, k) + assert r.indices.shape == (m, k) + expected_inds = dpt.reshape(dpt.arange(n, dtype=r.indices.dtype), (1, n))[ + :, :k + ] + assert dpt.all( + dpt.sort(r.indices, axis=1) == dpt.sort(expected_inds, axis=1) + ) + assert dpt.all(dpt.sort(r.values, axis=1) == dpt.sort(x[:, :k], axis=1)) + + +def test_top_k_0d(): + get_queue_or_skip() + + a = dpt.ones(tuple(), dtype="i4") + assert a.ndim == 0 + assert a.size == 1 + + r = dpt.top_k(a, 1) + assert r.values == a + assert r.indices == dpt.zeros_like(a, dtype=r.indices.dtype) + + +def test_top_k_noncontig(): + get_queue_or_skip() + + a = dpt.arange(256, dtype=dpt.int32)[::2] + r = dpt.top_k(a, 3) + + assert dpt.all(dpt.sort(r.values) == dpt.asarray([250, 252, 254])), r.values + assert dpt.all( + dpt.sort(r.indices) == dpt.asarray([125, 126, 127]) + ), r.indices + + +def test_top_k_axis0(): + get_queue_or_skip() + + m, n, k = 128, 8, 3 + x = dpt.reshape(dpt.arange(m * n, dtype=dpt.int32), (m, n)) + + r = dpt.top_k(x, k, axis=0, mode="smallest") + assert r.values.shape == (k, n) + assert r.indices.shape == (k, n) + expected_inds = dpt.reshape(dpt.arange(m, dtype=r.indices.dtype), (m, 1))[ + :k, : + ] + assert dpt.all( + dpt.sort(r.indices, axis=0) == dpt.sort(expected_inds, axis=0) + ) + assert dpt.all(dpt.sort(r.values, axis=0) == dpt.sort(x[:k, :], axis=0)) + + +def test_top_k_validation(): + get_queue_or_skip() + x = dpt.ones(10, dtype=dpt.int64) + with pytest.raises(ValueError): + # k must be positive + dpt.top_k(x, -1) + with pytest.raises(TypeError): + # argument should be usm_ndarray + dpt.top_k(list(), 2) + x2 = dpt.reshape(x, (2, 5)) + with pytest.raises(ValueError): + # k must not exceed array dimension + # along specified axis + dpt.top_k(x2, 100, axis=1) + with pytest.raises(ValueError): + # for 0d arrays, k must be 1 + dpt.top_k(x[0], 2) + with pytest.raises(ValueError): + # mode must be "largest", or "smallest" + dpt.top_k(x, 2, mode="invalid")