Skip to content

Commit

Permalink
bootstrap: support basic comptime if
Browse files Browse the repository at this point in the history
  • Loading branch information
StunxFS committed Dec 25, 2023
1 parent ea38392 commit 454bc2a
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 62 deletions.
133 changes: 89 additions & 44 deletions rivetc/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,28 +57,61 @@ def __init__(self, args):
self.checker = checker.Checker(self)
self.codegen = codegen.Codegen(self)

def run(self):
# if we are compiling the `core` module, avoid autoloading it
if self.prefs.mod_name != "core":
self.parsed_files += self.load_module("core", "core", "", token.NO_POS)

self.load_root_module()
self.import_modules()

if not self.prefs.check_syntax:
self.vlog("registering symbols...")
self.register.walk_files(self.source_files)
if report.ERRORS > 0:
self.abort()
self.vlog("resolving symbols...")
self.resolver.resolve_files(self.source_files)
if report.ERRORS > 0:
self.abort()
self.vlog("checking files...")
self.checker.check_files(self.source_files)
if report.ERRORS > 0:
self.abort()
if not self.prefs.check:
self.vlog("generating RIR...")
self.codegen.gen_source_files(self.source_files)
if report.ERRORS > 0:
self.abort()

def import_modules(self):
for sf in self.parsed_files:
for decl in sf.decls:
if isinstance(decl, ast.ImportDecl):
mod = self.load_module_files(
decl.path, decl.alias, sf.file, decl.pos
)
if mod.found:
if mod_sym_ := self.universe.find(mod.full_name):
mod_sym = mod_sym_ # module already imported
else:
mod_sym = sym.Mod(False, mod.full_name)
self.universe.add(mod_sym)
self.parsed_files += parser.Parser(self).parse_mod(
mod_sym, mod.files
)
decl.alias = mod.alias
decl.mod_sym = mod_sym
self.import_modules_from_decls(sf, sf.decls)
self.resolve_deps()
if report.ERRORS > 0:
self.abort()

def import_modules_from_decls(self, sf, decls):
for decl in decls:
if isinstance(decl, ast.ImportDecl):
mod = self.load_module_files(
decl.path, decl.alias, sf.file, decl.pos
)
if mod.found:
if mod_sym_ := self.universe.find(mod.full_name):
mod_sym = mod_sym_ # module already imported
else:
mod_sym = sym.Mod(False, mod.full_name)
self.universe.add(mod_sym)
self.parsed_files += parser.Parser(self).parse_mod(
mod_sym, mod.files
)
decl.alias = mod.alias
decl.mod_sym = mod_sym
elif isinstance(decl, ast.ComptimeIf):
ct_decls = self.evalue_comptime_if(decl)
self.import_modules_from_decls(sf, ct_decls)

def resolve_deps(self):
g = self.import_graph()
g_resolved = g.resolve()
Expand Down Expand Up @@ -123,33 +156,6 @@ def import_graph(self):
g.add(fp.sym.name, deps)
return g

def run(self):
# if we are compiling the `core` module, avoid autoloading it
if self.prefs.mod_name != "core":
self.parsed_files += self.load_module("core", "core", "", token.NO_POS)

self.load_root_module()
self.import_modules()

if not self.prefs.check_syntax:
self.vlog("registering symbols...")
self.register.walk_files(self.source_files)
if report.ERRORS > 0:
self.abort()
self.vlog("resolving symbols...")
self.resolver.resolve_files(self.source_files)
if report.ERRORS > 0:
self.abort()
self.vlog("checking files...")
self.checker.check_files(self.source_files)
if report.ERRORS > 0:
self.abort()
if not self.prefs.check:
self.vlog("generating RIR...")
self.codegen.gen_source_files(self.source_files)
if report.ERRORS > 0:
self.abort()

