Skip to content

Commit

Permalink
Merge pull request #19 from alexdovzhanyn/funtion-invocation-codegen
Browse files Browse the repository at this point in the history
Function Invocation Codegen
  • Loading branch information
alexdovzhanyn authored Jul 30, 2024
2 parents d3c0d41 + 7511609 commit e534590
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 78 deletions.
132 changes: 114 additions & 18 deletions src/compiler/CodeGen.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <iostream>
#include <memory>
#include "binaryen-c.h"
#include "compiler/Compiler.hpp"
#include "lexer/Lexemes.hpp"
#include "StandardLibrary.hpp"
#include "parser/ast/ASTNodeList.hpp"
Expand All @@ -26,6 +27,8 @@ namespace Theta {
}

BinaryenExpressionRef CodeGen::generate(shared_ptr<ASTNode> node, BinaryenModuleRef &module) {
if (node->hasOwnScope()) scope.enterScope();

if (node->getNodeType() == ASTNode::SOURCE) {
generateSource(dynamic_pointer_cast<SourceNode>(node), module);
} else if (node->getNodeType() == ASTNode::CAPSULE) {
Expand All @@ -34,6 +37,10 @@ namespace Theta {
return generateBlock(dynamic_pointer_cast<ASTNodeList>(node), module);
} else if (node->getNodeType() == ASTNode::RETURN) {
return generateReturn(dynamic_pointer_cast<ReturnNode>(node), module);
} else if (node->getNodeType() == ASTNode::FUNCTION_INVOCATION) {
return generateFunctionInvocation(dynamic_pointer_cast<FunctionInvocationNode>(node), module);
} else if (node->getNodeType() == ASTNode::IDENTIFIER) {
return generateIdentifier(dynamic_pointer_cast<IdentifierNode>(node), module);
} else if (node->getNodeType() == ASTNode::BINARY_OPERATION) {
return generateBinaryOperation(dynamic_pointer_cast<BinaryOperationNode>(node), module);
} else if (node->getNodeType() == ASTNode::UNARY_OPERATION) {
Expand All @@ -46,41 +53,82 @@ namespace Theta {
return generateBooleanLiteral(dynamic_pointer_cast<LiteralNode>(node), module);
}

if (node->hasOwnScope()) scope.exitScope();

return nullptr;
}

BinaryenExpressionRef CodeGen::generateCapsule(shared_ptr<CapsuleNode> capsuleNode, BinaryenModuleRef &module) {
vector<shared_ptr<ASTNode>> capsuleElements = dynamic_pointer_cast<ASTNodeList>(capsuleNode->getValue())->getElements();

hoistCapsuleElements(capsuleElements);

for (auto elem : capsuleElements) {
string elemType = dynamic_pointer_cast<TypeDeclarationNode>(elem->getResolvedType())->getType();
if (elem->getNodeType() == ASTNode::ASSIGNMENT) {
shared_ptr<IdentifierNode> identNode = dynamic_pointer_cast<IdentifierNode>(elem->getLeft());

if (elemType == DataTypes::FUNCTION) {
shared_ptr<FunctionDeclarationNode> fnDeclNode = dynamic_pointer_cast<FunctionDeclarationNode>(elem->getRight());
generateFunctionDeclaration(
identNode->getIdentifier(),
dynamic_pointer_cast<FunctionDeclarationNode>(elem->getRight()),
module,
true
);
}
}
}
}

BinaryenExpressionRef CodeGen::generateFunctionDeclaration(
string identifier,
shared_ptr<FunctionDeclarationNode> fnDeclNode,
BinaryenModuleRef &module,
bool addToExports
) {
scope.enterScope();

BinaryenType parameterType = BinaryenTypeNone();
int totalParams = fnDeclNode->getParameters()->getElements().size();

string functionName = capsuleNode->getName() + "." + identNode->getIdentifier();

cout << "it is" << functionName.c_str() << " " << functionName.c_str() << endl;
if (totalParams > 0) {
BinaryenType* types = new BinaryenType[totalParams];

BinaryenExpressionRef body = generate(fnDeclNode->getDefinition(), module);
for (int i = 0; i < totalParams; i++) {
shared_ptr<IdentifierNode> identNode = dynamic_pointer_cast<IdentifierNode>(fnDeclNode->getParameters()->getElements().at(i));

BinaryenFunctionRef fn = BinaryenAddFunction(
module,
functionName.c_str(),
BinaryenTypeNone(),
getBinaryenTypeFromTypeDeclaration(dynamic_pointer_cast<TypeDeclarationNode>(fnDeclNode->getResolvedType()->getValue())),
NULL,
0,
body
);
identNode->setMappedBinaryenIndex(i);

BinaryenAddFunctionExport(module, functionName.c_str(), functionName.c_str());
}
scope.insert(identNode->getIdentifier(), identNode);
types[i] = getBinaryenTypeFromTypeDeclaration(

dynamic_pointer_cast<TypeDeclarationNode>(fnDeclNode->getParameters()->getElements().at(i)->getValue())
);
}

parameterType = BinaryenTypeCreate(types, totalParams);
}

string functionName = Compiler::getQualifiedFunctionIdentifier(
identifier,
dynamic_pointer_cast<ASTNode>(fnDeclNode)
);

BinaryenFunctionRef fn = BinaryenAddFunction(
module,
functionName.c_str(),
parameterType,
getBinaryenTypeFromTypeDeclaration(dynamic_pointer_cast<TypeDeclarationNode>(fnDeclNode->getResolvedType()->getValue())),
NULL,
0,
generate(fnDeclNode->getDefinition(), module)
);

if (addToExports) {
BinaryenAddFunctionExport(module, functionName.c_str(), functionName.c_str());
}

scope.exitScope();
}

BinaryenExpressionRef CodeGen::generateBlock(shared_ptr<ASTNodeList> blockNode, BinaryenModuleRef &module) {
Expand All @@ -103,6 +151,37 @@ namespace Theta {
return BinaryenReturn(module, generate(returnNode->getValue(), module));
}

BinaryenExpressionRef CodeGen::generateFunctionInvocation(shared_ptr<FunctionInvocationNode> funcInvNode, BinaryenModuleRef &module) {
BinaryenExpressionRef* arguments = new BinaryenExpressionRef[funcInvNode->getParameters()->getElements().size()];

string funcName = Compiler::getQualifiedFunctionIdentifier(
dynamic_pointer_cast<IdentifierNode>(funcInvNode->getIdentifier())->getIdentifier(),
funcInvNode
);

for (int i = 0; i < funcInvNode->getParameters()->getElements().size(); i++) {
arguments[i] = generate(funcInvNode->getParameters()->getElements().at(i), module);
}

return BinaryenCall(
module,
funcName.c_str(),
arguments,
funcInvNode->getParameters()->getElements().size(),
getBinaryenTypeFromTypeDeclaration(dynamic_pointer_cast<TypeDeclarationNode>(funcInvNode->getResolvedType()))
);
}

BinaryenExpressionRef CodeGen::generateIdentifier(shared_ptr<IdentifierNode> identNode, BinaryenModuleRef &module) {
shared_ptr<ASTNode> identInScope = scope.lookup(identNode->getIdentifier());

return BinaryenLocalGet(
module,
identInScope->getMappedBinaryenIndex(),
getBinaryenTypeFromTypeDeclaration(dynamic_pointer_cast<TypeDeclarationNode>(identNode->getResolvedType()))
);
}

BinaryenExpressionRef CodeGen::generateBinaryOperation(shared_ptr<BinaryOperationNode> binOpNode, BinaryenModuleRef &module) {
if (binOpNode->getOperator() == Lexemes::EXPONENT) {
return generateExponentOperation(binOpNode, module);
Expand All @@ -113,12 +192,13 @@ namespace Theta {
BinaryenExpressionRef binaryenLeft = generate(binOpNode->getLeft(), module);
BinaryenExpressionRef binaryenRight = generate(binOpNode->getRight(), module);

cout << binOpNode->toJSON() << endl;

if (!binaryenLeft || !binaryenRight) {
throw runtime_error("Invalid operand types for binary operation");
}

// TODO: This wont work if we have nested operations on either side
if (binOpNode->getLeft()->getNodeType() == ASTNode::STRING_LITERAL) {
if (dynamic_pointer_cast<TypeDeclarationNode>(binOpNode->getResolvedType())->getType() == DataTypes::STRING) {
return BinaryenStringConcat(
module,
binaryenLeft,
Expand Down Expand Up @@ -243,4 +323,20 @@ namespace Theta {
if (typeDeclaration->getType() == DataTypes::STRING) return BinaryenTypeStringref();
if (typeDeclaration->getType() == DataTypes::BOOLEAN) return BinaryenTypeInt32();
}

void CodeGen::hoistCapsuleElements(vector<shared_ptr<ASTNode>> elements) {
scope.enterScope();

for (auto ast : elements) bindIdentifierToScope(ast);
}

void CodeGen::bindIdentifierToScope(shared_ptr<ASTNode> ast) {
string identifier = dynamic_pointer_cast<IdentifierNode>(ast->getLeft())->getIdentifier();

if (ast->getRight()->getNodeType() == ASTNode::FUNCTION_DECLARATION) {
identifier = Compiler::getQualifiedFunctionIdentifier(identifier, ast->getRight());
}

scope.insert(identifier, ast->getRight());
}
}
36 changes: 24 additions & 12 deletions src/compiler/CodeGen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
#include "../parser/ast/UnaryOperationNode.hpp"
#include "../parser/ast/LiteralNode.hpp"
#include "../parser/ast/SourceNode.hpp"
#include "compiler/SymbolTableStack.hpp"
#include "parser/ast/ASTNodeList.hpp"
#include "parser/ast/CapsuleNode.hpp"
#include "parser/ast/FunctionDeclarationNode.hpp"
#include "parser/ast/IdentifierNode.hpp"
#include "parser/ast/ReturnNode.hpp"
#include "parser/ast/TypeDeclarationNode.hpp"
#include "parser/ast/FunctionInvocationNode.hpp"
#include <binaryen-c.h>

using namespace std;
Expand All @@ -19,21 +23,29 @@ namespace Theta {
public:
// using GenerateResult = std::variant<BinaryenExpressionRef, BinaryenLiteral, int>;

static BinaryenModuleRef generateWasmFromAST(shared_ptr<ASTNode> ast);
static BinaryenExpressionRef generate(shared_ptr<ASTNode> node, BinaryenModuleRef &module);
static BinaryenExpressionRef generateCapsule(shared_ptr<CapsuleNode> node, BinaryenModuleRef &module);
static BinaryenExpressionRef generateBlock(shared_ptr<ASTNodeList> node, BinaryenModuleRef &module);
static BinaryenExpressionRef generateReturn(shared_ptr<ReturnNode> node, BinaryenModuleRef &module);
static BinaryenExpressionRef generateBinaryOperation(shared_ptr<BinaryOperationNode> node, BinaryenModuleRef &module);
static BinaryenExpressionRef generateUnaryOperation(shared_ptr<UnaryOperationNode> node, BinaryenModuleRef &module);
static BinaryenExpressionRef generateNumberLiteral(shared_ptr<LiteralNode> node, BinaryenModuleRef &module);
static BinaryenExpressionRef generateStringLiteral(shared_ptr<LiteralNode> node, BinaryenModuleRef &module);
static BinaryenExpressionRef generateBooleanLiteral(shared_ptr<LiteralNode> node, BinaryenModuleRef &module);
static BinaryenExpressionRef generateExponentOperation(shared_ptr<BinaryOperationNode> node, BinaryenModuleRef &module);
static void generateSource(shared_ptr<SourceNode> node, BinaryenModuleRef &module);
BinaryenModuleRef generateWasmFromAST(shared_ptr<ASTNode> ast);
BinaryenExpressionRef generate(shared_ptr<ASTNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateCapsule(shared_ptr<CapsuleNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateBlock(shared_ptr<ASTNodeList> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateReturn(shared_ptr<ReturnNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateFunctionDeclaration(string identifier, shared_ptr<FunctionDeclarationNode> node, BinaryenModuleRef &module, bool addToExports = false);
BinaryenExpressionRef generateFunctionInvocation(shared_ptr<FunctionInvocationNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateIdentifier(shared_ptr<IdentifierNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateBinaryOperation(shared_ptr<BinaryOperationNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateUnaryOperation(shared_ptr<UnaryOperationNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateNumberLiteral(shared_ptr<LiteralNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateStringLiteral(shared_ptr<LiteralNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateBooleanLiteral(shared_ptr<LiteralNode> node, BinaryenModuleRef &module);
BinaryenExpressionRef generateExponentOperation(shared_ptr<BinaryOperationNode> node, BinaryenModuleRef &module);
void generateSource(shared_ptr<SourceNode> node, BinaryenModuleRef &module);

private:
SymbolTableStack scope;

static BinaryenOp getBinaryenOpFromBinOpNode(shared_ptr<BinaryOperationNode> node);
static BinaryenType getBinaryenTypeFromTypeDeclaration(shared_ptr<TypeDeclarationNode> node);

void hoistCapsuleElements(vector<shared_ptr<ASTNode>> elements);
void bindIdentifierToScope(shared_ptr<ASTNode> ast);
};
}
32 changes: 30 additions & 2 deletions src/compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace Theta {

if (!isTypeValid) return;

BinaryenModuleRef module = CodeGen::generateWasmFromAST(programAST);
CodeGen codeGen;
BinaryenModuleRef module = codeGen.generateWasmFromAST(programAST);

if (isEmitWAT) {
cout << "Generated WAT for \"" + entrypoint + "\":" << endl;
Expand All @@ -57,7 +58,8 @@ namespace Theta {

if (!isTypeValid) return ast;

BinaryenModuleRef module = CodeGen::generateWasmFromAST(ast);
CodeGen codeGen;
BinaryenModuleRef module = codeGen.generateWasmFromAST(ast);

cout << "-> " + ast->toJSON() << endl;
cout << "-> ";
Expand Down Expand Up @@ -214,4 +216,30 @@ namespace Theta {
cout << "Could not parse AST for file " + fileName << endl;
}
}

string Compiler::getQualifiedFunctionIdentifier(string variableName, shared_ptr<ASTNode> node) {
vector<shared_ptr<ASTNode>> params;

if (node->getNodeType() == ASTNode::FUNCTION_DECLARATION) {
shared_ptr<FunctionDeclarationNode> declarationNode = dynamic_pointer_cast<FunctionDeclarationNode>(node);
params = declarationNode->getParameters()->getElements();
} else {
shared_ptr<FunctionInvocationNode> invocationNode = dynamic_pointer_cast<FunctionInvocationNode>(node);
params = invocationNode->getParameters()->getElements();
}

string functionIdentifier = variableName + to_string(params.size());

for (int i = 0; i < params.size(); i++) {
if (node->getNodeType() == ASTNode::FUNCTION_DECLARATION) {
shared_ptr<TypeDeclarationNode> paramType = dynamic_pointer_cast<TypeDeclarationNode>(params.at(i)->getValue());
functionIdentifier += paramType->getType();
} else {
shared_ptr<TypeDeclarationNode> paramType = dynamic_pointer_cast<TypeDeclarationNode>(params.at(i)->getResolvedType());
functionIdentifier += paramType->getType();
}
}

return functionIdentifier;
}
}
10 changes: 10 additions & 0 deletions src/compiler/Compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ namespace Theta {
*/
bool optimizeAST(shared_ptr<ASTNode> &ast, bool silenceErrors = false);


/**
* @brief Generates a unique function identifier based on the function's name and its parameters to handle overloading.
*
* @param variableName The base name of the function.
* @param declarationNode The function declaration node containing the parameters.
* @return string The unique identifier for the function.
*/
static string getQualifiedFunctionIdentifier(string variableName, shared_ptr<ASTNode> node);

shared_ptr<map<string, string>> filesByCapsuleName;
private:
/**
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/optimization/LiteralInlinerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void LiteralInlinerPass::bindIdentifierToScope(shared_ptr<ASTNode> &ast, SymbolT
string identifier = dynamic_pointer_cast<IdentifierNode>(ast->getLeft())->getIdentifier();

if (ast->getRight()->getNodeType() == ASTNode::FUNCTION_DECLARATION) {
string uniqueFuncIdentifier = getDeterministicFunctionIdentifier(identifier, ast->getRight());
string uniqueFuncIdentifier = Compiler::getQualifiedFunctionIdentifier(identifier, ast->getRight());

shared_ptr<ASTNode> existingFuncIdentifierInScope = scope.lookup(uniqueFuncIdentifier);

Expand Down
27 changes: 0 additions & 27 deletions src/compiler/optimization/OptimizationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,30 +88,3 @@ shared_ptr<ASTNode> OptimizationPass::lookupInScope(string identifierName) {

return foindHoisted;
}


string OptimizationPass::getDeterministicFunctionIdentifier(string variableName, shared_ptr<ASTNode> node) {
vector<shared_ptr<ASTNode>> params;

if (node->getNodeType() == ASTNode::FUNCTION_DECLARATION) {
shared_ptr<FunctionDeclarationNode> declarationNode = dynamic_pointer_cast<FunctionDeclarationNode>(node);
params = declarationNode->getParameters()->getElements();
} else {
shared_ptr<FunctionInvocationNode> invocationNode = dynamic_pointer_cast<FunctionInvocationNode>(node);
params = invocationNode->getParameters()->getElements();
}

string functionIdentifier = variableName + to_string(params.size());

for (int i = 0; i < params.size(); i++) {
if (node->getNodeType() == ASTNode::FUNCTION_DECLARATION) {
shared_ptr<TypeDeclarationNode> paramType = dynamic_pointer_cast<TypeDeclarationNode>(params.at(i)->getValue());
functionIdentifier += paramType->getType();
} else {
shared_ptr<TypeDeclarationNode> paramType = dynamic_pointer_cast<TypeDeclarationNode>(params.at(i)->getResolvedType());
functionIdentifier += paramType->getType();
}
}

return functionIdentifier;
}
9 changes: 0 additions & 9 deletions src/compiler/optimization/OptimizationPass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,6 @@ namespace Theta {
*/
shared_ptr<ASTNode> lookupInScope(string identifier);

/**
* @brief Generates a unique function identifier based on the function's name and its parameters to handle overloading.
*
* @param variableName The base name of the function.
* @param declarationNode The function declaration node containing the parameters.
* @return string The unique identifier for the function.
*/
string getDeterministicFunctionIdentifier(string variableName, shared_ptr<ASTNode> node);

private:
/**
* @brief Pure virtual function to be implemented by derived classes for performing specific optimizations on the AST.
Expand Down
Loading

0 comments on commit e534590

Please sign in to comment.