Skip to content

Commit

Permalink
Merge pull request #1488 from Jaseci-Labs/sub_class_symtab_fix
Browse files Browse the repository at this point in the history
[Issue #1345 fix] Sub class symtab fix
  • Loading branch information
marsninja authored Jan 6, 2025
2 parents d049a8c + 2c2cd14 commit 577af09
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 5 deletions.
8 changes: 5 additions & 3 deletions jac/jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,14 +697,16 @@ def unparse(self) -> str:
def get_href_path(node: AstNode) -> str:
"""Return the full path of the module that contains this node."""
parent = node.find_parent_of_type(Module)
mod_list = []
if isinstance(node, Module):
mod_list: list[Module | Architype] = []
if isinstance(node, (Module, Architype)):
mod_list.append(node)
while parent is not None:
mod_list.append(parent)
parent = parent.find_parent_of_type(Module)
mod_list.reverse()
return ".".join(p.name for p in mod_list)
return ".".join(
p.name if isinstance(p, Module) else p.name.sym_name for p in mod_list
)


class GlobalVars(ElementStmt, AstAccessNode):
Expand Down
103 changes: 103 additions & 0 deletions jac/jaclang/compiler/passes/main/inheritance_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Pass used to add the inherited symbols for architypes."""

from __future__ import annotations

from typing import Optional

import jaclang.compiler.absyntree as ast
from jaclang.compiler.passes import Pass
from jaclang.compiler.symtable import Symbol, SymbolTable
from jaclang.settings import settings


class InheritancePass(Pass):
"""Add inherited abilities in the target symbol tables."""

def __debug_print(self, msg: str) -> None:
if settings.inherit_pass_debug:
self.log_info("[PyImportPass] " + msg)

def __lookup(self, name: str, sym_table: SymbolTable) -> Optional[Symbol]:
symbol = sym_table.lookup(name)
if symbol is None:
# Check if the needed symbol in builtins
builtins_symtable = self.ir.sym_tab.find_scope("builtins")
assert builtins_symtable is not None
symbol = builtins_symtable.lookup(name)
return symbol

def enter_architype(self, node: ast.Architype) -> None:
"""Fill architype symbol tables with abilities from parent architypes."""
if node.base_classes is None:
return

for item in node.base_classes.items:
# The assumption is that the base class can only be a name node
# or an atom trailer only.
assert isinstance(item, (ast.Name, ast.AtomTrailer))

# In case of name node, then get the symbol table that contains
# the current class and lookup for that name after that use the
# symbol to get the symbol table of the base class
if isinstance(item, ast.Name):
assert node.sym_tab.parent is not None
base_class_symbol = self.__lookup(item.sym_name, node.sym_tab.parent)
if base_class_symbol is None:
msg = "Missing symbol for base class "
msg += f"{ast.Module.get_href_path(item)}.{item.sym_name}"
msg += f" needed for {ast.Module.get_href_path(node)}"
self.__debug_print(msg)
continue
base_class_symbol_table = base_class_symbol.fetch_sym_tab
if (
base_class_symbol_table is None
and base_class_symbol.defn[0]
.parent_of_type(ast.Module)
.py_info.is_raised_from_py
):
msg = "Missing symbol table for python base class "
msg += f"{ast.Module.get_href_path(item)}.{item.sym_name}"
msg += f" needed for {ast.Module.get_href_path(node)}"
self.__debug_print(msg)
continue
assert base_class_symbol_table is not None
node.sym_tab.inherit_sym_tab(base_class_symbol_table)

# In case of atom trailer, unwind it and use each name node to
# as the code above to lookup for the base class
elif isinstance(item, ast.AtomTrailer):
current_sym_table = node.sym_tab.parent
not_found: bool = False
assert current_sym_table is not None
for name in item.as_attr_list:
sym = self.__lookup(name.sym_name, current_sym_table)
if sym is None:
msg = "Missing symbol for base class "
msg += f"{ast.Module.get_href_path(name)}.{name.sym_name}"
msg += f" needed for {ast.Module.get_href_path(node)}"
self.__debug_print(msg)
not_found = True
break
current_sym_table = sym.fetch_sym_tab

# In case of python nodes, the base class may not be
# raised so ignore these classes for now
# TODO Do we need to import these classes?
if (
sym.defn[0].parent_of_type(ast.Module).py_info.is_raised_from_py
and current_sym_table is None
):
msg = "Missing symbol table for python base class "
msg += f"{ast.Module.get_href_path(name)}.{name.sym_name}"
msg += f" needed for {ast.Module.get_href_path(node)}"
self.__debug_print(msg)
not_found = True
break

assert current_sym_table is not None

if not_found:
continue

assert current_sym_table is not None
node.sym_tab.inherit_sym_tab(current_sym_table)
2 changes: 2 additions & 0 deletions jac/jaclang/compiler/passes/main/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .registry_pass import RegistryPass # noqa: I100
from .access_modifier_pass import AccessCheckPass # noqa: I100
from .py_collect_dep_pass import PyCollectDepsPass # noqa: I100
from .inheritance_pass import InheritancePass # noqa: I100

py_code_gen = [
SubNodeTabPass,
Expand All @@ -38,6 +39,7 @@
PyCollectDepsPass,
PyImportPass,
DefUsePass,
InheritancePass,
FuseTypeInfoPass,
AccessCheckPass,
]
Expand Down
4 changes: 2 additions & 2 deletions jac/jaclang/langserve/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_completion(self) -> None:
"doubleinner",
"apply_red",
],
8,
11,
),
(
lspt.Position(65, 23),
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_completion(self) -> None:
"doubleinner",
"apply_red",
],
8,
11,
),
(
lspt.Position(73, 22),
Expand Down
1 change: 1 addition & 0 deletions jac/jaclang/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Settings:
collect_py_dep_debug: bool = False
print_py_raised_ast: bool = False
py_import_pass_debug: bool = False
inherit_pass_debug: bool = False

# Compiler configuration
disable_mtllm: bool = False
Expand Down
11 changes: 11 additions & 0 deletions jac/jaclang/tests/fixtures/base_class1.jac
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import:py test_py;

class B :test_py.A: {}

with entry {
a = test_py.A();
b = B();

a.start();
b.start();
}
11 changes: 11 additions & 0 deletions jac/jaclang/tests/fixtures/base_class2.jac
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import:py from test_py { A }

class B :A: {}

with entry {
a = A();
b = B();

a.start();
b.start();
}
12 changes: 12 additions & 0 deletions jac/jaclang/tests/fixtures/test_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Test file for subclass issue."""

p = 5
g = 6


class A:
"""Dummy class to test the base class issue."""

def start(self) -> int:
"""Return 0."""
return 0
38 changes: 38 additions & 0 deletions jac/jaclang/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,44 @@ def test_builtins_loading(self) -> None:
r"13\:12 \- 13\:18.*Name - append - .*SymbolPath: builtins_test.builtins.list.append",
)

