@@ -32,6 +32,7 @@ limitations under the License.
32
32
#include < iostream>
33
33
#include < limits>
34
34
#include < memory>
35
+ #include < mutex> // NOLINT
35
36
#include < numeric>
36
37
#include < optional>
37
38
#include < ostream>
@@ -322,10 +323,14 @@ class ErrorOr : public Expected<T, Error> {
322
323
// A promise to complete execution with a success or an error.
323
324
class Promise ;
324
325
326
+ // A promise that completes when a specific number of count downs have occurred.
327
+ class CountDownPromise ;
328
+
325
329
// A future that becomes available when a corresponding promise is completed.
326
330
class Future {
327
331
public:
328
332
explicit Future (const Promise& promise);
333
+ explicit Future (const CountDownPromise& promise);
329
334
330
335
Future (Future&&) = default ;
331
336
Future& operator =(Future&&) = default ;
@@ -377,6 +382,9 @@ class Promise {
377
382
public:
378
383
Promise () : data_(std::make_shared<Future::Data>()) {}
379
384
385
+ Promise (const Promise&) = default ;
386
+ Promise& operator =(const Promise&) = default ;
387
+
380
388
Promise (Promise&&) = default ;
381
389
Promise& operator =(Promise&&) = default ;
382
390
@@ -391,11 +399,83 @@ class Promise {
391
399
std::shared_ptr<Future::Data> data_;
392
400
};
393
401
402
+ // A simple implementation of `tsl::CountDownAsyncValueRef` that is compatible
403
+ // with `ffi::Future`.
404
+ class CountDownPromise {
405
+ public:
406
+ CountDownPromise () = default ;
407
+
408
+ CountDownPromise (Promise promise, int64_t count)
409
+ : state_(std::make_shared<State>(std::move(promise), count)) {
410
+ assert (count > 0 && " Count must be positive" );
411
+ }
412
+
413
+ explicit CountDownPromise (int64_t count)
414
+ : CountDownPromise(Promise(), count) {}
415
+
416
+ // Drops the count by `count` and returns true if the underlying promise
417
+ // became available.
418
+ bool CountDown (size_t count, const Error& error = Error::Success()) {
419
+ assert (state_->count .load () >= count && " Invalid count down value" );
420
+
421
+ if (XLA_FFI_PREDICT_FALSE (!error.success ())) {
422
+ const std::lock_guard<std::mutex> lock (state_->mutex );
423
+ state_->is_error .store (true , std::memory_order_release);
424
+ state_->error = error;
425
+ }
426
+
427
+ bool is_complete =
428
+ state_->count .fetch_sub (count, std::memory_order_acq_rel) == count;
429
+ if (XLA_FFI_PREDICT_FALSE (is_complete)) {
430
+ bool is_error = state_->is_error .load (std::memory_order_acquire);
431
+ if (XLA_FFI_PREDICT_FALSE (is_error)) {
432
+ auto take_error = [&] {
433
+ const std::lock_guard<std::mutex> lock (state_->mutex );
434
+ return state_->error ;
435
+ };
436
+ state_->promise .SetError (take_error ());
437
+ return true ;
438
+ } else {
439
+ state_->promise .SetAvailable ();
440
+ return true ;
441
+ }
442
+ }
443
+
444
+ return false ;
445
+ }
446
+
447
+ // Drops the count by `1` and returns true if the underlying promise became
448
+ // available.
449
+ bool CountDown (Error error = Error::Success()) { return CountDown (1 , error); }
450
+
451
+ private:
452
+ friend class Future ;
453
+
454
+ struct State {
455
+ State (Promise promise, int64_t count)
456
+ : promise(std::move(promise)), count(count), is_error(false ) {}
457
+
458
+ Promise promise;
459
+ std::atomic<int64_t > count;
460
+ std::atomic<bool > is_error;
461
+
462
+ std::mutex mutex;
463
+ Error error;
464
+ };
465
+
466
+ std::shared_ptr<State> state_;
467
+
468
+ const Promise& AsPromise () const { return state_->promise ; }
469
+ };
470
+
394
471
inline Future::Future (const Promise& promise) : data_(promise.data_) {
395
472
assert (data_.use_count () == 2 &&
396
473
" Promise can be used to create at most one Future" );
397
474
}
398
475
476
+ inline Future::Future (const CountDownPromise& promise)
477
+ : Future(promise.AsPromise()) {}
478
+
399
479
template <typename F>
400
480
void Future::OnReady (F&& f) {
401
481
static_assert (std::is_invocable_v<F, const std::optional<Error>&>,
0 commit comments