@@ -976,5 +976,57 @@ ENTRY main {
976
976
HloOpcode::kConstant , HloOpcode::kAdd , HloOpcode::kAdd });
977
977
}
978
978
979
+ TEST_F (InstructionFusionTest, SkipCustomFusions) {
980
+ absl::string_view module_string = R"(
981
+ HloModule module
982
+
983
+ %fused_computation (param_0: f32[10,10], param_1: f32[10,10]) -> f32[10,10] {
984
+ %param_0 = f32[10,10]{1,0} parameter(0)
985
+ %param_1 = f32[10,10]{1,0} parameter(1)
986
+ %add = f32[10,10]{1,0} add(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
987
+ %subtract = f32[10,10]{1,0} subtract(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
988
+ ROOT %multiply = f32[10,10]{1,0} multiply(f32[10,10]{1,0} %add, f32[10,10]{1,0} %subtract)
989
+ }
990
+
991
+ ENTRY %main (Arg_0: f32[10,10], Arg_1: f32[10,10]) -> f32[10,10] {
992
+ %Arg_0 = f32[10,10]{1,0} parameter(0), metadata={op_name="x"}
993
+ %Arg_1 = f32[10,10]{1,0} parameter(1), metadata={op_name="y"}
994
+ ROOT %subtract_multiply_fusion = f32[10,10]{1,0} fusion(f32[10,10]{1,0} %Arg_0, f32[10,10]{1,0} %Arg_1), kind=kCustom, calls=%fused_computation
995
+ }
996
+ )" ;
997
+
998
+ TF_ASSERT_OK_AND_ASSIGN (auto module,
999
+ ParseAndReturnVerifiedModule (module_string));
1000
+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
1001
+ CpuInstructionFusion ().Run (module.get ()));
1002
+ EXPECT_FALSE (changed);
1003
+ }
1004
+
1005
+ TEST_F (InstructionFusionTest, SkipComputationsAttachedToCustomCalls) {
1006
+ absl::string_view module_string = R"(
1007
+ HloModule module
1008
+
1009
+ %custom_computation (param_0: f32[10,10], param_1: f32[10,10]) -> f32[10,10] {
1010
+ %param_0 = f32[10,10]{1,0} parameter(0)
1011
+ %param_1 = f32[10,10]{1,0} parameter(1)
1012
+ %add = f32[10,10]{1,0} add(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
1013
+ %subtract = f32[10,10]{1,0} subtract(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
1014
+ ROOT %multiply = f32[10,10]{1,0} multiply(f32[10,10]{1,0} %add, f32[10,10]{1,0} %subtract)
1015
+ }
1016
+
1017
+ ENTRY %main (Arg_0: f32[10,10], Arg_1: f32[10,10]) -> f32[10,10] {
1018
+ %Arg_0 = f32[10,10]{1,0} parameter(0), metadata={op_name="x"}
1019
+ %Arg_1 = f32[10,10]{1,0} parameter(1), metadata={op_name="y"}
1020
+ ROOT %custom_call = f32[10,10]{1,0} custom-call(f32[10,10]{1,0} %Arg_0, f32[10,10]{1,0} %Arg_1), custom_call_target="target", called_computations={%custom_computation}
1021
+ }
1022
+ )" ;
1023
+
1024
+ TF_ASSERT_OK_AND_ASSIGN (auto module,
1025
+ ParseAndReturnVerifiedModule (module_string));
1026
+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
1027
+ CpuInstructionFusion ().Run (module.get ()));
1028
+ EXPECT_FALSE (changed);
1029
+ }
1030
+
979
1031
} // namespace
980
1032
} // namespace xla::cpu
0 commit comments