Skip to content

Commit

Permalink
merge-squashed: propedeuthic changes towards gemm cost reduction
Browse files Browse the repository at this point in the history
make rank1 work just on non-deflated (single-threaded)
  • Loading branch information
albestro committed Dec 4, 2023
1 parent 51109c9 commit 477b2e2
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 282 deletions.
12 changes: 7 additions & 5 deletions include/dlaf/eigensolver/tridiag_solver/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ void TridiagSolver<B, D, T>::call(Matrix<T, Device::CPU>& tridiag, Matrix<T, D>&
Matrix<T, D>(vec_size, vec_tile_size), // z0
Matrix<T, D>(vec_size, vec_tile_size), // z1
Matrix<SizeType, D>(vec_size, vec_tile_size), // i2
Matrix<SizeType, D>(vec_size, vec_tile_size)}; // i5
Matrix<SizeType, D>(vec_size, vec_tile_size), // i5
Matrix<SizeType, D>(vec_size, vec_tile_size)}; // i6

WorkSpaceHost<T> ws_h{Matrix<T, Device::CPU>(vec_size, vec_tile_size), // d0
Matrix<ColType, Device::CPU>(vec_size, vec_tile_size), // c
Expand Down Expand Up @@ -380,7 +381,8 @@ void TridiagSolver<B, D, T>::call(comm::CommunicatorGrid grid, Matrix<T, Device:
Matrix<T, D>(dist_evals), // z0
Matrix<T, D>(dist_evals), // z1
Matrix<SizeType, D>(dist_evals), // i2
Matrix<SizeType, D>(dist_evals)}; // i5
Matrix<SizeType, D>(dist_evals), // i5
Matrix<SizeType, D>(dist_evals)}; // i6

WorkSpaceHost<T> ws_h{Matrix<T, Device::CPU>(dist_evals), // d0
Matrix<ColType, Device::CPU>(dist_evals), // c
Expand All @@ -392,7 +394,7 @@ void TridiagSolver<B, D, T>::call(comm::CommunicatorGrid grid, Matrix<T, Device:
DistWorkSpaceHostMirror<T, D> ws_hm{initMirrorMatrix(ws.e0), initMirrorMatrix(ws.e2),
initMirrorMatrix(ws.d1), initMirrorMatrix(ws.z0),
initMirrorMatrix(ws.z1), initMirrorMatrix(ws.i2),
initMirrorMatrix(ws.i5)};
initMirrorMatrix(ws.i5), initMirrorMatrix(ws.i6)};

// Set `ws.e0` to `zero` (needed for Given's rotation to make sure no random values are picked up)
matrix::util::set0<B, T, D>(thread_priority::normal, ws.e0);
Expand Down Expand Up @@ -426,12 +428,12 @@ void TridiagSolver<B, D, T>::call(comm::CommunicatorGrid grid, Matrix<T, Device:
copy(ws.e0, ws_hm.e0);

// Note: ws_hm.d1 is the mirror of ws.d1 which is evals
applyIndex(0, n, ws_hm.i2, ws_h.d0, ws_hm.d1);
applyIndex(0, n, ws_h.i1, ws_h.d0, ws_hm.d1);
copy(ws_hm.d1, evals);

// Note: ws_hm.e2 is the mirror of ws.e2 which is evecs
dlaf::permutations::permute<Backend::MC, Device::CPU, T, Coord::Col>(grid, row_task_chain, 0, n,
ws_hm.i2, ws_hm.e0, ws_hm.e2);
ws_h.i1, ws_hm.e0, ws_hm.e2);
copy(ws_hm.e2, evecs);
}

Expand Down
Loading

0 comments on commit 477b2e2

Please sign in to comment.