Skip to content

Commit

Permalink
Small MPI refactor (#58)
Browse files Browse the repository at this point in the history
* MPI wrapper can now operate on trivially copyable types using their object representation

* StaggeredAllGather introduced as MPI pattern
  • Loading branch information
kubagalecki authored Jan 20, 2024
1 parent 682c682 commit cc95825
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 176 deletions.
62 changes: 42 additions & 20 deletions include/l3ster/comm/MpiComm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ L3STER_MPI_TYPE_MAPPING_STRUCT(unsigned long long, MPI_UNSIGNED_LONG_LONG) // NO
L3STER_MPI_TYPE_MAPPING_STRUCT(float, MPI_FLOAT) // NOLINT
L3STER_MPI_TYPE_MAPPING_STRUCT(double, MPI_DOUBLE) // NOLINT
L3STER_MPI_TYPE_MAPPING_STRUCT(long double, MPI_LONG_DOUBLE) // NOLINT
L3STER_MPI_TYPE_MAPPING_STRUCT(std::byte, MPI_BYTE) // NOLINT

inline void
handleMPIError(int error, std::string_view err_msg, std::source_location src_loc = std::source_location::current())
Expand All @@ -55,11 +56,12 @@ handleMPIError(int error, std::string_view err_msg, std::source_location src_loc
#define L3STER_INVOKE_MPI(fun__, ...) comm::handleMPIError(fun__(__VA_ARGS__), "Call to " #fun__ " failed")

template < typename T >
concept MpiType_c = requires { MpiType< std::remove_cvref_t< T > >::value(); };
concept MpiBuiltinType_c = requires { MpiType< std::remove_cvref_t< T > >::value(); };

template < typename T >
concept MpiBuf_c = std::ranges::contiguous_range< T > and std::ranges::sized_range< T > and
MpiType_c< std::ranges::range_value_t< T > >;
(MpiBuiltinType_c< std::ranges::range_value_t< T > > or
std::is_trivially_copyable_v< std::ranges::range_value_t< T > >);
template < typename T >
concept MpiBorrowedBuf_c = MpiBuf_c< T > and std::ranges::borrowed_range< T >;

Expand All @@ -68,7 +70,7 @@ concept MpiOutputIterator_c =
std::output_iterator< It, std::ranges::range_value_t< Buf > > and
std::same_as< std::iter_value_t< It >, std::ranges::range_value_t< Buf > > and std::contiguous_iterator< It >;

template < MpiType_c T >
template < MpiBuiltinType_c T >
struct MpiBufView
{
MPI_Datatype type;
Expand All @@ -79,9 +81,23 @@ struct MpiBufView
template < MpiBuf_c Buffer >
auto parseMpiBuf(Buffer&& buf)
{
return MpiBufView{MpiType< std::ranges::range_value_t< decltype(buf) > >::value(),
std::ranges::data(buf),
static_cast< int >(std::ranges::ssize(buf))};
using range_value_t = std::remove_reference_t< std::ranges::range_value_t< Buffer > >;
if constexpr (MpiBuiltinType_c< range_value_t >)
{
const auto mpi_type = MpiType< range_value_t >::value();
const auto data_ptr = std::ranges::data(buf);
const auto size = static_cast< int >(std::ranges::size(buf));
return MpiBufView{mpi_type, data_ptr, size};
}
else
{
constexpr bool is_const_range =
std::is_const_v< std::remove_reference_t< decltype(*std::ranges::begin(buf)) > >;
if constexpr (is_const_range)
return parseMpiBuf(std::as_bytes(std::span{buf}));
else
return parseMpiBuf(std::as_writable_bytes(std::span{buf}));
}
}
} // namespace comm

Expand All @@ -93,16 +109,14 @@ class MpiComm
public:
friend class MpiComm;

template < comm::MpiType_c T >
template < comm::MpiBuiltinType_c T >
[[nodiscard]] auto numElems() const -> int;

[[nodiscard]] int getSource() const { return m_status.MPI_SOURCE; }
[[nodiscard]] int getTag() const { return m_status.MPI_TAG; }
[[nodiscard]] int getError() const { return m_status.MPI_ERROR; }

private:
auto getHandle() -> MPI_Status* { return &m_status; }

MPI_Status m_status;
};
static_assert(std::is_standard_layout_v< Status >);
Expand All @@ -112,6 +126,7 @@ class MpiComm
public:
friend class MpiComm;

Request() = default;
Request(const Request&) = delete;
Request& operator=(const Request&) = delete;
inline Request(Request&&) noexcept;
Expand All @@ -138,11 +153,8 @@ class MpiComm
requires std::same_as< std::ranges::range_value_t< RequestRange >, Request >;

private:
Request() = default;
int waitImpl() noexcept { return MPI_Wait(&m_request, MPI_STATUS_IGNORE); }

auto getHandle() -> MPI_Request* { return &m_request; }

MPI_Request m_request = MPI_REQUEST_NULL;
};
static_assert(std::is_standard_layout_v< Request >);
Expand Down Expand Up @@ -203,6 +215,8 @@ class MpiComm
void allReduce(Data&& data, It out_it, MPI_Op op) const;
template < comm::MpiBuf_c Data, comm::MpiOutputIterator_c< Data > It >
void gather(Data&& data, It out_it, int root) const;
template < comm::MpiBuf_c Data, comm::MpiOutputIterator_c< Data > It >
void allGather(Data&& data, It out_it) const;
template < comm::MpiBuf_c Data >
void broadcast(Data&& data, int root) const;

