@@ -63,6 +63,8 @@ limitations under the License.
63
63
#include " xla/pjrt/pjrt_future.h"
64
64
#include " xla/pjrt/pjrt_stream_executor_client.h"
65
65
#include " xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
66
+ #include " xla/pjrt/profiling/device_time_measurement.h"
67
+ #include " xla/pjrt/profiling/test_util/mock_device_time_measurement.h"
66
68
#include " xla/service/platform_util.h"
67
69
#include " xla/shape.h"
68
70
#include " xla/shape_util.h"
@@ -1875,6 +1877,66 @@ TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) {
1875
1877
EXPECT_NE (layouts[1 ]->ToString (), " {2,1,0}" );
1876
1878
}
1877
1879
1880
+ // Same test as SendRecvChunked, but check GPU device time measurement.
1881
+ TEST (StreamExecutorGpuClientTest, NonZeroGPUDeviceTimeMeasurement) {
1882
+ TF_ASSERT_OK_AND_ASSIGN (auto client,
1883
+ GetStreamExecutorGpuClient (GpuClientOptions ()));
1884
+
1885
+ TF_ASSERT_OK_AND_ASSIGN (auto executable,
1886
+ CompileExecutable (kProgram , *client));
1887
+
1888
+ std::array<float , 2 > sent_value = {0 .0f , 0 .0f };
1889
+
1890
+ // Send buffer to host.
1891
+ SendCallback send_callback = {
1892
+ /* channel_id=*/ 1 , [&](const PjRtTransferMetadata& m, PjRtChunk chunk,
1893
+ int64_t total_size_in_bytes, bool done) {
1894
+ float * data = reinterpret_cast <float *>(chunk.data ());
1895
+ sent_value[0 ] = data[0 ];
1896
+ sent_value[1 ] = data[1 ];
1897
+ return absl::OkStatus ();
1898
+ }};
1899
+
1900
+ // Recv buffer from host.
1901
+ RecvCallback recv_callback = {
1902
+ /* channel_id=*/ 2 , [&](const PjRtTransferMetadata& m,
1903
+ std::unique_ptr<CopyToDeviceStream> stream) {
1904
+ auto chunk0 = PjRtChunk::AllocateDefault (sizeof (float ));
1905
+ *reinterpret_cast <float *>(chunk0.data ()) = 5 .0f ;
1906
+ TF_CHECK_OK (stream->AddChunk (std::move (chunk0)).Await ());
1907
+
1908
+ auto chunk1 = PjRtChunk::AllocateDefault (sizeof (float ));
1909
+ *reinterpret_cast <float *>(chunk1.data ()) = 6 .0f ;
1910
+ TF_CHECK_OK (stream->AddChunk (std::move (chunk1)).Await ());
1911
+
1912
+ return absl::OkStatus ();
1913
+ }};
1914
+
1915
+ // Callbacks for point-to-point communication ops.
1916
+ std::vector<std::vector<SendCallback>> send_callbacks = {{send_callback}};
1917
+ std::vector<std::vector<RecvCallback>> recv_callbacks = {{recv_callback}};
1918
+
1919
+ ExecuteOptions opts;
1920
+ opts.send_callbacks = send_callbacks;
1921
+ opts.recv_callbacks = recv_callbacks;
1922
+
1923
+ // Test non-zero GPU device time measurement.
1924
+ auto measurement0 = CreateDeviceTimeMeasurement ();
1925
+ auto result = executable->Execute (/* argument_handles=*/ {{}}, opts);
1926
+
1927
+ TF_ASSERT_OK_AND_ASSIGN (std::shared_ptr<xla::Literal> result_literal,
1928
+ ExtractSingleResult (result));
1929
+ EXPECT_EQ (sent_value[0 ], 2 .0f );
1930
+ EXPECT_EQ (sent_value[1 ], 3 .0f );
1931
+ EXPECT_TRUE (LiteralTestUtil::Equal (LiteralUtil::CreateR1<float >({5 .0f , 6 .0f }),
1932
+ *result_literal));
1933
+
1934
+ // Check measurement after execution completes.
1935
+ EXPECT_GT (
1936
+ measurement0->GetTotalDuration (DeviceTimeMeasurement::DeviceType::kGpu ),
1937
+ absl::ZeroDuration ());
1938
+ }
1939
+
1878
1940
struct ShardedAutotuningTestInfo {
1879
1941
bool use_xla_computation;
1880
1942
int num_active_nodes;
0 commit comments