Skip to content

Commit df24a90

Browse files
sizhit2tensorflower-gardener
authored andcommitted
Add RAII helper class MarkEventReadyOnExit for GpuEvent
PiperOrigin-RevId: 735962674
1 parent 516528b commit df24a90

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

third_party/xla/xla/pjrt/gpu/tfrt/gpu_event.h

+26
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define XLA_PJRT_GPU_TFRT_GPU_EVENT_H_
1818

1919
#include <cstddef>
20+
#include <utility>
2021

2122
#include "absl/container/inlined_vector.h"
2223
#include "absl/types/span.h"
@@ -63,6 +64,31 @@ class TfrtEventSet {
6364
absl::InlinedVector<tsl::AsyncValueRef<GpuEvent>, 4> events_;
6465
};
6566

67+
// A RAII helper class used to set an AsyncValueRef<GpuEvent> to a ready state
68+
// upon destruction. In many cases in PjRt implementation, there will be
69+
// multiple return statements in the function, all of which require setting
70+
// some AsyncValueRef<GpuEvent> to be ready. This class could make such code
71+
// more robust by using setting the AsyncValue in the destructor.
72+
class MarkEventReadyOnExit {
73+
public:
74+
explicit MarkEventReadyOnExit(tsl::AsyncValueRef<GpuEvent> event)
75+
: event_(std::move(event)) {}
76+
77+
MarkEventReadyOnExit(const MarkEventReadyOnExit&) = delete;
78+
MarkEventReadyOnExit& operator=(const MarkEventReadyOnExit&) = delete;
79+
MarkEventReadyOnExit(MarkEventReadyOnExit&&) = default;
80+
MarkEventReadyOnExit& operator=(MarkEventReadyOnExit&&) = default;
81+
82+
~MarkEventReadyOnExit() {
83+
if (event_) event_.SetStateConcrete();
84+
}
85+
86+
tsl::AsyncValueRef<GpuEvent> Release() && { return std::move(event_); }
87+
88+
private:
89+
tsl::AsyncValueRef<GpuEvent> event_;
90+
};
91+
6692
} // namespace xla
6793

6894
#endif // XLA_PJRT_GPU_TFRT_GPU_EVENT_H_

third_party/xla/xla/pjrt/gpu/tfrt/gpu_event_test.cc

+13
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,18 @@ TEST(TfrtEventSetTest, ClearEvents) {
111111
EXPECT_EQ(event_set.size(), 0);
112112
}
113113

114+
TEST(MarkEventReadyOnExitTest, EventReleaseAndReadyOnExit) {
115+
tsl::AsyncValueRef<GpuEvent> event =
116+
tsl::MakeConstructedAsyncValueRef<GpuEvent>();
117+
tsl::AsyncValueRef<GpuEvent> released_event =
118+
MarkEventReadyOnExit(event).Release();
119+
EXPECT_EQ(event.GetAsyncValue(), released_event.GetAsyncValue());
120+
{
121+
MarkEventReadyOnExit ready_on_exit(event);
122+
EXPECT_FALSE(event.IsAvailable());
123+
}
124+
EXPECT_TRUE(event.IsAvailable());
125+
}
126+
114127
} // namespace
115128
} // namespace xla

0 commit comments

Comments
 (0)