From 246b85474e19ba1829743804fe9fd362fff885f6 Mon Sep 17 00:00:00 2001 From: Frode Aarstad Date: Mon, 11 Nov 2024 14:32:23 +0100 Subject: [PATCH] Move everserver config to ServerConfig --- src/everest/bin/everest_script.py | 4 +- src/everest/bin/kill_script.py | 6 +- src/everest/bin/monitor_script.py | 4 +- src/everest/config/everest_config.py | 75 +------------------ src/everest/config/server_config.py | 75 ++++++++++++++++++- src/everest/detached/__init__.py | 30 ++++---- src/everest/detached/jobs/everserver.py | 10 ++- .../entry_points/test_everest_entry.py | 24 +++--- tests/everest/test_detached.py | 21 ++++-- tests/everest/test_everserver.py | 4 +- tests/everest/test_logging.py | 4 +- tests/everest/test_util.py | 4 +- 12 files changed, 135 insertions(+), 126 deletions(-) diff --git a/src/everest/bin/everest_script.py b/src/everest/bin/everest_script.py index f00137e69a5..535d62da809 100755 --- a/src/everest/bin/everest_script.py +++ b/src/everest/bin/everest_script.py @@ -9,7 +9,7 @@ from ert.config import ErtConfig from ert.storage import open_storage -from everest.config import EverestConfig +from everest.config import EverestConfig, ServerConfig from everest.detached import ( ServerStatus, everserver_status, @@ -84,7 +84,7 @@ def run_everest(options): logger = logging.getLogger("everest_main") server_state = everserver_status(options.config) - if server_is_running(*options.config.server_context): + if server_is_running(*ServerConfig.get_server_context(options.config.output_dir)): config_file = options.config.config_file print( "An optimization is currently running.\n" diff --git a/src/everest/bin/kill_script.py b/src/everest/bin/kill_script.py index c9d0c6453e7..1717f108a3b 100755 --- a/src/everest/bin/kill_script.py +++ b/src/everest/bin/kill_script.py @@ -10,7 +10,7 @@ import traceback from functools import partial -from everest.config import EverestConfig +from everest.config import EverestConfig, ServerConfig from everest.detached import server_is_running, stop_server, wait_for_server_to_stop from everest.util import version_info @@ -70,7 +70,9 @@ def _handle_keyboard_interrupt(signal, frame, after=False): def kill_everest(options): - if not server_is_running(*options.config.server_context): + if not server_is_running( + *ServerConfig.get_server_context(options.config.output_dir) + ): print("Server is not running.") return diff --git a/src/everest/bin/monitor_script.py b/src/everest/bin/monitor_script.py index 310f4bd80ca..a3187f24fa5 100755 --- a/src/everest/bin/monitor_script.py +++ b/src/everest/bin/monitor_script.py @@ -6,7 +6,7 @@ import threading from functools import partial -from everest.config import EverestConfig +from everest.config import EverestConfig, ServerConfig from everest.detached import ServerStatus, everserver_status, server_is_running from .utils import ( @@ -63,7 +63,7 @@ def monitor_everest(options): config: EverestConfig = options.config server_state = everserver_status(options.config) - if server_is_running(*config.server_context): + if server_is_running(*ServerConfig.get_server_context(config.output_dir)): run_detached_monitor(config, show_all_jobs=options.show_all_jobs) server_state = everserver_status(config) if server_state["status"] == ServerStatus.failed: diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index eefc600c70a..ad271b5678c 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -1,4 +1,3 @@ -import json import logging import os import shutil @@ -11,7 +10,6 @@ Literal, Optional, Protocol, - Tuple, no_type_check, ) @@ -45,14 +43,9 @@ from ..config_file_loader import yaml_file_to_substituted_config_dict from ..strings import ( - CERTIFICATE_DIR, DEFAULT_OUTPUT_DIR, - DETACHED_NODE_DIR, - HOSTFILE_NAME, OPTIMIZATION_LOG_DIR, OPTIMIZATION_OUTPUT_DIR, - SERVER_STATUS, - SESSION_DIR, STORAGE_DIR, ) from .control_config import ControlConfig @@ -605,7 +598,7 @@ def config_file(self) -> Optional[str]: return None @property - def output_dir(self) -> Optional[str]: + def output_dir(self) -> str: assert self.environment is not None path = self.environment.output_folder @@ -655,67 +648,6 @@ def storage_dir(self): def log_dir(self): return self._get_output_subdirectory(OPTIMIZATION_LOG_DIR) - @property - def detached_node_dir(self): - return self._get_output_subdirectory(DETACHED_NODE_DIR) - - @property - def session_dir(self): - """Return path to the session directory containing information about the - certificates and host information""" - return os.path.join(self.detached_node_dir, SESSION_DIR) - - @property - def certificate_dir(self): - """Return the path to certificate folder""" - return os.path.join(self.session_dir, CERTIFICATE_DIR) - - def get_server_url(self, server_info=None): - """Return the url of the server. - - If server_info are given, the url is generated using that info. Otherwise - server information are retrieved from the hostfile - """ - if server_info is None: - server_info = self.server_info - - url = f"https://{server_info['host']}:{server_info['port']}" - return url - - @property - def hostfile_path(self): - return os.path.join(self.session_dir, HOSTFILE_NAME) - - @property - def server_info(self): - """Load server information from the hostfile""" - host_file_path = self.hostfile_path - try: - with open(host_file_path, "r", encoding="utf-8") as f: - json_string = f.read() - - data = json.loads(json_string) - if set(data.keys()) != {"host", "port", "cert", "auth"}: - raise RuntimeError("Malformed hostfile") - return data - except FileNotFoundError: - # No host file - return {"host": None, "port": None, "cert": None, "auth": None} - - @property - def server_context(self) -> Tuple[str, str, Tuple[str, str]]: - """Returns a tuple with - - url of the server - - path to the .cert file - - password for the certificate file - """ - - return ( - self.get_server_url(self.server_info), - self.server_info[CERTIFICATE_DIR], - ("username", self.server_info["auth"]), - ) - @property def export_path(self): """Returns the export file path. If not file name is provide the default @@ -738,11 +670,6 @@ def export_path(self): default_export_file = f"{os.path.splitext(self.config_file)[0]}.csv" return os.path.join(full_file_path, default_export_file) - @property - def everserver_status_path(self): - """Returns path to the everest server status file""" - return os.path.join(self.session_dir, SERVER_STATUS) - def to_dict(self) -> dict: the_dict = self.model_dump(exclude_none=True) diff --git a/src/everest/config/server_config.py b/src/everest/config/server_config.py index 4ee3a00e34f..3fadbc6fb05 100644 --- a/src/everest/config/server_config.py +++ b/src/everest/config/server_config.py @@ -1,7 +1,16 @@ -from typing import Literal, Optional +import json +import os +from typing import Literal, Optional, Tuple from pydantic import BaseModel, ConfigDict, Field +from ..strings import ( + CERTIFICATE_DIR, + DETACHED_NODE_DIR, + HOSTFILE_NAME, + SERVER_STATUS, + SESSION_DIR, +) from .has_ert_queue_options import HasErtQueueOptions @@ -41,3 +50,67 @@ class ServerConfig(BaseModel, HasErtQueueOptions): # type: ignore model_config = ConfigDict( extra="forbid", ) + + @staticmethod + def get_server_url(output_dir: str) -> str: + """Return the url of the server. + + If server_info are given, the url is generated using that info. Otherwise + server information are retrieved from the hostfile + """ + server_info = ServerConfig.get_server_info(output_dir) + return f"https://{server_info['host']}:{server_info['port']}" + + @staticmethod + def get_server_context(output_dir: str) -> Tuple[str, bool, Tuple[str, str]]: + """Returns a tuple with + - url of the server + - path to the .cert file + - password for the certificate file + """ + server_info = ServerConfig.get_server_info(output_dir) + return ( + ServerConfig.get_server_url(output_dir), + server_info[CERTIFICATE_DIR], + ("username", server_info["auth"]), + ) + + @staticmethod + def get_server_info(output_dir: str) -> dict: + """Load server information from the hostfile""" + host_file_path = ServerConfig.get_hostfile_path(output_dir) + try: + with open(host_file_path, "r", encoding="utf-8") as f: + json_string = f.read() + + data = json.loads(json_string) + if set(data.keys()) != {"host", "port", "cert", "auth"}: + raise RuntimeError("Malformed hostfile") + return data + except FileNotFoundError: + # No host file + return {"host": None, "port": None, "cert": None, "auth": None} + + @staticmethod + def get_detached_node_dir(output_dir: str) -> str: + return os.path.join(os.path.abspath(output_dir), DETACHED_NODE_DIR) + + @staticmethod + def get_hostfile_path(output_dir: str) -> str: + return os.path.join(ServerConfig.get_session_dir(output_dir), HOSTFILE_NAME) + + @staticmethod + def get_session_dir(output_dir: str) -> str: + """Return path to the session directory containing information about the + certificates and host information""" + return os.path.join(ServerConfig.get_detached_node_dir(output_dir), SESSION_DIR) + + @staticmethod + def get_everserver_status_path(output_dir: str) -> str: + """Returns path to the everest server status file""" + return os.path.join(ServerConfig.get_session_dir(output_dir), SERVER_STATUS) + + @staticmethod + def get_certificate_dir(output_dir: str) -> str: + """Return the path to certificate folder""" + return os.path.join(ServerConfig.get_session_dir(output_dir), CERTIFICATE_DIR) diff --git a/src/everest/detached/__init__.py b/src/everest/detached/__init__.py index c882fef63c4..1f64e655eb9 100644 --- a/src/everest/detached/__init__.py +++ b/src/everest/detached/__init__.py @@ -16,7 +16,7 @@ from ert import BatchContext, BatchSimulator, JobState from ert.config import ErtConfig, QueueSystem -from everest.config import EverestConfig +from everest.config import EverestConfig, ServerConfig from everest.config_keys import ConfigKeys as CK from everest.strings import ( EVEREST, @@ -59,7 +59,9 @@ def start_server(config: EverestConfig, ert_config: ErtConfig, storage): """ Start an Everest server running the optimization defined in the config """ - if server_is_running(*config.server_context): # better safe than sorry + if server_is_running( + *ServerConfig.get_server_context(config.output_dir) + ): # better safe than sorry return log_dir = config.log_dir @@ -143,7 +145,7 @@ def stop_server(config: EverestConfig, retries: int = 5): """ for retry in range(retries): try: - url, cert, auth = config.server_context + url, cert, auth = ServerConfig.get_server_context(config.output_dir) stop_endpoint = "/".join([url, STOP_ENDPOINT]) response = requests.post( stop_endpoint, @@ -174,7 +176,7 @@ def wait_for_server( Raise an exception when the timeout is reached. """ - if not server_is_running(*config.server_context): + if not server_is_running(*ServerConfig.get_server_context(config.output_dir)): sleep_time_increment = float(timeout) / (2**_HTTP_REQUEST_RETRY - 1) for retry_count in range(_HTTP_REQUEST_RETRY): # Failure may occur before contact with the server is established: @@ -218,11 +220,11 @@ def wait_for_server( sleep_time = sleep_time_increment * (2**retry_count) time.sleep(sleep_time) - if server_is_running(*config.server_context): + if server_is_running(*ServerConfig.get_server_context(config.output_dir)): return # If number of retries reached and server is not running - throw exception - if not server_is_running(*config.server_context): + if not server_is_running(*ServerConfig.get_server_context(config.output_dir)): raise RuntimeError("Failed to start server within configured timeout.") @@ -264,16 +266,18 @@ def wait_for_server_to_stop(config: EverestConfig, timeout): Raise an exception when the timeout is reached. """ - if server_is_running(*config.server_context): + if server_is_running(*ServerConfig.get_server_context(config.output_dir)): sleep_time_increment = float(timeout) / (2**_HTTP_REQUEST_RETRY - 1) for retry_count in range(_HTTP_REQUEST_RETRY): sleep_time = sleep_time_increment * (2**retry_count) time.sleep(sleep_time) - if not server_is_running(*config.server_context): + if not server_is_running( + *ServerConfig.get_server_context(config.output_dir) + ): return # If number of retries reached and server still running - throw exception - if server_is_running(*config.server_context): + if server_is_running(*ServerConfig.get_server_context(config.output_dir)): raise Exception("Failed to stop server within configured timeout.") @@ -310,7 +314,7 @@ def start_monitor(config: EverestConfig, callback, polling_interval=5): Monitoring stops when the server stops answering. It can also be interrupted by returning True from the callback """ - url, cert, auth = config.server_context + url, cert, auth = ServerConfig.get_server_context(config.output_dir) sim_endpoint = "/".join([url, SIM_PROGRESS_ENDPOINT]) opt_endpoint = "/".join([url, OPT_PROGRESS_ENDPOINT]) @@ -448,7 +452,7 @@ def generate_everserver_ert_config(config: EverestConfig, debug_mode: bool = Fal site_config = ErtConfig.read_site_config() abs_everest_config = os.path.join(config.config_directory, config.config_file) - detached_node_dir = config.detached_node_dir + detached_node_dir = ServerConfig.get_detached_node_dir(config.output_dir) simulation_path = os.path.join(detached_node_dir, SIMULATION_DIR) queue_system = _find_res_queue_system(config) arg_list = ["--config-file", abs_everest_config] @@ -532,7 +536,7 @@ def update_everserver_status( ): """Update the everest server status with new status information""" new_status = {"status": status, "message": message} - path = config.everserver_status_path + path = ServerConfig.get_everserver_status_path(config.output_dir) if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) with open(path, "w", encoding="utf-8") as outfile: @@ -560,7 +564,7 @@ def everserver_status(config: EverestConfig): 'message': None } """ - path = config.everserver_status_path + path = ServerConfig.get_everserver_status_path(config.output_dir) if os.path.exists(path): with open(path, "r", encoding="utf-8") as f: return json.load(f, object_hook=ServerStatusEncoder.decode) diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 1aeb9124a0e..6f7e3145f28 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -23,7 +23,7 @@ from ert.ensemble_evaluator import EvaluatorServerConfig from ert.run_models.everest_run_model import EverestRunModel from everest import export_to_csv, export_with_progress -from everest.config import EverestConfig +from everest.config import EverestConfig, ServerConfig from everest.detached import ServerStatus, get_opt_status, update_everserver_status from everest.export import check_for_errors from everest.simulator import JOB_FAILURE @@ -169,7 +169,9 @@ def _find_open_port(host, lower, upper): def _write_hostfile(config: EverestConfig, host, port, cert, auth): - host_file_path = config.hostfile_path + # host_file_path = config.hostfile_path + host_file_path = ServerConfig.get_hostfile_path(config.output_dir) + if not os.path.exists(os.path.dirname(host_file_path)): os.makedirs(os.path.dirname(host_file_path)) data = { @@ -185,7 +187,7 @@ def _write_hostfile(config: EverestConfig, host, port, cert, auth): def _configure_loggers(config: EverestConfig): - detached_node_dir = config.detached_node_dir + detached_node_dir = ServerConfig.get_detached_node_dir(config.output_dir) everest_logs_dir = config.log_dir configure_logger( @@ -414,7 +416,7 @@ def _generate_certificate(config: EverestConfig): ) # Write certificate and key to disk - cert_folder = config.certificate_dir + cert_folder = ServerConfig.get_certificate_dir(config.output_dir) makedirs_if_needed(cert_folder) cert_path = os.path.join(cert_folder, cert_name + ".crt") with open(cert_path, "wb") as f: diff --git a/tests/everest/entry_points/test_everest_entry.py b/tests/everest/entry_points/test_everest_entry.py index a5693af2788..53a47d63f27 100644 --- a/tests/everest/entry_points/test_everest_entry.py +++ b/tests/everest/entry_points/test_everest_entry.py @@ -1,7 +1,7 @@ import logging import os from functools import partial -from unittest.mock import PropertyMock, patch +from unittest.mock import patch import pytest @@ -9,7 +9,7 @@ from everest.bin.everest_script import everest_entry from everest.bin.kill_script import kill_entry from everest.bin.monitor_script import monitor_entry -from everest.config import EverestConfig +from everest.config import EverestConfig, ServerConfig from everest.detached import ( SIM_PROGRESS_ENDPOINT, ServerStatus, @@ -288,9 +288,8 @@ def test_everest_entry_monitor_no_run( @patch("everest.bin.everest_script.start_server") @patch("everest.detached._query_server", side_effect=query_server_mock) @patch.object( - EverestConfig, - "server_context", - new_callable=PropertyMock, + ServerConfig, + "get_server_context", return_value=("localhost", "", ""), ) @patch("everest.detached.get_opt_status", return_value={}) @@ -323,9 +322,8 @@ def test_everest_entry_show_all_jobs( @patch("everest.bin.everest_script.start_server") @patch("everest.detached._query_server", side_effect=query_server_mock) @patch.object( - EverestConfig, - "server_context", - new_callable=PropertyMock, + ServerConfig, + "get_server_context", return_value=("localhost", "", ""), ) @patch("everest.detached.get_opt_status", return_value={}) @@ -360,9 +358,8 @@ def test_everest_entry_no_show_all_jobs( @patch("everest.bin.monitor_script.server_is_running", return_value=True) @patch("everest.detached._query_server", side_effect=query_server_mock) @patch.object( - EverestConfig, - "server_context", - new_callable=PropertyMock, + ServerConfig, + "get_server_context", return_value=("localhost", "", ""), ) @patch("everest.detached.get_opt_status", return_value={}) @@ -392,9 +389,8 @@ def test_monitor_entry_show_all_jobs( @patch("everest.bin.monitor_script.server_is_running", return_value=True) @patch("everest.detached._query_server", side_effect=query_server_mock) @patch.object( - EverestConfig, - "server_context", - new_callable=PropertyMock, + ServerConfig, + "get_server_context", return_value=("localhost", "", ""), ) @patch("everest.detached.get_opt_status", return_value={}) diff --git a/tests/everest/test_detached.py b/tests/everest/test_detached.py index 5955ce8ac04..cdb4df6d2fb 100644 --- a/tests/everest/test_detached.py +++ b/tests/everest/test_detached.py @@ -10,8 +10,7 @@ from ert import JobState from ert.config import ErtConfig, QueueSystem from ert.storage import open_storage -from everest.config import EverestConfig -from everest.config.server_config import ServerConfig +from everest.config import EverestConfig, ServerConfig from everest.config.simulator_config import SimulatorConfig from everest.config_keys import ConfigKeys as CK from everest.detached import ( @@ -84,7 +83,7 @@ def test_https_requests(copy_math_func_test_data_to_tmp): server_status = everserver_status(everest_config) assert ServerStatus.running == server_status["status"] - url, cert, auth = everest_config.server_context + url, cert, auth = ServerConfig.get_server_context(everest_config.output_dir) result = requests.get(url, verify=cert, auth=auth, proxies=PROXY) assert result.status_code == 200 # Request has succeeded @@ -95,7 +94,7 @@ def test_https_requests(copy_math_func_test_data_to_tmp): response.raise_for_status() # Test request with wrong password fails - url, cert, _ = everest_config.server_context + url, cert, _ = ServerConfig.get_server_context(everest_config.output_dir) usr = "admin" password = "wrong_password" with pytest.raises(Exception): # noqa B017 @@ -103,7 +102,9 @@ def test_https_requests(copy_math_func_test_data_to_tmp): result.raise_for_status() # Test stopping server - assert server_is_running(*everest_config.server_context) + assert server_is_running( + *ServerConfig.get_server_context(everest_config.output_dir) + ) if stop_server(everest_config): wait_for_server_to_stop(everest_config, 60) @@ -115,7 +116,9 @@ def test_https_requests(copy_math_func_test_data_to_tmp): ServerStatus.stopped, ServerStatus.completed, ] - assert not server_is_running(*everest_config.server_context) + assert not server_is_running( + *ServerConfig.get_server_context(everest_config.output_dir) + ) else: context_stop_and_wait() server_status = everserver_status(everest_config) @@ -126,11 +129,13 @@ def test_server_status(copy_math_func_test_data_to_tmp): config = EverestConfig.load_file("config_minimal.yml") # Check status file does not exist before initial status update - assert not os.path.exists(config.everserver_status_path) + assert not os.path.exists( + ServerConfig.get_everserver_status_path(config.output_dir) + ) update_everserver_status(config, ServerStatus.starting) # Check status file exists after initial status update - assert os.path.exists(config.everserver_status_path) + assert os.path.exists(ServerConfig.get_everserver_status_path(config.output_dir)) # Check we can read the server status from disk status = everserver_status(config) diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index c5bd3f50293..db3d81e7d95 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -6,7 +6,7 @@ from ropt.enums import OptimizerExitCode from seba_sqlite.snapshot import SebaSnapshot -from everest.config import EverestConfig +from everest.config import EverestConfig, ServerConfig from everest.detached import ServerStatus, everserver_status from everest.detached.jobs import everserver from everest.simulator import JOB_FAILURE, JOB_SUCCESS @@ -72,7 +72,7 @@ def test_hostfile_storage(copy_math_func_test_data_to_tmp): "auth": "1234", } everserver._write_hostfile(config, **expected_result) - result = config.server_info + result = ServerConfig.get_server_info(config.output_dir) assert result == expected_result diff --git a/tests/everest/test_logging.py b/tests/everest/test_logging.py index 6675bcb6323..764859d43d3 100644 --- a/tests/everest/test_logging.py +++ b/tests/everest/test_logging.py @@ -4,7 +4,7 @@ from ert.config import ErtConfig from ert.storage import open_storage -from everest.config import EverestConfig +from everest.config import EverestConfig, ServerConfig from everest.detached import ( context_stop_and_wait, generate_everserver_ert_config, @@ -47,7 +47,7 @@ def test_logging_setup(copy_math_func_test_data_to_tmp): everest_logs_dir_path = everest_config.log_dir - detached_node_dir = everest_config.detached_node_dir + detached_node_dir = ServerConfig.get_detached_node_dir(everest_config.output_dir) endpoint_log_path = os.path.join(detached_node_dir, "endpoint.log") everest_log_path = os.path.join(everest_logs_dir_path, "everest.log") diff --git a/tests/everest/test_util.py b/tests/everest/test_util.py index 72937e88697..edfbf58e191 100644 --- a/tests/everest/test_util.py +++ b/tests/everest/test_util.py @@ -6,7 +6,7 @@ from everest import util from everest.bin.utils import report_on_previous_run -from everest.config import EverestConfig +from everest.config import EverestConfig, ServerConfig from everest.config.everest_config import get_system_installed_jobs from everest.config_keys import ConfigKeys from everest.detached import ServerStatus @@ -128,7 +128,7 @@ def test_get_everserver_status_path(copy_math_func_test_data_to_tmp): session_path = os.path.join( cwd, "everest_output", "detached_node_output", ".session" ) - path = config.everserver_status_path + path = ServerConfig.get_everserver_status_path(config.output_dir) expected_path = os.path.join(session_path, SERVER_STATUS) assert path == expected_path