diff --git a/compiler/include/compiler/backend/optree/optimizer/opt_builder.hpp b/compiler/include/compiler/backend/optree/optimizer/opt_builder.hpp index b8edfd7e..fdf2f678 100644 --- a/compiler/include/compiler/backend/optree/optimizer/opt_builder.hpp +++ b/compiler/include/compiler/backend/optree/optimizer/opt_builder.hpp @@ -16,6 +16,14 @@ class OptBuilder : public Builder { Callback onInsert; Callback onUpdate; Callback onErase; + + Notifier() = default; + Notifier(const Notifier &) = default; + Notifier(Notifier &&) = default; + ~Notifier() = default; + + Notifier(const Callback &onInsert, const Callback &onUpdate, const Callback &onErase) + : onInsert(onInsert), onUpdate(onUpdate), onErase(onErase){}; }; OptBuilder(const Notifier ¬ifier) : Builder(), notifier(notifier){}; @@ -23,10 +31,13 @@ class OptBuilder : public Builder { OptBuilder(OptBuilder &&) = default; ~OptBuilder() override = default; + using Builder::insert; + void insert(const Operation::Ptr &op) override; Operation::Ptr clone(const Operation::Ptr &op); void erase(const Operation::Ptr &op); void update(const Operation::Ptr &op, const std::function &actor); + void replace(const Operation::Ptr &op, const Operation::Ptr &newOp); private: const Notifier ¬ifier; diff --git a/compiler/include/compiler/backend/optree/optimizer/transform_factories.hpp b/compiler/include/compiler/backend/optree/optimizer/transform_factories.hpp index 66cdceeb..5cc1730a 100644 --- a/compiler/include/compiler/backend/optree/optimizer/transform_factories.hpp +++ b/compiler/include/compiler/backend/optree/optimizer/transform_factories.hpp @@ -6,6 +6,7 @@ namespace optree { namespace optimizer { BaseTransform::Ptr createEraseUnusedOps(); +BaseTransform::Ptr createFoldConstants(); } // namespace optimizer } // namespace optree diff --git a/compiler/include/compiler/optree/base_adaptor.hpp b/compiler/include/compiler/optree/base_adaptor.hpp index 77315e6b..549ba0b1 100644 --- a/compiler/include/compiler/optree/base_adaptor.hpp +++ b/compiler/include/compiler/optree/base_adaptor.hpp @@ -74,6 +74,18 @@ struct Adaptor { return op.operator bool(); } + operator const Operation::Ptr &() const { + return op; + } + + operator Operation::Ptr() const { + return op; + } + + Operation *const operator->() const { + return op.get(); + } + const utils::SourceRef &ref() const { return op->ref; } diff --git a/compiler/include/compiler/optree/helpers.hpp b/compiler/include/compiler/optree/helpers.hpp index 5a6a840c..14525973 100644 --- a/compiler/include/compiler/optree/helpers.hpp +++ b/compiler/include/compiler/optree/helpers.hpp @@ -14,4 +14,11 @@ Type::Ptr deduceTargetCastType(const Type::Ptr &outType, const Type::Ptr &inType ArithCastOp insertNumericCastOp(const Type::Ptr &resultType, const Value::Ptr &value, Builder &builder, utils::SourceRef &ref); +template +AdaptorType getValueOwnerAs(const Value::Ptr &value) { + if (value->owner.expired()) + return {}; + return value->owner.lock()->as(); +} + } // namespace optree diff --git a/compiler/lib/backend/optree/optimizer/opt_builder.cpp b/compiler/lib/backend/optree/optimizer/opt_builder.cpp index ff884f6d..cfe64a85 100644 --- a/compiler/lib/backend/optree/optimizer/opt_builder.cpp +++ b/compiler/lib/backend/optree/optimizer/opt_builder.cpp @@ -8,22 +8,34 @@ using namespace optree; using namespace optree::optimizer; +namespace { + +void notifyInsertRecursively(const Operation::Ptr &op, const OptBuilder::Notifier ¬ifier) { + for (const auto &nestedOp : op->body) { + notifyInsertRecursively(nestedOp, notifier); + notifier.onInsert(nestedOp); + } +} + +} // namespace + void OptBuilder::insert(const Operation::Ptr &op) { Builder::insert(op); notifier.onInsert(op); } Operation::Ptr OptBuilder::clone(const Operation::Ptr &op) { - // TODO - notifier.onInsert(op); + auto newOp = op->clone(); + notifyInsertRecursively(op, notifier); + insert(op); return op; } void OptBuilder::erase(const Operation::Ptr &op) { if (op->parent) setInsertPointAfter(op); - for (const auto &nestedOp : op->body) - erase(nestedOp); + for (auto it = op->body.rbegin(); it != op->body.rend(); ++it) + erase(*it); op->erase(); notifier.onErase(op); } @@ -32,3 +44,17 @@ void OptBuilder::update(const Operation::Ptr &op, const std::function &a actor(); notifier.onUpdate(op); } + +void OptBuilder::replace(const Operation::Ptr &op, const Operation::Ptr &newOp) { + for (auto oldResultIt = op->results.begin(), newResultIt = newOp->results.begin(); + oldResultIt != op->results.end() && newResultIt != newOp->results.end(); ++oldResultIt, ++newResultIt) { + const Value::Ptr &oldResult = *oldResultIt; + const Value::Ptr &newResult = *newResultIt; + for (const auto &use : oldResult->uses) { + auto user = use.lock(); + update(user, [&] { user->operand(use.operandNumber) = newResult; }); + } + newResult->uses.splice_after(newResult->uses.before_begin(), oldResult->uses); + } + erase(op); +} diff --git a/compiler/lib/backend/optree/optimizer/optimizer.cpp b/compiler/lib/backend/optree/optimizer/optimizer.cpp index b8068155..a403f97c 100644 --- a/compiler/lib/backend/optree/optimizer/optimizer.cpp +++ b/compiler/lib/backend/optree/optimizer/optimizer.cpp @@ -13,6 +13,8 @@ using namespace optree; using namespace optree::optimizer; +namespace { + class OperationSet { std::vector data; std::unordered_map positions; @@ -63,7 +65,35 @@ class OperationSet { } }; -namespace { +class MutationTracker { + Operation *const trackedOp; + bool updatedTag; + bool erasedTag; + + public: + MutationTracker(const MutationTracker &) = delete; + MutationTracker(MutationTracker &&) = delete; + ~MutationTracker() = default; + + explicit MutationTracker(const Operation::Ptr &trackedOp) + : trackedOp(trackedOp.get()), updatedTag(false), erasedTag(false){}; + + bool updated() const { + return updatedTag; + } + bool erased() const { + return erasedTag; + } + + void raiseUpdated(const Operation::Ptr &op) { + if (op.get() == trackedOp) + updatedTag = true; + } + void raiseErased(const Operation::Ptr &op) { + if (op.get() == trackedOp) + erasedTag = true; + } +}; void pushToSet(const Operation::Ptr &root, OperationSet &ops) { for (const auto &op : root->body) @@ -71,28 +101,35 @@ void pushToSet(const Operation::Ptr &root, OperationSet &ops) { ops.push(root); } -} // namespace - -Optimizer::Optimizer() : iterLimit(100U) { - transforms.emplace_back(createEraseUnusedOps()); -} - -void Optimizer::process(Program &program) const { - OperationSet ops; - bool mutated = false; +OptBuilder::Notifier makeNotifier(OperationSet &ops, bool &mutated, MutationTracker &tracker) { OptBuilder::Notifier notifier; notifier.onInsert = [&ops, &mutated](const Operation::Ptr &op) { ops.push(op); mutated = true; }; - notifier.onUpdate = [&ops, &mutated](const Operation::Ptr &op) { + notifier.onUpdate = [&ops, &mutated, &tracker](const Operation::Ptr &op) { ops.push(op); mutated = true; + tracker.raiseUpdated(op); }; - notifier.onErase = [&ops, &mutated](const Operation::Ptr &op) { + notifier.onErase = [&ops, &mutated, &tracker](const Operation::Ptr &op) { ops.erase(op); mutated = true; + tracker.raiseErased(op); }; + return notifier; +} + +} // namespace + +Optimizer::Optimizer() : iterLimit(100U) { + transforms.emplace_back(createEraseUnusedOps()); + transforms.emplace_back(createFoldConstants()); +} + +void Optimizer::process(Program &program) const { + OperationSet ops; + bool mutated = false; size_t iter = 0; do { mutated = false; @@ -100,7 +137,11 @@ void Optimizer::process(Program &program) const { pushToSet(program.root, ops); while (!ops.empty()) { Operation::Ptr op = ops.pop(); + MutationTracker tracker(op); + auto notifier = makeNotifier(ops, mutated, tracker); for (const auto &transform : transforms) { + if (tracker.erased()) + break; if (!transform->canRun(op)) continue; OptBuilder builder(notifier); diff --git a/compiler/lib/backend/optree/optimizer/transforms/fold_constants.cpp b/compiler/lib/backend/optree/optimizer/transforms/fold_constants.cpp new file mode 100644 index 00000000..88349930 --- /dev/null +++ b/compiler/lib/backend/optree/optimizer/transforms/fold_constants.cpp @@ -0,0 +1,65 @@ +#include "optimizer/transform.hpp" + +#include +#include + +#include "compiler/optree/adaptors.hpp" +#include "compiler/optree/definitions.hpp" +#include "compiler/optree/helpers.hpp" +#include "compiler/optree/operation.hpp" +#include "compiler/optree/types.hpp" + +#include "optimizer/opt_builder.hpp" + +using namespace optree; +using namespace optree::optimizer; + +namespace { + +struct FoldConstants : public Transform { + using Transform::Transform; + + static void foldArithBinaryOp(const ArithBinaryOp &op, OptBuilder &builder) { + auto lhsOp = getValueOwnerAs(op.lhs()); + auto rhsOp = getValueOwnerAs(op.rhs()); + if (!lhsOp || !rhsOp) + return; + auto type = op.result()->type; + if (type->is()) { + int64_t folded = 0; + int64_t lhs = lhsOp.value().as(); + int64_t rhs = rhsOp.value().as(); + switch (op.kind()) { + case ArithBinOpKind::AddI: + folded = lhs + rhs; + break; + case ArithBinOpKind::MulI: + folded = lhs * rhs; + break; + default: + folded = -1; // TODO: extend + } + auto newOp = builder.insert(op.ref(), type, folded); + builder.replace(op, newOp); + return; + } + // TODO: support other types + } + + void run(const Operation::Ptr &op, OptBuilder &builder) const override { + if (op->is()) + foldArithBinaryOp(op->as(), builder); + } +}; + +} // namespace + +namespace optree { +namespace optimizer { + +BaseTransform::Ptr createFoldConstants() { + return std::make_shared(); +} + +} // namespace optimizer +} // namespace optree diff --git a/compiler/lib/optree/operation.cpp b/compiler/lib/optree/operation.cpp index b204bcdb..3d396520 100644 --- a/compiler/lib/optree/operation.cpp +++ b/compiler/lib/optree/operation.cpp @@ -75,8 +75,8 @@ Operation::SpecId Operation::getUnknownSpecId() { } void Operation::addOperand(const Value::Ptr &value) { - operands.emplace_back(value); addUse(value, operands.size()); + operands.emplace_back(value); } void Operation::insertOperand(size_t operandNumber, const Value::Ptr &value) {