Skip to content

Commit

Permalink
Use execution::unpack from pika 0.17.0 to unpack tuples in permutatio…
Browse files Browse the repository at this point in the history
…ns implementation
  • Loading branch information
msimberg committed Aug 7, 2023
1 parent 68e3064 commit 37d1f1a
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions include/dlaf/permutations/general/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ void Permutations<B, D, T, C>::call(const SizeType i_begin, const SizeType i_end
std::move(mat_out_tiles));
};

auto permute_fn = [subm_dist](const auto i_perm, const auto& args) {
auto& [splits, index_tile_futs, mat_in_tiles, mat_out_tiles] = args;
auto permute_fn = [subm_dist](const auto i_perm, const auto& splits, const auto& index_tile_futs,
const auto& mat_in_tiles, const auto& mat_out_tiles) {
const TileElementIndex zero(0, 0);
const SizeType* perm_arr = index_tile_futs[0].get().ptr(zero);
const GlobalElementIndex out_begin{0, 0};
Expand All @@ -173,7 +173,7 @@ void Permutations<B, D, T, C>::call(const SizeType i_begin, const SizeType i_end
ex::start_detached(std::move(sender) |
dlaf::internal::transform(dlaf::internal::Policy<B>(),
std::move(setup_permute_fn)) |
ex::bulk(subm_dist.size().get<C>(), std::move(permute_fn)));
ex::unpack() | ex::bulk(subm_dist.size().get<C>(), std::move(permute_fn)));
}
else {
#if defined(DLAF_WITH_GPU)
Expand Down Expand Up @@ -393,8 +393,8 @@ void applyPackingIndex(const matrix::Distribution& subm_dist, IndexMapSender&& i
std::move(mat_out_tiles));
};

auto permute_fn = [subm_dist](const auto i_perm, const auto& args) {
auto& [splits, index_tile_futs, mat_in_tiles, mat_out_tiles] = args;
auto permute_fn = [subm_dist](const auto i_perm, const auto& splits, const auto& index_tile_futs,
const auto& mat_in_tiles, const auto& mat_out_tiles) {
TileElementIndex zero(0, 0);
const SizeType* perm_arr = index_tile_futs[0].get().ptr(zero);
const GlobalElementIndex out_begin{0, 0};
Expand All @@ -411,7 +411,7 @@ void applyPackingIndex(const matrix::Distribution& subm_dist, IndexMapSender&& i
ex::start_detached(std::move(sender) |
dlaf::internal::transform(dlaf::internal::Policy<Backend::MC>(),
std::move(setup_permute_fn)) |
ex::bulk(subm_dist.size().get<C>(), permute_fn));
ex::unpack() | ex::bulk(subm_dist.size().get<C>(), permute_fn));
}
else {
#if defined(DLAF_WITH_GPU)
Expand Down Expand Up @@ -480,8 +480,9 @@ void unpackLocalOnCPU(const matrix::Distribution& subm_dist, const matrix::Distr
std::move(mat_out_tiles));
};

auto permutations_unpack_local_f = [subm_dist](const auto i_perm, const auto& args) {
auto& [a, b, splits, perm_offseted, mat_in_tiles, mat_out_tiles] = args;
auto permutations_unpack_local_f = [subm_dist](const auto i_perm, const auto a, const auto b,
const auto& splits, const auto& perm_offseted,
const auto& mat_in_tiles, const auto& mat_out_tiles) {
const SizeType* perm_arr = perm_offseted.data();

// [a, b)
Expand All @@ -493,13 +494,13 @@ void unpackLocalOnCPU(const matrix::Distribution& subm_dist, const matrix::Distr
}
};

ex::start_detached(ex::when_all(std::forward<SendCountsSender>(send_counts),
std::forward<RecvCountsSender>(recv_counts),
std::forward<UnpackingIndexSender>(unpacking_index),
std::forward<MatSendSender>(mat_send),
std::forward<MatOutSender>(mat_out)) |
di::transform(di::Policy<Backend::MC>(), std::move(setup_unpack_local_f)) |
ex::bulk(subm_dist.size().get<C>(), std::move(permutations_unpack_local_f)));
ex::start_detached(
ex::when_all(std::forward<SendCountsSender>(send_counts),
std::forward<RecvCountsSender>(recv_counts),
std::forward<UnpackingIndexSender>(unpacking_index),
std::forward<MatSendSender>(mat_send), std::forward<MatOutSender>(mat_out)) |
di::transform(di::Policy<Backend::MC>(), std::move(setup_unpack_local_f)) | ex::unpack() |
ex::bulk(subm_dist.size().get<C>(), std::move(permutations_unpack_local_f)));
}

template <class T, Coord C, class RecvCountsSender, class UnpackingIndexSender, class MatRecvSender,
Expand Down Expand Up @@ -531,8 +532,9 @@ void unpackOthersOnCPU(const matrix::Distribution& subm_dist, const matrix::Dist
std::move(mat_out_tiles));
};

auto permutations_unpack_f = [subm_dist](const auto i_perm, const auto& args) {
auto& [a, b, splits, index_tile_futs, mat_in_tiles, mat_out_tiles] = args;
auto permutations_unpack_f = [subm_dist](const auto i_perm, const auto a, const auto b,
const auto& splits, const auto& index_tile_futs,
const auto& mat_in_tiles, const auto& mat_out_tiles) {
const SizeType* perm_arr = index_tile_futs[0].get().ptr();

// [0, a) and [b, end)
Expand All @@ -548,7 +550,7 @@ void unpackOthersOnCPU(const matrix::Distribution& subm_dist, const matrix::Dist
std::forward<UnpackingIndexSender>(unpacking_index),
std::forward<MatRecvSender>(mat_recv),
std::forward<MatOutSender>(mat_out)) |
di::transform(di::Policy<Backend::MC>(), std::move(setup_unpack_f)) |
di::transform(di::Policy<Backend::MC>(), std::move(setup_unpack_f)) | ex::unpack() |
ex::bulk(subm_dist.size().get<C>(), std::move(permutations_unpack_f)));
}

Expand Down

0 comments on commit 37d1f1a

Please sign in to comment.