Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Gemm (distributed) to be used with MatrixRef #1022

Merged
merged 16 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/dlaf/multiplication/general.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ void generalMatrix(const blas::Op opA, const blas::Op opB, const T alpha, Matrix
DLAF_UNIMPLEMENTED(opA, opB);
}

/// General sub-matrix distributed multiplication, computing
/// C = alpha * A * B + beta * C
///
/// @param mat_a contains the input matrix A.
/// @param mat_b contains the input matrix B.
/// @param mat_c On entry it contains the input matrix C. On exit matrix tiles in the range will be
/// overwritten with the result, while others are left untouched.
///
/// @pre @p mat_a, @p mat_b and @p mat_c are distributed the same way,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?
What I read -> same distribution
what should be -> same grid

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see 92bd701

/// @pre multipliable_sizes(mat_a.size(), mat_b.size(), mat_c.size(), opA, opB)
/// @pre multipliable_sizes(mat_a.tile_size(), mat_b.tile_size(), mat_c.tile_size(), opA, opB)
/// @pre multipliable_sizes(mat_a.tile_size_of({0, 0}), mat_b.tile_size_of({0, 0}),
/// mat_c.tile_size_of({0, 0}), opA, opB)
albestro marked this conversation as resolved.
Show resolved Hide resolved
template <Backend B, Device D, class T>
void generalMatrix(common::Pipeline<comm::Communicator>& row_task_chain,
common::Pipeline<comm::Communicator>& col_task_chain, const T alpha,
MatrixRef<const T, D>& mat_a, MatrixRef<const T, D>& mat_b, const T beta,
MatrixRef<T, D>& mat_c) {
DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_a), mat_c, mat_b);
DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_b), mat_c, mat_b);

DLAF_ASSERT_HEAVY(matrix::multipliable(mat_a, mat_b, mat_c, blas::Op::NoTrans, blas::Op::NoTrans),
mat_a, mat_b, mat_c);

internal::General<B, D, T>::callNN(row_task_chain, col_task_chain, alpha, mat_a, mat_b, beta, mat_c);
}

/// General sub-matrix multiplication implementation on local memory, computing
/// C[a:b][a:b] = alpha * opA(A[a:b][a:b]) * opB(B[a:b][a:b]) + beta * C[a:b][a:b]
/// where [a:b] is the range of tiles starting from tile index @p a to tile index @p b (excluded)
Expand Down
4 changes: 4 additions & 0 deletions include/dlaf/multiplication/general/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ template <Backend B, Device D, class T>
struct General {
static void callNN(const T alpha, MatrixRef<const T, D>& mat_a, MatrixRef<const T, D>& mat_b,
const T beta, MatrixRef<T, D>& mat_c);
static void callNN(common::Pipeline<comm::Communicator>& row_task_chain,
common::Pipeline<comm::Communicator>& col_task_chain, const T alpha,
MatrixRef<const T, D>& mat_a, MatrixRef<const T, D>& mat_b, const T beta,
MatrixRef<T, D>& mat_c);
};

template <Backend B, Device D, class T>
Expand Down
88 changes: 88 additions & 0 deletions include/dlaf/multiplication/general/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,94 @@ void General<B, D, T>::callNN(const T alpha, MatrixRef<const T, D>& mat_a, Matri
}
}

