Skip to content

Commit 8bc7fa2

Browse files
dfmGoogle-ML-Automation
authored andcommitted
[XLA:FFI] Add an FFI compatible implementation of tsl::CountDownAsyncValueRef.
This supports the common pattern of enqueuing a specific number of async tasks within an FFI handler. PiperOrigin-RevId: 722455855
1 parent 492a921 commit 8bc7fa2

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed

xla/ffi/api/ffi.h

+80
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include <iostream>
3333
#include <limits>
3434
#include <memory>
35+
#include <mutex> // NOLINT
3536
#include <numeric>
3637
#include <optional>
3738
#include <ostream>
@@ -322,10 +323,14 @@ class ErrorOr : public Expected<T, Error> {
322323
// A promise to complete execution with a success or an error.
323324
class Promise;
324325

326+
// A promise that completes when a specific number of count downs have occurred.
327+
class CountDownPromise;
328+
325329
// A future that becomes available when a corresponding promise is completed.
326330
class Future {
327331
public:
328332
explicit Future(const Promise& promise);
333+
explicit Future(const CountDownPromise& promise);
329334

330335
Future(Future&&) = default;
331336
Future& operator=(Future&&) = default;
@@ -377,6 +382,9 @@ class Promise {
377382
public:
378383
Promise() : data_(std::make_shared<Future::Data>()) {}
379384

385+
Promise(const Promise&) = default;
386+
Promise& operator=(const Promise&) = default;
387+
380388
Promise(Promise&&) = default;
381389
Promise& operator=(Promise&&) = default;
382390

@@ -391,11 +399,83 @@ class Promise {
391399
std::shared_ptr<Future::Data> data_;
392400
};
393401

402+
// A simple implementation of `tsl::CountDownAsyncValueRef` that is compatible
403+
// with `ffi::Future`.
404+
class CountDownPromise {
405+
public:
406+
CountDownPromise() = default;
407+
408+
CountDownPromise(Promise promise, int64_t count)
409+
: state_(std::make_shared<State>(std::move(promise), count)) {
410+
assert(count > 0 && "Count must be positive");
411+
}
412+
413+
explicit CountDownPromise(int64_t count)
414+
: CountDownPromise(Promise(), count) {}
415+
416+
// Drops the count by `count` and returns true if the underlying promise
417+
// became available.
418+
bool CountDown(size_t count, const Error& error = Error::Success()) {
419+
assert(state_->count.load() >= count && "Invalid count down value");
420+
421+
if (XLA_FFI_PREDICT_FALSE(!error.success())) {
422+
const std::lock_guard<std::mutex> lock(state_->mutex);
423+
state_->is_error.store(true, std::memory_order_release);
424+
state_->error = error;
425+
}
426+
427+
bool is_complete =
428+
state_->count.fetch_sub(count, std::memory_order_acq_rel) == count;
429+
if (XLA_FFI_PREDICT_FALSE(is_complete)) {
430+
bool is_error = state_->is_error.load(std::memory_order_acquire);
431+
if (XLA_FFI_PREDICT_FALSE(is_error)) {
432+
auto take_error = [&] {
433+
const std::lock_guard<std::mutex> lock(state_->mutex);
434+
return state_->error;
435+
};
436+
state_->promise.SetError(take_error());
437+
return true;
438+
} else {
439+
state_->promise.SetAvailable();
440+
return true;
441+
}
442+
}
443+
444+
return false;
445+
}
446+
447+
// Drops the count by `1` and returns true if the underlying promise became
448+
// available.
449+
bool CountDown(Error error = Error::Success()) { return CountDown(1, error); }
450+
451+
private:
452+
friend class Future;
453+
454+
struct State {
455+
State(Promise promise, int64_t count)
456+
: promise(std::move(promise)), count(count), is_error(false) {}
457+
458+
Promise promise;
459+
std::atomic<int64_t> count;
460+
std::atomic<bool> is_error;
461+
462+
std::mutex mutex;
463+
Error error;
464+
};
465+
466+
std::shared_ptr<State> state_;
467+
468+
const Promise& AsPromise() const { return state_->promise; }
469+
};
470+
394471
inline Future::Future(const Promise& promise) : data_(promise.data_) {
395472
assert(data_.use_count() == 2 &&
396473
"Promise can be used to create at most one Future");
397474
}
398475

476+
inline Future::Future(const CountDownPromise& promise)
477+
: Future(promise.AsPromise()) {}
478+
399479
template <typename F>
400480
void Future::OnReady(F&& f) {
401481
static_assert(std::is_invocable_v<F, const std::optional<Error>&>,

xla/ffi/api/ffi_test.cc

+55
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,61 @@ TEST(FfiTest, FutureRace) {
352352
}
353353
}
354354

355+
TEST(FfiTest, CountDownSuccess) {
356+
CountDownPromise counter(2);
357+
Future future(counter);
358+
EXPECT_FALSE(counter.CountDown());
359+
EXPECT_TRUE(counter.CountDown());
360+
future.OnReady([](const std::optional<Error>& error) {
361+
EXPECT_FALSE(error.has_value());
362+
});
363+
}
364+
365+
TEST(FfiTest, CountDownError) {
366+
CountDownPromise counter(3);
367+
Future future(counter);
368+
EXPECT_FALSE(counter.CountDown());
369+
EXPECT_FALSE(counter.CountDown(Error(ErrorCode::kInternal, "Test error")));
370+
EXPECT_TRUE(counter.CountDown());
371+
future.OnReady([](const std::optional<Error>& error) {
372+
EXPECT_TRUE(error.has_value());
373+
EXPECT_THAT(error->message(), HasSubstr("Test error"));
374+
});
375+
}
376+
377+
TEST(FfiTest, CountDownSuccessFromThreadPool) {
378+
tsl::thread::ThreadPool pool(tsl::Env::Default(), "ffi-test", 2);
379+
380+
CountDownPromise counter(2);
381+
Future future(counter);
382+
383+
future.OnReady([](const std::optional<Error>& error) {
384+
EXPECT_FALSE(error.has_value());
385+
});
386+
387+
for (int64_t i = 0; i < 2; ++i) {
388+
pool.Schedule([counter]() mutable { counter.CountDown(); });
389+
}
390+
}
391+
392+
TEST(FfiTest, CountDownErrorFromThreadPool) {
393+
tsl::thread::ThreadPool pool(tsl::Env::Default(), "ffi-test", 2);
394+
395+
CountDownPromise counter(3);
396+
Future future(counter);
397+
398+
future.OnReady([](const std::optional<Error>& error) {
399+
EXPECT_TRUE(error.has_value());
400+
EXPECT_THAT(error->message(), HasSubstr("Test error"));
401+
});
402+
403+
pool.Schedule([counter]() mutable { counter.CountDown(); });
404+
pool.Schedule([counter]() mutable {
405+
counter.CountDown(Error(ErrorCode::kInternal, "Test error"));
406+
});
407+
pool.Schedule([counter]() mutable { counter.CountDown(); });
408+
}
409+
355410
TEST(FfiTest, ReturnError) {
356411
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
357412
auto call_frame = builder.Build();

0 commit comments

Comments
 (0)