Skip to content

Commit 96c8197

Browse files
ZixuanJiangtensorflower-gardener
authored andcommitted
[XLA:SPMD] Use stable sort to fix a flaky test.
PiperOrigin-RevId: 678935741
1 parent 7b650db commit 96c8197

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

third_party/xla/xla/service/spmd/spmd_partitioner.cc

+4-6
Original file line numberDiff line numberDiff line change
@@ -1143,8 +1143,6 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
11431143
std::vector<int64_t>(halo_exchange_base_shape.rank(), 1)));
11441144
}
11451145

1146-
std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
1147-
std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
11481146
// TODO(yuanzx): We are concatenating on each sharded dimension one at time,
11491147
// and in the second dimension (and beyond) we create halos by slicing the
11501148
// concat in the previous dimension, which is not optimal. We should generate
@@ -1162,18 +1160,18 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
11621160
// partition.
11631161
MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
11641162
input_shard_size, explicit_left_padding[dim], 1);
1165-
left_halo_size_functions[dim] =
1163+
OffsetCalculation left_halo_size_functions =
11661164
shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
11671165

11681166
// Right halo.
11691167
MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
11701168
input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
1171-
right_halo_size_functions[dim] =
1169+
OffsetCalculation right_halo_size_functions =
11721170
limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
11731171

11741172
auto resharded = ExchangeHaloAndGetValidData(
1175-
visiting_hlo, halo_exchange_base_shape, left_halo_size_functions[dim],
1176-
right_halo_size_functions[dim], explicit_left_padding[dim],
1173+
visiting_hlo, halo_exchange_base_shape, left_halo_size_functions,
1174+
right_halo_size_functions, explicit_left_padding[dim],
11771175
padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim,
11781176
*halo_exchange_target, offsets_on_padded_shape[dim], pad_value,
11791177
partition_ordinals[dim], state_.collective_ops_creator,

third_party/xla/xla/service/spmd/spmd_partitioner_util.cc

+7-7
Original file line numberDiff line numberDiff line change
@@ -956,8 +956,7 @@ HloInstruction* ExchangeHaloCompact(
956956
(i + 1) * input_shard_size + right_halo_size_function.Calculate(i);
957957
max_window_size = std::max(max_window_size, limit - start);
958958
while (next_start < limit) {
959-
halos[i].emplace_back();
960-
Halo& halo = halos[i].back();
959+
Halo& halo = halos[i].emplace_back();
961960
halo.my_index = i;
962961
halo.halo_offset = next_start - start;
963962
halo.start = next_start % input_shard_size;
@@ -1038,11 +1037,12 @@ HloInstruction* ExchangeHaloCompact(
10381037
// Sort halos that are from the same src according to halo_offset, so that
10391038
// they are more likely to have similar characteristics.
10401039
for (int64_t i = 0; i < src_to_dst.size(); ++i) {
1041-
absl::c_sort(src_to_dst[i], [&](const std::pair<int64_t, int64_t>& a,
1042-
const std::pair<int64_t, int64_t>& b) {
1043-
return halos[a.first][a.second].halo_offset <
1044-
halos[b.first][b.second].halo_offset;
1045-
});
1040+
absl::c_stable_sort(src_to_dst[i],
1041+
[&](const std::pair<int64_t, int64_t>& a,
1042+
const std::pair<int64_t, int64_t>& b) {
1043+
return halos[a.first][a.second].halo_offset <
1044+
halos[b.first][b.second].halo_offset;
1045+
});
10461046
}
10471047

10481048
// Build collective permutes with distinct src/dst values.

0 commit comments

Comments
 (0)