Skip to content

Commit

Permalink
Merge pull request #66 from chainbound/feat/reqrep/compression
Browse files Browse the repository at this point in the history
feat: reqrep compression
  • Loading branch information
mempirate authored Jan 26, 2024
2 parents 7a4d594 + 1436582 commit 7d7f6dc
Show file tree
Hide file tree
Showing 13 changed files with 384 additions and 79 deletions.
1 change: 1 addition & 0 deletions book/src/links.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
[new-issue]: https://github.com/chainbound/msg-rs/issues/new
[license]: https://github.com/chainbound/msg-rs/blob/main/LICENSE
[contributing]: https://github.com/chainbound/msg-rs/tree/main/CONTRIBUTING.md
[examples]: https://github.com/chainboudn/msg-rs/tree/main/msg/examples

<!-- Chainbound links -->

Expand Down
44 changes: 40 additions & 4 deletions book/src/usage/compression.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,45 @@ MSG-RS supports message compression out of the box, and it is very easy to use.
Compression is most useful in scenarios where you are sending large messages over the network.
It can also help reduce the amount of bandwidth used by your application.

In MSG, compression is handled by the socket type.
In MSG, compression is handled by the socket type:

Here is an example of setting up a pub/sub socket with compression:
- [Request/Response](#requestresponse)
- [Publish/Subscribe](#publishsubscribe)

---

## Request/Response

You can also find a complete example in [msg/examples/reqrep_compression.rs][examples].

```rust
use msg::{compression::GzipCompressor, ReqSocket, RepSocket, Tcp};

#[tokio::main]
async fn main() {
// Initialize the reply socket (server side) with a transport
let mut rep = RepSocket::new(Tcp::default())
// Enable Gzip compression (compression level 6).
.with_compressor(GzipCompressor::new(6));

rep.bind("0.0.0.0:4444").await.unwrap();

// Initialize the request socket (client side) with a transport
let mut req = ReqSocket::new(Tcp::default())
// Enable Gzip compression (compression level 6).
// The request and response sockets *don't have to*
// use the same compression algorithm or level.
.with_compressor(GzipCompressor::new(6));

req.connect("0.0.0.0:4444").await.unwrap();

// ...
}
```

## Publish/Subscribe

You can also find a complete example in [msg/examples/pubsub_compression.rs][examples].

```rust
use msg::{compression::GzipCompressor, PubSocket, SubSocket, Tcp};
Expand All @@ -21,13 +57,13 @@ async fn main() {
.with_compressor(GzipCompressor::new(6));

// Configure the subscribers with options
let mut sub1 = SubSocket::new(Tcp::default());
let mut sub_socket = SubSocket::new(Tcp::default());

// ...
}
```

By looking at this, you might be wondering: "how does the subscriber know that the
By looking at this example, you might be wondering: "how does the subscriber know that the
publisher is compressing messages, if the subscriber is not configured with Gzip compression?"

The answer is that in MSG, compression is defined by the publisher for each message that is sent.
Expand Down
58 changes: 51 additions & 7 deletions msg-socket/src/rep/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ use tracing::{debug, error, info, warn};

use crate::{rep::SocketState, AuthResult, Authenticator, PubError, RepOptions, Request};
use msg_transport::{PeerAddress, Transport};
use msg_wire::{auth, reqrep};
use msg_wire::{
auth,
compression::{try_decompress_payload, Compressor},
reqrep,
};

pub(crate) struct PeerState<T: AsyncRead + AsyncWrite> {
pending_requests: FuturesUnordered<PendingRequest>,
Expand All @@ -28,6 +32,7 @@ pub(crate) struct PeerState<T: AsyncRead + AsyncWrite> {
egress_queue: VecDeque<reqrep::Message>,
state: Arc<SocketState>,
should_flush: bool,
compressor: Option<Arc<dyn Compressor>>,
}

pub(crate) struct RepDriver<T: Transport> {
Expand All @@ -44,6 +49,9 @@ pub(crate) struct RepDriver<T: Transport> {
pub(crate) to_socket: mpsc::Sender<Request>,
/// Optional connection authenticator.
pub(crate) auth: Option<Arc<dyn Authenticator>>,
/// Optional message compressor. This is shared with the socket to keep
/// the API consistent with other socket types (e.g. `PubSocket`)
pub(crate) compressor: Option<Arc<dyn Compressor>>,
/// A set of pending incoming connections, represented by [`Transport::Accept`].
pub(super) conn_tasks: FuturesUnordered<T::Accept>,
/// A joinset of authentication tasks.
Expand All @@ -62,9 +70,21 @@ where
loop {
if let Poll::Ready(Some((peer, msg))) = this.peer_states.poll_next_unpin(cx) {
match msg {
Some(Ok(request)) => {
Some(Ok(mut request)) => {
debug!("Received request from peer {}", peer);
this.state.stats.increment_rx(request.msg().len());

let size = request.msg().len();

// decompress the payload
match try_decompress_payload(request.compression_type, request.msg) {
Ok(decompressed) => request.msg = decompressed,
Err(e) => {
error!("Failed to decompress message: {:?}", e);
continue;
}
}

this.state.stats.increment_rx(size);
let _ = this.to_socket.try_send(request);
}
Some(Err(e)) => {
Expand Down Expand Up @@ -94,6 +114,7 @@ where
egress_queue: VecDeque::with_capacity(128),
state: Arc::clone(&this.state),
should_flush: false,
compressor: this.compressor.clone(),
}),
);
}
Expand Down Expand Up @@ -216,6 +237,7 @@ where
egress_queue: VecDeque::with_capacity(128),
state: Arc::clone(&self.state),
should_flush: false,
compressor: self.compressor.clone(),
}),
);
}
Expand Down Expand Up @@ -262,10 +284,32 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Stream for PeerState<T> {
}

// Then we check for completed requests, and push them onto the egress queue.
if let Poll::Ready(Some(Some((id, payload)))) =
if let Poll::Ready(Some(Some((id, mut payload)))) =
this.pending_requests.poll_next_unpin(cx)
{
let msg = reqrep::Message::new(id, payload);
let mut compression_type = 0;
let len_before = payload.len();
if let Some(ref compressor) = this.compressor {
match compressor.compress(&payload) {
Ok(compressed) => {
payload = compressed;
compression_type = compressor.compression_type() as u8;
}
Err(e) => {
tracing::error!("Failed to compress message: {:?}", e);
continue;
}
}

tracing::debug!(
"Compressed message {} from {} to {} bytes",
id,
len_before,
payload.len()
)
}

let msg = reqrep::Message::new(id, compression_type, payload);
this.egress_queue.push_back(msg);

continue;
Expand All @@ -276,19 +320,19 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Stream for PeerState<T> {
Poll::Ready(Some(result)) => {
tracing::trace!("Received message from peer {}: {:?}", this.addr, result);
let msg = result?;
let msg_id = msg.id();

let (tx, rx) = oneshot::channel();

// Add the pending request to the list
this.pending_requests.push(PendingRequest {
msg_id,
msg_id: msg.id(),
response: rx,
});

let request = Request {
source: this.addr,
response: tx,
compression_type: msg.header().compression_type(),
msg: msg.into_payload(),
};

Expand Down
52 changes: 49 additions & 3 deletions msg-socket/src/rep/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@ pub enum PubError {
Transport(#[from] Box<dyn std::error::Error + Send + Sync>),
}

#[derive(Default)]
pub struct RepOptions {
/// The maximum number of concurrent clients.
max_clients: Option<usize>,
min_compress_size: usize,
}

impl Default for RepOptions {
fn default() -> Self {
Self {
max_clients: None,
min_compress_size: 8192,
}
}
}

impl RepOptions {
Expand All @@ -37,6 +46,13 @@ impl RepOptions {
self.max_clients = Some(max_clients);
self
}

/// Sets the minimum payload size for compression.
/// If the payload is smaller than this value, it will not be compressed.
pub fn min_compress_size(mut self, min_compress_size: usize) -> Self {
self.min_compress_size = min_compress_size;
self
}
}

/// The request socket state, shared between the backend task and the socket.
Expand All @@ -45,11 +61,15 @@ pub(crate) struct SocketState {
pub(crate) stats: SocketStats,
}

/// A request received by the socket. It contains the source address, the message,
/// and a oneshot channel to respond to the request.
/// A request received by the socket.
pub struct Request {
/// The source address of the request.
source: SocketAddr,
/// The compression type used for the request payload
compression_type: u8,
/// The oneshot channel to respond to the request.
response: oneshot::Sender<Bytes>,
/// The message payload.
msg: Bytes,
}

Expand Down Expand Up @@ -78,6 +98,7 @@ mod tests {

use futures::StreamExt;
use msg_transport::tcp::Tcp;
use msg_wire::compression::{GzipCompressor, SnappyCompressor};
use rand::Rng;

use crate::{req::ReqSocket, Authenticator, ReqOptions};
Expand Down Expand Up @@ -226,4 +247,29 @@ mod tests {
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(rep.stats().active_clients(), 1);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_basic_reqrep_with_compression() {
let mut rep =
RepSocket::with_options(Tcp::default(), RepOptions::default().min_compress_size(0))
.with_compressor(SnappyCompressor);

rep.bind("0.0.0.0:4445").await.unwrap();

let mut req =
ReqSocket::with_options(Tcp::default(), ReqOptions::default().min_compress_size(0))
.with_compressor(GzipCompressor::new(6));

req.connect("0.0.0.0:4445").await.unwrap();

tokio::spawn(async move {
let req = rep.next().await.unwrap();

assert_eq!(req.msg(), &Bytes::from("hello"));
req.respond(Bytes::from("world")).unwrap();
});

let res: Bytes = req.request(Bytes::from("hello")).await.unwrap();
assert_eq!(res, Bytes::from("world"));
}
}
11 changes: 11 additions & 0 deletions msg-socket/src/rep/socket.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use futures::{stream::FuturesUnordered, Stream};
use msg_wire::compression::Compressor;
use std::{
io,
net::SocketAddr,
Expand Down Expand Up @@ -36,6 +37,8 @@ pub struct RepSocket<T: Transport> {
auth: Option<Arc<dyn Authenticator>>,
/// The local address this socket is bound to.
local_addr: Option<SocketAddr>,
/// Optional message compressor.
compressor: Option<Arc<dyn Compressor>>,
}

impl<T> RepSocket<T>
Expand All @@ -56,6 +59,7 @@ where
options: Arc::new(options),
state: Arc::new(SocketState::default()),
auth: None,
compressor: None,
}
}

Expand All @@ -65,6 +69,12 @@ where
self
}

/// Sets the message compressor for this socket.
pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
self.compressor = Some(Arc::new(compressor));
self
}

/// Binds the socket to the given address. This spawns the socket driver task.
pub async fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> Result<(), PubError> {
let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE);
Expand Down Expand Up @@ -103,6 +113,7 @@ where
auth: self.auth.take(),
auth_tasks: JoinSet::new(),
conn_tasks: FuturesUnordered::new(),
compressor: self.compressor.take(),
};

tokio::spawn(backend);
Expand Down
Loading

0 comments on commit 7d7f6dc

Please sign in to comment.