17
17
#define _KOKKOSKERNELS_SORTING_HPP
18
18
19
19
#include " Kokkos_Core.hpp"
20
+ #include " Kokkos_Sort.hpp"
20
21
#include " KokkosKernels_SimpleUtils.hpp" // for kk_exclusive_parallel_prefix_sum
21
22
#include " KokkosKernels_ExecSpaceUtils.hpp" // for kk_is_gpu_exec_space
22
23
#include < type_traits>
@@ -59,30 +60,13 @@ KOKKOS_INLINE_FUNCTION void SerialRadixSort2(ValueType* values, ValueType* value
59
60
// Team-level parallel sorting (callable inside any TeamPolicy kernel)
60
61
// -------------------------------------------------------------------
61
62
62
- // Comparison based sorting that uses the entire team (described by mem) to sort
63
- // raw array according to the comparator.
64
- template <typename Ordinal, typename ValueType, typename TeamMember,
65
- typename Comparator = Impl::DefaultComparator<ValueType>>
66
- KOKKOS_INLINE_FUNCTION void TeamBitonicSort (ValueType* values, Ordinal n, const TeamMember mem,
67
- const Comparator& comp = Comparator());
68
-
69
- // Same as SerialRadixSort, but also permutes perm[0...n] as it sorts
70
- // values[0...n].
71
- template <typename Ordinal, typename ValueType, typename PermType, typename TeamMember,
72
- typename Comparator = Impl::DefaultComparator<ValueType>>
73
- KOKKOS_INLINE_FUNCTION void TeamBitonicSort2 (ValueType* values, PermType* perm, Ordinal n, const TeamMember mem,
74
- const Comparator& comp = Comparator());
75
-
76
63
namespace Impl {
77
64
78
65
// Functor that sorts a view on one team
79
66
template <typename View, typename Ordinal, typename TeamMember, typename Comparator>
80
67
struct BitonicSingleTeamFunctor {
81
68
BitonicSingleTeamFunctor (View& v_, const Comparator& comp_) : v(v_), comp(comp_) {}
82
- KOKKOS_INLINE_FUNCTION void operator ()(const TeamMember t) const {
83
- KokkosKernels::TeamBitonicSort<Ordinal, typename View::value_type, TeamMember, Comparator>(v.data (), v.extent (0 ), t,
84
- comp);
85
- };
69
+ KOKKOS_INLINE_FUNCTION void operator ()(const TeamMember t) const { Kokkos::Experimental::sort_team (t, v, comp); };
86
70
View v;
87
71
Comparator comp;
88
72
};
@@ -97,8 +81,7 @@ struct BitonicChunkFunctor {
97
81
Ordinal chunkStart = chunk * chunkSize;
98
82
Ordinal n = chunkSize;
99
83
if (chunkStart + n > Ordinal (v.extent (0 ))) n = v.extent (0 ) - chunkStart;
100
- KokkosKernels::TeamBitonicSort<Ordinal, typename View::value_type, TeamMember, Comparator>(v.data () + chunkStart, n,
101
- t, comp);
84
+ Kokkos::Experimental::sort_team (t, Kokkos::subview (v, Kokkos::make_pair (chunkStart, chunkStart + n)), comp);
102
85
};
103
86
View v;
104
87
Comparator comp;
@@ -217,10 +200,11 @@ void bitonicSort(View v, const Comparator& comp) {
217
200
Ordinal npot = 1 ;
218
201
while (npot < n) npot <<= 1 ;
219
202
// Partition the data equally among fixed number of teams
220
- Ordinal chunkSize = 512 ;
221
- Ordinal numTeams = npot / chunkSize;
203
+ Ordinal chunkSize = 512 ;
204
+ Ordinal numTeamsChunkSort = (n + chunkSize - 1 ) / chunkSize;
205
+ Ordinal numTeams = npot / chunkSize;
222
206
// First, sort within teams
223
- Kokkos::parallel_for (team_policy (numTeams , Kokkos::AUTO ()),
207
+ Kokkos::parallel_for (team_policy (numTeamsChunkSort , Kokkos::AUTO ()),
224
208
Impl::BitonicChunkFunctor<View, Ordinal, team_member, Comparator>(v, comp, chunkSize));
225
209
for (int teamsPerBox = 2 ; teamsPerBox <= npot / chunkSize; teamsPerBox *= 2 ) {
226
210
Ordinal boxSize = teamsPerBox * chunkSize;
@@ -388,165 +372,23 @@ KOKKOS_INLINE_FUNCTION void SerialRadixSort2(ValueType* values, ValueType* value
388
372
// trivially-copyable) Pros: In-place, plenty of parallelism for GPUs, and
389
373
// memory references are coalesced Con: O(n log^2(n)) serial time is bad on CPUs
390
374
// Good diagram of the algorithm at https://en.wikipedia.org/wiki/Bitonic_sorter
391
- template <typename Ordinal, typename ValueType, typename TeamMember, typename Comparator>
392
- KOKKOS_INLINE_FUNCTION void TeamBitonicSort (ValueType* values, Ordinal n, const TeamMember mem,
393
- const Comparator& comp) {
394
- // Algorithm only works on power-of-two input size only.
395
- // If n is not a power-of-two, will implicitly pretend
396
- // that values[i] for i >= n is just the max for ValueType, so it never gets
397
- // swapped
398
- Ordinal npot = 1 ;
399
- Ordinal levels = 0 ;
400
- while (npot < n) {
401
- levels++;
402
- npot <<= 1 ;
403
- }
404
- for (Ordinal i = 0 ; i < levels; i++) {
405
- for (Ordinal j = 0 ; j <= i; j++) {
406
- // n/2 pairs of items are compared in parallel
407
- Kokkos::parallel_for (Kokkos::TeamVectorRange (mem, npot / 2 ), [=](const Ordinal t) {
408
- // How big are the brown/pink boxes?
409
- Ordinal boxSize = Ordinal (2 ) << (i - j);
410
- // Which box contains this thread?
411
- Ordinal boxID = t >> (i - j); // t * 2 / boxSize;
412
- Ordinal boxStart = boxID << (1 + i - j); // boxID * boxSize
413
- Ordinal boxOffset = t - (boxStart >> 1 ); // t - boxID * boxSize /
414
- // 2;
415
- Ordinal elem1 = boxStart + boxOffset;
416
- if (j == 0 ) {
417
- // first phase (brown box): within a block, compare with the
418
- // opposite value in the box
419
- Ordinal elem2 = boxStart + boxSize - 1 - boxOffset;
420
- if (elem2 < n) {
421
- // both elements in bounds, so compare them and swap if out of
422
- // order
423
- if (comp (values[elem2], values[elem1])) {
424
- ValueType temp = values[elem1];
425
- values[elem1] = values[elem2];
426
- values[elem2] = temp;
427
- }
428
- }
429
- } else {
430
- // later phases (pink box): within a block, compare with fixed
431
- // distance (boxSize / 2) apart
432
- Ordinal elem2 = elem1 + boxSize / 2 ;
433
- if (elem2 < n) {
434
- if (comp (values[elem2], values[elem1])) {
435
- ValueType temp = values[elem1];
436
- values[elem1] = values[elem2];
437
- values[elem2] = temp;
438
- }
439
- }
440
- }
441
- });
442
- mem.team_barrier ();
443
- }
444
- }
445
- }
446
-
447
- // Sort "values", while applying the same swaps to "perm"
448
- template <typename Ordinal, typename ValueType, typename PermType, typename TeamMember, typename Comparator>
449
- KOKKOS_INLINE_FUNCTION void TeamBitonicSort2 (ValueType* values, PermType* perm, Ordinal n, const TeamMember mem,
450
- const Comparator& comp) {
451
- // Algorithm only works on power-of-two input size only.
452
- // If n is not a power-of-two, will implicitly pretend
453
- // that values[i] for i >= n is just the max for ValueType, so it never gets
454
- // swapped
455
- Ordinal npot = 1 ;
456
- Ordinal levels = 0 ;
457
- while (npot < n) {
458
- levels++;
459
- npot <<= 1 ;
460
- }
461
- for (Ordinal i = 0 ; i < levels; i++) {
462
- for (Ordinal j = 0 ; j <= i; j++) {
463
- // n/2 pairs of items are compared in parallel
464
- Kokkos::parallel_for (Kokkos::TeamVectorRange (mem, npot / 2 ), [=](const Ordinal t) {
465
- // How big are the brown/pink boxes?
466
- Ordinal boxSize = Ordinal (2 ) << (i - j);
467
- // Which box contains this thread?
468
- Ordinal boxID = t >> (i - j); // t * 2 / boxSize;
469
- Ordinal boxStart = boxID << (1 + i - j); // boxID * boxSize
470
- Ordinal boxOffset = t - (boxStart >> 1 ); // t - boxID * boxSize /
471
- // 2;
472
- Ordinal elem1 = boxStart + boxOffset;
473
- if (j == 0 ) {
474
- // first phase (brown box): within a block, compare with the
475
- // opposite value in the box
476
- Ordinal elem2 = boxStart + boxSize - 1 - boxOffset;
477
- if (elem2 < n) {
478
- // both elements in bounds, so compare them and swap if out of
479
- // order
480
- if (comp (values[elem2], values[elem1])) {
481
- ValueType temp1 = values[elem1];
482
- values[elem1] = values[elem2];
483
- values[elem2] = temp1;
484
- PermType temp2 = perm[elem1];
485
- perm[elem1] = perm[elem2];
486
- perm[elem2] = temp2;
487
- }
488
- }
489
- } else {
490
- // later phases (pink box): within a block, compare with fixed
491
- // distance (boxSize / 2) apart
492
- Ordinal elem2 = elem1 + boxSize / 2 ;
493
- if (elem2 < n) {
494
- if (comp (values[elem2], values[elem1])) {
495
- ValueType temp1 = values[elem1];
496
- values[elem1] = values[elem2];
497
- values[elem2] = temp1;
498
- PermType temp2 = perm[elem1];
499
- perm[elem1] = perm[elem2];
500
- perm[elem2] = temp2;
501
- }
502
- }
503
- }
504
- });
505
- mem.team_barrier ();
506
- }
507
- }
508
- }
509
-
510
- // For backward compatibility: keep the public interface accessible in
511
- // KokkosKernels::Impl::
512
- namespace Impl {
513
-
514
- template <typename View, typename ExecSpace, typename Ordinal,
515
- typename Comparator = Impl::DefaultComparator<typename View::value_type>>
516
- [[deprecated]] void bitonicSort (View v, const Comparator& comp = Comparator()) {
517
- KokkosKernels::bitonicSort<View, ExecSpace, Ordinal, Comparator>(v, comp);
518
- }
519
-
520
- template <typename Ordinal, typename ValueType>
521
- [[deprecated]] KOKKOS_INLINE_FUNCTION void SerialRadixSort (ValueType* values, ValueType* valuesAux, Ordinal n) {
522
- KokkosKernels::SerialRadixSort<Ordinal, ValueType>(values, valuesAux, n);
523
- }
524
-
525
- // Same as SerialRadixSort, but also permutes perm[0...n] as it sorts
526
- // values[0...n].
527
- template <typename Ordinal, typename ValueType, typename PermType>
528
- [[deprecated]] KOKKOS_INLINE_FUNCTION void SerialRadixSort2 (ValueType* values, ValueType* valuesAux, PermType* perm,
529
- PermType* permAux, Ordinal n) {
530
- KokkosKernels::SerialRadixSort2<Ordinal, ValueType, PermType>(values, valuesAux, perm, permAux, n);
531
- }
532
-
533
375
template <typename Ordinal, typename ValueType, typename TeamMember,
534
376
typename Comparator = Impl::DefaultComparator<ValueType>>
535
- [[deprecated]] KOKKOS_INLINE_FUNCTION void TeamBitonicSort (ValueType* values, Ordinal n, const TeamMember mem,
536
- const Comparator& comp = Comparator()) {
537
- KokkosKernels::TeamBitonicSort<Ordinal, ValueType, TeamMember, Comparator>(values, n, mem, comp);
377
+ [[deprecated(" Use Kokkos::Experimental::sort_team instead" )]] KOKKOS_INLINE_FUNCTION void TeamBitonicSort (
378
+ ValueType* values, Ordinal n, const TeamMember mem, const Comparator& comp = Comparator()) {
379
+ Kokkos::View<ValueType*, Kokkos::AnonymousSpace> valuesView (values, n);
380
+ Kokkos::Experimental::sort_team (mem, valuesView, comp);
538
381
}
539
382
540
- // Same as SerialRadixSort, but also permutes perm[0...n] as it sorts
541
- // values[0...n].
383
+ // Sort "values", while applying the same swaps to "perm"
542
384
template <typename Ordinal, typename ValueType, typename PermType, typename TeamMember,
543
385
typename Comparator = Impl::DefaultComparator<ValueType>>
544
- [[deprecated]] KOKKOS_INLINE_FUNCTION void TeamBitonicSort2 (ValueType* values, PermType* perm, Ordinal n,
545
- const TeamMember mem,
546
- const Comparator& comp = Comparator()) {
547
- KokkosKernels::TeamBitonicSort2<Ordinal, ValueType, PermType, TeamMember, Comparator>(values, perm, n, mem, comp);
386
+ [[deprecated(" Use Kokkos::Experimental::sort_by_key_team instead" )]] KOKKOS_INLINE_FUNCTION void TeamBitonicSort2 (
387
+ ValueType* values, PermType* perm, Ordinal n, const TeamMember mem, const Comparator& comp = Comparator()) {
388
+ Kokkos::View<ValueType*, Kokkos::AnonymousSpace> valuesView (values, n);
389
+ Kokkos::View<PermType*, Kokkos::AnonymousSpace> permView (perm, n);
390
+ Kokkos::Experimental::sort_by_key_team (mem, valuesView, permView, comp);
548
391
}
549
- } // namespace Impl
550
392
551
393
} // namespace KokkosKernels
552
394
0 commit comments