Skip to content

Commit 210500f

Browse files
Add information displayed on failure, renamed variables
Add check of computed against expected indices
1 parent 54fa239 commit 210500f

File tree

1 file changed

+131
-26
lines changed

1 file changed

+131
-26
lines changed

dpctl/tests/test_usm_ndarray_top_k.py

Lines changed: 131 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,38 @@
2020
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2121

2222

23+
def _expected_largest_inds(inp, n, shift, k):
24+
"Computed expected top_k indices for mode='largest'"
25+
assert k < n
26+
ones_start_id = shift % (2 * n)
27+
28+
alloc_dev = inp.device
29+
30+
if ones_start_id < n:
31+
expected_inds = dpt.arange(
32+
ones_start_id, ones_start_id + k, dtype="i8", device=alloc_dev
33+
)
34+
else:
35+
# wrap-around
36+
ones_end_id = (ones_start_id + n) % (2 * n)
37+
if ones_end_id >= k:
38+
expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev)
39+
else:
40+
expected_inds = dpt.concat(
41+
(
42+
dpt.arange(ones_end_id, dtype="i8", device=alloc_dev),
43+
dpt.arange(
44+
ones_start_id,
45+
ones_start_id + k - ones_end_id,
46+
dtype="i8",
47+
device=alloc_dev,
48+
),
49+
)
50+
)
51+
52+
return expected_inds
53+
54+
2355
@pytest.mark.parametrize(
2456
"dtype",
2557
[
@@ -38,23 +70,57 @@
3870
"c16",
3971
],
4072
)
41-
@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193])
42-
def test_topk_1d_largest(dtype, n):
73+
@pytest.mark.parametrize("n", [33, 43, 255, 511, 1021, 8193])
74+
def test_top_k_1d_largest(dtype, n):
4375
q = get_queue_or_skip()
4476
skip_if_dtype_not_supported(dtype, q)
4577

78+
shift, k = 734, 5
4679
o = dpt.ones(n, dtype=dtype)
4780
z = dpt.zeros(n, dtype=dtype)
48-
zo = dpt.concat((o, z))
49-
inp = dpt.roll(zo, 734)
50-
k = 5
81+
oz = dpt.concat((o, z))
82+
inp = dpt.roll(oz, shift)
83+
84+
expected_inds = _expected_largest_inds(oz, n, shift, k)
5185

5286
s = dpt.top_k(inp, k, mode="largest")
5387
assert s.values.shape == (k,)
5488
assert s.values.dtype == inp.dtype
5589
assert s.indices.shape == (k,)
56-
assert dpt.all(s.values == dpt.ones(k, dtype=dtype))
57-
assert dpt.all(s.values == inp[s.indices])
90+
assert dpt.all(s.indices == expected_inds)
91+
assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values
92+
assert dpt.all(s.values == inp[s.indices]), s.indices
93+
94+
95+
def _expected_smallest_inds(inp, n, shift, k):
96+
"Computed expected top_k indices for mode='smallest'"
97+
assert k < n
98+
zeros_start_id = (n + shift) % (2 * n)
99+
zeros_end_id = (shift) % (2 * n)
100+
101+
alloc_dev = inp.device
102+
103+
if zeros_start_id < zeros_end_id:
104+
expected_inds = dpt.arange(
105+
zeros_start_id, zeros_start_id + k, dtype="i8", device=alloc_dev
106+
)
107+
else:
108+
if zeros_end_id >= k:
109+
expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev)
110+
else:
111+
expected_inds = dpt.concat(
112+
(
113+
dpt.arange(zeros_end_id, dtype="i8", device=alloc_dev),
114+
dpt.arange(
115+
zeros_start_id,
116+
zeros_start_id + k - zeros_end_id,
117+
dtype="i8",
118+
device=alloc_dev,
119+
),
120+
)
121+
)
122+
123+
return expected_inds
58124

59125

60126
@pytest.mark.parametrize(
@@ -75,41 +141,80 @@ def test_topk_1d_largest(dtype, n):
75141
"c16",
76142
],
77143
)
78-
@pytest.mark.parametrize("n", [33, 255, 257, 513, 1021, 8193])
79-
def test_topk_1d_smallest(dtype, n):
144+
@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193])
145+
def test_top_k_1d_smallest(dtype, n):
80146
q = get_queue_or_skip()
81147
skip_if_dtype_not_supported(dtype, q)
82148

149+
shift, k = 734, 5
83150
o = dpt.ones(n, dtype=dtype)
84151
z = dpt.zeros(n, dtype=dtype)
85-
zo = dpt.concat((o, z))
86-
inp = dpt.roll(zo, 734)
87-
k = 5
152+
oz = dpt.concat((o, z))
153+
inp = dpt.roll(oz, shift)
154+
155+
expected_inds = _expected_smallest_inds(oz, n, shift, k)
88156

89157
s = dpt.top_k(inp, k, mode="smallest")
90158
assert s.values.shape == (k,)
91159
assert s.values.dtype == inp.dtype
92160
assert s.indices.shape == (k,)
93-
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype))
94-
assert dpt.all(s.values == inp[s.indices])
161+
assert dpt.all(s.indices == expected_inds)
162+
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values
163+
assert dpt.all(s.values == inp[s.indices]), s.indices
95164

96165

97166
# triage failing top k radix implementation on CPU
98167
# replicates from Python behavior of radix sort topk implementation
99-
@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193])
100-
def test_topk_largest_1d_radix_i1_255(n):
168+
@pytest.mark.parametrize(
169+
"n",
170+
[
171+
33,
172+
34,
173+
35,
174+
36,
175+
37,
176+
38,
177+
39,
178+
40,
179+
41,
180+
42,
181+
43,
182+
44,
183+
45,
184+
46,
185+
47,
186+
48,
187+
49,
188+
50,
189+
61,
190+
137,
191+
255,
192+
511,
193+
1021,
194+
8193,
195+
],
196+
)
197+
def test_top_k_largest_1d_radix_i1(n):
101198
get_queue_or_skip()
102199
dt = "i1"
103200

201+
shift, k = 734, 5
104202
o = dpt.ones(n, dtype=dt)
105203
z = dpt.zeros(n, dtype=dt)
106-
zo = dpt.concat((o, z))
107-
inp = dpt.roll(zo, 734)
108-
k = 5
109-
110-
sorted = dpt.copy(dpt.sort(inp, descending=True, kind="radixsort")[:k])
111-
argsorted = dpt.copy(
112-
dpt.argsort(inp, descending=True, kind="radixsort")[:k]
113-
)
114-
assert dpt.all(sorted == dpt.ones(k, dtype=dt))
115-
assert dpt.all(sorted == inp[argsorted])
204+
oz = dpt.concat((o, z))
205+
inp = dpt.roll(oz, shift)
206+
207+
expected_inds = _expected_largest_inds(oz, n, shift, k)
208+
209+
sorted_v = dpt.sort(inp, descending=True, kind="radixsort")
210+
argsorted = dpt.argsort(inp, descending=True, kind="radixsort")
211+
212+
assert dpt.all(sorted_v == inp[argsorted])
213+
214+
topk_vals = dpt.copy(sorted_v[:k])
215+
topk_inds = dpt.copy(argsorted[:k])
216+
217+
assert dpt.all(topk_vals == dpt.ones(k, dtype=dt))
218+
assert dpt.all(topk_inds == expected_inds)
219+
220+
assert dpt.all(topk_vals == inp[topk_inds]), topk_inds

0 commit comments

Comments
 (0)