Skip to content

Commit

Permalink
Merge pull request #10 from gnarvaja/bundler-generic
Browse files Browse the repository at this point in the history
Add generic AA_BUNDLER_PROVIDER. Small refactor
  • Loading branch information
gnpar authored Oct 14, 2024
2 parents 26606d5 + bff80a7 commit 804a7b1
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 36 deletions.
100 changes: 66 additions & 34 deletions src/ethproto/aa_bundler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import random
from collections import defaultdict
from enum import Enum
from threading import local
from warnings import warn

from environs import Env
Expand All @@ -22,6 +24,7 @@
AA_BUNDLER_GAS_LIMIT_FACTOR = env.float("AA_BUNDLER_GAS_LIMIT_FACTOR", 1)
AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR = env.float("AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR", 1)
AA_BUNDLER_BASE_GAS_PRICE_FACTOR = env.float("AA_BUNDLER_BASE_GAS_PRICE_FACTOR", 1)
AA_BUNDLER_VERIFICATION_GAS_FACTOR = env.float("AA_BUNDLER_VERIFICATION_GAS_FACTOR", 1)

NonceMode = Enum(
"NonceMode",
Expand Down Expand Up @@ -51,8 +54,8 @@
}
]

NONCE_CACHE = {}
RANDOM_NONCE_KEY = None
NONCE_CACHE = defaultdict(lambda: 0)
RANDOM_NONCE_KEY = local()


def pack_two(a, b):
Expand All @@ -69,6 +72,10 @@ def _to_uint(x):
raise RuntimeError(f"Invalid int value {x}")


def apply_factor(x, factor):
return int(_to_uint(x) * factor)


