Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ set(GHEX_ENABLE_ATLAS_BINDINGS OFF CACHE BOOL "Set to true to build with Atlas b
set(GHEX_BUILD_FORTRAN OFF CACHE BOOL "True if FORTRAN bindings shall be built")
set(GHEX_BUILD_PYTHON_BINDINGS OFF CACHE BOOL "Set to true to build Python bindings")
set(GHEX_WITH_TESTING OFF CACHE BOOL "True if tests shall be built")
# TODO: Add FindNCCL.cmake module.
set(GHEX_USE_NCCL ON CACHE BOOL "Use NCCL")

# ---------------------------------------------------------------------
# Common includes
Expand Down
1 change: 1 addition & 0 deletions cmake/config.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#cmakedefine GHEX_USE_XPMEM
#cmakedefine GHEX_USE_XPMEM_ACCESS_GUARD
#cmakedefine GHEX_USE_GPU
#cmakedefine GHEX_USE_NCCL
#define GHEX_GPU_MODE @ghex_gpu_mode@
#cmakedefine GHEX_GPU_MODE_EMULATE
#define @GHEX_DEVICE@
Expand Down
9 changes: 9 additions & 0 deletions cmake/ghex_external_dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ if (GHEX_USE_XPMEM)
find_package(XPMEM REQUIRED)
endif()


# ---------------------------------------------------------------------
# nccl setup
# ---------------------------------------------------------------------
if(GHEX_USE_NCCL)
link_libraries("-lnccl")
# include_directories("")
endif()
Comment on lines +98 to +104
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hack. Add proper FindNCCL.cmake module.

This can be tested using the icon uenv manually setting export LIBRARY_PATH=/user-environment/env/default/lib64:/user-environment/env/default/lib.


# ---------------------------------------------------------------------
# parmetis setup
# ---------------------------------------------------------------------
Expand Down
159 changes: 158 additions & 1 deletion include/ghex/communication_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ghex/config.hpp>
#include <ghex/context.hpp>
#include <ghex/util/for_each.hpp>
#include <ghex/util/moved_bit.hpp>
#include <ghex/util/test_eq.hpp>
#include <ghex/pattern_container.hpp>
#include <ghex/device/stream.hpp>
Expand All @@ -24,6 +25,10 @@
#include <stdio.h>
#include <functional>

#ifdef GHEX_USE_NCCL
#include <nccl.h>
#endif

namespace ghex
{
// forward declaration for optimization on regular grids
Expand Down Expand Up @@ -207,8 +212,12 @@ class communication_object
using disable_if_buffer_info = std::enable_if_t<!is_buffer_info<T>::value, R>;

private: // members
ghex::util::moved_bit m_moved;
bool m_valid;
communicator_type m_comm;
#ifdef GHEX_USE_NCCL
ncclComm_t m_nccl_comm;
#endif
Comment on lines +215 to +220
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to oomph.

memory_type m_mem;
std::vector<send_request_type> m_send_reqs;
std::vector<recv_request_type> m_recv_reqs;
Expand All @@ -218,12 +227,45 @@ class communication_object
: m_valid(false)
, m_comm(c.transport_context()->get_communicator())
{
ncclUniqueId id;
if (m_comm.rank() == 0) {
ncclGetUniqueId(&id);
}
MPI_Comm mpi_comm = m_comm.mpi_comm();

MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, mpi_comm);

GHEX_CHECK_NCCL_RESULT(ncclCommInitRank(&m_nccl_comm, m_comm.size(), id, m_comm.rank()));
ncclResult_t state;
do {
GHEX_CHECK_NCCL_RESULT(ncclCommGetAsyncError(m_nccl_comm, &state));
} while(state == ncclInProgress);
}
~communication_object() noexcept {
if (!m_moved) {
GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaDeviceSynchronize());
GHEX_CHECK_NCCL_RESULT_NO_THROW(ncclCommDestroy(m_nccl_comm));
}
}
communication_object(const communication_object&) = delete;
communication_object(communication_object&&) = default;

