Skip to content

Commit

Permalink
Refactor forwardmodelrunner to be async
Browse files Browse the repository at this point in the history
This should help the forward model runner shutting down more gracefully, and removing some of the errors we are seeing in the logs.
  • Loading branch information
jonathan-eq committed Nov 13, 2024
1 parent 03cfa25 commit 165fa9c
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 100 deletions.
50 changes: 35 additions & 15 deletions src/_ert/forward_model_runner/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import asyncio
import json
import logging
import os
Expand Down Expand Up @@ -90,7 +91,7 @@ def _read_jobs_file(retry=True):
raise e


def main(args):
async def main(args):
parser = argparse.ArgumentParser(
description=(
"Run all the jobs specified in jobs.json, "
Expand Down Expand Up @@ -137,19 +138,38 @@ def main(args):
)

job_runner = ForwardModelRunner(jobs_data)
job_task = asyncio.create_task(_main(job_runner, parsed_args, reporters))

for job_status in job_runner.run(parsed_args.job):
logger.info(f"Job status: {job_status}")
def handle_sigterm(*args, **kwargs):
nonlocal reporters, job_task
job_task.cancel()
for reporter in reporters:
try:
reporter.report(job_status)
except OSError as oserror:
print(
f"job_dispatch failed due to {oserror}. Stopping and cleaning up."
)
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGKILL)

if isinstance(job_status, Finish) and not job_status.success():
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGKILL)
reporter.cancel()

asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, handle_sigterm)
await job_task


async def _main(
job_runner: ForwardModelRunner,
parsed_args,
reporters: typing.Sequence[reporting.Reporter],
):
try:
async for job_status in job_runner.run(parsed_args.job):
logger.info(f"Job status: {job_status}")

for reporter in reporters:
try:
await reporter.report(job_status)
await asyncio.sleep(0)
except OSError as oserror:
print(
f"job_dispatch failed due to {oserror}. Stopping and cleaning up."
)
return

if isinstance(job_status, Finish) and not job_status.success():
return
except asyncio.CancelledError:
pass
5 changes: 1 addition & 4 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def get_websocket(self) -> WebSocketClientProtocol:
close_timeout=self.CONNECTION_TIMEOUT,
)

async def _send(self, msg: AnyStr) -> None:
async def send(self, msg: AnyStr) -> None:
for retry in range(self._max_retries + 1):
try:
if self.websocket is None:
Expand Down Expand Up @@ -133,6 +133,3 @@ async def _send(self, msg: AnyStr) -> None:
raise ClientConnectionError(_error_msg) from exception
await asyncio.sleep(0.2 + self._timeout_multiplier * retry)
self.websocket = None

def send(self, msg: AnyStr) -> None:
self.loop.run_until_complete(self._send(msg))
8 changes: 2 additions & 6 deletions src/_ert/forward_model_runner/job_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import signal
import sys
Expand All @@ -13,12 +14,7 @@ def sigterm_handler(_signo, _stack_frame):
def main():
os.nice(19)
signal.signal(signal.SIGTERM, sigterm_handler)
try:
job_runner_main(sys.argv)
except Exception as e:
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGTERM)
raise e
asyncio.run(job_runner_main(sys.argv))


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion src/_ert/forward_model_runner/reporting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@

class Reporter(ABC):
@abstractmethod
def report(self, msg: Message):
async def report(self, msg: Message):
"""Report a message."""

@abstractmethod
def cancel(self):
"""Safely shut down the reporter"""
60 changes: 32 additions & 28 deletions src/_ert/forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
import logging
import queue
import threading
from datetime import datetime, timedelta
from pathlib import Path
Expand Down Expand Up @@ -32,7 +32,6 @@
Start,
)
from _ert.forward_model_runner.reporting.statemachine import StateMachine
from _ert.threading import ErtThread

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,22 +75,24 @@ def __init__(self, evaluator_url, token=None, cert_path=None):

self._ens_id = None
self._real_id = None
self._event_queue: queue.Queue[events.Event | EventSentinel] = queue.Queue()
self._event_publisher_thread = ErtThread(target=self._event_publisher)
self._event_queue: asyncio.Queue[events.Event | EventSentinel] = asyncio.Queue()
# self._event_publisher_thread = ErtThread(target=self._event_publisher)
self._timeout_timestamp = None
self._timestamp_lock = threading.Lock()
# seconds to timeout the reporter the thread after Finish() was received
self._reporter_timeout = 60
self._running = True
self._event_publishing_task = asyncio.create_task(self.async_event_publisher())

