From 820a230f228ba32861cb5dae5fbd1e2c971e9ca3 Mon Sep 17 00:00:00 2001 From: Aviksaikat Date: Thu, 8 Aug 2024 13:15:28 +0530 Subject: [PATCH] chore: suggesed changes made. created a new method reconnect to connect to new URI --- ape_infura/provider.py | 40 +++++++++++++++++++++------------------- setup.py | 2 +- tests/test_provider.py | 12 ++---------- 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/ape_infura/provider.py b/ape_infura/provider.py index b502b7f..b185e5a 100644 --- a/ape_infura/provider.py +++ b/ape_infura/provider.py @@ -35,23 +35,26 @@ def __init__(self): class Infura(Web3Provider, UpstreamProvider): network_uris: dict[tuple[str, str], str] = {} - api_keys: list[str] = [] + api_keys: set[str] = set() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.load_api_keys() def load_api_keys(self): - self.api_keys = [] + self.api_keys = set() for env_var_name in _ENVIRONMENT_VARIABLE_NAMES: if env_var := os.environ.get(env_var_name): - self.api_keys.extend([key.strip() for key in env_var.split(",")]) + self.api_keys.update(set(key.strip() for key in env_var.split(","))) if not self.api_keys: raise MissingProjectKeyError() - def get_random_api_key(self) -> str: - return random.choice(self.api_keys) + def __get_random_api_key(self) -> str: + """ + Get a random api key a private method. As self.api_keys are unhashable so have to typecast into list to make it hashable + """ + return random.choice(list(self.api_keys)) @property def uri(self) -> str: @@ -60,20 +63,7 @@ def uri(self) -> str: if (ecosystem_name, network_name) in self.network_uris: return self.network_uris[(ecosystem_name, network_name)] - key = self.get_random_api_key() - - prefix = f"{ecosystem_name}-" if ecosystem_name != "ethereum" else "" - network_uri = f"https://{prefix}{network_name}.infura.io/v3/{key}" - self.network_uris[(ecosystem_name, network_name)] = network_uri - return network_uri - - def get_new_uri(self) -> str: - """ - To generate a new URI with a new API key. Added to keep backwards compatibity - """ - key = self.get_random_api_key() - ecosystem_name = self.network.ecosystem.name - network_name = self.network.name + key = self.__get_random_api_key() prefix = f"{ecosystem_name}-" if ecosystem_name != "ethereum" else "" network_uri = f"https://{prefix}{network_name}.infura.io/v3/{key}" @@ -115,6 +105,18 @@ def connect(self): def disconnect(self): self._web3 = None + def reconnect(self): + """ + Disconnect the connectned API. + Refresh the API keys from environment variable. + Make the self.network_uris empty otherwise the old network_uri will be returned. + Connect again. + """ + self.disconnect() + self.load_api_keys() + self.network_uris = {} + self.connect() + def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError: txn = kwargs.get("txn") if not hasattr(exception, "args") or not len(exception.args): diff --git a/setup.py b/setup.py index 4627048..ad8c31d 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ "test": [ # `test` GitHub Action jobs uses this "pytest>=6.0", # Core testing package "pytest-xdist", # Multi-process runner - "pytest-mock", + "pytest-mock", # Mocking framework "pytest-cov", # Coverage analyzer plugin "hypothesis>=6.2.0,<7.0", # Strategy-based fuzzer "ape-arbitrum", # For integration testing diff --git a/tests/test_provider.py b/tests/test_provider.py index 6d67ffd..f7d8034 100644 --- a/tests/test_provider.py +++ b/tests/test_provider.py @@ -61,21 +61,13 @@ def test_load_single_and_multiple_api_keys(provider, mocker): assert "single_key2" in provider.api_keys -def test_random_api_key_selection(provider, mocker): - mocker.patch.dict(os.environ, {"WEB3_INFURA_PROJECT_ID": "key1,key2,key3,key4,key5"}) - provider.load_api_keys() - selected_keys = set() - for _ in range(50): # Run multiple times to ensure randomness - selected_keys.add(provider.get_random_api_key()) - assert len(selected_keys) > 1 # Ensure we're getting different keys - - def test_uri_with_random_api_key(provider, mocker): # mocker.patch.dict(os.environ, {"WEB3_INFURA_PROJECT_ID": "key1, key2, key3, key4, key5, key6"}) provider.load_api_keys() uris = set() for _ in range(100): # Generate multiple URIs - uri = provider.get_new_uri() # Use get_new_uri method to get a URI + provider.reconnect() # connect to a new URI + uri = provider.uri uris.add(uri) assert uri.startswith("https") assert "/v3" in uri