communicator_type& communicator() { return m_comm; }

private:
template<typename... Archs, typename... Fields>
void nccl_exchange_impl(buffer_info_type<Archs, Fields>... buffer_infos) {
pack_nccl();

ncclGroupStart();
post_sends_nccl();
post_recvs_nccl();
ncclGroupEnd();

unpack_nccl();
}
Comment on lines +255 to +266
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add customization point or similar to allow doing this with NCCL?



public: // exchange arbitrary field-device-pattern combinations
/** @brief non-blocking exchange of halo data
* @tparam Archs list of device types
Expand All @@ -234,8 +276,12 @@ class communication_object
[[nodiscard]] handle_type exchange(buffer_info_type<Archs, Fields>... buffer_infos)
{
exchange_impl(buffer_infos...);
#ifdef GHEX_USE_NCCL
nccl_exchange_impl();
#else
post_recvs();
pack();
#endif
return {this};
}

Expand All @@ -248,7 +294,6 @@ class communication_object
[[nodiscard]] disable_if_buffer_info<Iterator, handle_type> exchange(
Iterator first, Iterator last)
{
// call special function for a single range
return exchange_u(first, last);
}

Expand Down Expand Up @@ -279,8 +324,12 @@ class communication_object
[[nodiscard]] handle_type exchange(std::pair<Iterators, Iterators>... iter_pairs)
{
exchange_impl(iter_pairs...);
#ifdef GHEX_USE_NCCL
nccl_exchange_impl();
#else
post_recvs();
pack();
#endif
return {this};
}

Expand Down Expand Up @@ -462,6 +511,89 @@ class communication_object
});
}

void post_sends_nccl()
{
for_each(m_mem, [this](std::size_t, auto& map) {
for (auto& p0 : map.send_memory)
{
const auto device_id = p0.first;
for (auto& p1 : p0.second)
{
if (p1.second.size > 0u)
{
device::guard g(p1.second.buffer);
// TODO: Check why element size isn't relevant for the
// buffer size (also for recv).
GHEX_CHECK_NCCL_RESULT(
ncclSend(static_cast<const void*>(g.data()), p1.second.buffer.size(),
ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get()));
}
}
}
});
}

void post_recvs_nccl()
{
for_each(m_mem, [this](std::size_t, auto& m) {
using arch_type = typename std::remove_reference_t<decltype(m)>::arch_type;
for (auto& p0 : m.recv_memory)
{
const auto device_id = p0.first;
for (auto& p1 : p0.second)
{
if (p1.second.size > 0u)
{
if (!p1.second.buffer || p1.second.buffer.size() != p1.second.size
#if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE)
|| p1.second.buffer.device_id() != device_id
#endif
)
p1.second.buffer = arch_traits<arch_type>::make_message(
m_comm, p1.second.size, device_id);
GHEX_CHECK_NCCL_RESULT(
ncclRecv(p1.second.buffer.device_data(), p1.second.buffer.size(),
ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get()));
}
}
}
});
}

void pack_nccl()
{
for_each(m_mem, [this](std::size_t, auto& m) {
using arch_type = typename std::remove_reference_t<decltype(m)>::arch_type;
packer<arch_type>::pack2_nccl(m, m_send_reqs, m_comm);
});
}

void unpack_nccl()
{
for_each(m_mem, [this](std::size_t, auto& m) {
using arch_type = typename std::remove_reference_t<decltype(m)>::arch_type;
for (auto& p0 : m.recv_memory)
{
const auto device_id = p0.first;
for (auto& p1 : p0.second)
{
if (p1.second.size > 0u)
{
if (!p1.second.buffer || p1.second.buffer.size() != p1.second.size
#if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE)
|| p1.second.buffer.device_id() != device_id
#endif
)
p1.second.buffer = arch_traits<arch_type>::make_message(
m_comm, p1.second.size, device_id);
device::guard g(p1.second.buffer);
packer<arch_type>::unpack(p1.second, g.data());
}
}
}
});
}

void pack()
{
for_each(m_mem, [this](std::size_t, auto& m) {
Expand All @@ -473,12 +605,19 @@ class communication_object
private: // wait functions
void progress()
{
#ifdef GHEX_USE_NCCL
// TODO: No progress needed?
#else
if (!m_valid) return;
m_comm.progress();
#endif
}

bool is_ready()
{
#ifdef GHEX_USE_NCCL
// TODO: Check if streams are idle?
#else
if (!m_valid) return true;
if (m_comm.is_ready())
{
Expand All @@ -497,14 +636,17 @@ class communication_object
clear();
return true;
}
#endif
return false;
}

void wait()
{
#ifndef GHEX_USE_NCCL
if (!m_valid) return;
// wait for data to arrive (unpack callback will be invoked)
m_comm.wait_all();
#endif
#ifdef GHEX_CUDACC
sync_streams();
#endif
Expand All @@ -515,6 +657,10 @@ class communication_object
private: // synchronize (unpacking) streams
void sync_streams()
{
constexpr std::size_t num_events{128};
static std::vector<device::cuda_event> events(num_events);
static std::size_t event_index{0};

using gpu_mem_t = buffer_memory<gpu>;
auto& m = std::get<gpu_mem_t>(m_mem);
for (auto& p0 : m.recv_memory)
Expand All @@ -523,7 +669,18 @@ class communication_object
{
if (p1.second.size > 0u)
{
#ifdef GHEX_USE_NCCL
// Instead of doing a blocking wait, create events on each
// stream that the default stream waits for. This assumes
// that all kernels that need the unpacked data will use or
// synchronize with the default stream.
cudaEvent_t& e = events[event_index].get();
event_index = (event_index + 1) % num_events;
GHEX_CHECK_CUDA_RESULT(cudaEventRecord(e, p1.second.m_stream.get()));
GHEX_CHECK_CUDA_RESULT(cudaStreamWaitEvent(0, e));
#else
p1.second.m_stream.sync();
#endif
}
}
}
Expand Down
32 changes: 27 additions & 5 deletions include/ghex/device/cuda/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,41 @@ namespace ghex
{
namespace device
{
struct cuda_event {
cudaEvent_t m_event;
ghex::util::moved_bit m_moved;

cuda_event() {
GHEX_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming))
}
cuda_event(const cuda_event&) = delete;
cuda_event& operator=(const cuda_event&) = delete;
cuda_event(cuda_event&& other) = default;
cuda_event& operator=(cuda_event&&) = default;

~cuda_event()
{
if (!m_moved)
{
GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event))
}
}

operator bool() const noexcept { return m_moved; }
operator cudaEvent_t() const noexcept { return m_event; }
cudaEvent_t& get() noexcept { return m_event; }
const cudaEvent_t& get() const noexcept { return m_event; }
};

Comment on lines +22 to +47
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate file.

/** @brief thin wrapper around a cuda stream */
struct stream
{
cudaStream_t m_stream;
cudaEvent_t m_event;
ghex::util::moved_bit m_moved;

stream()
{
GHEX_CHECK_CUDA_RESULT(cudaStreamCreateWithFlags(&m_stream, cudaStreamNonBlocking))
GHEX_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming))
}

stream(const stream&) = delete;
Expand All @@ -42,7 +66,6 @@ struct stream
if (!m_moved)
{
GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaStreamDestroy(m_stream))
GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event))
}
}

Expand All @@ -55,9 +78,8 @@ struct stream

void sync()
{
GHEX_CHECK_CUDA_RESULT(cudaEventRecord(m_event, m_stream))
// busy wait here
GHEX_CHECK_CUDA_RESULT(cudaEventSynchronize(m_event))
GHEX_CHECK_CUDA_RESULT(cudaStreamSynchronize(m_stream))
}
};
} // namespace device
Expand Down
Loading