Skip to content

Commit 3352841

Browse files
Record device time measurements in PJRT stream executor client. Set device type to the platform that the client is running on.
PiperOrigin-RevId: 725859515
1 parent d6be12c commit 3352841

11 files changed

+345
-20
lines changed

xla/pjrt/BUILD

+6-1
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,9 @@ cc_library(
479479

480480
cc_library(
481481
name = "pjrt_stream_executor_client",
482-
srcs = ["pjrt_stream_executor_client.cc"],
482+
srcs = [
483+
"pjrt_stream_executor_client.cc",
484+
],
483485
hdrs = ["pjrt_stream_executor_client.h"],
484486
visibility = internal_visibility(["//xla:friends"]),
485487
deps = [
@@ -511,6 +513,7 @@ cc_library(
511513
"//xla/hlo/builder:xla_computation",
512514
"//xla/hlo/ir:hlo",
513515
"//xla/pjrt/distributed:protocol_proto_cc",
516+
"//xla/pjrt/profiling:device_time_measurement",
514517
"//xla/service:compiler",
515518
"//xla/service:computation_layout",
516519
"//xla/service:computation_placer",
@@ -572,6 +575,7 @@ xla_cc_test(
572575
"//xla/client:local_client",
573576
"//xla/hlo/builder:xla_builder",
574577
"//xla/hlo/testlib:test",
578+
"//xla/pjrt/profiling:device_time_measurement",
575579
"//xla/service:cpu_plugin",
576580
"//xla/service:platform_util",
577581
"//xla/stream_executor:platform",
@@ -582,6 +586,7 @@ xla_cc_test(
582586
"@com_google_absl//absl/status",
583587
"@com_google_absl//absl/status:statusor",
584588
"@com_google_absl//absl/synchronization",
589+
"@com_google_absl//absl/time",
585590
"@com_google_googletest//:gtest_main",
586591
],
587592
)

xla/pjrt/gpu/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ xla_cc_test(
190190
"//xla/pjrt/distributed:client",
191191
"//xla/pjrt/distributed:in_memory_key_value_store",
192192
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
193+
"//xla/pjrt/profiling:device_time_measurement",
194+
"//xla/pjrt/profiling/test_util:mock_device_time_measurement",
193195
"//xla/service:gpu_plugin",
194196
"//xla/service:platform_util",
195197
"//xla/stream_executor:device_memory",

xla/pjrt/gpu/se_gpu_pjrt_client_test.cc

+62
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ limitations under the License.
6363
#include "xla/pjrt/pjrt_future.h"
6464
#include "xla/pjrt/pjrt_stream_executor_client.h"
6565
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
66+
#include "xla/pjrt/profiling/device_time_measurement.h"
67+
#include "xla/pjrt/profiling/test_util/mock_device_time_measurement.h"
6668
#include "xla/service/platform_util.h"
6769
#include "xla/shape.h"
6870
#include "xla/shape_util.h"
@@ -1875,6 +1877,66 @@ TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) {
18751877
EXPECT_NE(layouts[1]->ToString(), "{2,1,0}");
18761878
}
18771879

1880+
// Same test as SendRecvChunked, but check GPU device time measurement.
1881+
TEST(StreamExecutorGpuClientTest, NonZeroGPUDeviceTimeMeasurement) {
1882+
TF_ASSERT_OK_AND_ASSIGN(auto client,
1883+
GetStreamExecutorGpuClient(GpuClientOptions()));
1884+
1885+
TF_ASSERT_OK_AND_ASSIGN(auto executable,
1886+
CompileExecutable(kProgram, *client));
1887+
1888+
std::array<float, 2> sent_value = {0.0f, 0.0f};
1889+
1890+
// Send buffer to host.
1891+
SendCallback send_callback = {
1892+
/*channel_id=*/1, [&](const PjRtTransferMetadata& m, PjRtChunk chunk,
1893+
int64_t total_size_in_bytes, bool done) {
1894+
float* data = reinterpret_cast<float*>(chunk.data());
1895+
sent_value[0] = data[0];
1896+
sent_value[1] = data[1];
1897+
return absl::OkStatus();
1898+
}};
1899+
1900+
// Recv buffer from host.
1901+
RecvCallback recv_callback = {
1902+
/*channel_id=*/2, [&](const PjRtTransferMetadata& m,
1903+
std::unique_ptr<CopyToDeviceStream> stream) {
1904+
auto chunk0 = PjRtChunk::AllocateDefault(sizeof(float));
1905+
*reinterpret_cast<float*>(chunk0.data()) = 5.0f;
1906+
TF_CHECK_OK(stream->AddChunk(std::move(chunk0)).Await());
1907+
1908+
auto chunk1 = PjRtChunk::AllocateDefault(sizeof(float));
1909+
*reinterpret_cast<float*>(chunk1.data()) = 6.0f;
1910+
TF_CHECK_OK(stream->AddChunk(std::move(chunk1)).Await());
1911+
1912+
return absl::OkStatus();
1913+
}};
1914+
1915+
// Callbacks for point-to-point communication ops.
1916+
std::vector<std::vector<SendCallback>> send_callbacks = {{send_callback}};
1917+
std::vector<std::vector<RecvCallback>> recv_callbacks = {{recv_callback}};
1918+
1919+
ExecuteOptions opts;
1920+
opts.send_callbacks = send_callbacks;
1921+
opts.recv_callbacks = recv_callbacks;
1922+
1923+
// Test non-zero GPU device time measurement.
1924+
auto measurement0 = CreateDeviceTimeMeasurement();
1925+
auto result = executable->Execute(/*argument_handles=*/{{}}, opts);
1926+
1927+
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<xla::Literal> result_literal,
1928+
ExtractSingleResult(result));
1929+
EXPECT_EQ(sent_value[0], 2.0f);
1930+
EXPECT_EQ(sent_value[1], 3.0f);
1931+
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<float>({5.0f, 6.0f}),
1932+
*result_literal));
1933+
1934+
// Check measurement after execution completes.
1935+
EXPECT_GT(
1936+
measurement0->GetTotalDuration(DeviceTimeMeasurement::DeviceType::kGpu),
1937+
absl::ZeroDuration());
1938+
}
1939+
18781940
struct ShardedAutotuningTestInfo {
18791941
bool use_xla_computation;
18801942
int num_active_nodes;

xla/pjrt/pjrt_stream_executor_client.cc

+34
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ limitations under the License.
114114
#include "xla/pjrt/pjrt_compiler.h"
115115
#include "xla/pjrt/pjrt_executable.h"
116116
#include "xla/pjrt/pjrt_future.h"
117+
#include "xla/pjrt/profiling/device_time_measurement.h"
117118
#include "xla/pjrt/semaphore.h"
118119
#include "xla/pjrt/tracked_device_buffer.h"
119120
#include "xla/pjrt/transpose.h"
@@ -2843,6 +2844,18 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
28432844
device_state->compute_semaphore().ScopedAcquire(1));
28442845
}
28452846

