diff --git a/rivetc/src/__init__.py b/rivetc/src/__init__.py index b95f013ee..4a173e1b6 100644 --- a/rivetc/src/__init__.py +++ b/rivetc/src/__init__.py @@ -57,28 +57,61 @@ def __init__(self, args): self.checker = checker.Checker(self) self.codegen = codegen.Codegen(self) + def run(self): + # if we are compiling the `core` module, avoid autoloading it + if self.prefs.mod_name != "core": + self.parsed_files += self.load_module("core", "core", "", token.NO_POS) + + self.load_root_module() + self.import_modules() + + if not self.prefs.check_syntax: + self.vlog("registering symbols...") + self.register.walk_files(self.source_files) + if report.ERRORS > 0: + self.abort() + self.vlog("resolving symbols...") + self.resolver.resolve_files(self.source_files) + if report.ERRORS > 0: + self.abort() + self.vlog("checking files...") + self.checker.check_files(self.source_files) + if report.ERRORS > 0: + self.abort() + if not self.prefs.check: + self.vlog("generating RIR...") + self.codegen.gen_source_files(self.source_files) + if report.ERRORS > 0: + self.abort() + def import_modules(self): for sf in self.parsed_files: - for decl in sf.decls: - if isinstance(decl, ast.ImportDecl): - mod = self.load_module_files( - decl.path, decl.alias, sf.file, decl.pos - ) - if mod.found: - if mod_sym_ := self.universe.find(mod.full_name): - mod_sym = mod_sym_ # module already imported - else: - mod_sym = sym.Mod(False, mod.full_name) - self.universe.add(mod_sym) - self.parsed_files += parser.Parser(self).parse_mod( - mod_sym, mod.files - ) - decl.alias = mod.alias - decl.mod_sym = mod_sym + self.import_modules_from_decls(sf, sf.decls) self.resolve_deps() if report.ERRORS > 0: self.abort() + def import_modules_from_decls(self, sf, decls): + for decl in decls: + if isinstance(decl, ast.ImportDecl): + mod = self.load_module_files( + decl.path, decl.alias, sf.file, decl.pos + ) + if mod.found: + if mod_sym_ := self.universe.find(mod.full_name): + mod_sym = mod_sym_ # module already imported + else: + mod_sym = sym.Mod(False, mod.full_name) + self.universe.add(mod_sym) + self.parsed_files += parser.Parser(self).parse_mod( + mod_sym, mod.files + ) + decl.alias = mod.alias + decl.mod_sym = mod_sym + elif isinstance(decl, ast.ComptimeIf): + ct_decls = self.evalue_comptime_if(decl) + self.import_modules_from_decls(sf, ct_decls) + def resolve_deps(self): g = self.import_graph() g_resolved = g.resolve() @@ -123,33 +156,6 @@ def import_graph(self): g.add(fp.sym.name, deps) return g - def run(self): - # if we are compiling the `core` module, avoid autoloading it - if self.prefs.mod_name != "core": - self.parsed_files += self.load_module("core", "core", "", token.NO_POS) - - self.load_root_module() - self.import_modules() - - if not self.prefs.check_syntax: - self.vlog("registering symbols...") - self.register.walk_files(self.source_files) - if report.ERRORS > 0: - self.abort() - self.vlog("resolving symbols...") - self.resolver.resolve_files(self.source_files) - if report.ERRORS > 0: - self.abort() - self.vlog("checking files...") - self.checker.check_files(self.source_files) - if report.ERRORS > 0: - self.abort() - if not self.prefs.check: - self.vlog("generating RIR...") - self.codegen.gen_source_files(self.source_files) - if report.ERRORS > 0: - self.abort() - def load_root_module(self): if path.isdir(self.prefs.input): files = self.filter_files( @@ -450,7 +456,46 @@ def type_symbol_size(self, sy): sy.align = align return size, align - def evalue_pp_symbol(self, name, pos): + def evalue_comptime_if(self, comptime_if): + if comptime_if.branch_idx != None: + return comptime_if.branches[comptime_if.branch_idx].nodes + for i, branch in enumerate(comptime_if.branches): + if branch.is_else and comptime_if.branch_idx == None: + comptime_if.branch_idx = i + elif cond := self.evalue_comptime_condition(branch.cond): + if cond: + comptime_if.branch_idx = i + if comptime_if.branch_idx != None: + return comptime_if.branches[comptime_if.branch_idx].nodes + return [] + + def evalue_comptime_condition(self, cond): + if isinstance(cond, ast.ParExpr): + return self.evalue_comptime_condition(cond.expr) + elif isinstance(cond, ast.Ident): + return self.evalue_comptime_ident(cond.name, cond.pos) + elif isinstance(cond, ast.UnaryExpr) and cond.op == token.Kind.Bang: + if val := self.evalue_comptime_condition(cond.right): + return not val + else: + return None + elif isinstance(cond, ast.BinaryExpr) and binary.op in [token.Kind.LogicalAnd, token.Kind.LogicalOr]: + left = self.evalue_comptime_condition(binary.left) + if left != None: + if binary.op == token.Kind.LogicalOr and left: + return True + right = self.evalue_comptime_condition(binary.right) + if right != None: + if binary.op == token.Kind.LogicalAnd: + return left and right + return right + return None + return None + else: + report.error("invalid comptime condition", cond.pos) + return None + + def evalue_comptime_ident(self, name, pos): # operating systems if name in ("_LINUX_", "_WINDOWS_"): return self.prefs.target_os.equals_to_string(name) diff --git a/rivetc/src/ast.py b/rivetc/src/ast.py index 25498984a..0a4ba97de 100644 --- a/rivetc/src/ast.py +++ b/rivetc/src/ast.py @@ -30,6 +30,22 @@ def __repr__(self): def __str__(self): return self.__repr__() +class ComptimeIf: + def __init__(self, branches, has_else, pos): + self.branches = branches + self.branch_idx = None + self.has_else = has_else + self.pos = pos + self.typ = None + +class ComptimeIfBranch: + def __init__(self, cond, is_else, nodes, pos): + self.cond=cond + self.is_else=is_else + self.nodes=nodes + self.pos=pos + self.typ=None + # Used in variable decls/stmts and guard exprs class ObjDecl: def __init__(self, is_mut, is_ref, name, has_typ, typ, level, pos): diff --git a/rivetc/src/checker.py b/rivetc/src/checker.py index bb6d5e334..3b5440949 100644 --- a/rivetc/src/checker.py +++ b/rivetc/src/checker.py @@ -71,7 +71,9 @@ def check_decls(self, decls): def check_decl(self, decl): old_sym = self.sym - if isinstance(decl, ast.ExternDecl): + if isinstance(decl, ast.ComptimeIf): + self.check_decls(self.comp.evalue_comptime_if(decl)) + elif isinstance(decl, ast.ExternDecl): self.check_decls(decl.decls) elif isinstance(decl, ast.ConstDecl): if decl.has_typ: @@ -212,7 +214,9 @@ def check_stmts(self, stmts): self.check_stmt(stmt) def check_stmt(self, stmt): - if isinstance(stmt, ast.VarDeclStmt): + if isinstance(stmt, ast.ComptimeIf): + self.check_stmts(self.comp.evalue_comptime_if(stmt)) + elif isinstance(stmt, ast.VarDeclStmt): old_expected_type = self.expected_type if len(stmt.lefts) == 1: if stmt.lefts[0].has_typ: @@ -309,6 +313,9 @@ def check_stmt(self, stmt): def check_expr(self, expr): if isinstance(expr, ast.EmptyExpr): pass # error raised in `Resolver` + elif isinstance(expr, ast.ComptimeIf): + expr.typ = self.check_expr(self.comp.evalue_comptime_if(expr)[0]) + return expr.typ elif isinstance(expr, ast.TypeNode): return expr.typ elif isinstance(expr, ast.AssignExpr): diff --git a/rivetc/src/codegen/__init__.py b/rivetc/src/codegen/__init__.py index 816f5499f..4e08bef1b 100644 --- a/rivetc/src/codegen/__init__.py +++ b/rivetc/src/codegen/__init__.py @@ -250,7 +250,9 @@ def gen_decls(self, decls): def gen_decl(self, decl): self.cur_func_defer_stmts = [] - if isinstance(decl, ast.ExternDecl): + if isinstance(decl, ast.ComptimeIf): + self.gen_decls(self.comp.evalue_comptime_if(decl)) + elif isinstance(decl, ast.ExternDecl): if decl.abi != sym.ABI.Rivet: self.gen_decls(decl.decls) elif isinstance(decl, ast.VarDecl): @@ -389,7 +391,9 @@ def gen_stmts(self, stmts): self.gen_stmt(stmt) def gen_stmt(self, stmt): - if isinstance(stmt, ast.ForStmt): + if isinstance(stmt, ast.ComptimeIf): + self.gen_stmts(self.comp.evalue_comptime_if(stmt)) + elif isinstance(stmt, ast.ForStmt): old_loop_scope = self.loop_scope self.loop_scope = stmt.scope old_while_continue_expr = self.while_continue_expr @@ -673,7 +677,9 @@ def gen_expr_with_cast(self, expected_typ_, expr, custom_tmp = None): return res_expr def gen_expr(self, expr, custom_tmp = None): - if isinstance(expr, ast.ParExpr): + if isinstance(expr, ast.ComptimeIf): + self.gen_expr(self.comp.evalue_comptime_if(expr)[0]) + elif isinstance(expr, ast.ParExpr): return self.gen_expr(expr.expr) elif isinstance(expr, ast.NoneLiteral): return ir.NoneLit(ir.VOID_PTR_T) diff --git a/rivetc/src/lexer.py b/rivetc/src/lexer.py index 6ded30b79..904a40396 100644 --- a/rivetc/src/lexer.py +++ b/rivetc/src/lexer.py @@ -840,5 +840,5 @@ def pp_symbol(self): elif ident == "false": defined = False else: - defined = self.comp.evalue_pp_symbol(ident, pos) + defined = self.comp.evalue_comptime_ident(ident, pos) return defined diff --git a/rivetc/src/parser.py b/rivetc/src/parser.py index 360b0426f..efcf881a9 100644 --- a/rivetc/src/parser.py +++ b/rivetc/src/parser.py @@ -97,6 +97,51 @@ def close_scope(self): self.scope.parent.childrens.append(self.scope) self.scope = self.scope.parent + # ---- comptime ------------------ + def parse_nodes(self, level): + if level == 0: # decl + decls = [] + if self.accept(Kind.Lbrace): + while self.tok.kind != Kind.Rbrace: + decls.append(self.parse_decl()) + self.expect(Kind.Rbrace) + else: + decls.append(self.parse_decl()) + return decls + elif level == 1: # stmts + stmts = [] + if self.accept(Kind.Lbrace): + while self.tok.kind != Kind.Rbrace: + stmts.append(self.parse_stmt()) + self.expect(Kind.Rbrace) + else: + stmts.append(self.parse_stmt()) + return stmts + return [self.parse_expr()] + + def parse_comptime_if(self, level): + branches = [] + has_else = False + pos = self.tok.pos + while self.tok.kind in (Kind.KwIf, Kind.KwElse): + if self.accept(Kind.KwElse) and self.tok.kind != Kind.KwIf: + branches.append( + ast.ComptimeIfBranch( + self.empty_expr(), self.parse_nodes(level), True, + Kind.KwElse + ) + ) + has_else = True + break + self.expect(Kind.KwIf) + cond = self.parse_expr() + branches.append( + ast.ComptimeIfBranch(cond,False, self.parse_nodes(level), cond.pos) + ) + if self.tok.kind != Kind.KwElse: + break + return ast.ComptimeIf(branches, has_else, pos) + # ---- declarations -------------- def parse_doc_comment(self): pos = self.tok.pos @@ -169,7 +214,13 @@ def parse_decl(self): ) is_public = self.is_public() pos = self.tok.pos - if self.accept(Kind.KwImport): + if self.accept(Kind.KwComptime): + if self.tok.kind == Kind.KwIf: + return self.parse_comptime_if(0) + else: + report.error("invalid comptime construction", self.tok.pos) + return + elif self.accept(Kind.KwImport): path = self.parse_import_path() alias = "" import_list = [] @@ -580,7 +631,13 @@ def decl_operator_is_used(self): return False def parse_stmt(self): - if self.accept(Kind.KwWhile): + if self.accept(Kind.KwComptime): + if self.tok.kind == Kind.KwIf: + return self.parse_comptime_if(1) + else: + report.error("invalid comptime construction", self.tok.pos) + return + elif self.accept(Kind.KwWhile): pos = self.prev_tok.pos is_inf = False continue_expr = self.empty_expr() @@ -814,7 +871,13 @@ def parse_unary_expr(self): def parse_primary_expr(self): expr = self.empty_expr() - if self.tok.kind in [ + if self.accept(Kind.KwComptime): + if self.tok.kind == Kind.KwIf: + expr = self.parse_comptime_if(3) + else: + report.error("invalid comptime construction", self.tok.pos) + return + elif self.tok.kind in [ Kind.KwTrue, Kind.KwFalse, Kind.Char, Kind.Number, Kind.String, Kind.KwNone, Kind.KwSelf, Kind.KwSelfTy ]: diff --git a/rivetc/src/register.py b/rivetc/src/register.py index 9fc6424ff..2b26d57ad 100644 --- a/rivetc/src/register.py +++ b/rivetc/src/register.py @@ -27,7 +27,9 @@ def walk_decls(self, decls): for decl in decls: old_abi = self.abi old_sym = self.sym - if isinstance(decl, ast.ImportDecl): + if isinstance(decl, ast.ComptimeIf): + self.walk_decls(self.comp.evalue_comptime_if(decl)) + elif isinstance(decl, ast.ImportDecl): if len(decl.import_list) == 0: if decl.is_public: try: diff --git a/rivetc/src/resolver.py b/rivetc/src/resolver.py index 16fee5ca7..b88499c81 100644 --- a/rivetc/src/resolver.py +++ b/rivetc/src/resolver.py @@ -27,7 +27,9 @@ def resolve_decls(self, decls): for decl in decls: old_sym = self.sym old_self_sym = self.self_sym - if isinstance(decl, ast.ExternDecl): + if isinstance(decl, ast.ComptimeIf): + self.resolve_decls(self.comp.evalue_comptime_if(decl)) + elif isinstance(decl, ast.ExternDecl): self.resolve_decls(decl.decls) elif isinstance(decl, ast.ConstDecl): self.resolve_type(decl.typ) @@ -152,16 +154,20 @@ def resolve_decls(self, decls): if arg.has_def_expr: self.resolve_expr(arg.def_expr) self.resolve_type(decl.ret_typ) - for stmt in decl.stmts: - self.resolve_stmt(stmt) + self.resolve_stmts(decl.stmts) elif isinstance(decl, ast.TestDecl): - for stmt in decl.stmts: - self.resolve_stmt(stmt) + self.resolve_stmts(decl.stmt) self.sym = old_sym self.self_sym = old_self_sym + def resolve_stmts(self, stmts): + for stmt in stmts: + self.resolve_stmt(stmt) + def resolve_stmt(self, stmt): - if isinstance(stmt, ast.VarDeclStmt): + if isinstance(stmt, ast.ComptimeIf): + self.resolve_stmts(self.comp.evalue_comptime_if(stmt)) + elif isinstance(stmt, ast.VarDeclStmt): for v in stmt.lefts: if v.has_typ: self.resolve_type(v.typ) @@ -212,6 +218,8 @@ def resolve_expr(self, expr): if isinstance(expr, ast.EmptyExpr): report.error("empty expression found", expr.pos) report.note("unexpected bug, please, report it") + elif isinstance(expr, ast.ComptimeIf): + self.resolve_expr(self.comp.evalue_comptime_if(expr)[0]) elif isinstance(expr, ast.TypeNode): self.resolve_type(expr.typ) elif isinstance(expr, ast.AssignExpr): @@ -315,8 +323,7 @@ def resolve_expr(self, expr): elif isinstance(expr, ast.ThrowExpr): self.resolve_expr(expr.expr) elif isinstance(expr, ast.Block): - for stmt in expr.stmts: - self.resolve_stmt(stmt) + self.resolve_stmts(expr.stmts) if expr.is_expr: self.resolve_expr(expr.expr) elif isinstance(expr, ast.IfExpr):