Skip to content

Commit

Permalink
Merge pull request #70 from chainbound/feat/connection-state
Browse files Browse the repository at this point in the history
feat: `ConnectionState` abstraction
  • Loading branch information
mempirate authored Jan 30, 2024
2 parents 7d7f6dc + 5e5bbdd commit 65605e4
Show file tree
Hide file tree
Showing 15 changed files with 409 additions and 209 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions msg-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ repository.workspace = true
[dependencies]
futures.workspace = true
tokio.workspace = true
tokio-util.workspace = true
59 changes: 52 additions & 7 deletions msg-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
use futures::future::BoxFuture;
use std::{
pin::Pin,
task::{Context, Poll},
time::SystemTime,
};

use futures::{Sink, SinkExt, Stream};
use tokio::sync::mpsc::{
self,
error::{SendError, TryRecvError, TrySendError},
Receiver, Sender,
error::{TryRecvError, TrySendError},
Receiver,
};
use tokio_util::sync::{PollSendError, PollSender};

pub mod task;

Expand Down Expand Up @@ -36,8 +40,10 @@ pub mod constants {

/// A bounded, bi-directional channel for sending and receiving messages.
/// Relies on Tokio's [`mpsc`] channel.
///
/// Channel also implements the [`Stream`] and [`Sink`] traits for convenience.
pub struct Channel<S, R> {
tx: Sender<S>,
tx: PollSender<S>,
rx: Receiver<R>,
}

Expand All @@ -49,14 +55,21 @@ pub struct Channel<S, R> {
/// the tuple can be used to send messages of type `S` and receive messages of
/// type `R`. The second channel can be used to send messages of type `R` and
/// receive messages of type `S`.
pub fn channel<S, R>(tx_buffer: usize, rx_buffer: usize) -> (Channel<S, R>, Channel<R, S>) {
pub fn channel<S, R>(tx_buffer: usize, rx_buffer: usize) -> (Channel<S, R>, Channel<R, S>)
where
S: Send,
R: Send,
{
let (tx1, rx1) = mpsc::channel(tx_buffer);
let (tx2, rx2) = mpsc::channel(rx_buffer);

let tx1 = PollSender::new(tx1);
let tx2 = PollSender::new(tx2);

(Channel { tx: tx1, rx: rx2 }, Channel { tx: tx2, rx: rx1 })
}

impl<S, R> Channel<S, R> {
impl<S: Send + 'static, R> Channel<S, R> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
Expand All @@ -66,7 +79,7 @@ impl<S, R> Channel<S, R> {
/// value of `Ok` does not mean that the data will be received. It is
/// possible for the corresponding receiver to hang up immediately after
/// this function returns `Ok`.
pub async fn send(&mut self, msg: S) -> Result<(), SendError<S>> {
pub async fn send(&mut self, msg: S) -> Result<(), PollSendError<S>> {
self.tx.send(msg).await
}

Expand All @@ -77,7 +90,11 @@ impl<S, R> Channel<S, R> {
/// with [`send`], this function has two failure cases instead of one (one for
/// disconnection, one for a full buffer).
pub fn try_send(&mut self, msg: S) -> Result<(), TrySendError<S>> {
self.tx.try_send(msg)
if let Some(tx) = self.tx.get_ref() {
tx.try_send(msg)
} else {
Err(TrySendError::Closed(msg))
}
}

/// Receives the next value for this receiver.
Expand Down Expand Up @@ -135,3 +152,31 @@ impl<S, R> Channel<S, R> {
self.rx.poll_recv(cx)
}
}

impl<S, R> Stream for Channel<S, R> {
type Item = R;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}

impl<S: Send + 'static, R> Sink<S> for Channel<S, R> {
type Error = PollSendError<S>;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_ready_unpin(cx)
}

fn start_send(mut self: Pin<&mut Self>, item: S) -> Result<(), Self::Error> {
self.tx.start_send_unpin(item)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_flush_unpin(cx)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_close_unpin(cx)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ use futures::{FutureExt, Stream};
use std::{pin::Pin, task::Poll, time::Duration};
use tokio::time::sleep;

/// Helper trait alias for backoff streams.
/// We define any stream that yields `Duration`s as a backoff
pub trait Backoff: Stream<Item = Duration> + Unpin {}

/// Blanket implementation of `Backoff` for any stream that yields `Duration`s.
impl<T> Backoff for T where T: Stream<Item = Duration> + Unpin {}

/// A stream that yields exponentially increasing backoff durations.
pub struct ExponentialBackoff {
/// Current number of retries.
Expand Down
5 changes: 5 additions & 0 deletions msg-socket/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod state;
pub use state::ConnectionState;

pub mod backoff;
pub use backoff::{Backoff, ExponentialBackoff};
34 changes: 34 additions & 0 deletions msg-socket/src/connection/state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::net::SocketAddr;

use super::Backoff;

/// Abstraction to represent the state of a connection.
///
/// * `C` is the channel type, which is used to send and receive generic messages.
/// * `B` is the backoff type, used to control the backoff state for inactive connections.
pub enum ConnectionState<C, B> {
Active {
/// Channel to control the underlying connection. This is used to send
/// and receive any kind of message in any direction.
channel: C,
},
Inactive {
addr: SocketAddr,
/// The current backoff state for inactive connections.
backoff: B,
},
}

impl<C, B: Backoff> ConnectionState<C, B> {
/// Returns `true` if the connection is active.
#[allow(unused)]
pub fn is_active(&self) -> bool {
matches!(self, Self::Active { .. })
}

/// Returns `true` if the connection is inactive.
#[allow(unused)]
pub fn is_inactive(&self) -> bool {
matches!(self, Self::Inactive { .. })
}
}
2 changes: 1 addition & 1 deletion msg-socket/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod rep;
mod req;
mod sub;

mod backoff;
mod connection;

use bytes::Bytes;
pub use pubs::{PubError, PubOptions, PubSocket};
Expand Down
1 change: 0 additions & 1 deletion msg-socket/src/pub/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ where
tokio::spawn(backend);

self.local_addr = Some(local_addr);
// self.to_driver = Some(to_driver);
self.to_sessions_bcast = Some(to_sessions_bcast);

Ok(())
Expand Down
8 changes: 5 additions & 3 deletions msg-socket/src/rep/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,15 @@ mod tests {
let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::with_options(Tcp::default(), RepOptions::default().max_clients(1));
rep.bind("127.0.0.1:0").await.unwrap();
let addr = rep.local_addr().unwrap();

let mut req1 = ReqSocket::new(Tcp::default());
req1.connect(rep.local_addr().unwrap()).await.unwrap();
req1.connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(rep.stats().active_clients(), 1);

let mut req2 = ReqSocket::new(Tcp::default());
req2.connect(rep.local_addr().unwrap()).await.unwrap();

req2.connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(rep.stats().active_clients(), 1);
}
Expand Down
Loading

0 comments on commit 65605e4

Please sign in to comment.