Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic AA_BUNDLER_PROVIDER. Small refactor #10

Merged
merged 4 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 66 additions & 34 deletions src/ethproto/aa_bundler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import random
from collections import defaultdict
from enum import Enum
from threading import local
from warnings import warn

from environs import Env
Expand All @@ -22,6 +24,7 @@
AA_BUNDLER_GAS_LIMIT_FACTOR = env.float("AA_BUNDLER_GAS_LIMIT_FACTOR", 1)
AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR = env.float("AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR", 1)
AA_BUNDLER_BASE_GAS_PRICE_FACTOR = env.float("AA_BUNDLER_BASE_GAS_PRICE_FACTOR", 1)
AA_BUNDLER_VERIFICATION_GAS_FACTOR = env.float("AA_BUNDLER_VERIFICATION_GAS_FACTOR", 1)

NonceMode = Enum(
"NonceMode",
Expand Down Expand Up @@ -51,8 +54,8 @@
}
]

NONCE_CACHE = {}
RANDOM_NONCE_KEY = None
NONCE_CACHE = defaultdict(lambda: 0)
RANDOM_NONCE_KEY = local()


def pack_two(a, b):
Expand All @@ -69,6 +72,10 @@ def _to_uint(x):
raise RuntimeError(f"Invalid int value {x}")


def apply_factor(x, factor):
return int(_to_uint(x) * factor)


