Skip to content

Commit dd1f8d3

Browse files
committed
fix: resolvedType of binaryOperation should be dependent upon whether the operator returns a boolean or not
1 parent a65dba3 commit dd1f8d3

File tree

5 files changed

+132
-16
lines changed

5 files changed

+132
-16
lines changed

src/compiler/CodeGen.cpp

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1+
#include <iostream>
2+
#include <memory>
3+
#include "binaryen-c.h"
4+
#include "lexer/Lexemes.hpp"
5+
#include "StandardLibrary.hpp"
6+
#include "parser/ast/ASTNodeList.hpp"
7+
#include "parser/ast/FunctionDeclarationNode.hpp"
8+
#include "parser/ast/IdentifierNode.hpp"
9+
#include "parser/ast/TypeDeclarationNode.hpp"
110
#include "CodeGen.hpp"
11+
#include "DataTypes.hpp"
212

313
namespace Theta {
414
BinaryenModuleRef CodeGen::generateWasmFromAST(shared_ptr<ASTNode> ast) {
@@ -10,12 +20,20 @@ namespace Theta {
1020

1121
generate(ast, module);
1222

23+
BinaryenModuleAutoDrop(module);
24+
1325
return module;
1426
}
1527

1628
BinaryenExpressionRef CodeGen::generate(shared_ptr<ASTNode> node, BinaryenModuleRef &module) {
1729
if (node->getNodeType() == ASTNode::SOURCE) {
1830
generateSource(dynamic_pointer_cast<SourceNode>(node), module);
31+
} else if (node->getNodeType() == ASTNode::CAPSULE) {
32+
return generateCapsule(dynamic_pointer_cast<CapsuleNode>(node), module);
33+
} else if (node->getNodeType() == ASTNode::BLOCK) {
34+
return generateBlock(dynamic_pointer_cast<ASTNodeList>(node), module);
35+
} else if (node->getNodeType() == ASTNode::RETURN) {
36+
return generateReturn(dynamic_pointer_cast<ReturnNode>(node), module);
1937
} else if (node->getNodeType() == ASTNode::BINARY_OPERATION) {
2038
return generateBinaryOperation(dynamic_pointer_cast<BinaryOperationNode>(node), module);
2139
} else if (node->getNodeType() == ASTNode::UNARY_OPERATION) {
@@ -31,6 +49,60 @@ namespace Theta {
3149
return nullptr;
3250
}
3351

52+
BinaryenExpressionRef CodeGen::generateCapsule(shared_ptr<CapsuleNode> capsuleNode, BinaryenModuleRef &module) {
53+
vector<shared_ptr<ASTNode>> capsuleElements = dynamic_pointer_cast<ASTNodeList>(capsuleNode->getValue())->getElements();
54+
55+
for (auto elem : capsuleElements) {
56+
string elemType = dynamic_pointer_cast<TypeDeclarationNode>(elem->getResolvedType())->getType();
57+
if (elem->getNodeType() == ASTNode::ASSIGNMENT) {
58+
shared_ptr<IdentifierNode> identNode = dynamic_pointer_cast<IdentifierNode>(elem->getLeft());
59+
60+
if (elemType == DataTypes::FUNCTION) {
61+
shared_ptr<FunctionDeclarationNode> fnDeclNode = dynamic_pointer_cast<FunctionDeclarationNode>(elem->getRight());
62+
63+
64+
string functionName = capsuleNode->getName() + "." + identNode->getIdentifier();
65+
66+
cout << "it is" << functionName.c_str() << " " << functionName.c_str() << endl;
67+
68+
BinaryenExpressionRef body = generate(fnDeclNode->getDefinition(), module);
69+
70+
BinaryenFunctionRef fn = BinaryenAddFunction(
71+
module,
72+
functionName.c_str(),
73+
BinaryenTypeNone(),
74+
getBinaryenTypeFromTypeDeclaration(dynamic_pointer_cast<TypeDeclarationNode>(fnDeclNode->getResolvedType()->getValue())),
75+
NULL,
76+
0,
77+
body
78+
);
79+
80+
BinaryenAddFunctionExport(module, functionName.c_str(), functionName.c_str());
81+
}
82+
}
83+
}
84+
}
85+
86+
BinaryenExpressionRef CodeGen::generateBlock(shared_ptr<ASTNodeList> blockNode, BinaryenModuleRef &module) {
87+
BinaryenExpressionRef* blockExpressions = new BinaryenExpressionRef[blockNode->getElements().size()];
88+
89+
for (int i = 0; i < blockNode->getElements().size(); i++) {
90+
blockExpressions[i] = generate(blockNode->getElements().at(i), module);
91+
}
92+
93+
return BinaryenBlock(
94+
module,
95+
NULL,
96+
blockExpressions,
97+
blockNode->getElements().size(),
98+
BinaryenTypeNone()
99+
);
100+
}
101+
102+
BinaryenExpressionRef CodeGen::generateReturn(shared_ptr<ReturnNode> returnNode, BinaryenModuleRef &module) {
103+
return BinaryenReturn(module, generate(returnNode->getValue(), module));
104+
}
105+
34106
BinaryenExpressionRef CodeGen::generateBinaryOperation(shared_ptr<BinaryOperationNode> binOpNode, BinaryenModuleRef &module) {
35107
if (binOpNode->getOperator() == Lexemes::EXPONENT) {
36108
return generateExponentOperation(binOpNode, module);
@@ -122,24 +194,24 @@ namespace Theta {
122194
BinaryenExpressionRef body = generate(sourceNode->getValue(), module);
123195

124196
if (!body) {
125-
throw std::runtime_error("Invalid body type for source node");
197+
throw runtime_error("Invalid body type for source node");
126198
}
127199

200+
shared_ptr<TypeDeclarationNode> returnType = dynamic_pointer_cast<TypeDeclarationNode>(sourceNode->getValue()->getResolvedType());
201+
128202
BinaryenFunctionRef mainFn = BinaryenAddFunction(
129203
module,
130204
"main",
131205
BinaryenTypeNone(),
132-
// BinaryenTypeStringref(),
133-
// BinaryenTypeInt64(),
134-
BinaryenTypeInt32(),
206+
getBinaryenTypeFromTypeDeclaration(returnType),
135207
NULL,
136208
0,
137209
body
138210
);
139211

140212
BinaryenAddFunctionExport(module, "main", "main");
141213
} else {
142-
// generate(sourceNode->getValue(), module);
214+
generate(sourceNode->getValue(), module);
143215
}
144216
}
145217

@@ -151,7 +223,24 @@ namespace Theta {
151223
if (op == Lexemes::TIMES) return BinaryenMulInt64();
152224
if (op == Lexemes::MODULO) return BinaryenRemSInt64();
153225

226+
string resolvedType = dynamic_pointer_cast<TypeDeclarationNode>(binOpNode->getLeft()->getResolvedType())->getType();
227+
228+
if (op == Lexemes::EQUALITY && resolvedType == DataTypes::NUMBER) return BinaryenEqInt64();
229+
if (op == Lexemes::EQUALITY && resolvedType == DataTypes::BOOLEAN) return BinaryenEqInt32();
230+
if (op == Lexemes::EQUALITY && resolvedType == DataTypes::STRING) return BinaryenStringEqEqual();
231+
if (op == Lexemes::INEQUALITY && resolvedType == DataTypes::NUMBER) return BinaryenNeInt64();
232+
if (op == Lexemes::INEQUALITY && resolvedType == DataTypes::BOOLEAN) return BinaryenEqInt32();
233+
if (op == Lexemes::INEQUALITY && resolvedType == DataTypes::STRING) return BinaryenStringEqEqual(); // FIXME: This is a stub
234+
if (op == Lexemes::LT && resolvedType == DataTypes::NUMBER) return BinaryenLtSInt64();
235+
if (op == Lexemes::GT && resolvedType == DataTypes::NUMBER) return BinaryenGtSInt64();
236+
154237

155238
// if (op == "**") return
156239
}
240+
241+
BinaryenType CodeGen::getBinaryenTypeFromTypeDeclaration(shared_ptr<TypeDeclarationNode> typeDeclaration) {
242+
if (typeDeclaration->getType() == DataTypes::NUMBER) return BinaryenTypeInt64();
243+
if (typeDeclaration->getType() == DataTypes::STRING) return BinaryenTypeStringref();
244+
if (typeDeclaration->getType() == DataTypes::BOOLEAN) return BinaryenTypeInt32();
245+
}
157246
}

src/compiler/CodeGen.hpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
#pragma once
22

3-
#include <vector>
4-
#include <deque>
5-
#include <string>
6-
#include <map>
7-
#include <fstream>
8-
#include <iostream>
93
#include <memory>
10-
#include <filesystem>
114
#include "../parser/ast/ASTNode.hpp"
125
#include "../parser/ast/BinaryOperationNode.hpp"
136
#include "../parser/ast/UnaryOperationNode.hpp"
147
#include "../parser/ast/LiteralNode.hpp"
158
#include "../parser/ast/SourceNode.hpp"
16-
#include "../lexer/Lexemes.hpp"
17-
#include "StandardLibrary.hpp"
9+
#include "parser/ast/ASTNodeList.hpp"
10+
#include "parser/ast/CapsuleNode.hpp"
11+
#include "parser/ast/ReturnNode.hpp"
12+
#include "parser/ast/TypeDeclarationNode.hpp"
1813
#include <binaryen-c.h>
1914

2015
using namespace std;
@@ -26,6 +21,9 @@ namespace Theta {
2621

2722
static BinaryenModuleRef generateWasmFromAST(shared_ptr<ASTNode> ast);
2823
static BinaryenExpressionRef generate(shared_ptr<ASTNode> node, BinaryenModuleRef &module);
24+
static BinaryenExpressionRef generateCapsule(shared_ptr<CapsuleNode> node, BinaryenModuleRef &module);
25+
static BinaryenExpressionRef generateBlock(shared_ptr<ASTNodeList> node, BinaryenModuleRef &module);
26+
static BinaryenExpressionRef generateReturn(shared_ptr<ReturnNode> node, BinaryenModuleRef &module);
2927
static BinaryenExpressionRef generateBinaryOperation(shared_ptr<BinaryOperationNode> node, BinaryenModuleRef &module);
3028
static BinaryenExpressionRef generateUnaryOperation(shared_ptr<UnaryOperationNode> node, BinaryenModuleRef &module);
3129
static BinaryenExpressionRef generateNumberLiteral(shared_ptr<LiteralNode> node, BinaryenModuleRef &module);
@@ -36,5 +34,6 @@ namespace Theta {
3634

3735
private:
3836
static BinaryenOp getBinaryenOpFromBinOpNode(shared_ptr<BinaryOperationNode> node);
37+
static BinaryenType getBinaryenTypeFromTypeDeclaration(shared_ptr<TypeDeclarationNode> node);
3938
};
4039
}

src/compiler/TypeChecker.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "parser/ast/TupleNode.hpp"
2121
#include "parser/ast/TypeDeclarationNode.hpp"
2222
#include "parser/ast/IdentifierNode.hpp"
23+
#include "lexer/Lexemes.hpp"
2324

2425
using namespace std;
2526

@@ -214,7 +215,11 @@ namespace Theta {
214215
return false;
215216
}
216217

217-
node->setResolvedType(node->getLeft()->getResolvedType());
218+
if (isBooleanOperator(node->getOperator())) {
219+
node->setResolvedType(make_shared<TypeDeclarationNode>(DataTypes::BOOLEAN));
220+
} else {
221+
node->setResolvedType(node->getLeft()->getResolvedType());
222+
}
218223

219224
return true;
220225
}
@@ -800,6 +805,21 @@ namespace Theta {
800805
return find(LANGUAGE_DATATYPES.begin(), LANGUAGE_DATATYPES.end(), type) != LANGUAGE_DATATYPES.end();
801806
}
802807

808+
bool TypeChecker::isBooleanOperator(string op) {
809+
array<string, 9> BOOLEAN_OPERATORS = {
810+
Lexemes::EQUALITY,
811+
Lexemes::INEQUALITY,
812+
Lexemes::LT,
813+
Lexemes::LTEQ,
814+
Lexemes::GT,
815+
Lexemes::GTEQ,
816+
Lexemes::AND,
817+
Lexemes::OR
818+
};
819+
820+
return find(BOOLEAN_OPERATORS.begin(), BOOLEAN_OPERATORS.end(), op) != BOOLEAN_OPERATORS.end();
821+
}
822+
803823
shared_ptr<TypeDeclarationNode> TypeChecker::makeVariadicType(vector<shared_ptr<TypeDeclarationNode>> types) {
804824
shared_ptr<TypeDeclarationNode> variadicTypeNode = make_shared<TypeDeclarationNode>(DataTypes::VARIADIC);
805825

src/compiler/TypeChecker.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,15 @@ namespace Theta {
257257
*/
258258
static bool isLanguageDataType(string type);
259259

260+
/**
261+
* @brief Checks if a given operator returns a boolean.
262+
*
263+
* @param op The operator to check.
264+
* @return true If the operator returns a boolean.
265+
* @return false Otherwise.
266+
*/
267+
static bool isBooleanOperator(string op);
268+
260269
/**
261270
* @brief Finds all AST nodes of a specific type within the tree rooted at a given node.
262271
*

src/compiler/optimization/LiteralInlinerPass.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "parser/ast/SymbolNode.hpp"
1010
#include "parser/ast/TypeDeclarationNode.hpp"
1111
#include <memory>
12-
#include <iostream>
1312

1413
using namespace Theta;
1514

0 commit comments

Comments
 (0)