diff --git a/Cargo.lock b/Cargo.lock index 688cb13..ad1ffbb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,18 @@ dependencies = [ "memchr", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" + [[package]] name = "async-channel" version = "1.9.0" @@ -258,12 +270,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "bumpalo" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" + [[package]] name = "bytes" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.0.83" @@ -279,6 +303,58 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "ciborium" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" + +[[package]] +name = "ciborium-ll" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfaff671f6b22ca62406885ece523383b9b64022e341e53e009a62ebc47a45f2" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a216b506622bb1d316cd51328dce24e07bdff4a6128a47c7e7fad11878d5adbb" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" + [[package]] name = "concurrent-queue" version = "2.4.0" @@ -288,6 +364,67 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "futures", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "tokio", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2fe95351b870527a5d09bf563ed3c97c0cffb87cf1c78a591bf48bb218d9aa" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", +] + [[package]] name = "crossbeam-utils" version = "0.8.17" @@ -303,6 +440,12 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + [[package]] name = "errno" version = "0.3.8" @@ -512,6 +655,12 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + [[package]] name = "hermit-abi" version = "0.3.3" @@ -538,6 +687,41 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "is-terminal" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" +dependencies = [ + "hermit-abi", + "rustix 0.38.28", + "windows-sys 0.48.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + +[[package]] +name = "js-sys" +version = "0.3.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -583,6 +767,15 @@ version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + [[package]] name = "miniz_oxide" version = "0.7.1" @@ -610,6 +803,7 @@ dependencies = [ "async-channel 2.1.1", "async-rustls", "bytes", + "criterion", "futures", "pretty_assertions", "rand", @@ -635,6 +829,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -660,6 +863,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + [[package]] name = "overload" version = "0.1.1" @@ -695,6 +904,34 @@ dependencies = [ "futures-io", ] +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + [[package]] name = "polling" version = "2.8.0" @@ -789,6 +1026,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "regex" version = "1.10.2" @@ -955,6 +1212,21 @@ dependencies = [ "untrusted", ] +[[package]] +name = "ryu" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "sct" version = "0.7.1" @@ -971,6 +1243,37 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" +[[package]] +name = "serde" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1088,6 +1391,16 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tokio" version = "1.35.1" @@ -1211,12 +1524,86 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3c4517f54858c779bbcbf228f4fca63d121bf85fbecb2dc578cdf4a39395690" +[[package]] +name = "walkdir" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" + +[[package]] +name = "web-sys" +version = "0.3.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki" version = "0.22.4" @@ -1243,6 +1630,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 5b60ea0..dd7b87d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,9 @@ tokio = { version = "1.33.0", features = ["macros", "io-util", "net", "time"], o smol = { version = "1.3.0", optional = true } [dev-dependencies] -tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +criterion = {version="0.5.1", features=["async_tokio"]} + +tracing-subscriber = {version = "0.3.16", features = ["env-filter"]} smol = { version = "1.3.0" } tokio = { version = "1.33.0", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] } @@ -56,4 +58,9 @@ webpki = { version = "0.22.4" } async-rustls = { version = "0.4.1" } tokio-rustls = "0.24.1" rstest = "0.18.2" -rand = "0.8.5" \ No newline at end of file +rand = "0.8.5" + + +[[bench]] +name = "bench_main" +harness = false \ No newline at end of file diff --git a/benches/bench_main.rs b/benches/bench_main.rs new file mode 100644 index 0000000..6d34307 --- /dev/null +++ b/benches/bench_main.rs @@ -0,0 +1,17 @@ +use criterion::criterion_main; + +mod benchmarks; + +criterion_main! { + benchmarks::tokio_concurrent::tokio, + // benchmarks::external_process::benches, + // benchmarks::iter_with_large_drop::benches, + // benchmarks::iter_with_large_setup::benches, + // benchmarks::iter_with_setup::benches, + // benchmarks::with_inputs::benches, + // benchmarks::special_characters::benches, + // benchmarks::measurement_overhead::benches, + // benchmarks::custom_measurement::benches, + // benchmarks::sampling_mode::benches, + // benchmarks::async_measurement_overhead::benches, +} \ No newline at end of file diff --git a/benches/benchmarks/mod.rs b/benches/benchmarks/mod.rs new file mode 100644 index 0000000..0f53577 --- /dev/null +++ b/benches/benchmarks/mod.rs @@ -0,0 +1,76 @@ +use bytes::{Bytes, BytesMut, BufMut}; +use mqrstt::packets::{Packet, ConnAck, ConnAckFlags, Publish, Disconnect}; + +pub mod tokio_concurrent; + + +fn fill_stuff(buffer: &mut BytesMut, publ_count: usize, publ_size: usize) { + empty_connect(buffer); + for i in 0..publ_count{ + very_large_publish(i as u16, publ_size/5).write(buffer).unwrap(); + } + empty_disconnect().write(buffer).unwrap(); +} + +fn empty_disconnect() -> Packet{ + let discon = Disconnect{ + reason_code: mqrstt::packets::reason_codes::DisconnectReasonCode::ServerBusy, + properties: Default::default(), + }; + + Packet::Disconnect(discon) +} + +fn empty_connect(buffer: &mut BytesMut){ + // let conn_ack = ConnAck{ + // connack_flags: ConnAckFlags::default(), + // reason_code: mqrstt::packets::reason_codes::ConnAckReasonCode::Success, + // connack_properties: Default::default(), + // }; + + // Packet::ConnAck(conn_ack) + // buffer.put_u8(0b0010_0000); // Connack flags + // buffer.put_u8(0x01); // Connack flags + // buffer.put_u8(0x00); // Reason code, + // buffer.put_u8(0x00); // empty properties + + buffer.put_u8(0x20); + buffer.put_u8(0x13); + buffer.put_u8(0x00); + buffer.put_u8(0x00); + buffer.put_u8(0x10); + buffer.put_u8(0x27); + buffer.put_u8(0x06); + buffer.put_u8(0x40); + buffer.put_u8(0x00); + buffer.put_u8(0x00); + buffer.put_u8(0x25); + buffer.put_u8(0x01); + buffer.put_u8(0x2a); + buffer.put_u8(0x01); + buffer.put_u8(0x29); + buffer.put_u8(0x01); + buffer.put_u8(0x22); + buffer.put_u8(0xff); + buffer.put_u8(0xff); + buffer.put_u8(0x28); + buffer.put_u8(0x01); + + +} + + +/// Returns Publish Packet with 5x `repeat` as payload in bytes. +fn very_large_publish(id: u16, repeat: usize) -> Packet { + let publ = Publish{ + dup: false, + qos: mqrstt::packets::QoS::ExactlyOnce, + retain: false, + topic: "BlaBla".to_string(), + packet_identifier: Some(id), + publish_properties: Default::default(), + payload: Bytes::from_iter([0u8, 1u8, 2, 3, 4].repeat(repeat)), + }; + + Packet::Publish(publ) +} \ No newline at end of file diff --git a/benches/benchmarks/tokio_concurrent.rs b/benches/benchmarks/tokio_concurrent.rs new file mode 100644 index 0000000..6e008fa --- /dev/null +++ b/benches/benchmarks/tokio_concurrent.rs @@ -0,0 +1,94 @@ +use std::{io::{Write, Cursor}, hint::black_box, time::Duration}; + +use bytes::BytesMut; +use criterion::{criterion_group, BatchSize, Criterion}; +use mqrstt::{new_tokio, ConnectOptions}; + +use super::fill_stuff; + +struct ReadWriteTester<'a>{ + read: Cursor<&'a [u8]>, + write: Vec +} + +impl<'a> ReadWriteTester<'a> { + pub fn new(read: &'a [u8]) -> Self { + Self{ + read: Cursor::new(read), + write: Vec::new(), + } + } +} + +impl<'a> tokio::io::AsyncRead for ReadWriteTester<'a> { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + tokio::io::AsyncRead::poll_read(std::pin::Pin::new(&mut self.read), cx, buf) + } +} + +impl<'a> tokio::io::AsyncWrite for ReadWriteTester<'a> { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + todo!() + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + todo!() + } + + fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + todo!() + } +} + + +fn tokio_concurrent(c: &mut Criterion) { + let mut group = c.benchmark_group("Tokio concurrent throughput test"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(20)); + + group.bench_function("tokio_bench_concurrent_read_write", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime) + .iter_with_setup( + || { + let mut buffer = BytesMut::new(); + + // :0 tells the OS to pick an open port. + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let tcp_stream = std::net::TcpStream::connect(addr).unwrap(); + + let (mut server, _addr) = listener.accept().unwrap(); + + fill_stuff(&mut buffer, 100, 5_000_000); + + server.write_all(&buffer.to_vec()).unwrap(); + + let tcp_stream = tokio::net::TcpStream::from_std(tcp_stream).unwrap(); + (tcp_stream, server, _addr) + }, + |(tcp_stream, server, addr)| async move { + + let _server_box = black_box(server); + let _addr = black_box(addr); + + let options = ConnectOptions::new("test", true); + let (mut network, _) = new_tokio(options); + + network.connect(tcp_stream, ()).await.unwrap(); + + network.run().await.unwrap(); + }) + }); +} + +criterion_group!(tokio, tokio_concurrent); \ No newline at end of file diff --git a/src/connect_options.rs b/src/connect_options.rs index bf1bd13..7ad1f79 100644 --- a/src/connect_options.rs +++ b/src/connect_options.rs @@ -1,4 +1,4 @@ -use std::{cell::OnceCell, time::Duration}; +use std::time::Duration; use bytes::Bytes; @@ -41,11 +41,11 @@ pub struct ConnectOptions { } impl ConnectOptions { - pub fn new(client_id: String, clean_start: bool) -> Self { + pub fn new>(client_id: S, clean_start: bool) -> Self { Self { keep_alive_interval: Duration::from_secs(60), clean_start: clean_start, - client_id, + client_id: client_id.as_ref().to_string(), username: None, password: None, diff --git a/src/lib.rs b/src/lib.rs index 5829d1b..3d74f11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -222,24 +222,31 @@ impl AsyncEventHandler for Arc where T: AsyncEventHandler{ } } +impl AsyncEventHandler for () { + fn handle(&self, _: Packet) -> impl Future + Send + Sync { + async {} + } +} + pub trait EventHandler { fn handle(&mut self, incoming_packet: Packet); } + +impl EventHandler for () { + fn handle(&mut self, _: Packet) {} +} + /// Most basic no op handler /// This handler performs no operations on incoming messages. -pub struct NOP{} +pub struct NOP {} -impl AsyncEventHandler for NOP{ - async fn handle(&self, _: Packet){ - - } +impl AsyncEventHandler for NOP { + async fn handle(&self, _: Packet) {} } -impl EventHandler for NOP{ - fn handle(&mut self, _: Packet){ - - } +impl EventHandler for NOP { + fn handle(&mut self, _: Packet) {} } // #[cfg(feature = "smol")] @@ -277,8 +284,9 @@ impl EventHandler for NOP{ /// let options = ConnectOptions::new("ExampleClient".to_string()); /// let (network, client) = mqrstt::new_tokio::(options); /// ``` -pub fn new_tokio(options: ConnectOptions) -> (tokio::Network, MqttClient) +pub fn new_tokio(options: ConnectOptions) -> (tokio::Network, MqttClient) where + H: AsyncEventHandler + Clone + Send + Sync, S: ::tokio::io::AsyncReadExt + ::tokio::io::AsyncWriteExt + Sized + Unpin, { use available_packet_ids::AvailablePacketIds; @@ -498,18 +506,13 @@ mod lib_test { let mut pingpong = Arc::new(PingPong {client: client.clone()}); - network.connect(stream, &mut pingpong).await.unwrap(); + network.connect(stream, pingpong).await.unwrap(); client.subscribe("mqrstt").await.unwrap(); let (n, _) = tokio::join!( async { - loop { - return match network.poll(&mut pingpong).await { - Ok(crate::tokio::NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } + network.run().await }, async { client.publish("mqrstt".to_string(), QoS::ExactlyOnce, false, b"ping".repeat(500)).await.unwrap(); @@ -523,10 +526,13 @@ mod lib_test { client.disconnect().await.unwrap(); } ); - let n = dbg!(n); - assert!(n.is_ok()); - assert_eq!(crate::tokio::NetworkStatus::OutgoingDisconnect, n.unwrap()); + dbg!(n); + + // let n = dbg!(n.1); + // assert!(n.is_ok()); + + // assert_eq!(crate::tokio::NetworkStatus::OutgoingDisconnect, n.unwrap()); } pub struct PingResp { @@ -606,34 +612,28 @@ mod lib_test { #[cfg(feature = "tokio")] #[tokio::test] async fn test_tokio_ping_req() { + use crate::tokio::NetworkStatus; + let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); client_id += "_TokioTcppingrespTest"; let mut options = ConnectOptions::new(client_id, true); - let mut keep_alive_interval = 5; + let keep_alive_interval = 5; options.set_keep_alive_interval(Duration::from_secs(keep_alive_interval)); let wait_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; let (mut network, client) = new_tokio(options); - let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + let stream = tokio::net::TcpStream::connect(("azurewe1576.azureexternal.dnvgl.com", 1883)).await.unwrap(); - let mut pingresp = Arc::new(PingResp::new(client.clone())); + let pingresp = Arc::new(PingResp::new(client.clone())); - network.connect(stream, &mut pingresp).await.unwrap(); + network.connect(stream, pingresp).await.unwrap(); - let futs = tokio::task::spawn(async move { + let futs: tokio::task::JoinHandle<(Result, ())> = tokio::task::spawn(async move { tokio::join!( async move { - loop { - match network.poll(&mut pingresp).await { - Ok(crate::tokio::NetworkStatus::Active) => continue, - Ok(crate::tokio::NetworkStatus::OutgoingDisconnect) => return Ok(pingresp), - Ok(crate::tokio::NetworkStatus::NoPingResp) => panic!(), - Ok(crate::tokio::NetworkStatus::IncomingDisconnect) => panic!(), - Err(err) => return Err(err), - } - } + network.run().await }, async move { tokio::time::sleep(wait_duration).await; @@ -645,9 +645,10 @@ mod lib_test { tokio::time::sleep(wait_duration + Duration::from_secs(1)).await; let (n, _) = futs.await.unwrap(); - assert!(n.is_ok()); - let pingresp = n.unwrap(); - assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); + dbg!(n); + // assert!(n.is_ok()); + // let pingresp = n.unwrap(); + // assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); } #[cfg(all(feature = "tokio", target_family = "windows"))] @@ -669,9 +670,9 @@ mod lib_test { let stream = tokio::net::TcpStream::connect(address).await.unwrap(); - let mut pingresp = Arc::new(PingResp::new(client.clone())); + let pingresp = Arc::new(PingResp::new(client.clone())); - network.connect(stream, &mut pingresp).await + network.connect(stream, pingresp).await }, async move { let listener = smol::net::TcpListener::bind(address).await.unwrap(); diff --git a/src/tokio/mod.rs b/src/tokio/mod.rs index 9f5674c..590dea1 100644 --- a/src/tokio/mod.rs +++ b/src/tokio/mod.rs @@ -13,5 +13,5 @@ pub enum NetworkStatus { /// Indicate that an outgoing disconnect has been transmited and the socket is closed OutgoingDisconnect, /// The server did not respond to the ping request and the socket has been closed - NoPingResp, + KeepAliveTimeout, } diff --git a/src/tokio/network.rs b/src/tokio/network.rs index d61d8de..0954fff 100644 --- a/src/tokio/network.rs +++ b/src/tokio/network.rs @@ -1,7 +1,9 @@ use async_channel::{Receiver, Sender}; +use tokio::join; use tokio::task::JoinSet; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::time::{Duration, Instant}; use crate::available_packet_ids::AvailablePacketIds; @@ -18,52 +20,44 @@ use super::NetworkStatus; use super::stream::read_half::ReadStream; use super::stream::write_half::WriteStream; +// type StreamType = tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static; + /// [`Network`] reads and writes to the network based on tokios [`AsyncReadExt`] [`AsyncWriteExt`]. /// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. /// The most import thing to remember is that you have to provide a new stream after the previous has failed. /// (i.e. you need to reconnect after any expected or unexpected disconnect). -pub struct Network { +pub struct Network { + handler: Option, + network: Option<(ReadStream, WriteStream)>, /// Options of the current mqtt connection options: ConnectOptions, last_network_action: Instant, - await_pingresp: Option, + + await_pingresp_atomic: Arc, perform_keep_alive: bool, state_handler: Arc, - // outgoing_packet_buffer: Vec, - // incoming_packet_buffer: Vec, - - join_set: JoinSet<()>, - - to_writer_s: Sender, - to_writer_r: Receiver, to_network_r: Receiver, } -impl Network { +impl Network { pub fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { - let (to_writer_s, to_writer_r) = async_channel::bounded(100); - Self { + handler: None, network: None, last_network_action: Instant::now(), - await_pingresp: None, + await_pingresp_atomic: Arc::new(AtomicBool::new(false)), perform_keep_alive: true, state_handler: Arc::new(StateHandler::new(&options, apkids)), options, - join_set: JoinSet::new(), - - to_writer_s, - to_writer_r, - to_network_r, } } @@ -71,15 +65,13 @@ impl Network { /// Tokio impl #[cfg(feature = "tokio")] -impl Network +impl Network where - S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin, + H: AsyncEventHandler + Clone + Send + Sync + 'static, + S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static { /// Initializes an MQTT connection with the provided configuration an stream - pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> - where - H: AsyncEventHandler - { + pub async fn connect(&mut self, stream: S, handler: H) -> Result<(), ConnectionError> { let (mut network, conn_ack) = Stream::connect(&self.options, stream).await?; self.last_network_action = Instant::now(); @@ -100,6 +92,7 @@ where self.last_network_action = Instant::now(); self.network = Some(network.split()); + self.handler = Some(handler); Ok(()) } @@ -114,149 +107,66 @@ where /// /// In all other cases the network is unusable anymore. /// The stream will be dropped and the internal buffers will be cleared. - pub async fn poll(&mut self, handler: &mut H) -> Result - where - H: AsyncEventHandler + Clone + Send + Sync + 'static - { + pub async fn run(&mut self) -> Result { if self.network.is_none() { return Err(ConnectionError::NoNetwork); } - match self.tokio_select(handler).await { - Ok(NetworkStatus::Active) => Ok(NetworkStatus::Active), - Err(ConnectionError::JoinError(err)) => Err(ConnectionError::JoinError(err)), - otherwise => { - self.network = None; - self.await_pingresp = None; - self.join_set.abort_all(); - self.clear_write_channl(); - - - - otherwise - } - } + self.tokio_select().await } - fn clear_write_channl(&mut self) { - loop { - match self.to_writer_r.try_recv() { - Ok(_) => (), - Err(_) => return, - } - } - } - - async fn tokio_select(&mut self, handler: &mut H) -> Result - where - H: AsyncEventHandler + Clone + Send + Sync + 'static - { - let Network { - network, - options: options, - last_network_action, - await_pingresp, - perform_keep_alive, - state_handler: state_handler, - // outgoing_packet_buffer, - join_set, - to_writer_s, - to_writer_r, - to_network_r, - } = self; - - let sleep; - if let Some(instant) = await_pingresp { - sleep = *instant + options.keep_alive_interval - Instant::now(); - } else { - sleep = *last_network_action + options.keep_alive_interval - Instant::now(); - } - - if let Some((read_stream, write_stream)) = network { - tokio::select! { - // res = read_stream.read_bytes() => { - // res?; - // loop { - // let packet = match read_stream.parse_message().await { - // Err(ReadBytes::Err(err)) => return Err(err), - // Err(ReadBytes::InsufficientBytes(_)) => { - // break; - // }, - // Ok(packet) => packet, - // }; - - // match packet{ - // Packet::PingResp => { - // handler.handle(packet).await; - // *await_pingresp = None; - // }, - // Packet::Disconnect(_) => { - // handler.handle(packet).await; - // return Ok(NetworkStatus::IncomingDisconnect); - // } - // Packet::ConnAck(conn_ack) => { - // if let Some(retransmit_packets) = mqtt_handler.handle_incoming_connack(&conn_ack)? { - // retransmit_packets.into_iter().map(|p| to_network_s.send(p)); - // // outgoing_packet_buffer.append(&mut retransmit_packets) - // } - // handler.handle(Packet::ConnAck(conn_ack)).await; - // } - // packet => { - // match mqtt_handler.handle_incoming_packet(&packet)? { - // (maybe_reply_packet, true) => { - // let handler_clone = handler.clone(); - // let sender_clone = to_network_s.clone(); - // join_set.spawn(async move { - // handler_clone.handle(packet).await; - // if let Some(reply_packet) = maybe_reply_packet { - // sender_clone.send(reply_packet).await; - // } - // }); - // }, - // (Some(reply_packet), false) => { - // to_network_s.send(reply_packet); - // }, - // (None, false) => (), - // } - // } - // } - // } - // Ok(NetworkStatus::Active) - // }, - // outgoing = to_network_r.recv() => { - // let packet = outgoing?; - // write_stream.write(&packet).await?; - // let mut disconnect = false; - - // if packet.packet_type() == PacketType::Disconnect{ - // disconnect = true; - // } - - // mqtt_handler.handle_outgoing_packet(packet)?; - // *last_network_action = Instant::now(); - - // if disconnect{ - // Ok(NetworkStatus::OutgoingDisconnect) - // } - // else{ - // Ok(NetworkStatus::Active) - // } - // }, - _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { - let packet = Packet::PingReq; - write_stream.write(&packet).await?; - *last_network_action = Instant::now(); - *await_pingresp = Some(Instant::now()); - Ok(NetworkStatus::Active) - }, - _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { - let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; - write_stream.write(&Packet::Disconnect(disconnect)).await?; - Ok(NetworkStatus::NoPingResp) + async fn tokio_select(&mut self) -> Result { + match self.network.take() { + Some((read_stream, write_stream)) => { + let run_signal = Arc::new(AtomicBool::new(true)); + let (to_writer_s, to_writer_r) = async_channel::bounded(100); + + let mut read_network = NetworkReader{ + run_signal: run_signal.clone(), + handler: self.handler.as_ref().unwrap().clone(), + read_stream: read_stream, + await_pingresp_atomic: self.await_pingresp_atomic.clone(), + state_handler: self.state_handler.clone(), + to_writer_s, + }; + + let mut write_network = NetworkWriter { + run_signal: run_signal.clone(), + write_stream, + keep_alive_interval: self.options.keep_alive_interval, + last_network_action: self.last_network_action, + await_pingresp_bool: self.await_pingresp_atomic.clone(), + await_pingresp_time: None, + perform_keep_alive: self.perform_keep_alive, + state_handler: self.state_handler.clone(), + to_writer_r: to_writer_r, + to_network_r: self.to_network_r.clone(), + }; + + let read_task = tokio::spawn(async move { + let ret = read_network.read().await; + read_network.run_signal.store(false, std::sync::atomic::Ordering::Release); + ret + }); + + let write_task = tokio::spawn(async move { + let ret = write_network.write().await; + write_network.run_signal.store(false, std::sync::atomic::Ordering::Release); + ret + }); + + let (a,b) = join!(read_task, write_task); + let a = a?; + let b = b?; + match (a, b) { + (Ok(a), _) => Ok(a), + (_, Ok(b)) => Ok(b), + (Err(err), Err(_)) => Err(err), } } - } else { - Err(ConnectionError::NoNetwork) + None => { + Err(ConnectionError::NoNetwork) + }, } } @@ -264,35 +174,27 @@ where } -pub struct NetworkReader { - read_stream: ReadStream, - - // last_network_action: Instant, - // await_pingresp: Option, - // perform_keep_alive: bool, +pub struct NetworkReader { + run_signal: Arc, + handler: H, + read_stream: ReadStream, + await_pingresp_atomic:Arc, state_handler: Arc, - // outgoing_packet_buffer: Vec, - // incoming_packet_buffer: Vec, - - join_set: JoinSet<()>, - to_writer_s: Sender, } #[cfg(feature = "tokio")] -impl NetworkReader +impl NetworkReader where - S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin, + H: AsyncEventHandler + Clone + Send + Sync + 'static, + S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, { - async fn read(&mut self, handler: &mut H) -> Result - where - H: AsyncEventHandler + Clone + Send + Sync + 'static - { - loop { + async fn read(&mut self) -> Result{ + while self.run_signal.load(std::sync::atomic::Ordering::Acquire) { let _ = self.read_stream.read_bytes().await?; loop { - let packet = match self.read_stream.parse_message().await { + let packet = match self.read_stream.parse_message() { Err(ReadBytes::Err(err)) => return Err(err), Err(ReadBytes::InsufficientBytes(_)) => { break; @@ -302,33 +204,36 @@ where match packet{ Packet::PingResp => { - handler.handle(packet).await; - // *await_pingresp = None; + self.handler.handle(packet).await; + #[cfg(feature = "logs")] + if !self.await_pingresp_atomic.fetch_and(false, std::sync::atomic::Ordering::SeqCst) { + tracing::warn!("Received PingResp but did not expect it"); + } + #[cfg(not(feature = "logs"))] + self.await_pingresp_atomic.store(false, std::sync::atomic::Ordering::SeqCst); + println!("Turned await_pingresp atomic to false"); }, Packet::Disconnect(_) => { - handler.handle(packet).await; + self.handler.handle(packet).await; return Ok(NetworkStatus::IncomingDisconnect); } Packet::ConnAck(conn_ack) => { if let Some(retransmit_packets) = self.state_handler.handle_incoming_connack(&conn_ack)? { - retransmit_packets.into_iter().map(|p| self.to_writer_s.send(p)); + for packet in retransmit_packets.into_iter(){ + self.to_writer_s.send(packet).await?; + } } - handler.handle(Packet::ConnAck(conn_ack)).await; + self.handler.handle(Packet::ConnAck(conn_ack)).await; } packet => { match self.state_handler.handle_incoming_packet(&packet)? { (maybe_reply_packet, true) => { - let handler_clone = handler.clone(); - let sender_clone = self.to_writer_s.clone(); - self.join_set.spawn(async move { - handler_clone.handle(packet).await; - if let Some(reply_packet) = maybe_reply_packet { - sender_clone.send(reply_packet).await; - } - }); + if let Some(reply_packet) = maybe_reply_packet { + let _ = self.to_writer_s.send(reply_packet).await?; + } }, (Some(reply_packet), false) => { - self.to_writer_s.send(reply_packet); + self.to_writer_s.send(reply_packet).await?; }, (None, false) => (), } @@ -338,15 +243,19 @@ where } Ok(NetworkStatus::Active) } - } pub struct NetworkWriter { + run_signal: Arc, + write_stream: WriteStream, - // last_network_action: Instant, - // await_pingresp: Option, - // perform_keep_alive: bool, + keep_alive_interval: Duration, + + last_network_action: Instant, + await_pingresp_bool: Arc, + await_pingresp_time: Option, + perform_keep_alive: bool, state_handler: Arc, @@ -359,11 +268,21 @@ impl NetworkWriter where S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin, { - async fn write(&mut self) -> Result - where - H: AsyncEventHandler + Clone + Send + Sync + 'static - { - loop { + async fn write(&mut self) -> Result { + while self.run_signal.load(std::sync::atomic::Ordering::Acquire) { + + if self.await_pingresp_time.is_some() && !self.await_pingresp_bool.load(std::sync::atomic::Ordering::Acquire) { + self.await_pingresp_time = None; + } + + let sleep; + if let Some(instant) = &self.await_pingresp_time { + sleep = *instant + self.keep_alive_interval - Instant::now(); + } else { + sleep = self.last_network_action + self.keep_alive_interval - Instant::now(); + } + + tokio::select!{ outgoing = self.to_network_r.recv() => { let packet = outgoing?; @@ -372,13 +291,35 @@ where let disconnect = if packet.packet_type() == PacketType::Disconnect { true } else { false }; self.state_handler.handle_outgoing_packet(packet)?; - // *last_network_action = Instant::now(); + self.last_network_action = Instant::now(); if disconnect{ return Ok(NetworkStatus::OutgoingDisconnect) } + }, + from_reader = self.to_writer_r.recv() => { + let packet = from_reader?; + self.write_stream.write(&packet).await?; + self.state_handler.handle_outgoing_packet(packet)?; + self.last_network_action = Instant::now(); + }, + _ = tokio::time::sleep(sleep), if self.await_pingresp_time.is_none() && self.perform_keep_alive => { + let packet = Packet::PingReq; + self.write_stream.write(&packet).await?; + self.await_pingresp_bool.store(true, std::sync::atomic::Ordering::SeqCst); + self.last_network_action = Instant::now(); + self.await_pingresp_time = Some(Instant::now()); + }, + _ = tokio::time::sleep(sleep), if self.await_pingresp_time.is_some() => { + self.await_pingresp_time = None; + if self.await_pingresp_bool.load(std::sync::atomic::Ordering::SeqCst){ + let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; + self.write_stream.write(&Packet::Disconnect(disconnect)).await?; + return Ok(NetworkStatus::KeepAliveTimeout) + } } } } + Ok(NetworkStatus::Active) } } diff --git a/src/tokio/stream/read_half.rs b/src/tokio/stream/read_half.rs index aaac32d..8a8a58a 100644 --- a/src/tokio/stream/read_half.rs +++ b/src/tokio/stream/read_half.rs @@ -17,7 +17,7 @@ pub struct ReadStream { } impl ReadStream where S: tokio::io::AsyncRead + Sized + Unpin { - pub fn new(stream: ReadHalf, const_buffer: [u8; 4096], read_buffer: BytesMut) -> Self{ + pub fn new(stream: ReadHalf, const_buffer: [u8; 4096], read_buffer: BytesMut) -> Self{ Self{ stream, const_buffer, @@ -25,7 +25,7 @@ impl ReadStream where S: tokio::io::AsyncRead + Sized + Unpin { } } - pub async fn parse_message(&mut self) -> Result> { + pub fn parse_message(&mut self) -> Result> { let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; if header.remaining_length + header_length > self.read_buffer.len() {