From 2672a0952dd54bc6b161d4837be4636802aee9db Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Wed, 13 Nov 2024 15:11:57 +0100 Subject: [PATCH] fix tests --- src/_ert/forward_model_runner/cli.py | 15 +- .../forward_model_runner/reporting/event.py | 16 +- .../reporting/statemachine.py | 2 +- .../test_event_reporter.py | 190 ++++++++++-------- .../forward_model_runner/test_job_dispatch.py | 67 ++---- tests/ert/utils.py | 49 ++++- 6 files changed, 189 insertions(+), 150 deletions(-) diff --git a/src/_ert/forward_model_runner/cli.py b/src/_ert/forward_model_runner/cli.py index 5ea70f4a594..7353f9f17a0 100644 --- a/src/_ert/forward_model_runner/cli.py +++ b/src/_ert/forward_model_runner/cli.py @@ -5,7 +5,6 @@ import os import signal import sys -import time import typing from datetime import datetime @@ -72,21 +71,21 @@ def _setup_logging(directory: str = "logs"): JOBS_JSON_RETRY_TIME = 30 -def _wait_for_retry(): - time.sleep(JOBS_JSON_RETRY_TIME) +async def _wait_for_retry(): + await asyncio.sleep(JOBS_JSON_RETRY_TIME) -def _read_jobs_file(retry=True): +async def _read_jobs_file(retry=True): try: - with open(JOBS_FILE, "r", encoding="utf-8") as json_file: + with open(JOBS_FILE, "r", encoding="utf-8") as json_file: # noqa: ASYNC230 return json.load(json_file) except json.JSONDecodeError as e: raise IOError("Job Runner cli failed to load JSON-file.") from e except FileNotFoundError as e: if retry: logger.error(f"Could not find file {JOBS_FILE}, retrying") - _wait_for_retry() - return _read_jobs_file(retry=False) + await _wait_for_retry() + return await _read_jobs_file(retry=False) else: raise e @@ -119,7 +118,7 @@ async def main(args): # Make sure that logging is setup _after_ we have moved to the runpath directory _setup_logging() - jobs_data = _read_jobs_file() + jobs_data = await _read_jobs_file() experiment_id = jobs_data.get("experiment_id") ens_id = jobs_data.get("ens_id") diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 998148d930a..a0bfb14f94a 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -76,13 +76,15 @@ def __init__(self, evaluator_url, token=None, cert_path=None): self._ens_id = None self._real_id = None self._event_queue: asyncio.Queue[events.Event | EventSentinel] = asyncio.Queue() - # self._event_publisher_thread = ErtThread(target=self._event_publisher) self._timeout_timestamp = None self._timestamp_lock = threading.Lock() # seconds to timeout the reporter the thread after Finish() was received self._reporter_timeout = 60 - self._running = True self._event_publishing_task = asyncio.create_task(self.async_event_publisher()) + self._event_publisher_ready = asyncio.Event() + + async def join(self) -> None: + await self._event_publishing_task async def async_event_publisher(self): logger.debug("Publishing event.") @@ -91,8 +93,9 @@ async def async_event_publisher(self): token=self._token, cert=self._cert, ) as client: + self._event_publisher_ready.set() event = None - while self._running: + while True: with self._timestamp_lock: if ( self._timeout_timestamp is not None @@ -103,14 +106,17 @@ async def async_event_publisher(self): if event is None: # if we successfully sent the event we can proceed # to next one + print("GETTING MORE EVENTS!") event = await self._event_queue.get() if event is self._sentinel: + self._event_queue.task_done() print("NEW EVENT WAS SENTINEL :))") - return + break try: await client.send(event_to_json(event)) self._event_queue.task_done() event = None + print("Sent event :)") except ClientConnectionError as exception: # Possible intermittent failure, we retry sending the event logger.error(str(exception)) @@ -122,9 +128,11 @@ async def async_event_publisher(self): break async def report(self, msg): + await self._event_publisher_ready.wait() await self._statemachine.transition(msg) async def _dump_event(self, event: events.Event): + print(f"DUMPED EVENT {type(event)=}") logger.debug(f'Schedule "{type(event)}" for delivery') await self._event_queue.put(event) diff --git a/src/_ert/forward_model_runner/reporting/statemachine.py b/src/_ert/forward_model_runner/reporting/statemachine.py index 4f949662ed1..61ab517fd81 100644 --- a/src/_ert/forward_model_runner/reporting/statemachine.py +++ b/src/_ert/forward_model_runner/reporting/statemachine.py @@ -59,6 +59,6 @@ async def transition(self, message: Message): f"Illegal transition {self._state} -> {new_state} for {message}, " f"expected to transition into {self._transitions[self._state]}" ) - + print(f"TRANSITIONING STATE W/{message=}") await self._handler[new_state](message) self._state = new_state diff --git a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py index d7dad85f0e8..84ae95af645 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py +++ b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py @@ -1,6 +1,6 @@ +import asyncio import os import sys -import time from unittest.mock import patch import pytest @@ -27,17 +27,10 @@ Start, ) from _ert.forward_model_runner.reporting.statemachine import TransitionError -from tests.ert.utils import _mock_ws_thread +from tests.ert.utils import _mock_ws_task, async_wait_until -def _wait_until(condition, timeout, fail_msg): - start = time.time() - while not condition(): - assert start + timeout > time.time(), fail_msg - time.sleep(0.1) - - -def test_report_with_successful_start_message_argument(unused_tcp_port): +async def test_report_with_successful_start_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -45,10 +38,12 @@ def test_report_with_successful_start_message_argument(unused_tcp_port): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Start(fmstep1)) - reporter.report(Finish()) + + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Start(fmstep1)) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 1 event = event_from_json(lines[0]) @@ -58,9 +53,10 @@ def test_report_with_successful_start_message_argument(unused_tcp_port): assert event.fm_step == "0" assert os.path.basename(event.std_out) == "stdout" assert os.path.basename(event.std_err) == "stderr" + reporter._event_publishing_task.cancel() -def test_report_with_failed_start_message_argument(unused_tcp_port): +async def test_report_with_failed_start_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -70,13 +66,13 @@ def test_report_with_failed_start_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) msg = Start(fmstep1).with_error("massive_failure") - - reporter.report(msg) - reporter.report(Finish()) + await reporter.report(msg) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 2 event = event_from_json(lines[1]) @@ -84,7 +80,7 @@ def test_report_with_failed_start_message_argument(unused_tcp_port): assert event.error_msg == "massive_failure" -def test_report_with_successful_exit_message_argument(unused_tcp_port): +async def test_report_with_successful_exit_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -93,17 +89,18 @@ def test_report_with_successful_exit_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Exited(fmstep1, 0)) - reporter.report(Finish().with_error("failed")) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Exited(fmstep1, 0)) + await reporter.report(Finish().with_error("failed")) + await reporter.join() assert len(lines) == 1 event = event_from_json(lines[0]) assert type(event) is ForwardModelStepSuccess -def test_report_with_failed_exit_message_argument(unused_tcp_port): +async def test_report_with_failed_exit_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -112,10 +109,11 @@ def test_report_with_failed_exit_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Exited(fmstep1, 1).with_error("massive_failure")) - reporter.report(Finish()) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Exited(fmstep1, 1).with_error("massive_failure")) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 1 event = event_from_json(lines[0]) @@ -123,7 +121,7 @@ def test_report_with_failed_exit_message_argument(unused_tcp_port): assert event.error_msg == "massive_failure" -def test_report_with_running_message_argument(unused_tcp_port): +async def test_report_with_running_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -132,10 +130,11 @@ def test_report_with_running_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Finish()) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 1 event = event_from_json(lines[0]) @@ -144,7 +143,7 @@ def test_report_with_running_message_argument(unused_tcp_port): assert event.current_memory_usage == 10 -def test_report_only_job_running_for_successful_run(unused_tcp_port): +async def test_report_only_job_running_for_successful_run(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -153,15 +152,16 @@ def test_report_only_job_running_for_successful_run(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Finish()) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 1 -def test_report_with_failed_finish_message_argument(unused_tcp_port): +async def test_report_with_failed_finish_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -170,29 +170,32 @@ def test_report_with_failed_finish_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Finish().with_error("massive_failure")) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + await reporter.report(Finish().with_error("massive_failure")) + await reporter.join() assert len(lines) == 1 -def test_report_inconsistent_events(unused_tcp_port): +async def test_report_inconsistent_events(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines), pytest.raises( - TransitionError, - match=r"Illegal transition None -> \(MessageType,\)", - ): - reporter.report(Finish()) + async with _mock_ws_task(host, unused_tcp_port, lines): + with pytest.raises( + TransitionError, + match=r"Illegal transition None -> \(MessageType,\)", + ): + await reporter.report(Finish()) + reporter.cancel() @pytest.mark.integration_test -def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): +async def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): # this is to show when the reporter fails ert won't crash nor # staying hanging but instead finishes up the job; # see reporter._event_publisher_thread.join() @@ -201,8 +204,8 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): # which then sets _timeout_timestamp=None mock_send_retry_time = 2 - def mock_send(msg): - time.sleep(mock_send_retry_time) + async def mock_send(msg): + await asyncio.sleep(mock_send_retry_time) raise ClientConnectionError("Sending failed!") host = "localhost" @@ -213,18 +216,23 @@ def mock_send(msg): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + async with _mock_ws_task(host, unused_tcp_port, lines): with patch( "_ert.forward_model_runner.client.Client.send", lambda x, y: mock_send(y) ): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10)) + ) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10)) + ) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10)) + ) # set _stop_timestamp - reporter.report(Finish()) - if reporter._event_publisher_thread.is_alive(): - reporter._event_publisher_thread.join() + await reporter.report(Finish()) + await reporter.join() # set _stop_timestamp to None only when timer stopped assert reporter._timeout_timestamp is None assert len(lines) == 0, "expected 0 Job running messages" @@ -235,7 +243,7 @@ def mock_send(msg): @pytest.mark.skipif( sys.platform.startswith("darwin"), reason="Performance can be flaky" ) -def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): +async def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): # this is to show when the reporter fails but reconnects # reporter still manages to send events and completes fine # see assert reporter._timeout_timestamp is not None @@ -243,27 +251,33 @@ def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): # it finished succesfully mock_send_retry_time = 0.1 - def send_func(msg): - time.sleep(mock_send_retry_time) + async def send_func(msg): + await asyncio.sleep(mock_send_retry_time) raise ClientConnectionError("Sending failed!") host = "localhost" url = f"ws://{host}:{unused_tcp_port}" - reporter = Event(evaluator_url=url) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + async with _mock_ws_task(host, unused_tcp_port, lines): with patch("_ert.forward_model_runner.client.Client.send") as patched_send: + reporter = Event(evaluator_url=url) patched_send.side_effect = send_func - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10))) + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10)) + ) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10)) + ) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10)) + ) - _wait_until( + await async_wait_until( condition=lambda: patched_send.call_count == 3, timeout=10, fail_msg="10 seconds should be sufficient to send three events", @@ -271,23 +285,22 @@ def send_func(msg): # reconnect and continue sending events # set _stop_timestamp - reporter.report(Finish()) - if reporter._event_publisher_thread.is_alive(): - reporter._event_publisher_thread.join() + await reporter.report(Finish()) + await reporter.join() # set _stop_timestamp was not set to None since the reporter finished on time assert reporter._timeout_timestamp is not None assert len(lines) == 3, "expected 3 Job running messages" @pytest.mark.integration_test -def test_report_with_closed_received_exiting_gracefully(unused_tcp_port): +async def test_report_with_closed_received_exiting_gracefully(unused_tcp_port): # Whenever the receiver end closes the connection, a ConnectionClosedOK is raised # The reporter should exit the publisher thread gracefully and not send any # more events mock_send_retry_time = 3 - def mock_send(msg): - time.sleep(mock_send_retry_time) + async def mock_send(msg): + await asyncio.sleep(mock_send_retry_time) raise ClientConnectionClosedOK("Connection Closed") host = "localhost" @@ -297,13 +310,13 @@ def mock_send(msg): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) # sleep until both Running events have been received - _wait_until( + await async_wait_until( condition=lambda: len(lines) == 2, timeout=10, fail_msg="Should not take 10 seconds to send two events", @@ -312,15 +325,16 @@ def mock_send(msg): with patch( "_ert.forward_model_runner.client.Client.send", lambda x, y: mock_send(y) ): - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10))) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10)) + ) # Make sure the publisher thread exits because it got # ClientConnectionClosedOK. If it hangs it could indicate that the # exception is not caught/handled correctly - if reporter._event_publisher_thread.is_alive(): - reporter._event_publisher_thread.join() + await reporter.join() - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=400, rss=10))) - reporter.report(Finish()) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=400, rss=10))) + await reporter.report(Finish()) # set _stop_timestamp was not set to None since the reporter finished on time assert reporter._timeout_timestamp is not None diff --git a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py index 474ff102785..36e970cd7a1 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py +++ b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py @@ -1,17 +1,15 @@ from __future__ import annotations +import asyncio import glob import importlib import json import os -import signal import stat import subprocess import sys from subprocess import Popen from textwrap import dedent -from threading import Lock -from unittest.mock import mock_open, patch import pandas as pd import psutil @@ -21,17 +19,13 @@ from _ert.forward_model_runner.cli import JOBS_FILE, _setup_reporters, main from _ert.forward_model_runner.forward_model_step import killed_by_oom from _ert.forward_model_runner.reporting import Event, Interactive -from _ert.forward_model_runner.reporting.message import Finish, Init -from _ert.threading import ErtThread -from tests.ert.utils import _mock_ws_thread, wait_until - -from .test_event_reporter import _wait_until +from tests.ert.utils import _mock_ws_task, async_wait_until, wait_until @pytest.mark.usefixtures("use_tmpdir") -def test_terminate_steps(): +async def test_terminate_steps(): # Executes itself recursively and sleeps for 100 seconds - with open("dummy_executable", "w", encoding="utf-8") as f: + with open("dummy_executable", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write( """#!/usr/bin/env python import sys, os, time @@ -73,11 +67,11 @@ def test_terminate_steps(): "ert_pid": "", } - with open(JOBS_FILE, "w", encoding="utf-8") as f: + with open(JOBS_FILE, "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write(json.dumps(step_list)) # macOS doesn't provide /usr/bin/setsid, so we roll our own - with open("setsid", "w", encoding="utf-8") as f: + with open("setsid", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write( dedent( """\ @@ -95,7 +89,7 @@ def test_terminate_steps(): "_ert.forward_model_runner.job_dispatch" ).origin # (we wait for the process below) - job_dispatch_process = Popen( + job_dispatch_process = Popen( # noqa: ASYNC220 [ os.getcwd() + "/setsid", sys.executable, @@ -113,7 +107,8 @@ def test_terminate_steps(): wait_until(lambda: len(p.children(recursive=True)) == 0) - os.wait() # allow os to clean up zombie processes + # allow os to clean up zombie processes + os.wait() # noqa: ASYNC222 @pytest.mark.usefixtures("use_tmpdir") @@ -294,9 +289,12 @@ def test_missing_directory_exits(tmp_path): main(["script.py", str(tmp_path / "non_existent")]) -def test_retry_of_jobs_json_file_read(unused_tcp_port, tmp_path, monkeypatch, caplog): - lock = Lock() - lock.acquire() +async def test_retry_of_jobs_json_file_read( + unused_tcp_port, tmp_path, monkeypatch, caplog +): + lock = asyncio.Lock() + await lock.acquire() + monkeypatch.setattr(_ert.forward_model_runner.cli, "_wait_for_retry", lock.acquire) jobs_json = json.dumps( { @@ -306,17 +304,19 @@ def test_retry_of_jobs_json_file_read(unused_tcp_port, tmp_path, monkeypatch, ca } ) - with _mock_ws_thread("localhost", unused_tcp_port, []): - thread = ErtThread(target=main, args=[["script.py", str(tmp_path)]]) - thread.start() - _wait_until( + async with _mock_ws_task("localhost", unused_tcp_port, []): + fm_runner_task = asyncio.create_task(main(["script.py", str(tmp_path)])) + + await async_wait_until( lambda: f"Could not find file {JOBS_FILE}, retrying" in caplog.text, 2, "Did not get expected log message from missing jobs.json", ) (tmp_path / JOBS_FILE).write_text(jobs_json) + await asyncio.sleep(0) lock.release() - thread.join() + + await fm_runner_task @pytest.mark.parametrize( @@ -339,29 +339,6 @@ def test_setup_reporters(is_interactive_run, ens_id): assert any(isinstance(r, Interactive) for r in reporters) -@pytest.mark.usefixtures("use_tmpdir") -def test_job_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port): - host = "localhost" - port = unused_tcp_port - jobs_json = json.dumps({"ens_id": "_id_", "dispatch_url": f"ws://localhost:{port}"}) - - with patch("_ert.forward_model_runner.cli.os.killpg") as mock_killpg, patch( - "_ert.forward_model_runner.cli.os.getpgid" - ) as mock_getpgid, patch( - "_ert.forward_model_runner.cli.open", new=mock_open(read_data=jobs_json) - ), patch("_ert.forward_model_runner.cli.ForwardModelRunner") as mock_runner: - mock_runner.return_value.run.return_value = [ - Init([], 0, 0), - Finish().with_error("overall bad run"), - ] - mock_getpgid.return_value = 17 - - with _mock_ws_thread(host, port, []): - main(["script.py"]) - - mock_killpg.assert_called_with(17, signal.SIGKILL) - - @pytest.mark.skipif(sys.platform.startswith("darwin"), reason="No oom_score on MacOS") def test_killed_by_oom(tmp_path, monkeypatch): """Test out-of-memory detection for pid and descendants based diff --git a/tests/ert/utils.py b/tests/ert/utils.py index 732f816f8cd..31e2dfb25b7 100644 --- a/tests/ert/utils.py +++ b/tests/ert/utils.py @@ -61,6 +61,16 @@ def wait_until(func, interval=0.5, timeout=30): ) +async def async_wait_until(condition, timeout, fail_msg, interval=0.1): + t = 0 + while t < timeout: + await asyncio.sleep(interval) + if condition(): + return + t += interval + raise AssertionError(fail_msg) + + def _mock_ws(host, port, messages, delay_startup=0): loop = asyncio.new_event_loop() done = loop.create_future() @@ -70,6 +80,7 @@ async def _handler(websocket, path): msg = await websocket.recv() messages.append(msg) if msg == "stop": + print("SHOULD STOP!") done.set_result(None) break @@ -82,8 +93,24 @@ async def _run_server(): loop.close() -@contextlib.contextmanager -def _mock_ws_thread(host, port, messages): +async def _mock_ws_async(host, port, messages, delay_startup=0): + done = asyncio.Future() + + async def _handler(websocket, path): + while True: + msg = await websocket.recv() + messages.append(msg) + if msg == "stop": + done.set_result(None) + break + + await asyncio.sleep(delay_startup) + async with websockets.server.serve(_handler, host, port): + await done + + +@contextlib.asynccontextmanager +async def _mock_ws_thread(host, port, messages): mock_ws_thread = ErtThread( target=partial(_mock_ws, messages=messages), args=( @@ -97,12 +124,26 @@ def _mock_ws_thread(host, port, messages): # Make sure to join the thread even if an exception occurs finally: url = f"ws://{host}:{port}" - with Client(url) as client: - client.send("stop") + async with Client(url) as client: + await client.send("stop") mock_ws_thread.join() messages.pop() +@contextlib.asynccontextmanager +async def _mock_ws_task(host, port, messages): + mock_ws_task = asyncio.create_task(_mock_ws_async(host, port, messages)) + try: + yield + # Make sure to join the thread even if an exception occurs + finally: + url = f"ws://{host}:{port}" + async with Client(url) as client: + await client.send("stop") + await mock_ws_task + messages.pop() + + async def poll(driver: Driver, expected: set[int], *, started=None, finished=None): """Poll driver until expected realisations finish