Skip to content

Commit

Permalink
Merge branch 'master' into fix/unused_var_analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper authored Sep 26, 2024
2 parents 3b2ece2 + d60d31f commit c93d13c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
17 changes: 17 additions & 0 deletions tests/unit/ast/nodes/test_fold_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,20 @@ def test_compare_type_mismatch(op):
old_node = vyper_ast.body[0].value
with pytest.raises(UnfoldableNode):
old_node.get_folded_value()


@pytest.mark.parametrize("op", ["==", "!="])
def test_compare_eq_bytes(get_contract, op):
left, right = "0xA1AAB33F", "0xa1aab33f"
source = f"""
@external
def foo(a: bytes4, b: bytes4) -> bool:
return a {op} b
"""
contract = get_contract(source)

vyper_ast = parse_and_fold(f"{left} {op} {right}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()

assert contract.foo(left, right) == new_node.value
7 changes: 5 additions & 2 deletions vyper/semantics/analysis/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,11 @@ def visit_Compare(self, node):
raise UnfoldableNode(
f"Invalid literal types for {node.op.description} comparison", node
)

value = node.op._op(left.value, right.value)
lvalue, rvalue = left.value, right.value
if isinstance(left, vy_ast.Hex):
# Hex values are str, convert to be case-unsensitive.
lvalue, rvalue = lvalue.lower(), rvalue.lower()
value = node.op._op(lvalue, rvalue)
return vy_ast.NameConstant.from_node(node, value=value)

def visit_List(self, node) -> vy_ast.ExprNode:
Expand Down
18 changes: 10 additions & 8 deletions vyper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ class OrderedSet(Generic[_T]):
"""

def __init__(self, iterable=None):
self._data = dict()
if iterable is not None:
self.update(iterable)
if iterable is None:
self._data = dict()
else:
self._data = dict.fromkeys(iterable)

def __repr__(self):
keys = ", ".join(repr(k) for k in self)
Expand Down Expand Up @@ -57,6 +58,7 @@ def pop(self):
def add(self, item: _T) -> None:
self._data[item] = None

# NOTE to refactor: duplicate of self.update()
def addmany(self, iterable):
for item in iterable:
self._data[item] = None
Expand Down Expand Up @@ -109,11 +111,11 @@ def intersection(cls, *sets):
if len(sets) == 0:
raise ValueError("undefined: intersection of no sets")

ret = sets[0].copy()
for e in sets[0]:
if any(e not in s for s in sets[1:]):
ret.remove(e)
return ret
tmp = sets[0]._data.keys()
for s in sets[1:]:
tmp &= s._data.keys()

return cls(tmp)


class StringEnum(enum.Enum):
Expand Down

0 comments on commit c93d13c

Please sign in to comment.