Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace websockets with zmq #9173

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
137 changes: 43 additions & 94 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
import asyncio
import logging
import ssl
from typing import Any, AnyStr, Optional, Union
import time
from typing import Any, Optional, Union

import zmq
from typing_extensions import Self
from websockets.client import WebSocketClientProtocol, connect
from websockets.datastructures import Headers
from websockets.exceptions import (
ConnectionClosedError,
ConnectionClosedOK,
InvalidHandshake,
InvalidURI,
)

from _ert.async_utils import new_event_loop

logger = logging.getLogger(__name__)

Expand All @@ -35,18 +25,8 @@
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
if self.websocket is not None:
self.loop.run_until_complete(self.websocket.close())
self.loop.close()

async def __aenter__(self) -> "Client":
return self

async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
) -> None:
if self.websocket is not None:
await self.websocket.close()
self.socket.close()
self.context.term()

def __init__(
self,
Expand All @@ -60,79 +40,48 @@
max_retries = self.DEFAULT_MAX_RETRIES
if timeout_multiplier is None:
timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER
if url is None:
raise ValueError("url was None")
self.url = url
self.token = token
self._extra_headers = Headers()

# Set up ZeroMQ context and socket
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
# reduce backlog
self.socket.setsockopt(zmq.LINGER, 0)
# if server is not ready yet, no message is sent
self.socket.setsockopt(zmq.IMMEDIATE, 1)
if token is not None:
self._extra_headers["token"] = token

# Mimics the behavior of the ssl argument when connection to
# websockets. If none is specified it will deduce based on the url,
# if True it will enforce TLS, and if you want to use self signed
# certificates you need to pass an ssl_context with the certificate
# loaded.
self._ssl_context: Optional[Union[bool, ssl.SSLContext]] = None
if cert is not None:
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self._ssl_context.load_verify_locations(cadata=cert)
elif url.startswith("wss"):
self._ssl_context = True
client_public, client_secret = zmq.curve_keypair()
self.socket.curve_secretkey = client_secret
self.socket.curve_publickey = client_public
self.socket.curve_serverkey = token.encode("utf-8")
self.socket.connect(url)

self._max_retries = max_retries
self._timeout_multiplier = timeout_multiplier
self.websocket: Optional[WebSocketClientProtocol] = None
self.loop = new_event_loop()

async def get_websocket(self) -> WebSocketClientProtocol:
return await connect(
self.url,
ssl=self._ssl_context,
extra_headers=self._extra_headers,
open_timeout=self.CONNECTION_TIMEOUT,
ping_timeout=self.CONNECTION_TIMEOUT,
ping_interval=self.CONNECTION_TIMEOUT,
close_timeout=self.CONNECTION_TIMEOUT,
)

async def _send(self, msg: AnyStr) -> None:
for retry in range(self._max_retries + 1):

self.reconnect()

Check failure on line 63 in src/_ert/forward_model_runner/client.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "reconnect" in typed context

def reconnect(self):

Check failure on line 65 in src/_ert/forward_model_runner/client.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
"""Connect to the server with exponential backoff."""
retries = self._max_retries
while retries > 0:
try:
if self.websocket is None:
self.websocket = await self.get_websocket()
await self.websocket.send(msg)
return
except ConnectionClosedOK as exception:
_error_msg = (
f"Connection closed received from the server {self.url}! "
f" Exception from {type(exception)}: {exception!s}"
)
raise ClientConnectionClosedOK(_error_msg) from exception
except (
InvalidHandshake,
InvalidURI,
OSError,
asyncio.TimeoutError,
) as exception:
if retry == self._max_retries:
_error_msg = (
f"Not able to establish the "
f"websocket connection {self.url}! Max retries reached!"
" Check for firewall issues."
f" Exception from {type(exception)}: {exception!s}"
)
raise ClientConnectionError(_error_msg) from exception
except ConnectionClosedError as exception:
if retry == self._max_retries:
_error_msg = (
f"Not been able to send the event"
f" to {self.url}! Max retries reached!"
f" Exception from {type(exception)}: {exception!s}"
)
raise ClientConnectionError(_error_msg) from exception
await asyncio.sleep(0.2 + self._timeout_multiplier * retry)
self.websocket = None

def send(self, msg: AnyStr) -> None:
self.loop.run_until_complete(self._send(msg))
self.socket.connect(self.url)
break
except zmq.ZMQError as e:
logger.warning(f"Failed to connect to {self.url}: {e}")
retries -= 1
if retries == 0:
raise e
# Exponential backoff
sleep_time = self._timeout_multiplier * (self._max_retries - retries)
time.sleep(sleep_time)

def send(self, message):

Check failure on line 81 in src/_ert/forward_model_runner/client.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a type annotation
try:
self.socket.send_multipart([b"dispatch", message.encode("utf-8")])
except zmq.ZMQError as e:
logger.warning(f"Failed to send message: {e}")
self.reconnect()

Check failure on line 86 in src/_ert/forward_model_runner/client.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "reconnect" in typed context
self.socket.send_multipart([b"dispatch", message.encode("utf-8")])
25 changes: 10 additions & 15 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,8 @@
from ert.run_arg import RunArg
from ert.scheduler import Scheduler, create_driver

from ._wait_for_evaluator import wait_for_evaluator
from .config import EvaluatorServerConfig
from .snapshot import (
EnsembleSnapshot,
FMStepSnapshot,
RealizationSnapshot,
)
from .snapshot import EnsembleSnapshot, FMStepSnapshot, RealizationSnapshot
from .state import (
ENSEMBLE_STATE_CANCELLED,
ENSEMBLE_STATE_FAILED,
Expand Down Expand Up @@ -204,8 +199,8 @@
cert: Optional[Union[str, bytes]] = None,
retries: int = 10,
) -> None:
async with Client(url, token, cert, max_retries=retries) as client:
await client._send(event_to_json(event))
with Client(url, token, cert, max_retries=retries) as client:
client.send(event_to_json(event))

Check failure on line 203 in src/ert/ensemble_evaluator/_ensemble.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "send" in typed context

def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]:
def event_builder(status: str) -> Event:
Expand All @@ -230,16 +225,16 @@
ce_unary_send_method_name,
partialmethod(
self.__class__.send_event,
self._config.dispatch_uri,
self._config.get_connection_info().push_pull_uri,
token=self._config.token,
cert=self._config.cert,
),
)
await wait_for_evaluator(
base_url=self._config.url,
token=self._config.token,
cert=self._config.cert,
)
# await wait_for_evaluator(
# base_url=self._config.url,
# token=self._config.token,
# cert=self._config.cert,
# )
await self._evaluate_inner(
event_unary_send=getattr(self, ce_unary_send_method_name),
scheduler_queue=scheduler_queue,
Expand Down Expand Up @@ -282,7 +277,7 @@
max_running=self._queue_config.max_running,
submit_sleep=self._queue_config.submit_sleep,
ens_id=self.id_,
ee_uri=self._config.dispatch_uri,
ee_uri=self._config.get_connection_info().push_pull_uri,
ee_cert=self._config.cert,
ee_token=self._config.token,
)
Expand Down
25 changes: 19 additions & 6 deletions src/ert/ensemble_evaluator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from datetime import datetime, timedelta
from typing import Optional