template <Backend B, Device D, class T>
void General<B, D, T>::callNN(common::Pipeline<comm::Communicator>& row_task_chain,
common::Pipeline<comm::Communicator>& col_task_chain, const T alpha,
MatrixRef<const T, D>& mat_a, MatrixRef<const T, D>& mat_b, const T beta,
MatrixRef<T, D>& mat_c) {
namespace ex = pika::execution::experimental;

if (mat_c.size().isEmpty())
return;
albestro marked this conversation as resolved.
Show resolved Hide resolved

const matrix::Distribution& dist_a = mat_a.distribution();
const matrix::Distribution& dist_b = mat_b.distribution();
const matrix::Distribution& dist_c = mat_c.distribution();
const auto rank = dist_c.rank_index();

if (mat_a.nr_tiles().cols() == 0) {
// Note: if beta == 1, we optimize by not even scheduling anything
if (beta != T(1)) {
for (SizeType j = 0; j < mat_c.distribution().local_nr_tiles().cols(); ++j)
for (SizeType i = 0; i < mat_c.distribution().local_nr_tiles().rows(); ++i)
ex::start_detached(dlaf::internal::whenAllLift(beta, mat_c.readwrite(LocalTileIndex(i, j))) |
tile::scal(dlaf::internal::Policy<B>()));
}
return;
}

constexpr std::size_t n_workspaces = 2;
common::RoundRobin<matrix::Panel<Coord::Col, T, D>> panelsA(n_workspaces, dist_a);
common::RoundRobin<matrix::Panel<Coord::Row, T, D>> panelsB(n_workspaces, dist_b);

DLAF_ASSERT_HEAVY(mat_a.nr_tiles().cols() == mat_b.nr_tiles().rows(), mat_a.nr_tiles(),
mat_b.nr_tiles());

// This loops over the global indices for k, because every rank has to participate in communication
for (SizeType k = 0; k < mat_a.nr_tiles().cols(); ++k) {
auto& panelA = panelsA.nextResource();
auto& panelB = panelsB.nextResource();

if (k == 0 || k == mat_a.nr_tiles().cols() - 1) {
DLAF_ASSERT_HEAVY(dist_a.tile_size_of<Coord::Col>(k) == dist_b.tile_size_of<Coord::Row>(k),
dist_a.tile_size_of<Coord::Col>(k), dist_b.tile_size_of<Coord::Row>(k));
const SizeType kSize = dist_a.tile_size_of<Coord::Col>(k);
panelA.setWidth(kSize);
panelB.setHeight(kSize);
}

// Setup the column workspace for the root ranks, i.e. the ones in the current col
const auto rank_k_col = dist_a.rank_global_tile<Coord::Col>(k);
if (rank_k_col == rank.col()) {
const auto k_local = dist_a.local_tile_from_global_tile<Coord::Col>(k);
for (SizeType i = 0; i < dist_c.local_nr_tiles().rows(); ++i) {
const LocalTileIndex ik(i, k_local);
panelA.setTile(ik, mat_a.read(ik));
}
}
// Setup the row workspace for the root ranks, i.e. the ones in the current row
const auto rank_k_row = dist_b.rank_global_tile<Coord::Row>(k);
if (rank_k_row == rank.row()) {
const auto k_local = dist_b.local_tile_from_global_tile<Coord::Row>(k);
for (SizeType j = 0; j < dist_c.local_nr_tiles().cols(); ++j) {
const LocalTileIndex kj(k_local, j);
panelB.setTile(kj, mat_b.read(kj));
}
}

// Broadcast both column and row panel from root to others (row-wise and col-wise, respectively)
broadcast(rank_k_col, panelA, row_task_chain);
broadcast(rank_k_row, panelB, col_task_chain);

// This is the core loop where the k step performs the update over the entire local matrix using
// the col and row workspaces.
// Everything needed for the update is available locally thanks to previous broadcasts.
for (SizeType i = 0; i < dist_c.local_nr_tiles().rows(); ++i) {
for (SizeType j = 0; j < dist_c.local_nr_tiles().cols(); ++j) {
const LocalTileIndex ij(i, j);

ex::start_detached(dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::NoTrans, alpha,
panelA.read(ij), panelB.read(ij),
k == 0 ? beta : T(1), mat_c.readwrite(ij)) |
tile::gemm(dlaf::internal::Policy<B>()));
}
}

panelA.reset();
panelB.reset();
}
}

