Skip to content

Commit

Permalink
update comparisons for operator<=>
Browse files Browse the repository at this point in the history
  • Loading branch information
KRM7 committed Aug 26, 2023
1 parent 9b1fecb commit 16ca60e
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 99 deletions.
9 changes: 8 additions & 1 deletion src/core/ga_base.impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,14 @@ namespace gapp
auto optimal_pop = algorithm_->optimalSolutions(*this, pop);

optimal_sols = detail::mergeParetoSets(std::move(optimal_sols), std::move(optimal_pop));
detail::erase_duplicates(optimal_sols);

/* Duplicate elements are removed from optimal_sols using exact comparison
* of the chromosomes in order to avoid issues with using a non-transitive
* comparison function for std::sort and std::unique. */
std::ranges::sort(optimal_sols, std::less{}, &Candidate<T>::chromosome);
auto chrom_eq = [](const auto& lhs, const auto& rhs) { return lhs.chromosome == rhs.chromosome; };
auto last = std::unique(optimal_sols.begin(), optimal_sols.end(), chrom_eq);
optimal_sols.erase(last, optimal_sols.end());
}

template<typename T>
Expand Down
60 changes: 4 additions & 56 deletions src/population/candidate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,14 @@ namespace gapp
template<typename T>
using CandidatePair = std::pair<Candidate<T>, Candidate<T>>;

/* Candidates are considered equal if their chromosomes are the same. */

