From 6ac11b97a9e4a97de547aca63cf4b1a151a3c379 Mon Sep 17 00:00:00 2001 From: Jose Mendoza <56417208+StunxFS@users.noreply.github.com> Date: Wed, 27 Dec 2023 17:55:01 +0000 Subject: [PATCH] register: fix bug --- rivetc/src/__init__.py | 20 ++++++++++++++------ rivetc/src/register.py | 14 +++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/rivetc/src/__init__.py b/rivetc/src/__init__.py index e5dee3255..0186f8ae0 100644 --- a/rivetc/src/__init__.py +++ b/rivetc/src/__init__.py @@ -105,6 +105,7 @@ def import_module(self, sf, decl): if len(decl.subimports) > 0: for subimport in decl.subimports: self.import_module(sf, subimport) + subimport.id=id(subimport.mod_sym) return mod = self.load_module_files( decl.path, decl.alias, sf.file, decl.pos @@ -156,14 +157,21 @@ def import_graph(self): deps.append("core") for d in fp.decls: if isinstance(d, ast.ImportDecl): - if not d.mod_sym: - continue # module not found - if d.mod_sym.name == fp.sym.name: - report.error("import cycle detected", d.pos) - continue - deps.append(d.mod_sym.name) + if len(d.subimports) > 0: + for subimport in d.subimports: + self.import_graph_mod(subimport, deps, fp) + else: + self.import_graph_mod(d, deps, fp) g.add(fp.sym.name, deps) return g + + def import_graph_mod(self, d, deps, fp): + if not d.mod_sym: + return # module not found + if d.mod_sym.name == fp.sym.name: + report.error("import cycle detected", d.pos) + return + deps.append(d.mod_sym.name) def load_root_module(self): if path.isdir(self.prefs.input): diff --git a/rivetc/src/register.py b/rivetc/src/register.py index 5935e1e15..ded9a2f08 100644 --- a/rivetc/src/register.py +++ b/rivetc/src/register.py @@ -200,11 +200,17 @@ def walk_decls(self, decls): report.error(e.args[0], decl.name_pos) self.abi = old_abi self.sym = old_sym - + def walk_import_decl(self, decl): if len(decl.subimports) > 0: for subimport in decl.subimports: self.walk_import_decl(subimport) + elif decl.glob: + for symbol in decl.mod_sym.syms: + if not symbol.is_public: + continue + self.check_imported_symbol(symbol, decl.pos) + self.source_file.imported_symbols[symbol.name] = symbol elif len(decl.import_list) == 0: if decl.is_public: try: @@ -218,12 +224,6 @@ def walk_import_decl(self, decl): else: self.source_file.imported_symbols[decl.alias ] = decl.mod_sym - elif decl.glob: - for symbol in decl.mod_sym.syms: - if not symbol.is_public: - continue - self.check_imported_symbol(symbol, decl.pos) - self.source_file.imported_symbols[symbol.name] = symbol else: for import_info in decl.import_list: if import_info.name == "self":