Skip to content

Commit 4113e8d

Browse files
mrrytensorflower-gardener
authored andcommitted
[MSA] Microoptimizations in AsynchronousCopyResource.
Based on a profiling the memory-space assignment algorithm, this change makes two small optimizations to `AsynchronousCopyResource`: * Pass a pre-reserved `std::vector<std::pair<int64_t, float>>` instead of an `absl::flat_hash_map<int64_t, float>` to capture the changes to `delays`, because we do not need random access to the map, and a vector is faster to resize than a hash map. * Cache the raw data pointers from `std::vector<float>` to avoid the overhead of bounds and null checking in the hardened `std::vector` implementation. * Replace the simple functions in `time_utils.cc` with inline implementations in `time_utils.h`: since these boil down to adding or subtracting `1`, the resulting code will be smaller and more efficient (and less likely to spill FP registers to the stack). * Refactor the inner-loop that writes `delay_changes` so that the floating-point operations are not separated by a data-dependent call, and we can keep more `float`s in registers. PiperOrigin-RevId: 719112419
1 parent aa26c7d commit 4113e8d

File tree

6 files changed

+65
-65
lines changed

6 files changed

+65
-65
lines changed

third_party/xla/xla/service/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -6335,7 +6335,6 @@ cc_library(
63356335

63366336
cc_library(
63376337
name = "time_utils",
6338-
srcs = ["time_utils.cc"],
63396338
hdrs = ["time_utils.h"],
63406339
deps = [],
63416340
)

third_party/xla/xla/service/memory_space_assignment/BUILD

+2-3
Original file line numberDiff line numberDiff line change
@@ -588,18 +588,17 @@ cc_library(
588588
"@com_google_absl//absl/container:btree",
589589
"@com_google_absl//absl/container:flat_hash_map",
590590
"@com_google_absl//absl/container:flat_hash_set",
591+
"@com_google_absl//absl/container:inlined_vector",
591592
"@com_google_absl//absl/functional:any_invocable",
592593
"@com_google_absl//absl/log",
593594
"@com_google_absl//absl/log:check",
595+
"@com_google_absl//absl/memory",
594596
"@com_google_absl//absl/status",
595597
"@com_google_absl//absl/status:statusor",
596598
"@com_google_absl//absl/strings",
597599
"@com_google_absl//absl/strings:str_format",
598600
"@com_google_absl//absl/strings:string_view",
599601
"@com_google_absl//absl/types:span",
600-
"@local_tsl//tsl/platform:logging",
601-
"@local_tsl//tsl/platform:status",
602-
"@local_tsl//tsl/platform:statusor",
603602
],
604603
)
605604

third_party/xla/xla/service/memory_space_assignment/algorithm.cc