def pack_user_operation(user_operation):
# https://github.com/eth-infinitism/account-abstraction/blob/develop/contracts/interfaces/PackedUserOperation.sol
return {
Expand Down Expand Up @@ -134,10 +141,9 @@ def fetch_nonce(w3, account, entry_point, nonce_key):


def get_random_nonce_key():
global RANDOM_NONCE_KEY
if RANDOM_NONCE_KEY is None:
RANDOM_NONCE_KEY = random.randint(1, 2**192 - 1)
return RANDOM_NONCE_KEY
if getattr(RANDOM_NONCE_KEY, "key", None) is None:
RANDOM_NONCE_KEY.key = random.randint(1, 2**192 - 1)
return RANDOM_NONCE_KEY.key


def get_nonce_and_key(w3, tx, nonce_mode, entry_point=AA_BUNDLER_ENTRYPOINT, fetch=False):
Expand All @@ -153,20 +159,25 @@ def get_nonce_and_key(w3, tx, nonce_mode, entry_point=AA_BUNDLER_ENTRYPOINT, fet
if nonce is None:
if fetch or nonce_mode == NonceMode.FIXED_KEY_FETCH_ALWAYS:
nonce = fetch_nonce(w3, get_sender(tx), entry_point, nonce_key)
elif nonce_key not in NONCE_CACHE:
nonce = 0
else:
nonce = NONCE_CACHE[nonce_key]
return nonce_key, nonce


def handle_response_error(resp, w3, tx, retry_nonce):
def consume_nonce(nonce_key, nonce):
NONCE_CACHE[nonce_key] = max(NONCE_CACHE[nonce_key], nonce + 1)


def check_nonce_error(resp, retry_nonce):
"""Returns the next nonce if resp contains a nonce error and retries weren't exhausted
Raises RevertError otherwise
"""
if "AA25" in resp["error"]["message"] and AA_BUNDLER_MAX_GETNONCE_RETRIES > 0:
# Retry fetching the nonce
if retry_nonce == AA_BUNDLER_MAX_GETNONCE_RETRIES:
raise RevertError(resp["error"]["message"])
warn(f'{resp["error"]["message"]} error, I will retry fetching the nonce')
return send_transaction(w3, tx, retry_nonce=(retry_nonce or 0) + 1)
return (retry_nonce or 0) + 1
else:
raise RevertError(resp["error"]["message"])

Expand All @@ -185,7 +196,7 @@ def get_sender(tx):
return tx["from"]


def send_transaction(w3, tx, retry_nonce=None):
def build_user_operation(w3, tx, retry_nonce=None):
nonce_key, nonce = get_nonce_and_key(
w3, tx, AA_BUNDLER_NONCE_MODE, entry_point=AA_BUNDLER_ENTRYPOINT, fetch=retry_nonce is not None
)
Expand All @@ -210,44 +221,65 @@ def send_transaction(w3, tx, retry_nonce=None):
"eth_estimateUserOperationGas", [user_operation, AA_BUNDLER_ENTRYPOINT]
)
if "error" in resp:
return handle_response_error(resp, w3, tx, retry_nonce)
next_nonce = check_nonce_error(resp, retry_nonce)
return build_user_operation(w3, tx, retry_nonce=next_nonce)

user_operation.update(resp["result"])

resp = w3.provider.make_request("rundler_maxPriorityFeePerGas", [])
if "error" in resp:
raise RevertError(resp["error"]["message"])
max_priority_fee_per_gas = int(_to_uint(resp["result"]) * AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR)
user_operation["maxPriorityFeePerGas"] = hex(max_priority_fee_per_gas)
user_operation["maxFeePerGas"] = hex(max_priority_fee_per_gas + get_base_fee(w3))
user_operation["callGasLimit"] = hex(
int(_to_uint(user_operation["callGasLimit"]) * AA_BUNDLER_GAS_LIMIT_FACTOR)
)
elif AA_BUNDLER_PROVIDER == "gelato":
user_operation.update(
{
"preVerificationGas": "0x00",
"callGasLimit": "0x00",
"verificationGasLimit": "0x00",
"maxFeePerGas": "0x00",
"maxPriorityFeePerGas": "0x00",
}
user_operation["maxPriorityFeePerGas"] = resp["result"]
user_operation["maxFeePerGas"] = hex(int(resp["result"], 16) + get_base_fee(w3))

elif AA_BUNDLER_PROVIDER == "generic":
resp = w3.provider.make_request(
"eth_estimateUserOperationGas", [user_operation, AA_BUNDLER_ENTRYPOINT]
)
user_operation["signature"] = add_0x_prefix(
sign_user_operation(
AA_BUNDLER_EXECUTOR_PK, user_operation, tx["chainId"], AA_BUNDLER_ENTRYPOINT
).hex()
if "error" in resp:
next_nonce = check_nonce_error(resp, retry_nonce)
return build_user_operation(w3, tx, retry_nonce=next_nonce)

user_operation.update(resp["result"])

else:
warn(f"Unknown AA_BUNDLER_PROVIDER: {AA_BUNDLER_PROVIDER}")

# Apply increase factors
user_operation["verificationGasLimit"] = hex(
apply_factor(user_operation["verificationGasLimit"], AA_BUNDLER_VERIFICATION_GAS_FACTOR)
)
if "maxPriorityFeePerGas" in user_operation:
user_operation["maxPriorityFeePerGas"] = hex(
apply_factor(user_operation["maxPriorityFeePerGas"], AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR)
)
if "callGasLimit" in user_operation:
user_operation["callGasLimit"] = hex(
apply_factor(user_operation["callGasLimit"], AA_BUNDLER_GAS_LIMIT_FACTOR)
)

# Remove paymaster related fields
user_operation.pop("paymaster", None)
user_operation.pop("paymasterData", None)
user_operation.pop("paymasterVerificationGasLimit", None)
user_operation.pop("paymasterPostOpGasLimit", None)

# Consume the nonce, even if the userop may fail later
consume_nonce(nonce_key, nonce)

return user_operation


def send_transaction(w3, tx, retry_nonce=None):
user_operation = build_user_operation(w3, tx, retry_nonce)
user_operation["signature"] = add_0x_prefix(
sign_user_operation(
AA_BUNDLER_EXECUTOR_PK, user_operation, tx["chainId"], AA_BUNDLER_ENTRYPOINT
).hex()
)
resp = w3.provider.make_request("eth_sendUserOperation", [user_operation, AA_BUNDLER_ENTRYPOINT])
if "error" in resp:
return handle_response_error(resp, w3, tx, retry_nonce)
next_nonce = check_nonce_error(resp, retry_nonce)
return send_transaction(w3, tx, retry_nonce=next_nonce)

# Store nonce in the cache, so next time uses a new nonce
NONCE_CACHE[nonce_key] = nonce + 1
return {"userOpHash": resp["result"]}
47 changes: 45 additions & 2 deletions tests/test_aa_bundler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from queue import Queue
from threading import Event, Thread
from unittest.mock import MagicMock, patch

from hexbytes import HexBytes
Expand Down Expand Up @@ -155,8 +157,8 @@ def test_get_nonce_random_key_mode(fetch_nonce_mock, randint_mock):
fetch_nonce_mock.assert_not_called()
randint_mock.assert_called_with(1, 2**192 - 1)
randint_mock.reset_mock()
assert aa_bundler.RANDOM_NONCE_KEY == 444
aa_bundler.RANDOM_NONCE_KEY = None # cleanup
assert aa_bundler.RANDOM_NONCE_KEY.key == 444
aa_bundler.RANDOM_NONCE_KEY.key = None # cleanup


@patch.object(aa_bundler.random, "randint")
Expand Down Expand Up @@ -241,3 +243,44 @@ def make_request(method, params):
get_base_fee_mock.assert_called_once_with(w3)
assert aa_bundler.NONCE_CACHE[0] == 1
assert ret == {"userOpHash": "0xa950a17ca1ed83e974fb1aa227360a007cb65f566518af117ffdbb04d8d2d524"}


def test_random_key_nonces_are_thread_safe():
queue = Queue()
event = Event()

def worker():
event.wait() # Get all threads running at the same time
nonce_key, nonce = aa_bundler.get_nonce_and_key(
FAIL_IF_USED,
{"from": TEST_SENDER},
nonce_mode=aa_bundler.NonceMode.RANDOM_KEY,
)
aa_bundler.consume_nonce(nonce_key, nonce)
queue.put(
aa_bundler.get_nonce_and_key(
FAIL_IF_USED,
{"from": TEST_SENDER},
nonce_mode=aa_bundler.NonceMode.RANDOM_KEY,
)
)

threads = [Thread(target=worker) for _ in range(15)]
for thread in threads:
thread.start()

# Fire all threads at once
event.set()
for thread in threads:
thread.join()

nonces = {}

while not queue.empty():
nonce_key, nonce = queue.get_nowait()
# Each thread got a different key
assert nonce_key not in nonces
nonces[nonce_key] = nonce

# All nonces are the same
assert all(nonce == 1 for nonce in nonces.values())

0 comments on commit 804a7b1

Please sign in to comment.