Skip to content

Commit

Permalink
conv:irstollvm: Improve codestyle.
Browse files Browse the repository at this point in the history
  • Loading branch information
lkorenc committed Nov 2, 2023
1 parent 4df412b commit 11ea05c
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 138 deletions.
49 changes: 16 additions & 33 deletions lib/vast/Conversion/Common/Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
#include "vast/Conversion/Common/Patterns.hpp"
#include "vast/Conversion/TypeConverters/LLVMTypeConverter.hpp"

namespace vast::conv::irstollvm
{
namespace vast::conv::irstollvm {
// I would consider to just use the entire namespace, everything
// has (unfortunately) prefixed name with `LLVM` anyway.
namespace LLVM = mlir::LLVM;
Expand Down Expand Up @@ -48,14 +47,11 @@ namespace vast::conv::irstollvm
}

auto mk_index(auto loc, std::size_t idx, auto &rewriter) const
-> mlir::LLVM::ConstantOp
{
-> mlir::LLVM::ConstantOp {
auto index_type = convert(rewriter.getIndexType());
return rewriter.template create< mlir::LLVM::ConstantOp >(
loc,
index_type,
rewriter.getIntegerAttr(index_type, idx));

loc, index_type, rewriter.getIntegerAttr(index_type, idx)
);
}
};

Expand All @@ -67,9 +63,8 @@ namespace vast::conv::irstollvm

using adaptor_t = typename src_t::Adaptor;

