diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 07da3c0ace..c9d7248809 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -1,3 +1,4 @@ +import copy import json from vyper import compiler @@ -216,24 +217,27 @@ def foo(): input_bundle = make_input_bundle({"lib1.vy": lib1, "main.vy": main}) lib1_file = input_bundle.load_file("lib1.vy") - out = compiler.compile_from_file_input( + lib1_out = compiler.compile_from_file_input( lib1_file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"] ) - lib1_ast = out["annotated_ast_dict"]["ast"] + + lib1_ast = copy.deepcopy(lib1_out["annotated_ast_dict"]["ast"]) lib1_sha256sum = lib1_ast.pop("source_sha256sum") assert lib1_sha256sum == lib1_file.sha256sum to_strip = NODE_SRC_ATTRIBUTES + ("resolved_path", "variable_reads", "variable_writes") _strip_source_annotations(lib1_ast, to_strip=to_strip) main_file = input_bundle.load_file("main.vy") - out = compiler.compile_from_file_input( + main_out = compiler.compile_from_file_input( main_file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"] ) - main_ast = out["annotated_ast_dict"]["ast"] + main_ast = main_out["annotated_ast_dict"]["ast"] main_sha256sum = main_ast.pop("source_sha256sum") assert main_sha256sum == main_file.sha256sum _strip_source_annotations(main_ast, to_strip=to_strip) + assert main_out["annotated_ast_dict"]["imports"][0] == lib1_out["annotated_ast_dict"]["ast"] + # TODO: would be nice to refactor this into bunch of small test cases assert main_ast == { "ast_type": "Module", @@ -1776,3 +1780,49 @@ def qux2(): }, } ] + + +def test_annotated_ast_export_recursion(make_input_bundle): + sources = { + "main.vy": """ +import lib1 + +@external +def foo(): + lib1.foo() + """, + "lib1.vy": """ +import lib2 + +def foo(): + lib2.foo() + """, + "lib2.vy": """ +def foo(): + pass + """, + } + + input_bundle = make_input_bundle(sources) + + def compile_and_get_ast(file_name): + file = input_bundle.load_file(file_name) + output = compiler.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=["annotated_ast_dict"] + ) + return output["annotated_ast_dict"] + + lib1_ast = compile_and_get_ast("lib1.vy")["ast"] + lib2_ast = compile_and_get_ast("lib2.vy")["ast"] + main_out = compile_and_get_ast("main.vy") + + lib1_import_ast = main_out["imports"][1] + lib2_import_ast = main_out["imports"][0] + + # path is once virtual, once libX.vy + # type contains name which is based on path + keys = [s for s in lib1_import_ast.keys() if s not in {"path", "type"}] + + for key in keys: + assert lib1_ast[key] == lib1_import_ast[key] + assert lib2_ast[key] == lib2_import_ast[key] diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 991edeca6e..d3c721dbfb 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -949,6 +949,12 @@ def validate(self): class Ellipsis(Constant): __slots__ = () + def to_dict(self): + ast_dict = super().to_dict() + # python ast ellipsis() is not json serializable; use a string + ast_dict["value"] = self.node_source_code + return ast_dict + class Dict(ExprNode): __slots__ = ("keys", "values") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 58c7d0b2e4..3e3a9a62b2 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -70,6 +70,7 @@ class TopLevel(VyperNode): class Module(TopLevel): path: str = ... resolved_path: str = ... + source_id: int = ... def namespace(self) -> Any: ... # context manager class FunctionDef(TopLevel): diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 09d299b90d..d04b677b3e 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -3,7 +3,8 @@ from collections import deque from pathlib import PurePath -from vyper.ast import ast_to_dict +import vyper.ast as vy_ast +from vyper.ast.utils import ast_to_dict from vyper.codegen.ir_node import IRnode from vyper.compiler.output_bundle import SolcJSONWriter, VyperArchiveWriter from vyper.compiler.phases import CompilerData @@ -11,7 +12,9 @@ from vyper.evm import opcodes from vyper.exceptions import VyperException from vyper.ir import compile_ir +from vyper.semantics.analysis.base import ModuleInfo from vyper.semantics.types.function import FunctionVisibility, StateMutability +from vyper.semantics.types.module import InterfaceT from vyper.typing import StorageLayout from vyper.utils import vyper_warn from vyper.warnings import ContractSizeLimitWarning @@ -26,9 +29,32 @@ def build_ast_dict(compiler_data: CompilerData) -> dict: def build_annotated_ast_dict(compiler_data: CompilerData) -> dict: + module_t = compiler_data.annotated_vyper_module._metadata["type"] + # get all reachable imports including recursion + imported_module_infos = module_t.reachable_imports + unique_modules: dict[str, vy_ast.Module] = {} + for info in imported_module_infos: + if isinstance(info.typ, InterfaceT): + ast = info.typ.decl_node + if ast is None: # json abi + continue + else: + assert isinstance(info.typ, ModuleInfo) + ast = info.typ.module_t._module + + assert isinstance(ast, vy_ast.Module) # help mypy + # use resolved_path for uniqueness, since Module objects can actually + # come from multiple InputBundles (particularly builtin interfaces), + # so source_id is not guaranteed to be unique. + if ast.resolved_path in unique_modules: + # sanity check -- objects must be identical + assert unique_modules[ast.resolved_path] is ast + unique_modules[ast.resolved_path] = ast + annotated_ast_dict = { "contract_name": str(compiler_data.contract_path), "ast": ast_to_dict(compiler_data.annotated_vyper_module), + "imports": [ast_to_dict(ast) for ast in unique_modules.values()], } return annotated_ast_dict