Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Module Import Logic #1482

Merged
merged 10 commits into from
Jan 26, 2025
14 changes: 12 additions & 2 deletions jac/jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,8 +1013,9 @@ def resolve_relative_path(self, target_item: Optional[str] = None) -> str:
target = self.dot_path_str
if target_item:
target += f".{target_item}"
base_path = os.path.dirname(self.loc.mod_path)
base_path = base_path if base_path else os.getcwd()
base_path = (
os.getenv("JACPATH") or os.path.dirname(self.loc.mod_path) or os.getcwd()
)
parts = target.split(".")
traversal_levels = self.level - 1 if self.level > 0 else 0
actual_parts = parts[traversal_levels:]
Expand All @@ -1026,6 +1027,15 @@ def resolve_relative_path(self, target_item: Optional[str] = None) -> str:
if os.path.exists(relative_path + ".jac")
else relative_path
)
jacpath = os.getenv("JACPATH")
if not os.path.exists(relative_path) and jacpath:
name_to_find = actual_parts[-1] + ".jac"

# Walk through the single path in JACPATH
for root, _, files in os.walk(jacpath):
if name_to_find in files:
relative_path = os.path.join(root, name_to_find)
break
return relative_path

def normalize(self, deep: bool = False) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion jac/jaclang/compiler/passes/main/import_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def import_jac_mod_from_file(self, target: str) -> ast.Module | None:
self.warnings_had += mod_pass.warnings_had
mod = mod_pass.ir
except Exception as e:
logger.info(e)
logger.error(e)
mod = None
if isinstance(mod, ast.Module):
self.import_table[target] = mod
Expand Down
44 changes: 44 additions & 0 deletions jac/jaclang/compiler/tests/test_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,47 @@ def test_jac_py_import_auto(self) -> None:
"{SomeObj(a=10): 'check'} [MyObj(apple=5, banana=7), MyObj(apple=5, banana=7)]",
stdout_value,
)

def test_import_with_jacpath(self) -> None:
"""Test module import using JACPATH."""
# Set up a temporary JACPATH environment variable
import os
import tempfile

jacpath_dir = tempfile.TemporaryDirectory()
os.environ["JACPATH"] = jacpath_dir.name

# Create a mock Jac file in the JACPATH directory
module_name = "test_module"
jac_file_path = os.path.join(jacpath_dir.name, f"{module_name}.jac")
with open(jac_file_path, "w") as f:
f.write(
"""
with entry {
"Hello from JACPATH!" :> print;
}
"""
)

# Capture the output
captured_output = io.StringIO()
sys.stdout = captured_output

try:
JacMachine(self.fixture_abs_path(__file__)).attach_program(
JacProgram(mod_bundle=None, bytecode=None, sem_ir=None)
)
jac_import(module_name, base_path=__file__)
cli.run(jac_file_path)

# Reset stdout and get the output
sys.stdout = sys.__stdout__
stdout_value = captured_output.getvalue()

self.assertIn("Hello from JACPATH!", stdout_value)

finally:
captured_output.close()
JacMachine.detach()
os.environ.pop("JACPATH", None)
jacpath_dir.cleanup()
29 changes: 26 additions & 3 deletions jac/jaclang/runtimelib/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ def load_jac_mod_as_item(
return getattr(new_module, name, new_module)
except ImportError as e:
logger.error(dump_traceback(e))
# logger.error(
# f"Failed to load {name} from {jac_file_path} in {module.__name__}: {str(e)}"
# )
return None


Expand Down Expand Up @@ -319,6 +316,32 @@ def run_import(
"""Run the import process for Jac modules."""
unique_loaded_items: list[types.ModuleType] = []
module = None
# Gather all possible search paths
jacpaths = os.environ.get("JACPATH", "")
search_paths = [spec.caller_dir]
if jacpaths:
for p in jacpaths.split(os.pathsep):
p = p.strip()
if p and p not in search_paths:
search_paths.append(p)

# Attempt to locate the module file or directory
found_path = None
target_path_components = spec.target.split(".")
for search_path in search_paths:
candidate = os.path.join(search_path, "/".join(target_path_components))
# Check if the candidate is a directory or a .jac file
if (os.path.isdir(candidate)) or (os.path.isfile(candidate + ".jac")):
found_path = candidate
break

# If a suitable path was found, update spec.full_target; otherwise, raise an error
if found_path:
spec.full_target = os.path.abspath(found_path)
else:
raise ImportError(
f"Unable to locate module '{spec.target}' in {search_paths}"
)
if os.path.isfile(spec.full_target + ".jac"):
module_name = self.get_sys_mod_name(spec.full_target + ".jac")
module_name = spec.override_name if spec.override_name else module_name
Expand Down
4 changes: 2 additions & 2 deletions jac/jaclang/runtimelib/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, base_path: str = "") -> None:
self.loaded_modules: dict[str, types.ModuleType] = {}
if not base_path:
base_path = os.getcwd()
# Ensure the base_path is a list rather than a string
self.base_path = base_path
self.base_path_dir = (
os.path.dirname(base_path)
Expand Down Expand Up @@ -306,12 +307,11 @@ def get_bytecode(
return marshal.load(f)

result = compile_jac(full_target, cache_result=cachable)
if result.errors_had or not result.ir.gen.py_bytecode:
if result.errors_had:
for alrt in result.errors_had:
# We're not logging here, it already gets logged as the errors were added to the errors_had list.
# Regardless of the logging, this needs to be sent to the end user, so we'll printing it to stderr.
logger.error(alrt.pretty_print())
return None
if result.ir.gen.py_bytecode is not None:
return marshal.loads(result.ir.gen.py_bytecode)
else:
Expand Down
Loading