Skip to content

Commit

Permalink
merge-squashed trisolver dist change to reduce gemm cost
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Dec 4, 2023
1 parent fffad13 commit 62c1cbe
Showing 1 changed file with 104 additions and 37 deletions.
141 changes: 104 additions & 37 deletions include/dlaf/eigensolver/tridiag_solver/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,35 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub
index_sorted_coltype[to_sizet(jjj_el)] = jj_el;
}

// TODO manage edge cases
std::array<SizeType, 3> n_udl = [&]() {
SizeType first_dense;
for (first_dense = 0; first_dense < n; ++first_dense) {
const SizeType initial_el = index_sorted_coltype[to_sizet(first_dense)];
const ColType coltype = types[to_sizet(initial_el)];
if (ColType::UpperHalf != coltype)
break;
}

SizeType last_dense;
for (last_dense = n - 1; last_dense >= 0; --last_dense) {
const SizeType initial_el = index_sorted_coltype[to_sizet(last_dense)];
const ColType coltype = types[to_sizet(initial_el)];
if (ColType::LowerHalf != coltype && ColType::Deflated != coltype)
break;
}

SizeType last_lower;
for (last_lower = n - 1; last_lower >= 0; --last_lower) {
const SizeType initial_el = index_sorted_coltype[to_sizet(last_lower)];
const ColType coltype = types[to_sizet(initial_el)];
if (ColType::Deflated != coltype)
break;
}

return std::array<SizeType, 3>{first_dense, last_dense + 1, last_lower + 1};
}();

// invertIndex i3->i2
// i3 (in) : initial <--- deflated
// i2 (out) : initial ---> deflated
Expand Down Expand Up @@ -533,7 +562,7 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub
}
}

return std::tuple(k, k_lc);
return std::tuple(k, k_lc, n_udl);
}

template <class T>
Expand Down Expand Up @@ -1309,9 +1338,7 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S
const SizeType* i2 = i2_tiles_arr[0].get().ptr();
const SizeType* i6 = i6_tiles_arr[0].get().ptr();

// STEP 0a: Fill ones for deflated Eigenvectors and copy related Eigenvalues (single-thread)
// Note: this step is completely independent from the rest, but it is small and it is going
// to be dropped soon.
// STEP 0a: Permute eigenvalues for deflated eigenvectors (single-thread)
// Note: use last threads that in principle should have less work to do
if (k < n && thread_idx == nthreads - 1) {
const T* eval_initial_ptr = d_tiles_futs[0].get().ptr();
Expand All @@ -1320,23 +1347,6 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S
for (SizeType jeg_el_lc = k_lc; jeg_el_lc < n_el_lc; ++jeg_el_lc) {
const SizeType jeg_el =
dist_sub.globalElementFromLocalElement<Coord::Col>(jeg_el_lc);

const SizeType ieg_el = jeg_el;

if (dist_sub.rankIndex().row() == dist_sub.rankGlobalElement<Coord::Row>(ieg_el)) {
const SizeType ieg_el_lc =
dist_sub.localElementFromGlobalElement<Coord::Row>(ieg_el);
const LocalTileIndex
ieg_lc{dist_sub.localTileFromLocalElement<Coord::Row>(ieg_el_lc),
dist_sub.localTileFromLocalElement<Coord::Col>(jeg_el_lc)};
const SizeType linear_lc = dist_sub.localTileLinearIndex(ieg_lc);
const TileElementIndex
ijeg_el_tl{dist_sub.tileElementFromLocalElement<Coord::Row>(ieg_el_lc),
dist_sub.tileElementFromLocalElement<Coord::Col>(jeg_el_lc)};

evec_tiles[to_sizet(linear_lc)](ijeg_el_tl) = T{1};
}

eval_ptr[jeg_el] = eval_initial_ptr[i6[jeg_el]];
}
}
Expand Down Expand Up @@ -1660,19 +1670,31 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid,
WorkSpace<T, D>& ws, WorkSpaceHost<T>& ws_h,
DistWorkSpaceHostMirror<T, D>& ws_hm) {
namespace ex = pika::execution::experimental;
using matrix::internal::distribution::global_tile_element_distance;
using pika::execution::thread_priority;

const matrix::Distribution& dist_evecs = ws.e0.distribution();
const matrix::Distribution& dist = ws.e0.distribution();

const GlobalElementIndex sub_offset{i_begin * dist.blockSize().rows(),
i_begin * dist.blockSize().cols()};
const matrix::Distribution dist_sub(
dist, {sub_offset,
{
global_tile_element_distance<Coord::Row>(dist, i_begin, i_end),
global_tile_element_distance<Coord::Col>(dist, i_begin, i_end),
}});

// Calculate the size of the upper subproblem
const SizeType n1 = dist_evecs.globalTileElementDistance<Coord::Row>(i_begin, i_split);
const SizeType n = dist.globalTileElementDistance<Coord::Row>(i_begin, i_end);
const SizeType n_upper = dist.globalTileElementDistance<Coord::Row>(i_begin, i_split);
const SizeType n_lower = dist.globalTileElementDistance<Coord::Row>(i_split, i_end);

// The local size of the subproblem
const GlobalTileIndex idx_gl_begin(i_begin, i_begin);
const LocalTileIndex idx_loc_begin{dist_evecs.nextLocalTileFromGlobalTile<Coord::Row>(i_begin),
dist_evecs.nextLocalTileFromGlobalTile<Coord::Col>(i_begin)};
const LocalTileIndex idx_loc_end{dist_evecs.nextLocalTileFromGlobalTile<Coord::Row>(i_end),
dist_evecs.nextLocalTileFromGlobalTile<Coord::Col>(i_end)};
const LocalTileIndex idx_loc_begin{dist.nextLocalTileFromGlobalTile<Coord::Row>(i_begin),
dist.nextLocalTileFromGlobalTile<Coord::Col>(i_begin)};
const LocalTileIndex idx_loc_end{dist.nextLocalTileFromGlobalTile<Coord::Row>(i_end),
dist.nextLocalTileFromGlobalTile<Coord::Col>(i_end)};
const LocalTileSize sz_loc_tiles = idx_loc_end - idx_loc_begin;
const LocalTileIndex idx_begin_tiles_vec(i_begin, 0);
const LocalTileSize sz_tiles_vec(i_end - i_begin, 1);
Expand All @@ -1699,7 +1721,7 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid,
}

