diff --git a/py_experimenter/database_connector.py b/py_experimenter/database_connector.py index 83ccf6ba..c9df6d82 100644 --- a/py_experimenter/database_connector.py +++ b/py_experimenter/database_connector.py @@ -2,7 +2,7 @@ import logging from functools import reduce from operator import concat -from typing import Dict, List, Optional, Tuple, Union, Any +from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd @@ -13,8 +13,8 @@ DatabaseConnectionError, EmptyFillDatabaseCallError, NoExperimentsLeftException, + NoPausedExperimentsException, TableHasWrongStructureError, - NoPausedExperimentsException ) from py_experimenter.experiment_status import ExperimentStatus @@ -214,7 +214,7 @@ def _select_open_experiments_from_db(self, connection, cursor, random_order: boo time = utils.get_timestamp_representation() - self.execute(cursor, f"SELECT id FROM {self.database_configuration.table_name} WHERE status = 'created' ORDER BY {order_by} LIMIT 1;") + self.execute(cursor, self._get_pull_experiment_query(order_by)) experiment_id = self.fetchall(cursor)[0][0] self.execute( cursor, diff --git a/test/test_run_experiments/test_run_mysql_experiment.py b/test/test_run_experiments/test_run_mysql_experiment.py index 5cb4f517..8cb472b4 100644 --- a/test/test_run_experiments/test_run_mysql_experiment.py +++ b/test/test_run_experiments/test_run_mysql_experiment.py @@ -102,6 +102,26 @@ def test_mysql_shh(): experimenter.db_connector.close_connection(connection) experimenter.close_ssh() +def test_no_experiment_double_execution(): + experiment_configuration_file_path = os.path.join("test", "test_run_experiments", "test_run_mysql_experiment_config.yml") + logging.basicConfig(level=logging.DEBUG) + experimenter = PyExperimenter(experiment_configuration_file_path=experiment_configuration_file_path, use_codecarbon=False, use_ssh_tunnel=False) + try: + experimenter.delete_table() + except ProgrammingError as e: + logging.warning(e) + experimenter.fill_table_from_config() + + # At most 30 experiments should be executed. If the experiment is executed twice, there should be less then 30 entries + experimenter.execute(own_function, max_experiments=30, n_jobs=5) + + connection = experimenter.db_connector.connect() + cursor = experimenter.db_connector.cursor(connection) + cursor.execute(f"SELECT * FROM {experimenter.db_connector.database_configuration.table_name} WHERE status = 'done'") + entries = cursor.fetchall() + + # If the experiment is executed twice, there should be less then 30 entries + assert len(entries) == 30 def error_function(keyfields: dict, result_processor: ResultProcessor, custom_fields: dict): raise Exception("Error with weird symbos '@#$%&/\()=") @@ -141,7 +161,6 @@ def test_run_error_experiment(): ]: assert message in entries[0][11] - def own_function_raising_errors(keyfields: dict, result_processor: ResultProcessor, custom_fields: dict): error_code = keyfields["error_code"] @@ -206,4 +225,4 @@ def test_boolean_in_table(): assert table["value"].dtype == int assert (table["value"] == [1, 0]).all() assert (table["given_bool"] == [1, 0]).all() - assert (table["status"] == ["done", "done"]).all() + assert (table["status"] == ["done", "done"]).all() \ No newline at end of file