From 45b08ba6e52c4c976a08d5375303da5f05dd98dd Mon Sep 17 00:00:00 2001 From: Maksim Vlasov Date: Sun, 12 May 2024 21:18:01 +0300 Subject: [PATCH] Solidify native types and add type deduction for attributes (#155) Initialize ConstantOp using *a template function* instead of multiple overloads (therefore it should be checked to hold one of supported native types during the optree verification stage). Introduce new helpers: `typeOneOf`, `canHoldAlternative` for variants, `advanceEarly` overload for ranges. --- compiler/include/compiler/optree/adaptors.hpp | 10 +++--- .../include/compiler/optree/attribute.hpp | 31 ++++++++++++++++--- .../include/compiler/optree/declarative.hpp | 13 +------- compiler/include/compiler/optree/types.hpp | 14 +++++++++ compiler/include/compiler/utils/helpers.hpp | 20 ++++++++++++ .../optree_to_llvmir/llvmir_generator.cpp | 6 ++-- compiler/lib/frontend/converter/converter.cpp | 11 ++++--- compiler/lib/optree/adaptors.cpp | 21 ------------- compiler/lib/optree/attribute.cpp | 16 +++++----- .../optree/optimizer/erase_unused_ops.cpp | 6 ++-- compiler/tests/optree/declarative.cpp | 14 ++++----- 11 files changed, 94 insertions(+), 68 deletions(-) diff --git a/compiler/include/compiler/optree/adaptors.hpp b/compiler/include/compiler/optree/adaptors.hpp index 5eff2c52..cf22931f 100644 --- a/compiler/include/compiler/optree/adaptors.hpp +++ b/compiler/include/compiler/optree/adaptors.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -71,10 +70,11 @@ struct ReturnOp : Adaptor { struct ConstantOp : Adaptor { OPTREE_ADAPTOR_HELPER(Adaptor, "Constant") - void init(const Type::Ptr &type, int64_t value); - void init(const Type::Ptr &type, bool value); - void init(const Type::Ptr &type, double value); - void init(const Type::Ptr &type, const std::string &value); + template + void init(const Type::Ptr &type, const T &value) { + op->results.emplace_back(Value::make(type, op)); + op->addAttr(value); + } OPTREE_ADAPTOR_ATTRIBUTE_OPAQUE(value, 0) OPTREE_ADAPTOR_RESULT(result, 0) diff --git a/compiler/include/compiler/optree/attribute.hpp b/compiler/include/compiler/optree/attribute.hpp index 1f9155b5..e0ec078a 100644 --- a/compiler/include/compiler/optree/attribute.hpp +++ b/compiler/include/compiler/optree/attribute.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -8,13 +7,19 @@ #include "compiler/optree/definitions.hpp" #include "compiler/optree/types.hpp" +#include "compiler/utils/helpers.hpp" namespace optree { struct Attribute { - std::variant - storage; + using Storage = std::variant< + // + std::monostate, NativeInt, NativeBool, NativeFloat, NativeStr, Type::Ptr, ArithBinOpKind, ArithCastOpKind, + LogicBinOpKind, LogicUnaryOpKind + // + >; + + Storage storage; Attribute() = default; Attribute(const Attribute &) = default; @@ -25,7 +30,18 @@ struct Attribute { Attribute &operator=(Attribute &&) = default; template - explicit Attribute(const VariantType &value) : storage(value){}; + explicit Attribute(const VariantType &value) { + if constexpr (Attribute::canHold()) + set(value); + else if constexpr (std::is_integral_v) + set(static_cast(value)); + else if constexpr (std::is_floating_point_v) + set(static_cast(value)); + else if constexpr (std::is_constructible_v) + set(std::string(value)); + else + throw std::bad_variant_access(); + } template bool is() const noexcept { @@ -72,6 +88,11 @@ struct Attribute { } void dump(std::ostream &stream) const; + + template + static constexpr bool canHold() { + return utils::canHoldAlternative; + } }; } // namespace optree diff --git a/compiler/include/compiler/optree/declarative.hpp b/compiler/include/compiler/optree/declarative.hpp index 26980152..ffcec02d 100644 --- a/compiler/include/compiler/optree/declarative.hpp +++ b/compiler/include/compiler/optree/declarative.hpp @@ -1,11 +1,9 @@ #pragma once #include -#include #include #include #include -#include #include #include @@ -132,16 +130,7 @@ class DeclarativeModule { template DeclarativeModule &attr(const T &value) { - if constexpr (std::is_same_v, bool>) - current->addAttr(value); - else if constexpr (std::is_same_v, const char *>) - current->addAttr(std::string(value)); - else if constexpr (std::is_integral_v) - current->addAttr(static_cast(value)); - else if constexpr (std::is_floating_point_v) - current->addAttr(static_cast(value)); - else - current->addAttr(value); + current->addAttr(value); return *this; } diff --git a/compiler/include/compiler/optree/types.hpp b/compiler/include/compiler/optree/types.hpp index ea4ade78..2fd72394 100644 --- a/compiler/include/compiler/optree/types.hpp +++ b/compiler/include/compiler/optree/types.hpp @@ -1,7 +1,9 @@ #pragma once +#include #include #include +#include #include #include @@ -65,6 +67,7 @@ struct NoneType : public Type { struct IntegerType : public Type { using Ptr = std::shared_ptr; + using NativeType = int64_t; const unsigned width; @@ -79,6 +82,7 @@ struct IntegerType : public Type { struct BoolType : public IntegerType { using Ptr = std::shared_ptr; + using NativeType = bool; static constinit const unsigned intWidth = 8U; @@ -90,6 +94,7 @@ struct BoolType : public IntegerType { struct FloatType : public Type { using Ptr = std::shared_ptr; + using NativeType = double; const unsigned width; @@ -104,6 +109,7 @@ struct FloatType : public Type { struct StrType : public Type { using Ptr = std::shared_ptr; + using NativeType = std::string; const unsigned charWidth; @@ -167,4 +173,12 @@ struct TypeStorage { static StrType::Ptr strType(unsigned charWidth = 8U); }; +template +using NativeType = typename ConcreteType::NativeType; + +using NativeInt = NativeType; +using NativeBool = NativeType; +using NativeFloat = NativeType; +using NativeStr = NativeType; + } // namespace optree diff --git a/compiler/include/compiler/utils/helpers.hpp b/compiler/include/compiler/utils/helpers.hpp index e36bab25..02371c14 100644 --- a/compiler/include/compiler/utils/helpers.hpp +++ b/compiler/include/compiler/utils/helpers.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #if defined(_MSC_VER) && !defined(__clang__) // MSVC #define COMPILER_UNREACHABLE(MESSAGE) \ @@ -143,8 +144,21 @@ class ZippedRanges { } }; +template +struct CanHoldAlternative; + +template +struct CanHoldAlternative> + : std::disjunction...> {}; + } // namespace detail +template +constexpr bool typeOneOf = std::disjunction_v...>; + +template +constexpr bool canHoldAlternative = detail::CanHoldAlternative, VariantType>::value; + template void interleave(const Range &values, const UnaryPred &printValue, const NullaryPred &printSep) { if (std::empty(values)) @@ -167,6 +181,12 @@ auto advanceEarly(Iterator begin, Iterator end) { return detail::AdvanceEarlyRange(begin, end); } +template +auto advanceEarly(Range &&range) { + using Iterator = decltype(std::begin(range)); + return detail::AdvanceEarlyRange(std::begin(range), std::end(range)); +} + template auto reversed(Range &&range) { return detail::ReversedRange(range); diff --git a/compiler/lib/codegen/optree_to_llvmir/llvmir_generator.cpp b/compiler/lib/codegen/optree_to_llvmir/llvmir_generator.cpp index 01039d7f..d34545cb 100644 --- a/compiler/lib/codegen/optree_to_llvmir/llvmir_generator.cpp +++ b/compiler/lib/codegen/optree_to_llvmir/llvmir_generator.cpp @@ -216,14 +216,14 @@ void LLVMIRGenerator::visit(const ConstantOp &op) { if (type->is()) return result(llvm::ConstantInt::get(convertType(type), op.value().as())); if (type->is()) { - auto num = op.value().as(); + auto num = static_cast(op.value().as()); auto *value = llvm::ConstantInt::get(convertType(type), reinterpret_cast(num), /*IsSigned*/ true); return result(value); } if (type->is()) - return result(llvm::ConstantFP::get(convertType(type), op.value().as())); + return result(llvm::ConstantFP::get(convertType(type), static_cast(op.value().as()))); if (type->is()) - return result(getGlobalString(op.value().as())); + return result(getGlobalString(op.value().as())); COMPILER_UNREACHABLE("unexpected result type in ConstantOp"); } diff --git a/compiler/lib/frontend/converter/converter.cpp b/compiler/lib/frontend/converter/converter.cpp index 0d56af68..c3db9963 100644 --- a/compiler/lib/frontend/converter/converter.cpp +++ b/compiler/lib/frontend/converter/converter.cpp @@ -227,20 +227,23 @@ Value::Ptr visitExpression(const Node::Ptr &node, ConverterContext &ctx) { } Value::Ptr visitIntegerLiteralValue(const Node::Ptr &node, ConverterContext &ctx) { - auto value = static_cast(node->intNum()); + auto value = static_cast(node->intNum()); return ctx.insert(node->ref, TypeStorage::integerType(), value).result(); } Value::Ptr visitBooleanLiteralValue(const Node::Ptr &node, ConverterContext &ctx) { - return ctx.insert(node->ref, TypeStorage::boolType(), node->boolean()).result(); + auto value = static_cast(node->boolean()); + return ctx.insert(node->ref, TypeStorage::boolType(), value).result(); } Value::Ptr visitFloatingPointLiteralValue(const Node::Ptr &node, ConverterContext &ctx) { - return ctx.insert(node->ref, TypeStorage::floatType(), node->fpNum()).result(); + auto value = static_cast(node->fpNum()); + return ctx.insert(node->ref, TypeStorage::floatType(), value).result(); } Value::Ptr visitStringLiteralValue(const Node::Ptr &node, ConverterContext &ctx) { - return ctx.insert(node->ref, TypeStorage::strType(), node->str()).result(); + auto value = static_cast(node->str()); + return ctx.insert(node->ref, TypeStorage::strType(), value).result(); } Value::Ptr visitBinaryOperation(const Node::Ptr &node, ConverterContext &ctx) { diff --git a/compiler/lib/optree/adaptors.cpp b/compiler/lib/optree/adaptors.cpp index 4d3bf1dc..e8a47e37 100644 --- a/compiler/lib/optree/adaptors.cpp +++ b/compiler/lib/optree/adaptors.cpp @@ -1,6 +1,5 @@ #include "adaptors.hpp" -#include #include #include @@ -47,26 +46,6 @@ Value::Ptr ConditionOp::terminator() const { return op->body.back()->result(0); } -void ConstantOp::init(const Type::Ptr &type, int64_t value) { - op->results.emplace_back(Value::make(type, op)); - op->addAttr(value); -} - -void ConstantOp::init(const Type::Ptr &type, bool value) { - op->results.emplace_back(Value::make(type, op)); - op->addAttr(value); -} - -void ConstantOp::init(const Type::Ptr &type, double value) { - op->results.emplace_back(Value::make(type, op)); - op->addAttr(value); -} - -void ConstantOp::init(const Type::Ptr &type, const std::string &value) { - op->results.emplace_back(Value::make(type, op)); - op->addAttr(value); -} - void ElseOp::init() { } diff --git a/compiler/lib/optree/attribute.cpp b/compiler/lib/optree/attribute.cpp index bb6a1224..34fb58f5 100644 --- a/compiler/lib/optree/attribute.cpp +++ b/compiler/lib/optree/attribute.cpp @@ -15,20 +15,20 @@ void Attribute::dump(std::ostream &stream) const { stream << "empty"; return; } - if (is()) { - stream << "int64_t : " << as(); + if (is()) { + stream << "int : " << as(); return; } - if (is()) { - stream << "double : " << as(); + if (is()) { + stream << "float : " << as(); return; } - if (is()) { - stream << "bool : " << as(); + if (is()) { + stream << "bool : " << as(); return; } - if (is()) { - stream << "string : " << as(); + if (is()) { + stream << "str : " << as(); return; } if (is()) { diff --git a/compiler/tests/backend/optree/optimizer/erase_unused_ops.cpp b/compiler/tests/backend/optree/optimizer/erase_unused_ops.cpp index 79e820d1..148ce61f 100644 --- a/compiler/tests/backend/optree/optimizer/erase_unused_ops.cpp +++ b/compiler/tests/backend/optree/optimizer/erase_unused_ops.cpp @@ -30,7 +30,7 @@ TEST_F(EraseUnusedOpsTest, can_erase_unused_ops) { { auto &&[m, v] = getActual(); m.opInit("test", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody(); - v[2] = m.opInit(m.tI64, int64_t(123)); + v[2] = m.opInit(m.tI64, 123); v[3] = m.opInit(ArithBinOpKind::AddI, v[0], v[1]); v[4] = m.opInit(ArithCastOpKind::IntToFloat, m.tF64, v[0]); v[5] = m.opInit(LogicBinOpKind::LessEqualI, v[0], v[1]); @@ -52,7 +52,7 @@ TEST_F(EraseUnusedOpsTest, can_erase_chain_of_unused_ops) { { auto &&[m, v] = getActual(); m.opInit("test", m.tFunc({m.tI64}, m.tNone)).inward(v[0], 0).withBody(); - v[1] = m.opInit(m.tI64, int64_t(123)); + v[1] = m.opInit(m.tI64, 123); v[2] = m.opInit(ArithBinOpKind::AddI, v[0], v[1]); v[3] = m.opInit(ArithCastOpKind::IntToFloat, m.tF64, v[2]); m.opInit(); @@ -72,7 +72,7 @@ TEST_F(EraseUnusedOpsTest, can_keep_used_ops) { auto &&[m, v] = getActual(); m.opInit("test", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody(); - v[2] = m.opInit(m.tI64, int64_t(123)); + v[2] = m.opInit(m.tI64, 123); v[3] = m.opInit(ArithBinOpKind::AddI, v[0], v[1]); v[4] = m.opInit(ArithCastOpKind::IntToFloat, m.tF64, v[0]); v[5] = m.opInit(LogicBinOpKind::LessEqualI, v[0], v[1]); diff --git a/compiler/tests/optree/declarative.cpp b/compiler/tests/optree/declarative.cpp index 2cd69321..cca32e83 100644 --- a/compiler/tests/optree/declarative.cpp +++ b/compiler/tests/optree/declarative.cpp @@ -37,9 +37,9 @@ TEST_F(DeclarativeTest, can_insert_function_with_body) { m.endBody(); // clang-format on assertDump("Module () -> ()\n" - " Function {string : myfunc, Type : func((int(64), float(64)) -> none)} () -> () [#0 : int(64), #1 : " + " Function {str : myfunc, Type : func((int(64), float(64)) -> none)} () -> () [#0 : int(64), #1 : " "float(64)]\n" - " Constant {int64_t : 123} () -> (#2 : int(64))\n" + " Constant {int : 123} () -> (#2 : int(64))\n" " Allocate () -> (#3 : ptr(int(64)))\n" " ArithBinary {ArithBinOpKind : 1} (#2 : int(64), #1 : float(64)) -> (#4 : int(64))\n" " Store (#3 : ptr(int(64)), #4 : int(64)) -> ()\n" @@ -49,7 +49,7 @@ TEST_F(DeclarativeTest, can_insert_function_with_body) { TEST_F(DeclarativeTest, can_insert_with_adapted_init) { // clang-format off m.opInit("myfunc", m.tFunc({m.tI64, m.tF64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody(); - v[2] = m.opInit(m.tI64, int64_t(456L)); + v[2] = m.opInit(m.tI64, 456L); v[3] = m.opInit(m.tPtr(m.tI64)); v[4] = m.opInit(ArithBinOpKind::AddI, v[2], v[1]); m.opInit(v[3], v[4]); @@ -57,9 +57,9 @@ TEST_F(DeclarativeTest, can_insert_with_adapted_init) { m.endBody(); // clang-format on assertDump("Module () -> ()\n" - " Function {string : myfunc, Type : func((int(64), float(64)) -> none)} () -> () [#0 : int(64), #1 : " + " Function {str : myfunc, Type : func((int(64), float(64)) -> none)} () -> () [#0 : int(64), #1 : " "float(64)]\n" - " Constant {int64_t : 456} () -> (#2 : int(64))\n" + " Constant {int : 456} () -> (#2 : int(64))\n" " Allocate () -> (#3 : ptr(int(64)))\n" " ArithBinary {ArithBinOpKind : 1} (#2 : int(64), #1 : float(64)) -> (#4 : int(64))\n" " Store (#3 : ptr(int(64)), #4 : int(64)) -> ()\n" @@ -87,8 +87,8 @@ TEST_F(DeclarativeTest, can_insert_nested_operations) { m.endBody(); // clang-format on assertDump("Module () -> ()\n" - " Function {string : myfunc, Type : func((float(64)) -> none)} () -> () [#0 : float(64)]\n" - " Constant {double : 7.89} () -> (#1 : float(64))\n" + " Function {str : myfunc, Type : func((float(64)) -> none)} () -> () [#0 : float(64)]\n" + " Constant {float : 7.89} () -> (#1 : float(64))\n" " Allocate () -> (#2 : ptr(float(64)))\n" " LogicBinary {LogicBinOpKind : 12} (#0 : float(64), #1 : float(64)) -> (#3 : int(8))\n" " If (#3 : int(8)) -> ()\n"