diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 0a62312..93cc140 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -2,9 +2,9 @@ name: Tests on: push: - branches: ["main"] + branches: ["main", "v1.1.x"] pull_request: - branches: ["main"] + branches: ["main", "v1.1.x"] jobs: build: diff --git a/src/ethproto/aa_bundler.py b/src/ethproto/aa_bundler.py index fc3687c..e532d71 100644 --- a/src/ethproto/aa_bundler.py +++ b/src/ethproto/aa_bundler.py @@ -1,7 +1,9 @@ import random -from warnings import warn +from collections import defaultdict from enum import Enum -import requests +from threading import local +from warnings import warn + from environs import Env from eth_abi import encode from eth_account import Account @@ -9,8 +11,8 @@ from hexbytes import HexBytes from web3 import Web3 from web3.constants import ADDRESS_ZERO -from .contracts import RevertError +from .contracts import RevertError env = Env() @@ -21,6 +23,7 @@ AA_BUNDLER_GAS_LIMIT_FACTOR = env.float("AA_BUNDLER_GAS_LIMIT_FACTOR", 1) AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR = env.float("AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR", 1) AA_BUNDLER_BASE_GAS_PRICE_FACTOR = env.float("AA_BUNDLER_BASE_GAS_PRICE_FACTOR", 1) +AA_BUNDLER_VERIFICATION_GAS_FACTOR = env.float("AA_BUNDLER_VERIFICATION_GAS_FACTOR", 1) NonceMode = Enum( "NonceMode", @@ -50,8 +53,8 @@ } ] -NONCE_CACHE = {} -RANDOM_NONCE_KEY = None +NONCE_CACHE = defaultdict(lambda: 0) +RANDOM_NONCE_KEY = local() def pack_two(a, b): @@ -68,6 +71,10 @@ def _to_uint(x): raise RuntimeError(f"Invalid int value {x}") +def apply_factor(x, factor): + return int(_to_uint(x) * factor) + + def pack_user_operation(user_operation): # https://github.com/eth-infinitism/account-abstraction/blob/develop/contracts/interfaces/PackedUserOperation.sol return { @@ -133,10 +140,9 @@ def fetch_nonce(w3, account, entry_point, nonce_key): def get_random_nonce_key(): - global RANDOM_NONCE_KEY - if RANDOM_NONCE_KEY is None: - RANDOM_NONCE_KEY = random.randint(1, 2**192 - 1) - return RANDOM_NONCE_KEY + if getattr(RANDOM_NONCE_KEY, "key", None) is None: + RANDOM_NONCE_KEY.key = random.randint(1, 2**192 - 1) + return RANDOM_NONCE_KEY.key def get_nonce_and_key(w3, tx, nonce_mode, entry_point=AA_BUNDLER_ENTRYPOINT, fetch=False): @@ -152,20 +158,25 @@ def get_nonce_and_key(w3, tx, nonce_mode, entry_point=AA_BUNDLER_ENTRYPOINT, fet if nonce is None: if fetch or nonce_mode == NonceMode.FIXED_KEY_FETCH_ALWAYS: nonce = fetch_nonce(w3, get_sender(tx), entry_point, nonce_key) - elif nonce_key not in NONCE_CACHE: - nonce = 0 else: nonce = NONCE_CACHE[nonce_key] return nonce_key, nonce -def handle_response_error(resp, w3, tx, retry_nonce): +def consume_nonce(nonce_key, nonce): + NONCE_CACHE[nonce_key] = max(NONCE_CACHE[nonce_key], nonce + 1) + + +def check_nonce_error(resp, retry_nonce): + """Returns the next nonce if resp contains a nonce error and retries weren't exhausted + Raises RevertError otherwise + """ if "AA25" in resp["error"]["message"] and AA_BUNDLER_MAX_GETNONCE_RETRIES > 0: # Retry fetching the nonce if retry_nonce == AA_BUNDLER_MAX_GETNONCE_RETRIES: raise RevertError(resp["error"]["message"]) warn(f'{resp["error"]["message"]} error, I will retry fetching the nonce') - return send_transaction(w3, tx, retry_nonce=(retry_nonce or 0) + 1) + return (retry_nonce or 0) + 1 else: raise RevertError(resp["error"]["message"]) @@ -184,7 +195,7 @@ def get_sender(tx): return tx["from"] -def send_transaction(w3, tx, retry_nonce=None): +def build_user_operation(w3, tx, retry_nonce=None): nonce_key, nonce = get_nonce_and_key( w3, tx, AA_BUNDLER_NONCE_MODE, entry_point=AA_BUNDLER_ENTRYPOINT, fetch=retry_nonce is not None ) @@ -209,42 +220,63 @@ def send_transaction(w3, tx, retry_nonce=None): "eth_estimateUserOperationGas", [user_operation, AA_BUNDLER_ENTRYPOINT] ) if "error" in resp: - return handle_response_error(resp, w3, tx, retry_nonce) + next_nonce = check_nonce_error(resp, retry_nonce) + return build_user_operation(w3, tx, retry_nonce=next_nonce) user_operation.update(resp["result"]) resp = w3.provider.make_request("rundler_maxPriorityFeePerGas", []) if "error" in resp: raise RevertError(resp["error"]["message"]) - max_priority_fee_per_gas = int(_to_uint(resp["result"]) * AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR) - user_operation["maxPriorityFeePerGas"] = hex(max_priority_fee_per_gas) - user_operation["maxFeePerGas"] = hex(max_priority_fee_per_gas + get_base_fee(w3)) - user_operation["callGasLimit"] = hex( - int(_to_uint(user_operation["callGasLimit"]) * AA_BUNDLER_GAS_LIMIT_FACTOR) - ) - elif AA_BUNDLER_PROVIDER == "gelato": - user_operation.update( - { - "preVerificationGas": "0x00", - "callGasLimit": "0x00", - "verificationGasLimit": "0x00", - "maxFeePerGas": "0x00", - "maxPriorityFeePerGas": "0x00", - } + user_operation["maxPriorityFeePerGas"] = resp["result"] + user_operation["maxFeePerGas"] = hex(int(resp["result"], 16) + get_base_fee(w3)) + + elif AA_BUNDLER_PROVIDER == "generic": + resp = w3.provider.make_request( + "eth_estimateUserOperationGas", [user_operation, AA_BUNDLER_ENTRYPOINT] ) - user_operation["signature"] = sign_user_operation( - AA_BUNDLER_EXECUTOR_PK, user_operation, tx["chainId"], AA_BUNDLER_ENTRYPOINT + if "error" in resp: + next_nonce = check_nonce_error(resp, retry_nonce) + return build_user_operation(w3, tx, retry_nonce=next_nonce) + + user_operation.update(resp["result"]) + + else: + warn(f"Unknown AA_BUNDLER_PROVIDER: {AA_BUNDLER_PROVIDER}") + + # Apply increase factors + user_operation["verificationGasLimit"] = hex( + apply_factor(user_operation["verificationGasLimit"], AA_BUNDLER_VERIFICATION_GAS_FACTOR) ) + if "maxPriorityFeePerGas" in user_operation: + user_operation["maxPriorityFeePerGas"] = hex( + apply_factor(user_operation["maxPriorityFeePerGas"], AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR) + ) + if "callGasLimit" in user_operation: + user_operation["callGasLimit"] = hex( + apply_factor(user_operation["callGasLimit"], AA_BUNDLER_GAS_LIMIT_FACTOR) + ) + # Remove paymaster related fields user_operation.pop("paymaster", None) user_operation.pop("paymasterData", None) user_operation.pop("paymasterVerificationGasLimit", None) user_operation.pop("paymasterPostOpGasLimit", None) + # Consume the nonce, even if the userop may fail later + consume_nonce(nonce_key, nonce) + + return user_operation + + +def send_transaction(w3, tx, retry_nonce=None): + user_operation = build_user_operation(w3, tx, retry_nonce) + user_operation["signature"] = sign_user_operation( + AA_BUNDLER_EXECUTOR_PK, user_operation, tx["chainId"], AA_BUNDLER_ENTRYPOINT + ) resp = w3.provider.make_request("eth_sendUserOperation", [user_operation, AA_BUNDLER_ENTRYPOINT]) if "error" in resp: - return handle_response_error(resp, w3, tx, retry_nonce) + next_nonce = check_nonce_error(resp, retry_nonce) + return send_transaction(w3, tx, retry_nonce=next_nonce) - # Store nonce in the cache, so next time uses a new nonce - NONCE_CACHE[nonce_key] = nonce + 1 return {"userOpHash": resp["result"]} diff --git a/src/ethproto/w3wrappers.py b/src/ethproto/w3wrappers.py index f42be56..06b7ac7 100644 --- a/src/ethproto/w3wrappers.py +++ b/src/ethproto/w3wrappers.py @@ -444,7 +444,7 @@ def deploy(self, eth_contract, init_params, from_, **kwargs): def get_events(self, eth_wrapper, event_name, filter_kwargs={}): """Returns a list of events given a filter, like this: - >>> provider.get_events(currencywrapper, "Transfer", dict(fromBlock=0)) + >>> provider.get_events(currencywrapper, "Transfer", dict(from_block=0)) [AttributeDict({ 'args': AttributeDict( {'from': '0x0000000000000000000000000000000000000000', @@ -463,8 +463,8 @@ def get_events(self, eth_wrapper, event_name, filter_kwargs={}): """ contract = eth_wrapper.contract event = getattr(contract.events, event_name) - if "fromBlock" not in filter_kwargs: - filter_kwargs["fromBlock"] = self.get_first_block(eth_wrapper) + if "from_block" not in filter_kwargs: + filter_kwargs["from_block"] = self.get_first_block(eth_wrapper) event_filter = event.create_filter(**filter_kwargs) return event_filter.get_all_entries() @@ -490,7 +490,7 @@ def init_eth_wrapper(self, eth_wrapper, owner, init_params, kwargs): constructor_params, init_params = init_params real_contract = self.construct(eth_contract, constructor_params, {"from": eth_wrapper.owner}) ERC1967Proxy = self.get_contract_factory("ERC1967Proxy") - init_data = eth_contract.encodeABI(fn_name="initialize", args=init_params) + init_data = eth_contract.encode_abi(abi_element_identifier="initialize", args=init_params) proxy_contract = self.construct( ERC1967Proxy, (real_contract.address, init_data), diff --git a/tests/hardhat-project/contracts/EventLauncher.sol b/tests/hardhat-project/contracts/EventLauncher.sol new file mode 100644 index 0000000..70fdcf2 --- /dev/null +++ b/tests/hardhat-project/contracts/EventLauncher.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.9; + +contract EventLauncher { + event Event1(uint256 value); + + event Event2(uint256 value); + + function launchEvent1(uint256 value) public { + emit Event1(value); + } + + function launchEvent2(uint256 value) public { + emit Event2(value); + } +} \ No newline at end of file diff --git a/tests/test_aa_bundler.py b/tests/test_aa_bundler.py index c065b82..07022cc 100644 --- a/tests/test_aa_bundler.py +++ b/tests/test_aa_bundler.py @@ -1,8 +1,11 @@ -import os +from queue import Queue +from threading import Event, Thread +from unittest.mock import MagicMock, patch + from hexbytes import HexBytes -from ethproto import aa_bundler from web3.constants import HASH_ZERO -from unittest.mock import MagicMock, patch + +from ethproto import aa_bundler def test_pack_two(): @@ -66,9 +69,7 @@ def test_hash_packed_user_operation(): def test_sign_user_operation(): signature = aa_bundler.sign_user_operation(TEST_PRIVATE_KEY, user_operation, CHAIN_ID, ENTRYPOINT) - assert ( - signature - == "0xb9b872bfe4e90f4628e8ec24879a5b01045f91da8457f3ce2b417d2e5774b508261ec1147a820e75a141cb61b884a78d7e88996ceddafb9a7016cfe7a48a1f4f1b" # noqa + assert (signature == "0xb9b872bfe4e90f4628e8ec24879a5b01045f91da8457f3ce2b417d2e5774b508261ec1147a820e75a141cb61b884a78d7e88996ceddafb9a7016cfe7a48a1f4f1b" # noqa ) @@ -76,9 +77,7 @@ def test_sign_user_operation_gas_diff(): user_operation_2 = dict(user_operation) user_operation_2["maxPriorityFeePerGas"] -= 1 signature = aa_bundler.sign_user_operation(TEST_PRIVATE_KEY, user_operation_2, CHAIN_ID, ENTRYPOINT) - assert ( - signature - == "0x8162479d2dbd18d7fe93a2f51e283021d6e4eae4f57d20cdd553042723a0b0ea690ab3903d45126b0047da08ab53dfdf86656e4f258ac4936ba96a759ccb77f61b" # noqa + assert (signature == "0x8162479d2dbd18d7fe93a2f51e283021d6e4eae4f57d20cdd553042723a0b0ea690ab3903d45126b0047da08ab53dfdf86656e4f258ac4936ba96a759ccb77f61b" # noqa ) @@ -156,8 +155,8 @@ def test_get_nonce_random_key_mode(fetch_nonce_mock, randint_mock): fetch_nonce_mock.assert_not_called() randint_mock.assert_called_with(1, 2**192 - 1) randint_mock.reset_mock() - assert aa_bundler.RANDOM_NONCE_KEY == 444 - aa_bundler.RANDOM_NONCE_KEY = None # cleanup + assert aa_bundler.RANDOM_NONCE_KEY.key == 444 + aa_bundler.RANDOM_NONCE_KEY.key = None # cleanup @patch.object(aa_bundler.random, "randint") @@ -242,3 +241,44 @@ def make_request(method, params): get_base_fee_mock.assert_called_once_with(w3) assert aa_bundler.NONCE_CACHE[0] == 1 assert ret == {"userOpHash": "0xa950a17ca1ed83e974fb1aa227360a007cb65f566518af117ffdbb04d8d2d524"} + + +def test_random_key_nonces_are_thread_safe(): + queue = Queue() + event = Event() + + def worker(): + event.wait() # Get all threads running at the same time + nonce_key, nonce = aa_bundler.get_nonce_and_key( + FAIL_IF_USED, + {"from": TEST_SENDER}, + nonce_mode=aa_bundler.NonceMode.RANDOM_KEY, + ) + aa_bundler.consume_nonce(nonce_key, nonce) + queue.put( + aa_bundler.get_nonce_and_key( + FAIL_IF_USED, + {"from": TEST_SENDER}, + nonce_mode=aa_bundler.NonceMode.RANDOM_KEY, + ) + ) + + threads = [Thread(target=worker) for _ in range(15)] + for thread in threads: + thread.start() + + # Fire all threads at once + event.set() + for thread in threads: + thread.join() + + nonces = {} + + while not queue.empty(): + nonce_key, nonce = queue.get_nowait() + # Each thread got a different key + assert nonce_key not in nonces + nonces[nonce_key] = nonce + + # All nonces are the same + assert all(nonce == 1 for nonce in nonces.values()) diff --git a/tests/test_w3.py b/tests/test_w3.py index a19d936..a217973 100644 --- a/tests/test_w3.py +++ b/tests/test_w3.py @@ -77,3 +77,32 @@ def test_wrapper_build_from_def(): assert counter.value() == 0 counter.increase() assert counter.value() == 1 + + +def test_get_events(): + provider = wrappers.get_provider("w3") + contract_def = provider.get_contract_def("EventLauncher") + wrapper = wrappers.ETHWrapper.build_from_def(contract_def) + + launcher = wrapper() + + launcher.launchEvent1(1) + + cutoff_block = provider.w3.eth.get_block("latest") + launcher.launchEvent2(2) + launcher.launchEvent1(3) + + all_event1 = provider.get_events(launcher, "Event1", dict(from_block=0)) + assert len(all_event1) == 2 + + first_event1_only = provider.get_events(launcher, "Event1", dict(to_block=cutoff_block.number)) + assert len(first_event1_only) == 1 + assert first_event1_only[0] == all_event1[0] + + last_event1_only = provider.get_events(launcher, "Event1", dict(from_block=cutoff_block.number + 1)) + assert len(last_event1_only) == 1 + assert last_event1_only[0] == all_event1[-1] + + event2 = provider.get_events(launcher, "Event2") + assert len(event2) == 1 + assert event2[0].args.value == 2