Skip to content

Commit

Permalink
Solidify native types and add type deduction for attributes (#155)
Browse files Browse the repository at this point in the history
Initialize ConstantOp using *a template function* instead of multiple
overloads (therefore it should be checked to hold one of supported
native types during the optree verification stage).

Introduce new helpers: `typeOneOf`, `canHoldAlternative` for variants,
`advanceEarly` overload for ranges.
  • Loading branch information
vla5924 authored May 12, 2024
1 parent 4a3fa1e commit 45b08ba
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 68 deletions.
10 changes: 5 additions & 5 deletions compiler/include/compiler/optree/adaptors.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include <cstdint>
#include <string>
#include <vector>

Expand Down Expand Up @@ -71,10 +70,11 @@ struct ReturnOp : Adaptor {
struct ConstantOp : Adaptor {
OPTREE_ADAPTOR_HELPER(Adaptor, "Constant")

void init(const Type::Ptr &type, int64_t value);
void init(const Type::Ptr &type, bool value);
void init(const Type::Ptr &type, double value);
void init(const Type::Ptr &type, const std::string &value);
template <typename T>
void init(const Type::Ptr &type, const T &value) {
op->results.emplace_back(Value::make(type, op));
op->addAttr(value);
}

OPTREE_ADAPTOR_ATTRIBUTE_OPAQUE(value, 0)
OPTREE_ADAPTOR_RESULT(result, 0)
Expand Down
31 changes: 26 additions & 5 deletions compiler/include/compiler/optree/attribute.hpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
#pragma once

#include <cstdint>
#include <memory>
#include <ostream>
#include <string>
#include <variant>

#include "compiler/optree/definitions.hpp"
#include "compiler/optree/types.hpp"
#include "compiler/utils/helpers.hpp"

namespace optree {

struct Attribute {
std::variant<std::monostate, int64_t, double, bool, std::string, Type::Ptr, ArithBinOpKind, ArithCastOpKind,
LogicBinOpKind, LogicUnaryOpKind>
storage;
using Storage = std::variant<
//
std::monostate, NativeInt, NativeBool, NativeFloat, NativeStr, Type::Ptr, ArithBinOpKind, ArithCastOpKind,
LogicBinOpKind, LogicUnaryOpKind
//
>;

Storage storage;

Attribute() = default;
Attribute(const Attribute &) = default;
Expand All @@ -25,7 +30,18 @@ struct Attribute {
Attribute &operator=(Attribute &&) = default;

template <typename VariantType>
explicit Attribute(const VariantType &value) : storage(value){};
explicit Attribute(const VariantType &value) {
if constexpr (Attribute::canHold<VariantType>())
set(value);
else if constexpr (std::is_integral_v<VariantType>)
set(static_cast<NativeInt>(value));
else if constexpr (std::is_floating_point_v<VariantType>)
set(static_cast<NativeFloat>(value));
else if constexpr (std::is_constructible_v<std::string, VariantType>)
set(std::string(value));
else
throw std::bad_variant_access();
}

template <typename VariantType>
bool is() const noexcept {
Expand Down Expand Up @@ -72,6 +88,11 @@ struct Attribute {
}

void dump(std::ostream &stream) const;

template <typename T>
static constexpr bool canHold() {
return utils::canHoldAlternative<T, Storage>;
}
};

} // namespace optree
13 changes: 1 addition & 12 deletions compiler/include/compiler/optree/declarative.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#pragma once

#include <cstddef>
#include <cstdint>
#include <ostream>
#include <string>
#include <string_view>
#include <type_traits>
#include <unordered_map>
#include <utility>

Expand Down Expand Up @@ -132,16 +130,7 @@ class DeclarativeModule {

template <typename T>
DeclarativeModule &attr(const T &value) {
if constexpr (std::is_same_v<std::remove_cvref_t<T>, bool>)
current->addAttr(value);
else if constexpr (std::is_same_v<std::remove_cvref_t<T>, const char *>)
current->addAttr(std::string(value));
else if constexpr (std::is_integral_v<T>)
current->addAttr(static_cast<int64_t>(value));
else if constexpr (std::is_floating_point_v<T>)
current->addAttr(static_cast<double>(value));
else
current->addAttr(value);
current->addAttr(value);
return *this;
}

Expand Down
14 changes: 14 additions & 0 deletions compiler/include/compiler/optree/types.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#pragma once

#include <cstdint>
#include <memory>
#include <ostream>
#include <string>
#include <type_traits>
#include <vector>

Expand Down Expand Up @@ -65,6 +67,7 @@ struct NoneType : public Type {

struct IntegerType : public Type {
using Ptr = std::shared_ptr<const IntegerType>;
using NativeType = int64_t;

const unsigned width;

Expand All @@ -79,6 +82,7 @@ struct IntegerType : public Type {

struct BoolType : public IntegerType {
using Ptr = std::shared_ptr<const BoolType>;
using NativeType = bool;

static constinit const unsigned intWidth = 8U;

Expand All @@ -90,6 +94,7 @@ struct BoolType : public IntegerType {

struct FloatType : public Type {
using Ptr = std::shared_ptr<const FloatType>;
using NativeType = double;

const unsigned width;

Expand All @@ -104,6 +109,7 @@ struct FloatType : public Type {

struct StrType : public Type {
using Ptr = std::shared_ptr<const StrType>;
using NativeType = std::string;

const unsigned charWidth;

Expand Down Expand Up @@ -167,4 +173,12 @@ struct TypeStorage {
static StrType::Ptr strType(unsigned charWidth = 8U);
};

template <typename ConcreteType>
using NativeType = typename ConcreteType::NativeType;

using NativeInt = NativeType<IntegerType>;
using NativeBool = NativeType<BoolType>;
using NativeFloat = NativeType<FloatType>;
using NativeStr = NativeType<StrType>;

} // namespace optree
20 changes: 20 additions & 0 deletions compiler/include/compiler/utils/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ostream>
#include <tuple>
#include <type_traits>
#include <variant>

#if defined(_MSC_VER) && !defined(__clang__) // MSVC
#define COMPILER_UNREACHABLE(MESSAGE) \
Expand Down Expand Up @@ -143,8 +144,21 @@ class ZippedRanges {
}
};

template <typename RequiredType, typename VariantType>
struct CanHoldAlternative;

template <typename RequiredType, typename... SupportedTypes>
struct CanHoldAlternative<RequiredType, std::variant<SupportedTypes...>>
: std::disjunction<std::is_same<RequiredType, SupportedTypes>...> {};

} // namespace detail

template <typename RequiredType, typename... AllowedTypes>
constexpr bool typeOneOf = std::disjunction_v<std::is_same<RequiredType, AllowedTypes>...>;

template <typename RequiredType, typename VariantType>
constexpr bool canHoldAlternative = detail::CanHoldAlternative<std::remove_cvref_t<RequiredType>, VariantType>::value;

template <typename Range, typename UnaryPred, typename NullaryPred>
void interleave(const Range &values, const UnaryPred &printValue, const NullaryPred &printSep) {
if (std::empty(values))
Expand All @@ -167,6 +181,12 @@ auto advanceEarly(Iterator begin, Iterator end) {
return detail::AdvanceEarlyRange<Iterator>(begin, end);
}

template <typename Range>
auto advanceEarly(Range &&range) {
using Iterator = decltype(std::begin(range));
return detail::AdvanceEarlyRange<Iterator>(std::begin(range), std::end(range));
}

template <typename Range>
auto reversed(Range &&range) {
return detail::ReversedRange<Range>(range);
Expand Down
6 changes: 3 additions & 3 deletions compiler/lib/codegen/optree_to_llvmir/llvmir_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,14 @@ void LLVMIRGenerator::visit(const ConstantOp &op) {
if (type->is<BoolType>())
return result(llvm::ConstantInt::get(convertType(type), op.value().as<bool>()));
if (type->is<IntegerType>()) {
auto num = op.value().as<int64_t>();
auto num = static_cast<int64_t>(op.value().as<NativeInt>());
auto *value = llvm::ConstantInt::get(convertType(type), reinterpret_cast<uint64_t &>(num), /*IsSigned*/ true);
return result(value);
}
if (type->is<FloatType>())
return result(llvm::ConstantFP::get(convertType(type), op.value().as<double>()));
return result(llvm::ConstantFP::get(convertType(type), static_cast<double>(op.value().as<NativeFloat>())));
if (type->is<StrType>())
return result(getGlobalString(op.value().as<std::string>()));
return result(getGlobalString(op.value().as<NativeStr>()));
COMPILER_UNREACHABLE("unexpected result type in ConstantOp");
}

Expand Down
11 changes: 7 additions & 4 deletions compiler/lib/frontend/converter/converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,23 @@ Value::Ptr visitExpression(const Node::Ptr &node, ConverterContext &ctx) {
}

Value::Ptr visitIntegerLiteralValue(const Node::Ptr &node, ConverterContext &ctx) {
auto value = static_cast<int64_t>(node->intNum());
auto value = static_cast<NativeInt>(node->intNum());
return ctx.insert<ConstantOp>(node->ref, TypeStorage::integerType(), value).result();
}

Value::Ptr visitBooleanLiteralValue(const Node::Ptr &node, ConverterContext &ctx) {
return ctx.insert<ConstantOp>(node->ref, TypeStorage::boolType(), node->boolean()).result();
auto value = static_cast<NativeBool>(node->boolean());
return ctx.insert<ConstantOp>(node->ref, TypeStorage::boolType(), value).result();
}

Value::Ptr visitFloatingPointLiteralValue(const Node::Ptr &node, ConverterContext &ctx) {
return ctx.insert<ConstantOp>(node->ref, TypeStorage::floatType(), node->fpNum()).result();
auto value = static_cast<NativeFloat>(node->fpNum());
return ctx.insert<ConstantOp>(node->ref, TypeStorage::floatType(), value).result();
}

Value::Ptr visitStringLiteralValue(const Node::Ptr &node, ConverterContext &ctx) {
return ctx.insert<ConstantOp>(node->ref, TypeStorage::strType(), node->str()).result();
auto value = static_cast<NativeStr>(node->str());
return ctx.insert<ConstantOp>(node->ref, TypeStorage::strType(), value).result();
}

Value::Ptr visitBinaryOperation(const Node::Ptr &node, ConverterContext &ctx) {
Expand Down
21 changes: 0 additions & 21 deletions compiler/lib/optree/adaptors.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "adaptors.hpp"

#include <cstdint>
#include <string>
#include <vector>

Expand Down Expand Up @@ -47,26 +46,6 @@ Value::Ptr ConditionOp::terminator() const {
return op->body.back()->result(0);
}

void ConstantOp::init(const Type::Ptr &type, int64_t value) {
op->results.emplace_back(Value::make(type, op));
op->addAttr(value);
}

void ConstantOp::init(const Type::Ptr &type, bool value) {
op->results.emplace_back(Value::make(type, op));
op->addAttr(value);
}

void ConstantOp::init(const Type::Ptr &type, double value) {
op->results.emplace_back(Value::make(type, op));
op->addAttr(value);
}

void ConstantOp::init(const Type::Ptr &type, const std::string &value) {
op->results.emplace_back(Value::make(type, op));
op->addAttr(value);
}

void ElseOp::init() {
}

Expand Down
16 changes: 8 additions & 8 deletions compiler/lib/optree/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ void Attribute::dump(std::ostream &stream) const {
stream << "empty";
return;
}
if (is<int64_t>()) {
stream << "int64_t : " << as<int64_t>();
if (is<NativeInt>()) {
stream << "int : " << as<NativeInt>();
return;
}
if (is<double>()) {
stream << "double : " << as<double>();
if (is<NativeFloat>()) {
stream << "float : " << as<NativeFloat>();
return;
}
if (is<bool>()) {
stream << "bool : " << as<bool>();
if (is<NativeBool>()) {
stream << "bool : " << as<NativeBool>();
return;
}
if (is<std::string>()) {
stream << "string : " << as<std::string>();
if (is<NativeStr>()) {
stream << "str : " << as<NativeStr>();
return;
}
if (is<Type::Ptr>()) {
Expand Down
6 changes: 3 additions & 3 deletions compiler/tests/backend/optree/optimizer/erase_unused_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ TEST_F(EraseUnusedOpsTest, can_erase_unused_ops) {
{
auto &&[m, v] = getActual();
m.opInit<FunctionOp>("test", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody();
v[2] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[2] = m.opInit<ConstantOp>(m.tI64, 123);
v[3] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[0], v[1]);
v[4] = m.opInit<ArithCastOp>(ArithCastOpKind::IntToFloat, m.tF64, v[0]);
v[5] = m.opInit<LogicBinaryOp>(LogicBinOpKind::LessEqualI, v[0], v[1]);
Expand All @@ -52,7 +52,7 @@ TEST_F(EraseUnusedOpsTest, can_erase_chain_of_unused_ops) {
{
auto &&[m, v] = getActual();
m.opInit<FunctionOp>("test", m.tFunc({m.tI64}, m.tNone)).inward(v[0], 0).withBody();
v[1] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[1] = m.opInit<ConstantOp>(m.tI64, 123);
v[2] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[0], v[1]);
v[3] = m.opInit<ArithCastOp>(ArithCastOpKind::IntToFloat, m.tF64, v[2]);
m.opInit<ReturnOp>();
Expand All @@ -72,7 +72,7 @@ TEST_F(EraseUnusedOpsTest, can_keep_used_ops) {
auto &&[m, v] = getActual();

m.opInit<FunctionOp>("test", m.tFunc({m.tI64, m.tI64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody();
v[2] = m.opInit<ConstantOp>(m.tI64, int64_t(123));
v[2] = m.opInit<ConstantOp>(m.tI64, 123);
v[3] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[0], v[1]);
v[4] = m.opInit<ArithCastOp>(ArithCastOpKind::IntToFloat, m.tF64, v[0]);
v[5] = m.opInit<LogicBinaryOp>(LogicBinOpKind::LessEqualI, v[0], v[1]);
Expand Down
14 changes: 7 additions & 7 deletions compiler/tests/optree/declarative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ TEST_F(DeclarativeTest, can_insert_function_with_body) {
m.endBody();
// clang-format on
assertDump("Module () -> ()\n"
" Function {string : myfunc, Type : func((int(64), float(64)) -> none)} () -> () [#0 : int(64), #1 : "
" Function {str : myfunc, Type : func((int(64), float(64)) -> none)} () -> () [#0 : int(64), #1 : "
"float(64)]\n"
" Constant {int64_t : 123} () -> (#2 : int(64))\n"
" Constant {int : 123} () -> (#2 : int(64))\n"
" Allocate () -> (#3 : ptr(int(64)))\n"
" ArithBinary {ArithBinOpKind : 1} (#2 : int(64), #1 : float(64)) -> (#4 : int(64))\n"
" Store (#3 : ptr(int(64)), #4 : int(64)) -> ()\n"
Expand All @@ -49,17 +49,17 @@ TEST_F(DeclarativeTest, can_insert_function_with_body) {
TEST_F(DeclarativeTest, can_insert_with_adapted_init) {
// clang-format off
m.opInit<FunctionOp>("myfunc", m.tFunc({m.tI64, m.tF64}, m.tNone)).inward(v[0], 0).inward(v[1], 1).withBody();
v[2] = m.opInit<ConstantOp>(m.tI64, int64_t(456L));
v[2] = m.opInit<ConstantOp>(m.tI64, 456L);
v[3] = m.opInit<AllocateOp>(m.tPtr(m.tI64));
v[4] = m.opInit<ArithBinaryOp>(ArithBinOpKind::AddI, v[2], v[1]);
m.opInit<StoreOp>(v[3], v[4]);
m.opInit<ReturnOp>();
m.endBody();
// clang-format on
assertDump("Module () -> ()\n"
" Function {string : myfunc, Type : func((int(64), float(64)) -> none)} () -> () [#0 : int(64), #1 : "
" Function {str : myfunc, Type : func((int(64), float(64)) -> none)} () -> () [#0 : int(64), #1 : "
"float(64)]\n"
" Constant {int64_t : 456} () -> (#2 : int(64))\n"
" Constant {int : 456} () -> (#2 : int(64))\n"
" Allocate () -> (#3 : ptr(int(64)))\n"
" ArithBinary {ArithBinOpKind : 1} (#2 : int(64), #1 : float(64)) -> (#4 : int(64))\n"
" Store (#3 : ptr(int(64)), #4 : int(64)) -> ()\n"
Expand Down Expand Up @@ -87,8 +87,8 @@ TEST_F(DeclarativeTest, can_insert_nested_operations) {
m.endBody();
// clang-format on
assertDump("Module () -> ()\n"
" Function {string : myfunc, Type : func((float(64)) -> none)} () -> () [#0 : float(64)]\n"
" Constant {double : 7.89} () -> (#1 : float(64))\n"
" Function {str : myfunc, Type : func((float(64)) -> none)} () -> () [#0 : float(64)]\n"
" Constant {float : 7.89} () -> (#1 : float(64))\n"
" Allocate () -> (#2 : ptr(float(64)))\n"
" LogicBinary {LogicBinOpKind : 12} (#0 : float(64), #1 : float(64)) -> (#3 : int(8))\n"
" If (#3 : int(8)) -> ()\n"
Expand Down

0 comments on commit 45b08ba

Please sign in to comment.