Skip to content

Commit 4d61953

Browse files
authored
implement mutex (#7)
* implement mutex * add lock guard
1 parent bc123c9 commit 4d61953

File tree

3 files changed

+244
-32
lines changed

3 files changed

+244
-32
lines changed

README.md

+24-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# cosched
22

3-
A simple c++20 header-only coroutine scheduler.
3+
A simple c++20 header-only coroutine scheduler that supports parallel task, timer, latch, and mutex.
44

5-
Let's start from several examples.
5+
It is not intend for a high performance coroutine framework that can be used in a real production environment.
6+
My goal is to implement all basic functions using the fewest possible lines of code.
67

78
# Example
89

@@ -107,6 +108,27 @@ int main() {
107108
}
108109
```
109110
111+
## Async mutex
112+
113+
Cosched supports coroutine mutex lock. Unlike a normal mutex, which blocks threads, the coroutine mutex only blocks the coroutines, allowing the worker thread to continue executing other tasks.
114+
```c++
115+
std::vector<int> v;
116+
coro::async_mutex mu;
117+
118+
auto push_task = [&]() -> coro::task<> {
119+
// create a lock guard type (same as the std::unique_lock).
120+
coro::async_lock l = co_await coro::async_lock::make_lock(mu);
121+
assert(l.owns_lock());
122+
std::cout << "push back task begin, timestamp="
123+
<< std::chrono::duration_cast<std::chrono::milliseconds>(
124+
std::chrono::steady_clock::now().time_since_epoch())
125+
.count()
126+
<< '\n';
127+
v.push_back(v.size());
128+
co_await coro::this_scheduler::sleep_for(10ms);
129+
};
130+
```
131+
110132
# Key Design
111133

112134
In this chapter I will introduce how this tiny scheduler works in behind. It involves the core concepts of the c++20 coroutine.

cosched.hpp

+178-28
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#pragma once
22
#include <atomic>
3-
#include <cassert>
43
#include <chrono>
54
#include <concepts>
65
#include <condition_variable>
76
#include <coroutine>
87
#include <cstddef>
8+
#include <cstdint>
99
#include <future>
1010
#include <memory>
1111
#include <mutex>
@@ -52,8 +52,10 @@ enum class task_type {
5252
async,
5353
};
5454

55-
class latch;
55+
class async_latch;
5656
class timer;
57+
class async_mutex;
58+
class async_lock;
5759

5860
namespace this_scheduler {
5961
inline always_awaiter yield;
@@ -72,23 +74,48 @@ struct coro_mutex_context {
7274
std::mutex mu;
7375
};
7476

75-
template <class Pred>
77+
struct mono {
78+
void operator()() const noexcept {}
79+
};
80+
81+
template <class Pred, class RetFn = mono>
7682
requires requires(Pred f, std::coroutine_handle<> h) {
7783
{ f(h) } -> std::same_as<bool>;
7884
}
7985
class condition_awaiter : public std::suspend_always {
8086
public:
81-
explicit condition_awaiter(Pred f) noexcept : f_(std::move(f)) {}
87+
explicit condition_awaiter(Pred f, RetFn r = {}) noexcept
88+
: f_(std::move(f)), r_(std::move(r)) {}
8289

8390
bool await_suspend(std::coroutine_handle<> h) const { return f_(h); }
8491

92+
std::invoke_result_t<RetFn> await_resume() {
93+
if constexpr (!std::is_same_v<void, std::invoke_result_t<RetFn>>) {
94+
return r_();
95+
}
96+
}
97+
8598
private:
8699
Pred f_;
100+
RetFn r_;
87101
};
88102

89103
template <class Tp>
90104
condition_awaiter(Tp) -> condition_awaiter<Tp>;
91105

106+
template <class Tp, class Rp>
107+
condition_awaiter(Tp, Rp) -> condition_awaiter<Tp, Rp>;
108+
109+
template <class Tp>
110+
struct enable_condition_awaiter_transform : std::false_type {};
111+
112+
template <class RetFn = mono>
113+
struct async_lock_token {
114+
inline auto wait_this(static_thread_pool*);
115+
async_mutex* mu;
116+
RetFn ret;
117+
};
118+
92119
template <class Tp>
93120
struct final_awaiter;
94121

@@ -114,11 +141,17 @@ class promise_base : public std::promise<Tp> {
114141

115142
inline final_awaiter<Tp> final_suspend() noexcept;
116143

117-
template <class NotAllowedAwaiter>
118-
void await_transform(NotAllowedAwaiter) = delete;
144+
template <class Up>
145+
requires enable_condition_awaiter_transform<Up>::value auto await_transform(
146+
Up&& u) noexcept {
147+
return condition_awaiter(u.wait_this(shared_ctx_->scheduler));
148+
}
119149

120-
inline auto await_transform(latch&) noexcept;
121-
inline auto await_transform(const timer&) noexcept;
150+
template <class Rp>
151+
auto await_transform(async_lock_token<Rp> att) noexcept {
152+
return condition_awaiter(att.wait_this(shared_ctx_->scheduler),
153+
std::move(att.ret));
154+
}
122155

123156
always_awaiter await_transform(always_awaiter) noexcept {
124157
return always_awaiter{shared_ctx_->scheduler};
@@ -401,8 +434,9 @@ class static_thread_pool {
401434
template <class Tp>
402435
friend struct details_::final_awaiter;
403436

404-
friend class latch;
437+
friend class async_latch;
405438
friend class timer;
439+
friend class async_mutex;
406440

407441
template <class Tp>
408442
void schedule(std::coroutine_handle<Tp> handle) {
@@ -572,11 +606,128 @@ details_::promise_base<Tp>::final_suspend() noexcept {
572606
return {this};
573607
}
574608

575-
class latch {
609+
class async_mutex {
610+
static constexpr uint64_t kMuLocked = 0x1;
611+
static constexpr uint64_t kMuWait = 0x2;
612+
576613
public:
577-
explicit latch(std::ptrdiff_t countdown) : countdown_(countdown) {}
578-
latch(const latch&) = delete;
579-
latch& operator=(const latch&) = delete;
614+
async_mutex() : state_(0) {}
615+
async_mutex(const async_mutex&) = delete;
616+
async_mutex& operator=(const async_mutex&) = delete;
617+
618+
details_::async_lock_token<> lock() { return {this}; }
619+
620+
void unlock() {
621+
uint64_t s = state_.load(std::memory_order_relaxed);
622+
if ((s & kMuWait) || !state_.compare_exchange_strong(s, s & ~kMuLocked)) {
623+
unlock_slow();
624+
}
625+
}
626+
627+
private:
628+
template <class Rp>
629+
friend struct details_::async_lock_token;
630+
631+
auto wait_this(static_thread_pool* scheduler) {
632+
return [scheduler, this](std::coroutine_handle<> this_coroutine) -> bool {
633+
uint64_t s = state_.load(std::memory_order_relaxed);
634+
if ((s & (kMuLocked | kMuWait)) ||
635+
!state_.compare_exchange_strong(s, s | kMuLocked,
636+
std::memory_order_acq_rel)) {
637+
return lock_slow(scheduler, this_coroutine);
638+
}
639+
return false;
640+
};
641+
}
642+
643+
bool lock_slow(static_thread_pool* scheduler,
644+
std::coroutine_handle<> this_coroutine) {
645+
std::unique_lock l(wait_ctx_.mu);
646+
state_.fetch_or(kMuWait, std::memory_order_relaxed);
647+
uint64_t s = state_.load(std::memory_order_relaxed);
648+
if (!(s & kMuLocked)) {
649+
state_.store((s | kMuLocked) & ~kMuWait, std::memory_order_relaxed);
650+
return false;
651+
}
652+
wait_ctx_.scheduler = scheduler;
653+
wait_ctx_.wait_ques.push_back(this_coroutine);
654+
return true;
655+
}
656+
657+
void unlock_slow() {
658+
std::unique_lock l(wait_ctx_.mu);
659+
uint64_t s = state_.load(std::memory_order_relaxed);
660+
if (wait_ctx_.wait_ques.empty()) {
661+
state_.store(s & ~kMuLocked, std::memory_order_relaxed);
662+
return;
663+
}
664+
auto h = wait_ctx_.wait_ques.front();
665+
wait_ctx_.wait_ques.pop_front();
666+
wait_ctx_.scheduler->schedule(h);
667+
}
668+
669+
std::atomic<uint64_t> state_;
670+
details_::coro_mutex_context wait_ctx_;
671+
};
672+
673+
template <class RetFn>
674+
auto details_::async_lock_token<RetFn>::wait_this(
675+
static_thread_pool* scheduler) {
676+
return mu->wait_this(scheduler);
677+
}
678+
679+
class async_lock {
680+
public:
681+
static auto make_lock(async_mutex& mu) {
682+
auto create_lock = [mu = &mu] { return async_lock(*mu); };
683+
return details_::async_lock_token<decltype(create_lock)>(
684+
&mu, std::move(create_lock));
685+
}
686+
687+
async_lock(async_mutex& mu, std::defer_lock_t) noexcept
688+
: mu_(&mu), owns_(false) {}
689+
690+
async_lock(async_mutex& mu, std::adopt_lock_t) noexcept
691+
: mu_(&mu), owns_(true) {}
692+
693+
details_::async_lock_token<> lock() {
694+
owns_ = true;
695+
return {mu_};
696+
}
697+
698+
async_lock(async_lock&& r) noexcept { *this = std::move(r); }
699+
700+
async_lock& operator=(async_lock&& r) noexcept {
701+
mu_ = r.mu_;
702+
owns_ = r.owns_;
703+
r.mu_ = nullptr;
704+
r.owns_ = false;
705+
return *this;
706+
}
707+
708+
void unlock() {
709+
mu_->unlock();
710+
owns_ = false;
711+
}
712+
713+
bool owns_lock() const noexcept { return owns_; }
714+
715+
~async_lock() {
716+
if (owns_ && mu_) mu_->unlock();
717+
}
718+
719+
private:
720+
async_lock(async_mutex& mu) : mu_(&mu), owns_(true) {}
721+
722+
async_mutex* mu_;
723+
bool owns_;
724+
};
725+
726+
class async_latch {
727+
public:
728+
explicit async_latch(std::ptrdiff_t countdown) : countdown_(countdown) {}
729+
async_latch(const async_latch&) = delete;
730+
async_latch& operator=(const async_latch&) = delete;
580731

581732
void count_down(std::ptrdiff_t n = 1) {
582733
auto before = countdown_.fetch_sub(n, std::memory_order_acq_rel);
@@ -594,13 +745,13 @@ class latch {
594745
friend class details_::promise_base;
595746

596747
auto wait_this(static_thread_pool* scheduler) {
597-
return [scheduler, this](std::coroutine_handle<> h) -> bool {
748+
return [scheduler, this](std::coroutine_handle<> this_coroutine) -> bool {
598749
std::unique_lock l(wait_ctx_.mu);
599750
if (countdown_.load(std::memory_order_acquire) <= 0) {
600751
return false;
601752
}
602753
wait_ctx_.scheduler = scheduler;
603-
wait_ctx_.wait_ques.push_back(h);
754+
wait_ctx_.wait_ques.push_back(this_coroutine);
604755
return true;
605756
};
606757
}
@@ -609,6 +760,11 @@ class latch {
609760
details_::coro_mutex_context wait_ctx_;
610761
};
611762

763+
namespace details_ {
764+
template <>
765+
struct enable_condition_awaiter_transform<async_latch&> : std::true_type {};
766+
} // namespace details_
767+
612768
class timer {
613769
public:
614770
using duration = details_::time_manager::duration;
@@ -623,35 +779,29 @@ class timer {
623779
friend class details_::promise_base;
624780

625781
auto wait_this(static_thread_pool* scheduler) const {
626-
return [scheduler, this](std::coroutine_handle<> h) -> bool {
782+
return [scheduler, this](std::coroutine_handle<> this_coroutine) -> bool {
627783
if (!scheduler) {
628784
std::this_thread::sleep_for(t_);
629785
return false;
630786
}
631-
scheduler->schedule_timer(h, t_);
787+
scheduler->schedule_timer(this_coroutine, t_);
632788
return true;
633789
};
634790
}
635791

636792
duration t_;
637793
};
638794

795+
namespace details_ {
796+
template <>
797+
struct enable_condition_awaiter_transform<timer> : std::true_type {};
798+
} // namespace details_
799+
639800
template <class Rep, class Period>
640801
timer this_scheduler::sleep_for(const std::chrono::duration<Rep, Period>& rel) {
641802
return timer(rel);
642803
}
643804

644-
template <class Tp>
645-
inline auto details_::promise_base<Tp>::await_transform(latch& l) noexcept {
646-
return condition_awaiter(l.wait_this(shared_ctx_->scheduler));
647-
}
648-
649-
template <class Tp>
650-
inline auto details_::promise_base<Tp>::await_transform(
651-
const timer& t) noexcept {
652-
return condition_awaiter(t.wait_this(shared_ctx_->scheduler));
653-
}
654-
655805
template <class Tp>
656806
template <class Up>
657807
inline async_awaiter<Up> details_::promise_base<Tp>::await_transform(

0 commit comments

Comments
 (0)