From bb0d6df575397ca18a0bdd1134a8d71e47ced2ba Mon Sep 17 00:00:00 2001 From: GnP Date: Thu, 10 Oct 2024 10:45:20 -0300 Subject: [PATCH] Make RANDOM_KEY nonces thread-safe --- src/ethproto/aa_bundler.py | 10 ++++---- tests/test_aa_bundler.py | 47 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/src/ethproto/aa_bundler.py b/src/ethproto/aa_bundler.py index e92cbcf..5cfcb30 100644 --- a/src/ethproto/aa_bundler.py +++ b/src/ethproto/aa_bundler.py @@ -1,6 +1,7 @@ import random from collections import defaultdict from enum import Enum +from threading import local from warnings import warn from environs import Env @@ -54,7 +55,7 @@ ] NONCE_CACHE = defaultdict(lambda: 0) -RANDOM_NONCE_KEY = None +RANDOM_NONCE_KEY = local() def pack_two(a, b): @@ -140,10 +141,9 @@ def fetch_nonce(w3, account, entry_point, nonce_key): def get_random_nonce_key(): - global RANDOM_NONCE_KEY - if RANDOM_NONCE_KEY is None: - RANDOM_NONCE_KEY = random.randint(1, 2**192 - 1) - return RANDOM_NONCE_KEY + if getattr(RANDOM_NONCE_KEY, "key", None) is None: + RANDOM_NONCE_KEY.key = random.randint(1, 2**192 - 1) + return RANDOM_NONCE_KEY.key def get_nonce_and_key(w3, tx, nonce_mode, entry_point=AA_BUNDLER_ENTRYPOINT, fetch=False): diff --git a/tests/test_aa_bundler.py b/tests/test_aa_bundler.py index fc98632..247288c 100644 --- a/tests/test_aa_bundler.py +++ b/tests/test_aa_bundler.py @@ -1,3 +1,5 @@ +from queue import Queue +from threading import Event, Thread from unittest.mock import MagicMock, patch from hexbytes import HexBytes @@ -155,8 +157,8 @@ def test_get_nonce_random_key_mode(fetch_nonce_mock, randint_mock): fetch_nonce_mock.assert_not_called() randint_mock.assert_called_with(1, 2**192 - 1) randint_mock.reset_mock() - assert aa_bundler.RANDOM_NONCE_KEY == 444 - aa_bundler.RANDOM_NONCE_KEY = None # cleanup + assert aa_bundler.RANDOM_NONCE_KEY.key == 444 + aa_bundler.RANDOM_NONCE_KEY.key = None # cleanup @patch.object(aa_bundler.random, "randint") @@ -241,3 +243,44 @@ def make_request(method, params): get_base_fee_mock.assert_called_once_with(w3) assert aa_bundler.NONCE_CACHE[0] == 1 assert ret == {"userOpHash": "0xa950a17ca1ed83e974fb1aa227360a007cb65f566518af117ffdbb04d8d2d524"} + + +def test_random_key_nonces_are_thread_safe(): + queue = Queue() + event = Event() + + def worker(): + event.wait() # Get all threads running at the same time + nonce_key, nonce = aa_bundler.get_nonce_and_key( + FAIL_IF_USED, + {"from": TEST_SENDER}, + nonce_mode=aa_bundler.NonceMode.RANDOM_KEY, + ) + aa_bundler.consume_nonce(nonce_key, nonce) + queue.put( + aa_bundler.get_nonce_and_key( + FAIL_IF_USED, + {"from": TEST_SENDER}, + nonce_mode=aa_bundler.NonceMode.RANDOM_KEY, + ) + ) + + threads = [Thread(target=worker) for _ in range(15)] + for thread in threads: + thread.start() + + # Fire all threads at once + event.set() + for thread in threads: + thread.join() + + nonces = {} + + while not queue.empty(): + nonce_key, nonce = queue.get_nowait() + # Each thread got a different key + assert nonce_key not in nonces + nonces[nonce_key] = nonce + + # All nonces are the same + assert all(nonce == 1 for nonce in nonces.values())