Skip to content

Commit

Permalink
rcu
Browse files Browse the repository at this point in the history
  • Loading branch information
KRM7 committed Sep 3, 2023
1 parent c5cd8b0 commit d649785
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 16 deletions.
120 changes: 120 additions & 0 deletions src/utility/rcu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/* Copyright (c) 2023 Krisztián Rugási. Subject to the MIT License. */

#ifndef GA_UTILITY_RCU_HPP
#define GA_UTILITY_RCU_HPP

#include "utility.hpp"
#include <atomic>
#include <mutex>
#include <shared_mutex>
#include <vector>
#include <tuple>
#include <cstdint>
#include <cstddef>

namespace gapp::detail
{
struct default_rcu_domain_tag {};

template<typename = default_rcu_domain_tag>
class rcu_domain
{
public:
inline static void read_lock() noexcept
{
reader.epoch.store(writer_epoch.load(std::memory_order_relaxed), std::memory_order_release);
std::ignore = reader.epoch.load(std::memory_order_acquire);
}

inline static void read_unlock() noexcept
{
reader.epoch.store(NOT_READING, std::memory_order_release);
}

inline static void synchronize() noexcept
{
uint64_t current = writer_epoch.load(std::memory_order_acquire);
uint64_t target = current + 1;
writer_epoch.compare_exchange_strong(current, target, std::memory_order_release);

std::shared_lock _{ reader_list_mtx };

for (const registered_reader* reader : reader_list)
{
while (reader->epoch.load(std::memory_order_acquire) < target) { GAPP_PAUSE(); }
}
}

private:
struct registered_reader
{
registered_reader() noexcept
{
std::unique_lock _{ reader_list_mtx };
reader_list.push_back(this);
}

~registered_reader() noexcept
{
std::unique_lock _{ reader_list_mtx };
std::erase(reader_list, this);
}

std::atomic<uint64_t> epoch = NOT_READING;
};

inline static constexpr uint64_t NOT_READING = std::numeric_limits<uint64_t>::max();

GAPP_API inline static std::vector<registered_reader*> reader_list;
GAPP_API inline static std::shared_mutex reader_list_mtx;

GAPP_API alignas(128) inline static constinit std::atomic<uint64_t> writer_epoch = 0;
alignas(128) inline static thread_local registered_reader reader;
};


template<typename T, typename D = default_rcu_domain_tag>
class rcu_obj
{
public:
template<typename... Args>
constexpr rcu_obj(Args&&... args) :
data_(new T(std::forward<Args>(args)...))
{}

~rcu_obj() noexcept
{
T* ptr = data_.load(std::memory_order_consume);
rcu_domain<D>::synchronize();
delete ptr;
}

template<typename U>
rcu_obj& operator=(U&& value)
{
T* new_ptr = new T(std::forward<U>(value));
T* old_ptr = data_.exchange(new_ptr, std::memory_order_acq_rel);
rcu_domain<D>::synchronize();
delete old_ptr;

return *this;
}

T& get() const noexcept
{
return *data_.load(std::memory_order_consume);
}

T& operator*() const noexcept { return get(); }
T* operator->() const noexcept { return std::addressof(get()); }

void lock() const noexcept { rcu_domain<D>::read_lock(); }
void unlock() const noexcept { rcu_domain<D>::read_unlock(); }

private:
std::atomic<T*> data_;
};

} // namespace gapp::detail

#endif // !GA_UTILITY_RCU_HPP
42 changes: 26 additions & 16 deletions src/utility/rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "type_traits.hpp"
#include "concepts.hpp"
#include "bit.hpp"
#include "rcu.hpp"
#include <algorithm>
#include <functional>
#include <array>
Expand Down Expand Up @@ -163,7 +164,7 @@ namespace gapp::rng
return { seed_seq_gen(), seed_seq_gen() };
}

state_type state_;
alignas(128) state_type state_;
};


Expand All @@ -183,17 +184,21 @@ namespace gapp::rng
using state_type = Xoroshiro128p::state_type;

/** @return The next number of the sequence. Thread-safe. */
result_type operator()() noexcept { return std::invoke(generator_.instance); }
result_type operator()() const noexcept
{
std::scoped_lock _{ generator_.instance };
return std::invoke(*generator_.instance);
}

/**
* Set a new seed for the generator.
* Should not be called concurrently with operator().
*/
void seed(std::uint64_t seed) noexcept
/** Set a new seed for the generator. Thread-safe. */
static void seed(std::uint64_t seed)
{
std::scoped_lock _{ generator_list_mtx_ };
global_generator_.seed(seed);
for (auto* generator : generator_list_) { *generator = global_generator_.jump(); }
global_generator.seed(seed);
for (Generator* generator : generator_list)
{
*generator = global_generator.jump();
}
}

/** @returns The smallest possible value that can be generated. */
Expand All @@ -203,28 +208,30 @@ namespace gapp::rng
static constexpr result_type max() noexcept { return Xoroshiro128p::max(); }

private:
using Generator = detail::rcu_obj<Xoroshiro128p>;

struct RegisteredGenerator
{
RegisteredGenerator()
{
std::scoped_lock _{ generator_list_mtx_ };
instance = global_generator_.jump();
generator_list_.push_back(std::addressof(instance));
instance = global_generator.jump();
generator_list.push_back(std::addressof(instance));
}

~RegisteredGenerator() noexcept
{
std::scoped_lock _{ generator_list_mtx_ };
std::erase(generator_list_, std::addressof(instance));
std::erase(generator_list, std::addressof(instance));
}

Xoroshiro128p instance{ 0 };
Generator instance{ 0 };
};

GAPP_API inline static constinit Xoroshiro128p global_generator_{ GAPP_SEED };
GAPP_API inline static std::vector<Xoroshiro128p*> generator_list_;
GAPP_API inline static constinit Xoroshiro128p global_generator{ GAPP_SEED };
GAPP_API inline static std::vector<Generator*> generator_list;
GAPP_API inline static std::mutex generator_list_mtx_;
inline static thread_local RegisteredGenerator generator_;
alignas(128) inline static thread_local RegisteredGenerator generator_;
};


Expand Down Expand Up @@ -255,6 +262,7 @@ namespace gapp::rng
template<std::integral IntType = int>
IntType randomBinomial(IntType n, double p);


/** Generate a random index for a container. */
template<detail::IndexableContainer T>
size_t randomIdx(const T& cont);
Expand All @@ -263,10 +271,12 @@ namespace gapp::rng
template<std::forward_iterator Iter>
Iter randomElement(Iter first, Iter last);


/** Generate @p count unique integers from the half-open range [@p lbound, @p ubound). */
template<std::integral IntType>
std::vector<IntType> sampleUnique(IntType lbound, IntType ubound, size_t count);


/** Select an index based on a discrete CDF. */
template<std::floating_point T>
size_t sampleCdf(std::span<const T> cdf);
Expand Down

0 comments on commit d649785

Please sign in to comment.