Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements top_k in dpctl.tensor #1921

Merged
merged 30 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0b353bf
Implements `top_k`
ndgrigorian Dec 4, 2024
f809573
Add `top_k` to doc sources
ndgrigorian Dec 5, 2024
b59e59d
Use `std::size_t` for `k` instead of `py::ssize_t`
ndgrigorian Dec 5, 2024
b4d7ba4
Add implementation of `top_k` using radix sort
ndgrigorian Dec 5, 2024
c37f1de
Reduce code duplication by using std::optional in py_topk directly
ndgrigorian Dec 11, 2024
1f0e9fd
Clean up top_k docstring
ndgrigorian Dec 11, 2024
fa74e5e
Move top_k into _sorting.py
ndgrigorian Dec 12, 2024
cd83881
top_k raises when k > 1 and input is a scalar
ndgrigorian Dec 13, 2024
70c2ef0
Fix bug in top_k partial merge sort implementation
ndgrigorian Dec 13, 2024
65f14be
Add test file for top_k functionality
oleksandr-pavlyk Dec 14, 2024
275827a
Add sort_utils.hpp with iota_impl
oleksandr-pavlyk Dec 20, 2024
cd1243f
Simplify write-out kernels in topk implementation (avoid recomputing …
oleksandr-pavlyk Dec 22, 2024
823c201
Add explicit read into registers
oleksandr-pavlyk Dec 22, 2024
3a4d8ab
Use unique_ptr as temporary owner of USM allocation
oleksandr-pavlyk Dec 20, 2024
b15e302
Factor out write-out kernel into separate detail function
oleksandr-pavlyk Dec 25, 2024
4ef615e
Typo: smart_malloc_jost -> smart_malloc_host
oleksandr-pavlyk Dec 25, 2024
2608aea
Reimplemented map_back_impl to process few elements per work-item
oleksandr-pavlyk Dec 25, 2024
15da303
Replace use of std::log2(size_t_value)
oleksandr-pavlyk Dec 25, 2024
fb9bc43
Remove dead code disable via preprocessor conditional
oleksandr-pavlyk Dec 27, 2024
6cf36b1
Add information displayed on failure, renamed variables
oleksandr-pavlyk Dec 25, 2024
4ec252e
Add comment to explain ceil_log2 algo
oleksandr-pavlyk Dec 27, 2024
1436522
Add static assers to async_smart_free
oleksandr-pavlyk Dec 27, 2024
f0e3245
Skip top_k test known to fail for some CPU architectures
ndgrigorian Dec 28, 2024
8f4dcdd
Fixed blunder in work-item id to data_id computation
oleksandr-pavlyk Dec 31, 2024
b6806ba
Replace map_back_impl in sort_utils
oleksandr-pavlyk Dec 31, 2024
b47d9f3
Use get_global_linear_id instead of get_global_id and rely on implici…
oleksandr-pavlyk Jan 2, 2025
0493485
Counters in one-workgroup kernel to use uint16_t from uint32_t
oleksandr-pavlyk Jan 2, 2025
5125e11
Apply work-around for failing tests with CPU device and short sub-groups
oleksandr-pavlyk Jan 4, 2025
505b64c
Remove skipping of tests for i1/i2 dtypes since work-around
oleksandr-pavlyk Dec 28, 2024
8c6abf5
Add entry to changelog for top_k
ndgrigorian Jan 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Sorting functions

argsort
sort
top_k
1 change: 1 addition & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ set(_sorting_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp
)
set(_sorting_radix_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
Expand Down
3 changes: 2 additions & 1 deletion dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@
unique_inverse,
unique_values,
)
from ._sorting import argsort, sort
from ._sorting import argsort, sort, top_k
from ._testing import allclose
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type

Expand Down Expand Up @@ -387,4 +387,5 @@
"DLDeviceType",
"take_along_axis",
"put_along_axis",
"top_k",
]
167 changes: 167 additions & 0 deletions dpctl/tensor/_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
from typing import NamedTuple

import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as du
Expand All @@ -24,6 +27,7 @@
_argsort_descending,
_sort_ascending,
_sort_descending,
_topk,
)
from ._tensor_sorting_radix_impl import (
_radix_argsort_ascending,
Expand Down Expand Up @@ -267,3 +271,166 @@ def argsort(x, axis=-1, descending=False, stable=True, kind=None):
inv_perm = sorted(range(nd), key=lambda d: perm[d])
res = dpt.permute_dims(res, inv_perm)
return res


def _get_top_k_largest(mode):
modes = {"largest": True, "smallest": False}
try:
return modes[mode]
except KeyError:
raise ValueError(
f"`mode` must be `largest` or `smallest`. Got `{mode}`."
)


class TopKResult(NamedTuple):
values: dpt.usm_ndarray
indices: dpt.usm_ndarray


def top_k(x, k, /, *, axis=None, mode="largest"):
"""top_k(x, k, axis=None, mode="largest")

Returns the `k` largest or smallest values and their indices in the input
array `x` along the specified axis `axis`.

Args:
x (usm_ndarray):
input array.
k (int):
number of elements to find. Must be a positive integer value.
axis (Optional[int]):
axis along which to search. If `None`, the search will be performed
over the flattened array. Default: ``None``.
mode (Literal["largest", "smallest"]):
search mode. Must be one of the following modes:

- `"largest"`: return the `k` largest elements.
- `"smallest"`: return the `k` smallest elements.

Default: `"largest"`.

Returns:
tuple[usm_ndarray, usm_ndarray]
a namedtuple `(values, indices)` whose

* first element `values` will be an array containing the `k`
largest or smallest elements of `x`. The array has the same data
type as `x`. If `axis` was `None`, `values` will be a
one-dimensional array with shape `(k,)` and otherwise, `values`
will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]`
* second element `indices` will be an array containing indices of
`x` that result in `values`. The array will have the same shape
as `values` and will have the default array index data type.
"""
largest = _get_top_k_largest(mode)
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
)

