Skip to content

Commit 0811f1b

Browse files
Initial check in of sort/argsort per array API spec
Implemented data-parallel merge-sort algorithm.
1 parent 07dea79 commit 0811f1b

File tree

11 files changed

+1799
-0
lines changed

11 files changed

+1799
-0
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ set(_reduction_sources
118118
set(_boolean_reduction_sources
119119
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
120120
)
121+
set(_sorting_sources
122+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp
123+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp
124+
)
121125
set(_tensor_impl_sources
122126
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_ctors.cpp
123127
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
@@ -148,6 +152,10 @@ set(_tensor_reductions_impl_sources
148152
${_boolean_reduction_sources}
149153
${_reduction_sources}
150154
)
155+
set(_tensor_sorting_impl_sources
156+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
157+
${_sorting_sources}
158+
)
151159

152160
set(_py_trgts)
153161

@@ -166,6 +174,11 @@ pybind11_add_module(${python_module_name} MODULE ${_tensor_reductions_impl_sourc
166174
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_reductions_impl_sources})
167175
list(APPEND _py_trgts ${python_module_name})
168176

177+
set(python_module_name _tensor_sorting_impl)
178+
pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources})
179+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources})
180+
list(APPEND _py_trgts ${python_module_name})
181+
169182
set(_clang_prefix "")
170183
if (WIN32)
171184
set(_clang_prefix "/clang:")
@@ -179,6 +192,7 @@ set(_no_fast_math_sources
179192
list(APPEND _no_fast_math_sources
180193
${_elementwise_sources}
181194
${_reduction_sources}
195+
${_sorting_sources}
182196
)
183197

184198
foreach(_src_fn ${_no_fast_math_sources})

dpctl/tensor/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
reduce_hypot,
180180
sum,
181181
)
182+
from ._sorting import argsort, sort
182183
from ._testing import allclose
183184

184185
__all__ = [
@@ -346,4 +347,6 @@
346347
"__array_namespace_info__",
347348
"reciprocal",
348349
"angle",
350+
"sort",
351+
"argsort",
349352
]

dpctl/tensor/_sorting.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from numpy.core.numeric import normalize_axis_index
2+
3+
import dpctl
4+
import dpctl.tensor as dpt
5+
import dpctl.tensor._tensor_impl as ti
6+
7+
from ._tensor_sorting_impl import (
8+
_argsort_ascending,
9+
_argsort_descending,
10+
_sort_ascending,
11+
_sort_descending,
12+
)
13+
14+
15+
def sort(x, axis=-1, descending=False, stable=False):
16+
if not isinstance(x, dpt.usm_ndarray):
17+
raise TypeError(
18+
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
19+
)
20+
nd = x.ndim
21+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
22+
a1 = axis + 1
23+
if a1 == nd:
24+
perm = list(range(nd))
25+
arr = x
26+
else:
27+
perm = [i for i in range(nd) if i != axis] + [
28+
axis,
29+
]
30+
arr = dpt.permute_dims(x, perm)
31+
exec_q = x.sycl_queue
32+
host_tasks_list = []
33+
impl_fn = _sort_descending if descending else _sort_ascending
34+
if arr.flags.c_contiguous:
35+
res = dpt.empty_like(arr, order="C")
36+
ht_ev, _ = impl_fn(
37+
src=arr, trailing_dims_to_sort=1, dst=res, sycl_queue=exec_q
38+
)
39+
host_tasks_list.append(ht_ev)
40+
else:
41+
tmp = dpt.empty_like(arr, order="C")
42+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
43+
src=arr, dst=tmp, sycl_queue=exec_q
44+
)
45+
host_tasks_list.append(ht_ev)
46+
res = dpt.empty_like(arr, order="C")
47+
ht_ev, _ = impl_fn(
48+
src=tmp,
49+
trailing_dims_to_sort=1,
50+
dst=res,
51+
sycl_queue=exec_q,
52+
depends=[copy_ev],
53+
)
54+
host_tasks_list.append(ht_ev)
55+
if a1 != nd:
56+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
57+
res = dpt.permute_dims(res, inv_perm)
58+
dpctl.SyclEvent.wait_for(host_tasks_list)
59+
return res
60+
61+
62+
def argsort(x, axis=-1, descending=False, stable=False):
63+
if not isinstance(x, dpt.usm_ndarray):
64+
raise TypeError(
65+
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
66+
)
67+
nd = x.ndim
68+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
69+
a1 = axis + 1
70+
if a1 == nd:
71+
perm = list(range(nd))
72+
arr = x
73+
else:
74+
perm = [i for i in range(nd) if i != axis] + [
75+
axis,
76+
]
77+
arr = dpt.permute_dims(x, perm)
78+
exec_q = x.sycl_queue
79+
host_tasks_list = []
80+
impl_fn = _argsort_descending if descending else _argsort_ascending
81+
index_dt = ti.default_device_index_type(exec_q)
82+
if arr.flags.c_contiguous:
83+
res = dpt.empty_like(arr, dtype=index_dt, order="C")
84+
ht_ev, _ = impl_fn(
85+
src=arr, trailing_dims_to_sort=1, dst=res, sycl_queue=exec_q
86+
)
87+
host_tasks_list.append(ht_ev)
88+
else:
89+
tmp = dpt.empty_like(arr, order="C")
90+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
91+
src=arr, dst=tmp, sycl_queue=exec_q
92+
)
93+
host_tasks_list.append(ht_ev)
94+
res = dpt.empty_like(arr, dtype=index_dt, order="C")
95+
ht_ev, _ = impl_fn(
96+
src=tmp,
97+
trailing_dims_to_sort=1,
98+
dst=res,
99+
sycl_queue=exec_q,
100+
depends=[copy_ev],
101+
)
102+
host_tasks_list.append(ht_ev)
103+
if a1 != nd:
104+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
105+
res = dpt.permute_dims(res, inv_perm)
106+
dpctl.SyclEvent.wait_for(host_tasks_list)
107+
return res

0 commit comments

Comments
 (0)