Skip to content

Commit

Permalink
[IFRT] Extend Array assembly/disassembly operations to distinguish be…
Browse files Browse the repository at this point in the history
…tween addressable-shard and all-shard processing

IFRT's assembly/disassembly operations
(`Client::AssembleArrayFromSingleDeviceArray`,
`Array::DisassembleIntoSingleDeviceArrays`, and related methods in `Sharding`)
treated all shards equally without distinguishing the addressability of the
device of the shards. This had practical problems:

* When the user only has single-device arrays for addressable devices, and
asssemble them into a multi-shard array, the user is forced to use a `Sharding`
that only contains addressable devices. However, with SPMD, it is common to use
a `Sharding` that can express both adressable/non-addressable shards (e.g.,
`HloSharding`).

* When the user has a multi-shard array that spans both
addressable/non-addressable devices, disassembling the array into single-device
arrays would create a single-device array with no addressable devices, which is
often not well supported in the user code because the user code sometimes makes
a strong assumption that any array contains at least one addressable device.

On the other hand, making assembly/diassembly handle only addressable shards is
not future proof. An MPMD setup (not all inputs use a single device mesh) can
see an array with no addressable devices. Thus, changing assembly/diassembly
sematics to handle only addressable shards is too restrictive.

To resolve this single-device array addressability issue, this change makes it
explicit whether only addressable shards will be processed or all shards will
be processed in assembly/disassembly operations.

This change extends `Client::AssembleArrayFromSingleDeviceArray` and
`Array::DisassembleIntoSingleDeviceArrays` methods to be able to handle
addressable shards only.

It will be done in subsequent changes to make the IFRT user code to request
only addressable devices.

PiperOrigin-RevId: 686577669
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Oct 16, 2024
1 parent 6244044 commit 1b13e93
Show file tree
Hide file tree
Showing 27 changed files with 369 additions and 71 deletions.
10 changes: 7 additions & 3 deletions xla/python/ifrt/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
Expand Down Expand Up @@ -81,8 +78,15 @@ class Array : public llvm::RTTIExtends<Array, Value> {

// Breaks an array up into per-device arrays. This is the elimination
// counterpart of `Client::AssembleArrayFromSingleDeviceArrays()`.
// TODO(hyeontaek): Replace this API with the version that takes
// `SingleDeviceShardSemantics`.
virtual absl::StatusOr<std::vector<tsl::RCReference<Array>>>
DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) = 0;
virtual absl::StatusOr<std::vector<tsl::RCReference<Array>>>
DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) = 0;

// Returns a shard of an Array which is fully replicated. This is an
// optimization so that instead of disassembling into all the shards when
// the Array is fully replicated, we can just get 1 shard out and create an
Expand Down
50 changes: 31 additions & 19 deletions xla/python/ifrt/array_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferReplicated) {

TF_ASSERT_OK_AND_ASSIGN(auto single_device_arrays,
array->DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics::kAlwaysCopy));
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));
ASSERT_EQ(single_device_arrays.size(), devices.size());
for (int i = 0; i < single_device_arrays.size(); ++i) {
EXPECT_THAT(single_device_arrays[i]->sharding().devices()->devices(),
Expand Down Expand Up @@ -383,7 +384,8 @@ TEST(ArrayImplTest, AssembleArray) {
auto assembled_array,
client->AssembleArrayFromSingleDeviceArrays(
assembled_shape, assembled_sharding, absl::MakeSpan(arrays),
ArrayCopySemantics::kAlwaysCopy));
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));

EXPECT_EQ(assembled_array->dtype(), dtype);
EXPECT_EQ(assembled_array->shape(), assembled_shape);
Expand Down Expand Up @@ -438,11 +440,14 @@ TEST(ArrayImplTest, AssembleAndDisassembleArray) {
auto assembled_array,
client->AssembleArrayFromSingleDeviceArrays(
assembled_shape, assembled_sharding, absl::MakeSpan(arrays),
ArrayCopySemantics::kAlwaysCopy));
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));

TF_ASSERT_OK_AND_ASSIGN(auto single_device_arrays,
assembled_array->DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics::kAlwaysCopy));
TF_ASSERT_OK_AND_ASSIGN(
auto single_device_arrays,
assembled_array->DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));

ASSERT_THAT(single_device_arrays, SizeIs(2));
EXPECT_EQ(single_device_arrays[0]->dtype(), array0->dtype());
Expand Down Expand Up @@ -479,7 +484,8 @@ TEST(ArrayImplTest, AssembleAndDisassembleSingleDeviceArray) {
TF_ASSERT_OK_AND_ASSIGN(auto assembled_array,
client->AssembleArrayFromSingleDeviceArrays(
shape, sharding, absl::MakeSpan(arrays),
ArrayCopySemantics::kAlwaysCopy));
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));

