Skip to content

Commit 44020d6

Browse files
Cherry pick binary search fix for 6.0 (#345)
Co-authored-by: Lőrinc Serfőző <mfep@users.noreply.github.com>
1 parent 5d9e939 commit 44020d6

File tree

3 files changed

+108
-6
lines changed

3 files changed

+108
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ Full documentation for rocThrust is available at [https://rocthrust.readthedocs.
1313
### Removed
1414
- Removed cub symlink from the root of the repository.
1515
- Removed support for deprecated macros (THRUST_DEVICE_BACKEND and THRUST_HOST_BACKEND).
16+
### Fixed
17+
- Fixed a segmentation fault when binary search / upper bound / lower bound / equal range was invoked with `hip_rocprim::execute_on_stream_base` policy.
1618
### Known issues
1719
- For NVIDIA backend, `NV_IF_TARGET` and `THRUST_RDC_ENABLED` intend to substitute the `THRUST_HAS_CUDART` macro, which is now no longer used in Thrust (provided for legacy support only). However, there is no `THRUST_RDC_ENABLED` macro available for the HIP backend, so some branches in Thrust's code may be unreachable in the HIP backend.
1820

test/test_binary_search.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,29 @@ TEST(BinarySearchTests, TestScalarEqualRangeDispatchImplicit)
683683
ASSERT_EQ(13, vec.front());
684684
}
685685

686+
TEST(BinarySearchTests, TestEqualRangeExecutionPolicy)
687+
{
688+
using thrust_exec_policy_t
689+
= thrust::detail::execute_with_allocator<thrust::device_allocator<char>,
690+
thrust::hip_rocprim::execute_on_stream_base>;
691+
692+
constexpr int data[] = {1, 2, 3, 4, 4, 5, 6, 7, 8, 9};
693+
constexpr size_t size = sizeof(data) / sizeof(data[0]);
694+
constexpr int key = 4;
695+
thrust::device_vector<int> d_data(data, data + size);
696+
697+
thrust::pair<thrust::device_vector<int>::iterator, thrust::device_vector<int>::iterator> range
698+
= thrust::equal_range(
699+
thrust_exec_policy_t(thrust::hip_rocprim::execute_on_stream_base<thrust_exec_policy_t>(
700+
hipStreamPerThread),
701+
thrust::device_allocator<char>()),
702+
d_data.begin(),
703+
d_data.end(),
704+
key);
705+
706+
ASSERT_EQ(*range.first, 4);
707+
ASSERT_EQ(*range.second, 5);
708+
}
686709

687710
__global__
688711
THRUST_HIP_LAUNCH_BOUNDS_DEFAULT

thrust/system/hip/detail/binary_search.h

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,12 +464,38 @@ HaystackIt lower_bound(execution_policy<Derived>& policy,
464464
values_type values(policy, 1);
465465
results_type result(policy, 1);
466466

467-
values[0] = value;
467+
{
468+
typedef typename thrust::iterator_system<const T*>::type value_in_system_t;
469+
value_in_system_t value_in_system;
470+
using thrust::system::detail::generic::select_system;
471+
thrust::copy_n(
472+
select_system(
473+
thrust::detail::derived_cast(thrust::detail::strip_const(value_in_system)),
474+
thrust::detail::derived_cast(thrust::detail::strip_const(policy))),
475+
&value,
476+
1,
477+
values.begin());
478+
}
468479

469480
__binary_search::lower_bound(
470481
policy, first, last, values.begin(), values.end(), result.begin(), compare_op);
471482

472-
return first + result[0];
483+
difference_type h_result;
484+
{
485+
typedef
486+
typename thrust::iterator_system<difference_type*>::type result_out_system_t;
487+
result_out_system_t result_out_system;
488+
using thrust::system::detail::generic::select_system;
489+
thrust::copy_n(
490+
select_system(thrust::detail::derived_cast(thrust::detail::strip_const(policy)),
491+
thrust::detail::derived_cast(
492+
thrust::detail::strip_const(result_out_system))),
493+
result.begin(),
494+
1,
495+
&h_result);
496+
}
497+
498+
return first + h_result;
473499
}
474500

475501
__device__
@@ -524,13 +550,39 @@ HaystackIt upper_bound(execution_policy<Derived>& policy,
524550
values_type values(policy, 1);
525551
results_type result(policy, 1);
526552

527-
values[0] = value;
553+
{
554+
typedef typename thrust::iterator_system<const T*>::type value_in_system_t;
555+
value_in_system_t value_in_system;
556+
using thrust::system::detail::generic::select_system;
557+
thrust::copy_n(
558+
select_system(
559+
thrust::detail::derived_cast(thrust::detail::strip_const(value_in_system)),
560+
thrust::detail::derived_cast(thrust::detail::strip_const(policy))),
561+
&value,
562+
1,
563+
values.begin());
564+
}
528565

529566
__binary_search::upper_bound(
530567
policy, first, last, values.begin(), values.end(), result.begin(), compare_op
531568
);
532569

533-
return first + result[0];
570+
difference_type h_result;
571+
{
572+
typedef
573+
typename thrust::iterator_system<difference_type*>::type result_out_system_t;
574+
result_out_system_t result_out_system;
575+
using thrust::system::detail::generic::select_system;
576+
thrust::copy_n(
577+
select_system(thrust::detail::derived_cast(thrust::detail::strip_const(policy)),
578+
thrust::detail::derived_cast(
579+
thrust::detail::strip_const(result_out_system))),
580+
result.begin(),
581+
1,
582+
&h_result);
583+
}
584+
585+
return first + h_result;
534586
}
535587

536588
__device__
@@ -583,13 +635,38 @@ bool binary_search(execution_policy<Derived>& policy,
583635
values_type values(policy, 1);
584636
results_type result(policy, 1);
585637

586-
values[0] = value;
638+
{
639+
typedef typename thrust::iterator_system<const T*>::type value_in_system_t;
640+
value_in_system_t value_in_system;
641+
using thrust::system::detail::generic::select_system;
642+
thrust::copy_n(
643+
select_system(
644+
thrust::detail::derived_cast(thrust::detail::strip_const(value_in_system)),
645+
thrust::detail::derived_cast(thrust::detail::strip_const(policy))),
646+
&value,
647+
1,
648+
values.begin());
649+
}
587650

588651
__binary_search::binary_search(
589652
policy, first, last, values.begin(), values.end(), result.begin(), compare_op
590653
);
591654

592-
return result[0] != 0;
655+
int h_result;
656+
{
657+
typedef typename thrust::iterator_system<int*>::type result_out_system_t;
658+
result_out_system_t result_out_system;
659+
using thrust::system::detail::generic::select_system;
660+
thrust::copy_n(
661+
select_system(thrust::detail::derived_cast(thrust::detail::strip_const(policy)),
662+
thrust::detail::derived_cast(
663+
thrust::detail::strip_const(result_out_system))),
664+
result.begin(),
665+
1,
666+
&h_result);
667+
}
668+
669+
return h_result != 0;
593670
}
594671

595672
__device__

0 commit comments

Comments
 (0)