diff --git a/tests/functional/codegen/modules/test_interface_imports.py b/tests/functional/codegen/modules/test_interface_imports.py index c0fae6496c..3f0f8cb010 100644 --- a/tests/functional/codegen/modules/test_interface_imports.py +++ b/tests/functional/codegen/modules/test_interface_imports.py @@ -58,3 +58,31 @@ def foo() -> bool: c = get_contract(main, input_bundle=input_bundle) assert c.foo() is True + + +def test_import_interface_flags(make_input_bundle, get_contract): + ifaces = """ +flag Foo: + BOO + MOO + POO + +interface IFoo: + def foo() -> Foo: nonpayable + """ + + contract = """ +import ifaces + +implements: ifaces + +@external +def foo() -> ifaces.Foo: + return ifaces.Foo.POO + """ + + input_bundle = make_input_bundle({"ifaces.vyi": ifaces}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 4 diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index d6cc50a2ea..dabeaf21b6 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -21,7 +21,7 @@ from vyper.semantics.types.base import TYPE_T, VyperType, is_type_t from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.primitives import AddressT -from vyper.semantics.types.user import EventT, StructT, _UserType +from vyper.semantics.types.user import EventT, FlagT, StructT, _UserType from vyper.utils import OrderedSet if TYPE_CHECKING: @@ -45,27 +45,29 @@ def __init__( functions: dict, events: dict, structs: dict, + flags: dict, ) -> None: validate_unique_method_ids(list(functions.values())) - members = functions | events | structs + members = functions | events | structs | flags # sanity check: by construction, there should be no duplicates. - assert len(members) == len(functions) + len(events) + len(structs) + assert len(members) == len(functions) + len(events) + len(structs) + len(flags) super().__init__(functions) - self._helper = VyperType(events | structs) + self._helper = VyperType(events | structs | flags) self._id = _id self._helper._id = _id self.functions = functions self.events = events self.structs = structs + self.flags = flags self.decl_node = decl_node def get_type_member(self, attr, node): - # get an event or struct from this interface + # get an event, struct or flag from this interface return TYPE_T(self._helper.get_member(attr, node)) @property @@ -159,12 +161,14 @@ def _from_lists( interface_name: str, decl_node: Optional[vy_ast.VyperNode], function_list: list[tuple[str, ContractFunctionT]], - event_list: list[tuple[str, EventT]], - struct_list: list[tuple[str, StructT]], + event_list: Optional[list[tuple[str, EventT]]] = None, + struct_list: Optional[list[tuple[str, StructT]]] = None, + flag_list: Optional[list[tuple[str, FlagT]]] = None, ) -> "InterfaceT": - functions = {} - events = {} - structs = {} + functions: dict[str, ContractFunctionT] = {} + events: dict[str, EventT] = {} + structs: dict[str, StructT] = {} + flags: dict[str, FlagT] = {} seen_items: dict = {} @@ -175,19 +179,20 @@ def _mark_seen(name, item): raise NamespaceCollision(msg, item.decl_node, prev_decl=prev_decl) seen_items[name] = item - for name, function in function_list: - _mark_seen(name, function) - functions[name] = function + def _process(dst_dict, items): + if items is None: + return - for name, event in event_list: - _mark_seen(name, event) - events[name] = event + for name, item in items: + _mark_seen(name, item) + dst_dict[name] = item - for name, struct in struct_list: - _mark_seen(name, struct) - structs[name] = struct + _process(functions, function_list) + _process(events, event_list) + _process(structs, struct_list) + _process(flags, flag_list) - return cls(interface_name, decl_node, functions, events, structs) + return cls(interface_name, decl_node, functions, events, structs, flags) @classmethod def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": @@ -214,8 +219,7 @@ def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": for item in [i for i in abi if i.get("type") == "event"]: events.append((item["name"], EventT.from_abi(item))) - structs: list = [] # no structs in json ABI (as of yet) - return cls._from_lists(name, None, functions, events, structs) + return cls._from_lists(name, None, functions, events) @classmethod def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": @@ -247,8 +251,9 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": # these are accessible via import, but they do not show up # in the ABI json structs = [(node.name, node._metadata["struct_type"]) for node in module_t.struct_defs] + flags = [(node.name, node._metadata["flag_type"]) for node in module_t.flag_defs] - return cls._from_lists(module_t._id, module_t.decl_node, funcs, events, structs) + return cls._from_lists(module_t._id, module_t.decl_node, funcs, events, structs, flags) @classmethod def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": @@ -265,11 +270,7 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": ) functions.append((func_ast.name, ContractFunctionT.from_InterfaceDef(func_ast))) - # no structs or events in InterfaceDefs - events: list = [] - structs: list = [] - - return cls._from_lists(node.name, node, functions, events, structs) + return cls._from_lists(node.name, node, functions) # Datatype to store all module information.