From d60d31f3375db5b1f867c209316d9c29aca5d8b3 Mon Sep 17 00:00:00 2001 From: trocher Date: Thu, 26 Sep 2024 02:52:57 +0200 Subject: [PATCH] fix[lang]: fix `==` and `!=` bytesM folding (#4254) `bytesM` literals are not case-sensitive (they represent the same value no matter if the literal is lower- or upper-case), but the folding operation was case sensitive. --------- Co-authored-by: Charles Cooper --- tests/unit/ast/nodes/test_fold_compare.py | 17 +++++++++++++++++ vyper/semantics/analysis/constant_folding.py | 7 +++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/unit/ast/nodes/test_fold_compare.py b/tests/unit/ast/nodes/test_fold_compare.py index aab8ac0b2d..fd9f65a7d3 100644 --- a/tests/unit/ast/nodes/test_fold_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -110,3 +110,20 @@ def test_compare_type_mismatch(op): old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): old_node.get_folded_value() + + +@pytest.mark.parametrize("op", ["==", "!="]) +def test_compare_eq_bytes(get_contract, op): + left, right = "0xA1AAB33F", "0xa1aab33f" + source = f""" +@external +def foo(a: bytes4, b: bytes4) -> bool: + return a {op} b + """ + contract = get_contract(source) + + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value + new_node = old_node.get_folded_value() + + assert contract.foo(left, right) == new_node.value diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index 6e4166dc52..98cab0f8cb 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -180,8 +180,11 @@ def visit_Compare(self, node): raise UnfoldableNode( f"Invalid literal types for {node.op.description} comparison", node ) - - value = node.op._op(left.value, right.value) + lvalue, rvalue = left.value, right.value + if isinstance(left, vy_ast.Hex): + # Hex values are str, convert to be case-unsensitive. + lvalue, rvalue = lvalue.lower(), rvalue.lower() + value = node.op._op(lvalue, rvalue) return vy_ast.NameConstant.from_node(node, value=value) def visit_List(self, node) -> vy_ast.ExprNode: