diff --git a/include/ghex/transport_layer/continuation_communicator.hpp b/include/ghex/transport_layer/continuation_communicator.hpp new file mode 100644 index 0000000..6478289 --- /dev/null +++ b/include/ghex/transport_layer/continuation_communicator.hpp @@ -0,0 +1,369 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_TL_CONTINUATION_COMMUNICATOR_HPP +#define INCLUDED_GHEX_TL_CONTINUATION_COMMUNICATOR_HPP + +#include +#include "./callback_communicator.hpp" + +namespace gridtools{ + namespace ghex { + namespace tl { + + // implementation details here: + namespace cont_detail { + + // shared request state + struct request_state + { + // volatile is needed to prevent the compiler + // from optimizing away the check of this member + volatile bool m_ready = false; + bool is_ready() const noexcept { return m_ready; } + }; + + // simple request class which is returned from send and recv calls + struct request + { + std::shared_ptr m_request_state; + bool is_ready() const noexcept { return m_request_state->is_ready(); } + }; + + // type-erased message + struct any_message + { + using value_type = unsigned char; + + struct iface + { + virtual unsigned char* data() noexcept = 0; + virtual const unsigned char* data() const noexcept = 0; + virtual std::size_t size() const noexcept = 0; + virtual ~iface() {} + }; + + template + struct holder final : public iface + { + using value_type = typename Message::value_type; + Message m_message; + holder(Message&& m): m_message{std::move(m)} {} + + unsigned char* data() noexcept override { return reinterpret_cast(m_message.data()); } + const unsigned char* data() const noexcept override { return reinterpret_cast(m_message.data()); } + std::size_t size() const noexcept override { return sizeof(value_type)*m_message.size(); } + }; + + std::unique_ptr m_ptr; + + template + any_message(Message&& m) : m_ptr{std::make_unique>(std::move(m))} {} + any_message(any_message&&) = default; + + unsigned char* data() noexcept { return m_ptr->data(); } + const unsigned char* data() const noexcept { return m_ptr->data(); } + std::size_t size() const noexcept { return m_ptr->size(); } + }; + + // simple wrapper around an l-value reference message (stores pointer and size) + template + struct ref_message + { + using value_type = T;//unsigned char; + T* m_data; + std::size_t m_size; + T* data() noexcept { return m_data; } + const T* data() const noexcept { return m_data; } + std::size_t size() const noexcept { return m_size; } + }; + + // simple shared message which is internally used for send_multi + template + struct shared_message + { + using value_type = typename Message::value_type; + std::shared_ptr m_message; + + shared_message(Message&& m) : m_message{std::make_shared(std::move(m))} {} + shared_message(const shared_message&) = default; + shared_message(shared_message&&) = default; + + value_type* data() noexcept { return m_message->data(); } + const value_type* data() const noexcept { return m_message->data(); } + std::size_t size() const noexcept { return m_message->size(); } + }; + + // type-erased future + struct any_future + { + struct iface + { + virtual bool ready() = 0; + virtual ~iface() {} + }; + + template + struct holder final : public iface + { + Future m_future; + holder() = default; + holder(Future&& fut): m_future{std::move(fut)} {} + bool ready() override { return m_future.ready(); } + }; + + std::unique_ptr m_ptr; + + template + any_future(Future&& fut) : m_ptr{std::make_unique>(std::move(fut))} {} + + bool ready() { return m_ptr->ready(); } + }; + + } // namespace cont_detail + + + + // thread-safe shared communicator which handles callbacks + // note: no templates, everything is type-erased + // relies on future-based basic communicator which is passed for every send/recv + class continuation_communicator + { + public: // member types + + using tag_type = int; + using rank_type = int; + // this is the message type returned in the callback: + using message_type = cont_detail::any_message; + // returned from send/recv to check for completion + using request = cont_detail::request; + + private: // member types + + // wrapper for messages passed by l-value reference + template + using ref_message = cont_detail::ref_message; + + // necessary meta information for each send/receive operation + struct element_type + { + using message_arg_type = message_type; + std::function m_cb; + rank_type m_rank; + tag_type m_tag; + cont_detail::any_future m_future; + message_type m_msg; + std::shared_ptr m_request_state; + }; + // we need thread-safe queues + using lock_free_alloc_t = boost::lockfree::allocator>; + using send_container_type = boost::lockfree::queue>; + using recv_container_type = boost::lockfree::queue>; + + private: // members + + send_container_type m_sends; + recv_container_type m_recvs; + + public: // ctors + + continuation_communicator() : m_sends(128), m_recvs(128) {} + continuation_communicator(const continuation_communicator&) = delete; + continuation_communicator(continuation_communicator&&) = default; + ~continuation_communicator() { /* TODO: consume all*/ } + + public: // send + + // use basic comm to post the send and place the callback in a queue + // returns a request to check for completion + // takes ownership of message if it is an r-value reference! + template + request send(Comm& comm, Message&& msg, rank_type dst, tag_type tag, CallBack&& cb) + { + GHEX_CHECK_CALLBACK + using is_rvalue = std::is_rvalue_reference(msg))>; + return send(comm, std::forward(msg), dst, tag, std::forward(cb), is_rvalue()); + } + + // no-callback version + template + request send(Comm& comm, Message&& msg, rank_type dst, tag_type tag) + { + return send(comm, std::forward(msg), dst, tag, [](message_type,rank_type,tag_type){}); + } + + public: // send multi + + // use basic comm to post the sends and place the callback in a queue + // returns a vector of request to check for completion + // takes ownership of message if it is an r-value reference! + // internally transforms the callback (and the message if moved in) into shared objects + template + std::vector send_multi(Comm& comm, Message&& msg, const Neighs& neighs, tag_type tag, CallBack&& cb) + { + GHEX_CHECK_CALLBACK + using is_rvalue = std::is_rvalue_reference(msg))>; + return send_multi(comm, std::forward(msg), neighs, tag, std::forward(cb), is_rvalue()); + } + + // no-callback version + template + std::vector send_multi(Comm& comm, Message&& msg, const Neighs& neighs, tag_type tag) + { + return send_multi(comm, std::forward(msg), neighs, tag, [](message_type,rank_type,tag_type){}); + } + + public: // receive + + // use basic comm to post the recv and place the callback in a queue + // returns a request to check for completion + // takes ownership of message if it is an r-value reference! + template + request recv(Comm& comm, Message&& msg, rank_type src, tag_type tag, CallBack&& cb) + { + GHEX_CHECK_CALLBACK + using is_rvalue = std::is_rvalue_reference(msg))>; + return recv(comm, std::forward(msg), src, tag, std::forward(cb), is_rvalue()); + } + + // no-callback version + template + request recv(Comm& comm, Message&& msg, rank_type src, tag_type tag) + { + return recv(comm, std::forward(msg), src, tag, [](message_type,rank_type,tag_type){}); + } + + public: // progress + + // progress the ques and return the number of progressed callbacks + std::size_t progress() + { + std::size_t num_completed = 0u; + num_completed += run(m_sends); + num_completed += run(m_recvs); + return num_completed; + } + + + private: // implementation + + template + request send(Comm& comm, Message& msg, rank_type dst, tag_type tag, CallBack&& cb, std::false_type) + { + using V = typename Message::value_type; + request req{std::make_shared()}; + auto fut = comm.send(msg,dst,tag); + auto element_ptr = new element_type{std::forward(cb), dst, tag, std::move(fut), + ref_message{msg.data(),msg.size()}, req.m_request_state}; + while (!m_sends.push(element_ptr)) {} + return req; + } + + template + request send(Comm& comm, Message&& msg, rank_type dst, tag_type tag, CallBack&& cb, std::true_type) + { + request req{std::make_shared()}; + auto fut = comm.send(msg,dst,tag); + auto element_ptr = new element_type{std::forward(cb), dst, tag, std::move(fut), + std::move(msg), req.m_request_state}; + while (!m_sends.push(element_ptr)) {} + return req; + } + + template + request recv(Comm& comm, Message& msg, rank_type src, tag_type tag, CallBack&& cb, std::false_type) + { + using V = typename Message::value_type; + request req{std::make_shared()}; + auto fut = comm.recv(msg,src,tag); + auto element_ptr = new element_type{std::forward(cb), src, tag, std::move(fut), + ref_message{msg.data(),msg.size()}, req.m_request_state}; + while (!m_recvs.push(element_ptr)) {} + return req; + } + + template + request recv(Comm& comm, Message&& msg, rank_type src, tag_type tag, CallBack&& cb, std::true_type) + { + request req{std::make_shared()}; + auto fut = comm.recv(msg,src,tag); + auto element_ptr = new element_type{std::forward(cb), src, tag, std::move(fut), + std::move(msg), req.m_request_state}; + while (!m_recvs.push(element_ptr)) {} + return req; + } + + template + std::vector send_multi(Comm& comm, Message& msg, const Neighs& neighs, tag_type tag, CallBack&& cb, std::false_type) + { + using cb_type = typename std::remove_cv::type>::type; + auto cb_ptr = std::make_shared( std::forward(cb) ); + std::vector reqs; + for (auto id : neighs) + reqs.push_back( send(comm, msg, id, tag, + [cb_ptr](message_type m, rank_type r, tag_type t) + { + // if (cb_ptr->use_count == 1) + (*cb_ptr)(std::move(m),r,t); + }) ); + return reqs; + } + + template + std::vector send_multi(Comm& comm, Message&& msg, const Neighs& neighs, tag_type tag, CallBack&& cb, std::true_type) + { + using cb_type = typename std::remove_cv::type>::type; + auto cb_ptr = std::make_shared( std::forward(cb) ); + cont_detail::shared_message s_msg{std::move(msg)}; + std::vector reqs; + for (auto id : neighs) + { + auto s_msg_cpy = s_msg; + reqs.push_back( send(comm, std::move(s_msg_cpy), id, tag, + [cb_ptr](message_type m, rank_type r, tag_type t) + { + // if (cb_ptr->use_count == 1) + (*cb_ptr)(std::move(m),r,t); + }) ); + } + return reqs; + } + + template + std::size_t run(Queue& d) + { + element_type* ptr = nullptr; + if (d.pop(ptr)) + { + if (ptr->m_future.ready()) + { + // call the callback + ptr->m_cb(std::move(ptr->m_msg), ptr->m_rank, ptr->m_tag); + // make request ready + ptr->m_request_state->m_ready = true; + delete ptr; + return 1u; + } + else + { + while( !d.push(ptr) ) {} + return 0u; + } + } + else return 0u; + } + }; + + } // namespace tl + } // namespace ghex +}// namespace gridtools + +#endif/*INCLUDED_GHEX_TL_CONTINUATION_COMMUNICATOR_HPP */ + diff --git a/tests/transport/CMakeLists.txt b/tests/transport/CMakeLists.txt index 6052534..15fb46b 100644 --- a/tests/transport/CMakeLists.txt +++ b/tests/transport/CMakeLists.txt @@ -1,5 +1,5 @@ -set(_tests test_low_level test_low_level_x test_send_multi test_cancel_request test_attach_detach) +set(_tests test_low_level test_low_level_x test_send_multi test_cancel_request test_attach_detach test_ts) foreach(t_ ${_tests}) diff --git a/tests/transport/test_ts.cpp b/tests/transport/test_ts.cpp new file mode 100644 index 0000000..ad5246f --- /dev/null +++ b/tests/transport/test_ts.cpp @@ -0,0 +1,396 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#include +#include +#include +#include +#include +#include +#include + +#include + +using comm_t = gridtools::ghex::tl::communicator; +using cont_comm_t = gridtools::ghex::tl::continuation_communicator; +using msg_type = gridtools::ghex::tl::message_buffer<>; + +std::atomic num_completed; + +// ring-communication using arbitrary number of threads for communication and progressing +// each rank has num_comm_threads threads +// each rank has num_progress_threads which progress the communication and execute the callbacks +// +// num_comm_threads send to the right rank +// num_comm_threads receive from the left rank +// +// the messages are passed as l-value references (GHEX does not take ownership) +// there is one exception to show-case the usage of moving in the message, but this does not alter the semantics of +// this test +// +// there are two modes: +// - wait mode: + each thread waits until the send and receive are finished +// + this is done using the requests returned by the communicator +// +// - nowait mode: + send and receives are posted, the function returns immediately +// +void test_ring(std::size_t num_progress_threads, std::size_t num_comm_threads, bool wait) +{ + num_completed.store(0u); + + // use basic communicator to establish neighbors + comm_t comm; + const int rank = comm.rank(); + const int r_rank = (rank+1)%comm.size(); + const int l_rank = (rank+comm.size()-1)%comm.size(); + + // shared callback communicator + cont_comm_t cont_comm; + + // per-thread objects + std::vector send_msgs; + std::vector recv_msgs; + std::vector comms; + for (std::size_t i=0; i()[0] = rank; + send_msgs.back().data()[1] = i; + recv_msgs.back().data()[0] = -1; + recv_msgs.back().data()[1] = -1; + } + + // total number of sends and receives + std::size_t num_requests = 2*num_comm_threads; + + // lambda which places send and receive calls + auto send_recv_func_nowait = + [&cont_comm,l_rank,r_rank](comm_t& c, int tag, msg_type& recv_msg, msg_type& send_msg) + { + cont_comm.recv(c, recv_msg,l_rank,tag, + [](cont_comm_t::message_type m, int r, int t) { + std::cout << "received from " << r << " with tag " << t << " and size " << m.size() << std::endl; }); + + // give up ownership of some message + // this is just to illustrate the functionality and syntax + msg_type another_msg(4096); + another_msg.data()[0] = send_msg.data()[0]; + another_msg.data()[1] = send_msg.data()[1]; + cont_comm.send(c, std::move(another_msg),r_rank,tag, + [](cont_comm_t::message_type m, int r, int t) { + std::cout << "sent to " << r << " with tag " << t << " and size " << m.size() << std::endl; }); + }; + + // lambda which places send and receive calls and waits for completion + auto send_recv_func_wait = + [&cont_comm,l_rank,r_rank](comm_t& c, int tag, msg_type& recv_msg, msg_type& send_msg) + { + auto recv_req = cont_comm.recv(c, recv_msg,l_rank,tag, + [](cont_comm_t::message_type m, int r, int t) { + std::cout << "received from " << r << " with tag " << t << " and size " << m.size() << std::endl; }); + + auto send_req = cont_comm.send(c, send_msg,r_rank,tag, + [](cont_comm_t::message_type m, int r, int t) { + std::cout << "sent to " << r << " with tag " << t << " and size " << m.size() << std::endl; }); + while ( !(recv_req.is_ready() && send_req.is_ready()) ) {} + }; + + // lambda which progresses the queues + auto progress_func = + [&cont_comm, num_requests]() + { + while(num_completed < num_requests) + num_completed += cont_comm.progress(); + }; + + // make threads + std::vector threads; + threads.reserve(num_progress_threads+num_comm_threads); + + for (std::size_t i=0; i()[0] == l_rank); + EXPECT_TRUE(recv_msgs[i].data()[1] == (int)i); + } + + comm.barrier(); +} + + +// send multiple messages from rank 0 (broadcast) +// and repost the same message after it's been sent +// each rank has num_comm_threads threads +// each rank has num_progress_threads which progress the communication and execute the callbacks +// +// rank 0: num_comm_threads send twice to each rank (using a repost) +// other ranks: num_comm_threads receive twice from rank 0 +// +// there are two modes: +// - wait mode: + each thread waits until the first round of communication has finished and then reposts +// + this is done using the requests returned by the communicator +// + the messages are passed as l-value references (GHEX does not take ownership) +// +// - nowait mode: + rank 0: each thread submits a send_multi. Another thread from the progress thread-pool +// executes the callback and resubmits a send_multi after the message has been sent +// to all other ranks. +// + other ranks: two receives are posted (with two different recv messages), the function returns immediately +// + the messages are passed as r-value references (GHEX takes ownership) +// +void test_send_multi(std::size_t num_progress_threads, std::size_t num_comm_threads, bool wait) +{ + num_completed.store(0u); + + // use basic communicator to establish neighbors + comm_t comm; + + // shared callback communicator + cont_comm_t cont_comm; + + // per-thread objects + std::vector comms; + std::vector num_reps; // used for counting in the first nowait send callback + for (std::size_t i=0; i neighbor_ranks; + for (int i=1; i()[0] = tag; + // get a vector of requests + auto reqs = cont_comm.send_multi(c, msg, neighbor_ranks, tag, + [](cont_comm_t::message_type m, int r, int t) + { + std::cout << "sent to " << r << " with tag " << t << " and size " << m.size() << std::endl; + }); + // wait until all requests are done + bool finished = false; + while (!finished) + { + bool f = true; + for (auto& r : reqs) + f = f && r.is_ready(); + finished = f; + } + std::cout << "reposting" << std::endl; + reqs = cont_comm.send_multi(c, msg, neighbor_ranks, tag+num_comm_threads, + [](cont_comm_t::message_type m, int r, int t) + { + std::cout << "sent to " << r << " with tag " << t << " and size " << m.size() << std::endl; + }); + // wait until all requests are done + // this is important since otherwise the message will go out of scope and is destroyed + // and that would lead to corruption since we passed the message as l-value reference + finished = false; + while (!finished) + { + bool f = true; + for (auto& r : reqs) + f = f && r.is_ready(); + finished = f; + } + }; + + // nowait mode + auto send_multi_nowait = + [&cont_comm,&neighbor_ranks,num_comm_threads](comm_t& c, int tag, int& num_reps_i) + { + const int s = neighbor_ranks.size(); + msg_type msg(4096); + msg.data()[0] = tag; + // no return value is required, message is moved in + cont_comm.send_multi(c, std::move(msg), neighbor_ranks, tag, + [&num_reps_i,s,&c,&cont_comm,neighbor_ranks,num_comm_threads](cont_comm_t::message_type m, int r, int t) + { + std::cout << "sent to " << r << " with tag " << t << " and size " << m.size() << std::endl; + ++num_reps_i; + // check if the message has been sent to all ranks + if (num_reps_i%s == 0) + { + std::cout << "reposting" << std::endl; + // note the move in the repost: + // it recommended to always move inside a callback since this is safe in all cases! + // here it is actually required to move and not doing so will lead to bad bad things. + cont_comm.send_multi(c, std::move(m), neighbor_ranks, t+num_comm_threads); + } + }); + }; + + // make threads + std::vector threads; + threads.reserve(num_progress_threads+num_comm_threads); + + for (std::size_t i=0; i()[0] = -1; + // get a request as return value + auto req = cont_comm.recv(c, msg, 0, tag, + [](cont_comm_t::message_type m, int, int t) + { + EXPECT_TRUE(reinterpret_cast(m.data())[0] == t); + }); + // wait on the request + while (!req.is_ready()){} + msg.data()[0] = -1; + req = cont_comm.recv(c, msg, 0, tag+num_comm_threads, + [num_comm_threads](cont_comm_t::message_type m, int, int t) + { + EXPECT_TRUE(reinterpret_cast(m.data())[0] == (int)(t-num_comm_threads)); + }); + // wait until the requests is ready + // this is important since otherwise the message will go out of scope and is destroyed + // and that would lead to corruption since we passed the message as l-value reference + while (!req.is_ready()){} + }; + + // nowait mode + auto recv_nowait = + [&cont_comm,num_comm_threads](comm_t& c, int tag) + { + msg_type msg(4096); + msg.data()[0] = -1; + // no return value is required, message is moved in + cont_comm.recv(c, std::move(msg), 0, tag, + [](cont_comm_t::message_type m, int, int t) + { + EXPECT_TRUE(reinterpret_cast(m.data())[0] == t); + }); + // another message is created + msg_type msg2(4096); + msg2.data()[0] = -1; + // no return value is required, message is moved in + cont_comm.recv(c, std::move(msg2), 0, tag+num_comm_threads, + [num_comm_threads](cont_comm_t::message_type m, int, int t) + { + EXPECT_TRUE(reinterpret_cast(m.data())[0] == (int)(t-num_comm_threads)); + }); + }; + + // make threads + std::vector threads; + threads.reserve(num_progress_threads+num_comm_threads); + + for (std::size_t i=0; i