Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: support flags from imported interfaces #4253

Merged
merged 20 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions vyper/semantics/analysis/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.module import InterfaceT


def constant_fold(module_ast: vy_ast.Module):
Expand Down Expand Up @@ -105,12 +106,15 @@ def visit_Attribute(self, node) -> vy_ast.ExprNode:
# not super type-safe but we don't care. just catch AttributeErrors
# and move on
try:
module_t = namespace[value.id].module_t

ns_member = namespace[value.id]
module_t = ns_member if isinstance(ns_member, InterfaceT) else ns_member.module_t
for module_name in path:
module_t = module_t.members[module_name].module_t

varinfo = module_t.get_member(node.attr, node)
varinfo = (
module_t.get_type_member(node.attr, node)
if isinstance(module_t, InterfaceT)
else module_t.get_member(node.attr, node)
)

return varinfo.decl_node.value.get_folded_value()
except (VyperException, AttributeError, KeyError):
Expand Down
91 changes: 64 additions & 27 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
StructureException,
UnfoldableNode,
)
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.base import Modifiability, VarInfo
Fixed Show fixed Hide fixed
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
Expand All @@ -21,7 +21,7 @@
from vyper.semantics.types.base import TYPE_T, VyperType, is_type_t
from vyper.semantics.types.function import ContractFunctionT
from vyper.semantics.types.primitives import AddressT
from vyper.semantics.types.user import EventT, StructT, _UserType
from vyper.semantics.types.user import EventT, FlagT, StructT, _UserType
Fixed Show fixed Hide fixed

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.types.user
begins an import cycle.
from vyper.utils import OrderedSet, sha256sum

if TYPE_CHECKING:
Expand All @@ -45,28 +45,41 @@
functions: dict,
events: dict,
structs: dict,
flags: dict,
constants: dict,
) -> None:
validate_unique_method_ids(list(functions.values()))

members = functions | events | structs
public_constants = {
k: varinfo for k, varinfo in constants.items() if varinfo.decl_node.is_public
}

members = functions | events | structs | flags | constants

# sanity check: by construction, there should be no duplicates.
assert len(members) == len(functions) + len(events) + len(structs)
assert len(members) == len(functions) + len(events) + len(structs) + len(flags) + len(
constants
) - len(public_constants)

super().__init__(functions)

self._helper = VyperType(events | structs)
self._helper = VyperType(events | structs | flags | constants)
self._id = _id
self._helper._id = _id
self.functions = functions
self.events = events
self.structs = structs
self.flags = flags
self.constants = constants

self.decl_node = decl_node

def get_type_member(self, attr, node):
# get an event or struct from this interface
return TYPE_T(self._helper.get_member(attr, node))
# get an event, struct or constant from this interface
type_member = self._helper.get_member(attr, node)
if isinstance(type_member, (EventT, FlagT, StructT)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
if isinstance(type_member, (EventT, FlagT, StructT)):
if isinstance(type_member, VyperType):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or UserType might be better

return TYPE_T(type_member)
return type_member

Check warning on line 82 in vyper/semantics/types/module.py

View check run for this annotation

Codecov / codecov/patch

vyper/semantics/types/module.py#L82

Added line #L82 was not covered by tests

@property
def getter_signature(self):
Expand Down Expand Up @@ -159,35 +172,57 @@
interface_name: str,
decl_node: Optional[vy_ast.VyperNode],
function_list: list[tuple[str, ContractFunctionT]],
event_list: list[tuple[str, EventT]],
struct_list: list[tuple[str, StructT]],
event_list: Optional[list[tuple[str, EventT]]] = None,
struct_list: Optional[list[tuple[str, StructT]]] = None,
flag_list: Optional[list[tuple[str, FlagT]]] = None,
constant_list: Optional[list[tuple[str, VarInfo]]] = None,
) -> "InterfaceT":
functions = {}
events = {}
structs = {}
flags = {}
constants = {}

seen_items: dict = {}

def _mark_seen(name, item):
if name in seen_items:
prev = seen_items[name]
if (
isinstance(prev, ContractFunctionT)
and isinstance(item, VarInfo)
and item.decl_node.is_public
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, why is this needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, a public constant would result in a duplicate because it would be in constants and its public getter would be in functions.

return
msg = f"multiple functions or events named '{name}'!"
prev_decl = seen_items[name].decl_node
raise NamespaceCollision(msg, item.decl_node, prev_decl=prev_decl)
raise NamespaceCollision(msg, item.decl_node, prev_decl=prev.decl_node)
seen_items[name] = item

for name, function in function_list:
_mark_seen(name, function)
functions[name] = function

for name, event in event_list:
_mark_seen(name, event)
events[name] = event
if event_list:
for name, event in event_list:
_mark_seen(name, event)
events[name] = event

if struct_list:
for name, struct in struct_list:
_mark_seen(name, struct)
structs[name] = struct

if flag_list:
for name, flag in flag_list:
_mark_seen(name, flag)
flags[name] = flag

for name, struct in struct_list:
_mark_seen(name, struct)
structs[name] = struct
if constant_list:
for name, constant in constant_list:
_mark_seen(name, constant)
constants[name] = constant

return cls(interface_name, decl_node, functions, events, structs)
return cls(interface_name, decl_node, functions, events, structs, flags, constants)

@classmethod
def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT":
Expand All @@ -214,8 +249,7 @@
for item in [i for i in abi if i.get("type") == "event"]:
events.append((item["name"], EventT.from_abi(item)))

structs: list = [] # no structs in json ABI (as of yet)
return cls._from_lists(name, None, functions, events, structs)
return cls._from_lists(name, None, functions, events)

@classmethod
def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT":
Expand Down Expand Up @@ -247,8 +281,15 @@
# these are accessible via import, but they do not show up
# in the ABI json
structs = [(node.name, node._metadata["struct_type"]) for node in module_t.struct_defs]

return cls._from_lists(module_t._id, module_t.decl_node, funcs, events, structs)
flags = [(node.name, node._metadata["flag_type"]) for node in module_t.flag_defs]
constants = [
(node.target.id, node.target._metadata["varinfo"])
for node in module_t.variable_decls
if node.is_constant
]
return cls._from_lists(
module_t._id, module_t.decl_node, funcs, events, structs, flags, constants
)

@classmethod
def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT":
Expand All @@ -265,11 +306,7 @@
)
functions.append((func_ast.name, ContractFunctionT.from_InterfaceDef(func_ast)))

# no structs or events in InterfaceDefs
events: list = []
structs: list = []

return cls._from_lists(node.name, node, functions, events, structs)
return cls._from_lists(node.name, node, functions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much cleaner now, thanks



# Datatype to store all module information.
Expand Down
Loading