@@ -164,9 +164,6 @@ TEST(ShardingConversionsTest, VerifyIncorrectShardings) {
164
164
ShardingParam too_many_slices{/* dim_shards=*/ {2 , 2 },
165
165
{/* permutation=*/ {0 }, /* axis_sizes=*/ {2 }}};
166
166
EXPECT_FALSE (too_many_slices.verify ().ok ());
167
- ShardingParam cannot_distribute_slices{
168
- /* dim_shards=*/ {1 , 2 }, {/* permutation=*/ {0 , 1 }, /* axis_sizes=*/ {3 , 2 }}};
169
- EXPECT_FALSE (cannot_distribute_slices.verify ().ok ());
170
167
ShardingParam incorrect_permutation{
171
168
/* dim_shards=*/ {4 , 1 },
172
169
{/* permutation=*/ {0 , 1 , 1 }, /* axis_sizes=*/ {2 , 2 , 2 }}};
@@ -197,10 +194,7 @@ TEST_P(HloShardingToShardingParamTest, HloShardingToShardingParam) {
197
194
TF_ASSERT_OK_AND_ASSIGN (
198
195
auto sharding_param,
199
196
ToShardingParam (param.hlo_sharding , param.rank , param.num_devices ));
200
- // We cannot verify sharding param because we're losing info about the
201
- // axis_size during these conversions. While strictly some ShardingParam
202
- // are invalid because they have more dims than axis, in practice this is not
203
- // a problem because we can still correctly map the shards to the devices.
197
+ EXPECT_TRUE (sharding_param.verify ().ok ());
204
198
TF_ASSERT_OK_AND_ASSIGN (auto actual_hlo_sharding,
205
199
ToHloSharding (sharding_param));
206
200
EXPECT_EQ (param.hlo_sharding , actual_hlo_sharding);
0 commit comments