From 165fa9c4a95ed4541ee8556c4dfdc7da5fa414f3 Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Wed, 13 Nov 2024 07:57:25 +0100 Subject: [PATCH] Refactor forwardmodelrunner to be async This should help the forward model runner shutting down more gracefully, and removing some of the errors we are seeing in the logs. --- src/_ert/forward_model_runner/cli.py | 50 ++++++++---- src/_ert/forward_model_runner/client.py | 5 +- src/_ert/forward_model_runner/job_dispatch.py | 8 +- .../forward_model_runner/reporting/base.py | 6 +- .../forward_model_runner/reporting/event.py | 60 +++++++------- .../forward_model_runner/reporting/file.py | 5 +- .../reporting/interactive.py | 9 ++- .../reporting/statemachine.py | 10 ++- src/_ert/forward_model_runner/runner.py | 81 ++++++++++--------- src/ert/ensemble_evaluator/_ensemble.py | 2 +- 10 files changed, 136 insertions(+), 100 deletions(-) diff --git a/src/_ert/forward_model_runner/cli.py b/src/_ert/forward_model_runner/cli.py index a41b0ca4b16..5ea70f4a594 100644 --- a/src/_ert/forward_model_runner/cli.py +++ b/src/_ert/forward_model_runner/cli.py @@ -1,4 +1,5 @@ import argparse +import asyncio import json import logging import os @@ -90,7 +91,7 @@ def _read_jobs_file(retry=True): raise e -def main(args): +async def main(args): parser = argparse.ArgumentParser( description=( "Run all the jobs specified in jobs.json, " @@ -137,19 +138,38 @@ def main(args): ) job_runner = ForwardModelRunner(jobs_data) + job_task = asyncio.create_task(_main(job_runner, parsed_args, reporters)) - for job_status in job_runner.run(parsed_args.job): - logger.info(f"Job status: {job_status}") + def handle_sigterm(*args, **kwargs): + nonlocal reporters, job_task + job_task.cancel() for reporter in reporters: - try: - reporter.report(job_status) - except OSError as oserror: - print( - f"job_dispatch failed due to {oserror}. Stopping and cleaning up." - ) - pgid = os.getpgid(os.getpid()) - os.killpg(pgid, signal.SIGKILL) - - if isinstance(job_status, Finish) and not job_status.success(): - pgid = os.getpgid(os.getpid()) - os.killpg(pgid, signal.SIGKILL) + reporter.cancel() + + asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, handle_sigterm) + await job_task + + +async def _main( + job_runner: ForwardModelRunner, + parsed_args, + reporters: typing.Sequence[reporting.Reporter], +): + try: + async for job_status in job_runner.run(parsed_args.job): + logger.info(f"Job status: {job_status}") + + for reporter in reporters: + try: + await reporter.report(job_status) + await asyncio.sleep(0) + except OSError as oserror: + print( + f"job_dispatch failed due to {oserror}. Stopping and cleaning up." + ) + return + + if isinstance(job_status, Finish) and not job_status.success(): + return + except asyncio.CancelledError: + pass diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 2566ca005f8..455ac2624cd 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -96,7 +96,7 @@ async def get_websocket(self) -> WebSocketClientProtocol: close_timeout=self.CONNECTION_TIMEOUT, ) - async def _send(self, msg: AnyStr) -> None: + async def send(self, msg: AnyStr) -> None: for retry in range(self._max_retries + 1): try: if self.websocket is None: @@ -133,6 +133,3 @@ async def _send(self, msg: AnyStr) -> None: raise ClientConnectionError(_error_msg) from exception await asyncio.sleep(0.2 + self._timeout_multiplier * retry) self.websocket = None - - def send(self, msg: AnyStr) -> None: - self.loop.run_until_complete(self._send(msg)) diff --git a/src/_ert/forward_model_runner/job_dispatch.py b/src/_ert/forward_model_runner/job_dispatch.py index ccd1e5044c2..889c8b57507 100644 --- a/src/_ert/forward_model_runner/job_dispatch.py +++ b/src/_ert/forward_model_runner/job_dispatch.py @@ -1,3 +1,4 @@ +import asyncio import os import signal import sys @@ -13,12 +14,7 @@ def sigterm_handler(_signo, _stack_frame): def main(): os.nice(19) signal.signal(signal.SIGTERM, sigterm_handler) - try: - job_runner_main(sys.argv) - except Exception as e: - pgid = os.getpgid(os.getpid()) - os.killpg(pgid, signal.SIGTERM) - raise e + asyncio.run(job_runner_main(sys.argv)) if __name__ == "__main__": diff --git a/src/_ert/forward_model_runner/reporting/base.py b/src/_ert/forward_model_runner/reporting/base.py index 5b7dd1e3dc8..65e0e54d825 100644 --- a/src/_ert/forward_model_runner/reporting/base.py +++ b/src/_ert/forward_model_runner/reporting/base.py @@ -5,5 +5,9 @@ class Reporter(ABC): @abstractmethod - def report(self, msg: Message): + async def report(self, msg: Message): """Report a message.""" + + @abstractmethod + def cancel(self): + """Safely shut down the reporter""" diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 8bf13dee238..998148d930a 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -1,7 +1,7 @@ from __future__ import annotations +import asyncio import logging -import queue import threading from datetime import datetime, timedelta from pathlib import Path @@ -32,7 +32,6 @@ Start, ) from _ert.forward_model_runner.reporting.statemachine import StateMachine -from _ert.threading import ErtThread logger = logging.getLogger(__name__) @@ -76,22 +75,24 @@ def __init__(self, evaluator_url, token=None, cert_path=None): self._ens_id = None self._real_id = None - self._event_queue: queue.Queue[events.Event | EventSentinel] = queue.Queue() - self._event_publisher_thread = ErtThread(target=self._event_publisher) + self._event_queue: asyncio.Queue[events.Event | EventSentinel] = asyncio.Queue() + # self._event_publisher_thread = ErtThread(target=self._event_publisher) self._timeout_timestamp = None self._timestamp_lock = threading.Lock() # seconds to timeout the reporter the thread after Finish() was received self._reporter_timeout = 60 + self._running = True + self._event_publishing_task = asyncio.create_task(self.async_event_publisher()) - def _event_publisher(self): + async def async_event_publisher(self): logger.debug("Publishing event.") - with Client( + async with Client( url=self._evaluator_url, token=self._token, cert=self._cert, ) as client: event = None - while True: + while self._running: with self._timestamp_lock: if ( self._timeout_timestamp is not None @@ -102,11 +103,13 @@ def _event_publisher(self): if event is None: # if we successfully sent the event we can proceed # to next one - event = self._event_queue.get() + event = await self._event_queue.get() if event is self._sentinel: - break + print("NEW EVENT WAS SENTINEL :))") + return try: - client.send(event_to_json(event)) + await client.send(event_to_json(event)) + self._event_queue.task_done() event = None except ClientConnectionError as exception: # Possible intermittent failure, we retry sending the event @@ -115,21 +118,21 @@ def _event_publisher(self): # The receiving end has closed the connection, we stop # sending events logger.debug(str(exception)) + self._event_queue.task_done() break - def report(self, msg): - self._statemachine.transition(msg) + async def report(self, msg): + await self._statemachine.transition(msg) - def _dump_event(self, event: events.Event): + async def _dump_event(self, event: events.Event): logger.debug(f'Schedule "{type(event)}" for delivery') - self._event_queue.put(event) + await self._event_queue.put(event) - def _init_handler(self, msg: Init): + async def _init_handler(self, msg: Init): self._ens_id = str(msg.ens_id) self._real_id = str(msg.real_id) - self._event_publisher_thread.start() - def _job_handler(self, msg: Union[Start, Running, Exited]): + async def _job_handler(self, msg: Union[Start, Running, Exited]): assert msg.job job_name = msg.job.name() job_msg = { @@ -144,16 +147,16 @@ def _job_handler(self, msg: Union[Start, Running, Exited]): std_out=str(Path(msg.job.std_out).resolve()), std_err=str(Path(msg.job.std_err).resolve()), ) - self._dump_event(event) + await self._dump_event(event) if not msg.success(): logger.error(f"Job {job_name} FAILED to start") event = ForwardModelStepFailure(**job_msg, error_msg=msg.error_message) - self._dump_event(event) + await self._dump_event(event) elif isinstance(msg, Exited): if msg.success(): logger.debug(f"Job {job_name} exited successfully") - self._dump_event(ForwardModelStepSuccess(**job_msg)) + await self._dump_event(ForwardModelStepSuccess(**job_msg)) else: logger.error( _JOB_EXIT_FAILED_STRING.format( @@ -165,7 +168,7 @@ def _job_handler(self, msg: Union[Start, Running, Exited]): event = ForwardModelStepFailure( **job_msg, exit_code=msg.exit_code, error_msg=msg.error_message ) - self._dump_event(event) + await self._dump_event(event) elif isinstance(msg, Running): logger.debug(f"{job_name} job is running") @@ -175,21 +178,22 @@ def _job_handler(self, msg: Union[Start, Running, Exited]): current_memory_usage=msg.memory_status.rss, cpu_seconds=msg.memory_status.cpu_seconds, ) - self._dump_event(event) + await self._dump_event(event) - def _finished_handler(self, _): - self._event_queue.put(Event._sentinel) + async def _finished_handler(self, _): + await self._event_queue.put(Event._sentinel) with self._timestamp_lock: self._timeout_timestamp = datetime.now() + timedelta( seconds=self._reporter_timeout ) - if self._event_publisher_thread.is_alive(): - self._event_publisher_thread.join() - def _checksum_handler(self, msg: Checksum): + async def _checksum_handler(self, msg: Checksum): fm_checksum = ForwardModelStepChecksum( ensemble=self._ens_id, real=self._real_id, checksums={msg.run_path: msg.data}, ) - self._dump_event(fm_checksum) + await self._dump_event(fm_checksum) + + def cancel(self): + self._event_publishing_task.cancel() diff --git a/src/_ert/forward_model_runner/reporting/file.py b/src/_ert/forward_model_runner/reporting/file.py index e6e601fe0f2..4d3675d6191 100644 --- a/src/_ert/forward_model_runner/reporting/file.py +++ b/src/_ert/forward_model_runner/reporting/file.py @@ -39,7 +39,7 @@ def __init__(self): self.status_dict = {} self.node = socket.gethostname() - def report(self, msg: Message): + async def report(self, msg: Message): fm_step_status = {} if msg.job: @@ -217,3 +217,6 @@ def _dump_ok_file(): def _dump_status_json(self): with open(STATUS_json, "wb") as fp: fp.write(orjson.dumps(self.status_dict, option=orjson.OPT_INDENT_2)) + + def cancel(self): + pass diff --git a/src/_ert/forward_model_runner/reporting/interactive.py b/src/_ert/forward_model_runner/reporting/interactive.py index fd489c78378..1759db01a35 100644 --- a/src/_ert/forward_model_runner/reporting/interactive.py +++ b/src/_ert/forward_model_runner/reporting/interactive.py @@ -11,7 +11,7 @@ class Interactive(Reporter): @staticmethod - def _report(msg: Message) -> Optional[str]: + async def _report(msg: Message) -> Optional[str]: if not isinstance(msg, (Start, Finish)): return None if isinstance(msg, Finish): @@ -26,7 +26,10 @@ def _report(msg: Message) -> Optional[str]: ) return f"Running job: {msg.job.name()} ... " - def report(self, msg: Message): - _msg = self._report(msg) + async def report(self, msg: Message): + _msg = await self._report(msg) if _msg is not None: print(_msg) + + def cancel(self): + pass diff --git a/src/_ert/forward_model_runner/reporting/statemachine.py b/src/_ert/forward_model_runner/reporting/statemachine.py index 4d749414e4d..4f949662ed1 100644 --- a/src/_ert/forward_model_runner/reporting/statemachine.py +++ b/src/_ert/forward_model_runner/reporting/statemachine.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Dict, Tuple, Type +from typing import Awaitable, Callable, Dict, Tuple, Type from _ert.forward_model_runner.reporting.message import ( Checksum, @@ -35,13 +35,15 @@ def __init__(self) -> None: self._state = None def add_handler( - self, states: Tuple[Type[Message], ...], handler: Callable[[Message], None] + self, + states: Tuple[Type[Message], ...], + handler: Callable[[Message], Awaitable[None]], ) -> None: if states in self._handler: raise ValueError(f"{states} already handled by {self._handler[states]}") self._handler[states] = handler - def transition(self, message: Message): + async def transition(self, message: Message): new_state = None for state in self._handler: if isinstance(message, state): @@ -58,5 +60,5 @@ def transition(self, message: Message): f"expected to transition into {self._transitions[self._state]}" ) - self._handler[new_state](message) + await self._handler[new_state](message) self._state = new_state diff --git a/src/_ert/forward_model_runner/runner.py b/src/_ert/forward_model_runner/runner.py index bd304f3c7d3..38fde86bb89 100644 --- a/src/_ert/forward_model_runner/runner.py +++ b/src/_ert/forward_model_runner/runner.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import json import os @@ -49,46 +50,52 @@ def _populate_checksums(self, manifest): info["error"] = f"Expected file {path} not created by forward model!" return manifest - def run(self, names_of_steps_to_run: List[str]): - if not names_of_steps_to_run: - step_queue = self.steps - else: - step_queue = [ - step for step in self.steps if step.name() in names_of_steps_to_run - ] - init_message = Init( - step_queue, - self.simulation_id, - self.ert_pid, - self.ens_id, - self.real_id, - self.experiment_id, - ) - - unused = set(names_of_steps_to_run) - {step.name() for step in step_queue} - if unused: - init_message.with_error( - f"{unused} does not exist. " - f"Available forward_model steps: {[step.name() for step in self.steps]}" + async def run(self, names_of_steps_to_run: List[str]): + try: + if not names_of_steps_to_run: + step_queue = self.steps + else: + step_queue = [ + step for step in self.steps if step.name() in names_of_steps_to_run + ] + init_message = Init( + step_queue, + self.simulation_id, + self.ert_pid, + self.ens_id, + self.real_id, + self.experiment_id, ) - yield init_message - return - else: - yield init_message - for step in step_queue: - for status_update in step.run(): - yield status_update - if not status_update.success(): - yield Checksum(checksum_dict={}, run_path=os.getcwd()) - yield Finish().with_error( - "Not all forward model steps completed successfully." - ) - return + unused = set(names_of_steps_to_run) - {step.name() for step in step_queue} + if unused: + init_message.with_error( + f"{unused} does not exist. " + f"Available forward_model steps: {[step.name() for step in self.steps]}" + ) + yield init_message + return + + yield init_message + for step in step_queue: + for status_update in step.run(): + yield status_update + if not status_update.success(): + yield Checksum(checksum_dict={}, run_path=os.getcwd()) + yield Finish().with_error( + "Not all forward model steps completed successfully." + ) + return - checksum_dict = self._populate_checksums(self._read_manifest()) - yield Checksum(checksum_dict=checksum_dict, run_path=os.getcwd()) - yield Finish() + checksum_dict = self._populate_checksums(self._read_manifest()) + yield Checksum(checksum_dict=checksum_dict, run_path=os.getcwd()) + yield Finish() + except asyncio.CancelledError: + yield Checksum(checksum_dict={}, run_path=os.getcwd()) + yield Finish().with_error( + "Not all forward model steps completed successfully." + ) + return def _set_environment(self): if self.global_environment: diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index ecc1d5c81d5..51ba388d3a8 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -205,7 +205,7 @@ async def send_event( retries: int = 10, ) -> None: async with Client(url, token, cert, max_retries=retries) as client: - await client._send(event_to_json(event)) + await client.send(event_to_json(event)) def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]: def event_builder(status: str) -> Event: