Skip to content

Commit

Permalink
Use communicator grid pipelines in tridiagonal eigensolver rot
Browse files Browse the repository at this point in the history
  • Loading branch information
msimberg committed Sep 26, 2023
1 parent 63152c0 commit 787dfef
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
4 changes: 1 addition & 3 deletions include/dlaf/eigensolver/tridiag_solver/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -1463,9 +1463,7 @@ void mergeDistSubproblems(comm::CommunicatorGrid& grid,
//
// Note: i_split is unique
const comm::IndexT_MPI tag = to_int(i_split);
// TODO: No cloning of communicator! It probably should use the pipeline.
applyGivensRotationsToMatrixColumns(grid.rowCommunicator(), tag, i_begin, i_end, std::move(rots),
ws.e0);
applyGivensRotationsToMatrixColumns(grid, row_task_chain, tag, i_begin, i_end, std::move(rots), ws.e0);
// Placeholder for rearranging the eigenvectors: (local permutation)
copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws.e1);

Expand Down
39 changes: 20 additions & 19 deletions include/dlaf/eigensolver/tridiag_solver/rot.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ namespace dlaf::eigensolver::internal {
namespace wrapper {

template <Device D, class T>
void sendCol(comm::Communicator& comm, const comm::IndexT_MPI rank_dest, const comm::IndexT_MPI tag,
const T* col_data, const SizeType n, MPI_Request* req) {
void sendCol(const comm::Communicator& comm, const comm::IndexT_MPI rank_dest,
const comm::IndexT_MPI tag, const T* col_data, const SizeType n, MPI_Request* req) {
static_assert(D == Device::CPU, "This function works just with CPU memory.");

DLAF_MPI_CHECK_ERROR(MPI_Isend(col_data, static_cast<int>(n), dlaf::comm::mpi_datatype<T>::type,
rank_dest, tag, comm, req));
}

template <Device D, class T>
void recvCol(comm::Communicator& comm, const comm::IndexT_MPI rank_dest, const comm::IndexT_MPI tag,
T* col_data, const SizeType n, MPI_Request* req) {
void recvCol(const comm::Communicator& comm, const comm::IndexT_MPI rank_dest,
const comm::IndexT_MPI tag, T* col_data, const SizeType n, MPI_Request* req) {
static_assert(D == Device::CPU, "This function works just with CPU memory.");

DLAF_MPI_CHECK_ERROR(MPI_Irecv(col_data, static_cast<int>(n), dlaf::comm::mpi_datatype<T>::type,
Expand Down Expand Up @@ -194,7 +194,7 @@ void applyGivensRotationsToMatrixColumns(const SizeType i_begin, const SizeType
/// Apply GivenRotations to tiles of the distributed square sub-matrix identified by tile in range
/// [i_begin, i_end).
///
/// @param comm_row row communicator
/// @param comm_row_chain row communicator pipeline
/// @param tag is used for all communications happening over @p comm_row
/// @param i_begin global tile index for both row and column identifying the start of the sub-matrix
/// @param i_end global tile index for both row and column identifying the end of the sub-matrix
Expand All @@ -205,10 +205,10 @@ void applyGivensRotationsToMatrixColumns(const SizeType i_begin, const SizeType
/// @pre mat is distributed along rows the same way as comm_row
/// @pre memory layout of @p mat is column major.
template <class T, Device D, class GRSender>
// TODO
void applyGivensRotationsToMatrixColumns(comm::Communicator comm_row, const comm::IndexT_MPI tag,
const SizeType i_begin, const SizeType i_end,
GRSender&& rots_fut, Matrix<T, D>& mat) {
void applyGivensRotationsToMatrixColumns(comm::CommunicatorGrid& grid,
common::Pipeline<comm::Communicator>& comm_row_chain,
const comm::IndexT_MPI tag, const SizeType i_begin,
const SizeType i_end, GRSender&& rots_fut, Matrix<T, D>& mat) {
// Note:
// a column index may be paired to more than one other index, this may lead to a race
// condition if parallelized trivially. Current implementation is serial.
Expand All @@ -217,9 +217,10 @@ void applyGivensRotationsToMatrixColumns(comm::Communicator comm_row, const comm
namespace tt = pika::this_thread::experimental;
namespace di = dlaf::internal;

DLAF_ASSERT_HEAVY(comm_row.size() == mat.commGridSize().cols(), comm_row.size(),
mat.commGridSize().cols());
DLAF_ASSERT_HEAVY(comm_row.rank() == mat.rankIndex().col(), comm_row.rank(), mat.rankIndex().col());
DLAF_ASSERT_HEAVY(grid.rowCommunicator().size() == mat.commGridSize().cols(),
grid.rowCommunicator().size(), mat.commGridSize().cols());
DLAF_ASSERT_HEAVY(grid.rowCommunicator().rank() == mat.rankIndex().col(),
grid.rowCommunicator().rank(), mat.rankIndex().col());

const matrix::Distribution& dist = mat.distribution();

Expand Down Expand Up @@ -253,9 +254,9 @@ void applyGivensRotationsToMatrixColumns(comm::Communicator comm_row, const comm
const matrix::Distribution dist_sub({range_size, range_size}, dist.blockSize(), dist.commGridSize(),
dist.rankIndex(), dist.rankGlobalTile({i_begin, i_begin}));

auto givens_rots_fn = [comm_row, tag, dist_sub, mb](std::vector<GivensRotation<T>> rots,
std::vector<matrix::Tile<T, D>> tiles,
std::vector<matrix::Tile<T, D>> all_ws) {
auto givens_rots_fn = [&comm_row_chain, tag, dist_sub,
mb](std::vector<GivensRotation<T>> rots, std::vector<matrix::Tile<T, D>> tiles,
std::vector<matrix::Tile<T, D>> all_ws) mutable {
// Note:
// It would have been enough to just get the first tile from the beginning, and it would have
// worked anyway (thanks to the fact that panel has its own memorychunk and the first tile would
Expand Down Expand Up @@ -314,10 +315,10 @@ void applyGivensRotationsToMatrixColumns(comm::Communicator comm_row, const comm
// Note:
// These communications use raw pointers, so correct lifetime management of related tiles
// is up to the caller.
comm_checkpoints.emplace_back(wrapper::scheduleSendCol<D, T>(comm_row, rank_partner, tag,
col_send, m));
comm_checkpoints.emplace_back(wrapper::scheduleRecvCol<D, T>(comm_row, rank_partner, tag,
col_recv, m));
comm_checkpoints.emplace_back(wrapper::scheduleSendCol<D, T>(comm_row_chain.read(), rank_partner,
tag, col_send, m));
comm_checkpoints.emplace_back(wrapper::scheduleRecvCol<D, T>(comm_row_chain.read(), rank_partner,
tag, col_recv, m));
}

// Note:
Expand Down
3 changes: 2 additions & 1 deletion test/unit/eigensolver/test_tridiag_solver_rot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ void testApplyGivenRotations(comm::CommunicatorGrid& grid, const SizeType m, con

{
matrix::MatrixMirror<T, D, Device::CPU> mat(mat_h);
applyGivensRotationsToMatrixColumns(grid.rowCommunicator(), tag, idx_begin, idx_end, ex::just(rots),
auto comm_row_chain = grid.row_communicator_pipeline();
applyGivensRotationsToMatrixColumns(grid, comm_row_chain, tag, idx_begin, idx_end, ex::just(rots),
mat.get());
}

Expand Down

0 comments on commit 787dfef

Please sign in to comment.