Skip to content

Commit

Permalink
feat(kernel): 添加exp算子
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Feb 4, 2024
1 parent d076c20 commit 3199eb5
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/04kernel/include/kernel/collectors/simple_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace refactor::kernel {
Neg,
Not,
HardSwish,
Exp,
};

std::string_view unaryName(SimpleUnaryType type);
Expand Down
1 change: 1 addition & 0 deletions src/04kernel/src/collectors/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace refactor::kernel {
CASE(Sqrt);
CASE(Sigmoid);
CASE(Erf);
CASE(Exp);
CASE(Neg);
CASE(Not);
CASE(HardSwish);
Expand Down
8 changes: 8 additions & 0 deletions src/04kernel/src/kernels/simple_unary/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace refactor::kernel {
Op::Neg,
Op::Erf,
Op::HardSwish,
Op::Exp,
};
return supportedOp.contains(op) && a.dataType.isCpuNumberic()
? std::make_unique<K>(op, a.dataType, a.elementsSize())
Expand Down Expand Up @@ -185,6 +186,13 @@ namespace refactor::kernel {
default:
UNREACHABLE();
}
case Op::Exp:
switch (dataType) {
CASE(std::exp, F32);
CASE(std::exp, F64);
default:
UNREACHABLE();
}
default:
UNREACHABLE();
}
Expand Down
19 changes: 16 additions & 3 deletions src/04kernel/src/kernels/simple_unary/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ namespace refactor::kernel {

auto K::build(Op op, Tensor const &a) noexcept -> KernelBox {
static const std::unordered_set<Op>
supportedOp{Op::Abs, Op::Relu, Op::Sqrt,
Op::Sigmoid, Op::Tanh, Op::Neg,
Op::Erf, Op::HardSwish};
supportedOp{
Op::Abs,
Op::Relu,
Op::Sqrt,
Op::Sigmoid,
Op::Tanh,
Op::Neg,
Op::Erf,
Op::HardSwish,
Op::Exp,
};
#ifndef USE_CUDA
return nullptr;
#endif
Expand Down Expand Up @@ -159,6 +167,11 @@ extern "C" __global__ void kernel(
{__(Op::HardSwish, DT::FP16), "x * __hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, hrcp(__float2half(6.f)) * x + hrcp(__float2half(2.f))))"},
{__(Op::HardSwish, DT::F64 ), "x * fmax(0.0, fmin(1.0, fma(1.0/6.0, x, 0.5)))"},

{__(Op::Exp, DT::F32 ), "expf(x)"},
{__(Op::Exp, DT::F64 ), "exp(x)"},
{__(Op::Exp, DT::FP16), "hexp(x)"},
{__(Op::Exp, DT::BF16), "hexp(x)"},

};
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions src/04kernel/test/kernels/simple_unary/test_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ TEST(kernel, SimpleUnaryCpu) {
testOp(SimpleUnaryType::Erf, std::erf);
testOpWithData(SimpleUnaryType::HardSwish,
VecFloat{0.000000, 0.666667, 1.666667, 3.000000, 4.000000, 5.000000});
testOpWithData(SimpleUnaryType::Exp, VecFloat{1.000000, 2.718282, 7.389056, 20.085537, 54.598148, 148.413162});
}
1 change: 1 addition & 0 deletions src/04kernel/test/kernels/simple_unary/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ TEST(kernel, SimpleUnaryCuda) {
testOp(SimpleUnaryType::Tanh);
testOp(SimpleUnaryType::Erf);
testOp(SimpleUnaryType::HardSwish);
testOp(SimpleUnaryType::Exp);
}

#endif
6 changes: 6 additions & 0 deletions src/05computation/src/operators/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ namespace refactor::computation {
static uint8_t ID = 20;
return reinterpret_cast<size_t>(&ID);
}
case SimpleUnaryType::Exp: {
static uint8_t ID = 21;
return reinterpret_cast<size_t>(&ID);
}
default:
UNREACHABLE();
}
Expand Down Expand Up @@ -134,6 +138,8 @@ namespace refactor::computation {
return "Not";
case SimpleUnaryType::HardSwish:
return "HardSwish";
case SimpleUnaryType::Exp:
return "Exp";
default:
UNREACHABLE();
}
Expand Down
1 change: 1 addition & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ namespace refactor::onnx {
REGISTER(Neg , SimpleUnary );
REGISTER(Identity , SimpleUnary );
REGISTER(HardSwish , SimpleUnary );
REGISTER(Exp , SimpleUnary );
REGISTER(Slice , Slice );
REGISTER(Softmax , Softmax );
REGISTER(Split , Split );
Expand Down
9 changes: 8 additions & 1 deletion src/07onnx/src/operators/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace refactor::onnx {
opType == "onnx::Neg" ? Ty::Neg :
opType == "onnx::Identity"? Ty::Identity:
opType == "onnx::HardSwish" ? Ty::HardSwish :
opType == "onnx::Exp" ? Ty::Exp :
UNREACHABLEX(Ty, "Unsupported unary operator: {}", opType);
// clang-format on

Expand Down Expand Up @@ -134,6 +135,10 @@ namespace refactor::onnx {
static uint8_t ID = 22;
return reinterpret_cast<size_t>(&ID);
}
case Ty::Exp: {
static uint8_t ID = 23;
return reinterpret_cast<size_t>(&ID);
}
default:
UNREACHABLE();
}
Expand Down Expand Up @@ -165,6 +170,7 @@ namespace refactor::onnx {
case Ty::Neg : return "onnx::Neg";
case Ty::Identity : return "onnx::Identity";
case Ty::HardSwish : return "onnx::HardSwish";
case Ty::Exp : return "onnx::Exp";
default: UNREACHABLE();
}
// clang-format on
Expand Down Expand Up @@ -194,7 +200,7 @@ namespace refactor::onnx {
Ty::Cos, Ty::Cosh,
Ty::Sin, Ty::Sinh,
Ty::Tan, Ty::HardSwish},
{Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log},
{Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log, Ty::Exp},
{Ty::Neg},
{Ty::Identity}};
if (SET[0].contains(type)) {
Expand Down Expand Up @@ -294,6 +300,7 @@ namespace refactor::onnx {
case Ty::Neg : type_ = Ty_::Neg ; break;
case Ty::Identity : return std::make_unique<computation::Identity>();
case Ty::HardSwish : type_ = Ty_::HardSwish ; break;
case Ty::Exp : type_ = Ty_::Exp ; break;
default: UNREACHABLE();
}
// clang-format on
Expand Down
1 change: 1 addition & 0 deletions src/07onnx/src/operators/simple_unary.hh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace refactor::onnx {
Cos,
Cosh,
Erf,
Exp,
HardSwish,
Identity,
Log,
Expand Down
13 changes: 13 additions & 0 deletions src/07onnx/test/test_simple_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,17 @@ TEST(infer, SimpleUnary) {
ASSERT_EQ(y->dataType, DataType::F32);
ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)}));
}
{
// Exp Test
auto edges = Edges{
{Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}};
count_t inputs[]{0};
auto infered = SimpleUnary(SimpleUnaryType::Exp).infer(TensorRefs(edges, inputs), {true});
ASSERT_TRUE(infered.isOk());
auto outputs = std::move(infered.unwrap());
ASSERT_EQ(outputs.size(), 1);
auto y = std::move(outputs[0]);
ASSERT_EQ(y->dataType, DataType::F32);
ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)}));
}
}

0 comments on commit 3199eb5

Please sign in to comment.