2847+
auto start_time_ns = std::make_shared<uint64_t>();
2848+
std::optional<uint64_t> key = xla::GetDeviceTimeMeasurementKey();
2849+
// Record the start time of the execution by placing a callback on the stream
2850+
// directly before the execution. If this callback is added, another callback
2851+
// will be added directly after the execution to record the elapsed device
2852+
// time.
2853+
if (key.has_value()) {
2854+
TF_RETURN_IF_ERROR(device_state->ThenExecuteCallback(
2855+
device_state->compute_stream(), [start_time_ns]() {
2856+
*start_time_ns = tsl::Env::Default()->NowNanos();
2857+
}));
2858+
}
28462859
absl::StatusOr<ExecutionOutput> result_buffer_or_status =
28472860
executables_[executable_idx]->RunAsync(std::move(execution_inputs),
28482861
run_options);
@@ -2854,6 +2867,27 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
28542867
return result_buffer_or_status.status();
28552868
}
28562869

2870+
// Add a callback on the stream to record the elapsed device time of the
2871+
// executable execution.
2872+
//
2873+
// Do not place other callbacks between the callback recording the start time
2874+
// and this callback because their execution time will incorrectly count
2875+
// toward device execution time.
2876+
//
2877+
// This callback is only added if there is a valid key to guarantee that
2878+
// either both or none of the device time measurement callbacks are added to
2879+
// the stream, and to avoid needing a mutex.
2880+
if (key.has_value()) {
2881+
TF_RETURN_IF_ERROR(device_state->ThenExecuteCallback(
2882+
device_state->compute_stream(),
2883+
[key, start_time_ns,
2884+
device_type = GetDeviceType(client_->platform_id())]() {
2885+
auto elapsed = absl::FromUnixNanos(tsl::Env::Default()->NowNanos()) -
2886+
absl::FromUnixNanos(*start_time_ns);
2887+
xla::RecordDeviceTimeMeasurement(*key, elapsed, device_type);
2888+
}));
2889+
}
2890+
28572891
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
28582892
ExecutionOutput& execution_output = result_buffer_or_status.value();
28592893
// If we used a transient tuple for the arguments we donated its root table

