Skip to content

Commit eac67fd

Browse files
ICGogcopybara-github
authored andcommitted
Relax verification of ShardingParam so that a dimension can be sharded over multiple axis.
PiperOrigin-RevId: 617356700
1 parent b8de936 commit eac67fd

File tree

3 files changed

+5
-14
lines changed

3 files changed

+5
-14
lines changed

xla/python/ifrt/ir/sharding_param.cc

+3-6
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,9 @@ absl::Status ShardingParam::verify() const {
161161
break;
162162
}
163163
cum_size *= minor_to_major().axis_sizes[index];
164-
if (cum_size > dim_shards()[dim_index]) {
165-
return absl::InvalidArgumentError(absl::StrCat(
166-
"Dimension #", dim_index, " of ", dim_shards()[dim_index],
167-
" shards can't be assigned to the axes"));
168-
} else if (cum_size == dim_shards()[dim_index]) {
169-
cum_size = 1;
164+
while (dim_index < dim_shards().size() &&
165+
cum_size % dim_shards()[dim_index] == 0) {
166+
cum_size /= dim_shards()[dim_index];
170167
dim_index++;
171168
}
172169
}

xla/python/ifrt/ir/tests/verify_array.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func.func @array_requires_enough_devices() {
7373
// -----
7474

7575
func.func @array_requires_shard_distributable_to_axes() {
76-
// expected-error@+2 {{Dimension #1 of 2 shards can't be assigned to the axes}}
76+
// expected-error@+2 {{Can't shard the dims 1x2 to the mesh of [0] on 3}}
7777
%0 = builtin.unrealized_conversion_cast to
7878
!ifrt.array<tensor<4x4xi32>, 1x2 to [0] on 3, [0,1,2]>
7979
return

xla/python/ifrt/support/sharding_conversions_test.cc

+1-7
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,6 @@ TEST(ShardingConversionsTest, VerifyIncorrectShardings) {
164164
ShardingParam too_many_slices{/*dim_shards=*/{2, 2},
165165
{/*permutation=*/{0}, /*axis_sizes=*/{2}}};
166166
EXPECT_FALSE(too_many_slices.verify().ok());
167-
ShardingParam cannot_distribute_slices{
168-
/*dim_shards=*/{1, 2}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{3, 2}}};
169-
EXPECT_FALSE(cannot_distribute_slices.verify().ok());
170167
ShardingParam incorrect_permutation{
171168
/*dim_shards=*/{4, 1},
172169
{/*permutation=*/{0, 1, 1}, /*axis_sizes=*/{2, 2, 2}}};
@@ -197,10 +194,7 @@ TEST_P(HloShardingToShardingParamTest, HloShardingToShardingParam) {
197194
TF_ASSERT_OK_AND_ASSIGN(
198195
auto sharding_param,
199196
ToShardingParam(param.hlo_sharding, param.rank, param.num_devices));
200-
// We cannot verify sharding param because we're losing info about the
201-
// axis_size during these conversions. While strictly some ShardingParam
202-
// are invalid because they have more dims than axis, in practice this is not
203-
// a problem because we can still correctly map the shards to the devices.
197+
EXPECT_TRUE(sharding_param.verify().ok());
204198
TF_ASSERT_OK_AND_ASSIGN(auto actual_hlo_sharding,
205199
ToHloSharding(sharding_param));
206200
EXPECT_EQ(param.hlo_sharding, actual_hlo_sharding);

0 commit comments

Comments
 (0)