Skip to content

Commit

Permalink
New client-side retry design
Browse files Browse the repository at this point in the history
Implements connection retries at the socket level, making it easier to do authentication & re-authentication.

This will however break retries for `ReqSocket`, which should be dealt with in a follow-up PR. The main change would be to move connection establishment inside of the driver instead of the socket handle and follow the same pattern as in `SubSocket`.

Closes #51
Closes #52
  • Loading branch information
mempirate authored Jan 19, 2024
2 parents a492a6e + 13d9d94 commit a79810c
Show file tree
Hide file tree
Showing 20 changed files with 459 additions and 1,136 deletions.
4 changes: 1 addition & 3 deletions book/src/usage/authentication.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ async fn main() {
ReqOptions::default().auth_token(Bytes::from("client1")),
);

// The auth token needs to be set on the transport layer for now.
let tcp_config = tcp::Config::default().auth_token(Bytes::from("client1"));
let mut req = ReqSocket::new(Tcp::new(tcp_config));
...
}
```

Expand Down
74 changes: 74 additions & 0 deletions msg-socket/src/backoff/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use futures::{FutureExt, Stream};
use std::{pin::Pin, task::Poll, time::Duration};
use tokio::time::sleep;

/// A stream that yields exponentially increasing backoff durations.
pub struct ExponentialBackoff {
/// Current number of retries.
retry_count: usize,
/// Maximum number of retries before closing the stream.
max_retries: usize,
/// The current backoff duration.
backoff: Duration,
/// The current backoff timeout, if any.
/// We need the timeout to be pinned (`Sleep` is not `Unpin`)
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
}

impl ExponentialBackoff {
pub fn new(initial: Duration, max_retries: usize) -> Self {
Self {
retry_count: 0,
max_retries,
backoff: initial,
timeout: None,
}
}

/// (Re)-set the timeout to the current backoff duration.
fn reset_timeout(&mut self) {
self.timeout = Some(Box::pin(sleep(self.backoff)));
}
}

impl Stream for ExponentialBackoff {
type Item = Duration;

/// Polls the exponential backoff stream. Returns `Poll::Ready` with the current backoff
/// duration if the backoff timeout has elapsed, otherwise returns `Poll::Pending`.
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

loop {
if let Some(ref mut timeout) = this.timeout {
if timeout.poll_unpin(cx).is_ready() {
// Timeout has elapsed, so reset the timeout and double the backoff
this.backoff *= 2;
this.retry_count += 1;

// Close the stream
if this.retry_count >= this.max_retries {
return Poll::Ready(None);
}

this.reset_timeout();

// Wake up the task to poll the timeout again
cx.waker().wake_by_ref();

// Return the current backoff duration
return Poll::Ready(Some(this.backoff));
} else {
// Timeout has not elapsed, so return pending
return Poll::Pending;
}
} else {
// Set initial timeout
this.reset_timeout();
}
}
}
}
2 changes: 2 additions & 0 deletions msg-socket/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ mod rep;
mod req;
mod sub;

mod backoff;

use bytes::Bytes;
pub use pubs::{PubError, PubOptions, PubSocket};
pub use rep::*;
Expand Down
95 changes: 76 additions & 19 deletions msg-socket/src/pub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ mod tests {
use std::time::Duration;

use futures::StreamExt;
use msg_transport::tcp::{self, Tcp};
use msg_transport::{quic::Quic, tcp::Tcp};
use msg_wire::compression::GzipCompressor;

use crate::{Authenticator, SubOptions, SubSocket};
Expand Down Expand Up @@ -211,14 +211,43 @@ mod tests {
}

#[tokio::test]
async fn pubsub_auth() {
async fn pubsub_auth_tcp() {
let _ = tracing_subscriber::fmt::try_init();

let mut pub_socket = PubSocket::new(Tcp::default()).with_auth(Auth);

let mut sub_socket = SubSocket::with_options(
Tcp::new(tcp::Config::default().auth_token(Bytes::from("hello"))),
SubOptions::default(),
Tcp::default(),
SubOptions::default().auth_token(Bytes::from("client1")),
);

pub_socket.bind("0.0.0.0:0".parse().unwrap()).await.unwrap();
let addr = pub_socket.local_addr().unwrap();

sub_socket.connect(addr).await.unwrap();
sub_socket.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;

pub_socket
.publish("HELLO".to_string(), "WORLD".into())
.await
.unwrap();

let msg = sub_socket.next().await.unwrap();
tracing::info!("Received message: {:?}", msg);
assert_eq!("HELLO", msg.topic());
assert_eq!("WORLD", msg.payload());
}

#[tokio::test]
async fn pubsub_auth_quic() {
let _ = tracing_subscriber::fmt::try_init();

let mut pub_socket = PubSocket::new(Quic::default()).with_auth(Auth);

let mut sub_socket = SubSocket::with_options(
Quic::default(),
SubOptions::default().auth_token(Bytes::from("client1")),
);

pub_socket.bind("0.0.0.0:0".parse().unwrap()).await.unwrap();
Expand All @@ -245,11 +274,9 @@ mod tests {

let mut pub_socket = PubSocket::new(Tcp::default());

let mut sub1 =
SubSocket::<Tcp>::new(Tcp::new(tcp::Config::default().blocking_connect(true)));
let mut sub1 = SubSocket::new(Tcp::default());

let mut sub2 =
SubSocket::<Tcp>::new(Tcp::new(tcp::Config::default().blocking_connect(true)));
let mut sub2 = SubSocket::new(Tcp::default());

pub_socket.bind("0.0.0.0:0".parse().unwrap()).await.unwrap();
let addr = pub_socket.local_addr().unwrap();
Expand Down Expand Up @@ -280,14 +307,11 @@ mod tests {
async fn pubsub_many_compressed() {
let _ = tracing_subscriber::fmt::try_init();

let mut pub_socket =
PubSocket::<Tcp>::new(Tcp::default()).with_compressor(GzipCompressor::new(6));
let mut pub_socket = PubSocket::new(Tcp::default()).with_compressor(GzipCompressor::new(6));

let mut sub1 =
SubSocket::<Tcp>::new(Tcp::new(tcp::Config::default().blocking_connect(true)));
let mut sub1 = SubSocket::new(Tcp::default());

let mut sub2 =
SubSocket::<Tcp>::new(Tcp::new(tcp::Config::default().blocking_connect(true)));
let mut sub2 = SubSocket::new(Tcp::default());

pub_socket.bind("0.0.0.0:0".parse().unwrap()).await.unwrap();
let addr = pub_socket.local_addr().unwrap();
Expand Down Expand Up @@ -317,26 +341,59 @@ mod tests {
}

#[tokio::test]
async fn pubsub_durable() {
async fn pubsub_durable_tcp() {
let _ = tracing_subscriber::fmt::try_init();

let mut pub_socket = PubSocket::<Tcp>::new(Tcp::default());
let mut pub_socket = PubSocket::new(Tcp::default());

let mut sub_socket = SubSocket::<Tcp>::new(Tcp::default());
let mut sub_socket = SubSocket::new(Tcp::default());

// Try to connect and subscribe before the publisher is up
sub_socket
.connect("0.0.0.0:6662".parse().unwrap())
.await
.unwrap();
sub_socket.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(1000)).await;
tokio::time::sleep(Duration::from_millis(500)).await;

pub_socket
.bind("0.0.0.0:6662".parse().unwrap())
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
tokio::time::sleep(Duration::from_millis(2000)).await;

pub_socket
.publish("HELLO".to_string(), Bytes::from("WORLD"))
.await
.unwrap();

let msg = sub_socket.next().await.unwrap();
tracing::info!("Received message: {:?}", msg);
assert_eq!("HELLO", msg.topic());
assert_eq!("WORLD", msg.payload());
}

#[tokio::test]
async fn pubsub_durable_quic() {
let _ = tracing_subscriber::fmt::try_init();

let mut pub_socket = PubSocket::new(Quic::default());

let mut sub_socket = SubSocket::new(Quic::default());

// Try to connect and subscribe before the publisher is up
sub_socket
.connect("0.0.0.0:6662".parse().unwrap())
.await
.unwrap();
sub_socket.subscribe("HELLO".to_string()).await.unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;

pub_socket
.bind("0.0.0.0:6662".parse().unwrap())
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(2000)).await;

pub_socket
.publish("HELLO".to_string(), Bytes::from("WORLD"))
Expand Down
30 changes: 19 additions & 11 deletions msg-socket/src/rep/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ mod tests {
use std::time::Duration;

use futures::StreamExt;
use msg_transport::tcp::{self, Tcp};
use msg_transport::tcp::Tcp;
use rand::Rng;

use crate::{req::ReqSocket, Authenticator};
use crate::{req::ReqSocket, Authenticator, ReqOptions};

use super::*;

Expand All @@ -89,7 +89,7 @@ mod tests {
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_req_rep_simple() {
async fn reqrep_simple() {
let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::new(Tcp::default());
rep.bind(localhost()).await.unwrap();
Expand Down Expand Up @@ -125,21 +125,28 @@ mod tests {
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_req_rep_durable() {
async fn reqrep_durable() {
let _ = tracing_subscriber::fmt::try_init();
let random_port = rand::random::<u16>() + 10000;
let addr = format!("0.0.0.0:{}", random_port);

// Initialize the request socket (client side) with a transport
let mut req = ReqSocket::new(Tcp::default());
let endpoint = addr.parse().unwrap();
// Try to connect even through the server isn't up yet
req.connect(addr.parse().unwrap()).await.unwrap();
let connection_attempt = tokio::spawn(async move {
req.connect(endpoint).await.unwrap();

req
});

// Wait a moment to start the server
tokio::time::sleep(Duration::from_secs(1)).await;
tokio::time::sleep(Duration::from_millis(500)).await;
let mut rep = RepSocket::new(Tcp::default());
rep.bind(addr.parse().unwrap()).await.unwrap();

let req = connection_attempt.await.unwrap();

tokio::spawn(async move {
// Receive the request and respond with "world"
// RepSocket implements `Stream`
Expand All @@ -153,7 +160,7 @@ mod tests {
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_req_rep_auth() {
async fn reqrep_auth() {
struct Auth;

impl Authenticator for Auth {
Expand All @@ -168,9 +175,10 @@ mod tests {
rep.bind(localhost()).await.unwrap();

// Initialize socket with a client ID. This will implicitly enable authentication.
let mut req = ReqSocket::new(Tcp::new(
tcp::Config::default().auth_token(Bytes::from("REQ")),
));
let mut req = ReqSocket::with_options(
Tcp::default(),
ReqOptions::default().auth_token(Bytes::from("REQ")),
);

req.connect(rep.local_addr().unwrap()).await.unwrap();

Expand Down Expand Up @@ -204,7 +212,7 @@ mod tests {
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_rep_max_connections() {
async fn rep_max_connections() {
let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::with_options(Tcp::default(), RepOptions::default().max_clients(1));
rep.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
Expand Down
8 changes: 8 additions & 0 deletions msg-socket/src/req/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub enum Command {

#[derive(Debug, Clone)]
pub struct ReqOptions {
auth_token: Option<Bytes>,
/// Timeout duration for requests.
timeout: std::time::Duration,
/// Wether to block on initial connection to the target.
Expand All @@ -57,6 +58,12 @@ pub struct ReqOptions {
}

impl ReqOptions {
/// Sets the authentication token for the socket.
pub fn auth_token(mut self, auth_token: Bytes) -> Self {
self.auth_token = Some(auth_token);
self
}

/// Sets the timeout for the socket.
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
Expand Down Expand Up @@ -100,6 +107,7 @@ impl ReqOptions {
impl Default for ReqOptions {
fn default() -> Self {
Self {
auth_token: None,
timeout: std::time::Duration::from_secs(5),
blocking_connect: true,
backoff_duration: Duration::from_millis(200),
Expand Down
Loading

0 comments on commit a79810c

Please sign in to comment.