Skip to content

Commit

Permalink
VM: add support for different function types
Browse files Browse the repository at this point in the history
  • Loading branch information
mrunix00 committed Jul 16, 2024
1 parent b41129e commit d3c91a8
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 90 deletions.
26 changes: 13 additions & 13 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
size_t id; \
bool isLocal; \
Instruction instruction; \
Variable::Type type; \
VariableType *type; \
if (segment.find_local(identifier) != -1) { \
isLocal = true; \
id = segment.find_local(identifier); \
Expand All @@ -21,18 +21,18 @@
} else { \
throw std::runtime_error("[Node::compile] Identifier not found: " + identifier); \
} \
switch (type) { \
case Variable::Type::I32: \
switch (type->type) { \
case VariableType::Type::I32: \
instruction.type = isLocal ? Instruction::InstructionType::OPERATION##LocalI32 \
: Instruction::InstructionType::OPERATION##GlobalI32; \
instruction.params.index = id; \
break; \
case Variable::Type::I64: \
case VariableType::Type::I64: \
instruction.type = isLocal ? Instruction::InstructionType::OPERATION##LocalI64 \
: Instruction::InstructionType::OPERATION##GlobalI64; \
instruction.params.index = id; \
break; \
case Variable::Type::U32: \
case VariableType::Type::U32: \
instruction.type = isLocal ? Instruction::InstructionType::OPERATION##LocalU32 \
: Instruction::InstructionType::OPERATION##GlobalU32; \
instruction.params.index = id; \
Expand Down Expand Up @@ -69,7 +69,7 @@ inline uint32_t convert<uint32_t>(const std::string &value) {
} else { \
((Node *) value.value())->compile(program, segment); \
} \
segment.declare_variable(identifier.token.value, Variable::Type::TYPE); \
segment.declare_variable(identifier.token.value, new VariableType(VariableType::Type::TYPE)); \
segment.instructions.push_back({ \
.type = segment.id == 0 \
? Instruction::InstructionType::StoreGlobal##TYPE \
Expand All @@ -79,7 +79,7 @@ inline uint32_t convert<uint32_t>(const std::string &value) {
} break;

#define VAR_CASE(OP, TYPE) \
case Variable::Type::TYPE: \
case VariableType::Type::TYPE: \
segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::OP##TYPE}); \
break;

Expand All @@ -99,9 +99,9 @@ enum class GenericInstruction {
};

void assert(bool condition, const char *message);
Variable::Type varTypeConvert(AbstractSyntaxTree *ast);
Variable::Type deduceType(Program &program, Segment &segment, AbstractSyntaxTree *ast);
Instruction getInstructionWithType(GenericInstruction instruction, Variable::Type type);
Instruction emitLoad(Variable::Type, const Token &token);
void typeCast(std::vector<Instruction> &instructions, Variable::Type from, Variable::Type to);
size_t sizeOfType(Variable::Type type);
VariableType::Type varTypeConvert(AbstractSyntaxTree *ast);
VariableType::Type deduceType(Program &program, Segment &segment, AbstractSyntaxTree *ast);
Instruction getInstructionWithType(GenericInstruction instruction, VariableType::Type type);
Instruction emitLoad(VariableType::Type, const Token &token);
void typeCast(std::vector<Instruction> &instructions, VariableType::Type from, VariableType::Type to);
size_t sizeOfType(VariableType::Type type);
34 changes: 25 additions & 9 deletions include/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

struct Instruction {
Expand Down Expand Up @@ -80,35 +81,50 @@ struct Instruction {
} params{};
};

struct Variable {
std::string name;
enum class Type {
struct VariableType {
enum Type {
Invalid = 0,
I32,
I64,
U32,
Function
} type;
size_t index;
size_t size;
explicit VariableType(Type type) : type(type){};
};

struct FunctionType : public VariableType {
VariableType *returnType;
std::vector<VariableType *> arguments;
FunctionType(VariableType *returnType, std::vector<VariableType *> arguments)
: VariableType(Function), returnType(returnType), arguments(std::move(arguments)){};
};

struct Variable {
std::string name;
VariableType *type{};
size_t index{};
size_t size{};
Variable() = default;
Variable(std::string name, VariableType *type, size_t index, size_t size)
: name(std::move(name)), type(type), index(index), size(size){};
};

struct Segment {
std::vector<Instruction> instructions;
std::unordered_map<std::string, Variable> locals;
std::unordered_map<std::string, size_t> functions;
std::unordered_map<std::string, Variable> functions;
size_t locals_capacity;
size_t id{};
size_t find_local(const std::string &identifier);
void declare_variable(const std::string &name, Variable::Type type);
void declare_function(const std::string &name, size_t index);
void declare_variable(const std::string &name, VariableType *varType);
void declare_function(const std::string &name, VariableType *funcType, size_t index);
};

struct Program {
std::vector<Segment> segments;
Program();
size_t find_global(const std::string &identifier);
size_t find_function(const Segment &segment, const std::string &identifier);
Variable find_function(const Segment &segment, const std::string &identifier);
};

struct StackFrame {
Expand Down
38 changes: 24 additions & 14 deletions src/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void Declaration::compile(Program &program, Segment &segment) const {
if (value.has_value()) {
auto varType = deduceType(program, segment, value.value());
value.value()->compile(program, segment);
segment.declare_variable(identifier.token.value, varType);
segment.declare_variable(identifier.token.value, new VariableType(varType));
emitStore(program, segment, identifier.token.value);
} else {
throw std::runtime_error("[Declaration::compile] Cannot deduce the variable type!");
Expand All @@ -140,15 +140,19 @@ void Declaration::compile(Program &program, Segment &segment) const {
case AbstractSyntaxTree::Type::FunctionDeclaration: {
auto functionDeclaration = (FunctionDeclaration *) type.value();
auto newSegment = Segment{.id = program.segments.size()};
segment.declare_function(identifier.token.value, program.segments.size());
segment.declare_variable(identifier.token.value, Variable::Type::Function);
auto returnType = new VariableType(varTypeConvert(functionDeclaration->returnType));
auto arguments = std::vector<VariableType *>();
for (auto arg: functionDeclaration->arguments)
arguments.push_back(new VariableType(varTypeConvert(arg->type.value())));
segment.declare_function(identifier.token.value,
new FunctionType(returnType, arguments),
program.segments.size());
for (auto argument: functionDeclaration->arguments) {
newSegment.locals[argument->identifier.token.value] = {
.name = argument->identifier.token.value,
.type = deduceType(program, segment, argument),
.index = newSegment.locals.size(),
.size = sizeOfType(varTypeConvert(functionDeclaration->returnType))
};
newSegment.locals[argument->identifier.token.value] = Variable(
argument->identifier.token.value,
new VariableType(deduceType(program, segment, argument)),
newSegment.locals.size(),
sizeOfType(varTypeConvert(functionDeclaration->returnType)));
newSegment.locals_capacity += sizeOfType(deduceType(program, segment, argument));
}
value.value()->compile(program, newSegment);
Expand Down Expand Up @@ -253,14 +257,20 @@ bool FunctionCall::operator==(const AbstractSyntaxTree &other) const {
return identifier == otherFunctionCall.identifier;
}
void FunctionCall::compile(Program &program, Segment &segment) const {
for (auto &argument: arguments) {
auto function = program.find_function(segment, identifier.token.value);
auto functionType = (FunctionType *) function.type;
for (int i = 0; i < arguments.size(); i++) {
auto argument = arguments[i];
auto definedArgument = functionType->arguments[i];
if (deduceType(program, segment, argument) != definedArgument->type)
throw std::runtime_error("[FunctionCall::compile] Argument type mismatch!");
argument->compile(program, segment);
}

segment.instructions.push_back(
Instruction{
.type = Instruction::InstructionType::Call,
.params = {.index = program.find_function(segment, identifier.token.value)},
.params = {.index = function.index},
});
}

Expand Down Expand Up @@ -344,7 +354,7 @@ void UnaryExpression::compile(Program &program, Segment &segment) const {
if (node->token.type != Identifier)
throw std::runtime_error("[UnaryExpression::compile] Invalid expression varType!");

Variable::Type varType;
VariableType *varType;
if (segment.find_local(node->token.value) != -1) {
varType = segment.locals[node->token.value].type;
} else if (program.find_global(node->token.value) != -1) {
Expand All @@ -356,7 +366,7 @@ void UnaryExpression::compile(Program &program, Segment &segment) const {
emitLoad(program, segment, node->token.value);
switch (op.type) {
case Increment:
switch (varType) {
switch (varType->type) {
VAR_CASE(Increment, U32)
VAR_CASE(Increment, I32)
VAR_CASE(Increment, I64)
Expand All @@ -365,7 +375,7 @@ void UnaryExpression::compile(Program &program, Segment &segment) const {
}
break;
case Decrement:
switch (varType) {
switch (varType->type) {
VAR_CASE(Decrement, U32)
VAR_CASE(Decrement, I32)
VAR_CASE(Decrement, I64)
Expand Down
Loading

0 comments on commit d3c91a8

Please sign in to comment.