Skip to content

Commit 7156818

Browse files
mehrdadkhaniGoogle-ML-Automation
authored andcommitted
[XLA:TPU] Fixes a bug in synchronous memory call ops conversion to async in memory space assignment. The bug showed up when the consumer of a replaced sync instruction is consumed multiple times by the same instruction.
The issue was that the algorithm was not using the correct allocation sequence when dealing with such repeated consumer instructions. This CL fixes the bug by using the correct allocation sequence. PiperOrigin-RevId: 691265175
1 parent e0c2421 commit 7156818

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

xla/service/memory_space_assignment/algorithm.cc

+3-2
Original file line numberDiff line numberDiff line change
@@ -4335,8 +4335,9 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
43354335
// consumed multiple times by the same instruction. We can just find the
43364336
// previous allocation and use that allocation.
43374337
if (request.inclusive_start_time == request.end_time) {
4338-
Allocation* allocation =
4339-
GetLiveAllocationAt(*allocation_sequence, request.end_time);
4338+
Allocation* allocation = GetLiveAllocationAt(
4339+
*request.allocation_value_to_update->mutable_allocation_sequence(),
4340+
request.end_time);
43404341
CHECK_NE(allocation, nullptr);
43414342
allocation->AddUse(request.use->hlo_use);
43424343
return Result::kSuccess;

xla/service/memory_space_assignment/memory_space_assignment_test.cc

+44
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,50 @@ ENTRY %entry (p0.2: f32[10,2,3], p1: f32[10,2,3], p2: pred[]) -> f32[10,2,3] {
11711171
tuple->operand(1)));
11721172
}
11731173

1174+
// Added for b/376344953 that was introduced when we tried to
1175+
// convert a sync copy that was used by a conditional into an async copy.
1176+
TEST_F(MemorySpaceAssignmentTest, ConditionalCopyReplacement) {
1177+
absl::string_view hlo_string = R"(
1178+
HloModule CondAllocation, is_scheduled=true
1179+
1180+
true_computation {
1181+
p0 = (f32[3]{0}) parameter(0)
1182+
gte = f32[3]{0} get-tuple-element(p0), index=0
1183+
ROOT neg1 = f32[3]{0} negate(gte)
1184+
}
1185+
1186+
false_computation {
1187+
p0 = (f32[3]{0}) parameter(0)
1188+
gte = f32[3]{0} get-tuple-element(p0), index=0
1189+
ROOT neg2 = f32[3]{0} negate(gte)
1190+
}
1191+
1192+
ENTRY entry {
1193+
p0_main = f32[3]{0} parameter(0)
1194+
p1 = pred[] parameter(1)
1195+
copy = f32[3]{0} copy(p0_main)
1196+
tuple = (f32[3]{0}) tuple(copy)
1197+
ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
1198+
}
1199+
)";
1200+
TF_ASSERT_OK_AND_ASSIGN(auto module,
1201+
ParseAndReturnVerifiedModule(hlo_string));
1202+
Options options = DefaultMemorySpaceOptions();
1203+
options.enable_sync_copy_replacement = true;
1204+
AssignMemorySpace(module.get(), options);
1205+
auto conditional =
1206+
module->GetComputationWithName("entry")->GetInstructionWithName(
1207+
"conditional");
1208+
CHECK_NE(conditional, nullptr);
1209+
auto p0 = module->GetComputationWithName("entry")->GetInstructionWithName(
1210+
"p0_main");
1211+
CHECK_NE(p0, nullptr);
1212+
auto copy = conditional->operand(1)->operand(0);
1213+
CHECK_NE(copy, nullptr);
1214+
EXPECT_THAT(copy,
1215+
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0));
1216+
}
1217+
11741218
TEST_F(MemorySpaceAssignmentTest, AlwaysSpillJitPrefetchTest) {
11751219
// The negate chain is long enough for asynchronous copy to be inserted
11761220
// between p1 and add.

0 commit comments

Comments
 (0)