From 613d4f5205b5db8063ea4281b3013814e4e7f063 Mon Sep 17 00:00:00 2001 From: Lukas Fehring Date: Mon, 15 Apr 2024 16:22:11 +0200 Subject: [PATCH] Update ssh handling --- py_experimenter/config.py | 17 ++++++++++++----- py_experimenter/database_connector_mysql.py | 17 +++++++---------- py_experimenter/experimenter.py | 14 +++++++------- test/test_database_connector.py | 7 +++---- 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/py_experimenter/config.py b/py_experimenter/config.py index 1daac4fa..700a28d5 100644 --- a/py_experimenter/config.py +++ b/py_experimenter/config.py @@ -9,11 +9,7 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from py_experimenter import utils -from py_experimenter.exceptions import ( - InvalidColumnError, - InvalidConfigError, - InvalidLogtableError, -) +from py_experimenter.exceptions import InvalidColumnError, InvalidConfigError, InvalidLogtableError class Cfg(ABC): @@ -45,6 +41,7 @@ class DatabaseCfg(Cfg): def __init__( self, provider: str, + use_ssh_tunnel: bool, database_name: str, table_name: str, result_timestamps: bool, @@ -58,6 +55,8 @@ def __init__( :param provider: Database Provider; either `sqlite` or `mysql` :type provider: str + :param use_ssh_tunnel: Whether to use an SSH tunnel to connect to the database + :type use_ssh_tunnel: bool :param database_name: Name of the database :type database_name: str :param table_name: Name of the table @@ -71,6 +70,7 @@ def __init__( :type logtables: Dict[str, Dict[str,str]] """ self.provider = provider + self.use_ssh_tunnel = use_ssh_tunnel self.database_name = database_name self.table_name = table_name self.result_timestamps = result_timestamps @@ -85,6 +85,8 @@ def extract_config(config: OmegaConf, logger: logging.Logger) -> Tuple["Database database_config = config["PY_EXPERIMENTER"]["Database"] table_config = database_config["table"] provider = database_config["provider"] + # Optional use_ssh_tunnel + use_ssh_tunnel = database_config["use_ssh"] if "use_ssh" in database_config else False database_name = database_config["database"] table_name = database_config["table"]["name"] @@ -97,6 +99,7 @@ def extract_config(config: OmegaConf, logger: logging.Logger) -> Tuple["Database return DatabaseCfg( provider, + use_ssh_tunnel, database_name, table_name, result_timestamps, @@ -208,6 +211,9 @@ def valid(self) -> bool: if self.provider not in ["sqlite", "mysql"]: self.logger.error("Database provider must be either sqlite or mysql") return False + if self.use_ssh_tunnel not in [True, False]: + self.logger.error("Use SSH tunnel must be a boolean.") + return False if not isinstance(self.database_name, str): self.logger.error("Database name must be a string") return False @@ -372,6 +378,7 @@ def extract_config(config_path: str, logger: logging.Logger) -> "PyExperimenterC def valid(self) -> bool: if not isinstance(self.n_jobs, int) and self.n_jobs > 0: self.logger.error("n_jobs must be a positive integer") + return False if not (self.database_configuration.valid() and self.custom_configuration.valid() and self.codecarbon_configuration.valid()): self.logger.error("Database configuration invalid") return False diff --git a/py_experimenter/database_connector_mysql.py b/py_experimenter/database_connector_mysql.py index 834d814b..43e891e9 100644 --- a/py_experimenter/database_connector_mysql.py +++ b/py_experimenter/database_connector_mysql.py @@ -7,21 +7,17 @@ from omegaconf import OmegaConf from pymysql import Error, connect +from py_experimenter.config import DatabaseCfg from py_experimenter.database_connector import DatabaseConnector -from py_experimenter.exceptions import ( - DatabaseConnectionError, - DatabaseCreationError, - SshTunnelError, -) +from py_experimenter.exceptions import DatabaseConnectionError, DatabaseCreationError, SshTunnelError class DatabaseConnectorMYSQL(DatabaseConnector): _prepared_statement_placeholder = "%s" - def __init__(self, database_configuration: OmegaConf, use_codecarbon: bool, credential_path: str, use_ssh_tunnel: bool, logger: Logger): + def __init__(self, database_configuration: DatabaseCfg, use_codecarbon: bool, credential_path: str, logger: Logger): self.credential_path = credential_path - self.use_ssh_tunnel = use_ssh_tunnel - if self.use_ssh_tunnel: + if database_configuration.use_ssh_tunnel: self.start_ssh_tunnel(logger) super().__init__(database_configuration, use_codecarbon, logger) @@ -55,7 +51,6 @@ def get_ssh_tunnel(self, logger: Logger): except DatabaseConnectionError as err: logger.error(err) raise SshTunnelError("Error when creating SSH tunnel! Check the credentials file.") - def start_ssh_tunnel(self, logger: Logger): tunnel = self.get_ssh_tunnel(logger) @@ -64,6 +59,8 @@ def start_ssh_tunnel(self, logger: Logger): tunnel.start() def close_ssh_tunnel(self): + if not self.database_configuration.use_ssh_tunnel: + self.logger.warning("Attempt to close SSH tunnel, but ssh tunnel is not used.") tunnel = self.get_ssh_tunnel(self.logger) if tunnel is not None: tunnel.stop(force=True) @@ -108,7 +105,7 @@ def _get_database_credentials(self): try: credential_config = OmegaConf.load(self.credential_path) database_configuration = credential_config["CREDENTIALS"]["Database"] - if self.use_ssh_tunnel: + if self.database_configuration.use_ssh_tunnel: server_address = credential_config["CREDENTIALS"]["Connection"]["Ssh"]["server"] else: server_address = credential_config["CREDENTIALS"]["Connection"]["Standard"]["server"] diff --git a/py_experimenter/experimenter.py b/py_experimenter/experimenter.py index 914cc272..1da4f922 100644 --- a/py_experimenter/experimenter.py +++ b/py_experimenter/experimenter.py @@ -12,10 +12,7 @@ from py_experimenter.config import PyExperimenterCfg from py_experimenter.database_connector_lite import DatabaseConnectorLITE from py_experimenter.database_connector_mysql import DatabaseConnectorMYSQL -from py_experimenter.exceptions import ( - InvalidConfigError, - NoExperimentsLeftException, -) +from py_experimenter.exceptions import InvalidConfigError, NoExperimentsLeftException from py_experimenter.experiment_status import ExperimentStatus from py_experimenter.result_processor import ResultProcessor @@ -107,7 +104,10 @@ def __init__( raise InvalidConfigError("Invalid configuration") self.database_credential_file_path = database_credential_file_path - self.use_ssh_tunnel = use_ssh_tunnel + + # If use_ssh_tunnel is not None, the decision is based on the given kwarg + if use_ssh_tunnel is not None: + self.config.database_configuration.use_ssh_tunnel = use_ssh_tunnel if table_name is not None: self.config.database_configuration.table_name = table_name @@ -121,7 +121,7 @@ def __init__( self.db_connector = DatabaseConnectorLITE(self.config.database_configuration, self.use_codecarbon, self.logger) elif self.config.database_configuration.provider == "mysql": self.db_connector = DatabaseConnectorMYSQL( - self.config.database_configuration, self.use_codecarbon, database_credential_file_path, use_ssh_tunnel, self.logger + self.config.database_configuration, self.use_codecarbon, database_credential_file_path, self.logger ) else: raise ValueError("The provider indicated in the config file is not supported") @@ -132,7 +132,7 @@ def close_ssh(self) -> None: """ Closes the ssh tunnel if it is used. """ - if self.config.database_configuration.provider == "mysql" and self.use_ssh_tunnel: + if self.config.database_configuration.provider == "mysql": self.db_connector.close_ssh_tunnel() else: self.logger.warning("No ssh tunnel to close") diff --git a/test/test_database_connector.py b/test/test_database_connector.py index 84068f44..74677f80 100644 --- a/test/test_database_connector.py +++ b/test/test_database_connector.py @@ -133,7 +133,6 @@ def test_fill_table( experiment_configuration, False, CREDENTIAL_PATH, - False, logger=logger, ) @@ -176,7 +175,7 @@ def test_delete_experiments_with_condition( config = OmegaConf.load(CONFIG_PATH) experiment_configuration = DatabaseCfg.extract_config(config, logger) - database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, use_ssh_tunnel=True, logger=logger) + database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, logger=logger) database_connector._delete_experiments_with_condition(f'WHERE status = "{ExperimentStatus.CREATED.value}"') @@ -222,7 +221,7 @@ def test_get_experiments_with_condition( config = OmegaConf.load(CONFIG_PATH) experiment_configuration = DatabaseCfg.extract_config(config, logger) - database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, False, logger=logger) + database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, logger=logger) database_connector._get_experiments_with_condition(f'WHERE status = "{ExperimentStatus.CREATED.value}"') @@ -251,7 +250,7 @@ def test_delete_table( config = OmegaConf.load(CONFIG_PATH) experiment_configuration = DatabaseCfg.extract_config(config, logger) - database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, False, logger=logger) + database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, logger=logger) database_connector.delete_table()