From 62c1cbe07f0aa5d29944319ed9c251d646eafcac Mon Sep 17 00:00:00 2001 From: Alberto Invernizzi Date: Mon, 25 Sep 2023 10:58:15 +0200 Subject: [PATCH] merge-squashed trisolver dist change to reduce gemm cost --- .../dlaf/eigensolver/tridiag_solver/merge.h | 141 +++++++++++++----- 1 file changed, 104 insertions(+), 37 deletions(-) diff --git a/include/dlaf/eigensolver/tridiag_solver/merge.h b/include/dlaf/eigensolver/tridiag_solver/merge.h index 7ef731451e..3afec9732f 100644 --- a/include/dlaf/eigensolver/tridiag_solver/merge.h +++ b/include/dlaf/eigensolver/tridiag_solver/merge.h @@ -506,6 +506,35 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub index_sorted_coltype[to_sizet(jjj_el)] = jj_el; } + // TODO manage edge cases + std::array n_udl = [&]() { + SizeType first_dense; + for (first_dense = 0; first_dense < n; ++first_dense) { + const SizeType initial_el = index_sorted_coltype[to_sizet(first_dense)]; + const ColType coltype = types[to_sizet(initial_el)]; + if (ColType::UpperHalf != coltype) + break; + } + + SizeType last_dense; + for (last_dense = n - 1; last_dense >= 0; --last_dense) { + const SizeType initial_el = index_sorted_coltype[to_sizet(last_dense)]; + const ColType coltype = types[to_sizet(initial_el)]; + if (ColType::LowerHalf != coltype && ColType::Deflated != coltype) + break; + } + + SizeType last_lower; + for (last_lower = n - 1; last_lower >= 0; --last_lower) { + const SizeType initial_el = index_sorted_coltype[to_sizet(last_lower)]; + const ColType coltype = types[to_sizet(initial_el)]; + if (ColType::Deflated != coltype) + break; + } + + return std::array{first_dense, last_dense + 1, last_lower + 1}; + }(); + // invertIndex i3->i2 // i3 (in) : initial <--- deflated // i2 (out) : initial ---> deflated @@ -533,7 +562,7 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub } } - return std::tuple(k, k_lc); + return std::tuple(k, k_lc, n_udl); } template @@ -1309,9 +1338,7 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S const SizeType* i2 = i2_tiles_arr[0].get().ptr(); const SizeType* i6 = i6_tiles_arr[0].get().ptr(); - // STEP 0a: Fill ones for deflated Eigenvectors and copy related Eigenvalues (single-thread) - // Note: this step is completely independent from the rest, but it is small and it is going - // to be dropped soon. + // STEP 0a: Permute eigenvalues for deflated eigenvectors (single-thread) // Note: use last threads that in principle should have less work to do if (k < n && thread_idx == nthreads - 1) { const T* eval_initial_ptr = d_tiles_futs[0].get().ptr(); @@ -1320,23 +1347,6 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S for (SizeType jeg_el_lc = k_lc; jeg_el_lc < n_el_lc; ++jeg_el_lc) { const SizeType jeg_el = dist_sub.globalElementFromLocalElement(jeg_el_lc); - - const SizeType ieg_el = jeg_el; - - if (dist_sub.rankIndex().row() == dist_sub.rankGlobalElement(ieg_el)) { - const SizeType ieg_el_lc = - dist_sub.localElementFromGlobalElement(ieg_el); - const LocalTileIndex - ieg_lc{dist_sub.localTileFromLocalElement(ieg_el_lc), - dist_sub.localTileFromLocalElement(jeg_el_lc)}; - const SizeType linear_lc = dist_sub.localTileLinearIndex(ieg_lc); - const TileElementIndex - ijeg_el_tl{dist_sub.tileElementFromLocalElement(ieg_el_lc), - dist_sub.tileElementFromLocalElement(jeg_el_lc)}; - - evec_tiles[to_sizet(linear_lc)](ijeg_el_tl) = T{1}; - } - eval_ptr[jeg_el] = eval_initial_ptr[i6[jeg_el]]; } } @@ -1660,19 +1670,31 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, WorkSpace& ws, WorkSpaceHost& ws_h, DistWorkSpaceHostMirror& ws_hm) { namespace ex = pika::execution::experimental; + using matrix::internal::distribution::global_tile_element_distance; using pika::execution::thread_priority; - const matrix::Distribution& dist_evecs = ws.e0.distribution(); + const matrix::Distribution& dist = ws.e0.distribution(); + + const GlobalElementIndex sub_offset{i_begin * dist.blockSize().rows(), + i_begin * dist.blockSize().cols()}; + const matrix::Distribution dist_sub( + dist, {sub_offset, + { + global_tile_element_distance(dist, i_begin, i_end), + global_tile_element_distance(dist, i_begin, i_end), + }}); // Calculate the size of the upper subproblem - const SizeType n1 = dist_evecs.globalTileElementDistance(i_begin, i_split); + const SizeType n = dist.globalTileElementDistance(i_begin, i_end); + const SizeType n_upper = dist.globalTileElementDistance(i_begin, i_split); + const SizeType n_lower = dist.globalTileElementDistance(i_split, i_end); // The local size of the subproblem const GlobalTileIndex idx_gl_begin(i_begin, i_begin); - const LocalTileIndex idx_loc_begin{dist_evecs.nextLocalTileFromGlobalTile(i_begin), - dist_evecs.nextLocalTileFromGlobalTile(i_begin)}; - const LocalTileIndex idx_loc_end{dist_evecs.nextLocalTileFromGlobalTile(i_end), - dist_evecs.nextLocalTileFromGlobalTile(i_end)}; + const LocalTileIndex idx_loc_begin{dist.nextLocalTileFromGlobalTile(i_begin), + dist.nextLocalTileFromGlobalTile(i_begin)}; + const LocalTileIndex idx_loc_end{dist.nextLocalTileFromGlobalTile(i_end), + dist.nextLocalTileFromGlobalTile(i_end)}; const LocalTileSize sz_loc_tiles = idx_loc_end - idx_loc_begin; const LocalTileIndex idx_begin_tiles_vec(i_begin, 0); const LocalTileSize sz_tiles_vec(i_end - i_begin, 1); @@ -1699,7 +1721,7 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, } // Update indices of second sub-problem - addIndex(i_split, i_end, n1, ws_h.i1); + addIndex(i_split, i_end, n_upper, ws_h.i1); // Step #1 // @@ -1709,7 +1731,7 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // - deflate `d`, `z` and `c` // - apply Givens rotations to `Q` - `evecs` // - sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2); + sortIndex(i_begin, i_end, ex::just(n_upper), ws_h.d0, ws_h.i1, ws_hm.i2); auto rots = applyDeflation(i_begin, i_end, scaled_rho, std::move(tol), ws_hm.i2, ws_h.d0, ws_hm.z0, ws_h.c); @@ -1730,13 +1752,14 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`. // - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented) // - auto [k_unique, k_lc] = - ex::split_tuple(stablePartitionIndexForDeflation(dist_evecs, i_begin, i_end, ws_h.c, ws_h.d0, - ws_hm.i2, ws_h.i3, ws_hm.i5, ws_h.i4, ws_hm.i6)); + auto [k_unique, k_lc_unique, n_udl] = + ex::split_tuple(stablePartitionIndexForDeflation(dist, i_begin, i_end, ws_h.c, ws_h.d0, ws_hm.i2, + ws_h.i3, ws_hm.i5, ws_h.i4, ws_hm.i6)); // from now on i2 is the inverse of i6 invertIndex(i_begin, i_end, ws_hm.i6, ws_hm.i2); auto k = ex::split(std::move(k_unique)); + auto k_lc = ex::split(std::move(k_lc_unique)); // Reorder Eigenvectors if constexpr (Backend::MC == B) { @@ -1758,18 +1781,62 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // Note: here ws_hm.z0 is used as a contiguous buffer for the laed4 call matrix::util::set0(thread_priority::normal, idx_loc_begin, sz_loc_tiles, ws_hm.e2); - solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, k, std::move(k_lc), + solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, k, k_lc, std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_h.i4, ws_hm.i6, ws_hm.i2, ws_hm.e2); + copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2); // Step #3: Eigenvectors of the tridiagonal system: Q * U // // The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as // prepared for the deflated system. - copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2); - dlaf::multiplication::internal::generalSubMatrix(grid, row_task_chain, col_task_chain, - i_begin, i_end, T(1), ws.e1, ws.e2, T(0), - ws.e0); + ex::start_detached( + ex::when_all(std::move(k_lc), std::move(n_udl), row_task_chain(), col_task_chain()) | + ex::transfer(dlaf::internal::getBackendScheduler()) | + ex::then([dist_sub, sub_offset, n, n_upper, n_lower, e0 = ws.e0.subPipeline(), + e1 = ws.e1.subPipelineConst(), e2 = ws.e2.subPipelineConst()]( + const SizeType k_lc, const std::array& n_udl, auto&& row_comm_wrapper, + auto&& col_comm_wrapper) mutable { + using dlaf::matrix::internal::MatrixRef; + + common::Pipeline sub_comm_row(row_comm_wrapper.get()); + common::Pipeline sub_comm_col(col_comm_wrapper.get()); + + const auto [a, b, c] = n_udl; + + using GEMM = dlaf::multiplication::internal::General; + { + MatrixRef e1_sub(e1, {sub_offset, {n_upper, b}}); + MatrixRef e2_sub(e2, {sub_offset, {b, c}}); + MatrixRef e0_sub(e0, {sub_offset, {n_upper, c}}); + + GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub); + } + + { + MatrixRef e1_sub(e1, {{sub_offset.row() + n_upper, sub_offset.col() + a}, + {n_lower, c - a}}); + MatrixRef e2_sub(e2, {{sub_offset.row() + a, sub_offset.col()}, {c - a, c}}); + MatrixRef e0_sub(e0, {{sub_offset.row() + n_upper, sub_offset.col()}, {n_lower, c}}); + + GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub); + } + + // copy deflated from e1 to e0 + if (k_lc < dist_sub.localSize().cols()) { + const SizeType k = dist_sub.globalElementFromLocalElement(k_lc); + const matrix::internal::SubMatrixSpec deflated_submat{{sub_offset.row(), sub_offset.col() + k}, + {n, n - k}}; + MatrixRef sub_e0(e0, deflated_submat); + MatrixRef sub_e1(e1, deflated_submat); + + copy(sub_e1, sub_e0); + } + + namespace tt = pika::this_thread::experimental; + tt::sync_wait(sub_comm_row()); + tt::sync_wait(sub_comm_col()); + })); // Step #4: Final permutation to sort eigenvalues and eigenvectors //