// Update indices of second sub-problem
addIndex(i_split, i_end, n1, ws_h.i1);
addIndex(i_split, i_end, n_upper, ws_h.i1);

// Step #1
//
Expand All @@ -1709,7 +1731,7 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid,
// - deflate `d`, `z` and `c`
// - apply Givens rotations to `Q` - `evecs`
//
sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2);
sortIndex(i_begin, i_end, ex::just(n_upper), ws_h.d0, ws_h.i1, ws_hm.i2);

auto rots =
applyDeflation(i_begin, i_end, scaled_rho, std::move(tol), ws_hm.i2, ws_h.d0, ws_hm.z0, ws_h.c);
Expand All @@ -1730,13 +1752,14 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid,
// - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`.
// - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented)
//
auto [k_unique, k_lc] =
ex::split_tuple(stablePartitionIndexForDeflation(dist_evecs, i_begin, i_end, ws_h.c, ws_h.d0,
ws_hm.i2, ws_h.i3, ws_hm.i5, ws_h.i4, ws_hm.i6));
auto [k_unique, k_lc_unique, n_udl] =
ex::split_tuple(stablePartitionIndexForDeflation(dist, i_begin, i_end, ws_h.c, ws_h.d0, ws_hm.i2,
ws_h.i3, ws_hm.i5, ws_h.i4, ws_hm.i6));
// from now on i2 is the inverse of i6
invertIndex(i_begin, i_end, ws_hm.i6, ws_hm.i2);

auto k = ex::split(std::move(k_unique));
auto k_lc = ex::split(std::move(k_lc_unique));

// Reorder Eigenvectors
if constexpr (Backend::MC == B) {
Expand All @@ -1758,18 +1781,62 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid,

// Note: here ws_hm.z0 is used as a contiguous buffer for the laed4 call
matrix::util::set0<Backend::MC>(thread_priority::normal, idx_loc_begin, sz_loc_tiles, ws_hm.e2);
solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, k, std::move(k_lc),
solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, k, k_lc,
std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_h.i4, ws_hm.i6, ws_hm.i2,
ws_hm.e2);
copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2);

// Step #3: Eigenvectors of the tridiagonal system: Q * U
//
// The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as
// prepared for the deflated system.
copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2);
dlaf::multiplication::internal::generalSubMatrix<B, D, T>(grid, row_task_chain, col_task_chain,
i_begin, i_end, T(1), ws.e1, ws.e2, T(0),
ws.e0);
ex::start_detached(
ex::when_all(std::move(k_lc), std::move(n_udl), row_task_chain(), col_task_chain()) |
ex::transfer(dlaf::internal::getBackendScheduler<Backend::MC>()) |
ex::then([dist_sub, sub_offset, n, n_upper, n_lower, e0 = ws.e0.subPipeline(),
e1 = ws.e1.subPipelineConst(), e2 = ws.e2.subPipelineConst()](
const SizeType k_lc, const std::array<SizeType, 3>& n_udl, auto&& row_comm_wrapper,
auto&& col_comm_wrapper) mutable {
using dlaf::matrix::internal::MatrixRef;

common::Pipeline<comm::Communicator> sub_comm_row(row_comm_wrapper.get());
common::Pipeline<comm::Communicator> sub_comm_col(col_comm_wrapper.get());

const auto [a, b, c] = n_udl;

using GEMM = dlaf::multiplication::internal::General<B, D, T>;
{
MatrixRef<const T, D> e1_sub(e1, {sub_offset, {n_upper, b}});
MatrixRef<const T, D> e2_sub(e2, {sub_offset, {b, c}});
MatrixRef<T, D> e0_sub(e0, {sub_offset, {n_upper, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

{
MatrixRef<const T, D> e1_sub(e1, {{sub_offset.row() + n_upper, sub_offset.col() + a},
{n_lower, c - a}});
MatrixRef<const T, D> e2_sub(e2, {{sub_offset.row() + a, sub_offset.col()}, {c - a, c}});
MatrixRef<T, D> e0_sub(e0, {{sub_offset.row() + n_upper, sub_offset.col()}, {n_lower, c}});

GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub);
}

// copy deflated from e1 to e0
if (k_lc < dist_sub.localSize().cols()) {
const SizeType k = dist_sub.globalElementFromLocalElement<Coord::Col>(k_lc);
const matrix::internal::SubMatrixSpec deflated_submat{{sub_offset.row(), sub_offset.col() + k},
{n, n - k}};
MatrixRef<T, D> sub_e0(e0, deflated_submat);
MatrixRef<const T, D> sub_e1(e1, deflated_submat);

copy(sub_e1, sub_e0);
}

namespace tt = pika::this_thread::experimental;
tt::sync_wait(sub_comm_row());
tt::sync_wait(sub_comm_col());
}));

// Step #4: Final permutation to sort eigenvalues and eigenvectors
//
Expand Down

0 comments on commit 62c1cbe

Please sign in to comment.