diff --git a/Cargo.lock b/Cargo.lock index 05c89fa..0b5a802 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -728,6 +728,7 @@ version = "0.1.1" dependencies = [ "futures", "tokio", + "tokio-util", ] [[package]] diff --git a/msg-common/Cargo.toml b/msg-common/Cargo.toml index 4faa305..b138826 100644 --- a/msg-common/Cargo.toml +++ b/msg-common/Cargo.toml @@ -15,3 +15,4 @@ repository.workspace = true [dependencies] futures.workspace = true tokio.workspace = true +tokio-util.workspace = true diff --git a/msg-common/src/lib.rs b/msg-common/src/lib.rs index 06e797c..1f2aedf 100644 --- a/msg-common/src/lib.rs +++ b/msg-common/src/lib.rs @@ -1,13 +1,17 @@ use futures::future::BoxFuture; use std::{ + pin::Pin, task::{Context, Poll}, time::SystemTime, }; + +use futures::{Sink, SinkExt, Stream}; use tokio::sync::mpsc::{ self, - error::{SendError, TryRecvError, TrySendError}, - Receiver, Sender, + error::{TryRecvError, TrySendError}, + Receiver, }; +use tokio_util::sync::{PollSendError, PollSender}; pub mod task; @@ -36,8 +40,10 @@ pub mod constants { /// A bounded, bi-directional channel for sending and receiving messages. /// Relies on Tokio's [`mpsc`] channel. +/// +/// Channel also implements the [`Stream`] and [`Sink`] traits for convenience. pub struct Channel { - tx: Sender, + tx: PollSender, rx: Receiver, } @@ -49,14 +55,21 @@ pub struct Channel { /// the tuple can be used to send messages of type `S` and receive messages of /// type `R`. The second channel can be used to send messages of type `R` and /// receive messages of type `S`. -pub fn channel(tx_buffer: usize, rx_buffer: usize) -> (Channel, Channel) { +pub fn channel(tx_buffer: usize, rx_buffer: usize) -> (Channel, Channel) +where + S: Send, + R: Send, +{ let (tx1, rx1) = mpsc::channel(tx_buffer); let (tx2, rx2) = mpsc::channel(rx_buffer); + let tx1 = PollSender::new(tx1); + let tx2 = PollSender::new(tx2); + (Channel { tx: tx1, rx: rx2 }, Channel { tx: tx2, rx: rx1 }) } -impl Channel { +impl Channel { /// Sends a value, waiting until there is capacity. /// /// A successful send occurs when it is determined that the other end of the @@ -66,7 +79,7 @@ impl Channel { /// value of `Ok` does not mean that the data will be received. It is /// possible for the corresponding receiver to hang up immediately after /// this function returns `Ok`. - pub async fn send(&mut self, msg: S) -> Result<(), SendError> { + pub async fn send(&mut self, msg: S) -> Result<(), PollSendError> { self.tx.send(msg).await } @@ -77,7 +90,11 @@ impl Channel { /// with [`send`], this function has two failure cases instead of one (one for /// disconnection, one for a full buffer). pub fn try_send(&mut self, msg: S) -> Result<(), TrySendError> { - self.tx.try_send(msg) + if let Some(tx) = self.tx.get_ref() { + tx.try_send(msg) + } else { + Err(TrySendError::Closed(msg)) + } } /// Receives the next value for this receiver. @@ -135,3 +152,31 @@ impl Channel { self.rx.poll_recv(cx) } } + +impl Stream for Channel { + type Item = R; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) + } +} + +impl Sink for Channel { + type Error = PollSendError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.tx.poll_ready_unpin(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: S) -> Result<(), Self::Error> { + self.tx.start_send_unpin(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.tx.poll_flush_unpin(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.tx.poll_close_unpin(cx) + } +} diff --git a/msg-socket/src/backoff/mod.rs b/msg-socket/src/connection/backoff.rs similarity index 88% rename from msg-socket/src/backoff/mod.rs rename to msg-socket/src/connection/backoff.rs index 79e362d..4dde9a7 100644 --- a/msg-socket/src/backoff/mod.rs +++ b/msg-socket/src/connection/backoff.rs @@ -2,6 +2,13 @@ use futures::{FutureExt, Stream}; use std::{pin::Pin, task::Poll, time::Duration}; use tokio::time::sleep; +/// Helper trait alias for backoff streams. +/// We define any stream that yields `Duration`s as a backoff +pub trait Backoff: Stream + Unpin {} + +/// Blanket implementation of `Backoff` for any stream that yields `Duration`s. +impl Backoff for T where T: Stream + Unpin {} + /// A stream that yields exponentially increasing backoff durations. pub struct ExponentialBackoff { /// Current number of retries. diff --git a/msg-socket/src/connection/mod.rs b/msg-socket/src/connection/mod.rs new file mode 100644 index 0000000..e16d88b --- /dev/null +++ b/msg-socket/src/connection/mod.rs @@ -0,0 +1,5 @@ +pub mod state; +pub use state::ConnectionState; + +pub mod backoff; +pub use backoff::{Backoff, ExponentialBackoff}; diff --git a/msg-socket/src/connection/state.rs b/msg-socket/src/connection/state.rs new file mode 100644 index 0000000..de5a5f6 --- /dev/null +++ b/msg-socket/src/connection/state.rs @@ -0,0 +1,34 @@ +use std::net::SocketAddr; + +use super::Backoff; + +/// Abstraction to represent the state of a connection. +/// +/// * `C` is the channel type, which is used to send and receive generic messages. +/// * `B` is the backoff type, used to control the backoff state for inactive connections. +pub enum ConnectionState { + Active { + /// Channel to control the underlying connection. This is used to send + /// and receive any kind of message in any direction. + channel: C, + }, + Inactive { + addr: SocketAddr, + /// The current backoff state for inactive connections. + backoff: B, + }, +} + +impl ConnectionState { + /// Returns `true` if the connection is active. + #[allow(unused)] + pub fn is_active(&self) -> bool { + matches!(self, Self::Active { .. }) + } + + /// Returns `true` if the connection is inactive. + #[allow(unused)] + pub fn is_inactive(&self) -> bool { + matches!(self, Self::Inactive { .. }) + } +} diff --git a/msg-socket/src/lib.rs b/msg-socket/src/lib.rs index d3b9c45..9b387a1 100644 --- a/msg-socket/src/lib.rs +++ b/msg-socket/src/lib.rs @@ -7,7 +7,7 @@ mod rep; mod req; mod sub; -mod backoff; +mod connection; use bytes::Bytes; pub use pubs::{PubError, PubOptions, PubSocket}; diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index 7911726..d577a6c 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -12,7 +12,10 @@ use std::{ use tokio::sync::{mpsc, oneshot}; use tokio_util::codec::Framed; -use crate::req::SocketState; +use crate::{ + connection::{ConnectionState, ExponentialBackoff}, + req::SocketState, +}; use super::{Command, ReqError, ReqOptions}; use msg_wire::{ @@ -34,8 +37,9 @@ pub(crate) struct ReqDriver { pub(crate) id_counter: u32, /// Commands from the socket. pub(crate) from_socket: mpsc::Receiver, - /// The actual [`Framed`] connection with the `Req`-specific codec. - pub(crate) conn: Framed, + /// The transport controller, wrapped in a [`ConnectionState`] for backoff. + /// The [`Framed`] object can send and receive messages from the socket. + pub(crate) conn_state: ConnectionState, ExponentialBackoff>, /// The outgoing message queue. pub(crate) egress_queue: VecDeque, /// The currently pending requests, if any. Uses [`FxHashMap`] for performance. @@ -138,34 +142,70 @@ where let this = self.get_mut(); loop { + // Try to flush pending messages if this.should_flush(cx) { - if let Poll::Ready(Ok(_)) = this.conn.poll_flush_unpin(cx) { - this.should_flush = false; + if let ConnectionState::Active { ref mut channel } = this.conn_state { + if let Poll::Ready(Ok(_)) = channel.poll_flush_unpin(cx) { + this.should_flush = false; + } } + + // TODO: what to do with an inactive connection here? } - // Check for incoming messages from the socket - match this.conn.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(msg))) => { - this.on_message(msg); + match this.conn_state { + ConnectionState::Active { ref mut channel } => { + // Check for incoming messages from the socket + match channel.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(msg))) => { + this.on_message(msg); - continue; - } - Poll::Ready(Some(Err(e))) => { - if let reqrep::Error::Io(e) = e { - tracing::error!("Socket error: {:?}", e); - if e.kind() == std::io::ErrorKind::Other { - tracing::error!("Other error: {:?}", e); + continue; + } + Poll::Ready(Some(Err(e))) => { + if let reqrep::Error::Io(e) = e { + tracing::error!("Socket error: {:?}", e); + if e.kind() == std::io::ErrorKind::Other { + tracing::error!("Other error: {:?}", e); + return Poll::Ready(()); + } + } + + continue; + } + Poll::Ready(None) => { + tracing::debug!("Socket closed, shutting down backend"); return Poll::Ready(()); } + Poll::Pending => {} + } + + // Check for outgoing messages to the socket + if channel.poll_ready_unpin(cx).is_ready() { + // Drain the egress queue + if let Some(msg) = this.egress_queue.pop_front() { + // Generate the new message + let size = msg.size(); + tracing::debug!("Sending msg {}", msg.id()); + match channel.start_send_unpin(msg) { + Ok(_) => { + this.socket_state.stats.increment_tx(size); + + this.should_flush = true; + // We might be able to send more queued messages + continue; + } + Err(e) => { + tracing::error!("Failed to send message to socket: {:?}", e); + return Poll::Ready(()); + } + } + } } - continue; } - Poll::Ready(None) => { - tracing::debug!("Socket closed, shutting down backend"); - return Poll::Ready(()); + ConnectionState::Inactive { addr, ref backoff } => { + // TODO: handle backoff in case of an inactive connection } - Poll::Pending => {} } // Check for request timeouts @@ -173,28 +213,6 @@ where this.check_timeouts(); } - if this.conn.poll_ready_unpin(cx).is_ready() { - // Drain the egress queue - if let Some(msg) = this.egress_queue.pop_front() { - // Generate the new message - let size = msg.size(); - tracing::debug!("Sending msg {}", msg.id()); - match this.conn.start_send_unpin(msg) { - Ok(_) => { - this.socket_state.stats.increment_tx(size); - - this.should_flush = true; - // We might be able to send more queued messages - continue; - } - Err(e) => { - tracing::error!("Failed to send message to socket: {:?}", e); - return Poll::Ready(()); - } - } - } - } - // Check for outgoing messages from the socket handle match this.from_socket.poll_recv(cx) { Poll::Ready(Some(Command::Send { @@ -236,7 +254,13 @@ where tracing::debug!( "Socket dropped, shutting down backend and flushing connection" ); - let _ = ready!(this.conn.poll_close_unpin(cx)); + + if let ConnectionState::Active { ref mut channel } = this.conn_state { + let _ = ready!(channel.poll_close_unpin(cx)); + } + + // TODO: handle inactive connection here? + return Poll::Ready(()); } Poll::Pending => {} diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 17dacf4..204b754 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -12,7 +12,7 @@ use msg_transport::Transport; use msg_wire::{auth, reqrep}; use super::{Command, ReqDriver, ReqError, ReqOptions, DEFAULT_BUFFER_SIZE}; -use crate::backoff::ExponentialBackoff; +use crate::connection::{ConnectionState, ExponentialBackoff}; use crate::ReqMessage; use crate::{req::stats::SocketStats, req::SocketState}; @@ -122,13 +122,14 @@ where let mut framed = Framed::new(stream, reqrep::Codec::new()); framed.set_backpressure_boundary(self.options.backpressure_boundary); + let conn = ConnectionState::Active { channel: framed }; // Create the socket backend let driver: ReqDriver = ReqDriver { options: Arc::clone(&self.options), id_counter: 0, from_socket, - conn: framed, + conn_state: conn, egress_queue: VecDeque::new(), // TODO: we should limit the amount of active outgoing requests, and that should be the capacity. // If we do this, we'll never have to re-allocate. diff --git a/msg-socket/src/sub/driver.rs b/msg-socket/src/sub/driver.rs index 9571c7f..8ab2360 100644 --- a/msg-socket/src/sub/driver.rs +++ b/msg-socket/src/sub/driver.rs @@ -1,4 +1,4 @@ -use futures::{Future, SinkExt, Stream, StreamExt}; +use futures::{Future, SinkExt, StreamExt}; use rustc_hash::FxHashMap; use std::{ collections::HashSet, @@ -7,13 +7,12 @@ use std::{ pin::Pin, sync::Arc, task::{Context, Poll}, - time::Duration, }; use tokio::sync::mpsc::{self, error::TrySendError}; use tokio_util::codec::Framed; use tracing::{debug, error, info, warn}; -use crate::backoff::ExponentialBackoff; +use crate::connection::{ConnectionState, ExponentialBackoff}; use super::session::SessionCommand; use super::{ @@ -26,6 +25,10 @@ use msg_common::{channel, task::JoinMap, Channel}; use msg_transport::Transport; use msg_wire::{auth, compression::try_decompress_payload, pubsub}; +/// Publisher channel type, used to send messages to the publisher session +/// and receive messages to forward to the socket frontend. +type PubChannel = Channel; + pub(crate) struct SubDriver { /// Options shared with the socket. pub(super) options: Arc, @@ -39,29 +42,12 @@ pub(crate) struct SubDriver { pub(super) connection_tasks: JoinMap>, /// The set of subscribed topics. pub(super) subscribed_topics: HashSet, - /// All active publisher sessions for this subscriber socket. - pub(super) publishers: FxHashMap>, + /// All publisher sessions for this subscriber socket, keyed by address. + pub(super) publishers: FxHashMap>, /// Socket state. This is shared with the backend task. pub(super) state: Arc, } -/// Represents the state of a publisher. -pub(crate) enum PublisherState -where - S: Stream, -{ - Active { - /// The channel to the publisher session. - channel: Channel, - }, - Inactive { - /// The address of the publisher. - addr: SocketAddr, - /// Exponential backoff for retrying connections. - backoff: S, - }, -} - impl Future for SubDriver where T: Transport + Send + Sync + Unpin + 'static, @@ -78,12 +64,14 @@ where continue; } + // Then, poll the socket for new commands. if let Poll::Ready(Some(cmd)) = this.from_socket.poll_recv(cx) { this.on_command(cmd); continue; } + // Finally, poll the connection tasks for new connections. if let Poll::Ready(Some(Ok((addr, result)))) = this.connection_tasks.poll_join_next(cx) { match result { @@ -107,13 +95,13 @@ impl SubDriver where T: Transport + Send + Sync + 'static, { - /// De-activates a publisher by setting it to [`PublisherState::Inactive`]. This will initialize - /// the backoff stream. + /// De-activates a publisher by setting it to [`ConnectionState::Inactive`]. + /// This will initialize the backoff stream. fn reset_publisher(&mut self, addr: SocketAddr) { tracing::debug!("Resetting publisher at {addr:?}"); self.publishers.insert( addr, - PublisherState::Inactive { + ConnectionState::Inactive { addr, backoff: ExponentialBackoff::new(self.options.initial_backoff, 16), }, @@ -122,7 +110,7 @@ where /// Returns true if we're already connected to the given publisher address. fn is_connected(&self, addr: &SocketAddr) -> bool { - if let Some(PublisherState::Active { .. }) = self.publishers.get(addr) { + if self.publishers.get(addr).is_some_and(|s| s.is_active()) { return true; } @@ -141,7 +129,7 @@ where // Subscribe to the topic on all publishers for (addr, publisher_state) in self.publishers.iter_mut() { - if let PublisherState::Active { channel } = publisher_state { + if let ConnectionState::Active { channel } = publisher_state { // If the channel is closed on the other side, deactivate the publisher if let Err(TrySendError::Closed(_)) = channel.try_send(SessionCommand::Subscribe(topic.clone())) @@ -175,7 +163,7 @@ where if self.subscribed_topics.remove(&topic) { // Unsubscribe from the topic on all publishers for (addr, publisher_state) in self.publishers.iter_mut() { - if let PublisherState::Active { channel } = publisher_state { + if let ConnectionState::Active { channel } = publisher_state { // If the channel is closed on the other side, deactivate the publisher if let Err(TrySendError::Closed(_)) = channel.try_send(SessionCommand::Unsubscribe(topic.clone())) @@ -325,7 +313,7 @@ where let (driver_channel, mut publisher_channel) = channel(1024, 64); let publisher_session = - PublisherSession::new(addr, PublisherStream::new(framed), driver_channel); + PublisherSession::new(addr, PublisherStream::from(framed), driver_channel); // Get the shared session stats. let session_stats = publisher_session.stats(); @@ -344,7 +332,7 @@ where self.publishers.insert( addr, - PublisherState::Active { + ConnectionState::Active { channel: publisher_channel, }, ); @@ -367,7 +355,7 @@ where for (addr, state) in self.publishers.iter_mut() { match state { - PublisherState::Active { channel } => { + ConnectionState::Active { channel } => { match channel.poll_recv(cx) { Poll::Ready(Some(mut msg)) => { match try_decompress_payload(msg.compression_type, msg.payload) { @@ -400,7 +388,7 @@ where Poll::Pending => {} } } - PublisherState::Inactive { addr, backoff } => { + ConnectionState::Inactive { addr, backoff } => { // Poll the backoff stream if let Poll::Ready(item) = backoff.poll_next_unpin(cx) { if let Some(duration) = item { diff --git a/msg-socket/src/sub/stream.rs b/msg-socket/src/sub/stream.rs index af1d7e5..60fc450 100644 --- a/msg-socket/src/sub/stream.rs +++ b/msg-socket/src/sub/stream.rs @@ -18,11 +18,6 @@ pub(super) struct PublisherStream { } impl PublisherStream { - /// Cretes a new publisher stream from the given framed connection. - pub fn new(conn: Framed) -> Self { - Self { conn, flush: false } - } - /// Queues a message to be sent to the publisher. If the connection /// is ready, this will register the waker /// and flush on the next poll. @@ -47,6 +42,12 @@ impl PublisherStream { } } +impl From> for PublisherStream { + fn from(conn: Framed) -> Self { + Self { conn, flush: false } + } +} + pub(super) struct TopicMessage { pub timestamp: u64, pub compression_type: u8, diff --git a/msg/examples/reqrep.rs b/msg/examples/reqrep.rs index 81b238d..286facd 100644 --- a/msg/examples/reqrep.rs +++ b/msg/examples/reqrep.rs @@ -22,6 +22,15 @@ async fn main() { req.respond(Bytes::from("world")).unwrap(); }); - let res: Bytes = req.request(Bytes::from("hello")).await.unwrap(); + let res: Bytes = req.request(Bytes::from("helloooo!")).await.unwrap(); println!("Response: {:?}", res); + + // Access the socket statistics + let stats = req.stats(); + println!( + "Sent: {}B, Received: {}B | time: {}μs", + stats.bytes_tx(), + stats.bytes_rx(), + stats.rtt() + ); }