Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move everserver config to ServerConfig #9170

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/everest/bin/everest_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ert.config import ErtConfig
from ert.storage import open_storage
from everest.config import EverestConfig
from everest.config import EverestConfig, ServerConfig
from everest.detached import (
ServerStatus,
everserver_status,
Expand Down Expand Up @@ -84,7 +84,7 @@ def run_everest(options):
logger = logging.getLogger("everest_main")
server_state = everserver_status(options.config)

if server_is_running(*options.config.server_context):
if server_is_running(*ServerConfig.get_server_context(options.config.output_dir)):
config_file = options.config.config_file
print(
"An optimization is currently running.\n"
Expand Down
6 changes: 4 additions & 2 deletions src/everest/bin/kill_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import traceback
from functools import partial

from everest.config import EverestConfig
from everest.config import EverestConfig, ServerConfig
from everest.detached import server_is_running, stop_server, wait_for_server_to_stop
from everest.util import version_info

Expand Down Expand Up @@ -70,7 +70,9 @@ def _handle_keyboard_interrupt(signal, frame, after=False):


def kill_everest(options):
if not server_is_running(*options.config.server_context):
if not server_is_running(
*ServerConfig.get_server_context(options.config.output_dir)
):
print("Server is not running.")
return

Expand Down
4 changes: 2 additions & 2 deletions src/everest/bin/monitor_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import threading
from functools import partial

from everest.config import EverestConfig
from everest.config import EverestConfig, ServerConfig
from everest.detached import ServerStatus, everserver_status, server_is_running

from .utils import (
Expand Down Expand Up @@ -63,7 +63,7 @@ def monitor_everest(options):
config: EverestConfig = options.config
server_state = everserver_status(options.config)

if server_is_running(*config.server_context):
if server_is_running(*ServerConfig.get_server_context(config.output_dir)):
run_detached_monitor(config, show_all_jobs=options.show_all_jobs)
server_state = everserver_status(config)
if server_state["status"] == ServerStatus.failed:
Expand Down
75 changes: 1 addition & 74 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import os
import shutil
Expand All @@ -11,7 +10,6 @@
Literal,
Optional,
Protocol,
Tuple,
no_type_check,
)

Expand Down Expand Up @@ -45,14 +43,9 @@

from ..config_file_loader import yaml_file_to_substituted_config_dict
from ..strings import (
CERTIFICATE_DIR,
DEFAULT_OUTPUT_DIR,
DETACHED_NODE_DIR,
HOSTFILE_NAME,
OPTIMIZATION_LOG_DIR,
OPTIMIZATION_OUTPUT_DIR,
SERVER_STATUS,
SESSION_DIR,
STORAGE_DIR,
)
from .control_config import ControlConfig
Expand Down Expand Up @@ -605,7 +598,7 @@ def config_file(self) -> Optional[str]:
return None

@property
def output_dir(self) -> Optional[str]:
def output_dir(self) -> str:
assert self.environment is not None
path = self.environment.output_folder

Expand Down Expand Up @@ -655,67 +648,6 @@ def storage_dir(self):
def log_dir(self):
return self._get_output_subdirectory(OPTIMIZATION_LOG_DIR)

@property
def detached_node_dir(self):
return self._get_output_subdirectory(DETACHED_NODE_DIR)

@property
def session_dir(self):
"""Return path to the session directory containing information about the
certificates and host information"""
return os.path.join(self.detached_node_dir, SESSION_DIR)

@property
def certificate_dir(self):
"""Return the path to certificate folder"""
return os.path.join(self.session_dir, CERTIFICATE_DIR)

def get_server_url(self, server_info=None):
"""Return the url of the server.

If server_info are given, the url is generated using that info. Otherwise
server information are retrieved from the hostfile
"""
if server_info is None:
server_info = self.server_info

url = f"https://{server_info['host']}:{server_info['port']}"
return url

@property
def hostfile_path(self):
return os.path.join(self.session_dir, HOSTFILE_NAME)

@property
def server_info(self):
"""Load server information from the hostfile"""
host_file_path = self.hostfile_path
try:
with open(host_file_path, "r", encoding="utf-8") as f:
json_string = f.read()

data = json.loads(json_string)
if set(data.keys()) != {"host", "port", "cert", "auth"}:
raise RuntimeError("Malformed hostfile")
return data
except FileNotFoundError:
# No host file
return {"host": None, "port": None, "cert": None, "auth": None}

@property
def server_context(self) -> Tuple[str, str, Tuple[str, str]]:
"""Returns a tuple with
- url of the server
- path to the .cert file
- password for the certificate file
"""

return (
self.get_server_url(self.server_info),
self.server_info[CERTIFICATE_DIR],
("username", self.server_info["auth"]),
)

@property
def export_path(self):
"""Returns the export file path. If not file name is provide the default
Expand All @@ -738,11 +670,6 @@ def export_path(self):
default_export_file = f"{os.path.splitext(self.config_file)[0]}.csv"
return os.path.join(full_file_path, default_export_file)

@property
def everserver_status_path(self):
"""Returns path to the everest server status file"""
return os.path.join(self.session_dir, SERVER_STATUS)

def to_dict(self) -> dict:
the_dict = self.model_dump(exclude_none=True)

Expand Down
75 changes: 74 additions & 1 deletion src/everest/config/server_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from typing import Literal, Optional
import json
import os
from typing import Literal, Optional, Tuple

from pydantic import BaseModel, ConfigDict, Field

from ..strings import (
CERTIFICATE_DIR,
DETACHED_NODE_DIR,
HOSTFILE_NAME,
SERVER_STATUS,
SESSION_DIR,
)
from .has_ert_queue_options import HasErtQueueOptions


Expand Down Expand Up @@ -41,3 +50,67 @@ class ServerConfig(BaseModel, HasErtQueueOptions): # type: ignore
model_config = ConfigDict(
extra="forbid",
)

@staticmethod
def get_server_url(output_dir: str) -> str:
"""Return the url of the server.

If server_info are given, the url is generated using that info. Otherwise
server information are retrieved from the hostfile
"""
server_info = ServerConfig.get_server_info(output_dir)
return f"https://{server_info['host']}:{server_info['port']}"

@staticmethod
def get_server_context(output_dir: str) -> Tuple[str, bool, Tuple[str, str]]:
"""Returns a tuple with
- url of the server
- path to the .cert file
- password for the certificate file
"""
server_info = ServerConfig.get_server_info(output_dir)
return (
ServerConfig.get_server_url(output_dir),
server_info[CERTIFICATE_DIR],
("username", server_info["auth"]),
)

@staticmethod
def get_server_info(output_dir: str) -> dict:
"""Load server information from the hostfile"""
host_file_path = ServerConfig.get_hostfile_path(output_dir)
try:
with open(host_file_path, "r", encoding="utf-8") as f:
json_string = f.read()

data = json.loads(json_string)
if set(data.keys()) != {"host", "port", "cert", "auth"}:
raise RuntimeError("Malformed hostfile")
return data
except FileNotFoundError:
# No host file
return {"host": None, "port": None, "cert": None, "auth": None}

@staticmethod
def get_detached_node_dir(output_dir: str) -> str:
return os.path.join(os.path.abspath(output_dir), DETACHED_NODE_DIR)

@staticmethod
def get_hostfile_path(output_dir: str) -> str:
return os.path.join(ServerConfig.get_session_dir(output_dir), HOSTFILE_NAME)

@staticmethod
def get_session_dir(output_dir: str) -> str:
"""Return path to the session directory containing information about the
certificates and host information"""
return os.path.join(ServerConfig.get_detached_node_dir(output_dir), SESSION_DIR)

@staticmethod
def get_everserver_status_path(output_dir: str) -> str:
"""Returns path to the everest server status file"""
return os.path.join(ServerConfig.get_session_dir(output_dir), SERVER_STATUS)

@staticmethod
def get_certificate_dir(output_dir: str) -> str:
"""Return the path to certificate folder"""
return os.path.join(ServerConfig.get_session_dir(output_dir), CERTIFICATE_DIR)
30 changes: 17 additions & 13 deletions src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ert import BatchContext, BatchSimulator, JobState
from ert.config import ErtConfig, QueueSystem
from everest.config import EverestConfig
from everest.config import EverestConfig, ServerConfig
from everest.config_keys import ConfigKeys as CK
from everest.strings import (
EVEREST,
Expand Down Expand Up @@ -59,7 +59,9 @@ def start_server(config: EverestConfig, ert_config: ErtConfig, storage):
"""
Start an Everest server running the optimization defined in the config
"""
if server_is_running(*config.server_context): # better safe than sorry
if server_is_running(
*ServerConfig.get_server_context(config.output_dir)
): # better safe than sorry
return

log_dir = config.log_dir
Expand Down Expand Up @@ -143,7 +145,7 @@ def stop_server(config: EverestConfig, retries: int = 5):
"""
for retry in range(retries):
try:
url, cert, auth = config.server_context
url, cert, auth = ServerConfig.get_server_context(config.output_dir)
stop_endpoint = "/".join([url, STOP_ENDPOINT])
response = requests.post(
stop_endpoint,
Expand Down Expand Up @@ -174,7 +176,7 @@ def wait_for_server(

Raise an exception when the timeout is reached.
"""
if not server_is_running(*config.server_context):
if not server_is_running(*ServerConfig.get_server_context(config.output_dir)):
sleep_time_increment = float(timeout) / (2**_HTTP_REQUEST_RETRY - 1)
for retry_count in range(_HTTP_REQUEST_RETRY):
# Failure may occur before contact with the server is established:
Expand Down Expand Up @@ -218,11 +220,11 @@ def wait_for_server(

sleep_time = sleep_time_increment * (2**retry_count)
time.sleep(sleep_time)
if server_is_running(*config.server_context):
if server_is_running(*ServerConfig.get_server_context(config.output_dir)):
return

# If number of retries reached and server is not running - throw exception
if not server_is_running(*config.server_context):
if not server_is_running(*ServerConfig.get_server_context(config.output_dir)):
raise RuntimeError("Failed to start server within configured timeout.")


Expand Down Expand Up @@ -264,16 +266,18 @@ def wait_for_server_to_stop(config: EverestConfig, timeout):

Raise an exception when the timeout is reached.
"""
if server_is_running(*config.server_context):
if server_is_running(*ServerConfig.get_server_context(config.output_dir)):
sleep_time_increment = float(timeout) / (2**_HTTP_REQUEST_RETRY - 1)
for retry_count in range(_HTTP_REQUEST_RETRY):
sleep_time = sleep_time_increment * (2**retry_count)
time.sleep(sleep_time)
if not server_is_running(*config.server_context):
if not server_is_running(
*ServerConfig.get_server_context(config.output_dir)
):
return

# If number of retries reached and server still running - throw exception
if server_is_running(*config.server_context):
if server_is_running(*ServerConfig.get_server_context(config.output_dir)):
raise Exception("Failed to stop server within configured timeout.")


Expand Down Expand Up @@ -310,7 +314,7 @@ def start_monitor(config: EverestConfig, callback, polling_interval=5):
Monitoring stops when the server stops answering. It can also be
interrupted by returning True from the callback
"""
url, cert, auth = config.server_context
url, cert, auth = ServerConfig.get_server_context(config.output_dir)
sim_endpoint = "/".join([url, SIM_PROGRESS_ENDPOINT])
opt_endpoint = "/".join([url, OPT_PROGRESS_ENDPOINT])

Expand Down Expand Up @@ -448,7 +452,7 @@ def generate_everserver_ert_config(config: EverestConfig, debug_mode: bool = Fal

site_config = ErtConfig.read_site_config()
abs_everest_config = os.path.join(config.config_directory, config.config_file)
detached_node_dir = config.detached_node_dir
detached_node_dir = ServerConfig.get_detached_node_dir(config.output_dir)
simulation_path = os.path.join(detached_node_dir, SIMULATION_DIR)
queue_system = _find_res_queue_system(config)
arg_list = ["--config-file", abs_everest_config]
Expand Down Expand Up @@ -532,7 +536,7 @@ def update_everserver_status(
):
"""Update the everest server status with new status information"""
new_status = {"status": status, "message": message}
path = config.everserver_status_path
path = ServerConfig.get_everserver_status_path(config.output_dir)
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
with open(path, "w", encoding="utf-8") as outfile:
Expand Down Expand Up @@ -560,7 +564,7 @@ def everserver_status(config: EverestConfig):
'message': None
}
"""
path = config.everserver_status_path
path = ServerConfig.get_everserver_status_path(config.output_dir)
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
return json.load(f, object_hook=ServerStatusEncoder.decode)
Expand Down
Loading
Loading