Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify the Metrics classes #9

Merged
merged 12 commits into from
Sep 5, 2023
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ will install all available configurations:

```shell
git clone https://github.com/KRM7/gapp.git
sudo gapp/build/install.sh
sudo bash ./gapp/build/install.sh
```

Once the library is installed, you can import it using `find_package` and then link
Expand Down
10 changes: 2 additions & 8 deletions docs/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,14 @@ are intended to be used for multi-objective optimization problems.

If you want to track something that doesn't have a metric already implemented
for it, it's possible to implement your own metrics. Metrics must be derived
from the `Monitor` class, and implement 3 methods: `initialize`, `update`, and
`value_at`:
from the `Monitor` class, and implement the `initialize`, and `update` methods:

```cpp
// The second type parameter of Monitor is the type used to store the
// gathered metrics. This will be used as the type of the data_ field.
class MyMetric : public metrics::Monitor<MyMetric, std::vector<double>>
{
public:
// Returns the value of the metric in the given generation.
// Note that this method is not virtual.
double value_at(size_t generation) const noexcept { return data_[generation]; }
private:
// Initialize the metric. Called at the start of a run.
// (optional) Initialize the metric. Called at the start of a run.
void initialize(const GaInfo& ga) override { data_.clear(); }

// Update the metric with a new value from the current generation.
Expand Down
4 changes: 1 addition & 3 deletions examples/9_metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ using namespace gapp;

struct MyMetric : public metrics::Monitor<MyMetric, std::vector<double>>
{
double value_at(size_t generation) const noexcept { return data_[generation]; }
void initialize(const GaInfo&) override { data_.clear(); }
void update(const GaInfo& ga) override { data_.push_back(ga.fitness_matrix()[0][0]); }
};

Expand All @@ -29,6 +27,6 @@ int main()
std::cout << std::format("Generation {}\t| {:.6f}\n", gen + 1, metric[gen]);
}

const auto* hypervol = GA.get_metric_if<metrics::AutoHypervolume>(); // untracked metric
[[maybe_unused]] const auto* hypervol = GA.get_metric_if<metrics::AutoHypervolume>(); // untracked metric
assert(hypervol == nullptr);
}
10 changes: 6 additions & 4 deletions src/algorithm/nd_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace gapp::algorithm::dtl
{
GAPP_ASSERT(std::distance(current, last) >= 0);

return std::find_if(current, last, compose(&FrontInfo::rank, detail::not_equal_to(current->rank)));
return std::ranges::find_if(current, last, detail::not_equal_to(current->rank), &FrontInfo::rank);
}

std::vector<ParetoFrontsRange> paretoFrontBounds(ParetoFronts& pareto_fronts)
Expand Down Expand Up @@ -65,8 +65,7 @@ namespace gapp::algorithm::dtl
auto last_in = first + popsize;

if (last_in == last) return { last, last };

auto partial_front_first = std::find_if(first, last_in, compose(&FrontInfo::rank, detail::equal_to(last_in->rank)));
auto partial_front_first = std::ranges::find_if(first, last_in, detail::equal_to(last_in->rank), &FrontInfo::rank);
auto partial_front_last = nextFrontBegin(std::prev(last_in), last);

return { partial_front_first, partial_front_last };
Expand Down Expand Up @@ -205,7 +204,10 @@ namespace gapp::algorithm::dtl

std::for_each(last_idx, indices.rend(), [&](size_t col) noexcept
{
if (dmat(row, col).load(std::memory_order_relaxed) == MAXIMAL) dmat(row, col).store(NONMAXIMAL, std::memory_order_relaxed);
if (dmat(row, col).load(std::memory_order_relaxed) == MAXIMAL)
{
dmat(row, col).store(NONMAXIMAL, std::memory_order_relaxed);
}
});
});
});
Expand Down
19 changes: 9 additions & 10 deletions src/algorithm/nsga3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "../utility/math.hpp"
#include "../utility/rng.hpp"
#include "../utility/utility.hpp"
#include "../utility/cone_tree.hpp"
#include <algorithm>
#include <execution>
#include <functional>
Expand Down Expand Up @@ -98,7 +97,7 @@ namespace gapp::algorithm
};

RefLineGenerator ref_generator_;
detail::ConeTree ref_lines_;
std::vector<Point> ref_lines_;

std::vector<CandidateInfo> sol_info_;
std::vector<size_t> niche_counts_;
Expand Down Expand Up @@ -200,7 +199,7 @@ namespace gapp::algorithm
for (size_t dim = 0; dim < ideal_point_.size(); dim++)
{
auto weights = weightVector(ideal_point_.size(), dim);
auto ASFi = [&](const auto& fvec) noexcept { return ASF(ideal_point_, weights, fvec); };
auto ASFi = [&](const auto& fvec) { return ASF(ideal_point_, weights, fvec); };

std::vector<double> chebysev_distances(popsize + extreme_points_.size());

Expand Down Expand Up @@ -244,10 +243,12 @@ namespace gapp::algorithm
{
const FitnessVector fnorm = normalizeFitnessVec(first[sol.idx], ideal_point_, nadir_point_);

const auto best = ref_lines_.findBestMatch(fnorm);
auto idistance = [&](const auto& line) { return std::inner_product(fnorm.begin(), fnorm.end(), line.begin(), 0.0); };

sol_info_[sol.idx].ref_idx = size_t(best.elem - ref_lines_.begin());
sol_info_[sol.idx].ref_dist = math::perpendicularDistanceSq(*best.elem, fnorm);
auto closest = detail::max_element(ref_lines_.begin(), ref_lines_.end(), idistance);

sol_info_[sol.idx].ref_idx = std::distance(ref_lines_.begin(), closest);
sol_info_[sol.idx].ref_dist = math::perpendicularDistanceSq(*closest, fnorm);
});
}

Expand Down Expand Up @@ -349,8 +350,7 @@ namespace gapp::algorithm
pimpl_->ideal_point_ = detail::maxFitness(fitness_matrix.begin(), fitness_matrix.end());
pimpl_->extreme_points_ = {};

auto ref_lines = pimpl_->generateReferencePoints(ga.num_objectives(), ga.population_size());
pimpl_->ref_lines_ = detail::ConeTree{ ref_lines };
pimpl_->ref_lines_ = pimpl_->generateReferencePoints(ga.num_objectives(), ga.population_size());
pimpl_->niche_counts_.resize(pimpl_->ref_lines_.size());

auto pfronts = nonDominatedSort(fitness_matrix.begin(), fitness_matrix.end());
Expand Down Expand Up @@ -421,8 +421,7 @@ namespace gapp::algorithm

std::vector<size_t> NSGA3::optimalSolutionsImpl(const GaInfo&) const
{
return detail::find_indices(pimpl_->sol_info_,
detail::compose(&Impl::CandidateInfo::rank, detail::equal_to(0_sz)));
return detail::find_indices(pimpl_->sol_info_, [](const Impl::CandidateInfo& sol) { return sol.rank == 0; });
}

} // namespace gapp::algorithm
13 changes: 12 additions & 1 deletion src/core/ga_base.impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ namespace gapp

metrics_.initialize(*this);
metrics_.update(*this);

if (end_of_generation_callback_) end_of_generation_callback_(*this);
}

template<typename T>
Expand Down Expand Up @@ -481,7 +483,16 @@ namespace gapp
auto optimal_pop = algorithm_->optimalSolutions(*this, pop);

optimal_sols = detail::mergeParetoSets(std::move(optimal_sols), std::move(optimal_pop));
detail::erase_duplicates(optimal_sols);

/* Duplicate elements are removed from optimal_sols using exact comparison
* of the chromosomes in order to avoid issues with using a non-transitive
* comparison function for std::sort and std::unique. */
auto chrom_eq = [](const auto& lhs, const auto& rhs) { return lhs.chromosome == rhs.chromosome; };
auto chrom_less = [](const auto& lhs, const auto& rhs) { return lhs.chromosome < rhs.chromosome; };

std::sort(optimal_sols.begin(), optimal_sols.end(), chrom_less);
auto last = std::unique(optimal_sols.begin(), optimal_sols.end(), chrom_eq);
optimal_sols.erase(last, optimal_sols.end());
}

template<typename T>
Expand Down
21 changes: 0 additions & 21 deletions src/metrics/distribution_metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@ namespace gapp::metrics
{
using math::Point;

std::span<const double> NadirPoint::value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < data_.size());

return data_[generation];
}

void NadirPoint::initialize(const GaInfo& ga)
{
data_.clear();
Expand All @@ -39,13 +32,6 @@ namespace gapp::metrics
ref_point_(std::move(ref_point))
{}

double Hypervolume::value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < data_.size());

return data_[generation];
}

void Hypervolume::initialize(const GaInfo& ga)
{
GAPP_ASSERT(ref_point_.size() == ga.num_objectives());
Expand All @@ -63,13 +49,6 @@ namespace gapp::metrics
}


double AutoHypervolume::value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < data_.size());

return data_[generation];
}

void AutoHypervolume::initialize(const GaInfo& ga)
{
data_.clear();
Expand Down
6 changes: 0 additions & 6 deletions src/metrics/distribution_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ namespace gapp::metrics
*/
class NadirPoint final : public Monitor<NadirPoint, FitnessMatrix>
{
public:
std::span<const double> value_at(size_t generation) const noexcept;
private:
void initialize(const GaInfo& ga) override;
void update(const GaInfo& ga) override;
};
Expand Down Expand Up @@ -50,8 +47,6 @@ namespace gapp::metrics
/** @returns The reference point used for computing the hypervolumes. */
const FitnessVector& ref_point() const noexcept { return ref_point_; }

double value_at(size_t generation) const noexcept;

private:
void initialize(const GaInfo& ga) override;
void update(const GaInfo& ga) override;
Expand Down Expand Up @@ -79,7 +74,6 @@ namespace gapp::metrics
/** @returns The reference point used for computing the hypervolumes. */
const FitnessVector& ref_point() const noexcept { return worst_point_; }

double value_at(size_t generation) const noexcept;
private:
void initialize(const GaInfo& ga) override;
void update(const GaInfo& ga) override;
Expand Down
7 changes: 0 additions & 7 deletions src/metrics/fitness_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ namespace gapp::metrics
template<typename Derived>
class FitnessMonitor : public Monitor<Derived, FitnessMatrix>
{
public:
std::span<const double> value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < this->data_.size());

return this->data_[generation];
}
private:
void initialize(const GaInfo& ga) override
{
Expand Down
5 changes: 2 additions & 3 deletions src/metrics/metric_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace gapp::detail
MetricSet::MetricSet(Metrics... metrics)
{
metrics_.reserve(sizeof...(metrics));
(metrics_.push_back(std::make_unique<Metrics>(std::move(metrics))), ...);
( metrics_.push_back(std::make_unique<Metrics>(std::move(metrics))), ... );
}

template<typename Metric>
Expand All @@ -60,8 +60,7 @@ namespace gapp::detail
{
auto found = std::find_if(metrics_.begin(), metrics_.end(), [](const auto& metric) { return metric->type_id() == detail::type_id<Metric>; });

if (found == metrics_.end()) return nullptr;
return static_cast<const Metric*>(found->get());
return found != metrics_.end() ? static_cast<Metric*>(found->get()) : nullptr;
}

} // namespace gapp::detail
Expand Down
7 changes: 0 additions & 7 deletions src/metrics/misc_metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,4 @@ namespace gapp::metrics
data_.push_back(sum_ - old_sum);
}

size_t FitnessEvaluations::value_at(size_t generation) const noexcept
{
GAPP_ASSERT(generation < data_.size());

return data_[generation];
}

} // namespace gapp::metrics
3 changes: 0 additions & 3 deletions src/metrics/misc_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ namespace gapp::metrics
/** Record the number of fitness function evaluations performed in each generation. */
class FitnessEvaluations final : public Monitor<FitnessEvaluations, std::vector<size_t>>
{
public:
size_t value_at(size_t generation) const noexcept;
private:
void initialize(const GaInfo& ga) override;
void update(const GaInfo& ga) override;

Expand Down
33 changes: 11 additions & 22 deletions src/metrics/monitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "monitor_base.hpp"
#include "../utility/type_id.hpp"
#include "../utility/type_traits.hpp"
#include "../utility/concepts.hpp"
#include "../utility/utility.hpp"
#include <vector>
#include <type_traits>
#include <cstddef>
Expand All @@ -17,26 +19,24 @@ namespace gapp::metrics
* Metrics can be used to track certain attributes of the %GAs
* in every generation throughout a run.
*
* New metrics should be derived from this class, and they should implement the following 3 methods:
* New metrics should be derived from this class, and they should implement the following methods:
*
* - initialize : Used to initialize the monitor at the start of a run.
* - update : Used to update the monitored attribute once every generation.
* - value_at : Returns the value of the monitored metric in a specific generation
(must be implemented as a public, non-virtual function).
* - initialize (optional) : Used to initialize the monitor at the start of a run.
* - update : Used to update the monitored attribute once every generation.
*
* @note The monitor doesn't have access to any encoding specific information about a %GA
* by default, so no such information can be tracked.
*
* @tparam Derived The type of the derived class.
* @tparam MetricData The type of the container used to store the monitored metrics (eg. std::vector).
*/
template<typename Derived, typename MetricData>
template<typename Derived, detail::IndexableContainer MetricData>
class Monitor : public MonitorBase
{
public:
/** @returns The value of the tracked metric in the specified @p generation. */
[[nodiscard]]
constexpr decltype(auto) operator[](size_t generation) const { return derived().value_at(generation); }
constexpr auto operator[](size_t generation) const noexcept { GAPP_ASSERT(generation < data_.size()); return data_[generation]; }

/** @returns The data collected by the monitor throughout the run. */
[[nodiscard]]
Expand All @@ -47,28 +47,17 @@ namespace gapp::metrics
constexpr size_t size() const noexcept { return data_.size(); }

/** @returns An iterator to the first element of the metric data. */
constexpr auto begin() const { return data_.begin(); }
constexpr auto begin() const noexcept { return data_.begin(); }

/** @returns An iterator to one past the last element of the metric data. */
constexpr auto end() const { return data_.end(); }
constexpr auto end() const noexcept { return data_.end(); }

void initialize(const GaInfo&) override { data_.clear(); }

protected:
MetricData data_;

constexpr Monitor() noexcept
{
static_assert(!std::is_abstract_v<Derived>,
"The Derived class should implement all the virtual functions of Monitor.");
static_assert(detail::is_derived_from_spec_of_v<Derived, Monitor>,
"The first type parameter of the Monitor class must be the derived class.");
static_assert(std::is_invocable_v<decltype(&Derived::value_at), const Derived&, size_t>,
"Classes derived from Monitor must implement a public 'value_at(size_t) const' function.");
}

size_t type_id() const noexcept final { return detail::type_id<Derived>; }

private:
constexpr const Derived& derived() const noexcept { return static_cast<const Derived&>(*this); }
};

} // namespace gapp::metrics
Expand Down
Loading
Loading