Skip to content

Commit

Permalink
Update ssh handling
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasFehring authored and tornede committed Apr 17, 2024
1 parent bfdbdd7 commit 613d4f5
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 26 deletions.
17 changes: 12 additions & 5 deletions py_experimenter/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
from omegaconf import DictConfig, ListConfig, OmegaConf

from py_experimenter import utils
from py_experimenter.exceptions import (
InvalidColumnError,
InvalidConfigError,
InvalidLogtableError,
)
from py_experimenter.exceptions import InvalidColumnError, InvalidConfigError, InvalidLogtableError


class Cfg(ABC):
Expand Down Expand Up @@ -45,6 +41,7 @@ class DatabaseCfg(Cfg):
def __init__(
self,
provider: str,
use_ssh_tunnel: bool,
database_name: str,
table_name: str,
result_timestamps: bool,
Expand All @@ -58,6 +55,8 @@ def __init__(
:param provider: Database Provider; either `sqlite` or `mysql`
:type provider: str
:param use_ssh_tunnel: Whether to use an SSH tunnel to connect to the database
:type use_ssh_tunnel: bool
:param database_name: Name of the database
:type database_name: str
:param table_name: Name of the table
Expand All @@ -71,6 +70,7 @@ def __init__(
:type logtables: Dict[str, Dict[str,str]]
"""
self.provider = provider
self.use_ssh_tunnel = use_ssh_tunnel
self.database_name = database_name
self.table_name = table_name
self.result_timestamps = result_timestamps
Expand All @@ -85,6 +85,8 @@ def extract_config(config: OmegaConf, logger: logging.Logger) -> Tuple["Database
database_config = config["PY_EXPERIMENTER"]["Database"]
table_config = database_config["table"]
provider = database_config["provider"]
# Optional use_ssh_tunnel
use_ssh_tunnel = database_config["use_ssh"] if "use_ssh" in database_config else False
database_name = database_config["database"]
table_name = database_config["table"]["name"]

Expand All @@ -97,6 +99,7 @@ def extract_config(config: OmegaConf, logger: logging.Logger) -> Tuple["Database

return DatabaseCfg(
provider,
use_ssh_tunnel,
database_name,
table_name,
result_timestamps,
Expand Down Expand Up @@ -208,6 +211,9 @@ def valid(self) -> bool:
if self.provider not in ["sqlite", "mysql"]:
self.logger.error("Database provider must be either sqlite or mysql")
return False
if self.use_ssh_tunnel not in [True, False]:
self.logger.error("Use SSH tunnel must be a boolean.")
return False
if not isinstance(self.database_name, str):
self.logger.error("Database name must be a string")
return False
Expand Down Expand Up @@ -372,6 +378,7 @@ def extract_config(config_path: str, logger: logging.Logger) -> "PyExperimenterC
def valid(self) -> bool:
if not isinstance(self.n_jobs, int) and self.n_jobs > 0:
self.logger.error("n_jobs must be a positive integer")
return False
if not (self.database_configuration.valid() and self.custom_configuration.valid() and self.codecarbon_configuration.valid()):
self.logger.error("Database configuration invalid")
return False
Expand Down
17 changes: 7 additions & 10 deletions py_experimenter/database_connector_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@
from omegaconf import OmegaConf
from pymysql import Error, connect

from py_experimenter.config import DatabaseCfg
from py_experimenter.database_connector import DatabaseConnector
from py_experimenter.exceptions import (
DatabaseConnectionError,
DatabaseCreationError,
SshTunnelError,
)
from py_experimenter.exceptions import DatabaseConnectionError, DatabaseCreationError, SshTunnelError


class DatabaseConnectorMYSQL(DatabaseConnector):
_prepared_statement_placeholder = "%s"

def __init__(self, database_configuration: OmegaConf, use_codecarbon: bool, credential_path: str, use_ssh_tunnel: bool, logger: Logger):
def __init__(self, database_configuration: DatabaseCfg, use_codecarbon: bool, credential_path: str, logger: Logger):
self.credential_path = credential_path
self.use_ssh_tunnel = use_ssh_tunnel
if self.use_ssh_tunnel:
if database_configuration.use_ssh_tunnel:
self.start_ssh_tunnel(logger)
super().__init__(database_configuration, use_codecarbon, logger)

Expand Down Expand Up @@ -55,7 +51,6 @@ def get_ssh_tunnel(self, logger: Logger):
except DatabaseConnectionError as err:
logger.error(err)
raise SshTunnelError("Error when creating SSH tunnel! Check the credentials file.")


def start_ssh_tunnel(self, logger: Logger):
tunnel = self.get_ssh_tunnel(logger)
Expand All @@ -64,6 +59,8 @@ def start_ssh_tunnel(self, logger: Logger):
tunnel.start()

def close_ssh_tunnel(self):
if not self.database_configuration.use_ssh_tunnel:
self.logger.warning("Attempt to close SSH tunnel, but ssh tunnel is not used.")
tunnel = self.get_ssh_tunnel(self.logger)
if tunnel is not None:
tunnel.stop(force=True)
Expand Down Expand Up @@ -108,7 +105,7 @@ def _get_database_credentials(self):
try:
credential_config = OmegaConf.load(self.credential_path)
database_configuration = credential_config["CREDENTIALS"]["Database"]
if self.use_ssh_tunnel:
if self.database_configuration.use_ssh_tunnel:
server_address = credential_config["CREDENTIALS"]["Connection"]["Ssh"]["server"]
else:
server_address = credential_config["CREDENTIALS"]["Connection"]["Standard"]["server"]
Expand Down
14 changes: 7 additions & 7 deletions py_experimenter/experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from py_experimenter.config import PyExperimenterCfg
from py_experimenter.database_connector_lite import DatabaseConnectorLITE
from py_experimenter.database_connector_mysql import DatabaseConnectorMYSQL
from py_experimenter.exceptions import (
InvalidConfigError,
NoExperimentsLeftException,
)
from py_experimenter.exceptions import InvalidConfigError, NoExperimentsLeftException
from py_experimenter.experiment_status import ExperimentStatus
from py_experimenter.result_processor import ResultProcessor

Expand Down Expand Up @@ -107,7 +104,10 @@ def __init__(
raise InvalidConfigError("Invalid configuration")

self.database_credential_file_path = database_credential_file_path
self.use_ssh_tunnel = use_ssh_tunnel

# If use_ssh_tunnel is not None, the decision is based on the given kwarg
if use_ssh_tunnel is not None:
self.config.database_configuration.use_ssh_tunnel = use_ssh_tunnel

if table_name is not None:
self.config.database_configuration.table_name = table_name
Expand All @@ -121,7 +121,7 @@ def __init__(
self.db_connector = DatabaseConnectorLITE(self.config.database_configuration, self.use_codecarbon, self.logger)
elif self.config.database_configuration.provider == "mysql":
self.db_connector = DatabaseConnectorMYSQL(
self.config.database_configuration, self.use_codecarbon, database_credential_file_path, use_ssh_tunnel, self.logger
self.config.database_configuration, self.use_codecarbon, database_credential_file_path, self.logger
)
else:
raise ValueError("The provider indicated in the config file is not supported")
Expand All @@ -132,7 +132,7 @@ def close_ssh(self) -> None:
"""
Closes the ssh tunnel if it is used.
"""
if self.config.database_configuration.provider == "mysql" and self.use_ssh_tunnel:
if self.config.database_configuration.provider == "mysql":
self.db_connector.close_ssh_tunnel()
else:
self.logger.warning("No ssh tunnel to close")
Expand Down
7 changes: 3 additions & 4 deletions test/test_database_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def test_fill_table(
experiment_configuration,
False,
CREDENTIAL_PATH,
False,
logger=logger,
)

Expand Down Expand Up @@ -176,7 +175,7 @@ def test_delete_experiments_with_condition(

config = OmegaConf.load(CONFIG_PATH)
experiment_configuration = DatabaseCfg.extract_config(config, logger)
database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, use_ssh_tunnel=True, logger=logger)
database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, logger=logger)

database_connector._delete_experiments_with_condition(f'WHERE status = "{ExperimentStatus.CREATED.value}"')

Expand Down Expand Up @@ -222,7 +221,7 @@ def test_get_experiments_with_condition(

config = OmegaConf.load(CONFIG_PATH)
experiment_configuration = DatabaseCfg.extract_config(config, logger)
database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, False, logger=logger)
database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, logger=logger)

database_connector._get_experiments_with_condition(f'WHERE status = "{ExperimentStatus.CREATED.value}"')

Expand Down Expand Up @@ -251,7 +250,7 @@ def test_delete_table(

config = OmegaConf.load(CONFIG_PATH)
experiment_configuration = DatabaseCfg.extract_config(config, logger)
database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, False, logger=logger)
database_connector = DatabaseConnectorMYSQL(experiment_configuration, False, CREDENTIAL_PATH, logger=logger)

database_connector.delete_table()

Expand Down

0 comments on commit 613d4f5

Please sign in to comment.