Skip to content

Commit 845451d

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[xla:cpu] Migrate AllReduce to unified collectives API
PiperOrigin-RevId: 711530846
1 parent 95de515 commit 845451d

16 files changed

+197
-88
lines changed

third_party/xla/xla/backends/cpu/collectives/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ cc_library(
2525
"//xla/core/collectives",
2626
"//xla/core/collectives:collectives_registry",
2727
"//xla/core/collectives:communicator",
28+
"//xla/service:collective_ops_utils",
2829
"@com_google_absl//absl/log",
2930
"@com_google_absl//absl/log:check",
3031
"@com_google_absl//absl/status",
3132
"@com_google_absl//absl/status:statusor",
33+
"@com_google_absl//absl/time",
3234
"@local_tsl//tsl/platform:casts",
3335
],
3436
)

third_party/xla/xla/backends/cpu/collectives/cpu_collectives.cc

+24
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ limitations under the License.
1818
#include "absl/log/check.h"
1919
#include "absl/log/log.h"
2020
#include "absl/status/statusor.h"
21+
#include "absl/time/time.h"
2122
#include "xla/core/collectives/collectives.h"
2223
#include "xla/core/collectives/collectives_registry.h"
24+
#include "xla/core/collectives/communicator.h"
25+
#include "xla/service/collective_ops_utils.h"
26+
#include "xla/util.h"
2327
#include "tsl/platform/casts.h"
2428

2529
namespace xla::cpu {
@@ -36,4 +40,24 @@ CpuCollectives* CpuCollectives::Default() {
3640
LOG(FATAL) << "Unsupported collectives implementation for CPU";
3741
}
3842

43+
absl::StatusOr<const CpuCollectives::Device*> CpuCollectives::TryCast(
44+
const Collectives::Device* device) {
45+
if (auto* cpu_device = tsl::down_cast<const Device*>(device)) {
46+
return cpu_device;
47+
}
48+
return InvalidArgument("Collectives device is not a CPU device");
49+
}
50+
51+
absl::StatusOr<const CpuCollectives::Executor*> CpuCollectives::TryCast(
52+
const Communicator::Executor* executor) {
53+
if (auto* cpu_executor = tsl::down_cast<const Executor*>(executor)) {
54+
return cpu_executor;
55+
}
56+
return InvalidArgument("Collectives executor is not a CPU executor");
57+
}
58+
59+
CpuCollectives::Executor::Executor(RendezvousKey rendezvous_key,
60+
absl::Duration timeout)
61+
: rendezvous_key_(rendezvous_key), timeout_(timeout) {}
62+
3963
} // namespace xla::cpu

third_party/xla/xla/backends/cpu/collectives/cpu_collectives.h

+21-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ limitations under the License.
1616
#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_
1717
#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_
1818

19+
#include "absl/status/statusor.h"
20+
#include "absl/time/time.h"
1921
#include "xla/core/collectives/collectives.h"
2022
#include "xla/core/collectives/communicator.h"
23+
#include "xla/service/collective_ops_utils.h"
2124
#include "xla/xla_data.pb.h"
2225

2326
namespace xla::cpu {
@@ -33,10 +36,27 @@ class CpuCollectives : public Collectives {
3336
Device() = default;
3437
};
3538

39+
// Executor allows CPU collectives clients to pass additional information to
40+
// the collectives implementation.
3641
class Executor : public Communicator::Executor {
3742
public:
38-
Executor() = default;
43+
Executor(RendezvousKey rendezvous_key, absl::Duration timeout);
44+
45+
const RendezvousKey& rendezvous_key() const { return rendezvous_key_; }
46+
const absl::Duration& timeout() const { return timeout_; }
47+
48+
private:
49+
RendezvousKey rendezvous_key_;
50+
absl::Duration timeout_;
3951
};
52+
53+
// Tries to cast a Collectives::Device to a CpuCollectives::Device.
54+
static absl::StatusOr<const Device*> TryCast(
55+
const Collectives::Device* device);
56+
57+
// Tries to cast a Communicator::Executor to a CpuCollectives::Executor.
58+
static absl::StatusOr<const Executor*> TryCast(
59+
const Communicator::Executor* executor);
4060
};
4161

4262
} // namespace xla::cpu

third_party/xla/xla/backends/cpu/runtime/BUILD