k = operator.index(k)
if k < 0:
raise ValueError("`k` must be a positive integer value")

nd = x.ndim
if axis is None:
sz = x.size
if nd == 0:
if k > 1:
raise ValueError(f"`k`={k} is out of bounds 1")
return TopKResult(
dpt.copy(x, order="C"),
dpt.zeros_like(
x, dtype=ti.default_device_index_type(x.sycl_queue)
),
)
arr = x
n_search_dims = None
res_sh = k
else:
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
sz = x.shape[axis]
a1 = axis + 1
if a1 == nd:
perm = list(range(nd))
arr = x
else:
perm = [i for i in range(nd) if i != axis] + [
axis,
]
arr = dpt.permute_dims(x, perm)
n_search_dims = 1
res_sh = arr.shape[: nd - 1] + (k,)

if k > sz:
raise ValueError(f"`k`={k} is out of bounds {sz}")

exec_q = x.sycl_queue
_manager = du.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events

res_usm_type = arr.usm_type
if arr.flags.c_contiguous:
vals = dpt.empty(
res_sh,
dtype=arr.dtype,
usm_type=res_usm_type,
order="C",
sycl_queue=exec_q,
)
inds = dpt.empty(
res_sh,
dtype=ti.default_device_index_type(exec_q),
usm_type=res_usm_type,
order="C",
sycl_queue=exec_q,
)
ht_ev, impl_ev = _topk(
src=arr,
trailing_dims_to_search=n_search_dims,
k=k,
largest=largest,
vals=vals,
inds=inds,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, impl_ev)
else:
tmp = dpt.empty_like(arr, order="C")
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, copy_ev)
vals = dpt.empty(
res_sh,
dtype=arr.dtype,
usm_type=res_usm_type,
order="C",
sycl_queue=exec_q,
)
inds = dpt.empty(
res_sh,
dtype=ti.default_device_index_type(exec_q),
usm_type=res_usm_type,
order="C",
sycl_queue=exec_q,
)
ht_ev, impl_ev = _topk(
src=tmp,
trailing_dims_to_search=n_search_dims,
k=k,
largest=largest,
vals=vals,
inds=inds,
sycl_queue=exec_q,
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, impl_ev)
if axis is not None and a1 != nd:
inv_perm = sorted(range(nd), key=lambda d: perm[d])
vals = dpt.permute_dims(vals, inv_perm)
inds = dpt.permute_dims(inds, inv_perm)

return TopKResult(vals, inds)
35 changes: 9 additions & 26 deletions dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "kernels/dpctl_tensor_types.hpp"
#include "kernels/sorting/search_sorted_detail.hpp"
#include "kernels/sorting/sort_utils.hpp"

namespace dpctl
{
Expand Down Expand Up @@ -811,20 +812,12 @@ sycl::event stable_argsort_axis1_contig_impl(

const size_t total_nelems = iter_nelems * sort_nelems;

sycl::event populate_indexed_data_ev =
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;

const sycl::range<1> range{total_nelems};
using IotaKernelName = populate_index_data_krn<argTy, IndexTy, ValueComp>;

using KernelName =
populate_index_data_krn<argTy, IndexTy, ValueComp>;

cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
size_t i = id[0];
res_tp[i] = static_cast<IndexTy>(i);
});
});
sycl::event populate_indexed_data_ev = iota_impl<IotaKernelName, IndexTy>(
exec_q, res_tp, total_nelems, depends);

// Sort segments of the array
sycl::event base_sort_ev =
Expand All @@ -839,21 +832,11 @@ sycl::event stable_argsort_axis1_contig_impl(
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
{base_sort_ev});

sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(merges_ev);

auto temp_acc =
merge_sort_detail::GetReadOnlyAccess<decltype(res_tp)>{}(res_tp,
cgh);

using KernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
using MapBackKernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;

const sycl::range<1> range{total_nelems};

cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
res_tp[id] = (temp_acc[id] % sort_nelems);
});
});
sycl::event write_out_ev = map_back_impl<MapBackKernelName, IndexTy>(
exec_q, total_nelems, res_tp, res_tp, sort_nelems, {merges_ev});

return write_out_ev;
}
Expand Down
Loading
Loading