From 40d30c281a571f58ab4c4d6695f5b01d75cc0653 Mon Sep 17 00:00:00 2001 From: KRM7 <70973547+KRM7@users.noreply.github.com> Date: Sun, 20 Aug 2023 11:33:30 +0200 Subject: [PATCH] simplify the inteface of the metric classes --- docs/metrics.md | 10 ++------- examples/9_metrics.cpp | 2 -- src/metrics/distribution_metrics.cpp | 21 ------------------ src/metrics/distribution_metrics.hpp | 6 ----- src/metrics/fitness_metrics.hpp | 7 ------ src/metrics/metric_set.hpp | 5 ++--- src/metrics/misc_metrics.cpp | 7 ------ src/metrics/misc_metrics.hpp | 3 --- src/metrics/monitor.hpp | 33 ++++++++++------------------ test/unit/metrics.cpp | 6 ++--- 10 files changed, 18 insertions(+), 82 deletions(-) diff --git a/docs/metrics.md b/docs/metrics.md index 86597b86..8ba1b315 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -64,20 +64,14 @@ are intended to be used for multi-objective optimization problems. If you want to track something that doesn't have a metric already implemented for it, it's possible to implement your own metrics. Metrics must be derived -from the `Monitor` class, and implement 3 methods: `initialize`, `update`, and -`value_at`: +from the `Monitor` class, and implement the `initialize`, and `update` methods: ```cpp // The second type parameter of Monitor is the type used to store the // gathered metrics. This will be used as the type of the data_ field. class MyMetric : public metrics::Monitor> { -public: - // Returns the value of the metric in the given generation. - // Note that this method is not virtual. - double value_at(size_t generation) const noexcept { return data_[generation]; } -private: - // Initialize the metric. Called at the start of a run. + // (optional) Initialize the metric. Called at the start of a run. void initialize(const GaInfo& ga) override { data_.clear(); } // Update the metric with a new value from the current generation. diff --git a/examples/9_metrics.cpp b/examples/9_metrics.cpp index a9eb9913..b38d6775 100644 --- a/examples/9_metrics.cpp +++ b/examples/9_metrics.cpp @@ -10,8 +10,6 @@ using namespace gapp; struct MyMetric : public metrics::Monitor> { - double value_at(size_t generation) const noexcept { return data_[generation]; } - void initialize(const GaInfo&) override { data_.clear(); } void update(const GaInfo& ga) override { data_.push_back(ga.fitness_matrix()[0][0]); } }; diff --git a/src/metrics/distribution_metrics.cpp b/src/metrics/distribution_metrics.cpp index a8f8495a..07c72ab4 100644 --- a/src/metrics/distribution_metrics.cpp +++ b/src/metrics/distribution_metrics.cpp @@ -13,13 +13,6 @@ namespace gapp::metrics { using math::Point; - std::span NadirPoint::value_at(size_t generation) const noexcept - { - GAPP_ASSERT(generation < data_.size()); - - return data_[generation]; - } - void NadirPoint::initialize(const GaInfo& ga) { data_.clear(); @@ -39,13 +32,6 @@ namespace gapp::metrics ref_point_(std::move(ref_point)) {} - double Hypervolume::value_at(size_t generation) const noexcept - { - GAPP_ASSERT(generation < data_.size()); - - return data_[generation]; - } - void Hypervolume::initialize(const GaInfo& ga) { GAPP_ASSERT(ref_point_.size() == ga.num_objectives()); @@ -63,13 +49,6 @@ namespace gapp::metrics } - double AutoHypervolume::value_at(size_t generation) const noexcept - { - GAPP_ASSERT(generation < data_.size()); - - return data_[generation]; - } - void AutoHypervolume::initialize(const GaInfo& ga) { data_.clear(); diff --git a/src/metrics/distribution_metrics.hpp b/src/metrics/distribution_metrics.hpp index 312efb9b..7a51d36e 100644 --- a/src/metrics/distribution_metrics.hpp +++ b/src/metrics/distribution_metrics.hpp @@ -16,9 +16,6 @@ namespace gapp::metrics */ class NadirPoint final : public Monitor { - public: - std::span value_at(size_t generation) const noexcept; - private: void initialize(const GaInfo& ga) override; void update(const GaInfo& ga) override; }; @@ -50,8 +47,6 @@ namespace gapp::metrics /** @returns The reference point used for computing the hypervolumes. */ const FitnessVector& ref_point() const noexcept { return ref_point_; } - double value_at(size_t generation) const noexcept; - private: void initialize(const GaInfo& ga) override; void update(const GaInfo& ga) override; @@ -79,7 +74,6 @@ namespace gapp::metrics /** @returns The reference point used for computing the hypervolumes. */ const FitnessVector& ref_point() const noexcept { return worst_point_; } - double value_at(size_t generation) const noexcept; private: void initialize(const GaInfo& ga) override; void update(const GaInfo& ga) override; diff --git a/src/metrics/fitness_metrics.hpp b/src/metrics/fitness_metrics.hpp index a5a405e0..32de34ce 100644 --- a/src/metrics/fitness_metrics.hpp +++ b/src/metrics/fitness_metrics.hpp @@ -16,13 +16,6 @@ namespace gapp::metrics template class FitnessMonitor : public Monitor { - public: - std::span value_at(size_t generation) const noexcept - { - GAPP_ASSERT(generation < this->data_.size()); - - return this->data_[generation]; - } private: void initialize(const GaInfo& ga) override { diff --git a/src/metrics/metric_set.hpp b/src/metrics/metric_set.hpp index 075d75c2..9732737a 100644 --- a/src/metrics/metric_set.hpp +++ b/src/metrics/metric_set.hpp @@ -51,7 +51,7 @@ namespace gapp::detail MetricSet::MetricSet(Metrics... metrics) { metrics_.reserve(sizeof...(metrics)); - (metrics_.push_back(std::make_unique(std::move(metrics))), ...); + ( metrics_.push_back(std::make_unique(std::move(metrics))), ... ); } template @@ -60,8 +60,7 @@ namespace gapp::detail { auto found = std::find_if(metrics_.begin(), metrics_.end(), [](const auto& metric) { return metric->type_id() == detail::type_id; }); - if (found == metrics_.end()) return nullptr; - return static_cast(found->get()); + return found != metrics_.end() ? static_cast(found->get()) : nullptr; } } // namespace gapp::detail diff --git a/src/metrics/misc_metrics.cpp b/src/metrics/misc_metrics.cpp index 25f662fa..f3c6145f 100644 --- a/src/metrics/misc_metrics.cpp +++ b/src/metrics/misc_metrics.cpp @@ -22,11 +22,4 @@ namespace gapp::metrics data_.push_back(sum_ - old_sum); } - size_t FitnessEvaluations::value_at(size_t generation) const noexcept - { - GAPP_ASSERT(generation < data_.size()); - - return data_[generation]; - } - } // namespace gapp::metrics diff --git a/src/metrics/misc_metrics.hpp b/src/metrics/misc_metrics.hpp index 7aab1fa5..b640ca93 100644 --- a/src/metrics/misc_metrics.hpp +++ b/src/metrics/misc_metrics.hpp @@ -12,9 +12,6 @@ namespace gapp::metrics /** Record the number of fitness function evaluations performed in each generation. */ class FitnessEvaluations final : public Monitor> { - public: - size_t value_at(size_t generation) const noexcept; - private: void initialize(const GaInfo& ga) override; void update(const GaInfo& ga) override; diff --git a/src/metrics/monitor.hpp b/src/metrics/monitor.hpp index 3575977a..d1be4544 100644 --- a/src/metrics/monitor.hpp +++ b/src/metrics/monitor.hpp @@ -6,6 +6,8 @@ #include "monitor_base.hpp" #include "../utility/type_id.hpp" #include "../utility/type_traits.hpp" +#include "../utility/concepts.hpp" +#include "../utility/utility.hpp" #include #include #include @@ -17,12 +19,10 @@ namespace gapp::metrics * Metrics can be used to track certain attributes of the %GAs * in every generation throughout a run. * - * New metrics should be derived from this class, and they should implement the following 3 methods: + * New metrics should be derived from this class, and they should implement the following methods: * - * - initialize : Used to initialize the monitor at the start of a run. - * - update : Used to update the monitored attribute once every generation. - * - value_at : Returns the value of the monitored metric in a specific generation - (must be implemented as a public, non-virtual function). + * - initialize (optional) : Used to initialize the monitor at the start of a run. + * - update : Used to update the monitored attribute once every generation. * * @note The monitor doesn't have access to any encoding specific information about a %GA * by default, so no such information can be tracked. @@ -30,13 +30,13 @@ namespace gapp::metrics * @tparam Derived The type of the derived class. * @tparam MetricData The type of the container used to store the monitored metrics (eg. std::vector). */ - template + template class Monitor : public MonitorBase { public: /** @returns The value of the tracked metric in the specified @p generation. */ [[nodiscard]] - constexpr decltype(auto) operator[](size_t generation) const { return derived().value_at(generation); } + constexpr auto operator[](size_t generation) const noexcept { GAPP_ASSERT(generation < data_.size()); return data_[generation]; } /** @returns The data collected by the monitor throughout the run. */ [[nodiscard]] @@ -47,28 +47,17 @@ namespace gapp::metrics constexpr size_t size() const noexcept { return data_.size(); } /** @returns An iterator to the first element of the metric data. */ - constexpr auto begin() const { return data_.begin(); } + constexpr auto begin() const noexcept { return data_.begin(); } /** @returns An iterator to one past the last element of the metric data. */ - constexpr auto end() const { return data_.end(); } + constexpr auto end() const noexcept { return data_.end(); } + + void initialize(const GaInfo&) override { data_.clear(); } protected: MetricData data_; - constexpr Monitor() noexcept - { - static_assert(!std::is_abstract_v, - "The Derived class should implement all the virtual functions of Monitor."); - static_assert(detail::is_derived_from_spec_of_v, - "The first type parameter of the Monitor class must be the derived class."); - static_assert(std::is_invocable_v, - "Classes derived from Monitor must implement a public 'value_at(size_t) const' function."); - } - size_t type_id() const noexcept final { return detail::type_id; } - - private: - constexpr const Derived& derived() const noexcept { return static_cast(*this); } }; } // namespace gapp::metrics diff --git a/test/unit/metrics.cpp b/test/unit/metrics.cpp index cf96d2dc..b00e7026 100644 --- a/test/unit/metrics.cpp +++ b/test/unit/metrics.cpp @@ -31,9 +31,9 @@ TEMPLATE_TEST_CASE("fitness_metrics", "[metrics]", FitnessMin, FitnessMax, Fitne REQUIRE(metric.data().size() == num_gen); REQUIRE(std::all_of(metric.begin(), metric.end(), detail::is_size(num_obj))); - REQUIRE(metric.value_at(4).size() == num_obj); + REQUIRE(metric[4].size() == num_obj); - const auto& val = metric.value_at(7); + const auto& val = metric[7]; REQUIRE(std::all_of(val.begin(), val.end(), detail::equal_to(0.0))); } @@ -48,7 +48,7 @@ TEST_CASE("nadir_point_metric", "[metrics]") REQUIRE(metric.size() == num_gen); REQUIRE(std::all_of(metric.begin(), metric.end(), detail::is_size(num_obj))); - const auto& val = metric.value_at(5); + const auto& val = metric[5]; REQUIRE(std::all_of(val.begin(), val.end(), detail::equal_to(0.0))); }