Skip to content

Commit

Permalink
Merge pull request #12 from gnarvaja/port-aa-bundler-fixes-to-1.1
Browse files Browse the repository at this point in the history
Port 1.2.x aa-bundler fixes to 1.1.x branch
  • Loading branch information
gnarvaja authored Nov 5, 2024
2 parents 3cb7146 + 1091672 commit fb7529b
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 52 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Tests

on:
push:
branches: ["main"]
branches: ["main", "v1.1.x"]
pull_request:
branches: ["main"]
branches: ["main", "v1.1.x"]

jobs:
build:
Expand Down
102 changes: 67 additions & 35 deletions src/ethproto/aa_bundler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import random
from warnings import warn
from collections import defaultdict
from enum import Enum
import requests
from threading import local
from warnings import warn

from environs import Env
from eth_abi import encode
from eth_account import Account
from eth_account.messages import encode_defunct
from hexbytes import HexBytes
from web3 import Web3
from web3.constants import ADDRESS_ZERO
from .contracts import RevertError

from .contracts import RevertError

env = Env()

Expand All @@ -21,6 +23,7 @@
AA_BUNDLER_GAS_LIMIT_FACTOR = env.float("AA_BUNDLER_GAS_LIMIT_FACTOR", 1)
AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR = env.float("AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR", 1)
AA_BUNDLER_BASE_GAS_PRICE_FACTOR = env.float("AA_BUNDLER_BASE_GAS_PRICE_FACTOR", 1)
AA_BUNDLER_VERIFICATION_GAS_FACTOR = env.float("AA_BUNDLER_VERIFICATION_GAS_FACTOR", 1)

NonceMode = Enum(
"NonceMode",
Expand Down Expand Up @@ -50,8 +53,8 @@
}
]

NONCE_CACHE = {}
RANDOM_NONCE_KEY = None
NONCE_CACHE = defaultdict(lambda: 0)
RANDOM_NONCE_KEY = local()


def pack_two(a, b):
Expand All @@ -68,6 +71,10 @@ def _to_uint(x):
raise RuntimeError(f"Invalid int value {x}")


def apply_factor(x, factor):
return int(_to_uint(x) * factor)


