@@ -1028,5 +1028,55 @@ ENTRY %main (Arg_0: f32[10,10], Arg_1: f32[10,10]) -> f32[10,10] {
1028
1028
EXPECT_FALSE (changed);
1029
1029
}
1030
1030
1031
+ static constexpr absl::string_view kScatterModuleString = R"(
1032
+ HloModule module
1033
+
1034
+ %scatter_max (param0: f32[], param1: f32[]) -> f32[] {
1035
+ %lhs = f32[] parameter(0)
1036
+ %rhs = f32[] parameter(1)
1037
+ %maximum.1 = f32[] maximum(f32[] lhs, f32[] rhs)
1038
+ %convert.8 = bf16[] convert(f32[] maximum.1)
1039
+ ROOT %convert.9 = f32[] convert(bf16[] convert.8)
1040
+ }
1041
+
1042
+ ENTRY %main (arg0: f32[13,5,10,62], arg1: s32[3,1], arg2: f32[3,1,5,10,62])
1043
+ -> f32[13,5,10,62] {
1044
+ %arg0 = f32[13,5,10,62]{3,2,1,0} parameter(0)
1045
+ %arg1 = s32[3,1]{1,0} parameter(1)
1046
+ %arg2 = f32[3,1,5,10,62]{4,3,2,1,0} parameter(2)
1047
+ ROOT %scatter.2 = f32[13,5,10,62]{3,2,1,0} scatter(
1048
+ f32[13,5,10,62]{3,2,1,0} %arg0,
1049
+ s32[3,1]{1,0} %arg1,
1050
+ f32[3,1,5,10,62]{4,3,2,1,0} %arg2),
1051
+ update_window_dims={1,2,3,4}, inserted_window_dims={},
1052
+ scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=scatter_max
1053
+ }
1054
+ )" ;
1055
+
1056
+ TEST_F (InstructionFusionTest, SkipScatterComputationsIfFusionEmitters) {
1057
+ auto mod_config = GetModuleConfigForTest ();
1058
+ auto debug_options = GetDebugOptionsForTest ();
1059
+ debug_options.set_xla_cpu_use_thunk_runtime (true );
1060
+ debug_options.set_xla_cpu_use_fusion_emitters (true );
1061
+ mod_config.set_debug_options (debug_options);
1062
+ TF_ASSERT_OK_AND_ASSIGN (auto module, ParseAndReturnVerifiedModule (
1063
+ kScatterModuleString , mod_config));
1064
+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
1065
+ CpuInstructionFusion ().Run (module.get ()));
1066
+ EXPECT_FALSE (changed);
1067
+ }
1068
+
1069
+ TEST_F (InstructionFusionTest, NoSkipScatterComputationsIfNoFusionEmitters) {
1070
+ auto mod_config = GetModuleConfigForTest ();
1071
+ auto debug_options = GetDebugOptionsForTest ();
1072
+ debug_options.set_xla_cpu_use_fusion_emitters (false );
1073
+ mod_config.set_debug_options (debug_options);
1074
+ TF_ASSERT_OK_AND_ASSIGN (auto module, ParseAndReturnVerifiedModule (
1075
+ kScatterModuleString , mod_config));
1076
+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
1077
+ CpuInstructionFusion ().Run (module.get ()));
1078
+ EXPECT_TRUE (changed);
1079
+ }
1080
+
1031
1081
} // namespace
1032
1082
} // namespace xla::cpu
0 commit comments