Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/alby/trisolver-dist-opt-step3' i…
Browse files Browse the repository at this point in the history
…nto comm-grid-round-robin
  • Loading branch information
msimberg committed Dec 12, 2023
2 parents e65a26d + 4eb08c6 commit 3353914
Show file tree
Hide file tree
Showing 8 changed files with 585 additions and 63 deletions.
204 changes: 144 additions & 60 deletions include/dlaf/eigensolver/tridiag_solver/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ auto stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* typ
//
// @return k number of non-deflated eigenvectors
// @return k_local number of local non-deflated eigenvectors
// @return n_udl tuple with global indices for [first_dense, last_dense, last_lower]
template <class T>
auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub, const ColType* types,
const T* evals, SizeType* perm_sorted,
Expand Down Expand Up @@ -501,7 +502,6 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub
index_sorted_coltype[to_sizet(jjj_el)] = jj_el;
}

// TODO manage edge cases
std::array<SizeType, 3> n_udl = [&]() {
SizeType first_dense;
for (first_dense = 0; first_dense < n; ++first_dense) {
Expand All @@ -511,14 +511,11 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub
break;
}

SizeType last_dense;
for (last_dense = n - 1; last_dense >= 0; --last_dense) {
const SizeType initial_el = index_sorted_coltype[to_sizet(last_dense)];
const ColType coltype = types[to_sizet(initial_el)];
if (ColType::LowerHalf != coltype && ColType::Deflated != coltype)
break;
}

// Note:
// Eigenvectors will be sorted according index_sorted_coltype, i.e. local sort by coltype.
// Since it is a local order, it is legit if deflated are globally interlaced with other column
// types. However, GEMM will be able to skip just the last global contiguous group of deflated
// eigenvectors, but not the ones interlaced with others.
SizeType last_lower;
for (last_lower = n - 1; last_lower >= 0; --last_lower) {
const SizeType initial_el = index_sorted_coltype[to_sizet(last_lower)];
Expand All @@ -527,6 +524,14 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub
break;
}

SizeType last_dense;
for (last_dense = last_lower; last_dense >= 0; --last_dense) {
const SizeType initial_el = index_sorted_coltype[to_sizet(last_dense)];
const ColType coltype = types[to_sizet(initial_el)];
if (ColType::LowerHalf != coltype && ColType::Deflated != coltype)
break;
}

return std::array<SizeType, 3>{first_dense, last_dense + 1, last_lower + 1};
}();

Expand Down Expand Up @@ -1654,6 +1659,123 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S
}));
}

