Skip to content

Commit c1bb92f

Browse files
hyeontaekGoogle-ML-Automation
authored andcommitted
[IFRT] Change tsl::RCReference<xla::ifrt::DeviceList> to xla::ifrt::DeviceListRef
IFRT will introduce a wrapper around `tsl::RCReference<xla::ifrt::DeviceList>` that is more concise and is also safer to use (e.g., equality and hash compares the dereferenced value, not the `tsl::RCReference`. This migration will happen in 3 steps: 1. Define `xla::ifrt::DeviceListRef` alias that is interchangeable with `tsl::RCReference<xla::ifrt::DeviceList>`. 2. Change `tsl::RCReference<xla::ifrt::DeviceList>` to `xla::ifrt::DeviceListRef` in IFRT API, implementations, and user code. (current step) 3. Introduce a real wrapper `xla::ifrt::DeviceListRef`, replacing the alias. PiperOrigin-RevId: 730231976
1 parent 610c85d commit c1bb92f

21 files changed

+98
-119
lines changed

xla/python/ifrt/array_impl_test_lib.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -536,8 +536,7 @@ TEST(ArrayImplTest, CopyToSameDevices) {
536536

537537
TEST(ArrayImplTest, CopyToDifferentDevice) {
538538
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
539-
tsl::RCReference<DeviceList> devices =
540-
client->MakeDeviceList(client->addressable_devices());
539+
DeviceListRef devices = client->MakeDeviceList(client->addressable_devices());
541540

542541
DType dtype(DType::kF32);
543542
Shape shape({2, 3});

xla/python/ifrt/basic_device_list.cc

+4-6
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,15 @@ namespace ifrt {
3737

3838
char BasicDeviceList::ID = 0;
3939

40-
tsl::RCReference<DeviceList> BasicDeviceList::Create(Devices devices) {
41-
return tsl::MakeRef<BasicDeviceList>(std::move(devices));
40+
DeviceListRef BasicDeviceList::Create(Devices devices) {
41+
return DeviceListRef(tsl::MakeRef<BasicDeviceList>(std::move(devices)));
4242
}
4343

44-
tsl::RCReference<DeviceList> BasicDeviceList::Create(
45-
absl::Span<Device* const> devices) {
44+
DeviceListRef BasicDeviceList::Create(absl::Span<Device* const> devices) {
4645
return Create(Devices(devices.begin(), devices.end()));
4746
}
4847

49-
tsl::RCReference<DeviceList> BasicDeviceList::Create(
50-
std::initializer_list<Device*> devices) {
48+
DeviceListRef BasicDeviceList::Create(std::initializer_list<Device*> devices) {
5149
return Create(Devices(devices.begin(), devices.end()));
5250
}
5351

xla/python/ifrt/basic_device_list.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,9 @@ class BasicDeviceList : public llvm::RTTIExtends<BasicDeviceList, DeviceList> {
5454
using Devices = absl::InlinedVector<Device*, kInlineDeviceSize>;
5555

5656
// Constructor with a pre-populated `devices`.
57-
static tsl::RCReference<DeviceList> Create(Devices devices);
58-
static tsl::RCReference<DeviceList> Create(absl::Span<Device* const> devices);
59-
static tsl::RCReference<DeviceList> Create(
60-
std::initializer_list<Device*> devices);
57+
static DeviceListRef Create(Devices devices);
58+
static DeviceListRef Create(absl::Span<Device* const> devices);
59+
static DeviceListRef Create(std::initializer_list<Device*> devices);
6160

6261
~BasicDeviceList() override = default;
6362

@@ -95,7 +94,7 @@ class BasicDeviceList : public llvm::RTTIExtends<BasicDeviceList, DeviceList> {
9594
struct AddressableDeviceListCache {
9695
absl::once_flag once_flag;
9796
DeviceList* device_list = nullptr;
98-
tsl::RCReference<DeviceList> device_list_holder;
97+
DeviceListRef device_list_holder;
9998
};
10099
mutable AddressableDeviceListCache addressable_device_list_cache_;
101100

xla/python/ifrt/client.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
153153
// device.
154154
virtual absl::StatusOr<std::vector<tsl::RCReference<Array>>> CopyArrays(
155155
absl::Span<tsl::RCReference<Array>> arrays,
156-
std::optional<tsl::RCReference<DeviceList>> devices,
156+
std::optional<DeviceListRef> devices,
157157
std::optional<MemoryKind> memory_kind, ArrayCopySemantics semantics) = 0;
158158

159159
// Remaps shards across input `Array`s to create new `Array`s based on `plan`.
@@ -237,7 +237,7 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
237237
int local_hardware_id) const = 0;
238238

239239
// Creates a device list from the given list of devices.
240-
virtual tsl::RCReference<DeviceList> MakeDeviceList(
240+
virtual DeviceListRef MakeDeviceList(
241241
absl::Span<Device* const> devices) const = 0;
242242

243243
// TODO(hyeontaek): Potentially remove this method to encourage supporting
@@ -246,7 +246,7 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
246246

247247
// Returns a topology that covers the provided devices.
248248
virtual absl::StatusOr<std::shared_ptr<Topology>> GetTopologyForDevices(
249-
const tsl::RCReference<DeviceList>& devices) const = 0;
249+
const DeviceListRef& devices) const = 0;
250250

251251
// Returns the default layout on `device` with `memory_kind` for a buffer with
252252
// `dtype` and single-shard dimensions `dims`.

xla/python/ifrt/custom_call_program.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ struct CustomCallProgram
3737
// Specification for a single array. The sharding of all input and output
3838
// specs must use only the devices in `devices`.
3939
CustomCallProgram(std::string type, std::string name,
40-
absl::Cord serialized_program_text,
41-
tsl::RCReference<DeviceList> devices,
40+
absl::Cord serialized_program_text, DeviceListRef devices,
4241
std::vector<ArraySpec> input_specs,
4342
std::vector<ArraySpec> output_specs)
4443
: type(std::move(type)),
@@ -62,7 +61,7 @@ struct CustomCallProgram
6261
absl::Cord serialized_program_text;
6362

6463
// List of devices to compile and run the custom call program on.
65-
tsl::RCReference<DeviceList> devices;
64+
DeviceListRef devices;
6665

6766
// Specification for input and output arrays. The custom call program must
6867
// expect to receive input arrays and return output arrays both following the

xla/python/ifrt/custom_call_program_serdes.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class CustomCallProgramSerDes
8181
"Failed to parse serialized CustomCallProgramProto");
8282
}
8383
TF_ASSIGN_OR_RETURN(
84-
tsl::RCReference<DeviceList> devices,
84+
DeviceListRef devices,
8585
DeviceList::FromProto(deserialize_program_options->client,
8686
proto.devices()));
8787
std::vector<ArraySpec> input_specs;

xla/python/ifrt/custom_call_program_serdes_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class CustomCallProgramSerDesTest : public test_util::DeviceTest {};
5151
TEST_P(CustomCallProgramSerDesTest, RoundTrip) {
5252
Shape shape0({10, 20});
5353
Shape shard_shape0({5, 20});
54-
tsl::RCReference<DeviceList> devices = GetDevices({0, 1});
54+
DeviceListRef devices = GetDevices({0, 1});
5555
std::shared_ptr<const Sharding> sharding0 =
5656
ConcreteEvenSharding::Create(devices, MemoryKind(),
5757
/*shape=*/shape0,

xla/python/ifrt/device_list.cc

+2-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace ifrt {
3131

3232
char DeviceList::ID = 0;
3333

34-
absl::StatusOr<tsl::RCReference<DeviceList>> DeviceList::FromProto(
34+
absl::StatusOr<DeviceListRef> DeviceList::FromProto(
3535
xla::ifrt::Client* client, const DeviceListProto& proto) {
3636
absl::InlinedVector<Device*, 1> devices;
3737
devices.reserve(proto.device_ids_size());
@@ -52,8 +52,7 @@ DeviceListProto DeviceList::ToProto() const {
5252
return proto;
5353
}
5454

55-
std::vector<DeviceId> GetDeviceIds(
56-
const tsl::RCReference<DeviceList>& device_list) {
55+
std::vector<DeviceId> GetDeviceIds(const DeviceListRef& device_list) {
5756
std::vector<DeviceId> ids;
5857
ids.reserve(device_list->devices().size());
5958
for (const Device* device : device_list->devices()) {

xla/python/ifrt/device_list.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ class DeviceList : public tsl::ReferenceCounted<DeviceList>,
104104
using DeviceListRef = tsl::RCReference<DeviceList>;
105105

106106
// Returns the id of each device in `device_list`.
107-
std::vector<DeviceId> GetDeviceIds(
108-
const tsl::RCReference<DeviceList>& device_list);
107+
std::vector<DeviceId> GetDeviceIds(const DeviceListRef& device_list);
109108

110109
// Hash function for `DeviceList`. Assumes that every unique device has a unique
111110
// `Device` object, not duplicate `Device` objects ("d1 == d2 if d1->id() ==

xla/python/ifrt/device_test_util.cc

+4-7
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,11 @@ std::shared_ptr<MockClient> MakeDeviceTestClient(int num_devices,
132132
return it->second.get();
133133
});
134134
ON_CALL(*client, MakeDeviceList)
135-
.WillByDefault([](absl::Span<Device* const> devices)
136-
-> tsl::RCReference<DeviceList> {
135+
.WillByDefault([](absl::Span<Device* const> devices) -> DeviceListRef {
137136
return BasicDeviceList::Create(devices);
138137
});
139138
ON_CALL(*client, GetTopologyForDevices)
140-
.WillByDefault(
141-
[](const tsl::RCReference<DeviceList>&) { return nullptr; });
139+
.WillByDefault([](const DeviceListRef&) { return nullptr; });
142140
return client;
143141
}
144142

@@ -149,12 +147,11 @@ void DeviceTest::SetUp() {
149147
client_ = MakeDeviceTestClient(num_devices, num_addressable_devices);
150148
}
151149

152-
tsl::RCReference<DeviceList> DeviceTest::GetDevices(
153-
absl::Span<const int> device_indices) {
150+
DeviceListRef DeviceTest::GetDevices(absl::Span<const int> device_indices) {
154151
return test_util::GetDevices(client_.get(), device_indices).value();
155152
}
156153

157-
tsl::RCReference<DeviceList> DeviceTest::GetAddressableDevices(
154+
DeviceListRef DeviceTest::GetAddressableDevices(
158155
absl::Span<const int> device_indices) {
159156
return test_util::GetAddressableDevices(client_.get(), device_indices)
160157
.value();

xla/python/ifrt/device_test_util.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,12 @@ class DeviceTest : public testing::TestWithParam<DeviceTestParam> {
4646
// Returns `DeviceList` containing devices at given indexes (not ids) within
4747
// `client.devices()`.
4848
// REQUIRES: 0 <= device_indices[i] < num_devices
49-
tsl::RCReference<DeviceList> GetDevices(absl::Span<const int> device_indices);
49+
DeviceListRef GetDevices(absl::Span<const int> device_indices);
5050

5151
// Returns `DeviceList` containing devices at given indexes (not ids) within
5252
// `client.addressable_devices()`.
5353
// REQUIRES: 0 <= device_indices[i] < num_addressable_devices
54-
tsl::RCReference<DeviceList> GetAddressableDevices(
55-
absl::Span<const int> device_indices);
54+
DeviceListRef GetAddressableDevices(absl::Span<const int> device_indices);
5655

5756
private:
5857
std::shared_ptr<Client> client_;

xla/python/ifrt/executable.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class LoadedExecutable
234234
// API).
235235
virtual absl::StatusOr<ExecuteResult> Execute(
236236
absl::Span<tsl::RCReference<Array>> args, const ExecuteOptions& options,
237-
std::optional<tsl::RCReference<DeviceList>> devices) = 0;
237+
std::optional<DeviceListRef> devices) = 0;
238238

239239
// Deletes the executable from the devices. The operation may be asynchronous.
240240
// The returned future will have the result of the deletion on the devices.

xla/python/ifrt/mock.cc

+4-5
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
155155
});
156156
ON_CALL(*this, CopyArrays)
157157
.WillByDefault([this](absl::Span<tsl::RCReference<Array>> arrays,
158-
std::optional<tsl::RCReference<DeviceList>> devices,
158+
std::optional<DeviceListRef> devices,
159159
std::optional<MemoryKind> memory_kind,
160160
ArrayCopySemantics semantics) {
161161
return delegated_->CopyArrays(arrays, std::move(devices), memory_kind,
@@ -229,10 +229,9 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
229229
return delegated_->GetDefaultCompiler();
230230
});
231231
ON_CALL(*this, GetTopologyForDevices)
232-
.WillByDefault(
233-
[this](const tsl::RCReference<xla::ifrt::DeviceList>& devices) {
234-
return delegated_->GetTopologyForDevices(devices);
235-
});
232+
.WillByDefault([this](const DeviceListRef& devices) {
233+
return delegated_->GetTopologyForDevices(devices);
234+
});
236235
ON_CALL(*this, GetDefaultLayout)
237236
.WillByDefault([this](xla::ifrt::DType dtype,
238237
absl::Span<const int64_t> dims,

xla/python/ifrt/mock.h

+6-7
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
148148
(final));
149149
MOCK_METHOD(absl::StatusOr<std::vector<tsl::RCReference<Array>>>, CopyArrays,
150150
(absl::Span<tsl::RCReference<Array>> arrays,
151-
std::optional<tsl::RCReference<DeviceList>> devices,
151+
std::optional<DeviceListRef> devices,
152152
std::optional<MemoryKind> memory_kind,
153153
ArrayCopySemantics semantics),
154154
(final));
@@ -180,12 +180,11 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
180180
(const, final));
181181
MOCK_METHOD(absl::StatusOr<Device*>, LookupAddressableDevice,
182182
(int local_hardware_id), (const, final));
183-
MOCK_METHOD(tsl::RCReference<DeviceList>, MakeDeviceList,
183+
MOCK_METHOD(DeviceListRef, MakeDeviceList,
184184
(absl::Span<Device* const> devices), (const));
185185
MOCK_METHOD(Compiler*, GetDefaultCompiler, (), (final));
186186
MOCK_METHOD(absl::StatusOr<std::shared_ptr<Topology>>, GetTopologyForDevices,
187-
(const tsl::RCReference<xla::ifrt::DeviceList>& devices),
188-
(const, final));
187+
(const xla::ifrt::DeviceListRef& devices), (const, final));
189188
MOCK_METHOD(absl::StatusOr<std::shared_ptr<const PjRtLayout>>,
190189
GetDefaultLayout,
191190
(xla::ifrt::DType dtype, absl::Span<const int64_t> dims,
@@ -319,7 +318,7 @@ class MockLoadedExecutable
319318
MOCK_METHOD(absl::StatusOr<ExecuteResult>, Execute,
320319
(absl::Span<tsl::RCReference<Array>> args,
321320
const ExecuteOptions& options,
322-
std::optional<tsl::RCReference<DeviceList>> devices),
321+
std::optional<DeviceListRef> devices),
323322
(final));
324323
MOCK_METHOD(Future<>, Delete, (), (final));
325324
MOCK_METHOD(bool, IsDeleted, (), (const, final));
@@ -357,7 +356,7 @@ class MockSharding : public llvm::RTTIExtends<MockSharding, Sharding> {
357356
BasicDeviceList::Create({}), MemoryKind(),
358357
/*is_fully_replicated=*/false) {}
359358

360-
MockSharding(tsl::RCReference<DeviceList> devices, MemoryKind memory_kind,
359+
MockSharding(DeviceListRef devices, MemoryKind memory_kind,
361360
bool is_fully_replicated)
362361
: llvm::RTTIExtends<MockSharding, Sharding>(devices, memory_kind,
363362
is_fully_replicated) {}
@@ -395,7 +394,7 @@ class MockSharding : public llvm::RTTIExtends<MockSharding, Sharding> {
395394
MOCK_METHOD(bool, HasSamePartitioning, (const Sharding& other),
396395
(const final));
397396
MOCK_METHOD(absl::StatusOr<std::unique_ptr<Sharding>>, WithDeviceAssignment,
398-
(std::optional<tsl::RCReference<DeviceList>> devices,
397+
(std::optional<DeviceListRef> devices,
399398
std::optional<MemoryKind> memory_kind),
400399
(const final));
401400
MOCK_METHOD(void, Hash, (absl::HashState), (const final));

xla/python/ifrt/remap_impl_test_lib.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ absl::StatusOr<ArraySpec> CreateArraySpec(Client* client,
6969
absl::Span<const int> device_indices,
7070
Shape shard_shape = Shape({2, 3}),
7171
DType dtype = DType(DType::kS32)) {
72-
TF_ASSIGN_OR_RETURN(tsl::RCReference<DeviceList> device_list,
72+
TF_ASSIGN_OR_RETURN(DeviceListRef device_list,
7373
test_util::GetAddressableDevices(client, device_indices));
7474
TF_ASSIGN_OR_RETURN(Shape shape,
7575
GetShape(device_indices.size(), shard_shape));

xla/python/ifrt/remap_plan_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ TEST_P(RemapPlanTest, ToFromProto) {
7070

7171
Shape shape({20, 20});
7272
Shape shard_shape({5, 20});
73-
tsl::RCReference<DeviceList> devices = GetDevices({0, 1, 2, 3});
73+
DeviceListRef devices = GetDevices({0, 1, 2, 3});
7474
std::shared_ptr<const Sharding> sharding =
7575
ConcreteEvenSharding::Create(devices, MemoryKind(), /*shape=*/shape,
7676
/*shard_shape=*/shard_shape);

0 commit comments

Comments
 (0)