From b7c96a22219515414e18b710d3bfe1e1ebc36a2d Mon Sep 17 00:00:00 2001 From: xjules Date: Mon, 11 Nov 2024 15:30:45 +0100 Subject: [PATCH 01/16] Replace client with zmq push --- src/_ert/forward_model_runner/client.py | 139 ++++++++---------------- 1 file changed, 44 insertions(+), 95 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 2566ca005f8..c417067d2bb 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -1,19 +1,9 @@ -import asyncio import logging -import ssl -from typing import Any, AnyStr, Optional, Union +import time +from typing import Any, Optional, Union +import zmq from typing_extensions import Self -from websockets.client import WebSocketClientProtocol, connect -from websockets.datastructures import Headers -from websockets.exceptions import ( - ConnectionClosedError, - ConnectionClosedOK, - InvalidHandshake, - InvalidURI, -) - -from _ert.async_utils import new_event_loop logger = logging.getLogger(__name__) @@ -35,18 +25,8 @@ def __enter__(self) -> Self: return self def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None: - if self.websocket is not None: - self.loop.run_until_complete(self.websocket.close()) - self.loop.close() - - async def __aenter__(self) -> "Client": - return self - - async def __aexit__( - self, exc_type: Any, exc_value: Any, exc_traceback: Any - ) -> None: - if self.websocket is not None: - await self.websocket.close() + self.socket.close() + self.context.term() def __init__( self, @@ -60,79 +40,48 @@ def __init__( max_retries = self.DEFAULT_MAX_RETRIES if timeout_multiplier is None: timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER - if url is None: - raise ValueError("url was None") self.url = url self.token = token - self._extra_headers = Headers() - if token is not None: - self._extra_headers["token"] = token - - # Mimics the behavior of the ssl argument when connection to - # websockets. If none is specified it will deduce based on the url, - # if True it will enforce TLS, and if you want to use self signed - # certificates you need to pass an ssl_context with the certificate - # loaded. - self._ssl_context: Optional[Union[bool, ssl.SSLContext]] = None - if cert is not None: - self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - self._ssl_context.load_verify_locations(cadata=cert) - elif url.startswith("wss"): - self._ssl_context = True + + # Set up ZeroMQ context and socket + self.context = zmq.Context() + self.socket = self.context.socket(zmq.PUSH) + + if cert: + client_public, client_secret = zmq.curve_keypair() + server_public = cert + + self.socket.curve_secretkey = client_secret + self.socket.curve_publickey = client_public + self.socket.curve_serverkey = server_public self._max_retries = max_retries self._timeout_multiplier = timeout_multiplier - self.websocket: Optional[WebSocketClientProtocol] = None - self.loop = new_event_loop() - - async def get_websocket(self) -> WebSocketClientProtocol: - return await connect( - self.url, - ssl=self._ssl_context, - extra_headers=self._extra_headers, - open_timeout=self.CONNECTION_TIMEOUT, - ping_timeout=self.CONNECTION_TIMEOUT, - ping_interval=self.CONNECTION_TIMEOUT, - close_timeout=self.CONNECTION_TIMEOUT, - ) - - async def _send(self, msg: AnyStr) -> None: - for retry in range(self._max_retries + 1): + + self.reconnect() + + def reconnect(self): + """Connect to the server with exponential backoff.""" + retries = self._max_retries + while retries > 0: try: - if self.websocket is None: - self.websocket = await self.get_websocket() - await self.websocket.send(msg) - return - except ConnectionClosedOK as exception: - _error_msg = ( - f"Connection closed received from the server {self.url}! " - f" Exception from {type(exception)}: {exception!s}" - ) - raise ClientConnectionClosedOK(_error_msg) from exception - except ( - InvalidHandshake, - InvalidURI, - OSError, - asyncio.TimeoutError, - ) as exception: - if retry == self._max_retries: - _error_msg = ( - f"Not able to establish the " - f"websocket connection {self.url}! Max retries reached!" - " Check for firewall issues." - f" Exception from {type(exception)}: {exception!s}" - ) - raise ClientConnectionError(_error_msg) from exception - except ConnectionClosedError as exception: - if retry == self._max_retries: - _error_msg = ( - f"Not been able to send the event" - f" to {self.url}! Max retries reached!" - f" Exception from {type(exception)}: {exception!s}" - ) - raise ClientConnectionError(_error_msg) from exception - await asyncio.sleep(0.2 + self._timeout_multiplier * retry) - self.websocket = None - - def send(self, msg: AnyStr) -> None: - self.loop.run_until_complete(self._send(msg)) + self.socket.connect(self.url) + break + except zmq.ZMQError as e: + print(f"Failed to connect to {self.url}: {e}") + retries -= 1 + if retries == 0: + raise e + # Exponential backoff + sleep_time = self._timeout_multiplier * (self._max_retries - retries) + time.sleep(sleep_time) + + def push(self, message): + try: + if self.token: + message = f"{self.token}:{message}" + self.socket.send_string(message) + except zmq.ZMQError as e: + print(f"Failed to send message: {e}") + self.reconnect() + self.socket.send_string(message) From 519b9affd5a8eb70da842983c1f6a62bdb4f7513 Mon Sep 17 00:00:00 2001 From: xjules Date: Tue, 12 Nov 2024 09:15:07 +0100 Subject: [PATCH 02/16] WIP: replace websockets with zmq in ensemble evaluator --- src/ert/ensemble_evaluator/evaluator.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 3855ec85cac..f0759eec215 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -23,6 +23,7 @@ ) import websockets +import zmq.asyncio from pydantic_core._pydantic_core import ValidationError from websockets.datastructures import Headers, HeadersLike from websockets.exceptions import ConnectionClosedError @@ -89,13 +90,19 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._batching_interval: int = 2 self._complete_batch: asyncio.Event = asyncio.Event() + async def _initialize_zmq(self) -> None: + self._zmq_context = zmq.asyncio.Context() + self._receiver_socket = self._zmq_context.socket(zmq.PULL) + self._publisher_socket = self._zmq_context.socket(zmq.PUB) + async def _publisher(self) -> None: while True: event = await self._events_to_send.get() - await asyncio.gather( - *[client.send(event_to_json(event)) for client in self._clients], - return_exceptions=True, - ) + # await asyncio.gather( + # *[client.send(event_to_json(event)) for client in self._clients], + # return_exceptions=True, + # ) + self._publisher_socket.send_json(event_to_json(event)) self._events_to_send.task_done() async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None: From 954bc020815401fc18117daf893847b2dc8fb13a Mon Sep 17 00:00:00 2001 From: xjules Date: Tue, 12 Nov 2024 15:17:57 +0100 Subject: [PATCH 03/16] WIP: evaluator -> zmq --- src/ert/ensemble_evaluator/evaluator.py | 182 ++++++------------------ 1 file changed, 43 insertions(+), 139 deletions(-) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index f0759eec215..05f390d4326 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -1,16 +1,14 @@ +from __future__ import annotations + import asyncio import datetime import logging import traceback -from contextlib import asynccontextmanager, contextmanager -from http import HTTPStatus from typing import ( Any, - AsyncIterator, Awaitable, Callable, Dict, - Generator, Iterable, List, Optional, @@ -22,15 +20,9 @@ get_args, ) -import websockets import zmq.asyncio -from pydantic_core._pydantic_core import ValidationError -from websockets.datastructures import Headers, HeadersLike -from websockets.exceptions import ConnectionClosedError -from websockets.server import WebSocketServerProtocol from _ert.events import ( - EESnapshot, EESnapshotUpdate, EETerminated, EEUserCancel, @@ -71,7 +63,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._loop: Optional[asyncio.AbstractEventLoop] = None - self._clients: Set[WebSocketServerProtocol] = set() self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue() self._events: asyncio.Queue[Event] = asyncio.Queue() @@ -89,19 +80,16 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._max_batch_size: int = 500 self._batching_interval: int = 2 self._complete_batch: asyncio.Event = asyncio.Event() + self._zmq_context: zmq.asyncio.Context | None = None async def _initialize_zmq(self) -> None: - self._zmq_context = zmq.asyncio.Context() - self._receiver_socket = self._zmq_context.socket(zmq.PULL) - self._publisher_socket = self._zmq_context.socket(zmq.PUB) + self._zmq_context = zmq.asyncio.Context() # type: ignore + self._listen_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL) + self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUB) async def _publisher(self) -> None: while True: event = await self._events_to_send.get() - # await asyncio.gather( - # *[client.send(event_to_json(event)) for client in self._clients], - # return_exceptions=True, - # ) self._publisher_socket.send_json(event_to_json(event)) self._events_to_send.task_done() @@ -213,139 +201,54 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None: def ensemble(self) -> Ensemble: return self._ensemble - @contextmanager - def store_client( - self, websocket: WebSocketServerProtocol - ) -> Generator[None, None, None]: - self._clients.add(websocket) - yield - self._clients.remove(websocket) - - async def handle_client(self, websocket: WebSocketServerProtocol) -> None: - with self.store_client(websocket): - current_snapshot_dict = self._ensemble.snapshot.to_dict() - event: Event = EESnapshot( - snapshot=current_snapshot_dict, ensemble=self.ensemble.id_ - ) - await websocket.send(event_to_json(event)) - - async for raw_msg in websocket: + async def listen_for_messages(self) -> None: + while True: + sender, raw_msg = await self._listen_socket.recv_multipart() + sender = sender.decode("utf-8") + if sender == "client": event = event_from_json(raw_msg) - logger.debug(f"got message from client: {event}") if type(event) is EEUserCancel: - logger.debug(f"Client {websocket.remote_address} asked to cancel.") + logger.debug("Client asked to cancel.") self._signal_cancel() - elif type(event) is EEUserDone: - logger.debug(f"Client {websocket.remote_address} signalled done.") + logger.debug("Client signalled done.") self.stop() - - @asynccontextmanager - async def count_dispatcher(self) -> AsyncIterator[None]: - await self._dispatchers_connected.put(None) - yield - await self._dispatchers_connected.get() - self._dispatchers_connected.task_done() - - async def handle_dispatch(self, websocket: WebSocketServerProtocol) -> None: - async with self.count_dispatcher(): - try: - async for raw_msg in websocket: - try: - event = dispatch_event_from_json(raw_msg) - if event.ensemble != self.ensemble.id_: - logger.info( - "Got event from evaluator " - f"{event.ensemble}. " - f"Ignoring since I am {self.ensemble.id_}" - ) - continue - if type(event) is ForwardModelStepChecksum: - await self.forward_checksum(event) - else: - await self._events.put(event) - except ValidationError as ex: - logger.warning( - "cannot handle event - " - f"closing connection to dispatcher: {ex}" - ) - await websocket.close( - code=1011, reason=f"failed handling message {raw_msg!r}" - ) - return - - if type(event) in [EnsembleSucceeded, EnsembleFailed]: - return - except ConnectionClosedError as connection_error: - # Dispatchers may close the connection abruptly in the case of - # * flaky network (then the dispatcher will try to reconnect) - # * job being killed due to MAX_RUNTIME - # * job being killed by user - logger.error( - f"a dispatcher abruptly closed a websocket: {connection_error!s}" - ) + elif sender == "dispatch": + event = dispatch_event_from_json(raw_msg) + if event.ensemble != self.ensemble.id_: + logger.info( + "Got event from evaluator " + f"{event.ensemble}. " + f"Ignoring since I am {self.ensemble.id_}" + ) + continue + if type(event) is ForwardModelStepChecksum: + await self.forward_checksum(event) + else: + await self._events.put(event) + if type(event) in [EnsembleSucceeded, EnsembleFailed]: + return + else: + logger.info(f"Connection attempt to unknown sender: {sender}.") async def forward_checksum(self, event: Event) -> None: # clients still need to receive events via ws await self._events_to_send.put(event) await self._manifest_queue.put(event) - async def connection_handler(self, websocket: WebSocketServerProtocol) -> None: - path = websocket.path - elements = path.split("/") - if elements[1] == "client": - await self.handle_client(websocket) - elif elements[1] == "dispatch": - await self.handle_dispatch(websocket) - else: - logger.info(f"Connection attempt to unknown path: {path}.") - - async def process_request( - self, path: str, request_headers: Headers - ) -> Optional[Tuple[HTTPStatus, HeadersLike, bytes]]: - if request_headers.get("token") != self._config.token: - return HTTPStatus.UNAUTHORIZED, {}, b"" - if path == "/healthcheck": - return HTTPStatus.OK, {}, b"" - return None - async def _server(self) -> None: - async with websockets.serve( - self.connection_handler, - sock=self._config.get_socket(), - ssl=self._config.get_server_ssl_context(), - process_request=self.process_request, - max_queue=None, - max_size=2**26, - ping_timeout=60, - ping_interval=60, - close_timeout=60, - ) as server: - self._server_started.set() - await self._server_done.wait() - server.close(close_connections=False) - if self._dispatchers_connected is not None: - logger.debug( - f"Got done signal. {self._dispatchers_connected.qsize()} " - "dispatchers to disconnect..." - ) - try: # Wait for dispatchers to disconnect - await asyncio.wait_for( - self._dispatchers_connected.join(), timeout=20 - ) - except asyncio.TimeoutError: - logger.debug("Timed out waiting for dispatchers to disconnect") - else: - logger.debug("Got done signal. No dispatchers connected") - - logger.debug("Sending termination-message to clients...") - - await self._events.join() - await self._complete_batch.wait() - await self._batch_processing_queue.join() - event = EETerminated(ensemble=self._ensemble.id_) - await self._events_to_send.put(event) - await self._events_to_send.join() + await self._initialize_zmq() + self._server_started.set() + await self._server_done.wait() + + await self._events.join() + await self._complete_batch.wait() + await self._batch_processing_queue.join() + event = EETerminated(ensemble=self._ensemble.id_) + await self._events_to_send.put(event) + await self._events_to_send.join() + self._listen_socket.close() + self._publisher_socket.close() logger.debug("Async server exiting.") def stop(self) -> None: @@ -379,6 +282,7 @@ async def _start_running(self) -> None: ), asyncio.create_task(self._process_event_buffer(), name="processing_task"), asyncio.create_task(self._publisher(), name="publisher_task"), + asyncio.create_task(self.listen_for_messages(), name="listener_task"), ] # now we wait for the server to actually start await self._server_started.wait() From 30ce4b918489dec2a411b519e71087b866f86ab3 Mon Sep 17 00:00:00 2001 From: xjules Date: Tue, 12 Nov 2024 15:41:31 +0100 Subject: [PATCH 04/16] Replace websockets with zmq in monitor --- src/ert/ensemble_evaluator/monitor.py | 67 ++++++++++----------------- 1 file changed, 24 insertions(+), 43 deletions(-) diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 93bc2ec5e1e..d78d2252d90 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -4,9 +4,7 @@ import uuid from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Optional, Union -from aiohttp import ClientError -from websockets import ConnectionClosed, Headers, WebSocketClientProtocol -from websockets.client import connect +import zmq.asyncio from _ert.events import ( EETerminated, @@ -16,7 +14,6 @@ event_from_json, event_to_json, ) -from ert.ensemble_evaluator._wait_for_evaluator import wait_for_evaluator if TYPE_CHECKING: from ert.ensemble_evaluator.evaluator_connection_info import EvaluatorConnectionInfo @@ -36,11 +33,16 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None: self._ee_con_info = ee_con_info self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0] self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue() - self._connection: Optional[WebSocketClientProtocol] = None self._receiver_task: Optional[asyncio.Task[None]] = None self._connected: asyncio.Event = asyncio.Event() self._connection_timeout: float = 120.0 self._receiver_timeout: float = 60.0 + self._zmq_context = zmq.asyncio.Context() # type: ignore + self._listen_socket: zmq.asyncio.Socket = self._zmq_context.socket( + zmq.SUBSCRIBE + ) + self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "") + self._push_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUSH) async def __aenter__(self) -> "Monitor": self._receiver_task = asyncio.create_task(self._receiver()) @@ -65,27 +67,27 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None return_exceptions=True, ) - if self._connection: - await self._connection.close() + self._listen_socket.close() + self._push_socket.close() async def signal_cancel(self) -> None: - if not self._connection: - return await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} asking server to cancel...") cancel_event = EEUserCancel(monitor=self._id) - await self._connection.send(event_to_json(cancel_event)) + await self._push_socket.send_multipart( + [b"client", event_to_json(cancel_event).encode()] + ) logger.debug(f"monitor-{self._id} asked server to cancel") async def signal_done(self) -> None: - if not self._connection: - return await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} informing server monitor is done...") done_event = EEUserDone(monitor=self._id) - await self._connection.send(event_to_json(done_event)) + await self._push_socket.send_multipart( + [b"client", event_to_json(done_event).encode()] + ) logger.debug(f"monitor-{self._id} informed server monitor is done") async def track( @@ -124,36 +126,15 @@ async def _receiver(self) -> None: if self._ee_con_info.cert: tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) tls.load_verify_locations(cadata=self._ee_con_info.cert) - headers = Headers() - if self._ee_con_info.token: - headers["token"] = self._ee_con_info.token - - await wait_for_evaluator( - base_url=self._ee_con_info.url, - token=self._ee_con_info.token, - cert=self._ee_con_info.cert, - timeout=5, - ) - async for conn in connect( - self._ee_con_info.client_uri, - ssl=tls, - extra_headers=headers, - max_size=2**26, - max_queue=500, - open_timeout=5, - ping_timeout=60, - ping_interval=60, - close_timeout=60, - ): + + while True: try: - self._connection = conn - self._connected.set() - async for raw_msg in self._connection: - event = event_from_json(raw_msg) - await self._event_queue.put(event) - except (ConnectionRefusedError, ConnectionClosed, ClientError) as exc: - self._connection = None - self._connected.clear() + raw_msg = await self._listen_socket.recv_string() + event = event_from_json(raw_msg) + await self._event_queue.put(event) + except (zmq.ZMQError, asyncio.CancelledError) as exc: + # Handle disconnection or other ZMQ errors (reconnect or log) logger.debug( - f"Monitor connection to EnsembleEvaluator went down, reconnecting: {exc}" + f"ZeroMQ connection to EnsembleEvaluator went down, reconnecting: {exc}" ) + await asyncio.sleep(1) From c592a7d99351ff6113ccf9127fc32de7fd92dffa Mon Sep 17 00:00:00 2001 From: xjules Date: Tue, 12 Nov 2024 15:48:27 +0100 Subject: [PATCH 05/16] Update client --- src/_ert/forward_model_runner/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index c417067d2bb..9987a266734 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -68,7 +68,7 @@ def reconnect(self): self.socket.connect(self.url) break except zmq.ZMQError as e: - print(f"Failed to connect to {self.url}: {e}") + logger.warning(f"Failed to connect to {self.url}: {e}") retries -= 1 if retries == 0: raise e @@ -80,8 +80,8 @@ def push(self, message): try: if self.token: message = f"{self.token}:{message}" - self.socket.send_string(message) + self.socket.send_multipart([b"dispatch", message.encode()]) except zmq.ZMQError as e: - print(f"Failed to send message: {e}") + logger.warning(f"Failed to send message: {e}") self.reconnect() - self.socket.send_string(message) + self.socket.send_multipart([b"dispatch", message.encode()]) From 6ddf210dcfd5c91f034c02dce7aff947000f8867 Mon Sep 17 00:00:00 2001 From: xjules Date: Wed, 13 Nov 2024 09:46:50 +0100 Subject: [PATCH 06/16] Update EvaluatorServerConfig to contain zmq connection info --- src/_ert/forward_model_runner/client.py | 3 +++ src/ert/ensemble_evaluator/_ensemble.py | 8 ++------ src/ert/ensemble_evaluator/config.py | 17 ++++++++++++----- src/ert/ensemble_evaluator/evaluator.py | 8 +++++--- .../evaluator_connection_info.py | 15 ++------------- src/ert/ensemble_evaluator/monitor.py | 2 ++ 6 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 9987a266734..553c15d2e5c 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -46,6 +46,9 @@ def __init__( # Set up ZeroMQ context and socket self.context = zmq.Context() self.socket = self.context.socket(zmq.PUSH) + # self.socket.setsockopt(zmq.LINGER, 0) + # self.socket.setsockopt(zmq.SNDTIMEO, self.CONNECTION_TIMEOUT * 1000) + self.socket.connect(url) if cert: client_public, client_secret = zmq.curve_keypair() diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index ecc1d5c81d5..f752ded6601 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -33,11 +33,7 @@ from ._wait_for_evaluator import wait_for_evaluator from .config import EvaluatorServerConfig -from .snapshot import ( - EnsembleSnapshot, - FMStepSnapshot, - RealizationSnapshot, -) +from .snapshot import EnsembleSnapshot, FMStepSnapshot, RealizationSnapshot from .state import ( ENSEMBLE_STATE_CANCELLED, ENSEMBLE_STATE_FAILED, @@ -282,7 +278,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches max_running=self._queue_config.max_running, submit_sleep=self._queue_config.submit_sleep, ens_id=self.id_, - ee_uri=self._config.dispatch_uri, + ee_uri=self._config.get_connection_info().push_pull_uri, ee_cert=self._config.cert, ee_token=self._config.token, ) diff --git a/src/ert/ensemble_evaluator/config.py b/src/ert/ensemble_evaluator/config.py index 79c127cccdb..77cba269f87 100644 --- a/src/ert/ensemble_evaluator/config.py +++ b/src/ert/ensemble_evaluator/config.py @@ -132,10 +132,16 @@ def __init__( custom_range=custom_port_range, custom_host=custom_host ) host, port = self._socket_handle.getsockname() - self.protocol = "wss" if generate_cert else "ws" - self.url = f"{self.protocol}://{host}:{port}" - self.client_uri = f"{self.url}/client" - self.dispatch_uri = f"{self.url}/dispatch" + self.host = host + self.pub_sub_port = port + host, port = self._socket_handle.getsockname() + self.push_pull_port = port + + # self.protocol = "wss" if generate_cert else "ws" + # self.url = f"{self.protocol}://{host}:{port}" + # self.client_uri = f"{self.url}/client" + # self.dispatch_uri = f"{self.url}/dispatch" + if generate_cert: cert, key, pw = _generate_certificate(host) else: @@ -151,7 +157,8 @@ def get_socket(self) -> socket.socket: def get_connection_info(self) -> EvaluatorConnectionInfo: return EvaluatorConnectionInfo( - self.url, + f"tcp://{self.host}:{self.push_pull_port}", + f"tcp://{self.host}:{self.pub_sub_port}", self.cert, self.token, ) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 05f390d4326..8a77038a20a 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -84,8 +84,10 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): async def _initialize_zmq(self) -> None: self._zmq_context = zmq.asyncio.Context() # type: ignore - self._listen_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL) + self._pull_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL) + self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUB) + self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}") async def _publisher(self) -> None: while True: @@ -203,7 +205,7 @@ def ensemble(self) -> Ensemble: async def listen_for_messages(self) -> None: while True: - sender, raw_msg = await self._listen_socket.recv_multipart() + sender, raw_msg = await self._pull_socket.recv_multipart() sender = sender.decode("utf-8") if sender == "client": event = event_from_json(raw_msg) @@ -247,7 +249,7 @@ async def _server(self) -> None: event = EETerminated(ensemble=self._ensemble.id_) await self._events_to_send.put(event) await self._events_to_send.join() - self._listen_socket.close() + self._pull_socket.close() self._publisher_socket.close() logger.debug("Async server exiting.") diff --git a/src/ert/ensemble_evaluator/evaluator_connection_info.py b/src/ert/ensemble_evaluator/evaluator_connection_info.py index bd48e08e4a1..399f13da1d8 100644 --- a/src/ert/ensemble_evaluator/evaluator_connection_info.py +++ b/src/ert/ensemble_evaluator/evaluator_connection_info.py @@ -6,18 +6,7 @@ class EvaluatorConnectionInfo: """Read only server-info""" - url: str + push_pull_uri: str + pub_sub_uri: str cert: Optional[Union[str, bytes]] = None token: Optional[str] = None - - @property - def dispatch_uri(self) -> str: - return f"{self.url}/dispatch" - - @property - def client_uri(self) -> str: - return f"{self.url}/client" - - @property - def result_uri(self) -> str: - return f"{self.url}/result" diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index d78d2252d90..6c162d80822 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -127,6 +127,8 @@ async def _receiver(self) -> None: tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) tls.load_verify_locations(cadata=self._ee_con_info.cert) + self._listen_socket.connect(self._ee_con_info.pub_sub_uri) + self._push_socket.connect(self._ee_con_info.push_pull_uri) while True: try: raw_msg = await self._listen_socket.recv_string() From c469200743778502b09b2983f962461de05afef9 Mon Sep 17 00:00:00 2001 From: xjules Date: Wed, 13 Nov 2024 10:31:30 +0100 Subject: [PATCH 07/16] Update monitor with proper initialization of sockets --- src/ert/ensemble_evaluator/monitor.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 6c162d80822..9c9315dbc9b 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import logging import ssl @@ -38,11 +40,8 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None: self._connection_timeout: float = 120.0 self._receiver_timeout: float = 60.0 self._zmq_context = zmq.asyncio.Context() # type: ignore - self._listen_socket: zmq.asyncio.Socket = self._zmq_context.socket( - zmq.SUBSCRIBE - ) - self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "") - self._push_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUSH) + self._listen_socket: zmq.asyncio.Socket | None = None + self._push_socket: zmq.asyncio.Socket | None = None async def __aenter__(self) -> "Monitor": self._receiver_task = asyncio.create_task(self._receiver()) @@ -127,7 +126,11 @@ async def _receiver(self) -> None: tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) tls.load_verify_locations(cadata=self._ee_con_info.cert) + self._listen_socket = self._zmq_context.socket(zmq.SUB) self._listen_socket.connect(self._ee_con_info.pub_sub_uri) + self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "") + + self._push_socket = self._zmq_context.socket(zmq.PUSH) self._push_socket.connect(self._ee_con_info.push_pull_uri) while True: try: From 5ebe5d54435940dc198962d98f627630a6840cbf Mon Sep 17 00:00:00 2001 From: xjules Date: Wed, 13 Nov 2024 11:36:46 +0100 Subject: [PATCH 08/16] Updates to client --- src/_ert/forward_model_runner/client.py | 22 +++++++++++----------- src/ert/ensemble_evaluator/evaluator.py | 4 ++++ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 553c15d2e5c..baf8c5054f6 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -50,13 +50,13 @@ def __init__( # self.socket.setsockopt(zmq.SNDTIMEO, self.CONNECTION_TIMEOUT * 1000) self.socket.connect(url) - if cert: - client_public, client_secret = zmq.curve_keypair() - server_public = cert + # if cert: + # client_public, client_secret = zmq.curve_keypair() + # server_public = cert - self.socket.curve_secretkey = client_secret - self.socket.curve_publickey = client_public - self.socket.curve_serverkey = server_public + # self.socket.curve_secretkey = client_secret + # self.socket.curve_publickey = client_public + # self.socket.curve_serverkey = server_public self._max_retries = max_retries self._timeout_multiplier = timeout_multiplier @@ -79,12 +79,12 @@ def reconnect(self): sleep_time = self._timeout_multiplier * (self._max_retries - retries) time.sleep(sleep_time) - def push(self, message): + def send(self, message): try: - if self.token: - message = f"{self.token}:{message}" - self.socket.send_multipart([b"dispatch", message.encode()]) + # if self.token: + # message = f"{self.token}:{message}" + self.socket.send_multipart([b"dispatch", message.encode("utf-8")]) except zmq.ZMQError as e: logger.warning(f"Failed to send message: {e}") self.reconnect() - self.socket.send_multipart([b"dispatch", message.encode()]) + self.socket.send_multipart([b"dispatch", message.encode("utf-8")]) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 8a77038a20a..7f87dc9b04f 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -88,6 +88,7 @@ async def _initialize_zmq(self) -> None: self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUB) self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}") + print("ZMQ initialized") async def _publisher(self) -> None: while True: @@ -207,7 +208,9 @@ async def listen_for_messages(self) -> None: while True: sender, raw_msg = await self._pull_socket.recv_multipart() sender = sender.decode("utf-8") + raw_msg = raw_msg.decode("utf-8") if sender == "client": + print(f"Got client {raw_msg=}") event = event_from_json(raw_msg) if type(event) is EEUserCancel: logger.debug("Client asked to cancel.") @@ -216,6 +219,7 @@ async def listen_for_messages(self) -> None: logger.debug("Client signalled done.") self.stop() elif sender == "dispatch": + print(f"Got dispatch {raw_msg=}") event = dispatch_event_from_json(raw_msg) if event.ensemble != self.ensemble.id_: logger.info( From c3bb5250d89f4ee9c9551a1dc778fd01984c31df Mon Sep 17 00:00:00 2001 From: xjules Date: Wed, 13 Nov 2024 14:16:51 +0100 Subject: [PATCH 09/16] Don't use wait_for_evaluator --- src/ert/ensemble_evaluator/_ensemble.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index f752ded6601..e682563dc9f 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -31,7 +31,6 @@ from ert.run_arg import RunArg from ert.scheduler import Scheduler, create_driver -from ._wait_for_evaluator import wait_for_evaluator from .config import EvaluatorServerConfig from .snapshot import EnsembleSnapshot, FMStepSnapshot, RealizationSnapshot from .state import ( @@ -200,8 +199,8 @@ async def send_event( cert: Optional[Union[str, bytes]] = None, retries: int = 10, ) -> None: - async with Client(url, token, cert, max_retries=retries) as client: - await client._send(event_to_json(event)) + with Client(url, token, cert, max_retries=retries) as client: + client.send(event_to_json(event)) def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]: def event_builder(status: str) -> Event: @@ -226,16 +225,16 @@ async def evaluate( ce_unary_send_method_name, partialmethod( self.__class__.send_event, - self._config.dispatch_uri, + self._config.get_connection_info().push_pull_uri, token=self._config.token, cert=self._config.cert, ), ) - await wait_for_evaluator( - base_url=self._config.url, - token=self._config.token, - cert=self._config.cert, - ) + # await wait_for_evaluator( + # base_url=self._config.url, + # token=self._config.token, + # cert=self._config.cert, + # ) await self._evaluate_inner( event_unary_send=getattr(self, ce_unary_send_method_name), scheduler_queue=scheduler_queue, From 6c5b21b07491904ddf95befe21495f71ae453fdf Mon Sep 17 00:00:00 2001 From: xjules Date: Wed, 13 Nov 2024 15:03:48 +0100 Subject: [PATCH 10/16] Remove server_started --- src/ert/ensemble_evaluator/evaluator.py | 9 ++------- src/ert/ensemble_evaluator/monitor.py | 10 +++++----- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 7f87dc9b04f..f9ced9206c0 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -70,7 +70,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue() self._ee_tasks: List[asyncio.Task[None]] = [] - self._server_started: asyncio.Event = asyncio.Event() self._server_done: asyncio.Event = asyncio.Event() # batching section @@ -88,7 +87,7 @@ async def _initialize_zmq(self) -> None: self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUB) self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}") - print("ZMQ initialized") + logger.error("ZMQ initialized") async def _publisher(self) -> None: while True: @@ -243,10 +242,7 @@ async def forward_checksum(self, event: Event) -> None: await self._manifest_queue.put(event) async def _server(self) -> None: - await self._initialize_zmq() - self._server_started.set() await self._server_done.wait() - await self._events.join() await self._complete_batch.wait() await self._batch_processing_queue.join() @@ -281,6 +277,7 @@ async def _start_running(self) -> None: if not self._config: raise ValueError("no config for evaluator") self._loop = asyncio.get_running_loop() + await self._initialize_zmq() self._ee_tasks = [ asyncio.create_task(self._server(), name="server_task"), asyncio.create_task( @@ -290,8 +287,6 @@ async def _start_running(self) -> None: asyncio.create_task(self._publisher(), name="publisher_task"), asyncio.create_task(self.listen_for_messages(), name="listener_task"), ] - # now we wait for the server to actually start - await self._server_started.wait() self._ee_tasks.append( asyncio.create_task( diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 9c9315dbc9b..00e5dcd4c0b 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -40,8 +40,8 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None: self._connection_timeout: float = 120.0 self._receiver_timeout: float = 60.0 self._zmq_context = zmq.asyncio.Context() # type: ignore - self._listen_socket: zmq.asyncio.Socket | None = None - self._push_socket: zmq.asyncio.Socket | None = None + # self._listen_socket: zmq.asyncio.Socket | None = None + # self._push_socket: zmq.asyncio.Socket | None = None async def __aenter__(self) -> "Monitor": self._receiver_task = asyncio.create_task(self._receiver()) @@ -58,6 +58,9 @@ async def __aenter__(self) -> "Monitor": async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self._receiver_task: + if self._listen_socket and self._push_socket: + self._listen_socket.close() + self._push_socket.close() if not self._receiver_task.done(): self._receiver_task.cancel() # we are done and not interested in errors when cancelling @@ -66,9 +69,6 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None return_exceptions=True, ) - self._listen_socket.close() - self._push_socket.close() - async def signal_cancel(self) -> None: await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} asking server to cancel...") From 932c425e55034ec0a69c7886ef76d38aadd7e7ca Mon Sep 17 00:00:00 2001 From: xjules Date: Wed, 13 Nov 2024 16:44:01 +0100 Subject: [PATCH 11/16] Make find_available ports work with zmq --- src/ert/ensemble_evaluator/config.py | 15 +++++++++------ src/ert/ensemble_evaluator/evaluator.py | 16 +++++++++++----- src/ert/ensemble_evaluator/monitor.py | 5 +++++ src/ert/shared/net_utils.py | 1 + 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/ert/ensemble_evaluator/config.py b/src/ert/ensemble_evaluator/config.py index 77cba269f87..4d8c6b18fb9 100644 --- a/src/ert/ensemble_evaluator/config.py +++ b/src/ert/ensemble_evaluator/config.py @@ -129,19 +129,22 @@ def __init__( custom_host: typing.Optional[str] = None, ) -> None: self._socket_handle = find_available_socket( - custom_range=custom_port_range, custom_host=custom_host + custom_range=custom_port_range, + custom_host=custom_host, + will_close_then_reopen_socket=True, ) host, port = self._socket_handle.getsockname() self.host = host self.pub_sub_port = port + + self._socket_handle = find_available_socket( + custom_range=custom_port_range, + custom_host=custom_host, + will_close_then_reopen_socket=True, + ) host, port = self._socket_handle.getsockname() self.push_pull_port = port - # self.protocol = "wss" if generate_cert else "ws" - # self.url = f"{self.protocol}://{host}:{port}" - # self.client_uri = f"{self.url}/client" - # self.dispatch_uri = f"{self.url}/dispatch" - if generate_cert: cert, key, pw = _generate_certificate(host) else: diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index f9ced9206c0..4864afb48b4 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -83,16 +83,22 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): async def _initialize_zmq(self) -> None: self._zmq_context = zmq.asyncio.Context() # type: ignore - self._pull_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL) - self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") - self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PUB) - self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}") + try: + self._pull_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL) + self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") + self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket( + zmq.PUB + ) + self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}") + except zmq.error.ZMQError as e: + logger.error(f"ZMQ error: {e}") + raise logger.error("ZMQ initialized") async def _publisher(self) -> None: while True: event = await self._events_to_send.get() - self._publisher_socket.send_json(event_to_json(event)) + await self._publisher_socket.send_string(event_to_json(event)) self._events_to_send.task_done() async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None: diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 00e5dcd4c0b..946a4b99204 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json import logging import ssl import uuid @@ -132,10 +133,14 @@ async def _receiver(self) -> None: self._push_socket = self._zmq_context.socket(zmq.PUSH) self._push_socket.connect(self._ee_con_info.push_pull_uri) + self._connected.set() + while True: try: raw_msg = await self._listen_socket.recv_string() + raw_msg = json.loads(raw_msg) event = event_from_json(raw_msg) + print(f"monitor-{self._id} received event: {event}") await self._event_queue.put(event) except (zmq.ZMQError, asyncio.CancelledError) as exc: # Handle disconnection or other ZMQ errors (reconnect or log) diff --git a/src/ert/shared/net_utils.py b/src/ert/shared/net_utils.py index 66c12aef6c9..2cf467481ac 100644 --- a/src/ert/shared/net_utils.py +++ b/src/ert/shared/net_utils.py @@ -111,6 +111,7 @@ def _bind_socket( if will_close_then_reopen_socket: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) else: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) From 2d53e4ecdf49c2ea2fdd4e9e267d10f1db6a4f14 Mon Sep 17 00:00:00 2001 From: xjules Date: Thu, 14 Nov 2024 10:29:51 +0100 Subject: [PATCH 12/16] WIP: full snapshot udpate needs to be sent to client on connection --- src/ert/ensemble_evaluator/evaluator.py | 22 +++++++++++++++++++--- src/ert/ensemble_evaluator/monitor.py | 13 ++++++------- src/ert/run_models/base_run_model.py | 8 +------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 4864afb48b4..59d42521ab9 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -23,6 +23,7 @@ import zmq.asyncio from _ert.events import ( + EESnapshot, EESnapshotUpdate, EETerminated, EEUserCancel, @@ -209,6 +210,21 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None: def ensemble(self) -> Ensemble: return self._ensemble + async def listen_for_clients(self) -> None: + while True: + event = await self._publisher_socket.recv() + # TODO change to router-dealer as this would inform all subscribers about the snapshot + if event[0] == 1: + print("Subscriber connected") + current_snapshot_dict = self._ensemble.snapshot.to_dict() + event: Event = EESnapshot( + snapshot=current_snapshot_dict, ensemble=self.ensemble.id_ + ) + await self._publisher_socket.send_string(event_to_json(event)) + + elif event[0] == 0: + print("Subscriber disconnected") + async def listen_for_messages(self) -> None: while True: sender, raw_msg = await self._pull_socket.recv_multipart() @@ -224,8 +240,8 @@ async def listen_for_messages(self) -> None: logger.debug("Client signalled done.") self.stop() elif sender == "dispatch": - print(f"Got dispatch {raw_msg=}") event = dispatch_event_from_json(raw_msg) + # print(f"Got dispatch {event=}") if event.ensemble != self.ensemble.id_: logger.info( "Got event from evaluator " @@ -237,8 +253,8 @@ async def listen_for_messages(self) -> None: await self.forward_checksum(event) else: await self._events.put(event) - if type(event) in [EnsembleSucceeded, EnsembleFailed]: - return + # if type(event) in [EnsembleSucceeded, EnsembleFailed]: + # return else: logger.info(f"Connection attempt to unknown sender: {sender}.") diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 946a4b99204..96ba11a5742 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import json import logging import ssl import uuid @@ -76,7 +75,7 @@ async def signal_cancel(self) -> None: cancel_event = EEUserCancel(monitor=self._id) await self._push_socket.send_multipart( - [b"client", event_to_json(cancel_event).encode()] + [b"client", event_to_json(cancel_event).encode("utf-8")] ) logger.debug(f"monitor-{self._id} asked server to cancel") @@ -86,7 +85,7 @@ async def signal_done(self) -> None: done_event = EEUserDone(monitor=self._id) await self._push_socket.send_multipart( - [b"client", event_to_json(done_event).encode()] + [b"client", event_to_json(done_event).encode("utf-8")] ) logger.debug(f"monitor-{self._id} informed server monitor is done") @@ -127,7 +126,7 @@ async def _receiver(self) -> None: tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) tls.load_verify_locations(cadata=self._ee_con_info.cert) - self._listen_socket = self._zmq_context.socket(zmq.SUB) + self._listen_socket = self._zmq_context.socket(zmq.XSUB) self._listen_socket.connect(self._ee_con_info.pub_sub_uri) self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "") @@ -138,11 +137,11 @@ async def _receiver(self) -> None: while True: try: raw_msg = await self._listen_socket.recv_string() - raw_msg = json.loads(raw_msg) + # print(f"monitor-{self._id} received msg: {raw_msg}") event = event_from_json(raw_msg) - print(f"monitor-{self._id} received event: {event}") + # print(f"monitor-{self._id} received event: {event}") await self._event_queue.put(event) - except (zmq.ZMQError, asyncio.CancelledError) as exc: + except zmq.ZMQError as exc: # Handle disconnection or other ZMQ errors (reconnect or log) logger.debug( f"ZeroMQ connection to EnsembleEvaluator went down, reconnecting: {exc}" diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index b000d5fa491..557c737c939 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -26,12 +26,7 @@ import numpy as np -from _ert.events import ( - EESnapshot, - EESnapshotUpdate, - EETerminated, - Event, -) +from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event from ert.analysis import ( AnalysisEvent, AnalysisStatusEvent, @@ -507,7 +502,6 @@ async def run_monitor( event, iteration, ) - if event.snapshot.get(STATUS) in [ ENSEMBLE_STATE_STOPPED, ENSEMBLE_STATE_FAILED, From 0a25ff4696311a54a1047480d2c4c89fdb4bb27d Mon Sep 17 00:00:00 2001 From: xjules Date: Thu, 14 Nov 2024 13:35:18 +0100 Subject: [PATCH 13/16] Fixing staff... --- src/_ert/forward_model_runner/client.py | 2 - src/ert/ensemble_evaluator/evaluator.py | 116 +++++++++++++++--------- src/ert/ensemble_evaluator/monitor.py | 4 +- 3 files changed, 73 insertions(+), 49 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index baf8c5054f6..13b73e7b374 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -81,8 +81,6 @@ def reconnect(self): def send(self, message): try: - # if self.token: - # message = f"{self.token}:{message}" self.socket.send_multipart([b"dispatch", message.encode("utf-8")]) except zmq.ZMQError as e: logger.warning(f"Failed to send message: {e}") diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 59d42521ab9..3f40db159ba 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -64,8 +64,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._loop: Optional[asyncio.AbstractEventLoop] = None - self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue() - self._events: asyncio.Queue[Event] = asyncio.Queue() self._events_to_send: asyncio.Queue[Event] = asyncio.Queue() self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue() @@ -81,6 +79,8 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._batching_interval: int = 2 self._complete_batch: asyncio.Event = asyncio.Event() self._zmq_context: zmq.asyncio.Context | None = None + self._clients_connected: asyncio.Queue[None] = asyncio.Queue() + self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue() async def _initialize_zmq(self) -> None: self._zmq_context = zmq.asyncio.Context() # type: ignore @@ -88,13 +88,13 @@ async def _initialize_zmq(self) -> None: self._pull_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL) self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket( - zmq.PUB + zmq.XPUB ) self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}") except zmq.error.ZMQError as e: logger.error(f"ZMQ error: {e}") raise - logger.error("ZMQ initialized") + logger.info("ZMQ initialized") async def _publisher(self) -> None: while True: @@ -212,51 +212,71 @@ def ensemble(self) -> Ensemble: async def listen_for_clients(self) -> None: while True: - event = await self._publisher_socket.recv() - # TODO change to router-dealer as this would inform all subscribers about the snapshot - if event[0] == 1: - print("Subscriber connected") - current_snapshot_dict = self._ensemble.snapshot.to_dict() - event: Event = EESnapshot( - snapshot=current_snapshot_dict, ensemble=self.ensemble.id_ - ) - await self._publisher_socket.send_string(event_to_json(event)) - - elif event[0] == 0: - print("Subscriber disconnected") + try: + raw_msg = await self._publisher_socket.recv() + # this would inform all subscribers about the snapshot + if raw_msg[0] == 1: + await self._clients_connected.put(None) + current_snapshot_dict = self._ensemble.snapshot.to_dict() + event: Event = EESnapshot( + snapshot=current_snapshot_dict, ensemble=self.ensemble.id_ + ) + await self._publisher_socket.send_string(event_to_json(event)) + + elif raw_msg[0] == 0: + await self._clients_connected.get() + self._clients_connected.task_done() + except zmq.error.ZMQError as e: + if e.errno == zmq.ENOTSOCK: + logger.warning( + "Evaluator publisher closed, no new clients accepted" + ) + else: + logger.error(f"Unexpected error when connecting new clients: {e}") + return + except asyncio.CancelledError: + return async def listen_for_messages(self) -> None: while True: - sender, raw_msg = await self._pull_socket.recv_multipart() - sender = sender.decode("utf-8") - raw_msg = raw_msg.decode("utf-8") - if sender == "client": - print(f"Got client {raw_msg=}") - event = event_from_json(raw_msg) - if type(event) is EEUserCancel: - logger.debug("Client asked to cancel.") - self._signal_cancel() - elif type(event) is EEUserDone: - logger.debug("Client signalled done.") - self.stop() - elif sender == "dispatch": - event = dispatch_event_from_json(raw_msg) - # print(f"Got dispatch {event=}") - if event.ensemble != self.ensemble.id_: - logger.info( - "Got event from evaluator " - f"{event.ensemble}. " - f"Ignoring since I am {self.ensemble.id_}" + try: + sender, raw_msg = await self._pull_socket.recv_multipart() + sender = sender.decode("utf-8") + raw_msg = raw_msg.decode("utf-8") + if sender == "client": + event = event_from_json(raw_msg) + if type(event) is EEUserCancel: + logger.debug("Client asked to cancel.") + self._signal_cancel() + elif type(event) is EEUserDone: + logger.debug("Client signalled done.") + self.stop() + elif sender == "dispatch": + event = dispatch_event_from_json(raw_msg) + if event.ensemble != self.ensemble.id_: + logger.info( + "Got event from evaluator " + f"{event.ensemble}. " + f"Ignoring since I am {self.ensemble.id_}" + ) + continue + if type(event) is ForwardModelStepChecksum: + await self.forward_checksum(event) + else: + await self._events.put(event) + # if type(event) in [EnsembleSucceeded, EnsembleFailed]: + # return + else: + logger.info(f"Connection attempt to unknown sender: {sender}.") + except zmq.error.ZMQError as e: + if e.errno == zmq.ENOTSOCK: + logger.warning( + "Evaluator receiver closed, no new messages are received" ) - continue - if type(event) is ForwardModelStepChecksum: - await self.forward_checksum(event) else: - await self._events.put(event) - # if type(event) in [EnsembleSucceeded, EnsembleFailed]: - # return - else: - logger.info(f"Connection attempt to unknown sender: {sender}.") + logger.error(f"Unexpected error when listening to messages: {e}") + except asyncio.CancelledError: + return async def forward_checksum(self, event: Event) -> None: # clients still need to receive events via ws @@ -271,6 +291,7 @@ async def _server(self) -> None: event = EETerminated(ensemble=self._ensemble.id_) await self._events_to_send.put(event) await self._events_to_send.join() + await self._clients_connected.join() self._pull_socket.close() self._publisher_socket.close() logger.debug("Async server exiting.") @@ -308,6 +329,7 @@ async def _start_running(self) -> None: asyncio.create_task(self._process_event_buffer(), name="processing_task"), asyncio.create_task(self._publisher(), name="publisher_task"), asyncio.create_task(self.listen_for_messages(), name="listener_task"), + asyncio.create_task(self.listen_for_clients(), name="client_task"), ] self._ee_tasks.append( @@ -360,7 +382,11 @@ async def _monitor_and_handle_tasks(self) -> None: if stop_timeout_task: stop_timeout_task.cancel() return - elif task.get_name() == "ensemble_task": + elif task.get_name() in [ + "ensemble_task", + "listener_task", + "client_task", + ]: stop_timeout_task = asyncio.create_task( self._wait_for_stopped_server() ) diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 96ba11a5742..cc1e426bcaa 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -126,7 +126,7 @@ async def _receiver(self) -> None: tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) tls.load_verify_locations(cadata=self._ee_con_info.cert) - self._listen_socket = self._zmq_context.socket(zmq.XSUB) + self._listen_socket = self._zmq_context.socket(zmq.SUB) self._listen_socket.connect(self._ee_con_info.pub_sub_uri) self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "") @@ -139,7 +139,7 @@ async def _receiver(self) -> None: raw_msg = await self._listen_socket.recv_string() # print(f"monitor-{self._id} received msg: {raw_msg}") event = event_from_json(raw_msg) - # print(f"monitor-{self._id} received event: {event}") + print(f"monitor-{self._id} received event: {event}") await self._event_queue.put(event) except zmq.ZMQError as exc: # Handle disconnection or other ZMQ errors (reconnect or log) From 2201ae2c8c2289ad157cb50a098e60a3f6652501 Mon Sep 17 00:00:00 2001 From: xjules Date: Thu, 14 Nov 2024 14:38:14 +0100 Subject: [PATCH 14/16] Settup encryption with curve --- src/_ert/forward_model_runner/client.py | 15 +++++---------- src/ert/ensemble_evaluator/config.py | 3 +++ src/ert/ensemble_evaluator/evaluator.py | 9 +++++++++ src/ert/ensemble_evaluator/monitor.py | 16 +++++++++++++--- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 13b73e7b374..7fdb1a17a63 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -46,18 +46,13 @@ def __init__( # Set up ZeroMQ context and socket self.context = zmq.Context() self.socket = self.context.socket(zmq.PUSH) - # self.socket.setsockopt(zmq.LINGER, 0) - # self.socket.setsockopt(zmq.SNDTIMEO, self.CONNECTION_TIMEOUT * 1000) + if token is not None: + client_public, client_secret = zmq.curve_keypair() + self.socket.curve_secretkey = client_secret + self.socket.curve_publickey = client_public + self.socket.curve_serverkey = token.encode("utf-8") self.socket.connect(url) - # if cert: - # client_public, client_secret = zmq.curve_keypair() - # server_public = cert - - # self.socket.curve_secretkey = client_secret - # self.socket.curve_publickey = client_public - # self.socket.curve_serverkey = server_public - self._max_retries = max_retries self._timeout_multiplier = timeout_multiplier diff --git a/src/ert/ensemble_evaluator/config.py b/src/ert/ensemble_evaluator/config.py index 4d8c6b18fb9..0a97db4679f 100644 --- a/src/ert/ensemble_evaluator/config.py +++ b/src/ert/ensemble_evaluator/config.py @@ -11,6 +11,7 @@ from datetime import datetime, timedelta from typing import Optional +import zmq from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization @@ -154,6 +155,8 @@ def __init__( self._key_pw = pw self.token = _generate_authentication() if use_token else None + self.server_public_key, self.server_secret_key = zmq.curve_keypair() + self.token = self.server_public_key.decode("utf-8") def get_socket(self) -> socket.socket: return self._socket_handle.dup() diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 3f40db159ba..5a8e6f983e4 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -86,11 +86,20 @@ async def _initialize_zmq(self) -> None: self._zmq_context = zmq.asyncio.Context() # type: ignore try: self._pull_socket: zmq.asyncio.Socket = self._zmq_context.socket(zmq.PULL) + self._pull_socket.curve_secretkey = self._config.server_secret_key + self._pull_socket.curve_publickey = self._config.server_public_key + self._pull_socket.curve_server = True self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") + self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket( zmq.XPUB ) + + self._publisher_socket.curve_secretkey = self._config.server_secret_key + self._publisher_socket.curve_publickey = self._config.server_public_key + self._publisher_socket.curve_server = True self._publisher_socket.bind(f"tcp://*:{self._config.pub_sub_port}") + except zmq.error.ZMQError as e: logger.error(f"ZMQ error: {e}") raise diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index cc1e426bcaa..c2dd458c2ea 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -127,19 +127,29 @@ async def _receiver(self) -> None: tls.load_verify_locations(cadata=self._ee_con_info.cert) self._listen_socket = self._zmq_context.socket(zmq.SUB) + self._push_socket = self._zmq_context.socket(zmq.PUSH) + + if self._ee_con_info.token is not None: + client_public, client_secret = zmq.curve_keypair() + self._listen_socket.curve_secretkey = client_secret + self._listen_socket.curve_publickey = client_public + self._listen_socket.curve_serverkey = self._ee_con_info.token.encode( + "utf-8" + ) + self._push_socket.curve_secretkey = client_secret + self._push_socket.curve_publickey = client_public + self._push_socket.curve_serverkey = self._ee_con_info.token.encode("utf-8") + self._listen_socket.connect(self._ee_con_info.pub_sub_uri) self._listen_socket.setsockopt_string(zmq.SUBSCRIBE, "") - self._push_socket = self._zmq_context.socket(zmq.PUSH) self._push_socket.connect(self._ee_con_info.push_pull_uri) self._connected.set() while True: try: raw_msg = await self._listen_socket.recv_string() - # print(f"monitor-{self._id} received msg: {raw_msg}") event = event_from_json(raw_msg) - print(f"monitor-{self._id} received event: {event}") await self._event_queue.put(event) except zmq.ZMQError as exc: # Handle disconnection or other ZMQ errors (reconnect or log) From 7580b35484757621861042ee1e94b75bbe7ca6e4 Mon Sep 17 00:00:00 2001 From: xjules Date: Thu, 14 Nov 2024 16:06:51 +0100 Subject: [PATCH 15/16] Adjust push-pull sockets to perform faster --- src/_ert/forward_model_runner/client.py | 4 ++++ src/ert/ensemble_evaluator/evaluator.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 7fdb1a17a63..5618dd355d6 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -46,6 +46,10 @@ def __init__( # Set up ZeroMQ context and socket self.context = zmq.Context() self.socket = self.context.socket(zmq.PUSH) + # reduce backlog + self.socket.setsockopt(zmq.LINGER, 0) + # if server is not ready yet, no message is sent + self.socket.setsockopt(zmq.IMMEDIATE, 1) if token is not None: client_public, client_secret = zmq.curve_keypair() self.socket.curve_secretkey = client_secret diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 5a8e6f983e4..fcf2a34edf6 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -89,6 +89,8 @@ async def _initialize_zmq(self) -> None: self._pull_socket.curve_secretkey = self._config.server_secret_key self._pull_socket.curve_publickey = self._config.server_public_key self._pull_socket.curve_server = True + self._pull_socket.setsockopt(zmq.LINGER, 0) + self._pull_socket.setsockopt(zmq.RCVHWM, 10000) self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket( From 561dd56bae42947bbd207bd94b3fb5e44f48aca8 Mon Sep 17 00:00:00 2001 From: xjules Date: Fri, 15 Nov 2024 16:57:36 +0100 Subject: [PATCH 16/16] Add event batching on dispatcher side --- src/_ert/forward_model_runner/client.py | 23 ++++++-- .../forward_model_runner/reporting/event.py | 39 ++++++++------ src/ert/ensemble_evaluator/evaluator.py | 52 ++++++++++--------- 3 files changed, 68 insertions(+), 46 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 5618dd355d6..7d57c1cdf07 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import time from typing import Any, Optional, Union @@ -47,9 +49,12 @@ def __init__( self.context = zmq.Context() self.socket = self.context.socket(zmq.PUSH) # reduce backlog - self.socket.setsockopt(zmq.LINGER, 0) + # self.socket.setsockopt(zmq.LINGER, 0) # if server is not ready yet, no message is sent - self.socket.setsockopt(zmq.IMMEDIATE, 1) + # self.socket.setsockopt(zmq.IMMEDIATE, 1) + # self.socket.setsockopt(zmq.SNDHWM, 10000) + # self.socket.setsockopt(zmq.SNDHWM, 1) + # self.socket.setsockopt(zmq.RCVHWM, 1) if token is not None: client_public, client_secret = zmq.curve_keypair() self.socket.curve_secretkey = client_secret @@ -78,10 +83,18 @@ def reconnect(self): sleep_time = self._timeout_multiplier * (self._max_retries - retries) time.sleep(sleep_time) - def send(self, message): + def send(self, messages: str | list[str]) -> None: + if isinstance(messages, str): + messages = [messages] try: - self.socket.send_multipart([b"dispatch", message.encode("utf-8")]) + logger.debug(f"sending messages: {messages}") + self.socket.send_multipart( + [b"dispatch"] + [message.encode("utf-8") for message in messages] + ) + logger.debug("sending messages: success") except zmq.ZMQError as e: logger.warning(f"Failed to send message: {e}") self.reconnect() - self.socket.send_multipart([b"dispatch", message.encode("utf-8")]) + self.socket.send_multipart( + [b"dispatch"] + [message.encode("utf-8") for message in messages] + ) diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 8bf13dee238..11ef56374b6 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -3,6 +3,7 @@ import logging import queue import threading +import time from datetime import datetime, timedelta from pathlib import Path from typing import Final, Union @@ -18,8 +19,6 @@ ) from _ert.forward_model_runner.client import ( Client, - ClientConnectionClosedOK, - ClientConnectionError, ) from _ert.forward_model_runner.reporting.base import Reporter from _ert.forward_model_runner.reporting.message import ( @@ -90,7 +89,8 @@ def _event_publisher(self): token=self._token, cert=self._cert, ) as client: - event = None + events = [] + last_sent_time = time.time() while True: with self._timestamp_lock: if ( @@ -99,23 +99,28 @@ def _event_publisher(self): ): self._timeout_timestamp = None break - if event is None: - # if we successfully sent the event we can proceed - # to next one + + try: event = self._event_queue.get() + logger.debug(f"Got event for zmq: {event}") if event is self._sentinel: + if events: + logger.debug(f"Got event class for zmq: {events}") + client.send(events) + events.clear() break - try: - client.send(event_to_json(event)) - event = None - except ClientConnectionError as exception: - # Possible intermittent failure, we retry sending the event - logger.error(str(exception)) - except ClientConnectionClosedOK as exception: - # The receiving end has closed the connection, we stop - # sending events - logger.debug(str(exception)) - break + events.append(event_to_json(event)) + + current_time = time.time() + if current_time - last_sent_time >= 2: + if events: + logger.debug(f"Got event class for zmq: {events}") + client.send(events) + events.clear() + last_sent_time = current_time + except Exception as e: + logger.error(f"Failed to send event: {e}") + raise def report(self, msg): self._statemachine.transition(msg) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index fcf2a34edf6..6573e22ccad 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -89,8 +89,9 @@ async def _initialize_zmq(self) -> None: self._pull_socket.curve_secretkey = self._config.server_secret_key self._pull_socket.curve_publickey = self._config.server_public_key self._pull_socket.curve_server = True - self._pull_socket.setsockopt(zmq.LINGER, 0) - self._pull_socket.setsockopt(zmq.RCVHWM, 10000) + # self._pull_socket.setsockopt(zmq.LINGER, 0) + # self._pull_socket.setsockopt(zmq.SNDHWM, 1) + # self._pull_socket.setsockopt(zmq.RCVHWM, 1) self._pull_socket.bind(f"tcp://*:{self._config.push_pull_port}") self._publisher_socket: zmq.asyncio.Socket = self._zmq_context.socket( @@ -251,30 +252,33 @@ async def listen_for_clients(self) -> None: async def listen_for_messages(self) -> None: while True: try: - sender, raw_msg = await self._pull_socket.recv_multipart() - sender = sender.decode("utf-8") - raw_msg = raw_msg.decode("utf-8") + frames = await self._pull_socket.recv_multipart() + sender = frames[0].decode("utf-8") if sender == "client": - event = event_from_json(raw_msg) - if type(event) is EEUserCancel: - logger.debug("Client asked to cancel.") - self._signal_cancel() - elif type(event) is EEUserDone: - logger.debug("Client signalled done.") - self.stop() + for frame in frames[1:]: + raw_msg = frame.decode("utf-8") + event = event_from_json(raw_msg) + if type(event) is EEUserCancel: + logger.debug("Client asked to cancel.") + self._signal_cancel() + elif type(event) is EEUserDone: + logger.debug("Client signalled done.") + self.stop() elif sender == "dispatch": - event = dispatch_event_from_json(raw_msg) - if event.ensemble != self.ensemble.id_: - logger.info( - "Got event from evaluator " - f"{event.ensemble}. " - f"Ignoring since I am {self.ensemble.id_}" - ) - continue - if type(event) is ForwardModelStepChecksum: - await self.forward_checksum(event) - else: - await self._events.put(event) + for frame in frames[1:]: + raw_msg = frame.decode("utf-8") + event = dispatch_event_from_json(raw_msg) + if event.ensemble != self.ensemble.id_: + logger.info( + "Got event from evaluator " + f"{event.ensemble}. " + f"Ignoring since I am {self.ensemble.id_}" + ) + continue + if type(event) is ForwardModelStepChecksum: + await self.forward_checksum(event) + else: + await self._events.put(event) # if type(event) in [EnsembleSucceeded, EnsembleFailed]: # return else: