From 1ff1218a4a359f642957885b016ddb48a9c2328f Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Mon, 3 Nov 2025 16:01:00 +0100 Subject: [PATCH 01/25] Add first dummy version of NCCL backend Mostly just copy MPI implementation to a new directory, not functional. --- cmake/oomph_nccl.cmake | 18 +++ src/nccl/CMakeLists.txt | 9 ++ src/nccl/channel_base.hpp | 77 +++++++++++ src/nccl/communicator.hpp | 128 ++++++++++++++++++ src/nccl/context.cpp | 34 +++++ src/nccl/context.hpp | 86 ++++++++++++ src/nccl/handle.hpp | 31 +++++ src/nccl/lock_cache.hpp | 53 ++++++++ src/nccl/nccl_communicator.hpp | 56 ++++++++ src/nccl/nccl_error.hpp | 33 +++++ src/nccl/recv_channel.hpp | 77 +++++++++++ src/nccl/region.hpp | 84 ++++++++++++ src/nccl/request.hpp | 37 ++++++ src/nccl/request_queue.hpp | 233 +++++++++++++++++++++++++++++++++ src/nccl/request_state.hpp | 97 ++++++++++++++ src/nccl/rma_context.hpp | 84 ++++++++++++ src/nccl/send_channel.hpp | 46 +++++++ 17 files changed, 1183 insertions(+) create mode 100644 cmake/oomph_nccl.cmake create mode 100644 src/nccl/CMakeLists.txt create mode 100644 src/nccl/channel_base.hpp create mode 100644 src/nccl/communicator.hpp create mode 100644 src/nccl/context.cpp create mode 100644 src/nccl/context.hpp create mode 100644 src/nccl/handle.hpp create mode 100644 src/nccl/lock_cache.hpp create mode 100644 src/nccl/nccl_communicator.hpp create mode 100644 src/nccl/nccl_error.hpp create mode 100644 src/nccl/recv_channel.hpp create mode 100644 src/nccl/region.hpp create mode 100644 src/nccl/request.hpp create mode 100644 src/nccl/request_queue.hpp create mode 100644 src/nccl/request_state.hpp create mode 100644 src/nccl/rma_context.hpp create mode 100644 src/nccl/send_channel.hpp diff --git a/cmake/oomph_nccl.cmake b/cmake/oomph_nccl.cmake new file mode 100644 index 00000000..7528f820 --- /dev/null +++ b/cmake/oomph_nccl.cmake @@ -0,0 +1,18 @@ +# set all NCCL related options and values + +#------------------------------------------------------------------------------ +# Enable NCCL support +#------------------------------------------------------------------------------ +set(OOMPH_WITH_NCCL OFF CACHE BOOL "Build with NCCL backend") + +if (OOMPH_WITH_NCCL) + # find_package(NCCL REQUIRED) + add_library(oomph_nccl SHARED) + add_library(oomph::nccl ALIAS oomph_nccl) + oomph_shared_lib_options(oomph_nccl) + # target_link_libraries(oomph_nccl PUBLIC NCCL::NCCL) + install(TARGETS oomph_nccl + EXPORT oomph-targets + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) +endif() diff --git a/src/nccl/CMakeLists.txt b/src/nccl/CMakeLists.txt new file mode 100644 index 00000000..9d006c15 --- /dev/null +++ b/src/nccl/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(oomph_private_nccl_headers INTERFACE) +target_include_directories(oomph_private_nccl_headers INTERFACE + "$") +target_link_libraries(oomph_nccl PRIVATE oomph_private_nccl_headers) + +list(TRANSFORM oomph_sources PREPEND ${CMAKE_CURRENT_SOURCE_DIR}/../ + OUTPUT_VARIABLE oomph_sources_nccl) +target_sources(oomph_nccl PRIVATE ${oomph_sources_nccl}) +target_sources(oomph_nccl PRIVATE context.cpp) diff --git a/src/nccl/channel_base.hpp b/src/nccl/channel_base.hpp new file mode 100644 index 00000000..f8751a44 --- /dev/null +++ b/src/nccl/channel_base.hpp @@ -0,0 +1,77 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +// paths relative to backend +#include + +namespace oomph +{ +class channel_base +{ + protected: + using heap_type = context_impl::heap_type; + using pointer = heap_type::pointer; + using handle_type = typename pointer::handle_type; + using key_type = typename handle_type::key_type; + using flag_basic_type = key_type; + using flag_type = flag_basic_type volatile; + + protected: + //heap_type& m_heap; + std::size_t m_size; + std::size_t m_T_size; + std::size_t m_levels; + std::size_t m_capacity; + communicator::rank_type m_remote_rank; + communicator::tag_type m_tag; + bool m_connected = false; + MPI_Request m_init_req; + + public: + channel_base(/*heap_type& h,*/ std::size_t size, std::size_t T_size, + communicator::rank_type remote_rank, communicator::tag_type tag, std::size_t levels) + //: m_heap{h} + : m_size{size} + , m_T_size{T_size} + , m_levels{levels} + , m_capacity{levels} + , m_remote_rank{remote_rank} + , m_tag{tag} + { + } + + void connect() + { + OOMPH_CHECK_MPI_RESULT(MPI_Wait(&m_init_req, MPI_STATUS_IGNORE)); + m_connected = true; + } + + protected: + // index of flag in buffer (in units of flag_basic_type) + std::size_t flag_offset() const noexcept + { + return (m_size * m_T_size + 2 * sizeof(flag_basic_type) - 1) / sizeof(flag_basic_type) - 1; + } + // number of elements of type T (including padding) + std::size_t buffer_size() const noexcept + { + return ((flag_offset() + 1) * sizeof(flag_basic_type) + m_T_size - 1) / m_T_size; + } + // pointer to flag location for a given buffer + void* flag_ptr(void* ptr) const noexcept + { + return (void*)((char*)ptr + flag_offset() * sizeof(flag_basic_type)); + } +}; + +} // namespace oomph diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp new file mode 100644 index 00000000..0022b157 --- /dev/null +++ b/src/nccl/communicator.hpp @@ -0,0 +1,128 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +// paths relative to backend +#include <../communicator_base.hpp> +#include <../device_guard.hpp> +#include +#include + +namespace oomph +{ +class communicator_impl : public communicator_base +{ + public: + context_impl* m_context; + request_queue m_send_reqs; + request_queue m_recv_reqs; + + communicator_impl(context_impl* ctxt) + : communicator_base(ctxt) + , m_context(ctxt) + { + } + + auto& get_heap() noexcept { return m_context->get_heap(); } + + mpi_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, + tag_type tag) + { + MPI_Request r; + const_device_guard dg(ptr); + OOMPH_CHECK_MPI_RESULT(MPI_Isend(dg.data(), size, MPI_BYTE, dst, tag, mpi_comm(), &r)); + return {r}; + } + + mpi_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, + tag_type tag) + { + MPI_Request r; + device_guard dg(ptr); + OOMPH_CHECK_MPI_RESULT(MPI_Irecv(dg.data(), size, MPI_BYTE, src, tag, mpi_comm(), &r)); + return {r}; + } + + send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, + tag_type tag, util::unique_function&& cb, + std::size_t* scheduled) + { + auto req = send(ptr, size, dst, tag); + if (!has_reached_recursion_depth() && req.is_ready()) + { + auto inc = recursion(); + cb(dst, tag); + return {}; + } + else + { + auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, + std::move(cb), req); + s->create_self_ref(); + m_send_reqs.enqueue(s.get()); + return {std::move(s)}; + } + } + + recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, + tag_type tag, util::unique_function&& cb, + std::size_t* scheduled) + { + auto req = recv(ptr, size, src, tag); + if (!has_reached_recursion_depth() && req.is_ready()) + { + auto inc = recursion(); + cb(src, tag); + return {}; + } + else + { + auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, + std::move(cb), req); + s->create_self_ref(); + m_recv_reqs.enqueue(s.get()); + return {std::move(s)}; + } + } + + shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, + rank_type src, tag_type tag, util::unique_function&& cb, + std::atomic* scheduled) + { + auto req = recv(ptr, size, src, tag); + if (!m_context->has_reached_recursion_depth() && req.is_ready()) + { + auto inc = m_context->recursion(); + cb(src, tag); + return {}; + } + else + { + auto s = std::make_shared(m_context, this, scheduled, src, + tag, std::move(cb), req); + s->create_self_ref(); + m_context->m_req_queue.enqueue(s.get()); + return {std::move(s)}; + } + } + + void progress() + { + m_send_reqs.progress(); + m_recv_reqs.progress(); + m_context->progress(); + } + + bool cancel_recv(detail::request_state* s) { return m_recv_reqs.cancel(s); } +}; + +} // namespace oomph diff --git a/src/nccl/context.cpp b/src/nccl/context.cpp new file mode 100644 index 00000000..9f3273d4 --- /dev/null +++ b/src/nccl/context.cpp @@ -0,0 +1,34 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ + +// paths relative to backend +#include +#include + +namespace oomph +{ +communicator_impl* +context_impl::get_communicator() +{ + auto comm = new communicator_impl{this}; + m_comms_set.insert(comm); + return comm; +} + +const char *context_impl::get_transport_option(const std::string &opt) { + if (opt == "name") { + return "mpi"; + } + else { + return "unspecified"; + } +} + +} // namespace oomph diff --git a/src/nccl/context.hpp b/src/nccl/context.hpp new file mode 100644 index 00000000..cc542392 --- /dev/null +++ b/src/nccl/context.hpp @@ -0,0 +1,86 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +#include + +// paths relative to backend +#include +#include <../context_base.hpp> +#include +#include + +namespace oomph +{ +class context_impl : public context_base +{ + public: + using region_type = region; + using device_region_type = region; + using heap_type = hwmalloc::heap; + + private: + heap_type m_heap; + detail::nccl_comm m_comm; + + public: + shared_request_queue m_req_queue; + + public: + context_impl(MPI_Comm comm, bool thread_safe, hwmalloc::heap_config const& heap_config) + : context_base(comm, thread_safe) + , m_heap{this, heap_config} + //, m_rma_context{m_mpi_comm} + , m_comm{mpi_comm{comm}} + { + } + + context_impl(context_impl const&) = delete; + context_impl(context_impl&&) = delete; + + region make_region(void* ptr) const { return {ptr}; } + + auto& get_heap() noexcept { return m_heap; } + + communicator_impl* get_communicator(); + + void progress() { m_req_queue.progress(); } + + bool cancel_recv(detail::shared_request_state* r) { + // TODO: Ignore? Can't undo kernel launches. + } + + unsigned int num_tag_bits() const noexcept { + // TODO: Important? Can't use tags with NCCL. + return 32; + } + + const char* get_transport_option(const std::string& opt); +}; + +template<> +inline region +register_memory(context_impl& c, void* ptr, std::size_t) +{ + return c.make_region(ptr); +} + +#if OOMPH_ENABLE_DEVICE +template<> +inline region +register_device_memory(context_impl& c, int, void* ptr, std::size_t) +{ + return c.make_region(ptr); +} +#endif + +} // namespace oomph diff --git a/src/nccl/handle.hpp b/src/nccl/handle.hpp new file mode 100644 index 00000000..179c2686 --- /dev/null +++ b/src/nccl/handle.hpp @@ -0,0 +1,31 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +struct handle +{ + using key_type = MPI_Aint; + + void* m_ptr; + std::size_t m_size; + + key_type get_remote_key() const noexcept + { + MPI_Aint address; + OOMPH_CHECK_MPI_RESULT_NOEXCEPT(MPI_Get_address(m_ptr, &address)); + return address; + //return ((char*)m_ptr - MPI_BOTTOM); + } +}; +} // namespace oomph diff --git a/src/nccl/lock_cache.hpp b/src/nccl/lock_cache.hpp new file mode 100644 index 00000000..1b61f46f --- /dev/null +++ b/src/nccl/lock_cache.hpp @@ -0,0 +1,53 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include + +#include +#include + +namespace oomph +{ +class lock_cache +{ + private: + MPI_Win m_win; + std::set m_ranks; + std::mutex m_mutex; + + public: + lock_cache(MPI_Win win) noexcept + : m_win(win) + { + } + + lock_cache(lock_cache const&) = delete; + + ~lock_cache() + { + for (auto r : m_ranks) MPI_Win_unlock(r, m_win); + } + + void lock(rank_type r) + { + std::lock_guard l(m_mutex); + + auto it = m_ranks.find(r); + if (it == m_ranks.end()) + { + m_ranks.insert(r); + OOMPH_CHECK_MPI_RESULT(MPI_Win_lock(MPI_LOCK_SHARED, r, 0, m_win)); + } + } +}; + +} // namespace oomph diff --git a/src/nccl/nccl_communicator.hpp b/src/nccl/nccl_communicator.hpp new file mode 100644 index 00000000..5e05ccc9 --- /dev/null +++ b/src/nccl/nccl_communicator.hpp @@ -0,0 +1,56 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include + +#include +#include <../mpi_comm.hpp> + +#include + +namespace oomph::detail +{ +class nccl_comm +{ + ncclComm_t m_comm; + oomph::util::moved_bit m_moved; + + public: + nccl_communicator(mpi_comm mpi_comm) + { + ncclUniqueId id; + if (mpi_comm.rank() == 0) { OOMPH_CHECK_NCCL_RESULT(ncclGetUniqueId(&id)); } + + OOMPH_CHECK_MPI_RESULT(MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, mpi_comm.get())); + + OOMPH_CHECK_NCCL_RESULT(ncclCommInitRank(&m_comm, mpi_comm.size(), id, mpi_comm.rank())); + ncclResult_t result; + do { + OOMPH_CHECK_NCCL_RESULT(ncclCommGetAsyncError(m_comm, &result)); + } + } + nccl_comm(nccl_comm&&) noexcept = default; + nccl_comm& operator=(nccl_comm&&) noexcept = default; + nccl_comm(nccl_comm const&) = delete; + nccl_comm& operator=(nccl_comm const&) = delete; + ~nccl_comm() noexcept + { + if (!m_moved) + { + // TODO + // OOMPH_CHECK_CUDA_RESULT_NOEXCEPT(cudaDeviceSynchronize()); + cudaDeviceSynchronize(); + OOMPH_CHECK_NCCL_RESULT_NOEXCEPT(ncclCommDestroy(m_comm)); + } + } +}; +} // namespace oomph::detail diff --git a/src/nccl/nccl_error.hpp b/src/nccl/nccl_error.hpp new file mode 100644 index 00000000..ac4da242 --- /dev/null +++ b/src/nccl/nccl_error.hpp @@ -0,0 +1,33 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +// TODO: Print error string and code. +#ifdef NDEBUG +#define OOMPH_CHECK_NCCL_RESULT(x) x; +#define OOMPH_CHECK_NCCL_RESULT_NOEXCEPT(x) x; +#else +#include +#include +#include +#define OOMPH_CHECK_NCCL_RESULT(x) \ + if (x != ncclSuccess && x != ncclInProgress) \ + throw std::runtime_error("OOMPH Error: NCCL Call failed " + std::string(#x) + " in " + \ + std::string(__FILE__) + ":" + std::to_string(__LINE__)); +#define OOMPH_CHECK_NCCL_RESULT_NOEXCEPT(x) \ + if (x != ncclSuccess && x != ncclInProgress) \ + { \ + std::cerr << "OOMPH Error: NCCL Call failed " << std::string(#x) << " in " \ + << std::string(__FILE__) << ":" << std::to_string(__LINE__) << std::endl; \ + std::terminate(); \ + } +#endif diff --git a/src/nccl/recv_channel.hpp b/src/nccl/recv_channel.hpp new file mode 100644 index 00000000..87b0c269 --- /dev/null +++ b/src/nccl/recv_channel.hpp @@ -0,0 +1,77 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include + +// paths relative to backend +#include + +namespace oomph +{ +class recv_channel_impl : public channel_base +{ + using base = channel_base; + using flag_basic_type = typename base::flag_basic_type; + using flag_type = typename base::flag_type; + using pointer = typename base::pointer; + using handle_type = typename base::handle_type; + using key_type = typename base::key_type; + + private: + communicator::impl* m_comm; + pointer m_buffer; + key_type m_local_key; + + public: + recv_channel_impl(communicator::impl* impl_, std::size_t size, std::size_t T_size, + communicator::rank_type src, communicator::tag_type tag, std::size_t levels) + : base(size, T_size, src, tag, levels) + , m_comm(impl_) + , m_buffer{m_comm->get_heap().allocate( + levels * base::buffer_size() * T_size, hwmalloc::numa().local_node())} + , m_local_key{m_buffer.handle().get_remote_key()} + { + m_comm->m_context->lock(src); + OOMPH_CHECK_MPI_RESULT(MPI_Isend(&m_local_key, sizeof(key_type), MPI_BYTE, + base::m_remote_rank, base::m_tag, m_comm->get_comm(), &(base::m_init_req))); + } + recv_channel_impl(recv_channel_impl const&) = delete; + recv_channel_impl(recv_channel_impl&&) = delete; + + ~recv_channel_impl() + { + } + + //void connect() {} + + std::size_t capacity() + { + return base::m_capacity; + } + + void* get(std::size_t& index) + { + index = 0; + return nullptr; + } + + void release(std::size_t index) + { + } +}; + +void release_recv_channel_buffer(recv_channel_impl* rc, std::size_t index) +{ + rc->release(index); +} + +} // namespace oomph diff --git a/src/nccl/region.hpp b/src/nccl/region.hpp new file mode 100644 index 00000000..78154a00 --- /dev/null +++ b/src/nccl/region.hpp @@ -0,0 +1,84 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +// paths relative to backend +#include + +namespace oomph +{ +class region +{ + public: + using handle_type = handle; + + private: + void* m_ptr; + + public: + region(void* ptr) + : m_ptr{ptr} + { + } + + region(region const&) = delete; + + region(region&& r) noexcept + : m_ptr{std::exchange(r.m_ptr, nullptr)} + { + } + + // get a handle to some portion of the region + handle_type get_handle(std::size_t offset, std::size_t size) + { + return {(void*)((char*)m_ptr + offset), size}; + } +}; + +class rma_region +{ + public: + using handle_type = handle; + + private: + MPI_Comm m_comm; + MPI_Win m_win; + void* m_ptr; + + public: + rma_region(MPI_Comm comm, MPI_Win win, void* ptr, std::size_t size) + : m_comm{comm} + , m_win{win} + , m_ptr{ptr} + { + OOMPH_CHECK_MPI_RESULT(MPI_Win_attach(m_win, ptr, size)); + } + + rma_region(rma_region const&) = delete; + + rma_region(rma_region&& r) noexcept + : m_comm{r.m_comm} + , m_win{r.m_win} + , m_ptr{std::exchange(r.m_ptr, nullptr)} + { + } + + ~rma_region() + { + if (m_ptr) MPI_Win_detach(m_win, m_ptr); + } + + // get a handle to some portion of the region + handle_type get_handle(std::size_t offset, std::size_t size) + { + return {(void*)((char*)m_ptr + offset), size}; + } +}; +} // namespace oomph diff --git a/src/nccl/request.hpp b/src/nccl/request.hpp new file mode 100644 index 00000000..a126143b --- /dev/null +++ b/src/nccl/request.hpp @@ -0,0 +1,37 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +struct mpi_request +{ + MPI_Request m_req; + + bool is_ready() + { + int flag; + OOMPH_CHECK_MPI_RESULT(MPI_Test(&m_req, &flag, MPI_STATUS_IGNORE)); + return flag; + } + + bool cancel() + { + OOMPH_CHECK_MPI_RESULT(MPI_Cancel(&m_req)); + MPI_Status st; + OOMPH_CHECK_MPI_RESULT(MPI_Wait(&m_req, &st)); + int flag = false; + OOMPH_CHECK_MPI_RESULT(MPI_Test_cancelled(&st, &flag)); + return flag; + } +}; +} // namespace oomph diff --git a/src/nccl/request_queue.hpp b/src/nccl/request_queue.hpp new file mode 100644 index 00000000..bc44e415 --- /dev/null +++ b/src/nccl/request_queue.hpp @@ -0,0 +1,233 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include + +// paths relative to backend +#include + +namespace oomph +{ + +class request_queue +{ + private: + using element_type = detail::request_state; + using queue_type = std::vector; + + private: // members + queue_type m_queue; + queue_type m_ready_queue; + bool in_progress = false; + std::vector reqs; + std::vector indices; + + public: // ctors + request_queue() + { + m_queue.reserve(256); + m_ready_queue.reserve(256); + } + + public: // member functions + std::size_t size() const noexcept { return m_queue.size(); } + + void enqueue(element_type* e) + { + e->m_index = m_queue.size(); + m_queue.push_back(e); + } + + int progress() + { + if (in_progress) return 0; + in_progress = true; + + const auto qs = size(); + if (qs == 0) + { + in_progress = false; + return 0; + } + + m_ready_queue.clear(); + + m_ready_queue.reserve(qs); + //reqs.resize(0); + reqs.clear(); + reqs.reserve(qs); + indices.resize(qs + 1); + + std::transform(m_queue.begin(), m_queue.end(), std::back_inserter(reqs), + [](auto e) { return e->m_req.m_req; }); + + int outcount; + OOMPH_CHECK_MPI_RESULT( + MPI_Testsome(qs, reqs.data(), &outcount, indices.data(), MPI_STATUSES_IGNORE)); + + if (outcount == 0) + { + in_progress = false; + return 0; + } + + indices[outcount] = qs; + + std::size_t k = 0; + std::size_t j = 0; + for (std::size_t i = 0; i < qs; ++i) + { + auto e = m_queue[i]; + if ((int)i == indices[k]) + { + m_ready_queue.push_back(e); + ++k; + } + else if (i > j) + { + e->m_index = j; + m_queue[j] = e; + ++j; + } + else + { + ++j; + } + } + m_queue.erase(m_queue.end() - m_ready_queue.size(), m_queue.end()); + + int completed = m_ready_queue.size(); + for (auto e : m_ready_queue) + { + auto ptr = e->release_self_ref(); + e->invoke_cb(); + } + + in_progress = false; + return completed; + } + + bool cancel(element_type* e) + { + auto const index = e->m_index; + if (m_queue[index]->m_req.cancel()) + { + auto ptr = e->release_self_ref(); + e->set_canceled(); + if (index + 1 < m_queue.size()) + { + m_queue[index] = m_queue.back(); + m_queue[index]->m_index = index; + } + m_queue.pop_back(); + return true; + } + else + return false; + } +}; + +class shared_request_queue +{ + private: + using element_type = detail::shared_request_state; + using queue_type = boost::lockfree::queue, + boost::lockfree::allocator>>; + + private: // members + queue_type m_queue; + std::atomic m_size; + + public: // ctors + shared_request_queue() + : m_queue(256) + , m_size(0) + { + } + + public: // member functions + std::size_t size() const noexcept { return m_size.load(); } + + void enqueue(element_type* e) + { + m_queue.push(e); + ++m_size; + } + + int progress() + { + static thread_local bool in_progress = false; + static thread_local std::vector m_local_queue; + int found = 0; + + if (in_progress) return 0; + in_progress = true; + + element_type* e; + while (m_queue.pop(e)) + { + if (e->m_req.is_ready()) + { + found = 1; + break; + } + else + { + m_local_queue.push_back(e); + } + } + + for (auto x : m_local_queue) m_queue.push(x); + m_local_queue.clear(); + + if (found) + { + auto ptr = e->release_self_ref(); + e->invoke_cb(); + --m_size; + } + + in_progress = false; + return found; + } + + bool cancel(element_type* e) + { + static thread_local std::vector m_local_queue; + m_local_queue.clear(); + + bool canceled = false; + m_queue.consume_all( + [q = &m_local_queue, e, &canceled](element_type* x) + { + if (e == x) + { + if (e->m_req.cancel()) + { + auto ptr = e->release_self_ref(); + e->set_canceled(); + canceled = true; + } + else + q->push_back(x); + } + else + q->push_back(x); + }); + + for (auto x : m_local_queue) m_queue.push(x); + + return canceled; + } +}; + +} // namespace oomph diff --git a/src/nccl/request_state.hpp b/src/nccl/request_state.hpp new file mode 100644 index 00000000..da69eb95 --- /dev/null +++ b/src/nccl/request_state.hpp @@ -0,0 +1,97 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +// paths relative to backend +#include <../request_state_base.hpp> +#include + +namespace oomph +{ +namespace detail +{ +struct request_state +: public util::enable_shared_from_this +, public request_state_base +{ + using base = request_state_base; + using shared_ptr_t = util::unsafe_shared_ptr; + + mpi_request m_req; + shared_ptr_t m_self_ptr; + std::size_t m_index; + + request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, std::size_t* scheduled, + rank_type rank, tag_type tag, cb_type&& cb, mpi_request m) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_req{m} + { + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } +}; + +struct shared_request_state +: public std::enable_shared_from_this +, public request_state_base +{ + using base = request_state_base; + using shared_ptr_t = std::shared_ptr; + + mpi_request m_req; + shared_ptr_t m_self_ptr; + + shared_request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, + std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb, + mpi_request m) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_req{m} + { + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } +}; + +} // namespace detail +} // namespace oomph diff --git a/src/nccl/rma_context.hpp b/src/nccl/rma_context.hpp new file mode 100644 index 00000000..aec295f0 --- /dev/null +++ b/src/nccl/rma_context.hpp @@ -0,0 +1,84 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include + +// paths relative to backend +#include +#include + +namespace oomph +{ +class rma_context +{ + public: + using region_type = rma_region; + using device_region_type = rma_region; + using heap_type = hwmalloc::heap; + + private: + struct mpi_win_holder + { + MPI_Win m; + ~mpi_win_holder() { MPI_Win_free(&m); } + }; + + private: + MPI_Comm m_mpi_comm; + mpi_win_holder m_win; + heap_type m_heap; + std::unique_ptr m_lock_cache; + + public: + rma_context(MPI_Comm comm) + : m_mpi_comm{comm} + , m_heap{this} + { + MPI_Info info; + OOMPH_CHECK_MPI_RESULT(MPI_Info_create(&info)); + OOMPH_CHECK_MPI_RESULT(MPI_Info_set(info, "no_locks", "false")); + OOMPH_CHECK_MPI_RESULT(MPI_Win_create_dynamic(info, m_mpi_comm, &(m_win.m))); + MPI_Info_free(&info); + OOMPH_CHECK_MPI_RESULT(MPI_Win_fence(0, m_win.m)); + m_lock_cache = std::make_unique(m_win.m); + } + rma_context(context_impl const&) = delete; + rma_context(context_impl&&) = delete; + + rma_region make_region(void* ptr, std::size_t size) const + { + return {m_mpi_comm, m_win.m, ptr, size}; + } + + auto get_window() const noexcept { return m_win.m; } + auto& get_heap() noexcept { return m_heap; } + void lock(rank_type r) { m_lock_cache->lock(r); } +}; + +template<> +inline rma_region +register_memory(rma_context& c, void* ptr, std::size_t size) +{ + return c.make_region(ptr, size); +} + +#if OOMPH_ENABLE_DEVICE +template<> +inline rma_region +register_device_memory(rma_context& c, int, void* ptr, std::size_t size) +{ + return c.make_region(ptr, size); +} +#endif + +} // namespace oomph diff --git a/src/nccl/send_channel.hpp b/src/nccl/send_channel.hpp new file mode 100644 index 00000000..caa95b74 --- /dev/null +++ b/src/nccl/send_channel.hpp @@ -0,0 +1,46 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include +#include + +// paths relative to backend +#include + +namespace oomph +{ +class send_channel_impl : public channel_base +{ + using base = channel_base; + using flag_basic_type = typename base::flag_basic_type; + using flag_type = typename base::flag_type; + using pointer = typename base::pointer; + using handle_type = typename base::handle_type; + using key_type = typename base::key_type; + + communicator::impl* m_comm; + key_type m_remote_key; + + public: + send_channel_impl(communicator::impl* impl_, std::size_t size, std::size_t T_size, + communicator::rank_type dst, communicator::tag_type tag, std::size_t levels) + : base(size, T_size, dst, tag, levels) + , m_comm(impl_) + { + m_comm->m_context->lock(dst); + OOMPH_CHECK_MPI_RESULT(MPI_Irecv(&m_remote_key, sizeof(key_type), MPI_BYTE, + base::m_remote_rank, base::m_tag, m_comm->get_comm(), &(base::m_init_req))); + } + send_channel_impl(send_channel_impl const&) = delete; + send_channel_impl(send_channel_impl&&) = delete; + +}; + +} // namespace oomph From 76a8d174493ab13417b2e6555f8e819f9e1e44e3 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Mon, 3 Nov 2025 16:33:06 +0100 Subject: [PATCH 02/25] Clean up some unnecessary nccl files and try to port more mpi functionality to nccl --- include/oomph/channel/send_channel.hpp | 1 + src/nccl/channel_base.hpp | 77 -------- src/nccl/communicator.hpp | 133 +++++++------- src/nccl/context.cpp | 2 +- src/nccl/context.hpp | 6 +- src/nccl/handle.hpp | 31 ---- src/nccl/lock_cache.hpp | 53 ------ src/nccl/recv_channel.hpp | 77 -------- src/nccl/region.hpp | 84 --------- src/nccl/request.hpp | 26 +-- src/nccl/request_queue.hpp | 233 ------------------------- src/nccl/request_state.hpp | 8 +- src/nccl/rma_context.hpp | 84 --------- src/nccl/send_channel.hpp | 46 ----- 14 files changed, 83 insertions(+), 778 deletions(-) delete mode 100644 src/nccl/channel_base.hpp delete mode 100644 src/nccl/handle.hpp delete mode 100644 src/nccl/lock_cache.hpp delete mode 100644 src/nccl/recv_channel.hpp delete mode 100644 src/nccl/region.hpp delete mode 100644 src/nccl/request_queue.hpp delete mode 100644 src/nccl/rma_context.hpp delete mode 100644 src/nccl/send_channel.hpp diff --git a/include/oomph/channel/send_channel.hpp b/include/oomph/channel/send_channel.hpp index c6fb75d7..e60778f1 100644 --- a/include/oomph/channel/send_channel.hpp +++ b/include/oomph/channel/send_channel.hpp @@ -7,6 +7,7 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ +// TODO: Needed for a completely backend implementation? Skip for NCCL? #pragma once #include diff --git a/src/nccl/channel_base.hpp b/src/nccl/channel_base.hpp deleted file mode 100644 index f8751a44..00000000 --- a/src/nccl/channel_base.hpp +++ /dev/null @@ -1,77 +0,0 @@ -/* - * ghex-org - * - * Copyright (c) 2014-2023, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - */ -#pragma once - -#include - -// paths relative to backend -#include - -namespace oomph -{ -class channel_base -{ - protected: - using heap_type = context_impl::heap_type; - using pointer = heap_type::pointer; - using handle_type = typename pointer::handle_type; - using key_type = typename handle_type::key_type; - using flag_basic_type = key_type; - using flag_type = flag_basic_type volatile; - - protected: - //heap_type& m_heap; - std::size_t m_size; - std::size_t m_T_size; - std::size_t m_levels; - std::size_t m_capacity; - communicator::rank_type m_remote_rank; - communicator::tag_type m_tag; - bool m_connected = false; - MPI_Request m_init_req; - - public: - channel_base(/*heap_type& h,*/ std::size_t size, std::size_t T_size, - communicator::rank_type remote_rank, communicator::tag_type tag, std::size_t levels) - //: m_heap{h} - : m_size{size} - , m_T_size{T_size} - , m_levels{levels} - , m_capacity{levels} - , m_remote_rank{remote_rank} - , m_tag{tag} - { - } - - void connect() - { - OOMPH_CHECK_MPI_RESULT(MPI_Wait(&m_init_req, MPI_STATUS_IGNORE)); - m_connected = true; - } - - protected: - // index of flag in buffer (in units of flag_basic_type) - std::size_t flag_offset() const noexcept - { - return (m_size * m_T_size + 2 * sizeof(flag_basic_type) - 1) / sizeof(flag_basic_type) - 1; - } - // number of elements of type T (including padding) - std::size_t buffer_size() const noexcept - { - return ((flag_offset() + 1) * sizeof(flag_basic_type) + m_T_size - 1) / m_T_size; - } - // pointer to flag location for a given buffer - void* flag_ptr(void* ptr) const noexcept - { - return (void*)((char*)ptr + flag_offset() * sizeof(flag_basic_type)); - } -}; - -} // namespace oomph diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp index 0022b157..7aa2899b 100644 --- a/src/nccl/communicator.hpp +++ b/src/nccl/communicator.hpp @@ -9,13 +9,16 @@ */ #pragma once +#include + #include // paths relative to backend #include <../communicator_base.hpp> #include <../device_guard.hpp> #include -#include +// #include +#include namespace oomph { @@ -23,8 +26,8 @@ class communicator_impl : public communicator_base { public: context_impl* m_context; - request_queue m_send_reqs; - request_queue m_recv_reqs; + // request_queue m_send_reqs; + // request_queue m_recv_reqs; communicator_impl(context_impl* ctxt) : communicator_base(ctxt) @@ -34,64 +37,66 @@ class communicator_impl : public communicator_base auto& get_heap() noexcept { return m_context->get_heap(); } - mpi_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag) + nccl_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, + [[maybe_unused]] tag_type tag) { - MPI_Request r; + // TODO: Stream? Currently assume 0. const_device_guard dg(ptr); - OOMPH_CHECK_MPI_RESULT(MPI_Isend(dg.data(), size, MPI_BYTE, dst, tag, mpi_comm(), &r)); - return {r}; + OOMPH_CHECK_NCCL_RESULT( + ncclSend(dg.data(), size, ncclChar, dst, m_context->m_comm.get(), 0)); + // TODO: Return event to stream? Return void? + return {}; } - mpi_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - tag_type tag) + nccl_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, + [[maybe_unused]] tag_type tag) { - MPI_Request r; + // TODO: Stream? Currently assume 0. device_guard dg(ptr); - OOMPH_CHECK_MPI_RESULT(MPI_Irecv(dg.data(), size, MPI_BYTE, src, tag, mpi_comm(), &r)); - return {r}; + OOMPH_CHECK_NCCL_RESULT( + ncclRecv(dg.data(), size, ncclChar, src, m_context->m_comm.get(), 0)); + // TODO: Return event to stream? Return void? + return {}; } send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + tag_type tag, util::unique_function&& cb, std::size_t* scheduled) { auto req = send(ptr, size, dst, tag); - if (!has_reached_recursion_depth() && req.is_ready()) - { - auto inc = recursion(); - cb(dst, tag); - return {}; - } - else - { - auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, - std::move(cb), req); - s->create_self_ref(); - m_send_reqs.enqueue(s.get()); - return {std::move(s)}; - } + // if (!has_reached_recursion_depth() && req.is_ready()) + // { + // auto inc = recursion(); + // cb(dst, tag); + // return {}; + // } + // else + // { + // TODO: Do we want to support callbacks for NCCL communication? How should this be structured? + auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb), req); + // s->create_self_ref(); + // TODO: Callback ignored. + // m_send_reqs.enqueue(s.get()); + return {std::move(s)}; + // } } recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + tag_type tag, util::unique_function&& cb, std::size_t* scheduled) { auto req = recv(ptr, size, src, tag); - if (!has_reached_recursion_depth() && req.is_ready()) - { - auto inc = recursion(); - cb(src, tag); - return {}; - } - else - { - auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, - std::move(cb), req); - s->create_self_ref(); - m_recv_reqs.enqueue(s.get()); - return {std::move(s)}; - } + // if (!has_reached_recursion_depth() && req.is_ready()) + // { + // auto inc = recursion(); + // cb(src, tag); + // return {}; + // } + // else + // { + auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb), req); + // s->create_self_ref(); + // m_recv_reqs.enqueue(s.get()); + return {std::move(s)}; + // } } shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, @@ -99,30 +104,32 @@ class communicator_impl : public communicator_base std::atomic* scheduled) { auto req = recv(ptr, size, src, tag); - if (!m_context->has_reached_recursion_depth() && req.is_ready()) - { - auto inc = m_context->recursion(); - cb(src, tag); - return {}; - } - else - { - auto s = std::make_shared(m_context, this, scheduled, src, - tag, std::move(cb), req); - s->create_self_ref(); - m_context->m_req_queue.enqueue(s.get()); - return {std::move(s)}; - } + // if (!m_context->has_reached_recursion_depth() && req.is_ready()) + // { + // auto inc = m_context->recursion(); + // cb(src, tag); + // return {}; + // } + // else + // { + auto s = std::make_shared(m_context, this, scheduled, src, + tag, std::move(cb), req); + // s->create_self_ref(); + // m_context->m_req_queue.enqueue(s.get()); + return {std::move(s)}; + // } } void progress() { - m_send_reqs.progress(); - m_recv_reqs.progress(); - m_context->progress(); + // Nothing to do to progress NCCL. Just wait for GPU to finish. } - bool cancel_recv(detail::request_state* s) { return m_recv_reqs.cancel(s); } + bool cancel_recv(detail::request_state* s) + { + // TODO: NCCL does not allow cancellation? + return false; + } }; } // namespace oomph diff --git a/src/nccl/context.cpp b/src/nccl/context.cpp index 9f3273d4..32074ba1 100644 --- a/src/nccl/context.cpp +++ b/src/nccl/context.cpp @@ -24,7 +24,7 @@ context_impl::get_communicator() const char *context_impl::get_transport_option(const std::string &opt) { if (opt == "name") { - return "mpi"; + return "nccl"; } else { return "unspecified"; diff --git a/src/nccl/context.hpp b/src/nccl/context.hpp index cc542392..606f9f29 100644 --- a/src/nccl/context.hpp +++ b/src/nccl/context.hpp @@ -16,7 +16,6 @@ // paths relative to backend #include #include <../context_base.hpp> -#include #include namespace oomph @@ -36,11 +35,10 @@ class context_impl : public context_base shared_request_queue m_req_queue; public: - context_impl(MPI_Comm comm, bool thread_safe, hwmalloc::heap_config const& heap_config) + context_impl(ncclComm_t comm, bool thread_safe, hwmalloc::heap_config const& heap_config) : context_base(comm, thread_safe) , m_heap{this, heap_config} - //, m_rma_context{m_mpi_comm} - , m_comm{mpi_comm{comm}} + , m_comm{nccl_comm{comm}} { } diff --git a/src/nccl/handle.hpp b/src/nccl/handle.hpp deleted file mode 100644 index 179c2686..00000000 --- a/src/nccl/handle.hpp +++ /dev/null @@ -1,31 +0,0 @@ -/* - * ghex-org - * - * Copyright (c) 2014-2023, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - */ -#pragma once - -#include - -namespace oomph -{ -struct handle -{ - using key_type = MPI_Aint; - - void* m_ptr; - std::size_t m_size; - - key_type get_remote_key() const noexcept - { - MPI_Aint address; - OOMPH_CHECK_MPI_RESULT_NOEXCEPT(MPI_Get_address(m_ptr, &address)); - return address; - //return ((char*)m_ptr - MPI_BOTTOM); - } -}; -} // namespace oomph diff --git a/src/nccl/lock_cache.hpp b/src/nccl/lock_cache.hpp deleted file mode 100644 index 1b61f46f..00000000 --- a/src/nccl/lock_cache.hpp +++ /dev/null @@ -1,53 +0,0 @@ -/* - * ghex-org - * - * Copyright (c) 2014-2023, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - */ -#pragma once - -#include -#include - -#include -#include - -namespace oomph -{ -class lock_cache -{ - private: - MPI_Win m_win; - std::set m_ranks; - std::mutex m_mutex; - - public: - lock_cache(MPI_Win win) noexcept - : m_win(win) - { - } - - lock_cache(lock_cache const&) = delete; - - ~lock_cache() - { - for (auto r : m_ranks) MPI_Win_unlock(r, m_win); - } - - void lock(rank_type r) - { - std::lock_guard l(m_mutex); - - auto it = m_ranks.find(r); - if (it == m_ranks.end()) - { - m_ranks.insert(r); - OOMPH_CHECK_MPI_RESULT(MPI_Win_lock(MPI_LOCK_SHARED, r, 0, m_win)); - } - } -}; - -} // namespace oomph diff --git a/src/nccl/recv_channel.hpp b/src/nccl/recv_channel.hpp deleted file mode 100644 index 87b0c269..00000000 --- a/src/nccl/recv_channel.hpp +++ /dev/null @@ -1,77 +0,0 @@ -/* - * ghex-org - * - * Copyright (c) 2014-2023, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - */ -#pragma once - -#include -#include - -// paths relative to backend -#include - -namespace oomph -{ -class recv_channel_impl : public channel_base -{ - using base = channel_base; - using flag_basic_type = typename base::flag_basic_type; - using flag_type = typename base::flag_type; - using pointer = typename base::pointer; - using handle_type = typename base::handle_type; - using key_type = typename base::key_type; - - private: - communicator::impl* m_comm; - pointer m_buffer; - key_type m_local_key; - - public: - recv_channel_impl(communicator::impl* impl_, std::size_t size, std::size_t T_size, - communicator::rank_type src, communicator::tag_type tag, std::size_t levels) - : base(size, T_size, src, tag, levels) - , m_comm(impl_) - , m_buffer{m_comm->get_heap().allocate( - levels * base::buffer_size() * T_size, hwmalloc::numa().local_node())} - , m_local_key{m_buffer.handle().get_remote_key()} - { - m_comm->m_context->lock(src); - OOMPH_CHECK_MPI_RESULT(MPI_Isend(&m_local_key, sizeof(key_type), MPI_BYTE, - base::m_remote_rank, base::m_tag, m_comm->get_comm(), &(base::m_init_req))); - } - recv_channel_impl(recv_channel_impl const&) = delete; - recv_channel_impl(recv_channel_impl&&) = delete; - - ~recv_channel_impl() - { - } - - //void connect() {} - - std::size_t capacity() - { - return base::m_capacity; - } - - void* get(std::size_t& index) - { - index = 0; - return nullptr; - } - - void release(std::size_t index) - { - } -}; - -void release_recv_channel_buffer(recv_channel_impl* rc, std::size_t index) -{ - rc->release(index); -} - -} // namespace oomph diff --git a/src/nccl/region.hpp b/src/nccl/region.hpp deleted file mode 100644 index 78154a00..00000000 --- a/src/nccl/region.hpp +++ /dev/null @@ -1,84 +0,0 @@ -/* - * ghex-org - * - * Copyright (c) 2014-2023, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - */ -#pragma once - -// paths relative to backend -#include - -namespace oomph -{ -class region -{ - public: - using handle_type = handle; - - private: - void* m_ptr; - - public: - region(void* ptr) - : m_ptr{ptr} - { - } - - region(region const&) = delete; - - region(region&& r) noexcept - : m_ptr{std::exchange(r.m_ptr, nullptr)} - { - } - - // get a handle to some portion of the region - handle_type get_handle(std::size_t offset, std::size_t size) - { - return {(void*)((char*)m_ptr + offset), size}; - } -}; - -class rma_region -{ - public: - using handle_type = handle; - - private: - MPI_Comm m_comm; - MPI_Win m_win; - void* m_ptr; - - public: - rma_region(MPI_Comm comm, MPI_Win win, void* ptr, std::size_t size) - : m_comm{comm} - , m_win{win} - , m_ptr{ptr} - { - OOMPH_CHECK_MPI_RESULT(MPI_Win_attach(m_win, ptr, size)); - } - - rma_region(rma_region const&) = delete; - - rma_region(rma_region&& r) noexcept - : m_comm{r.m_comm} - , m_win{r.m_win} - , m_ptr{std::exchange(r.m_ptr, nullptr)} - { - } - - ~rma_region() - { - if (m_ptr) MPI_Win_detach(m_win, m_ptr); - } - - // get a handle to some portion of the region - handle_type get_handle(std::size_t offset, std::size_t size) - { - return {(void*)((char*)m_ptr + offset), size}; - } -}; -} // namespace oomph diff --git a/src/nccl/request.hpp b/src/nccl/request.hpp index a126143b..4d3ee0eb 100644 --- a/src/nccl/request.hpp +++ b/src/nccl/request.hpp @@ -9,29 +9,13 @@ */ #pragma once -#include - namespace oomph { -struct mpi_request +struct nccl_request { - MPI_Request m_req; - - bool is_ready() - { - int flag; - OOMPH_CHECK_MPI_RESULT(MPI_Test(&m_req, &flag, MPI_STATUS_IGNORE)); - return flag; - } - - bool cancel() - { - OOMPH_CHECK_MPI_RESULT(MPI_Cancel(&m_req)); - MPI_Status st; - OOMPH_CHECK_MPI_RESULT(MPI_Wait(&m_req, &st)); - int flag = false; - OOMPH_CHECK_MPI_RESULT(MPI_Test_cancelled(&st, &flag)); - return flag; - } + // TODO: Ready when group has completed? Check stream or event? + bool is_ready() { return true; } + // TODO: No cancellation with NCCL? + bool cancel() { return false; } }; } // namespace oomph diff --git a/src/nccl/request_queue.hpp b/src/nccl/request_queue.hpp deleted file mode 100644 index bc44e415..00000000 --- a/src/nccl/request_queue.hpp +++ /dev/null @@ -1,233 +0,0 @@ -/* - * ghex-org - * - * Copyright (c) 2014-2023, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - */ -#pragma once - -#include -#include - -// paths relative to backend -#include - -namespace oomph -{ - -class request_queue -{ - private: - using element_type = detail::request_state; - using queue_type = std::vector; - - private: // members - queue_type m_queue; - queue_type m_ready_queue; - bool in_progress = false; - std::vector reqs; - std::vector indices; - - public: // ctors - request_queue() - { - m_queue.reserve(256); - m_ready_queue.reserve(256); - } - - public: // member functions - std::size_t size() const noexcept { return m_queue.size(); } - - void enqueue(element_type* e) - { - e->m_index = m_queue.size(); - m_queue.push_back(e); - } - - int progress() - { - if (in_progress) return 0; - in_progress = true; - - const auto qs = size(); - if (qs == 0) - { - in_progress = false; - return 0; - } - - m_ready_queue.clear(); - - m_ready_queue.reserve(qs); - //reqs.resize(0); - reqs.clear(); - reqs.reserve(qs); - indices.resize(qs + 1); - - std::transform(m_queue.begin(), m_queue.end(), std::back_inserter(reqs), - [](auto e) { return e->m_req.m_req; }); - - int outcount; - OOMPH_CHECK_MPI_RESULT( - MPI_Testsome(qs, reqs.data(), &outcount, indices.data(), MPI_STATUSES_IGNORE)); - - if (outcount == 0) - { - in_progress = false; - return 0; - } - - indices[outcount] = qs; - - std::size_t k = 0; - std::size_t j = 0; - for (std::size_t i = 0; i < qs; ++i) - { - auto e = m_queue[i]; - if ((int)i == indices[k]) - { - m_ready_queue.push_back(e); - ++k; - } - else if (i > j) - { - e->m_index = j; - m_queue[j] = e; - ++j; - } - else - { - ++j; - } - } - m_queue.erase(m_queue.end() - m_ready_queue.size(), m_queue.end()); - - int completed = m_ready_queue.size(); - for (auto e : m_ready_queue) - { - auto ptr = e->release_self_ref(); - e->invoke_cb(); - } - - in_progress = false; - return completed; - } - - bool cancel(element_type* e) - { - auto const index = e->m_index; - if (m_queue[index]->m_req.cancel()) - { - auto ptr = e->release_self_ref(); - e->set_canceled(); - if (index + 1 < m_queue.size()) - { - m_queue[index] = m_queue.back(); - m_queue[index]->m_index = index; - } - m_queue.pop_back(); - return true; - } - else - return false; - } -}; - -class shared_request_queue -{ - private: - using element_type = detail::shared_request_state; - using queue_type = boost::lockfree::queue, - boost::lockfree::allocator>>; - - private: // members - queue_type m_queue; - std::atomic m_size; - - public: // ctors - shared_request_queue() - : m_queue(256) - , m_size(0) - { - } - - public: // member functions - std::size_t size() const noexcept { return m_size.load(); } - - void enqueue(element_type* e) - { - m_queue.push(e); - ++m_size; - } - - int progress() - { - static thread_local bool in_progress = false; - static thread_local std::vector m_local_queue; - int found = 0; - - if (in_progress) return 0; - in_progress = true; - - element_type* e; - while (m_queue.pop(e)) - { - if (e->m_req.is_ready()) - { - found = 1; - break; - } - else - { - m_local_queue.push_back(e); - } - } - - for (auto x : m_local_queue) m_queue.push(x); - m_local_queue.clear(); - - if (found) - { - auto ptr = e->release_self_ref(); - e->invoke_cb(); - --m_size; - } - - in_progress = false; - return found; - } - - bool cancel(element_type* e) - { - static thread_local std::vector m_local_queue; - m_local_queue.clear(); - - bool canceled = false; - m_queue.consume_all( - [q = &m_local_queue, e, &canceled](element_type* x) - { - if (e == x) - { - if (e->m_req.cancel()) - { - auto ptr = e->release_self_ref(); - e->set_canceled(); - canceled = true; - } - else - q->push_back(x); - } - else - q->push_back(x); - }); - - for (auto x : m_local_queue) m_queue.push(x); - - return canceled; - } -}; - -} // namespace oomph diff --git a/src/nccl/request_state.hpp b/src/nccl/request_state.hpp index da69eb95..0eb061de 100644 --- a/src/nccl/request_state.hpp +++ b/src/nccl/request_state.hpp @@ -26,12 +26,12 @@ struct request_state using base = request_state_base; using shared_ptr_t = util::unsafe_shared_ptr; - mpi_request m_req; + nccl_request m_req; shared_ptr_t m_self_ptr; std::size_t m_index; request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, std::size_t* scheduled, - rank_type rank, tag_type tag, cb_type&& cb, mpi_request m) + rank_type rank, tag_type tag, cb_type&& cb, nccl_request m) : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} , m_req{m} { @@ -63,12 +63,12 @@ struct shared_request_state using base = request_state_base; using shared_ptr_t = std::shared_ptr; - mpi_request m_req; + nccl_request m_req; shared_ptr_t m_self_ptr; shared_request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb, - mpi_request m) + nccl_request m) : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} , m_req{m} { diff --git a/src/nccl/rma_context.hpp b/src/nccl/rma_context.hpp deleted file mode 100644 index aec295f0..00000000 --- a/src/nccl/rma_context.hpp +++ /dev/null @@ -1,84 +0,0 @@ -/* - * ghex-org - * - * Copyright (c) 2014-2023, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - */ -#pragma once - -#include -#include -#include - -// paths relative to backend -#include -#include - -namespace oomph -{ -class rma_context -{ - public: - using region_type = rma_region; - using device_region_type = rma_region; - using heap_type = hwmalloc::heap; - - private: - struct mpi_win_holder - { - MPI_Win m; - ~mpi_win_holder() { MPI_Win_free(&m); } - }; - - private: - MPI_Comm m_mpi_comm; - mpi_win_holder m_win; - heap_type m_heap; - std::unique_ptr m_lock_cache; - - public: - rma_context(MPI_Comm comm) - : m_mpi_comm{comm} - , m_heap{this} - { - MPI_Info info; - OOMPH_CHECK_MPI_RESULT(MPI_Info_create(&info)); - OOMPH_CHECK_MPI_RESULT(MPI_Info_set(info, "no_locks", "false")); - OOMPH_CHECK_MPI_RESULT(MPI_Win_create_dynamic(info, m_mpi_comm, &(m_win.m))); - MPI_Info_free(&info); - OOMPH_CHECK_MPI_RESULT(MPI_Win_fence(0, m_win.m)); - m_lock_cache = std::make_unique(m_win.m); - } - rma_context(context_impl const&) = delete; - rma_context(context_impl&&) = delete; - - rma_region make_region(void* ptr, std::size_t size) const - { - return {m_mpi_comm, m_win.m, ptr, size}; - } - - auto get_window() const noexcept { return m_win.m; } - auto& get_heap() noexcept { return m_heap; } - void lock(rank_type r) { m_lock_cache->lock(r); } -}; - -template<> -inline rma_region -register_memory(rma_context& c, void* ptr, std::size_t size) -{ - return c.make_region(ptr, size); -} - -#if OOMPH_ENABLE_DEVICE -template<> -inline rma_region -register_device_memory(rma_context& c, int, void* ptr, std::size_t size) -{ - return c.make_region(ptr, size); -} -#endif - -} // namespace oomph diff --git a/src/nccl/send_channel.hpp b/src/nccl/send_channel.hpp deleted file mode 100644 index caa95b74..00000000 --- a/src/nccl/send_channel.hpp +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ghex-org - * - * Copyright (c) 2014-2023, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - */ - -#include -#include - -// paths relative to backend -#include - -namespace oomph -{ -class send_channel_impl : public channel_base -{ - using base = channel_base; - using flag_basic_type = typename base::flag_basic_type; - using flag_type = typename base::flag_type; - using pointer = typename base::pointer; - using handle_type = typename base::handle_type; - using key_type = typename base::key_type; - - communicator::impl* m_comm; - key_type m_remote_key; - - public: - send_channel_impl(communicator::impl* impl_, std::size_t size, std::size_t T_size, - communicator::rank_type dst, communicator::tag_type tag, std::size_t levels) - : base(size, T_size, dst, tag, levels) - , m_comm(impl_) - { - m_comm->m_context->lock(dst); - OOMPH_CHECK_MPI_RESULT(MPI_Irecv(&m_remote_key, sizeof(key_type), MPI_BYTE, - base::m_remote_rank, base::m_tag, m_comm->get_comm(), &(base::m_init_req))); - } - send_channel_impl(send_channel_impl const&) = delete; - send_channel_impl(send_channel_impl&&) = delete; - -}; - -} // namespace oomph From f185ce660a2b541970512094b0714b794134f6da Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Thu, 27 Nov 2025 14:50:38 +0100 Subject: [PATCH 03/25] Add todos --- include/oomph/communicator.hpp | 3 +++ src/nccl/communicator.hpp | 2 ++ 2 files changed, 5 insertions(+) diff --git a/include/oomph/communicator.hpp b/include/oomph/communicator.hpp index 71d9908c..ff668390 100644 --- a/include/oomph/communicator.hpp +++ b/include/oomph/communicator.hpp @@ -100,6 +100,9 @@ class communicator bool is_ready() const noexcept { + // TODO: Would prefer not to count sends/recvs for NCCL. Prefer to check + // if stream or event is done (sends/recvs should be submitted in + // groups). return (scheduled_sends() == 0) && (scheduled_recvs() == 0) && (scheduled_shared_recvs() == 0); } diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp index 7aa2899b..d6159412 100644 --- a/src/nccl/communicator.hpp +++ b/src/nccl/communicator.hpp @@ -75,6 +75,8 @@ class communicator_impl : public communicator_base auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb), req); // s->create_self_ref(); // TODO: Callback ignored. + // TODO: Have to respect `scheduled`. Needs to be incremented before + // send and decremeted when done. // m_send_reqs.enqueue(s.get()); return {std::move(s)}; // } From d4909b3b77804ac8ef1147ed139156334be5bd1c Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Wed, 3 Dec 2025 13:36:23 +0100 Subject: [PATCH 04/25] Update nccl support --- CMakeLists.txt | 5 ++ cmake/oomph_nccl.cmake | 4 +- include/oomph/communicator.hpp | 88 ++++++++++++++++++---------------- src/CMakeLists.txt | 4 ++ src/communicator.cpp | 30 +++++++++--- src/mpi/communicator.hpp | 21 ++++---- src/nccl/communicator.hpp | 31 +++++++----- src/nccl/context.hpp | 19 ++++---- src/nccl/nccl_communicator.hpp | 6 ++- 9 files changed, 128 insertions(+), 80 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ca924a0e..d6a7ab73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,6 +80,11 @@ include(oomph_ucx) # --------------------------------------------------------------------- include(oomph_libfabric) +# --------------------------------------------------------------------- +# oomph NCCL variant +# --------------------------------------------------------------------- +include(oomph_nccl) + # --------------------------------------------------------------------- # main src subdir # --------------------------------------------------------------------- diff --git a/cmake/oomph_nccl.cmake b/cmake/oomph_nccl.cmake index 7528f820..eda57055 100644 --- a/cmake/oomph_nccl.cmake +++ b/cmake/oomph_nccl.cmake @@ -6,11 +6,11 @@ set(OOMPH_WITH_NCCL OFF CACHE BOOL "Build with NCCL backend") if (OOMPH_WITH_NCCL) - # find_package(NCCL REQUIRED) + find_package(NCCL REQUIRED) add_library(oomph_nccl SHARED) add_library(oomph::nccl ALIAS oomph_nccl) oomph_shared_lib_options(oomph_nccl) - # target_link_libraries(oomph_nccl PUBLIC NCCL::NCCL) + target_link_libraries(oomph_nccl PUBLIC NCCL::nccl) install(TARGETS oomph_nccl EXPORT oomph-targets LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/include/oomph/communicator.hpp b/include/oomph/communicator.hpp index ff668390..29716253 100644 --- a/include/oomph/communicator.hpp +++ b/include/oomph/communicator.hpp @@ -98,6 +98,8 @@ class communicator return m_state->m_shared_scheduled_recvs->load(); } + bool is_stream_aware() const noexcept; + bool is_ready() const noexcept { // TODO: Would prefer not to count sends/recvs for NCCL. Prefer to check @@ -146,6 +148,10 @@ class communicator } #endif + // TODO: const noexcept? + void start_group(); + void end_group(); + // no callback versions // ==================== @@ -153,33 +159,33 @@ class communicator // ---- template - recv_request recv(message_buffer& msg, rank_type src, tag_type tag) + recv_request recv(message_buffer& msg, rank_type src, tag_type tag, void* stream = nullptr) { assert(msg); return recv(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), src, tag, - util::unique_function([](rank_type, tag_type) {})); + util::unique_function([](rank_type, tag_type) {}), stream); } // shared_recv // ----------- template - shared_recv_request shared_recv(message_buffer& msg, rank_type src, tag_type tag) + shared_recv_request shared_recv(message_buffer& msg, rank_type src, tag_type tag, void* stream = nullptr) { assert(msg); return shared_recv(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), src, tag, - util::unique_function([](rank_type, tag_type) {})); + util::unique_function([](rank_type, tag_type) {}), stream); } // send // ---- template - send_request send(message_buffer const& msg, rank_type dst, tag_type tag) + send_request send(message_buffer const& msg, rank_type dst, tag_type tag, void* stream = nullptr) { assert(msg); return send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), dst, tag, - util::unique_function([](rank_type, tag_type) {})); + util::unique_function([](rank_type, tag_type) {}), stream); } // send_multi @@ -187,7 +193,7 @@ class communicator template send_multi_request send_multi(message_buffer const& msg, rank_type const* neighs, - std::size_t neighs_size, tag_type tag) + std::size_t neighs_size, tag_type tag, void* stream = nullptr) { assert(msg); auto mrs = m_state->make_multi_request_state(neighs_size); @@ -195,21 +201,21 @@ class communicator { send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), neighs[i], tag, util::unique_function( - [mrs](rank_type, tag_type) { --(mrs->m_counter); })); + [mrs](rank_type, tag_type) { --(mrs->m_counter); }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer const& msg, - std::vector const& neighs, tag_type tag) + std::vector const& neighs, tag_type tag, void* stream = nullptr) { return send_multi(msg, neighs.data(), neighs.size(), tag); } template send_multi_request send_multi(message_buffer const& msg, rank_type const* neighs, - tag_type const* tags, std::size_t neighs_size) + tag_type const* tags, std::size_t neighs_size, void* stream = nullptr) { assert(msg); auto mrs = m_state->make_multi_request_state(neighs_size); @@ -217,14 +223,14 @@ class communicator { send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), neighs[i], tags[i], util::unique_function( - [mrs](rank_type, tag_type) { --(mrs->m_counter); })); + [mrs](rank_type, tag_type) { --(mrs->m_counter); }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer const& msg, - std::vector const& neighs, std::vector const& tags) + std::vector const& neighs, std::vector const& tags, void* stream = nullptr) { assert(neighs.size() == tags.size()); return send_multi(msg, neighs.data(), tags.data(), neighs.size()); @@ -237,7 +243,7 @@ class communicator // ---- template - recv_request recv(message_buffer&& msg, rank_type src, tag_type tag, CallBack&& callback) + recv_request recv(message_buffer&& msg, rank_type src, tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK(CallBack) assert(msg); @@ -245,11 +251,11 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - cb_rref{std::forward(callback), std::move(msg)})); + cb_rref{std::forward(callback), std::move(msg)}), stream); } template - recv_request recv(message_buffer& msg, rank_type src, tag_type tag, CallBack&& callback) + recv_request recv(message_buffer& msg, rank_type src, tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_REF(CallBack) assert(msg); @@ -257,7 +263,7 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - cb_lref{std::forward(callback), &msg})); + cb_lref{std::forward(callback), &msg}), stream); } // shared_recv @@ -265,7 +271,7 @@ class communicator template shared_recv_request shared_recv(message_buffer&& msg, rank_type src, tag_type tag, - CallBack&& callback) + CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK(CallBack) assert(msg); @@ -273,12 +279,12 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return shared_recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - cb_rref{std::forward(callback), std::move(msg)})); + cb_rref{std::forward(callback), std::move(msg)}), stream); } template shared_recv_request shared_recv(message_buffer& msg, rank_type src, tag_type tag, - CallBack&& callback) + CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_REF(CallBack) assert(msg); @@ -286,14 +292,14 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return shared_recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - cb_lref{std::forward(callback), &msg})); + cb_lref{std::forward(callback), &msg}), stream); } // send // ---- template - send_request send(message_buffer&& msg, rank_type dst, tag_type tag, CallBack&& callback) + send_request send(message_buffer&& msg, rank_type dst, tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK(CallBack) assert(msg); @@ -301,11 +307,11 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return send(m_ptr, s * sizeof(T), dst, tag, util::unique_function( - cb_rref{std::forward(callback), std::move(msg)})); + cb_rref{std::forward(callback), std::move(msg)}), stream); } template - send_request send(message_buffer& msg, rank_type dst, tag_type tag, CallBack&& callback) + send_request send(message_buffer& msg, rank_type dst, tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_REF(CallBack) assert(msg); @@ -313,12 +319,12 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return send(m_ptr, s * sizeof(T), dst, tag, util::unique_function( - cb_lref{std::forward(callback), &msg})); + cb_lref{std::forward(callback), &msg}), stream); } template send_request send(message_buffer const& msg, rank_type dst, tag_type tag, - CallBack&& callback) + CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_CONST_REF(CallBack) assert(msg); @@ -326,7 +332,7 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return send(m_ptr, s * sizeof(T), dst, tag, util::unique_function( - cb_lref_const{std::forward(callback), &msg})); + cb_lref_const{std::forward(callback), &msg}), stream); } // send_multi @@ -334,7 +340,7 @@ class communicator template send_multi_request send_multi(message_buffer&& msg, std::vector neighs, - tag_type tag, CallBack&& callback) + tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI(CallBack) assert(msg); @@ -352,14 +358,14 @@ class communicator callback(message_buffer(std::move(mrs->m_msg), mrs->m_msg_size), std::move(mrs->m_neighs), t); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer&& msg, std::vector neighs, - std::vector tags, CallBack&& callback) + std::vector tags, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_TAGS(CallBack) assert(msg); @@ -380,14 +386,14 @@ class communicator callback(message_buffer(std::move(mrs->m_msg), mrs->m_msg_size), std::move(mrs->m_neighs), mrs->m_tags); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer& msg, std::vector neighs, - tag_type tag, CallBack&& callback) + tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_REF(CallBack) assert(msg); @@ -405,14 +411,14 @@ class communicator callback(*reinterpret_cast*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), t); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer& msg, std::vector neighs, - std::vector tags, CallBack&& callback) + std::vector tags, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_REF_TAGS(CallBack) assert(msg); @@ -432,14 +438,14 @@ class communicator callback(*reinterpret_cast*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), std::move(mrs->m_tags)); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer const& msg, std::vector neighs, - tag_type tag, CallBack&& callback) + tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_CONST_REF(CallBack) assert(msg); @@ -457,14 +463,14 @@ class communicator callback(*reinterpret_cast const*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), t); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer const& msg, std::vector neighs, - std::vector tags, CallBack&& callback) + std::vector tags, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_CONST_REF_TAGS(CallBack) assert(msg); @@ -484,7 +490,7 @@ class communicator callback(*reinterpret_cast const*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), std::move(mrs->m_tags)); } - })); + }), stream); } return {std::move(mrs)}; } @@ -502,13 +508,13 @@ class communicator #endif send_request send(detail::message_buffer::heap_ptr_impl const* m_ptr, std::size_t size, - rank_type dst, tag_type tag, util::unique_function&& cb); + rank_type dst, tag_type tag, util::unique_function&& cb, void* stream); recv_request recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb); + tag_type tag, util::unique_function&& cb, void* stream); shared_recv_request shared_recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, - rank_type src, tag_type tag, util::unique_function&& cb); + rank_type src, tag_type tag, util::unique_function&& cb, void* stream); }; } // namespace oomph diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ffc2d2b0..affb05cc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,3 +22,7 @@ endif() if (OOMPH_WITH_LIBFABRIC) add_subdirectory(libfabric) endif() + +if (OOMPH_WITH_NCCL) + add_subdirectory(nccl) +endif() diff --git a/src/communicator.cpp b/src/communicator.cpp index 823042cc..4b764fa8 100644 --- a/src/communicator.cpp +++ b/src/communicator.cpp @@ -45,34 +45,52 @@ communicator::mpi_comm() const noexcept return m_state->m_impl->mpi_comm(); } +bool +communicator::is_stream_aware() const noexcept +{ + return m_state->m_impl->is_stream_aware(); +} + void communicator::progress() { m_state->m_impl->progress(); } +void +communicator::start_group() +{ + return m_state->m_impl->start_group(); +} + +void +communicator::end_group() +{ + return m_state->m_impl->end_group(); +} + send_request communicator::send(detail::message_buffer::heap_ptr_impl const* m_ptr, std::size_t size, - rank_type dst, tag_type tag, util::unique_function&& cb) + rank_type dst, tag_type tag, util::unique_function&& cb, void* stream) { return m_state->m_impl->send(m_ptr->m, size, dst, tag, std::move(cb), - &(m_state->scheduled_sends)); + &(m_state->scheduled_sends), stream); } recv_request communicator::recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb) + tag_type tag, util::unique_function&& cb, void* stream) { return m_state->m_impl->recv(m_ptr->m, size, src, tag, std::move(cb), - &(m_state->scheduled_recvs)); + &(m_state->scheduled_recvs), stream); } shared_recv_request communicator::shared_recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, - rank_type src, tag_type tag, util::unique_function&& cb) + rank_type src, tag_type tag, util::unique_function&& cb, void* stream) { return m_state->m_impl->shared_recv(m_ptr->m, size, src, tag, std::move(cb), - m_state->m_shared_scheduled_recvs); + m_state->m_shared_scheduled_recvs, stream); } detail::message_buffer diff --git a/src/mpi/communicator.hpp b/src/mpi/communicator.hpp index 0022b157..b47c47b8 100644 --- a/src/mpi/communicator.hpp +++ b/src/mpi/communicator.hpp @@ -34,8 +34,13 @@ class communicator_impl : public communicator_base auto& get_heap() noexcept { return m_context->get_heap(); } + bool is_stream_aware() const noexcept { return false; } + + void start_group() {} + void end_group() {} + mpi_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag) + tag_type tag, void*) // TODO: Stream ignored, not stream-aware. Separate interface? { MPI_Request r; const_device_guard dg(ptr); @@ -44,7 +49,7 @@ class communicator_impl : public communicator_base } mpi_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - tag_type tag) + tag_type tag, void*) { MPI_Request r; device_guard dg(ptr); @@ -54,9 +59,9 @@ class communicator_impl : public communicator_base send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, void* stream) { - auto req = send(ptr, size, dst, tag); + auto req = send(ptr, size, dst, tag, stream); if (!has_reached_recursion_depth() && req.is_ready()) { auto inc = recursion(); @@ -75,9 +80,9 @@ class communicator_impl : public communicator_base recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, void* stream) { - auto req = recv(ptr, size, src, tag); + auto req = recv(ptr, size, src, tag, stream); if (!has_reached_recursion_depth() && req.is_ready()) { auto inc = recursion(); @@ -96,9 +101,9 @@ class communicator_impl : public communicator_base shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, - std::atomic* scheduled) + std::atomic* scheduled, void* stream) { - auto req = recv(ptr, size, src, tag); + auto req = recv(ptr, size, src, tag, stream); if (!m_context->has_reached_recursion_depth() && req.is_ready()) { auto inc = m_context->recursion(); diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp index d6159412..170cbc6b 100644 --- a/src/nccl/communicator.hpp +++ b/src/nccl/communicator.hpp @@ -19,6 +19,7 @@ #include // #include #include +#include namespace oomph { @@ -37,32 +38,36 @@ class communicator_impl : public communicator_base auto& get_heap() noexcept { return m_context->get_heap(); } + bool is_stream_aware() const noexcept { return true; } + + void start_group() { OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); } + + void end_group() { OOMPH_CHECK_NCCL_RESULT(ncclGroupEnd()); } + nccl_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - [[maybe_unused]] tag_type tag) + [[maybe_unused]] tag_type tag, void* stream) { - // TODO: Stream? Currently assume 0. const_device_guard dg(ptr); OOMPH_CHECK_NCCL_RESULT( - ncclSend(dg.data(), size, ncclChar, dst, m_context->m_comm.get(), 0)); + ncclSend(dg.data(), size, ncclChar, dst, m_context->get_comm(), static_cast(stream))); // TODO: Return event to stream? Return void? return {}; } nccl_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - [[maybe_unused]] tag_type tag) + [[maybe_unused]] tag_type tag, void* stream) { - // TODO: Stream? Currently assume 0. device_guard dg(ptr); OOMPH_CHECK_NCCL_RESULT( - ncclRecv(dg.data(), size, ncclChar, src, m_context->m_comm.get(), 0)); + ncclRecv(dg.data(), size, ncclChar, src, m_context->get_comm(), static_cast(stream))); // TODO: Return event to stream? Return void? return {}; } send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag, util::unique_function&& cb, std::size_t* scheduled) + tag_type tag, util::unique_function&& cb, std::size_t* scheduled, void* stream) { - auto req = send(ptr, size, dst, tag); + auto req = send(ptr, size, dst, tag, stream); // if (!has_reached_recursion_depth() && req.is_ready()) // { // auto inc = recursion(); @@ -83,9 +88,9 @@ class communicator_impl : public communicator_base } recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb, std::size_t* scheduled) + tag_type tag, util::unique_function&& cb, std::size_t* scheduled, void* stream) { - auto req = recv(ptr, size, src, tag); + auto req = recv(ptr, size, src, tag, stream); // if (!has_reached_recursion_depth() && req.is_ready()) // { // auto inc = recursion(); @@ -103,9 +108,9 @@ class communicator_impl : public communicator_base shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, - std::atomic* scheduled) + std::atomic* scheduled, void* stream) { - auto req = recv(ptr, size, src, tag); + auto req = recv(ptr, size, src, tag, stream); // if (!m_context->has_reached_recursion_depth() && req.is_ready()) // { // auto inc = m_context->recursion(); @@ -127,7 +132,7 @@ class communicator_impl : public communicator_base // Nothing to do to progress NCCL. Just wait for GPU to finish. } - bool cancel_recv(detail::request_state* s) + bool cancel_recv(detail::request_state*) { // TODO: NCCL does not allow cancellation? return false; diff --git a/src/nccl/context.hpp b/src/nccl/context.hpp index 606f9f29..2cfd4dec 100644 --- a/src/nccl/context.hpp +++ b/src/nccl/context.hpp @@ -9,6 +9,7 @@ */ #pragma once +#include #include #include @@ -16,7 +17,7 @@ // paths relative to backend #include #include <../context_base.hpp> -#include +#include namespace oomph { @@ -32,29 +33,31 @@ class context_impl : public context_base detail::nccl_comm m_comm; public: - shared_request_queue m_req_queue; - - public: - context_impl(ncclComm_t comm, bool thread_safe, hwmalloc::heap_config const& heap_config) + context_impl(MPI_Comm comm, bool thread_safe, hwmalloc::heap_config const& heap_config) : context_base(comm, thread_safe) , m_heap{this, heap_config} - , m_comm{nccl_comm{comm}} + , m_comm{oomph::detail::nccl_comm{comm}} { } context_impl(context_impl const&) = delete; context_impl(context_impl&&) = delete; + ncclComm_t get_comm() const noexcept { return m_comm.get(); } + region make_region(void* ptr) const { return {ptr}; } auto& get_heap() noexcept { return m_heap; } communicator_impl* get_communicator(); - void progress() { m_req_queue.progress(); } + void progress() { + // NCCL will make progress on its own. Or deadlock. + } - bool cancel_recv(detail::shared_request_state* r) { + bool cancel_recv(detail::shared_request_state*) { // TODO: Ignore? Can't undo kernel launches. + return false; } unsigned int num_tag_bits() const noexcept { diff --git a/src/nccl/nccl_communicator.hpp b/src/nccl/nccl_communicator.hpp index 5e05ccc9..71944fe6 100644 --- a/src/nccl/nccl_communicator.hpp +++ b/src/nccl/nccl_communicator.hpp @@ -25,7 +25,7 @@ class nccl_comm oomph::util::moved_bit m_moved; public: - nccl_communicator(mpi_comm mpi_comm) + nccl_comm(mpi_comm mpi_comm) { ncclUniqueId id; if (mpi_comm.rank() == 0) { OOMPH_CHECK_NCCL_RESULT(ncclGetUniqueId(&id)); } @@ -36,7 +36,7 @@ class nccl_comm ncclResult_t result; do { OOMPH_CHECK_NCCL_RESULT(ncclCommGetAsyncError(m_comm, &result)); - } + } while (result == ncclInProgress); } nccl_comm(nccl_comm&&) noexcept = default; nccl_comm& operator=(nccl_comm&&) noexcept = default; @@ -52,5 +52,7 @@ class nccl_comm OOMPH_CHECK_NCCL_RESULT_NOEXCEPT(ncclCommDestroy(m_comm)); } } + + ncclComm_t get() const noexcept { return m_comm; } }; } // namespace oomph::detail From 2349474cec96dcb1f3bf2983e3ae8dd98af3b6fa Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Thu, 18 Dec 2025 14:37:41 +0100 Subject: [PATCH 05/25] Slightly more working nccl backend with events as requests and lots of debugging --- cmake/FindNCCL.cmake | 73 ++++ include/oomph/communicator.hpp | 3 +- src/mpi/request.hpp | 1 + src/nccl/communicator.hpp | 134 ++++--- src/nccl/context.hpp | 10 +- src/nccl/cuda_error.hpp | 29 ++ src/nccl/handle.hpp | 20 ++ src/nccl/nccl_error.hpp | 20 +- src/nccl/region.hpp | 45 +++ src/nccl/request.hpp | 21 +- src/nccl/request_queue.hpp | 165 +++++++++ src/nccl/request_state.hpp | 6 +- src/request.cpp | 5 + src/request_state_base.hpp | 6 +- test/CMakeLists.txt | 6 + test/test_send_recv.cpp | 630 +++++++++++++++++---------------- 16 files changed, 805 insertions(+), 369 deletions(-) create mode 100644 cmake/FindNCCL.cmake create mode 100644 src/nccl/cuda_error.hpp create mode 100644 src/nccl/handle.hpp create mode 100644 src/nccl/region.hpp create mode 100644 src/nccl/request_queue.hpp diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake new file mode 100644 index 00000000..d5beae56 --- /dev/null +++ b/cmake/FindNCCL.cmake @@ -0,0 +1,73 @@ +# This is from https://github.com/pytorch/gloo/blob/main/cmake/Modules/Findnccl.cmake. +# TODO: Check that license is compatible. + +# Try to find NCCL +# +# The following variables are optionally searched for defaults +# NCCL_ROOT_DIR: Base directory where all NCCL components are found +# NCCL_INCLUDE_DIR: Directory where NCCL header is found +# NCCL_LIB_DIR: Directory where NCCL library is found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIRS +# NCCL_LIBRARIES +# +# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks +# install NCCL in the same location as the CUDA toolkit. +# See https://github.com/caffe2/caffe2/issues/1601 + +set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR} CACHE PATH "Folder contains NVIDIA NCCL") + +find_path(NCCL_INCLUDE_DIR + NAMES nccl.h + HINTS + ${NCCL_INCLUDE_DIR} + ${NCCL_ROOT_DIR} + ${NCCL_ROOT_DIR}/include + ${CUDA_TOOLKIT_ROOT_DIR}/include) + +if ($ENV{USE_STATIC_NCCL}) + message(STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library") + set(NCCL_LIBNAME "libnccl_static.a") +else() + set(NCCL_LIBNAME "nccl") +endif() + +find_library(NCCL_LIBRARY + NAMES ${NCCL_LIBNAME} + HINTS + ${NCCL_LIB_DIR} + ${NCCL_ROOT_DIR} + ${NCCL_ROOT_DIR}/lib + ${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu + ${NCCL_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib64) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY) + +if (NCCL_FOUND) + set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIR}/nccl.h") + message(STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}") + file (STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED + REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1) + if (NCCL_MAJOR_VERSION_DEFINED) + string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" "" + NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED}) + message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}") + endif() + set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) + set(NCCL_LIBRARIES ${NCCL_LIBRARY}) + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + + if(NOT TARGET NCCL::nccl AND NCCL_FOUND) + add_library(NCCL::nccl SHARED IMPORTED) + set_target_properties(NCCL::nccl PROPERTIES + IMPORTED_LOCATION ${NCCL_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES ${NCCL_INCLUDE_DIRS} + ) + endif() +endif() + diff --git a/include/oomph/communicator.hpp b/include/oomph/communicator.hpp index 29716253..6aebb200 100644 --- a/include/oomph/communicator.hpp +++ b/include/oomph/communicator.hpp @@ -102,6 +102,7 @@ class communicator bool is_ready() const noexcept { + std::cerr << "communicator::is_ready()\n"; // TODO: Would prefer not to count sends/recvs for NCCL. Prefer to check // if stream or event is done (sends/recvs should be submitted in // groups). @@ -210,7 +211,7 @@ class communicator send_multi_request send_multi(message_buffer const& msg, std::vector const& neighs, tag_type tag, void* stream = nullptr) { - return send_multi(msg, neighs.data(), neighs.size(), tag); + return send_multi(msg, neighs.data(), neighs.size(), tag, stream); } template diff --git a/src/mpi/request.hpp b/src/mpi/request.hpp index a126143b..d87356f8 100644 --- a/src/mpi/request.hpp +++ b/src/mpi/request.hpp @@ -19,6 +19,7 @@ struct mpi_request bool is_ready() { + std::cerr << "mpi_request::is_ready\n"; int flag; OOMPH_CHECK_MPI_RESULT(MPI_Test(&m_req, &flag, MPI_STATUS_IGNORE)); return flag; diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp index 170cbc6b..eb391e8b 100644 --- a/src/nccl/communicator.hpp +++ b/src/nccl/communicator.hpp @@ -9,17 +9,20 @@ */ #pragma once +#include +#include + #include #include // paths relative to backend -#include <../communicator_base.hpp> -#include <../device_guard.hpp> -#include -// #include -#include -#include +#include "../communicator_base.hpp" +#include "../device_guard.hpp" +#include "./context.hpp" +#include "./request.hpp" +#include "./request_state.hpp" +#include "./request_queue.hpp" namespace oomph { @@ -27,8 +30,11 @@ class communicator_impl : public communicator_base { public: context_impl* m_context; - // request_queue m_send_reqs; - // request_queue m_recv_reqs; + request_queue m_send_reqs; + request_queue m_recv_reqs; + bool m_in_group = false; + std::optional m_group_event; + cudaStream_t m_last_stream; communicator_impl(context_impl* ctxt) : communicator_base(ctxt) @@ -40,70 +46,96 @@ class communicator_impl : public communicator_base bool is_stream_aware() const noexcept { return true; } - void start_group() { OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); } + void start_group() { + OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); + m_in_group = true; + + // TODO: Correct flags etc. + cudaEvent_t event; + cudaEventCreate(&event); + std::cerr << "created group event " << event << "\n"; + m_group_event = event; + } - void end_group() { OOMPH_CHECK_NCCL_RESULT(ncclGroupEnd()); } + void end_group() { + m_in_group = false; + OOMPH_CHECK_NCCL_RESULT(ncclGroupEnd()); + + // All streams used in a NCCL group synchronize with the end of the group. + // We arbitrarily pick the last stream to synchronize against. + OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(m_group_event.value(), m_last_stream)); + // TODO: Release event. + } nccl_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, [[maybe_unused]] tag_type tag, void* stream) { + std::cerr << "nccl::send\n"; + const_device_guard dg(ptr); OOMPH_CHECK_NCCL_RESULT( ncclSend(dg.data(), size, ncclChar, dst, m_context->get_comm(), static_cast(stream))); - // TODO: Return event to stream? Return void? - return {}; + + if (m_in_group) { + m_last_stream = static_cast(stream); + // Store event now, but record it when group ends + // TODO: Have to make sure it's safe to query event early. + std::cerr << "using group event " << m_group_event.value() << "\n"; + return {m_group_event.value()}; + } else { + // TODO: Correct flags etc. + // TODO: Free event. + cudaEvent_t event; + cudaEventCreate(&event); + OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(event, static_cast(stream))); + return {event}; + } } nccl_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, [[maybe_unused]] tag_type tag, void* stream) { + std::cerr << "nccl::recv\n"; + device_guard dg(ptr); OOMPH_CHECK_NCCL_RESULT( ncclRecv(dg.data(), size, ncclChar, src, m_context->get_comm(), static_cast(stream))); - // TODO: Return event to stream? Return void? - return {}; + + if (m_in_group) { + m_last_stream = static_cast(stream); + // Store event now, but record it when group ends + std::cerr << "using group event " << m_group_event.value() << "\n"; + return {m_group_event.value()}; + } else { + // TODO: Correct flags etc. + // TODO: Free event. + cudaEvent_t event; + cudaEventCreate(&event); + OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(event, static_cast(stream))); + return {event}; + } } send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, tag_type tag, util::unique_function&& cb, std::size_t* scheduled, void* stream) { auto req = send(ptr, size, dst, tag, stream); - // if (!has_reached_recursion_depth() && req.is_ready()) - // { - // auto inc = recursion(); - // cb(dst, tag); - // return {}; - // } - // else - // { - // TODO: Do we want to support callbacks for NCCL communication? How should this be structured? + // TODO: Do early checking? auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb), req); - // s->create_self_ref(); - // TODO: Callback ignored. - // TODO: Have to respect `scheduled`. Needs to be incremented before - // send and decremeted when done. - // m_send_reqs.enqueue(s.get()); + s->create_self_ref(); + m_send_reqs.enqueue(s.get()); return {std::move(s)}; - // } } recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, std::size_t* scheduled, void* stream) { auto req = recv(ptr, size, src, tag, stream); - // if (!has_reached_recursion_depth() && req.is_ready()) - // { - // auto inc = recursion(); - // cb(src, tag); - // return {}; - // } - // else - // { + // TODO: Do early checking? auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb), req); - // s->create_self_ref(); - // m_recv_reqs.enqueue(s.get()); + s->create_self_ref(); + m_recv_reqs.enqueue(s.get()); return {std::move(s)}; - // } } shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, @@ -111,25 +143,23 @@ class communicator_impl : public communicator_base std::atomic* scheduled, void* stream) { auto req = recv(ptr, size, src, tag, stream); - // if (!m_context->has_reached_recursion_depth() && req.is_ready()) - // { - // auto inc = m_context->recursion(); - // cb(src, tag); - // return {}; - // } - // else - // { + // TODO: Do early checking? auto s = std::make_shared(m_context, this, scheduled, src, tag, std::move(cb), req); - // s->create_self_ref(); - // m_context->m_req_queue.enqueue(s.get()); + s->create_self_ref(); + m_context->m_req_queue.enqueue(s.get()); return {std::move(s)}; - // } } void progress() { - // Nothing to do to progress NCCL. Just wait for GPU to finish. + std::cerr << "nccl communicator::progress\n"; + // Communication progresses independently, but requests must be marked + // ready and callbacks must be invoked. + m_send_reqs.progress(); + m_recv_reqs.progress(); + m_context->progress(); + // std::this_thread::sleep_for(std::chrono::seconds(1)); } bool cancel_recv(detail::request_state*) diff --git a/src/nccl/context.hpp b/src/nccl/context.hpp index 2cfd4dec..2dbd850f 100644 --- a/src/nccl/context.hpp +++ b/src/nccl/context.hpp @@ -18,6 +18,7 @@ #include #include <../context_base.hpp> #include +#include "./request_queue.hpp" namespace oomph { @@ -32,12 +33,19 @@ class context_impl : public context_base heap_type m_heap; detail::nccl_comm m_comm; + public: + shared_request_queue m_req_queue; + public: context_impl(MPI_Comm comm, bool thread_safe, hwmalloc::heap_config const& heap_config) : context_base(comm, thread_safe) , m_heap{this, heap_config} , m_comm{oomph::detail::nccl_comm{comm}} { + if (thread_safe) { + // TODO: Appropriate? + throw std::runtime_error("NCCL not supported with thread_safe = true"); + } } context_impl(context_impl const&) = delete; @@ -52,7 +60,7 @@ class context_impl : public context_base communicator_impl* get_communicator(); void progress() { - // NCCL will make progress on its own. Or deadlock. + m_req_queue.progress(); } bool cancel_recv(detail::shared_request_state*) { diff --git a/src/nccl/cuda_error.hpp b/src/nccl/cuda_error.hpp new file mode 100644 index 00000000..04cb8166 --- /dev/null +++ b/src/nccl/cuda_error.hpp @@ -0,0 +1,29 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include +#include +#include + +#define OOMPH_CHECK_CUDA_RESULT(x) \ + if (x != cudaSuccess) \ + throw std::runtime_error("OOMPH Error: CUDA Call failed " + std::string(#x) + " (" + \ + std::string(cudaGetErrorString(x)) + ") in " + \ + std::string(__FILE__) + ":" + std::to_string(__LINE__)); + +#define OOMPH_CHECK_CUDA_RESULT_NO_THROW(x) \ + try { OOMPH_CHECK_CUDA_RESULT(x) } \ + catch (const std::exception& e) { \ + std::cerr << e.what() << std::endl; \ + std::terminate(); \ + } diff --git a/src/nccl/handle.hpp b/src/nccl/handle.hpp new file mode 100644 index 00000000..086f001f --- /dev/null +++ b/src/nccl/handle.hpp @@ -0,0 +1,20 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +namespace oomph +{ +struct handle +{ + void* m_ptr; + std::size_t m_size; +}; +} // namespace oomph + diff --git a/src/nccl/nccl_error.hpp b/src/nccl/nccl_error.hpp index ac4da242..6488a028 100644 --- a/src/nccl/nccl_error.hpp +++ b/src/nccl/nccl_error.hpp @@ -20,14 +20,20 @@ #include #include #define OOMPH_CHECK_NCCL_RESULT(x) \ - if (x != ncclSuccess && x != ncclInProgress) \ - throw std::runtime_error("OOMPH Error: NCCL Call failed " + std::string(#x) + " in " + \ - std::string(__FILE__) + ":" + std::to_string(__LINE__)); + { \ + ncclResult_t r = x; \ + if (r != ncclSuccess && r != ncclInProgress) \ + throw std::runtime_error("OOMPH Error: NCCL Call failed " + std::string(#x) + " = " + std::to_string(r) + " (\"" + ncclGetErrorString(r) + "\") in " + \ + std::string(__FILE__) + ":" + std::to_string(__LINE__)); \ + } #define OOMPH_CHECK_NCCL_RESULT_NOEXCEPT(x) \ - if (x != ncclSuccess && x != ncclInProgress) \ { \ - std::cerr << "OOMPH Error: NCCL Call failed " << std::string(#x) << " in " \ - << std::string(__FILE__) << ":" << std::to_string(__LINE__) << std::endl; \ - std::terminate(); \ + ncclResult_t r = x; \ + if (r != ncclSuccess && r != ncclInProgress) \ + { \ + std::cerr << "OOMPH Error: NCCL Call failed " << std::string(#x) << " in " \ + << std::string(__FILE__) << ":" << std::to_string(__LINE__) << std::endl; \ + std::terminate(); \ + } \ } #endif diff --git a/src/nccl/region.hpp b/src/nccl/region.hpp new file mode 100644 index 00000000..c7593cb0 --- /dev/null +++ b/src/nccl/region.hpp @@ -0,0 +1,45 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +// paths relative to backend +#include + +namespace oomph +{ +class region +{ + public: + using handle_type = handle; + + private: + void* m_ptr; + + public: + region(void* ptr) + : m_ptr{ptr} + { + } + + region(region const&) = delete; + + region(region&& r) noexcept + : m_ptr{std::exchange(r.m_ptr, nullptr)} + { + } + + // get a handle to some portion of the region + handle_type get_handle(std::size_t offset, std::size_t size) + { + return {(void*)((char*)m_ptr + offset), size}; + } +}; +} // namespace oomph + diff --git a/src/nccl/request.hpp b/src/nccl/request.hpp index 4d3ee0eb..333e772b 100644 --- a/src/nccl/request.hpp +++ b/src/nccl/request.hpp @@ -9,13 +9,32 @@ */ #pragma once +#include + +#include "./cuda_error.hpp" + namespace oomph { struct nccl_request { // TODO: Ready when group has completed? Check stream or event? - bool is_ready() { return true; } + bool is_ready() { + std::cerr << "checking if request is ready\n"; + cudaError_t res = cudaEventQuery(m_event); + std::cerr << "request " << m_event << " is in state " << res << "\n"; + if (res == cudaSuccess) { + return true; + } else if (res == cudaErrorNotReady) { + return false; + } else { + OOMPH_CHECK_CUDA_RESULT(res); + return false; + } + } // TODO: No cancellation with NCCL? bool cancel() { return false; } + + // TODO: Use wrapper class + cudaEvent_t m_event; }; } // namespace oomph diff --git a/src/nccl/request_queue.hpp b/src/nccl/request_queue.hpp new file mode 100644 index 00000000..12bb5ba6 --- /dev/null +++ b/src/nccl/request_queue.hpp @@ -0,0 +1,165 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include + +// paths relative to backend +#include + +namespace oomph +{ + +class request_queue +{ + private: + using element_type = detail::request_state; + using queue_type = std::vector; + + private: // members + queue_type m_queue; + bool in_progress = false; + + public: // ctors + request_queue() + { + m_queue.reserve(256); + } + + public: // member functions + std::size_t size() const noexcept { return m_queue.size(); } + + void enqueue(element_type* e) + { + e->m_index = m_queue.size(); + m_queue.push_back(e); + } + + int progress() + { + std::cerr << "nccl request_queue::progress\n"; + if (in_progress) return 0; + in_progress = true; + + const auto qs = size(); + if (qs == 0) + { + in_progress = false; + return 0; + } + + auto erase_begin = std::remove_if( + m_queue.begin(), m_queue.end(), + [](auto& req) { + std::cerr << "checking if request ready with event " << req->m_req.m_event << "\n"; + if (req->m_req.is_ready()) { + auto ptr = req->release_self_ref(); + std::cerr << "invoking callback on req: " << req << "\n"; + req->invoke_cb(); + return true; + } else { + return false; + } + } + ); + auto completed = std::distance(erase_begin, m_queue.end()); + if (completed != 0) { + std::cerr << "completed " << completed << " requests\n"; + } + m_queue.erase(erase_begin, m_queue.end()); + + in_progress = false; + return completed; + } + + bool cancel(element_type*) + { + // No cancellation with NCCL. + return false; + } +}; + +class shared_request_queue +{ + private: + using element_type = detail::shared_request_state; + using queue_type = boost::lockfree::queue, + boost::lockfree::allocator>>; + + private: // members + queue_type m_queue; + std::atomic m_size; + + public: // ctors + shared_request_queue() + : m_queue(256) + , m_size(0) + { + } + + public: // member functions + std::size_t size() const noexcept { return m_size.load(); } + + void enqueue(element_type* e) + { + m_queue.push(e); + ++m_size; + } + + int progress() + { + std::cerr << "nccl shared_request_queue::progress\n"; + + static thread_local bool in_progress = false; + static thread_local std::vector m_local_queue; + int found = 0; + + if (in_progress) return 0; + in_progress = true; + + element_type* e; + while (m_queue.pop(e)) + { + if (e->m_req.is_ready()) + { + std::cerr << "found ready request in shared queue\n"; + found = 1; + break; + } + else + { + m_local_queue.push_back(e); + } + } + + for (auto x : m_local_queue) m_queue.push(x); + m_local_queue.clear(); + + if (found) + { + auto ptr = e->release_self_ref(); + e->invoke_cb(); + --m_size; + } + + in_progress = false; + return found; + } + + bool cancel(element_type*) + { + // No cancellation with NCCL. + return false; + } +}; + +} // namespace oomph diff --git a/src/nccl/request_state.hpp b/src/nccl/request_state.hpp index 0eb061de..eb21cd8a 100644 --- a/src/nccl/request_state.hpp +++ b/src/nccl/request_state.hpp @@ -12,8 +12,8 @@ #include // paths relative to backend -#include <../request_state_base.hpp> -#include +#include "../request_state_base.hpp" +#include "./request.hpp" namespace oomph { @@ -35,6 +35,7 @@ struct request_state : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} , m_req{m} { + std::cerr << "creating nccl request_state\n"; } void progress(); @@ -72,6 +73,7 @@ struct shared_request_state : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} , m_req{m} { + std::cerr << "creating nccl shared_request_state\n"; } void progress(); diff --git a/src/request.cpp b/src/request.cpp index 972650f3..9749979d 100644 --- a/src/request.cpp +++ b/src/request.cpp @@ -24,6 +24,7 @@ namespace oomph bool send_request::is_ready() const noexcept { + std::cerr << "send_request::is_ready()\n"; if (!m) return true; return m->is_ready(); } @@ -46,7 +47,9 @@ send_request::wait() bool recv_request::is_ready() const noexcept { + // std::cerr << "recv_request::is_ready()\n"; if (!m) return true; + // std::cerr << "recv_request::is_ready, checking impl m->is_ready()\n"; return m->is_ready(); } @@ -83,6 +86,7 @@ recv_request::cancel() bool shared_recv_request::is_ready() const noexcept { + std::cerr << "shared_recv_request::is_ready()\n"; if (!m) return true; return m->is_ready(); } @@ -120,6 +124,7 @@ shared_recv_request::cancel() bool send_multi_request::is_ready() const noexcept { + std::cerr << "send_multi_request::is_ready()\n"; if (!m) return true; return (m->m_counter == 0); } diff --git a/src/request_state_base.hpp b/src/request_state_base.hpp index c0a6598a..1383dce6 100644 --- a/src/request_state_base.hpp +++ b/src/request_state_base.hpp @@ -88,12 +88,16 @@ struct request_state_base ++(*m_scheduled); } - bool is_ready() const noexcept { return traits::load(m_ready); } + bool is_ready() const noexcept { + // std::cerr << "request_state_base::is_ready()\n"; + return traits::load(m_ready); + } bool is_canceled() const noexcept { return traits::load(m_canceled); } void invoke_cb() { + std::cerr << "invoke_cb, setting m_ready to true\n"; m_cb(m_rank, m_tag); --(*m_scheduled); traits::store(m_ready, true); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5217bbaf..10dd3bbe 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -86,4 +86,10 @@ if (OOMPH_WITH_LIBFABRIC) endforeach() endif() +if (OOMPH_WITH_NCCL) + foreach(t ${parallel_tests}) + reg_parallel_test(${t} nccl 4) + endforeach() +endif() + add_subdirectory(bindings) diff --git a/test/test_send_recv.cpp b/test/test_send_recv.cpp index 0cfd1170..3943bb47 100644 --- a/test/test_send_recv.cpp +++ b/test/test_send_recv.cpp @@ -182,8 +182,11 @@ template void launch_test(Func f) { + std::cerr << "launch_test\n"; + // single threaded { + std::cerr << "single threaded\n"; oomph::context ctxt(MPI_COMM_WORLD, false); reset_counters(); f(ctxt, SIZE, 0, 1, false); @@ -192,20 +195,22 @@ launch_test(Func f) } // multi threaded - { - oomph::context ctxt(MPI_COMM_WORLD, true); - std::vector threads; - threads.reserve(NTHREADS); - reset_counters(); - for (int i = 0; i < NTHREADS; ++i) - threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, false}); - for (auto& t : threads) t.join(); - threads.clear(); - reset_counters(); - for (int i = 0; i < NTHREADS; ++i) - threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, true}); - for (auto& t : threads) t.join(); - } + // TODO: Don't run for NCCL, run for others. + // { + // std::cerr << "multi threaded\n"; + // oomph::context ctxt(MPI_COMM_WORLD, true); + // std::vector threads; + // threads.reserve(NTHREADS); + // reset_counters(); + // for (int i = 0; i < NTHREADS; ++i) + // threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, false}); + // for (auto& t : threads) t.join(); + // threads.clear(); + // reset_counters(); + // for (int i = 0; i < NTHREADS; ++i) + // threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, true}); + // for (auto& t : threads) t.join(); + // } } // no callback @@ -217,42 +222,57 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, Env env(ctxt, size, tid, num_threads, user_alloc); // use is_ready() -> must manually progress the communicator + std::cerr << "test_send_recv 1\n"; for (int i = 0; i < NITERS; i++) { + std::cerr << "iteration " << i << "\n"; + env.comm.start_group(); auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); auto sreq = env.comm.send(env.smsg, env.speer_rank, env.tag); + env.comm.end_group(); + std::cerr << "rreq.is_ready() = " << rreq.is_ready() << '\n'; + std::cerr << "sreq.is_ready() = " << sreq.is_ready() << '\n'; while (!(rreq.is_ready() && sreq.is_ready())) { + std::cerr << "calling env.comm.progress()\n"; env.comm.progress(); }; EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); } + std::cerr << "test_send_recv 1 done\n"; + std::cerr << "test_send_recv 2\n"; // use test() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); auto sreq = env.comm.send(env.smsg, env.speer_rank, env.tag); + env.comm.end_group(); while (!(rreq.test() && sreq.test())) {}; EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); } - // use wait() -> communicator is progressed automatically - for (int i = 0; i < NITERS; i++) - { - auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); - env.comm.send(env.smsg, env.speer_rank, env.tag).wait(); - rreq.wait(); - EXPECT_TRUE(env.check_recv_buffer()); - env.fill_recv_buffer(); - } + // std::cerr << "test_send_recv 3\n"; + // // use wait() -> communicator is progressed automatically + // for (int i = 0; i < NITERS; i++) + // { + // env.comm.start_group(); + // auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); + // env.comm.send(env.smsg, env.speer_rank, env.tag).wait(); + // env.comm.end_group(); + // rreq.wait(); + // EXPECT_TRUE(env.check_recv_buffer()); + // env.fill_recv_buffer(); + // } } TEST_F(mpi_test_fixture, send_recv) { - launch_test(test_send_recv); + // TODO: Only device tests with NCCL. + // launch_test(test_send_recv); #if HWMALLOC_ENABLE_DEVICE launch_test(test_send_recv); #endif @@ -279,8 +299,10 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa // use is_ready() -> must manually progress the communicator for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.recv(env.rmsg, env.rpeer_rank, 1, recv_callback); auto sh = env.comm.send(env.smsg, env.speer_rank, 1, send_callback); + env.comm.end_group(); while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -317,283 +339,283 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa EXPECT_EQ(sent, NITERS); } -TEST_F(mpi_test_fixture, send_recv_cb) -{ - launch_test(test_send_recv_cb); -#if HWMALLOC_ENABLE_DEVICE - launch_test(test_send_recv_cb); -#endif -} - -// callback: pass by r-value reference (give up ownership) -// ======================================================= -template -void -test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, - bool user_alloc) -{ - using rank_type = test_environment::rank_type; - using tag_type = test_environment::tag_type; - using message = test_environment::message; - - Env env(ctxt, size, tid, num_threads, user_alloc); - - volatile int received = 0; - volatile int sent = 0; - - auto send_callback = [&](message msg, rank_type, tag_type) - { - ++sent; - env.smsg = std::move(msg); - }; - auto recv_callback = [&](message msg, rank_type, tag_type) - { - ++received; - env.rmsg = std::move(msg); - }; - - // use is_ready() -> must manually progress the communicator - for (int i = 0; i < NITERS; i++) - { - auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); - auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); - while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } - EXPECT_TRUE(env.check_recv_buffer()); - env.fill_recv_buffer(); - } - EXPECT_EQ(received, NITERS); - EXPECT_EQ(sent, NITERS); - - received = 0; - sent = 0; - // use test() -> communicator is progressed automatically - for (int i = 0; i < NITERS; i++) - { - auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); - auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); - while (!rh.test() || !sh.test()) {} - EXPECT_TRUE(env.check_recv_buffer()); - env.fill_recv_buffer(); - } - EXPECT_EQ(received, NITERS); - EXPECT_EQ(sent, NITERS); - - received = 0; - sent = 0; - // use wait() -> communicator is progressed automatically - for (int i = 0; i < NITERS; i++) - { - auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); - env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback).wait(); - rh.wait(); - EXPECT_TRUE(env.check_recv_buffer()); - env.fill_recv_buffer(); - } - EXPECT_EQ(received, NITERS); - EXPECT_EQ(sent, NITERS); -} - -TEST_F(mpi_test_fixture, send_recv_cb_disown) -{ - launch_test(test_send_recv_cb_disown); -#if HWMALLOC_ENABLE_DEVICE - launch_test(test_send_recv_cb_disown); -#endif -} - -// callback: pass by r-value reference (give up ownership), shared recv -// ==================================================================== -template -void -test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, - bool user_alloc) -{ - using rank_type = test_environment::rank_type; - using tag_type = test_environment::tag_type; - using message = test_environment::message; - - Env env(ctxt, size, tid, num_threads, user_alloc); - - thread_id = env.thread_id; - - //volatile int received = 0; - volatile int sent = 0; - - auto send_callback = [&](message msg, rank_type, tag_type) - { - ++sent; - env.smsg = std::move(msg); - }; - auto recv_callback = [&](message msg, rank_type, tag_type) - { - //std::cout << thread_id << " " << env.thread_id << std::endl; - //if (thread_id != env.thread_id) std::cout << "other thread picked up callback" << std::endl; - //else std::cout << "my thread picked up callback" << std::endl; - env.rmsg = std::move(msg); - ++shared_received[env.thread_id]; - }; - - // use is_ready() -> must manually progress the communicator - for (int i = 0; i < NITERS; i++) - { - auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); - auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); - while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } - EXPECT_TRUE(env.rmsg); - EXPECT_TRUE(env.check_recv_buffer()); - env.fill_recv_buffer(); - } - EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); - EXPECT_EQ(sent, NITERS); - - shared_received[env.thread_id].store(0); - sent = 0; - // use test() -> communicator is progressed automatically - for (int i = 0; i < NITERS; i++) - { - auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); - auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); - while (!rh.test() || !sh.test()) {} - EXPECT_TRUE(env.check_recv_buffer()); - env.fill_recv_buffer(); - } - EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); - EXPECT_EQ(sent, NITERS); - - shared_received[env.thread_id].store(0); - sent = 0; - // use wait() -> communicator is progressed automatically - for (int i = 0; i < NITERS; i++) - { - auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); - env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback).wait(); - rh.wait(); - EXPECT_TRUE(env.check_recv_buffer()); - env.fill_recv_buffer(); - } - EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); - EXPECT_EQ(sent, NITERS); -} - -TEST_F(mpi_test_fixture, send_shared_recv_cb_disown) -{ - launch_test(test_send_shared_recv_cb_disown); -#if HWMALLOC_ENABLE_DEVICE - launch_test(test_send_shared_recv_cb_disown); -#endif -} - -// callback: pass by l-value reference, and resubmit -// ================================================= -template -void -test_send_recv_cb_resubmit(oomph::context& ctxt, std::size_t size, int tid, int num_threads, - bool user_alloc) -{ - using rank_type = test_environment::rank_type; - using tag_type = test_environment::tag_type; - using message = test_environment::message; - - Env env(ctxt, size, tid, num_threads, user_alloc); - - volatile int received = 0; - volatile int sent = 0; - - struct recursive_send_callback - { - Env& env; - volatile int& sent; - - void operator()(message& msg, rank_type dst, tag_type tag) - { - ++sent; - if (sent < NITERS) env.comm.send(msg, dst, tag, recursive_send_callback{*this}); - } - }; - - struct recursive_recv_callback - { - Env& env; - volatile int& received; - - void operator()(message& msg, rank_type src, tag_type tag) - { - ++received; - EXPECT_TRUE(env.check_recv_buffer()); - env.fill_recv_buffer(); - if (received < NITERS) env.comm.recv(msg, src, tag, recursive_recv_callback{*this}); - } - }; - - env.comm.recv(env.rmsg, env.rpeer_rank, 1, recursive_recv_callback{env, received}); - env.comm.send(env.smsg, env.speer_rank, 1, recursive_send_callback{env, sent}); - - while (sent < NITERS || received < NITERS) { env.comm.progress(); }; -} - -TEST_F(mpi_test_fixture, send_recv_cb_resubmit) -{ - launch_test(test_send_recv_cb_resubmit); -#if HWMALLOC_ENABLE_DEVICE - launch_test(test_send_recv_cb_resubmit); -#endif -} - -// callback: pass by r-value reference (give up ownership), and resubmit -// ===================================================================== -template -void -test_send_recv_cb_resubmit_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, - bool user_alloc) -{ - using rank_type = test_environment::rank_type; - using tag_type = test_environment::tag_type; - using message = test_environment::message; - - Env env(ctxt, size, tid, num_threads, user_alloc); - - volatile int received = 0; - volatile int sent = 0; - - struct recursive_send_callback - { - Env& env; - volatile int& sent; - - void operator()(message msg, rank_type dst, tag_type tag) - { - ++sent; - if (sent < NITERS) - env.comm.send(std::move(msg), dst, tag, recursive_send_callback{*this}); - } - }; - - struct recursive_recv_callback - { - Env& env; - volatile int& received; - - void operator()(message msg, rank_type src, tag_type tag) - { - ++received; - env.rmsg = std::move(msg); - EXPECT_TRUE(env.check_recv_buffer()); - env.fill_recv_buffer(); - if (received < NITERS) - env.comm.recv(std::move(env.rmsg), src, tag, recursive_recv_callback{*this}); - } - }; - - env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recursive_recv_callback{env, received}); - env.comm.send(std::move(env.smsg), env.speer_rank, 1, recursive_send_callback{env, sent}); - - while (sent < NITERS || received < NITERS) { env.comm.progress(); }; -} - -TEST_F(mpi_test_fixture, send_recv_cb_resubmit_disown) -{ - launch_test(test_send_recv_cb_resubmit_disown); -#if HWMALLOC_ENABLE_DEVICE - launch_test(test_send_recv_cb_resubmit_disown); -#endif -} +// TEST_F(mpi_test_fixture, send_recv_cb) +// { +// launch_test(test_send_recv_cb); +// #if HWMALLOC_ENABLE_DEVICE +// launch_test(test_send_recv_cb); +// #endif +// } +// +// // callback: pass by r-value reference (give up ownership) +// // ======================================================= +// template +// void +// test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, +// bool user_alloc) +// { +// using rank_type = test_environment::rank_type; +// using tag_type = test_environment::tag_type; +// using message = test_environment::message; +// +// Env env(ctxt, size, tid, num_threads, user_alloc); +// +// volatile int received = 0; +// volatile int sent = 0; +// +// auto send_callback = [&](message msg, rank_type, tag_type) +// { +// ++sent; +// env.smsg = std::move(msg); +// }; +// auto recv_callback = [&](message msg, rank_type, tag_type) +// { +// ++received; +// env.rmsg = std::move(msg); +// }; +// +// // use is_ready() -> must manually progress the communicator +// for (int i = 0; i < NITERS; i++) +// { +// auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); +// auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); +// while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } +// EXPECT_TRUE(env.check_recv_buffer()); +// env.fill_recv_buffer(); +// } +// EXPECT_EQ(received, NITERS); +// EXPECT_EQ(sent, NITERS); +// +// received = 0; +// sent = 0; +// // use test() -> communicator is progressed automatically +// for (int i = 0; i < NITERS; i++) +// { +// auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); +// auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); +// while (!rh.test() || !sh.test()) {} +// EXPECT_TRUE(env.check_recv_buffer()); +// env.fill_recv_buffer(); +// } +// EXPECT_EQ(received, NITERS); +// EXPECT_EQ(sent, NITERS); +// +// received = 0; +// sent = 0; +// // use wait() -> communicator is progressed automatically +// for (int i = 0; i < NITERS; i++) +// { +// auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); +// env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback).wait(); +// rh.wait(); +// EXPECT_TRUE(env.check_recv_buffer()); +// env.fill_recv_buffer(); +// } +// EXPECT_EQ(received, NITERS); +// EXPECT_EQ(sent, NITERS); +// } +// +// TEST_F(mpi_test_fixture, send_recv_cb_disown) +// { +// launch_test(test_send_recv_cb_disown); +// #if HWMALLOC_ENABLE_DEVICE +// launch_test(test_send_recv_cb_disown); +// #endif +// } +// +// // callback: pass by r-value reference (give up ownership), shared recv +// // ==================================================================== +// template +// void +// test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, +// bool user_alloc) +// { +// using rank_type = test_environment::rank_type; +// using tag_type = test_environment::tag_type; +// using message = test_environment::message; +// +// Env env(ctxt, size, tid, num_threads, user_alloc); +// +// thread_id = env.thread_id; +// +// //volatile int received = 0; +// volatile int sent = 0; +// +// auto send_callback = [&](message msg, rank_type, tag_type) +// { +// ++sent; +// env.smsg = std::move(msg); +// }; +// auto recv_callback = [&](message msg, rank_type, tag_type) +// { +// //std::cout << thread_id << " " << env.thread_id << std::endl; +// //if (thread_id != env.thread_id) std::cout << "other thread picked up callback" << std::endl; +// //else std::cout << "my thread picked up callback" << std::endl; +// env.rmsg = std::move(msg); +// ++shared_received[env.thread_id]; +// }; +// +// // use is_ready() -> must manually progress the communicator +// for (int i = 0; i < NITERS; i++) +// { +// auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); +// auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); +// while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } +// EXPECT_TRUE(env.rmsg); +// EXPECT_TRUE(env.check_recv_buffer()); +// env.fill_recv_buffer(); +// } +// EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); +// EXPECT_EQ(sent, NITERS); +// +// shared_received[env.thread_id].store(0); +// sent = 0; +// // use test() -> communicator is progressed automatically +// for (int i = 0; i < NITERS; i++) +// { +// auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); +// auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); +// while (!rh.test() || !sh.test()) {} +// EXPECT_TRUE(env.check_recv_buffer()); +// env.fill_recv_buffer(); +// } +// EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); +// EXPECT_EQ(sent, NITERS); +// +// shared_received[env.thread_id].store(0); +// sent = 0; +// // use wait() -> communicator is progressed automatically +// for (int i = 0; i < NITERS; i++) +// { +// auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); +// env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback).wait(); +// rh.wait(); +// EXPECT_TRUE(env.check_recv_buffer()); +// env.fill_recv_buffer(); +// } +// EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); +// EXPECT_EQ(sent, NITERS); +// } +// +// TEST_F(mpi_test_fixture, send_shared_recv_cb_disown) +// { +// launch_test(test_send_shared_recv_cb_disown); +// #if HWMALLOC_ENABLE_DEVICE +// launch_test(test_send_shared_recv_cb_disown); +// #endif +// } +// +// // callback: pass by l-value reference, and resubmit +// // ================================================= +// template +// void +// test_send_recv_cb_resubmit(oomph::context& ctxt, std::size_t size, int tid, int num_threads, +// bool user_alloc) +// { +// using rank_type = test_environment::rank_type; +// using tag_type = test_environment::tag_type; +// using message = test_environment::message; +// +// Env env(ctxt, size, tid, num_threads, user_alloc); +// +// volatile int received = 0; +// volatile int sent = 0; +// +// struct recursive_send_callback +// { +// Env& env; +// volatile int& sent; +// +// void operator()(message& msg, rank_type dst, tag_type tag) +// { +// ++sent; +// if (sent < NITERS) env.comm.send(msg, dst, tag, recursive_send_callback{*this}); +// } +// }; +// +// struct recursive_recv_callback +// { +// Env& env; +// volatile int& received; +// +// void operator()(message& msg, rank_type src, tag_type tag) +// { +// ++received; +// EXPECT_TRUE(env.check_recv_buffer()); +// env.fill_recv_buffer(); +// if (received < NITERS) env.comm.recv(msg, src, tag, recursive_recv_callback{*this}); +// } +// }; +// +// env.comm.recv(env.rmsg, env.rpeer_rank, 1, recursive_recv_callback{env, received}); +// env.comm.send(env.smsg, env.speer_rank, 1, recursive_send_callback{env, sent}); +// +// while (sent < NITERS || received < NITERS) { env.comm.progress(); }; +// } +// +// TEST_F(mpi_test_fixture, send_recv_cb_resubmit) +// { +// launch_test(test_send_recv_cb_resubmit); +// #if HWMALLOC_ENABLE_DEVICE +// launch_test(test_send_recv_cb_resubmit); +// #endif +// } +// +// // callback: pass by r-value reference (give up ownership), and resubmit +// // ===================================================================== +// template +// void +// test_send_recv_cb_resubmit_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, +// bool user_alloc) +// { +// using rank_type = test_environment::rank_type; +// using tag_type = test_environment::tag_type; +// using message = test_environment::message; +// +// Env env(ctxt, size, tid, num_threads, user_alloc); +// +// volatile int received = 0; +// volatile int sent = 0; +// +// struct recursive_send_callback +// { +// Env& env; +// volatile int& sent; +// +// void operator()(message msg, rank_type dst, tag_type tag) +// { +// ++sent; +// if (sent < NITERS) +// env.comm.send(std::move(msg), dst, tag, recursive_send_callback{*this}); +// } +// }; +// +// struct recursive_recv_callback +// { +// Env& env; +// volatile int& received; +// +// void operator()(message msg, rank_type src, tag_type tag) +// { +// ++received; +// env.rmsg = std::move(msg); +// EXPECT_TRUE(env.check_recv_buffer()); +// env.fill_recv_buffer(); +// if (received < NITERS) +// env.comm.recv(std::move(env.rmsg), src, tag, recursive_recv_callback{*this}); +// } +// }; +// +// env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recursive_recv_callback{env, received}); +// env.comm.send(std::move(env.smsg), env.speer_rank, 1, recursive_send_callback{env, sent}); +// +// while (sent < NITERS || received < NITERS) { env.comm.progress(); }; +// } +// +// TEST_F(mpi_test_fixture, send_recv_cb_resubmit_disown) +// { +// launch_test(test_send_recv_cb_resubmit_disown); +// #if HWMALLOC_ENABLE_DEVICE +// launch_test(test_send_recv_cb_resubmit_disown); +// #endif +// } From d3a4b042923028abc502f4418ced0239e994aba0 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Thu, 18 Dec 2025 14:47:49 +0100 Subject: [PATCH 06/25] Enable one more nccl test --- test/test_send_recv.cpp | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/test/test_send_recv.cpp b/test/test_send_recv.cpp index 3943bb47..de7826f9 100644 --- a/test/test_send_recv.cpp +++ b/test/test_send_recv.cpp @@ -255,18 +255,21 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, env.fill_recv_buffer(); } - // std::cerr << "test_send_recv 3\n"; - // // use wait() -> communicator is progressed automatically - // for (int i = 0; i < NITERS; i++) - // { - // env.comm.start_group(); - // auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); - // env.comm.send(env.smsg, env.speer_rank, env.tag).wait(); - // env.comm.end_group(); - // rreq.wait(); - // EXPECT_TRUE(env.check_recv_buffer()); - // env.fill_recv_buffer(); - // } + std::cerr << "test_send_recv 3\n"; + // use wait() -> communicator is progressed automatically + for (int i = 0; i < NITERS; i++) + { + env.comm.start_group(); + auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); + // TODO: The sreq.wait was previously called immediately. With NCCL + // groups can't call wait so early (communication hasn't started yet). + auto sreq = env.comm.send(env.smsg, env.speer_rank, env.tag); + env.comm.end_group(); + sreq.wait(); + rreq.wait(); + EXPECT_TRUE(env.check_recv_buffer()); + env.fill_recv_buffer(); + } } TEST_F(mpi_test_fixture, send_recv) From 4deaf8278452a0c0c1009eef66db2aa75469647e Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 19 Dec 2025 14:20:02 +0100 Subject: [PATCH 07/25] Remove TODOs --- include/oomph/communicator.hpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/include/oomph/communicator.hpp b/include/oomph/communicator.hpp index 6aebb200..182b43b7 100644 --- a/include/oomph/communicator.hpp +++ b/include/oomph/communicator.hpp @@ -102,10 +102,6 @@ class communicator bool is_ready() const noexcept { - std::cerr << "communicator::is_ready()\n"; - // TODO: Would prefer not to count sends/recvs for NCCL. Prefer to check - // if stream or event is done (sends/recvs should be submitted in - // groups). return (scheduled_sends() == 0) && (scheduled_recvs() == 0) && (scheduled_shared_recvs() == 0); } @@ -149,7 +145,6 @@ class communicator } #endif - // TODO: const noexcept? void start_group(); void end_group(); From 69a46aa5d8efbffa7400ce4ee50da58297f30ead Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 19 Dec 2025 14:27:29 +0100 Subject: [PATCH 08/25] Add is_stream_aware, start_group, end_group to all backends --- src/libfabric/communicator.hpp | 5 +++++ src/mpi/communicator.hpp | 2 +- src/ucx/communicator.hpp | 5 +++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/libfabric/communicator.hpp b/src/libfabric/communicator.hpp index ff8fc945..68bcbf7e 100644 --- a/src/libfabric/communicator.hpp +++ b/src/libfabric/communicator.hpp @@ -75,6 +75,11 @@ class communicator_impl : public communicator_base // -------------------------------------------------------------------- auto& get_heap() noexcept { return m_context->get_heap(); } + bool is_stream_aware() const noexcept { return false; } + + void start_group() {} + void end_group() {} + // -------------------------------------------------------------------- /// generate a tag with 0xRRRRRRRRtttttttt rank, tag. /// original tag can be 32bits, then we add 32bits of rank info. diff --git a/src/mpi/communicator.hpp b/src/mpi/communicator.hpp index b47c47b8..eebe4286 100644 --- a/src/mpi/communicator.hpp +++ b/src/mpi/communicator.hpp @@ -40,7 +40,7 @@ class communicator_impl : public communicator_base void end_group() {} mpi_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag, void*) // TODO: Stream ignored, not stream-aware. Separate interface? + tag_type tag, void*) { MPI_Request r; const_device_guard dg(ptr); diff --git a/src/ucx/communicator.hpp b/src/ucx/communicator.hpp index dcb4a4ac..f90943a4 100644 --- a/src/ucx/communicator.hpp +++ b/src/ucx/communicator.hpp @@ -70,6 +70,11 @@ class communicator_impl : public communicator_base auto& get_heap() noexcept { return m_context->get_heap(); } + bool is_stream_aware() const noexcept { return false; } + + void start_group() {} + void end_group() {} + void progress() { while (ucp_worker_progress(m_send_worker->get())) {} From 56a0159c14c8e2c9207bac17e39f4a5265f4fa5e Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 19 Dec 2025 16:09:58 +0100 Subject: [PATCH 09/25] Clean up nccl event/request handling --- src/nccl/communicator.hpp | 74 +++++++++++++++++++------------------- src/nccl/request.hpp | 25 +++++-------- src/nccl/request_queue.hpp | 2 +- src/nccl/request_state.hpp | 4 +-- 4 files changed, 48 insertions(+), 57 deletions(-) diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp index eb391e8b..f7888c87 100644 --- a/src/nccl/communicator.hpp +++ b/src/nccl/communicator.hpp @@ -21,8 +21,8 @@ #include "../device_guard.hpp" #include "./context.hpp" #include "./request.hpp" -#include "./request_state.hpp" #include "./request_queue.hpp" +#include "./request_state.hpp" namespace oomph { @@ -32,10 +32,16 @@ class communicator_impl : public communicator_base context_impl* m_context; request_queue m_send_reqs; request_queue m_recv_reqs; - bool m_in_group = false; - std::optional m_group_event; - cudaStream_t m_last_stream; + private: + struct group_info { + detail::group_cuda_event m_event{}; + cudaStream_t m_last_stream{}; + }; + + std::optional m_group_info; + + public: communicator_impl(context_impl* ctxt) : communicator_base(ctxt) , m_context(ctxt) @@ -47,24 +53,23 @@ class communicator_impl : public communicator_base bool is_stream_aware() const noexcept { return true; } void start_group() { - OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); - m_in_group = true; + assert(!m_group_info.has_value()); - // TODO: Correct flags etc. - cudaEvent_t event; - cudaEventCreate(&event); - std::cerr << "created group event " << event << "\n"; - m_group_event = event; + OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); + m_group_info.emplace(); + std::cerr << "started group\n"; + std::cerr << "group_info: " << static_cast(m_group_info->m_event.get()) << "\n"; } void end_group() { - m_in_group = false; + assert(m_group_info.has_value()); + OOMPH_CHECK_NCCL_RESULT(ncclGroupEnd()); // All streams used in a NCCL group synchronize with the end of the group. // We arbitrarily pick the last stream to synchronize against. - OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(m_group_event.value(), m_last_stream)); - // TODO: Release event. + m_group_info->m_event.record(m_group_info->m_last_stream); + m_group_info.reset(); } nccl_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, @@ -76,19 +81,15 @@ class communicator_impl : public communicator_base OOMPH_CHECK_NCCL_RESULT( ncclSend(dg.data(), size, ncclChar, dst, m_context->get_comm(), static_cast(stream))); - if (m_in_group) { - m_last_stream = static_cast(stream); + if (m_group_info.has_value()) { + m_group_info->m_last_stream = static_cast(stream); // Store event now, but record it when group ends - // TODO: Have to make sure it's safe to query event early. - std::cerr << "using group event " << m_group_event.value() << "\n"; - return {m_group_event.value()}; + std::cerr << "using group event " << m_group_info->m_event.get() << "\n"; + return {m_group_info->m_event}; } else { - // TODO: Correct flags etc. - // TODO: Free event. - cudaEvent_t event; - cudaEventCreate(&event); - OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(event, static_cast(stream))); - return {event}; + detail::cuda_event event; + event.record(static_cast(stream)); + return {std::move(event)}; } } @@ -101,18 +102,15 @@ class communicator_impl : public communicator_base OOMPH_CHECK_NCCL_RESULT( ncclRecv(dg.data(), size, ncclChar, src, m_context->get_comm(), static_cast(stream))); - if (m_in_group) { - m_last_stream = static_cast(stream); + if (m_group_info.has_value()) { + m_group_info->m_last_stream = static_cast(stream); // Store event now, but record it when group ends - std::cerr << "using group event " << m_group_event.value() << "\n"; - return {m_group_event.value()}; + std::cerr << "using group event " << m_group_info->m_event.get() << "\n"; + return {m_group_info->m_event}; } else { - // TODO: Correct flags etc. - // TODO: Free event. - cudaEvent_t event; - cudaEventCreate(&event); - OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(event, static_cast(stream))); - return {event}; + detail::cuda_event event; + event.record(static_cast(stream)); + return {std::move(event)}; } } @@ -121,7 +119,7 @@ class communicator_impl : public communicator_base { auto req = send(ptr, size, dst, tag, stream); // TODO: Do early checking? - auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb), req); + auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb), std::move(req)); s->create_self_ref(); m_send_reqs.enqueue(s.get()); return {std::move(s)}; @@ -132,7 +130,7 @@ class communicator_impl : public communicator_base { auto req = recv(ptr, size, src, tag, stream); // TODO: Do early checking? - auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb), req); + auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb), std::move(req)); s->create_self_ref(); m_recv_reqs.enqueue(s.get()); return {std::move(s)}; @@ -145,7 +143,7 @@ class communicator_impl : public communicator_base auto req = recv(ptr, size, src, tag, stream); // TODO: Do early checking? auto s = std::make_shared(m_context, this, scheduled, src, - tag, std::move(cb), req); + tag, std::move(cb), std::move(req)); s->create_self_ref(); m_context->m_req_queue.enqueue(s.get()); return {std::move(s)}; diff --git a/src/nccl/request.hpp b/src/nccl/request.hpp index 333e772b..37a037ee 100644 --- a/src/nccl/request.hpp +++ b/src/nccl/request.hpp @@ -9,32 +9,25 @@ */ #pragma once +#include + #include #include "./cuda_error.hpp" +#include "./cuda_event.hpp" namespace oomph { struct nccl_request { - // TODO: Ready when group has completed? Check stream or event? bool is_ready() { - std::cerr << "checking if request is ready\n"; - cudaError_t res = cudaEventQuery(m_event); - std::cerr << "request " << m_event << " is in state " << res << "\n"; - if (res == cudaSuccess) { - return true; - } else if (res == cudaErrorNotReady) { - return false; - } else { - OOMPH_CHECK_CUDA_RESULT(res); - return false; - } + return std::visit([](auto& event) { + return event.is_ready(); + }, m_event); } - // TODO: No cancellation with NCCL? - bool cancel() { return false; } - // TODO: Use wrapper class - cudaEvent_t m_event; + // We can store either a single event for a particular request, or a shared + // event that signals the end of a NCCL group. + std::variant m_event; }; } // namespace oomph diff --git a/src/nccl/request_queue.hpp b/src/nccl/request_queue.hpp index 12bb5ba6..f5738563 100644 --- a/src/nccl/request_queue.hpp +++ b/src/nccl/request_queue.hpp @@ -60,7 +60,7 @@ class request_queue auto erase_begin = std::remove_if( m_queue.begin(), m_queue.end(), [](auto& req) { - std::cerr << "checking if request ready with event " << req->m_req.m_event << "\n"; + // std::cerr << "checking if request ready with event " << req->m_req.m_event << "\n"; if (req->m_req.is_ready()) { auto ptr = req->release_self_ref(); std::cerr << "invoking callback on req: " << req << "\n"; diff --git a/src/nccl/request_state.hpp b/src/nccl/request_state.hpp index eb21cd8a..26a5d758 100644 --- a/src/nccl/request_state.hpp +++ b/src/nccl/request_state.hpp @@ -33,7 +33,7 @@ struct request_state request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, std::size_t* scheduled, rank_type rank, tag_type tag, cb_type&& cb, nccl_request m) : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} - , m_req{m} + , m_req{std::move(m)} { std::cerr << "creating nccl request_state\n"; } @@ -71,7 +71,7 @@ struct shared_request_state std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb, nccl_request m) : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} - , m_req{m} + , m_req{std::move(m)} { std::cerr << "creating nccl shared_request_state\n"; } From c8e91c1d2bcd798574d19cb71452a4830064c76b Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 19 Dec 2025 16:11:49 +0100 Subject: [PATCH 10/25] Remove debugging print --- src/mpi/request.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mpi/request.hpp b/src/mpi/request.hpp index d87356f8..a126143b 100644 --- a/src/mpi/request.hpp +++ b/src/mpi/request.hpp @@ -19,7 +19,6 @@ struct mpi_request bool is_ready() { - std::cerr << "mpi_request::is_ready\n"; int flag; OOMPH_CHECK_MPI_RESULT(MPI_Test(&m_req, &flag, MPI_STATUS_IGNORE)); return flag; From 90933b3d918b0521e683ffb4d7950a4cc23e5739 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Fri, 19 Dec 2025 17:27:34 +0100 Subject: [PATCH 11/25] cleap --- src/nccl/communicator.hpp | 58 +-- src/nccl/context.cpp | 2 +- src/nccl/context.hpp | 17 +- src/nccl/cuda_error.hpp | 3 +- src/nccl/handle.hpp | 4 +- src/nccl/nccl_communicator.hpp | 17 +- src/nccl/nccl_error.hpp | 31 +- src/nccl/region.hpp | 4 +- src/nccl/request.hpp | 8 +- src/nccl/request_queue.hpp | 34 +- src/nccl/request_state.hpp | 12 +- src/request_state_base.hpp | 2 +- test/test_send_recv.cpp | 631 +++++++++++++++++---------------- 13 files changed, 421 insertions(+), 402 deletions(-) diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp index f7888c87..af68ffaa 100644 --- a/src/nccl/communicator.hpp +++ b/src/nccl/communicator.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2025, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -20,9 +20,9 @@ #include "../communicator_base.hpp" #include "../device_guard.hpp" #include "./context.hpp" -#include "./request.hpp" -#include "./request_queue.hpp" -#include "./request_state.hpp" +#include "request.hpp" +#include "request_queue.hpp" +#include "request_state.hpp" namespace oomph { @@ -35,10 +35,29 @@ class communicator_impl : public communicator_base private: struct group_info { + // A shared CUDA event used for synchronization at the end of the NCCL + // group. All streams used within the group are waited for before the + // group kernel starts and all streams can be used to wait for the + // completion of the group kernel. From + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/streams.html: + // + // NCCL allows for using multiple streams within a group call. This will + // enforce a stream dependency of all streams before the NCCL kernel + // starts and block all streams until the NCCL kernel completes. + // + // It will behave as if the NCCL group operation was posted on every + // stream, but given it is a single operation, it will cause a global + // synchronization point between the streams. detail::group_cuda_event m_event{}; + + // We arbitrarily use the last stream used within a group to synchronize + // the whole group. cudaStream_t m_last_stream{}; }; + // NCCL group information. When no group is active this is std::nullopt. + // When a group is active it contains information used for synchronizing + // with the end of the group kernel. std::optional m_group_info; public: @@ -57,8 +76,9 @@ class communicator_impl : public communicator_base OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); m_group_info.emplace(); - std::cerr << "started group\n"; - std::cerr << "group_info: " << static_cast(m_group_info->m_event.get()) << "\n"; + + // std::cerr << "started group\n"; + // std::cerr << "group_info: " << static_cast(m_group_info->m_event.get()) << "\n"; } void end_group() { @@ -75,7 +95,7 @@ class communicator_impl : public communicator_base nccl_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, [[maybe_unused]] tag_type tag, void* stream) { - std::cerr << "nccl::send\n"; + // std::cerr << "nccl::send\n"; const_device_guard dg(ptr); OOMPH_CHECK_NCCL_RESULT( @@ -83,8 +103,9 @@ class communicator_impl : public communicator_base if (m_group_info.has_value()) { m_group_info->m_last_stream = static_cast(stream); - // Store event now, but record it when group ends - std::cerr << "using group event " << m_group_info->m_event.get() << "\n"; + // std::cerr << "using group event " << m_group_info->m_event.get() << "\n"; + // The event is stored now, but recorded only in end_group. Until + // an event has been recorded the event is never ready. return {m_group_info->m_event}; } else { detail::cuda_event event; @@ -96,7 +117,7 @@ class communicator_impl : public communicator_base nccl_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, [[maybe_unused]] tag_type tag, void* stream) { - std::cerr << "nccl::recv\n"; + // std::cerr << "nccl::recv\n"; device_guard dg(ptr); OOMPH_CHECK_NCCL_RESULT( @@ -104,8 +125,9 @@ class communicator_impl : public communicator_base if (m_group_info.has_value()) { m_group_info->m_last_stream = static_cast(stream); - // Store event now, but record it when group ends - std::cerr << "using group event " << m_group_info->m_event.get() << "\n"; + // std::cerr << "using group event " << m_group_info->m_event.get() << "\n"; + // The event is stored now, but recorded only in end_group. Until + // an event has been recorded the event is never ready. return {m_group_info->m_event}; } else { detail::cuda_event event; @@ -118,7 +140,6 @@ class communicator_impl : public communicator_base tag_type tag, util::unique_function&& cb, std::size_t* scheduled, void* stream) { auto req = send(ptr, size, dst, tag, stream); - // TODO: Do early checking? auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb), std::move(req)); s->create_self_ref(); m_send_reqs.enqueue(s.get()); @@ -129,7 +150,6 @@ class communicator_impl : public communicator_base tag_type tag, util::unique_function&& cb, std::size_t* scheduled, void* stream) { auto req = recv(ptr, size, src, tag, stream); - // TODO: Do early checking? auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb), std::move(req)); s->create_self_ref(); m_recv_reqs.enqueue(s.get()); @@ -141,7 +161,6 @@ class communicator_impl : public communicator_base std::atomic* scheduled, void* stream) { auto req = recv(ptr, size, src, tag, stream); - // TODO: Do early checking? auto s = std::make_shared(m_context, this, scheduled, src, tag, std::move(cb), std::move(req)); s->create_self_ref(); @@ -151,20 +170,15 @@ class communicator_impl : public communicator_base void progress() { - std::cerr << "nccl communicator::progress\n"; + // std::cerr << "nccl communicator::progress\n"; // Communication progresses independently, but requests must be marked // ready and callbacks must be invoked. m_send_reqs.progress(); m_recv_reqs.progress(); m_context->progress(); - // std::this_thread::sleep_for(std::chrono::seconds(1)); } - bool cancel_recv(detail::request_state*) - { - // TODO: NCCL does not allow cancellation? - return false; - } + bool cancel_recv(detail::request_state*) { return false; } }; } // namespace oomph diff --git a/src/nccl/context.cpp b/src/nccl/context.cpp index 32074ba1..ef36973c 100644 --- a/src/nccl/context.cpp +++ b/src/nccl/context.cpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2025, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. diff --git a/src/nccl/context.hpp b/src/nccl/context.hpp index 2dbd850f..2cdacea8 100644 --- a/src/nccl/context.hpp +++ b/src/nccl/context.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2025, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -43,7 +43,6 @@ class context_impl : public context_base , m_comm{oomph::detail::nccl_comm{comm}} { if (thread_safe) { - // TODO: Appropriate? throw std::runtime_error("NCCL not supported with thread_safe = true"); } } @@ -59,19 +58,9 @@ class context_impl : public context_base communicator_impl* get_communicator(); - void progress() { - m_req_queue.progress(); - } - - bool cancel_recv(detail::shared_request_state*) { - // TODO: Ignore? Can't undo kernel launches. - return false; - } + void progress() { m_req_queue.progress(); } - unsigned int num_tag_bits() const noexcept { - // TODO: Important? Can't use tags with NCCL. - return 32; - } + bool cancel_recv(detail::shared_request_state*) { return false; } const char* get_transport_option(const std::string& opt); }; diff --git a/src/nccl/cuda_error.hpp b/src/nccl/cuda_error.hpp index 04cb8166..0a785b3e 100644 --- a/src/nccl/cuda_error.hpp +++ b/src/nccl/cuda_error.hpp @@ -10,9 +10,10 @@ #pragma once #include +#include #include #include -#include + #include #define OOMPH_CHECK_CUDA_RESULT(x) \ diff --git a/src/nccl/handle.hpp b/src/nccl/handle.hpp index 086f001f..16eb0651 100644 --- a/src/nccl/handle.hpp +++ b/src/nccl/handle.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2025, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -9,6 +9,8 @@ */ #pragma once +#include + namespace oomph { struct handle diff --git a/src/nccl/nccl_communicator.hpp b/src/nccl/nccl_communicator.hpp index 71944fe6..82c75774 100644 --- a/src/nccl/nccl_communicator.hpp +++ b/src/nccl/nccl_communicator.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2025, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -9,13 +9,14 @@ */ #pragma once +#include + #include #include -#include -#include <../mpi_comm.hpp> - -#include +#include "../mpi_comm.hpp" +#include "cuda_error.hpp" +#include "nccl_error.hpp" namespace oomph::detail { @@ -46,10 +47,8 @@ class nccl_comm { if (!m_moved) { - // TODO - // OOMPH_CHECK_CUDA_RESULT_NOEXCEPT(cudaDeviceSynchronize()); - cudaDeviceSynchronize(); - OOMPH_CHECK_NCCL_RESULT_NOEXCEPT(ncclCommDestroy(m_comm)); + OOMPH_CHECK_CUDA_RESULT_NO_THROW(cudaDeviceSynchronize()); + OOMPH_CHECK_NCCL_RESULT_NO_THROW(ncclCommDestroy(m_comm)); } } diff --git a/src/nccl/nccl_error.hpp b/src/nccl/nccl_error.hpp index 6488a028..44423e92 100644 --- a/src/nccl/nccl_error.hpp +++ b/src/nccl/nccl_error.hpp @@ -9,31 +9,24 @@ */ #pragma once -#include - -// TODO: Print error string and code. -#ifdef NDEBUG -#define OOMPH_CHECK_NCCL_RESULT(x) x; -#define OOMPH_CHECK_NCCL_RESULT_NOEXCEPT(x) x; -#else #include #include #include + +#include + #define OOMPH_CHECK_NCCL_RESULT(x) \ { \ ncclResult_t r = x; \ if (r != ncclSuccess && r != ncclInProgress) \ - throw std::runtime_error("OOMPH Error: NCCL Call failed " + std::string(#x) + " = " + std::to_string(r) + " (\"" + ncclGetErrorString(r) + "\") in " + \ - std::string(__FILE__) + ":" + std::to_string(__LINE__)); \ + throw std::runtime_error("OOMPH Error: NCCL Call failed " + std::string(#x) + " = " + \ + std::to_string(r) + " (\"" + ncclGetErrorString(r) + \ + "\") in " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__)); \ } -#define OOMPH_CHECK_NCCL_RESULT_NOEXCEPT(x) \ - { \ - ncclResult_t r = x; \ - if (r != ncclSuccess && r != ncclInProgress) \ - { \ - std::cerr << "OOMPH Error: NCCL Call failed " << std::string(#x) << " in " \ - << std::string(__FILE__) << ":" << std::to_string(__LINE__) << std::endl; \ - std::terminate(); \ - } \ +#define OOMPH_CHECK_NCCL_RESULT_NO_THROW(x) \ + try { OOMPH_CHECK_NCCL_RESULT(x) } \ + catch (const std::exception& e) { \ + std::cerr << e.what() << std::endl; \ + std::terminate(); \ } -#endif diff --git a/src/nccl/region.hpp b/src/nccl/region.hpp index c7593cb0..5bb7b2ba 100644 --- a/src/nccl/region.hpp +++ b/src/nccl/region.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2025, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -10,7 +10,7 @@ #pragma once // paths relative to backend -#include +#include "handle.hpp" namespace oomph { diff --git a/src/nccl/request.hpp b/src/nccl/request.hpp index 37a037ee..4e9e1884 100644 --- a/src/nccl/request.hpp +++ b/src/nccl/request.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2025, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -13,8 +13,8 @@ #include -#include "./cuda_error.hpp" -#include "./cuda_event.hpp" +#include "cuda_error.hpp" +#include "cuda_event.hpp" namespace oomph { @@ -26,7 +26,7 @@ struct nccl_request }, m_event); } - // We can store either a single event for a particular request, or a shared + // We store either a single event for a particular request, or a shared // event that signals the end of a NCCL group. std::variant m_event; }; diff --git a/src/nccl/request_queue.hpp b/src/nccl/request_queue.hpp index f5738563..e0806392 100644 --- a/src/nccl/request_queue.hpp +++ b/src/nccl/request_queue.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2025, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -11,14 +11,14 @@ #include #include + #include // paths relative to backend -#include +#include "request_state.hpp" namespace oomph { - class request_queue { private: @@ -46,7 +46,8 @@ class request_queue int progress() { - std::cerr << "nccl request_queue::progress\n"; + // std::cerr << "nccl request_queue::progress\n"; + if (in_progress) return 0; in_progress = true; @@ -63,7 +64,7 @@ class request_queue // std::cerr << "checking if request ready with event " << req->m_req.m_event << "\n"; if (req->m_req.is_ready()) { auto ptr = req->release_self_ref(); - std::cerr << "invoking callback on req: " << req << "\n"; + // std::cerr << "invoking callback on req: " << req << "\n"; req->invoke_cb(); return true; } else { @@ -72,20 +73,16 @@ class request_queue } ); auto completed = std::distance(erase_begin, m_queue.end()); - if (completed != 0) { - std::cerr << "completed " << completed << " requests\n"; - } + // if (completed != 0) { + // std::cerr << "completed " << completed << " requests\n"; + // } m_queue.erase(erase_begin, m_queue.end()); in_progress = false; return completed; } - bool cancel(element_type*) - { - // No cancellation with NCCL. - return false; - } + bool cancel(element_type*) { return false; } }; class shared_request_queue @@ -117,7 +114,7 @@ class shared_request_queue int progress() { - std::cerr << "nccl shared_request_queue::progress\n"; + // std::cerr << "nccl shared_request_queue::progress\n"; static thread_local bool in_progress = false; static thread_local std::vector m_local_queue; @@ -131,7 +128,7 @@ class shared_request_queue { if (e->m_req.is_ready()) { - std::cerr << "found ready request in shared queue\n"; + // std::cerr << "found ready request in shared queue\n"; found = 1; break; } @@ -155,11 +152,6 @@ class shared_request_queue return found; } - bool cancel(element_type*) - { - // No cancellation with NCCL. - return false; - } + bool cancel(element_type*) { return false; } }; - } // namespace oomph diff --git a/src/nccl/request_state.hpp b/src/nccl/request_state.hpp index 26a5d758..edbd383e 100644 --- a/src/nccl/request_state.hpp +++ b/src/nccl/request_state.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2025, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -13,11 +13,9 @@ // paths relative to backend #include "../request_state_base.hpp" -#include "./request.hpp" +#include "request.hpp" -namespace oomph -{ -namespace detail +namespace oomph::detail { struct request_state : public util::enable_shared_from_this @@ -35,7 +33,7 @@ struct request_state : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} , m_req{std::move(m)} { - std::cerr << "creating nccl request_state\n"; + // std::cerr << "creating nccl request_state\n"; } void progress(); @@ -94,6 +92,4 @@ struct shared_request_state return std::move(m_self_ptr); } }; - -} // namespace detail } // namespace oomph diff --git a/src/request_state_base.hpp b/src/request_state_base.hpp index 1383dce6..9110aa71 100644 --- a/src/request_state_base.hpp +++ b/src/request_state_base.hpp @@ -97,7 +97,7 @@ struct request_state_base void invoke_cb() { - std::cerr << "invoke_cb, setting m_ready to true\n"; + // std::cerr << "invoke_cb, setting m_ready to true\n"; m_cb(m_rank, m_tag); --(*m_scheduled); traits::store(m_ready, true); diff --git a/test/test_send_recv.cpp b/test/test_send_recv.cpp index de7826f9..c4fb38bf 100644 --- a/test/test_send_recv.cpp +++ b/test/test_send_recv.cpp @@ -195,22 +195,27 @@ launch_test(Func f) } // multi threaded - // TODO: Don't run for NCCL, run for others. - // { - // std::cerr << "multi threaded\n"; - // oomph::context ctxt(MPI_COMM_WORLD, true); - // std::vector threads; - // threads.reserve(NTHREADS); - // reset_counters(); - // for (int i = 0; i < NTHREADS; ++i) - // threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, false}); - // for (auto& t : threads) t.join(); - // threads.clear(); - // reset_counters(); - // for (int i = 0; i < NTHREADS; ++i) - // threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, true}); - // for (auto& t : threads) t.join(); - // } + try { + std::cerr << "multi threaded\n"; + oomph::context ctxt(MPI_COMM_WORLD, true); + std::vector threads; + threads.reserve(NTHREADS); + reset_counters(); + for (int i = 0; i < NTHREADS; ++i) + threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, false}); + for (auto& t : threads) t.join(); + threads.clear(); + reset_counters(); + for (int i = 0; i < NTHREADS; ++i) + threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, true}); + for (auto& t : threads) t.join(); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } // no callback @@ -274,8 +279,7 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, TEST_F(mpi_test_fixture, send_recv) { - // TODO: Only device tests with NCCL. - // launch_test(test_send_recv); + launch_test(test_send_recv); #if HWMALLOC_ENABLE_DEVICE launch_test(test_send_recv); #endif @@ -318,8 +322,10 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa // use test() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.recv(env.rmsg, env.rpeer_rank, 1, recv_callback); auto sh = env.comm.send(env.smsg, env.speer_rank, 1, send_callback); + env.comm.end_group(); while (!rh.test() || !sh.test()) {} EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -332,8 +338,11 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa // use wait() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.recv(env.rmsg, env.rpeer_rank, 1, recv_callback); - env.comm.send(env.smsg, env.speer_rank, 1, send_callback).wait(); + auto sh = env.comm.send(env.smsg, env.speer_rank, 1, send_callback); + env.comm.end_group(); + sh.wait(); rh.wait(); EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -342,283 +351,307 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa EXPECT_EQ(sent, NITERS); } -// TEST_F(mpi_test_fixture, send_recv_cb) -// { -// launch_test(test_send_recv_cb); -// #if HWMALLOC_ENABLE_DEVICE -// launch_test(test_send_recv_cb); -// #endif -// } -// -// // callback: pass by r-value reference (give up ownership) -// // ======================================================= -// template -// void -// test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, -// bool user_alloc) -// { -// using rank_type = test_environment::rank_type; -// using tag_type = test_environment::tag_type; -// using message = test_environment::message; -// -// Env env(ctxt, size, tid, num_threads, user_alloc); -// -// volatile int received = 0; -// volatile int sent = 0; -// -// auto send_callback = [&](message msg, rank_type, tag_type) -// { -// ++sent; -// env.smsg = std::move(msg); -// }; -// auto recv_callback = [&](message msg, rank_type, tag_type) -// { -// ++received; -// env.rmsg = std::move(msg); -// }; -// -// // use is_ready() -> must manually progress the communicator -// for (int i = 0; i < NITERS; i++) -// { -// auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); -// auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); -// while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } -// EXPECT_TRUE(env.check_recv_buffer()); -// env.fill_recv_buffer(); -// } -// EXPECT_EQ(received, NITERS); -// EXPECT_EQ(sent, NITERS); -// -// received = 0; -// sent = 0; -// // use test() -> communicator is progressed automatically -// for (int i = 0; i < NITERS; i++) -// { -// auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); -// auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); -// while (!rh.test() || !sh.test()) {} -// EXPECT_TRUE(env.check_recv_buffer()); -// env.fill_recv_buffer(); -// } -// EXPECT_EQ(received, NITERS); -// EXPECT_EQ(sent, NITERS); -// -// received = 0; -// sent = 0; -// // use wait() -> communicator is progressed automatically -// for (int i = 0; i < NITERS; i++) -// { -// auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); -// env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback).wait(); -// rh.wait(); -// EXPECT_TRUE(env.check_recv_buffer()); -// env.fill_recv_buffer(); -// } -// EXPECT_EQ(received, NITERS); -// EXPECT_EQ(sent, NITERS); -// } -// -// TEST_F(mpi_test_fixture, send_recv_cb_disown) -// { -// launch_test(test_send_recv_cb_disown); -// #if HWMALLOC_ENABLE_DEVICE -// launch_test(test_send_recv_cb_disown); -// #endif -// } -// -// // callback: pass by r-value reference (give up ownership), shared recv -// // ==================================================================== -// template -// void -// test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, -// bool user_alloc) -// { -// using rank_type = test_environment::rank_type; -// using tag_type = test_environment::tag_type; -// using message = test_environment::message; -// -// Env env(ctxt, size, tid, num_threads, user_alloc); -// -// thread_id = env.thread_id; -// -// //volatile int received = 0; -// volatile int sent = 0; -// -// auto send_callback = [&](message msg, rank_type, tag_type) -// { -// ++sent; -// env.smsg = std::move(msg); -// }; -// auto recv_callback = [&](message msg, rank_type, tag_type) -// { -// //std::cout << thread_id << " " << env.thread_id << std::endl; -// //if (thread_id != env.thread_id) std::cout << "other thread picked up callback" << std::endl; -// //else std::cout << "my thread picked up callback" << std::endl; -// env.rmsg = std::move(msg); -// ++shared_received[env.thread_id]; -// }; -// -// // use is_ready() -> must manually progress the communicator -// for (int i = 0; i < NITERS; i++) -// { -// auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); -// auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); -// while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } -// EXPECT_TRUE(env.rmsg); -// EXPECT_TRUE(env.check_recv_buffer()); -// env.fill_recv_buffer(); -// } -// EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); -// EXPECT_EQ(sent, NITERS); -// -// shared_received[env.thread_id].store(0); -// sent = 0; -// // use test() -> communicator is progressed automatically -// for (int i = 0; i < NITERS; i++) -// { -// auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); -// auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); -// while (!rh.test() || !sh.test()) {} -// EXPECT_TRUE(env.check_recv_buffer()); -// env.fill_recv_buffer(); -// } -// EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); -// EXPECT_EQ(sent, NITERS); -// -// shared_received[env.thread_id].store(0); -// sent = 0; -// // use wait() -> communicator is progressed automatically -// for (int i = 0; i < NITERS; i++) -// { -// auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); -// env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback).wait(); -// rh.wait(); -// EXPECT_TRUE(env.check_recv_buffer()); -// env.fill_recv_buffer(); -// } -// EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); -// EXPECT_EQ(sent, NITERS); -// } -// -// TEST_F(mpi_test_fixture, send_shared_recv_cb_disown) -// { -// launch_test(test_send_shared_recv_cb_disown); -// #if HWMALLOC_ENABLE_DEVICE -// launch_test(test_send_shared_recv_cb_disown); -// #endif -// } -// -// // callback: pass by l-value reference, and resubmit -// // ================================================= -// template -// void -// test_send_recv_cb_resubmit(oomph::context& ctxt, std::size_t size, int tid, int num_threads, -// bool user_alloc) -// { -// using rank_type = test_environment::rank_type; -// using tag_type = test_environment::tag_type; -// using message = test_environment::message; -// -// Env env(ctxt, size, tid, num_threads, user_alloc); -// -// volatile int received = 0; -// volatile int sent = 0; -// -// struct recursive_send_callback -// { -// Env& env; -// volatile int& sent; -// -// void operator()(message& msg, rank_type dst, tag_type tag) -// { -// ++sent; -// if (sent < NITERS) env.comm.send(msg, dst, tag, recursive_send_callback{*this}); -// } -// }; -// -// struct recursive_recv_callback -// { -// Env& env; -// volatile int& received; -// -// void operator()(message& msg, rank_type src, tag_type tag) -// { -// ++received; -// EXPECT_TRUE(env.check_recv_buffer()); -// env.fill_recv_buffer(); -// if (received < NITERS) env.comm.recv(msg, src, tag, recursive_recv_callback{*this}); -// } -// }; -// -// env.comm.recv(env.rmsg, env.rpeer_rank, 1, recursive_recv_callback{env, received}); -// env.comm.send(env.smsg, env.speer_rank, 1, recursive_send_callback{env, sent}); -// -// while (sent < NITERS || received < NITERS) { env.comm.progress(); }; -// } -// -// TEST_F(mpi_test_fixture, send_recv_cb_resubmit) -// { -// launch_test(test_send_recv_cb_resubmit); -// #if HWMALLOC_ENABLE_DEVICE -// launch_test(test_send_recv_cb_resubmit); -// #endif -// } -// -// // callback: pass by r-value reference (give up ownership), and resubmit -// // ===================================================================== -// template -// void -// test_send_recv_cb_resubmit_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, -// bool user_alloc) -// { -// using rank_type = test_environment::rank_type; -// using tag_type = test_environment::tag_type; -// using message = test_environment::message; -// -// Env env(ctxt, size, tid, num_threads, user_alloc); -// -// volatile int received = 0; -// volatile int sent = 0; -// -// struct recursive_send_callback -// { -// Env& env; -// volatile int& sent; -// -// void operator()(message msg, rank_type dst, tag_type tag) -// { -// ++sent; -// if (sent < NITERS) -// env.comm.send(std::move(msg), dst, tag, recursive_send_callback{*this}); -// } -// }; -// -// struct recursive_recv_callback -// { -// Env& env; -// volatile int& received; -// -// void operator()(message msg, rank_type src, tag_type tag) -// { -// ++received; -// env.rmsg = std::move(msg); -// EXPECT_TRUE(env.check_recv_buffer()); -// env.fill_recv_buffer(); -// if (received < NITERS) -// env.comm.recv(std::move(env.rmsg), src, tag, recursive_recv_callback{*this}); -// } -// }; -// -// env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recursive_recv_callback{env, received}); -// env.comm.send(std::move(env.smsg), env.speer_rank, 1, recursive_send_callback{env, sent}); -// -// while (sent < NITERS || received < NITERS) { env.comm.progress(); }; -// } -// -// TEST_F(mpi_test_fixture, send_recv_cb_resubmit_disown) -// { -// launch_test(test_send_recv_cb_resubmit_disown); -// #if HWMALLOC_ENABLE_DEVICE -// launch_test(test_send_recv_cb_resubmit_disown); -// #endif -// } +TEST_F(mpi_test_fixture, send_recv_cb) +{ + launch_test(test_send_recv_cb); +#if HWMALLOC_ENABLE_DEVICE + launch_test(test_send_recv_cb); +#endif +} + +// callback: pass by r-value reference (give up ownership) +// ======================================================= +template +void +test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, + bool user_alloc) +{ + using rank_type = test_environment::rank_type; + using tag_type = test_environment::tag_type; + using message = test_environment::message; + + Env env(ctxt, size, tid, num_threads, user_alloc); + + volatile int received = 0; + volatile int sent = 0; + + auto send_callback = [&](message msg, rank_type, tag_type) + { + ++sent; + env.smsg = std::move(msg); + }; + auto recv_callback = [&](message msg, rank_type, tag_type) + { + ++received; + env.rmsg = std::move(msg); + }; + + // use is_ready() -> must manually progress the communicator + for (int i = 0; i < NITERS; i++) + { + env.comm.start_group(); + auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); + auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); + while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } + EXPECT_TRUE(env.check_recv_buffer()); + env.fill_recv_buffer(); + } + EXPECT_EQ(received, NITERS); + EXPECT_EQ(sent, NITERS); + + received = 0; + sent = 0; + // use test() -> communicator is progressed automatically + for (int i = 0; i < NITERS; i++) + { + env.comm.start_group(); + auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); + auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); + while (!rh.test() || !sh.test()) {} + EXPECT_TRUE(env.check_recv_buffer()); + env.fill_recv_buffer(); + } + EXPECT_EQ(received, NITERS); + EXPECT_EQ(sent, NITERS); + + received = 0; + sent = 0; + // use wait() -> communicator is progressed automatically + for (int i = 0; i < NITERS; i++) + { + env.comm.start_group(); + auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); + auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); + sh.wait(); + rh.wait(); + EXPECT_TRUE(env.check_recv_buffer()); + env.fill_recv_buffer(); + } + EXPECT_EQ(received, NITERS); + EXPECT_EQ(sent, NITERS); +} + +TEST_F(mpi_test_fixture, send_recv_cb_disown) +{ + launch_test(test_send_recv_cb_disown); +#if HWMALLOC_ENABLE_DEVICE + launch_test(test_send_recv_cb_disown); +#endif +} + +// callback: pass by r-value reference (give up ownership), shared recv +// ==================================================================== +template +void +test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, + bool user_alloc) +{ + using rank_type = test_environment::rank_type; + using tag_type = test_environment::tag_type; + using message = test_environment::message; + + Env env(ctxt, size, tid, num_threads, user_alloc); + + thread_id = env.thread_id; + + //volatile int received = 0; + volatile int sent = 0; + + auto send_callback = [&](message msg, rank_type, tag_type) + { + ++sent; + env.smsg = std::move(msg); + }; + auto recv_callback = [&](message msg, rank_type, tag_type) + { + //std::cout << thread_id << " " << env.thread_id << std::endl; + //if (thread_id != env.thread_id) std::cout << "other thread picked up callback" << std::endl; + //else std::cout << "my thread picked up callback" << std::endl; + env.rmsg = std::move(msg); + ++shared_received[env.thread_id]; + }; + + // use is_ready() -> must manually progress the communicator + for (int i = 0; i < NITERS; i++) + { + env.comm.start_group(); + auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); + auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); + while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } + EXPECT_TRUE(env.rmsg); + EXPECT_TRUE(env.check_recv_buffer()); + env.fill_recv_buffer(); + } + EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); + EXPECT_EQ(sent, NITERS); + + shared_received[env.thread_id].store(0); + sent = 0; + // use test() -> communicator is progressed automatically + for (int i = 0; i < NITERS; i++) + { + env.comm.start_group(); + auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); + auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); + while (!rh.test() || !sh.test()) {} + EXPECT_TRUE(env.check_recv_buffer()); + env.fill_recv_buffer(); + } + EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); + EXPECT_EQ(sent, NITERS); + + shared_received[env.thread_id].store(0); + sent = 0; + // use wait() -> communicator is progressed automatically + for (int i = 0; i < NITERS; i++) + { + env.comm.start_group(); + auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); + auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); + sh.wait(); + rh.wait(); + EXPECT_TRUE(env.check_recv_buffer()); + env.fill_recv_buffer(); + } + EXPECT_EQ(shared_received[env.thread_id].load(), NITERS); + EXPECT_EQ(sent, NITERS); +} + +TEST_F(mpi_test_fixture, send_shared_recv_cb_disown) +{ + launch_test(test_send_shared_recv_cb_disown); +#if HWMALLOC_ENABLE_DEVICE + launch_test(test_send_shared_recv_cb_disown); +#endif +} + +// callback: pass by l-value reference, and resubmit +// ================================================= +template +void +test_send_recv_cb_resubmit(oomph::context& ctxt, std::size_t size, int tid, int num_threads, + bool user_alloc) +{ + if (ctxt.get_transport_option("name") == std::string("nccl")) { + // Skip for NCCL. Recursive comms hangs. TODO: Does it have to hang? + return; + } + + using rank_type = test_environment::rank_type; + using tag_type = test_environment::tag_type; + using message = test_environment::message; + + Env env(ctxt, size, tid, num_threads, user_alloc); + + volatile int received = 0; + volatile int sent = 0; + + struct recursive_send_callback + { + Env& env; + volatile int& sent; + + void operator()(message& msg, rank_type dst, tag_type tag) + { + ++sent; + if (sent < NITERS) env.comm.send(msg, dst, tag, recursive_send_callback{*this}); + } + }; + + struct recursive_recv_callback + { + Env& env; + volatile int& received; + + void operator()(message& msg, rank_type src, tag_type tag) + { + ++received; + EXPECT_TRUE(env.check_recv_buffer()); + env.fill_recv_buffer(); + if (received < NITERS) env.comm.recv(msg, src, tag, recursive_recv_callback{*this}); + } + }; + + env.comm.recv(env.rmsg, env.rpeer_rank, 1, recursive_recv_callback{env, received}); + env.comm.send(env.smsg, env.speer_rank, 1, recursive_send_callback{env, sent}); + + while (sent < NITERS || received < NITERS) { env.comm.progress(); }; +} + +TEST_F(mpi_test_fixture, send_recv_cb_resubmit) +{ + launch_test(test_send_recv_cb_resubmit); +#if HWMALLOC_ENABLE_DEVICE + launch_test(test_send_recv_cb_resubmit); +#endif +} + +// callback: pass by r-value reference (give up ownership), and resubmit +// ===================================================================== +template +void +test_send_recv_cb_resubmit_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, + bool user_alloc) +{ + if (ctxt.get_transport_option("name") == std::string("nccl")) { + // Skip for NCCL. Recursive comms hangs. TODO: Does it have to hang? + return; + } + + using rank_type = test_environment::rank_type; + using tag_type = test_environment::tag_type; + using message = test_environment::message; + + Env env(ctxt, size, tid, num_threads, user_alloc); + + volatile int received = 0; + volatile int sent = 0; + + struct recursive_send_callback + { + Env& env; + volatile int& sent; + + void operator()(message msg, rank_type dst, tag_type tag) + { + ++sent; + if (sent < NITERS) + env.comm.send(std::move(msg), dst, tag, recursive_send_callback{*this}); + } + }; + + struct recursive_recv_callback + { + Env& env; + volatile int& received; + + void operator()(message msg, rank_type src, tag_type tag) + { + ++received; + env.rmsg = std::move(msg); + EXPECT_TRUE(env.check_recv_buffer()); + env.fill_recv_buffer(); + if (received < NITERS) + env.comm.recv(std::move(env.rmsg), src, tag, recursive_recv_callback{*this}); + } + }; + + env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recursive_recv_callback{env, received}); + env.comm.send(std::move(env.smsg), env.speer_rank, 1, recursive_send_callback{env, sent}); + + while (sent < NITERS || received < NITERS) { env.comm.progress(); }; +} + +TEST_F(mpi_test_fixture, send_recv_cb_resubmit_disown) +{ + launch_test(test_send_recv_cb_resubmit_disown); +#if HWMALLOC_ENABLE_DEVICE + launch_test(test_send_recv_cb_resubmit_disown); +#endif +} From f8c3258ebc41e22ffd031769e8bc9c4c0f879d41 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Mon, 22 Dec 2025 12:16:01 +0100 Subject: [PATCH 12/25] Clean up and disable some tests with NCCL --- src/nccl/request_state.hpp | 2 +- src/request.cpp | 6 +- test/test_barrier.cpp | 166 +++++++++++++++++++++++-------------- test/test_cancel.cpp | 76 +++++++++++------ test/test_context.cpp | 98 ++++++++++++---------- test/test_locality.cpp | 1 + test/test_send_recv.cpp | 65 ++++++++------- 7 files changed, 244 insertions(+), 170 deletions(-) diff --git a/src/nccl/request_state.hpp b/src/nccl/request_state.hpp index edbd383e..bba9ce1f 100644 --- a/src/nccl/request_state.hpp +++ b/src/nccl/request_state.hpp @@ -71,7 +71,7 @@ struct shared_request_state : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} , m_req{std::move(m)} { - std::cerr << "creating nccl shared_request_state\n"; + // std::cerr << "creating nccl shared_request_state\n"; } void progress(); diff --git a/src/request.cpp b/src/request.cpp index 9749979d..1f210769 100644 --- a/src/request.cpp +++ b/src/request.cpp @@ -24,7 +24,7 @@ namespace oomph bool send_request::is_ready() const noexcept { - std::cerr << "send_request::is_ready()\n"; + // std::cerr << "send_request::is_ready()\n"; if (!m) return true; return m->is_ready(); } @@ -86,7 +86,7 @@ recv_request::cancel() bool shared_recv_request::is_ready() const noexcept { - std::cerr << "shared_recv_request::is_ready()\n"; + // std::cerr << "shared_recv_request::is_ready()\n"; if (!m) return true; return m->is_ready(); } @@ -124,7 +124,7 @@ shared_recv_request::cancel() bool send_multi_request::is_ready() const noexcept { - std::cerr << "send_multi_request::is_ready()\n"; + // std::cerr << "send_multi_request::is_ready()\n"; if (!m) return true; return (m->m_counter == 0); } diff --git a/test/test_barrier.cpp b/test/test_barrier.cpp index 3016c091..bc2ac320 100644 --- a/test/test_barrier.cpp +++ b/test/test_barrier.cpp @@ -55,98 +55,138 @@ class test_barrier TEST_F(mpi_test_fixture, in_node1) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); - - oomph::test_barrier{b}.test_in_node1(ctxt); + try { + auto ctxt = context(MPI_COMM_WORLD, true); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); + + oomph::test_barrier{b}.test_in_node1(ctxt); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } TEST_F(mpi_test_fixture, in_barrier_1) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + try { + auto ctxt = context(MPI_COMM_WORLD, true); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); - auto comm = ctxt.get_communicator(); - auto comm2 = ctxt.get_communicator(); + auto comm = ctxt.get_communicator(); + auto comm2 = ctxt.get_communicator(); - for (int i = 0; i < 20; i++) { b.rank_barrier(); } + for (int i = 0; i < 20; i++) { b.rank_barrier(); } + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } TEST_F(mpi_test_fixture, in_barrier) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + try { + auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); - auto work = [&]() - { - auto comm = ctxt.get_communicator(); - auto comm2 = ctxt.get_communicator(); - for (int i = 0; i < 10; i++) + auto work = [&]() { - comm.progress(); - b.thread_barrier(); - } - }; + auto comm = ctxt.get_communicator(); + auto comm2 = ctxt.get_communicator(); + for (int i = 0; i < 10; i++) + { + comm.progress(); + b.thread_barrier(); + } + }; - std::vector ths; - for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work}); } - for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + std::vector ths; + for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work}); } + for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } TEST_F(mpi_test_fixture, full_barrier) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + try { + auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); - auto work = [&]() - { - auto comm = ctxt.get_communicator(); - auto comm3 = ctxt.get_communicator(); - for (int i = 0; i < 10; i++) { b(); } - }; + auto work = [&]() + { + auto comm = ctxt.get_communicator(); + auto comm3 = ctxt.get_communicator(); + for (int i = 0; i < 10; i++) { b(); } + }; - std::vector ths; - for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work}); } - for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + std::vector ths; + for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work}); } + for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } TEST_F(mpi_test_fixture, full_barrier_sendrecv) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + try { + auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); - auto work = [&](int tid) - { - auto comm = ctxt.get_communicator(); - auto comm2 = ctxt.get_communicator(); - int s_rank = (tid < 3) ? comm.rank() : ((comm.rank() + 1) % comm.size()); - int s_tag = comm.rank() * 10 + tid; - int r_rank = (tid > 0) ? comm.rank() : ((comm.rank() + comm.size() - 1) % comm.size()); - int r_tag = (tid > 0) ? (comm.rank() * 10 + tid - 1) : (r_rank * 10 + n_threads - 1); - - auto s_buffer = comm.make_buffer(1000); - auto r_buffer = comm.make_buffer(1000); - for (auto& x : s_buffer) x = s_tag; - auto r_req = comm.recv(r_buffer, r_rank, r_tag); - auto s_req = comm.send(s_buffer, s_rank, s_tag); - b(); - while (!(r_req.test() && s_req.test())) {}; - b(); - }; - - std::vector ths; - for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work, i}); } - for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + auto work = [&](int tid) + { + auto comm = ctxt.get_communicator(); + auto comm2 = ctxt.get_communicator(); + int s_rank = (tid < 3) ? comm.rank() : ((comm.rank() + 1) % comm.size()); + int s_tag = comm.rank() * 10 + tid; + int r_rank = (tid > 0) ? comm.rank() : ((comm.rank() + comm.size() - 1) % comm.size()); + int r_tag = (tid > 0) ? (comm.rank() * 10 + tid - 1) : (r_rank * 10 + n_threads - 1); + + auto s_buffer = comm.make_buffer(1000); + auto r_buffer = comm.make_buffer(1000); + for (auto& x : s_buffer) x = s_tag; + auto r_req = comm.recv(r_buffer, r_rank, r_tag); + auto s_req = comm.send(s_buffer, s_rank, s_tag); + b(); + while (!(r_req.test() && s_req.test())) {}; + b(); + }; + + std::vector ths; + for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work, i}); } + for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } diff --git a/test/test_cancel.cpp b/test/test_cancel.cpp index f00ed737..4c5b41e7 100644 --- a/test/test_cancel.cpp +++ b/test/test_cancel.cpp @@ -65,6 +65,10 @@ TEST_F(mpi_test_fixture, test_cancel_request) { using namespace oomph; auto ctxt = context(MPI_COMM_WORLD, false); + if (ctxt.get_transport_option("name") == std::string("nccl")) { + // NCCL does not support cancellation + return; + } auto comm = ctxt.get_communicator(); test_1(comm, 1); test_1(comm, 32); @@ -74,19 +78,27 @@ TEST_F(mpi_test_fixture, test_cancel_request) TEST_F(mpi_test_fixture, test_cancel_request_mt) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - - std::vector threads; - threads.reserve(n_threads); - for (size_t i = 0; i < n_threads; ++i) - threads.push_back(std::thread{[&ctxt, i]() { - auto comm = ctxt.get_communicator(); - test_1(comm, 1, i); - test_1(comm, 32, i); - test_1(comm, 4096, i); - }}); - for (auto& t : threads) t.join(); + try { + auto ctxt = context(MPI_COMM_WORLD, true); + std::size_t n_threads = 4; + + std::vector threads; + threads.reserve(n_threads); + for (size_t i = 0; i < n_threads; ++i) + threads.push_back(std::thread{[&ctxt, i]() { + auto comm = ctxt.get_communicator(); + test_1(comm, 1, i); + test_1(comm, 32, i); + test_1(comm, 4096, i); + }}); + for (auto& t : threads) t.join(); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } void @@ -145,6 +157,10 @@ TEST_F(mpi_test_fixture, test_cancel_cb) { using namespace oomph; auto ctxt = context(MPI_COMM_WORLD, false); + if (ctxt.get_transport_option("name") == std::string("nccl")) { + // NCCL does not support cancellation + return; + } auto comm = ctxt.get_communicator(); test_2(comm, 1); test_2(comm, 32); @@ -154,17 +170,25 @@ TEST_F(mpi_test_fixture, test_cancel_cb) TEST_F(mpi_test_fixture, test_cancel_cb_mt) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - - std::vector threads; - threads.reserve(n_threads); - for (size_t i = 0; i < n_threads; ++i) - threads.push_back(std::thread{[&ctxt, i]() { - auto comm = ctxt.get_communicator(); - test_2(comm, 1, i); - test_2(comm, 32, i); - test_2(comm, 4096, i); - }}); - for (auto& t : threads) t.join(); + try { + auto ctxt = context(MPI_COMM_WORLD, true); + std::size_t n_threads = 4; + + std::vector threads; + threads.reserve(n_threads); + for (size_t i = 0; i < n_threads; ++i) + threads.push_back(std::thread{[&ctxt, i]() { + auto comm = ctxt.get_communicator(); + test_2(comm, 1, i); + test_2(comm, 32, i); + test_2(comm, 4096, i); + }}); + for (auto& t : threads) t.join(); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } diff --git a/test/test_context.cpp b/test/test_context.cpp index 930c248a..6813b68d 100644 --- a/test/test_context.cpp +++ b/test/test_context.cpp @@ -20,57 +20,65 @@ const int num_threads = 4; TEST_F(mpi_test_fixture, context_ordered) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + try { + auto ctxt = context(MPI_COMM_WORLD, true); - //auto func = [&ctxt](int tid) - //{ - // auto comm = ctxt.get_communicator(); - // auto smsg_1 = comm.make_buffer(size); - // auto smsg_2 = comm.make_buffer(size); - // auto rmsg_1 = comm.make_buffer(size); - // auto rmsg_2 = comm.make_buffer(size); - // bool sent_1 = false; - // bool sent_2 = false; - // if (comm.rank() == 0) - // { - // const int payload_offset = 1 + tid; - // for (unsigned int i = 0; i < size; ++i) - // { - // smsg_1[i] = i + payload_offset; - // smsg_2[i] = i + payload_offset + 1; - // } - // std::vector neighs(comm.size()>1 ? comm.size() - 1 : 1, 0); - // for (int i = 1; i < comm.size(); ++i) neighs[i - 1] = i; + //auto func = [&ctxt](int tid) + //{ + // auto comm = ctxt.get_communicator(); + // auto smsg_1 = comm.make_buffer(size); + // auto smsg_2 = comm.make_buffer(size); + // auto rmsg_1 = comm.make_buffer(size); + // auto rmsg_2 = comm.make_buffer(size); + // bool sent_1 = false; + // bool sent_2 = false; + // if (comm.rank() == 0) + // { + // const int payload_offset = 1 + tid; + // for (unsigned int i = 0; i < size; ++i) + // { + // smsg_1[i] = i + payload_offset; + // smsg_2[i] = i + payload_offset + 1; + // } + // std::vector neighs(comm.size()>1 ? comm.size() - 1 : 1, 0); + // for (int i = 1; i < comm.size(); ++i) neighs[i - 1] = i; - // comm.send_multi(std::move(smsg_1), neighs, tid, - // [&sent_1](decltype(smsg_1), std::vector, tag_type) { sent_1 = true; }); + // comm.send_multi(std::move(smsg_1), neighs, tid, + // [&sent_1](decltype(smsg_1), std::vector, tag_type) { sent_1 = true; }); - // comm.send_multi(std::move(smsg_2), neighs, tid, - // [&sent_2](decltype(smsg_2), std::vector, tag_type) { sent_2 = true; }); + // comm.send_multi(std::move(smsg_2), neighs, tid, + // [&sent_2](decltype(smsg_2), std::vector, tag_type) { sent_2 = true; }); - // } - // if (comm.rank() > 0 || comm.size() == 1) - // { - // // ordered sends/recvs with same tag should arrive in order - // comm.recv(rmsg_1, 0, tid).wait(); - // comm.recv(rmsg_2, 0, tid).wait(); + // } + // if (comm.rank() > 0 || comm.size() == 1) + // { + // // ordered sends/recvs with same tag should arrive in order + // comm.recv(rmsg_1, 0, tid).wait(); + // comm.recv(rmsg_2, 0, tid).wait(); - // // check message - // const int payload_offset = 1 + tid; - // for (unsigned int i = 0; i < size; ++i) - // { - // EXPECT_EQ(rmsg_1[i], i + payload_offset); - // EXPECT_EQ(rmsg_2[i], i + payload_offset + 1); - // } - // } - // if (comm.rank() == 0) - // while (!sent_1 || !sent_2) { comm.progress(); } - //}; + // // check message + // const int payload_offset = 1 + tid; + // for (unsigned int i = 0; i < size; ++i) + // { + // EXPECT_EQ(rmsg_1[i], i + payload_offset); + // EXPECT_EQ(rmsg_2[i], i + payload_offset + 1); + // } + // } + // if (comm.rank() == 0) + // while (!sent_1 || !sent_2) { comm.progress(); } + //}; - //std::vector threads; - //threads.reserve(num_threads); - //for (int i = 0; i < num_threads; ++i) threads.push_back(std::thread{func, i}); - //for (auto& t : threads) t.join(); + //std::vector threads; + //threads.reserve(num_threads); + //for (int i = 0; i < num_threads; ++i) threads.push_back(std::thread{func, i}); + //for (auto& t : threads) t.join(); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } //TEST_F(mpi_test_fixture, context_multi) diff --git a/test/test_locality.cpp b/test/test_locality.cpp index 80e5e1ab..ce0ea126 100644 --- a/test/test_locality.cpp +++ b/test/test_locality.cpp @@ -42,6 +42,7 @@ TEST_F(mpi_test_fixture, locality_enumerate) gethostname(my_host_name.data(), HOST_NAME_MAX + 1); for (int r = 0; r < comm.size(); ++r) { + // TODO: Can this be made to work with NCCL? if (r == comm.rank()) { for (int rr = 0; rr < comm.size(); ++rr) diff --git a/test/test_send_recv.cpp b/test/test_send_recv.cpp index c4fb38bf..87e5b4ec 100644 --- a/test/test_send_recv.cpp +++ b/test/test_send_recv.cpp @@ -182,11 +182,11 @@ template void launch_test(Func f) { - std::cerr << "launch_test\n"; + // std::cerr << "launch_test\n"; // single threaded { - std::cerr << "single threaded\n"; + // std::cerr << "single threaded\n"; oomph::context ctxt(MPI_COMM_WORLD, false); reset_counters(); f(ctxt, SIZE, 0, 1, false); @@ -194,28 +194,28 @@ launch_test(Func f) f(ctxt, SIZE, 0, 1, true); } - // multi threaded - try { - std::cerr << "multi threaded\n"; - oomph::context ctxt(MPI_COMM_WORLD, true); - std::vector threads; - threads.reserve(NTHREADS); - reset_counters(); - for (int i = 0; i < NTHREADS; ++i) - threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, false}); - for (auto& t : threads) t.join(); - threads.clear(); - reset_counters(); - for (int i = 0; i < NTHREADS; ++i) - threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, true}); - for (auto& t : threads) t.join(); - } catch (std::runtime_error const& e) { - if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { - EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); - } else { - throw e; - } - } +// // multi threaded +// try { +// // std::cerr << "multi threaded\n"; +// oomph::context ctxt(MPI_COMM_WORLD, true); +// std::vector threads; +// threads.reserve(NTHREADS); +// reset_counters(); +// for (int i = 0; i < NTHREADS; ++i) +// threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, false}); +// for (auto& t : threads) t.join(); +// threads.clear(); +// reset_counters(); +// for (int i = 0; i < NTHREADS; ++i) +// threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, true}); +// for (auto& t : threads) t.join(); +// } catch (std::runtime_error const& e) { +// if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { +// EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); +// } else { +// throw e; +// } +// } } // no callback @@ -227,27 +227,27 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, Env env(ctxt, size, tid, num_threads, user_alloc); // use is_ready() -> must manually progress the communicator - std::cerr << "test_send_recv 1\n"; + // std::cerr << "test_send_recv 1\n"; for (int i = 0; i < NITERS; i++) { - std::cerr << "iteration " << i << "\n"; + // std::cerr << "iteration " << i << "\n"; env.comm.start_group(); auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); auto sreq = env.comm.send(env.smsg, env.speer_rank, env.tag); env.comm.end_group(); - std::cerr << "rreq.is_ready() = " << rreq.is_ready() << '\n'; - std::cerr << "sreq.is_ready() = " << sreq.is_ready() << '\n'; + // std::cerr << "rreq.is_ready() = " << rreq.is_ready() << '\n'; + // std::cerr << "sreq.is_ready() = " << sreq.is_ready() << '\n'; while (!(rreq.is_ready() && sreq.is_ready())) { - std::cerr << "calling env.comm.progress()\n"; + // std::cerr << "calling env.comm.progress()\n"; env.comm.progress(); }; EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); } - std::cerr << "test_send_recv 1 done\n"; + // std::cerr << "test_send_recv 1 done\n"; - std::cerr << "test_send_recv 2\n"; + // std::cerr << "test_send_recv 2\n"; // use test() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { @@ -260,7 +260,7 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, env.fill_recv_buffer(); } - std::cerr << "test_send_recv 3\n"; + // std::cerr << "test_send_recv 3\n"; // use wait() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { @@ -353,6 +353,7 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa TEST_F(mpi_test_fixture, send_recv_cb) { + // TODO: With aws-ofi-nccl, the second init segfaults. Why? launch_test(test_send_recv_cb); #if HWMALLOC_ENABLE_DEVICE launch_test(test_send_recv_cb); From 5908b18285699d5bf36dec4b6df38450bc9027f1 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Mon, 22 Dec 2025 12:18:46 +0100 Subject: [PATCH 13/25] Remove TODO --- include/oomph/channel/send_channel.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/oomph/channel/send_channel.hpp b/include/oomph/channel/send_channel.hpp index e60778f1..c6fb75d7 100644 --- a/include/oomph/channel/send_channel.hpp +++ b/include/oomph/channel/send_channel.hpp @@ -7,7 +7,6 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -// TODO: Needed for a completely backend implementation? Skip for NCCL? #pragma once #include From e7f6fbb50283938c113800de25852ac390570694 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Mon, 22 Dec 2025 12:32:33 +0100 Subject: [PATCH 14/25] Add missing cuda_event.hpp file --- src/nccl/cuda_event.hpp | 89 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 src/nccl/cuda_event.hpp diff --git a/src/nccl/cuda_event.hpp b/src/nccl/cuda_event.hpp new file mode 100644 index 00000000..83acbfd9 --- /dev/null +++ b/src/nccl/cuda_event.hpp @@ -0,0 +1,89 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +#include + +#include "cuda_error.hpp" + +namespace oomph::detail { +struct cuda_event { + cudaEvent_t m_event; + oomph::util::moved_bit m_moved; + bool m_recorded{false}; + + cuda_event() { + OOMPH_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming)); + // std::cerr << "created a cuda_event with value " << m_event << "\n"; + } + cuda_event(cuda_event&& other) noexcept = default; + cuda_event& operator=(cuda_event&& other) noexcept = default; + cuda_event(const cuda_event&) = delete; + cuda_event& operator=(const cuda_event&) = delete; + ~cuda_event() noexcept { + if (!m_moved) { + OOMPH_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event)); + } + } + + void record(cudaStream_t stream) { + assert(!m_moved); + OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(m_event, stream)); + m_recorded = true; + } + + bool is_ready() { + // std::cerr << "checking if request is ready\n"; + if (m_moved || !m_recorded) { + return false; + } + + cudaError_t res = cudaEventQuery(m_event); + // std::cerr << "request " << m_event << " is in state " << res << "\n"; + if (res == cudaSuccess) { + return true; + } else if (res == cudaErrorNotReady) { + return false; + } else { + OOMPH_CHECK_CUDA_RESULT(res); + return false; + } + } + + cudaEvent_t get() { + assert(!m_moved); + return m_event; + } +}; + +struct group_cuda_event { + std::shared_ptr m_event; + + group_cuda_event() : m_event(std::make_shared()) {} + group_cuda_event(const group_cuda_event&) = default; + group_cuda_event& operator=(const group_cuda_event&) = default; + group_cuda_event(group_cuda_event&&) = default; + group_cuda_event& operator=(group_cuda_event&&) = default; + + void record(cudaStream_t stream) { + m_event->record(stream); + } + + bool is_ready() { + return m_event->is_ready(); + } + + cudaEvent_t get() { + return m_event->get(); + } +}; +} From 8eb0cec6e0c5cb9609ae2d3f39f748f486eac556 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Mon, 22 Dec 2025 13:16:31 +0100 Subject: [PATCH 15/25] Update hwmalloc --- ext/hwmalloc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/hwmalloc b/ext/hwmalloc index 2078a51e..c3ddc35f 160000 --- a/ext/hwmalloc +++ b/ext/hwmalloc @@ -1 +1 @@ -Subproject commit 2078a51ef862ba22705f3c28f4d399d78980604b +Subproject commit c3ddc35f58ad6709388c209dfaec59b1ff40d472 From a6810db193d05a42578eea3f918086a5a8179072 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Mon, 22 Dec 2025 15:39:24 +0100 Subject: [PATCH 16/25] Minor cleanup --- src/nccl/communicator.hpp | 1 + src/nccl/context.cpp | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp index af68ffaa..3785f8bc 100644 --- a/src/nccl/communicator.hpp +++ b/src/nccl/communicator.hpp @@ -10,6 +10,7 @@ #pragma once #include +#include #include #include diff --git a/src/nccl/context.cpp b/src/nccl/context.cpp index ef36973c..657cd542 100644 --- a/src/nccl/context.cpp +++ b/src/nccl/context.cpp @@ -9,8 +9,8 @@ */ // paths relative to backend -#include -#include +#include "context.hpp" +#include "communicator.hpp" namespace oomph { From 8a854e47e6fac1c04eb13322f0758a4f7bdc41a7 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Mon, 22 Dec 2025 15:47:35 +0100 Subject: [PATCH 17/25] More cleanup --- src/nccl/context.hpp | 8 +++--- src/request.cpp | 5 ---- src/request_state_base.hpp | 6 +---- test/test_send_recv.cpp | 54 +++++++++++++++----------------------- 4 files changed, 26 insertions(+), 47 deletions(-) diff --git a/src/nccl/context.hpp b/src/nccl/context.hpp index 2cdacea8..4f78eded 100644 --- a/src/nccl/context.hpp +++ b/src/nccl/context.hpp @@ -15,10 +15,10 @@ #include // paths relative to backend -#include -#include <../context_base.hpp> -#include -#include "./request_queue.hpp" +#include "../context_base.hpp" +#include "nccl_communicator.hpp" +#include "region.hpp" +#include "request_queue.hpp" namespace oomph { diff --git a/src/request.cpp b/src/request.cpp index 1f210769..972650f3 100644 --- a/src/request.cpp +++ b/src/request.cpp @@ -24,7 +24,6 @@ namespace oomph bool send_request::is_ready() const noexcept { - // std::cerr << "send_request::is_ready()\n"; if (!m) return true; return m->is_ready(); } @@ -47,9 +46,7 @@ send_request::wait() bool recv_request::is_ready() const noexcept { - // std::cerr << "recv_request::is_ready()\n"; if (!m) return true; - // std::cerr << "recv_request::is_ready, checking impl m->is_ready()\n"; return m->is_ready(); } @@ -86,7 +83,6 @@ recv_request::cancel() bool shared_recv_request::is_ready() const noexcept { - // std::cerr << "shared_recv_request::is_ready()\n"; if (!m) return true; return m->is_ready(); } @@ -124,7 +120,6 @@ shared_recv_request::cancel() bool send_multi_request::is_ready() const noexcept { - // std::cerr << "send_multi_request::is_ready()\n"; if (!m) return true; return (m->m_counter == 0); } diff --git a/src/request_state_base.hpp b/src/request_state_base.hpp index 9110aa71..c0a6598a 100644 --- a/src/request_state_base.hpp +++ b/src/request_state_base.hpp @@ -88,16 +88,12 @@ struct request_state_base ++(*m_scheduled); } - bool is_ready() const noexcept { - // std::cerr << "request_state_base::is_ready()\n"; - return traits::load(m_ready); - } + bool is_ready() const noexcept { return traits::load(m_ready); } bool is_canceled() const noexcept { return traits::load(m_canceled); } void invoke_cb() { - // std::cerr << "invoke_cb, setting m_ready to true\n"; m_cb(m_rank, m_tag); --(*m_scheduled); traits::store(m_ready, true); diff --git a/test/test_send_recv.cpp b/test/test_send_recv.cpp index 87e5b4ec..97aca9f0 100644 --- a/test/test_send_recv.cpp +++ b/test/test_send_recv.cpp @@ -182,11 +182,8 @@ template void launch_test(Func f) { - // std::cerr << "launch_test\n"; - // single threaded { - // std::cerr << "single threaded\n"; oomph::context ctxt(MPI_COMM_WORLD, false); reset_counters(); f(ctxt, SIZE, 0, 1, false); @@ -194,28 +191,27 @@ launch_test(Func f) f(ctxt, SIZE, 0, 1, true); } -// // multi threaded -// try { -// // std::cerr << "multi threaded\n"; -// oomph::context ctxt(MPI_COMM_WORLD, true); -// std::vector threads; -// threads.reserve(NTHREADS); -// reset_counters(); -// for (int i = 0; i < NTHREADS; ++i) -// threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, false}); -// for (auto& t : threads) t.join(); -// threads.clear(); -// reset_counters(); -// for (int i = 0; i < NTHREADS; ++i) -// threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, true}); -// for (auto& t : threads) t.join(); -// } catch (std::runtime_error const& e) { -// if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { -// EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); -// } else { -// throw e; -// } -// } + // multi threaded + try { + oomph::context ctxt(MPI_COMM_WORLD, true); + std::vector threads; + threads.reserve(NTHREADS); + reset_counters(); + for (int i = 0; i < NTHREADS; ++i) + threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, false}); + for (auto& t : threads) t.join(); + threads.clear(); + reset_counters(); + for (int i = 0; i < NTHREADS; ++i) + threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, true}); + for (auto& t : threads) t.join(); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw e; + } + } } // no callback @@ -227,27 +223,20 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, Env env(ctxt, size, tid, num_threads, user_alloc); // use is_ready() -> must manually progress the communicator - // std::cerr << "test_send_recv 1\n"; for (int i = 0; i < NITERS; i++) { - // std::cerr << "iteration " << i << "\n"; env.comm.start_group(); auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); auto sreq = env.comm.send(env.smsg, env.speer_rank, env.tag); env.comm.end_group(); - // std::cerr << "rreq.is_ready() = " << rreq.is_ready() << '\n'; - // std::cerr << "sreq.is_ready() = " << sreq.is_ready() << '\n'; while (!(rreq.is_ready() && sreq.is_ready())) { - // std::cerr << "calling env.comm.progress()\n"; env.comm.progress(); }; EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); } - // std::cerr << "test_send_recv 1 done\n"; - // std::cerr << "test_send_recv 2\n"; // use test() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { @@ -260,7 +249,6 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, env.fill_recv_buffer(); } - // std::cerr << "test_send_recv 3\n"; // use wait() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { From 863750d6bc8d16ba462ca3446a22caeaa59e6c82 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Tue, 6 Jan 2026 11:01:28 +0100 Subject: [PATCH 18/25] Add missing stream argument --- include/oomph/communicator.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/oomph/communicator.hpp b/include/oomph/communicator.hpp index 182b43b7..0d73448c 100644 --- a/include/oomph/communicator.hpp +++ b/include/oomph/communicator.hpp @@ -229,7 +229,7 @@ class communicator std::vector const& neighs, std::vector const& tags, void* stream = nullptr) { assert(neighs.size() == tags.size()); - return send_multi(msg, neighs.data(), tags.data(), neighs.size()); + return send_multi(msg, neighs.data(), tags.data(), neighs.size(), stream); } // callback versions From ea4742b1ba048918985773de32014a192bc46585 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Tue, 6 Jan 2026 11:07:35 +0100 Subject: [PATCH 19/25] Add dummy stream parameter to libfabric and ucx backends --- src/libfabric/communicator.hpp | 7 ++++--- src/ucx/communicator.hpp | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/libfabric/communicator.hpp b/src/libfabric/communicator.hpp index 68bcbf7e..7dfa01b5 100644 --- a/src/libfabric/communicator.hpp +++ b/src/libfabric/communicator.hpp @@ -174,7 +174,7 @@ class communicator_impl : public communicator_base // -------------------------------------------------------------------- send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, oomph::tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, void*) { [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); std::uint64_t stag = make_tag64(tag, /*this->rank(), */ this->m_context->get_context_tag()); @@ -247,7 +247,7 @@ class communicator_impl : public communicator_base recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, oomph::tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, void*) { [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); std::uint64_t stag = make_tag64(tag, /*src, */ this->m_context->get_context_tag()); @@ -300,7 +300,8 @@ class communicator_impl : public communicator_base shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, oomph::tag_type tag, util::unique_function&& cb, - std::atomic* scheduled) + std::atomic* scheduled, + void*) { [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); std::uint64_t stag = make_tag64(tag, /*src, */ this->m_context->get_context_tag()); diff --git a/src/ucx/communicator.hpp b/src/ucx/communicator.hpp index f90943a4..1689618d 100644 --- a/src/ucx/communicator.hpp +++ b/src/ucx/communicator.hpp @@ -129,7 +129,7 @@ class communicator_impl : public communicator_base send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, void*) { const auto& ep = m_send_worker->connect(dst); const auto stag = @@ -191,7 +191,7 @@ class communicator_impl : public communicator_base recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, void*) { const auto rtag = (communicator::any_source == src) @@ -263,7 +263,7 @@ class communicator_impl : public communicator_base shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, - std::atomic* scheduled) + std::atomic* scheduled, void*) { const auto rtag = (communicator::any_source == src) From 729460f6085721c4faa28e35efd51d0855eedab8 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Tue, 6 Jan 2026 11:12:22 +0100 Subject: [PATCH 20/25] Remove TODO from FindNCCL.cmake --- cmake/FindNCCL.cmake | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake index d5beae56..b5327377 100644 --- a/cmake/FindNCCL.cmake +++ b/cmake/FindNCCL.cmake @@ -1,5 +1,4 @@ -# This is from https://github.com/pytorch/gloo/blob/main/cmake/Modules/Findnccl.cmake. -# TODO: Check that license is compatible. +# From https://github.com/pytorch/gloo/blob/main/cmake/Modules/Findnccl.cmake. # Try to find NCCL # From 3757868e09a98c2c490250c91cd269692e8fd364 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Tue, 6 Jan 2026 11:19:51 +0100 Subject: [PATCH 21/25] Remove TODO from test_locality.cpp --- test/test_locality.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_locality.cpp b/test/test_locality.cpp index ce0ea126..80e5e1ab 100644 --- a/test/test_locality.cpp +++ b/test/test_locality.cpp @@ -42,7 +42,6 @@ TEST_F(mpi_test_fixture, locality_enumerate) gethostname(my_host_name.data(), HOST_NAME_MAX + 1); for (int r = 0; r < comm.size(); ++r) { - // TODO: Can this be made to work with NCCL? if (r == comm.rank()) { for (int rr = 0; rr < comm.size(); ++rr) From fb91491d70decc5a22ca43d3d18ee90fd049ddf5 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Wed, 7 Jan 2026 17:41:30 +0100 Subject: [PATCH 22/25] Add event pool and cached cuda event helper --- src/nccl/CMakeLists.txt | 2 +- src/nccl/cached_cuda_event.hpp | 51 +++++++++++++++++++++++++++ src/nccl/communicator.hpp | 16 +++------ src/nccl/cuda_event.hpp | 41 +++++++--------------- src/nccl/cuda_event_pool.cpp | 18 ++++++++++ src/nccl/cuda_event_pool.hpp | 63 ++++++++++++++++++++++++++++++++++ src/nccl/group_cuda_event.hpp | 43 +++++++++++++++++++++++ src/nccl/request.hpp | 3 +- 8 files changed, 195 insertions(+), 42 deletions(-) create mode 100644 src/nccl/cached_cuda_event.hpp create mode 100644 src/nccl/cuda_event_pool.cpp create mode 100644 src/nccl/cuda_event_pool.hpp create mode 100644 src/nccl/group_cuda_event.hpp diff --git a/src/nccl/CMakeLists.txt b/src/nccl/CMakeLists.txt index 9d006c15..51bc3105 100644 --- a/src/nccl/CMakeLists.txt +++ b/src/nccl/CMakeLists.txt @@ -6,4 +6,4 @@ target_link_libraries(oomph_nccl PRIVATE oomph_private_nccl_headers) list(TRANSFORM oomph_sources PREPEND ${CMAKE_CURRENT_SOURCE_DIR}/../ OUTPUT_VARIABLE oomph_sources_nccl) target_sources(oomph_nccl PRIVATE ${oomph_sources_nccl}) -target_sources(oomph_nccl PRIVATE context.cpp) +target_sources(oomph_nccl PRIVATE context.cpp cuda_event_pool.cpp) diff --git a/src/nccl/cached_cuda_event.hpp b/src/nccl/cached_cuda_event.hpp new file mode 100644 index 00000000..9a9a2720 --- /dev/null +++ b/src/nccl/cached_cuda_event.hpp @@ -0,0 +1,51 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include "cuda_event.hpp" +#include "cuda_event_pool.hpp" + +namespace oomph::detail { +// A cuda_event backed by a cuda_event_pool. +// +// Same semantics as cuda_event, but the event is retrieved from a static +// cuda_event_pool on construction and returned to the pool on destruction. +struct cached_cuda_event { + cuda_event m_event; + + cached_cuda_event() : m_event(get_cuda_event_pool().pop()) {} + cached_cuda_event(cached_cuda_event&& other) noexcept = default; + cached_cuda_event& operator=(cached_cuda_event&& other) noexcept = default; + cached_cuda_event(const cached_cuda_event&) = default; + cached_cuda_event& operator=(const cached_cuda_event&) = default; + ~cached_cuda_event() noexcept { + if (m_event) { + get_cuda_event_pool().push(std::move(m_event)); + } + } + + operator bool() noexcept { + return bool(m_event); + } + + void record(cudaStream_t stream) { + return m_event.record(stream); + } + + bool is_ready() const { + return m_event.is_ready(); + } + + cudaEvent_t get() { + return m_event.get(); + } +}; +} + diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp index 3785f8bc..4f39328b 100644 --- a/src/nccl/communicator.hpp +++ b/src/nccl/communicator.hpp @@ -21,6 +21,8 @@ #include "../communicator_base.hpp" #include "../device_guard.hpp" #include "./context.hpp" +#include "cached_cuda_event.hpp" +#include "group_cuda_event.hpp" #include "request.hpp" #include "request_queue.hpp" #include "request_state.hpp" @@ -77,9 +79,6 @@ class communicator_impl : public communicator_base OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); m_group_info.emplace(); - - // std::cerr << "started group\n"; - // std::cerr << "group_info: " << static_cast(m_group_info->m_event.get()) << "\n"; } void end_group() { @@ -96,20 +95,17 @@ class communicator_impl : public communicator_base nccl_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, [[maybe_unused]] tag_type tag, void* stream) { - // std::cerr << "nccl::send\n"; - const_device_guard dg(ptr); OOMPH_CHECK_NCCL_RESULT( ncclSend(dg.data(), size, ncclChar, dst, m_context->get_comm(), static_cast(stream))); if (m_group_info.has_value()) { m_group_info->m_last_stream = static_cast(stream); - // std::cerr << "using group event " << m_group_info->m_event.get() << "\n"; // The event is stored now, but recorded only in end_group. Until // an event has been recorded the event is never ready. return {m_group_info->m_event}; } else { - detail::cuda_event event; + detail::cached_cuda_event event; event.record(static_cast(stream)); return {std::move(event)}; } @@ -118,20 +114,17 @@ class communicator_impl : public communicator_base nccl_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, [[maybe_unused]] tag_type tag, void* stream) { - // std::cerr << "nccl::recv\n"; - device_guard dg(ptr); OOMPH_CHECK_NCCL_RESULT( ncclRecv(dg.data(), size, ncclChar, src, m_context->get_comm(), static_cast(stream))); if (m_group_info.has_value()) { m_group_info->m_last_stream = static_cast(stream); - // std::cerr << "using group event " << m_group_info->m_event.get() << "\n"; // The event is stored now, but recorded only in end_group. Until // an event has been recorded the event is never ready. return {m_group_info->m_event}; } else { - detail::cuda_event event; + detail::cached_cuda_event event; event.record(static_cast(stream)); return {std::move(event)}; } @@ -171,7 +164,6 @@ class communicator_impl : public communicator_base void progress() { - // std::cerr << "nccl communicator::progress\n"; // Communication progresses independently, but requests must be marked // ready and callbacks must be invoked. m_send_reqs.progress(); diff --git a/src/nccl/cuda_event.hpp b/src/nccl/cuda_event.hpp index 83acbfd9..58ae677a 100644 --- a/src/nccl/cuda_event.hpp +++ b/src/nccl/cuda_event.hpp @@ -9,6 +9,8 @@ */ #pragma once +#include + #include #include @@ -16,6 +18,10 @@ #include "cuda_error.hpp" namespace oomph::detail { +// RAII wrapper for a cudaEvent_t. +// +// Move-only wrapper around cudaEvent_t that automatically destroys the +// underlying event on destruction. Can be used to record events on streams. struct cuda_event { cudaEvent_t m_event; oomph::util::moved_bit m_moved; @@ -23,7 +29,6 @@ struct cuda_event { cuda_event() { OOMPH_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming)); - // std::cerr << "created a cuda_event with value " << m_event << "\n"; } cuda_event(cuda_event&& other) noexcept = default; cuda_event& operator=(cuda_event&& other) noexcept = default; @@ -35,27 +40,29 @@ struct cuda_event { } } + operator bool() noexcept { + return !m_moved; + } + void record(cudaStream_t stream) { assert(!m_moved); OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(m_event, stream)); m_recorded = true; } - bool is_ready() { - // std::cerr << "checking if request is ready\n"; + bool is_ready() const { if (m_moved || !m_recorded) { return false; } cudaError_t res = cudaEventQuery(m_event); - // std::cerr << "request " << m_event << " is in state " << res << "\n"; if (res == cudaSuccess) { return true; } else if (res == cudaErrorNotReady) { return false; } else { - OOMPH_CHECK_CUDA_RESULT(res); - return false; + OOMPH_CHECK_CUDA_RESULT(res); + return false; } } @@ -64,26 +71,4 @@ struct cuda_event { return m_event; } }; - -struct group_cuda_event { - std::shared_ptr m_event; - - group_cuda_event() : m_event(std::make_shared()) {} - group_cuda_event(const group_cuda_event&) = default; - group_cuda_event& operator=(const group_cuda_event&) = default; - group_cuda_event(group_cuda_event&&) = default; - group_cuda_event& operator=(group_cuda_event&&) = default; - - void record(cudaStream_t stream) { - m_event->record(stream); - } - - bool is_ready() { - return m_event->is_ready(); - } - - cudaEvent_t get() { - return m_event->get(); - } -}; } diff --git a/src/nccl/cuda_event_pool.cpp b/src/nccl/cuda_event_pool.cpp new file mode 100644 index 00000000..57607cdc --- /dev/null +++ b/src/nccl/cuda_event_pool.cpp @@ -0,0 +1,18 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include "cuda_event_pool.hpp" + +namespace oomph::detail { +cuda_event_pool& get_cuda_event_pool() { + static cuda_event_pool pool{128}; + return pool; +} +} diff --git a/src/nccl/cuda_event_pool.hpp b/src/nccl/cuda_event_pool.hpp new file mode 100644 index 00000000..56bbfa01 --- /dev/null +++ b/src/nccl/cuda_event_pool.hpp @@ -0,0 +1,63 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include + +#include + +#include + +#include "cuda_error.hpp" +#include "cuda_event.hpp" + +namespace oomph::detail { +// Pool of cuda_events. +// +// Simple wrapper over a vector of cuda_events. Events can be popped from the +// pool. New events are created if the pool is empty. Events can be returned to +// the pool for reuse. Events do not need to originate from the pool. Not +// thread-safe. +class cuda_event_pool +{ + private: + std::vector m_events; + + public: + cuda_event_pool(std::size_t expected_pool_size) + : m_events(expected_pool_size) + { + } + + cuda_event_pool(const cuda_event_pool&) = delete; + cuda_event_pool& operator=(const cuda_event_pool&) = delete; + cuda_event_pool(cuda_event_pool&& other) noexcept = delete; + cuda_event_pool& operator=(cuda_event_pool&&) noexcept = delete; + + public: + cuda_event pop() { + if (m_events.empty()) { + return {}; + } else { + auto event{std::move(m_events.back())}; + m_events.pop_back(); + return event; + } + } + + void push(cuda_event&& event) { m_events.push_back(std::move(event)); } + void clear() { m_events.clear(); } +}; + +// Get a static instance of a cuda_event_pool. +cuda_event_pool& get_cuda_event_pool(); +} diff --git a/src/nccl/group_cuda_event.hpp b/src/nccl/group_cuda_event.hpp new file mode 100644 index 00000000..f3cfd44c --- /dev/null +++ b/src/nccl/group_cuda_event.hpp @@ -0,0 +1,43 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +#include "cached_cuda_event.hpp" + +namespace oomph::detail { +// A shared cuda_event suitable for use with NCCL groups. +// +// A cached_cuda_event stored in a shared_ptr for shared usage between multiple +// requests. +struct group_cuda_event { + std::shared_ptr m_event; + + group_cuda_event() : m_event(std::make_shared()) {} + group_cuda_event(const group_cuda_event&) = default; + group_cuda_event& operator=(const group_cuda_event&) = default; + group_cuda_event(group_cuda_event&&) = default; + group_cuda_event& operator=(group_cuda_event&&) = default; + + void record(cudaStream_t stream) { + m_event->record(stream); + } + + bool is_ready() { + return m_event->is_ready(); + } + + cudaEvent_t get() { + return m_event->get(); + } +}; +} + diff --git a/src/nccl/request.hpp b/src/nccl/request.hpp index 4e9e1884..16223391 100644 --- a/src/nccl/request.hpp +++ b/src/nccl/request.hpp @@ -15,6 +15,7 @@ #include "cuda_error.hpp" #include "cuda_event.hpp" +#include "group_cuda_event.hpp" namespace oomph { @@ -28,6 +29,6 @@ struct nccl_request // We store either a single event for a particular request, or a shared // event that signals the end of a NCCL group. - std::variant m_event; + std::variant m_event; }; } // namespace oomph From 4c295b0eb982425838bb0f47957b117d41c10b35 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Thu, 8 Jan 2026 12:41:05 +0100 Subject: [PATCH 23/25] Remove duplicate key in clang-format config --- .clang-format | 1 - 1 file changed, 1 deletion(-) diff --git a/.clang-format b/.clang-format index e941415f..71fc4868 100644 --- a/.clang-format +++ b/.clang-format @@ -27,7 +27,6 @@ BreakBeforeBraces: Allman # ConstructorInitializerAllOnOneLineOrOnePerLine: false BreakConstructorInitializers: BeforeComma ConstructorInitializerIndentWidth: 0 -BreakInheritanceList: BeforeComma #AllowShortBlocksOnASingleLine: Always AllowShortBlocksOnASingleLine: true AllowShortCaseLabelsOnASingleLine: false From fe3904dc5975502502a69626699504c6d30c1568 Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Thu, 8 Jan 2026 12:41:40 +0100 Subject: [PATCH 24/25] Format nccl files --- src/nccl/cached_cuda_event.hpp | 59 ++++++++---------- src/nccl/communicator.hpp | 109 ++++++++++++++++++--------------- src/nccl/context.cpp | 12 ++-- src/nccl/context.hpp | 6 +- src/nccl/cuda_error.hpp | 8 ++- src/nccl/cuda_event.hpp | 86 +++++++++++++------------- src/nccl/cuda_event_pool.cpp | 13 ++-- src/nccl/cuda_event_pool.hpp | 20 +++--- src/nccl/group_cuda_event.hpp | 36 +++++------ src/nccl/handle.hpp | 1 - src/nccl/nccl_error.hpp | 10 ++- src/nccl/region.hpp | 1 - src/nccl/request.hpp | 7 +-- src/nccl/request_queue.hpp | 29 ++++----- src/nccl/request_state.hpp | 2 +- 15 files changed, 202 insertions(+), 197 deletions(-) diff --git a/src/nccl/cached_cuda_event.hpp b/src/nccl/cached_cuda_event.hpp index 9a9a2720..a62c3f3c 100644 --- a/src/nccl/cached_cuda_event.hpp +++ b/src/nccl/cached_cuda_event.hpp @@ -12,40 +12,35 @@ #include "cuda_event.hpp" #include "cuda_event_pool.hpp" -namespace oomph::detail { +namespace oomph::detail +{ // A cuda_event backed by a cuda_event_pool. // // Same semantics as cuda_event, but the event is retrieved from a static // cuda_event_pool on construction and returned to the pool on destruction. -struct cached_cuda_event { - cuda_event m_event; - - cached_cuda_event() : m_event(get_cuda_event_pool().pop()) {} - cached_cuda_event(cached_cuda_event&& other) noexcept = default; - cached_cuda_event& operator=(cached_cuda_event&& other) noexcept = default; - cached_cuda_event(const cached_cuda_event&) = default; - cached_cuda_event& operator=(const cached_cuda_event&) = default; - ~cached_cuda_event() noexcept { - if (m_event) { - get_cuda_event_pool().push(std::move(m_event)); - } - } - - operator bool() noexcept { - return bool(m_event); - } - - void record(cudaStream_t stream) { - return m_event.record(stream); - } - - bool is_ready() const { - return m_event.is_ready(); - } - - cudaEvent_t get() { - return m_event.get(); - } +struct cached_cuda_event +{ + cuda_event m_event; + + cached_cuda_event() + : m_event(get_cuda_event_pool().pop()) + { + } + cached_cuda_event(cached_cuda_event&& other) noexcept = default; + cached_cuda_event& operator=(cached_cuda_event&& other) noexcept = default; + cached_cuda_event(const cached_cuda_event&) = default; + cached_cuda_event& operator=(const cached_cuda_event&) = default; + ~cached_cuda_event() noexcept + { + if (m_event) { get_cuda_event_pool().push(std::move(m_event)); } + } + + operator bool() noexcept { return bool(m_event); } + + void record(cudaStream_t stream) { return m_event.record(stream); } + + bool is_ready() const { return m_event.is_ready(); } + + cudaEvent_t get() { return m_event.get(); } }; -} - +} // namespace oomph::detail diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp index 4f39328b..089ad8d1 100644 --- a/src/nccl/communicator.hpp +++ b/src/nccl/communicator.hpp @@ -37,25 +37,26 @@ class communicator_impl : public communicator_base request_queue m_recv_reqs; private: - struct group_info { - // A shared CUDA event used for synchronization at the end of the NCCL - // group. All streams used within the group are waited for before the - // group kernel starts and all streams can be used to wait for the - // completion of the group kernel. From - // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/streams.html: - // - // NCCL allows for using multiple streams within a group call. This will - // enforce a stream dependency of all streams before the NCCL kernel - // starts and block all streams until the NCCL kernel completes. - // - // It will behave as if the NCCL group operation was posted on every - // stream, but given it is a single operation, it will cause a global - // synchronization point between the streams. - detail::group_cuda_event m_event{}; - - // We arbitrarily use the last stream used within a group to synchronize - // the whole group. - cudaStream_t m_last_stream{}; + struct group_info + { + // A shared CUDA event used for synchronization at the end of the NCCL + // group. All streams used within the group are waited for before the + // group kernel starts and all streams can be used to wait for the + // completion of the group kernel. From + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/streams.html: + // + // NCCL allows for using multiple streams within a group call. This will + // enforce a stream dependency of all streams before the NCCL kernel + // starts and block all streams until the NCCL kernel completes. + // + // It will behave as if the NCCL group operation was posted on every + // stream, but given it is a single operation, it will cause a global + // synchronization point between the streams. + detail::group_cuda_event m_event{}; + + // We arbitrarily use the last stream used within a group to synchronize + // the whole group. + cudaStream_t m_last_stream{}; }; // NCCL group information. When no group is active this is std::nullopt. @@ -74,37 +75,42 @@ class communicator_impl : public communicator_base bool is_stream_aware() const noexcept { return true; } - void start_group() { - assert(!m_group_info.has_value()); + void start_group() + { + assert(!m_group_info.has_value()); - OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); - m_group_info.emplace(); + OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); + m_group_info.emplace(); } - void end_group() { - assert(m_group_info.has_value()); + void end_group() + { + assert(m_group_info.has_value()); - OOMPH_CHECK_NCCL_RESULT(ncclGroupEnd()); + OOMPH_CHECK_NCCL_RESULT(ncclGroupEnd()); - // All streams used in a NCCL group synchronize with the end of the group. - // We arbitrarily pick the last stream to synchronize against. - m_group_info->m_event.record(m_group_info->m_last_stream); - m_group_info.reset(); + // All streams used in a NCCL group synchronize with the end of the group. + // We arbitrarily pick the last stream to synchronize against. + m_group_info->m_event.record(m_group_info->m_last_stream); + m_group_info.reset(); } nccl_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, [[maybe_unused]] tag_type tag, void* stream) { const_device_guard dg(ptr); - OOMPH_CHECK_NCCL_RESULT( - ncclSend(dg.data(), size, ncclChar, dst, m_context->get_comm(), static_cast(stream))); + OOMPH_CHECK_NCCL_RESULT(ncclSend(dg.data(), size, ncclChar, dst, m_context->get_comm(), + static_cast(stream))); - if (m_group_info.has_value()) { + if (m_group_info.has_value()) + { m_group_info->m_last_stream = static_cast(stream); - // The event is stored now, but recorded only in end_group. Until - // an event has been recorded the event is never ready. + // The event is stored now, but recorded only in end_group. Until + // an event has been recorded the event is never ready. return {m_group_info->m_event}; - } else { + } + else + { detail::cached_cuda_event event; event.record(static_cast(stream)); return {std::move(event)}; @@ -115,15 +121,18 @@ class communicator_impl : public communicator_base [[maybe_unused]] tag_type tag, void* stream) { device_guard dg(ptr); - OOMPH_CHECK_NCCL_RESULT( - ncclRecv(dg.data(), size, ncclChar, src, m_context->get_comm(), static_cast(stream))); + OOMPH_CHECK_NCCL_RESULT(ncclRecv(dg.data(), size, ncclChar, src, m_context->get_comm(), + static_cast(stream))); - if (m_group_info.has_value()) { + if (m_group_info.has_value()) + { m_group_info->m_last_stream = static_cast(stream); - // The event is stored now, but recorded only in end_group. Until - // an event has been recorded the event is never ready. + // The event is stored now, but recorded only in end_group. Until + // an event has been recorded the event is never ready. return {m_group_info->m_event}; - } else { + } + else + { detail::cached_cuda_event event; event.record(static_cast(stream)); return {std::move(event)}; @@ -131,20 +140,24 @@ class communicator_impl : public communicator_base } send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag, util::unique_function&& cb, std::size_t* scheduled, void* stream) + tag_type tag, util::unique_function&& cb, std::size_t* scheduled, + void* stream) { auto req = send(ptr, size, dst, tag, stream); - auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb), std::move(req)); + auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb), + std::move(req)); s->create_self_ref(); m_send_reqs.enqueue(s.get()); return {std::move(s)}; } recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb, std::size_t* scheduled, void* stream) + tag_type tag, util::unique_function&& cb, std::size_t* scheduled, + void* stream) { auto req = recv(ptr, size, src, tag, stream); - auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb), std::move(req)); + auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb), + std::move(req)); s->create_self_ref(); m_recv_reqs.enqueue(s.get()); return {std::move(s)}; @@ -164,8 +177,8 @@ class communicator_impl : public communicator_base void progress() { - // Communication progresses independently, but requests must be marked - // ready and callbacks must be invoked. + // Communication progresses independently, but requests must be marked + // ready and callbacks must be invoked. m_send_reqs.progress(); m_recv_reqs.progress(); m_context->progress(); diff --git a/src/nccl/context.cpp b/src/nccl/context.cpp index 657cd542..9d0e0477 100644 --- a/src/nccl/context.cpp +++ b/src/nccl/context.cpp @@ -22,13 +22,11 @@ context_impl::get_communicator() return comm; } -const char *context_impl::get_transport_option(const std::string &opt) { - if (opt == "name") { - return "nccl"; - } - else { - return "unspecified"; - } +const char* +context_impl::get_transport_option(const std::string& opt) +{ + if (opt == "name") { return "nccl"; } + else { return "unspecified"; } } } // namespace oomph diff --git a/src/nccl/context.hpp b/src/nccl/context.hpp index 4f78eded..e87b44e2 100644 --- a/src/nccl/context.hpp +++ b/src/nccl/context.hpp @@ -30,7 +30,7 @@ class context_impl : public context_base using heap_type = hwmalloc::heap; private: - heap_type m_heap; + heap_type m_heap; detail::nccl_comm m_comm; public: @@ -42,9 +42,7 @@ class context_impl : public context_base , m_heap{this, heap_config} , m_comm{oomph::detail::nccl_comm{comm}} { - if (thread_safe) { - throw std::runtime_error("NCCL not supported with thread_safe = true"); - } + if (thread_safe) { throw std::runtime_error("NCCL not supported with thread_safe = true"); } } context_impl(context_impl const&) = delete; diff --git a/src/nccl/cuda_error.hpp b/src/nccl/cuda_error.hpp index 0a785b3e..baf9a17b 100644 --- a/src/nccl/cuda_error.hpp +++ b/src/nccl/cuda_error.hpp @@ -23,8 +23,12 @@ std::string(__FILE__) + ":" + std::to_string(__LINE__)); #define OOMPH_CHECK_CUDA_RESULT_NO_THROW(x) \ - try { OOMPH_CHECK_CUDA_RESULT(x) } \ - catch (const std::exception& e) { \ + try \ + { \ + OOMPH_CHECK_CUDA_RESULT(x) \ + } \ + catch (const std::exception& e) \ + { \ std::cerr << e.what() << std::endl; \ std::terminate(); \ } diff --git a/src/nccl/cuda_event.hpp b/src/nccl/cuda_event.hpp index 58ae677a..01762bea 100644 --- a/src/nccl/cuda_event.hpp +++ b/src/nccl/cuda_event.hpp @@ -17,58 +17,58 @@ #include "cuda_error.hpp" -namespace oomph::detail { +namespace oomph::detail +{ // RAII wrapper for a cudaEvent_t. // // Move-only wrapper around cudaEvent_t that automatically destroys the // underlying event on destruction. Can be used to record events on streams. -struct cuda_event { - cudaEvent_t m_event; - oomph::util::moved_bit m_moved; - bool m_recorded{false}; +struct cuda_event +{ + cudaEvent_t m_event; + oomph::util::moved_bit m_moved; + bool m_recorded{false}; - cuda_event() { - OOMPH_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming)); - } - cuda_event(cuda_event&& other) noexcept = default; - cuda_event& operator=(cuda_event&& other) noexcept = default; - cuda_event(const cuda_event&) = delete; - cuda_event& operator=(const cuda_event&) = delete; - ~cuda_event() noexcept { - if (!m_moved) { - OOMPH_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event)); - } - } - - operator bool() noexcept { - return !m_moved; - } + cuda_event() + { + OOMPH_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming)); + } + cuda_event(cuda_event&& other) noexcept = default; + cuda_event& operator=(cuda_event&& other) noexcept = default; + cuda_event(const cuda_event&) = delete; + cuda_event& operator=(const cuda_event&) = delete; + ~cuda_event() noexcept + { + if (!m_moved) { OOMPH_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event)); } + } - void record(cudaStream_t stream) { - assert(!m_moved); - OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(m_event, stream)); - m_recorded = true; - } + operator bool() noexcept { return !m_moved; } - bool is_ready() const { - if (m_moved || !m_recorded) { - return false; + void record(cudaStream_t stream) + { + assert(!m_moved); + OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(m_event, stream)); + m_recorded = true; } - cudaError_t res = cudaEventQuery(m_event); - if (res == cudaSuccess) { - return true; - } else if (res == cudaErrorNotReady) { - return false; - } else { - OOMPH_CHECK_CUDA_RESULT(res); - return false; + bool is_ready() const + { + if (m_moved || !m_recorded) { return false; } + + cudaError_t res = cudaEventQuery(m_event); + if (res == cudaSuccess) { return true; } + else if (res == cudaErrorNotReady) { return false; } + else + { + OOMPH_CHECK_CUDA_RESULT(res); + return false; + } } - } - cudaEvent_t get() { - assert(!m_moved); - return m_event; - } + cudaEvent_t get() + { + assert(!m_moved); + return m_event; + } }; -} +} // namespace oomph::detail diff --git a/src/nccl/cuda_event_pool.cpp b/src/nccl/cuda_event_pool.cpp index 57607cdc..4b180dcb 100644 --- a/src/nccl/cuda_event_pool.cpp +++ b/src/nccl/cuda_event_pool.cpp @@ -10,9 +10,12 @@ #include "cuda_event_pool.hpp" -namespace oomph::detail { -cuda_event_pool& get_cuda_event_pool() { - static cuda_event_pool pool{128}; - return pool; -} +namespace oomph::detail +{ +cuda_event_pool& +get_cuda_event_pool() +{ + static cuda_event_pool pool{128}; + return pool; } +} // namespace oomph::detail diff --git a/src/nccl/cuda_event_pool.hpp b/src/nccl/cuda_event_pool.hpp index 56bbfa01..d669b83b 100644 --- a/src/nccl/cuda_event_pool.hpp +++ b/src/nccl/cuda_event_pool.hpp @@ -20,7 +20,8 @@ #include "cuda_error.hpp" #include "cuda_event.hpp" -namespace oomph::detail { +namespace oomph::detail +{ // Pool of cuda_events. // // Simple wrapper over a vector of cuda_events. Events can be popped from the @@ -44,13 +45,14 @@ class cuda_event_pool cuda_event_pool& operator=(cuda_event_pool&&) noexcept = delete; public: - cuda_event pop() { - if (m_events.empty()) { - return {}; - } else { - auto event{std::move(m_events.back())}; - m_events.pop_back(); - return event; + cuda_event pop() + { + if (m_events.empty()) { return {}; } + else + { + auto event{std::move(m_events.back())}; + m_events.pop_back(); + return event; } } @@ -60,4 +62,4 @@ class cuda_event_pool // Get a static instance of a cuda_event_pool. cuda_event_pool& get_cuda_event_pool(); -} +} // namespace oomph::detail diff --git a/src/nccl/group_cuda_event.hpp b/src/nccl/group_cuda_event.hpp index f3cfd44c..eb60e80b 100644 --- a/src/nccl/group_cuda_event.hpp +++ b/src/nccl/group_cuda_event.hpp @@ -13,31 +13,29 @@ #include "cached_cuda_event.hpp" -namespace oomph::detail { +namespace oomph::detail +{ // A shared cuda_event suitable for use with NCCL groups. // // A cached_cuda_event stored in a shared_ptr for shared usage between multiple // requests. -struct group_cuda_event { - std::shared_ptr m_event; +struct group_cuda_event +{ + std::shared_ptr m_event; - group_cuda_event() : m_event(std::make_shared()) {} - group_cuda_event(const group_cuda_event&) = default; - group_cuda_event& operator=(const group_cuda_event&) = default; - group_cuda_event(group_cuda_event&&) = default; - group_cuda_event& operator=(group_cuda_event&&) = default; + group_cuda_event() + : m_event(std::make_shared()) + { + } + group_cuda_event(const group_cuda_event&) = default; + group_cuda_event& operator=(const group_cuda_event&) = default; + group_cuda_event(group_cuda_event&&) = default; + group_cuda_event& operator=(group_cuda_event&&) = default; - void record(cudaStream_t stream) { - m_event->record(stream); - } + void record(cudaStream_t stream) { m_event->record(stream); } - bool is_ready() { - return m_event->is_ready(); - } + bool is_ready() { return m_event->is_ready(); } - cudaEvent_t get() { - return m_event->get(); - } + cudaEvent_t get() { return m_event->get(); } }; -} - +} // namespace oomph::detail diff --git a/src/nccl/handle.hpp b/src/nccl/handle.hpp index 16eb0651..9527592e 100644 --- a/src/nccl/handle.hpp +++ b/src/nccl/handle.hpp @@ -19,4 +19,3 @@ struct handle std::size_t m_size; }; } // namespace oomph - diff --git a/src/nccl/nccl_error.hpp b/src/nccl/nccl_error.hpp index 44423e92..ca4cbe3b 100644 --- a/src/nccl/nccl_error.hpp +++ b/src/nccl/nccl_error.hpp @@ -22,11 +22,15 @@ throw std::runtime_error("OOMPH Error: NCCL Call failed " + std::string(#x) + " = " + \ std::to_string(r) + " (\"" + ncclGetErrorString(r) + \ "\") in " + std::string(__FILE__) + ":" + \ - std::to_string(__LINE__)); \ + std::to_string(__LINE__)); \ } #define OOMPH_CHECK_NCCL_RESULT_NO_THROW(x) \ - try { OOMPH_CHECK_NCCL_RESULT(x) } \ - catch (const std::exception& e) { \ + try \ + { \ + OOMPH_CHECK_NCCL_RESULT(x) \ + } \ + catch (const std::exception& e) \ + { \ std::cerr << e.what() << std::endl; \ std::terminate(); \ } diff --git a/src/nccl/region.hpp b/src/nccl/region.hpp index 5bb7b2ba..71a84f87 100644 --- a/src/nccl/region.hpp +++ b/src/nccl/region.hpp @@ -42,4 +42,3 @@ class region } }; } // namespace oomph - diff --git a/src/nccl/request.hpp b/src/nccl/request.hpp index 16223391..e6c24c7d 100644 --- a/src/nccl/request.hpp +++ b/src/nccl/request.hpp @@ -21,10 +21,9 @@ namespace oomph { struct nccl_request { - bool is_ready() { - return std::visit([](auto& event) { - return event.is_ready(); - }, m_event); + bool is_ready() + { + return std::visit([](auto& event) { return event.is_ready(); }, m_event); } // We store either a single event for a particular request, or a shared diff --git a/src/nccl/request_queue.hpp b/src/nccl/request_queue.hpp index e0806392..8a491daf 100644 --- a/src/nccl/request_queue.hpp +++ b/src/nccl/request_queue.hpp @@ -26,14 +26,11 @@ class request_queue using queue_type = std::vector; private: // members - queue_type m_queue; - bool in_progress = false; + queue_type m_queue; + bool in_progress = false; public: // ctors - request_queue() - { - m_queue.reserve(256); - } + request_queue() { m_queue.reserve(256); } public: // member functions std::size_t size() const noexcept { return m_queue.size(); } @@ -58,20 +55,19 @@ class request_queue return 0; } - auto erase_begin = std::remove_if( - m_queue.begin(), m_queue.end(), - [](auto& req) { + auto erase_begin = std::remove_if(m_queue.begin(), m_queue.end(), + [](auto& req) + { // std::cerr << "checking if request ready with event " << req->m_req.m_event << "\n"; - if (req->m_req.is_ready()) { + if (req->m_req.is_ready()) + { auto ptr = req->release_self_ref(); // std::cerr << "invoking callback on req: " << req << "\n"; req->invoke_cb(); return true; - } else { - return false; } - } - ); + else { return false; } + }); auto completed = std::distance(erase_begin, m_queue.end()); // if (completed != 0) { // std::cerr << "completed " << completed << " requests\n"; @@ -132,10 +128,7 @@ class shared_request_queue found = 1; break; } - else - { - m_local_queue.push_back(e); - } + else { m_local_queue.push_back(e); } } for (auto x : m_local_queue) m_queue.push(x); diff --git a/src/nccl/request_state.hpp b/src/nccl/request_state.hpp index bba9ce1f..92b3358c 100644 --- a/src/nccl/request_state.hpp +++ b/src/nccl/request_state.hpp @@ -92,4 +92,4 @@ struct shared_request_state return std::move(m_self_ptr); } }; -} // namespace oomph +} // namespace oomph::detail From 4ea3bef0f2880f5a0d911d51df6858855319592c Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Thu, 8 Jan 2026 12:49:55 +0100 Subject: [PATCH 25/25] Remove debug prints --- src/nccl/request_queue.hpp | 10 ---------- src/nccl/request_state.hpp | 2 -- 2 files changed, 12 deletions(-) diff --git a/src/nccl/request_queue.hpp b/src/nccl/request_queue.hpp index 8a491daf..669f4be2 100644 --- a/src/nccl/request_queue.hpp +++ b/src/nccl/request_queue.hpp @@ -43,8 +43,6 @@ class request_queue int progress() { - // std::cerr << "nccl request_queue::progress\n"; - if (in_progress) return 0; in_progress = true; @@ -58,20 +56,15 @@ class request_queue auto erase_begin = std::remove_if(m_queue.begin(), m_queue.end(), [](auto& req) { - // std::cerr << "checking if request ready with event " << req->m_req.m_event << "\n"; if (req->m_req.is_ready()) { auto ptr = req->release_self_ref(); - // std::cerr << "invoking callback on req: " << req << "\n"; req->invoke_cb(); return true; } else { return false; } }); auto completed = std::distance(erase_begin, m_queue.end()); - // if (completed != 0) { - // std::cerr << "completed " << completed << " requests\n"; - // } m_queue.erase(erase_begin, m_queue.end()); in_progress = false; @@ -110,8 +103,6 @@ class shared_request_queue int progress() { - // std::cerr << "nccl shared_request_queue::progress\n"; - static thread_local bool in_progress = false; static thread_local std::vector m_local_queue; int found = 0; @@ -124,7 +115,6 @@ class shared_request_queue { if (e->m_req.is_ready()) { - // std::cerr << "found ready request in shared queue\n"; found = 1; break; } diff --git a/src/nccl/request_state.hpp b/src/nccl/request_state.hpp index 92b3358c..e0ac23a6 100644 --- a/src/nccl/request_state.hpp +++ b/src/nccl/request_state.hpp @@ -33,7 +33,6 @@ struct request_state : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} , m_req{std::move(m)} { - // std::cerr << "creating nccl request_state\n"; } void progress(); @@ -71,7 +70,6 @@ struct shared_request_state : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} , m_req{std::move(m)} { - // std::cerr << "creating nccl shared_request_state\n"; } void progress();