Skip to content

Commit 4c66c8c

Browse files
fix[lang]: recursion in uses analysis for nonreentrant functions (vyperlang#3971)
this commit fixes `uses` analysis for nonreentrant functions, which are called recursively. a partial fix for this was applied in cb94068, but it missed the case where a nonreentrant function is deep in the call tree.
1 parent fb55f4c commit 4c66c8c

File tree

2 files changed

+62
-24
lines changed

2 files changed

+62
-24
lines changed

tests/functional/syntax/modules/test_initializers.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,52 +1300,86 @@ def foo():
13001300
assert e.value._hint == "try importing lib1 first"
13011301

13021302

1303-
def test_nonreentrant_exports(make_input_bundle):
1303+
@pytest.fixture
1304+
def nonreentrant_library_bundle(make_input_bundle):
1305+
# test simple case
13041306
lib1 = """
13051307
# lib1.vy
1306-
@external
1308+
@internal
13071309
@nonreentrant
13081310
def bar():
13091311
pass
1312+
1313+
# lib1.vy
1314+
@external
1315+
@nonreentrant
1316+
def ext_bar():
1317+
pass
13101318
"""
1311-
main = """
1319+
# test case with recursion
1320+
lib2 = """
1321+
@internal
1322+
def bar():
1323+
self.baz()
1324+
1325+
@external
1326+
def ext_bar():
1327+
self.baz()
1328+
1329+
@nonreentrant
1330+
@internal
1331+
def baz():
1332+
return
1333+
"""
1334+
# test case with nested recursion
1335+
lib3 = """
13121336
import lib1
1337+
uses: lib1
13131338
1314-
exports: lib1.bar # line 4
1339+
@internal
1340+
def bar():
1341+
lib1.bar()
1342+
1343+
@external
1344+
def ext_bar():
1345+
lib1.bar()
1346+
"""
1347+
1348+
return make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3})
1349+
1350+
1351+
@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3"))
1352+
def test_nonreentrant_exports(nonreentrant_library_bundle, lib):
1353+
main = f"""
1354+
import {lib}
1355+
1356+
exports: {lib}.ext_bar # line 4
13151357
13161358
@external
13171359
def foo():
13181360
pass
13191361
"""
1320-
input_bundle = make_input_bundle({"lib1.vy": lib1})
13211362
with pytest.raises(ImmutableViolation) as e:
1322-
compile_code(main, input_bundle=input_bundle)
1323-
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE
1324-
hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
1363+
compile_code(main, input_bundle=nonreentrant_library_bundle)
1364+
assert e.value._message == f"Cannot access `{lib}` state!" + NONREENTRANT_NOTE
1365+
hint = f"add `uses: {lib}` or `initializes: {lib}` as a top-level statement to your contract"
13251366
assert e.value._hint == hint
13261367
assert e.value.annotations[0].lineno == 4
13271368

13281369

1329-
def test_internal_nonreentrant_import(make_input_bundle):
1330-
lib1 = """
1331-
# lib1.vy
1332-
@internal
1333-
@nonreentrant
1334-
def bar():
1335-
pass
1336-
"""
1337-
main = """
1338-
import lib1
1370+
@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3"))
1371+
def test_internal_nonreentrant_import(nonreentrant_library_bundle, lib):
1372+
main = f"""
1373+
import {lib}
13391374
13401375
@external
13411376
def foo():
1342-
lib1.bar() # line 6
1377+
{lib}.bar() # line 6
13431378
"""
1344-
input_bundle = make_input_bundle({"lib1.vy": lib1})
13451379
with pytest.raises(ImmutableViolation) as e:
1346-
compile_code(main, input_bundle=input_bundle)
1347-
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE
1380+
compile_code(main, input_bundle=nonreentrant_library_bundle)
1381+
assert e.value._message == f"Cannot access `{lib}` state!" + NONREENTRANT_NOTE
13481382

1349-
hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
1383+
hint = f"add `uses: {lib}` or `initializes: {lib}` as a top-level statement to your contract"
13501384
assert e.value._hint == hint
13511385
assert e.value.annotations[0].lineno == 6

vyper/semantics/types/function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ def get_variable_accesses(self):
165165
return self._variable_reads | self._variable_writes
166166

167167
def uses_state(self):
168-
return self.nonreentrant or uses_state(self.get_variable_accesses())
168+
return (
169+
self.nonreentrant
170+
or uses_state(self.get_variable_accesses())
171+
or any(f.nonreentrant for f in self.reachable_internal_functions)
172+
)
169173

170174
def get_used_modules(self):
171175
# _used_modules is populated during analysis

0 commit comments

Comments
 (0)