+40-16
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#include "absl/algorithm/container.h"
3737
#include "absl/container/flat_hash_map.h"
3838
#include "absl/container/flat_hash_set.h"
39+
#include "absl/container/inlined_vector.h"
3940
#include "absl/functional/any_invocable.h"
4041
#include "absl/log/check.h"
4142
#include "absl/log/log.h"
@@ -3129,8 +3130,16 @@ bool AsynchronousCopyOrdering::ViolatesOrdering(int64_t exclusive_start_time,
31293130

31303131
bool AsynchronousCopyResource::ConsumeResource(
31313132
int64_t exclusive_start_time, int64_t end_time, float resource,
3132-
absl::flat_hash_map<int64_t, float>* delay_change_map,
3133+
std::vector<std::pair<int64_t, float>>* delay_changes,
31333134
float resource_to_free) {
3135+
// Cache the pointers to the arrays to avoid the overhead of `operator[]`
3136+
// size checks in hardened libc++.
3137+
//
3138+
// NOTE: Do not modify the vectors `initial_resources_` or `delay_` in this
3139+
// function, otherwise the pointers will become dangling.
3140+
float* initial_resources_ptr = initial_resources_.data();
3141+
float* delay_ptr = delay_.data();
3142+
31343143
std::list<AsynchronousCopy>::iterator current_copy = async_copies_.end();
31353144
// In order to propagate the resource to the next scheduled copy, we iterate
31363145
// over the copies in start time order until we either find enough free
@@ -3160,7 +3169,8 @@ bool AsynchronousCopyResource::ConsumeResource(
31603169
// this copy would have to be delayed because of an earlier copy that wasn't
31613170
// finished when this copy starts.
31623171
if (current_copy == async_copies_.end()) {
3163-
resource += delay_[ExclusiveToInclusiveStartTime(exclusive_start_time)];
3172+
resource +=
3173+
delay_ptr[ExclusiveToInclusiveStartTime(exclusive_start_time)];
31643174
}
31653175

31663176
// Find the copy that is right after this one. If there are leftover
@@ -3186,7 +3196,7 @@ bool AsynchronousCopyResource::ConsumeResource(
31863196
time < end_time && resource != 0; ++time) {
31873197
// Iterate over the logical times that this copy spans. Note that the
31883198
// start and end time ranges are exclusive.
3189-
float used_resource = std::min(resource, initial_resources_[time]);
3199+
float used_resource = std::min(resource, initial_resources_ptr[time]);
31903200
if (next_copy != async_copies_.end() &&
31913201
next_copy->exclusive_start_time ==
31923202
InclusiveToExclusiveStartTime(time)) {
@@ -3199,15 +3209,17 @@ bool AsynchronousCopyResource::ConsumeResource(
31993209
if (!delay_for_next_copy.has_value()) {
32003210
// Update the delay_ vector and resource_freed variable with the amount
32013211
// that was freed when removing the copy.
3212+
float old_delay = delay_ptr[time];
32023213
float old_resource =
3203-
std::max(0.0f, initial_resources_[time] - delay_[time]);
3204-
if (delay_change_map) {
3205-
delay_change_map->emplace(time, delay_[time]);
3206-
}
3207-
delay_[time] = std::max(0.0f, resource - resource_to_free);
3214+
std::max(0.0f, initial_resources_ptr[time] - old_delay);
3215+
float new_delay = std::max(0.0f, resource - resource_to_free);
32083216
float new_resource =
3209-
std::max(0.0f, initial_resources_[time] - delay_[time]);
3217+
std::max(0.0f, initial_resources_ptr[time] - new_delay);
32103218
resource_freed += std::max(0.0f, new_resource - old_resource);
3219+
delay_ptr[time] = new_delay;
3220+
if (delay_changes) {
3221+
delay_changes->emplace_back(time, old_delay);
3222+
}
32113223
}
32123224
// Update the resource with the used amount in this logical time.
32133225
resource -= used_resource;
@@ -3303,7 +3315,7 @@ void AsynchronousCopyResource::RemoveCopy(
33033315
copy_it->exclusive_start_time);
33043316
CHECK(ConsumeResource(copy_it->exclusive_start_time, copy_it->end_time,
33053317
/*resource=*/0,
3306-
/*delay_change_map=*/nullptr,
3318+
/*delay_changes=*/nullptr,
33073319
/*resource_to_free=*/copy_it->resource));
33083320
// If the copy to be removed is the value pointed by async_copy_time_map_, we
33093321
// make the next copy with the same start time to be pointed by
@@ -3325,24 +3337,36 @@ void AsynchronousCopyResource::RemoveCopy(
33253337
bool AsynchronousCopyResource::HasEnoughResource(int64_t exclusive_start_time,
33263338
int64_t end_time,
33273339
float resource) {
3328-
absl::flat_hash_map<int64_t, float> delay_changes;
3340+
std::vector<std::pair<int64_t, float>> delay_changes;
3341+
delay_changes.reserve(delay_.size());
33293342
bool result =
33303343
ConsumeResource(exclusive_start_time, end_time, resource, &delay_changes);
3331-
for (const auto& change_pair : delay_changes) {
3332-
delay_[change_pair.first] = change_pair.second;
3344+
// Apply the delay changes in reverse order. This ensures that the original
3345+
// value of each delay is restored.
3346+
if (!delay_changes.empty()) {
3347+
for (int64_t i = delay_changes.size() - 1; i >= 0; --i) {
3348+
const auto& [time, delay] = delay_changes[i];
3349+
delay_[time] = delay;
3350+
}
33333351
}
33343352
return result;
33353353
}
33363354

33373355
bool AsynchronousCopyResource::HasEnoughResourceMultiCheck(
33383356
const std::vector<ResourceSpec>& specs) {
3339-
absl::flat_hash_map<int64_t, float> delay_changes;
3357+
std::vector<std::pair<int64_t, float>> delay_changes;
3358+
delay_changes.reserve(delay_.size());
33403359
bool result = absl::c_all_of(specs, [&](const ResourceSpec& spec) {
33413360
return ConsumeResource(spec.exclusive_start_time, spec.end_time,
33423361
spec.resource, &delay_changes);
33433362
});
3344-
for (const auto& change_pair : delay_changes) {
3345-
delay_[change_pair.first] = change_pair.second;
3363+
// Apply the delay changes in reverse order. This ensures that the original
3364+
// value of each delay is restored.
3365+
if (!delay_changes.empty()) {
3366+
for (int64_t i = delay_changes.size() - 1; i >= 0; --i) {
3367+
const auto& [time, delay] = delay_changes[i];
3368+
delay_[time] = delay;
3369+
}
33463370
}
33473371
return result;
33483372
}

third_party/xla/xla/service/memory_space_assignment/algorithm.h

+7-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALGORITHM_H_
1818

1919
#include <algorithm>
20+
#include <cstddef>
2021
#include <cstdint>
2122
#include <list>
2223
#include <map>
@@ -35,6 +36,8 @@ limitations under the License.
3536
#endif
3637
#include "absl/container/flat_hash_map.h"
3738
#include "absl/container/flat_hash_set.h"
39+
#include "absl/container/inlined_vector.h"
40+
#include "absl/memory/memory.h"
3841
#include "absl/status/status.h"
3942
#include "absl/status/statusor.h"
4043
#include "absl/strings/string_view.h"
@@ -218,12 +221,13 @@ class AsynchronousCopyResource {
218221

219222
private:
220223
// Internal helper method to implement adding/removing/checking resources.
221-
// ConsumeResource() may modify delay_. If delay_change_map is not null,
224+
// ConsumeResource() may modify delay_. If delay_changes is not null,
222225
// for any change to delay_[i], {i, delay_[i]} will be added to
223-
// delay_change_map, allowing callers to undo any modifications.
226+
// delay_changes, allowing callers to undo any modifications by iterating over
227+
// the vector in reverse order.
224228
bool ConsumeResource(
225229
int64_t exclusive_start_time, int64_t end_time, float resource,
226-
absl::flat_hash_map<int64_t, float>* delay_change_map = nullptr,
230+
std::vector<std::pair<int64_t, float>>* delay_changes = nullptr,
227231
float resource_to_free = 0.0);
228232

229233
// Same as the public RemoveCopy except it works on the async_copies_

third_party/xla/xla/service/time_utils.cc

-38
This file was deleted.

third_party/xla/xla/service/time_utils.h

+16-4
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,22 @@ limitations under the License.
2121
namespace xla {
2222

2323
// Convert between inclusive/exclusive start/end times.
24-
int64_t ExclusiveToInclusiveStartTime(int64_t exclusive_time);
25-
int64_t InclusiveToExclusiveStartTime(int64_t inclusive_time);
26-
int64_t ExclusiveToInclusiveEndTime(int64_t exclusive_time);
27-
int64_t InclusiveToExclusiveEndTime(int64_t inclusive_time);
24+
25+
inline int64_t ExclusiveToInclusiveStartTime(int64_t exclusive_time) {
26+
return exclusive_time + 1;
27+
}
28+
29+
inline int64_t InclusiveToExclusiveStartTime(int64_t inclusive_time) {
30+
return inclusive_time - 1;
31+
}
32+
33+
inline int64_t ExclusiveToInclusiveEndTime(int64_t exclusive_time) {
34+
return exclusive_time - 1;
35+
}
36+
37+
inline int64_t InclusiveToExclusiveEndTime(int64_t inclusive_time) {
38+
return inclusive_time + 1;
39+
}
2840

2941
} // namespace xla
3042

0 commit comments

Comments
 (0)