From 98fff6d0bdd8ccd6a4fad227cd8a2b94cc4556b5 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Mon, 1 Jan 2024 14:24:20 +0100 Subject: [PATCH] 0.3.0-alpha.1 --- Cargo.lock | 295 ++------- Cargo.toml | 28 +- benches/bench_main.rs | 13 +- benches/benchmarks/mod.rs | 87 ++- benches/benchmarks/tokio.rs | 353 +++++++++++ benches/benchmarks/tokio_concurrent.rs | 89 --- src/client.rs | 80 ++- src/error.rs | 2 +- src/event_handlers.rs | 192 ++++++ src/lib.rs | 822 ++++++++++--------------- src/mqtt_handler.rs | 9 +- src/packets/suback.rs | 1 - src/packets/subscribe.rs | 23 + src/packets/unsubscribe.rs | 5 + src/smol/network.rs | 92 ++- src/smol/stream.rs | 31 +- src/tokio/mod.rs | 99 ++- src/tokio/network.rs | 252 +++++--- src/tokio/stream/read_half.rs | 36 -- src/tokio/stream/write_half.rs | 23 - 20 files changed, 1464 insertions(+), 1068 deletions(-) create mode 100644 benches/benchmarks/tokio.rs delete mode 100644 benches/benchmarks/tokio_concurrent.rs create mode 100644 src/event_handlers.rs diff --git a/Cargo.lock b/Cargo.lock index ad1ffbb..05f82dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,17 +38,6 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" -[[package]] -name = "async-channel" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" -dependencies = [ - "concurrent-queue", - "event-listener 2.5.3", - "futures-core", -] - [[package]] name = "async-channel" version = "2.1.1" @@ -71,41 +60,20 @@ dependencies = [ "async-lock 3.2.0", "async-task", "concurrent-queue", - "fastrand 2.0.1", - "futures-lite 2.1.0", + "fastrand", + "futures-lite", "slab", ] [[package]] name = "async-fs" -version = "1.6.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "279cf904654eeebfa37ac9bb1598880884924aab82e290aa65c9e77a0e142e06" +checksum = "dd1f344136bad34df1f83a47f3fd7f2ab85d75cb8a940af4ccf6d482a84ea01b" dependencies = [ - "async-lock 2.8.0", - "autocfg", + "async-lock 3.2.0", "blocking", - "futures-lite 1.13.0", -] - -[[package]] -name = "async-io" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" -dependencies = [ - "async-lock 2.8.0", - "autocfg", - "cfg-if", - "concurrent-queue", - "futures-lite 1.13.0", - "log", - "parking", - "polling 2.8.0", - "rustix 0.37.27", - "slab", - "socket2 0.4.10", - "waker-fn", + "futures-lite", ] [[package]] @@ -118,10 +86,10 @@ dependencies = [ "cfg-if", "concurrent-queue", "futures-io", - "futures-lite 2.1.0", + "futures-lite", "parking", - "polling 3.3.1", - "rustix 0.38.28", + "polling", + "rustix", "slab", "tracing", "windows-sys 0.52.0", @@ -149,30 +117,31 @@ dependencies = [ [[package]] name = "async-net" -version = "1.8.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0434b1ed18ce1cf5769b8ac540e33f01fa9471058b5e89da9e06f3c882a8c12f" +checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7" dependencies = [ - "async-io 1.13.0", + "async-io", "blocking", - "futures-lite 1.13.0", + "futures-lite", ] [[package]] name = "async-process" -version = "1.8.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea6438ba0a08d81529c69b36700fa2f95837bfe3e776ab39cde9c14d9149da88" +checksum = "15c1cd5d253ecac3d3cf15e390fd96bd92a13b1d14497d81abf077304794fb04" dependencies = [ - "async-io 1.13.0", - "async-lock 2.8.0", + "async-channel", + "async-io", + "async-lock 3.2.0", "async-signal", "blocking", "cfg-if", - "event-listener 3.1.0", - "futures-lite 1.13.0", - "rustix 0.38.28", - "windows-sys 0.48.0", + "event-listener 4.0.1", + "futures-lite", + "rustix", + "windows-sys 0.52.0", ] [[package]] @@ -191,13 +160,13 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e47d90f65a225c4527103a8d747001fc56e375203592b25ad103e1ca13124c5" dependencies = [ - "async-io 2.2.2", + "async-io", "async-lock 2.8.0", "atomic-waker", "cfg-if", "futures-core", "futures-io", - "rustix 0.38.28", + "rustix", "signal-hook-registry", "slab", "windows-sys 0.48.0", @@ -242,12 +211,6 @@ version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - [[package]] name = "bitflags" version = "2.4.1" @@ -260,12 +223,12 @@ version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" dependencies = [ - "async-channel 2.1.1", + "async-channel", "async-lock 3.2.0", "async-task", - "fastrand 2.0.1", + "fastrand", "futures-io", - "futures-lite 2.1.0", + "futures-lite", "piper", "tracing", ] @@ -434,12 +397,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "diff" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" - [[package]] name = "either" version = "1.9.0" @@ -462,17 +419,6 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" -[[package]] -name = "event-listener" -version = "3.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d93877bcde0eb80ca09131a08d23f0a5c18a620b01db137dba666d18cd9b30c2" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - [[package]] name = "event-listener" version = "4.0.1" @@ -494,15 +440,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "fastrand" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" -dependencies = [ - "instant", -] - [[package]] name = "fastrand" version = "2.0.1" @@ -511,9 +448,9 @@ checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" [[package]] name = "futures" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", @@ -526,9 +463,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -536,15 +473,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" dependencies = [ "futures-core", "futures-task", @@ -553,24 +490,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" - -[[package]] -name = "futures-lite" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" -dependencies = [ - "fastrand 1.9.0", - "futures-core", - "futures-io", - "memchr", - "parking", - "pin-project-lite", - "waker-fn", -] +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-lite" @@ -578,7 +500,7 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aeee267a1883f7ebef3700f262d2d54de95dfaf38189015a74fdc4e0c7ad8143" dependencies = [ - "fastrand 2.0.1", + "fastrand", "futures-core", "futures-io", "parking", @@ -587,9 +509,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", @@ -598,15 +520,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-timer" @@ -616,9 +538,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -667,26 +589,6 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "io-lifetimes" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.48.0", -] - [[package]] name = "is-terminal" version = "0.4.9" @@ -694,7 +596,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi", - "rustix 0.38.28", + "rustix", "windows-sys 0.48.0", ] @@ -734,12 +636,6 @@ version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" -[[package]] -name = "linux-raw-sys" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" - [[package]] name = "linux-raw-sys" version = "0.4.12" @@ -798,14 +694,13 @@ dependencies = [ [[package]] name = "mqrstt" -version = "0.2.2" +version = "0.3.0-alpha.1" dependencies = [ - "async-channel 2.1.1", + "async-channel", "async-rustls", "bytes", "criterion", "futures", - "pretty_assertions", "rand", "rstest", "rustls", @@ -900,7 +795,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "668d31b1c4eba19242f2088b2bf3316b82ca31082a8335764db4e083db7485d4" dependencies = [ "atomic-waker", - "fastrand 2.0.1", + "fastrand", "futures-io", ] @@ -932,22 +827,6 @@ dependencies = [ "plotters-backend", ] -[[package]] -name = "polling" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b2d323e8ca7996b3e23126511a523f7e62924d93ecd5ae73b333815b0eb3dce" -dependencies = [ - "autocfg", - "bitflags 1.3.2", - "cfg-if", - "concurrent-queue", - "libc", - "log", - "pin-project-lite", - "windows-sys 0.48.0", -] - [[package]] name = "polling" version = "3.3.1" @@ -957,7 +836,7 @@ dependencies = [ "cfg-if", "concurrent-queue", "pin-project-lite", - "rustix 0.38.28", + "rustix", "tracing", "windows-sys 0.52.0", ] @@ -968,16 +847,6 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" -[[package]] -name = "pretty_assertions" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af7cee1a6c8a5b9208b3cb1061f10c0cb689087b3d8ce85fb9d2dd7a29b6ba66" -dependencies = [ - "diff", - "yansi", -] - [[package]] name = "proc-macro2" version = "1.0.70" @@ -1154,30 +1023,16 @@ dependencies = [ "semver", ] -[[package]] -name = "rustix" -version = "0.37.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea8ca367a3a01fe35e6943c400addf443c0f57670e6ec51196f71a4b8762dd2" -dependencies = [ - "bitflags 1.3.2", - "errno", - "io-lifetimes", - "libc", - "linux-raw-sys 0.3.8", - "windows-sys 0.48.0", -] - [[package]] name = "rustix" version = "0.38.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" dependencies = [ - "bitflags 2.4.1", + "bitflags", "errno", "libc", - "linux-raw-sys 0.4.12", + "linux-raw-sys", "windows-sys 0.52.0", ] @@ -1309,29 +1164,19 @@ checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "smol" -version = "1.3.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13f2b548cd8447f8de0fdf1c592929f70f4fc7039a05e47404b0d096ec6987a1" +checksum = "e635339259e51ef85ac7aa29a1cd991b957047507288697a690e80ab97d07cad" dependencies = [ - "async-channel 1.9.0", + "async-channel", "async-executor", "async-fs", - "async-io 1.13.0", - "async-lock 2.8.0", + "async-io", + "async-lock 3.2.0", "async-net", "async-process", "blocking", - "futures-lite 1.13.0", -] - -[[package]] -name = "socket2" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" -dependencies = [ - "libc", - "winapi", + "futures-lite", ] [[package]] @@ -1363,18 +1208,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.51" +version = "1.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f11c217e1416d6f036b870f14e0413d480dbf28edbee1f877abaf0206af43bb7" +checksum = "b2cd5904763bad08ad5513ddbb12cf2ae273ca53fa9f68e843e236ec6dfccc09" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.51" +version = "1.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" +checksum = "3dcf4a824cce0aeacd6f38ae6f24234c8e80d68632338ebaa1443b5df9e29e19" dependencies = [ "proc-macro2", "quote", @@ -1413,7 +1258,7 @@ dependencies = [ "mio", "num_cpus", "pin-project-lite", - "socket2 0.5.5", + "socket2", "tokio-macros", "windows-sys 0.48.0", ] @@ -1518,12 +1363,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" -[[package]] -name = "waker-fn" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3c4517f54858c779bbcbf228f4fca63d121bf85fbecb2dc578cdf4a39395690" - [[package]] name = "walkdir" version = "2.4.0" @@ -1776,9 +1615,3 @@ name = "windows_x86_64_msvc" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" - -[[package]] -name = "yansi" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" diff --git a/Cargo.toml b/Cargo.toml index 6c69b37..1494a6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mqrstt" -version = "0.2.2" +version = "0.3.0-alpha.1" homepage = "https://github.com/GunnarMorrigan/mqrstt" repository = "https://github.com/GunnarMorrigan/mqrstt" documentation = "https://docs.rs/mqrstt" @@ -9,50 +9,46 @@ readme = "README.md" edition = "2021" license = "MPL-2.0" keywords = [ "MQTT", "IoT", "MQTTv5", "messaging", "client" ] -description = "Pure rust MQTTv5 client implementation for sync and async (Smol & Tokio)" +description = "Pure rust MQTTv5 client implementation Smol and Tokio" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["smol", "tokio_mqtt", "concurrent_tokio"] -concurrent_tokio = ["dep:tokio", "tokio/rt"] -tokio_mqtt = ["dep:tokio"] +default = ["smol", "tokio", "tokio_concurrent"] +tokio_concurrent = ["dep:tokio", "tokio/rt"] +tokio = ["dep:tokio"] smol = ["dep:smol"] -sync = [] logs = ["dep:tracing"] -# quic = ["dep:quinn"] [dependencies] # Packets bytes = "1.5.0" # Errors -thiserror = "1.0.49" -tracing = { version = "0.1.39", optional = true } +thiserror = "1.0.53" +tracing = { version = "0.1.40", optional = true } async-channel = "2.1.1" #async-mutex = "1.4.0" -futures = { version = "0.3.28", default-features = false, features = ["std", "async-await"] } +futures = { version = "0.3.30", default-features = false, features = ["std", "async-await"] } # quic feature flag # quinn = {version = "0.9.0", optional = true } # tokio feature flag -tokio = { version = "1.33.0", features = ["macros", "io-util", "net", "time"], optional = true } +tokio = { version = "1.35.1", features = ["macros", "io-util", "net", "time"], optional = true } # smol feature flag -smol = { version = "1.3.0", optional = true } +smol = { version = "2.0.0", optional = true } [dev-dependencies] criterion = {version="0.5.1", features=["async_tokio"]} -tracing-subscriber = {version = "0.3.16", features = ["env-filter"]} +tracing-subscriber = {version = "0.3.18", features = ["env-filter"]} -smol = { version = "1.3.0" } +smol = { version = "2.0.0" } tokio = { version = "1.33.0", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] } -pretty_assertions = "1.3.0" - rustls = { version = "0.21.7" } rustls-pemfile = { version = "1.0.3" } webpki = { version = "0.22.4" } diff --git a/benches/bench_main.rs b/benches/bench_main.rs index 8858246..9d7f5e4 100644 --- a/benches/bench_main.rs +++ b/benches/bench_main.rs @@ -3,15 +3,6 @@ use criterion::criterion_main; mod benchmarks; criterion_main! { - benchmarks::tokio_concurrent::tokio, - // benchmarks::external_process::benches, - // benchmarks::iter_with_large_drop::benches, - // benchmarks::iter_with_large_setup::benches, - // benchmarks::iter_with_setup::benches, - // benchmarks::with_inputs::benches, - // benchmarks::special_characters::benches, - // benchmarks::measurement_overhead::benches, - // benchmarks::custom_measurement::benches, - // benchmarks::sampling_mode::benches, - // benchmarks::async_measurement_overhead::benches, + benchmarks::tokio::tokio_concurrent, + benchmarks::tokio::tokio_synchronous, } diff --git a/benches/benchmarks/mod.rs b/benches/benchmarks/mod.rs index c2ed860..1d169c1 100644 --- a/benches/benchmarks/mod.rs +++ b/benches/benchmarks/mod.rs @@ -1,7 +1,7 @@ use bytes::{BufMut, Bytes, BytesMut}; use mqrstt::packets::{Disconnect, Packet, Publish}; -pub mod tokio_concurrent; +pub mod tokio; fn fill_stuff(buffer: &mut BytesMut, publ_count: usize, publ_size: usize) { empty_connect(buffer); @@ -65,8 +65,91 @@ fn very_large_publish(id: u16, repeat: usize) -> Packet { topic: "BlaBla".into(), packet_identifier: Some(id), publish_properties: Default::default(), - payload: Bytes::from_iter([0u8, 1u8, 2, 3, 4].repeat(repeat)), + payload: Bytes::from_iter("ping".repeat(repeat).into_bytes()), }; Packet::Publish(publ) } + + +mod test_handlers{ + use std::{sync::{atomic::AtomicU16, Arc}, ops::AddAssign, time::Duration}; + + use bytes::Bytes; + use mqrstt::{AsyncEventHandler, packets::{self, Packet}, MqttClient, AsyncEventHandlerMut}; + + pub struct PingPong { + pub client: MqttClient, + pub number: Arc, + } + + impl PingPong{ + pub fn new(client: MqttClient) -> Self { + Self { + client, + number: Arc::new(AtomicU16::new(0)), + } + } + } + + impl AsyncEventHandler for PingPong { + async fn handle(&self, event: packets::Packet) -> () { + self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + let max_len = payload.len().min(10); + let a = &payload[0..max_len]; + if payload.to_lowercase().contains("ping") { + self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); + } + } + } + Packet::ConnAck(_) => (), + _ => (), + } + } + } + + impl AsyncEventHandlerMut for PingPong { + async fn handle(&mut self, event: packets::Packet) -> () { + self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + let max_len = payload.len().min(10); + let a = &payload[0..max_len]; + if payload.to_lowercase().contains("ping") { + self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); + } + } + } + Packet::ConnAck(_) => (), + _ => (), + } + } + } + + pub struct SimpleDelay{ + delay: Duration, + } + + impl SimpleDelay{ + pub fn new(delay: Duration) -> Self{ + Self { + delay, + } + } + } + + impl AsyncEventHandler for SimpleDelay { + fn handle(&self, _: Packet) -> impl futures::prelude::Future + Send + Sync { + tokio::time::sleep(self.delay) + } + } + impl AsyncEventHandlerMut for SimpleDelay{ + fn handle(&mut self, _: Packet) -> impl futures::prelude::Future + Send + Sync { + tokio::time::sleep(self.delay) + } + } +} \ No newline at end of file diff --git a/benches/benchmarks/tokio.rs b/benches/benchmarks/tokio.rs new file mode 100644 index 0000000..336f756 --- /dev/null +++ b/benches/benchmarks/tokio.rs @@ -0,0 +1,353 @@ +use std::{ + hint::black_box, + io::{Cursor, Write}, + time::Duration, net::SocketAddr, sync::Arc, +}; + +use bytes::BytesMut; +use criterion::{criterion_group, Criterion}; +use mqrstt::{NetworkBuilder, ConnectOptions, NetworkStatus}; +use tokio::net::TcpStream; + +use crate::benchmarks::test_handlers::{PingPong, SimpleDelay}; + +use super::fill_stuff; + +struct ReadWriteTester<'a> { + read: Cursor<&'a [u8]>, + write: Vec, +} + +impl<'a> ReadWriteTester<'a> { + pub fn new(read: &'a [u8]) -> Self { + Self { + read: Cursor::new(read), + write: Vec::new(), + } + } +} + +impl<'a> tokio::io::AsyncRead for ReadWriteTester<'a> { + fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> std::task::Poll> { + tokio::io::AsyncRead::poll_read(std::pin::Pin::new(&mut self.read), cx, buf) + } +} + +impl<'a> tokio::io::AsyncWrite for ReadWriteTester<'a> { + fn poll_write(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, _buf: &[u8]) -> std::task::Poll> { + todo!() + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { + todo!() + } + + fn poll_shutdown(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { + todo!() + } +} + +fn tokio_setup() -> (TcpStream, std::net::TcpStream, SocketAddr) { + let mut buffer = BytesMut::new(); + + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let tcp_stream = std::net::TcpStream::connect(addr).unwrap(); + + let (mut server, _addr) = listener.accept().unwrap(); + + fill_stuff(&mut buffer, 100, 5_000_000); + + server.write_all(&buffer.to_vec()).unwrap(); + + let tcp_stream = tokio::net::TcpStream::from_std(tcp_stream).unwrap(); + (tcp_stream, server, _addr) +} + +fn tokio_concurrent_benchmarks(c: &mut Criterion) { + let mut group = c.benchmark_group("Tokio concurrent read, write and handling"); + group.sample_size(30); + group.measurement_time(Duration::from_secs(120)); + + group.bench_function("tokio_bench_concurrent_read_write_and_handling_NOP", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_with_setup( + tokio_setup, + |(tcp_stream, server, addr)| async move { + let _server_box = black_box(server); + let _addr = black_box(addr); + + let options = ConnectOptions::new("test"); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_concurrent_network(); + + let _server_box = black_box(client); + + network.connect(tcp_stream, &mut ()).await.unwrap(); + let (read, write) = network.split(()).unwrap(); + + let _network_box = black_box(network); + + let read_handle = tokio::task::spawn(read.run()); + let write_handle = tokio::task::spawn(write.run()); + + let (read_res, write_res) = tokio::join!(read_handle, write_handle); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_some()); + assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); + }, + ) + }); + group.bench_function("tokio_bench_concurrent_read_write_and_handling_PingPong", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_with_setup( + tokio_setup, + |(tcp_stream, server, addr)| async move { + let options = ConnectOptions::new("test"); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_concurrent_network(); + + let mut pingpong = Arc::new(PingPong::new(client.clone())); + + network.connect(tcp_stream, &mut pingpong).await.unwrap(); + let (read, write) = network.split(pingpong.clone()).unwrap(); + + let read_handle = tokio::task::spawn(read.run()); + let write_handle = tokio::task::spawn(write.run()); + + let (read_res, write_res) = futures::join!(read_handle, write_handle); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_some()); + assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); + assert_eq!(102, pingpong.number.load(std::sync::atomic::Ordering::SeqCst)); + + let _server_box = black_box(client.clone()); + let _server_box = black_box(server); + let _addr_box = black_box(addr); + let _network_box = black_box(network); + }, + ) + }); + group.bench_function("tokio_bench_concurrent_read_write_and_handling_100ms_Delay", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_with_setup( + tokio_setup, + |(tcp_stream, server, addr)| async move { + let _server_box = black_box(server); + let _addr = black_box(addr); + + let options = ConnectOptions::new("test"); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_concurrent_network(); + + let _server_box = black_box(client); + + let mut handler = Arc::new(SimpleDelay::new(Duration::from_millis(100))); + + network.connect(tcp_stream, &mut handler).await.unwrap(); + let (read, write) = network.split(handler).unwrap(); + + + let read_handle = tokio::task::spawn(read.run()); + let write_handle = tokio::task::spawn(write.run()); + + let (read_res, write_res) = tokio::join!(read_handle, write_handle); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_some()); + assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); + + let _network_box = black_box(network); + }, + ) + }); + + group.bench_function("tokio_bench_concurrent_read_write", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_with_setup( + tokio_setup, + |(tcp_stream, server, addr)| async move { + let _server_box = black_box(server); + let _addr = black_box(addr); + + let options = ConnectOptions::new("test"); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); + + let _server_box = black_box(client); + + network.connect(tcp_stream, &mut ()).await.unwrap(); + + let (read, write) = network.split(()).unwrap(); + + let read_handle = tokio::task::spawn(read.run()); + let write_handle = tokio::task::spawn(write.run()); + + let (read_res, write_res) = tokio::join!(read_handle, write_handle); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_some()); + assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); + } + ) + }); + group.bench_function("tokio_bench_concurrent_read_write_PingPong", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_with_setup( + tokio_setup, + |(tcp_stream, server, addr)| async move { + let options = ConnectOptions::new("test"); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); + + let mut pingpong = PingPong::new(client.clone()); + + let num_packets_received = pingpong.number.clone(); + + network.connect(tcp_stream, &mut pingpong).await.unwrap(); + let (read, write) = network.split(pingpong).unwrap(); + + let read_handle = tokio::task::spawn(read.run()); + let write_handle = tokio::task::spawn(write.run()); + + let (read_res, write_res) = futures::join!(read_handle, write_handle); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_some()); + assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); + assert_eq!(102, num_packets_received.load(std::sync::atomic::Ordering::SeqCst)); + + let _server_box = black_box(client.clone()); + let _server_box = black_box(server); + let _addr_box = black_box(addr); + let _network_box = black_box(network); + }, + ) + }); + group.bench_function("tokio_bench_concurrent_read_write_100ms_Delay", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_with_setup( + tokio_setup, + |(tcp_stream, server, addr)| async move { + let _server_box = black_box(server); + let _addr = black_box(addr); + + let options = ConnectOptions::new("test"); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); + + let _server_box = black_box(client); + + let mut handler = SimpleDelay::new(Duration::from_millis(100)); + + network.connect(tcp_stream, &mut handler).await.unwrap(); + let (read, write) = network.split(handler).unwrap(); + + + let read_handle = tokio::task::spawn(read.run()); + let write_handle = tokio::task::spawn(write.run()); + + let (read_res, write_res) = futures::join!(read_handle, write_handle); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_ok()); + let read_res = read_res.unwrap(); + assert!(read_res.is_some()); + assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); + + let _network_box = black_box(network); + }, + ) + }); +} + + +fn tokio_synchronous_benchmarks(c: &mut Criterion){ + let mut group = c.benchmark_group("Tokio sequential"); + group.sample_size(30); + group.measurement_time(Duration::from_secs(120)); + + group.bench_function("tokio_bench_sync_read_write", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_with_setup( + tokio_setup, + |(tcp_stream, server, addr)| async move { + let _server_box = black_box(server); + let _addr = black_box(addr); + + let options = ConnectOptions::new("test"); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); + + let _server_box = black_box(client); + + network.connect(tcp_stream, &mut ()).await.unwrap(); + + let network_res = network.run(&mut ()).await; + + assert!(network_res.is_ok()); + let network_res = network_res.unwrap(); + assert_eq!(network_res, NetworkStatus::IncomingDisconnect); + } + ) + }); + group.bench_function("tokio_bench_sync_read_write_PingPong", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_with_setup( + tokio_setup, + |(tcp_stream, server, addr)| async move { + let _server_box = black_box(server); + let _addr = black_box(addr); + + let options = ConnectOptions::new("test"); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); + + let mut pingpong = PingPong::new(client.clone()); + + let _server_box = black_box(client); + + network.connect(tcp_stream, &mut pingpong).await.unwrap(); + + let network_res = network.run(&mut pingpong).await; + + assert!(network_res.is_ok()); + let network_res = network_res.unwrap(); + assert_eq!(network_res, NetworkStatus::IncomingDisconnect); + } + ) + }); + group.bench_function("tokio_bench_sync_read_write_100ms_Delay", |b| { + let runtime = tokio::runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_with_setup( + tokio_setup, + |(tcp_stream, server, addr)| async move { + let _server_box = black_box(server); + let _addr = black_box(addr); + + let options = ConnectOptions::new("test"); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); + + let mut handler = SimpleDelay::new(Duration::from_millis(100)); + + let _server_box = black_box(client); + + network.connect(tcp_stream, &mut handler).await.unwrap(); + + let network_res = network.run(&mut handler).await; + + assert!(network_res.is_ok()); + let network_res = network_res.unwrap(); + assert_eq!(network_res, NetworkStatus::IncomingDisconnect); + } + ) + }); +} + +criterion_group!(tokio_concurrent, tokio_concurrent_benchmarks); +criterion_group!(tokio_synchronous, tokio_synchronous_benchmarks); diff --git a/benches/benchmarks/tokio_concurrent.rs b/benches/benchmarks/tokio_concurrent.rs deleted file mode 100644 index 9000a5c..0000000 --- a/benches/benchmarks/tokio_concurrent.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::{ - hint::black_box, - io::{Cursor, Write}, - time::Duration, -}; - -use bytes::BytesMut; -use criterion::{criterion_group, Criterion}; -use mqrstt::{new_tokio, ConnectOptions}; - -use super::fill_stuff; - -struct ReadWriteTester<'a> { - read: Cursor<&'a [u8]>, - write: Vec, -} - -impl<'a> ReadWriteTester<'a> { - pub fn new(read: &'a [u8]) -> Self { - Self { - read: Cursor::new(read), - write: Vec::new(), - } - } -} - -impl<'a> tokio::io::AsyncRead for ReadWriteTester<'a> { - fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> std::task::Poll> { - tokio::io::AsyncRead::poll_read(std::pin::Pin::new(&mut self.read), cx, buf) - } -} - -impl<'a> tokio::io::AsyncWrite for ReadWriteTester<'a> { - fn poll_write(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, _buf: &[u8]) -> std::task::Poll> { - todo!() - } - - fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { - todo!() - } - - fn poll_shutdown(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { - todo!() - } -} - -fn tokio_concurrent(c: &mut Criterion) { - let mut group = c.benchmark_group("Tokio concurrent throughput test"); - group.sample_size(10); - group.measurement_time(Duration::from_secs(20)); - - group.bench_function("tokio_bench_concurrent_read_write", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup( - || { - let mut buffer = BytesMut::new(); - - // :0 tells the OS to pick an open port. - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - - let tcp_stream = std::net::TcpStream::connect(addr).unwrap(); - - let (mut server, _addr) = listener.accept().unwrap(); - - fill_stuff(&mut buffer, 100, 5_000_000); - - server.write_all(&buffer.to_vec()).unwrap(); - - let tcp_stream = tokio::net::TcpStream::from_std(tcp_stream).unwrap(); - (tcp_stream, server, _addr) - }, - |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, _) = new_tokio(options); - - network.connect(tcp_stream, ()).await.unwrap(); - - // network.run().await.unwrap(); - todo!() - }, - ) - }); -} - -criterion_group!(tokio, tokio_concurrent); diff --git a/src/client.rs b/src/client.rs index 53c9bef..e09c533 100644 --- a/src/client.rs +++ b/src/client.rs @@ -26,7 +26,7 @@ pub struct MqttClient { } impl MqttClient { - pub fn new(available_packet_ids_r: Receiver, to_network_s: Sender, max_packet_size: usize) -> Self { + pub(crate) fn new(available_packet_ids_r: Receiver, to_network_s: Sender, max_packet_size: usize) -> Self { Self { available_packet_ids_r, to_network_s, @@ -44,8 +44,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let options = mqrstt::ConnectOptions::new("example_id"); - /// # let (network, mqtt_client) = mqrstt::new_smol::(options); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_sequential_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -89,8 +88,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let options = mqrstt::ConnectOptions::new("example_id"); - /// # let (network, mqtt_client) = mqrstt::new_smol::(options); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_sequential_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -160,8 +158,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let options = mqrstt::ConnectOptions::new("example_id"); - /// # let (network, mqtt_client) = mqrstt::new_smol::(options); + /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -215,8 +212,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let options = mqrstt::ConnectOptions::new("example_id"); - /// # let (network, mqtt_client) = mqrstt::new_smol::(options); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -271,6 +267,7 @@ impl MqttClient { /// mqtt_client.publish_with_properties("test/topic", QoS::AtMostOnce, true, payload, properties).await; /// /// # }); + /// # let _network = std::hint::black_box(network); /// ``` pub async fn publish_with_properties, P: Into>(&self, topic: T, qos: QoS, retain: bool, payload: P, properties: PublishProperties) -> Result<(), ClientError> { let pkid = match qos { @@ -299,8 +296,7 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let options = mqrstt::ConnectOptions::new("example_id"); - /// # let (network, mqtt_client) = mqrstt::new_smol::(options); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { /// /// // Unsubscribe from a single topic specified as a string: @@ -324,6 +320,7 @@ impl MqttClient { /// mqtt_client.unsubscribe(topics.as_slice()).await; /// /// # }); + /// # let _network = std::hint::black_box(network); /// ``` pub async fn unsubscribe>(&self, into_topics: T) -> Result<(), ClientError> { let pkid = self.available_packet_ids_r.recv().await.map_err(|_| ClientError::NoNetworkChannel)?; @@ -346,8 +343,7 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let options = mqrstt::ConnectOptions::new("example_id"); - /// # let (network, mqtt_client) = mqrstt::new_smol::(options); + /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::UnsubscribeProperties; @@ -416,13 +412,13 @@ impl MqttClient { /// # Example /// /// ``` - /// # let options = mqrstt::ConnectOptions::new("example_id"); - /// # let (network, mqtt_client) = mqrstt::new_smol::(options); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { /// /// mqtt_client.disconnect().await.unwrap(); /// /// # }); + /// # let _network = std::hint::black_box(network); /// ``` pub async fn disconnect(&self) -> Result<(), ClientError> { let disconnect = Disconnect { @@ -442,8 +438,7 @@ impl MqttClient { /// # Example /// /// ``` - /// # let options = mqrstt::ConnectOptions::new("example_id"); - /// # let (network, mqtt_client) = mqrstt::new_smol::(options); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::DisconnectProperties; @@ -457,6 +452,7 @@ impl MqttClient { /// mqtt_client.disconnect_with_properties(DisconnectReasonCode::NormalDisconnection, properties).await.unwrap(); /// /// # }); + /// # let _network = std::hint::black_box(network); /// ``` pub async fn disconnect_with_properties(&self, reason_code: DisconnectReasonCode, properties: DisconnectProperties) -> Result<(), ClientError> { let disconnect = Disconnect { reason_code, properties }; @@ -910,6 +906,56 @@ mod tests { (client, client_to_handler_r, to_network_r) } + #[tokio::test] + async fn test_subscribe() { + let (mqtt_client, _client_to_handler_r, _) = create_new_test_client(); + + // retain_handling: RetainHandling::ZERO, retain_as_publish: false, no_local: false, qos: QoS::AtMostOnce, + let _ = mqtt_client.subscribe("test/topic").await; + + // retain_handling: RetainHandling::ZERO, retain_as_publish: false, no_local: false, qos: QoS::ExactlyOnce, + let _ = mqtt_client.subscribe(("test/topic", QoS::ExactlyOnce)).await; + + let vec = vec![("test/topic1", QoS::ExactlyOnce), ("test/topic2", QoS::AtMostOnce)]; + let _ = mqtt_client.subscribe(vec).await; + + let vec = [("test/topic1", QoS::ExactlyOnce), ("test/topic2", QoS::AtLeastOnce)]; + let _ = mqtt_client.subscribe(vec.as_slice()).await; + + let sub_options = crate::packets::SubscriptionOptions{ + retain_handling: crate::packets::RetainHandling::TWO, + retain_as_publish: false, + no_local: false, + qos: QoS::AtLeastOnce, + }; + let _ = mqtt_client.subscribe(("final/test/topic", sub_options)).await; + } + + #[tokio::test] + async fn test_unsubscribe() { + let (mqtt_client, _client_to_handler_r, _) = create_new_test_client(); + + // Unsubscribe from a single topic specified as a string: + let topic = "test/topic"; + let _ = mqtt_client.unsubscribe(topic).await; + + // Unsubscribe from multiple topics specified as an array of string slices: + let topics = ["test/topic1", "test/topic2"]; + let _ = mqtt_client.unsubscribe(topics.as_slice()).await; + + // Unsubscribe from a single topic specified as a String: + let topic = String::from("test/topic"); + let _ = mqtt_client.unsubscribe(topic).await; + + // Unsubscribe from multiple topics specified as a Vec: + let topics = vec![String::from("test/topic1"), String::from("test/topic2")]; + let _ = mqtt_client.unsubscribe(topics).await; + + // Unsubscribe from multiple topics specified as an array of String: + let topics = &[String::from("test/topic1"), String::from("test/topic2")]; + let _ = mqtt_client.unsubscribe(topics.as_slice()).await; + } + #[tokio::test] async fn publish_with_just_right_topic_len() { let (client, _client_to_handler_r, _) = create_new_test_client(); diff --git a/src/error.rs b/src/error.rs index 4c316e8..47606f3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -38,7 +38,7 @@ pub enum ConnectionError { #[error("Handler Error: {0:?}")] HandlerError(#[from] HandlerError), - #[cfg(feature = "concurrent_tokio")] + #[cfg(feature = "tokio_concurrent")] #[error("Join error")] JoinError(#[from] tokio::task::JoinError), } diff --git a/src/event_handlers.rs b/src/event_handlers.rs new file mode 100644 index 0000000..6be3232 --- /dev/null +++ b/src/event_handlers.rs @@ -0,0 +1,192 @@ +use std::sync::Arc; + +use futures::Future; + +use crate::packets::Packet; + +/// Handlers are used to deal with packets before they are further processed (acked) +/// This guarantees that the end user has handlded the packet. +/// Trait for async mutable access to handler. +/// Usefull when you have a single handler + +/// This trait can be used types which +pub trait AsyncEventHandler { + fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync; +} +impl AsyncEventHandler for &T where T: AsyncEventHandler { + #[inline] + fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync{ + AsyncEventHandler::handle(*self, incoming_packet) + } +} +impl AsyncEventHandler for Arc where T: AsyncEventHandler { + #[inline] + fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync{ + ::handle(&self, incoming_packet) + } +} +impl AsyncEventHandler for () { + fn handle(&self, _: Packet) -> impl Future + Send + Sync { + async {} + } +} + +pub trait AsyncEventHandlerMut { + fn handle(&mut self, incoming_packet: Packet) -> impl Future + Send + Sync; +} + +impl AsyncEventHandlerMut for () { + fn handle(&mut self, _: Packet) -> impl Future + Send + Sync { + async {} + } +} + + +pub trait EventHandler { + fn handle(&mut self, incoming_packet: Packet); +} + +impl EventHandler for () { + fn handle(&mut self, _: Packet) {} +} + +pub mod example_handlers{ + use std::{sync::atomic::AtomicU16, ops::AddAssign}; + + use bytes::Bytes; + + use crate::{AsyncEventHandlerMut, packets::{Packet, self}, EventHandler, MqttClient, AsyncEventHandler}; + + /// Most basic no op handler + /// This handler performs no operations on incoming messages. + pub struct NOP {} + + impl AsyncEventHandlerMut for NOP { + async fn handle(&mut self, _: Packet) {} + } + + impl EventHandler for NOP { + fn handle(&mut self, _: Packet) {} + } + + pub struct PingResp { + pub client: MqttClient, + pub ping_resp_received: AtomicU16, + } + + impl PingResp { + pub fn new(client: MqttClient) -> Self { + Self { + client, + ping_resp_received: AtomicU16::new(0), + } + } + } + + impl AsyncEventHandlerMut for PingResp { + async fn handle(&mut self, event: packets::Packet) -> () { + use Packet::*; + if event == PingResp { + self.ping_resp_received.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + println!("Received packet: {}", event); + } + } + + impl EventHandler for PingResp { + fn handle(&mut self, event: Packet) { + use Packet::*; + if event == PingResp { + self.ping_resp_received.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + println!("Received packet: {}", event); + } + } + + pub struct PingPong { + pub client: MqttClient, + pub number: AtomicU16, + } + + impl PingPong{ + pub fn new(client: MqttClient) -> Self { + Self { + client, + number: AtomicU16::new(0), + } + } + } + + impl AsyncEventHandler for PingPong { + async fn handle(&self, event: packets::Packet) -> () { + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + let max_len = payload.len().min(10); + let a = &payload[0..max_len]; + if payload.to_lowercase().contains("ping") { + self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); + println!("Received publish payload: {}", a); + + if !p.retain{ + self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + + println!("DBG: \n {}", &Packet::Publish(p)); + } + } + } + Packet::ConnAck(_) => { + println!("Connected!") + } + _ => (), + } + } + } + + impl AsyncEventHandlerMut for PingPong { + async fn handle(&mut self, event: packets::Packet) -> () { + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + let max_len = payload.len().min(10); + let a = &payload[0..max_len]; + if payload.to_lowercase().contains("ping") { + self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); + println!("Received publish payload: {}", a); + + if !p.retain{ + self.number.get_mut().add_assign(1); + } + + println!("DBG: \n {}", &Packet::Publish(p)); + } + } + } + Packet::ConnAck(_) => { + println!("Connected!") + } + _ => (), + } + } + } + + #[cfg(feature = "sync")] + impl EventHandler for PingPong { + fn handle(&mut self, event: Packet) { + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + if payload.to_lowercase().contains("ping") { + self.client.publish_blocking(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).unwrap(); + } + } + } + Packet::ConnAck(_) => { + println!("Connected!") + } + _ => (), + } + } + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 3d87a8b..a862ec1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,23 +33,23 @@ //! ```rust //! use mqrstt::{ //! MqttClient, -//! NOP, +//! example_handlers::NOP, //! ConnectOptions, -//! new_smol, //! packets::{self, Packet}, //! AsyncEventHandler, -//! smol::NetworkStatus, +//! NetworkStatus, +//! NetworkBuilder, //! }; //! //! smol::block_on(async { -//! let options = ConnectOptions::new("mqrsttSmolExample"); -//! //! // Construct a no op handler //! let mut nop = NOP{}; //! //! // In normal operations you would want to loop this connection //! // To reconnect after a disconnect or error -//! let (mut network, client) = new_smol(options); +//! let (mut network, client) = NetworkBuilder +//! ::new_from_client_id("mqrsttSmolExample") +//! .smol_sequential_network(); //! let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)) //! .await //! .unwrap(); @@ -59,14 +59,7 @@ //! client.subscribe("mqrstt").await.unwrap(); //! //! let (n, t) = futures::join!( -//! async { -//! loop { -//! return match network.poll(&mut nop).await { -//! Ok(NetworkStatus::Active) => continue, -//! otherwise => otherwise, -//! }; -//! } -//! }, +//! network.run(&mut nop), //! async { //! smol::Timer::after(std::time::Duration::from_secs(30)).await; //! client.disconnect().await.unwrap(); @@ -82,19 +75,20 @@ //! ```rust //! use mqrstt::{ //! MqttClient, -//! NOP, +//! example_handlers::NOP, //! ConnectOptions, -//! new_tokio, //! packets::{self, Packet}, //! AsyncEventHandler, -//! tokio::NetworkStatus, +//! NetworkStatus, +//! NetworkBuilder, //! }; //! use tokio::time::Duration; //! //! #[tokio::main] //! async fn main() { -//! let options = ConnectOptions::new("TokioTcpPingPongExample"); -//! let (mut network, client) = new_tokio(options); +//! let (mut network, client) = NetworkBuilder +//! ::new_from_client_id("TokioTcpPingPongExample") +//! .tokio_sequential_network(); //! //! // Construct a no op handler //! let mut nop = NOP{}; @@ -108,15 +102,8 @@ //! //! client.subscribe("mqrstt").await.unwrap(); //! -//! let (n, _) = tokio::join!( -//! async { -//! loop { -//! return match network.poll(&mut nop).await { -//! Ok(NetworkStatus::Active) => continue, -//! otherwise => otherwise, -//! }; -//! } -//! }, +//! let (n, _) = futures::join!( +//! network.run(&mut nop), //! async { //! tokio::time::sleep(Duration::from_secs(30)).await; //! client.disconnect().await.unwrap(); @@ -126,59 +113,60 @@ //! } //! ``` //! -//! Sync example: -//! ---------------------------- -//! ```rust -//! use mqrstt::{ -//! MqttClient, -//! NOP, -//! ConnectOptions, -//! new_sync, -//! packets::{self, Packet}, -//! EventHandler, -//! sync::NetworkStatus, -//! }; -//! use std::net::TcpStream; -//! -//! let mut client_id: String = "SyncTcppingrespTestExample".to_string(); -//! let options = ConnectOptions::new(client_id); -//! -//! let address = "broker.emqx.io"; -//! let port = 1883; -//! -//! let (mut network, client) = new_sync(options); -//! -//! // Construct a no op handler -//! let mut nop = NOP{}; -//! -//! // In normal operations you would want to loop connect -//! // To reconnect after a disconnect or error -//! let stream = TcpStream::connect((address, port)).unwrap(); -//! // IMPORTANT: Set nonblocking to true! No progression will be made when stream reads block! -//! stream.set_nonblocking(true).unwrap(); -//! network.connect(stream, &mut nop).unwrap(); -//! -//! let res_join_handle = std::thread::spawn(move || -//! loop { -//! match network.poll(&mut nop) { -//! Ok(NetworkStatus::ActivePending) => { -//! std::thread::sleep(std::time::Duration::from_millis(100)); -//! }, -//! Ok(NetworkStatus::ActiveReady) => { -//! std::thread::sleep(std::time::Duration::from_millis(100)); -//! }, -//! otherwise => return otherwise, -//! } -//! } -//! ); -//! -//! std::thread::sleep(std::time::Duration::from_secs(30)); -//! client.disconnect_blocking().unwrap(); -//! let join_res = res_join_handle.join(); -//! assert!(join_res.is_ok()); -//! let res = join_res.unwrap(); -//! assert!(res.is_ok()); -//! ``` +// //! Sync example: +// //! ---------------------------- +// //! ```rust +// //! use mqrstt::{ +// //! MqttClient, +// //! example_handlers::NOP, +// //! ConnectOptions, +// //! packets::{self, Packet}, +// //! EventHandler, +// //! sync::NetworkStatus, +// //! }; +// //! use std::net::TcpStream; +// //! +// //! let mut client_id: String = "SyncTcppingrespTestExample".to_string(); +// //! let options = ConnectOptions::new(client_id); +// //! +// //! let address = "broker.emqx.io"; +// //! let port = 1883; +// //! +// //! let (mut network, client) = new_sync(options); +// //! +// //! // Construct a no op handler +// //! let mut nop = NOP{}; +// //! +// //! // In normal operations you would want to loop connect +// //! // To reconnect after a disconnect or error +// //! let stream = TcpStream::connect((address, port)).unwrap(); +// //! // IMPORTANT: Set nonblocking to true! No progression will be made when stream reads block! +// //! stream.set_nonblocking(true).unwrap(); +// //! network.connect(stream, &mut nop).unwrap(); +// //! +// //! let res_join_handle = std::thread::spawn(move || +// //! loop { +// //! match network.poll(&mut nop) { +// //! Ok(NetworkStatus::ActivePending) => { +// //! std::thread::sleep(std::time::Duration::from_millis(100)); +// //! }, +// //! Ok(NetworkStatus::ActiveReady) => { +// //! std::thread::sleep(std::time::Duration::from_millis(100)); +// //! }, +// //! otherwise => return otherwise, +// //! } +// //! } +// //! ); +// //! +// //! std::thread::sleep(std::time::Duration::from_secs(30)); +// //! client.disconnect_blocking().unwrap(); +// //! let join_res = res_join_handle.join(); +// //! assert!(join_res.is_ok()); +// //! let res = join_res.unwrap(); +// //! assert!(res.is_ok()); +// //! ``` + +const CHANNEL_SIZE: usize = 100; mod available_packet_ids; mod client; @@ -186,24 +174,23 @@ mod connect_options; mod mqtt_handler; mod util; -#[cfg(any(feature = "tokio", feature = "concurrent_tokio"))] +#[cfg(any(feature = "tokio", feature = "tokio_concurrent"))] pub mod tokio; #[cfg(feature = "smol")] pub mod smol; -#[cfg(feature = "sync")] -pub mod sync; pub mod error; pub mod packets; -pub mod state; +mod state; +mod event_handlers; +use std::marker::PhantomData; -use std::sync::Arc; +pub use event_handlers::*; pub use client::MqttClient; pub use connect_options::ConnectOptions; -use futures::Future; -pub use mqtt_handler::StateHandler; -use packets::{Packet}; +use mqtt_handler::StateHandler; +use available_packet_ids::AvailablePacketIds; #[cfg(test)] pub mod tests; @@ -220,102 +207,127 @@ pub enum NetworkStatus { KeepAliveTimeout, } -/// Handlers are used to deal with packets before they are further processed (acked) -/// This guarantees that the end user has handlded the packet. -/// Trait for async mutable access to handler. -/// Usefull when you have a single handler -pub trait AsyncEventHandler { - fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync; +#[derive(Debug)] +pub struct NetworkBuilder{ + handler: PhantomData, + stream: PhantomData, + options: ConnectOptions, } -impl AsyncEventHandler for Arc -where - T: AsyncEventHandler, -{ - fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync { - T::handle(&self, incoming_packet) +impl NetworkBuilder { + #[inline] + pub const fn new_from_options(options: ConnectOptions) -> Self{ + Self{ + handler: PhantomData, + stream: PhantomData, + options + } } -} - -impl AsyncEventHandler for () { - fn handle(&self, _: Packet) -> impl Future + Send + Sync { - async {} + #[inline] + pub fn new_from_client_id>(client_id: C) -> Self{ + let options = ConnectOptions::new(client_id); + Self{ + handler: PhantomData, + stream: PhantomData, + options + } } } -pub trait EventHandler { - fn handle(&mut self, incoming_packet: Packet); -} - -impl EventHandler for () { - fn handle(&mut self, _: Packet) {} +impl NetworkBuilder +where + H: AsyncEventHandlerMut, + S: ::tokio::io::AsyncReadExt + ::tokio::io::AsyncWriteExt + Sized + Unpin, +{ + /// Creates the needed components to run the MQTT client using a stream that implements [`tokio::io::AsyncReadExt`] and [`tokio::io::AsyncWriteExt`] + /// This network is supposed to be ran on a single task/thread. The read and write operations happen one after the other. + /// This approach does not give the most speed in terms of reading and writing but provides a simple and easy to use client with low overhead for low throughput clients. + /// + /// For more throughput: [`NetworkBuilder::tokio_concurrent_network`] + /// + /// # Example + /// ``` + /// use mqrstt::ConnectOptions; + /// + /// let options = ConnectOptions::new("ExampleClient"); + /// let (mut network, client) = mqrstt::NetworkBuilder::<(), tokio::net::TcpStream> + /// ::new_from_options(options) + /// .tokio_sequential_network(); + /// ``` + pub fn tokio_sequential_network(self) -> (tokio::Network, MqttClient) where H: AsyncEventHandlerMut { + let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); + + let (apkids, apkids_r) = AvailablePacketIds::new(self.options.send_maximum()); + + let max_packet_size = self.options.maximum_packet_size(); + + let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); + + let network = tokio::Network::new(self.options, to_network_r, apkids); + + (network, client) + } } -/// Most basic no op handler -/// This handler performs no operations on incoming messages. -pub struct NOP {} -impl AsyncEventHandler for NOP { - async fn handle(&self, _: Packet) {} -} - -impl EventHandler for NOP { - fn handle(&mut self, _: Packet) {} +impl NetworkBuilder +where + H: AsyncEventHandler, + S: ::tokio::io::AsyncReadExt + ::tokio::io::AsyncWriteExt + Sized + Unpin, +{ + #[cfg(feature = "tokio_concurrent")] + /// Creates the needed components to run the MQTT client using a stream that implements [`tokio::io::AsyncReadExt`] and [`tokio::io::AsyncWriteExt`] + /// # Example + /// + /// ``` + /// use mqrstt::ConnectOptions; + /// + /// let options = ConnectOptions::new("ExampleClient"); + /// let (mut network, client) = mqrstt::NetworkBuilder::<(), tokio::net::TcpStream> + /// ::new_from_options(options) + /// .tokio_concurrent_network(); + /// ``` + pub fn tokio_concurrent_network(self) -> (tokio::Network, MqttClient) { + let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); + + let (apkids, apkids_r) = AvailablePacketIds::new(self.options.send_maximum()); + + let max_packet_size = self.options.maximum_packet_size(); + + let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); + + let network = tokio::Network::new(self.options, to_network_r, apkids); + + (network, client) + } } #[cfg(feature = "smol")] -/// Creates the needed components to run the MQTT client using a stream that implements [`smol::io::AsyncReadExt`] and [`smol::io::AsyncWriteExt`] -/// ``` -/// use mqrstt::ConnectOptions; -/// -/// let options = ConnectOptions::new("ExampleClient"); -/// let (network, client) = mqrstt::new_tokio::(options); -/// ``` -pub fn new_smol(options: ConnectOptions) -> (smol::Network, MqttClient) -where +impl NetworkBuilder +where + H: AsyncEventHandlerMut, S: ::smol::io::AsyncReadExt + ::smol::io::AsyncWriteExt + Sized + Unpin, { - let (to_network_s, to_network_r) = async_channel::bounded(100); - - let (apkids, apkids_r) = available_packet_ids::AvailablePacketIds::new(options.send_maximum()); - - let max_packet_size = options.maximum_packet_size(); - - let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); - - let network = smol::Network::::new(options, to_network_r, apkids); - - (network, client) -} - -/// Creates the needed components to run the MQTT client using a stream that implements [`tokio::io::AsyncReadExt`] and [`tokio::io::AsyncWriteExt`] -#[cfg(feature = "concurrent_tokio")] -/// # Example -/// -/// ``` -/// use mqrstt::ConnectOptions; -/// -/// let options = ConnectOptions::new("ExampleClient"); -/// let (network, client) = mqrstt::new_tokio::(options); -/// ``` -pub fn new_tokio(options: ConnectOptions) -> (tokio::Network, MqttClient) -where - H: AsyncEventHandler + Clone + Send + Sync, - S: ::tokio::io::AsyncReadExt + ::tokio::io::AsyncWriteExt + Sized + Unpin, -{ - use available_packet_ids::AvailablePacketIds; - - let (to_network_s, to_network_r) = async_channel::bounded(100); - - let (apkids, apkids_r) = AvailablePacketIds::new(options.send_maximum()); - - let max_packet_size = options.maximum_packet_size(); - - let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); - - let network = tokio::Network::new(options, to_network_r, apkids); + /// Creates the needed components to run the MQTT client using a stream that implements [`smol::io::AsyncReadExt`] and [`smol::io::AsyncWriteExt`] + /// ``` + /// let (mut network, client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream> + /// ::new_from_client_id("ExampleClient") + /// .smol_sequential_network(); + /// ``` + pub fn smol_sequential_network(self) -> (smol::Network, MqttClient) { + let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); + + let (apkids, apkids_r) = available_packet_ids::AvailablePacketIds::new(self.options.send_maximum()); + + let max_packet_size = self.options.maximum_packet_size(); + + let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); + + let network = smol::Network::::new(self.options, to_network_r, apkids); + + (network, client) + } - (network, client) } #[cfg(feature = "sync")] @@ -351,137 +363,24 @@ where } #[cfg(test)] -mod lib_test { - use std::{ - net::TcpStream, - sync::{atomic::AtomicU16, Arc}, - thread::{self}, - time::Duration, - }; - - #[cfg(feature = "concurrent_tokio")] - use crate::new_tokio; +fn random_chars() -> String { + rand::Rng::sample_iter(rand::thread_rng(), &rand::distributions::Alphanumeric).take(7).map(char::from).collect() +} - #[cfg(feature = "smol")] - use crate::new_smol; +#[cfg(feature = "smol")] +#[cfg(test)] +mod smol_lib_test { - #[cfg(feature = "sync")] - use crate::new_sync; + use std::time::Duration; use rand::Rng; - use crate::{ - packets::{self, Packet}, - AsyncEventHandler, ConnectOptions, EventHandler, MqttClient, - }; - use bytes::Bytes; - use packets::QoS; - - pub struct PingPong { - pub client: MqttClient, - pub number: AtomicU16, - } - - impl PingPong{ - pub fn new(client: MqttClient) -> Self { - Self { - client, - number: AtomicU16::new(0), - } - } - } + use crate::{ConnectOptions, example_handlers::PingPong, packets::QoS, NetworkBuilder}; - #[cfg(any(feature = "smol", feature = "tokio", feature = "concurrent_tokio"))] - impl AsyncEventHandler for PingPong { - async fn handle(&self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - let max_len = payload.len().min(10); - let a = &payload[0..max_len]; - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); - println!("Received publish payload: {}", a); - println!("DBG: \n {}", &Packet::Publish(p)); - - // println!("Received Ping, Send pong!"); - self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } - } - #[cfg(feature = "sync")] - impl EventHandler for PingPong { - fn handle(&mut self, event: Packet) { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish_blocking(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).unwrap(); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } - } - - #[cfg(feature = "sync")] - #[test] - fn test_sync_tcp() { - let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); - client_id += "_SyncTcpPingPong"; - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 1883; - - // IMPORTANT: Set nonblocking to true! Blocking on reads will happen! - let stream = TcpStream::connect((address, port)).unwrap(); - stream.set_read_timeout(Some(Duration::from_millis(500))).unwrap(); - // stream.set_nonblocking(true).unwrap(); - - let (mut network, client) = new_sync(options); - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(stream, &mut pingpong).unwrap(); - - client.subscribe_blocking("mqrstt").unwrap(); - - let res_join_handle = thread::spawn(move || - network.poll(&mut pingpong).unwrap() - ); - - client.publish_blocking("mqrstt".to_string(), QoS::ExactlyOnce, false, b"ping".repeat(500)).unwrap(); - client.publish_blocking("mqrstt".to_string(), QoS::AtMostOnce, true, b"ping".to_vec()).unwrap(); - client.publish_blocking("mqrstt".to_string(), QoS::AtLeastOnce, false, b"ping".to_vec()).unwrap(); - client.publish_blocking("mqrstt".to_string(), QoS::ExactlyOnce, false, b"ping".repeat(500)).unwrap(); - - std::thread::sleep(std::time::Duration::from_secs(20)); - client.unsubscribe_blocking("mqrstt").unwrap(); - std::thread::sleep(std::time::Duration::from_secs(5)); - client.disconnect_blocking().unwrap(); - println!("Disconnect queued"); - - let wrapped_res = res_join_handle.join(); - assert!(wrapped_res.is_ok()); - let res = dbg!(wrapped_res.unwrap()); - // assert!(res.is_ok()); - } - - #[cfg(feature = "smol")] #[test] fn test_smol_tcp() { + smol::block_on(async { let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); client_id += "_SmolTcpPingPong"; @@ -490,7 +389,7 @@ mod lib_test { let address = "broker.emqx.io"; let port = 1883; - let (mut network, client) = new_smol(options); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingpong = PingPong::new(client.clone()); @@ -519,30 +418,130 @@ mod lib_test { }); } - #[cfg(feature = "concurrent_tokio")] + #[test] + fn test_smol_ping_req() { + smol::block_on(async { + let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); + client_id += "_SmolTcppingrespTest"; + let mut options = ConnectOptions::new(client_id); + options.set_keep_alive_interval(Duration::from_secs(5)); + + let sleep_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; + + let address = "broker.emqx.io"; + let port = 1883; + + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); + + let mut pingresp = crate::example_handlers::PingResp::new(client.clone()); + + network.connect(stream, &mut pingresp).await.unwrap(); + + let (n, _) = futures::join!( + async { + match network.run(&mut pingresp).await { + Ok(crate::NetworkStatus::OutgoingDisconnect) => return Ok(pingresp), + Ok(crate::NetworkStatus::KeepAliveTimeout) => panic!(), + Ok(crate::NetworkStatus::IncomingDisconnect) => panic!(), + Err(err) => return Err(err), + } + }, + async { + smol::Timer::after(sleep_duration).await; + client.disconnect().await.unwrap(); + } + ); + assert!(n.is_ok()); + let pingresp = n.unwrap(); + assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); + }); + } + + #[cfg(all(target_family = "windows"))] + #[test] + fn test_close_write_tcp_stream_smol() { + use crate::error::ConnectionError; + use std::io::ErrorKind; + + smol::block_on(async { + let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); + client_id += "_SmolTcppingrespTest"; + let options = ConnectOptions::new(client_id); + + let address = "127.0.0.1"; + let port = 2001; + + let (n, _) = futures::join!( + async { + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); + let mut pingresp = crate::example_handlers::PingResp::new(client.clone()); + network.connect(stream, &mut pingresp).await + }, + async { + let listener = smol::net::TcpListener::bind((address, port)).await.unwrap(); + let (stream, _) = listener.accept().await.unwrap(); + smol::Timer::after(std::time::Duration::from_secs(10)).await; + stream.shutdown(std::net::Shutdown::Write).unwrap(); + } + ); + if let ConnectionError::Io(err) = n.unwrap_err() { + assert_eq!(ErrorKind::ConnectionReset, err.kind()); + assert_eq!("Connection reset by peer".to_string(), err.to_string()); + } else { + panic!(); + } + }); + } +} + +#[cfg(feature = "tokio_concurrent")] +#[cfg(test)] +mod tokio_lib_test { + use crate::example_handlers::PingPong; + + use crate::packets::QoS; + + use std::{ + sync::Arc, + time::Duration, + }; + + use crate:: + ConnectOptions + ; + + + #[cfg(feature = "tokio_concurrent")] #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn test_tokio_tcp() { - let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); - client_id += "_TokioTcpPingPong"; - let options = ConnectOptions::new(client_id); + use std::hint::black_box; + + use crate::NetworkBuilder; - let (mut network, client) = new_tokio(options); + let client_id: String = crate::random_chars() + "_TokioTcpPingPong"; - let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + let (mut network, client) = NetworkBuilder + ::new_from_client_id(client_id) + .tokio_concurrent_network(); + + let stream = tokio::net::TcpStream::connect(("azurewe1576.azureexternal.dnvgl.com", 1883)).await.unwrap(); + // let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); - let pingpong = Arc::new(PingPong::new(client.clone())); + let mut pingpong = Arc::new(PingPong::new(client.clone())); - network.connect(stream, pingpong.clone()).await.unwrap(); + network.connect(stream, &mut pingpong).await.unwrap(); client.subscribe(("mqrstt", QoS::ExactlyOnce)).await.unwrap(); - let (read, write) = network.read_write_tasks().unwrap(); + let (read, write) = network.split(pingpong.clone()).unwrap(); let read_handle = tokio::task::spawn(read.run()); let write_handle = tokio::task::spawn(write.run()); let (read_result, write_result, _) = tokio::join!( - read_handle, + read_handle, write_handle, async { client.publish("mqrstt".to_string(), QoS::ExactlyOnce, false, b"ping".repeat(500)).await.unwrap(); @@ -561,142 +560,67 @@ mod lib_test { assert!(write_result.is_ok()); assert_eq!(crate::NetworkStatus::OutgoingDisconnect, write_result.unwrap().unwrap()); assert_eq!(4, pingpong.number.load(std::sync::atomic::Ordering::SeqCst)); + let _ = black_box(read_result); } - pub struct PingResp { - pub client: MqttClient, - pub ping_resp_received: AtomicU16, - } - - impl PingResp { - pub fn new(client: MqttClient) -> Self { - Self { - client, - ping_resp_received: AtomicU16::new(0), - } - } - } - - impl AsyncEventHandler for PingResp { - async fn handle(&self, event: packets::Packet) -> () { - use Packet::*; - if event == PingResp { - self.ping_resp_received.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - } - println!("Received packet: {}", event); - } - } - - impl EventHandler for PingResp { - fn handle(&mut self, event: Packet) { - use Packet::*; - if event == PingResp { - self.ping_resp_received.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - } - println!("Received packet: {}", event); - } - } - - #[cfg(feature = "sync")] - #[test] - fn test_sync_ping_req() { - let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); - client_id += "_SyncTcppingrespTest"; - let mut options = ConnectOptions::new(client_id); - options.set_keep_alive_interval(Duration::from_secs(5)); - - let sleep_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; - - let address = "broker.emqx.io"; - let port = 1883; + - let (mut network, client) = new_sync(options); + // #[cfg(feature = "tokio_concurrent")] + // #[tokio::test] + // async fn test_tokio_ping_req() { + // let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); + // client_id += "_TokioTcppingrespTest"; + // let mut options = ConnectOptions::new(client_id); + // let keep_alive_interval = 5; + // options.set_keep_alive_interval(Duration::from_secs(keep_alive_interval)); - // IMPORTANT: Set nonblocking to true! Blocking on reads will happen! - let stream = TcpStream::connect((address, port)).unwrap(); - stream.set_nonblocking(true).unwrap(); + // let wait_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; - let mut pingresp = PingResp::new(client.clone()); + // let (mut network, client) = new_tokio(options); - network.connect(stream, &mut pingresp).unwrap(); + // let stream = tokio::net::TcpStream::connect(("azurewe1576.azureexternal.dnvgl.com", 1883)).await.unwrap(); - let res_join_handle = thread::spawn(move || loop { - match network.poll(&mut pingresp) { - Ok(crate::NetworkStatus::OutgoingDisconnect) => return Ok(pingresp), - Ok(crate::NetworkStatus::KeepAliveTimeout) => panic!(), - Ok(crate::NetworkStatus::IncomingDisconnect) => panic!(), - Err(err) => return Err(err), - } - }); + // let pingresp = Arc::new(crate::test_handlers::PingResp::new(client.clone())); - std::thread::sleep(sleep_duration); - client.disconnect_blocking().unwrap(); - let join_res = res_join_handle.join(); - assert!(join_res.is_ok()); + // network.connect(stream, &mut pingresp).await.unwrap(); - let res = join_res.unwrap(); - assert!(res.is_ok()); - let pingresp = res.unwrap(); - assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); - } - - #[cfg(feature = "concurrent_tokio")] - #[tokio::test] - async fn test_tokio_ping_req() { - let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); - client_id += "_TokioTcppingrespTest"; - let mut options = ConnectOptions::new(client_id); - let keep_alive_interval = 5; - options.set_keep_alive_interval(Duration::from_secs(keep_alive_interval)); + // let (read, write) = network.split(pingresp.clone()).unwrap(); - let wait_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; - - let (mut network, client) = new_tokio(options); - - let stream = tokio::net::TcpStream::connect(("azurewe1576.azureexternal.dnvgl.com", 1883)).await.unwrap(); - - let pingresp = Arc::new(PingResp::new(client.clone())); - - network.connect(stream, pingresp.clone()).await.unwrap(); - - let (read, write) = network.read_write_tasks().unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); + // let read_handle = tokio::task::spawn(read.run()); + // let write_handle = tokio::task::spawn(write.run()); - tokio::time::sleep(wait_duration).await; - client.disconnect().await.unwrap(); + // tokio::time::sleep(wait_duration).await; + // client.disconnect().await.unwrap(); - tokio::time::sleep(Duration::from_secs(1)).await; + // tokio::time::sleep(Duration::from_secs(1)).await; - let (read_result, write_result) = tokio::join!(read_handle, write_handle); - let (read_result, write_result) = (read_result.unwrap(), write_result.unwrap()); - assert!(write_result.is_ok()); - assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); - } + // let (read_result, write_result) = tokio::join!(read_handle, write_handle); + // let (read_result, write_result) = (read_result.unwrap(), write_result.unwrap()); + // assert!(write_result.is_ok()); + // assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); + // } #[cfg(all(feature = "tokio", target_family = "windows"))] #[tokio::test] async fn test_close_write_tcp_stream_tokio() { - use crate::error::ConnectionError; + use crate::{error::ConnectionError, NetworkBuilder}; use core::panic; use std::io::ErrorKind; let address = ("127.0.0.1", 2000); - let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); - client_id += "_TokioTcppingrespTest"; + let client_id: String = crate::random_chars() + "_TokioTcppingrespTest"; let options = ConnectOptions::new(client_id); - + let (n, _) = tokio::join!( async move { - let (mut network, client) = new_tokio(options); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); let stream = tokio::net::TcpStream::connect(address).await.unwrap(); - let pingresp = Arc::new(PingResp::new(client.clone())); + let mut pingresp = crate::example_handlers::PingResp::new(client.clone()); - network.connect(stream, pingresp).await + network.connect(stream, &mut pingresp).await }, async move { let listener = smol::net::TcpListener::bind(address).await.unwrap(); @@ -713,82 +637,4 @@ mod lib_test { panic!(); } } - - #[cfg(feature = "smol")] - #[test] - fn test_smol_ping_req() { - smol::block_on(async { - let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); - client_id += "_SmolTcppingrespTest"; - let mut options = ConnectOptions::new(client_id); - options.set_keep_alive_interval(Duration::from_secs(5)); - - let sleep_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; - - let address = "broker.emqx.io"; - let port = 1883; - - let (mut network, client) = new_smol(options); - let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); - - let mut pingresp = PingResp::new(client.clone()); - - network.connect(stream, &mut pingresp).await.unwrap(); - - let (n, _) = futures::join!( - async { - match network.run(&mut pingresp).await { - Ok(crate::NetworkStatus::OutgoingDisconnect) => return Ok(pingresp), - Ok(crate::NetworkStatus::KeepAliveTimeout) => panic!(), - Ok(crate::NetworkStatus::IncomingDisconnect) => panic!(), - Err(err) => return Err(err), - } - }, - async { - smol::Timer::after(sleep_duration).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); - let pingresp = n.unwrap(); - assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); - }); - } - - #[cfg(all(feature = "smol", target_family = "windows"))] - #[test] - fn test_close_write_tcp_stream_smol() { - use crate::error::ConnectionError; - use std::io::ErrorKind; - - smol::block_on(async { - let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); - client_id += "_SmolTcppingrespTest"; - let options = ConnectOptions::new(client_id); - - let address = "127.0.0.1"; - let port = 2001; - - let (n, _) = futures::join!( - async { - let (mut network, client) = new_smol(options); - let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); - let mut pingresp = PingResp::new(client.clone()); - network.connect(stream, &mut pingresp).await - }, - async { - let listener = smol::net::TcpListener::bind((address, port)).await.unwrap(); - let (stream, _) = listener.accept().await.unwrap(); - smol::Timer::after(std::time::Duration::from_secs(10)).await; - stream.shutdown(std::net::Shutdown::Write).unwrap(); - } - ); - if let ConnectionError::Io(err) = n.unwrap_err() { - assert_eq!(ErrorKind::ConnectionReset, err.kind()); - assert_eq!("Connection reset by peer".to_string(), err.to_string()); - } else { - panic!(); - } - }); - } } diff --git a/src/mqtt_handler.rs b/src/mqtt_handler.rs index 0dbc189..544c6fa 100644 --- a/src/mqtt_handler.rs +++ b/src/mqtt_handler.rs @@ -254,15 +254,8 @@ mod handler_tests { Packet, QoS, UnsubAck, UnsubAckProperties, {PubComp, PubCompProperties}, {PubRec, PubRecProperties}, {PubRel, PubRelProperties}, {SubAck, SubAckProperties}, }, tests::test_packets::{create_connack_packet, create_puback_packet, create_publish_packet, create_subscribe_packet, create_unsubscribe_packet}, - AsyncEventHandler, ConnectOptions, StateHandler, + ConnectOptions, StateHandler, }; - - pub struct Nop {} - - impl AsyncEventHandler for Nop { - async fn handle(&self, _event: Packet) {} - } - fn handler(clean_start: bool) -> (StateHandler, Receiver) { let (apkids, apkids_r) = AvailablePacketIds::new(100); diff --git a/src/packets/suback.rs b/src/packets/suback.rs index 5812ed4..3f8caa2 100644 --- a/src/packets/suback.rs +++ b/src/packets/suback.rs @@ -136,7 +136,6 @@ impl WireLength for SubAckProperties { #[cfg(test)] mod test { use bytes::BytesMut; - use pretty_assertions::assert_eq; use super::SubAck; use crate::packets::mqtt_traits::{VariableHeaderRead, VariableHeaderWrite}; diff --git a/src/packets/subscribe.rs b/src/packets/subscribe.rs index 596865f..80d994f 100644 --- a/src/packets/subscribe.rs +++ b/src/packets/subscribe.rs @@ -273,6 +273,15 @@ impl IntoSingleSubscription for Box { (value, SubscriptionOptions::default()) } } +impl IntoSingleSubscription for &(T, QoS) +where + T: AsRef, +{ + #[inline] + fn into((topic, qos): Self) -> (Box, SubscriptionOptions) { + (Box::from(topic.as_ref()), SubscriptionOptions { qos: *qos, ..Default::default() }) + } +} impl IntoSingleSubscription for (T, QoS) where T: AsRef, @@ -282,6 +291,15 @@ where (Box::from(topic.as_ref()), SubscriptionOptions { qos, ..Default::default() }) } } +impl IntoSingleSubscription for &(T, SubscriptionOptions) +where + T: AsRef, +{ + #[inline] + fn into((topic, sub): Self) -> (Box, SubscriptionOptions) { + (Box::from(topic.as_ref()), *sub) + } +} impl IntoSingleSubscription for (T, SubscriptionOptions) where T: AsRef, @@ -310,6 +328,11 @@ impl_subscription!(&str); impl_subscription!(&String); impl_subscription!(String); impl_subscription!(Box); +impl From<&(&str, QoS)> for Subscription { + fn from(value: &(&str, QoS)) -> Self { + Self(vec![IntoSingleSubscription::into(value)]) + } +} impl From<(T, QoS)> for Subscription where (T, QoS): IntoSingleSubscription, diff --git a/src/packets/unsubscribe.rs b/src/packets/unsubscribe.rs index 3cb73a5..8db6f5b 100644 --- a/src/packets/unsubscribe.rs +++ b/src/packets/unsubscribe.rs @@ -214,6 +214,11 @@ where Self(value.iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) } } +impl From<&[&str]> for UnsubscribeTopics{ + fn from(value: &[&str]) -> Self { + Self(value.iter().map(|val| IntoUnsubscribeTopic::into(*val)).collect()) + } +} // -------------------- Iterators -------------------- impl FromIterator for UnsubscribeTopics where diff --git a/src/smol/network.rs b/src/smol/network.rs index 5ba64d8..5bd8ab6 100644 --- a/src/smol/network.rs +++ b/src/smol/network.rs @@ -2,6 +2,7 @@ use async_channel::Receiver; use futures::FutureExt; +use std::marker::PhantomData; use std::time::{Duration, Instant}; use crate::available_packet_ids::AvailablePacketIds; @@ -11,7 +12,7 @@ use crate::packets::error::ReadBytes; use crate::packets::reason_codes::DisconnectReasonCode; use crate::packets::{Disconnect, Packet, PacketType}; use crate::NetworkStatus; -use crate::{AsyncEventHandler, StateHandler}; +use crate::{AsyncEventHandlerMut, StateHandler}; use super::stream::Stream; @@ -19,7 +20,8 @@ use super::stream::Stream; /// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. /// The most import thing to remember is that you have to provide a new stream after the previous has failed. /// (i.e. you need to reconnect after any expected or unexpected disconnect). -pub struct Network { +pub struct Network { + handler: PhantomData, network: Option>, /// Options of the current mqtt connection @@ -32,15 +34,15 @@ pub struct Network { state_handler: StateHandler, outgoing_packet_buffer: Vec, - incoming_packet_buffer: Vec, to_network_r: Receiver, } -impl Network { +impl Network { pub fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { let state_handler = StateHandler::new(&options, apkids); Self { + handler: PhantomData, network: None, keep_alive_interval: options.keep_alive_interval, @@ -52,22 +54,19 @@ impl Network { state_handler, outgoing_packet_buffer: Vec::new(), - incoming_packet_buffer: Vec::new(), to_network_r, } } } -impl Network +impl Network where + H: AsyncEventHandlerMut, S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, { /// Initializes an MQTT connection with the provided configuration an stream - pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> - where - H: AsyncEventHandler, - { + pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError>{ let (mut network, conn_ack) = Stream::connect(&self.options, stream).await?; self.last_network_action = Instant::now(); @@ -100,10 +99,7 @@ where /// /// In all other cases the network is unusable anymore. /// The stream will be dropped and the internal buffers will be cleared. - pub async fn run(&mut self, handler: &mut H) -> Result - where - H: AsyncEventHandler, - { + pub async fn run(&mut self, handler: &mut H) -> Result { if self.network.is_none() { return Err(ConnectionError::NoNetwork); } @@ -114,7 +110,6 @@ where self.network = None; self.await_pingresp = None; self.outgoing_packet_buffer.clear(); - self.incoming_packet_buffer.clear(); // This is safe as inside the Ok it is not possible to have a None due to the above Ok(None) pattern. return otherwise.map(|ok| ok.unwrap()); @@ -123,11 +118,9 @@ where } } - async fn smol_select(&mut self, handler: &mut H) -> Result, ConnectionError> - where - H: AsyncEventHandler, - { + async fn smol_select(&mut self, handler: &mut H) -> Result, ConnectionError> { let Network { + handler: _, network, options: _, keep_alive_interval, @@ -136,7 +129,6 @@ where perform_keep_alive, state_handler, outgoing_packet_buffer, - incoming_packet_buffer, to_network_r, } = self; @@ -153,52 +145,46 @@ where futures::select! { res = stream.read_bytes().fuse() => { res?; - match stream.parse_messages(incoming_packet_buffer).await { + match stream.parse_message().await { Err(ReadBytes::Err(err)) => return Err(err), Err(ReadBytes::InsufficientBytes(_)) => return Ok(None), - Ok(_) => (), - } - - for packet in incoming_packet_buffer.drain(0..){ - use Packet::*; - match packet{ - PingResp => { - handler.handle(packet).await; - *await_pingresp = None; - }, - Disconnect(_) => { - handler.handle(packet).await; - return Ok(Some(NetworkStatus::IncomingDisconnect)); - } - packet => { - match state_handler.handle_incoming_packet(&packet)? { - (maybe_reply_packet, true) => { - if let Some(reply_packet) = maybe_reply_packet { + Ok(packet) => { + match packet{ + Packet::PingResp => { + handler.handle(packet).await; + *await_pingresp = None; + }, + Packet::Disconnect(_) => { + handler.handle(packet).await; + return Ok(Some(NetworkStatus::IncomingDisconnect)); + } + packet => { + match state_handler.handle_incoming_packet(&packet)? { + (maybe_reply_packet, true) => { + handler.handle(packet).await; + if let Some(reply_packet) = maybe_reply_packet { + outgoing_packet_buffer.push(reply_packet); + } + }, + (Some(reply_packet), false) => { outgoing_packet_buffer.push(reply_packet); - } - }, - (Some(reply_packet), false) => { - outgoing_packet_buffer.push(reply_packet); - }, - (None, false) => (), + }, + (None, false) => (), + } } } - } + stream.write_all(outgoing_packet_buffer).await?; + *last_network_action = Instant::now(); + }, } - stream.write_all(outgoing_packet_buffer).await?; - *last_network_action = Instant::now(); - Ok(None) }, outgoing = to_network_r.recv().fuse() => { let packet = outgoing?; stream.write(&packet).await?; - let mut disconnect = false; - if packet.packet_type() == PacketType::Disconnect{ - disconnect = true; - } + let disconnect = packet.packet_type() == PacketType::Disconnect; state_handler.handle_outgoing_packet(packet)?; *last_network_action = Instant::now(); diff --git a/src/smol/stream.rs b/src/smol/stream.rs index 71ca70e..be8a72b 100644 --- a/src/smol/stream.rs +++ b/src/smol/stream.rs @@ -30,31 +30,22 @@ pub struct Stream { } impl Stream { - pub async fn parse_messages(&mut self, incoming_packet_buffer: &mut Vec) -> Result<(), ReadBytes> { - loop { - if self.read_buffer.is_empty() { - return Ok(()); - } - let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; + pub async fn parse_message(&mut self) -> Result> { + let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; - if header.remaining_length + header_length > self.read_buffer.len() { - return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); - } + if header.remaining_length + header_length > self.read_buffer.len() { + return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); + } - self.read_buffer.advance(header_length); + self.read_buffer.advance(header_length); - let buf = self.read_buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; + let buf = self.read_buffer.split_to(header.remaining_length); + let read_packet = Packet::read(header, buf.into())?; - #[cfg(feature = "logs")] - trace!("Read packet from network {}", read_packet); - let packet_type = read_packet.packet_type(); - incoming_packet_buffer.push(read_packet); + #[cfg(feature = "logs")] + trace!("Read packet from network {}", read_packet); - if packet_type == PacketType::Disconnect { - return Ok(()); - } - } + Ok(read_packet) } } diff --git a/src/tokio/mod.rs b/src/tokio/mod.rs index e9b22a8..0520267 100644 --- a/src/tokio/mod.rs +++ b/src/tokio/mod.rs @@ -1,5 +1,100 @@ -mod network; +mod stream; + +pub(crate) mod network; +use futures::Future; pub use network::Network; +pub use network::NetworkReader; +pub use network::NetworkWriter; -mod stream; +use crate::AsyncEventHandler; +use crate::AsyncEventHandlerMut; +use crate::error::ConnectionError; +use crate::packets::Packet; + +/// This empty struct is used to indicate the handling of messages goes via a mutable handler. +/// Only a single mutable reference can exist at once. +/// Thus this kind is not for concurrent message handling but for concurrent TCP read and write operations. +pub struct SequentialHandler; + +/// This empty struct is used to indicate a (tokio) task based handling of messages. +/// Per incoming message a task is spawned to call the handler. +/// +/// This kind of handler is used for both concurrent message handling and concurrent TCP read and write operations. +pub struct ConcurrentHandler; + + +trait HandlerExt: Sized{ + /// Should call the handler in the fashion of the handler. + /// (e.g. spawn a task if or await the handle call) + fn call_handler(handler: &mut H, incoming_packet: Packet) -> impl Future + Send; + + /// Should call the handler and await it + fn call_handler_await(handler: &mut H, incoming_packet: Packet) -> impl Future + Send; + + /// Should call the handler in the fashion of the handler. + /// (e.g. spawn a task if or await the handle call) + /// The reply (e.g. an ACK) to the original packet is only send when the handle call has completed + fn call_handler_with_reply( + network: &mut NetworkReader, + incoming_packet: Packet, + reply_packet: Option + ) -> impl Future> + Send + where + S: Send + ; + +} +impl HandlerExt for SequentialHandler { + #[inline] + fn call_handler(handler: &mut H, incoming_packet: Packet) -> impl Future + Send{ + handler.handle(incoming_packet) + } + #[inline] + fn call_handler_await(handler: &mut H, incoming_packet: Packet) -> impl Future + Send{ + handler.handle(incoming_packet) + } + fn call_handler_with_reply(network: &mut NetworkReader, incoming_packet: Packet, reply_packet: Option) -> impl Future> + Send + where + S: Send, + { + async{ + network.handler.handle(incoming_packet).await; + if let Some(reply_packet) = reply_packet { + network.to_writer_s.send(reply_packet).await?; + } + Ok(()) + } + } +} +impl HandlerExt for ConcurrentHandler { + fn call_handler(handler: &mut H, incoming_packet: Packet) -> impl Future + Send{ + let handler_clone = handler.clone(); + tokio::spawn(async move { + handler_clone.handle(incoming_packet).await; + }); + std::future::ready(()) + } + #[inline] + fn call_handler_await(handler: &mut H, incoming_packet: Packet) -> impl Future + Send{ + handler.handle(incoming_packet) + } + + fn call_handler_with_reply(network: &mut NetworkReader, incoming_packet: Packet, reply_packet: Option) -> impl Future> + Send + where + S: Send, + { + let handler_clone = network.handler.clone(); + let write_channel_clone = network.to_writer_s.clone(); + + network.join_set.spawn(async move { + handler_clone.handle(incoming_packet).await; + if let Some(reply_packet) = reply_packet { + write_channel_clone.send(reply_packet).await?; + } + Ok(()) + }); + + std::future::ready(Ok(())) + } +} \ No newline at end of file diff --git a/src/tokio/network.rs b/src/tokio/network.rs index d1bacd3..cc8d16f 100644 --- a/src/tokio/network.rs +++ b/src/tokio/network.rs @@ -1,8 +1,8 @@ use async_channel::{Receiver, Sender}; -use futures::Future; -use tokio::join; +use tokio::task::JoinSet; +use std::marker::PhantomData; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -14,8 +14,9 @@ use crate::packets::error::ReadBytes; use crate::packets::reason_codes::DisconnectReasonCode; use crate::packets::{Disconnect, Packet, PacketType}; -use crate::{AsyncEventHandler, StateHandler, NetworkStatus}; +use crate::{AsyncEventHandlerMut, StateHandler, NetworkStatus}; +use super::{SequentialHandler, HandlerExt}; use super::stream::read_half::ReadStream; use super::stream::write_half::WriteStream; use super::stream::Stream; @@ -26,32 +27,27 @@ use super::stream::Stream; /// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. /// The most import thing to remember is that you have to provide a new stream after the previous has failed. /// (i.e. you need to reconnect after any expected or unexpected disconnect). -pub struct Network { - handler: Option, - +pub struct Network { + handler_helper: PhantomData, + handler: PhantomData, network: Option>, /// Options of the current mqtt connection options: ConnectOptions, - last_network_action: Instant, - - await_pingresp_atomic: Arc, perform_keep_alive: bool, - state_handler: Arc, - to_network_r: Receiver, } -impl Network { +impl Network { pub fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { Self { - handler: None, + handler_helper: PhantomData, + handler: PhantomData, network: None, last_network_action: Instant::now(), - await_pingresp_atomic: Arc::new(AtomicBool::new(false)), perform_keep_alive: true, state_handler: Arc::new(StateHandler::new(&options, apkids)), @@ -64,13 +60,13 @@ impl Network { } /// Tokio impl -impl Network +impl Network where - H: AsyncEventHandler + Clone + Send + Sync + 'static, + N: HandlerExt, S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, { /// Initializes an MQTT connection with the provided configuration an stream - pub async fn connect(&mut self, stream: S, handler: H) -> Result<(), ConnectionError> { + pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { let (mut network, conn_ack) = Stream::connect(&self.options, stream).await?; self.last_network_action = Instant::now(); @@ -82,23 +78,148 @@ where } let packets = self.state_handler.handle_incoming_connack(&conn_ack)?; - handler.handle(Packet::ConnAck(conn_ack)).await; + N::call_handler_await(handler, Packet::ConnAck(conn_ack)).await; if let Some(mut packets) = packets { network.write_all(&mut packets).await?; self.last_network_action = Instant::now(); } self.network = Some(network); - self.handler = Some(handler); Ok(()) } +} + +impl Network +where + H: AsyncEventHandlerMut, + SequentialHandler: HandlerExt, + S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, +{ + + /// A single call to run will perform one of three tasks: + /// - Read from the stream and parse the bytes to packets for the user to handle + /// - Write user packets to stream + /// - Perform keepalive if necessary + /// + /// In all other cases the network is unusable anymore. + /// The stream will be dropped and the internal buffers will be cleared. + pub async fn run(&mut self, handler: &mut H) -> Result { + if self.network.is_none() { + return Err(ConnectionError::NoNetwork); + } + + + match self.tokio_select(handler).await { + otherwise => { + self.network = None; + + otherwise + } + } + } + + async fn tokio_select(&mut self, handler: &mut H) -> Result { + let Network { + network, + options, + last_network_action, + perform_keep_alive, + to_network_r, + handler_helper: _, + handler: _, + state_handler, + } = self; + + let mut await_pingresp = None; + let mut outgoing_packet_buffer = Vec::new(); + + loop { + let sleep; + if let Some(instant) = await_pingresp { + sleep = instant + options.get_keep_alive_interval() - Instant::now(); + } else { + sleep = *last_network_action + options.get_keep_alive_interval() - Instant::now(); + } + + if let Some(stream) = network { + tokio::select! { + res = stream.read_bytes() => { + res?; + loop{ + let packet = match stream.parse_message().await { + Err(ReadBytes::Err(err)) => return Err(err), + Err(ReadBytes::InsufficientBytes(_)) => break, + Ok(packet) => packet, + }; + match packet{ + Packet::PingResp => { + SequentialHandler::call_handler_await(handler, packet).await; + await_pingresp = None; + }, + Packet::Disconnect(_) => { + SequentialHandler::call_handler_await(handler, packet).await; + return Ok(NetworkStatus::IncomingDisconnect); + } + packet => { + match state_handler.handle_incoming_packet(&packet)? { + (maybe_reply_packet, true) => { + SequentialHandler::call_handler_await(handler, packet).await; + if let Some(reply_packet) = maybe_reply_packet { + outgoing_packet_buffer.push(reply_packet); + } + }, + (Some(reply_packet), false) => { + outgoing_packet_buffer.push(reply_packet); + }, + (None, false) => (), + } + } + } + stream.write_all(&mut outgoing_packet_buffer).await?; + *last_network_action = Instant::now(); + } + }, + outgoing = to_network_r.recv() => { + let packet = outgoing?; + stream.write(&packet).await?; + let disconnect = packet.packet_type() == PacketType::Disconnect; + state_handler.handle_outgoing_packet(packet)?; + *last_network_action = Instant::now(); - #[cfg(feature = "concurrent_tokio")] + + if disconnect{ + return Ok(NetworkStatus::OutgoingDisconnect); + } + }, + _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { + let packet = Packet::PingReq; + stream.write(&packet).await?; + *last_network_action = Instant::now(); + await_pingresp = Some(Instant::now()); + }, + _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { + let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; + stream.write(&Packet::Disconnect(disconnect)).await?; + return Ok(NetworkStatus::KeepAliveTimeout); + } + } + } else { + return Err(ConnectionError::NoNetwork); + } + } + } +} + + +impl Network +where + S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, +{ /// Creates both read and write tasks to run this them in parallel. /// If you want to run concurrently (not parallel) the [`Self::run`] method is a better aproach! - pub fn read_write_tasks(&mut self) -> Result<(NetworkReader, NetworkWriter), ConnectionError> { + pub fn split(&mut self, handler: H) -> Result<(NetworkReader, NetworkWriter), ConnectionError> { if self.network.is_none() { return Err(ConnectionError::NoNetwork)?; } @@ -108,14 +229,17 @@ where let (read_stream, write_stream) = network.split(); let run_signal = Arc::new(AtomicBool::new(true)); let (to_writer_s, to_writer_r) = async_channel::bounded(100); + let await_pingresp_atomic = Arc::new(AtomicBool::new(false)); let read_network = NetworkReader { run_signal: run_signal.clone(), - handler: self.handler.as_ref().unwrap().clone(), + handler_helper: PhantomData, + handler: handler, read_stream, - await_pingresp_atomic: self.await_pingresp_atomic.clone(), + await_pingresp_atomic: await_pingresp_atomic.clone(), state_handler: self.state_handler.clone(), to_writer_s, + join_set: JoinSet::new(), }; let write_network = NetworkWriter { @@ -123,7 +247,7 @@ where write_stream, keep_alive_interval: self.options.keep_alive_interval, last_network_action: self.last_network_action, - await_pingresp_bool: self.await_pingresp_atomic.clone(), + await_pingresp_bool: await_pingresp_atomic.clone(), await_pingresp_time: None, perform_keep_alive: self.perform_keep_alive, state_handler: self.state_handler.clone(), @@ -136,50 +260,41 @@ where None => Err(ConnectionError::NoNetwork), } } - - /// A single call to run will perform one of three tasks: - /// - Read from the stream and parse the bytes to packets for the user to handle - /// - Write user packets to stream - /// - Perform keepalive if necessary - /// - /// This function can produce an indication of the state of the network or an error. - /// When the network is still active (i.e. stream is not closed and no disconnect packet has been processed) the network will return [`NetworkStatus::Active`] - /// - /// In all other cases the network is unusable anymore. - /// The stream will be dropped and the internal buffers will be cleared. - pub fn run() { - - } - } -#[cfg(feature = "concurrent_tokio")] -pub struct NetworkReader { - run_signal: Arc, - - handler: H, - read_stream: ReadStream, - await_pingresp_atomic: Arc, - state_handler: Arc, - to_writer_s: Sender, +#[cfg(feature = "tokio_concurrent")] +pub struct NetworkReader { + pub(crate) run_signal: Arc, + + pub(crate) handler_helper: PhantomData, + pub handler: H, + + pub(crate) read_stream: ReadStream, + pub(crate) await_pingresp_atomic: Arc, + pub(crate) state_handler: Arc, + pub(crate) to_writer_s: Sender, + pub(crate) join_set: JoinSet>, } -#[cfg(feature = "concurrent_tokio")] -impl NetworkReader +#[cfg(feature = "tokio_concurrent")] +impl NetworkReader where - H: AsyncEventHandler + Clone + Send + Sync + 'static, + N: HandlerExt, S: tokio::io::AsyncReadExt + Sized + Unpin + Send + 'static, { - /// Runs the read half of the concurrent read & write tokio client. + /// Runs the read half with a [`AsyncEventHandlerMut`]. /// Continuously loops until disconnect or error. /// /// # Return - /// - Ok(None) in the case that the write task requested shutdown. - /// - Ok(Some(reason)) in the case that this task initiates a shutdown. - /// - Err in the case of IO, or protocol errors. + /// - Ok(None) in the case that the write task requested shutdown. + /// - Ok(Some(reason)) in the case that this task initiates a shutdown. + /// - Err in the case of IO, or protocol errors. pub async fn run(mut self) -> Result, ConnectionError> { let ret = self.read().await; self.run_signal.store(false, std::sync::atomic::Ordering::Release); + while let Some(_) = self.join_set.join_next().await { + () + } ret } async fn read(&mut self) -> Result, ConnectionError> { @@ -196,7 +311,7 @@ where match packet { Packet::PingResp => { - self.handler.handle(packet).await; + N::call_handler(&mut self.handler, packet).await; #[cfg(feature = "logs")] if !self.await_pingresp_atomic.fetch_and(false, std::sync::atomic::Ordering::SeqCst) { tracing::warn!("Received PingResp but did not expect it"); @@ -205,7 +320,7 @@ where self.await_pingresp_atomic.store(false, std::sync::atomic::Ordering::SeqCst); } Packet::Disconnect(_) => { - self.handler.handle(packet).await; + N::call_handler(&mut self.handler, packet).await; return Ok(Some(NetworkStatus::IncomingDisconnect)); } Packet::ConnAck(conn_ack) => { @@ -214,14 +329,11 @@ where self.to_writer_s.send(packet).await?; } } - self.handler.handle(Packet::ConnAck(conn_ack)).await; + N::call_handler(&mut self.handler, Packet::ConnAck(conn_ack)).await; } packet => match self.state_handler.handle_incoming_packet(&packet)? { (maybe_reply_packet, true) => { - self.handler.handle(packet).await; - if let Some(reply_packet) = maybe_reply_packet { - let _ = self.to_writer_s.send(reply_packet).await?; - } + N::call_handler_with_reply(self, packet, maybe_reply_packet).await?; } (Some(reply_packet), false) => { self.to_writer_s.send(reply_packet).await?; @@ -235,7 +347,7 @@ where } } -#[cfg(feature = "concurrent_tokio")] +#[cfg(feature = "tokio_concurrent")] pub struct NetworkWriter { run_signal: Arc, @@ -254,18 +366,18 @@ pub struct NetworkWriter { to_network_r: Receiver, } -#[cfg(feature = "concurrent_tokio")] +#[cfg(feature = "tokio_concurrent")] impl NetworkWriter where S: tokio::io::AsyncWriteExt + Sized + Unpin, -{ +{ /// Runs the write half of the concurrent read & write tokio client /// Continuously loops until disconnect or error. /// /// # Return - /// - Ok(None) in the case that the read task requested shutdown - /// - Ok(Some(reason)) in the case that this task initiates a shutdown - /// - Err in the case of IO, or protocol errors. + /// - Ok(None) in the case that the read task requested shutdown + /// - Ok(Some(reason)) in the case that this task initiates a shutdown + /// - Err in the case of IO, or protocol errors. pub async fn run(mut self) -> Result, ConnectionError> { let ret = self.write().await; self.run_signal.store(false, std::sync::atomic::Ordering::Release); @@ -283,13 +395,13 @@ where } else { sleep = self.last_network_action + self.keep_alive_interval - Instant::now(); } - + ; tokio::select! { outgoing = self.to_network_r.recv() => { let packet = outgoing?; self.write_stream.write(&packet).await?; - let disconnect = if packet.packet_type() == PacketType::Disconnect { true } else { false }; + let disconnect = packet.packet_type() == PacketType::Disconnect; self.state_handler.handle_outgoing_packet(packet)?; self.last_network_action = Instant::now(); diff --git a/src/tokio/stream/read_half.rs b/src/tokio/stream/read_half.rs index 01843c6..00b862f 100644 --- a/src/tokio/stream/read_half.rs +++ b/src/tokio/stream/read_half.rs @@ -45,29 +45,6 @@ where Ok(read_packet) } - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.read_buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_required_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(io::Error::new(io::ErrorKind::InvalidData, err)), - }; - - if header_length + header.remaining_length > self.read_buffer.len() { - self.read_required_bytes(header.remaining_length - self.read_buffer.len()).await?; - } - - self.read_buffer.advance(header_length); - - let buf = self.read_buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)); - } - } - pub async fn read_bytes(&mut self) -> io::Result { let read = self.stream.read(&mut self.const_buffer).await?; if read == 0 { @@ -77,17 +54,4 @@ where Ok(read) } } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - - loop { - let read = self.read_bytes().await?; - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } } diff --git a/src/tokio/stream/write_half.rs b/src/tokio/stream/write_half.rs index 58b784b..99cbcca 100644 --- a/src/tokio/stream/write_half.rs +++ b/src/tokio/stream/write_half.rs @@ -32,27 +32,4 @@ where self.write_buffer.clear(); Ok(()) } - - pub async fn write_all(&mut self, packets: &mut I) -> Result<(), ConnectionError> - where - I: Iterator, - { - let writes = packets.map(|packet| { - packet.write(&mut self.write_buffer)?; - - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); - - Ok::<(), ConnectionError>(()) - }); - - for write in writes { - write?; - } - - self.stream.write_all(&self.write_buffer[..]).await?; - self.stream.flush().await?; - self.write_buffer.clear(); - Ok(()) - } }