From defbc9c279b8b5e08b7e46360a2dad5a2527a35b Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Sat, 7 Oct 2023 10:07:31 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=81=A2=E5=A4=8D=E5=89=8D=E7=AB=AF?= =?UTF-8?q?=E5=AF=BC=E5=87=BA=E5=B9=B6=E4=B8=BA=E6=9B=B4=E5=A4=9A=E7=AE=97?= =?UTF-8?q?=E5=AD=90=E5=AE=9E=E7=8E=B0=E5=89=8D=E7=AB=AF=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/04frontend/include/frontend/operator.h | 17 ++- src/05onnx/src/operators.cpp | 10 +- src/05onnx/src/operators/arithmetic.cpp.bak | 104 ------------------ ...ization.cpp.bak => batch_normalization.cc} | 34 ++++-- .../src/operators/batch_normalization.hh | 25 +++++ .../src/operators/{cast.cpp.bak => cast.cc} | 29 +++-- src/05onnx/src/operators/cast.hh | 25 +++++ src/05onnx/src/operators/common.cpp | 6 + src/05onnx/src/operators/common.h | 4 + src/05onnx/src/operators/compair.cc | 71 ++++++++++++ src/05onnx/src/operators/compair.cpp.bak | 58 ---------- src/05onnx/src/operators/compair.hh | 34 ++++++ .../operators/{binary.cc => simple_binary.cc} | 75 +++++++------ .../operators/{binary.hh => simple_binary.hh} | 16 +-- src/07python_ffi/src/main.cpp | 16 +-- 15 files changed, 284 insertions(+), 240 deletions(-) delete mode 100644 src/05onnx/src/operators/arithmetic.cpp.bak rename src/05onnx/src/operators/{batch_normalization.cpp.bak => batch_normalization.cc} (54%) create mode 100644 src/05onnx/src/operators/batch_normalization.hh rename src/05onnx/src/operators/{cast.cpp.bak => cast.cc} (82%) create mode 100644 src/05onnx/src/operators/cast.hh create mode 100644 src/05onnx/src/operators/compair.cc delete mode 100644 src/05onnx/src/operators/compair.cpp.bak create mode 100644 src/05onnx/src/operators/compair.hh rename src/05onnx/src/operators/{binary.cc => simple_binary.cc} (68%) rename src/05onnx/src/operators/{binary.hh => simple_binary.hh} (60%) diff --git a/src/04frontend/include/frontend/operator.h b/src/04frontend/include/frontend/operator.h index 326d1dd3..6225dcbd 100644 --- a/src/04frontend/include/frontend/operator.h +++ b/src/04frontend/include/frontend/operator.h @@ -34,7 +34,8 @@ namespace refactor::frontend { using Attributes = std::unordered_map; class Operator; - using OpBox = std::unique_ptr; + class OpBox; + class Operator { public: virtual size_t opTypeId() const = 0; @@ -54,12 +55,22 @@ namespace refactor::frontend { } }; + class OpBox { + std::unique_ptr op; + + public: + explicit OpBox(std::unique_ptr ptr) : op(std::move(ptr)) {} + + Operator *operator->() { return op.get(); } + Operator const *operator->() const { return op.get(); } + }; + struct Node { OpBox op; std::string name; - template T *opAs() const { - return dynamic_cast(op.get()); + template T const *opAs() const { + return dynamic_cast(op.operator->()); } }; diff --git a/src/05onnx/src/operators.cpp b/src/05onnx/src/operators.cpp index 97fdaccd..f5bf9b43 100644 --- a/src/05onnx/src/operators.cpp +++ b/src/05onnx/src/operators.cpp @@ -2,16 +2,16 @@ #include "frontend/operator.h" #include "operators/common.h" -#include "operators/binary.hh" +#include "operators/simple_binary.hh" namespace refactor::onnx { void register_() { // clang-format off - Operator::register_("onnx::Add"); - Operator::register_("onnx::Sub"); - Operator::register_("onnx::Mul"); - Operator::register_("onnx::Div"); + Operator::register_("onnx::Add"); + Operator::register_("onnx::Sub"); + Operator::register_("onnx::Mul"); + Operator::register_("onnx::Div"); // clang-format on } diff --git a/src/05onnx/src/operators/arithmetic.cpp.bak b/src/05onnx/src/operators/arithmetic.cpp.bak deleted file mode 100644 index 2a268d12..00000000 --- a/src/05onnx/src/operators/arithmetic.cpp.bak +++ /dev/null @@ -1,104 +0,0 @@ -#include "common.h" -#include "common/range.h" -#include "computation/operators/simple_binary.h" - -namespace refactor::onnx { - using namespace common; - - enum class Ty { - Add, - Sub, - Mul, - Div - }; - - template - void calculate(Ty ty, void *dst, void const *a, void const *b) { - using T_ = typename primitive_t::type; - auto a_ = *reinterpret_cast(a); - auto b_ = *reinterpret_cast(b); - auto dst_ = reinterpret_cast(dst); - switch (ty) { - case Ty::Add: - *dst_ = a_ + b_; - break; - case Ty::Sub: - *dst_ = a_ - b_; - break; - case Ty::Mul: - *dst_ = a_ * b_; - break; - case Ty::Div: - *dst_ = a_ / b_; - break; - default: - UNREACHABLE(); - } - } - - InferResult inferArithmetic(Operator const &op, TensorRefs inputs, InferOptions const &options) { - EXPECT_SIZE(2) - - auto const &a = inputs[0]; - auto const &b = inputs[1]; - auto dataType = a.dataType; - if (!dataType.isNumberic() || b.dataType != dataType) { - return Err(InferError(ERROR_MSG("Data type not support"))); - } - - MULTIDIR_BROADCAST((ShapeRefs{a.shape, b.shape})) - auto ans = Tensor::share(dataType, std::move(output), extractDependency(inputs)); - if (!options.shouldCalculate(inputs, {*ans})) { - return Ok(Tensors{std::move(ans)}); - } - - auto eleSize = dataType.size(); - auto dst = reinterpret_cast(ans->malloc()); - for (auto i : range0_(ans->elementsSize())) { - auto ty = op.opType.is("onnx::Add") ? Ty::Add - : op.opType.is("onnx::Sub") ? Ty::Sub - : op.opType.is("onnx::Mul") ? Ty::Mul - : op.opType.is("onnx::Div") ? Ty::Div - : UNREACHABLEX(Ty, ""); - auto indices = locateN(ans->shape, i); - auto a_ = locate1(a, indices), - b_ = locate1(b, indices); - auto dst_ = dst + i * eleSize; - //------------------------------------- -#define CASE(T) \ - case DataType::T: \ - calculate(ty, dst_, a_, b_); \ - break - //------------------------------------- - switch (dataType.internal) { - CASE(F32); - CASE(F64); - CASE(I32); - CASE(I64); - CASE(I8); - CASE(I16); - CASE(U8); - CASE(U16); - CASE(U32); - CASE(U64); - default: - ans->free(); - break; - } - } - return Ok(Tensors{std::move(ans)}); - } - - LowerOperator lowerArithmetic(Operator const &op, TensorRefs) { - using namespace computation; - - auto type = op.opType.is("onnx::Add") ? SimpleBinaryType::Add - : op.opType.is("onnx::Sub") ? SimpleBinaryType::Sub - : op.opType.is("onnx::Mul") ? SimpleBinaryType::Mul - : op.opType.is("onnx::Div") ? SimpleBinaryType::Div - : UNREACHABLEX(SimpleBinaryType, - "{} not support", - op.opType.name()); - return {std::make_shared(type), {0, 1}}; - } -}// namespace refactor::onnx diff --git a/src/05onnx/src/operators/batch_normalization.cpp.bak b/src/05onnx/src/operators/batch_normalization.cc similarity index 54% rename from src/05onnx/src/operators/batch_normalization.cpp.bak rename to src/05onnx/src/operators/batch_normalization.cc index a61b8fe8..6287341e 100644 --- a/src/05onnx/src/operators/batch_normalization.cpp.bak +++ b/src/05onnx/src/operators/batch_normalization.cc @@ -1,15 +1,30 @@ #include "computation/operators/batch_normalization.h" +#include "batch_normalization.hh" #include "common.h" #include namespace refactor::onnx { using namespace common; + using Op = BatchNormalization; - InferResult inferBatchNormalization(Operator const &op, TensorRefs inputs, InferOptions const &) { - if (op.attribute("training_mode", {0}).int_() != 0) { - return Err(InferError(ERROR_MSG("training_mode is not supported"))); - } + Op::BatchNormalization(bool trainingMode_) + : Operator(), trainingMode(trainingMode_) {} + + auto Op::build(std::string_view, Attributes attributes) -> OpBox { + auto trainingMode = defaultOr(attributes, "training_mode", {0}).int_() != 0; + return OpBox(std::make_unique(trainingMode)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "onnx::Op"; } + + auto Op::infer( + TensorRefs inputs, + InferOptions const &options) const -> InferResult { EXPECT_SIZE(5) auto const &x = inputs[0]; @@ -32,16 +47,11 @@ namespace refactor::onnx { return Ok(Tensors{Tensor::share(x.dataType, x.shape, extractDependency(inputs))}); } - - LowerOperator lowerBatchNormalization(Operator const &, TensorRefs inputs) { - using namespace computation; - + auto Op::lower(TensorRefs inputs) const -> LowerOperator { + using Op_ = computation::BatchNormalization; decltype(LowerOperator::inputs) inputs_(inputs.size()); std::iota(inputs_.begin(), inputs_.end(), 0); - return { - std::make_shared(), - std::move(inputs_), - }; + return {std::make_shared(), std::move(inputs_)}; } }// namespace refactor::onnx diff --git a/src/05onnx/src/operators/batch_normalization.hh b/src/05onnx/src/operators/batch_normalization.hh new file mode 100644 index 00000000..351db511 --- /dev/null +++ b/src/05onnx/src/operators/batch_normalization.hh @@ -0,0 +1,25 @@ +#ifndef ONNX_BATCH_NORMALIZATION_HH +#define ONNX_BATCH_NORMALIZATION_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + struct BatchNormalization final : public Operator { + bool trainingMode; + + explicit BatchNormalization(bool); + + static OpBox build(std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + LowerOperator lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_BATCH_NORMALIZATION_HH diff --git a/src/05onnx/src/operators/cast.cpp.bak b/src/05onnx/src/operators/cast.cc similarity index 82% rename from src/05onnx/src/operators/cast.cpp.bak rename to src/05onnx/src/operators/cast.cc index 32047ab6..07a225df 100644 --- a/src/05onnx/src/operators/cast.cpp.bak +++ b/src/05onnx/src/operators/cast.cc @@ -1,10 +1,27 @@ #include "computation/operators/cast.h" +#include "cast.hh" #include "common.h" #include "common/natural.h" #include namespace refactor::onnx { using namespace common; + using Op = Cast; + + Op::Cast(DataType to_) + : Operator(), to(to_) {} + + auto Op::build(std::string_view, Attributes attributes) -> OpBox { + auto to = *DataType::parse(attributes.at("to").int_()); + return OpBox(std::make_unique(to)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "onnx::Cast"; } template void castData(void const *src, void *dst, size_t size) { @@ -13,11 +30,10 @@ namespace refactor::onnx { std::transform(std::execution::unseq, src_, src_ + size, dst_, [](auto x) { return static_cast(x); }); } - InferResult inferCast(Operator const &op, TensorRefs inputs, InferOptions const &options) { + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { EXPECT_SIZE(1) auto const &input = inputs[0]; - auto to = *DataType::parse(op.attribute("to").int_()); auto ans = Tensor::share(to, input.shape, extractDependency(inputs)); if (!options.shouldCalculate(inputs, {*ans})) { return Ok(Tensors{std::move(ans)}); @@ -103,10 +119,9 @@ namespace refactor::onnx { return Ok(Tensors{std::move(ans)}); } - LowerOperator lowerCast(Operator const &op, TensorRefs) { - using namespace computation; - - auto to = *DataType::parse(op.attribute("to").int_()); - return {std::make_shared(to), {0}}; + auto Op::lower(TensorRefs) const -> LowerOperator { + using Op_ = computation::Cast; + return {std::make_shared(to), {0}}; } + }// namespace refactor::onnx diff --git a/src/05onnx/src/operators/cast.hh b/src/05onnx/src/operators/cast.hh new file mode 100644 index 00000000..6c7c5e2a --- /dev/null +++ b/src/05onnx/src/operators/cast.hh @@ -0,0 +1,25 @@ +#ifndef ONNX_CAST_HH +#define ONNX_CAST_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + struct Cast final : public Operator { + common::DataType to; + + explicit Cast(common::DataType); + + static OpBox build(std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + LowerOperator lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_CAST_HH diff --git a/src/05onnx/src/operators/common.cpp b/src/05onnx/src/operators/common.cpp index 80f7eee8..8e52c165 100644 --- a/src/05onnx/src/operators/common.cpp +++ b/src/05onnx/src/operators/common.cpp @@ -101,4 +101,10 @@ namespace refactor::onnx { }); return Ok(std::move(ans)); } + + Attribute defaultOr(Attributes &attrs, std::string const &name, Attribute defaultValue) { + auto iter = attrs.find(name); + return iter == attrs.end() ? defaultValue : std::move(iter->second); + } + }// namespace refactor::onnx diff --git a/src/05onnx/src/operators/common.h b/src/05onnx/src/operators/common.h index ad9533f0..eb975bac 100644 --- a/src/05onnx/src/operators/common.h +++ b/src/05onnx/src/operators/common.h @@ -36,6 +36,10 @@ namespace refactor::onnx { OptionalInts const &pads, OptionalInts const &strides); + Attribute defaultOr(Attributes &attrs, + std::string const &name, + Attribute defaultValue); + #define EXPECT_SIZE(N) \ if (inputs.size() != (N)) { \ return Err(InferError(ERROR_MSG("Input size error"))); \ diff --git a/src/05onnx/src/operators/compair.cc b/src/05onnx/src/operators/compair.cc new file mode 100644 index 00000000..e16b0b0b --- /dev/null +++ b/src/05onnx/src/operators/compair.cc @@ -0,0 +1,71 @@ +#include "computation/operators/compair.h" +#include "common.h" +#include "common/range.h" +#include "compair.hh" + +namespace refactor::onnx { + using namespace common; + using Op = Compair; + + Op::Compair(CompairType type_) + : Operator(), type(type_) {} + + auto Op::build(std::string_view, Attributes); + auto Op::typeId(CompairType); + + auto Op::opTypeId() const final; + auto Op::opTypeName() const final; + auto Op::infer(TensorRefs, InferOptions const &) const final; + auto Op::lower(TensorRefs) const final; + + // InferResult inferCompair(Operator const &op, TensorRefs inputs, InferOptions const &options) { + // EXPECT_SIZE(2) + + // auto const &a = inputs[0]; + // auto const &b = inputs[1]; + // if (a.dataType != b.dataType) { + // return Err(InferError(ERROR_MSG("Input data type not support"))); + // } + + // MULTIDIR_BROADCAST((ShapeRefs{a.shape, b.shape})) + // auto ans = Tensor::share(DataType::Bool, std::move(output), extractDependency(inputs)); + // if (!options.shouldCalculate(inputs, {*ans}) || a.dataType != DataType::I64) {// TODO: support other data type + // return Ok(Tensors{std::move(ans)}); + // } + + // auto dst = reinterpret_cast(ans->malloc()); + // for (auto i : range0_(ans->elementsSize())) { + // auto indices = locateN(ans->shape, i); + // auto a_ = *reinterpret_cast(locate1(a, indices)), + // b_ = *reinterpret_cast(locate1(b, indices)); + // if (op.opType.is("onnx::Equal")) { + // dst[i] = a_ == b_; + // } else if (op.opType.is("onnx::Greater")) { + // dst[i] = a_ > b_; + // } else if (op.opType.is("onnx::GreaterOrEqual")) { + // dst[i] = a_ >= b_; + // } else if (op.opType.is("onnx::Less")) { + // dst[i] = a_ < b_; + // } else if (op.opType.is("onnx::LessOrEqual")) { + // dst[i] = a_ <= b_; + // } else { + // return Err(InferError(ERROR_MSG("OpType not support"))); + // } + // } + // return Ok(Tensors{std::move(ans)}); + // } + + // LowerOperator lowerCompair(Operator const &op, TensorRefs) { + // using namespace computation; + + // auto type = op.opType.is("onnx::Equal") ? CompairType::EQ + // : op.opType.is("onnx::Greater") ? CompairType::GT + // : op.opType.is("onnx::GreaterOrEqual") ? CompairType::GE + // : op.opType.is("onnx::Less") ? CompairType::LT + // : op.opType.is("onnx::LessOrEqual") ? CompairType::LE + // : UNREACHABLEX(CompairType, + // "{} not support", + // op.opType.name()); + // return {std::make_shared(type), {0, 1}}; + // } +}// namespace refactor::onnx diff --git a/src/05onnx/src/operators/compair.cpp.bak b/src/05onnx/src/operators/compair.cpp.bak deleted file mode 100644 index 61c85496..00000000 --- a/src/05onnx/src/operators/compair.cpp.bak +++ /dev/null @@ -1,58 +0,0 @@ -#include "computation/operators/compair.h" -#include "common.h" -#include "common/range.h" - -namespace refactor::onnx { - using namespace common; - - InferResult inferCompair(Operator const &op, TensorRefs inputs, InferOptions const &options) { - EXPECT_SIZE(2) - - auto const &a = inputs[0]; - auto const &b = inputs[1]; - if (a.dataType != b.dataType) { - return Err(InferError(ERROR_MSG("Input data type not support"))); - } - - MULTIDIR_BROADCAST((ShapeRefs{a.shape, b.shape})) - auto ans = Tensor::share(DataType::Bool, std::move(output), extractDependency(inputs)); - if (!options.shouldCalculate(inputs, {*ans}) || a.dataType != DataType::I64) {// TODO: support other data type - return Ok(Tensors{std::move(ans)}); - } - - auto dst = reinterpret_cast(ans->malloc()); - for (auto i : range0_(ans->elementsSize())) { - auto indices = locateN(ans->shape, i); - auto a_ = *reinterpret_cast(locate1(a, indices)), - b_ = *reinterpret_cast(locate1(b, indices)); - if (op.opType.is("onnx::Equal")) { - dst[i] = a_ == b_; - } else if (op.opType.is("onnx::Greater")) { - dst[i] = a_ > b_; - } else if (op.opType.is("onnx::GreaterOrEqual")) { - dst[i] = a_ >= b_; - } else if (op.opType.is("onnx::Less")) { - dst[i] = a_ < b_; - } else if (op.opType.is("onnx::LessOrEqual")) { - dst[i] = a_ <= b_; - } else { - return Err(InferError(ERROR_MSG("OpType not support"))); - } - } - return Ok(Tensors{std::move(ans)}); - } - - LowerOperator lowerCompair(Operator const &op, TensorRefs) { - using namespace computation; - - auto type = op.opType.is("onnx::Equal") ? CompairType::EQ - : op.opType.is("onnx::Greater") ? CompairType::GT - : op.opType.is("onnx::GreaterOrEqual") ? CompairType::GE - : op.opType.is("onnx::Less") ? CompairType::LT - : op.opType.is("onnx::LessOrEqual") ? CompairType::LE - : UNREACHABLEX(CompairType, - "{} not support", - op.opType.name()); - return {std::make_shared(type), {0, 1}}; - } -}// namespace refactor::onnx diff --git a/src/05onnx/src/operators/compair.hh b/src/05onnx/src/operators/compair.hh new file mode 100644 index 00000000..bc927e75 --- /dev/null +++ b/src/05onnx/src/operators/compair.hh @@ -0,0 +1,34 @@ +#ifndef ONNX_COMPAIR_HH +#define ONNX_COMPAIR_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + enum class CompairType { + EQ, + NE, + LT, + LE, + GT, + GE, + }; + + struct Compair final : public Operator { + CompairType type; + + explicit Compair(CompairType); + + static OpBox build(std::string_view, Attributes); + static size_t typeId(CompairType); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + LowerOperator lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_COMPAIR_HH diff --git a/src/05onnx/src/operators/binary.cc b/src/05onnx/src/operators/simple_binary.cc similarity index 68% rename from src/05onnx/src/operators/binary.cc rename to src/05onnx/src/operators/simple_binary.cc index 8a064bab..ba9ad61d 100644 --- a/src/05onnx/src/operators/binary.cc +++ b/src/05onnx/src/operators/simple_binary.cc @@ -1,4 +1,4 @@ -#include "binary.hh" +#include "simple_binary.hh" #include "common.h" #include "common/error_handler.h" #include "common/range.h" @@ -6,42 +6,45 @@ namespace refactor::onnx { using namespace common; + using Op = SimpleBinary; + using Ty = SimpleBinaryType; - Binary::Binary(BinaryType type_) : type(type_) {} + Op::SimpleBinary(Ty type_) + : Operator(), type(type_) {} - auto Binary::build(std::string_view opType, Attributes attributes) -> OpBox { + auto Op::build(std::string_view opType, Attributes attributes) -> OpBox { ASSERT(attributes.empty(), "Simple binary operator should not have attributes"); if (opType == "onnx::Add") { - return std::make_unique(BinaryType::Add); + return OpBox(std::make_unique(Ty::Add)); } if (opType == "onnx::Sub") { - return std::make_unique(BinaryType::Sub); + return OpBox(std::make_unique(Ty::Sub)); } if (opType == "onnx::Mul") { - return std::make_unique(BinaryType::Mul); + return OpBox(std::make_unique(Ty::Mul)); } if (opType == "onnx::Div") { - return std::make_unique(BinaryType::Div); + return OpBox(std::make_unique(Ty::Div)); } UNREACHABLEX(void, "Unsupported binary operator: {}", opType); } - auto Binary::typeId(BinaryType type) -> size_t { + auto Op::typeId(Ty type) -> size_t { switch (type) { - case BinaryType::Add: { + case Ty::Add: { static uint8_t ID = 1; return reinterpret_cast(&ID); } - case BinaryType::Sub: { + case Ty::Sub: { static uint8_t ID = 2; return reinterpret_cast(&ID); } - case BinaryType::Mul: { + case Ty::Mul: { static uint8_t ID = 3; return reinterpret_cast(&ID); } - case BinaryType::Div: { + case Ty::Div: { static uint8_t ID = 4; return reinterpret_cast(&ID); } @@ -50,16 +53,16 @@ namespace refactor::onnx { } } - auto Binary::opTypeId() const -> size_t { return typeId(type); } - auto Binary::opTypeName() const -> std::string_view { + auto Op::opTypeId() const -> size_t { return typeId(type); } + auto Op::opTypeName() const -> std::string_view { switch (type) { - case BinaryType::Add: + case Ty::Add: return "onnx::Add"; - case BinaryType::Sub: + case Ty::Sub: return "onnx::Sub"; - case BinaryType::Mul: + case Ty::Mul: return "onnx::Mul"; - case BinaryType::Div: + case Ty::Div: return "onnx::Div"; default: UNREACHABLE(); @@ -67,22 +70,22 @@ namespace refactor::onnx { } template - void calculate(BinaryType ty, void *dst, void const *a, void const *b) { + void calculate(Ty ty, void *dst, void const *a, void const *b) { using T_ = typename primitive_t::type; auto a_ = *reinterpret_cast(a); auto b_ = *reinterpret_cast(b); auto dst_ = reinterpret_cast(dst); switch (ty) { - case BinaryType::Add: + case Ty::Add: *dst_ = a_ + b_; break; - case BinaryType::Sub: + case Ty::Sub: *dst_ = a_ - b_; break; - case BinaryType::Mul: + case Ty::Mul: *dst_ = a_ * b_; break; - case BinaryType::Div: + case Ty::Div: *dst_ = a_ / b_; break; default: @@ -90,7 +93,7 @@ namespace refactor::onnx { } } - auto Binary::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { EXPECT_SIZE(2) auto const &a = inputs[0]; @@ -138,25 +141,27 @@ namespace refactor::onnx { return Ok(Tensors{std::move(ans)}); } - auto Binary::lower(TensorRefs) const -> LowerOperator { - computation::SimpleBinaryType type_; + auto Op::lower(TensorRefs) const -> LowerOperator { + using Op_ = computation::SimpleBinary; + using Ty_ = computation::SimpleBinaryType; + Ty_ type_; switch (type) { - case BinaryType::Add: - type_ = computation::SimpleBinaryType::Add; + case Ty::Add: + type_ = Ty_::Add; break; - case BinaryType::Sub: - type_ = computation::SimpleBinaryType::Sub; + case Ty::Sub: + type_ = Ty_::Sub; break; - case BinaryType::Mul: - type_ = computation::SimpleBinaryType::Mul; + case Ty::Mul: + type_ = Ty_::Mul; break; - case BinaryType::Div: - type_ = computation::SimpleBinaryType::Div; + case Ty::Div: + type_ = Ty_::Div; break; default: break; } - return {std::make_shared(type_), {0, 1}}; + return {std::make_shared(type_), {0, 1}}; } }// namespace refactor::onnx diff --git a/src/05onnx/src/operators/binary.hh b/src/05onnx/src/operators/simple_binary.hh similarity index 60% rename from src/05onnx/src/operators/binary.hh rename to src/05onnx/src/operators/simple_binary.hh index 3a7f2667..7d89e51f 100644 --- a/src/05onnx/src/operators/binary.hh +++ b/src/05onnx/src/operators/simple_binary.hh @@ -1,25 +1,25 @@ -#ifndef ONNX_BINARY_HH -#define ONNX_BINARY_HH +#ifndef ONNX_SIMPLE_BINARY_HH +#define ONNX_SIMPLE_BINARY_HH #include "frontend/operator.h" namespace refactor::onnx { using namespace frontend; - enum class BinaryType { + enum class SimpleBinaryType { Add, Sub, Mul, Div, }; - struct Binary final : public Operator { - BinaryType type; + struct SimpleBinary final : public Operator { + SimpleBinaryType type; - explicit Binary(BinaryType); + explicit SimpleBinary(SimpleBinaryType); static OpBox build(std::string_view, Attributes); - static size_t typeId(BinaryType); + static size_t typeId(SimpleBinaryType); size_t opTypeId() const final; std::string_view opTypeName() const final; @@ -29,4 +29,4 @@ namespace refactor::onnx { }// namespace refactor::onnx -#endif// ONNX_BINARY_HH +#endif// ONNX_SIMPLE_BINARY_HH diff --git a/src/07python_ffi/src/main.cpp b/src/07python_ffi/src/main.cpp index bc9f581f..ecc4c833 100644 --- a/src/07python_ffi/src/main.cpp +++ b/src/07python_ffi/src/main.cpp @@ -15,14 +15,14 @@ namespace refactor::python_ffi { // clang-format off py::class_ >(m, "Tensor" ); - py::class_>(m, "Operator" ); - - // m .def("config_log" , &configLog , return_::automatic ) - // .def("_make_operator" , &makeOp , return_::move ) - // .def("_make_tensor" , &makeTensor , return_::move ) - // .def("_make_data" , &makeTensorWithData , return_::move ) - // .def("_make_data_ex" , &makeTensorWithExternalData , return_::move ) - // .def("_make_compiler" , &makeCompiler , return_::move ); + py::class_ >(m, "Operator" ); + + m .def("config_log" , &configLog , return_::automatic ) + .def("_make_operator" , &makeOp , return_::move ) + .def("_make_tensor" , &makeTensor , return_::move ) + .def("_make_data" , &makeTensorWithData , return_::move ) + .def("_make_data_ex" , &makeTensorWithExternalData , return_::move ) + .def("_make_compiler" , &makeCompiler , return_::move ); py::class_ >(m, "Compiler" ) .def("substitute" , &Compiler::substitute , return_::automatic )