@@ -1143,8 +1143,6 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
1143
1143
std::vector<int64_t >(halo_exchange_base_shape.rank (), 1 )));
1144
1144
}
1145
1145
1146
- std::vector<OffsetCalculation> left_halo_size_functions (base_shape_.rank ());
1147
- std::vector<OffsetCalculation> right_halo_size_functions (base_shape_.rank ());
1148
1146
// TODO(yuanzx): We are concatenating on each sharded dimension one at time,
1149
1147
// and in the second dimension (and beyond) we create halos by slicing the
1150
1148
// concat in the previous dimension, which is not optimal. We should generate
@@ -1162,18 +1160,18 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
1162
1160
// partition.
1163
1161
MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded (
1164
1162
input_shard_size, explicit_left_padding[dim], 1 );
1165
- left_halo_size_functions[dim] =
1163
+ OffsetCalculation left_halo_size_functions =
1166
1164
shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
1167
1165
1168
1166
// Right halo.
1169
1167
MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded (
1170
1168
input_shard_size, input_shard_size + explicit_left_padding[dim], 1 );
1171
- right_halo_size_functions[dim] =
1169
+ OffsetCalculation right_halo_size_functions =
1172
1170
limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
1173
1171
1174
1172
auto resharded = ExchangeHaloAndGetValidData (
1175
- visiting_hlo, halo_exchange_base_shape, left_halo_size_functions[dim] ,
1176
- right_halo_size_functions[dim] , explicit_left_padding[dim],
1173
+ visiting_hlo, halo_exchange_base_shape, left_halo_size_functions,
1174
+ right_halo_size_functions, explicit_left_padding[dim],
1177
1175
padded_shape.dimensions (dim), shard_shape.dimensions (dim), dim,
1178
1176
*halo_exchange_target, offsets_on_padded_shape[dim], pad_value,
1179
1177
partition_ordinals[dim], state_.collective_ops_creator ,
0 commit comments