Skip to content

Commit

Permalink
VM: add support for multiple types in binary expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
mrunix00 committed Jul 7, 2024
1 parent a7c3a40 commit 5fdee67
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 43 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ project(SPL)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

add_compile_options(-O3)
add_compile_options(-O0)

file(GLOB_RECURSE TEST_SOURCES tests/*.cpp)
file(GLOB_RECURSE SOURCES src/*.cpp)
Expand Down
103 changes: 102 additions & 1 deletion include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,105 @@ static inline Variable::Type varTypeConvert(AbstractSyntaxTree *ast) {
#define VAR_CASE(OP, TYPE) \
case Variable::Type::TYPE: \
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::OP##TYPE}); \
break;
break;

static Variable::Type deduceType(Program &program, Segment &segment, AbstractSyntaxTree *ast) {
switch (ast->nodeType) {
case AbstractSyntaxTree::Type::Node: {
auto token = dynamic_cast<Node *>(ast)->token;
switch (token.type) {
case Number: {
try {
std::stoi(token.value);
return Variable::Type::I32;
} catch (std::out_of_range &) {
goto long64;
}
long64:
try {
std::stol(token.value);
return Variable::Type::I64;
} catch (std::exception &) {
throw std::runtime_error("Invalid number: " + token.value);
}
}
case Identifier: {
if (segment.find_local(token.value) != -1)
return segment.locals[token.value].type;
if (program.find_global(token.value) != -1)
return program.segments[0].locals[token.value].type;
throw std::runtime_error("Identifier not found: " + token.value);
}
default:
throw std::runtime_error("Invalid type: " + token.value);
}
}
case AbstractSyntaxTree::Type::UnaryExpression: {
auto unary = dynamic_cast<UnaryExpression *>(ast);
return deduceType(program, segment, unary->expression);
}
case AbstractSyntaxTree::Type::BinaryExpression: {
auto binary = dynamic_cast<BinaryExpression *>(ast);
auto left = deduceType(program, segment, binary->left);
auto right = deduceType(program, segment, binary->right);
if ((left == Variable::Type::I32 || left == Variable::Type::I64) &&
(right == Variable::Type::I32 || right == Variable::Type::I64))
return left == Variable::Type::I64 || right == Variable::Type::I64 ? Variable::Type::I64
: Variable::Type::I32;
if (left != right)
throw std::runtime_error("Type mismatch");
return left;
}
case AbstractSyntaxTree::Type::FunctionCall: {
// TODO: add support for multiple return types
auto call = dynamic_cast<FunctionCall *>(ast);
auto function = program.find_function(program.segments[segment.id], call->identifier.token.value);
if (function == -1)
throw std::runtime_error("Function not found: " + call->identifier.token.value);
return Variable::Type::I32;
}
default:
throw std::runtime_error("Invalid type: " + ast->typeStr);
}
}

enum class GenericInstruction {
Add,
Sub,
Mul,
Div,
Mod,
Equal,
Less,
Greater,
GreaterEqual,
LessEqual,
NotEqual
};

#define TYPE_CASE(INS) \
case GenericInstruction::INS: { \
switch (type) { \
case Variable::Type::I32: \
return {Instruction::InstructionType::INS##I32}; \
case Variable::Type::I64: \
return {Instruction::InstructionType::INS##I64}; \
default: \
throw std::runtime_error("[getInstructionWithType] Invalid type"); \
} \
}
static inline Instruction getInstructionWithType(GenericInstruction instruction, Variable::Type type) {
switch (instruction) {
TYPE_CASE(Add)
TYPE_CASE(Sub)
TYPE_CASE(Mul)
TYPE_CASE(Div)
TYPE_CASE(Mod)
TYPE_CASE(Equal)
TYPE_CASE(Less)
TYPE_CASE(Greater)
TYPE_CASE(GreaterEqual)
TYPE_CASE(LessEqual)
TYPE_CASE(NotEqual)
}
}
92 changes: 51 additions & 41 deletions src/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,27 @@ bool Node::operator==(const AbstractSyntaxTree &other) const {
void Node::compile(Program &program, Segment &segment) const {
switch (token.type) {
case Number: {
try {
segment.instructions.push_back(
Instruction{
.type = Instruction::InstructionType::LoadI32,
.params = {.i32 = std::stoi(token.value)},
});
return;
} catch (std::out_of_range &) {
goto long64;
auto type = deduceType(program, segment, (AbstractSyntaxTree *) this);
switch (type) {
case Variable::Type::I32:
return segment.instructions.push_back(
Instruction{
.type = Instruction::InstructionType::LoadI32,
.params = {.i32 = std::stoi(token.value)},
});
case Variable::Type::I64:
return segment.instructions.push_back(
Instruction{
.type = Instruction::InstructionType::LoadI64,
.params = {.i64 = std::stol(token.value)},
});
default:
throw std::runtime_error("[Node::compile] Invalid type: " + token.value);
}
long64:
segment.instructions.push_back(
Instruction{
.type = Instruction::InstructionType::LoadI64,
.params = {.i64 = std::stol(token.value)},
});
} break;
}
case Identifier: {
emitLoad(program, segment, token.value);
} break;
return emitLoad(program, segment, token.value);
}
default:
throw std::runtime_error("[Node::compile] This should not be accessed!");
}
Expand All @@ -60,50 +61,59 @@ bool BinaryExpression::operator==(const AbstractSyntaxTree &other) const {
op == otherBinaryExpression.op;
}

// TODO: Add support for other types
void BinaryExpression::compile(Program &program, Segment &segment) const {
if (op.type == Assign) {
right->compile(program, segment);
emitStore(program, segment, dynamic_cast<Node &>(*left).token.value);
return;
}
auto leftType = deduceType(program, segment, left);
auto rightType = deduceType(program, segment, right);
auto finalType = deduceType(program, segment, (AbstractSyntaxTree*) this);

left->compile(program, segment);
if (leftType != finalType) {
switch (leftType) {
case Variable::Type::I32:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::ConvertI32toI64});
break;
default:
throw std::runtime_error("[BinaryExpression::compile] Invalid type: " + left->typeStr);
}
}
right->compile(program, segment);
if (rightType != finalType) {
switch (rightType) {
case Variable::Type::I32:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::ConvertI32toI64});
break;
default:
throw std::runtime_error("[BinaryExpression::compile] Invalid type: "+ right->typeStr);
}
}
switch (op.type) {
case Plus:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::AddI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Add, finalType));
case Minus:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::SubI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Sub, finalType));
case Multiply:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::MulI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Mul, finalType));
case Divide:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::DivI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Div, finalType));
case Modulo:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::ModI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Mod, finalType));
case Greater:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::GreaterI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Greater, finalType));
case Less:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::LessI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Less, finalType));
case GreaterEqual:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::GreaterEqualI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::GreaterEqual, finalType));
case LessEqual:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::LessEqualI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::LessEqual, finalType));
case Equal:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::EqualI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Equal, finalType));
case NotEqual:
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::NotEqualI32});
break;
return segment.instructions.push_back(getInstructionWithType(GenericInstruction::NotEqual, finalType));
default:
throw std::runtime_error("[BinaryExpression::compile] Invalid operator: " + op.value);
}
Expand Down
10 changes: 10 additions & 0 deletions tests/vm_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ TEST(VM, SimpleI64VariableDeclaration) {
ASSERT_EQ(*static_cast<int64_t *>(vm.topStack(sizeof(int64_t))), 42);
}

TEST(VM, BinaryExpressionWithMultipleTypes) {
const char *input = "define a : i64 = 42;"
"define b : i32 = 42;"
"a + b;";
VM vm;
auto program = compile(input);
vm.run(program);
ASSERT_EQ(*static_cast<int64_t *>(vm.topStack(sizeof(int64_t))), 84);
}

TEST(VM, SimpleVariableAssignment) {
const char *input = "define a : i32 = 42; a = 43; a;";
VM vm;
Expand Down

0 comments on commit 5fdee67

Please sign in to comment.