Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 128 additions & 120 deletions benchmarks/comm_2_test_halo_exchange_3D_generic_full.cpp

Large diffs are not rendered by default.

122 changes: 65 additions & 57 deletions benchmarks/simple_comm_test_halo_exchange_3D_generic_full.cpp

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions include/ghex/communication_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,15 @@ namespace gridtools {

/** @brief communication object constructor
* @param p pattern*/
communication_object(const Pattern& p) :
communication_object(const Pattern& p, communicator_t comm) :
m_pattern{p},
m_send_halos{m_pattern.send_halos()},
m_receive_halos{m_pattern.recv_halos()},
m_n_send_halos{m_send_halos.size()},
m_n_receive_halos(m_receive_halos.size()),
m_send_buffers(m_n_send_halos),
m_receive_buffers(m_n_receive_halos),
m_communicator{m_pattern.communicator()} {
m_communicator{comm} {

for (const auto& halo : m_send_halos) {
const auto& domain_id = halo.first;
Expand Down
75 changes: 40 additions & 35 deletions include/ghex/communication_object_2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,25 @@ namespace gridtools {
namespace ghex {

// forward declaration
template<typename Transport, typename GridType, typename DomainIdType>
template<typename Communicator, typename GridType, typename DomainIdType>
class communication_object;

/** @brief handle type for waiting on asynchronous communication processes.
* The wait function is stored in a member.
* @tparam Transport message transport type
* @tparam GridType grid tag type
* @tparam DomainIdType domain id type*/
template<typename Transport, typename GridType, typename DomainIdType>
template<typename Communicator, typename GridType, typename DomainIdType>
class communication_handle
{
private: // friend class

friend class communication_object<Transport,GridType,DomainIdType>;
friend class communication_object<Communicator,GridType,DomainIdType>;

private: // member types

using co_t = communication_object<Transport,GridType,DomainIdType>;
using communicator_type = tl::communicator<Transport>;
using co_t = communication_object<Communicator,GridType,DomainIdType>;
using communicator_type = Communicator;

private: // members

Expand Down Expand Up @@ -89,30 +89,30 @@ namespace gridtools {
* @tparam Transport message transport type
* @tparam GridType grid tag type
* @tparam DomainIdType domain id type*/
template<typename Transport, typename GridType, typename DomainIdType>
template<typename Communicator, typename GridType, typename DomainIdType>
class communication_object
{
private: // friend class

friend class communication_handle<Transport,GridType,DomainIdType>;
friend class communication_handle<Communicator,GridType,DomainIdType>;

public: // member types

/** @brief handle type returned by exhange operation */
using handle_type = communication_handle<Transport,GridType,DomainIdType>;
using transport_type = Transport;
using handle_type = communication_handle<Communicator,GridType,DomainIdType>;
//using transport_type = Transport;
using grid_type = GridType;
using domain_id_type = DomainIdType;
using pattern_type = pattern<Transport,GridType,DomainIdType>;
using pattern_container_type = pattern_container<Transport,GridType,DomainIdType>;
using this_type = communication_object<Transport,GridType,DomainIdType>;
using pattern_type = pattern<Communicator,GridType,DomainIdType>;
using pattern_container_type = pattern_container<Communicator,GridType,DomainIdType>;
using this_type = communication_object<Communicator,GridType,DomainIdType>;

template<typename D, typename F>
using buffer_info_type = buffer_info<pattern_type,D,F>;

private: // member types

using communicator_type = typename handle_type::communicator_type;
using communicator_type = Communicator; //typename handle_type::communicator_type;
using address_type = typename communicator_type::address_type;
using index_container_type = typename pattern_type::index_container_type;
using pack_function_type = std::function<void(void*,const index_container_type&, void*)>;
Expand Down Expand Up @@ -191,12 +191,16 @@ namespace gridtools {
private: // members

bool m_valid;
communicator_type m_comm;
memory_type m_mem;
std::vector<typename communicator_type::template future<void>> m_send_futures;

public: // ctors

communication_object() : m_valid(false) {}
communication_object(communicator_type comm)
: m_valid(false)
, m_comm(comm)
{}
communication_object(const communication_object&) = delete;
communication_object(communication_object&&) = default;

Expand All @@ -221,7 +225,7 @@ namespace gridtools {
[[nodiscard]] handle_type exchange(buffer_info_type<Archs,Fields>... buffer_infos)
{
// check that arguments are compatible
using test_t = pattern_container<transport_type,grid_type,domain_id_type>;
using test_t = pattern_container<communicator_type,grid_type,domain_id_type>;
static_assert(detail::test_eq_t<test_t, typename buffer_info_type<Archs,Fields>::pattern_container_type...>::value,
"patterns are not compatible with this communication object");
if (m_valid)
Expand Down Expand Up @@ -257,9 +261,9 @@ namespace gridtools {
allocate<arch_type,value_type>(mem, bi->get_pattern(), field_ptr, my_dom_id, bi->device_id(), tag_offsets[i]);
++i;
});
handle_type h(std::get<0>(buffer_info_tuple)->get_pattern().communicator(), [this](){this->wait();});
post_recvs(h.m_comm);
pack(h.m_comm);
handle_type h(m_comm, [this](){this->wait();});
post_recvs();
pack();
return h;
}

Expand All @@ -275,8 +279,8 @@ namespace gridtools {
[[nodiscard]] handle_type exchange(buffer_info_type<Arch,Field>* first, std::size_t length)
{
auto h = exchange_impl(first, length);
post_recvs(h.m_comm);
pack(h.m_comm);
post_recvs();
pack();
return h;
}

Expand All @@ -293,10 +297,10 @@ namespace gridtools {
using field_type = std::remove_reference_t<decltype(first->get_field())>;
using value_type = typename field_type::value_type;
auto h = exchange_impl(first, length);
post_recvs(h.m_comm);
post_recvs();
h.m_wait_fct = [this](){this->wait_u<value_type,field_type>();};
memory_t& mem = std::get<memory_t>(m_mem);
packer<gpu>::template pack_u<value_type,field_type>(mem, m_send_futures, h.m_comm);
packer<gpu>::template pack_u<value_type,field_type>(mem, m_send_futures, m_comm);
return h;
}
#endif
Expand All @@ -316,7 +320,7 @@ namespace gridtools {
[[nodiscard]] handle_type exchange_impl(buffer_info_type<Arch,Field>* first, std::size_t length)
{
// check that arguments are compatible
using test_t = pattern_container<transport_type,grid_type,domain_id_type>;
using test_t = pattern_container<communicator_type,grid_type,domain_id_type>;
static_assert(std::is_same<test_t, typename buffer_info_type<Arch,Field>::pattern_container_type>::value,
"patterns are not compatible with this communication object");
if (m_valid)
Expand Down Expand Up @@ -344,12 +348,12 @@ namespace gridtools {
const auto my_dom_id =(first+k)->get_field().domain_id();
allocate<Arch,value_type>(mem, (first+k)->get_pattern(), field_ptr, my_dom_id, (first+k)->device_id(), tag_offset);
}
return handle_type(first->get_pattern().communicator(), [this](){this->wait();});
return handle_type(m_comm, [this](){this->wait();});
}

void post_recvs(communicator_type& comm)
void post_recvs()
{
detail::for_each(m_mem, [this,&comm](auto& m)
detail::for_each(m_mem, [this](auto& m)
{
for (auto& p0 : m.recv_memory)
{
Expand All @@ -361,19 +365,19 @@ namespace gridtools {
m.m_recv_futures.emplace_back(
typename std::remove_reference_t<decltype(m)>::future_type{
&p1.second,
comm.recv(p1.second.buffer, p1.second.address, p1.second.tag).m_handle});
m_comm.recv(p1.second.buffer, p1.second.address, p1.second.tag).m_handle});
}
}
}
});
}

void pack(communicator_type& comm)
void pack()
{
detail::for_each(m_mem, [this,&comm](auto& m)
detail::for_each(m_mem, [this](auto& m)
{
using arch_type = typename std::remove_reference_t<decltype(m)>::arch_type;
packer<arch_type>::pack(m,m_send_futures,comm);
packer<arch_type>::pack(m,m_send_futures,m_comm);
});
}

Expand Down Expand Up @@ -529,12 +533,13 @@ namespace gridtools {
* @tparam PatternContainer pattern type
* @return communication object */
template<typename PatternContainer>
auto make_communication_object()
auto make_communication_object(typename PatternContainer::value_type::communicator_type comm)
{
using transport_type = typename PatternContainer::value_type::communicator_type::transport_type;
using grid_type = typename PatternContainer::value_type::grid_type;
using domain_id_type = typename PatternContainer::value_type::domain_id_type;
return communication_object<transport_type,grid_type,domain_id_type>();
//using transport_type = typename PatternContainer::value_type::communicator_type::transport_type;
using communicator_type = typename PatternContainer::value_type::communicator_type;
using grid_type = typename PatternContainer::value_type::grid_type;
using domain_id_type = typename PatternContainer::value_type::domain_id_type;
return communication_object<communicator_type,grid_type,domain_id_type>(comm);
}

} // namespace ghex
Expand Down
2 changes: 1 addition & 1 deletion include/ghex/glue/gridtools/make_gt_pattern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace gridtools {
using halo_gen_type = typename Grid::domain_descriptor_type::halo_generator_type;
auto halo_gen = halo_gen_type(first,last, std::forward<Halos>(halos), grid.m_periodic);

return make_pattern<structured::grid>(grid.m_setup_comm, grid.m_comm, halo_gen, grid.m_domains);
return make_pattern<structured::grid>(grid.m_context, halo_gen, grid.m_domains);
}

} // namespace ghex
Expand Down
38 changes: 20 additions & 18 deletions include/ghex/glue/gridtools/processor_grid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,29 @@ namespace gridtools {

namespace ghex {

template<typename Transport>
template<typename Context>
struct gt_grid
{
using domain_descriptor_type = structured::domain_descriptor<int,3>;
using domain_id_type = typename domain_descriptor_type::domain_id_type;
MPI_Comm m_setup_comm;
tl::communicator<Transport> m_comm;
Context& m_context;
//MPI_Comm m_setup_comm;
//tl::communicator<Transport> m_comm;
std::vector<domain_descriptor_type> m_domains;
std::array<int, 3> m_global_extents;
std::array<bool, 3> m_periodic;
};

template<typename Layout = ::gridtools::layout_map<0,1,2>, typename Array0, typename Array1>
gt_grid<tl::mpi_tag>
make_gt_processor_grid(const Array0& local_extents, const Array1& periodicity, MPI_Comm cart_comm)
template<typename Layout = ::gridtools::layout_map<0,1,2>, typename Context, typename Array0, typename Array1>
gt_grid<Context>
make_gt_processor_grid(Context& context, const Array0& local_extents, const Array1& periodicity)
{
int dims[3];
int periods[3];
int coords[3];
MPI_Cart_get(cart_comm, 3, dims, periods, coords);
MPI_Cart_get(context.world(), 3, dims, periods, coords);
int rank;
MPI_Cart_rank(cart_comm, coords, &rank);
MPI_Cart_rank(context.world(), coords, &rank);

std::array<bool, 3> periodic;
std::copy(periodicity.begin(), periodicity.end(), periodic.begin());
Expand All @@ -56,18 +57,18 @@ namespace gridtools {
{
int coords_i[3] = {i,0,0};
int rank_i;
MPI_Cart_rank(cart_comm, coords_i, &rank_i);
MPI_Cart_rank(context.world(), coords_i, &rank_i);
if (coords[0]==i && coords[1]==0 && coords[2]==0)
{
// broadcast
int lext = local_extents[0];
extents_x[i] = lext;
MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, cart_comm);
MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, context.world());
}
else
{
// recv
MPI_Bcast(&extents_x[i], sizeof(int), MPI_BYTE, rank_i, cart_comm);
MPI_Bcast(&extents_x[i], sizeof(int), MPI_BYTE, rank_i, context.world());
}
}
std::partial_sum(extents_x.begin(), extents_x.end(), extents_x.begin());
Expand All @@ -77,18 +78,18 @@ namespace gridtools {
{
int coords_i[3] = {0,i,0};
int rank_i;
MPI_Cart_rank(cart_comm, coords_i, &rank_i);
MPI_Cart_rank(context.world(), coords_i, &rank_i);
if (coords[1]==i && coords[0]==0 && coords[2]==0)
{
// broadcast
int lext = local_extents[1];
extents_y[i] = lext;
MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, cart_comm);
MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, context.world());
}
else
{
// recv
MPI_Bcast(&extents_y[i], sizeof(int), MPI_BYTE, rank_i, cart_comm);
MPI_Bcast(&extents_y[i], sizeof(int), MPI_BYTE, rank_i, context.world());
}
}
std::partial_sum(extents_y.begin(), extents_y.end(), extents_y.begin());
Expand All @@ -98,18 +99,18 @@ namespace gridtools {
{
int coords_i[3] = {0,0,i};
int rank_i;
MPI_Cart_rank(cart_comm, coords_i, &rank_i);
MPI_Cart_rank(context.world(), coords_i, &rank_i);
if (coords[2]==i && coords[0]==0 && coords[1]==0)
{
// broadcast
int lext = local_extents[2];
extents_z[i] = lext;
MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, cart_comm);
MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, context.world());
}
else
{
// recv
MPI_Bcast(&extents_z[i], sizeof(int), MPI_BYTE, rank_i, cart_comm);
MPI_Bcast(&extents_z[i], sizeof(int), MPI_BYTE, rank_i, context.world());
}
}
std::partial_sum(extents_z.begin(), extents_z.end(), extents_z.begin());
Expand Down Expand Up @@ -143,7 +144,8 @@ namespace gridtools {

structured::domain_descriptor<int,3> local_domain{rank, global_first, global_last};

return {cart_comm, tl::communicator<tl::mpi_tag>{cart_comm}, {local_domain}, global_extents, periodic};
//return {cart_comm, tl::communicator<tl::mpi_tag>{cart_comm}, {local_domain}, global_extents, periodic};
return {context, {local_domain}, global_extents, periodic};

}

Expand Down
Loading