Skip to content

Commit

Permalink
start factoring out gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Dec 12, 2023
1 parent e0b8b9e commit 4f31507
Showing 1 changed file with 60 additions and 47 deletions.
107 changes: 60 additions & 47 deletions include/dlaf/eigensolver/tridiag_solver/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -1650,6 +1650,64 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S
}));
}

template <Backend B, class T, Device D, class KLcSender, class UDLSenders>
void multiplyEigenvectors(const matrix::Distribution& dist_sub,
common::Pipeline<comm::Communicator>& row_task_chain,
common::Pipeline<comm::Communicator>& col_task_chain,
const GlobalElementIndex sub_offset, const SizeType n, const SizeType n_upper,
const SizeType n_lower, Matrix<T, D>& e0, Matrix<T, D>& e1, Matrix<T, D>& e2,
KLcSender&& k_lc, UDLSenders&& n_udl) {
namespace ex = pika::execution::experimental;

ex::start_detached(
ex::when_all(std::forward<KLcSender>(k_lc), std::forward<UDLSenders>(n_udl), row_task_chain(),
col_task_chain()) |
ex::transfer(dlaf::internal::getBackendScheduler<Backend::MC>()) |
ex::then([dist_sub, sub_offset, n, n_upper, n_lower, e0 = e0.subPipeline(),
e1 = e1.subPipelineConst(),
e2 = e2.subPipelineConst()](const SizeType k_lc, const std::array<SizeType, 3>& n_udl,
auto&& row_comm_wrapper, auto&& col_comm_wrapper) mutable {
using dlaf::matrix::internal::MatrixRef;

common::Pipeline<comm::Communicator> sub_comm_row(row_comm_wrapper.get());
common::Pipeline<comm::Communicator> sub_comm_col(col_comm_wrapper.get());

const auto [a, b, c] = n_udl;

using GEMM = dlaf::multiplication::internal::General<B, D, T>;
{
MatrixRef<const T, D> e1_sub(e1, {sub_offset, {n_upper, b}});
MatrixRef<const T, D> e2_sub(e2, {sub_offset, {b, c}});
MatrixRef<T, D> e0_sub(e0, {sub_offset, {n_upper, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

{
MatrixRef<const T, D> e1_sub(e1, {{sub_offset.row() + n_upper, sub_offset.col() + a},
{n_lower, c - a}});
MatrixRef<const T, D> e2_sub(e2, {{sub_offset.row() + a, sub_offset.col()}, {c - a, c}});
MatrixRef<T, D> e0_sub(e0, {{sub_offset.row() + n_upper, sub_offset.col()}, {n_lower, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

if (k_lc < dist_sub.localSize().cols()) {
const SizeType k = dist_sub.globalElementFromLocalElement<Coord::Col>(k_lc);
const matrix::internal::SubMatrixSpec deflated_submat{{sub_offset.row(), sub_offset.col() + k},
{n, n - k}};
MatrixRef<T, D> sub_e0(e0, deflated_submat);
MatrixRef<const T, D> sub_e1(e1, deflated_submat);

copy(sub_e1, sub_e0);
}

namespace tt = pika::this_thread::experimental;
tt::sync_wait(sub_comm_row());
tt::sync_wait(sub_comm_col());
}));
}

// Distributed version of the tridiagonal solver on CPUs
template <Backend B, class T, Device D, class RhoSender>
void mergeDistSubproblems(comm::CommunicatorGrid grid,
Expand Down Expand Up @@ -1781,53 +1839,8 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid,
//
// The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as
// prepared for the deflated system.
ex::start_detached(
ex::when_all(std::move(k_lc), std::move(n_udl), row_task_chain(), col_task_chain()) |
ex::transfer(dlaf::internal::getBackendScheduler<Backend::MC>()) |
ex::then([dist_sub, sub_offset, n, n_upper, n_lower, e0 = ws.e0.subPipeline(),
e1 = ws.e1.subPipelineConst(), e2 = ws.e2.subPipelineConst()](
const SizeType k_lc, const std::array<SizeType, 3>& n_udl, auto&& row_comm_wrapper,
auto&& col_comm_wrapper) mutable {
using dlaf::matrix::internal::MatrixRef;

common::Pipeline<comm::Communicator> sub_comm_row(row_comm_wrapper.get());
common::Pipeline<comm::Communicator> sub_comm_col(col_comm_wrapper.get());

const auto [a, b, c] = n_udl;

using GEMM = dlaf::multiplication::internal::General<B, D, T>;
{
MatrixRef<const T, D> e1_sub(e1, {sub_offset, {n_upper, b}});
MatrixRef<const T, D> e2_sub(e2, {sub_offset, {b, c}});
MatrixRef<T, D> e0_sub(e0, {sub_offset, {n_upper, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

{
MatrixRef<const T, D> e1_sub(e1, {{sub_offset.row() + n_upper, sub_offset.col() + a},
{n_lower, c - a}});
MatrixRef<const T, D> e2_sub(e2, {{sub_offset.row() + a, sub_offset.col()}, {c - a, c}});
MatrixRef<T, D> e0_sub(e0, {{sub_offset.row() + n_upper, sub_offset.col()}, {n_lower, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

// copy deflated from e1 to e0
if (k_lc < dist_sub.localSize().cols()) {
const SizeType k = dist_sub.globalElementFromLocalElement<Coord::Col>(k_lc);
const matrix::internal::SubMatrixSpec deflated_submat{{sub_offset.row(), sub_offset.col() + k},
{n, n - k}};
MatrixRef<T, D> sub_e0(e0, deflated_submat);
MatrixRef<const T, D> sub_e1(e1, deflated_submat);

copy(sub_e1, sub_e0);
}

namespace tt = pika::this_thread::experimental;
tt::sync_wait(sub_comm_row());
tt::sync_wait(sub_comm_col());
}));
multiplyEigenvectors<B>(dist_sub, row_task_chain, col_task_chain, sub_offset, n, n_upper, n_lower,
ws.e0, ws.e1, ws.e2, std::move(k_lc), std::move(n_udl));

// Step #4: Final permutation to sort eigenvalues and eigenvectors
//
Expand Down

0 comments on commit 4f31507

Please sign in to comment.