-
Notifications
You must be signed in to change notification settings - Fork 16
Do halo exchanges with NCCL #185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a4f5499
4746b35
d26c05a
99fe0a0
527d590
78879bb
f314a1c
ab0dfd0
4b5833f
ee1b851
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| +1 −1 | ext/hwmalloc | |
| +3 −2 | include/oomph/context.hpp | |
| +3 −2 | src/context.cpp | |
| +6 −3 | src/libfabric/context.cpp | |
| +2 −2 | src/libfabric/context.hpp | |
| +4 −3 | src/mpi/context.hpp | |
| +4 −3 | src/ucx/context.hpp |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| #include <ghex/config.hpp> | ||
| #include <ghex/context.hpp> | ||
| #include <ghex/util/for_each.hpp> | ||
| #include <ghex/util/moved_bit.hpp> | ||
| #include <ghex/util/test_eq.hpp> | ||
| #include <ghex/pattern_container.hpp> | ||
| #include <ghex/device/stream.hpp> | ||
|
|
@@ -24,6 +25,10 @@ | |
| #include <stdio.h> | ||
| #include <functional> | ||
|
|
||
| #ifdef GHEX_USE_NCCL | ||
| #include <nccl.h> | ||
| #endif | ||
|
|
||
| namespace ghex | ||
| { | ||
| // forward declaration for optimization on regular grids | ||
|
|
@@ -207,8 +212,12 @@ class communication_object | |
| using disable_if_buffer_info = std::enable_if_t<!is_buffer_info<T>::value, R>; | ||
|
|
||
| private: // members | ||
| ghex::util::moved_bit m_moved; | ||
| bool m_valid; | ||
| communicator_type m_comm; | ||
| #ifdef GHEX_USE_NCCL | ||
| ncclComm_t m_nccl_comm; | ||
| #endif | ||
|
Comment on lines
+215
to
+220
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this to oomph. |
||
| memory_type m_mem; | ||
| std::vector<send_request_type> m_send_reqs; | ||
| std::vector<recv_request_type> m_recv_reqs; | ||
|
|
@@ -218,12 +227,45 @@ class communication_object | |
| : m_valid(false) | ||
| , m_comm(c.transport_context()->get_communicator()) | ||
| { | ||
| ncclUniqueId id; | ||
| if (m_comm.rank() == 0) { | ||
| ncclGetUniqueId(&id); | ||
| } | ||
| MPI_Comm mpi_comm = m_comm.mpi_comm(); | ||
|
|
||
| MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, mpi_comm); | ||
|
|
||
| GHEX_CHECK_NCCL_RESULT(ncclCommInitRank(&m_nccl_comm, m_comm.size(), id, m_comm.rank())); | ||
| ncclResult_t state; | ||
| do { | ||
| GHEX_CHECK_NCCL_RESULT(ncclCommGetAsyncError(m_nccl_comm, &state)); | ||
| } while(state == ncclInProgress); | ||
| } | ||
| ~communication_object() noexcept { | ||
| if (!m_moved) { | ||
| GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaDeviceSynchronize()); | ||
| GHEX_CHECK_NCCL_RESULT_NO_THROW(ncclCommDestroy(m_nccl_comm)); | ||
| } | ||
| } | ||
| communication_object(const communication_object&) = delete; | ||
| communication_object(communication_object&&) = default; | ||
|
|
||
| communicator_type& communicator() { return m_comm; } | ||
|
|
||
| private: | ||
| template<typename... Archs, typename... Fields> | ||
| void nccl_exchange_impl(buffer_info_type<Archs, Fields>... buffer_infos) { | ||
| pack_nccl(); | ||
|
|
||
| ncclGroupStart(); | ||
| post_sends_nccl(); | ||
| post_recvs_nccl(); | ||
| ncclGroupEnd(); | ||
|
|
||
| unpack_nccl(); | ||
| } | ||
|
Comment on lines
+255
to
+266
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add customization point or similar to allow doing this with NCCL? |
||
|
|
||
|
|
||
| public: // exchange arbitrary field-device-pattern combinations | ||
| /** @brief non-blocking exchange of halo data | ||
| * @tparam Archs list of device types | ||
|
|
@@ -234,8 +276,12 @@ class communication_object | |
| [[nodiscard]] handle_type exchange(buffer_info_type<Archs, Fields>... buffer_infos) | ||
| { | ||
| exchange_impl(buffer_infos...); | ||
| #ifdef GHEX_USE_NCCL | ||
| nccl_exchange_impl(); | ||
| #else | ||
| post_recvs(); | ||
| pack(); | ||
| #endif | ||
| return {this}; | ||
| } | ||
|
|
||
|
|
@@ -248,7 +294,6 @@ class communication_object | |
| [[nodiscard]] disable_if_buffer_info<Iterator, handle_type> exchange( | ||
| Iterator first, Iterator last) | ||
| { | ||
| // call special function for a single range | ||
| return exchange_u(first, last); | ||
| } | ||
|
|
||
|
|
@@ -279,8 +324,12 @@ class communication_object | |
| [[nodiscard]] handle_type exchange(std::pair<Iterators, Iterators>... iter_pairs) | ||
| { | ||
| exchange_impl(iter_pairs...); | ||
| #ifdef GHEX_USE_NCCL | ||
| nccl_exchange_impl(); | ||
| #else | ||
| post_recvs(); | ||
| pack(); | ||
| #endif | ||
| return {this}; | ||
| } | ||
|
|
||
|
|
@@ -462,6 +511,89 @@ class communication_object | |
| }); | ||
| } | ||
|
|
||
| void post_sends_nccl() | ||
| { | ||
| for_each(m_mem, [this](std::size_t, auto& map) { | ||
| for (auto& p0 : map.send_memory) | ||
| { | ||
| const auto device_id = p0.first; | ||
| for (auto& p1 : p0.second) | ||
| { | ||
| if (p1.second.size > 0u) | ||
| { | ||
| device::guard g(p1.second.buffer); | ||
| // TODO: Check why element size isn't relevant for the | ||
| // buffer size (also for recv). | ||
| GHEX_CHECK_NCCL_RESULT( | ||
| ncclSend(static_cast<const void*>(g.data()), p1.second.buffer.size(), | ||
| ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); | ||
| } | ||
| } | ||
| } | ||
| }); | ||
| } | ||
|
|
||
| void post_recvs_nccl() | ||
| { | ||
| for_each(m_mem, [this](std::size_t, auto& m) { | ||
| using arch_type = typename std::remove_reference_t<decltype(m)>::arch_type; | ||
| for (auto& p0 : m.recv_memory) | ||
| { | ||
| const auto device_id = p0.first; | ||
| for (auto& p1 : p0.second) | ||
| { | ||
| if (p1.second.size > 0u) | ||
| { | ||
| if (!p1.second.buffer || p1.second.buffer.size() != p1.second.size | ||
| #if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE) | ||
| || p1.second.buffer.device_id() != device_id | ||
| #endif | ||
| ) | ||
| p1.second.buffer = arch_traits<arch_type>::make_message( | ||
| m_comm, p1.second.size, device_id); | ||
| GHEX_CHECK_NCCL_RESULT( | ||
| ncclRecv(p1.second.buffer.device_data(), p1.second.buffer.size(), | ||
| ncclChar, p1.second.rank, m_nccl_comm, p1.second.m_stream.get())); | ||
| } | ||
| } | ||
| } | ||
| }); | ||
| } | ||
|
|
||
| void pack_nccl() | ||
| { | ||
| for_each(m_mem, [this](std::size_t, auto& m) { | ||
| using arch_type = typename std::remove_reference_t<decltype(m)>::arch_type; | ||
| packer<arch_type>::pack2_nccl(m, m_send_reqs, m_comm); | ||
| }); | ||
| } | ||
|
|
||
| void unpack_nccl() | ||
| { | ||
| for_each(m_mem, [this](std::size_t, auto& m) { | ||
| using arch_type = typename std::remove_reference_t<decltype(m)>::arch_type; | ||
| for (auto& p0 : m.recv_memory) | ||
| { | ||
| const auto device_id = p0.first; | ||
| for (auto& p1 : p0.second) | ||
| { | ||
| if (p1.second.size > 0u) | ||
| { | ||
| if (!p1.second.buffer || p1.second.buffer.size() != p1.second.size | ||
| #if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE) | ||
| || p1.second.buffer.device_id() != device_id | ||
| #endif | ||
| ) | ||
| p1.second.buffer = arch_traits<arch_type>::make_message( | ||
| m_comm, p1.second.size, device_id); | ||
| device::guard g(p1.second.buffer); | ||
| packer<arch_type>::unpack(p1.second, g.data()); | ||
| } | ||
| } | ||
| } | ||
| }); | ||
| } | ||
|
|
||
| void pack() | ||
| { | ||
| for_each(m_mem, [this](std::size_t, auto& m) { | ||
|
|
@@ -473,12 +605,19 @@ class communication_object | |
| private: // wait functions | ||
| void progress() | ||
| { | ||
| #ifdef GHEX_USE_NCCL | ||
| // TODO: No progress needed? | ||
| #else | ||
| if (!m_valid) return; | ||
| m_comm.progress(); | ||
| #endif | ||
| } | ||
|
|
||
| bool is_ready() | ||
| { | ||
| #ifdef GHEX_USE_NCCL | ||
| // TODO: Check if streams are idle? | ||
| #else | ||
| if (!m_valid) return true; | ||
| if (m_comm.is_ready()) | ||
| { | ||
|
|
@@ -497,14 +636,17 @@ class communication_object | |
| clear(); | ||
| return true; | ||
| } | ||
| #endif | ||
| return false; | ||
| } | ||
|
|
||
| void wait() | ||
| { | ||
| #ifndef GHEX_USE_NCCL | ||
| if (!m_valid) return; | ||
| // wait for data to arrive (unpack callback will be invoked) | ||
| m_comm.wait_all(); | ||
| #endif | ||
| #ifdef GHEX_CUDACC | ||
| sync_streams(); | ||
| #endif | ||
|
|
@@ -515,6 +657,10 @@ class communication_object | |
| private: // synchronize (unpacking) streams | ||
| void sync_streams() | ||
| { | ||
| constexpr std::size_t num_events{128}; | ||
| static std::vector<device::cuda_event> events(num_events); | ||
| static std::size_t event_index{0}; | ||
|
|
||
| using gpu_mem_t = buffer_memory<gpu>; | ||
| auto& m = std::get<gpu_mem_t>(m_mem); | ||
| for (auto& p0 : m.recv_memory) | ||
|
|
@@ -523,7 +669,18 @@ class communication_object | |
| { | ||
| if (p1.second.size > 0u) | ||
| { | ||
| #ifdef GHEX_USE_NCCL | ||
| // Instead of doing a blocking wait, create events on each | ||
| // stream that the default stream waits for. This assumes | ||
| // that all kernels that need the unpacked data will use or | ||
| // synchronize with the default stream. | ||
| cudaEvent_t& e = events[event_index].get(); | ||
| event_index = (event_index + 1) % num_events; | ||
| GHEX_CHECK_CUDA_RESULT(cudaEventRecord(e, p1.second.m_stream.get())); | ||
| GHEX_CHECK_CUDA_RESULT(cudaStreamWaitEvent(0, e)); | ||
| #else | ||
| p1.second.m_stream.sync(); | ||
| #endif | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,17 +19,41 @@ namespace ghex | |
| { | ||
| namespace device | ||
| { | ||
| struct cuda_event { | ||
| cudaEvent_t m_event; | ||
| ghex::util::moved_bit m_moved; | ||
|
|
||
| cuda_event() { | ||
| GHEX_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming)) | ||
| } | ||
| cuda_event(const cuda_event&) = delete; | ||
| cuda_event& operator=(const cuda_event&) = delete; | ||
| cuda_event(cuda_event&& other) = default; | ||
| cuda_event& operator=(cuda_event&&) = default; | ||
|
|
||
| ~cuda_event() | ||
| { | ||
| if (!m_moved) | ||
| { | ||
| GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event)) | ||
| } | ||
| } | ||
|
|
||
| operator bool() const noexcept { return m_moved; } | ||
| operator cudaEvent_t() const noexcept { return m_event; } | ||
| cudaEvent_t& get() noexcept { return m_event; } | ||
| const cudaEvent_t& get() const noexcept { return m_event; } | ||
| }; | ||
|
|
||
|
Comment on lines
+22
to
+47
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Separate file. |
||
| /** @brief thin wrapper around a cuda stream */ | ||
| struct stream | ||
| { | ||
| cudaStream_t m_stream; | ||
| cudaEvent_t m_event; | ||
| ghex::util::moved_bit m_moved; | ||
|
|
||
| stream() | ||
| { | ||
| GHEX_CHECK_CUDA_RESULT(cudaStreamCreateWithFlags(&m_stream, cudaStreamNonBlocking)) | ||
| GHEX_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming)) | ||
| } | ||
|
|
||
| stream(const stream&) = delete; | ||
|
|
@@ -42,7 +66,6 @@ struct stream | |
| if (!m_moved) | ||
| { | ||
| GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaStreamDestroy(m_stream)) | ||
| GHEX_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -55,9 +78,8 @@ struct stream | |
|
|
||
| void sync() | ||
| { | ||
| GHEX_CHECK_CUDA_RESULT(cudaEventRecord(m_event, m_stream)) | ||
| // busy wait here | ||
| GHEX_CHECK_CUDA_RESULT(cudaEventSynchronize(m_event)) | ||
| GHEX_CHECK_CUDA_RESULT(cudaStreamSynchronize(m_stream)) | ||
| } | ||
| }; | ||
| } // namespace device | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a hack. Add proper
FindNCCL.cmakemodule.This can be tested using the icon uenv manually setting
export LIBRARY_PATH=/user-environment/env/default/lib64:/user-environment/env/default/lib.