Skip to content

Commit

Permalink
Bugfix of ssh passphrase (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
tornede authored Mar 8, 2024
1 parent 5b9ef5c commit d0f0a92
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/source/usage/database_credential_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ The following example shows how to connect to a database server using an SSH ser
server: example.mysqlserver.com (address from ssh server)
address: ssh_hostname (either name/ip address of the ssh server or a name from you local ssh config file)
port: optional_ssh_port (default: 22)
passphrase: passphrase
ssh_private_key_password: passphrase
remote_address: optional_mysql_server_address (default: 127.0.0.1)
remote_port: optional_mysql_server_port (default: 3306)
local_address: optional_local_address (default: 127.0.0.1)
local_port: optional_local_port (default: 3306)
.. note::
Note that we do not support further parameters for the SSH connection, such as explicitly setting the private key file. To use these, you have to adapt your local ssh config file.
Note that we do not support further parameters for the SSH connection, such as explicitly setting the private key file. To use these, you have to adapt your local ssh config file.
6 changes: 3 additions & 3 deletions py_experimenter/database_connector_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_ssh_tunnel(self, logger: Logger):
parameters = dict(credentials["Ssh"])
ssh_address_or_host = parameters["address"]
ssh_address_or_host_port = parameters["port"] if "port" in parameters else 22
ssh_keypass = parameters["ssh_keypass"] if "ssh_keypass" in parameters else None
ssh_private_key_password = parameters["ssh_private_key_password"] if "ssh_private_key_password" in parameters else None
remote_bind_address = parameters["remote_address"] if "remote_address" in parameters else "127.0.0.1"
remote_bind_address_port = parameters["remote_port"] if "remote_port" in parameters else 3306
local_bind_address = parameters["local_address"] if "local_address" in parameters else "127.0.0.1"
Expand All @@ -40,7 +40,7 @@ def get_ssh_tunnel(self, logger: Logger):
try:
tunnel = sshtunnel.SSHTunnelForwarder(
ssh_address_or_host=(ssh_address_or_host, ssh_address_or_host_port),
ssh_pkey=ssh_keypass,
ssh_private_key_password=ssh_private_key_password,
remote_bind_address=(remote_bind_address, remote_bind_address_port),
local_bind_address=(local_bind_address, local_bind_address_port),
logger=logger,
Expand Down Expand Up @@ -179,4 +179,4 @@ def _get_column_names_from_entries(entries):

self.execute(cursor, f"SHOW COLUMNS FROM {self.database_configuration.table_name}")
column_names = _get_column_names_from_entries(self.fetchall(cursor))
return column_names
return column_names
4 changes: 4 additions & 0 deletions py_experimenter/experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,13 @@ def unpause_experiment(self, experiment_id: int, experiment_function: Callable)
:param experiment_function: _description_ The experiment function to use to continue the given experiment
:type experiment_function: Callable
"""
self._write_codecarbon_config()

keyfield_dict, _ = self.db_connector.pull_paused_experiment(experiment_id)
self._execute_experiment(experiment_id, keyfield_dict, experiment_function)

self._delete_codecarbon_config()

def _worker(self, experiment_function: Callable[[Dict, Dict, ResultProcessor], None], random_order: bool) -> None:
"""
Worker that repeatedly pulls open experiments from the database table and executes them.
Expand Down

0 comments on commit d0f0a92

Please sign in to comment.