Skip to content

Commit 8edbe17

Browse files
reedwmcopybara-github
authored andcommitted
Fix broken GemmRewriteTest.BF16GemmCodeGen on Hopper.
This was broken by openxla@4e09e73. The issue is that the optimized HLO was changed by having a native BF16 multiply, so the filecheck string had to be changed. Because the multiply is done in BF16 instead of FP32, the tolerance must also be lowered. PiperOrigin-RevId: 616380852
1 parent 6de79c2 commit 8edbe17

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

xla/service/gpu/tests/gemm_rewrite_test.cc

+25-11
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,32 @@ ENTRY bf16gemm {
257257
}
258258
)";
259259

260-
MatchOptimizedHlo(hlo_text, R"(
261-
; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
262-
; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
263-
; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
264-
; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
265-
; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
266-
; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0)
267-
; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
268-
; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
269-
)");
260+
if (CudaOrRocmCheck(9, 0, Switch::False)) {
261+
// The Hopper optimized HLO has a BF16 multiply instruction since Hopper has
262+
// native BF16 multiply support.
263+
MatchOptimizedHlo(hlo_text, R"(
264+
; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
265+
; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
266+
; CHECK: [[INSTR_2:%[^ ]+]] = bf16[3]{0} multiply([[P0]], [[P1]])
267+
; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[INSTR_2]])
268+
; CHECK: [[INSTR_4:%[^ ]+]] = f32[] constant(0)
269+
; CHECK: [[INSTR_5:%[^ ]+]] = f32[] reduce([[INSTR_3]], [[INSTR_4]]), dimensions={0}, to_apply=[[INSTR_6:%[^ ]+]]
270+
; CHECK: ROOT [[INSTR_7:%[^ ]+]] = bf16[] convert([[INSTR_5]])
271+
)");
272+
} else {
273+
MatchOptimizedHlo(hlo_text, R"(
274+
; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1)
275+
; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]])
276+
; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0)
277+
; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]])
278+
; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]])
279+
; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0)
280+
; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]]
281+
; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]])
282+
)");
283+
}
270284

271-
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
285+
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-4}));
272286
}
273287

274288
TEST_F(GemmRewriteTest, BF16Transpose) {

0 commit comments

Comments
 (0)