diff --git a/include/dlaf/eigensolver/tridiag_solver/impl.h b/include/dlaf/eigensolver/tridiag_solver/impl.h index 04191b86e7..f1f6a44d20 100644 --- a/include/dlaf/eigensolver/tridiag_solver/impl.h +++ b/include/dlaf/eigensolver/tridiag_solver/impl.h @@ -216,7 +216,8 @@ void TridiagSolver::call(Matrix& tridiag, Matrix& Matrix(vec_size, vec_tile_size), // z0 Matrix(vec_size, vec_tile_size), // z1 Matrix(vec_size, vec_tile_size), // i2 - Matrix(vec_size, vec_tile_size)}; // i5 + Matrix(vec_size, vec_tile_size), // i5 + Matrix(vec_size, vec_tile_size)}; // i6 WorkSpaceHost ws_h{Matrix(vec_size, vec_tile_size), // d0 Matrix(vec_size, vec_tile_size), // c @@ -373,7 +374,8 @@ void TridiagSolver::call(comm::CommunicatorGrid grid, Matrix(dist_evals), // z0 Matrix(dist_evals), // z1 Matrix(dist_evals), // i2 - Matrix(dist_evals)}; // i5 + Matrix(dist_evals), // i5 + Matrix(dist_evals)}; // i6 WorkSpaceHost ws_h{Matrix(dist_evals), // d0 Matrix(dist_evals), // c @@ -384,7 +386,8 @@ void TridiagSolver::call(comm::CommunicatorGrid grid, Matrix ws_hm{initMirrorMatrix(ws.e0), initMirrorMatrix(ws.e2), initMirrorMatrix(ws.d1), initMirrorMatrix(ws.z0), - initMirrorMatrix(ws.z1), initMirrorMatrix(ws.i2)}; + initMirrorMatrix(ws.z1), initMirrorMatrix(ws.i2), + initMirrorMatrix(ws.i5), initMirrorMatrix(ws.i6)}; // Set `ws.e0` to `zero` (needed for Given's rotation to make sure no random values are picked up) matrix::util::set0(pika::execution::thread_priority::normal, ws.e0); @@ -418,12 +421,12 @@ void TridiagSolver::call(comm::CommunicatorGrid grid, Matrix(grid, row_task_chain, 0, n, - ws_hm.i2, ws_hm.e0, ws_hm.e2); + ws_h.i1, ws_hm.e0, ws_hm.e2); copy(ws_hm.e2, evecs); } diff --git a/include/dlaf/eigensolver/tridiag_solver/merge.h b/include/dlaf/eigensolver/tridiag_solver/merge.h index 1fb110525d..2d31647ac5 100644 --- a/include/dlaf/eigensolver/tridiag_solver/merge.h +++ b/include/dlaf/eigensolver/tridiag_solver/merge.h @@ -120,6 +120,7 @@ struct WorkSpace { Matrix i2; Matrix i5; + Matrix i6; }; template @@ -161,6 +162,8 @@ struct DistWorkSpaceHostMirror { HostMirrorMatrix z1; HostMirrorMatrix i2; + HostMirrorMatrix i5; + HostMirrorMatrix i6; }; template @@ -255,7 +258,7 @@ auto calcTolerance(const SizeType i_begin, const SizeType i_end, Matrix initial. -// -// The permutation will allow to keep the mapping between sorted eigenvalues and unsorted eigenvectors, -// which is useful since eigenvectors are more expensive to permuted, so we can keep them in their initial order. -// -// @param n number of eigenvalues -// @param c_ptr array[n] containing the column type of each eigenvector after deflation (initial order) -// @param evals_ptr array[n] of eigenvalues sorted as in_ptr -// @param in_ptr array[n] representing permutation current -> initial (i.e. evals[i] -> c_ptr[in_ptr[i]]) -// @param out_ptr array[n] permutation (sorted non-deflated | sorted deflated) -> initial -// -// @return k number of non-deflated eigenvectors -template -SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* c_ptr, - const T* evals_ptr, const SizeType* in_ptr, - SizeType* out_ptr) { - // Get the number of non-deflated entries - SizeType k = 0; - for (SizeType i = 0; i < n; ++i) { - if (c_ptr[i] != ColType::Deflated) - ++k; - } - - // Create the permutation (sorted non-deflated | sorted deflated) -> initial - // Note: - // Since during deflation, eigenvalues related to deflated eigenvectors, might not be sorted anymore, - // this step also take care of sorting eigenvalues (actually just their related index) by their ascending value. - SizeType i1 = 0; // index of non-deflated values in out - SizeType i2 = k; // index of deflated values - for (SizeType i = 0; i < n; ++i) { - const SizeType ii = in_ptr[i]; - - // non-deflated are untouched, just squeeze them at the beginning as they appear - if (c_ptr[ii] != ColType::Deflated) { - out_ptr[i1] = ii; - ++i1; - } - // deflated are the ones that can have been moved "out-of-order" by deflation... - // ... so each time insert it in the right place based on eigenvalue value - else { - const T a = evals_ptr[ii]; - - SizeType j = i2; - // shift to right all greater values (shift just indices) - for (; j > k; --j) { - const T b = evals_ptr[out_ptr[j - 1]]; - if (a > b) { - break; - } - out_ptr[j] = out_ptr[j - 1]; - } - // and insert the current index in the empty place, such that eigenvalues are sorted. - out_ptr[j] = ii; - ++i2; - } - } - return k; -} - // This function returns number of non-deflated eigenvectors and a tuple with number of upper, dense // and lower non-deflated eigenvectors, together with two permutations: // - @p index_sorted (sort(non-deflated)|sort(deflated) -> initial. @@ -364,16 +307,16 @@ auto stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* typ std::array offsets{0, 0, 0, 0}; std::for_each(types, types + n, [&offsets](const auto& coltype) { if (coltype != ColType::Deflated) - offsets[1 + coltype_index(coltype)]++; + offsets[1 + sortOrder(coltype)]++; }); - std::array n_udl{offsets[1 + coltype_index(ColType::UpperHalf)], - offsets[1 + coltype_index(ColType::Dense)], - offsets[1 + coltype_index(ColType::LowerHalf)]}; + std::array n_udl{offsets[1 + sortOrder(ColType::UpperHalf)], + offsets[1 + sortOrder(ColType::Dense)], + offsets[1 + sortOrder(ColType::LowerHalf)]}; std::partial_sum(offsets.cbegin(), offsets.cend(), offsets.begin()); - const SizeType k = to_SizeType(offsets[coltype_index(ColType::Deflated)]); + const SizeType k = to_SizeType(offsets[sortOrder(ColType::Deflated)]); // Create the permutation (sorted non-deflated | sorted deflated) -> initial // Note: @@ -418,7 +361,7 @@ auto stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* typ for (SizeType j = 0; j < n; ++j) { const ColType& coltype = types[to_sizet(j)]; if (coltype != ColType::Deflated) { - auto& index_for_coltype = offsets[coltype_index(coltype)]; + auto& index_for_coltype = offsets[sortOrder(coltype)]; index_sorted_coltype[index_for_coltype] = j; ++index_for_coltype; } @@ -428,58 +371,260 @@ auto stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* typ return std::tuple(k, std::move(n_udl)); } +// This function returns number of global non-deflated eigenvectors, together with two permutations: +// - @p index_sorted_coltype (sort(upper)|sort(dense)|sort(lower)|sort(deflated)) -> initial +// - @p index_sorted (sort(non-deflated)|sort(deflated)) -> initial. +// +// Both permutations are represented using global indices, but: +// - @p index_sorted sorts "globally", i.e. considering all evecs across ranks +// - @p index_sorted_coltype sorts "locally", i.e. consdering just evecs from the same rank +// +// In addition, even if all ranks have the full permutation, it is important to highlight that +// thanks to how it is built, i.e. rank-independent permutations, @p index_sorted_coltype can be used as +// if it was distributed, since each "local" tile would contain global indices that are valid +// just on the related rank. +// +// rank | 0 | 1 | 0 | +// initial | 0U| 1L| 2X| 3U| 4U| 5X| 6L| 7L| 8X| +// index_sorted | 3U| 0U| 4U| 6L| 7L| 1L| 2X| 5X| 8X| -> UUULLLXXX +// index_sorted_col_type | 3U| 6L| 7L| 3U| 4U| 5X| 1L| 2X| 8X| +// +// index_sorted_col_type on can be used "locally": +// on rank0 | 3U| 6L| 7L| --| --| --| 1L| 2X| 8X| -> ULLLXX +// on rank1 | --| --| --| 3U| 4U| 5X| --| --| --| -> UUX +// +// where U: Upper, D: Dense, L: Lower, X: Deflated +// +// The permutations will allow to keep the mapping between sorted eigenvalues and unsorted eigenvectors, +// which is useful since eigenvectors are more expensive to permuted, so we can keep them in their +// initial order. +// +// @param n number of eigenvalues +// @param types array[n] column type of each eigenvector after deflation (initial order) +// @param evals array[n] of eigenvalues sorted as perm_sorted +// @param perm_sorted array[n] current -> initial (i.e. evals[i] -> types[perm_sorted[i]]) +// @param index_sorted array[n] global(sort(non-deflated)|sort(deflated))) -> initial +// @param index_sorted_coltype array[n] local(sort(upper)|sort(dense)|sort(lower)|sort(deflated))) -> initial +// +// @return k number of non-deflated eigenvectors +// @return k_local number of local non-deflated eigenvectors template -auto stablePartitionIndexForDeflation(const SizeType i_begin, const SizeType i_end, - Matrix& c, - Matrix& evals, - Matrix& in, - Matrix& out) { +auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub, const SizeType n, + const ColType* types, const T* evals, SizeType* perm_sorted, + SizeType* index_sorted, SizeType* index_sorted_coltype, + SizeType* i4, SizeType* i6) { + // Note: + // (in) types + // column type of the initial indexing + // (in) in_ptr + // initial <-- sorted by ascending eigenvalue + // (out) index_sorted + // initial <-- sorted by ascending eigenvalue in two groups (non-deflated | deflated) + // (out) index_sorted_coltype + // initial <-- sorted by ascending eigenvalue in four groups (upper | dense | lower | deflated) + const SizeType k = std::count_if(types, types + n, + [](const ColType coltype) { return ColType::Deflated != coltype; }); + + // Create the permutation (sorted non-deflated | sorted deflated) -> initial + // Note: + // Since during deflation, eigenvalues related to deflated eigenvectors, might not be sorted anymore, + // this step also take care of sorting eigenvalues (actually just their related index) by their ascending value. + SizeType i1 = 0; // index of non-deflated values in out + SizeType i2 = k; // index of deflated values + for (SizeType i = 0; i < n; ++i) { + const SizeType ii = perm_sorted[i]; + + // non-deflated are untouched, just squeeze them at the beginning as they appear + if (types[ii] != ColType::Deflated) { + index_sorted[i1] = ii; + ++i1; + } + // deflated are the ones that can have been moved "out-of-order" by deflation... + // ... so each time insert it in the right place based on eigenvalue value + else { + const T a = evals[ii]; + + SizeType j = i2; + // shift to right all greater values (shift just indices) + for (; j > k; --j) { + const T b = evals[index_sorted[j - 1]]; + if (a > b) { + break; + } + index_sorted[j] = index_sorted[j - 1]; + } + // and insert the current index in the empty place, such that eigenvalues are sorted. + index_sorted[j] = ii; + ++i2; + } + } + + // Create the permutation (sort(upper)|sort(dense)|sort(lower)|sort(deflated)) -> initial + // Note: + // index_sorted is used as "reference" in order to deal with deflated vectors in the right sorted order. + // In this way, also non-deflated are considered in a sorted way, which is not a requirement, + // but it does not hurt either. + const SizeType nperms = dist_sub.size().cols(); + + using offsets_t = std::array; + std::vector offsets(to_sizet(dist_sub.commGridSize().cols()), {0, 0, 0, 0}); + + for (SizeType j_el = 0; j_el < nperms; ++j_el) { + const SizeType jj_el = index_sorted[to_sizet(j_el)]; + const ColType coltype = types[to_sizet(jj_el)]; + + const comm::IndexT_MPI rank = dist_sub.rankGlobalElement(jj_el); + offsets_t& rank_offsets = offsets[to_sizet(rank)]; + + if (coltype != ColType::Deflated) + ++rank_offsets[1 + sortOrder(coltype)]; + } + std::for_each(offsets.begin(), offsets.end(), [](offsets_t& rank_offsets) { + std::partial_sum(rank_offsets.cbegin(), rank_offsets.cend(), rank_offsets.begin()); + }); + + const SizeType k_lc = + to_SizeType(offsets[to_sizet(dist_sub.rankIndex().col())][sortOrder(ColType::Deflated)]); + + for (SizeType j_el = 0; j_el < nperms; ++j_el) { + const SizeType jj_el = index_sorted[to_sizet(j_el)]; + const ColType coltype = types[to_sizet(jj_el)]; + + const comm::IndexT_MPI rank = dist_sub.rankGlobalElement(jj_el); + offsets_t& rank_offsets = offsets[to_sizet(rank)]; + + const SizeType jjj_el_lc = to_SizeType(rank_offsets[sortOrder(coltype)]++); + const SizeType jjj_el = + dist_sub.template globalElementFromLocalElementAndRank(rank, jjj_el_lc); + + index_sorted_coltype[to_sizet(jjj_el)] = jj_el; + } + + // TODO manage edge cases + std::array n_udl = [&]() { + SizeType first_dense; + for (first_dense = 0; first_dense < n; ++first_dense) { + const SizeType initial_el = index_sorted_coltype[to_sizet(first_dense)]; + const ColType coltype = types[to_sizet(initial_el)]; + if (ColType::UpperHalf != coltype) + break; + } + + SizeType last_dense; + for (last_dense = n - 1; last_dense >= 0; --last_dense) { + const SizeType initial_el = index_sorted_coltype[to_sizet(last_dense)]; + const ColType coltype = types[to_sizet(initial_el)]; + if (ColType::LowerHalf != coltype && ColType::Deflated != coltype) + break; + } + + SizeType last_lower; + for (last_lower = n - 1; last_lower >= 0; --last_lower) { + const SizeType initial_el = index_sorted_coltype[to_sizet(last_lower)]; + const ColType coltype = types[to_sizet(initial_el)]; + if (ColType::Deflated != coltype) + break; + } + + return std::array{first_dense, last_dense + 1, last_lower + 1}; + }(); + + // invertIndex i3->i2 + // i3 (in) : initial <--- deflated + // i2 (out) : initial ---> deflated + for (SizeType i = 0; i < n; ++i) + perm_sorted[index_sorted[i]] = i; + + // compose i5*i2 (!i3) -> i4 + // i5 (in) : initial <--- sort by coltype + // i2 (in) : deflated <--- initial + // i4 (out) : deflated <--- sort by col type + for (SizeType i = 0; i < n; ++i) + i4[i] = perm_sorted[index_sorted_coltype[i]]; + + // create i6 using i5 and i4 for deflated + for (SizeType j_el = 0, jnd_el = 0; j_el < nperms; ++j_el) { + const SizeType jj_el = index_sorted_coltype[to_sizet(j_el)]; + const ColType coltype = types[to_sizet(jj_el)]; + + if (ColType::Deflated != coltype) { + i6[j_el] = jnd_el; + ++jnd_el; + } + else { + i6[j_el] = i4[j_el]; + } + } + + return std::tuple(k, k_lc, n_udl); +} + +template +auto stablePartitionIndexForDeflation( + const SizeType i_begin, const SizeType i_end, Matrix& c, + Matrix& evals, Matrix& in, + Matrix& out, Matrix& out_by_coltype) { namespace ex = pika::execution::experimental; namespace di = dlaf::internal; const SizeType n = problemSize(i_begin, i_end, in.distribution()); - auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_fut, const auto& in_tiles_futs, - const auto& out_tiles) { + auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_futs, const auto& in_tiles_futs, + const auto& out_tiles, const auto& out_coltype_tiles) { const TileElementIndex zero_idx(0, 0); const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx); - const T* evals_ptr = evals_tiles_fut[0].get().ptr(zero_idx); + const T* evals_ptr = evals_tiles_futs[0].get().ptr(zero_idx); const SizeType* in_ptr = in_tiles_futs[0].get().ptr(zero_idx); SizeType* out_ptr = out_tiles[0].ptr(zero_idx); + SizeType* out_coltype_ptr = out_coltype_tiles[0].ptr(zero_idx); - return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr); + return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr, out_coltype_ptr); }; TileCollector tc{i_begin, i_end}; return ex::when_all(ex::when_all_vector(tc.read(c)), ex::when_all_vector(tc.read(evals)), - ex::when_all_vector(tc.read(in)), ex::when_all_vector(tc.readwrite(out))) | + ex::when_all_vector(tc.read(in)), ex::when_all_vector(tc.readwrite(out)), + ex::when_all_vector(tc.readwrite(out_by_coltype))) | di::transform(di::Policy(), std::move(part_fn)); } template auto stablePartitionIndexForDeflation( - const SizeType i_begin, const SizeType i_end, Matrix& c, - Matrix& evals, Matrix& in, - Matrix& out, Matrix& out_by_coltype) { + const matrix::Distribution& dist_evecs, const SizeType i_begin, const SizeType i_end, + Matrix& c, Matrix& evals, + Matrix& in, Matrix& out, + Matrix& out_by_coltype, Matrix& i4, + Matrix& i6) { namespace ex = pika::execution::experimental; namespace di = dlaf::internal; const SizeType n = problemSize(i_begin, i_end, in.distribution()); - auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_futs, const auto& in_tiles_futs, - const auto& out_tiles, const auto& out_coltype_tiles) { + + const matrix::Distribution dist_evecs_sub( + dist_evecs, + matrix::SubDistributionSpec{dist_evecs.globalElementIndex({i_begin, i_begin}, {0, 0}), {n, n}}); + + auto part_fn = [n, dist_evecs_sub](const auto& c_tiles_futs, const auto& evals_tiles_futs, + const auto& in_tiles_futs, const auto& out_tiles, + const auto& out_coltype_tiles, const auto& i4_tiles, + const auto& i6_tiles) { const TileElementIndex zero_idx(0, 0); const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx); const T* evals_ptr = evals_tiles_futs[0].get().ptr(zero_idx); - const SizeType* in_ptr = in_tiles_futs[0].get().ptr(zero_idx); + SizeType* in_ptr = in_tiles_futs[0].ptr(zero_idx); SizeType* out_ptr = out_tiles[0].ptr(zero_idx); SizeType* out_coltype_ptr = out_coltype_tiles[0].ptr(zero_idx); + SizeType* i4_ptr = i4_tiles[0].ptr(zero_idx); + SizeType* i6_ptr = i6_tiles[0].ptr(zero_idx); - return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr, out_coltype_ptr); + return stablePartitionIndexForDeflationArrays(dist_evecs_sub, n, c_ptr, evals_ptr, in_ptr, out_ptr, + out_coltype_ptr, i4_ptr, i6_ptr); }; TileCollector tc{i_begin, i_end}; return ex::when_all(ex::when_all_vector(tc.read(c)), ex::when_all_vector(tc.read(evals)), - ex::when_all_vector(tc.read(in)), ex::when_all_vector(tc.readwrite(out)), - ex::when_all_vector(tc.readwrite(out_by_coltype))) | + ex::when_all_vector(tc.readwrite(in)), ex::when_all_vector(tc.readwrite(out)), + ex::when_all_vector(tc.readwrite(out_by_coltype)), + ex::when_all_vector(tc.readwrite(i4)), ex::when_all_vector(tc.readwrite(i6))) | di::transform(di::Policy(), std::move(part_fn)); } @@ -912,9 +1057,9 @@ void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const Size const SizeType k, std::array n_udl) mutable { using dlaf::matrix::internal::MatrixRef; - const SizeType n_uh = to_SizeType(n_udl[coltype_index(ColType::UpperHalf)]); - const SizeType n_de = to_SizeType(n_udl[coltype_index(ColType::Dense)]); - const SizeType n_lh = to_SizeType(n_udl[coltype_index(ColType::LowerHalf)]); + const SizeType n_uh = to_SizeType(n_udl[sortOrder(ColType::UpperHalf)]); + const SizeType n_de = to_SizeType(n_udl[sortOrder(ColType::Dense)]); + const SizeType n_lh = to_SizeType(n_udl[sortOrder(ColType::LowerHalf)]); const SizeType a = n_uh + n_de; const SizeType b = n_de + n_lh; @@ -992,13 +1137,13 @@ void assembleDistZVec(comm::CommunicatorGrid grid, common::Pipeline +template void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const SizeType i_begin, - const SizeType i_end, const LocalTileIndex ij_begin_lc, - const LocalTileSize sz_loc_tiles, KSender&& k, RhoSender&& rho, + const SizeType i_end, KSender&& k, KLcSender&& k_lc, RhoSender&& rho, Matrix& d, Matrix& z, - Matrix& evals, Matrix& i2, - Matrix& evecs) { + Matrix& evals, Matrix& i4, + Matrix& i6, + Matrix& i2, Matrix& evecs) { namespace ex = pika::execution::experimental; namespace di = dlaf::internal; namespace tt = pika::this_thread::experimental; @@ -1009,17 +1154,9 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S const SizeType n = problemSize(i_begin, i_end, dist); - const SizeType m_subm_el_lc = [=]() { - const auto i_loc_begin = ij_begin_lc.row(); - const auto i_loc_end = ij_begin_lc.row() + sz_loc_tiles.rows(); - return dist.localElementDistanceFromLocalTile(i_loc_begin, i_loc_end); - }(); - - const SizeType n_subm_el_lc = [=]() { - const auto i_loc_begin = ij_begin_lc.col(); - const auto i_loc_end = ij_begin_lc.col() + sz_loc_tiles.cols(); - return dist.localElementDistanceFromLocalTile(i_loc_begin, i_loc_end); - }(); + const matrix::Distribution dist_sub( + dist, {{i_begin * dist.blockSize().rows(), i_begin * dist.blockSize().cols()}, + dist.globalTileElementDistance({i_begin, i_begin}, {i_end, i_end})}); auto bcast_evals = [i_begin, i_end, dist](common::Pipeline& row_comm_chain, @@ -1054,386 +1191,374 @@ void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const S comm, req)); }; - // Note: at least two column of tiles per-worker, in the range [1, getTridiagRank1NWorkers()] - const std::size_t nthreads = [nrtiles = sz_loc_tiles.cols()]() { - const std::size_t min_workers = 1; - const std::size_t available_workers = getTridiagRank1NWorkers(); - const std::size_t ideal_workers = util::ceilDiv(to_sizet(nrtiles), to_sizet(2)); - return std::clamp(ideal_workers, min_workers, available_workers); - }(); - + const auto hp_scheduler = di::getBackendScheduler(pika::execution::thread_priority::high); ex::start_detached( - ex::when_all(ex::just(std::make_unique>(nthreads)), - std::forward(row_comm), std::forward(col_comm), - std::forward(k), std::forward(rho), + ex::when_all(std::forward(row_comm), std::forward(col_comm), + std::forward(k), std::forward(k_lc), std::forward(rho), ex::when_all_vector(tc.read(d)), ex::when_all_vector(tc.readwrite(z)), - ex::when_all_vector(tc.readwrite(evals)), ex::when_all_vector(tc.read(i2)), + ex::when_all_vector(tc.readwrite(evals)), ex::when_all_vector(tc.read(i4)), + ex::when_all_vector(tc.read(i6)), ex::when_all_vector(tc.read(i2)), ex::when_all_vector(tc.readwrite(evecs)), // additional workspaces ex::just(std::vector>()), ex::just(memory::MemoryView())) | - ex::transfer(di::getBackendScheduler(pika::execution::thread_priority::high)) | - ex::bulk(nthreads, [nthreads, n, n_subm_el_lc, m_subm_el_lc, i_begin, ij_begin_lc, sz_loc_tiles, - dist, bcast_evals, all_reduce_in_place]( - const std::size_t thread_idx, auto& barrier_ptr, auto& row_comm_wrapper, - auto& col_comm_wrapper, const auto& k, const auto& rho, - const auto& d_tiles_futs, auto& z_tiles, const auto& eval_tiles, - const auto& i2_tile_arr, const auto& evec_tiles, auto& ws_cols, - auto& ws_row) { - using dlaf::comm::internal::transformMPI; - - common::Pipeline row_comm_chain(row_comm_wrapper.get()); - const dlaf::comm::Communicator& col_comm = col_comm_wrapper.get(); - - const auto barrier_busy_wait = getTridiagRank1BarrierBusyWait(); - const std::size_t batch_size = - std::max(2, util::ceilDiv(to_sizet(sz_loc_tiles.cols()), nthreads)); - const SizeType begin = to_SizeType(thread_idx * batch_size); - const SizeType end = std::min(to_SizeType((thread_idx + 1) * batch_size), sz_loc_tiles.cols()); - - // STEP 0a: Fill ones for deflated Eigenvectors. (single-thread) - // Note: this step is completely independent from the rest, but it is small and it is going - // to be dropped soon. - // Note: use last threads that in principle should have less work to do - if (thread_idx == nthreads - 1) { - // just if there are deflated eigenvectors - if (k < n) { - const GlobalElementSize origin_el(i_begin * dist.blockSize().rows(), - i_begin * dist.blockSize().cols()); - const SizeType* i2_perm = i2_tile_arr[0].get().ptr(); - - for (SizeType i_subm_el = 0; i_subm_el < n; ++i_subm_el) { - const SizeType j_subm_el = i2_perm[i_subm_el]; - - // if it is a deflated vector - if (j_subm_el >= k) { - const GlobalElementIndex ij_el(origin_el.rows() + i_subm_el, - origin_el.cols() + j_subm_el); - const GlobalTileIndex ij = dist.globalTileIndex(ij_el); - - if (dist.rankIndex() == dist.rankGlobalTile(ij)) { - const LocalTileIndex ij_lc = dist.localTileIndex(ij); - const SizeType linear_subm_lc = - (ij_lc.row() - ij_begin_lc.row()) + - (ij_lc.col() - ij_begin_lc.col()) * sz_loc_tiles.rows(); - const TileElementIndex ij_el_tl = dist.tileElementIndex(ij_el); - evec_tiles[to_sizet(linear_subm_lc)](ij_el_tl) = T{1}; - } - } - } - } - } - - // STEP 0b: Initialize workspaces (single-thread) - if (thread_idx == 0) { - // Note: - // - nthreads are used for both LAED4 and weight calculation (one per worker thread) - // - last one is used for reducing weights from all workers - ws_cols.reserve(nthreads + 1); - - // Note: - // Considering that - // - LAED4 requires working on k elements - // - Weight computation requires working on m_subm_el_lc - // - // and they are needed at two steps that cannot happen in parallel, we opted for allocating - // the workspace with the highest requirement of memory, and reuse them for both steps. - const SizeType max_size = std::max(k, m_subm_el_lc); - for (std::size_t i = 0; i < nthreads; ++i) - ws_cols.emplace_back(max_size); - ws_cols.emplace_back(m_subm_el_lc); - - ws_row = memory::MemoryView(n_subm_el_lc); - std::fill_n(ws_row(), n_subm_el_lc, 0); - } - - // Note: we have to wait that LAED4 workspaces are ready to be used - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - const T* d_ptr = d_tiles_futs[0].get().ptr(); - const T* z_ptr = z_tiles[0].ptr(); - - // STEP 1: LAED4 (multi-thread) - { - common::internal::SingleThreadedBlasScope single; - - T* eval_ptr = eval_tiles[0].ptr(); - T* delta_ptr = ws_cols[thread_idx](); - - for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) { - const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc); - const SizeType j = dist.globalTileFromLocalTile(j_lc); - const SizeType n_subm_el = dist.globalTileElementDistance(i_begin, j); - - // Skip columns that are in the deflation zone - if (n_subm_el >= k) - break; - - const SizeType n_el_tl = std::min(dist.tileSize(j), k - n_subm_el); - for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) { - const SizeType j_el = n_subm_el + j_el_tl; - - // Solve the deflated rank-1 problem - T& eigenval = eval_ptr[to_sizet(j_el)]; - lapack::laed4(to_int(k), to_int(j_el), d_ptr, z_ptr, delta_ptr, rho, &eigenval); - - // copy the parts from delta stored on this rank - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType linear_subm_lc = i_subm_lc + to_SizeType(j_subm_lc) * sz_loc_tiles.rows(); - auto& evec_tile = evec_tiles[to_sizet(linear_subm_lc)]; - - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType m_subm_el = dist.globalTileElementDistance(i_begin, i); - - const SizeType i_subm = i - i_begin; - const auto& i2_perm = i2_tile_arr[to_sizet(i_subm)].get(); - - const SizeType m_el_tl = std::min(dist.tileSize(i), n - m_subm_el); - for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { - const SizeType jj_subm_el = i2_perm({i_el_tl, 0}); - if (jj_subm_el < k) - evec_tile({i_el_tl, j_el_tl}) = delta_ptr[jj_subm_el]; - } - } - } - } - } - - // Note: This barrier ensures that LAED4 finished, so from now on values are available - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - // STEP 2: Broadcast evals - - // Note: this ensures that evals broadcasting finishes before bulk releases resources - struct sync_wait_on_exit_t { - ex::unique_any_sender<> sender_; - - ~sync_wait_on_exit_t() { - if (sender_) - tt::sync_wait(std::move(sender_)); - } - } bcast_barrier; - - if (thread_idx == 0) - bcast_barrier.sender_ = bcast_evals(row_comm_chain, eval_tiles); - - // Note: laed4 handles k <= 2 cases differently - if (k <= 2) - return; - - // STEP 2 Compute weights (multi-thread) - auto& q = evec_tiles; - T* w = ws_cols[thread_idx](); - - // STEP 2a: copy diagonal from q -> w (or just initialize with 1) - if (thread_idx == 0) { - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType i_subm_el = dist.globalTileElementDistance(i_begin, i); - const SizeType m_subm_el_lc = - dist.localElementDistanceFromLocalTile(ij_begin_lc.row(), i_lc); - const auto& i2 = i2_tile_arr[to_sizet(i - i_begin)].get(); - - const SizeType m_el_tl = std::min(dist.tileSize(i), n - i_subm_el); - for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { - const SizeType i_subm_el_lc = m_subm_el_lc + i_el_tl; - - const SizeType jj_subm_el = i2({i_el_tl, 0}); - const SizeType n_el = dist.globalTileElementDistance(0, i_begin); - const SizeType jj_el = n_el + jj_subm_el; - const SizeType jj = dist.globalTileFromGlobalElement(jj_el); - - if (dist.rankGlobalTile(jj) == dist.rankIndex().col()) { - const SizeType jj_lc = dist.localTileFromGlobalTile(jj); - const SizeType jj_subm_lc = jj_lc - ij_begin_lc.col(); - const SizeType jj_el_tl = dist.tileElementFromGlobalElement(jj_el); - - const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * jj_subm_lc; - - w[i_subm_el_lc] = q[to_sizet(linear_subm_lc)]({i_el_tl, jj_el_tl}); - } - else { - w[i_subm_el_lc] = T(1); - } - } - } - } - else { // other workers - std::fill_n(w, m_subm_el_lc, T(1)); - } - - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - // STEP 2b: compute weights - for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) { - const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc); - const SizeType j = dist.globalTileFromLocalTile(j_lc); - const SizeType n_subm_el = dist.globalTileElementDistance(i_begin, j); - - // Skip columns that are in the deflation zone - if (n_subm_el >= k) - break; - - const SizeType n_el_tl = std::min(dist.tileSize(j), k - n_subm_el); - for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) { - const SizeType j_subm_el = n_subm_el + j_el_tl; - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType m_subm_el = dist.globalTileElementDistance(i_begin, i); - - auto& i2_perm = i2_tile_arr[to_sizet(i - i_begin)].get(); - - const SizeType m_el_tl = std::min(dist.tileSize(i), n - m_subm_el); - for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { - const SizeType ii_subm_el = i2_perm({i_el_tl, 0}); - - // deflated zone - if (ii_subm_el >= k) - continue; - - // diagonal - if (ii_subm_el == j_subm_el) - continue; - - const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc; - const SizeType i_subm_el_lc = i_subm_lc * dist.blockSize().rows() + i_el_tl; - - w[i_subm_el_lc] *= q[to_sizet(linear_subm_lc)]({i_el_tl, j_el_tl}) / - (d_ptr[to_sizet(ii_subm_el)] - d_ptr[to_sizet(j_subm_el)]); - } - } - } - } - - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - // STEP 2c: reduce, then finalize computation with sign and square root (single-thread) - if (thread_idx == 0) { - // local reduction from all bulk workers - for (SizeType i = 0; i < m_subm_el_lc; ++i) { - for (std::size_t tidx = 1; tidx < nthreads; ++tidx) { - const T* w_partial = ws_cols[tidx](); - w[i] *= w_partial[i]; - } - } - - tt::sync_wait(ex::when_all(row_comm_chain(), - ex::just(MPI_PROD, common::make_data(w, m_subm_el_lc))) | - transformMPI(all_reduce_in_place)); - - T* weights = ws_cols[nthreads](); - for (SizeType i_subm_el_lc = 0; i_subm_el_lc < m_subm_el_lc; ++i_subm_el_lc) { - const SizeType i_subm_lc = i_subm_el_lc / dist.blockSize().rows(); - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType i_subm = i - i_begin; - const SizeType i_subm_el = - i_subm * dist.blockSize().rows() + i_subm_el_lc % dist.blockSize().rows(); - - const auto* i2_perm = i2_tile_arr[0].get().ptr(); - const SizeType ii_subm_el = i2_perm[i_subm_el]; - weights[to_sizet(i_subm_el_lc)] = - std::copysign(std::sqrt(-w[i_subm_el_lc]), z_ptr[to_sizet(ii_subm_el)]); - } - } - - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - // STEP 3: Compute eigenvectors of the modified rank-1 modification (normalize) (multi-thread) - - // STEP 3a: Form evecs using weights vector and compute (local) sum of squares - { - common::internal::SingleThreadedBlasScope single; - - const T* w = ws_cols[nthreads](); - T* sum_squares = ws_row(); - - for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) { - const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc); - const SizeType j = dist.globalTileFromLocalTile(j_lc); - const SizeType n_subm_el = dist.globalTileElementDistance(i_begin, j); - - // Skip columns that are in the deflation zone - if (n_subm_el >= k) - break; - - const SizeType n_el_tl = std::min(dist.tileSize(j), k - n_subm_el); - for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) { - const SizeType j_subm_el_lc = j_subm_lc * dist.blockSize().cols() + j_el_tl; - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType m_subm_el = dist.globalTileElementDistance(i_begin, i); - - const SizeType i_subm = i - i_begin; - const auto& i2_perm = i2_tile_arr[to_sizet(i_subm)].get(); - - const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc; - const auto& q_tile = q[to_sizet(linear_subm_lc)]; - - const SizeType m_el_tl = std::min(dist.tileSize(i), n - m_subm_el); - for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { - const SizeType ii_subm_el = i2_perm({i_el_tl, 0}); - - const SizeType i_subm_el_lc = i_subm_lc * dist.blockSize().rows() + i_el_tl; - if (ii_subm_el >= k) - q_tile({i_el_tl, j_el_tl}) = 0; - else - q_tile({i_el_tl, j_el_tl}) = w[i_subm_el_lc] / q_tile({i_el_tl, j_el_tl}); - } - - sum_squares[j_subm_el_lc] += - blas::dot(m_el_tl, q_tile.ptr({0, j_el_tl}), 1, q_tile.ptr({0, j_el_tl}), 1); - } - } - } - } - - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - // STEP 3b: Reduce to get the sum of all squares on all ranks - if (thread_idx == 0) - tt::sync_wait(ex::just(std::cref(col_comm), MPI_SUM, - common::make_data(ws_row(), n_subm_el_lc)) | - transformMPI(all_reduce_in_place)); - - barrier_ptr->arrive_and_wait(barrier_busy_wait); - - // STEP 3c: Normalize (compute norm of each column and scale column vector) - { - common::internal::SingleThreadedBlasScope single; - - const T* sum_squares = ws_row(); - - for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) { - const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc); - const SizeType j = dist.globalTileFromLocalTile(j_lc); - const SizeType n_subm_el = dist.globalTileElementDistance(i_begin, j); - - // Skip columns that are in the deflation zone - if (n_subm_el >= k) - break; - - const SizeType n_el_tl = std::min(dist.tileSize(j), k - n_subm_el); - for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) { - const SizeType j_subm_el_lc = j_subm_lc * dist.blockSize().cols() + j_el_tl; - const T vec_norm = std::sqrt(sum_squares[j_subm_el_lc]); - - for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) { - const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc; - const SizeType i_lc = ij_begin_lc.row() + i_subm_lc; - const SizeType i = dist.globalTileFromLocalTile(i_lc); - const SizeType m_subm_el = dist.globalTileElementDistance(i_begin, i); - - const SizeType m_el_tl = std::min(dist.tileSize(i), n - m_subm_el); - blas::scal(m_el_tl, 1 / vec_norm, q[to_sizet(linear_subm_lc)].ptr({0, j_el_tl}), 1); - } - } - } - } + ex::transfer(hp_scheduler) | + ex::let_value([n, dist_sub, bcast_evals, all_reduce_in_place, hp_scheduler]( + auto& row_comm_wrapper, auto& col_comm_wrapper, const SizeType k, + const SizeType k_lc, const auto& rho, const auto& d_tiles_futs, auto& z_tiles, + const auto& eval_tiles, const auto& i4_tiles_arr, const auto& i6_tiles_arr, + const auto& i2_tiles_arr, const auto& evec_tiles, auto& ws_cols, auto& ws_row) { + using pika::execution::thread_priority; + + const std::size_t nthreads = [dist_sub, k_lc] { + const std::size_t workload = to_sizet(dist_sub.localSize().rows() * k_lc); + const std::size_t workload_unit = 2 * to_sizet(dist_sub.blockSize().linear_size()); + + const std::size_t min_workers = 1; + const std::size_t available_workers = getTridiagRank1NWorkers(); + + const std::size_t ideal_workers = util::ceilDiv(to_sizet(workload), workload_unit); + return std::clamp(ideal_workers, min_workers, available_workers); + }(); + + return ex::just(std::make_unique>(nthreads)) | ex::transfer(hp_scheduler) | + ex::bulk(nthreads, [&row_comm_wrapper, &col_comm_wrapper, k, k_lc, &rho, &d_tiles_futs, + &z_tiles, &eval_tiles, &i4_tiles_arr, &i6_tiles_arr, &i2_tiles_arr, + &evec_tiles, &ws_cols, &ws_row, nthreads, n, dist_sub, bcast_evals, + all_reduce_in_place](const std::size_t thread_idx, + auto& barrier_ptr) { + using dlaf::comm::internal::transformMPI; + + common::Pipeline row_comm_chain(row_comm_wrapper.get()); + const dlaf::comm::Communicator& col_comm = col_comm_wrapper.get(); + + const SizeType m_lc = dist_sub.localNrTiles().rows(); + const SizeType m_el_lc = dist_sub.localSize().rows(); + const SizeType n_el_lc = dist_sub.localSize().cols(); + + const auto barrier_busy_wait = getTridiagRank1BarrierBusyWait(); + + const SizeType* i4 = i4_tiles_arr[0].get().ptr(); + const SizeType* i2 = i2_tiles_arr[0].get().ptr(); + const SizeType* i6 = i6_tiles_arr[0].get().ptr(); + + // STEP 0a: Permute eigenvalues for deflated eigenvectors (single-thread) + // Note: use last threads that in principle should have less work to do + if (k < n && thread_idx == nthreads - 1) { + const T* eval_initial_ptr = d_tiles_futs[0].get().ptr(); + T* eval_ptr = eval_tiles[0].ptr(); + + for (SizeType jeg_el_lc = k_lc; jeg_el_lc < n_el_lc; ++jeg_el_lc) { + const SizeType jeg_el = + dist_sub.globalElementFromLocalElement(jeg_el_lc); + eval_ptr[jeg_el] = eval_initial_ptr[i6[jeg_el]]; + } + } + + const std::size_t batch_size = util::ceilDiv(to_sizet(k_lc), nthreads); + const SizeType begin = to_SizeType(thread_idx * batch_size); + const SizeType end = std::min(to_SizeType(thread_idx * batch_size + batch_size), k_lc); + + // // at least two tiles (in columns) + // const std::size_t batch_size = + // std::max(2 * to_sizet(dist_sub.blockSize().cols()), + // util::ceilDiv(to_sizet(k_lc), nthreads)); + // const SizeType begin = to_SizeType(thread_idx * batch_size); + // const SizeType end = std::min(to_SizeType((thread_idx + 1) * batch_size), k_lc); + + // STEP 0b: Initialize workspaces (single-thread) + if (thread_idx == 0) { + // Note: + // - nthreads are used for both LAED4 and weight calculation (one per worker thread) + // - last one is used for reducing weights from all workers + ws_cols.reserve(nthreads + 1); + + // Note: + // Considering that + // - LAED4 requires working on k elements + // - Weight computation requires working on m_el_lc + // + // and they are needed at two steps that cannot happen in parallel, we opted for allocating + // the workspace with the highest requirement of memory, and reuse them for both steps. + const SizeType max_size = std::max(k, m_el_lc); + for (std::size_t i = 0; i < nthreads; ++i) + ws_cols.emplace_back(max_size); + ws_cols.emplace_back(m_el_lc); + + ws_row = memory::MemoryView(n_el_lc); + std::fill_n(ws_row(), n_el_lc, 0); + } + + // Note: we have to wait that LAED4 workspaces are ready to be used + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + const T* d_ptr = d_tiles_futs[0].get().ptr(); + const T* z_ptr = z_tiles[0].ptr(); + + // STEP 1: LAED4 (multi-thread) + { + common::internal::SingleThreadedBlasScope single; // TODO needed also for laed? + + T* eval_ptr = eval_tiles[0].ptr(); + T* delta_ptr = ws_cols[thread_idx](); + + for (SizeType jeg_el_lc = begin; jeg_el_lc < end; ++jeg_el_lc) { + const SizeType jeg_el = + dist_sub.globalElementFromLocalElement(jeg_el_lc); + const SizeType jeg_lc = dist_sub.localTileFromLocalElement(jeg_el_lc); + + // Solve the deflated rank-1 problem + // Note: + // it solves considering the order in the original fully sorted non-deflated (i3) + // but it stores it in extended global (as eigenvectors are stored in E1) + const SizeType js_el = i6[jeg_el]; + T& eigenval = eval_ptr[to_sizet(jeg_el)]; // eval is in compact rank layout + lapack::laed4(to_signed(k), to_signed(js_el), d_ptr, z_ptr, + delta_ptr, rho, &eigenval); + + // Now laed4 result has to be copied in the right spot + const SizeType jeg_el_tl = + dist_sub.tileElementFromGlobalElement(jeg_el); + + for (SizeType i_lc = 0; i_lc < m_lc; ++i_lc) { + const SizeType i = dist_sub.globalTileFromLocalTile(i_lc); + const SizeType m_el_tl = dist_sub.tileSize(i); + const SizeType linear_lc = dist_sub.localTileLinearIndex({i_lc, jeg_lc}); + const auto& evec = evec_tiles[to_sizet(linear_lc)]; + for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { + const SizeType i_el = + dist_sub.globalElementFromLocalTileAndTileElement(i_lc, + i_el_tl); + DLAF_ASSERT_HEAVY(i_el < n, i_el, n); + const SizeType is_el = i4[i_el]; + + // just non-deflated, because deflated have been already set to 0 + if (is_el < k) + evec({i_el_tl, jeg_el_tl}) = delta_ptr[is_el]; + } + } + } + } + // Note: This barrier ensures that LAED4 finished, so from now on values are available + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + // STEP 2: Broadcast evals + + // Note: this ensures that evals broadcasting finishes before bulk releases resources + struct sync_wait_on_exit_t { + ex::unique_any_sender<> sender_; + + ~sync_wait_on_exit_t() { + if (sender_) + tt::sync_wait(std::move(sender_)); + } + } bcast_barrier; + + if (thread_idx == 0) + bcast_barrier.sender_ = bcast_evals(row_comm_chain, eval_tiles); + + // Note: laed4 handles k <= 2 cases differently + if (k <= 2) + return; + + // STEP 2 Compute weights (multi-thread) + auto& q = evec_tiles; + T* w = ws_cols[thread_idx](); + + // STEP 2a: copy diagonal from q -> w (or just initialize with 1) + // Note: + // Loop over compact rank (=expanded global) up to k_el_lc + // index on k_el_lc has to be converted to global element on k_el, so it can be used with + // permutations + // during the switch from col axis to row axis we must keep the matching between eigenvectors + if (thread_idx == 0) { + for (SizeType ieg_el_lc = 0; ieg_el_lc < m_el_lc; ++ieg_el_lc) { + const SizeType ieg_el = + dist_sub.globalElementFromLocalElement(ieg_el_lc); + const SizeType is_el = i4[ieg_el]; + + if (is_el >= k) { + w[ieg_el_lc] = T{0}; + continue; + } + + const SizeType js_el = is_el; + const SizeType jeg_el = i2[js_el]; + + const GlobalElementIndex ijeg_subm_el(ieg_el, jeg_el); + + if (dist_sub.rankIndex().col() == dist_sub.rankGlobalElement(jeg_el)) { + const SizeType linear_subm_lc = dist_sub.localTileLinearIndex( + {dist_sub.localTileFromLocalElement(ieg_el_lc), + dist_sub.localTileFromGlobalElement(jeg_el)}); + const TileElementIndex ij_tl = dist_sub.tileElementIndex(ijeg_subm_el); + w[ieg_el_lc] = q[to_sizet(linear_subm_lc)](ij_tl); + } + else { + w[ieg_el_lc] = T{1}; + } + } + } + else { // other workers + std::fill_n(w, m_el_lc, T(1)); + } + + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + // STEP 2b: compute weights + { + for (SizeType jeg_el_lc = begin; jeg_el_lc < end; ++jeg_el_lc) { + const SizeType jeg_el = + dist_sub.globalElementFromLocalElement(jeg_el_lc); + const SizeType jeg_lc = dist_sub.localTileFromGlobalElement(jeg_el); + const SizeType js_el = i6[jeg_el]; + const T delta_j = d_ptr[to_sizet(js_el)]; + + const SizeType jeg_el_tl = + dist_sub.tileElementFromLocalElement(jeg_el_lc); + + for (SizeType i_lc = 0; i_lc < m_lc; ++i_lc) { + const SizeType i = dist_sub.globalTileFromLocalTile(i_lc); + const SizeType m_el_tl = dist_sub.tileSize(i); + const SizeType linear_lc = dist_sub.localTileLinearIndex({i_lc, jeg_lc}); + const auto& q_tile = q[to_sizet(linear_lc)]; + + for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { + const SizeType i_el = + dist_sub.globalElementFromGlobalTileAndTileElement(i, i_el_tl); + DLAF_ASSERT_HEAVY(i_el < n, i_el, n); + const SizeType is_el = i4[i_el]; + + // skip if deflated + if (is_el >= k) + continue; + + // skip if originally it was on the diagonal + if (is_el == js_el) + continue; + + const SizeType ieg_el_lc = + dist_sub.localElementFromLocalTileAndTileElement(i_lc, i_el_tl); + const TileElementIndex ij_tl(i_el_tl, jeg_el_tl); + + w[ieg_el_lc] *= q_tile(ij_tl) / (d_ptr[to_sizet(is_el)] - delta_j); + } + } + } + } + + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + // STEP 2c: reduce, then finalize computation with sign and square root (single-thread) + if (thread_idx == 0) { + // local reduction from all bulk workers + for (SizeType i = 0; i < m_el_lc; ++i) { + for (std::size_t tidx = 1; tidx < nthreads; ++tidx) { + const T* w_partial = ws_cols[tidx](); + w[i] *= w_partial[i]; + } + } + + tt::sync_wait(ex::when_all(row_comm_chain(), + ex::just(MPI_PROD, common::make_data(w, m_el_lc))) | + transformMPI(all_reduce_in_place)); + + // TODO check all weights < 0 (!= 0 otherwise q elements are set to zero and then nomr = 0 => nan) + + T* weights = ws_cols[nthreads](); + // TODO this can be limited to k_lc + for (SizeType i_el_lc = 0; i_el_lc < m_el_lc; ++i_el_lc) { + const SizeType i_el = dist_sub.globalElementFromLocalElement(i_el_lc); + const SizeType ii_el = i4[i_el]; + weights[to_sizet(i_el_lc)] = + std::copysign(std::sqrt(-w[i_el_lc]), z_ptr[to_sizet(ii_el)]); + } + } + + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + // STEP 3: Compute eigenvectors of the modified rank-1 modification (normalize) (multi-thread) + + // STEP 3a: Form evecs using weights vector and compute (local) sum of squares + { + common::internal::SingleThreadedBlasScope single; + + const T* w = ws_cols[nthreads](); + T* sum_squares = ws_row(); + + for (SizeType jeg_el_lc = begin; jeg_el_lc < end; ++jeg_el_lc) { + const SizeType jeg_lc = dist_sub.localTileFromLocalElement(jeg_el_lc); + const SizeType jeg_el_tl = + dist_sub.tileElementFromLocalElement(jeg_el_lc); + + for (SizeType i_lc = 0; i_lc < dist_sub.localNrTiles().rows(); ++i_lc) { + const SizeType i = dist_sub.globalTileFromLocalTile(i_lc); + const SizeType m_el_tl = dist_sub.tileSize(i); + const SizeType linear_lc = dist_sub.localTileLinearIndex({i_lc, jeg_lc}); + const auto& q_tile = q[to_sizet(linear_lc)]; + + for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) { + const SizeType i_el = + dist_sub.globalElementFromGlobalTileAndTileElement(i, i_el_tl); + + DLAF_ASSERT_HEAVY(i_el < n, i_el, n); + const SizeType is_el = i4[i_el]; + + // it is a deflated row, skip it (it should be already 0) + if (is_el >= k) + continue; + + const SizeType ieg_el_lc = + dist_sub.localElementFromLocalTileAndTileElement(i_lc, i_el_tl); + const TileElementIndex ijeg_el_tl(i_el_tl, jeg_el_tl); + + q_tile(ijeg_el_tl) = w[ieg_el_lc] / q_tile(ijeg_el_tl); + } + + const T* partial_evec = q_tile.ptr({0, jeg_el_tl}); + sum_squares[jeg_el_lc] += blas::dot(m_el_tl, partial_evec, 1, partial_evec, 1); + } + } + } + + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + // STEP 3b: Reduce to get the sum of all squares on all ranks + if (thread_idx == 0) { + // TODO it can be limited to k_lc + tt::sync_wait(ex::just(std::cref(col_comm), MPI_SUM, + common::make_data(ws_row(), n_el_lc)) | + transformMPI(all_reduce_in_place)); + } + + barrier_ptr->arrive_and_wait(barrier_busy_wait); + + // STEP 3c: Normalize (compute norm of each column and scale column vector) + { + common::internal::SingleThreadedBlasScope single; + + const T* sum_squares = ws_row(); + + for (SizeType jeg_el_lc = begin; jeg_el_lc < end; ++jeg_el_lc) { + const SizeType jeg_lc = dist_sub.localTileFromLocalElement(jeg_el_lc); + const SizeType jeg_el_tl = + dist_sub.tileElementFromLocalElement(jeg_el_lc); + + const T vec_norm = std::sqrt(sum_squares[jeg_el_lc]); + + for (SizeType i_lc = 0; i_lc < m_lc; ++i_lc) { + const LocalTileIndex ijeg_lc(i_lc, jeg_lc); + const SizeType ijeg_linear = dist_sub.localTileLinearIndex(ijeg_lc); + + T* partial_evec = q[to_sizet(ijeg_linear)].ptr({0, jeg_el_tl}); + + const SizeType i = dist_sub.globalTileFromLocalTile(i_lc); + const SizeType m_el_tl = dist_sub.tileSize(i); + blas::scal(m_el_tl, 1 / vec_norm, partial_evec, 1); + } + } + } + }); })); } @@ -1448,17 +1573,24 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, DistWorkSpaceHostMirror& ws_hm) { namespace ex = pika::execution::experimental; - const matrix::Distribution& dist_evecs = ws.e0.distribution(); + const matrix::Distribution& dist = ws.e0.distribution(); + + const GlobalElementIndex sub_offset{i_begin * dist.blockSize().rows(), + i_begin * dist.blockSize().cols()}; + const matrix::Distribution dist_sub( + dist, {sub_offset, dist.globalTileElementDistance({i_begin, i_begin}, {i_end, i_end})}); // Calculate the size of the upper subproblem - const SizeType n1 = dist_evecs.globalTileElementDistance(i_begin, i_split); + const SizeType n = dist.globalTileElementDistance(i_begin, i_end); + const SizeType n_upper = dist.globalTileElementDistance(i_begin, i_split); + const SizeType n_lower = dist.globalTileElementDistance(i_split, i_end); // The local size of the subproblem const GlobalTileIndex idx_gl_begin(i_begin, i_begin); - const LocalTileIndex idx_loc_begin{dist_evecs.nextLocalTileFromGlobalTile(i_begin), - dist_evecs.nextLocalTileFromGlobalTile(i_begin)}; - const LocalTileIndex idx_loc_end{dist_evecs.nextLocalTileFromGlobalTile(i_end), - dist_evecs.nextLocalTileFromGlobalTile(i_end)}; + const LocalTileIndex idx_loc_begin{dist.nextLocalTileFromGlobalTile(i_begin), + dist.nextLocalTileFromGlobalTile(i_begin)}; + const LocalTileIndex idx_loc_end{dist.nextLocalTileFromGlobalTile(i_end), + dist.nextLocalTileFromGlobalTile(i_end)}; const LocalTileSize sz_loc_tiles = idx_loc_end - idx_loc_begin; const LocalTileIndex idx_begin_tiles_vec(i_begin, 0); const LocalTileSize sz_tiles_vec(i_end - i_begin, 1); @@ -1476,6 +1608,17 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // Initialize the column types vector `c` initColTypes(i_begin, i_split, i_end, ws_h.c); + // Initialize i1 as identity just for single tile sub-problems + if (i_split == i_begin + 1) { + initIndex(i_begin, i_split, ws_h.i1); + } + if (i_split + 1 == i_end) { + initIndex(i_split, i_end, ws_h.i1); + } + + // Update indices of second sub-problem + addIndex(i_split, i_end, n_upper, ws_h.i1); + // Step #1 // // i1 (out) : initial <--- initial (identity map) @@ -1484,14 +1627,7 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // - deflate `d`, `z` and `c` // - apply Givens rotations to `Q` - `evecs` // - if (i_split == i_begin + 1) { - initIndex(i_begin, i_split, ws_h.i1); - } - if (i_split + 1 == i_end) { - initIndex(i_split, i_end, ws_h.i1); - } - addIndex(i_split, i_end, n1, ws_h.i1); - sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2); + sortIndex(i_begin, i_end, ex::just(n_upper), ws_h.d0, ws_h.i1, ws_hm.i2); auto rots = applyDeflation(i_begin, i_end, scaled_rho, std::move(tol), ws_hm.i2, ws_h.d0, ws_hm.z0, ws_h.c); @@ -1502,8 +1638,6 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, const comm::IndexT_MPI tag = to_int(i_split); applyGivensRotationsToMatrixColumns(grid.rowCommunicator(), tag, i_begin, i_end, std::move(rots), ws.e0); - // Placeholder for rearranging the eigenvectors: (local permutation) - copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws.e1); // Step #2 // @@ -1514,40 +1648,100 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`. // - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented) // - auto k = - stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_h.d0, ws_hm.i2, ws_h.i3) | ex::split(); + auto [k_unique, k_lc_unique, n_udl] = + ex::split_tuple(stablePartitionIndexForDeflation(dist, i_begin, i_end, ws_h.c, ws_h.d0, ws_hm.i2, + ws_h.i3, ws_hm.i5, ws_h.i4, ws_hm.i6)); + // from now on i2 is the inverse of i6 + invertIndex(i_begin, i_end, ws_hm.i6, ws_hm.i2); + + auto k = ex::split(std::move(k_unique)); + auto k_lc = ex::split(std::move(k_lc_unique)); + + // Reorder Eigenvectors + if constexpr (Backend::MC == B) { + copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i5, ws.i5); + dlaf::permutations::internal::permuteJustLocal(i_begin, i_end, ws.i5, ws.e0, + ws.e1); + } + else { + // TODO remove this branch. It exists just because GPU permuteJustLocal is not implemented yet + copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws_hm.e0); + dlaf::permutations::internal::permuteJustLocal( + i_begin, i_end, ws_hm.i5, ws_hm.e0, ws_hm.e2); + copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e1); + } + + // Reorder Eigenvalues applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1); applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1); - copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.d1, ws_h.d0); - - // - // i3 (in) : initial <--- deflated - // i2 (out) : initial ---> deflated - // - invertIndex(i_begin, i_end, ws_h.i3, ws_hm.i2); // Note: here ws_hm.z0 is used as a contiguous buffer for the laed4 call matrix::util::set0(pika::execution::thread_priority::normal, idx_loc_begin, sz_loc_tiles, ws_hm.e2); - solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, idx_loc_begin, sz_loc_tiles, - k, std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_hm.i2, ws_hm.e2); + solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, k, k_lc, + std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_h.i4, ws_hm.i6, ws_hm.i2, + ws_hm.e2); + copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2); // Step #3: Eigenvectors of the tridiagonal system: Q * U // // The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as // prepared for the deflated system. - copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2); - dlaf::multiplication::internal::generalSubMatrix(grid, row_task_chain, col_task_chain, - i_begin, i_end, T(1), ws.e1, ws.e2, T(0), - ws.e0); + ex::start_detached( + ex::when_all(std::move(k_lc), std::move(n_udl), row_task_chain(), col_task_chain()) | + ex::transfer(dlaf::internal::getBackendScheduler()) | + ex::then([dist_sub, sub_offset, n, n_upper, n_lower, e0 = ws.e0.subPipeline(), + e1 = ws.e1.subPipelineConst(), e2 = ws.e2.subPipelineConst()]( + const SizeType k_lc, const std::array& n_udl, auto&& row_comm_wrapper, + auto&& col_comm_wrapper) mutable { + using dlaf::matrix::internal::MatrixRef; + + common::Pipeline sub_comm_row(row_comm_wrapper.get()); + common::Pipeline sub_comm_col(col_comm_wrapper.get()); + + const auto [a, b, c] = n_udl; + + using GEMM = dlaf::multiplication::internal::GeneralSub; + { + MatrixRef e1_sub(e1, {sub_offset, {n_upper, b}}); + MatrixRef e2_sub(e2, {sub_offset, {b, c}}); + MatrixRef e0_sub(e0, {sub_offset, {n_upper, c}}); + + GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub); + } + + { + MatrixRef e1_sub(e1, {{sub_offset.row() + n_upper, sub_offset.col() + a}, + {n_lower, c - a}}); + MatrixRef e2_sub(e2, {{sub_offset.row() + a, sub_offset.col()}, {c - a, c}}); + MatrixRef e0_sub(e0, {{sub_offset.row() + n_upper, sub_offset.col()}, {n_lower, c}}); + + GEMM::callNN(sub_comm_row, sub_comm_col, T(1), e1_sub, e2_sub, T(0), e0_sub); + } + + // copy deflated from e1 to e0 + if (k_lc < dist_sub.localSize().cols()) { + const SizeType k = dist_sub.globalElementFromLocalElement(k_lc); + const matrix::internal::SubMatrixSpec deflated_submat{{sub_offset.row(), sub_offset.col() + k}, + {n, n - k}}; + MatrixRef sub_e0(e0, deflated_submat); + MatrixRef sub_e1(e1, deflated_submat); + + copy(sub_e1, sub_e0); + } + + namespace tt = pika::this_thread::experimental; + tt::sync_wait(sub_comm_row()); + tt::sync_wait(sub_comm_col()); + })); // Step #4: Final permutation to sort eigenvalues and eigenvectors // // i1 (in) : deflated <--- deflated (identity map) // i2 (out) : deflated <--- post_sorted // - initIndex(i_begin, i_end, ws_h.i1); - sortIndex(i_begin, i_end, std::move(k), ws_h.d0, ws_h.i1, ws_hm.i2); - copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i2, ws_h.i1); + + // TODO merge sort + sortIndex(i_begin, i_end, std::move(k), ws_h.d0, ws_hm.i2, ws_h.i1); } } diff --git a/include/dlaf/matrix/copy.h b/include/dlaf/matrix/copy.h index b2df806b91..426fdcdf8d 100644 --- a/include/dlaf/matrix/copy.h +++ b/include/dlaf/matrix/copy.h @@ -86,10 +86,10 @@ void copy(MatrixRef& src, MatrixRef& dst) { // TODO assert same size src == dest const dlaf::internal::Policy> policy; - for (SizeType j = 0; j < src.nrTiles().cols(); ++j) { - for (SizeType i = 0; i < src.nrTiles().rows(); ++i) { - ex::start_detached(ex::when_all(src.read(GlobalTileIndex{i, j}), - dst.readwrite(GlobalTileIndex{i, j})) | + for (SizeType j = 0; j < src.distribution().localNrTiles().cols(); ++j) { + for (SizeType i = 0; i < src.distribution().localNrTiles().rows(); ++i) { + ex::start_detached(ex::when_all(src.read(LocalTileIndex{i, j}), + dst.readwrite(LocalTileIndex{i, j})) | matrix::copy(policy)); } } diff --git a/include/dlaf/multiplication/general/api.h b/include/dlaf/multiplication/general/api.h index 5d3ebe743e..cd02f5982e 100644 --- a/include/dlaf/multiplication/general/api.h +++ b/include/dlaf/multiplication/general/api.h @@ -36,8 +36,10 @@ struct GeneralSub { Matrix& mat_b, const T beta, Matrix& mat_c); // Note: internal helper - static void callNN(const blas::Op opA, const blas::Op opB, const T alpha, MatrixRef& mat_a, - MatrixRef& mat_b, const T beta, MatrixRef& mat_c); + static void callNN(common::Pipeline& row_task_chain, + common::Pipeline& col_task_chain, const T alpha, + MatrixRef& mat_a, MatrixRef& mat_b, const T beta, + MatrixRef& mat_c); }; // ETI diff --git a/include/dlaf/multiplication/general/impl.h b/include/dlaf/multiplication/general/impl.h index b03b512ff8..1e433ad2c0 100644 --- a/include/dlaf/multiplication/general/impl.h +++ b/include/dlaf/multiplication/general/impl.h @@ -10,6 +10,8 @@ #pragma once +#include + #include #include #include @@ -24,6 +26,7 @@ #include #include #include +#include namespace dlaf::multiplication { namespace internal { @@ -188,5 +191,86 @@ void GeneralSub::callNN(common::Pipeline& row_task_ panelB.reset(); } } + +template +void GeneralSub::callNN(common::Pipeline& row_task_chain, + common::Pipeline& col_task_chain, const T alpha, + MatrixRef& mat_a, MatrixRef& mat_b, + const T beta, MatrixRef& mat_c) { + namespace ex = pika::execution::experimental; + + // TODO assert equal distribution? + DLAF_ASSERT(dlaf::matrix::multipliable_sizes(mat_a.size(), mat_b.size(), mat_c.size(), + blas::Op::NoTrans, blas::Op::NoTrans), + mat_a.size(), mat_b.size(), mat_c.size()); + + if (mat_c.size().isEmpty()) + return; + + const matrix::Distribution& dist_a = mat_a.distribution(); + const matrix::Distribution& dist_b = mat_b.distribution(); + const matrix::Distribution& dist_c = mat_c.distribution(); + const auto rank = dist_c.rankIndex(); + + constexpr std::size_t n_workspaces = 2; + common::RoundRobin> panelsA(n_workspaces, dist_c); + common::RoundRobin> panelsB(n_workspaces, dist_c); + + DLAF_ASSERT_HEAVY(mat_a.nrTiles().cols() == mat_b.nrTiles().rows(), mat_a.nrTiles(), mat_b.nrTiles()); + + // This loops over the global indices for k, because every rank has to participate in communication + for (SizeType k = 0; k < mat_a.nrTiles().cols(); ++k) { + auto& panelA = panelsA.nextResource(); + auto& panelB = panelsB.nextResource(); + + if (k == 0 || k == mat_a.nrTiles().cols() - 1) { + DLAF_ASSERT_HEAVY(dist_a.tileSize(k) == dist_b.tileSize(k), + dist_a.tileSize(k), dist_b.tileSize(k)); + const SizeType kSize = dist_a.tileSize(k); + panelA.setWidth(kSize); + panelB.setHeight(kSize); + } + + // Setup the column workspace for the root ranks, i.e. the ones in the current col + const auto rank_k_col = dist_a.rankGlobalTile(k); + if (rank_k_col == rank.col()) { + const auto k_local = dist_a.template localTileFromGlobalTile(k); + for (SizeType i = 0; i < dist_c.localNrTiles().rows(); ++i) { + const LocalTileIndex ik(i, k_local); + panelA.setTile(ik, mat_a.read(ik)); + } + } + // Setup the row workspace for the root ranks, i.e. the ones in the current row + const auto rank_k_row = dist_b.rankGlobalTile(k); + if (rank_k_row == rank.row()) { + const auto k_local = dist_b.template localTileFromGlobalTile(k); + for (SizeType j = 0; j < dist_c.localNrTiles().cols(); ++j) { + const LocalTileIndex kj(k_local, j); + panelB.setTile(kj, mat_b.read(kj)); + } + } + + // Broadcast both column and row panel from root to others (row-wise and col-wise, respectively) + broadcast(rank_k_col, panelA, row_task_chain); + broadcast(rank_k_row, panelB, col_task_chain); + + // This is the core loop where the k step performs the update over the entire local matrix using + // the col and row workspaces. + // Everything needed for the update is available locally thanks to previous broadcasts. + for (SizeType i = 0; i < dist_c.localNrTiles().rows(); ++i) { + for (SizeType j = 0; j < dist_c.localNrTiles().cols(); ++j) { + const LocalTileIndex ij(i, j); + + ex::start_detached(dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::NoTrans, alpha, + panelA.read(ij), panelB.read(ij), + k == 0 ? beta : T(1), mat_c.readwrite(ij)) | + tile::gemm(dlaf::internal::Policy())); + } + } + + panelA.reset(); + panelB.reset(); + } +} } } diff --git a/include/dlaf/permutations/general/impl.h b/include/dlaf/permutations/general/impl.h index 84483b9b72..28dfd24e9e 100644 --- a/include/dlaf/permutations/general/impl.h +++ b/include/dlaf/permutations/general/impl.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -192,6 +193,79 @@ void Permutations::call(const SizeType i_begin, const SizeType i_end } } +template +void permuteJustLocal(const SizeType i_begin, const SizeType i_end, Matrix& perms, + Matrix& mat_in, Matrix& mat_out) { + static_assert(C == Coord::Col, "Just column permutation"); + + namespace ut = matrix::util; + namespace ex = pika::execution::experimental; + namespace di = dlaf::internal; + + using matrix::internal::MatrixRef; + using matrix::internal::SubMatrixSpec; + + if (i_begin == i_end) + return; + + const matrix::Distribution& distr = mat_in.distribution(); + + const SubMatrixSpec sub_spec{ + distr.globalElementIndex({i_begin, i_begin}, {0, 0}), + distr.globalTileElementDistance({i_begin, i_begin}, {i_end, i_end}), + }; + MatrixRef mat_sub_in(mat_in, sub_spec); + MatrixRef mat_sub_out(mat_out, sub_spec); + + const matrix::Distribution& dist_sub = mat_sub_in.distribution(); + + const SizeType ntiles = i_end - i_begin; + const auto perms_range = common::iterate_range2d(LocalTileIndex(i_begin, 0), LocalTileSize(ntiles, 1)); + const auto mat_range = common::iterate_range2d(dist_sub.localNrTiles()); + auto sender = ex::when_all(ex::when_all_vector(matrix::selectRead(perms, std::move(perms_range))), + ex::when_all_vector(matrix::selectRead(mat_sub_in, mat_range)), + ex::when_all_vector(matrix::select(mat_sub_out, mat_range))); + + auto permute_fn = [dist_sub](const auto& perm_tiles_futs, const auto& mat_in_tiles, + const auto& mat_out_tiles, auto&&...) { + const SizeType* perm_ptr = perm_tiles_futs[0].get().ptr(); + + const SizeType nperms_lc = dist_sub.localSize().cols(); + if constexpr (D == Device::CPU) { + for (SizeType j_el_lc = 0; j_el_lc < nperms_lc; ++j_el_lc) { + const SizeType j_el = dist_sub.globalElementFromLocalElement(j_el_lc); + const SizeType jj_el = perm_ptr[to_sizet(j_el)]; + + const SizeType j_lc = dist_sub.localTileFromLocalElement(j_el_lc); + const SizeType j_el_tl = dist_sub.tileElementFromLocalElement(j_el_lc); + + const SizeType jj_lc = dist_sub.localTileFromGlobalElement(jj_el); + const SizeType jj_el_tl = dist_sub.tileElementFromGlobalElement(jj_el); + + for (SizeType i_lc = 0; i_lc < dist_sub.localNrTiles().rows(); ++i_lc) { + const std::size_t j_lc_linear = to_sizet(dist_sub.localTileLinearIndex({i_lc, j_lc})); + const std::size_t jj_lc_linear = to_sizet(dist_sub.localTileLinearIndex({i_lc, jj_lc})); + + const auto& tile_in = mat_in_tiles[jj_lc_linear].get(); + auto& tile_out = mat_out_tiles[j_lc_linear]; + + DLAF_ASSERT_HEAVY(tile_in.size().rows() == tile_out.size().rows(), tile_in.size(), + tile_out.size()); + const TileElementSize region(tile_in.size().rows(), 1); + const TileElementIndex sub_in(0, jj_el_tl); + const TileElementIndex sub_out(0, j_el_tl); + + dlaf::tile::lacpy(region, sub_in, tile_in, sub_out, tile_out); + } + } + } + else { + // TODO GPU + } + }; + ex::start_detached(di::transform(di::Policy(), std::move(permute_fn), std::move(sender))); +} + template auto whenAllReadWriteTilesArray(LocalTileIndex begin, LocalTileIndex end, Matrix& matrix) { const LocalTileSize sz{end.row() - begin.row(), end.col() - begin.col()}; diff --git a/test/unit/eigensolver/test_tridiag_solver_merge.cpp b/test/unit/eigensolver/test_tridiag_solver_merge.cpp index 5701003973..76a765402e 100644 --- a/test/unit/eigensolver/test_tridiag_solver_merge.cpp +++ b/test/unit/eigensolver/test_tridiag_solver_merge.cpp @@ -92,6 +92,7 @@ TEST(StablePartitionIndexOnDeflated, FullRange) { Matrix vals(sz, bk); Matrix in(sz, bk); Matrix out(sz, bk); + Matrix out_by_type(sz, bk); // Note: // UpperHalf -> u @@ -125,7 +126,8 @@ TEST(StablePartitionIndexOnDeflated, FullRange) { const SizeType i_begin = 0; const SizeType i_end = 4; - auto k = stablePartitionIndexForDeflation(i_begin, i_end, c, vals, in, out); + auto [k, n_udl] = stablePartitionIndexForDeflation(i_begin, i_end, c, vals, in, out, out_by_type) | + ex::split_tuple(); // | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial // | l| f|d1|d2| u| u| l| f|d3| l| c_arr