Skip to content

Commit 31e7e36

Browse files
Updates the BasicStringArray class to use absl::Cord as the element type.
Before this change, it was using absl::string_view and switching to absl::Cord allows both the IFRT client and its users more flexibility as well as opportunities for optimizations by allowing the strings to be: either included inline, or be readonly views of existing strings (say, in a numpy array or a tensor). PiperOrigin-RevId: 689556652
1 parent 6c0ce17 commit 31e7e36

File tree

6 files changed

+89
-116
lines changed

6 files changed

+89
-116
lines changed

xla/python/pjrt_ifrt/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ cc_library(
332332
"@com_google_absl//absl/status",
333333
"@com_google_absl//absl/status:statusor",
334334
"@com_google_absl//absl/strings",
335+
"@com_google_absl//absl/strings:cord",
335336
"@com_google_absl//absl/strings:str_format",
336-
"@com_google_absl//absl/strings:string_view",
337337
"@com_google_absl//absl/synchronization",
338338
"@com_google_absl//absl/types:span",
339339
"@llvm-project//llvm:Support",
@@ -357,6 +357,7 @@ xla_cc_test(
357357
"@com_google_absl//absl/log",
358358
"@com_google_absl//absl/status",
359359
"@com_google_absl//absl/strings",
360+
"@com_google_absl//absl/strings:cord",
360361
"@com_google_absl//absl/strings:string_view",
361362
"@com_google_absl//absl/synchronization",
362363
"@com_google_absl//absl/types:span",

xla/python/pjrt_ifrt/basic_string_array.cc

+42-55
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ limitations under the License.
2828
#include "absl/status/statusor.h"
2929
#include "absl/strings/str_cat.h"
3030
#include "absl/strings/str_format.h"
31-
#include "absl/strings/string_view.h"
3231
#include "absl/synchronization/mutex.h"
3332
#include "absl/types/span.h"
3433
#include "xla/pjrt/pjrt_layout.h"
@@ -48,7 +47,7 @@ limitations under the License.
4847
// DisassembleIntoSingleDeviceArrays, Reshard, FullyReplicatedShard,
4948
// CopyToHostBuffer and AssembleFromSingleDeviceArrays share a common pattern
5049
// that waits for the source array(s) buffers to become ready and then copies
51-
// the data into a new array's buffer backing store. Factor out the common
50+
// the data into a new array's buffer. Factor out the common
5251
// pattern into a helper function.
5352

5453
namespace xla {
@@ -104,7 +103,7 @@ absl::StatusOr<tsl::RCReference<BasicStringArray>> BasicStringArray::Create(
104103
auto ready_future = Future<>(ready_promise);
105104

106105
// Buffers when the become ready must be consistent with the sharding. For
107-
// instance, Buffers.size() (the number of per-shard spans of string_views)
106+
// instance, Buffers.size() (the number of per-shard spans of absl::Cords)
108107
// and the devices in the sharding that was used to create an array must
109108
// match. If they do not, the array's ready future and buffers future should
110109
// become ready with an appropriate error status.
@@ -216,66 +215,62 @@ BasicStringArray::DisassembleIntoSingleDeviceArrays(
216215
// For each single device array we are going to pre-make:
217216
// (1) a Promise-Future pair for passing the buffers,
218217
//
219-
// (2) a Per-shard buffer backing store and the corresponding
220-
// on-done-with-buffer callback.
218+
// (2) a Per-shard data store and the corresponding on-done-with-buffer
219+
// callback.
221220
//
222221
// (3) shape and sharding by disassembing the source array's sharding.
223222
//
224223
// The Futures, the on-done-with-host-buffer callbacks, shapes and shardings
225-
// are used to make the arrays. The promises and the buffer backing stores
224+
// are used to make the arrays. The promises and the per-shard stores
226225
// are passed onto the OnReady callback that populates them when the buffers
227226
// of the source array become ready.
228227
std::vector<Promise<Buffers>> buffer_promises;
229228
buffer_promises.reserve(num_shards);
230229
std::vector<Future<Buffers>> buffer_futures;
231230
buffer_futures.reserve(num_shards);
232231

233-
struct PerShardBufferBackingStore { // Data (strings) for a single shard.
234-
void CopyFrom(absl::Span<const absl::string_view> input_buffer) {
232+
struct PerShardStringStore { // Data (strings) for a single shard.
233+
void CopyFrom(absl::Span<const absl::Cord> input_buffer) {
235234
strings.reserve(input_buffer.size());
236-
string_views.reserve(input_buffer.size());
237-
for (absl::string_view buf : input_buffer) {
238-
strings.push_back(std::string(buf.data(), buf.size()));
239-
string_views.push_back(strings.back());
235+
for (const auto& input_string : input_buffer) {
236+
strings.push_back(input_string);
240237
}
241238
}
242-
std::vector<std::string> strings;
243-
std::vector<absl::string_view> string_views;
239+
std::vector<absl::Cord> strings;
244240
};
245-
std::vector<std::shared_ptr<PerShardBufferBackingStore>>
246-
per_shard_buffer_backing_stores;
247-
per_shard_buffer_backing_stores.reserve(num_shards);
241+
242+
std::vector<std::shared_ptr<PerShardStringStore>> per_shard_strings;
243+
per_shard_strings.reserve(num_shards);
248244
std::vector<OnDoneWithBuffer> on_done_with_buffer_callbacks;
249245
on_done_with_buffer_callbacks.reserve(num_shards);
250246

251247
for (int i = 0; i < num_shards; ++i) {
252248
buffer_promises.push_back(Future<Buffers>::CreatePromise());
253249
buffer_futures.push_back(Future<Buffers>(buffer_promises.back()));
254250

255-
auto backing_store = std::make_shared<PerShardBufferBackingStore>();
256-
per_shard_buffer_backing_stores.push_back(backing_store);
251+
auto current_shard_strings = std::make_shared<PerShardStringStore>();
252+
per_shard_strings.push_back(current_shard_strings);
257253
on_done_with_buffer_callbacks.push_back(
258-
[backing_store = std::move(backing_store)]() {});
254+
[data = std::move(current_shard_strings)]() {});
259255
}
260256

261-
// Copy each of the per-shard data into the its per-shard buffer backing
262-
// store, make a Buffers object and set the corresponding promise.
257+
// When the buffers become ready, copy each of the per-shard data into the
258+
// buffer of the corresponding single-device array.
263259
buffers_.OnReady([buffer_promises = std::move(buffer_promises),
264-
per_shard_buffer_backing_stores =
265-
std::move(per_shard_buffer_backing_stores)](
260+
per_shard_data = std::move(per_shard_strings)](
266261
absl::StatusOr<Buffers> buffers) mutable {
267262
if (!buffers.ok()) {
268263
for (auto& promise : buffer_promises) {
269264
promise.Set(buffers.status());
270265
}
271-
per_shard_buffer_backing_stores.clear();
266+
per_shard_data.clear();
272267
return;
273268
}
274269
auto num_shards = buffers->size();
275270
for (int i = 0; i < num_shards; ++i) {
276-
per_shard_buffer_backing_stores[i]->CopyFrom((*buffers)[i]);
271+
per_shard_data[i]->CopyFrom((*buffers)[i]);
277272
Buffers buffers;
278-
buffers.push_back(per_shard_buffer_backing_stores[i]->string_views);
273+
buffers.push_back(absl::MakeConstSpan(per_shard_data[i]->strings));
279274
buffer_promises[i].Set(std::move(buffers));
280275
}
281276
});
@@ -325,29 +320,24 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Copy(
325320
sharding_->devices()->size()));
326321
}
327322

328-
struct BufferBackingStore {
329-
void AddShardData(absl::Span<const absl::string_view> input_buffer) {
323+
struct StringStore {
324+
void AddShardData(absl::Span<const absl::Cord> input_buffer) {
330325
auto& shard_strings = strings.emplace_back();
331326
shard_strings.reserve(input_buffer.size());
332327

333-
auto& shard_string_views = string_views.emplace_back();
334-
shard_string_views.reserve(input_buffer.size());
335-
336-
for (absl::string_view buf : input_buffer) {
337-
shard_strings.push_back(std::string(buf.data(), buf.size()));
338-
shard_string_views.push_back(shard_strings.back());
328+
for (const auto& input_string : input_buffer) {
329+
shard_strings.push_back(input_string);
339330
}
340331
}
341-
std::vector<std::vector<std::string>> strings;
342-
std::vector<std::vector<absl::string_view>> string_views;
332+
std::vector<std::vector<absl::Cord>> strings;
343333
};
344334

345-
auto backing_store = std::make_shared<BufferBackingStore>();
346-
auto on_done_with_buffer = [backing_store]() {};
335+
auto string_store = std::make_shared<StringStore>();
336+
auto on_done_with_buffer = [string_store]() {};
347337
auto buffers_promise = Future<Buffers>::CreatePromise();
348338
auto buffers_future = Future<Buffers>(buffers_promise);
349339

350-
auto copier = [backing_store = std::move(backing_store),
340+
auto copier = [string_store = std::move(string_store),
351341
buffers_promise = std::move(buffers_promise)](
352342
absl::StatusOr<Buffers> input_buffers) mutable {
353343
if (!input_buffers.ok()) {
@@ -357,8 +347,8 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Copy(
357347
Buffers buffers;
358348
buffers.reserve(input_buffers->size());
359349
for (auto& input_buffer : *input_buffers) {
360-
backing_store->AddShardData(input_buffer);
361-
buffers.push_back(backing_store->string_views.back());
350+
string_store->AddShardData(input_buffer);
351+
buffers.push_back(string_store->strings.back());
362352
}
363353
buffers_promise.Set(std::move(buffers));
364354
};
@@ -384,25 +374,22 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::FullyReplicatedShard(
384374
if (!sharding_->IsFullyReplicated()) {
385375
return absl::FailedPreconditionError("This array is not fully replicated");
386376
}
387-
struct BufferBackingStore { // Data (strings) for a single shard.
388-
void CopyFrom(absl::Span<const absl::string_view> input_buffer) {
377+
struct StringStore { // Data (strings) for a single shard.
378+
void CopyFrom(absl::Span<const absl::Cord> input_buffer) {
389379
strings.reserve(input_buffer.size());
390-
string_views.reserve(input_buffer.size());
391-
for (absl::string_view buf : input_buffer) {
392-
strings.push_back(std::string(buf.data(), buf.size()));
393-
string_views.push_back(strings.back());
380+
for (const auto& input_strings : input_buffer) {
381+
strings.push_back(input_strings);
394382
}
395383
}
396-
std::vector<std::string> strings;
397-
std::vector<absl::string_view> string_views;
384+
std::vector<absl::Cord> strings;
398385
};
399386

400-
auto backing_store = std::make_shared<BufferBackingStore>();
401-
auto on_done_with_buffer = [backing_store]() {};
387+
auto string_store = std::make_shared<StringStore>();
388+
auto on_done_with_buffer = [string_store]() {};
402389
auto buffers_promise = Future<Buffers>::CreatePromise();
403390
auto buffers_future = Future<Buffers>(buffers_promise);
404391

405-
auto copier = [backing_store = std::move(backing_store),
392+
auto copier = [string_store = std::move(string_store),
406393
buffers_promise = std::move(buffers_promise)](
407394
absl::StatusOr<Buffers> input_buffers) mutable {
408395
if (!input_buffers.ok()) {
@@ -414,10 +401,10 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::FullyReplicatedShard(
414401
// were run when the source array's buffers became ready would have
415402
// ensured that the input_buffers have at least one shard's worth of data.
416403
auto& input_buffer = (*input_buffers)[0];
417-
backing_store->CopyFrom(input_buffer);
404+
string_store->CopyFrom(input_buffer);
418405

419406
Buffers buffers;
420-
buffers.push_back(backing_store->string_views);
407+
buffers.push_back(string_store->strings);
421408
buffers_promise.Set(std::move(buffers));
422409
};
423410
buffers_.OnReady(std::move(copier));

xla/python/pjrt_ifrt/basic_string_array.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ limitations under the License.
2828
#include "absl/container/inlined_vector.h"
2929
#include "absl/hash/hash.h"
3030
#include "absl/log/check.h"
31-
#include "absl/strings/string_view.h"
31+
#include "absl/strings/cord.h"
3232
#include "absl/synchronization/mutex.h"
3333
#include "absl/types/span.h"
3434
#include "llvm/Support/ExtensibleRTTI.h"
@@ -71,7 +71,7 @@ class BasicStringArray final
7171
: public llvm::RTTIExtends<BasicStringArray, Array> {
7272
public:
7373
// Must be in dense major to minor order.
74-
using Buffer = absl::Span<const absl::string_view>;
74+
using Buffer = absl::Span<const absl::Cord>;
7575

7676
// One Buffer per shard.
7777
static constexpr int kBuffersInlineSize = 1;
@@ -82,7 +82,7 @@ class BasicStringArray final
8282
using OnDoneWithBuffer = std::function<void()>;
8383

8484
// General array construction. The `buffers` and their elements
85-
// (absl::string_views) must live until the `on_done_with_buffer` is called.
85+
// (absl::Cords) must live until the `on_done_with_buffer` is called.
8686
// The number and order of buffers must match the number and order of devices
8787
// in `sharding`.
8888
static absl::StatusOr<tsl::RCReference<BasicStringArray>> Create(

xla/python/pjrt_ifrt/basic_string_array_test.cc

+29-34
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include <gtest/gtest.h>
2828
#include "absl/log/log.h"
2929
#include "absl/status/status.h"
30+
#include "absl/strings/cord.h"
3031
#include "absl/strings/str_cat.h"
3132
#include "absl/strings/string_view.h"
3233
#include "absl/synchronization/notification.h"
@@ -84,21 +85,15 @@ std::pair<BasicStringArray::Buffers, BasicStringArray::OnDoneWithBuffer>
8485
MakeBuffersAndOnDoneWithBuffer(
8586
absl::Span<const absl::string_view> input_strings) {
8687
BasicStringArray::Buffers buffers;
87-
auto string_holder = std::make_shared<std::vector<std::string>>();
88-
string_holder->reserve(input_strings.size());
89-
auto string_view_holder = std::make_shared<std::vector<absl::string_view>>();
90-
string_view_holder->reserve(input_strings.size());
91-
for (const auto str : input_strings) {
92-
string_holder->push_back(std::string(str));
88+
auto strings = std::make_shared<std::vector<absl::Cord>>();
89+
strings->reserve(input_strings.size());
90+
for (const auto input_str : input_strings) {
91+
strings->push_back(absl::Cord(input_str));
9392
}
94-
for (const auto& str : *string_holder) {
95-
string_view_holder->push_back(absl::string_view(str));
96-
}
97-
buffers.push_back(*string_view_holder);
93+
buffers.push_back(*strings);
9894

9995
BasicStringArray::OnDoneWithBuffer on_done_with_buffer =
100-
[string_holder = std::move(string_holder),
101-
string_view_holder = std::move(string_view_holder)]() {};
96+
[strings = std::move(strings)]() {};
10297

10398
return std::make_pair(std::move(buffers), std::move(on_done_with_buffer));
10499
}
@@ -175,7 +170,7 @@ TEST(BasicStringArrayLayoutTest, Equality) {
175170
TEST(BasicStringArrayTest, CreateSuccess) {
176171
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
177172
BasicStringArray::Buffers buffers;
178-
buffers.push_back({"abc", "def"});
173+
buffers.push_back({absl::Cord("abc"), absl::Cord("def")});
179174

180175
// This test implicitly tests that the on_done_with_buffer can be a nullptr,
181176
// and that the destruction of the BasicStringArray object completes
@@ -197,7 +192,7 @@ TEST(BasicStringArrayTest, Destruction) {
197192
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
198193

199194
BasicStringArray::Buffers buffers;
200-
buffers.push_back({"abc", "def"});
195+
buffers.push_back({absl::Cord("abc"), absl::Cord("def")});
201196

202197
absl::Notification on_done_with_buffer_called;
203198
BasicStringArray::OnDoneWithBuffer on_done_with_buffer =
@@ -228,10 +223,10 @@ TEST(BasicStringArrayTest, InvalidBuffersAreHandledCorrectly) {
228223
ASSERT_GE(devices.size(), 1);
229224

230225
// Make a BasicStringArray::Buffer with two shards.
231-
auto shard0_data = std::make_shared<std::vector<absl::string_view>>();
232-
shard0_data->push_back("abc");
233-
auto shard1_data = std::make_shared<std::vector<absl::string_view>>();
234-
shard1_data->push_back("def");
226+
auto shard0_data = std::make_shared<std::vector<absl::Cord>>();
227+
shard0_data->push_back(absl::Cord("abc"));
228+
auto shard1_data = std::make_shared<std::vector<absl::Cord>>();
229+
shard1_data->push_back(absl::Cord("def"));
235230
BasicStringArray::Buffers buffers;
236231
buffers.push_back(*shard0_data);
237232
buffers.push_back(*shard1_data);
@@ -260,7 +255,7 @@ TEST(BasicStringArrayTest, InvalidBuffersAreHandledCorrectly) {
260255
TEST(BasicStringArrayTest, Delete) {
261256
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
262257
BasicStringArray::Buffers buffers;
263-
buffers.push_back({"abc", "def"});
258+
buffers.push_back({absl::Cord("abc"), absl::Cord("def")});
264259
absl::Notification on_done_with_buffer_called;
265260
BasicStringArray::OnDoneWithBuffer on_done_with_buffer =
266261
[&on_done_with_buffer_called]() { on_done_with_buffer_called.Notify(); };
@@ -294,7 +289,7 @@ TEST(GetReadyFutureTest, SuccessCase) {
294289

295290
// Make the buffers future ready asynchronously.
296291
BasicStringArray::Buffers buffers;
297-
buffers.push_back({"abc", "def"});
292+
buffers.push_back({absl::Cord("abc"), absl::Cord("def")});
298293
tsl::Env::Default()->SchedClosure([&]() { promise.Set(buffers); });
299294
TF_EXPECT_OK(ready_future.Await());
300295
}
@@ -326,11 +321,11 @@ TEST(MakeArrayFromHostBufferTest, SuccessCase) {
326321
std::shared_ptr<const Sharding> sharding =
327322
SingleDeviceSharding::Create(device, MemoryKind());
328323

329-
auto string_views = std::make_shared<std::vector<absl::string_view>>();
330-
string_views->push_back("abc");
331-
string_views->push_back("def");
332-
const void* data = string_views->data();
333-
auto on_done_with_host_buffer = [string_views = std::move(string_views)]() {};
324+
auto strings = std::make_shared<std::vector<absl::Cord>>();
325+
strings->push_back(absl::Cord("abc"));
326+
strings->push_back(absl::Cord("def"));
327+
const void* data = strings->data();
328+
auto on_done_with_host_buffer = [strings = std::move(strings)]() {};
334329

335330
TF_ASSERT_OK(client->MakeArrayFromHostBuffer(
336331
data, DType(DType::kString), shape,
@@ -345,11 +340,11 @@ TEST(MakeArrayFromHostBufferTest, FailureCases) {
345340
Device* device = client->addressable_devices().at(0);
346341
std::shared_ptr<const Sharding> single_device_sharding =
347342
SingleDeviceSharding::Create(device, MemoryKind());
348-
auto string_views = std::make_shared<std::vector<absl::string_view>>();
349-
string_views->push_back("abc");
350-
string_views->push_back("def");
351-
const void* data = string_views->data();
352-
auto on_done_with_host_buffer = [string_views = std::move(string_views)]() {};
343+
auto strings = std::make_shared<std::vector<absl::Cord>>();
344+
strings->push_back(absl::Cord("abc"));
345+
strings->push_back(absl::Cord("def"));
346+
const void* data = strings->data();
347+
auto on_done_with_host_buffer = [strings = std::move(strings)]() {};
353348

354349
// MakeArrayFromHostBuffer should check and fail if `byte_strides` in not
355350
// nullopt.
@@ -398,12 +393,12 @@ absl::StatusOr<tsl::RCReference<Array>> MakeSingleDeviceStringTestArray(
398393
std::shared_ptr<const Sharding> sharding =
399394
SingleDeviceSharding::Create(device, MemoryKind());
400395

401-
auto string_views = std::make_shared<std::vector<absl::string_view>>();
396+
auto strings = std::make_shared<std::vector<absl::Cord>>();
402397
for (const auto& content : contents) {
403-
string_views->push_back(content);
398+
strings->push_back(absl::Cord(content));
404399
}
405-
const void* data = string_views->data();
406-
auto on_done_with_host_buffer = [string_views = std::move(string_views)]() {};
400+
const void* data = strings->data();
401+
auto on_done_with_host_buffer = [strings = std::move(strings)]() {};
407402

408403
return client->MakeArrayFromHostBuffer(
409404
data, DType(DType::kString), shape,

0 commit comments

Comments
 (0)