From de4bc05be5fe736d28f462366c027b5a8d44c46b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20B=C3=B6sch?= <48126478+boeschf@users.noreply.github.com> Date: Thu, 7 Oct 2021 22:44:03 +0200 Subject: [PATCH 1/6] syntax --- include/oomph/context.hpp | 4 ++++ src/CMakeLists.txt | 1 + src/mpi/context.hpp | 5 +++++ src/src.cpp | 6 ++++++ test/CMakeLists.txt | 3 ++- 5 files changed, 18 insertions(+), 1 deletion(-) diff --git a/include/oomph/context.hpp b/include/oomph/context.hpp index 23092972..3ea29fc4 100644 --- a/include/oomph/context.hpp +++ b/include/oomph/context.hpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace oomph { @@ -60,6 +61,9 @@ class context communicator get_communicator(); + template + void map_tensor(void* ptr); + private: detail::message_buffer make_buffer_core(std::size_t size); #if HWMALLOC_ENABLE_DEVICE diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f125e6d9..4f392827 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,6 @@ target_sources(oomph_common PRIVATE barrier.cpp) target_sources(oomph_common PRIVATE rank_topology.cpp) +#target_sources(oomph_common PRIVATE tensor.cpp) if (OOMPH_WITH_MPI) add_subdirectory(mpi) diff --git a/src/mpi/context.hpp b/src/mpi/context.hpp index b8bcbe3f..5de4f285 100644 --- a/src/mpi/context.hpp +++ b/src/mpi/context.hpp @@ -47,6 +47,11 @@ class context_impl : public context_base void lock(communicator::rank_type r) { m_rma_context.lock(r); } communicator_impl* get_communicator(); + + template + void register_tensor(impl::tensor const & t) + { + } }; template<> diff --git a/src/src.cpp b/src/src.cpp index c64f38ad..72cc8d46 100644 --- a/src/src.cpp +++ b/src/src.cpp @@ -38,6 +38,12 @@ context::get_communicator() return {m->get_communicator()}; } +template<> +void context::map_tensor<1>(void* ptr) +{ + m->register_tensor(make_tensor<1>(ptr)); +} + /////////////////////////////// // communicator // /////////////////////////////// diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b512a7bf..acbbd946 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,7 +7,8 @@ add_subdirectory(mpi_runner) # --------------------------------------------------------------------- # list of tests to be executed -set(parallel_tests test_context test_send_recv test_send_multi test_cancel test_barrier test_locality) +set(parallel_tests test_context test_send_recv test_send_multi test_cancel test_barrier + test_locality test_tensor) # creates an object library (i.e. *.o file) function(compile_test t_) From dba82c08926b254b607c3852a5d4de325c22fea4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20B=C3=B6sch?= <48126478+boeschf@users.noreply.github.com> Date: Tue, 2 Nov 2021 09:39:19 +0100 Subject: [PATCH 2/6] first trial, cpu only, no rma yet --- cmake/oomph_common.cmake | 4 + include/oomph/communicator.hpp | 22 +- include/oomph/context.hpp | 7 +- include/oomph/tensor/detail/map.hpp | 74 +++++ include/oomph/tensor/detail/terminal.hpp | 231 ++++++++++++++++ include/oomph/tensor/layout.hpp | 57 ++++ include/oomph/tensor/map.hpp | 82 ++++++ include/oomph/tensor/map_fwd.hpp | 21 ++ include/oomph/tensor/range.hpp | 48 ++++ include/oomph/tensor/receiver.hpp | 67 +++++ include/oomph/tensor/sender.hpp | 68 +++++ include/oomph/tensor/vector.hpp | 206 ++++++++++++++ src/mpi/context.hpp | 5 - src/src.cpp | 59 +--- test/test_tensor.cpp | 331 +++++++++++++++++++++++ 15 files changed, 1214 insertions(+), 68 deletions(-) create mode 100644 include/oomph/tensor/detail/map.hpp create mode 100644 include/oomph/tensor/detail/terminal.hpp create mode 100644 include/oomph/tensor/layout.hpp create mode 100644 include/oomph/tensor/map.hpp create mode 100644 include/oomph/tensor/map_fwd.hpp create mode 100644 include/oomph/tensor/range.hpp create mode 100644 include/oomph/tensor/receiver.hpp create mode 100644 include/oomph/tensor/sender.hpp create mode 100644 include/oomph/tensor/vector.hpp create mode 100644 test/test_tensor.cpp diff --git a/cmake/oomph_common.cmake b/cmake/oomph_common.cmake index a4966724..8e4435fb 100644 --- a/cmake/oomph_common.cmake +++ b/cmake/oomph_common.cmake @@ -7,9 +7,13 @@ mark_as_advanced(OOMPH_USE_FAST_PIMPL) # --------------------------------------------------------------------- # compiler and linker flags # --------------------------------------------------------------------- +#set(cxx_lang "$") function(oomph_target_compile_options target) set_target_properties(${target} PROPERTIES INTERFACE_POSITION_INDEPENDENT_CODE ON) target_compile_options(${target} PRIVATE -Wall -Wextra -Wpedantic) + #target_compile_options(${target} PRIVATE + # $<${cxx_lang}:$> + #) endfunction() function(oomph_target_link_options target) diff --git a/include/oomph/communicator.hpp b/include/oomph/communicator.hpp index 6ef5ae1d..9e4f74c7 100644 --- a/include/oomph/communicator.hpp +++ b/include/oomph/communicator.hpp @@ -23,12 +23,14 @@ namespace oomph { -class context; -class send_channel_base; -class recv_channel_base; - +class context_impl; class communicator_impl; +namespace detail +{ +communicator get_communicator(context_impl* c); +} // namespace detail + class communicator { public: @@ -42,9 +44,7 @@ class communicator static constexpr tag_type any_tag = -1; private: - friend class context; - friend class send_channel_base; - friend class recv_channel_base; + friend communicator detail::get_communicator(context_impl*); struct schedule { @@ -202,7 +202,7 @@ class communicator // ==================== template - [[nodiscard]] recv_request recv(message_buffer& msg, rank_type src, tag_type tag) + recv_request recv(message_buffer& msg, rank_type src, tag_type tag) { assert(msg); auto& scheduled = m_schedule->scheduled_recvs; @@ -213,7 +213,7 @@ class communicator } template - [[nodiscard]] send_request send(message_buffer const& msg, rank_type dst, tag_type tag) + send_request send(message_buffer const& msg, rank_type dst, tag_type tag) { assert(msg); auto& scheduled = m_schedule->scheduled_sends; @@ -224,8 +224,8 @@ class communicator } template - [[nodiscard]] send_request send_multi(message_buffer const& msg, - std::vector const& neighs, tag_type tag) + send_request send_multi(message_buffer const& msg, std::vector const& neighs, + tag_type tag) { assert(msg); auto& scheduled = m_schedule->scheduled_sends; diff --git a/include/oomph/context.hpp b/include/oomph/context.hpp index 55d11183..f15e3832 100644 --- a/include/oomph/context.hpp +++ b/include/oomph/context.hpp @@ -13,9 +13,9 @@ #include #include #include +#include #include #include -#include namespace oomph { @@ -81,8 +81,9 @@ class context communicator get_communicator(); - template - void map_tensor(void* ptr); + template + tensor::map map_tensor( + tensor::vector const& extents, T* first, T* last); private: detail::message_buffer make_buffer_core(std::size_t size); diff --git a/include/oomph/tensor/detail/map.hpp b/include/oomph/tensor/detail/map.hpp new file mode 100644 index 00000000..946edca1 --- /dev/null +++ b/include/oomph/tensor/detail/map.hpp @@ -0,0 +1,74 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +namespace tensor +{ +namespace detail +{ +template +class map +{ + public: + static constexpr std::size_t dim() noexcept { return Layout::max_arg + 1; }; + + using vec = vector; + + protected: + static constexpr std::size_t s_stride_1_dim = Layout::find(dim() - 1); + + protected: + T* m_data; + std::size_t m_num_elements; + vec m_extents; + vec m_strides; + std::size_t m_line_size; + + public: + map(vec const& extents, T* first, T* last) + : m_data{first} + , m_num_elements{product(extents)} + , m_extents{extents} + { + auto const stride_1_extent = m_extents[s_stride_1_dim]; + auto const num_lines = m_num_elements / stride_1_extent; + auto const total_padding = (((last + 1) - first) - m_num_elements) * sizeof(T); + std::size_t const padding = (num_lines == 1) ? 0 : total_padding / (num_lines - 1); + m_strides[s_stride_1_dim] = 1; + std::size_t s = stride_1_extent * sizeof(T) + padding; + assert((s / sizeof(T)) * sizeof(T) == s); + s /= sizeof(T); + m_line_size = s; + for (std::size_t i = 1; i < dim(); ++i) + { + m_strides[Layout::find(dim() - 1 - i)] = s; + s *= m_extents[Layout::find(dim() - 1 - i)]; + } + } + + map(map const&) noexcept = default; + map(map&&) noexcept = default; + map& operator=(map const&) noexcept = default; + map& operator=(map&&) noexcept = default; + + public: + auto const& strides() const noexcept { return m_strides; } + vec const& extents() const noexcept { return m_extents; } + std::size_t line_size() const noexcept { return m_line_size; } + T* get_address(vec coord) const noexcept { return m_data + dot(coord, m_strides); } +}; + +} // namespace detail +} // namespace tensor +} // namespace oomph diff --git a/include/oomph/tensor/detail/terminal.hpp b/include/oomph/tensor/detail/terminal.hpp new file mode 100644 index 00000000..efaa43af --- /dev/null +++ b/include/oomph/tensor/detail/terminal.hpp @@ -0,0 +1,231 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace oomph +{ +namespace tensor +{ +namespace detail +{ +template +class terminal; + +template +class terminal> +{ + public: + using map_type = map; + static constexpr std::size_t dim() noexcept { return map_type::dim(); }; + using vec = vector; + using range_type = range; + + static constexpr std::size_t last_dim = Layout::find(0); + + protected: + struct transport_range + { + range_type m_range; + message_buffer m_message; + int m_rank; + int m_tag; + bool m_direct; + }; + + struct serialization_range + { + range_type m_range; + T* m_ptr; + }; + + protected: + map_type m_map; + std::unique_ptr m_comm; + std::vector m_transport_ranges; + std::vector m_serialization_ranges; + bool m_connected = false; + + public: + template + terminal(Map& m) + : m_map{m} + , m_comm{std::make_unique(oomph::detail::get_communicator(m.m_context))} + { + } + + terminal(terminal&&) noexcept = default; + terminal& operator=(terminal&&) noexcept = default; + + public: + void add_range(range_type const& view, int rank, int tag) + { + assert(!m_connected); + //x==*, y==1, z==1, w==1 -> direct + //x==Nx, y==*, z==1, w==1 -> direct + //x==Nx, y==Ny z==*, w==1 -> direct + //x==Nx, y==Ny z==Nz, w==* -> direct + + bool found_subset = (view.extents()[last_dim] != 1); + bool direct = true; + for (std::size_t d = 1; d < dim(); ++d) + { + auto const D = Layout::find(d); + if (found_subset && view.extents()[D] != m_map.extents()[D]) + { + direct = false; + break; + } + else if (view.extents()[D] != 1) + { + found_subset = true; + } + } + + if (direct) + { + auto ext = view.extents(); + ext[Layout::find(dim() - 1)] = m_map.line_size(); + auto const n_elements = product(ext); + m_transport_ranges.push_back(transport_range{view, + m_comm->make_buffer(m_map.get_address(view.first()), n_elements), rank, tag, + true}); + } + else + { + auto const n_elements = product(view.extents()); + auto const n_elements_slice = n_elements / view.extents()[last_dim]; + m_transport_ranges.push_back( + transport_range{view, m_comm->make_buffer(n_elements), rank, tag, false}); + T* ptr = m_transport_ranges.back().m_message.data(); + + std::size_t const first_k = view.first()[last_dim]; + std::size_t const last_k = first_k + view.extents()[last_dim]; + for (std::size_t k = first_k; k < last_k; ++k) + { + auto first = view.first(); + first[last_dim] = k; + auto ext = view.extents(); + ext[last_dim] = 1; + + m_serialization_ranges.push_back(serialization_range{range_type{first, ext}, ptr}); + ptr += n_elements_slice; + } + + std::sort(m_serialization_ranges.begin(), m_serialization_ranges.end(), + [](auto const& a, auto const& b) + { + auto const& first_a = a.m_range.first(); + auto const& first_b = b.m_range.first(); + for (std::size_t d = 0; d < dim(); ++d) + { + if (first_a[Layout::find(d)] < first_b[Layout::find(d)]) return true; + if (first_a[Layout::find(d)] > first_b[Layout::find(d)]) return false; + } + return true; + }); + } + } + + protected: + void connect() + { + assert(!m_connected); + m_connected = true; + } + + T* serialize(serialization_range const& r, T* dst, vec coord, + std::integral_constant) + { + static constexpr std::size_t D = Layout::find(dim() - 1); + T const* src = m_map.get_address(coord); + std::size_t const n = r.m_range.extents()[D]; + for (std::size_t i = 0; i < n; ++i) dst[i] = src[i]; + + return dst + n; + } + + T const* serialize(serialization_range const& r, T const* src, vec coord, + std::integral_constant) + { + static constexpr std::size_t D = Layout::find(dim() - 1); + T* dst = m_map.get_address(coord); + std::size_t const n = r.m_range.extents()[D]; + for (std::size_t i = 0; i < n; ++i) dst[i] = src[i]; + return src + n; + } + + template + Ptr serialize(serialization_range const& r, Ptr ptr, vec coord, + std::integral_constant) + { + static constexpr std::size_t D = Layout::find(N); + std::size_t const first = r.m_range.first()[D]; + std::size_t const last = first + r.m_range.extents()[D]; + for (; coord[D] < last; ++coord[D]) + ptr = serialize(r, ptr, coord, std::integral_constant{}); + return ptr; + } + + struct pack_handle + { + bool is_ready() const noexcept { return true; } + void wait() {} + }; + + pack_handle pack() + { + assert(m_connected); + for (auto& r : m_serialization_ranges) + serialize(r, r.m_ptr, r.m_range.first(), std::integral_constant()); + return {}; + } + + pack_handle unpack() + { + assert(m_connected); + for (auto& r : m_serialization_ranges) + serialize(r, (T const*)r.m_ptr, r.m_range.first(), + std::integral_constant()); + return {}; + } + + struct handle + { + communicator* m_comm; + bool is_ready() const noexcept { return m_comm->is_ready(); } + void progress() { m_comm->progress(); } + void wait() { m_comm->wait_all(); } + }; + + handle send() + { + assert(m_connected); + for (auto& r : m_transport_ranges) m_comm->send(r.m_message, r.m_rank, r.m_tag); + return {m_comm.get()}; + } + + handle recv() + { + assert(m_connected); + for (auto& r : m_transport_ranges) m_comm->recv(r.m_message, r.m_rank, r.m_tag); + return {m_comm.get()}; + } +}; + +} // namespace detail +} // namespace tensor +} // namespace oomph diff --git a/include/oomph/tensor/layout.hpp b/include/oomph/tensor/layout.hpp new file mode 100644 index 00000000..1a9eacd5 --- /dev/null +++ b/include/oomph/tensor/layout.hpp @@ -0,0 +1,57 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +namespace tensor +{ +template +struct layout +{ + static constexpr std::size_t N = sizeof...(I); + static constexpr std::size_t max_arg = N - 1; + + using arg_list = boost::mp11::mp_list...>; + template + using sorter = boost::mp11::mp_bool<(A::value < B::value)>; + using sorted_arg_list = boost::mp11::mp_sort; + using lookup_list = boost::mp11::mp_iota_c; + static_assert(std::is_same>::value, + "arguments must be unique, contiguous and starting from 0"); + template + using F = boost::mp11::mp_find; + using reverse_lookup = boost::mp11::mp_transform; + + // Get the position of the element with value `i` in the layout + static constexpr std::size_t find(int i) + { + return find_impl(i, boost::mp11::make_index_sequence{}); + } + + // Get the value of the element at position `i` in the layout + static constexpr int at(std::size_t i) + { + std::size_t const ri[] = {I...}; + return ri[i]; + } + + template + static constexpr std::size_t find_impl(int i, boost::mp11::index_sequence) + { + std::size_t const ri[] = {boost::mp11::mp_at_c::value...}; + return ri[i]; + } +}; + +} // namespace tensor +} // namespace oomph diff --git a/include/oomph/tensor/map.hpp b/include/oomph/tensor/map.hpp new file mode 100644 index 00000000..7dcb45e8 --- /dev/null +++ b/include/oomph/tensor/map.hpp @@ -0,0 +1,82 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include + +namespace oomph +{ +namespace tensor +{ +namespace detail +{ +template +class terminal; +} // namespace detail + +template +class map : public detail::map +{ + private: + friend class oomph::context; + friend class detail::terminal>; + using base = detail::map; + + public: + using base::dim; + using vec = typename base::vec; + + private: + context_impl* m_context; + + private: + map(context_impl* c, vec const& extents, T* first, T* last) + : base(extents, first, last) + , m_context{c} + { + } + + public: + map(map const&) = delete; + map& operator=(map const&) = delete; + + map(map&& other) noexcept + : base(std::move(other)) + , m_context{std::exchange(other.m_context, nullptr)} + { + } + + map& operator=(map&& other) noexcept + { + deregister(); + static_cast(*this) = std::move(other); + m_context = std::exchange(other.m_context, nullptr); + } + + ~map() { deregister(); } + + private: + void deregister() + { + if (m_context) {} + } +}; + +} // namespace tensor + +template +tensor::map +context::map_tensor(tensor::vector const& extents, T* first, + T* last) +{ + return {m.get(), extents, first, last}; +} +} // namespace oomph diff --git a/include/oomph/tensor/map_fwd.hpp b/include/oomph/tensor/map_fwd.hpp new file mode 100644 index 00000000..aab703b2 --- /dev/null +++ b/include/oomph/tensor/map_fwd.hpp @@ -0,0 +1,21 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +namespace tensor +{ +template +class map; +} // namespace tensor +} // namespace oomph diff --git a/include/oomph/tensor/range.hpp b/include/oomph/tensor/range.hpp new file mode 100644 index 00000000..e875d5b1 --- /dev/null +++ b/include/oomph/tensor/range.hpp @@ -0,0 +1,48 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +namespace tensor +{ +template +class range +{ + public: + using vec = vector; + //template + //friend class oomph::detail::map; + + private: + vec m_first; + vec m_extents; + vec m_increments; + + public: + constexpr range(vec const& first, vec const& extents, + vec const& increments = make_uniform((std::size_t)1)) + : m_first{first} + , m_extents{extents} + , m_increments{increments} + { + } + + constexpr range(range const&) noexcept = default; + range& operator=(range const&) noexcept = default; + + vec const& first() const noexcept { return m_first; } + vec const& extents() const noexcept { return m_extents; } + vec const& increments() const noexcept { return m_increments; } +}; +} // namespace tensor +} // namespace oomph diff --git a/include/oomph/tensor/receiver.hpp b/include/oomph/tensor/receiver.hpp new file mode 100644 index 00000000..36bd300c --- /dev/null +++ b/include/oomph/tensor/receiver.hpp @@ -0,0 +1,67 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include + +namespace oomph +{ +namespace tensor +{ +template +class receiver; + +template +class receiver> : private detail::terminal> +{ + private: + using base = detail::terminal>; + + public: + using map_type = map; + using range_type = typename base::range_type; + using pack_handle = typename base::pack_handle; + using handle = typename base::handle; + + public: + receiver(map_type& m) + : base(m) + { + } + + receiver(receiver&&) noexcept = default; + receiver& operator=(receiver&&) noexcept = default; + + receiver& add_src(range_type const& view, int rank, int tag) + { + base::add_range(view, rank, tag); + return *this; + } + + receiver& connect() + { + base::connect(); + return *this; + } + + pack_handle unpack() { return base::unpack(); } + + handle recv() { return base::recv(); } +}; + +template +receiver> +make_receiver(map& m) +{ + return {m}; +} +} // namespace tensor +} // namespace oomph diff --git a/include/oomph/tensor/sender.hpp b/include/oomph/tensor/sender.hpp new file mode 100644 index 00000000..33f7b61c --- /dev/null +++ b/include/oomph/tensor/sender.hpp @@ -0,0 +1,68 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include + +namespace oomph +{ +namespace tensor +{ +template +class sender; + +template +class sender> : private detail::terminal> +{ + private: + using base = detail::terminal>; + + public: + using map_type = map; + using range_type = typename base::range_type; + using pack_handle = typename base::pack_handle; + using handle = typename base::handle; + + public: + sender(map_type& m) + : base(m) + { + } + + sender(sender&&) noexcept = default; + sender& operator=(sender&&) noexcept = default; + + sender& add_dst(range_type const& view, int rank, int tag) + { + base::add_range(view, rank, tag); + return *this; + } + + sender& connect() + { + base::connect(); + return *this; + } + + pack_handle pack() { return base::pack(); } + + handle send() { return base::send(); } +}; + +template +sender> +make_sender(map& m) +{ + return {m}; +} + +} // namespace tensor +} // namespace oomph diff --git a/include/oomph/tensor/vector.hpp b/include/oomph/tensor/vector.hpp new file mode 100644 index 00000000..db67a3a8 --- /dev/null +++ b/include/oomph/tensor/vector.hpp @@ -0,0 +1,206 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +namespace tensor +{ +template +struct vector +{ + T m_data[N]; + + static constexpr std::size_t size() noexcept { return N; } + + template + vector& operator=(vector const& v) + { + for (std::size_t i = 0; i < N; ++i) m_data[i] = v[i]; + } + + constexpr T const* cbegin() const { return m_data; } + constexpr T const* begin() const { return m_data; } + constexpr T* begin() { return m_data; } + + constexpr T const* cend() const { return cbegin() + N; } + constexpr T const* end() const { return begin() + N; } + constexpr T* end() { return begin() + N; } + + constexpr const T* data() const noexcept { return m_data; } + constexpr T* data() noexcept { return m_data; } + + constexpr T operator[](std::size_t i) const noexcept { return m_data[i]; } + T& operator[](std::size_t i) noexcept { return m_data[i]; } + + template + vector& operator+=(vector const& v) noexcept + { + for (std::size_t i = 0; i < N; ++i) m_data[i] += v[i]; + return *this; + } + + template + vector& operator+=(U const& u) noexcept + { + for (std::size_t i = 0; i < N; ++i) m_data[i] += u; + return *this; + } + + template + vector& operator-=(vector const& v) noexcept + { + for (std::size_t i = 0; i < N; ++i) m_data[i] += v[i]; + return *this; + } + + template + vector& operator-=(U const& u) noexcept + { + for (std::size_t i = 0; i < N; ++i) m_data[i] += u; + return *this; + } + + template + vector& operator*=(vector const& v) noexcept + { + for (std::size_t i = 0; i < N; ++i) m_data[i] *= v[i]; + return *this; + } + + template + vector& operator*=(U const& u) noexcept + { + for (std::size_t i = 0; i < N; ++i) m_data[i] *= u; + return *this; + } + + template + vector& operator/=(vector const& v) noexcept + { + for (std::size_t i = 0; i < N; ++i) m_data[i] /= v[i]; + return *this; + } + + template + vector& operator/=(U const& u) noexcept + { + for (std::size_t i = 0; i < N; ++i) m_data[i] /= u; + return *this; + } +}; + +template +inline constexpr auto +add(vector const& a, vector const& b, std::index_sequence) noexcept +{ + using R = std::remove_reference_t() + std::declval())>; + return R{(a[I] + b[I])...}; +} + +template +inline constexpr auto +operator+(vector const& a, vector const& b) noexcept +{ + return add(a, b, std::make_index_sequence{}); +} + +template +inline constexpr auto +subtract(vector const& a, vector const& b, std::index_sequence) noexcept +{ + using R = std::remove_reference_t() - std::declval())>; + return R{(a[I] - b[I])...}; +} + +template +inline constexpr auto +operator-(vector const& a, vector const& b) noexcept +{ + return subtract(a, b, std::make_index_sequence{}); +} + +template +inline constexpr auto +multiply(vector const& a, vector const& b, std::index_sequence) noexcept +{ + using R = std::remove_reference_t() * std::declval())>; + return R{(a[I] * b[I])...}; +} + +template +inline constexpr auto +operator*(vector const& a, vector const& b) noexcept +{ + return multiply(a, b, std::make_index_sequence{}); +} + +template +inline constexpr auto +divide(vector const& a, vector const& b, std::index_sequence) noexcept +{ + using R = std::remove_reference_t() / std::declval())>; + return R{(a[I] / b[I])...}; +} + +template +inline constexpr auto +operator/(vector const& a, vector const& b) noexcept +{ + return divide(a, b, std::make_index_sequence{}); +} + +template +inline constexpr auto +dot(vector const& a, vector const& b, std::index_sequence) noexcept +{ + return (... + (a[I] * b[I])); +} + +template +inline constexpr auto +dot(vector const& a, vector const& b) noexcept +{ + return dot(a, b, std::make_index_sequence{}); +} + +template +inline constexpr auto +product(vector const& v, std::index_sequence) noexcept +{ + return (... * v[I]); +} + +template +inline constexpr auto +product(vector const& v) noexcept +{ + return product(v, std::make_index_sequence{}); +} + +template +inline constexpr vector +make_uniform(T const& value, std::index_sequence) noexcept +{ + T const* const vref = &value; + return {vref[I - I]...}; +} + +template +inline constexpr vector +make_uniform(T const& value) noexcept +{ + return make_uniform(value, std::make_index_sequence{}); +} + +} // namespace tensor +} // namespace oomph diff --git a/src/mpi/context.hpp b/src/mpi/context.hpp index 5de4f285..b8bcbe3f 100644 --- a/src/mpi/context.hpp +++ b/src/mpi/context.hpp @@ -47,11 +47,6 @@ class context_impl : public context_base void lock(communicator::rank_type r) { m_rma_context.lock(r); } communicator_impl* get_communicator(); - - template - void register_tensor(impl::tensor const & t) - { - } }; template<> diff --git a/src/src.cpp b/src/src.cpp index 5137becb..05b15d46 100644 --- a/src/src.cpp +++ b/src/src.cpp @@ -35,13 +35,7 @@ context::~context() = default; communicator context::get_communicator() { - return {m->get_communicator()}; -} - -template<> -void context::map_tensor<1>(void* ptr) -{ - m->register_tensor(make_tensor<1>(ptr)); + return detail::get_communicator(m.get()); } /////////////////////////////// @@ -87,6 +81,15 @@ communicator::progress() m_impl->progress(); } +namespace detail +{ +communicator +get_communicator(context_impl* c) +{ + return {c->get_communicator()}; +} +} // namespace detail + /////////////////////////////// // message_buffer // /////////////////////////////// @@ -314,46 +317,4 @@ recv_request::cancel() return res; } -///////////////////////////////// -//// send_channel_base // -///////////////////////////////// -// -//send_channel_base::~send_channel_base() = default; -// -//void -//send_channel_base::connect() -//{ -// m_impl->connect(); -//} -// -///////////////////////////////// -//// recv_channel_base // -///////////////////////////////// -// -//recv_channel_base::~recv_channel_base() = default; -// -//void -//recv_channel_base::connect() -//{ -// m_impl->connect(); -//} -// -//std::size_t -//recv_channel_base::capacity() -//{ -// return m_impl->capacity(); -//} -// -//void* -//recv_channel_base::get(std::size_t& index) -//{ -// return m_impl->get(index); -//} -// -//recv_channel_impl* -//recv_channel_base::get_impl() noexcept -//{ -// return &(*m_impl); -//} - } // namespace oomph diff --git a/test/test_tensor.cpp b/test/test_tensor.cpp new file mode 100644 index 00000000..4662760b --- /dev/null +++ b/test/test_tensor.cpp @@ -0,0 +1,331 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include +#include +#include + +#include +#include "./mpi_runner/mpi_test_fixture.hpp" +#include +#include +#include +#include +#include + +template +void +test_layout() +{ + using namespace oomph; + using namespace oomph::tensor; + static_assert(layout::at(0) == I, ""); + static_assert(layout::at(1) == J, ""); + static_assert(layout::at(2) == K, ""); + + static constexpr std::size_t pos_0 = (I == 0) ? 0 : ((J == 0) ? 1 : 2); + static constexpr std::size_t pos_1 = (I == 1) ? 0 : ((J == 1) ? 1 : 2); + static constexpr std::size_t pos_2 = (I == 2) ? 0 : ((J == 2) ? 1 : 2); + + static_assert(layout::find(0) == pos_0, ""); + static_assert(layout::find(1) == pos_1, ""); + static_assert(layout::find(2) == pos_2, ""); +} + +TEST(layout, ctor) +{ + using namespace oomph; + using namespace oomph::tensor; + + test_layout<0, 1, 2>(); + test_layout<0, 2, 1>(); + test_layout<1, 0, 2>(); + test_layout<1, 2, 0>(); + test_layout<2, 0, 1>(); + test_layout<2, 1, 0>(); + //test_layout<2,2,0>(); +} + +template +void +test_tensor_strides(std::size_t X, std::size_t Y, std::size_t Z, std::size_t padding = 0) +{ + using namespace oomph; + using T = double; + using layout_t = tensor::layout; + + std::size_t dims[] = {X, Y, Z}; + + std::vector data( + (layout_t::find(2) == 0 + ? (X + padding) * Y * Z + : (layout_t::find(2) == 1 ? (Y + padding) * X * Z : (Z + padding) * X * Y)) - + padding); + + tensor::detail::map m({X, Y, Z}, data.data(), (data.data() + data.size()) - 1); + + EXPECT_EQ(m.strides()[layout_t::find(2)], 1); + EXPECT_EQ(m.strides()[layout_t::find(1)], dims[layout_t::find(2)] + padding); + EXPECT_EQ(m.strides()[layout_t::find(0)], + dims[layout_t::find(1)] * (dims[layout_t::find(2)] + padding)); +} + +TEST(tensor, strides) +{ + test_tensor_strides<2, 1, 0>(2, 3, 5); + test_tensor_strides<2, 1, 0>(2, 3, 5, 1); + test_tensor_strides<2, 1, 0>(2, 3, 5, 3); + test_tensor_strides<2, 0, 1>(2, 3, 5); + test_tensor_strides<2, 0, 1>(2, 3, 5, 1); + test_tensor_strides<2, 0, 1>(2, 3, 5, 3); + test_tensor_strides<1, 2, 0>(2, 3, 5); + test_tensor_strides<1, 2, 0>(2, 3, 5, 1); + test_tensor_strides<1, 2, 0>(2, 3, 5, 3); + test_tensor_strides<1, 0, 2>(2, 3, 5); + test_tensor_strides<1, 0, 2>(2, 3, 5, 1); + test_tensor_strides<1, 0, 2>(2, 3, 5, 3); +} + +TEST_F(mpi_test_fixture, ctor) +{ + using namespace oomph; + auto ctxt = context(MPI_COMM_WORLD, false); + + std::size_t const halo = 2; + std::size_t const x = 2; + std::size_t const y = 3; + std::size_t const z = 5; + + std::vector data((x + 2 * halo) * (y + 2 * halo) * (z + 2 * halo), world_rank); + + data[halo + 0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 0))] = 9000; + data[halo + 1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 0))] = 9001; + data[halo + 0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 0))] = 9010; + data[halo + 1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 0))] = 9011; + data[halo + 0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 0))] = 9020; + data[halo + 1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 0))] = 9021; + + data[halo + 0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 1))] = 9100; + data[halo + 1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 1))] = 9101; + data[halo + 0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 1))] = 9110; + data[halo + 1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 1))] = 9111; + data[halo + 0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 1))] = 9120; + data[halo + 1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 1))] = 9121; + + data[halo + 0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 2))] = 9200; + data[halo + 1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 2))] = 9201; + data[halo + 0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 2))] = 9210; + data[halo + 1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 2))] = 9211; + data[halo + 0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 2))] = 9220; + data[halo + 1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 2))] = 9221; + + data[halo + 0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 3))] = 9300; + data[halo + 1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 3))] = 9301; + data[halo + 0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 3))] = 9310; + data[halo + 1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 3))] = 9311; + data[halo + 0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 3))] = 9320; + data[halo + 1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 3))] = 9321; + + data[halo + 0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 4))] = 9400; + data[halo + 1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 4))] = 9401; + data[halo + 0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 4))] = 9410; + data[halo + 1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 4))] = 9411; + data[halo + 0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 4))] = 9420; + data[halo + 1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 4))] = 9421; + + using layout_t = tensor::layout<2, 1, 0>; + using map_t = tensor::map; + + map_t t = ctxt.map_tensor({x + 2 * halo, y + 2 * halo, z + 2 * halo}, data.data(), + (data.data() + data.size()) - 1); + + tensor::sender> s = tensor::make_sender(t); + tensor::receiver> r = tensor::make_receiver(t); + + if (world_rank == 0) + { + s + // y-z plane, +x direction + .add_dst({{halo + x - halo, halo, halo}, {halo, y, z}}, 1, 0) + // x-y plane, -z direction + .add_dst({{halo, halo, halo}, {x, y, halo}}, 2, 0) + // whole x-y plane, -z direction + .add_dst({{0, 0, halo}, {x + 2 * halo, y + 2 * halo, halo}}, 3, 0) + .connect(); + + s.pack().wait(); + s.send().wait(); + } + if (world_rank == 1) + { + r.add_src({{0, halo, halo}, {halo, y, z}}, 0, 0).connect(); + r.recv().wait(); + r.unpack().wait(); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 0))], 9000); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 0))], 9001); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 0))], 9010); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 0))], 9011); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 0))], 9020); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 0))], 9021); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 1))], 9100); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 1))], 9101); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 1))], 9110); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 1))], 9111); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 1))], 9120); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 1))], 9121); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 2))], 9200); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 2))], 9201); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 2))], 9210); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 2))], 9211); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 2))], 9220); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 2))], 9221); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 3))], 9300); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 3))], 9301); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 3))], 9310); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 3))], 9311); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 3))], 9320); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 3))], 9321); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 4))], 9400); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 4))], 9401); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 4))], 9410); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 4))], 9411); + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 4))], 9420); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 4))], 9421); + } + if (world_rank == 2) + { + r.add_src({{halo, halo, halo + z}, {x, y, halo}}, 0, 0).connect(); + r.recv().wait(); + r.unpack().wait(); + + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 5))], 9000); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 5))], 9001); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 5))], 9010); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 5))], 9011); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 5))], 9020); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 5))], 9021); + + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 6))], 9100); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 6))], 9101); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 6))], 9110); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 6))], 9111); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 6))], 9120); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 6))], 9121); + } + if (world_rank == 3) + { + r.add_src({{0, 0, halo + z}, {x + 2 * halo, y + 2 * halo, halo}}, 0, 0).connect(); + r.recv().wait(); + r.unpack().wait(); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[4 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 5))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[4 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 5))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 5))], 9000); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 5))], 9001); + EXPECT_EQ(data[4 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 5))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 5))], 9010); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 5))], 9011); + EXPECT_EQ(data[4 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 5))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 5))], 9020); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 5))], 9021); + EXPECT_EQ(data[4 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 5))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[4 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 5))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[4 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 5))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 5))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[4 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (0 + (y + 2 * halo) * (halo + 6))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[4 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (1 + (y + 2 * halo) * (halo + 6))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 6))], 9100); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 6))], 9101); + EXPECT_EQ(data[4 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (halo + 0 + (y + 2 * halo) * (halo + 6))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 6))], 9110); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 6))], 9111); + EXPECT_EQ(data[4 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (halo + 1 + (y + 2 * halo) * (halo + 6))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 6))], 9120); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 6))], 9121); + EXPECT_EQ(data[4 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (halo + 2 + (y + 2 * halo) * (halo + 6))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[4 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (5 + (y + 2 * halo) * (halo + 6))], 0); + + EXPECT_EQ(data[0 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[1 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 0 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[halo + 1 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[4 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 6))], 0); + EXPECT_EQ(data[5 + (x + 2 * halo) * (6 + (y + 2 * halo) * (halo + 6))], 0); + } +} From bd41baac4d04d70bd13fe05c2f2dd8e2885ba6c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20B=C3=B6sch?= <48126478+boeschf@users.noreply.github.com> Date: Tue, 2 Nov 2021 14:04:42 +0100 Subject: [PATCH 3/6] c++14 compatible vector products --- include/oomph/tensor/vector.hpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/include/oomph/tensor/vector.hpp b/include/oomph/tensor/vector.hpp index db67a3a8..0cd7a03b 100644 --- a/include/oomph/tensor/vector.hpp +++ b/include/oomph/tensor/vector.hpp @@ -159,6 +159,7 @@ operator/(vector const& a, vector const& b) noexcept return divide(a, b, std::make_index_sequence{}); } +#if __cplusplus >= 201703L template inline constexpr auto dot(vector const& a, vector const& b, std::index_sequence) noexcept @@ -172,7 +173,18 @@ dot(vector const& a, vector const& b) noexcept { return dot(a, b, std::make_index_sequence{}); } +#else +template +inline constexpr auto +dot(vector const& a, vector const& b) noexcept +{ + auto r = a[0] * b[0]; + for (std::size_t i = 1; i < N; ++i) r += a[i] * b[i]; + return r; +} +#endif +#if __cplusplus >= 201703L template inline constexpr auto product(vector const& v, std::index_sequence) noexcept @@ -186,6 +198,16 @@ product(vector const& v) noexcept { return product(v, std::make_index_sequence{}); } +#else +template +inline constexpr auto +product(vector const& v) noexcept +{ + T r = v[0]; + for (std::size_t i = 1; i < N; ++i) r *= v[i]; + return r; +} +#endif template inline constexpr vector From 4883968f91c08e89dc4eab12464c831f0819da2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20B=C3=B6sch?= <48126478+boeschf@users.noreply.github.com> Date: Tue, 2 Nov 2021 17:24:41 +0100 Subject: [PATCH 4/6] stage extension --- include/oomph/tensor/detail/terminal.hpp | 134 ++++++++++++++++++----- include/oomph/tensor/receiver.hpp | 8 +- include/oomph/tensor/sender.hpp | 8 +- test/test_tensor.cpp | 14 ++- 4 files changed, 124 insertions(+), 40 deletions(-) diff --git a/include/oomph/tensor/detail/terminal.hpp b/include/oomph/tensor/detail/terminal.hpp index efaa43af..26c40ac8 100644 --- a/include/oomph/tensor/detail/terminal.hpp +++ b/include/oomph/tensor/detail/terminal.hpp @@ -14,6 +14,8 @@ #include #include #include +#include +#include #include namespace oomph @@ -39,11 +41,11 @@ class terminal> protected: struct transport_range { - range_type m_range; - message_buffer m_message; - int m_rank; - int m_tag; - bool m_direct; + range_type m_range; + std::shared_ptr> m_message; + int m_rank; + int m_tag; + bool m_direct; }; struct serialization_range @@ -52,18 +54,74 @@ class terminal> T* m_ptr; }; + struct buffer_cache + { + communicator* m_comm; + std::vector>> m_messages; + std::map> m_available_idx; + + buffer_cache(communicator* c) noexcept + : m_comm{c} + { + } + + buffer_cache(buffer_cache&&) noexcept = default; + buffer_cache& operator=(buffer_cache&&) noexcept = default; + + std::shared_ptr> operator()(std::size_t size, std::size_t stage) + { + std::set* index_set = nullptr; + auto it = m_available_idx.find(stage); + if (it == m_available_idx.end()) + { + index_set = &(m_available_idx[stage]); + for (std::size_t i = 0; i < m_messages.size(); ++i) index_set->insert(i); + } + else + { + index_set = &(it->second); + } + + auto m_it = std::find_if(index_set->begin(), index_set->end(), + [this, size](std::size_t i) { return (m_messages[i]->size() == size); }); + + if (m_it == index_set->end()) + { + m_messages.push_back( + std::make_shared>(m_comm->make_buffer(size))); + for (auto& s : m_available_idx) s.second.insert(m_messages.size() - 1); + index_set->erase(m_messages.size() - 1); + return m_messages.back(); + } + else + { + auto res = *m_it; + index_set->erase(m_it); + return m_messages[res]; + } + } + }; + + struct stage_t + { + std::vector m_transport_ranges; + std::vector m_serialization_ranges; + }; + protected: - map_type m_map; - std::unique_ptr m_comm; - std::vector m_transport_ranges; - std::vector m_serialization_ranges; - bool m_connected = false; + map_type m_map; + std::unique_ptr m_comm; + std::map m_stages; + buffer_cache m_buffer_cache; + std::vector m_stage_lu; + bool m_connected = false; public: template terminal(Map& m) : m_map{m} , m_comm{std::make_unique(oomph::detail::get_communicator(m.m_context))} + , m_buffer_cache{m_comm.get()} { } @@ -71,7 +129,7 @@ class terminal> terminal& operator=(terminal&&) noexcept = default; public: - void add_range(range_type const& view, int rank, int tag) + void add_range(range_type const& view, int rank, int tag, std::size_t stage) { assert(!m_connected); //x==*, y==1, z==1, w==1 -> direct @@ -95,22 +153,26 @@ class terminal> } } + auto& s = m_stages[stage]; + if (direct) { auto ext = view.extents(); ext[Layout::find(dim() - 1)] = m_map.line_size(); auto const n_elements = product(ext); - m_transport_ranges.push_back(transport_range{view, - m_comm->make_buffer(m_map.get_address(view.first()), n_elements), rank, tag, - true}); + s.m_transport_ranges.push_back(transport_range{view, + std::make_shared>( + m_comm->make_buffer(m_map.get_address(view.first()), n_elements)), + rank, tag, true}); } else { auto const n_elements = product(view.extents()); auto const n_elements_slice = n_elements / view.extents()[last_dim]; - m_transport_ranges.push_back( - transport_range{view, m_comm->make_buffer(n_elements), rank, tag, false}); - T* ptr = m_transport_ranges.back().m_message.data(); + s.m_transport_ranges.push_back(transport_range{view, + //m_comm->make_buffer(n_elements), + m_buffer_cache(n_elements, stage), rank, tag, false}); + T* ptr = s.m_transport_ranges.back().m_message->data(); std::size_t const first_k = view.first()[last_dim]; std::size_t const last_k = first_k + view.extents()[last_dim]; @@ -121,11 +183,12 @@ class terminal> auto ext = view.extents(); ext[last_dim] = 1; - m_serialization_ranges.push_back(serialization_range{range_type{first, ext}, ptr}); + s.m_serialization_ranges.push_back( + serialization_range{range_type{first, ext}, ptr}); ptr += n_elements_slice; } - std::sort(m_serialization_ranges.begin(), m_serialization_ranges.end(), + std::sort(s.m_serialization_ranges.begin(), s.m_serialization_ranges.end(), [](auto const& a, auto const& b) { auto const& first_a = a.m_range.first(); @@ -144,9 +207,18 @@ class terminal> void connect() { assert(!m_connected); + assert(m_stages.size() > 0); + m_stage_lu = std::vector(m_stages.rbegin()->first + 1, nullptr); + for (auto& kvp : m_stages) m_stage_lu[kvp.first] = &(kvp.second); m_connected = true; } + stage_t* get_stage(std::size_t stage) + { + assert(stage < m_stage_lu.size()); + return m_stage_lu[stage]; + } + T* serialize(serialization_range const& r, T* dst, vec coord, std::integral_constant) { @@ -186,18 +258,22 @@ class terminal> void wait() {} }; - pack_handle pack() + pack_handle pack(std::size_t stage) { assert(m_connected); - for (auto& r : m_serialization_ranges) + stage_t* s = get_stage(stage); + if (!s) return {}; + for (auto& r : s->m_serialization_ranges) serialize(r, r.m_ptr, r.m_range.first(), std::integral_constant()); return {}; } - pack_handle unpack() + pack_handle unpack(std::size_t stage) { assert(m_connected); - for (auto& r : m_serialization_ranges) + stage_t* s = get_stage(stage); + if (!s) return {}; + for (auto& r : s->m_serialization_ranges) serialize(r, (T const*)r.m_ptr, r.m_range.first(), std::integral_constant()); return {}; @@ -211,17 +287,21 @@ class terminal> void wait() { m_comm->wait_all(); } }; - handle send() + handle send(std::size_t stage) { assert(m_connected); - for (auto& r : m_transport_ranges) m_comm->send(r.m_message, r.m_rank, r.m_tag); + stage_t* s = get_stage(stage); + if (!s) return {m_comm.get()}; + for (auto& r : s->m_transport_ranges) m_comm->send(*r.m_message, r.m_rank, r.m_tag); return {m_comm.get()}; } - handle recv() + handle recv(std::size_t stage) { assert(m_connected); - for (auto& r : m_transport_ranges) m_comm->recv(r.m_message, r.m_rank, r.m_tag); + stage_t* s = get_stage(stage); + if (!s) return {m_comm.get()}; + for (auto& r : s->m_transport_ranges) m_comm->recv(*r.m_message, r.m_rank, r.m_tag); return {m_comm.get()}; } }; diff --git a/include/oomph/tensor/receiver.hpp b/include/oomph/tensor/receiver.hpp index 36bd300c..8dd25218 100644 --- a/include/oomph/tensor/receiver.hpp +++ b/include/oomph/tensor/receiver.hpp @@ -40,9 +40,9 @@ class receiver> : private detail::terminal receiver(receiver&&) noexcept = default; receiver& operator=(receiver&&) noexcept = default; - receiver& add_src(range_type const& view, int rank, int tag) + receiver& add_src(range_type const& view, int rank, int tag, std::size_t stage = 0) { - base::add_range(view, rank, tag); + base::add_range(view, rank, tag, stage); return *this; } @@ -52,9 +52,9 @@ class receiver> : private detail::terminal return *this; } - pack_handle unpack() { return base::unpack(); } + pack_handle unpack(std::size_t stage = 0) { return base::unpack(stage); } - handle recv() { return base::recv(); } + handle recv(std::size_t stage = 0) { return base::recv(stage); } }; template diff --git a/include/oomph/tensor/sender.hpp b/include/oomph/tensor/sender.hpp index 33f7b61c..3230a4e9 100644 --- a/include/oomph/tensor/sender.hpp +++ b/include/oomph/tensor/sender.hpp @@ -40,9 +40,9 @@ class sender> : private detail::terminal> sender(sender&&) noexcept = default; sender& operator=(sender&&) noexcept = default; - sender& add_dst(range_type const& view, int rank, int tag) + sender& add_dst(range_type const& view, int rank, int tag, std::size_t stage = 0) { - base::add_range(view, rank, tag); + base::add_range(view, rank, tag, stage); return *this; } @@ -52,9 +52,9 @@ class sender> : private detail::terminal> return *this; } - pack_handle pack() { return base::pack(); } + pack_handle pack(std::size_t stage = 0) { return base::pack(stage); } - handle send() { return base::send(); } + handle send(std::size_t stage = 0) { return base::send(stage); } }; template diff --git a/test/test_tensor.cpp b/test/test_tensor.cpp index 4662760b..19de9ca0 100644 --- a/test/test_tensor.cpp +++ b/test/test_tensor.cpp @@ -153,15 +153,19 @@ TEST_F(mpi_test_fixture, ctor) { s // y-z plane, +x direction - .add_dst({{halo + x - halo, halo, halo}, {halo, y, z}}, 1, 0) + .add_dst({{halo + x - halo, halo, halo}, {halo, y, z}}, 1, 0,1) // x-y plane, -z direction - .add_dst({{halo, halo, halo}, {x, y, halo}}, 2, 0) + .add_dst({{halo, halo, halo}, {x, y, halo}}, 2, 0,0) // whole x-y plane, -z direction - .add_dst({{0, 0, halo}, {x + 2 * halo, y + 2 * halo, halo}}, 3, 0) + .add_dst({{0, 0, halo}, {x + 2 * halo, y + 2 * halo, halo}}, 3, 0,2) .connect(); - s.pack().wait(); - s.send().wait(); + s.pack(1).wait(); + s.send(1).wait(); + s.pack(0).wait(); + s.send(0).wait(); + s.pack(2).wait(); + s.send(2).wait(); } if (world_rank == 1) { From 4e0412bd3f98d05fdeb98b4cb6cdbf1dc57170b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20B=C3=B6sch?= <48126478+boeschf@users.noreply.github.com> Date: Wed, 3 Nov 2021 11:32:55 +0100 Subject: [PATCH 5/6] buffer cache --- include/oomph/tensor/buffer_cache.hpp | 31 ++++++ include/oomph/tensor/detail/buffer_cache.hpp | 104 +++++++++++++++++++ include/oomph/tensor/detail/terminal.hpp | 76 ++++---------- include/oomph/tensor/receiver.hpp | 15 ++- include/oomph/tensor/sender.hpp | 15 ++- test/test_tensor.cpp | 6 +- 6 files changed, 187 insertions(+), 60 deletions(-) create mode 100644 include/oomph/tensor/buffer_cache.hpp create mode 100644 include/oomph/tensor/detail/buffer_cache.hpp diff --git a/include/oomph/tensor/buffer_cache.hpp b/include/oomph/tensor/buffer_cache.hpp new file mode 100644 index 00000000..16cd2e49 --- /dev/null +++ b/include/oomph/tensor/buffer_cache.hpp @@ -0,0 +1,31 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +namespace tensor +{ +template +struct buffer_cache +{ + std::shared_ptr> m_cache; + + buffer_cache() + : m_cache{std::make_shared>()} + { + } + buffer_cache(buffer_cache const&) = default; + buffer_cache& operator=(buffer_cache const&) = default; +}; +} // namespace tensor +} // namespace oomph diff --git a/include/oomph/tensor/detail/buffer_cache.hpp b/include/oomph/tensor/detail/buffer_cache.hpp new file mode 100644 index 00000000..b6a58a24 --- /dev/null +++ b/include/oomph/tensor/detail/buffer_cache.hpp @@ -0,0 +1,104 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace oomph +{ +namespace tensor +{ +namespace detail +{ +template +struct buffer_cache +{ + std::vector>> m_messages; + std::map> m_available_idx; + + buffer_cache() noexcept = default; + buffer_cache(buffer_cache&&) noexcept = default; + buffer_cache& operator=(buffer_cache&&) noexcept = default; + + std::shared_ptr> operator()(communicator& comm, std::size_t size, Id id) + { + std::set* index_set = nullptr; + auto it = m_available_idx.find(id); + if (it == m_available_idx.end()) + { + index_set = &(m_available_idx[id]); + for (std::size_t i = 0; i < m_messages.size(); ++i) index_set->insert(i); + } + else + { + index_set = &(it->second); + } + + auto m_it = std::find_if(index_set->begin(), index_set->end(), + [this, size](std::size_t i) { return (m_messages[i]->size() == size); }); + + if (m_it == index_set->end()) + { + m_messages.push_back(std::make_shared>(comm.make_buffer(size))); + for (auto& s : m_available_idx) s.second.insert(m_messages.size() - 1); + index_set->erase(m_messages.size() - 1); + return m_messages.back(); + } + else + { + auto res = *m_it; + index_set->erase(m_it); + return m_messages[res]; + } + } + + template + std::shared_ptr> operator()(communicator& comm, std::size_t size, Id id, + buffer_cache& c2, Id2 id2) + { + std::set* index_set = nullptr; + auto it = m_available_idx.find(id); + if (it == m_available_idx.end()) + { + index_set = &(m_available_idx[id]); + for (std::size_t i = 0; i < m_messages.size(); ++i) index_set->insert(i); + } + else + { + index_set = &(it->second); + } + + auto m_it = std::find_if(index_set->begin(), index_set->end(), + [this, size](std::size_t i) { return (m_messages[i]->size() == size); }); + + if (m_it == index_set->end()) + { + auto res = c2(comm, size, id2); + m_messages.push_back(res); + for (auto& s : m_available_idx) s.second.insert(m_messages.size() - 1); + index_set->erase(m_messages.size() - 1); + return res; + } + else + { + auto res = *m_it; + index_set->erase(m_it); + return m_messages[res]; + } + } +}; + +} // namespace detail +} // namespace tensor +} // namespace oomph diff --git a/include/oomph/tensor/detail/terminal.hpp b/include/oomph/tensor/detail/terminal.hpp index 26c40ac8..c44da783 100644 --- a/include/oomph/tensor/detail/terminal.hpp +++ b/include/oomph/tensor/detail/terminal.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -54,54 +55,6 @@ class terminal> T* m_ptr; }; - struct buffer_cache - { - communicator* m_comm; - std::vector>> m_messages; - std::map> m_available_idx; - - buffer_cache(communicator* c) noexcept - : m_comm{c} - { - } - - buffer_cache(buffer_cache&&) noexcept = default; - buffer_cache& operator=(buffer_cache&&) noexcept = default; - - std::shared_ptr> operator()(std::size_t size, std::size_t stage) - { - std::set* index_set = nullptr; - auto it = m_available_idx.find(stage); - if (it == m_available_idx.end()) - { - index_set = &(m_available_idx[stage]); - for (std::size_t i = 0; i < m_messages.size(); ++i) index_set->insert(i); - } - else - { - index_set = &(it->second); - } - - auto m_it = std::find_if(index_set->begin(), index_set->end(), - [this, size](std::size_t i) { return (m_messages[i]->size() == size); }); - - if (m_it == index_set->end()) - { - m_messages.push_back( - std::make_shared>(m_comm->make_buffer(size))); - for (auto& s : m_available_idx) s.second.insert(m_messages.size() - 1); - index_set->erase(m_messages.size() - 1); - return m_messages.back(); - } - else - { - auto res = *m_it; - index_set->erase(m_it); - return m_messages[res]; - } - } - }; - struct stage_t { std::vector m_transport_ranges; @@ -109,19 +62,28 @@ class terminal> }; protected: - map_type m_map; - std::unique_ptr m_comm; - std::map m_stages; - buffer_cache m_buffer_cache; - std::vector m_stage_lu; - bool m_connected = false; + map_type m_map; + std::unique_ptr m_comm; + std::map m_stages; + std::shared_ptr> m_top_buffer_cache; + buffer_cache m_buffer_cache; + std::vector m_stage_lu; + bool m_connected = false; public: template terminal(Map& m) : m_map{m} , m_comm{std::make_unique(oomph::detail::get_communicator(m.m_context))} - , m_buffer_cache{m_comm.get()} + , m_top_buffer_cache{std::make_shared>()} + { + } + + template + terminal(Map& m, std::shared_ptr> c) + : m_map{m} + , m_comm{std::make_unique(oomph::detail::get_communicator(m.m_context))} + , m_top_buffer_cache{c} { } @@ -171,7 +133,9 @@ class terminal> auto const n_elements_slice = n_elements / view.extents()[last_dim]; s.m_transport_ranges.push_back(transport_range{view, //m_comm->make_buffer(n_elements), - m_buffer_cache(n_elements, stage), rank, tag, false}); + //m_buffer_cache(*m_comm, n_elements, stage), + m_buffer_cache(*m_comm, n_elements, stage, *m_top_buffer_cache, m_comm.get()), rank, + tag, false}); T* ptr = s.m_transport_ranges.back().m_message->data(); std::size_t const first_k = view.first()[last_dim]; diff --git a/include/oomph/tensor/receiver.hpp b/include/oomph/tensor/receiver.hpp index 8dd25218..52dcdd22 100644 --- a/include/oomph/tensor/receiver.hpp +++ b/include/oomph/tensor/receiver.hpp @@ -9,8 +9,9 @@ */ #pragma once -#include #include +#include +#include namespace oomph { @@ -37,6 +38,11 @@ class receiver> : private detail::terminal { } + receiver(map_type& m, buffer_cache const& c) + : base(m, c.m_cache) + { + } + receiver(receiver&&) noexcept = default; receiver& operator=(receiver&&) noexcept = default; @@ -63,5 +69,12 @@ make_receiver(map& m) { return {m}; } + +template +receiver> +make_receiver(map& m, buffer_cache const& c) +{ + return {m, c}; +} } // namespace tensor } // namespace oomph diff --git a/include/oomph/tensor/sender.hpp b/include/oomph/tensor/sender.hpp index 3230a4e9..09892d56 100644 --- a/include/oomph/tensor/sender.hpp +++ b/include/oomph/tensor/sender.hpp @@ -9,8 +9,9 @@ */ #pragma once -#include #include +#include +#include namespace oomph { @@ -37,6 +38,11 @@ class sender> : private detail::terminal> { } + sender(map_type& m, buffer_cache const& c) + : base(m, c.m_cache) + { + } + sender(sender&&) noexcept = default; sender& operator=(sender&&) noexcept = default; @@ -64,5 +70,12 @@ make_sender(map& m) return {m}; } +template +sender> +make_sender(map& m, buffer_cache const& c) +{ + return {m, c}; +} + } // namespace tensor } // namespace oomph diff --git a/test/test_tensor.cpp b/test/test_tensor.cpp index 19de9ca0..299c8d2a 100644 --- a/test/test_tensor.cpp +++ b/test/test_tensor.cpp @@ -146,8 +146,10 @@ TEST_F(mpi_test_fixture, ctor) map_t t = ctxt.map_tensor({x + 2 * halo, y + 2 * halo, z + 2 * halo}, data.data(), (data.data() + data.size()) - 1); - tensor::sender> s = tensor::make_sender(t); - tensor::receiver> r = tensor::make_receiver(t); + tensor::buffer_cache s_cache; + tensor::buffer_cache r_cache; + tensor::sender> s = tensor::make_sender(t, s_cache); + tensor::receiver> r = tensor::make_receiver(t, r_cache); if (world_rank == 0) { From e17936ac3c645a9589e4e2499d44754d271864e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20B=C3=B6sch?= <48126478+boeschf@users.noreply.github.com> Date: Thu, 4 Nov 2021 08:50:11 +0100 Subject: [PATCH 6/6] runtime layout --- include/oomph/tensor/layout.hpp | 6 +- include/oomph/tensor/rt_layout.hpp | 100 +++++++++++++++++++++++++++++ test/test_tensor.cpp | 41 ++++++++---- 3 files changed, 133 insertions(+), 14 deletions(-) create mode 100644 include/oomph/tensor/rt_layout.hpp diff --git a/include/oomph/tensor/layout.hpp b/include/oomph/tensor/layout.hpp index 1a9eacd5..f97ee68e 100644 --- a/include/oomph/tensor/layout.hpp +++ b/include/oomph/tensor/layout.hpp @@ -33,20 +33,20 @@ struct layout using reverse_lookup = boost::mp11::mp_transform; // Get the position of the element with value `i` in the layout - static constexpr std::size_t find(int i) + static constexpr std::size_t find(std::size_t i) { return find_impl(i, boost::mp11::make_index_sequence{}); } // Get the value of the element at position `i` in the layout - static constexpr int at(std::size_t i) + static constexpr std::size_t at(std::size_t i) { std::size_t const ri[] = {I...}; return ri[i]; } template - static constexpr std::size_t find_impl(int i, boost::mp11::index_sequence) + static constexpr std::size_t find_impl(std::size_t i, boost::mp11::index_sequence) { std::size_t const ri[] = {boost::mp11::mp_at_c::value...}; return ri[i]; diff --git a/include/oomph/tensor/rt_layout.hpp b/include/oomph/tensor/rt_layout.hpp new file mode 100644 index 00000000..61d5a617 --- /dev/null +++ b/include/oomph/tensor/rt_layout.hpp @@ -0,0 +1,100 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +namespace tensor +{ +class rt_layout +{ + private: + std::vector const m; + + public: + rt_layout(std::initializer_list l) + : m{l} + { + } + + rt_layout(std::size_t* first, std::size_t* last) + : m(first, last) + { + } + + rt_layout(std::size_t* first, std::size_t count) + : rt_layout(first, first + count) + { + } + + rt_layout(rt_layout const&) = default; + + rt_layout(rt_layout&&) noexcept = default; + + public: + friend bool operator==(rt_layout const& a, rt_layout const& b) + { + std::size_t const s_a = a.size(); + std::size_t const s_b = b.size(); + if (s_a != s_b) return false; + for (std::size_t i = 0; i < s_a; ++i) + if (a.at(i) != b.at(i)) return false; + return true; + } + + template + bool is_equal() const noexcept + { + std::size_t const s = m.size(); + if (Layout::max_arg + 1 != s) return false; + for (std::size_t i = 0; i < s; ++i) + if (Layout::at(i) != m[i]) return false; + return true; + } + + public: + std::size_t const* data() const noexcept { return m.data(); } + std::size_t size() const noexcept { return m.size(); } + + public: + // Get the position of the element with value `i` in the layout + std::size_t find(std::size_t i) const noexcept + { + std::size_t const s = m.size(); + for (std::size_t j = 0; j < s; ++j) + if (m[j] == i) return j; + return s; + } + + // Get the value of the element at position `i` in the layout + std::size_t at(std::size_t i) const noexcept { return m[i]; } +}; + +namespace detail +{ +template +rt_layout make_rt_layout(std::index_sequence) noexcept +{ + return {{Layout::at(Is)...}}; +} + +} // namespace detail + +template +rt_layout +make_rt_layout() noexcept +{ + return detail::make_rt_layout(std::make_index_sequence{}); +} + +} // namespace tensor +} // namespace oomph diff --git a/test/test_tensor.cpp b/test/test_tensor.cpp index 299c8d2a..e61c6c57 100644 --- a/test/test_tensor.cpp +++ b/test/test_tensor.cpp @@ -9,6 +9,7 @@ */ #include +#include #include #include @@ -26,17 +27,35 @@ test_layout() { using namespace oomph; using namespace oomph::tensor; - static_assert(layout::at(0) == I, ""); - static_assert(layout::at(1) == J, ""); - static_assert(layout::at(2) == K, ""); + using layout_t = layout; + static_assert(layout_t::at(0) == I, ""); + static_assert(layout_t::at(1) == J, ""); + static_assert(layout_t::at(2) == K, ""); static constexpr std::size_t pos_0 = (I == 0) ? 0 : ((J == 0) ? 1 : 2); static constexpr std::size_t pos_1 = (I == 1) ? 0 : ((J == 1) ? 1 : 2); static constexpr std::size_t pos_2 = (I == 2) ? 0 : ((J == 2) ? 1 : 2); - static_assert(layout::find(0) == pos_0, ""); - static_assert(layout::find(1) == pos_1, ""); - static_assert(layout::find(2) == pos_2, ""); + static_assert(layout_t::find(0) == pos_0, ""); + static_assert(layout_t::find(1) == pos_1, ""); + static_assert(layout_t::find(2) == pos_2, ""); + + auto l = make_rt_layout(); + + EXPECT_EQ(l.at(0), layout_t::at(0)); + EXPECT_EQ(l.at(1), layout_t::at(1)); + EXPECT_EQ(l.at(2), layout_t::at(2)); + + EXPECT_EQ(l.find(0), layout_t::find(0)); + EXPECT_EQ(l.find(1), layout_t::find(1)); + EXPECT_EQ(l.find(2), layout_t::find(2)); + + EXPECT_TRUE( l.template is_equal() ); + EXPECT_FALSE( (l.template is_equal>()) ); + EXPECT_FALSE( (l.template is_equal>()) ); + EXPECT_FALSE( (l.template is_equal>()) ); + EXPECT_FALSE( (l.template is_equal>()) ); + EXPECT_FALSE( (l.template is_equal>()) ); } TEST(layout, ctor) @@ -146,8 +165,8 @@ TEST_F(mpi_test_fixture, ctor) map_t t = ctxt.map_tensor({x + 2 * halo, y + 2 * halo, z + 2 * halo}, data.data(), (data.data() + data.size()) - 1); - tensor::buffer_cache s_cache; - tensor::buffer_cache r_cache; + tensor::buffer_cache s_cache; + tensor::buffer_cache r_cache; tensor::sender> s = tensor::make_sender(t, s_cache); tensor::receiver> r = tensor::make_receiver(t, r_cache); @@ -155,11 +174,11 @@ TEST_F(mpi_test_fixture, ctor) { s // y-z plane, +x direction - .add_dst({{halo + x - halo, halo, halo}, {halo, y, z}}, 1, 0,1) + .add_dst({{halo + x - halo, halo, halo}, {halo, y, z}}, 1, 0, 1) // x-y plane, -z direction - .add_dst({{halo, halo, halo}, {x, y, halo}}, 2, 0,0) + .add_dst({{halo, halo, halo}, {x, y, halo}}, 2, 0, 0) // whole x-y plane, -z direction - .add_dst({{0, 0, halo}, {x + 2 * halo, y + 2 * halo, halo}}, 3, 0,2) + .add_dst({{0, 0, halo}, {x + 2 * halo, y + 2 * halo, halo}}, 3, 0, 2) .connect(); s.pack(1).wait();