def load_root_module(self):
if path.isdir(self.prefs.input):
files = self.filter_files(
Expand Down Expand Up @@ -450,7 +456,46 @@ def type_symbol_size(self, sy):
sy.align = align
return size, align

def evalue_pp_symbol(self, name, pos):
def evalue_comptime_if(self, comptime_if):
if comptime_if.branch_idx != None:
return comptime_if.branches[comptime_if.branch_idx].nodes
for i, branch in enumerate(comptime_if.branches):
if branch.is_else and comptime_if.branch_idx == None:
comptime_if.branch_idx = i
elif cond := self.evalue_comptime_condition(branch.cond):
if cond:
comptime_if.branch_idx = i
if comptime_if.branch_idx != None:
return comptime_if.branches[comptime_if.branch_idx].nodes
return []

def evalue_comptime_condition(self, cond):
if isinstance(cond, ast.ParExpr):
return self.evalue_comptime_condition(cond.expr)
elif isinstance(cond, ast.Ident):
return self.evalue_comptime_ident(cond.name, cond.pos)
elif isinstance(cond, ast.UnaryExpr) and cond.op == token.Kind.Bang:
if val := self.evalue_comptime_condition(cond.right):
return not val
else:
return None
elif isinstance(cond, ast.BinaryExpr) and binary.op in [token.Kind.LogicalAnd, token.Kind.LogicalOr]:
left = self.evalue_comptime_condition(binary.left)
if left != None:
if binary.op == token.Kind.LogicalOr and left:
return True
right = self.evalue_comptime_condition(binary.right)
if right != None:
if binary.op == token.Kind.LogicalAnd:
return left and right
return right
return None
return None
else:
report.error("invalid comptime condition", cond.pos)
return None

def evalue_comptime_ident(self, name, pos):
# operating systems
if name in ("_LINUX_", "_WINDOWS_"):
return self.prefs.target_os.equals_to_string(name)
Expand Down
16 changes: 16 additions & 0 deletions rivetc/src/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ def __repr__(self):
def __str__(self):
return self.__repr__()

class ComptimeIf:
def __init__(self, branches, has_else, pos):
self.branches = branches
self.branch_idx = None
self.has_else = has_else
self.pos = pos
self.typ = None

class ComptimeIfBranch:
def __init__(self, cond, is_else, nodes, pos):
self.cond=cond
self.is_else=is_else
self.nodes=nodes
self.pos=pos
self.typ=None

# Used in variable decls/stmts and guard exprs
class ObjDecl:
def __init__(self, is_mut, is_ref, name, has_typ, typ, level, pos):
Expand Down
11 changes: 9 additions & 2 deletions rivetc/src/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def check_decls(self, decls):

def check_decl(self, decl):
old_sym = self.sym
if isinstance(decl, ast.ExternDecl):
if isinstance(decl, ast.ComptimeIf):
self.check_decls(self.comp.evalue_comptime_if(decl))
elif isinstance(decl, ast.ExternDecl):
self.check_decls(decl.decls)
elif isinstance(decl, ast.ConstDecl):
if decl.has_typ:
Expand Down Expand Up @@ -212,7 +214,9 @@ def check_stmts(self, stmts):
self.check_stmt(stmt)

def check_stmt(self, stmt):
if isinstance(stmt, ast.VarDeclStmt):
if isinstance(stmt, ast.ComptimeIf):
self.check_stmts(self.comp.evalue_comptime_if(stmt))
elif isinstance(stmt, ast.VarDeclStmt):
old_expected_type = self.expected_type
if len(stmt.lefts) == 1:
if stmt.lefts[0].has_typ:
Expand Down Expand Up @@ -309,6 +313,9 @@ def check_stmt(self, stmt):
def check_expr(self, expr):
if isinstance(expr, ast.EmptyExpr):
pass # error raised in `Resolver`
elif isinstance(expr, ast.ComptimeIf):
expr.typ = self.check_expr(self.comp.evalue_comptime_if(expr)[0])
return expr.typ
elif isinstance(expr, ast.TypeNode):
return expr.typ
elif isinstance(expr, ast.AssignExpr):
Expand Down
12 changes: 9 additions & 3 deletions rivetc/src/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ def gen_decls(self, decls):

def gen_decl(self, decl):
self.cur_func_defer_stmts = []
if isinstance(decl, ast.ExternDecl):
if isinstance(decl, ast.ComptimeIf):
self.gen_decls(self.comp.evalue_comptime_if(decl))
elif isinstance(decl, ast.ExternDecl):
if decl.abi != sym.ABI.Rivet:
self.gen_decls(decl.decls)
elif isinstance(decl, ast.VarDecl):
Expand Down Expand Up @@ -389,7 +391,9 @@ def gen_stmts(self, stmts):
self.gen_stmt(stmt)

def gen_stmt(self, stmt):
if isinstance(stmt, ast.ForStmt):
if isinstance(stmt, ast.ComptimeIf):
self.gen_stmts(self.comp.evalue_comptime_if(stmt))
elif isinstance(stmt, ast.ForStmt):
old_loop_scope = self.loop_scope
self.loop_scope = stmt.scope
old_while_continue_expr = self.while_continue_expr
Expand Down Expand Up @@ -673,7 +677,9 @@ def gen_expr_with_cast(self, expected_typ_, expr, custom_tmp = None):
return res_expr

def gen_expr(self, expr, custom_tmp = None):
if isinstance(expr, ast.ParExpr):
if isinstance(expr, ast.ComptimeIf):
self.gen_expr(self.comp.evalue_comptime_if(expr)[0])
elif isinstance(expr, ast.ParExpr):
return self.gen_expr(expr.expr)
elif isinstance(expr, ast.NoneLiteral):
return ir.NoneLit(ir.VOID_PTR_T)
Expand Down
2 changes: 1 addition & 1 deletion rivetc/src/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,5 +840,5 @@ def pp_symbol(self):
elif ident == "false":
defined = False
else:
defined = self.comp.evalue_pp_symbol(ident, pos)
defined = self.comp.evalue_comptime_ident(ident, pos)
return defined
69 changes: 66 additions & 3 deletions rivetc/src/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,51 @@ def close_scope(self):
self.scope.parent.childrens.append(self.scope)
self.scope = self.scope.parent

# ---- comptime ------------------
def parse_nodes(self, level):
if level == 0: # decl
decls = []
if self.accept(Kind.Lbrace):
while self.tok.kind != Kind.Rbrace:
decls.append(self.parse_decl())
self.expect(Kind.Rbrace)
else:
decls.append(self.parse_decl())
return decls
elif level == 1: # stmts
stmts = []
if self.accept(Kind.Lbrace):
while self.tok.kind != Kind.Rbrace:
stmts.append(self.parse_stmt())
self.expect(Kind.Rbrace)
else:
stmts.append(self.parse_stmt())
return stmts
return [self.parse_expr()]

def parse_comptime_if(self, level):
branches = []
has_else = False
pos = self.tok.pos
while self.tok.kind in (Kind.KwIf, Kind.KwElse):
if self.accept(Kind.KwElse) and self.tok.kind != Kind.KwIf:
branches.append(
ast.ComptimeIfBranch(
self.empty_expr(), self.parse_nodes(level), True,
Kind.KwElse
)
)
has_else = True
break
self.expect(Kind.KwIf)
cond = self.parse_expr()
branches.append(
ast.ComptimeIfBranch(cond,False, self.parse_nodes(level), cond.pos)
)
if self.tok.kind != Kind.KwElse:
break
return ast.ComptimeIf(branches, has_else, pos)

# ---- declarations --------------
def parse_doc_comment(self):
pos = self.tok.pos
Expand Down Expand Up @@ -169,7 +214,13 @@ def parse_decl(self):
)
is_public = self.is_public()
pos = self.tok.pos
if self.accept(Kind.KwImport):
if self.accept(Kind.KwComptime):
if self.tok.kind == Kind.KwIf:
return self.parse_comptime_if(0)
else:
report.error("invalid comptime construction", self.tok.pos)
return
elif self.accept(Kind.KwImport):
path = self.parse_import_path()
alias = ""
import_list = []
Expand Down Expand Up @@ -580,7 +631,13 @@ def decl_operator_is_used(self):
return False