def pack_user_operation(user_operation):
# https://github.com/eth-infinitism/account-abstraction/blob/develop/contracts/interfaces/PackedUserOperation.sol
return {
Expand Down Expand Up @@ -134,10 +141,9 @@ def fetch_nonce(w3, account, entry_point, nonce_key):


def get_random_nonce_key():
global RANDOM_NONCE_KEY
if RANDOM_NONCE_KEY is None:
RANDOM_NONCE_KEY = random.randint(1, 2**192 - 1)
return RANDOM_NONCE_KEY
if getattr(RANDOM_NONCE_KEY, "key", None) is None:
RANDOM_NONCE_KEY.key = random.randint(1, 2**192 - 1)
return RANDOM_NONCE_KEY.key


def get_nonce_and_key(w3, tx, nonce_mode, entry_point=AA_BUNDLER_ENTRYPOINT, fetch=False):
Expand All @@ -153,20 +159,25 @@ def get_nonce_and_key(w3, tx, nonce_mode, entry_point=AA_BUNDLER_ENTRYPOINT, fet
if nonce is None:
if fetch or nonce_mode == NonceMode.FIXED_KEY_FETCH_ALWAYS:
nonce = fetch_nonce(w3, get_sender(tx), entry_point, nonce_key)
elif nonce_key not in NONCE_CACHE:
nonce = 0
else:
nonce = NONCE_CACHE[nonce_key]
return nonce_key, nonce


def handle_response_error(resp, w3, tx, retry_nonce):
def consume_nonce(nonce_key, nonce):
NONCE_CACHE[nonce_key] = max(NONCE_CACHE[nonce_key], nonce + 1)


def check_nonce_error(resp, retry_nonce):
"""Returns the next nonce if resp contains a nonce error and retries weren't exhausted
Raises RevertError otherwise
"""
if "AA25" in resp["error"]["message"] and AA_BUNDLER_MAX_GETNONCE_RETRIES > 0:
# Retry fetching the nonce
if retry_nonce == AA_BUNDLER_MAX_GETNONCE_RETRIES:
raise RevertError(resp["error"]["message"])
warn(f'{resp["error"]["message"]} error, I will retry fetching the nonce')
return send_transaction(w3, tx, retry_nonce=(retry_nonce or 0) + 1)
return (retry_nonce or 0) + 1
else:
raise RevertError(resp["error"]["message"])

Expand All @@ -185,7 +196,7 @@ def get_sender(tx):
return tx["from"]


def send_transaction(w3, tx, retry_nonce=None):
def build_user_operation(w3, tx, retry_nonce=None):
nonce_key, nonce = get_nonce_and_key(
w3, tx, AA_BUNDLER_NONCE_MODE, entry_point=AA_BUNDLER_ENTRYPOINT, fetch=retry_nonce is not None
)
Expand All @@ -210,44 +221,65 @@ def send_transaction(w3, tx, retry_nonce=None):
"eth_estimateUserOperationGas", [user_operation, AA_BUNDLER_ENTRYPOINT]
)
if "error" in resp:
return handle_response_error(resp, w3, tx, retry_nonce)
next_nonce = check_nonce_error(resp, retry_nonce)
return build_user_operation(w3, tx, retry_nonce=next_nonce)

user_operation.update(resp["result"])

resp = w3.provider.make_request("rundler_maxPriorityFeePerGas", [])
if "error" in resp:
raise RevertError(resp["error"]["message"])
max_priority_fee_per_gas = int(_to_uint(resp["result"]) * AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR)
user_operation["maxPriorityFeePerGas"] = hex(max_priority_fee_per_gas)
user_operation["maxFeePerGas"] = hex(max_priority_fee_per_gas + get_base_fee(w3))
user_operation["callGasLimit"] = hex(
int(_to_uint(user_operation["callGasLimit"]) * AA_BUNDLER_GAS_LIMIT_FACTOR)
)
elif AA_BUNDLER_PROVIDER == "gelato":
user_operation.update(
{
"preVerificationGas": "0x00",
"callGasLimit": "0x00",
"verificationGasLimit": "0x00",
"maxFeePerGas": "0x00",
"maxPriorityFeePerGas": "0x00",
}
user_operation["maxPriorityFeePerGas"] = resp["result"]
user_operation["maxFeePerGas"] = hex(int(resp["result"], 16) + get_base_fee(w3))

elif AA_BUNDLER_PROVIDER == "generic":
resp = w3.provider.make_request(
"eth_estimateUserOperationGas", [user_operation, AA_BUNDLER_ENTRYPOINT]
)
user_operation["signature"] = add_0x_prefix(
sign_user_operation(
AA_BUNDLER_EXECUTOR_PK, user_operation, tx["chainId"], AA_BUNDLER_ENTRYPOINT
).hex()
if "error" in resp:
next_nonce = check_nonce_error(resp, retry_nonce)
return build_user_operation(w3, tx, retry_nonce=next_nonce)

user_operation.update(resp["result"])

else:
warn(f"Unknown AA_BUNDLER_PROVIDER: {AA_BUNDLER_PROVIDER}")

# Apply increase factors
user_operation["verificationGasLimit"] = hex(
apply_factor(user_operation["verificationGasLimit"], AA_BUNDLER_VERIFICATION_GAS_FACTOR)
)
if "maxPriorityFeePerGas" in user_operation:
user_operation["maxPriorityFeePerGas"] = hex(
apply_factor(user_operation["maxPriorityFeePerGas"], AA_BUNDLER_PRIORITY_GAS_PRICE_FACTOR)
)
if "callGasLimit" in user_operation:
user_operation["callGasLimit"] = hex(
apply_factor(user_operation["callGasLimit"], AA_BUNDLER_GAS_LIMIT_FACTOR)
)

# Remove paymaster related fields
user_operation.pop("paymaster", None)
user_operation.pop("paymasterData", None)
user_operation.pop("paymasterVerificationGasLimit", None)
user_operation.pop("paymasterPostOpGasLimit", None)

# Consume the nonce, even if the userop may fail later
consume_nonce(nonce_key, nonce)
gnpar marked this conversation as resolved.
Show resolved Hide resolved

return user_operation


def send_transaction(w3, tx, retry_nonce=None):
user_operation = build_user_operation(w3, tx, retry_nonce)
user_operation["signature"] = add_0x_prefix(
sign_user_operation(
AA_BUNDLER_EXECUTOR_PK, user_operation, tx["chainId"], AA_BUNDLER_ENTRYPOINT
).hex()
)
resp = w3.provider.make_request("eth_sendUserOperation", [user_operation, AA_BUNDLER_ENTRYPOINT])
if "error" in resp:
return handle_response_error(resp, w3, tx, retry_nonce)
next_nonce = check_nonce_error(resp, retry_nonce)
return send_transaction(w3, tx, retry_nonce=next_nonce)

# Store nonce in the cache, so next time uses a new nonce
NONCE_CACHE[nonce_key] = nonce + 1
return {"userOpHash": resp["result"]}
47 changes: 45 additions & 2 deletions tests/test_aa_bundler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from queue import Queue
from threading import Event, Thread
from unittest.mock import MagicMock, patch

from hexbytes import HexBytes
Expand Down Expand Up @@ -155,8 +157,8 @@ def test_get_nonce_random_key_mode(fetch_nonce_mock, randint_mock):
fetch_nonce_mock.assert_not_called()
randint_mock.assert_called_with(1, 2**192 - 1)
randint_mock.reset_mock()
assert aa_bundler.RANDOM_NONCE_KEY == 444
aa_bundler.RANDOM_NONCE_KEY = None # cleanup
assert aa_bundler.RANDOM_NONCE_KEY.key == 444
aa_bundler.RANDOM_NONCE_KEY.key = None # cleanup


@patch.object(aa_bundler.random, "randint")
Expand Down Expand Up @@ -241,3 +243,44 @@ def make_request(method, params):
get_base_fee_mock.assert_called_once_with(w3)
assert aa_bundler.NONCE_CACHE[0] == 1
assert ret == {"userOpHash": "0xa950a17ca1ed83e974fb1aa227360a007cb65f566518af117ffdbb04d8d2d524"}


def test_random_key_nonces_are_thread_safe():
queue = Queue()
event = Event()

def worker():
event.wait() # Get all threads running at the same time
nonce_key, nonce = aa_bundler.get_nonce_and_key(
FAIL_IF_USED,
{"from": TEST_SENDER},
nonce_mode=aa_bundler.NonceMode.RANDOM_KEY,
)
aa_bundler.consume_nonce(nonce_key, nonce)
queue.put(
aa_bundler.get_nonce_and_key(
FAIL_IF_USED,
{"from": TEST_SENDER},
nonce_mode=aa_bundler.NonceMode.RANDOM_KEY,
)
)

threads = [Thread(target=worker) for _ in range(15)]
for thread in threads:
thread.start()

# Fire all threads at once
event.set()
for thread in threads:
thread.join()

nonces = {}

while not queue.empty():
nonce_key, nonce = queue.get_nowait()
# Each thread got a different key
assert nonce_key not in nonces
nonces[nonce_key] = nonce

# All nonces are the same
assert all(nonce == 1 for nonce in nonces.values())
Loading