Expand Down Expand Up @@ -235,7 +249,7 @@ class MpiComm
MPI_Comm m_comm = MPI_COMM_NULL;
};

template < comm::MpiType_c T >
template < comm::MpiBuiltinType_c T >
auto MpiComm::Status::numElems() const -> int
{
int retval{};
Expand Down Expand Up @@ -336,7 +350,7 @@ auto MpiComm::FileHandle::readAtAsync(Data&& read_range, MPI_Offset offset) cons
{
const auto [datatype, buf_begin, buf_size] = comm::parseMpiBuf(read_range);
MpiComm::Request request;
L3STER_INVOKE_MPI(MPI_File_iread_at, m_file, offset, buf_begin, buf_size, datatype, request.getHandle());
L3STER_INVOKE_MPI(MPI_File_iread_at, m_file, offset, buf_begin, buf_size, datatype, &request.m_request);
return request;
}

Expand All @@ -351,7 +365,7 @@ auto MpiComm::FileHandle::writeAtAsync(Data&& write_range, MPI_Offset offset) co
std::ranges::data(write_range),
util::exactIntegerCast< int >(std::ranges::size(write_range)),
datatype,
request.getHandle());
&request.m_request);
return request;
}

Expand Down Expand Up @@ -406,7 +420,7 @@ auto MpiComm::probeAsync(int source, int tag) const -> std::pair< Status, bool >
{
auto retval = std::pair< Status, bool >{};
int flag{};
L3STER_INVOKE_MPI(MPI_Iprobe, source, tag, m_comm, &flag, retval.first.getHandle());
L3STER_INVOKE_MPI(MPI_Iprobe, source, tag, m_comm, &flag, &retval.first.m_status);
retval.second = flag;
return retval;
}
Expand Down Expand Up @@ -437,8 +451,16 @@ template < comm::MpiBuf_c Data, comm::MpiOutputIterator_c< Data > It >
void MpiComm::gather(Data&& data, It out_it, int root) const
{
const auto [datatype, buf_begin, buf_size] = comm::parseMpiBuf(data);
L3STER_INVOKE_MPI(
MPI_Gather, buf_begin, buf_size, datatype, std::addressof(*out_it), buf_size, datatype, root, m_comm);
const auto out_ptr = std::addressof(*out_it);
L3STER_INVOKE_MPI(MPI_Gather, buf_begin, buf_size, datatype, out_ptr, buf_size, datatype, root, m_comm);
}

template < comm::MpiBuf_c Data, comm::MpiOutputIterator_c< Data > It >
void MpiComm::allGather(Data&& data, It out_it) const
{
const auto [datatype, buf_begin, buf_size] = comm::parseMpiBuf(data);
const auto out_ptr = std::addressof(*out_it);
L3STER_INVOKE_MPI(MPI_Allgather, buf_begin, buf_size, datatype, out_ptr, buf_size, datatype, m_comm);
}

template < comm::MpiBuf_c Data >
Expand All @@ -453,7 +475,7 @@ auto MpiComm::broadcastAsync(Data&& data, int root) const -> MpiComm::Request
{
const auto [datatype, buf_begin, buf_size] = comm::parseMpiBuf(data);
auto request = Request{};
L3STER_INVOKE_MPI(MPI_Ibcast, buf_begin, buf_size, datatype, root, m_comm, request.getHandle());
L3STER_INVOKE_MPI(MPI_Ibcast, buf_begin, buf_size, datatype, root, m_comm, &request.m_request);
return request;
}

