From 477b2e2e0920b06ee0648031e697371c3b22949c Mon Sep 17 00:00:00 2001 From: Alberto Invernizzi Date: Fri, 15 Sep 2023 12:20:50 +0200 Subject: [PATCH] merge-squashed: propedeuthic changes towards gemm cost reduction make rank1 work just on non-deflated (single-threaded) --- .../dlaf/eigensolver/tridiag_solver/impl.h | 12 +- .../dlaf/eigensolver/tridiag_solver/merge.h | 539 +++++++++--------- 2 files changed, 269 insertions(+), 282 deletions(-) diff --git a/include/dlaf/eigensolver/tridiag_solver/impl.h b/include/dlaf/eigensolver/tridiag_solver/impl.h index d12cf7cd8c..9ab2217d1c 100644 --- a/include/dlaf/eigensolver/tridiag_solver/impl.h +++ b/include/dlaf/eigensolver/tridiag_solver/impl.h @@ -219,7 +219,8 @@ void TridiagSolver::call(Matrix& tridiag, Matrix& Matrix(vec_size, vec_tile_size), // z0 Matrix(vec_size, vec_tile_size), // z1 Matrix(vec_size, vec_tile_size), // i2 - Matrix(vec_size, vec_tile_size)}; // i5 + Matrix(vec_size, vec_tile_size), // i5 + Matrix(vec_size, vec_tile_size)}; // i6 WorkSpaceHost ws_h{Matrix(vec_size, vec_tile_size), // d0 Matrix(vec_size, vec_tile_size), // c @@ -380,7 +381,8 @@ void TridiagSolver::call(comm::CommunicatorGrid grid, Matrix(dist_evals), // z0 Matrix(dist_evals), // z1 Matrix(dist_evals), // i2 - Matrix(dist_evals)}; // i5 + Matrix(dist_evals), // i5 + Matrix(dist_evals)}; // i6 WorkSpaceHost ws_h{Matrix(dist_evals), // d0 Matrix(dist_evals), // c @@ -392,7 +394,7 @@ void TridiagSolver::call(comm::CommunicatorGrid grid, Matrix ws_hm{initMirrorMatrix(ws.e0), initMirrorMatrix(ws.e2), initMirrorMatrix(ws.d1), initMirrorMatrix(ws.z0), initMirrorMatrix(ws.z1), initMirrorMatrix(ws.i2), - initMirrorMatrix(ws.i5)}; + initMirrorMatrix(ws.i5), initMirrorMatrix(ws.i6)}; // Set `ws.e0` to `zero` (needed for Given's rotation to make sure no random values are picked up) matrix::util::set0(thread_priority::normal, ws.e0); @@ -426,12 +428,12 @@ void TridiagSolver::call(comm::CommunicatorGrid grid, Matrix(grid, row_task_chain, 0, n, - ws_hm.i2, ws_hm.e0, ws_hm.e2); + ws_h.i1, ws_hm.e0, ws_hm.e2); copy(ws_hm.e2, evecs); } diff --git a/include/dlaf/eigensolver/tridiag_solver/merge.h b/include/dlaf/eigensolver/tridiag_solver/merge.h index 31fe2dd8f2..fbfcd5c1ca 100644 --- a/include/dlaf/eigensolver/tridiag_solver/merge.h +++ b/include/dlaf/eigensolver/tridiag_solver/merge.h @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -120,6 +121,7 @@ struct WorkSpace { Matrix i2; Matrix i5; + Matrix i6; }; template @@ -162,6 +164,7 @@ struct DistWorkSpaceHostMirror { HostMirrorMatrix i2; HostMirrorMatrix i5; + HostMirrorMatrix i6; }; template @@ -409,11 +412,12 @@ auto stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* typ // @param index_sorted_coltype array[n] local(sort(upper)|sort(dense)|sort(lower)|sort(deflated))) -> initial // // @return k number of non-deflated eigenvectors +// @return k_local number of local non-deflated eigenvectors template -SizeType stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub, const SizeType n, - const ColType* types, const T* evals, - const SizeType* perm_sorted, SizeType* index_sorted, - SizeType* index_sorted_coltype) { +auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub, const SizeType n, + const ColType* types, const T* evals, SizeType* perm_sorted, + SizeType* index_sorted, SizeType* index_sorted_coltype, + SizeType* i4, SizeType* i6) { // Note: // (in) types // column type of the initial indexing @@ -484,6 +488,9 @@ SizeType stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist std::partial_sum(rank_offsets.cbegin(), rank_offsets.cend(), rank_offsets.begin()); }); + const SizeType k_lc = + to_SizeType(offsets[to_sizet(dist_sub.rankIndex().col())][ev_sort_order(ColType::Deflated)]); + for (SizeType j_el = 0; j_el < nperms; ++j_el) { const SizeType jj_el = index_sorted[to_sizet(j_el)]; const ColType coltype = types[to_sizet(jj_el)]; @@ -499,7 +506,34 @@ SizeType stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist index_sorted_coltype[to_sizet(jjj_el)] = jj_el; } - return k; + // invertIndex i3->i2 + // i3 (in) : initial <--- deflated + // i2 (out) : initial ---> deflated + for (SizeType i = 0; i < n; ++i) + perm_sorted[index_sorted[i]] = i; + + // compose i5*i2 (!i3) -> i4 + // i5 (in) : initial <--- sort by coltype + // i2 (in) : deflated <--- initial + // i4 (out) : deflated <--- sort by col type + for (SizeType i = 0; i < n; ++i) + i4[i] = perm_sorted[index_sorted_coltype[i]]; + + // create i6 using i5 and i4 for deflated + for (SizeType j_el = 0, jnd_el = 0; j_el < nperms; ++j_el) { + const SizeType jj_el = index_sorted_coltype[to_sizet(j_el)]; + const ColType coltype = types[to_sizet(jj_el)]; + + if (ColType::Deflated != coltype) { + i6[j_el] = jnd_el; + ++jnd_el; + } + else { + i6[j_el] = i4[j_el]; + } + } + + return std::tuple(k, k_lc); } template @@ -532,12 +566,12 @@ auto stablePartitionIndexForDeflation( } template -auto stablePartitionIndexForDeflation(const matrix::Distribution& dist_evecs, const SizeType i_begin, - const SizeType i_end, Matrix& c, - Matrix& evals, - Matrix& in, - Matrix& out, - Matrix& out_by_coltype) { +auto stablePartitionIndexForDeflation( + const matrix::Distribution& dist_evecs, const SizeType i_begin, const SizeType i_end, + Matrix& c, Matrix& evals, + Matrix& in, Matrix& out, + Matrix& out_by_coltype, Matrix& i4, + Matrix& i6) { namespace ex = pika::execution::experimental; namespace di = dlaf::internal; using pika::execution::thread_stacksize; @@ -550,22 +584,26 @@ auto stablePartitionIndexForDeflation(const matrix::Distribution& dist_evecs, co auto part_fn = [n, dist_evecs_sub](const auto& c_tiles_futs, const auto& evals_tiles_futs, const auto& in_tiles_futs, const auto& out_tiles, - const auto& out_coltype_tiles) { + const auto& out_coltype_tiles, const auto& i4_tiles, + const auto& i6_tiles) { const TileElementIndex zero_idx(0, 0); const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx); const T* evals_ptr = evals_tiles_futs[0].get().ptr(zero_idx); - const SizeType* in_ptr = in_tiles_futs[0].get().ptr(zero_idx); + SizeType* in_ptr = in_tiles_futs[0].ptr(zero_idx); SizeType* out_ptr = out_tiles[0].ptr(zero_idx); SizeType* out_coltype_ptr = out_coltype_tiles[0].ptr(zero_idx); + SizeType* i4_ptr = i4_tiles[0].ptr(zero_idx); + SizeType* i6_ptr = i6_tiles[0].ptr(zero_idx); return stablePartitionIndexForDeflationArrays(dist_evecs_sub, n, c_ptr, evals_ptr, in_ptr, out_ptr, - out_coltype_ptr); + out_coltype_ptr, i4_ptr, i6_ptr); }; TileCollector tc{i_begin, i_end}; return ex::when_all(ex::when_all_vector(tc.read(c)), ex::when_all_vector(tc.read(evals)), - ex::when_all_vector(tc.read(in)), ex::when_all_vector(tc.readwrite(out)), - ex::when_all_vector(tc.readwrite(out_by_coltype))) | + ex::when_all_vector(tc.readwrite(in)), ex::when_all_vector(tc.readwrite(out)), + ex::when_all_vector(tc.readwrite(out_by_coltype)), + ex::when_all_vector(tc.readwrite(i4)), ex::when_all_vector(tc.readwrite(i6))) | di::transform(di::Policy(thread_stacksize::nostack), std::move(part_fn)); } @@ -1163,13 +1201,13 @@ struct ScopedSenderWait { } }; -template +template void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const SizeType i_begin, - const SizeType i_end, const LocalTileIndex ij_begin_lc, - const LocalTileSize sz_loc_tiles, KSender&& k, RhoSender&& rho, + const SizeType i_end, KSender&& k, KLcSender&& k_lc, RhoSender&& rho, Matrix& d, Matrix& z, - Matrix& evals, Matrix& i2, - Matrix& evecs) { + Matrix& evals, Matrix& i4, + Matrix& i6, + Matrix& i2, Matrix& evecs) { namespace ex = pika::execution::experimental; namespace di = dlaf::internal; namespace tt = pika::this_thread::experimental; @@ -1181,17 +1219,11 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S const SizeType n = problemSize(i_begin, i_end, dist); - const SizeType m_subm_el_lc = [=]() { - const auto i_loc_begin = ij_begin_lc.row(); - const auto i_loc_end = ij_begin_lc.row() + sz_loc_tiles.rows(); - return dist.localElementDistanceFromLocalTile(i_loc_begin, i_loc_end); - }(); - - const SizeType n_subm_el_lc = [=]() { - const auto i_loc_begin = ij_begin_lc.col(); - const auto i_loc_end = ij_begin_lc.col() + sz_loc_tiles.cols(); - return dist.localElementDistanceFromLocalTile(i_loc_begin, i_loc_end); - }(); + using dlaf::matrix::internal::distribution::global_tile_element_distance; + const matrix::Distribution dist_sub( + dist, {{i_begin * dist.blockSize().rows(), i_begin * dist.blockSize().cols()}, + {global_tile_element_distance(dist, i_begin, i_end), + global_tile_element_distance(dist, i_begin, i_end)}}); auto bcast_evals = [i_begin, i_end, dist](common::Pipeline& row_comm_chain, @@ -1227,7 +1259,7 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S }; // Note: at least two column of tiles per-worker, in the range [1, getTridiagRank1NWorkers()] - const std::size_t nthreads = [nrtiles = sz_loc_tiles.cols()]() { + const std::size_t nthreads = [nrtiles = dist_sub.localNrTiles().cols()]() { const std::size_t min_workers = 1; const std::size_t available_workers = getTridiagRank1NWorkers(); const std::size_t ideal_workers = util::ceilDiv(to_sizet(nrtiles), to_sizet(2)); @@ -1237,62 +1269,68 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S ex::start_detached( ex::when_all(ex::just(std::make_unique>(nthreads)), std::forward(row_comm), std::forward(col_comm), - std::forward(k), std::forward(rho), + std::forward(k), std::forward(k_lc), std::forward(rho), ex::when_all_vector(tc.read(d)), ex::when_all_vector(tc.readwrite(z)), - ex::when_all_vector(tc.readwrite(evals)), ex::when_all_vector(tc.read(i2)), + ex::when_all_vector(tc.readwrite(evals)), ex::when_all_vector(tc.read(i4)), + ex::when_all_vector(tc.read(i6)), ex::when_all_vector(tc.read(i2)), ex::when_all_vector(tc.readwrite(evecs)), // additional workspaces ex::just(std::vector>()), ex::just(memory::MemoryView())) | ex::transfer(di::getBackendScheduler(thread_priority::high)) | - ex::bulk(nthreads, [nthreads, n, n_subm_el_lc, m_subm_el_lc, i_begin, ij_begin_lc, sz_loc_tiles, - dist, bcast_evals, all_reduce_in_place]( + ex::bulk(nthreads, [nthreads, n, dist_sub, bcast_evals, all_reduce_in_place]( const std::size_t thread_idx, auto& barrier_ptr, auto& row_comm_wrapper, - auto& col_comm_wrapper, const auto& k, const auto& rho, - const auto& d_tiles_futs, auto& z_tiles, const auto& eval_tiles, - const auto& i2_tile_arr, const auto& evec_tiles, auto& ws_cols, + auto& col_comm_wrapper, const SizeType k, const SizeType k_lc, + const auto& rho, const auto& d_tiles_futs, auto& z_tiles, + const auto& eval_tiles, const auto& i4_tiles_arr, const auto& i6_tiles_arr, + const auto& i2_tiles_arr, const auto& evec_tiles, auto& ws_cols, auto& ws_row) { using dlaf::comm::internal::transformMPI; common::Pipeline row_comm_chain(row_comm_wrapper.get()); const dlaf::comm::Communicator& col_comm = col_comm_wrapper.get(); + const SizeType m_el_lc = dist_sub.localSize().rows(); + const SizeType n_el_lc = dist_sub.localSize().cols(); + const auto barrier_busy_wait = getTridiagRank1BarrierBusyWait(); - const std::size_t batch_size = - std::max(2, util::ceilDiv(to_sizet(sz_loc_tiles.cols()), nthreads)); - const SizeType begin = to_SizeType(thread_idx * batch_size); - const SizeType end = std::min(to_SizeType((thread_idx + 1) * batch_size), sz_loc_tiles.cols()); - // STEP 0a: Fill ones for deflated Eigenvectors. (single-thread) + // const std::size_t batch_size = + // std::max(2, util::ceilDiv(to_sizet(dist_sub.localNrTiles().cols()), nthreads)); + // const SizeType begin = to_SizeType(thread_idx * batch_size); + // const SizeType end = + // std::min(to_SizeType((thread_idx + 1) * batch_size), dist_sub.localNrTiles().cols()); + + const SizeType* i4 = i4_tiles_arr[0].get().ptr(); + const SizeType* i2 = i2_tiles_arr[0].get().ptr(); + const SizeType* i6 = i6_tiles_arr[0].get().ptr(); + + // STEP 0a: Fill ones for deflated Eigenvectors and copy related Eigenvalues (single-thread) // Note: this step is completely independent from the rest, but it is small and it is going // to be dropped soon. // Note: use last threads that in principle should have less work to do - if (thread_idx == nthreads - 1) { - // just if there are deflated eigenvectors - if (k < n) { - const GlobalElementSize origin_el(i_begin * dist.blockSize().rows(), - i_begin * dist.blockSize().cols()); - const SizeType* i2_perm = i2_tile_arr[0].get().ptr(); - - for (SizeType i_subm_el = 0; i_subm_el < n; ++i_subm_el) { - const SizeType j_subm_el = i2_perm[i_subm_el]; - - // if it is a deflated vector - if (j_subm_el >= k) { - const GlobalElementIndex ij_el(origin_el.rows() + i_subm_el, - origin_el.cols() + j_subm_el); - const GlobalTileIndex ij = dist.globalTileIndex(ij_el); - - if (dist.rankIndex() == dist.rankGlobalTile(ij)) { - const LocalTileIndex ij_lc = dist.localTileIndex(ij); - const SizeType linear_subm_lc = - (ij_lc.row() - ij_begin_lc.row()) + - (ij_lc.col() - ij_begin_lc.col()) * sz_loc_tiles.rows(); - const TileElementIndex ij_el_tl = dist.tileElementIndex(ij_el); - evec_tiles[to_sizet(linear_subm_lc)](ij_el_tl) = T{1}; - } - } + if (k < n && thread_idx == nthreads - 1) { + const T* eval_initial_ptr = d_tiles_futs[0].get().ptr(); + T* eval_ptr = eval_tiles[0].ptr(); + + for (SizeType jeg_el_lc = k_lc; jeg_el_lc < n_el_lc; ++jeg_el_lc) { + const SizeType jeg_el = dist_sub.globalElementFromLocalElement(jeg_el_lc); + + const SizeType ieg_el = jeg_el; + + if (dist_sub.rankIndex().row() == dist_sub.rankGlobalElement(ieg_el)) { + const SizeType ieg_el_lc = dist_sub.localElementFromGlobalElement(ieg_el); + const LocalTileIndex ieg_lc{dist_sub.localTileFromLocalElement(ieg_el_lc), + dist_sub.localTileFromLocalElement(jeg_el_lc)}; + const SizeType linear_lc = dist_sub.localTileLinearIndex(ieg_lc); + const TileElementIndex + ijeg_el_tl{dist_sub.tileElementFromLocalElement(ieg_el_lc), + dist_sub.tileElementFromLocalElement(jeg_el_lc)}; + + evec_tiles[to_sizet(linear_lc)](ijeg_el_tl) = T{1}; } + + eval_ptr[jeg_el] = eval_initial_ptr[i6[jeg_el]]; } } @@ -1306,17 +1344,17 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S // Note: // Considering that // - LAED4 requires working on k elements - // - Weight computation requires working on m_subm_el_lc + // - Weight computation requires working on m_el_lc // // and they are needed at two steps that cannot happen in parallel, we opted for allocating // the workspace with the highest requirement of memory, and reuse them for both steps. - const SizeType max_size = std::max(k, m_subm_el_lc); + const SizeType max_size = std::max(k, m_el_lc); for (std::size_t i = 0; i < nthreads; ++i) ws_cols.emplace_back(max_size); - ws_cols.emplace_back(m_subm_el_lc); + ws_cols.emplace_back(m_el_lc); - ws_row = memory::MemoryView(n_subm_el_lc); - std::fill_n(ws_row(), n_subm_el_lc, 0); + ws_row = memory::MemoryView(n_el_lc); + std::fill_n(ws_row(), n_el_lc, 0); } // Note: we have to wait that LAED4 workspaces are ready to be used @@ -1326,52 +1364,43 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S const T* z_ptr = z_tiles[0].ptr(); // STEP 1: LAED4 (multi-thread) - { - common::internal::SingleThreadedBlasScope single; + if (thread_idx == 0) { // TODO make it multi-threaded over multiple workers + common::internal::SingleThreadedBlasScope single; // TODO needed also for laed? T* eval_ptr = eval_tiles[0].ptr(); T* delta_ptr = ws_cols[thread_idx](); - for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) { - const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc); - const SizeType j = dist.globalTileFromLocalTile(j_lc); - const SizeType n_subm_el = dist.globalTileElementDistance(i_begin, j); - - // Skip columns that are in the deflation zone - if (n_subm_el >= k) - break; - - const SizeType n_el_tl = std::min(dist.tileSize(j), k - n_subm_el); - for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) { - const SizeType j_el = n_subm_el + j_el_tl; - - // Solve the deflated rank-1 problem - T& eigenval = eval_ptr[to_sizet(j_el)]; - lapack::laed4(to_int(k), to_int(j_el), d_ptr, z_ptr, delta_ptr, rho, &eigenval); - - // copy the parts from delta stored on this rank - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType linear_subm_lc = i_subm_lc + to_SizeType(j_subm_lc) * sz_loc_tiles.rows(); - auto& evec_tile = evec_tiles[to_sizet(linear_subm_lc)]; - - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType m_subm_el = dist.globalTileElementDistance(i_begin, i); - - const SizeType i_subm = i - i_begin; - const auto& i2_perm = i2_tile_arr[to_sizet(i_subm)].get(); - - const SizeType m_el_tl = std::min(dist.tileSize(i), n - m_subm_el); - for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { - const SizeType jj_subm_el = i2_perm({i_el_tl, 0}); - if (jj_subm_el < k) - evec_tile({i_el_tl, j_el_tl}) = delta_ptr[jj_subm_el]; - } + for (SizeType jeg_el_lc = 0; jeg_el_lc < k_lc; ++jeg_el_lc) { + const SizeType jeg_el = dist_sub.globalElementFromLocalElement(jeg_el_lc); + const SizeType jeg_lc = dist_sub.localTileFromLocalElement(jeg_el_lc); + + // Solve the deflated rank-1 problem + // Note: + // it solves considering the order in the original fully sorted non-deflated (i3) + // but it stores it in extended global (as eigenvectors are stored in E1) + const SizeType js_el = i6[jeg_el]; + T& eigenval = eval_ptr[to_sizet(jeg_el)]; // eval is in compact rank layout + lapack::laed4(to_signed(k), to_signed(js_el), d_ptr, z_ptr, delta_ptr, rho, + &eigenval); + + // Now laed4 result has to be copied in the right spot + const SizeType jeg_el_tl = dist_sub.tileElementFromGlobalElement(jeg_el); + + for (SizeType i_el_lc = 0; i_el_lc < m_el_lc; ++i_el_lc) { + const SizeType i_el = dist_sub.globalElementFromLocalElement(i_el_lc); + const SizeType is_el = i4[i_el]; + + // just non-deflated, because deflated have been already set to 0 + if (is_el < k) { + const SizeType i_lc = dist_sub.localTileFromLocalElement(i_el_lc); + const SizeType linear_lc = dist_sub.localTileLinearIndex({i_lc, jeg_lc}); + const auto& evec = evec_tiles[to_sizet(linear_lc)]; + const SizeType i_el_tl = dist_sub.tileElementFromLocalElement(i_el_lc); + evec({i_el_tl, jeg_el_tl}) = delta_ptr[is_el]; } } } } - // Note: This barrier ensures that LAED4 finished, so from now on values are available barrier_ptr->arrive_and_wait(barrier_busy_wait); @@ -1391,83 +1420,69 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S T* w = ws_cols[thread_idx](); // STEP 2a: copy diagonal from q -> w (or just initialize with 1) + // Note: + // Loop over compact rank (=expanded global) up to k_el_lc + // index on k_el_lc has to be converted to global element on k_el, so it can be used with + // permutations + // during the switch from col axis to row axis we must keep the matching between eigenvectors if (thread_idx == 0) { - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType i_subm_el = dist.globalTileElementDistance(i_begin, i); - const SizeType m_subm_el_lc = - dist.localElementDistanceFromLocalTile(ij_begin_lc.row(), i_lc); - const auto& i2 = i2_tile_arr[to_sizet(i - i_begin)].get(); - - const SizeType m_el_tl = std::min(dist.tileSize(i), n - i_subm_el); - for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { - const SizeType i_subm_el_lc = m_subm_el_lc + i_el_tl; - - const SizeType jj_subm_el = i2({i_el_tl, 0}); - const SizeType n_el = dist.globalTileElementDistance(0, i_begin); - const SizeType jj_el = n_el + jj_subm_el; - const SizeType jj = dist.globalTileFromGlobalElement(jj_el); - - if (dist.rankGlobalTile(jj) == dist.rankIndex().col()) { - const SizeType jj_lc = dist.localTileFromGlobalTile(jj); - const SizeType jj_subm_lc = jj_lc - ij_begin_lc.col(); - const SizeType jj_el_tl = dist.tileElementFromGlobalElement(jj_el); - - const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * jj_subm_lc; - - w[i_subm_el_lc] = q[to_sizet(linear_subm_lc)]({i_el_tl, jj_el_tl}); - } - else { - w[i_subm_el_lc] = T(1); - } + for (SizeType ieg_el_lc = 0; ieg_el_lc < m_el_lc; ++ieg_el_lc) { + const SizeType ieg_el = dist_sub.globalElementFromLocalElement(ieg_el_lc); + const SizeType is_el = i4[ieg_el]; + + if (is_el >= k) { + w[ieg_el_lc] = T{0}; + continue; + } + + const SizeType js_el = is_el; + const SizeType jeg_el = i2[js_el]; + + const GlobalElementIndex ijeg_subm_el(ieg_el, jeg_el); + + if (dist_sub.rankIndex().col() == dist_sub.rankGlobalElement(jeg_el)) { + const SizeType linear_subm_lc = dist_sub.localTileLinearIndex( + {dist_sub.localTileFromLocalElement(ieg_el_lc), + dist_sub.localTileFromGlobalElement(jeg_el)}); + const TileElementIndex ij_tl = dist_sub.tileElementIndex(ijeg_subm_el); + w[ieg_el_lc] = q[to_sizet(linear_subm_lc)](ij_tl); + } + else { + w[ieg_el_lc] = T{1}; } } } else { // other workers - std::fill_n(w, m_subm_el_lc, T(1)); + std::fill_n(w, m_el_lc, T(1)); } barrier_ptr->arrive_and_wait(barrier_busy_wait); // STEP 2b: compute weights - for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) { - const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc); - const SizeType j = dist.globalTileFromLocalTile(j_lc); - const SizeType n_subm_el = dist.globalTileElementDistance(i_begin, j); - - // Skip columns that are in the deflation zone - if (n_subm_el >= k) - break; - - const SizeType n_el_tl = std::min(dist.tileSize(j), k - n_subm_el); - for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) { - const SizeType j_subm_el = n_subm_el + j_el_tl; - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType m_subm_el = dist.globalTileElementDistance(i_begin, i); - - auto& i2_perm = i2_tile_arr[to_sizet(i - i_begin)].get(); - - const SizeType m_el_tl = std::min(dist.tileSize(i), n - m_subm_el); - for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { - const SizeType ii_subm_el = i2_perm({i_el_tl, 0}); - - // deflated zone - if (ii_subm_el >= k) - continue; - - // diagonal - if (ii_subm_el == j_subm_el) - continue; - - const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc; - const SizeType i_subm_el_lc = i_subm_lc * dist.blockSize().rows() + i_el_tl; - - w[i_subm_el_lc] *= q[to_sizet(linear_subm_lc)]({i_el_tl, j_el_tl}) / - (d_ptr[to_sizet(ii_subm_el)] - d_ptr[to_sizet(j_subm_el)]); - } + if (thread_idx == 0) { // TODO make it multithreaded again + for (SizeType jeg_el_lc = 0; jeg_el_lc < k_lc; ++jeg_el_lc) { + const SizeType jeg_el = dist_sub.globalElementFromLocalElement(jeg_el_lc); + const SizeType js_el = i6[jeg_el]; + + for (SizeType ieg_el_lc = 0; ieg_el_lc < m_el_lc; ++ieg_el_lc) { + const SizeType ieg_el = dist_sub.globalElementFromLocalElement(ieg_el_lc); + const SizeType is_el = i4[ieg_el]; + + // skip if deflated + if (is_el >= k) + continue; + + // skip if originally it was on the diagonal + if (is_el == js_el) + continue; + + const SizeType linear_lc = dist_sub.localTileLinearIndex( + {dist_sub.localTileFromLocalElement(ieg_el_lc), + dist_sub.localTileFromLocalElement(jeg_el_lc)}); + const TileElementIndex ij_tl = dist_sub.tileElementIndex({ieg_el, jeg_el}); + + w[ieg_el_lc] *= + q[to_sizet(linear_lc)](ij_tl) / (d_ptr[to_sizet(is_el)] - d_ptr[to_sizet(js_el)]); } } } @@ -1477,7 +1492,7 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S // STEP 2c: reduce, then finalize computation with sign and square root (single-thread) if (thread_idx == 0) { // local reduction from all bulk workers - for (SizeType i = 0; i < m_subm_el_lc; ++i) { + for (SizeType i = 0; i < m_el_lc; ++i) { for (std::size_t tidx = 1; tidx < nthreads; ++tidx) { const T* w_partial = ws_cols[tidx](); w[i] *= w_partial[i]; @@ -1485,22 +1500,17 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S } tt::sync_wait(ex::when_all(row_comm_chain(), - ex::just(MPI_PROD, common::make_data(w, m_subm_el_lc))) | + ex::just(MPI_PROD, common::make_data(w, m_el_lc))) | transformMPI(all_reduce_in_place)); + // TODO check all weights < 0 (!= 0 otherwise q elements are set to zero and then nomr = 0 => nan) + T* weights = ws_cols[nthreads](); - for (SizeType i_subm_el_lc = 0; i_subm_el_lc < m_subm_el_lc; ++i_subm_el_lc) { - const SizeType i_subm_lc = i_subm_el_lc / dist.blockSize().rows(); - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType i_subm = i - i_begin; - const SizeType i_subm_el = - i_subm * dist.blockSize().rows() + i_subm_el_lc % dist.blockSize().rows(); - - const auto* i2_perm = i2_tile_arr[0].get().ptr(); - const SizeType ii_subm_el = i2_perm[i_subm_el]; - weights[to_sizet(i_subm_el_lc)] = - std::copysign(std::sqrt(-w[i_subm_el_lc]), z_ptr[to_sizet(ii_subm_el)]); + // TODO this can be limited to k_lc + for (SizeType i_el_lc = 0; i_el_lc < m_el_lc; ++i_el_lc) { + const SizeType i_el = dist_sub.globalElementFromLocalElement(i_el_lc); + const SizeType ii_el = i4[i_el]; + weights[to_sizet(i_el_lc)] = std::copysign(std::sqrt(-w[i_el_lc]), z_ptr[to_sizet(ii_el)]); } } @@ -1509,49 +1519,43 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S // STEP 3: Compute eigenvectors of the modified rank-1 modification (normalize) (multi-thread) // STEP 3a: Form evecs using weights vector and compute (local) sum of squares - { + if (thread_idx == 0) { // TODO make it multithreaded again common::internal::SingleThreadedBlasScope single; const T* w = ws_cols[nthreads](); T* sum_squares = ws_row(); - for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) { - const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc); - const SizeType j = dist.globalTileFromLocalTile(j_lc); - const SizeType n_subm_el = dist.globalTileElementDistance(i_begin, j); - - // Skip columns that are in the deflation zone - if (n_subm_el >= k) - break; - - const SizeType n_el_tl = std::min(dist.tileSize(j), k - n_subm_el); - for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) { - const SizeType j_subm_el_lc = j_subm_lc * dist.blockSize().cols() + j_el_tl; - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType m_subm_el = dist.globalTileElementDistance(i_begin, i); - - const SizeType i_subm = i - i_begin; - const auto& i2_perm = i2_tile_arr[to_sizet(i_subm)].get(); - - const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc; - const auto& q_tile = q[to_sizet(linear_subm_lc)]; - - const SizeType m_el_tl = std::min(dist.tileSize(i), n - m_subm_el); - for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { - const SizeType ii_subm_el = i2_perm({i_el_tl, 0}); - - const SizeType i_subm_el_lc = i_subm_lc * dist.blockSize().rows() + i_el_tl; - if (ii_subm_el >= k) - q_tile({i_el_tl, j_el_tl}) = 0; - else - q_tile({i_el_tl, j_el_tl}) = w[i_subm_el_lc] / q_tile({i_el_tl, j_el_tl}); - } + for (SizeType jeg_el_lc = 0; jeg_el_lc < k_lc; ++jeg_el_lc) { + const SizeType jeg_lc = dist_sub.localTileFromLocalElement(jeg_el_lc); + const SizeType jeg_el_tl = dist_sub.tileElementFromLocalElement(jeg_el_lc); - sum_squares[j_subm_el_lc] += - blas::dot(m_el_tl, q_tile.ptr({0, j_el_tl}), 1, q_tile.ptr({0, j_el_tl}), 1); - } + for (SizeType ieg_el_lc = 0; ieg_el_lc < m_el_lc; ++ieg_el_lc) { + const SizeType ieg_el = dist_sub.globalElementFromLocalElement(ieg_el_lc); + const SizeType is_el = i4[ieg_el]; + + // it is a deflated row, skip it (it should be already 0) + if (is_el >= k) + continue; + + const LocalTileIndex ijeg_lc(dist_sub.localTileFromLocalElement(ieg_el_lc), + jeg_lc); + const SizeType ijeg_linear = dist_sub.localTileLinearIndex(ijeg_lc); + const TileElementIndex ijeg_el_tl( + dist_sub.tileElementFromLocalElement(ieg_el_lc), jeg_el_tl); + + const auto& q_tile = q[to_sizet(ijeg_linear)]; + + q_tile(ijeg_el_tl) = w[ieg_el_lc] / q_tile(ijeg_el_tl); + } + + // column-major once the full column has been updated, compute the sum of squares (for norm) + for (SizeType i_lc = 0; i_lc < dist_sub.localNrTiles().rows(); ++i_lc) { + const LocalTileIndex ijeg_lc(i_lc, jeg_lc); + const SizeType ijeg_linear = dist_sub.localTileLinearIndex(ijeg_lc); + const T* partial_evec = q[to_sizet(ijeg_linear)].ptr({0, jeg_el_tl}); + const SizeType i = dist_sub.globalTileFromLocalTile(i_lc); + const SizeType m_el_tl = dist_sub.tileSize(i); + sum_squares[jeg_el_lc] += blas::dot(m_el_tl, partial_evec, 1, partial_evec, 1); } } } @@ -1560,41 +1564,33 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S // STEP 3b: Reduce to get the sum of all squares on all ranks if (thread_idx == 0) - tt::sync_wait(ex::just(std::cref(col_comm), MPI_SUM, - common::make_data(ws_row(), n_subm_el_lc)) | + // TODO it can be limited to k_lc + tt::sync_wait(ex::just(std::cref(col_comm), MPI_SUM, common::make_data(ws_row(), n_el_lc)) | transformMPI(all_reduce_in_place)); barrier_ptr->arrive_and_wait(barrier_busy_wait); // STEP 3c: Normalize (compute norm of each column and scale column vector) - { + if (thread_idx == 0) { // TODO make it multithreaded again common::internal::SingleThreadedBlasScope single; const T* sum_squares = ws_row(); - for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) { - const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc); - const SizeType j = dist.globalTileFromLocalTile(j_lc); - const SizeType n_subm_el = dist.globalTileElementDistance(i_begin, j); + for (SizeType jeg_el_lc = 0; jeg_el_lc < k_lc; ++jeg_el_lc) { + const SizeType jeg_lc = dist_sub.localTileFromLocalElement(jeg_el_lc); + const SizeType jeg_el_tl = dist_sub.tileElementFromLocalElement(jeg_el_lc); - // Skip columns that are in the deflation zone - if (n_subm_el >= k) - break; + const T vec_norm = std::sqrt(sum_squares[jeg_el_lc]); - const SizeType n_el_tl = std::min(dist.tileSize(j), k - n_subm_el); - for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) { - const SizeType j_subm_el_lc = j_subm_lc * dist.blockSize().cols() + j_el_tl; - const T vec_norm = std::sqrt(sum_squares[j_subm_el_lc]); + for (SizeType i_lc = 0; i_lc < dist_sub.localNrTiles().rows(); ++i_lc) { + const LocalTileIndex ijeg_lc(i_lc, jeg_lc); + const SizeType ijeg_linear = dist_sub.localTileLinearIndex(ijeg_lc); - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc; - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType m_subm_el = dist.globalTileElementDistance(i_begin, i); + T* partial_evec = q[to_sizet(ijeg_linear)].ptr({0, jeg_el_tl}); - const SizeType m_el_tl = std::min(dist.tileSize(i), n - m_subm_el); - blas::scal(m_el_tl, 1 / vec_norm, q[to_sizet(linear_subm_lc)].ptr({0, j_el_tl}), 1); - } + const SizeType i = dist_sub.globalTileFromLocalTile(i_lc); + const SizeType m_el_tl = dist_sub.tileSize(i); + blas::scal(m_el_tl, 1 / vec_norm, partial_evec, 1); } } } @@ -1681,8 +1677,13 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`. // - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented) // - auto k = ex::split(stablePartitionIndexForDeflation(dist_evecs, i_begin, i_end, ws_h.c, ws_h.d0, - ws_hm.i2, ws_h.i3, ws_hm.i5)); + auto [k_unique, k_lc] = + ex::split_tuple(stablePartitionIndexForDeflation(dist_evecs, i_begin, i_end, ws_h.c, ws_h.d0, + ws_hm.i2, ws_h.i3, ws_hm.i5, ws_h.i4, ws_hm.i6)); + // from now on i2 is the inverse of i6 + invertIndex(i_begin, i_end, ws_hm.i6, ws_hm.i2); + + auto k = ex::split(std::move(k_unique)); // Reorder Eigenvectors if constexpr (Backend::MC == B) { @@ -1701,28 +1702,12 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // Reorder Eigenvalues applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1); applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1); - copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.d1, ws_h.d0); - - // - // i3 (in) : initial <--- deflated - // i2 (out) : initial ---> deflated - // - invertIndex(i_begin, i_end, ws_h.i3, ws_hm.i2); - - // - // i5 (in) : initial <--- sort by coltype - // i2 (in) : deflated <--- initial - // i4 (out) : deflated <--- sort by col type - // - // TODO This is propedeutic to work in rank1 solver with columns sorted by type, so that they are - // well-shaped for an optimized gemm, but still keeping track of where the actual position sorted by - // eigenvalues is. - applyIndex(i_begin, i_end, ws_hm.i5, ws_hm.i2, ws_h.i4); // Note: here ws_hm.z0 is used as a contiguous buffer for the laed4 call matrix::util::set0(thread_priority::normal, idx_loc_begin, sz_loc_tiles, ws_hm.e2); - solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, idx_loc_begin, sz_loc_tiles, - k, std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_h.i4, ws_hm.e2); + solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, k, std::move(k_lc), + std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_h.i4, ws_hm.i6, ws_hm.i2, + ws_hm.e2); // Step #3: Eigenvectors of the tridiagonal system: Q * U // @@ -1738,8 +1723,8 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // i1 (in) : deflated <--- deflated (identity map) // i2 (out) : deflated <--- post_sorted // - initIndex(i_begin, i_end, ws_h.i1); - sortIndex(i_begin, i_end, std::move(k), ws_h.d0, ws_h.i1, ws_hm.i2); - copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i2, ws_h.i1); + + // TODO merge sort + sortIndex(i_begin, i_end, std::move(k), ws_h.d0, ws_hm.i2, ws_h.i1); } }