def test_sub_class_symbol_table_fix_1(self) -> None:
"""Testing for print AstTool."""
from jaclang.settings import settings

settings.ast_symbol_info_detailed = True
captured_output = io.StringIO()
sys.stdout = captured_output

cli.tool("ir", ["ast", f"{self.fixture_abs_path('base_class1.jac')}"])

sys.stdout = sys.__stdout__
stdout_value = captured_output.getvalue()
settings.ast_symbol_info_detailed = False

self.assertRegex(
stdout_value,
r"10:7 - 10:12.*Name - start - Type.*SymbolPath: base_class1.B.start",
)

def test_sub_class_symbol_table_fix_2(self) -> None:
"""Testing for print AstTool."""
from jaclang.settings import settings

settings.ast_symbol_info_detailed = True
captured_output = io.StringIO()
sys.stdout = captured_output

cli.tool("ir", ["ast", f"{self.fixture_abs_path('base_class2.jac')}"])

sys.stdout = sys.__stdout__
stdout_value = captured_output.getvalue()
settings.ast_symbol_info_detailed = False

self.assertRegex(
stdout_value,
r"10:7 - 10:12.*Name - start - Type.*SymbolPath: base_class2.B.start",
)

def test_expr_types(self) -> None:
"""Testing for print AstTool."""
captured_output = io.StringIO()
Expand Down

0 comments on commit 577af09

Please sign in to comment.