mlir::LogicalResult matchAndRewrite(
src_t op, adaptor_t ops, conversion_rewriter &rewriter
) const override {
auto matchAndRewrite(src_t op, adaptor_t ops, conversion_rewriter &rewriter)
-> const logical_result override {
auto target_ty = this->type_converter().convert_type_to_type(op.getType());
auto new_op = rewriter.create< trg_t >(op.getLoc(), *target_ty, ops.getOperands());
rewriter.replaceOp(op, new_op);
Expand All @@ -86,9 +81,8 @@ namespace vast::conv::irstollvm

using adaptor_t = typename src_t::Adaptor;

mlir::LogicalResult matchAndRewrite(
src_t op, adaptor_t ops, conversion_rewriter &rewriter
) const override {
auto matchAndRewrite(src_t op, adaptor_t ops, conversion_rewriter &rewriter)
-> const logical_result override {
rewriter.replaceOp(op, ops.getOperands());
return mlir::success();
}
Expand All @@ -102,38 +96,27 @@ namespace vast::conv::irstollvm

using adaptor_t = typename src_t::Adaptor;

mlir::LogicalResult matchAndRewrite(
src_t op, adaptor_t ops, conversion_rewriter &rewriter
) const override {
auto matchAndRewrite(src_t op, adaptor_t ops, conversion_rewriter &rewriter)
-> const logical_result override {
rewriter.eraseOp(op);
return mlir::success();
}

static void legalize(auto &trg) { trg.template addIllegalOp< src_t >(); }
};

static auto get_is_illegal(auto &tc)
{
return [&](mlir_type type)
{
return !tc.isLegal(type);
};
static auto get_is_illegal(auto &tc) {
return [&](mlir_type type) { return !tc.isLegal(type); };
}

template< typename T, typename type_converter >
auto get_has_only_legal_types(type_converter &tc)
{
return [&](T op) -> bool
{
return !has_type_somewhere(op, get_is_illegal(tc));
};
auto get_has_only_legal_types(type_converter &tc) {
return [&](T op) -> bool { return !has_type_somewhere(op, get_is_illegal(tc)); };
}

template< typename T, typename type_converter >
auto get_has_legal_return_type(type_converter &tc)
{
return [&](T op) -> bool
{
auto get_has_legal_return_type(type_converter &tc) {
return [&](T op) -> bool {
return !contains_subtype(op.getResult().getType(), get_is_illegal(tc));
};
}
Expand Down
167 changes: 62 additions & 105 deletions lib/vast/Conversion/Common/IRsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@

VAST_RELAX_WARNINGS
#include <mlir/Analysis/DataLayoutAnalysis.h>
#include <mlir/Conversion/LLVMCommon/Pattern.h>
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
#include <mlir/Conversion/LLVMCommon/Pattern.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>

#include <llvm/ADT/APFloat.h>
VAST_UNRELAX_WARNINGS

#include "../PassesDetails.hpp"

#include "vast/Dialect/HighLevel/HighLevelAttributes.hpp"
#include "vast/Dialect/HighLevel/HighLevelTypes.hpp"
#include "vast/Dialect/HighLevel/HighLevelOps.hpp"
#include "vast/Dialect/HighLevel/HighLevelTypes.hpp"

#include "vast/Dialect/Core/CoreAttributes.hpp"

Expand All @@ -27,10 +27,10 @@ VAST_UNRELAX_WARNINGS
#include "vast/Dialect/Core/CoreOps.hpp"
#include "vast/Dialect/Core/TypeTraits.hpp"

#include "vast/Util/Common.hpp"
#include "vast/Util/Symbols.hpp"
#include "vast/Util/Terminator.hpp"
#include "vast/Util/TypeList.hpp"
#include "vast/Util/Common.hpp"

#include "vast/Conversion/Common/Passes.hpp"
#include "vast/Conversion/TypeConverters/LLVMTypeConverter.hpp"
Expand All @@ -53,15 +53,12 @@ namespace vast::conv::irstollvm
using op_t = ll::StructGEPOp;

logical_result matchAndRewrite(
op_t op, typename op_t::Adaptor ops,
conversion_rewriter &rewriter) const override
{
std::vector< mlir::LLVM::GEPArg > indices { 0ul, ops.getIdx() };
op_t op, typename op_t::Adaptor ops, conversion_rewriter &rewriter
) const override {
std::vector< mlir::LLVM::GEPArg > indices{ 0ul, ops.getIdx() };
auto gep = rewriter.create< mlir::LLVM::GEPOp >(
op.getLoc(),
convert(op.getType()),
ops.getRecord(),
indices);
op.getLoc(), convert(op.getType()), ops.getRecord(), indices
);

rewriter.replaceOp(op, gep);
return mlir::success();
Expand All @@ -75,84 +72,65 @@ namespace vast::conv::irstollvm

using op_t = ll::Extract;

std::size_t to_number(mlir::TypedAttr attr) const
{
std::size_t to_number(mlir::TypedAttr attr) const {
auto int_attr = mlir::dyn_cast< mlir::IntegerAttr >(attr);
VAST_CHECK(int_attr, "Cannot convert {0} to `mlir::IntegerAttr`.", attr);

return int_attr.getUInt();
}

bool is_consistent(op_t op) const
{
auto size = to_number(op.getTo()) - to_number(op.getFrom()) + 1;
const auto &dl = this->type_converter().getDataLayoutAnalysis()
->getAtOrAbove(op);
bool is_consistent(op_t op) const {
auto size = to_number(op.getTo()) - to_number(op.getFrom()) + 1;
const auto &dl = this->type_converter().getDataLayoutAnalysis()->getAtOrAbove(op);
auto target_bw = dl.getTypeSizeInBits(convert(op.getType()));

return target_bw != size;
}

logical_result matchAndRewrite(
op_t op, typename op_t::Adaptor ops,
conversion_rewriter &rewriter) const override
{
op_t op, typename op_t::Adaptor ops, conversion_rewriter &rewriter
) const override {
auto loc = op.getLoc();

auto value = [&]() -> mlir::Value
{
auto value = [&]() -> mlir::Value {
auto arg = ops.getArg();
if (auto ptr = mlir::dyn_cast< mlir::LLVM::LLVMPointerType >(arg.getType()))
{
if (auto ptr = mlir::dyn_cast< mlir::LLVM::LLVMPointerType >(arg.getType())) {
return rewriter.create< mlir::LLVM::LoadOp >(
op.getLoc(),
ptr.getElementType(),
arg);
op.getLoc(), ptr.getElementType(), arg
);
}
return arg;
}();

auto i8_type = mlir::IntegerType::get(getContext(), 8);

auto extract = [&](auto from, auto pos) -> mlir::Value
{
auto extract = [&](auto from, auto pos) -> mlir::Value {
auto shift = rewriter.create< mlir::LLVM::LShrOp >(
loc,
value,
iN(rewriter, loc, value.getType(), from));
auto trunc = rewriter.create< mlir::LLVM::TruncOp >(
loc,
i8_type,
shift);
auto zext = rewriter.create< mlir::LLVM::ZExtOp >(
loc,
convert(op.getType()),
trunc);

if (pos == 0)
loc, value, iN(rewriter, loc, value.getType(), from)
);
auto trunc = rewriter.create< mlir::LLVM::TruncOp >(loc, i8_type, shift);
auto zext =
rewriter.create< mlir::LLVM::ZExtOp >(loc, convert(op.getType()), trunc);

if (pos == 0) {
return zext;
}

return rewriter.create< mlir::LLVM::ShlOp >(
loc,
convert(op.getType()),
zext,
iN(rewriter, loc, convert(op.getType()), pos));

loc, convert(op.getType()), zext,
iN(rewriter, loc, convert(op.getType()), pos)
);
};

mlir::Value head = iN(rewriter, loc, convert(op.getType()), 0);
// TODO(conv:abi): It may be possible we don't need this in the end and plain
// `shift & trunc` will work. I am leaving it here for now as it
// seems to work.
for (std::size_t i = 0; i < op.size() / 8; ++i)
{
for (std::size_t i = 0; i < op.size() / 8; ++i) {
auto offset = op.from() + i * 8;
auto byte = extract(offset, i * 8);
head = rewriter.create< mlir::LLVM::OrOp >(
loc,
convert(op.getType()),
byte,
head);
auto byte = extract(offset, i * 8);
head =
rewriter.create< mlir::LLVM::OrOp >(loc, convert(op.getType()), byte, head);
}
rewriter.replaceOp(op, { head });
return mlir::success();
Expand All @@ -166,44 +144,33 @@ namespace vast::conv::irstollvm

using op_t = ll::Concat;

std::size_t bw(operation op) const
{
std::size_t bw(operation op) const {
VAST_ASSERT(op->getNumResults() == 1);
const auto &dl = this->type_converter().getDataLayoutAnalysis()
->getAtOrAbove(op);
const auto &dl = this->type_converter().getDataLayoutAnalysis()->getAtOrAbove(op);
return dl.getTypeSizeInBits(convert(op->getResult(0).getType()));
}

logical_result matchAndRewrite(
op_t op, typename op_t::Adaptor ops,
conversion_rewriter &rewriter) const override
{
op_t op, typename op_t::Adaptor ops, conversion_rewriter &rewriter
) const override {
auto loc = op.getLoc();

auto resize = [&](auto w) -> mlir::Value
{
auto resize = [&](auto w) -> mlir::Value {
auto trg_type = convert(op.getType());
if (w.getType() == trg_type)
if (w.getType() == trg_type) {
return w;
return rewriter.create< mlir::LLVM::ZExtOp >(
loc,
trg_type,
w);
}
return rewriter.create< mlir::LLVM::ZExtOp >(loc, trg_type, w);
};
mlir::Value head = resize(ops.getOperands()[0]);

std::size_t start = bw(ops.getOperands()[0].getDefiningOp());
for (std::size_t i = 1; i < ops.getOperands().size(); ++i)
{
auto full = resize(ops.getOperands()[i]);
for (std::size_t i = 1; i < ops.getOperands().size(); ++i) {
auto full = resize(ops.getOperands()[i]);
auto shifted = rewriter.create< mlir::LLVM::ShlOp >(
loc,
full,
mk_index(loc, start, rewriter));
head = rewriter.create< mlir::LLVM::OrOp >(
loc,
head,
shifted->getResult(0));
loc, full, mk_index(loc, start, rewriter)
);
head = rewriter.create< mlir::LLVM::OrOp >(loc, head, shifted->getResult(0));

start += bw(ops.getOperands()[i].getDefiningOp());
}
Expand All @@ -226,9 +193,8 @@ namespace vast::conv::irstollvm
using base::base;

logical_result matchAndRewrite(
Op unit_op, typename Op::Adaptor ops,
conversion_rewriter &rewriter) const override
{
Op unit_op, typename Op::Adaptor ops, conversion_rewriter &rewriter
) const override {
auto parent = unit_op.getBody().getParentRegion();
rewriter.inlineRegionBefore(unit_op.getBody(), *parent, parent->end());

Expand All @@ -239,7 +205,6 @@ namespace vast::conv::irstollvm
rewriter.eraseOp(unit_op);
return logical_result::success();
}

};

template< typename op_t >
Expand All @@ -250,37 +215,29 @@ namespace vast::conv::irstollvm

using Op = op_t;

mlir::Block *start_block(Op op) const override
{
return &*op.getBody().begin();
}
mlir::Block *start_block(Op op) const override { return &*op.getBody().begin(); }

auto matchAndRewrite(Op op, typename Op::Adaptor ops,
conversion_rewriter &rewriter) const
-> logical_result override
{
auto
matchAndRewrite(Op op, typename Op::Adaptor ops, conversion_rewriter &rewriter) const
-> logical_result override {
// If we do not have any branching inside, we can just "inline"
// the op.
if (op.getBody().hasOneBlock())
if (op.getBody().hasOneBlock()) {
return base::handle_singleblock(op, ops, rewriter);
}

return base::handle_multiblock(op, ops, rewriter);
}
};

using label_stmt = hl_scopelike< hl::LabelStmt >;
using scope_op = hl_scopelike< core::ScopeOp >;
using scope_op = hl_scopelike< core::ScopeOp >;

using label_patterns = util::type_list<
erase_pattern< hl::LabelDeclOp >,
label_stmt
>;
using label_patterns = util::type_list< erase_pattern< hl::LabelDeclOp >, label_stmt >;

// TODO(conv): Figure out if these can be somehow unified.
using inline_region_from_op_conversions = util::type_list<
inline_region_from_op< hl::TranslationUnitOp >,
scope_op
>;
using inline_region_from_op_conversions =
util::type_list< inline_region_from_op< hl::TranslationUnitOp >, scope_op >;

struct subscript : base_pattern< hl::SubscriptOp >
{
Expand Down

0 comments on commit 11ea05c

Please sign in to comment.