ASSERT_EQ(assembled_array->dtype(), array->dtype());
ASSERT_EQ(assembled_array->shape(), array->shape());
Expand All @@ -488,7 +494,8 @@ TEST(ArrayImplTest, AssembleAndDisassembleSingleDeviceArray) {

TF_ASSERT_OK_AND_ASSIGN(auto single_device_arrays,
assembled_array->DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics::kAlwaysCopy));
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));

ASSERT_THAT(single_device_arrays, SizeIs(1));
ASSERT_EQ(single_device_arrays[0]->dtype(), array->dtype());
Expand Down Expand Up @@ -557,18 +564,22 @@ TEST(ArrayImplTest, CopyToDifferentDevice) {
std::vector<Shape> shapes(shards.size(), shape);
std::shared_ptr<const Sharding> sharding =
ConcreteSharding::Create(devices, MemoryKind(), shape, shapes);
TF_ASSERT_OK_AND_ASSIGN(arrays.emplace_back(),
client->AssembleArrayFromSingleDeviceArrays(
shape, sharding, absl::MakeSpan(shards),
ArrayCopySemantics::kAlwaysCopy));
TF_ASSERT_OK_AND_ASSIGN(
arrays.emplace_back(),
client->AssembleArrayFromSingleDeviceArrays(
shape, sharding, absl::MakeSpan(shards),
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));
}
{
std::shared_ptr<const Sharding> sharding =
ConcreteEvenSharding::Create(devices, MemoryKind(), shape, shape);
TF_ASSERT_OK_AND_ASSIGN(arrays.emplace_back(),
client->AssembleArrayFromSingleDeviceArrays(
shape, sharding, absl::MakeSpan(shards),
ArrayCopySemantics::kAlwaysCopy));
TF_ASSERT_OK_AND_ASSIGN(
arrays.emplace_back(),
client->AssembleArrayFromSingleDeviceArrays(
shape, sharding, absl::MakeSpan(shards),
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));
}

BasicDeviceList::Devices new_devices;
Expand All @@ -589,9 +600,10 @@ TEST(ArrayImplTest, CopyToDifferentDevice) {
BasicDeviceList::Create(new_devices), MemoryKind()));
EXPECT_EQ(new_arrays[i]->sharding(), *expected_sharding);

