diff --git a/tests/functional/codegen/storage_variables/test_getters.py b/tests/functional/codegen/storage_variables/test_getters.py index 5169bd300d..e581d71223 100644 --- a/tests/functional/codegen/storage_variables/test_getters.py +++ b/tests/functional/codegen/storage_variables/test_getters.py @@ -1,3 +1,9 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import OverflowException, TypeMismatch + + def test_state_accessor(get_contract): state_accessor = """ y: HashMap[int128, int128] @@ -98,3 +104,32 @@ def __init__(): if item["type"] == "constructor": continue assert item["stateMutability"] == "view" + + +@pytest.mark.parametrize( + "typ,index,expected_error", + [ + ("uint256", "-1", TypeMismatch), + ("uint256", "0-1", TypeMismatch), + ("uint256", "0-1+1", TypeMismatch), + ("uint256", "2**256", OverflowException), + ("uint256", "2**256 // 2", OverflowException), + ("uint256", "2 * 2**255", OverflowException), + ("int256", "-2**255", TypeMismatch), + ("int256", "-2**256", OverflowException), + ("int256", "2**255", TypeMismatch), + ("int256", "2**256 - 5", OverflowException), + ("int256", "2 * 2**254", TypeMismatch), + ("int8", "*".join(["2"] * 7), TypeMismatch), + ], +) +def test_hashmap_index_checks(typ, index, expected_error): + code = f""" +m: HashMap[{typ}, uint256] + +@external +def foo(): + self.m[{index}] = 2 + """ + with pytest.raises(expected_error): + compile_code(code) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 5b20ef773a..26c6a4ef9f 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -914,14 +914,14 @@ def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: else: base_type = get_exact_type_from_node(node.value) - # get the correct type for the index, it might - # not be exactly base_type.key_type - # note: index_type is validated in types_from_Subscript - index_types = get_possible_types_from_node(node.slice) - index_type = index_types.pop() + if isinstance(base_type, HashMapT): + index_type = base_type.key_type + else: + # Arrays allow most int types as index: Take the least specific + index_type = get_possible_types_from_node(node.slice).pop() - self.visit(node.slice, index_type) self.visit(node.value, base_type) + self.visit(node.slice, index_type) def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None: if isinstance(typ, TYPE_T):