20
20
from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
21
21
22
22
23
+ def _expected_largest_inds (inp , n , shift , k ):
24
+ "Computed expected top_k indices for mode='largest'"
25
+ assert k < n
26
+ ones_start_id = shift % (2 * n )
27
+
28
+ alloc_dev = inp .device
29
+
30
+ if ones_start_id < n :
31
+ expected_inds = dpt .arange (
32
+ ones_start_id , ones_start_id + k , dtype = "i8" , device = alloc_dev
33
+ )
34
+ else :
35
+ # wrap-around
36
+ ones_end_id = (ones_start_id + n ) % (2 * n )
37
+ if ones_end_id >= k :
38
+ expected_inds = dpt .arange (k , dtype = "i8" , device = alloc_dev )
39
+ else :
40
+ expected_inds = dpt .concat (
41
+ (
42
+ dpt .arange (ones_end_id , dtype = "i8" , device = alloc_dev ),
43
+ dpt .arange (
44
+ ones_start_id ,
45
+ ones_start_id + k - ones_end_id ,
46
+ dtype = "i8" ,
47
+ device = alloc_dev ,
48
+ ),
49
+ )
50
+ )
51
+
52
+ return expected_inds
53
+
54
+
23
55
@pytest .mark .parametrize (
24
56
"dtype" ,
25
57
[
38
70
"c16" ,
39
71
],
40
72
)
41
- @pytest .mark .parametrize ("n" , [33 , 255 , 511 , 1021 , 8193 ])
42
- def test_topk_1d_largest (dtype , n ):
73
+ @pytest .mark .parametrize ("n" , [33 , 43 , 255 , 511 , 1021 , 8193 ])
74
+ def test_top_k_1d_largest (dtype , n ):
43
75
q = get_queue_or_skip ()
44
76
skip_if_dtype_not_supported (dtype , q )
45
77
78
+ shift , k = 734 , 5
46
79
o = dpt .ones (n , dtype = dtype )
47
80
z = dpt .zeros (n , dtype = dtype )
48
- zo = dpt .concat ((o , z ))
49
- inp = dpt .roll (zo , 734 )
50
- k = 5
81
+ oz = dpt .concat ((o , z ))
82
+ inp = dpt .roll (oz , shift )
83
+
84
+ expected_inds = _expected_largest_inds (oz , n , shift , k )
51
85
52
86
s = dpt .top_k (inp , k , mode = "largest" )
53
87
assert s .values .shape == (k ,)
54
88
assert s .values .dtype == inp .dtype
55
89
assert s .indices .shape == (k ,)
56
- assert dpt .all (s .values == dpt .ones (k , dtype = dtype ))
57
- assert dpt .all (s .values == inp [s .indices ])
90
+ assert dpt .all (s .indices == expected_inds )
91
+ assert dpt .all (s .values == dpt .ones (k , dtype = dtype )), s .values
92
+ assert dpt .all (s .values == inp [s .indices ]), s .indices
93
+
94
+
95
+ def _expected_smallest_inds (inp , n , shift , k ):
96
+ "Computed expected top_k indices for mode='smallest'"
97
+ assert k < n
98
+ zeros_start_id = (n + shift ) % (2 * n )
99
+ zeros_end_id = (shift ) % (2 * n )
100
+
101
+ alloc_dev = inp .device
102
+
103
+ if zeros_start_id < zeros_end_id :
104
+ expected_inds = dpt .arange (
105
+ zeros_start_id , zeros_start_id + k , dtype = "i8" , device = alloc_dev
106
+ )
107
+ else :
108
+ if zeros_end_id >= k :
109
+ expected_inds = dpt .arange (k , dtype = "i8" , device = alloc_dev )
110
+ else :
111
+ expected_inds = dpt .concat (
112
+ (
113
+ dpt .arange (zeros_end_id , dtype = "i8" , device = alloc_dev ),
114
+ dpt .arange (
115
+ zeros_start_id ,
116
+ zeros_start_id + k - zeros_end_id ,
117
+ dtype = "i8" ,
118
+ device = alloc_dev ,
119
+ ),
120
+ )
121
+ )
122
+
123
+ return expected_inds
58
124
59
125
60
126
@pytest .mark .parametrize (
@@ -75,41 +141,80 @@ def test_topk_1d_largest(dtype, n):
75
141
"c16" ,
76
142
],
77
143
)
78
- @pytest .mark .parametrize ("n" , [33 , 255 , 257 , 513 , 1021 , 8193 ])
79
- def test_topk_1d_smallest (dtype , n ):
144
+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
145
+ def test_top_k_1d_smallest (dtype , n ):
80
146
q = get_queue_or_skip ()
81
147
skip_if_dtype_not_supported (dtype , q )
82
148
149
+ shift , k = 734 , 5
83
150
o = dpt .ones (n , dtype = dtype )
84
151
z = dpt .zeros (n , dtype = dtype )
85
- zo = dpt .concat ((o , z ))
86
- inp = dpt .roll (zo , 734 )
87
- k = 5
152
+ oz = dpt .concat ((o , z ))
153
+ inp = dpt .roll (oz , shift )
154
+
155
+ expected_inds = _expected_smallest_inds (oz , n , shift , k )
88
156
89
157
s = dpt .top_k (inp , k , mode = "smallest" )
90
158
assert s .values .shape == (k ,)
91
159
assert s .values .dtype == inp .dtype
92
160
assert s .indices .shape == (k ,)
93
- assert dpt .all (s .values == dpt .zeros (k , dtype = dtype ))
94
- assert dpt .all (s .values == inp [s .indices ])
161
+ assert dpt .all (s .indices == expected_inds )
162
+ assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s .values
163
+ assert dpt .all (s .values == inp [s .indices ]), s .indices
95
164
96
165
97
166
# triage failing top k radix implementation on CPU
98
167
# replicates from Python behavior of radix sort topk implementation
99
- @pytest .mark .parametrize ("n" , [33 , 255 , 511 , 1021 , 8193 ])
100
- def test_topk_largest_1d_radix_i1_255 (n ):
168
+ @pytest .mark .parametrize (
169
+ "n" ,
170
+ [
171
+ 33 ,
172
+ 34 ,
173
+ 35 ,
174
+ 36 ,
175
+ 37 ,
176
+ 38 ,
177
+ 39 ,
178
+ 40 ,
179
+ 41 ,
180
+ 42 ,
181
+ 43 ,
182
+ 44 ,
183
+ 45 ,
184
+ 46 ,
185
+ 47 ,
186
+ 48 ,
187
+ 49 ,
188
+ 50 ,
189
+ 61 ,
190
+ 137 ,
191
+ 255 ,
192
+ 511 ,
193
+ 1021 ,
194
+ 8193 ,
195
+ ],
196
+ )
197
+ def test_top_k_largest_1d_radix_i1 (n ):
101
198
get_queue_or_skip ()
102
199
dt = "i1"
103
200
201
+ shift , k = 734 , 5
104
202
o = dpt .ones (n , dtype = dt )
105
203
z = dpt .zeros (n , dtype = dt )
106
- zo = dpt .concat ((o , z ))
107
- inp = dpt .roll (zo , 734 )
108
- k = 5
109
-
110
- sorted = dpt .copy (dpt .sort (inp , descending = True , kind = "radixsort" )[:k ])
111
- argsorted = dpt .copy (
112
- dpt .argsort (inp , descending = True , kind = "radixsort" )[:k ]
113
- )
114
- assert dpt .all (sorted == dpt .ones (k , dtype = dt ))
115
- assert dpt .all (sorted == inp [argsorted ])
204
+ oz = dpt .concat ((o , z ))
205
+ inp = dpt .roll (oz , shift )
206
+
207
+ expected_inds = _expected_largest_inds (oz , n , shift , k )
208
+
209
+ sorted_v = dpt .sort (inp , descending = True , kind = "radixsort" )
210
+ argsorted = dpt .argsort (inp , descending = True , kind = "radixsort" )
211
+
212
+ assert dpt .all (sorted_v == inp [argsorted ])
213
+
214
+ topk_vals = dpt .copy (sorted_v [:k ])
215
+ topk_inds = dpt .copy (argsorted [:k ])
216
+
217
+ assert dpt .all (topk_vals == dpt .ones (k , dtype = dt ))
218
+ assert dpt .all (topk_inds == expected_inds )
219
+
220
+ assert dpt .all (topk_vals == inp [topk_inds ]), topk_inds
0 commit comments