template <Backend B, class T, Device D, class KLcSender, class UDLSenders>
void multiplyEigenvectors(const GlobalElementIndex sub_offset, const matrix::Distribution& dist_sub,
comm::CommunicatorPipeline<comm::CommunicatorType::Row>& row_task_chain,
comm::CommunicatorPipeline<comm::CommunicatorType::Col>& col_task_chain,
const SizeType n_upper, const SizeType n_lower, Matrix<T, D>& e0,
Matrix<T, D>& e1, Matrix<T, D>& e2, KLcSender&& k_lc, UDLSenders&& n_udl) {
// Note:
// This function computes E0 = E1 . E2
//
// where E1 is the matrix with eigenvectors and it looks like this
//
// ┌──────────┐ k
// │ b │ │
//
// ┌── ┌───┬──────┬─┬────┐
// │ │UUU│DDDDDD│ │XXXX│
// │ │UUU│DDDDDD│ │XXXX│
// n_upper │ │UUU│DDDDDD│ │XXXX│
// │ │UUU│DDDDDD│ │XXXX│
// │ │UUU│DDDDDD│ │XXXX│
// ├── ├───┼──────┼─┤XXXX│
// │ │ │DDDDDD│L│XXXX│
// n_lower │ │ │DDDDDD│L│XXXX│
// │ │ │DDDDDD│L│XXXX│
// └── └───┴──────┴─┴────┘
// │ a │
// └───┘
// │ c │
// └────────────┘
//
// Where (a, b, c) are the values from n_udl
//
// Note:
// E1 matrix does not have all deflated values at the end, indeed part of them are "interlaced" with
// others. The GEMM will perform anyway a computation for deflated eigenvectors (which are zeroed out)
// while the copy step will be performed at "local" level, so even interlaced ones will get copied
// in the right spot.
//
// The multiplication in two different steps in order to skip zero blocks of the matrix, created by
// the grouping of eigenvectors of different lengths (UPPER, DENSE and LOWER).
//
// 1. GEMM1 = TL . TOP
// 2. GEMM2 = BR . BOTTOM
// 3. copy DEFLATED
//
// ┌────────────┬────┐
// │ │ │
// │ │ │
// │ T O P │ │
// │ │ │
// │ │ │
// ├────────────┤ │
// │ │ │
// │ │ │
// │B O T T O M │ │
// │ │ │
// └────────────┴────┘
//
// ┌──────────┬─┬────┐ ┌────────────┬────┐
// │ │0│ │ │ │ │
// │ │0│ D │ │ │ │
// │ TL │0│ E │ │ GEMM 1 │ C │
// │ │0│ F │ │ │ │
// │ │0│ L │ │ │ O │
// ├───┬──────┴─┤ A │ ├────────────┤ │
// │000│ │ T │ │ │ P │
// │000│ │ E │ │ │ │
// │000│ BR │ D │ │ GEMM 2 │ Y │
// │000│ │ │ │ │ │
// └───┴────────┴────┘ └────────────┴────┘

namespace ex = pika::execution::experimental;
using pika::execution::thread_priority;

ex::start_detached(
ex::when_all(std::forward<KLcSender>(k_lc), std::forward<UDLSenders>(n_udl)) |
ex::transfer(dlaf::internal::getBackendScheduler<Backend::MC>(thread_priority::high)) |
ex::then([dist_sub, sub_offset, n_upper, n_lower, e0 = e0.subPipeline(),
e1 = e1.subPipelineConst(), e2 = e2.subPipelineConst(),
sub_comm_row = row_task_chain.sub_pipeline(),
sub_comm_col = col_task_chain.sub_pipeline()](
const SizeType k_lc, const std::array<SizeType, 3>& n_udl) mutable {
using dlaf::matrix::internal::MatrixRef;

const SizeType n = dist_sub.size().cols();
const auto [a, b, c] = n_udl;

using GEMM = dlaf::multiplication::internal::General<B, D, T>;
{
MatrixRef<const T, D> e1_sub(e1, {sub_offset, {n_upper, b}});
MatrixRef<const T, D> e2_sub(e2, {sub_offset, {b, c}});
MatrixRef<T, D> e0_sub(e0, {sub_offset, {n_upper, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

{
MatrixRef<const T, D> e1_sub(e1, {{sub_offset.row() + n_upper, sub_offset.col() + a},
{n_lower, c - a}});
MatrixRef<const T, D> e2_sub(e2, {{sub_offset.row() + a, sub_offset.col()}, {c - a, c}});
MatrixRef<T, D> e0_sub(e0, {{sub_offset.row() + n_upper, sub_offset.col()}, {n_lower, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

if (k_lc < dist_sub.local_size().cols()) {
const SizeType k = dist_sub.global_element_from_local_element<Coord::Col>(k_lc);
const matrix::internal::SubMatrixSpec deflated_submat{{sub_offset.row(), sub_offset.col() + k},
{n, n - k}};
MatrixRef<T, D> sub_e0(e0, deflated_submat);
MatrixRef<const T, D> sub_e1(e1, deflated_submat);

copy(sub_e1, sub_e0);
}
}));
}

// Distributed version of the tridiagonal solver on CPUs
template <Backend B, class T, Device D, class RhoSender>
void mergeDistSubproblems(comm::CommunicatorPipeline<comm::CommunicatorType::Full>& full_task_chain,
Expand All @@ -1668,8 +1790,8 @@ void mergeDistSubproblems(comm::CommunicatorPipeline<comm::CommunicatorType::Ful

const matrix::Distribution& dist = ws.e0.distribution();

const GlobalElementIndex sub_offset{i_begin * dist.blockSize().rows(),
i_begin * dist.blockSize().cols()};
const GlobalElementIndex sub_offset{i_begin * dist.tile_size().rows(),
i_begin * dist.tile_size().cols()};
const matrix::Distribution dist_sub(
dist, {sub_offset,
{
Expand All @@ -1678,16 +1800,15 @@ void mergeDistSubproblems(comm::CommunicatorPipeline<comm::CommunicatorType::Ful
}});

// Calculate the size of the upper subproblem
const SizeType n = dist.globalTileElementDistance<Coord::Row>(i_begin, i_end);
const SizeType n_upper = dist.globalTileElementDistance<Coord::Row>(i_begin, i_split);
const SizeType n_lower = dist.globalTileElementDistance<Coord::Row>(i_split, i_end);
const SizeType n_upper = global_tile_element_distance<Coord::Row>(dist, i_begin, i_split);
const SizeType n_lower = global_tile_element_distance<Coord::Row>(dist, i_split, i_end);

// The local size of the subproblem
const GlobalTileIndex idx_gl_begin(i_begin, i_begin);
const LocalTileIndex idx_loc_begin{dist.nextLocalTileFromGlobalTile<Coord::Row>(i_begin),
dist.nextLocalTileFromGlobalTile<Coord::Col>(i_begin)};
const LocalTileIndex idx_loc_end{dist.nextLocalTileFromGlobalTile<Coord::Row>(i_end),
dist.nextLocalTileFromGlobalTile<Coord::Col>(i_end)};
const LocalTileIndex idx_loc_begin{dist.next_local_tile_from_global_tile<Coord::Row>(i_begin),
dist.next_local_tile_from_global_tile<Coord::Col>(i_begin)};
const LocalTileIndex idx_loc_end{dist.next_local_tile_from_global_tile<Coord::Row>(i_end),
dist.next_local_tile_from_global_tile<Coord::Col>(i_end)};
const LocalTileSize sz_loc_tiles = idx_loc_end - idx_loc_begin;
const LocalTileIndex idx_begin_tiles_vec(i_begin, 0);
const LocalTileSize sz_tiles_vec(i_end - i_begin, 1);
Expand Down Expand Up @@ -1772,7 +1893,9 @@ void mergeDistSubproblems(comm::CommunicatorPipeline<comm::CommunicatorType::Ful
applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1);
applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1);

// Note: here ws_hm.z0 is used as a contiguous buffer for the laed4 call
// Note:
// set0 is required because deflated eigenvectors rows won't be touched in rank1 and so they will be
// neutral when used in GEMM (copy will take care of them later)
matrix::util::set0<Backend::MC>(thread_priority::normal, idx_loc_begin, sz_loc_tiles, ws_hm.e2);
solveRank1ProblemDist(row_task_chain.exclusive(), col_task_chain.exclusive(), i_begin, i_end, k, k_lc,
std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_h.i4, ws_hm.i6, ws_hm.i2,
Expand All @@ -1783,47 +1906,8 @@ void mergeDistSubproblems(comm::CommunicatorPipeline<comm::CommunicatorType::Ful
//
// The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as
// prepared for the deflated system.
ex::start_detached(
ex::when_all(std::move(k_lc), std::move(n_udl)) |
ex::transfer(dlaf::internal::getBackendScheduler<Backend::MC>()) |
ex::then([dist_sub, sub_offset, n, n_upper, n_lower, e0 = ws.e0.subPipeline(),
e1 = ws.e1.subPipelineConst(), e2 = ws.e2.subPipelineConst(),
sub_comm_row = row_task_chain.sub_pipeline(),
sub_comm_col = col_task_chain.sub_pipeline()](
const SizeType k_lc, const std::array<SizeType, 3>& n_udl) mutable {
using dlaf::matrix::internal::MatrixRef;

const auto [a, b, c] = n_udl;

using GEMM = dlaf::multiplication::internal::General<B, D, T>;
{
MatrixRef<const T, D> e1_sub(e1, {sub_offset, {n_upper, b}});
MatrixRef<const T, D> e2_sub(e2, {sub_offset, {b, c}});
MatrixRef<T, D> e0_sub(e0, {sub_offset, {n_upper, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

{
MatrixRef<const T, D> e1_sub(e1, {{sub_offset.row() + n_upper, sub_offset.col() + a},
{n_lower, c - a}});
MatrixRef<const T, D> e2_sub(e2, {{sub_offset.row() + a, sub_offset.col()}, {c - a, c}});
MatrixRef<T, D> e0_sub(e0, {{sub_offset.row() + n_upper, sub_offset.col()}, {n_lower, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

// copy deflated from e1 to e0
if (k_lc < dist_sub.localSize().cols()) {
const SizeType k = dist_sub.globalElementFromLocalElement<Coord::Col>(k_lc);
const matrix::internal::SubMatrixSpec deflated_submat{{sub_offset.row(), sub_offset.col() + k},
{n, n - k}};
MatrixRef<T, D> sub_e0(e0, deflated_submat);
MatrixRef<const T, D> sub_e1(e1, deflated_submat);

copy(sub_e1, sub_e0);
}
}));
multiplyEigenvectors<B>(sub_offset, dist_sub, row_task_chain, col_task_chain, n_upper, n_lower, ws.e0,
ws.e1, ws.e2, std::move(k_lc), std::move(n_udl));

// Step #4: Final permutation to sort eigenvalues and eigenvectors
//
Expand Down
2 changes: 2 additions & 0 deletions miniapp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ DLAF_addMiniapp(miniapp_bt_reduction_to_band SOURCES miniapp_bt_reduction_to_ban

DLAF_addMiniapp(miniapp_triangular_solver SOURCES miniapp_triangular_solver.cpp)

DLAF_addMiniapp(miniapp_triangular_multiplication SOURCES miniapp_triangular_multiplication.cpp)

DLAF_addMiniapp(miniapp_eigensolver SOURCES miniapp_eigensolver.cpp)

DLAF_addMiniapp(miniapp_gen_eigensolver SOURCES miniapp_gen_eigensolver.cpp)
Expand Down
Loading

0 comments on commit 3353914

Please sign in to comment.