Skip to content

Commit

Permalink
improved selection implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
KRM7 committed Oct 28, 2024
1 parent a16966c commit 43b7178
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 11 deletions.
1 change: 0 additions & 1 deletion src/algorithm/reference_lines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "../utility/math.hpp"
#include "../utility/utility.hpp"
#include <algorithm>
#include <execution>
#include <numeric>
#include <functional>
#include <iterator>
Expand Down
24 changes: 23 additions & 1 deletion src/utility/algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,28 @@ namespace gapp::detail
return indices;
}

template<std::random_access_iterator Iter, typename T, typename Comp = std::less<std::iter_value_t<Iter>>>
requires std::strict_weak_order<Comp, std::iter_reference_t<Iter>, const T&>
Iter lower_bound(Iter first, Iter last, const T& value, Comp&& comp = {})
{
GAPP_ASSERT(std::distance(first, last) >= 0);
GAPP_ASSERT(std::is_sorted(first, last, comp));

std::size_t length = last - first;
while (length > 32)
{
length = length / 2;
first = std::invoke(comp, first[length], value) ? first + length : first;
}

for (; first != last; ++first)
{
if (!std::invoke(comp, *first, value)) break;
}

return first;
}

template<std::forward_iterator Iter, typename F = std::identity>
requires std::invocable<F&, std::iter_reference_t<Iter>>
constexpr Iter max_element(Iter first, Iter last, F&& transform = {})
Expand Down Expand Up @@ -336,4 +358,4 @@ namespace gapp::detail

} // namespace gapp::detail

#endif // !GAPP_UTILITY_ALGORITHM_HPP
#endif // !GAPP_UTILITY_ALGORITHM_HPP
14 changes: 5 additions & 9 deletions src/utility/rng.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
/* Copyright (c) 2022 Krisztián Rugási. Subject to the MIT License. */

#ifndef GA_UTILITY_RNG_HPP
#define GA_UTILITY_RNG_HPP
#ifndef GAPP_UTILITY_RNG_HPP
#define GAPP_UTILITY_RNG_HPP

#include "algorithm.hpp"
#include "functional.hpp"
#include "distribution.hpp"
#include "small_vector.hpp"
Expand Down Expand Up @@ -530,14 +531,9 @@ namespace gapp::rng

const RealType threshold = rng::randomReal<RealType>(0.0, cdf.back()); // use cdf.back() in case it's not exactly 1.0

if (cdf.size() < 128)
{
return std::distance(cdf.begin(), std::find_if(cdf.begin(), cdf.end(), detail::greater_eq_than(threshold)));
}

return std::distance(cdf.begin(), std::lower_bound(cdf.begin(), cdf.end(), threshold));
return std::distance(cdf.begin(), detail::lower_bound(cdf.begin(), cdf.end(), threshold));
}

} // namespace gapp::rng

#endif // !GA_UTILITY_RNG_HPP
#endif // !GAPP_UTILITY_RNG_HPP
29 changes: 29 additions & 0 deletions test/benchmark/lower_bound.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/benchmark/catch_benchmark.hpp>
#include "utility/algorithm.hpp"
#include "utility/rng.hpp"
#include <algorithm>

using namespace gapp;

std::vector<double> randomVector(std::size_t size)
{
std::vector vec(size, 0.0);
for (double& elem : vec) elem = rng::randomReal();
std::sort(vec.begin(), vec.end());
return vec;
}


TEST_CASE("binary_search", "[benchmark]")
{
const auto vlen = GENERATE(100, 1000, 10000);
WARN("Number of elements: " << vlen);

auto v = randomVector(vlen);

BENCHMARK("std::find_if") { return std::find_if(v.begin(), v.end(), detail::greater_eq_than(rng::randomReal())); };
BENCHMARK("std::lower_bound") { return std::lower_bound(v.begin(), v.end(), rng::randomReal()); };
BENCHMARK("detail::lower_bound") { return detail::lower_bound(v.begin(), v.end(), rng::randomReal()); };
}
26 changes: 26 additions & 0 deletions test/unit/algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,32 @@ TEST_CASE("partial_argsort", "[algorithm]")
}
}

TEST_CASE("lower_bound" "[algorithm]")
{
small_vector nums(500, 0.0);
std::iota(nums.begin(), nums.end(), 0.0);

REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), 0.0) == 0.0);
REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), 1.0) == 1.0);
REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), 2.0) == 2.0);
REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), 120.0) == 120.0);
REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), 499.0) == 499.0);

REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), 1.9) == 2.0);
REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), 1.1) == 2.0);
REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), 332.7) == 333.0);

REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), -1.0) == 0.0);
REQUIRE(*detail::lower_bound(nums.begin(), nums.end(), -100.0) == 0.0);

REQUIRE(detail::lower_bound(nums.begin(), nums.end(), 499.1) == nums.end());
REQUIRE(detail::lower_bound(nums.begin(), nums.end(), 10000.0) == nums.end());

REQUIRE(*detail::lower_bound(nums.rbegin(), nums.rend(), 1.1, std::greater<>{}) == 1.0);

REQUIRE(detail::lower_bound(nums.begin(), nums.begin(), 33.0) == nums.begin());
}

TEST_CASE("max_element", "[algorithm]")
{
const small_vector nums = { 4.0, 0.0, 2.0, 5.0, 1.0 };
Expand Down

0 comments on commit 43b7178

Please sign in to comment.