import zmq
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
Expand Down Expand Up @@ -129,13 +130,22 @@ def __init__(
custom_host: typing.Optional[str] = None,
) -> None:
self._socket_handle = find_available_socket(
custom_range=custom_port_range, custom_host=custom_host
custom_range=custom_port_range,
custom_host=custom_host,
will_close_then_reopen_socket=True,
)
host, port = self._socket_handle.getsockname()
self.protocol = "wss" if generate_cert else "ws"
self.url = f"{self.protocol}://{host}:{port}"
self.client_uri = f"{self.url}/client"
self.dispatch_uri = f"{self.url}/dispatch"
self.host = host
self.pub_sub_port = port

self._socket_handle = find_available_socket(
custom_range=custom_port_range,
custom_host=custom_host,
will_close_then_reopen_socket=True,
)
host, port = self._socket_handle.getsockname()
self.push_pull_port = port

if generate_cert:
cert, key, pw = _generate_certificate(host)
else:
Expand All @@ -145,13 +155,16 @@ def __init__(
self._key_pw = pw

self.token = _generate_authentication() if use_token else None
self.server_public_key, self.server_secret_key = zmq.curve_keypair()
self.token = self.server_public_key.decode("utf-8")

def get_socket(self) -> socket.socket:
return self._socket_handle.dup()

def get_connection_info(self) -> EvaluatorConnectionInfo:
return EvaluatorConnectionInfo(
self.url,
f"tcp://{self.host}:{self.push_pull_port}",
f"tcp://{self.host}:{self.pub_sub_port}",
self.cert,
self.token,
)
Expand Down
Loading
Loading