diff --git a/benchmarks/comm_2_test_halo_exchange_3D_generic_full.cpp b/benchmarks/comm_2_test_halo_exchange_3D_generic_full.cpp index 53e5239..ff8fb45 100644 --- a/benchmarks/comm_2_test_halo_exchange_3D_generic_full.cpp +++ b/benchmarks/comm_2_test_halo_exchange_3D_generic_full.cpp @@ -24,7 +24,8 @@ #include #include #include -#include +#include +#include #include #include @@ -33,6 +34,10 @@ #include #endif +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; + namespace halo_exchange_3D_generic_full { using timer_type = gridtools::ghex::timer; @@ -85,8 +90,8 @@ namespace halo_exchange_3D_generic_full { } } - template - bool run(ST &file, + template + bool run(ST &file, context_type& context, Comm comm, int DIM1, int DIM2, int DIM3, @@ -110,7 +115,7 @@ namespace halo_exchange_3D_generic_full { int H3p3, triple_t *_a, triple_t *_b, - triple_t *_c, bool use_gpu, gridtools::ghex::tl::mpi::communicator_base& world) + triple_t *_c, bool use_gpu) { // compute total domain const std::array g_first{ 0, 0, 0}; @@ -128,7 +133,7 @@ namespace halo_exchange_3D_generic_full { // define local domain domain_descriptor_type local_domain{ - world.rank(),//comm.rank(), + context.world().rank(),//comm.rank(), std::array{coords[0]*DIM1,coords[1]*DIM2,coords[2]*DIM3}, std::array{(coords[0]+1)*DIM1-1,(coords[1]+1)*DIM2-1,(coords[2]+1)*DIM3-1}}; std::vector local_domains{local_domain}; @@ -152,13 +157,13 @@ namespace halo_exchange_3D_generic_full { #endif // make patterns - auto pattern_1 = gridtools::ghex::make_pattern(world, halo_gen_1, local_domains); + auto pattern_1 = gridtools::ghex::make_pattern(context, halo_gen_1, local_domains); #ifndef GHEX_1_PATTERN_BENCHMARK - auto pattern_2 = gridtools::ghex::make_pattern(world, halo_gen_2, local_domains); - auto pattern_3 = gridtools::ghex::make_pattern(world, halo_gen_3, local_domains); + auto pattern_2 = gridtools::ghex::make_pattern(context, halo_gen_2, local_domains); + auto pattern_3 = gridtools::ghex::make_pattern(context, halo_gen_3, local_domains); #endif // communication object - auto co = gridtools::ghex::make_communication_object(); + auto co = gridtools::ghex::make_communication_object(comm); file << "Proc: (" << coords[0] << ", " << coords[1] << ", " << coords[2] << ")\n"; @@ -264,7 +269,7 @@ namespace halo_exchange_3D_generic_full { std::array{H1m3,H2m3,H3m3}, std::array{(DIM1 + H1m3 + H1p3), (DIM2 + H2m3 + H2p3), (DIM3 + H3m3 + H3p3)}); - world.barrier(); + MPI_Barrier(context.world()); // do all the stuff here file << " LOCAL MEAN STD MIN MAX" << std::endl; @@ -279,7 +284,7 @@ namespace halo_exchange_3D_generic_full { { timer_type t_0; timer_type t_1; - world.barrier(); + MPI_Barrier(context.world()); t_0.tic(); auto h = co.exchange( #ifndef GHEX_1_PATTERN_BENCHMARK @@ -295,14 +300,14 @@ namespace halo_exchange_3D_generic_full { t_1.tic(); h.wait(); t_1.toc(); - world.barrier(); + MPI_Barrier(context.world()); timer_type t; t(t_0.sum()+t_1.sum()); - auto t_0_all = gridtools::ghex::reduce(t_0,world); - auto t_1_all = gridtools::ghex::reduce(t_1,world); - auto t_all = gridtools::ghex::reduce(t,world); + auto t_0_all = gridtools::ghex::reduce(t_0,context.world()); + auto t_1_all = gridtools::ghex::reduce(t_1,context.world()); + auto t_all = gridtools::ghex::reduce(t,context.world()); if (k >= k_start) { t_0_local(t_0); @@ -401,7 +406,7 @@ namespace halo_exchange_3D_generic_full { delete[] gpu_c; #endif - world.barrier(); + MPI_Barrier(context.world()); } else @@ -409,7 +414,7 @@ namespace halo_exchange_3D_generic_full { auto field1 = a; auto field2 = b; auto field3 = c; - world.barrier(); + MPI_Barrier(context.world()); file << " LOCAL MEAN STD MIN MAX" << std::endl; timer_type t_0_local; @@ -423,7 +428,7 @@ namespace halo_exchange_3D_generic_full { { timer_type t_0; timer_type t_1; - world.barrier(); + MPI_Barrier(context.world()); t_0.tic(); auto h = co.exchange( #ifndef GHEX_1_PATTERN_BENCHMARK @@ -439,14 +444,14 @@ namespace halo_exchange_3D_generic_full { t_1.tic(); h.wait(); t_1.toc(); - world.barrier(); + MPI_Barrier(context.world()); timer_type t; t(t_0.sum()+t_1.sum()); - auto t_0_all = gridtools::ghex::reduce(t_0,world); - auto t_1_all = gridtools::ghex::reduce(t_1,world); - auto t_all = gridtools::ghex::reduce(t,world); + auto t_0_all = gridtools::ghex::reduce(t_0,context.world()); + auto t_1_all = gridtools::ghex::reduce(t_1,context.world()); + auto t_all = gridtools::ghex::reduce(t,context.world()); if (k >= k_start) { t_0_local(t_0); @@ -511,7 +516,7 @@ namespace halo_exchange_3D_generic_full { << std::endl; //file << std::endl << std::endl; - world.barrier(); + MPI_Barrier(context.world()); } @@ -682,7 +687,7 @@ namespace halo_exchange_3D_generic_full { int H3p3) { gridtools::ghex::tl::mpi::communicator_base world; - //std::cout << world.rank() << " " << world.size() << "\n"; + //std::cout << context.world().rank() << " " << context.world().size() << "\n"; std::stringstream ss; ss << world.rank(); @@ -700,6 +705,9 @@ namespace halo_exchange_3D_generic_full { MPI_Cart_create(world, 3, dims, period, false, &CartComm); MPI_Cart_get(CartComm, 3, dims, period, coords); + + context_type context(1, CartComm); + auto comm = context.get_communicator(context.get_token()); /* Each process will hold a tile of size (DIM1+2*H)x(DIM2+2*H)x(DIM3+2*H). The DIM1xDIM2xDIM3 area inside @@ -744,7 +752,7 @@ namespace halo_exchange_3D_generic_full { "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -768,12 +776,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -797,12 +805,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -826,12 +834,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -855,12 +863,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -884,12 +892,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -913,12 +921,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -942,12 +950,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -971,7 +979,7 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "---------------------------------------------------\n"; file << "Permutation 0,2,1\n"; @@ -979,7 +987,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1003,12 +1011,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1032,12 +1040,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1061,12 +1069,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1090,12 +1098,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1119,12 +1127,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1148,12 +1156,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1177,12 +1185,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1206,7 +1214,7 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "---------------------------------------------------\n"; file << "Permutation 1,0,2\n"; @@ -1214,7 +1222,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1238,12 +1246,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1267,12 +1275,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1296,12 +1304,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1325,12 +1333,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1354,12 +1362,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1383,12 +1391,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1412,12 +1420,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1441,7 +1449,7 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "---------------------------------------------------\n"; file << "Permutation 1,2,0\n"; @@ -1449,7 +1457,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1473,12 +1481,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1502,12 +1510,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1531,12 +1539,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1560,12 +1568,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1589,12 +1597,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1618,12 +1626,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1647,12 +1655,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H31, " "_a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1676,7 +1684,7 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "---------------------------------------------------\n"; file << "Permutation 2,0,1\n"; @@ -1684,7 +1692,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1708,12 +1716,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1737,12 +1745,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1766,12 +1774,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1795,12 +1803,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1824,12 +1832,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1853,12 +1861,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1882,12 +1890,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1911,7 +1919,7 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "---------------------------------------------------\n"; file << "Permutation 2,1,0\n"; @@ -1919,7 +1927,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1943,12 +1951,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -1972,12 +1980,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -2001,12 +2009,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -2030,12 +2038,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -2059,12 +2067,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -2088,12 +2096,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -2117,12 +2125,12 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed && run(file, + passed = passed && run(file, context, comm, DIM1, DIM2, DIM3, @@ -2146,7 +2154,7 @@ namespace halo_exchange_3D_generic_full { H3p3, _a, _b, - _c, use_gpu, world); + _c, use_gpu); file << "---------------------------------------------------\n"; delete[] _a; diff --git a/benchmarks/simple_comm_test_halo_exchange_3D_generic_full.cpp b/benchmarks/simple_comm_test_halo_exchange_3D_generic_full.cpp index c995912..f8c7940 100644 --- a/benchmarks/simple_comm_test_halo_exchange_3D_generic_full.cpp +++ b/benchmarks/simple_comm_test_halo_exchange_3D_generic_full.cpp @@ -31,9 +31,14 @@ #include #include #include -#include +#include +#include #include "../utils/triplet.hpp" +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; + /* CPU data descriptor */ template class my_data_desc { @@ -128,8 +133,8 @@ namespace halo_exchange_3D_generic_full { typedef double T2; typedef long long int T3; - template - bool run(ST &file, + template + bool run(ST &file, context_type& context, Comm comm, int DIM1, int DIM2, int DIM3, @@ -183,18 +188,18 @@ namespace halo_exchange_3D_generic_full { auto halo_gen_2 = halo_generator_t{g_first, g_last, halos_2, periodic}; auto halo_gen_3 = halo_generator_t{g_first, g_last, halos_3, periodic}; - auto patterns_1 = gridtools::ghex::make_pattern(CartComm, halo_gen_1, local_domains); - auto patterns_2 = gridtools::ghex::make_pattern(CartComm, halo_gen_2, local_domains); - auto patterns_3 = gridtools::ghex::make_pattern(CartComm, halo_gen_3, local_domains); + auto patterns_1 = gridtools::ghex::make_pattern(context, halo_gen_1, local_domains); + auto patterns_2 = gridtools::ghex::make_pattern(context, halo_gen_2, local_domains); + auto patterns_3 = gridtools::ghex::make_pattern(context, halo_gen_3, local_domains); using communication_object_t = gridtools::ghex::communication_object; // same type for all patterns std::vector cos_1; - for (const auto& p : patterns_1) cos_1.push_back(communication_object_t{p}); + for (const auto& p : patterns_1) cos_1.push_back(communication_object_t{p,comm}); std::vector cos_2; - for (const auto& p : patterns_2) cos_2.push_back(communication_object_t{p}); + for (const auto& p : patterns_2) cos_2.push_back(communication_object_t{p,comm}); std::vector cos_3; - for (const auto& p : patterns_3) cos_3.push_back(communication_object_t{p}); + for (const auto& p : patterns_3) cos_3.push_back(communication_object_t{p,comm}); array, layoutmap> a( _a, (DIM1 + H1m1 + H1p1), (DIM2 + H2m1 + H2p1), (DIM3 + H3m1 + H3p1)); @@ -509,6 +514,9 @@ namespace halo_exchange_3D_generic_full { MPI_Cart_create(MPI_COMM_WORLD, 3, dims, period, false, &CartComm); MPI_Cart_get(CartComm, 3, dims, period, coords); + + context_type context(1, CartComm); + auto comm = context.get_communicator(context.get_token()); /* Each process will hold a tile of size (DIM1+2*H)x(DIM2+2*H)x(DIM3+2*H). The DIM1xDIM2xDIM3 area inside @@ -553,7 +561,7 @@ namespace halo_exchange_3D_generic_full { bool passed = true; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -582,7 +590,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -612,7 +620,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -642,7 +650,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -672,7 +680,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -702,7 +710,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -732,7 +740,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -762,7 +770,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -795,7 +803,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -825,7 +833,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -855,7 +863,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -885,7 +893,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -915,7 +923,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -945,7 +953,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -975,7 +983,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1005,7 +1013,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1038,7 +1046,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1068,7 +1076,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1098,7 +1106,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1128,7 +1136,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1158,7 +1166,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1188,7 +1196,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1218,7 +1226,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1248,7 +1256,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1281,7 +1289,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1310,7 +1318,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1339,7 +1347,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1368,7 +1376,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1397,7 +1405,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1426,7 +1434,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1455,7 +1463,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1484,7 +1492,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H31, " "_a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1516,7 +1524,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1545,7 +1553,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1574,7 +1582,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1603,7 +1611,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1632,7 +1640,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1661,7 +1669,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1691,7 +1699,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1721,7 +1729,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1754,7 +1762,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1784,7 +1792,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1814,7 +1822,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1844,7 +1852,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1874,7 +1882,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, " "_c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1903,7 +1911,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1932,7 +1940,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, _a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, @@ -1961,7 +1969,7 @@ namespace halo_exchange_3D_generic_full { file << "run(file, DIM1, DIM2, DIM3, H1m, H1p, H2m, H2p, H3m, H3p, " "_a, " "_b, _c)\n"; - passed = passed and run(file, + passed = passed and run(file, context, comm, DIM1, DIM2, DIM3, diff --git a/include/ghex/communication_object.hpp b/include/ghex/communication_object.hpp index 345456f..ce8cd0d 100644 --- a/include/ghex/communication_object.hpp +++ b/include/ghex/communication_object.hpp @@ -203,7 +203,7 @@ namespace gridtools { /** @brief communication object constructor * @param p pattern*/ - communication_object(const Pattern& p) : + communication_object(const Pattern& p, communicator_t comm) : m_pattern{p}, m_send_halos{m_pattern.send_halos()}, m_receive_halos{m_pattern.recv_halos()}, @@ -211,7 +211,7 @@ namespace gridtools { m_n_receive_halos(m_receive_halos.size()), m_send_buffers(m_n_send_halos), m_receive_buffers(m_n_receive_halos), - m_communicator{m_pattern.communicator()} { + m_communicator{comm} { for (const auto& halo : m_send_halos) { const auto& domain_id = halo.first; diff --git a/include/ghex/communication_object_2.hpp b/include/ghex/communication_object_2.hpp index a891cad..b6ef7eb 100644 --- a/include/ghex/communication_object_2.hpp +++ b/include/ghex/communication_object_2.hpp @@ -28,7 +28,7 @@ namespace gridtools { namespace ghex { // forward declaration - template + template class communication_object; /** @brief handle type for waiting on asynchronous communication processes. @@ -36,17 +36,17 @@ namespace gridtools { * @tparam Transport message transport type * @tparam GridType grid tag type * @tparam DomainIdType domain id type*/ - template + template class communication_handle { private: // friend class - friend class communication_object; + friend class communication_object; private: // member types - using co_t = communication_object; - using communicator_type = tl::communicator; + using co_t = communication_object; + using communicator_type = Communicator; private: // members @@ -89,30 +89,30 @@ namespace gridtools { * @tparam Transport message transport type * @tparam GridType grid tag type * @tparam DomainIdType domain id type*/ - template + template class communication_object { private: // friend class - friend class communication_handle; + friend class communication_handle; public: // member types /** @brief handle type returned by exhange operation */ - using handle_type = communication_handle; - using transport_type = Transport; + using handle_type = communication_handle; + //using transport_type = Transport; using grid_type = GridType; using domain_id_type = DomainIdType; - using pattern_type = pattern; - using pattern_container_type = pattern_container; - using this_type = communication_object; + using pattern_type = pattern; + using pattern_container_type = pattern_container; + using this_type = communication_object; template using buffer_info_type = buffer_info; private: // member types - using communicator_type = typename handle_type::communicator_type; + using communicator_type = Communicator; //typename handle_type::communicator_type; using address_type = typename communicator_type::address_type; using index_container_type = typename pattern_type::index_container_type; using pack_function_type = std::function; @@ -191,12 +191,16 @@ namespace gridtools { private: // members bool m_valid; + communicator_type m_comm; memory_type m_mem; std::vector> m_send_futures; public: // ctors - communication_object() : m_valid(false) {} + communication_object(communicator_type comm) + : m_valid(false) + , m_comm(comm) + {} communication_object(const communication_object&) = delete; communication_object(communication_object&&) = default; @@ -221,7 +225,7 @@ namespace gridtools { [[nodiscard]] handle_type exchange(buffer_info_type... buffer_infos) { // check that arguments are compatible - using test_t = pattern_container; + using test_t = pattern_container; static_assert(detail::test_eq_t::pattern_container_type...>::value, "patterns are not compatible with this communication object"); if (m_valid) @@ -257,9 +261,9 @@ namespace gridtools { allocate(mem, bi->get_pattern(), field_ptr, my_dom_id, bi->device_id(), tag_offsets[i]); ++i; }); - handle_type h(std::get<0>(buffer_info_tuple)->get_pattern().communicator(), [this](){this->wait();}); - post_recvs(h.m_comm); - pack(h.m_comm); + handle_type h(m_comm, [this](){this->wait();}); + post_recvs(); + pack(); return h; } @@ -275,8 +279,8 @@ namespace gridtools { [[nodiscard]] handle_type exchange(buffer_info_type* first, std::size_t length) { auto h = exchange_impl(first, length); - post_recvs(h.m_comm); - pack(h.m_comm); + post_recvs(); + pack(); return h; } @@ -293,10 +297,10 @@ namespace gridtools { using field_type = std::remove_reference_tget_field())>; using value_type = typename field_type::value_type; auto h = exchange_impl(first, length); - post_recvs(h.m_comm); + post_recvs(); h.m_wait_fct = [this](){this->wait_u();}; memory_t& mem = std::get(m_mem); - packer::template pack_u(mem, m_send_futures, h.m_comm); + packer::template pack_u(mem, m_send_futures, m_comm); return h; } #endif @@ -316,7 +320,7 @@ namespace gridtools { [[nodiscard]] handle_type exchange_impl(buffer_info_type* first, std::size_t length) { // check that arguments are compatible - using test_t = pattern_container; + using test_t = pattern_container; static_assert(std::is_same::pattern_container_type>::value, "patterns are not compatible with this communication object"); if (m_valid) @@ -344,12 +348,12 @@ namespace gridtools { const auto my_dom_id =(first+k)->get_field().domain_id(); allocate(mem, (first+k)->get_pattern(), field_ptr, my_dom_id, (first+k)->device_id(), tag_offset); } - return handle_type(first->get_pattern().communicator(), [this](){this->wait();}); + return handle_type(m_comm, [this](){this->wait();}); } - void post_recvs(communicator_type& comm) + void post_recvs() { - detail::for_each(m_mem, [this,&comm](auto& m) + detail::for_each(m_mem, [this](auto& m) { for (auto& p0 : m.recv_memory) { @@ -361,19 +365,19 @@ namespace gridtools { m.m_recv_futures.emplace_back( typename std::remove_reference_t::future_type{ &p1.second, - comm.recv(p1.second.buffer, p1.second.address, p1.second.tag).m_handle}); + m_comm.recv(p1.second.buffer, p1.second.address, p1.second.tag).m_handle}); } } } }); } - void pack(communicator_type& comm) + void pack() { - detail::for_each(m_mem, [this,&comm](auto& m) + detail::for_each(m_mem, [this](auto& m) { using arch_type = typename std::remove_reference_t::arch_type; - packer::pack(m,m_send_futures,comm); + packer::pack(m,m_send_futures,m_comm); }); } @@ -529,12 +533,13 @@ namespace gridtools { * @tparam PatternContainer pattern type * @return communication object */ template - auto make_communication_object() + auto make_communication_object(typename PatternContainer::value_type::communicator_type comm) { - using transport_type = typename PatternContainer::value_type::communicator_type::transport_type; - using grid_type = typename PatternContainer::value_type::grid_type; - using domain_id_type = typename PatternContainer::value_type::domain_id_type; - return communication_object(); + //using transport_type = typename PatternContainer::value_type::communicator_type::transport_type; + using communicator_type = typename PatternContainer::value_type::communicator_type; + using grid_type = typename PatternContainer::value_type::grid_type; + using domain_id_type = typename PatternContainer::value_type::domain_id_type; + return communication_object(comm); } } // namespace ghex diff --git a/include/ghex/glue/gridtools/make_gt_pattern.hpp b/include/ghex/glue/gridtools/make_gt_pattern.hpp index af25b96..2c0a567 100644 --- a/include/ghex/glue/gridtools/make_gt_pattern.hpp +++ b/include/ghex/glue/gridtools/make_gt_pattern.hpp @@ -26,7 +26,7 @@ namespace gridtools { using halo_gen_type = typename Grid::domain_descriptor_type::halo_generator_type; auto halo_gen = halo_gen_type(first,last, std::forward(halos), grid.m_periodic); - return make_pattern(grid.m_setup_comm, grid.m_comm, halo_gen, grid.m_domains); + return make_pattern(grid.m_context, halo_gen, grid.m_domains); } } // namespace ghex diff --git a/include/ghex/glue/gridtools/processor_grid.hpp b/include/ghex/glue/gridtools/processor_grid.hpp index 281d2c4..9c43e53 100644 --- a/include/ghex/glue/gridtools/processor_grid.hpp +++ b/include/ghex/glue/gridtools/processor_grid.hpp @@ -24,28 +24,29 @@ namespace gridtools { namespace ghex { - template + template struct gt_grid { using domain_descriptor_type = structured::domain_descriptor; using domain_id_type = typename domain_descriptor_type::domain_id_type; - MPI_Comm m_setup_comm; - tl::communicator m_comm; + Context& m_context; + //MPI_Comm m_setup_comm; + //tl::communicator m_comm; std::vector m_domains; std::array m_global_extents; std::array m_periodic; }; - template, typename Array0, typename Array1> - gt_grid - make_gt_processor_grid(const Array0& local_extents, const Array1& periodicity, MPI_Comm cart_comm) + template, typename Context, typename Array0, typename Array1> + gt_grid + make_gt_processor_grid(Context& context, const Array0& local_extents, const Array1& periodicity) { int dims[3]; int periods[3]; int coords[3]; - MPI_Cart_get(cart_comm, 3, dims, periods, coords); + MPI_Cart_get(context.world(), 3, dims, periods, coords); int rank; - MPI_Cart_rank(cart_comm, coords, &rank); + MPI_Cart_rank(context.world(), coords, &rank); std::array periodic; std::copy(periodicity.begin(), periodicity.end(), periodic.begin()); @@ -56,18 +57,18 @@ namespace gridtools { { int coords_i[3] = {i,0,0}; int rank_i; - MPI_Cart_rank(cart_comm, coords_i, &rank_i); + MPI_Cart_rank(context.world(), coords_i, &rank_i); if (coords[0]==i && coords[1]==0 && coords[2]==0) { // broadcast int lext = local_extents[0]; extents_x[i] = lext; - MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, cart_comm); + MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, context.world()); } else { // recv - MPI_Bcast(&extents_x[i], sizeof(int), MPI_BYTE, rank_i, cart_comm); + MPI_Bcast(&extents_x[i], sizeof(int), MPI_BYTE, rank_i, context.world()); } } std::partial_sum(extents_x.begin(), extents_x.end(), extents_x.begin()); @@ -77,18 +78,18 @@ namespace gridtools { { int coords_i[3] = {0,i,0}; int rank_i; - MPI_Cart_rank(cart_comm, coords_i, &rank_i); + MPI_Cart_rank(context.world(), coords_i, &rank_i); if (coords[1]==i && coords[0]==0 && coords[2]==0) { // broadcast int lext = local_extents[1]; extents_y[i] = lext; - MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, cart_comm); + MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, context.world()); } else { // recv - MPI_Bcast(&extents_y[i], sizeof(int), MPI_BYTE, rank_i, cart_comm); + MPI_Bcast(&extents_y[i], sizeof(int), MPI_BYTE, rank_i, context.world()); } } std::partial_sum(extents_y.begin(), extents_y.end(), extents_y.begin()); @@ -98,18 +99,18 @@ namespace gridtools { { int coords_i[3] = {0,0,i}; int rank_i; - MPI_Cart_rank(cart_comm, coords_i, &rank_i); + MPI_Cart_rank(context.world(), coords_i, &rank_i); if (coords[2]==i && coords[0]==0 && coords[1]==0) { // broadcast int lext = local_extents[2]; extents_z[i] = lext; - MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, cart_comm); + MPI_Bcast(&lext, sizeof(int), MPI_BYTE, rank_i, context.world()); } else { // recv - MPI_Bcast(&extents_z[i], sizeof(int), MPI_BYTE, rank_i, cart_comm); + MPI_Bcast(&extents_z[i], sizeof(int), MPI_BYTE, rank_i, context.world()); } } std::partial_sum(extents_z.begin(), extents_z.end(), extents_z.begin()); @@ -143,7 +144,8 @@ namespace gridtools { structured::domain_descriptor local_domain{rank, global_first, global_last}; - return {cart_comm, tl::communicator{cart_comm}, {local_domain}, global_extents, periodic}; + //return {cart_comm, tl::communicator{cart_comm}, {local_domain}, global_extents, periodic}; + return {context, {local_domain}, global_extents, periodic}; } diff --git a/include/ghex/pattern.hpp b/include/ghex/pattern.hpp index 1ec5101..d54d1ac 100644 --- a/include/ghex/pattern.hpp +++ b/include/ghex/pattern.hpp @@ -11,9 +11,10 @@ #ifndef INCLUDED_GHEX_PATTERN_HPP #define INCLUDED_GHEX_PATTERN_HPP -#include "./transport_layer/mpi/setup.hpp" -#include "./transport_layer/mpi/communicator.hpp" +//#include "./transport_layer/mpi/setup.hpp" +//#include "./transport_layer/mpi/communicator.hpp" #include "./buffer_info.hpp" +#include "./transport_layer/context.hpp" namespace gridtools { @@ -26,19 +27,19 @@ namespace gridtools { } // namespace detail // forward declaration - template + template class pattern; /** @brief an iterable holding communication patterns (one pattern per domain) * @tparam Transport transport protocol * @tparam GridType indicates structured/unstructured grids * @tparam DomainIdType type to uniquely identify partail (local) domains*/ - template + template class pattern_container { public: // member tyes /** @brief pattern type this object is holding */ - using value_type = pattern; + using value_type = pattern; private: // private member types using data_type = std::vector; @@ -81,43 +82,25 @@ namespace gridtools { data_type m_patterns; int m_max_tag; }; - - namespace detail { - // implementation detail - template - auto make_pattern(tl::mpi::setup_communicator& setup_comm, tl::communicator& comm, HaloGenerator&& hgen, DomainRange&& d_range) - { - using grid_type = typename GridType::template type::value_type>; - return detail::make_pattern_impl::apply(setup_comm, comm, std::forward(hgen), std::forward(d_range)); - } - } // namespace detail - - // helper function if transport protocol is also MPI - template - auto make_pattern(MPI_Comm mpi_comm, HaloGenerator&& hgen, DomainRange&& d_range) - { - tl::communicator mpi_comm_{mpi_comm}; - tl::mpi::setup_communicator setup_comm(mpi_comm); - return detail::make_pattern(setup_comm, mpi_comm_, hgen, d_range); - } - + /** * @brief construct a pattern for each domain and establish neighbor relationships * @tparam GridType indicates structured/unstructured grids * @tparam Transport transport protocol + * @tparam ThreadPrimitives threading primitivs (locks etc.) * @tparam HaloGenerator function object which takes a domain as argument * @tparam DomainRange a range type holding domains - * @param mpi_comm MPI communicator (used for establishing network topology) - * @param comm custom communicator used in the actual exchange operations + * @param context transport layer context * @param hgen receive halo generator function object (emits iteration spaces (global coordinates) or index lists (global indices) * @param d_range range of local domains * @return iterable of patterns (one per domain) */ - template - auto make_pattern(MPI_Comm mpi_comm, tl::communicator& comm, HaloGenerator&& hgen, DomainRange&& d_range) + template + auto make_pattern(tl::context& context, HaloGenerator&& hgen, DomainRange&& d_range) { - tl::mpi::setup_communicator setup_comm(mpi_comm); - return detail::make_pattern(setup_comm, comm, hgen, d_range); + using grid_type = typename GridType::template type::value_type>; + return detail::make_pattern_impl::apply(context, std::forward(hgen), std::forward(d_range)); + } } // namespace ghex diff --git a/include/ghex/structured/pattern.hpp b/include/ghex/structured/pattern.hpp index eb31745..a679b3c 100644 --- a/include/ghex/structured/pattern.hpp +++ b/include/ghex/structured/pattern.hpp @@ -28,19 +28,19 @@ namespace gridtools { * @tparam Transport transport protocol * @tparam CoordinateArrayType coordinate-like array type * @tparam DomainIdType domain id type*/ - template - class pattern,DomainIdType> + template + class pattern,DomainIdType> { public: // member types using grid_type = structured::detail::grid; - using this_type = pattern; + using this_type = pattern; using coordinate_type = typename grid_type::coordinate_type; using coordinate_element_type = typename grid_type::coordinate_element_type; using dimension = typename grid_type::dimension; - using communicator_type = tl::communicator; + using communicator_type = Communicator; using address_type = typename communicator_type::address_type; using domain_id_type = DomainIdType; - using pattern_container_type = pattern_container; + using pattern_container_type = pattern_container; // this struct holds the first and the last coordinate (inclusive) // of a hypercube in N-dimensional space. @@ -157,10 +157,9 @@ namespace gridtools { return s; } - friend class pattern_container; + friend class pattern_container; private: // members - communicator_type m_comm; iteration_space_pair m_domain; coordinate_type m_global_first; coordinate_type m_global_last; @@ -170,8 +169,8 @@ namespace gridtools { pattern_container_type* m_container; public: // ctors - pattern(communicator_type& comm, const iteration_space_pair& domain, const extended_domain_id_type& id) - : m_comm(comm), m_domain(domain), m_id(id) {} + pattern(const iteration_space_pair& domain, const extended_domain_id_type& id) + : m_domain(domain), m_id(id) {} pattern(const pattern&) = default; pattern(pattern&&) = default; @@ -182,8 +181,6 @@ namespace gridtools { const map_type& recv_halos() const noexcept { return m_recv_map; } domain_id_type domain_id() const noexcept { return m_id.id; } extended_domain_id_type extended_domain_id() const noexcept { return m_id; } - communicator_type& communicator() noexcept { return m_comm; } - const communicator_type& communicator() const noexcept { return m_comm; } const pattern_container_type& container() const noexcept { return *m_container; } coordinate_type& global_first() noexcept { return m_global_first; } coordinate_type& global_last() noexcept { return m_global_last; } @@ -209,20 +206,24 @@ namespace gridtools { template struct make_pattern_impl<::gridtools::ghex::structured::detail::grid> { - template - static auto apply(tl::mpi::setup_communicator& comm, tl::communicator& new_comm, HaloGenerator&& hgen, DomainRange&& d_range) + template + static auto apply(tl::context& context, HaloGenerator&& hgen, DomainRange&& d_range) { // typedefs + using context_type = tl::context; using domain_type = typename std::remove_reference_t::value_type; using domain_id_type = typename domain_type::domain_id_type; using grid_type = ::gridtools::ghex::structured::detail::grid; - using pattern_type = pattern; + using communicator_type = typename context_type::communicator_type; + using pattern_type = pattern; using iteration_space = typename pattern_type::iteration_space; using iteration_space_pair = typename pattern_type::iteration_space_pair; using coordinate_type = typename pattern_type::coordinate_type; using extended_domain_id_type = typename pattern_type::extended_domain_id_type; // get this address from new communicator + auto comm = context.get_setup_communicator(); + auto new_comm = context.get_serial_communicator(); auto my_address = new_comm.address(); // set up domain ids, extents and recv halos @@ -244,7 +245,7 @@ namespace gridtools { iteration_space{coordinate_type{d.first()}-coordinate_type{d.first()}, coordinate_type{d.last()} -coordinate_type{d.first()}}, iteration_space{coordinate_type{d.first()}, coordinate_type{d.last()}}} ); - my_patterns.emplace_back( new_comm, my_domain_extents.back(), my_domain_ids.back() ); + my_patterns.emplace_back( /*new_comm,*/ my_domain_extents.back(), my_domain_ids.back() ); // make space for more halos my_generated_recv_halos.resize(my_generated_recv_halos.size()+1); // generate recv halos: invoke halo generator @@ -570,7 +571,7 @@ namespace gridtools { } } - return pattern_container(std::move(my_patterns), m_max_tag); + return pattern_container(std::move(my_patterns), m_max_tag); } }; diff --git a/include/ghex/threads/atomic/mutex.hpp b/include/ghex/threads/atomic/mutex.hpp new file mode 100644 index 0000000..4eee98d --- /dev/null +++ b/include/ghex/threads/atomic/mutex.hpp @@ -0,0 +1,62 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_THREADS_MUTEX_HPP +#define INCLUDED_GHEX_THREADS_MUTEX_HPP + +#include +#include + +namespace gridtools { + namespace ghex { + namespace threads { + namespace atomic { + + class atomic_mutex + { + private: // members + std::atomic m_flag; + public: + atomic_mutex() noexcept : m_flag(0) {} + atomic_mutex(const atomic_mutex&) = delete; + atomic_mutex(atomic_mutex&&) = delete; + + inline bool try_lock() noexcept + { + bool expected = false; + return m_flag.compare_exchange_weak(expected, true, std::memory_order_relaxed); + } + + inline bool try_unlock() noexcept + { + bool expected = true; + return m_flag.compare_exchange_weak(expected, false, std::memory_order_relaxed); + } + + inline void lock() noexcept + { + while (!try_lock()) {} + } + + inline void unlock() noexcept + { + while (!try_unlock()) {} + } + }; + + template + using lock_guard = std::lock_guard; + } // namespace atomic + } // namespace threads + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_GHEX_THREADS_MUTEX_HPP */ + diff --git a/include/ghex/threads/atomic/primitives.hpp b/include/ghex/threads/atomic/primitives.hpp new file mode 100644 index 0000000..abda84b --- /dev/null +++ b/include/ghex/threads/atomic/primitives.hpp @@ -0,0 +1,193 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_THREADS_ATOMIC_PRIMITIVES_HPP +#define INCLUDED_GHEX_THREADS_ATOMIC_PRIMITIVES_HPP + +#include +#include +#include "./mutex.hpp" + +namespace gridtools { + namespace ghex { + namespace threads { + namespace atomic { + + template + using void_return_type = typename std::enable_if< + std::is_same,void>::value, + void>::type; + + template + using return_type = typename std::enable_if< + !std::is_same,void>::value, + boost::callable_traits::return_type_t>::type; + +#ifndef GHEX_THREAD_SINGLE + struct primitives + { + public: // member types + using id_type = int; + + class token + { + private: // members + id_type m_id; + int m_epoch = 0; + bool m_selected = false; + + friend primitives; + + token(id_type id, int epoch) noexcept + : m_id(id), m_epoch(epoch), m_selected(id==0?true:false) + {} + + public: // ctors + token(const token&) = delete; + token(token&&) = default; + + public: // member functions + id_type id() const noexcept { return m_id; } + }; + + using mutex_type = atomic_mutex; + using lock_type = lock_guard; + + private: // members + const int m_num_threads; + std::atomic m_ids; + mutable volatile int m_epoch; + mutable std::atomic b_count; + mutable mutex_type m_mutex; + + public: // ctors + primitives(int num_threads) noexcept + : m_num_threads(num_threads) + , m_ids(0) + , m_epoch(0) + , b_count(0) + {} + + primitives(const primitives&) = delete; + primitives(primitives&&) = delete; + + public: // public member functions + inline token get_token() noexcept + { + return {(int)m_ids++,0}; + } + + inline void barrier(token& t) const noexcept + { + int expected = b_count; + while (!b_count.compare_exchange_weak(expected, expected+1, std::memory_order_relaxed)) + expected = b_count; + t.m_epoch ^= 1; + t.m_selected = (expected?false:true); + if (expected == m_num_threads-1) + { + b_count.store(0); + m_epoch ^= 1; + } + while(t.m_epoch != m_epoch) {} + } + + template + inline void single(token& t, F && f) const noexcept + { + if (t.m_selected) { + f(); + } + } + + template + inline void master(token& t, F && f) const noexcept + { + if (t.m_id == 0) { + f(); + } + } + + template + inline void_return_type critical(F && f) const noexcept + { + lock_type l(m_mutex); + f(); + } + + template + inline return_type critical(F && f) const noexcept + { + lock_type l(m_mutex); + return f(); + } + }; +#else + struct primitives + { + public: // member types + using id_type = int; + + class token + { + private: // members + id_type m_id; + + friend primitives; + + token(id_type id) noexcept + : m_id(id) + {} + + public: // ctors + token(const token&) = delete; + token(token&&) = default; + + public: // member functions + id_type id() const noexcept { return m_id; } + }; + + using mutex_type = atomic_mutex; + using lock_type = lock_guard; + + private: // members + + public: // ctors + primitives(int=1) noexcept + {} + + primitives(const primitives&) = delete; + primitives(primitives&&) = delete; + + public: // public member functions + inline token get_token() noexcept { return {0}; } + + inline void barrier(token& t) noexcept {} + + template + inline void single(token& t, F && f) const noexcept { f(); } + + template + inline void master(token& t, F && f) const noexcept { f(); } + + template + inline void_return_type critical(F && f) const noexcept { f(); } + + template + inline return_type critical(F && f) const noexcept { return f(); } + }; +#endif + } // namespace atomic + } // namespace threads + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_GHEX_THREADS_ATOMIC_PRIMITIVES_HPP */ + diff --git a/include/ghex/transport_layer/communicator.hpp b/include/ghex/transport_layer/communicator.hpp index 96ef884..a2d446c 100644 --- a/include/ghex/transport_layer/communicator.hpp +++ b/include/ghex/transport_layer/communicator.hpp @@ -19,8 +19,8 @@ namespace gridtools { /** @brief communicator class which exposes basic communication primitives * @tparam TransportTag transport protocol tag */ - template - class communicator; + //template + //class communicator; // concept // ------- diff --git a/include/ghex/transport_layer/config.hpp b/include/ghex/transport_layer/config.hpp new file mode 100644 index 0000000..8e304a7 --- /dev/null +++ b/include/ghex/transport_layer/config.hpp @@ -0,0 +1,17 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_TL_CONFIG_HPP +#define INCLUDED_TL_CONFIG_HPP + + + +#endif /* INCLUDED_TL_CONFIG_HPP */ + diff --git a/include/ghex/transport_layer/context.hpp b/include/ghex/transport_layer/context.hpp new file mode 100644 index 0000000..6570cbd --- /dev/null +++ b/include/ghex/transport_layer/context.hpp @@ -0,0 +1,208 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_TL_CONTEXT_HPP +#define INCLUDED_GHEX_TL_CONTEXT_HPP + +#include +#include +#include "./config.hpp" +#include "./mpi/setup.hpp" + +namespace gridtools { + namespace ghex { + namespace tl { + + template + class transport_context; + + class mpi_world + { + private: + MPI_Comm m_comm; + int m_rank; + int m_size; + bool m_owning = false; + + mpi_world(int& argc, char**& argv) + { +#if defined(GHEX_THREAD_SINGLE) + MPI_Init(&argc, &argv); +#elif defined(GHEX_MPI_USE_GHEX_LOCKS) + int provided; + int res = MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &provided); + if (res == MPI_ERR_OTHER) + throw std::runtime_error("MPI init failed"); + if (provided < MPI_THREAD_SERIALIZED) + throw std::runtime_error("MPI does not support required threading level"); +#else + int provided; + int res = MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); + if (res == MPI_ERR_OTHER) + throw std::runtime_error("MPI init failed"); + if (provided < MPI_THREAD_MULTIPLE) + throw std::runtime_error("MPI does not support required threading level"); +#endif + m_comm = MPI_COMM_WORLD; + MPI_Comm_rank(m_comm, &m_rank); + MPI_Comm_size(m_comm, &m_size); + m_owning = true; + } + + mpi_world(MPI_Comm comm) + : m_comm(comm) + , m_rank{ [comm]() { int r; MPI_Comm_rank(comm, &r); return r; }() } + , m_size{ [comm]() { int r; MPI_Comm_size(comm, &r); return r; }() } + {} + + mpi_world(const mpi_world&) = delete; + mpi_world(mpi_world&&) = delete; + + // forward declaration + template + friend class parallel_context; + + public: + ~mpi_world() + { + if (m_owning) MPI_Finalize(); + } + + public: + inline int rank() const noexcept { return m_rank; } + inline int size() const noexcept { return m_size; } + operator MPI_Comm() const noexcept { return m_comm; } + }; + + template + class parallel_context + { + public: // members + using thread_primitives_type = ThreadPrimitives; + using thread_token = typename thread_primitives_type::token; + + private: + mpi_world m_world; + thread_primitives_type m_thread_primitives; + + private: + template + parallel_context(int num_threads, int& argc, char**& argv, Args&&...) noexcept + : m_world(argc,argv) + , m_thread_primitives(num_threads) + {} + + template + parallel_context(int num_threads, MPI_Comm comm, Args&&...) noexcept + : m_world(comm) + , m_thread_primitives(num_threads) + {} + + parallel_context(const parallel_context&) = delete; + parallel_context(parallel_context&&) = delete; + + // forward declaration + template + friend class context; + + public: + // thread-safe + const mpi_world& world() const { return m_world; } + // thread-safe + const thread_primitives_type& thread_primitives() const { return m_thread_primitives; } + // thread-safe + void barrier(thread_token& t) const + { + m_thread_primitives.barrier(t); + m_thread_primitives.single(t, [this]() { MPI_Barrier(m_world.m_comm); } ); + m_thread_primitives.barrier(t); + } + }; + + template + class context + { + public: // member types + using tag = TransportTag; + using transport_context_type = transport_context; + using communicator_type = typename transport_context_type::communicator_type; + using parallel_context_type = parallel_context; + using thread_primitives_type = typename parallel_context_type::thread_primitives_type; + using thread_token = typename parallel_context_type::thread_token; + + private: + parallel_context_type m_parallel_context; + transport_context_type m_transport_context; + + public: + template + context(int num_threads, Args&&... args) + : m_parallel_context{num_threads, std::forward(args)...} + , m_transport_context{m_parallel_context, std::forward(args)...} + {} + + context(const context&) = delete; + context(context&&) = delete; + + public: + + const mpi_world& world() const noexcept + { + return m_parallel_context.world(); + } + + const thread_primitives_type& thread_primitives() const noexcept + { + return m_parallel_context.thread_primitives(); + } + + mpi::setup_communicator get_setup_communicator() + { + return mpi::setup_communicator(m_parallel_context.world()); + } + + // per-rank communicator (recv) + // thread-safe + communicator_type get_serial_communicator() + { + return m_parallel_context.m_thread_primitives.critical( + [this]() mutable { return m_transport_context.get_serial_communicator(); } + ); + } + + // per-thread communicator (send) + // thread-safe + communicator_type get_communicator(const thread_token& t) + { + return m_parallel_context.m_thread_primitives.critical( + [this,&t]() mutable { return m_transport_context.get_communicator(t.id()); } + ); + } + + // thread-safe + thread_token get_token() noexcept + { + return m_parallel_context.m_thread_primitives.get_token(); + } + + // thread-safe + void barrier(thread_token& t) const + { + m_parallel_context.barrier(t); + } + + }; + + } // namespace tl + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_CONTEXT_HPP */ + diff --git a/include/ghex/transport_layer/mpi/communicator.hpp b/include/ghex/transport_layer/mpi/communicator.hpp index 0174123..62bcf0a 100644 --- a/include/ghex/transport_layer/mpi/communicator.hpp +++ b/include/ghex/transport_layer/mpi/communicator.hpp @@ -15,167 +15,179 @@ #include "../communicator.hpp" #include "./communicator_base.hpp" #include "./future.hpp" -#include "./communicator_traits.hpp" +#include "../context.hpp" namespace gridtools { namespace ghex { namespace tl { + + template + struct transport_context; + + namespace mpi { - /** Mpi communicator which exposes basic non-blocking transport functionality and - * returns futures to await said transports to complete. */ - template<> - class communicator - : public mpi::communicator_base - { - public: - using transport_type = mpi_tag; - using base_type = mpi::communicator_base; - using address_type = typename base_type::rank_type; - using rank_type = typename base_type::rank_type; - using size_type = typename base_type::size_type; - using tag_type = typename base_type::tag_type; - using request = mpi::request; - using status = mpi::status; - template - using future = mpi::future; - using traits = mpi::communicator_traits; - - public: - - communicator(const traits& t = traits{}) : base_type{t.communicator()} {} - communicator(const base_type& c) : base_type{c} {} - communicator(const MPI_Comm& c) : base_type{c} {} + /** Mpi communicator which exposes basic non-blocking transport functionality and + * returns futures to await said transports to complete. */ + template + class communicator// + : public communicator_base + { + public: + using transport_type = mpi_tag; + using base_type = mpi::communicator_base; + using address_type = typename base_type::rank_type; + using rank_type = typename base_type::rank_type; + using size_type = typename base_type::size_type; + using tag_type = typename base_type::tag_type; + using request = request_t; + using status = status_t; + template + using future = future_t; + + public: + + using transport_context_type = transport_context; + transport_context_type* m_transport_context; + int m_thread_id; + + communicator(const MPI_Comm& c, transport_context_type* tc, int thread_id = -1) + : base_type{c} + , m_transport_context{tc} + , m_thread_id{thread_id} + {} + + communicator(const communicator&) = default; + communicator(communicator&&) noexcept = default; + + communicator& operator=(const communicator&) = default; + communicator& operator=(communicator&&) noexcept = default; + + /** @return address of this process */ + address_type address() const { return rank(); } + + public: // send + + /** @brief non-blocking send + * @tparam Message a container type + * @param msg source container + * @param dest destination rank + * @param tag message tag + * @return completion handle */ + template + [[nodiscard]] future send(const Message& msg, rank_type dest, tag_type tag) const + { + request req; + GHEX_CHECK_MPI_RESULT( + MPI_Isend(reinterpret_cast(msg.data()),sizeof(typename Message::value_type)*msg.size(), + MPI_BYTE, dest, tag, *this, &req.get()) + ); + return req; + } - communicator(const communicator&) = default; - communicator(communicator&&) noexcept = default; + public: // recv + + /** @brief non-blocking receive + * @tparam Message a container type + * @param msg destination container + * @param source source rank + * @param tag message tag + * @return completion handle */ + template + [[nodiscard]] future recv(Message& msg, rank_type source, tag_type tag) const + { + request req; + GHEX_CHECK_MPI_RESULT( + MPI_Irecv(reinterpret_cast(msg.data()),sizeof(typename Message::value_type)*msg.size(), + MPI_BYTE, source, tag, *this, &req.get())); + return req; + } - communicator& operator=(const communicator&) = default; - communicator& operator=(communicator&&) noexcept = default; + /** @brief non-blocking receive which allocates the container within this function and returns it + * in the future + * @tparam Message a container type + * @tparam Args additional argument types for construction of Message + * @param n number of elements to be received + * @param source source rank + * @param tag message tag + * @param args additional arguments to be passed to new container of type Message at construction + * @return completion handle with message as payload */ + template + [[nodiscard]] future recv(int n, rank_type source, tag_type tag, Args&& ...args) const + { + Message msg{n, std::forward(args)...}; + return { std::move(msg), recv(msg, source, tag).m_handle }; - /** @return address of this process */ - address_type address() const { return rank(); } + } - public: // send + /** @brief non-blocking receive which maches any tag from the given source. If a match is found, it + * allocates the container of type Message within this function and returns it in the future. + * The container size will be set according to the matched receive operation. + * @tparam Message a container type + * @tparam Args additional argument types for construction of Message + * @param source source rank + * @param args additional arguments to be passed to new container of type Message at construction + * @return optional which may hold a future< std::tuple > */ + template + [[nodiscard]] auto recv_any_tag(rank_type source, Args&& ...args) const + { + return recv_any(source, MPI_ANY_TAG, std::forward(args)...); + } - /** @brief non-blocking send - * @tparam Message a container type - * @param msg source container - * @param dest destination rank - * @param tag message tag - * @return completion handle */ - template - [[nodiscard]] future send(const Message& msg, rank_type dest, tag_type tag) const - { - request req; - GHEX_CHECK_MPI_RESULT( - MPI_Isend(reinterpret_cast(msg.data()),sizeof(typename Message::value_type)*msg.size(), - MPI_BYTE, dest, tag, *this, &req.get()) - ); - return req; - } - - public: // recv - - /** @brief non-blocking receive - * @tparam Message a container type - * @param msg destination container - * @param source source rank - * @param tag message tag - * @return completion handle */ - template - [[nodiscard]] future recv(Message& msg, rank_type source, tag_type tag) const - { - request req; - GHEX_CHECK_MPI_RESULT( - MPI_Irecv(reinterpret_cast(msg.data()),sizeof(typename Message::value_type)*msg.size(), - MPI_BYTE, source, tag, *this, &req.get())); - return req; - } - - /** @brief non-blocking receive which allocates the container within this function and returns it - * in the future - * @tparam Message a container type - * @tparam Args additional argument types for construction of Message - * @param n number of elements to be received - * @param source source rank - * @param tag message tag - * @param args additional arguments to be passed to new container of type Message at construction - * @return completion handle with message as payload */ - template - [[nodiscard]] future recv(int n, rank_type source, tag_type tag, Args&& ...args) const - { - Message msg{n, std::forward(args)...}; - return { std::move(msg), recv(msg, source, tag).m_handle }; - - } - - /** @brief non-blocking receive which maches any tag from the given source. If a match is found, it - * allocates the container of type Message within this function and returns it in the future. - * The container size will be set according to the matched receive operation. - * @tparam Message a container type - * @tparam Args additional argument types for construction of Message - * @param source source rank - * @param args additional arguments to be passed to new container of type Message at construction - * @return optional which may hold a future< std::tuple > */ - template - [[nodiscard]] auto recv_any_tag(rank_type source, Args&& ...args) const - { - return recv_any(source, MPI_ANY_TAG, std::forward(args)...); - } - - /** @brief non-blocking receive which maches any source using the given tag. If a match is found, it - * allocates the container of type Message within this function and returns it in the future. - * The container size will be set according to the matched receive operation. - * @tparam Message a container type - * @tparam Args additional argument types for construction of Message - * @param tag message tag - * @param args additional arguments to be passed to new container of type Message at construction - * @return optional which may hold a future< std::tuple > */ - template - [[nodiscard]] auto recv_any_source(tag_type tag, Args&& ...args) const - { - return recv_any(MPI_ANY_SOURCE, tag, std::forward(args)...); - } - - /** @brief non-blocking receive which maches any source and any tag. If a match is found, it - * allocates the container of type Message within this function and returns it in the future. - * The container size will be set according to the matched receive operation. - * @tparam Message a container type - * @tparam Args additional argument types for construction of Message - * @param tag message tag - * @param args additional arguments to be passed to new container of type Message at construction - * @return optional which may hold a future< std::tuple > */ - template - [[nodiscard]] auto recv_any_source_any_tag(Args&& ...args) const - { - return recv_any(MPI_ANY_SOURCE, MPI_ANY_TAG, std::forward(args)...); - } + /** @brief non-blocking receive which maches any source using the given tag. If a match is found, it + * allocates the container of type Message within this function and returns it in the future. + * The container size will be set according to the matched receive operation. + * @tparam Message a container type + * @tparam Args additional argument types for construction of Message + * @param tag message tag + * @param args additional arguments to be passed to new container of type Message at construction + * @return optional which may hold a future< std::tuple > */ + template + [[nodiscard]] auto recv_any_source(tag_type tag, Args&& ...args) const + { + return recv_any(MPI_ANY_SOURCE, tag, std::forward(args)...); + } - private: // implementation + /** @brief non-blocking receive which maches any source and any tag. If a match is found, it + * allocates the container of type Message within this function and returns it in the future. + * The container size will be set according to the matched receive operation. + * @tparam Message a container type + * @tparam Args additional argument types for construction of Message + * @param tag message tag + * @param args additional arguments to be passed to new container of type Message at construction + * @return optional which may hold a future< std::tuple > */ + template + [[nodiscard]] auto recv_any_source_any_tag(Args&& ...args) const + { + return recv_any(MPI_ANY_SOURCE, MPI_ANY_TAG, std::forward(args)...); + } - template - [[nodiscard]] boost::optional< future< std::tuple > > - recv_any(rank_type source, tag_type tag, Args&& ...args) const - { - MPI_Message mpi_msg; - status st; - int flag = 0; - GHEX_CHECK_MPI_RESULT(MPI_Improbe(source, tag, *this, &flag, &mpi_msg, &st.get())); - if (flag) + private: // implementation + + template + [[nodiscard]] boost::optional< future< std::tuple > > + recv_any(rank_type source, tag_type tag, Args&& ...args) const { - int count; - GHEX_CHECK_MPI_RESULT(MPI_Get_count(&st.get(), MPI_CHAR, &count)); - Message msg(count/sizeof(typename Message::value_type), std::forward(args)...); - request req; - GHEX_CHECK_MPI_RESULT(MPI_Imrecv(msg.data(), count, MPI_CHAR, &mpi_msg, &req.get())); - using future_t = future>; - return future_t{ std::make_tuple(std::move(msg), st.source(), st.tag()), std::move(req) }; + MPI_Message mpi_msg; + status st; + int flag = 0; + GHEX_CHECK_MPI_RESULT(MPI_Improbe(source, tag, *this, &flag, &mpi_msg, &st.get())); + if (flag) + { + int count; + GHEX_CHECK_MPI_RESULT(MPI_Get_count(&st.get(), MPI_CHAR, &count)); + Message msg(count/sizeof(typename Message::value_type), std::forward(args)...); + request req; + GHEX_CHECK_MPI_RESULT(MPI_Imrecv(msg.data(), count, MPI_CHAR, &mpi_msg, &req.get())); + using future_t = future>; + return future_t{ std::make_tuple(std::move(msg), st.source(), st.tag()), std::move(req) }; + } + return boost::none; } - return boost::none; - } - }; + }; + + } // namespace mpi } // namespace tl diff --git a/include/ghex/transport_layer/mpi/communicator_traits.hpp b/include/ghex/transport_layer/mpi/communicator_traits.hpp deleted file mode 100644 index e335760..0000000 --- a/include/ghex/transport_layer/mpi/communicator_traits.hpp +++ /dev/null @@ -1,45 +0,0 @@ -/* - * GridTools - * - * Copyright (c) 2019, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - */ - -#ifndef GHEX_MPI_COMMUNICATOR_TRAITS_HPP -#define GHEX_MPI_COMMUNICATOR_TRAITS_HPP - -//#include -#include "./communicator_base.hpp" - -namespace gridtools -{ -namespace ghex -{ -namespace tl { -namespace mpi -{ - -struct communicator_traits -{ - communicator_base m_comm; - - communicator_traits(MPI_Comm comm) - : m_comm{comm} - { } - - communicator_traits() - : m_comm{MPI_COMM_WORLD} - { } - - MPI_Comm communicator() const { return m_comm; } -}; - -} // namespace mpi -} // namespace tl -} // namespace ghex -} // namespace gridtools - -#endif diff --git a/include/ghex/transport_layer/mpi/context.hpp b/include/ghex/transport_layer/mpi/context.hpp new file mode 100644 index 0000000..7ab6d96 --- /dev/null +++ b/include/ghex/transport_layer/mpi/context.hpp @@ -0,0 +1,53 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_TL_MPI_CONTEXT_HPP +#define INCLUDED_TL_MPI_CONTEXT_HPP + +#include "../context.hpp" +#include "./communicator.hpp" + +namespace gridtools { + namespace ghex { + namespace tl { + + template + struct transport_context + { + using communicator_type = mpi::communicator; + + parallel_context& m_parallel_context; + + template + transport_context(parallel_context& pc, Args&&...) + : m_parallel_context(pc) + {} + + communicator_type get_serial_communicator() + { + return {(MPI_Comm)(m_parallel_context.world()),this}; + } + + communicator_type get_communicator(int thread_id) + { + return {(MPI_Comm)(m_parallel_context.world()),this, thread_id}; + } + + }; + + //using mpi_context = context; + + } + } +} + +#endif /* INCLUDED_TL_MPI_CONTEXT_HPP */ + + diff --git a/include/ghex/transport_layer/mpi/future.hpp b/include/ghex/transport_layer/mpi/future.hpp index 542db95..ea7b9fc 100644 --- a/include/ghex/transport_layer/mpi/future.hpp +++ b/include/ghex/transport_layer/mpi/future.hpp @@ -20,22 +20,22 @@ namespace gridtools{ /** @brief future template for non-blocking communication */ template - struct future + struct future_t { using value_type = T; - using handle_type = request; + using handle_type = request_t; value_type m_data; handle_type m_handle; - future(value_type&& data, handle_type&& h) + future_t(value_type&& data, handle_type&& h) : m_data(std::move(data)) , m_handle(std::move(h)) {} - future(const future&) = delete; - future(future&&) = default; - future& operator=(const future&) = delete; - future& operator=(future&&) = default; + future_t(const future_t&) = delete; + future_t(future_t&&) = default; + future_t& operator=(const future_t&) = delete; + future_t& operator=(future_t&&) = default; void wait() noexcept { @@ -72,20 +72,20 @@ namespace gridtools{ }; template<> - struct future + struct future_t { - using handle_type = request; + using handle_type = request_t; handle_type m_handle; - future() noexcept = default; - future(handle_type&& h) + future_t() noexcept = default; + future_t(handle_type&& h) : m_handle(std::move(h)) {} - future(const future&) = delete; - future(future&&) = default; - future& operator=(const future&) = delete; - future& operator=(future&&) = default; + future_t(const future_t&) = delete; + future_t(future_t&&) = default; + future_t& operator=(const future_t&) = delete; + future_t& operator=(future_t&&) = default; void wait() noexcept { diff --git a/include/ghex/transport_layer/mpi/request.hpp b/include/ghex/transport_layer/mpi/request.hpp index b35051f..9069d20 100644 --- a/include/ghex/transport_layer/mpi/request.hpp +++ b/include/ghex/transport_layer/mpi/request.hpp @@ -20,7 +20,7 @@ namespace gridtools{ namespace mpi { /** @brief thin wrapper around MPI_Request */ - struct request + struct request_t { GHEX_C_STRUCT(req_type, MPI_Request) req_type m_req = MPI_REQUEST_NULL; diff --git a/include/ghex/transport_layer/mpi/setup.hpp b/include/ghex/transport_layer/mpi/setup.hpp index f153c62..87d718b 100644 --- a/include/ghex/transport_layer/mpi/setup.hpp +++ b/include/ghex/transport_layer/mpi/setup.hpp @@ -26,10 +26,11 @@ namespace gridtools{ { public: using base_type = communicator_base; - using handle_type = request; + using handle_type = request_t; using address_type = base_type::rank_type; + using status = status_t; template - using future = future; + using future = future_t; public: setup_communicator(const MPI_Comm& comm) : base_type{comm} {} diff --git a/include/ghex/transport_layer/mpi/status.hpp b/include/ghex/transport_layer/mpi/status.hpp index f3116c4..e8af320 100644 --- a/include/ghex/transport_layer/mpi/status.hpp +++ b/include/ghex/transport_layer/mpi/status.hpp @@ -20,7 +20,7 @@ namespace gridtools{ namespace mpi { /** @brief thin wrapper around MPI_Request */ - struct status + struct status_t { GHEX_C_STRUCT(stat_type, MPI_Status) stat_type m_status; diff --git a/tests/communication_object.cpp b/tests/communication_object.cpp index 5b1e0a3..21242fb 100644 --- a/tests/communication_object.cpp +++ b/tests/communication_object.cpp @@ -20,10 +20,14 @@ #include #include #include -#include -#include +#include +#include #include "../utils/triplet.hpp" +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; + /* CPU data descriptor */ template @@ -97,8 +101,7 @@ TEST(communication_object, constructor) { using coordinate_t = domain_descriptor_t::coordinate_type; using halo_generator_t = domain_descriptor_t::halo_generator_type; //gridtools::structured::halo_generator; - gridtools::ghex::tl::mpi::communicator_base world; - gridtools::ghex::tl::communicator comm{world}; + context_type context(1, MPI_COMM_WORLD); /* Problem sizes */ const int d1 = 2; @@ -121,22 +124,23 @@ TEST(communication_object, constructor) { std::vector local_domains; domain_descriptor_t my_domain_1{ - comm.rank(), - coordinate_t{(comm.rank() % d1 ) * DIM1 , (comm.rank() / d1) * DIM2 , 0}, - coordinate_t{(comm.rank() % d1 + 1) * DIM1 - 1, (comm.rank() / d1 + 1) * DIM2 - 1, DIM3-1} + context.world().rank(), + coordinate_t{(context.world().rank() % d1 ) * DIM1 , (context.world().rank() / d1) * DIM2 , 0}, + coordinate_t{(context.world().rank() % d1 + 1) * DIM1 - 1, (context.world().rank() / d1 + 1) * DIM2 - 1, DIM3-1} }; local_domains.push_back(my_domain_1); auto halo_gen = halo_generator_t{g_first, g_last, halos, periodic}; - auto patterns = gridtools::ghex::make_pattern(world, halo_gen, local_domains); + auto patterns = gridtools::ghex::make_pattern(context, halo_gen, local_domains); using communication_object_t = gridtools::ghex::communication_object; + auto comm = context.get_communicator(context.get_token()); std::vector cos; for (const auto& p : patterns) { EXPECT_NO_THROW( - cos.push_back(communication_object_t{p}); + cos.push_back(communication_object_t{p,comm}); ); } @@ -150,8 +154,7 @@ TEST(communication_object, exchange) { using halo_generator_t = domain_descriptor_t::halo_generator_type; //gridtools::structured_halo_generator; using layout_map_type = gridtools::layout_map<2, 1, 0>; - gridtools::ghex::tl::mpi::communicator_base world; - gridtools::ghex::tl::communicator comm{world}; + context_type context(1, MPI_COMM_WORLD); /* Problem sizes */ const int d1 = 2; @@ -170,26 +173,27 @@ TEST(communication_object, exchange) { const std::array g_last{d1*DIM1-1, d2*DIM2-1, d3*DIM3-1}; const std::array halos{H1m, H1p, H2m, H2p, H3m, H3p}; const std::array periodic{true, true, true}; - int coords[3]{comm.rank() % d1, comm.rank() / d1, 0}; // rank in cartesian coordinates + int coords[3]{context.world().rank() % d1, context.world().rank() / d1, 0}; // rank in cartesian coordinates std::vector local_domains; domain_descriptor_t my_domain_1{ - comm.rank(), - coordinate_t{(comm.rank() % d1 ) * DIM1 , (comm.rank() / d1) * DIM2 , 0}, - coordinate_t{(comm.rank() % d1 + 1) * DIM1 - 1, (comm.rank() / d1 + 1) * DIM2 - 1, DIM3-1} + context.world().rank(), + coordinate_t{(context.world().rank() % d1 ) * DIM1 , (context.world().rank() / d1) * DIM2 , 0}, + coordinate_t{(context.world().rank() % d1 + 1) * DIM1 - 1, (context.world().rank() / d1 + 1) * DIM2 - 1, DIM3-1} }; local_domains.push_back(my_domain_1); auto halo_gen = halo_generator_t{g_first, g_last, halos, periodic}; - auto patterns = gridtools::ghex::make_pattern(world, halo_gen, local_domains); + auto patterns = gridtools::ghex::make_pattern(context, halo_gen, local_domains); using communication_object_t = gridtools::ghex::communication_object; + auto comm = context.get_communicator(context.get_token()); std::vector cos; for (const auto& p : patterns) { - cos.push_back(communication_object_t{p}); + cos.push_back(communication_object_t{p,comm}); } triple_t* _values_1 = new triple_t[(DIM1 + H1m + H1p) * (DIM2 + H2m + H2p) * (DIM3 + H3m + H3p)]; @@ -255,8 +259,7 @@ TEST(communication_object, exchange_asymmetric_halos) { using halo_generator_t = domain_descriptor_t::halo_generator_type; //gridtools::structured_halo_generator; using layout_map_type = gridtools::layout_map<2, 1, 0>; - gridtools::ghex::tl::mpi::communicator_base world; - gridtools::ghex::tl::communicator comm{world}; + context_type context(1, MPI_COMM_WORLD); /* Problem sizes */ const int d1 = 2; @@ -275,26 +278,27 @@ TEST(communication_object, exchange_asymmetric_halos) { const std::array g_last{d1*DIM1-1, d2*DIM2-1, d3*DIM3-1}; const std::array halos{H1m, H1p, H2m, H2p, H3m, H3p}; const std::array periodic{true, true, true}; - int coords[3]{comm.rank() % d1, comm.rank() / d1, 0}; // rank in cartesian coordinates + int coords[3]{context.world().rank() % d1, context.world().rank() / d1, 0}; // rank in cartesian coordinates std::vector local_domains; domain_descriptor_t my_domain_1{ - comm.rank(), - coordinate_t{(comm.rank() % d1 ) * DIM1 , (comm.rank() / d1) * DIM2 , 0}, - coordinate_t{(comm.rank() % d1 + 1) * DIM1 - 1, (comm.rank() / d1 + 1) * DIM2 - 1, DIM3-1} + context.world().rank(), + coordinate_t{(context.world().rank() % d1 ) * DIM1 , (context.world().rank() / d1) * DIM2 , 0}, + coordinate_t{(context.world().rank() % d1 + 1) * DIM1 - 1, (context.world().rank() / d1 + 1) * DIM2 - 1, DIM3-1} }; local_domains.push_back(my_domain_1); auto halo_gen = halo_generator_t{g_first, g_last, halos, periodic}; - auto patterns = gridtools::ghex::make_pattern(world, halo_gen, local_domains); + auto patterns = gridtools::ghex::make_pattern(context, halo_gen, local_domains); using communication_object_t = gridtools::ghex::communication_object; + auto comm = context.get_communicator(context.get_token()); std::vector cos; for (const auto& p : patterns) { - cos.push_back(communication_object_t{p}); + cos.push_back(communication_object_t{p,comm}); } triple_t* _values_1 = new triple_t[(DIM1 + H1m + H1p) * (DIM2 + H2m + H2p) * (DIM3 + H3m + H3p)]; @@ -360,8 +364,7 @@ TEST(communication_object, exchange_multiple_fields) { using halo_generator_t = domain_descriptor_t::halo_generator_type; //gridtools::structured_halo_generator; using layout_map_type = gridtools::layout_map<2, 1, 0>; - gridtools::ghex::tl::mpi::communicator_base world; - gridtools::ghex::tl::communicator comm{world}; + context_type context(1, MPI_COMM_WORLD); /* Problem sizes */ const int d1 = 2; @@ -381,26 +384,27 @@ TEST(communication_object, exchange_multiple_fields) { const std::array halos{H1m, H1p, H2m, H2p, H3m, H3p}; const std::array periodic{true, true, true}; const int add = 1; - int coords[3]{comm.rank() % d1, comm.rank() / d1, 0}; // rank in cartesian coordinates + int coords[3]{context.world().rank() % d1, context.world().rank() / d1, 0}; // rank in cartesian coordinates std::vector local_domains; domain_descriptor_t my_domain_1{ - comm.rank(), - coordinate_t{(comm.rank() % d1 ) * DIM1 , (comm.rank() / d1) * DIM2 , 0}, - coordinate_t{(comm.rank() % d1 + 1) * DIM1 - 1, (comm.rank() / d1 + 1) * DIM2 - 1, DIM3-1} + context.world().rank(), + coordinate_t{(context.world().rank() % d1 ) * DIM1 , (context.world().rank() / d1) * DIM2 , 0}, + coordinate_t{(context.world().rank() % d1 + 1) * DIM1 - 1, (context.world().rank() / d1 + 1) * DIM2 - 1, DIM3-1} }; local_domains.push_back(my_domain_1); auto halo_gen = halo_generator_t{g_first, g_last, halos, periodic}; - auto patterns = gridtools::ghex::make_pattern(world, halo_gen, local_domains); + auto patterns = gridtools::ghex::make_pattern(context, halo_gen, local_domains); using communication_object_t = gridtools::ghex::communication_object; + auto comm = context.get_communicator(context.get_token()); std::vector cos; for (const auto& p : patterns) { - cos.push_back(communication_object_t{p}); + cos.push_back(communication_object_t{p,comm}); } triple_t* _values_1 = new triple_t[(DIM1 + H1m + H1p) * (DIM2 + H2m + H2p) * (DIM3 + H3m + H3p)]; @@ -496,8 +500,7 @@ TEST(communication_object, multithreading) { using halo_generator_t = domain_descriptor_t::halo_generator_type; //gridtools::structured_halo_generator; using layout_map_type = gridtools::layout_map<2, 1, 0>; - gridtools::ghex::tl::mpi::communicator_base world; - gridtools::ghex::tl::communicator comm{world}; + context_type context(1, MPI_COMM_WORLD); /* Problem sizes */ const int d1 = 2; @@ -516,33 +519,34 @@ TEST(communication_object, multithreading) { const std::array g_last{d1 * DIM1 * 2 - 1, d2 * DIM2 - 1, d3 * DIM3 - 1}; const std::array halos{H1m, H1p, H2m, H2p, H3m, H3p}; const std::array periodic{true, true, true}; - int coords[3]{comm.rank() % d1, comm.rank() / d1, 0}; // rank in cartesian coordinates + int coords[3]{context.world().rank() % d1, context.world().rank() / d1, 0}; // rank in cartesian coordinates std::vector local_domains; domain_descriptor_t my_domain_1{ - comm.rank() * 2, - coordinate_t{(comm.rank() % d1) * DIM1 * 2 , (comm.rank() / d1 ) * DIM2 , 0 }, - coordinate_t{(comm.rank() % d1) * DIM1 * 2 + DIM1 - 1 , (comm.rank() / d1 + 1) * DIM2 - 1, DIM3-1} + context.world().rank() * 2, + coordinate_t{(context.world().rank() % d1) * DIM1 * 2 , (context.world().rank() / d1 ) * DIM2 , 0 }, + coordinate_t{(context.world().rank() % d1) * DIM1 * 2 + DIM1 - 1 , (context.world().rank() / d1 + 1) * DIM2 - 1, DIM3-1} }; local_domains.push_back(my_domain_1); domain_descriptor_t my_domain_2{ - comm.rank() * 2 + 1, - coordinate_t{(comm.rank() % d1) * DIM1 * 2 + DIM1 , (comm.rank() / d1 ) * DIM2 , 0 }, - coordinate_t{(comm.rank() % d1) * DIM1 * 2 + DIM1 * 2 - 1, (comm.rank() / d1 + 1) * DIM2 - 1, DIM3-1} + context.world().rank() * 2 + 1, + coordinate_t{(context.world().rank() % d1) * DIM1 * 2 + DIM1 , (context.world().rank() / d1 ) * DIM2 , 0 }, + coordinate_t{(context.world().rank() % d1) * DIM1 * 2 + DIM1 * 2 - 1, (context.world().rank() / d1 + 1) * DIM2 - 1, DIM3-1} }; local_domains.push_back(my_domain_2); auto halo_gen = halo_generator_t{g_first, g_last, halos, periodic}; - auto patterns = gridtools::ghex::make_pattern(world, halo_gen, local_domains); + auto patterns = gridtools::ghex::make_pattern(context, halo_gen, local_domains); using communication_object_t = gridtools::ghex::communication_object; + auto comm = context.get_communicator(context.get_token()); std::vector cos; for (const auto& p : patterns) { - cos.push_back(communication_object_t{p}); + cos.push_back(communication_object_t{p,comm}); } triple_t* _values_1 = new triple_t[(DIM1 + H1m + H1p) * (DIM2 + H2m + H2p) * (DIM3 + H3m + H3p)]; @@ -652,8 +656,7 @@ TEST(communication_object, multithreading_multiple_fileds) { using halo_generator_t = domain_descriptor_t::halo_generator_type; //gridtools::structured_halo_generator; using layout_map_type = gridtools::layout_map<2, 1, 0>; - gridtools::ghex::tl::mpi::communicator_base world; - gridtools::ghex::tl::communicator comm{world}; + context_type context(1, MPI_COMM_WORLD); /* Problem sizes */ const int d1 = 2; @@ -672,33 +675,34 @@ TEST(communication_object, multithreading_multiple_fileds) { const std::array g_last{d1 * DIM1 * 2 - 1, d2 * DIM2 - 1, d3 * DIM3 - 1}; const std::array halos{H1m, H1p, H2m, H2p, H3m, H3p}; const std::array periodic{true, true, true}; - int coords[3]{comm.rank() % d1, comm.rank() / d1, 0}; // rank in cartesian coordinates + int coords[3]{context.world().rank() % d1, context.world().rank() / d1, 0}; // rank in cartesian coordinates std::vector local_domains; domain_descriptor_t my_domain_1{ - comm.rank() * 2, - coordinate_t{(comm.rank() % d1) * DIM1 * 2 , (comm.rank() / d1 ) * DIM2 , 0 }, - coordinate_t{(comm.rank() % d1) * DIM1 * 2 + DIM1 - 1 , (comm.rank() / d1 + 1) * DIM2 - 1, DIM3-1} + context.world().rank() * 2, + coordinate_t{(context.world().rank() % d1) * DIM1 * 2 , (context.world().rank() / d1 ) * DIM2 , 0 }, + coordinate_t{(context.world().rank() % d1) * DIM1 * 2 + DIM1 - 1 , (context.world().rank() / d1 + 1) * DIM2 - 1, DIM3-1} }; local_domains.push_back(my_domain_1); domain_descriptor_t my_domain_2{ - comm.rank() * 2 + 1, - coordinate_t{(comm.rank() % d1) * DIM1 * 2 + DIM1 , (comm.rank() / d1 ) * DIM2 , 0 }, - coordinate_t{(comm.rank() % d1) * DIM1 * 2 + DIM1 * 2 - 1, (comm.rank() / d1 + 1) * DIM2 - 1, DIM3-1} + context.world().rank() * 2 + 1, + coordinate_t{(context.world().rank() % d1) * DIM1 * 2 + DIM1 , (context.world().rank() / d1 ) * DIM2 , 0 }, + coordinate_t{(context.world().rank() % d1) * DIM1 * 2 + DIM1 * 2 - 1, (context.world().rank() / d1 + 1) * DIM2 - 1, DIM3-1} }; local_domains.push_back(my_domain_2); auto halo_gen = halo_generator_t{g_first, g_last, halos, periodic}; - auto patterns = gridtools::ghex::make_pattern(world, halo_gen, local_domains); + auto patterns = gridtools::ghex::make_pattern(context, halo_gen, local_domains); using communication_object_t = gridtools::ghex::communication_object; + auto comm = context.get_communicator(context.get_token()); std::vector cos; for (const auto& p : patterns) { - cos.push_back(communication_object_t{p}); + cos.push_back(communication_object_t{p,comm}); } triple_t* _values_1_1 = new triple_t[(DIM1 + H1m + H1p) * (DIM2 + H2m + H2p) * (DIM3 + H3m + H3p)]; diff --git a/tests/communication_object_2.cpp b/tests/communication_object_2.cpp index 4d1e3d2..01cd6ac 100644 --- a/tests/communication_object_2.cpp +++ b/tests/communication_object_2.cpp @@ -8,27 +8,18 @@ * SPDX-License-Identifier: BSD-3-Clause * */ -//#define STANDALONE - -//#define SERIAL_SPLIT -//#define MULTI_THREADED_EXCHANGE -//#define MULTI_THREADED_EXCHANGE_THREADS -//#define MULTI_THREADED_EXCHANGE_ASYNC_ASYNC -//#define MULTI_THREADED_EXCHANGE_ASYNC_DEFERRED -//#define MULTI_THREADED_EXCHANGE_ASYNC_ASYNC_WAIT #include #include -#include +#include +#include #include #include #include #include -#ifndef STANDALONE #include -#endif #include #ifdef __CUDACC__ @@ -44,6 +35,10 @@ __global__ void print_kernel() { } #endif +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; + template using array_type = gridtools::array; @@ -145,25 +140,24 @@ bool test_values(const Domain& d, const Halos& halos, const Periodic& periodic, } -#ifndef STANDALONE TEST(communication_object_2, exchange) +{ +#if defined(GHEX_TEST_SERIAL) || defined(GHEX_TEST_SERIAL_VECTOR) || defined(GHEX_TEST_SERIAL_SPLIT) || defined(GHEX_TEST_SERIAL_SPLIT_VECTOR) + context_type context(1, MPI_COMM_WORLD); #else -bool test0() + context_type context(2, MPI_COMM_WORLD); #endif -{ - //gridtools::ghex::mpi::mpi_comm mpi_comm; - gridtools::ghex::tl::mpi::communicator_base mpi_comm; #ifdef __CUDACC__ int num_devices_per_node; cudaGetDeviceCount(&num_devices_per_node); MPI_Comm raw_local_comm; - MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, mpi_comm.rank(), MPI_INFO_NULL, &raw_local_comm); + MPI_Comm_split_type(context.world(), MPI_COMM_TYPE_SHARED, context.world().rank(), MPI_INFO_NULL, &raw_local_comm); gridtools::ghex::tl::mpi::communicator_base local_comm(raw_local_comm, gridtools::ghex::tl::mpi::comm_take_ownership); if (local_comm.rank()>>(); cudaDeviceSynchronize(); @@ -172,19 +166,16 @@ bool test0() #ifdef GHEX_EMULATE_GPU int num_devices_per_node = 1; MPI_Comm raw_local_comm; - MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, mpi_comm.rank(), MPI_INFO_NULL, &raw_local_comm); + MPI_Comm_split_type(context.world(), MPI_COMM_TYPE_SHARED, context.world().rank(), MPI_INFO_NULL, &raw_local_comm); gridtools::ghex::tl::mpi::communicator_base local_comm(raw_local_comm, gridtools::ghex::tl::mpi::comm_take_ownership); if (local_comm.rank() comm{mpi_comm}; - // local portion per domain const std::array local_ext{10,15,20}; //const std::array local_ext{4,3,2}; @@ -213,7 +204,7 @@ bool test0() // compute total domain const std::array g_first{ 0, 0, 0}; - const std::array g_last {local_ext[0]*4-1, ((comm.size()-1)/2+1)*local_ext[1]-1, local_ext[2]-1}; + const std::array g_last {local_ext[0]*4-1, ((context.world().size()-1)/2+1)*local_ext[1]-1, local_ext[2]-1}; // maximum halo const std::array offset{3,3,3}; // local size including potential halos @@ -249,13 +240,13 @@ bool test0() // add local domains std::vector local_domains; local_domains.push_back( domain_descriptor_type{ - comm.rank()*2, - std::array{ ((comm.rank()%2)*2 )*local_ext[0], (comm.rank()/2 )*local_ext[1], 0}, - std::array{ ((comm.rank()%2)*2+1)*local_ext[0]-1, (comm.rank()/2+1)*local_ext[1]-1, local_ext[2]-1}}); + context.world().rank()*2, + std::array{ ((context.world().rank()%2)*2 )*local_ext[0], (context.world().rank()/2 )*local_ext[1], 0}, + std::array{ ((context.world().rank()%2)*2+1)*local_ext[0]-1, (context.world().rank()/2+1)*local_ext[1]-1, local_ext[2]-1}}); local_domains.push_back( domain_descriptor_type{ - comm.rank()*2+1, - std::array{ ((comm.rank()%2)*2+1)*local_ext[0], (comm.rank()/2 )*local_ext[1], 0}, - std::array{ ((comm.rank()%2)*2+2)*local_ext[0]-1, (comm.rank()/2+1)*local_ext[1]-1, local_ext[2]-1}}); + context.world().rank()*2+1, + std::array{ ((context.world().rank()%2)*2+1)*local_ext[0], (context.world().rank()/2 )*local_ext[1], 0}, + std::array{ ((context.world().rank()%2)*2+2)*local_ext[0]-1, (context.world().rank()/2+1)*local_ext[1]-1, local_ext[2]-1}}); // halo generators std::array halos1{0,0,1,0,1,2}; @@ -264,13 +255,8 @@ bool test0() auto halo_gen2 = domain_descriptor_type::halo_generator_type(g_first, g_last, halos2, periodic); // make patterns - auto pattern1 = gridtools::ghex::make_pattern(comm, halo_gen1, local_domains); - auto pattern2 = gridtools::ghex::make_pattern(comm, halo_gen2, local_domains); - - // communication object - auto co = gridtools::ghex::make_communication_object(); - auto co_1 = gridtools::ghex::make_communication_object(); - auto co_2 = gridtools::ghex::make_communication_object(); + auto pattern1 = gridtools::ghex::make_pattern(context, halo_gen1, local_domains); + auto pattern2 = gridtools::ghex::make_pattern(context, halo_gen2, local_domains); // wrap raw fields auto field_1a = gridtools::ghex::wrap_field(local_domains[0].domain_id(), field_1a_raw.data(), offset, local_ext_buffer); @@ -288,41 +274,6 @@ bool test0() fill_values(local_domains[0], field_3a); fill_values(local_domains[1], field_3b); - //// print arrays - //std::cout.flush(); - //comm.barrier(); - //for (int r=0; r(context.get_communicator(context.get_token())); co.bexchange( pattern1(field_1a_gpu), pattern1(field_1b), @@ -433,6 +385,7 @@ bool test0() pattern1(field_3b) ); #else + auto co = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co.bexchange( pattern1(field_1a_gpu), pattern1(field_1b_gpu), @@ -444,6 +397,7 @@ bool test0() #endif #endif #ifdef GHEX_TEST_SERIAL_VECTOR + auto co = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); std::vector> field_vec{ pattern1(field_1a_gpu), pattern1(field_1b_gpu), @@ -455,11 +409,15 @@ bool test0() #endif #ifdef GHEX_TEST_SERIAL_SPLIT + auto token = context.get_token(); + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(token)); // non-blocking variant auto h1 = co_1.exchange(pattern1(field_1a_gpu), pattern2(field_2a_gpu), pattern1(field_3a_gpu)); #ifdef GHEX_HYBRID_TESTS + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(token)); auto h2 = co_2.exchange(pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)); #else + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(token)); auto h2 = co_2.exchange(pattern1(field_1b_gpu), pattern2(field_2b_gpu), pattern1(field_3b_gpu)); #endif // ... overlap communication (packing, posting) with computation here @@ -468,6 +426,9 @@ bool test0() h2.wait(); #endif #ifdef GHEX_TEST_SERIAL_SPLIT_VECTOR + auto token = context.get_token(); + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(token)); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(token)); std::vector> field_vec_a{ pattern1(field_1a_gpu), pattern2(field_2a_gpu), @@ -485,24 +446,25 @@ bool test0() #endif #ifdef GHEX_TEST_THREADS - auto func = [](decltype(co)& co_, auto... bis) + auto func = [&context](auto... bis) { + auto co_ = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co_.bexchange(bis...); }; // packing and posting may be done concurrently // waiting and unpacking may be done concurrently std::vector threads; - threads.push_back(std::thread{func, std::ref(co_1), + threads.push_back(std::thread{func, pattern1(field_1a_gpu), pattern2(field_2a_gpu), pattern1(field_3a_gpu)}); #ifdef GHEX_HYBRID_TESTS - threads.push_back(std::thread{func, std::ref(co_2), + threads.push_back(std::thread{func, pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)}); #else - threads.push_back(std::thread{func, std::ref(co_2), + threads.push_back(std::thread{func, pattern1(field_1b_gpu), pattern2(field_2b_gpu), pattern1(field_3b_gpu)}); @@ -512,8 +474,9 @@ bool test0() #endif #ifdef GHEX_TEST_THREADS_VECTOR using field_vec_type = std::vector>; - auto func = [](decltype(co)& co_, field_vec_type& vec) + auto func = [&context](field_vec_type& vec) { + auto co_ = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co_.exchange(vec.data(), vec.size()).wait(); }; // packing and posting may be done concurrently @@ -527,31 +490,32 @@ bool test0() pattern1(field_1b_gpu), pattern2(field_2b_gpu), pattern1(field_3b_gpu)}; - threads.push_back(std::thread{func, std::ref(co_1), std::ref(field_vec_a)}); - threads.push_back(std::thread{func, std::ref(co_2), std::ref(field_vec_b)}); + threads.push_back(std::thread{func, std::ref(field_vec_a)}); + threads.push_back(std::thread{func, std::ref(field_vec_b)}); // ... overlap communication with computation here for (auto& t : threads) t.join(); #endif #ifdef GHEX_TEST_ASYNC_ASYNC - auto func = [](decltype(co)& co_, auto... bis) + auto func = [&context](auto... bis) { + auto co_ = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co_.bexchange(bis...); }; // packing and posting may be done concurrently // waiting and unpacking may be done concurrently auto policy = std::launch::async; - auto future_1 = std::async(policy, func, std::ref(co_1), + auto future_1 = std::async(policy, func, pattern1(field_1a_gpu), pattern2(field_2a_gpu), pattern1(field_3a_gpu)); #ifdef GHEX_HYBRID_TESTS - auto future_2 = std::async(policy, func, std::ref(co_2), + auto future_2 = std::async(policy, func, pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)); #else - auto future_2 = std::async(policy, func, std::ref(co_2), + auto future_2 = std::async(policy, func, pattern1(field_1b_gpu), pattern2(field_2b_gpu), pattern1(field_3b_gpu)); @@ -562,8 +526,9 @@ bool test0() #endif #ifdef GHEX_TEST_ASYNC_ASYNC_VECTOR using field_vec_type = std::vector>; - auto func = [](decltype(co)& co_, field_vec_type& vec) + auto func = [](field_vec_type& vec) { + auto co_ = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co_.exchange(vec.data(), vec.size()).wait(); }; // packing and posting may be done concurrently @@ -577,32 +542,34 @@ bool test0() pattern1(field_1b_gpu), pattern2(field_2b_gpu), pattern1(field_3b_gpu)}; - auto future_1 = std::async(policy, func, std::ref(co_1), std::ref(field_vec_a)); - auto future_2 = std::async(policy, func, std::ref(co_2), std::ref(field_vec_b)); + auto future_1 = std::async(policy, func, std::ref(field_vec_a)); + auto future_2 = std::async(policy, func, std::ref(field_vec_b)); // ... overlap communication with computation here future_1.wait(); future_2.wait(); #endif #ifdef GHEX_TEST_ASYNC_DEFERRED - auto func_h = [](decltype(co)& co_, auto... bis) + auto func_h = [](auto co_, auto... bis) { - return co_.exchange(bis...); + return co_->exchange(bis...); }; + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); // packing and posting serially on current thread // waiting and unpacking serially on current thread auto policy = std::launch::deferred; - auto future_1 = std::async(policy, func_h, std::ref(co_1), + auto future_1 = std::async(policy, func_h, &co_1, pattern1(field_1a_gpu), pattern2(field_2a_gpu), pattern1(field_3a_gpu)); #ifdef GHEX_HYBRID_TESTS - auto future_2 = std::async(policy, func_h, std::ref(co_2), + auto future_2 = std::async(policy, func_h, &co_2, pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)); #else - auto future_2 = std::async(policy, func_h, std::ref(co_2), + auto future_2 = std::async(policy, func_h, &co_2, pattern1(field_1b_gpu), pattern2(field_2b_gpu), pattern1(field_3b_gpu)); @@ -617,10 +584,12 @@ bool test0() #endif #ifdef GHEX_TEST_ASYNC_DEFERRED_VECTOR using field_vec_type = std::vector>; - auto func_h = [](decltype(co)& co_, field_vec_type& vec) + auto func_h = [](auto co_, field_vec_type& vec) { - return co_.exchange(vec.data(), vec.size()); + return co_->exchange(vec.data(), vec.size()); }; + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); // packing and posting may be done concurrently // waiting and unpacking may be done concurrently auto policy = std::launch::deferred; @@ -632,8 +601,8 @@ bool test0() pattern1(field_1b_gpu), pattern2(field_2b_gpu), pattern1(field_3b_gpu)}; - auto future_1 = std::async(policy, func_h, std::ref(co_1), std::ref(field_vec_a)); - auto future_2 = std::async(policy, func_h, std::ref(co_2), std::ref(field_vec_b)); + auto future_1 = std::async(policy, func_h, &co_1, std::ref(field_vec_a)); + auto future_2 = std::async(policy, func_h, &co_2, std::ref(field_vec_b)); // deferred policy: essentially serial on current thread auto h1 = future_1.get(); auto h2 = future_2.get(); @@ -644,24 +613,26 @@ bool test0() #endif #ifdef GHEX_TEST_ASYNC_ASYNC_WAIT - auto func_h = [](decltype(co)& co_, auto... bis) + auto func_h = [](auto co_, auto... bis) { - return co_.exchange(bis...); + return co_->exchange(bis...); }; + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); // packing and posting may be done concurrently // waiting and unpacking serially auto policy = std::launch::async; - auto future_1 = std::async(policy, func_h, std::ref(co_1), + auto future_1 = std::async(policy, func_h, &co_1, pattern1(field_1a_gpu), pattern2(field_2a_gpu), pattern1(field_3a_gpu)); #ifdef GHEX_HYBRID_TESTS - auto future_2 = std::async(policy, func_h, std::ref(co_2), + auto future_2 = std::async(policy, func_h, &co_2, pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)); #else - auto future_2 = std::async(policy, func_h, std::ref(co_2), + auto future_2 = std::async(policy, func_h, &co_2, pattern1(field_1b_gpu), pattern2(field_2b_gpu), pattern1(field_3b_gpu)); @@ -673,10 +644,12 @@ bool test0() #endif #ifdef GHEX_TEST_ASYNC_ASYNC_WAIT_VECTOR using field_vec_type = std::vector>; - auto func_h = [](decltype(co)& co_, field_vec_type& vec) + auto func_h = [](auto co_, field_vec_type& vec) { - return co_.exchange(vec.data(), vec.size()); + return co_->exchange(vec.data(), vec.size()); }; + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); // packing and posting may be done concurrently // waiting and unpacking may be done concurrently auto policy = std::launch::async; @@ -688,8 +661,8 @@ bool test0() pattern1(field_1b_gpu), pattern2(field_2b_gpu), pattern1(field_3b_gpu)}; - auto future_1 = std::async(policy, func_h, std::ref(co_1), std::ref(field_vec_a)); - auto future_2 = std::async(policy, func_h, std::ref(co_2), std::ref(field_vec_b)); + auto future_1 = std::async(policy, func_h, &co_1, std::ref(field_vec_a)); + auto future_2 = std::async(policy, func_h, &co_2, std::ref(field_vec_b)); // ... overlap communication (packing, posting) with computation here // waiting and unpacking is serial here future_1.get().wait(); @@ -745,6 +718,7 @@ bool test0() // exchange #ifdef GHEX_TEST_SERIAL // blocking variant + auto co = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co.bexchange( pattern1(field_1a), pattern1(field_1b), @@ -755,6 +729,7 @@ bool test0() ); #endif #ifdef GHEX_TEST_SERIAL_VECTOR + auto co = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); std::vector> field_vec{ pattern1(field_1a), pattern1(field_1b), @@ -767,6 +742,9 @@ bool test0() #ifdef GHEX_TEST_SERIAL_SPLIT // non-blocking variant + auto token = context.get_token(); + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(token)); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(token)); auto h1 = co_1.exchange(pattern1(field_1a), pattern2(field_2a), pattern1(field_3a)); auto h2 = co_2.exchange(pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)); // ... overlap communication (packing, posting) with computation here @@ -775,6 +753,9 @@ bool test0() h2.wait(); #endif #ifdef GHEX_TEST_SERIAL_SPLIT_VECTOR + auto token = context.get_token(); + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(token)); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(token)); std::vector> field_vec_a{ pattern1(field_1a), pattern2(field_2a), @@ -792,18 +773,19 @@ bool test0() #endif #ifdef GHEX_TEST_THREADS - auto func = [](decltype(co)& co_, auto... bis) + auto func = [&context](auto... bis) { + auto co_ = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co_.bexchange(bis...); }; // packing and posting may be done concurrently // waiting and unpacking may be done concurrently std::vector threads; - threads.push_back(std::thread{func, std::ref(co_1), + threads.push_back(std::thread{func, pattern1(field_1a), pattern2(field_2a), pattern1(field_3a)}); - threads.push_back(std::thread{func, std::ref(co_2), + threads.push_back(std::thread{func, pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)}); @@ -812,8 +794,9 @@ bool test0() #endif #ifdef GHEX_TEST_THREADS_VECTOR using field_vec_type = std::vector>; - auto func = [](decltype(co)& co_, field_vec_type& vec) + auto func = [&context](field_vec_type& vec) { + auto co_ = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co_.exchange(vec.data(), vec.size()).wait(); }; // packing and posting may be done concurrently @@ -827,25 +810,26 @@ bool test0() pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)}; - threads.push_back(std::thread{func, std::ref(co_1), std::ref(field_vec_a)}); - threads.push_back(std::thread{func, std::ref(co_2), std::ref(field_vec_b)}); + threads.push_back(std::thread{func, std::ref(field_vec_a)}); + threads.push_back(std::thread{func, std::ref(field_vec_b)}); // ... overlap communication with computation here for (auto& t : threads) t.join(); #endif #ifdef GHEX_TEST_ASYNC_ASYNC - auto func = [](decltype(co)& co_, auto... bis) + auto func = [&context](auto... bis) { + auto co_ = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co_.bexchange(bis...); }; // packing and posting may be done concurrently // waiting and unpacking may be done concurrently auto policy = std::launch::async; - auto future_1 = std::async(policy, func, std::ref(co_1), + auto future_1 = std::async(policy, func, pattern1(field_1a), pattern2(field_2a), pattern1(field_3a)); - auto future_2 = std::async(policy, func, std::ref(co_2), + auto future_2 = std::async(policy, func, pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)); @@ -855,8 +839,9 @@ bool test0() #endif #ifdef GHEX_TEST_ASYNC_ASYNC_VECTOR using field_vec_type = std::vector>; - auto func = [](decltype(co)& co_, field_vec_type& vec) + auto func = [&context](field_vec_type& vec) { + auto co_ = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); co_.exchange(vec.data(), vec.size()).wait(); }; // packing and posting may be done concurrently @@ -870,26 +855,28 @@ bool test0() pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)}; - auto future_1 = std::async(policy, func, std::ref(co_1), std::ref(field_vec_a)); - auto future_2 = std::async(policy, func, std::ref(co_2), std::ref(field_vec_b)); + auto future_1 = std::async(policy, func, std::ref(field_vec_a)); + auto future_2 = std::async(policy, func, std::ref(field_vec_b)); // ... overlap communication with computation here future_1.wait(); future_2.wait(); #endif #ifdef GHEX_TEST_ASYNC_DEFERRED - auto func_h = [](decltype(co)& co_, auto... bis) + auto func_h = [](auto co_, auto... bis) { - return co_.exchange(bis...); + return co_->exchange(bis...); }; + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); // packing and posting serially on current thread // waiting and unpacking serially on current thread auto policy = std::launch::deferred; - auto future_1 = std::async(policy, func_h, std::ref(co_1), + auto future_1 = std::async(policy, func_h, &co_1, pattern1(field_1a), pattern2(field_2a), pattern1(field_3a)); - auto future_2 = std::async(policy, func_h, std::ref(co_2), + auto future_2 = std::async(policy, func_h, &co_2, pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)); @@ -903,10 +890,12 @@ bool test0() #endif #ifdef GHEX_TEST_ASYNC_DEFERRED_VECTOR using field_vec_type = std::vector>; - auto func_h = [](decltype(co)& co_, field_vec_type& vec) + auto func_h = [](auto co_, field_vec_type& vec) { - return co_.exchange(vec.data(), vec.size()); + return co_->exchange(vec.data(), vec.size()); }; + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); // packing and posting may be done concurrently // waiting and unpacking may be done concurrently auto policy = std::launch::deferred; @@ -918,8 +907,8 @@ bool test0() pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)}; - auto future_1 = std::async(policy, func_h, std::ref(co_1), std::ref(field_vec_a)); - auto future_2 = std::async(policy, func_h, std::ref(co_2), std::ref(field_vec_b)); + auto future_1 = std::async(policy, func_h, &co_1, std::ref(field_vec_a)); + auto future_2 = std::async(policy, func_h, &co_2, std::ref(field_vec_b)); // deferred policy: essentially serial on current thread auto h1 = future_1.get(); auto h2 = future_2.get(); @@ -930,18 +919,20 @@ bool test0() #endif #ifdef GHEX_TEST_ASYNC_ASYNC_WAIT - auto func_h = [](decltype(co)& co_, auto... bis) + auto func_h = [](auto co_, auto... bis) { - return co_.exchange(bis...); + return co_->exchange(bis...); }; + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); // packing and posting may be done concurrently // waiting and unpacking serially auto policy = std::launch::async; - auto future_1 = std::async(policy, func_h, std::ref(co_1), + auto future_1 = std::async(policy, func_h, &co_1, pattern1(field_1a), pattern2(field_2a), pattern1(field_3a)); - auto future_2 = std::async(policy, func_h, std::ref(co_2), + auto future_2 = std::async(policy, func_h, &co_2, pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)); @@ -952,10 +943,12 @@ bool test0() #endif #ifdef GHEX_TEST_ASYNC_ASYNC_WAIT_VECTOR using field_vec_type = std::vector>; - auto func_h = [](decltype(co)& co_, field_vec_type& vec) + auto func_h = [](auto co_, field_vec_type& vec) { - return co_.exchange(vec.data(), vec.size()); + return co_->exchange(vec.data(), vec.size()); }; + auto co_1 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); + auto co_2 = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); // packing and posting may be done concurrently // waiting and unpacking may be done concurrently auto policy = std::launch::async; @@ -967,8 +960,8 @@ bool test0() pattern1(field_1b), pattern2(field_2b), pattern1(field_3b)}; - auto future_1 = std::async(policy, func_h, std::ref(co_1), std::ref(field_vec_a)); - auto future_2 = std::async(policy, func_h, std::ref(co_2), std::ref(field_vec_b)); + auto future_1 = std::async(policy, func_h, &co_1, std::ref(field_vec_a)); + auto future_2 = std::async(policy, func_h, &co_2, std::ref(field_vec_b)); // ... overlap communication (packing, posting) with computation here // waiting and unpacking is serial here future_1.get().wait(); @@ -977,49 +970,13 @@ bool test0() } - //// print arrays - //std::cout.flush(); - //comm.barrier(); - //for (int r=0; r(local_domains[0], halos1, periodic, g_first, g_last, field_1a, comm); - passed = passed && test_values(local_domains[1], halos1, periodic, g_first, g_last, field_1b, comm); - passed = passed && test_values(local_domains[0], halos2, periodic, g_first, g_last, field_2a, comm); - passed = passed && test_values(local_domains[1], halos2, periodic, g_first, g_last, field_2b, comm); - passed = passed && test_values(local_domains[0], halos1, periodic, g_first, g_last, field_3a, comm); - passed = passed && test_values(local_domains[1], halos1, periodic, g_first, g_last, field_3b, comm); + passed = passed && test_values(local_domains[0], halos1, periodic, g_first, g_last, field_1a, context.world()); + passed = passed && test_values(local_domains[1], halos1, periodic, g_first, g_last, field_1b, context.world()); + passed = passed && test_values(local_domains[0], halos2, periodic, g_first, g_last, field_2a, context.world()); + passed = passed && test_values(local_domains[1], halos2, periodic, g_first, g_last, field_2b, context.world()); + passed = passed && test_values(local_domains[0], halos1, periodic, g_first, g_last, field_3a, context.world()); + passed = passed && test_values(local_domains[1], halos1, periodic, g_first, g_last, field_3b, context.world()); #ifdef STANDALONE if (passed) @@ -1032,32 +989,4 @@ bool test0() #endif } -#ifdef STANDALONE -#include -int main(int argc, char* argv[]) -{ - //MPI_Init(&argc,&argv); -#ifdef MULTI_THREADED_EXCHANGE - int provided; - int res = MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); - if (res == MPI_ERR_OTHER) - { - throw std::runtime_error("MPI init failed"); - } - if (provided < MPI_THREAD_MULTIPLE) - { - throw std::runtime_error("MPI does not support threading"); - } -#else - boost::mpi::environment env(argc, argv); -#endif - - auto passed = test0(); - -#ifdef MULTI_THREADED_EXCHANGE - MPI_Finalize(); -#endif - return 0; -} -#endif diff --git a/tests/data_store_test.cpp b/tests/data_store_test.cpp index 518437e..ae21f77 100644 --- a/tests/data_store_test.cpp +++ b/tests/data_store_test.cpp @@ -11,10 +11,15 @@ #include #include -#include +#include +#include #include #include +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; + TEST(data_store, make) { const int Nx0 = 10; @@ -38,10 +43,11 @@ TEST(data_store, make) MPI_Cart_create(MPI_COMM_WORLD, 3, &dimensions[0], period, false, &CartComm); const std::array extents{Nx0,Ny0,Nz0}; - //auto grid = gridtools::make_gt_processor_grid>(extents, periodicity, CartComm); - auto grid = gridtools::ghex::make_gt_processor_grid(extents, periodicity, CartComm); + context_type context(1, CartComm); + + auto grid = gridtools::ghex::make_gt_processor_grid(context, extents, periodicity); auto pattern1 = gridtools::ghex::make_gt_pattern(grid, std::array{1,1,1,1,0,0}); - auto co = gridtools::ghex::make_communication_object(); + auto co = gridtools::ghex::make_communication_object(context.get_communicator(context.get_token())); using host_backend_t = gridtools::backend::mc; using host_storage_info_t = gridtools::storage_traits::storage_info_t<0, 3, halo_t>; diff --git a/tests/transport/test_attach_detach.cpp b/tests/transport/test_attach_detach.cpp index 5413665..659cca9 100644 --- a/tests/transport/test_attach_detach.cpp +++ b/tests/transport/test_attach_detach.cpp @@ -10,14 +10,18 @@ */ #include -#include +#include +#include #include #include #include +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; using allocator_type = std::allocator; -using comm_type = gridtools::ghex::tl::communicator; +using comm_type = context_type::communicator_type; using callback_comm_type = gridtools::ghex::tl::callback_communicator; //using callback_comm_type = gridtools::ghex::tl::callback_communicator_ts; using message_type = typename callback_comm_type::message_type; @@ -27,7 +31,9 @@ const unsigned int SIZE = 1<<12; TEST(attach, attach_progress) { bool ok = true; - comm_type comm; + + context_type context(1,MPI_COMM_WORLD); + auto comm = context.get_communicator(context.get_token()); callback_comm_type cb_comm(comm); int cb_count = 0; @@ -66,7 +72,8 @@ TEST(attach, attach_progress) TEST(detach, detach_wait) { bool ok = true; - comm_type comm; + context_type context(1,MPI_COMM_WORLD); + auto comm = context.get_communicator(context.get_token()); callback_comm_type cb_comm(comm); int cb_count = 0; @@ -111,7 +118,8 @@ TEST(detach, detach_wait) TEST(detach, detach_cancel_unexpected) { bool ok = true; - comm_type comm; + context_type context(1,MPI_COMM_WORLD); + auto comm = context.get_communicator(context.get_token()); callback_comm_type cb_comm(comm); message_type send_msg{SIZE}; diff --git a/tests/transport/test_cancel_request.cpp b/tests/transport/test_cancel_request.cpp index c8ac731..3de7b96 100644 --- a/tests/transport/test_cancel_request.cpp +++ b/tests/transport/test_cancel_request.cpp @@ -1,5 +1,6 @@ #include -#include +#include +#include #include #include #include @@ -11,10 +12,15 @@ template using callback_comm_t = gridtools::ghex::tl::callback_communicator; //using callback_comm_t = gridtools::ghex::tl::callback_communicator_ts; +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; + int rank; const unsigned int SIZE = 1<<12; -bool test_simple(gridtools::ghex::tl::communicator &comm, int rank) { +template +bool test_simple(Comm& comm, int rank) { using allocator_type = std::allocator; using smsg_type = gridtools::ghex::tl::shared_message_buffer; @@ -55,7 +61,8 @@ bool test_simple(gridtools::ghex::tl::communicator } -bool test_single(gridtools::ghex::tl::communicator &comm, int rank) { +template +bool test_single(Comm& comm, int rank) { using allocator_type = std::allocator; using smsg_type = gridtools::ghex::tl::shared_message_buffer; @@ -138,7 +145,8 @@ class call_back { } }; -bool test_send_10(gridtools::ghex::tl::communicator &comm, int rank) { +template +bool test_send_10(Comm& comm, int rank) { using allocator_type = std::allocator; using smsg_type = gridtools::ghex::tl::shared_message_buffer; @@ -187,31 +195,21 @@ TEST(transport, check_mpi_ranks_eq_4) { } TEST(transport, cancel_requests_reposting) { - - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - - gridtools::ghex::tl::communicator comm; - - EXPECT_TRUE(test_send_10(comm, rank)); - + context_type context(1,MPI_COMM_WORLD); + auto comm = context.get_communicator(context.get_token()); + EXPECT_TRUE(test_send_10(comm, context.world().rank())); } TEST(transport, cancel_requests_simple) { - - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - - gridtools::ghex::tl::communicator comm; - - EXPECT_TRUE(test_simple(comm, rank)); + context_type context(1,MPI_COMM_WORLD); + auto comm = context.get_communicator(context.get_token()); + EXPECT_TRUE(test_simple(comm, context.world().rank())); } TEST(transport, cancel_single_request) { - - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - - gridtools::ghex::tl::communicator comm; - - EXPECT_TRUE(test_single(comm, rank)); + context_type context(1,MPI_COMM_WORLD); + auto comm = context.get_communicator(context.get_token()); + EXPECT_TRUE(test_single(comm, context.world().rank())); } diff --git a/tests/transport/test_low_level.cpp b/tests/transport/test_low_level.cpp index 5ddda43..5824a18 100644 --- a/tests/transport/test_low_level.cpp +++ b/tests/transport/test_low_level.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include #include #include @@ -9,6 +11,10 @@ template using callback_comm_t = gridtools::ghex::tl::callback_communicator; //using callback_comm_t = gridtools::ghex::tl::callback_communicator_ts; +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; + int rank; /** @@ -16,7 +22,10 @@ int rank; */ void test1() { - gridtools::ghex::tl::communicator sr; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); std::vector smsg = {1,2,3,4,5,6,7,8,9,10}; std::vector rmsg(10); @@ -50,7 +59,10 @@ void test1() { } void test2() { - gridtools::ghex::tl::communicator sr; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); using allocator_type = std::allocator; using smsg_type = gridtools::ghex::tl::shared_message_buffer; @@ -96,7 +108,10 @@ void test2() { } void test1_mesg() { - gridtools::ghex::tl::communicator sr; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); gridtools::ghex::tl::message_buffer<> smsg{40}; @@ -136,7 +151,10 @@ void test1_mesg() { } void test2_mesg() { - gridtools::ghex::tl::communicator sr; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); using allocator_type = std::allocator; using smsg_type = gridtools::ghex::tl::shared_message_buffer; using comm_type = std::remove_reference_t; @@ -189,7 +207,10 @@ void test2_mesg() { } void test1_shared_mesg() { - gridtools::ghex::tl::communicator sr; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); gridtools::ghex::tl::message_buffer<> smsg{40}; diff --git a/tests/transport/test_low_level_x.cpp b/tests/transport/test_low_level_x.cpp index cca6d85..de586be 100644 --- a/tests/transport/test_low_level_x.cpp +++ b/tests/transport/test_low_level_x.cpp @@ -1,5 +1,6 @@ #include -#include +#include +#include #include #include #include @@ -10,6 +11,10 @@ template using callback_comm_t = gridtools::ghex::tl::callback_communicator; //using callback_comm_t = gridtools::ghex::tl::callback_communicator_ts; +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; + /** * Simple Send recv on two ranks. * P0 sends a message to P1 and receive from P1, @@ -19,8 +24,11 @@ using callback_comm_t = gridtools::ghex::tl::callback_communicator; int rank; auto test1() { - using comm_type = gridtools::ghex::tl::communicator; - comm_type sr; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); + using comm_type = std::remove_reference_t; std::vector smsg = {0,0,0,0,1,0,0,0,2,0,0,0,3,0,0,0,4,0,0,0,5,0,0,0,6,0,0,0,7,0,0,0,8,0,0,0,9,0,0,0}; std::vector rmsg(40, 40); @@ -54,12 +62,15 @@ auto test1() { } auto test2() { - using sr_comm_type = gridtools::ghex::tl::communicator; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); + using sr_comm_type = std::remove_reference_t; using allocator_type = std::allocator; using smsg_type = gridtools::ghex::tl::shared_message_buffer; using cb_comm_type = callback_comm_t; - sr_comm_type sr; cb_comm_type cb_comm(sr); std::vector smsg = {0,0,0,0,1,0,0,0,2,0,0,0,3,0,0,0,4,0,0,0,5,0,0,0,6,0,0,0,7,0,0,0,8,0,0,0,9,0,0,0}; @@ -99,7 +110,11 @@ auto test2() { } auto test1_mesg() { - gridtools::ghex::tl::communicator sr; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); + using sr_comm_type = std::remove_reference_t; gridtools::ghex::tl::message_buffer<> smsg{40}; @@ -111,7 +126,7 @@ auto test1_mesg() { gridtools::ghex::tl::message_buffer<> rmsg{40}; - gridtools::ghex::tl::communicator::future rfut; + sr_comm_type::future rfut; if ( rank == 0 ) { sr.send(smsg, 1, 1).get(); @@ -141,12 +156,15 @@ auto test1_mesg() { } auto test2_mesg() { - using sr_comm_type = gridtools::ghex::tl::communicator; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); + using sr_comm_type = std::remove_reference_t; using allocator_type = std::allocator; using smsg_type = gridtools::ghex::tl::shared_message_buffer; using cb_comm_type = callback_comm_t; - sr_comm_type sr; cb_comm_type cb_comm(sr); gridtools::ghex::tl::message_buffer<> smsg{40}; @@ -194,7 +212,11 @@ auto test2_mesg() { } auto test1_shared_mesg() { - gridtools::ghex::tl::communicator sr; + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); + using sr_comm_type = std::remove_reference_t; gridtools::ghex::tl::shared_message_buffer<> smsg{40}; int* data = smsg.data(); @@ -205,7 +227,7 @@ auto test1_shared_mesg() { gridtools::ghex::tl::shared_message_buffer<> rmsg{40}; - gridtools::ghex::tl::communicator::future rfut; + sr_comm_type::future rfut; if ( rank == 0 ) { auto sf = sr.send(smsg, 1, 1); diff --git a/tests/transport/test_send_multi.cpp b/tests/transport/test_send_multi.cpp index 67e538a..2b05e9a 100644 --- a/tests/transport/test_send_multi.cpp +++ b/tests/transport/test_send_multi.cpp @@ -1,5 +1,6 @@ #include -#include +#include +#include #include #include @@ -9,6 +10,10 @@ template using callback_comm_t = gridtools::ghex::tl::callback_communicator; //using callback_comm_t = gridtools::ghex::tl::callback_communicator_ts; +using transport = gridtools::ghex::tl::mpi_tag; +using threading = gridtools::ghex::threads::atomic::primitives; +using context_type = gridtools::ghex::tl::context; + const int SIZE = 4000000; int mpi_rank; @@ -22,13 +27,13 @@ TEST(transport, send_multi) { EXPECT_EQ(size, 4); } + context_type context(1,MPI_COMM_WORLD); + auto token = context.get_token(); + mpi_rank = context.world().rank(); + context.barrier(token); - MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank); - - MPI_Barrier(MPI_COMM_WORLD); - - using comm_type = gridtools::ghex::tl::communicator; - comm_type comm; + auto comm = context.get_communicator(token); + using comm_type = std::remove_reference_t; using allocator_type = std::allocator; using smsg_type = gridtools::ghex::tl::shared_message_buffer;