Skip to content

Commit 892ba95

Browse files
Do not fuse instructions inside custom fusions/calls
PiperOrigin-RevId: 721595219
1 parent a48d834 commit 892ba95

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc

+33
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,43 @@ bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) {
7474
(CanBeOutputFused(consumer->operand(0), consumer) ||
7575
CanBeOutputFused(consumer->operand(1), consumer));
7676
}
77+
7778
} // namespace
7879

80+
void CpuInstructionFusion::ComputeInstructionsToSkip(
81+
HloModule* module,
82+
const absl::flat_hash_set<absl::string_view>& execution_threads) {
83+
const auto computations_list =
84+
module->MakeComputationPostOrder(execution_threads);
85+
instructions_to_skip_.clear();
86+
for (auto* computation : computations_list) {
87+
for (auto* instruction : computation->MakeInstructionPostOrder()) {
88+
if (instruction->IsCustomFusion() ||
89+
instruction->opcode() == HloOpcode::kCustomCall) {
90+
HloCallableInstruction* callable =
91+
Cast<HloCallableInstruction>(instruction);
92+
if (callable->called_computations().empty()) {
93+
continue;
94+
}
95+
for (HloInstruction* instr :
96+
callable->called_computation()->instructions())
97+
instructions_to_skip_.insert(instr);
98+
}
99+
}
100+
}
101+
}
102+
103+
bool CpuInstructionFusion::ShouldSkip(const HloInstruction* inst) const {
104+
return instructions_to_skip_.contains(inst);
105+
}
106+
79107
FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
80108
int64_t operand_index) {
109+
if (ShouldSkip(consumer)) {
110+
return FusionDecision::Forbid(
111+
"Don't fuse instructions from custom fusions/calls");
112+
}
113+
81114
HloInstruction* producer = consumer->mutable_operand(operand_index);
82115
VLOG(2) << "Considering for fusion: operand " << operand_index << " of "
83116
<< consumer->ToString();

third_party/xla/xla/service/cpu/cpu_instruction_fusion.h

+8
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class CpuInstructionFusion : public InstructionFusion {
4040
const absl::flat_hash_set<absl::string_view>&
4141
execution_threads) override {
4242
fusion_node_evaluations_.clear();
43+
ComputeInstructionsToSkip(module, execution_threads);
4344
return InstructionFusion::Run(module, execution_threads);
4445
}
4546

@@ -62,10 +63,17 @@ class CpuInstructionFusion : public InstructionFusion {
6263
// Returns if a constant is large enough to be considered a large constant.
6364
bool IsLargeConstant(const HloInstruction* constant) const;
6465

66+
bool ShouldSkip(const HloInstruction* inst) const;
67+
void ComputeInstructionsToSkip(
68+
HloModule* module,
69+
const absl::flat_hash_set<absl::string_view>& execution_threads);
70+
6571
// Keep track of the number of times each instruction inside a fusion node is
6672
// indexed with different index vectors.
6773
absl::flat_hash_map<const HloInstruction*, FusionNodeIndexingEvaluation>
6874
fusion_node_evaluations_;
75+
76+
absl::flat_hash_set<const HloInstruction*> instructions_to_skip_;
6977
};
7078

7179
} // namespace cpu

third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc

+52
Original file line numberDiff line numberDiff line change
@@ -976,5 +976,57 @@ ENTRY main {
976976
HloOpcode::kConstant, HloOpcode::kAdd, HloOpcode::kAdd});
977977
}
978978

979+
TEST_F(InstructionFusionTest, SkipCustomFusions) {
980+
absl::string_view module_string = R"(
981+
HloModule module
982+
983+
%fused_computation (param_0: f32[10,10], param_1: f32[10,10]) -> f32[10,10] {
984+
%param_0 = f32[10,10]{1,0} parameter(0)
985+
%param_1 = f32[10,10]{1,0} parameter(1)
986+
%add = f32[10,10]{1,0} add(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
987+
%subtract = f32[10,10]{1,0} subtract(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
988+
ROOT %multiply = f32[10,10]{1,0} multiply(f32[10,10]{1,0} %add, f32[10,10]{1,0} %subtract)
989+
}
990+
991+
ENTRY %main (Arg_0: f32[10,10], Arg_1: f32[10,10]) -> f32[10,10] {
992+
%Arg_0 = f32[10,10]{1,0} parameter(0), metadata={op_name="x"}
993+
%Arg_1 = f32[10,10]{1,0} parameter(1), metadata={op_name="y"}
994+
ROOT %subtract_multiply_fusion = f32[10,10]{1,0} fusion(f32[10,10]{1,0} %Arg_0, f32[10,10]{1,0} %Arg_1), kind=kCustom, calls=%fused_computation
995+
}
996+
)";
997+
998+
TF_ASSERT_OK_AND_ASSIGN(auto module,
999+
ParseAndReturnVerifiedModule(module_string));
1000+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
1001+
CpuInstructionFusion().Run(module.get()));
1002+
EXPECT_FALSE(changed);
1003+
}
1004+
1005+
TEST_F(InstructionFusionTest, SkipComputationsAttachedToCustomCalls) {
1006+
absl::string_view module_string = R"(
1007+
HloModule module
1008+
1009+
%custom_computation (param_0: f32[10,10], param_1: f32[10,10]) -> f32[10,10] {
1010+
%param_0 = f32[10,10]{1,0} parameter(0)
1011+
%param_1 = f32[10,10]{1,0} parameter(1)
1012+
%add = f32[10,10]{1,0} add(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
1013+
%subtract = f32[10,10]{1,0} subtract(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
1014+
ROOT %multiply = f32[10,10]{1,0} multiply(f32[10,10]{1,0} %add, f32[10,10]{1,0} %subtract)
1015+
}
1016+
1017+
ENTRY %main (Arg_0: f32[10,10], Arg_1: f32[10,10]) -> f32[10,10] {
1018+
%Arg_0 = f32[10,10]{1,0} parameter(0), metadata={op_name="x"}
1019+
%Arg_1 = f32[10,10]{1,0} parameter(1), metadata={op_name="y"}
1020+
ROOT %custom_call = f32[10,10]{1,0} custom-call(f32[10,10]{1,0} %Arg_0, f32[10,10]{1,0} %Arg_1), custom_call_target="target", called_computations={%custom_computation}
1021+
}
1022+
)";
1023+
1024+
TF_ASSERT_OK_AND_ASSIGN(auto module,
1025+
ParseAndReturnVerifiedModule(module_string));
1026+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
1027+
CpuInstructionFusion().Run(module.get()));
1028+
EXPECT_FALSE(changed);
1029+
}
1030+
9791031
} // namespace
9801032
} // namespace xla::cpu

0 commit comments

Comments
 (0)