xla/pjrt/profiling/BUILD

+16-11
Original file line numberDiff line numberDiff line change
@@ -28,34 +28,39 @@ exports_files(
2828
)
2929

3030
cc_library(
31-
name = "device_time_measurement",
31+
name = "no_op_device_time_measurement",
32+
srcs = [
33+
"device_time_measurement.h",
34+
"no_op_device_time_measurement.cc",
35+
"no_op_device_time_measurement.h",
36+
],
3237
# copybara:uncomment_begin(google-only)
3338
# compatible_with = ["//buildenv/target:non_prod"],
3439
# copybara:uncomment_end
35-
textual_hdrs = ["device_time_measurement.h"],
3640
deps = [
37-
# copybara:comment_begin(oss-only)
38-
":no_op_device_time_measurement",
39-
# copybara:comment_end
40-
# copybara:uncomment_begin(google-only)
41-
# "//learning/brain/google/runtime:device_runtime_profiling",
42-
# copybara:uncomment_end
41+
"//xla/pjrt:pjrt_compiler",
4342
"@com_google_absl//absl/container:flat_hash_map",
4443
"@com_google_absl//absl/synchronization",
4544
"@com_google_absl//absl/time",
4645
],
4746
)
4847

4948
cc_library(
50-
name = "no_op_device_time_measurement",
51-
srcs = ["device_time_measurement.h"],
52-
hdrs = ["no_op_device_time_measurement.h"],
49+
name = "device_time_measurement",
5350
# copybara:uncomment_begin(google-only)
5451
# compatible_with = ["//buildenv/target:non_prod"],
5552
# copybara:uncomment_end
53+
textual_hdrs = ["device_time_measurement.h"],
5654
deps = [
55+
# copybara:comment_begin(oss-only)
56+
":no_op_device_time_measurement",
57+
# copybara:comment_end
58+
# copybara:uncomment_begin(google-only)
59+
# "//learning/brain/google/runtime:device_runtime_profiling",
60+
# copybara:uncomment_end
5761
"@com_google_absl//absl/container:flat_hash_map",
5862
"@com_google_absl//absl/synchronization",
5963
"@com_google_absl//absl/time",
64+
"//xla/pjrt:pjrt_compiler",
6065
],
6166
)

xla/pjrt/profiling/device_time_measurement.h

+14
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include "absl/container/flat_hash_map.h"
2424
#include "absl/synchronization/mutex.h"
2525
#include "absl/time/time.h"
26+
#include "xla/pjrt/pjrt_compiler.h"
2627