def parse_stmt(self):
if self.accept(Kind.KwWhile):
if self.accept(Kind.KwComptime):
if self.tok.kind == Kind.KwIf:
return self.parse_comptime_if(1)
else:
report.error("invalid comptime construction", self.tok.pos)
return
elif self.accept(Kind.KwWhile):
pos = self.prev_tok.pos
is_inf = False
continue_expr = self.empty_expr()
Expand Down Expand Up @@ -814,7 +871,13 @@ def parse_unary_expr(self):

def parse_primary_expr(self):
expr = self.empty_expr()
if self.tok.kind in [
if self.accept(Kind.KwComptime):
if self.tok.kind == Kind.KwIf:
expr = self.parse_comptime_if(3)
else:
report.error("invalid comptime construction", self.tok.pos)
return
elif self.tok.kind in [
Kind.KwTrue, Kind.KwFalse, Kind.Char, Kind.Number, Kind.String,
Kind.KwNone, Kind.KwSelf, Kind.KwSelfTy
]:
Expand Down
4 changes: 3 additions & 1 deletion rivetc/src/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def walk_decls(self, decls):
for decl in decls:
old_abi = self.abi
old_sym = self.sym
if isinstance(decl, ast.ImportDecl):
if isinstance(decl, ast.ComptimeIf):
self.walk_decls(self.comp.evalue_comptime_if(decl))
elif isinstance(decl, ast.ImportDecl):
if len(decl.import_list) == 0:
if decl.is_public:
try:
Expand Down
Loading

0 comments on commit 454bc2a

Please sign in to comment.