diff --git a/jac/jaclang/compiler/absyntree.py b/jac/jaclang/compiler/absyntree.py index 9b91ad168a..e25ba9d81c 100644 --- a/jac/jaclang/compiler/absyntree.py +++ b/jac/jaclang/compiler/absyntree.py @@ -697,14 +697,16 @@ def unparse(self) -> str: def get_href_path(node: AstNode) -> str: """Return the full path of the module that contains this node.""" parent = node.find_parent_of_type(Module) - mod_list = [] - if isinstance(node, Module): + mod_list: list[Module | Architype] = [] + if isinstance(node, (Module, Architype)): mod_list.append(node) while parent is not None: mod_list.append(parent) parent = parent.find_parent_of_type(Module) mod_list.reverse() - return ".".join(p.name for p in mod_list) + return ".".join( + p.name if isinstance(p, Module) else p.name.sym_name for p in mod_list + ) class GlobalVars(ElementStmt, AstAccessNode): diff --git a/jac/jaclang/compiler/passes/main/inheritance_pass.py b/jac/jaclang/compiler/passes/main/inheritance_pass.py new file mode 100644 index 0000000000..8f4ba95877 --- /dev/null +++ b/jac/jaclang/compiler/passes/main/inheritance_pass.py @@ -0,0 +1,103 @@ +"""Pass used to add the inherited symbols for architypes.""" + +from __future__ import annotations + +from typing import Optional + +import jaclang.compiler.absyntree as ast +from jaclang.compiler.passes import Pass +from jaclang.compiler.symtable import Symbol, SymbolTable +from jaclang.settings import settings + + +class InheritancePass(Pass): + """Add inherited abilities in the target symbol tables.""" + + def __debug_print(self, msg: str) -> None: + if settings.inherit_pass_debug: + self.log_info("[PyImportPass] " + msg) + + def __lookup(self, name: str, sym_table: SymbolTable) -> Optional[Symbol]: + symbol = sym_table.lookup(name) + if symbol is None: + # Check if the needed symbol in builtins + builtins_symtable = self.ir.sym_tab.find_scope("builtins") + assert builtins_symtable is not None + symbol = builtins_symtable.lookup(name) + return symbol + + def enter_architype(self, node: ast.Architype) -> None: + """Fill architype symbol tables with abilities from parent architypes.""" + if node.base_classes is None: + return + + for item in node.base_classes.items: + # The assumption is that the base class can only be a name node + # or an atom trailer only. + assert isinstance(item, (ast.Name, ast.AtomTrailer)) + + # In case of name node, then get the symbol table that contains + # the current class and lookup for that name after that use the + # symbol to get the symbol table of the base class + if isinstance(item, ast.Name): + assert node.sym_tab.parent is not None + base_class_symbol = self.__lookup(item.sym_name, node.sym_tab.parent) + if base_class_symbol is None: + msg = "Missing symbol for base class " + msg += f"{ast.Module.get_href_path(item)}.{item.sym_name}" + msg += f" needed for {ast.Module.get_href_path(node)}" + self.__debug_print(msg) + continue + base_class_symbol_table = base_class_symbol.fetch_sym_tab + if ( + base_class_symbol_table is None + and base_class_symbol.defn[0] + .parent_of_type(ast.Module) + .py_info.is_raised_from_py + ): + msg = "Missing symbol table for python base class " + msg += f"{ast.Module.get_href_path(item)}.{item.sym_name}" + msg += f" needed for {ast.Module.get_href_path(node)}" + self.__debug_print(msg) + continue + assert base_class_symbol_table is not None + node.sym_tab.inherit_sym_tab(base_class_symbol_table) + + # In case of atom trailer, unwind it and use each name node to + # as the code above to lookup for the base class + elif isinstance(item, ast.AtomTrailer): + current_sym_table = node.sym_tab.parent + not_found: bool = False + assert current_sym_table is not None + for name in item.as_attr_list: + sym = self.__lookup(name.sym_name, current_sym_table) + if sym is None: + msg = "Missing symbol for base class " + msg += f"{ast.Module.get_href_path(name)}.{name.sym_name}" + msg += f" needed for {ast.Module.get_href_path(node)}" + self.__debug_print(msg) + not_found = True + break + current_sym_table = sym.fetch_sym_tab + + # In case of python nodes, the base class may not be + # raised so ignore these classes for now + # TODO Do we need to import these classes? + if ( + sym.defn[0].parent_of_type(ast.Module).py_info.is_raised_from_py + and current_sym_table is None + ): + msg = "Missing symbol table for python base class " + msg += f"{ast.Module.get_href_path(name)}.{name.sym_name}" + msg += f" needed for {ast.Module.get_href_path(node)}" + self.__debug_print(msg) + not_found = True + break + + assert current_sym_table is not None + + if not_found: + continue + + assert current_sym_table is not None + node.sym_tab.inherit_sym_tab(current_sym_table) diff --git a/jac/jaclang/compiler/passes/main/schedules.py b/jac/jaclang/compiler/passes/main/schedules.py index 0ac604fa9b..1274e43d89 100644 --- a/jac/jaclang/compiler/passes/main/schedules.py +++ b/jac/jaclang/compiler/passes/main/schedules.py @@ -20,6 +20,7 @@ from .registry_pass import RegistryPass # noqa: I100 from .access_modifier_pass import AccessCheckPass # noqa: I100 from .py_collect_dep_pass import PyCollectDepsPass # noqa: I100 +from .inheritance_pass import InheritancePass # noqa: I100 py_code_gen = [ SubNodeTabPass, @@ -38,6 +39,7 @@ PyCollectDepsPass, PyImportPass, DefUsePass, + InheritancePass, FuseTypeInfoPass, AccessCheckPass, ] diff --git a/jac/jaclang/langserve/tests/test_server.py b/jac/jaclang/langserve/tests/test_server.py index 375b8317c5..82c776ff87 100644 --- a/jac/jaclang/langserve/tests/test_server.py +++ b/jac/jaclang/langserve/tests/test_server.py @@ -330,7 +330,7 @@ def test_completion(self) -> None: "doubleinner", "apply_red", ], - 8, + 11, ), ( lspt.Position(65, 23), @@ -359,7 +359,7 @@ def test_completion(self) -> None: "doubleinner", "apply_red", ], - 8, + 11, ), ( lspt.Position(73, 22), diff --git a/jac/jaclang/settings.py b/jac/jaclang/settings.py index 2c1b1291d7..431a4218c3 100644 --- a/jac/jaclang/settings.py +++ b/jac/jaclang/settings.py @@ -17,6 +17,7 @@ class Settings: collect_py_dep_debug: bool = False print_py_raised_ast: bool = False py_import_pass_debug: bool = False + inherit_pass_debug: bool = False # Compiler configuration disable_mtllm: bool = False diff --git a/jac/jaclang/tests/fixtures/base_class1.jac b/jac/jaclang/tests/fixtures/base_class1.jac new file mode 100644 index 0000000000..288f57e8c1 --- /dev/null +++ b/jac/jaclang/tests/fixtures/base_class1.jac @@ -0,0 +1,11 @@ +import:py test_py; + +class B :test_py.A: {} + +with entry { + a = test_py.A(); + b = B(); + + a.start(); + b.start(); +} diff --git a/jac/jaclang/tests/fixtures/base_class2.jac b/jac/jaclang/tests/fixtures/base_class2.jac new file mode 100644 index 0000000000..110d822906 --- /dev/null +++ b/jac/jaclang/tests/fixtures/base_class2.jac @@ -0,0 +1,11 @@ +import:py from test_py { A } + +class B :A: {} + +with entry { + a = A(); + b = B(); + + a.start(); + b.start(); +} diff --git a/jac/jaclang/tests/fixtures/test_py.py b/jac/jaclang/tests/fixtures/test_py.py new file mode 100644 index 0000000000..ba3e4f804e --- /dev/null +++ b/jac/jaclang/tests/fixtures/test_py.py @@ -0,0 +1,12 @@ +"""Test file for subclass issue.""" + +p = 5 +g = 6 + + +class A: + """Dummy class to test the base class issue.""" + + def start(self) -> int: + """Return 0.""" + return 0 diff --git a/jac/jaclang/tests/test_cli.py b/jac/jaclang/tests/test_cli.py index 8e20720a7d..0a97391118 100644 --- a/jac/jaclang/tests/test_cli.py +++ b/jac/jaclang/tests/test_cli.py @@ -233,6 +233,44 @@ def test_builtins_loading(self) -> None: r"13\:12 \- 13\:18.*Name - append - .*SymbolPath: builtins_test.builtins.list.append", ) + def test_sub_class_symbol_table_fix_1(self) -> None: + """Testing for print AstTool.""" + from jaclang.settings import settings + + settings.ast_symbol_info_detailed = True + captured_output = io.StringIO() + sys.stdout = captured_output + + cli.tool("ir", ["ast", f"{self.fixture_abs_path('base_class1.jac')}"]) + + sys.stdout = sys.__stdout__ + stdout_value = captured_output.getvalue() + settings.ast_symbol_info_detailed = False + + self.assertRegex( + stdout_value, + r"10:7 - 10:12.*Name - start - Type.*SymbolPath: base_class1.B.start", + ) + + def test_sub_class_symbol_table_fix_2(self) -> None: + """Testing for print AstTool.""" + from jaclang.settings import settings + + settings.ast_symbol_info_detailed = True + captured_output = io.StringIO() + sys.stdout = captured_output + + cli.tool("ir", ["ast", f"{self.fixture_abs_path('base_class2.jac')}"]) + + sys.stdout = sys.__stdout__ + stdout_value = captured_output.getvalue() + settings.ast_symbol_info_detailed = False + + self.assertRegex( + stdout_value, + r"10:7 - 10:12.*Name - start - Type.*SymbolPath: base_class2.B.start", + ) + def test_expr_types(self) -> None: """Testing for print AstTool.""" captured_output = io.StringIO()