From 3199eb5a7ac0459f42650e4bd2f6ddc59205fd4a Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Sun, 4 Feb 2024 15:17:17 +0800 Subject: [PATCH] =?UTF-8?q?feat(kernel):=20=E6=B7=BB=E5=8A=A0exp=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../include/kernel/collectors/simple_unary.h | 1 + src/04kernel/src/collectors/simple_unary.cc | 1 + .../src/kernels/simple_unary/cpu_kernel.cc | 8 ++++++++ .../src/kernels/simple_unary/cuda_kernel.cc | 19 ++++++++++++++++--- .../test/kernels/simple_unary/test_cpu.cpp | 1 + .../test/kernels/simple_unary/test_cuda.cpp | 1 + .../src/operators/simple_unary.cc | 6 ++++++ src/07onnx/src/operators.cpp | 1 + src/07onnx/src/operators/simple_unary.cc | 9 ++++++++- src/07onnx/src/operators/simple_unary.hh | 1 + src/07onnx/test/test_simple_unary.cpp | 13 +++++++++++++ 11 files changed, 57 insertions(+), 4 deletions(-) diff --git a/src/04kernel/include/kernel/collectors/simple_unary.h b/src/04kernel/include/kernel/collectors/simple_unary.h index 913c0095..f19b0358 100644 --- a/src/04kernel/include/kernel/collectors/simple_unary.h +++ b/src/04kernel/include/kernel/collectors/simple_unary.h @@ -26,6 +26,7 @@ namespace refactor::kernel { Neg, Not, HardSwish, + Exp, }; std::string_view unaryName(SimpleUnaryType type); diff --git a/src/04kernel/src/collectors/simple_unary.cc b/src/04kernel/src/collectors/simple_unary.cc index 51a334c9..378e341a 100644 --- a/src/04kernel/src/collectors/simple_unary.cc +++ b/src/04kernel/src/collectors/simple_unary.cc @@ -29,6 +29,7 @@ namespace refactor::kernel { CASE(Sqrt); CASE(Sigmoid); CASE(Erf); + CASE(Exp); CASE(Neg); CASE(Not); CASE(HardSwish); diff --git a/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc b/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc index 8b6a0d13..fad4777b 100644 --- a/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc @@ -20,6 +20,7 @@ namespace refactor::kernel { Op::Neg, Op::Erf, Op::HardSwish, + Op::Exp, }; return supportedOp.contains(op) && a.dataType.isCpuNumberic() ? std::make_unique(op, a.dataType, a.elementsSize()) @@ -185,6 +186,13 @@ namespace refactor::kernel { default: UNREACHABLE(); } + case Op::Exp: + switch (dataType) { + CASE(std::exp, F32); + CASE(std::exp, F64); + default: + UNREACHABLE(); + } default: UNREACHABLE(); } diff --git a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc index e883374a..effd4ae5 100644 --- a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc @@ -17,9 +17,17 @@ namespace refactor::kernel { auto K::build(Op op, Tensor const &a) noexcept -> KernelBox { static const std::unordered_set - supportedOp{Op::Abs, Op::Relu, Op::Sqrt, - Op::Sigmoid, Op::Tanh, Op::Neg, - Op::Erf, Op::HardSwish}; + supportedOp{ + Op::Abs, + Op::Relu, + Op::Sqrt, + Op::Sigmoid, + Op::Tanh, + Op::Neg, + Op::Erf, + Op::HardSwish, + Op::Exp, + }; #ifndef USE_CUDA return nullptr; #endif @@ -159,6 +167,11 @@ extern "C" __global__ void kernel( {__(Op::HardSwish, DT::FP16), "x * __hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, hrcp(__float2half(6.f)) * x + hrcp(__float2half(2.f))))"}, {__(Op::HardSwish, DT::F64 ), "x * fmax(0.0, fmin(1.0, fma(1.0/6.0, x, 0.5)))"}, + {__(Op::Exp, DT::F32 ), "expf(x)"}, + {__(Op::Exp, DT::F64 ), "exp(x)"}, + {__(Op::Exp, DT::FP16), "hexp(x)"}, + {__(Op::Exp, DT::BF16), "hexp(x)"}, + }; // clang-format on diff --git a/src/04kernel/test/kernels/simple_unary/test_cpu.cpp b/src/04kernel/test/kernels/simple_unary/test_cpu.cpp index 47249281..e4d3de4f 100644 --- a/src/04kernel/test/kernels/simple_unary/test_cpu.cpp +++ b/src/04kernel/test/kernels/simple_unary/test_cpu.cpp @@ -59,4 +59,5 @@ TEST(kernel, SimpleUnaryCpu) { testOp(SimpleUnaryType::Erf, std::erf); testOpWithData(SimpleUnaryType::HardSwish, VecFloat{0.000000, 0.666667, 1.666667, 3.000000, 4.000000, 5.000000}); + testOpWithData(SimpleUnaryType::Exp, VecFloat{1.000000, 2.718282, 7.389056, 20.085537, 54.598148, 148.413162}); } diff --git a/src/04kernel/test/kernels/simple_unary/test_cuda.cpp b/src/04kernel/test/kernels/simple_unary/test_cuda.cpp index 72ebff72..c604b678 100644 --- a/src/04kernel/test/kernels/simple_unary/test_cuda.cpp +++ b/src/04kernel/test/kernels/simple_unary/test_cuda.cpp @@ -53,6 +53,7 @@ TEST(kernel, SimpleUnaryCuda) { testOp(SimpleUnaryType::Tanh); testOp(SimpleUnaryType::Erf); testOp(SimpleUnaryType::HardSwish); + testOp(SimpleUnaryType::Exp); } #endif diff --git a/src/05computation/src/operators/simple_unary.cc b/src/05computation/src/operators/simple_unary.cc index d43aa5ac..a37fe1c7 100644 --- a/src/05computation/src/operators/simple_unary.cc +++ b/src/05computation/src/operators/simple_unary.cc @@ -85,6 +85,10 @@ namespace refactor::computation { static uint8_t ID = 20; return reinterpret_cast(&ID); } + case SimpleUnaryType::Exp: { + static uint8_t ID = 21; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -134,6 +138,8 @@ namespace refactor::computation { return "Not"; case SimpleUnaryType::HardSwish: return "HardSwish"; + case SimpleUnaryType::Exp: + return "Exp"; default: UNREACHABLE(); } diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index c871100d..0981f720 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -120,6 +120,7 @@ namespace refactor::onnx { REGISTER(Neg , SimpleUnary ); REGISTER(Identity , SimpleUnary ); REGISTER(HardSwish , SimpleUnary ); + REGISTER(Exp , SimpleUnary ); REGISTER(Slice , Slice ); REGISTER(Softmax , Softmax ); REGISTER(Split , Split ); diff --git a/src/07onnx/src/operators/simple_unary.cc b/src/07onnx/src/operators/simple_unary.cc index 8ce5e14f..9192c916 100644 --- a/src/07onnx/src/operators/simple_unary.cc +++ b/src/07onnx/src/operators/simple_unary.cc @@ -38,6 +38,7 @@ namespace refactor::onnx { opType == "onnx::Neg" ? Ty::Neg : opType == "onnx::Identity"? Ty::Identity: opType == "onnx::HardSwish" ? Ty::HardSwish : + opType == "onnx::Exp" ? Ty::Exp : UNREACHABLEX(Ty, "Unsupported unary operator: {}", opType); // clang-format on @@ -134,6 +135,10 @@ namespace refactor::onnx { static uint8_t ID = 22; return reinterpret_cast(&ID); } + case Ty::Exp: { + static uint8_t ID = 23; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -165,6 +170,7 @@ namespace refactor::onnx { case Ty::Neg : return "onnx::Neg"; case Ty::Identity : return "onnx::Identity"; case Ty::HardSwish : return "onnx::HardSwish"; + case Ty::Exp : return "onnx::Exp"; default: UNREACHABLE(); } // clang-format on @@ -194,7 +200,7 @@ namespace refactor::onnx { Ty::Cos, Ty::Cosh, Ty::Sin, Ty::Sinh, Ty::Tan, Ty::HardSwish}, - {Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log}, + {Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log, Ty::Exp}, {Ty::Neg}, {Ty::Identity}}; if (SET[0].contains(type)) { @@ -294,6 +300,7 @@ namespace refactor::onnx { case Ty::Neg : type_ = Ty_::Neg ; break; case Ty::Identity : return std::make_unique(); case Ty::HardSwish : type_ = Ty_::HardSwish ; break; + case Ty::Exp : type_ = Ty_::Exp ; break; default: UNREACHABLE(); } // clang-format on diff --git a/src/07onnx/src/operators/simple_unary.hh b/src/07onnx/src/operators/simple_unary.hh index 746a1775..93e3f116 100644 --- a/src/07onnx/src/operators/simple_unary.hh +++ b/src/07onnx/src/operators/simple_unary.hh @@ -17,6 +17,7 @@ namespace refactor::onnx { Cos, Cosh, Erf, + Exp, HardSwish, Identity, Log, diff --git a/src/07onnx/test/test_simple_unary.cpp b/src/07onnx/test/test_simple_unary.cpp index bd1e7959..5f9cb691 100644 --- a/src/07onnx/test/test_simple_unary.cpp +++ b/src/07onnx/test/test_simple_unary.cpp @@ -36,4 +36,17 @@ TEST(infer, SimpleUnary) { ASSERT_EQ(y->dataType, DataType::F32); ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)})); } + { + // Exp Test + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}}; + count_t inputs[]{0}; + auto infered = SimpleUnary(SimpleUnaryType::Exp).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)})); + } }