diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 2566ca005f8..5618dd355d6 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -1,19 +1,9 @@ -import asyncio import logging -import ssl -from typing import Any, AnyStr, Optional, Union +import time +from typing import Any, Optional, Union +import zmq from typing_extensions import Self -from websockets.client import WebSocketClientProtocol, connect -from websockets.datastructures import Headers -from websockets.exceptions import ( - ConnectionClosedError, - ConnectionClosedOK, - InvalidHandshake, - InvalidURI, -) - -from _ert.async_utils import new_event_loop logger = logging.getLogger(__name__) @@ -35,18 +25,8 @@ def __enter__(self) -> Self: return self def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None: - if self.websocket is not None: - self.loop.run_until_complete(self.websocket.close()) - self.loop.close() - - async def __aenter__(self) -> "Client": - return self - - async def __aexit__( - self, exc_type: Any, exc_value: Any, exc_traceback: Any - ) -> None: - if self.websocket is not None: - await self.websocket.close() + self.socket.close() + self.context.term() def __init__( self, @@ -60,79 +40,48 @@ def __init__( max_retries = self.DEFAULT_MAX_RETRIES if timeout_multiplier is None: timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER - if url is None: - raise ValueError("url was None") self.url = url self.token = token - self._extra_headers = Headers() + + # Set up ZeroMQ context and socket + self.context = zmq.Context() + self.socket = self.context.socket(zmq.PUSH) + # reduce backlog + self.socket.setsockopt(zmq.LINGER, 0) + # if server is not ready yet, no message is sent + self.socket.setsockopt(zmq.IMMEDIATE, 1) if token is not None: - self._extra_headers["token"] = token - - # Mimics the behavior of the ssl argument when connection to - # websockets. If none is specified it will deduce based on the url, - # if True it will enforce TLS, and if you want to use self signed - # certificates you need to pass an ssl_context with the certificate - # loaded. - self._ssl_context: Optional[Union[bool, ssl.SSLContext]] = None - if cert is not None: - self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - self._ssl_context.load_verify_locations(cadata=cert) - elif url.startswith("wss"): - self._ssl_context = True + client_public, client_secret = zmq.curve_keypair() + self.socket.curve_secretkey = client_secret + self.socket.curve_publickey = client_public + self.socket.curve_serverkey = token.encode("utf-8") + self.socket.connect(url) self._max_retries = max_retries self._timeout_multiplier = timeout_multiplier - self.websocket: Optional[WebSocketClientProtocol] = None - self.loop = new_event_loop() - - async def get_websocket(self) -> WebSocketClientProtocol: - return await connect( - self.url, - ssl=self._ssl_context, - extra_headers=self._extra_headers, - open_timeout=self.CONNECTION_TIMEOUT, - ping_timeout=self.CONNECTION_TIMEOUT, - ping_interval=self.CONNECTION_TIMEOUT, - close_timeout=self.CONNECTION_TIMEOUT, - ) - - async def _send(self, msg: AnyStr) -> None: - for retry in range(self._max_retries + 1): + + self.reconnect() + + def reconnect(self): + """Connect to the server with exponential backoff.""" + retries = self._max_retries + while retries > 0: try: - if self.websocket is None: - self.websocket = await self.get_websocket() - await self.websocket.send(msg) - return - except ConnectionClosedOK as exception: - _error_msg = ( - f"Connection closed received from the server {self.url}! " - f" Exception from {type(exception)}: {exception!s}" - ) - raise ClientConnectionClosedOK(_error_msg) from exception - except ( - InvalidHandshake, - InvalidURI, - OSError, - asyncio.TimeoutError, - ) as exception: - if retry == self._max_retries: - _error_msg = ( - f"Not able to establish the " - f"websocket connection {self.url}! Max retries reached!" - " Check for firewall issues." - f" Exception from {type(exception)}: {exception!s}" - ) - raise ClientConnectionError(_error_msg) from exception - except ConnectionClosedError as exception: - if retry == self._max_retries: - _error_msg = ( - f"Not been able to send the event" - f" to {self.url}! Max retries reached!" - f" Exception from {type(exception)}: {exception!s}" - ) - raise ClientConnectionError(_error_msg) from exception - await asyncio.sleep(0.2 + self._timeout_multiplier * retry) - self.websocket = None - - def send(self, msg: AnyStr) -> None: - self.loop.run_until_complete(self._send(msg)) + self.socket.connect(self.url) + break + except zmq.ZMQError as e: + logger.warning(f"Failed to connect to {self.url}: {e}") + retries -= 1 + if retries == 0: + raise e + # Exponential backoff + sleep_time = self._timeout_multiplier * (self._max_retries - retries) + time.sleep(sleep_time) + + def send(self, message): + try: + self.socket.send_multipart([b"dispatch", message.encode("utf-8")]) + except zmq.ZMQError as e: + logger.warning(f"Failed to send message: {e}") + self.reconnect() + self.socket.send_multipart([b"dispatch", message.encode("utf-8")]) diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index ecc1d5c81d5..e682563dc9f 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -31,13 +31,8 @@ from ert.run_arg import RunArg from ert.scheduler import Scheduler, create_driver -from ._wait_for_evaluator import wait_for_evaluator from .config import EvaluatorServerConfig -from .snapshot import ( - EnsembleSnapshot, - FMStepSnapshot, - RealizationSnapshot, -) +from .snapshot import EnsembleSnapshot, FMStepSnapshot, RealizationSnapshot from .state import ( ENSEMBLE_STATE_CANCELLED, ENSEMBLE_STATE_FAILED, @@ -204,8 +199,8 @@ async def send_event( cert: Optional[Union[str, bytes]] = None, retries: int = 10, ) -> None: - async with Client(url, token, cert, max_retries=retries) as client: - await client._send(event_to_json(event)) + with Client(url, token, cert, max_retries=retries) as client: + client.send(event_to_json(event)) def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]: def event_builder(status: str) -> Event: @@ -230,16 +225,16 @@ async def evaluate( ce_unary_send_method_name, partialmethod( self.__class__.send_event, - self._config.dispatch_uri, + self._config.get_connection_info().push_pull_uri, token=self._config.token, cert=self._config.cert, ), ) - await wait_for_evaluator( - base_url=self._config.url, - token=self._config.token, - cert=self._config.cert, - ) + # await wait_for_evaluator( + # base_url=self._config.url, + # token=self._config.token, + # cert=self._config.cert, + # ) await self._evaluate_inner( event_unary_send=getattr(self, ce_unary_send_method_name), scheduler_queue=scheduler_queue, @@ -282,7 +277,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches max_running=self._queue_config.max_running, submit_sleep=self._queue_config.submit_sleep, ens_id=self.id_, - ee_uri=self._config.dispatch_uri, + ee_uri=self._config.get_connection_info().push_pull_uri, ee_cert=self._config.cert, ee_token=self._config.token, ) diff --git a/src/ert/ensemble_evaluator/config.py b/src/ert/ensemble_evaluator/config.py index 79c127cccdb..0a97db4679f 100644 --- a/src/ert/ensemble_evaluator/config.py +++ b/src/ert/ensemble_evaluator/config.py @@ -11,6 +11,7 @@ from datetime import datetime, timedelta from typing import Optional +import zmq from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization @@ -129,13 +130,22 @@ def __init__( custom_host: typing.Optional[str] = None, ) -> None: self._socket_handle = find_available_socket( - custom_range=custom_port_range, custom_host=custom_host + custom_range=custom_port_range, + custom_host=custom_host, + will_close_then_reopen_socket=True, ) host, port = self._socket_handle.getsockname() - self.protocol = "wss" if generate_cert else "ws" - self.url = f"{self.protocol}://{host}:{port}" - self.client_uri = f"{self.url}/client" - self.dispatch_uri = f"{self.url}/dispatch" + self.host = host + self.pub_sub_port = port + + self._socket_handle = find_available_socket( + custom_range=custom_port_range, + custom_host=custom_host, + will_close_then_reopen_socket=True, + ) + host, port = self._socket_handle.getsockname() + self.push_pull_port = port + if generate_cert: cert, key, pw = _generate_certificate(host) else: @@ -145,13 +155,16 @@ def __init__( self._key_pw = pw self.token = _generate_authentication() if use_token else None + self.server_public_key, self.server_secret_key = zmq.curve_keypair() + self.token = self.server_public_key.decode("utf-8") def get_socket(self) -> socket.socket: return self._socket_handle.dup() def get_connection_info(self) -> EvaluatorConnectionInfo: return EvaluatorConnectionInfo( - self.url, + f"tcp://{self.host}:{self.push_pull_port}", + f"tcp://{self.host}:{self.pub_sub_port}", self.cert, self.token, ) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 3855ec85cac..fcf2a34edf6 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -1,16 +1,14 @@ +from __future__ import annotations + import asyncio import datetime import logging import traceback -from contextlib import asynccontextmanager, contextmanager -from http import HTTPStatus from typing import ( Any, - AsyncIterator, Awaitable, Callable, Dict, - Generator, Iterable, List, Optional, @@ -22,11 +20,7 @@ get_args, ) -import websockets -from pydantic_core._pydantic_core import ValidationError -from websockets.datastructures import Headers, HeadersLike -from websockets.exceptions import ConnectionClosedError -from websockets.server import WebSocketServerProtocol +import zmq.asyncio from _ert.events import ( EESnapshot, @@ -70,15 +64,11 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._loop: Optional[asyncio.AbstractEventLoop] = None - self._clients: Set[WebSocketServerProtocol] = set() - self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue() - self._events: asyncio.Queue[Event] = asyncio.Queue() self._events_to_send: asyncio.Queue[Event] = asyncio.Queue() self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue() self._ee_tasks: List[asyncio.Task[None]] = [] - self._server_started: asyncio.Event = asyncio.Event() self._server_done: asyncio.Event = asyncio.Event() # batching section @@ -88,14 +78,39 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._max_batch_size: int = 500 self._batching_interval: int = 2 self._complete_batch: asyncio.Event = asyncio.Event() + self._zmq_context: zmq.asyncio.Context | None = None + self._clients_connected: asyncio.Queue[None] = asyncio.Queue() + self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue() + + async def _initialize_zmq(self) -> None: + self._zmq_context = zmq.asyncio.Context() # type: ignore + try: + self._pull_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL) + self._pull_socket.curve_secretkey = self._config.server_secret_key + self._pull_socket.curve_publickey = self._config.server_public_key + self._pull_socket.curve_server = True + self._pull_socket.setsockopt(zmq.LINGER, 0) + self._pull_socket.setsockopt(zmq.RCVHWM, 10000) + self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") + + self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket( + zmq.XPUB + ) + + self._publisher_socket.curve_secretkey = self._config.server_secret_key + self._publisher_socket.curve_publickey = self._config.server_public_key + self._publisher_socket.curve_server = True + self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}") + + except zmq.error.ZMQError as e: + logger.error(f"ZMQ error: {e}") + raise + logger.info("ZMQ initialized") async def _publisher(self) -> None: while True: event = await self._events_to_send.get() - await asyncio.gather( - *[client.send(event_to_json(event)) for client in self._clients], - return_exceptions=True, - ) + await self._publisher_socket.send_string(event_to_json(event)) self._events_to_send.task_done() async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None: @@ -206,139 +221,90 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None: def ensemble(self) -> Ensemble: return self._ensemble - @contextmanager - def store_client( - self, websocket: WebSocketServerProtocol - ) -> Generator[None, None, None]: - self._clients.add(websocket) - yield - self._clients.remove(websocket) - - async def handle_client(self, websocket: WebSocketServerProtocol) -> None: - with self.store_client(websocket): - current_snapshot_dict = self._ensemble.snapshot.to_dict() - event: Event = EESnapshot( - snapshot=current_snapshot_dict, ensemble=self.ensemble.id_ - ) - await websocket.send(event_to_json(event)) - - async for raw_msg in websocket: - event = event_from_json(raw_msg) - logger.debug(f"got message from client: {event}") - if type(event) is EEUserCancel: - logger.debug(f"Client {websocket.remote_address} asked to cancel.") - self._signal_cancel() - - elif type(event) is EEUserDone: - logger.debug(f"Client {websocket.remote_address} signalled done.") - self.stop() - - @asynccontextmanager - async def count_dispatcher(self) -> AsyncIterator[None]: - await self._dispatchers_connected.put(None) - yield - await self._dispatchers_connected.get() - self._dispatchers_connected.task_done() - - async def handle_dispatch(self, websocket: WebSocketServerProtocol) -> None: - async with self.count_dispatcher(): + async def listen_for_clients(self) -> None: + while True: try: - async for raw_msg in websocket: - try: - event = dispatch_event_from_json(raw_msg) - if event.ensemble != self.ensemble.id_: - logger.info( - "Got event from evaluator " - f"{event.ensemble}. " - f"Ignoring since I am {self.ensemble.id_}" - ) - continue - if type(event) is ForwardModelStepChecksum: - await self.forward_checksum(event) - else: - await self._events.put(event) - except ValidationError as ex: - logger.warning( - "cannot handle event - " - f"closing connection to dispatcher: {ex}" - ) - await websocket.close( - code=1011, reason=f"failed handling message {raw_msg!r}" + raw_msg = await self._publisher_socket.recv() + # this would inform all subscribers about the snapshot + if raw_msg[0] == 1: + await self._clients_connected.put(None) + current_snapshot_dict = self._ensemble.snapshot.to_dict() + event: Event = EESnapshot( + snapshot=current_snapshot_dict, ensemble=self.ensemble.id_ + ) + await self._publisher_socket.send_string(event_to_json(event)) + + elif raw_msg[0] == 0: + await self._clients_connected.get() + self._clients_connected.task_done() + except zmq.error.ZMQError as e: + if e.errno == zmq.ENOTSOCK: + logger.warning( + "Evaluator publisher closed, no new clients accepted" + ) + else: + logger.error(f"Unexpected error when connecting new clients: {e}") + return + except asyncio.CancelledError: + return + + async def listen_for_messages(self) -> None: + while True: + try: + sender, raw_msg = await self._pull_socket.recv_multipart() + sender = sender.decode("utf-8") + raw_msg = raw_msg.decode("utf-8") + if sender == "client": + event = event_from_json(raw_msg) + if type(event) is EEUserCancel: + logger.debug("Client asked to cancel.") + self._signal_cancel() + elif type(event) is EEUserDone: + logger.debug("Client signalled done.") + self.stop() + elif sender == "dispatch": + event = dispatch_event_from_json(raw_msg) + if event.ensemble != self.ensemble.id_: + logger.info( + "Got event from evaluator " + f"{event.ensemble}. " + f"Ignoring since I am {self.ensemble.id_}" ) - return - - if type(event) in [EnsembleSucceeded, EnsembleFailed]: - return - except ConnectionClosedError as connection_error: - # Dispatchers may close the connection abruptly in the case of - # * flaky network (then the dispatcher will try to reconnect) - # * job being killed due to MAX_RUNTIME - # * job being killed by user - logger.error( - f"a dispatcher abruptly closed a websocket: {connection_error!s}" - ) + continue + if type(event) is ForwardModelStepChecksum: + await self.forward_checksum(event) + else: + await self._events.put(event) + # if type(event) in [EnsembleSucceeded, EnsembleFailed]: + # return + else: + logger.info(f"Connection attempt to unknown sender: {sender}.") + except zmq.error.ZMQError as e: + if e.errno == zmq.ENOTSOCK: + logger.warning( + "Evaluator receiver closed, no new messages are received" + ) + else: + logger.error(f"Unexpected error when listening to messages: {e}") + except asyncio.CancelledError: + return async def forward_checksum(self, event: Event) -> None: # clients still need to receive events via ws await self._events_to_send.put(event) await self._manifest_queue.put(event) - async def connection_handler(self, websocket: WebSocketServerProtocol) -> None: - path = websocket.path - elements = path.split("/") - if elements[1] == "client": - await self.handle_client(websocket) - elif elements[1] == "dispatch": - await self.handle_dispatch(websocket) - else: - logger.info(f"Connection attempt to unknown path: {path}.") - - async def process_request( - self, path: str, request_headers: Headers - ) -> Optional[Tuple[HTTPStatus, HeadersLike, bytes]]: - if request_headers.get("token") != self._config.token: - return HTTPStatus.UNAUTHORIZED, {}, b"" - if path == "/healthcheck": - return HTTPStatus.OK, {}, b"" - return None - async def _server(self) -> None: - async with websockets.serve( - self.connection_handler, - sock=self._config.get_socket(), - ssl=self._config.get_server_ssl_context(), - process_request=self.process_request, - max_queue=None, - max_size=2**26, - ping_timeout=60, - ping_interval=60, - close_timeout=60, - ) as server: - self._server_started.set() - await self._server_done.wait() - server.close(close_connections=False) - if self._dispatchers_connected is not None: - logger.debug( - f"Got done signal. {self._dispatchers_connected.qsize()} " - "dispatchers to disconnect..." - ) - try: # Wait for dispatchers to disconnect - await asyncio.wait_for( - self._dispatchers_connected.join(), timeout=20 - ) - except asyncio.TimeoutError: - logger.debug("Timed out waiting for dispatchers to disconnect") - else: - logger.debug("Got done signal. No dispatchers connected") - - logger.debug("Sending termination-message to clients...") - - await self._events.join() - await self._complete_batch.wait() - await self._batch_processing_queue.join() - event = EETerminated(ensemble=self._ensemble.id_) - await self._events_to_send.put(event) - await self._events_to_send.join() + await self._server_done.wait() + await self._events.join() + await self._complete_batch.wait() + await self._batch_processing_queue.join() + event = EETerminated(ensemble=self._ensemble.id_) + await self._events_to_send.put(event) + await self._events_to_send.join() + await self._clients_connected.join() + self._pull_socket.close() + self._publisher_socket.close() logger.debug("Async server exiting.") def stop(self) -> None: @@ -365,6 +331,7 @@ async def _start_running(self) -> None: if not self._config: raise ValueError("no config for evaluator") self._loop = asyncio.get_running_loop() + await self._initialize_zmq() self._ee_tasks = [ asyncio.create_task(self._server(), name="server_task"), asyncio.create_task( @@ -372,9 +339,9 @@ async def _start_running(self) -> None: ), asyncio.create_task(self._process_event_buffer(), name="processing_task"), asyncio.create_task(self._publisher(), name="publisher_task"), + asyncio.create_task(self.listen_for_messages(), name="listener_task"), + asyncio.create_task(self.listen_for_clients(), name="client_task"), ] - # now we wait for the server to actually start - await self._server_started.wait() self._ee_tasks.append( asyncio.create_task( @@ -426,7 +393,11 @@ async def _monitor_and_handle_tasks(self) -> None: if stop_timeout_task: stop_timeout_task.cancel() return - elif task.get_name() == "ensemble_task": + elif task.get_name() in [ + "ensemble_task", + "listener_task", + "client_task", + ]: stop_timeout_task = asyncio.create_task( self._wait_for_stopped_server() ) diff --git a/src/ert/ensemble_evaluator/evaluator_connection_info.py b/src/ert/ensemble_evaluator/evaluator_connection_info.py index bd48e08e4a1..399f13da1d8 100644 --- a/src/ert/ensemble_evaluator/evaluator_connection_info.py +++ b/src/ert/ensemble_evaluator/evaluator_connection_info.py @@ -6,18 +6,7 @@ class EvaluatorConnectionInfo: """Read only server-info""" - url: str + push_pull_uri: str + pub_sub_uri: str cert: Optional[Union[str, bytes]] = None token: Optional[str] = None - - @property - def dispatch_uri(self) -> str: - return f"{self.url}/dispatch" - - @property - def client_uri(self) -> str: - return f"{self.url}/client" - - @property - def result_uri(self) -> str: - return f"{self.url}/result" diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 93bc2ec5e1e..c2dd458c2ea 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -1,12 +1,12 @@ +from __future__ import annotations + import asyncio import logging import ssl import uuid from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Optional, Union -from aiohttp import ClientError -from websockets import ConnectionClosed, Headers, WebSocketClientProtocol -from websockets.client import connect +import zmq.asyncio from _ert.events import ( EETerminated, @@ -16,7 +16,6 @@ event_from_json, event_to_json, ) -from ert.ensemble_evaluator._wait_for_evaluator import wait_for_evaluator if TYPE_CHECKING: from ert.ensemble_evaluator.evaluator_connection_info import EvaluatorConnectionInfo @@ -36,11 +35,13 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None: self._ee_con_info = ee_con_info self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0] self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue() - self._connection: Optional[WebSocketClientProtocol] = None self._receiver_task: Optional[asyncio.Task[None]] = None self._connected: asyncio.Event = asyncio.Event() self._connection_timeout: float = 120.0 self._receiver_timeout: float = 60.0 + self._zmq_context = zmq.asyncio.Context() # type: ignore + # self._listen_socket: zmq.asyncio.Socket | None = None + # self._push_socket: zmq.asyncio.Socket | None = None async def __aenter__(self) -> "Monitor": self._receiver_task = asyncio.create_task(self._receiver()) @@ -57,6 +58,9 @@ async def __aenter__(self) -> "Monitor": async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self._receiver_task: + if self._listen_socket and self._push_socket: + self._listen_socket.close() + self._push_socket.close() if not self._receiver_task.done(): self._receiver_task.cancel() # we are done and not interested in errors when cancelling @@ -65,27 +69,24 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None return_exceptions=True, ) - if self._connection: - await self._connection.close() - async def signal_cancel(self) -> None: - if not self._connection: - return await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} asking server to cancel...") cancel_event = EEUserCancel(monitor=self._id) - await self._connection.send(event_to_json(cancel_event)) + await self._push_socket.send_multipart( + [b"client", event_to_json(cancel_event).encode("utf-8")] + ) logger.debug(f"monitor-{self._id} asked server to cancel") async def signal_done(self) -> None: - if not self._connection: - return await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} informing server monitor is done...") done_event = EEUserDone(monitor=self._id) - await self._connection.send(event_to_json(done_event)) + await self._push_socket.send_multipart( + [b"client", event_to_json(done_event).encode("utf-8")] + ) logger.debug(f"monitor-{self._id} informed server monitor is done") async def track( @@ -124,36 +125,35 @@ async def _receiver(self) -> None: if self._ee_con_info.cert: tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) tls.load_verify_locations(cadata=self._ee_con_info.cert) - headers = Headers() - if self._ee_con_info.token: - headers["token"] = self._ee_con_info.token - - await wait_for_evaluator( - base_url=self._ee_con_info.url, - token=self._ee_con_info.token, - cert=self._ee_con_info.cert, - timeout=5, - ) - async for conn in connect( - self._ee_con_info.client_uri, - ssl=tls, - extra_headers=headers, - max_size=2**26, - max_queue=500, - open_timeout=5, - ping_timeout=60, - ping_interval=60, - close_timeout=60, - ): + + self._listen_socket = self._zmq_context.socket(zmq.SUB) + self._push_socket = self._zmq_context.socket(zmq.PUSH) + + if self._ee_con_info.token is not None: + client_public, client_secret = zmq.curve_keypair() + self._listen_socket.curve_secretkey = client_secret + self._listen_socket.curve_publickey = client_public + self._listen_socket.curve_serverkey = self._ee_con_info.token.encode( + "utf-8" + ) + self._push_socket.curve_secretkey = client_secret + self._push_socket.curve_publickey = client_public + self._push_socket.curve_serverkey = self._ee_con_info.token.encode("utf-8") + + self._listen_socket.connect(self._ee_con_info.pub_sub_uri) + self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "") + + self._push_socket.connect(self._ee_con_info.push_pull_uri) + self._connected.set() + + while True: try: - self._connection = conn - self._connected.set() - async for raw_msg in self._connection: - event = event_from_json(raw_msg) - await self._event_queue.put(event) - except (ConnectionRefusedError, ConnectionClosed, ClientError) as exc: - self._connection = None - self._connected.clear() + raw_msg = await self._listen_socket.recv_string() + event = event_from_json(raw_msg) + await self._event_queue.put(event) + except zmq.ZMQError as exc: + # Handle disconnection or other ZMQ errors (reconnect or log) logger.debug( - f"Monitor connection to EnsembleEvaluator went down, reconnecting: {exc}" + f"ZeroMQ connection to EnsembleEvaluator went down, reconnecting: {exc}" ) + await asyncio.sleep(1) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index b000d5fa491..557c737c939 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -26,12 +26,7 @@ import numpy as np -from _ert.events import ( - EESnapshot, - EESnapshotUpdate, - EETerminated, - Event, -) +from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event from ert.analysis import ( AnalysisEvent, AnalysisStatusEvent, @@ -507,7 +502,6 @@ async def run_monitor( event, iteration, ) - if event.snapshot.get(STATUS) in [ ENSEMBLE_STATE_STOPPED, ENSEMBLE_STATE_FAILED, diff --git a/src/ert/shared/net_utils.py b/src/ert/shared/net_utils.py index 66c12aef6c9..2cf467481ac 100644 --- a/src/ert/shared/net_utils.py +++ b/src/ert/shared/net_utils.py @@ -111,6 +111,7 @@ def _bind_socket( if will_close_then_reopen_socket: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) else: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0)