+4
Original file line numberDiff line numberDiff line change
@@ -458,13 +458,17 @@ cc_library(
458458
"//xla:status_macros",
459459
"//xla:util",
460460
"//xla:xla_data_proto_cc",
461+
"//xla/backends/cpu/collectives:cpu_collectives",
461462
"//xla/runtime:buffer_use",
462463
"//xla/service:buffer_assignment",
463464
"//xla/service:collective_ops_utils",
464465
"//xla/service/cpu:collectives_interface",
465466
"//xla/tsl/concurrency:async_value",
467+
"//xla/tsl/platform:errors",
466468
"@com_google_absl//absl/algorithm:container",
467469
"@com_google_absl//absl/container:inlined_vector",
470+
"@com_google_absl//absl/log",
471+
"@com_google_absl//absl/log:check",
468472
"@com_google_absl//absl/memory",
469473
"@com_google_absl//absl/status",
470474
"@com_google_absl//absl/status:statusor",

third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc

+7-6
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ limitations under the License.
2121
#include <utility>
2222

2323
#include "absl/container/inlined_vector.h"
24+
#include "absl/log/check.h"
25+
#include "absl/log/log.h"
2426
#include "absl/memory/memory.h"
2527
#include "absl/status/status.h"
2628
#include "absl/status/statusor.h"
2729
#include "absl/strings/str_format.h"
28-
#include "absl/types/span.h"
30+
#include "xla/backends/cpu/collectives/cpu_collectives.h"
2931
#include "xla/backends/cpu/runtime/collective_thunk.h"
3032
#include "xla/backends/cpu/runtime/thunk.h"
3133
#include "xla/primitive_util.h"
@@ -35,9 +37,8 @@ limitations under the License.
3537
#include "xla/shape.h"
3638
#include "xla/shape_util.h"
3739
#include "xla/tsl/concurrency/async_value_ref.h"
40+
#include "xla/tsl/platform/errors.h"
3841
#include "xla/util.h"
39-
#include "tsl/platform/errors.h"
40-
#include "tsl/platform/logging.h"
4142
#include "tsl/platform/statusor.h"
4243
#include "tsl/profiler/lib/traceme.h"
4344

@@ -102,12 +103,12 @@ tsl::AsyncValueRef<AllReduceThunk::ExecuteEvent> AllReduceThunk::Execute(
102103
return ExecuteWithCommunicator(
103104
params.collective_params,
104105
[&](const RendezvousKey& key, CollectivesCommunicator& comm) {
106+
CpuCollectives::Executor executor(key, DefaultCollectiveTimeout());
105107
for (int32_t i = 0; i < data.source.size(); ++i) {
106108
const Shape& shape = destination_shape(i);
107109
TF_RETURN_IF_ERROR(comm.AllReduce(
108-
key, reduction_kind_, shape.element_type(),
109-
ShapeUtil::ElementsIn(shape), data.source[i].opaque(),
110-
data.destination[i].opaque(), DefaultCollectiveTimeout()));
110+
data.source[i], data.destination[i], shape.element_type(),
111+
ShapeUtil::ElementsIn(shape), reduction_kind_, executor));
111112
}
112113
return absl::OkStatus();
113114
});

third_party/xla/xla/pjrt/cpu/BUILD

