Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
vla5924 committed Apr 24, 2024
1 parent 1bfee8a commit 48d51cb
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 17 deletions.
11 changes: 11 additions & 0 deletions compiler/include/compiler/backend/optree/optimizer/opt_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,28 @@ class OptBuilder : public Builder {
Callback onInsert;
Callback onUpdate;
Callback onErase;

Notifier() = default;
Notifier(const Notifier &) = default;
Notifier(Notifier &&) = default;
~Notifier() = default;

Notifier(const Callback &onInsert, const Callback &onUpdate, const Callback &onErase)
: onInsert(onInsert), onUpdate(onUpdate), onErase(onErase){};
};

OptBuilder(const Notifier &notifier) : Builder(), notifier(notifier){};
OptBuilder(const OptBuilder &) = delete;
OptBuilder(OptBuilder &&) = default;
~OptBuilder() override = default;

using Builder::insert;

void insert(const Operation::Ptr &op) override;
Operation::Ptr clone(const Operation::Ptr &op);
void erase(const Operation::Ptr &op);
void update(const Operation::Ptr &op, const std::function<void()> &actor);
void replace(const Operation::Ptr &op, const Operation::Ptr &newOp);

private:
const Notifier &notifier;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace optree {
namespace optimizer {

BaseTransform::Ptr createEraseUnusedOps();
BaseTransform::Ptr createFoldConstants();

} // namespace optimizer
} // namespace optree
12 changes: 12 additions & 0 deletions compiler/include/compiler/optree/base_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ struct Adaptor {
return op.operator bool();
}

operator const Operation::Ptr &() const {
return op;
}

operator Operation::Ptr() const {
return op;
}

Operation *const operator->() const {
return op.get();
}

const utils::SourceRef &ref() const {
return op->ref;
}
Expand Down
7 changes: 7 additions & 0 deletions compiler/include/compiler/optree/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,11 @@ Type::Ptr deduceTargetCastType(const Type::Ptr &outType, const Type::Ptr &inType
ArithCastOp insertNumericCastOp(const Type::Ptr &resultType, const Value::Ptr &value, Builder &builder,
utils::SourceRef &ref);

template <typename AdaptorType>
AdaptorType getValueOwnerAs(const Value::Ptr &value) {
if (value->owner.expired())
return {};
return value->owner.lock()->as<AdaptorType>();
}

} // namespace optree
34 changes: 30 additions & 4 deletions compiler/lib/backend/optree/optimizer/opt_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,34 @@
using namespace optree;
using namespace optree::optimizer;

namespace {

void notifyInsertRecursively(const Operation::Ptr &op, const OptBuilder::Notifier &notifier) {
for (const auto &nestedOp : op->body) {
notifyInsertRecursively(nestedOp, notifier);
notifier.onInsert(nestedOp);
}
}

} // namespace

void OptBuilder::insert(const Operation::Ptr &op) {
Builder::insert(op);
notifier.onInsert(op);
}

Operation::Ptr OptBuilder::clone(const Operation::Ptr &op) {
// TODO
notifier.onInsert(op);
auto newOp = op->clone();
notifyInsertRecursively(op, notifier);
insert(op);
return op;
}

void OptBuilder::erase(const Operation::Ptr &op) {
if (op->parent)
setInsertPointAfter(op);
for (const auto &nestedOp : op->body)
erase(nestedOp);
for (auto it = op->body.rbegin(); it != op->body.rend(); ++it)
erase(*it);
op->erase();
notifier.onErase(op);
}
Expand All @@ -32,3 +44,17 @@ void OptBuilder::update(const Operation::Ptr &op, const std::function<void()> &a
actor();
notifier.onUpdate(op);
}

void OptBuilder::replace(const Operation::Ptr &op, const Operation::Ptr &newOp) {
for (auto oldResultIt = op->results.begin(), newResultIt = newOp->results.begin();
oldResultIt != op->results.end() && newResultIt != newOp->results.end(); ++oldResultIt, ++newResultIt) {
const Value::Ptr &oldResult = *oldResultIt;
const Value::Ptr &newResult = *newResultIt;
for (const auto &use : oldResult->uses) {
auto user = use.lock();
update(user, [&] { user->operand(use.operandNumber) = newResult; });
}
newResult->uses.splice_after(newResult->uses.before_begin(), oldResult->uses);
}
erase(op);
}
65 changes: 53 additions & 12 deletions compiler/lib/backend/optree/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
using namespace optree;
using namespace optree::optimizer;

