Skip to content

Commit

Permalink
hl:lowertypes: Prune old code, use newer helpers that do type convers…
Browse files Browse the repository at this point in the history
…ion.
  • Loading branch information
lkorenc committed Nov 14, 2023
1 parent 9a6acbd commit 9a975c0
Showing 1 changed file with 20 additions and 137 deletions.
157 changes: 20 additions & 137 deletions lib/vast/Dialect/HighLevel/Transforms/HLLowerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ VAST_UNRELAX_WARNINGS

#include "vast/Conversion/TypeConverters/DataLayout.hpp"
#include "vast/Conversion/TypeConverters/HLToStd.hpp"
#include "vast/Conversion/TypeConverters/TypeConverter.hpp"
#include "vast/Conversion/TypeConverters/TypeConvertingPattern.hpp"

#include <algorithm>
#include <iostream>
Expand All @@ -36,160 +36,43 @@ namespace vast::hl
{
using type_converter_t = conv::tc::HLToStd;

struct LowerHighLevelOpType : mlir::ConversionPattern
{
using Base = mlir::ConversionPattern;
using Base::Base;

LowerHighLevelOpType(type_converter_t &tc, mcontext_t *mctx)
: Base(tc, mlir::Pattern::MatchAnyOpTypeTag{}, 1, mctx)
{}

template< typename attrs_list >
maybe_attr_t high_level_typed_attr_conversion(mlir::Attribute attr) const {
using attr_t = typename attrs_list::head;
using rest_t = typename attrs_list::tail;

if (auto typed = mlir::dyn_cast< attr_t >(attr)) {
if constexpr (std::same_as< attr_t, core::VoidAttr>) {
return Maybe(typed.getType())
.and_then([&] (auto type) {
return getTypeConverter()->convertType(type);
})
.and_then([&] (auto type) {
return core::VoidAttr::get(type.getContext(), type);
})
.template take_wrapped< maybe_attr_t >();
} else {
return Maybe(typed.getType())
.and_then([&] (auto type) {
return getTypeConverter()->convertType(type);
})
.and_then([&] (auto type) {
return attr_t::get(type, typed.getValue());
})
.template take_wrapped< maybe_attr_t >();
}
}

if constexpr (attrs_list::size != 1) {
return high_level_typed_attr_conversion< rest_t >(attr);
} else {
return std::nullopt;
}
}

auto convert_high_level_typed_attr() const {
return [&] (mlir::Attribute attr) {
return high_level_typed_attr_conversion< core::typed_attrs >(attr);
};
}
namespace pattern {

logical_result matchAndRewrite(
operation op, llvm::ArrayRef< mlir_value > ops,
conversion_rewriter &rewriter
) const override {
if (mlir::isa< FuncOp >(op)) {
return mlir::failure();
}
struct lower_type : conv::tc::type_converting_pattern< type_converter_t >
{
using parent = conv::tc::type_converting_pattern< type_converter_t >;

auto &tc = static_cast< type_converter_t & >(*getTypeConverter());

mlir::SmallVector< mlir_type > rty;
auto status = tc.convertTypes(op->getResultTypes(), rty);
// TODO(lukas): How to use `llvm::formatv` with `operation `?
VAST_CHECK(mlir::succeeded(status), "Was not able to type convert.");

// We just change type, no need to copy everything
auto lower_op = [&]() {
for (std::size_t i = 0; i < rty.size(); ++i) {
op->getResult(i).setType(rty[i]);
}

mlir::AttrTypeReplacer replacer;
replacer.addReplacement(conv::tc::convert_type_attr(tc));
replacer.addReplacement(conv::tc::convert_data_layout_attrs(tc));
replacer.addReplacement(convert_high_level_typed_attr());
replacer.recursivelyReplaceElementsIn(op, true /* replace attrs */);
};
// It has to be done in one "transaction".
rewriter.updateRootInPlace(op, lower_op);

return mlir::success();
}
};
lower_type(type_converter_t &tc, mcontext_t *mctx) : parent(tc, *mctx) {}

struct LowerFuncOpType : mlir::OpConversionPattern< FuncOp >
{
using Base = mlir::OpConversionPattern< FuncOp >;
using Base::Base;

using Base::getTypeConverter;

// As the reference how to lower functions, the `StandardToLLVM`
// conversion is used.
//
// But basically we need to copy the function with the converted
// function type -> copy body -> fix arguments of the entry region.
logical_result matchAndRewrite(
FuncOp fn, OpAdaptor adaptor, conversion_rewriter &rewriter
) const override {
auto fty = adaptor.getFunctionType();
auto &tc = static_cast< type_converter_t & >(*getTypeConverter());

conv::tc::signature_conversion_t sigconvert(fty.getNumInputs());
if (mlir::failed(tc.convertSignatureArgs(fty.getInputs(), sigconvert))) {
return mlir::failure();
logical_result matchAndRewrite(
operation op, mlir::ArrayRef< mlir::Value > ops,
conversion_rewriter &rewriter
) const override {
if (auto func_op = mlir::dyn_cast< hl::FuncOp >(op))
return replace(func_op, ops, rewriter);
return replace(op, ops, rewriter);
}

llvm::SmallVector< mlir_type, 1 > results;
if (mlir::failed(tc.convertTypes(fty.getResults(), results))) {
return mlir::failure();
}
};

auto params = sigconvert.getConvertedTypes();

auto new_type = core::FunctionType::get(
rewriter.getContext(), params, results, fty.isVarArg()
);

// TODO deal with function attribute types

rewriter.updateRootInPlace(fn, [&] {
fn.setType(new_type);
for (auto [ty, param] : llvm::zip(params, fn.getBody().getArguments())) {
param.setType(ty);
}
});

return mlir::success();
}
};
} // namespace pattern

struct HLLowerTypesPass : HLLowerTypesBase< HLLowerTypesPass >
{
void runOnOperation() override {
auto op = this->getOperation();
auto &mctx = this->getContext();

const auto &dl_analysis = this->getAnalysis< mlir::DataLayoutAnalysis >();
type_converter_t type_converter(dl_analysis.getAtOrAbove(op), mctx);

mlir::ConversionTarget trg(mctx);
// We want to check *everything* for presence of hl type
// that can be lowered.
auto is_legal = [](operation op)
{
auto is_hl = [](mlir_type t) -> bool { return isHighLevelType(t); };

return !has_type_somewhere(op, is_hl);
};
auto is_legal = type_converter.get_is_type_conversion_legal();
trg.markUnknownOpDynamicallyLegal(is_legal);

mlir::RewritePatternSet patterns(&mctx);
const auto &dl_analysis = this->getAnalysis< mlir::DataLayoutAnalysis >();
type_converter_t type_converter(dl_analysis.getAtOrAbove(op), mctx);

patterns.add< LowerHighLevelOpType, LowerFuncOpType >(
type_converter, patterns.getContext()
);
patterns.add< pattern::lower_type >(type_converter, patterns.getContext());

if (mlir::failed(mlir::applyPartialConversion(op, trg, std::move(patterns)))) {
return signalPassFailure();
Expand Down

0 comments on commit 9a975c0

Please sign in to comment.