def pack_user_operation(user_operation):
# https://github.com/eth-infinitism/account-abstraction/blob/develop/contracts/interfaces/PackedUserOperation.sol
return {
Expand Down Expand Up @@ -133,10 +140,9 @@ def fetch_nonce(w3, account, entry_point, nonce_key):


def get_random_nonce_key():
global RANDOM_NONCE_KEY
if RANDOM_NONCE_KEY is None:
RANDOM_NONCE_KEY = random.randint(1, 2**192 - 1)
return RANDOM_NONCE_KEY
if getattr(RANDOM_NONCE_KEY, "key", None) is None:
RANDOM_NONCE_KEY.key = random.randint(1, 2**192 - 1)
return RANDOM_NONCE_KEY.key


def get_nonce_and_key(w3, tx, nonce_mode, entry_point=AA_BUNDLER_ENTRYPOINT, fetch=False):
Expand All @@ -152,20 +158,25 @@ def get_nonce_and_key(w3, tx, nonce_mode, entry_point=AA_BUNDLER_ENTRYPOINT, fet
if nonce is None:
if fetch or nonce_mode == NonceMode.FIXED_KEY_FETCH_ALWAYS:
nonce = fetch_nonce(w3, get_sender(tx), entry_point, nonce_key)
elif nonce_key not in NONCE_CACHE:
nonce = 0
else:
nonce = NONCE_CACHE[nonce_key]
return nonce_key, nonce


def handle_response_error(resp, w3, tx, retry_nonce):
def consume_nonce(nonce_key, nonce):
NONCE_CACHE[nonce_key] = max(NONCE_CACHE[nonce_key], nonce + 1)


def check_nonce_error(resp, retry_nonce):
"""Returns the next nonce if resp contains a nonce error and retries weren't exhausted
Raises RevertError otherwise
"""
if "AA25" in resp["error"]["message"] and AA_BUNDLER_MAX_GETNONCE_RETRIES > 0:
# Retry fetching the nonce
if retry_nonce == AA_BUNDLER_MAX_GETNONCE_RETRIES:
raise RevertError(resp["error"]["message"])
warn(f'{resp["error"]["message"]} error, I will retry fetching the nonce')
return send_transaction(w3, tx, retry_nonce=(retry_nonce or 0) + 1)
return (retry_nonce or 0) + 1
else:
raise RevertError(resp["error"]["message"])

Expand All @@ -184,7 +195,7 @@ def get_sender(tx):
return tx["from"]


def send_transaction(w3, tx, retry_nonce=None):
def build_user_operation(w3, tx, retry_nonce=None):
nonce_key, nonce = get_nonce_and_key(
w3, tx, AA_BUNDLER_NONCE_MODE, entry_point=AA_BUNDLER_ENTRYPOINT, fetch=retry_nonce is not None
)
Expand All @@ -209,42 +220,63 @@ def send_transaction(w3, tx, retry_nonce=None):
"eth_estimateUserOperationGas", [user_operation, AA_BUNDLER_ENTRYPOINT]
)
if "error" in resp:
return handle_response_error(resp, w3, tx, retry_nonce)
next_nonce = check_nonce_error(resp, retry_nonce)
return build_user_operation(w3, tx, retry_nonce=next_nonce)

user_operation.update(resp["result"])

resp = w3.provider.make_request("rundler_maxPriorityFeePerGas", [])
if "error" in resp:
raise RevertError(resp["error"]["message"])
max_priority_fee_per_gas = int(_to_uint(resp["result"]) * AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR)
user_operation["maxPriorityFeePerGas"] = hex(max_priority_fee_per_gas)
user_operation["maxFeePerGas"] = hex(max_priority_fee_per_gas + get_base_fee(w3))
user_operation["callGasLimit"] = hex(
int(_to_uint(user_operation["callGasLimit"]) * AA_BUNDLER_GAS_LIMIT_FACTOR)
)
elif AA_BUNDLER_PROVIDER == "gelato":
user_operation.update(
{
"preVerificationGas": "0x00",
"callGasLimit": "0x00",
"verificationGasLimit": "0x00",
"maxFeePerGas": "0x00",
"maxPriorityFeePerGas": "0x00",
}
user_operation["maxPriorityFeePerGas"] = resp["result"]
user_operation["maxFeePerGas"] = hex(int(resp["result"], 16) + get_base_fee(w3))

elif AA_BUNDLER_PROVIDER == "generic":
resp = w3.provider.make_request(
"eth_estimateUserOperationGas", [user_operation, AA_BUNDLER_ENTRYPOINT]
)
user_operation["signature"] = sign_user_operation(
AA_BUNDLER_EXECUTOR_PK, user_operation, tx["chainId"], AA_BUNDLER_ENTRYPOINT
if "error" in resp:
next_nonce = check_nonce_error(resp, retry_nonce)
return build_user_operation(w3, tx, retry_nonce=next_nonce)

user_operation.update(resp["result"])

else:
warn(f"Unknown AA_BUNDLER_PROVIDER: {AA_BUNDLER_PROVIDER}")

# Apply increase factors
user_operation["verificationGasLimit"] = hex(
apply_factor(user_operation["verificationGasLimit"], AA_BUNDLER_VERIFICATION_GAS_FACTOR)
)
if "maxPriorityFeePerGas" in user_operation:
user_operation["maxPriorityFeePerGas"] = hex(
apply_factor(user_operation["maxPriorityFeePerGas"], AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR)
)
if "callGasLimit" in user_operation:
user_operation["callGasLimit"] = hex(
apply_factor(user_operation["callGasLimit"], AA_BUNDLER_GAS_LIMIT_FACTOR)
)

# Remove paymaster related fields
user_operation.pop("paymaster", None)
user_operation.pop("paymasterData", None)
user_operation.pop("paymasterVerificationGasLimit", None)
user_operation.pop("paymasterPostOpGasLimit", None)

# Consume the nonce, even if the userop may fail later
consume_nonce(nonce_key, nonce)

return user_operation


def send_transaction(w3, tx, retry_nonce=None):
user_operation = build_user_operation(w3, tx, retry_nonce)
user_operation["signature"] = sign_user_operation(
AA_BUNDLER_EXECUTOR_PK, user_operation, tx["chainId"], AA_BUNDLER_ENTRYPOINT
)
resp = w3.provider.make_request("eth_sendUserOperation", [user_operation, AA_BUNDLER_ENTRYPOINT])
if "error" in resp:
return handle_response_error(resp, w3, tx, retry_nonce)
next_nonce = check_nonce_error(resp, retry_nonce)
return send_transaction(w3, tx, retry_nonce=next_nonce)

# Store nonce in the cache, so next time uses a new nonce
NONCE_CACHE[nonce_key] = nonce + 1
return {"userOpHash": resp["result"]}
8 changes: 4 additions & 4 deletions src/ethproto/w3wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def deploy(self, eth_contract, init_params, from_, **kwargs):
def get_events(self, eth_wrapper, event_name, filter_kwargs={}):
"""Returns a list of events given a filter, like this:
>>> provider.get_events(currencywrapper, "Transfer", dict(fromBlock=0))
>>> provider.get_events(currencywrapper, "Transfer", dict(from_block=0))
[AttributeDict({
'args': AttributeDict(
{'from': '0x0000000000000000000000000000000000000000',
Expand All @@ -463,8 +463,8 @@ def get_events(self, eth_wrapper, event_name, filter_kwargs={}):
"""
contract = eth_wrapper.contract
event = getattr(contract.events, event_name)
if "fromBlock" not in filter_kwargs:
filter_kwargs["fromBlock"] = self.get_first_block(eth_wrapper)
if "from_block" not in filter_kwargs:
filter_kwargs["from_block"] = self.get_first_block(eth_wrapper)
event_filter = event.create_filter(**filter_kwargs)
return event_filter.get_all_entries()

Expand All @@ -490,7 +490,7 @@ def init_eth_wrapper(self, eth_wrapper, owner, init_params, kwargs):
constructor_params, init_params = init_params
real_contract = self.construct(eth_contract, constructor_params, {"from": eth_wrapper.owner})
ERC1967Proxy = self.get_contract_factory("ERC1967Proxy")
init_data = eth_contract.encodeABI(fn_name="initialize", args=init_params)
init_data = eth_contract.encode_abi(abi_element_identifier="initialize", args=init_params)
proxy_contract = self.construct(
ERC1967Proxy,
(real_contract.address, init_data),
Expand Down
16 changes: 16 additions & 0 deletions tests/hardhat-project/contracts/EventLauncher.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.9;

contract EventLauncher {
event Event1(uint256 value);

event Event2(uint256 value);

function launchEvent1(uint256 value) public {
emit Event1(value);
}

function launchEvent2(uint256 value) public {
emit Event2(value);
}
}
62 changes: 51 additions & 11 deletions tests/test_aa_bundler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from queue import Queue
from threading import Event, Thread
from unittest.mock import MagicMock, patch

from hexbytes import HexBytes
from ethproto import aa_bundler
from web3.constants import HASH_ZERO
from unittest.mock import MagicMock, patch

from ethproto import aa_bundler


def test_pack_two():
Expand Down Expand Up @@ -66,19 +69,15 @@ def test_hash_packed_user_operation():

def test_sign_user_operation():
signature = aa_bundler.sign_user_operation(TEST_PRIVATE_KEY, user_operation, CHAIN_ID, ENTRYPOINT)
assert (
signature
== "0xb9b872bfe4e90f4628e8ec24879a5b01045f91da8457f3ce2b417d2e5774b508261ec1147a820e75a141cb61b884a78d7e88996ceddafb9a7016cfe7a48a1f4f1b" # noqa
assert (signature == "0xb9b872bfe4e90f4628e8ec24879a5b01045f91da8457f3ce2b417d2e5774b508261ec1147a820e75a141cb61b884a78d7e88996ceddafb9a7016cfe7a48a1f4f1b" # noqa
)


def test_sign_user_operation_gas_diff():
user_operation_2 = dict(user_operation)
user_operation_2["maxPriorityFeePerGas"] -= 1
signature = aa_bundler.sign_user_operation(TEST_PRIVATE_KEY, user_operation_2, CHAIN_ID, ENTRYPOINT)
assert (
signature
== "0x8162479d2dbd18d7fe93a2f51e283021d6e4eae4f57d20cdd553042723a0b0ea690ab3903d45126b0047da08ab53dfdf86656e4f258ac4936ba96a759ccb77f61b" # noqa
assert (signature == "0x8162479d2dbd18d7fe93a2f51e283021d6e4eae4f57d20cdd553042723a0b0ea690ab3903d45126b0047da08ab53dfdf86656e4f258ac4936ba96a759ccb77f61b" # noqa
)


Expand Down Expand Up @@ -156,8 +155,8 @@ def test_get_nonce_random_key_mode(fetch_nonce_mock, randint_mock):
fetch_nonce_mock.assert_not_called()
randint_mock.assert_called_with(1, 2**192 - 1)
randint_mock.reset_mock()
assert aa_bundler.RANDOM_NONCE_KEY == 444
aa_bundler.RANDOM_NONCE_KEY = None # cleanup
assert aa_bundler.RANDOM_NONCE_KEY.key == 444
aa_bundler.RANDOM_NONCE_KEY.key = None # cleanup


@patch.object(aa_bundler.random, "randint")
Expand Down Expand Up @@ -242,3 +241,44 @@ def make_request(method, params):
get_base_fee_mock.assert_called_once_with(w3)
assert aa_bundler.NONCE_CACHE[0] == 1
assert ret == {"userOpHash": "0xa950a17ca1ed83e974fb1aa227360a007cb65f566518af117ffdbb04d8d2d524"}


def test_random_key_nonces_are_thread_safe():
queue = Queue()
event = Event()

def worker():
event.wait() # Get all threads running at the same time
nonce_key, nonce = aa_bundler.get_nonce_and_key(
FAIL_IF_USED,
{"from": TEST_SENDER},
nonce_mode=aa_bundler.NonceMode.RANDOM_KEY,
)
aa_bundler.consume_nonce(nonce_key, nonce)
queue.put(
aa_bundler.get_nonce_and_key(
FAIL_IF_USED,
{"from": TEST_SENDER},
nonce_mode=aa_bundler.NonceMode.RANDOM_KEY,
)
)

threads = [Thread(target=worker) for _ in range(15)]
for thread in threads:
thread.start()

# Fire all threads at once
event.set()
for thread in threads:
thread.join()

nonces = {}

while not queue.empty():
nonce_key, nonce = queue.get_nowait()
# Each thread got a different key
assert nonce_key not in nonces
nonces[nonce_key] = nonce

# All nonces are the same
assert all(nonce == 1 for nonce in nonces.values())
29 changes: 29 additions & 0 deletions tests/test_w3.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,32 @@ def test_wrapper_build_from_def():
assert counter.value() == 0
counter.increase()
assert counter.value() == 1


def test_get_events():
provider = wrappers.get_provider("w3")
contract_def = provider.get_contract_def("EventLauncher")
wrapper = wrappers.ETHWrapper.build_from_def(contract_def)

launcher = wrapper()

launcher.launchEvent1(1)

cutoff_block = provider.w3.eth.get_block("latest")
launcher.launchEvent2(2)
launcher.launchEvent1(3)

all_event1 = provider.get_events(launcher, "Event1", dict(from_block=0))
assert len(all_event1) == 2

first_event1_only = provider.get_events(launcher, "Event1", dict(to_block=cutoff_block.number))
assert len(first_event1_only) == 1
assert first_event1_only[0] == all_event1[0]

last_event1_only = provider.get_events(launcher, "Event1", dict(from_block=cutoff_block.number + 1))
assert len(last_event1_only) == 1
assert last_event1_only[0] == all_event1[-1]

event2 = provider.get_events(launcher, "Event2")
assert len(event2) == 1
assert event2[0].args.value == 2

0 comments on commit fb7529b

Please sign in to comment.