namespace {

class OperationSet {
std::vector<Operation::Ptr> data;
std::unordered_map<const Operation *, size_t> positions;
Expand Down Expand Up @@ -63,44 +65,83 @@ class OperationSet {
}
};

namespace {
class MutationTracker {
Operation *const trackedOp;
bool updatedTag;
bool erasedTag;

public:
MutationTracker(const MutationTracker &) = delete;
MutationTracker(MutationTracker &&) = delete;
~MutationTracker() = default;

explicit MutationTracker(const Operation::Ptr &trackedOp)
: trackedOp(trackedOp.get()), updatedTag(false), erasedTag(false){};

bool updated() const {
return updatedTag;
}
bool erased() const {
return erasedTag;
}

void raiseUpdated(const Operation::Ptr &op) {
if (op.get() == trackedOp)
updatedTag = true;
}
void raiseErased(const Operation::Ptr &op) {
if (op.get() == trackedOp)
erasedTag = true;
}
};

void pushToSet(const Operation::Ptr &root, OperationSet &ops) {
for (const auto &op : root->body)
pushToSet(op, ops);
ops.push(root);
}

} // namespace

Optimizer::Optimizer() : iterLimit(100U) {
transforms.emplace_back(createEraseUnusedOps());
}

void Optimizer::process(Program &program) const {
OperationSet ops;
bool mutated = false;
OptBuilder::Notifier makeNotifier(OperationSet &ops, bool &mutated, MutationTracker &tracker) {
OptBuilder::Notifier notifier;
notifier.onInsert = [&ops, &mutated](const Operation::Ptr &op) {
ops.push(op);
mutated = true;
};
notifier.onUpdate = [&ops, &mutated](const Operation::Ptr &op) {
notifier.onUpdate = [&ops, &mutated, &tracker](const Operation::Ptr &op) {
ops.push(op);
mutated = true;
tracker.raiseUpdated(op);
};
notifier.onErase = [&ops, &mutated](const Operation::Ptr &op) {
notifier.onErase = [&ops, &mutated, &tracker](const Operation::Ptr &op) {
ops.erase(op);
mutated = true;
tracker.raiseErased(op);
};
return notifier;
}

} // namespace

Optimizer::Optimizer() : iterLimit(100U) {
transforms.emplace_back(createEraseUnusedOps());
transforms.emplace_back(createFoldConstants());
}

void Optimizer::process(Program &program) const {
OperationSet ops;
bool mutated = false;
size_t iter = 0;
do {
mutated = false;
ops.clear();
pushToSet(program.root, ops);
while (!ops.empty()) {
Operation::Ptr op = ops.pop();
MutationTracker tracker(op);
auto notifier = makeNotifier(ops, mutated, tracker);
for (const auto &transform : transforms) {
if (tracker.erased())
break;
if (!transform->canRun(op))
continue;
OptBuilder builder(notifier);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "optimizer/transform.hpp"

#include <cstdint>
#include <memory>

#include "compiler/optree/adaptors.hpp"
#include "compiler/optree/definitions.hpp"
#include "compiler/optree/helpers.hpp"
#include "compiler/optree/operation.hpp"
#include "compiler/optree/types.hpp"

#include "optimizer/opt_builder.hpp"

using namespace optree;
using namespace optree::optimizer;

namespace {

struct FoldConstants : public Transform<ArithBinaryOp, ArithCastOp, LogicBinaryOp, LogicUnaryOp> {
using Transform::Transform;

static void foldArithBinaryOp(const ArithBinaryOp &op, OptBuilder &builder) {
auto lhsOp = getValueOwnerAs<ConstantOp>(op.lhs());
auto rhsOp = getValueOwnerAs<ConstantOp>(op.rhs());
if (!lhsOp || !rhsOp)
return;
auto type = op.result()->type;
if (type->is<IntegerType>()) {
int64_t folded = 0;
int64_t lhs = lhsOp.value().as<int64_t>();
int64_t rhs = rhsOp.value().as<int64_t>();
switch (op.kind()) {
case ArithBinOpKind::AddI:
folded = lhs + rhs;
break;
case ArithBinOpKind::MulI:
folded = lhs * rhs;
break;
default:
folded = -1; // TODO: extend
}
auto newOp = builder.insert<ConstantOp>(op.ref(), type, folded);
builder.replace(op, newOp);
return;
}
// TODO: support other types
}

void run(const Operation::Ptr &op, OptBuilder &builder) const override {
if (op->is<ArithBinaryOp>())
foldArithBinaryOp(op->as<ArithBinaryOp>(), builder);
}
};

} // namespace

namespace optree {
namespace optimizer {

BaseTransform::Ptr createFoldConstants() {
return std::make_shared<FoldConstants>();
}

} // namespace optimizer
} // namespace optree
2 changes: 1 addition & 1 deletion compiler/lib/optree/operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ Operation::SpecId Operation::getUnknownSpecId() {
}

void Operation::addOperand(const Value::Ptr &value) {
operands.emplace_back(value);
addUse(value, operands.size());
operands.emplace_back(value);
}

void Operation::insertOperand(size_t operandNumber, const Value::Ptr &value) {
Expand Down

0 comments on commit 48d51cb

Please sign in to comment.