+12-6
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,13 @@ cc_library(
299299
"//xla:status_macros",
300300
"//xla:types",
301301
"//xla:xla_data_proto_cc",
302+
"//xla/backends/cpu/collectives:cpu_collectives",
302303
"//xla/service:collective_ops_utils",
303304
"//xla/service:global_device_id",
304305
"//xla/service/cpu:collectives_interface",
306+
"//xla/stream_executor:device_memory",
307+
"//xla/tsl/platform:errors",
308+
"//xla/tsl/platform:statusor",
305309
"@com_google_absl//absl/base:core_headers",
306310
"@com_google_absl//absl/container:flat_hash_map",
307311
"@com_google_absl//absl/status",
@@ -325,21 +329,23 @@ xla_cc_test(
325329
":gloo_kv_store",
326330
"//xla:executable_run_options",
327331
"//xla:xla_data_proto_cc",
332+
"//xla/backends/cpu/collectives:cpu_collectives",
328333
"//xla/pjrt/distributed:in_memory_key_value_store",
329334
"//xla/pjrt/distributed:key_value_store_interface",
330335
"//xla/service:collective_ops_utils",
331336
"//xla/service:global_device_id",
332337
"//xla/service/cpu:collectives_interface",
338+
"//xla/stream_executor:device_memory",
333339
"//xla/tsl/lib/core:status_test_util",
340+
"//xla/tsl/platform:env",
341+
"//xla/tsl/platform:errors",
342+
"//xla/tsl/platform:statusor",
343+
"//xla/tsl/platform:test",
344+
"//xla/tsl/platform:test_benchmark",
345+
"//xla/tsl/platform:test_main",
334346
"@com_google_absl//absl/status:statusor",
335347
"@com_google_absl//absl/time",
336348
"@com_google_absl//absl/types:span",
337-
"@local_tsl//tsl/platform:env",
338-
"@local_tsl//tsl/platform:errors",
339-
"@local_tsl//tsl/platform:statusor",
340-
"@local_tsl//tsl/platform:test",
341-
"@local_tsl//tsl/platform:test_benchmark",
342-
"@local_tsl//tsl/platform:test_main",
343349
] + select({
344350
# Gloo's transport_tcp is not available on MacOS
345351
"//xla/tsl:macos": [

third_party/xla/xla/pjrt/cpu/gloo_collectives.cc

+33-27
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,17 @@ limitations under the License.
4747
#include "gloo/transport/device.h"
4848
#include "gloo/transport/unbound_buffer.h"
4949
#include "gloo/types.h"
50+
#include "xla/backends/cpu/collectives/cpu_collectives.h"
5051
#include "xla/primitive_util.h"
5152
#include "xla/service/collective_ops_utils.h"
5253
#include "xla/service/cpu/collectives_interface.h"
5354
#include "xla/service/global_device_id.h"
5455
#include "xla/status_macros.h"
56+
#include "xla/stream_executor/device_memory.h"
57+
#include "xla/tsl/platform/errors.h"
58+
#include "xla/tsl/platform/statusor.h"
5559
#include "xla/types.h"
5660
#include "xla/xla_data.pb.h"
57-
#include "tsl/platform/errors.h"
58-
#include "tsl/platform/logging.h"
5961

6062
namespace xla::cpu {
6163

@@ -66,14 +68,16 @@ GlooCollectivesCommunicator::~GlooCollectivesCommunicator() = default;
6668

6769
template <typename T>
6870
static absl::Status SetAllReduceOptions(ReductionKind reduction_kind,
69-
const void* input_buffer,
70-
void* output_buffer,
71+
se::DeviceMemoryBase input_buffer,
72+
se::DeviceMemoryBase output_buffer,
7173
size_t num_elements,
7274
gloo::AllreduceOptions& options) {
73-
options.setInput(reinterpret_cast<T*>(const_cast<void*>(input_buffer)),
74-
num_elements);
75-
options.setOutput(reinterpret_cast<T*>(const_cast<void*>(output_buffer)),
76-
num_elements);
75+
options.setInput(
76+
reinterpret_cast<T*>(const_cast<void*>(input_buffer.opaque())),
77+
num_elements);
78+
options.setOutput(
79+
reinterpret_cast<T*>(const_cast<void*>(output_buffer.opaque())),
80+
num_elements);
7781

7882
using ReductionFn = void (*)(void*, const void*, const void*, size_t);
7983

@@ -105,75 +109,77 @@ static absl::Status SetAllReduceOptions(ReductionKind reduction_kind,
105109
}
106110

107111
absl::Status GlooCollectivesCommunicator::AllReduce(
108-
const RendezvousKey& key, ReductionKind reduction_kind,
109-
PrimitiveType element_type, size_t num_elements, const void* input_buffer,
110-
void* output_buffer, absl::Duration timeout) {
112+
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
113+
PrimitiveType dtype, size_t count, ReductionKind reduction_kind,
114+
const Executor& executor) {
115+
TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor));
116+
111117
gloo::AllreduceOptions options(context_);
112118
// TODO(phawkins): how to do tags?
113119
// options.setTag(tag);
114-
switch (element_type) {
120+
switch (dtype) {
115121
case S8:
116122
TF_RETURN_IF_ERROR(SetAllReduceOptions<int8_t>(
117-
reduction_kind, input_buffer, output_buffer, num_elements, options));
123+
reduction_kind, send_buffer, recv_buffer, count, options));
118124
break;
119125
case PRED:
120126
case U8:
121127
TF_RETURN_IF_ERROR(SetAllReduceOptions<uint8_t>(
122-
reduction_kind, input_buffer, output_buffer, num_elements, options));
128+
reduction_kind, send_buffer, recv_buffer, count, options));
123129
break;
124130
case S16:
125131
TF_RETURN_IF_ERROR(SetAllReduceOptions<int16_t>(
126-
reduction_kind, input_buffer, output_buffer, num_elements, options));
132+
reduction_kind, send_buffer, recv_buffer, count, options));
127133
break;
128134
case U16:
129135
TF_RETURN_IF_ERROR(SetAllReduceOptions<uint16_t>(
130-
reduction_kind, input_buffer, output_buffer, num_elements, options));
136+
reduction_kind, send_buffer, recv_buffer, count, options));
131137
break;
132138
case S32:
133139
TF_RETURN_IF_ERROR(SetAllReduceOptions<int32_t>(
134-
reduction_kind, input_buffer, output_buffer, num_elements, options));
140+
reduction_kind, send_buffer, recv_buffer, count, options));
135141
break;
136142
case U32:
137143
TF_RETURN_IF_ERROR(SetAllReduceOptions<uint32_t>(
138-
reduction_kind, input_buffer, output_buffer, num_elements, options));
144+
reduction_kind, send_buffer, recv_buffer, count, options));
139145
break;
140146
case S64:
141147
TF_RETURN_IF_ERROR(SetAllReduceOptions<int64_t>(
142-
reduction_kind, input_buffer, output_buffer, num_elements, options));
148+
reduction_kind, send_buffer, recv_buffer, count, options));
143149
break;
144150
case U64:
145151
TF_RETURN_IF_ERROR(SetAllReduceOptions<uint64_t>(
146-
reduction_kind, input_buffer, output_buffer, num_elements, options));
152+
reduction_kind, send_buffer, recv_buffer, count, options));
147153
break;
148154
case F16:
149155
TF_RETURN_IF_ERROR(SetAllReduceOptions<gloo::float16>(
150-
reduction_kind, input_buffer, output_buffer, num_elements, options));
156+
reduction_kind, send_buffer, recv_buffer, count, options));
151157
break;
152158
case BF16:
153159
TF_RETURN_IF_ERROR(SetAllReduceOptions<bfloat16>(
154-
reduction_kind, input_buffer, output_buffer, num_elements, options));
160+
reduction_kind, send_buffer, recv_buffer, count, options));
155161
break;
156162
case F32:
157163
TF_RETURN_IF_ERROR(SetAllReduceOptions<float>(
158-
reduction_kind, input_buffer, output_buffer, num_elements, options));
164+
reduction_kind, send_buffer, recv_buffer, count, options));
159165
break;
160166
case F64:
161167
TF_RETURN_IF_ERROR(SetAllReduceOptions<double>(
162-
reduction_kind, input_buffer, output_buffer, num_elements, options));
168+
reduction_kind, send_buffer, recv_buffer, count, options));
163169
break;
164170
case C64:
165171
TF_RETURN_IF_ERROR(SetAllReduceOptions<std::complex<float>>(
166-
reduction_kind, input_buffer, output_buffer, num_elements, options));
172+
reduction_kind, send_buffer, recv_buffer, count, options));
167173
break;
168174
case C128:
169175
TF_RETURN_IF_ERROR(SetAllReduceOptions<std::complex<double>>(
170-
reduction_kind, input_buffer, output_buffer, num_elements, options));
176+
reduction_kind, send_buffer, recv_buffer, count, options));
171177
break;
172178
default:
173179
return absl::InvalidArgumentError("Unknown datatype in allreduce");
174180
}
175181
options.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING);
176-
options.setTimeout(absl::ToChronoMilliseconds(timeout));
182+
options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout()));
177183

178184
try {
179185
gloo::allreduce(options);

third_party/xla/xla/pjrt/cpu/gloo_collectives.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ class GlooCollectivesCommunicator : public CollectivesCommunicator {
4444
explicit GlooCollectivesCommunicator(std::shared_ptr<gloo::Context> context);
4545
~GlooCollectivesCommunicator() override;
4646

47-
absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind,
48-
PrimitiveType element_type, size_t num_elements,
49-
const void* input_buffer, void* output_buffer,
50-
absl::Duration timeout) override;
47+
absl::Status AllReduce(se::DeviceMemoryBase send_buffer,
48+
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
49+
size_t count, ReductionKind reduction_kind,
50+
const Executor& executor) override;
5151
absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes,
5252
std::optional<int> source_rank,
5353
absl::Span<int const> target_ranks,

0 commit comments

Comments
 (0)