From ad73045723bd06fe15c00a7fb6879f88a1b89dee Mon Sep 17 00:00:00 2001 From: Doc CI Action Date: Wed, 19 Jul 2023 11:42:25 +0000 Subject: [PATCH] Doc TridiagSolver (local): embed row permutation in rank1 solver (#936) --- ...r_2tridiag__solver_2kernels_8h_source.html | 117 +- master/merge_8h_source.html | 1677 +++++++++-------- 2 files changed, 890 insertions(+), 904 deletions(-) diff --git a/master/eigensolver_2tridiag__solver_2kernels_8h_source.html b/master/eigensolver_2tridiag__solver_2kernels_8h_source.html index a88b347355..8815b26cd6 100644 --- a/master/eigensolver_2tridiag__solver_2kernels_8h_source.html +++ b/master/eigensolver_2tridiag__solver_2kernels_8h_source.html @@ -344,86 +344,49 @@
273  di::transformDetach(di::Policy<DefaultBackend_v<D>>(), initIndexTile_o, std::move(sender));
274 }
275 
-
276 template <class T>
-
277 void setUnitDiagonal(const SizeType& k, const SizeType& tile_begin,
-
278  const matrix::Tile<T, Device::CPU>& tile);
-
279 
-
280 #define DLAF_CPU_SET_UNIT_DIAGONAL_ETI(kword, Type) \
-
281  kword template void setUnitDiagonal(const SizeType& k, const SizeType& tile_begin, \
-
282  const matrix::Tile<Type, Device::CPU>& tile)
-
283 
-
284 DLAF_CPU_SET_UNIT_DIAGONAL_ETI(extern, float);
-
285 DLAF_CPU_SET_UNIT_DIAGONAL_ETI(extern, double);
+
276 #ifdef DLAF_WITH_GPU
+
277 
+
278 // Returns the number of non-deflated entries
+
279 void stablePartitionIndexOnDevice(SizeType n, const ColType* c_ptr, const SizeType* in_ptr,
+
280  SizeType* out_ptr, SizeType* host_k_ptr, SizeType* device_k_ptr,
+
281  whip::stream_t stream);
+
282 
+
283 template <class T>
+
284 void mergeIndicesOnDevice(const SizeType* begin_ptr, const SizeType* split_ptr, const SizeType* end_ptr,
+
285  SizeType* out_ptr, const T* v_ptr, whip::stream_t stream);
286 
-
287 #ifdef DLAF_WITH_GPU
-
288 template <class T>
-
289 void setUnitDiagonal(const SizeType& k, const SizeType& tile_begin,
-
290  const matrix::Tile<T, Device::GPU>& tile, whip::stream_t stream);
+
287 #define DLAF_CUDA_MERGE_INDICES_ETI(kword, Type) \
+
288  kword template void mergeIndicesOnDevice(const SizeType* begin_ptr, const SizeType* split_ptr, \
+
289  const SizeType* end_ptr, SizeType* out_ptr, \
+
290  const Type* v_ptr, whip::stream_t stream)
291 
-
292 #define DLAF_GPU_SET_UNIT_DIAGONAL_ETI(kword, Type) \
-
293  kword template void setUnitDiagonal(const SizeType& k, const SizeType& tile_begin, \
-
294  const matrix::Tile<Type, Device::GPU>& tile, \
-
295  whip::stream_t stream)
-
296 
-
297 DLAF_GPU_SET_UNIT_DIAGONAL_ETI(extern, float);
-
298 DLAF_GPU_SET_UNIT_DIAGONAL_ETI(extern, double);
-
299 
-
300 #endif
+
292 DLAF_CUDA_MERGE_INDICES_ETI(extern, float);
+
293 DLAF_CUDA_MERGE_INDICES_ETI(extern, double);
+
294 
+
295 template <class T>
+
296 void applyIndexOnDevice(SizeType len, const SizeType* index, const T* in, T* out, whip::stream_t stream);
+
297 
+
298 #define DLAF_CUDA_APPLY_INDEX_ETI(kword, Type) \
+
299  kword template void applyIndexOnDevice(SizeType len, const SizeType* index, const Type* in, \
+
300  Type* out, whip::stream_t stream)
301 
-
302 DLAF_MAKE_CALLABLE_OBJECT(setUnitDiagonal);
-
303 
-
304 template <Device D, class KSender, class TileSender>
-
305 void setUnitDiagonalAsync(KSender&& k, SizeType tile_begin, TileSender&& tile) {
-
306  namespace di = dlaf::internal;
-
307  auto sender = di::whenAllLift(std::forward<KSender>(k), tile_begin, std::forward<TileSender>(tile));
-
308  di::transformDetach(di::Policy<DefaultBackend_v<D>>(), setUnitDiagonal_o, std::move(sender));
-
309 }
-
310 
-
311 // ---------------------------
-
312 
-
313 #ifdef DLAF_WITH_GPU
-
314 
-
315 // Returns the number of non-deflated entries
-
316 void stablePartitionIndexOnDevice(SizeType n, const ColType* c_ptr, const SizeType* in_ptr,
-
317  SizeType* out_ptr, SizeType* host_k_ptr, SizeType* device_k_ptr,
-
318  whip::stream_t stream);
-
319 
-
320 template <class T>
-
321 void mergeIndicesOnDevice(const SizeType* begin_ptr, const SizeType* split_ptr, const SizeType* end_ptr,
-
322  SizeType* out_ptr, const T* v_ptr, whip::stream_t stream);
-
323 
-
324 #define DLAF_CUDA_MERGE_INDICES_ETI(kword, Type) \
-
325  kword template void mergeIndicesOnDevice(const SizeType* begin_ptr, const SizeType* split_ptr, \
-
326  const SizeType* end_ptr, SizeType* out_ptr, \
-
327  const Type* v_ptr, whip::stream_t stream)
-
328 
-
329 DLAF_CUDA_MERGE_INDICES_ETI(extern, float);
-
330 DLAF_CUDA_MERGE_INDICES_ETI(extern, double);
-
331 
-
332 template <class T>
-
333 void applyIndexOnDevice(SizeType len, const SizeType* index, const T* in, T* out, whip::stream_t stream);
-
334 
-
335 #define DLAF_CUDA_APPLY_INDEX_ETI(kword, Type) \
-
336  kword template void applyIndexOnDevice(SizeType len, const SizeType* index, const Type* in, \
-
337  Type* out, whip::stream_t stream)
-
338 
-
339 DLAF_CUDA_APPLY_INDEX_ETI(extern, float);
-
340 DLAF_CUDA_APPLY_INDEX_ETI(extern, double);
-
341 
-
342 void invertIndexOnDevice(SizeType len, const SizeType* in, SizeType* out, whip::stream_t stream);
-
343 
-
344 template <class T>
-
345 void givensRotationOnDevice(SizeType len, T* x, T* y, T c, T s, whip::stream_t stream);
-
346 
-
347 #define DLAF_GIVENS_ROT_ETI(kword, Type) \
-
348  kword template void givensRotationOnDevice(SizeType len, Type* x, Type* y, Type c, Type s, \
-
349  whip::stream_t stream)
-
350 
-
351 DLAF_GIVENS_ROT_ETI(extern, float);
-
352 DLAF_GIVENS_ROT_ETI(extern, double);
-
353 
-
354 #endif
-
355 }
+
302 DLAF_CUDA_APPLY_INDEX_ETI(extern, float);
+
303 DLAF_CUDA_APPLY_INDEX_ETI(extern, double);
+
304 
+
305 void invertIndexOnDevice(SizeType len, const SizeType* in, SizeType* out, whip::stream_t stream);
+
306 
+
307 template <class T>
+
308 void givensRotationOnDevice(SizeType len, T* x, T* y, T c, T s, whip::stream_t stream);
+
309 
+
310 #define DLAF_GIVENS_ROT_ETI(kword, Type) \
+
311  kword template void givensRotationOnDevice(SizeType len, Type* x, Type* y, Type c, Type s, \
+
312  whip::stream_t stream)
+
313 
+
314 DLAF_GIVENS_ROT_ETI(extern, float);
+
315 DLAF_GIVENS_ROT_ETI(extern, double);
+
316 
+
317 #endif
+
318 }
#define DLAF_MAKE_CALLABLE_OBJECT(fname)
Definition: callable_object.h:26
diff --git a/master/merge_8h_source.html b/master/merge_8h_source.html index 82f3951b55..b0f1773fd0 100644 --- a/master/merge_8h_source.html +++ b/master/merge_8h_source.html @@ -508,874 +508,897 @@
437 template <class T, class KSender, class RhoSender>
438 void solveRank1Problem(const SizeType i_begin, const SizeType i_end, KSender&& k, RhoSender&& rho,
439  Matrix<const T, Device::CPU>& d, Matrix<T, Device::CPU>& z,
-
440  Matrix<T, Device::CPU>& evals, Matrix<T, Device::CPU>& evecs) {
-
441  namespace ex = pika::execution::experimental;
-
442  namespace di = dlaf::internal;
-
443 
-
444  const SizeType n = problemSize(i_begin, i_end, evals.distribution());
-
445  const SizeType nb = evals.distribution().blockSize().rows();
-
446 
-
447  TileCollector tc{i_begin, i_end};
-
448 
-
449  // Note: at least two column of tiles per-worker, in the range [1, getTridiagRank1NWorkers()]
-
450  const std::size_t nthreads = [nrtiles = (i_end - i_begin)]() {
-
451  const std::size_t min_workers = 1;
-
452  const std::size_t available_workers = getTridiagRank1NWorkers();
-
453  const std::size_t ideal_workers = util::ceilDiv(to_sizet(nrtiles), to_sizet(2));
-
454  return std::clamp(ideal_workers, min_workers, available_workers);
-
455  }();
-
456 
-
457  ex::start_detached(
-
458  ex::when_all(ex::just(std::make_unique<pika::barrier<>>(nthreads)), std::forward<KSender>(k),
-
459  std::forward<RhoSender>(rho), ex::when_all_vector(tc.read(d)),
-
460  ex::when_all_vector(tc.readwrite(z)), ex::when_all_vector(tc.readwrite(evals)),
-
461  ex::when_all_vector(tc.readwrite(evecs)),
-
462  ex::just(std::vector<memory::MemoryView<T, Device::CPU>>())) |
-
463  ex::transfer(di::getBackendScheduler<Backend::MC>(pika::execution::thread_priority::high)) |
-
464  ex::bulk(nthreads, [nthreads, n, nb](std::size_t thread_idx, auto& barrier_ptr, auto& k, auto& rho,
-
465  auto& d_tiles_futs, auto& z_tiles, auto& eval_tiles,
-
466  auto& evec_tiles, auto& ws_vecs) {
-
467  const matrix::Distribution distr(LocalElementSize(n, n), TileElementSize(nb, nb));
-
468 
-
469  const auto barrier_busy_wait = getTridiagRank1BarrierBusyWait();
-
470  const std::size_t batch_size = util::ceilDiv(to_sizet(k), nthreads);
-
471  const std::size_t begin = thread_idx * batch_size;
-
472  const std::size_t end = std::min(thread_idx * batch_size + batch_size, to_sizet(k));
-
473 
-
474  // STEP 0: Initialize workspaces (single-thread)
-
475  if (thread_idx == 0) {
-
476  ws_vecs.reserve(nthreads);
-
477  for (std::size_t i = 0; i < nthreads; ++i)
-
478  ws_vecs.emplace_back(to_sizet(k));
-
479  }
-
480 
-
481  barrier_ptr->arrive_and_wait(barrier_busy_wait);
-
482 
-
483  // STEP 1: LAED4 (multi-thread)
-
484  const T* d_ptr = d_tiles_futs[0].get().ptr();
-
485  const T* z_ptr = z_tiles[0].ptr();
-
486 
-
487  {
-
488  common::internal::SingleThreadedBlasScope single;
-
489 
-
490  T* eval_ptr = eval_tiles[0].ptr();
-
491 
-
492  for (std::size_t i = begin; i < end; ++i) {
-
493  T& eigenval = eval_ptr[i];
-
494 
-
495  const SizeType i_tile = distr.globalTileLinearIndex(GlobalElementIndex(0, to_SizeType(i)));
-
496  const SizeType i_col = distr.tileElementFromGlobalElement<Coord::Col>(to_SizeType(i));
-
497  T* delta = evec_tiles[to_sizet(i_tile)].ptr(TileElementIndex(0, i_col));
-
498 
-
499  lapack::laed4(to_int(k), to_int(i), d_ptr, z_ptr, delta, rho, &eigenval);
-
500  }
-
501 
-
502  // Note: for in-place row permutation implementation: The rows should be permuted for the k=2 case as well.
-
503 
-
504  // Note: laed4 handles k <= 2 cases differently
-
505  if (k <= 2)
-
506  return;
-
507  }
+
440  Matrix<T, Device::CPU>& evals, Matrix<const SizeType, Device::CPU>& i2,
+
441  Matrix<T, Device::CPU>& evecs) {
+
442  namespace ex = pika::execution::experimental;
+
443  namespace di = dlaf::internal;
+
444 
+
445  const SizeType n = problemSize(i_begin, i_end, evals.distribution());
+
446  const SizeType nb = evals.distribution().blockSize().rows();
+
447 
+
448  TileCollector tc{i_begin, i_end};
+
449 
+
450  // Note: at least two column of tiles per-worker, in the range [1, getTridiagRank1NWorkers()]
+
451  const std::size_t nthreads = [nrtiles = (i_end - i_begin)]() {
+
452  const std::size_t min_workers = 1;
+
453  const std::size_t available_workers = getTridiagRank1NWorkers();
+
454  const std::size_t ideal_workers = util::ceilDiv(to_sizet(nrtiles), to_sizet(2));
+
455  return std::clamp(ideal_workers, min_workers, available_workers);
+
456  }();
+
457 
+
458  ex::start_detached(
+
459  ex::when_all(ex::just(std::make_unique<pika::barrier<>>(nthreads)), std::forward<KSender>(k),
+
460  std::forward<RhoSender>(rho), ex::when_all_vector(tc.read(d)),
+
461  ex::when_all_vector(tc.readwrite(z)), ex::when_all_vector(tc.readwrite(evals)),
+
462  ex::when_all_vector(tc.read(i2)), ex::when_all_vector(tc.readwrite(evecs)),
+
463  ex::just(std::vector<memory::MemoryView<T, Device::CPU>>())) |
+
464  ex::transfer(di::getBackendScheduler<Backend::MC>(pika::execution::thread_priority::high)) |
+
465  ex::bulk(nthreads, [nthreads, n, nb](std::size_t thread_idx, auto& barrier_ptr, auto& k, auto& rho,
+
466  auto& d_tiles_futs, auto& z_tiles, auto& eval_tiles,
+
467  const auto& i2_tile_arr, auto& evec_tiles, auto& ws_vecs) {
+
468  const matrix::Distribution distr(LocalElementSize(n, n), TileElementSize(nb, nb));
+
469 
+
470  const SizeType* i2_perm = i2_tile_arr[0].get().ptr();
+
471 
+
472  const auto barrier_busy_wait = getTridiagRank1BarrierBusyWait();
+
473  const std::size_t batch_size = util::ceilDiv(to_sizet(k), nthreads);
+
474  const std::size_t begin = thread_idx * batch_size;
+
475  const std::size_t end = std::min(thread_idx * batch_size + batch_size, to_sizet(k));
+
476 
+
477  // STEP 0a: Fill ones for deflated Eigenvectors. (single-thread)
+
478  // Note: this step is completely independent from the rest, but it is small and it is going
+
479  // to be dropped soon.
+
480  // Note: use last thread that in principle should have less work to do
+
481  if (thread_idx == nthreads - 1) {
+
482  for (SizeType i = 0; i < n; ++i) {
+
483  const SizeType j = i2_perm[to_sizet(i)];
+
484 
+
485  // if it is deflated
+
486  if (j >= k) {
+
487  const GlobalElementIndex ij(i, j);
+
488  const auto linear_ij = distr.globalTileLinearIndex(ij);
+
489  const auto ij_el = distr.tileElementIndex(ij);
+
490 
+
491  evec_tiles[to_sizet(linear_ij)](ij_el) = 1;
+
492  }
+
493  }
+
494  }
+
495 
+
496  // STEP 0b: Initialize workspaces (single-thread)
+
497  if (thread_idx == 0) {
+
498  ws_vecs.reserve(nthreads);
+
499  for (std::size_t i = 0; i < nthreads; ++i)
+
500  ws_vecs.emplace_back(to_sizet(k));
+
501  }
+
502 
+
503  barrier_ptr->arrive_and_wait(barrier_busy_wait);
+
504 
+
505  // STEP 1: LAED4 (multi-thread)
+
506  const T* d_ptr = d_tiles_futs[0].get().ptr();
+
507  const T* z_ptr = z_tiles[0].ptr();
508 
-
509  barrier_ptr->arrive_and_wait(barrier_busy_wait);
-
510 
-
511  // STEP 2a Compute weights (multi-thread)
-
512  auto& q = evec_tiles;
-
513  T* w = ws_vecs[thread_idx]();
-
514 
-
515  // - copy diagonal from q -> w (or just initialize with 1)
-
516  if (thread_idx == 0) {
-
517  for (auto i = 0; i < k; ++i) {
-
518  const GlobalElementIndex kk(i, i);
-
519  const auto diag_tile = distr.globalTileLinearIndex(kk);
-
520  const auto diag_element = distr.tileElementIndex(kk);
-
521 
-
522  w[i] = q[to_sizet(diag_tile)](diag_element);
-
523  }
-
524  }
-
525  else {
-
526  std::fill_n(w, k, T(1));
-
527  }
-
528 
-
529  // - compute productorial
-
530  auto compute_w = [&](const GlobalElementIndex ij) {
-
531  const auto q_tile = distr.globalTileLinearIndex(ij);
-
532  const auto q_ij = distr.tileElementIndex(ij);
+
509  {
+
510  common::internal::SingleThreadedBlasScope single;
+
511 
+
512  T* eval_ptr = eval_tiles[0].ptr();
+
513 
+
514  for (std::size_t i = begin; i < end; ++i) {
+
515  T& eigenval = eval_ptr[i];
+
516 
+
517  const SizeType i_tile = distr.globalTileLinearIndex(GlobalElementIndex(0, to_SizeType(i)));
+
518  const SizeType i_col = distr.tileElementFromGlobalElement<Coord::Col>(to_SizeType(i));
+
519  T* delta = evec_tiles[to_sizet(i_tile)].ptr(TileElementIndex(0, i_col));
+
520 
+
521  lapack::laed4(to_int(k), to_int(i), d_ptr, z_ptr, delta, rho, &eigenval);
+
522  }
+
523 
+
524  // Note: laed4 handles k <= 2 cases differently
+
525  if (k <= 2) {
+
526  // Note: The rows should be permuted for the k=2 case as well.
+
527  if (k == 2) {
+
528  T* ws = ws_vecs[thread_idx]();
+
529  for (SizeType j = to_SizeType(begin); j < to_SizeType(end); ++j) {
+
530  const SizeType j_tile = distr.globalTileLinearIndex(GlobalElementIndex(0, j));
+
531  const SizeType j_col = distr.tileElementFromGlobalElement<Coord::Col>(j);
+
532  T* evec = evec_tiles[to_sizet(j_tile)].ptr(TileElementIndex(0, j_col));
533 
-
534  const SizeType i = ij.row();
-
535  const SizeType j = ij.col();
-
536 
-
537  w[i] *= q[to_sizet(q_tile)](q_ij) / (d_ptr[to_sizet(i)] - d_ptr[to_sizet(j)]);
-
538  };
-
539 
-
540  for (auto j = to_SizeType(begin); j < to_SizeType(end); ++j) {
-
541  for (auto i = 0; i < j; ++i)
-
542  compute_w({i, j});
-
543 
-
544  for (auto i = j + 1; i < k; ++i)
-
545  compute_w({i, j});
-
546  }
-
547 
+
534  std::copy(evec, evec + k, ws);
+
535  std::fill_n(evec, k, 0); // by default "deflated"
+
536  for (SizeType i = 0; i < n; ++i) {
+
537  const SizeType ii = i2_perm[i];
+
538  if (ii < k)
+
539  evec[i] = ws[ii];
+
540  }
+
541  }
+
542  }
+
543  return;
+
544  }
+
545  }
+
546 
+
547  // Note: This barrier ensures that LAED4 finished, so from now on values are available
548  barrier_ptr->arrive_and_wait(barrier_busy_wait);
549 
-
550  // STEP 2B: reduce, then finalize computation with sign and square root (single-thread)
-
551  if (thread_idx == 0) {
-
552  for (int i = 0; i < k; ++i) {
-
553  for (std::size_t tidx = 1; tidx < nthreads; ++tidx) {
-
554  const T* w_partial = ws_vecs[tidx]();
-
555  w[i] *= w_partial[i];
-
556  }
-
557  z_tiles[0].ptr()[i] = std::copysign(std::sqrt(-w[i]), z_ptr[to_sizet(i)]);
-
558  }
-
559  }
+
550  // STEP 2a Compute weights (multi-thread)
+
551  auto& q = evec_tiles;
+
552  T* w = ws_vecs[thread_idx]();
+
553 
+
554  // - copy diagonal from q -> w (or just initialize with 1)
+
555  if (thread_idx == 0) {
+
556  for (SizeType i = 0; i < k; ++i) {
+
557  const GlobalElementIndex kk(i, i);
+
558  const auto diag_tile = distr.globalTileLinearIndex(kk);
+
559  const auto diag_element = distr.tileElementIndex(kk);
560 
-
561  barrier_ptr->arrive_and_wait(barrier_busy_wait);
-
562 
-
563  // STEP 3: Compute eigenvectors of the modified rank-1 modification (normalize) (multi-thread)
-
564  {
-
565  common::internal::SingleThreadedBlasScope single;
-
566 
-
567  const T* w = z_ptr;
-
568  T* s = ws_vecs[thread_idx]();
-
569 
-
570  for (auto j = to_SizeType(begin); j < to_SizeType(end); ++j) {
-
571  for (int i = 0; i < k; ++i) {
-
572  const auto q_tile = distr.globalTileLinearIndex({i, j});
-
573  const auto q_ij = distr.tileElementIndex({i, j});
-
574 
-
575  s[i] = w[i] / q[to_sizet(q_tile)](q_ij);
-
576  }
-
577 
-
578  const T vec_norm = blas::nrm2(k, s, 1);
-
579 
-
580  for (auto i = 0; i < k; ++i) {
-
581  const auto q_tile = distr.globalTileLinearIndex({i, j});
-
582  const auto q_ij = distr.tileElementIndex({i, j});
-
583 
-
584  q[to_sizet(q_tile)](q_ij) = s[i] / vec_norm;
-
585  }
-
586  }
-
587  }
-
588  }));
-
589 }
-
590 
-
591 template <class T, Device D, class KSender>
-
592 void setUnitDiag(const SizeType i_begin, const SizeType i_end, KSender&& k, Matrix<T, D>& mat) {
-
593  // Iterate over diagonal tiles
-
594  const matrix::Distribution& distr = mat.distribution();
-
595  for (SizeType i_tile = i_begin; i_tile < i_end; ++i_tile) {
-
596  const SizeType tile_begin = distr.globalTileElementDistance<Coord::Row>(i_begin, i_tile);
-
597 
-
598  setUnitDiagonalAsync<D>(k, tile_begin, mat.readwrite(GlobalTileIndex(i_tile, i_tile)));
-
599  }
-
600 }
+
561  w[i] = q[to_sizet(diag_tile)](diag_element);
+
562  }
+
563  }
+
564  else {
+
565  std::fill_n(w, k, T(1));
+
566  }
+
567 
+
568  // - compute productorial
+
569  auto compute_w = [&](const GlobalElementIndex ij) {
+
570  const auto q_tile = distr.globalTileLinearIndex(ij);
+
571  const auto q_ij = distr.tileElementIndex(ij);
+
572 
+
573  const SizeType i = ij.row();
+
574  const SizeType j = ij.col();
+
575 
+
576  w[i] *= q[to_sizet(q_tile)](q_ij) / (d_ptr[to_sizet(i)] - d_ptr[to_sizet(j)]);
+
577  };
+
578 
+
579  for (SizeType j = to_SizeType(begin); j < to_SizeType(end); ++j) {
+
580  for (SizeType i = 0; i < j; ++i)
+
581  compute_w({i, j});
+
582 
+
583  for (SizeType i = j + 1; i < k; ++i)
+
584  compute_w({i, j});
+
585  }
+
586 
+
587  barrier_ptr->arrive_and_wait(barrier_busy_wait);
+
588 
+
589  // STEP 2B: reduce, then finalize computation with sign and square root (single-thread)
+
590  if (thread_idx == 0) {
+
591  for (SizeType i = 0; i < k; ++i) {
+
592  for (std::size_t tidx = 1; tidx < nthreads; ++tidx) {
+
593  const T* w_partial = ws_vecs[tidx]();
+
594  w[i] *= w_partial[i];
+
595  }
+
596  z_tiles[0].ptr()[i] = std::copysign(std::sqrt(-w[i]), z_ptr[to_sizet(i)]);
+
597  }
+
598  }
+
599 
+
600  barrier_ptr->arrive_and_wait(barrier_busy_wait);
601 
-
602 template <Backend B, Device D, class T, class RhoSender>
-
603 void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const SizeType i_end,
-
604  RhoSender&& rho, WorkSpace<T, D>& ws, WorkSpaceHost<T>& ws_h,
-
605  WorkSpaceHostMirror<T, D>& ws_hm) {
-
606  namespace ex = pika::execution::experimental;
-
607 
-
608  const GlobalTileIndex idx_gl_begin(i_begin, i_begin);
-
609  const LocalTileIndex idx_loc_begin(i_begin, i_begin);
-
610  const SizeType nrtiles = i_end - i_begin;
-
611  const LocalTileSize sz_loc_tiles(nrtiles, nrtiles);
-
612 
-
613  const LocalTileIndex idx_begin_tiles_vec(i_begin, 0);
-
614  const LocalTileSize sz_tiles_vec(nrtiles, 1);
-
615 
-
616  // Calculate the size of the upper subproblem
-
617  const SizeType n1 = problemSize(i_begin, i_split, ws.e0.distribution());
+
602  // STEP 3: Compute eigenvectors of the modified rank-1 modification (normalize) (multi-thread)
+
603  {
+
604  common::internal::SingleThreadedBlasScope single;
+
605 
+
606  const T* w = z_ptr;
+
607  T* s = ws_vecs[thread_idx]();
+
608 
+
609  for (SizeType j = to_SizeType(begin); j < to_SizeType(end); ++j) {
+
610  for (SizeType i = 0; i < k; ++i) {
+
611  const auto q_tile = distr.globalTileLinearIndex({i, j});
+
612  const auto q_ij = distr.tileElementIndex({i, j});
+
613 
+
614  s[i] = w[i] / q[to_sizet(q_tile)](q_ij);
+
615  }
+
616 
+
617  const T vec_norm = blas::nrm2(k, s, 1);
618 
-
619  // Assemble the rank-1 update vector `z` from the last row of Q1 and the first row of Q2
-
620  assembleZVec(i_begin, i_split, i_end, rho, ws.e0, ws.z0);
-
621  copy(idx_begin_tiles_vec, sz_tiles_vec, ws.z0, ws_hm.z0);
-
622 
-
623  // Double `rho` to account for the normalization of `z` and make sure `rho > 0` for the root solver laed4
-
624  auto scaled_rho = scaleRho(std::move(rho)) | ex::split();
-
625 
-
626  // Calculate the tolerance used for deflation
-
627  auto tol = calcTolerance(i_begin, i_end, ws_h.d0, ws_hm.z0);
-
628 
-
629  // Initialize the column types vector `c`
-
630  initColTypes(i_begin, i_split, i_end, ws_h.c);
-
631 
-
632  // Step #1
-
633  //
-
634  // i1 (out) : initial <--- initial (identity map)
-
635  // i2 (out) : initial <--- pre_sorted
-
636  //
-
637  // - deflate `d`, `z` and `c`
-
638  // - apply Givens rotations to `Q` - `evecs`
-
639  //
-
640  if (i_split == i_begin + 1) {
-
641  initIndex(i_begin, i_split, ws_h.i1);
-
642  }
-
643  if (i_split + 1 == i_end) {
-
644  initIndex(i_split, i_end, ws_h.i1);
-
645  }
-
646  addIndex(i_split, i_end, n1, ws_h.i1);
-
647  sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2);
-
648 
-
649  auto rots =
-
650  applyDeflation(i_begin, i_end, scaled_rho, std::move(tol), ws_hm.i2, ws_h.d0, ws_hm.z0, ws_h.c);
-
651 
-
652  // ---
-
653 
-
654  applyGivensRotationsToMatrixColumns(i_begin, i_end, std::move(rots), ws.e0);
-
655  // Placeholder for rearranging the eigenvectors: (local permutation)
-
656  copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws.e1);
+
619  for (SizeType i = 0; i < n; ++i) {
+
620  const SizeType ii = i2_perm[i];
+
621  const auto q_tile = distr.globalTileLinearIndex({i, j});
+
622  const auto q_ij = distr.tileElementIndex({i, j});
+
623 
+
624  if (ii < k)
+
625  q[to_sizet(q_tile)](q_ij) = s[ii] / vec_norm;
+
626  else
+
627  q[to_sizet(q_tile)](q_ij) = 0;
+
628  }
+
629  }
+
630  }
+
631  }));
+
632 }
+
633 
+
634 template <Backend B, Device D, class T, class RhoSender>
+
635 void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const SizeType i_end,
+
636  RhoSender&& rho, WorkSpace<T, D>& ws, WorkSpaceHost<T>& ws_h,
+
637  WorkSpaceHostMirror<T, D>& ws_hm) {
+
638  namespace ex = pika::execution::experimental;
+
639 
+
640  const GlobalTileIndex idx_gl_begin(i_begin, i_begin);
+
641  const LocalTileIndex idx_loc_begin(i_begin, i_begin);
+
642  const SizeType nrtiles = i_end - i_begin;
+
643  const LocalTileSize sz_loc_tiles(nrtiles, nrtiles);
+
644 
+
645  const LocalTileIndex idx_begin_tiles_vec(i_begin, 0);
+
646  const LocalTileSize sz_tiles_vec(nrtiles, 1);
+
647 
+
648  // Calculate the size of the upper subproblem
+
649  const SizeType n1 = problemSize(i_begin, i_split, ws.e0.distribution());
+
650 
+
651  // Assemble the rank-1 update vector `z` from the last row of Q1 and the first row of Q2
+
652  assembleZVec(i_begin, i_split, i_end, rho, ws.e0, ws.z0);
+
653  copy(idx_begin_tiles_vec, sz_tiles_vec, ws.z0, ws_hm.z0);
+
654 
+
655  // Double `rho` to account for the normalization of `z` and make sure `rho > 0` for the root solver laed4
+
656  auto scaled_rho = scaleRho(std::move(rho)) | ex::split();
657 
-
658  // Step #2
-
659  //
-
660  // i2 (in) : initial <--- pre_sorted
-
661  // i3 (out) : initial <--- deflated
-
662  //
-
663  // - reorder `d0 -> d1`, `z0 -> z1`, using `i3` such that deflated entries are at the bottom.
-
664  // - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`.
-
665  // - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented)
-
666  //
-
667  auto k = stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_hm.i2, ws_h.i3) | ex::split();
-
668 
-
669  applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1);
-
670  applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1);
-
671  copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.d1, ws_h.d0);
-
672 
-
673  //
-
674  // i3 (in) : initial <--- deflated
-
675  // i2 (out) : initial ---> deflated
-
676  //
-
677  invertIndex(i_begin, i_end, ws_h.i3, ws_hm.i2);
-
678 
-
679  // Note:
-
680  // This is neeeded to set to zero elements of e2 outside of the k by k top-left part.
-
681  // The input is not required to be zero for solveRank1Problem.
-
682  matrix::util::set0<Backend::MC>(pika::execution::thread_priority::normal, idx_loc_begin, sz_loc_tiles,
-
683  ws_hm.e2);
-
684  solveRank1Problem(i_begin, i_end, k, scaled_rho, ws_hm.d1, ws_hm.z1, ws_h.d0, ws_hm.e2);
+
658  // Calculate the tolerance used for deflation
+
659  auto tol = calcTolerance(i_begin, i_end, ws_h.d0, ws_hm.z0);
+
660 
+
661  // Initialize the column types vector `c`
+
662  initColTypes(i_begin, i_split, i_end, ws_h.c);
+
663 
+
664  // Step #1
+
665  //
+
666  // i1 (out) : initial <--- initial (identity map)
+
667  // i2 (out) : initial <--- pre_sorted
+
668  //
+
669  // - deflate `d`, `z` and `c`
+
670  // - apply Givens rotations to `Q` - `evecs`
+
671  //
+
672  if (i_split == i_begin + 1) {
+
673  initIndex(i_begin, i_split, ws_h.i1);
+
674  }
+
675  if (i_split + 1 == i_end) {
+
676  initIndex(i_split, i_end, ws_h.i1);
+
677  }
+
678  addIndex(i_split, i_end, n1, ws_h.i1);
+
679  sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2);
+
680 
+
681  auto rots =
+
682  applyDeflation(i_begin, i_end, scaled_rho, std::move(tol), ws_hm.i2, ws_h.d0, ws_hm.z0, ws_h.c);
+
683 
+
684  // ---
685 
-
686  copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2);
-
687 
-
688  setUnitDiag(i_begin, i_end, k, ws.e2);
+
686  applyGivensRotationsToMatrixColumns(i_begin, i_end, std::move(rots), ws.e0);
+
687  // Placeholder for rearranging the eigenvectors: (local permutation)
+
688  copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws.e1);
689 
-
690  // Step #3: Eigenvectors of the tridiagonal system: Q * U
+
690  // Step #2
691  //
-
692  // The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as
-
693  // prepared for the deflated system.
+
692  // i2 (in) : initial <--- pre_sorted
+
693  // i3 (out) : initial <--- deflated
694  //
-
695  copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i2, ws.i2);
-
696  // The following permutation will be removed in the future.
-
697  // (The copy is needed to simplify the removal)
-
698  dlaf::permutations::permute<B, D, T, Coord::Row>(i_begin, i_end, ws.i2, ws.e2, ws.e0);
-
699  copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws.e2);
-
700  dlaf::multiplication::generalSubMatrix<B, D, T>(i_begin, i_end, blas::Op::NoTrans, blas::Op::NoTrans,
-
701  T(1), ws.e1, ws.e2, T(0), ws.e0);
-
702 
-
703  // Step #4: Final permutation to sort eigenvalues and eigenvectors
-
704  //
-
705  // i1 (in) : deflated <--- deflated (identity map)
-
706  // i2 (out) : deflated <--- post_sorted
-
707  //
-
708  initIndex(i_begin, i_end, ws_h.i1);
-
709  sortIndex(i_begin, i_end, std::move(k), ws_h.d0, ws_h.i1, ws_hm.i2);
-
710  copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i2, ws_h.i1);
-
711 }
-
712 
-
713 // The bottom row of Q1 and the top row of Q2. The bottom row of Q1 is negated if `rho < 0`.
-
714 //
-
715 // Note that the norm of `z` is sqrt(2) because it is a concatination of two normalized vectors. Hence
-
716 // to normalize `z` we have to divide by sqrt(2).
-
717 template <class T, Device D, class RhoSender>
-
718 void assembleDistZVec(comm::CommunicatorGrid grid, common::Pipeline<comm::Communicator>& full_task_chain,
-
719  const SizeType i_begin, const SizeType i_split, const SizeType i_end,
-
720  RhoSender&& rho, Matrix<const T, D>& evecs, Matrix<T, D>& z) {
-
721  namespace ex = pika::execution::experimental;
-
722 
-
723  const matrix::Distribution& dist = evecs.distribution();
-
724  comm::Index2D this_rank = dist.rankIndex();
+
695  // - reorder `d0 -> d1`, `z0 -> z1`, using `i3` such that deflated entries are at the bottom.
+
696  // - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`.
+
697  // - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented)
+
698  //
+
699  auto k = stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_hm.i2, ws_h.i3) | ex::split();
+
700 
+
701  applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1);
+
702  applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1);
+
703  copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.d1, ws_h.d0);
+
704 
+
705  //
+
706  // i3 (in) : initial <--- deflated
+
707  // i2 (out) : initial ---> deflated
+
708  //
+
709  invertIndex(i_begin, i_end, ws_h.i3, ws_hm.i2);
+
710 
+
711  // Note:
+
712  // This is neeeded to set to zero elements of e2 outside of the k by k top-left part.
+
713  // The input is not required to be zero for solveRank1Problem.
+
714  matrix::util::set0<Backend::MC>(pika::execution::thread_priority::normal, idx_loc_begin, sz_loc_tiles,
+
715  ws_hm.e2);
+
716  solveRank1Problem(i_begin, i_end, k, scaled_rho, ws_hm.d1, ws_hm.z1, ws_h.d0, ws_hm.i2, ws_hm.e2);
+
717  copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2);
+
718 
+
719  // Step #3: Eigenvectors of the tridiagonal system: Q * U
+
720  //
+
721  // The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as
+
722  // prepared for the deflated system.
+
723  dlaf::multiplication::generalSubMatrix<B, D, T>(i_begin, i_end, blas::Op::NoTrans, blas::Op::NoTrans,
+
724  T(1), ws.e1, ws.e2, T(0), ws.e0);
725 
-
726  // Iterate over tiles of Q1 and Q2 around the split row `i_split`.
-
727  for (SizeType i = i_begin; i < i_end; ++i) {
-
728  // True if tile is in Q1
-
729  bool top_tile = i < i_split;
-
730  // Move to the row below `i_split` for `Q2`
-
731  const SizeType evecs_row = i_split - ((top_tile) ? 1 : 0);
-
732  const GlobalTileIndex idx_evecs(evecs_row, i);
-
733  const GlobalTileIndex z_idx(i, 0);
-
734 
-
735  // Copy the last row of a `Q1` tile or the first row of a `Q2` tile into a column vector `z` tile
-
736  comm::Index2D evecs_tile_rank = dist.rankGlobalTile(idx_evecs);
-
737  if (evecs_tile_rank == this_rank) {
-
738  // Copy the row into the column vector `z`
-
739  assembleRank1UpdateVectorTileAsync<T, D>(top_tile, rho, evecs.read(idx_evecs), z.readwrite(z_idx));
-
740  ex::start_detached(comm::scheduleSendBcast(full_task_chain(), z.read(z_idx)));
-
741  }
-
742  else {
-
743  const comm::IndexT_MPI root_rank = grid.rankFullCommunicator(evecs_tile_rank);
-
744  ex::start_detached(comm::scheduleRecvBcast(full_task_chain(), root_rank, z.readwrite(z_idx)));
-
745  }
-
746  }
-
747 }
+
726  // Step #4: Final permutation to sort eigenvalues and eigenvectors
+
727  //
+
728  // i1 (in) : deflated <--- deflated (identity map)
+
729  // i2 (out) : deflated <--- post_sorted
+
730  //
+
731  initIndex(i_begin, i_end, ws_h.i1);
+
732  sortIndex(i_begin, i_end, std::move(k), ws_h.d0, ws_h.i1, ws_hm.i2);
+
733  copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i2, ws_h.i1);
+
734 }
+
735 
+
736 // The bottom row of Q1 and the top row of Q2. The bottom row of Q1 is negated if `rho < 0`.
+
737 //
+
738 // Note that the norm of `z` is sqrt(2) because it is a concatination of two normalized vectors. Hence
+
739 // to normalize `z` we have to divide by sqrt(2).
+
740 template <class T, Device D, class RhoSender>
+
741 void assembleDistZVec(comm::CommunicatorGrid grid, common::Pipeline<comm::Communicator>& full_task_chain,
+
742  const SizeType i_begin, const SizeType i_split, const SizeType i_end,
+
743  RhoSender&& rho, Matrix<const T, D>& evecs, Matrix<T, D>& z) {
+
744  namespace ex = pika::execution::experimental;
+
745 
+
746  const matrix::Distribution& dist = evecs.distribution();
+
747  comm::Index2D this_rank = dist.rankIndex();
748 
-
749 template <class T, class CommSender, class KSender, class RhoSender>
-
750 void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const SizeType i_begin,
-
751  const SizeType i_end, const LocalTileIndex ij_begin_lc,
-
752  const LocalTileSize sz_loc_tiles, KSender&& k, RhoSender&& rho,
-
753  Matrix<const T, Device::CPU>& d, Matrix<T, Device::CPU>& z,
-
754  Matrix<T, Device::CPU>& evals, Matrix<const SizeType, Device::CPU>& i2,
-
755  Matrix<T, Device::CPU>& evecs) {
-
756  namespace ex = pika::execution::experimental;
-
757  namespace di = dlaf::internal;
-
758  namespace tt = pika::this_thread::experimental;
-
759 
-
760  const matrix::Distribution& dist = evecs.distribution();
-
761 
-
762  TileCollector tc{i_begin, i_end};
-
763 
-
764  const SizeType n = problemSize(i_begin, i_end, dist);
-
765 
-
766  const SizeType m_subm_el_lc = [=]() {
-
767  const auto i_loc_begin = ij_begin_lc.row();
-
768  const auto i_loc_end = ij_begin_lc.row() + sz_loc_tiles.rows();
-
769  return dist.localElementDistanceFromLocalTile<Coord::Row>(i_loc_begin, i_loc_end);
-
770  }();
+
749  // Iterate over tiles of Q1 and Q2 around the split row `i_split`.
+
750  for (SizeType i = i_begin; i < i_end; ++i) {
+
751  // True if tile is in Q1
+
752  bool top_tile = i < i_split;
+
753  // Move to the row below `i_split` for `Q2`
+
754  const SizeType evecs_row = i_split - ((top_tile) ? 1 : 0);
+
755  const GlobalTileIndex idx_evecs(evecs_row, i);
+
756  const GlobalTileIndex z_idx(i, 0);
+
757 
+
758  // Copy the last row of a `Q1` tile or the first row of a `Q2` tile into a column vector `z` tile
+
759  comm::Index2D evecs_tile_rank = dist.rankGlobalTile(idx_evecs);
+
760  if (evecs_tile_rank == this_rank) {
+
761  // Copy the row into the column vector `z`
+
762  assembleRank1UpdateVectorTileAsync<T, D>(top_tile, rho, evecs.read(idx_evecs), z.readwrite(z_idx));
+
763  ex::start_detached(comm::scheduleSendBcast(full_task_chain(), z.read(z_idx)));
+
764  }
+
765  else {
+
766  const comm::IndexT_MPI root_rank = grid.rankFullCommunicator(evecs_tile_rank);
+
767  ex::start_detached(comm::scheduleRecvBcast(full_task_chain(), root_rank, z.readwrite(z_idx)));
+
768  }
+
769  }
+
770 }
771 
-
772  const SizeType n_subm_el_lc = [=]() {
-
773  const auto i_loc_begin = ij_begin_lc.col();
-
774  const auto i_loc_end = ij_begin_lc.col() + sz_loc_tiles.cols();
-
775  return dist.localElementDistanceFromLocalTile<Coord::Col>(i_loc_begin, i_loc_end);
-
776  }();
-
777 
-
778  auto bcast_evals = [i_begin, i_end,
-
779  dist](common::Pipeline<comm::Communicator>& row_comm_chain,
-
780  const std::vector<matrix::Tile<T, Device::CPU>>& eval_tiles) {
-
781  using dlaf::comm::internal::sendBcast_o;
-
782  using dlaf::comm::internal::recvBcast_o;
-
783 
-
784  const comm::Index2D this_rank = dist.rankIndex();
-
785 
-
786  std::vector<ex::unique_any_sender<>> comms;
-
787  comms.reserve(to_sizet(i_end - i_begin));
+
772 template <class T, class CommSender, class KSender, class RhoSender>
+
773 void solveRank1ProblemDist(CommSender&& row_comm, CommSender&& col_comm, const SizeType i_begin,
+
774  const SizeType i_end, const LocalTileIndex ij_begin_lc,
+
775  const LocalTileSize sz_loc_tiles, KSender&& k, RhoSender&& rho,
+
776  Matrix<const T, Device::CPU>& d, Matrix<T, Device::CPU>& z,
+
777  Matrix<T, Device::CPU>& evals, Matrix<const SizeType, Device::CPU>& i2,
+
778  Matrix<T, Device::CPU>& evecs) {
+
779  namespace ex = pika::execution::experimental;
+
780  namespace di = dlaf::internal;
+
781  namespace tt = pika::this_thread::experimental;
+
782 
+
783  const matrix::Distribution& dist = evecs.distribution();
+
784 
+
785  TileCollector tc{i_begin, i_end};
+
786 
+
787  const SizeType n = problemSize(i_begin, i_end, dist);
788 
-
789  for (SizeType i = i_begin; i < i_end; ++i) {
-
790  const comm::IndexT_MPI evecs_tile_rank = dist.rankGlobalTile<Coord::Col>(i);
-
791  auto& tile = eval_tiles[to_sizet(i - i_begin)];
-
792 
-
793  if (evecs_tile_rank == this_rank.col())
-
794  comms.emplace_back(ex::when_all(row_comm_chain(), ex::just(std::cref(tile))) |
-
795  transformMPI(sendBcast_o));
-
796  else
-
797  comms.emplace_back(ex::when_all(row_comm_chain(), ex::just(evecs_tile_rank, std::cref(tile))) |
-
798  transformMPI(recvBcast_o));
-
799  }
+
789  const SizeType m_subm_el_lc = [=]() {
+
790  const auto i_loc_begin = ij_begin_lc.row();
+
791  const auto i_loc_end = ij_begin_lc.row() + sz_loc_tiles.rows();
+
792  return dist.localElementDistanceFromLocalTile<Coord::Row>(i_loc_begin, i_loc_end);
+
793  }();
+
794 
+
795  const SizeType n_subm_el_lc = [=]() {
+
796  const auto i_loc_begin = ij_begin_lc.col();
+
797  const auto i_loc_end = ij_begin_lc.col() + sz_loc_tiles.cols();
+
798  return dist.localElementDistanceFromLocalTile<Coord::Col>(i_loc_begin, i_loc_end);
+
799  }();
800 
-
801  return ex::ensure_started(ex::when_all_vector(std::move(comms)));
-
802  };
-
803 
-
804  auto all_reduce_in_place = [](const dlaf::comm::Communicator& comm, MPI_Op reduce_op, const auto& data,
-
805  MPI_Request* req) {
-
806  auto msg = comm::make_message(data);
-
807  DLAF_MPI_CHECK_ERROR(MPI_Iallreduce(MPI_IN_PLACE, msg.data(), msg.count(), msg.mpi_type(), reduce_op,
-
808  comm, req));
-
809  };
-
810 
-
811  // Note: at least two column of tiles per-worker, in the range [1, getTridiagRank1NWorkers()]
-
812  const std::size_t nthreads = [nrtiles = sz_loc_tiles.cols()]() {
-
813  const std::size_t min_workers = 1;
-
814  const std::size_t available_workers = getTridiagRank1NWorkers();
-
815  const std::size_t ideal_workers = util::ceilDiv(to_sizet(nrtiles), to_sizet(2));
-
816  return std::clamp(ideal_workers, min_workers, available_workers);
-
817  }();
-
818 
-
819  ex::start_detached(
-
820  ex::when_all(ex::just(std::make_unique<pika::barrier<>>(nthreads)),
-
821  std::forward<CommSender>(row_comm), std::forward<CommSender>(col_comm),
-
822  std::forward<KSender>(k), std::forward<RhoSender>(rho),
-
823  ex::when_all_vector(tc.read(d)), ex::when_all_vector(tc.readwrite(z)),
-
824  ex::when_all_vector(tc.readwrite(evals)), ex::when_all_vector(tc.read(i2)),
-
825  ex::when_all_vector(tc.readwrite(evecs)),
-
826  // additional workspaces
-
827  ex::just(std::vector<memory::MemoryView<T, Device::CPU>>()),
-
828  ex::just(memory::MemoryView<T, Device::CPU>())) |
-
829  ex::transfer(di::getBackendScheduler<Backend::MC>(pika::execution::thread_priority::high)) |
-
830  ex::bulk(nthreads, [nthreads, n, n_subm_el_lc, m_subm_el_lc, i_begin, ij_begin_lc, sz_loc_tiles,
-
831  dist, bcast_evals, all_reduce_in_place](
-
832  const std::size_t thread_idx, auto& barrier_ptr, auto& row_comm_wrapper,
-
833  auto& col_comm_wrapper, const auto& k, const auto& rho,
-
834  const auto& d_tiles_futs, auto& z_tiles, const auto& eval_tiles,
-
835  const auto& i2_tile_arr, const auto& evec_tiles, auto& ws_cols,
-
836  auto& ws_row) {
-
837  using dlaf::comm::internal::transformMPI;
-
838 
-
839  common::Pipeline<comm::Communicator> row_comm_chain(row_comm_wrapper.get());
-
840  const dlaf::comm::Communicator& col_comm = col_comm_wrapper.get();
+
801  auto bcast_evals = [i_begin, i_end,
+
802  dist](common::Pipeline<comm::Communicator>& row_comm_chain,
+
803  const std::vector<matrix::Tile<T, Device::CPU>>& eval_tiles) {
+
804  using dlaf::comm::internal::sendBcast_o;
+
805  using dlaf::comm::internal::recvBcast_o;
+
806 
+
807  const comm::Index2D this_rank = dist.rankIndex();
+
808 
+
809  std::vector<ex::unique_any_sender<>> comms;
+
810  comms.reserve(to_sizet(i_end - i_begin));
+
811 
+
812  for (SizeType i = i_begin; i < i_end; ++i) {
+
813  const comm::IndexT_MPI evecs_tile_rank = dist.rankGlobalTile<Coord::Col>(i);
+
814  auto& tile = eval_tiles[to_sizet(i - i_begin)];
+
815 
+
816  if (evecs_tile_rank == this_rank.col())
+
817  comms.emplace_back(ex::when_all(row_comm_chain(), ex::just(std::cref(tile))) |
+
818  transformMPI(sendBcast_o));
+
819  else
+
820  comms.emplace_back(ex::when_all(row_comm_chain(), ex::just(evecs_tile_rank, std::cref(tile))) |
+
821  transformMPI(recvBcast_o));
+
822  }
+
823 
+
824  return ex::ensure_started(ex::when_all_vector(std::move(comms)));
+
825  };
+
826 
+
827  auto all_reduce_in_place = [](const dlaf::comm::Communicator& comm, MPI_Op reduce_op, const auto& data,
+
828  MPI_Request* req) {
+
829  auto msg = comm::make_message(data);
+
830  DLAF_MPI_CHECK_ERROR(MPI_Iallreduce(MPI_IN_PLACE, msg.data(), msg.count(), msg.mpi_type(), reduce_op,
+
831  comm, req));
+
832  };
+
833 
+
834  // Note: at least two column of tiles per-worker, in the range [1, getTridiagRank1NWorkers()]
+
835  const std::size_t nthreads = [nrtiles = sz_loc_tiles.cols()]() {
+
836  const std::size_t min_workers = 1;
+
837  const std::size_t available_workers = getTridiagRank1NWorkers();
+
838  const std::size_t ideal_workers = util::ceilDiv(to_sizet(nrtiles), to_sizet(2));
+
839  return std::clamp(ideal_workers, min_workers, available_workers);
+
840  }();
841 
-
842  const auto barrier_busy_wait = getTridiagRank1BarrierBusyWait();
-
843  const std::size_t batch_size =
-
844  std::max<std::size_t>(2, util::ceilDiv(to_sizet(sz_loc_tiles.cols()), nthreads));
-
845  const SizeType begin = to_SizeType(thread_idx * batch_size);
-
846  const SizeType end = std::min(to_SizeType((thread_idx + 1) * batch_size), sz_loc_tiles.cols());
-
847 
-
848  // STEP 0a: Fill ones for deflated Eigenvectors. (single-thread)
-
849  // Note: this step is completely independent from the rest, but it is small and it is going
-
850  // to be dropped soon.
-
851  // Note: use last threads that in principle should have less work to do
-
852  if (thread_idx == nthreads - 1) {
-
853  // just if there are deflated eigenvectors
-
854  if (k < n) {
-
855  const GlobalElementSize origin_el(i_begin * dist.blockSize().rows(),
-
856  i_begin * dist.blockSize().cols());
-
857  const SizeType* i2_perm = i2_tile_arr[0].get().ptr();
-
858 
-
859  for (SizeType i_subm_el = 0; i_subm_el < n; ++i_subm_el) {
-
860  const SizeType j_subm_el = i2_perm[i_subm_el];
+
842  ex::start_detached(
+
843  ex::when_all(ex::just(std::make_unique<pika::barrier<>>(nthreads)),
+
844  std::forward<CommSender>(row_comm), std::forward<CommSender>(col_comm),
+
845  std::forward<KSender>(k), std::forward<RhoSender>(rho),
+
846  ex::when_all_vector(tc.read(d)), ex::when_all_vector(tc.readwrite(z)),
+
847  ex::when_all_vector(tc.readwrite(evals)), ex::when_all_vector(tc.read(i2)),
+
848  ex::when_all_vector(tc.readwrite(evecs)),
+
849  // additional workspaces
+
850  ex::just(std::vector<memory::MemoryView<T, Device::CPU>>()),
+
851  ex::just(memory::MemoryView<T, Device::CPU>())) |
+
852  ex::transfer(di::getBackendScheduler<Backend::MC>(pika::execution::thread_priority::high)) |
+
853  ex::bulk(nthreads, [nthreads, n, n_subm_el_lc, m_subm_el_lc, i_begin, ij_begin_lc, sz_loc_tiles,
+
854  dist, bcast_evals, all_reduce_in_place](
+
855  const std::size_t thread_idx, auto& barrier_ptr, auto& row_comm_wrapper,
+
856  auto& col_comm_wrapper, const auto& k, const auto& rho,
+
857  const auto& d_tiles_futs, auto& z_tiles, const auto& eval_tiles,
+
858  const auto& i2_tile_arr, const auto& evec_tiles, auto& ws_cols,
+
859  auto& ws_row) {
+
860  using dlaf::comm::internal::transformMPI;
861 
-
862  // if it is a deflated vector
-
863  if (j_subm_el >= k) {
-
864  const GlobalElementIndex ij_el(origin_el.rows() + i_subm_el,
-
865  origin_el.cols() + j_subm_el);
-
866  const GlobalTileIndex ij = dist.globalTileIndex(ij_el);
-
867 
-
868  if (dist.rankIndex() == dist.rankGlobalTile(ij)) {
-
869  const LocalTileIndex ij_lc = dist.localTileIndex(ij);
-
870  const SizeType linear_subm_lc =
-
871  (ij_lc.row() - ij_begin_lc.row()) +
-
872  (ij_lc.col() - ij_begin_lc.col()) * sz_loc_tiles.rows();
-
873  const TileElementIndex ij_el_tl = dist.tileElementIndex(ij_el);
-
874  evec_tiles[to_sizet(linear_subm_lc)](ij_el_tl) = T{1};
-
875  }
-
876  }
-
877  }
-
878  }
-
879  }
-
880 
-
881  // STEP 0b: Initialize workspaces (single-thread)
-
882  if (thread_idx == 0) {
-
883  // Note:
-
884  // - nthreads are used for both LAED4 and weight calculation (one per worker thread)
-
885  // - last one is used for reducing weights from all workers
-
886  ws_cols.reserve(nthreads + 1);
-
887 
-
888  // Note:
-
889  // Considering that
-
890  // - LAED4 requires working on k elements
-
891  // - Weight computaiton requires working on m_subm_el_lc
-
892  //
-
893  // and they are needed at two steps that cannot happen in parallel, we opted for allocating the
-
894  // workspace with the highest requirement of memory, and reuse them for both steps.
-
895  const SizeType max_size = std::max(k, m_subm_el_lc);
-
896  for (std::size_t i = 0; i < nthreads; ++i)
-
897  ws_cols.emplace_back(max_size);
-
898  ws_cols.emplace_back(m_subm_el_lc);
-
899 
-
900  ws_row = memory::MemoryView<T, Device::CPU>(n_subm_el_lc);
-
901  std::fill_n(ws_row(), n_subm_el_lc, 0);
+
862  common::Pipeline<comm::Communicator> row_comm_chain(row_comm_wrapper.get());
+
863  const dlaf::comm::Communicator& col_comm = col_comm_wrapper.get();
+
864 
+
865  const auto barrier_busy_wait = getTridiagRank1BarrierBusyWait();
+
866  const std::size_t batch_size =
+
867  std::max<std::size_t>(2, util::ceilDiv(to_sizet(sz_loc_tiles.cols()), nthreads));
+
868  const SizeType begin = to_SizeType(thread_idx * batch_size);
+
869  const SizeType end = std::min(to_SizeType((thread_idx + 1) * batch_size), sz_loc_tiles.cols());
+
870 
+
871  // STEP 0a: Fill ones for deflated Eigenvectors. (single-thread)
+
872  // Note: this step is completely independent from the rest, but it is small and it is going
+
873  // to be dropped soon.
+
874  // Note: use last threads that in principle should have less work to do
+
875  if (thread_idx == nthreads - 1) {
+
876  // just if there are deflated eigenvectors
+
877  if (k < n) {
+
878  const GlobalElementSize origin_el(i_begin * dist.blockSize().rows(),
+
879  i_begin * dist.blockSize().cols());
+
880  const SizeType* i2_perm = i2_tile_arr[0].get().ptr();
+
881 
+
882  for (SizeType i_subm_el = 0; i_subm_el < n; ++i_subm_el) {
+
883  const SizeType j_subm_el = i2_perm[i_subm_el];
+
884 
+
885  // if it is a deflated vector
+
886  if (j_subm_el >= k) {
+
887  const GlobalElementIndex ij_el(origin_el.rows() + i_subm_el,
+
888  origin_el.cols() + j_subm_el);
+
889  const GlobalTileIndex ij = dist.globalTileIndex(ij_el);
+
890 
+
891  if (dist.rankIndex() == dist.rankGlobalTile(ij)) {
+
892  const LocalTileIndex ij_lc = dist.localTileIndex(ij);
+
893  const SizeType linear_subm_lc =
+
894  (ij_lc.row() - ij_begin_lc.row()) +
+
895  (ij_lc.col() - ij_begin_lc.col()) * sz_loc_tiles.rows();
+
896  const TileElementIndex ij_el_tl = dist.tileElementIndex(ij_el);
+
897  evec_tiles[to_sizet(linear_subm_lc)](ij_el_tl) = T{1};
+
898  }
+
899  }
+
900  }
+
901  }
902  }
903 
-
904  // Note: we have to wait that LAED4 workspaces are ready to be used
-
905  barrier_ptr->arrive_and_wait(barrier_busy_wait);
-
906 
-
907  const T* d_ptr = d_tiles_futs[0].get().ptr();
-
908  const T* z_ptr = z_tiles[0].ptr();
-
909 
-
910  // STEP 1: LAED4 (multi-thread)
-
911  {
-
912  common::internal::SingleThreadedBlasScope single;
-
913 
-
914  T* eval_ptr = eval_tiles[0].ptr();
-
915  T* delta_ptr = ws_cols[thread_idx]();
-
916 
-
917  for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) {
-
918  const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc);
-
919  const SizeType j = dist.globalTileFromLocalTile<Coord::Col>(j_lc);
-
920  const SizeType n_subm_el = dist.globalTileElementDistance<Coord::Col>(i_begin, j);
-
921 
-
922  // Skip columns that are in the deflation zone
-
923  if (n_subm_el >= k)
-
924  break;
-
925 
-
926  const SizeType n_el_tl = std::min(dist.tileSize<Coord::Col>(j), k - n_subm_el);
-
927  for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) {
-
928  const SizeType j_el = n_subm_el + j_el_tl;
+
904  // STEP 0b: Initialize workspaces (single-thread)
+
905  if (thread_idx == 0) {
+
906  // Note:
+
907  // - nthreads are used for both LAED4 and weight calculation (one per worker thread)
+
908  // - last one is used for reducing weights from all workers
+
909  ws_cols.reserve(nthreads + 1);
+
910 
+
911  // Note:
+
912  // Considering that
+
913  // - LAED4 requires working on k elements
+
914  // - Weight computation requires working on m_subm_el_lc
+
915  //
+
916  // and they are needed at two steps that cannot happen in parallel, we opted for allocating the
+
917  // workspace with the highest requirement of memory, and reuse them for both steps.
+
918  const SizeType max_size = std::max(k, m_subm_el_lc);
+
919  for (std::size_t i = 0; i < nthreads; ++i)
+
920  ws_cols.emplace_back(max_size);
+
921  ws_cols.emplace_back(m_subm_el_lc);
+
922 
+
923  ws_row = memory::MemoryView<T, Device::CPU>(n_subm_el_lc);
+
924  std::fill_n(ws_row(), n_subm_el_lc, 0);
+
925  }
+
926 
+
927  // Note: we have to wait that LAED4 workspaces are ready to be used
+
928  barrier_ptr->arrive_and_wait(barrier_busy_wait);
929 
-
930  // Solve the deflated rank-1 problem
-
931  T& eigenval = eval_ptr[to_sizet(j_el)];
-
932  lapack::laed4(to_int(k), to_int(j_el), d_ptr, z_ptr, delta_ptr, rho, &eigenval);
-
933 
-
934  // copy the parts from delta stored on this rank
-
935  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
-
936  const SizeType linear_subm_lc = i_subm_lc + to_SizeType(j_subm_lc) * sz_loc_tiles.rows();
-
937  auto& evec_tile = evec_tiles[to_sizet(linear_subm_lc)];
-
938 
-
939  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
-
940  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
-
941  const SizeType m_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
-
942 
-
943  const SizeType i_subm = i - i_begin;
-
944  const auto& i2_perm = i2_tile_arr[to_sizet(i_subm)].get();
-
945 
-
946  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - m_subm_el);
-
947  for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) {
-
948  const SizeType jj_subm_el = i2_perm({i_el_tl, 0});
-
949  if (jj_subm_el < k)
-
950  evec_tile({i_el_tl, j_el_tl}) = delta_ptr[jj_subm_el];
-
951  }
-
952  }
-
953  }
-
954  }
-
955  }
+
930  const T* d_ptr = d_tiles_futs[0].get().ptr();
+
931  const T* z_ptr = z_tiles[0].ptr();
+
932 
+
933  // STEP 1: LAED4 (multi-thread)
+
934  {
+
935  common::internal::SingleThreadedBlasScope single;
+
936 
+
937  T* eval_ptr = eval_tiles[0].ptr();
+
938  T* delta_ptr = ws_cols[thread_idx]();
+
939 
+
940  for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) {
+
941  const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc);
+
942  const SizeType j = dist.globalTileFromLocalTile<Coord::Col>(j_lc);
+
943  const SizeType n_subm_el = dist.globalTileElementDistance<Coord::Col>(i_begin, j);
+
944 
+
945  // Skip columns that are in the deflation zone
+
946  if (n_subm_el >= k)
+
947  break;
+
948 
+
949  const SizeType n_el_tl = std::min(dist.tileSize<Coord::Col>(j), k - n_subm_el);
+
950  for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) {
+
951  const SizeType j_el = n_subm_el + j_el_tl;
+
952 
+
953  // Solve the deflated rank-1 problem
+
954  T& eigenval = eval_ptr[to_sizet(j_el)];
+
955  lapack::laed4(to_int(k), to_int(j_el), d_ptr, z_ptr, delta_ptr, rho, &eigenval);
956 
-
957  // Note: This barrier ensures that LAED4 finished, so from now on values are available
-
958  barrier_ptr->arrive_and_wait(barrier_busy_wait);
-
959 
-
960  // STEP 2: Broadcast evals
+
957  // copy the parts from delta stored on this rank
+
958  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
+
959  const SizeType linear_subm_lc = i_subm_lc + to_SizeType(j_subm_lc) * sz_loc_tiles.rows();
+
960  auto& evec_tile = evec_tiles[to_sizet(linear_subm_lc)];
961 
-
962  // Note: this ensures that evals broadcasting finishes before bulk releases resources
-
963  struct sync_wait_on_exit_t {
-
964  ex::unique_any_sender<> sender_;
+
962  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
+
963  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
+
964  const SizeType m_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
965 
-
966  ~sync_wait_on_exit_t() {
-
967  if (sender_)
-
968  tt::sync_wait(std::move(sender_));
-
969  }
-
970  } bcast_barrier;
-
971 
-
972  if (thread_idx == 0)
-
973  bcast_barrier.sender_ = bcast_evals(row_comm_chain, eval_tiles);
-
974 
-
975  // Note: laed4 handles k <= 2 cases differently
-
976  if (k <= 2)
-
977  return;
-
978 
-
979  // STEP 2 Compute weights (multi-thread)
-
980  auto& q = evec_tiles;
-
981  T* w = ws_cols[thread_idx]();
+
966  const SizeType i_subm = i - i_begin;
+
967  const auto& i2_perm = i2_tile_arr[to_sizet(i_subm)].get();
+
968 
+
969  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - m_subm_el);
+
970  for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) {
+
971  const SizeType jj_subm_el = i2_perm({i_el_tl, 0});
+
972  if (jj_subm_el < k)
+
973  evec_tile({i_el_tl, j_el_tl}) = delta_ptr[jj_subm_el];
+
974  }
+
975  }
+
976  }
+
977  }
+
978  }
+
979 
+
980  // Note: This barrier ensures that LAED4 finished, so from now on values are available
+
981  barrier_ptr->arrive_and_wait(barrier_busy_wait);
982 
-
983  // STEP 2a: copy diagonal from q -> w (or just initialize with 1)
-
984  if (thread_idx == 0) {
-
985  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
-
986  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
-
987  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
-
988  const SizeType i_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
-
989  const SizeType m_subm_el_lc =
-
990  dist.localElementDistanceFromLocalTile<Coord::Row>(ij_begin_lc.row(), i_lc);
-
991  const auto& i2 = i2_tile_arr[to_sizet(i - i_begin)].get();
-
992 
-
993  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - i_subm_el);
-
994  for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) {
-
995  const SizeType i_subm_el_lc = m_subm_el_lc + i_el_tl;
-
996 
-
997  const SizeType jj_subm_el = i2({i_el_tl, 0});
-
998  const SizeType n_el = dist.globalTileElementDistance<Coord::Col>(0, i_begin);
-
999  const SizeType jj_el = n_el + jj_subm_el;
-
1000  const SizeType jj = dist.globalTileFromGlobalElement<Coord::Col>(jj_el);
+
983  // STEP 2: Broadcast evals
+
984 
+
985  // Note: this ensures that evals broadcasting finishes before bulk releases resources
+
986  struct sync_wait_on_exit_t {
+
987  ex::unique_any_sender<> sender_;
+
988 
+
989  ~sync_wait_on_exit_t() {
+
990  if (sender_)
+
991  tt::sync_wait(std::move(sender_));
+
992  }
+
993  } bcast_barrier;
+
994 
+
995  if (thread_idx == 0)
+
996  bcast_barrier.sender_ = bcast_evals(row_comm_chain, eval_tiles);
+
997 
+
998  // Note: laed4 handles k <= 2 cases differently
+
999  if (k <= 2)
+
1000  return;
1001 
-
1002  if (dist.rankGlobalTile<Coord::Col>(jj) == dist.rankIndex().col()) {
-
1003  const SizeType jj_lc = dist.localTileFromGlobalTile<Coord::Col>(jj);
-
1004  const SizeType jj_subm_lc = jj_lc - ij_begin_lc.col();
-
1005  const SizeType jj_el_tl = dist.tileElementFromGlobalElement<Coord::Col>(jj_el);
-
1006 
-
1007  const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * jj_subm_lc;
-
1008 
-
1009  w[i_subm_el_lc] = q[to_sizet(linear_subm_lc)]({i_el_tl, jj_el_tl});
-
1010  }
-
1011  else {
-
1012  w[i_subm_el_lc] = T(1);
-
1013  }
-
1014  }
-
1015  }
-
1016  }
-
1017  else { // other workers
-
1018  std::fill_n(w, m_subm_el_lc, T(1));
-
1019  }
-
1020 
-
1021  barrier_ptr->arrive_and_wait(barrier_busy_wait);
-
1022 
-
1023  // STEP 2b: compute weights
-
1024  for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) {
-
1025  const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc);
-
1026  const SizeType j = dist.globalTileFromLocalTile<Coord::Col>(j_lc);
-
1027  const SizeType n_subm_el = dist.globalTileElementDistance<Coord::Col>(i_begin, j);
-
1028 
-
1029  // Skip columns that are in the deflation zone
-
1030  if (n_subm_el >= k)
-
1031  break;
-
1032 
-
1033  const SizeType n_el_tl = std::min(dist.tileSize<Coord::Col>(j), k - n_subm_el);
-
1034  for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) {
-
1035  const SizeType j_subm_el = n_subm_el + j_el_tl;
-
1036  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
-
1037  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
-
1038  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
-
1039  const SizeType m_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
-
1040 
-
1041  auto& i2_perm = i2_tile_arr[to_sizet(i - i_begin)].get();
-
1042 
-
1043  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - m_subm_el);
-
1044  for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) {
-
1045  const SizeType ii_subm_el = i2_perm({i_el_tl, 0});
-
1046 
-
1047  // deflated zone
-
1048  if (ii_subm_el >= k)
-
1049  continue;
-
1050 
-
1051  // diagonal
-
1052  if (ii_subm_el == j_subm_el)
-
1053  continue;
-
1054 
-
1055  const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc;
-
1056  const SizeType i_subm_el_lc = i_subm_lc * dist.blockSize().rows() + i_el_tl;
-
1057 
-
1058  w[i_subm_el_lc] *= q[to_sizet(linear_subm_lc)]({i_el_tl, j_el_tl}) /
-
1059  (d_ptr[to_sizet(ii_subm_el)] - d_ptr[to_sizet(j_subm_el)]);
-
1060  }
-
1061  }
-
1062  }
-
1063  }
-
1064 
-
1065  barrier_ptr->arrive_and_wait(barrier_busy_wait);
-
1066 
-
1067  // STEP 2c: reduce, then finalize computation with sign and square root (single-thread)
-
1068  if (thread_idx == 0) {
-
1069  // local reduction from all bulk workers
-
1070  for (int i = 0; i < m_subm_el_lc; ++i) {
-
1071  for (std::size_t tidx = 1; tidx < nthreads; ++tidx) {
-
1072  const T* w_partial = ws_cols[tidx]();
-
1073  w[i] *= w_partial[i];
-
1074  }
-
1075  }
-
1076 
-
1077  tt::sync_wait(ex::when_all(row_comm_chain(),
-
1078  ex::just(MPI_PROD, common::make_data(w, m_subm_el_lc))) |
-
1079  transformMPI(all_reduce_in_place));
+
1002  // STEP 2 Compute weights (multi-thread)
+
1003  auto& q = evec_tiles;
+
1004  T* w = ws_cols[thread_idx]();
+
1005 
+
1006  // STEP 2a: copy diagonal from q -> w (or just initialize with 1)
+
1007  if (thread_idx == 0) {
+
1008  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
+
1009  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
+
1010  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
+
1011  const SizeType i_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
+
1012  const SizeType m_subm_el_lc =
+
1013  dist.localElementDistanceFromLocalTile<Coord::Row>(ij_begin_lc.row(), i_lc);
+
1014  const auto& i2 = i2_tile_arr[to_sizet(i - i_begin)].get();
+
1015 
+
1016  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - i_subm_el);
+
1017  for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) {
+
1018  const SizeType i_subm_el_lc = m_subm_el_lc + i_el_tl;
+
1019 
+
1020  const SizeType jj_subm_el = i2({i_el_tl, 0});
+
1021  const SizeType n_el = dist.globalTileElementDistance<Coord::Col>(0, i_begin);
+
1022  const SizeType jj_el = n_el + jj_subm_el;
+
1023  const SizeType jj = dist.globalTileFromGlobalElement<Coord::Col>(jj_el);
+
1024 
+
1025  if (dist.rankGlobalTile<Coord::Col>(jj) == dist.rankIndex().col()) {
+
1026  const SizeType jj_lc = dist.localTileFromGlobalTile<Coord::Col>(jj);
+
1027  const SizeType jj_subm_lc = jj_lc - ij_begin_lc.col();
+
1028  const SizeType jj_el_tl = dist.tileElementFromGlobalElement<Coord::Col>(jj_el);
+
1029 
+
1030  const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * jj_subm_lc;
+
1031 
+
1032  w[i_subm_el_lc] = q[to_sizet(linear_subm_lc)]({i_el_tl, jj_el_tl});
+
1033  }
+
1034  else {
+
1035  w[i_subm_el_lc] = T(1);
+
1036  }
+
1037  }
+
1038  }
+
1039  }
+
1040  else { // other workers
+
1041  std::fill_n(w, m_subm_el_lc, T(1));
+
1042  }
+
1043 
+
1044  barrier_ptr->arrive_and_wait(barrier_busy_wait);
+
1045 
+
1046  // STEP 2b: compute weights
+
1047  for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) {
+
1048  const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc);
+
1049  const SizeType j = dist.globalTileFromLocalTile<Coord::Col>(j_lc);
+
1050  const SizeType n_subm_el = dist.globalTileElementDistance<Coord::Col>(i_begin, j);
+
1051 
+
1052  // Skip columns that are in the deflation zone
+
1053  if (n_subm_el >= k)
+
1054  break;
+
1055 
+
1056  const SizeType n_el_tl = std::min(dist.tileSize<Coord::Col>(j), k - n_subm_el);
+
1057  for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) {
+
1058  const SizeType j_subm_el = n_subm_el + j_el_tl;
+
1059  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
+
1060  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
+
1061  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
+
1062  const SizeType m_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
+
1063 
+
1064  auto& i2_perm = i2_tile_arr[to_sizet(i - i_begin)].get();
+
1065 
+
1066  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - m_subm_el);
+
1067  for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) {
+
1068  const SizeType ii_subm_el = i2_perm({i_el_tl, 0});
+
1069 
+
1070  // deflated zone
+
1071  if (ii_subm_el >= k)
+
1072  continue;
+
1073 
+
1074  // diagonal
+
1075  if (ii_subm_el == j_subm_el)
+
1076  continue;
+
1077 
+
1078  const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc;
+
1079  const SizeType i_subm_el_lc = i_subm_lc * dist.blockSize().rows() + i_el_tl;
1080 
-
1081  T* weights = ws_cols[nthreads]();
-
1082  for (int i_subm_el_lc = 0; i_subm_el_lc < m_subm_el_lc; ++i_subm_el_lc) {
-
1083  const SizeType i_subm_lc = i_subm_el_lc / dist.blockSize().rows();
-
1084  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
-
1085  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
-
1086  const SizeType i_subm = i - i_begin;
-
1087  const SizeType i_subm_el =
-
1088  i_subm * dist.blockSize().rows() + i_subm_el_lc % dist.blockSize().rows();
+
1081  w[i_subm_el_lc] *= q[to_sizet(linear_subm_lc)]({i_el_tl, j_el_tl}) /
+
1082  (d_ptr[to_sizet(ii_subm_el)] - d_ptr[to_sizet(j_subm_el)]);
+
1083  }
+
1084  }
+
1085  }
+
1086  }
+
1087 
+
1088  barrier_ptr->arrive_and_wait(barrier_busy_wait);
1089 
-
1090  const auto* i2_perm = i2_tile_arr[0].get().ptr();
-
1091  const SizeType ii_subm_el = i2_perm[i_subm_el];
-
1092  weights[to_sizet(i_subm_el_lc)] =
-
1093  std::copysign(std::sqrt(-w[i_subm_el_lc]), z_ptr[to_sizet(ii_subm_el)]);
-
1094  }
-
1095  }
-
1096 
-
1097  barrier_ptr->arrive_and_wait(barrier_busy_wait);
-
1098 
-
1099  // STEP 3: Compute eigenvectors of the modified rank-1 modification (normalize) (multi-thread)
-
1100 
-
1101  // STEP 3a: Form evecs using weights vector and compute (local) sum of squares
-
1102  {
-
1103  common::internal::SingleThreadedBlasScope single;
-
1104 
-
1105  const T* w = ws_cols[nthreads]();
-
1106  T* sum_squares = ws_row();
-
1107 
-
1108  for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) {
-
1109  const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc);
-
1110  const SizeType j = dist.globalTileFromLocalTile<Coord::Col>(j_lc);
-
1111  const SizeType n_subm_el = dist.globalTileElementDistance<Coord::Col>(i_begin, j);
+
1090  // STEP 2c: reduce, then finalize computation with sign and square root (single-thread)
+
1091  if (thread_idx == 0) {
+
1092  // local reduction from all bulk workers
+
1093  for (SizeType i = 0; i < m_subm_el_lc; ++i) {
+
1094  for (std::size_t tidx = 1; tidx < nthreads; ++tidx) {
+
1095  const T* w_partial = ws_cols[tidx]();
+
1096  w[i] *= w_partial[i];
+
1097  }
+
1098  }
+
1099 
+
1100  tt::sync_wait(ex::when_all(row_comm_chain(),
+
1101  ex::just(MPI_PROD, common::make_data(w, m_subm_el_lc))) |
+
1102  transformMPI(all_reduce_in_place));
+
1103 
+
1104  T* weights = ws_cols[nthreads]();
+
1105  for (SizeType i_subm_el_lc = 0; i_subm_el_lc < m_subm_el_lc; ++i_subm_el_lc) {
+
1106  const SizeType i_subm_lc = i_subm_el_lc / dist.blockSize().rows();
+
1107  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
+
1108  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
+
1109  const SizeType i_subm = i - i_begin;
+
1110  const SizeType i_subm_el =
+
1111  i_subm * dist.blockSize().rows() + i_subm_el_lc % dist.blockSize().rows();
1112 
-
1113  // Skip columns that are in the deflation zone
-
1114  if (n_subm_el >= k)
-
1115  break;
-
1116 
-
1117  const SizeType n_el_tl = std::min(dist.tileSize<Coord::Col>(j), k - n_subm_el);
-
1118  for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) {
-
1119  const SizeType j_subm_el_lc = j_subm_lc * dist.blockSize().cols() + j_el_tl;
-
1120  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
-
1121  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
-
1122  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
-
1123  const SizeType m_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
-
1124 
-
1125  const SizeType i_subm = i - i_begin;
-
1126  const auto& i2_perm = i2_tile_arr[to_sizet(i_subm)].get();
+
1113  const auto* i2_perm = i2_tile_arr[0].get().ptr();
+
1114  const SizeType ii_subm_el = i2_perm[i_subm_el];
+
1115  weights[to_sizet(i_subm_el_lc)] =
+
1116  std::copysign(std::sqrt(-w[i_subm_el_lc]), z_ptr[to_sizet(ii_subm_el)]);
+
1117  }
+
1118  }
+
1119 
+
1120  barrier_ptr->arrive_and_wait(barrier_busy_wait);
+
1121 
+
1122  // STEP 3: Compute eigenvectors of the modified rank-1 modification (normalize) (multi-thread)
+
1123 
+
1124  // STEP 3a: Form evecs using weights vector and compute (local) sum of squares
+
1125  {
+
1126  common::internal::SingleThreadedBlasScope single;
1127 
-
1128  const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc;
-
1129  const auto& q_tile = q[to_sizet(linear_subm_lc)];
+
1128  const T* w = ws_cols[nthreads]();
+
1129  T* sum_squares = ws_row();
1130 
-
1131  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - m_subm_el);
-
1132  for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) {
-
1133  const SizeType ii_subm_el = i2_perm({i_el_tl, 0});
-
1134 
-
1135  const SizeType i_subm_el_lc = i_subm_lc * dist.blockSize().rows() + i_el_tl;
-
1136  if (ii_subm_el >= k)
-
1137  q_tile({i_el_tl, j_el_tl}) = 0;
-
1138  else
-
1139  q_tile({i_el_tl, j_el_tl}) = w[i_subm_el_lc] / q_tile({i_el_tl, j_el_tl});
-
1140  }
-
1141 
-
1142  sum_squares[j_subm_el_lc] +=
-
1143  blas::dot(m_el_tl, q_tile.ptr({0, j_el_tl}), 1, q_tile.ptr({0, j_el_tl}), 1);
-
1144  }
-
1145  }
-
1146  }
-
1147  }
-
1148 
-
1149  barrier_ptr->arrive_and_wait(barrier_busy_wait);
+
1131  for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) {
+
1132  const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc);
+
1133  const SizeType j = dist.globalTileFromLocalTile<Coord::Col>(j_lc);
+
1134  const SizeType n_subm_el = dist.globalTileElementDistance<Coord::Col>(i_begin, j);
+
1135 
+
1136  // Skip columns that are in the deflation zone
+
1137  if (n_subm_el >= k)
+
1138  break;
+
1139 
+
1140  const SizeType n_el_tl = std::min(dist.tileSize<Coord::Col>(j), k - n_subm_el);
+
1141  for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) {
+
1142  const SizeType j_subm_el_lc = j_subm_lc * dist.blockSize().cols() + j_el_tl;
+
1143  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
+
1144  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
+
1145  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
+
1146  const SizeType m_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
+
1147 
+
1148  const SizeType i_subm = i - i_begin;
+
1149  const auto& i2_perm = i2_tile_arr[to_sizet(i_subm)].get();
1150 
-
1151  // STEP 3b: Reduce to get the sum of all squares on all ranks
-
1152  if (thread_idx == 0)
-
1153  tt::sync_wait(ex::just(std::cref(col_comm), MPI_SUM,
-
1154  common::make_data(ws_row(), n_subm_el_lc)) |
-
1155  transformMPI(all_reduce_in_place));
-
1156 
-
1157  barrier_ptr->arrive_and_wait(barrier_busy_wait);
-
1158 
-
1159  // STEP 3c: Normalize (compute norm of each column and scale column vector)
-
1160  {
-
1161  common::internal::SingleThreadedBlasScope single;
-
1162 
-
1163  const T* sum_squares = ws_row();
+
1151  const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc;
+
1152  const auto& q_tile = q[to_sizet(linear_subm_lc)];
+
1153 
+
1154  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - m_subm_el);
+
1155  for (SizeType i_el_tl = 0; i_el_tl < m_el_tl; ++i_el_tl) {
+
1156  const SizeType ii_subm_el = i2_perm({i_el_tl, 0});
+
1157 
+
1158  const SizeType i_subm_el_lc = i_subm_lc * dist.blockSize().rows() + i_el_tl;
+
1159  if (ii_subm_el >= k)
+
1160  q_tile({i_el_tl, j_el_tl}) = 0;
+
1161  else
+
1162  q_tile({i_el_tl, j_el_tl}) = w[i_subm_el_lc] / q_tile({i_el_tl, j_el_tl});
+
1163  }
1164 
-
1165  for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) {
-
1166  const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc);
-
1167  const SizeType j = dist.globalTileFromLocalTile<Coord::Col>(j_lc);
-
1168  const SizeType n_subm_el = dist.globalTileElementDistance<Coord::Col>(i_begin, j);
-
1169 
-
1170  // Skip columns that are in the deflation zone
-
1171  if (n_subm_el >= k)
-
1172  break;
+
1165  sum_squares[j_subm_el_lc] +=
+
1166  blas::dot(m_el_tl, q_tile.ptr({0, j_el_tl}), 1, q_tile.ptr({0, j_el_tl}), 1);
+
1167  }
+
1168  }
+
1169  }
+
1170  }
+
1171 
+
1172  barrier_ptr->arrive_and_wait(barrier_busy_wait);
1173 
-
1174  const SizeType n_el_tl = std::min(dist.tileSize<Coord::Col>(j), k - n_subm_el);
-
1175  for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) {
-
1176  const SizeType j_subm_el_lc = j_subm_lc * dist.blockSize().cols() + j_el_tl;
-
1177  const T vec_norm = std::sqrt(sum_squares[j_subm_el_lc]);
-
1178 
-
1179  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
-
1180  const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc;
-
1181  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
-
1182  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
-
1183  const SizeType m_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
-
1184 
-
1185  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - m_subm_el);
-
1186  blas::scal(m_el_tl, 1 / vec_norm, q[to_sizet(linear_subm_lc)].ptr({0, j_el_tl}), 1);
-
1187  }
-
1188  }
-
1189  }
-
1190  }
-
1191  }));
-
1192 }
-
1193 
-
1194 // Distributed version of the tridiagonal solver on CPUs
-
1195 template <Backend B, class T, Device D, class RhoSender>
-
1196 void mergeDistSubproblems(comm::CommunicatorGrid grid,
-
1197  common::Pipeline<comm::Communicator>& full_task_chain,
-
1198  common::Pipeline<comm::Communicator>& row_task_chain,
-
1199  common::Pipeline<comm::Communicator>& col_task_chain, const SizeType i_begin,
-
1200  const SizeType i_split, const SizeType i_end, RhoSender&& rho,
-
1201  WorkSpace<T, D>& ws, WorkSpaceHost<T>& ws_h,
-
1202  DistWorkSpaceHostMirror<T, D>& ws_hm) {
-
1203  namespace ex = pika::execution::experimental;
-
1204 
-
1205  const matrix::Distribution& dist_evecs = ws.e0.distribution();
-
1206 
-
1207  // Calculate the size of the upper subproblem
-
1208  const SizeType n1 = dist_evecs.globalTileElementDistance<Coord::Row>(i_begin, i_split);
-
1209 
-
1210  // The local size of the subproblem
-
1211  const GlobalTileIndex idx_gl_begin(i_begin, i_begin);
-
1212  const LocalTileIndex idx_loc_begin{dist_evecs.nextLocalTileFromGlobalTile<Coord::Row>(i_begin),
-
1213  dist_evecs.nextLocalTileFromGlobalTile<Coord::Col>(i_begin)};
-
1214  const LocalTileIndex idx_loc_end{dist_evecs.nextLocalTileFromGlobalTile<Coord::Row>(i_end),
-
1215  dist_evecs.nextLocalTileFromGlobalTile<Coord::Col>(i_end)};
-
1216  const LocalTileSize sz_loc_tiles = idx_loc_end - idx_loc_begin;
-
1217  const LocalTileIndex idx_begin_tiles_vec(i_begin, 0);
-
1218  const LocalTileSize sz_tiles_vec(i_end - i_begin, 1);
-
1219 
-
1220  // Assemble the rank-1 update vector `z` from the last row of Q1 and the first row of Q2
-
1221  assembleDistZVec(grid, full_task_chain, i_begin, i_split, i_end, rho, ws.e0, ws.z0);
-
1222  copy(idx_begin_tiles_vec, sz_tiles_vec, ws.z0, ws_hm.z0);
-
1223 
-
1224  // Double `rho` to account for the normalization of `z` and make sure `rho > 0` for the root solver laed4
-
1225  auto scaled_rho = scaleRho(std::move(rho)) | ex::split();
-
1226 
-
1227  // Calculate the tolerance used for deflation
-
1228  auto tol = calcTolerance(i_begin, i_end, ws_h.d0, ws_hm.z0);
+
1174  // STEP 3b: Reduce to get the sum of all squares on all ranks
+
1175  if (thread_idx == 0)
+
1176  tt::sync_wait(ex::just(std::cref(col_comm), MPI_SUM,
+
1177  common::make_data(ws_row(), n_subm_el_lc)) |
+
1178  transformMPI(all_reduce_in_place));
+
1179 
+
1180  barrier_ptr->arrive_and_wait(barrier_busy_wait);
+
1181 
+
1182  // STEP 3c: Normalize (compute norm of each column and scale column vector)
+
1183  {
+
1184  common::internal::SingleThreadedBlasScope single;
+
1185 
+
1186  const T* sum_squares = ws_row();
+
1187 
+
1188  for (SizeType j_subm_lc = begin; j_subm_lc < end; ++j_subm_lc) {
+
1189  const SizeType j_lc = ij_begin_lc.col() + to_SizeType(j_subm_lc);
+
1190  const SizeType j = dist.globalTileFromLocalTile<Coord::Col>(j_lc);
+
1191  const SizeType n_subm_el = dist.globalTileElementDistance<Coord::Col>(i_begin, j);
+
1192 
+
1193  // Skip columns that are in the deflation zone
+
1194  if (n_subm_el >= k)
+
1195  break;
+
1196 
+
1197  const SizeType n_el_tl = std::min(dist.tileSize<Coord::Col>(j), k - n_subm_el);
+
1198  for (SizeType j_el_tl = 0; j_el_tl < n_el_tl; ++j_el_tl) {
+
1199  const SizeType j_subm_el_lc = j_subm_lc * dist.blockSize().cols() + j_el_tl;
+
1200  const T vec_norm = std::sqrt(sum_squares[j_subm_el_lc]);
+
1201 
+
1202  for (SizeType i_subm_lc = 0; i_subm_lc < sz_loc_tiles.rows(); ++i_subm_lc) {
+
1203  const SizeType linear_subm_lc = i_subm_lc + sz_loc_tiles.rows() * j_subm_lc;
+
1204  const SizeType i_lc = ij_begin_lc.row() + i_subm_lc;
+
1205  const SizeType i = dist.globalTileFromLocalTile<Coord::Row>(i_lc);
+
1206  const SizeType m_subm_el = dist.globalTileElementDistance<Coord::Row>(i_begin, i);
+
1207 
+
1208  const SizeType m_el_tl = std::min(dist.tileSize<Coord::Row>(i), n - m_subm_el);
+
1209  blas::scal(m_el_tl, 1 / vec_norm, q[to_sizet(linear_subm_lc)].ptr({0, j_el_tl}), 1);
+
1210  }
+
1211  }
+
1212  }
+
1213  }
+
1214  }));
+
1215 }
+
1216 
+
1217 // Distributed version of the tridiagonal solver on CPUs
+
1218 template <Backend B, class T, Device D, class RhoSender>
+
1219 void mergeDistSubproblems(comm::CommunicatorGrid grid,
+
1220  common::Pipeline<comm::Communicator>& full_task_chain,
+
1221  common::Pipeline<comm::Communicator>& row_task_chain,
+
1222  common::Pipeline<comm::Communicator>& col_task_chain, const SizeType i_begin,
+
1223  const SizeType i_split, const SizeType i_end, RhoSender&& rho,
+
1224  WorkSpace<T, D>& ws, WorkSpaceHost<T>& ws_h,
+
1225  DistWorkSpaceHostMirror<T, D>& ws_hm) {
+
1226  namespace ex = pika::execution::experimental;
+
1227 
+
1228  const matrix::Distribution& dist_evecs = ws.e0.distribution();
1229 
-
1230  // Initialize the column types vector `c`
-
1231  initColTypes(i_begin, i_split, i_end, ws_h.c);
+
1230  // Calculate the size of the upper subproblem
+
1231  const SizeType n1 = dist_evecs.globalTileElementDistance<Coord::Row>(i_begin, i_split);
1232 
-
1233  // Step #1
-
1234  //
-
1235  // i1 (out) : initial <--- initial (identity map)
-
1236  // i2 (out) : initial <--- pre_sorted
-
1237  //
-
1238  // - deflate `d`, `z` and `c`
-
1239  // - apply Givens rotations to `Q` - `evecs`
-
1240  //
-
1241  if (i_split == i_begin + 1) {
-
1242  initIndex(i_begin, i_split, ws_h.i1);
-
1243  }
-
1244  if (i_split + 1 == i_end) {
-
1245  initIndex(i_split, i_end, ws_h.i1);
-
1246  }
-
1247  addIndex(i_split, i_end, n1, ws_h.i1);
-
1248  sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2);
+
1233  // The local size of the subproblem
+
1234  const GlobalTileIndex idx_gl_begin(i_begin, i_begin);
+
1235  const LocalTileIndex idx_loc_begin{dist_evecs.nextLocalTileFromGlobalTile<Coord::Row>(i_begin),
+
1236  dist_evecs.nextLocalTileFromGlobalTile<Coord::Col>(i_begin)};
+
1237  const LocalTileIndex idx_loc_end{dist_evecs.nextLocalTileFromGlobalTile<Coord::Row>(i_end),
+
1238  dist_evecs.nextLocalTileFromGlobalTile<Coord::Col>(i_end)};
+
1239  const LocalTileSize sz_loc_tiles = idx_loc_end - idx_loc_begin;
+
1240  const LocalTileIndex idx_begin_tiles_vec(i_begin, 0);
+
1241  const LocalTileSize sz_tiles_vec(i_end - i_begin, 1);
+
1242 
+
1243  // Assemble the rank-1 update vector `z` from the last row of Q1 and the first row of Q2
+
1244  assembleDistZVec(grid, full_task_chain, i_begin, i_split, i_end, rho, ws.e0, ws.z0);
+
1245  copy(idx_begin_tiles_vec, sz_tiles_vec, ws.z0, ws_hm.z0);
+
1246 
+
1247  // Double `rho` to account for the normalization of `z` and make sure `rho > 0` for the root solver laed4
+
1248  auto scaled_rho = scaleRho(std::move(rho)) | ex::split();
1249 
-
1250  auto rots =
-
1251  applyDeflation(i_begin, i_end, scaled_rho, std::move(tol), ws_hm.i2, ws_h.d0, ws_hm.z0, ws_h.c);
+
1250  // Calculate the tolerance used for deflation
+
1251  auto tol = calcTolerance(i_begin, i_end, ws_h.d0, ws_hm.z0);
1252 
-
1253  // ---
-
1254 
-
1255  // Make sure Isend/Irecv messages don't match between calls by providing a unique `tag`
-
1256  //
-
1257  // Note: i_split is unique
-
1258  const comm::IndexT_MPI tag = to_int(i_split);
-
1259  applyGivensRotationsToMatrixColumns(grid.rowCommunicator(), tag, i_begin, i_end, std::move(rots),
-
1260  ws.e0);
-
1261  // Placeholder for rearranging the eigenvectors: (local permutation)
-
1262  copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws.e1);
-
1263 
-
1264  // Step #2
-
1265  //
-
1266  // i2 (in) : initial <--- pre_sorted
-
1267  // i3 (out) : initial <--- deflated
-
1268  //
-
1269  // - reorder `d0 -> d1`, `z0 -> z1`, using `i3` such that deflated entries are at the bottom.
-
1270  // - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`.
-
1271  // - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented)
-
1272  //
-
1273  auto k = stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_hm.i2, ws_h.i3) | ex::split();
-
1274  applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1);
-
1275  applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1);
-
1276  copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.d1, ws_h.d0);
+
1253  // Initialize the column types vector `c`
+
1254  initColTypes(i_begin, i_split, i_end, ws_h.c);
+
1255 
+
1256  // Step #1
+
1257  //
+
1258  // i1 (out) : initial <--- initial (identity map)
+
1259  // i2 (out) : initial <--- pre_sorted
+
1260  //
+
1261  // - deflate `d`, `z` and `c`
+
1262  // - apply Givens rotations to `Q` - `evecs`
+
1263  //
+
1264  if (i_split == i_begin + 1) {
+
1265  initIndex(i_begin, i_split, ws_h.i1);
+
1266  }
+
1267  if (i_split + 1 == i_end) {
+
1268  initIndex(i_split, i_end, ws_h.i1);
+
1269  }
+
1270  addIndex(i_split, i_end, n1, ws_h.i1);
+
1271  sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2);
+
1272 
+
1273  auto rots =
+
1274  applyDeflation(i_begin, i_end, scaled_rho, std::move(tol), ws_hm.i2, ws_h.d0, ws_hm.z0, ws_h.c);
+
1275 
+
1276  // ---
1277 
-
1278  //
-
1279  // i3 (in) : initial <--- deflated
-
1280  // i2 (out) : initial ---> deflated
-
1281  //
-
1282  invertIndex(i_begin, i_end, ws_h.i3, ws_hm.i2);
-
1283 
-
1284  // Note: here ws_hm.z0 is used as a contiguous buffer for the laed4 call
-
1285  matrix::util::set0<Backend::MC>(pika::execution::thread_priority::normal, idx_loc_begin, sz_loc_tiles,
-
1286  ws_hm.e2);
-
1287  solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, idx_loc_begin, sz_loc_tiles,
-
1288  k, std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_hm.i2, ws_hm.e2);
-
1289 
-
1290  // Step #3: Eigenvectors of the tridiagonal system: Q * U
+
1278  // Make sure Isend/Irecv messages don't match between calls by providing a unique `tag`
+
1279  //
+
1280  // Note: i_split is unique
+
1281  const comm::IndexT_MPI tag = to_int(i_split);
+
1282  applyGivensRotationsToMatrixColumns(grid.rowCommunicator(), tag, i_begin, i_end, std::move(rots),
+
1283  ws.e0);
+
1284  // Placeholder for rearranging the eigenvectors: (local permutation)
+
1285  copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws.e1);
+
1286 
+
1287  // Step #2
+
1288  //
+
1289  // i2 (in) : initial <--- pre_sorted
+
1290  // i3 (out) : initial <--- deflated
1291  //
-
1292  // The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as
-
1293  // prepared for the deflated system.
-
1294  copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2);
-
1295  dlaf::multiplication::generalSubMatrix<B, D, T>(grid, row_task_chain, col_task_chain, i_begin, i_end,
-
1296  T(1), ws.e1, ws.e2, T(0), ws.e0);
-
1297 
-
1298  // Step #4: Final permutation to sort eigenvalues and eigenvectors
-
1299  //
-
1300  // i1 (in) : deflated <--- deflated (identity map)
-
1301  // i2 (out) : deflated <--- post_sorted
-
1302  //
-
1303  initIndex(i_begin, i_end, ws_h.i1);
-
1304  sortIndex(i_begin, i_end, std::move(k), ws_h.d0, ws_h.i1, ws_hm.i2);
-
1305  copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i2, ws_h.i1);
-
1306 }
-
1307 }
+
1292  // - reorder `d0 -> d1`, `z0 -> z1`, using `i3` such that deflated entries are at the bottom.
+
1293  // - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`.
+
1294  // - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented)
+
1295  //
+
1296  auto k = stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_hm.i2, ws_h.i3) | ex::split();
+
1297  applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1);
+
1298  applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1);
+
1299  copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.d1, ws_h.d0);
+
1300 
+
1301  //
+
1302  // i3 (in) : initial <--- deflated
+
1303  // i2 (out) : initial ---> deflated
+
1304  //
+
1305  invertIndex(i_begin, i_end, ws_h.i3, ws_hm.i2);
+
1306 
+
1307  // Note: here ws_hm.z0 is used as a contiguous buffer for the laed4 call
+
1308  matrix::util::set0<Backend::MC>(pika::execution::thread_priority::normal, idx_loc_begin, sz_loc_tiles,
+
1309  ws_hm.e2);
+
1310  solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, idx_loc_begin, sz_loc_tiles,
+
1311  k, std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_hm.i2, ws_hm.e2);
+
1312 
+
1313  // Step #3: Eigenvectors of the tridiagonal system: Q * U
+
1314  //
+
1315  // The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as
+
1316  // prepared for the deflated system.
+
1317  copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2);
+
1318  dlaf::multiplication::generalSubMatrix<B, D, T>(grid, row_task_chain, col_task_chain, i_begin, i_end,
+
1319  T(1), ws.e1, ws.e2, T(0), ws.e0);
+
1320 
+
1321  // Step #4: Final permutation to sort eigenvalues and eigenvectors
+
1322  //
+
1323  // i1 (in) : deflated <--- deflated (identity map)
+
1324  // i2 (out) : deflated <--- post_sorted
+
1325  //
+
1326  initIndex(i_begin, i_end, ws_h.i1);
+
1327  sortIndex(i_begin, i_end, std::move(k), ws_h.d0, ws_h.i1, ws_hm.i2);
+
1328  copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i2, ws_h.i1);
+
1329 }
+
1330 }
Definition: communicator.h:40