Skip to content

Commit

Permalink
drop grid from signature
Browse files Browse the repository at this point in the history
since we can get the same information directly from matrix
  • Loading branch information
albestro committed Nov 28, 2023
1 parent 2eaee33 commit 4f9514d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 13 deletions.
14 changes: 4 additions & 10 deletions include/dlaf/multiplication/general.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,28 +68,22 @@ void generalMatrix(const blas::Op opA, const blas::Op opB, const T alpha, Matrix
/// C = alpha * A * B + beta * C
///
/// @param mat_a contains the input matrix A.
/// @pre @p mat_a is distributed according to @p grid
///
/// @param mat_b contains the input matrix B.
/// @pre @p mat_b is distributed according to @p grid
///
/// @param mat_c On entry it contains the input matrix C. On exit matrix tiles in the range will be
/// overwritten with the result, while others are left untouched.
/// @pre @p mat_c is distributed according to @p grid
///
/// @pre @p mat_a, @p mat_b and @p mat_c are distributed the same way,
/// @pre multipliable_sizes(mat_a.size(), mat_b.size(), mat_c.size(), opA, opB)
/// @pre multipliable_sizes(mat_a.tile_size(), mat_b.tile_size(), mat_c.tile_size(), opA, opB)
/// @pre multipliable_sizes(mat_a.tile_size_of({0, 0}), mat_b.tile_size_of({0, 0}),
/// mat_c.tile_size_of({0, 0}), opA, opB)
template <Backend B, Device D, class T>
void generalMatrix([[maybe_unused]] comm::CommunicatorGrid grid,
common::Pipeline<comm::Communicator>& row_task_chain,
void generalMatrix(common::Pipeline<comm::Communicator>& row_task_chain,
common::Pipeline<comm::Communicator>& col_task_chain, const T alpha,
MatrixRef<const T, D>& mat_a, MatrixRef<const T, D>& mat_b, const T beta,
MatrixRef<T, D>& mat_c) {
DLAF_ASSERT(matrix::equal_process_grid(mat_a, grid), mat_a, grid);
DLAF_ASSERT(matrix::equal_process_grid(mat_b, grid), mat_a, grid);
DLAF_ASSERT(matrix::equal_process_grid(mat_c, grid), mat_a, grid);
DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_a), mat_c, mat_b);
DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_b), mat_c, mat_b);

DLAF_ASSERT_HEAVY(matrix::multipliable(mat_a, mat_b, mat_c, blas::Op::NoTrans, blas::Op::NoTrans),
mat_a, mat_b, mat_c);
Expand Down
2 changes: 1 addition & 1 deletion include/dlaf/util_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ bool equal_process_grid(const MatrixLike<const T, D>& m, const comm::Communicato
return m.commGridSize() == g.size() && m.rankIndex() == g.rank();
}

/// Returns true if the matrix is distributed on the communication grid.
/// Returns true if the two matrices are distributed on the same grid
template <template <class, Device> class MatrixLikeA, template <class, Device> class MatrixLikeB,
class T, Device D>
bool same_process_grid(const MatrixLikeA<const T, D>& a, const MatrixLikeB<const T, D>& b) noexcept {
Expand Down
4 changes: 2 additions & 2 deletions test/unit/multiplication/test_multiplication_general.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ void testGeneralMultiplication(const T alpha, const T beta, const GemmConfig& co
// Note: currently it is implemented just the NoTrans/NoTrans case
ASSERT_EQ(config.opA, blas::Op::NoTrans);
ASSERT_EQ(config.opB, blas::Op::NoTrans);
multiplication::internal::generalMatrix<B>(grid, mpi_row_chain, mpi_col_chain, alpha, mat_sub_a,
mat_sub_b, beta, mat_sub_c);
multiplication::internal::generalMatrix<B>(mpi_row_chain, mpi_col_chain, alpha, mat_sub_a, mat_sub_b,
beta, mat_sub_c);
}

const auto fullValuesResult = mix_values(config.sub_c(), subValuesResult, fullValuesC);
Expand Down

0 comments on commit 4f9514d

Please sign in to comment.