diff --git a/include/vast/Conversion/Parser/Passes.hpp b/include/vast/Conversion/Parser/Passes.hpp index f43230cd56..e0f3c2f5eb 100644 --- a/include/vast/Conversion/Parser/Passes.hpp +++ b/include/vast/Conversion/Parser/Passes.hpp @@ -12,6 +12,7 @@ VAST_UNRELAX_WARNINGS namespace vast { std::unique_ptr< mlir::Pass > createHLToParserPass(); + std::unique_ptr< mlir::Pass > createParserReconcileCastsPass(); std::unique_ptr< mlir::Pass > createParserSourceToSarifPass(); // Generate the code for registering passes. diff --git a/include/vast/Conversion/Parser/Passes.td b/include/vast/Conversion/Parser/Passes.td index e22588dbf8..585fb54b82 100644 --- a/include/vast/Conversion/Parser/Passes.td +++ b/include/vast/Conversion/Parser/Passes.td @@ -19,4 +19,14 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> { ]; } +def ParserReconcileCasts : Pass<"vast-parser-reconcile-casts", "core::ModuleOp"> { + let summary = "Reconcile casts in parser dialect."; + let description = [{ WIP }]; + + let constructor = "vast::createParserReconcileCastsPass()"; + let dependentDialects = [ + "vast::pr::ParserDialect" + ]; +} + #endif // VAST_CONVERSION_PARSER_PASSES_TD diff --git a/include/vast/Dialect/Parser/Ops.td b/include/vast/Dialect/Parser/Ops.td index 30b47cee05..65bb3dfdbb 100644 --- a/include/vast/Dialect/Parser/Ops.td +++ b/include/vast/Dialect/Parser/Ops.td @@ -19,6 +19,8 @@ def Parser_Source { let summary = "Source of parsed data."; + let hasFolder = 1; + let assemblyFormat = [{ $arguments attr-dict `:` functional-type($arguments, $result) }]; @@ -31,6 +33,8 @@ def Paser_Sink { let summary = "Sink of parsed data."; + let hasFolder = 1; + let assemblyFormat = [{ $arguments attr-dict `:` functional-type($arguments, $result) }]; @@ -43,6 +47,8 @@ def Parser_Parse { let summary = "Parsing operation data."; + let hasFolder = 1; + let assemblyFormat = [{ $arguments attr-dict `:` functional-type($arguments, $result) }]; @@ -69,6 +75,8 @@ def Parse_MaybeParse { let summary = "Maybe parsing operation data."; + let hasFolder = 1; + let assemblyFormat = [{ $arguments attr-dict `:` functional-type($arguments, $result) }]; @@ -76,13 +84,15 @@ def Parse_MaybeParse def Parse_Cast : Parser_Op< "cast" > - , Arguments< (ins Parser_AnyDataType:$arguments) > + , Arguments< (ins Parser_AnyDataType:$operand) > , Results< (outs Parser_AnyDataType:$result) > { let summary = "Casting operation."; + let hasFolder = 1; + let assemblyFormat = [{ - $arguments attr-dict `:` functional-type($arguments, $result) + $operand attr-dict `:` functional-type($operand, $result) }]; } diff --git a/lib/vast/Conversion/Parser/CMakeLists.txt b/lib/vast/Conversion/Parser/CMakeLists.txt index 460396bd31..735d5ee512 100644 --- a/lib/vast/Conversion/Parser/CMakeLists.txt +++ b/lib/vast/Conversion/Parser/CMakeLists.txt @@ -2,4 +2,5 @@ add_vast_conversion_library(ParserConversionPasses ToParser.cpp + ReconcileCasts.cpp ) diff --git a/lib/vast/Conversion/Parser/ReconcileCasts.cpp b/lib/vast/Conversion/Parser/ReconcileCasts.cpp new file mode 100644 index 0000000000..b2526a78ca --- /dev/null +++ b/lib/vast/Conversion/Parser/ReconcileCasts.cpp @@ -0,0 +1,81 @@ +#include "vast/Util/Warnings.hpp" + +#include "vast/Conversion/Parser/Passes.hpp" + +VAST_RELAX_WARNINGS +#include +#include +#include +VAST_UNRELAX_WARNINGS + +#include "PassesDetails.hpp" +#include "Utils.hpp" + +#include "vast/Conversion/Common/Mixins.hpp" +#include "vast/Conversion/Common/Patterns.hpp" + +#include "vast/Dialect/Parser/Ops.hpp" +#include "vast/Dialect/Parser/Types.hpp" + +namespace vast::conv { + + namespace pattern { + + struct UnrealizedCastConversion + : one_to_one_conversion_pattern< mlir::UnrealizedConversionCastOp, pr::Cast > + { + using op_t = mlir::UnrealizedConversionCastOp; + using base = one_to_one_conversion_pattern< mlir::UnrealizedConversionCastOp, pr::Cast >; + using base::base; + + using adaptor_t = typename op_t::Adaptor; + + logical_result matchAndRewrite( + op_t op, adaptor_t adaptor, conversion_rewriter &rewriter + ) const override { + if (op.getNumOperands() != 1) { + return mlir::failure(); + } + + auto src = mlir::dyn_cast< mlir::UnrealizedConversionCastOp >(op.getOperand(0).getDefiningOp()); + + if (!src || src.getNumOperands() != 1) { + return mlir::failure(); + } + + if (pr::is_parser_type(src.getOperand(0).getType())) { + rewriter.replaceOpWithNewOp< pr::Cast >(op, op.getType(0), src.getOperand(0)); + return mlir::success(); + } + + return mlir::success(); + } + + static void legalize(base_conversion_config &cfg) { + cfg.target.addLegalOp< pr::Cast >(); + } + }; + + using cast_conversions = util::type_list< UnrealizedCastConversion >; + + } // namespace pattern + + struct ParserReconcileCastsPass + : ConversionPassMixin< ParserReconcileCastsPass, ParserReconcileCastsBase > + { + using base = ConversionPassMixin< ParserReconcileCastsPass, ParserReconcileCastsBase >; + + static conversion_target create_conversion_target(mcontext_t &mctx) { + return conversion_target(mctx); + } + + static void populate_conversions(auto &cfg) { + base::populate_conversions< pattern::cast_conversions >(cfg); + } + }; + +} // namespace vast::conv + +std::unique_ptr< mlir::Pass > vast::createParserReconcileCastsPass() { + return std::make_unique< vast::conv::ParserReconcileCastsPass >(); +} diff --git a/lib/vast/Conversion/Parser/ToParser.cpp b/lib/vast/Conversion/Parser/ToParser.cpp index 76365065b4..cc4a9aa83d 100644 --- a/lib/vast/Conversion/Parser/ToParser.cpp +++ b/lib/vast/Conversion/Parser/ToParser.cpp @@ -14,6 +14,7 @@ VAST_RELAX_WARNINGS VAST_UNRELAX_WARNINGS #include "PassesDetails.hpp" +#include "Utils.hpp" #include "vast/Conversion/Common/Mixins.hpp" #include "vast/Conversion/Common/Patterns.hpp" @@ -32,29 +33,13 @@ VAST_UNRELAX_WARNINGS namespace vast::conv { - enum class data_type { data, nodata, maybedata }; - - mlir_type to_mlir_type(data_type type, mcontext_t *mctx) { - switch (type) { - case data_type::data: return pr::DataType::get(mctx); - case data_type::nodata: return pr::NoDataType::get(mctx); - case data_type::maybedata: return pr::MaybeDataType::get(mctx); - } - } - - template< typename... Ts > - auto is_one_of(mlir_type ty) { return (mlir::isa< Ts >(ty) || ...); } - - bool is_parser_type(mlir_type ty) { - return is_one_of< pr::DataType, pr::NoDataType, pr::MaybeDataType >(ty); - } enum class function_category { sink, source, parser, nonparser }; struct function_model { - data_type return_type; - std::vector< data_type > arguments; + pr::data_type return_type; + std::vector< pr::data_type > arguments; function_category category; bool is_sink() const { return category == function_category::sink; } @@ -92,7 +77,7 @@ namespace vast::conv { } // namespace vast::conv -LLVM_YAML_IS_SEQUENCE_VECTOR(vast::conv::data_type); +LLVM_YAML_IS_SEQUENCE_VECTOR(vast::pr::data_type); LLVM_YAML_IS_SEQUENCE_VECTOR(vast::conv::named_function_model); using llvm::yaml::IO; @@ -100,12 +85,12 @@ using llvm::yaml::MappingTraits; using llvm::yaml::ScalarEnumerationTraits; template<> -struct ScalarEnumerationTraits< vast::conv::data_type > +struct ScalarEnumerationTraits< vast::pr::data_type > { - static void enumeration(IO &io, vast::conv::data_type &value) { - io.enumCase(value, "data", vast::conv::data_type::data); - io.enumCase(value, "nodata", vast::conv::data_type::nodata); - io.enumCase(value, "maybedata", vast::conv::data_type::maybedata); + static void enumeration(IO &io, vast::pr::data_type &value) { + io.enumCase(value, "data", vast::pr::data_type::data); + io.enumCase(value, "nodata", vast::pr::data_type::nodata); + io.enumCase(value, "maybedata", vast::pr::data_type::maybedata); } }; @@ -254,12 +239,12 @@ namespace vast::conv { for (auto val : values) { out.push_back([&] () -> mlir_value { if (auto cast = mlir::dyn_cast< mlir::UnrealizedConversionCastOp >(val.getDefiningOp())) { - if (is_parser_type(cast.getOperand(0).getType())) { + if (pr::is_parser_type(cast.getOperand(0).getType())) { return cast.getOperand(0); } } - if (!is_parser_type(val.getType())) { + if (!pr::is_parser_type(val.getType())) { return rewriter.template create< mlir::UnrealizedConversionCastOp >( val.getLoc(), pr::MaybeDataType::get(val.getContext()), val ).getResult(0); @@ -327,7 +312,7 @@ namespace vast::conv { logical_result matchAndRewrite( op_t op, adaptor_t adaptor, conversion_rewriter &rewriter ) const override { - auto rty = to_mlir_type(data_type::nodata, rewriter.getContext()); + auto rty = to_mlir_type(pr::data_type::nodata, rewriter.getContext()); auto args = convert_value_types(adaptor.getOperands(), rty, rewriter); auto converted = rewriter.create< pr::NoParse >(op.getLoc(), rty, args); rewriter.replaceOpWithNewOp< mlir::UnrealizedConversionCastOp >( @@ -476,7 +461,7 @@ namespace vast::conv { op_t op, adaptor_t adaptor, conversion_rewriter &rewriter ) const override { auto rewrite = [&] (auto ty) { - ty = is_parser_type(ty) ? ty : pr::MaybeDataType::get(rewriter.getContext()); + ty = pr::is_parser_type(ty) ? ty : pr::MaybeDataType::get(rewriter.getContext()); auto converted = rewriter.create< pr::Ref >(op.getLoc(), ty, op.getName()); rewriter.replaceOpWithNewOp< mlir::UnrealizedConversionCastOp >( op, op.getType(), converted->getResult(0) @@ -526,7 +511,7 @@ namespace vast::conv { cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >(); cfg.target.addDynamicallyLegalOp< op_t >([](op_t op) { for (auto ty : op.getResult().getType()) { - if (!is_parser_type(ty)) { + if (!pr::is_parser_type(ty)) { return false; } } @@ -567,6 +552,24 @@ namespace vast::conv { } }; + struct AssignConversion + : one_to_one_conversion_pattern< hl::AssignOp, pr::Assign > + { + using op_t = hl::AssignOp; + using base = one_to_one_conversion_pattern< op_t, pr::Assign >; + using base::base; + + using adaptor_t = typename op_t::Adaptor; + + logical_result matchAndRewrite( + op_t op, adaptor_t adaptor, conversion_rewriter &rewriter + ) const override { + auto args = realized_operand_values(adaptor.getOperands(), rewriter); + rewriter.replaceOpWithNewOp< pr::Assign >(op, std::vector< mlir_type >(), args); + return mlir::success(); + } + }; + struct ExprConversion : parser_conversion_pattern_base< hl::ExprOp > { @@ -611,7 +614,7 @@ namespace vast::conv { logical_result matchAndRewrite( op_t op, adaptor_t adaptor, conversion_rewriter &rewriter ) const override { - auto maybe = to_mlir_type(data_type::maybedata, rewriter.getContext()); + auto maybe = to_mlir_type(pr::data_type::maybedata, rewriter.getContext()); /* auto decl = */ rewriter.create< pr::Decl >(op.getLoc(), op.getSymName(), maybe); if (auto &init_region = op.getInitializer(); !init_region.empty()) { @@ -658,6 +661,7 @@ namespace vast::conv { ToNoParse< hl::MulFOp >, ToNoParse< hl::DivFOp >, ToNoParse< hl::RemFOp >, // Other operations + AssignConversion, ExprConversion, FuncConversion, ParamConversion, diff --git a/lib/vast/Conversion/Parser/Utils.hpp b/lib/vast/Conversion/Parser/Utils.hpp new file mode 100644 index 0000000000..7e3a182039 --- /dev/null +++ b/lib/vast/Conversion/Parser/Utils.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "vast/Util/Common.hpp" + +#include "vast/Dialect/Parser/Types.hpp" + +namespace vast::pr { + + enum class data_type { data, nodata, maybedata }; + + static inline mlir_type to_mlir_type(data_type type, mcontext_t *mctx) { + switch (type) { + case data_type::data: return pr::DataType::get(mctx); + case data_type::nodata: return pr::NoDataType::get(mctx); + case data_type::maybedata: return pr::MaybeDataType::get(mctx); + } + } + + template< typename... Ts > + auto is_one_of(mlir_type ty) { return (mlir::isa< Ts >(ty) || ...); } + + static inline bool is_parser_type(mlir_type ty) { + return is_one_of< pr::DataType, pr::NoDataType, pr::MaybeDataType >(ty); + } + +} // namespace vast::pr diff --git a/lib/vast/Dialect/Parser/Ops.cpp b/lib/vast/Dialect/Parser/Ops.cpp index f9b5ef7639..3d8d312458 100644 --- a/lib/vast/Dialect/Parser/Ops.cpp +++ b/lib/vast/Dialect/Parser/Ops.cpp @@ -15,20 +15,53 @@ using namespace vast::pr; namespace vast::pr { - using fold_result_t = ::llvm::SmallVectorImpl< ::mlir::OpFoldResult >; + using fold_result = ::mlir::OpFoldResult; + using fold_results = ::llvm::SmallVectorImpl< fold_result >; - logical_result NoParse::fold(FoldAdaptor adaptor, fold_result_t &results) { - auto change = mlir::failure(); - auto op = getOperation(); + template< typename op_t > + logical_result forward_same_operation( + op_t op, auto adaptor, fold_results &results + ) { + if (op.getNumOperands() == 1 && op.getNumResults() == 1) { + if (auto operand = op.getOperand(0); mlir::isa< op_t >(operand.getDefiningOp())) { + if (operand.getType() == op->getOpResult(0).getType()) { + results.push_back(operand); + return mlir::success(); + } + } + } + + return mlir::failure(); + } + + logical_result Source::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } + + logical_result Sink::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } + + logical_result Parse::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } + + logical_result NoParse::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } + + logical_result MaybeParse::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } - for (auto [idx, operand] : llvm::reverse(llvm::enumerate(getOperands()))) { - if (auto noparse = mlir::dyn_cast< NoParse >(operand.getDefiningOp())) { - op->eraseOperand(idx); - change = mlir::success(); + fold_result Cast::fold(FoldAdaptor adaptor) { + if (auto operand = getOperand(); mlir::isa< Cast >(operand.getDefiningOp())) { + if (operand.getType() == getType()) { + return operand; } } - return change; + return {}; } } // namespace vast::pr