Skip to content

Commit

Permalink
fix: 恢复前端导出并为更多算子实现前端类
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Oct 7, 2023
1 parent 02a7e9f commit defbc9c
Show file tree
Hide file tree
Showing 15 changed files with 284 additions and 240 deletions.
17 changes: 14 additions & 3 deletions src/04frontend/include/frontend/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ namespace refactor::frontend {
using Attributes = std::unordered_map<std::string, Attribute>;

class Operator;
using OpBox = std::unique_ptr<Operator>;
class OpBox;

class Operator {
public:
virtual size_t opTypeId() const = 0;
Expand All @@ -54,12 +55,22 @@ namespace refactor::frontend {
}
};

class OpBox {
std::unique_ptr<Operator> op;

public:
explicit OpBox(std::unique_ptr<Operator> ptr) : op(std::move(ptr)) {}

Operator *operator->() { return op.get(); }
Operator const *operator->() const { return op.get(); }
};

struct Node {
OpBox op;
std::string name;

template<class T> T *opAs() const {
return dynamic_cast<T *>(op.get());
template<class T> T const *opAs() const {
return dynamic_cast<T *>(op.operator->());
}
};

Expand Down
10 changes: 5 additions & 5 deletions src/05onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
#include "frontend/operator.h"
#include "operators/common.h"

#include "operators/binary.hh"
#include "operators/simple_binary.hh"

namespace refactor::onnx {

void register_() {
// clang-format off
Operator::register_<Binary>("onnx::Add");
Operator::register_<Binary>("onnx::Sub");
Operator::register_<Binary>("onnx::Mul");
Operator::register_<Binary>("onnx::Div");
Operator::register_<SimpleBinary>("onnx::Add");
Operator::register_<SimpleBinary>("onnx::Sub");
Operator::register_<SimpleBinary>("onnx::Mul");
Operator::register_<SimpleBinary>("onnx::Div");
// clang-format on
}

Expand Down
104 changes: 0 additions & 104 deletions src/05onnx/src/operators/arithmetic.cpp.bak

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
#include "computation/operators/batch_normalization.h"
#include "batch_normalization.hh"
#include "common.h"
#include <numeric>

namespace refactor::onnx {
using namespace common;
using Op = BatchNormalization;

InferResult inferBatchNormalization(Operator const &op, TensorRefs inputs, InferOptions const &) {
if (op.attribute("training_mode", {0}).int_() != 0) {
return Err(InferError(ERROR_MSG("training_mode is not supported")));
}
Op::BatchNormalization(bool trainingMode_)
: Operator(), trainingMode(trainingMode_) {}

auto Op::build(std::string_view, Attributes attributes) -> OpBox {
auto trainingMode = defaultOr(attributes, "training_mode", {0}).int_() != 0;
return OpBox(std::make_unique<Op>(trainingMode));
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto Op::opTypeId() const -> size_t { return typeId(); }
auto Op::opTypeName() const -> std::string_view { return "onnx::Op"; }

auto Op::infer(
TensorRefs inputs,
InferOptions const &options) const -> InferResult {
EXPECT_SIZE(5)

auto const &x = inputs[0];
Expand All @@ -32,16 +47,11 @@ namespace refactor::onnx {

return Ok(Tensors{Tensor::share(x.dataType, x.shape, extractDependency(inputs))});
}

LowerOperator lowerBatchNormalization(Operator const &, TensorRefs inputs) {
using namespace computation;

auto Op::lower(TensorRefs inputs) const -> LowerOperator {
using Op_ = computation::BatchNormalization;
decltype(LowerOperator::inputs) inputs_(inputs.size());
std::iota(inputs_.begin(), inputs_.end(), 0);
return {
std::make_shared<BatchNormalization>(),
std::move(inputs_),
};
return {std::make_shared<Op_>(), std::move(inputs_)};
}

}// namespace refactor::onnx
25 changes: 25 additions & 0 deletions src/05onnx/src/operators/batch_normalization.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef ONNX_BATCH_NORMALIZATION_HH
#define ONNX_BATCH_NORMALIZATION_HH

#include "frontend/operator.h"

namespace refactor::onnx {
using namespace frontend;

struct BatchNormalization final : public Operator {
bool trainingMode;

explicit BatchNormalization(bool);

static OpBox build(std::string_view, Attributes);
static size_t typeId();

size_t opTypeId() const final;
std::string_view opTypeName() const final;
InferResult infer(TensorRefs, InferOptions const &) const final;
LowerOperator lower(TensorRefs) const final;
};

}// namespace refactor::onnx

#endif// ONNX_BATCH_NORMALIZATION_HH
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
#include "computation/operators/cast.h"
#include "cast.hh"
#include "common.h"
#include "common/natural.h"
#include <execution>

namespace refactor::onnx {
using namespace common;
using Op = Cast;

Op::Cast(DataType to_)
: Operator(), to(to_) {}

auto Op::build(std::string_view, Attributes attributes) -> OpBox {
auto to = *DataType::parse(attributes.at("to").int_());
return OpBox(std::make_unique<Op>(to));
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto Op::opTypeId() const -> size_t { return typeId(); }
auto Op::opTypeName() const -> std::string_view { return "onnx::Cast"; }

template<class TS, class TD>
void castData(void const *src, void *dst, size_t size) {
Expand All @@ -13,11 +30,10 @@ namespace refactor::onnx {
std::transform(std::execution::unseq, src_, src_ + size, dst_, [](auto x) { return static_cast<TD>(x); });
}

InferResult inferCast(Operator const &op, TensorRefs inputs, InferOptions const &options) {
auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult {
EXPECT_SIZE(1)

auto const &input = inputs[0];
auto to = *DataType::parse(op.attribute("to").int_());
auto ans = Tensor::share(to, input.shape, extractDependency(inputs));
if (!options.shouldCalculate(inputs, {*ans})) {
return Ok(Tensors{std::move(ans)});
Expand Down Expand Up @@ -103,10 +119,9 @@ namespace refactor::onnx {
return Ok(Tensors{std::move(ans)});
}

LowerOperator lowerCast(Operator const &op, TensorRefs) {
using namespace computation;

auto to = *DataType::parse(op.attribute("to").int_());
return {std::make_shared<Cast>(to), {0}};
auto Op::lower(TensorRefs) const -> LowerOperator {
using Op_ = computation::Cast;
return {std::make_shared<Op_>(to), {0}};
}

}// namespace refactor::onnx
25 changes: 25 additions & 0 deletions src/05onnx/src/operators/cast.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef ONNX_CAST_HH
#define ONNX_CAST_HH

#include "frontend/operator.h"

namespace refactor::onnx {
using namespace frontend;

struct Cast final : public Operator {
common::DataType to;

explicit Cast(common::DataType);

static OpBox build(std::string_view, Attributes);
static size_t typeId();

size_t opTypeId() const final;
std::string_view opTypeName() const final;
InferResult infer(TensorRefs, InferOptions const &) const final;
LowerOperator lower(TensorRefs) const final;
};

}// namespace refactor::onnx

#endif// ONNX_CAST_HH
6 changes: 6 additions & 0 deletions src/05onnx/src/operators/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,10 @@ namespace refactor::onnx {
});
return Ok(std::move(ans));
}

Attribute defaultOr(Attributes &attrs, std::string const &name, Attribute defaultValue) {
auto iter = attrs.find(name);
return iter == attrs.end() ? defaultValue : std::move(iter->second);
}

}// namespace refactor::onnx
4 changes: 4 additions & 0 deletions src/05onnx/src/operators/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ namespace refactor::onnx {
OptionalInts const &pads,
OptionalInts const &strides);

Attribute defaultOr(Attributes &attrs,
std::string const &name,
Attribute defaultValue);

#define EXPECT_SIZE(N) \
if (inputs.size() != (N)) { \
return Err(InferError(ERROR_MSG("Input size error"))); \
Expand Down
Loading

0 comments on commit defbc9c

Please sign in to comment.