Skip to content

Commit

Permalink
chore: suggesed changes made. created a new method reconnect to conne…
Browse files Browse the repository at this point in the history
…ct to new URI
  • Loading branch information
Aviksaikat committed Aug 8, 2024
1 parent be08aca commit 820a230
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 30 deletions.
40 changes: 21 additions & 19 deletions ape_infura/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,26 @@ def __init__(self):

class Infura(Web3Provider, UpstreamProvider):
network_uris: dict[tuple[str, str], str] = {}
api_keys: list[str] = []
api_keys: set[str] = set()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.load_api_keys()

def load_api_keys(self):
self.api_keys = []
self.api_keys = set()
for env_var_name in _ENVIRONMENT_VARIABLE_NAMES:
if env_var := os.environ.get(env_var_name):
self.api_keys.extend([key.strip() for key in env_var.split(",")])
self.api_keys.update(set(key.strip() for key in env_var.split(",")))

if not self.api_keys:
raise MissingProjectKeyError()

def get_random_api_key(self) -> str:
return random.choice(self.api_keys)
def __get_random_api_key(self) -> str:
"""
Get a random api key a private method. As self.api_keys are unhashable so have to typecast into list to make it hashable
"""
return random.choice(list(self.api_keys))

@property
def uri(self) -> str:
Expand All @@ -60,20 +63,7 @@ def uri(self) -> str:
if (ecosystem_name, network_name) in self.network_uris:
return self.network_uris[(ecosystem_name, network_name)]

key = self.get_random_api_key()

prefix = f"{ecosystem_name}-" if ecosystem_name != "ethereum" else ""
network_uri = f"https://{prefix}{network_name}.infura.io/v3/{key}"
self.network_uris[(ecosystem_name, network_name)] = network_uri
return network_uri

def get_new_uri(self) -> str:
"""
To generate a new URI with a new API key. Added to keep backwards compatibity
"""
key = self.get_random_api_key()
ecosystem_name = self.network.ecosystem.name
network_name = self.network.name
key = self.__get_random_api_key()

prefix = f"{ecosystem_name}-" if ecosystem_name != "ethereum" else ""
network_uri = f"https://{prefix}{network_name}.infura.io/v3/{key}"
Expand Down Expand Up @@ -115,6 +105,18 @@ def connect(self):
def disconnect(self):
self._web3 = None

def reconnect(self):
"""
Disconnect the connectned API.
Refresh the API keys from environment variable.
Make the self.network_uris empty otherwise the old network_uri will be returned.
Connect again.
"""
self.disconnect()
self.load_api_keys()
self.network_uris = {}
self.connect()

def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError:
txn = kwargs.get("txn")
if not hasattr(exception, "args") or not len(exception.args):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"test": [ # `test` GitHub Action jobs uses this
"pytest>=6.0", # Core testing package
"pytest-xdist", # Multi-process runner
"pytest-mock",
"pytest-mock", # Mocking framework
"pytest-cov", # Coverage analyzer plugin
"hypothesis>=6.2.0,<7.0", # Strategy-based fuzzer
"ape-arbitrum", # For integration testing
Expand Down
12 changes: 2 additions & 10 deletions tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,13 @@ def test_load_single_and_multiple_api_keys(provider, mocker):
assert "single_key2" in provider.api_keys


def test_random_api_key_selection(provider, mocker):
mocker.patch.dict(os.environ, {"WEB3_INFURA_PROJECT_ID": "key1,key2,key3,key4,key5"})
provider.load_api_keys()
selected_keys = set()
for _ in range(50): # Run multiple times to ensure randomness
selected_keys.add(provider.get_random_api_key())
assert len(selected_keys) > 1 # Ensure we're getting different keys


def test_uri_with_random_api_key(provider, mocker):
# mocker.patch.dict(os.environ, {"WEB3_INFURA_PROJECT_ID": "key1, key2, key3, key4, key5, key6"})
provider.load_api_keys()
uris = set()
for _ in range(100): # Generate multiple URIs
uri = provider.get_new_uri() # Use get_new_uri method to get a URI
provider.reconnect() # connect to a new URI
uri = provider.uri
uris.add(uri)
assert uri.startswith("https")
assert "/v3" in uri
Expand Down

0 comments on commit 820a230

Please sign in to comment.