Skip to content

Commit

Permalink
feat: add support for API secret (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Nov 26, 2024
1 parent 589a728 commit 626b932
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 10 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ export WEB3_INFURA_PROJECT_ID=MY_API_TOKEN
export WEB3_INFURA_PROJECT_ID=MY_API_TOKEN1, MY_API_TOKEN2
```

Additionally, if your app requires an API secret as well, use either of the following environment variables:

- WEB3_INFURA_PROJECT_ID
- WEB3_INFURA_API_KEY

And each request will use the secret as a form of authentication.

To use the Infura provider plugin in most commands, set it via the `--network` option:

```bash
Expand Down
43 changes: 37 additions & 6 deletions ape_infura/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from ape.api import UpstreamProvider
from ape.exceptions import ContractLogicError, ProviderError, VirtualMachineError
from ape_ethereum.provider import Web3Provider
from requests import Session
from web3 import HTTPProvider, Web3
from web3.exceptions import ContractLogicError as Web3ContractLogicError
from web3.exceptions import ExtraDataLengthError
from web3.gas_strategies.rpc import rpc_gas_price_strategy
from web3.middleware import geth_poa_middleware
from web3.middleware.validation import MAX_EXTRADATA_LENGTH

_ENVIRONMENT_VARIABLE_NAMES = ("WEB3_INFURA_PROJECT_ID", "WEB3_INFURA_API_KEY")
_API_KEY_ENVIRONMENT_VARIABLE_NAMES = ("WEB3_INFURA_PROJECT_ID", "WEB3_INFURA_API_KEY")
_API_SECRET_ENVIRONMENT_VARIABLE_NAMES = ("WEB3_INFURA_PROJECT_SECRET", "WEB3_INFURA_API_SECRET")

# NOTE: https://docs.infura.io/learn/websockets#supported-networks
_WEBSOCKET_CAPABLE_NETWORKS = {
"arbitrum": ("mainnet", "sepolia"),
Expand All @@ -38,10 +41,26 @@ class InfuraProviderError(ProviderError):

class MissingProjectKeyError(InfuraProviderError):
def __init__(self):
env_var_str = ", ".join([f"${n}" for n in _ENVIRONMENT_VARIABLE_NAMES])
env_var_str = ", ".join([f"${n}" for n in _API_KEY_ENVIRONMENT_VARIABLE_NAMES])
super().__init__(f"Must set one of {env_var_str}")


def _get_api_key_secret() -> Optional[str]:
for name in _API_SECRET_ENVIRONMENT_VARIABLE_NAMES:
if secret := os.environ.get(name):
return secret

return None


def _get_session() -> Session:
session = Session()
if api_secret := _get_api_key_secret():
session.auth = ("", api_secret)

return session


class Infura(Web3Provider, UpstreamProvider):
network_uris: dict[tuple[str, str], str] = {}

Expand All @@ -60,7 +79,7 @@ def __get_random_api_key(self) -> str:
@cached_property
def _api_keys(self) -> set[str]:
api_keys = set()
for env_var_name in _ENVIRONMENT_VARIABLE_NAMES:
for env_var_name in _API_KEY_ENVIRONMENT_VARIABLE_NAMES:
if env_var := os.environ.get(env_var_name):
api_keys.update(set(key.strip() for key in env_var.split(",")))

Expand All @@ -78,8 +97,15 @@ def uri(self) -> str:

key = self.__get_random_api_key()

prefix = f"{ecosystem_name}-" if ecosystem_name != "ethereum" else ""
network_uri = f"https://{prefix}{network_name}.infura.io/v3/{key}"
if ecosystem_name == "bsc" and "opbnb" in network_name:
sub_network = network_name.split("-")[-1] if "-" in network_name else "mainnet"
prefix = f"opbnb-{sub_network}"
else:
prefix = f"{ecosystem_name}-" if ecosystem_name != "ethereum" else ""
prefix = f"{prefix}{network_name}"

network_uri = f"https://{prefix}.infura.io/v3/{key}"

self.network_uris[(ecosystem_name, network_name)] = network_uri
return network_uri

Expand All @@ -104,7 +130,9 @@ def connection_str(self) -> str:
return self.uri

def connect(self):
self._web3 = _create_web3(HTTPProvider(self.uri))
session = _get_session()
http_provider = HTTPProvider(self.uri, session=session)
self._web3 = _create_web3(http_provider)

if self._needs_poa_middleware:
self._web3.middleware_onion.inject(geth_poa_middleware, layer=0)
Expand All @@ -130,6 +158,9 @@ def _needs_poa_middleware(self) -> bool:
block = self.web3.eth.get_block(block_id) # type: ignore
except ExtraDataLengthError:
return True
except Exception:
# Some nodes are "light" and may not find earliest blocks.
continue
else:
if (
"proofOfAuthorityData" in block
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from ape_infura import NETWORKS

NETWORK_SKIPS = ("starknet",)


@pytest.fixture
def accounts():
Expand All @@ -20,7 +22,9 @@ def networks():


# NOTE: Using a `str` as param for better pytest test-case name generation.
@pytest.fixture(params=[f"{e}:{n}" for e, values in NETWORKS.items() for n in values])
@pytest.fixture(
params=[f"{e}:{n}" for e, values in NETWORKS.items() if e not in NETWORK_SKIPS for n in values]
)
def provider(networks, request):
ecosystem, network = request.param.split(":")
ecosystem_cls = networks.get_ecosystem(ecosystem)
Expand Down
25 changes: 22 additions & 3 deletions tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,30 @@
from web3.exceptions import ExtraDataLengthError
from web3.middleware import geth_poa_middleware

from ape_infura.provider import _WEBSOCKET_CAPABLE_NETWORKS, Infura
from ape_infura.provider import _WEBSOCKET_CAPABLE_NETWORKS, Infura, _get_session


def test_infura_http(provider):
ecosystem = provider.network.ecosystem.name
network = provider.network.name

if network in ("opbnb-testnet",):
pytest.skip("This network is weird and has missing trie node errors")

assert isinstance(provider, Infura)
assert provider.http_uri.startswith("https")
assert provider.get_balance(ZERO_ADDRESS) > 0
assert provider.get_block(0)
ecosystem_uri = "" if ecosystem == "ethereum" else f"{ecosystem}-"
assert f"https://{ecosystem_uri}{network}.infura.io/v3/" in provider.uri
if "opbnb" in network:
expected = (
"https://opbnb-mainnet.infura.io/v3/"
if network == "opbnb"
else f"https://{network}.infura.io/v3/"
)
else:
expected = f"https://{ecosystem_uri}{network}.infura.io/v3/"

assert expected in provider.uri


def test_infura_ws(provider):
Expand Down Expand Up @@ -107,3 +119,10 @@ def test_dynamic_poa_check(mocker):
patch.return_value = mock_web3
infura.connect()
mock_web3.middleware_onion.inject.assert_called_once_with(geth_poa_middleware, layer=0)


def test_api_secret():
os.environ["WEB3_INFURA_PROJECT_SECRET"] = "123"
session = _get_session()
assert session.auth == ("", "123")
del os.environ["WEB3_INFURA_PROJECT_SECRET"]

0 comments on commit 626b932

Please sign in to comment.