Skip to content

Commit

Permalink
fix(frontend): 填充边信息时不能接受输入边不存在
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jan 31, 2024
1 parent 35dc6c8 commit 27a8ad6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 22 deletions.
6 changes: 3 additions & 3 deletions src/00common/include/common/error_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace refactor {
std::abort()

#ifndef DISABLE_ASSERT
#define ASSERT(CONDITION, F, ...) \
{ \
if (!(CONDITION)) RUNTIME_ERROR(fmt::format("Assertion: " #F, ##__VA_ARGS__)); \
#define ASSERT(CONDITION, F, ...) \
{ \
if (!(CONDITION)) RUNTIME_ERROR(fmt::format("Assertion: " F, ##__VA_ARGS__)); \
}
#else
#define ASSERT(CONDITION, F)
Expand Down
24 changes: 9 additions & 15 deletions src/06frontend/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,27 +98,21 @@ namespace refactor::frontend {
auto const startTime = high_resolution_clock::now();
// 拓扑遍历
for (auto [nodeIdx, inputs, outputs] : _internal.topology) {
auto unknownEdge = false, inputChanged = false;
for (auto i : inputs) {
auto const &input = _internal.edges[i].tensor;
if (!input) {// 有入边未知
unknownEdge = true;
break;
}
auto checked = edgeChanged[2 * i]; // NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
auto changed = edgeChanged[2 * i + 1];// NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
auto inputChanged = false;
for (auto i : range0_(inputs.size())) {
auto j = inputs[i];
auto const &input = _internal.edges[j].tensor;
ASSERT(input, "The {}th input of \"{}\" is nullptr", i, _internal.nodes[nodeIdx].name);
auto checked = edgeChanged[2 * j]; // NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
auto changed = edgeChanged[2 * j + 1];// NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
if (!checked) {
checked = true;
if (changed = _edgeSnapshot[i] != *input) {
_edgeSnapshot[i] = input->snapshot();
if (changed = _edgeSnapshot[j] != *input) {
_edgeSnapshot[j] = input->snapshot();
}
}
inputChanged |= changed;
}
// 有入边未知,跳过节点
if (unknownEdge) {
continue;
}
if (!inputChanged && std::all_of(outputs.begin(), outputs.end(),
[this](auto i) { return _internal.edges[i].tensor; })) {
// 入边未发生变化,且出边已推导
Expand Down
6 changes: 3 additions & 3 deletions src/07onnx/src/operators/split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ namespace refactor::onnx {
numOutputs(numOutputs_) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
auto axis = attributes.getOrInsert( "axis", {0}).int_();
auto numOutputs = attributes.getOrInsert( "num_outputs", {0}).int_();
auto axis = attributes.getOrInsert("axis", {0}).int_();
auto numOutputs = attributes.getOrInsert("num_outputs", {0}).int_();
return OpBox(std::make_unique<Op>(axis, numOutputs));
}
auto Op::typeId() -> size_t {
Expand Down Expand Up @@ -45,7 +45,7 @@ namespace refactor::onnx {
ans[i] = Tensor::share(input.dataType, input.shape, dependencies);
ans[i]->shape[axis_] = DimExpr(each);
} else {
ASSERT(i == numOutputs - 1, ERROR_MSG("Split error"));
ASSERT(i == numOutputs - 1, "Split error");
ans[i] = Tensor::share(input.dataType, input.shape, dependencies);
ans[i]->shape[axis_] = DimExpr(total);
}
Expand Down
2 changes: 1 addition & 1 deletion src/07onnx/src/operators/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace refactor::onnx {
return Err(InferError(ERROR_MSG("repeats not support")));
}
EXPECT_VAL(repeats.shape[0], repeatsSize)
ASSERT(repeatsSize == rank, ERROR_MSG("repeats size error"));
ASSERT(repeatsSize == rank, "repeats size error");

auto repeats_ = repeats.data->get<int64_t>();
Shape output(rank, DimExpr(1));
Expand Down

0 comments on commit 27a8ad6

Please sign in to comment.