From 16ca60e41235d08a9968d816f82b63288a49df8e Mon Sep 17 00:00:00 2001 From: KRM7 <70973547+KRM7@users.noreply.github.com> Date: Tue, 22 Aug 2023 01:19:11 +0200 Subject: [PATCH] update comparisons for operator<=> --- src/core/ga_base.impl.hpp | 9 +++++- src/population/candidate.hpp | 60 +++--------------------------------- src/utility/math.hpp | 37 +++++++++++----------- test/unit/math.cpp | 48 ++++++++++++++--------------- 4 files changed, 55 insertions(+), 99 deletions(-) diff --git a/src/core/ga_base.impl.hpp b/src/core/ga_base.impl.hpp index 31266af0..f8c0e5fe 100644 --- a/src/core/ga_base.impl.hpp +++ b/src/core/ga_base.impl.hpp @@ -481,7 +481,14 @@ namespace gapp auto optimal_pop = algorithm_->optimalSolutions(*this, pop); optimal_sols = detail::mergeParetoSets(std::move(optimal_sols), std::move(optimal_pop)); - detail::erase_duplicates(optimal_sols); + + /* Duplicate elements are removed from optimal_sols using exact comparison + * of the chromosomes in order to avoid issues with using a non-transitive + * comparison function for std::sort and std::unique. */ + std::ranges::sort(optimal_sols, std::less{}, &Candidate::chromosome); + auto chrom_eq = [](const auto& lhs, const auto& rhs) { return lhs.chromosome == rhs.chromosome; }; + auto last = std::unique(optimal_sols.begin(), optimal_sols.end(), chrom_eq); + optimal_sols.erase(last, optimal_sols.end()); } template diff --git a/src/population/candidate.hpp b/src/population/candidate.hpp index 5b7fa7f0..5a772f59 100644 --- a/src/population/candidate.hpp +++ b/src/population/candidate.hpp @@ -137,27 +137,14 @@ namespace gapp template using CandidatePair = std::pair, Candidate>; - /* Candidates are considered equal if their chromosomes are the same. */ + /** + * Comparison operators based on the chromosomes of the candidates. + * The comparisons are not transitive if T is a floating-point type. + */ template bool operator==(const Candidate& lhs, const Candidate& rhs) noexcept; - template - bool operator!=(const Candidate& lhs, const Candidate& rhs) noexcept; - - /* Lexicographical comparison operators based on the chromosomes of the candidates. */ - - template - bool operator<(const Candidate& lhs, const Candidate& rhs) noexcept; - - template - bool operator<=(const Candidate& lhs, const Candidate& rhs) noexcept; - - template - bool operator>(const Candidate& lhs, const Candidate& rhs) noexcept; - - template - bool operator>=(const Candidate& lhs, const Candidate& rhs) noexcept; /* Hash function for the candidates. */ template @@ -198,45 +185,6 @@ namespace gapp } } - template - bool operator!=(const Candidate& lhs, const Candidate& rhs) noexcept - { - return !(lhs == rhs); - } - - template - bool operator<(const Candidate& lhs, const Candidate& rhs) noexcept - { - if constexpr (std::is_floating_point_v) - { - return std::lexicographical_compare(lhs.chromosome.begin(), lhs.chromosome.end(), - rhs.chromosome.begin(), rhs.chromosome.end(), - math::floatIsLess); - } - else - { - return lhs.chromosome < rhs.chromosome; - } - } - - template - bool operator>=(const Candidate& lhs, const Candidate& rhs) noexcept - { - return !(lhs < rhs); - } - - template - bool operator>(const Candidate& lhs, const Candidate& rhs) noexcept - { - return rhs < lhs; - } - - template - bool operator<=(const Candidate& lhs, const Candidate& rhs) noexcept - { - return !(rhs < lhs); - } - template size_t CandidateHasher::operator()(const Candidate& candidate) const noexcept { diff --git a/src/utility/math.hpp b/src/utility/math.hpp index 97035b3f..fa41595e 100644 --- a/src/utility/math.hpp +++ b/src/utility/math.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -101,33 +102,33 @@ namespace gapp::math }; - /* Comparison function for floating point numbers. Returns -1 if (lhs < rhs), +1 if (lhs > rhs), and 0 if (lhs == rhs). */ + /* Three-way comparison function for floating point numbers. */ template - constexpr std::int8_t floatCompare(T lhs, T rhs) noexcept; + std::weak_ordering floatCompare(T lhs, T rhs) noexcept; /* Equality comparison for floating point numbers. Returns true if lhs is approximately equal to rhs. */ template - constexpr bool floatIsEqual(T lhs, T rhs) noexcept; + bool floatIsEqual(T lhs, T rhs) noexcept; /* Less than comparison for floating point numbers. Returns true if lhs is definitely less than rhs. */ template - constexpr bool floatIsLess(T lhs, T rhs) noexcept; + bool floatIsLess(T lhs, T rhs) noexcept; /* Less than comparison for fp numbers. Assumes that lhs is not greater than rhs. */ template - constexpr bool floatIsLessAssumeNotGreater(T lhs, T rhs) noexcept; + bool floatIsLessAssumeNotGreater(T lhs, T rhs) noexcept; /* Greater than comparison for floating point numbers. Returns true if lhs is definitely greater than rhs. */ template - constexpr bool floatIsGreater(T lhs, T rhs) noexcept; + bool floatIsGreater(T lhs, T rhs) noexcept; /* Less than or equal to comparison for floating point numbers. Returns true if lhs is less than or approximately equal to rhs. */ template - constexpr bool floatIsLessEq(T lhs, T rhs) noexcept; + bool floatIsLessEq(T lhs, T rhs) noexcept; /* Greater than or equal to comparison for floating point numbers. Returns true if lhs is greater than or approximately equal to rhs. */ template - constexpr bool floatIsGreaterEq(T lhs, T rhs) noexcept; + bool floatIsGreaterEq(T lhs, T rhs) noexcept; /* Equality comparison for fp vectors. Returns true if the elements of the ranges are approximately equal. */ template @@ -182,7 +183,7 @@ namespace gapp::math namespace gapp::math { template - constexpr std::int8_t floatCompare(T lhs, T rhs) noexcept + std::weak_ordering floatCompare(T lhs, T rhs) noexcept { GAPP_ASSERT(!std::isnan(lhs) && !std::isnan(rhs)); @@ -190,13 +191,13 @@ namespace gapp::math const T scale = std::min(std::max(std::abs(lhs), std::abs(rhs)), std::numeric_limits::max()); const T tol = std::max(Tolerances::rel(scale), Tolerances::abs()); - if (diff > tol) return 1; // lhs < rhs - if (diff < -tol) return -1; // lhs > rhs - return 0; // lhs == rhs + if (diff > tol) return std::weak_ordering::greater; + if (diff < -tol) return std::weak_ordering::less; + return std::weak_ordering::equivalent; } template - constexpr bool floatIsEqual(T lhs, T rhs) noexcept + bool floatIsEqual(T lhs, T rhs) noexcept { const T scale = std::max(std::abs(lhs), std::abs(rhs)); @@ -206,7 +207,7 @@ namespace gapp::math } template - constexpr bool floatIsLess(T lhs, T rhs) noexcept + bool floatIsLess(T lhs, T rhs) noexcept { const T scale = std::max(std::abs(lhs), std::abs(rhs)); @@ -216,7 +217,7 @@ namespace gapp::math } template - constexpr bool floatIsLessAssumeNotGreater(T lhs, T rhs) noexcept + bool floatIsLessAssumeNotGreater(T lhs, T rhs) noexcept { const T scale = std::abs(rhs); @@ -226,7 +227,7 @@ namespace gapp::math } template - constexpr bool floatIsGreater(T lhs, T rhs) noexcept + bool floatIsGreater(T lhs, T rhs) noexcept { const T scale = std::max(std::abs(lhs), std::abs(rhs)); @@ -236,13 +237,13 @@ namespace gapp::math } template - constexpr bool floatIsLessEq(T lhs, T rhs) noexcept + bool floatIsLessEq(T lhs, T rhs) noexcept { return !floatIsGreater(lhs, rhs); } template - constexpr bool floatIsGreaterEq(T lhs, T rhs) noexcept + bool floatIsGreaterEq(T lhs, T rhs) noexcept { return !floatIsLess(lhs, rhs); } diff --git a/test/unit/math.cpp b/test/unit/math.cpp index 28eb8780..e02189de 100644 --- a/test/unit/math.cpp +++ b/test/unit/math.cpp @@ -150,38 +150,38 @@ TEST_CASE("fp_compare", "[math]") ScopedTolerances _(abs, rel); INFO("Relative tolerance eps: " << rel << ", absolute tolerance: " << abs); - REQUIRE(floatCompare(0.0, 0.0) == 0); - REQUIRE(floatCompare(0.0, -0.0) == 0); - REQUIRE(floatCompare(-0.0, 0.0) == 0); - REQUIRE(floatCompare(-0.0, -0.0) == 0); + REQUIRE((floatCompare(0.0, 0.0) == 0)); + REQUIRE((floatCompare(0.0, -0.0) == 0)); + REQUIRE((floatCompare(-0.0, 0.0) == 0)); + REQUIRE((floatCompare(-0.0, -0.0) == 0)); - REQUIRE(floatCompare(4.0, 4.0) == 0); - REQUIRE(floatCompare(0.0, 4.0) < 0); - REQUIRE(floatCompare(4.0, 0.0) > 0); + REQUIRE((floatCompare(4.0, 4.0) == 0)); + REQUIRE((floatCompare(0.0, 4.0) < 0)); + REQUIRE((floatCompare(4.0, 0.0) > 0)); - REQUIRE(floatCompare(SMALL, SMALL) == 0); - REQUIRE(floatCompare(BIG, BIG) == 0); - REQUIRE(floatCompare(INF, INF) == 0); + REQUIRE((floatCompare(SMALL, SMALL) == 0)); + REQUIRE((floatCompare(BIG, BIG) == 0)); + REQUIRE((floatCompare(INF, INF) == 0)); - REQUIRE(floatCompare(-INF, INF) < 0); - REQUIRE(floatCompare(INF, -INF) > 0); - REQUIRE(floatCompare(INF, INF) == 0); - REQUIRE(floatCompare(-INF, -INF) == 0); + REQUIRE((floatCompare(-INF, INF) < 0)); + REQUIRE((floatCompare(INF, -INF) > 0)); + REQUIRE((floatCompare(INF, INF) == 0)); + REQUIRE((floatCompare(-INF, -INF) == 0)); - REQUIRE(floatCompare(0.0, INF) < 0); - REQUIRE(floatCompare(INF, 0.0) > 0); + REQUIRE((floatCompare(0.0, INF) < 0)); + REQUIRE((floatCompare(INF, 0.0) > 0)); - REQUIRE(floatCompare(0.0, BIG) < 0); - REQUIRE(floatCompare(BIG, 0.0) > 0); + REQUIRE((floatCompare(0.0, BIG) < 0)); + REQUIRE((floatCompare(BIG, 0.0) > 0)); - REQUIRE(floatCompare(SMALL, BIG) < 0); - REQUIRE(floatCompare(BIG, SMALL) > 0); + REQUIRE((floatCompare(SMALL, BIG) < 0)); + REQUIRE((floatCompare(BIG, SMALL) > 0)); - REQUIRE(floatCompare(SMALL, INF) < 0); - REQUIRE(floatCompare(INF, SMALL) > 0); + REQUIRE((floatCompare(SMALL, INF) < 0)); + REQUIRE((floatCompare(INF, SMALL) > 0)); - REQUIRE(floatCompare(BIG, INF) < 0); - REQUIRE(floatCompare(INF, BIG) > 0); + REQUIRE((floatCompare(BIG, INF) < 0)); + REQUIRE((floatCompare(INF, BIG) > 0)); } SECTION("is_less_not_greater")