Skip to content

Commit

Permalink
[XLA] Fix pathological exponential compilation time in algebraic simp…
Browse files Browse the repository at this point in the history
…lifier.

PiperOrigin-RevId: 687397035
  • Loading branch information
berkinilbeyi authored and Google-ML-Automation committed Oct 18, 2024
1 parent ded5e51 commit 7e3b009
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
10 changes: 10 additions & 0 deletions xla/hlo/transforms/simplifiers/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4332,6 +4332,11 @@ absl::Status AlgebraicSimplifierVisitor::HandleMaximum(
HloInstruction *lhs, *rhs;
CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs))));

// max(x, x) -> x
if (lhs == rhs) {
return ReplaceInstruction(maximum, lhs);
}

// max(x, -inf) -> x
PrimitiveType ty = maximum->shape().element_type();
if (primitive_util::IsIntegralType(ty) ||
Expand Down Expand Up @@ -4432,6 +4437,11 @@ absl::Status AlgebraicSimplifierVisitor::HandleMinimum(
HloInstruction *lhs, *rhs;
CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs))));

// min(x, x) -> x
if (lhs == rhs) {
return ReplaceInstruction(minimum, lhs);
}

// min(x, inf) -> x
PrimitiveType ty = minimum->shape().element_type();
if (primitive_util::IsIntegralType(ty) ||
Expand Down
78 changes: 78 additions & 0 deletions xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12342,5 +12342,83 @@ TEST_F(AlgebraicSimplifierTest, BitcastBroadcastDifferentLayout) {
EXPECT_FALSE(simplifier.Run(module.get()).value());
}

TEST_F(AlgebraicSimplifierTest, TrivialMin) {
const char* kModuleStr = R"(
HloModule m
test {
a = f32[4,4] parameter(0)
ROOT %min = f32[4,4] minimum(%a, %a)
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}

TEST_F(AlgebraicSimplifierTest, TrivialMax) {
const char* kModuleStr = R"(
HloModule m
test {
a = f32[4,4] parameter(0)
ROOT %min = f32[4,4] maximum(%a, %a)
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}

TEST_F(AlgebraicSimplifierTest, PathologicalComplexity) {
// Without replacing min(x,x)->x, the algorithmic recursion complexity is
// O(2^n).
const char* kModuleStr = R"(
HloModule m
test {
a = s32[4,4] parameter(0)
b = s32[4,4] parameter(1)
%cmp0 = pred[4,4] compare(a, b), direction=GE
%c1 = f32[] constant(1)
%ones = f32[4,4] broadcast(f32[] %c1)
%c0 = f32[] constant(0)
%zeros = f32[4,4] broadcast(f32[] %c0)
%min = f32[4,4] minimum(%ones, %zeros)
%min0 = f32[4,4] minimum(%min, %min)
%min1 = f32[4,4] minimum(%min0, %min0)
%min2 = f32[4,4] minimum(%min1, %min1)
%min3 = f32[4,4] minimum(%min2, %min2)
%min4 = f32[4,4] minimum(%min3, %min3)
%min5 = f32[4,4] minimum(%min4, %min4)
%min6 = f32[4,4] minimum(%min5, %min5)
%min7 = f32[4,4] minimum(%min6, %min6)
%min8 = f32[4,4] minimum(%min7, %min7)
%min9 = f32[4,4] minimum(%min8, %min8)
%min10 = f32[4,4] minimum(%min9, %min9)
%min11 = f32[4,4] minimum(%min10, %min10)
%min12 = f32[4,4] minimum(%min11, %min11)
%min13 = f32[4,4] minimum(%min12, %min12)
%min14 = f32[4,4] minimum(%min13, %min13)
%min15 = f32[4,4] minimum(%min14, %min14)
%min16 = f32[4,4] minimum(%min15, %min15)
%min17 = f32[4,4] minimum(%min16, %min16)
%min18 = f32[4,4] minimum(%min17, %min17)
%min19 = f32[4,4] minimum(%min18, %min18)
%min20 = f32[4,4] minimum(%min19, %min19)
%min21 = f32[4,4] minimum(%min20, %min20)
%min22 = f32[4,4] minimum(%min21, %min21)
%min23 = f32[4,4] minimum(%min22, %min22)
%min24 = f32[4,4] minimum(%min23, %min23)
%min25 = f32[4,4] minimum(%min24, %min24)
%min26 = f32[4,4] minimum(%min25, %min25)
%min27 = f32[4,4] minimum(%min26, %min26)
%min28 = f32[4,4] minimum(%min27, %min27)
%min29 = f32[4,4] minimum(%min28, %min28)
ROOT %cmp1 = pred[4,4] compare(%min29, %zeros), direction=LT
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::Constant())));
}

} // namespace
} // namespace xla

0 comments on commit 7e3b009

Please sign in to comment.