From b046e1b6a86084e39487e175032b73b57acf085d Mon Sep 17 00:00:00 2001 From: antazoey Date: Fri, 2 Aug 2024 14:24:23 -0500 Subject: [PATCH] perf: RPC / provider level speed optimizations (#2193) --- src/ape/api/networks.py | 1 - src/ape/api/transactions.py | 33 +++++-- src/ape/contracts/base.py | 15 +-- src/ape/managers/chain.py | 1 - src/ape/pytest/config.py | 16 ++- src/ape/pytest/coverage.py | 2 +- src/ape/pytest/fixtures.py | 40 +++----- src/ape/pytest/gas.py | 5 +- src/ape/pytest/plugin.py | 44 ++------- src/ape/pytest/runners.py | 4 +- src/ape/utils/abi.py | 4 +- src/ape/utils/misc.py | 2 +- src/ape_ethereum/ecosystem.py | 55 +++++------ src/ape_ethereum/provider.py | 98 ++++++++++++------- src/ape_ethereum/trace.py | 8 +- src/ape_node/provider.py | 4 +- src/ape_test/accounts.py | 27 +++-- src/ape_test/provider.py | 66 ++++++++++--- .../functional/test_contract_call_handler.py | 6 +- tests/functional/test_exceptions.py | 2 +- tests/functional/test_provider.py | 8 +- 21 files changed, 241 insertions(+), 200 deletions(-) diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 302f75355d..671ffd6d92 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -173,7 +173,6 @@ def serialize_transaction(self) -> bytes: Returns: bytes """ - if not self.signature: raise SignatureError("The transaction is not signed.") diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index 0bfb91321c..9cb8984ab7 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -5,7 +5,7 @@ from typing import IO, TYPE_CHECKING, Any, NoReturn, Optional, Union from eth_pydantic_types import HexBytes -from eth_utils import is_0x_prefixed, is_hex, to_int +from eth_utils import is_0x_prefixed, is_hex, to_hex, to_int from ethpm_types.abi import EventABI, MethodABI from pydantic import ConfigDict, field_validator from pydantic.fields import Field @@ -74,6 +74,20 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._raise_on_revert = raise_on_revert + @field_validator("nonce", mode="before") + @classmethod + def validate_nonce(cls, value): + if value is None or isinstance(value, int): + return value + + elif isinstance(value, str) and value.startswith("0x"): + return to_int(hexstr=value) + + elif isinstance(value, str): + return int(value) + + return to_int(value) + @field_validator("gas_limit", mode="before") @classmethod def validate_gas_limit(cls, value): @@ -161,12 +175,12 @@ def receipt(self) -> Optional["ReceiptAPI"]: """ try: - txn_hash = self.txn_hash.hex() + txn_hash = to_hex(self.txn_hash) except SignatureError: return None try: - return self.provider.get_receipt(txn_hash, required_confirmations=0, timeout=0) + return self.chain_manager.get_receipt(txn_hash) except (TransactionNotFoundError, ProviderNotConnectedError): return None @@ -355,7 +369,6 @@ def failed(self) -> bool: Ecosystem plugins override this property when their receipts are able to be failing. """ - return False @property @@ -450,6 +463,13 @@ def await_confirmations(self) -> "ReceiptAPI": Returns: :class:`~ape.api.ReceiptAPI`: The receipt that is now confirmed. """ + # perf: avoid *everything* if required_confirmations is 0, as this is likely a + # dev environment or the user doesn't care. + if self.required_confirmations == 0: + # The transaction might not yet be confirmed but + # the user is aware of this. Or, this is a development environment. + return self + try: self.raise_for_status() except TransactionError: @@ -472,11 +492,6 @@ def await_confirmations(self) -> "ReceiptAPI": if self.transaction.raise_on_revert: raise tx_err - if self.required_confirmations == 0: - # The transaction might not yet be confirmed but - # the user is aware of this. Or, this is a development environment. - return self - confirmations_occurred = self._confirmations_occurred if self.required_confirmations and confirmations_occurred >= self.required_confirmations: return self diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index b8c7bf32ff..63c42c65fa 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -554,12 +554,7 @@ def __len__(self): def __call__(self, *args: Any, **kwargs: Any) -> MockContractLog: # Create a dictionary from the positional arguments event_args: dict[Any, Any] = dict(zip((ipt.name for ipt in self.abi.inputs), args)) - - overlapping_keys = set(k for k in event_args.keys() if k is not None) & set( - k for k in kwargs.keys() if k is not None - ) - - if overlapping_keys: + if overlapping_keys := set(event_args).intersection(kwargs): raise ValueError( f"Overlapping keys found in arguments: '{', '.join(overlapping_keys)}'." ) @@ -1132,8 +1127,8 @@ def get_event_by_signature(self, signature: str) -> ContractEvent: :class:`~ape.contracts.base.ContractEvent` """ - name_from_sig = signature.split("(")[0].strip() - options = self._events_.get(name_from_sig, []) + name_from_sig = signature.partition("(")[0].strip() + options = self._events_.get(name_from_sig.strip(), []) err = ContractDataError(f"No event found with signature '{signature}'.") if not options: @@ -1157,7 +1152,7 @@ def get_error_by_signature(self, signature: str) -> type[CustomError]: :class:`~ape.exceptions.CustomError` """ - name_from_sig = signature.split("(")[0].strip() + name_from_sig = signature.partition("(")[0].strip() options = self._errors_.get(name_from_sig, []) err = ContractDataError(f"No error found with signature '{signature}'.") if not options: @@ -1605,7 +1600,7 @@ def _get_name(cc: ContractContainer) -> str: return contract elif "." in search_name: - next_node = search_name.split(".")[0] + next_node = search_name.partition(".")[0] if next_node != item: continue diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index ae6d8363ea..8ce6bec943 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -67,7 +67,6 @@ def head(self) -> BlockAPI: """ The latest block. """ - return self.provider.get_block("latest") @property diff --git a/src/ape/pytest/config.py b/src/ape/pytest/config.py index af772fb254..d21e2cf619 100644 --- a/src/ape/pytest/config.py +++ b/src/ape/pytest/config.py @@ -1,9 +1,10 @@ +from functools import cached_property from typing import Any, Optional, Union from _pytest.config import Config as PytestConfig from ape.types import ContractFunctionPath -from ape.utils import ManagerAccessMixin, cached_property +from ape.utils.basemodel import ManagerAccessMixin def _get_config_exclusions(config) -> list[ContractFunctionPath]: @@ -76,15 +77,12 @@ def gas_exclusions(self) -> list[ContractFunctionPath]: """ The combination of both CLI values and config values. """ - cli_value = self.pytest_config.getoption("--gas-exclude") - exclusions: list[ContractFunctionPath] = [] - if cli_value: - items = cli_value.split(",") - for item in items: - exclusion = ContractFunctionPath.from_str(item) - exclusions.append(exclusion) - + exclusions = ( + [ContractFunctionPath.from_str(item) for item in cli_value.split(",")] + if cli_value + else [] + ) paths = _get_config_exclusions(self.ape_test_config.gas) exclusions.extend(paths) return exclusions diff --git a/src/ape/pytest/coverage.py b/src/ape/pytest/coverage.py index d21a928377..bc4674243e 100644 --- a/src/ape/pytest/coverage.py +++ b/src/ape/pytest/coverage.py @@ -7,7 +7,7 @@ from ethpm_types.source import ContractSource from ape.logging import logger -from ape.managers import ProjectManager +from ape.managers.project import ProjectManager from ape.pytest.config import ConfigWrapper from ape.types import ( ContractFunctionPath, diff --git a/src/ape/pytest/fixtures.py b/src/ape/pytest/fixtures.py index 5d7f68f0a4..0e5d179a45 100644 --- a/src/ape/pytest/fixtures.py +++ b/src/ape/pytest/fixtures.py @@ -1,11 +1,12 @@ -import copy from collections.abc import Iterator from fnmatch import fnmatch +from functools import cached_property from typing import Optional import pytest -from ape.api import ReceiptAPI, TestAccountAPI +from ape.api.accounts import TestAccountAPI +from ape.api.transactions import ReceiptAPI from ape.exceptions import BlockNotFoundError, ChainError from ape.logging import logger from ape.managers.chain import ChainManager @@ -13,7 +14,8 @@ from ape.managers.project import ProjectManager from ape.pytest.config import ConfigWrapper from ape.types import SnapshotID -from ape.utils import ManagerAccessMixin, allow_disconnected, cached_property +from ape.utils.basemodel import ManagerAccessMixin +from ape.utils.misc import allow_disconnected class PytestApeFixtures(ManagerAccessMixin): @@ -30,9 +32,10 @@ def __init__(self, config_wrapper: ConfigWrapper, receipt_capture: "ReceiptCaptu @cached_property def _track_transactions(self) -> bool: - has_reason = self.config_wrapper.track_gas or self.config_wrapper.track_coverage return ( - self.network_manager.provider is not None and self.provider.is_connected and has_reason + self.network_manager.provider is not None + and self.provider.is_connected + and (self.config_wrapper.track_gas or self.config_wrapper.track_coverage) ) @pytest.fixture(scope="session") @@ -40,7 +43,6 @@ def accounts(self) -> list[TestAccountAPI]: """ A collection of pre-funded accounts. """ - return self.account_manager.test_accounts @pytest.fixture(scope="session") @@ -48,7 +50,6 @@ def compilers(self): """ Access compiler manager directly. """ - return self.compiler_manager @pytest.fixture(scope="session") @@ -56,7 +57,6 @@ def chain(self) -> ChainManager: """ Manipulate the blockchain, such as mine or change the pending timestamp. """ - return self.chain_manager @pytest.fixture(scope="session") @@ -64,7 +64,6 @@ def networks(self) -> NetworkManager: """ Connect to other networks in your tests. """ - return self.network_manager @pytest.fixture(scope="session") @@ -72,7 +71,6 @@ def project(self) -> ProjectManager: """ Access contract types and dependencies. """ - return self.local_project @pytest.fixture(scope="session") @@ -88,7 +86,6 @@ def _isolation(self) -> Iterator[None]: Isolation logic used to implement isolation fixtures for each pytest scope. When tracing support is available, will also assist in capturing receipts. """ - try: snapshot_id = self._snapshot() except BlockNotFoundError: @@ -174,7 +171,7 @@ def capture_range(self, start_block: int, stop_block: int): txn_hash = txn.txn_hash.hex() except Exception: # Might have been from an impersonated account. - # Those txns need to be added separatly, same as tracing calls. + # Those txns need to be added separately, same as tracing calls. # Likely, it was already accounted before this point. continue @@ -189,14 +186,14 @@ def capture(self, transaction_hash: str): if not receipt: return - if not (contract_address := (receipt.receiver or receipt.contract_address)): + elif not (contract_address := (receipt.receiver or receipt.contract_address)): return - if not (contract_type := self.chain_manager.contracts.get(contract_address)): + elif not (contract_type := self.chain_manager.contracts.get(contract_address)): # Not an invoke-transaction or a known address return - if not (source_id := (contract_type.source_id or None)): + elif not (source_id := (contract_type.source_id or None)): # Not a local or known contract type. return @@ -229,7 +226,6 @@ def _exclude_from_gas_report( Helper method to determine if a certain contract / method combination should be excluded from the gas report. """ - for exclusion in self.config_wrapper.gas_exclusions: # Default to looking at all contracts contract_pattern = exclusion.contract_name @@ -241,15 +237,3 @@ def _exclude_from_gas_report( return True return False - - -def _build_report(report: dict, contract: str, method: str, usages: list) -> dict: - new_dict = copy.deepcopy(report) - if contract not in new_dict: - new_dict[contract] = {method: usages} - elif method not in new_dict[contract]: - new_dict[contract][method] = usages - else: - new_dict[contract][method].extend(usages) - - return new_dict diff --git a/src/ape/pytest/gas.py b/src/ape/pytest/gas.py index b21c3ea07d..87f6ffca41 100644 --- a/src/ape/pytest/gas.py +++ b/src/ape/pytest/gas.py @@ -4,12 +4,11 @@ from ethpm_types.source import ContractSource from evm_trace.gas import merge_reports -from ape.api import TraceAPI +from ape.api.trace import TraceAPI from ape.pytest.config import ConfigWrapper from ape.types import AddressType, ContractFunctionPath, GasReport -from ape.utils import parse_gas_table from ape.utils.basemodel import ManagerAccessMixin -from ape.utils.trace import _exclude_gas +from ape.utils.trace import _exclude_gas, parse_gas_table class GasTracker(ManagerAccessMixin): diff --git a/src/ape/pytest/plugin.py b/src/ape/pytest/plugin.py index 92218811a9..ecb10b7e0e 100644 --- a/src/ape/pytest/plugin.py +++ b/src/ape/pytest/plugin.py @@ -1,10 +1,7 @@ import sys from pathlib import Path -import pytest - from ape.exceptions import ConfigError -from ape.logging import LogLevel, logger from ape.pytest.config import ConfigWrapper from ape.pytest.coverage import CoverageTracker from ape.pytest.fixtures import PytestApeFixtures, ReceiptCapture @@ -64,17 +61,19 @@ def add_option(*names, **kwargs): def pytest_configure(config): # Do not include ape internals in tracebacks unless explicitly asked if not config.getoption("--show-internal"): - path_str = sys.modules["ape"].__file__ - if path_str: - base_path = Path(path_str).parent.as_posix() + if path_str := sys.modules["ape"].__file__: + base_path = str(Path(path_str).parent) def is_module(v): return getattr(v, "__file__", None) and v.__file__.startswith(base_path) - modules = [v for v in sys.modules.values() if is_module(v)] - for module in modules: - if hasattr(module, "__tracebackhide__"): - setattr(module, "__tracebackhide__", True) + for module in (v for v in sys.modules.values() if is_module(v)): + # NOTE: Using try/except w/ type:ignore (over checking for attr) + # for performance reasons! + try: + module.__tracebackhide__ = True # type: ignore[attr-defined] + except AttributeError: + pass config_wrapper = ConfigWrapper(config) receipt_capture = ReceiptCapture(config_wrapper) @@ -99,28 +98,3 @@ def is_module(v): config.addinivalue_line( "markers", "use_network(choice): Run this test using the given network choice." ) - - -def pytest_load_initial_conftests(early_config): - """ - Compile contracts before loading ``conftest.py``s. - """ - capture_manager = early_config.pluginmanager.get_plugin("capturemanager") - pm = ManagerAccessMixin.local_project - - # Suspend stdout capture to display compilation data - capture_manager.suspend() - try: - pm.load_contracts() - except Exception as err: - logger.log_debug_stack_trace() - message = "Unable to load project. " - if logger.level > LogLevel.DEBUG: - message = f"{message}Use `-v DEBUG` to see more info.\n" - - err_type_name = getattr(type(err), "__name__", "Exception") - message = f"{message}Failure reason: ({err_type_name}) {err}" - raise pytest.UsageError(message) - - finally: - capture_manager.resume() diff --git a/src/ape/pytest/runners.py b/src/ape/pytest/runners.py index 132638d15d..0445e8f446 100644 --- a/src/ape/pytest/runners.py +++ b/src/ape/pytest/runners.py @@ -6,14 +6,14 @@ from _pytest._code.code import Traceback as PytestTraceback from rich import print as rich_print -from ape.api import ProviderContextManager +from ape.api.networks import ProviderContextManager from ape.logging import LogLevel from ape.pytest.config import ConfigWrapper from ape.pytest.coverage import CoverageTracker from ape.pytest.fixtures import ReceiptCapture from ape.pytest.gas import GasTracker from ape.types.coverage import CoverageReport -from ape.utils import ManagerAccessMixin +from ape.utils.basemodel import ManagerAccessMixin from ape_console._cli import console diff --git a/src/ape/utils/abi.py b/src/ape/utils/abi.py index 42d241eb93..36e329645f 100644 --- a/src/ape/utils/abi.py +++ b/src/ape/utils/abi.py @@ -159,7 +159,7 @@ def _decode( return values elif has_array_of_tuples_return: - item_type_str = str(_types[0].type).split("[")[0] + item_type_str = str(_types[0].type).partition("[")[0] data = { **_types[0].model_dump(), "type": item_type_str, @@ -179,7 +179,7 @@ def _decode( else: for output_type, value in zip(_types, values): if isinstance(value, (tuple, list)): - item_type_str = str(output_type.type).split("[")[0] + item_type_str = str(output_type.type).partition("[")[0] if item_type_str == "tuple": # Either an array of structs or nested structs. item_type_data = { diff --git a/src/ape/utils/misc.py b/src/ape/utils/misc.py index 83dd8ddb96..d0977f1334 100644 --- a/src/ape/utils/misc.py +++ b/src/ape/utils/misc.py @@ -172,7 +172,7 @@ def get_package_version(obj: Any) -> str: # Reduce module string to base package # NOTE: Assumed that string input is module name e.g. `__name__` - pkg_name = obj.split(".")[0] + pkg_name = obj.partition(".")[0] # NOTE: In case the distribution and package name differ dists = _get_distributions(pkg_name) diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 12efbe886e..9a98b615cf 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -550,7 +550,6 @@ def decode_receipt(self, data: dict) -> ReceiptAPI: status = self.conversion_manager.convert(status, int) status = TransactionStatusEnum(status) - txn_hash = None hash_key_choices = ( "hash", "txHash", @@ -559,18 +558,13 @@ def decode_receipt(self, data: dict) -> ReceiptAPI: "transactionHash", "transaction_hash", ) - for choice in hash_key_choices: - if choice in data: - txn_hash = data[choice] - break + txn_hash = next((data[choice] for choice in hash_key_choices if choice in data), None) + if txn_hash and isinstance(txn_hash, bytes): + txn_hash = txn_hash.hex() - if txn_hash: - txn_hash = txn_hash.hex() if isinstance(txn_hash, bytes) else txn_hash - - data_bytes = data.get("data", b"") + data_bytes = data.get("data") if data_bytes and isinstance(data_bytes, str): data["data"] = HexBytes(data_bytes) - elif "input" in data and isinstance(data["input"], str): data["input"] = HexBytes(data["input"]) @@ -578,17 +572,17 @@ def decode_receipt(self, data: dict) -> ReceiptAPI: if block_number is None: raise ValueError("Missing block number.") - receipt_kwargs = dict( - block_number=block_number, - contract_address=data.get("contract_address") or data.get("contractAddress"), - gas_limit=data.get("gas", data.get("gas_limit", data.get("gasLimit"))) or 0, - gas_price=data.get("gas_price", data.get("gasPrice")) or 0, - gas_used=data.get("gas_used", data.get("gasUsed")) or 0, - logs=data.get("logs", []), - status=status, - txn_hash=txn_hash, - transaction=self.create_transaction(**data), - ) + receipt_kwargs = { + "block_number": block_number, + "contract_address": data.get("contract_address", data.get("contractAddress")), + "gas_limit": data.get("gas", data.get("gas_limit", data.get("gasLimit"))) or 0, + "gas_price": data.get("gas_price", data.get("gasPrice")) or 0, + "gas_used": data.get("gas_used", data.get("gasUsed")) or 0, + "logs": data.get("logs", []), + "status": status, + "txn_hash": txn_hash, + "transaction": self.create_transaction(**data), + } receipt_cls: type[Receipt] if any( @@ -631,7 +625,10 @@ def _python_type_for_abi_type(self, abi_type: ABIType) -> Union[type, Sequence]: # NOTE: An array can be an array of tuples, so we start with an array check if str(abi_type.type).endswith("]"): # remove one layer of the potential onion of array - new_type = "[".join(str(abi_type.type).split("[")[:-1]) + abi_type_str = str(abi_type.type) + last_bracket_pos = abi_type_str.rfind("[") + new_type = abi_type_str[:last_bracket_pos] if last_bracket_pos != -1 else abi_type_str + # create a new type with the inner type of array new_abi_type = ABIType(type=new_type, **abi_type.model_dump(exclude={"type"})) # NOTE: type for static and dynamic array is a single item list @@ -952,7 +949,7 @@ def create_transaction(self, **kwargs) -> TransactionAPI: if "gas" not in tx_data: tx_data["gas"] = None - return txn_class(**tx_data) + return txn_class.model_validate(tx_data) def decode_logs(self, logs: Sequence[dict], *events: EventABI) -> Iterator["ContractLog"]: if not logs: @@ -1478,14 +1475,10 @@ def _correct_key(key: str, data: dict, alt_keys: tuple[str, ...]) -> dict: if key in data: return data - # Check for alternative. for possible_key in alt_keys: - if possible_key not in data: - continue - - # Alt found: use it. - new_data = {k: v for k, v in data.items() if k not in alt_keys} - new_data[key] = data[possible_key] - return new_data + if possible_key in data: + new_data = data.copy() + new_data[key] = new_data.pop(possible_key) + return new_data return data diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 15f4547669..19eae807af 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -233,7 +233,7 @@ def client_version(self) -> str: @property def base_fee(self) -> int: - latest_block_number = self.get_block("latest").number + latest_block_number = self._get_latest_block_rpc().get("number") if latest_block_number is None: # Possibly no blocks yet. logger.debug("Latest block has no number. Using base fee of '0'.") @@ -280,8 +280,7 @@ def _get_fee_history(self, block_number: int) -> FeeHistory: raise APINotImplementedError(str(err)) from err def _get_last_base_fee(self) -> int: - block = self.get_block("latest") - base_fee = getattr(block, "base_fee", None) + base_fee = self._get_latest_block_rpc().get("baseFeePerGas", None) if base_fee is not None: return base_fee @@ -296,8 +295,7 @@ def is_connected(self) -> bool: @property def max_gas(self) -> int: - block = self.web3.eth.get_block("latest") - return block["gasLimit"] + return int(self._get_latest_block_rpc()["gasLimit"], 16) @cached_property def supports_tracing(self) -> bool: @@ -411,12 +409,16 @@ def get_block(self, block_id: BlockID) -> BlockAPI: except Exception as err: raise BlockNotFoundError(block_id, reason=str(err)) from err - # Some nodes (like anvil) will not have a base fee if set to 0. - if "baseFeePerGas" in block_data and block_data.get("baseFeePerGas") is None: - block_data["baseFeePerGas"] = 0 - return self.network.ecosystem.decode_block(block_data) + def _get_latest_block(self) -> BlockAPI: + # perf: By-pass as much as possible since this is a common action. + data = self._get_latest_block_rpc() + return self.network.ecosystem.decode_block(data) + + def _get_latest_block_rpc(self) -> dict: + return self.make_request("eth_getBlockByNumber", ["latest", False]) + def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: return self.web3.eth.get_transaction_count(address, block_identifier=block_id) @@ -568,6 +570,7 @@ def _prepare_call(self, txn: Union[dict, TransactionAPI], **kwargs) -> list: txn_dict.pop("gasLimit", None) txn_dict.pop("maxFeePerGas", None) txn_dict.pop("maxPriorityFeePerGas", None) + txn_dict.pop("signature", None) # NOTE: Block ID is required so if given None, default to `"latest"`. block_identifier = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) or "latest" @@ -593,10 +596,12 @@ def get_receipt( timeout = ( timeout if timeout is not None else self.provider.network.transaction_acceptance_timeout ) - hex_hash = HexBytes(txn_hash) + try: - receipt_data = self.web3.eth.wait_for_transaction_receipt(hex_hash, timeout=timeout) + receipt_data = dict( + self.web3.eth.wait_for_transaction_receipt(hex_hash, timeout=timeout) + ) except TimeExhausted as err: msg_str = str(err) if f"HexBytes('{txn_hash}')" in msg_str: @@ -609,17 +614,33 @@ def get_receipt( ecosystem_config = self.network.ecosystem_config network_config: dict = ecosystem_config.get(self.network.name, {}) max_retries = network_config.get("max_get_transaction_retries", DEFAULT_MAX_RETRIES_TX) - txn = {} - for attempt in range(max_retries): - try: - txn = dict(self.web3.eth.get_transaction(HexStr(txn_hash))) - break - except TransactionNotFound: - if attempt < max_retries - 1: # if this wasn't the last attempt - time.sleep(1) # Wait for 1 second before retrying. - continue # Continue to the next iteration, effectively retrying the operation. - else: # if it was the last attempt - raise # Re-raise the last exception. + + if transaction := kwargs.get("transaction"): + # perf: If called `send_transaction()`, we should already have the data! + txn = ( + transaction + if isinstance(transaction, dict) + else transaction.model_dump(by_alias=True, mode="json") + ) + if "effectiveGasPrice" in receipt_data: + receipt_data["gasPrice"] = receipt_data["effectiveGasPrice"] + + else: + txn = {} + for attempt in range(max_retries): + try: + txn = dict(self.web3.eth.get_transaction(HexStr(txn_hash))) + break + + except TransactionNotFound: + if attempt < max_retries - 1: + # Not the last attempt. Wait and then retry. + time.sleep(0.5) + continue + + else: + # It was the last attempt - raise the exception as-is. + raise data = {"required_confirmations": required_confirmations, **txn, **receipt_data} receipt = self._create_receipt(**data) @@ -754,9 +775,7 @@ def assert_chain_activity(): while True: # The next block we want is simply 1 after the last. next_block = last.number + 1 - - head = self.get_block("latest") - + head = self._get_latest_block() try: if head.number is None or head.hash is None: raise ProviderError("Head block has no number or hash.") @@ -826,7 +845,7 @@ def poll_logs( required_confirmations = self.network.required_confirmations if stop_block is not None: - if stop_block <= (self.provider.get_block("latest").number or 0): + if stop_block <= (self._get_latest_block().number or 0): raise ValueError("'stop' argument must be in the future.") for block in self.poll_blocks(stop_block, required_confirmations, new_block_timeout): @@ -936,6 +955,7 @@ def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI: def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: vm_err = None + txn_data = None txn_hash = None try: if txn.sender is not None and txn.signature is None: @@ -976,7 +996,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: if txn.required_confirmations is not None else self.network.required_confirmations ) - txn_dict = txn.model_dump(by_alias=True, mode="json") + txn_data = txn_data or txn.model_dump(by_alias=True, mode="json") if vm_err: receipt = self._create_receipt( block_number=-1, # Not in a block. @@ -984,10 +1004,14 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: required_confirmations=required_confirmations, status=TransactionStatusEnum.FAILING, txn_hash=txn_hash, - **txn_dict, + **txn_data, ) else: - receipt = self.get_receipt(txn_hash, required_confirmations=required_confirmations) + receipt = self.get_receipt( + txn_hash, + required_confirmations=required_confirmations, + transaction=txn_data, + ) # NOTE: Ensure to cache even the failed receipts. # NOTE: Caching must happen before error enrichment. @@ -995,15 +1019,15 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: if receipt.failed: # For some reason, some nodes have issues with integer-types. - if isinstance(txn_dict.get("type"), int): - txn_dict["type"] = to_hex(txn_dict["type"]) + if isinstance(txn_data.get("type"), int): + txn_data["type"] = to_hex(txn_data["type"]) # NOTE: For some reason, some providers have issues with # `nonce`, it's not needed anyway. - txn_dict.pop("nonce", None) + txn_data.pop("nonce", None) # NOTE: Using JSON mode since used as request data. - txn_params = cast(TxParams, txn_dict) + txn_params = cast(TxParams, txn_data) # Replay txn to get revert reason try: @@ -1112,7 +1136,7 @@ def create_access_list( list[:class:`~ape_ethereum.transactions.AccessList`] """ # NOTE: Using JSON mode since used in request data. - tx_dict = transaction.model_dump(by_alias=True, mode="json", exclude=("chain_id",)) + tx_dict = transaction.model_dump(by_alias=True, mode="json", exclude={"chain_id"}) tx_dict_converted = {} for key, val in tx_dict.items(): if isinstance(val, int): @@ -1401,10 +1425,12 @@ def _complete_connect(self): self.concurrency = 32 self.block_page_size = 50_000 else: - client_name = client_version.split("/")[0] + client_name = client_version.partition("/")[0] logger.info(f"Connecting to a '{client_name}' node.") - self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy) + if not self.network.is_dev: + self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy) + # Check for chain errors, including syncing try: chain_id = self.web3.eth.chain_id diff --git a/src/ape_ethereum/trace.py b/src/ape_ethereum/trace.py index 9917634340..784933b1f0 100644 --- a/src/ape_ethereum/trace.py +++ b/src/ape_ethereum/trace.py @@ -432,12 +432,12 @@ def _discover_calltrace_approach(self) -> CallTreeNode: TA.BASIC: self._get_basic_calltree, } - reason = "" + reason_map = {} for approach, fn in approaches.items(): try: call = fn() except Exception as err: - reason = f"{err}" + reason_map[approach.name] = f"{err}" continue self._set_approach(approach) @@ -445,7 +445,8 @@ def _discover_calltrace_approach(self) -> CallTreeNode: # Not sure this would happen, as the basic-approach should # always work. - raise ProviderError(f"Unable to create CallTreeNode. Reason: {reason}") + reason_str = ", ".join(f"{k}={v}" for k, v in reason_map.items()) + raise ProviderError(f"Unable to create CallTreeNode. Reason(s): {reason_str}") def _debug_trace_transaction(self, parameters: Optional[dict] = None) -> dict: parameters = parameters or self.debug_trace_transaction_parameters @@ -477,7 +478,6 @@ def _get_basic_calltree(self) -> CallTreeNode: # Figure out the 'returndata' using 'eth_call' RPC. tx = receipt.transaction.model_copy(update={"nonce": None}) - try: return_value = self.provider.send_call(tx, block_id=receipt.block_number) except ContractLogicError: diff --git a/src/ape_node/provider.py b/src/ape_node/provider.py index 2dc19775dc..1d91f745d1 100644 --- a/src/ape_node/provider.py +++ b/src/ape_node/provider.py @@ -384,7 +384,7 @@ def disconnect(self): super().disconnect() def snapshot(self) -> SnapshotID: - return self.get_block("latest").number or 0 + return self._get_latest_block().number or 0 def restore(self, snapshot_id: SnapshotID): if isinstance(snapshot_id, int): @@ -397,7 +397,7 @@ def restore(self, snapshot_id: SnapshotID): block_number_hex_str = add_0x_prefix(HexStr(snapshot_id)) block_number_int = int(snapshot_id, 16) - current_block = self.get_block("latest").number + current_block = self._get_latest_block().number if block_number_int == current_block: # Head is already at this block. return diff --git a/src/ape_test/accounts.py b/src/ape_test/accounts.py index 307b258026..833c4a62e8 100644 --- a/src/ape_test/accounts.py +++ b/src/ape_test/accounts.py @@ -4,9 +4,11 @@ from eip712.messages import EIP712Message from eth_account import Account as EthAccount +from eth_account._utils.signing import sign_transaction_dict from eth_account.messages import SignableMessage, encode_defunct +from eth_keys.datatypes import PrivateKey # type: ignore from eth_pydantic_types import HexBytes -from eth_utils import to_bytes +from eth_utils import to_bytes, to_hex from ape.api import TestAccountAPI, TestAccountContainerAPI, TransactionAPI from ape.exceptions import ProviderNotConnectedError, SignatureError @@ -126,20 +128,29 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature] return None def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]: - # Signs anything that's given to it - # NOTE: Using JSON mode since used as request data. - tx_data = txn.model_dump(mode="json", by_alias=True) + # Signs any transaction that's given to it. + # NOTE: Using JSON mode, as only primitive types can be signed. + tx_data = txn.model_dump(mode="json", by_alias=True, exclude={"sender"}) + private_key = PrivateKey(HexBytes(self.private_key)) + # NOTE: var name `sig_r` instead of `r` to avoid clashing with pdb commands. try: - signature = EthAccount.sign_transaction(tx_data, self.private_key) + ( + sig_v, + sig_r, + sig_s, + _, + ) = sign_transaction_dict(private_key, tx_data) except TypeError as err: # Occurs when missing properties on the txn that are needed to sign. raise SignatureError(str(err)) from err + # NOTE: Using `to_bytes(hexstr=to_hex(sig_r))` instead of `to_bytes(sig_r)` as + # a performance optimization. txn.signature = TransactionSignature( - v=signature.v, - r=to_bytes(signature.r), - s=to_bytes(signature.s), + v=sig_v, + r=to_bytes(hexstr=to_hex(sig_r)), + s=to_bytes(hexstr=to_hex(sig_s)), ) return txn diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index baf13597a5..5bb9b3e6ad 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -17,8 +17,9 @@ from web3.providers.eth_tester.defaults import API_ENDPOINTS, static_return from web3.types import TxParams -from ape.api import PluginConfig, ReceiptAPI, TestProviderAPI, TransactionAPI +from ape.api import BlockAPI, PluginConfig, ReceiptAPI, TestProviderAPI, TransactionAPI from ape.exceptions import ( + APINotImplementedError, ContractLogicError, ProviderError, ProviderNotConnectedError, @@ -90,6 +91,10 @@ def auto_mine(self, value: Any) -> None: else: raise TypeError("Expecting bool-value for auto_mine setter.") + @property + def max_gas(self) -> int: + return self.evm_backend.get_block_by_number("latest")["gas_limit"] + def connect(self): if "tester" in self.__dict__: del self.__dict__["tester"] @@ -119,9 +124,7 @@ def estimate_gas_cost( estimate_gas = self.web3.eth.estimate_gas # NOTE: Using JSON mode since used as request data. - txn_dict = txn.model_dump(mode="json") - - txn_dict.pop("gas", None) + txn_dict = txn.model_dump(by_alias=True, mode="json", exclude={"gas_limit", "chain_id"}) txn_data = cast(TxParams, txn_dict) try: @@ -206,6 +209,7 @@ def send_call( data.pop("gasLimit", None) data.pop("maxFeePerGas", None) data.pop("maxPriorityFeePerGas", None) + data.pop("signature", None) tx_params = cast(TxParams, data) vm_err = None @@ -233,8 +237,11 @@ def send_call( def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: vm_err = None + txn_dict = None try: - txn_hash = self.web3.eth.send_raw_transaction(txn.serialize_transaction()).hex() + txn_hash = self.tester.ethereum_tester.send_raw_transaction( + txn.serialize_transaction().hex() + ) except (ValidationError, TransactionFailed, Web3ContractLogicError) as err: vm_err = self.get_virtual_machine_error(err, txn=txn) if txn.raise_on_revert: @@ -248,17 +255,26 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: required_confirmations=required_confirmations, error=vm_err, txn_hash=txn_hash ) else: - receipt = self.get_receipt(txn_hash, required_confirmations=required_confirmations) + txn_dict = txn_dict or txn.model_dump(mode="json") + + # Signature is typically excluded from the model fields, + # so we have to include it manually. + txn_dict["signature"] = txn.signature + + receipt = self.get_receipt( + txn_hash, required_confirmations=required_confirmations, transaction=txn_dict + ) # NOTE: Caching must happen before error enrichment. self.chain_manager.history.append(receipt) if receipt.failed: # NOTE: Using JSON mode since used as request data. - txn_dict = txn.model_dump(mode="json") + txn_dict = txn_dict or txn.model_dump(mode="json") txn_dict["nonce"] += 1 txn_params = cast(TxParams, txn_dict) + txn_dict.pop("signature", None) # Replay txn to get revert reason try: @@ -284,7 +300,7 @@ def snapshot(self) -> SnapshotID: def restore(self, snapshot_id: SnapshotID): if snapshot_id: - current_hash = self.get_block("latest").hash + current_hash = self._get_latest_block_rpc().get("hash") if current_hash != snapshot_id: try: return self.evm_backend.revert_to_snapshot(snapshot_id) @@ -292,9 +308,9 @@ def restore(self, snapshot_id: SnapshotID): raise UnknownSnapshotError(snapshot_id) def set_timestamp(self, new_timestamp: int): - current = self.get_block("pending").timestamp - if new_timestamp == current: - # Is the same, treat as a noop. + current_timestamp = self.evm_backend.get_block_by_number("pending")["timestamp"] + if new_timestamp == current_timestamp: + # no change, return immediately return try: @@ -316,13 +332,24 @@ def set_timestamp(self, new_timestamp: int): def mine(self, num_blocks: int = 1): self.evm_backend.mine_blocks(num_blocks) + def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + # perf: Using evm_backend directly instead of going through web3. + return self.evm_backend.get_balance( + HexBytes(address), block_number="latest" if block_id is None else block_id + ) + + def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + return self.evm_backend.get_nonce( + HexBytes(address), block_number="latest" if block_id is None else block_id + ) + def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: from_block = max(0, log_filter.start_block) if log_filter.stop_block is None: to_block = None else: - latest_block = self.get_block("latest").number + latest_block = self._get_latest_block_rpc().get("number") to_block = ( min(latest_block, log_filter.stop_block) if latest_block is not None @@ -354,6 +381,13 @@ def get_test_account(self, index: int) -> "TestAccountAPI": def add_account(self, private_key: str): self.evm_backend.add_account(private_key) + def _get_last_base_fee(self) -> int: + base_fee = self._get_latest_block_rpc().get("base_fee_per_gas", None) + if base_fee is not None: + return base_fee + + raise APINotImplementedError("No base fee found in block.") + def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError: if isinstance(exception, ValidationError): match = self._CANNOT_AFFORD_GAS_PATTERN.match(str(exception)) @@ -396,3 +430,11 @@ def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMa else: return VirtualMachineError(base_err=exception, **kwargs) + + def _get_latest_block(self) -> BlockAPI: + # perf: By-pass as much as possible since this is a common action. + data = self._get_latest_block_rpc() + return self.network.ecosystem.decode_block(data) + + def _get_latest_block_rpc(self) -> dict: + return self.evm_backend.get_block_by_number("latest") diff --git a/tests/functional/test_contract_call_handler.py b/tests/functional/test_contract_call_handler.py index 14069aaecc..f59b6f03ae 100644 --- a/tests/functional/test_contract_call_handler.py +++ b/tests/functional/test_contract_call_handler.py @@ -11,7 +11,8 @@ def test_struct_input( assert actual == output_from_struct_input_call -def test_call_contract_not_found(mocker, method_abi_with_struct_input): +def test_call_contract_not_found(mocker, method_abi_with_struct_input, networks): + (networks.ethereum.local.__dict__ or {}).pop("explorer", None) contract = mocker.MagicMock() contract.is_contract = False method = method_abi_with_struct_input @@ -21,7 +22,8 @@ def test_call_contract_not_found(mocker, method_abi_with_struct_input): handler() -def test_transact_contract_not_found(mocker, owner, method_abi_with_struct_input): +def test_transact_contract_not_found(mocker, owner, method_abi_with_struct_input, networks): + (networks.ethereum.local.__dict__ or {}).pop("explorer", None) contract = mocker.MagicMock() contract.is_contract = False method = method_abi_with_struct_input diff --git a/tests/functional/test_exceptions.py b/tests/functional/test_exceptions.py index cf5521be69..cd88449666 100644 --- a/tests/functional/test_exceptions.py +++ b/tests/functional/test_exceptions.py @@ -71,7 +71,7 @@ def test_deploy_address_as_address( contract = vyper_contract_container.deploy(629, sender=owner) receipt = contract.creation_metadata.receipt - data = receipt.model_dump(exclude=("transaction",)) + data = receipt.model_dump(exclude={"transaction"}) # Show when receiver is zero_address, it still picks contract address. data["transaction"] = ethereum.create_transaction(receiver=zero_address) diff --git a/tests/functional/test_provider.py b/tests/functional/test_provider.py index 136893c363..b1f4428c57 100644 --- a/tests/functional/test_provider.py +++ b/tests/functional/test_provider.py @@ -299,17 +299,20 @@ def test_no_comma_in_rpc_url(): def test_send_transaction_when_no_error_and_receipt_fails( - mock_transaction, mock_web3, eth_tester_provider, owner, vyper_contract_instance + mocker, mock_web3, mock_transaction, eth_tester_provider, owner, vyper_contract_instance ): start_web3 = eth_tester_provider._web3 eth_tester_provider._web3 = mock_web3 + mock_eth_tester = mocker.MagicMock() + original_tester = eth_tester_provider.tester + eth_tester_provider.__dict__["tester"] = mock_eth_tester try: # NOTE: Value is meaningless. tx_hash = HashBytes32.__eth_pydantic_validate__(123**36) # Sending tx "works" meaning no vm error. - mock_web3.eth.send_raw_transaction.return_value = tx_hash + mock_eth_tester.ethereum_tester.send_raw_transaction.return_value = tx_hash # Getting a receipt "works", but you get a failed one. receipt_data = { @@ -334,6 +337,7 @@ def test_send_transaction_when_no_error_and_receipt_fails( finally: eth_tester_provider._web3 = start_web3 + eth_tester_provider.__dict__["tester"] = original_tester def test_network_choice(eth_tester_provider):