diff --git a/src/commands/start_api.py b/src/commands/start_api.py index cdea782e..21ad75b0 100644 --- a/src/commands/start_api.py +++ b/src/commands/start_api.py @@ -21,7 +21,7 @@ LOG_PLAIN, settings, ) -from src.validators.typings import ValidatorsRegistrationMode +from src.validators.typings import RelayerTypes, ValidatorsRegistrationMode logger = logging.getLogger(__name__) @@ -167,6 +167,17 @@ envvar='LOG_LEVEL', help='The log level.', ) +@click.option( + '--relayer-type', + type=click.Choice( + [RelayerTypes.DEFAULT, RelayerTypes.DVT], + case_sensitive=False, + ), + default=RelayerTypes.DEFAULT, + help='Relayer type.', + prompt='Enter the relayer type', + envvar='RELAYER_TYPE', +) @click.option( '--relayer-endpoint', type=str, @@ -196,6 +207,7 @@ def start_api( hot_wallet_password_file: str | None, max_fee_per_gas_gwei: int, database_dir: str | None, + relayer_type: str, relayer_endpoint: str, ) -> None: vault_config = VaultConfig(vault, Path(data_dir)) @@ -225,6 +237,7 @@ def start_api( database_dir=database_dir, log_level=log_level, log_format=log_format, + relayer_type=relayer_type, relayer_endpoint=relayer_endpoint, validators_registration_mode=validators_registration_mode, ) diff --git a/src/config/settings.py b/src/config/settings.py index b8a87555..cf01bbc0 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -1,6 +1,6 @@ from pathlib import Path -from decouple import Choices, Csv +from decouple import Csv from decouple import config as decouple_config from web3 import Web3 from web3.types import ChecksumAddress @@ -76,6 +76,7 @@ class Settings(metaclass=Singleton): sentry_environment: str pool_size: int | None + relayer_type: str relayer_endpoint: str relayer_timeout: int validators_registration_mode: ValidatorsRegistrationMode @@ -125,6 +126,7 @@ def set( log_level: str | None = None, log_format: str | None = None, pool_size: int | None = None, + relayer_type: str = RelayerTypes.DEFAULT, relayer_endpoint: str | None = None, validators_registration_mode: ValidatorsRegistrationMode = ValidatorsRegistrationMode.AUTO, min_validators_registration: int = DEFAULT_MIN_VALIDATORS_REGISTRATION, @@ -241,6 +243,7 @@ def set( self.consensus_retry_timeout = decouple_config( 'CONSENSUS_RETRY_TIMEOUT', default=120, cast=int ) + self.relayer_type = relayer_type self.relayer_endpoint = relayer_endpoint or '' self.relayer_timeout = decouple_config('RELAYER_TIMEOUT', default=10, cast=int) @@ -300,14 +303,3 @@ def is_genesis_vault(self) -> bool: LOG_JSON = 'json' LOG_FORMATS = [LOG_PLAIN, LOG_JSON] LOG_DATE_FORMAT = '%Y-%m-%d %H:%M:%S' - -RELAYER_TYPE: str = decouple_config( - 'RELAYER_TYPE', - default=RelayerTypes.DEFAULT, - cast=Choices( - [ - RelayerTypes.DEFAULT, - RelayerTypes.DVT, - ] - ), -) diff --git a/src/validators/relayer.py b/src/validators/relayer.py index 5f097d10..57cff027 100644 --- a/src/validators/relayer.py +++ b/src/validators/relayer.py @@ -9,7 +9,7 @@ from sw_utils.common import urljoin from web3 import Web3 -from src.config.settings import RELAYER_TYPE, settings +from src.config.settings import settings from src.validators.exceptions import MissingDepositDataValidatorsException from src.validators.execution import ( get_validators_from_deposit_data, @@ -178,7 +178,7 @@ async def _get_validators_from_dvt_relayer( def create_relayer_adapter() -> RelayerAdapter: - if RELAYER_TYPE == RelayerTypes.DVT: + if settings.relayer_type == RelayerTypes.DVT: dvt_relayer = DvtRelayerClient() deposit_data = load_deposit_data(settings.vault, settings.deposit_data_file) return RelayerAdapter(dvt_relayer, deposit_data)