Skip to content

Commit

Permalink
now retiled matrix is just a simple matrix + fix const correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Jul 19, 2023
1 parent 26afc1d commit 50f89dc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
15 changes: 7 additions & 8 deletions include/dlaf/matrix/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <dlaf/communication/kernels/p2p.h>
#include <dlaf/matrix/copy_tile.h>
#include <dlaf/matrix/matrix.h>
#include <dlaf/matrix/retiled_matrix.h>
#include <dlaf/types.h>
#include <dlaf/util_matrix.h>

Expand Down Expand Up @@ -86,18 +85,18 @@ void copy(Matrix<const T, Source>& source, Matrix<T, Destination>& dest) {
/// @pre nb = min(src.blockSize().cols(), dst.blockSize().cols())
/// src.blockSize().cols() % nb == 0
/// dst.blockSize().cols() % nb == 0
/// @pre src has equal tile and block sizes.
/// @pre dst has equal tile and block sizes.
template <class T, Device Source, Device Destination>
void copy(Matrix<T, Source>& src, // TODO this should be const
Matrix<T, Destination>& dst, comm::CommunicatorGrid grid) {
void copy(Matrix<const T, Source>& src, Matrix<T, Destination>& dst, comm::CommunicatorGrid grid) {
namespace ex = pika::execution::experimental;

DLAF_ASSERT_MODERATE(equal_size(src, dst), src.size(), dst.size());
DLAF_ASSERT_MODERATE(equal_process_grid(src, grid), src.commGridSize(), grid.size());
DLAF_ASSERT_MODERATE(equal_process_grid(dst, grid), dst.commGridSize(), grid.size());

// TODO Currently multiple tile per blocks cannot be tested, as Matrix does not support it yet.
DLAF_ASSERT_MODERATE(src.baseTileSize() == src.blockSize(), src.baseTileSize(), src.blockSize());
DLAF_ASSERT_MODERATE(dst.baseTileSize() == dst.blockSize(), dst.baseTileSize(), dst.blockSize());
DLAF_ASSERT_MODERATE(single_tile_per_block(src), src);
DLAF_ASSERT_MODERATE(single_tile_per_block(src), dst);

// Note:
// From an algorithmic point of view it would be better to reason in terms of block instead of tiles,
Expand All @@ -120,8 +119,8 @@ void copy(Matrix<T, Source>& src, // TODO this should be const
const LocalTileSize scale_factor_src{tile_size_src.rows() / mb, tile_size_src.cols() / nb};
const LocalTileSize scale_factor_dst{tile_size_dst.rows() / mb, tile_size_dst.cols() / nb};

RetiledMatrix<T, Source> src_retiled(src, scale_factor_src); // TODO this should be const
RetiledMatrix<T, Destination> dst_retiled(dst, scale_factor_dst);
Matrix<const T, Source> src_retiled = src.retiledSubPipelineConst(scale_factor_src);
Matrix<T, Destination> dst_retiled = dst.retiledSubPipeline(scale_factor_dst);

const comm::Index2D rank = grid.rank();
common::Pipeline<comm::Communicator> comm_pipeline(grid.fullCommunicator().clone());
Expand Down
15 changes: 9 additions & 6 deletions test/unit/matrix/test_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1840,7 +1840,7 @@ struct TestReshuffling {
const TileElementSize dst_tilesize;
};
std::vector<TestReshuffling> sizes_reshuffling_tests{
TestReshuffling{{10, 10}, {3, 3}, {3, 3}}, // no-reshuffling
TestReshuffling{{10, 10}, {3, 3}, {3, 3}}, // same shape
TestReshuffling{{10, 5}, {5, 10}, {10, 2}}, // x2 | /5
TestReshuffling{{26, 13}, {10, 3}, {5, 6}}, // /2 | x2
};
Expand All @@ -1856,14 +1856,17 @@ void testReshuffling(const TestReshuffling& config, CommunicatorGrid grid) {
std::min(grid.size().cols() - 1, dlaf::util::ceilDiv(grid.size().cols(), 2)));
matrix::Distribution dist_dst(size, dst_tilesize, grid.size(), grid.rank(), origin_rank_dst);

matrix::Matrix<T, Device::CPU> src_host(dist_src); // TODO this should be const
matrix::Matrix<T, Device::CPU> dst_host(dist_dst);

auto fixedValues = [](const GlobalElementIndex index) { return T(index.row() * 1000 + index.col()); };
matrix::util::set(src_host, fixedValues);

matrix::Matrix<const T, Device::CPU> src_host = [dist_src, fixedValues]() {
matrix::Matrix<T, Device::CPU> src_host(dist_src);
matrix::util::set(src_host, fixedValues);
return src_host;
}();
matrix::Matrix<T, Device::CPU> dst_host(dist_dst);

{
matrix::MatrixMirror<T, Source, Device::CPU> src(src_host);
matrix::MatrixMirror<const T, Source, Device::CPU> src(src_host);
matrix::MatrixMirror<T, Destination, Device::CPU> dst(dst_host);
matrix::copy(src.get(), dst.get(), grid);
}
Expand Down

0 comments on commit 50f89dc

Please sign in to comment.