TF_ASSERT_OK_AND_ASSIGN(auto shards,
arrays[i]->DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics::kAlwaysCopy));
TF_ASSERT_OK_AND_ASSIGN(
auto shards, arrays[i]->DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics::kAlwaysCopy,
SingleDeviceShardSemantics::kAddressableShards));
for (const auto& shard : shards) {
std::vector<float> out_data(6);
auto future = shard->CopyToHostBuffer(out_data.data(),
Expand Down
8 changes: 8 additions & 0 deletions xla/python/ifrt/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,19 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
std::function<void()> on_done_with_host_buffer) = 0;

// Builds a larger array out of individual per-device shards.
// TODO(hyeontaek): Replace this API with the version that takes
// `SingleDeviceShardSemantics`.
virtual absl::StatusOr<tsl::RCReference<Array>>
AssembleArrayFromSingleDeviceArrays(
Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics semantics) = 0;
virtual absl::StatusOr<tsl::RCReference<Array>>
AssembleArrayFromSingleDeviceArrays(
Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) = 0;

// Copies the arrays to a new set of devices.
//
Expand Down
25 changes: 23 additions & 2 deletions xla/python/ifrt/mock.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ char MockHostCallback::ID = 0;
char MockLoadedHostCallback::ID = 0;
char MockSharding::ID = 0;

namespace {
using ::testing::_;
}

// LINT.IfChange(MockArrayDelegation)
MockArray::MockArray(tsl::RCReference<xla::ifrt::Array> delegated)
: delegated_(std::move(delegated)) {
Expand All @@ -76,10 +80,17 @@ MockArray::MockArray(tsl::RCReference<xla::ifrt::Array> delegated)
.WillByDefault([this]() -> absl::StatusOr<std::unique_ptr<PjRtLayout>> {
return delegated_->layout();
});
ON_CALL(*this, DisassembleIntoSingleDeviceArrays)
ON_CALL(*this, DisassembleIntoSingleDeviceArrays(_))
.WillByDefault([this](ArrayCopySemantics semantics) {
return delegated_->DisassembleIntoSingleDeviceArrays(semantics);
});
ON_CALL(*this, DisassembleIntoSingleDeviceArrays(_, _))
.WillByDefault(
[this](ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
return delegated_->DisassembleIntoSingleDeviceArrays(
array_copy_semantics, single_device_shard_semantics);
});
ON_CALL(*this, FullyReplicatedShard)
.WillByDefault([this](ArrayCopySemantics semantics) {
return delegated_->FullyReplicatedShard(semantics);
Expand Down Expand Up @@ -108,14 +119,24 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
data, dtype, std::move(shape), byte_strides, std::move(sharding),
semantics, std::move(on_done_with_host_buffer));
});
ON_CALL(*this, AssembleArrayFromSingleDeviceArrays)
ON_CALL(*this, AssembleArrayFromSingleDeviceArrays(_, _, _, _))
.WillByDefault([this](Shape shape,
std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics semantics) {
return delegated_->AssembleArrayFromSingleDeviceArrays(
std::move(shape), std::move(sharding), arrays, semantics);
});
ON_CALL(*this, AssembleArrayFromSingleDeviceArrays(_, _, _, _, _))
.WillByDefault(
[this](Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
return delegated_->AssembleArrayFromSingleDeviceArrays(
std::move(shape), std::move(sharding), arrays,
array_copy_semantics, single_device_shard_semantics);
});
ON_CALL(*this, CopyArrays)
.WillByDefault([this](absl::Span<tsl::RCReference<Array>> arrays,
std::optional<tsl::RCReference<DeviceList>> devices,
Expand Down
12 changes: 12 additions & 0 deletions xla/python/ifrt/mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ class MockArray : public llvm::RTTIExtends<MockArray, Array> {
MOCK_METHOD(absl::StatusOr<std::vector<tsl::RCReference<Array>>>,
DisassembleIntoSingleDeviceArrays, (ArrayCopySemantics semantics),
(final));
MOCK_METHOD(absl::StatusOr<std::vector<tsl::RCReference<Array>>>,
DisassembleIntoSingleDeviceArrays,
(ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics),
(final));
MOCK_METHOD(absl::StatusOr<tsl::RCReference<Array>>, FullyReplicatedShard,
(ArrayCopySemantics semantics), (final));
MOCK_METHOD(Future<>, CopyToHostBuffer,
Expand Down Expand Up @@ -120,6 +125,13 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics semantics),
(final));
MOCK_METHOD(absl::StatusOr<tsl::RCReference<Array>>,
AssembleArrayFromSingleDeviceArrays,
(Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics),
(final));
MOCK_METHOD(absl::StatusOr<std::vector<tsl::RCReference<Array>>>, CopyArrays,
(absl::Span<tsl::RCReference<Array>> arrays,
std::optional<tsl::RCReference<DeviceList>> devices,
Expand Down
1 change: 1 addition & 0 deletions xla/python/ifrt/support/sharding_conversions_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ std::shared_ptr<MockClient> MakeTestClient(int num_devices) {
for (int i = 0; i < num_devices; ++i) {
auto device = std::make_unique<MockDevice>();
ON_CALL(*device, Id).WillByDefault(Return(DeviceId(i)));
ON_CALL(*device, IsAddressable).WillByDefault(Return(true));
state->devices.push_back(device.get());
state->device_map.insert({DeviceId(i), std::move(device)});
}
Expand Down
33 changes: 30 additions & 3 deletions xla/python/ifrt_proxy/client/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,22 @@ Array::AssembleArrayFromSingleDeviceArrays(
xla::ifrt::Client* client, std::shared_ptr<RpcHelper> rpc_helper,
Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics semantics) {
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAddressableShards &&
rpc_helper->version().protocol_version() < 8) {
return absl::UnimplementedError(
"SingleDeviceShardSemantics::kAdressableShards is not supported in "
"ifrt-proxy version < 8");
}
auto req = std::make_unique<AssembleArrayFromSingleDeviceArraysRequest>();
TF_RET_CHECK(!arrays.empty());
*req->mutable_shape() = shape.ToProto();
TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), sharding->ToProto());
req->set_copy_semantics(ToArrayCopySemanticsProto(semantics));
req->set_copy_semantics(ToArrayCopySemanticsProto(array_copy_semantics));
req->set_single_device_shard_semantics(
ToSingleDeviceShardSemanticsProto(single_device_shard_semantics));
for (const tsl::RCReference<xla::ifrt::Array>& rcref : arrays) {
Array* array = llvm::dyn_cast<Array>(rcref.get());
if (array == nullptr) {
Expand Down Expand Up @@ -244,9 +254,26 @@ Array::RemapArrays(xla::ifrt::Client* client,

absl::StatusOr<std::vector<tsl::RCReference<xla::ifrt::Array>>>
Array::DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) {
return DisassembleIntoSingleDeviceArrays(
semantics, SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<std::vector<tsl::RCReference<xla::ifrt::Array>>>
Array::DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAddressableShards &&
rpc_helper_->version().protocol_version() < 8) {
return absl::UnimplementedError(
"SingleDeviceShardSemantics::kAdressableShards is not supported in "
"version < 8");
}
auto req = std::make_unique<DisassembleIntoSingleDeviceArraysRequest>();
req->set_array_handle(handle_.handle);
req->set_copy_semantics(ToArrayCopySemanticsProto(semantics));
req->set_copy_semantics(ToArrayCopySemanticsProto(array_copy_semantics));
req->set_single_device_shard_semantics(
ToSingleDeviceShardSemanticsProto(single_device_shard_semantics));

TF_ASSIGN_OR_RETURN(
std::shared_ptr<DisassembleIntoSingleDeviceArraysResponse> response,
Expand Down
7 changes: 6 additions & 1 deletion xla/python/ifrt_proxy/client/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class Array final : public llvm::RTTIExtends<Array, xla::ifrt::Array> {
xla::ifrt::Client* client, std::shared_ptr<RpcHelper> rpc_helper,
Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics semantics);
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics);

// `Array::RemapArrays()` implements `Client::RemapArrays()`.
// TODO(b/261226026): Implement logic directly in client.cc.
Expand Down Expand Up @@ -118,6 +119,10 @@ class Array final : public llvm::RTTIExtends<Array, xla::ifrt::Array> {

absl::StatusOr<std::vector<tsl::RCReference<xla::ifrt::Array>>>
DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override;
absl::StatusOr<std::vector<tsl::RCReference<xla::ifrt::Array>>>
DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) override;

absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> FullyReplicatedShard(
xla::ifrt::ArrayCopySemantics semantics) override;
Expand Down
14 changes: 13 additions & 1 deletion xla/python/ifrt_proxy/client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,19 @@ Client::AssembleArrayFromSingleDeviceArrays(
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics semantics) {
return Array::AssembleArrayFromSingleDeviceArrays(
this, rpc_helper_, std::move(shape), sharding, arrays, semantics);
this, rpc_helper_, std::move(shape), sharding, arrays, semantics,
SingleDeviceShardSemantics::kAllShards);
}

absl::StatusOr<tsl::RCReference<xla::ifrt::Array>>
Client::AssembleArrayFromSingleDeviceArrays(
Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
return Array::AssembleArrayFromSingleDeviceArrays(
this, rpc_helper_, std::move(shape), sharding, arrays,
array_copy_semantics, single_device_shard_semantics);
}

absl::StatusOr<std::vector<tsl::RCReference<xla::ifrt::Array>>>
Expand Down
6 changes: 6 additions & 0 deletions xla/python/ifrt_proxy/client/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics semantics) override;
absl::StatusOr<tsl::RCReference<xla::ifrt::Array>>
AssembleArrayFromSingleDeviceArrays(
Shape shape, std::shared_ptr<const Sharding> sharding,
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) override;

absl::StatusOr<std::vector<tsl::RCReference<Array>>> CopyArrays(
absl::Span<tsl::RCReference<Array>> arrays,
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt_proxy/client/version.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace proxy {
// LINT.IfChange
// TODO(b/296144873): Document the version upgrade policy.
inline constexpr int kClientMinVersion = 3;
inline constexpr int kClientMaxVersion = 7;
inline constexpr int kClientMaxVersion = 8;
// LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md)

} // namespace proxy
Expand Down
6 changes: 6 additions & 0 deletions xla/python/ifrt_proxy/common/VERSION.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,9 @@
* Added date: 2024-10-01.
* Changes:
* Added support for `Client::GetAllDevices()`.

## Version 8

* Added date: 2024-10-11.
* Changes:
* Added support for `SingleDeviceShardSemantics` in Array assembly and disassembly operations.
2 changes: 2 additions & 0 deletions xla/python/ifrt_proxy/common/ifrt_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ message AssembleArrayFromSingleDeviceArraysRequest {
ShardingProto sharding = 2;
repeated fixed64 single_device_array_handles = 3;
proto.ArrayCopySemantics copy_semantics = 4;
optional proto.SingleDeviceShardSemantics single_device_shard_semantics = 5;
}
message AssembleArrayFromSingleDeviceArraysResponse {
fixed64 array_handle = 1;
Expand Down Expand Up @@ -299,6 +300,7 @@ message CopyToHostBufferResponse {}
message DisassembleIntoSingleDeviceArraysRequest {
fixed64 array_handle = 1;
proto.ArrayCopySemantics copy_semantics = 2;
optional proto.SingleDeviceShardSemantics single_device_shard_semantics = 3;
}
message DisassembleIntoSingleDeviceArraysResponse {
repeated fixed64 single_device_array_handles = 1;
Expand Down
Loading

0 comments on commit 1b13e93

Please sign in to comment.