Skip to content

Commit

Permalink
fix: issues with api keys
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Aug 19, 2024
1 parent 86e575f commit bc62811
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
29 changes: 17 additions & 12 deletions ape_infura/provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import random
from functools import cached_property
from typing import Optional

from ape.api import UpstreamProvider
Expand Down Expand Up @@ -35,26 +36,30 @@ def __init__(self):

class Infura(Web3Provider, UpstreamProvider):
network_uris: dict[tuple[str, str], str] = {}
api_keys: set[str] = set()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.load_api_keys()

def load_api_keys(self):
self.api_keys = set()
def __get_random_api_key(self) -> str:
"""
Get a random api key a private method.
"""
if keys := self._api_keys:
return random.choice(list(keys))

raise MissingProjectKeyError()

@cached_property
def _api_keys(self) -> set[str]:
api_keys = set()
for env_var_name in _ENVIRONMENT_VARIABLE_NAMES:
if env_var := os.environ.get(env_var_name):
self.api_keys.update(set(key.strip() for key in env_var.split(",")))
api_keys.update(set(key.strip() for key in env_var.split(",")))

if not self.api_keys:
if not api_keys:
raise MissingProjectKeyError()

def __get_random_api_key(self) -> str:
"""
Get a random api key a private method.
"""
return random.choice(list(self.api_keys))
return api_keys

@property
def uri(self) -> str:
Expand Down Expand Up @@ -109,7 +114,7 @@ def disconnect(self):
Make the self.network_uris empty otherwise the old network_uri will be returned.
"""
self._web3 = None
self.load_api_keys()
(self.__dict__ or {}).pop("_api_keys", None)
self.network_uris = {}

def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError:
Expand Down
39 changes: 30 additions & 9 deletions tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,49 @@ def test_infura_ws(provider):


def test_load_multiple_api_keys(provider, mocker):
original_env = os.environ.copy()
mocker.patch.dict(
os.environ,
{"WEB3_INFURA_PROJECT_ID": "key1,key2,key3", "WEB3_INFURA_API_KEY": "key4,key5,key6"},
)
provider.load_api_keys()
# As there will be API keys in the ENV as well
assert len(provider.api_keys) == 6
assert "key1" in provider.api_keys
assert "key6" in provider.api_keys
provider.disconnect()
assert len(provider._api_keys) == 6
assert "key1" in provider._api_keys
assert "key6" in provider._api_keys

os.environ.clear()
os.environ.update(original_env)

# Disconnect so key isn't cached.
provider.disconnect()


def test_load_single_and_multiple_api_keys(provider, mocker):
original_env = os.environ.copy()
mocker.patch.dict(
os.environ,
{
"WEB3_INFURA_PROJECT_ID": "single_key1",
"WEB3_INFURA_API_KEY": "single_key2",
},
)
provider.load_api_keys()
assert len(provider.api_keys) == 2
assert "single_key1" in provider.api_keys
assert "single_key2" in provider.api_keys
provider.disconnect()
assert len(provider._api_keys) == 2
assert "single_key1" in provider._api_keys
assert "single_key2" in provider._api_keys

os.environ.clear()
os.environ.update(original_env)

# Disconnect so key isn't cached.
provider.disconnect()


def test_uri_with_random_api_key(provider, mocker):
original_env = os.environ.copy()
mocker.patch.dict(os.environ, {"WEB3_INFURA_PROJECT_ID": "key1, key2, key3, key4, key5, key6"})
provider.load_api_keys()

uris = set()
for _ in range(100): # Generate multiple URIs
provider.disconnect() # connect to a new URI
Expand All @@ -72,3 +87,9 @@ def test_uri_with_random_api_key(provider, mocker):
assert uri.startswith("https")
assert "/v3" in uri
assert len(uris) > 1 # Ensure we're getting different URIs with different

os.environ.clear()
os.environ.update(original_env)

# Disconnect so key isn't cached.
provider.disconnect()

0 comments on commit bc62811

Please sign in to comment.