From 2626feeec2b6ed2347999d6c44a47972d14e8bc9 Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:28:06 +0200 Subject: [PATCH 1/3] feat: bump MSRV + minor refactoring --- Cargo.lock | 12 ++--- Cargo.toml | 89 ++++++++++++++++++++++++++++++++- README.md | 2 +- book/src/usage/faq.md | 2 +- justfile | 14 ++++++ msg-common/src/channel.rs | 18 +++---- msg-sim/src/dummynet.rs | 4 +- msg-socket/src/lib.rs | 21 +++++--- msg-socket/src/pub/driver.rs | 3 +- msg-socket/src/pub/mod.rs | 31 +++++++----- msg-socket/src/pub/session.rs | 15 ++++-- msg-socket/src/pub/socket.rs | 13 ++--- msg-socket/src/pub/stats.rs | 2 - msg-socket/src/pub/trie.rs | 6 +++ msg-socket/src/rep/driver.rs | 19 ++++--- msg-socket/src/rep/mod.rs | 14 +++--- msg-socket/src/rep/socket.rs | 19 ++++--- msg-socket/src/rep/stats.rs | 4 +- msg-socket/src/req/driver.rs | 18 ++++--- msg-socket/src/req/mod.rs | 15 ++++-- msg-socket/src/req/socket.rs | 6 +-- msg-socket/src/sub/mod.rs | 10 ++-- msg-socket/src/sub/socket.rs | 45 +++++++---------- msg-socket/src/sub/stream.rs | 11 ++-- msg-transport/src/ipc/mod.rs | 8 +-- msg-transport/src/lib.rs | 9 ++-- msg-wire/src/compression/mod.rs | 9 ++-- msg-wire/src/reqrep.rs | 3 ++ msg/Cargo.toml | 10 ++-- 29 files changed, 282 insertions(+), 150 deletions(-) create mode 100644 justfile diff --git a/Cargo.lock b/Cargo.lock index f190edd..52d229f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -713,7 +713,7 @@ dependencies = [ [[package]] name = "msg" -version = "0.1.2" +version = "0.1.3" dependencies = [ "bytes", "criterion", @@ -733,7 +733,7 @@ dependencies = [ [[package]] name = "msg-common" -version = "0.1.2" +version = "0.1.3" dependencies = [ "futures", "tokio", @@ -742,14 +742,14 @@ dependencies = [ [[package]] name = "msg-sim" -version = "0.1.2" +version = "0.1.3" dependencies = [ "pnet", ] [[package]] name = "msg-socket" -version = "0.1.2" +version = "0.1.3" dependencies = [ "bytes", "futures", @@ -770,7 +770,7 @@ dependencies = [ [[package]] name = "msg-transport" -version = "0.1.2" +version = "0.1.3" dependencies = [ "async-trait", "futures", @@ -786,7 +786,7 @@ dependencies = [ [[package]] name = "msg-wire" -version = "0.1.2" +version = "0.1.3" dependencies = [ "bytes", "flate2", diff --git a/Cargo.toml b/Cargo.toml index 35960fc..5fe2e46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,9 @@ members = [ resolver = "2" [workspace.package] -version = "0.1.2" +version = "0.1.3" edition = "2021" -rust-version = "1.70" +rust-version = "1.75" license = "MIT" description = "A flexible and lightweight messaging library for distributed systems" authors = ["Jonas Bostoen", "Nicolas Racchi"] @@ -72,3 +72,88 @@ opt-level = 3 [profile.debug-maxperf] inherits = "maxperf" debug = true + +[workspace.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[workspace.lints.rustdoc] +all = "warn" + +[workspace.lints.rust] +missing_debug_implementations = "warn" +missing_docs = "warn" +rust-2018-idioms = { level = "deny", priority = -1 } +unreachable-pub = "warn" +unused-must-use = "deny" + +[workspace.lints.clippy] +# These are some of clippy's nursery (i.e., experimental) lints that we like. +# By default, nursery lints are allowed. Some of the lints below have made good +# suggestions which we fixed. The others didn't have any findings, so we can +# assume they don't have that many false positives. Let's enable them to +# prevent future problems. +branches_sharing_code = "warn" +clear_with_drain = "warn" +derive_partial_eq_without_eq = "warn" +doc_markdown = "warn" +empty_line_after_doc_comments = "warn" +empty_line_after_outer_attr = "warn" +enum_glob_use = "warn" +equatable_if_let = "warn" +explicit_into_iter_loop = "warn" +explicit_iter_loop = "warn" +flat_map_option = "warn" +imprecise_flops = "warn" +iter_on_empty_collections = "warn" +iter_on_single_items = "warn" +iter_with_drain = "warn" +iter_without_into_iter = "warn" +large_stack_frames = "warn" +manual_assert = "warn" +manual_clamp = "warn" +manual_string_new = "warn" +match_same_arms = "warn" +missing_const_for_fn = "warn" +mutex_integer = "warn" +naive_bytecount = "warn" +needless_bitwise_bool = "warn" +needless_continue = "warn" +needless_pass_by_ref_mut = "warn" +nonstandard_macro_braces = "warn" +or_fun_call = "warn" +path_buf_push_overwrite = "warn" +read_zero_byte_vec = "warn" +redundant_clone = "warn" +single_char_pattern = "warn" +string_lit_as_bytes = "warn" +suboptimal_flops = "warn" +suspicious_operation_groupings = "warn" +trailing_empty_array = "warn" +trait_duplication_in_bounds = "warn" +transmute_undefined_repr = "warn" +trivial_regex = "warn" +tuple_array_conversions = "warn" +type_repetition_in_bounds = "warn" +uninhabited_references = "warn" +unnecessary_struct_initialization = "warn" +unused_peekable = "warn" +unused_rounding = "warn" +use_self = "warn" +useless_let_if_seq = "warn" +zero_sized_map_values = "warn" + +# These are nursery lints which have findings. Allow them for now. Some are not +# quite mature enough for use in our codebase and some we don't really want. +# Explicitly listing should make it easier to fix in the future. +as_ptr_cast_mut = "allow" +cognitive_complexity = "allow" +collection_is_never_read = "allow" +debug_assert_with_mut_call = "allow" +fallible_impl_from = "allow" +future_not_send = "allow" +needless_collect = "allow" +non_send_fields_in_send_ty = "allow" +redundant_pub_crate = "allow" +significant_drop_in_scrutinee = "allow" +significant_drop_tightening = "allow" diff --git a/README.md b/README.md index 34e98e7..8afc392 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ The 📖 [MSG-RS Book][book] contains detailed information on how to use the lib ## MSRV -The minimum supported Rust version is 1.70. +The minimum supported Rust version is 1.75. ## Contributions & Bug Reports diff --git a/book/src/usage/faq.md b/book/src/usage/faq.md index 51aa63f..5701f38 100644 --- a/book/src/usage/faq.md +++ b/book/src/usage/faq.md @@ -15,4 +15,4 @@ Until then, we recommend using the git dependency as shown in the [Getting start ## What is the minimum supported Rust version (MSRV)? -MSG-RS currently supports Rust 1.70 or later. +MSG-RS currently supports Rust 1.75 or later. diff --git a/justfile b/justfile new file mode 100644 index 0000000..5ced516 --- /dev/null +++ b/justfile @@ -0,0 +1,14 @@ +default: check doc fmt + +check: + cargo check --workspace --all-features --all-targets + +doc: + cargo doc --workspace --all-features --no-deps --document-private-items + + +fmt: + cargo +nightly fmt --all -- --check + +test: + cargo nextest run --workspace --retries 3 \ No newline at end of file diff --git a/msg-common/src/channel.rs b/msg-common/src/channel.rs index 03ccc74..aac05da 100644 --- a/msg-common/src/channel.rs +++ b/msg-common/src/channel.rs @@ -56,11 +56,11 @@ impl Channel { self.tx.send(msg).await } - /// Attempts to immediately send a message on this [`Sender`] + /// Attempts to immediately send a message on this channel. /// - /// This method differs from [`send`] by returning immediately if the channel's + /// This method differs from `send` by returning immediately if the channel's /// buffer is full or no receiver is waiting to acquire some data. Compared - /// with [`send`], this function has two failure cases instead of one (one for + /// with `send`, this function has two failure cases instead of one (one for /// disconnection, one for a full buffer). pub fn try_send(&mut self, msg: S) -> Result<(), TrySendError> { if let Some(tx) = self.tx.get_ref() { @@ -70,17 +70,17 @@ impl Channel { } } - /// Receives the next value for this receiver. + /// Receives the next value for this channel. /// /// This method returns `None` if the channel has been closed and there are /// no remaining messages in the channel's buffer. This indicates that no /// further values can ever be received from this `Receiver`. The channel is - /// closed when all senders have been dropped, or when [`close`] is called. + /// closed when all senders have been dropped, or when `close` is called. /// /// If there are no messages in the channel's buffer, but the channel has /// not yet been closed, this method will sleep until a message is sent or - /// the channel is closed. Note that if [`close`] is called, but there are - /// still outstanding [`Permits`] from before it was closed, the channel is + /// the channel is closed. Note that if `close` is called, but there are + /// still outstanding `Permits` from before it was closed, the channel is /// not considered closed by `recv` until the permits are released. pub async fn recv(&mut self) -> Option { self.rx.recv().await @@ -89,10 +89,10 @@ impl Channel { /// Tries to receive the next value for this receiver. /// /// This method returns the [`Empty`](TryRecvError::Empty) error if the channel is currently - /// empty, but there are still outstanding [senders] or [permits]. + /// empty, but there are still outstanding senders or permits. /// /// This method returns the [`Disconnected`](TryRecvError::Disconnected) error if the channel is - /// currently empty, and there are no outstanding [senders] or [permits]. + /// currently empty, and there are no outstanding senders or permits. /// /// Unlike the [`poll_recv`](Self::poll_recv) method, this method will never return an /// [`Empty`](TryRecvError::Empty) error spuriously. diff --git a/msg-sim/src/dummynet.rs b/msg-sim/src/dummynet.rs index b9a3e8c..1125c7d 100644 --- a/msg-sim/src/dummynet.rs +++ b/msg-sim/src/dummynet.rs @@ -271,13 +271,13 @@ fn get_loopback_name() -> String { } /// Assert that the given status is successful, otherwise return an error with the given message. -/// The type of the error will be `io::ErrorKind::Other`. +/// The type of the error will be [`io::ErrorKind::Other`]. fn assert_status(status: ExitStatus, error: E) -> io::Result<()> where E: Into>, { if !status.success() { - return Err(io::Error::new(io::ErrorKind::Other, error)); + return Err(io::Error::other(error)); } Ok(()) diff --git a/msg-socket/src/lib.rs b/msg-socket/src/lib.rs index 1705d11..fa2e3bb 100644 --- a/msg-socket/src/lib.rs +++ b/msg-socket/src/lib.rs @@ -2,24 +2,31 @@ #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![cfg_attr(not(test), warn(unused_crate_dependencies))] -use msg_transport::Address; +use bytes::Bytes; use tokio::io::{AsyncRead, AsyncWrite}; +use msg_transport::Address; + #[path = "pub/mod.rs"] mod pubs; +pub use pubs::{PubError, PubOptions, PubSocket}; + mod rep; +pub use rep::*; + mod req; +pub use req::*; + mod sub; +pub use sub::*; mod connection; pub use connection::*; -use bytes::Bytes; -pub use pubs::{PubError, PubOptions, PubSocket}; -pub use rep::*; -pub use req::*; -pub use sub::*; +/// The default buffer size for a socket. +const DEFAULT_BUFFER_SIZE: usize = 1024; +/// A request Identifier. pub struct RequestId(u32); impl RequestId { @@ -36,10 +43,12 @@ impl RequestId { } } +/// An interface for authenticating clients, given their ID. pub trait Authenticator: Send + Sync + Unpin + 'static { fn authenticate(&self, id: &Bytes) -> bool; } +/// The result of an authentication attempt. pub(crate) struct AuthResult { id: Bytes, addr: A, diff --git a/msg-socket/src/pub/driver.rs b/msg-socket/src/pub/driver.rs index 8f8c8ea..73b050e 100644 --- a/msg-socket/src/pub/driver.rs +++ b/msg-socket/src/pub/driver.rs @@ -17,7 +17,8 @@ use crate::{AuthResult, Authenticator}; use msg_transport::{Address, PeerAddress, Transport}; use msg_wire::{auth, pubsub}; -#[allow(clippy::type_complexity)] +/// The driver for the publisher socket. This is responsible for accepting incoming connections, +/// authenticating them, and spawning new [`SubscriberSession`]s for each connection. pub(crate) struct PubDriver, A: Address> { /// Session ID counter. pub(super) id_counter: u32, diff --git a/msg-socket/src/pub/mod.rs b/msg-socket/src/pub/mod.rs index 0c16925..75a8e13 100644 --- a/msg-socket/src/pub/mod.rs +++ b/msg-socket/src/pub/mod.rs @@ -1,19 +1,28 @@ -use bytes::Bytes; use std::io; + +use bytes::Bytes; use thiserror::Error; mod driver; -use msg_wire::{ - compression::{CompressionType, Compressor}, - pubsub, -}; + mod session; + mod socket; -mod stats; -mod trie; pub use socket::*; + +mod stats; use stats::SocketStats; +mod trie; + +use msg_wire::{ + compression::{CompressionType, Compressor}, + pubsub, +}; + +/// The default buffer size for the socket. +const DEFAULT_BUFFER_SIZE: usize = 1024; + #[derive(Debug, Error)] pub enum PubError { #[error("IO error: {0:?}")] @@ -28,10 +37,8 @@ pub enum PubError { TopicExists, #[error("Unknown topic: {0}")] UnknownTopic(String), - #[error("Topic closed")] - TopicClosed, - #[error("Transport error: {0:?}")] - Transport(#[from] Box), + #[error("Could not connect to any valid endpoints")] + NoValidEndpoints, } #[derive(Debug)] @@ -55,7 +62,7 @@ impl Default for PubOptions { fn default() -> Self { Self { max_clients: None, - session_buffer_size: 1024, + session_buffer_size: DEFAULT_BUFFER_SIZE, flush_interval: Some(std::time::Duration::from_micros(50)), backpressure_boundary: 8192, min_compress_size: 8192, diff --git a/msg-socket/src/pub/session.rs b/msg-socket/src/pub/session.rs index f000276..2912549 100644 --- a/msg-socket/src/pub/session.rs +++ b/msg-socket/src/pub/session.rs @@ -14,6 +14,9 @@ use tracing::{debug, error, trace, warn}; use super::{trie::PrefixTrie, PubMessage, SocketState}; use msg_wire::pubsub; +/// A subscriber session. This struct represents a single subscriber session, which is a +/// connection to a subscriber. This struct is responsible for handling incoming and outgoing +/// messages, as well as managing the connection state. pub(super) struct SubscriberSession { /// The sequence number of this session. pub(super) seq: u32, @@ -36,6 +39,7 @@ pub(super) struct SubscriberSession { } impl SubscriberSession { + /// Handles outgoing messages to the socket. #[inline] fn on_outgoing(&mut self, msg: PubMessage) { // Check if the message matches the topic filter @@ -50,16 +54,17 @@ impl SubscriberSession { } } + /// Handles incoming messages from the socket. #[inline] fn on_incoming(&mut self, msg: pubsub::Message) { // The only incoming messages we should have are control messages. match msg_to_control(&msg) { ControlMsg::Subscribe(topic) => { - debug!("Subscribing to topic {:?}", topic); + debug!("Subscribing to topic {}", topic); self.topic_filter.insert(&topic) } ControlMsg::Unsubscribe(topic) => { - debug!("Unsubscribing from topic {:?}", topic); + debug!("Unsubscribing from topic {}", topic); self.topic_filter.remove(&topic) } ControlMsg::Close => { @@ -68,6 +73,7 @@ impl SubscriberSession { } } + /// Checks if the connection should be flushed. #[inline] fn should_flush(&mut self, cx: &mut Context<'_>) -> bool { if self.should_flush { @@ -118,7 +124,10 @@ fn msg_to_control(msg: &pubsub::Message) -> ControlMsg { ControlMsg::Close } } else { - warn!("Unkown control message topic, closing session: {:?}", msg.topic()); + warn!( + "Unkown control message topic, closing session: {}", + String::from_utf8_lossy(msg.topic()) + ); ControlMsg::Close } } diff --git a/msg-socket/src/pub/socket.rs b/msg-socket/src/pub/socket.rs index 3b53439..ed7e0af 100644 --- a/msg-socket/src/pub/socket.rs +++ b/msg-socket/src/pub/socket.rs @@ -1,4 +1,4 @@ -use std::{io, net::SocketAddr, path::PathBuf, sync::Arc}; +use std::{net::SocketAddr, path::PathBuf, sync::Arc}; use bytes::Bytes; use futures::stream::FuturesUnordered; @@ -120,10 +120,7 @@ where } let Some(local_addr) = transport.local_addr() else { - return Err(PubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not bind to any valid address", - ))); + return Err(PubError::NoValidEndpoints); }; debug!("Listening on {:?}", local_addr); @@ -149,8 +146,7 @@ where /// Publishes a message to the given topic. If the topic doesn't exist, this is a no-op. pub async fn publish(&self, topic: impl Into, message: Bytes) -> Result<(), PubError> { - let topic = topic.into(); - let mut msg = PubMessage::new(topic, message); + let mut msg = PubMessage::new(topic.into(), message); // We compress here since that way we only have to do it once. // Compression is only done if the message is larger than the @@ -159,8 +155,7 @@ where if len_before > self.options.min_compress_size { if let Some(ref compressor) = self.compressor { msg.compress(compressor.as_ref())?; - - trace!("Compressed message from {} to {} bytes", len_before, msg.payload().len(),); + trace!("Compressed message from {} to {} bytes", len_before, msg.payload().len()); } } diff --git a/msg-socket/src/pub/stats.rs b/msg-socket/src/pub/stats.rs index 0fd37a7..b5ea3bf 100644 --- a/msg-socket/src/pub/stats.rs +++ b/msg-socket/src/pub/stats.rs @@ -8,8 +8,6 @@ pub struct SocketStats { bytes_tx: AtomicUsize, /// Total number of active request clients active_clients: AtomicUsize, - // / Total number of dropped messages due to a slow consumer - // dropped_messages: AtomicUsize, } impl SocketStats { diff --git a/msg-socket/src/pub/trie.rs b/msg-socket/src/pub/trie.rs index 13bb830..ae9c9be 100644 --- a/msg-socket/src/pub/trie.rs +++ b/msg-socket/src/pub/trie.rs @@ -2,6 +2,7 @@ use std::collections::hash_map::Entry; use rustc_hash::FxHashMap; +/// A node in the prefix trie. struct Node { children: FxHashMap, catch_all: bool, @@ -14,6 +15,11 @@ impl Node { } } +/// A prefix trie for matching topics. +/// +/// This trie is used to match topics in a NATS-like system. It supports wildcards: +/// - `*` matches a single token. +/// - `>` matches one or more tokens. pub(super) struct PrefixTrie { root: Node, } diff --git a/msg-socket/src/rep/driver.rs b/msg-socket/src/rep/driver.rs index a1596f4..1372a18 100644 --- a/msg-socket/src/rep/driver.rs +++ b/msg-socket/src/rep/driver.rs @@ -17,7 +17,7 @@ use tokio_stream::{StreamMap, StreamNotifyClose}; use tokio_util::codec::Framed; use tracing::{debug, error, info, trace, warn}; -use crate::{rep::SocketState, AuthResult, Authenticator, PubError, RepOptions, Request}; +use crate::{rep::SocketState, AuthResult, Authenticator, RepOptions, Request}; use msg_transport::{Address, PeerAddress, Transport}; use msg_wire::{ @@ -26,6 +26,8 @@ use msg_wire::{ reqrep, }; +use super::RepError; + pub(crate) struct PeerState { pending_requests: FuturesUnordered, conn: Framed, @@ -57,7 +59,7 @@ pub(crate) struct RepDriver, A: Address> { /// A set of pending incoming connections, represented by [`Transport::Accept`]. pub(super) conn_tasks: FuturesUnordered, /// A joinset of authentication tasks. - pub(crate) auth_tasks: JoinSet, PubError>>, + pub(crate) auth_tasks: JoinSet, RepError>>, } impl Future for RepDriver @@ -65,7 +67,7 @@ where T: Transport + Unpin + 'static, A: Address, { - type Output = Result<(), PubError>; + type Output = Result<(), RepError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); @@ -198,12 +200,13 @@ where let mut conn = Framed::new(io, auth::Codec::new_server()); debug!("Waiting for auth"); + // Wait for the response let auth = conn .next() .await - .ok_or(PubError::SocketClosed)? - .map_err(|e| PubError::Auth(e.to_string()))?; + .ok_or(RepError::SocketClosed)? + .map_err(|e| RepError::Auth(e.to_string()))?; debug!("Auth received: {:?}", auth); @@ -211,7 +214,7 @@ where conn.send(auth::Message::Reject).await?; conn.flush().await?; conn.close().await?; - return Err(PubError::Auth("Invalid auth message".to_string())); + return Err(RepError::Auth("Invalid auth message".to_string())); }; // If authentication fails, send a reject message and close the connection @@ -219,7 +222,7 @@ where conn.send(auth::Message::Reject).await?; conn.flush().await?; conn.close().await?; - return Err(PubError::Auth("Authentication failed".to_string())); + return Err(RepError::Auth("Authentication failed".to_string())); } // Send ack @@ -248,7 +251,7 @@ where } impl Stream for PeerState { - type Item = Result, PubError>; + type Item = Result, RepError>; /// Advances the state of the peer. fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index 9067311..773f608 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -9,10 +9,9 @@ mod stats; pub use socket::*; use stats::SocketStats; -const DEFAULT_BUFFER_SIZE: usize = 1024; - +/// Errors that can occur when using a reply socket. #[derive(Debug, Error)] -pub enum PubError { +pub enum RepError { #[error("IO error: {0:?}")] Io(#[from] std::io::Error), #[error("Wire protocol error: {0:?}")] @@ -21,10 +20,11 @@ pub enum PubError { Auth(String), #[error("Socket closed")] SocketClosed, - #[error("Transport error: {0:?}")] - Transport(#[from] Box), + #[error("Could not connect to any valid endpoints")] + NoValidEndpoints, } +/// The reply socket options. pub struct RepOptions { /// The maximum number of concurrent clients. max_clients: Option, @@ -82,8 +82,8 @@ impl Request { } /// Responds to the request. - pub fn respond(self, response: Bytes) -> Result<(), PubError> { - self.response.send(response).map_err(|_| PubError::SocketClosed) + pub fn respond(self, response: Bytes) -> Result<(), RepError> { + self.response.send(response).map_err(|_| RepError::SocketClosed) } } diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs index bc7f2b4..767ebb6 100644 --- a/msg-socket/src/rep/socket.rs +++ b/msg-socket/src/rep/socket.rs @@ -1,5 +1,4 @@ use std::{ - io, net::SocketAddr, path::PathBuf, pin::Pin, @@ -17,8 +16,8 @@ use tokio_stream::StreamMap; use tracing::{debug, warn}; use crate::{ - rep::{driver::RepDriver, SocketState, SocketStats, DEFAULT_BUFFER_SIZE}, - Authenticator, PubError, RepOptions, Request, + rep::{driver::RepDriver, RepError, SocketState, SocketStats}, + Authenticator, RepOptions, Request, DEFAULT_BUFFER_SIZE, }; use msg_transport::{Address, Transport}; @@ -48,7 +47,8 @@ impl RepSocket where T: Transport + Send + Unpin + 'static, { - pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> { + /// Binds the socket to the given socket address. + pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), RepError> { let addrs = lookup_host(addr).await?; self.try_bind(addrs.collect()).await } @@ -58,7 +58,8 @@ impl RepSocket where T: Transport + Send + Unpin + 'static, { - pub async fn bind(&mut self, path: impl Into) -> Result<(), PubError> { + /// Binds the socket to the given path. + pub async fn bind(&mut self, path: impl Into) -> Result<(), RepError> { let addr = path.into().clone(); self.try_bind(vec![addr]).await } @@ -100,7 +101,7 @@ where } /// Binds the socket to the given address. This spawns the socket driver task. - pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), PubError> { + pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), RepError> { let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE); let mut transport = self.transport.take().expect("Transport has been moved already"); @@ -116,10 +117,7 @@ where } let Some(local_addr) = transport.local_addr() else { - return Err(PubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not bind to any valid address", - ))); + return Err(RepError::NoValidEndpoints); }; debug!("Listening on {:?}", local_addr); @@ -144,6 +142,7 @@ where Ok(()) } + /// Returns the statistics for this socket. pub fn stats(&self) -> &SocketStats { &self.state.stats } diff --git a/msg-socket/src/rep/stats.rs b/msg-socket/src/rep/stats.rs index e3249d3..41e23b0 100644 --- a/msg-socket/src/rep/stats.rs +++ b/msg-socket/src/rep/stats.rs @@ -1,7 +1,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; -/// Statistics for a reply socket. These are shared between the driver task -/// and the socket. +/// Statistics for a reply socket. +/// These are shared between the driver task and the socket. #[derive(Debug, Default)] pub struct SocketStats { /// Total bytes sent diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index 05d133c..b793ff4 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -37,7 +37,6 @@ type ConnectionCtl = ConnectionState, Expone /// the the socket forward. pub(crate) struct ReqDriver, A: Address> { /// Options shared with the socket. - #[allow(unused)] pub(crate) options: Arc, /// State shared with the socket. pub(crate) socket_state: Arc, @@ -139,6 +138,7 @@ where fn on_message(&mut self, msg: reqrep::Message) { if let Some(pending) = self.pending_requests.remove(&msg.id()) { let rtt = pending.start.elapsed().as_micros() as usize; + let size = msg.size(); let compression_type = msg.header().compression_type(); let mut payload = msg.into_payload(); @@ -148,9 +148,7 @@ where Ok(decompressed) => payload = decompressed, Err(e) => { error!(err = ?e, "Failed to decompress response payload"); - let _ = pending.sender.send(Err(ReqError::Wire(reqrep::Error::Io( - io::Error::new(io::ErrorKind::Other, "Failed to decompress response"), - )))); + let _ = pending.sender.send(Err(ReqError::Wire(reqrep::Error::Decompression))); return; } } @@ -169,6 +167,7 @@ where Command::Send { mut message, response } => { let start = std::time::Instant::now(); + // Compress the message if it's larger than the minimum size let len_before = message.payload().len(); if len_before > self.options.min_compress_size { if let Some(ref compressor) = self.compressor { @@ -193,8 +192,11 @@ where } } + /// Check for request timeouts and notify the sender if any requests have timed out. + /// This is done periodically by the driver. fn check_timeouts(&mut self) { let now = Instant::now(); + let timed_out_ids = self .pending_requests .iter() @@ -214,6 +216,7 @@ where } } + /// Check if the connection should be flushed. #[inline] fn should_flush(&mut self, cx: &mut Context<'_>) -> bool { if self.should_flush { @@ -233,6 +236,8 @@ where } } + /// Reset the connection state to inactive, so that it will be re-tried. + /// This is done when the connection is closed or an error occurs. #[inline] fn reset_connection(&mut self) { self.conn_state = ConnectionState::Inactive { @@ -316,10 +321,7 @@ where } Poll::Ready(Some(Err(err))) => { if let reqrep::Error::Io(e) = err { - error!(err = ?e, "Socket error"); - if e.kind() == std::io::ErrorKind::Other { - error!(err = ?e, "Other error"); - } + error!(err = ?e, "Socket wire error"); } // set the connection to inactive, so that it will be re-tried diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index e53bc25..5f50026 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -16,30 +16,36 @@ pub use socket::*; use self::stats::SocketStats; +/// The default buffer size for the socket. const DEFAULT_BUFFER_SIZE: usize = 1024; +/// Errors that can occur when using a request socket. #[derive(Debug, Error)] pub enum ReqError { #[error("IO error: {0:?}")] Io(#[from] std::io::Error), - #[error("Authentication error: {0:?}")] - Auth(String), #[error("Wire protocol error: {0:?}")] Wire(#[from] reqrep::Error), + #[error("Authentication error: {0}")] + Auth(String), #[error("Socket closed")] SocketClosed, - #[error("Transport error: {0:?}")] - Transport(#[from] Box), #[error("Request timed out")] Timeout, + #[error("Could not connect to any valid endpoints")] + NoValidEndpoints, } +/// Commands that can be sent to the request socket driver. pub enum Command { + /// Send a request message and wait for a response. Send { message: ReqMessage, response: oneshot::Sender> }, } +/// The request socket options. #[derive(Debug, Clone)] pub struct ReqOptions { + /// Optional authentication token. auth_token: Option, /// Timeout duration for requests. timeout: std::time::Duration, @@ -139,7 +145,6 @@ pub struct ReqMessage { payload: Bytes, } -#[allow(unused)] impl ReqMessage { pub fn new(payload: Bytes) -> Self { Self { diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 75c39eb..54ac5c5 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use rustc_hash::FxHashMap; -use std::{io, marker::PhantomData, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; +use std::{marker::PhantomData, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; use tokio::{ net::{lookup_host, ToSocketAddrs}, sync::{mpsc, oneshot}, @@ -40,9 +40,7 @@ where /// Connects to the target address with the default options. pub async fn connect(&mut self, addr: impl ToSocketAddrs) -> Result<(), ReqError> { let mut addrs = lookup_host(addr).await?; - let endpoint = addrs.next().ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidInput, "could not find any valid address") - })?; + let endpoint = addrs.next().ok_or_else(|| ReqError::NoValidEndpoints)?; self.try_connect(endpoint).await } diff --git a/msg-socket/src/sub/mod.rs b/msg-socket/src/sub/mod.rs index ccf867e..7129511 100644 --- a/msg-socket/src/sub/mod.rs +++ b/msg-socket/src/sub/mod.rs @@ -19,13 +19,13 @@ mod stream; use msg_transport::Address; use msg_wire::pubsub; -const DEFAULT_BUFFER_SIZE: usize = 1024; +use crate::DEFAULT_BUFFER_SIZE; #[derive(Debug, Error)] pub enum SubError { #[error("IO error: {0:?}")] Io(#[from] std::io::Error), - #[error("Authentication error: {0:?}")] + #[error("Authentication error: {0}")] Auth(String), #[error("Wire protocol error: {0:?}")] Wire(#[from] pubsub::Error), @@ -33,8 +33,10 @@ pub enum SubError { SocketClosed, #[error("Command channel full")] ChannelFull, - #[error("Transport error: {0:?}")] - Transport(#[from] Box), + #[error("Could not find any valid endpoints")] + NoValidEndpoints, + #[error("Reserved topic 'MSG' cannot be used")] + ReservedTopic, } #[derive(Debug)] diff --git a/msg-socket/src/sub/socket.rs b/msg-socket/src/sub/socket.rs index d001b14..173e708 100644 --- a/msg-socket/src/sub/socket.rs +++ b/msg-socket/src/sub/socket.rs @@ -1,6 +1,5 @@ use std::{ collections::HashSet, - io, net::{IpAddr, Ipv4Addr, SocketAddr}, path::PathBuf, pin::Pin, @@ -23,6 +22,7 @@ use super::{ DEFAULT_BUFFER_SIZE, }; +/// A subscriber socket. This socket implements [`Stream`] and yields incoming [`PubMessage`]s. pub struct SubSocket, A: Address> { /// Command channel to the socket driver. to_driver: mpsc::Sender>, @@ -46,10 +46,7 @@ where /// Connects to the given endpoint asynchronously. pub async fn connect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> { let mut addrs = lookup_host(endpoint).await?; - let mut endpoint = addrs.next().ok_or(SubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - )))?; + let mut endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?; // Some transport implementations (e.g. Quinn) can't dial an unspecified // IP address, so replace it with localhost. @@ -64,12 +61,7 @@ where /// Attempts to connect to the given endpoint immediately. pub fn try_connect(&mut self, endpoint: impl Into) -> Result<(), SubError> { let addr = endpoint.into(); - let mut endpoint: SocketAddr = addr.parse().map_err(|_| { - SubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - )) - })?; + let mut endpoint: SocketAddr = addr.parse().map_err(|_| SubError::NoValidEndpoints)?; // Some transport implementations (e.g. Quinn) can't dial an unspecified // IP address, so replace it with localhost. @@ -84,10 +76,7 @@ where /// Disconnects from the given endpoint asynchronously. pub async fn disconnect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> { let mut addrs = lookup_host(endpoint).await?; - let endpoint = addrs.next().ok_or(SubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - )))?; + let endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?; self.disconnect_inner(endpoint).await } @@ -95,12 +84,7 @@ where /// Attempts to disconnect from the given endpoint immediately. pub fn try_disconnect(&mut self, endpoint: impl Into) -> Result<(), SubError> { let endpoint = endpoint.into(); - let endpoint: SocketAddr = endpoint.parse().map_err(|_| { - SubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - )) - })?; + let endpoint: SocketAddr = endpoint.parse().map_err(|_| SubError::NoValidEndpoints)?; self.try_disconnect_inner(endpoint) } @@ -136,11 +120,12 @@ where T: Transport + Send + Sync + Unpin + 'static, A: Address, { - #[allow(clippy::new_without_default)] + /// Creates a new subscriber socket with the default [`SubOptions`]. pub fn new(transport: T) -> Self { Self::with_options(transport, SubOptions::default()) } + /// Creates a new subscriber socket with the given transport and options. pub fn with_options(transport: T, options: SubOptions) -> Self { let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); let (to_socket, from_driver) = mpsc::channel(options.ingress_buffer_size); @@ -208,7 +193,9 @@ where self.ensure_active_driver(); let topic = topic.into(); - assert!(!topic.starts_with("MSG"), "MSG is a reserved topic"); + if topic.starts_with("MSG") { + return Err(SubError::ReservedTopic); + } self.send_command(Command::Subscribe { topic }).await?; @@ -220,7 +207,9 @@ where self.ensure_active_driver(); let topic = topic.into(); - assert!(!topic.starts_with("MSG"), "MSG is a reserved topic"); + if topic.starts_with("MSG") { + return Err(SubError::ReservedTopic); + } self.try_send_command(Command::Subscribe { topic })?; @@ -232,7 +221,9 @@ where self.ensure_active_driver(); let topic = topic.into(); - assert!(!topic.starts_with("MSG"), "MSG is a reserved topic"); + if topic.starts_with("MSG") { + return Err(SubError::ReservedTopic); + } self.send_command(Command::Unsubscribe { topic }).await?; @@ -244,7 +235,9 @@ where self.ensure_active_driver(); let topic = topic.into(); - assert!(!topic.starts_with("MSG"), "MSG is a reserved topic"); + if topic.starts_with("MSG") { + return Err(SubError::ReservedTopic); + } self.try_send_command(Command::Unsubscribe { topic })?; diff --git a/msg-socket/src/sub/stream.rs b/msg-socket/src/sub/stream.rs index 2792403..6acac75 100644 --- a/msg-socket/src/sub/stream.rs +++ b/msg-socket/src/sub/stream.rs @@ -1,14 +1,16 @@ -use bytes::Bytes; -use futures::{SinkExt, Stream, StreamExt}; use std::{ pin::Pin, task::{ready, Context, Poll}, }; + +use bytes::Bytes; +use futures::{SinkExt, Stream, StreamExt}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::Framed; use tracing::{debug, trace}; use super::SubError; + use msg_wire::pubsub; /// Wraps a framed connection to a publisher and exposes all the PUBSUB specific methods. @@ -19,8 +21,7 @@ pub(super) struct PublisherStream { impl PublisherStream { /// Queues a message to be sent to the publisher. If the connection - /// is ready, this will register the waker - /// and flush on the next poll. + /// is ready, this will register the waker and flush on the next poll. pub fn poll_send( &mut self, cx: &mut Context<'_>, @@ -48,6 +49,7 @@ impl From> for Pub } } +/// A message received from a stream. pub(super) struct TopicMessage { pub timestamp: u64, pub compression_type: u8, @@ -58,7 +60,6 @@ pub(super) struct TopicMessage { impl Stream for PublisherStream { type Item = Result; - #[inline] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); diff --git a/msg-transport/src/ipc/mod.rs b/msg-transport/src/ipc/mod.rs index 1a72cb8..40b1b24 100644 --- a/msg-transport/src/ipc/mod.rs +++ b/msg-transport/src/ipc/mod.rs @@ -109,10 +109,10 @@ impl Transport for Ipc { if addr.exists() { debug!("Socket file already exists. Attempting to remove."); if let Err(e) = std::fs::remove_file(&addr) { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("Failed to remove existing socket file, {:?}", e), - )); + return Err(io::Error::other(format!( + "Failed to remove existing socket file, {:?}", + e + ))); } } diff --git a/msg-transport/src/lib.rs b/msg-transport/src/lib.rs index f817ab0..346ccee 100644 --- a/msg-transport/src/lib.rs +++ b/msg-transport/src/lib.rs @@ -45,12 +45,11 @@ pub trait Transport { /// An error that occurred when setting up the connection. type Error: std::error::Error + From + Send + Sync; - /// A pending [`Transport::Output`] for an outbound connection, - /// obtained when calling [`Transport::connect`]. + /// A pending output for an outbound connection, obtained when calling [`Transport::connect`]. type Connect: Future> + Send; - /// A pending [`Transport::Output`] for an inbound connection, - /// obtained when calling [`Transport::poll_accept`]. + /// A pending output for an inbound connection, obtained when calling + /// [`Transport::poll_accept`]. type Accept: Future> + Send + Unpin; /// Returns the local address this transport is bound to (if it is bound). @@ -64,7 +63,7 @@ pub trait Transport { fn connect(&mut self, addr: A) -> Self::Connect; /// Poll for incoming connections. If an inbound connection is received, a future representing - /// a pending inbound connection is returned. The future will resolve to [`Transport::Output`]. + /// a pending inbound connection is returned. The future will resolve to [`Transport::Accept`]. fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll; } diff --git a/msg-wire/src/compression/mod.rs b/msg-wire/src/compression/mod.rs index 5934524..8a32bfc 100644 --- a/msg-wire/src/compression/mod.rs +++ b/msg-wire/src/compression/mod.rs @@ -2,12 +2,15 @@ use bytes::Bytes; use std::io; mod gzip; -mod lz4; -mod snappy; -mod zstd; pub use gzip::*; + +mod lz4; pub use lz4::*; + +mod snappy; pub use snappy::*; + +mod zstd; pub use zstd::*; /// The possible compression type used for a message. diff --git a/msg-wire/src/reqrep.rs b/msg-wire/src/reqrep.rs index 77c4337..2de3f24 100644 --- a/msg-wire/src/reqrep.rs +++ b/msg-wire/src/reqrep.rs @@ -11,10 +11,13 @@ pub enum Error { Io(#[from] std::io::Error), #[error("Invalid wire ID: {0}")] WireId(u8), + #[error("Failed to decompress message")] + Decompression, } #[derive(Debug, Clone)] pub struct Message { + /// The message header. header: Header, /// The message payload. payload: Bytes, diff --git a/msg/Cargo.toml b/msg/Cargo.toml index c2f42cd..8a2e005 100644 --- a/msg/Cargo.toml +++ b/msg/Cargo.toml @@ -16,15 +16,11 @@ msg-transport.workspace = true msg-wire.workspace = true tokio.workspace = true -bytes.workspace = true tokio-stream.workspace = true [dev-dependencies] -# benchmarking +bytes.workspace = true tracing-subscriber = "0.3" -# Add jemalloc for extra perf on Linux systems. -[target.'cfg(all(not(windows), not(target_env = "musl")))'.dependencies] -jemallocator = { version = "0.5.0", features = ["profiling"] } divan = "0.1" futures.workspace = true tracing.workspace = true @@ -32,6 +28,10 @@ rand.workspace = true criterion.workspace = true pprof.workspace = true +# Add jemalloc for extra perf on Linux systems. +[target.'cfg(all(not(windows), not(target_env = "musl")))'.dependencies] +jemallocator = { version = "0.5.0", features = ["profiling"] } + [[bench]] name = "reqrep" harness = false From 86046f5ad61ee19e3e11577d72aee1de6d996b12 Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Wed, 9 Oct 2024 18:21:41 +0200 Subject: [PATCH 2/3] chore: update actions --- .github/workflows/ci.yml | 32 ++++++++++++-------------------- justfile | 6 ++++-- msg-transport/src/lib.rs | 2 +- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1b8bbe9..e558911 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,14 +8,12 @@ jobs: timeout-minutes: 20 steps: - name: Checkout sources - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Rust toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: toolchain: nightly - profile: minimal - override: true - - uses: Swatinem/rust-cache@v1 + - uses: Swatinem/rust-cache@v2 with: cache-on-failure: true - name: cargo test @@ -28,15 +26,13 @@ jobs: timeout-minutes: 20 steps: - name: Checkout sources - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Rust toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: toolchain: nightly - profile: minimal components: rustfmt, clippy - override: true - - uses: Swatinem/rust-cache@v1 + - uses: Swatinem/rust-cache@v2 with: cache-on-failure: true - name: cargo fmt @@ -50,14 +46,12 @@ jobs: continue-on-error: true steps: - name: Checkout sources - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Rust toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: toolchain: nightly - profile: minimal - override: true - - uses: Swatinem/rust-cache@v1 + - uses: Swatinem/rust-cache@v2 with: cache-on-failure: true - name: build @@ -71,14 +65,12 @@ jobs: continue-on-error: true steps: - name: Checkout sources - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Rust toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: toolchain: nightly - profile: minimal - override: true - - uses: Swatinem/rust-cache@v1 + - uses: Swatinem/rust-cache@v2 with: cache-on-failure: true - name: doclint diff --git a/justfile b/justfile index 5ced516..2e27e87 100644 --- a/justfile +++ b/justfile @@ -1,4 +1,4 @@ -default: check doc fmt +default: check doc fmt clippy check: cargo check --workspace --all-features --all-targets @@ -6,9 +6,11 @@ check: doc: cargo doc --workspace --all-features --no-deps --document-private-items +clippy: + cargo +nightly clippy --all --all-features -- -D warnings fmt: cargo +nightly fmt --all -- --check test: - cargo nextest run --workspace --retries 3 \ No newline at end of file + cargo nextest run --workspace --retries 3 diff --git a/msg-transport/src/lib.rs b/msg-transport/src/lib.rs index 346ccee..8f47e6c 100644 --- a/msg-transport/src/lib.rs +++ b/msg-transport/src/lib.rs @@ -88,7 +88,7 @@ impl<'a, T, A> Acceptor<'a, T, A> { } } -impl<'a, T, A> Future for Acceptor<'a, T, A> +impl Future for Acceptor<'_, T, A> where T: Transport + Unpin, A: Address, From ed012792c642712ebac26d8d7ca534d586cc153b Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Wed, 9 Oct 2024 20:08:20 +0200 Subject: [PATCH 3/3] nit: lint --- msg-socket/src/req/socket.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 54ac5c5..d68be0a 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -40,7 +40,7 @@ where /// Connects to the target address with the default options. pub async fn connect(&mut self, addr: impl ToSocketAddrs) -> Result<(), ReqError> { let mut addrs = lookup_host(addr).await?; - let endpoint = addrs.next().ok_or_else(|| ReqError::NoValidEndpoints)?; + let endpoint = addrs.next().ok_or(ReqError::NoValidEndpoints)?; self.try_connect(endpoint).await }