diff --git a/dpctl/tests/test_usm_ndarray_top_k.py b/dpctl/tests/test_usm_ndarray_top_k.py index e6218d846e..a029db005c 100644 --- a/dpctl/tests/test_usm_ndarray_top_k.py +++ b/dpctl/tests/test_usm_ndarray_top_k.py @@ -20,6 +20,38 @@ from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported +def _expected_largest_inds(inp, n, shift, k): + "Computed expected top_k indices for mode='largest'" + assert k < n + ones_start_id = shift % (2 * n) + + alloc_dev = inp.device + + if ones_start_id < n: + expected_inds = dpt.arange( + ones_start_id, ones_start_id + k, dtype="i8", device=alloc_dev + ) + else: + # wrap-around + ones_end_id = (ones_start_id + n) % (2 * n) + if ones_end_id >= k: + expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev) + else: + expected_inds = dpt.concat( + ( + dpt.arange(ones_end_id, dtype="i8", device=alloc_dev), + dpt.arange( + ones_start_id, + ones_start_id + k - ones_end_id, + dtype="i8", + device=alloc_dev, + ), + ) + ) + + return expected_inds + + @pytest.mark.parametrize( "dtype", [ @@ -38,25 +70,59 @@ "c16", ], ) -@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193]) -def test_topk_1d_largest(dtype, n): +@pytest.mark.parametrize("n", [33, 43, 255, 511, 1021, 8193]) +def test_top_k_1d_largest(dtype, n): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) + shift, k = 734, 5 o = dpt.ones(n, dtype=dtype) z = dpt.zeros(n, dtype=dtype) oz = dpt.concat((o, z)) - inp = dpt.roll(oz, 734) - k = 5 + inp = dpt.roll(oz, shift) + + expected_inds = _expected_largest_inds(oz, n, shift, k) s = dpt.top_k(inp, k, mode="largest") assert s.values.shape == (k,) assert s.values.dtype == inp.dtype assert s.indices.shape == (k,) + assert dpt.all(s.indices == expected_inds) assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values assert dpt.all(s.values == inp[s.indices]), s.indices +def _expected_smallest_inds(inp, n, shift, k): + "Computed expected top_k indices for mode='smallest'" + assert k < n + zeros_start_id = (n + shift) % (2 * n) + zeros_end_id = (shift) % (2 * n) + + alloc_dev = inp.device + + if zeros_start_id < zeros_end_id: + expected_inds = dpt.arange( + zeros_start_id, zeros_start_id + k, dtype="i8", device=alloc_dev + ) + else: + if zeros_end_id >= k: + expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev) + else: + expected_inds = dpt.concat( + ( + dpt.arange(zeros_end_id, dtype="i8", device=alloc_dev), + dpt.arange( + zeros_start_id, + zeros_start_id + k - zeros_end_id, + dtype="i8", + device=alloc_dev, + ), + ) + ) + + return expected_inds + + @pytest.mark.parametrize( "dtype", [ @@ -75,37 +141,70 @@ def test_topk_1d_largest(dtype, n): "c16", ], ) -@pytest.mark.parametrize("n", [33, 255, 257, 513, 1021, 8193]) -def test_topk_1d_smallest(dtype, n): +@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193]) +def test_top_k_1d_smallest(dtype, n): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) + shift, k = 734, 5 o = dpt.ones(n, dtype=dtype) z = dpt.zeros(n, dtype=dtype) oz = dpt.concat((o, z)) - inp = dpt.roll(oz, 734) - k = 5 + inp = dpt.roll(oz, shift) + + expected_inds = _expected_smallest_inds(oz, n, shift, k) s = dpt.top_k(inp, k, mode="smallest") assert s.values.shape == (k,) assert s.values.dtype == inp.dtype assert s.indices.shape == (k,) + assert dpt.all(s.indices == expected_inds) assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values assert dpt.all(s.values == inp[s.indices]), s.indices # triage failing top k radix implementation on CPU # replicates from Python behavior of radix sort topk implementation -@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193]) -def test_topk_largest_1d_radix_i1(n): +@pytest.mark.parametrize( + "n", + [ + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 61, + 137, + 255, + 511, + 1021, + 8193, + ], +) +def test_top_k_largest_1d_radix_i1(n): get_queue_or_skip() dt = "i1" + shift, k = 734, 5 o = dpt.ones(n, dtype=dt) z = dpt.zeros(n, dtype=dt) oz = dpt.concat((o, z)) - inp = dpt.roll(oz, 734) - k = 5 + inp = dpt.roll(oz, shift) + + expected_inds = _expected_largest_inds(oz, n, shift, k) sorted_v = dpt.sort(inp, descending=True, kind="radixsort") argsorted = dpt.argsort(inp, descending=True, kind="radixsort") @@ -116,4 +215,6 @@ def test_topk_largest_1d_radix_i1(n): topk_inds = dpt.copy(argsorted[:k]) assert dpt.all(topk_vals == dpt.ones(k, dtype=dt)) + assert dpt.all(topk_inds == expected_inds) + assert dpt.all(topk_vals == inp[topk_inds]), topk_inds