From bc6281116c74019ed1f018c7dfce07ead6426452 Mon Sep 17 00:00:00 2001 From: Juliya Smith Date: Mon, 19 Aug 2024 12:36:57 -0500 Subject: [PATCH] fix: issues with api keys --- ape_infura/provider.py | 29 +++++++++++++++++------------ tests/test_provider.py | 39 ++++++++++++++++++++++++++++++--------- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/ape_infura/provider.py b/ape_infura/provider.py index 37f681e..ca5515c 100644 --- a/ape_infura/provider.py +++ b/ape_infura/provider.py @@ -1,5 +1,6 @@ import os import random +from functools import cached_property from typing import Optional from ape.api import UpstreamProvider @@ -35,26 +36,30 @@ def __init__(self): class Infura(Web3Provider, UpstreamProvider): network_uris: dict[tuple[str, str], str] = {} - api_keys: set[str] = set() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.load_api_keys() - def load_api_keys(self): - self.api_keys = set() + def __get_random_api_key(self) -> str: + """ + Get a random api key a private method. + """ + if keys := self._api_keys: + return random.choice(list(keys)) + + raise MissingProjectKeyError() + + @cached_property + def _api_keys(self) -> set[str]: + api_keys = set() for env_var_name in _ENVIRONMENT_VARIABLE_NAMES: if env_var := os.environ.get(env_var_name): - self.api_keys.update(set(key.strip() for key in env_var.split(","))) + api_keys.update(set(key.strip() for key in env_var.split(","))) - if not self.api_keys: + if not api_keys: raise MissingProjectKeyError() - def __get_random_api_key(self) -> str: - """ - Get a random api key a private method. - """ - return random.choice(list(self.api_keys)) + return api_keys @property def uri(self) -> str: @@ -109,7 +114,7 @@ def disconnect(self): Make the self.network_uris empty otherwise the old network_uri will be returned. """ self._web3 = None - self.load_api_keys() + (self.__dict__ or {}).pop("_api_keys", None) self.network_uris = {} def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError: diff --git a/tests/test_provider.py b/tests/test_provider.py index cd9f2aa..bddbe28 100644 --- a/tests/test_provider.py +++ b/tests/test_provider.py @@ -36,18 +36,26 @@ def test_infura_ws(provider): def test_load_multiple_api_keys(provider, mocker): + original_env = os.environ.copy() mocker.patch.dict( os.environ, {"WEB3_INFURA_PROJECT_ID": "key1,key2,key3", "WEB3_INFURA_API_KEY": "key4,key5,key6"}, ) - provider.load_api_keys() # As there will be API keys in the ENV as well - assert len(provider.api_keys) == 6 - assert "key1" in provider.api_keys - assert "key6" in provider.api_keys + provider.disconnect() + assert len(provider._api_keys) == 6 + assert "key1" in provider._api_keys + assert "key6" in provider._api_keys + + os.environ.clear() + os.environ.update(original_env) + + # Disconnect so key isn't cached. + provider.disconnect() def test_load_single_and_multiple_api_keys(provider, mocker): + original_env = os.environ.copy() mocker.patch.dict( os.environ, { @@ -55,15 +63,22 @@ def test_load_single_and_multiple_api_keys(provider, mocker): "WEB3_INFURA_API_KEY": "single_key2", }, ) - provider.load_api_keys() - assert len(provider.api_keys) == 2 - assert "single_key1" in provider.api_keys - assert "single_key2" in provider.api_keys + provider.disconnect() + assert len(provider._api_keys) == 2 + assert "single_key1" in provider._api_keys + assert "single_key2" in provider._api_keys + + os.environ.clear() + os.environ.update(original_env) + + # Disconnect so key isn't cached. + provider.disconnect() def test_uri_with_random_api_key(provider, mocker): + original_env = os.environ.copy() mocker.patch.dict(os.environ, {"WEB3_INFURA_PROJECT_ID": "key1, key2, key3, key4, key5, key6"}) - provider.load_api_keys() + uris = set() for _ in range(100): # Generate multiple URIs provider.disconnect() # connect to a new URI @@ -72,3 +87,9 @@ def test_uri_with_random_api_key(provider, mocker): assert uri.startswith("https") assert "/v3" in uri assert len(uris) > 1 # Ensure we're getting different URIs with different + + os.environ.clear() + os.environ.update(original_env) + + # Disconnect so key isn't cached. + provider.disconnect()