@@ -47,15 +47,17 @@ limitations under the License.
47
47
#include " gloo/transport/device.h"
48
48
#include " gloo/transport/unbound_buffer.h"
49
49
#include " gloo/types.h"
50
+ #include " xla/backends/cpu/collectives/cpu_collectives.h"
50
51
#include " xla/primitive_util.h"
51
52
#include " xla/service/collective_ops_utils.h"
52
53
#include " xla/service/cpu/collectives_interface.h"
53
54
#include " xla/service/global_device_id.h"
54
55
#include " xla/status_macros.h"
56
+ #include " xla/stream_executor/device_memory.h"
57
+ #include " xla/tsl/platform/errors.h"
58
+ #include " xla/tsl/platform/statusor.h"
55
59
#include " xla/types.h"
56
60
#include " xla/xla_data.pb.h"
57
- #include " tsl/platform/errors.h"
58
- #include " tsl/platform/logging.h"
59
61
60
62
namespace xla ::cpu {
61
63
@@ -66,14 +68,16 @@ GlooCollectivesCommunicator::~GlooCollectivesCommunicator() = default;
66
68
67
69
template <typename T>
68
70
static absl::Status SetAllReduceOptions (ReductionKind reduction_kind,
69
- const void * input_buffer,
70
- void * output_buffer,
71
+ se::DeviceMemoryBase input_buffer,
72
+ se::DeviceMemoryBase output_buffer,
71
73
size_t num_elements,
72
74
gloo::AllreduceOptions& options) {
73
- options.setInput (reinterpret_cast <T*>(const_cast <void *>(input_buffer)),
74
- num_elements);
75
- options.setOutput (reinterpret_cast <T*>(const_cast <void *>(output_buffer)),
76
- num_elements);
75
+ options.setInput (
76
+ reinterpret_cast <T*>(const_cast <void *>(input_buffer.opaque ())),
77
+ num_elements);
78
+ options.setOutput (
79
+ reinterpret_cast <T*>(const_cast <void *>(output_buffer.opaque ())),
80
+ num_elements);
77
81
78
82
using ReductionFn = void (*)(void *, const void *, const void *, size_t );
79
83
@@ -105,75 +109,77 @@ static absl::Status SetAllReduceOptions(ReductionKind reduction_kind,
105
109
}
106
110
107
111
absl::Status GlooCollectivesCommunicator::AllReduce (
108
- const RendezvousKey& key, ReductionKind reduction_kind,
109
- PrimitiveType element_type, size_t num_elements, const void * input_buffer,
110
- void * output_buffer, absl::Duration timeout) {
112
+ se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
113
+ PrimitiveType dtype, size_t count, ReductionKind reduction_kind,
114
+ const Executor& executor) {
115
+ TF_ASSIGN_OR_RETURN (auto cpu_executor, CpuCollectives::TryCast (&executor));
116
+
111
117
gloo::AllreduceOptions options (context_);
112
118
// TODO(phawkins): how to do tags?
113
119
// options.setTag(tag);
114
- switch (element_type ) {
120
+ switch (dtype ) {
115
121
case S8:
116
122
TF_RETURN_IF_ERROR (SetAllReduceOptions<int8_t >(
117
- reduction_kind, input_buffer, output_buffer, num_elements , options));
123
+ reduction_kind, send_buffer, recv_buffer, count , options));
118
124
break ;
119
125
case PRED:
120
126
case U8:
121
127
TF_RETURN_IF_ERROR (SetAllReduceOptions<uint8_t >(
122
- reduction_kind, input_buffer, output_buffer, num_elements , options));
128
+ reduction_kind, send_buffer, recv_buffer, count , options));
123
129
break ;
124
130
case S16:
125
131
TF_RETURN_IF_ERROR (SetAllReduceOptions<int16_t >(
126
- reduction_kind, input_buffer, output_buffer, num_elements , options));
132
+ reduction_kind, send_buffer, recv_buffer, count , options));
127
133
break ;
128
134
case U16:
129
135
TF_RETURN_IF_ERROR (SetAllReduceOptions<uint16_t >(
130
- reduction_kind, input_buffer, output_buffer, num_elements , options));
136
+ reduction_kind, send_buffer, recv_buffer, count , options));
131
137
break ;
132
138
case S32:
133
139
TF_RETURN_IF_ERROR (SetAllReduceOptions<int32_t >(
134
- reduction_kind, input_buffer, output_buffer, num_elements , options));
140
+ reduction_kind, send_buffer, recv_buffer, count , options));
135
141
break ;
136
142
case U32:
137
143
TF_RETURN_IF_ERROR (SetAllReduceOptions<uint32_t >(
138
- reduction_kind, input_buffer, output_buffer, num_elements , options));
144
+ reduction_kind, send_buffer, recv_buffer, count , options));
139
145
break ;
140
146
case S64:
141
147
TF_RETURN_IF_ERROR (SetAllReduceOptions<int64_t >(
142
- reduction_kind, input_buffer, output_buffer, num_elements , options));
148
+ reduction_kind, send_buffer, recv_buffer, count , options));
143
149
break ;
144
150
case U64:
145
151
TF_RETURN_IF_ERROR (SetAllReduceOptions<uint64_t >(
146
- reduction_kind, input_buffer, output_buffer, num_elements , options));
152
+ reduction_kind, send_buffer, recv_buffer, count , options));
147
153
break ;
148
154
case F16:
149
155
TF_RETURN_IF_ERROR (SetAllReduceOptions<gloo::float16>(
150
- reduction_kind, input_buffer, output_buffer, num_elements , options));
156
+ reduction_kind, send_buffer, recv_buffer, count , options));
151
157
break ;
152
158
case BF16:
153
159
TF_RETURN_IF_ERROR (SetAllReduceOptions<bfloat16>(
154
- reduction_kind, input_buffer, output_buffer, num_elements , options));
160
+ reduction_kind, send_buffer, recv_buffer, count , options));
155
161
break ;
156
162
case F32:
157
163
TF_RETURN_IF_ERROR (SetAllReduceOptions<float >(
158
- reduction_kind, input_buffer, output_buffer, num_elements , options));
164
+ reduction_kind, send_buffer, recv_buffer, count , options));
159
165
break ;
160
166
case F64:
161
167
TF_RETURN_IF_ERROR (SetAllReduceOptions<double >(
162
- reduction_kind, input_buffer, output_buffer, num_elements , options));
168
+ reduction_kind, send_buffer, recv_buffer, count , options));
163
169
break ;
164
170
case C64:
165
171
TF_RETURN_IF_ERROR (SetAllReduceOptions<std::complex<float >>(
166
- reduction_kind, input_buffer, output_buffer, num_elements , options));
172
+ reduction_kind, send_buffer, recv_buffer, count , options));
167
173
break ;
168
174
case C128:
169
175
TF_RETURN_IF_ERROR (SetAllReduceOptions<std::complex<double >>(
170
- reduction_kind, input_buffer, output_buffer, num_elements , options));
176
+ reduction_kind, send_buffer, recv_buffer, count , options));
171
177
break ;
172
178
default :
173
179
return absl::InvalidArgumentError (" Unknown datatype in allreduce" );
174
180
}
175
181
options.setAlgorithm (gloo::AllreduceOptions::Algorithm::RING);
176
- options.setTimeout (absl::ToChronoMilliseconds (timeout));
182
+ options.setTimeout (absl::ToChronoMilliseconds (cpu_executor-> timeout () ));
177
183
178
184
try {
179
185
gloo::allreduce (options);
0 commit comments