Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Nov 12, 2024
1 parent 9b0640d commit 8c9c835
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 297 deletions.
2 changes: 1 addition & 1 deletion src/everest/bin/everest_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def run_everest(options):

makedirs_if_needed(options.config.output_dir, roll_if_exists=True)
server_config = generate_everserver_config(options.config)
await start_server(options.config, server_config)
await start_server(options.config, server_config, options.debug)
print("Waiting for server ...")
wait_for_server(options.config, timeout=600)
print("Everest server found!")
Expand Down
28 changes: 8 additions & 20 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Literal,
Optional,
Protocol,
Self,
Tuple,
no_type_check,
)
Expand Down Expand Up @@ -107,14 +106,12 @@ class HasName(Protocol):

class EverestConfig(BaseModelWithPropertySupport): # type: ignore
controls: Annotated[List[ControlConfig], AfterValidator(unique_items)] = Field(
default_factory=list,
description="""Defines a list of controls.
Controls should have unique names each control defines
a group of control variables
""",
)
objective_functions: List[ObjectiveFunctionConfig] = Field(
default_factory=list,
description="List of objective function specifications",
)
optimization: Optional[OptimizationConfig] = Field(
Expand Down Expand Up @@ -220,24 +217,9 @@ class EverestConfig(BaseModelWithPropertySupport): # type: ignore
default=None,
description="Settings to control the exports of a optimization run by everest.",
)
config_path: Path = Field(default=Path.cwd())
config_path: Path = Field()
model_config = ConfigDict(extra="forbid")

@model_validator(mode="after")
def check_defaults(self) -> Self:
if self.server is None:
self.server = ServerConfig(queue_system=self.simulator.queue_system)
if self.simulator.queue_system != self.server.queue_system:
# Check something
pass
return self

#
# @model_validator(mode="after")
# def validate_queue_systems(self): # pylint: disable=E0213
# # if self.server.queue_system
# pass

@model_validator(mode="after")
def validate_install_job_sources(self): # pylint: disable=E0213
model = self.model
Expand Down Expand Up @@ -775,7 +757,13 @@ def with_defaults(cls, **kwargs):
Creates an Everest config with default values. Useful for initializing a config
without having to provide empty defaults.
"""
return EverestConfig.model_validate({**kwargs})
defaults = {
"controls": [],
"objective_functions": [],
"config_path": ".",
}

return EverestConfig.model_validate({**defaults, **kwargs})

@staticmethod
def lint_config_dict(config: dict) -> List["ErrorDetails"]:
Expand Down
4 changes: 2 additions & 2 deletions src/everest/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ class ServerConfig(BaseModel, HasErtQueueOptions): # type: ignore
""",
) # Corresponds to queue name
exclude_host: Optional[str] = Field(
None,
"",
description="""Comma separated list of nodes that should be
excluded from the slurm run""",
)
include_host: Optional[str] = Field(
None,
"",
description="""Comma separated list of nodes that
should be included in the slurm run""",
)
Expand Down
4 changes: 2 additions & 2 deletions src/everest/config/simulator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class SimulatorConfig(BaseModel, HasErtQueueOptions, extra="forbid"): # type: i
"needs to be deleted.",
)
exclude_host: Optional[str] = Field(
None,
"",
description="""Comma separated list of nodes that should be
excluded from the slurm run.""",
)
include_host: Optional[str] = Field(
None,
"",
description="""Comma separated list of nodes that
should be included in the slurm run""",
)
Expand Down
111 changes: 35 additions & 76 deletions src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from seba_sqlite.snapshot import SebaSnapshot

from ert.config import QueueSystem
from ert.config.queue_config import QueueOptions
from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
QueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.scheduler import create_driver
from ert.scheduler.driver import FailedSubmit
from ert.scheduler.event import StartedEvent
Expand Down Expand Up @@ -56,7 +62,9 @@
_context = None


async def start_server(config: EverestConfig, queue_options: QueueOptions) -> None:
async def start_server(
config: EverestConfig, queue_options: QueueOptions, debug: bool = False
) -> None:
"""
Start an Everest server running the optimization defined in the config
"""
Expand Down Expand Up @@ -88,7 +96,10 @@ async def start_server(config: EverestConfig, queue_options: QueueOptions) -> No

