From a65cd9eb8742c96cd736b62b24c8760a5e687ed7 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 28 Jan 2025 11:11:29 -0800 Subject: [PATCH] low pow(x,1) to x --- csrc/codegen.cpp | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 98bd5a1624f..0ec40ef5342 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -990,32 +990,39 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { return false; } - // Only **2 and **3 are considered - if (!(exponent == 2 || exponent == 3)) { + // Only **1, **2 and **3 are considered + if (!(exponent == 1 || exponent == 2 || exponent == 3)) { return false; } auto lhs = gen(bop->lhs()); if (print_inline_) { - code_ << lhs << " * " << lhs; - if (exponent == 3) { - code_ << " * " << lhs; + if (exponent == 1) { + code_ << lhs; + } else if (exponent == 2) { + code_ << lhs << " * " << lhs; + } else if (exponent == 3) { + code_ << lhs << " * " << lhs << " * " << lhs; } } else { indent() << gen(bop->out()); if (bop->out()->isScalar()) { - code_ << " = " << lhs << " * " << lhs; - if (exponent == 3) { - code_ << " * " << lhs; + if (exponent == 1) { + code_ << " = " << lhs; + } else if (exponent == 2) { + code_ << " = " << lhs << " * " << lhs; + } else if (exponent == 3) { + code_ << " = " << lhs << " * " << lhs << " * " << lhs; } } else { code_ << "\n"; - indent() << kTab << "= " << lhs << "\n"; - indent() << kTab << "* " << lhs; - if (exponent == 3) { - code_ << "\n"; - indent() << kTab << "* " << lhs; + if (exponent == 1) { + indent() << kTab << "= " << lhs; + } else if (exponent == 2) { + indent() << kTab << "= " << lhs << "\n * " << lhs; + } else if (exponent == 3) { + indent() << kTab << "= " << lhs << "\n * " << lhs << "\n * " << lhs; } } }