Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksimShagov committed May 13, 2024
1 parent e32cc1b commit 4e07f17
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 69 deletions.
4 changes: 2 additions & 2 deletions compiler/include/compiler/optree/adaptors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ struct FunctionOp : Adaptor {
struct FunctionCallOp : Adaptor {
OPTREE_ADAPTOR_HELPER(Adaptor, "FunctionCall")

void init(const std::string &name, const Type::Ptr &resultType, const std::vector<Value::Ptr> &arguments);
void init(const FunctionOp &callee, const std::vector<Value::Ptr> &arguments);
void init(const std::string &name, const Type::Ptr &resultType, const std::vector<Value::Ptr> &arguments = {});
void init(const FunctionOp &callee, const std::vector<Value::Ptr> &arguments = {});

OPTREE_ADAPTOR_ATTRIBUTE(name, setName, std::string, 0)
OPTREE_ADAPTOR_RESULT(result, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "compiler/optree/adaptors.hpp"
#include "compiler/optree/operation.hpp"
#include "compiler/utils/helpers.hpp"
#include "compiler/utils/language.hpp"

#include "optimizer/opt_builder.hpp"
#include "optimizer/transform.hpp"
Expand Down Expand Up @@ -37,8 +38,8 @@ struct EraseUnusedFunctions : public Transform<ModuleOp> {
if (funcOp)
getInnerFunctionCallNames(moduleChild, funcOp.name(), edges);
}
auto &mainFunctions = edges["main"];
std::unordered_set<std::string> usedFunctions = {"main"};
auto &mainFunctions = edges[utils::language::funcMain];
std::unordered_set<std::string> usedFunctions = {utils::language::funcMain};
std::deque<std::string> queue(mainFunctions.begin(), mainFunctions.end());
while (!queue.empty()) {
auto &name = queue.front();
Expand All @@ -50,7 +51,7 @@ struct EraseUnusedFunctions : public Transform<ModuleOp> {
queue.pop_front();
}

for (const auto &op : utils::advanceEarly(op->body.begin(), op->body.end())) {
for (const auto &op : utils::advanceEarly(op->body)) {
if (!usedFunctions.contains(op->as<FunctionOp>().name())) {
builder.erase(op);
}
Expand Down
126 changes: 62 additions & 64 deletions compiler/tests/backend/optree/optimizer/erase_unused_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include <gtest/gtest.h>

#include <utility>

#include "compiler/backend/optree/optimizer/optimizer.hpp"
#include "compiler/backend/optree/optimizer/transform_factories.hpp"
#include "compiler/optree/adaptors.hpp"
Expand Down Expand Up @@ -29,18 +27,18 @@ TEST_F(EraseUnusedFunctionsTest, can_run_on_empty_optree) {
TEST_F(EraseUnusedFunctionsTest, can_remove_unused_function) {
{
auto &&[m, v] = getActual();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("unused", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[2], 0).inward(v[3], 1).withBody();
v[4] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[4] = m.opInit<ConstantOp>(m.tI64, 123);
v[5] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[2], v[3]);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("test3", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
m.opInit<FunctionOp>("test3", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[0], v[0]);
v[2] = m.opInit<ArithCastOp>(ArithCastOpKind::IntToFloat, m.tF64, v[1]);
v[3] = m.opInit<LogicBinaryOp>(LogicBinOpKind::LessEqualI, v[0], v[1]);
Expand All @@ -49,13 +47,13 @@ TEST_F(EraseUnusedFunctionsTest, can_remove_unused_function) {
}
{
auto &&[m, v] = getExpected();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("test3", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
m.opInit<FunctionOp>("test3", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[0], v[0]);
v[2] = m.opInit<ArithCastOp>(ArithCastOpKind::IntToFloat, m.tF64, v[1]);
v[3] = m.opInit<LogicBinaryOp>(LogicBinOpKind::LessEqualI, v[0], v[1]);
Expand All @@ -70,31 +68,31 @@ TEST_F(EraseUnusedFunctionsTest, can_remove_unused_function) {
TEST_F(EraseUnusedFunctionsTest, can_remove_unused_function_with_recursion) {
{
auto &&[m, v] = getActual();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("unused", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[2], 0).inward(v[3], 1).withBody();
v[4] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[4] = m.opInit<ConstantOp>(m.tI64, 123);
v[5] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[2], v[3]);
v[0] = m.opInit<FunctionCallOp>("unused", m.tNone, std::vector<Value::Ptr>());
v[0] = m.opInit<FunctionCallOp>("unused", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("test3", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("test3", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
}
{
auto &&[m, v] = getExpected();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("test3", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("test3", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
}
Expand All @@ -106,37 +104,37 @@ TEST_F(EraseUnusedFunctionsTest, can_remove_unused_function_with_recursion) {
TEST_F(EraseUnusedFunctionsTest, can_remove_unused_function_calls_each_other) {
{
auto &&[m, v] = getActual();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("unused1", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[2], 0).inward(v[3], 1).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[2], v[3]);
v[2] = m.opInit<FunctionCallOp>("unused2", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("unused2", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("unused2", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[2], 0).inward(v[3], 1).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[2], v[3]);
v[2] = m.opInit<FunctionCallOp>("unused1", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("unused1", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("test3", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("test3", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
}
{
auto &&[m, v] = getExpected();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<ConstantOp>(m.tI64, 123);
v[1] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("test3", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("test3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("test3", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("test3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
}
Expand All @@ -148,37 +146,37 @@ TEST_F(EraseUnusedFunctionsTest, can_remove_unused_function_calls_each_other) {
TEST_F(EraseUnusedFunctionsTest, can_keep_complex_used_functions) {
{
auto &&[m, v] = getActual();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("used1", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("used1", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used1", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody();
v[2] = m.opInit<FunctionCallOp>("used2", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used2", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used2", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[2], 0).inward(v[3], 1).withBody();
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used3", m.tFunc({}, m.tNone)).withBody();
m.opInit<FunctionOp>("used3", m.tFunc(m.tNone)).withBody();
m.opInit<ReturnOp>();
m.endBody();
}
{
auto &&[m, v] = getExpected();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("used1", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("used1", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used1", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody();
v[2] = m.opInit<FunctionCallOp>("used2", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used2", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used2", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[2], 0).inward(v[3], 1).withBody();
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used3", m.tFunc({}, m.tNone)).withBody();
m.opInit<FunctionOp>("used3", m.tFunc(m.tNone)).withBody();
m.opInit<ReturnOp>();
}

Expand All @@ -189,46 +187,46 @@ TEST_F(EraseUnusedFunctionsTest, can_keep_complex_used_functions) {
TEST_F(EraseUnusedFunctionsTest, can_keep_several_used_functions) {
{
auto &&[m, v] = getActual();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("used1", m.tNone, std::vector<Value::Ptr>());
v[1] = m.opInit<FunctionCallOp>("used2", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("used1", m.tNone);
v[1] = m.opInit<FunctionCallOp>("used2", m.tNone);
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("unused", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[2], 0).inward(v[3], 1).withBody();
v[4] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[4] = m.opInit<ConstantOp>(m.tI64, 123);
v[5] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[2], v[3]);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used1", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody();
v[2] = m.opInit<FunctionCallOp>("used2", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used2", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used2", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[2], 0).inward(v[3], 1).withBody();
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used3", m.tFunc({}, m.tNone)).withBody();
m.opInit<FunctionOp>("used3", m.tFunc(m.tNone)).withBody();
m.opInit<ReturnOp>();
m.endBody();
}
{
auto &&[m, v] = getExpected();
m.opInit<FunctionOp>("main", m.tFunc({}, m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("used1", m.tNone, std::vector<Value::Ptr>());
v[1] = m.opInit<FunctionCallOp>("used2", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone, std::vector<Value::Ptr>());
m.opInit<FunctionOp>("main", m.tFunc(m.tNone)).withBody();
v[0] = m.opInit<FunctionCallOp>("used1", m.tNone);
v[1] = m.opInit<FunctionCallOp>("used2", m.tNone);
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used1", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody();
v[2] = m.opInit<FunctionCallOp>("used2", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used2", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used2", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[2], 0).inward(v[3], 1).withBody();
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone, std::vector<Value::Ptr>());
v[2] = m.opInit<FunctionCallOp>("used3", m.tNone);
m.opInit<ReturnOp>();
m.endBody();
m.opInit<FunctionOp>("used3", m.tFunc({}, m.tNone)).withBody();
m.opInit<FunctionOp>("used3", m.tFunc(m.tNone)).withBody();
m.opInit<ReturnOp>();
}

Expand Down

0 comments on commit 4e07f17

Please sign in to comment.