@@ -28,7 +28,6 @@ limitations under the License.
28
28
#include " absl/status/statusor.h"
29
29
#include " absl/strings/str_cat.h"
30
30
#include " absl/strings/str_format.h"
31
- #include " absl/strings/string_view.h"
32
31
#include " absl/synchronization/mutex.h"
33
32
#include " absl/types/span.h"
34
33
#include " xla/pjrt/pjrt_layout.h"
@@ -48,7 +47,7 @@ limitations under the License.
48
47
// DisassembleIntoSingleDeviceArrays, Reshard, FullyReplicatedShard,
49
48
// CopyToHostBuffer and AssembleFromSingleDeviceArrays share a common pattern
50
49
// that waits for the source array(s) buffers to become ready and then copies
51
- // the data into a new array's buffer backing store . Factor out the common
50
+ // the data into a new array's buffer. Factor out the common
52
51
// pattern into a helper function.
53
52
54
53
namespace xla {
@@ -104,7 +103,7 @@ absl::StatusOr<tsl::RCReference<BasicStringArray>> BasicStringArray::Create(
104
103
auto ready_future = Future<>(ready_promise);
105
104
106
105
// Buffers when the become ready must be consistent with the sharding. For
107
- // instance, Buffers.size() (the number of per-shard spans of string_views )
106
+ // instance, Buffers.size() (the number of per-shard spans of absl::Cords )
108
107
// and the devices in the sharding that was used to create an array must
109
108
// match. If they do not, the array's ready future and buffers future should
110
109
// become ready with an appropriate error status.
@@ -216,66 +215,62 @@ BasicStringArray::DisassembleIntoSingleDeviceArrays(
216
215
// For each single device array we are going to pre-make:
217
216
// (1) a Promise-Future pair for passing the buffers,
218
217
//
219
- // (2) a Per-shard buffer backing store and the corresponding
220
- // on-done-with-buffer callback.
218
+ // (2) a Per-shard data store and the corresponding on-done-with-buffer
219
+ // callback.
221
220
//
222
221
// (3) shape and sharding by disassembing the source array's sharding.
223
222
//
224
223
// The Futures, the on-done-with-host-buffer callbacks, shapes and shardings
225
- // are used to make the arrays. The promises and the buffer backing stores
224
+ // are used to make the arrays. The promises and the per-shard stores
226
225
// are passed onto the OnReady callback that populates them when the buffers
227
226
// of the source array become ready.
228
227
std::vector<Promise<Buffers>> buffer_promises;
229
228
buffer_promises.reserve (num_shards);
230
229
std::vector<Future<Buffers>> buffer_futures;
231
230
buffer_futures.reserve (num_shards);
232
231
233
- struct PerShardBufferBackingStore { // Data (strings) for a single shard.
234
- void CopyFrom (absl::Span<const absl::string_view > input_buffer) {
232
+ struct PerShardStringStore { // Data (strings) for a single shard.
233
+ void CopyFrom (absl::Span<const absl::Cord > input_buffer) {
235
234
strings.reserve (input_buffer.size ());
236
- string_views.reserve (input_buffer.size ());
237
- for (absl::string_view buf : input_buffer) {
238
- strings.push_back (std::string (buf.data (), buf.size ()));
239
- string_views.push_back (strings.back ());
235
+ for (const auto & input_string : input_buffer) {
236
+ strings.push_back (input_string);
240
237
}
241
238
}
242
- std::vector<std::string> strings;
243
- std::vector<absl::string_view> string_views;
239
+ std::vector<absl::Cord> strings;
244
240
};
245
- std::vector<std::shared_ptr<PerShardBufferBackingStore>>
246
- per_shard_buffer_backing_stores ;
247
- per_shard_buffer_backing_stores .reserve (num_shards);
241
+
242
+ std::vector<std::shared_ptr<PerShardStringStore>> per_shard_strings ;
243
+ per_shard_strings .reserve (num_shards);
248
244
std::vector<OnDoneWithBuffer> on_done_with_buffer_callbacks;
249
245
on_done_with_buffer_callbacks.reserve (num_shards);
250
246
251
247
for (int i = 0 ; i < num_shards; ++i) {
252
248
buffer_promises.push_back (Future<Buffers>::CreatePromise ());
253
249
buffer_futures.push_back (Future<Buffers>(buffer_promises.back ()));
254
250
255
- auto backing_store = std::make_shared<PerShardBufferBackingStore >();
256
- per_shard_buffer_backing_stores .push_back (backing_store );
251
+ auto current_shard_strings = std::make_shared<PerShardStringStore >();
252
+ per_shard_strings .push_back (current_shard_strings );
257
253
on_done_with_buffer_callbacks.push_back (
258
- [backing_store = std::move (backing_store )]() {});
254
+ [data = std::move (current_shard_strings )]() {});
259
255
}
260
256
261
- // Copy each of the per-shard data into the its per-shard buffer backing
262
- // store, make a Buffers object and set the corresponding promise .
257
+ // When the buffers become ready, copy each of the per-shard data into the
258
+ // buffer of the corresponding single-device array .
263
259
buffers_.OnReady ([buffer_promises = std::move (buffer_promises),
264
- per_shard_buffer_backing_stores =
265
- std::move (per_shard_buffer_backing_stores)](
260
+ per_shard_data = std::move (per_shard_strings)](
266
261
absl::StatusOr<Buffers> buffers) mutable {
267
262
if (!buffers.ok ()) {
268
263
for (auto & promise : buffer_promises) {
269
264
promise.Set (buffers.status ());
270
265
}
271
- per_shard_buffer_backing_stores .clear ();
266
+ per_shard_data .clear ();
272
267
return ;
273
268
}
274
269
auto num_shards = buffers->size ();
275
270
for (int i = 0 ; i < num_shards; ++i) {
276
- per_shard_buffer_backing_stores [i]->CopyFrom ((*buffers)[i]);
271
+ per_shard_data [i]->CopyFrom ((*buffers)[i]);
277
272
Buffers buffers;
278
- buffers.push_back (per_shard_buffer_backing_stores [i]->string_views );
273
+ buffers.push_back (absl::MakeConstSpan (per_shard_data [i]->strings ) );
279
274
buffer_promises[i].Set (std::move (buffers));
280
275
}
281
276
});
@@ -325,29 +320,24 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Copy(
325
320
sharding_->devices ()->size ()));
326
321
}
327
322
328
- struct BufferBackingStore {
329
- void AddShardData (absl::Span<const absl::string_view > input_buffer) {
323
+ struct StringStore {
324
+ void AddShardData (absl::Span<const absl::Cord > input_buffer) {
330
325
auto & shard_strings = strings.emplace_back ();
331
326
shard_strings.reserve (input_buffer.size ());
332
327
333
- auto & shard_string_views = string_views.emplace_back ();
334
- shard_string_views.reserve (input_buffer.size ());
335
-
336
- for (absl::string_view buf : input_buffer) {
337
- shard_strings.push_back (std::string (buf.data (), buf.size ()));
338
- shard_string_views.push_back (shard_strings.back ());
328
+ for (const auto & input_string : input_buffer) {
329
+ shard_strings.push_back (input_string);
339
330
}
340
331
}
341
- std::vector<std::vector<std::string>> strings;
342
- std::vector<std::vector<absl::string_view>> string_views;
332
+ std::vector<std::vector<absl::Cord>> strings;
343
333
};
344
334
345
- auto backing_store = std::make_shared<BufferBackingStore >();
346
- auto on_done_with_buffer = [backing_store ]() {};
335
+ auto string_store = std::make_shared<StringStore >();
336
+ auto on_done_with_buffer = [string_store ]() {};
347
337
auto buffers_promise = Future<Buffers>::CreatePromise ();
348
338
auto buffers_future = Future<Buffers>(buffers_promise);
349
339
350
- auto copier = [backing_store = std::move (backing_store ),
340
+ auto copier = [string_store = std::move (string_store ),
351
341
buffers_promise = std::move (buffers_promise)](
352
342
absl::StatusOr<Buffers> input_buffers) mutable {
353
343
if (!input_buffers.ok ()) {
@@ -357,8 +347,8 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Copy(
357
347
Buffers buffers;
358
348
buffers.reserve (input_buffers->size ());
359
349
for (auto & input_buffer : *input_buffers) {
360
- backing_store ->AddShardData (input_buffer);
361
- buffers.push_back (backing_store-> string_views .back ());
350
+ string_store ->AddShardData (input_buffer);
351
+ buffers.push_back (string_store-> strings .back ());
362
352
}
363
353
buffers_promise.Set (std::move (buffers));
364
354
};
@@ -384,25 +374,22 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::FullyReplicatedShard(
384
374
if (!sharding_->IsFullyReplicated ()) {
385
375
return absl::FailedPreconditionError (" This array is not fully replicated" );
386
376
}
387
- struct BufferBackingStore { // Data (strings) for a single shard.
388
- void CopyFrom (absl::Span<const absl::string_view > input_buffer) {
377
+ struct StringStore { // Data (strings) for a single shard.
378
+ void CopyFrom (absl::Span<const absl::Cord > input_buffer) {
389
379
strings.reserve (input_buffer.size ());
390
- string_views.reserve (input_buffer.size ());
391
- for (absl::string_view buf : input_buffer) {
392
- strings.push_back (std::string (buf.data (), buf.size ()));
393
- string_views.push_back (strings.back ());
380
+ for (const auto & input_strings : input_buffer) {
381
+ strings.push_back (input_strings);
394
382
}
395
383
}
396
- std::vector<std::string> strings;
397
- std::vector<absl::string_view> string_views;
384
+ std::vector<absl::Cord> strings;
398
385
};
399
386
400
- auto backing_store = std::make_shared<BufferBackingStore >();
401
- auto on_done_with_buffer = [backing_store ]() {};
387
+ auto string_store = std::make_shared<StringStore >();
388
+ auto on_done_with_buffer = [string_store ]() {};
402
389
auto buffers_promise = Future<Buffers>::CreatePromise ();
403
390
auto buffers_future = Future<Buffers>(buffers_promise);
404
391
405
- auto copier = [backing_store = std::move (backing_store ),
392
+ auto copier = [string_store = std::move (string_store ),
406
393
buffers_promise = std::move (buffers_promise)](
407
394
absl::StatusOr<Buffers> input_buffers) mutable {
408
395
if (!input_buffers.ok ()) {
@@ -414,10 +401,10 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::FullyReplicatedShard(
414
401
// were run when the source array's buffers became ready would have
415
402
// ensured that the input_buffers have at least one shard's worth of data.
416
403
auto & input_buffer = (*input_buffers)[0 ];
417
- backing_store ->CopyFrom (input_buffer);
404
+ string_store ->CopyFrom (input_buffer);
418
405
419
406
Buffers buffers;
420
- buffers.push_back (backing_store-> string_views );
407
+ buffers.push_back (string_store-> strings );
421
408
buffers_promise.Set (std::move (buffers));
422
409
};
423
410
buffers_.OnReady (std::move (copier));
0 commit comments