diff --git a/.clang-format b/.clang-format index e941415f..ce920879 100644 --- a/.clang-format +++ b/.clang-format @@ -1,93 +1,135 @@ +# Copyright (c) 2016 Thomas Heller +# Copyright (c) 2016-2018 Hartmut Kaiser +# +# SPDX-License-Identifier: BSL-1.0 +# Distributed under the Boost Software License, Version 1.0. (See accompanying +# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +# This .clang-format file is a suggested configuration file for formatting +# source code for the pika project. +# +# Here are a couple of guidelines of how to use this file. +# +# - You should use this file for creating an initial formatting for new files. +# +# - Please separate edits which are pure formatting into isolated commits +# keeping those distinct from edits changing any of the code. +# +# - Please do _not_ configure your editor to automatically format the source +# file while saving edits to disk +# - Please do _not_ reformat a full source file without dire need. + +# PLEASE NOTE: This file requires clang-format V18.0 + --- -Language: Cpp -# BasedOnStyle: LLVM -#AccessModifierOffset: -4 -AccessModifierOffset: -2 +AccessModifierOffset: -4 AlignAfterOpenBracket: DontAlign AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: true -AlignConsecutiveMacros: true -#AlignConsecutiveAssignments: false -#AlignConsecutiveDeclarations: false -#AlignConsecutiveMacros: false +AlignConsecutiveDeclarations: false AlignEscapedNewlines: Right -AlignOperands: true +AlignOperands: false AlignTrailingComments: true -AllowAllArgumentsOnNextLine: false AllowAllParametersOfDeclarationOnNextLine: false -BinPackArguments: true -BinPackParameters: true -BreakBeforeBraces: Allman -#BreakBeforeBraces: Attach -#BreakBeforeBraces: Custom -#BraceWrapping: -# AfterCaseLabel: true -# AfterClass: true -# BreakBeforeBinaryOperators: All -# ConstructorInitializerAllOnOneLineOrOnePerLine: false -BreakConstructorInitializers: BeforeComma -ConstructorInitializerIndentWidth: 0 -BreakInheritanceList: BeforeComma -#AllowShortBlocksOnASingleLine: Always -AllowShortBlocksOnASingleLine: true -AllowShortCaseLabelsOnASingleLine: false +AllowShortBlocksOnASingleLine: Always +AllowShortCaseLabelsOnASingleLine: true +AllowShortEnumsOnASingleLine: true AllowShortFunctionsOnASingleLine: All -AllowShortIfStatementsOnASingleLine: Always +AllowShortIfStatementsOnASingleLine: WithoutElse AllowShortLambdasOnASingleLine: All AllowShortLoopsOnASingleLine: true -#AlwaysBreakAfterReturnType: None -AlwaysBreakAfterReturnType: TopLevelDefinitions -# PenaltyReturnTypeOnItsOwnLine: 1 +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: false -AlwaysBreakTemplateDeclarations: Yes -BreakBeforeTernaryOperators: true -BreakInheritanceList: BeforeComma -BreakStringLiterals: false +AlwaysBreakTemplateDeclarations: true +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterCaseLabel: true + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: false + AfterStruct: true + AfterUnion: true + BeforeCatch: true + BeforeElse: true + IndentBraces: false +BreakAfterAttributes: Leave +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BreakBeforeInheritanceComma: true +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: false +BreakConstructorInitializersBeforeComma: true +BreakStringLiterals: true ColumnLimit: 100 -#ColumnLimit: 120 -# CommentPragmas -CompactNamespaces: false +CommentPragmas: "///" +CompactNamespaces: true +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 2 ContinuationIndentWidth: 4 Cpp11BracedListStyle: true -#--DeriveLineEnding: true -#--DerivePointerAlignment: true +DerivePointerAlignment: false +#ExperimentalAutoDetectBinPacking: true # Do weird reformatting FixNamespaceComments: true -# ForEachMacros: [for_each] -# IncludeBlocks: Regroup -IndentCaseLabels: true -#IndentGotoLabels: false -#IndentPPDirectives: BeforeHash -IndentPPDirectives: None +# ForEachMacros: [''] +IncludeCategories: + - Regex: '^' + Priority: 1 + - Regex: '^' + Priority: 2 + - Regex: '^' + Priority: 3 + - Regex: '^' + Priority: 4 + - Regex: '^' + Priority: 5 + - Regex: '^<.*' + Priority: 6 + - Regex: '.*' + Priority: 7 +# IncludeIsMainRegex: '' +IndentCaseLabels: false IndentWidth: 4 IndentWrappedFunctionNames: false +IndentPPDirectives: AfterHash +InsertBraces: false +IntegerLiteralSeparator: + Binary: 4 + Decimal: 0 + Hex: 4 KeepEmptyLinesAtTheStartOfBlocks: false +Language: Cpp +# MacroBlockBegin: '' +# MacroBlockEnd: '' MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None +NamespaceIndentation: All +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 20 PointerAlignment: Left +PPIndentWidth: 1 ReflowComments: false -SortIncludes: false -SortUsingDeclarations: false -SpaceAfterCStyleCast: false -SpaceAfterLogicalNot: false -SpaceAfterTemplateKeyword: false +QualifierAlignment: Right +RequiresClausePosition: OwnLine +RequiresExpressionIndentation: OuterScope +SortIncludes: true +SpaceAfterCStyleCast: true +SpaceAfterTemplateKeyword: true SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -#SpaceBeforeSquareBrackets: false SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 1 +SpacesBeforeTrailingComments: 4 SpacesInAngles: false SpacesInCStyleCastParentheses: false -# SpacesInConditionalStatement: false SpacesInContainerLiterals: false SpacesInParentheses: false SpacesInSquareBrackets: false -#Standard: Latest +Standard: Cpp11 TabWidth: 4 UseTab: Never - - - +... diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml new file mode 100644 index 00000000..52ea587e --- /dev/null +++ b/.github/workflows/clang-format.yml @@ -0,0 +1,20 @@ +name: CI + +on: + push: + pull_request: + branches: + - main + +jobs: + clang-format-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: DoozyX/clang-format-lint-action@v0.20 + with: + source: "." + exclude: "./ext" + extensions: "hpp,cpp,hpp.in,cu" + clangFormatVersion: 18 + style: file diff --git a/benchmarks/accumulator.hpp b/benchmarks/accumulator.hpp index 3c111fd2..97bb0ead 100644 --- a/benchmarks/accumulator.hpp +++ b/benchmarks/accumulator.hpp @@ -9,156 +9,154 @@ */ #pragma once -#include #include #include #include +#include #include #include -namespace oomph -{ -/** @brief accumulates samples and computes basic statistics on-line in a numerically stable fashion. */ -class accumulator -{ - public: // member types - using size_type = std::size_t; - using value_type = double; +namespace oomph { + /** @brief accumulates samples and computes basic statistics on-line in a numerically stable fashion. */ + class accumulator + { + public: // member types + using size_type = std::size_t; + using value_type = double; - private: // members - size_type m_num_samples = 0u; - value_type m_min = std::numeric_limits::max(); - value_type m_max = std::numeric_limits::min(); - value_type m_mean = 0; - value_type m_variance = 0; + private: // members + size_type m_num_samples = 0u; + value_type m_min = std::numeric_limits::max(); + value_type m_max = std::numeric_limits::min(); + value_type m_mean = 0; + value_type m_variance = 0; - public: // ctors - accumulator() noexcept = default; - accumulator(const accumulator&) noexcept = default; - accumulator(accumulator&&) noexcept = default; - accumulator(size_type num_samples_, value_type min_, value_type max_, value_type mean_, - value_type variance_) noexcept - : m_num_samples(num_samples_) - , m_min(min_) - , m_max(max_) - , m_mean(mean_) - , m_variance(variance_) - { - } - accumulator& operator=(const accumulator&) noexcept = default; - accumulator& operator=(accumulator&&) noexcept = default; + public: // ctors + accumulator() noexcept = default; + accumulator(accumulator const&) noexcept = default; + accumulator(accumulator&&) noexcept = default; + accumulator(size_type num_samples_, value_type min_, value_type max_, value_type mean_, + value_type variance_) noexcept + : m_num_samples(num_samples_) + , m_min(min_) + , m_max(max_) + , m_mean(mean_) + , m_variance(variance_) + { + } + accumulator& operator=(accumulator const&) noexcept = default; + accumulator& operator=(accumulator&&) noexcept = default; - public: // return statistics - inline size_type num_samples() const noexcept { return m_num_samples; } - inline value_type min() const noexcept { return m_min; } - inline value_type max() const noexcept { return m_max; } - inline value_type mean() const noexcept { return m_mean; } - inline value_type variance() const noexcept - { - return ((m_num_samples > 1) ? (m_variance / (m_num_samples - 1)) : 0); - } - inline value_type stddev() const noexcept - { - return (m_num_samples > 1 ? std::sqrt(variance()) : 0); - } - inline value_type sum() const noexcept { return m_num_samples * m_mean; } + public: // return statistics + inline size_type num_samples() const noexcept { return m_num_samples; } + inline value_type min() const noexcept { return m_min; } + inline value_type max() const noexcept { return m_max; } + inline value_type mean() const noexcept { return m_mean; } + inline value_type variance() const noexcept + { + return ((m_num_samples > 1) ? (m_variance / (m_num_samples - 1)) : 0); + } + inline value_type stddev() const noexcept + { + return (m_num_samples > 1 ? std::sqrt(variance()) : 0); + } + inline value_type sum() const noexcept { return m_num_samples * m_mean; } - public: // add samples - /** @brief accumulate samples + public: // add samples + /** @brief accumulate samples * @tparam InputIterator Iterator type over a sample range * @param first iterator pointing to the first sample * @param last iterator pointing to one-after-last sample * @return reference to this object */ - template - inline accumulator& operator()(InputIterator first, InputIterator last) noexcept - { - for (auto sample_ptr = first; sample_ptr != last; ++sample_ptr) (*this)(*sample_ptr); - return *this; - } + template + inline accumulator& operator()(InputIterator first, InputIterator last) noexcept + { + for (auto sample_ptr = first; sample_ptr != last; ++sample_ptr) (*this)(*sample_ptr); + return *this; + } - /** @brief accumulate one sample + /** @brief accumulate one sample * @param sample a sample * @return reference to this object */ - inline accumulator& operator()(value_type sample) noexcept - { - m_min = std::min(m_min, sample); - m_max = std::max(m_max, sample); - const value_type delta = sample - m_mean; - m_mean += delta / (++m_num_samples); - m_variance += delta * (sample - m_mean); - return *this; - } + inline accumulator& operator()(value_type sample) noexcept + { + m_min = std::min(m_min, sample); + m_max = std::max(m_max, sample); + value_type const delta = sample - m_mean; + m_mean += delta / (++m_num_samples); + m_variance += delta * (sample - m_mean); + return *this; + } - /** @brief accumulate another accumulator + /** @brief accumulate another accumulator * @param other another accumulator object * @return reference to this object */ - inline accumulator& operator()(const accumulator& other) noexcept - { - if (other.m_num_samples == 0) return *this; - if (m_num_samples == 0) + inline accumulator& operator()(accumulator const& other) noexcept { - m_num_samples = other.m_num_samples; - m_min = other.m_min; - m_max = other.m_max; - m_mean = other.m_mean; - m_variance = other.m_variance; + if (other.m_num_samples == 0) return *this; + if (m_num_samples == 0) + { + m_num_samples = other.m_num_samples; + m_min = other.m_min; + m_max = other.m_max; + m_mean = other.m_mean; + m_variance = other.m_variance; + return *this; + } + m_min = std::min(m_min, other.min()); + m_max = std::max(m_max, other.max()); + auto const delta = other.m_mean - m_mean; + auto const num_samples_new = m_num_samples + other.m_num_samples; + m_mean += (delta * other.m_num_samples) / num_samples_new; + m_variance += other.m_variance + + (delta * delta * m_num_samples * other.m_num_samples) / num_samples_new; + m_num_samples = num_samples_new; return *this; } - m_min = std::min(m_min, other.min()); - m_max = std::max(m_max, other.max()); - const auto delta = other.m_mean - m_mean; - const auto num_samples_new = m_num_samples + other.m_num_samples; - m_mean += (delta * other.m_num_samples) / num_samples_new; - m_variance += other.m_variance + - (delta * delta * m_num_samples * other.m_num_samples) / num_samples_new; - m_num_samples = num_samples_new; - return *this; - } - public: - /** @brief reset accumulator */ - inline void clear() noexcept - { - m_num_samples = 0; - m_min = std::numeric_limits::max(); - m_max = std::numeric_limits::min(); - m_mean = 0; - m_variance = 0; - } + public: + /** @brief reset accumulator */ + inline void clear() noexcept + { + m_num_samples = 0; + m_min = std::numeric_limits::max(); + m_max = std::numeric_limits::min(); + m_mean = 0; + m_variance = 0; + } - public: - /** @brief print info to output stream */ - template> - friend std::basic_ostream& operator<<( - std::basic_ostream& os, const accumulator& acc) - { - os << "[" << acc.min() << "," << acc.mean() << "," << acc.max() << "] (" << acc.stddev() - << "," << acc.num_samples() << ")"; - return os; - } -}; + public: + /** @brief print info to output stream */ + template > + friend std::basic_ostream& + operator<<(std::basic_ostream& os, accumulator const& acc) + { + os << "[" << acc.min() << "," << acc.mean() << "," << acc.max() << "] (" << acc.stddev() + << "," << acc.num_samples() << ")"; + return os; + } + }; -/** @brief all-reduce accumulators over the MPI group defined by the communicator + /** @brief all-reduce accumulators over the MPI group defined by the communicator * @param acc accumulator local to each rank * @param comm MPI communicator * @return combined allocator incorporating all samples */ -accumulator -reduce(const accumulator& acc, MPI_Comm comm) -{ - int rank, size; - MPI_Comm_rank(comm, &rank); - MPI_Comm_size(comm, &size); - std::vector accs; - if (rank == 0) { accs.resize(size); } - MPI_Gather(reinterpret_cast(&acc), sizeof(accumulator), MPI_BYTE, - reinterpret_cast(accs.data()), sizeof(accumulator), MPI_BYTE, 0, comm); - accumulator acc_all; - if (rank == 0) + accumulator reduce(accumulator const& acc, MPI_Comm comm) { - for (const auto x : accs) { acc_all(x); } + int rank, size; + MPI_Comm_rank(comm, &rank); + MPI_Comm_size(comm, &size); + std::vector accs; + if (rank == 0) { accs.resize(size); } + MPI_Gather(reinterpret_cast(&acc), sizeof(accumulator), MPI_BYTE, + reinterpret_cast(accs.data()), sizeof(accumulator), MPI_BYTE, 0, comm); + accumulator acc_all; + if (rank == 0) + { + for (auto const x : accs) { acc_all(x); } + } + MPI_Bcast(reinterpret_cast(&acc_all), sizeof(accumulator), MPI_BYTE, 0, comm); + return acc_all; } - MPI_Bcast(reinterpret_cast(&acc_all), sizeof(accumulator), MPI_BYTE, 0, comm); - return acc_all; -} -} // namespace oomph +} // namespace oomph diff --git a/benchmarks/args.hpp b/benchmarks/args.hpp index 2d9145d3..4848e226 100644 --- a/benchmarks/args.hpp +++ b/benchmarks/args.hpp @@ -14,55 +14,50 @@ #include #ifndef OOMPH_BENCHMARKS_PURE_MPI -#include +# include #endif #ifdef OOMPH_BENCHMARKS_MT -#include +# include #endif -namespace oomph -{ -struct args -{ - bool is_valid = true; - int n_iter = 0; - int n_secs = 5; - int buff_size = 0; - int inflight = 0; - int num_threads = 1; - - args(int argc, char** argv, bool timed = false) +namespace oomph { + struct args { - if (argc != 4) + bool is_valid = true; + int n_iter = 0; + int n_secs = 5; + int buff_size = 0; + int inflight = 0; + int num_threads = 1; + + args(int argc, char** argv, bool timed = false) { - is_valid = false; + if (argc != 4) + { + is_valid = false; #ifndef OOMPH_BENCHMARKS_PURE_MPI - if (argc == 2 && !std::strcmp(argv[1], "-c")) print_config(); + if (argc == 2 && !std::strcmp(argv[1], "-c")) print_config(); #endif - } - else - { - if (timed) { - n_secs = std::atoi(argv[1]); - } - else { - n_iter = std::atoi(argv[1]); } - buff_size = std::atoi(argv[2]); - inflight = std::atoi(argv[3]); + else + { + if (timed) { n_secs = std::atoi(argv[1]); } + else { n_iter = std::atoi(argv[1]); } + buff_size = std::atoi(argv[2]); + inflight = std::atoi(argv[3]); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp parallel - { -#pragma omp master - num_threads = omp_get_num_threads(); - } +# pragma omp parallel + { +# pragma omp master + num_threads = omp_get_num_threads(); + } #endif + } } - } - operator bool() const noexcept { return is_valid; } -}; + operator bool() const noexcept { return is_valid; } + }; -} // namespace oomph +} // namespace oomph diff --git a/benchmarks/bench_p2p_bi_cb_avail_mt.cpp b/benchmarks/bench_p2p_bi_cb_avail_mt.cpp index 5cd3c2c9..94a41b3f 100644 --- a/benchmarks/bench_p2p_bi_cb_avail_mt.cpp +++ b/benchmarks/bench_p2p_bi_cb_avail_mt.cpp @@ -7,19 +7,18 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include #include -#include "./mpi_environment.hpp" +#include +#include #include "./args.hpp" +#include "./mpi_environment.hpp" #include "./timer.hpp" #include "./utils.hpp" -#include -const char *syncmode = "callback"; -const char *waitmode = "avail"; +char const* syncmode = "callback"; +char const* waitmode = "avail"; -int -main(int argc, char** argv) +int main(int argc, char** argv) { using namespace oomph; using message = oomph::message_buffer; @@ -33,13 +32,13 @@ main(int argc, char** argv) context ctxt(MPI_COMM_WORLD, multi_threaded); barrier b(ctxt, cmd_args.num_threads); - timer t0; - timer t1; + timer t0; + timer t1; - const auto inflight = cmd_args.inflight; - const auto num_threads = cmd_args.num_threads; - const auto buff_size = cmd_args.buff_size; - const auto niter = cmd_args.n_iter; + auto const inflight = cmd_args.inflight; + auto const num_threads = cmd_args.num_threads; + auto const buff_size = cmd_args.buff_size; + auto const niter = cmd_args.n_iter; if (env.rank == 0) { @@ -61,32 +60,30 @@ main(int argc, char** argv) #endif #ifdef OOMPH_BENCHMARKS_MT -#pragma omp parallel +# pragma omp parallel #endif { - auto comm = ctxt.get_communicator(); - const auto rank = comm.rank(); - const auto size = comm.size(); - const auto thread_id = THREADID; - const auto peer_rank = (rank + 1) % size; - - int comm_cnt = 0, nlsend_cnt = 0, nlrecv_cnt = 0, submit_cnt = 0, submit_recv_cnt = 0; - int last_received = 0; - int last_sent = 0; - int dbg = 0, sdbg = 0, rdbg = 0; - int lsent = 0, lrecv = 0; - const int delta_i = niter / 10; - - auto send_callback = [inflight, &nlsend_cnt, &comm_cnt, &sent]( - message&, int, int tag) { + auto comm = ctxt.get_communicator(); + auto const rank = comm.rank(); + auto const size = comm.size(); + auto const thread_id = THREADID; + auto const peer_rank = (rank + 1) % size; + + int comm_cnt = 0, nlsend_cnt = 0, nlrecv_cnt = 0, submit_cnt = 0, submit_recv_cnt = 0; + int last_received = 0; + int last_sent = 0; + int dbg = 0, sdbg = 0, rdbg = 0; + int lsent = 0, lrecv = 0; + int const delta_i = niter / 10; + + auto send_callback = [inflight, &nlsend_cnt, &comm_cnt, &sent](message&, int, int tag) { int pthr = tag / inflight; if (pthr != THREADID) nlsend_cnt++; comm_cnt++; sent++; }; - auto recv_callback = [inflight, &nlrecv_cnt, &comm_cnt, &received]( - message&, int, int tag) { + auto recv_callback = [inflight, &nlrecv_cnt, &comm_cnt, &received](message&, int, int tag) { int pthr = tag / inflight; if (pthr != THREADID) nlrecv_cnt++; comm_cnt++; @@ -94,10 +91,12 @@ main(int argc, char** argv) }; if (thread_id == 0 && rank == 0) - { std::cout << "\n\nrunning test " << __FILE__ << "\n\n"; }; + { + std::cout << "\n\nrunning test " << __FILE__ << "\n\n"; + }; - std::vector smsgs(inflight); - std::vector rmsgs(inflight); + std::vector smsgs(inflight); + std::vector rmsgs(inflight); std::vector sreqs(inflight); std::vector rreqs(inflight); for (int j = 0; j < inflight; j++) @@ -125,8 +124,8 @@ main(int argc, char** argv) dbg = 0; std::cout << rank << " total bwdt MB/s: " << ((received - last_received + sent - last_sent) * size * - (double)buff_size / 2) / - t0.stoc() + (double) buff_size / 2) / + t0.stoc() << "\n"; t0.tic(); last_received = received; @@ -156,8 +155,8 @@ main(int argc, char** argv) comm.recv(rmsgs[j], peer_rank, thread_id * inflight + j, recv_callback); lrecv++; } - else comm.progress(); - + else + comm.progress(); // if(lsent < lrecv+2*inflight && sent < niter && smsgs[j].use_count() == 1) if (lsent < lrecv + 2 * inflight && sent < niter && (sreqs[j].test())) @@ -169,7 +168,8 @@ main(int argc, char** argv) comm.send(smsgs[j], peer_rank, thread_id * inflight + j, send_callback); lsent++; } - else comm.progress(); + else + comm.progress(); } } @@ -177,8 +177,8 @@ main(int argc, char** argv) if (thread_id == 0 && rank == 0) { - const auto t = t1.stoc(); - double bw = ((double)niter*size*buff_size)/t; + auto const t = t1.stoc(); + double bw = ((double) niter * size * buff_size) / t; // clang-format off std::cout << "time: " << t / 1000000 << "s\n"; std::cout << "final MB/s: " << bw << "\n"; @@ -200,7 +200,7 @@ main(int argc, char** argv) b(); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp critical +# pragma omp critical #endif { std::cout << "rank " << rank << " thread " << thread_id << " sends submitted " @@ -218,8 +218,7 @@ main(int argc, char** argv) int send_complete = 0; // complete all posted sends - do - { + do { comm.progress(); // check if we have completed all our posted sends if (!send_complete) @@ -251,17 +250,17 @@ main(int argc, char** argv) // Notify the peer and keep submitting recvs until we get his notification. send_request sf; recv_request rf; - auto smsg = comm.make_buffer(1); - auto rmsg = comm.make_buffer(1); + auto smsg = comm.make_buffer(1); + auto rmsg = comm.make_buffer(1); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp master +# pragma omp master #endif { - sf = comm.send( - smsg, peer_rank, 0x80000); //, [](communicator_type::message_type, int, int){}); - rf = comm.recv( - rmsg, peer_rank, 0x80000); //, [](communicator_type::message_type, int, int){}); + sf = comm.send(smsg, peer_rank, + 0x8'0000); //, [](communicator_type::message_type, int, int){}); + rf = comm.recv(rmsg, peer_rank, + 0x8'0000); //, [](communicator_type::message_type, int, int){}); } while (tail_recv == 0) @@ -279,7 +278,7 @@ main(int argc, char** argv) } #ifdef OOMPH_BENCHMARKS_MT -#pragma omp master +# pragma omp master #endif { if (rf.test()) tail_recv = 1; diff --git a/benchmarks/bench_p2p_bi_cb_wait_mt.cpp b/benchmarks/bench_p2p_bi_cb_wait_mt.cpp index 24a9d88c..904d5ac1 100644 --- a/benchmarks/bench_p2p_bi_cb_wait_mt.cpp +++ b/benchmarks/bench_p2p_bi_cb_wait_mt.cpp @@ -7,19 +7,18 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include #include -#include "./mpi_environment.hpp" +#include +#include #include "./args.hpp" +#include "./mpi_environment.hpp" #include "./timer.hpp" #include "./utils.hpp" -#include -const char *syncmode = "callback"; -const char *waitmode = "wait"; +char const* syncmode = "callback"; +char const* waitmode = "wait"; -int -main(int argc, char** argv) +int main(int argc, char** argv) { using namespace oomph; using message = oomph::message_buffer; @@ -33,13 +32,13 @@ main(int argc, char** argv) context ctxt(MPI_COMM_WORLD, multi_threaded); barrier b(ctxt, cmd_args.num_threads); - timer t0; - timer t1; + timer t0; + timer t1; - const auto inflight = cmd_args.inflight; - const auto num_threads = cmd_args.num_threads; - const auto buff_size = cmd_args.buff_size; - const auto niter = cmd_args.n_iter; + auto const inflight = cmd_args.inflight; + auto const num_threads = cmd_args.num_threads; + auto const buff_size = cmd_args.buff_size; + auto const niter = cmd_args.n_iter; if (env.rank == 0) { @@ -57,30 +56,28 @@ main(int argc, char** argv) #endif #ifdef OOMPH_BENCHMARKS_MT -#pragma omp parallel +# pragma omp parallel #endif { - auto comm = ctxt.get_communicator(); - const auto rank = comm.rank(); - const auto size = comm.size(); - const auto thread_id = THREADID; - const auto peer_rank = (rank + 1) % size; - - int comm_cnt = 0, nlsend_cnt = 0, nlrecv_cnt = 0; - int i = 0, dbg = 0; - int last_i = 0; - const int delta_i = niter / 10; - - auto send_callback = [inflight, &nlsend_cnt, &comm_cnt, &sent]( - message&, int, int tag) { + auto comm = ctxt.get_communicator(); + auto const rank = comm.rank(); + auto const size = comm.size(); + auto const thread_id = THREADID; + auto const peer_rank = (rank + 1) % size; + + int comm_cnt = 0, nlsend_cnt = 0, nlrecv_cnt = 0; + int i = 0, dbg = 0; + int last_i = 0; + int const delta_i = niter / 10; + + auto send_callback = [inflight, &nlsend_cnt, &comm_cnt, &sent](message&, int, int tag) { int pthr = tag / inflight; if (pthr != THREADID) nlsend_cnt++; comm_cnt++; sent++; }; - auto recv_callback = [inflight, &nlrecv_cnt, &comm_cnt, &received]( - message&, int, int tag) { + auto recv_callback = [inflight, &nlrecv_cnt, &comm_cnt, &received](message&, int, int tag) { int pthr = tag / inflight; if (pthr != THREADID) nlrecv_cnt++; comm_cnt++; @@ -88,10 +85,12 @@ main(int argc, char** argv) }; if (thread_id == 0 && rank == 0) - { std::cout << "\n\nrunning test " << __FILE__ << "\n\n"; }; + { + std::cout << "\n\nrunning test " << __FILE__ << "\n\n"; + }; - std::vector smsgs(inflight); - std::vector rmsgs(inflight); + std::vector smsgs(inflight); + std::vector rmsgs(inflight); std::vector sreqs(inflight); std::vector rreqs(inflight); for (int j = 0; j < inflight; j++) @@ -117,14 +116,14 @@ main(int argc, char** argv) // ghex barrier not needed here (all comm finished), and VERY SLOW // barrier.in_node(comm); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp barrier +# pragma omp barrier #endif if (thread_id == 0 && dbg >= delta_i) { dbg = 0; std::cout << rank << " total bwdt MB/s: " - << ((i - last_i) * size * (double)buff_size) / t0.stoc() << "\n"; + << ((i - last_i) * size * (double) buff_size) / t0.stoc() << "\n"; t0.tic(); last_i = i; } @@ -147,7 +146,7 @@ main(int argc, char** argv) // ghex barrier not needed here (all comm finished), and VERY SLOW // barrier.in_node(comm); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp barrier +# pragma omp barrier #endif sent = 0; received = 0; @@ -157,8 +156,8 @@ main(int argc, char** argv) if (thread_id == 0 && rank == 0) { - const auto t = t1.stoc(); - double bw = ((double)niter*size*buff_size)/t; + auto const t = t1.stoc(); + double bw = ((double) niter * size * buff_size) / t; // clang-format off std::cout << "time: " << t / 1000000 << "s\n"; std::cout << "final MB/s: " << bw << "\n"; @@ -180,7 +179,7 @@ main(int argc, char** argv) b(); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp critical +# pragma omp critical #endif { std::cout << "rank " << rank << " thread " << thread_id << " serviced " << comm_cnt diff --git a/benchmarks/bench_p2p_bi_ft_avail_mt.cpp b/benchmarks/bench_p2p_bi_ft_avail_mt.cpp index aae0e1bf..2d82c9c7 100644 --- a/benchmarks/bench_p2p_bi_ft_avail_mt.cpp +++ b/benchmarks/bench_p2p_bi_ft_avail_mt.cpp @@ -7,19 +7,18 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include #include -#include "./mpi_environment.hpp" +#include +#include #include "./args.hpp" +#include "./mpi_environment.hpp" #include "./timer.hpp" #include "./utils.hpp" -#include -const char *syncmode = "future"; -const char *waitmode = "avail"; +char const* syncmode = "future"; +char const* waitmode = "avail"; -int -main(int argc, char** argv) +int main(int argc, char** argv) { using namespace oomph; using message = oomph::message_buffer; @@ -33,13 +32,13 @@ main(int argc, char** argv) context ctxt(MPI_COMM_WORLD, multi_threaded); barrier b(ctxt, cmd_args.num_threads); - timer t0; - timer t1; + timer t0; + timer t1; - const auto inflight = cmd_args.inflight; - const auto num_threads = cmd_args.num_threads; - const auto buff_size = cmd_args.buff_size; - const auto niter = cmd_args.n_iter; + auto const inflight = cmd_args.inflight; + auto const num_threads = cmd_args.num_threads; + auto const buff_size = cmd_args.buff_size; + auto const niter = cmd_args.n_iter; if (env.rank == 0) { @@ -61,26 +60,28 @@ main(int argc, char** argv) #endif #ifdef OOMPH_BENCHMARKS_MT -#pragma omp parallel +# pragma omp parallel #endif { - auto comm = ctxt.get_communicator(); - const auto rank = comm.rank(); - const auto size = comm.size(); - const auto thread_id = THREADID; - const auto peer_rank = (rank + 1) % size; + auto comm = ctxt.get_communicator(); + auto const rank = comm.rank(); + auto const size = comm.size(); + auto const thread_id = THREADID; + auto const peer_rank = (rank + 1) % size; - int dbg = 0, sdbg = 0, rdbg = 0; - int last_received = 0; - int last_sent = 0; - int lsent = 0, lrecv = 0; - const int delta_i = niter / 10; + int dbg = 0, sdbg = 0, rdbg = 0; + int last_received = 0; + int last_sent = 0; + int lsent = 0, lrecv = 0; + int const delta_i = niter / 10; if (thread_id == 0 && rank == 0) - { std::cout << "\n\nrunning test " << __FILE__ << "\n\n"; }; + { + std::cout << "\n\nrunning test " << __FILE__ << "\n\n"; + }; - std::vector smsgs(inflight); - std::vector rmsgs(inflight); + std::vector smsgs(inflight); + std::vector rmsgs(inflight); std::vector sreqs(inflight); std::vector rreqs(inflight); for (int j = 0; j < inflight; j++) @@ -128,8 +129,8 @@ main(int argc, char** argv) dbg = 0; std::cout << rank << " total bwdt MB/s: " << ((received - last_received + sent - last_sent) * size * - (double)buff_size / 2) / - t0.stoc() + (double) buff_size / 2) / + t0.stoc() << "\n"; t0.tic(); last_received = received; @@ -144,7 +145,8 @@ main(int argc, char** argv) dbg += num_threads; rreqs[j] = comm.recv(rmsgs[j], peer_rank, thread_id * inflight + j); } - else comm.progress(); + else + comm.progress(); if (lsent < lrecv + 2 * inflight && sent < niter && (sreqs[j].test())) { @@ -154,7 +156,8 @@ main(int argc, char** argv) dbg += num_threads; sreqs[j] = comm.send(smsgs[j], peer_rank, thread_id * inflight + j); } - else comm.progress(); + else + comm.progress(); } } @@ -162,8 +165,8 @@ main(int argc, char** argv) if (thread_id == 0 && rank == 0) { - const auto t = t1.stoc(); - double bw = ((double)niter*size*buff_size)/t; + auto const t = t1.stoc(); + double bw = ((double) niter * size * buff_size) / t; // clang-format off std::cout << "time: " << t / 1000000 << "s\n"; std::cout << "final MB/s: " << bw << "\n"; @@ -193,8 +196,7 @@ main(int argc, char** argv) int send_complete = 0; // complete all posted sends - do - { + do { comm.progress(); // check if we have completed all our posted sends if (!send_complete) @@ -215,7 +217,9 @@ main(int argc, char** argv) for (int j = 0; j < inflight; j++) { if (rreqs[j].test()) - { rreqs[j] = comm.recv(rmsgs[j], peer_rank, thread_id * inflight + j); } + { + rreqs[j] = comm.recv(rmsgs[j], peer_rank, thread_id * inflight + j); + } } } while (tail_send != num_threads); @@ -223,15 +227,15 @@ main(int argc, char** argv) // Notify the peer and keep submitting recvs until we get his notification. send_request sf; recv_request rf; - auto smsg = comm.make_buffer(1); - auto rmsg = comm.make_buffer(1); + auto smsg = comm.make_buffer(1); + auto rmsg = comm.make_buffer(1); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp master +# pragma omp master #endif { - sf = comm.send(smsg, peer_rank, 0x80000); - rf = comm.recv(rmsg, peer_rank, 0x80000); + sf = comm.send(smsg, peer_rank, 0x8'0000); + rf = comm.recv(rmsg, peer_rank, 0x8'0000); } while (tail_recv == 0) @@ -242,11 +246,13 @@ main(int argc, char** argv) for (int j = 0; j < inflight; j++) { if (rreqs[j].test()) - { rreqs[j] = comm.recv(rmsgs[j], peer_rank, thread_id * inflight + j); } + { + rreqs[j] = comm.recv(rmsgs[j], peer_rank, thread_id * inflight + j); + } } #ifdef OOMPH_BENCHMARKS_MT -#pragma omp master +# pragma omp master #endif { if (rf.test()) tail_recv = 1; diff --git a/benchmarks/bench_p2p_bi_ft_wait_mt.cpp b/benchmarks/bench_p2p_bi_ft_wait_mt.cpp index c8485906..79d74271 100644 --- a/benchmarks/bench_p2p_bi_ft_wait_mt.cpp +++ b/benchmarks/bench_p2p_bi_ft_wait_mt.cpp @@ -7,19 +7,18 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include #include -#include "./mpi_environment.hpp" +#include +#include #include "./args.hpp" +#include "./mpi_environment.hpp" #include "./timer.hpp" #include "./utils.hpp" -#include -const char *syncmode = "future"; -const char *waitmode = "wait"; +char const* syncmode = "future"; +char const* waitmode = "wait"; -int -main(int argc, char** argv) +int main(int argc, char** argv) { using namespace oomph; using message = oomph::message_buffer; @@ -33,13 +32,13 @@ main(int argc, char** argv) context ctxt(MPI_COMM_WORLD, multi_threaded); barrier b(ctxt, cmd_args.num_threads); - timer t0; - timer t1; + timer t0; + timer t1; - const auto inflight = cmd_args.inflight; - const auto num_threads = cmd_args.num_threads; - const auto buff_size = cmd_args.buff_size; - const auto niter = cmd_args.n_iter; + auto const inflight = cmd_args.inflight; + auto const num_threads = cmd_args.num_threads; + auto const buff_size = cmd_args.buff_size; + auto const niter = cmd_args.n_iter; if (env.rank == 0) { @@ -49,26 +48,28 @@ main(int argc, char** argv) } #ifdef OOMPH_BENCHMARKS_MT -#pragma omp parallel +# pragma omp parallel #endif { - auto comm = ctxt.get_communicator(); - const auto rank = comm.rank(); - const auto size = comm.size(); - const auto thread_id = THREADID; - const auto peer_rank = (rank + 1) % size; - - int dbg = 0; - int sent = 0, received = 0; - int last_received = 0; - int last_sent = 0; - const int delta_i = niter / 10; + auto comm = ctxt.get_communicator(); + auto const rank = comm.rank(); + auto const size = comm.size(); + auto const thread_id = THREADID; + auto const peer_rank = (rank + 1) % size; + + int dbg = 0; + int sent = 0, received = 0; + int last_received = 0; + int last_sent = 0; + int const delta_i = niter / 10; if (thread_id == 0 && rank == 0) - { std::cout << "\n\nrunning test " << __FILE__ << "\n\n"; }; + { + std::cout << "\n\nrunning test " << __FILE__ << "\n\n"; + }; - std::vector smsgs(inflight); - std::vector rmsgs(inflight); + std::vector smsgs(inflight); + std::vector rmsgs(inflight); std::vector sreqs(inflight); std::vector rreqs(inflight); for (int j = 0; j < inflight; j++) @@ -95,8 +96,8 @@ main(int argc, char** argv) dbg = 0; std::cout << rank << " total bwdt MB/s: " << ((received - last_received + sent - last_sent) * size * - (double)buff_size / 2) / - t0.stoc() + (double) buff_size / 2) / + t0.stoc() << "\n"; t0.tic(); last_received = received; @@ -116,7 +117,7 @@ main(int argc, char** argv) comm.wait_all(); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp barrier +# pragma omp barrier #endif ///* wait for all */ @@ -131,8 +132,8 @@ main(int argc, char** argv) if (thread_id == 0 && rank == 0) { - const auto t = t1.stoc(); - double bw = ((double)niter*size*buff_size)/t; + auto const t = t1.stoc(); + double bw = ((double) niter * size * buff_size) / t; // clang-format off std::cout << "time: " << t / 1000000 << "s\n"; std::cout << "final MB/s: " << bw << "\n"; diff --git a/benchmarks/bench_p2p_pp_ft_avail_mt.cpp b/benchmarks/bench_p2p_pp_ft_avail_mt.cpp index 98301e3e..0783588e 100644 --- a/benchmarks/bench_p2p_pp_ft_avail_mt.cpp +++ b/benchmarks/bench_p2p_pp_ft_avail_mt.cpp @@ -7,20 +7,20 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include #include -#include "./mpi_environment.hpp" +#include #include "./args.hpp" +#include "./mpi_environment.hpp" #include "./timer.hpp" #include "./utils.hpp" // #include +#include #include #include #include -#include #ifdef OOMPH_BENCHMARKS_MT -#include +# include #endif // enable cleaned up debugging output @@ -35,8 +35,7 @@ std::cerr << temp.str(); } // clang-format on -std::string -print_send_recv_info(std::tuple& tup) +std::string print_send_recv_info(std::tuple& tup) { std::stringstream temp; temp << " Sends Posted " << std::get<0>(tup) << " Sends Completed " << std::get<1>(tup) @@ -45,20 +44,20 @@ print_send_recv_info(std::tuple& tup) return temp.str(); } -const char* syncmode = "future"; -const char* waitmode = "avail"; +char const* syncmode = "future"; +char const* waitmode = "avail"; std::atomic sends_posted(0); std::atomic sends_completed(0); std::atomic receives_posted(0); // keep track of sends on a thread local basis -template +template struct alignas(64) msg_tracker { using message = oomph::message_buffer; std::vector msgs; - std::vector reqs; + std::vector reqs; // msg_tracker() = default; // @@ -75,8 +74,7 @@ struct alignas(64) msg_tracker } }; -int -main(int argc, char* argv[]) +int main(int argc, char* argv[]) { using namespace oomph; using message = oomph::message_buffer; @@ -90,13 +88,13 @@ main(int argc, char* argv[]) context ctxt(MPI_COMM_WORLD, multi_threaded); barrier b(ctxt, cmd_args.num_threads); - timer t0; - timer t1; + timer t0; + timer t1; - const auto inflight = cmd_args.inflight; - const auto num_threads = cmd_args.num_threads; - const auto n_secs = cmd_args.n_secs; - const auto buff_size = cmd_args.buff_size; + auto const inflight = cmd_args.inflight; + auto const num_threads = cmd_args.num_threads; + auto const n_secs = cmd_args.n_secs; + auto const buff_size = cmd_args.buff_size; if (env.rank == 0) { @@ -106,8 +104,8 @@ main(int argc, char* argv[]) } // How often do we display debug msgs - const int debug_freq = 5; - const int msecond = (1000 * n_secs) / debug_freq; + int const debug_freq = 5; + int const msecond = (1000 * n_secs) / debug_freq; // true when time exceeded std::atomic time_up = false; @@ -120,27 +118,27 @@ main(int argc, char* argv[]) // so we can post a "done sends" message std::atomic threads_completed(0); // only one thread is reponsible for the "done sends" msg - std::atomic master_thread(-1); + std::atomic master_thread(-1); std::atomic sends_complete_checked_flag = false; std::atomic num_messages_expected = std::numeric_limits::max() / 2; // int mode; - double elapsed; + double elapsed; oomph::timer ttimer; #ifdef OOMPH_BENCHMARKS_MT -#pragma omp parallel +# pragma omp parallel #endif { // ---------------------------------------------------------------- // variables in parallel section are thread local // ---------------------------------------------------------------- - auto comm = ctxt.get_communicator(); - const auto rank = comm.rank(); - const auto size = comm.size(); - const auto thread_id = THREADID; - const auto peer_rank = (rank + 1) % size; + auto comm = ctxt.get_communicator(); + auto const rank = comm.rank(); + auto const size = comm.size(); + auto const thread_id = THREADID; + auto const peer_rank = (rank + 1) % size; // track sends/recvs msg_tracker sends; @@ -150,14 +148,14 @@ main(int argc, char* argv[]) // when all threads have finished sending, // we use these to sync total msg count between ranks - message done_send = comm.make_buffer(sizeof(int)); - message done_recv = comm.make_buffer(sizeof(int)); + message done_send = comm.make_buffer(sizeof(int)); + message done_recv = comm.make_buffer(sizeof(int)); send_request fsend; recv_request frecv; // NB. these are thread local - bool thread_sends_complete = false; // true when thread completed sends - bool thread_sends_complete_flag = false; // true after thread signals counter + bool thread_sends_complete = false; // true when thread completed sends + bool thread_sends_complete_flag = false; // true after thread signals counter // loop for allowed time : sending and receiving do { @@ -253,11 +251,11 @@ main(int argc, char* argv[]) // number of messages sent by the peer + (inflight*num_threads) // then all messages sent by them have been received. } while (!sends_complete_checked_flag || - receives_posted != (num_messages_expected + inflight * num_threads)); + receives_posted != (num_messages_expected + inflight * num_threads)); -// buffered_out("rank: " << rank << "\tthread " -// << " Done" << thread_id << "\tsend: " << sends_posted -// << "\trecv: " << receives_posted); + // buffered_out("rank: " << rank << "\tthread " + // << " Done" << thread_id << "\tsend: " << sends_posted + // << "\trecv: " << receives_posted); // barrier + progress here before final checks b.thread_barrier(); @@ -272,7 +270,8 @@ main(int argc, char* argv[]) { if (!recvs.reqs[j].test()) { - if (recvs.reqs[j].cancel()) receives_posted--; + if (recvs.reqs[j].cancel()) + receives_posted--; else throw std::runtime_error("Receive cancel failed"); } @@ -297,7 +296,7 @@ main(int argc, char* argv[]) // total traffic is amount sends_posted in both directions if (rank == 0 && thread_id == 0) { - double bw = ((double)(sends_posted + receives_posted) * buff_size) / elapsed; + double bw = ((double) (sends_posted + receives_posted) * buff_size) / elapsed; // clang-format off std::cout << "time: " << elapsed/1000000 << "s\n"; std::cout << "final MB/s: " << bw << "\n"; diff --git a/benchmarks/mpi_environment.hpp b/benchmarks/mpi_environment.hpp index 7affabd3..ee5cd418 100644 --- a/benchmarks/mpi_environment.hpp +++ b/benchmarks/mpi_environment.hpp @@ -9,43 +9,39 @@ */ #pragma once -#include #include +#include -namespace oomph -{ -struct mpi_environment -{ - int size; - int rank; - - mpi_environment(bool thread_safe, int& argc, char**& argv) +namespace oomph { + struct mpi_environment { - int mode; - if (thread_safe) + int size; + int rank; + + mpi_environment(bool thread_safe, int& argc, char**& argv) { - MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &mode); - if (mode != MPI_THREAD_MULTIPLE) + int mode; + if (thread_safe) { - std::cerr << "MPI_THREAD_MULTIPLE not supported by MPI, aborting\n"; - std::terminate(); + MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &mode); + if (mode != MPI_THREAD_MULTIPLE) + { + std::cerr << "MPI_THREAD_MULTIPLE not supported by MPI, aborting\n"; + std::terminate(); + } } + else { MPI_Init_thread(&argc, &argv, MPI_THREAD_SINGLE, &mode); } + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); } - else - { - MPI_Init_thread(&argc, &argv, MPI_THREAD_SINGLE, &mode); - } - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - } - mpi_environment(mpi_environment const&) = delete; + mpi_environment(mpi_environment const&) = delete; - ~mpi_environment() - { - MPI_Barrier(MPI_COMM_WORLD); - MPI_Finalize(); - } -}; + ~mpi_environment() + { + MPI_Barrier(MPI_COMM_WORLD); + MPI_Finalize(); + } + }; -} // namespace oomph +} // namespace oomph diff --git a/benchmarks/mpi_p2p_bi_avail_mt.cpp b/benchmarks/mpi_p2p_bi_avail_mt.cpp index 4bc1dedd..a52f919e 100644 --- a/benchmarks/mpi_p2p_bi_avail_mt.cpp +++ b/benchmarks/mpi_p2p_bi_avail_mt.cpp @@ -7,28 +7,27 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ +#include #include #include #include -#include // do not include OOMPH functionality #define OOMPH_BENCHMARKS_PURE_MPI -#include "./mpi_environment.hpp" #include "./args.hpp" +#include "./mpi_environment.hpp" #include "./timer.hpp" #include "./utils.hpp" #ifdef OOMPH_BENCHMARKS_MT -#include +# include #endif /* OOMPH_BENCHMARKS_MT */ -int -main(int argc, char* argv[]) +int main(int argc, char* argv[]) { using namespace oomph; - int rank, size, peer_rank; + int rank, size, peer_rank; timer t0, t1; int last_received = 0; @@ -48,9 +47,9 @@ main(int argc, char* argv[]) mpi_environment env(multi_threaded, argc, argv); if (env.size != 2) return exit(argv[0]); - const auto inflight = cmd_args.inflight; - const auto buff_size = cmd_args.buff_size; - const auto niter = cmd_args.n_iter; + auto const inflight = cmd_args.inflight; + auto const buff_size = cmd_args.buff_size; + auto const niter = cmd_args.n_iter; if (env.rank == 0) { @@ -60,18 +59,18 @@ main(int argc, char* argv[]) } #ifdef OOMPH_BENCHMARKS_MT -#pragma omp parallel +# pragma omp parallel #endif { - int thrid = 0, nthr = 1; - MPI_Comm mpi_comm = MPI_COMM_NULL; + int thrid = 0, nthr = 1; + MPI_Comm mpi_comm = MPI_COMM_NULL; unsigned char** sbuffers = new unsigned char*[inflight]; unsigned char** rbuffers = new unsigned char*[inflight]; - MPI_Request* sreq = new MPI_Request[inflight]; - MPI_Request* rreq = new MPI_Request[inflight]; + MPI_Request* sreq = new MPI_Request[inflight]; + MPI_Request* rreq = new MPI_Request[inflight]; #ifdef OOMPH_BENCHMARKS_MT -#pragma omp master +# pragma omp master #endif { MPI_Comm_rank(MPI_COMM_WORLD, &rank); @@ -90,7 +89,7 @@ main(int argc, char* argv[]) { if (thrid == tid) { MPI_Comm_dup(MPI_COMM_WORLD, &mpi_comm); } #ifdef OOMPH_BENCHMARKS_MT -#pragma omp barrier +# pragma omp barrier #endif } @@ -106,7 +105,7 @@ main(int argc, char* argv[]) if (thrid == 0) MPI_Barrier(mpi_comm); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp barrier +# pragma omp barrier #endif if (thrid == 0) @@ -171,7 +170,7 @@ main(int argc, char* argv[]) { dbg = 0; t0.vtoc(header, - (double)(received - last_received + sent - last_sent) * size * buff_size / 2); + (double) (received - last_received + sent - last_sent) * size * buff_size / 2); t0.tic(); last_received = received; last_sent = sent; @@ -201,16 +200,16 @@ main(int argc, char* argv[]) } } #endif - delete []sbuffers; - delete []rbuffers; + delete[] sbuffers; + delete[] rbuffers; + } } -} MPI_Barrier(MPI_COMM_WORLD); if (rank == 1) { t1.vtoc(); - t1.vtoc("final ", (double)niter * size * buff_size); + t1.vtoc("final ", (double) niter * size * buff_size); } return 0; diff --git a/benchmarks/mpi_p2p_bi_wait_mt.cpp b/benchmarks/mpi_p2p_bi_wait_mt.cpp index c19d69e6..b716a74f 100644 --- a/benchmarks/mpi_p2p_bi_wait_mt.cpp +++ b/benchmarks/mpi_p2p_bi_wait_mt.cpp @@ -7,28 +7,27 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ +#include #include #include #include -#include // do not include OOMPH functionality #define OOMPH_BENCHMARKS_PURE_MPI -#include "./mpi_environment.hpp" #include "./args.hpp" +#include "./mpi_environment.hpp" #include "./timer.hpp" #include "./utils.hpp" #ifdef OOMPH_BENCHMARKS_MT -#include +# include #endif /* OOMPH_BENCHMARKS_MT */ -int -main(int argc, char* argv[]) +int main(int argc, char* argv[]) { using namespace oomph; - int rank, size, peer_rank; + int rank, size, peer_rank; timer t0, t1; args cmd_args(argc, argv); @@ -38,9 +37,9 @@ main(int argc, char* argv[]) mpi_environment env(multi_threaded, argc, argv); if (env.size != 2) return exit(argv[0]); - const auto inflight = cmd_args.inflight; - const auto buff_size = cmd_args.buff_size; - const auto niter = cmd_args.n_iter; + auto const inflight = cmd_args.inflight; + auto const buff_size = cmd_args.buff_size; + auto const niter = cmd_args.n_iter; if (env.rank == 0) { @@ -50,18 +49,18 @@ main(int argc, char* argv[]) } #ifdef OOMPH_BENCHMARKS_MT -#pragma omp parallel +# pragma omp parallel #endif { - int thrid = 0, nthr = 1; - MPI_Comm mpi_comm = MPI_COMM_NULL; + int thrid = 0, nthr = 1; + MPI_Comm mpi_comm = MPI_COMM_NULL; unsigned char** sbuffers = new unsigned char*[inflight]; unsigned char** rbuffers = new unsigned char*[inflight]; - MPI_Request* sreq = new MPI_Request[inflight]; - MPI_Request* rreq = new MPI_Request[inflight]; + MPI_Request* sreq = new MPI_Request[inflight]; + MPI_Request* rreq = new MPI_Request[inflight]; #ifdef OOMPH_BENCHMARKS_MT -#pragma omp master +# pragma omp master #endif { MPI_Comm_rank(MPI_COMM_WORLD, &rank); @@ -80,7 +79,7 @@ main(int argc, char* argv[]) { if (thrid == tid) { MPI_Comm_dup(MPI_COMM_WORLD, &mpi_comm); } #ifdef OOMPH_BENCHMARKS_MT -#pragma omp barrier +# pragma omp barrier #endif } @@ -96,7 +95,7 @@ main(int argc, char* argv[]) if (thrid == 0) MPI_Barrier(mpi_comm); #ifdef OOMPH_BENCHMARKS_MT -#pragma omp barrier +# pragma omp barrier #endif if (thrid == 0) @@ -108,8 +107,8 @@ main(int argc, char* argv[]) << "\n"; } - int i = 0, dbg = 0; - int last_i = 0; + int i = 0, dbg = 0; + int last_i = 0; char header[256]; snprintf(header, 256, "%d total bwdt ", rank); while (i < niter) @@ -117,7 +116,7 @@ main(int argc, char* argv[]) if (thrid == 0 && dbg >= (niter / 10)) { dbg = 0; - t0.vtoc(header, (double)(i - last_i) * size * buff_size); + t0.vtoc(header, (double) (i - last_i) * size * buff_size); t0.tic(); last_i = i; } @@ -145,18 +144,18 @@ main(int argc, char* argv[]) } #endif #ifdef OOMPH_BENCHMARKS_MT -#pragma omp barrier +# pragma omp barrier #endif } - delete []sbuffers; - delete []rbuffers; + delete[] sbuffers; + delete[] rbuffers; } MPI_Barrier(MPI_COMM_WORLD); if (rank == 1) { t1.vtoc(); - t1.vtoc("final ", (double)niter * size * buff_size); + t1.vtoc("final ", (double) niter * size * buff_size); } return 0; diff --git a/benchmarks/timer.hpp b/benchmarks/timer.hpp index e3fe42ae..fa8576d7 100644 --- a/benchmarks/timer.hpp +++ b/benchmarks/timer.hpp @@ -9,94 +9,92 @@ */ #pragma once -#include "./accumulator.hpp" #include +#include "./accumulator.hpp" -namespace oomph -{ -/** @brief timer with built-in statistics */ -class timer : public accumulator -{ - private: // member types - using base = accumulator; - using clock_type = std::chrono::high_resolution_clock; - using time_point = typename clock_type::time_point; - - private: // members - time_point m_time_point = clock_type::now(); - - public: // ctors - timer() = default; - timer(const base& b) - : base(b) - { - } - timer(base&& b) - : base(std::move(b)) +namespace oomph { + /** @brief timer with built-in statistics */ + class timer : public accumulator { - } - timer(const timer&) noexcept = default; - timer(timer&&) noexcept = default; - timer& operator=(const timer&) noexcept = default; - timer& operator=(timer&&) noexcept = default; + private: // member types + using base = accumulator; + using clock_type = std::chrono::high_resolution_clock; + using time_point = typename clock_type::time_point; - public: // time functions - /** @brief start timings */ - inline void tic() noexcept { m_time_point = clock_type::now(); } - /** @brief stop timings */ - inline double stoc() noexcept - { - return std::chrono::duration_cast( - clock_type::now() - m_time_point) - .count(); - } + private: // members + time_point m_time_point = clock_type::now(); - /** @brief stop timings */ - inline double toc() noexcept - { - const auto t = - std::chrono::duration_cast(clock_type::now() - m_time_point) - .count(); - this->operator()(t); - return t; - } + public: // ctors + timer() = default; + timer(base const& b) + : base(b) + { + } + timer(base&& b) + : base(std::move(b)) + { + } + timer(timer const&) noexcept = default; + timer(timer&&) noexcept = default; + timer& operator=(timer const&) noexcept = default; + timer& operator=(timer&&) noexcept = default; - /** @brief stop timings, verbose: print measured time */ - inline void vtoc() noexcept - { - double t = - std::chrono::duration_cast(clock_type::now() - m_time_point) + public: // time functions + /** @brief start timings */ + inline void tic() noexcept { m_time_point = clock_type::now(); } + /** @brief stop timings */ + inline double stoc() noexcept + { + return std::chrono::duration_cast( + clock_type::now() - m_time_point) .count(); - std::cout << "time: " << t / 1000000 << "s\n"; - } + } - /** @brief stop timings, verbose: print measured time and bandwidth */ - inline void vtoc(const char* header, long bytes) noexcept - { - double t = - std::chrono::duration_cast(clock_type::now() - m_time_point) - .count(); - std::cout << header << " MB/s: " << bytes / t << "\n"; - } + /** @brief stop timings */ + inline double toc() noexcept + { + auto const t = std::chrono::duration_cast( + clock_type::now() - m_time_point) + .count(); + this->operator()(t); + return t; + } - /** @brief stop and start another timing period */ - inline void toc_tic() noexcept - { - auto t2 = clock_type::now(); - this->operator()( - std::chrono::duration_cast(t2 - m_time_point).count()); - m_time_point = t2; - } -}; + /** @brief stop timings, verbose: print measured time */ + inline void vtoc() noexcept + { + double t = std::chrono::duration_cast( + clock_type::now() - m_time_point) + .count(); + std::cout << "time: " << t / 1000000 << "s\n"; + } + + /** @brief stop timings, verbose: print measured time and bandwidth */ + inline void vtoc(char const* header, long bytes) noexcept + { + double t = std::chrono::duration_cast( + clock_type::now() - m_time_point) + .count(); + std::cout << header << " MB/s: " << bytes / t << "\n"; + } + + /** @brief stop and start another timing period */ + inline void toc_tic() noexcept + { + auto t2 = clock_type::now(); + this->operator()( + std::chrono::duration_cast(t2 - m_time_point).count()); + m_time_point = t2; + } + }; -/** @brief all-reduce timers over the MPI group defined by the communicator + /** @brief all-reduce timers over the MPI group defined by the communicator * @param t timer local to each rank * @param comm MPI communicator * @return combined timer incorporating statistics over all timings */ -timer -reduce(const timer& t, MPI_Comm comm) -{ - return reduce(static_cast(t), comm); -} + timer reduce(timer const& t, MPI_Comm comm) + { + return reduce(static_cast(t), comm); + } -} // namespace oomph +} // namespace oomph diff --git a/benchmarks/utils.hpp b/benchmarks/utils.hpp index fa1def60..33384de3 100644 --- a/benchmarks/utils.hpp +++ b/benchmarks/utils.hpp @@ -12,19 +12,18 @@ #include #ifdef OOMPH_BENCHMARKS_MT -#define THREADID omp_get_thread_num() +# define THREADID omp_get_thread_num() #else -#define THREADID 0 +# define THREADID 0 #endif namespace oomph { -inline int -exit(char const* executable) -{ - std::cerr << "Usage: " << executable << " [niter] [msg_size] [inflight]" << std::endl; - std::cerr << " run with 2 MPI processes: e.g.: mpirun -np 2 ..." << std::endl; - return 1; -} + inline int exit(char const* executable) + { + std::cerr << "Usage: " << executable << " [niter] [msg_size] [inflight]" << std::endl; + std::cerr << " run with 2 MPI processes: e.g.: mpirun -np 2 ..." << std::endl; + return 1; + } -} // namespace oomph +} // namespace oomph diff --git a/bindings/fortran/communicator_bind.cpp b/bindings/fortran/communicator_bind.cpp index 03dc2f7a..bdf11b6b 100644 --- a/bindings/fortran/communicator_bind.cpp +++ b/bindings/fortran/communicator_bind.cpp @@ -7,269 +7,304 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include -#include #include -#include #include +#include +#include +#include using namespace oomph::fort; using communicator_type = oomph::communicator; -namespace oomph { - namespace fort { - - /* fortran-side user callback */ - typedef void (*f_callback)(void *mesg, int rank, int tag, void *user_data); - - struct callback { - f_callback cb; - void *user_data = nullptr; - callback(f_callback pcb, void *puser_data = nullptr) : cb{pcb}, user_data{puser_data} {} - void operator() (message_type message, int rank, int tag) const { - if(cb) cb(&message, rank, tag, user_data); - } - }; - - struct callback_ref { - f_callback cb; - void *user_data = nullptr; - callback_ref(f_callback pcb, void *puser_data = nullptr) : cb{pcb}, user_data{puser_data} {} - void operator() (message_type &message, int rank, int tag) const { - if(cb) cb(&message, rank, tag, user_data); - } - }; - - struct callback_multi { - f_callback cb; - void *user_data = nullptr; - callback_multi(f_callback pcb, void *puser_data = nullptr) : cb{pcb}, user_data{puser_data} {} - void operator() (message_type message, std::vector, std::vector) const { - if(cb) cb(&message, -1, -1, user_data); - } - }; - - struct callback_multi_ref { - f_callback cb; - void *user_data = nullptr; - callback_multi_ref(f_callback pcb, void *puser_data = nullptr) : cb{pcb}, user_data{puser_data} {} - void operator() (message_type &message, std::vector, std::vector) const { - if(cb) cb(&message, -1, -1, user_data); - } - }; - } -} - - -extern "C" -void* oomph_get_communicator() +namespace oomph { namespace fort { + + /* fortran-side user callback */ + typedef void (*f_callback)(void* mesg, int rank, int tag, void* user_data); + + struct callback + { + f_callback cb; + void* user_data = nullptr; + callback(f_callback pcb, void* puser_data = nullptr) + : cb{pcb} + , user_data{puser_data} + { + } + void operator()(message_type message, int rank, int tag) const + { + if (cb) cb(&message, rank, tag, user_data); + } + }; + + struct callback_ref + { + f_callback cb; + void* user_data = nullptr; + callback_ref(f_callback pcb, void* puser_data = nullptr) + : cb{pcb} + , user_data{puser_data} + { + } + void operator()(message_type& message, int rank, int tag) const + { + if (cb) cb(&message, rank, tag, user_data); + } + }; + + struct callback_multi + { + f_callback cb; + void* user_data = nullptr; + callback_multi(f_callback pcb, void* puser_data = nullptr) + : cb{pcb} + , user_data{puser_data} + { + } + void operator()(message_type message, std::vector, std::vector) const + { + if (cb) cb(&message, -1, -1, user_data); + } + }; + + struct callback_multi_ref + { + f_callback cb; + void* user_data = nullptr; + callback_multi_ref(f_callback pcb, void* puser_data = nullptr) + : cb{pcb} + , user_data{puser_data} + { + } + void operator()(message_type& message, std::vector, std::vector) const + { + if (cb) cb(&message, -1, -1, user_data); + } + }; +}} // namespace oomph::fort + +extern "C" void* oomph_get_communicator() { return new obj_wrapper(get_context().get_communicator()); } -extern "C" -int oomph_comm_rank(obj_wrapper *wrapper) +extern "C" int oomph_comm_rank(obj_wrapper* wrapper) { return get_object_ptr_unsafe(wrapper)->rank(); } -extern "C" -int oomph_comm_size(obj_wrapper *wrapper) +extern "C" int oomph_comm_size(obj_wrapper* wrapper) { return get_object_ptr_unsafe(wrapper)->size(); } -extern "C" -void oomph_comm_progress(obj_wrapper *wrapper) +extern "C" void oomph_comm_progress(obj_wrapper* wrapper) { get_object_ptr_unsafe(wrapper)->progress(); } - /* SEND requests */ -extern "C" -void oomph_comm_post_send(obj_wrapper *wcomm, message_type *message, int rank, int tag, frequest_type *freq) +extern "C" void oomph_comm_post_send( + obj_wrapper* wcomm, message_type* message, int rank, int tag, frequest_type* freq) { - if(nullptr==message || nullptr==wcomm){ - std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message || nullptr == wcomm) + { + std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - communicator_type *comm = get_object_ptr_unsafe(wcomm); + communicator_type* comm = get_object_ptr_unsafe(wcomm); auto req = comm->send(*message, rank, tag); - new(freq->data) decltype(req)(std::move(req)); + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = false; } -extern "C" -void oomph_comm_post_send_cb_wrapped(obj_wrapper *wcomm, message_type *message, int rank, int tag, f_callback cb, frequest_type *freq, void *user_data) +extern "C" void oomph_comm_post_send_cb_wrapped(obj_wrapper* wcomm, message_type* message, int rank, + int tag, f_callback cb, frequest_type* freq, void* user_data) { - if(nullptr==message || nullptr==wcomm){ - std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message || nullptr == wcomm) + { + std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - communicator_type *comm = get_object_ptr_unsafe(wcomm); + communicator_type* comm = get_object_ptr_unsafe(wcomm); auto req = comm->send(*message, rank, tag, callback_ref{cb, user_data}); - if(!freq) return; - new(freq->data) decltype(req)(std::move(req)); + if (!freq) return; + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = false; } -extern "C" -void oomph_comm_send_cb_wrapped(obj_wrapper *wcomm, message_type **message_ref, int rank, int tag, f_callback cb, frequest_type *freq, void *user_data) +extern "C" void oomph_comm_send_cb_wrapped(obj_wrapper* wcomm, message_type** message_ref, int rank, + int tag, f_callback cb, frequest_type* freq, void* user_data) { - if(nullptr==message_ref || nullptr==wcomm || nullptr == *message_ref){ - std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message_ref || nullptr == wcomm || nullptr == *message_ref) + { + std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - communicator_type *comm = get_object_ptr_unsafe(wcomm); - + communicator_type* comm = get_object_ptr_unsafe(wcomm); + auto req = comm->send(std::move(**message_ref), rank, tag, callback{cb, user_data}); *message_ref = nullptr; - if(!freq) return; - new(freq->data) decltype(req)(std::move(req)); + if (!freq) return; + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = false; } - /* SEND_MULTI requests */ -extern "C" -void oomph_comm_post_send_multi_wrapped(obj_wrapper *wcomm, message_type *message, int *ranks, int nranks, int *tags, frequest_type *freq) +extern "C" void oomph_comm_post_send_multi_wrapped(obj_wrapper* wcomm, message_type* message, + int* ranks, int nranks, int* tags, frequest_type* freq) { - if(nullptr==message || nullptr==wcomm){ - std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message || nullptr == wcomm) + { + std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - communicator_type *comm = get_object_ptr_unsafe(wcomm); + communicator_type* comm = get_object_ptr_unsafe(wcomm); std::vector ranks_array(nranks); - ranks_array.assign(ranks, ranks+nranks); + ranks_array.assign(ranks, ranks + nranks); std::vector tags_array(nranks); - tags_array.assign(tags, tags+nranks); + tags_array.assign(tags, tags + nranks); auto req = comm->send_multi(*message, ranks_array, tags_array); - new(freq->data) decltype(req)(std::move(req)); + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = false; } -extern "C" -void oomph_comm_post_send_multi_cb_wrapped(obj_wrapper *wcomm, message_type *message, int *ranks, int nranks, int *tags, f_callback cb, frequest_type *freq, void *user_data) +extern "C" void oomph_comm_post_send_multi_cb_wrapped(obj_wrapper* wcomm, message_type* message, + int* ranks, int nranks, int* tags, f_callback cb, frequest_type* freq, void* user_data) { - if(nullptr==message || nullptr==wcomm){ - std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message || nullptr == wcomm) + { + std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - communicator_type *comm = get_object_ptr_unsafe(wcomm); + communicator_type* comm = get_object_ptr_unsafe(wcomm); std::vector ranks_array(nranks); - ranks_array.assign(ranks, ranks+nranks); + ranks_array.assign(ranks, ranks + nranks); std::vector tags_array(nranks); - tags_array.assign(tags, tags+nranks); + tags_array.assign(tags, tags + nranks); - auto req = comm->send_multi(*message, ranks_array, tags_array, callback_multi_ref{cb, user_data}); - if(!freq) return; - new(freq->data) decltype(req)(std::move(req)); + auto req = + comm->send_multi(*message, ranks_array, tags_array, callback_multi_ref{cb, user_data}); + if (!freq) return; + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = false; } -extern "C" -void oomph_comm_send_multi_cb_wrapped(obj_wrapper *wcomm, message_type **message_ref, int *ranks, int nranks, int *tags, f_callback cb, frequest_type *freq, void *user_data) +extern "C" void oomph_comm_send_multi_cb_wrapped(obj_wrapper* wcomm, message_type** message_ref, + int* ranks, int nranks, int* tags, f_callback cb, frequest_type* freq, void* user_data) { - if(nullptr==message_ref || nullptr==wcomm || nullptr == *message_ref){ - std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message_ref || nullptr == wcomm || nullptr == *message_ref) + { + std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - communicator_type *comm = get_object_ptr_unsafe(wcomm); + communicator_type* comm = get_object_ptr_unsafe(wcomm); std::vector ranks_array(nranks); - ranks_array.assign(ranks, ranks+nranks); + ranks_array.assign(ranks, ranks + nranks); std::vector tags_array(nranks); - tags_array.assign(tags, tags+nranks); + tags_array.assign(tags, tags + nranks); - auto req = comm->send_multi(std::move(**message_ref), ranks_array, tags_array, callback_multi{cb, user_data}); + auto req = comm->send_multi( + std::move(**message_ref), ranks_array, tags_array, callback_multi{cb, user_data}); *message_ref = nullptr; - if(!freq) return; - new(freq->data) decltype(req)(std::move(req)); + if (!freq) return; + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = false; } - /* RECV requests */ -extern "C" -void oomph_comm_post_recv(obj_wrapper *wcomm, message_type *message, int rank, int tag, frequest_type *freq) +extern "C" void oomph_comm_post_recv( + obj_wrapper* wcomm, message_type* message, int rank, int tag, frequest_type* freq) { - if(nullptr==message || nullptr==wcomm){ - std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message || nullptr == wcomm) + { + std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - communicator_type *comm = get_object_ptr_unsafe(wcomm); + communicator_type* comm = get_object_ptr_unsafe(wcomm); auto req = comm->recv(*message, rank, tag); - new(freq->data) decltype(req)(std::move(req)); + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = true; } -extern "C" -void oomph_comm_post_recv_cb_wrapped(obj_wrapper *wcomm, message_type *message, int rank, int tag, f_callback cb, frequest_type *freq, void *user_data) +extern "C" void oomph_comm_post_recv_cb_wrapped(obj_wrapper* wcomm, message_type* message, int rank, + int tag, f_callback cb, frequest_type* freq, void* user_data) { - if(nullptr==message || nullptr==wcomm){ - std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message || nullptr == wcomm) + { + std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - communicator_type *comm = get_object_ptr_unsafe(wcomm); + communicator_type* comm = get_object_ptr_unsafe(wcomm); auto req = comm->recv(*message, rank, tag, callback_ref{cb, user_data}); - if(!freq) return; - new(freq->data) decltype(req)(std::move(req)); + if (!freq) return; + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = true; } -extern "C" -void oomph_comm_recv_cb_wrapped(obj_wrapper *wcomm, message_type **message_ref, int rank, int tag, f_callback cb, frequest_type *freq, void *user_data) +extern "C" void oomph_comm_recv_cb_wrapped(obj_wrapper* wcomm, message_type** message_ref, int rank, + int tag, f_callback cb, frequest_type* freq, void* user_data) { - if(nullptr==message_ref || nullptr==wcomm || nullptr == *message_ref){ - std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message_ref || nullptr == wcomm || nullptr == *message_ref) + { + std::cerr << "ERROR: NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - - communicator_type *comm = get_object_ptr_unsafe(wcomm); + + communicator_type* comm = get_object_ptr_unsafe(wcomm); auto req = comm->recv(std::move(**message_ref), rank, tag, callback{cb, user_data}); *message_ref = nullptr; - if(!freq) return; - new(freq->data) decltype(req)(std::move(req)); + if (!freq) return; + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = true; } - /* resubmission of recv requests from inside callbacks */ -extern "C" -void oomph_comm_resubmit_recv_wrapped(obj_wrapper *wcomm, message_type *message, int rank, int tag, f_callback cb, frequest_type *freq, void *user_data) +extern "C" void oomph_comm_resubmit_recv_wrapped(obj_wrapper* wcomm, message_type* message, + int rank, int tag, f_callback cb, frequest_type* freq, void* user_data) { - if(nullptr==message || nullptr==wcomm){ - std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ << ". Terminating.\n"; + if (nullptr == message || nullptr == wcomm) + { + std::cerr << "ERROR: trying to submit a NULL message or communicator in " << __FUNCTION__ + << ". Terminating.\n"; std::terminate(); } - communicator_type *comm = get_object_ptr_unsafe(wcomm); + communicator_type* comm = get_object_ptr_unsafe(wcomm); auto req = comm->recv(std::move(*message), rank, tag, callback{cb, user_data}); - if(!freq) return; - new(freq->data) decltype(req)(std::move(req)); + if (!freq) return; + new (freq->data) decltype(req)(std::move(req)); freq->recv_request = true; } diff --git a/bindings/fortran/context_bind.cpp b/bindings/fortran/context_bind.cpp index e5b74aa0..f5d02ef9 100644 --- a/bindings/fortran/context_bind.cpp +++ b/bindings/fortran/context_bind.cpp @@ -7,47 +7,33 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include -#include #include +#include #include +#include #include #include -namespace -{ -oomph::context* oomph_context; +namespace { + oomph::context* oomph_context; #if OOMPH_ENABLE_BARRIER -oomph::barrier* oomph_barrier_obj = nullptr; + oomph::barrier* oomph_barrier_obj = nullptr; #endif -} +} // namespace -namespace oomph -{ -namespace fort -{ -context& -get_context() -{ - return *oomph_context; -} +namespace oomph { namespace fort { + context& get_context() { return *oomph_context; } #if OOMPH_ENABLE_BARRIER -#pragma message "barrier is enabled" - oomph::barrier& - barrier() - { - return *oomph_barrier_obj; - } +# pragma message "barrier is enabled" + oomph::barrier& barrier() { return *oomph_barrier_obj; } #else -#pragma message "barrier is disabled" +# pragma message "barrier is disabled" #endif int nthreads = 1; -} // namespace fort -} // namespace oomph +}} // namespace oomph::fort -extern "C" void -oomph_init(int nthreads, MPI_Fint fcomm) +extern "C" void oomph_init(int nthreads, MPI_Fint fcomm) { /* the fortran-side mpi communicator must be translated to C */ MPI_Comm ccomm = MPI_Comm_f2c(fcomm); @@ -58,8 +44,7 @@ oomph_init(int nthreads, MPI_Fint fcomm) #endif } -extern "C" void -oomph_finalize() +extern "C" void oomph_finalize() { delete oomph_context; #if OOMPH_ENABLE_BARRIER @@ -67,33 +52,18 @@ oomph_finalize() #endif } -extern "C" int -oomph_get_current_cpu() -{ - return sched_getcpu(); -} +extern "C" int oomph_get_current_cpu() { return sched_getcpu(); } -extern "C" int -oomph_get_ncpus() -{ - return get_nprocs_conf(); -} +extern "C" int oomph_get_ncpus() { return get_nprocs_conf(); } #if OOMPH_ENABLE_BARRIER -extern "C" void -oomph_barrier(int type) +extern "C" void oomph_barrier(int type) { switch (type) - { - case (oomph::fort::OomphBarrierThread): - oomph::fort::barrier().thread_barrier(); - break; - case (oomph::fort::OomphBarrierRank): - oomph::fort::barrier().rank_barrier(); - break; - default: - oomph::fort::barrier()(); - break; - } + { + case (oomph::fort::OomphBarrierThread): oomph::fort::barrier().thread_barrier(); break; + case (oomph::fort::OomphBarrierRank): oomph::fort::barrier().rank_barrier(); break; + default: oomph::fort::barrier()(); break; + } } #endif diff --git a/bindings/fortran/context_bind.hpp b/bindings/fortran/context_bind.hpp index 617d288d..578ceba9 100644 --- a/bindings/fortran/context_bind.hpp +++ b/bindings/fortran/context_bind.hpp @@ -10,10 +10,6 @@ #pragma once #include -namespace oomph -{ -namespace fort -{ -context& get_context(); -} // namespace fort -} // namespace oomph +namespace oomph { namespace fort { + context& get_context(); +}} // namespace oomph::fort diff --git a/bindings/fortran/message_bind.cpp b/bindings/fortran/message_bind.cpp index b7950ca6..8c1128f4 100644 --- a/bindings/fortran/message_bind.cpp +++ b/bindings/fortran/message_bind.cpp @@ -8,54 +8,53 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include +#include -extern "C" -void *oomph_message_new(std::size_t size, int allocator_type) +extern "C" void* oomph_message_new(std::size_t size, int allocator_type) { - void *wmessage = nullptr; - auto &context = oomph::fort::get_context(); + void* wmessage = nullptr; + auto& context = oomph::fort::get_context(); - switch(allocator_type){ + switch (allocator_type) + { case oomph::fort::OomphAllocatorHost: - { - wmessage = new message_type(std::move(context.make_buffer(size))); - break; - } + { + wmessage = new message_type(std::move(context.make_buffer(size))); + break; + } #if HWMALLOC_ENABLE_DEVICE case oomph::fort::OomphAllocatorDevice: - { - wmessage = new message_type(std::move(context.make_device_buffer(size))); - break; - } + { + wmessage = new message_type(std::move(context.make_device_buffer(size))); + break; + } #endif default: - { - std::cerr << "BINDINGS: " << __FUNCTION__ << ": unsupported allocator type: " << allocator_type << "\n"; - std::terminate(); - break; - } + { + std::cerr << "BINDINGS: " << __FUNCTION__ + << ": unsupported allocator type: " << allocator_type << "\n"; + std::terminate(); + break; + } } return wmessage; } -extern "C" -void oomph_message_free(message_type **message_ref) +extern "C" void oomph_message_free(message_type** message_ref) { if (nullptr == message_ref) return; delete *message_ref; *message_ref = nullptr; } -extern "C" -void oomph_message_zero(message_type *message) +extern "C" void oomph_message_zero(message_type* message) { if (nullptr == message) return; unsigned char* __restrict data = message->data(); @@ -63,10 +62,10 @@ void oomph_message_zero(message_type *message) std::memset(data, 0, size); } -extern "C" -unsigned char *oomph_message_data_wrapped(message_type *message, std::size_t *size) +extern "C" unsigned char* oomph_message_data_wrapped(message_type* message, std::size_t* size) { - if (nullptr == message) { + if (nullptr == message) + { *size = 0; return nullptr; } diff --git a/bindings/fortran/object_wrapper.cpp b/bindings/fortran/object_wrapper.cpp index 5bbf5b76..fb9d7c1d 100644 --- a/bindings/fortran/object_wrapper.cpp +++ b/bindings/fortran/object_wrapper.cpp @@ -7,11 +7,10 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include #include +#include -extern "C" void -oomph_obj_free(oomph::fort::obj_wrapper** wrapper_ref) +extern "C" void oomph_obj_free(oomph::fort::obj_wrapper** wrapper_ref) { auto wrapper = *wrapper_ref; diff --git a/bindings/fortran/object_wrapper.hpp b/bindings/fortran/object_wrapper.hpp index 684a0fd6..9f585644 100644 --- a/bindings/fortran/object_wrapper.hpp +++ b/bindings/fortran/object_wrapper.hpp @@ -11,53 +11,48 @@ #include -namespace oomph -{ -namespace fort -{ -class obj_wrapper -{ - public: - /** base class for stored object type */ - struct obj_storage_base +namespace oomph { namespace fort { + class obj_wrapper { - virtual ~obj_storage_base() = default; - }; + public: + /** base class for stored object type */ + struct obj_storage_base + { + virtual ~obj_storage_base() = default; + }; - /** actual object storage */ - template - struct obj_storage : obj_storage_base - { - T m_obj; - obj_storage(T const& obj) - : m_obj(obj) + /** actual object storage */ + template + struct obj_storage : obj_storage_base { - } - obj_storage(T&& obj) - : m_obj(std::move(obj)) + T m_obj; + obj_storage(T const& obj) + : m_obj(obj) + { + } + obj_storage(T&& obj) + : m_obj(std::move(obj)) + { + } + }; + + std::unique_ptr m_obj_storage; + + obj_wrapper(obj_wrapper&&) = default; + + template ::type> + obj_wrapper(Arg&& arg) + : m_obj_storage(new obj_storage(std::forward(arg))) { } }; - std::unique_ptr m_obj_storage; - - obj_wrapper(obj_wrapper&&) = default; - - template::type> - obj_wrapper(Arg&& arg) - : m_obj_storage(new obj_storage(std::forward(arg))) + /** get the object without performing type checks: + * assume that has already been done in Fortran and the cast is legal */ + template + T* get_object_ptr_unsafe(obj_wrapper* src) { + return &reinterpret_cast*>(src->m_obj_storage.get())->m_obj; } -}; -/** get the object without performing type checks: - * assume that has already been done in Fortran and the cast is legal */ -template -T* -get_object_ptr_unsafe(obj_wrapper* src) -{ - return &reinterpret_cast*>(src->m_obj_storage.get())->m_obj; -} - -} // namespace fort -} // namespace oomph +}} // namespace oomph::fort diff --git a/bindings/fortran/request_bind.cpp b/bindings/fortran/request_bind.cpp index 4de4041f..4f1aa2ad 100644 --- a/bindings/fortran/request_bind.cpp +++ b/bindings/fortran/request_bind.cpp @@ -9,49 +9,54 @@ */ #include "request_bind.hpp" -extern "C" -bool oomph_request_test(oomph::fort::frequest_type *freq) +extern "C" bool oomph_request_test(oomph::fort::frequest_type* freq) { - if(freq->recv_request){ - oomph::recv_request *req = reinterpret_cast(freq->data); + if (freq->recv_request) + { + oomph::recv_request* req = reinterpret_cast(freq->data); return req->test(); - } else { - oomph::send_request *req = reinterpret_cast(freq->data); + } + else + { + oomph::send_request* req = reinterpret_cast(freq->data); return req->test(); } } -extern "C" -bool oomph_request_ready(oomph::fort::frequest_type *freq) +extern "C" bool oomph_request_ready(oomph::fort::frequest_type* freq) { - if(freq->recv_request){ - oomph::recv_request *req = reinterpret_cast(freq->data); + if (freq->recv_request) + { + oomph::recv_request* req = reinterpret_cast(freq->data); return req->is_ready(); - } else { - oomph::send_request *req = reinterpret_cast(freq->data); + } + else + { + oomph::send_request* req = reinterpret_cast(freq->data); return req->is_ready(); } } -extern "C" -void oomph_request_wait(oomph::fort::frequest_type *freq) +extern "C" void oomph_request_wait(oomph::fort::frequest_type* freq) { - if(freq->recv_request){ - oomph::recv_request *req = reinterpret_cast(freq->data); + if (freq->recv_request) + { + oomph::recv_request* req = reinterpret_cast(freq->data); req->wait(); - } else { - oomph::send_request *req = reinterpret_cast(freq->data); + } + else + { + oomph::send_request* req = reinterpret_cast(freq->data); req->wait(); } } -extern "C" -bool oomph_request_cancel(oomph::fort::frequest_type *freq) +extern "C" bool oomph_request_cancel(oomph::fort::frequest_type* freq) { - if(freq->recv_request){ - oomph::recv_request *req = reinterpret_cast(freq->data); + if (freq->recv_request) + { + oomph::recv_request* req = reinterpret_cast(freq->data); return req->cancel(); - } else { - return false; } + else { return false; } } diff --git a/bindings/fortran/request_bind.hpp b/bindings/fortran/request_bind.hpp index c4f605c3..467c8afb 100644 --- a/bindings/fortran/request_bind.hpp +++ b/bindings/fortran/request_bind.hpp @@ -10,17 +10,16 @@ #ifndef INCLUDED_OOMPH_FORTRAN_REQUEST_BIND_HPP #define INCLUDED_OOMPH_FORTRAN_REQUEST_BIND_HPP +#include #include #include -#include -namespace oomph { - namespace fort { - struct frequest_type { - int8_t data[OOMPH_REQUEST_SIZE]; - bool recv_request; - }; - } -} +namespace oomph { namespace fort { + struct frequest_type + { + int8_t data[OOMPH_REQUEST_SIZE]; + bool recv_request; + }; +}} // namespace oomph::fort #endif /* INCLUDED_OOMPH_FORTRAN_REQUEST_BIND_HPP */ diff --git a/bindings/fortran/sizes.cpp b/bindings/fortran/sizes.cpp index e09c8c84..a202ff03 100644 --- a/bindings/fortran/sizes.cpp +++ b/bindings/fortran/sizes.cpp @@ -13,9 +13,11 @@ int main() { std::cout << "#ifndef OOMPH_SIZES_H_INCLUDED\n"; - std::cout << "#define OOMPH_SIZES_H_INCLUDED\n"; - std::cout << "\n"; - size_t rsize = sizeof(oomph::send_request)>sizeof(oomph::recv_request)?sizeof(oomph::send_request):sizeof(oomph::recv_request); + std::cout << "#define OOMPH_SIZES_H_INCLUDED\n"; + std::cout << "\n"; + size_t rsize = sizeof(oomph::send_request) > sizeof(oomph::recv_request) ? + sizeof(oomph::send_request) : + sizeof(oomph::recv_request); std::cout << "#define OOMPH_REQUEST_SIZE " << rsize << "\n"; std::cout << "\n"; std::cout << "#endif /* OOMPH_SIZES_H_INCLUDED */\n"; diff --git a/cmake/config.hpp.in b/cmake/config.hpp.in index 458b038a..1101a9f3 100644 --- a/cmake/config.hpp.in +++ b/cmake/config.hpp.in @@ -15,20 +15,23 @@ #define OOMPH_ENABLE_DEVICE HWMALLOC_ENABLE_DEVICE #define OOMPH_DEVICE_RUNTIME HWMALLOC_DEVICE_RUNTIME #if defined(HWMALLOC_DEVICE_HIP) -# define OOMPH_DEVICE_HIP +# define OOMPH_DEVICE_HIP #elif defined(HWMALLOC_DEVICE_CUDA) -# define OOMPH_DEVICE_CUDA +# define OOMPH_DEVICE_CUDA #elif defined(HWMALLOC_DEVICE_EMULATE) -# define OOMPH_DEVICE_EMULATE +# define OOMPH_DEVICE_EMULATE #else -# define OOMPH_DEVICE_NONE +# define OOMPH_DEVICE_NONE #endif #cmakedefine01 OOMPH_USE_FAST_PIMPL #cmakedefine01 OOMPH_ENABLE_BARRIER + +// clang-format off #define OOMPH_RECURSION_DEPTH @OOMPH_RECURSION_DEPTH@ #define OOMPH_VERSION @OOMPH_VERSION_NUMERIC@ #define OOMPH_VERSION_MAJOR @OOMPH_VERSION_MAJOR@ #define OOMPH_VERSION_MINOR @OOMPH_VERSION_MINOR@ #define OOMPH_VERSION_PATCH @OOMPH_VERSION_PATCH@ +// clang-format on diff --git a/cmake/oomph_defs.hpp.in b/cmake/oomph_defs.hpp.in index 70ae8732..eb06fa87 100644 --- a/cmake/oomph_defs.hpp.in +++ b/cmake/oomph_defs.hpp.in @@ -11,19 +11,19 @@ #include -namespace oomph -{ - namespace fort +namespace oomph { namespace fort { + // clang-format off + using fp_type = @OOMPH_FORTRAN_FP@; + // clang-format on + typedef enum { - using fp_type = @OOMPH_FORTRAN_FP@; - typedef enum { - OomphBarrierGlobal=1, - OomphBarrierThread=2, - OomphBarrierRank=3 - } oomph_barrier_type; - typedef enum { - OomphAllocatorHost=1, - OomphAllocatorDevice=2 - } oomph_allocator_type; - } -} + OomphBarrierGlobal = 1, + OomphBarrierThread = 2, + OomphBarrierRank = 3 + } oomph_barrier_type; + typedef enum + { + OomphAllocatorHost = 1, + OomphAllocatorDevice = 2 + } oomph_allocator_type; +}} // namespace oomph::fort diff --git a/include/oomph/barrier.hpp b/include/oomph/barrier.hpp index 2f09af17..7d56ce76 100644 --- a/include/oomph/barrier.hpp +++ b/include/oomph/barrier.hpp @@ -12,12 +12,11 @@ #include #if OOMPH_ENABLE_BARRIER -#include -#include +# include +# include -namespace oomph -{ -/** +namespace oomph { + /** The barrier object synchronize threads or ranks, or both. When synchronizing ranks, it also progress the communicator. @@ -37,54 +36,54 @@ performed as usual. This is why the barrier is split into is_node1 and in_node2. in_node1 returns true to the thread selected to run the rank_barrier in the full barrier. */ -class barrier -{ - private: // members - std::size_t m_threads; - mutable std::atomic b_count{0}; - mutable std::atomic b_count2; - MPI_Comm m_mpi_comm; - context_impl const* m_context; + class barrier + { + private: // members + std::size_t m_threads; + mutable std::atomic b_count{0}; + mutable std::atomic b_count2; + MPI_Comm m_mpi_comm; + context_impl const* m_context; - friend class test_barrier; + friend class test_barrier; - public: // ctors - barrier(context const& c, size_t n_threads = 1); - barrier(const barrier&) = delete; - barrier(barrier&&) = delete; + public: // ctors + barrier(context const& c, size_t n_threads = 1); + barrier(barrier const&) = delete; + barrier(barrier&&) = delete; - public: // public member functions - int size() const noexcept { return m_threads; } + public: // public member functions + int size() const noexcept { return m_threads; } - /** This is the most general barrier, it synchronize threads and ranks. */ - void operator()() const; + /** This is the most general barrier, it synchronize threads and ranks. */ + void operator()() const; - /** + /** * This function can be used to synchronize ranks. * Only one thread per rank must call this function. * If other threads exist, they hace to be synchronized separately, * maybe using the in_node function. */ - void rank_barrier() const; + void rank_barrier() const; - /** + /** * This function synchronize the threads in a rank. The number of threads that need to participate * is indicated in the construction of the barrier object, whose reference is shared among the * participating threads. */ - void thread_barrier() const - { - in_node1(); - in_node2(); - } + void thread_barrier() const + { + in_node1(); + in_node2(); + } - private: - bool in_node1() const; + private: + bool in_node1() const; - void in_node2() const; -}; + void in_node2() const; + }; -} // namespace oomph +} // namespace oomph #else -#pragma message("barrier is not enabled in this configuration") -#endif // OOMPH_ENABLE_BARRIER +# pragma message("barrier is not enabled in this configuration") +#endif // OOMPH_ENABLE_BARRIER diff --git a/include/oomph/channel/channel.hpp b/include/oomph/channel/channel.hpp index b7279a33..1b3e51ce 100644 --- a/include/oomph/channel/channel.hpp +++ b/include/oomph/channel/channel.hpp @@ -9,5 +9,5 @@ */ #pragma once -#include #include +#include diff --git a/include/oomph/channel/recv_channel.hpp b/include/oomph/channel/recv_channel.hpp index 8ac22d2e..f29deb46 100644 --- a/include/oomph/channel/recv_channel.hpp +++ b/include/oomph/channel/recv_channel.hpp @@ -9,150 +9,148 @@ */ #pragma once -#include #include +#include -namespace oomph -{ -class recv_channel_impl; - -void release_recv_channel_buffer(recv_channel_impl*, std::size_t); +namespace oomph { + class recv_channel_impl; -class recv_channel_base -{ - protected: - util::heap_pimpl m_impl; + void release_recv_channel_buffer(recv_channel_impl*, std::size_t); - recv_channel_base(communicator& comm, std::size_t size, std::size_t T_size, - communicator::rank_type src, communicator::tag_type tag, std::size_t levels); + class recv_channel_base + { + protected: + util::heap_pimpl m_impl; - ~recv_channel_base(); + recv_channel_base(communicator& comm, std::size_t size, std::size_t T_size, + communicator::rank_type src, communicator::tag_type tag, std::size_t levels); - void* get(std::size_t& index); + ~recv_channel_base(); - recv_channel_impl* get_impl() noexcept; + void* get(std::size_t& index); - public: - void connect(); + recv_channel_impl* get_impl() noexcept; - std::size_t capacity(); -}; + public: + void connect(); -template -class recv_channel : public recv_channel_base -{ - using base = recv_channel_base; + std::size_t capacity(); + }; - public: - class buffer + template + class recv_channel : public recv_channel_base { - // message_buffer m_buffer; - friend class recv_channel; - - private: - T* m_ptr = nullptr; - std::size_t m_size; - std::size_t m_index; - recv_channel_impl* m_recv_channel_impl; - - public: - buffer() = default; - - buffer(buffer&& other) - : m_ptr{std::exchange(other.m_ptr, nullptr)} - , m_size{other.m_size} - , m_index{other.m_index} - , m_recv_channel_impl{other.m_recv_channel_impl} - { - } + using base = recv_channel_base; - buffer& operator=(buffer&& other) + public: + class buffer { - release(); - m_ptr = std::exchange(other.m_ptr, nullptr); - m_size = other.m_size; - m_index = other.m_index; - m_recv_channel_impl = other.m_recv_channel_impl; - return *this; - } + // message_buffer m_buffer; + friend class recv_channel; + + private: + T* m_ptr = nullptr; + std::size_t m_size; + std::size_t m_index; + recv_channel_impl* m_recv_channel_impl; + + public: + buffer() = default; + + buffer(buffer&& other) + : m_ptr{std::exchange(other.m_ptr, nullptr)} + , m_size{other.m_size} + , m_index{other.m_index} + , m_recv_channel_impl{other.m_recv_channel_impl} + { + } - ~buffer() { release(); } + buffer& operator=(buffer&& other) + { + release(); + m_ptr = std::exchange(other.m_ptr, nullptr); + m_size = other.m_size; + m_index = other.m_index; + m_recv_channel_impl = other.m_recv_channel_impl; + return *this; + } - private: - buffer(T* ptr, std::size_t size_, std::size_t index, recv_channel_impl* rc) - : m_ptr{ptr} - , m_size{size_} - , m_index{index} - , m_recv_channel_impl{rc} - { - } + ~buffer() { release(); } + + private: + buffer(T* ptr, std::size_t size_, std::size_t index, recv_channel_impl* rc) + : m_ptr{ptr} + , m_size{size_} + , m_index{index} + , m_recv_channel_impl{rc} + { + } - public: - operator bool() const noexcept { return m_ptr; } + public: + operator bool() const noexcept { return m_ptr; } - std::size_t size() const noexcept { return m_size; } + std::size_t size() const noexcept { return m_size; } - T* data() noexcept { return m_ptr; } - T const* data() const noexcept { return m_ptr; } - T* begin() noexcept { return data(); } - T const* begin() const noexcept { return data(); } - T* end() noexcept { return data() + size(); } - T const* end() const noexcept { return data() + size(); } - T const* cbegin() const noexcept { return data(); } - T const* cend() const noexcept { return data() + size(); } + T* data() noexcept { return m_ptr; } + T const* data() const noexcept { return m_ptr; } + T* begin() noexcept { return data(); } + T const* begin() const noexcept { return data(); } + T* end() noexcept { return data() + size(); } + T const* end() const noexcept { return data() + size(); } + T const* cbegin() const noexcept { return data(); } + T const* cend() const noexcept { return data() + size(); } - void release() - { - if (m_ptr) + void release() { - release_recv_channel_buffer(m_recv_channel_impl, m_index); - m_ptr = nullptr; + if (m_ptr) + { + release_recv_channel_buffer(m_recv_channel_impl, m_index); + m_ptr = nullptr; + } } + }; + + //class request + //{ + // class impl; + // bool is_ready_local(); + // bool is_ready_remote(); + // void wait_local(); + // void wait_remote(); + //}; + + private: + std::size_t m_size; + + public: + recv_channel(communicator& comm, std::size_t size, communicator::rank_type src, + communicator::tag_type tag, std::size_t levels) + : base(comm, size, sizeof(T), src, tag, levels) + , m_size{size} + { } - }; - - //class request - //{ - // class impl; - // bool is_ready_local(); - // bool is_ready_remote(); - // void wait_local(); - // void wait_remote(); - //}; - - private: - std::size_t m_size; - - public: - recv_channel(communicator& comm, std::size_t size, communicator::rank_type src, - communicator::tag_type tag, std::size_t levels) - : base(comm, size, sizeof(T), src, tag, levels) - , m_size{size} - { - } - recv_channel(recv_channel const&) = delete; - recv_channel(recv_channel&&) = default; + recv_channel(recv_channel const&) = delete; + recv_channel(recv_channel&&) = default; - //void connect(); + //void connect(); - //std::size_t capacity() const noexcept; + //std::size_t capacity() const noexcept; - //buffer make_buffer(); + //buffer make_buffer(); - buffer get() - { - T* ptr = nullptr; - std::size_t index; - do + buffer get() { - ptr = (T*)base::get(index); + T* ptr = nullptr; + std::size_t index; + do { + ptr = (T*) base::get(index); + } + //while (!ptr); + while (false); + return {ptr, m_size, index, base::get_impl()}; } - //while (!ptr); - while (false); - return {ptr, m_size, index, base::get_impl()}; - } - //request put(buffer& b); -}; + //request put(buffer& b); + }; -} // namespace oomph +} // namespace oomph diff --git a/include/oomph/channel/send_channel.hpp b/include/oomph/channel/send_channel.hpp index c6fb75d7..0178cb5d 100644 --- a/include/oomph/channel/send_channel.hpp +++ b/include/oomph/channel/send_channel.hpp @@ -9,63 +9,62 @@ */ #pragma once -#include #include +#include -namespace oomph -{ -class send_channel_impl; +namespace oomph { + class send_channel_impl; -class send_channel_base -{ - protected: - util::heap_pimpl m_impl; + class send_channel_base + { + protected: + util::heap_pimpl m_impl; - send_channel_base(communicator& comm, std::size_t size, std::size_t T_size, - communicator::rank_type dst, communicator::tag_type tag, std::size_t levels); + send_channel_base(communicator& comm, std::size_t size, std::size_t T_size, + communicator::rank_type dst, communicator::tag_type tag, std::size_t levels); - ~send_channel_base(); + ~send_channel_base(); - public: - void connect(); -}; + public: + void connect(); + }; -template -class send_channel : public send_channel_base -{ - using base = send_channel_base; + template + class send_channel : public send_channel_base + { + using base = send_channel_base; - public: - //class buffer - //{ - // message_buffer m_buffer; - //}; + public: + //class buffer + //{ + // message_buffer m_buffer; + //}; - //class request - //{ - // class impl; - // bool is_ready_local(); - // bool is_ready_remote(); - // void wait_local(); - // void wait_remote(); - //}; + //class request + //{ + // class impl; + // bool is_ready_local(); + // bool is_ready_remote(); + // void wait_local(); + // void wait_remote(); + //}; - public: - send_channel(communicator& comm, std::size_t size, communicator::rank_type dst, - communicator::tag_type tag, std::size_t levels) - : base(comm, size, sizeof(T), dst, tag, levels) - { - } - send_channel(send_channel const&) = delete; - send_channel(send_channel&&) = default; + public: + send_channel(communicator& comm, std::size_t size, communicator::rank_type dst, + communicator::tag_type tag, std::size_t levels) + : base(comm, size, sizeof(T), dst, tag, levels) + { + } + send_channel(send_channel const&) = delete; + send_channel(send_channel&&) = default; - //void connect(); + //void connect(); - //std::size_t capacity() const noexcept; + //std::size_t capacity() const noexcept; - //buffer make_buffer(); + //buffer make_buffer(); - //request put(buffer& b); -}; + //request put(buffer& b); + }; -} // namespace oomph +} // namespace oomph diff --git a/include/oomph/communicator.hpp b/include/oomph/communicator.hpp index 71d9908c..e671c934 100644 --- a/include/oomph/communicator.hpp +++ b/include/oomph/communicator.hpp @@ -9,503 +9,497 @@ */ #pragma once -#include -#include #include -#include #include +#include +#include #include #include -#include #include +#include #include #include +#include -namespace oomph -{ +namespace oomph { -class context; + class context; -class communicator -{ - private: - friend class context; + class communicator + { + private: + friend class context; - public: - using impl_type = communicator_impl; + public: + using impl_type = communicator_impl; - public: - static constexpr rank_type any_source = -1; - static constexpr tag_type any_tag = -1; + public: + static constexpr rank_type any_source = -1; + static constexpr tag_type any_tag = -1; - private: - template - struct cb_rref - { - std::decay_t cb; - message_buffer m; + private: + template + struct cb_rref + { + std::decay_t cb; + message_buffer m; - void operator()(rank_type r, tag_type t) { cb(std::move(m), r, t); } - }; + void operator()(rank_type r, tag_type t) { cb(std::move(m), r, t); } + }; - template - struct cb_lref - { - std::decay_t cb; - message_buffer* m; + template + struct cb_lref + { + std::decay_t cb; + message_buffer* m; - void operator()(rank_type r, tag_type t) { cb(*m, r, t); } - }; + void operator()(rank_type r, tag_type t) { cb(*m, r, t); } + }; - template - struct cb_lref_const - { - std::decay_t cb; - message_buffer const* m; + template + struct cb_lref_const + { + std::decay_t cb; + message_buffer const* m; - void operator()(rank_type r, tag_type t) { cb(*m, r, t); } - }; + void operator()(rank_type r, tag_type t) { cb(*m, r, t); } + }; - private: - util::unsafe_shared_ptr m_state; + private: + util::unsafe_shared_ptr m_state; - private: - communicator(impl_type* impl_, std::atomic* shared_scheduled_recvs) - : m_state{util::make_shared(impl_, shared_scheduled_recvs)} - { - } + private: + communicator(impl_type* impl_, std::atomic* shared_scheduled_recvs) + : m_state{util::make_shared(impl_, shared_scheduled_recvs)} + { + } - communicator(util::unsafe_shared_ptr s) noexcept - : m_state{s} - { - } - - public: - communicator(communicator const&) = delete; - communicator(communicator&& other) = default; - communicator& operator=(communicator const&) = delete; - communicator& operator=(communicator&& other) = default; - - public: - rank_type rank() const noexcept; - rank_type size() const noexcept; - MPI_Comm mpi_comm() const noexcept; - bool is_local(rank_type rank) const noexcept; - std::size_t scheduled_sends() const noexcept { return m_state->scheduled_sends; } - std::size_t scheduled_recvs() const noexcept { return m_state->scheduled_recvs; } - std::size_t scheduled_shared_recvs() const noexcept - { - return m_state->m_shared_scheduled_recvs->load(); - } + communicator(util::unsafe_shared_ptr s) noexcept + : m_state{s} + { + } - bool is_ready() const noexcept - { - return (scheduled_sends() == 0) && (scheduled_recvs() == 0) && - (scheduled_shared_recvs() == 0); - } + public: + communicator(communicator const&) = delete; + communicator(communicator&& other) = default; + communicator& operator=(communicator const&) = delete; + communicator& operator=(communicator&& other) = default; + + public: + rank_type rank() const noexcept; + rank_type size() const noexcept; + MPI_Comm mpi_comm() const noexcept; + bool is_local(rank_type rank) const noexcept; + std::size_t scheduled_sends() const noexcept { return m_state->scheduled_sends; } + std::size_t scheduled_recvs() const noexcept { return m_state->scheduled_recvs; } + std::size_t scheduled_shared_recvs() const noexcept + { + return m_state->m_shared_scheduled_recvs->load(); + } - void wait_all() - { - while (!is_ready()) { progress(); } - } + bool is_ready() const noexcept + { + return (scheduled_sends() == 0) && (scheduled_recvs() == 0) && + (scheduled_shared_recvs() == 0); + } - template - message_buffer make_buffer(std::size_t size) - { - return {make_buffer_core(size * sizeof(T)), size}; - } + void wait_all() + { + while (!is_ready()) { progress(); } + } - template - message_buffer make_buffer(T* ptr, std::size_t size) - { - return {make_buffer_core(ptr, size * sizeof(T)), size}; - } + template + message_buffer make_buffer(std::size_t size) + { + return {make_buffer_core(size * sizeof(T)), size}; + } + + template + message_buffer make_buffer(T* ptr, std::size_t size) + { + return {make_buffer_core(ptr, size * sizeof(T)), size}; + } #if OOMPH_ENABLE_DEVICE - template - message_buffer make_device_buffer(std::size_t size, int id = hwmalloc::get_device_id()) - { - return {make_buffer_core(size * sizeof(T), id), size}; - } + template + message_buffer make_device_buffer(std::size_t size, int id = hwmalloc::get_device_id()) + { + return {make_buffer_core(size * sizeof(T), id), size}; + } - template - message_buffer make_device_buffer(T* device_ptr, std::size_t size, - int id = hwmalloc::get_device_id()) - { - return {make_buffer_core(device_ptr, size * sizeof(T), id), size}; - } + template + message_buffer + make_device_buffer(T* device_ptr, std::size_t size, int id = hwmalloc::get_device_id()) + { + return {make_buffer_core(device_ptr, size * sizeof(T), id), size}; + } - template - message_buffer make_device_buffer(T* ptr, T* device_ptr, std::size_t size, - int id = hwmalloc::get_device_id()) - { - return {make_buffer_core(ptr, device_ptr, size * sizeof(T), id), size}; - } + template + message_buffer make_device_buffer( + T* ptr, T* device_ptr, std::size_t size, int id = hwmalloc::get_device_id()) + { + return {make_buffer_core(ptr, device_ptr, size * sizeof(T), id), size}; + } #endif - // no callback versions - // ==================== + // no callback versions + // ==================== - // recv - // ---- + // recv + // ---- - template - recv_request recv(message_buffer& msg, rank_type src, tag_type tag) - { - assert(msg); - return recv(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), src, tag, - util::unique_function([](rank_type, tag_type) {})); - } + template + recv_request recv(message_buffer& msg, rank_type src, tag_type tag) + { + assert(msg); + return recv(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), src, tag, + util::unique_function([](rank_type, tag_type) {})); + } - // shared_recv - // ----------- + // shared_recv + // ----------- - template - shared_recv_request shared_recv(message_buffer& msg, rank_type src, tag_type tag) - { - assert(msg); - return shared_recv(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), src, tag, - util::unique_function([](rank_type, tag_type) {})); - } + template + shared_recv_request shared_recv(message_buffer& msg, rank_type src, tag_type tag) + { + assert(msg); + return shared_recv(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), src, tag, + util::unique_function([](rank_type, tag_type) {})); + } - // send - // ---- + // send + // ---- - template - send_request send(message_buffer const& msg, rank_type dst, tag_type tag) - { - assert(msg); - return send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), dst, tag, - util::unique_function([](rank_type, tag_type) {})); - } + template + send_request send(message_buffer const& msg, rank_type dst, tag_type tag) + { + assert(msg); + return send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), dst, tag, + util::unique_function([](rank_type, tag_type) {})); + } - // send_multi - // ---------- + // send_multi + // ---------- - template - send_multi_request send_multi(message_buffer const& msg, rank_type const* neighs, - std::size_t neighs_size, tag_type tag) - { - assert(msg); - auto mrs = m_state->make_multi_request_state(neighs_size); - for (std::size_t i = 0; i < neighs_size; ++i) + template + send_multi_request send_multi(message_buffer const& msg, rank_type const* neighs, + std::size_t neighs_size, tag_type tag) { - send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), neighs[i], tag, - util::unique_function( - [mrs](rank_type, tag_type) { --(mrs->m_counter); })); + assert(msg); + auto mrs = m_state->make_multi_request_state(neighs_size); + for (std::size_t i = 0; i < neighs_size; ++i) + { + send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), neighs[i], tag, + util::unique_function( + [mrs](rank_type, tag_type) { --(mrs->m_counter); })); + } + return {std::move(mrs)}; } - return {std::move(mrs)}; - } - template - send_multi_request send_multi(message_buffer const& msg, - std::vector const& neighs, tag_type tag) - { - return send_multi(msg, neighs.data(), neighs.size(), tag); - } + template + send_multi_request + send_multi(message_buffer const& msg, std::vector const& neighs, tag_type tag) + { + return send_multi(msg, neighs.data(), neighs.size(), tag); + } - template - send_multi_request send_multi(message_buffer const& msg, rank_type const* neighs, - tag_type const* tags, std::size_t neighs_size) - { - assert(msg); - auto mrs = m_state->make_multi_request_state(neighs_size); - for (std::size_t i = 0; i < neighs_size; ++i) + template + send_multi_request send_multi(message_buffer const& msg, rank_type const* neighs, + tag_type const* tags, std::size_t neighs_size) { - send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), neighs[i], tags[i], - util::unique_function( - [mrs](rank_type, tag_type) { --(mrs->m_counter); })); + assert(msg); + auto mrs = m_state->make_multi_request_state(neighs_size); + for (std::size_t i = 0; i < neighs_size; ++i) + { + send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), neighs[i], tags[i], + util::unique_function( + [mrs](rank_type, tag_type) { --(mrs->m_counter); })); + } + return {std::move(mrs)}; } - return {std::move(mrs)}; - } - template - send_multi_request send_multi(message_buffer const& msg, - std::vector const& neighs, std::vector const& tags) - { - assert(neighs.size() == tags.size()); - return send_multi(msg, neighs.data(), tags.data(), neighs.size()); - } + template + send_multi_request send_multi(message_buffer const& msg, + std::vector const& neighs, std::vector const& tags) + { + assert(neighs.size() == tags.size()); + return send_multi(msg, neighs.data(), tags.data(), neighs.size()); + } - // callback versions - // ================= + // callback versions + // ================= - // recv - // ---- + // recv + // ---- - template - recv_request recv(message_buffer&& msg, rank_type src, tag_type tag, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK(CallBack) - assert(msg); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - return recv(m_ptr, s * sizeof(T), src, tag, - util::unique_function( - cb_rref{std::forward(callback), std::move(msg)})); - } - - template - recv_request recv(message_buffer& msg, rank_type src, tag_type tag, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_REF(CallBack) - assert(msg); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - return recv(m_ptr, s * sizeof(T), src, tag, - util::unique_function( - cb_lref{std::forward(callback), &msg})); - } - - // shared_recv - // ----------- - - template - shared_recv_request shared_recv(message_buffer&& msg, rank_type src, tag_type tag, - CallBack&& callback) - { - OOMPH_CHECK_CALLBACK(CallBack) - assert(msg); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - return shared_recv(m_ptr, s * sizeof(T), src, tag, - util::unique_function( - cb_rref{std::forward(callback), std::move(msg)})); - } - - template - shared_recv_request shared_recv(message_buffer& msg, rank_type src, tag_type tag, - CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_REF(CallBack) - assert(msg); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - return shared_recv(m_ptr, s * sizeof(T), src, tag, - util::unique_function( - cb_lref{std::forward(callback), &msg})); - } - - // send - // ---- - - template - send_request send(message_buffer&& msg, rank_type dst, tag_type tag, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK(CallBack) - assert(msg); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - return send(m_ptr, s * sizeof(T), dst, tag, - util::unique_function( - cb_rref{std::forward(callback), std::move(msg)})); - } - - template - send_request send(message_buffer& msg, rank_type dst, tag_type tag, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_REF(CallBack) - assert(msg); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - return send(m_ptr, s * sizeof(T), dst, tag, - util::unique_function( - cb_lref{std::forward(callback), &msg})); - } - - template - send_request send(message_buffer const& msg, rank_type dst, tag_type tag, - CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_CONST_REF(CallBack) - assert(msg); - const auto s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - return send(m_ptr, s * sizeof(T), dst, tag, - util::unique_function( - cb_lref_const{std::forward(callback), &msg})); - } - - // send_multi - // ---------- - - template - send_multi_request send_multi(message_buffer&& msg, std::vector neighs, - tag_type tag, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_MULTI(CallBack) - assert(msg); - auto const s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - auto mrs = m_state->make_multi_request_state(std::move(neighs), std::move(msg)); - for (auto dst : mrs->m_neighs) - { - send(m_ptr, s * sizeof(T), dst, tag, + template + recv_request recv(message_buffer&& msg, rank_type src, tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + return recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - [mrs, callback](rank_type, tag_type t) - { - if (--(mrs->m_counter) == 0ul) - { - callback(message_buffer(std::move(mrs->m_msg), mrs->m_msg_size), - std::move(mrs->m_neighs), t); - } - })); + cb_rref{std::forward(callback), std::move(msg)})); } - return {std::move(mrs)}; - } - template - send_multi_request send_multi(message_buffer&& msg, std::vector neighs, - std::vector tags, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_MULTI_TAGS(CallBack) - assert(msg); - assert(neighs.size() == tags.size()); - auto const s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - auto mrs = - m_state->make_multi_request_state(std::move(neighs), std::move(tags), std::move(msg)); - const auto n = mrs->m_neighs.size(); - for (std::size_t i = 0; i < n; ++i) - { - send(m_ptr, s * sizeof(T), mrs->m_neighs[i], mrs->m_tags[i], + template + recv_request recv(message_buffer& msg, rank_type src, tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_REF(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + return recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - [mrs, callback](rank_type, tag_type) - { - if (--(mrs->m_counter) == 0ul) - { - callback(message_buffer(std::move(mrs->m_msg), mrs->m_msg_size), - std::move(mrs->m_neighs), mrs->m_tags); - } - })); + cb_lref{std::forward(callback), &msg})); } - return {std::move(mrs)}; - } - template - send_multi_request send_multi(message_buffer& msg, std::vector neighs, - tag_type tag, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_MULTI_REF(CallBack) - assert(msg); - auto const s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - auto mrs = m_state->make_multi_request_state(std::move(neighs), msg); - for (auto dst : mrs->m_neighs) - { - send(m_ptr, s * sizeof(T), dst, tag, + // shared_recv + // ----------- + + template + shared_recv_request + shared_recv(message_buffer&& msg, rank_type src, tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + return shared_recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - [mrs, callback](rank_type, tag_type t) - { - if (--(mrs->m_counter) == 0ul) - { - callback(*reinterpret_cast*>(mrs->m_msg_ptr), - std::move(mrs->m_neighs), t); - } - })); + cb_rref{std::forward(callback), std::move(msg)})); } - return {std::move(mrs)}; - } - template - send_multi_request send_multi(message_buffer& msg, std::vector neighs, - std::vector tags, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_MULTI_REF_TAGS(CallBack) - assert(msg); - assert(neighs.size() == tags.size()); - auto const s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - auto mrs = m_state->make_multi_request_state(std::move(neighs), std::move(tags), msg); - const auto n = mrs->m_neighs.size(); - for (std::size_t i = 0; i < n; ++i) - { - send(m_ptr, s * sizeof(T), mrs->m_neighs[i], mrs->m_tags[i], + template + shared_recv_request + shared_recv(message_buffer& msg, rank_type src, tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_REF(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + return shared_recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - [mrs, callback](rank_type, tag_type) - { - if (--(mrs->m_counter) == 0ul) - { - callback(*reinterpret_cast*>(mrs->m_msg_ptr), - std::move(mrs->m_neighs), std::move(mrs->m_tags)); - } - })); + cb_lref{std::forward(callback), &msg})); } - return {std::move(mrs)}; - } - template - send_multi_request send_multi(message_buffer const& msg, std::vector neighs, - tag_type tag, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_MULTI_CONST_REF(CallBack) - assert(msg); - auto const s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - auto mrs = m_state->make_multi_request_state(std::move(neighs), msg); - for (auto dst : mrs->m_neighs) - { - send(m_ptr, s * sizeof(T), dst, tag, + // send + // ---- + + template + send_request send(message_buffer&& msg, rank_type dst, tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + return send(m_ptr, s * sizeof(T), dst, tag, + util::unique_function( + cb_rref{std::forward(callback), std::move(msg)})); + } + + template + send_request send(message_buffer& msg, rank_type dst, tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_REF(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + return send(m_ptr, s * sizeof(T), dst, tag, util::unique_function( - [mrs, callback](rank_type, tag_type t) - { + cb_lref{std::forward(callback), &msg})); + } + + template + send_request + send(message_buffer const& msg, rank_type dst, tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_CONST_REF(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + return send(m_ptr, s * sizeof(T), dst, tag, + util::unique_function( + cb_lref_const{std::forward(callback), &msg})); + } + + // send_multi + // ---------- + + template + send_multi_request send_multi(message_buffer&& msg, std::vector neighs, + tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_MULTI(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + auto mrs = m_state->make_multi_request_state(std::move(neighs), std::move(msg)); + for (auto dst : mrs->m_neighs) + { + send(m_ptr, s * sizeof(T), dst, tag, + util::unique_function( + [mrs, callback](rank_type, tag_type t) { + if (--(mrs->m_counter) == 0ul) + { + callback(message_buffer(std::move(mrs->m_msg), mrs->m_msg_size), + std::move(mrs->m_neighs), t); + } + })); + } + return {std::move(mrs)}; + } + + template + send_multi_request send_multi(message_buffer&& msg, std::vector neighs, + std::vector tags, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_MULTI_TAGS(CallBack) + assert(msg); + assert(neighs.size() == tags.size()); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + auto mrs = m_state->make_multi_request_state( + std::move(neighs), std::move(tags), std::move(msg)); + auto const n = mrs->m_neighs.size(); + for (std::size_t i = 0; i < n; ++i) + { + send(m_ptr, s * sizeof(T), mrs->m_neighs[i], mrs->m_tags[i], + util::unique_function( + [mrs, callback](rank_type, tag_type) { + if (--(mrs->m_counter) == 0ul) + { + callback(message_buffer(std::move(mrs->m_msg), mrs->m_msg_size), + std::move(mrs->m_neighs), mrs->m_tags); + } + })); + } + return {std::move(mrs)}; + } + + template + send_multi_request send_multi(message_buffer& msg, std::vector neighs, + tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_MULTI_REF(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + auto mrs = m_state->make_multi_request_state(std::move(neighs), msg); + for (auto dst : mrs->m_neighs) + { + send(m_ptr, s * sizeof(T), dst, tag, + util::unique_function( + [mrs, callback](rank_type, tag_type t) { + if (--(mrs->m_counter) == 0ul) + { + callback(*reinterpret_cast*>(mrs->m_msg_ptr), + std::move(mrs->m_neighs), t); + } + })); + } + return {std::move(mrs)}; + } + + template + send_multi_request send_multi(message_buffer& msg, std::vector neighs, + std::vector tags, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_MULTI_REF_TAGS(CallBack) + assert(msg); + assert(neighs.size() == tags.size()); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + auto mrs = m_state->make_multi_request_state(std::move(neighs), std::move(tags), msg); + auto const n = mrs->m_neighs.size(); + for (std::size_t i = 0; i < n; ++i) + { + send(m_ptr, s * sizeof(T), mrs->m_neighs[i], mrs->m_tags[i], + util::unique_function( + [mrs, callback](rank_type, tag_type) { + if (--(mrs->m_counter) == 0ul) + { + callback(*reinterpret_cast*>(mrs->m_msg_ptr), + std::move(mrs->m_neighs), std::move(mrs->m_tags)); + } + })); + } + return {std::move(mrs)}; + } + + template + send_multi_request send_multi(message_buffer const& msg, std::vector neighs, + tag_type tag, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_MULTI_CONST_REF(CallBack) + assert(msg); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + auto mrs = m_state->make_multi_request_state(std::move(neighs), msg); + for (auto dst : mrs->m_neighs) + { + send(m_ptr, s * sizeof(T), dst, tag, + util::unique_function([mrs, callback]( + rank_type, tag_type t) { if (--(mrs->m_counter) == 0ul) { callback(*reinterpret_cast const*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), t); } })); + } + return {std::move(mrs)}; } - return {std::move(mrs)}; - } - template - send_multi_request send_multi(message_buffer const& msg, std::vector neighs, - std::vector tags, CallBack&& callback) - { - OOMPH_CHECK_CALLBACK_MULTI_CONST_REF_TAGS(CallBack) - assert(msg); - assert(neighs.size() == tags.size()); - auto const s = msg.size(); - auto m_ptr = msg.m.m_heap_ptr.get(); - auto mrs = m_state->make_multi_request_state(std::move(neighs), std::move(tags), msg); - const auto n = mrs->m_neighs.size(); - for (std::size_t i = 0; i < n; ++i) - { - send(m_ptr, s * sizeof(T), mrs->m_neighs[i], mrs->m_tags[i], - util::unique_function( - [mrs, callback](rank_type, tag_type) - { + template + send_multi_request send_multi(message_buffer const& msg, std::vector neighs, + std::vector tags, CallBack&& callback) + { + OOMPH_CHECK_CALLBACK_MULTI_CONST_REF_TAGS(CallBack) + assert(msg); + assert(neighs.size() == tags.size()); + auto const s = msg.size(); + auto m_ptr = msg.m.m_heap_ptr.get(); + auto mrs = m_state->make_multi_request_state(std::move(neighs), std::move(tags), msg); + auto const n = mrs->m_neighs.size(); + for (std::size_t i = 0; i < n; ++i) + { + send(m_ptr, s * sizeof(T), mrs->m_neighs[i], mrs->m_tags[i], + util::unique_function([mrs, callback]( + rank_type, tag_type) { if (--(mrs->m_counter) == 0ul) { callback(*reinterpret_cast const*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), std::move(mrs->m_tags)); } })); + } + return {std::move(mrs)}; } - return {std::move(mrs)}; - } - void progress(); + void progress(); - private: - detail::message_buffer make_buffer_core(std::size_t size); - detail::message_buffer make_buffer_core(void* ptr, std::size_t size); + private: + detail::message_buffer make_buffer_core(std::size_t size); + detail::message_buffer make_buffer_core(void* ptr, std::size_t size); #if OOMPH_ENABLE_DEVICE - detail::message_buffer make_buffer_core(std::size_t size, int device_id); - detail::message_buffer make_buffer_core(void* device_ptr, std::size_t size, int device_id); - detail::message_buffer make_buffer_core(void* ptr, void* device_ptr, std::size_t size, - int device_id); + detail::message_buffer make_buffer_core(std::size_t size, int device_id); + detail::message_buffer make_buffer_core(void* device_ptr, std::size_t size, int device_id); + detail::message_buffer make_buffer_core( + void* ptr, void* device_ptr, std::size_t size, int device_id); #endif - send_request send(detail::message_buffer::heap_ptr_impl const* m_ptr, std::size_t size, - rank_type dst, tag_type tag, util::unique_function&& cb); + send_request send(detail::message_buffer::heap_ptr_impl const* m_ptr, std::size_t size, + rank_type dst, tag_type tag, util::unique_function&& cb); - recv_request recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb); + recv_request recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, + rank_type src, tag_type tag, util::unique_function&& cb); - shared_recv_request shared_recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, - rank_type src, tag_type tag, util::unique_function&& cb); -}; + shared_recv_request shared_recv(detail::message_buffer::heap_ptr_impl* m_ptr, + std::size_t size, rank_type src, tag_type tag, + util::unique_function&& cb); + }; -} // namespace oomph +} // namespace oomph diff --git a/include/oomph/context.hpp b/include/oomph/context.hpp index ac5f66f4..97ed4807 100644 --- a/include/oomph/context.hpp +++ b/include/oomph/context.hpp @@ -11,116 +11,115 @@ #include #include +#include #include #include -#include -#include #include +#include -namespace oomph -{ -class context_impl; -class barrier; -class context -{ - friend class barrier; - - public: - using pimpl = util::heap_pimpl; - - public: - struct schedule +namespace oomph { + class context_impl; + class barrier; + class context { - std::atomic scheduled_sends = 0; - std::atomic scheduled_recvs = 0; - }; + friend class barrier; - private: - util::mpi_comm_holder m_mpi_comm; - pimpl m; - std::unique_ptr m_schedule; + public: + using pimpl = util::heap_pimpl; - public: - context(MPI_Comm comm, bool thread_safe = true, bool message_pool_never_free = false, - std::size_t message_pool_reserve = 1); + public: + struct schedule + { + std::atomic scheduled_sends = 0; + std::atomic scheduled_recvs = 0; + }; - context(context const&) = delete; + private: + util::mpi_comm_holder m_mpi_comm; + pimpl m; + std::unique_ptr m_schedule; - context(context&&) noexcept = default; + public: + context(MPI_Comm comm, bool thread_safe = true, bool message_pool_never_free = false, + std::size_t message_pool_reserve = 1); - context& operator=(context const&) = delete; + context(context const&) = delete; - context& operator=(context&&) noexcept = default; + context(context&&) noexcept = default; - ~context(); + context& operator=(context const&) = delete; - public: - rank_type rank() const noexcept; + context& operator=(context&&) noexcept = default; - rank_type size() const noexcept; + ~context(); - rank_type local_rank() const noexcept; + public: + rank_type rank() const noexcept; - rank_type local_size() const noexcept; + rank_type size() const noexcept; - MPI_Comm mpi_comm() const noexcept { return m_mpi_comm.get(); } + rank_type local_rank() const noexcept; - template - message_buffer make_buffer(std::size_t size) - { - return {make_buffer_core(size * sizeof(T)), size}; - } + rank_type local_size() const noexcept; - template - message_buffer make_buffer(T* ptr, std::size_t size) - { - return {make_buffer_core(ptr, size * sizeof(T)), size}; - } + MPI_Comm mpi_comm() const noexcept { return m_mpi_comm.get(); } -#if OOMPH_ENABLE_DEVICE - template - message_buffer make_device_buffer(std::size_t size, int id = hwmalloc::get_device_id()) - { - return {make_buffer_core(size * sizeof(T), id), size}; - } + template + message_buffer make_buffer(std::size_t size) + { + return {make_buffer_core(size * sizeof(T)), size}; + } - template - message_buffer make_device_buffer(T* device_ptr, std::size_t size, - int id = hwmalloc::get_device_id()) - { - return {make_buffer_core(device_ptr, size * sizeof(T), id), size}; - } + template + message_buffer make_buffer(T* ptr, std::size_t size) + { + return {make_buffer_core(ptr, size * sizeof(T)), size}; + } - template - message_buffer make_device_buffer(T* ptr, T* device_ptr, std::size_t size, - int id = hwmalloc::get_device_id()) - { - return {make_buffer_core(ptr, device_ptr, size * sizeof(T), id), size}; - } +#if OOMPH_ENABLE_DEVICE + template + message_buffer make_device_buffer(std::size_t size, int id = hwmalloc::get_device_id()) + { + return {make_buffer_core(size * sizeof(T), id), size}; + } + + template + message_buffer + make_device_buffer(T* device_ptr, std::size_t size, int id = hwmalloc::get_device_id()) + { + return {make_buffer_core(device_ptr, size * sizeof(T), id), size}; + } + + template + message_buffer make_device_buffer( + T* ptr, T* device_ptr, std::size_t size, int id = hwmalloc::get_device_id()) + { + return {make_buffer_core(ptr, device_ptr, size * sizeof(T), id), size}; + } #endif - communicator get_communicator(); //unsigned int tag_range = 0); + communicator get_communicator(); //unsigned int tag_range = 0); - //unsigned int num_tag_ranges() const noexcept { return m_tag_range_factory.num_ranges(); } + //unsigned int num_tag_ranges() const noexcept { return m_tag_range_factory.num_ranges(); } - const char* get_transport_option(const std::string& opt); + char const* get_transport_option(std::string const& opt); - private: - detail::message_buffer make_buffer_core(std::size_t size); - detail::message_buffer make_buffer_core(void* ptr, std::size_t size); + private: + detail::message_buffer make_buffer_core(std::size_t size); + detail::message_buffer make_buffer_core(void* ptr, std::size_t size); #if OOMPH_ENABLE_DEVICE - detail::message_buffer make_buffer_core(std::size_t size, int device_id); - detail::message_buffer make_buffer_core(void* device_ptr, std::size_t size, int device_id); - detail::message_buffer make_buffer_core(void* ptr, void* device_ptr, std::size_t size, - int device_id); + detail::message_buffer make_buffer_core(std::size_t size, int device_id); + detail::message_buffer make_buffer_core(void* device_ptr, std::size_t size, int device_id); + detail::message_buffer make_buffer_core( + void* ptr, void* device_ptr, std::size_t size, int device_id); #endif -}; + }; -template -typename Context::region_type register_memory(Context&, void*, std::size_t); + template + typename Context::region_type register_memory(Context&, void*, std::size_t); #if OOMPH_ENABLE_DEVICE -template -typename Context::device_region_type register_device_memory(Context&, int, void*, std::size_t); + template + typename Context::device_region_type register_device_memory(Context&, int, void*, std::size_t); #endif -} // namespace oomph +} // namespace oomph diff --git a/include/oomph/detail/communicator_helper.hpp b/include/oomph/detail/communicator_helper.hpp index 6e0e97d5..8335c6eb 100644 --- a/include/oomph/detail/communicator_helper.hpp +++ b/include/oomph/detail/communicator_helper.hpp @@ -11,6 +11,7 @@ #include #include +#include #include #include //#include @@ -33,7 +34,7 @@ #define OOMPH_CHECK_CALLBACK_MSG_REF \ static_assert(std::is_same&>::value || \ - std::is_same const&>::value, \ + std::is_same const&>::value, \ "first callback argument type is not an l-value reference to a message_buffer"); #define OOMPH_CHECK_CALLBACK_MSG_CONST_REF \ @@ -41,129 +42,107 @@ "first callback argument type is not a const l-value reference to a message_buffer"); #define OOMPH_CHECK_CALLBACK(CALLBACK) \ - { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) \ - OOMPH_CHECK_CALLBACK_MSG \ - } + {OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) OOMPH_CHECK_CALLBACK_MSG} #define OOMPH_CHECK_CALLBACK_MULTI(CALLBACK) \ - { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, tag_type) \ - OOMPH_CHECK_CALLBACK_MSG \ - } + {OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, tag_type) OOMPH_CHECK_CALLBACK_MSG} #define OOMPH_CHECK_CALLBACK_MULTI_TAGS(CALLBACK) \ - { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, std::vector) \ - OOMPH_CHECK_CALLBACK_MSG \ - } + {OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, std::vector) \ + OOMPH_CHECK_CALLBACK_MSG} #define OOMPH_CHECK_CALLBACK_REF(CALLBACK) \ - { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) \ - OOMPH_CHECK_CALLBACK_MSG_REF \ - } + {OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) OOMPH_CHECK_CALLBACK_MSG_REF} #define OOMPH_CHECK_CALLBACK_MULTI_REF(CALLBACK) \ - { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, tag_type) \ - OOMPH_CHECK_CALLBACK_MSG_REF \ - } + {OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, tag_type) \ + OOMPH_CHECK_CALLBACK_MSG_REF} #define OOMPH_CHECK_CALLBACK_MULTI_REF_TAGS(CALLBACK) \ - { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, std::vector) \ - OOMPH_CHECK_CALLBACK_MSG_REF \ - } + {OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, std::vector) \ + OOMPH_CHECK_CALLBACK_MSG_REF} #define OOMPH_CHECK_CALLBACK_CONST_REF(CALLBACK) \ - { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) \ - OOMPH_CHECK_CALLBACK_MSG_CONST_REF \ - } + {OOMPH_CHECK_CALLBACK_F(CALLBACK, rank_type, tag_type) OOMPH_CHECK_CALLBACK_MSG_CONST_REF} #define OOMPH_CHECK_CALLBACK_MULTI_CONST_REF(CALLBACK) \ - { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, tag_type) \ - OOMPH_CHECK_CALLBACK_MSG_CONST_REF \ - } + {OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, tag_type) \ + OOMPH_CHECK_CALLBACK_MSG_CONST_REF} #define OOMPH_CHECK_CALLBACK_MULTI_CONST_REF_TAGS(CALLBACK) \ - { \ - OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, std::vector) \ - OOMPH_CHECK_CALLBACK_MSG_CONST_REF \ - } - -namespace oomph -{ -class communicator_impl; - -namespace detail -{ -struct communicator_state -{ - using impl_type = communicator_impl; - impl_type* m_impl; - std::atomic* m_shared_scheduled_recvs; - util::pool_factory m_mrs_factory; - std::size_t scheduled_sends = 0; - std::size_t scheduled_recvs = 0; - - communicator_state(impl_type* impl_, std::atomic* shared_scheduled_recvs); - ~communicator_state(); - communicator_state(communicator_state const&) = delete; - communicator_state(communicator_state&&) = delete; - communicator_state& operator=(communicator_state const&) = delete; - communicator_state& operator=(communicator_state&&) = delete; - - auto make_multi_request_state(std::size_t ns) { return m_mrs_factory.make(m_impl, ns); } - - template - auto make_multi_request_state(std::vector&& neighs, - oomph::message_buffer const& msg) - { - return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), std::vector{}, - msg.size(), &msg); - } - - template - auto make_multi_request_state(std::vector&& neighs, std::vector&& tags, - oomph::message_buffer const& msg) - { - return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), std::move(tags), - msg.size(), &msg); - } - - template - auto make_multi_request_state(std::vector&& neighs, oomph::message_buffer& msg) - { - return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), std::vector{}, - msg.size(), &msg); - } - - template - auto make_multi_request_state(std::vector&& neighs, std::vector&& tags, - oomph::message_buffer& msg) - { - return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), std::move(tags), - msg.size(), &msg); - } - - template - auto make_multi_request_state(std::vector&& neighs, oomph::message_buffer&& msg) - { - return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), std::vector{}, - msg.size(), nullptr, std::move(msg.m)); - } - - template - auto make_multi_request_state(std::vector&& neighs, std::vector&& tags, - oomph::message_buffer&& msg) - { - return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), std::move(tags), - msg.size(), nullptr, std::move(msg.m)); - } -}; - -} // namespace detail -} // namespace oomph + {OOMPH_CHECK_CALLBACK_F(CALLBACK, std::vector, std::vector) \ + OOMPH_CHECK_CALLBACK_MSG_CONST_REF} + +namespace oomph { + class communicator_impl; + + namespace detail { + struct communicator_state + { + using impl_type = communicator_impl; + impl_type* m_impl; + std::atomic* m_shared_scheduled_recvs; + util::pool_factory m_mrs_factory; + std::size_t scheduled_sends = 0; + std::size_t scheduled_recvs = 0; + + communicator_state(impl_type* impl_, std::atomic* shared_scheduled_recvs); + ~communicator_state(); + communicator_state(communicator_state const&) = delete; + communicator_state(communicator_state&&) = delete; + communicator_state& operator=(communicator_state const&) = delete; + communicator_state& operator=(communicator_state&&) = delete; + + auto make_multi_request_state(std::size_t ns) { return m_mrs_factory.make(m_impl, ns); } + + template + auto make_multi_request_state( + std::vector&& neighs, oomph::message_buffer const& msg) + { + return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), + std::vector{}, msg.size(), &msg); + } + + template + auto make_multi_request_state(std::vector&& neighs, + std::vector&& tags, oomph::message_buffer const& msg) + { + return m_mrs_factory.make( + m_impl, neighs.size(), std::move(neighs), std::move(tags), msg.size(), &msg); + } + + template + auto + make_multi_request_state(std::vector&& neighs, oomph::message_buffer& msg) + { + return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), + std::vector{}, msg.size(), &msg); + } + + template + auto make_multi_request_state(std::vector&& neighs, + std::vector&& tags, oomph::message_buffer& msg) + { + return m_mrs_factory.make( + m_impl, neighs.size(), std::move(neighs), std::move(tags), msg.size(), &msg); + } + + template + auto make_multi_request_state( + std::vector&& neighs, oomph::message_buffer&& msg) + { + return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), + std::vector{}, msg.size(), nullptr, std::move(msg.m)); + } + + template + auto make_multi_request_state(std::vector&& neighs, + std::vector&& tags, oomph::message_buffer&& msg) + { + return m_mrs_factory.make(m_impl, neighs.size(), std::move(neighs), std::move(tags), + msg.size(), nullptr, std::move(msg.m)); + } + }; + + } // namespace detail +} // namespace oomph diff --git a/include/oomph/detail/message_buffer.hpp b/include/oomph/detail/message_buffer.hpp index 8f8408da..07ffa0ca 100644 --- a/include/oomph/detail/message_buffer.hpp +++ b/include/oomph/detail/message_buffer.hpp @@ -12,61 +12,56 @@ #include #include -#include #include +#include -namespace oomph -{ -namespace detail -{ -class message_buffer -{ - public: - class heap_ptr_impl; - //using pimpl = util::pimpl; - using pimpl = util::heap_pimpl; - - public: - void* m_ptr = nullptr; - pimpl m_heap_ptr; - - public: - ~message_buffer(); - - message_buffer() noexcept = default; - - template - message_buffer(VoidPtr ptr) - : m_ptr{ptr.get()} - , m_heap_ptr(ptr) +namespace oomph { namespace detail { + class message_buffer { - } + public: + class heap_ptr_impl; + //using pimpl = util::pimpl; + using pimpl = util::heap_pimpl; - message_buffer(message_buffer&& other) noexcept - : m_ptr{std::exchange(other.m_ptr, nullptr)} - , m_heap_ptr{std::move(other.m_heap_ptr)} - { - } + public: + void* m_ptr = nullptr; + pimpl m_heap_ptr; + + public: + ~message_buffer(); - message_buffer& operator=(message_buffer&&); + message_buffer() noexcept = default; - operator bool() const noexcept { return m_ptr; } + template + message_buffer(VoidPtr ptr) + : m_ptr{ptr.get()} + , m_heap_ptr(ptr) + { + } - bool on_device() const noexcept; + message_buffer(message_buffer&& other) noexcept + : m_ptr{std::exchange(other.m_ptr, nullptr)} + , m_heap_ptr{std::move(other.m_heap_ptr)} + { + } + + message_buffer& operator=(message_buffer&&); + + operator bool() const noexcept { return m_ptr; } + + bool on_device() const noexcept; #if OOMPH_ENABLE_DEVICE - void* device_data() noexcept; - void const* device_data() const noexcept; + void* device_data() noexcept; + void const* device_data() const noexcept; - int device_id() const noexcept; + int device_id() const noexcept; - void clone_to_device(std::size_t count); - void clone_to_host(std::size_t count); + void clone_to_device(std::size_t count); + void clone_to_host(std::size_t count); #endif - void clear(); -}; - -} // namespace detail + void clear(); + }; -} // namespace oomph +}} // namespace oomph::detail diff --git a/include/oomph/message_buffer.hpp b/include/oomph/message_buffer.hpp index a4527d5d..b1039f4b 100644 --- a/include/oomph/message_buffer.hpp +++ b/include/oomph/message_buffer.hpp @@ -10,88 +10,86 @@ #pragma once #include -#include #include #include +#include -namespace oomph -{ - -namespace detail -{ -struct communicator_state; -} - -template -class message_buffer -{ - public: - using value_type = T; - - private: - friend class context; - friend class communicator; - friend struct detail::communicator_state; - - private: - detail::message_buffer m; - std::size_t m_size; - - private: - message_buffer(detail::message_buffer&& m_, std::size_t size_) - : m{std::move(m_)} - , m_size{size_} - { - } - - public: - message_buffer() = default; - message_buffer(message_buffer&&) = default; - message_buffer& operator=(message_buffer&&) = default; +namespace oomph { - template - message_buffer(message_buffer&& other) noexcept - : m{std::move(other.m)} - , m_size{(other.m_size * sizeof(U)) / sizeof(T)} - { + namespace detail { + struct communicator_state; } - template - message_buffer& operator=(message_buffer&& other) noexcept + template + class message_buffer { - m = std::move(other.m); - m_size = (other.m_size * sizeof(U)) / sizeof(T); - return *this; - } - - public: - operator bool() const noexcept { return m; } - - std::size_t size() const noexcept { return m_size; } - - T* data() noexcept { return (T*)m.m_ptr; } - T const* data() const noexcept { return (T const*)m.m_ptr; } - T* begin() noexcept { return data(); } - T const* begin() const noexcept { return data(); } - T* end() noexcept { return data() + size(); } - T const* end() const noexcept { return data() + size(); } - T const* cbegin() const noexcept { return data(); } - T const* cend() const noexcept { return data() + size(); } - - T& operator[](std::size_t i) noexcept { return *(data() + i); } - T const& operator[](std::size_t i) const noexcept { return *(data() + i); } - - bool on_device() const noexcept { return m.on_device(); } + public: + using value_type = T; + + private: + friend class context; + friend class communicator; + friend struct detail::communicator_state; + + private: + detail::message_buffer m; + std::size_t m_size; + + private: + message_buffer(detail::message_buffer&& m_, std::size_t size_) + : m{std::move(m_)} + , m_size{size_} + { + } + + public: + message_buffer() = default; + message_buffer(message_buffer&&) = default; + message_buffer& operator=(message_buffer&&) = default; + + template + message_buffer(message_buffer&& other) noexcept + : m{std::move(other.m)} + , m_size{(other.m_size * sizeof(U)) / sizeof(T)} + { + } + + template + message_buffer& operator=(message_buffer&& other) noexcept + { + m = std::move(other.m); + m_size = (other.m_size * sizeof(U)) / sizeof(T); + return *this; + } + + public: + operator bool() const noexcept { return m; } + + std::size_t size() const noexcept { return m_size; } + + T* data() noexcept { return (T*) m.m_ptr; } + T const* data() const noexcept { return (T const*) m.m_ptr; } + T* begin() noexcept { return data(); } + T const* begin() const noexcept { return data(); } + T* end() noexcept { return data() + size(); } + T const* end() const noexcept { return data() + size(); } + T const* cbegin() const noexcept { return data(); } + T const* cend() const noexcept { return data() + size(); } + + T& operator[](std::size_t i) noexcept { return *(data() + i); } + T const& operator[](std::size_t i) const noexcept { return *(data() + i); } + + bool on_device() const noexcept { return m.on_device(); } #if OOMPH_ENABLE_DEVICE - T* device_data() noexcept { return (T*)m.device_data(); } - T const* device_data() const noexcept { return (T*)m.device_data(); } + T* device_data() noexcept { return (T*) m.device_data(); } + T const* device_data() const noexcept { return (T*) m.device_data(); } - int device_id() const noexcept { return m.device_id(); } + int device_id() const noexcept { return m.device_id(); } - void clone_to_device() { m.clone_to_device(m_size * sizeof(T)); } - void clone_to_host() { m.clone_to_host(m_size * sizeof(T)); } + void clone_to_device() { m.clone_to_device(m_size * sizeof(T)); } + void clone_to_host() { m.clone_to_host(m_size * sizeof(T)); } #endif -}; + }; -} // namespace oomph +} // namespace oomph diff --git a/include/oomph/request.hpp b/include/oomph/request.hpp index 3f9a8e53..a69d46ef 100644 --- a/include/oomph/request.hpp +++ b/include/oomph/request.hpp @@ -10,144 +10,142 @@ #pragma once #include -#include #include #include +#include + +namespace oomph { + + class communicator_impl; + + namespace detail { + // fwd declarations + struct request_state; + struct shared_request_state; + + struct multi_request_state + { + communicator_impl* m_comm; + std::size_t m_counter; + std::vector m_neighs = std::vector(); + std::vector m_tags = std::vector(); + std::size_t m_msg_size = 0ul; + void* m_msg_ptr = nullptr; + oomph::detail::message_buffer m_msg = oomph::detail::message_buffer(); + }; + } // namespace detail -namespace oomph -{ - -class communicator_impl; - -namespace detail -{ -// fwd declarations -struct request_state; -struct shared_request_state; - -struct multi_request_state -{ - communicator_impl* m_comm; - std::size_t m_counter; - std::vector m_neighs = std::vector(); - std::vector m_tags = std::vector(); - std::size_t m_msg_size = 0ul; - void* m_msg_ptr = nullptr; - oomph::detail::message_buffer m_msg = oomph::detail::message_buffer(); -}; -} // namespace detail - -class send_request -{ - protected: - using state_type = detail::request_state; - friend class communicator; - friend class communicator_impl; - - util::unsafe_shared_ptr m; - - send_request(util::unsafe_shared_ptr s) noexcept - : m{std::move(s)} + class send_request { - } - - public: - send_request() = default; - send_request(send_request const&) = delete; - send_request(send_request&&) = default; - send_request& operator=(send_request const&) = delete; - send_request& operator=(send_request&&) = default; - - public: - bool is_ready() const noexcept; - bool test(); - void wait(); -}; - -class recv_request -{ - protected: - using state_type = detail::request_state; - friend class communicator; - friend class communicator_impl; - - util::unsafe_shared_ptr m; - - recv_request(util::unsafe_shared_ptr s) noexcept - : m{std::move(s)} + protected: + using state_type = detail::request_state; + friend class communicator; + friend class communicator_impl; + + util::unsafe_shared_ptr m; + + send_request(util::unsafe_shared_ptr s) noexcept + : m{std::move(s)} + { + } + + public: + send_request() = default; + send_request(send_request const&) = delete; + send_request(send_request&&) = default; + send_request& operator=(send_request const&) = delete; + send_request& operator=(send_request&&) = default; + + public: + bool is_ready() const noexcept; + bool test(); + void wait(); + }; + + class recv_request { - } - - public: - recv_request() = default; - recv_request(recv_request const&) = delete; - recv_request(recv_request&&) = default; - recv_request& operator=(recv_request const&) = delete; - recv_request& operator=(recv_request&&) = default; - - public: - bool is_ready() const noexcept; - bool is_canceled() const noexcept; - bool test(); - void wait(); - bool cancel(); -}; - -class shared_recv_request -{ - private: - using state_type = detail::shared_request_state; - friend class communicator; - friend class communicator_impl; - - private: - std::shared_ptr m; - - shared_recv_request(std::shared_ptr s) noexcept - : m{std::move(s)} + protected: + using state_type = detail::request_state; + friend class communicator; + friend class communicator_impl; + + util::unsafe_shared_ptr m; + + recv_request(util::unsafe_shared_ptr s) noexcept + : m{std::move(s)} + { + } + + public: + recv_request() = default; + recv_request(recv_request const&) = delete; + recv_request(recv_request&&) = default; + recv_request& operator=(recv_request const&) = delete; + recv_request& operator=(recv_request&&) = default; + + public: + bool is_ready() const noexcept; + bool is_canceled() const noexcept; + bool test(); + void wait(); + bool cancel(); + }; + + class shared_recv_request { - } - - public: - shared_recv_request() = default; - shared_recv_request(shared_recv_request const&) = default; - shared_recv_request(shared_recv_request&&) = default; - shared_recv_request& operator=(shared_recv_request const&) = default; - shared_recv_request& operator=(shared_recv_request&&) = default; - - public: - bool is_ready() const noexcept; - bool is_canceled() const noexcept; - bool test(); - void wait(); - bool cancel(); -}; - -class send_multi_request -{ - protected: - using state_type = detail::multi_request_state; - friend class communicator; - friend class communicator_impl; - - util::unsafe_shared_ptr m; - - send_multi_request(util::unsafe_shared_ptr s) noexcept - : m{std::move(s)} + private: + using state_type = detail::shared_request_state; + friend class communicator; + friend class communicator_impl; + + private: + std::shared_ptr m; + + shared_recv_request(std::shared_ptr s) noexcept + : m{std::move(s)} + { + } + + public: + shared_recv_request() = default; + shared_recv_request(shared_recv_request const&) = default; + shared_recv_request(shared_recv_request&&) = default; + shared_recv_request& operator=(shared_recv_request const&) = default; + shared_recv_request& operator=(shared_recv_request&&) = default; + + public: + bool is_ready() const noexcept; + bool is_canceled() const noexcept; + bool test(); + void wait(); + bool cancel(); + }; + + class send_multi_request { - } - - public: - send_multi_request() = default; - send_multi_request(send_multi_request const&) = delete; - send_multi_request(send_multi_request&&) = default; - send_multi_request& operator=(send_multi_request const&) = delete; - send_multi_request& operator=(send_multi_request&&) = default; - - public: - bool is_ready() const noexcept; - bool test(); - void wait(); -}; - -} // namespace oomph + protected: + using state_type = detail::multi_request_state; + friend class communicator; + friend class communicator_impl; + + util::unsafe_shared_ptr m; + + send_multi_request(util::unsafe_shared_ptr s) noexcept + : m{std::move(s)} + { + } + + public: + send_multi_request() = default; + send_multi_request(send_multi_request const&) = delete; + send_multi_request(send_multi_request&&) = default; + send_multi_request& operator=(send_multi_request const&) = delete; + send_multi_request& operator=(send_multi_request&&) = default; + + public: + bool is_ready() const noexcept; + bool test(); + void wait(); + }; + +} // namespace oomph diff --git a/include/oomph/types.hpp b/include/oomph/types.hpp index 59a0ff35..23c9c99b 100644 --- a/include/oomph/types.hpp +++ b/include/oomph/types.hpp @@ -9,10 +9,9 @@ */ #pragma once -namespace oomph -{ +namespace oomph { -using tag_type = int; -using rank_type = int; + using tag_type = int; + using rank_type = int; -} // namespace oomph +} // namespace oomph diff --git a/include/oomph/util/heap_pimpl.hpp b/include/oomph/util/heap_pimpl.hpp index d4a58fd7..e6f1ba63 100644 --- a/include/oomph/util/heap_pimpl.hpp +++ b/include/oomph/util/heap_pimpl.hpp @@ -11,37 +11,33 @@ #include -namespace oomph -{ -namespace util -{ -template -class heap_pimpl -{ - private: - std::unique_ptr m; +namespace oomph { namespace util { + template + class heap_pimpl + { + private: + std::unique_ptr m; - public: - ~heap_pimpl(); - heap_pimpl() noexcept; - heap_pimpl(T* ptr) noexcept; - template - heap_pimpl(Args&&... args); - heap_pimpl(heap_pimpl const&) = delete; - heap_pimpl(heap_pimpl&&) noexcept; - heap_pimpl& operator=(heap_pimpl const&) = delete; - heap_pimpl& operator=(heap_pimpl&&) noexcept; + public: + ~heap_pimpl(); + heap_pimpl() noexcept; + heap_pimpl(T* ptr) noexcept; + template + heap_pimpl(Args&&... args); + heap_pimpl(heap_pimpl const&) = delete; + heap_pimpl(heap_pimpl&&) noexcept; + heap_pimpl& operator=(heap_pimpl const&) = delete; + heap_pimpl& operator=(heap_pimpl&&) noexcept; - T* operator->() noexcept; - T const* operator->() const noexcept; - T& operator*() noexcept; - T const& operator*() const noexcept; - T* get() noexcept; - T const* get() const noexcept; -}; + T* operator->() noexcept; + T const* operator->() const noexcept; + T& operator*() noexcept; + T const& operator*() const noexcept; + T* get() noexcept; + T const* get() const noexcept; + }; -template -heap_pimpl make_heap_pimpl(Args&&... args); + template + heap_pimpl make_heap_pimpl(Args&&... args); -} // namespace util -} // namespace oomph +}} // namespace oomph::util diff --git a/include/oomph/util/moved_bit.hpp b/include/oomph/util/moved_bit.hpp index 27ad3084..3e0e2bb6 100644 --- a/include/oomph/util/moved_bit.hpp +++ b/include/oomph/util/moved_bit.hpp @@ -11,34 +11,30 @@ #include -namespace oomph -{ -namespace util -{ -struct moved_bit -{ - bool m_moved = false; - - moved_bit() = default; - moved_bit(bool state) noexcept - : m_moved{state} - { - } - moved_bit(const moved_bit&) = default; - moved_bit(moved_bit&& other) noexcept - : m_moved{std::exchange(other.m_moved, true)} +namespace oomph { namespace util { + struct moved_bit { - } + bool m_moved = false; - moved_bit& operator=(const moved_bit&) = default; - moved_bit& operator=(moved_bit&& other) noexcept - { - m_moved = std::exchange(other.m_moved, true); - return *this; - } + moved_bit() = default; + moved_bit(bool state) noexcept + : m_moved{state} + { + } + moved_bit(moved_bit const&) = default; + moved_bit(moved_bit&& other) noexcept + : m_moved{std::exchange(other.m_moved, true)} + { + } + + moved_bit& operator=(moved_bit const&) = default; + moved_bit& operator=(moved_bit&& other) noexcept + { + m_moved = std::exchange(other.m_moved, true); + return *this; + } - operator bool() const { return m_moved; } -}; + operator bool() const { return m_moved; } + }; -} // namespace util -} // namespace oomph +}} // namespace oomph::util diff --git a/include/oomph/util/mpi_clone_comm.hpp b/include/oomph/util/mpi_clone_comm.hpp index 765e73cc..4ebaca13 100644 --- a/include/oomph/util/mpi_clone_comm.hpp +++ b/include/oomph/util/mpi_clone_comm.hpp @@ -11,17 +11,12 @@ #include -namespace oomph -{ -namespace util -{ -inline MPI_Comm -mpi_clone_comm(MPI_Comm mpi_comm) -{ - MPI_Comm new_comm; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_dup(mpi_comm, &new_comm)); - return new_comm; -} +namespace oomph { namespace util { + inline MPI_Comm mpi_clone_comm(MPI_Comm mpi_comm) + { + MPI_Comm new_comm; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_dup(mpi_comm, &new_comm)); + return new_comm; + } -} // namespace util -} // namespace oomph +}} // namespace oomph::util diff --git a/include/oomph/util/mpi_comm_holder.hpp b/include/oomph/util/mpi_comm_holder.hpp index 51650ddd..9e35397a 100644 --- a/include/oomph/util/mpi_comm_holder.hpp +++ b/include/oomph/util/mpi_comm_holder.hpp @@ -12,30 +12,26 @@ #include #include -namespace oomph -{ -namespace util -{ -class mpi_comm_holder -{ - private: - MPI_Comm m; - moved_bit m_moved; - - public: - mpi_comm_holder(MPI_Comm comm) - : m{util::mpi_clone_comm(comm)} - { - } - mpi_comm_holder(mpi_comm_holder const&) = delete; - mpi_comm_holder(mpi_comm_holder&&) noexcept = default; - mpi_comm_holder& operator=(mpi_comm_holder const&) = delete; - mpi_comm_holder& operator=(mpi_comm_holder&&) noexcept = default; - ~mpi_comm_holder() noexcept +namespace oomph { namespace util { + class mpi_comm_holder { - if (!m_moved) MPI_Comm_free(&m); - } - MPI_Comm get() const noexcept { return m; } -}; -} // namespace util -} // namespace oomph + private: + MPI_Comm m; + moved_bit m_moved; + + public: + mpi_comm_holder(MPI_Comm comm) + : m{util::mpi_clone_comm(comm)} + { + } + mpi_comm_holder(mpi_comm_holder const&) = delete; + mpi_comm_holder(mpi_comm_holder&&) noexcept = default; + mpi_comm_holder& operator=(mpi_comm_holder const&) = delete; + mpi_comm_holder& operator=(mpi_comm_holder&&) noexcept = default; + ~mpi_comm_holder() noexcept + { + if (!m_moved) MPI_Comm_free(&m); + } + MPI_Comm get() const noexcept { return m; } + }; +}} // namespace oomph::util diff --git a/include/oomph/util/mpi_error.hpp b/include/oomph/util/mpi_error.hpp index 0f09c217..52d91cc0 100644 --- a/include/oomph/util/mpi_error.hpp +++ b/include/oomph/util/mpi_error.hpp @@ -12,21 +12,21 @@ #include #ifdef NDEBUG -#define OOMPH_CHECK_MPI_RESULT(x) x; -#define OOMPH_CHECK_MPI_RESULT_NOEXCEPT(x) x; +# define OOMPH_CHECK_MPI_RESULT(x) x; +# define OOMPH_CHECK_MPI_RESULT_NOEXCEPT(x) x; #else -#include -#include -#include -#define OOMPH_CHECK_MPI_RESULT(x) \ - if (x != MPI_SUCCESS) \ - throw std::runtime_error("OOMPH Error: MPI Call failed " + std::string(#x) + " in " + \ - std::string(__FILE__) + ":" + std::to_string(__LINE__)); -#define OOMPH_CHECK_MPI_RESULT_NOEXCEPT(x) \ - if (x != MPI_SUCCESS) \ - { \ - std::cerr << "OOMPH Error: MPI Call failed " << std::string(#x) << " in " \ - << std::string(__FILE__) << ":" << std::to_string(__LINE__) << std::endl; \ - std::terminate(); \ - } +# include +# include +# include +# define OOMPH_CHECK_MPI_RESULT(x) \ + if (x != MPI_SUCCESS) \ + throw std::runtime_error("OOMPH Error: MPI Call failed " + std::string(#x) + " in " + \ + std::string(__FILE__) + ":" + std::to_string(__LINE__)); +# define OOMPH_CHECK_MPI_RESULT_NOEXCEPT(x) \ + if (x != MPI_SUCCESS) \ + { \ + std::cerr << "OOMPH Error: MPI Call failed " << std::string(#x) << " in " \ + << std::string(__FILE__) << ":" << std::to_string(__LINE__) << std::endl; \ + std::terminate(); \ + } #endif diff --git a/include/oomph/util/pimpl.hpp b/include/oomph/util/pimpl.hpp index 413c858d..3a75caa1 100644 --- a/include/oomph/util/pimpl.hpp +++ b/include/oomph/util/pimpl.hpp @@ -13,41 +13,35 @@ #include #if OOMPH_USE_FAST_PIMPL -#include "./stack_pimpl.hpp" +# include "./stack_pimpl.hpp" #else -#include "./heap_pimpl.hpp" +# include "./heap_pimpl.hpp" #endif -namespace oomph -{ -namespace util -{ +namespace oomph { namespace util { #if OOMPH_USE_FAST_PIMPL -template::value> -using pimpl = stack_pimpl; + template ::value> + using pimpl = stack_pimpl; -template::value, typename... Args> -pimpl -make_pimpl(Args&&... args) -{ - return make_stack_pimpl(std::forward(args)...); -} + template ::value, typename... Args> + pimpl make_pimpl(Args&&... args) + { + return make_stack_pimpl(std::forward(args)...); + } #else -template -using pimpl = heap_pimpl; + template + using pimpl = heap_pimpl; -template -pimpl -make_pimpl(Args&&... args) -{ - return make_heap_pimpl(std::forward(args)...); -} + template + pimpl make_pimpl(Args&&... args) + { + return make_heap_pimpl(std::forward(args)...); + } #endif -} // namespace util -} // namespace oomph +}} // namespace oomph::util diff --git a/include/oomph/util/placement_new.hpp b/include/oomph/util/placement_new.hpp index 3662c1e0..b3344dc2 100644 --- a/include/oomph/util/placement_new.hpp +++ b/include/oomph/util/placement_new.hpp @@ -9,35 +9,29 @@ */ #pragma once -#include #include +#include #include #if defined(NDEBUG) -#define OOMPH_DEBUG_ARG(TYPE, NAME) TYPE +# define OOMPH_DEBUG_ARG(TYPE, NAME) TYPE #else -#define OOMPH_DEBUG_ARG(TYPE, NAME) TYPE NAME +# define OOMPH_DEBUG_ARG(TYPE, NAME) TYPE NAME #endif -namespace oomph -{ -namespace util -{ -template -inline T* -placement_new(void* buffer, OOMPH_DEBUG_ARG(std::size_t, size), Args&&... args) -{ - assert(sizeof(T) <= size); - assert(std::align(std::alignment_of::value, sizeof(T), buffer, size) == buffer); - return new (buffer) T{std::forward(args)...}; -} +namespace oomph { namespace util { + template + inline T* placement_new(void* buffer, OOMPH_DEBUG_ARG(std::size_t, size), Args&&... args) + { + assert(sizeof(T) <= size); + assert(std::align(std::alignment_of::value, sizeof(T), buffer, size) == buffer); + return new (buffer) T{std::forward(args)...}; + } -template -inline void -placement_delete(void* buffer) -{ - reinterpret_cast(buffer)->~T(); -} + template + inline void placement_delete(void* buffer) + { + reinterpret_cast(buffer)->~T(); + } -} // namespace util -} // namespace oomph +}} // namespace oomph::util diff --git a/include/oomph/util/pool_allocator.hpp b/include/oomph/util/pool_allocator.hpp index 17be8e1f..311fb373 100644 --- a/include/oomph/util/pool_allocator.hpp +++ b/include/oomph/util/pool_allocator.hpp @@ -9,60 +9,54 @@ */ #pragma once +#include #include #include -#include -namespace oomph -{ -namespace util -{ +namespace oomph { namespace util { -template -struct pool_allocator -{ - using value_type = T; - using pool_type = boost::pool; + template + struct pool_allocator + { + using value_type = T; + using pool_type = boost::pool; - pool_type* _p; + pool_type* _p; - constexpr pool_allocator(pool_type* p) noexcept - : _p{p} - { - } + constexpr pool_allocator(pool_type* p) noexcept + : _p{p} + { + } - template - constexpr pool_allocator(const pool_allocator& other) noexcept - : _p{other._p} - { - } + template + constexpr pool_allocator(pool_allocator const& other) noexcept + : _p{other._p} + { + } #ifdef NDEBUG - [[nodiscard]] T* allocate(std::size_t) + [[nodiscard]] T* allocate(std::size_t) #else - [[nodiscard]] T* allocate(std::size_t n) + [[nodiscard]] T* allocate(std::size_t n) #endif - { - assert(_p->get_requested_size() >= sizeof(T) * n); - if (auto ptr = static_cast(_p->malloc())) return ptr; - throw std::bad_alloc(); - } + { + assert(_p->get_requested_size() >= sizeof(T) * n); + if (auto ptr = static_cast(_p->malloc())) return ptr; + throw std::bad_alloc(); + } - void deallocate(T* p, std::size_t /*n*/) noexcept { _p->free(p); } -}; + void deallocate(T* p, std::size_t /*n*/) noexcept { _p->free(p); } + }; -template -bool -operator==(const pool_allocator&, const pool_allocator&) -{ - return true; -} -template -bool -operator!=(const pool_allocator&, const pool_allocator&) -{ - return false; -} + template + bool operator==(pool_allocator const&, pool_allocator const&) + { + return true; + } + template + bool operator!=(pool_allocator const&, pool_allocator const&) + { + return false; + } -} // namespace util -} // namespace oomph +}} // namespace oomph::util diff --git a/include/oomph/util/pool_factory.hpp b/include/oomph/util/pool_factory.hpp index bcd8a5f7..a7209bb1 100644 --- a/include/oomph/util/pool_factory.hpp +++ b/include/oomph/util/pool_factory.hpp @@ -12,42 +12,38 @@ #include #include -namespace oomph -{ -namespace util -{ - -template -struct pool_factory -{ - public: - using value_type = T; - using ptr_type = unsafe_shared_ptr; - - private: - using allocator_type = pool_allocator; - using pool_type = typename allocator_type::pool_type; - - pool_type m_pool; - - public: - pool_factory() - : m_pool{ptr_type::template allocation_size()} - { - } - - pool_factory(pool_factory const&) = delete; - pool_factory(pool_factory&&) = delete; - pool_factory& operator=(pool_factory const&) = delete; - pool_factory& operator=(pool_factory&&) = delete; +namespace oomph { namespace util { - template - ptr_type make(Args&&... args) + template + struct pool_factory { - return oomph::util::allocate_shared(allocator_type(&m_pool), - std::forward(args)...); - } -}; - -} // namespace util -} // namespace oomph + public: + using value_type = T; + using ptr_type = unsafe_shared_ptr; + + private: + using allocator_type = pool_allocator; + using pool_type = typename allocator_type::pool_type; + + pool_type m_pool; + + public: + pool_factory() + : m_pool{ptr_type::template allocation_size()} + { + } + + pool_factory(pool_factory const&) = delete; + pool_factory(pool_factory&&) = delete; + pool_factory& operator=(pool_factory const&) = delete; + pool_factory& operator=(pool_factory&&) = delete; + + template + ptr_type make(Args&&... args) + { + return oomph::util::allocate_shared( + allocator_type(&m_pool), std::forward(args)...); + } + }; + +}} // namespace oomph::util diff --git a/include/oomph/util/stack_pimpl.hpp b/include/oomph/util/stack_pimpl.hpp index 9a093796..a8f98849 100644 --- a/include/oomph/util/stack_pimpl.hpp +++ b/include/oomph/util/stack_pimpl.hpp @@ -11,38 +11,34 @@ #include -namespace oomph -{ -namespace util -{ -template::value> -class stack_pimpl -{ - private: - util::stack_storage m; +namespace oomph { namespace util { + template ::value> + class stack_pimpl + { + private: + util::stack_storage m; - public: - ~stack_pimpl(); - stack_pimpl() noexcept; - template - stack_pimpl(Args&&... args); - stack_pimpl(stack_pimpl const&) = delete; - stack_pimpl(stack_pimpl&&) noexcept; - stack_pimpl& operator=(stack_pimpl const&) = delete; - stack_pimpl& operator=(stack_pimpl&&) noexcept; + public: + ~stack_pimpl(); + stack_pimpl() noexcept; + template + stack_pimpl(Args&&... args); + stack_pimpl(stack_pimpl const&) = delete; + stack_pimpl(stack_pimpl&&) noexcept; + stack_pimpl& operator=(stack_pimpl const&) = delete; + stack_pimpl& operator=(stack_pimpl&&) noexcept; - T* operator->() noexcept; - T const* operator->() const noexcept; - T& operator*() noexcept; - T const& operator*() const noexcept; - T* get() noexcept; - T const* get() const noexcept; -}; + T* operator->() noexcept; + T const* operator->() const noexcept; + T& operator*() noexcept; + T const& operator*() const noexcept; + T* get() noexcept; + T const* get() const noexcept; + }; -template::value, typename... Args> -stack_pimpl make_stack_pimpl(Args&&... args); + template ::value, typename... Args> + stack_pimpl make_stack_pimpl(Args&&... args); -} // namespace util -} // namespace oomph +}} // namespace oomph::util diff --git a/include/oomph/util/stack_storage.hpp b/include/oomph/util/stack_storage.hpp index 634a2263..2d9d806e 100644 --- a/include/oomph/util/stack_storage.hpp +++ b/include/oomph/util/stack_storage.hpp @@ -9,113 +9,107 @@ */ #pragma once +#include #include #include -#include -namespace oomph -{ -namespace util -{ -namespace detail -{ -template -inline void -compare_size() -{ - static_assert(BufferSize >= Size, "buffer size is too small"); - static_assert(BufferAlignment >= Alignment, "buffer alignment not big enough"); -} +namespace oomph { namespace util { + namespace detail { + template + inline void compare_size() + { + static_assert(BufferSize >= Size, "buffer size is too small"); + static_assert(BufferAlignment >= Alignment, "buffer alignment not big enough"); + } -template -struct size_comparer -{ - inline size_comparer() - { - // going through one additional layer to get good error messages - compare_size(); - } -}; -} // namespace detail + template + struct size_comparer + { + inline size_comparer() + { + // going through one additional layer to get good error messages + compare_size(); + } + }; + } // namespace detail -template::value> -class stack_storage -{ - private: - using aligned_storage = std::aligned_storage_t; + template ::value> + class stack_storage + { + private: + using aligned_storage = std::aligned_storage_t; - private: - aligned_storage m_impl; - bool m_empty = false; + private: + aligned_storage m_impl; + bool m_empty = false; - public: - stack_storage() - : m_empty{true} - { - } + public: + stack_storage() + : m_empty{true} + { + } - template - stack_storage(Args&&... args) - { - placement_new(&m_impl, BufferSize, std::forward(args)...); - } - stack_storage(stack_storage const&) = delete; - stack_storage(stack_storage&& other) - : m_empty{other.m_empty} - { - if (!m_empty) placement_new(&m_impl, BufferSize, std::move(*other.get())); - } - stack_storage& operator=(stack_storage const&) = delete; - stack_storage& operator=(stack_storage&& other) - { - if (m_empty && other.m_empty) { return *this; } - else if (m_empty && !other.m_empty) + template + stack_storage(Args&&... args) { - placement_new(&m_impl, BufferSize, std::move(*other.get())); - m_empty = false; - return *this; + placement_new(&m_impl, BufferSize, std::forward(args)...); } - else if (!m_empty && other.m_empty) + stack_storage(stack_storage const&) = delete; + stack_storage(stack_storage&& other) + : m_empty{other.m_empty} { - placement_delete(&m_impl); - m_empty = true; - return *this; + if (!m_empty) placement_new(&m_impl, BufferSize, std::move(*other.get())); } - else + stack_storage& operator=(stack_storage const&) = delete; + stack_storage& operator=(stack_storage&& other) { - *get() = std::move(*other.get()); - return *this; + if (m_empty && other.m_empty) { return *this; } + else if (m_empty && !other.m_empty) + { + placement_new(&m_impl, BufferSize, std::move(*other.get())); + m_empty = false; + return *this; + } + else if (!m_empty && other.m_empty) + { + placement_delete(&m_impl); + m_empty = true; + return *this; + } + else + { + *get() = std::move(*other.get()); + return *this; + } + } + ~stack_storage() + { + detail::size_comparer s{}; + if (!m_empty) placement_delete(&m_impl); } - } - ~stack_storage() - { - detail::size_comparer s{}; - if (!m_empty) placement_delete(&m_impl); - } - T* get() noexcept - { - assert(!m_empty); - return reinterpret_cast(&m_impl); - } + T* get() noexcept + { + assert(!m_empty); + return reinterpret_cast(&m_impl); + } - T const* get() const noexcept - { - assert(!m_empty); - return reinterpret_cast(&m_impl); - } + T const* get() const noexcept + { + assert(!m_empty); + return reinterpret_cast(&m_impl); + } - //T release() - //{ - // T t{std::move(*get())}; - // placement_delete(&m_impl); - // m_empty = true; - // return std::move(t); - //} -}; + //T release() + //{ + // T t{std::move(*get())}; + // placement_delete(&m_impl); + // m_empty = true; + // return std::move(t); + //} + }; -} // namespace util -} // namespace oomph +}} // namespace oomph::util diff --git a/include/oomph/util/unique_function.hpp b/include/oomph/util/unique_function.hpp index 58a19e09..0898a76c 100644 --- a/include/oomph/util/unique_function.hpp +++ b/include/oomph/util/unique_function.hpp @@ -10,225 +10,226 @@ #pragma once #include -#include #include +#include #include #include -namespace oomph -{ -namespace util -{ -template -class unique_function; - -namespace detail -{ -template -struct unique_function -{ - virtual R invoke(Args&&... args) = 0; - virtual ~unique_function(){}; - - virtual void move_construct(void*) = 0; -}; - -template -struct unique_function_impl : unique_function -{ - using this_type = unique_function_impl; - - Func func; - - template - unique_function_impl(F&& f) - : func{std::move(f)} - { - } +namespace oomph { namespace util { + template + class unique_function; - virtual R invoke(Args&&... args) final override { return func(std::forward(args)...); } + namespace detail { + template + struct unique_function + { + virtual R invoke(Args&&... args) = 0; + virtual ~unique_function() {}; - virtual void move_construct(void* addr) final override - { - ::new (addr) this_type{std::move(func)}; - } -}; + virtual void move_construct(void*) = 0; + }; -// specialization for void function -template -struct unique_function_impl : unique_function -{ - using this_type = unique_function_impl; + template + struct unique_function_impl : unique_function + { + using this_type = unique_function_impl; + + Func func; + + template + unique_function_impl(F&& f) + : func{std::move(f)} + { + } + + virtual R invoke(Args&&... args) final override + { + return func(std::forward(args)...); + } + + virtual void move_construct(void* addr) final override + { + ::new (addr) this_type{std::move(func)}; + } + }; + + // specialization for void function + template + struct unique_function_impl : unique_function + { + using this_type = unique_function_impl; - Func func; + Func func; - template - unique_function_impl(F&& f) - : func{std::move(f)} - { - } + template + unique_function_impl(F&& f) + : func{std::move(f)} + { + } - virtual void invoke(Args&&... args) final override { func(std::forward(args)...); } + virtual void invoke(Args&&... args) final override + { + func(std::forward(args)...); + } - virtual void move_construct(void* addr) final override - { - ::new (addr) this_type{std::move(func)}; - } -}; - -} // namespace detail - -// a function object wrapper a la std::function but for move-only types -// which uses small buffer optimization -template -class unique_function -{ - private: // member types - // abstract base class - using interface_t = detail::unique_function; - - // buffer size and type - static constexpr std::size_t sbo_size = S; - using buffer_t = std::aligned_storage_t::value>; - - // variant holds 3 alternatives: - // - empty state - // - heap allocated function objects - // - stack buffer for small function objects (sbo) - using holder_t = std::variant; - - private: // members - holder_t holder; - - private: // helper templates for type inspection - // return type - template - using result_t = std::result_of_t; - // concrete type for allocation - template - using concrete_t = detail::unique_function_impl, R, Args...>; - // F can be invoked with Args and return type can be converted to R - template - using has_signature_t = decltype((R)(std::declval>())); - // is already a unique function - template - using is_unique_function_t = std::is_same, unique_function>; - // differentiate small and large function objects - template - using enable_if_large_function_t = - std::enable_if_t::value && (sizeof(std::decay_t) > sbo_size), - bool>; - template - using enable_if_small_function_t = - std::enable_if_t::value && (sizeof(std::decay_t) <= sbo_size), - bool>; - - public: // ctors - // construct empty - unique_function() noexcept = default; - - // deleted copy ctors - unique_function(unique_function const&) = delete; - unique_function& operator=(unique_function const&) = delete; - - // construct from large function - template, enable_if_large_function_t = true> - unique_function(F&& f) - : holder{std::in_place_type_t{}, new concrete_t(std::move(f))} - { - static_assert(std::is_rvalue_reference_v, "argument is not an r-value reference"); - } + virtual void move_construct(void* addr) final override + { + ::new (addr) this_type{std::move(func)}; + } + }; - // construct from small function - template, enable_if_small_function_t = true> - unique_function(F&& f) - : holder{std::in_place_type_t{}} - { - static_assert(std::is_rvalue_reference_v, "argument is not an r-value reference"); - ::new (&std::get<2>(holder)) concrete_t(std::forward(f)); - } + } // namespace detail - // move construct from unique_function - unique_function(unique_function&& other) noexcept - : holder{std::move(other.holder)} + // a function object wrapper a la std::function but for move-only types + // which uses small buffer optimization + template + class unique_function { - move_construct(other.holder); - } + private: // member types + // abstract base class + using interface_t = detail::unique_function; + + // buffer size and type + static constexpr std::size_t sbo_size = S; + using buffer_t = + std::aligned_storage_t::value>; + + // variant holds 3 alternatives: + // - empty state + // - heap allocated function objects + // - stack buffer for small function objects (sbo) + using holder_t = std::variant; + + private: // members + holder_t holder; + + private: // helper templates for type inspection + // return type + template + using result_t = std::result_of_t; + // concrete type for allocation + template + using concrete_t = detail::unique_function_impl, R, Args...>; + // F can be invoked with Args and return type can be converted to R + template + using has_signature_t = decltype((R) (std::declval>())); + // is already a unique function + template + using is_unique_function_t = std::is_same, unique_function>; + // differentiate small and large function objects + template + using enable_if_large_function_t = std::enable_if_t< + !is_unique_function_t::value && (sizeof(std::decay_t) > sbo_size), bool>; + template + using enable_if_small_function_t = std::enable_if_t< + !is_unique_function_t::value && (sizeof(std::decay_t) <= sbo_size), bool>; + + public: // ctors + // construct empty + unique_function() noexcept = default; + + // deleted copy ctors + unique_function(unique_function const&) = delete; + unique_function& operator=(unique_function const&) = delete; + + // construct from large function + template , enable_if_large_function_t = true> + unique_function(F&& f) + : holder{std::in_place_type_t{}, new concrete_t(std::move(f))} + { + static_assert(std::is_rvalue_reference_v, "argument is not an r-value reference"); + } - // move assign from unique_function - unique_function& operator=(unique_function&& other) noexcept - { - destroy(); - holder = std::move(other.holder); - move_construct(other.holder); - return *this; - } - - // move assign from large function - template, enable_if_large_function_t = true> - unique_function& operator=(F&& f) noexcept - { - static_assert(std::is_rvalue_reference_v, "argument is not an r-value reference"); - destroy(); - holder.template emplace(new concrete_t(std::move(f))); - return *this; - } - - // move assign from small function - template, enable_if_small_function_t = true> - unique_function& operator=(F&& f) noexcept - { - static_assert(std::is_rvalue_reference_v, "argument is not an r-value reference"); - destroy(); - holder.template emplace(); - ::new (&std::get<2>(holder)) concrete_t(std::forward(f)); - return *this; - } + // construct from small function + template , enable_if_small_function_t = true> + unique_function(F&& f) + : holder{std::in_place_type_t{}} + { + static_assert(std::is_rvalue_reference_v, "argument is not an r-value reference"); + ::new (&std::get<2>(holder)) concrete_t(std::forward(f)); + } - ~unique_function() { destroy(); } + // move construct from unique_function + unique_function(unique_function&& other) noexcept + : holder{std::move(other.holder)} + { + move_construct(other.holder); + } - public: // member functions - R operator()(Args... args) const { return get()->invoke(std::forward(args)...); } + // move assign from unique_function + unique_function& operator=(unique_function&& other) noexcept + { + destroy(); + holder = std::move(other.holder); + move_construct(other.holder); + return *this; + } - operator bool() const noexcept { return (holder.index() != 0); } + // move assign from large function + template , enable_if_large_function_t = true> + unique_function& operator=(F&& f) noexcept + { + static_assert(std::is_rvalue_reference_v, "argument is not an r-value reference"); + destroy(); + holder.template emplace(new concrete_t(std::move(f))); + return *this; + } - private: // helper functions - static interface_t* get_from_buffer(holder_t const& h) noexcept - { - return std::launder(reinterpret_cast(&const_cast(std::get<2>(h)))); - } + // move assign from small function + template , enable_if_small_function_t = true> + unique_function& operator=(F&& f) noexcept + { + static_assert(std::is_rvalue_reference_v, "argument is not an r-value reference"); + destroy(); + holder.template emplace(); + ::new (&std::get<2>(holder)) concrete_t(std::forward(f)); + return *this; + } - static interface_t* get_from_buffer(holder_t& h) noexcept - { - return std::launder(reinterpret_cast(&std::get<2>(h))); - } + ~unique_function() { destroy(); } - interface_t* get() const noexcept - { - return (holder.index() == 2) ? get_from_buffer(holder) : std::get<1>(holder); - } + public: // member functions + R operator()(Args... args) const { return get()->invoke(std::forward(args)...); } - void destroy() - { - // delete from heap - if (holder.index() == 1) delete std::get<1>(holder); - // delete from stack buffer - if (holder.index() == 2) std::destroy_at(get_from_buffer(holder)); - } + operator bool() const noexcept { return (holder.index() != 0); } - void move_construct(holder_t& other_holder) - { - // explicitly move if function is stored in stack buffer - if (other_holder.index() == 2) + private: // helper functions + static interface_t* get_from_buffer(holder_t const& h) noexcept + { + return std::launder( + reinterpret_cast(&const_cast(std::get<2>(h)))); + } + + static interface_t* get_from_buffer(holder_t& h) noexcept + { + return std::launder(reinterpret_cast(&std::get<2>(h))); + } + + interface_t* get() const noexcept + { + return (holder.index() == 2) ? get_from_buffer(holder) : std::get<1>(holder); + } + + void destroy() + { + // delete from heap + if (holder.index() == 1) delete std::get<1>(holder); + // delete from stack buffer + if (holder.index() == 2) std::destroy_at(get_from_buffer(holder)); + } + + void move_construct(holder_t& other_holder) { - interface_t* ptr = get_from_buffer(other_holder); - ptr->move_construct(&std::get<2>(holder)); - std::destroy_at(ptr); + // explicitly move if function is stored in stack buffer + if (other_holder.index() == 2) + { + interface_t* ptr = get_from_buffer(other_holder); + ptr->move_construct(&std::get<2>(holder)); + std::destroy_at(ptr); + } + // reset to empty state + other_holder = std::monostate{}; } - // reset to empty state - other_holder = std::monostate{}; - } -}; + }; -} // namespace util -} // namespace oomph +}} // namespace oomph::util diff --git a/include/oomph/util/unsafe_shared_ptr.hpp b/include/oomph/util/unsafe_shared_ptr.hpp index ae42160a..44e302b9 100644 --- a/include/oomph/util/unsafe_shared_ptr.hpp +++ b/include/oomph/util/unsafe_shared_ptr.hpp @@ -9,207 +9,203 @@ */ #pragma once -#include -#include #include +#include +#include -namespace oomph -{ -namespace util -{ - -template -class enable_shared_from_this; - -namespace detail -{ - -template -struct control_block -{ - std::size_t m_ref_count = 1ul; - T* m_ptr = nullptr; - - virtual void free() = 0; -}; - -template -struct control_block_impl : public control_block -{ - using this_type = control_block_impl; - using base_type = control_block; - using alloc_t = typename std::allocator_traits::template rebind_alloc; - using traits = std::allocator_traits; - - alloc_t m_alloc; - T m_t; - - template - control_block_impl(Alloc const a, Args&&... args) - : base_type() - , m_alloc{a} - , m_t{std::forward(args)...} - { - this->m_ptr = &m_t; - set_shared_from_this(); - } - - void free() override final - { - auto a = m_alloc; - m_alloc.~alloc_t(); - m_t.~T(); - traits::deallocate(a, this, 1); - } - - template, D>::value, bool> = true> - void set_shared_from_this() - { - } - - template, D>::value, bool> = true> - void set_shared_from_this() - { - m_t._shared_from_this_cb = this; - } -}; - -} // namespace detail - -template -class unsafe_shared_ptr -{ - template - friend class enable_shared_from_this; - - private: - using block_t = detail::control_block; - - public: - template - static constexpr std::size_t allocation_size() - { - return sizeof(detail::control_block_impl); - } - - private: - block_t* m = nullptr; - - private: - unsafe_shared_ptr(block_t* m_) noexcept - : m{m_} - { - if (m) ++m->m_ref_count; - } - - public: - template - unsafe_shared_ptr(Alloc const& alloc, Args&&... args) - { - using block_impl_t = detail::control_block_impl; - using alloc_t = typename block_impl_t::alloc_t; - using traits = std::allocator_traits; - - alloc_t a(alloc); - m = traits::allocate(a, 1); - ::new (m) block_impl_t(a, std::forward(args)...); - } - - public: - unsafe_shared_ptr() noexcept = default; - - unsafe_shared_ptr(unsafe_shared_ptr const& other) noexcept - : m{other.m} +namespace oomph { namespace util { + + template + class enable_shared_from_this; + + namespace detail { + + template + struct control_block + { + std::size_t m_ref_count = 1ul; + T* m_ptr = nullptr; + + virtual void free() = 0; + }; + + template + struct control_block_impl : public control_block + { + using this_type = control_block_impl; + using base_type = control_block; + using alloc_t = + typename std::allocator_traits::template rebind_alloc; + using traits = std::allocator_traits; + + alloc_t m_alloc; + T m_t; + + template + control_block_impl(Alloc const a, Args&&... args) + : base_type() + , m_alloc{a} + , m_t{std::forward(args)...} + { + this->m_ptr = &m_t; + set_shared_from_this(); + } + + void free() override final + { + auto a = m_alloc; + m_alloc.~alloc_t(); + m_t.~T(); + traits::deallocate(a, this, 1); + } + + template , D>::value, bool> = + true> + void set_shared_from_this() + { + } + + template , D>::value, bool> = + true> + void set_shared_from_this() + { + m_t._shared_from_this_cb = this; + } + }; + + } // namespace detail + + template + class unsafe_shared_ptr { - if (m) ++m->m_ref_count; - } - - unsafe_shared_ptr(unsafe_shared_ptr&& other) noexcept - : m{std::exchange(other.m, nullptr)} + template + friend class enable_shared_from_this; + + private: + using block_t = detail::control_block; + + public: + template + static constexpr std::size_t allocation_size() + { + return sizeof(detail::control_block_impl); + } + + private: + block_t* m = nullptr; + + private: + unsafe_shared_ptr(block_t* m_) noexcept + : m{m_} + { + if (m) ++m->m_ref_count; + } + + public: + template + unsafe_shared_ptr(Alloc const& alloc, Args&&... args) + { + using block_impl_t = detail::control_block_impl; + using alloc_t = typename block_impl_t::alloc_t; + using traits = std::allocator_traits; + + alloc_t a(alloc); + m = traits::allocate(a, 1); + ::new (m) block_impl_t(a, std::forward(args)...); + } + + public: + unsafe_shared_ptr() noexcept = default; + + unsafe_shared_ptr(unsafe_shared_ptr const& other) noexcept + : m{other.m} + { + if (m) ++m->m_ref_count; + } + + unsafe_shared_ptr(unsafe_shared_ptr&& other) noexcept + : m{std::exchange(other.m, nullptr)} + { + } + + unsafe_shared_ptr& operator=(unsafe_shared_ptr const& other) noexcept + { + destroy(); + m = other.m; + if (m) ++m->m_ref_count; + return *this; + } + + unsafe_shared_ptr& operator=(unsafe_shared_ptr&& other) noexcept + { + destroy(); + m = std::exchange(other.m, nullptr); + return *this; + } + + ~unsafe_shared_ptr() { destroy(); } + + operator bool() const noexcept { return (bool) m; } + + T* get() const noexcept { return m->m_ptr; } + + T* operator->() const noexcept { return m->m_ptr; } + + T& operator*() const noexcept { return *(m->m_ptr); } + + std::size_t use_count() const noexcept { return m ? m->m_ref_count : 0ul; } + + private: + void destroy() noexcept + { + if (!m) return; + if (--m->m_ref_count == 0) m->free(); + m = nullptr; + } + }; + + template + unsafe_shared_ptr make_shared(Args&&... args) { + return {std::allocator{}, std::forward(args)...}; } - unsafe_shared_ptr& operator=(unsafe_shared_ptr const& other) noexcept + template + unsafe_shared_ptr allocate_shared(Alloc const& alloc, Args&&... args) { - destroy(); - m = other.m; - if (m) ++m->m_ref_count; - return *this; + return {alloc, std::forward(args)...}; } - unsafe_shared_ptr& operator=(unsafe_shared_ptr&& other) noexcept + template + class enable_shared_from_this { - destroy(); - m = std::exchange(other.m, nullptr); - return *this; - } - - ~unsafe_shared_ptr() { destroy(); } - - operator bool() const noexcept { return (bool)m; } - - T* get() const noexcept { return m->m_ptr; } - - T* operator->() const noexcept { return m->m_ptr; } - - T& operator*() const noexcept { return *(m->m_ptr); } - - std::size_t use_count() const noexcept { return m ? m->m_ref_count : 0ul; } - - private: - void destroy() noexcept - { - if (!m) return; - if (--m->m_ref_count == 0) m->free(); - m = nullptr; - } -}; - -template -unsafe_shared_ptr -make_shared(Args&&... args) -{ - return {std::allocator{}, std::forward(args)...}; -} - -template -unsafe_shared_ptr -allocate_shared(Alloc const& alloc, Args&&... args) -{ - return {alloc, std::forward(args)...}; -} - -template -class enable_shared_from_this -{ - template - friend struct detail::control_block_impl; - - private: - detail::control_block* _shared_from_this_cb = nullptr; - - public: - enable_shared_from_this() noexcept {} - enable_shared_from_this(enable_shared_from_this const&) noexcept {} - - protected: - enable_shared_from_this& operator=(enable_shared_from_this const&) noexcept - { - _shared_from_this_cb = nullptr; - } - - public: - unsafe_shared_ptr shared_from_this() - { - assert(((bool)_shared_from_this_cb) && "not created by a unsafe_shared_ptr"); - return {_shared_from_this_cb}; - } - - private: - D* derived() noexcept { return static_cast(this); } -}; - -} // namespace util -} // namespace oomph + template + friend struct detail::control_block_impl; + + private: + detail::control_block* _shared_from_this_cb = nullptr; + + public: + enable_shared_from_this() noexcept {} + enable_shared_from_this(enable_shared_from_this const&) noexcept {} + + protected: + enable_shared_from_this& operator=(enable_shared_from_this const&) noexcept + { + _shared_from_this_cb = nullptr; + } + + public: + unsafe_shared_ptr shared_from_this() + { + assert(((bool) _shared_from_this_cb) && "not created by a unsafe_shared_ptr"); + return {_shared_from_this_cb}; + } + + private: + D* derived() noexcept { return static_cast(this); } + }; + +}} // namespace oomph::util diff --git a/include/oomph/utils.hpp b/include/oomph/utils.hpp index 5fde6b5e..bfd77edb 100644 --- a/include/oomph/utils.hpp +++ b/include/oomph/utils.hpp @@ -9,7 +9,6 @@ */ #pragma once -namespace oomph -{ -void print_config(); +namespace oomph { + void print_config(); } diff --git a/src/barrier.cpp b/src/barrier.cpp index 9481d297..03c31bf7 100644 --- a/src/barrier.cpp +++ b/src/barrier.cpp @@ -8,76 +8,73 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include -#include #include +#include #if OOMPH_ENABLE_BARRIER // paths relative to backend -#include <../communicator_set.hpp> - -namespace oomph -{ +# include <../communicator_set.hpp> -barrier::barrier(context const& c, size_t n_threads) -: m_threads{n_threads} -, b_count2{m_threads} -, m_mpi_comm{c.mpi_comm()} -, m_context{c.m.get()} -{ -} +namespace oomph { -void -barrier::operator()() const -{ - if (in_node1()) rank_barrier(); - else - while (b_count2 == m_threads) communicator_set::get().progress(m_context); - in_node2(); -} + barrier::barrier(context const& c, size_t n_threads) + : m_threads{n_threads} + , b_count2{m_threads} + , m_mpi_comm{c.mpi_comm()} + , m_context{c.m.get()} + { + } -void -barrier::rank_barrier() const -{ - MPI_Request req = MPI_REQUEST_NULL; - int flag; - MPI_Ibarrier(m_mpi_comm, &req); - while (true) + void barrier::operator()() const { - communicator_set::get().progress(m_context); - MPI_Test(&req, &flag, MPI_STATUS_IGNORE); - if (flag) break; + if (in_node1()) + rank_barrier(); + else + while (b_count2 == m_threads) communicator_set::get().progress(m_context); + in_node2(); } -} -bool -barrier::in_node1() const -{ - size_t expected = b_count; - while (!b_count.compare_exchange_weak(expected, expected + 1, std::memory_order_relaxed)) - expected = b_count; - if (expected == m_threads - 1) + void barrier::rank_barrier() const { - b_count.store(0); - return true; + MPI_Request req = MPI_REQUEST_NULL; + int flag; + MPI_Ibarrier(m_mpi_comm, &req); + while (true) + { + communicator_set::get().progress(m_context); + MPI_Test(&req, &flag, MPI_STATUS_IGNORE); + if (flag) break; + } } - else + + bool barrier::in_node1() const { - while (b_count != 0) communicator_set::get().progress(m_context); - return false; + size_t expected = b_count; + while (!b_count.compare_exchange_weak(expected, expected + 1, std::memory_order_relaxed)) + expected = b_count; + if (expected == m_threads - 1) + { + b_count.store(0); + return true; + } + else + { + while (b_count != 0) communicator_set::get().progress(m_context); + return false; + } } -} -void -barrier::in_node2() const -{ - size_t ex = b_count2; - while (!b_count2.compare_exchange_weak(ex, ex - 1, std::memory_order_relaxed)) ex = b_count2; - if (ex == 1) { b_count2.store(m_threads); } - else + void barrier::in_node2() const { - while (b_count2 != m_threads) communicator_set::get().progress(m_context); + size_t ex = b_count2; + while (!b_count2.compare_exchange_weak(ex, ex - 1, std::memory_order_relaxed)) + ex = b_count2; + if (ex == 1) { b_count2.store(m_threads); } + else + { + while (b_count2 != m_threads) communicator_set::get().progress(m_context); + } } -} -} // namespace oomph +} // namespace oomph #endif diff --git a/src/common/print_config.cpp b/src/common/print_config.cpp index cc0320bc..2e55e72d 100644 --- a/src/common/print_config.cpp +++ b/src/common/print_config.cpp @@ -9,16 +9,14 @@ */ #include -namespace oomph -{ -void -print_config() -{ - std::cout << std::endl; - std::cout << " -- OOMPH compile configuration:" << std::endl; - std::cout << std::endl; +namespace oomph { + void print_config() + { + std::cout << std::endl; + std::cout << " -- OOMPH compile configuration:" << std::endl; + std::cout << std::endl; #include - std::cout << std::endl; -} + std::cout << std::endl; + } -} // namespace oomph +} // namespace oomph diff --git a/src/common/rank_topology.cpp b/src/common/rank_topology.cpp index 682f0067..9cf4a2df 100644 --- a/src/common/rank_topology.cpp +++ b/src/common/rank_topology.cpp @@ -9,27 +9,26 @@ */ #include "../rank_topology.hpp" -namespace oomph -{ -rank_topology::rank_topology(MPI_Comm comm) -: m_comm(comm) -{ - // get rank from comm - int rank; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(comm, &rank)); - // split comm into shared memory comms - const int key = rank; - OOMPH_CHECK_MPI_RESULT( - MPI_Comm_split_type(comm, MPI_COMM_TYPE_SHARED, key, MPI_INFO_NULL, &m_shared_comm)); - // get rank within shared memory comm and its size - OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(m_shared_comm, &m_rank)); - int size; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(m_shared_comm, &size)); - // gather rank info from all ranks within shared comm - std::vector ranks(size); - MPI_Allgather(&rank, 1, MPI_INT, ranks.data(), 1, MPI_INT, m_shared_comm); - // insert into set - for (auto r : ranks) m_rank_set.insert(r); - OOMPH_CHECK_MPI_RESULT(MPI_Comm_free(&m_shared_comm)); -} -} //namespace oomph +namespace oomph { + rank_topology::rank_topology(MPI_Comm comm) + : m_comm(comm) + { + // get rank from comm + int rank; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(comm, &rank)); + // split comm into shared memory comms + int const key = rank; + OOMPH_CHECK_MPI_RESULT( + MPI_Comm_split_type(comm, MPI_COMM_TYPE_SHARED, key, MPI_INFO_NULL, &m_shared_comm)); + // get rank within shared memory comm and its size + OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(m_shared_comm, &m_rank)); + int size; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(m_shared_comm, &size)); + // gather rank info from all ranks within shared comm + std::vector ranks(size); + MPI_Allgather(&rank, 1, MPI_INT, ranks.data(), 1, MPI_INT, m_shared_comm); + // insert into set + for (auto r : ranks) m_rank_set.insert(r); + OOMPH_CHECK_MPI_RESULT(MPI_Comm_free(&m_shared_comm)); + } +} //namespace oomph diff --git a/src/common/thread_id.cpp b/src/common/thread_id.cpp index a1f4c80a..0774069c 100644 --- a/src/common/thread_id.cpp +++ b/src/common/thread_id.cpp @@ -9,30 +9,26 @@ */ #include "../thread_id.hpp" -namespace oomph -{ -namespace -{ -std::uintptr_t* -alloc_tid_m() -{ - auto ptr = new std::uintptr_t{}; - *ptr = (std::uintptr_t)ptr; - return ptr; -} -} // namespace +namespace oomph { + namespace { + std::uintptr_t* alloc_tid_m() + { + auto ptr = new std::uintptr_t{}; + *ptr = (std::uintptr_t) ptr; + return ptr; + } + } // namespace -thread_id::thread_id() -: m{alloc_tid_m()} -{ -} + thread_id::thread_id() + : m{alloc_tid_m()} + { + } -thread_id::~thread_id() { delete m; } + thread_id::~thread_id() { delete m; } -thread_id const& -tid() -{ - static thread_local thread_id id; - return id; -} -} // namespace oomph + thread_id const& tid() + { + static thread_local thread_id id; + return id; + } +} // namespace oomph diff --git a/src/communicator.cpp b/src/communicator.cpp index 823042cc..f3e19cd9 100644 --- a/src/communicator.cpp +++ b/src/communicator.cpp @@ -11,99 +11,78 @@ #include // paths relative to backend -#include -#include #include <../message_buffer.hpp> #include <../util/heap_pimpl_src.hpp> +#include +#include OOMPH_INSTANTIATE_HEAP_PIMPL(oomph::detail::message_buffer::heap_ptr_impl) -namespace oomph -{ - -rank_type -communicator::rank() const noexcept -{ - return m_state->m_impl->rank(); -} - -rank_type -communicator::size() const noexcept -{ - return m_state->m_impl->size(); -} - -bool -communicator::is_local(rank_type rank) const noexcept -{ - return m_state->m_impl->is_local(rank); -} - -MPI_Comm -communicator::mpi_comm() const noexcept -{ - return m_state->m_impl->mpi_comm(); -} - -void -communicator::progress() -{ - m_state->m_impl->progress(); -} - -send_request -communicator::send(detail::message_buffer::heap_ptr_impl const* m_ptr, std::size_t size, - rank_type dst, tag_type tag, util::unique_function&& cb) -{ - return m_state->m_impl->send(m_ptr->m, size, dst, tag, std::move(cb), - &(m_state->scheduled_sends)); -} - -recv_request -communicator::recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb) -{ - return m_state->m_impl->recv(m_ptr->m, size, src, tag, std::move(cb), - &(m_state->scheduled_recvs)); -} - -shared_recv_request -communicator::shared_recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, - rank_type src, tag_type tag, util::unique_function&& cb) -{ - return m_state->m_impl->shared_recv(m_ptr->m, size, src, tag, std::move(cb), - m_state->m_shared_scheduled_recvs); -} - -detail::message_buffer -communicator::make_buffer_core(std::size_t size) -{ - return m_state->m_impl->get_heap().allocate(size, hwmalloc::numa().local_node()); -} - -detail::message_buffer -communicator::make_buffer_core(void* ptr, std::size_t size) -{ - return m_state->m_impl->get_heap().register_user_allocation(ptr, size); -} +namespace oomph { + + rank_type communicator::rank() const noexcept { return m_state->m_impl->rank(); } + + rank_type communicator::size() const noexcept { return m_state->m_impl->size(); } + + bool communicator::is_local(rank_type rank) const noexcept + { + return m_state->m_impl->is_local(rank); + } + + MPI_Comm communicator::mpi_comm() const noexcept { return m_state->m_impl->mpi_comm(); } + + void communicator::progress() { m_state->m_impl->progress(); } + + send_request communicator::send(detail::message_buffer::heap_ptr_impl const* m_ptr, + std::size_t size, rank_type dst, tag_type tag, + util::unique_function&& cb) + { + return m_state->m_impl->send( + m_ptr->m, size, dst, tag, std::move(cb), &(m_state->scheduled_sends)); + } + + recv_request communicator::recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, + rank_type src, tag_type tag, util::unique_function&& cb) + { + return m_state->m_impl->recv( + m_ptr->m, size, src, tag, std::move(cb), &(m_state->scheduled_recvs)); + } + + shared_recv_request communicator::shared_recv(detail::message_buffer::heap_ptr_impl* m_ptr, + std::size_t size, rank_type src, tag_type tag, + util::unique_function&& cb) + { + return m_state->m_impl->shared_recv( + m_ptr->m, size, src, tag, std::move(cb), m_state->m_shared_scheduled_recvs); + } + + detail::message_buffer communicator::make_buffer_core(std::size_t size) + { + return m_state->m_impl->get_heap().allocate(size, hwmalloc::numa().local_node()); + } + + detail::message_buffer communicator::make_buffer_core(void* ptr, std::size_t size) + { + return m_state->m_impl->get_heap().register_user_allocation(ptr, size); + } #if OOMPH_ENABLE_DEVICE -detail::message_buffer -communicator::make_buffer_core(std::size_t size, int id) -{ - return m_state->m_impl->get_heap().allocate(size, hwmalloc::numa().local_node(), id); -} - -detail::message_buffer -communicator::make_buffer_core(void* device_ptr, std::size_t size, int device_id) -{ - return m_state->m_impl->get_heap().register_user_allocation(device_ptr, device_id, size); -} - -detail::message_buffer -communicator::make_buffer_core(void* ptr, void* device_ptr, std::size_t size, int device_id) -{ - return m_state->m_impl->get_heap().register_user_allocation(ptr, device_ptr, device_id, size); -} + detail::message_buffer communicator::make_buffer_core(std::size_t size, int id) + { + return m_state->m_impl->get_heap().allocate(size, hwmalloc::numa().local_node(), id); + } + + detail::message_buffer communicator::make_buffer_core( + void* device_ptr, std::size_t size, int device_id) + { + return m_state->m_impl->get_heap().register_user_allocation(device_ptr, device_id, size); + } + + detail::message_buffer communicator::make_buffer_core( + void* ptr, void* device_ptr, std::size_t size, int device_id) + { + return m_state->m_impl->get_heap().register_user_allocation( + ptr, device_ptr, device_id, size); + } #endif -} // namespace oomph +} // namespace oomph diff --git a/src/communicator_base.hpp b/src/communicator_base.hpp index 69337bc6..c273be62 100644 --- a/src/communicator_base.hpp +++ b/src/communicator_base.hpp @@ -15,38 +15,37 @@ #include <../context_base.hpp> #include <../increment_guard.hpp> -namespace oomph -{ -template -class communicator_base -{ - public: - using pool_factory_type = util::pool_factory; - using recursion_increment = increment_guard; - - protected: - context_base* m_context; - pool_factory_type m_req_state_factory; - std::size_t m_recursion_depth = 0u; - - communicator_base(context_base* ctxt) - : m_context(ctxt) +namespace oomph { + template + class communicator_base { - } - - public: - rank_type rank() const noexcept { return m_context->rank(); } - rank_type size() const noexcept { return m_context->size(); } - MPI_Comm mpi_comm() const noexcept { return m_context->get_comm(); } - rank_topology const& topology() const noexcept { return m_context->topology(); } - void release() { m_context->deregister_communicator(static_cast(this)); } - bool is_local(rank_type rank) const noexcept { return topology().is_local(rank); } - - bool has_reached_recursion_depth() const noexcept - { - return m_recursion_depth > OOMPH_RECURSION_DEPTH; - } - - recursion_increment recursion() noexcept { return {m_recursion_depth}; } -}; -} // namespace oomph + public: + using pool_factory_type = util::pool_factory; + using recursion_increment = increment_guard; + + protected: + context_base* m_context; + pool_factory_type m_req_state_factory; + std::size_t m_recursion_depth = 0u; + + communicator_base(context_base* ctxt) + : m_context(ctxt) + { + } + + public: + rank_type rank() const noexcept { return m_context->rank(); } + rank_type size() const noexcept { return m_context->size(); } + MPI_Comm mpi_comm() const noexcept { return m_context->get_comm(); } + rank_topology const& topology() const noexcept { return m_context->topology(); } + void release() { m_context->deregister_communicator(static_cast(this)); } + bool is_local(rank_type rank) const noexcept { return topology().is_local(rank); } + + bool has_reached_recursion_depth() const noexcept + { + return m_recursion_depth > OOMPH_RECURSION_DEPTH; + } + + recursion_increment recursion() noexcept { return {m_recursion_depth}; } + }; +} // namespace oomph diff --git a/src/communicator_set.cpp b/src/communicator_set.cpp index dd613d83..b35804ca 100644 --- a/src/communicator_set.cpp +++ b/src/communicator_set.cpp @@ -11,52 +11,40 @@ // paths relative to backend #if OOMPH_ENABLE_BARRIER -#include <../communicator_set_impl.hpp> +# include <../communicator_set_impl.hpp> #else -#include <../communicator_set_noop.hpp> +# include <../communicator_set_noop.hpp> #endif #include <../message_buffer.hpp> #include <../util/heap_pimpl_src.hpp> OOMPH_INSTANTIATE_HEAP_PIMPL(oomph::communicator_set::impl) -namespace oomph -{ - -communicator_set& -communicator_set::get() -{ - static communicator_set s; - return s; -} - -communicator_set::communicator_set() -: m_impl{util::make_heap_pimpl()} -{ -} - -void -communicator_set::insert(context_impl const* ctxt, communicator_impl* comm) -{ - m_impl->insert(ctxt, comm); -} - -void -communicator_set::erase(context_impl const* ctxt, communicator_impl* comm) -{ - m_impl->erase(ctxt, comm); -} - -void -communicator_set::erase(context_impl const* ctxt) -{ - m_impl->erase(ctxt); -} - -void -communicator_set::progress(context_impl const* ctxt) -{ - m_impl->progress(ctxt); -} - -} // namespace oomph +namespace oomph { + + communicator_set& communicator_set::get() + { + static communicator_set s; + return s; + } + + communicator_set::communicator_set() + : m_impl{util::make_heap_pimpl()} + { + } + + void communicator_set::insert(context_impl const* ctxt, communicator_impl* comm) + { + m_impl->insert(ctxt, comm); + } + + void communicator_set::erase(context_impl const* ctxt, communicator_impl* comm) + { + m_impl->erase(ctxt, comm); + } + + void communicator_set::erase(context_impl const* ctxt) { m_impl->erase(ctxt); } + + void communicator_set::progress(context_impl const* ctxt) { m_impl->progress(ctxt); } + +} // namespace oomph diff --git a/src/communicator_set.hpp b/src/communicator_set.hpp index 18189ef6..a2546d1a 100644 --- a/src/communicator_set.hpp +++ b/src/communicator_set.hpp @@ -10,40 +10,39 @@ #pragma once #include -#include #include +#include #include #include -namespace oomph -{ +namespace oomph { -// singleton -class communicator_set -{ - private: - struct impl; - util::heap_pimpl m_impl; + // singleton + class communicator_set + { + private: + struct impl; + util::heap_pimpl m_impl; - private: - communicator_set(); - communicator_set(communicator_set const&) = delete; - communicator_set& operator=(communicator_set const&) = delete; + private: + communicator_set(); + communicator_set(communicator_set const&) = delete; + communicator_set& operator=(communicator_set const&) = delete; - public: - ~communicator_set() = default; + public: + ~communicator_set() = default; - public: - static communicator_set& get(); + public: + static communicator_set& get(); - public: - void insert(context_impl const* ctxt, communicator_impl* comm); + public: + void insert(context_impl const* ctxt, communicator_impl* comm); - void erase(context_impl const* ctxt, communicator_impl* comm); + void erase(context_impl const* ctxt, communicator_impl* comm); - void erase(context_impl const* ctxt); + void erase(context_impl const* ctxt); - void progress(context_impl const* ctxt); -}; + void progress(context_impl const* ctxt); + }; -} // namespace oomph +} // namespace oomph diff --git a/src/communicator_set_impl.hpp b/src/communicator_set_impl.hpp index 698a3031..4157e9f2 100644 --- a/src/communicator_set_impl.hpp +++ b/src/communicator_set_impl.hpp @@ -9,56 +9,55 @@ */ #pragma once +#include #include #include -#include // paths relative to backend -#include -#include #include <../communicator_set.hpp> #include <../thread_id.hpp> +#include +#include -namespace oomph -{ - -struct communicator_set::impl -{ - using set_type = std::set; - using map_type = std::map; - using mutex = std::mutex; - using lock_guard = std::lock_guard; - - mutex m_mtx; - std::map m_map; - - void insert(context_impl const* ctxt, communicator_impl* comm) - { - auto const& _tid = tid(); - lock_guard lock(m_mtx); - m_map[ctxt][_tid].insert(comm); - } - - void erase(context_impl const* ctxt, communicator_impl* comm) - { - auto const& _tid = tid(); - lock_guard lock(m_mtx); - m_map[ctxt][_tid].erase(comm); - } - - void erase(context_impl const* ctxt) - { - lock_guard lock(m_mtx); - m_map.erase(ctxt); - } +namespace oomph { - void progress(context_impl const* ctxt) + struct communicator_set::impl { - auto const& _tid = tid(); - lock_guard lock(m_mtx); - auto& s = m_map[ctxt][_tid]; - for (auto c : s) c->progress(); - } -}; - -} // namespace oomph + using set_type = std::set; + using map_type = std::map; + using mutex = std::mutex; + using lock_guard = std::lock_guard; + + mutex m_mtx; + std::map m_map; + + void insert(context_impl const* ctxt, communicator_impl* comm) + { + auto const& _tid = tid(); + lock_guard lock(m_mtx); + m_map[ctxt][_tid].insert(comm); + } + + void erase(context_impl const* ctxt, communicator_impl* comm) + { + auto const& _tid = tid(); + lock_guard lock(m_mtx); + m_map[ctxt][_tid].erase(comm); + } + + void erase(context_impl const* ctxt) + { + lock_guard lock(m_mtx); + m_map.erase(ctxt); + } + + void progress(context_impl const* ctxt) + { + auto const& _tid = tid(); + lock_guard lock(m_mtx); + auto& s = m_map[ctxt][_tid]; + for (auto c : s) c->progress(); + } + }; + +} // namespace oomph diff --git a/src/communicator_set_noop.hpp b/src/communicator_set_noop.hpp index 01ae8c45..0744f6cb 100644 --- a/src/communicator_set_noop.hpp +++ b/src/communicator_set_noop.hpp @@ -12,18 +12,17 @@ // paths relative to backend #include <../communicator_set.hpp> -namespace oomph -{ +namespace oomph { -struct communicator_set::impl -{ - void insert(context_impl const*, communicator_impl*) {} + struct communicator_set::impl + { + void insert(context_impl const*, communicator_impl*) {} - void erase(context_impl const*, communicator_impl*) {} + void erase(context_impl const*, communicator_impl*) {} - void erase(context_impl const*) {} + void erase(context_impl const*) {} - void progress(context_impl const*) {} -}; + void progress(context_impl const*) {} + }; -} // namespace oomph +} // namespace oomph diff --git a/src/communicator_state.cpp b/src/communicator_state.cpp index a200a4ba..8f687d50 100644 --- a/src/communicator_state.cpp +++ b/src/communicator_state.cpp @@ -8,32 +8,28 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include -#include #include +#include // paths relative to backend -#include #include <../communicator_set.hpp> +#include -namespace oomph -{ -namespace detail -{ -communicator_state::communicator_state(impl_type* impl_, - std::atomic* shared_scheduled_recvs) - //, util::tag_range tr, util::tag_range rtr) -: m_impl{impl_} -, m_shared_scheduled_recvs{shared_scheduled_recvs} -//, m_tag_range(tr) -//, m_reserved_tag_range(rtr) -{ - communicator_set::get().insert(m_impl->m_context, m_impl); -} +namespace oomph { namespace detail { + communicator_state::communicator_state( + impl_type* impl_, std::atomic* shared_scheduled_recvs) + //, util::tag_range tr, util::tag_range rtr) + : m_impl{impl_} + , m_shared_scheduled_recvs{shared_scheduled_recvs} + //, m_tag_range(tr) + //, m_reserved_tag_range(rtr) + { + communicator_set::get().insert(m_impl->m_context, m_impl); + } -communicator_state::~communicator_state() -{ - communicator_set::get().erase(m_impl->m_context, m_impl); - m_impl->release(); -} -} // namespace detail -} // namespace oomph + communicator_state::~communicator_state() + { + communicator_set::get().erase(m_impl->m_context, m_impl); + m_impl->release(); + } +}} // namespace oomph::detail diff --git a/src/context.cpp b/src/context.cpp index e6be5ad0..78da8b96 100644 --- a/src/context.cpp +++ b/src/context.cpp @@ -11,97 +11,75 @@ #include // paths relative to backend -#include -#include -#include <../message_buffer.hpp> #include <../communicator_set.hpp> +#include <../message_buffer.hpp> #include <../util/heap_pimpl_src.hpp> +#include +#include OOMPH_INSTANTIATE_HEAP_PIMPL(oomph::context_impl) OOMPH_INSTANTIATE_HEAP_PIMPL(oomph::detail::message_buffer::heap_ptr_impl) -namespace oomph -{ - -context::context(MPI_Comm comm, bool thread_safe, //unsigned int num_tag_ranges, - bool message_pool_never_free, std::size_t message_pool_reserve) -: m_mpi_comm{comm} -, m(m_mpi_comm.get(), thread_safe, message_pool_never_free, message_pool_reserve) -, m_schedule{std::make_unique()} -//, m_tag_range_factory(num_tag_ranges, m->num_tag_bits()) -{ -} - -context::~context() { communicator_set::get().erase(m.get()); } - -communicator -context::get_communicator()//unsigned int tr) -{ - return {m->get_communicator(), &(m_schedule->scheduled_recvs)}; - //, m_tag_range_factory.create(tr), - // m_tag_range_factory.create(tr, true)}; -} - -rank_type -context::rank() const noexcept -{ - return m->rank(); -} - -rank_type -context::size() const noexcept -{ - return m->size(); -} - -rank_type -context::local_rank() const noexcept -{ - return m->topology().local_rank(); -} - -rank_type -context::local_size() const noexcept -{ - return m->topology().local_size(); -} - -const char* -context::get_transport_option(const std::string& opt) -{ - return m->get_transport_option(opt); -} - -detail::message_buffer -context::make_buffer_core(std::size_t size) -{ - return m->get_heap().allocate(size, hwmalloc::numa().local_node()); -} - -detail::message_buffer -context::make_buffer_core(void* ptr, std::size_t size) -{ - return m->get_heap().register_user_allocation(ptr, size); -} +namespace oomph { + + context::context(MPI_Comm comm, bool thread_safe, //unsigned int num_tag_ranges, + bool message_pool_never_free, std::size_t message_pool_reserve) + : m_mpi_comm{comm} + , m(m_mpi_comm.get(), thread_safe, message_pool_never_free, message_pool_reserve) + , m_schedule{std::make_unique()} + //, m_tag_range_factory(num_tag_ranges, m->num_tag_bits()) + { + } + + context::~context() { communicator_set::get().erase(m.get()); } + + communicator context::get_communicator() //unsigned int tr) + { + return {m->get_communicator(), &(m_schedule->scheduled_recvs)}; + //, m_tag_range_factory.create(tr), + // m_tag_range_factory.create(tr, true)}; + } + + rank_type context::rank() const noexcept { return m->rank(); } + + rank_type context::size() const noexcept { return m->size(); } + + rank_type context::local_rank() const noexcept { return m->topology().local_rank(); } + + rank_type context::local_size() const noexcept { return m->topology().local_size(); } + + char const* context::get_transport_option(std::string const& opt) + { + return m->get_transport_option(opt); + } + + detail::message_buffer context::make_buffer_core(std::size_t size) + { + return m->get_heap().allocate(size, hwmalloc::numa().local_node()); + } + + detail::message_buffer context::make_buffer_core(void* ptr, std::size_t size) + { + return m->get_heap().register_user_allocation(ptr, size); + } #if OOMPH_ENABLE_DEVICE -detail::message_buffer -context::make_buffer_core(std::size_t size, int id) -{ - return m->get_heap().allocate(size, hwmalloc::numa().local_node(), id); -} - -detail::message_buffer -context::make_buffer_core(void* device_ptr, std::size_t size, int device_id) -{ - return m->get_heap().register_user_allocation(device_ptr, device_id, size); -} - -detail::message_buffer -context::make_buffer_core(void* ptr, void* device_ptr, std::size_t size, int device_id) -{ - return m->get_heap().register_user_allocation(ptr, device_ptr, device_id, size); -} + detail::message_buffer context::make_buffer_core(std::size_t size, int id) + { + return m->get_heap().allocate(size, hwmalloc::numa().local_node(), id); + } + + detail::message_buffer context::make_buffer_core( + void* device_ptr, std::size_t size, int device_id) + { + return m->get_heap().register_user_allocation(device_ptr, device_id, size); + } + + detail::message_buffer context::make_buffer_core( + void* ptr, void* device_ptr, std::size_t size, int device_id) + { + return m->get_heap().register_user_allocation(ptr, device_ptr, device_id, size); + } #endif -} // namespace oomph +} // namespace oomph diff --git a/src/context_base.hpp b/src/context_base.hpp index df81bf31..d3f9be4a 100644 --- a/src/context_base.hpp +++ b/src/context_base.hpp @@ -9,59 +9,59 @@ */ #pragma once -#include #include +#include #include // paths relative to backend +#include <../increment_guard.hpp> #include <../mpi_comm.hpp> -#include <../unique_ptr_set.hpp> #include <../rank_topology.hpp> -#include <../increment_guard.hpp> +#include <../unique_ptr_set.hpp> -namespace oomph -{ -class context_base -{ - public: - using recursion_increment = increment_guard>; +namespace oomph { + class context_base + { + public: + using recursion_increment = increment_guard>; - protected: - mpi_comm m_mpi_comm; - bool const m_thread_safe; - rank_topology const m_rank_topology; - unique_ptr_set m_comms_set; - std::atomic m_recursion_depth = 0u; + protected: + mpi_comm m_mpi_comm; + bool const m_thread_safe; + rank_topology const m_rank_topology; + unique_ptr_set m_comms_set; + std::atomic m_recursion_depth = 0u; - public: - context_base(MPI_Comm comm, bool thread_safe) - : m_mpi_comm{comm} - , m_thread_safe{thread_safe} - , m_rank_topology(comm) - { - int mpi_thread_safety; - OOMPH_CHECK_MPI_RESULT(MPI_Query_thread(&mpi_thread_safety)); - if (m_thread_safe && !(mpi_thread_safety == MPI_THREAD_MULTIPLE)) - throw std::runtime_error("oomph: MPI is not thread safe!"); - else if (!m_thread_safe && !(mpi_thread_safety == MPI_THREAD_SINGLE) && rank() == 0) - std::cerr << "oomph warning: MPI thread safety is higher than required" << std::endl; - } + public: + context_base(MPI_Comm comm, bool thread_safe) + : m_mpi_comm{comm} + , m_thread_safe{thread_safe} + , m_rank_topology(comm) + { + int mpi_thread_safety; + OOMPH_CHECK_MPI_RESULT(MPI_Query_thread(&mpi_thread_safety)); + if (m_thread_safe && !(mpi_thread_safety == MPI_THREAD_MULTIPLE)) + throw std::runtime_error("oomph: MPI is not thread safe!"); + else if (!m_thread_safe && !(mpi_thread_safety == MPI_THREAD_SINGLE) && rank() == 0) + std::cerr << "oomph warning: MPI thread safety is higher than required" + << std::endl; + } - public: - rank_type rank() const noexcept { return m_mpi_comm.rank(); } - rank_type size() const noexcept { return m_mpi_comm.size(); } - rank_topology const& topology() const noexcept { return m_rank_topology; } - MPI_Comm get_comm() const noexcept { return m_mpi_comm; } - bool thread_safe() const noexcept { return m_thread_safe; } + public: + rank_type rank() const noexcept { return m_mpi_comm.rank(); } + rank_type size() const noexcept { return m_mpi_comm.size(); } + rank_topology const& topology() const noexcept { return m_rank_topology; } + MPI_Comm get_comm() const noexcept { return m_mpi_comm; } + bool thread_safe() const noexcept { return m_thread_safe; } - void deregister_communicator(communicator_impl* c) { m_comms_set.remove(c); } + void deregister_communicator(communicator_impl* c) { m_comms_set.remove(c); } - bool has_reached_recursion_depth() const noexcept - { - return m_recursion_depth > OOMPH_RECURSION_DEPTH; - } + bool has_reached_recursion_depth() const noexcept + { + return m_recursion_depth > OOMPH_RECURSION_DEPTH; + } - recursion_increment recursion() noexcept { return {m_recursion_depth}; } -}; + recursion_increment recursion() noexcept { return {m_recursion_depth}; } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/debug.hpp b/src/debug.hpp index 1ee5db8c..ebe074f0 100644 --- a/src/debug.hpp +++ b/src/debug.hpp @@ -13,28 +13,25 @@ #include #if (OOMPH_DEBUG_LEVEL >= 2) -#define OOMPH_LOG(msg, ...) \ - do \ - { \ - time_t tm = time(NULL); \ - char* stm = ctime(&tm); \ - stm[strlen(stm) - 1] = 0; \ - (void)fprintf(stderr, "%s %s:%d " msg "\n", stm, __FILE__, __LINE__, ##__VA_ARGS__); \ - (void)fflush(stderr); \ - } while (0); +# define OOMPH_LOG(msg, ...) \ + do { \ + time_t tm = time(NULL); \ + char* stm = ctime(&tm); \ + stm[strlen(stm) - 1] = 0; \ + (void) fprintf(stderr, "%s %s:%d " msg "\n", stm, __FILE__, __LINE__, ##__VA_ARGS__); \ + (void) fflush(stderr); \ + } while (0); #else -#define OOMPH_LOG(msg, ...) \ - do \ - { \ - } while (0); +# define OOMPH_LOG(msg, ...) \ + do { \ + } while (0); #endif #define OOMPH_WARN(msg, ...) \ - do \ - { \ + do { \ time_t tm = time(NULL); \ - char* stm = ctime(&tm); \ + char* stm = ctime(&tm); \ stm[strlen(stm) - 1] = 0; \ - (void)fprintf( \ + (void) fprintf( \ stderr, "%s WARNING: %s:%d " msg "\n", stm, __FILE__, __LINE__, ##__VA_ARGS__); \ } while (0); diff --git a/src/device_guard.hpp b/src/device_guard.hpp index 30f6763e..4b67eebb 100644 --- a/src/device_guard.hpp +++ b/src/device_guard.hpp @@ -12,84 +12,71 @@ #include #include -namespace oomph -{ -struct device_guard_base -{ - bool m_on_device; - int m_new_device_id; - int m_current_device_id; - - device_guard_base(bool on_device = false, int new_id = 0) - : m_on_device{on_device} - , m_new_device_id{new_id} +namespace oomph { + struct device_guard_base { + bool m_on_device; + int m_new_device_id; + int m_current_device_id; + + device_guard_base(bool on_device = false, int new_id = 0) + : m_on_device{on_device} + , m_new_device_id{new_id} + { #if OOMPH_ENABLE_DEVICE - m_current_device_id = hwmalloc::get_device_id(); - if (m_on_device && (m_current_device_id != m_new_device_id)) - hwmalloc::set_device_id(m_new_device_id); + m_current_device_id = hwmalloc::get_device_id(); + if (m_on_device && (m_current_device_id != m_new_device_id)) + hwmalloc::set_device_id(m_new_device_id); #endif - } + } - device_guard_base(device_guard_base const&) = delete; + device_guard_base(device_guard_base const&) = delete; - ~device_guard_base() - { + ~device_guard_base() + { #if OOMPH_ENABLE_DEVICE - if (m_on_device && (m_current_device_id != m_new_device_id)) - hwmalloc::set_device_id(m_current_device_id); + if (m_on_device && (m_current_device_id != m_new_device_id)) + hwmalloc::set_device_id(m_current_device_id); #endif - } -}; + } + }; -struct device_guard : public device_guard_base -{ - void* m_ptr; + struct device_guard : public device_guard_base + { + void* m_ptr; - template - device_guard(Pointer& ptr) + template + device_guard(Pointer& ptr) #if OOMPH_ENABLE_DEVICE - : device_guard_base(ptr.on_device(), ptr.device_id()) - , m_ptr - { - ptr.on_device() ? ptr.device_ptr() : ptr.get() - } + : device_guard_base(ptr.on_device(), ptr.device_id()) + , m_ptr{ptr.on_device() ? ptr.device_ptr() : ptr.get()} #else - : device_guard_base() - , m_ptr - { - ptr.get() - } + : device_guard_base() + , m_ptr{ptr.get()} #endif - { - } + { + } - void* data() const noexcept { return m_ptr; } -}; + void* data() const noexcept { return m_ptr; } + }; -struct const_device_guard : public device_guard_base -{ - void const* m_ptr; + struct const_device_guard : public device_guard_base + { + void const* m_ptr; - template - const_device_guard(Pointer const& ptr) + template + const_device_guard(Pointer const& ptr) #if OOMPH_ENABLE_DEVICE - : device_guard_base(ptr.on_device(), ptr.device_id()) - , m_ptr - { - ptr.on_device() ? ptr.device_ptr() : ptr.get() - } + : device_guard_base(ptr.on_device(), ptr.device_id()) + , m_ptr{ptr.on_device() ? ptr.device_ptr() : ptr.get()} #else - : device_guard_base() - , m_ptr - { - ptr.get() - } + : device_guard_base() + , m_ptr{ptr.get()} #endif - { - } + { + } - void const* data() const noexcept { return m_ptr; } -}; + void const* data() const noexcept { return m_ptr; } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/increment_guard.hpp b/src/increment_guard.hpp index 538925ae..622f67c4 100644 --- a/src/increment_guard.hpp +++ b/src/increment_guard.hpp @@ -9,35 +9,34 @@ */ #pragma once -namespace oomph -{ +namespace oomph { -template -class increment_guard -{ - private: - T* m = nullptr; - - public: - increment_guard(T& r) noexcept - : m{&r} + template + class increment_guard { - ++(*m); - } + private: + T* m = nullptr; - increment_guard(increment_guard&& other) noexcept - : m(other.m) - { - other.m = nullptr; - } + public: + increment_guard(T& r) noexcept + : m{&r} + { + ++(*m); + } - increment_guard(increment_guard const&) = delete; - increment_guard& operator=(increment_guard const&) = delete; - increment_guard& operator=(increment_guard&&) = delete; + increment_guard(increment_guard&& other) noexcept + : m(other.m) + { + other.m = nullptr; + } - ~increment_guard() - { - if (m) --(*m); - } -}; -} // namespace oomph + increment_guard(increment_guard const&) = delete; + increment_guard& operator=(increment_guard const&) = delete; + increment_guard& operator=(increment_guard&&) = delete; + + ~increment_guard() + { + if (m) --(*m); + } + }; +} // namespace oomph diff --git a/src/libfabric/communicator.hpp b/src/libfabric/communicator.hpp index ff8fc945..a38419dc 100644 --- a/src/libfabric/communicator.hpp +++ b/src/libfabric/communicator.hpp @@ -14,108 +14,109 @@ #include -#include #include +#include // paths relative to backend #include <../communicator_base.hpp> #include <../device_guard.hpp> +#include +#include #include #include -#include -#include - -namespace oomph -{ -using operation_context = libfabric::operation_context; +namespace oomph { -using tag_disp = NS_DEBUG::detail::hex<12, uintptr_t>; + using operation_context = libfabric::operation_context; -template -inline /*constexpr*/ NS_DEBUG::print_threshold com_deb("COMMUNI"); + using tag_disp = NS_DEBUG::detail::hex<12, uintptr_t>; -static NS_DEBUG::enable_print com_err("COMMUNI"); + template + inline /*constexpr*/ NS_DEBUG::print_threshold com_deb("COMMUNI"); -class communicator_impl : public communicator_base -{ - using tag_type = std::uint64_t; - // - using segment_type = libfabric::memory_segment; - using region_type = segment_type::handle_type; + static NS_DEBUG::enable_print com_err("COMMUNI"); - using callback_queue = boost::lockfree::queue, boost::lockfree::allocator>>; - - public: - context_impl* m_context; - libfabric::endpoint_wrapper m_tx_endpoint; - libfabric::endpoint_wrapper m_rx_endpoint; - // - callback_queue m_send_cb_queue; - callback_queue m_recv_cb_queue; - callback_queue m_recv_cb_cancel; - - // -------------------------------------------------------------------- - communicator_impl(context_impl* ctxt) - : communicator_base(ctxt) - , m_context(ctxt) - , m_send_cb_queue(128) - , m_recv_cb_queue(128) - , m_recv_cb_cancel(8) + class communicator_impl : public communicator_base { - LF_DEB(com_deb<9>, debug(NS_DEBUG::str<>("MPI_comm"), NS_DEBUG::ptr(mpi_comm()))); - m_tx_endpoint = m_context->get_controller()->get_tx_endpoint(); - m_rx_endpoint = m_context->get_controller()->get_rx_endpoint(); - } + using tag_type = std::uint64_t; + // + using segment_type = libfabric::memory_segment; + using region_type = segment_type::handle_type; + + using callback_queue = boost::lockfree::queue, boost::lockfree::allocator>>; + + public: + context_impl* m_context; + libfabric::endpoint_wrapper m_tx_endpoint; + libfabric::endpoint_wrapper m_rx_endpoint; + // + callback_queue m_send_cb_queue; + callback_queue m_recv_cb_queue; + callback_queue m_recv_cb_cancel; + + // -------------------------------------------------------------------- + communicator_impl(context_impl* ctxt) + : communicator_base(ctxt) + , m_context(ctxt) + , m_send_cb_queue(128) + , m_recv_cb_queue(128) + , m_recv_cb_cancel(8) + { + LF_DEB(com_deb<9>, debug(NS_DEBUG::str<>("MPI_comm"), NS_DEBUG::ptr(mpi_comm()))); + m_tx_endpoint = m_context->get_controller()->get_tx_endpoint(); + m_rx_endpoint = m_context->get_controller()->get_rx_endpoint(); + } - // -------------------------------------------------------------------- - ~communicator_impl() { clear_callback_queues(); } + // -------------------------------------------------------------------- + ~communicator_impl() { clear_callback_queues(); } - // -------------------------------------------------------------------- - auto& get_heap() noexcept { return m_context->get_heap(); } + // -------------------------------------------------------------------- + auto& get_heap() noexcept { return m_context->get_heap(); } - // -------------------------------------------------------------------- - /// generate a tag with 0xRRRRRRRRtttttttt rank, tag. - /// original tag can be 32bits, then we add 32bits of rank info. - inline std::uint64_t make_tag64(std::uint32_t tag, /*std::uint32_t rank, */ std::uintptr_t ctxt) - { - return (((ctxt & 0x0000000000FFFFFF) << 24) | ((std::uint64_t(tag) & 0x0000000000FFFFFF))); - } + // -------------------------------------------------------------------- + /// generate a tag with 0xRRRRRRRRtttttttt rank, tag. + /// original tag can be 32bits, then we add 32bits of rank info. + inline std::uint64_t make_tag64( + std::uint32_t tag, /*std::uint32_t rank, */ std::uintptr_t ctxt) + { + return (((ctxt & 0x0000'0000'00FF'FFFF) << 24) | + ((std::uint64_t(tag) & 0x0000'0000'00FF'FFFF))); + } - // -------------------------------------------------------------------- - template - inline void execute_fi_function(Func F, const char* msg, Args&&... args) - { - bool ok = false; - while (!ok) + // -------------------------------------------------------------------- + template + inline void execute_fi_function(Func F, char const* msg, Args&&... args) { - ssize_t ret = F(std::forward(args)...); - if (ret == 0) { return; } - else if (ret == -FI_EAGAIN) - { - // com_deb<9>.error("Reposting", msg); - // no point stressing the system - m_context->get_controller()->poll_for_work_completions(this); - } - else if (ret == -FI_ENOENT) + bool ok = false; + while (!ok) { - // if a node has failed, we can recover - // @TODO : put something better here - com_err.error("No destination endpoint, terminating."); - std::terminate(); + ssize_t ret = F(std::forward(args)...); + if (ret == 0) { return; } + else if (ret == -FI_EAGAIN) + { + // com_deb<9>.error("Reposting", msg); + // no point stressing the system + m_context->get_controller()->poll_for_work_completions(this); + } + else if (ret == -FI_ENOENT) + { + // if a node has failed, we can recover + // @TODO : put something better here + com_err.error("No destination endpoint, terminating."); + std::terminate(); + } + else if (ret) { throw NS_LIBFABRIC::fabric_error(int(ret), msg); } } - else if (ret) { throw NS_LIBFABRIC::fabric_error(int(ret), msg); } } - } - // -------------------------------------------------------------------- - // this takes a pinned memory region and sends it - void send_tagged_region(region_type const& send_region, std::size_t size, fi_addr_t dst_addr_, - uint64_t tag_, operation_context* ctxt) - { - [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - // clang-format off + // -------------------------------------------------------------------- + // this takes a pinned memory region and sends it + void send_tagged_region(region_type const& send_region, std::size_t size, + fi_addr_t dst_addr_, uint64_t tag_, operation_context* ctxt) + { + [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); + // clang-format off LF_DEB(com_deb<9>, debug(NS_DEBUG::str<>("send_tagged_region"), "->", NS_DEBUG::dec<2>(dst_addr_), @@ -123,22 +124,24 @@ class communicator_impl : public communicator_base "tag", tag_disp(tag_), "context", NS_DEBUG::ptr(ctxt), "tx endpoint", NS_DEBUG::ptr(m_tx_endpoint.get_ep()))); - // clang-format on - execute_fi_function(fi_tsend, "fi_tsend", m_tx_endpoint.get_ep(), send_region.get_address(), - size, send_region.get_local_key(), dst_addr_, tag_, ctxt); - } + // clang-format on + execute_fi_function(fi_tsend, "fi_tsend", m_tx_endpoint.get_ep(), + send_region.get_address(), size, send_region.get_local_key(), dst_addr_, tag_, + ctxt); + } - // -------------------------------------------------------------------- - // this takes a pinned memory region and sends it using inject instead of send - void inject_tagged_region(region_type const& send_region, std::size_t size, fi_addr_t dst_addr_, - uint64_t tag_) - { - [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - // clang-format on - LF_DEB(com_deb<9>, - debug(NS_DEBUG::str<>("inject tagged"), "->", NS_DEBUG::dec<2>(dst_addr_), send_region, - "tag", tag_disp(tag_), "tx endpoint", NS_DEBUG::ptr(m_tx_endpoint.get_ep()))); - // clang-format off + // -------------------------------------------------------------------- + // this takes a pinned memory region and sends it using inject instead of send + void inject_tagged_region( + region_type const& send_region, std::size_t size, fi_addr_t dst_addr_, uint64_t tag_) + { + [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); + // clang-format on + LF_DEB(com_deb<9>, + debug(NS_DEBUG::str<>("inject tagged"), "->", NS_DEBUG::dec<2>(dst_addr_), + send_region, "tag", tag_disp(tag_), "tx endpoint", + NS_DEBUG::ptr(m_tx_endpoint.get_ep()))); + // clang-format off execute_fi_function(fi_tinject, "fi_tinject", m_tx_endpoint.get_ep(), send_region.get_address(), size, dst_addr_, tag_); } @@ -159,62 +162,65 @@ class communicator_impl : public communicator_base "tag", tag_disp(tag_), "context", NS_DEBUG::ptr(ctxt), "rx endpoint", NS_DEBUG::ptr(m_rx_endpoint.get_ep()))); - // clang-format on - constexpr uint64_t ignore = 0; - execute_fi_function(fi_trecv, "fi_trecv", m_rx_endpoint.get_ep(), recv_region.get_address(), - size, recv_region.get_local_key(), src_addr_, tag_, ignore, ctxt); - // if (l.owns_lock()) l.unlock(); - } + // clang-format on + constexpr uint64_t ignore = 0; + execute_fi_function(fi_trecv, "fi_trecv", m_rx_endpoint.get_ep(), + recv_region.get_address(), size, recv_region.get_local_key(), src_addr_, tag_, + ignore, ctxt); + // if (l.owns_lock()) l.unlock(); + } - // -------------------------------------------------------------------- - send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - oomph::tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) - { - [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - std::uint64_t stag = make_tag64(tag, /*this->rank(), */ this->m_context->get_context_tag()); + // -------------------------------------------------------------------- + send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, + rank_type dst, oomph::tag_type tag, + util::unique_function&& cb, std::size_t* scheduled) + { + [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); + std::uint64_t stag = + make_tag64(tag, /*this->rank(), */ this->m_context->get_context_tag()); #if OOMPH_ENABLE_DEVICE - auto const& reg = ptr.on_device() ? ptr.device_handle() : ptr.handle(); + auto const& reg = ptr.on_device() ? ptr.device_handle() : ptr.handle(); #else - auto const& reg = ptr.handle(); + auto const& reg = ptr.handle(); #endif #ifdef EXTRA_SIZE_CHECKS - if (size != reg.get_size()) - { - LF_DEB(com_err, error(NS_DEBUG::str<>("send mismatch"), "size", NS_DEBUG::hex<6>(size), - "reg size", NS_DEBUG::hex<6>(reg.get_size()))); - } -#endif - m_context->get_controller()->sends_posted_++; - - // use optimized inject if msg is very small - if (size <= m_context->get_controller()->get_tx_inject_size()) - { - inject_tagged_region(reg, size, fi_addr_t(dst), stag); - if (!has_reached_recursion_depth()) + if (size != reg.get_size()) { - auto inc = recursion(); - cb(dst, tag); - return {}; + LF_DEB(com_err, + error(NS_DEBUG::str<>("send mismatch"), "size", NS_DEBUG::hex<6>(size), + "reg size", NS_DEBUG::hex<6>(reg.get_size()))); } - else +#endif + m_context->get_controller()->sends_posted_++; + + // use optimized inject if msg is very small + if (size <= m_context->get_controller()->get_tx_inject_size()) { - // construct request which is also an operation context - auto s = - m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb)); - s->create_self_ref(); - while (!m_send_cb_queue.push(s.get())) {} - return {std::move(s)}; + inject_tagged_region(reg, size, fi_addr_t(dst), stag); + if (!has_reached_recursion_depth()) + { + auto inc = recursion(); + cb(dst, tag); + return {}; + } + else + { + // construct request which is also an operation context + auto s = m_req_state_factory.make( + m_context, this, scheduled, dst, tag, std::move(cb)); + s->create_self_ref(); + while (!m_send_cb_queue.push(s.get())) {} + return {std::move(s)}; + } } - } - // construct request which is also an operation context - auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb)); - s->create_self_ref(); + // construct request which is also an operation context + auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb)); + s->create_self_ref(); - // clang-format off + // clang-format off LF_DEB(com_deb<9>, debug(NS_DEBUG::str<>("Send"), "thisrank", NS_DEBUG::dec<>(rank()), @@ -234,39 +240,40 @@ class communicator_impl : public communicator_base NS_DEBUG::mem_crc32(reg.get_address(), size, "CRC32"))); } #endif - // clang-format on + // clang-format on - send_tagged_region(reg, size, fi_addr_t(dst), stag, &(s->m_operation_context)); - return {std::move(s)}; - } + send_tagged_region(reg, size, fi_addr_t(dst), stag, &(s->m_operation_context)); + return {std::move(s)}; + } - recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - oomph::tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) - { - [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - std::uint64_t stag = make_tag64(tag, /*src, */ this->m_context->get_context_tag()); + recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, + oomph::tag_type tag, util::unique_function&& cb, + std::size_t* scheduled) + { + [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); + std::uint64_t stag = make_tag64(tag, /*src, */ this->m_context->get_context_tag()); #if OOMPH_ENABLE_DEVICE - auto const& reg = ptr.on_device() ? ptr.device_handle() : ptr.handle(); + auto const& reg = ptr.on_device() ? ptr.device_handle() : ptr.handle(); #else - auto const& reg = ptr.handle(); + auto const& reg = ptr.handle(); #endif #ifdef EXTRA_SIZE_CHECKS - if (size != reg.get_size()) - { - LF_DEB(com_err, error(NS_DEBUG::str<>("recv mismatch"), "size", NS_DEBUG::hex<6>(size), - "reg size", NS_DEBUG::hex<6>(reg.get_size()))); - } + if (size != reg.get_size()) + { + LF_DEB(com_err, + error(NS_DEBUG::str<>("recv mismatch"), "size", NS_DEBUG::hex<6>(size), + "reg size", NS_DEBUG::hex<6>(reg.get_size()))); + } #endif - m_context->get_controller()->recvs_posted_++; + m_context->get_controller()->recvs_posted_++; - // construct request which is also an operation context - auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb)); - s->create_self_ref(); + // construct request which is also an operation context + auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb)); + s->create_self_ref(); - // clang-format off + // clang-format off LF_DEB(com_deb<9>, debug(NS_DEBUG::str<>("recv"), "thisrank", NS_DEBUG::dec<>(rank()), @@ -286,41 +293,42 @@ class communicator_impl : public communicator_base NS_DEBUG::mem_crc32(reg.get_address(), size, "CRC32"))); } #endif - // clang-format on + // clang-format on - recv_tagged_region(reg, size, fi_addr_t(src), stag, &(s->m_operation_context)); - return {std::move(s)}; - } + recv_tagged_region(reg, size, fi_addr_t(src), stag, &(s->m_operation_context)); + return {std::move(s)}; + } - shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, - rank_type src, oomph::tag_type tag, - util::unique_function&& cb, - std::atomic* scheduled) - { - [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - std::uint64_t stag = make_tag64(tag, /*src, */ this->m_context->get_context_tag()); + shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, + rank_type src, oomph::tag_type tag, + util::unique_function&& cb, + std::atomic* scheduled) + { + [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); + std::uint64_t stag = make_tag64(tag, /*src, */ this->m_context->get_context_tag()); #if OOMPH_ENABLE_DEVICE - auto const& reg = ptr.on_device() ? ptr.device_handle() : ptr.handle(); + auto const& reg = ptr.on_device() ? ptr.device_handle() : ptr.handle(); #else - auto const& reg = ptr.handle(); + auto const& reg = ptr.handle(); #endif #ifdef EXTRA_SIZE_CHECKS - if (size != reg.get_size()) - { - LF_DEB(com_err, error(NS_DEBUG::str<>("recv mismatch"), "size", NS_DEBUG::hex<6>(size), - "reg size", NS_DEBUG::hex<6>(reg.get_size()))); - } + if (size != reg.get_size()) + { + LF_DEB(com_err, + error(NS_DEBUG::str<>("recv mismatch"), "size", NS_DEBUG::hex<6>(size), + "reg size", NS_DEBUG::hex<6>(reg.get_size()))); + } #endif - m_context->get_controller()->recvs_posted_++; + m_context->get_controller()->recvs_posted_++; - // construct request which is also an operation context - auto s = std::make_shared(m_context, this, scheduled, src, - tag, std::move(cb)); - s->create_self_ref(); + // construct request which is also an operation context + auto s = std::make_shared( + m_context, this, scheduled, src, tag, std::move(cb)); + s->create_self_ref(); - // clang-format off + // clang-format off LF_DEB(com_deb<9>, debug(NS_DEBUG::str<>("shared_recv"), "thisrank", NS_DEBUG::dec<>(rank()), @@ -333,102 +341,97 @@ class communicator_impl : public communicator_base "reg size", NS_DEBUG::hex<6>(reg.get_size()), "op_ctx", NS_DEBUG::ptr(&(s->m_operation_context)), "req", NS_DEBUG::ptr(s.get()))); - // clang-format on + // clang-format on - recv_tagged_region(reg, size, fi_addr_t(src), stag, &(s->m_operation_context)); - m_context->get_controller()->poll_recv_queue(m_rx_endpoint.get_rx_cq(), this); - return {std::move(s)}; - } + recv_tagged_region(reg, size, fi_addr_t(src), stag, &(s->m_operation_context)); + m_context->get_controller()->poll_recv_queue(m_rx_endpoint.get_rx_cq(), this); + return {std::move(s)}; + } - void progress() - { - m_context->get_controller()->poll_for_work_completions(this); - clear_callback_queues(); - } + void progress() + { + m_context->get_controller()->poll_for_work_completions(this); + clear_callback_queues(); + } - void clear_callback_queues() - { - // work through ready callbacks, which were pushed to the queue - // (by other threads) - m_send_cb_queue.consume_all( - [](oomph::detail::request_state* req) - { + void clear_callback_queues() + { + // work through ready callbacks, which were pushed to the queue + // (by other threads) + m_send_cb_queue.consume_all([](oomph::detail::request_state* req) { [[maybe_unused]] auto scp = com_deb<9>.scope("m_send_cb_queue.consume_all", NS_DEBUG::ptr(req)); auto ptr = req->release_self_ref(); req->invoke_cb(); }); - m_recv_cb_queue.consume_all( - [](oomph::detail::request_state* req) - { + m_recv_cb_queue.consume_all([](oomph::detail::request_state* req) { [[maybe_unused]] auto scp = com_deb<9>.scope("m_recv_cb_queue.consume_all", NS_DEBUG::ptr(req)); auto ptr = req->release_self_ref(); req->invoke_cb(); }); - m_context->m_recv_cb_queue.consume_all( - [](detail::shared_request_state* req) - { + m_context->m_recv_cb_queue.consume_all([](detail::shared_request_state* req) { auto ptr = req->release_self_ref(); req->invoke_cb(); }); - } + } - // Cancel is a problem with libfabric because fi_cancel is asynchronous. - // The item to be cancelled will either complete with CANCELLED status - // or will complete as usual (ie before the cancel could take effect) - // - // We can only be certain if we poll until the completion happens - // or attach a callback to the cancel notification which is not supported - // by oomph. - bool cancel_recv(detail::request_state* s) - { - // get the original message operation context - operation_context* op_ctx = &(s->m_operation_context); + // Cancel is a problem with libfabric because fi_cancel is asynchronous. + // The item to be cancelled will either complete with CANCELLED status + // or will complete as usual (ie before the cancel could take effect) + // + // We can only be certain if we poll until the completion happens + // or attach a callback to the cancel notification which is not supported + // by oomph. + bool cancel_recv(detail::request_state* s) + { + // get the original message operation context + operation_context* op_ctx = &(s->m_operation_context); - // submit the cancellation request - bool ok = (fi_cancel(&m_rx_endpoint.get_ep()->fid, op_ctx) == 0); - LF_DEB(com_deb<9>, - debug(NS_DEBUG::str<>("Cancel"), "ok", ok, "op_ctx", NS_DEBUG::ptr(op_ctx))); + // submit the cancellation request + bool ok = (fi_cancel(&m_rx_endpoint.get_ep()->fid, op_ctx) == 0); + LF_DEB(com_deb<9>, + debug(NS_DEBUG::str<>("Cancel"), "ok", ok, "op_ctx", NS_DEBUG::ptr(op_ctx))); - // if the cancel operation failed completely, return - if (!ok) return false; + // if the cancel operation failed completely, return + if (!ok) return false; - bool found = false; - while (!found) - { - m_context->get_controller()->poll_recv_queue(m_rx_endpoint.get_rx_cq(), this); - // otherwise, poll until we know if it worked - std::stack temp_stack; - detail::request_state* temp; - while (!found && m_recv_cb_cancel.pop(temp)) + bool found = false; + while (!found) { - if (temp == s) + m_context->get_controller()->poll_recv_queue(m_rx_endpoint.get_rx_cq(), this); + // otherwise, poll until we know if it worked + std::stack temp_stack; + detail::request_state* temp; + while (!found && m_recv_cb_cancel.pop(temp)) { - // our recv was cancelled correctly - found = true; - LF_DEB(com_deb<9>, debug(NS_DEBUG::str<>("Cancel"), "succeeded", "op_ctx", - NS_DEBUG::ptr(op_ctx))); - auto ptr = s->release_self_ref(); - s->set_canceled(); + if (temp == s) + { + // our recv was cancelled correctly + found = true; + LF_DEB(com_deb<9>, + debug(NS_DEBUG::str<>("Cancel"), "succeeded", "op_ctx", + NS_DEBUG::ptr(op_ctx))); + auto ptr = s->release_self_ref(); + s->set_canceled(); + } + else + { + // a different cancel operation + temp_stack.push(temp); + } } - else + // return any weird unhandled cancels back to the queue + while (!temp_stack.empty()) { - // a different cancel operation - temp_stack.push(temp); + auto temp = temp_stack.top(); + temp_stack.pop(); + m_recv_cb_cancel.push(temp); } } - // return any weird unhandled cancels back to the queue - while (!temp_stack.empty()) - { - auto temp = temp_stack.top(); - temp_stack.pop(); - m_recv_cb_cancel.push(temp); - } + return found; } - return found; - } -}; + }; -} // namespace oomph +} // namespace oomph diff --git a/src/libfabric/context.cpp b/src/libfabric/context.cpp index 9365be8a..cb7757a2 100644 --- a/src/libfabric/context.cpp +++ b/src/libfabric/context.cpp @@ -11,84 +11,82 @@ // #include // paths relative to backend -#include -#include #include #include +#include +#include -namespace oomph -{ -// cppcheck-suppress ConfigurationNotChecked -static NS_DEBUG::enable_print src_deb("__SRC__"); +namespace oomph { + // cppcheck-suppress ConfigurationNotChecked + static NS_DEBUG::enable_print src_deb("__SRC__"); -using controller_type = libfabric::controller; + using controller_type = libfabric::controller; -context_impl::context_impl(MPI_Comm comm, bool thread_safe, bool message_pool_never_free, - std::size_t message_pool_reserve) -: context_base(comm, thread_safe) -, m_heap{this, message_pool_never_free, message_pool_reserve} -, m_recv_cb_queue(128) -, m_recv_cb_cancel(8) -{ - int rank, size; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(comm, &rank)); - OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(comm, &size)); + context_impl::context_impl(MPI_Comm comm, bool thread_safe, bool message_pool_never_free, + std::size_t message_pool_reserve) + : context_base(comm, thread_safe) + , m_heap{this, message_pool_never_free, message_pool_reserve} + , m_recv_cb_queue(128) + , m_recv_cb_cancel(8) + { + int rank, size; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(comm, &rank)); + OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(comm, &size)); - m_ctxt_tag = reinterpret_cast(this); - OOMPH_CHECK_MPI_RESULT(MPI_Bcast(&m_ctxt_tag, 1, MPI_UINT64_T, 0, comm)); - LF_DEB(src_deb, debug(NS_DEBUG::str<>("Broadcast"), "rank", debug::dec<3>(rank), "context", - debug::ptr(m_ctxt_tag))); + m_ctxt_tag = reinterpret_cast(this); + OOMPH_CHECK_MPI_RESULT(MPI_Bcast(&m_ctxt_tag, 1, MPI_UINT64_T, 0, comm)); + LF_DEB(src_deb, + debug(NS_DEBUG::str<>("Broadcast"), "rank", debug::dec<3>(rank), "context", + debug::ptr(m_ctxt_tag))); - // TODO fix the thread safety - // problem: controller is a singleton and has problems when 2 contexts are created in the - // following order: single threaded first, then multi-threaded after - //int threads = thread_safe ? std::thread::hardware_concurrency() : 1; - //int threads = std::thread::hardware_concurrency(); - int threads = boost::thread::physical_concurrency(); - m_controller = init_libfabric_controller(this, comm, rank, size, threads); - m_domain = m_controller->get_domain(); -} + // TODO fix the thread safety + // problem: controller is a singleton and has problems when 2 contexts are created in the + // following order: single threaded first, then multi-threaded after + //int threads = thread_safe ? std::thread::hardware_concurrency() : 1; + //int threads = std::thread::hardware_concurrency(); + int threads = boost::thread::physical_concurrency(); + m_controller = init_libfabric_controller(this, comm, rank, size, threads); + m_domain = m_controller->get_domain(); + } -communicator_impl* -context_impl::get_communicator() -{ - auto comm = new communicator_impl{this}; - m_comms_set.insert(comm); - return comm; -} + communicator_impl* context_impl::get_communicator() + { + auto comm = new communicator_impl{this}; + m_comms_set.insert(comm); + return comm; + } -const char* -context_impl::get_transport_option(const std::string& opt) -{ - if (opt == "name") { return "libfabric"; } - else if (opt == "progress") { return libfabric_progress_string(); } - else if (opt == "endpoint") { return libfabric_endpoint_string(); } - else if (opt == "rendezvous_threshold") + char const* context_impl::get_transport_option(std::string const& opt) { - static char buffer[32]; - std::string temp = std::to_string(m_controller->rendezvous_threshold()); - strncpy(buffer, temp.c_str(), std::min(size_t(31), std::strlen(temp.c_str()))); - return buffer; + if (opt == "name") { return "libfabric"; } + else if (opt == "progress") { return libfabric_progress_string(); } + else if (opt == "endpoint") { return libfabric_endpoint_string(); } + else if (opt == "rendezvous_threshold") + { + static char buffer[32]; + std::string temp = std::to_string(m_controller->rendezvous_threshold()); + strncpy(buffer, temp.c_str(), std::min(size_t(31), std::strlen(temp.c_str()))); + return buffer; + } + else { return "unspecified"; } } - else { return "unspecified"; } -} -std::shared_ptr -context_impl::init_libfabric_controller(oomph::context_impl* /*ctx*/, MPI_Comm comm, int rank, - int size, int threads) -{ - // only allow one thread to pass, make other wait - static std::mutex m_init_mutex; - std::lock_guard lock(m_init_mutex); - static std::shared_ptr instance(nullptr); - if (!instance.get()) + std::shared_ptr context_impl::init_libfabric_controller( + oomph::context_impl* /*ctx*/, MPI_Comm comm, int rank, int size, int threads) { - LF_DEB(src_deb, debug(NS_DEBUG::str<>("New Controller"), "rank", debug::dec<3>(rank), - "size", debug::dec<3>(size), "threads", debug::dec<3>(threads))); - instance.reset(new controller_type()); - instance->initialize(HAVE_LIBFABRIC_PROVIDER, rank == 0, size, threads, comm); + // only allow one thread to pass, make other wait + static std::mutex m_init_mutex; + std::lock_guard lock(m_init_mutex); + static std::shared_ptr instance(nullptr); + if (!instance.get()) + { + LF_DEB(src_deb, + debug(NS_DEBUG::str<>("New Controller"), "rank", debug::dec<3>(rank), "size", + debug::dec<3>(size), "threads", debug::dec<3>(threads))); + instance.reset(new controller_type()); + instance->initialize(HAVE_LIBFABRIC_PROVIDER, rank == 0, size, threads, comm); + } + return instance; } - return instance; -} -} // namespace oomph +} // namespace oomph diff --git a/src/libfabric/context.hpp b/src/libfabric/context.hpp index 256cb9fb..7a936223 100644 --- a/src/libfabric/context.hpp +++ b/src/libfabric/context.hpp @@ -9,8 +9,8 @@ */ #pragma once -#include #include +#include #include #include @@ -19,138 +19,142 @@ // paths relative to backend #include <../context_base.hpp> -#include #include +#include #include -namespace oomph -{ - -static NS_DEBUG::enable_print ctx_deb("CONTEXT"); - -using controller_type = libfabric::controller; - -class context_impl : public context_base -{ - public: - using region_type = libfabric::memory_segment; - using domain_type = region_type::provider_domain; - using device_region_type = libfabric::memory_segment; - using heap_type = hwmalloc::heap; - using callback_queue = boost::lockfree::queue, boost::lockfree::allocator>>; - - private: - heap_type m_heap; - domain_type* m_domain; - std::shared_ptr m_controller; - std::uintptr_t m_ctxt_tag; - - public: - // -------------------------------------------------- - // create a singleton ptr to a libfabric controller that - // can be shared between oomph context objects - static std::shared_ptr init_libfabric_controller(oomph::context_impl* ctx, - MPI_Comm comm, int rank, int size, int threads); - - // queue for shared recv callbacks - callback_queue m_recv_cb_queue; - // queue for canceled shared recv requests - callback_queue m_recv_cb_cancel; - - public: - context_impl(MPI_Comm comm, bool thread_safe, bool message_pool_never_free, - std::size_t message_pool_reserve); - context_impl(context_impl const&) = delete; - context_impl(context_impl&&) = delete; - - region_type make_region(void* const ptr, std::size_t size, int device_id) +namespace oomph { + + static NS_DEBUG::enable_print ctx_deb("CONTEXT"); + + using controller_type = libfabric::controller; + + class context_impl : public context_base { - if (m_controller->get_mrbind()) + public: + using region_type = libfabric::memory_segment; + using domain_type = region_type::provider_domain; + using device_region_type = libfabric::memory_segment; + using heap_type = hwmalloc::heap; + using callback_queue = boost::lockfree::queue, boost::lockfree::allocator>>; + + private: + heap_type m_heap; + domain_type* m_domain; + std::shared_ptr m_controller; + std::uintptr_t m_ctxt_tag; + + public: + // -------------------------------------------------- + // create a singleton ptr to a libfabric controller that + // can be shared between oomph context objects + static std::shared_ptr init_libfabric_controller( + oomph::context_impl* ctx, MPI_Comm comm, int rank, int size, int threads); + + // queue for shared recv callbacks + callback_queue m_recv_cb_queue; + // queue for canceled shared recv requests + callback_queue m_recv_cb_cancel; + + public: + context_impl(MPI_Comm comm, bool thread_safe, bool message_pool_never_free, + std::size_t message_pool_reserve); + context_impl(context_impl const&) = delete; + context_impl(context_impl&&) = delete; + + region_type make_region(void* const ptr, std::size_t size, int device_id) { - void* endpoint = m_controller->get_rx_endpoint().get_ep(); - return libfabric::memory_segment(m_domain, ptr, size, true, endpoint, device_id); + if (m_controller->get_mrbind()) + { + void* endpoint = m_controller->get_rx_endpoint().get_ep(); + return libfabric::memory_segment(m_domain, ptr, size, true, endpoint, device_id); + } + else + { + return libfabric::memory_segment(m_domain, ptr, size, false, nullptr, device_id); + } } - else { return libfabric::memory_segment(m_domain, ptr, size, false, nullptr, device_id); } - } - auto& get_heap() noexcept { return m_heap; } + auto& get_heap() noexcept { return m_heap; } - communicator_impl* get_communicator(); + communicator_impl* get_communicator(); - // we must modify all tags to use 32bits of context ptr for uniqueness - inline std::uintptr_t get_context_tag() { return m_ctxt_tag; } + // we must modify all tags to use 32bits of context ptr for uniqueness + inline std::uintptr_t get_context_tag() { return m_ctxt_tag; } - inline controller_type* get_controller() /*const */ { return m_controller.get(); } - const char* get_transport_option(const std::string& opt); + inline controller_type* get_controller() /*const */ { return m_controller.get(); } + char const* get_transport_option(std::string const& opt); - void progress() { get_controller()->poll_for_work_completions(nullptr); } + void progress() { get_controller()->poll_for_work_completions(nullptr); } - bool cancel_recv(detail::shared_request_state* s) - { - // get the original message operation context - auto op_ctx = &(s->m_operation_context); + bool cancel_recv(detail::shared_request_state* s) + { + // get the original message operation context + auto op_ctx = &(s->m_operation_context); - // submit the cancellation request - bool ok = (fi_cancel(&(get_controller()->get_rx_endpoint().get_ep()->fid), op_ctx) == 0); + // submit the cancellation request + bool ok = + (fi_cancel(&(get_controller()->get_rx_endpoint().get_ep()->fid), op_ctx) == 0); - // if the cancel operation failed completely, return - if (!ok) return false; + // if the cancel operation failed completely, return + if (!ok) return false; - bool found = false; - while (!found) - { - get_controller()->poll_recv_queue(get_controller()->get_rx_endpoint().get_rx_cq(), - nullptr); - // otherwise, poll until we know if it worked - std::stack temp_stack; - detail::shared_request_state* temp; - while (!found && m_recv_cb_cancel.pop(temp)) + bool found = false; + while (!found) { - if (temp == s) + get_controller()->poll_recv_queue( + get_controller()->get_rx_endpoint().get_rx_cq(), nullptr); + // otherwise, poll until we know if it worked + std::stack temp_stack; + detail::shared_request_state* temp; + while (!found && m_recv_cb_cancel.pop(temp)) { - // our recv was cancelled correctly - found = true; - LF_DEB(oomph::ctx_deb, debug(NS_DEBUG::str<>("Cancel shared"), "succeeded", - "op_ctx", NS_DEBUG::ptr(op_ctx))); - auto ptr = s->release_self_ref(); - s->set_canceled(); + if (temp == s) + { + // our recv was cancelled correctly + found = true; + LF_DEB(oomph::ctx_deb, + debug(NS_DEBUG::str<>("Cancel shared"), "succeeded", "op_ctx", + NS_DEBUG::ptr(op_ctx))); + auto ptr = s->release_self_ref(); + s->set_canceled(); + } + else + { + // a different cancel operation + temp_stack.push(temp); + } } - else + // return any weird unhandled cancels back to the queue + while (!temp_stack.empty()) { - // a different cancel operation - temp_stack.push(temp); + auto temp = temp_stack.top(); + temp_stack.pop(); + m_recv_cb_cancel.push(temp); } } - // return any weird unhandled cancels back to the queue - while (!temp_stack.empty()) - { - auto temp = temp_stack.top(); - temp_stack.pop(); - m_recv_cb_cancel.push(temp); - } + return found; } - return found; - } - unsigned int num_tag_bits() const noexcept { return 32; } -}; + unsigned int num_tag_bits() const noexcept { return 32; } + }; -// -------------------------------------------------------------------- -template<> -inline oomph::libfabric::memory_segment -register_memory(oomph::context_impl& c, void* const ptr, std::size_t size) -{ - return c.make_region(ptr, size, -2); -} + // -------------------------------------------------------------------- + template <> + inline oomph::libfabric::memory_segment + register_memory(oomph::context_impl& c, void* const ptr, std::size_t size) + { + return c.make_region(ptr, size, -2); + } #if OOMPH_ENABLE_DEVICE -template<> -inline oomph::libfabric::memory_segment -register_device_memory(context_impl& c, int device_id, void* ptr, std::size_t size) -{ - return c.make_region(ptr, size, device_id); -} + template <> + inline oomph::libfabric::memory_segment register_device_memory( + context_impl& c, int device_id, void* ptr, std::size_t size) + { + return c.make_region(ptr, size, device_id); + } #endif -} // namespace oomph +} // namespace oomph diff --git a/src/libfabric/controller.hpp b/src/libfabric/controller.hpp index 5becc148..95e3ad17 100644 --- a/src/libfabric/controller.hpp +++ b/src/libfabric/controller.hpp @@ -35,428 +35,436 @@ #include #include // -#include "oomph_libfabric_defines.hpp" +#include "controller_base.hpp" #include "fabric_error.hpp" #include "locality.hpp" #include "memory_region.hpp" +#include "oomph_libfabric_defines.hpp" #include "operation_context.hpp" -#include "controller_base.hpp" // #include // #include -namespace NS_DEBUG -{ -// cppcheck-suppress ConfigurationNotChecked +namespace NS_DEBUG { + // cppcheck-suppress ConfigurationNotChecked -using namespace oomph::debug; -template -inline /*constexpr*/ NS_DEBUG::print_threshold cnt_deb("CONTROL"); -// -static NS_DEBUG::enable_print cnt_err("CONTROL"); -} // namespace NS_DEBUG - -namespace oomph::libfabric -{ - -class controller : public controller_base -{ - public: - // -------------------------------------------------------------------- - controller() - : controller_base() - { - } + using namespace oomph::debug; + template + inline /*constexpr*/ NS_DEBUG::print_threshold cnt_deb("CONTROL"); + // + static NS_DEBUG::enable_print cnt_err("CONTROL"); +} // namespace NS_DEBUG - // -------------------------------------------------------------------- - void initialize_derived(std::string const&, bool, int, size_t, MPI_Comm mpi_comm) - { - // Broadcast address of all endpoints to all ranks - // and fill address vector with info - exchange_addresses(av_, mpi_comm); - } +namespace oomph::libfabric { - // -------------------------------------------------------------------- - constexpr fi_threading threadlevel_flags() + class controller : public controller_base { + public: + // -------------------------------------------------------------------- + controller() + : controller_base() + { + } + + // -------------------------------------------------------------------- + void initialize_derived(std::string const&, bool, int, size_t, MPI_Comm mpi_comm) + { + // Broadcast address of all endpoints to all ranks + // and fill address vector with info + exchange_addresses(av_, mpi_comm); + } + + // -------------------------------------------------------------------- + constexpr fi_threading threadlevel_flags() + { #if defined(HAVE_LIBFABRIC_GNI) /*|| defined(HAVE_LIBFABRIC_CXI)*/ - return FI_THREAD_ENDPOINT; + return FI_THREAD_ENDPOINT; #else - return FI_THREAD_SAFE; + return FI_THREAD_SAFE; #endif - } + } - // -------------------------------------------------------------------- - constexpr uint64_t caps_flags() - { + // -------------------------------------------------------------------- + constexpr uint64_t caps_flags() + { #if OOMPH_ENABLE_DEVICE && !defined(HAVE_LIBFABRIC_TCP) - std::int64_t hmem_flags = FI_HMEM; + std::int64_t hmem_flags = FI_HMEM; #else - std::int64_t hmem_flags = 0; + std::int64_t hmem_flags = 0; #endif - return hmem_flags | FI_MSG | FI_TAGGED | FI_RMA | FI_READ | FI_WRITE | FI_RECV | FI_SEND | - FI_TRANSMIT | FI_REMOTE_READ | FI_REMOTE_WRITE; - } - - // -------------------------------------------------------------------- - // we do not need to perform any special actions on init (to contact root node) - void setup_root_node_address(struct fi_info* /*info*/) {} + return hmem_flags | FI_MSG | FI_TAGGED | FI_RMA | FI_READ | FI_WRITE | FI_RECV | + FI_SEND | FI_TRANSMIT | FI_REMOTE_READ | FI_REMOTE_WRITE; + } - // -------------------------------------------------------------------- - // send address to rank 0 and receive array of all localities - void MPI_exchange_localities(fid_av* av, MPI_Comm comm, int rank, int size) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnt_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - std::vector localities(size * locality_defs::array_size, 0); - // - if (rank > 0) - { - LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("sending here"), iplocality(here_), - "size", locality_defs::array_size)); - /*int err = */ MPI_Send(here_.fabric_data(), locality_defs::array_size, MPI_CHAR, - 0, // dst rank - 0, // tag - comm); + // -------------------------------------------------------------------- + // we do not need to perform any special actions on init (to contact root node) + void setup_root_node_address(struct fi_info* /*info*/) {} - LF_DEB(NS_DEBUG::cnt_deb<9>, - debug(debug::str<>("receiving all"), "size", locality_defs::array_size)); - - MPI_Status status; - /*err = */ MPI_Recv(localities.data(), size * locality_defs::array_size, MPI_CHAR, - 0, // src rank - 0, // tag - comm, &status); - LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("received addresses"))); - } - else + // -------------------------------------------------------------------- + // send address to rank 0 and receive array of all localities + void MPI_exchange_localities(fid_av* av, MPI_Comm comm, int rank, int size) { - LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("receiving addresses"))); - memcpy(&localities[0], here_.fabric_data(), locality_defs::array_size); - for (int i = 1; i < size; ++i) + [[maybe_unused]] auto scp = NS_DEBUG::cnt_deb<9>.scope(NS_DEBUG::ptr(this), __func__); + std::vector localities(size * locality_defs::array_size, 0); + // + if (rank > 0) { LF_DEB(NS_DEBUG::cnt_deb<9>, - debug(debug::str<>("receiving address"), debug::dec<>(i))); + debug(debug::str<>("sending here"), iplocality(here_), "size", + locality_defs::array_size)); + /*int err = */ MPI_Send(here_.fabric_data(), locality_defs::array_size, MPI_CHAR, + 0, // dst rank + 0, // tag + comm); + + LF_DEB(NS_DEBUG::cnt_deb<9>, + debug(debug::str<>("receiving all"), "size", locality_defs::array_size)); + MPI_Status status; - /*int err = */ MPI_Recv(&localities[i * locality_defs::array_size], - size * locality_defs::array_size, MPI_CHAR, - i, // src rank - 0, // tag + /*err = */ MPI_Recv(localities.data(), size * locality_defs::array_size, MPI_CHAR, + 0, // src rank + 0, // tag comm, &status); - LF_DEB(NS_DEBUG::cnt_deb<9>, - debug(debug::str<>("received address"), debug::dec<>(i))); + LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("received addresses"))); + } + else + { + LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("receiving addresses"))); + memcpy(&localities[0], here_.fabric_data(), locality_defs::array_size); + for (int i = 1; i < size; ++i) + { + LF_DEB(NS_DEBUG::cnt_deb<9>, + debug(debug::str<>("receiving address"), debug::dec<>(i))); + MPI_Status status; + /*int err = */ MPI_Recv(&localities[i * locality_defs::array_size], + size * locality_defs::array_size, MPI_CHAR, + i, // src rank + 0, // tag + comm, &status); + LF_DEB(NS_DEBUG::cnt_deb<9>, + debug(debug::str<>("received address"), debug::dec<>(i))); + } + + LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("sending all"))); + for (int i = 1; i < size; ++i) + { + LF_DEB( + NS_DEBUG::cnt_deb<9>, debug(debug::str<>("sending to"), debug::dec<>(i))); + /*int err = */ MPI_Send(&localities[0], size * locality_defs::array_size, + MPI_CHAR, + i, // dst rank + 0, // tag + comm); + } } - LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("sending all"))); - for (int i = 1; i < size; ++i) + // all ranks should now have a full localities vector + LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("populating vector"))); + for (int i = 0; i < size; ++i) { - LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("sending to"), debug::dec<>(i))); - /*int err = */ MPI_Send(&localities[0], size * locality_defs::array_size, MPI_CHAR, - i, // dst rank - 0, // tag - comm); + locality temp; + int offset = i * locality_defs::array_size; + memcpy(temp.fabric_data_writable(), &localities[offset], locality_defs::array_size); + insert_address(av, temp); } } - // all ranks should now have a full localities vector - LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("populating vector"))); - for (int i = 0; i < size; ++i) + // -------------------------------------------------------------------- + // if we did not bootstrap, then fetch the list of all localities + // and insert each one into the address vector + void exchange_addresses(fid_av* av, MPI_Comm mpi_comm) { - locality temp; - int offset = i * locality_defs::array_size; - memcpy(temp.fabric_data_writable(), &localities[offset], locality_defs::array_size); - insert_address(av, temp); - } - } + [[maybe_unused]] auto scp = NS_DEBUG::cnt_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - // -------------------------------------------------------------------- - // if we did not bootstrap, then fetch the list of all localities - // and insert each one into the address vector - void exchange_addresses(fid_av* av, MPI_Comm mpi_comm) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnt_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - - int rank, size; - MPI_Comm_rank(mpi_comm, &rank); - MPI_Comm_size(mpi_comm, &size); + int rank, size; + MPI_Comm_rank(mpi_comm, &rank); + MPI_Comm_size(mpi_comm, &size); - LF_DEB(NS_DEBUG::cnt_deb<9>, - debug(debug::str<>("initialize_localities"), size, "localities")); + LF_DEB(NS_DEBUG::cnt_deb<9>, + debug(debug::str<>("initialize_localities"), size, "localities")); - MPI_exchange_localities(av, mpi_comm, rank, size); - debug_print_av_vector(size); - LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("Done localities"))); - } + MPI_exchange_localities(av, mpi_comm, rank, size); + debug_print_av_vector(size); + LF_DEB(NS_DEBUG::cnt_deb<9>, debug(debug::str<>("Done localities"))); + } - // -------------------------------------------------------------------- - inline constexpr bool bypass_tx_lock() - { + // -------------------------------------------------------------------- + inline constexpr bool bypass_tx_lock() + { #if defined(HAVE_LIBFABRIC_GNI) - return true; + return true; #elif defined(HAVE_LIBFABRIC_CXI) - // @todo : cxi provider is not yet thread safe using scalable endpoints - return false; + // @todo : cxi provider is not yet thread safe using scalable endpoints + return false; #else - return (threadlevel_flags() == FI_THREAD_SAFE || + return (threadlevel_flags() == FI_THREAD_SAFE || endpoint_type_ == endpoint_type::threadlocalTx); #endif - } + } - // -------------------------------------------------------------------- - inline controller_base::unique_lock get_tx_lock() - { - if (bypass_tx_lock()) return unique_lock(); - return unique_lock(send_mutex_); - } + // -------------------------------------------------------------------- + inline controller_base::unique_lock get_tx_lock() + { + if (bypass_tx_lock()) return unique_lock(); + return unique_lock(send_mutex_); + } - // -------------------------------------------------------------------- - inline controller_base::unique_lock try_tx_lock() - { - if (bypass_tx_lock()) return unique_lock(); - return unique_lock(send_mutex_, std::try_to_lock_t{}); - } + // -------------------------------------------------------------------- + inline controller_base::unique_lock try_tx_lock() + { + if (bypass_tx_lock()) return unique_lock(); + return unique_lock(send_mutex_, std::try_to_lock_t{}); + } - // -------------------------------------------------------------------- - inline constexpr bool bypass_rx_lock() - { + // -------------------------------------------------------------------- + inline constexpr bool bypass_rx_lock() + { #ifdef HAVE_LIBFABRIC_GNI - return true; + return true; #else - return ( - threadlevel_flags() == FI_THREAD_SAFE || endpoint_type_ == endpoint_type::scalableTxRx); + return (threadlevel_flags() == FI_THREAD_SAFE || + endpoint_type_ == endpoint_type::scalableTxRx); #endif - } + } - // -------------------------------------------------------------------- - inline controller_base::unique_lock get_rx_lock() - { - if (bypass_rx_lock()) return unique_lock(); - return unique_lock(recv_mutex_); - } + // -------------------------------------------------------------------- + inline controller_base::unique_lock get_rx_lock() + { + if (bypass_rx_lock()) return unique_lock(); + return unique_lock(recv_mutex_); + } - // -------------------------------------------------------------------- - inline controller_base::unique_lock try_rx_lock() - { - if (bypass_rx_lock()) return unique_lock(); - return unique_lock(recv_mutex_, std::try_to_lock_t{}); - } + // -------------------------------------------------------------------- + inline controller_base::unique_lock try_rx_lock() + { + if (bypass_rx_lock()) return unique_lock(); + return unique_lock(recv_mutex_, std::try_to_lock_t{}); + } - // -------------------------------------------------------------------- - int poll_send_queue(fid_cq* send_cq, void* user_data) - { + // -------------------------------------------------------------------- + int poll_send_queue(fid_cq* send_cq, void* user_data) + { #ifdef EXCESSIVE_POLLING_BACKOFF_MICRO_S - std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now(); - if (std::chrono::duration_cast(now - send_poll_stamp).count() < - EXCESSIVE_POLLING_BACKOFF_MICRO_S) - return 0; - send_poll_stamp = now; + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - send_poll_stamp) + .count() < EXCESSIVE_POLLING_BACKOFF_MICRO_S) + return 0; + send_poll_stamp = now; #endif - int ret; - fi_cq_msg_entry entry[max_completions_array_limit_]; - assert(max_completions_per_poll_ <= max_completions_array_limit_); - { - auto lock = try_tx_lock(); + int ret; + fi_cq_msg_entry entry[max_completions_array_limit_]; + assert(max_completions_per_poll_ <= max_completions_array_limit_); + { + auto lock = try_tx_lock(); - // if we're not threadlocal and didn't get the lock, - // then another thread is polling now, just exit - if (!bypass_tx_lock() && !lock.owns_lock()) { return -1; } + // if we're not threadlocal and didn't get the lock, + // then another thread is polling now, just exit + if (!bypass_tx_lock() && !lock.owns_lock()) { return -1; } - static auto polling = - NS_DEBUG::cnt_deb<9>.make_timer(1, debug::str<>("poll send queue")); - LF_DEB(NS_DEBUG::cnt_deb<9>, timed(polling, NS_DEBUG::ptr(send_cq))); + static auto polling = + NS_DEBUG::cnt_deb<9>.make_timer(1, debug::str<>("poll send queue")); + LF_DEB(NS_DEBUG::cnt_deb<9>, timed(polling, NS_DEBUG::ptr(send_cq))); - // poll for completions - { - ret = fi_cq_read(send_cq, &entry[0], max_completions_per_poll_); - } - // if there is an error, retrieve it - if (ret == -FI_EAVAIL) - { - struct fi_cq_err_entry e = {}; - int err_sz = fi_cq_readerr(send_cq, &e, 0); - (void)err_sz; - - // flags might not be set correctly - if ((e.flags & (FI_MSG | FI_SEND | FI_TAGGED)) != 0) + // poll for completions { - NS_DEBUG::cnt_err.error("txcq Error FI_EAVAIL for " - "FI_SEND with len", - debug::hex<6>(e.len), "context", NS_DEBUG::ptr(e.op_context), "code", - NS_DEBUG::dec<3>(e.err), "flags", debug::bin<16>(e.flags), "error", - fi_cq_strerror(send_cq, e.prov_errno, e.err_data, (char*)e.buf, e.len)); + ret = fi_cq_read(send_cq, &entry[0], max_completions_per_poll_); } - else if ((e.flags & FI_RMA) != 0) + // if there is an error, retrieve it + if (ret == -FI_EAVAIL) { - NS_DEBUG::cnt_err.error("txcq Error FI_EAVAIL for " - "FI_RMA with len", - debug::hex<6>(e.len), "context", NS_DEBUG::ptr(e.op_context), "code", - NS_DEBUG::dec<3>(e.err), "flags", debug::bin<16>(e.flags), "error", - fi_cq_strerror(send_cq, e.prov_errno, e.err_data, (char*)e.buf, e.len)); + struct fi_cq_err_entry e = {}; + int err_sz = fi_cq_readerr(send_cq, &e, 0); + (void) err_sz; + + // flags might not be set correctly + if ((e.flags & (FI_MSG | FI_SEND | FI_TAGGED)) != 0) + { + NS_DEBUG::cnt_err.error("txcq Error FI_EAVAIL for " + "FI_SEND with len", + debug::hex<6>(e.len), "context", NS_DEBUG::ptr(e.op_context), "code", + NS_DEBUG::dec<3>(e.err), "flags", debug::bin<16>(e.flags), "error", + fi_cq_strerror( + send_cq, e.prov_errno, e.err_data, (char*) e.buf, e.len)); + } + else if ((e.flags & FI_RMA) != 0) + { + NS_DEBUG::cnt_err.error("txcq Error FI_EAVAIL for " + "FI_RMA with len", + debug::hex<6>(e.len), "context", NS_DEBUG::ptr(e.op_context), "code", + NS_DEBUG::dec<3>(e.err), "flags", debug::bin<16>(e.flags), "error", + fi_cq_strerror( + send_cq, e.prov_errno, e.err_data, (char*) e.buf, e.len)); + } + operation_context* handler = reinterpret_cast(e.op_context); + handler->handle_error(e); + return 0; } - operation_context* handler = reinterpret_cast(e.op_context); - handler->handle_error(e); - return 0; } - } - // - // exit possibly locked region and process each completion - // - if (ret > 0) - { - int processed = 0; - for (int i = 0; i < ret; ++i) + // + // exit possibly locked region and process each completion + // + if (ret > 0) { - ++sends_complete; - LF_DEB(NS_DEBUG::cnt_deb<9>, - debug(debug::str<>("Completion"), i, debug::dec<2>(i), "txcq flags", - fi_tostr(&entry[i].flags, FI_TYPE_CQ_EVENT_FLAGS), "(", - debug::dec<>(entry[i].flags), ")", "context", - NS_DEBUG::ptr(entry[i].op_context), "length", debug::hex<6>(entry[i].len))); - if ((entry[i].flags & (FI_TAGGED | FI_SEND | FI_MSG)) != 0) + int processed = 0; + for (int i = 0; i < ret; ++i) { + ++sends_complete; LF_DEB(NS_DEBUG::cnt_deb<9>, - debug(debug::str<>("Completion"), "txcq tagged send completion", - NS_DEBUG::ptr(entry[i].op_context))); - - operation_context* handler = - reinterpret_cast(entry[i].op_context); - processed += handler->handle_tagged_send_completion(user_data); - } - else - { - NS_DEBUG::cnt_err.error("Received an unknown txcq completion", - debug::dec<>(entry[i].flags), debug::bin<64>(entry[i].flags)); - std::terminate(); + debug(debug::str<>("Completion"), i, debug::dec<2>(i), "txcq flags", + fi_tostr(&entry[i].flags, FI_TYPE_CQ_EVENT_FLAGS), "(", + debug::dec<>(entry[i].flags), ")", "context", + NS_DEBUG::ptr(entry[i].op_context), "length", + debug::hex<6>(entry[i].len))); + if ((entry[i].flags & (FI_TAGGED | FI_SEND | FI_MSG)) != 0) + { + LF_DEB(NS_DEBUG::cnt_deb<9>, + debug(debug::str<>("Completion"), "txcq tagged send completion", + NS_DEBUG::ptr(entry[i].op_context))); + + operation_context* handler = + reinterpret_cast(entry[i].op_context); + processed += handler->handle_tagged_send_completion(user_data); + } + else + { + NS_DEBUG::cnt_err.error("Received an unknown txcq completion", + debug::dec<>(entry[i].flags), debug::bin<64>(entry[i].flags)); + std::terminate(); + } } + return processed; } - return processed; - } - else if (ret == 0 || ret == -FI_EAGAIN) - { - // do nothing, we will try again on the next check + else if (ret == 0 || ret == -FI_EAGAIN) + { + // do nothing, we will try again on the next check + } + else { NS_DEBUG::cnt_err.error("unknown error in completion txcq read"); } + return 0; } - else { NS_DEBUG::cnt_err.error("unknown error in completion txcq read"); } - return 0; - } - // -------------------------------------------------------------------- - int poll_recv_queue(fid_cq* rx_cq, void* user_data) - { + // -------------------------------------------------------------------- + int poll_recv_queue(fid_cq* rx_cq, void* user_data) + { #ifdef EXCESSIVE_POLLING_BACKOFF_MICRO_S - std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now(); - if (std::chrono::duration_cast(now - recv_poll_stamp).count() < - EXCESSIVE_POLLING_BACKOFF_MICRO_S) - return 0; - recv_poll_stamp = now; + std::chrono::steady_clock::time_point now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - recv_poll_stamp) + .count() < EXCESSIVE_POLLING_BACKOFF_MICRO_S) + return 0; + recv_poll_stamp = now; #endif - int ret; - fi_cq_msg_entry entry[max_completions_array_limit_]; - assert(max_completions_per_poll_ <= max_completions_array_limit_); - { - auto lock = get_rx_lock(); + int ret; + fi_cq_msg_entry entry[max_completions_array_limit_]; + assert(max_completions_per_poll_ <= max_completions_array_limit_); + { + auto lock = get_rx_lock(); - // if we're not threadlocal and didn't get the lock, - // then another thread is polling now, just exit - if (!bypass_rx_lock() && !lock.owns_lock()) { return -1; } + // if we're not threadlocal and didn't get the lock, + // then another thread is polling now, just exit + if (!bypass_rx_lock() && !lock.owns_lock()) { return -1; } - static auto polling = - NS_DEBUG::cnt_deb<2>.make_timer(1, debug::str<>("poll recv queue")); - LF_DEB(NS_DEBUG::cnt_deb<2>, timed(polling, NS_DEBUG::ptr(rx_cq))); + static auto polling = + NS_DEBUG::cnt_deb<2>.make_timer(1, debug::str<>("poll recv queue")); + LF_DEB(NS_DEBUG::cnt_deb<2>, timed(polling, NS_DEBUG::ptr(rx_cq))); - // poll for completions - { - ret = fi_cq_read(rx_cq, &entry[0], max_completions_per_poll_); - } - // if there is an error, retrieve it - if (ret == -FI_EAVAIL) - { - // read the full error status - struct fi_cq_err_entry e = {}; - int err_sz = fi_cq_readerr(rx_cq, &e, 0); - (void)err_sz; - // from the manpage 'man 3 fi_cq_readerr' - if (e.err == FI_ECANCELED) + // poll for completions { - LF_DEB(NS_DEBUG::cnt_deb<1>, - debug(debug::str<>("rxcq Cancelled"), "flags", debug::hex<6>(e.flags), - "len", debug::hex<6>(e.len), "context", NS_DEBUG::ptr(e.op_context))); - // the request was cancelled, we can simply exit - // as the canceller will have doone any cleanup needed - operation_context* handler = reinterpret_cast(e.op_context); - handler->handle_cancelled(); - return 0; + ret = fi_cq_read(rx_cq, &entry[0], max_completions_per_poll_); } - else if (e.err != FI_SUCCESS) + // if there is an error, retrieve it + if (ret == -FI_EAVAIL) { - NS_DEBUG::cnt_err.error(debug::str<>("poll_recv_queue"), "error code", - debug::dec<>(-e.err), "flags", debug::hex<6>(e.flags), "len", - debug::hex<6>(e.len), "context", NS_DEBUG::ptr(e.op_context), "error msg", - fi_cq_strerror(rx_cq, e.prov_errno, e.err_data, (char*)e.buf, e.len)); + // read the full error status + struct fi_cq_err_entry e = {}; + int err_sz = fi_cq_readerr(rx_cq, &e, 0); + (void) err_sz; + // from the manpage 'man 3 fi_cq_readerr' + if (e.err == FI_ECANCELED) + { + LF_DEB(NS_DEBUG::cnt_deb<1>, + debug(debug::str<>("rxcq Cancelled"), "flags", debug::hex<6>(e.flags), + "len", debug::hex<6>(e.len), "context", + NS_DEBUG::ptr(e.op_context))); + // the request was cancelled, we can simply exit + // as the canceller will have doone any cleanup needed + operation_context* handler = + reinterpret_cast(e.op_context); + handler->handle_cancelled(); + return 0; + } + else if (e.err != FI_SUCCESS) + { + NS_DEBUG::cnt_err.error(debug::str<>("poll_recv_queue"), "error code", + debug::dec<>(-e.err), "flags", debug::hex<6>(e.flags), "len", + debug::hex<6>(e.len), "context", NS_DEBUG::ptr(e.op_context), + "error msg", + fi_cq_strerror(rx_cq, e.prov_errno, e.err_data, (char*) e.buf, e.len)); + } + operation_context* handler = reinterpret_cast(e.op_context); + if (handler) handler->handle_error(e); + return 0; } - operation_context* handler = reinterpret_cast(e.op_context); - if (handler) handler->handle_error(e); - return 0; } - } - // - // release the lock and process each completion - // - if (ret > 0) - { - int processed = 0; - for (int i = 0; i < ret; ++i) + // + // release the lock and process each completion + // + if (ret > 0) { - ++recvs_complete; - LF_DEB(NS_DEBUG::cnt_deb<2>, - debug(debug::str<>("Completion"), i, "rxcq flags", - fi_tostr(&entry[i].flags, FI_TYPE_CQ_EVENT_FLAGS), "(", - debug::dec<>(entry[i].flags), ")", "context", - NS_DEBUG::ptr(entry[i].op_context), "length", debug::hex<6>(entry[i].len))); - if ((entry[i].flags & (FI_TAGGED | FI_RECV)) != 0) + int processed = 0; + for (int i = 0; i < ret; ++i) { + ++recvs_complete; LF_DEB(NS_DEBUG::cnt_deb<2>, - debug(debug::str<>("Completion"), "rxcq tagged recv completion", - NS_DEBUG::ptr(entry[i].op_context))); - - operation_context* handler = - reinterpret_cast(entry[i].op_context); - processed += handler->handle_tagged_recv_completion(user_data); - } - else - { - NS_DEBUG::cnt_err.error("Received an unknown rxcq completion", - debug::dec<>(entry[i].flags), debug::bin<64>(entry[i].flags)); - std::terminate(); + debug(debug::str<>("Completion"), i, "rxcq flags", + fi_tostr(&entry[i].flags, FI_TYPE_CQ_EVENT_FLAGS), "(", + debug::dec<>(entry[i].flags), ")", "context", + NS_DEBUG::ptr(entry[i].op_context), "length", + debug::hex<6>(entry[i].len))); + if ((entry[i].flags & (FI_TAGGED | FI_RECV)) != 0) + { + LF_DEB(NS_DEBUG::cnt_deb<2>, + debug(debug::str<>("Completion"), "rxcq tagged recv completion", + NS_DEBUG::ptr(entry[i].op_context))); + + operation_context* handler = + reinterpret_cast(entry[i].op_context); + processed += handler->handle_tagged_recv_completion(user_data); + } + else + { + NS_DEBUG::cnt_err.error("Received an unknown rxcq completion", + debug::dec<>(entry[i].flags), debug::bin<64>(entry[i].flags)); + std::terminate(); + } } + return processed; + } + else if (ret == 0 || ret == -FI_EAGAIN) + { + // do nothing, we will try again on the next check } - return processed; + else { NS_DEBUG::cnt_err.error("unknown error in completion rxcq read"); } + return 0; } - else if (ret == 0 || ret == -FI_EAGAIN) + + // Jobs started using mpi don't have this info + struct fi_info* set_src_dst_addresses(struct fi_info* info, bool tx) { - // do nothing, we will try again on the next check + (void) info; // unused variable warning + (void) tx; // unused variable warning + + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("fi_dupinfo"))); + struct fi_info* hints = fi_dupinfo(info); + if (!hints) throw NS_LIBFABRIC::fabric_error(0, "fi_dupinfo"); + // clear any Rx address data that might be set + // free(hints->src_addr); + // hints->src_addr = nullptr; + // hints->src_addrlen = 0; + free(hints->dest_addr); + hints->dest_addr = nullptr; + hints->dest_addrlen = 0; + return hints; } - else { NS_DEBUG::cnt_err.error("unknown error in completion rxcq read"); } - return 0; - } + }; - // Jobs started using mpi don't have this info - struct fi_info* set_src_dst_addresses(struct fi_info* info, bool tx) - { - (void)info; // unused variable warning - (void)tx; // unused variable warning - - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("fi_dupinfo"))); - struct fi_info* hints = fi_dupinfo(info); - if (!hints) throw NS_LIBFABRIC::fabric_error(0, "fi_dupinfo"); - // clear any Rx address data that might be set - // free(hints->src_addr); - // hints->src_addr = nullptr; - // hints->src_addrlen = 0; - free(hints->dest_addr); - hints->dest_addr = nullptr; - hints->dest_addrlen = 0; - return hints; - } -}; - -} // namespace oomph::libfabric +} // namespace oomph::libfabric diff --git a/src/libfabric/controller_base.hpp b/src/libfabric/controller_base.hpp index e1ce377e..a5eb1705 100644 --- a/src/libfabric/controller_base.hpp +++ b/src/libfabric/controller_base.hpp @@ -53,15 +53,13 @@ // ---------------------------------------- // auto progress (libfabric thread) or manual // ---------------------------------------- -static fi_progress -libfabric_progress_type() +static fi_progress libfabric_progress_type() { if (std::getenv("LIBFABRIC_AUTO_PROGRESS") == nullptr) return FI_PROGRESS_MANUAL; return FI_PROGRESS_AUTO; } -static const char* -libfabric_progress_string() +static char const* libfabric_progress_string() { if (libfabric_progress_type() == FI_PROGRESS_AUTO) return "auto"; return "manual"; @@ -93,8 +91,7 @@ enum class endpoint_type : int // ---------------------------------------- // single endpoint or separate for send/recv // ---------------------------------------- -static endpoint_type -libfabric_endpoint_type() +static endpoint_type libfabric_endpoint_type() { auto env_str = std::getenv("LIBFABRIC_ENDPOINT_TYPE"); if (env_str == nullptr) return endpoint_type::single; @@ -114,8 +111,7 @@ libfabric_endpoint_type() return endpoint_type::single; } -static const char* -libfabric_endpoint_string() +static char const* libfabric_endpoint_string() { auto lf_ep_type = libfabric_endpoint_type(); if (lf_ep_type == endpoint_type::multiple) return "multiple"; @@ -128,8 +124,7 @@ libfabric_endpoint_string() // ---------------------------------------- // number of completions to handle per poll // ---------------------------------------- -static int -libfabric_completions_per_poll() +static int libfabric_completions_per_poll() { auto env_str = std::getenv("LIBFABRIC_POLL_SIZE"); if (env_str != nullptr) @@ -148,8 +143,7 @@ libfabric_completions_per_poll() // ---------------------------------------- // Eager/Rendezvous threshold // ---------------------------------------- -static int -libfabric_rendezvous_threshold(int def_val) +static int libfabric_rendezvous_threshold(int def_val) { auto env_str = std::getenv("LIBFABRIC_RENDEZVOUS_THRESHOLD"); if (env_str != nullptr) @@ -170,9 +164,9 @@ libfabric_rendezvous_threshold(int def_val) // Needed on Cray for GNI extensions // ------------------------------------------------ #ifdef HAVE_LIBFABRIC_GNI -#include "rdma/fi_ext_gni.h" +# include "rdma/fi_ext_gni.h" //#define OOMPH_GNI_REG "none" -#define OOMPH_GNI_REG "internal" +# define OOMPH_GNI_REG "internal" //#define OOMPH_GNI_REG "udreg" static std::vector> gni_strs = { @@ -213,19 +207,18 @@ static std::vector> gni_ints = { // api 2.0, then we ask for that, but the cxi legacy library on daint only supports 1.15, // so drop back to that version if needed #if defined(OOMPH_LIBFABRIC_V1_API) -#define LIBFABRIC_FI_VERSION_MAJOR 1 -#define LIBFABRIC_FI_VERSION_MINOR 15 +# define LIBFABRIC_FI_VERSION_MAJOR 1 +# define LIBFABRIC_FI_VERSION_MINOR 15 #else -#define LIBFABRIC_FI_VERSION_MAJOR 2 -#define LIBFABRIC_FI_VERSION_MINOR 0 +# define LIBFABRIC_FI_VERSION_MAJOR 2 +# define LIBFABRIC_FI_VERSION_MINOR 0 #endif -namespace NS_DEBUG -{ -// cppcheck-suppress ConfigurationNotChecked -static NS_DEBUG::enable_print cnb_deb("CONBASE"); -static NS_DEBUG::enable_print cnb_err("CONBASE"); -} // namespace NS_DEBUG +namespace NS_DEBUG { + // cppcheck-suppress ConfigurationNotChecked + static NS_DEBUG::enable_print cnb_deb("CONBASE"); + static NS_DEBUG::enable_print cnb_err("CONBASE"); +} // namespace NS_DEBUG /** @brief a class to return the number of progressed callbacks */ struct progress_status @@ -237,7 +230,7 @@ struct progress_status int num_sends() const noexcept { return m_num_sends; } int num_recvs() const noexcept { return m_num_recvs; } - progress_status& operator+=(const progress_status& other) noexcept + progress_status& operator+=(progress_status const& other) noexcept { m_num_sends += other.m_num_sends; m_num_recvs += other.m_num_recvs; @@ -245,814 +238,822 @@ struct progress_status } }; -namespace NS_LIBFABRIC -{ -/// A wrapper around fi_close that reports any error -/// Because we use so many handles, we must be careful to -/// delete them all before closing resources that use them -template -void -fidclose(Handle fid, const char* msg) -{ - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("closing"), msg)); - int ret = fi_close(fid); - if (ret == -FI_EBUSY) { throw NS_LIBFABRIC::fabric_error(ret, "fi_close EBUSY"); } - else if (ret == FI_SUCCESS) { return; } - throw NS_LIBFABRIC::fabric_error(ret, "fi_close error"); -} - -/// when using thread local endpoints, we encapsulate things that -/// are needed to manage an endpoint -struct endpoint_wrapper -{ - private: - friend class controller; - - fid_ep* ep_ = nullptr; - fid_cq* rq_ = nullptr; - fid_cq* tq_ = nullptr; - const char* name_ = nullptr; - - public: - endpoint_wrapper() {} - endpoint_wrapper(fid_ep* ep, fid_cq* rq, fid_cq* tq, const char* name) - : ep_(ep) - , rq_(rq) - , tq_(tq) - , name_(name) +namespace NS_LIBFABRIC { + /// A wrapper around fi_close that reports any error + /// Because we use so many handles, we must be careful to + /// delete them all before closing resources that use them + template + void fidclose(Handle fid, char const* msg) { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, name_); + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("closing"), msg)); + int ret = fi_close(fid); + if (ret == -FI_EBUSY) { throw NS_LIBFABRIC::fabric_error(ret, "fi_close EBUSY"); } + else if (ret == FI_SUCCESS) { return; } + throw NS_LIBFABRIC::fabric_error(ret, "fi_close error"); } - // to keep boost::lockfree happy, we need these copy operators - endpoint_wrapper(const endpoint_wrapper& ep) = default; - endpoint_wrapper& operator=(const endpoint_wrapper& ep) = default; - - void cleanup() + /// when using thread local endpoints, we encapsulate things that + /// are needed to manage an endpoint + struct endpoint_wrapper { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, name_); - if (ep_) - { - fidclose(&ep_->fid, "endpoint"); - ep_ = nullptr; - } - if (rq_) + private: + friend class controller; + + fid_ep* ep_ = nullptr; + fid_cq* rq_ = nullptr; + fid_cq* tq_ = nullptr; + char const* name_ = nullptr; + + public: + endpoint_wrapper() {} + endpoint_wrapper(fid_ep* ep, fid_cq* rq, fid_cq* tq, char const* name) + : ep_(ep) + , rq_(rq) + , tq_(tq) + , name_(name) { - fidclose(&rq_->fid, "rq"); - rq_ = nullptr; - } - if (tq_) - { - fidclose(&tq_->fid, "tq"); - tq_ = nullptr; + [[maybe_unused]] auto scp = + NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, name_); } - } - - inline fid_ep* get_ep() { return ep_; } - inline fid_cq* get_rx_cq() { return rq_; } - inline fid_cq* get_tx_cq() { return tq_; } - inline void set_tx_cq(fid_cq* cq) { tq_ = cq; } - inline const char* get_name() { return name_; } -}; - -using region_type = NS_MEMORY::memory_handle; -using endpoint_context_pool = - boost::lockfree::queue>; - -struct stack_endpoint -{ - endpoint_wrapper endpoint_; - endpoint_context_pool* pool_; - // - stack_endpoint() - : endpoint_() - , pool_(nullptr) - { - } - // - stack_endpoint(fid_ep* ep, fid_cq* rq, fid_cq* tq, const char* name, - endpoint_context_pool* pool) - : endpoint_(ep, rq, tq, name) - , pool_(pool) - { - } - // - stack_endpoint& operator=(stack_endpoint&& other) - { - endpoint_ = std::move(other.endpoint_); - pool_ = std::exchange(other.pool_, nullptr); - return *this; - } - - ~stack_endpoint() - { - if (!pool_) return; - LF_DEB(NS_DEBUG::cnb_deb, - trace(debug::str<>("Scalable Ep"), "used push", "ep", NS_DEBUG::ptr(get_ep()), "tx cq", - NS_DEBUG::ptr(get_tx_cq()), "rx cq", NS_DEBUG::ptr(get_rx_cq()))); - pool_->push(endpoint_); - } - - inline fid_ep* get_ep() { return endpoint_.get_ep(); } - inline fid_cq* get_rx_cq() { return endpoint_.get_rx_cq(); } + // to keep boost::lockfree happy, we need these copy operators + endpoint_wrapper(endpoint_wrapper const& ep) = default; + endpoint_wrapper& operator=(endpoint_wrapper const& ep) = default; - inline fid_cq* get_tx_cq() { return endpoint_.get_tx_cq(); } -}; - -struct endpoints_lifetime_manager -{ - // threadlocal endpoints - static inline thread_local stack_endpoint tl_tx_; - static inline thread_local stack_endpoint tl_stx_; - static inline thread_local stack_endpoint tl_srx_; - // non threadlocal endpoints, tx/rx - endpoint_wrapper ep_tx_; - endpoint_wrapper ep_rx_; -}; - -template -class controller_base -{ - public: - typedef std::mutex mutex_type; - typedef std::lock_guard scoped_lock; - typedef std::unique_lock unique_lock; + void cleanup() + { + [[maybe_unused]] auto scp = + NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, name_); + if (ep_) + { + fidclose(&ep_->fid, "endpoint"); + ep_ = nullptr; + } + if (rq_) + { + fidclose(&rq_->fid, "rq"); + rq_ = nullptr; + } + if (tq_) + { + fidclose(&tq_->fid, "tq"); + tq_ = nullptr; + } + } - protected: - // For threadlocal/scalable endpoints, - // we use a dedicated threadlocal endpoint wrapper - std::unique_ptr eps_; + inline fid_ep* get_ep() { return ep_; } + inline fid_cq* get_rx_cq() { return rq_; } + inline fid_cq* get_tx_cq() { return tq_; } + inline void set_tx_cq(fid_cq* cq) { tq_ = cq; } + inline char const* get_name() { return name_; } + }; + using region_type = NS_MEMORY::memory_handle; using endpoint_context_pool = boost::lockfree::queue>; - endpoint_context_pool tx_endpoints_; - endpoint_context_pool rx_endpoints_; - - struct fi_info* fabric_info_; - struct fid_fabric* fabric_; - struct fid_domain* fabric_domain_; - struct fid_pep* ep_passive_; - - struct fid_av* av_; - endpoint_type endpoint_type_; - - locality here_; - locality root_; - // used during queue creation setup and during polling - mutex_type controller_mutex_; - - // used to protect send/recv resources - alignas(64) mutex_type send_mutex_; - alignas(64) mutex_type recv_mutex_; - - std::size_t tx_inject_size_; - std::size_t tx_attr_size_; - std::size_t rx_attr_size_; - - uint32_t max_completions_per_poll_; - uint32_t msg_rendezvous_threshold_; - inline static constexpr uint32_t max_completions_array_limit_ = 256; + struct stack_endpoint + { + endpoint_wrapper endpoint_; + endpoint_context_pool* pool_; + // + stack_endpoint() + : endpoint_() + , pool_(nullptr) + { + } + // + stack_endpoint( + fid_ep* ep, fid_cq* rq, fid_cq* tq, char const* name, endpoint_context_pool* pool) + : endpoint_(ep, rq, tq, name) + , pool_(pool) + { + } + // + stack_endpoint& operator=(stack_endpoint&& other) + { + endpoint_ = std::move(other.endpoint_); + pool_ = std::exchange(other.pool_, nullptr); + return *this; + } - static inline thread_local std::chrono::steady_clock::time_point send_poll_stamp; - static inline thread_local std::chrono::steady_clock::time_point recv_poll_stamp; + ~stack_endpoint() + { + if (!pool_) return; + LF_DEB(NS_DEBUG::cnb_deb, + trace(debug::str<>("Scalable Ep"), "used push", "ep", NS_DEBUG::ptr(get_ep()), + "tx cq", NS_DEBUG::ptr(get_tx_cq()), "rx cq", NS_DEBUG::ptr(get_rx_cq()))); + pool_->push(endpoint_); + } - // set if FI_MR_LOCAL is required (local access requires binding) - bool mrlocal = false; - // set if FI_MR_ENDPOINT is required (per endpoint memory binding) - bool mrbind = false; - // set if FI_MR_HRMEM provider requires heterogeneous memory registration - bool mrhmem = false; + inline fid_ep* get_ep() { return endpoint_.get_ep(); } - public: - bool get_mrbind() { return mrbind; } + inline fid_cq* get_rx_cq() { return endpoint_.get_rx_cq(); } - public: - NS_LIBFABRIC::simple_counter sends_posted_; - NS_LIBFABRIC::simple_counter recvs_posted_; - NS_LIBFABRIC::simple_counter sends_readied_; - NS_LIBFABRIC::simple_counter recvs_readied_; - NS_LIBFABRIC::simple_counter sends_complete; - NS_LIBFABRIC::simple_counter recvs_complete; + inline fid_cq* get_tx_cq() { return endpoint_.get_tx_cq(); } + }; - void finvoke(const char* msg, const char* err, int ret) + struct endpoints_lifetime_manager { - LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>(msg))); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, err); - } + // threadlocal endpoints + static inline thread_local stack_endpoint tl_tx_; + static inline thread_local stack_endpoint tl_stx_; + static inline thread_local stack_endpoint tl_srx_; + // non threadlocal endpoints, tx/rx + endpoint_wrapper ep_tx_; + endpoint_wrapper ep_rx_; + }; - public: - // -------------------------------------------------------------------- - controller_base() - : eps_(nullptr) - , tx_endpoints_(1) - , rx_endpoints_(1) - , fabric_info_(nullptr) - , fabric_(nullptr) - , fabric_domain_(nullptr) - , ep_passive_(nullptr) - , av_(nullptr) - , tx_inject_size_(0) - , tx_attr_size_(0) - , rx_attr_size_(0) - , max_completions_per_poll_(1) - , msg_rendezvous_threshold_(0x4000) - , sends_posted_(0) - , recvs_posted_(0) - , sends_readied_(0) - , recvs_readied_(0) - , sends_complete(0) - , recvs_complete(0) + template + class controller_base { - } + public: + typedef std::mutex mutex_type; + typedef std::lock_guard scoped_lock; + typedef std::unique_lock unique_lock; + + protected: + // For threadlocal/scalable endpoints, + // we use a dedicated threadlocal endpoint wrapper + std::unique_ptr eps_; + + using endpoint_context_pool = + boost::lockfree::queue>; + endpoint_context_pool tx_endpoints_; + endpoint_context_pool rx_endpoints_; + + struct fi_info* fabric_info_; + struct fid_fabric* fabric_; + struct fid_domain* fabric_domain_; + struct fid_pep* ep_passive_; + + struct fid_av* av_; + endpoint_type endpoint_type_; + + locality here_; + locality root_; + + // used during queue creation setup and during polling + mutex_type controller_mutex_; + + // used to protect send/recv resources + alignas(64) mutex_type send_mutex_; + alignas(64) mutex_type recv_mutex_; + + std::size_t tx_inject_size_; + std::size_t tx_attr_size_; + std::size_t rx_attr_size_; + + uint32_t max_completions_per_poll_; + uint32_t msg_rendezvous_threshold_; + inline static constexpr uint32_t max_completions_array_limit_ = 256; + + static inline thread_local std::chrono::steady_clock::time_point send_poll_stamp; + static inline thread_local std::chrono::steady_clock::time_point recv_poll_stamp; + + // set if FI_MR_LOCAL is required (local access requires binding) + bool mrlocal = false; + // set if FI_MR_ENDPOINT is required (per endpoint memory binding) + bool mrbind = false; + // set if FI_MR_HRMEM provider requires heterogeneous memory registration + bool mrhmem = false; + + public: + bool get_mrbind() { return mrbind; } + + public: + NS_LIBFABRIC::simple_counter sends_posted_; + NS_LIBFABRIC::simple_counter recvs_posted_; + NS_LIBFABRIC::simple_counter sends_readied_; + NS_LIBFABRIC::simple_counter recvs_readied_; + NS_LIBFABRIC::simple_counter sends_complete; + NS_LIBFABRIC::simple_counter recvs_complete; + + void finvoke(char const* msg, char const* err, int ret) + { + LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>(msg))); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, err); + } - // -------------------------------------------------------------------- - // clean up all resources - ~controller_base() - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - unsigned int messages_handled_ = 0; - unsigned int rma_reads_ = 0; - unsigned int recv_deletes_ = 0; + public: + // -------------------------------------------------------------------- + controller_base() + : eps_(nullptr) + , tx_endpoints_(1) + , rx_endpoints_(1) + , fabric_info_(nullptr) + , fabric_(nullptr) + , fabric_domain_(nullptr) + , ep_passive_(nullptr) + , av_(nullptr) + , tx_inject_size_(0) + , tx_attr_size_(0) + , rx_attr_size_(0) + , max_completions_per_poll_(1) + , msg_rendezvous_threshold_(0x4000) + , sends_posted_(0) + , recvs_posted_(0) + , sends_readied_(0) + , recvs_readied_(0) + , sends_complete(0) + , recvs_complete(0) + { + } - LF_DEB(NS_DEBUG::cnb_deb, - debug(debug::str<>("counters"), "Received messages", debug::dec<>(messages_handled_), - "Total reads", debug::dec<>(rma_reads_), "Total deletes", - debug::dec<>(recv_deletes_), "deletes error", - debug::dec<>(messages_handled_ - recv_deletes_))); + // -------------------------------------------------------------------- + // clean up all resources + ~controller_base() + { + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + unsigned int messages_handled_ = 0; + unsigned int rma_reads_ = 0; + unsigned int recv_deletes_ = 0; - tx_endpoints_.consume_all([](auto&& ep) { ep.cleanup(); }); - rx_endpoints_.consume_all([](auto&& ep) { ep.cleanup(); }); + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("counters"), "Received messages", + debug::dec<>(messages_handled_), "Total reads", debug::dec<>(rma_reads_), + "Total deletes", debug::dec<>(recv_deletes_), "deletes error", + debug::dec<>(messages_handled_ - recv_deletes_))); - // No cleanup threadlocals : done by consume_all cleanup above - // eps_->tl_tx_.endpoint_.cleanup(); - // eps_->tl_stx_.endpoint_.cleanup(); - // eps_->tl_srx_.endpoint_.cleanup(); + tx_endpoints_.consume_all([](auto&& ep) { ep.cleanup(); }); + rx_endpoints_.consume_all([](auto&& ep) { ep.cleanup(); }); - // non threadlocal endpoints, tx/rx - eps_->ep_tx_.cleanup(); - eps_->ep_rx_.cleanup(); + // No cleanup threadlocals : done by consume_all cleanup above + // eps_->tl_tx_.endpoint_.cleanup(); + // eps_->tl_stx_.endpoint_.cleanup(); + // eps_->tl_srx_.endpoint_.cleanup(); - // Cleanup endpoints - eps_.reset(nullptr); + // non threadlocal endpoints, tx/rx + eps_->ep_tx_.cleanup(); + eps_->ep_rx_.cleanup(); - // delete adddress vector - fidclose(&av_->fid, "Address Vector"); + // Cleanup endpoints + eps_.reset(nullptr); - try - { - fidclose(&fabric_domain_->fid, "Domain"); - } - catch (fabric_error& e) - { - std::cout << "fabric domain close failed : Ensure all RMA " - "objects are freed before program termination" - << std::endl; - } - fidclose(&fabric_->fid, "Fabric"); + // delete adddress vector + fidclose(&av_->fid, "Address Vector"); - // clean up - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("freeing fabric_info"))); + try + { + fidclose(&fabric_domain_->fid, "Domain"); + } + catch (fabric_error& e) + { + std::cout << "fabric domain close failed : Ensure all RMA " + "objects are freed before program termination" + << std::endl; + } + fidclose(&fabric_->fid, "Fabric"); - fi_freeinfo(fabric_info_); - } + // clean up + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("freeing fabric_info"))); - // -------------------------------------------------------------------- - // setup an endpoint for receiving messages, - // usually an rx endpoint is shared by all threads - endpoint_wrapper create_rx_endpoint(struct fid_domain* domain, struct fi_info* info, - struct fid_av* av) - { - auto ep_rx = new_endpoint_active(domain, info, false); + fi_freeinfo(fabric_info_); + } - // bind address vector - bind_address_vector_to_endpoint(ep_rx, av); + // -------------------------------------------------------------------- + // setup an endpoint for receiving messages, + // usually an rx endpoint is shared by all threads + endpoint_wrapper create_rx_endpoint( + struct fid_domain* domain, struct fi_info* info, struct fid_av* av) + { + auto ep_rx = new_endpoint_active(domain, info, false); - // create a completion queue for the rx endpoint - info->rx_attr->op_flags |= FI_COMPLETION; - auto rx_cq = create_completion_queue(domain, info->rx_attr->size, "rx"); + // bind address vector + bind_address_vector_to_endpoint(ep_rx, av); - // bind CQ to endpoint - bind_queue_to_endpoint(ep_rx, rx_cq, FI_RECV, "rx"); - return endpoint_wrapper(ep_rx, rx_cq, nullptr, "rx"); - } + // create a completion queue for the rx endpoint + info->rx_attr->op_flags |= FI_COMPLETION; + auto rx_cq = create_completion_queue(domain, info->rx_attr->size, "rx"); - // -------------------------------------------------------------------- - // initialize the basic fabric/domain/name - template - void initialize(std::string const& provider, bool rootnode, int size, size_t threads, - Args&&... args) - { - LF_DEB(NS_DEBUG::cnb_deb, eval([]() { std::cout.setf(std::ios::unitbuf); })); - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + // bind CQ to endpoint + bind_queue_to_endpoint(ep_rx, rx_cq, FI_RECV, "rx"); + return endpoint_wrapper(ep_rx, rx_cq, nullptr, "rx"); + } - max_completions_per_poll_ = libfabric_completions_per_poll(); - LF_DEB(NS_DEBUG::cnb_err, - debug(debug::str<>("Poll completions"), debug::dec<3>(max_completions_per_poll_))); + // -------------------------------------------------------------------- + // initialize the basic fabric/domain/name + template + void initialize( + std::string const& provider, bool rootnode, int size, size_t threads, Args&&... args) + { + LF_DEB(NS_DEBUG::cnb_deb, eval([]() { std::cout.setf(std::ios::unitbuf); })); + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - uint32_t default_val = (threads == 1) ? 0x400 : 0x4000; - msg_rendezvous_threshold_ = libfabric_rendezvous_threshold(default_val); - LF_DEB(NS_DEBUG::cnb_err, - debug(debug::str<>("Rendezvous threshold"), debug::hex<4>(msg_rendezvous_threshold_))); + max_completions_per_poll_ = libfabric_completions_per_poll(); + LF_DEB(NS_DEBUG::cnb_err, + debug(debug::str<>("Poll completions"), debug::dec<3>(max_completions_per_poll_))); - endpoint_type_ = static_cast(libfabric_endpoint_type()); - LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("Endpoints"), libfabric_endpoint_string())); + uint32_t default_val = (threads == 1) ? 0x400 : 0x4000; + msg_rendezvous_threshold_ = libfabric_rendezvous_threshold(default_val); + LF_DEB(NS_DEBUG::cnb_err, + debug(debug::str<>("Rendezvous threshold"), + debug::hex<4>(msg_rendezvous_threshold_))); - eps_ = std::make_unique(); + endpoint_type_ = static_cast(libfabric_endpoint_type()); + LF_DEB( + NS_DEBUG::cnb_err, debug(debug::str<>("Endpoints"), libfabric_endpoint_string())); - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Threads"), debug::dec<3>(threads))); + eps_ = std::make_unique(); - open_fabric(provider, threads, rootnode); + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Threads"), debug::dec<3>(threads))); - // create an address vector that will be bound to (all) endpoints - av_ = create_address_vector(fabric_info_, size, threads); + open_fabric(provider, threads, rootnode); - // we need an rx endpoint in all cases except scalable rx - if (endpoint_type_ != endpoint_type::scalableTxRx) - { - // setup an endpoint for receiving messages - // rx endpoint is typically shared by all threads - eps_->ep_rx_ = create_rx_endpoint(fabric_domain_, fabric_info_, av_); - } + // create an address vector that will be bound to (all) endpoints + av_ = create_address_vector(fabric_info_, size, threads); - if (endpoint_type_ == endpoint_type::single) - { - // always bind a tx cq to the rx endpoint for single endpoint type - auto tx_cq = bind_tx_queue_to_rx_endpoint(fabric_info_, eps_->ep_rx_.get_ep()); - eps_->ep_rx_.set_tx_cq(tx_cq); - } - else if (endpoint_type_ != endpoint_type::scalableTxRx) - { + // we need an rx endpoint in all cases except scalable rx + if (endpoint_type_ != endpoint_type::scalableTxRx) + { + // setup an endpoint for receiving messages + // rx endpoint is typically shared by all threads + eps_->ep_rx_ = create_rx_endpoint(fabric_domain_, fabric_info_, av_); + } + + if (endpoint_type_ == endpoint_type::single) + { + // always bind a tx cq to the rx endpoint for single endpoint type + auto tx_cq = bind_tx_queue_to_rx_endpoint(fabric_info_, eps_->ep_rx_.get_ep()); + eps_->ep_rx_.set_tx_cq(tx_cq); + } + else if (endpoint_type_ != endpoint_type::scalableTxRx) + { #if defined(HAVE_LIBFABRIC_SOCKETS) || defined(HAVE_LIBFABRIC_TCP) || \ defined(HAVE_LIBFABRIC_VERBS) || defined(HAVE_LIBFABRIC_CXI) || defined(HAVE_LIBFABRIC_EFA) - // it appears that the rx endpoint cannot be enabled if it does not - // have a Tx CQ (at least when using sockets), so we create a dummy - // Tx CQ and bind it just to stop libfabric from triggering an error. - // The tx_cq won't actually be used because the user will get the real - // tx endpoint which will have the correct cq bound to it - auto dummy_cq = bind_tx_queue_to_rx_endpoint(fabric_info_, eps_->ep_rx_.get_ep()); - eps_->ep_rx_.set_tx_cq(dummy_cq); + // it appears that the rx endpoint cannot be enabled if it does not + // have a Tx CQ (at least when using sockets), so we create a dummy + // Tx CQ and bind it just to stop libfabric from triggering an error. + // The tx_cq won't actually be used because the user will get the real + // tx endpoint which will have the correct cq bound to it + auto dummy_cq = bind_tx_queue_to_rx_endpoint(fabric_info_, eps_->ep_rx_.get_ep()); + eps_->ep_rx_.set_tx_cq(dummy_cq); #endif - } + } - if (endpoint_type_ == endpoint_type::multiple) - { - // create a separate Tx endpoint for sending messages - // note that the CQ needs FI_RECV even though its a Tx cq to keep - // some providers happy as they trigger an error if an endpoint - // has no Rx cq attached (appears to be a progress related bug) - auto ep_tx = new_endpoint_active(fabric_domain_, fabric_info_, true); - - // create a completion queue for tx endpoint - fabric_info_->tx_attr->op_flags |= (FI_INJECT_COMPLETE | FI_COMPLETION); - auto tx_cq = - create_completion_queue(fabric_domain_, fabric_info_->tx_attr->size, "tx multiple"); - - bind_queue_to_endpoint(ep_tx, tx_cq, FI_TRANSMIT | FI_RECV, "tx multiple"); - bind_address_vector_to_endpoint(ep_tx, av_); - enable_endpoint(ep_tx, "tx multiple"); - - // combine endpoints and CQ into wrapper for convenience - eps_->ep_tx_ = endpoint_wrapper(ep_tx, nullptr, tx_cq, "tx multiple"); - } - else if (endpoint_type_ == endpoint_type::threadlocalTx) - { - // each thread creates a Tx endpoint on first call to get_tx_endpoint() - } - else if (endpoint_type_ == endpoint_type::scalableTx || - endpoint_type_ == endpoint_type::scalableTxRx) - { - // setup tx contexts for each possible thread - size_t threads_allocated = 0; - auto ep_sx = new_endpoint_scalable(fabric_domain_, fabric_info_, true /*Tx*/, threads, - threads_allocated); + if (endpoint_type_ == endpoint_type::multiple) + { + // create a separate Tx endpoint for sending messages + // note that the CQ needs FI_RECV even though its a Tx cq to keep + // some providers happy as they trigger an error if an endpoint + // has no Rx cq attached (appears to be a progress related bug) + auto ep_tx = new_endpoint_active(fabric_domain_, fabric_info_, true); - LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>("scalable endpoint ok"), - "Contexts allocated", debug::dec<4>(threads_allocated))); + // create a completion queue for tx endpoint + fabric_info_->tx_attr->op_flags |= (FI_INJECT_COMPLETE | FI_COMPLETION); + auto tx_cq = create_completion_queue( + fabric_domain_, fabric_info_->tx_attr->size, "tx multiple"); - finvoke("fi_scalable_ep_bind AV", "fi_scalable_ep_bind", - fi_scalable_ep_bind(ep_sx, &av_->fid, 0)); + bind_queue_to_endpoint(ep_tx, tx_cq, FI_TRANSMIT | FI_RECV, "tx multiple"); + bind_address_vector_to_endpoint(ep_tx, av_); + enable_endpoint(ep_tx, "tx multiple"); - // prepare the stack for insertions - tx_endpoints_.reserve(threads_allocated); - // - for (unsigned int i = 0; i < threads_allocated; i++) + // combine endpoints and CQ into wrapper for convenience + eps_->ep_tx_ = endpoint_wrapper(ep_tx, nullptr, tx_cq, "tx multiple"); + } + else if (endpoint_type_ == endpoint_type::threadlocalTx) { - [[maybe_unused]] auto scp = - NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), "scalable", debug::dec<4>(i)); + // each thread creates a Tx endpoint on first call to get_tx_endpoint() + } + else if (endpoint_type_ == endpoint_type::scalableTx || + endpoint_type_ == endpoint_type::scalableTxRx) + { + // setup tx contexts for each possible thread + size_t threads_allocated = 0; + auto ep_sx = new_endpoint_scalable( + fabric_domain_, fabric_info_, true /*Tx*/, threads, threads_allocated); + + LF_DEB(NS_DEBUG::cnb_deb, + trace(debug::str<>("scalable endpoint ok"), "Contexts allocated", + debug::dec<4>(threads_allocated))); - // For threadlocal/scalable endpoints, tx/rx resources - fid_ep* scalable_ep_tx; - fid_cq* scalable_cq_tx; + finvoke("fi_scalable_ep_bind AV", "fi_scalable_ep_bind", + fi_scalable_ep_bind(ep_sx, &av_->fid, 0)); - // Create a Tx context, cq, bind and enable - finvoke("create tx context", "fi_tx_context", - fi_tx_context(ep_sx, i, NULL, &scalable_ep_tx, NULL)); - scalable_cq_tx = create_completion_queue(fabric_domain_, - fabric_info_->tx_attr->size, "tx scalable"); - bind_queue_to_endpoint(scalable_ep_tx, scalable_cq_tx, FI_TRANSMIT, "tx scalable"); - enable_endpoint(scalable_ep_tx, "tx scalable"); + // prepare the stack for insertions + tx_endpoints_.reserve(threads_allocated); + // + for (unsigned int i = 0; i < threads_allocated; i++) + { + [[maybe_unused]] auto scp = + NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), "scalable", debug::dec<4>(i)); + + // For threadlocal/scalable endpoints, tx/rx resources + fid_ep* scalable_ep_tx; + fid_cq* scalable_cq_tx; + + // Create a Tx context, cq, bind and enable + finvoke("create tx context", "fi_tx_context", + fi_tx_context(ep_sx, i, NULL, &scalable_ep_tx, NULL)); + scalable_cq_tx = create_completion_queue( + fabric_domain_, fabric_info_->tx_attr->size, "tx scalable"); + bind_queue_to_endpoint( + scalable_ep_tx, scalable_cq_tx, FI_TRANSMIT, "tx scalable"); + enable_endpoint(scalable_ep_tx, "tx scalable"); + + endpoint_wrapper tx(scalable_ep_tx, nullptr, scalable_cq_tx, "tx scalable"); + LF_DEB(NS_DEBUG::cnb_deb, + trace(debug::str<>("Scalable Ep"), "initial tx push", "ep", + NS_DEBUG::ptr(tx.get_ep()), "tx cq", NS_DEBUG::ptr(tx.get_tx_cq()), + "rx cq", NS_DEBUG::ptr(tx.get_rx_cq()))); + tx_endpoints_.push(tx); + } - endpoint_wrapper tx(scalable_ep_tx, nullptr, scalable_cq_tx, "tx scalable"); - LF_DEB(NS_DEBUG::cnb_deb, - trace(debug::str<>("Scalable Ep"), "initial tx push", "ep", - NS_DEBUG::ptr(tx.get_ep()), "tx cq", NS_DEBUG::ptr(tx.get_tx_cq()), "rx cq", - NS_DEBUG::ptr(tx.get_rx_cq()))); - tx_endpoints_.push(tx); + eps_->ep_tx_ = endpoint_wrapper(ep_sx, nullptr, nullptr, "rx scalable"); } - eps_->ep_tx_ = endpoint_wrapper(ep_sx, nullptr, nullptr, "rx scalable"); + // once enabled we can get the address + enable_endpoint(eps_->ep_rx_.get_ep(), "rx here"); + here_ = get_endpoint_address(&eps_->ep_rx_.get_ep()->fid); + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("setting 'here'"), iplocality(here_))); + + // // if we are using scalable endpoints, then setup tx/rx contexts + // // we will us a single endpoint for all Tx/Rx contexts + // if (endpoint_type_ == endpoint_type::scalableTx || + // endpoint_type_ == endpoint_type::scalableTxRx) + // { + + // // thread slots might not be same as what we asked for + // size_t threads_allocated = 0; + // auto ep_sx = new_endpoint_scalable(fabric_domain_, fabric_info_, true /*Tx*/, threads, + // threads_allocated); + // if (!ep_sx) + // throw NS_LIBFABRIC::fabric_error(FI_EOTHER, "fi_scalable endpoint creation failed"); + + // LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>("scalable endpoint ok"), + // "Contexts allocated", debug::dec<4>(threads_allocated))); + + // // prepare the stack for insertions + // tx_endpoints_.reserve(threads_allocated); + // rx_endpoints_.reserve(threads_allocated); + // // + // for (unsigned int i = 0; i < threads_allocated; i++) + // { + // [[maybe_unused]] auto scp = + // NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), "scalable", debug::dec<4>(i)); + + // // For threadlocal/scalable endpoints, tx/rx resources + // fid_ep* scalable_ep_tx; + // fid_cq* scalable_cq_tx; + //// fid_ep* scalable_ep_rx; + //// fid_cq* scalable_cq_rx; + + // // Tx context setup + // finvoke("create tx context", "fi_tx_context", + // fi_tx_context(ep_sx, i, NULL, &scalable_ep_tx, NULL)); + + // scalable_cq_tx = create_completion_queue(fabric_domain_, + // fabric_info_->tx_attr->size, "tx scalable"); + + // bind_queue_to_endpoint(scalable_ep_tx, scalable_cq_tx, FI_TRANSMIT, "tx scalable"); + + // enable_endpoint(scalable_ep_tx, "tx scalable"); + + // endpoint_wrapper tx(scalable_ep_tx, nullptr, scalable_cq_tx, "tx scalable"); + // LF_DEB(NS_DEBUG::cnb_deb, + // trace(debug::str<>("Scalable Ep"), "initial tx push", "ep", + // NS_DEBUG::ptr(tx.get_ep()), "tx cq", NS_DEBUG::ptr(tx.get_tx_cq()), "rx cq", + // NS_DEBUG::ptr(tx.get_rx_cq()))); + // tx_endpoints_.push(tx); + + // // Rx contexts + //// finvoke("create rx context", "fi_rx_context", + //// fi_rx_context(ep_sx, i, NULL, &scalable_ep_rx, NULL)); + + //// scalable_cq_rx = + //// create_completion_queue(fabric_domain_, fabric_info_->rx_attr->size, "rx"); + + //// bind_queue_to_endpoint(scalable_ep_rx, scalable_cq_rx, FI_RECV, "rx scalable"); + + //// enable_endpoint(scalable_ep_rx, "rx scalable"); + + //// endpoint_wrapper rx(scalable_ep_rx, scalable_cq_rx, nullptr, "rx scalable"); + //// LF_DEB(NS_DEBUG::cnb_deb, + //// trace(debug::str<>("Scalable Ep"), "initial rx push", "ep", + //// NS_DEBUG::ptr(rx.get_ep()), "tx cq", NS_DEBUG::ptr(rx.get_tx_cq()), "rx cq", + //// NS_DEBUG::ptr(rx.get_rx_cq()))); + //// rx_endpoints_.push(rx); + // } + + // finvoke("fi_scalable_ep_bind AV", "fi_scalable_ep_bind", + // fi_scalable_ep_bind(ep_sx, &av_->fid, 0)); + + // eps_->ep_tx_ = endpoint_wrapper(ep_sx, nullptr, nullptr, "rx scalable"); + + return static_cast(this)->initialize_derived( + provider, rootnode, size, threads, std::forward(args)...); } - // once enabled we can get the address - enable_endpoint(eps_->ep_rx_.get_ep(), "rx here"); - here_ = get_endpoint_address(&eps_->ep_rx_.get_ep()->fid); - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("setting 'here'"), iplocality(here_))); - - // // if we are using scalable endpoints, then setup tx/rx contexts - // // we will us a single endpoint for all Tx/Rx contexts - // if (endpoint_type_ == endpoint_type::scalableTx || - // endpoint_type_ == endpoint_type::scalableTxRx) - // { - - // // thread slots might not be same as what we asked for - // size_t threads_allocated = 0; - // auto ep_sx = new_endpoint_scalable(fabric_domain_, fabric_info_, true /*Tx*/, threads, - // threads_allocated); - // if (!ep_sx) - // throw NS_LIBFABRIC::fabric_error(FI_EOTHER, "fi_scalable endpoint creation failed"); - - // LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>("scalable endpoint ok"), - // "Contexts allocated", debug::dec<4>(threads_allocated))); - - // // prepare the stack for insertions - // tx_endpoints_.reserve(threads_allocated); - // rx_endpoints_.reserve(threads_allocated); - // // - // for (unsigned int i = 0; i < threads_allocated; i++) - // { - // [[maybe_unused]] auto scp = - // NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), "scalable", debug::dec<4>(i)); - - // // For threadlocal/scalable endpoints, tx/rx resources - // fid_ep* scalable_ep_tx; - // fid_cq* scalable_cq_tx; - //// fid_ep* scalable_ep_rx; - //// fid_cq* scalable_cq_rx; - - // // Tx context setup - // finvoke("create tx context", "fi_tx_context", - // fi_tx_context(ep_sx, i, NULL, &scalable_ep_tx, NULL)); - - // scalable_cq_tx = create_completion_queue(fabric_domain_, - // fabric_info_->tx_attr->size, "tx scalable"); - - // bind_queue_to_endpoint(scalable_ep_tx, scalable_cq_tx, FI_TRANSMIT, "tx scalable"); - - // enable_endpoint(scalable_ep_tx, "tx scalable"); - - // endpoint_wrapper tx(scalable_ep_tx, nullptr, scalable_cq_tx, "tx scalable"); - // LF_DEB(NS_DEBUG::cnb_deb, - // trace(debug::str<>("Scalable Ep"), "initial tx push", "ep", - // NS_DEBUG::ptr(tx.get_ep()), "tx cq", NS_DEBUG::ptr(tx.get_tx_cq()), "rx cq", - // NS_DEBUG::ptr(tx.get_rx_cq()))); - // tx_endpoints_.push(tx); - - // // Rx contexts - //// finvoke("create rx context", "fi_rx_context", - //// fi_rx_context(ep_sx, i, NULL, &scalable_ep_rx, NULL)); - - //// scalable_cq_rx = - //// create_completion_queue(fabric_domain_, fabric_info_->rx_attr->size, "rx"); - - //// bind_queue_to_endpoint(scalable_ep_rx, scalable_cq_rx, FI_RECV, "rx scalable"); - - //// enable_endpoint(scalable_ep_rx, "rx scalable"); - - //// endpoint_wrapper rx(scalable_ep_rx, scalable_cq_rx, nullptr, "rx scalable"); - //// LF_DEB(NS_DEBUG::cnb_deb, - //// trace(debug::str<>("Scalable Ep"), "initial rx push", "ep", - //// NS_DEBUG::ptr(rx.get_ep()), "tx cq", NS_DEBUG::ptr(rx.get_tx_cq()), "rx cq", - //// NS_DEBUG::ptr(rx.get_rx_cq()))); - //// rx_endpoints_.push(rx); - // } - - // finvoke("fi_scalable_ep_bind AV", "fi_scalable_ep_bind", - // fi_scalable_ep_bind(ep_sx, &av_->fid, 0)); - - // eps_->ep_tx_ = endpoint_wrapper(ep_sx, nullptr, nullptr, "rx scalable"); - - return static_cast(this)->initialize_derived(provider, rootnode, size, threads, - std::forward(args)...); - } - - // -------------------------------------------------------------------- - constexpr uint64_t caps_flags() { return static_cast(this)->caps_flags(); } + // -------------------------------------------------------------------- + constexpr uint64_t caps_flags() { return static_cast(this)->caps_flags(); } - // -------------------------------------------------------------------- - constexpr fi_threading threadlevel_flags() - { - return static_cast(this)->threadlevel_flags(); - } + // -------------------------------------------------------------------- + constexpr fi_threading threadlevel_flags() + { + return static_cast(this)->threadlevel_flags(); + } - // -------------------------------------------------------------------- - constexpr std::int64_t memory_registration_mode_flags() - { - std::int64_t base_flags = FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY; + // -------------------------------------------------------------------- + constexpr std::int64_t memory_registration_mode_flags() + { + std::int64_t base_flags = FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY; #if OOMPH_ENABLE_DEVICE - base_flags = base_flags | FI_MR_HMEM; + base_flags = base_flags | FI_MR_HMEM; #endif - base_flags = base_flags | FI_MR_LOCAL; + base_flags = base_flags | FI_MR_LOCAL; #if defined(HAVE_LIBFABRIC_CXI) - return base_flags | FI_MR_MMU_NOTIFY | FI_MR_ENDPOINT; + return base_flags | FI_MR_MMU_NOTIFY | FI_MR_ENDPOINT; #elif defined(HAVE_LIBFABRIC_EFA) - return base_flags | FI_MR_MMU_NOTIFY | FI_MR_ENDPOINT; + return base_flags | FI_MR_MMU_NOTIFY | FI_MR_ENDPOINT; #else - return base_flags; + return base_flags; #endif - } - - // -------------------------------------------------------------------- - uint32_t rendezvous_threshold() { return msg_rendezvous_threshold_; } - // -------------------------------------------------------------------- - // initialize the basic fabric/domain/name - void open_fabric(std::string const& provider, int threads, bool rootnode) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + } - struct fi_info* fabric_hints_ = fi_allocinfo(); - if (!fabric_hints_) + // -------------------------------------------------------------------- + uint32_t rendezvous_threshold() { return msg_rendezvous_threshold_; } + // -------------------------------------------------------------------- + // initialize the basic fabric/domain/name + void open_fabric(std::string const& provider, int threads, bool rootnode) { - throw NS_LIBFABRIC::fabric_error(-1, "Failed to allocate fabric hints"); - } + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Here locality"), iplocality(here_))); + struct fi_info* fabric_hints_ = fi_allocinfo(); + if (!fabric_hints_) + { + throw NS_LIBFABRIC::fabric_error(-1, "Failed to allocate fabric hints"); + } + + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Here locality"), iplocality(here_))); #if defined(HAVE_LIBFABRIC_SOCKETS) || defined(HAVE_LIBFABRIC_TCP) || defined(HAVE_LIBFABRIC_VERBS) - fabric_hints_->addr_format = FI_SOCKADDR_IN; + fabric_hints_->addr_format = FI_SOCKADDR_IN; #elif defined(HAVE_LIBFABRIC_EFA) - fabric_hints_->addr_format = FI_ADDR_EFA; + fabric_hints_->addr_format = FI_ADDR_EFA; #endif - fabric_hints_->caps = caps_flags(); + fabric_hints_->caps = caps_flags(); - fabric_hints_->mode = FI_CONTEXT /*| FI_MR_LOCAL*/; - if (provider.c_str() == std::string("tcp")) - { - fabric_hints_->fabric_attr->prov_name = - strdup(std::string(provider + ";ofi_rxm").c_str()); - } - else if (provider.c_str() == std::string("verbs")) - { - fabric_hints_->fabric_attr->prov_name = - strdup(std::string(provider + ";ofi_rxm").c_str()); - } - else { fabric_hints_->fabric_attr->prov_name = strdup(provider.c_str()); } - LF_DEB(NS_DEBUG::cnb_deb, - debug(debug::str<>("fabric provider"), fabric_hints_->fabric_attr->prov_name)); + fabric_hints_->mode = FI_CONTEXT /*| FI_MR_LOCAL*/; + if (provider.c_str() == std::string("tcp")) + { + fabric_hints_->fabric_attr->prov_name = + strdup(std::string(provider + ";ofi_rxm").c_str()); + } + else if (provider.c_str() == std::string("verbs")) + { + fabric_hints_->fabric_attr->prov_name = + strdup(std::string(provider + ";ofi_rxm").c_str()); + } + else { fabric_hints_->fabric_attr->prov_name = strdup(provider.c_str()); } + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("fabric provider"), fabric_hints_->fabric_attr->prov_name)); - fabric_hints_->domain_attr->mr_mode = memory_registration_mode_flags(); + fabric_hints_->domain_attr->mr_mode = memory_registration_mode_flags(); - // Enable/Disable the use of progress threads - auto progress = libfabric_progress_type(); - fabric_hints_->domain_attr->control_progress = progress; - fabric_hints_->domain_attr->data_progress = progress; - LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("progress"), libfabric_progress_string())); + // Enable/Disable the use of progress threads + auto progress = libfabric_progress_type(); + fabric_hints_->domain_attr->control_progress = progress; + fabric_hints_->domain_attr->data_progress = progress; + LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("progress"), libfabric_progress_string())); - if (threads > 1) - { - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("FI_THREAD_FID"))); - // Enable thread safe mode (Does not work with psm2 provider) - // fabric_hints_->domain_attr->threading = FI_THREAD_SAFE; - //fabric_hints_->domain_attr->threading = FI_THREAD_FID; - fabric_hints_->domain_attr->threading = threadlevel_flags(); - } - else - { - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("FI_THREAD_DOMAIN"))); - // we serialize everything - fabric_hints_->domain_attr->threading = FI_THREAD_DOMAIN; - } + if (threads > 1) + { + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("FI_THREAD_FID"))); + // Enable thread safe mode (Does not work with psm2 provider) + // fabric_hints_->domain_attr->threading = FI_THREAD_SAFE; + //fabric_hints_->domain_attr->threading = FI_THREAD_FID; + fabric_hints_->domain_attr->threading = threadlevel_flags(); + } + else + { + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("FI_THREAD_DOMAIN"))); + // we serialize everything + fabric_hints_->domain_attr->threading = FI_THREAD_DOMAIN; + } - // Enable resource management - fabric_hints_->domain_attr->resource_mgmt = FI_RM_ENABLED; + // Enable resource management + fabric_hints_->domain_attr->resource_mgmt = FI_RM_ENABLED; - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("fabric endpoint"), "RDM")); - fabric_hints_->ep_attr->type = FI_EP_RDM; + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("fabric endpoint"), "RDM")); + fabric_hints_->ep_attr->type = FI_EP_RDM; - uint64_t flags = 0; - LF_DEB(NS_DEBUG::cnb_deb, - debug(debug::str<>("get fabric info"), "FI_VERSION", - debug::dec(LIBFABRIC_FI_VERSION_MAJOR), debug::dec(LIBFABRIC_FI_VERSION_MINOR))); + uint64_t flags = 0; + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("get fabric info"), "FI_VERSION", + debug::dec(LIBFABRIC_FI_VERSION_MAJOR), + debug::dec(LIBFABRIC_FI_VERSION_MINOR))); - int ret = fi_getinfo(FI_VERSION(LIBFABRIC_FI_VERSION_MAJOR, LIBFABRIC_FI_VERSION_MINOR), - nullptr, nullptr, flags, fabric_hints_, &fabric_info_); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "Failed to get fabric info"); + int ret = fi_getinfo(FI_VERSION(LIBFABRIC_FI_VERSION_MAJOR, LIBFABRIC_FI_VERSION_MINOR), + nullptr, nullptr, flags, fabric_hints_, &fabric_info_); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "Failed to get fabric info"); - if (rootnode) - { - LF_DEB(NS_DEBUG::cnb_err, - trace(debug::str<>("Fabric info"), "\n", fi_tostr(fabric_info_, FI_TYPE_INFO))); - } + if (rootnode) + { + LF_DEB(NS_DEBUG::cnb_err, + trace(debug::str<>("Fabric info"), "\n", fi_tostr(fabric_info_, FI_TYPE_INFO))); + } - bool context = (fabric_hints_->mode & FI_CONTEXT) != 0; - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_CONTEXT"), context)); + bool context = (fabric_hints_->mode & FI_CONTEXT) != 0; + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_CONTEXT"), context)); - mrlocal = (fabric_hints_->domain_attr->mr_mode & FI_MR_LOCAL) != 0; - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_MR_LOCAL"), mrlocal)); + mrlocal = (fabric_hints_->domain_attr->mr_mode & FI_MR_LOCAL) != 0; + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_MR_LOCAL"), mrlocal)); - mrbind = (fabric_hints_->domain_attr->mr_mode & FI_MR_ENDPOINT) != 0; - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_MR_ENDPOINT"), mrbind)); + mrbind = (fabric_hints_->domain_attr->mr_mode & FI_MR_ENDPOINT) != 0; + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_MR_ENDPOINT"), mrbind)); - /* Check if provider requires heterogeneous memory registration */ - mrhmem = (fabric_hints_->domain_attr->mr_mode & FI_MR_HMEM) != 0; - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_MR_HMEM"), mrhmem)); + /* Check if provider requires heterogeneous memory registration */ + mrhmem = (fabric_hints_->domain_attr->mr_mode & FI_MR_HMEM) != 0; + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_MR_HMEM"), mrhmem)); - bool mrhalloc = (fabric_hints_->domain_attr->mr_mode & FI_MR_ALLOCATED) != 0; - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_MR_ALLOCATED"), mrhalloc)); + bool mrhalloc = (fabric_hints_->domain_attr->mr_mode & FI_MR_ALLOCATED) != 0; + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Requires FI_MR_ALLOCATED"), mrhalloc)); - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Creating fi_fabric"))); - ret = fi_fabric(fabric_info_->fabric_attr, &fabric_, nullptr); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "Failed to get fi_fabric"); + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Creating fi_fabric"))); + ret = fi_fabric(fabric_info_->fabric_attr, &fabric_, nullptr); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "Failed to get fi_fabric"); - // Allocate a domain. - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Allocating domain"))); - ret = fi_domain(fabric_, fabric_info_, &fabric_domain_, nullptr); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_domain"); + // Allocate a domain. + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Allocating domain"))); + ret = fi_domain(fabric_, fabric_info_, &fabric_domain_, nullptr); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_domain"); #if defined(HAVE_LIBFABRIC_GNI) - { - [[maybe_unused]] auto scp = - NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), "GNI memory registration block"); - - LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("-------"), "GNI String values")); - // Dump out all vars for debug purposes - for (auto& gni_data : gni_strs) { - _set_check_domain_op_value(gni_data.first, 0, gni_data.second.c_str(), - false); - } - LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("-------"), "GNI Int values")); - for (auto& gni_data : gni_ints) - { - _set_check_domain_op_value(gni_data.first, 0, gni_data.second.c_str(), - false); - } - LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("-------"))); - - // -------------------------- - // GNI_MR_CACHE - // set GNI mem reg to be either none, internal or udreg - // - _set_check_domain_op_value(GNI_MR_CACHE, const_cast(OOMPH_GNI_REG), - "GNI_MR_CACHE"); - - // -------------------------- - // GNI_MR_UDREG_REG_LIMIT - // Experiments showed default value of 2048 too high if - // launching multiple clients on one node - // - int32_t udreg_limit = 0x0800; // 0x0400 = 1024, 0x0800 = 2048 - _set_check_domain_op_value(GNI_MR_UDREG_REG_LIMIT, udreg_limit, - "GNI_MR_UDREG_REG_LIMIT"); - - // -------------------------- - // GNI_MR_CACHE_LAZY_DEREG - // Enable lazy deregistration in MR cache - // - int32_t enable = 1; - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("setting GNI_MR_CACHE_LAZY_DEREG"))); - _set_check_domain_op_value(GNI_MR_CACHE_LAZY_DEREG, enable, - "GNI_MR_CACHE_LAZY_DEREG"); + [[maybe_unused]] auto scp = + NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), "GNI memory registration block"); - // -------------------------- - // GNI_MSG_RENDEZVOUS_THRESHOLD (c.f. GNI_RMA_RDMA_THRESHOLD) - // - int32_t thresh = msg_rendezvous_threshold_; - _set_check_domain_op_value(GNI_MSG_RENDEZVOUS_THRESHOLD, thresh, - "GNI_MSG_RENDEZVOUS_THRESHOLD"); - } + LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("-------"), "GNI String values")); + // Dump out all vars for debug purposes + for (auto& gni_data : gni_strs) + { + _set_check_domain_op_value( + gni_data.first, 0, gni_data.second.c_str(), false); + } + LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("-------"), "GNI Int values")); + for (auto& gni_data : gni_ints) + { + _set_check_domain_op_value( + gni_data.first, 0, gni_data.second.c_str(), false); + } + LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("-------"))); + + // -------------------------- + // GNI_MR_CACHE + // set GNI mem reg to be either none, internal or udreg + // + _set_check_domain_op_value( + GNI_MR_CACHE, const_cast(OOMPH_GNI_REG), "GNI_MR_CACHE"); + + // -------------------------- + // GNI_MR_UDREG_REG_LIMIT + // Experiments showed default value of 2048 too high if + // launching multiple clients on one node + // + int32_t udreg_limit = 0x0800; // 0x0400 = 1024, 0x0800 = 2048 + _set_check_domain_op_value( + GNI_MR_UDREG_REG_LIMIT, udreg_limit, "GNI_MR_UDREG_REG_LIMIT"); + + // -------------------------- + // GNI_MR_CACHE_LAZY_DEREG + // Enable lazy deregistration in MR cache + // + int32_t enable = 1; + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("setting GNI_MR_CACHE_LAZY_DEREG"))); + _set_check_domain_op_value( + GNI_MR_CACHE_LAZY_DEREG, enable, "GNI_MR_CACHE_LAZY_DEREG"); + + // -------------------------- + // GNI_MSG_RENDEZVOUS_THRESHOLD (c.f. GNI_RMA_RDMA_THRESHOLD) + // + int32_t thresh = msg_rendezvous_threshold_; + _set_check_domain_op_value( + GNI_MSG_RENDEZVOUS_THRESHOLD, thresh, "GNI_MSG_RENDEZVOUS_THRESHOLD"); + } #endif - tx_inject_size_ = fabric_info_->tx_attr->inject_size; + tx_inject_size_ = fabric_info_->tx_attr->inject_size; - // the number of preposted receives, and sender queue depth - // is set by querying the tx/tx attr sizes - tx_attr_size_ = std::min(size_t(512), fabric_info_->tx_attr->size / 2); - rx_attr_size_ = std::min(size_t(512), fabric_info_->rx_attr->size / 2); - fi_freeinfo(fabric_hints_); - } + // the number of preposted receives, and sender queue depth + // is set by querying the tx/tx attr sizes + tx_attr_size_ = std::min(size_t(512), fabric_info_->tx_attr->size / 2); + rx_attr_size_ = std::min(size_t(512), fabric_info_->rx_attr->size / 2); + fi_freeinfo(fabric_hints_); + } - // -------------------------------------------------------------------- - struct fi_info* set_src_dst_addresses(struct fi_info* info, bool tx) - { - return static_cast(this)->set_src_dst_addresses(info, tx); - } + // -------------------------------------------------------------------- + struct fi_info* set_src_dst_addresses(struct fi_info* info, bool tx) + { + return static_cast(this)->set_src_dst_addresses(info, tx); + } #ifdef HAVE_LIBFABRIC_GNI - // -------------------------------------------------------------------- - // Special GNI extensions to disable memory registration cache - - // if set is false, the old value is returned and nothing is set - template - int _set_check_domain_op_value(int op, T value, const char* info, bool set = true) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - static struct fi_gni_ops_domain* gni_domain_ops = nullptr; - int ret = 0; + // -------------------------------------------------------------------- + // Special GNI extensions to disable memory registration cache - if (gni_domain_ops == nullptr) + // if set is false, the old value is returned and nothing is set + template + int _set_check_domain_op_value(int op, T value, char const* info, bool set = true) { - ret = fi_open_ops(&fabric_domain_->fid, FI_GNI_DOMAIN_OPS_1, 0, (void**)&gni_domain_ops, - nullptr); - LF_DEB(NS_DEBUG::cnb_deb, - debug(debug::str<>("gni open ops"), (ret == 0 ? "OK" : "FAIL"), - NS_DEBUG::ptr(gni_domain_ops))); - } + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + static struct fi_gni_ops_domain* gni_domain_ops = nullptr; + int ret = 0; - // if open was ok and set flag is present, then set value - if (ret == 0 && set) - { - ret = gni_domain_ops->set_val(&fabric_domain_->fid, (dom_ops_val_t)(op), - reinterpret_cast(&value)); + if (gni_domain_ops == nullptr) + { + ret = fi_open_ops(&fabric_domain_->fid, FI_GNI_DOMAIN_OPS_1, 0, + (void**) &gni_domain_ops, nullptr); + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("gni open ops"), (ret == 0 ? "OK" : "FAIL"), + NS_DEBUG::ptr(gni_domain_ops))); + } - LF_DEB(NS_DEBUG::cnb_deb, - debug(debug::str<>("gni set ops val"), value, (ret == 0 ? "OK" : "FAIL"))); - } + // if open was ok and set flag is present, then set value + if (ret == 0 && set) + { + ret = gni_domain_ops->set_val( + &fabric_domain_->fid, (dom_ops_val_t) (op), reinterpret_cast(&value)); - // Get the value (so we can check that the value we set is now returned) - T new_value; - ret = gni_domain_ops->get_val(&fabric_domain_->fid, (dom_ops_val_t)(op), &new_value); - if constexpr (std::is_integral::value) - { - LF_DEB(NS_DEBUG::cnb_err, debug(debug::str<>("gni op val"), (ret == 0 ? "OK" : "FAIL"), - info, debug::hex<8>(new_value))); - } - else - { - LF_DEB(NS_DEBUG::cnb_err, - debug(debug::str<>("gni op val"), (ret == 0 ? "OK" : "FAIL"), info, new_value)); - } - // - if (ret) throw NS_LIBFABRIC::fabric_error(ret, std::string("setting ") + info); + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("gni set ops val"), value, (ret == 0 ? "OK" : "FAIL"))); + } - return ret; - } + // Get the value (so we can check that the value we set is now returned) + T new_value; + ret = gni_domain_ops->get_val(&fabric_domain_->fid, (dom_ops_val_t) (op), &new_value); + if constexpr (std::is_integral::value) + { + LF_DEB(NS_DEBUG::cnb_err, + debug(debug::str<>("gni op val"), (ret == 0 ? "OK" : "FAIL"), info, + debug::hex<8>(new_value))); + } + else + { + LF_DEB(NS_DEBUG::cnb_err, + debug(debug::str<>("gni op val"), (ret == 0 ? "OK" : "FAIL"), info, new_value)); + } + // + if (ret) throw NS_LIBFABRIC::fabric_error(ret, std::string("setting ") + info); + + return ret; + } #endif - // -------------------------------------------------------------------- - struct fid_ep* new_endpoint_active(struct fid_domain* domain, struct fi_info* info, bool tx) - { - // don't allow multiple threads to call endpoint create at the same time - scoped_lock lock(controller_mutex_); + // -------------------------------------------------------------------- + struct fid_ep* new_endpoint_active(struct fid_domain* domain, struct fi_info* info, bool tx) + { + // don't allow multiple threads to call endpoint create at the same time + scoped_lock lock(controller_mutex_); - // make sure src_addr/dst_addr are set accordingly - // and we do not create two endpoint with the same src address - struct fi_info* hints = set_src_dst_addresses(info, tx); + // make sure src_addr/dst_addr are set accordingly + // and we do not create two endpoint with the same src address + struct fi_info* hints = set_src_dst_addresses(info, tx); - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - LF_DEB(NS_DEBUG::cnb_deb, - debug(debug::str<>("Got info mode"), (info->mode & FI_NOTIFY_FLAGS_ONLY))); + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("Got info mode"), (info->mode & FI_NOTIFY_FLAGS_ONLY))); - struct fid_ep* ep; - int ret = fi_endpoint(domain, hints, &ep, nullptr); - if (ret) - { - throw NS_LIBFABRIC::fabric_error(ret, "fi_endpoint (too many threadlocal " - "endpoints?)"); + struct fid_ep* ep; + int ret = fi_endpoint(domain, hints, &ep, nullptr); + if (ret) + { + throw NS_LIBFABRIC::fabric_error(ret, + "fi_endpoint (too many threadlocal " + "endpoints?)"); + } + fi_freeinfo(hints); + LF_DEB( + NS_DEBUG::cnb_deb, debug(debug::str<>("new_endpoint_active"), NS_DEBUG::ptr(ep))); + return ep; } - fi_freeinfo(hints); - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("new_endpoint_active"), NS_DEBUG::ptr(ep))); - return ep; - } - // -------------------------------------------------------------------- - struct fid_ep* new_endpoint_scalable(struct fid_domain* domain, struct fi_info* info, bool tx, - size_t threads, size_t& threads_allocated) - { - // don't allow multiple threads to call endpoint create at the same time - scoped_lock lock(controller_mutex_); + // -------------------------------------------------------------------- + struct fid_ep* new_endpoint_scalable(struct fid_domain* domain, struct fi_info* info, + bool tx, size_t threads, size_t& threads_allocated) + { + // don't allow multiple threads to call endpoint create at the same time + scoped_lock lock(controller_mutex_); - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("fi_dupinfo"))); - struct fi_info* hints = fi_dupinfo(info); - if (!hints) throw NS_LIBFABRIC::fabric_error(0, "fi_dupinfo"); + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("fi_dupinfo"))); + struct fi_info* hints = fi_dupinfo(info); + if (!hints) throw NS_LIBFABRIC::fabric_error(0, "fi_dupinfo"); - int flags = 0; - struct fi_info* new_hints = nullptr; - int ret = fi_getinfo(FI_VERSION(LIBFABRIC_FI_VERSION_MAJOR, LIBFABRIC_FI_VERSION_MINOR), - nullptr, nullptr, flags, hints, &new_hints); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_getinfo"); + int flags = 0; + struct fi_info* new_hints = nullptr; + int ret = fi_getinfo(FI_VERSION(LIBFABRIC_FI_VERSION_MAJOR, LIBFABRIC_FI_VERSION_MINOR), + nullptr, nullptr, flags, hints, &new_hints); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_getinfo"); - // Check the optimal number of TX/RX contexts supported by the provider - size_t context_count = 0; - if (tx) { context_count = std::min(new_hints->domain_attr->tx_ctx_cnt, threads); } - else { context_count = std::min(new_hints->domain_attr->rx_ctx_cnt, threads); } + // Check the optimal number of TX/RX contexts supported by the provider + size_t context_count = 0; + if (tx) { context_count = std::min(new_hints->domain_attr->tx_ctx_cnt, threads); } + else { context_count = std::min(new_hints->domain_attr->rx_ctx_cnt, threads); } - // clang-format off + // clang-format off LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>("scalable endpoint"), "Tx", tx, @@ -1060,440 +1061,451 @@ class controller_base "tx_ctx_cnt", debug::dec<3>(new_hints->domain_attr->tx_ctx_cnt), "rx_ctx_cnt", debug::dec<3>(new_hints->domain_attr->rx_ctx_cnt), "context_count", debug::dec<3>(context_count))); - // clang-format on - - threads_allocated = context_count; - new_hints->ep_attr->tx_ctx_cnt = context_count; - new_hints->ep_attr->rx_ctx_cnt = context_count; - - struct fid_ep* ep; - ret = fi_scalable_ep(domain, new_hints, &ep, nullptr); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_scalable_ep"); - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("new_endpoint_scalable"), NS_DEBUG::ptr(ep))); - fi_freeinfo(hints); - return ep; - } - - // -------------------------------------------------------------------- - endpoint_wrapper& get_rx_endpoint() - { - static auto rx = NS_DEBUG::cnb_deb.make_timer(1, debug::str<>("get_rx_endpoint")); - LF_DEB(NS_DEBUG::cnb_deb, timed(rx)); + // clang-format on + + threads_allocated = context_count; + new_hints->ep_attr->tx_ctx_cnt = context_count; + new_hints->ep_attr->rx_ctx_cnt = context_count; + + struct fid_ep* ep; + ret = fi_scalable_ep(domain, new_hints, &ep, nullptr); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_scalable_ep"); + LF_DEB( + NS_DEBUG::cnb_deb, debug(debug::str<>("new_endpoint_scalable"), NS_DEBUG::ptr(ep))); + fi_freeinfo(hints); + return ep; + } - if (endpoint_type_ == endpoint_type::scalableTxRx) + // -------------------------------------------------------------------- + endpoint_wrapper& get_rx_endpoint() { - if (eps_->tl_srx_.get_ep() == nullptr) + static auto rx = NS_DEBUG::cnb_deb.make_timer(1, debug::str<>("get_rx_endpoint")); + LF_DEB(NS_DEBUG::cnb_deb, timed(rx)); + + if (endpoint_type_ == endpoint_type::scalableTxRx) { - endpoint_wrapper ep; - bool ok = rx_endpoints_.pop(ep); - if (!ok) + if (eps_->tl_srx_.get_ep() == nullptr) { - // clang-format off + endpoint_wrapper ep; + bool ok = rx_endpoints_.pop(ep); + if (!ok) + { + // clang-format off LF_DEB(NS_DEBUG::cnb_deb, error(debug::str<>("Scalable Ep"), "pop rx", "ep", NS_DEBUG::ptr(ep.get_ep()), "tx cq", NS_DEBUG::ptr(ep.get_tx_cq()), "rx cq", NS_DEBUG::ptr(ep.get_rx_cq()))); - // clang-format on - throw std::runtime_error("rx endpoint wrapper pop fail"); + // clang-format on + throw std::runtime_error("rx endpoint wrapper pop fail"); + } + eps_->tl_srx_ = stack_endpoint( + ep.get_ep(), ep.get_rx_cq(), ep.get_tx_cq(), ep.get_name(), &rx_endpoints_); + LF_DEB(NS_DEBUG::cnb_deb, + trace(debug::str<>("Scalable Ep"), "pop rx", "ep", + NS_DEBUG::ptr(eps_->tl_srx_.get_ep()), "tx cq", + NS_DEBUG::ptr(eps_->tl_srx_.get_tx_cq()), "rx cq", + NS_DEBUG::ptr(eps_->tl_srx_.get_rx_cq()))); } - eps_->tl_srx_ = stack_endpoint(ep.get_ep(), ep.get_rx_cq(), ep.get_tx_cq(), - ep.get_name(), &rx_endpoints_); - LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>("Scalable Ep"), "pop rx", "ep", - NS_DEBUG::ptr(eps_->tl_srx_.get_ep()), "tx cq", - NS_DEBUG::ptr(eps_->tl_srx_.get_tx_cq()), "rx cq", - NS_DEBUG::ptr(eps_->tl_srx_.get_rx_cq()))); + return eps_->tl_srx_.endpoint_; } - return eps_->tl_srx_.endpoint_; + // otherwise just return the normal Rx endpoint + return eps_->ep_rx_; } - // otherwise just return the normal Rx endpoint - return eps_->ep_rx_; - } - // -------------------------------------------------------------------- - endpoint_wrapper& get_tx_endpoint() - { - if (endpoint_type_ == endpoint_type::threadlocalTx) + // -------------------------------------------------------------------- + endpoint_wrapper& get_tx_endpoint() { - if (eps_->tl_tx_.get_ep() == nullptr) + if (endpoint_type_ == endpoint_type::threadlocalTx) { - [[maybe_unused]] auto scp = - NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, "threadlocal"); - - // create a completion queue for tx endpoint - fabric_info_->tx_attr->op_flags |= (FI_INJECT_COMPLETE | FI_COMPLETION); - auto tx_cq = create_completion_queue(fabric_domain_, fabric_info_->tx_attr->size, - "tx threadlocal"); - - // setup an endpoint for sending messages - // note that the CQ needs FI_RECV even though its a Tx cq to keep - // some providers happy as they trigger an error if an endpoint - // has no Rx cq attached (progress bug) - auto ep_tx = new_endpoint_active(fabric_domain_, fabric_info_, true); - bind_queue_to_endpoint(ep_tx, tx_cq, FI_TRANSMIT | FI_RECV, "tx threadlocal"); - bind_address_vector_to_endpoint(ep_tx, av_); - enable_endpoint(ep_tx, "tx threadlocal"); - - // set threadlocal endpoint wrapper - LF_DEB(NS_DEBUG::cnb_deb, - trace(debug::str<>("Threadlocal Ep"), "create Tx", "ep", NS_DEBUG::ptr(ep_tx), - "tx cq", NS_DEBUG::ptr(tx_cq), "rx cq", NS_DEBUG::ptr(nullptr))); - // for cleaning up at termination - endpoint_wrapper ep(ep_tx, nullptr, tx_cq, "tx threadlocal"); - tx_endpoints_.push(ep); - eps_->tl_tx_ = stack_endpoint(ep_tx, nullptr, tx_cq, "threadlocal", nullptr); + if (eps_->tl_tx_.get_ep() == nullptr) + { + [[maybe_unused]] auto scp = + NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, "threadlocal"); + + // create a completion queue for tx endpoint + fabric_info_->tx_attr->op_flags |= (FI_INJECT_COMPLETE | FI_COMPLETION); + auto tx_cq = create_completion_queue( + fabric_domain_, fabric_info_->tx_attr->size, "tx threadlocal"); + + // setup an endpoint for sending messages + // note that the CQ needs FI_RECV even though its a Tx cq to keep + // some providers happy as they trigger an error if an endpoint + // has no Rx cq attached (progress bug) + auto ep_tx = new_endpoint_active(fabric_domain_, fabric_info_, true); + bind_queue_to_endpoint(ep_tx, tx_cq, FI_TRANSMIT | FI_RECV, "tx threadlocal"); + bind_address_vector_to_endpoint(ep_tx, av_); + enable_endpoint(ep_tx, "tx threadlocal"); + + // set threadlocal endpoint wrapper + LF_DEB(NS_DEBUG::cnb_deb, + trace(debug::str<>("Threadlocal Ep"), "create Tx", "ep", + NS_DEBUG::ptr(ep_tx), "tx cq", NS_DEBUG::ptr(tx_cq), "rx cq", + NS_DEBUG::ptr(nullptr))); + // for cleaning up at termination + endpoint_wrapper ep(ep_tx, nullptr, tx_cq, "tx threadlocal"); + tx_endpoints_.push(ep); + eps_->tl_tx_ = stack_endpoint(ep_tx, nullptr, tx_cq, "threadlocal", nullptr); + } + return eps_->tl_tx_.endpoint_; } - return eps_->tl_tx_.endpoint_; - } - else if (endpoint_type_ == endpoint_type::scalableTx || - endpoint_type_ == endpoint_type::scalableTxRx) - { - if (eps_->tl_stx_.get_ep() == nullptr) + else if (endpoint_type_ == endpoint_type::scalableTx || + endpoint_type_ == endpoint_type::scalableTxRx) { - endpoint_wrapper ep; - bool ok = tx_endpoints_.pop(ep); - if (!ok) + if (eps_->tl_stx_.get_ep() == nullptr) { + endpoint_wrapper ep; + bool ok = tx_endpoints_.pop(ep); + if (!ok) + { + LF_DEB(NS_DEBUG::cnb_deb, + error(debug::str<>("Scalable Ep"), "pop tx", "ep", + NS_DEBUG::ptr(ep.get_ep()), "tx cq", NS_DEBUG::ptr(ep.get_tx_cq()), + "rx cq", NS_DEBUG::ptr(ep.get_rx_cq()))); + throw std::runtime_error("tx endpoint wrapper pop fail"); + } + eps_->tl_stx_ = stack_endpoint( + ep.get_ep(), ep.get_rx_cq(), ep.get_tx_cq(), ep.get_name(), &tx_endpoints_); LF_DEB(NS_DEBUG::cnb_deb, - error(debug::str<>("Scalable Ep"), "pop tx", "ep", - NS_DEBUG::ptr(ep.get_ep()), "tx cq", NS_DEBUG::ptr(ep.get_tx_cq()), - "rx cq", NS_DEBUG::ptr(ep.get_rx_cq()))); - throw std::runtime_error("tx endpoint wrapper pop fail"); + trace(debug::str<>("Scalable Ep"), "pop tx", "ep", + NS_DEBUG::ptr(eps_->tl_stx_.get_ep()), "tx cq", + NS_DEBUG::ptr(eps_->tl_stx_.get_tx_cq()), "rx cq", + NS_DEBUG::ptr(eps_->tl_stx_.get_rx_cq()))); } - eps_->tl_stx_ = stack_endpoint(ep.get_ep(), ep.get_rx_cq(), ep.get_tx_cq(), - ep.get_name(), &tx_endpoints_); - LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>("Scalable Ep"), "pop tx", "ep", - NS_DEBUG::ptr(eps_->tl_stx_.get_ep()), "tx cq", - NS_DEBUG::ptr(eps_->tl_stx_.get_tx_cq()), "rx cq", - NS_DEBUG::ptr(eps_->tl_stx_.get_rx_cq()))); + return eps_->tl_stx_.endpoint_; } - return eps_->tl_stx_.endpoint_; + else if (endpoint_type_ == endpoint_type::multiple) { return eps_->ep_tx_; } + // single : shared tx/rx endpoint + return eps_->ep_rx_; } - else if (endpoint_type_ == endpoint_type::multiple) { return eps_->ep_tx_; } - // single : shared tx/rx endpoint - return eps_->ep_rx_; - } - - // -------------------------------------------------------------------- - void bind_address_vector_to_endpoint(struct fid_ep* endpoint, struct fid_av* av) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Binding AV"), "to", NS_DEBUG::ptr(endpoint))); - int ret = fi_ep_bind(endpoint, &av->fid, 0); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "bind address_vector"); - } - - // -------------------------------------------------------------------- - void bind_queue_to_endpoint(struct fid_ep* endpoint, struct fid_cq*& cq, uint32_t cqtype, - const char* type) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, type); - LF_DEB(NS_DEBUG::cnb_deb, - debug(debug::str<>("Binding CQ"), "to", NS_DEBUG::ptr(endpoint), type)); - int ret = fi_ep_bind(endpoint, &cq->fid, cqtype); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "bind cq"); - } + // -------------------------------------------------------------------- + void bind_address_vector_to_endpoint(struct fid_ep* endpoint, struct fid_av* av) + { + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - // -------------------------------------------------------------------- - fid_cq* bind_tx_queue_to_rx_endpoint(struct fi_info* info, struct fid_ep* ep) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - info->tx_attr->op_flags |= (FI_INJECT_COMPLETE | FI_COMPLETION); - fid_cq* tx_cq = create_completion_queue(fabric_domain_, info->tx_attr->size, "tx->rx"); - // shared send/recv endpoint - bind send cq to the recv endpoint - bind_queue_to_endpoint(ep, tx_cq, FI_TRANSMIT, "tx->rx bug fix"); - return tx_cq; - } + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("Binding AV"), "to", NS_DEBUG::ptr(endpoint))); + int ret = fi_ep_bind(endpoint, &av->fid, 0); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "bind address_vector"); + } - // -------------------------------------------------------------------- - void enable_endpoint(struct fid_ep* endpoint, const char* type) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, type); + // -------------------------------------------------------------------- + void bind_queue_to_endpoint( + struct fid_ep* endpoint, struct fid_cq*& cq, uint32_t cqtype, char const* type) + { + [[maybe_unused]] auto scp = + NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, type); - LF_DEB(NS_DEBUG::cnb_deb, - debug(debug::str<>("Enabling endpoint"), NS_DEBUG::ptr(endpoint))); - int ret = fi_enable(endpoint); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_enable"); - } + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("Binding CQ"), "to", NS_DEBUG::ptr(endpoint), type)); + int ret = fi_ep_bind(endpoint, &cq->fid, cqtype); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "bind cq"); + } - // -------------------------------------------------------------------- - locality get_endpoint_address(struct fid* id) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + // -------------------------------------------------------------------- + fid_cq* bind_tx_queue_to_rx_endpoint(struct fi_info* info, struct fid_ep* ep) + { + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + info->tx_attr->op_flags |= (FI_INJECT_COMPLETE | FI_COMPLETION); + fid_cq* tx_cq = create_completion_queue(fabric_domain_, info->tx_attr->size, "tx->rx"); + // shared send/recv endpoint - bind send cq to the recv endpoint + bind_queue_to_endpoint(ep, tx_cq, FI_TRANSMIT, "tx->rx bug fix"); + return tx_cq; + } - locality::locality_data local_addr; - std::size_t addrlen = locality_defs::array_size; - int ret = fi_getname(id, local_addr.data(), &addrlen); - if (ret || (addrlen > locality_defs::array_size)) + // -------------------------------------------------------------------- + void enable_endpoint(struct fid_ep* endpoint, char const* type) { - std::string err = - std::to_string(addrlen) + "=" + std::to_string(locality_defs::array_size); - NS_LIBFABRIC::fabric_error(ret, "fi_getname - size error or other problem " + err); + [[maybe_unused]] auto scp = + NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, type); + + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("Enabling endpoint"), NS_DEBUG::ptr(endpoint))); + int ret = fi_enable(endpoint); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_enable"); } - // optimized out when debug logging is false - if constexpr (NS_DEBUG::cnb_deb.is_enabled()) + // -------------------------------------------------------------------- + locality get_endpoint_address(struct fid* id) { - std::stringstream temp1; - for (std::size_t i = 0; i < locality_defs::array_length; ++i) + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + + locality::locality_data local_addr; + std::size_t addrlen = locality_defs::array_size; + int ret = fi_getname(id, local_addr.data(), &addrlen); + if (ret || (addrlen > locality_defs::array_size)) { - temp1 << debug::ipaddr(&local_addr[i]) << " - "; + std::string err = + std::to_string(addrlen) + "=" + std::to_string(locality_defs::array_size); + NS_LIBFABRIC::fabric_error(ret, "fi_getname - size error or other problem " + err); } - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("raw address data"), "size", - debug::dec<>(addrlen), " : ", temp1.str().c_str())); - std::stringstream temp2; - for (std::size_t i = 0; i < locality_defs::array_length; ++i) + // optimized out when debug logging is false + if constexpr (NS_DEBUG::cnb_deb.is_enabled()) { - temp2 << debug::hex<8>(local_addr[i]) << " - "; + std::stringstream temp1; + for (std::size_t i = 0; i < locality_defs::array_length; ++i) + { + temp1 << debug::ipaddr(&local_addr[i]) << " - "; + } + + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("raw address data"), "size", debug::dec<>(addrlen), " : ", + temp1.str().c_str())); + std::stringstream temp2; + for (std::size_t i = 0; i < locality_defs::array_length; ++i) + { + temp2 << debug::hex<8>(local_addr[i]) << " - "; + } + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("raw address data"), temp2.str().c_str())); } - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("raw address data"), temp2.str().c_str())); + return locality(local_addr); } - return locality(local_addr); - } - // -------------------------------------------------------------------- - fid_pep* create_passive_endpoint(struct fid_fabric* fabric, struct fi_info* info) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + // -------------------------------------------------------------------- + fid_pep* create_passive_endpoint(struct fid_fabric* fabric, struct fi_info* info) + { + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - struct fid_pep* ep; - int ret = fi_passive_ep(fabric, info, &ep, nullptr); - if (ret) { throw NS_LIBFABRIC::fabric_error(ret, "Failed to create fi_passive_ep"); } - return ep; - } + struct fid_pep* ep; + int ret = fi_passive_ep(fabric, info, &ep, nullptr); + if (ret) { throw NS_LIBFABRIC::fabric_error(ret, "Failed to create fi_passive_ep"); } + return ep; + } - // -------------------------------------------------------------------- - inline const locality& here() const { return here_; } + // -------------------------------------------------------------------- + inline locality const& here() const { return here_; } - // -------------------------------------------------------------------- - inline const fi_addr_t& fi_address() const { return here_.fi_address(); } + // -------------------------------------------------------------------- + inline fi_addr_t const& fi_address() const { return here_.fi_address(); } - // -------------------------------------------------------------------- - inline void setHere(const locality& val) { here_ = val; } + // -------------------------------------------------------------------- + inline void setHere(locality const& val) { here_ = val; } - // -------------------------------------------------------------------- - inline const locality& root() const { return root_; } + // -------------------------------------------------------------------- + inline locality const& root() const { return root_; } - // -------------------------------------------------------------------- - inline struct fid_domain* get_domain() const { return fabric_domain_; } + // -------------------------------------------------------------------- + inline struct fid_domain* get_domain() const { return fabric_domain_; } - // -------------------------------------------------------------------- - inline std::size_t get_rma_protocol_size() { return 65536; } + // -------------------------------------------------------------------- + inline std::size_t get_rma_protocol_size() { return 65536; } #ifdef DISABLE_FI_INJECT - // -------------------------------------------------------------------- - inline std::size_t get_tx_inject_size() { return 0; } + // -------------------------------------------------------------------- + inline std::size_t get_tx_inject_size() { return 0; } #else - // -------------------------------------------------------------------- - inline std::size_t get_tx_inject_size() { return tx_inject_size_; } + // -------------------------------------------------------------------- + inline std::size_t get_tx_inject_size() { return tx_inject_size_; } #endif - // -------------------------------------------------------------------- - inline std::size_t get_tx_size() { return tx_attr_size_; } + // -------------------------------------------------------------------- + inline std::size_t get_tx_size() { return tx_attr_size_; } - // -------------------------------------------------------------------- - inline std::size_t get_rx_size() { return rx_attr_size_; } + // -------------------------------------------------------------------- + inline std::size_t get_rx_size() { return rx_attr_size_; } - // -------------------------------------------------------------------- - // returns true when all connections have been disconnected and none are active - inline bool isTerminated() - { - return false; - //return (qp_endpoint_map_.size() == 0); - } + // -------------------------------------------------------------------- + // returns true when all connections have been disconnected and none are active + inline bool isTerminated() + { + return false; + //return (qp_endpoint_map_.size() == 0); + } - // -------------------------------------------------------------------- - void debug_print_av_vector(std::size_t N) - { - locality addr; - std::size_t addrlen = locality_defs::array_size; - for (std::size_t i = 0; i < N; ++i) + // -------------------------------------------------------------------- + void debug_print_av_vector(std::size_t N) { - int ret = fi_av_lookup(av_, fi_addr_t(i), addr.fabric_data_writable(), &addrlen); - addr.set_fi_address(fi_addr_t(i)); - if ((ret == 0) && (addrlen == locality_defs::array_size)) - { - LF_DEB(NS_DEBUG::cnb_deb, - debug(debug::str<>("address vector"), debug::dec<3>(i), iplocality(addr))); - } - else + locality addr; + std::size_t addrlen = locality_defs::array_size; + for (std::size_t i = 0; i < N; ++i) { - LF_DEB(NS_DEBUG::cnb_err, - error(debug::str<>("address length"), debug::dec<3>(addrlen), - debug::dec<3>(locality_defs::array_size))); - throw std::runtime_error("debug_print_av_vector : address vector " - "traversal failure"); + int ret = fi_av_lookup(av_, fi_addr_t(i), addr.fabric_data_writable(), &addrlen); + addr.set_fi_address(fi_addr_t(i)); + if ((ret == 0) && (addrlen == locality_defs::array_size)) + { + LF_DEB(NS_DEBUG::cnb_deb, + debug(debug::str<>("address vector"), debug::dec<3>(i), iplocality(addr))); + } + else + { + LF_DEB(NS_DEBUG::cnb_err, + error(debug::str<>("address length"), debug::dec<3>(addrlen), + debug::dec<3>(locality_defs::array_size))); + throw std::runtime_error("debug_print_av_vector : address vector " + "traversal failure"); + } } } - } - // -------------------------------------------------------------------- - inline constexpr bool bypass_tx_lock() - { + // -------------------------------------------------------------------- + inline constexpr bool bypass_tx_lock() + { #if defined(HAVE_LIBFABRIC_GNI) - return true; + return true; #elif defined(HAVE_LIBFABRIC_CXI) - // @todo : cxi provider is not yet thread safe using scalable endpoints - return false; + // @todo : cxi provider is not yet thread safe using scalable endpoints + return false; #else - return (threadlevel_flags() == FI_THREAD_SAFE || + return (threadlevel_flags() == FI_THREAD_SAFE || endpoint_type_ == endpoint_type::threadlocalTx); #endif - } + } - // -------------------------------------------------------------------- - inline controller_base::unique_lock get_tx_lock() - { - if (bypass_tx_lock()) return unique_lock(); - return unique_lock(send_mutex_); - } + // -------------------------------------------------------------------- + inline controller_base::unique_lock get_tx_lock() + { + if (bypass_tx_lock()) return unique_lock(); + return unique_lock(send_mutex_); + } - // -------------------------------------------------------------------- - inline controller_base::unique_lock try_tx_lock() - { - if (bypass_tx_lock()) return unique_lock(); - return unique_lock(send_mutex_, std::try_to_lock_t{}); - } + // -------------------------------------------------------------------- + inline controller_base::unique_lock try_tx_lock() + { + if (bypass_tx_lock()) return unique_lock(); + return unique_lock(send_mutex_, std::try_to_lock_t{}); + } - // -------------------------------------------------------------------- - inline constexpr bool bypass_rx_lock() - { + // -------------------------------------------------------------------- + inline constexpr bool bypass_rx_lock() + { #ifdef HAVE_LIBFABRIC_GNI - return true; + return true; #else - return ( - threadlevel_flags() == FI_THREAD_SAFE || endpoint_type_ == endpoint_type::scalableTxRx); + return (threadlevel_flags() == FI_THREAD_SAFE || + endpoint_type_ == endpoint_type::scalableTxRx); #endif - } + } - // -------------------------------------------------------------------- - inline controller_base::unique_lock get_rx_lock() - { - if (bypass_rx_lock()) return unique_lock(); - return unique_lock(recv_mutex_); - } + // -------------------------------------------------------------------- + inline controller_base::unique_lock get_rx_lock() + { + if (bypass_rx_lock()) return unique_lock(); + return unique_lock(recv_mutex_); + } - // -------------------------------------------------------------------- - inline controller_base::unique_lock try_rx_lock() - { - if (bypass_rx_lock()) return unique_lock(); - return unique_lock(recv_mutex_, std::try_to_lock_t{}); - } + // -------------------------------------------------------------------- + inline controller_base::unique_lock try_rx_lock() + { + if (bypass_rx_lock()) return unique_lock(); + return unique_lock(recv_mutex_, std::try_to_lock_t{}); + } - // -------------------------------------------------------------------- - progress_status poll_for_work_completions(void* user_data) - { - progress_status p{0, 0}; - bool retry = false; - do { - // sends - uint32_t nsend = static_cast(this)->poll_send_queue( - get_tx_endpoint().get_tx_cq(), user_data); - p.m_num_sends += nsend; - retry = (nsend == max_completions_per_poll_); - // recvs - uint32_t nrecv = static_cast(this)->poll_recv_queue( - get_rx_endpoint().get_rx_cq(), user_data); - p.m_num_recvs += nrecv; - retry |= (nrecv == max_completions_per_poll_); - } while (retry); - return p; - } + // -------------------------------------------------------------------- + progress_status poll_for_work_completions(void* user_data) + { + progress_status p{0, 0}; + bool retry = false; + do { + // sends + uint32_t nsend = static_cast(this)->poll_send_queue( + get_tx_endpoint().get_tx_cq(), user_data); + p.m_num_sends += nsend; + retry = (nsend == max_completions_per_poll_); + // recvs + uint32_t nrecv = static_cast(this)->poll_recv_queue( + get_rx_endpoint().get_rx_cq(), user_data); + p.m_num_recvs += nrecv; + retry |= (nrecv == max_completions_per_poll_); + } while (retry); + return p; + } - // -------------------------------------------------------------------- - inline int poll_send_queue(fid_cq* tx_cq, void* user_data) - { - return static_cast(this)->poll_send_queue(tx_cq, user_data); - } + // -------------------------------------------------------------------- + inline int poll_send_queue(fid_cq* tx_cq, void* user_data) + { + return static_cast(this)->poll_send_queue(tx_cq, user_data); + } - // -------------------------------------------------------------------- - inline int poll_recv_queue(fid_cq* rx_cq, void* user_data) - { - return static_cast(this)->poll_recv_queue(rx_cq, user_data); - } + // -------------------------------------------------------------------- + inline int poll_recv_queue(fid_cq* rx_cq, void* user_data) + { + return static_cast(this)->poll_recv_queue(rx_cq, user_data); + } - // -------------------------------------------------------------------- - struct fid_cq* create_completion_queue(struct fid_domain* domain, size_t size, const char* type) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, type); - - struct fid_cq* cq; - fi_cq_attr cq_attr = {}; - cq_attr.format = FI_CQ_FORMAT_MSG; - cq_attr.wait_obj = FI_WAIT_NONE; - cq_attr.wait_cond = FI_CQ_COND_NONE; - cq_attr.size = size; - cq_attr.flags = 0 /*FI_COMPLETION*/; - LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>("CQ size"), debug::dec<4>(size))); - // open completion queue on fabric domain and set context to null - int ret = fi_cq_open(domain, &cq_attr, &cq, nullptr); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_cq_open"); - return cq; - } + // -------------------------------------------------------------------- + struct fid_cq* create_completion_queue( + struct fid_domain* domain, size_t size, char const* type) + { + [[maybe_unused]] auto scp = + NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__, type); + + struct fid_cq* cq; + fi_cq_attr cq_attr = {}; + cq_attr.format = FI_CQ_FORMAT_MSG; + cq_attr.wait_obj = FI_WAIT_NONE; + cq_attr.wait_cond = FI_CQ_COND_NONE; + cq_attr.size = size; + cq_attr.flags = 0 /*FI_COMPLETION*/; + LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>("CQ size"), debug::dec<4>(size))); + // open completion queue on fabric domain and set context to null + int ret = fi_cq_open(domain, &cq_attr, &cq, nullptr); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_cq_open"); + return cq; + } - // -------------------------------------------------------------------- - fid_av* create_address_vector(struct fi_info* info, int N, int num_rx_contexts) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + // -------------------------------------------------------------------- + fid_av* create_address_vector(struct fi_info* info, int N, int num_rx_contexts) + { + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); - fid_av* av; - fi_av_attr av_attr = {fi_av_type(0), 0, 0, 0, nullptr, nullptr, 0}; + fid_av* av; + fi_av_attr av_attr = {fi_av_type(0), 0, 0, 0, nullptr, nullptr, 0}; - // number of addresses expected - av_attr.count = N; + // number of addresses expected + av_attr.count = N; - // number of receive contexts used - int rx_ctx_bits = 0; + // number of receive contexts used + int rx_ctx_bits = 0; #ifdef RX_CONTEXTS_SUPPORT - while (num_rx_contexts >> ++rx_ctx_bits) - ; - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("rx_ctx_bits"), rx_ctx_bits)); + while (num_rx_contexts >> ++rx_ctx_bits); + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("rx_ctx_bits"), rx_ctx_bits)); #endif - av_attr.rx_ctx_bits = rx_ctx_bits; - // if contexts is nonzero, then we are using a single scalable endpoint - av_attr.ep_per_node = (num_rx_contexts > 0) ? 2 : 0; - - if (info->domain_attr->av_type != FI_AV_UNSPEC) - { - av_attr.type = info->domain_attr->av_type; - } - else - { - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("map FI_AV_TABLE"))); - av_attr.type = FI_AV_TABLE; - } + av_attr.rx_ctx_bits = rx_ctx_bits; + // if contexts is nonzero, then we are using a single scalable endpoint + av_attr.ep_per_node = (num_rx_contexts > 0) ? 2 : 0; - LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Creating AV"))); - int ret = fi_av_open(fabric_domain_, &av_attr, &av, nullptr); - if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_av_open"); - return av; - } + if (info->domain_attr->av_type != FI_AV_UNSPEC) + { + av_attr.type = info->domain_attr->av_type; + } + else + { + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("map FI_AV_TABLE"))); + av_attr.type = FI_AV_TABLE; + } - // -------------------------------------------------------------------- - locality insert_address(const locality& address) { return insert_address(av_, address); } + LF_DEB(NS_DEBUG::cnb_deb, debug(debug::str<>("Creating AV"))); + int ret = fi_av_open(fabric_domain_, &av_attr, &av, nullptr); + if (ret) throw NS_LIBFABRIC::fabric_error(ret, "fi_av_open"); + return av; + } - // -------------------------------------------------------------------- - locality insert_address(fid_av* av, const locality& address) - { - [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + // -------------------------------------------------------------------- + locality insert_address(locality const& address) { return insert_address(av_, address); } - LF_DEB(NS_DEBUG::cnb_deb, - trace(debug::str<>("inserting AV"), iplocality(address), NS_DEBUG::ptr(av))); - fi_addr_t fi_addr = 0xffffffff; - int ret = fi_av_insert(av, address.fabric_data(), 1, &fi_addr, 0, nullptr); - if (ret < 0) { throw NS_LIBFABRIC::fabric_error(ret, "fi_av_insert"); } - else if (ret == 0) + // -------------------------------------------------------------------- + locality insert_address(fid_av* av, locality const& address) { - NS_DEBUG::cnb_deb.error("fi_av_insert called with existing address"); - NS_LIBFABRIC::fabric_error(ret, "fi_av_insert did not return 1"); + [[maybe_unused]] auto scp = NS_DEBUG::cnb_deb.scope(NS_DEBUG::ptr(this), __func__); + + LF_DEB(NS_DEBUG::cnb_deb, + trace(debug::str<>("inserting AV"), iplocality(address), NS_DEBUG::ptr(av))); + fi_addr_t fi_addr = 0xffff'ffff; + int ret = fi_av_insert(av, address.fabric_data(), 1, &fi_addr, 0, nullptr); + if (ret < 0) { throw NS_LIBFABRIC::fabric_error(ret, "fi_av_insert"); } + else if (ret == 0) + { + NS_DEBUG::cnb_deb.error("fi_av_insert called with existing address"); + NS_LIBFABRIC::fabric_error(ret, "fi_av_insert did not return 1"); + } + // address was generated correctly, now update the locality with the fi_addr + locality new_locality(address, fi_addr); + LF_DEB(NS_DEBUG::cnb_deb, + trace(debug::str<>("AV add"), "rank", debug::dec<>(fi_addr), + iplocality(new_locality), "fi_addr", debug::hex<4>(fi_addr))); + return new_locality; } - // address was generated correctly, now update the locality with the fi_addr - locality new_locality(address, fi_addr); - LF_DEB(NS_DEBUG::cnb_deb, trace(debug::str<>("AV add"), "rank", debug::dec<>(fi_addr), - iplocality(new_locality), "fi_addr", debug::hex<4>(fi_addr))); - return new_locality; - } -}; + }; -} // namespace NS_LIBFABRIC +} // namespace NS_LIBFABRIC diff --git a/src/libfabric/fabric_error.hpp b/src/libfabric/fabric_error.hpp index 0f2db4c1..325975a7 100644 --- a/src/libfabric/fabric_error.hpp +++ b/src/libfabric/fabric_error.hpp @@ -10,43 +10,41 @@ #pragma once #include -#include #include +#include // #include // #include "oomph_libfabric_defines.hpp" -namespace NS_DEBUG -{ -// cppcheck-suppress ConfigurationNotChecked -static NS_DEBUG::enable_print err_deb("ERROR__"); -} // namespace NS_DEBUG +namespace NS_DEBUG { + // cppcheck-suppress ConfigurationNotChecked + static NS_DEBUG::enable_print err_deb("ERROR__"); +} // namespace NS_DEBUG -namespace NS_LIBFABRIC -{ +namespace NS_LIBFABRIC { -class fabric_error : public std::runtime_error -{ - public: - // -------------------------------------------------------------------- - fabric_error(int err, const std::string& msg) - : std::runtime_error(std::string(fi_strerror(-err)) + msg) - , error_(err) + class fabric_error : public std::runtime_error { - NS_DEBUG::err_deb.error(msg, ":", fi_strerror(-err)); - std::terminate(); - } + public: + // -------------------------------------------------------------------- + fabric_error(int err, std::string const& msg) + : std::runtime_error(std::string(fi_strerror(-err)) + msg) + , error_(err) + { + NS_DEBUG::err_deb.error(msg, ":", fi_strerror(-err)); + std::terminate(); + } - fabric_error(int err) - : std::runtime_error(fi_strerror(-err)) - , error_(-err) - { - NS_DEBUG::err_deb.error(what()); - std::terminate(); - } + fabric_error(int err) + : std::runtime_error(fi_strerror(-err)) + , error_(-err) + { + NS_DEBUG::err_deb.error(what()); + std::terminate(); + } - int error_; -}; + int error_; + }; -} // namespace NS_LIBFABRIC +} // namespace NS_LIBFABRIC diff --git a/src/libfabric/libfabric_defines_template.hpp b/src/libfabric/libfabric_defines_template.hpp index 64c04944..efd2bb67 100644 --- a/src/libfabric/libfabric_defines_template.hpp +++ b/src/libfabric/libfabric_defines_template.hpp @@ -14,26 +14,26 @@ // some namespaces for the lib and for debugging are setup correctly #define NS_LIBFABRIC oomph::libfabric -#define NS_MEMORY oomph::libfabric -#define NS_DEBUG oomph::debug +#define NS_MEMORY oomph::libfabric +#define NS_DEBUG oomph::debug #ifndef LF_DEB -#define LF_DEB(printer, Expr) \ - if constexpr (printer.is_enabled()) { printer.Expr; }; +# define LF_DEB(printer, Expr) \ + if constexpr (printer.is_enabled()) { printer.Expr; }; #endif #define LFSOURCE_DIR "@OOMPH_SRC_LIBFABRIC_DIR@" -#define LFPRINT_HPP "@OOMPH_SRC_LIBFABRIC_DIR@/print.hpp" -#define LFCOUNT_HPP "@OOMPH_SRC_LIBFABRIC_DIR@/simple_counter.hpp" +#define LFPRINT_HPP "@OOMPH_SRC_LIBFABRIC_DIR@/print.hpp" +#define LFCOUNT_HPP "@OOMPH_SRC_LIBFABRIC_DIR@/simple_counter.hpp" // oomph has a debug print helper file in the main source tree #if __has_include(LFPRINT_HPP) -#include LFPRINT_HPP -#define has_debug 1 +# include LFPRINT_HPP +# define has_debug 1 #endif #if __has_include(LFCOUNT_HPP) -#include LFCOUNT_HPP +# include LFCOUNT_HPP #endif #endif diff --git a/src/libfabric/locality.cpp b/src/libfabric/locality.cpp index 487912f5..ff23eeb5 100644 --- a/src/libfabric/locality.cpp +++ b/src/libfabric/locality.cpp @@ -10,27 +10,22 @@ #include -namespace oomph -{ -namespace libfabric -{ +namespace oomph { namespace libfabric { -// ------------------------------------------------------------------ -// format as ip address, port, libfabric address -// ------------------------------------------------------------------ -iplocality::iplocality(const locality& l) -: data(l) -{ -} + // ------------------------------------------------------------------ + // format as ip address, port, libfabric address + // ------------------------------------------------------------------ + iplocality::iplocality(locality const& l) + : data(l) + { + } -std::ostream& -operator<<(std::ostream& os, const iplocality& p) -{ - os << std::dec << NS_DEBUG::ipaddr(p.data.fabric_data()) << " - " - << NS_DEBUG::ipaddr(p.data.ip_address()) << ":" << NS_DEBUG::dec<>(p.data.port()) << " (" - << NS_DEBUG::dec<>(p.data.fi_address()) << ") "; - return os; -} + std::ostream& operator<<(std::ostream& os, iplocality const& p) + { + os << std::dec << NS_DEBUG::ipaddr(p.data.fabric_data()) << " - " + << NS_DEBUG::ipaddr(p.data.ip_address()) << ":" << NS_DEBUG::dec<>(p.data.port()) << " (" + << NS_DEBUG::dec<>(p.data.fi_address()) << ") "; + return os; + } -} // namespace libfabric -} // namespace oomph +}} // namespace oomph::libfabric diff --git a/src/libfabric/locality.hpp b/src/libfabric/locality.hpp index 74f6b290..84f5ddc2 100644 --- a/src/libfabric/locality.hpp +++ b/src/libfabric/locality.hpp @@ -15,243 +15,238 @@ #include #include // -#include -#include #include +#include +#include // #include "oomph_libfabric_defines.hpp" // Different providers use different address formats that we must accommodate // in our locality object. #ifdef HAVE_LIBFABRIC_GNI -#define HAVE_LIBFABRIC_LOCALITY_SIZE 48 +# define HAVE_LIBFABRIC_LOCALITY_SIZE 48 #endif #ifdef HAVE_LIBFABRIC_CXI -#ifdef HAVE_LIBFABRIC_CXI_1_15 -#define HAVE_LIBFABRIC_LOCALITY_SIZE sizeof(int) -#else -#define HAVE_LIBFABRIC_LOCALITY_SIZE sizeof(long int) -#endif +# ifdef HAVE_LIBFABRIC_CXI_1_15 +# define HAVE_LIBFABRIC_LOCALITY_SIZE sizeof(int) +# else +# define HAVE_LIBFABRIC_LOCALITY_SIZE sizeof(long int) +# endif #endif #ifdef HAVE_LIBFABRIC_EFA -#define HAVE_LIBFABRIC_LOCALITY_SIZE 32 +# define HAVE_LIBFABRIC_LOCALITY_SIZE 32 #endif #if defined(HAVE_LIBFABRIC_VERBS) || defined(HAVE_LIBFABRIC_TCP) || \ defined(HAVE_LIBFABRIC_SOCKETS) || defined(HAVE_LIBFABRIC_PSM2) -#define HAVE_LIBFABRIC_LOCALITY_SIZE 16 -#define HAVE_LIBFABRIC_LOCALITY_SOCKADDR +# define HAVE_LIBFABRIC_LOCALITY_SIZE 16 +# define HAVE_LIBFABRIC_LOCALITY_SOCKADDR #endif -namespace oomph -{ -// cppcheck-suppress ConfigurationNotChecked -static NS_DEBUG::enable_print loc_deb("LOCALTY"); -} // namespace oomph - -namespace oomph -{ -namespace libfabric -{ - -struct locality; - -// ------------------------------------------------------------------ -// format as ip address, port, libfabric address -// ------------------------------------------------------------------ -struct iplocality -{ - const locality& data; - iplocality(const locality& a); - friend std::ostream& operator<<(std::ostream& os, const iplocality& p); -}; - -// -------------------------------------------------------------------- -// Locality, in this structure we store the information required by -// libfabric to make a connection to another node. -// With libfabric 1.4.x the array contains the fabric ip address stored -// as the second uint32_t in the array. For this reason we use an -// array of uint32_t rather than uint8_t/char so we can easily access -// the ip for debug/validation purposes -// -------------------------------------------------------------------- -namespace locality_defs -{ -// the number of 32bit ints stored in our array -const uint32_t array_size = HAVE_LIBFABRIC_LOCALITY_SIZE; -const uint32_t array_length = HAVE_LIBFABRIC_LOCALITY_SIZE / 4; -} // namespace locality_defs - -struct locality -{ - // array type of our locality data - typedef std::array locality_data; - - static const char* type() { return "libfabric"; } - - explicit locality(const locality_data& in_data) - { - std::memcpy(&data_[0], &in_data[0], locality_defs::array_size); - fi_address_ = 0; - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("expl constructing"), iplocality((*this)))); - } - - locality() - { - std::memset(&data_[0], 0x00, locality_defs::array_size); - fi_address_ = 0; - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("default construct"), iplocality((*this)))); - } - - locality(const locality& other) - : data_(other.data_) - , fi_address_(other.fi_address_) - { - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("copy construct"), iplocality((*this)))); - } - - locality(const locality& other, fi_addr_t addr) - : data_(other.data_) - , fi_address_(addr) - { - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("copy fi construct"), iplocality((*this)))); - } - - locality(locality&& other) - : data_(std::move(other.data_)) - , fi_address_(other.fi_address_) - { - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("move construct"), iplocality((*this)))); - } - - // provided to support sockets mode bootstrap - explicit locality(const std::string& address, const std::string& portnum) - { - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("explicit construct"), address, ":", portnum)); - // - struct sockaddr_in socket_data; - memset(&socket_data, 0, sizeof(socket_data)); - socket_data.sin_family = AF_INET; - socket_data.sin_port = htons(std::stol(portnum)); - inet_pton(AF_INET, address.c_str(), &(socket_data.sin_addr)); - // - std::memcpy(&data_[0], &socket_data, locality_defs::array_size); - fi_address_ = 0; - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("string constructing"), iplocality((*this)))); - } - - // some condition marking this locality as valid - explicit inline operator bool() const - { - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("bool operator"), iplocality((*this)))); - return (ip_address() != 0); - } - - inline bool valid() const - { - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("valid operator"), iplocality((*this)))); - return (ip_address() != 0); - } - - locality& operator=(const locality& other) - { - data_ = other.data_; - fi_address_ = other.fi_address_; - LF_DEB(loc_deb, - trace(NS_DEBUG::str<>("copy operator"), iplocality(*this), iplocality(other))); - return *this; - } - - bool operator==(const locality& other) - { - LF_DEB(loc_deb, - trace(NS_DEBUG::str<>("equality operator"), iplocality(*this), iplocality(other))); - return std::memcmp(&data_, &other.data_, locality_defs::array_size) == 0; - } - - bool less_than(const locality& other) - { - LF_DEB(loc_deb, - trace(NS_DEBUG::str<>("less operator"), iplocality(*this), iplocality(other))); - if (ip_address() < other.ip_address()) return true; - if (ip_address() == other.ip_address()) return port() < other.port(); - return false; - } - - const uint32_t& ip_address() const - { +namespace oomph { + // cppcheck-suppress ConfigurationNotChecked + static NS_DEBUG::enable_print loc_deb("LOCALTY"); +} // namespace oomph + +namespace oomph { namespace libfabric { + + struct locality; + + // ------------------------------------------------------------------ + // format as ip address, port, libfabric address + // ------------------------------------------------------------------ + struct iplocality + { + locality const& data; + iplocality(locality const& a); + friend std::ostream& operator<<(std::ostream& os, iplocality const& p); + }; + + // -------------------------------------------------------------------- + // Locality, in this structure we store the information required by + // libfabric to make a connection to another node. + // With libfabric 1.4.x the array contains the fabric ip address stored + // as the second uint32_t in the array. For this reason we use an + // array of uint32_t rather than uint8_t/char so we can easily access + // the ip for debug/validation purposes + // -------------------------------------------------------------------- + namespace locality_defs { + // the number of 32bit ints stored in our array + uint32_t const array_size = HAVE_LIBFABRIC_LOCALITY_SIZE; + uint32_t const array_length = HAVE_LIBFABRIC_LOCALITY_SIZE / 4; + } // namespace locality_defs + + struct locality + { + // array type of our locality data + typedef std::array locality_data; + + static char const* type() { return "libfabric"; } + + explicit locality(locality_data const& in_data) + { + std::memcpy(&data_[0], &in_data[0], locality_defs::array_size); + fi_address_ = 0; + LF_DEB(loc_deb, trace(NS_DEBUG::str<>("expl constructing"), iplocality((*this)))); + } + + locality() + { + std::memset(&data_[0], 0x00, locality_defs::array_size); + fi_address_ = 0; + LF_DEB(loc_deb, trace(NS_DEBUG::str<>("default construct"), iplocality((*this)))); + } + + locality(locality const& other) + : data_(other.data_) + , fi_address_(other.fi_address_) + { + LF_DEB(loc_deb, trace(NS_DEBUG::str<>("copy construct"), iplocality((*this)))); + } + + locality(locality const& other, fi_addr_t addr) + : data_(other.data_) + , fi_address_(addr) + { + LF_DEB(loc_deb, trace(NS_DEBUG::str<>("copy fi construct"), iplocality((*this)))); + } + + locality(locality&& other) + : data_(std::move(other.data_)) + , fi_address_(other.fi_address_) + { + LF_DEB(loc_deb, trace(NS_DEBUG::str<>("move construct"), iplocality((*this)))); + } + + // provided to support sockets mode bootstrap + explicit locality(std::string const& address, std::string const& portnum) + { + LF_DEB(loc_deb, trace(NS_DEBUG::str<>("explicit construct"), address, ":", portnum)); + // + struct sockaddr_in socket_data; + memset(&socket_data, 0, sizeof(socket_data)); + socket_data.sin_family = AF_INET; + socket_data.sin_port = htons(std::stol(portnum)); + inet_pton(AF_INET, address.c_str(), &(socket_data.sin_addr)); + // + std::memcpy(&data_[0], &socket_data, locality_defs::array_size); + fi_address_ = 0; + LF_DEB(loc_deb, trace(NS_DEBUG::str<>("string constructing"), iplocality((*this)))); + } + + // some condition marking this locality as valid + explicit inline operator bool() const + { + LF_DEB(loc_deb, trace(NS_DEBUG::str<>("bool operator"), iplocality((*this)))); + return (ip_address() != 0); + } + + inline bool valid() const + { + LF_DEB(loc_deb, trace(NS_DEBUG::str<>("valid operator"), iplocality((*this)))); + return (ip_address() != 0); + } + + locality& operator=(locality const& other) + { + data_ = other.data_; + fi_address_ = other.fi_address_; + LF_DEB(loc_deb, + trace(NS_DEBUG::str<>("copy operator"), iplocality(*this), iplocality(other))); + return *this; + } + + bool operator==(locality const& other) + { + LF_DEB(loc_deb, + trace(NS_DEBUG::str<>("equality operator"), iplocality(*this), iplocality(other))); + return std::memcmp(&data_, &other.data_, locality_defs::array_size) == 0; + } + + bool less_than(locality const& other) + { + LF_DEB(loc_deb, + trace(NS_DEBUG::str<>("less operator"), iplocality(*this), iplocality(other))); + if (ip_address() < other.ip_address()) return true; + if (ip_address() == other.ip_address()) return port() < other.port(); + return false; + } + + uint32_t const& ip_address() const + { #if defined(HAVE_LIBFABRIC_LOCALITY_SOCKADDR) - return reinterpret_cast(data_.data())->sin_addr.s_addr; + return reinterpret_cast(data_.data())->sin_addr.s_addr; #elif defined(HAVE_LIBFABRIC_GNI) - return data_[0]; + return data_[0]; #elif defined(HAVE_LIBFABRIC_CXI) - return data_[0]; + return data_[0]; #elif defined(HAVE_LIBFABRIC_EFA) - return data_[0]; + return data_[0]; #else - throw fabric_error(0, "unsupported fabric provider, please fix ASAP"); + throw fabric_error(0, "unsupported fabric provider, please fix ASAP"); #endif - } + } - static const uint32_t& ip_address(const locality_data& data) - { + static uint32_t const& ip_address(locality_data const& data) + { #if defined(HAVE_LIBFABRIC_LOCALITY_SOCKADDR) - return reinterpret_cast(&data)->sin_addr.s_addr; + return reinterpret_cast(&data)->sin_addr.s_addr; #elif defined(HAVE_LIBFABRIC_GNI) - return data[0]; + return data[0]; #elif defined(HAVE_LIBFABRIC_CXI) - return data[0]; + return data[0]; #elif defined(HAVE_LIBFABRIC_EFA) - return data[0]; + return data[0]; #else - throw fabric_error(0, "unsupported fabric provider, please fix ASAP"); + throw fabric_error(0, "unsupported fabric provider, please fix ASAP"); #endif - } - - inline const fi_addr_t& fi_address() const { return fi_address_; } - - inline void set_fi_address(fi_addr_t fi_addr) { fi_address_ = fi_addr; } - - inline uint16_t port() const - { - uint16_t port = 256 * reinterpret_cast(data_.data())[2] + - reinterpret_cast(data_.data())[3]; - return port; - } - - inline const void* fabric_data() const { return data_.data(); } - - inline char* fabric_data_writable() { return reinterpret_cast(data_.data()); } - - private: - friend bool operator==(locality const& lhs, locality const& rhs) - { - LF_DEB(loc_deb, - trace(NS_DEBUG::str<>("equality friend"), iplocality(lhs), iplocality(rhs))); - return ((lhs.data_ == rhs.data_) && (lhs.fi_address_ == rhs.fi_address_)); - } - - friend bool operator<(locality const& lhs, locality const& rhs) - { - const uint32_t& a1 = lhs.ip_address(); - const uint32_t& a2 = rhs.ip_address(); - const fi_addr_t& f1 = lhs.fi_address(); - const fi_addr_t& f2 = rhs.fi_address(); - LF_DEB(loc_deb, trace(NS_DEBUG::str<>("less friend"), iplocality(lhs), iplocality(rhs))); - return (a1 < a2) || (a1 == a2 && f1 < f2); - } - - friend std::ostream& operator<<(std::ostream& os, locality const& loc) - { - for (uint32_t i = 0; i < locality_defs::array_length; ++i) { os << loc.data_[i]; } - return os; - } - - private: - locality_data data_; - fi_addr_t fi_address_; -}; - -} // namespace libfabric -} // namespace oomph + } + + inline fi_addr_t const& fi_address() const { return fi_address_; } + + inline void set_fi_address(fi_addr_t fi_addr) { fi_address_ = fi_addr; } + + inline uint16_t port() const + { + uint16_t port = 256 * reinterpret_cast(data_.data())[2] + + reinterpret_cast(data_.data())[3]; + return port; + } + + inline void const* fabric_data() const { return data_.data(); } + + inline char* fabric_data_writable() { return reinterpret_cast(data_.data()); } + + private: + friend bool operator==(locality const& lhs, locality const& rhs) + { + LF_DEB(loc_deb, + trace(NS_DEBUG::str<>("equality friend"), iplocality(lhs), iplocality(rhs))); + return ((lhs.data_ == rhs.data_) && (lhs.fi_address_ == rhs.fi_address_)); + } + + friend bool operator<(locality const& lhs, locality const& rhs) + { + uint32_t const& a1 = lhs.ip_address(); + uint32_t const& a2 = rhs.ip_address(); + fi_addr_t const& f1 = lhs.fi_address(); + fi_addr_t const& f2 = rhs.fi_address(); + LF_DEB( + loc_deb, trace(NS_DEBUG::str<>("less friend"), iplocality(lhs), iplocality(rhs))); + return (a1 < a2) || (a1 == a2 && f1 < f2); + } + + friend std::ostream& operator<<(std::ostream& os, locality const& loc) + { + for (uint32_t i = 0; i < locality_defs::array_length; ++i) { os << loc.data_[i]; } + return os; + } + + private: + locality_data data_; + fi_addr_t fi_address_; + }; + +}} // namespace oomph::libfabric diff --git a/src/libfabric/memory_region.hpp b/src/libfabric/memory_region.hpp index 0cd5c4a7..f1eb5326 100644 --- a/src/libfabric/memory_region.hpp +++ b/src/libfabric/memory_region.hpp @@ -18,20 +18,19 @@ #include #include -#include "oomph_libfabric_defines.hpp" #include "fabric_error.hpp" +#include "oomph_libfabric_defines.hpp" #ifdef OOMPH_ENABLE_DEVICE -#include +# include #endif // ------------------------------------------------------------------ -namespace NS_MEMORY -{ +namespace NS_MEMORY { -static NS_DEBUG::enable_print mrn_deb("REGION_"); + static NS_DEBUG::enable_print mrn_deb("REGION_"); -/* + /* struct fi_mr_attr { union { const struct iovec *mr_iov; @@ -60,342 +59,356 @@ struct fi_mr_attr { */ -// This is the only part of the code that actually -// calls libfabric functions -struct region_provider -{ - // The internal memory region handle - using provider_region = struct fid_mr; - using provider_domain = struct fid_domain; - - // register region - static inline int fi_register_memory(provider_domain* pd, int device_id, const void* buf, - size_t len, uint64_t access_flags, uint64_t offset, uint64_t request_key, - struct fid_mr** mr) + // This is the only part of the code that actually + // calls libfabric functions + struct region_provider { - [[maybe_unused]] auto scp = - NS_MEMORY::mrn_deb.scope(__func__, NS_DEBUG::ptr(buf), NS_DEBUG::dec<>(len), device_id); - // - struct iovec addresses = {/*.iov_base = */ const_cast(buf), /*.iov_len = */ len}; - fi_mr_attr attr = { - /*.mr_iov = */ &addresses, - /*.iov_count = */ 1, - /*.access = */ access_flags, - /*.offset = */ offset, - /*.requested_key = */ request_key, - /*.context = */ nullptr, - /*.auth_key_size = */ 0, - /*.auth_key = */ nullptr, - /*.iface = */ FI_HMEM_SYSTEM, - /*.device = */ {0}, + // The internal memory region handle + using provider_region = struct fid_mr; + using provider_domain = struct fid_domain; + + // register region + static inline int fi_register_memory(provider_domain* pd, int device_id, void const* buf, + size_t len, uint64_t access_flags, uint64_t offset, uint64_t request_key, + struct fid_mr** mr) + { + [[maybe_unused]] auto scp = NS_MEMORY::mrn_deb.scope( + __func__, NS_DEBUG::ptr(buf), NS_DEBUG::dec<>(len), device_id); + // + struct iovec addresses = {/*.iov_base = */ const_cast(buf), /*.iov_len = */ len}; + fi_mr_attr attr = { + /*.mr_iov = */ &addresses, + /*.iov_count = */ 1, + /*.access = */ access_flags, + /*.offset = */ offset, + /*.requested_key = */ request_key, + /*.context = */ nullptr, + /*.auth_key_size = */ 0, + /*.auth_key = */ nullptr, + /*.iface = */ FI_HMEM_SYSTEM, + /*.device = */ {0}, #if (FI_MAJOR_VERSION > 1) || ((FI_MAJOR_VERSION == 1) && FI_MINOR_VERSION > 17) - /*.hmem_data = */ nullptr, + /*.hmem_data = */ nullptr, #endif #if (FI_MAJOR_VERSION >= 2) - /*page_size = */ static_cast(sysconf(_SC_PAGESIZE)), - /*base_mr = */ nullptr, - /*sub_mr_cnt = */ 0, - }; + /*page_size = */ static_cast(sysconf(_SC_PAGESIZE)), + /*base_mr = */ nullptr, + /*sub_mr_cnt = */ 0, + }; #else - }; + }; #endif - if (device_id >= 0) - { + if (device_id >= 0) + { #ifdef OOMPH_ENABLE_DEVICE - attr.device.cuda = device_id; - int handle = hwmalloc::get_device_id(); - attr.device.cuda = handle; -#if defined(OOMPH_DEVICE_CUDA) - attr.iface = FI_HMEM_CUDA; - LF_DEB(NS_MEMORY::mrn_deb, - trace(NS_DEBUG::str<>("CUDA"), "set device id", device_id, handle)); -#elif defined(OOMPH_DEVICE_HIP) - attr.iface = FI_HMEM_ROCR; - LF_DEB(NS_MEMORY::mrn_deb, - trace(NS_DEBUG::str<>("HIP"), "set device id", device_id, handle)); -#endif + attr.device.cuda = device_id; + int handle = hwmalloc::get_device_id(); + attr.device.cuda = handle; +# if defined(OOMPH_DEVICE_CUDA) + attr.iface = FI_HMEM_CUDA; + LF_DEB(NS_MEMORY::mrn_deb, + trace(NS_DEBUG::str<>("CUDA"), "set device id", device_id, handle)); +# elif defined(OOMPH_DEVICE_HIP) + attr.iface = FI_HMEM_ROCR; + LF_DEB(NS_MEMORY::mrn_deb, + trace(NS_DEBUG::str<>("HIP"), "set device id", device_id, handle)); +# endif #endif + } + uint64_t flags = 0; + int ret = fi_mr_regattr(pd, &attr, flags, mr); + if (ret) { throw NS_LIBFABRIC::fabric_error(int(ret), "register_memory"); } + return ret; } - uint64_t flags = 0; - int ret = fi_mr_regattr(pd, &attr, flags, mr); - if (ret) { throw NS_LIBFABRIC::fabric_error(int(ret), "register_memory"); } - return ret; - } - // unregister region - static inline int unregister_memory(provider_region* region) { return fi_close(®ion->fid); } - - // Default registration flags for this provider - static inline constexpr int access_flags() - { - return FI_READ | FI_WRITE | FI_RECV | FI_SEND /*| FI_REMOTE_READ | FI_REMOTE_WRITE*/; - } + // unregister region + static inline int unregister_memory(provider_region* region) + { + return fi_close(®ion->fid); + } - // Get the local descriptor of the memory region. - static inline void* get_local_key(provider_region* const region) { return fi_mr_desc(region); } + // Default registration flags for this provider + static inline constexpr int access_flags() + { + return FI_READ | FI_WRITE | FI_RECV | FI_SEND /*| FI_REMOTE_READ | FI_REMOTE_WRITE*/; + } - // Get the remote key of the memory region. - static inline uint64_t get_remote_key(provider_region* const region) - { - return fi_mr_key(region); - } -}; + // Get the local descriptor of the memory region. + static inline void* get_local_key(provider_region* const region) + { + return fi_mr_desc(region); + } -// -------------------------------------------------------------------- -// This is a handle to a small chunk of memory that has been registered -// as part of a much larger allocation (a memory_segment) -struct memory_handle -{ - // -------------------------------------------------------------------- - using provider_region = region_provider::provider_region; + // Get the remote key of the memory region. + static inline uint64_t get_remote_key(provider_region* const region) + { + return fi_mr_key(region); + } + }; // -------------------------------------------------------------------- - // Default constructor creates unusable handle(region) - memory_handle() - : address_{nullptr} - , region_{nullptr} - , size_{0} - , used_space_{0} - { - } - memory_handle(memory_handle const&) noexcept = default; - memory_handle& operator=(memory_handle const&) noexcept = default; - - memory_handle(provider_region* region, unsigned char* addr, - std::size_t size /*, uint32_t flags*/) noexcept - : address_{addr} - , region_{region} - , size_{uint32_t(size)} - , used_space_{0} + // This is a handle to a small chunk of memory that has been registered + // as part of a much larger allocation (a memory_segment) + struct memory_handle { - // LF_DEB(NS_MEMORY::mrn_deb, - // trace(NS_DEBUG::str<>("memory_handle"), *this)); - } + // -------------------------------------------------------------------- + using provider_region = region_provider::provider_region; + + // -------------------------------------------------------------------- + // Default constructor creates unusable handle(region) + memory_handle() + : address_{nullptr} + , region_{nullptr} + , size_{0} + , used_space_{0} + { + } + memory_handle(memory_handle const&) noexcept = default; + memory_handle& operator=(memory_handle const&) noexcept = default; + + memory_handle(provider_region* region, unsigned char* addr, + std::size_t size /*, uint32_t flags*/) noexcept + : address_{addr} + , region_{region} + , size_{uint32_t(size)} + , used_space_{0} + { + // LF_DEB(NS_MEMORY::mrn_deb, + // trace(NS_DEBUG::str<>("memory_handle"), *this)); + } - // -------------------------------------------------------------------- - // move constructor, clear other region so that it is not unregistered twice - memory_handle(memory_handle&& other) noexcept - : address_{other.address_} - , region_{std::exchange(other.region_, nullptr)} - , size_{other.size_} - , used_space_{other.used_space_} - { - } + // -------------------------------------------------------------------- + // move constructor, clear other region so that it is not unregistered twice + memory_handle(memory_handle&& other) noexcept + : address_{other.address_} + , region_{std::exchange(other.region_, nullptr)} + , size_{other.size_} + , used_space_{other.used_space_} + { + } - // -------------------------------------------------------------------- - // move assignment, clear other region so that it is not unregistered twice - memory_handle& operator=(memory_handle&& other) noexcept - { - address_ = other.address_; - region_ = std::exchange(other.region_, nullptr); - size_ = other.size_; - used_space_ = other.used_space_; - return *this; - } + // -------------------------------------------------------------------- + // move assignment, clear other region so that it is not unregistered twice + memory_handle& operator=(memory_handle&& other) noexcept + { + address_ = other.address_; + region_ = std::exchange(other.region_, nullptr); + size_ = other.size_; + used_space_ = other.used_space_; + return *this; + } - // -------------------------------------------------------------------- - // Return the address of this memory region block. - inline unsigned char* get_address(void) const { return address_; } + // -------------------------------------------------------------------- + // Return the address of this memory region block. + inline unsigned char* get_address(void) const { return address_; } - // -------------------------------------------------------------------- - // Get the local descriptor of the memory region. - inline void* get_local_key(void) const { return region_provider::get_local_key(region_); } + // -------------------------------------------------------------------- + // Get the local descriptor of the memory region. + inline void* get_local_key(void) const { return region_provider::get_local_key(region_); } - // -------------------------------------------------------------------- - // Get the remote key of the memory region. - inline uint64_t get_remote_key(void) const { return region_provider::get_remote_key(region_); } + // -------------------------------------------------------------------- + // Get the remote key of the memory region. + inline uint64_t get_remote_key(void) const + { + return region_provider::get_remote_key(region_); + } - // -------------------------------------------------------------------- - // Get the size of the memory chunk usable by this memory region, - // this may be smaller than the value returned by get_length - // if the region is a sub region (partial region) within another block - inline uint64_t get_size(void) const { return size_; } + // -------------------------------------------------------------------- + // Get the size of the memory chunk usable by this memory region, + // this may be smaller than the value returned by get_length + // if the region is a sub region (partial region) within another block + inline uint64_t get_size(void) const { return size_; } - // -------------------------------------------------------------------- - // Get the size used by a message in the memory region. - inline uint32_t get_message_length(void) const { return used_space_; } + // -------------------------------------------------------------------- + // Get the size used by a message in the memory region. + inline uint32_t get_message_length(void) const { return used_space_; } - // -------------------------------------------------------------------- - // Set the size used by a message in the memory region. - inline void set_message_length(uint32_t length) { used_space_ = length; } + // -------------------------------------------------------------------- + // Set the size used by a message in the memory region. + inline void set_message_length(uint32_t length) { used_space_ = length; } - // -------------------------------------------------------------------- - void release_region() noexcept { region_ = nullptr; } + // -------------------------------------------------------------------- + void release_region() noexcept { region_ = nullptr; } - // -------------------------------------------------------------------- - // return the underlying libfabric region handle - inline provider_region* get_region() const { return region_; } + // -------------------------------------------------------------------- + // return the underlying libfabric region handle + inline provider_region* get_region() const { return region_; } - // -------------------------------------------------------------------- - // Deregister the memory region. - // returns 0 when successful, -1 otherwise - int deregister(void) const - { - if (region_ /*&& !get_user_region()*/) + // -------------------------------------------------------------------- + // Deregister the memory region. + // returns 0 when successful, -1 otherwise + int deregister(void) const { - LF_DEB(NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("release"), region_)); - // - if (region_provider::unregister_memory(region_)) + if (region_ /*&& !get_user_region()*/) { - LF_DEB(NS_MEMORY::mrn_deb, error("fi_close mr failed")); - return -1; + LF_DEB(NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("release"), region_)); + // + if (region_provider::unregister_memory(region_)) + { + LF_DEB(NS_MEMORY::mrn_deb, error("fi_close mr failed")); + return -1; + } + else + { + LF_DEB( + NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("de-Registered region"), *this)); + } + region_ = nullptr; } - else - { - LF_DEB(NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("de-Registered region"), *this)); - } - region_ = nullptr; + return 0; } - return 0; - } - // -------------------------------------------------------------------- - friend std::ostream& operator<<(std::ostream& os, memory_handle const& region) - { - (void)region; + // -------------------------------------------------------------------- + friend std::ostream& operator<<(std::ostream& os, memory_handle const& region) + { + (void) region; #if 1 || has_debug - os << "region " - << NS_DEBUG::ptr(®ion) - //<< " fi_region " << NS_DEBUG::ptr(region.region_) - << " address " << NS_DEBUG::ptr(region.address_) << " size " - << NS_DEBUG::hex<6>(region.size_) - //<< " used_space " << NS_DEBUG::hex<6>(region.used_space_/*size_*/) - << " loc key " - << NS_DEBUG::ptr( - region.region_ ? region_provider::get_local_key(region.region_) : nullptr) - << " rem key " - << NS_DEBUG::ptr(region.region_ ? region_provider::get_remote_key(region.region_) : 0); - ///// clang-format off - ///// clang-format on + os << "region " + << NS_DEBUG::ptr(®ion) + //<< " fi_region " << NS_DEBUG::ptr(region.region_) + << " address " << NS_DEBUG::ptr(region.address_) << " size " + << NS_DEBUG::hex<6>(region.size_) + //<< " used_space " << NS_DEBUG::hex<6>(region.used_space_/*size_*/) + << " loc key " + << NS_DEBUG::ptr( + region.region_ ? region_provider::get_local_key(region.region_) : nullptr) + << " rem key " + << NS_DEBUG::ptr( + region.region_ ? region_provider::get_remote_key(region.region_) : 0); + ///// clang-format off + ///// clang-format on #endif - return os; - } - - protected: - // This gives the start address of this region. - // This is the address that should be used for data storage - unsigned char* address_; + return os; + } - // The hardware level handle to the region (as returned from libfabric fi_mr_reg) - mutable provider_region* region_; + protected: + // This gives the start address of this region. + // This is the address that should be used for data storage + unsigned char* address_; - // The (maximum available) size of the memory buffer - uint32_t size_; + // The hardware level handle to the region (as returned from libfabric fi_mr_reg) + mutable provider_region* region_; - // Space used by a message in the memory region. - // This may be smaller/less than the size available if more space - // was allocated than it turns out was needed - mutable uint32_t used_space_; -}; + // The (maximum available) size of the memory buffer + uint32_t size_; -// -------------------------------------------------------------------- -// a memory segment is a pinned block of memory that has been specialized -// by a particular region provider. Each provider (infiniband, libfabric, -// other) has a different definition for the object and the protection -// domain used to limit access. -// -------------------------------------------------------------------- -struct memory_segment : public memory_handle -{ - using provider_domain = region_provider::provider_domain; - using provider_region = region_provider::provider_region; - using handle_type = memory_handle; + // Space used by a message in the memory region. + // This may be smaller/less than the size available if more space + // was allocated than it turns out was needed + mutable uint32_t used_space_; + }; // -------------------------------------------------------------------- - memory_segment(provider_region* region, unsigned char* address, unsigned char* base_address, - uint64_t size) - : memory_handle(region, address, size) - , base_addr_(base_address) - { - } - + // a memory segment is a pinned block of memory that has been specialized + // by a particular region provider. Each provider (infiniband, libfabric, + // other) has a different definition for the object and the protection + // domain used to limit access. // -------------------------------------------------------------------- - // move constructor, clear other region - memory_segment(memory_segment&& other) noexcept - : memory_handle(std::move(other)) - , base_addr_{std::exchange(other.base_addr_, nullptr)} + struct memory_segment : public memory_handle { - } + using provider_domain = region_provider::provider_domain; + using provider_region = region_provider::provider_region; + using handle_type = memory_handle; + + // -------------------------------------------------------------------- + memory_segment(provider_region* region, unsigned char* address, unsigned char* base_address, + uint64_t size) + : memory_handle(region, address, size) + , base_addr_(base_address) + { + } - // -------------------------------------------------------------------- - // move assignment, clear other region - memory_segment& operator=(memory_segment&& other) noexcept - { - memory_handle(std::move(other)); - region_ = std::exchange(other.region_, nullptr); - return *this; - } + // -------------------------------------------------------------------- + // move constructor, clear other region + memory_segment(memory_segment&& other) noexcept + : memory_handle(std::move(other)) + , base_addr_{std::exchange(other.base_addr_, nullptr)} + { + } - // -------------------------------------------------------------------- - // construct a memory region object by registering an existing address buffer - // we do not cache local/remote keys here because memory segments are only - // used by the heap to store chunks and the user will always receive - // a memory_handle - which does have keys cached - memory_segment(provider_domain* pd, const void* buffer, const uint64_t length, bool bind_mr, - void* ep, int device_id) - { - // an rma key counter to keep some providers (CXI) happy - static std::atomic key = 0; - // - address_ = static_cast(const_cast(buffer)); - size_ = length; - used_space_ = length; - region_ = nullptr; - // - base_addr_ = memory_handle::address_; - LF_DEB(NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("memory_segment"), *this, device_id)); - - int ret = region_provider::fi_register_memory(pd, device_id, buffer, length, - region_provider::access_flags(), 0, key++, &(region_)); - if (!ret) + // -------------------------------------------------------------------- + // move assignment, clear other region + memory_segment& operator=(memory_segment&& other) noexcept { - LF_DEB(NS_MEMORY::mrn_deb, - trace(NS_DEBUG::str<>("Registered region"), "device", device_id, *this)); + memory_handle(std::move(other)); + region_ = std::exchange(other.region_, nullptr); + return *this; } - if (bind_mr) + // -------------------------------------------------------------------- + // construct a memory region object by registering an existing address buffer + // we do not cache local/remote keys here because memory segments are only + // used by the heap to store chunks and the user will always receive + // a memory_handle - which does have keys cached + memory_segment(provider_domain* pd, void const* buffer, uint64_t const length, bool bind_mr, + void* ep, int device_id) { - ret = fi_mr_bind(region_, (struct fid*)ep, 0); - if (ret) { throw NS_LIBFABRIC::fabric_error(int(ret), "fi_mr_bind"); } - else { LF_DEB(NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("Bound region"), *this)); } + // an rma key counter to keep some providers (CXI) happy + static std::atomic key = 0; + // + address_ = static_cast(const_cast(buffer)); + size_ = length; + used_space_ = length; + region_ = nullptr; + // + base_addr_ = memory_handle::address_; + LF_DEB(NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("memory_segment"), *this, device_id)); - ret = fi_mr_enable(region_); - if (ret) { throw NS_LIBFABRIC::fabric_error(int(ret), "fi_mr_enable"); } - else { LF_DEB(NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("Enabled region"), *this)); } + int ret = region_provider::fi_register_memory(pd, device_id, buffer, length, + region_provider::access_flags(), 0, key++, &(region_)); + if (!ret) + { + LF_DEB(NS_MEMORY::mrn_deb, + trace(NS_DEBUG::str<>("Registered region"), "device", device_id, *this)); + } + + if (bind_mr) + { + ret = fi_mr_bind(region_, (struct fid*) ep, 0); + if (ret) { throw NS_LIBFABRIC::fabric_error(int(ret), "fi_mr_bind"); } + else { LF_DEB(NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("Bound region"), *this)); } + + ret = fi_mr_enable(region_); + if (ret) { throw NS_LIBFABRIC::fabric_error(int(ret), "fi_mr_enable"); } + else + { + LF_DEB(NS_MEMORY::mrn_deb, trace(NS_DEBUG::str<>("Enabled region"), *this)); + } + } } - } - // -------------------------------------------------------------------- - // destroy the region and memory according to flag settings - ~memory_segment() { deregister(); } + // -------------------------------------------------------------------- + // destroy the region and memory according to flag settings + ~memory_segment() { deregister(); } - handle_type get_handle(std::size_t offset, std::size_t size) const noexcept - { - return memory_handle(region_, base_addr_ + offset, size); - } + handle_type get_handle(std::size_t offset, std::size_t size) const noexcept + { + return memory_handle(region_, base_addr_ + offset, size); + } - // -------------------------------------------------------------------- - // Get the address of the base memory region. - // This is the address of the memory allocated from the system - inline unsigned char* get_base_address(void) const { return base_addr_; } + // -------------------------------------------------------------------- + // Get the address of the base memory region. + // This is the address of the memory allocated from the system + inline unsigned char* get_base_address(void) const { return base_addr_; } - // -------------------------------------------------------------------- - friend std::ostream& operator<<(std::ostream& os, memory_segment const& region) - { - (void)region; + // -------------------------------------------------------------------- + friend std::ostream& operator<<(std::ostream& os, memory_segment const& region) + { + (void) region; #if has_debug - // clang-format off + // clang-format off os << *static_cast(®ion) << " base address " << NS_DEBUG::ptr(region.base_addr_); - // clang-format on + // clang-format on #endif - return os; - } + return os; + } - public: - // this is the base address of the memory registered by this segment - // individual memory_handles are offset from this address - unsigned char* base_addr_; -}; + public: + // this is the base address of the memory registered by this segment + // individual memory_handles are offset from this address + unsigned char* base_addr_; + }; -} // namespace NS_MEMORY +} // namespace NS_MEMORY diff --git a/src/libfabric/operation_context.cpp b/src/libfabric/operation_context.cpp index ce5081dd..8c8d277f 100644 --- a/src/libfabric/operation_context.cpp +++ b/src/libfabric/operation_context.cpp @@ -8,49 +8,52 @@ * SPDX-License-Identifier: BSD-3-Clause */ // paths relative to backend -#include -#include #include #include +#include +#include -namespace oomph::libfabric -{ -void -operation_context::handle_cancelled() -{ - [[maybe_unused]] auto scp = opctx_deb<1>.scope(NS_DEBUG::ptr(this), __func__); - // enqueue the cancelled/callback - if (std::holds_alternative(m_req)) - { - // regular (non-shared) recv - auto s = std::get(m_req); - while (!(s->m_comm->m_recv_cb_cancel.push(s))) {} - } - else if (std::holds_alternative(m_req)) +namespace oomph::libfabric { + void operation_context::handle_cancelled() { - // shared recv - auto s = std::get(m_req); - while (!(s->m_ctxt->m_recv_cb_cancel.push(s))) {} + [[maybe_unused]] auto scp = opctx_deb<1>.scope(NS_DEBUG::ptr(this), __func__); + // enqueue the cancelled/callback + if (std::holds_alternative(m_req)) + { + // regular (non-shared) recv + auto s = std::get(m_req); + while (!(s->m_comm->m_recv_cb_cancel.push(s))) {} + } + else if (std::holds_alternative(m_req)) + { + // shared recv + auto s = std::get(m_req); + while (!(s->m_ctxt->m_recv_cb_cancel.push(s))) {} + } + else { throw std::runtime_error("Request state invalid in handle_cancelled"); } } - else { throw std::runtime_error("Request state invalid in handle_cancelled"); } -} -int -operation_context::handle_tagged_recv_completion_impl(void* user_data) -{ - [[maybe_unused]] auto scp = opctx_deb<1>.scope(NS_DEBUG::ptr(this), __func__); - if (std::holds_alternative(m_req)) + int operation_context::handle_tagged_recv_completion_impl(void* user_data) { - // regular (non-shared) recv - auto s = std::get(m_req); - //if (std::this_thread::get_id() == thread_id_) - if (reinterpret_cast(user_data) == s->m_comm) + [[maybe_unused]] auto scp = opctx_deb<1>.scope(NS_DEBUG::ptr(this), __func__); + if (std::holds_alternative(m_req)) { - if (!s->m_comm->has_reached_recursion_depth()) + // regular (non-shared) recv + auto s = std::get(m_req); + //if (std::this_thread::get_id() == thread_id_) + if (reinterpret_cast(user_data) == s->m_comm) { - auto inc = s->m_comm->recursion(); - auto ptr = s->release_self_ref(); - s->invoke_cb(); + if (!s->m_comm->has_reached_recursion_depth()) + { + auto inc = s->m_comm->recursion(); + auto ptr = s->release_self_ref(); + s->invoke_cb(); + } + else + { + // enqueue the callback + while (!(s->m_comm->m_recv_cb_queue.push(s))) {} + } } else { @@ -58,82 +61,76 @@ operation_context::handle_tagged_recv_completion_impl(void* user_data) while (!(s->m_comm->m_recv_cb_queue.push(s))) {} } } - else - { - // enqueue the callback - while (!(s->m_comm->m_recv_cb_queue.push(s))) {} - } - } - else if (std::holds_alternative(m_req)) - { - // shared recv - auto s = std::get(m_req); - if (!s->m_comm->m_context->has_reached_recursion_depth()) + else if (std::holds_alternative(m_req)) { - auto inc = s->m_comm->m_context->recursion(); - auto ptr = s->release_self_ref(); - s->invoke_cb(); - } - else - { - // enqueue the callback - while (!(s->m_comm->m_context->m_recv_cb_queue.push(s))) {} - } - } - else - { - detail::request_state** req = reinterpret_cast(&m_req); - LF_DEB(NS_MEMORY::opctx_deb<9>, - error(NS_DEBUG::str<>("invalid request_state"), this, "request", NS_DEBUG::ptr(req))); - throw std::runtime_error("Request state invalid in handle_tagged_recv"); - } - return 1; -} - -int -operation_context::handle_tagged_send_completion_impl(void* user_data) -{ - if (std::holds_alternative(m_req)) - { - // regular (non-shared) recv - auto s = std::get(m_req); - if (reinterpret_cast(user_data) == s->m_comm) - { - if (!s->m_comm->has_reached_recursion_depth()) + // shared recv + auto s = std::get(m_req); + if (!s->m_comm->m_context->has_reached_recursion_depth()) { - auto inc = s->m_comm->recursion(); + auto inc = s->m_comm->m_context->recursion(); auto ptr = s->release_self_ref(); s->invoke_cb(); } else { // enqueue the callback - while (!(s->m_comm->m_send_cb_queue.push(s))) {} + while (!(s->m_comm->m_context->m_recv_cb_queue.push(s))) {} } } else { - // enqueue the callback - while (!(s->m_comm->m_send_cb_queue.push(s))) {} + detail::request_state** req = reinterpret_cast(&m_req); + LF_DEB(NS_MEMORY::opctx_deb<9>, + error( + NS_DEBUG::str<>("invalid request_state"), this, "request", NS_DEBUG::ptr(req))); + throw std::runtime_error("Request state invalid in handle_tagged_recv"); } + return 1; } - else if (std::holds_alternative(m_req)) + + int operation_context::handle_tagged_send_completion_impl(void* user_data) { - // shared recv - auto s = std::get(m_req); - if (!s->m_comm->m_context->has_reached_recursion_depth()) + if (std::holds_alternative(m_req)) { - auto inc = s->m_comm->m_context->recursion(); - auto ptr = s->release_self_ref(); - s->invoke_cb(); + // regular (non-shared) recv + auto s = std::get(m_req); + if (reinterpret_cast(user_data) == s->m_comm) + { + if (!s->m_comm->has_reached_recursion_depth()) + { + auto inc = s->m_comm->recursion(); + auto ptr = s->release_self_ref(); + s->invoke_cb(); + } + else + { + // enqueue the callback + while (!(s->m_comm->m_send_cb_queue.push(s))) {} + } + } + else + { + // enqueue the callback + while (!(s->m_comm->m_send_cb_queue.push(s))) {} + } } - else + else if (std::holds_alternative(m_req)) { - // enqueue the callback - while (!(s->m_comm->m_context->m_recv_cb_queue.push(s))) {} + // shared recv + auto s = std::get(m_req); + if (!s->m_comm->m_context->has_reached_recursion_depth()) + { + auto inc = s->m_comm->m_context->recursion(); + auto ptr = s->release_self_ref(); + s->invoke_cb(); + } + else + { + // enqueue the callback + while (!(s->m_comm->m_context->m_recv_cb_queue.push(s))) {} + } } + else { throw std::runtime_error("Request state invalid in handle_tagged_send"); } + return 1; } - else { throw std::runtime_error("Request state invalid in handle_tagged_send"); } - return 1; -} -} // namespace oomph::libfabric +} // namespace oomph::libfabric diff --git a/src/libfabric/operation_context.hpp b/src/libfabric/operation_context.hpp index ad106e6a..0f6b5103 100644 --- a/src/libfabric/operation_context.hpp +++ b/src/libfabric/operation_context.hpp @@ -15,39 +15,38 @@ // #include "operation_context_base.hpp" // -namespace oomph::libfabric -{ - -template -inline /*constexpr*/ NS_DEBUG::print_threshold opctx_deb("OP__CXT"); - -// This struct holds the ready state of a future -// we must also store the context used in libfabric, in case -// a request is cancelled - fi_cancel(...) needs it -struct operation_context : public operation_context_base -{ - std::variant m_req; - - template - operation_context(RequestState* req) - : operation_context_base() - , m_req{req} - { - [[maybe_unused]] auto scp = - opctx_deb<9>.scope(NS_DEBUG::ptr(this), __func__, "request", req); - } - - // -------------------------------------------------------------------- - // When a completion returns FI_ECANCELED, this is called - void handle_cancelled(); +namespace oomph::libfabric { - // -------------------------------------------------------------------- - // Called when a tagged recv completes - int handle_tagged_recv_completion_impl(void* user_data); + template + inline /*constexpr*/ NS_DEBUG::print_threshold opctx_deb("OP__CXT"); - // -------------------------------------------------------------------- - // Called when a tagged send completes - int handle_tagged_send_completion_impl(void* user_data); -}; - -} // namespace oomph::libfabric + // This struct holds the ready state of a future + // we must also store the context used in libfabric, in case + // a request is cancelled - fi_cancel(...) needs it + struct operation_context : public operation_context_base + { + std::variant m_req; + + template + operation_context(RequestState* req) + : operation_context_base() + , m_req{req} + { + [[maybe_unused]] auto scp = + opctx_deb<9>.scope(NS_DEBUG::ptr(this), __func__, "request", req); + } + + // -------------------------------------------------------------------- + // When a completion returns FI_ECANCELED, this is called + void handle_cancelled(); + + // -------------------------------------------------------------------- + // Called when a tagged recv completes + int handle_tagged_recv_completion_impl(void* user_data); + + // -------------------------------------------------------------------- + // Called when a tagged send completes + int handle_tagged_send_completion_impl(void* user_data); + }; + +} // namespace oomph::libfabric diff --git a/src/libfabric/operation_context_base.hpp b/src/libfabric/operation_context_base.hpp index e5156f99..5de5c386 100644 --- a/src/libfabric/operation_context_base.hpp +++ b/src/libfabric/operation_context_base.hpp @@ -12,85 +12,84 @@ #include #include "oomph_libfabric_defines.hpp" -namespace NS_LIBFABRIC -{ +namespace NS_LIBFABRIC { -class controller; + class controller; -static NS_DEBUG::enable_print ctx_bas("CTXBASE"); + static NS_DEBUG::enable_print ctx_bas("CTXBASE"); -// This struct holds the ready state of a future -// we must also store the context used in libfabric, in case -// a request is cancelled - fi_cancel(...) needs it -template -struct operation_context_base -{ - private: - // libfabric requires some space for it's internal bookkeeping - // so the first member of this struct must be fi_context - fi_context context_reserved_space; - - public: - operation_context_base() - : context_reserved_space() + // This struct holds the ready state of a future + // we must also store the context used in libfabric, in case + // a request is cancelled - fi_cancel(...) needs it + template + struct operation_context_base { - [[maybe_unused]] auto scp = ctx_bas.scope(NS_DEBUG::ptr(this), __func__); - } + private: + // libfabric requires some space for it's internal bookkeeping + // so the first member of this struct must be fi_context + fi_context context_reserved_space; - // error - void handle_error(struct fi_cq_err_entry& err) - { - static_cast(this)->handle_error_impl(err); - } - void handle_error_impl(struct fi_cq_err_entry& /*err*/) { std::terminate(); } + public: + operation_context_base() + : context_reserved_space() + { + [[maybe_unused]] auto scp = ctx_bas.scope(NS_DEBUG::ptr(this), __func__); + } - void handle_cancelled() { static_cast(this)->handle_cancelled_impl(); } - void handle_cancelled_impl() { std::terminate(); } + // error + void handle_error(struct fi_cq_err_entry& err) + { + static_cast(this)->handle_error_impl(err); + } + void handle_error_impl(struct fi_cq_err_entry& /*err*/) { std::terminate(); } - // send - int handle_send_completion() - { - return static_cast(this)->handle_send_completion_impl(); - } - int handle_send_completion_impl() { return 0; } + void handle_cancelled() { static_cast(this)->handle_cancelled_impl(); } + void handle_cancelled_impl() { std::terminate(); } - // tagged send - int handle_tagged_send_completion(void* user_data) - { - return static_cast(this)->handle_tagged_send_completion_impl(user_data); - } - int handle_tagged_send_completion_impl(void* /*user_data*/) { return 0; } + // send + int handle_send_completion() + { + return static_cast(this)->handle_send_completion_impl(); + } + int handle_send_completion_impl() { return 0; } - // recv - int handle_recv_completion(std::uint64_t len) - { - return static_cast(this)->handle_recv_completion_impl(len); - } - int handle_recv_completion_impl(std::uint64_t /*len*/) { return 0; } + // tagged send + int handle_tagged_send_completion(void* user_data) + { + return static_cast(this)->handle_tagged_send_completion_impl(user_data); + } + int handle_tagged_send_completion_impl(void* /*user_data*/) { return 0; } - // tagged recv - int handle_tagged_recv_completion(void* user_data) - { - return static_cast(this)->handle_tagged_recv_completion_impl(user_data); - } - int handle_tagged_recv_completion_impl(bool /*threadlocal*/) { return 0; } + // recv + int handle_recv_completion(std::uint64_t len) + { + return static_cast(this)->handle_recv_completion_impl(len); + } + int handle_recv_completion_impl(std::uint64_t /*len*/) { return 0; } - void handle_rma_read_completion() - { - static_cast(this)->handle_rma_read_completion_impl(); - } - void handle_rma_read_completion_impl() {} + // tagged recv + int handle_tagged_recv_completion(void* user_data) + { + return static_cast(this)->handle_tagged_recv_completion_impl(user_data); + } + int handle_tagged_recv_completion_impl(bool /*threadlocal*/) { return 0; } - // unknown sender = new connection - int handle_new_connection(controller* ctrl, std::uint64_t len) - { - return static_cast(this)->handle_new_connection_impl(ctrl, len); - } - int handle_new_connection_impl(controller*, std::uint64_t) { return 0; } -}; + void handle_rma_read_completion() + { + static_cast(this)->handle_rma_read_completion_impl(); + } + void handle_rma_read_completion_impl() {} -// provided so that a pointer can be cast to this and the operation_context_type queried -struct unspecialized_context : public operation_context_base -{ -}; -} // namespace NS_LIBFABRIC + // unknown sender = new connection + int handle_new_connection(controller* ctrl, std::uint64_t len) + { + return static_cast(this)->handle_new_connection_impl(ctrl, len); + } + int handle_new_connection_impl(controller*, std::uint64_t) { return 0; } + }; + + // provided so that a pointer can be cast to this and the operation_context_type queried + struct unspecialized_context : public operation_context_base + { + }; +} // namespace NS_LIBFABRIC diff --git a/src/libfabric/print.hpp b/src/libfabric/print.hpp index cf8de408..73c37c41 100644 --- a/src/libfabric/print.hpp +++ b/src/libfabric/print.hpp @@ -27,12 +27,12 @@ #include // #if defined(__linux) || defined(linux) || defined(__linux__) -#include -#include +# include +# include #elif defined(__APPLE__) -#include -#include -#define environ (*_NSGetEnviron()) +# include +# include +# define environ (*_NSGetEnviron()) #else extern char** environ; #endif @@ -78,665 +78,648 @@ extern char** environ; // ------------------------------------------------------------ /// \cond NODETAIL -namespace NS_DEBUG -{ - -// ------------------------------------------------------------------ -// format as zero padded int -// ------------------------------------------------------------------ -namespace detail -{ - -template -struct dec -{ - constexpr dec(T const& v) - : data_(v) - { - } +namespace NS_DEBUG { - T const& data_; + // ------------------------------------------------------------------ + // format as zero padded int + // ------------------------------------------------------------------ + namespace detail { - friend std::ostream& operator<<(std::ostream& os, dec const& d) - { - os << std::right << std::setfill('0') << std::setw(N) << std::noshowbase << std::dec - << d.data_; - return os; - } -}; -} // namespace detail - -template -constexpr detail::dec -dec(T const& v) -{ - return detail::dec(v); -} - -// ------------------------------------------------------------------ -// format as pointer -// ------------------------------------------------------------------ -struct ptr -{ - ptr(void const* v) - : data_(v) - { - } - ptr(std::uintptr_t const v) - : data_(reinterpret_cast(v)) - { - } - void const* data_; - friend std::ostream& operator<<(std::ostream& os, ptr const& d) - { - os << std::right << "0x" << std::setfill('0') << std::setw(12) << std::noshowbase - << std::hex << reinterpret_cast(d.data_); - return os; - } -}; - -// ------------------------------------------------------------------ -// format as zero padded hex -// ------------------------------------------------------------------ -namespace detail -{ - -template -struct hex; - -template -struct hex::value>::type> -{ - constexpr hex(T const& v) - : data_(v) - { - } - T const& data_; - friend std::ostream& operator<<(std::ostream& os, const hex& d) - { - os << std::right << "0x" << std::setfill('0') << std::setw(N) << std::noshowbase << std::hex - << d.data_; - return os; - } -}; + template + struct dec + { + constexpr dec(T const& v) + : data_(v) + { + } -template -struct hex::value>::type> -{ - constexpr hex(T const& v) - : data_(v) - { - } - T const& data_; - friend std::ostream& operator<<(std::ostream& os, const hex& d) - { - os << std::right << std::setw(N) << std::noshowbase << std::hex << d.data_; - return os; - } -}; -} // namespace detail - -template -constexpr detail::hex -hex(T const& v) -{ - return detail::hex(v); -} - -// ------------------------------------------------------------------ -// format as binary bits -// ------------------------------------------------------------------ -namespace detail -{ - -template -struct bin -{ - constexpr bin(T const& v) - : data_(v) - { - } - T const& data_; - friend std::ostream& operator<<(std::ostream& os, const bin& d) + T const& data_; + + friend std::ostream& operator<<(std::ostream& os, dec const& d) + { + os << std::right << std::setfill('0') << std::setw(N) << std::noshowbase << std::dec + << d.data_; + return os; + } + }; + } // namespace detail + + template + constexpr detail::dec dec(T const& v) { - os << std::bitset(d.data_); - return os; + return detail::dec(v); } -}; -} // namespace detail - -template -constexpr detail::bin -bin(T const& v) -{ - return detail::bin(v); -} - -// ------------------------------------------------------------------ -// format as padded string -// ------------------------------------------------------------------ -template -struct str -{ - constexpr str(char const* v) - : data_(v) + + // ------------------------------------------------------------------ + // format as pointer + // ------------------------------------------------------------------ + struct ptr { - } + ptr(void const* v) + : data_(v) + { + } + ptr(std::uintptr_t const v) + : data_(reinterpret_cast(v)) + { + } + void const* data_; + friend std::ostream& operator<<(std::ostream& os, ptr const& d) + { + os << std::right << "0x" << std::setfill('0') << std::setw(12) << std::noshowbase + << std::hex << reinterpret_cast(d.data_); + return os; + } + }; + + // ------------------------------------------------------------------ + // format as zero padded hex + // ------------------------------------------------------------------ + namespace detail { - char const* data_; + template + struct hex; - friend std::ostream& operator<<(std::ostream& os, str const& d) + template + struct hex::value>::type> + { + constexpr hex(T const& v) + : data_(v) + { + } + T const& data_; + friend std::ostream& operator<<(std::ostream& os, hex const& d) + { + os << std::right << "0x" << std::setfill('0') << std::setw(N) << std::noshowbase + << std::hex << d.data_; + return os; + } + }; + + template + struct hex::value>::type> + { + constexpr hex(T const& v) + : data_(v) + { + } + T const& data_; + friend std::ostream& operator<<(std::ostream& os, hex const& d) + { + os << std::right << std::setw(N) << std::noshowbase << std::hex << d.data_; + return os; + } + }; + } // namespace detail + + template + constexpr detail::hex hex(T const& v) { - os << std::left << std::setfill(' ') << std::setw(N) << d.data_; - return os; + return detail::hex(v); } -}; - -// ------------------------------------------------------------------ -// format as ip address -// ------------------------------------------------------------------ -struct ipaddr -{ - ipaddr(const void* a) - : data_(reinterpret_cast(a)) - , ipdata_(0) + + // ------------------------------------------------------------------ + // format as binary bits + // ------------------------------------------------------------------ + namespace detail { + + template + struct bin + { + constexpr bin(T const& v) + : data_(v) + { + } + T const& data_; + friend std::ostream& operator<<(std::ostream& os, bin const& d) + { + os << std::bitset(d.data_); + return os; + } + }; + } // namespace detail + + template + constexpr detail::bin bin(T const& v) { + return detail::bin(v); } - ipaddr(const uint32_t a) - : data_(reinterpret_cast(&ipdata_)) - , ipdata_(a) + + // ------------------------------------------------------------------ + // format as padded string + // ------------------------------------------------------------------ + template + struct str { - } - const uint8_t* data_; - const uint32_t ipdata_; + constexpr str(char const* v) + : data_(v) + { + } + + char const* data_; + + friend std::ostream& operator<<(std::ostream& os, str const& d) + { + os << std::left << std::setfill(' ') << std::setw(N) << d.data_; + return os; + } + }; - friend std::ostream& operator<<(std::ostream& os, ipaddr const& p) + // ------------------------------------------------------------------ + // format as ip address + // ------------------------------------------------------------------ + struct ipaddr { - os << std::dec << int(p.data_[0]) << "." << int(p.data_[1]) << "." << int(p.data_[2]) << "." - << int(p.data_[3]); - return os; - } -}; - -// ------------------------------------------------------------------ -// helper fuction for printing CRC32 -// ------------------------------------------------------------------ -inline uint32_t -crc32(const void* address, size_t length) -{ - boost::crc_32_type result; - result.process_bytes(address, length); - return result.checksum(); -} - -// ------------------------------------------------------------------ -// helper fuction for printing short memory dump and crc32 -// useful for debugging corruptions in buffers during -// rma or other transfers -// ------------------------------------------------------------------ -struct mem_crc32 -{ - mem_crc32(const void* a, std::size_t len, const char* txt) - : addr_(reinterpret_cast(a)) - , len_(len) - , txt_(txt) + ipaddr(void const* a) + : data_(reinterpret_cast(a)) + , ipdata_(0) + { + } + ipaddr(uint32_t const a) + : data_(reinterpret_cast(&ipdata_)) + , ipdata_(a) + { + } + uint8_t const* data_; + uint32_t const ipdata_; + + friend std::ostream& operator<<(std::ostream& os, ipaddr const& p) + { + os << std::dec << int(p.data_[0]) << "." << int(p.data_[1]) << "." << int(p.data_[2]) + << "." << int(p.data_[3]); + return os; + } + }; + + // ------------------------------------------------------------------ + // helper fuction for printing CRC32 + // ------------------------------------------------------------------ + inline uint32_t crc32(void const* address, size_t length) { + boost::crc_32_type result; + result.process_bytes(address, length); + return result.checksum(); } - const std::uint8_t* addr_; - const std::size_t len_; - const char* txt_; - friend std::ostream& operator<<(std::ostream& os, mem_crc32 const& p) + + // ------------------------------------------------------------------ + // helper fuction for printing short memory dump and crc32 + // useful for debugging corruptions in buffers during + // rma or other transfers + // ------------------------------------------------------------------ + struct mem_crc32 { - const std::uint8_t* byte = static_cast(p.addr_); - os << "Memory:"; - os << " address " << ptr(p.addr_) << " length " << hex<6, std::size_t>(p.len_) - << " CRC32:" << hex<8, std::size_t>(crc32(p.addr_, p.len_)) << "\n"; - size_t i = 0; - while (i < std::min(size_t(128), p.len_)) - { - os << "0x"; - for (int j = 7; j >= 0; j--) + mem_crc32(void const* a, std::size_t len, char const* txt) + : addr_(reinterpret_cast(a)) + , len_(len) + , txt_(txt) + { + } + std::uint8_t const* addr_; + std::size_t const len_; + char const* txt_; + friend std::ostream& operator<<(std::ostream& os, mem_crc32 const& p) + { + std::uint8_t const* byte = static_cast(p.addr_); + os << "Memory:"; + os << " address " << ptr(p.addr_) << " length " << hex<6, std::size_t>(p.len_) + << " CRC32:" << hex<8, std::size_t>(crc32(p.addr_, p.len_)) << "\n"; + size_t i = 0; + while (i < std::min(size_t(128), p.len_)) { - os << std::hex << std::setfill('0') << std::setw(2) - << (((i + j) > p.len_) ? (int)0 : (int)byte[i + j]); + os << "0x"; + for (int j = 7; j >= 0; j--) + { + os << std::hex << std::setfill('0') << std::setw(2) + << (((i + j) > p.len_) ? (int) 0 : (int) byte[i + j]); + } + i += 8; + if (i % 32 == 0) + os << std::endl; + else + os << " "; } - i += 8; - if (i % 32 == 0) os << std::endl; - else - os << " "; + os << ": " << p.txt_; + return os; } - os << ": " << p.txt_; - return os; - } -}; - -namespace detail -{ - -template -void -tuple_print(std::ostream& os, TupleType const& t, std::index_sequence) -{ - (..., (os << (I == 0 ? "" : " ") << std::get(t))); -} - -template -void -tuple_print(std::ostream& os, const std::tuple& t) -{ - tuple_print(os, t, std::make_index_sequence()); -} -} // namespace detail - -namespace detail -{ - -// ------------------------------------------------------------------ -// helper class for printing thread ID -// ------------------------------------------------------------------ -struct current_thread_print_helper -{ -}; - -inline std::ostream& -operator<<(std::ostream& os, current_thread_print_helper const&) -{ - os << hex<12, std::thread::id>(std::this_thread::get_id()) + }; + + namespace detail { + + template + void tuple_print(std::ostream& os, TupleType const& t, std::index_sequence) + { + (..., (os << (I == 0 ? "" : " ") << std::get(t))); + } + + template + void tuple_print(std::ostream& os, std::tuple const& t) + { + tuple_print(os, t, std::make_index_sequence()); + } + } // namespace detail + + namespace detail { + + // ------------------------------------------------------------------ + // helper class for printing thread ID + // ------------------------------------------------------------------ + struct current_thread_print_helper + { + }; + + inline std::ostream& operator<<(std::ostream& os, current_thread_print_helper const&) + { + os << hex<12, std::thread::id>(std::this_thread::get_id()) #ifdef DEBUGGING_PRINT_LINUX - << " cpu " << debug::dec<3, int>(sched_getcpu()) << " "; + << " cpu " << debug::dec<3, int>(sched_getcpu()) << " "; #else - << " cpu " - << "--- "; + << " cpu " + << "--- "; #endif - return os; -} - -// ------------------------------------------------------------------ -// helper class for printing time since start -// ------------------------------------------------------------------ -struct hostname_print_helper -{ - const char* get_hostname() const - { - static bool initialized = false; - static char hostname_[20]; - if (!initialized) - { - initialized = true; - gethostname(hostname_, std::size_t(12)); - std::string temp = "(" + std::to_string(guess_rank()) + ")"; - std::strcat(hostname_, temp.c_str()); + return os; } - return hostname_; - } - int guess_rank() const - { - std::vector env_strings{"_RANK=", "_NODEID="}; - for (char** current = environ; *current; current++) + // ------------------------------------------------------------------ + // helper class for printing time since start + // ------------------------------------------------------------------ + struct hostname_print_helper { - auto e = std::string(*current); - for (auto s : env_strings) + char const* get_hostname() const + { + static bool initialized = false; + static char hostname_[20]; + if (!initialized) + { + initialized = true; + gethostname(hostname_, std::size_t(12)); + std::string temp = "(" + std::to_string(guess_rank()) + ")"; + std::strcat(hostname_, temp.c_str()); + } + return hostname_; + } + + int guess_rank() const { - auto pos = e.find(s); - if (pos != std::string::npos) + std::vector env_strings{"_RANK=", "_NODEID="}; + for (char** current = environ; *current; current++) { - //std::cout << "Got a rank string : " << e << std::endl; - return std::stoi(e.substr(pos + s.size(), 5)); + auto e = std::string(*current); + for (auto s : env_strings) + { + auto pos = e.find(s); + if (pos != std::string::npos) + { + //std::cout << "Got a rank string : " << e << std::endl; + return std::stoi(e.substr(pos + s.size(), 5)); + } + } } + return -1; } + }; + + inline std::ostream& operator<<(std::ostream& os, hostname_print_helper const& h) + { + os << debug::str<13>(h.get_hostname()) << " "; + return os; } - return -1; - } -}; - -inline std::ostream& -operator<<(std::ostream& os, hostname_print_helper const& h) -{ - os << debug::str<13>(h.get_hostname()) << " "; - return os; -} - -// ------------------------------------------------------------------ -// helper class for printing time since start -// ------------------------------------------------------------------ -struct current_time_print_helper -{ -}; - -inline std::ostream& -operator<<(std::ostream& os, current_time_print_helper const&) -{ - using namespace std::chrono; - static steady_clock::time_point log_t_start = steady_clock::now(); - // - auto now = steady_clock::now(); - auto nowt = duration_cast(now - log_t_start).count(); - // - os << debug::dec<10>(nowt) << " "; - return os; -} - -template -void -display(char const* prefix, Args const&... args) -{ - // using a temp stream object with a single copy to cout at the end - // prevents multiple threads from injecting overlapping text - std::stringstream tempstream; - tempstream << prefix << detail::current_time_print_helper() - << detail::current_thread_print_helper() << detail::hostname_print_helper(); - ((tempstream << args << " "), ...); - tempstream << "\n"; - std::cout << tempstream.str() << std::flush; -} - -template -void -debug(Args const&... args) -{ - display(" ", args...); -} - -template -void -warning(Args const&... args) -{ - display(" ", args...); -} - -template -void -error(Args const&... args) -{ - display(" ", args...); -} - -template -void -scope(Args const&... args) -{ - display(" ", args...); -} - -template -void -trace(Args const&... args) -{ - display(" ", args...); -} - -template -void -timed(Args const&... args) -{ - display(" ", args...); -} -} // namespace detail - -template -struct scoped_var -{ - // capture tuple elements by reference - no temp vars in constructor please - char const* prefix_; - std::tuple const message_; - std::string buffered_msg; - - // - scoped_var(char const* p, Args const&... args) - : prefix_(p) - , message_(args...) - { - std::stringstream tempstream; - detail::tuple_print(tempstream, message_); - buffered_msg = tempstream.str(); - detail::display(" ", prefix_, debug::str<>(">> enter <<"), tempstream.str()); - } - ~scoped_var() { detail::display(" ", prefix_, debug::str<>("<< leave >>"), buffered_msg); } -}; - -template -struct timed_var -{ - mutable std::chrono::steady_clock::time_point time_start_; - double const delay_; - std::tuple const message_; - // - timed_var(double const& delay, Args const&... args) - : time_start_(std::chrono::steady_clock::now()) - , delay_(delay) - , message_(args...) - { - } + // ------------------------------------------------------------------ + // helper class for printing time since start + // ------------------------------------------------------------------ + struct current_time_print_helper + { + }; - bool elapsed(std::chrono::steady_clock::time_point const& now) const - { - double elapsed_ = - std::chrono::duration_cast>(now - time_start_).count(); + inline std::ostream& operator<<(std::ostream& os, current_time_print_helper const&) + { + using namespace std::chrono; + static steady_clock::time_point log_t_start = steady_clock::now(); + // + auto now = steady_clock::now(); + auto nowt = duration_cast(now - log_t_start).count(); + // + os << debug::dec<10>(nowt) << " "; + return os; + } - if (elapsed_ > delay_) + template + void display(char const* prefix, Args const&... args) { - time_start_ = now; - return true; + // using a temp stream object with a single copy to cout at the end + // prevents multiple threads from injecting overlapping text + std::stringstream tempstream; + tempstream << prefix << detail::current_time_print_helper() + << detail::current_thread_print_helper() << detail::hostname_print_helper(); + ((tempstream << args << " "), ...); + tempstream << "\n"; + std::cout << tempstream.str() << std::flush; } - return false; - } - friend std::ostream& operator<<(std::ostream& os, timed_var const& ti) - { - detail::tuple_print(os, ti.message_); - return os; - } -}; + template + void debug(Args const&... args) + { + display(" ", args...); + } -/////////////////////////////////////////////////////////////////////////// -template -struct enable_print; + template + void warning(Args const&... args) + { + display(" ", args...); + } -// when false, debug statements should produce no code -template<> -struct enable_print -{ - constexpr enable_print(const char*) {} + template + void error(Args const&... args) + { + display(" ", args...); + } - constexpr bool is_enabled() const { return false; } + template + void scope(Args const&... args) + { + display(" ", args...); + } - template - constexpr void debug(Args const&...) const - { - } + template + void trace(Args const&... args) + { + display(" ", args...); + } - template - constexpr void warning(Args const&...) const - { - } + template + void timed(Args const&... args) + { + display(" ", args...); + } + } // namespace detail - template - constexpr void trace(Args const&...) const + template + struct scoped_var { - } + // capture tuple elements by reference - no temp vars in constructor please + char const* prefix_; + std::tuple const message_; + std::string buffered_msg; - template - constexpr void error(Args const&...) const - { - } + // + scoped_var(char const* p, Args const&... args) + : prefix_(p) + , message_(args...) + { + std::stringstream tempstream; + detail::tuple_print(tempstream, message_); + buffered_msg = tempstream.str(); + detail::display(" ", prefix_, debug::str<>(">> enter <<"), tempstream.str()); + } - template - constexpr void timed(Args const&...) const - { - } + ~scoped_var() + { + detail::display(" ", prefix_, debug::str<>("<< leave >>"), buffered_msg); + } + }; + + template + struct timed_var + { + mutable std::chrono::steady_clock::time_point time_start_; + double const delay_; + std::tuple const message_; + // + timed_var(double const& delay, Args const&... args) + : time_start_(std::chrono::steady_clock::now()) + , delay_(delay) + , message_(args...) + { + } - template - constexpr void array(std::string const&, std::vector const&) const - { - } + bool elapsed(std::chrono::steady_clock::time_point const& now) const + { + double elapsed_ = + std::chrono::duration_cast>(now - time_start_) + .count(); - template - constexpr void array(std::string const&, std::array const&) const - { - } + if (elapsed_ > delay_) + { + time_start_ = now; + return true; + } + return false; + } - template - constexpr void array(std::string const&, Iter, Iter) const - { - } + friend std::ostream& operator<<(std::ostream& os, timed_var const& ti) + { + detail::tuple_print(os, ti.message_); + return os; + } + }; - template - constexpr bool scope(Args const&...) - { - return true; - } + /////////////////////////////////////////////////////////////////////////// + template + struct enable_print; - template - constexpr bool declare_variable(Args const&...) const + // when false, debug statements should produce no code + template <> + struct enable_print { - return true; - } + constexpr enable_print(char const*) {} - template - constexpr void set(T&, V const&) - { - } + constexpr bool is_enabled() const { return false; } - // @todo, return void so that timers have zero footprint when disabled - template - constexpr int make_timer(const double, Args const&...) const - { - return 0; - } + template + constexpr void debug(Args const&...) const + { + } - template - constexpr bool eval(Expr const&) - { - return true; - } -}; - -// when true, debug statements produce valid output -template<> -struct enable_print -{ - private: - char const* prefix_; - - public: - constexpr enable_print() - : prefix_("") - { - } + template + constexpr void warning(Args const&...) const + { + } - constexpr enable_print(const char* p) - : prefix_(p) - { - } + template + constexpr void trace(Args const&...) const + { + } - constexpr bool is_enabled() const { return true; } + template + constexpr void error(Args const&...) const + { + } - template - constexpr void debug(Args const&... args) const - { - detail::debug(prefix_, args...); - } + template + constexpr void timed(Args const&...) const + { + } - template - constexpr void warning(Args const&... args) const - { - detail::warning(prefix_, args...); - } + template + constexpr void array(std::string const&, std::vector const&) const + { + } - template - constexpr void trace(Args const&... args) const - { - detail::trace(prefix_, args...); - } + template + constexpr void array(std::string const&, std::array const&) const + { + } - template - constexpr void error(Args const&... args) const - { - detail::error(prefix_, args...); - } + template + constexpr void array(std::string const&, Iter, Iter) const + { + } - template - scoped_var scope(Args const&... args) - { - return scoped_var(prefix_, args...); - } + template + constexpr bool scope(Args const&...) + { + return true; + } - template - void timed(timed_var const& init, Args const&... args) const - { - auto now = std::chrono::steady_clock::now(); - if (init.elapsed(now)) { detail::timed(prefix_, init, args...); } - } + template + constexpr bool declare_variable(Args const&...) const + { + return true; + } - template - void array(std::string const& name, std::vector const& v) const - { - std::cout << str<20>(name.c_str()) << ": {" << debug::dec<4>(v.size()) << "} : "; - std::copy(std::begin(v), std::end(v), std::ostream_iterator(std::cout, ", ")); - std::cout << "\n"; - } + template + constexpr void set(T&, V const&) + { + } - template - void array(std::string const& name, const std::array& v) const - { - std::cout << str<20>(name.c_str()) << ": {" << debug::dec<4>(v.size()) << "} : "; - std::copy(std::begin(v), std::end(v), std::ostream_iterator(std::cout, ", ")); - std::cout << "\n"; - } + // @todo, return void so that timers have zero footprint when disabled + template + constexpr int make_timer(double const, Args const&...) const + { + return 0; + } - template - void array(std::string const& name, Iter begin, Iter end) const - { - std::cout << str<20>(name.c_str()) << ": {" << debug::dec<4>(std::distance(begin, end)) - << "} : "; - std::copy(begin, end, - std::ostream_iterator::value_type>(std::cout, - ", ")); - std::cout << std::endl; - } + template + constexpr bool eval(Expr const&) + { + return true; + } + }; - template - T declare_variable(Args const&... args) const + // when true, debug statements produce valid output + template <> + struct enable_print { - return T(args...); - } + private: + char const* prefix_; - template - void set(T& var, V const& val) - { - var = val; - } + public: + constexpr enable_print() + : prefix_("") + { + } + + constexpr enable_print(char const* p) + : prefix_(p) + { + } + + constexpr bool is_enabled() const { return true; } + + template + constexpr void debug(Args const&... args) const + { + detail::debug(prefix_, args...); + } + + template + constexpr void warning(Args const&... args) const + { + detail::warning(prefix_, args...); + } + + template + constexpr void trace(Args const&... args) const + { + detail::trace(prefix_, args...); + } - template - timed_var make_timer(const double delay, const Args... args) const + template + constexpr void error(Args const&... args) const + { + detail::error(prefix_, args...); + } + + template + scoped_var scope(Args const&... args) + { + return scoped_var(prefix_, args...); + } + + template + void timed(timed_var const& init, Args const&... args) const + { + auto now = std::chrono::steady_clock::now(); + if (init.elapsed(now)) { detail::timed(prefix_, init, args...); } + } + + template + void array(std::string const& name, std::vector const& v) const + { + std::cout << str<20>(name.c_str()) << ": {" << debug::dec<4>(v.size()) << "} : "; + std::copy(std::begin(v), std::end(v), std::ostream_iterator(std::cout, ", ")); + std::cout << "\n"; + } + + template + void array(std::string const& name, std::array const& v) const + { + std::cout << str<20>(name.c_str()) << ": {" << debug::dec<4>(v.size()) << "} : "; + std::copy(std::begin(v), std::end(v), std::ostream_iterator(std::cout, ", ")); + std::cout << "\n"; + } + + template + void array(std::string const& name, Iter begin, Iter end) const + { + std::cout << str<20>(name.c_str()) << ": {" << debug::dec<4>(std::distance(begin, end)) + << "} : "; + std::copy(begin, end, + std::ostream_iterator::value_type>( + std::cout, ", ")); + std::cout << std::endl; + } + + template + T declare_variable(Args const&... args) const + { + return T(args...); + } + + template + void set(T& var, V const& val) + { + var = val; + } + + template + timed_var make_timer(double const delay, Args const... args) const + { + return timed_var(delay, args...); + } + + template + auto eval(Expr const& e) + { + return e(); + } + }; + + // ------------------------------------------------------------------ + // helper for N>M true/false + // ------------------------------------------------------------------ + template + struct check_level : std::integral_constant { - return timed_var(delay, args...); - } + }; - template - auto eval(Expr const& e) + template + struct print_threshold : enable_print::value> { - return e(); - } -}; - -// ------------------------------------------------------------------ -// helper for N>M true/false -// ------------------------------------------------------------------ -template -struct check_level : std::integral_constant -{ -}; - -template -struct print_threshold : enable_print::value> -{ - using base_type = enable_print::value>; - // inherit constructor - using base_type::base_type; -}; - -} // namespace NS_DEBUG + using base_type = enable_print::value>; + // inherit constructor + using base_type::base_type; + }; + +} // namespace NS_DEBUG /// \endcond diff --git a/src/libfabric/request_state.hpp b/src/libfabric/request_state.hpp index d00e0367..58f15dd5 100644 --- a/src/libfabric/request_state.hpp +++ b/src/libfabric/request_state.hpp @@ -13,90 +13,88 @@ #include "../request_state_base.hpp" #include "./operation_context.hpp" -namespace oomph -{ -namespace detail -{ - -struct request_state -: public util::enable_shared_from_this -, public request_state_base -{ - using base = request_state_base; - using shared_ptr_t = util::unsafe_shared_ptr; - using operation_context = libfabric::operation_context; - - operation_context m_operation_context; - util::unsafe_shared_ptr m_self_ptr; - - request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, std::size_t* scheduled, - rank_type rank, tag_type tag, cb_type&& cb) - : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} - , m_operation_context{this} - { - } - - void progress(); - - bool cancel(); - - void create_self_ref() - { - // create a self-reference cycle!! - // this is useful if we only keep a raw pointer around internally, which still is supposed - // to keep the object alive - m_self_ptr = shared_from_this(); - } - - shared_ptr_t release_self_ref() noexcept - { - assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); - return std::move(m_self_ptr); - } -}; - -struct shared_request_state -: public std::enable_shared_from_this -, public request_state_base -{ - using base = request_state_base; - using shared_ptr_t = std::shared_ptr; - using operation_context = libfabric::operation_context; - - operation_context m_operation_context; - std::shared_ptr m_self_ptr; - - shared_request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, - std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb) - : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} - , m_operation_context{this} - { - [[maybe_unused]] auto scp = libfabric::opctx_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - } +namespace oomph { namespace detail { - ~shared_request_state() + struct request_state + : public util::enable_shared_from_this + , public request_state_base { - [[maybe_unused]] auto scp = libfabric::opctx_deb<9>.scope(NS_DEBUG::ptr(this), __func__); - } - - void progress(); - - bool cancel(); - - void create_self_ref() + using base = request_state_base; + using shared_ptr_t = util::unsafe_shared_ptr; + using operation_context = libfabric::operation_context; + + operation_context m_operation_context; + util::unsafe_shared_ptr m_self_ptr; + + request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, + std::size_t* scheduled, rank_type rank, tag_type tag, cb_type&& cb) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_operation_context{this} + { + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool) m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } + }; + + struct shared_request_state + : public std::enable_shared_from_this + , public request_state_base { - // create a self-reference cycle!! - // this is useful if we only keep a raw pointer around internally, which still is supposed - // to keep the object alive - m_self_ptr = shared_from_this(); - } - - shared_ptr_t release_self_ref() noexcept - { - assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); - return std::move(m_self_ptr); - } -}; - -} // namespace detail -} // namespace oomph + using base = request_state_base; + using shared_ptr_t = std::shared_ptr; + using operation_context = libfabric::operation_context; + + operation_context m_operation_context; + std::shared_ptr m_self_ptr; + + shared_request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, + std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_operation_context{this} + { + [[maybe_unused]] auto scp = + libfabric::opctx_deb<9>.scope(NS_DEBUG::ptr(this), __func__); + } + + ~shared_request_state() + { + [[maybe_unused]] auto scp = + libfabric::opctx_deb<9>.scope(NS_DEBUG::ptr(this), __func__); + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool) m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } + }; + +}} // namespace oomph::detail diff --git a/src/libfabric/simple_counter.hpp b/src/libfabric/simple_counter.hpp index f44eac92..26ecf8d5 100644 --- a/src/libfabric/simple_counter.hpp +++ b/src/libfabric/simple_counter.hpp @@ -12,13 +12,13 @@ #include "oomph_libfabric_defines.hpp" // #include -#include #include +#include #ifdef OOMPH_LIBFABRIC_HAVE_PERFORMANCE_COUNTERS -#define PERFORMANCE_COUNTER_ENABLED true +# define PERFORMANCE_COUNTER_ENABLED true #else -#define PERFORMANCE_COUNTER_ENABLED false +# define PERFORMANCE_COUNTER_ENABLED false #endif // @@ -29,90 +29,86 @@ // the performance counter that will simply do nothing when disabled - but // still allow code that uses the counters in arithmetic to compile. // -namespace oomph -{ -namespace libfabric -{ -template::value>> -struct simple_counter -{ -}; - -// -------------------------------------------------------------------- -// specialization for performance counters Enabled -// we provide an atomic that can be incremented or added/subtracted to -template -struct simple_counter -{ - simple_counter() - : value_{T()} +namespace oomph { namespace libfabric { + template ::value>> + struct simple_counter { - } + }; - simple_counter(const T& init) - : value_{init} + // -------------------------------------------------------------------- + // specialization for performance counters Enabled + // we provide an atomic that can be incremented or added/subtracted to + template + struct simple_counter { - } + simple_counter() + : value_{T()} + { + } - inline operator T() const { return value_; } + simple_counter(T const& init) + : value_{init} + { + } - inline T operator=(const T& x) { return value_ = x; } + inline operator T() const { return value_; } - inline T operator++() { return ++value_; } + inline T operator=(T const& x) { return value_ = x; } - inline T operator++(int x) { return (value_ += x); } + inline T operator++() { return ++value_; } - inline T operator+=(const T& rhs) { return (value_ += rhs); } + inline T operator++(int x) { return (value_ += x); } - inline T operator--() { return --value_; } + inline T operator+=(T const& rhs) { return (value_ += rhs); } - inline T operator--(int x) { return (value_ -= x); } + inline T operator--() { return --value_; } - inline T operator-=(const T& rhs) { return (value_ -= rhs); } + inline T operator--(int x) { return (value_ -= x); } - friend std::ostream& operator<<(std::ostream& os, const simple_counter& x) - { - os << x.value_; - return os; - } + inline T operator-=(T const& rhs) { return (value_ -= rhs); } - std::atomic value_; -}; + friend std::ostream& operator<<(std::ostream& os, simple_counter const& x) + { + os << x.value_; + return os; + } -// -------------------------------------------------------------------- -// specialization for performance counters Disabled -// just return dummy values so that arithmetic operations compile ok -template -struct simple_counter -{ - simple_counter() {} + std::atomic value_; + }; - simple_counter(const T&) {} + // -------------------------------------------------------------------- + // specialization for performance counters Disabled + // just return dummy values so that arithmetic operations compile ok + template + struct simple_counter + { + simple_counter() {} - inline operator T() const { return 0; } + simple_counter(T const&) {} - // inline bool operator==(const T&) { return true; } + inline operator T() const { return 0; } - inline T operator=(const T&) { return 0; } + // inline bool operator==(const T&) { return true; } - inline T operator++() { return 0; } + inline T operator=(T const&) { return 0; } - inline T operator++(int) { return 0; } + inline T operator++() { return 0; } - inline T operator+=(const T&) { return 0; } + inline T operator++(int) { return 0; } - inline T operator--() { return 0; } + inline T operator+=(T const&) { return 0; } - inline T operator--(int) { return 0; } + inline T operator--() { return 0; } - inline T operator-=(const T&) { return 0; } + inline T operator--(int) { return 0; } - friend std::ostream& operator<<(std::ostream& os, const simple_counter&) - { - os << "undefined"; - return os; - } -}; -} // namespace libfabric -} // namespace oomph + inline T operator-=(T const&) { return 0; } + + friend std::ostream& operator<<(std::ostream& os, simple_counter const&) + { + os << "undefined"; + return os; + } + }; +}} // namespace oomph::libfabric diff --git a/src/message_buffer.cpp b/src/message_buffer.cpp index 9bf9eacf..2cd097a6 100644 --- a/src/message_buffer.cpp +++ b/src/message_buffer.cpp @@ -15,69 +15,45 @@ OOMPH_INSTANTIATE_HEAP_PIMPL(oomph::detail::message_buffer::heap_ptr_impl) -namespace oomph -{ -namespace detail -{ +namespace oomph { namespace detail { -message_buffer::~message_buffer() -{ - if (m_ptr) m_heap_ptr->release(); -} + message_buffer::~message_buffer() + { + if (m_ptr) m_heap_ptr->release(); + } -message_buffer& -message_buffer::operator=(message_buffer&& other) -{ - if (m_ptr) m_heap_ptr->release(); - m_ptr = std::exchange(other.m_ptr, nullptr); - m_heap_ptr = std::move(other.m_heap_ptr); - return *this; -} + message_buffer& message_buffer::operator=(message_buffer&& other) + { + if (m_ptr) m_heap_ptr->release(); + m_ptr = std::exchange(other.m_ptr, nullptr); + m_heap_ptr = std::move(other.m_heap_ptr); + return *this; + } -bool -message_buffer::on_device() const noexcept -{ - return m_heap_ptr->m.on_device(); -} + bool message_buffer::on_device() const noexcept { return m_heap_ptr->m.on_device(); } #if OOMPH_ENABLE_DEVICE -void* -message_buffer::device_data() noexcept -{ - return m_heap_ptr->m.device_ptr(); -} + void* message_buffer::device_data() noexcept { return m_heap_ptr->m.device_ptr(); } -void const* -message_buffer::device_data() const noexcept -{ - return m_heap_ptr->m.device_ptr(); -} + void const* message_buffer::device_data() const noexcept { return m_heap_ptr->m.device_ptr(); } -int -message_buffer::device_id() const noexcept -{ - return m_heap_ptr->m.device_id(); -} + int message_buffer::device_id() const noexcept { return m_heap_ptr->m.device_id(); } -void -message_buffer::clone_to_device(std::size_t count) -{ - hwmalloc::memcpy_to_device(m_heap_ptr->m.device_ptr(), m_ptr, count); -} + void message_buffer::clone_to_device(std::size_t count) + { + hwmalloc::memcpy_to_device(m_heap_ptr->m.device_ptr(), m_ptr, count); + } -void -message_buffer::clone_to_host(std::size_t count) -{ - hwmalloc::memcpy_to_host(m_ptr, m_heap_ptr->m.device_ptr(), count); -} + void message_buffer::clone_to_host(std::size_t count) + { + hwmalloc::memcpy_to_host(m_ptr, m_heap_ptr->m.device_ptr(), count); + } #endif -void -message_buffer::clear() -{ - m_ptr = nullptr; - m_heap_ptr = context_impl::heap_type::pointer{nullptr}; -} + void message_buffer::clear() + { + m_ptr = nullptr; + m_heap_ptr = context_impl::heap_type::pointer{nullptr}; + } -} // namespace detail -} // namespace oomph +}} // namespace oomph::detail diff --git a/src/message_buffer.hpp b/src/message_buffer.hpp index c4e38b77..ccccab77 100644 --- a/src/message_buffer.hpp +++ b/src/message_buffer.hpp @@ -13,21 +13,17 @@ #include // paths relative to backend -#include #include +#include -namespace oomph -{ -namespace detail -{ -using heap_ptr = typename context_impl::heap_type::pointer; +namespace oomph { namespace detail { + using heap_ptr = typename context_impl::heap_type::pointer; -class message_buffer::heap_ptr_impl -{ - public: - heap_ptr m; - void release() { m.release(); } -}; + class message_buffer::heap_ptr_impl + { + public: + heap_ptr m; + void release() { m.release(); } + }; -} // namespace detail -} // namespace oomph +}} // namespace oomph::detail diff --git a/src/mpi/channel_base.hpp b/src/mpi/channel_base.hpp index f8751a44..3f2386e8 100644 --- a/src/mpi/channel_base.hpp +++ b/src/mpi/channel_base.hpp @@ -14,64 +14,64 @@ // paths relative to backend #include -namespace oomph -{ -class channel_base -{ - protected: - using heap_type = context_impl::heap_type; - using pointer = heap_type::pointer; - using handle_type = typename pointer::handle_type; - using key_type = typename handle_type::key_type; - using flag_basic_type = key_type; - using flag_type = flag_basic_type volatile; +namespace oomph { + class channel_base + { + protected: + using heap_type = context_impl::heap_type; + using pointer = heap_type::pointer; + using handle_type = typename pointer::handle_type; + using key_type = typename handle_type::key_type; + using flag_basic_type = key_type; + using flag_type = flag_basic_type volatile; - protected: - //heap_type& m_heap; - std::size_t m_size; - std::size_t m_T_size; - std::size_t m_levels; - std::size_t m_capacity; - communicator::rank_type m_remote_rank; - communicator::tag_type m_tag; - bool m_connected = false; - MPI_Request m_init_req; + protected: + //heap_type& m_heap; + std::size_t m_size; + std::size_t m_T_size; + std::size_t m_levels; + std::size_t m_capacity; + communicator::rank_type m_remote_rank; + communicator::tag_type m_tag; + bool m_connected = false; + MPI_Request m_init_req; - public: - channel_base(/*heap_type& h,*/ std::size_t size, std::size_t T_size, - communicator::rank_type remote_rank, communicator::tag_type tag, std::size_t levels) - //: m_heap{h} - : m_size{size} - , m_T_size{T_size} - , m_levels{levels} - , m_capacity{levels} - , m_remote_rank{remote_rank} - , m_tag{tag} - { - } + public: + channel_base(/*heap_type& h,*/ std::size_t size, std::size_t T_size, + communicator::rank_type remote_rank, communicator::tag_type tag, std::size_t levels) + //: m_heap{h} + : m_size{size} + , m_T_size{T_size} + , m_levels{levels} + , m_capacity{levels} + , m_remote_rank{remote_rank} + , m_tag{tag} + { + } - void connect() - { - OOMPH_CHECK_MPI_RESULT(MPI_Wait(&m_init_req, MPI_STATUS_IGNORE)); - m_connected = true; - } + void connect() + { + OOMPH_CHECK_MPI_RESULT(MPI_Wait(&m_init_req, MPI_STATUS_IGNORE)); + m_connected = true; + } - protected: - // index of flag in buffer (in units of flag_basic_type) - std::size_t flag_offset() const noexcept - { - return (m_size * m_T_size + 2 * sizeof(flag_basic_type) - 1) / sizeof(flag_basic_type) - 1; - } - // number of elements of type T (including padding) - std::size_t buffer_size() const noexcept - { - return ((flag_offset() + 1) * sizeof(flag_basic_type) + m_T_size - 1) / m_T_size; - } - // pointer to flag location for a given buffer - void* flag_ptr(void* ptr) const noexcept - { - return (void*)((char*)ptr + flag_offset() * sizeof(flag_basic_type)); - } -}; + protected: + // index of flag in buffer (in units of flag_basic_type) + std::size_t flag_offset() const noexcept + { + return (m_size * m_T_size + 2 * sizeof(flag_basic_type) - 1) / sizeof(flag_basic_type) - + 1; + } + // number of elements of type T (including padding) + std::size_t buffer_size() const noexcept + { + return ((flag_offset() + 1) * sizeof(flag_basic_type) + m_T_size - 1) / m_T_size; + } + // pointer to flag location for a given buffer + void* flag_ptr(void* ptr) const noexcept + { + return (void*) ((char*) ptr + flag_offset() * sizeof(flag_basic_type)); + } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/mpi/communicator.hpp b/src/mpi/communicator.hpp index 0022b157..231cb94f 100644 --- a/src/mpi/communicator.hpp +++ b/src/mpi/communicator.hpp @@ -17,112 +17,111 @@ #include #include -namespace oomph -{ -class communicator_impl : public communicator_base -{ - public: - context_impl* m_context; - request_queue m_send_reqs; - request_queue m_recv_reqs; - - communicator_impl(context_impl* ctxt) - : communicator_base(ctxt) - , m_context(ctxt) +namespace oomph { + class communicator_impl : public communicator_base { - } - - auto& get_heap() noexcept { return m_context->get_heap(); } + public: + context_impl* m_context; + request_queue m_send_reqs; + request_queue m_recv_reqs; - mpi_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag) - { - MPI_Request r; - const_device_guard dg(ptr); - OOMPH_CHECK_MPI_RESULT(MPI_Isend(dg.data(), size, MPI_BYTE, dst, tag, mpi_comm(), &r)); - return {r}; - } + communicator_impl(context_impl* ctxt) + : communicator_base(ctxt) + , m_context(ctxt) + { + } - mpi_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - tag_type tag) - { - MPI_Request r; - device_guard dg(ptr); - OOMPH_CHECK_MPI_RESULT(MPI_Irecv(dg.data(), size, MPI_BYTE, src, tag, mpi_comm(), &r)); - return {r}; - } + auto& get_heap() noexcept { return m_context->get_heap(); } - send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) - { - auto req = send(ptr, size, dst, tag); - if (!has_reached_recursion_depth() && req.is_ready()) + mpi_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, + rank_type dst, tag_type tag) { - auto inc = recursion(); - cb(dst, tag); - return {}; + MPI_Request r; + const_device_guard dg(ptr); + OOMPH_CHECK_MPI_RESULT(MPI_Isend(dg.data(), size, MPI_BYTE, dst, tag, mpi_comm(), &r)); + return {r}; } - else + + mpi_request recv( + context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag) { - auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, - std::move(cb), req); - s->create_self_ref(); - m_send_reqs.enqueue(s.get()); - return {std::move(s)}; + MPI_Request r; + device_guard dg(ptr); + OOMPH_CHECK_MPI_RESULT(MPI_Irecv(dg.data(), size, MPI_BYTE, src, tag, mpi_comm(), &r)); + return {r}; } - } - recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) - { - auto req = recv(ptr, size, src, tag); - if (!has_reached_recursion_depth() && req.is_ready()) + send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, + rank_type dst, tag_type tag, util::unique_function&& cb, + std::size_t* scheduled) { - auto inc = recursion(); - cb(src, tag); - return {}; + auto req = send(ptr, size, dst, tag); + if (!has_reached_recursion_depth() && req.is_ready()) + { + auto inc = recursion(); + cb(dst, tag); + return {}; + } + else + { + auto s = m_req_state_factory.make( + m_context, this, scheduled, dst, tag, std::move(cb), req); + s->create_self_ref(); + m_send_reqs.enqueue(s.get()); + return {std::move(s)}; + } } - else + + recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, + tag_type tag, util::unique_function&& cb, + std::size_t* scheduled) { - auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, - std::move(cb), req); - s->create_self_ref(); - m_recv_reqs.enqueue(s.get()); - return {std::move(s)}; + auto req = recv(ptr, size, src, tag); + if (!has_reached_recursion_depth() && req.is_ready()) + { + auto inc = recursion(); + cb(src, tag); + return {}; + } + else + { + auto s = m_req_state_factory.make( + m_context, this, scheduled, src, tag, std::move(cb), req); + s->create_self_ref(); + m_recv_reqs.enqueue(s.get()); + return {std::move(s)}; + } } - } - shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, - rank_type src, tag_type tag, util::unique_function&& cb, - std::atomic* scheduled) - { - auto req = recv(ptr, size, src, tag); - if (!m_context->has_reached_recursion_depth() && req.is_ready()) + shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, + rank_type src, tag_type tag, util::unique_function&& cb, + std::atomic* scheduled) { - auto inc = m_context->recursion(); - cb(src, tag); - return {}; + auto req = recv(ptr, size, src, tag); + if (!m_context->has_reached_recursion_depth() && req.is_ready()) + { + auto inc = m_context->recursion(); + cb(src, tag); + return {}; + } + else + { + auto s = std::make_shared( + m_context, this, scheduled, src, tag, std::move(cb), req); + s->create_self_ref(); + m_context->m_req_queue.enqueue(s.get()); + return {std::move(s)}; + } } - else + + void progress() { - auto s = std::make_shared(m_context, this, scheduled, src, - tag, std::move(cb), req); - s->create_self_ref(); - m_context->m_req_queue.enqueue(s.get()); - return {std::move(s)}; + m_send_reqs.progress(); + m_recv_reqs.progress(); + m_context->progress(); } - } - - void progress() - { - m_send_reqs.progress(); - m_recv_reqs.progress(); - m_context->progress(); - } - bool cancel_recv(detail::request_state* s) { return m_recv_reqs.cancel(s); } -}; + bool cancel_recv(detail::request_state* s) { return m_recv_reqs.cancel(s); } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/mpi/context.cpp b/src/mpi/context.cpp index 9f3273d4..a0e41fd1 100644 --- a/src/mpi/context.cpp +++ b/src/mpi/context.cpp @@ -9,26 +9,21 @@ */ // paths relative to backend -#include #include +#include -namespace oomph -{ -communicator_impl* -context_impl::get_communicator() -{ - auto comm = new communicator_impl{this}; - m_comms_set.insert(comm); - return comm; -} - -const char *context_impl::get_transport_option(const std::string &opt) { - if (opt == "name") { - return "mpi"; +namespace oomph { + communicator_impl* context_impl::get_communicator() + { + auto comm = new communicator_impl{this}; + m_comms_set.insert(comm); + return comm; } - else { - return "unspecified"; + + char const* context_impl::get_transport_option(std::string const& opt) + { + if (opt == "name") { return "mpi"; } + else { return "unspecified"; } } -} -} // namespace oomph +} // namespace oomph diff --git a/src/mpi/context.hpp b/src/mpi/context.hpp index e31bc73e..b3549f04 100644 --- a/src/mpi/context.hpp +++ b/src/mpi/context.hpp @@ -13,92 +13,89 @@ // paths relative to backend #include <../context_base.hpp> -#include #include +#include -namespace oomph -{ -class context_impl : public context_base -{ - public: - using region_type = region; - using device_region_type = region; - using heap_type = hwmalloc::heap; - - private: - heap_type m_heap; - //rma_context m_rma_context; - unsigned int m_n_tag_bits; - - public: - shared_request_queue m_req_queue; - - public: - context_impl(MPI_Comm comm, bool thread_safe, bool message_pool_never_free, - std::size_t message_pool_reserve) - : context_base(comm, thread_safe) - , m_heap{this, message_pool_never_free, message_pool_reserve} - //, m_rma_context{m_mpi_comm} +namespace oomph { + class context_impl : public context_base { - // get largest allowed tag value - int flag; - int* tag_ub; - MPI_Comm_get_attr(this->get_comm(), MPI_TAG_UB, &tag_ub, &flag); - unsigned int max_tag = flag ? *tag_ub : 32767; - - // compute bit mask - unsigned long tmp = max_tag; - unsigned long mask = 1u; - m_n_tag_bits = 0; - while (tmp > 0) + public: + using region_type = region; + using device_region_type = region; + using heap_type = hwmalloc::heap; + + private: + heap_type m_heap; + //rma_context m_rma_context; + unsigned int m_n_tag_bits; + + public: + shared_request_queue m_req_queue; + + public: + context_impl(MPI_Comm comm, bool thread_safe, bool message_pool_never_free, + std::size_t message_pool_reserve) + : context_base(comm, thread_safe) + , m_heap{this, message_pool_never_free, message_pool_reserve} + //, m_rma_context{m_mpi_comm} { - ++m_n_tag_bits; - tmp >>= 1; - mask <<= 1; + // get largest allowed tag value + int flag; + int* tag_ub; + MPI_Comm_get_attr(this->get_comm(), MPI_TAG_UB, &tag_ub, &flag); + unsigned int max_tag = flag ? *tag_ub : 32767; + + // compute bit mask + unsigned long tmp = max_tag; + unsigned long mask = 1u; + m_n_tag_bits = 0; + while (tmp > 0) + { + ++m_n_tag_bits; + tmp >>= 1; + mask <<= 1; + } + mask -= 1; + + // If bit mask is larger than max tag value, then we have some strange upper bound which is + // not at a power of 2 boundary and we reduce the maximum to the next lower power of 2. + if (mask > max_tag) --m_n_tag_bits; } - mask -= 1; - - // If bit mask is larger than max tag value, then we have some strange upper bound which is - // not at a power of 2 boundary and we reduce the maximum to the next lower power of 2. - if (mask > max_tag) --m_n_tag_bits; - } - context_impl(context_impl const&) = delete; - context_impl(context_impl&&) = delete; + context_impl(context_impl const&) = delete; + context_impl(context_impl&&) = delete; - region make_region(void* ptr) const { return {ptr}; } + region make_region(void* ptr) const { return {ptr}; } - auto& get_heap() noexcept { return m_heap; } + auto& get_heap() noexcept { return m_heap; } - //auto get_window() const noexcept { return m_rma_context.get_window(); } - //auto& get_rma_heap() noexcept { return m_rma_context.get_heap(); } - //void lock(rank_type r) { m_rma_context.lock(r); } + //auto get_window() const noexcept { return m_rma_context.get_window(); } + //auto& get_rma_heap() noexcept { return m_rma_context.get_heap(); } + //void lock(rank_type r) { m_rma_context.lock(r); } - communicator_impl* get_communicator(); + communicator_impl* get_communicator(); - void progress() { m_req_queue.progress(); } + void progress() { m_req_queue.progress(); } - bool cancel_recv(detail::shared_request_state* r) { return m_req_queue.cancel(r); } + bool cancel_recv(detail::shared_request_state* r) { return m_req_queue.cancel(r); } - unsigned int num_tag_bits() const noexcept { return m_n_tag_bits; } + unsigned int num_tag_bits() const noexcept { return m_n_tag_bits; } - const char* get_transport_option(const std::string& opt); -}; + char const* get_transport_option(std::string const& opt); + }; -template<> -inline region -register_memory(context_impl& c, void* ptr, std::size_t) -{ - return c.make_region(ptr); -} + template <> + inline region register_memory(context_impl& c, void* ptr, std::size_t) + { + return c.make_region(ptr); + } #if OOMPH_ENABLE_DEVICE -template<> -inline region -register_device_memory(context_impl& c, int, void* ptr, std::size_t) -{ - return c.make_region(ptr); -} + template <> + inline region register_device_memory(context_impl& c, int, void* ptr, std::size_t) + { + return c.make_region(ptr); + } #endif -} // namespace oomph +} // namespace oomph diff --git a/src/mpi/handle.hpp b/src/mpi/handle.hpp index 179c2686..f06b1ccf 100644 --- a/src/mpi/handle.hpp +++ b/src/mpi/handle.hpp @@ -11,21 +11,20 @@ #include -namespace oomph -{ -struct handle -{ - using key_type = MPI_Aint; +namespace oomph { + struct handle + { + using key_type = MPI_Aint; - void* m_ptr; - std::size_t m_size; + void* m_ptr; + std::size_t m_size; - key_type get_remote_key() const noexcept - { - MPI_Aint address; - OOMPH_CHECK_MPI_RESULT_NOEXCEPT(MPI_Get_address(m_ptr, &address)); - return address; - //return ((char*)m_ptr - MPI_BOTTOM); - } -}; -} // namespace oomph + key_type get_remote_key() const noexcept + { + MPI_Aint address; + OOMPH_CHECK_MPI_RESULT_NOEXCEPT(MPI_Get_address(m_ptr, &address)); + return address; + //return ((char*)m_ptr - MPI_BOTTOM); + } + }; +} // namespace oomph diff --git a/src/mpi/lock_cache.hpp b/src/mpi/lock_cache.hpp index 1b61f46f..f9c5c549 100644 --- a/src/mpi/lock_cache.hpp +++ b/src/mpi/lock_cache.hpp @@ -9,45 +9,44 @@ */ #pragma once -#include #include +#include -#include #include +#include -namespace oomph -{ -class lock_cache -{ - private: - MPI_Win m_win; - std::set m_ranks; - std::mutex m_mutex; - - public: - lock_cache(MPI_Win win) noexcept - : m_win(win) +namespace oomph { + class lock_cache { - } - - lock_cache(lock_cache const&) = delete; + private: + MPI_Win m_win; + std::set m_ranks; + std::mutex m_mutex; + + public: + lock_cache(MPI_Win win) noexcept + : m_win(win) + { + } - ~lock_cache() - { - for (auto r : m_ranks) MPI_Win_unlock(r, m_win); - } + lock_cache(lock_cache const&) = delete; - void lock(rank_type r) - { - std::lock_guard l(m_mutex); + ~lock_cache() + { + for (auto r : m_ranks) MPI_Win_unlock(r, m_win); + } - auto it = m_ranks.find(r); - if (it == m_ranks.end()) + void lock(rank_type r) { - m_ranks.insert(r); - OOMPH_CHECK_MPI_RESULT(MPI_Win_lock(MPI_LOCK_SHARED, r, 0, m_win)); + std::lock_guard l(m_mutex); + + auto it = m_ranks.find(r); + if (it == m_ranks.end()) + { + m_ranks.insert(r); + OOMPH_CHECK_MPI_RESULT(MPI_Win_lock(MPI_LOCK_SHARED, r, 0, m_win)); + } } - } -}; + }; -} // namespace oomph +} // namespace oomph diff --git a/src/mpi/recv_channel.hpp b/src/mpi/recv_channel.hpp index 87b0c269..13d0cbf7 100644 --- a/src/mpi/recv_channel.hpp +++ b/src/mpi/recv_channel.hpp @@ -15,63 +15,55 @@ // paths relative to backend #include -namespace oomph -{ -class recv_channel_impl : public channel_base -{ - using base = channel_base; - using flag_basic_type = typename base::flag_basic_type; - using flag_type = typename base::flag_type; - using pointer = typename base::pointer; - using handle_type = typename base::handle_type; - using key_type = typename base::key_type; +namespace oomph { + class recv_channel_impl : public channel_base + { + using base = channel_base; + using flag_basic_type = typename base::flag_basic_type; + using flag_type = typename base::flag_type; + using pointer = typename base::pointer; + using handle_type = typename base::handle_type; + using key_type = typename base::key_type; - private: - communicator::impl* m_comm; - pointer m_buffer; - key_type m_local_key; + private: + communicator::impl* m_comm; + pointer m_buffer; + key_type m_local_key; - public: - recv_channel_impl(communicator::impl* impl_, std::size_t size, std::size_t T_size, - communicator::rank_type src, communicator::tag_type tag, std::size_t levels) - : base(size, T_size, src, tag, levels) - , m_comm(impl_) - , m_buffer{m_comm->get_heap().allocate( - levels * base::buffer_size() * T_size, hwmalloc::numa().local_node())} - , m_local_key{m_buffer.handle().get_remote_key()} - { - m_comm->m_context->lock(src); - OOMPH_CHECK_MPI_RESULT(MPI_Isend(&m_local_key, sizeof(key_type), MPI_BYTE, - base::m_remote_rank, base::m_tag, m_comm->get_comm(), &(base::m_init_req))); - } - recv_channel_impl(recv_channel_impl const&) = delete; - recv_channel_impl(recv_channel_impl&&) = delete; + public: + recv_channel_impl(communicator::impl* impl_, std::size_t size, std::size_t T_size, + communicator::rank_type src, communicator::tag_type tag, std::size_t levels) + : base(size, T_size, src, tag, levels) + , m_comm(impl_) + , m_buffer{m_comm->get_heap().allocate( + levels * base::buffer_size() * T_size, hwmalloc::numa().local_node())} + , m_local_key{m_buffer.handle().get_remote_key()} + { + m_comm->m_context->lock(src); + OOMPH_CHECK_MPI_RESULT(MPI_Isend(&m_local_key, sizeof(key_type), MPI_BYTE, + base::m_remote_rank, base::m_tag, m_comm->get_comm(), &(base::m_init_req))); + } + recv_channel_impl(recv_channel_impl const&) = delete; + recv_channel_impl(recv_channel_impl&&) = delete; - ~recv_channel_impl() - { - } - - //void connect() {} - - std::size_t capacity() - { - return base::m_capacity; - } + ~recv_channel_impl() {} - void* get(std::size_t& index) - { - index = 0; - return nullptr; - } + //void connect() {} + + std::size_t capacity() { return base::m_capacity; } - void release(std::size_t index) + void* get(std::size_t& index) + { + index = 0; + return nullptr; + } + + void release(std::size_t index) {} + }; + + void release_recv_channel_buffer(recv_channel_impl* rc, std::size_t index) { + rc->release(index); } -}; - -void release_recv_channel_buffer(recv_channel_impl* rc, std::size_t index) -{ - rc->release(index); -} -} // namespace oomph +} // namespace oomph diff --git a/src/mpi/region.hpp b/src/mpi/region.hpp index 78154a00..7a18eb73 100644 --- a/src/mpi/region.hpp +++ b/src/mpi/region.hpp @@ -12,73 +12,72 @@ // paths relative to backend #include -namespace oomph -{ -class region -{ - public: - using handle_type = handle; - - private: - void* m_ptr; - - public: - region(void* ptr) - : m_ptr{ptr} +namespace oomph { + class region { - } + public: + using handle_type = handle; - region(region const&) = delete; + private: + void* m_ptr; - region(region&& r) noexcept - : m_ptr{std::exchange(r.m_ptr, nullptr)} - { - } + public: + region(void* ptr) + : m_ptr{ptr} + { + } - // get a handle to some portion of the region - handle_type get_handle(std::size_t offset, std::size_t size) - { - return {(void*)((char*)m_ptr + offset), size}; - } -}; + region(region const&) = delete; -class rma_region -{ - public: - using handle_type = handle; + region(region&& r) noexcept + : m_ptr{std::exchange(r.m_ptr, nullptr)} + { + } - private: - MPI_Comm m_comm; - MPI_Win m_win; - void* m_ptr; + // get a handle to some portion of the region + handle_type get_handle(std::size_t offset, std::size_t size) + { + return {(void*) ((char*) m_ptr + offset), size}; + } + }; - public: - rma_region(MPI_Comm comm, MPI_Win win, void* ptr, std::size_t size) - : m_comm{comm} - , m_win{win} - , m_ptr{ptr} + class rma_region { - OOMPH_CHECK_MPI_RESULT(MPI_Win_attach(m_win, ptr, size)); - } + public: + using handle_type = handle; - rma_region(rma_region const&) = delete; + private: + MPI_Comm m_comm; + MPI_Win m_win; + void* m_ptr; - rma_region(rma_region&& r) noexcept - : m_comm{r.m_comm} - , m_win{r.m_win} - , m_ptr{std::exchange(r.m_ptr, nullptr)} - { - } + public: + rma_region(MPI_Comm comm, MPI_Win win, void* ptr, std::size_t size) + : m_comm{comm} + , m_win{win} + , m_ptr{ptr} + { + OOMPH_CHECK_MPI_RESULT(MPI_Win_attach(m_win, ptr, size)); + } - ~rma_region() - { - if (m_ptr) MPI_Win_detach(m_win, m_ptr); - } + rma_region(rma_region const&) = delete; - // get a handle to some portion of the region - handle_type get_handle(std::size_t offset, std::size_t size) - { - return {(void*)((char*)m_ptr + offset), size}; - } -}; -} // namespace oomph + rma_region(rma_region&& r) noexcept + : m_comm{r.m_comm} + , m_win{r.m_win} + , m_ptr{std::exchange(r.m_ptr, nullptr)} + { + } + + ~rma_region() + { + if (m_ptr) MPI_Win_detach(m_win, m_ptr); + } + + // get a handle to some portion of the region + handle_type get_handle(std::size_t offset, std::size_t size) + { + return {(void*) ((char*) m_ptr + offset), size}; + } + }; +} // namespace oomph diff --git a/src/mpi/request.hpp b/src/mpi/request.hpp index a126143b..39642c66 100644 --- a/src/mpi/request.hpp +++ b/src/mpi/request.hpp @@ -11,27 +11,26 @@ #include -namespace oomph -{ -struct mpi_request -{ - MPI_Request m_req; - - bool is_ready() +namespace oomph { + struct mpi_request { - int flag; - OOMPH_CHECK_MPI_RESULT(MPI_Test(&m_req, &flag, MPI_STATUS_IGNORE)); - return flag; - } + MPI_Request m_req; - bool cancel() - { - OOMPH_CHECK_MPI_RESULT(MPI_Cancel(&m_req)); - MPI_Status st; - OOMPH_CHECK_MPI_RESULT(MPI_Wait(&m_req, &st)); - int flag = false; - OOMPH_CHECK_MPI_RESULT(MPI_Test_cancelled(&st, &flag)); - return flag; - } -}; -} // namespace oomph + bool is_ready() + { + int flag; + OOMPH_CHECK_MPI_RESULT(MPI_Test(&m_req, &flag, MPI_STATUS_IGNORE)); + return flag; + } + + bool cancel() + { + OOMPH_CHECK_MPI_RESULT(MPI_Cancel(&m_req)); + MPI_Status st; + OOMPH_CHECK_MPI_RESULT(MPI_Wait(&m_req, &st)); + int flag = false; + OOMPH_CHECK_MPI_RESULT(MPI_Test_cancelled(&st, &flag)); + return flag; + } + }; +} // namespace oomph diff --git a/src/mpi/request_queue.hpp b/src/mpi/request_queue.hpp index bc44e415..ee032c32 100644 --- a/src/mpi/request_queue.hpp +++ b/src/mpi/request_queue.hpp @@ -9,206 +9,197 @@ */ #pragma once -#include #include +#include // paths relative to backend #include -namespace oomph -{ - -class request_queue -{ - private: - using element_type = detail::request_state; - using queue_type = std::vector; - - private: // members - queue_type m_queue; - queue_type m_ready_queue; - bool in_progress = false; - std::vector reqs; - std::vector indices; +namespace oomph { - public: // ctors - request_queue() + class request_queue { - m_queue.reserve(256); - m_ready_queue.reserve(256); - } - - public: // member functions - std::size_t size() const noexcept { return m_queue.size(); } - - void enqueue(element_type* e) - { - e->m_index = m_queue.size(); - m_queue.push_back(e); - } + private: + using element_type = detail::request_state; + using queue_type = std::vector; + + private: // members + queue_type m_queue; + queue_type m_ready_queue; + bool in_progress = false; + std::vector reqs; + std::vector indices; + + public: // ctors + request_queue() + { + m_queue.reserve(256); + m_ready_queue.reserve(256); + } - int progress() - { - if (in_progress) return 0; - in_progress = true; + public: // member functions + std::size_t size() const noexcept { return m_queue.size(); } - const auto qs = size(); - if (qs == 0) + void enqueue(element_type* e) { - in_progress = false; - return 0; + e->m_index = m_queue.size(); + m_queue.push_back(e); } - m_ready_queue.clear(); + int progress() + { + if (in_progress) return 0; + in_progress = true; - m_ready_queue.reserve(qs); - //reqs.resize(0); - reqs.clear(); - reqs.reserve(qs); - indices.resize(qs + 1); + auto const qs = size(); + if (qs == 0) + { + in_progress = false; + return 0; + } - std::transform(m_queue.begin(), m_queue.end(), std::back_inserter(reqs), - [](auto e) { return e->m_req.m_req; }); + m_ready_queue.clear(); - int outcount; - OOMPH_CHECK_MPI_RESULT( - MPI_Testsome(qs, reqs.data(), &outcount, indices.data(), MPI_STATUSES_IGNORE)); + m_ready_queue.reserve(qs); + //reqs.resize(0); + reqs.clear(); + reqs.reserve(qs); + indices.resize(qs + 1); - if (outcount == 0) - { - in_progress = false; - return 0; - } + std::transform(m_queue.begin(), m_queue.end(), std::back_inserter(reqs), + [](auto e) { return e->m_req.m_req; }); - indices[outcount] = qs; + int outcount; + OOMPH_CHECK_MPI_RESULT( + MPI_Testsome(qs, reqs.data(), &outcount, indices.data(), MPI_STATUSES_IGNORE)); - std::size_t k = 0; - std::size_t j = 0; - for (std::size_t i = 0; i < qs; ++i) - { - auto e = m_queue[i]; - if ((int)i == indices[k]) + if (outcount == 0) { - m_ready_queue.push_back(e); - ++k; + in_progress = false; + return 0; } - else if (i > j) + + indices[outcount] = qs; + + std::size_t k = 0; + std::size_t j = 0; + for (std::size_t i = 0; i < qs; ++i) { - e->m_index = j; - m_queue[j] = e; - ++j; + auto e = m_queue[i]; + if ((int) i == indices[k]) + { + m_ready_queue.push_back(e); + ++k; + } + else if (i > j) + { + e->m_index = j; + m_queue[j] = e; + ++j; + } + else { ++j; } } - else + m_queue.erase(m_queue.end() - m_ready_queue.size(), m_queue.end()); + + int completed = m_ready_queue.size(); + for (auto e : m_ready_queue) { - ++j; + auto ptr = e->release_self_ref(); + e->invoke_cb(); } - } - m_queue.erase(m_queue.end() - m_ready_queue.size(), m_queue.end()); - int completed = m_ready_queue.size(); - for (auto e : m_ready_queue) - { - auto ptr = e->release_self_ref(); - e->invoke_cb(); + in_progress = false; + return completed; } - in_progress = false; - return completed; - } - - bool cancel(element_type* e) - { - auto const index = e->m_index; - if (m_queue[index]->m_req.cancel()) + bool cancel(element_type* e) { - auto ptr = e->release_self_ref(); - e->set_canceled(); - if (index + 1 < m_queue.size()) + auto const index = e->m_index; + if (m_queue[index]->m_req.cancel()) { - m_queue[index] = m_queue.back(); - m_queue[index]->m_index = index; + auto ptr = e->release_self_ref(); + e->set_canceled(); + if (index + 1 < m_queue.size()) + { + m_queue[index] = m_queue.back(); + m_queue[index]->m_index = index; + } + m_queue.pop_back(); + return true; } - m_queue.pop_back(); - return true; + else + return false; } - else - return false; - } -}; - -class shared_request_queue -{ - private: - using element_type = detail::shared_request_state; - using queue_type = boost::lockfree::queue, - boost::lockfree::allocator>>; - - private: // members - queue_type m_queue; - std::atomic m_size; - - public: // ctors - shared_request_queue() - : m_queue(256) - , m_size(0) - { - } - - public: // member functions - std::size_t size() const noexcept { return m_size.load(); } + }; - void enqueue(element_type* e) + class shared_request_queue { - m_queue.push(e); - ++m_size; - } + private: + using element_type = detail::shared_request_state; + using queue_type = boost::lockfree::queue, boost::lockfree::allocator>>; + + private: // members + queue_type m_queue; + std::atomic m_size; + + public: // ctors + shared_request_queue() + : m_queue(256) + , m_size(0) + { + } - int progress() - { - static thread_local bool in_progress = false; - static thread_local std::vector m_local_queue; - int found = 0; + public: // member functions + std::size_t size() const noexcept { return m_size.load(); } - if (in_progress) return 0; - in_progress = true; + void enqueue(element_type* e) + { + m_queue.push(e); + ++m_size; + } - element_type* e; - while (m_queue.pop(e)) + int progress() { - if (e->m_req.is_ready()) + static thread_local bool in_progress = false; + static thread_local std::vector m_local_queue; + int found = 0; + + if (in_progress) return 0; + in_progress = true; + + element_type* e; + while (m_queue.pop(e)) { - found = 1; - break; + if (e->m_req.is_ready()) + { + found = 1; + break; + } + else { m_local_queue.push_back(e); } } - else + + for (auto x : m_local_queue) m_queue.push(x); + m_local_queue.clear(); + + if (found) { - m_local_queue.push_back(e); + auto ptr = e->release_self_ref(); + e->invoke_cb(); + --m_size; } - } - - for (auto x : m_local_queue) m_queue.push(x); - m_local_queue.clear(); - if (found) - { - auto ptr = e->release_self_ref(); - e->invoke_cb(); - --m_size; + in_progress = false; + return found; } - in_progress = false; - return found; - } - - bool cancel(element_type* e) - { - static thread_local std::vector m_local_queue; - m_local_queue.clear(); + bool cancel(element_type* e) + { + static thread_local std::vector m_local_queue; + m_local_queue.clear(); - bool canceled = false; - m_queue.consume_all( - [q = &m_local_queue, e, &canceled](element_type* x) - { + bool canceled = false; + m_queue.consume_all([q = &m_local_queue, e, &canceled](element_type* x) { if (e == x) { if (e->m_req.cancel()) @@ -224,10 +215,10 @@ class shared_request_queue q->push_back(x); }); - for (auto x : m_local_queue) m_queue.push(x); + for (auto x : m_local_queue) m_queue.push(x); - return canceled; - } -}; + return canceled; + } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/mpi/request_state.hpp b/src/mpi/request_state.hpp index da69eb95..140fd370 100644 --- a/src/mpi/request_state.hpp +++ b/src/mpi/request_state.hpp @@ -15,83 +15,79 @@ #include <../request_state_base.hpp> #include -namespace oomph -{ -namespace detail -{ -struct request_state -: public util::enable_shared_from_this -, public request_state_base -{ - using base = request_state_base; - using shared_ptr_t = util::unsafe_shared_ptr; - - mpi_request m_req; - shared_ptr_t m_self_ptr; - std::size_t m_index; - - request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, std::size_t* scheduled, - rank_type rank, tag_type tag, cb_type&& cb, mpi_request m) - : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} - , m_req{m} +namespace oomph { namespace detail { + struct request_state + : public util::enable_shared_from_this + , public request_state_base { - } - - void progress(); - - bool cancel(); - - void create_self_ref() + using base = request_state_base; + using shared_ptr_t = util::unsafe_shared_ptr; + + mpi_request m_req; + shared_ptr_t m_self_ptr; + std::size_t m_index; + + request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, + std::size_t* scheduled, rank_type rank, tag_type tag, cb_type&& cb, mpi_request m) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_req{m} + { + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool) m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } + }; + + struct shared_request_state + : public std::enable_shared_from_this + , public request_state_base { - // create a self-reference cycle!! - // this is useful if we only keep a raw pointer around internally, which still is supposed - // to keep the object alive - m_self_ptr = shared_from_this(); - } - - shared_ptr_t release_self_ref() noexcept - { - assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); - return std::move(m_self_ptr); - } -}; - -struct shared_request_state -: public std::enable_shared_from_this -, public request_state_base -{ - using base = request_state_base; - using shared_ptr_t = std::shared_ptr; - - mpi_request m_req; - shared_ptr_t m_self_ptr; - - shared_request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, - std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb, - mpi_request m) - : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} - , m_req{m} - { - } - - void progress(); - - bool cancel(); - - void create_self_ref() - { - // create a self-reference cycle!! - // this is useful if we only keep a raw pointer around internally, which still is supposed - // to keep the object alive - m_self_ptr = shared_from_this(); - } - - shared_ptr_t release_self_ref() noexcept - { - assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); - return std::move(m_self_ptr); - } -}; - -} // namespace detail -} // namespace oomph + using base = request_state_base; + using shared_ptr_t = std::shared_ptr; + + mpi_request m_req; + shared_ptr_t m_self_ptr; + + shared_request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, + std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb, + mpi_request m) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_req{m} + { + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool) m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } + }; + +}} // namespace oomph::detail diff --git a/src/mpi/rma_context.hpp b/src/mpi/rma_context.hpp index aec295f0..c9a20052 100644 --- a/src/mpi/rma_context.hpp +++ b/src/mpi/rma_context.hpp @@ -9,76 +9,74 @@ */ #pragma once -#include #include +#include #include // paths relative to backend -#include #include +#include -namespace oomph -{ -class rma_context -{ - public: - using region_type = rma_region; - using device_region_type = rma_region; - using heap_type = hwmalloc::heap; - - private: - struct mpi_win_holder +namespace oomph { + class rma_context { - MPI_Win m; - ~mpi_win_holder() { MPI_Win_free(&m); } - }; + public: + using region_type = rma_region; + using device_region_type = rma_region; + using heap_type = hwmalloc::heap; - private: - MPI_Comm m_mpi_comm; - mpi_win_holder m_win; - heap_type m_heap; - std::unique_ptr m_lock_cache; + private: + struct mpi_win_holder + { + MPI_Win m; + ~mpi_win_holder() { MPI_Win_free(&m); } + }; - public: - rma_context(MPI_Comm comm) - : m_mpi_comm{comm} - , m_heap{this} - { - MPI_Info info; - OOMPH_CHECK_MPI_RESULT(MPI_Info_create(&info)); - OOMPH_CHECK_MPI_RESULT(MPI_Info_set(info, "no_locks", "false")); - OOMPH_CHECK_MPI_RESULT(MPI_Win_create_dynamic(info, m_mpi_comm, &(m_win.m))); - MPI_Info_free(&info); - OOMPH_CHECK_MPI_RESULT(MPI_Win_fence(0, m_win.m)); - m_lock_cache = std::make_unique(m_win.m); - } - rma_context(context_impl const&) = delete; - rma_context(context_impl&&) = delete; + private: + MPI_Comm m_mpi_comm; + mpi_win_holder m_win; + heap_type m_heap; + std::unique_ptr m_lock_cache; - rma_region make_region(void* ptr, std::size_t size) const - { - return {m_mpi_comm, m_win.m, ptr, size}; - } + public: + rma_context(MPI_Comm comm) + : m_mpi_comm{comm} + , m_heap{this} + { + MPI_Info info; + OOMPH_CHECK_MPI_RESULT(MPI_Info_create(&info)); + OOMPH_CHECK_MPI_RESULT(MPI_Info_set(info, "no_locks", "false")); + OOMPH_CHECK_MPI_RESULT(MPI_Win_create_dynamic(info, m_mpi_comm, &(m_win.m))); + MPI_Info_free(&info); + OOMPH_CHECK_MPI_RESULT(MPI_Win_fence(0, m_win.m)); + m_lock_cache = std::make_unique(m_win.m); + } + rma_context(context_impl const&) = delete; + rma_context(context_impl&&) = delete; - auto get_window() const noexcept { return m_win.m; } - auto& get_heap() noexcept { return m_heap; } - void lock(rank_type r) { m_lock_cache->lock(r); } -}; + rma_region make_region(void* ptr, std::size_t size) const + { + return {m_mpi_comm, m_win.m, ptr, size}; + } + + auto get_window() const noexcept { return m_win.m; } + auto& get_heap() noexcept { return m_heap; } + void lock(rank_type r) { m_lock_cache->lock(r); } + }; -template<> -inline rma_region -register_memory(rma_context& c, void* ptr, std::size_t size) -{ - return c.make_region(ptr, size); -} + template <> + inline rma_region register_memory(rma_context& c, void* ptr, std::size_t size) + { + return c.make_region(ptr, size); + } #if OOMPH_ENABLE_DEVICE -template<> -inline rma_region -register_device_memory(rma_context& c, int, void* ptr, std::size_t size) -{ - return c.make_region(ptr, size); -} + template <> + inline rma_region + register_device_memory(rma_context& c, int, void* ptr, std::size_t size) + { + return c.make_region(ptr, size); + } #endif -} // namespace oomph +} // namespace oomph diff --git a/src/mpi/send_channel.hpp b/src/mpi/send_channel.hpp index caa95b74..1e15e2e6 100644 --- a/src/mpi/send_channel.hpp +++ b/src/mpi/send_channel.hpp @@ -14,33 +14,31 @@ // paths relative to backend #include -namespace oomph -{ -class send_channel_impl : public channel_base -{ - using base = channel_base; - using flag_basic_type = typename base::flag_basic_type; - using flag_type = typename base::flag_type; - using pointer = typename base::pointer; - using handle_type = typename base::handle_type; - using key_type = typename base::key_type; - - communicator::impl* m_comm; - key_type m_remote_key; - - public: - send_channel_impl(communicator::impl* impl_, std::size_t size, std::size_t T_size, - communicator::rank_type dst, communicator::tag_type tag, std::size_t levels) - : base(size, T_size, dst, tag, levels) - , m_comm(impl_) +namespace oomph { + class send_channel_impl : public channel_base { - m_comm->m_context->lock(dst); - OOMPH_CHECK_MPI_RESULT(MPI_Irecv(&m_remote_key, sizeof(key_type), MPI_BYTE, - base::m_remote_rank, base::m_tag, m_comm->get_comm(), &(base::m_init_req))); - } - send_channel_impl(send_channel_impl const&) = delete; - send_channel_impl(send_channel_impl&&) = delete; + using base = channel_base; + using flag_basic_type = typename base::flag_basic_type; + using flag_type = typename base::flag_type; + using pointer = typename base::pointer; + using handle_type = typename base::handle_type; + using key_type = typename base::key_type; + + communicator::impl* m_comm; + key_type m_remote_key; -}; + public: + send_channel_impl(communicator::impl* impl_, std::size_t size, std::size_t T_size, + communicator::rank_type dst, communicator::tag_type tag, std::size_t levels) + : base(size, T_size, dst, tag, levels) + , m_comm(impl_) + { + m_comm->m_context->lock(dst); + OOMPH_CHECK_MPI_RESULT(MPI_Irecv(&m_remote_key, sizeof(key_type), MPI_BYTE, + base::m_remote_rank, base::m_tag, m_comm->get_comm(), &(base::m_init_req))); + } + send_channel_impl(send_channel_impl const&) = delete; + send_channel_impl(send_channel_impl&&) = delete; + }; -} // namespace oomph +} // namespace oomph diff --git a/src/mpi_comm.hpp b/src/mpi_comm.hpp index cfc48ca6..106ae91d 100644 --- a/src/mpi_comm.hpp +++ b/src/mpi_comm.hpp @@ -12,39 +12,38 @@ #include #include -namespace oomph -{ -class mpi_comm -{ - private: - MPI_Comm m_comm; - rank_type m_rank; - rank_type m_size; - - public: - mpi_comm(MPI_Comm comm) - : m_comm{comm} - , m_rank{[](MPI_Comm c) { - int r; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(c, &r)); - return r; - }(comm)} - , m_size{[](MPI_Comm c) { - int s; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(c, &s)); - return s; - }(comm)} +namespace oomph { + class mpi_comm { - } + private: + MPI_Comm m_comm; + rank_type m_rank; + rank_type m_size; + + public: + mpi_comm(MPI_Comm comm) + : m_comm{comm} + , m_rank{[](MPI_Comm c) { + int r; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(c, &r)); + return r; + }(comm)} + , m_size{[](MPI_Comm c) { + int s; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(c, &s)); + return s; + }(comm)} + { + } - mpi_comm(mpi_comm const&) = default; - mpi_comm& operator=(mpi_comm const&) = default; + mpi_comm(mpi_comm const&) = default; + mpi_comm& operator=(mpi_comm const&) = default; - rank_type rank() const noexcept { return m_rank; } - rank_type size() const noexcept { return m_size; } + rank_type rank() const noexcept { return m_rank; } + rank_type size() const noexcept { return m_size; } - operator MPI_Comm() const noexcept { return m_comm; } - MPI_Comm get() const noexcept { return m_comm; } -}; + operator MPI_Comm() const noexcept { return m_comm; } + MPI_Comm get() const noexcept { return m_comm; } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/rank_topology.hpp b/src/rank_topology.hpp index e357bc6c..1bdbcf06 100644 --- a/src/rank_topology.hpp +++ b/src/rank_topology.hpp @@ -9,48 +9,47 @@ */ #pragma once -#include -#include #include +#include +#include -namespace oomph -{ -/** @brief Class representing node (shared memory) topology. */ -class rank_topology -{ - public: // member types - using set_type = std::unordered_set; - using size_type = set_type::size_type; - - private: // members - MPI_Comm m_comm; - MPI_Comm m_shared_comm; - int m_rank; - std::unordered_set m_rank_set; - - public: // ctors - /** @brief construct from MPI communicator */ - rank_topology(MPI_Comm comm); - rank_topology(const rank_topology&) = default; - rank_topology(rank_topology&&) noexcept = default; - rank_topology& operator=(const rank_topology&) = default; - rank_topology& operator=(rank_topology&&) noexcept = default; - - public: // member functions - /** @brief return whether rank is located on this node */ - bool is_local(int rank) const noexcept { return m_rank_set.find(rank) != m_rank_set.end(); } - - /** @brief return number of ranks on this node */ - size_type local_size() const noexcept { return m_rank_set.size(); } - - /** @brief return ranks on this node */ - const set_type& local_ranks() const noexcept { return m_rank_set; } - - /** @brief return local rank number */ - int local_rank() const noexcept { return m_rank; } - - /** @brief return raw mpi communicator */ - auto mpi_comm() const noexcept { return m_comm; } -}; - -} //namespace oomph +namespace oomph { + /** @brief Class representing node (shared memory) topology. */ + class rank_topology + { + public: // member types + using set_type = std::unordered_set; + using size_type = set_type::size_type; + + private: // members + MPI_Comm m_comm; + MPI_Comm m_shared_comm; + int m_rank; + std::unordered_set m_rank_set; + + public: // ctors + /** @brief construct from MPI communicator */ + rank_topology(MPI_Comm comm); + rank_topology(rank_topology const&) = default; + rank_topology(rank_topology&&) noexcept = default; + rank_topology& operator=(rank_topology const&) = default; + rank_topology& operator=(rank_topology&&) noexcept = default; + + public: // member functions + /** @brief return whether rank is located on this node */ + bool is_local(int rank) const noexcept { return m_rank_set.find(rank) != m_rank_set.end(); } + + /** @brief return number of ranks on this node */ + size_type local_size() const noexcept { return m_rank_set.size(); } + + /** @brief return ranks on this node */ + set_type const& local_ranks() const noexcept { return m_rank_set; } + + /** @brief return local rank number */ + int local_rank() const noexcept { return m_rank; } + + /** @brief return raw mpi communicator */ + auto mpi_comm() const noexcept { return m_comm; } + }; + +} //namespace oomph diff --git a/src/request.cpp b/src/request.cpp index 972650f3..0d39d333 100644 --- a/src/request.cpp +++ b/src/request.cpp @@ -11,158 +11,125 @@ #include // paths relative to backend -#include -#include #include <../message_buffer.hpp> #include <../util/heap_pimpl_src.hpp> +#include +#include OOMPH_INSTANTIATE_HEAP_PIMPL(oomph::context_impl) -namespace oomph -{ - -bool -send_request::is_ready() const noexcept -{ - if (!m) return true; - return m->is_ready(); -} - -bool -send_request::test() -{ - if (!m || m->is_ready()) return true; - m->progress(); - return m->is_ready(); -} - -void -send_request::wait() -{ - if (!m) return; - while (!m->is_ready()) m->progress(); -} - -bool -recv_request::is_ready() const noexcept -{ - if (!m) return true; - return m->is_ready(); -} - -bool -recv_request::is_canceled() const noexcept -{ - if (!m) return true; - return m->is_canceled(); -} - -bool -recv_request::test() -{ - if (!m || m->is_ready()) return true; - m->progress(); - return m->is_ready(); -} - -void -recv_request::wait() -{ - if (!m) return; - while (!m->is_ready()) m->progress(); -} - -bool -recv_request::cancel() -{ - if (!m) return false; - if (m->is_ready()) return false; - return m->cancel(); -} - -bool -shared_recv_request::is_ready() const noexcept -{ - if (!m) return true; - return m->is_ready(); -} - -bool -shared_recv_request::is_canceled() const noexcept -{ - if (!m) return true; - return m->is_canceled(); -} - -bool -shared_recv_request::test() -{ - if (!m || m->is_ready()) return true; - m->progress(); - return m->is_ready(); -} - -void -shared_recv_request::wait() -{ - if (!m) return; - while (!m->is_ready()) m->progress(); -} - -bool -shared_recv_request::cancel() -{ - if (!m) return false; - if (m->is_ready()) return false; - return m->cancel(); -} - -bool -send_multi_request::is_ready() const noexcept -{ - if (!m) return true; - return (m->m_counter == 0); -} - -bool -send_multi_request::test() -{ - if (!m) return true; - if (m->m_counter == 0) return true; - m->m_comm->progress(); - return (m->m_counter == 0); -} - -void -send_multi_request::wait() -{ - if (!m) return; - if (m->m_counter == 0) return; - while (m->m_counter > 0) m->m_comm->progress(); -} - -void -detail::request_state::progress() -{ - m_comm->progress(); -} - -bool -detail::request_state::cancel() -{ - return m_comm->cancel_recv(this); -} - -void -detail::shared_request_state::progress() -{ - m_ctxt->progress(); -} - -bool -detail::shared_request_state::cancel() -{ - return m_ctxt->cancel_recv(this); -} - -} // namespace oomph +namespace oomph { + + bool send_request::is_ready() const noexcept + { + if (!m) return true; + return m->is_ready(); + } + + bool send_request::test() + { + if (!m || m->is_ready()) return true; + m->progress(); + return m->is_ready(); + } + + void send_request::wait() + { + if (!m) return; + while (!m->is_ready()) m->progress(); + } + + bool recv_request::is_ready() const noexcept + { + if (!m) return true; + return m->is_ready(); + } + + bool recv_request::is_canceled() const noexcept + { + if (!m) return true; + return m->is_canceled(); + } + + bool recv_request::test() + { + if (!m || m->is_ready()) return true; + m->progress(); + return m->is_ready(); + } + + void recv_request::wait() + { + if (!m) return; + while (!m->is_ready()) m->progress(); + } + + bool recv_request::cancel() + { + if (!m) return false; + if (m->is_ready()) return false; + return m->cancel(); + } + + bool shared_recv_request::is_ready() const noexcept + { + if (!m) return true; + return m->is_ready(); + } + + bool shared_recv_request::is_canceled() const noexcept + { + if (!m) return true; + return m->is_canceled(); + } + + bool shared_recv_request::test() + { + if (!m || m->is_ready()) return true; + m->progress(); + return m->is_ready(); + } + + void shared_recv_request::wait() + { + if (!m) return; + while (!m->is_ready()) m->progress(); + } + + bool shared_recv_request::cancel() + { + if (!m) return false; + if (m->is_ready()) return false; + return m->cancel(); + } + + bool send_multi_request::is_ready() const noexcept + { + if (!m) return true; + return (m->m_counter == 0); + } + + bool send_multi_request::test() + { + if (!m) return true; + if (m->m_counter == 0) return true; + m->m_comm->progress(); + return (m->m_counter == 0); + } + + void send_multi_request::wait() + { + if (!m) return; + if (m->m_counter == 0) return; + while (m->m_counter > 0) m->m_comm->progress(); + } + + void detail::request_state::progress() { m_comm->progress(); } + + bool detail::request_state::cancel() { return m_comm->cancel_recv(this); } + + void detail::shared_request_state::progress() { m_ctxt->progress(); } + + bool detail::shared_request_state::cancel() { return m_ctxt->cancel_recv(this); } + +} // namespace oomph diff --git a/src/request_state_base.hpp b/src/request_state_base.hpp index c0a6598a..985eaa15 100644 --- a/src/request_state_base.hpp +++ b/src/request_state_base.hpp @@ -11,102 +11,97 @@ #include -namespace oomph -{ -namespace detail -{ - -template -struct request_state_traits -{ - template - using type = T; - - template - static inline void store(T& dst, T const& v) noexcept - { - dst = v; - } +namespace oomph { namespace detail { - template - static inline T load(T const& src) noexcept - { - return src; - } -}; - -template<> -struct request_state_traits -{ - template - using type = std::atomic; - - template - static inline void store(type& dst, T const& v) noexcept + template + struct request_state_traits { - dst.store(v); - } - - template - static inline T load(type const& src) noexcept + template + using type = T; + + template + static inline void store(T& dst, T const& v) noexcept + { + dst = v; + } + + template + static inline T load(T const& src) noexcept + { + return src; + } + }; + + template <> + struct request_state_traits { - return src.load(); - } -}; - -template -struct request_state_base -{ - using traits = request_state_traits; - using context_type = oomph::context_impl; - using communicator_type = oomph::communicator_impl; - using cb_type = oomph::util::unique_function; - - template - using type = typename traits::template type; - - context_type* m_ctxt; - communicator_type* m_comm; - type* m_scheduled; - rank_type m_rank; - tag_type m_tag; - cb_type m_cb; - type m_ready; - type m_canceled; - - request_state_base(context_type* ctxt, communicator_type* comm, type* scheduled, - rank_type rank, tag_type tag, cb_type&& cb) - : m_ctxt{ctxt} - , m_comm{comm} - , m_scheduled{scheduled} - , m_rank{rank} - , m_tag{tag} - , m_cb{std::move(cb)} - , m_ready(false) - , m_canceled(false) + template + using type = std::atomic; + + template + static inline void store(type& dst, T const& v) noexcept + { + dst.store(v); + } + + template + static inline T load(type const& src) noexcept + { + return src.load(); + } + }; + + template + struct request_state_base { - ++(*m_scheduled); - } - - bool is_ready() const noexcept { return traits::load(m_ready); } - - bool is_canceled() const noexcept { return traits::load(m_canceled); } - - void invoke_cb() - { - m_cb(m_rank, m_tag); - --(*m_scheduled); - traits::store(m_ready, true); - } - - void set_canceled() - { - --(*m_scheduled); - traits::store(m_ready, true); - traits::store(m_canceled, true); - } -}; - -} // namespace detail - -} // namespace oomph + using traits = request_state_traits; + using context_type = oomph::context_impl; + using communicator_type = oomph::communicator_impl; + using cb_type = oomph::util::unique_function; + + template + using type = typename traits::template type; + + context_type* m_ctxt; + communicator_type* m_comm; + type* m_scheduled; + rank_type m_rank; + tag_type m_tag; + cb_type m_cb; + type m_ready; + type m_canceled; + + request_state_base(context_type* ctxt, communicator_type* comm, + type* scheduled, rank_type rank, tag_type tag, cb_type&& cb) + : m_ctxt{ctxt} + , m_comm{comm} + , m_scheduled{scheduled} + , m_rank{rank} + , m_tag{tag} + , m_cb{std::move(cb)} + , m_ready(false) + , m_canceled(false) + { + ++(*m_scheduled); + } + + bool is_ready() const noexcept { return traits::load(m_ready); } + + bool is_canceled() const noexcept { return traits::load(m_canceled); } + + void invoke_cb() + { + m_cb(m_rank, m_tag); + --(*m_scheduled); + traits::store(m_ready, true); + } + + void set_canceled() + { + --(*m_scheduled); + traits::store(m_ready, true); + traits::store(m_canceled, true); + } + }; + +}} // namespace oomph::detail diff --git a/src/thread_id.hpp b/src/thread_id.hpp index afd43a67..ecbcac41 100644 --- a/src/thread_id.hpp +++ b/src/thread_id.hpp @@ -11,30 +11,38 @@ #include -namespace oomph -{ -class thread_id -{ - using id_type = std::uintptr_t const; +namespace oomph { + class thread_id + { + using id_type = std::uintptr_t const; - private: - id_type* const m; + private: + id_type* const m; - public: - thread_id(); - ~thread_id(); - thread_id(thread_id const&) = delete; - thread_id(thread_id&) = delete; - thread_id& operator=(thread_id const&) = delete; - thread_id& operator=(thread_id&&) = delete; + public: + thread_id(); + ~thread_id(); + thread_id(thread_id const&) = delete; + thread_id(thread_id&) = delete; + thread_id& operator=(thread_id const&) = delete; + thread_id& operator=(thread_id&&) = delete; - public: - friend bool operator==(thread_id const& a, thread_id const& b) noexcept { return (a.m == b.m); } - friend bool operator!=(thread_id const& a, thread_id const& b) noexcept { return (a.m != b.m); } - friend bool operator<(thread_id const& a, thread_id const& b) noexcept { return (a.m < b.m); } + public: + friend bool operator==(thread_id const& a, thread_id const& b) noexcept + { + return (a.m == b.m); + } + friend bool operator!=(thread_id const& a, thread_id const& b) noexcept + { + return (a.m != b.m); + } + friend bool operator<(thread_id const& a, thread_id const& b) noexcept + { + return (a.m < b.m); + } - operator std::uintptr_t() const& noexcept { return *m; } -}; + operator std::uintptr_t() const& noexcept { return *m; } + }; -thread_id const& tid(); -} // namespace oomph + thread_id const& tid(); +} // namespace oomph diff --git a/src/ucx/address.hpp b/src/ucx/address.hpp index 7fa63904..418dd6a7 100644 --- a/src/ucx/address.hpp +++ b/src/ucx/address.hpp @@ -10,73 +10,72 @@ #pragma once #include -#include #include +#include #include // paths relative to backend #include -namespace oomph -{ -struct address_t -{ - std::vector m_buffer; - - address_t() = default; - - address_t(std::size_t length) - : m_buffer(length) - { - } - - template - address_t(ForwardIterator first, ForwardIterator last) - : m_buffer(first, last) - { - } - - address_t(std::vector&& buffer) - : m_buffer(std::move(buffer)) - { - } - - address_t(const address_t& other) = default; - address_t& operator=(const address_t& other) = default; - address_t(address_t&&) noexcept = default; - address_t& operator=(address_t&&) noexcept = default; - - std::size_t size() const noexcept { return m_buffer.size(); } - - const unsigned char* data() const noexcept { return m_buffer.data(); } - unsigned char* data() noexcept { return m_buffer.data(); } - - const ucp_address_t* get() const noexcept - { - return reinterpret_cast(m_buffer.data()); - } - - auto begin() const noexcept { return m_buffer.begin(); } - auto begin() noexcept { return m_buffer.begin(); } - auto cbegin() const noexcept { return m_buffer.cbegin(); } - - auto end() const noexcept { return m_buffer.end(); } - auto end() noexcept { return m_buffer.end(); } - auto cend() const noexcept { return m_buffer.cend(); } - - unsigned char operator[](std::size_t i) const noexcept { return m_buffer[i]; } - unsigned char& operator[](std::size_t i) noexcept { return m_buffer[i]; } - - template> - friend std::basic_ostream& operator<<( - std::basic_ostream& os, const address_t& addr) +namespace oomph { + struct address_t { - os << "address{"; - os << std::hex; - for (auto c : addr) os << (unsigned int)c; - os << std::dec << "}"; - return os; - } -}; - -} // namespace oomph + std::vector m_buffer; + + address_t() = default; + + address_t(std::size_t length) + : m_buffer(length) + { + } + + template + address_t(ForwardIterator first, ForwardIterator last) + : m_buffer(first, last) + { + } + + address_t(std::vector&& buffer) + : m_buffer(std::move(buffer)) + { + } + + address_t(address_t const& other) = default; + address_t& operator=(address_t const& other) = default; + address_t(address_t&&) noexcept = default; + address_t& operator=(address_t&&) noexcept = default; + + std::size_t size() const noexcept { return m_buffer.size(); } + + unsigned char const* data() const noexcept { return m_buffer.data(); } + unsigned char* data() noexcept { return m_buffer.data(); } + + ucp_address_t const* get() const noexcept + { + return reinterpret_cast(m_buffer.data()); + } + + auto begin() const noexcept { return m_buffer.begin(); } + auto begin() noexcept { return m_buffer.begin(); } + auto cbegin() const noexcept { return m_buffer.cbegin(); } + + auto end() const noexcept { return m_buffer.end(); } + auto end() noexcept { return m_buffer.end(); } + auto cend() const noexcept { return m_buffer.cend(); } + + unsigned char operator[](std::size_t i) const noexcept { return m_buffer[i]; } + unsigned char& operator[](std::size_t i) noexcept { return m_buffer[i]; } + + template > + friend std::basic_ostream& + operator<<(std::basic_ostream& os, address_t const& addr) + { + os << "address{"; + os << std::hex; + for (auto c : addr) os << (unsigned int) c; + os << std::dec << "}"; + return os; + } + }; + +} // namespace oomph diff --git a/src/ucx/address_db.hpp b/src/ucx/address_db.hpp index ff1609ba..01f87e3c 100644 --- a/src/ucx/address_db.hpp +++ b/src/ucx/address_db.hpp @@ -12,53 +12,52 @@ // paths relative to backend #include -namespace oomph -{ -struct type_erased_address_db_t -{ - struct iface +namespace oomph { + struct type_erased_address_db_t { - virtual rank_type rank() = 0; - virtual rank_type size() = 0; - virtual int est_size() = 0; - virtual void init(const address_t&) = 0; - virtual address_t find(rank_type) = 0; - virtual ~iface() {} - }; - - template - struct impl_t final : public iface - { - Impl m_impl; - impl_t(const Impl& impl) - : m_impl{impl} + struct iface { - } - impl_t(Impl&& impl) - : m_impl{std::move(impl)} + virtual rank_type rank() = 0; + virtual rank_type size() = 0; + virtual int est_size() = 0; + virtual void init(address_t const&) = 0; + virtual address_t find(rank_type) = 0; + virtual ~iface() {} + }; + + template + struct impl_t final : public iface { - } - rank_type rank() override { return m_impl.rank(); } - rank_type size() override { return m_impl.size(); } - int est_size() override { return m_impl.est_size(); } - void init(const address_t& addr) override { m_impl.init(addr); } - address_t find(rank_type rank) override { return m_impl.find(rank); } - }; + Impl m_impl; + impl_t(Impl const& impl) + : m_impl{impl} + { + } + impl_t(Impl&& impl) + : m_impl{std::move(impl)} + { + } + rank_type rank() override { return m_impl.rank(); } + rank_type size() override { return m_impl.size(); } + int est_size() override { return m_impl.est_size(); } + void init(address_t const& addr) override { m_impl.init(addr); } + address_t find(rank_type rank) override { return m_impl.find(rank); } + }; - std::unique_ptr m_impl; + std::unique_ptr m_impl; - template - type_erased_address_db_t(Impl&& impl) - : m_impl{std::make_unique>>>( - std::forward(impl))} - { - } + template + type_erased_address_db_t(Impl&& impl) + : m_impl{std::make_unique>>>( + std::forward(impl))} + { + } - inline rank_type rank() const { return m_impl->rank(); } - inline rank_type size() const { return m_impl->size(); } - inline int est_size() const { return m_impl->est_size(); } - inline void init(const address_t& addr) { m_impl->init(addr); } - inline address_t find(rank_type rank) { return m_impl->find(rank); } -}; + inline rank_type rank() const { return m_impl->rank(); } + inline rank_type size() const { return m_impl->size(); } + inline int est_size() const { return m_impl->est_size(); } + inline void init(address_t const& addr) { m_impl->init(addr); } + inline address_t find(rank_type rank) { return m_impl->find(rank); } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/address_db_mpi.hpp b/src/ucx/address_db_mpi.hpp index 9b6eecc7..edd0d4aa 100644 --- a/src/ucx/address_db_mpi.hpp +++ b/src/ucx/address_db_mpi.hpp @@ -15,78 +15,77 @@ #include // paths relative to backend -#include -#include #include +#include +#include -namespace oomph -{ -struct address_db_mpi -{ - using key_t = rank_type; - using value_t = address_t; +namespace oomph { + struct address_db_mpi + { + using key_t = rank_type; + using value_t = address_t; - MPI_Comm m_mpi_comm; - const key_t m_rank; - const key_t m_size; + MPI_Comm m_mpi_comm; + key_t const m_rank; + key_t const m_size; - value_t m_value; - std::map m_address_map; + value_t m_value; + std::map m_address_map; - address_db_mpi(MPI_Comm comm) - : m_mpi_comm{comm} - , m_rank{[](MPI_Comm c) { - int r; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(c, &r)); - return r; - }(comm)} - , m_size{[](MPI_Comm c) { - int s; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(c, &s)); - return s; - }(comm)} - { - } + address_db_mpi(MPI_Comm comm) + : m_mpi_comm{comm} + , m_rank{[](MPI_Comm c) { + int r; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(c, &r)); + return r; + }(comm)} + , m_size{[](MPI_Comm c) { + int s; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(c, &s)); + return s; + }(comm)} + { + } - address_db_mpi(const address_db_mpi&) = delete; - address_db_mpi(address_db_mpi&&) = default; + address_db_mpi(address_db_mpi const&) = delete; + address_db_mpi(address_db_mpi&&) = default; - key_t rank() const noexcept { return m_rank; } - key_t size() const noexcept { return m_size; } - int est_size() const noexcept { return m_size; } + key_t rank() const noexcept { return m_rank; } + key_t size() const noexcept { return m_size; } + int est_size() const noexcept { return m_size; } - value_t find(key_t k) - { - auto it = m_address_map.find(k); - if (it != m_address_map.end()) { return it->second; } - throw std::runtime_error("Cound not find peer address in the MPI address xdatabase."); - } + value_t find(key_t k) + { + auto it = m_address_map.find(k); + if (it != m_address_map.end()) { return it->second; } + throw std::runtime_error("Cound not find peer address in the MPI address xdatabase."); + } - void init(const value_t& addr) - { - m_value = addr; - m_address_map[m_rank] = addr; - for (key_t r = 0; r < m_size; ++r) + void init(value_t const& addr) { - if (r == m_rank) - { - std::size_t size = m_value.size(); - OOMPH_CHECK_MPI_RESULT( - MPI_Bcast(&size, sizeof(std::size_t), MPI_BYTE, r, m_mpi_comm)); - OOMPH_CHECK_MPI_RESULT( - MPI_Bcast(m_value.data(), m_value.size(), MPI_BYTE, r, m_mpi_comm)); - } - else + m_value = addr; + m_address_map[m_rank] = addr; + for (key_t r = 0; r < m_size; ++r) { - std::size_t size; - OOMPH_CHECK_MPI_RESULT( - MPI_Bcast(&size, sizeof(std::size_t), MPI_BYTE, r, m_mpi_comm)); - value_t addr(size); - OOMPH_CHECK_MPI_RESULT(MPI_Bcast(addr.data(), size, MPI_BYTE, r, m_mpi_comm)); - m_address_map[r] = addr; + if (r == m_rank) + { + std::size_t size = m_value.size(); + OOMPH_CHECK_MPI_RESULT( + MPI_Bcast(&size, sizeof(std::size_t), MPI_BYTE, r, m_mpi_comm)); + OOMPH_CHECK_MPI_RESULT( + MPI_Bcast(m_value.data(), m_value.size(), MPI_BYTE, r, m_mpi_comm)); + } + else + { + std::size_t size; + OOMPH_CHECK_MPI_RESULT( + MPI_Bcast(&size, sizeof(std::size_t), MPI_BYTE, r, m_mpi_comm)); + value_t addr(size); + OOMPH_CHECK_MPI_RESULT(MPI_Bcast(addr.data(), size, MPI_BYTE, r, m_mpi_comm)); + m_address_map[r] = addr; + } } } - } -}; + }; -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/address_db_pmi.hpp b/src/ucx/address_db_pmi.hpp index 8dc1a17d..204a54c4 100644 --- a/src/ucx/address_db_pmi.hpp +++ b/src/ucx/address_db_pmi.hpp @@ -9,109 +9,110 @@ */ #pragma once +#include #include -#include #include -#include +#include // paths relative to backend -#include -#include #include +#include +#include -namespace oomph -{ -struct address_db_pmi -{ - // PMI interface to obtain peer addresses - // per-communicator instance used to store/query connections - pmi pmi_impl; +namespace oomph { + struct address_db_pmi + { + // PMI interface to obtain peer addresses + // per-communicator instance used to store/query connections + pmi pmi_impl; - using key_t = rank_type; - using value_t = address_t; + using key_t = rank_type; + using value_t = address_t; - MPI_Comm m_mpi_comm; - std::vector m_rank_map; + MPI_Comm m_mpi_comm; + std::vector m_rank_map; - // these should be PMIx ranks. might need remaping to MPI ranks - key_t m_rank; - key_t m_size; - std::string m_suffix; - std::string m_key; + // these should be PMIx ranks. might need remaping to MPI ranks + key_t m_rank; + key_t m_size; + std::string m_suffix; + std::string m_key; - auto suffix() const noexcept { return m_suffix; } + auto suffix() const noexcept { return m_suffix; } - int make_instance() - { - static int _instance = 0; - const auto ret = _instance++; - return ret; - } - - address_db_pmi(MPI_Comm comm) - : m_mpi_comm{comm} - , m_suffix(std::string("_") + std::to_string(make_instance())) - , m_key("ghex-rank-address" + m_suffix) - { - m_rank = pmi_impl.rank(); - m_size = pmi_impl.size(); - - int mpi_rank{[](MPI_Comm c) { - int r; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(c, &r)); - return r; - }(comm)}; - int mpi_size{[](MPI_Comm c) { - int s; - OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(c, &s)); - return s; - }(comm)}; - - if (m_size != mpi_size) - { throw std::runtime_error("PMIx and MPI sizes are different. Bailing out."); } - - // map MPI communicator ranks to PMIx ranks - m_rank_map.resize(mpi_size); - for (int i = 0; i < mpi_size; i++) m_rank_map[i] = m_rank; - OOMPH_CHECK_MPI_RESULT(MPI_Alltoall( - MPI_IN_PLACE, 0, MPI_BYTE, m_rank_map.data(), sizeof(key_t), MPI_BYTE, comm)); - - // from now on use the MPI communicator rank outsize, and remap for internal use - m_rank = mpi_rank; - } - - address_db_pmi(const address_db_pmi&) = delete; - address_db_pmi(address_db_pmi&&) = default; - - key_t rank() const noexcept { return m_rank; } - key_t size() const noexcept { return m_size; } - int est_size() const noexcept { return m_size; } - - value_t find(key_t k) - { - // ranks coming from outside are MPI ranks - remap to PMIx - k = m_rank_map[k]; - try + int make_instance() { - return pmi_impl.get(k, m_key); + static int _instance = 0; + auto const ret = _instance++; + return ret; } - catch (std::runtime_error& err) + + address_db_pmi(MPI_Comm comm) + : m_mpi_comm{comm} + , m_suffix(std::string("_") + std::to_string(make_instance())) + , m_key("ghex-rank-address" + m_suffix) { - std::string msg = - std::string("PMIx could not find peer address: ") + std::string(err.what()); - throw std::runtime_error(msg); + m_rank = pmi_impl.rank(); + m_size = pmi_impl.size(); + + int mpi_rank{[](MPI_Comm c) { + int r; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_rank(c, &r)); + return r; + }(comm)}; + int mpi_size{[](MPI_Comm c) { + int s; + OOMPH_CHECK_MPI_RESULT(MPI_Comm_size(c, &s)); + return s; + }(comm)}; + + if (m_size != mpi_size) + { + throw std::runtime_error("PMIx and MPI sizes are different. Bailing out."); + } + + // map MPI communicator ranks to PMIx ranks + m_rank_map.resize(mpi_size); + for (int i = 0; i < mpi_size; i++) m_rank_map[i] = m_rank; + OOMPH_CHECK_MPI_RESULT(MPI_Alltoall( + MPI_IN_PLACE, 0, MPI_BYTE, m_rank_map.data(), sizeof(key_t), MPI_BYTE, comm)); + + // from now on use the MPI communicator rank outsize, and remap for internal use + m_rank = mpi_rank; } - } - void init(const value_t& addr) - { - std::vector data(addr.data(), addr.data() + addr.size()); - pmi_impl.set(m_key, data); + address_db_pmi(address_db_pmi const&) = delete; + address_db_pmi(address_db_pmi&&) = default; + + key_t rank() const noexcept { return m_rank; } + key_t size() const noexcept { return m_size; } + int est_size() const noexcept { return m_size; } + + value_t find(key_t k) + { + // ranks coming from outside are MPI ranks - remap to PMIx + k = m_rank_map[k]; + try + { + return pmi_impl.get(k, m_key); + } + catch (std::runtime_error& err) + { + std::string msg = + std::string("PMIx could not find peer address: ") + std::string(err.what()); + throw std::runtime_error(msg); + } + } - // TODO: we have to call an explicit PMIx Fence due to - // https://github.com/open-mpi/ompi/issues/6982 - pmi_impl.exchange(); - } -}; + void init(value_t const& addr) + { + std::vector data(addr.data(), addr.data() + addr.size()); + pmi_impl.set(m_key, data); + + // TODO: we have to call an explicit PMIx Fence due to + // https://github.com/open-mpi/ompi/issues/6982 + pmi_impl.exchange(); + } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/communicator.hpp b/src/ucx/communicator.hpp index dcb4a4ac..04c2c0eb 100644 --- a/src/ucx/communicator.hpp +++ b/src/ucx/communicator.hpp @@ -16,462 +16,453 @@ // paths relative to backend #include <../communicator_base.hpp> #include <../device_guard.hpp> -#include #include +#include -namespace oomph -{ - -class communicator_impl : public communicator_base -{ - public: - using worker_type = worker_t; - template - using lockfree_queue = boost::lockfree::queue, - boost::lockfree::allocator>>; - - using recv_req_queue_type = lockfree_queue; - - public: - context_impl* m_context; - bool const m_thread_safe; - worker_type* m_recv_worker; - worker_type* m_send_worker; - ucx_mutex& m_mutex; - recv_req_queue_type m_send_req_queue; - recv_req_queue_type m_recv_req_queue; - recv_req_queue_type m_cancel_recv_req_queue; - std::vector m_cancel_recv_req_vec; - - public: - communicator_impl(context_impl* ctxt, bool thread_safe, worker_type* recv_worker, - worker_type* send_worker, ucx_mutex& mtx) - : communicator_base(ctxt) - , m_context(ctxt) - , m_thread_safe{thread_safe} - , m_recv_worker{recv_worker} - , m_send_worker{send_worker} - , m_mutex{mtx} - , m_send_req_queue(128) - , m_recv_req_queue(128) - , m_cancel_recv_req_queue(128) - { - } +namespace oomph { - ~communicator_impl() + class communicator_impl : public communicator_base { - // schedule all endpoints for closing - for (auto& kvp : m_send_worker->m_endpoint_cache) + public: + using worker_type = worker_t; + template + using lockfree_queue = boost::lockfree::queue, + boost::lockfree::allocator>>; + + using recv_req_queue_type = lockfree_queue; + + public: + context_impl* m_context; + bool const m_thread_safe; + worker_type* m_recv_worker; + worker_type* m_send_worker; + ucx_mutex& m_mutex; + recv_req_queue_type m_send_req_queue; + recv_req_queue_type m_recv_req_queue; + recv_req_queue_type m_cancel_recv_req_queue; + std::vector m_cancel_recv_req_vec; + + public: + communicator_impl(context_impl* ctxt, bool thread_safe, worker_type* recv_worker, + worker_type* send_worker, ucx_mutex& mtx) + : communicator_base(ctxt) + , m_context(ctxt) + , m_thread_safe{thread_safe} + , m_recv_worker{recv_worker} + , m_send_worker{send_worker} + , m_mutex{mtx} + , m_send_req_queue(128) + , m_recv_req_queue(128) + , m_cancel_recv_req_queue(128) + { + } + + ~communicator_impl() { - m_send_worker->m_endpoint_handles.push_back(kvp.second.close()); - m_send_worker->m_endpoint_handles.back().progress(); + // schedule all endpoints for closing + for (auto& kvp : m_send_worker->m_endpoint_cache) + { + m_send_worker->m_endpoint_handles.push_back(kvp.second.close()); + m_send_worker->m_endpoint_handles.back().progress(); + } } - } - auto& get_heap() noexcept { return m_context->get_heap(); } + auto& get_heap() noexcept { return m_context->get_heap(); } - void progress() - { - while (ucp_worker_progress(m_send_worker->get())) {} - if (m_thread_safe) + void progress() { + while (ucp_worker_progress(m_send_worker->get())) {} + if (m_thread_safe) + { #ifdef OOMPH_UCX_USE_SPIN_LOCK - // this is really important for large-scale multithreading: check if still is - sched_yield(); + // this is really important for large-scale multithreading: check if still is + sched_yield(); #endif - { - // progress recv worker in locked region - //ucx_lock lock(m_mutex); - //while (ucp_worker_progress(m_recv_worker->get())) {} - for (unsigned int i = 0; i < 10; ++i) { - if (m_mutex.try_lock()) + // progress recv worker in locked region + //ucx_lock lock(m_mutex); + //while (ucp_worker_progress(m_recv_worker->get())) {} + for (unsigned int i = 0; i < 10; ++i) { - auto p = ucp_worker_progress(m_recv_worker->get()); - m_mutex.unlock(); - if (!p) break; + if (m_mutex.try_lock()) + { + auto p = ucp_worker_progress(m_recv_worker->get()); + m_mutex.unlock(); + if (!p) break; + } } } } - } - else - { - while (ucp_worker_progress(m_recv_worker->get())) {} - } - // work through ready send callbacks - m_send_req_queue.consume_all( - [](detail::request_state* req) + else { + while (ucp_worker_progress(m_recv_worker->get())) {} + } + // work through ready send callbacks + m_send_req_queue.consume_all([](detail::request_state* req) { auto ptr = req->release_self_ref(); req->invoke_cb(); }); - // work through ready recv callbacks, which were pushed to the queue by other threads - // (including this thread) - if (m_thread_safe) - m_recv_req_queue.consume_all( - [](detail::request_state* req) - { + // work through ready recv callbacks, which were pushed to the queue by other threads + // (including this thread) + if (m_thread_safe) + m_recv_req_queue.consume_all([](detail::request_state* req) { auto ptr = req->release_self_ref(); req->invoke_cb(); }); - m_context->m_recv_req_queue.consume_all( - [](detail::shared_request_state* req) - { + m_context->m_recv_req_queue.consume_all([](detail::shared_request_state* req) { auto ptr = req->release_self_ref(); req->invoke_cb(); }); - } - - send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) - { - const auto& ep = m_send_worker->connect(dst); - const auto stag = - ((std::uint_fast64_t)tag << OOMPH_UCX_TAG_BITS) | (std::uint_fast64_t)(rank()); - - ucs_status_ptr_t ret; - { - // device is set according to message memory: needed? - const_device_guard dg(ptr); - - ret = ucp_tag_send_nb(ep.get(), // destination - dg.data(), // buffer - size, // buffer size - ucp_dt_make_contig(1), // data type - stag, // tag - &communicator_impl::send_callback); // callback function pointer } - if (reinterpret_cast(ret) == UCS_OK) + send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, + rank_type dst, tag_type tag, util::unique_function&& cb, + std::size_t* scheduled) { - // send operation is completed immediately - if (!has_reached_recursion_depth()) - { - auto inc = recursion(); - // call the callback - cb(dst, tag); - return {}; - // request is freed by ucx internally - } - else + auto const& ep = m_send_worker->connect(dst); + auto const stag = + ((std::uint_fast64_t) tag << OOMPH_UCX_TAG_BITS) | (std::uint_fast64_t)(rank()); + + ucs_status_ptr_t ret; { - // allocate request_state - auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, - std::move(cb), ret, m_mutex); - s->create_self_ref(); - // push callback to the queue - enqueue_send(s.get()); - return {std::move(s)}; - // request is freed by ucx internally + // device is set according to message memory: needed? + const_device_guard dg(ptr); + + ret = ucp_tag_send_nb(ep.get(), // destination + dg.data(), // buffer + size, // buffer size + ucp_dt_make_contig(1), // data type + stag, // tag + &communicator_impl::send_callback); // callback function pointer } - } - else if (!UCS_PTR_IS_ERR(ret)) - { - // send operation was scheduled - // allocate request_state - auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, - std::move(cb), ret, m_mutex); - s->create_self_ref(); - // attach necessary data to the request - request_data::construct(ret, s.get()); - return {std::move(s)}; - } - else - { - // an error occurred - throw std::runtime_error("oomph: ucx error - send operation failed"); - } - } - - recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) - { - const auto rtag = - (communicator::any_source == src) - ? ((std::uint_fast64_t)tag << OOMPH_UCX_TAG_BITS) - : ((std::uint_fast64_t)tag << OOMPH_UCX_TAG_BITS) | (std::uint_fast64_t)(src); - - const auto rtag_mask = (communicator::any_source == src) - ? (OOMPH_UCX_TAG_MASK | OOMPH_UCX_ANY_SOURCE_MASK) - : (OOMPH_UCX_TAG_MASK | OOMPH_UCX_SPECIFIC_SOURCE_MASK); - - if (m_thread_safe) m_mutex.lock(); - ucs_status_ptr_t ret; - { - // device is set according to message memory: needed? - device_guard dg(ptr); - - ret = ucp_tag_recv_nb(m_recv_worker->get(), // worker - dg.data(), // buffer - size, // buffer size - ucp_dt_make_contig(1), // data type - rtag, // tag - rtag_mask, // tag mask - &communicator_impl::recv_callback); // callback function pointer - } - if (!UCS_PTR_IS_ERR(ret)) - { - if (UCS_INPROGRESS != ucp_request_check_status(ret)) + if (reinterpret_cast(ret) == UCS_OK) { - // early completed - ucp_request_free(ret); - if (m_thread_safe) m_mutex.unlock(); + // send operation is completed immediately if (!has_reached_recursion_depth()) { auto inc = recursion(); - cb(src, tag); + // call the callback + cb(dst, tag); return {}; + // request is freed by ucx internally } else { // allocate request_state - auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, - std::move(cb), ret, m_mutex); + auto s = m_req_state_factory.make( + m_context, this, scheduled, dst, tag, std::move(cb), ret, m_mutex); s->create_self_ref(); // push callback to the queue - enqueue_recv(s.get()); + enqueue_send(s.get()); return {std::move(s)}; + // request is freed by ucx internally } } - else + else if (!UCS_PTR_IS_ERR(ret)) { - // recv operation was scheduled + // send operation was scheduled // allocate request_state - auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, - std::move(cb), ret, m_mutex); + auto s = m_req_state_factory.make( + m_context, this, scheduled, dst, tag, std::move(cb), ret, m_mutex); s->create_self_ref(); // attach necessary data to the request request_data::construct(ret, s.get()); - if (m_thread_safe) m_mutex.unlock(); return {std::move(s)}; } + else + { + // an error occurred + throw std::runtime_error("oomph: ucx error - send operation failed"); + } } - else + + recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, + tag_type tag, util::unique_function&& cb, + std::size_t* scheduled) { - // an error occurred - throw std::runtime_error("oomph: ucx error - recv operation failed"); - } - } + auto const rtag = (communicator::any_source == src) ? + ((std::uint_fast64_t) tag << OOMPH_UCX_TAG_BITS) : + ((std::uint_fast64_t) tag << OOMPH_UCX_TAG_BITS) | (std::uint_fast64_t)(src); - shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, - rank_type src, tag_type tag, util::unique_function&& cb, - std::atomic* scheduled) - { - const auto rtag = - (communicator::any_source == src) - ? ((std::uint_fast64_t)tag << OOMPH_UCX_TAG_BITS) - : ((std::uint_fast64_t)tag << OOMPH_UCX_TAG_BITS) | (std::uint_fast64_t)(src); + auto const rtag_mask = (communicator::any_source == src) ? + (OOMPH_UCX_TAG_MASK | OOMPH_UCX_ANY_SOURCE_MASK) : + (OOMPH_UCX_TAG_MASK | OOMPH_UCX_SPECIFIC_SOURCE_MASK); - const auto rtag_mask = (communicator::any_source == src) - ? (OOMPH_UCX_TAG_MASK | OOMPH_UCX_ANY_SOURCE_MASK) - : (OOMPH_UCX_TAG_MASK | OOMPH_UCX_SPECIFIC_SOURCE_MASK); + if (m_thread_safe) m_mutex.lock(); + ucs_status_ptr_t ret; + { + // device is set according to message memory: needed? + device_guard dg(ptr); + + ret = ucp_tag_recv_nb(m_recv_worker->get(), // worker + dg.data(), // buffer + size, // buffer size + ucp_dt_make_contig(1), // data type + rtag, // tag + rtag_mask, // tag mask + &communicator_impl::recv_callback); // callback function pointer + } - if (m_thread_safe) m_mutex.lock(); - ucs_status_ptr_t ret; - { - // device is set according to message memory: needed? - device_guard dg(ptr); - - ret = ucp_tag_recv_nb(m_recv_worker->get(), // worker - dg.data(), // buffer - size, // buffer size - ucp_dt_make_contig(1), // data type - rtag, // tag - rtag_mask, // tag mask - &communicator_impl::recv_callback); // callback function pointer + if (!UCS_PTR_IS_ERR(ret)) + { + if (UCS_INPROGRESS != ucp_request_check_status(ret)) + { + // early completed + ucp_request_free(ret); + if (m_thread_safe) m_mutex.unlock(); + if (!has_reached_recursion_depth()) + { + auto inc = recursion(); + cb(src, tag); + return {}; + } + else + { + // allocate request_state + auto s = m_req_state_factory.make( + m_context, this, scheduled, src, tag, std::move(cb), ret, m_mutex); + s->create_self_ref(); + // push callback to the queue + enqueue_recv(s.get()); + return {std::move(s)}; + } + } + else + { + // recv operation was scheduled + // allocate request_state + auto s = m_req_state_factory.make( + m_context, this, scheduled, src, tag, std::move(cb), ret, m_mutex); + s->create_self_ref(); + // attach necessary data to the request + request_data::construct(ret, s.get()); + if (m_thread_safe) m_mutex.unlock(); + return {std::move(s)}; + } + } + else + { + // an error occurred + throw std::runtime_error("oomph: ucx error - recv operation failed"); + } } - if (!UCS_PTR_IS_ERR(ret)) + shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, + rank_type src, tag_type tag, util::unique_function&& cb, + std::atomic* scheduled) { - if (UCS_INPROGRESS != ucp_request_check_status(ret)) + auto const rtag = (communicator::any_source == src) ? + ((std::uint_fast64_t) tag << OOMPH_UCX_TAG_BITS) : + ((std::uint_fast64_t) tag << OOMPH_UCX_TAG_BITS) | (std::uint_fast64_t)(src); + + auto const rtag_mask = (communicator::any_source == src) ? + (OOMPH_UCX_TAG_MASK | OOMPH_UCX_ANY_SOURCE_MASK) : + (OOMPH_UCX_TAG_MASK | OOMPH_UCX_SPECIFIC_SOURCE_MASK); + + if (m_thread_safe) m_mutex.lock(); + ucs_status_ptr_t ret; { - // early completed - ucp_request_free(ret); - if (m_thread_safe) m_mutex.unlock(); - if (!m_context->has_reached_recursion_depth()) + // device is set according to message memory: needed? + device_guard dg(ptr); + + ret = ucp_tag_recv_nb(m_recv_worker->get(), // worker + dg.data(), // buffer + size, // buffer size + ucp_dt_make_contig(1), // data type + rtag, // tag + rtag_mask, // tag mask + &communicator_impl::recv_callback); // callback function pointer + } + + if (!UCS_PTR_IS_ERR(ret)) + { + if (UCS_INPROGRESS != ucp_request_check_status(ret)) { - auto inc = m_context->recursion(); - cb(src, tag); - return {}; + // early completed + ucp_request_free(ret); + if (m_thread_safe) m_mutex.unlock(); + if (!m_context->has_reached_recursion_depth()) + { + auto inc = m_context->recursion(); + cb(src, tag); + return {}; + } + else + { + // allocate shared request_state + auto s = std::make_shared( + m_context, this, scheduled, src, tag, std::move(cb), ret, m_mutex); + s->create_self_ref(); + m_context->enqueue_recv(s.get()); + return {std::move(s)}; + } } else { + // recv operation was scheduled // allocate shared request_state - auto s = std::make_shared(m_context, this, - scheduled, src, tag, std::move(cb), ret, m_mutex); + auto s = std::make_shared( + m_context, this, scheduled, src, tag, std::move(cb), ret, m_mutex); s->create_self_ref(); - m_context->enqueue_recv(s.get()); + // attach necessary data to the request + request_data::construct(ret, s.get()); + if (m_thread_safe) m_mutex.unlock(); return {std::move(s)}; } } else { - // recv operation was scheduled - // allocate shared request_state - auto s = std::make_shared(m_context, this, scheduled, - src, tag, std::move(cb), ret, m_mutex); - s->create_self_ref(); - // attach necessary data to the request - request_data::construct(ret, s.get()); - if (m_thread_safe) m_mutex.unlock(); - return {std::move(s)}; + // an error occurred + throw std::runtime_error("oomph: ucx error - recv operation failed"); } } - else + + void enqueue_send(detail::request_state* d) { - // an error occurred - throw std::runtime_error("oomph: ucx error - recv operation failed"); + while (!m_send_req_queue.push(d)) {} } - } - void enqueue_send(detail::request_state* d) - { - while (!m_send_req_queue.push(d)) {} - } - - void enqueue_recv(detail::request_state* d) - { - while (!m_recv_req_queue.push(d)) {} - } + void enqueue_recv(detail::request_state* d) + { + while (!m_recv_req_queue.push(d)) {} + } - void enqueue_cancel_recv(detail::request_state* d) - { - while (!m_cancel_recv_req_queue.push(d)) {} - } + void enqueue_cancel_recv(detail::request_state* d) + { + while (!m_cancel_recv_req_queue.push(d)) {} + } - inline static void send_callback(void* ucx_req, ucs_status_t status) - { - auto& req_data = *request_data::get(ucx_req); - if (status == UCS_OK) + inline static void send_callback(void* ucx_req, ucs_status_t status) { - // invoke callback - if (req_data.m_req) - { - auto req = req_data.m_req; - auto ptr = req->release_self_ref(); - req->invoke_cb(); - } - else + auto& req_data = *request_data::get(ucx_req); + if (status == UCS_OK) { - auto req = req_data.m_shared_req; - auto ptr = req->release_self_ref(); - req->invoke_cb(); + // invoke callback + if (req_data.m_req) + { + auto req = req_data.m_req; + auto ptr = req->release_self_ref(); + req->invoke_cb(); + } + else + { + auto req = req_data.m_shared_req; + auto ptr = req->release_self_ref(); + req->invoke_cb(); + } } - } - // else: cancelled - do nothing - cancel for sends does not exist + // else: cancelled - do nothing - cancel for sends does not exist - // destroy request - req_data.destroy(); - ucp_request_free(ucx_req); - } + // destroy request + req_data.destroy(); + ucp_request_free(ucx_req); + } - // this callback is called within a locked region - inline static void recv_callback(void* ucx_req, ucs_status_t status, - ucp_tag_recv_info_t* /*info*/) - { - auto& req_data = *request_data::get(ucx_req); - if (status == UCS_OK) + // this callback is called within a locked region + inline static void recv_callback( + void* ucx_req, ucs_status_t status, ucp_tag_recv_info_t* /*info*/) { - // return if early completion - if (req_data.empty()) return; - - if (req_data.m_req) + auto& req_data = *request_data::get(ucx_req); + if (status == UCS_OK) { - // normal recv - auto req = req_data.m_req; - if (req->m_ctxt->thread_safe()) + // return if early completion + if (req_data.empty()) return; + + if (req_data.m_req) { - // multi-threaded case - // free request here - req_data.destroy(); - ucp_request_free(ucx_req); - // enqueue request on the issuing communicator - // this guarantees that only the communicator on which the receive was issued - // will invoke the callback - req->m_comm->enqueue_recv(req); + // normal recv + auto req = req_data.m_req; + if (req->m_ctxt->thread_safe()) + { + // multi-threaded case + // free request here + req_data.destroy(); + ucp_request_free(ucx_req); + // enqueue request on the issuing communicator + // this guarantees that only the communicator on which the receive was issued + // will invoke the callback + req->m_comm->enqueue_recv(req); + } + else + { + // single-threaded case + // free request here + req_data.destroy(); + ucp_request_free(ucx_req); + // call the callback directly from here + auto ptr = req->release_self_ref(); + req->invoke_cb(); + } } else { - // single-threaded case + // shared recv + auto req = req_data.m_shared_req; // free request here req_data.destroy(); ucp_request_free(ucx_req); - // call the callback directly from here - auto ptr = req->release_self_ref(); - req->invoke_cb(); + // enqueue request on the context + req->m_ctxt->enqueue_recv(req); } } - else + else if (status == UCS_ERR_CANCELED) { - // shared recv - auto req = req_data.m_shared_req; - // free request here - req_data.destroy(); - ucp_request_free(ucx_req); - // enqueue request on the context - req->m_ctxt->enqueue_recv(req); + // receive was cancelled + if (req_data.m_req) + req_data.m_req->m_comm->enqueue_cancel_recv(req_data.m_req); + else + req_data.m_shared_req->m_ctxt->enqueue_cancel_recv(req_data.m_shared_req); } - } - else if (status == UCS_ERR_CANCELED) - { - // receive was cancelled - if (req_data.m_req) req_data.m_req->m_comm->enqueue_cancel_recv(req_data.m_req); else - req_data.m_shared_req->m_ctxt->enqueue_cancel_recv(req_data.m_shared_req); - } - else - { - // an error occurred - throw std::runtime_error("oomph: ucx error - recv message truncated"); + { + // an error occurred + throw std::runtime_error("oomph: ucx error - recv message truncated"); + } } - } - // Note: at this time, send requests cannot be canceled in UCX (1.7.0rc1) - // https://github.com/openucx/ucx/issues/1162 - //bool cancel_recv_cb(recv_request const& req) - bool cancel_recv(detail::request_state* s) - { - if (m_thread_safe) m_mutex.lock(); - ucp_request_cancel(m_recv_worker->get(), s->m_ucx_ptr); - //if (m_thread_safe) m_mutex.unlock(); - // The ucx callback will still be executed after the cancel. However, the status argument - // will indicate whether the cancel was successful. - // Progress the receive worker in order to execute the ucx callback - //if (m_thread_safe) m_mutex.lock(); - while (ucp_worker_progress(m_recv_worker->get())) {} - if (m_thread_safe) m_mutex.unlock(); - // check whether the cancelled callback was enqueued by consuming all queued cancelled - // callbacks and putting them in a temporary vector - bool found = false; - m_cancel_recv_req_vec.clear(); - m_cancel_recv_req_queue.consume_all( - [this, s, &found](detail::request_state* r) - { - if (r == s) found = true; + // Note: at this time, send requests cannot be canceled in UCX (1.7.0rc1) + // https://github.com/openucx/ucx/issues/1162 + //bool cancel_recv_cb(recv_request const& req) + bool cancel_recv(detail::request_state* s) + { + if (m_thread_safe) m_mutex.lock(); + ucp_request_cancel(m_recv_worker->get(), s->m_ucx_ptr); + //if (m_thread_safe) m_mutex.unlock(); + // The ucx callback will still be executed after the cancel. However, the status argument + // will indicate whether the cancel was successful. + // Progress the receive worker in order to execute the ucx callback + //if (m_thread_safe) m_mutex.lock(); + while (ucp_worker_progress(m_recv_worker->get())) {} + if (m_thread_safe) m_mutex.unlock(); + // check whether the cancelled callback was enqueued by consuming all queued cancelled + // callbacks and putting them in a temporary vector + bool found = false; + m_cancel_recv_req_vec.clear(); + m_cancel_recv_req_queue.consume_all([this, s, &found](detail::request_state* r) { + if (r == s) + found = true; else m_cancel_recv_req_vec.push_back(r); }); - // re-enqueue all callbacks which were not identical with the current callback - for (auto x : m_cancel_recv_req_vec) - while (!m_cancel_recv_req_queue.push(x)) {} + // re-enqueue all callbacks which were not identical with the current callback + for (auto x : m_cancel_recv_req_vec) + while (!m_cancel_recv_req_queue.push(x)) {} - // delete callback here if it was actually cancelled - if (found) - { - auto ptr = s->release_self_ref(); - s->set_canceled(); - void* ucx_req = s->m_ucx_ptr; - // destroy request - request_data::get(ucx_req)->destroy(); - if (m_thread_safe) m_mutex.lock(); - ucp_request_free(ucx_req); - if (m_thread_safe) m_mutex.unlock(); + // delete callback here if it was actually cancelled + if (found) + { + auto ptr = s->release_self_ref(); + s->set_canceled(); + void* ucx_req = s->m_ucx_ptr; + // destroy request + request_data::get(ucx_req)->destroy(); + if (m_thread_safe) m_mutex.lock(); + ucp_request_free(ucx_req); + if (m_thread_safe) m_mutex.unlock(); + } + return found; } - return found; - } -}; + }; -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/config.hpp b/src/ucx/config.hpp index 0fedb2ff..3f091b14 100644 --- a/src/ucx/config.hpp +++ b/src/ucx/config.hpp @@ -16,24 +16,21 @@ // paths relative to backend #include #ifdef OOMPH_UCX_USE_PMI -#include +# include #else -#include +# include #endif #ifdef OOMPH_UCX_USE_SPIN_LOCK -#include -namespace oomph -{ -using ucx_mutex = pthread_spin::mutex; +# include +namespace oomph { + using ucx_mutex = pthread_spin::mutex; } #else -namespace oomph -{ -using ucx_mutex = std::mutex; +namespace oomph { + using ucx_mutex = std::mutex; } #endif -namespace oomph -{ -using ucx_lock = std::lock_guard; +namespace oomph { + using ucx_lock = std::lock_guard; } diff --git a/src/ucx/context.cpp b/src/ucx/context.cpp index 8a93faea..339ebc54 100644 --- a/src/ucx/context.cpp +++ b/src/ucx/context.cpp @@ -10,96 +10,94 @@ #include #ifndef NDEBUG -#include +# include #endif // paths relative to backend -#include #include +#include -namespace oomph -{ -communicator_impl* -context_impl::get_communicator() -{ - auto send_worker = std::make_unique(get(), m_db, - (m_thread_safe ? UCS_THREAD_MODE_SERIALIZED : UCS_THREAD_MODE_SINGLE)); - auto send_worker_ptr = send_worker.get(); - if (m_thread_safe) +namespace oomph { + communicator_impl* context_impl::get_communicator() { - ucx_lock l(m_mutex); - m_workers.push_back(std::move(send_worker)); + auto send_worker = std::make_unique( + get(), m_db, (m_thread_safe ? UCS_THREAD_MODE_SERIALIZED : UCS_THREAD_MODE_SINGLE)); + auto send_worker_ptr = send_worker.get(); + if (m_thread_safe) + { + ucx_lock l(m_mutex); + m_workers.push_back(std::move(send_worker)); + } + else { m_workers.push_back(std::move(send_worker)); } + auto comm = + new communicator_impl{this, m_thread_safe, m_worker.get(), send_worker_ptr, m_mutex}; + m_comms_set.insert(comm); + return comm; } - else { m_workers.push_back(std::move(send_worker)); } - auto comm = - new communicator_impl{this, m_thread_safe, m_worker.get(), send_worker_ptr, m_mutex}; - m_comms_set.insert(comm); - return comm; -} -context_impl::~context_impl() -{ - // issue a barrier to sync all contexts - MPI_Barrier(m_mpi_comm); + context_impl::~context_impl() + { + // issue a barrier to sync all contexts + MPI_Barrier(m_mpi_comm); - const auto t0 = std::chrono::system_clock::now(); - double elapsed = 0.0; - static constexpr double t_timeout = 1000; + auto const t0 = std::chrono::system_clock::now(); + double elapsed = 0.0; + static constexpr double t_timeout = 1000; - // close endpoints while also progressing the receive worker - std::vector handles; - for (auto& w_ptr : m_workers) - for (auto& h : w_ptr->m_endpoint_handles) handles.push_back(std::move(h)); + // close endpoints while also progressing the receive worker + std::vector handles; + for (auto& w_ptr : m_workers) + for (auto& h : w_ptr->m_endpoint_handles) handles.push_back(std::move(h)); - std::vector tmp; - tmp.reserve(handles.size()); + std::vector tmp; + tmp.reserve(handles.size()); - while (handles.size() != 0u && elapsed < t_timeout) - { - for (auto& h : handles) + while (handles.size() != 0u && elapsed < t_timeout) { - ucp_worker_progress(m_worker->m_worker); - if (!h.ready()) tmp.push_back(std::move(h)); + for (auto& h : handles) + { + ucp_worker_progress(m_worker->m_worker); + if (!h.ready()) tmp.push_back(std::move(h)); + } + handles.swap(tmp); + tmp.clear(); + elapsed = + std::chrono::duration(std::chrono::system_clock::now() - t0) + .count(); } - handles.swap(tmp); - tmp.clear(); - elapsed = std::chrono::duration(std::chrono::system_clock::now() - t0) - .count(); - } - if (handles.size() > 0) - { + if (handles.size() > 0) + { #ifndef NDEBUG - std::cerr << "WARNING: timeout waiting for UCX endpoint close" << std::endl; + std::cerr << "WARNING: timeout waiting for UCX endpoint close" << std::endl; #endif - // free all requests for the unclosed endpoints - for (auto& h : handles) ucp_request_free(h.m_status); - } + // free all requests for the unclosed endpoints + for (auto& h : handles) ucp_request_free(h.m_status); + } - // issue another non-blocking barrier while progressing the receive worker in order to flush all - // remaining (remote) endpoints which are connected to this receive worker - MPI_Request req; - int flag; - MPI_Ibarrier(m_mpi_comm, &req); - while (true) - { - ucp_worker_progress(m_worker->m_worker); - MPI_Test(&req, &flag, MPI_STATUS_IGNORE); - if (flag) break; - } + // issue another non-blocking barrier while progressing the receive worker in order to flush all + // remaining (remote) endpoints which are connected to this receive worker + MPI_Request req; + int flag; + MPI_Ibarrier(m_mpi_comm, &req); + while (true) + { + ucp_worker_progress(m_worker->m_worker); + MPI_Test(&req, &flag, MPI_STATUS_IGNORE); + if (flag) break; + } - // receive worker should not have connected to any endpoint - assert(m_worker->m_endpoint_cache.size() == 0); + // receive worker should not have connected to any endpoint + assert(m_worker->m_endpoint_cache.size() == 0); - // another MPI barrier to be sure - MPI_Barrier(m_mpi_comm); -} + // another MPI barrier to be sure + MPI_Barrier(m_mpi_comm); + } -const char* -context_impl::get_transport_option(const std::string& opt) -{ - if (opt == "name") { return "ucx"; } - else { return "unspecified"; } -} + char const* context_impl::get_transport_option(std::string const& opt) + { + if (opt == "name") { return "ucx"; } + else { return "unspecified"; } + } -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/context.hpp b/src/ucx/context.hpp index 231068cf..cd3e53db 100644 --- a/src/ucx/context.hpp +++ b/src/ucx/context.hpp @@ -9,8 +9,8 @@ */ #pragma once -#include #include +#include #include @@ -19,243 +19,239 @@ // paths relative to backend #include <../context_base.hpp> #include <./config.hpp> -#include +#include #include -#include -#include #include -#include - -namespace oomph -{ -#define OOMPH_UCX_TAG_BITS 32 -#define OOMPH_UCX_RANK_BITS 32 -#define OOMPH_UCX_ANY_SOURCE_MASK 0x0000000000000000ul -#define OOMPH_UCX_SPECIFIC_SOURCE_MASK 0x00000000fffffffful -#define OOMPH_UCX_TAG_MASK 0xffffffff00000000ul - -class context_impl : public context_base -{ - public: // member types - using region_type = region; - using device_region_type = region; - using heap_type = hwmalloc::heap; - using worker_type = worker_t; - - private: // member types - struct ucp_context_h_holder - { - ucp_context_h m_context; - ~ucp_context_h_holder() { ucp_cleanup(m_context); } - }; - - using worker_vector = std::vector>; - - template - using lockfree_queue = boost::lockfree::queue, - boost::lockfree::allocator>>; - - using recv_req_queue_type = lockfree_queue; - - private: // members - type_erased_address_db_t m_db; - ucp_context_h_holder m_context; - heap_type m_heap; - rma_context m_rma_context; - std::size_t m_req_size; - std::unique_ptr m_worker; // shared, serialized - per rank - std::vector> m_workers; - - public: - ucx_mutex m_mutex; - recv_req_queue_type m_recv_req_queue; - recv_req_queue_type m_cancel_recv_req_queue; +#include +#include +#include - friend struct worker_t; +namespace oomph { +#define OOMPH_UCX_TAG_BITS 32 +#define OOMPH_UCX_RANK_BITS 32 +#define OOMPH_UCX_ANY_SOURCE_MASK 0x0000'0000'0000'0000ul +#define OOMPH_UCX_SPECIFIC_SOURCE_MASK 0x0000'0000'ffff'fffful +#define OOMPH_UCX_TAG_MASK 0xffff'ffff'0000'0000ul - public: // ctors - context_impl(MPI_Comm mpi_c, bool thread_safe, bool message_pool_never_free, - std::size_t message_pool_reserve) - : context_base(mpi_c, thread_safe) + class context_impl : public context_base + { + public: // member types + using region_type = region; + using device_region_type = region; + using heap_type = hwmalloc::heap; + using worker_type = worker_t; + + private: // member types + struct ucp_context_h_holder + { + ucp_context_h m_context; + ~ucp_context_h_holder() { ucp_cleanup(m_context); } + }; + + using worker_vector = std::vector>; + + template + using lockfree_queue = boost::lockfree::queue, + boost::lockfree::allocator>>; + + using recv_req_queue_type = lockfree_queue; + + private: // members + type_erased_address_db_t m_db; + ucp_context_h_holder m_context; + heap_type m_heap; + rma_context m_rma_context; + std::size_t m_req_size; + std::unique_ptr m_worker; // shared, serialized - per rank + std::vector> m_workers; + + public: + ucx_mutex m_mutex; + recv_req_queue_type m_recv_req_queue; + recv_req_queue_type m_cancel_recv_req_queue; + + friend struct worker_t; + + public: // ctors + context_impl(MPI_Comm mpi_c, bool thread_safe, bool message_pool_never_free, + std::size_t message_pool_reserve) + : context_base(mpi_c, thread_safe) #if defined OOMPH_UCX_USE_PMI - , m_db(address_db_pmi(context_base::m_mpi_comm)) + , m_db(address_db_pmi(context_base::m_mpi_comm)) #else - , m_db(address_db_mpi(context_base::m_mpi_comm)) + , m_db(address_db_mpi(context_base::m_mpi_comm)) #endif - , m_heap{this, message_pool_never_free, message_pool_reserve} - , m_rma_context() - , m_recv_req_queue(128) - , m_cancel_recv_req_queue(128) - { - // read run-time context - ucp_config_t* config_ptr; - OOMPH_CHECK_UCX_RESULT(ucp_config_read(NULL, NULL, &config_ptr)); - - // set parameters - ucp_params_t context_params; - // define valid fields - context_params.field_mask = - UCP_PARAM_FIELD_FEATURES // features - | UCP_PARAM_FIELD_TAG_SENDER_MASK // mask which gets sender endpoint from a tag - | UCP_PARAM_FIELD_MT_WORKERS_SHARED // multi-threaded context: thread safety - | UCP_PARAM_FIELD_ESTIMATED_NUM_EPS // estimated number of endpoints for this context - | UCP_PARAM_FIELD_REQUEST_SIZE // size of reserved space in a non-blocking request - | UCP_PARAM_FIELD_REQUEST_INIT // initialize request memory - ; - - // features - context_params.features = UCP_FEATURE_TAG // tag matching - | UCP_FEATURE_RMA // RMA access support - ; - // thread safety - // this should be true if we have per-thread workers, - // otherwise, if one worker is shared by all thread, it should be false - // requires benchmarking. - // This flag indicates if this context is shared by multiple workers from different threads. - // If so, this context needs thread safety support; otherwise, the context does not need to - // provide thread safety. For example, if the context is used by single worker, and that - // worker is shared by multiple threads, this context does not need thread safety; if the - // context is used by worker 1 and worker 2, and worker 1 is used by thread 1 and worker 2 - // is used by thread 2, then this context needs thread safety. Note that actual thread mode - // may be different from mode passed to ucp_init. To get actual thread mode use - // ucp_context_query. - //context_params.mt_workers_shared = true; - context_params.mt_workers_shared = this->m_thread_safe; - // estimated number of connections - // affects transport selection criteria and theresulting performance - context_params.estimated_num_eps = m_db.est_size(); - // mask - // mask which specifies particular bits of the tag which can uniquely identify - // the sender (UCP endpoint) in tagged operations. - //context_params.tag_sender_mask = 0x00000000fffffffful; - context_params.tag_sender_mask = 0xfffffffffffffffful; - // additional usable request size - context_params.request_size = request_data_size::value; - // initialize a valid request_data object within the ucx provided memory - context_params.request_init = &request_data::init; - - // initialize UCP - OOMPH_CHECK_UCX_RESULT(ucp_init(&context_params, config_ptr, &m_context.m_context)); - ucp_config_release(config_ptr); - - // check the actual parameters - ucp_context_attr_t attr; - attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE | // internal request size - UCP_ATTR_FIELD_THREAD_MODE; // thread safety - ucp_context_query(m_context.m_context, &attr); - m_req_size = attr.request_size; - if (this->m_thread_safe && attr.thread_mode != UCS_THREAD_MODE_MULTI) - throw std::runtime_error("ucx cannot be used with multi-threaded context"); - - // make shared worker - // use single-threaded UCX mode, as per developer advice - // https://github.com/openucx/ucx/issues/4609 - m_worker.reset(new worker_type{get(), m_db, UCS_THREAD_MODE_SINGLE}); - - // intialize database - m_db.init(m_worker->address()); - - m_rma_context.set_ucp_context(m_context.m_context); - } + , m_heap{this, message_pool_never_free, message_pool_reserve} + , m_rma_context() + , m_recv_req_queue(128) + , m_cancel_recv_req_queue(128) + { + // read run-time context + ucp_config_t* config_ptr; + OOMPH_CHECK_UCX_RESULT(ucp_config_read(NULL, NULL, &config_ptr)); + + // set parameters + ucp_params_t context_params; + // define valid fields + context_params.field_mask = UCP_PARAM_FIELD_FEATURES // features + | UCP_PARAM_FIELD_TAG_SENDER_MASK // mask which gets sender endpoint from a tag + | UCP_PARAM_FIELD_MT_WORKERS_SHARED // multi-threaded context: thread safety + | + UCP_PARAM_FIELD_ESTIMATED_NUM_EPS // estimated number of endpoints for this context + | + UCP_PARAM_FIELD_REQUEST_SIZE // size of reserved space in a non-blocking request + | UCP_PARAM_FIELD_REQUEST_INIT // initialize request memory + ; + + // features + context_params.features = UCP_FEATURE_TAG // tag matching + | UCP_FEATURE_RMA // RMA access support + ; + // thread safety + // this should be true if we have per-thread workers, + // otherwise, if one worker is shared by all thread, it should be false + // requires benchmarking. + // This flag indicates if this context is shared by multiple workers from different threads. + // If so, this context needs thread safety support; otherwise, the context does not need to + // provide thread safety. For example, if the context is used by single worker, and that + // worker is shared by multiple threads, this context does not need thread safety; if the + // context is used by worker 1 and worker 2, and worker 1 is used by thread 1 and worker 2 + // is used by thread 2, then this context needs thread safety. Note that actual thread mode + // may be different from mode passed to ucp_init. To get actual thread mode use + // ucp_context_query. + //context_params.mt_workers_shared = true; + context_params.mt_workers_shared = this->m_thread_safe; + // estimated number of connections + // affects transport selection criteria and theresulting performance + context_params.estimated_num_eps = m_db.est_size(); + // mask + // mask which specifies particular bits of the tag which can uniquely identify + // the sender (UCP endpoint) in tagged operations. + //context_params.tag_sender_mask = 0x00000000fffffffful; + context_params.tag_sender_mask = 0xffff'ffff'ffff'fffful; + // additional usable request size + context_params.request_size = request_data_size::value; + // initialize a valid request_data object within the ucx provided memory + context_params.request_init = &request_data::init; + + // initialize UCP + OOMPH_CHECK_UCX_RESULT(ucp_init(&context_params, config_ptr, &m_context.m_context)); + ucp_config_release(config_ptr); + + // check the actual parameters + ucp_context_attr_t attr; + attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE | // internal request size + UCP_ATTR_FIELD_THREAD_MODE; // thread safety + ucp_context_query(m_context.m_context, &attr); + m_req_size = attr.request_size; + if (this->m_thread_safe && attr.thread_mode != UCS_THREAD_MODE_MULTI) + throw std::runtime_error("ucx cannot be used with multi-threaded context"); + + // make shared worker + // use single-threaded UCX mode, as per developer advice + // https://github.com/openucx/ucx/issues/4609 + m_worker.reset(new worker_type{get(), m_db, UCS_THREAD_MODE_SINGLE}); + + // intialize database + m_db.init(m_worker->address()); + + m_rma_context.set_ucp_context(m_context.m_context); + } - ~context_impl(); + ~context_impl(); - context_impl(context_impl&&) = delete; - context_impl& operator=(context_impl&&) = delete; + context_impl(context_impl&&) = delete; + context_impl& operator=(context_impl&&) = delete; - ucp_context_h get() const noexcept { return m_context.m_context; } + ucp_context_h get() const noexcept { return m_context.m_context; } - region make_region(void* ptr) { return {ptr}; } + region make_region(void* ptr) { return {ptr}; } - auto& get_heap() noexcept { return m_heap; } + auto& get_heap() noexcept { return m_heap; } - communicator_impl* get_communicator(); + communicator_impl* get_communicator(); - void progress() - { - //{ - // ucx_lock lock(m_mutex); - // while (ucp_worker_progress(m_worker->get())) {} - //} - if (m_mutex.try_lock()) + void progress() { - ucp_worker_progress(m_worker->get()); - m_mutex.unlock(); - } - m_recv_req_queue.consume_all( - [](detail::shared_request_state* req) + //{ + // ucx_lock lock(m_mutex); + // while (ucp_worker_progress(m_worker->get())) {} + //} + if (m_mutex.try_lock()) { + ucp_worker_progress(m_worker->get()); + m_mutex.unlock(); + } + m_recv_req_queue.consume_all([](detail::shared_request_state* req) { auto ptr = req->release_self_ref(); req->invoke_cb(); }); - } - - void enqueue_recv(detail::shared_request_state* d) - { - while (!m_recv_req_queue.push(d)) {} - } + } - void enqueue_cancel_recv(detail::shared_request_state* d) - { - while (!m_cancel_recv_req_queue.push(d)) {} - } + void enqueue_recv(detail::shared_request_state* d) + { + while (!m_recv_req_queue.push(d)) {} + } - bool cancel_recv(detail::shared_request_state* s) - { - if (m_thread_safe) m_mutex.lock(); - ucp_request_cancel(m_worker->get(), s->m_ucx_ptr); - while (ucp_worker_progress(m_worker->get())) {} - // check whether the cancelled callback was enqueued by consuming all queued cancelled - // callbacks and putting them in a temporary vector - static thread_local bool found = false; - static thread_local std::vector m_cancel_recv_req_vec; - m_cancel_recv_req_vec.clear(); - m_cancel_recv_req_queue.consume_all( - [this, s, found_ptr = &found](detail::shared_request_state* r) - { - if (r == s) *found_ptr = true; - else - m_cancel_recv_req_vec.push_back(r); - }); - // re-enqueue all callbacks which were not identical with the current callback - for (auto x : m_cancel_recv_req_vec) - while (!m_cancel_recv_req_queue.push(x)) {} - if (m_thread_safe) m_mutex.unlock(); + void enqueue_cancel_recv(detail::shared_request_state* d) + { + while (!m_cancel_recv_req_queue.push(d)) {} + } - // delete callback here if it was actually cancelled - if (found) + bool cancel_recv(detail::shared_request_state* s) { - auto ptr = s->release_self_ref(); - s->set_canceled(); - void* ucx_req = s->m_ucx_ptr; - // destroy request - request_data::get(ucx_req)->destroy(); if (m_thread_safe) m_mutex.lock(); - ucp_request_free(ucx_req); + ucp_request_cancel(m_worker->get(), s->m_ucx_ptr); + while (ucp_worker_progress(m_worker->get())) {} + // check whether the cancelled callback was enqueued by consuming all queued cancelled + // callbacks and putting them in a temporary vector + static thread_local bool found = false; + static thread_local std::vector m_cancel_recv_req_vec; + m_cancel_recv_req_vec.clear(); + m_cancel_recv_req_queue.consume_all( + [this, s, found_ptr = &found](detail::shared_request_state* r) { + if (r == s) + *found_ptr = true; + else + m_cancel_recv_req_vec.push_back(r); + }); + // re-enqueue all callbacks which were not identical with the current callback + for (auto x : m_cancel_recv_req_vec) + while (!m_cancel_recv_req_queue.push(x)) {} if (m_thread_safe) m_mutex.unlock(); + + // delete callback here if it was actually cancelled + if (found) + { + auto ptr = s->release_self_ref(); + s->set_canceled(); + void* ucx_req = s->m_ucx_ptr; + // destroy request + request_data::get(ucx_req)->destroy(); + if (m_thread_safe) m_mutex.lock(); + ucp_request_free(ucx_req); + if (m_thread_safe) m_mutex.unlock(); + } + return found; } - return found; - } - const char* get_transport_option(const std::string& opt); + char const* get_transport_option(std::string const& opt); - unsigned int num_tag_bits() const noexcept { return OOMPH_UCX_TAG_BITS; } -}; + unsigned int num_tag_bits() const noexcept { return OOMPH_UCX_TAG_BITS; } + }; -template<> -inline region -register_memory(context_impl& c, void* ptr, std::size_t) -{ - return c.make_region(ptr); -} + template <> + inline region register_memory(context_impl& c, void* ptr, std::size_t) + { + return c.make_region(ptr); + } #if OOMPH_ENABLE_DEVICE -template<> -inline region -register_device_memory(context_impl& c, int, void* ptr, std::size_t) -{ - return c.make_region(ptr); -} + template <> + inline region register_device_memory(context_impl& c, int, void* ptr, std::size_t) + { + return c.make_region(ptr); + } #endif -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/endpoint.hpp b/src/ucx/endpoint.hpp index a15b9e43..eea86695 100644 --- a/src/ucx/endpoint.hpp +++ b/src/ucx/endpoint.hpp @@ -15,96 +15,96 @@ // paths relative to backend #include -namespace oomph -{ -#define OOMPH_ANY_SOURCE (int)-1 - -struct endpoint_t -{ - rank_type m_rank; - ucp_ep_h m_ep; - ucp_worker_h m_worker; - util::moved_bit m_moved; - - endpoint_t() noexcept - : m_moved(true) - { - } - endpoint_t(rank_type rank, ucp_worker_h local_worker, const address_t& remote_worker_address) - : m_rank(rank) - , m_worker{local_worker} - { - ucp_ep_params_t ep_params; - ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ep_params.address = remote_worker_address.get(); - OOMPH_CHECK_UCX_RESULT(ucp_ep_create(local_worker, &ep_params, &(m_ep))); - } - - endpoint_t(const endpoint_t&) = delete; - endpoint_t& operator=(const endpoint_t&) = delete; - endpoint_t(endpoint_t&& other) noexcept = default; - endpoint_t& operator=(endpoint_t&& other) = delete; - - struct close_handle +namespace oomph { +#define OOMPH_ANY_SOURCE (int) -1 + + struct endpoint_t { - bool m_done; - ucp_worker_h m_ucp_worker; - ucs_status_ptr_t m_status; + rank_type m_rank; + ucp_ep_h m_ep; + ucp_worker_h m_worker; + util::moved_bit m_moved; - close_handle() - : m_done{true} + endpoint_t() noexcept + : m_moved(true) { } - - close_handle(ucp_worker_h worker, ucs_status_ptr_t status) - : m_done{false} - , m_ucp_worker{worker} - , m_status{status} + endpoint_t( + rank_type rank, ucp_worker_h local_worker, address_t const& remote_worker_address) + : m_rank(rank) + , m_worker{local_worker} { + ucp_ep_params_t ep_params; + ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; + ep_params.address = remote_worker_address.get(); + OOMPH_CHECK_UCX_RESULT(ucp_ep_create(local_worker, &ep_params, &(m_ep))); } - close_handle(close_handle&& other) - : m_done{std::exchange(other.m_done, true)} - , m_ucp_worker{other.m_ucp_worker} - , m_status{other.m_status} - { - } + endpoint_t(endpoint_t const&) = delete; + endpoint_t& operator=(endpoint_t const&) = delete; + endpoint_t(endpoint_t&& other) noexcept = default; + endpoint_t& operator=(endpoint_t&& other) = delete; - bool ready() + struct close_handle { - progress(); - return m_done; - } + bool m_done; + ucp_worker_h m_ucp_worker; + ucs_status_ptr_t m_status; - void progress() - { - if (!m_done) + close_handle() + : m_done{true} + { + } + + close_handle(ucp_worker_h worker, ucs_status_ptr_t status) + : m_done{false} + , m_ucp_worker{worker} + , m_status{status} { - ucp_worker_progress(m_ucp_worker); - if (UCS_OK != ucp_request_check_status(m_status)) + } + + close_handle(close_handle&& other) + : m_done{std::exchange(other.m_done, true)} + , m_ucp_worker{other.m_ucp_worker} + , m_status{other.m_status} + { + } + + bool ready() + { + progress(); + return m_done; + } + + void progress() + { + if (!m_done) { - ucp_request_free(m_status); - m_done = true; + ucp_worker_progress(m_ucp_worker); + if (UCS_OK != ucp_request_check_status(m_status)) + { + ucp_request_free(m_status); + m_done = true; + } } } + }; + + close_handle close() + { + if (m_moved) return {}; + ucs_status_ptr_t ret = ucp_ep_close_nb(m_ep, UCP_EP_CLOSE_MODE_FLUSH); + if (UCS_OK == reinterpret_cast(ret)) return {}; + if (UCS_PTR_IS_ERR(ret)) return {}; + return {m_worker, ret}; } + + //operator bool() const noexcept { return m_moved; } + operator ucp_ep_h() const noexcept { return m_ep; } + + rank_type rank() const noexcept { return m_rank; } + ucp_ep_h& get() noexcept { return m_ep; } + ucp_ep_h const& get() const noexcept { return m_ep; } }; - close_handle close() - { - if (m_moved) return {}; - ucs_status_ptr_t ret = ucp_ep_close_nb(m_ep, UCP_EP_CLOSE_MODE_FLUSH); - if (UCS_OK == reinterpret_cast(ret)) return {}; - if (UCS_PTR_IS_ERR(ret)) return {}; - return {m_worker, ret}; - } - - //operator bool() const noexcept { return m_moved; } - operator ucp_ep_h() const noexcept { return m_ep; } - - rank_type rank() const noexcept { return m_rank; } - ucp_ep_h& get() noexcept { return m_ep; } - const ucp_ep_h& get() const noexcept { return m_ep; } -}; - -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/error.hpp b/src/ucx/error.hpp index df9ddc73..23df7a9e 100644 --- a/src/ucx/error.hpp +++ b/src/ucx/error.hpp @@ -12,20 +12,20 @@ #include #ifdef NDEBUG -#define OOMPH_CHECK_UCX_RESULT(x) x; -#define OOMPH_CHECK_UCX_RESULT_NOEXCEPT(x) x; +# define OOMPH_CHECK_UCX_RESULT(x) x; +# define OOMPH_CHECK_UCX_RESULT_NOEXCEPT(x) x; #else -#include -#include -#define OOMPH_CHECK_UCX_RESULT(x) \ - if (x != UCS_OK) \ - throw std::runtime_error("OOMPH Error: UCX Call failed " + std::string(#x) + " in " + \ - std::string(__FILE__) + ":" + std::to_string(__LINE__)); -#define OOMPH_CHECK_UCX_RESULT_NOEXCEPT(x) \ - if (x != UCX_OK) \ - { \ - std::cerr << "OOMPH Error: UCX Call failed " << std::string(#x) << " in " \ - << std::string(__FILE__) << ":" << std::to_string(__LINE__) << std::endl; \ - std::terminate(); \ - } +# include +# include +# define OOMPH_CHECK_UCX_RESULT(x) \ + if (x != UCS_OK) \ + throw std::runtime_error("OOMPH Error: UCX Call failed " + std::string(#x) + " in " + \ + std::string(__FILE__) + ":" + std::to_string(__LINE__)); +# define OOMPH_CHECK_UCX_RESULT_NOEXCEPT(x) \ + if (x != UCX_OK) \ + { \ + std::cerr << "OOMPH Error: UCX Call failed " << std::string(#x) << " in " \ + << std::string(__FILE__) << ":" << std::to_string(__LINE__) << std::endl; \ + std::terminate(); \ + } #endif diff --git a/src/ucx/handle.hpp b/src/ucx/handle.hpp index 77ad4c62..54e9f079 100644 --- a/src/ucx/handle.hpp +++ b/src/ucx/handle.hpp @@ -12,12 +12,11 @@ // paths relative to backend #include -namespace oomph -{ -struct handle -{ - void* m_ptr; - std::size_t m_size; -}; +namespace oomph { + struct handle + { + void* m_ptr; + std::size_t m_size; + }; -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/pmi.hpp b/src/ucx/pmi.hpp index 0642df95..a90ac4c1 100644 --- a/src/ucx/pmi.hpp +++ b/src/ucx/pmi.hpp @@ -9,10 +9,10 @@ */ #pragma once -#include +#include #include #include -#include +#include #include @@ -21,162 +21,163 @@ // paths relative to backend #include -namespace oomph -{ -class pmi -{ - private: - util::moved_bit m_moved; - pmix_proc_t allproc; - pmix_proc_t myproc; - int32_t nprocs; - - public: - using rank_type = int; - using size_type = int; - - public: - pmi() +namespace oomph { + class pmi { - int rc; - pmix_value_t* pvalue; + private: + util::moved_bit m_moved; + pmix_proc_t allproc; + pmix_proc_t myproc; + int32_t nprocs; + + public: + using rank_type = int; + using size_type = int; + + public: + pmi() + { + int rc; + pmix_value_t* pvalue; - if (PMIX_SUCCESS != (rc = PMIx_Init(&myproc, NULL, 0))) - { throw std::runtime_error("PMIx_Init failed with code " + std::to_string(rc)); } - if (myproc.rank == 0) OOMPH_LOG("%d PMIx initialized", myproc.rank); + if (PMIX_SUCCESS != (rc = PMIx_Init(&myproc, NULL, 0))) + { + throw std::runtime_error("PMIx_Init failed with code " + std::to_string(rc)); + } + if (myproc.rank == 0) OOMPH_LOG("%d PMIx initialized", myproc.rank); - /* job-related info is found in our nspace, assigned to the + /* job-related info is found in our nspace, assigned to the * wildcard rank as it doesn't relate to a specific rank. Setup * a name to retrieve such values */ - PMIX_PROC_CONSTRUCT(&allproc); - // (void)strncpy(allproc.nspace, myproc.nspace, PMIX_MAX_NSLEN); - std::memcpy(allproc.nspace, myproc.nspace, PMIX_MAX_NSLEN); - allproc.rank = PMIX_RANK_WILDCARD; - - /* get the number of procs in our job */ - if (PMIX_SUCCESS != (rc = PMIx_Get(&allproc, PMIX_JOB_SIZE, NULL, 0, &pvalue))) - { - std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); - throw std::runtime_error("Client ns " + nspace + " rank " + - std::to_string(myproc.rank) + - ": PMIx_Get job size failed: " + std::to_string(rc) + "\n"); + PMIX_PROC_CONSTRUCT(&allproc); + // (void)strncpy(allproc.nspace, myproc.nspace, PMIX_MAX_NSLEN); + std::memcpy(allproc.nspace, myproc.nspace, PMIX_MAX_NSLEN); + allproc.rank = PMIX_RANK_WILDCARD; + + /* get the number of procs in our job */ + if (PMIX_SUCCESS != (rc = PMIx_Get(&allproc, PMIX_JOB_SIZE, NULL, 0, &pvalue))) + { + std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); + throw std::runtime_error("Client ns " + nspace + " rank " + + std::to_string(myproc.rank) + + ": PMIx_Get job size failed: " + std::to_string(rc) + "\n"); + } + nprocs = pvalue->data.uint32; + PMIX_VALUE_RELEASE(pvalue); } - nprocs = pvalue->data.uint32; - PMIX_VALUE_RELEASE(pvalue); - } - pmi(const pmi&) = delete; - pmi(pmi&&) = default; - pmi& operator=(const pmi&) = delete; - pmi& operator=(pmi&&) = default; + pmi(pmi const&) = delete; + pmi(pmi&&) = default; + pmi& operator=(pmi const&) = delete; + pmi& operator=(pmi&&) = default; - ~pmi() - { - int rc; - if (m_moved) return; - if (PMIX_SUCCESS != (rc = PMIx_Finalize(NULL, 0))) - { - OOMPH_WARN( - "Client ns %s rank %d:PMIx_Finalize failed: %d\n", myproc.nspace, myproc.rank, rc); - } - else + ~pmi() { - if (myproc.rank == 0) OOMPH_LOG("%d PMIx finalized", myproc.rank); + int rc; + if (m_moved) return; + if (PMIX_SUCCESS != (rc = PMIx_Finalize(NULL, 0))) + { + OOMPH_WARN("Client ns %s rank %d:PMIx_Finalize failed: %d\n", myproc.nspace, + myproc.rank, rc); + } + else + { + if (myproc.rank == 0) OOMPH_LOG("%d PMIx finalized", myproc.rank); + } } - } - rank_type rank() { return myproc.rank; } + rank_type rank() { return myproc.rank; } - size_type size() { return nprocs; } + size_type size() { return nprocs; } - void set(const std::string key, const std::vector data) - { - int rc; - pmix_value_t value; - - PMIX_VALUE_CONSTRUCT(&value); - value.type = PMIX_BYTE_OBJECT; - value.data.bo.bytes = (char*)data.data(); - value.data.bo.size = data.size(); - if (PMIX_SUCCESS != (rc = PMIx_Put(PMIX_GLOBAL, key.c_str(), &value))) + void set(std::string const key, std::vector const data) { - std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); - throw std::runtime_error("Client ns " + nspace + " rank " + - std::to_string(myproc.rank) + ": PMIx_Put failed: " + key + - " " + std::to_string(rc) + "\n"); + int rc; + pmix_value_t value; + + PMIX_VALUE_CONSTRUCT(&value); + value.type = PMIX_BYTE_OBJECT; + value.data.bo.bytes = (char*) data.data(); + value.data.bo.size = data.size(); + if (PMIX_SUCCESS != (rc = PMIx_Put(PMIX_GLOBAL, key.c_str(), &value))) + { + std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); + throw std::runtime_error("Client ns " + nspace + " rank " + + std::to_string(myproc.rank) + ": PMIx_Put failed: " + key + " " + + std::to_string(rc) + "\n"); + } + + /* protect the data */ + value.data.bo.bytes = NULL; + value.data.bo.size = 0; + PMIX_VALUE_DESTRUCT(&value); + OOMPH_LOG("PMIx_Put on %s", key.c_str()); + + if (PMIX_SUCCESS != (rc = PMIx_Commit())) + { + std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); + throw std::runtime_error("Client ns " + nspace + " rank " + + std::to_string(myproc.rank) + ": PMIx_Commit failed: " + key + " " + + std::to_string(rc) + "\n"); + } + OOMPH_LOG("PMIx_Commit on %s", key.c_str()); } - /* protect the data */ - value.data.bo.bytes = NULL; - value.data.bo.size = 0; - PMIX_VALUE_DESTRUCT(&value); - OOMPH_LOG("PMIx_Put on %s", key.c_str()); - - if (PMIX_SUCCESS != (rc = PMIx_Commit())) + std::vector get(uint32_t peer_rank, std::string const key) { - std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); - throw std::runtime_error("Client ns " + nspace + " rank " + - std::to_string(myproc.rank) + ": PMIx_Commit failed: " + key + - " " + std::to_string(rc) + "\n"); + int rc; + pmix_proc_t proc; + pmix_value_t* pvalue; + + PMIX_PROC_CONSTRUCT(&proc); + // (void)strncpy(proc.nspace, myproc.nspace, PMIX_MAX_NSLEN); + std::memcpy(proc.nspace, myproc.nspace, PMIX_MAX_NSLEN); + proc.rank = peer_rank; + if (PMIX_SUCCESS != (rc = PMIx_Get(&proc, key.c_str(), NULL, 0, &pvalue))) + { + std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); + throw std::runtime_error("Client ns " + nspace + " rank " + + std::to_string(myproc.rank) + ": PMIx_Get " + key + ": " + std::to_string(rc) + + "\n"); + } + if (pvalue->type != PMIX_BYTE_OBJECT) + { + std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); + throw std::runtime_error("Client ns " + nspace + " rank " + + std::to_string(myproc.rank) + ": PMIx_Get " + key + ": " + std::to_string(rc) + + "\n"); + } + + /* get the returned data */ + std::vector data( + pvalue->data.bo.bytes, pvalue->data.bo.bytes + pvalue->data.bo.size); + + /* free the PMIx data */ + PMIX_VALUE_RELEASE(pvalue); + PMIX_PROC_DESTRUCT(&proc); + + OOMPH_LOG("PMIx_get %s returned %zi bytes", key.c_str(), data.size()); + + return data; } - OOMPH_LOG("PMIx_Commit on %s", key.c_str()); - } - std::vector get(uint32_t peer_rank, const std::string key) - { - int rc; - pmix_proc_t proc; - pmix_value_t* pvalue; - - PMIX_PROC_CONSTRUCT(&proc); - // (void)strncpy(proc.nspace, myproc.nspace, PMIX_MAX_NSLEN); - std::memcpy(proc.nspace, myproc.nspace, PMIX_MAX_NSLEN); - proc.rank = peer_rank; - if (PMIX_SUCCESS != (rc = PMIx_Get(&proc, key.c_str(), NULL, 0, &pvalue))) - { - std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); - throw std::runtime_error("Client ns " + nspace + " rank " + - std::to_string(myproc.rank) + ": PMIx_Get " + key + ": " + - std::to_string(rc) + "\n"); - } - if (pvalue->type != PMIX_BYTE_OBJECT) - { - std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); - throw std::runtime_error("Client ns " + nspace + " rank " + - std::to_string(myproc.rank) + ": PMIx_Get " + key + ": " + - std::to_string(rc) + "\n"); - } - - /* get the returned data */ - std::vector data( - pvalue->data.bo.bytes, pvalue->data.bo.bytes + pvalue->data.bo.size); - - /* free the PMIx data */ - PMIX_VALUE_RELEASE(pvalue); - PMIX_PROC_DESTRUCT(&proc); - - OOMPH_LOG("PMIx_get %s returned %zi bytes", key.c_str(), data.size()); - - return data; - } - - void exchange() - { - int rc; - pmix_info_t info; - bool flag; - - PMIX_INFO_CONSTRUCT(&info); - flag = true; - PMIX_INFO_LOAD(&info, PMIX_COLLECT_DATA, &flag, PMIX_BOOL); - if (PMIX_SUCCESS != (rc = PMIx_Fence(&allproc, 1, &info, 1))) + void exchange() { - std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); - throw std::runtime_error("Client ns " + nspace + " rank " + - std::to_string(myproc.rank) + - ": PMIx_Fence failed: " + std::to_string(rc) + "\n"); + int rc; + pmix_info_t info; + bool flag; + + PMIX_INFO_CONSTRUCT(&info); + flag = true; + PMIX_INFO_LOAD(&info, PMIX_COLLECT_DATA, &flag, PMIX_BOOL); + if (PMIX_SUCCESS != (rc = PMIx_Fence(&allproc, 1, &info, 1))) + { + std::string nspace(myproc.nspace, myproc.nspace + strlen(myproc.nspace)); + throw std::runtime_error("Client ns " + nspace + " rank " + + std::to_string(myproc.rank) + ": PMIx_Fence failed: " + std::to_string(rc) + + "\n"); + } + PMIX_INFO_DESTRUCT(&info); } - PMIX_INFO_DESTRUCT(&info); - } -}; -} // namespace oomph + }; +} // namespace oomph diff --git a/src/ucx/pthread_spin_mutex.hpp b/src/ucx/pthread_spin_mutex.hpp index 2e191904..23ad704c 100644 --- a/src/ucx/pthread_spin_mutex.hpp +++ b/src/ucx/pthread_spin_mutex.hpp @@ -11,30 +11,26 @@ #include -namespace oomph -{ -namespace pthread_spin -{ -class mutex -{ - private: // members - pthread_spinlock_t m_lock; +namespace oomph { namespace pthread_spin { + class mutex + { + private: // members + pthread_spinlock_t m_lock; - public: - mutex() noexcept { pthread_spin_init(&m_lock, PTHREAD_PROCESS_PRIVATE); } - mutex(const mutex&) = delete; - mutex(mutex&&) = delete; - ~mutex() { pthread_spin_destroy(&m_lock); } + public: + mutex() noexcept { pthread_spin_init(&m_lock, PTHREAD_PROCESS_PRIVATE); } + mutex(mutex const&) = delete; + mutex(mutex&&) = delete; + ~mutex() { pthread_spin_destroy(&m_lock); } - inline bool try_lock() noexcept { return (pthread_spin_trylock(&m_lock) == 0); } + inline bool try_lock() noexcept { return (pthread_spin_trylock(&m_lock) == 0); } - inline void lock() noexcept - { - while (!try_lock()) { sched_yield(); } - } + inline void lock() noexcept + { + while (!try_lock()) { sched_yield(); } + } - inline void unlock() noexcept { pthread_spin_unlock(&m_lock); } -}; + inline void unlock() noexcept { pthread_spin_unlock(&m_lock); } + }; -} // namespace pthread_spin -} // namespace oomph +}} // namespace oomph::pthread_spin diff --git a/src/ucx/region.hpp b/src/ucx/region.hpp index 033dca35..8d43aa91 100644 --- a/src/ucx/region.hpp +++ b/src/ucx/region.hpp @@ -14,57 +14,56 @@ // paths relative to backend #include -namespace oomph -{ -class region -{ - public: - using handle_type = handle; - - private: - void* m_ptr; - - public: - region(void* ptr) - : m_ptr{ptr} +namespace oomph { + class region { - } + public: + using handle_type = handle; - region(region const&) = delete; + private: + void* m_ptr; - region(region&& r) noexcept - : m_ptr{std::exchange(r.m_ptr, nullptr)} - { - } + public: + region(void* ptr) + : m_ptr{ptr} + { + } - // get a handle to some portion of the region - handle_type get_handle(std::size_t offset, std::size_t size) - { - return {(void*)((char*)m_ptr + offset), size}; - } -}; - -class rma_region -{ - public: - using handle_type = handle; - - private: - ucp_context_h m_ucp_context; - void* m_ptr; - std::size_t m_size; - ucp_mem_h m_memh; - - public: - rma_region(ucp_context_h ctxt, void* ptr, std::size_t size, bool gpu = false) - : m_ucp_context{ctxt} - , m_ptr{ptr} - , m_size{size} + region(region const&) = delete; + + region(region&& r) noexcept + : m_ptr{std::exchange(r.m_ptr, nullptr)} + { + } + + // get a handle to some portion of the region + handle_type get_handle(std::size_t offset, std::size_t size) + { + return {(void*) ((char*) m_ptr + offset), size}; + } + }; + + class rma_region { - ucp_mem_map_params_t params; + public: + using handle_type = handle; + + private: + ucp_context_h m_ucp_context; + void* m_ptr; + std::size_t m_size; + ucp_mem_h m_memh; + + public: + rma_region(ucp_context_h ctxt, void* ptr, std::size_t size, bool gpu = false) + : m_ucp_context{ctxt} + , m_ptr{ptr} + , m_size{size} + { + ucp_mem_map_params_t params; - // enable fields - /* clang-format off */ + // enable fields + /* clang-format off */ params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS // enable address field | UCP_MEM_MAP_PARAM_FIELD_LENGTH // enable length field @@ -73,51 +72,51 @@ class rma_region | UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE // enable memory type field #endif ; - /* clang-format on */ + /* clang-format on */ - // set fields - params.address = ptr; - params.length = size; -#if (UCP_API_VERSION >= 17432576) // version >= 1.10 - params.memory_type = UCS_MEMORY_TYPE_HOST; + // set fields + params.address = ptr; + params.length = size; +#if (UCP_API_VERSION >= 17432576) // version >= 1.10 + params.memory_type = UCS_MEMORY_TYPE_HOST; #endif - // special treatment for gpu memory + // special treatment for gpu memory #if OOMPH_ENABLE_DEVICE | !defined(OOMPH_DEVICE_EMULATE) - if (gpu) - { -#if (UCP_API_VERSION >= 17432576) // version >= 1.10 -#if defined(OOMPH_DEVICE_CUDA) - params.memory_type = UCS_MEMORY_TYPE_CUDA; -#elif defined(OOMPH_DEVICE_HIP) - params.memory_type = UCS_MEMORY_TYPE_ROCM; -#endif -#endif - } + if (gpu) + { +# if (UCP_API_VERSION >= 17432576) // version >= 1.10 +# if defined(OOMPH_DEVICE_CUDA) + params.memory_type = UCS_MEMORY_TYPE_CUDA; +# elif defined(OOMPH_DEVICE_HIP) + params.memory_type = UCS_MEMORY_TYPE_ROCM; +# endif +# endif + } #endif - // register memory - OOMPH_CHECK_UCX_RESULT(ucp_mem_map(m_ucp_context, ¶ms, &m_memh)); - } + // register memory + OOMPH_CHECK_UCX_RESULT(ucp_mem_map(m_ucp_context, ¶ms, &m_memh)); + } - rma_region(rma_region const&) = delete; - rma_region(rma_region&& r) noexcept - : m_ucp_context{r.m_ucp_context} - , m_ptr{std::exchange(r.m_ptr, nullptr)} - , m_size{r.m_size} - , m_memh{r.m_memh} - { - } - ~rma_region() - { - if (m_ptr) { ucp_mem_unmap(m_ucp_context, m_memh); } - } + rma_region(rma_region const&) = delete; + rma_region(rma_region&& r) noexcept + : m_ucp_context{r.m_ucp_context} + , m_ptr{std::exchange(r.m_ptr, nullptr)} + , m_size{r.m_size} + , m_memh{r.m_memh} + { + } + ~rma_region() + { + if (m_ptr) { ucp_mem_unmap(m_ucp_context, m_memh); } + } - // get a handle to some portion of the region - handle_type get_handle(std::size_t offset, std::size_t size) - { - return {(void*)((char*)m_ptr + offset), size}; - } -}; + // get a handle to some portion of the region + handle_type get_handle(std::size_t offset, std::size_t size) + { + return {(void*) ((char*) m_ptr + offset), size}; + } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/request_data.hpp b/src/ucx/request_data.hpp index 9c716658..f0796d7b 100644 --- a/src/ucx/request_data.hpp +++ b/src/ucx/request_data.hpp @@ -12,55 +12,55 @@ // paths relative to backend #include -namespace oomph -{ -class communicator_impl; +namespace oomph { + class communicator_impl; -struct request_data -{ - detail::request_state* m_req; - detail::shared_request_state* m_shared_req; - //bool m_empty; - - void destroy() + struct request_data { - //m_comm = nullptr; - //m_cb.~cb_t(); - //m_empty = true; - m_req = nullptr; - m_shared_req = nullptr; - } + detail::request_state* m_req; + detail::shared_request_state* m_shared_req; + //bool m_empty; - bool empty() const noexcept { return !((bool)m_req || (bool)m_shared_req); } + void destroy() + { + //m_comm = nullptr; + //m_cb.~cb_t(); + //m_empty = true; + m_req = nullptr; + m_shared_req = nullptr; + } - static request_data* construct(void* ptr, detail::request_state* req) - { - return ::new (get_impl(ptr)) request_data{req, nullptr}; - } + bool empty() const noexcept { return !((bool) m_req || (bool) m_shared_req); } - static request_data* construct(void* ptr, detail::shared_request_state* req) - { - return ::new (get_impl(ptr)) request_data{nullptr, req}; - } + static request_data* construct(void* ptr, detail::request_state* req) + { + return ::new (get_impl(ptr)) request_data{req, nullptr}; + } - // return pointer to an instance from ucx provided storage pointer - static request_data* get(void* ptr) { return std::launder(get_impl(ptr)); } + static request_data* construct(void* ptr, detail::shared_request_state* req) + { + return ::new (get_impl(ptr)) request_data{nullptr, req}; + } - // initialize request on pristine request data allocated by ucx - static void init(void* ptr) { get(ptr)->destroy(); } + // return pointer to an instance from ucx provided storage pointer + static request_data* get(void* ptr) { return std::launder(get_impl(ptr)); } - private: - static request_data* get_impl(void* ptr) - { - // alignment mask - static constexpr std::uintptr_t mask = ~(alignof(request_data) - 1u); - return reinterpret_cast( - (reinterpret_cast((unsigned char*)ptr) + alignof(request_data) - 1) & - mask); - } -}; + // initialize request on pristine request data allocated by ucx + static void init(void* ptr) { get(ptr)->destroy(); } + + private: + static request_data* get_impl(void* ptr) + { + // alignment mask + static constexpr std::uintptr_t mask = ~(alignof(request_data) - 1u); + return reinterpret_cast( + (reinterpret_cast((unsigned char*) ptr) + alignof(request_data) - + 1) & + mask); + } + }; -using request_data_size = - std::integral_constant; + using request_data_size = + std::integral_constant; -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/request_state.hpp b/src/ucx/request_state.hpp index 4549b8c8..0f9ffee3 100644 --- a/src/ucx/request_state.hpp +++ b/src/ucx/request_state.hpp @@ -15,86 +15,83 @@ #include <../request_state_base.hpp> #include -namespace oomph -{ -namespace detail -{ -struct request_state -: public util::enable_shared_from_this -, public request_state_base -{ - using base = request_state_base; - using shared_ptr_t = util::unsafe_shared_ptr; - - void* m_ucx_ptr; - ucx_mutex& m_mutex; - shared_ptr_t m_self_ptr; - - request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, std::size_t* scheduled, - rank_type rank, tag_type tag, cb_type&& cb, void* ucx_ptr, ucx_mutex& mtx) - : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} - , m_ucx_ptr{ucx_ptr} - , m_mutex{mtx} +namespace oomph { namespace detail { + struct request_state + : public util::enable_shared_from_this + , public request_state_base { - } - - void progress(); - - bool cancel(); - - void create_self_ref() + using base = request_state_base; + using shared_ptr_t = util::unsafe_shared_ptr; + + void* m_ucx_ptr; + ucx_mutex& m_mutex; + shared_ptr_t m_self_ptr; + + request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, + std::size_t* scheduled, rank_type rank, tag_type tag, cb_type&& cb, void* ucx_ptr, + ucx_mutex& mtx) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_ucx_ptr{ucx_ptr} + , m_mutex{mtx} + { + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool) m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } + }; + + struct shared_request_state + : public std::enable_shared_from_this + , public request_state_base { - // create a self-reference cycle!! - // this is useful if we only keep a raw pointer around internally, which still is supposed - // to keep the object alive - m_self_ptr = shared_from_this(); - } - - shared_ptr_t release_self_ref() noexcept - { - assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); - return std::move(m_self_ptr); - } -}; - -struct shared_request_state -: public std::enable_shared_from_this -, public request_state_base -{ - using base = request_state_base; - using shared_ptr_t = std::shared_ptr; - - void* m_ucx_ptr; - ucx_mutex& m_mutex; - shared_ptr_t m_self_ptr; - - shared_request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, - std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb, - void* ucx_ptr, ucx_mutex& mtx) - : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} - , m_ucx_ptr{ucx_ptr} - , m_mutex{mtx} - { - } - - void progress(); - - bool cancel(); - - void create_self_ref() - { - // create a self-reference cycle!! - // this is useful if we only keep a raw pointer around internally, which still is supposed - // to keep the object alive - m_self_ptr = shared_from_this(); - } - - shared_ptr_t release_self_ref() noexcept - { - assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); - return std::move(m_self_ptr); - } -}; - -} // namespace detail -} // namespace oomph + using base = request_state_base; + using shared_ptr_t = std::shared_ptr; + + void* m_ucx_ptr; + ucx_mutex& m_mutex; + shared_ptr_t m_self_ptr; + + shared_request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, + std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb, + void* ucx_ptr, ucx_mutex& mtx) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_ucx_ptr{ucx_ptr} + , m_mutex{mtx} + { + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool) m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } + }; + +}} // namespace oomph::detail diff --git a/src/ucx/rma_context.hpp b/src/ucx/rma_context.hpp index 62da088c..c0b5dd73 100644 --- a/src/ucx/rma_context.hpp +++ b/src/ucx/rma_context.hpp @@ -9,9 +9,9 @@ */ #pragma once +#include #include #include -#include #include @@ -19,51 +19,49 @@ #include #include -namespace oomph -{ -class rma_context -{ - public: - using region_type = rma_region; - using device_region_type = rma_region; - using heap_type = hwmalloc::heap; +namespace oomph { + class rma_context + { + public: + using region_type = rma_region; + using device_region_type = rma_region; + using heap_type = hwmalloc::heap; - private: - heap_type m_heap; - ucp_context_h m_context; + private: + heap_type m_heap; + ucp_context_h m_context; - public: - rma_context() - : m_heap{this} - { - } - rma_context(context_impl const&) = delete; - rma_context(context_impl&&) = delete; + public: + rma_context() + : m_heap{this} + { + } + rma_context(context_impl const&) = delete; + rma_context(context_impl&&) = delete; - rma_region make_region(void* ptr, std::size_t size, bool gpu = false) - { - return {m_context, ptr, size, gpu}; - } + rma_region make_region(void* ptr, std::size_t size, bool gpu = false) + { + return {m_context, ptr, size, gpu}; + } - auto& get_heap() noexcept { return m_heap; } + auto& get_heap() noexcept { return m_heap; } - void set_ucp_context(ucp_context_h c) { m_context = c; } -}; + void set_ucp_context(ucp_context_h c) { m_context = c; } + }; -template<> -inline rma_region -register_memory(rma_context& c, void* ptr, std::size_t size) -{ - return c.make_region(ptr, size); -} + template <> + inline rma_region register_memory(rma_context& c, void* ptr, std::size_t size) + { + return c.make_region(ptr, size); + } #if OOMPH_ENABLE_DEVICE -template<> -inline rma_region -register_device_memory(rma_context& c, int, void* ptr, std::size_t size) -{ - return c.make_region(ptr, size, true); -} + template <> + inline rma_region + register_device_memory(rma_context& c, int, void* ptr, std::size_t size) + { + return c.make_region(ptr, size, true); + } #endif -} // namespace oomph +} // namespace oomph diff --git a/src/ucx/worker.hpp b/src/ucx/worker.hpp index 94e589b7..669c2bd5 100644 --- a/src/ucx/worker.hpp +++ b/src/ucx/worker.hpp @@ -9,118 +9,117 @@ */ #pragma once -#include #include +#include #include #include // paths relative to backend -#include #include +#include -namespace oomph -{ -struct worker_t -{ - struct ucp_worker_handle +namespace oomph { + struct worker_t { - ucp_worker_h m_worker; - util::moved_bit m_moved; - - ucp_worker_handle() noexcept - : m_moved{true} + struct ucp_worker_handle { - } - ucp_worker_handle(const ucp_worker_handle&) = delete; - ucp_worker_handle& operator=(const ucp_worker_handle&) = delete; - ucp_worker_handle(ucp_worker_handle&& other) noexcept = default; - - ucp_worker_handle& operator=(ucp_worker_handle&& other) noexcept + ucp_worker_h m_worker; + util::moved_bit m_moved; + + ucp_worker_handle() noexcept + : m_moved{true} + { + } + ucp_worker_handle(ucp_worker_handle const&) = delete; + ucp_worker_handle& operator=(ucp_worker_handle const&) = delete; + ucp_worker_handle(ucp_worker_handle&& other) noexcept = default; + + ucp_worker_handle& operator=(ucp_worker_handle&& other) noexcept + { + destroy(); + m_worker.~ucp_worker_h(); + ::new ((void*) (&m_worker)) ucp_worker_h{other.m_worker}; + m_moved = std::move(other.m_moved); + return *this; + } + + ~ucp_worker_handle() { destroy(); } + + static void empty_send_cb(void*, ucs_status_t) {} + + void destroy() noexcept + { + if (!m_moved) ucp_worker_destroy(m_worker); + } + + operator ucp_worker_h() const noexcept { return m_worker; } + + ucp_worker_h& get() noexcept { return m_worker; } + ucp_worker_h const& get() const noexcept { return m_worker; } + }; + + using ep_handle_vector = std::vector; + using cache_type = std::unordered_map; + //using mutex_t = pthread_spin::recursive_mutex; + + //const mpi::rank_topology& m_rank_topology; + type_erased_address_db_t& m_db; + rank_type m_rank; + rank_type m_size; + ucp_worker_handle m_worker; + address_t m_address; + ep_handle_vector m_endpoint_handles; + cache_type m_endpoint_cache; + //int m_progressed_sends = 0; + //mutex_t* m_mutex_ptr = nullptr; + //volatile int m_progressed_recvs = 0; + //volatile int m_progressed_cancels = 0; + + worker_t(ucp_context_h ucp_handle, type_erased_address_db_t& db /*, mutex_t& mm*/, + ucs_thread_mode_t mode /*, const mpi::rank_topology& t*/) + //: m_rank_topology(t) + : m_db{db} + , m_rank{m_db.rank()} + , m_size{m_db.size()} + //, m_mutex_ptr{&mm} { - destroy(); - m_worker.~ucp_worker_h(); - ::new ((void*)(&m_worker)) ucp_worker_h{other.m_worker}; - m_moved = std::move(other.m_moved); - return *this; + ucp_worker_params_t params; + params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + params.thread_mode = mode; + OOMPH_CHECK_UCX_RESULT(ucp_worker_create(ucp_handle, ¶ms, &m_worker.get())); + ucp_address_t* worker_address; + std::size_t address_length; + OOMPH_CHECK_UCX_RESULT( + ucp_worker_get_address(m_worker.get(), &worker_address, &address_length)); + m_address = address_t{reinterpret_cast(worker_address), + reinterpret_cast(worker_address) + address_length}; + ucp_worker_release_address(m_worker.get(), worker_address); + m_worker.m_moved = false; } - ~ucp_worker_handle() { destroy(); } + worker_t(worker_t const&) = delete; + worker_t(worker_t&& other) noexcept = default; + worker_t& operator=(worker_t const&) = delete; + worker_t& operator=(worker_t&&) noexcept = delete; - static void empty_send_cb(void*, ucs_status_t) {} - - void destroy() noexcept + rank_type rank() const noexcept { return m_rank; } + rank_type size() const noexcept { return m_size; } + inline ucp_worker_h get() const noexcept { return m_worker.get(); } + address_t address() const noexcept { return m_address; } + inline endpoint_t const& connect(rank_type rank) { - if (!m_moved) ucp_worker_destroy(m_worker); + auto it = m_endpoint_cache.find(rank); + if (it != m_endpoint_cache.end()) return it->second; + auto addr = m_db.find(rank); + auto p = m_endpoint_cache.insert( + std::make_pair(rank, endpoint_t{rank, m_worker.get(), addr})); + return p.first->second; } - operator ucp_worker_h() const noexcept { return m_worker; } + //mutex_t& mutex() { return *m_mutex_ptr; } - ucp_worker_h& get() noexcept { return m_worker; } - const ucp_worker_h& get() const noexcept { return m_worker; } + //const mpi::rank_topology& rank_topology() const noexcept { return m_rank_topology; } }; - using ep_handle_vector = std::vector; - using cache_type = std::unordered_map; - //using mutex_t = pthread_spin::recursive_mutex; - - //const mpi::rank_topology& m_rank_topology; - type_erased_address_db_t& m_db; - rank_type m_rank; - rank_type m_size; - ucp_worker_handle m_worker; - address_t m_address; - ep_handle_vector m_endpoint_handles; - cache_type m_endpoint_cache; - //int m_progressed_sends = 0; - //mutex_t* m_mutex_ptr = nullptr; - //volatile int m_progressed_recvs = 0; - //volatile int m_progressed_cancels = 0; - - worker_t(ucp_context_h ucp_handle, type_erased_address_db_t& db /*, mutex_t& mm*/, - ucs_thread_mode_t mode /*, const mpi::rank_topology& t*/) - //: m_rank_topology(t) - : m_db{db} - , m_rank{m_db.rank()} - , m_size{m_db.size()} - //, m_mutex_ptr{&mm} - { - ucp_worker_params_t params; - params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - params.thread_mode = mode; - OOMPH_CHECK_UCX_RESULT(ucp_worker_create(ucp_handle, ¶ms, &m_worker.get())); - ucp_address_t* worker_address; - std::size_t address_length; - OOMPH_CHECK_UCX_RESULT( - ucp_worker_get_address(m_worker.get(), &worker_address, &address_length)); - m_address = address_t{reinterpret_cast(worker_address), - reinterpret_cast(worker_address) + address_length}; - ucp_worker_release_address(m_worker.get(), worker_address); - m_worker.m_moved = false; - } - - worker_t(const worker_t&) = delete; - worker_t(worker_t&& other) noexcept = default; - worker_t& operator=(const worker_t&) = delete; - worker_t& operator=(worker_t&&) noexcept = delete; - - rank_type rank() const noexcept { return m_rank; } - rank_type size() const noexcept { return m_size; } - inline ucp_worker_h get() const noexcept { return m_worker.get(); } - address_t address() const noexcept { return m_address; } - inline const endpoint_t& connect(rank_type rank) - { - auto it = m_endpoint_cache.find(rank); - if (it != m_endpoint_cache.end()) return it->second; - auto addr = m_db.find(rank); - auto p = - m_endpoint_cache.insert(std::make_pair(rank, endpoint_t{rank, m_worker.get(), addr})); - return p.first->second; - } - - //mutex_t& mutex() { return *m_mutex_ptr; } - - //const mpi::rank_topology& rank_topology() const noexcept { return m_rank_topology; } -}; - -} // namespace oomph +} // namespace oomph diff --git a/src/unique_ptr_set.hpp b/src/unique_ptr_set.hpp index a288f5d7..e0eb6e89 100644 --- a/src/unique_ptr_set.hpp +++ b/src/unique_ptr_set.hpp @@ -9,57 +9,56 @@ */ #pragma once +#include #include #include -#include -namespace oomph -{ -template> -class unique_ptr_set -{ - public: - using pointer = T*; +namespace oomph { + template > + class unique_ptr_set + { + public: + using pointer = T*; - private: - std::set m_ptrs; - std::unique_ptr m_mutex; - Deleter m_deleter; + private: + std::set m_ptrs; + std::unique_ptr m_mutex; + Deleter m_deleter; - public: - unique_ptr_set(Deleter d = Deleter{}) - : m_mutex{std::make_unique()} - , m_deleter{d} - { - } + public: + unique_ptr_set(Deleter d = Deleter{}) + : m_mutex{std::make_unique()} + , m_deleter{d} + { + } - unique_ptr_set(unique_ptr_set&&) = default; - unique_ptr_set& operator=(unique_ptr_set&&) = default; + unique_ptr_set(unique_ptr_set&&) = default; + unique_ptr_set& operator=(unique_ptr_set&&) = default; - ~unique_ptr_set() - { - if (m_mutex) - for (auto p : m_ptrs) destroy(p); - } + ~unique_ptr_set() + { + if (m_mutex) + for (auto p : m_ptrs) destroy(p); + } - public: - void insert(pointer p) - { - m_mutex->lock(); - m_ptrs.insert(p); - m_mutex->unlock(); - } + public: + void insert(pointer p) + { + m_mutex->lock(); + m_ptrs.insert(p); + m_mutex->unlock(); + } - void remove(pointer p) - { - m_mutex->lock(); - m_ptrs.erase(m_ptrs.find(p)); - destroy(p); - m_mutex->unlock(); - } + void remove(pointer p) + { + m_mutex->lock(); + m_ptrs.erase(m_ptrs.find(p)); + destroy(p); + m_mutex->unlock(); + } - private: - void destroy(pointer p) { m_deleter(p); } -}; + private: + void destroy(pointer p) { m_deleter(p); } + }; -} // namespace oomph +} // namespace oomph diff --git a/src/util/heap_pimpl_src.hpp b/src/util/heap_pimpl_src.hpp index 89c6dd27..f577bea2 100644 --- a/src/util/heap_pimpl_src.hpp +++ b/src/util/heap_pimpl_src.hpp @@ -9,89 +9,78 @@ */ #pragma once -#include #include +#include -namespace oomph -{ -namespace util -{ - -template -heap_pimpl -make_heap_pimpl(Args&&... args) -{ - return {new T{std::forward(args)...}}; -} - -template -heap_pimpl::~heap_pimpl() = default; - -template -heap_pimpl::heap_pimpl() noexcept = default; - -template -heap_pimpl::heap_pimpl(T* ptr) noexcept -: m{ptr} -{ -} - -template -template -heap_pimpl::heap_pimpl(Args&&... args) -: m{new T{std::forward(args)...}} -{ -} - -template -heap_pimpl::heap_pimpl(heap_pimpl&&) noexcept = default; - -template -heap_pimpl& heap_pimpl::operator=(heap_pimpl&&) noexcept = default; - -template -T* -heap_pimpl::operator->() noexcept -{ - return m.get(); -} - -template -T const* -heap_pimpl::operator->() const noexcept -{ - return m.get(); -} - -template -T& -heap_pimpl::operator*() noexcept -{ - return *m.get(); -} - -template -T const& -heap_pimpl::operator*() const noexcept -{ - return *m.get(); -} - -template -T* -heap_pimpl::get() noexcept -{ - return m.get(); -} - -template -T const* -heap_pimpl::get() const noexcept -{ - return m.get(); -} - -} // namespace util -} // namespace oomph +namespace oomph { namespace util { + + template + heap_pimpl make_heap_pimpl(Args&&... args) + { + return {new T{std::forward(args)...}}; + } + + template + heap_pimpl::~heap_pimpl() = default; + + template + heap_pimpl::heap_pimpl() noexcept = default; + + template + heap_pimpl::heap_pimpl(T* ptr) noexcept + : m{ptr} + { + } + + template + template + heap_pimpl::heap_pimpl(Args&&... args) + : m{new T{std::forward(args)...}} + { + } + + template + heap_pimpl::heap_pimpl(heap_pimpl&&) noexcept = default; + + template + heap_pimpl& heap_pimpl::operator=(heap_pimpl&&) noexcept = default; + + template + T* heap_pimpl::operator->() noexcept + { + return m.get(); + } + + template + T const* heap_pimpl::operator->() const noexcept + { + return m.get(); + } + + template + T& heap_pimpl::operator*() noexcept + { + return *m.get(); + } + + template + T const& heap_pimpl::operator*() const noexcept + { + return *m.get(); + } + + template + T* heap_pimpl::get() noexcept + { + return m.get(); + } + + template + T const* heap_pimpl::get() const noexcept + { + return m.get(); + } + +}} // namespace oomph::util #define OOMPH_INSTANTIATE_HEAP_PIMPL(T) template class ::oomph::util::heap_pimpl; diff --git a/src/util/pimpl_src.hpp b/src/util/pimpl_src.hpp index 02834980..550aa221 100644 --- a/src/util/pimpl_src.hpp +++ b/src/util/pimpl_src.hpp @@ -11,9 +11,9 @@ #include #if OOMPH_USE_FAST_PIMPL -# include "./stack_pimpl_src.hpp" -//# define OOMPH_INSTANTIATE_PIMPL(T) OOMPH_INSTANTIATE_STACK_PIMPL(T) +# include "./stack_pimpl_src.hpp" +//# define OOMPH_INSTANTIATE_PIMPL(T) OOMPH_INSTANTIATE_STACK_PIMPL(T) #else -# include "./heap_pimpl_src.hpp" -//# define OOMPH_INSTANTIATE_PIMPL(T) OOMPH_INSTANTIATE_HEAP_PIMPL(T) +# include "./heap_pimpl_src.hpp" +//# define OOMPH_INSTANTIATE_PIMPL(T) OOMPH_INSTANTIATE_HEAP_PIMPL(T) #endif diff --git a/src/util/stack_pimpl_src.hpp b/src/util/stack_pimpl_src.hpp index 9c0ef52a..90feba23 100644 --- a/src/util/stack_pimpl_src.hpp +++ b/src/util/stack_pimpl_src.hpp @@ -9,83 +9,72 @@ */ #pragma once -#include #include +#include -namespace oomph -{ -namespace util -{ - -template -stack_pimpl -make_stack_pimpl(Args&&... args) -{ - return {T{std::forward(args)...}}; -} - -template -stack_pimpl::~stack_pimpl() = default; - -template -stack_pimpl::stack_pimpl() noexcept = default; - -template -template -stack_pimpl::stack_pimpl(Args&&... args) -: m{std::forward(args)...} -{ -} - -template -stack_pimpl::stack_pimpl(stack_pimpl&&) noexcept = default; - -template -stack_pimpl& stack_pimpl::operator=(stack_pimpl&&) noexcept = default; - -template -T* -stack_pimpl::operator->() noexcept -{ - return m.get(); -} - -template -T const* -stack_pimpl::operator->() const noexcept -{ - return m.get(); -} - -template -T& -stack_pimpl::operator*() noexcept -{ - return *m.get(); -} - -template -T const& -stack_pimpl::operator*() const noexcept -{ - return *m.get(); -} - -template -T* -stack_pimpl::get() noexcept -{ - return m.get(); -} - -template -T const* -stack_pimpl::get() const noexcept -{ - return m.get(); -} - -} // namespace util -} // namespace oomph - -#define OOMPH_INSTANTIATE_STACK_PIMPL(T, B, A) template class ::oomph::util::stack_pimpl; +namespace oomph { namespace util { + + template + stack_pimpl make_stack_pimpl(Args&&... args) + { + return {T{std::forward(args)...}}; + } + + template + stack_pimpl::~stack_pimpl() = default; + + template + stack_pimpl::stack_pimpl() noexcept = default; + + template + template + stack_pimpl::stack_pimpl(Args&&... args) + : m{std::forward(args)...} + { + } + + template + stack_pimpl::stack_pimpl(stack_pimpl&&) noexcept = default; + + template + stack_pimpl& stack_pimpl::operator=(stack_pimpl&&) noexcept = default; + + template + T* stack_pimpl::operator->() noexcept + { + return m.get(); + } + + template + T const* stack_pimpl::operator->() const noexcept + { + return m.get(); + } + + template + T& stack_pimpl::operator*() noexcept + { + return *m.get(); + } + + template + T const& stack_pimpl::operator*() const noexcept + { + return *m.get(); + } + + template + T* stack_pimpl::get() noexcept + { + return m.get(); + } + + template + T const* stack_pimpl::get() const noexcept + { + return m.get(); + } + +}} // namespace oomph::util + +#define OOMPH_INSTANTIATE_STACK_PIMPL(T, B, A) template class ::oomph::util::stack_pimpl; diff --git a/test/ctor_stats.hpp b/test/ctor_stats.hpp index 9763405e..4eed2257 100644 --- a/test/ctor_stats.hpp +++ b/test/ctor_stats.hpp @@ -9,9 +9,9 @@ */ #pragma once -#include #include #include +#include #include // helper classes and functions for keeping track of lifetime @@ -53,11 +53,11 @@ struct ctor_stats_data struct ctor_stats { - ctor_stats_data* data; + ctor_stats_data* data; oomph::util::moved_bit moved; ctor_stats(ctor_stats_data& d) - : data{&d} + : data{&d} { ++(data->n_ctor); ++(data->alloc_ref_count); @@ -78,10 +78,11 @@ struct ctor_stats ctor_stats& operator=(ctor_stats const&) = delete; ctor_stats(ctor_stats&& other) - : data{other.data} - , moved{std::move(other.moved)} + : data{other.data} + , moved{std::move(other.moved)} { - if (!moved) ++(data->n_move_ctor); + if (!moved) + ++(data->n_move_ctor); else ++(data->n_move_ctor_of_moved); } @@ -91,13 +92,15 @@ struct ctor_stats data = other.data; if (!moved) { - if (!other.moved) ++(data->n_move_assign); + if (!other.moved) + ++(data->n_move_assign); else ++(data->n_move_assign_of_moved); } else { - if (!other.moved) ++(data->n_move_assign_to_moved); + if (!other.moved) + ++(data->n_move_assign_to_moved); else ++(data->n_move_assign_of_moved_to_moved); } @@ -115,7 +118,7 @@ struct function_registry { std::map m_data; - template + template F make(std::string const& id) { return F(m_data[id]); diff --git a/test/mpi_runner/gtest_main_mpi.cpp b/test/mpi_runner/gtest_main_mpi.cpp index d9ecf4ee..a62623b6 100644 --- a/test/mpi_runner/gtest_main_mpi.cpp +++ b/test/mpi_runner/gtest_main_mpi.cpp @@ -11,8 +11,7 @@ #include #include "./mpi_listener.hpp" -GTEST_API_ int -main(int argc, char** argv) +GTEST_API_ int main(int argc, char** argv) { int required = MPI_THREAD_MULTIPLE; int provided; diff --git a/test/mpi_runner/mpi_listener.hpp b/test/mpi_runner/mpi_listener.hpp index 3950577c..41627bfc 100644 --- a/test/mpi_runner/mpi_listener.hpp +++ b/test/mpi_runner/mpi_listener.hpp @@ -32,42 +32,42 @@ class mpi_listener : public testing::EmptyTestEventListener { - private: +private: using UnitTest = testing::UnitTest; using TestCase = testing::TestCase; using TestInfo = testing::TestInfo; using TestPartResult = testing::TestPartResult; - int rank_ = 0; - int size_ = 0; + int rank_ = 0; + int size_ = 0; std::ofstream fid_; - char buffer_[1024]; - int test_case_failures_ = 0; - int test_case_tests_ = 0; - int test_failures_ = 0; + char buffer_[1024]; + int test_case_failures_ = 0; + int test_case_tests_ = 0; + int test_failures_ = 0; bool does_print() const { return rank_ == 0; } - void print(const char* s) + void print(char const* s) { fid_ << s; if (does_print()) { std::cout << s; } } - void print(const std::string& s) { print(s.c_str()); } + void print(std::string const& s) { print(s.c_str()); } /// convenience function that handles the logic of using snprintf /// and forwarding the results to file and/or stdout. /// /// TODO : it might be an idea to use a resizeable buffer - template - void printf_helper(const char* s, Args&&... args) + template + void printf_helper(char const* s, Args&&... args) { snprintf(buffer_, sizeof(buffer_), s, std::forward(args)...); print(buffer_); } - public: +public: mpi_listener(std::string f_base = "") { MPI_Comm_rank(MPI_COMM_WORLD, &rank_); @@ -79,17 +79,17 @@ class mpi_listener : public testing::EmptyTestEventListener if (!fid_) { throw std::runtime_error("PID:" + std::to_string(rank_) + " could not open file " + - fname + " for test output"); + fname + " for test output"); } } /// Messages that are printed at the start and end of the test program. /// i.e. once only. - virtual void OnTestProgramStart(const UnitTest&) override + virtual void OnTestProgramStart(UnitTest const&) override { printf_helper("*** test output for rank %d of %d\n\n", rank_, size_); } - virtual void OnTestProgramEnd(const UnitTest&) override + virtual void OnTestProgramEnd(UnitTest const&) override { printf_helper("*** end test output for rank %d of %d\n", rank_, size_); } @@ -98,12 +98,12 @@ class mpi_listener : public testing::EmptyTestEventListener /// On startup a counter that counts the number of tests that fail in /// this test case is initialized to zero, and will be incremented for each /// test that fails. - virtual void OnTestCaseStart(const TestCase&) override + virtual void OnTestCaseStart(TestCase const&) override { test_case_failures_ = 0; test_case_tests_ = 0; } - virtual void OnTestCaseEnd(const TestCase& test_case) override + virtual void OnTestCaseEnd(TestCase const& test_case) override { if (test_case_failures_) { @@ -119,21 +119,21 @@ class mpi_listener : public testing::EmptyTestEventListener } // Called before a test starts. - virtual void OnTestStart(const TestInfo& test_info) override + virtual void OnTestStart(TestInfo const& test_info) override { printf_helper(" TEST %s::%s\n", test_info.test_case_name(), test_info.name()); test_failures_ = 0; } // Called after a failed assertion or a SUCCEED() invocation. - virtual void OnTestPartResult(const TestPartResult& test_part_result) override + virtual void OnTestPartResult(TestPartResult const& test_part_result) override { - const char* banner = + char const* banner = "--------------------------------------------------------------------------------"; // indent all lines in the summary by 4 spaces std::string summary = " " + std::string(test_part_result.summary()); - auto pos = summary.find("\n"); + auto pos = summary.find("\n"); while (pos != summary.size() && pos != std::string::npos) { summary.replace(pos, 1, "\n "); @@ -149,7 +149,7 @@ class mpi_listener : public testing::EmptyTestEventListener } // Called after a test ends. - virtual void OnTestEnd(const TestInfo&) override + virtual void OnTestEnd(TestInfo const&) override { test_case_tests_++; diff --git a/test/mpi_runner/mpi_test_fixture.hpp b/test/mpi_runner/mpi_test_fixture.hpp index ce4769d5..b30a7a92 100644 --- a/test/mpi_runner/mpi_test_fixture.hpp +++ b/test/mpi_runner/mpi_test_fixture.hpp @@ -25,7 +25,7 @@ struct mpi_test_fixture : public ::testing::Test //void TearDown() {} - protected: +protected: int world_rank; int world_size; }; diff --git a/test/reporting_allocator.hpp b/test/reporting_allocator.hpp index f75ea593..4f6ab821 100644 --- a/test/reporting_allocator.hpp +++ b/test/reporting_allocator.hpp @@ -10,19 +10,19 @@ #pragma once #include -#include -#include #include +#include +#include -template +template struct reporting_allocator { using value_type = T; reporting_allocator() noexcept {} - template - constexpr reporting_allocator(const reporting_allocator&) noexcept + template + constexpr reporting_allocator(reporting_allocator const&) noexcept { } @@ -46,7 +46,7 @@ struct reporting_allocator std::free(p); } - private: +private: void report(T* p, std::size_t n, bool alloc = true) const { std::cout << (alloc ? "Alloc: " : "Dealloc: ") << sizeof(T) * n << " bytes at " << std::hex @@ -54,15 +54,13 @@ struct reporting_allocator } }; -template -bool -operator==(const reporting_allocator&, const reporting_allocator&) +template +bool operator==(reporting_allocator const&, reporting_allocator const&) { return true; } -template -bool -operator!=(const reporting_allocator&, const reporting_allocator&) +template +bool operator!=(reporting_allocator const&, reporting_allocator const&) { return false; } diff --git a/test/test_barrier.cpp b/test/test_barrier.cpp index 3016c091..c1379de8 100644 --- a/test/test_barrier.cpp +++ b/test/test_barrier.cpp @@ -7,14 +7,14 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include -#include #include -#include "./mpi_runner/mpi_test_fixture.hpp" -#include #include -#include +#include #include +#include +#include +#include +#include "./mpi_runner/mpi_test_fixture.hpp" TEST_F(mpi_test_fixture, rank_barrier) { @@ -27,37 +27,35 @@ TEST_F(mpi_test_fixture, rank_barrier) for (int i = 0; i < 20; i++) { b.rank_barrier(); } } -namespace oomph -{ -class test_barrier -{ - public: - barrier& br; - - void test_in_node1(context& ctxt) +namespace oomph { + class test_barrier { - std::vector innode1_out(br.size()); + public: + barrier& br; - auto work = [&](int id) + void test_in_node1(context& ctxt) { - auto comm = ctxt.get_communicator(); - auto comm2 = ctxt.get_communicator(); - innode1_out[id] = br.in_node1() ? 1 : 0; - }; - std::vector ths; - for (int i = 0; i < br.size(); ++i) { ths.push_back(std::thread{work, i}); } - for (int i = 0; i < br.size(); ++i) { ths[i].join(); } - EXPECT_EQ(std::accumulate(innode1_out.begin(), innode1_out.end(), 0), 1); - } -}; -} // namespace oomph + std::vector innode1_out(br.size()); + + auto work = [&](int id) { + auto comm = ctxt.get_communicator(); + auto comm2 = ctxt.get_communicator(); + innode1_out[id] = br.in_node1() ? 1 : 0; + }; + std::vector ths; + for (int i = 0; i < br.size(); ++i) { ths.push_back(std::thread{work, i}); } + for (int i = 0; i < br.size(); ++i) { ths[i].join(); } + EXPECT_EQ(std::accumulate(innode1_out.begin(), innode1_out.end(), 0), 1); + } + }; +} // namespace oomph TEST_F(mpi_test_fixture, in_node1) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + auto ctxt = context(MPI_COMM_WORLD, true); std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + barrier b(ctxt, n_threads); oomph::test_barrier{b}.test_in_node1(ctxt); } @@ -65,9 +63,9 @@ TEST_F(mpi_test_fixture, in_node1) TEST_F(mpi_test_fixture, in_barrier_1) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + auto ctxt = context(MPI_COMM_WORLD, true); std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + barrier b(ctxt, n_threads); auto comm = ctxt.get_communicator(); auto comm2 = ctxt.get_communicator(); @@ -81,10 +79,9 @@ TEST_F(mpi_test_fixture, in_barrier) auto ctxt = context(MPI_COMM_WORLD, true); std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + barrier b(ctxt, n_threads); - auto work = [&]() - { + auto work = [&]() { auto comm = ctxt.get_communicator(); auto comm2 = ctxt.get_communicator(); for (int i = 0; i < 10; i++) @@ -105,10 +102,9 @@ TEST_F(mpi_test_fixture, full_barrier) auto ctxt = context(MPI_COMM_WORLD, true); std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + barrier b(ctxt, n_threads); - auto work = [&]() - { + auto work = [&]() { auto comm = ctxt.get_communicator(); auto comm3 = ctxt.get_communicator(); for (int i = 0; i < 10; i++) { b(); } @@ -125,16 +121,15 @@ TEST_F(mpi_test_fixture, full_barrier_sendrecv) auto ctxt = context(MPI_COMM_WORLD, true); std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + barrier b(ctxt, n_threads); - auto work = [&](int tid) - { + auto work = [&](int tid) { auto comm = ctxt.get_communicator(); auto comm2 = ctxt.get_communicator(); - int s_rank = (tid < 3) ? comm.rank() : ((comm.rank() + 1) % comm.size()); - int s_tag = comm.rank() * 10 + tid; - int r_rank = (tid > 0) ? comm.rank() : ((comm.rank() + comm.size() - 1) % comm.size()); - int r_tag = (tid > 0) ? (comm.rank() * 10 + tid - 1) : (r_rank * 10 + n_threads - 1); + int s_rank = (tid < 3) ? comm.rank() : ((comm.rank() + 1) % comm.size()); + int s_tag = comm.rank() * 10 + tid; + int r_rank = (tid > 0) ? comm.rank() : ((comm.rank() + comm.size() - 1) % comm.size()); + int r_tag = (tid > 0) ? (comm.rank() * 10 + tid - 1) : (r_rank * 10 + n_threads - 1); auto s_buffer = comm.make_buffer(1000); auto r_buffer = comm.make_buffer(1000); diff --git a/test/test_cancel.cpp b/test/test_cancel.cpp index f00ed737..f2ace3e8 100644 --- a/test/test_cancel.cpp +++ b/test/test_cancel.cpp @@ -7,30 +7,28 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include #include -#include "./mpi_runner/mpi_test_fixture.hpp" -#include #include -#include +#include +#include #include +#include +#include "./mpi_runner/mpi_test_fixture.hpp" -void -test_1(oomph::communicator& comm, unsigned int size, int thread_id = 0) +void test_1(oomph::communicator& comm, unsigned int size, int thread_id = 0) { EXPECT_TRUE(comm.size() > 0); auto msg = comm.make_buffer(size); if (comm.rank() == 0) { - std::vector dsts(comm.size()>1 ? comm.size()-1 : 1, 0); + std::vector dsts(comm.size() > 1 ? comm.size() - 1 : 1, 0); for (unsigned int i = 0; i < size; ++i) msg[i] = i; - for (int d = 1; d threads; @@ -89,8 +87,7 @@ TEST_F(mpi_test_fixture, test_cancel_request_mt) for (auto& t : threads) t.join(); } -void -test_2(oomph::communicator& comm, unsigned int size, int thread_id = 0) +void test_2(oomph::communicator& comm, unsigned int size, int thread_id = 0) { EXPECT_TRUE(comm.size() > 0); auto msg = comm.make_buffer(size); @@ -98,9 +95,9 @@ test_2(oomph::communicator& comm, unsigned int size, int thread_id = 0) if (comm.rank() == 0) { - std::vector dsts(comm.size()>1 ? comm.size()-1 : 1, 0); + std::vector dsts(comm.size() > 1 ? comm.size() - 1 : 1, 0); for (unsigned int i = 0; i < size; ++i) msg[i] = i; - for (int d = 1; d threads; diff --git a/test/test_context.cpp b/test/test_context.cpp index 930c248a..3dd501e6 100644 --- a/test/test_context.cpp +++ b/test/test_context.cpp @@ -7,15 +7,15 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include #include -#include "./mpi_runner/mpi_test_fixture.hpp" -#include #include +#include +#include #include +#include "./mpi_runner/mpi_test_fixture.hpp" -const std::size_t size = 1024; -const int num_threads = 4; +std::size_t const size = 1024; +int const num_threads = 4; TEST_F(mpi_test_fixture, context_ordered) { diff --git a/test/test_locality.cpp b/test/test_locality.cpp index 80e5e1ab..bf55403b 100644 --- a/test/test_locality.cpp +++ b/test/test_locality.cpp @@ -7,19 +7,18 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include +#include #include -#include "./mpi_runner/mpi_test_fixture.hpp" -#include #include -#include -#include #include -#include - +#include +#include +#include +#include +#include "./mpi_runner/mpi_test_fixture.hpp" #ifdef __APPLE__ -#define HOST_NAME_MAX _POSIX_HOST_NAME_MAX +# define HOST_NAME_MAX _POSIX_HOST_NAME_MAX #endif // test locality by collecting all local ranks @@ -45,7 +44,9 @@ TEST_F(mpi_test_fixture, locality_enumerate) if (r == comm.rank()) { for (int rr = 0; rr < comm.size(); ++rr) - { local_ranks[rr] = comm.is_local(rr) ? 1 : 0; } + { + local_ranks[rr] = comm.is_local(rr) ? 1 : 0; + } for (int rr = 0; rr < comm.size(); ++rr) { if (rr != comm.rank()) @@ -57,14 +58,16 @@ TEST_F(mpi_test_fixture, locality_enumerate) } else { - const int is_neighbor = comm.is_local(r) ? 1 : 0; + int const is_neighbor = comm.is_local(r) ? 1 : 0; comm.recv(local_ranks, r, 0).wait(); comm.recv(other_host_name, r, 1).wait(); EXPECT_EQ(is_neighbor, local_ranks[comm.rank()]); if (is_neighbor) for (int rr = 0; rr < comm.size(); ++rr) - { EXPECT_EQ((comm.is_local(rr) ? 1 : 0), local_ranks[rr]); } - const int equal_hosts = + { + EXPECT_EQ((comm.is_local(rr) ? 1 : 0), local_ranks[rr]); + } + int const equal_hosts = (std::strcmp(my_host_name.data(), other_host_name.data()) == 0) ? 1 : 0; if (is_neighbor == 1) { EXPECT_EQ(equal_hosts, 1); } } diff --git a/test/test_send_multi.cpp b/test/test_send_multi.cpp index 7d219182..a6190d34 100644 --- a/test/test_send_multi.cpp +++ b/test/test_send_multi.cpp @@ -7,36 +7,33 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include -#include -#include "./mpi_runner/mpi_test_fixture.hpp" -#include #include +#include #include +#include +#include #include +#include "./mpi_runner/mpi_test_fixture.hpp" -const int SIZE = 1000000; +int const SIZE = 1000000; -template -void -reset_msg(M& msg) +template +void reset_msg(M& msg) { for (std::size_t i = 0; i < msg.size(); ++i) msg[i] = -1; } -template -void -init_msg(M& msg) +template +void init_msg(M& msg) { for (std::size_t i = 0; i < msg.size(); ++i) msg[i] = i; } -template -bool -check_msg(M const& msg) +template +bool check_msg(M const& msg) { bool ok = true; - for (std::size_t i = 0; i < msg.size(); ++i) ok = ok && (msg[i] == (int)i); + for (std::size_t i = 0; i < msg.size(); ++i) ok = ok && (msg[i] == (int) i); return ok; } diff --git a/test/test_send_recv.cpp b/test/test_send_recv.cpp index 0cfd1170..08afefe8 100644 --- a/test/test_send_recv.cpp +++ b/test/test_send_recv.cpp @@ -7,16 +7,16 @@ * Please, refer to the LICENSE file in the root directory. * SPDX-License-Identifier: BSD-3-Clause */ -#include +#include #include -#include "./mpi_runner/mpi_test_fixture.hpp" -#include #include +#include +#include #include -#include +#include "./mpi_runner/mpi_test_fixture.hpp" -#define NITERS 50 -#define SIZE 64 +#define NITERS 50 +#define SIZE 64 #define NTHREADS 4 std::vector> shared_received(NTHREADS); @@ -33,22 +33,22 @@ struct test_environment_base using tag_type = oomph::tag_type; using message = oomph::message_buffer; - oomph::context& ctxt; + oomph::context& ctxt; oomph::communicator comm; - rank_type speer_rank; - rank_type rpeer_rank; - int thread_id; - int num_threads; - tag_type tag; + rank_type speer_rank; + rank_type rpeer_rank; + int thread_id; + int num_threads; + tag_type tag; test_environment_base(oomph::context& c, int tid, int num_t) - : ctxt(c) - , comm(ctxt.get_communicator()) - , speer_rank((comm.rank() + 1) % comm.size()) - , rpeer_rank((comm.rank() + comm.size() - 1) % comm.size()) - , thread_id(tid) - , num_threads(num_t) - , tag(tid) + : ctxt(c) + , comm(ctxt.get_communicator()) + , speer_rank((comm.rank() + 1) % comm.size()) + , rpeer_rank((comm.rank() + comm.size() - 1) % comm.size()) + , thread_id(tid) + , num_threads(num_t) + , tag(tid) { } }; @@ -57,25 +57,26 @@ struct test_environment : public test_environment_base { using base = test_environment_base; - static auto make_buffer(oomph::communicator& comm, std::size_t size, bool user_alloc, - rank_type* ptr) + static auto make_buffer( + oomph::communicator& comm, std::size_t size, bool user_alloc, rank_type* ptr) { - if (user_alloc) return comm.make_buffer(ptr, size); + if (user_alloc) + return comm.make_buffer(ptr, size); else return comm.make_buffer(size); } std::vector raw_smsg; std::vector raw_rmsg; - message smsg; - message rmsg; + message smsg; + message rmsg; test_environment(oomph::context& c, std::size_t size, int tid, int num_t, bool user_alloc) - : base(c, tid, num_t) - , raw_smsg(user_alloc ? size : 0) - , raw_rmsg(user_alloc ? size : 0) - , smsg(make_buffer(comm, size, user_alloc, raw_smsg.data())) - , rmsg(make_buffer(comm, size, user_alloc, raw_rmsg.data())) + : base(c, tid, num_t) + , raw_smsg(user_alloc ? size : 0) + , raw_rmsg(user_alloc ? size : 0) + , smsg(make_buffer(comm, size, user_alloc, raw_smsg.data())) + , rmsg(make_buffer(comm, size, user_alloc, raw_rmsg.data())) { fill_send_buffer(); fill_recv_buffer(); @@ -104,10 +105,11 @@ struct test_environment_device : public test_environment_base { using base = test_environment_base; - static auto make_buffer(oomph::communicator& comm, std::size_t size, bool user_alloc, - rank_type* device_ptr) + static auto make_buffer( + oomph::communicator& comm, std::size_t size, bool user_alloc, rank_type* device_ptr) { - if (user_alloc) return comm.make_device_buffer(device_ptr, size, 0); + if (user_alloc) + return comm.make_device_buffer(device_ptr, size, 0); else return comm.make_device_buffer(size, 0); } @@ -120,37 +122,37 @@ struct test_environment_device : public test_environment_base if (size) m_ptr = hwmalloc::device_malloc(size * sizeof(rank_type)); } device_allocation(device_allocation&& other) - : m_ptr{std::exchange(other.m_ptr, nullptr)} + : m_ptr{std::exchange(other.m_ptr, nullptr)} { } ~device_allocation() { -#ifndef OOMPH_TEST_LEAK_GPU_MEMORY +# ifndef OOMPH_TEST_LEAK_GPU_MEMORY if (m_ptr) hwmalloc::device_free(m_ptr); -#endif +# endif } - rank_type* get() const noexcept { return (rank_type*)m_ptr; } + rank_type* get() const noexcept { return (rank_type*) m_ptr; } }; device_allocation raw_device_smsg; device_allocation raw_device_rmsg; - message smsg; - message rmsg; - - test_environment_device(oomph::context& c, std::size_t size, int tid, int num_t, - bool user_alloc) - : base(c, tid, num_t) -#ifndef OOMPH_TEST_LEAK_GPU_MEMORY - , raw_device_smsg(user_alloc ? size : 0) - , raw_device_rmsg(user_alloc ? size : 0) - , smsg(make_buffer(comm, size, user_alloc, raw_device_smsg.get())) - , rmsg(make_buffer(comm, size, user_alloc, raw_device_rmsg.get())) -#else - , raw_device_smsg(size) - , raw_device_rmsg(size) - , smsg(make_buffer(comm, size, user_alloc, raw_device_smsg.get())) - , rmsg(make_buffer(comm, size, user_alloc, raw_device_rmsg.get())) -#endif + message smsg; + message rmsg; + + test_environment_device( + oomph::context& c, std::size_t size, int tid, int num_t, bool user_alloc) + : base(c, tid, num_t) +# ifndef OOMPH_TEST_LEAK_GPU_MEMORY + , raw_device_smsg(user_alloc ? size : 0) + , raw_device_rmsg(user_alloc ? size : 0) + , smsg(make_buffer(comm, size, user_alloc, raw_device_smsg.get())) + , rmsg(make_buffer(comm, size, user_alloc, raw_device_rmsg.get())) +# else + , raw_device_smsg(size) + , raw_device_rmsg(size) + , smsg(make_buffer(comm, size, user_alloc, raw_device_smsg.get())) + , rmsg(make_buffer(comm, size, user_alloc, raw_device_rmsg.get())) +# endif { fill_send_buffer(); fill_recv_buffer(); @@ -178,9 +180,8 @@ struct test_environment_device : public test_environment_base }; #endif -template -void -launch_test(Func f) +template +void launch_test(Func f) { // single threaded { @@ -193,7 +194,7 @@ launch_test(Func f) // multi threaded { - oomph::context ctxt(MPI_COMM_WORLD, true); + oomph::context ctxt(MPI_COMM_WORLD, true); std::vector threads; threads.reserve(NTHREADS); reset_counters(); @@ -210,9 +211,9 @@ launch_test(Func f) // no callback // =========== -template -void -test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, bool user_alloc) +template +void test_send_recv( + oomph::context& ctxt, std::size_t size, int tid, int num_threads, bool user_alloc) { Env env(ctxt, size, tid, num_threads, user_alloc); @@ -221,10 +222,7 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, { auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); auto sreq = env.comm.send(env.smsg, env.speer_rank, env.tag); - while (!(rreq.is_ready() && sreq.is_ready())) - { - env.comm.progress(); - }; + while (!(rreq.is_ready() && sreq.is_ready())) { env.comm.progress(); }; EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); } @@ -260,9 +258,9 @@ TEST_F(mpi_test_fixture, send_recv) // callback: pass by l-value reference // =================================== -template -void -test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threads, bool user_alloc) +template +void test_send_recv_cb( + oomph::context& ctxt, std::size_t size, int tid, int num_threads, bool user_alloc) { using rank_type = test_environment::rank_type; using tag_type = test_environment::tag_type; @@ -270,8 +268,8 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa Env env(ctxt, size, tid, num_threads, user_alloc); - volatile int received = 0; - volatile int sent = 0; + int volatile received = 0; + int volatile sent = 0; auto send_callback = [&](message const&, rank_type, tag_type) { ++sent; }; auto recv_callback = [&](message&, rank_type, tag_type) { ++received; }; @@ -327,10 +325,9 @@ TEST_F(mpi_test_fixture, send_recv_cb) // callback: pass by r-value reference (give up ownership) // ======================================================= -template -void -test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, - bool user_alloc) +template +void test_send_recv_cb_disown( + oomph::context& ctxt, std::size_t size, int tid, int num_threads, bool user_alloc) { using rank_type = test_environment::rank_type; using tag_type = test_environment::tag_type; @@ -338,16 +335,14 @@ test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int nu Env env(ctxt, size, tid, num_threads, user_alloc); - volatile int received = 0; - volatile int sent = 0; + int volatile received = 0; + int volatile sent = 0; - auto send_callback = [&](message msg, rank_type, tag_type) - { + auto send_callback = [&](message msg, rank_type, tag_type) { ++sent; env.smsg = std::move(msg); }; - auto recv_callback = [&](message msg, rank_type, tag_type) - { + auto recv_callback = [&](message msg, rank_type, tag_type) { ++received; env.rmsg = std::move(msg); }; @@ -403,10 +398,9 @@ TEST_F(mpi_test_fixture, send_recv_cb_disown) // callback: pass by r-value reference (give up ownership), shared recv // ==================================================================== -template -void -test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, - bool user_alloc) +template +void test_send_shared_recv_cb_disown( + oomph::context& ctxt, std::size_t size, int tid, int num_threads, bool user_alloc) { using rank_type = test_environment::rank_type; using tag_type = test_environment::tag_type; @@ -417,15 +411,13 @@ test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, thread_id = env.thread_id; //volatile int received = 0; - volatile int sent = 0; + int volatile sent = 0; - auto send_callback = [&](message msg, rank_type, tag_type) - { + auto send_callback = [&](message msg, rank_type, tag_type) { ++sent; env.smsg = std::move(msg); }; - auto recv_callback = [&](message msg, rank_type, tag_type) - { + auto recv_callback = [&](message msg, rank_type, tag_type) { //std::cout << thread_id << " " << env.thread_id << std::endl; //if (thread_id != env.thread_id) std::cout << "other thread picked up callback" << std::endl; //else std::cout << "my thread picked up callback" << std::endl; @@ -485,10 +477,9 @@ TEST_F(mpi_test_fixture, send_shared_recv_cb_disown) // callback: pass by l-value reference, and resubmit // ================================================= -template -void -test_send_recv_cb_resubmit(oomph::context& ctxt, std::size_t size, int tid, int num_threads, - bool user_alloc) +template +void test_send_recv_cb_resubmit( + oomph::context& ctxt, std::size_t size, int tid, int num_threads, bool user_alloc) { using rank_type = test_environment::rank_type; using tag_type = test_environment::tag_type; @@ -496,13 +487,13 @@ test_send_recv_cb_resubmit(oomph::context& ctxt, std::size_t size, int tid, int Env env(ctxt, size, tid, num_threads, user_alloc); - volatile int received = 0; - volatile int sent = 0; + int volatile received = 0; + int volatile sent = 0; struct recursive_send_callback { - Env& env; - volatile int& sent; + Env& env; + int volatile& sent; void operator()(message& msg, rank_type dst, tag_type tag) { @@ -513,8 +504,8 @@ test_send_recv_cb_resubmit(oomph::context& ctxt, std::size_t size, int tid, int struct recursive_recv_callback { - Env& env; - volatile int& received; + Env& env; + int volatile& received; void operator()(message& msg, rank_type src, tag_type tag) { @@ -541,10 +532,9 @@ TEST_F(mpi_test_fixture, send_recv_cb_resubmit) // callback: pass by r-value reference (give up ownership), and resubmit // ===================================================================== -template -void -test_send_recv_cb_resubmit_disown(oomph::context& ctxt, std::size_t size, int tid, int num_threads, - bool user_alloc) +template +void test_send_recv_cb_resubmit_disown( + oomph::context& ctxt, std::size_t size, int tid, int num_threads, bool user_alloc) { using rank_type = test_environment::rank_type; using tag_type = test_environment::tag_type; @@ -552,13 +542,13 @@ test_send_recv_cb_resubmit_disown(oomph::context& ctxt, std::size_t size, int ti Env env(ctxt, size, tid, num_threads, user_alloc); - volatile int received = 0; - volatile int sent = 0; + int volatile received = 0; + int volatile sent = 0; struct recursive_send_callback { - Env& env; - volatile int& sent; + Env& env; + int volatile& sent; void operator()(message msg, rank_type dst, tag_type tag) { @@ -570,8 +560,8 @@ test_send_recv_cb_resubmit_disown(oomph::context& ctxt, std::size_t size, int ti struct recursive_recv_callback { - Env& env; - volatile int& received; + Env& env; + int volatile& received; void operator()(message msg, rank_type src, tag_type tag) { diff --git a/test/test_unique_function.cpp b/test/test_unique_function.cpp index 8e18b86e..d404f2fd 100644 --- a/test/test_unique_function.cpp +++ b/test/test_unique_function.cpp @@ -20,7 +20,7 @@ struct simple_function int i = 0; simple_function(int i_ = 0) - : i{i_} + : i{i_} { } @@ -39,9 +39,9 @@ struct simple_function TEST(unqiue_function, simple_function) { - simple_function f1(1); - simple_function f2(2); - simple_function f3(0); + simple_function f1(1); + simple_function f2(2); + simple_function f3(0); oomph::util::unique_function uf1{std::move(f1)}; EXPECT_EQ(1, uf1(3)); EXPECT_EQ(3, uf1(4)); @@ -53,18 +53,16 @@ TEST(unqiue_function, simple_function) EXPECT_EQ(3, uf1(4)); } - void test_stats(ctor_stats_data const& stats, int n_ctor, int n_dtor, int n_dtor_of_moved, int n_move_ctor, int n_calls); - // small function which fits within the stack buffer struct small_function { ctor_stats stats; small_function(ctor_stats_data& d) - : stats{d} + : stats{d} { } @@ -87,7 +85,7 @@ struct large_function : public small_function std::array buffer; large_function(ctor_stats_data& d) - : small_function(d) + : small_function(d) { } }; @@ -95,9 +93,8 @@ struct large_function : public small_function // test (move) constructor // ======================= -template -void -test_ctor(function_registry& registry) +template +void test_ctor(function_registry& registry) { using namespace oomph::util; @@ -107,14 +104,14 @@ test_ctor(function_registry& registry) uf(); } { - auto f = registry.template make("b_F1_0"); + auto f = registry.template make("b_F1_0"); unique_function uf(std::move(f)); uf(); uf(); } { - auto f1 = registry.template make("c_F1_0"); - auto f2 = registry.template make("c_F1_1"); + auto f1 = registry.template make("c_F1_0"); + auto f2 = registry.template make("c_F1_1"); unique_function uf; uf = std::move(f1); uf(); @@ -122,8 +119,8 @@ test_ctor(function_registry& registry) uf(); } { - auto f1 = registry.template make("d_F1_0"); - auto f2 = registry.template make("d_F2_0"); + auto f1 = registry.template make("d_F1_0"); + auto f2 = registry.template make("d_F2_0"); unique_function uf; uf = std::move(f1); uf(); @@ -169,9 +166,8 @@ TEST(unqiue_function, ctor_large) // test move assign // ================ -template -void -test_move(function_registry& registry) +template +void test_move(function_registry& registry) { using namespace oomph::util; @@ -248,8 +244,7 @@ TEST(unqiue_function, move_large) } // implementation of check function -void -test_stats(ctor_stats_data const& stats, int n_ctor, int n_dtor, int n_dtor_of_moved, +void test_stats(ctor_stats_data const& stats, int n_ctor, int n_dtor, int n_dtor_of_moved, int n_move_ctor, int n_calls) { EXPECT_EQ(stats.n_ctor, n_ctor); diff --git a/test/test_unsafe_shared_ptr.cpp b/test/test_unsafe_shared_ptr.cpp index a592f6b1..77ae34ea 100644 --- a/test/test_unsafe_shared_ptr.cpp +++ b/test/test_unsafe_shared_ptr.cpp @@ -15,11 +15,11 @@ struct my_int { ctor_stats m_stats; - int m_i; + int m_i; my_int(ctor_stats_data& d, int i) - : m_stats{d} - , m_i(i) + : m_stats{d} + , m_i(i) { }