2728
namespace xla {
2829

@@ -79,5 +80,18 @@ void RecordDeviceTimeMeasurement(
7980
uint64_t key, absl::Duration elapsed,
8081
xla::DeviceTimeMeasurement::DeviceType device_type);
8182

83+
// Helper function to convert PjRtPlatformId to
84+
// DeviceTimeMeasurement::DeviceType.
85+
inline DeviceTimeMeasurement::DeviceType GetDeviceType(
86+
PjRtPlatformId platform_id) {
87+
if (platform_id == CudaId() || platform_id == RocmId() ||
88+
platform_id == SyclId()) {
89+
return DeviceTimeMeasurement::DeviceType::kGpu;
90+
} else if (platform_id == TpuId()) {
91+
return DeviceTimeMeasurement::DeviceType::kTpu;
92+
}
93+
return DeviceTimeMeasurement::DeviceType::kUnknown;
94+
}
95+
8296
} // namespace xla
8397
#endif // XLA_PJRT_PROFILING_DEVICE_TIME_MEASUREMENT_H_
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/pjrt/profiling/no_op_device_time_measurement.h"
17+
18+
#include <cstdint>
19+
#include <memory>
20+
#include <optional>
21+
22+
#include "absl/time/time.h"
23+
#include "xla/pjrt/profiling/device_time_measurement.h"
24+
25+
namespace xla {
26+
27+
std::unique_ptr<DeviceTimeMeasurement> CreateDeviceTimeMeasurement() {
28+
return std::make_unique<NoOpDeviceTimeMeasurement>();
29+
}
30+
31+
std::optional<uint64_t> GetDeviceTimeMeasurementKey() { return std::nullopt; }
32+
33+
void RecordDeviceTimeMeasurement(
34+
uint64_t key, absl::Duration elapsed,
35+
xla::DeviceTimeMeasurement::DeviceType device_type) {}
36+
37+
} // namespace xla

xla/pjrt/profiling/no_op_device_time_measurement.h

+4-8
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,13 @@ class NoOpDeviceTimeMeasurement : public DeviceTimeMeasurement {
5252
void Record(absl::Duration elapsed, DeviceType device_type) override {};
5353
};
5454

55-
inline std::unique_ptr<DeviceTimeMeasurement> CreateDeviceTimeMeasurement() {
56-
return std::make_unique<NoOpDeviceTimeMeasurement>();
57-
}
55+
std::unique_ptr<DeviceTimeMeasurement> CreateDeviceTimeMeasurement();
5856

59-
inline std::optional<uint64_t> GetDeviceTimeMeasurementKey() {
60-
return std::nullopt;
61-
}
57+
std::optional<uint64_t> GetDeviceTimeMeasurementKey();
6258

63-
inline void RecordDeviceTimeMeasurement(
59+
void RecordDeviceTimeMeasurement(
6460
uint64_t key, absl::Duration elapsed,
65-
xla::DeviceTimeMeasurement::DeviceType device_type) {}
61+
xla::DeviceTimeMeasurement::DeviceType device_type);
6662

6763
} // namespace xla
6864
#endif // XLA_PJRT_PROFILING_NO_OP_DEVICE_TIME_MEASUREMENT_H_

xla/pjrt/profiling/test_util/BUILD

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
load(
2+
"//xla/tsl:tsl.bzl",
3+
"internal_visibility",
4+
)
5+
6+
package(
7+
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
8+
default_visibility = internal_visibility([
9+
"//xla:internal",
10+
]),
11+
licenses = ["notice"],
12+
)
13+
14+
cc_library(
15+
name = "mock_device_time_measurement",
16+
testonly = True,
17+
srcs = [
18+
"mock_device_time_measurement.cc",
19+
"//xla/pjrt/profiling:device_time_measurement.h",
20+
],
21+
hdrs = ["mock_device_time_measurement.h"],
22+
# copybara:uncomment_begin(google-only)
23+
# compatible_with = ["//buildenv/target:non_prod"],
24+
# copybara:uncomment_end
25+
deps = [
26+
"//xla/pjrt:pjrt_compiler",
27+
"@com_google_absl//absl/container:flat_hash_map",
28+
"@com_google_absl//absl/debugging:leak_check",
29+
"@com_google_absl//absl/synchronization",
30+
"@com_google_absl//absl/time",
31+
],
32+
)

0 commit comments

Comments
 (0)