Skip to content

Commit

Permalink
simplify the inteface of the metric classes
Browse files Browse the repository at this point in the history
  • Loading branch information
KRM7 committed Aug 20, 2023
1 parent 58fe627 commit 40d30c2
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 82 deletions.
10 changes: 2 additions & 8 deletions docs/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,14 @@ are intended to be used for multi-objective optimization problems.

If you want to track something that doesn't have a metric already implemented
for it, it's possible to implement your own metrics. Metrics must be derived
from the `Monitor` class, and implement 3 methods: `initialize`, `update`, and
`value_at`:
from the `Monitor` class, and implement the `initialize`, and `update` methods:

```cpp
// The second type parameter of Monitor is the type used to store the
// gathered metrics. This will be used as the type of the data_ field.
class MyMetric : public metrics::Monitor<MyMetric, std::vector<double>>
{
public:
// Returns the value of the metric in the given generation.
// Note that this method is not virtual.
double value_at(size_t generation) const noexcept { return data_[generation]; }
private:
// Initialize the metric. Called at the start of a run.
// (optional) Initialize the metric. Called at the start of a run.
void initialize(const GaInfo& ga) override { data_.clear(); }

// Update the metric with a new value from the current generation.
Expand Down
2 changes: 0 additions & 2 deletions examples/9_metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ using namespace gapp;

struct MyMetric : public metrics::Monitor<MyMetric, std::vector<double>>
{
double value_at(size_t generation) const noexcept { return data_[generation]; }
void initialize(const GaInfo&) override { data_.clear(); }
void update(const GaInfo& ga) override { data_.push_back(ga.fitness_matrix()[0][0]); }
};

Expand Down
21 changes: 0 additions & 21 deletions src/metrics/distribution_metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@ namespace gapp::metrics
{
using math::Point;

std::span<const double> NadirPoint::value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < data_.size());

return data_[generation];
}

void NadirPoint::initialize(const GaInfo& ga)
{
data_.clear();
Expand All @@ -39,13 +32,6 @@ namespace gapp::metrics
ref_point_(std::move(ref_point))
{}

double Hypervolume::value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < data_.size());

return data_[generation];
}

void Hypervolume::initialize(const GaInfo& ga)
{
GAPP_ASSERT(ref_point_.size() == ga.num_objectives());
Expand All @@ -63,13 +49,6 @@ namespace gapp::metrics
}


double AutoHypervolume::value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < data_.size());

return data_[generation];
}

void AutoHypervolume::initialize(const GaInfo& ga)
{
data_.clear();
Expand Down
6 changes: 0 additions & 6 deletions src/metrics/distribution_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ namespace gapp::metrics
*/
class NadirPoint final : public Monitor<NadirPoint, FitnessMatrix>
{
public:
std::span<const double> value_at(size_t generation) const noexcept;
private:
void initialize(const GaInfo& ga) override;
void update(const GaInfo& ga) override;
};
Expand Down Expand Up @@ -50,8 +47,6 @@ namespace gapp::metrics
/** @returns The reference point used for computing the hypervolumes. */
const FitnessVector& ref_point() const noexcept { return ref_point_; }

double value_at(size_t generation) const noexcept;

private:
void initialize(const GaInfo& ga) override;
void update(const GaInfo& ga) override;
Expand Down Expand Up @@ -79,7 +74,6 @@ namespace gapp::metrics
/** @returns The reference point used for computing the hypervolumes. */
const FitnessVector& ref_point() const noexcept { return worst_point_; }

double value_at(size_t generation) const noexcept;
private:
void initialize(const GaInfo& ga) override;
void update(const GaInfo& ga) override;
Expand Down
7 changes: 0 additions & 7 deletions src/metrics/fitness_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ namespace gapp::metrics
template<typename Derived>
class FitnessMonitor : public Monitor<Derived, FitnessMatrix>
{
public:
std::span<const double> value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < this->data_.size());

return this->data_[generation];
}
private:
void initialize(const GaInfo& ga) override
{
Expand Down
5 changes: 2 additions & 3 deletions src/metrics/metric_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace gapp::detail
MetricSet::MetricSet(Metrics... metrics)
{
metrics_.reserve(sizeof...(metrics));
(metrics_.push_back(std::make_unique<Metrics>(std::move(metrics))), ...);
( metrics_.push_back(std::make_unique<Metrics>(std::move(metrics))), ... );
}

template<typename Metric>
Expand All @@ -60,8 +60,7 @@ namespace gapp::detail
{
auto found = std::find_if(metrics_.begin(), metrics_.end(), [](const auto& metric) { return metric->type_id() == detail::type_id<Metric>; });

if (found == metrics_.end()) return nullptr;
return static_cast<const Metric*>(found->get());
return found != metrics_.end() ? static_cast<Metric*>(found->get()) : nullptr;
}

} // namespace gapp::detail
Expand Down
7 changes: 0 additions & 7 deletions src/metrics/misc_metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,4 @@ namespace gapp::metrics
data_.push_back(sum_ - old_sum);
}

size_t FitnessEvaluations::value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < data_.size());

return data_[generation];
}

} // namespace gapp::metrics
3 changes: 0 additions & 3 deletions src/metrics/misc_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ namespace gapp::metrics
/** Record the number of fitness function evaluations performed in each generation. */
class FitnessEvaluations final : public Monitor<FitnessEvaluations, std::vector<size_t>>
{
public:
size_t value_at(size_t generation) const noexcept;
private:
void initialize(const GaInfo& ga) override;
void update(const GaInfo& ga) override;

Expand Down
33 changes: 11 additions & 22 deletions src/metrics/monitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "monitor_base.hpp"
#include "../utility/type_id.hpp"
#include "../utility/type_traits.hpp"
#include "../utility/concepts.hpp"
#include "../utility/utility.hpp"
#include <vector>
#include <type_traits>
#include <cstddef>
Expand All @@ -17,26 +19,24 @@ namespace gapp::metrics
* Metrics can be used to track certain attributes of the %GAs
* in every generation throughout a run.
*
* New metrics should be derived from this class, and they should implement the following 3 methods:
* New metrics should be derived from this class, and they should implement the following methods:
*
* - initialize : Used to initialize the monitor at the start of a run.
* - update : Used to update the monitored attribute once every generation.
* - value_at : Returns the value of the monitored metric in a specific generation
(must be implemented as a public, non-virtual function).
* - initialize (optional) : Used to initialize the monitor at the start of a run.
* - update : Used to update the monitored attribute once every generation.
*
* @note The monitor doesn't have access to any encoding specific information about a %GA
* by default, so no such information can be tracked.
*
* @tparam Derived The type of the derived class.
* @tparam MetricData The type of the container used to store the monitored metrics (eg. std::vector).
*/
template<typename Derived, typename MetricData>
template<typename Derived, detail::IndexableContainer MetricData>
class Monitor : public MonitorBase
{
public:
/** @returns The value of the tracked metric in the specified @p generation. */
[[nodiscard]]
constexpr decltype(auto) operator[](size_t generation) const { return derived().value_at(generation); }
constexpr auto operator[](size_t generation) const noexcept { GAPP_ASSERT(generation < data_.size()); return data_[generation]; }

/** @returns The data collected by the monitor throughout the run. */
[[nodiscard]]
Expand All @@ -47,28 +47,17 @@ namespace gapp::metrics
constexpr size_t size() const noexcept { return data_.size(); }

/** @returns An iterator to the first element of the metric data. */
constexpr auto begin() const { return data_.begin(); }
constexpr auto begin() const noexcept { return data_.begin(); }

/** @returns An iterator to one past the last element of the metric data. */
constexpr auto end() const { return data_.end(); }
constexpr auto end() const noexcept { return data_.end(); }

void initialize(const GaInfo&) override { data_.clear(); }

protected:
MetricData data_;

constexpr Monitor() noexcept
{
static_assert(!std::is_abstract_v<Derived>,
"The Derived class should implement all the virtual functions of Monitor.");
static_assert(detail::is_derived_from_spec_of_v<Derived, Monitor>,
"The first type parameter of the Monitor class must be the derived class.");
static_assert(std::is_invocable_v<decltype(&Derived::value_at), const Derived&, size_t>,
"Classes derived from Monitor must implement a public 'value_at(size_t) const' function.");
}

size_t type_id() const noexcept final { return detail::type_id<Derived>; }

private:
constexpr const Derived& derived() const noexcept { return static_cast<const Derived&>(*this); }
};

} // namespace gapp::metrics
Expand Down
6 changes: 3 additions & 3 deletions test/unit/metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ TEMPLATE_TEST_CASE("fitness_metrics", "[metrics]", FitnessMin, FitnessMax, Fitne
REQUIRE(metric.data().size() == num_gen);

REQUIRE(std::all_of(metric.begin(), metric.end(), detail::is_size(num_obj)));
REQUIRE(metric.value_at(4).size() == num_obj);
REQUIRE(metric[4].size() == num_obj);

const auto& val = metric.value_at(7);
const auto& val = metric[7];
REQUIRE(std::all_of(val.begin(), val.end(), detail::equal_to(0.0)));
}

Expand All @@ -48,7 +48,7 @@ TEST_CASE("nadir_point_metric", "[metrics]")
REQUIRE(metric.size() == num_gen);
REQUIRE(std::all_of(metric.begin(), metric.end(), detail::is_size(num_obj)));

const auto& val = metric.value_at(5);
const auto& val = metric[5];
REQUIRE(std::all_of(val.begin(), val.end(), detail::equal_to(0.0)));
}

Expand Down

0 comments on commit 40d30c2

Please sign in to comment.