Skip to content

Commit

Permalink
matrix implementation improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
KRM7 committed Feb 7, 2024
1 parent ed6b9c0 commit aa039a0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 116 deletions.
2 changes: 1 addition & 1 deletion src/algorithm/reference_lines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ namespace gapp::algorithm::reflines

FitnessMatrix pickSparseSubset(size_t dim, size_t num_points, RefLineGenerator generator, Positive<size_t> k)
{
if (num_points == 0) return {};
if (dim * num_points == 0) return {};

FitnessMatrix candidate_points = generator(dim, k * num_points);

Expand Down
93 changes: 31 additions & 62 deletions src/utility/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace gapp::detail
friend class MatrixRowBase<RowRef, Matrix>;
friend class MatrixRowBase<ConstRowRef, const Matrix>;

using value_type = RowRef;
using value_type = std::vector<T, A>;
using allocator_type = A;
using storage_type = std::vector<T, A>;

Expand All @@ -54,18 +54,18 @@ namespace gapp::detail

/* Iterators */

class ConstRowIterator : public stable_iterator_base<ConstRowIterator, const Matrix, ConstRowRef, ConstRowRef, detail::proxy_ptr<RowRef>, difference_type>
class ConstRowIterator : public stable_iterator_base<ConstRowIterator, const Matrix, value_type, ConstRowRef, detail::proxy_ptr<ConstRowRef>, difference_type>
{
public:
using my_base_ = stable_iterator_base<ConstRowIterator, const Matrix, ConstRowRef, ConstRowRef, detail::proxy_ptr<RowRef>, difference_type>;
using my_base_ = stable_iterator_base<ConstRowIterator, const Matrix, value_type, ConstRowRef, detail::proxy_ptr<ConstRowRef>, difference_type>;
using my_base_::my_base_;
using typename my_base_::iterator_category;
};

class RowIterator : public stable_iterator_base<RowIterator, Matrix, RowRef, RowRef, detail::proxy_ptr<RowRef>, difference_type>
class RowIterator : public stable_iterator_base<RowIterator, Matrix, value_type, RowRef, detail::proxy_ptr<RowRef>, difference_type>
{
public:
using my_base_ = stable_iterator_base<RowIterator, Matrix, RowRef, RowRef, detail::proxy_ptr<RowRef>, difference_type>;
using my_base_ = stable_iterator_base<RowIterator, Matrix, value_type, RowRef, detail::proxy_ptr<RowRef>, difference_type>;
using my_base_::my_base_;
using typename my_base_::iterator_category;

Expand All @@ -78,10 +78,10 @@ namespace gapp::detail
using const_reverse_iterator = std::reverse_iterator<ConstRowIterator>;

constexpr iterator begin() noexcept { return iterator(*this, 0); }
constexpr iterator end() noexcept { return iterator(*this, nrows_); }
constexpr iterator end() noexcept { return iterator(*this, nrows()); }

constexpr const_iterator begin() const noexcept { return const_iterator(*this, 0); }
constexpr const_iterator end() const noexcept { return const_iterator(*this, nrows_); }
constexpr const_iterator end() const noexcept { return const_iterator(*this, nrows()); }

/* Special member functions */

Expand All @@ -93,15 +93,15 @@ namespace gapp::detail
~Matrix() = default;

constexpr explicit Matrix(const A& alloc) :
data_(alloc), nrows_(0), ncols_(0)
data_(alloc), ncols_(0)
{}

constexpr Matrix(size_type nrows, size_type ncols) :
data_(nrows * ncols), nrows_(nrows), ncols_(ncols)
data_(nrows * ncols), ncols_(ncols)
{}

constexpr Matrix(size_type nrows, size_type ncols, const T& init, const A& alloc = A()) :
data_(nrows * ncols, init, alloc), nrows_(nrows), ncols_(ncols)
data_(nrows * ncols, init, alloc), ncols_(ncols)
{}

constexpr Matrix(std::initializer_list<std::initializer_list<T>> mat);
Expand All @@ -112,21 +112,21 @@ namespace gapp::detail

constexpr RowRef operator[](size_type row) noexcept
{
GAPP_ASSERT(row < nrows_, "Row index out of bounds.");
GAPP_ASSERT(row < nrows(), "Row index out of bounds.");

return RowRef(*this, row);
}

constexpr ConstRowRef operator[](size_type row) const noexcept
{
GAPP_ASSERT(row < nrows_, "Row index out of bounds.");
GAPP_ASSERT(row < nrows(), "Row index out of bounds.");

return ConstRowRef(*this, row);
}

constexpr storage_type column(size_type col_idx) const
{
GAPP_ASSERT(col_idx < ncols_, "Col index out of bounds.");
GAPP_ASSERT(col_idx < ncols(), "Col index out of bounds.");

storage_type col;
col.reserve(nrows());
Expand All @@ -141,18 +141,18 @@ namespace gapp::detail

constexpr element_reference operator()(size_type row, size_type col) noexcept
{
GAPP_ASSERT(row < nrows_, "Row index out of bounds.");
GAPP_ASSERT(col < ncols_, "Col index out of bounds.");
GAPP_ASSERT(row < nrows(), "Row index out of bounds.");
GAPP_ASSERT(col < ncols(), "Col index out of bounds.");

return data_[row * ncols_ + col];
return data_[row * ncols() + col];
}

constexpr const_element_reference operator()(size_type row, size_type col) const noexcept
{
GAPP_ASSERT(row < nrows_, "Row index out of bounds.");
GAPP_ASSERT(col < ncols_, "Col index out of bounds.");
GAPP_ASSERT(row < nrows(), "Row index out of bounds.");
GAPP_ASSERT(col < ncols(), "Col index out of bounds.");

return data_[row * ncols_ + col];
return data_[row * ncols() + col];
}

constexpr RowRef front() noexcept { return *begin(); }
Expand All @@ -169,13 +169,10 @@ namespace gapp::detail
constexpr void append_row(std::span<const T> row);
constexpr void pop_back() noexcept { resize(nrows() - 1, ncols()); } // NOLINT(*exception-escape)

constexpr iterator erase(const_iterator row);
constexpr iterator erase(const_iterator first, const_iterator last);

/* Size / capacity */

constexpr size_type size() const noexcept { return nrows_; } /* For the bounds checking in stable_iterator */
constexpr size_type nrows() const noexcept { return nrows_; }
constexpr size_type size() const noexcept { return nrows(); }
constexpr size_type nrows() const noexcept { return empty() ? 0 : (data_.size() / ncols_); }
constexpr size_type ncols() const noexcept { return ncols_; }
constexpr bool empty() const noexcept { return data_.empty(); }

Expand All @@ -184,35 +181,32 @@ namespace gapp::detail
constexpr void resize(size_type nrows, size_type ncols, const T& val = {})
{
data_.resize(nrows * ncols, val);
nrows_ = nrows;
ncols_ = ncols;
}

constexpr void clear() noexcept
{
data_.clear();
nrows_ = 0;
ncols_ = 0;
}

/* Other */

friend constexpr bool operator==(const Matrix& lhs, const Matrix& rhs)
{
return (lhs.empty() && rhs.empty()) ||
(lhs.nrows_ == rhs.nrows_ && lhs.ncols_ == rhs.ncols_ &&
(lhs.nrows() == rhs.nrows() && lhs.ncols() == rhs.ncols() &&
lhs.data_ == rhs.data_);
}

constexpr void swap(Matrix& other) noexcept
{
std::swap(data_, other.data_);
std::swap(nrows_, other.nrows_);
std::swap(ncols_, other.ncols_);
}

private:
storage_type data_;
size_type nrows_ = 0;
size_type ncols_ = 0;
};

Expand Down Expand Up @@ -264,13 +258,11 @@ namespace gapp::detail
}

constexpr size_type ncols() const noexcept { return mat_->ncols(); }

template<typename alloc_type>
explicit operator std::vector<value_type, alloc_type>() const { return std::vector<value_type, alloc_type>(begin(), end()); }

/* implicit */ operator std::span<copy_const_t<MatrixType, value_type>>() const { return { begin(), end() }; }
template<typename alloc_type>
operator std::vector<value_type, alloc_type>() const { return std::vector<value_type, alloc_type>(begin(), end()); }

/* Comparison operators */
operator std::span<copy_const_t<MatrixType, value_type>>() const { return { begin(), end() }; }

constexpr friend bool operator==(const Derived& lhs, const Derived& rhs)
{
Expand Down Expand Up @@ -406,16 +398,15 @@ namespace gapp::detail
/* MATRIX IMPLEMENTATION */

template<typename T, typename A>
constexpr Matrix<T, A>::Matrix(std::initializer_list<std::initializer_list<T>> mat) :
data_(), nrows_(mat.size()), ncols_(0)
constexpr Matrix<T, A>::Matrix(std::initializer_list<std::initializer_list<T>> mat)
{
if (mat.size() == 0) return;

ncols_ = mat.begin()->size();

GAPP_ASSERT(std::all_of(mat.begin(), mat.end(), detail::is_size(ncols_)), "Unequal row sizes in the input matrix.");

data_.reserve(nrows_ * ncols_);
data_.reserve(mat.size() * ncols_);
for (auto& row : mat)
{
for (auto& entry : row)
Expand All @@ -430,39 +421,17 @@ namespace gapp::detail
{
if (std::distance(first, last) <= 0) return;

nrows_ = std::distance(first, last);
ncols_ = first->size();
data_ = storage_type(first->begin(), first->begin() + nrows_ * ncols_);
data_ = storage_type(first->begin(), first->begin() + std::distance(first, last) * ncols_);
}

template<typename T, typename A>
constexpr void Matrix<T, A>::append_row(std::span<const T> row)
{
GAPP_ASSERT(row.size() == ncols_ || nrows_ == 0, "Can't insert row with different column count.");
GAPP_ASSERT(row.size() == ncols() || nrows() == 0, "Can't insert row with different column count.");

data_.insert(data_.end(), row.begin(), row.end());
if (nrows_ == 0) ncols_ = row.size();
nrows_++;
}

template<typename T, typename A>
constexpr auto Matrix<T, A>::erase(const_iterator row) -> iterator
{
const auto last_removed = data_.erase(row->begin(), row->end());
--nrows_;

return begin() + std::distance(data_.begin(), last_removed) / ncols_;
}

template<typename T, typename A>
constexpr auto Matrix<T, A>::erase(const_iterator first, const_iterator last) -> iterator
{
const auto data_last = last != end() ? last->begin() : data_.end(); // can't dereference last iter

const auto last_removed = data_.erase(first->begin(), data_last);
nrows_ -= std::distance(first, last);

return begin() + std::distance(data_.begin(), last_removed) / ncols_;
ncols_ = row.size();
}

} // namespace gapp::detail
Expand Down
66 changes: 19 additions & 47 deletions test/unit/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@ TEST_CASE("matrix", "[matrix]")

SECTION("iterator types")
{
using Row = Matrix<double>::RowRef;
using ConstRow = Matrix<double>::ConstRowRef;

using Iterator = Matrix<double>::iterator;
using ConstIterator = Matrix<double>::const_iterator;

STATIC_REQUIRE(std::is_same_v<Iterator::value_type, Row>);
STATIC_REQUIRE(std::is_same_v<Iterator::reference, Row>);
STATIC_REQUIRE(std::is_same_v<Iterator::value_type, Matrix<double>::value_type>);
STATIC_REQUIRE(std::is_same_v<Iterator::reference, Matrix<double>::reference>);

STATIC_REQUIRE(std::is_same_v<ConstIterator::value_type, ConstRow>);
STATIC_REQUIRE(std::is_same_v<ConstIterator::reference, ConstRow>);
STATIC_REQUIRE(std::is_same_v<ConstIterator::value_type, Matrix<double>::value_type>);
STATIC_REQUIRE(std::is_same_v<ConstIterator::reference, Matrix<double>::const_reference>);
}

Matrix<int> mat1;
Expand Down Expand Up @@ -128,30 +125,6 @@ TEST_CASE("matrix", "[matrix]")
REQUIRE(mat1 == Matrix{ { 2, 3, 4, 1 } });
}

SECTION("erase rows")
{
// erase single row
const auto last1 = mat2.erase(mat2.begin());

REQUIRE(mat2.nrows() == 1);
REQUIRE(mat2 == Matrix{ { 4, 5, 6 } });
REQUIRE(last1 == mat2.begin());

const auto last2 = mat2.erase(mat2.begin());

REQUIRE(mat2.nrows() == 0);
REQUIRE(mat2 == mat1);
REQUIRE(last2 == mat2.end());

// erase multiple rows
mat2 = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };

const auto last3 = mat2.erase(mat2.begin() + 1, mat2.end());

REQUIRE(mat2 == Matrix{ { 1, 2, 3 } });
REQUIRE(last3 == mat2.end());
}

SECTION("swap")
{
using std::swap;
Expand Down Expand Up @@ -262,29 +235,28 @@ TEST_CASE("matrix_rows", "[matrix]")
REQUIRE(mat == Matrix{ { 1, 2, 3 }, { 1, 2, 3 }, { 7, 8, 9 } });
REQUIRE(vec == std::vector{ 4, 5, 6 });
}

SECTION("row swaps temp")
{
Matrix<int>::value_type temp = row1;
row1 = row2;
row2 = temp;

REQUIRE(mat == Matrix{ { 4, 5, 6 }, { 1, 2, 3 }, { 7, 8, 9 } });
}
}

TEST_CASE("matrix_algorithms", "[matrix]")
{
Matrix mat1 = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
Matrix mat2 = { { 37, 40, 13 }, { 14, 4, 0 }, { 8, -1, 9 } };

for (auto row : mat1)
{
for (int& entry : row)
{
entry = 0;
}
}

const bool all_zero = std::all_of(mat1.cbegin(), mat1.cend(), [](auto row)
{
return std::all_of(row.begin(), row.end(), equal_to(0));
});

REQUIRE(all_zero);

std::copy(mat2.begin(), mat2.end(), mat1.begin());

REQUIRE(std::equal(mat1.begin(), mat1.end(), mat2.cbegin(), mat2.cend()));

std::sort(mat1.begin(), mat1.end(), [](const auto& lhs, const auto& rhs) { return lhs[0] < rhs[0]; });
REQUIRE(mat1 == Matrix{ { 8, -1, 9 }, { 14, 4, 0 }, { 37, 40, 13 } });

std::reverse(mat1.begin(), mat1.end());
REQUIRE(mat1 == Matrix{ { 37, 40, 13 }, { 14, 4, 0 }, { 8, -1, 9 } });
}
8 changes: 2 additions & 6 deletions test/unit/reference_lines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using namespace gapp::algorithm::reflines;
TEMPLATE_TEST_CASE_SIG("reference_lines", "[pareto_front]", ((auto F), F), quasirandomSimplexPointsMirror, quasirandomSimplexPointsSort, quasirandomSimplexPointsRoot, quasirandomSimplexPointsLog)
{
const size_t num_points = GENERATE(0, 1, 10);
const size_t dim = GENERATE(0, 1, 2, 3, 100);
const size_t dim = GENERATE(1, 2, 3, 100);

INFO("Dimensions: " << dim);

Expand All @@ -30,8 +30,6 @@ TEMPLATE_TEST_CASE_SIG("reference_lines", "[pareto_front]", ((auto F), F), quasi
REQUIRE(points.size() == num_points);
REQUIRE(std::all_of(points.begin(), points.end(), is_size(dim)));

if (dim == 0) return;

for (const auto& point : points)
{
CAPTURE(point);
Expand All @@ -43,7 +41,7 @@ TEMPLATE_TEST_CASE_SIG("reference_lines", "[pareto_front]", ((auto F), F), quasi
TEMPLATE_TEST_CASE_SIG("reference_lines_subset", "[pareto_front]", ((auto F), F), quasirandomSimplexPointsMirror, quasirandomSimplexPointsSort, quasirandomSimplexPointsRoot, quasirandomSimplexPointsLog)
{
const size_t num_points = GENERATE(0, 1, 10);
const size_t dim = GENERATE(0, 1, 2, 3, 100);
const size_t dim = GENERATE(1, 2, 3, 100);

INFO("Dimensions: " << dim);

Expand All @@ -52,8 +50,6 @@ TEMPLATE_TEST_CASE_SIG("reference_lines_subset", "[pareto_front]", ((auto F), F)
REQUIRE(points.size() == num_points);
REQUIRE(std::all_of(points.begin(), points.end(), is_size(dim)));

if (dim == 0) return;

for (const auto& point : points)
{
CAPTURE(point);
Expand Down

0 comments on commit aa039a0

Please sign in to comment.