Expand All @@ -471,7 +493,7 @@ auto MpiComm::allToAllAsync(SendBuf&& send_buf, RecvBuf&& recv_buf) const -> Mpi
const int n_elems = send_size / getSize();
auto request = Request{};
L3STER_INVOKE_MPI(
MPI_Ialltoall, send_begin, n_elems, send_type, recv_begin, n_elems, recv_type, m_comm, request.getHandle());
MPI_Ialltoall, send_begin, n_elems, send_type, recv_begin, n_elems, recv_type, m_comm, &request.m_request);
return request;
}

Expand Down
55 changes: 12 additions & 43 deletions include/l3ster/dofs/DofIntervals.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define L3STER_DOFS_DOFINTERVALS_HPP

#include "l3ster/dofs/NodeCondensation.hpp"
#include "l3ster/util/Algorithm.hpp"
#include "l3ster/util/BitsetManip.hpp"
#include "l3ster/util/Caliper.hpp"

Expand Down Expand Up @@ -82,12 +83,11 @@ void serializeDofIntervals(const node_interval_vector_t< n_fields >& inter
}

template < size_t n_fields >
void deserializeDofIntervals(const std::ranges::sized_range auto& serial_data, auto out_it)
requires std::same_as< std::ranges::range_value_t< std::decay_t< decltype(serial_data) > >, unsigned long long > and
std::output_iterator< decltype(out_it), node_interval_t< n_fields > >
void deserializeDofIntervals(std::span< const unsigned long long > serial_data,
std::output_iterator< node_interval_t< n_fields > > auto out_it)
{
constexpr auto n_ulls = util::bitsetNUllongs< n_fields >();
for (auto data_it = std::ranges::begin(serial_data); data_it != std::ranges::end(serial_data);)
for (auto data_it = serial_data.begin(); data_it != serial_data.end();)
{
auto delims = std::array< n_id_t, 2 >{};
auto serial_fieldcov = std::array< unsigned long long, n_ulls >{};
Expand All @@ -101,53 +101,22 @@ template < size_t n_fields >
auto gatherGlobalDofIntervals(const MpiComm& comm, const node_interval_vector_t< n_fields >& local_intervals)
-> node_interval_vector_t< n_fields >
{
constexpr size_t serial_interval_size = util::bitsetNUllongs< n_fields >() + 2;
const size_t n_intervals_local = local_intervals.size();
const auto comm_size = comm.getSize();
const auto my_rank = comm.getRank();

const size_t max_n_intervals_global = std::invoke([&] {
size_t retval{};
comm.allReduce(std::views::single(n_intervals_local), &retval, MPI_MAX);
return retval;
});
const size_t max_msg_size = max_n_intervals_global * serial_interval_size + 1u;
auto serial_local_intervals = std::invoke([&] {
auto retval = util::ArrayOwner< unsigned long long >(max_msg_size);
retval.front() = n_intervals_local;
serializeDofIntervals(local_intervals, std::next(retval.begin()));
constexpr size_t serial_interval_size = util::bitsetNUllongs< n_fields >() + 2;
const auto my_rank = comm.getRank();
const auto serialized_local_intervals = std::invoke([&] {
auto retval = util::ArrayOwner< unsigned long long >(serial_interval_size * local_intervals.size());
serializeDofIntervals(local_intervals, retval.begin());
return retval;
});

node_interval_vector_t< n_fields > retval;
retval.reserve(comm_size * max_n_intervals_global);
auto proc_buf = util::ArrayOwner< unsigned long long >(max_msg_size);
const auto process_data = [&](int sender_rank) {
retval.reserve(comm.getSize() * local_intervals.size());
const auto process_received = [&](std::span< const unsigned long long > received_intervals, int sender_rank) {
if (sender_rank != my_rank)
{
const auto received_intervals =
proc_buf | std::views::drop(1) | std::views::take(proc_buf.front() * serial_interval_size);
deserializeDofIntervals< n_fields >(received_intervals, std::back_inserter(retval));
}
else
std::ranges::copy(local_intervals, std::back_inserter(retval));
};

auto msg_buf =
my_rank == 0 ? std::move(serial_local_intervals) : util::ArrayOwner< unsigned long long >(max_msg_size);
auto request = comm.broadcastAsync(msg_buf, 0);
for (int root_rank = 1; root_rank < comm_size; ++root_rank)
{
request.wait();
std::swap(msg_buf, proc_buf);
if (my_rank == root_rank)
msg_buf = std::move(serial_local_intervals);
request = comm.broadcastAsync(msg_buf, root_rank);
process_data(root_rank - 1);
}
request.wait();
std::swap(msg_buf, proc_buf);
process_data(comm_size - 1);
util::staggeredAllGather(comm, std::span{serialized_local_intervals}, process_received);
return retval;
}

Expand Down
Loading

0 comments on commit cc95825

Please sign in to comment.