/**
* Comparison operators based on the chromosomes of the candidates.
* The comparisons are not transitive if T is a floating-point type.
*/
template<typename T>
bool operator==(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept;

template<typename T>
bool operator!=(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept;

/* Lexicographical comparison operators based on the chromosomes of the candidates. */

template<typename T>
bool operator<(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept;

template<typename T>
bool operator<=(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept;

template<typename T>
bool operator>(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept;

template<typename T>
bool operator>=(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept;

/* Hash function for the candidates. */
template<detail::Hashable T>
Expand Down Expand Up @@ -198,45 +185,6 @@ namespace gapp
}
}

template<typename T>
bool operator!=(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept
{
return !(lhs == rhs);
}

template<typename T>
bool operator<(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept
{
if constexpr (std::is_floating_point_v<T>)
{
return std::lexicographical_compare(lhs.chromosome.begin(), lhs.chromosome.end(),
rhs.chromosome.begin(), rhs.chromosome.end(),
math::floatIsLess<T>);
}
else
{
return lhs.chromosome < rhs.chromosome;
}
}

template<typename T>
bool operator>=(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept
{
return !(lhs < rhs);
}

template<typename T>
bool operator>(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept
{
return rhs < lhs;
}

template<typename T>
bool operator<=(const Candidate<T>& lhs, const Candidate<T>& rhs) noexcept
{
return !(rhs < lhs);
}

template<detail::Hashable T>
size_t CandidateHasher<T>::operator()(const Candidate<T>& candidate) const noexcept
{
Expand Down
37 changes: 19 additions & 18 deletions src/utility/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <vector>
#include <span>
#include <atomic>
#include <compare>
#include <concepts>
#include <limits>
#include <cstdint>
Expand Down Expand Up @@ -101,33 +102,33 @@ namespace gapp::math
};


/* Comparison function for floating point numbers. Returns -1 if (lhs < rhs), +1 if (lhs > rhs), and 0 if (lhs == rhs). */
/* Three-way comparison function for floating point numbers. */
template<std::floating_point T>
constexpr std::int8_t floatCompare(T lhs, T rhs) noexcept;
std::weak_ordering floatCompare(T lhs, T rhs) noexcept;

/* Equality comparison for floating point numbers. Returns true if lhs is approximately equal to rhs. */
template<std::floating_point T>
constexpr bool floatIsEqual(T lhs, T rhs) noexcept;
bool floatIsEqual(T lhs, T rhs) noexcept;

/* Less than comparison for floating point numbers. Returns true if lhs is definitely less than rhs. */
template<std::floating_point T>
constexpr bool floatIsLess(T lhs, T rhs) noexcept;
bool floatIsLess(T lhs, T rhs) noexcept;

/* Less than comparison for fp numbers. Assumes that lhs is not greater than rhs. */
template<std::floating_point T>
constexpr bool floatIsLessAssumeNotGreater(T lhs, T rhs) noexcept;
bool floatIsLessAssumeNotGreater(T lhs, T rhs) noexcept;

/* Greater than comparison for floating point numbers. Returns true if lhs is definitely greater than rhs. */
template<std::floating_point T>
constexpr bool floatIsGreater(T lhs, T rhs) noexcept;
bool floatIsGreater(T lhs, T rhs) noexcept;

/* Less than or equal to comparison for floating point numbers. Returns true if lhs is less than or approximately equal to rhs. */
template<std::floating_point T>
constexpr bool floatIsLessEq(T lhs, T rhs) noexcept;
bool floatIsLessEq(T lhs, T rhs) noexcept;

/* Greater than or equal to comparison for floating point numbers. Returns true if lhs is greater than or approximately equal to rhs. */
template<std::floating_point T>
constexpr bool floatIsGreaterEq(T lhs, T rhs) noexcept;
bool floatIsGreaterEq(T lhs, T rhs) noexcept;

/* Equality comparison for fp vectors. Returns true if the elements of the ranges are approximately equal. */
template<std::floating_point T>
Expand Down Expand Up @@ -182,21 +183,21 @@ namespace gapp::math
namespace gapp::math
{
template<std::floating_point T>
constexpr std::int8_t floatCompare(T lhs, T rhs) noexcept
std::weak_ordering floatCompare(T lhs, T rhs) noexcept
{
GAPP_ASSERT(!std::isnan(lhs) && !std::isnan(rhs));

const T diff = lhs - rhs;
const T scale = std::min(std::max(std::abs(lhs), std::abs(rhs)), std::numeric_limits<T>::max());
const T tol = std::max(Tolerances::rel<T>(scale), Tolerances::abs<T>());

if (diff > tol) return 1; // lhs < rhs
if (diff < -tol) return -1; // lhs > rhs
return 0; // lhs == rhs
if (diff > tol) return std::weak_ordering::greater;
if (diff < -tol) return std::weak_ordering::less;
return std::weak_ordering::equivalent;
}

template<std::floating_point T>
constexpr bool floatIsEqual(T lhs, T rhs) noexcept
bool floatIsEqual(T lhs, T rhs) noexcept
{
const T scale = std::max(std::abs(lhs), std::abs(rhs));

Expand All @@ -206,7 +207,7 @@ namespace gapp::math
}

template<std::floating_point T>
constexpr bool floatIsLess(T lhs, T rhs) noexcept
bool floatIsLess(T lhs, T rhs) noexcept
{
const T scale = std::max(std::abs(lhs), std::abs(rhs));

Expand All @@ -216,7 +217,7 @@ namespace gapp::math
}

template<std::floating_point T>
constexpr bool floatIsLessAssumeNotGreater(T lhs, T rhs) noexcept
bool floatIsLessAssumeNotGreater(T lhs, T rhs) noexcept
{
const T scale = std::abs(rhs);

Expand All @@ -226,7 +227,7 @@ namespace gapp::math
}

template<std::floating_point T>
constexpr bool floatIsGreater(T lhs, T rhs) noexcept
bool floatIsGreater(T lhs, T rhs) noexcept
{
const T scale = std::max(std::abs(lhs), std::abs(rhs));

Expand All @@ -236,13 +237,13 @@ namespace gapp::math
}

template<std::floating_point T>
constexpr bool floatIsLessEq(T lhs, T rhs) noexcept
bool floatIsLessEq(T lhs, T rhs) noexcept
{
return !floatIsGreater(lhs, rhs);
}

template<std::floating_point T>
constexpr bool floatIsGreaterEq(T lhs, T rhs) noexcept
bool floatIsGreaterEq(T lhs, T rhs) noexcept
{
return !floatIsLess(lhs, rhs);
}
Expand Down
48 changes: 24 additions & 24 deletions test/unit/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,38 +150,38 @@ TEST_CASE("fp_compare", "[math]")
ScopedTolerances _(abs, rel);
INFO("Relative tolerance eps: " << rel << ", absolute tolerance: " << abs);

REQUIRE(floatCompare(0.0, 0.0) == 0);
REQUIRE(floatCompare(0.0, -0.0) == 0);
REQUIRE(floatCompare(-0.0, 0.0) == 0);
REQUIRE(floatCompare(-0.0, -0.0) == 0);
REQUIRE((floatCompare(0.0, 0.0) == 0));
REQUIRE((floatCompare(0.0, -0.0) == 0));
REQUIRE((floatCompare(-0.0, 0.0) == 0));
REQUIRE((floatCompare(-0.0, -0.0) == 0));

REQUIRE(floatCompare(4.0, 4.0) == 0);
REQUIRE(floatCompare(0.0, 4.0) < 0);
REQUIRE(floatCompare(4.0, 0.0) > 0);
REQUIRE((floatCompare(4.0, 4.0) == 0));
REQUIRE((floatCompare(0.0, 4.0) < 0));
REQUIRE((floatCompare(4.0, 0.0) > 0));

REQUIRE(floatCompare(SMALL, SMALL) == 0);
REQUIRE(floatCompare(BIG, BIG) == 0);
REQUIRE(floatCompare(INF, INF) == 0);
REQUIRE((floatCompare(SMALL, SMALL) == 0));
REQUIRE((floatCompare(BIG, BIG) == 0));
REQUIRE((floatCompare(INF, INF) == 0));

REQUIRE(floatCompare(-INF, INF) < 0);
REQUIRE(floatCompare(INF, -INF) > 0);
REQUIRE(floatCompare(INF, INF) == 0);
REQUIRE(floatCompare(-INF, -INF) == 0);
REQUIRE((floatCompare(-INF, INF) < 0));
REQUIRE((floatCompare(INF, -INF) > 0));
REQUIRE((floatCompare(INF, INF) == 0));
REQUIRE((floatCompare(-INF, -INF) == 0));

REQUIRE(floatCompare(0.0, INF) < 0);
REQUIRE(floatCompare(INF, 0.0) > 0);
REQUIRE((floatCompare(0.0, INF) < 0));
REQUIRE((floatCompare(INF, 0.0) > 0));

REQUIRE(floatCompare(0.0, BIG) < 0);
REQUIRE(floatCompare(BIG, 0.0) > 0);
REQUIRE((floatCompare(0.0, BIG) < 0));
REQUIRE((floatCompare(BIG, 0.0) > 0));

REQUIRE(floatCompare(SMALL, BIG) < 0);
REQUIRE(floatCompare(BIG, SMALL) > 0);
REQUIRE((floatCompare(SMALL, BIG) < 0));
REQUIRE((floatCompare(BIG, SMALL) > 0));

REQUIRE(floatCompare(SMALL, INF) < 0);
REQUIRE(floatCompare(INF, SMALL) > 0);
REQUIRE((floatCompare(SMALL, INF) < 0));
REQUIRE((floatCompare(INF, SMALL) > 0));

REQUIRE(floatCompare(BIG, INF) < 0);
REQUIRE(floatCompare(INF, BIG) > 0);
REQUIRE((floatCompare(BIG, INF) < 0));
REQUIRE((floatCompare(INF, BIG) > 0));
}

SECTION("is_less_not_greater")
Expand Down

0 comments on commit 16ca60e

Please sign in to comment.