1
1
#pragma once
2
2
#include < atomic>
3
- #include < cassert>
4
3
#include < chrono>
5
4
#include < concepts>
6
5
#include < condition_variable>
7
6
#include < coroutine>
8
7
#include < cstddef>
8
+ #include < cstdint>
9
9
#include < future>
10
10
#include < memory>
11
11
#include < mutex>
@@ -52,8 +52,10 @@ enum class task_type {
52
52
async,
53
53
};
54
54
55
- class latch ;
55
+ class async_latch ;
56
56
class timer ;
57
+ class async_mutex ;
58
+ class async_lock ;
57
59
58
60
namespace this_scheduler {
59
61
inline always_awaiter yield;
@@ -72,23 +74,48 @@ struct coro_mutex_context {
72
74
std::mutex mu;
73
75
};
74
76
75
- template <class Pred >
77
+ struct mono {
78
+ void operator ()() const noexcept {}
79
+ };
80
+
81
+ template <class Pred , class RetFn = mono>
76
82
requires requires (Pred f, std::coroutine_handle<> h) {
77
83
{ f (h) } -> std::same_as<bool >;
78
84
}
79
85
class condition_awaiter : public std ::suspend_always {
80
86
public:
81
- explicit condition_awaiter (Pred f) noexcept : f_(std::move(f)) {}
87
+ explicit condition_awaiter (Pred f, RetFn r = {}) noexcept
88
+ : f_(std::move(f)), r_(std::move(r)) {}
82
89
83
90
bool await_suspend (std::coroutine_handle<> h) const { return f_ (h); }
84
91
92
+ std::invoke_result_t <RetFn> await_resume () {
93
+ if constexpr (!std::is_same_v<void , std::invoke_result_t <RetFn>>) {
94
+ return r_ ();
95
+ }
96
+ }
97
+
85
98
private:
86
99
Pred f_;
100
+ RetFn r_;
87
101
};
88
102
89
103
template <class Tp >
90
104
condition_awaiter (Tp) -> condition_awaiter<Tp>;
91
105
106
+ template <class Tp , class Rp >
107
+ condition_awaiter (Tp, Rp) -> condition_awaiter<Tp, Rp>;
108
+
109
+ template <class Tp >
110
+ struct enable_condition_awaiter_transform : std::false_type {};
111
+
112
+ template <class RetFn = mono>
113
+ struct async_lock_token {
114
+ inline auto wait_this (static_thread_pool*);
115
+ async_mutex* mu;
116
+ RetFn ret;
117
+ };
118
+
92
119
template <class Tp >
93
120
struct final_awaiter ;
94
121
@@ -114,11 +141,17 @@ class promise_base : public std::promise<Tp> {
114
141
115
142
inline final_awaiter<Tp> final_suspend () noexcept ;
116
143
117
- template <class NotAllowedAwaiter >
118
- void await_transform (NotAllowedAwaiter) = delete;
144
+ template <class Up >
145
+ requires enable_condition_awaiter_transform<Up>::value auto await_transform (
146
+ Up&& u) noexcept {
147
+ return condition_awaiter (u.wait_this (shared_ctx_->scheduler ));
148
+ }
119
149
120
- inline auto await_transform (latch&) noexcept ;
121
- inline auto await_transform (const timer&) noexcept ;
150
+ template <class Rp >
151
+ auto await_transform (async_lock_token<Rp> att) noexcept {
152
+ return condition_awaiter (att.wait_this (shared_ctx_->scheduler ),
153
+ std::move (att.ret ));
154
+ }
122
155
123
156
always_awaiter await_transform (always_awaiter) noexcept {
124
157
return always_awaiter{shared_ctx_->scheduler };
@@ -401,8 +434,9 @@ class static_thread_pool {
401
434
template <class Tp >
402
435
friend struct details_ ::final_awaiter;
403
436
404
- friend class latch ;
437
+ friend class async_latch ;
405
438
friend class timer ;
439
+ friend class async_mutex ;
406
440
407
441
template <class Tp >
408
442
void schedule (std::coroutine_handle<Tp> handle) {
@@ -572,11 +606,128 @@ details_::promise_base<Tp>::final_suspend() noexcept {
572
606
return {this };
573
607
}
574
608
575
- class latch {
609
+ class async_mutex {
610
+ static constexpr uint64_t kMuLocked = 0x1 ;
611
+ static constexpr uint64_t kMuWait = 0x2 ;
612
+
576
613
public:
577
- explicit latch (std::ptrdiff_t countdown) : countdown_(countdown) {}
578
- latch (const latch&) = delete ;
579
- latch& operator =(const latch&) = delete ;
614
+ async_mutex () : state_(0 ) {}
615
+ async_mutex (const async_mutex&) = delete ;
616
+ async_mutex& operator =(const async_mutex&) = delete ;
617
+
618
+ details_::async_lock_token<> lock () { return {this }; }
619
+
620
+ void unlock () {
621
+ uint64_t s = state_.load (std::memory_order_relaxed);
622
+ if ((s & kMuWait ) || !state_.compare_exchange_strong (s, s & ~kMuLocked )) {
623
+ unlock_slow ();
624
+ }
625
+ }
626
+
627
+ private:
628
+ template <class Rp >
629
+ friend struct details_ ::async_lock_token;
630
+
631
+ auto wait_this (static_thread_pool* scheduler) {
632
+ return [scheduler, this ](std::coroutine_handle<> this_coroutine) -> bool {
633
+ uint64_t s = state_.load (std::memory_order_relaxed);
634
+ if ((s & (kMuLocked | kMuWait )) ||
635
+ !state_.compare_exchange_strong (s, s | kMuLocked ,
636
+ std::memory_order_acq_rel)) {
637
+ return lock_slow (scheduler, this_coroutine);
638
+ }
639
+ return false ;
640
+ };
641
+ }
642
+
643
+ bool lock_slow (static_thread_pool* scheduler,
644
+ std::coroutine_handle<> this_coroutine) {
645
+ std::unique_lock l (wait_ctx_.mu );
646
+ state_.fetch_or (kMuWait , std::memory_order_relaxed);
647
+ uint64_t s = state_.load (std::memory_order_relaxed);
648
+ if (!(s & kMuLocked )) {
649
+ state_.store ((s | kMuLocked ) & ~kMuWait , std::memory_order_relaxed);
650
+ return false ;
651
+ }
652
+ wait_ctx_.scheduler = scheduler;
653
+ wait_ctx_.wait_ques .push_back (this_coroutine);
654
+ return true ;
655
+ }
656
+
657
+ void unlock_slow () {
658
+ std::unique_lock l (wait_ctx_.mu );
659
+ uint64_t s = state_.load (std::memory_order_relaxed);
660
+ if (wait_ctx_.wait_ques .empty ()) {
661
+ state_.store (s & ~kMuLocked , std::memory_order_relaxed);
662
+ return ;
663
+ }
664
+ auto h = wait_ctx_.wait_ques .front ();
665
+ wait_ctx_.wait_ques .pop_front ();
666
+ wait_ctx_.scheduler ->schedule (h);
667
+ }
668
+
669
+ std::atomic<uint64_t > state_;
670
+ details_::coro_mutex_context wait_ctx_;
671
+ };
672
+
673
+ template <class RetFn >
674
+ auto details_::async_lock_token<RetFn>::wait_this(
675
+ static_thread_pool* scheduler) {
676
+ return mu->wait_this (scheduler);
677
+ }
678
+
679
+ class async_lock {
680
+ public:
681
+ static auto make_lock (async_mutex& mu) {
682
+ auto create_lock = [mu = &mu] { return async_lock (*mu); };
683
+ return details_::async_lock_token<decltype (create_lock)>(
684
+ &mu, std::move (create_lock));
685
+ }
686
+
687
+ async_lock (async_mutex& mu, std::defer_lock_t ) noexcept
688
+ : mu_(&mu), owns_(false ) {}
689
+
690
+ async_lock (async_mutex& mu, std::adopt_lock_t ) noexcept
691
+ : mu_(&mu), owns_(true ) {}
692
+
693
+ details_::async_lock_token<> lock () {
694
+ owns_ = true ;
695
+ return {mu_};
696
+ }
697
+
698
+ async_lock (async_lock&& r) noexcept { *this = std::move (r); }
699
+
700
+ async_lock& operator =(async_lock&& r) noexcept {
701
+ mu_ = r.mu_ ;
702
+ owns_ = r.owns_ ;
703
+ r.mu_ = nullptr ;
704
+ r.owns_ = false ;
705
+ return *this ;
706
+ }
707
+
708
+ void unlock () {
709
+ mu_->unlock ();
710
+ owns_ = false ;
711
+ }
712
+
713
+ bool owns_lock () const noexcept { return owns_; }
714
+
715
+ ~async_lock () {
716
+ if (owns_ && mu_) mu_->unlock ();
717
+ }
718
+
719
+ private:
720
+ async_lock (async_mutex& mu) : mu_(&mu), owns_(true ) {}
721
+
722
+ async_mutex* mu_;
723
+ bool owns_;
724
+ };
725
+
726
+ class async_latch {
727
+ public:
728
+ explicit async_latch (std::ptrdiff_t countdown) : countdown_(countdown) {}
729
+ async_latch (const async_latch&) = delete ;
730
+ async_latch& operator =(const async_latch&) = delete ;
580
731
581
732
void count_down (std::ptrdiff_t n = 1 ) {
582
733
auto before = countdown_.fetch_sub (n, std::memory_order_acq_rel);
@@ -594,13 +745,13 @@ class latch {
594
745
friend class details_ ::promise_base;
595
746
596
747
auto wait_this (static_thread_pool* scheduler) {
597
- return [scheduler, this ](std::coroutine_handle<> h ) -> bool {
748
+ return [scheduler, this ](std::coroutine_handle<> this_coroutine ) -> bool {
598
749
std::unique_lock l (wait_ctx_.mu );
599
750
if (countdown_.load (std::memory_order_acquire) <= 0 ) {
600
751
return false ;
601
752
}
602
753
wait_ctx_.scheduler = scheduler;
603
- wait_ctx_.wait_ques .push_back (h );
754
+ wait_ctx_.wait_ques .push_back (this_coroutine );
604
755
return true ;
605
756
};
606
757
}
@@ -609,6 +760,11 @@ class latch {
609
760
details_::coro_mutex_context wait_ctx_;
610
761
};
611
762
763
+ namespace details_ {
764
+ template <>
765
+ struct enable_condition_awaiter_transform <async_latch&> : std::true_type {};
766
+ } // namespace details_
767
+
612
768
class timer {
613
769
public:
614
770
using duration = details_::time_manager::duration;
@@ -623,35 +779,29 @@ class timer {
623
779
friend class details_ ::promise_base;
624
780
625
781
auto wait_this (static_thread_pool* scheduler) const {
626
- return [scheduler, this ](std::coroutine_handle<> h ) -> bool {
782
+ return [scheduler, this ](std::coroutine_handle<> this_coroutine ) -> bool {
627
783
if (!scheduler) {
628
784
std::this_thread::sleep_for (t_);
629
785
return false ;
630
786
}
631
- scheduler->schedule_timer (h , t_);
787
+ scheduler->schedule_timer (this_coroutine , t_);
632
788
return true ;
633
789
};
634
790
}
635
791
636
792
duration t_;
637
793
};
638
794
795
+ namespace details_ {
796
+ template <>
797
+ struct enable_condition_awaiter_transform <timer> : std::true_type {};
798
+ } // namespace details_
799
+
639
800
template <class Rep , class Period >
640
801
timer this_scheduler::sleep_for (const std::chrono::duration<Rep, Period>& rel) {
641
802
return timer (rel);
642
803
}
643
804
644
- template <class Tp >
645
- inline auto details_::promise_base<Tp>::await_transform(latch& l) noexcept {
646
- return condition_awaiter (l.wait_this (shared_ctx_->scheduler ));
647
- }
648
-
649
- template <class Tp >
650
- inline auto details_::promise_base<Tp>::await_transform(
651
- const timer& t) noexcept {
652
- return condition_awaiter (t.wait_this (shared_ctx_->scheduler ));
653
- }
654
-
655
805
template <class Tp >
656
806
template <class Up >
657
807
inline async_awaiter<Up> details_::promise_base<Tp>::await_transform(
0 commit comments