driver = create_driver(queue_options)
try:
await driver.submit(0, "everserver", "--config-file", config.config_file)
args = ["--config-file", config.config_file]
if debug:
args.append("--debug")
await driver.submit(0, "everserver", *args)
except FailedSubmit as err:
raise ValueError(f"Failed to submit Everserver with error: {err}") from err
status = await driver.event_queue.get()
Expand Down Expand Up @@ -290,68 +301,6 @@ def start_monitor(config: EverestConfig, callback, polling_interval=5):
}


def _add_simulator_defaults(
options,
config: EverestConfig,
queue_options: List[Tuple[str, str]],
queue_system: Literal["LSF", "SLURM"],
):
simulator_options = (
config.simulator.extract_ert_queue_options(
queue_system=queue_system, everest_to_ert_key_tuples=queue_options
)
if config.simulator is not None
else []
)

option_names = [option[1] for option in options]
simulator_option_names = [option[1] for option in simulator_options]
options.extend(
simulator_options[simulator_option_names.index(res_key)]
for _, res_key in queue_options
if res_key not in option_names and res_key in simulator_option_names
)
return options


def _generate_queue_options(
config: EverestConfig,
queue_options: List[Tuple[str, str]],
res_queue_name: str, # Literal["LSF_QUEUE", "PARTITION"]?
queue_system: Literal["LSF", "SLURM", "TORQUE"],
):
queue_name_simulator = (
config.simulator.name if config.simulator is not None else None
)

queue_name = config.server.name if config.server is not None else None

if queue_name is None:
queue_name = queue_name_simulator

options = (
config.server.extract_ert_queue_options(
queue_system=queue_system, everest_to_ert_key_tuples=queue_options
)
if config.server is not None
else [(queue_system, "MAX_RUNNING", 1)]
)

if queue_name:
options.append(
(
queue_system,
res_queue_name,
queue_name,
),
)
# Inherit the include/exclude_host from the simulator config entry, if necessary.
# Currently this is only used by the slurm driver.
if queue_system == "SLURM":
options = _add_simulator_defaults(options, config, queue_options, queue_system)
return options


def _find_res_queue_system(config: EverestConfig):
queue_system_simulator: Literal["lsf", "local", "slurm", "torque"] = "local"
if config.simulator is not None:
Expand All @@ -372,20 +321,30 @@ def _find_res_queue_system(config: EverestConfig):
return QueueSystem(queue_system.upper())


def generate_everserver_config(config: EverestConfig, debug_mode: bool = False):
def generate_everserver_config(config: EverestConfig):
queue_system = _find_res_queue_system(config)

queue_options = None
if queue_system in _QUEUE_SYSTEMS:
queue_options = _generate_queue_options(
config,
_QUEUE_SYSTEMS[queue_system]["options"],
_QUEUE_SYSTEMS[queue_system]["name"],
queue_system,
ever_queue_config = config.server if config.server is not None else config.simulator

if queue_system == QueueSystem.LSF:
queue = LsfQueueOptions(
lsf_queue=ever_queue_config.name,
lsf_resource=ever_queue_config.options,
)
queue_options = {val[1]: val[2] for val in queue_options}
queue_options = {} if not queue_options else queue_options
return QueueOptions.create_queue_options(queue_system, queue_options, True)
elif queue_system == QueueSystem.SLURM:
queue = SlurmQueueOptions(
exclude_host=ever_queue_config.exclude_host,
include_host=ever_queue_config.include_host,
partition=ever_queue_config.name,
)
elif queue_system == QueueSystem.TORQUE:
queue = TorqueQueueOptions()
elif queue_system == QueueSystem.LOCAL:
queue = LocalQueueOptions()
else:
raise ValueError(f"Unknown queue system: {queue_system}")
queue.max_running = 1
return queue


def _query_server(cert, auth, endpoint):
Expand Down
Loading

0 comments on commit 8c9c835

Please sign in to comment.