From a4f5499503b41c90797c16f367aef805c8b2fa7c Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Wed, 22 Oct 2025 12:14:15 +0200 Subject: [PATCH 01/10] Update oomph for hwmalloc heap-config branch --- ext/oomph | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/oomph b/ext/oomph index 37db2ca5..d53e1ede 160000 --- a/ext/oomph +++ b/ext/oomph @@ -1 +1 @@ -Subproject commit 37db2ca5c7c11050b66fcd11c90e7436f0b8ff39 +Subproject commit d53e1edef937b9a6a102dbced958edd382461954 From 4746b35f73507f05a8fbc3b4011a2d8b7f346ac3 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 24 Oct 2025 11:46:11 +0200 Subject: [PATCH 02/10] Update oomph submodules --- ext/oomph | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/oomph b/ext/oomph index d53e1ede..7055358c 160000 --- a/ext/oomph +++ b/ext/oomph @@ -1 +1 @@ -Subproject commit d53e1edef937b9a6a102dbced958edd382461954 +Subproject commit 7055358ca2a9c136a61603448a7b93c49d54ee3e From d26c05ae640190f46e309cd916eeb72f5398a263 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 24 Oct 2025 14:28:35 +0200 Subject: [PATCH 03/10] Refactor pack/unpack kernels --- include/ghex/unstructured/user_concepts.hpp | 157 +++++++++++++++----- 1 file changed, 121 insertions(+), 36 deletions(-) diff --git a/include/ghex/unstructured/user_concepts.hpp b/include/ghex/unstructured/user_concepts.hpp index 280872a2..66becac5 100644 --- a/include/ghex/unstructured/user_concepts.hpp +++ b/include/ghex/unstructured/user_concepts.hpp @@ -454,39 +454,71 @@ class data_descriptor #ifdef GHEX_CUDACC -#define GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK 32 +#define GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_X 32 +#define GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_Y 8 template __global__ void -pack_kernel(const T* values, const std::size_t local_indices_size, +pack_kernel_levels_first(const T* values, const std::size_t local_indices_size, const std::size_t* local_indices, const std::size_t levels, T* buffer, - const std::size_t index_stride, const std::size_t level_stride, - const std::size_t buffer_index_stride, const std::size_t buffer_level_stride) + const std::size_t index_stride, const std::size_t buffer_index_stride) +{ + const std::size_t level = threadIdx.x + (blockIdx.x * blockDim.x); + const std::size_t idx = threadIdx.y + (blockIdx.y * blockDim.y); + + if (idx < local_indices_size && level < levels) + { + auto const local_index = local_indices[idx]; + buffer[idx * buffer_index_stride + level] = values[local_index * index_stride + level]; + } +} + +template +__global__ void +pack_kernel_levels_last(const T* values, const std::size_t local_indices_size, + const std::size_t* local_indices, const std::size_t levels, T* buffer, + const std::size_t level_stride, const std::size_t buffer_level_stride) { const std::size_t idx = threadIdx.x + (blockIdx.x * blockDim.x); - if (idx < local_indices_size) + const std::size_t level = threadIdx.y + (blockIdx.y * blockDim.y); + + if (idx < local_indices_size && level < levels) { - for (std::size_t level = 0; level < levels; ++level) - { - buffer[idx * buffer_index_stride + level * buffer_level_stride] = values[local_indices[idx] * index_stride + level * level_stride]; - } + auto const local_index = local_indices[idx]; + buffer[idx + level * buffer_level_stride] = values[local_index + level * level_stride]; } } template __global__ void -unpack_kernel(const T* buffer, const std::size_t local_indices_size, +unpack_kernel_levels_first(const T* buffer, const std::size_t local_indices_size, const std::size_t* local_indices, const std::size_t levels, T* values, - const std::size_t index_stride, const std::size_t level_stride, - const std::size_t buffer_index_stride, const std::size_t buffer_level_stride) + + const std::size_t index_stride, const std::size_t buffer_index_stride) +{ + const std::size_t level = threadIdx.x + (blockIdx.x * blockDim.x); + const std::size_t idx = threadIdx.y + (blockIdx.y * blockDim.y); + + if (idx < local_indices_size && level < levels) + { + auto const local_index = local_indices[idx]; + values[local_index * index_stride + level] = buffer[idx * buffer_index_stride + level]; + } +} + +template +__global__ void +unpack_kernel_levels_last(const T* buffer, const std::size_t local_indices_size, + const std::size_t* local_indices, const std::size_t levels, T* values, + const std::size_t level_stride, const std::size_t buffer_level_stride) { const std::size_t idx = threadIdx.x + (blockIdx.x * blockDim.x); - if (idx < local_indices_size) + const std::size_t level = threadIdx.y + (blockIdx.y * blockDim.y); + + if (idx < local_indices_size && level < levels) { - for (std::size_t level = 0; level < levels; ++level) - { - values[local_indices[idx] * index_stride + level * level_stride] = buffer[idx * buffer_index_stride + level * buffer_level_stride]; - } + auto const local_index = local_indices[idx]; + values[local_index + level * level_stride] = buffer[idx + level * buffer_level_stride]; } } @@ -522,7 +554,8 @@ class data_descriptor * @param outer_stride outer dimension's stride measured in number of elements of type T (special value 0: no padding) * @param device_id device id*/ data_descriptor(const domain_descriptor_type& domain, value_type* field, - std::size_t levels = 1u, bool levels_first = true, std::size_t outer_stride = 0u, device_id_type device_id = arch_traits::current_id()) + std::size_t levels = 1u, bool levels_first = true, std::size_t outer_stride = 0u, + device_id_type device_id = arch_traits::current_id()) : m_device_id{device_id} , m_domain_id{domain.domain_id()} , m_domain_size{domain.size()} @@ -549,34 +582,86 @@ class data_descriptor template void pack(value_type* buffer, const IndexContainer& c, void* stream_ptr) { + const dim3 threads_per_block(GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_X, + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_Y); + for (const auto& is : c) { - const int n_blocks = - static_cast(std::ceil(static_cast(is.local_indices().size()) / - GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK)); - const std::size_t buffer_index_stride = m_levels_first ? m_levels : 1u; - const std::size_t buffer_level_stride = m_levels_first ? 1u : is.local_indices().size(); - pack_kernel<<(stream_ptr))>>>(m_values, - is.local_indices().size(), is.local_indices().data(), m_levels, buffer, - m_index_stride, m_level_stride, buffer_index_stride, buffer_level_stride); + if (m_levels_first) + { + const int blocks_levels = static_cast( + std::ceil(static_cast(m_levels) / + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_X)); + const int blocks_indices = static_cast( + std::ceil(static_cast(is.local_indices().size()) / + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_Y)); + + const dim3 blocks(blocks_levels, blocks_indices); + + pack_kernel_levels_first<<(stream_ptr))>>>(m_values, + is.local_indices().size(), is.local_indices().data(), m_levels, buffer, + m_index_stride, m_levels); + } + else + { + const int blocks_indices = static_cast( + std::ceil(static_cast(is.local_indices().size()) / + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_X)); + const int blocks_levels = static_cast( + std::ceil(static_cast(m_levels) / + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_Y)); + + const dim3 blocks(blocks_indices, blocks_levels); + + pack_kernel_levels_last<<(stream_ptr))>>>(m_values, + is.local_indices().size(), is.local_indices().data(), m_levels, buffer, + m_level_stride, is.local_indices().size()); + } } } template void unpack(const value_type* buffer, const IndexContainer& c, void* stream_ptr) { + const dim3 threads_per_block(GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_X, + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_Y); + for (const auto& is : c) { - const int n_blocks = - static_cast(std::ceil(static_cast(is.local_indices().size()) / - GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK)); - const std::size_t buffer_index_stride = m_levels_first ? m_levels : 1u; - const std::size_t buffer_level_stride = m_levels_first ? 1u : is.local_indices().size(); - unpack_kernel<<(stream_ptr))>>>(buffer, - is.local_indices().size(), is.local_indices().data(), m_levels, m_values, - m_index_stride, m_level_stride, buffer_index_stride, buffer_level_stride); + if (m_levels_first) + { + const int blocks_levels = static_cast( + std::ceil(static_cast(m_levels) / + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_X)); + const int blocks_indices = static_cast( + std::ceil(static_cast(is.local_indices().size()) / + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_Y)); + + const dim3 blocks(blocks_levels, blocks_indices); + + unpack_kernel_levels_first<<(stream_ptr))>>>(buffer, + is.local_indices().size(), is.local_indices().data(), m_levels, m_values, + m_index_stride, m_levels); + } + else + { + const int blocks_indices = static_cast( + std::ceil(static_cast(is.local_indices().size()) / + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_X)); + const int blocks_levels = static_cast( + std::ceil(static_cast(m_levels) / + GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_Y)); + + const dim3 blocks(blocks_indices, blocks_levels); + + unpack_kernel_levels_last<<(stream_ptr))>>>(buffer, + is.local_indices().size(), is.local_indices().data(), m_levels, m_values, + m_level_stride, is.local_indices().size()); + } } } }; From 99fe0a0c457c70bfb8f47de60f3b2f775754d4c6 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 24 Oct 2025 16:15:24 +0200 Subject: [PATCH 04/10] Try 1d block for pack/unpack --- include/ghex/unstructured/user_concepts.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ghex/unstructured/user_concepts.hpp b/include/ghex/unstructured/user_concepts.hpp index 66becac5..01214880 100644 --- a/include/ghex/unstructured/user_concepts.hpp +++ b/include/ghex/unstructured/user_concepts.hpp @@ -455,7 +455,7 @@ class data_descriptor #ifdef GHEX_CUDACC #define GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_X 32 -#define GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_Y 8 +#define GHEX_UNSTRUCTURED_SERIALIZATION_THREADS_PER_BLOCK_Y 1 template __global__ void From 527d590234399052f7bb9ef77d8c876fc5fb1ff6 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Wed, 15 Oct 2025 10:20:56 +0200 Subject: [PATCH 05/10] Add dumb nccl implementation --- CMakeLists.txt | 1 + cmake/config.hpp.in | 1 + cmake/ghex_external_dependencies.cmake | 9 + include/ghex/communication_object.hpp | 263 ++++++++++++++++++++++--- include/ghex/packer.hpp | 111 +++++++++++ 5 files changed, 358 insertions(+), 27 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index dae28823..630e82a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,6 +52,7 @@ set(GHEX_ENABLE_ATLAS_BINDINGS OFF CACHE BOOL "Set to true to build with Atlas b set(GHEX_BUILD_FORTRAN OFF CACHE BOOL "True if FORTRAN bindings shall be built") set(GHEX_BUILD_PYTHON_BINDINGS OFF CACHE BOOL "Set to true to build Python bindings") set(GHEX_WITH_TESTING OFF CACHE BOOL "True if tests shall be built") +set(GHEX_USE_NCCL ON CACHE BOOL "Use NCCL") # --------------------------------------------------------------------- # Common includes diff --git a/cmake/config.hpp.in b/cmake/config.hpp.in index 69761668..9f370893 100644 --- a/cmake/config.hpp.in +++ b/cmake/config.hpp.in @@ -21,6 +21,7 @@ #cmakedefine GHEX_USE_XPMEM #cmakedefine GHEX_USE_XPMEM_ACCESS_GUARD #cmakedefine GHEX_USE_GPU +#cmakedefine GHEX_USE_NCCL #define GHEX_GPU_MODE @ghex_gpu_mode@ #cmakedefine GHEX_GPU_MODE_EMULATE #define @GHEX_DEVICE@ diff --git a/cmake/ghex_external_dependencies.cmake b/cmake/ghex_external_dependencies.cmake index 32c40fe4..6779e406 100644 --- a/cmake/ghex_external_dependencies.cmake +++ b/cmake/ghex_external_dependencies.cmake @@ -94,6 +94,15 @@ if (GHEX_USE_XPMEM) find_package(XPMEM REQUIRED) endif() + +# --------------------------------------------------------------------- +# nccl setup +# --------------------------------------------------------------------- +if(GHEX_USE_NCCL) + link_libraries("-lnccl") + # include_directories("") +endif() + # --------------------------------------------------------------------- # parmetis setup # --------------------------------------------------------------------- diff --git a/include/ghex/communication_object.hpp b/include/ghex/communication_object.hpp index d49cd1a4..b2bbb4b7 100644 --- a/include/ghex/communication_object.hpp +++ b/include/ghex/communication_object.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -23,6 +24,9 @@ #include #include #include +#ifdef GHEX_USE_NCCL +#include +#endif namespace ghex { @@ -207,23 +211,106 @@ class communication_object using disable_if_buffer_info = std::enable_if_t::value, R>; private: // members + ghex::util::moved_bit m_moved; bool m_valid; communicator_type m_comm; memory_type m_mem; std::vector m_send_reqs; std::vector m_recv_reqs; + ncclComm_t m_nccl_comm; public: // ctors communication_object(context& c) : m_valid(false) , m_comm(c.transport_context()->get_communicator()) { + // ncclConfig_t config = NCCL_CONFIG_INITIALIZER; + // config.blocking = 0; + ncclUniqueId id; + if (m_comm.rank() == 0) { + ncclGetUniqueId(&id); + } + MPI_Comm mpi_comm = m_comm.mpi_comm(); + + // std::ostringstream msg; + // msg << "doing MPI_Bcast on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; + // std::cerr << msg.str(); + + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, mpi_comm); + // TODO: Is this needed? + MPI_Barrier(mpi_comm); + + // std::ostringstream msg_done; + // msg_done << "finished MPI_Bcast on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; + // std::cerr << msg_done.str(); + + // std::ostringstream msg_init; + // msg_init << "initializing nccl communicator on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; + // std::cerr << msg_init.str(); + + // GHEX_CHECK_NCCL_RESULT(ncclCommInitRankConfig(&m_nccl_comm, m_comm.size(), id, m_comm.rank(), &config)); + GHEX_CHECK_NCCL_RESULT(ncclCommInitRank(&m_nccl_comm, m_comm.size(), id, m_comm.rank())); + ncclResult_t state; + do { + // std::ostringstream msg_ready; + // msg_ready << "checking if nccl communicator init is still in progress on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; + // std::cerr << msg_ready.str(); + + GHEX_CHECK_NCCL_RESULT(ncclCommGetAsyncError(m_nccl_comm, &state)); + } while(state == ncclInProgress); + + // std::ostringstream msg_init_done; + // msg_init_done << "nccl communicator init done on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; + // std::cerr << msg_init_done.str(); + // GHEX_CHECK_CUDA_RESULT(cudaDeviceSynchronize()); + } + ~communication_object() noexcept { + // TODO: nothrow + // std::ostringstream msg_destroy; + // msg_destroy << "~communication_object destroying nccl communicator"; + // if (m_moved) { + // msg_destroy << ", comm is valid\n"; + // GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaDeviceSynchronize()); + // GHEX_CHECK_NCCL_RESULT_NO_THROW(ncclCommDestroy(m_nccl_comm)); + // } else { + // msg_destroy << ", comm is moved, skipping ncclCommDestroy\n"; + // } + // std::cerr << msg_destroy.str(); } communication_object(const communication_object&) = delete; communication_object(communication_object&&) = default; communicator_type& communicator() { return m_comm; } + private: + template + void nccl_exchange_impl(buffer_info_type... buffer_infos) { + // GHEX_CHECK_CUDA_RESULT(cudaDeviceSynchronize()); + // pack + // send + // std::cerr << "starting packing\n"; + for_each(m_mem, [this](std::size_t, auto& m) { + using arch_type = typename std::remove_reference_t::arch_type; + packer::pack2_nccl(m, m_send_reqs, m_comm); + }); + // std::cerr << "packing done\n"; + + // std::cerr << "starting group\n"; + ncclGroupStart(); + post_sends_nccl(); + + // recv + // unpack + // std::cerr << "starting recvs\n"; + post_recvs_nccl(); + // std::cerr << "recvs done\n"; + ncclGroupEnd(); + // std::cerr << "ending group\n"; + unpack_nccl(); + // GHEX_CHECK_CUDA_RESULT(cudaDeviceSynchronize()); + } + + public: // exchange arbitrary field-device-pattern combinations /** @brief non-blocking exchange of halo data * @tparam Archs list of device types @@ -233,9 +320,20 @@ class communication_object template [[nodiscard]] handle_type exchange(buffer_info_type... buffer_infos) { + // std::cerr << "using first exchange overload\n"; exchange_impl(buffer_infos...); - post_recvs(); - pack(); + nccl_exchange_impl(); + // // TODO: Assymetry here. + // // + // // post_recvs iterates through memory and fields here in the + // // communication object, installs callbacks for unpacking per field + // // (though one loop remains inside unpack). + // // + // // pack passes send_reqs and comm to pack, which does the iterating and + // // installing callback. pack, however, waits for packs to complete to + // // trigger sends. + // post_recvs(); + // pack(); return {this}; } @@ -248,6 +346,7 @@ class communication_object [[nodiscard]] disable_if_buffer_info exchange( Iterator first, Iterator last) { + // std::cerr << "using exchange_u overload\n"; // call special function for a single range return exchange_u(first, last); } @@ -266,6 +365,7 @@ class communication_object [[nodiscard]] disable_if_buffer_info exchange( Iterator0 first0, Iterator0 last0, Iterator1 first1, Iterator1 last1, Iterators... iters) { + // std::cerr << "using exchange with iterators overload\n"; static_assert( sizeof...(Iterators) % 2 == 0, "need even number of iteratiors: (begin,end) pairs"); // call helper function to turn iterators into pairs of iterators @@ -278,9 +378,12 @@ class communication_object template [[nodiscard]] handle_type exchange(std::pair... iter_pairs) { + // std::cerr << "using private exchange with iterators overload\n"; + exchange_impl(iter_pairs...); - post_recvs(); - pack(); + nccl_exchange_impl(); + // post_recvs(); + // pack(); return {this}; } @@ -304,6 +407,7 @@ class communication_object #endif exchange_u(Iterator first, Iterator last) { + // std::cerr << "using private exchange_u with iterators overload\n"; // call exchange with a pair of iterators return exchange(std::make_pair(first, last)); } @@ -355,6 +459,7 @@ class communication_object template void exchange_impl(std::pair... iter_pairs) { + // std::cerr << "using first exchange_impl overload\n"; const std::tuple...> iter_pairs_t{iter_pairs...}; if (m_valid) throw std::runtime_error("earlier exchange operation was not finished"); @@ -386,12 +491,14 @@ class communication_object mem, it->get_pattern(), field_ptr, my_dom_id, it->device_id(), tag_offset); } }); + // std::cerr << "done in first exchange_impl overload\n"; } // helper function to set up communicaton buffers (compile-time case) template void exchange_impl(buffer_info_type... buffer_infos) { + // std::cerr << "using second exchange_impl overload\n"; // check that arguments are compatible using test_t = pattern_container; static_assert( @@ -462,6 +569,108 @@ class communication_object }); } + void post_sends_nccl() + { + for_each(m_mem, [this](std::size_t, auto& map) { + for (auto& p0 : map.send_memory) + { + const auto device_id = p0.first; + for (auto& p1 : p0.second) + { + if (p1.second.size > 0u) + { + device::guard g(p1.second.buffer); + GHEX_CHECK_NCCL_RESULT(ncclSend(static_cast(g.data()), p1.second.buffer.size() /* * sizeof(typename decltype(p1.second.buffer)::value_type) */, ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); + } + } + } + }); + } + + void post_recvs_nccl() + { + for_each(m_mem, [this](std::size_t, auto& m) { + using arch_type = typename std::remove_reference_t::arch_type; + for (auto& p0 : m.recv_memory) + { + const auto device_id = p0.first; + for (auto& p1 : p0.second) + { + if (p1.second.size > 0u) + { + if (!p1.second.buffer || p1.second.buffer.size() != p1.second.size +#if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE) + || p1.second.buffer.device_id() != device_id +#endif + ) + // std::cerr << "post_recvs_nccl: making message\n"; + p1.second.buffer = arch_traits::make_message( + m_comm, p1.second.size, device_id); + // std::cerr << "post_recvs_nccl: triggering ncclRecv\n"; + // std::cerr << "post_recvs_nccl: ptr is " << static_cast(p1.second.buffer.device_data()) << "\n"; + GHEX_CHECK_NCCL_RESULT(ncclRecv(p1.second.buffer.device_data(), p1.second.buffer.size() /* * sizeof(typename decltype(p1.second.buffer)::value_type) */, ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); + // std::cerr << "post_recvs_nccl: triggered ncclRecv\n"; + device::guard g(p1.second.buffer); + // std::cerr << "post_recvs_nccl: triggering unpack\n"; + // TODO: This doesn't seem to happen after the recv, schedule outside ncclCommGroup? + // packer::unpack(p1.second, g.data()); + // std::cerr << "post_recvs_nccl: triggered unpack\n"; + + // use callbacks for unpacking + // m_recv_reqs.push_back(m_comm.recv(p1.second.buffer, p1.second.rank, + // p1.second.tag, + // [ptr](context::message_type& m, context::rank_type, context::tag_type) { + // device::guard g(m); + // packer::unpack(*ptr, g.data()); + // })); + } + } + } + }); + } + + void unpack_nccl() + { + for_each(m_mem, [this](std::size_t, auto& m) { + using arch_type = typename std::remove_reference_t::arch_type; + for (auto& p0 : m.recv_memory) + { + const auto device_id = p0.first; + for (auto& p1 : p0.second) + { + if (p1.second.size > 0u) + { + if (!p1.second.buffer || p1.second.buffer.size() != p1.second.size +#if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE) + || p1.second.buffer.device_id() != device_id +#endif + ) + // std::cerr << "post_recvs_nccl: making message\n"; + p1.second.buffer = arch_traits::make_message( + m_comm, p1.second.size, device_id); + // std::cerr << "post_recvs_nccl: triggering ncclRecv\n"; + // std::cerr << "post_recvs_nccl: ptr is " << static_cast(p1.second.buffer.device_data()) << "\n"; + // GHEX_CHECK_NCCL_RESULT(ncclRecv(p1.second.buffer.device_data(), p1.second.buffer.size() /* * sizeof(typename decltype(p1.second.buffer)::value_type) */, ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); + // std::cerr << "post_recvs_nccl: triggered ncclRecv\n"; + device::guard g(p1.second.buffer); + // std::cerr << "post_recvs_nccl: triggering unpack\n"; + // TODO: This doesn't seem to happen after the recv, schedule outside ncclCommGroup? + packer::unpack(p1.second, g.data()); + // std::cerr << "post_recvs_nccl: triggered unpack\n"; + + // use callbacks for unpacking + // m_recv_reqs.push_back(m_comm.recv(p1.second.buffer, p1.second.rank, + // p1.second.tag, + // [ptr](context::message_type& m, context::rank_type, context::tag_type) { + // device::guard g(m); + // packer::unpack(*ptr, g.data()); + // })); + } + } + } + }); + } + void pack() { for_each(m_mem, [this](std::size_t, auto& m) { @@ -473,38 +682,38 @@ class communication_object private: // wait functions void progress() { - if (!m_valid) return; - m_comm.progress(); + // if (!m_valid) return; + // m_comm.progress(); } bool is_ready() { - if (!m_valid) return true; - if (m_comm.is_ready()) - { -#ifdef GHEX_CUDACC - sync_streams(); -#endif - clear(); - return true; - } - m_comm.progress(); - if (m_comm.is_ready()) - { -#ifdef GHEX_CUDACC - sync_streams(); -#endif - clear(); - return true; - } + // if (!m_valid) return true; +// if (m_comm.is_ready()) +// { +// #ifdef GHEX_CUDACC +// sync_streams(); +// #endif +// clear(); +// return true; +// } +// m_comm.progress(); +// if (m_comm.is_ready()) +// { +// #ifdef GHEX_CUDACC +// sync_streams(); +// #endif +// clear(); +// return true; +// } return false; } void wait() { - if (!m_valid) return; - // wait for data to arrive (unpack callback will be invoked) - m_comm.wait_all(); +// if (!m_valid) return; +// // wait for data to arrive (unpack callback will be invoked) +// m_comm.wait_all(); #ifdef GHEX_CUDACC sync_streams(); #endif diff --git a/include/ghex/packer.hpp b/include/ghex/packer.hpp index a1475ad7..b70409ef 100644 --- a/include/ghex/packer.hpp +++ b/include/ghex/packer.hpp @@ -20,6 +20,19 @@ #include #endif +#ifdef GHEX_USE_NCCL +#include + +#define GHEX_CHECK_NCCL_RESULT(x) \ + if (x != ncclSuccess && x != ncclInProgress) \ + throw std::runtime_error(std::string("nccl call failed (") + std::to_string(x) + "):" + ncclGetErrorString(x)); +#define GHEX_CHECK_NCCL_RESULT_NO_THROW(x) \ + if (x != ncclSuccess && x != ncclInProgress) { \ + std::cerr << "nccl call failed (" << std::to_string(x) << "): " << ncclGetErrorString(x) << '\n'; \ + std::terminate(); \ + } +#endif + #include namespace ghex @@ -51,6 +64,15 @@ struct packer } } + template + static void pack2(Map& map, Requests& send_reqs, Communicator& comm) + { + pack(map, send_reqs, comm); + } + + template + static void pack2_nccl(Map&, Requests&, Communicator&) {} + template static void unpack(Buffer& buffer, unsigned char* data) { @@ -163,6 +185,95 @@ struct packer }); } + template + static void pack2_nccl(Map& map, Requests&, Communicator& comm) + { +#if 0 + constexpr std::size_t num_extra_streams{32}; + static std::vector streams(num_extra_streams); + static std::size_t stream_index{0}; +#endif + + constexpr std::size_t num_events{128}; + static std::vector events(num_events); + static std::size_t event_index{0}; + + // Assume that send memory synchronizes with the default + // stream so schedule pack kernels after an event on the + // default stream. + cudaEvent_t& e = events[event_index].get(); + event_index = (event_index + 1) % num_events; + GHEX_CHECK_CUDA_RESULT(cudaEventRecord(e, 0)); + for (auto& p0 : map.send_memory) + { + const auto device_id = p0.first; + for (auto& p1 : p0.second) + { + if (p1.second.size > 0u) + { + if (!p1.second.buffer || p1.second.buffer.size() != p1.second.size || + p1.second.buffer.device_id() != device_id) + { + // std::cerr << "pack2_nccl: making message\n"; + p1.second.buffer = + arch_traits::make_message(comm, p1.second.size, device_id); + } + + device::guard g(p1.second.buffer); +#if 0 + int count = 0; +#endif + // Make sure stream used for packing synchronizes with the + // default stream. + GHEX_CHECK_CUDA_RESULT(cudaStreamWaitEvent(p1.second.m_stream.get(), e)); + for (const auto& fb : p1.second.field_infos) + { + // TODO: + // 1. launch pack kernels on separate streams for all data + // 1. (alternative) pack them all into the same kernel + // 2. trigger the send from a cuda host function + // 3. don't wait for futures here, but mixed with polling mpi for receives +#if 0 + if (count == 0) { +#endif + // std::cerr << "pack2_nccl: calling pack call_back\n"; + fb.call_back(g.data() + fb.offset, *fb.index_container, (void*)(&p1.second.m_stream.get())); +#if 0 + } else { + cudaStream_t& s = streams[stream_index].get(); + stream_index = (stream_index + 1) % num_extra_streams; + + cudaEvent_t& e = events[event_index].get(); + event_index = (event_index + 1) % num_events; + + fb.call_back(g.data() + fb.offset, *fb.index_container, (void*)(&s)); + + // Use the main stream only to synchronize. Launch + // the work on a separate stream and insert an event + // to allow waiting for all work on the main stream. + GHEX_CHECK_CUDA_RESULT(cudaEventRecord(e, s)); + GHEX_CHECK_CUDA_RESULT(cudaStreamWaitEvent(p1.second.m_stream.get(), e)); + } + ++count; +#endif + } + + // Warning: tag is not used. Messages have to be correctly ordered. + // This is just for debugging, don't do mpi and nccl send + // std::cerr << "pack2_nccl: triggering mpi_isend\n"; + // comm.send(p1.second.buffer, p1.second.rank, p1.second.tag); + // std::cerr << "pack2_nccl: triggering ncclSend\n"; + // std::cerr << "pack2_nccl: ptr is " << static_cast(p1.second.buffer.device_data()) << "\n"; + // std::cerr << "pack2_nccl: g.data() is " << static_cast(g.data()) << "\n"; + // std::cerr << "pack2_nccl: size is " << p1.second.buffer.size() << "\n"; + // std::cerr << "pack2_nccl: ptr on device " << p1.second.buffer.on_device() << "\n"; + // GHEX_CHECK_NCCL_RESULT(ncclSend(static_cast(g.data()), p1.second.buffer.size() /* * sizeof(typename decltype(p1.second.buffer)::value_type) */, ncclChar, p1.second.rank, nccl_comm, p1.second.m_stream.get())); + // std::cerr << "pack2_nccl: triggered ncclSend\n"; + } + } + } + } + template static void unpack(Buffer& buffer, unsigned char* data) { From 78879bb3397127e0320afe4e15902a385401976e Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 24 Oct 2025 19:15:16 +0200 Subject: [PATCH 06/10] Add back cuda event class --- include/ghex/device/cuda/stream.hpp | 32 ++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/include/ghex/device/cuda/stream.hpp b/include/ghex/device/cuda/stream.hpp index 5aa75ef0..eb5ea37a 100644 --- a/include/ghex/device/cuda/stream.hpp +++ b/include/ghex/device/cuda/stream.hpp @@ -19,17 +19,41 @@ namespace ghex { namespace device { +struct cuda_event { + cudaEvent_t m_event; + ghex::util::moved_bit m_moved; + + cuda_event() { + GHEX_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming)) + } + cuda_event(const cuda_event&) = delete; + cuda_event& operator=(const cuda_event&) = delete; + cuda_event(cuda_event&& other) = default; + cuda_event& operator=(cuda_event&&) = default; + + ~cuda_event() + { + if (!m_moved) + { + GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event)) + } + } + + operator bool() const noexcept { return m_moved; } + operator cudaEvent_t() const noexcept { return m_event; } + cudaEvent_t& get() noexcept { return m_event; } + const cudaEvent_t& get() const noexcept { return m_event; } +}; + /** @brief thin wrapper around a cuda stream */ struct stream { cudaStream_t m_stream; - cudaEvent_t m_event; ghex::util::moved_bit m_moved; stream() { GHEX_CHECK_CUDA_RESULT(cudaStreamCreateWithFlags(&m_stream, cudaStreamNonBlocking)) - GHEX_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming)) } stream(const stream&) = delete; @@ -42,7 +66,6 @@ struct stream if (!m_moved) { GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaStreamDestroy(m_stream)) - GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event)) } } @@ -55,9 +78,8 @@ struct stream void sync() { - GHEX_CHECK_CUDA_RESULT(cudaEventRecord(m_event, m_stream)) // busy wait here - GHEX_CHECK_CUDA_RESULT(cudaEventSynchronize(m_event)) + GHEX_CHECK_CUDA_RESULT(cudaStreamSynchronize(m_stream)) } }; } // namespace device From f314a1cbdf431e69df77c6751939d95c093d0fd6 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 24 Oct 2025 20:26:22 +0200 Subject: [PATCH 07/10] Add TODO for nccl in cmake --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 630e82a3..1f3ffa41 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,6 +52,7 @@ set(GHEX_ENABLE_ATLAS_BINDINGS OFF CACHE BOOL "Set to true to build with Atlas b set(GHEX_BUILD_FORTRAN OFF CACHE BOOL "True if FORTRAN bindings shall be built") set(GHEX_BUILD_PYTHON_BINDINGS OFF CACHE BOOL "Set to true to build Python bindings") set(GHEX_WITH_TESTING OFF CACHE BOOL "True if tests shall be built") +# TODO: Add FindNCCL.cmake module. set(GHEX_USE_NCCL ON CACHE BOOL "Use NCCL") # --------------------------------------------------------------------- From ab0dfd0a91d46bf553931e1473b9ac1628fcaad0 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 24 Oct 2025 20:36:10 +0200 Subject: [PATCH 08/10] Clean up nccl parts --- include/ghex/communication_object.hpp | 207 +++++++++----------------- 1 file changed, 71 insertions(+), 136 deletions(-) diff --git a/include/ghex/communication_object.hpp b/include/ghex/communication_object.hpp index b2bbb4b7..0d5db041 100644 --- a/include/ghex/communication_object.hpp +++ b/include/ghex/communication_object.hpp @@ -24,6 +24,7 @@ #include #include #include + #ifdef GHEX_USE_NCCL #include #endif @@ -214,68 +215,37 @@ class communication_object ghex::util::moved_bit m_moved; bool m_valid; communicator_type m_comm; +#ifdef GHEX_USE_NCCL + ncclComm_t m_nccl_comm; +#endif memory_type m_mem; std::vector m_send_reqs; std::vector m_recv_reqs; - ncclComm_t m_nccl_comm; public: // ctors communication_object(context& c) : m_valid(false) , m_comm(c.transport_context()->get_communicator()) { - // ncclConfig_t config = NCCL_CONFIG_INITIALIZER; - // config.blocking = 0; ncclUniqueId id; if (m_comm.rank() == 0) { ncclGetUniqueId(&id); } MPI_Comm mpi_comm = m_comm.mpi_comm(); - // std::ostringstream msg; - // msg << "doing MPI_Bcast on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; - // std::cerr << msg.str(); - MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, mpi_comm); - // TODO: Is this needed? - MPI_Barrier(mpi_comm); - - // std::ostringstream msg_done; - // msg_done << "finished MPI_Bcast on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; - // std::cerr << msg_done.str(); - // std::ostringstream msg_init; - // msg_init << "initializing nccl communicator on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; - // std::cerr << msg_init.str(); - - // GHEX_CHECK_NCCL_RESULT(ncclCommInitRankConfig(&m_nccl_comm, m_comm.size(), id, m_comm.rank(), &config)); GHEX_CHECK_NCCL_RESULT(ncclCommInitRank(&m_nccl_comm, m_comm.size(), id, m_comm.rank())); ncclResult_t state; do { - // std::ostringstream msg_ready; - // msg_ready << "checking if nccl communicator init is still in progress on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; - // std::cerr << msg_ready.str(); - GHEX_CHECK_NCCL_RESULT(ncclCommGetAsyncError(m_nccl_comm, &state)); } while(state == ncclInProgress); - - // std::ostringstream msg_init_done; - // msg_init_done << "nccl communicator init done on rank " << m_comm.rank() << "/" << m_comm.size() << '\n'; - // std::cerr << msg_init_done.str(); - // GHEX_CHECK_CUDA_RESULT(cudaDeviceSynchronize()); } ~communication_object() noexcept { - // TODO: nothrow - // std::ostringstream msg_destroy; - // msg_destroy << "~communication_object destroying nccl communicator"; - // if (m_moved) { - // msg_destroy << ", comm is valid\n"; - // GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaDeviceSynchronize()); - // GHEX_CHECK_NCCL_RESULT_NO_THROW(ncclCommDestroy(m_nccl_comm)); - // } else { - // msg_destroy << ", comm is moved, skipping ncclCommDestroy\n"; - // } - // std::cerr << msg_destroy.str(); + if (!m_moved) { + GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaDeviceSynchronize()); + GHEX_CHECK_NCCL_RESULT_NO_THROW(ncclCommDestroy(m_nccl_comm)); + } } communication_object(const communication_object&) = delete; communication_object(communication_object&&) = default; @@ -285,29 +255,14 @@ class communication_object private: template void nccl_exchange_impl(buffer_info_type... buffer_infos) { - // GHEX_CHECK_CUDA_RESULT(cudaDeviceSynchronize()); - // pack - // send - // std::cerr << "starting packing\n"; - for_each(m_mem, [this](std::size_t, auto& m) { - using arch_type = typename std::remove_reference_t::arch_type; - packer::pack2_nccl(m, m_send_reqs, m_comm); - }); - // std::cerr << "packing done\n"; + pack_nccl(); - // std::cerr << "starting group\n"; ncclGroupStart(); post_sends_nccl(); - - // recv - // unpack - // std::cerr << "starting recvs\n"; post_recvs_nccl(); - // std::cerr << "recvs done\n"; ncclGroupEnd(); - // std::cerr << "ending group\n"; + unpack_nccl(); - // GHEX_CHECK_CUDA_RESULT(cudaDeviceSynchronize()); } @@ -320,20 +275,13 @@ class communication_object template [[nodiscard]] handle_type exchange(buffer_info_type... buffer_infos) { - // std::cerr << "using first exchange overload\n"; exchange_impl(buffer_infos...); +#ifdef GHEX_USE_NCCL nccl_exchange_impl(); - // // TODO: Assymetry here. - // // - // // post_recvs iterates through memory and fields here in the - // // communication object, installs callbacks for unpacking per field - // // (though one loop remains inside unpack). - // // - // // pack passes send_reqs and comm to pack, which does the iterating and - // // installing callback. pack, however, waits for packs to complete to - // // trigger sends. - // post_recvs(); - // pack(); +#else + post_recvs(); + pack(); +#endif return {this}; } @@ -346,8 +294,6 @@ class communication_object [[nodiscard]] disable_if_buffer_info exchange( Iterator first, Iterator last) { - // std::cerr << "using exchange_u overload\n"; - // call special function for a single range return exchange_u(first, last); } @@ -365,7 +311,6 @@ class communication_object [[nodiscard]] disable_if_buffer_info exchange( Iterator0 first0, Iterator0 last0, Iterator1 first1, Iterator1 last1, Iterators... iters) { - // std::cerr << "using exchange with iterators overload\n"; static_assert( sizeof...(Iterators) % 2 == 0, "need even number of iteratiors: (begin,end) pairs"); // call helper function to turn iterators into pairs of iterators @@ -378,12 +323,13 @@ class communication_object template [[nodiscard]] handle_type exchange(std::pair... iter_pairs) { - // std::cerr << "using private exchange with iterators overload\n"; - exchange_impl(iter_pairs...); +#ifdef GHEX_USE_NCCL nccl_exchange_impl(); - // post_recvs(); - // pack(); +#else + post_recvs(); + pack(); +#endif return {this}; } @@ -407,7 +353,6 @@ class communication_object #endif exchange_u(Iterator first, Iterator last) { - // std::cerr << "using private exchange_u with iterators overload\n"; // call exchange with a pair of iterators return exchange(std::make_pair(first, last)); } @@ -459,7 +404,6 @@ class communication_object template void exchange_impl(std::pair... iter_pairs) { - // std::cerr << "using first exchange_impl overload\n"; const std::tuple...> iter_pairs_t{iter_pairs...}; if (m_valid) throw std::runtime_error("earlier exchange operation was not finished"); @@ -491,14 +435,12 @@ class communication_object mem, it->get_pattern(), field_ptr, my_dom_id, it->device_id(), tag_offset); } }); - // std::cerr << "done in first exchange_impl overload\n"; } // helper function to set up communicaton buffers (compile-time case) template void exchange_impl(buffer_info_type... buffer_infos) { - // std::cerr << "using second exchange_impl overload\n"; // check that arguments are compatible using test_t = pattern_container; static_assert( @@ -580,7 +522,11 @@ class communication_object if (p1.second.size > 0u) { device::guard g(p1.second.buffer); - GHEX_CHECK_NCCL_RESULT(ncclSend(static_cast(g.data()), p1.second.buffer.size() /* * sizeof(typename decltype(p1.second.buffer)::value_type) */, ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); + // TODO: Check why element size isn't relevant for the + // buffer size (also for recv). + GHEX_CHECK_NCCL_RESULT( + ncclSend(static_cast(g.data()), p1.second.buffer.size(), + ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); } } } @@ -603,32 +549,25 @@ class communication_object || p1.second.buffer.device_id() != device_id #endif ) - // std::cerr << "post_recvs_nccl: making message\n"; p1.second.buffer = arch_traits::make_message( m_comm, p1.second.size, device_id); - // std::cerr << "post_recvs_nccl: triggering ncclRecv\n"; - // std::cerr << "post_recvs_nccl: ptr is " << static_cast(p1.second.buffer.device_data()) << "\n"; - GHEX_CHECK_NCCL_RESULT(ncclRecv(p1.second.buffer.device_data(), p1.second.buffer.size() /* * sizeof(typename decltype(p1.second.buffer)::value_type) */, ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); - // std::cerr << "post_recvs_nccl: triggered ncclRecv\n"; - device::guard g(p1.second.buffer); - // std::cerr << "post_recvs_nccl: triggering unpack\n"; - // TODO: This doesn't seem to happen after the recv, schedule outside ncclCommGroup? - // packer::unpack(p1.second, g.data()); - // std::cerr << "post_recvs_nccl: triggered unpack\n"; - - // use callbacks for unpacking - // m_recv_reqs.push_back(m_comm.recv(p1.second.buffer, p1.second.rank, - // p1.second.tag, - // [ptr](context::message_type& m, context::rank_type, context::tag_type) { - // device::guard g(m); - // packer::unpack(*ptr, g.data()); - // })); + GHEX_CHECK_NCCL_RESULT( + ncclRecv(p1.second.buffer.device_data(), p1.second.buffer.size(), + ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); } } } }); } + void pack_nccl() + { + for_each(m_mem, [this](std::size_t, auto& m) { + using arch_type = typename std::remove_reference_t::arch_type; + packer::pack2_nccl(m, m_send_reqs, m_comm); + }); + } + void unpack_nccl() { for_each(m_mem, [this](std::size_t, auto& m) { @@ -645,26 +584,10 @@ class communication_object || p1.second.buffer.device_id() != device_id #endif ) - // std::cerr << "post_recvs_nccl: making message\n"; p1.second.buffer = arch_traits::make_message( m_comm, p1.second.size, device_id); - // std::cerr << "post_recvs_nccl: triggering ncclRecv\n"; - // std::cerr << "post_recvs_nccl: ptr is " << static_cast(p1.second.buffer.device_data()) << "\n"; - // GHEX_CHECK_NCCL_RESULT(ncclRecv(p1.second.buffer.device_data(), p1.second.buffer.size() /* * sizeof(typename decltype(p1.second.buffer)::value_type) */, ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); - // std::cerr << "post_recvs_nccl: triggered ncclRecv\n"; device::guard g(p1.second.buffer); - // std::cerr << "post_recvs_nccl: triggering unpack\n"; - // TODO: This doesn't seem to happen after the recv, schedule outside ncclCommGroup? packer::unpack(p1.second, g.data()); - // std::cerr << "post_recvs_nccl: triggered unpack\n"; - - // use callbacks for unpacking - // m_recv_reqs.push_back(m_comm.recv(p1.second.buffer, p1.second.rank, - // p1.second.tag, - // [ptr](context::message_type& m, context::rank_type, context::tag_type) { - // device::guard g(m); - // packer::unpack(*ptr, g.data()); - // })); } } } @@ -682,40 +605,52 @@ class communication_object private: // wait functions void progress() { - // if (!m_valid) return; - // m_comm.progress(); +#ifdef GHEX_USE_NCCL + // TODO: No progress needed? +#else + if (!m_valid) return; + m_comm.progress(); +#endif } bool is_ready() { - // if (!m_valid) return true; -// if (m_comm.is_ready()) -// { -// #ifdef GHEX_CUDACC -// sync_streams(); -// #endif -// clear(); -// return true; -// } -// m_comm.progress(); -// if (m_comm.is_ready()) -// { -// #ifdef GHEX_CUDACC -// sync_streams(); -// #endif -// clear(); -// return true; -// } +#ifdef GHEX_USE_NCCL + // TODO: Check if streams are idle? +#else + if (!m_valid) return true; + if (m_comm.is_ready()) + { +#ifdef GHEX_CUDACC + sync_streams(); +#endif + clear(); + return true; + } + m_comm.progress(); + if (m_comm.is_ready()) + { +#ifdef GHEX_CUDACC + sync_streams(); +#endif + clear(); + return true; + } +#endif return false; } void wait() { -// if (!m_valid) return; -// // wait for data to arrive (unpack callback will be invoked) -// m_comm.wait_all(); +#ifdef GHEX_USE_NCCL + // TODO: Wait for stream? +#else + if (!m_valid) return; + // wait for data to arrive (unpack callback will be invoked) + m_comm.wait_all(); #ifdef GHEX_CUDACC sync_streams(); +#endif #endif clear(); } From 4b5833fe09dca3cc040358f6146a1c0240fd9c97 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 24 Oct 2025 23:30:33 +0200 Subject: [PATCH 09/10] Small fix to stream syncing with nccl --- include/ghex/communication_object.hpp | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/include/ghex/communication_object.hpp b/include/ghex/communication_object.hpp index 0d5db041..6e0f4420 100644 --- a/include/ghex/communication_object.hpp +++ b/include/ghex/communication_object.hpp @@ -642,15 +642,13 @@ class communication_object void wait() { -#ifdef GHEX_USE_NCCL - // TODO: Wait for stream? -#else +#ifndef GHEX_USE_NCCL if (!m_valid) return; // wait for data to arrive (unpack callback will be invoked) m_comm.wait_all(); +#endif #ifdef GHEX_CUDACC sync_streams(); -#endif #endif clear(); } @@ -659,6 +657,10 @@ class communication_object private: // synchronize (unpacking) streams void sync_streams() { + constexpr std::size_t num_events{128}; + static std::vector events(num_events); + static std::size_t event_index{0}; + using gpu_mem_t = buffer_memory; auto& m = std::get(m_mem); for (auto& p0 : m.recv_memory) @@ -667,7 +669,18 @@ class communication_object { if (p1.second.size > 0u) { +#ifdef GHEX_USE_NCCL + // Instead of doing a blocking wait, create events on each + // stream that the default stream waits for. This assumes + // that all kernels that need the unpacked data will use or + // synchronize with the default stream. + cudaEvent_t& e = events[event_index].get(); + event_index = (event_index + 1) % num_events; + GHEX_CHECK_CUDA_RESULT(cudaEventRecord(e, p1.second.m_stream.get())); + GHEX_CHECK_CUDA_RESULT(cudaStreamWaitEvent(0, e)); +#else p1.second.m_stream.sync(); +#endif } } } From ee1b85193f73bf8abb083cc4e3293fcb6858074a Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 24 Oct 2025 23:31:05 +0200 Subject: [PATCH 10/10] Update test to disable cpu exchange with nccl --- test/unstructured/test_user_concepts.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/unstructured/test_user_concepts.cpp b/test/unstructured/test_user_concepts.cpp index 35e4d0a3..31ce2f28 100644 --- a/test/unstructured/test_user_concepts.cpp +++ b/test/unstructured/test_user_concepts.cpp @@ -273,6 +273,7 @@ test_data_descriptor(ghex::context& ctxt, std::size_t levels, bool levels_first) // application data auto& d = local_domains[0]; ghex::test::util::memory field(d.size()*levels, 0); +#ifndef GHEX_USE_NCCL initialize_data(d, field, levels, levels_first); data_descriptor_cpu_int_type data{d, field, levels, levels_first}; @@ -283,6 +284,7 @@ test_data_descriptor(ghex::context& ctxt, std::size_t levels, bool levels_first) // check exchanged data check_exchanged_data(d, field, patterns[0], levels, levels_first); +#endif #ifdef GHEX_CUDACC // application data @@ -293,6 +295,9 @@ test_data_descriptor(ghex::context& ctxt, std::size_t levels, bool levels_first) EXPECT_NO_THROW(co.exchange(patterns(data_gpu)).wait()); auto h_gpu = co.exchange(patterns(data_gpu)); +#ifdef GHEX_USE_NCCL + cudaDeviceSynchronize(); +#endif h_gpu.wait(); // check exchanged data