Skip to content

Commit 7c82c46

Browse files
cotatensorflower-gardener
authored andcommitted
[xla:cpu] fusion emitters: do not fuse inside computations called from scatter
The scatter fusion emitter won't be able to handle fusions inside computations called from the scatter instruction. Pave the way for the scatter fusion emitter by forbidding those fusions. PiperOrigin-RevId: 735969859
1 parent df24a90 commit 7c82c46

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc

+13
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ void CpuInstructionFusion::ComputeInstructionsToSkip(
8484
const auto computations_list =
8585
module->MakeComputationPostOrder(execution_threads);
8686
instructions_to_skip_.clear();
87+
const bool is_fusion_emitters =
88+
module->config().debug_options().xla_cpu_use_thunk_runtime() &&
89+
module->config().debug_options().xla_cpu_use_fusion_emitters();
8790
for (auto* computation : computations_list) {
8891
for (auto* instruction : computation->MakeInstructionPostOrder()) {
8992
if (instruction->IsCustomFusion() ||
@@ -96,6 +99,16 @@ void CpuInstructionFusion::ComputeInstructionsToSkip(
9699
for (HloInstruction* instr :
97100
callable->called_computation()->instructions())
98101
instructions_to_skip_.insert(instr);
102+
} else if (is_fusion_emitters &&
103+
instruction->opcode() == HloOpcode::kScatter) {
104+
// Disallow fusions in the called computation (e.g. reduction)
105+
// of a scatter "fusion"; the fusion emitter can't handle them.
106+
auto* scatter = Cast<HloScatterInstruction>(instruction);
107+
for (const auto* computation : scatter->called_computations()) {
108+
for (const auto* instr : computation->instructions()) {
109+
instructions_to_skip_.insert(instr);
110+
}
111+
}
99112
}
100113
}
101114
}

third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc

+50
Original file line numberDiff line numberDiff line change
@@ -1028,5 +1028,55 @@ ENTRY %main (Arg_0: f32[10,10], Arg_1: f32[10,10]) -> f32[10,10] {
10281028
EXPECT_FALSE(changed);
10291029
}
10301030

1031+
static constexpr absl::string_view kScatterModuleString = R"(
1032+
HloModule module
1033+
1034+
%scatter_max (param0: f32[], param1: f32[]) -> f32[] {
1035+
%lhs = f32[] parameter(0)
1036+
%rhs = f32[] parameter(1)
1037+
%maximum.1 = f32[] maximum(f32[] lhs, f32[] rhs)
1038+
%convert.8 = bf16[] convert(f32[] maximum.1)
1039+
ROOT %convert.9 = f32[] convert(bf16[] convert.8)
1040+
}
1041+
1042+
ENTRY %main (arg0: f32[13,5,10,62], arg1: s32[3,1], arg2: f32[3,1,5,10,62])
1043+
-> f32[13,5,10,62] {
1044+
%arg0 = f32[13,5,10,62]{3,2,1,0} parameter(0)
1045+
%arg1 = s32[3,1]{1,0} parameter(1)
1046+
%arg2 = f32[3,1,5,10,62]{4,3,2,1,0} parameter(2)
1047+
ROOT %scatter.2 = f32[13,5,10,62]{3,2,1,0} scatter(
1048+
f32[13,5,10,62]{3,2,1,0} %arg0,
1049+
s32[3,1]{1,0} %arg1,
1050+
f32[3,1,5,10,62]{4,3,2,1,0} %arg2),
1051+
update_window_dims={1,2,3,4}, inserted_window_dims={},
1052+
scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=scatter_max
1053+
}
1054+
)";
1055+
1056+
TEST_F(InstructionFusionTest, SkipScatterComputationsIfFusionEmitters) {
1057+
auto mod_config = GetModuleConfigForTest();
1058+
auto debug_options = GetDebugOptionsForTest();
1059+
debug_options.set_xla_cpu_use_thunk_runtime(true);
1060+
debug_options.set_xla_cpu_use_fusion_emitters(true);
1061+
mod_config.set_debug_options(debug_options);
1062+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
1063+
kScatterModuleString, mod_config));
1064+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
1065+
CpuInstructionFusion().Run(module.get()));
1066+
EXPECT_FALSE(changed);
1067+
}
1068+
1069+
TEST_F(InstructionFusionTest, NoSkipScatterComputationsIfNoFusionEmitters) {
1070+
auto mod_config = GetModuleConfigForTest();
1071+
auto debug_options = GetDebugOptionsForTest();
1072+
debug_options.set_xla_cpu_use_fusion_emitters(false);
1073+
mod_config.set_debug_options(debug_options);
1074+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
1075+
kScatterModuleString, mod_config));
1076+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
1077+
CpuInstructionFusion().Run(module.get()));
1078+
EXPECT_TRUE(changed);
1079+
}
1080+
10311081
} // namespace
10321082
} // namespace xla::cpu

0 commit comments

Comments
 (0)