template <Backend B, Device D, class T>
void GeneralSub<B, D, T>::callNN(const SizeType idx_begin, const SizeType idx_end, const blas::Op opA,
const blas::Op opB, const T alpha, Matrix<const T, D>& mat_a,
Expand Down
6 changes: 3 additions & 3 deletions include/dlaf/util_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ bool local_matrix(const MatrixLike<const T, D>& m) noexcept {
}

/// Returns true if the matrix is distributed on the communication grid.
template <class T, Device D>
bool equal_process_grid(const Matrix<const T, D>& m, const comm::CommunicatorGrid& g) noexcept {
template <template <class, Device> class MatrixLike, class T, Device D>
bool equal_process_grid(const MatrixLike<const T, D>& m, const comm::CommunicatorGrid& g) noexcept {
return m.commGridSize() == g.size() && m.rankIndex() == g.rank();
}

/// Returns true if the matrix is distributed on the communication grid.
/// Returns true if the two matrices are distributed on the same grid
template <template <class, Device> class MatrixLikeA, template <class, Device> class MatrixLikeB,
class T, Device D1, Device D2>
bool same_process_grid(const MatrixLikeA<const T, D1>& a, const MatrixLikeB<const T, D2>& b) noexcept {
Expand Down
135 changes: 135 additions & 0 deletions test/unit/multiplication/test_multiplication_general.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <dlaf/blas/enum_output.h>
#include <dlaf/common/assert.h>
#include <dlaf/common/index2d.h>
#include <dlaf/common/pipeline.h>
#include <dlaf/communication/communicator.h>
#include <dlaf/communication/communicator_grid.h>
#include <dlaf/matrix/index.h>
#include <dlaf/matrix/matrix.h>
Expand Down Expand Up @@ -40,6 +42,10 @@ template <class T>
struct GeneralMultiplicationTestMC : public ::testing::Test {};
TYPED_TEST_SUITE(GeneralMultiplicationTestMC, MatrixElementTypes);

template <class T>
struct GeneralMultiplicationDistTestMC : public TestWithCommGrids {};
TYPED_TEST_SUITE(GeneralMultiplicationDistTestMC, MatrixElementTypes);

template <class T>
struct GeneralSubMultiplicationTestMC : public ::testing::Test {};
TYPED_TEST_SUITE(GeneralSubMultiplicationTestMC, MatrixElementTypes);
Expand All @@ -53,6 +59,10 @@ template <class T>
struct GeneralMultiplicationTestGPU : public ::testing::Test {};
TYPED_TEST_SUITE(GeneralMultiplicationTestGPU, MatrixElementTypes);

template <class T>
struct GeneralMultiplicationDistTestGPU : public TestWithCommGrids {};
TYPED_TEST_SUITE(GeneralMultiplicationDistTestGPU, MatrixElementTypes);

template <class T>
struct GeneralSubMultiplicationTestGPU : public ::testing::Test {};
TYPED_TEST_SUITE(GeneralSubMultiplicationTestGPU, MatrixElementTypes);
Expand Down Expand Up @@ -136,6 +146,78 @@ void testGeneralMultiplication(const T alpha, const T beta, const GemmConfig& co
2 * (mat_ah.size().cols() + 1) * TypeUtilities<T>::error);
}

template <class T, Backend B, Device D>
void testGeneralMultiplication(const T alpha, const T beta, const GemmConfig& config,
comm::CommunicatorGrid& grid) {
using dlaf::matrix::internal::MatrixRef;

common::Pipeline<comm::Communicator> mpi_row_chain(grid.rowCommunicator());
common::Pipeline<comm::Communicator> mpi_col_chain(grid.colCommunicator());

const TileElementSize blocksize_a(config.mb, config.kb);
const TileElementSize blocksize_b(config.kb, config.nb);
const TileElementSize blocksize_c(config.mb, config.nb);

const comm::Index2D src_rank_c(std::max(0, grid.size().rows() - 1),
std::min(1, grid.size().cols() - 1));
const matrix::Distribution dist_c(config.full_c(), blocksize_c, grid.size(), grid.rank(), src_rank_c);

const comm::IndexT_MPI rank_aligned_row =
align_sub_rank_index<Coord::Row>(dist_c, config.sub_c().origin, blocksize_a,
config.sub_a().origin);
const comm::IndexT_MPI rank_aligned_col =
align_sub_rank_index<Coord::Col>(dist_c, config.sub_c().origin, blocksize_b,
config.sub_b().origin);

// Note:
// GEMM(NoTrans, NoTrans) requires:
// - a is rank aligned with c for what concerns rows
// - b is rank aligned with c for what concerns cols
const comm::Index2D src_rank_a{rank_aligned_row, 0};
const comm::Index2D src_rank_b{0, rank_aligned_col};

const matrix::Distribution dist_a(config.full_a(), blocksize_a, grid.size(), grid.rank(), src_rank_a);
const matrix::Distribution dist_b(config.full_b(), blocksize_b, grid.size(), grid.rank(), src_rank_b);

auto setMatrix = [&](auto&& elSetter, matrix::Distribution dist) {
Matrix<T, Device::CPU> matrix(std::move(dist));
dlaf::matrix::util::set(matrix, elSetter);
return matrix;
};

auto [subValuesA, subValuesB, subValuesC, subValuesResult] =
matrix::test::getMatrixMatrixMultiplication<GlobalElementIndex, T>(config.opA, config.opB,
config.k, alpha, beta);

const auto fullValuesA = mix_values(config.sub_a(), subValuesA, [](auto) { return T(-99); });
const auto fullValuesB = mix_values(config.sub_b(), subValuesB, [](auto) { return T(-99); });
const auto fullValuesC = mix_values(config.sub_c(), subValuesC, [](auto) { return T(-99); });

Matrix<const T, Device::CPU> mat_ah = setMatrix(fullValuesA, dist_a);
Matrix<const T, Device::CPU> mat_bh = setMatrix(fullValuesB, dist_b);
Matrix<T, Device::CPU> mat_ch = setMatrix(fullValuesC, dist_c);

{
MatrixMirror<const T, D, Device::CPU> mat_a(mat_ah);
MatrixMirror<const T, D, Device::CPU> mat_b(mat_bh);
MatrixMirror<T, D, Device::CPU> mat_c(mat_ch);

MatrixRef<const T, D> mat_sub_a(mat_a.get(), config.sub_a());
MatrixRef<const T, D> mat_sub_b(mat_b.get(), config.sub_b());
MatrixRef<T, D> mat_sub_c(mat_c.get(), config.sub_c());

// Note: currently it is implemented just the NoTrans/NoTrans case
ASSERT_EQ(config.opA, blas::Op::NoTrans);
ASSERT_EQ(config.opB, blas::Op::NoTrans);
multiplication::internal::generalMatrix<B>(mpi_row_chain, mpi_col_chain, alpha, mat_sub_a, mat_sub_b,
beta, mat_sub_c);
}

const auto fullValuesResult = mix_values(config.sub_c(), subValuesResult, fullValuesC);
CHECK_MATRIX_NEAR(fullValuesResult, mat_ch, 2 * (mat_ah.size().cols() + 1) * TypeUtilities<T>::error,
2 * (mat_ah.size().cols() + 1) * TypeUtilities<T>::error);
}

std::vector<GemmConfig> gemm_configs = {
// empty matrices
{blas::Op::NoTrans, blas::Op::NoTrans, 0, 0, 7, 3, 6, 2},
Expand Down Expand Up @@ -164,6 +246,7 @@ std::vector<GemmConfig> sub_gemm_configs = {
{blas::Op::NoTrans, blas::Op::NoTrans, 8, 8, 11, 10, 9, 13, {{2, 1}}, {{1, 1}}, {{0, 0}}},
// multi-tile
{blas::Op::NoTrans, blas::Op::NoTrans, 12, 20, 11, 3, 4, 5, {{7, 1}}, {{11, 10}}, {{4, 2}}},
{blas::Op::NoTrans, blas::Op::NoTrans, 12, 20, 11, 3, 4, 5, {{6, 10}}, {{5, 8}}, {{9, 12}}},
};

TYPED_TEST(GeneralMultiplicationTestMC, CorrectnessLocal) {
Expand Down Expand Up @@ -311,6 +394,32 @@ void testGeneralSubMultiplication(comm::CommunicatorGrid grid, const SizeType a,
2 * (mat_ah.size().cols() + 1) * TypeUtilities<T>::error);
}

TYPED_TEST(GeneralMultiplicationDistTestMC, CorrectnessDistributed) {
constexpr TypeParam alpha = TypeUtilities<TypeParam>::element(-1.3, .5);
constexpr TypeParam beta = TypeUtilities<TypeParam>::element(-2.6, .7);

for (auto comm_grid : this->commGrids()) {
for (const GemmConfig& test_config : gemm_configs) {
testGeneralMultiplication<TypeParam, Backend::MC, Device::CPU>(alpha, beta, test_config,
comm_grid);
pika::wait();
}
}
}

TYPED_TEST(GeneralMultiplicationDistTestMC, CorrectnessDistributedSub) {
constexpr TypeParam alpha = TypeUtilities<TypeParam>::element(-1.3, .5);
constexpr TypeParam beta = TypeUtilities<TypeParam>::element(-2.6, .7);

for (auto comm_grid : this->commGrids()) {
for (const GemmConfig& test_config : sub_gemm_configs) {
testGeneralMultiplication<TypeParam, Backend::MC, Device::CPU>(alpha, beta, test_config,
comm_grid);
pika::wait();
}
}
}

TYPED_TEST(GeneralSubMultiplicationDistTestMC, CorrectnessDistributed) {
for (auto comm_grid : this->commGrids()) {
for (const auto& [m, mb, a, b] : sizes) {
Expand All @@ -324,6 +433,32 @@ TYPED_TEST(GeneralSubMultiplicationDistTestMC, CorrectnessDistributed) {
}

#ifdef DLAF_WITH_GPU
TYPED_TEST(GeneralMultiplicationDistTestGPU, CorrectnessDistributed) {
constexpr TypeParam alpha = TypeUtilities<TypeParam>::element(-1.3, .5);
constexpr TypeParam beta = TypeUtilities<TypeParam>::element(-2.6, .7);

for (auto comm_grid : this->commGrids()) {
for (const GemmConfig& test_config : gemm_configs) {
testGeneralMultiplication<TypeParam, Backend::GPU, Device::GPU>(alpha, beta, test_config,
comm_grid);
pika::wait();
}
}
}

TYPED_TEST(GeneralMultiplicationDistTestGPU, CorrectnessDistributedSub) {
constexpr TypeParam alpha = TypeUtilities<TypeParam>::element(-1.3, .5);
constexpr TypeParam beta = TypeUtilities<TypeParam>::element(-2.6, .7);

for (auto comm_grid : this->commGrids()) {
for (const GemmConfig& test_config : sub_gemm_configs) {
testGeneralMultiplication<TypeParam, Backend::GPU, Device::GPU>(alpha, beta, test_config,
comm_grid);
pika::wait();
}
}
}

TYPED_TEST(GeneralSubMultiplicationDistTestGPU, CorrectnessDistributed) {
for (auto comm_grid : this->commGrids()) {
for (const auto& [m, mb, a, b] : sizes) {
Expand Down
Loading