def _event_publisher(self):
async def async_event_publisher(self):
logger.debug("Publishing event.")
with Client(
async with Client(
url=self._evaluator_url,
token=self._token,
cert=self._cert,
) as client:
event = None
while True:
while self._running:
with self._timestamp_lock:
if (
self._timeout_timestamp is not None
Expand All @@ -102,11 +103,13 @@ def _event_publisher(self):
if event is None:
# if we successfully sent the event we can proceed
# to next one
event = self._event_queue.get()
event = await self._event_queue.get()
if event is self._sentinel:
break
print("NEW EVENT WAS SENTINEL :))")
return
try:
client.send(event_to_json(event))
await client.send(event_to_json(event))
self._event_queue.task_done()
event = None
except ClientConnectionError as exception:
# Possible intermittent failure, we retry sending the event
Expand All @@ -115,21 +118,21 @@ def _event_publisher(self):
# The receiving end has closed the connection, we stop
# sending events
logger.debug(str(exception))
self._event_queue.task_done()
break

def report(self, msg):
self._statemachine.transition(msg)
async def report(self, msg):
await self._statemachine.transition(msg)

def _dump_event(self, event: events.Event):
async def _dump_event(self, event: events.Event):
logger.debug(f'Schedule "{type(event)}" for delivery')
self._event_queue.put(event)
await self._event_queue.put(event)

def _init_handler(self, msg: Init):
async def _init_handler(self, msg: Init):
self._ens_id = str(msg.ens_id)
self._real_id = str(msg.real_id)
self._event_publisher_thread.start()

def _job_handler(self, msg: Union[Start, Running, Exited]):
async def _job_handler(self, msg: Union[Start, Running, Exited]):
assert msg.job
job_name = msg.job.name()
job_msg = {
Expand All @@ -144,16 +147,16 @@ def _job_handler(self, msg: Union[Start, Running, Exited]):
std_out=str(Path(msg.job.std_out).resolve()),
std_err=str(Path(msg.job.std_err).resolve()),
)
self._dump_event(event)
await self._dump_event(event)
if not msg.success():
logger.error(f"Job {job_name} FAILED to start")
event = ForwardModelStepFailure(**job_msg, error_msg=msg.error_message)
self._dump_event(event)
await self._dump_event(event)

elif isinstance(msg, Exited):
if msg.success():
logger.debug(f"Job {job_name} exited successfully")
self._dump_event(ForwardModelStepSuccess(**job_msg))
await self._dump_event(ForwardModelStepSuccess(**job_msg))
else:
logger.error(
_JOB_EXIT_FAILED_STRING.format(
Expand All @@ -165,7 +168,7 @@ def _job_handler(self, msg: Union[Start, Running, Exited]):
event = ForwardModelStepFailure(
**job_msg, exit_code=msg.exit_code, error_msg=msg.error_message
)
self._dump_event(event)
await self._dump_event(event)

elif isinstance(msg, Running):
logger.debug(f"{job_name} job is running")
Expand All @@ -175,21 +178,22 @@ def _job_handler(self, msg: Union[Start, Running, Exited]):
current_memory_usage=msg.memory_status.rss,
cpu_seconds=msg.memory_status.cpu_seconds,
)
self._dump_event(event)
await self._dump_event(event)

def _finished_handler(self, _):
self._event_queue.put(Event._sentinel)
async def _finished_handler(self, _):
await self._event_queue.put(Event._sentinel)
with self._timestamp_lock:
self._timeout_timestamp = datetime.now() + timedelta(
seconds=self._reporter_timeout
)
if self._event_publisher_thread.is_alive():
self._event_publisher_thread.join()

def _checksum_handler(self, msg: Checksum):
async def _checksum_handler(self, msg: Checksum):
fm_checksum = ForwardModelStepChecksum(
ensemble=self._ens_id,
real=self._real_id,
checksums={msg.run_path: msg.data},
)
self._dump_event(fm_checksum)
await self._dump_event(fm_checksum)

def cancel(self):
self._event_publishing_task.cancel()
5 changes: 4 additions & 1 deletion src/_ert/forward_model_runner/reporting/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self):
self.status_dict = {}
self.node = socket.gethostname()

def report(self, msg: Message):
async def report(self, msg: Message):
fm_step_status = {}

if msg.job:
Expand Down Expand Up @@ -217,3 +217,6 @@ def _dump_ok_file():
def _dump_status_json(self):
with open(STATUS_json, "wb") as fp:
fp.write(orjson.dumps(self.status_dict, option=orjson.OPT_INDENT_2))

def cancel(self):
pass
9 changes: 6 additions & 3 deletions src/_ert/forward_model_runner/reporting/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class Interactive(Reporter):
@staticmethod
def _report(msg: Message) -> Optional[str]:
async def _report(msg: Message) -> Optional[str]:
if not isinstance(msg, (Start, Finish)):
return None
if isinstance(msg, Finish):
Expand All @@ -26,7 +26,10 @@ def _report(msg: Message) -> Optional[str]:
)
return f"Running job: {msg.job.name()} ... "

def report(self, msg: Message):
_msg = self._report(msg)
async def report(self, msg: Message):
_msg = await self._report(msg)
if _msg is not None:
print(_msg)

def cancel(self):
pass
10 changes: 6 additions & 4 deletions src/_ert/forward_model_runner/reporting/statemachine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable, Dict, Tuple, Type
from typing import Awaitable, Callable, Dict, Tuple, Type

from _ert.forward_model_runner.reporting.message import (
Checksum,
Expand Down Expand Up @@ -35,13 +35,15 @@ def __init__(self) -> None:
self._state = None

def add_handler(
self, states: Tuple[Type[Message], ...], handler: Callable[[Message], None]
self,
states: Tuple[Type[Message], ...],
handler: Callable[[Message], Awaitable[None]],
) -> None:
if states in self._handler:
raise ValueError(f"{states} already handled by {self._handler[states]}")
self._handler[states] = handler

def transition(self, message: Message):
async def transition(self, message: Message):
new_state = None
for state in self._handler:
if isinstance(message, state):
Expand All @@ -58,5 +60,5 @@ def transition(self, message: Message):
f"expected to transition into {self._transitions[self._state]}"
)

self._handler[new_state](message)
await self._handler[new_state](message)
self._state = new_state
Loading

0 comments on commit 165fa9c

Please sign in to comment.