From 38cfd1ede86f147dd06b3fe765f8b2913f8406d4 Mon Sep 17 00:00:00 2001 From: Alberto Invernizzi Date: Fri, 27 Oct 2023 17:07:24 +0200 Subject: [PATCH 01/16] add wrapper and new more generic implementation --- include/dlaf/multiplication/general.h | 24 +++++++ include/dlaf/multiplication/general/api.h | 4 ++ include/dlaf/multiplication/general/impl.h | 76 ++++++++++++++++++++++ 3 files changed, 104 insertions(+) diff --git a/include/dlaf/multiplication/general.h b/include/dlaf/multiplication/general.h index 6eaa3e1f22..40f88355ac 100644 --- a/include/dlaf/multiplication/general.h +++ b/include/dlaf/multiplication/general.h @@ -64,6 +64,30 @@ void generalMatrix(const blas::Op opA, const blas::Op opB, const T alpha, Matrix DLAF_UNIMPLEMENTED(opA, opB); } +template +void generalMatrix([[maybe_unused]] comm::CommunicatorGrid grid, + common::Pipeline& row_task_chain, + common::Pipeline& col_task_chain, const SizeType a, + const SizeType b, const T alpha, MatrixRef& mat_a, + MatrixRef& mat_b, const T beta, MatrixRef& mat_c) { + DLAF_ASSERT(equal_process_grid(mat_a, grid), mat_a, grid); + DLAF_ASSERT(equal_process_grid(mat_b, grid), mat_a, grid); + DLAF_ASSERT(equal_process_grid(mat_c, grid), mat_a, grid); + + using matrix::multipliable_sizes; + DLAF_ASSERT(multipliable_sizes(mat_a.size(), mat_b.size(), mat_c.size()), + "Multiplication incompatible matrix sizes.", mat_a.size(), mat_b.size(), mat_c.size()); + DLAF_ASSERT(multipliable_sizes(mat_a.blockSize(), mat_b.blockSize(), mat_c.blockSize()), + "Multiplication incompatible tile sizes."); + DLAF_ASSERT(mat_c.size().isEmpty() || multipliable_sizes(mat_a.distribution().tileSize({0, 0}), + mat_b.distribution().tileSize({0, 0}), + mat_c.distribution().tileSize({0, 0})), + "Multiplication incompatible tile sizes in first row/col. " + "(Are you using a matrix with offset not aligned with tile?)"); + + internal::General::callNN(row_task_chain, col_task_chain, alpha, mat_a, mat_b, beta, mat_c); +} + /// General sub-matrix multiplication implementation on local memory, computing /// C[a:b][a:b] = alpha * opA(A[a:b][a:b]) * opB(B[a:b][a:b]) + beta * C[a:b][a:b] /// where [a:b] is the range of tiles starting from tile index @p a to tile index @p b (excluded) diff --git a/include/dlaf/multiplication/general/api.h b/include/dlaf/multiplication/general/api.h index 3bb418cac4..058c1ba8de 100644 --- a/include/dlaf/multiplication/general/api.h +++ b/include/dlaf/multiplication/general/api.h @@ -25,6 +25,10 @@ template struct General { static void callNN(const T alpha, MatrixRef& mat_a, MatrixRef& mat_b, const T beta, MatrixRef& mat_c); + static void callNN(common::Pipeline& row_task_chain, + common::Pipeline& col_task_chain, const T alpha, + MatrixRef& mat_a, MatrixRef& mat_b, const T beta, + MatrixRef& mat_c); }; template diff --git a/include/dlaf/multiplication/general/impl.h b/include/dlaf/multiplication/general/impl.h index 34abe4adef..2c6178bdf2 100644 --- a/include/dlaf/multiplication/general/impl.h +++ b/include/dlaf/multiplication/general/impl.h @@ -60,6 +60,82 @@ void General::callNN(const T alpha, MatrixRef& mat_a, Matri } } +template +void General::callNN(common::Pipeline& row_task_chain, + common::Pipeline& col_task_chain, const T alpha, + MatrixRef& mat_a, MatrixRef& mat_b, const T beta, + MatrixRef& mat_c) { + namespace ex = pika::execution::experimental; + + if (mat_c.size().isEmpty()) + return; + + const matrix::Distribution& dist_a = mat_a.distribution(); + const matrix::Distribution& dist_b = mat_b.distribution(); + const matrix::Distribution& dist_c = mat_c.distribution(); + const auto rank = dist_c.rankIndex(); + + constexpr std::size_t n_workspaces = 2; + common::RoundRobin> panelsA(n_workspaces, dist_c); + common::RoundRobin> panelsB(n_workspaces, dist_c); + + DLAF_ASSERT_HEAVY(mat_a.nrTiles().cols() == mat_b.nrTiles().rows(), mat_a.nrTiles(), mat_b.nrTiles()); + + // This loops over the global indices for k, because every rank has to participate in communication + for (SizeType k = 0; k < mat_a.nrTiles().cols(); ++k) { + auto& panelA = panelsA.nextResource(); + auto& panelB = panelsB.nextResource(); + + if (k == 0 || k == mat_a.nrTiles().cols() - 1) { + DLAF_ASSERT_HEAVY(dist_a.tileSize(k) == dist_b.tileSize(k), + dist_a.tileSize(k), dist_b.tileSize(k)); + const SizeType kSize = dist_a.tileSize(k); + panelA.setWidth(kSize); + panelB.setHeight(kSize); + } + + // Setup the column workspace for the root ranks, i.e. the ones in the current col + const auto rank_k_col = dist_a.rankGlobalTile(k); + if (rank_k_col == rank.col()) { + const auto k_local = dist_a.template localTileFromGlobalTile(k); + for (SizeType i = 0; i < dist_c.localNrTiles().rows(); ++i) { + const LocalTileIndex ik(i, k_local); + panelA.setTile(ik, mat_a.read(ik)); + } + } + // Setup the row workspace for the root ranks, i.e. the ones in the current row + const auto rank_k_row = dist_b.rankGlobalTile(k); + if (rank_k_row == rank.row()) { + const auto k_local = dist_b.template localTileFromGlobalTile(k); + for (SizeType j = 0; j < dist_c.localNrTiles().cols(); ++j) { + const LocalTileIndex kj(k_local, j); + panelB.setTile(kj, mat_b.read(kj)); + } + } + + // Broadcast both column and row panel from root to others (row-wise and col-wise, respectively) + broadcast(rank_k_col, panelA, row_task_chain); + broadcast(rank_k_row, panelB, col_task_chain); + + // This is the core loop where the k step performs the update over the entire local matrix using + // the col and row workspaces. + // Everything needed for the update is available locally thanks to previous broadcasts. + for (SizeType i = 0; i < dist_c.localNrTiles().rows(); ++i) { + for (SizeType j = 0; j < dist_c.localNrTiles().cols(); ++j) { + const LocalTileIndex ij(i, j); + + ex::start_detached(dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::NoTrans, alpha, + panelA.read(ij), panelB.read(ij), + k == 0 ? beta : T(1), mat_c.readwrite(ij)) | + tile::gemm(dlaf::internal::Policy())); + } + } + + panelA.reset(); + panelB.reset(); + } +} + template void GeneralSub::callNN(const SizeType idx_begin, const SizeType idx_end, const blas::Op opA, const blas::Op opB, const T alpha, Matrix& mat_a, From 865d6e783d4924ec485e68a40c172e8dd6c5e85f Mon Sep 17 00:00:00 2001 From: Alberto Invernizzi Date: Fri, 27 Oct 2023 17:07:57 +0200 Subject: [PATCH 02/16] WIP: workaround for making it work with both Matrix and MatrixRef --- include/dlaf/util_matrix.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/dlaf/util_matrix.h b/include/dlaf/util_matrix.h index 2c07fef1a6..b8bb845bab 100644 --- a/include/dlaf/util_matrix.h +++ b/include/dlaf/util_matrix.h @@ -74,8 +74,8 @@ bool local_matrix(const MatrixLike& m) noexcept { } /// Returns true if the matrix is distributed on the communication grid. -template -bool equal_process_grid(const Matrix& m, const comm::CommunicatorGrid& g) noexcept { +template