Skip to content

Commit 2d9b6ff

Browse files
misccoPointKernel
andauthored
Fix a sync bug in stream_ref::wait (#1238) (#1283)
We were calling the wrong function Co-authored-by: Yunsong Wang <yunsongw@nvidia.com>
1 parent c4eda1a commit 2d9b6ff

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

libcudacxx/include/cuda/stream_ref

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ public:
139139
*/
140140
void wait() const
141141
{
142-
const auto __result = ::cudaStreamQuery(get());
142+
const auto __result = ::cudaStreamSynchronize(get());
143143
switch (__result)
144144
{
145145
case ::cudaSuccess:

libcudacxx/test/libcudacxx/cuda/stream_ref/stream_ref.wait.pass.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@
1414

1515
#include <cuda/stream_ref>
1616
#include <cuda/std/cassert>
17+
#include <atomic>
18+
#include <chrono>
19+
#include <thread>
20+
21+
void CUDART_CB callback(cudaStream_t, cudaError_t, void* flag)
22+
{
23+
std::chrono::milliseconds sleep_duration{1000};
24+
std::this_thread::sleep_for(sleep_duration);
25+
assert(!reinterpret_cast<std::atomic_flag*>(flag)->test_and_set());
26+
}
1727

1828
void test_wait(cuda::stream_ref& ref) {
1929
#ifndef _LIBCUDACXX_NO_EXCEPTIONS
@@ -31,8 +41,11 @@ int main(int argc, char** argv) {
3141
NV_IF_TARGET(NV_IS_HOST,( // passing case
3242
cudaStream_t stream;
3343
cudaStreamCreate(&stream);
44+
std::atomic_flag flag = ATOMIC_FLAG_INIT;
45+
cudaStreamAddCallback(stream, callback, &flag, 0);
3446
cuda::stream_ref ref{stream};
3547
test_wait(ref);
48+
assert(flag.test_and_set());
3649
cudaStreamDestroy(stream);
3750
))
3851

0 commit comments

Comments
 (0)