From 4f0f5010dd6bfe1e9df915b9c8de6f4004e92a3d Mon Sep 17 00:00:00 2001 From: Gunnar <13799935+GunnarMorrigan@users.noreply.github.com> Date: Fri, 3 Mar 2023 00:50:01 +0100 Subject: [PATCH] Retransmission and rework crate (#6) * Work to combine network and handler Moving to a combined network and handler because it is required for retransmit * Should not work but local cargo does not detect? * Very large rework to make retransmission possible Add retranmission and tests * Small lib change * Adjusted documentation * Adjusted documentation * Bumped version number to 1.5 * Bumped tokio version due to tokio issue readhalf unsplit * Handle connack packet directly * Cargo clippy and fmt * Fix issue in state with connack Vec would be cleared but would depend on items beting Some or None instead. * Cargo fmt and clippy * Adjusted code coverage generation * Make tests and code coverage sequential --- .github/workflows/rust.yml | 41 +- Cargo.lock | 241 +++++---- Cargo.toml | 10 +- README.md | 174 ++++-- rustfmt.toml | 7 +- src/available_packet_ids.rs | 22 +- src/client.rs | 206 ++------ src/connect_options.rs | 2 +- src/connections/mod.rs | 8 +- src/connections/smol.rs | 172 ++++++ src/connections/smol_stream.rs | 172 ------ src/connections/tokio.rs | 172 ++++++ src/connections/tokio_stream.rs | 170 ------ src/error.rs | 69 +-- src/lib.rs | 540 +++++++++---------- src/mqtt_handler.rs | 912 ++++++++++++-------------------- src/network/mod.rs | 2 + src/network/smol.rs | 194 +++++++ src/network/tokio.rs | 195 +++++++ src/packets/auth.rs | 22 +- src/packets/connack.rs | 120 ++--- src/packets/connect.rs | 191 ++----- src/packets/disconnect.rs | 49 +- src/packets/mod.rs | 128 ++--- src/packets/puback.rs | 60 +-- src/packets/pubcomp.rs | 60 +-- src/packets/publish.rs | 55 +- src/packets/pubrec.rs | 60 +-- src/packets/pubrel.rs | 54 +- src/packets/reason_codes.rs | 58 +- src/packets/suback.rs | 21 +- src/packets/subscribe.rs | 63 +-- src/packets/unsuback.rs | 37 +- src/packets/unsubscribe.rs | 40 +- src/smol_network.rs | 162 ------ src/state.rs | 162 +++++- src/tests/connection_tests.rs | 99 ---- src/tests/mod.rs | 1 - src/tests/test_bytes.rs | 65 +-- src/tests/test_packets.rs | 22 +- src/tests/tls.rs | 28 +- src/tokio_network.rs | 153 ------ src/util/mod.rs | 2 +- 43 files changed, 2171 insertions(+), 2850 deletions(-) create mode 100644 src/connections/smol.rs delete mode 100644 src/connections/smol_stream.rs create mode 100644 src/connections/tokio.rs delete mode 100644 src/connections/tokio_stream.rs create mode 100644 src/network/mod.rs create mode 100644 src/network/smol.rs create mode 100644 src/network/tokio.rs delete mode 100644 src/smol_network.rs delete mode 100644 src/tests/connection_tests.rs delete mode 100644 src/tokio_network.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 3c84797..89be431 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -27,7 +27,7 @@ jobs: # run clippy to verify we have no warnings - run: cargo fetch - name: cargo clippy - run: cargo clippy --all-targets --all-features -- -D warnings + run: cargo clippy --all-targets -- -D warnings test: name: Test @@ -47,20 +47,35 @@ jobs: - uses: actions/checkout@v3 - uses: EmbarkStudios/cargo-deny-action@v1 - codecov: - name: Generate code coverage + coverage: + name: Coverage runs-on: ubuntu-latest needs: test + permissions: + issues: write + steps: - - uses: actions/checkout@v2 - - uses: hecrj/setup-rust-action@v1 - # TODO: we don't use caching here because it's unclear if it will cause - # the coverage to get less accurate (this is the case for some coverage - # tools, although possibly not tarpaulin?) - - name: Run cargo-tarpaulin - uses: actions-rs/tarpaulin@v0.1 + - name: Checkout + uses: actions/checkout@v3 with: - args: '--all-features' + persist-credentials: false + + - name: Install toolchain + uses: dtolnay/rust-toolchain@nightly - - name: Upload to codecov.io - uses: codecov/codecov-action@v1 + - name: Install cargo-llvm-cov + run: | + curl -LsSf https://github.com/taiki-e/cargo-llvm-cov/releases/latest/download/cargo-llvm-cov-x86_64-unknown-linux-gnu.tar.gz \ + | tar xzf - -C ~/.cargo/bin + - name: Generate coverage report + run: | + cargo llvm-cov clean --workspace + cargo llvm-cov test -p mqrstt --no-report --all-features -- --test-threads=1 + cargo llvm-cov report --lcov > lcov.txt + env: + RUSTFLAGS: --cfg __ui_tests + + - name: Upload coverage report + uses: codecov/codecov-action@v3 + with: + files: ./lcov.txt diff --git a/Cargo.lock b/Cargo.lock index d92b294..09881c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,7 +56,7 @@ dependencies = [ "slab", "socket2", "waker-fn", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -105,7 +105,7 @@ dependencies = [ "futures-lite", "libc", "signal-hook", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -127,9 +127,9 @@ checksum = "7a40729d2133846d9ed0ea60a8b9541bccddab49cd30f0715a1da672fe9a2524" [[package]] name = "async-trait" -version = "0.1.61" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705339e0e4a9690e2908d2b3d049d85682cf19fbd5782494498fbf7003a6a282" +checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" dependencies = [ "proc-macro2", "quote", @@ -138,9 +138,9 @@ dependencies = [ [[package]] name = "atomic-waker" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "065374052e7df7ee4047b1160cca5e1467a12351a40b3da123c870ba0b8eda2a" +checksum = "debc29dde2e69f9e47506b525f639ed42300fc014a3e007832592448fa8e4599" [[package]] name = "autocfg" @@ -150,9 +150,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.13.1" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" [[package]] name = "blocking" @@ -170,21 +170,21 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.11.1" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" [[package]] name = "bytes" -version = "1.2.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" [[package]] name = "cc" -version = "1.0.77" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9f73505338f7d905b19d18738976aae232eb46b8efc15554ffc56deb5d9ebe4" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" [[package]] name = "cfg-if" @@ -194,9 +194,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "concurrent-queue" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd7bef69dc86e3c610e4e7aed41035e2a7ed12e72dd7530f61327a6579a4390b" +checksum = "c278839b831783b70278b14df4d45e1beb1aad306c07bb796637de9a0e323e8e" dependencies = [ "crossbeam-utils", ] @@ -234,18 +234,18 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "fastrand" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" dependencies = [ "instant", ] [[package]] name = "futures" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0" +checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84" dependencies = [ "futures-channel", "futures-core", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed" +checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" dependencies = [ "futures-core", "futures-sink", @@ -268,15 +268,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac" +checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" [[package]] name = "futures-executor" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2" +checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e" dependencies = [ "futures-core", "futures-task", @@ -285,9 +285,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" +checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" [[package]] name = "futures-lite" @@ -306,9 +306,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d" +checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" dependencies = [ "proc-macro2", "quote", @@ -317,15 +317,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9" +checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" [[package]] name = "futures-task" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea" +checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" [[package]] name = "futures-timer" @@ -335,9 +335,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6" +checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" dependencies = [ "futures-channel", "futures-core", @@ -353,9 +353,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.1.19" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" dependencies = [ "libc", ] @@ -371,9 +371,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49409df3e3bf0856b916e2ceaca09ee28e6871cf7d9ce97a692cacfdb2a25a47" +checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" dependencies = [ "wasm-bindgen", ] @@ -386,9 +386,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.137" +version = "0.2.139" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" +checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" [[package]] name = "log" @@ -416,19 +416,19 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "mio" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" dependencies = [ "libc", "log", "wasi", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] name = "mqrstt" -version = "0.1.4" +version = "0.1.5" dependencies = [ "async-channel", "async-mutex", @@ -461,9 +461,9 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6058e64324c71e02bc2b150e4f3bc8286db6c83092132ffa3f6b1eab0f9def5" +checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" dependencies = [ "hermit-abi", "libc", @@ -471,9 +471,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.16.0" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" +checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" [[package]] name = "output_vt100" @@ -510,16 +510,16 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "polling" -version = "2.5.1" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "166ca89eb77fd403230b9c156612965a81e094ec6ec3aa13663d4c8b113fa748" +checksum = "22122d5ec4f9fe1b3916419b76be1e80bcb93f618d071d2edf841b137b2a2bd6" dependencies = [ "autocfg", "cfg-if", "libc", "log", "wepoll-ffi", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -536,27 +536,27 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.47" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.21" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" dependencies = [ "proc-macro2", ] [[package]] name = "regex" -version = "1.7.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a" +checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" dependencies = [ "regex-syntax", ] @@ -628,9 +628,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.20.7" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c" +checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" dependencies = [ "log", "ring", @@ -640,9 +640,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0864aeff53f8c05aa08d86e5ef839d3dfcf07aeba2db32f12db0ef716e87bd55" +checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ "base64", ] @@ -674,9 +674,9 @@ dependencies = [ [[package]] name = "signal-hook" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a253b5e89e2698464fc26b545c9edceb338e18a89effeeecfea192c3025be29d" +checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9" dependencies = [ "libc", "signal-hook-registry", @@ -684,9 +684,9 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" dependencies = [ "libc", ] @@ -741,9 +741,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "syn" -version = "1.0.103" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d" +checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" dependencies = [ "proc-macro2", "quote", @@ -752,18 +752,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" +checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" +checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" dependencies = [ "proc-macro2", "quote", @@ -772,18 +772,19 @@ dependencies = [ [[package]] name = "thread_local" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" dependencies = [ + "cfg-if", "once_cell", ] [[package]] name = "tokio" -version = "1.24.1" +version = "1.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d9f76183f91ecfb55e1d7d5602bd1d979e38a3a522fe900241cf195624d67ae" +checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64" dependencies = [ "autocfg", "bytes", @@ -794,14 +795,14 @@ dependencies = [ "pin-project-lite", "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] name = "tokio-macros" -version = "1.8.0" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9724f9a975fb987ef7a3cd9be0350edcbe130698af5b8f7a631e23d42d052484" +checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" dependencies = [ "proc-macro2", "quote", @@ -883,9 +884,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3" +checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" [[package]] name = "untrusted" @@ -913,9 +914,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268" +checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -923,9 +924,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8ffb332579b0557b52d268b91feab8df3615f265d5270fec2a8c95b17c1142" +checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" dependencies = [ "bumpalo", "log", @@ -938,9 +939,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "052be0f94026e6cbc75cdefc9bae13fd6052cdcaf532fa6c45e7ae33a1e6c810" +checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -948,9 +949,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" +checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", @@ -961,15 +962,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" +checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" [[package]] name = "web-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcda906d8be16e728fd5adc5b729afad4e444e106ab28cd1c7256e54fa61510f" +checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97" dependencies = [ "js-sys", "wasm-bindgen", @@ -1031,47 +1032,71 @@ dependencies = [ "windows_x86_64_msvc", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" +checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" [[package]] name = "windows_aarch64_msvc" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" +checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" [[package]] name = "windows_i686_gnu" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" +checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" [[package]] name = "windows_i686_msvc" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" +checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" [[package]] name = "windows_x86_64_gnu" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" +checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" +checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" [[package]] name = "windows_x86_64_msvc" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" +checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" [[package]] name = "yansi" diff --git a/Cargo.toml b/Cargo.toml index d5c739a..b28bf50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mqrstt" -version = "0.1.4" +version = "0.1.5" homepage = "https://github.com/GunnarMorrigan/mqrstt" repository = "https://github.com/GunnarMorrigan/mqrstt" documentation = "https://docs.rs/mqrstt" @@ -14,7 +14,7 @@ description = "Pure rust MQTTv5 client implementation for Smol, Tokio and soon s # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["smol", "tokio"] +default = ["tokio", "smol"] tokio = ["dep:tokio"] smol = ["dep:smol"] # quic = ["dep:quinn"] @@ -36,17 +36,19 @@ async-trait = "0.1.61" # quinn = {version = "0.9.0", optional = true } # tokio feature flag -tokio = { version = "1.24.1", features = ["macros", "io-util", "net", "time"], optional = true } +tokio = { version = "1.26.0", features = ["macros", "io-util", "net", "time"], optional = true } # smol feature flag smol = { version = "1.3.0", optional = true } [dev-dependencies] pretty_assertions = "1.3.0" -tokio = { version = "1.24.1", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] } +tokio = { version = "1.26.0", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] } + smol = { version = "1.3.0" } tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } + rustls = { version = "0.20.7" } rustls-pemfile = { version = "1.0.1" } webpki = { version = "0.22.0" } diff --git a/README.md b/README.md index 4037c90..9dad40c 100644 --- a/README.md +++ b/README.md @@ -9,30 +9,49 @@ `mqrstt` is an MQTTv5 client implementation that allows for the smol and tokio runtimes. In the future we will also support a sync implementation. +Because this crate aims to be runtime agnostic the user is required to provide their own data stream. +The stream has to implement the smol or tokio [`AsyncReadExt`] and [`AsyncWriteExt`] traits. + -## Examples +## Features + - MQTT v5 + - Retransmission + - Runtime agnostic + - Lean + - Keep alive depends on actual communication + + ### To do + - Enforce size of outbound messages (e.g. Publish) + - Sync API + - More testing + - More documentation + - Remove logging calls or move all to test flag -You want to reconnect (with a new stream) after the network encountered an error or a disconnect took place! +## Examples + ### Notes: + - Your handler should not wait too long + - Create a new connection when an error or disconnect is encountered + - Handlers only get incoming packets ### Smol example: ```rust use mqrstt::{ - AsyncClient, + MqttClient, ConnectOptions, new_smol, packets::{self, Packet}, - AsyncEventHandlerMut, HandlerStatus, NetworkStatus, + AsyncEventHandler, NetworkStatus, }; use async_trait::async_trait; use bytes::Bytes; pub struct PingPong { - pub client: AsyncClient, + pub client: MqttClient, } #[async_trait] -impl AsyncEventHandlerMut for PingPong { +impl AsyncEventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: &packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet) -> () { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { @@ -56,89 +75,129 @@ impl AsyncEventHandlerMut for PingPong { } } smol::block_on(async { - let options = ConnectOptions::new("mqrsttExample".to_string()); - let (mut network, mut handler, client) = new_smol(options); + let options = ConnectOptions::new("mqrsttSmolExample".to_string()); + let (mut network, client) = new_smol(options); let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)) .await .unwrap(); network.connect(stream).await.unwrap(); + + // This subscribe is only processed when we run the network client.subscribe("mqrstt").await.unwrap(); + let mut pingpong = PingPong { client: client.clone(), }; - let (n, h, t) = futures::join!( + let (n, t) = futures::join!( async { loop { - return match network.run().await { + return match network.poll(&mut pingpong).await { Ok(NetworkStatus::Active) => continue, otherwise => otherwise, }; } }, async { - loop { - return match handler.handle_mut(&mut pingpong).await { - Ok(HandlerStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - smol::Timer::after(std::time::Duration::from_secs(60)).await; + smol::Timer::after(std::time::Duration::from_secs(30)).await; client.disconnect().await.unwrap(); } ); assert!(n.is_ok()); - assert!(h.is_ok()); }); ``` - ### Tokio example: ```rust -let options = ConnectOptions::new("TokioTcpPingPong".to_string()); - -let (mut network, mut handler, client) = new_tokio(options); -let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)) - .await - .unwrap(); +use mqrstt::{ + MqttClient, + ConnectOptions, + new_tokio, + packets::{self, Packet}, + AsyncEventHandler, NetworkStatus, +}; +use tokio::time::Duration; +use async_trait::async_trait; +use bytes::Bytes; -network.connect(stream).await.unwrap(); +pub struct PingPong { + pub client: MqttClient, +} +#[async_trait] +impl AsyncEventHandler for PingPong { + // Handlers only get INCOMING packets. This can change later. + async fn handle(&mut self, event: packets::Packet) -> () { + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + if payload.to_lowercase().contains("ping") { + self.client + .publish( + p.qos, + p.retain, + p.topic.clone(), + Bytes::from_static(b"pong"), + ) + .await + .unwrap(); + println!("Received Ping, Send pong!"); + } + } + }, + Packet::ConnAck(_) => { println!("Connected!") }, + _ => (), + } + } +} -client.subscribe("mqrstt").await.unwrap(); +#[tokio::main] +async fn main() { + let options = ConnectOptions::new("TokioTcpPingPongExample".to_string()); + + let (mut network, client) = new_tokio(options); + + let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)) + .await + .unwrap(); + + network.connect(stream).await.unwrap(); + + client.subscribe("mqrstt").await.unwrap(); + + let mut pingpong = PingPong { + client: client.clone(), + }; -let mut pingpong = PingPong { - client: client.clone(), -}; + let (n, _) = tokio::join!( + async { + loop { + return match network.poll(&mut pingpong).await { + Ok(NetworkStatus::Active) => continue, + otherwise => otherwise, + }; + } + }, + async { + tokio::time::sleep(Duration::from_secs(30)).await; + client.disconnect().await.unwrap(); + } + ); +} -let (n, h, _) = tokio::join!( - async { - loop { - return match network.run().await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - loop { - return match handler.handle_mut(&mut pingpong).await { - Ok(HandlerStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - tokio::time::sleep(Duration::from_secs(60)).await; - client.disconnect().await.unwrap(); - } -); ``` -## Important notes: - - Handlers only get incoming packets. +## FAQ + - Why are there no implementations for TLS connections? + Many examples of creating TLS streams in rust exist with the crates [`async-rustls`](https://crates.io/crates/async-rustls) and [`tokio-rustls`](https://crates.io/crates/tokio-rustls). The focus of this crate is `MQTTv5` and providing a runtime free choice. + +- What are the advantages over [`rumqttc`](https://crates.io/crates/rumqttc)? + - Handling of messages by user before acknowledgement. + - Ping req depending on communication + - No `rumqttc` packet id collision errors (It is not possible with `rumqtts`). + - Runtime agnositc + - Mqtt version 5 support + - Please ask :) ## Size With the smol runtime you can create very small binaries. A simple PingPong smol TCP client can be had for 550\~KB and with TLS you are looking at 1.5\~ MB using the following flags. This makes `mqrstt` extremely usefull for embedded devices! :) @@ -151,7 +210,6 @@ strip = true ``` ## License - Licensed under * Mozilla Public License, Version 2.0, [(MPL-2.0)](https://choosealicense.com/licenses/mpl-2.0/) diff --git a/rustfmt.toml b/rustfmt.toml index e672bfc..c9ce889 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,3 +1,4 @@ -unstable_features = true -brace_style = "PreferSameLine" -control_brace_style = "ClosingNextLine" \ No newline at end of file +# unstable_features = true +# brace_style = "PreferSameLine" +# control_brace_style = "ClosingNextLine" +max_width = 200 \ No newline at end of file diff --git a/src/available_packet_ids.rs b/src/available_packet_ids.rs index a6c2d80..a6ad750 100644 --- a/src/available_packet_ids.rs +++ b/src/available_packet_ids.rs @@ -1,7 +1,7 @@ -use async_channel::{Receiver, Sender}; +use async_channel::{Receiver, Sender, TrySendError}; use tracing::error; -use crate::error::MqttError; +use crate::error::HandlerError; #[derive(Debug, Clone)] pub struct AvailablePacketIds { @@ -36,18 +36,20 @@ impl AvailablePacketIds { // } // } - pub async fn mark_available(&self, pkid: u16) -> Result<(), MqttError> { - match self.sender.send(pkid).await { + pub fn mark_available(&self, pkid: u16) -> Result<(), HandlerError> { + match self.sender.try_send(pkid) { Ok(_) => { Ok(()) // debug!("Marked packet id as available: {}", pkid); } - Err(err) => { - error!( - "Encountered an error while marking an packet id as available. Error: {}", - err - ); - Err(MqttError::PacketIdError(err.0)) + Err(TrySendError::Closed(pkid)) => { + error!("Packet Id channel was closed"); + Err(HandlerError::PacketIdError(pkid)) + } + Err(TrySendError::Full(_)) => { + // There can never be more than the predetermined number of packet ids. + // Meaning that they then all fit in the channel + unreachable!() } } } diff --git a/src/client.rs b/src/client.rs index c6eb42f..5a6dbaf 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,92 +6,47 @@ use crate::{ error::ClientError, packets::{ reason_codes::DisconnectReasonCode, - Packet, QoS, {Disconnect, DisconnectProperties}, {Publish, PublishProperties}, - {Subscribe, SubscribeProperties, Subscription}, - {Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, + Packet, QoS, {Disconnect, DisconnectProperties}, {Publish, PublishProperties}, {Subscribe, SubscribeProperties, Subscription}, {Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, }, }; #[derive(Debug, Clone)] -pub struct AsyncClient { - // Provides this client with an available packet id or waits on it. +pub struct MqttClient { + /// Provides this client with an available packet id or waits on it. available_packet_ids: Receiver, - // Sends Publish, Subscribe, Unsubscribe to the event handler to handle later. - client_to_handler_s: Sender, - - // Sends Publish with QoS 0 + /// Sends Publish, Subscribe, Unsubscribe to the event handler to handle later. to_network_s: Sender, } -impl AsyncClient { - pub fn new( - available_packet_ids: Receiver, - client_to_handler_s: Sender, - to_network_s: Sender, - ) -> Self { - Self { - available_packet_ids, - client_to_handler_s, - to_network_s, - } +impl MqttClient { + pub fn new(available_packet_ids: Receiver, to_network_s: Sender) -> Self { + Self { available_packet_ids, to_network_s } } - pub async fn subscribe>( - &self, - into_subscribtions: A, - ) -> Result<(), ClientError> { - let pkid = self - .available_packet_ids - .recv() - .await - .map_err(|_| ClientError::NoHandler)?; + pub async fn subscribe>(&self, into_subscribtions: A) -> Result<(), ClientError> { + let pkid = self.available_packet_ids.recv().await.map_err(|_| ClientError::NoNetwork)?; let subscription: Subscription = into_subscribtions.into(); let sub = Subscribe::new(pkid, subscription.0); - self.client_to_handler_s - .send(Packet::Subscribe(sub)) - .await - .map_err(|_| ClientError::NoHandler)?; + self.to_network_s.send(Packet::Subscribe(sub)).await.map_err(|_| ClientError::NoNetwork)?; Ok(()) } - pub async fn subscribe_with_properties>( - &self, - properties: SubscribeProperties, - into_sub: S, - ) -> Result<(), ClientError> { - let pkid = self - .available_packet_ids - .recv() - .await - .map_err(|_| ClientError::NoHandler)?; + pub async fn subscribe_with_properties>(&self, properties: SubscribeProperties, into_sub: S) -> Result<(), ClientError> { + let pkid = self.available_packet_ids.recv().await.map_err(|_| ClientError::NoNetwork)?; let sub = Subscribe { packet_identifier: pkid, properties, topics: into_sub.into().0, }; - self.client_to_handler_s - .send(Packet::Subscribe(sub)) - .await - .map_err(|_| ClientError::NoHandler)?; + self.to_network_s.send(Packet::Subscribe(sub)).await.map_err(|_| ClientError::NoNetwork)?; Ok(()) } - pub async fn publish>( - &self, - qos: QoS, - retain: bool, - topic: String, - payload: P, - ) -> Result<(), ClientError> { + pub async fn publish>(&self, qos: QoS, retain: bool, topic: String, payload: P) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, - _ => Some( - self.available_packet_ids - .recv() - .await - .map_err(|_| ClientError::NoHandler)?, - ), + _ => Some(self.available_packet_ids.recv().await.map_err(|_| ClientError::NoNetwork)?), }; info!("Published message with ID: {:?}", pkid); let publish = Publish { @@ -103,44 +58,15 @@ impl AsyncClient { publish_properties: PublishProperties::default(), payload: payload.into(), }; - if qos == QoS::AtMostOnce { - self.to_network_s - .send(Packet::Publish(publish)) - .await - .map_err(|_| ClientError::NoHandler)?; - info!( - "Published message into network_packet_sender. len {}", - self.to_network_s.len() - ); - } else { - self.client_to_handler_s - .send(Packet::Publish(publish)) - .await - .map_err(|_| ClientError::NoHandler)?; - info!( - "Published message into handler_packet_sender: len {}", - self.client_to_handler_s.len() - ); - } + self.to_network_s.send(Packet::Publish(publish)).await.map_err(|_| ClientError::NoNetwork)?; + info!("Published message into handler_packet_sender: len {}", self.to_network_s.len()); Ok(()) } - pub async fn publish_with_properties>( - &self, - qos: QoS, - retain: bool, - topic: String, - payload: P, - properties: PublishProperties, - ) -> Result<(), ClientError> { + pub async fn publish_with_properties>(&self, qos: QoS, retain: bool, topic: String, payload: P, properties: PublishProperties) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, - _ => Some( - self.available_packet_ids - .recv() - .await - .map_err(|_| ClientError::NoHandler)?, - ), + _ => Some(self.available_packet_ids.recv().await.map_err(|_| ClientError::NoNetwork)?), }; let publish = Publish { dup: false, @@ -151,60 +77,29 @@ impl AsyncClient { publish_properties: properties, payload: payload.into(), }; - if qos == QoS::AtMostOnce { - self.to_network_s - .send(Packet::Publish(publish)) - .await - .map_err(|_| ClientError::NoHandler)?; - } else { - self.client_to_handler_s - .send(Packet::Publish(publish)) - .await - .map_err(|_| ClientError::NoHandler)?; - } + self.to_network_s.send(Packet::Publish(publish)).await.map_err(|_| ClientError::NoNetwork)?; Ok(()) } - pub async fn unsubscribe>( - &self, - into_topics: T, - ) -> Result<(), ClientError> { - let pkid = self - .available_packet_ids - .recv() - .await - .map_err(|_| ClientError::NoHandler)?; + pub async fn unsubscribe>(&self, into_topics: T) -> Result<(), ClientError> { + let pkid = self.available_packet_ids.recv().await.map_err(|_| ClientError::NoNetwork)?; let unsub = Unsubscribe { packet_identifier: pkid, properties: UnsubscribeProperties::default(), topics: into_topics.into().0, }; - self.client_to_handler_s - .send(Packet::Unsubscribe(unsub)) - .await - .map_err(|_| ClientError::NoHandler)?; + self.to_network_s.send(Packet::Unsubscribe(unsub)).await.map_err(|_| ClientError::NoNetwork)?; Ok(()) } - pub async fn unsubscribe_with_properties>( - &self, - into_topics: T, - properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { - let pkid = self - .available_packet_ids - .recv() - .await - .map_err(|_| ClientError::NoHandler)?; + pub async fn unsubscribe_with_properties>(&self, into_topics: T, properties: UnsubscribeProperties) -> Result<(), ClientError> { + let pkid = self.available_packet_ids.recv().await.map_err(|_| ClientError::NoNetwork)?; let unsub = Unsubscribe { packet_identifier: pkid, properties, topics: into_topics.into().0, }; - self.client_to_handler_s - .send(Packet::Unsubscribe(unsub)) - .await - .map_err(|_| ClientError::NoHandler)?; + self.to_network_s.send(Packet::Unsubscribe(unsub)).await.map_err(|_| ClientError::NoNetwork)?; Ok(()) } @@ -213,26 +108,13 @@ impl AsyncClient { reason_code: DisconnectReasonCode::NormalDisconnection, properties: DisconnectProperties::default(), }; - self.client_to_handler_s - .send(Packet::Disconnect(disconnect)) - .await - .map_err(|_| ClientError::NoHandler)?; + self.to_network_s.send(Packet::Disconnect(disconnect)).await.map_err(|_| ClientError::NoNetwork)?; Ok(()) } - pub async fn disconnect_with_properties( - &self, - reason_code: DisconnectReasonCode, - properties: DisconnectProperties, - ) -> Result<(), ClientError> { - let disconnect = Disconnect { - reason_code, - properties, - }; - self.client_to_handler_s - .send(Packet::Disconnect(disconnect)) - .await - .map_err(|_| ClientError::NoHandler)?; + pub async fn disconnect_with_properties(&self, reason_code: DisconnectReasonCode, properties: DisconnectProperties) -> Result<(), ClientError> { + let disconnect = Disconnect { reason_code, properties }; + self.to_network_s.send(Packet::Disconnect(disconnect)).await.map_err(|_| ClientError::NoNetwork)?; Ok(()) } } @@ -241,14 +123,11 @@ impl AsyncClient { mod tests { use async_channel::Receiver; - use crate::packets::{ - reason_codes::DisconnectReasonCode, DisconnectProperties, Packet, PacketType, - UnsubscribeProperties, - }; + use crate::packets::{reason_codes::DisconnectReasonCode, DisconnectProperties, Packet, PacketType, UnsubscribeProperties}; - use super::AsyncClient; + use super::MqttClient; - fn create_new_test_client() -> (AsyncClient, Receiver, Receiver) { + fn create_new_test_client() -> (MqttClient, Receiver, Receiver) { let (s, r) = async_channel::bounded(100); for i in 1..=100 { @@ -256,9 +135,9 @@ mod tests { } let (client_to_handler_s, client_to_handler_r) = async_channel::bounded(100); - let (to_network_s, to_network_r) = async_channel::bounded(100); + let (_, to_network_r) = async_channel::bounded(100); - let client = AsyncClient::new(r, client_to_handler_s, to_network_s); + let client = MqttClient::new(r, client_to_handler_s); (client, client_to_handler_r, to_network_r) } @@ -270,10 +149,7 @@ mod tests { let mut prop = UnsubscribeProperties::default(); prop.user_properties = vec![("A".to_string(), "B".to_string())]; - client - .unsubscribe_with_properties("Topic", prop.clone()) - .await - .unwrap(); + client.unsubscribe_with_properties("Topic", prop.clone()).await.unwrap(); let unsubscribe = client_to_handler_r.recv().await.unwrap(); assert_eq!(PacketType::Unsubscribe, unsubscribe.packet_type()); @@ -307,10 +183,7 @@ mod tests { #[tokio::test] async fn disconnect_with_properties_test() { let (client, client_to_handler_r, _) = create_new_test_client(); - client - .disconnect_with_properties(DisconnectReasonCode::KeepAliveTimeout, Default::default()) - .await - .unwrap(); + client.disconnect_with_properties(DisconnectReasonCode::KeepAliveTimeout, Default::default()).await.unwrap(); let disconnect = client_to_handler_r.recv().await.unwrap(); assert_eq!(PacketType::Disconnect, disconnect.packet_type()); @@ -329,10 +202,7 @@ mod tests { let mut properties = DisconnectProperties::default(); properties.reason_string = Some("TestString".to_string()); - client - .disconnect_with_properties(DisconnectReasonCode::KeepAliveTimeout, properties.clone()) - .await - .unwrap(); + client.disconnect_with_properties(DisconnectReasonCode::KeepAliveTimeout, properties.clone()).await.unwrap(); let disconnect = client_to_handler_r.recv().await.unwrap(); assert_eq!(PacketType::Disconnect, disconnect.packet_type()); diff --git a/src/connect_options.rs b/src/connect_options.rs index b290603..133173b 100644 --- a/src/connect_options.rs +++ b/src/connect_options.rs @@ -71,4 +71,4 @@ impl ConnectOptions { pub fn receive_maximum(&self) -> u16 { self.receive_maximum.unwrap_or(RECEIVE_MAXIMUM_DEFAULT) } -} \ No newline at end of file +} diff --git a/src/connections/mod.rs b/src/connections/mod.rs index ba06aad..66bfead 100644 --- a/src/connections/mod.rs +++ b/src/connections/mod.rs @@ -1,9 +1,9 @@ #[cfg(all(feature = "quic"))] pub mod quic; -#[cfg(feature = "smol")] -pub mod smol_stream; -#[cfg(feature = "tokio")] -pub mod tokio_stream; +// #[cfg(feature = "smol")] +pub mod smol; +// #[cfg(feature = "tokio")] +pub mod tokio; use crate::connect_options::ConnectOptions; use crate::packets::Connect; diff --git a/src/connections/smol.rs b/src/connections/smol.rs new file mode 100644 index 0000000..a3dd4ea --- /dev/null +++ b/src/connections/smol.rs @@ -0,0 +1,172 @@ +use std::io::{self, Error, ErrorKind}; + +use bytes::{Buf, BytesMut}; + +// use futures::{AsyncReadExt, AsyncWriteExt}; +use smol::io::{AsyncReadExt, AsyncWriteExt}; + +use tracing::trace; + +use crate::packets::{ + error::ReadBytes, + reason_codes::ConnAckReasonCode, + {FixedHeader, Packet, PacketType}, +}; +use crate::{connect_options::ConnectOptions, connections::create_connect_from_options, error::ConnectionError}; + +#[derive(Debug)] +pub struct Stream { + pub stream: S, + + /// Input buffer + const_buffer: [u8; 1000], + + /// Write buffer + read_buffer: BytesMut, + + /// Write buffer + write_buffer: BytesMut, +} + +impl Stream { + pub async fn parse_message(&mut self) -> Result> { + let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; + + if header.remaining_length + header_length > self.read_buffer.len() { + return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); + } + + self.read_buffer.advance(header_length); + + let buf = self.read_buffer.split_to(header.remaining_length); + Ok(Packet::read(header, buf.into())?) + } + + pub async fn parse_messages(&mut self, incoming_packet_buffer: &mut Vec) -> Result<(), ReadBytes> { + loop { + if self.read_buffer.is_empty() { + return Ok(()); + } + let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; + + if header.remaining_length + header_length > self.read_buffer.len() { + return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); + } + + self.read_buffer.advance(header_length); + + let buf = self.read_buffer.split_to(header.remaining_length); + let read_packet = Packet::read(header, buf.into())?; + tracing::trace!("Read packet from network {}", read_packet); + let packet_type = read_packet.packet_type(); + incoming_packet_buffer.push(read_packet); + + if packet_type == PacketType::Disconnect { + return Ok(()); + } + } + } +} + +impl Stream +where + S: smol::io::AsyncRead + smol::io::AsyncWrite + Sized + Unpin, +{ + pub async fn connect(options: &ConnectOptions, stream: S) -> Result<(Self, Packet), ConnectionError> { + let mut s = Self { + stream, + const_buffer: [0; 1000], + read_buffer: BytesMut::new(), + write_buffer: BytesMut::new(), + }; + + let connect = create_connect_from_options(options); + + s.write(&connect).await?; + + let packet = s.read().await?; + if let Packet::ConnAck(con) = packet { + if con.reason_code == ConnAckReasonCode::Success { + Ok((s, Packet::ConnAck(con))) + } else { + Err(ConnectionError::ConnectionRefused(con.reason_code)) + } + } else { + Err(ConnectionError::NotConnAck(packet)) + } + } + + pub async fn read(&mut self) -> io::Result { + loop { + let (header, header_length) = match FixedHeader::read_fixed_header(self.read_buffer.iter()) { + Ok(header) => header, + Err(ReadBytes::InsufficientBytes(required_len)) => { + self.read_required_bytes(required_len).await?; + continue; + } + Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), + }; + + self.read_buffer.advance(header_length); + + if header.remaining_length > self.read_buffer.len() { + self.read_required_bytes(header.remaining_length - self.read_buffer.len()).await?; + } + + let buf = self.read_buffer.split_to(header.remaining_length); + + return Packet::read(header, buf.into()).map_err(|err| Error::new(ErrorKind::InvalidData, err)); + } + } + + pub async fn read_bytes(&mut self) -> io::Result { + let read = self.stream.read(&mut self.const_buffer).await?; + if read == 0 { + Err(io::Error::new(io::ErrorKind::ConnectionReset, "Connection reset by peer")) + } else { + self.read_buffer.extend_from_slice(&self.const_buffer[0..read]); + Ok(read) + } + } + + /// Reads more than 'required' bytes to frame a packet into self.read buffer + pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { + let mut total_read = 0; + + loop { + let read = self.read_bytes().await?; + total_read += read; + if total_read >= required { + return Ok(total_read); + } + } + } + + pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { + packet.write(&mut self.write_buffer)?; + trace!("Sending packet {}", packet); + + self.stream.write_all(&self.write_buffer[..]).await?; + self.stream.flush().await?; + self.write_buffer.clear(); + Ok(()) + } + + pub async fn write_all(&mut self, packets: &mut Vec) -> Result<(), ConnectionError> { + let writes = packets.drain(0..).map(|packet| { + packet.write(&mut self.write_buffer)?; + trace!("Sending packet {}", packet); + + Ok::<(), ConnectionError>(()) + }); + + for write in writes { + write?; + } + + self.stream.write_all(&self.write_buffer[..]).await?; + self.stream.flush().await?; + self.write_buffer.clear(); + Ok(()) + } +} diff --git a/src/connections/smol_stream.rs b/src/connections/smol_stream.rs deleted file mode 100644 index b5ee24e..0000000 --- a/src/connections/smol_stream.rs +++ /dev/null @@ -1,172 +0,0 @@ -use std::io::{self, Error, ErrorKind}; - -use bytes::{Buf, BytesMut}; -use smol::io::{AsyncReadExt, AsyncWriteExt}; - -use futures::{AsyncRead, AsyncWrite}; - -use tracing::trace; - -use crate::packets::{ - error::ReadBytes, - reason_codes::ConnAckReasonCode, - {FixedHeader, Packet, PacketType}, -}; -use crate::{ - connect_options::ConnectOptions, connections::create_connect_from_options, - error::ConnectionError, -}; - -#[derive(Debug)] -pub struct SmolStream { - pub stream: S, - - /// Input buffer - const_buffer: [u8; 1000], - /// Buffered reads - buffer: BytesMut, -} - -impl SmolStream -where - S: AsyncRead + AsyncWrite + Sized + Unpin, -{ - pub async fn connect( - options: &ConnectOptions, - stream: S, - ) -> Result<(Self, Packet), ConnectionError> { - let mut s = Self { - stream, - const_buffer: [0; 1000], - buffer: BytesMut::new(), - }; - - let mut buf_out = BytesMut::new(); - - create_connect_from_options(options).write(&mut buf_out)?; - - s.write_buffer(&mut buf_out).await?; - - let packet = s.read().await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - Ok((s, Packet::ConnAck(con))) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } - } else { - Err(ConnectionError::NotConnAck(packet)) - } - } - - pub async fn parse_messages( - &mut self, - incoming_packet_sender: &async_channel::Sender, - ) -> Result, ReadBytes> { - let mut ret_packet_type = None; - loop { - if self.buffer.is_empty() { - return Ok(ret_packet_type); - } - let (header, header_length) = FixedHeader::read_fixed_header(self.buffer.iter())?; - - if header.remaining_length > self.buffer.len() { - return Err(ReadBytes::InsufficientBytes( - header.remaining_length - self.buffer.len(), - )); - } - - self.buffer.advance(header_length); - - let buf = self.buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - tracing::trace!("Read packet from network {}", read_packet); - let packet_type = read_packet.packet_type(); - incoming_packet_sender.send(read_packet).await?; - - match packet_type { - PacketType::Disconnect => return Ok(Some(PacketType::Disconnect)), - PacketType::PingResp => return Ok(Some(PacketType::PingResp)), - packet_type => ret_packet_type = Some(packet_type), - } - } - } - - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_required_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_required_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()) - .map_err(|err| Error::new(ErrorKind::InvalidData, err)); - } - } - - pub async fn read_bytes(&mut self) -> io::Result { - let read = self.stream.read(&mut self.const_buffer).await?; - if 0 == read { - if self.buffer.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "Connection reset by peer", - )) - } - } else { - self.buffer.extend_from_slice(&self.const_buffer[0..read]); - Ok(read) - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - - loop { - let read = self.read_bytes().await?; - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } - - pub async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> { - if buffer.is_empty() { - return Ok(()); - } - - self.stream.write_all(&buffer[..]).await?; - buffer.clear(); - Ok(()) - } - - pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { - packet.write(&mut self.buffer)?; - trace!("Sending packet {}", packet); - - self.stream.write_all(&self.buffer[..]).await?; - self.stream.flush().await?; - self.buffer.clear(); - Ok(()) - } -} diff --git a/src/connections/tokio.rs b/src/connections/tokio.rs new file mode 100644 index 0000000..86896a0 --- /dev/null +++ b/src/connections/tokio.rs @@ -0,0 +1,172 @@ +use std::io::{self, Error, ErrorKind}; + +use bytes::{Buf, BytesMut}; + +#[cfg(feature = "tokio")] +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use tracing::trace; + +use crate::packets::{ + error::ReadBytes, + reason_codes::ConnAckReasonCode, + {FixedHeader, Packet, PacketType}, +}; +use crate::{connect_options::ConnectOptions, connections::create_connect_from_options, error::ConnectionError}; + +#[derive(Debug)] +pub struct Stream { + pub stream: S, + + /// Input buffer + const_buffer: [u8; 1000], + + /// Write buffer + read_buffer: BytesMut, + + /// Write buffer + write_buffer: BytesMut, +} + +impl Stream { + pub async fn parse_message(&mut self) -> Result> { + let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; + + if header.remaining_length + header_length > self.read_buffer.len() { + return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); + } + + self.read_buffer.advance(header_length); + + let buf = self.read_buffer.split_to(header.remaining_length); + Ok(Packet::read(header, buf.into())?) + } + + pub async fn parse_messages(&mut self, incoming_packet_buffer: &mut Vec) -> Result<(), ReadBytes> { + loop { + if self.read_buffer.is_empty() { + return Ok(()); + } + let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; + + if header.remaining_length + header_length > self.read_buffer.len() { + return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); + } + + self.read_buffer.advance(header_length); + + let buf = self.read_buffer.split_to(header.remaining_length); + let read_packet = Packet::read(header, buf.into())?; + tracing::trace!("Read packet from network {}", read_packet); + let packet_type = read_packet.packet_type(); + incoming_packet_buffer.push(read_packet); + + if packet_type == PacketType::Disconnect { + return Ok(()); + } + } + } +} + +impl Stream +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin, +{ + pub async fn connect(options: &ConnectOptions, stream: S) -> Result<(Self, Packet), ConnectionError> { + let mut s = Self { + stream, + const_buffer: [0; 1000], + read_buffer: BytesMut::new(), + write_buffer: BytesMut::new(), + }; + + let connect = create_connect_from_options(options); + + s.write(&connect).await?; + + let packet = s.read().await?; + if let Packet::ConnAck(con) = packet { + if con.reason_code == ConnAckReasonCode::Success { + Ok((s, Packet::ConnAck(con))) + } else { + Err(ConnectionError::ConnectionRefused(con.reason_code)) + } + } else { + Err(ConnectionError::NotConnAck(packet)) + } + } + + pub async fn read(&mut self) -> io::Result { + loop { + let (header, header_length) = match FixedHeader::read_fixed_header(self.read_buffer.iter()) { + Ok(header) => header, + Err(ReadBytes::InsufficientBytes(required_len)) => { + self.read_required_bytes(required_len).await?; + continue; + } + Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), + }; + + self.read_buffer.advance(header_length); + + if header.remaining_length > self.read_buffer.len() { + self.read_required_bytes(header.remaining_length - self.read_buffer.len()).await?; + } + + let buf = self.read_buffer.split_to(header.remaining_length); + + return Packet::read(header, buf.into()).map_err(|err| Error::new(ErrorKind::InvalidData, err)); + } + } + + pub async fn read_bytes(&mut self) -> io::Result { + let read = self.stream.read(&mut self.const_buffer).await?; + if read == 0 { + Err(io::Error::new(io::ErrorKind::ConnectionReset, "Connection reset by peer")) + } else { + self.read_buffer.extend_from_slice(&self.const_buffer[0..read]); + Ok(read) + } + } + + /// Reads more than 'required' bytes to frame a packet into self.read buffer + pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { + let mut total_read = 0; + + loop { + let read = self.read_bytes().await?; + total_read += read; + if total_read >= required { + return Ok(total_read); + } + } + } + + pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { + packet.write(&mut self.write_buffer)?; + trace!("Sending packet {}", packet); + + self.stream.write_all(&self.write_buffer[..]).await?; + self.stream.flush().await?; + self.write_buffer.clear(); + Ok(()) + } + + pub async fn write_all(&mut self, packets: &mut Vec) -> Result<(), ConnectionError> { + let writes = packets.drain(0..).map(|packet| { + packet.write(&mut self.write_buffer)?; + trace!("Sending packet {}", packet); + + Ok::<(), ConnectionError>(()) + }); + + for write in writes { + write?; + } + + self.stream.write_all(&self.write_buffer[..]).await?; + self.stream.flush().await?; + self.write_buffer.clear(); + Ok(()) + } +} diff --git a/src/connections/tokio_stream.rs b/src/connections/tokio_stream.rs deleted file mode 100644 index 099c24f..0000000 --- a/src/connections/tokio_stream.rs +++ /dev/null @@ -1,170 +0,0 @@ -use std::io::{self, Error, ErrorKind}; - -use bytes::{Buf, BytesMut}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -use tracing::trace; - -use crate::packets::{ - error::ReadBytes, - reason_codes::ConnAckReasonCode, - {FixedHeader, Packet, PacketType}, -}; -use crate::{ - connect_options::ConnectOptions, connections::create_connect_from_options, - error::ConnectionError, -}; - -#[derive(Debug)] -pub struct TokioStream { - pub stream: S, - - /// Input buffer - const_buffer: [u8; 1000], - /// Buffered reads - buffer: BytesMut, -} - -impl TokioStream -where - S: AsyncRead + AsyncWrite + Sized + Unpin, -{ - pub async fn connect( - options: &ConnectOptions, - stream: S, - ) -> Result<(Self, Packet), ConnectionError> { - let mut s = Self { - stream, - const_buffer: [0; 1000], - buffer: BytesMut::new(), - }; - - let mut buf_out = BytesMut::new(); - - create_connect_from_options(options).write(&mut buf_out)?; - - s.write_buffer(&mut buf_out).await?; - - let packet = s.read().await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - Ok((s, Packet::ConnAck(con))) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } - } else { - Err(ConnectionError::NotConnAck(packet)) - } - } - - pub async fn parse_messages( - &mut self, - incoming_packet_sender: &async_channel::Sender, - ) -> Result, ReadBytes> { - let mut ret_packet_type = None; - loop { - if self.buffer.is_empty() { - return Ok(ret_packet_type); - } - let (header, header_length) = FixedHeader::read_fixed_header(self.buffer.iter())?; - - if header.remaining_length > self.buffer.len() { - return Err(ReadBytes::InsufficientBytes( - header.remaining_length - self.buffer.len(), - )); - } - - self.buffer.advance(header_length); - - let buf = self.buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - tracing::trace!("Read packet from network {}", read_packet); - let packet_type = read_packet.packet_type(); - incoming_packet_sender.send(read_packet).await?; - - match packet_type { - PacketType::Disconnect => return Ok(Some(PacketType::Disconnect)), - PacketType::PingResp => return Ok(Some(PacketType::PingResp)), - packet_type => ret_packet_type = Some(packet_type), - } - } - } - - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_required_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_required_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()) - .map_err(|err| Error::new(ErrorKind::InvalidData, err)); - } - } - - pub async fn read_bytes(&mut self) -> io::Result { - let read = self.stream.read(&mut self.const_buffer).await?; - if 0 == read { - if self.buffer.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "Connection reset by peer", - )) - } - } else { - self.buffer.extend_from_slice(&self.const_buffer[0..read]); - Ok(read) - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - - loop { - let read = self.read_bytes().await?; - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } - - pub async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> { - if buffer.is_empty() { - return Ok(()); - } - - self.stream.write_all(&buffer[..]).await?; - buffer.clear(); - Ok(()) - } - - pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { - packet.write(&mut self.buffer)?; - trace!("Sending packet {}", packet); - - self.stream.write_all(&self.buffer[..]).await?; - self.stream.flush().await?; - self.buffer.clear(); - Ok(()) - } -} diff --git a/src/error.rs b/src/error.rs index 299f8ae..e3ffac7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,38 +8,6 @@ use crate::packets::{ {Packet, PacketType}, }; -/// Errors that the [`mqrstt::EventHandler`] can emit -#[derive(Debug, Clone, thiserror::Error)] -pub enum MqttError { - #[error("Missing Packet ID")] - MissingPacketId, - - #[error("The incoming channel between network and handler is closed")] - IncomingNetworkChannelClosed, - - #[error("The outgoing channel between handler and network is closed: {0}")] - OutgoingNetworkChannelClosed(#[from] SendError), - - #[error("Channel between client and handler closed")] - ClientChannelClosed, - - #[error("Packet Id error, pkid: {0}")] - PacketIdError(u16), - - #[error("Received unsolicited ack pkid: {0}")] - Unsolicited(u16, PacketType), -} - -/// Errors producable by the [`mqrstt::AsyncClient`] -#[derive(Debug, Clone, thiserror::Error)] -pub enum ClientError { - #[error("One of more of the internal handler channels are closed")] - NoHandler, - - #[error("Internal network channel is closed")] - NoNetwork, -} - /// Critical errors that can happen during the operation of the entire client #[derive(Debug, thiserror::Error)] pub enum ConnectionError { @@ -66,6 +34,41 @@ pub enum ConnectionError { #[error("Expected ConnAck packet, received: {0:?}")] NotConnAck(Packet), + + #[error("The handler encountered an error")] + HandlerError(#[from] HandlerError), +} + +/// Errors that the [`mqrstt::MqttHandler`] can emit +#[derive(Debug, Clone, thiserror::Error)] +pub enum HandlerError { + #[error("Missing Packet ID")] + MissingPacketId, + + #[error("The incoming channel between network and handler is closed")] + IncomingNetworkChannelClosed, + + #[error("The outgoing channel between handler and network is closed: {0}")] + OutgoingNetworkChannelClosed(#[from] SendError), + + #[error("Channel between client and handler closed")] + ClientChannelClosed, + + #[error("Packet Id error, pkid: {0}")] + PacketIdError(u16), + + #[error("Packet collision error. packet ID: {0}")] + PacketIdCollision(u16), + + #[error("Received unsolicited ack pkid: {0}")] + Unsolicited(u16, PacketType), +} + +/// Errors producable by the [`mqrstt::AsyncClient`] +#[derive(Debug, Clone, thiserror::Error)] +pub enum ClientError { + #[error("Internal network channel is closed")] + NoNetwork, } impl From> for ReadBytes { @@ -87,4 +90,4 @@ impl From> for ReadBytes { fn from(value: SendError) -> Self { ReadBytes::Err(value.into()) } -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index bf57b31..0b3d025 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,49 +1,35 @@ -//! A pure rust MQTT client which strives to be as efficient as possible. -//! This crate strives to provide an ergonomic API and design that fits Rust. -//! -//! There are three parts to the design of the MQTT client. The network, the event handler and the client. -//! -//! - The network - which simply reads and forms packets from the network. -//! - The event handler - which makes sure that the MQTT protocol is followed. -//! By providing a custom handler messages are handled before they are acked, meaning that they are always handled. -//! - The client - which is used to send messages from different places. -//! -//! To Do: -//! - Rebroadcast unacked packets -//! - Enforce size of outbound messages (e.g. Publish) -//! - Sync API -//! - More testing -//! - More documentation -//! - Remove logging calls or move all to test flag -//! -//! A few questions still remain: -//! - This crate uses async channels to perform communication across its parts. Is there a better approach? -//! These channels do allow the user to decouple the network, handlers, and clients very easily. -//! - MPL-2.0 vs MIT OR APACHE 2.0 license? [poll](https://github.com/GunnarMorrigan/mqrstt/discussions/2) -//! - The handler currently only gets INCOMING packets +//! A pure rust MQTT client which strives to be easy to use and efficient. +//! Providing both async and sync options. //! +//! Because this crate aims to be runtime agnostic the user is required to provide their own data stream. +//! The stream has to implement the smol or tokio [`AsyncReadExt`] and [`AsyncWriteExt`] traits. //! +//! Notes: +//! ---------------------------- +//! - Your handler should not wait too long +//! - Create a new connection when an error or disconnect is encountered +//! - Handlers only get incoming packets //! -//! You want to reconnect (with a new stream) after the network encountered an error or a disconnect took place! //! //! Smol example: -//! ``` +//! ---------------------------- +//! ```rust //! use mqrstt::{ -//! AsyncClient, +//! MqttClient, //! ConnectOptions, //! new_smol, //! packets::{self, Packet}, -//! AsyncEventHandlerMut, HandlerStatus, NetworkStatus, +//! AsyncEventHandler, NetworkStatus, //! }; //! use async_trait::async_trait; //! use bytes::Bytes; //! pub struct PingPong { -//! pub client: AsyncClient, +//! pub client: MqttClient, //! } //! #[async_trait] -//! impl AsyncEventHandlerMut for PingPong { +//! impl AsyncEventHandler for PingPong { //! // Handlers only get INCOMING packets. This can change later. -//! async fn handle(&mut self, event: &packets::Packet) -> () { +//! async fn handle(&mut self, event: packets::Packet) -> () { //! match event { //! Packet::Publish(p) => { //! if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { @@ -67,123 +53,141 @@ //! } //! } //! smol::block_on(async { -//! let options = ConnectOptions::new("mqrsttExample".to_string()); -//! let (mut network, mut handler, client) = new_smol(options); +//! let options = ConnectOptions::new("mqrsttSmolExample".to_string()); +//! let (mut network, client) = new_smol(options); //! let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)) //! .await //! .unwrap(); //! network.connect(stream).await.unwrap(); +//! +//! // This subscribe is only processed when we run the network //! client.subscribe("mqrstt").await.unwrap(); +//! //! let mut pingpong = PingPong { //! client: client.clone(), //! }; -//! let (n, h, t) = futures::join!( +//! let (n, t) = futures::join!( //! async { //! loop { -//! return match network.run().await { +//! return match network.poll(&mut pingpong).await { //! Ok(NetworkStatus::Active) => continue, //! otherwise => otherwise, //! }; //! } //! }, //! async { -//! loop { -//! return match handler.handle_mut(&mut pingpong).await { -//! Ok(HandlerStatus::Active) => continue, -//! otherwise => otherwise, -//! }; -//! } -//! }, -//! async { //! smol::Timer::after(std::time::Duration::from_secs(30)).await; //! client.disconnect().await.unwrap(); //! } //! ); //! assert!(n.is_ok()); -//! assert!(h.is_ok()); //! }); //! ``` //! //! //! Tokio example: //! ---------------------------- -//! ```ignore -//! let options = ConnectOptions::new("TokioTcpPingPongExample".to_string()); -//! -//! let (mut network, mut handler, client) = new_tokio(options); -//! -//! let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)) -//! .await -//! .unwrap(); -//! -//! network.connect(stream).await.unwrap(); -//! -//! client.subscribe("mqrstt").await.unwrap(); +//! ```rust //! -//! let mut pingpong = PingPong { -//! client: client.clone(), +//! use mqrstt::{ +//! MqttClient, +//! ConnectOptions, +//! new_tokio, +//! packets::{self, Packet}, +//! AsyncEventHandler, NetworkStatus, //! }; +//! use tokio::time::Duration; +//! use async_trait::async_trait; +//! use bytes::Bytes; //! -//! let (n, h, _) = tokio::join!( -//! async { -//! loop { -//! return match network.run().await { -//! Ok(NetworkStatus::Active) => continue, -//! otherwise => otherwise, -//! }; -//! } -//! }, -//! async { -//! loop { -//! return match handler.handle_mut(&mut pingpong).await { -//! Ok(HandlerStatus::Active) => continue, -//! otherwise => otherwise, -//! }; +//! pub struct PingPong { +//! pub client: MqttClient, +//! } +//! #[async_trait] +//! impl AsyncEventHandler for PingPong { +//! // Handlers only get INCOMING packets. This can change later. +//! async fn handle(&mut self, event: packets::Packet) -> () { +//! match event { +//! Packet::Publish(p) => { +//! if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { +//! if payload.to_lowercase().contains("ping") { +//! self.client +//! .publish( +//! p.qos, +//! p.retain, +//! p.topic.clone(), +//! Bytes::from_static(b"pong"), +//! ) +//! .await +//! .unwrap(); +//! println!("Received Ping, Send pong!"); +//! } +//! } +//! }, +//! Packet::ConnAck(_) => { println!("Connected!") }, +//! _ => (), //! } -//! }, -//! async { -//! tokio::time::sleep(Duration::from_secs(30)).await; -//! client.disconnect().await.unwrap(); //! } -//! ); +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let options = ConnectOptions::new("TokioTcpPingPongExample".to_string()); +//! +//! let (mut network, client) = new_tokio(options); +//! +//! let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)) +//! .await +//! .unwrap(); +//! +//! network.connect(stream).await.unwrap(); +//! +//! client.subscribe("mqrstt").await.unwrap(); +//! +//! let mut pingpong = PingPong { +//! client: client.clone(), +//! }; +//! +//! let (n, _) = tokio::join!( +//! async { +//! loop { +//! return match network.poll(&mut pingpong).await { +//! Ok(NetworkStatus::Active) => continue, +//! otherwise => otherwise, +//! }; +//! } +//! }, +//! async { +//! tokio::time::sleep(Duration::from_secs(30)).await; +//! client.disconnect().await.unwrap(); +//! } +//! ); +//! assert!(n.is_ok()); +//! } +//! //! ``` -use packets::Packet; -use smol_network::SmolNetwork; - mod available_packet_ids; mod client; mod connect_options; pub mod connections; pub mod error; mod mqtt_handler; +mod network; pub mod packets; - -#[cfg(feature = "smol")] -mod smol_network; -#[cfg(feature = "smol")] -pub use smol_network::*; - -#[cfg(feature = "tokio")] -mod tokio_network; -#[cfg(feature = "tokio")] -pub use tokio_network::*; - pub mod state; - mod util; -pub use client::AsyncClient; +pub use client::MqttClient; pub use connect_options::ConnectOptions; pub use mqtt_handler::MqttHandler; - - +use packets::Packet; #[cfg(test)] pub mod tests; -/// [`NetworkStatus`] Represents status of the Network object. -/// It is returned when the run handle returns from performing an operation. +// /// [`NetworkStatus`] Represents status of the Network object. +// /// It is returned when the run handle returns from performing an operation. #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub enum NetworkStatus { Active, @@ -192,135 +196,91 @@ pub enum NetworkStatus { NoPingResp, } -/// [`HandlerStatus`] Represents status of the Network object. -/// It is returned when the run handle returns from performing an operation. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum HandlerStatus { - Active, - IncomingDisconnect, - OutgoingDisconnect, -} +// #[cfg(all(feature = "smol", feature = "tokio"))] +// compile_error!("Both smol and tokio runtimes not supported at once."); /// Handlers are used to deal with packets before they are further processed (acked) /// This guarantees that the end user has handlded the packet. /// Trait for async mutable access to handler. /// Usefull when you have a single handler -#[async_trait::async_trait] -pub trait AsyncEventHandlerMut { - async fn handle(&mut self, event: &Packet); -} - #[async_trait::async_trait] pub trait AsyncEventHandler { - async fn handle(&self, event: &Packet); -} - -pub trait EventHandlerMut { - fn handle(&mut self, event: &Packet); + async fn handle(&mut self, event: Packet); } -pub trait EventHandler { - fn handle(&self, event: &Packet); -} +// pub trait EventHandler { +// fn handle(&mut self, event: Packet); +// } -#[cfg(feature = "smol")] /// Creates the needed components to run the MQTT client using a stream that implements [`smol::io::AsyncReadExt`] and [`smol::io::AsyncWriteExt`] -pub fn new_smol(options: ConnectOptions) -> (SmolNetwork, MqttHandler, AsyncClient) +#[cfg(feature = "smol")] +pub fn new_smol(options: ConnectOptions) -> (network::smol::Network, MqttClient) where S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, { - let receive_maximum = options.receive_maximum(); - let (to_network_s, to_network_r) = async_channel::bounded(100); - let (network_to_handler_s, network_to_handler_r) = async_channel::bounded(100); - let (client_to_handler_s, client_to_handler_r) = - async_channel::bounded(receive_maximum as usize); - let (handler, packet_ids) = MqttHandler::new( - &options, - network_to_handler_r, - to_network_s.clone(), - client_to_handler_r, - ); + let (handler, packet_ids) = MqttHandler::new(&options); - let network = SmolNetwork::::new(options, network_to_handler_s, to_network_r); + let network = network::smol::Network::::new(options, handler, to_network_r); - let client = AsyncClient::new(packet_ids, client_to_handler_s, to_network_s); + let client = MqttClient::new(packet_ids, to_network_s); - (network, handler, client) + (network, client) } -#[cfg(feature = "tokio")] /// Creates the needed components to run the MQTT client using a stream that implements [`tokio::io::AsyncReadExt`] and [`tokio::io::AsyncWriteExt`] -pub fn new_tokio( - options: ConnectOptions, -) -> ( - tokio_network::TokioNetwork, - MqttHandler, - AsyncClient, -) +#[cfg(feature = "tokio")] +pub fn new_tokio(options: ConnectOptions) -> (network::tokio::Network, MqttClient) where S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin, { - let receive_maximum = options.receive_maximum(); - let (to_network_s, to_network_r) = async_channel::bounded(100); - let (network_to_handler_s, network_to_handler_r) = async_channel::bounded(100); - let (client_to_handler_s, client_to_handler_r) = - async_channel::bounded(receive_maximum as usize); - let (handler, packet_ids) = MqttHandler::new( - &options, - network_to_handler_r, - to_network_s.clone(), - client_to_handler_r, - ); + let (mqtt_handler, apkid) = MqttHandler::new(&options); - let network = - tokio_network::TokioNetwork::::new(options, network_to_handler_s, to_network_r); + let network = network::tokio::Network::new(options, mqtt_handler, to_network_r); - let client = AsyncClient::new(packet_ids, client_to_handler_s, to_network_s); + let client = MqttClient::new(apkid, to_network_s); - (network, handler, client) + (network, client) } #[cfg(test)] mod lib_test { use std::time::Duration; + #[cfg(feature = "tokio")] + use crate::new_tokio; + use crate::{ - AsyncClient, - ConnectOptions, - new_smol, new_tokio, - packets::{self, Packet, QoS}, + new_smol, + // new_smol, + packets::{self, Packet}, tests::tls::tests::simple_rust_tls, - AsyncEventHandlerMut, HandlerStatus, NetworkStatus, + AsyncEventHandler, + ConnectOptions, + MqttClient, + NetworkStatus, }; use async_trait::async_trait; use bytes::Bytes; + use packets::QoS; use rustls::ServerName; pub struct PingPong { - pub client: AsyncClient, + pub client: MqttClient, } #[async_trait] - impl AsyncEventHandlerMut for PingPong { - async fn handle(&mut self, event: &packets::Packet) -> () { + impl AsyncEventHandler for PingPong { + async fn handle(&mut self, event: packets::Packet) -> () { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { if payload.to_lowercase().contains("ping") { - self.client - .publish( - p.qos, - p.retain, - p.topic.clone(), - Bytes::from_static(b"pong"), - ) - .await - .unwrap(); - println!("Received Ping, Send pong!"); + self.client.publish(p.qos, p.retain, p.topic.clone(), Bytes::from_static(b"pong")).await.unwrap(); + // println!("Received Ping, Send pong!"); } } } @@ -340,44 +300,38 @@ mod lib_test { let address = "broker.emqx.io"; let port = 1883; - let (mut network, mut handler, client) = new_smol(options); + let (mut network, client) = new_smol(options); - let stream = smol::net::TcpStream::connect((address, port)) - .await - .unwrap(); + let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); network.connect(stream).await.unwrap(); client.subscribe("mqrstt").await.unwrap(); - let mut pingpong = PingPong { - client: client.clone(), - }; + let mut pingpong = PingPong { client: client.clone() }; - let (n, h, _) = futures::join!( + let (n, _) = futures::join!( async { loop { - return match network.run().await { + return match network.poll(&mut pingpong).await { Ok(NetworkStatus::Active) => continue, otherwise => otherwise, }; } }, async { - loop { - return match handler.handle_mut(&mut pingpong).await { - Ok(HandlerStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - smol::Timer::after(std::time::Duration::from_secs(30)).await; + client.publish(QoS::ExactlyOnce, false, "mqrstt".to_string(), b"ping".repeat(500)).await.unwrap(); + client.publish(QoS::AtMostOnce, true, "mqrstt".to_string(), b"ping".to_vec()).await.unwrap(); + client.publish(QoS::AtLeastOnce, false, "mqrstt".to_string(), b"ping".to_vec()).await.unwrap(); + client.publish(QoS::ExactlyOnce, false, "mqrstt".to_string(), b"ping".repeat(500)).await.unwrap(); + + smol::Timer::after(std::time::Duration::from_secs(20)).await; + client.unsubscribe("mqrstt").await.unwrap(); + smol::Timer::after(std::time::Duration::from_secs(5)).await; client.disconnect().await.unwrap(); } ); assert!(n.is_ok()); - assert!(h.is_ok()); }); } @@ -389,89 +343,64 @@ mod lib_test { let address = "broker.emqx.io"; let port = 8883; - let (mut network, mut handler, client) = new_smol(options); + let (mut network, client) = new_smol(options); - let arc_client_config = - simple_rust_tls(crate::tests::resources::EMQX_CERT.to_vec(), None, None).unwrap(); + let arc_client_config = simple_rust_tls(crate::tests::resources::EMQX_CERT.to_vec(), None, None).unwrap(); let domain = ServerName::try_from(address).unwrap(); let connector = async_rustls::TlsConnector::from(arc_client_config); - let stream = smol::net::TcpStream::connect((address, port)) - .await - .unwrap(); + let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let connection = connector.connect(domain, stream).await.unwrap(); network.connect(connection).await.unwrap(); client.subscribe("mqrstt").await.unwrap(); - let mut pingpong = PingPong { - client: client.clone(), - }; + let mut pingpong = PingPong { client: client.clone() }; - let (n, h, _) = futures::join!( + let (n, _) = futures::join!( async { loop { - return match network.run().await { + return match network.poll(&mut pingpong).await { Ok(NetworkStatus::Active) => continue, otherwise => otherwise, }; } }, - async { - loop { - return match handler.handle_mut(&mut pingpong).await { - Ok(HandlerStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, async { smol::Timer::after(std::time::Duration::from_secs(30)).await; client.disconnect().await.unwrap(); } ); assert!(n.is_ok()); - assert!(h.is_ok()); }); } + #[cfg(feature = "tokio")] #[tokio::test] async fn test_tokio_tcp() { let options = ConnectOptions::new("TokioTcpPingPong".to_string()); - let (mut network, mut handler, client) = new_tokio(options); + let (mut network, client) = new_tokio(options); - let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)) - .await - .unwrap(); + let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); network.connect(stream).await.unwrap(); client.subscribe("mqrstt").await.unwrap(); - let mut pingpong = PingPong { - client: client.clone(), - }; + let mut pingpong = PingPong { client: client.clone() }; - let (n, h, _) = tokio::join!( + let (n, _) = tokio::join!( async { loop { - return match network.run().await { + return match network.poll(&mut pingpong).await { Ok(NetworkStatus::Active) => continue, otherwise => otherwise, }; } }, - async { - loop { - return match handler.handle_mut(&mut pingpong).await { - Ok(HandlerStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, async { client.publish(QoS::ExactlyOnce, false, "mqrstt".to_string(), b"ping".repeat(500)).await.unwrap(); client.publish(QoS::AtMostOnce, true, "mqrstt".to_string(), b"ping".to_vec()).await.unwrap(); @@ -484,13 +413,13 @@ mod lib_test { client.disconnect().await.unwrap(); } ); + let n = dbg!(n); assert!(n.is_ok()); - assert!(h.is_ok()); assert_eq!(NetworkStatus::OutgoingDisconnect, n.unwrap()); - assert_eq!(HandlerStatus::OutgoingDisconnect, h.unwrap()); } + #[cfg(feature = "tokio")] #[tokio::test] async fn test_tokio_tls() { let options = ConnectOptions::new("TokioTlsPingPong".to_string()); @@ -498,63 +427,144 @@ mod lib_test { let address = "broker.emqx.io"; let port = 8883; - let (mut network, mut handler, client) = new_tokio(options); + let (mut network, client) = new_tokio(options); - let arc_client_config = - simple_rust_tls(crate::tests::resources::EMQX_CERT.to_vec(), None, None).unwrap(); + let arc_client_config = simple_rust_tls(crate::tests::resources::EMQX_CERT.to_vec(), None, None).unwrap(); let domain = ServerName::try_from(address).unwrap(); let connector = tokio_rustls::TlsConnector::from(arc_client_config); - let stream = tokio::net::TcpStream::connect((address, port)) - .await - .unwrap(); + let stream = tokio::net::TcpStream::connect((address, port)).await.unwrap(); let connection = connector.connect(domain, stream).await.unwrap(); network.connect(connection).await.unwrap(); client.subscribe("mqrstt").await.unwrap(); - let mut pingpong = PingPong { - client: client.clone(), - }; + let mut pingpong = PingPong { client: client.clone() }; - let (n, h, _) = tokio::join!( + let (n, _) = tokio::join!( async { loop { - return match network.run().await { - Ok(NetworkStatus::IncomingDisconnect) => { - Ok(NetworkStatus::IncomingDisconnect) - } - Ok(NetworkStatus::OutgoingDisconnect) => { - Ok(NetworkStatus::OutgoingDisconnect) - } + return match network.poll(&mut pingpong).await { + Ok(NetworkStatus::IncomingDisconnect) => Ok(NetworkStatus::IncomingDisconnect), + Ok(NetworkStatus::OutgoingDisconnect) => Ok(NetworkStatus::OutgoingDisconnect), Ok(NetworkStatus::NoPingResp) => Ok(NetworkStatus::NoPingResp), Ok(NetworkStatus::Active) => continue, Err(a) => Err(a), }; } }, - async { - loop { - return match handler.handle_mut(&mut pingpong).await { - Ok(HandlerStatus::IncomingDisconnect) => { - Ok(NetworkStatus::IncomingDisconnect) - } - Ok(HandlerStatus::OutgoingDisconnect) => { - Ok(NetworkStatus::OutgoingDisconnect) - } - Ok(HandlerStatus::Active) => continue, - Err(a) => Err(a), - }; - } - }, async { tokio::time::sleep(Duration::from_secs(30)).await; client.disconnect().await.unwrap(); } ); assert!(n.is_ok()); - assert!(h.is_ok()); + assert_eq!(NetworkStatus::OutgoingDisconnect, n.unwrap()); + } + + pub struct PingResp { + pub client: MqttClient, + pub ping_resp_received: u64, + } + + impl PingResp { + pub fn new(client: MqttClient) -> Self { + Self { client, ping_resp_received: 0 } + } + } + + #[async_trait] + impl AsyncEventHandler for PingResp { + async fn handle(&mut self, event: packets::Packet) -> () { + use Packet::*; + match event { + PingResp => { + self.ping_resp_received += 1; + } + _ => (), + } + } + } + + #[cfg(feature = "tokio")] + #[tokio::test] + async fn test_tokio_ping_req() { + let options = ConnectOptions::new("TokioTcpPingReqTest".to_string()); + + let (mut network, client) = new_tokio(options); + + let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + + network.connect(stream).await.unwrap(); + + let mut pingresp = PingResp::new(client.clone()); + + let futs = tokio::task::spawn(async { + tokio::join!( + async move { + loop { + match network.poll(&mut pingresp).await { + Ok(NetworkStatus::Active) => continue, + Ok(NetworkStatus::OutgoingDisconnect) => return Ok(pingresp), + Ok(NetworkStatus::NoPingResp) => panic!(), + Ok(NetworkStatus::IncomingDisconnect) => panic!(), + Err(err) => return Err(err), + } + } + }, + async move { + smol::Timer::after(std::time::Duration::from_secs(125)).await; + client.disconnect().await.unwrap(); + } + ) + }); + + tokio::time::sleep(Duration::new(125, 0)).await; + + let (n, _) = futs.await.unwrap(); + assert!(n.is_ok()); + let pingresp = n.unwrap(); + assert_eq!(2, pingresp.ping_resp_received); + } + + #[test] + fn test_smol_ping_req() { + smol::block_on(async { + let options = ConnectOptions::new("SmolTcpPingReq".to_string()); + + let address = "broker.emqx.io"; + let port = 1883; + + let (mut network, client) = new_smol(options); + + let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); + + network.connect(stream).await.unwrap(); + + let mut pingresp = PingResp::new(client.clone()); + + let (n, _) = futures::join!( + async { + loop { + match network.poll(&mut pingresp).await { + Ok(NetworkStatus::Active) => continue, + Ok(NetworkStatus::OutgoingDisconnect) => return Ok(pingresp), + Ok(NetworkStatus::NoPingResp) => panic!(), + Ok(NetworkStatus::IncomingDisconnect) => panic!(), + Err(err) => return Err(err), + } + } + }, + async { + smol::Timer::after(std::time::Duration::from_secs(125)).await; + client.disconnect().await.unwrap(); + } + ); + assert!(n.is_ok()); + let pingreq = n.unwrap(); + assert_eq!(2, pingreq.ping_resp_received); + }); } } diff --git a/src/mqtt_handler.rs b/src/mqtt_handler.rs index 0b44480..037abf9 100644 --- a/src/mqtt_handler.rs +++ b/src/mqtt_handler.rs @@ -1,7 +1,6 @@ use crate::connect_options::ConnectOptions; -use crate::error::MqttError; -use crate::packets::reason_codes::{PubAckReasonCode, PubRecReasonCode}; -use crate::packets::Disconnect; +use crate::error::HandlerError; +use crate::packets::reason_codes::{ConnAckReasonCode, PubAckReasonCode, PubRecReasonCode}; use crate::packets::PubComp; use crate::packets::PubRec; use crate::packets::PubRel; @@ -11,15 +10,13 @@ use crate::packets::SubAck; use crate::packets::Subscribe; use crate::packets::UnsubAck; use crate::packets::Unsubscribe; +use crate::packets::{ConnAck, Disconnect}; use crate::packets::{Packet, PacketType}; use crate::packets::{PubAck, PubAckProperties}; use crate::state::State; -use crate::{AsyncEventHandler, AsyncEventHandlerMut, HandlerStatus}; -use futures::FutureExt; - -use async_channel::{Receiver, Sender}; -use tracing::error; +use async_channel::Receiver; +use tracing::{error, info, warn}; #[cfg(test)] use tracing::debug; @@ -27,316 +24,167 @@ use tracing::debug; /// Eventloop with all the state of a connection pub struct MqttHandler { state: State, - - network_receiver: Receiver, - - network_sender: Sender, - - client_to_handler_r: Receiver, - - disconnect: bool, + clean_start: bool, } /// [`MqttHandler`] is used to properly handle incoming and outgoing packets according to the MQTT specifications. /// Only the incoming messages are shown to the user via the user provided handler. impl MqttHandler { - pub(crate) fn new( - options: &ConnectOptions, - network_receiver: Receiver, - network_sender: Sender, - client_to_handler_r: Receiver, - ) -> (Self, Receiver) { + pub(crate) fn new(options: &ConnectOptions) -> (Self, Receiver) { let (state, packet_id_channel) = State::new(options.receive_maximum()); - let task = MqttHandler { + let handler = MqttHandler { state, - network_receiver, - network_sender, - - client_to_handler_r, - - disconnect: false, + clean_start: options.clean_start, }; - (task, packet_id_channel) + (handler, packet_id_channel) } - // pub fn sync_handle(&self, handler: &mut H) -> Result<(), MqttError> { - // match self.network_receiver.try_recv() { - // Ok(event) => { - // handler.handle(&event); - // } - // Err(err) => { - // if err.is_closed() { - // return Err(MqttError::IncomingNetworkChannelClosed); - // } - // } - // } - // match self.client_to_handler_r.try_recv() { - // Ok(_) => {} - // Err(err) => { - // if err.is_closed() { - // return Err(MqttError::IncomingNetworkChannelClosed); - // } - // } - // } - // Ok(()) - // } - - pub async fn handle(&mut self, handler: &H) -> Result - where - H: AsyncEventHandler, - { - futures::select! { - incoming = self.network_receiver.recv().fuse() => { - match incoming { - Ok(event) => { - // debug!("Event Handler, handling incoming packet: {}", event); - handler.handle(&event).await; - self.handle_incoming_packet(event).await?; - } - Err(_) => return Err(MqttError::IncomingNetworkChannelClosed), - } - if self.disconnect { - self.disconnect = true; - return Ok(HandlerStatus::IncomingDisconnect); - } - }, - outgoing = self.client_to_handler_r.recv().fuse() => { - match outgoing { - Ok(event) => { - // debug!("Event Handler, handling outgoing packet: {}", event); - self.handle_outgoing_packet(event).await? - } - Err(_) => return Err(MqttError::ClientChannelClosed), - } - if self.disconnect { - self.disconnect = true; - return Ok(HandlerStatus::OutgoingDisconnect); - } - } - } - Ok(HandlerStatus::Active) - } - - pub async fn handle_mut(&mut self, handler: &mut H) -> Result - where - H: AsyncEventHandlerMut, - { - futures::select! { - incoming = self.network_receiver.recv().fuse() => { - match incoming { - Ok(event) => { - // debug!("Event Handler, handling incoming packet: {}", event); - handler.handle(&event).await; - self.handle_incoming_packet(event).await?; - } - Err(_) => return Err(MqttError::IncomingNetworkChannelClosed), - } - if self.disconnect { - self.disconnect = true; - return Ok(HandlerStatus::IncomingDisconnect); - } - }, - outgoing = self.client_to_handler_r.recv().fuse() => { - match outgoing { - Ok(event) => { - // debug!("Event Handler, handling outgoing packet: {}", event); - self.handle_outgoing_packet(event).await? - } - Err(_) => return Err(MqttError::ClientChannelClosed), - } - if self.disconnect { - self.disconnect = true; - return Ok(HandlerStatus::OutgoingDisconnect); - } - } - } - Ok(HandlerStatus::Active) - } - - async fn handle_incoming_packet(&mut self, packet: Packet) -> Result<(), MqttError> { + /// This function handles the incoming packet `packet` depending on the packet type. + /// Any packets that are produced as a response to the incoming packet are added to the outgoing_packet_buffer. + /// + /// # Return value + /// This function returns either an error or an indication wether the users handler needs to be called on this packet. + /// In some cases (retransmitted Publish packets) the users handler should not be called to avoid duplicate delivery. + /// true is returned if the users handler should be called + /// false otherwise + pub async fn handle_incoming_packet(&mut self, packet: &Packet, outgoing_packet_buffer: &mut Vec) -> Result { match packet { - Packet::Publish(publish) => self.handle_incoming_publish(&publish).await?, - Packet::PubAck(puback) => self.handle_incoming_puback(&puback).await?, - Packet::PubRec(pubrec) => self.handle_incoming_pubrec(&pubrec).await?, - Packet::PubRel(pubrel) => self.handle_incoming_pubrel(&pubrel).await?, - Packet::PubComp(pubcomp) => self.handle_incoming_pubcomp(&pubcomp).await?, - Packet::SubAck(suback) => self.handle_incoming_suback(suback).await?, - Packet::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback).await?, - Packet::PingResp => (), - Packet::ConnAck(_) => (), - Packet::Disconnect(_) => { - self.disconnect = true; - } + Packet::Publish(publish) => return self.handle_incoming_publish(publish, outgoing_packet_buffer).await, + Packet::PubAck(puback) => self.handle_incoming_puback(puback, outgoing_packet_buffer).await?, + Packet::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec, outgoing_packet_buffer).await?, + Packet::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel, outgoing_packet_buffer).await?, + Packet::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp, outgoing_packet_buffer).await?, + Packet::SubAck(suback) => self.handle_incoming_suback(suback, outgoing_packet_buffer).await?, + Packet::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback, outgoing_packet_buffer).await?, + Packet::ConnAck(connack) => self.handle_incoming_connack(connack, outgoing_packet_buffer).await?, a => unreachable!("Should not receive {}", a), }; - Ok(()) + Ok(false) } - async fn handle_incoming_publish(&mut self, publish: &Publish) -> Result<(), MqttError> { + async fn handle_incoming_publish(&mut self, publish: &Publish, outgoing_packet_buffer: &mut Vec) -> Result { match publish.qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => Ok(true), QoS::AtLeastOnce => { let puback = PubAck { - packet_identifier: publish - .packet_identifier - .ok_or(MqttError::MissingPacketId)?, + packet_identifier: publish.packet_identifier.ok_or(HandlerError::MissingPacketId)?, reason_code: PubAckReasonCode::Success, properties: PubAckProperties::default(), }; - self.network_sender.send(Packet::PubAck(puback)).await?; - Ok(()) + outgoing_packet_buffer.push(Packet::PubAck(puback)); + Ok(true) } QoS::ExactlyOnce => { - let pkid = publish - .packet_identifier - .ok_or(MqttError::MissingPacketId)?; - if !self.state.incoming_pub.insert(pkid) && !publish.dup { - error!( - "Received publish with an packet ID ({}) that is in use and the packet was not a duplicate", - pkid, - ); - } + let mut should_client_handle = true; + let pkid = publish.packet_identifier.ok_or(HandlerError::MissingPacketId)?; - let pubrec = PubRec::new(pkid); - self.network_sender.send(Packet::PubRec(pubrec)).await?; - - Ok(()) + if !self.state.add_incoming_pub(pkid) && !publish.dup { + warn!("Received publish with an packet ID ({}) that is in use and the packet was not a duplicate", pkid,); + should_client_handle = false; + } + outgoing_packet_buffer.push(Packet::PubRec(PubRec::new(pkid))); + Ok(should_client_handle) } } } - async fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), MqttError> { - if self - .state - .outgoing_pub - .remove(&puback.packet_identifier) - .is_some() - { + async fn handle_incoming_puback(&mut self, puback: &PubAck, _: &mut Vec) -> Result<(), HandlerError> { + if self.state.remove_outgoing_pub(puback.packet_identifier).is_some() { #[cfg(test)] - debug!( - "Publish {:?} has been acknowledged", - puback.packet_identifier - ); - self.state - .apkid - .mark_available(puback.packet_identifier) - .await?; + debug!("Publish {:?} has been acknowledged", puback.packet_identifier); + self.state.make_pkid_available(puback.packet_identifier)?; Ok(()) } else { - error!( - "Publish {:?} was not found, while receiving a PubAck for it", - puback.packet_identifier, - ); - Err(MqttError::Unsolicited( - puback.packet_identifier, - PacketType::PubAck, - )) + error!("Publish {:?} was not found, while receiving a PubAck for it", puback.packet_identifier,); + Err(HandlerError::Unsolicited(puback.packet_identifier, PacketType::PubAck)) } } - async fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), MqttError> { - match self.state.outgoing_pub.remove(&pubrec.packet_identifier) { + async fn handle_incoming_pubrec(&mut self, pubrec: &PubRec, outgoing_packet_buffer: &mut Vec) -> Result<(), HandlerError> { + match self.state.remove_outgoing_pub(pubrec.packet_identifier) { Some(_) => match pubrec.reason_code { PubRecReasonCode::Success | PubRecReasonCode::NoMatchingSubscribers => { let pubrel = PubRel::new(pubrec.packet_identifier); - self.state.outgoing_rel.insert(pubrec.packet_identifier); - self.network_sender.send(Packet::PubRel(pubrel)).await?; + + self.state.add_outgoing_rel(pubrec.packet_identifier); #[cfg(test)] debug!("Publish {:?} has been PubReced", pubrec.packet_identifier); + + outgoing_packet_buffer.push(Packet::PubRel(pubrel)); Ok(()) } _ => Ok(()), }, None => { - error!( - "Publish {} was not found, while receiving a PubRec for it", - pubrec.packet_identifier, - ); - Err(MqttError::Unsolicited( - pubrec.packet_identifier, - PacketType::PubRec, - )) + error!("Publish {} was not found, while receiving a PubRec for it", pubrec.packet_identifier,); + Err(HandlerError::Unsolicited(pubrec.packet_identifier, PacketType::PubRec)) } } } - async fn handle_incoming_pubrel(&self, pubrel: &PubRel) -> Result<(), MqttError> { - let pubcomp = PubComp::new(pubrel.packet_identifier); - self.network_sender.send(Packet::PubComp(pubcomp)).await?; - Ok(()) + async fn handle_incoming_pubrel(&mut self, pubrel: &PubRel, outgoing_packet_buffer: &mut Vec) -> Result<(), HandlerError> { + if self.state.remove_incoming_pub(pubrel.packet_identifier) { + let pubcomp = PubComp::new(pubrel.packet_identifier); + outgoing_packet_buffer.push(Packet::PubComp(pubcomp)); + Ok(()) + } else { + Err(HandlerError::Unsolicited(pubrel.packet_identifier, PacketType::PubRel)) + } } - async fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), MqttError> { - if self.state.outgoing_rel.remove(&pubcomp.packet_identifier) { - self.state - .apkid - .mark_available(pubcomp.packet_identifier) - .await?; + async fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp, _: &mut Vec) -> Result<(), HandlerError> { + if self.state.remove_outgoing_rel(&pubcomp.packet_identifier) { + self.state.make_pkid_available(pubcomp.packet_identifier)?; Ok(()) } else { - error!( - "PubRel {} was not found, while receiving a PubComp for it", - pubcomp.packet_identifier, - ); - Err(MqttError::Unsolicited( - pubcomp.packet_identifier, - PacketType::PubComp, - )) + error!("PubRel {} was not found, while receiving a PubComp for it", pubcomp.packet_identifier,); + Err(HandlerError::Unsolicited(pubcomp.packet_identifier, PacketType::PubComp)) } } - async fn handle_incoming_suback(&mut self, suback: SubAck) -> Result<(), MqttError> { - if self - .state - .outgoing_sub - .remove(&suback.packet_identifier) - .is_some() - { - self.state - .apkid - .mark_available(suback.packet_identifier) - .await?; + async fn handle_incoming_suback(&mut self, suback: &SubAck, _: &mut Vec) -> Result<(), HandlerError> { + if self.state.remove_outgoing_sub(suback.packet_identifier) { + self.state.make_pkid_available(suback.packet_identifier)?; Ok(()) } else { - error!( - "Sub {} was not found, while receiving a SubAck for it", - suback.packet_identifier, - ); - Err(MqttError::Unsolicited( - suback.packet_identifier, - PacketType::SubAck, - )) + error!("Sub {} was not found, while receiving a SubAck for it", suback.packet_identifier,); + Err(HandlerError::Unsolicited(suback.packet_identifier, PacketType::SubAck)) } } - async fn handle_incoming_unsuback(&mut self, unsuback: UnsubAck) -> Result<(), MqttError> { - if self - .state - .outgoing_unsub - .remove(&unsuback.packet_identifier) - .is_some() - { - self.state - .apkid - .mark_available(unsuback.packet_identifier) - .await?; + async fn handle_incoming_unsuback(&mut self, unsuback: &UnsubAck, _: &mut Vec) -> Result<(), HandlerError> { + if self.state.remove_outgoing_unsub(unsuback.packet_identifier) { + self.state.make_pkid_available(unsuback.packet_identifier)?; Ok(()) } else { - error!( - "Unsub {} was not found, while receiving a unsuback for it", - unsuback.packet_identifier, - ); - Err(MqttError::Unsolicited( - unsuback.packet_identifier, - PacketType::UnsubAck, - )) + error!("Unsub {} was not found, while receiving a unsuback for it", unsuback.packet_identifier,); + Err(HandlerError::Unsolicited(unsuback.packet_identifier, PacketType::UnsubAck)) } } - async fn handle_outgoing_packet(&mut self, packet: Packet) -> Result<(), MqttError> { + async fn handle_incoming_connack(&mut self, packet: &ConnAck, outgoing_packet_buffer: &mut Vec) -> Result<(), HandlerError> { + if packet.reason_code == ConnAckReasonCode::Success { + let retransmission = packet.connack_flags.session_present && !self.clean_start; + let (freeable_ids, mut republish) = self.state.reset(retransmission); + + for i in freeable_ids { + self.state.make_pkid_available(i)?; + } + outgoing_packet_buffer.append(&mut republish); + } + Ok(()) + } + + // async fn handle_incoming_disconnect(&mut self, packet: Disconnect) -> Result<(), MqttError> { + // self.disconnect = true; + // Ok(()) + // } + + pub async fn handle_outgoing_packet(&mut self, packet: Packet) -> Result<(), HandlerError> { + info!("Handling outgoing packet {}", packet); match packet { Packet::Publish(publish) => self.handle_outgoing_publish(publish).await, Packet::Subscribe(sub) => self.handle_outgoing_subscribe(sub).await, @@ -346,412 +194,258 @@ impl MqttHandler { } } - async fn handle_outgoing_publish(&mut self, publish: Publish) -> Result<(), MqttError> { + async fn handle_outgoing_publish(&mut self, publish: Publish) -> Result<(), HandlerError> { + let id = publish.packet_identifier; match publish.qos { - QoS::AtMostOnce => { - self.network_sender.send(Packet::Publish(publish)).await?; - } - QoS::AtLeastOnce => { - self.network_sender - .send(Packet::Publish(publish.clone())) - .await?; - if let Some(pub_collision) = self - .state - .outgoing_pub - .insert(publish.packet_identifier.unwrap(), publish) - { - error!( - "Encountered a colliding packet ID ({:?}) in a publish QoS 1 packet", - pub_collision.packet_identifier, - ) + QoS::AtMostOnce => Ok(()), + QoS::AtLeastOnce => match self.state.add_outgoing_pub(publish.packet_identifier.unwrap(), publish) { + Ok(_) => Ok(()), + Err(err) => { + error!("Encountered a colliding packet ID ({:?}) in a publish QoS 1 packet", id,); + Err(err) } - } - QoS::ExactlyOnce => { - self.network_sender - .send(Packet::Publish(publish.clone())) - .await?; - if let Some(pub_collision) = self - .state - .outgoing_pub - .insert(publish.packet_identifier.unwrap(), publish) - { - error!( - "Encountered a colliding packet ID ({:?}) in a publish QoS 2 packet", - pub_collision.packet_identifier, - ) + }, + QoS::ExactlyOnce => match self.state.add_outgoing_pub(publish.packet_identifier.unwrap(), publish) { + Ok(_) => Ok(()), + Err(err) => { + error!("Encountered a colliding packet ID ({:?}) in a publish QoS 2 packet", id,); + Err(err) } - } + }, } - Ok(()) } - async fn handle_outgoing_subscribe(&mut self, sub: Subscribe) -> Result<(), MqttError> { - if self - .state - .outgoing_sub - .insert(sub.packet_identifier, sub.clone()) - .is_some() - { - error!( - "Encountered a colliding packet ID ({}) in a subscribe packet", - sub.packet_identifier, - ) - } else { - self.network_sender.send(Packet::Subscribe(sub)).await?; + async fn handle_outgoing_subscribe(&mut self, sub: Subscribe) -> Result<(), HandlerError> { + info!("handling outgoing subscribe with ID: {}", sub.packet_identifier); + if !self.state.add_outgoing_sub(sub.packet_identifier) { + error!("Encountered a colliding packet ID ({}) in a subscribe packet\n {:?}", sub.packet_identifier, sub,); } Ok(()) } - async fn handle_outgoing_unsubscribe(&mut self, unsub: Unsubscribe) -> Result<(), MqttError> { - if self - .state - .outgoing_unsub - .insert(unsub.packet_identifier, unsub.clone()) - .is_some() - { - error!( - "Encountered a colliding packet ID ({}) in a unsubscribe packet", - unsub.packet_identifier, - ) - } else { - self.network_sender.send(Packet::Unsubscribe(unsub)).await?; + async fn handle_outgoing_unsubscribe(&mut self, unsub: Unsubscribe) -> Result<(), HandlerError> { + if !self.state.add_outgoing_unsub(unsub.packet_identifier) { + error!("Encountered a colliding packet ID ({}) in a unsubscribe packet", unsub.packet_identifier,); } Ok(()) } - async fn handle_outgoing_disconnect( - &mut self, - disconnect: Disconnect, - ) -> Result<(), MqttError> { - // self.atomic_disconnect.store(true, Ordering::Release); - self.disconnect = true; - self.network_sender - .send(Packet::Disconnect(disconnect)) - .await?; + async fn handle_outgoing_disconnect(&mut self, _: Disconnect) -> Result<(), HandlerError> { Ok(()) } } #[cfg(test)] mod handler_tests { - use std::{time::Duration}; - - use async_channel::{Receiver, Sender}; + use async_channel::Receiver; use crate::{ - ConnectOptions, - MqttHandler, packets::{ - reason_codes::{ - PubCompReasonCode, PubRecReasonCode, PubRelReasonCode, SubAckReasonCode, - }, - QoS, {Packet, PacketType}, {PubComp, PubCompProperties}, {PubRec, PubRecProperties}, - {PubRel, PubRelProperties}, {SubAck, SubAckProperties}, + reason_codes::{PubCompReasonCode, PubRecReasonCode, PubRelReasonCode, SubAckReasonCode, UnsubAckReasonCode}, + Packet, QoS, UnsubAck, UnsubAckProperties, {PubComp, PubCompProperties}, {PubRec, PubRecProperties}, {PubRel, PubRelProperties}, {SubAck, SubAckProperties}, }, - tests::test_packets::{ - create_disconnect_packet, create_puback_packet, create_publish_packet, - create_subscribe_packet, - }, - AsyncEventHandlerMut, + tests::test_packets::{create_connack_packet, create_puback_packet, create_publish_packet, create_subscribe_packet, create_unsubscribe_packet}, + AsyncEventHandler, ConnectOptions, MqttHandler, }; pub struct Nop {} #[async_trait::async_trait] - impl AsyncEventHandlerMut for Nop { - async fn handle(&mut self, _event: &Packet) {} + impl AsyncEventHandler for Nop { + async fn handle(&mut self, _event: Packet) {} } - fn handler() -> ( - MqttHandler, - Receiver, - Sender, - Sender, - ) { - let opt = ConnectOptions::new("test123123".to_string()); - - let (to_network_s, to_network_r) = async_channel::bounded(100); - let (network_to_handler_s, network_to_handler_r) = async_channel::bounded(100); - let (client_to_handler_s, client_to_handler_r) = async_channel::bounded(100); - - let (handler, _apkid) = MqttHandler::new( - &opt, - network_to_handler_r, - to_network_s, - client_to_handler_r, - ); - ( - handler, - to_network_r, - network_to_handler_s, - client_to_handler_s, - ) + fn handler(clean_start: bool) -> (MqttHandler, Receiver) { + let mut opt = ConnectOptions::new("test123123".to_string()); + opt.receive_maximum = Some(100); + opt.clean_start = clean_start; + + MqttHandler::new(&opt) } - #[tokio::test(flavor = "multi_thread")] + #[tokio::test] async fn outgoing_publish_qos_0() { - let mut nop = Nop {}; + let (mut handler, _apkid) = handler(false); - let (mut handler, to_network_r, _network_to_handler_s, client_to_handler_s) = handler(); - - let handler_task = tokio::task::spawn(async move { - let _ = loop { - match handler.handle_mut(&mut nop).await { - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); let pub_packet = create_publish_packet(QoS::AtMostOnce, false, false, None); - client_to_handler_s.send(pub_packet.clone()).await.unwrap(); - - let packet = to_network_r.recv().await.unwrap(); - - assert_eq!(packet, pub_packet); - - // If we drop the client to handler channel the handler will stop executing and we can inspect its internals. - drop(client_to_handler_s); - - let handler = handler_task.await.unwrap(); + handler.handle_outgoing_packet(pub_packet).await.unwrap(); - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); + assert!(handler.state.incoming_pub().is_empty()); + assert_eq!(handler.state.outgoing_pub_order().len(), 0); + assert_eq!(handler.state.outgoing_pub().len(), 100); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); } - #[tokio::test(flavor = "multi_thread")] + #[tokio::test] async fn outgoing_publish_qos_1() { - pub struct TestPubQoS1 { - stage: StagePubQoS1, - } - pub enum StagePubQoS1 { - PubAck, - Done, - } - impl TestPubQoS1 { - fn new() -> Self { - TestPubQoS1 { - stage: StagePubQoS1::PubAck, - } - } - } - #[async_trait::async_trait] - impl AsyncEventHandlerMut for TestPubQoS1 { - async fn handle(&mut self, event: &Packet) { - match self.stage { - StagePubQoS1::PubAck => { - assert_eq!(event.packet_type(), PacketType::PubAck); - self.stage = StagePubQoS1::Done; - } - StagePubQoS1::Done => (), - } - } - } - - let mut nop = TestPubQoS1::new(); + let (mut handler, apkid) = handler(false); + let mut resp_vec = Vec::new(); - let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); - - let handler_task = tokio::task::spawn(async move { - // Ignore the error that this will return - let _ = loop { - match handler.handle_mut(&mut nop).await { - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); - let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); + let pkid = apkid.recv().await.unwrap(); + let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(pkid)); - client_to_handler_s.send(pub_packet.clone()).await.unwrap(); + handler.handle_outgoing_packet(pub_packet.clone()).await.unwrap(); - let publish = to_network_r.recv().await.unwrap(); + assert!(handler.state.incoming_pub().is_empty()); + assert_eq!(handler.state.outgoing_pub_order().len(), 1); + assert_eq!(Packet::Publish(handler.state.outgoing_pub()[pkid as usize - 1].clone().unwrap()), pub_packet); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); - assert_eq!(pub_packet, publish); + let puback = create_puback_packet(pkid); + handler.handle_incoming_packet(&puback, &mut resp_vec).await.unwrap(); - let puback = create_puback_packet(1); + assert!(resp_vec.is_empty()); + assert!(handler.state.incoming_pub().is_empty()); + assert!(handler.state.outgoing_pub()[pkid as usize - 1].is_none()); + assert!(handler.state.outgoing_pub_order().is_empty()); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); + } - network_to_handler_s.send(puback).await.unwrap(); + #[tokio::test] + async fn incoming_publish_qos_1() { + let (mut handler, _apkid) = handler(false); + let mut resp_vec = Vec::new(); - tokio::time::sleep(Duration::new(5, 0)).await; + let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); - // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals. - drop(client_to_handler_s); - drop(network_to_handler_s); + handler.handle_incoming_packet(&pub_packet, &mut resp_vec).await.unwrap(); - let handler = handler_task.await.unwrap(); + assert!(handler.state.incoming_pub().is_empty()); + assert!(handler.state.outgoing_pub_order().is_empty()); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); + assert_eq!(1, resp_vec.len()); + let expected_puback = create_puback_packet(1); + assert_eq!(expected_puback, resp_vec.pop().unwrap()); } - #[tokio::test(flavor = "multi_thread")] - async fn incoming_publish_qos_1() { - let mut nop = Nop {}; + #[tokio::test] + async fn outgoing_publish_qos_2() { + let (mut handler, apkid) = handler(false); + let mut resp_vec = Vec::new(); - let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); + let pkid = apkid.recv().await.unwrap(); + let pub_packet = create_publish_packet(QoS::ExactlyOnce, false, false, Some(pkid)); - let handler_task = tokio::task::spawn(async move { - // Ignore the error that this will return - let _ = loop { - match handler.handle_mut(&mut nop).await { - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); - let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); + handler.handle_outgoing_packet(pub_packet.clone()).await.unwrap(); - network_to_handler_s.send(pub_packet.clone()).await.unwrap(); + assert!(handler.state.incoming_pub().is_empty()); + assert_eq!(pub_packet.clone(), Packet::Publish(handler.state.outgoing_pub()[pkid as usize - 1].clone().unwrap())); + assert_eq!(1, handler.state.outgoing_pub_order().len()); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); - let puback = to_network_r.recv().await.unwrap(); + let pubrec = Packet::PubRec(PubRec { + packet_identifier: pkid, + reason_code: PubRecReasonCode::Success, + properties: PubRecProperties::default(), + }); - assert_eq!(PacketType::PubAck, puback.packet_type()); + handler.handle_incoming_packet(&pubrec, &mut resp_vec).await.unwrap(); - let expected_puback = create_puback_packet(1); + assert!(handler.state.incoming_pub().is_empty()); + assert!(handler.state.outgoing_pub()[pkid as usize - 1].clone().is_none()); + assert_eq!(0, handler.state.outgoing_pub_order().len()); + assert_eq!(1, handler.state.outgoing_rel().len()); + assert_eq!(1, *handler.state.outgoing_rel().first().unwrap()); + assert!(handler.state.outgoing_sub().is_empty()); - assert_eq!(expected_puback, puback); + let expected_pubrel = Packet::PubRel(PubRel { + packet_identifier: pkid, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }); + assert_eq!(expected_pubrel, resp_vec.pop().unwrap()); - // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals. - drop(client_to_handler_s); - drop(network_to_handler_s); + let pubcomp = Packet::PubComp(PubComp { + packet_identifier: pkid, + reason_code: PubCompReasonCode::Success, + properties: PubCompProperties::default(), + }); - let handler = handler_task.await.unwrap(); + handler.handle_incoming_packet(&pubcomp, &mut resp_vec).await.unwrap(); - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); + assert!(resp_vec.is_empty()); + assert!(handler.state.incoming_pub().is_empty()); + assert!(handler.state.outgoing_pub()[pkid as usize - 1].clone().is_none()); + assert!(handler.state.outgoing_pub_order().is_empty()); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); } - #[tokio::test(flavor = "multi_thread")] - async fn outgoing_publish_qos_2() { - pub struct TestPubQoS2 { - stage: StagePubQoS2, - client_to_handler_s: Sender, - } - pub enum StagePubQoS2 { - PubRec, - PubComp, - Done, - } - impl TestPubQoS2 { - fn new(client_to_handler_s: Sender) -> Self { - TestPubQoS2 { - stage: StagePubQoS2::PubRec, - client_to_handler_s, - } - } - } - #[async_trait::async_trait] - impl AsyncEventHandlerMut for TestPubQoS2 { - async fn handle(&mut self, event: &Packet) { - match self.stage { - StagePubQoS2::PubRec => { - assert_eq!(event.packet_type(), PacketType::PubRec); - self.stage = StagePubQoS2::PubComp; - } - StagePubQoS2::PubComp => { - assert_eq!(event.packet_type(), PacketType::PubComp); - self.stage = StagePubQoS2::Done; - self.client_to_handler_s - .send(create_disconnect_packet()) - .await - .unwrap(); - } - StagePubQoS2::Done => (), - } - } - } - - let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); - - let mut nop = TestPubQoS2::new(client_to_handler_s.clone()); - - let handler_task = tokio::task::spawn(async move { - let _ = loop { - match handler.handle_mut(&mut nop).await { - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); - let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); + #[tokio::test] + async fn incoming_publish_qos_2() { + let (mut handler, apkid) = handler(false); + let mut resp_vec = Vec::new(); - client_to_handler_s.send(pub_packet.clone()).await.unwrap(); + let pkid = apkid.recv().await.unwrap(); + let pub_packet = create_publish_packet(QoS::ExactlyOnce, false, false, Some(pkid)); - let publish = to_network_r.recv().await.unwrap(); + handler.handle_incoming_packet(&pub_packet, &mut resp_vec).await.unwrap(); - assert_eq!(pub_packet, publish); + assert_eq!(1, handler.state.incoming_pub().len()); + assert!(handler.state.outgoing_pub_order().is_empty()); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); + assert!(handler.state.outgoing_unsub().is_empty()); let pubrec = Packet::PubRec(PubRec { - packet_identifier: 1, + packet_identifier: pkid, reason_code: PubRecReasonCode::Success, properties: PubRecProperties::default(), }); - network_to_handler_s.send(pubrec).await.unwrap(); - - let packet = to_network_r.recv().await.unwrap(); + assert_eq!(pubrec, resp_vec.pop().unwrap()); + assert!(resp_vec.is_empty()); + assert_eq!(1, handler.state.incoming_pub().len()); + assert_eq!(0, handler.state.outgoing_pub_order().len()); + assert_eq!(0, handler.state.outgoing_rel().len()); + assert!(handler.state.outgoing_sub().is_empty()); + assert!(handler.state.outgoing_unsub().is_empty()); - let expected_pubrel = Packet::PubRel(PubRel { - packet_identifier: 1, + let pubrel = Packet::PubRel(PubRel { + packet_identifier: pkid, reason_code: PubRelReasonCode::Success, properties: PubRelProperties::default(), }); - assert_eq!(expected_pubrel, packet); + handler.handle_incoming_packet(&pubrel, &mut resp_vec).await.unwrap(); let pubcomp = Packet::PubComp(PubComp { - packet_identifier: 1, + packet_identifier: pkid, reason_code: PubCompReasonCode::Success, properties: PubCompProperties::default(), }); - network_to_handler_s.send(pubcomp).await.unwrap(); - - drop(client_to_handler_s); - drop(network_to_handler_s); - - let handler = handler_task.await.unwrap(); - - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); + assert_eq!(pubcomp, resp_vec.pop().unwrap()); + assert!(resp_vec.is_empty()); + assert!(handler.state.incoming_pub().is_empty()); + assert!(handler.state.outgoing_pub_order().is_empty()); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); } - #[tokio::test(flavor = "multi_thread")] + #[tokio::test] async fn outgoing_subscribe() { - let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); + let (mut handler, apkid) = handler(false); + let mut resp_vec = Vec::new(); - let mut nop = Nop {}; + let pkid = apkid.recv().await.unwrap(); - let handler_task = tokio::task::spawn(async move { - // Ignore the error that this will return - let _ = loop { - match handler.handle_mut(&mut nop).await { - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); - - let sub_packet = create_subscribe_packet(1); + let sub_packet = create_subscribe_packet(pkid); - client_to_handler_s.send(sub_packet.clone()).await.unwrap(); + handler.handle_outgoing_packet(sub_packet).await.unwrap(); - let sub_result = to_network_r.recv().await.unwrap(); - - assert_eq!(sub_packet, sub_result); + assert!(resp_vec.is_empty()); + assert!(handler.state.incoming_pub().is_empty()); + assert!(handler.state.outgoing_pub_order().is_empty()); + assert!(handler.state.outgoing_rel().is_empty()); + assert_eq!(1, handler.state.outgoing_sub().len()); + assert!(handler.state.outgoing_unsub().is_empty()); let suback = Packet::SubAck(SubAck { packet_identifier: 1, @@ -759,17 +453,87 @@ mod handler_tests { properties: SubAckProperties::default(), }); - network_to_handler_s.send(suback).await.unwrap(); + handler.handle_incoming_packet(&suback, &mut resp_vec).await.unwrap(); + + assert!(handler.state.incoming_pub().is_empty()); + assert!(handler.state.outgoing_pub_order().is_empty()); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); + assert!(handler.state.outgoing_unsub().is_empty()); + } + + #[tokio::test] + async fn outgoing_unsubscribe() { + let (mut handler, apkid) = handler(false); + let mut resp_vec = Vec::new(); + + let pkid = apkid.recv().await.unwrap(); + + let unsub_packet = create_unsubscribe_packet(pkid); - tokio::time::sleep(Duration::new(2, 0)).await; + handler.handle_outgoing_packet(unsub_packet).await.unwrap(); - drop(client_to_handler_s); - drop(network_to_handler_s); + assert!(resp_vec.is_empty()); + assert!(handler.state.incoming_pub().is_empty()); + assert!(handler.state.outgoing_pub_order().is_empty()); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); + assert_eq!(1, handler.state.outgoing_unsub().len()); - let handler = handler_task.await.unwrap(); - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); + let unsuback = Packet::UnsubAck(UnsubAck { + packet_identifier: 1, + reason_codes: vec![UnsubAckReasonCode::Success], + properties: UnsubAckProperties::default(), + }); + + handler.handle_incoming_packet(&unsuback, &mut resp_vec).await.unwrap(); + + assert!(handler.state.incoming_pub().is_empty()); + assert!(handler.state.outgoing_pub_order().is_empty()); + assert!(handler.state.outgoing_rel().is_empty()); + assert!(handler.state.outgoing_sub().is_empty()); + assert!(handler.state.outgoing_unsub().is_empty()); + } + + #[tokio::test] + async fn retransmit_test_1() { + let (mut handler, apkid) = handler(false); + let mut stored_published_packets = Vec::new(); + let mut resp_vec = Vec::new(); + + let pkid = apkid.recv().await.unwrap(); + let pub1 = create_publish_packet(QoS::AtLeastOnce, false, false, Some(pkid)); + stored_published_packets.push(pub1.clone()); + handler.handle_outgoing_packet(pub1).await.unwrap(); + + let pkid = apkid.recv().await.unwrap(); + let pub1 = create_publish_packet(QoS::ExactlyOnce, false, false, Some(pkid)); + stored_published_packets.push(pub1.clone()); + handler.handle_outgoing_packet(pub1).await.unwrap(); + + let pub1 = create_publish_packet(QoS::AtMostOnce, false, false, None); + handler.handle_outgoing_packet(pub1).await.unwrap(); + + let connack = create_connack_packet(true); + + handler.handle_incoming_packet(&connack, &mut resp_vec).await.unwrap(); + + assert_eq!(stored_published_packets.len(), resp_vec.len()); + for i in 0..stored_published_packets.len() { + let expected_pub = stored_published_packets.get(i).unwrap(); + let res_pub = resp_vec.get(i).unwrap(); + match (expected_pub, res_pub) { + (Packet::Publish(expected), Packet::Publish(res)) => { + assert!(res.dup); + assert_eq!(expected.qos, res.qos); + assert_eq!(expected.retain, res.retain); + assert_eq!(expected.topic, res.topic); + assert_eq!(expected.packet_identifier, res.packet_identifier); + assert_eq!(expected.publish_properties, res.publish_properties); + assert_eq!(expected.payload, res.payload); + } + (_, _) => panic!(), + } + } } } diff --git a/src/network/mod.rs b/src/network/mod.rs new file mode 100644 index 0000000..24785ac --- /dev/null +++ b/src/network/mod.rs @@ -0,0 +1,2 @@ +pub mod smol; +pub mod tokio; diff --git a/src/network/smol.rs b/src/network/smol.rs new file mode 100644 index 0000000..8d21ca8 --- /dev/null +++ b/src/network/smol.rs @@ -0,0 +1,194 @@ +use async_channel::Receiver; + +use futures::FutureExt; + +use std::time::{Duration, Instant}; + +use crate::connections::smol::Stream; + +use crate::connect_options::ConnectOptions; +use crate::error::ConnectionError; +use crate::packets::error::ReadBytes; +use crate::packets::reason_codes::DisconnectReasonCode; +use crate::packets::{Disconnect, Packet, PacketType}; +use crate::{AsyncEventHandler, MqttHandler, NetworkStatus}; + +/// [`Network`] reads and writes to the network based on tokios [`AsyncReadExt`] [`AsyncWriteExt`]. +/// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. +/// The most import thing to remember is that you have to provide a new stream after the previous has failed. +/// (i.e. you need to reconnect after any expected or unexpected disconnect). +pub struct Network { + network: Option>, + + /// Options of the current mqtt connection + options: ConnectOptions, + + last_network_action: Instant, + await_pingresp: Option, + perform_keep_alive: bool, + + mqtt_handler: MqttHandler, + outgoing_packet_buffer: Vec, + incoming_packet_buffer: Vec, + + to_network_r: Receiver, +} + +impl Network { + pub fn new(options: ConnectOptions, mqtt_handler: MqttHandler, to_network_r: Receiver) -> Self { + Self { + network: None, + + options, + + last_network_action: Instant::now(), + await_pingresp: None, + perform_keep_alive: true, + + mqtt_handler, + outgoing_packet_buffer: Vec::new(), + incoming_packet_buffer: Vec::new(), + + to_network_r, + } + } +} + +impl Network +where + S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, +{ + pub async fn connect(&mut self, stream: S) -> Result<(), ConnectionError> { + let (network, connack) = Stream::connect(&self.options, stream).await?; + + self.network = Some(network); + + self.last_network_action = Instant::now(); + if self.options.keep_alive_interval_s == 0 { + self.perform_keep_alive = false; + } + + self.mqtt_handler.handle_incoming_packet(&connack, &mut self.outgoing_packet_buffer).await?; + + Ok(()) + } + + pub async fn poll(&mut self, handler: &mut H) -> Result + where + H: AsyncEventHandler, + { + if self.network.is_none() { + return Err(ConnectionError::NoNetwork); + } + + match self.smol_select(handler).await { + Ok(NetworkStatus::Active) => Ok(NetworkStatus::Active), + otherwise => { + self.network = None; + self.await_pingresp = None; + self.outgoing_packet_buffer.clear(); + self.incoming_packet_buffer.clear(); + + otherwise + } + } + } + + async fn smol_select(&mut self, handler: &mut H) -> Result + where + H: AsyncEventHandler, + { + let Network { + network, + options: _, + last_network_action, + await_pingresp, + perform_keep_alive, + mqtt_handler, + outgoing_packet_buffer, + incoming_packet_buffer, + to_network_r, + } = self; + + let sleep; + if !(*perform_keep_alive) { + sleep = Duration::new(3600, 0); + } else if let Some(instant) = await_pingresp { + sleep = *instant + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now(); + } else { + sleep = *last_network_action + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now(); + } + + if let Some(stream) = network { + futures::select! { + _ = stream.read_bytes().fuse() => { + match stream.parse_messages(incoming_packet_buffer).await { + Err(ReadBytes::Err(err)) => return Err(err), + Err(ReadBytes::InsufficientBytes(_)) => return Ok(NetworkStatus::Active), + Ok(_) => (), + } + + for packet in incoming_packet_buffer.drain(0..){ + use Packet::*; + match packet{ + PingResp => { + handler.handle(packet).await; + *await_pingresp = None; + }, + Disconnect(_) => { + handler.handle(packet).await; + return Ok(NetworkStatus::IncomingDisconnect); + } + packet => { + if mqtt_handler.handle_incoming_packet(&packet, outgoing_packet_buffer).await?{ + handler.handle(packet).await; + } + } + } + } + + stream.write_all(outgoing_packet_buffer).await?; + *last_network_action = Instant::now(); + + Ok(NetworkStatus::Active) + }, + outgoing = to_network_r.recv().fuse() => { + let packet = outgoing?; + stream.write(&packet).await?; + let mut disconnect = false; + + if packet.packet_type() == PacketType::Disconnect{ + disconnect = true; + } + + mqtt_handler.handle_outgoing_packet(packet).await?; + *last_network_action = Instant::now(); + + + if disconnect{ + Ok(NetworkStatus::OutgoingDisconnect) + } + else{ + Ok(NetworkStatus::Active) + } + }, + _ = smol::Timer::after(sleep).fuse() => { + if await_pingresp.is_none() && *perform_keep_alive{ + let packet = Packet::PingReq; + stream.write(&packet).await?; + *last_network_action = Instant::now(); + *await_pingresp = Some(Instant::now()); + Ok(NetworkStatus::Active) + } + else{ + let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; + stream.write(&Packet::Disconnect(disconnect)).await?; + Ok(NetworkStatus::NoPingResp) + } + }, + } + } else { + Err(ConnectionError::NoNetwork) + } + } +} diff --git a/src/network/tokio.rs b/src/network/tokio.rs new file mode 100644 index 0000000..326e3a9 --- /dev/null +++ b/src/network/tokio.rs @@ -0,0 +1,195 @@ +use async_channel::Receiver; + +use tracing::debug; + +use std::time::{Duration, Instant}; + +use crate::connections::tokio::Stream; + +use crate::connect_options::ConnectOptions; +use crate::error::ConnectionError; +use crate::packets::error::ReadBytes; +use crate::packets::reason_codes::DisconnectReasonCode; +use crate::packets::{Disconnect, Packet, PacketType}; +use crate::{AsyncEventHandler, MqttHandler, NetworkStatus}; + +/// [`Network`] reads and writes to the network based on tokios [`AsyncReadExt`] [`AsyncWriteExt`]. +/// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. +/// The most import thing to remember is that you have to provide a new stream after the previous has failed. +/// (i.e. you need to reconnect after any expected or unexpected disconnect). +pub struct Network { + network: Option>, + + /// Options of the current mqtt connection + options: ConnectOptions, + + last_network_action: Instant, + await_pingresp: Option, + perform_keep_alive: bool, + + mqtt_handler: MqttHandler, + outgoing_packet_buffer: Vec, + incoming_packet_buffer: Vec, + + to_network_r: Receiver, +} + +impl Network { + pub fn new(options: ConnectOptions, mqtt_handler: MqttHandler, to_network_r: Receiver) -> Self { + Self { + network: None, + + options, + + last_network_action: Instant::now(), + await_pingresp: None, + perform_keep_alive: true, + + mqtt_handler, + outgoing_packet_buffer: Vec::new(), + incoming_packet_buffer: Vec::new(), + + to_network_r, + } + } +} + +/// Tokio impl +#[cfg(feature = "tokio")] +impl Network +where + S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin, +{ + /// There is a possibility that when a disconnect and connect (aka reconnect) occurs that there are + /// still messages in the channel that were supposed to be send on the previous connection. + pub async fn connect(&mut self, stream: S) -> Result<(), ConnectionError> { + let (network, connack) = Stream::connect(&self.options, stream).await?; + + self.network = Some(network); + + self.last_network_action = Instant::now(); + if self.options.keep_alive_interval_s == 0 { + self.perform_keep_alive = false; + } + + self.mqtt_handler.handle_incoming_packet(&connack, &mut self.outgoing_packet_buffer).await?; + + Ok(()) + } + + pub async fn poll(&mut self, handler: &mut H) -> Result + where + H: AsyncEventHandler, + { + if self.network.is_none() { + return Err(ConnectionError::NoNetwork); + } + + match self.tokio_select(handler).await { + Ok(NetworkStatus::Active) => Ok(NetworkStatus::Active), + otherwise => { + self.network = None; + self.await_pingresp = None; + self.outgoing_packet_buffer.clear(); + self.incoming_packet_buffer.clear(); + + otherwise + } + } + } + + async fn tokio_select(&mut self, handler: &mut H) -> Result + where + H: AsyncEventHandler, + { + let Network { + network, + options: _, + last_network_action, + await_pingresp, + perform_keep_alive, + mqtt_handler, + outgoing_packet_buffer, + incoming_packet_buffer, + to_network_r, + } = self; + + let sleep; + if let Some(instant) = await_pingresp { + sleep = *instant + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now(); + } else { + sleep = *last_network_action + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now(); + } + + if let Some(stream) = network { + debug!("Select!"); + tokio::select! { + _ = stream.read_bytes() => { + match stream.parse_messages(incoming_packet_buffer).await { + Err(ReadBytes::Err(err)) => return Err(err), + Err(ReadBytes::InsufficientBytes(_)) => return Ok(NetworkStatus::Active), + Ok(_) => (), + } + + for packet in incoming_packet_buffer.drain(0..){ + use Packet::*; + match packet{ + PingResp => { + handler.handle(packet).await; + *await_pingresp = None; + }, + Disconnect(_) => { + handler.handle(packet).await; + return Ok(NetworkStatus::IncomingDisconnect); + } + packet => { + if mqtt_handler.handle_incoming_packet(&packet, outgoing_packet_buffer).await?{ + handler.handle(packet).await; + } + } + } + } + + stream.write_all(outgoing_packet_buffer).await?; + *last_network_action = Instant::now(); + + Ok(NetworkStatus::Active) + }, + outgoing = to_network_r.recv() => { + let packet = outgoing?; + stream.write(&packet).await?; + let mut disconnect = false; + + if packet.packet_type() == PacketType::Disconnect{ + disconnect = true; + } + + mqtt_handler.handle_outgoing_packet(packet).await?; + *last_network_action = Instant::now(); + + + if disconnect{ + Ok(NetworkStatus::OutgoingDisconnect) + } + else{ + Ok(NetworkStatus::Active) + } + }, + _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { + let packet = Packet::PingReq; + stream.write(&packet).await?; + *last_network_action = Instant::now(); + *await_pingresp = Some(Instant::now()); + Ok(NetworkStatus::Active) + }, + _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { + let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; + stream.write(&Packet::Disconnect(disconnect)).await?; + Ok(NetworkStatus::NoPingResp) + } + } + } else { + Err(ConnectionError::NoNetwork) + } + } +} diff --git a/src/packets/auth.rs b/src/packets/auth.rs index 86d56ba..cc80935 100644 --- a/src/packets/auth.rs +++ b/src/packets/auth.rs @@ -19,10 +19,7 @@ impl VariableHeaderRead for Auth { let reason_code = AuthReasonCode::read(&mut buf)?; let properties = AuthProperties::read(&mut buf)?; - Ok(Self { - reason_code, - properties, - }) + Ok(Self { reason_code, properties }) } } @@ -77,29 +74,20 @@ impl MqttRead for AuthProperties { match PropertyType::read(&mut property_data)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::SessionExpiryInterval, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); } properties.reason_string = Some(String::read(&mut property_data)?); } - PropertyType::UserProperty => properties.user_properties.push(( - String::read(&mut property_data)?, - String::read(&mut property_data)?, - )), + PropertyType::UserProperty => properties.user_properties.push((String::read(&mut property_data)?, String::read(&mut property_data)?)), PropertyType::AuthenticationMethod => { if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::AuthenticationMethod, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); } properties.authentication_method = Some(String::read(&mut property_data)?); } PropertyType::AuthenticationData => { if properties.authentication_data.is_empty() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::AuthenticationData, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); } properties.authentication_data = Bytes::read(&mut property_data)?; } diff --git a/src/packets/connack.rs b/src/packets/connack.rs index 8e51cf7..c5e3fd0 100644 --- a/src/packets/connack.rs +++ b/src/packets/connack.rs @@ -1,13 +1,13 @@ use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, VariableHeaderRead, MqttWrite}, + mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead}, read_variable_integer, reason_codes::ConnAckReasonCode, PacketType, PropertyType, QoS, }; -use bytes::{Buf, Bytes, BufMut}; +use bytes::{Buf, BufMut, Bytes}; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct ConnAck { /// 3.2.2.1 Connect Acknowledge Flags pub connack_flags: ConnAckFlags, @@ -23,11 +23,7 @@ pub struct ConnAck { impl VariableHeaderRead for ConnAck { fn read(_: u8, header_len: usize, mut buf: bytes::Bytes) -> Result { if header_len > buf.len() { - return Err(DeserializeError::InsufficientData( - "ConnAck".to_string(), - buf.len(), - header_len, - )); + return Err(DeserializeError::InsufficientData("ConnAck".to_string(), buf.len(), header_len)); } let connack_flags = ConnAckFlags::read(&mut buf)?; @@ -121,11 +117,7 @@ impl MqttRead for ConnAckProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "ConnAckProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("ConnAckProperties".to_string(), buf.len(), len)); } let mut property_data = buf.split_to(len); @@ -135,134 +127,98 @@ impl MqttRead for ConnAckProperties { match property { PropertyType::SessionExpiryInterval => { if properties.session_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::SessionExpiryInterval, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); } properties.session_expiry_interval = Some(u32::read(&mut property_data)?); } PropertyType::ReceiveMaximum => { if properties.receive_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ReceiveMaximum, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); } properties.receive_maximum = Some(u16::read(&mut property_data)?); } PropertyType::MaximumQos => { if properties.maximum_qos.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::MaximumQos, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumQos)); } properties.maximum_qos = Some(QoS::read(&mut property_data)?); } PropertyType::RetainAvailable => { if properties.retain_available.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::RetainAvailable, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable)); } properties.retain_available = Some(bool::read(&mut property_data)?); } PropertyType::MaximumPacketSize => { if properties.maximum_packet_size.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::MaximumPacketSize, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); } properties.maximum_packet_size = Some(u32::read(&mut property_data)?); } PropertyType::AssignedClientIdentifier => { if properties.assigned_client_id.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::AssignedClientIdentifier, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier)); } properties.assigned_client_id = Some(String::read(&mut property_data)?); } PropertyType::TopicAliasMaximum => { if properties.topic_alias_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::TopicAliasMaximum, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); } properties.topic_alias_maximum = Some(u16::read(&mut property_data)?); } PropertyType::ReasonString => { if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ReasonString, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); } properties.reason_string = Some(String::read(&mut property_data)?); } - PropertyType::UserProperty => properties.user_properties.push(( - String::read(&mut property_data)?, - String::read(&mut property_data)?, - )), + PropertyType::UserProperty => properties.user_properties.push((String::read(&mut property_data)?, String::read(&mut property_data)?)), PropertyType::WildcardSubscriptionAvailable => { if properties.wildcards_available.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::WildcardSubscriptionAvailable, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::WildcardSubscriptionAvailable)); } properties.wildcards_available = Some(bool::read(&mut property_data)?); } PropertyType::SubscriptionIdentifierAvailable => { if properties.subscription_ids_available.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::SubscriptionIdentifierAvailable, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifierAvailable)); } properties.subscription_ids_available = Some(bool::read(&mut property_data)?); } PropertyType::SharedSubscriptionAvailable => { if properties.shared_subscription_available.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::SharedSubscriptionAvailable, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::SharedSubscriptionAvailable)); } - properties.shared_subscription_available = - Some(bool::read(&mut property_data)?); + properties.shared_subscription_available = Some(bool::read(&mut property_data)?); } PropertyType::ServerKeepAlive => { if properties.server_keep_alive.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ServerKeepAlive, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive)); } properties.server_keep_alive = Some(u16::read(&mut property_data)?); } PropertyType::ResponseInformation => { if properties.response_info.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ResponseInformation, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseInformation)); } properties.response_info = Some(String::read(&mut property_data)?); } PropertyType::ServerReference => { if properties.server_reference.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ServerReference, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); } properties.server_reference = Some(String::read(&mut property_data)?); } PropertyType::AuthenticationMethod => { if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::AuthenticationMethod, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); } properties.authentication_method = Some(String::read(&mut property_data)?); } PropertyType::AuthenticationData => { if properties.authentication_data.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::AuthenticationData, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); } properties.authentication_data = Some(Bytes::read(&mut property_data)?); } @@ -279,33 +235,21 @@ impl MqttRead for ConnAckProperties { } } -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct ConnAckFlags{ +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +pub struct ConnAckFlags { pub session_present: bool, } -impl Default for ConnAckFlags { - fn default() -> Self { - Self { - session_present: false, - } - } -} - impl MqttRead for ConnAckFlags { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "ConnAckFlags".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("ConnAckFlags".to_string(), 0, 1)); } let byte = buf.get_u8(); - Ok(Self { - session_present: (byte & 0b00000001) != 0, + Ok(Self { + session_present: (byte & 0b00000001) != 0, }) } } @@ -324,6 +268,7 @@ mod tests { use crate::packets::{ connack::{ConnAck, ConnAckProperties}, mqtt_traits::{MqttRead, VariableHeaderRead}, + reason_codes::ConnAckReasonCode, }; #[test] @@ -338,7 +283,8 @@ mod tests { buf.extend_from_slice(packet); let c = ConnAck::read(0, packet.len(), buf.into()).unwrap(); - dbg!(c); + assert_eq!(ConnAckReasonCode::Success, c.reason_code); + assert_eq!(ConnAckProperties::default(), c.connack_properties); } #[test] @@ -356,9 +302,7 @@ mod tests { 2, // QoS 2 Exactly Once 34, // Topic Alias Max = 255 0, 255, 31, // Reason String = 'Houston we have got a problem' - 0, 29, b'H', b'o', b'u', b's', b't', b'o', b'n', b' ', b'w', b'e', b' ', b'h', b'a', - b'v', b'e', b' ', b'g', b'o', b't', b' ', b'a', b' ', b'p', b'r', b'o', b'b', b'l', - b'e', b'm', + 0, 29, b'H', b'o', b'u', b's', b't', b'o', b'n', b' ', b'w', b'e', b' ', b'h', b'a', b'v', b'e', b' ', b'g', b'o', b't', b' ', b'a', b' ', b'p', b'r', b'o', b'b', b'l', b'e', b'm', ]; buf.extend_from_slice(packet); diff --git a/src/packets/connect.rs b/src/packets/connect.rs index 9288d72..0700b20 100644 --- a/src/packets/connect.rs +++ b/src/packets/connect.rs @@ -3,8 +3,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use super::{ error::{DeserializeError, SerializeError}, mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, - ProtocolVersion, QoS, WireLength, + read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, ProtocolVersion, QoS, WireLength, }; /// Variable connect header: @@ -99,9 +98,7 @@ impl Default for Connect { impl VariableHeaderRead for Connect { fn read(_: u8, _: usize, mut buf: Bytes) -> Result { if String::read(&mut buf)? != "MQTT" { - return Err(DeserializeError::MalformedPacketWithInfo( - "Protocol not MQTT".to_string(), - )); + return Err(DeserializeError::MalformedPacketWithInfo("Protocol not MQTT".to_string())); } let protocol_version = ProtocolVersion::read(&mut buf)?; @@ -121,16 +118,8 @@ impl VariableHeaderRead for Connect { last_will = Some(LastWill::read(connect_flags.will_qos, retain, &mut buf)?); } - let username = if connect_flags.username { - Some(String::read(&mut buf)?) - } else { - None - }; - let password = if connect_flags.password { - Some(String::read(&mut buf)?) - } else { - None - }; + let username = if connect_flags.username { Some(String::read(&mut buf)?) } else { None }; + let password = if connect_flags.password { Some(String::read(&mut buf)?) } else { None }; let connect = Connect { protocol_version, @@ -153,9 +142,11 @@ impl VariableHeaderWrite for Connect { self.protocol_version.write(buf)?; - let mut connect_flags = ConnectFlags::default(); + let mut connect_flags = ConnectFlags { + clean_start: self.clean_start, + ..Default::default() + }; - connect_flags.clean_start = self.clean_start; if let Some(last_will) = &self.last_will { connect_flags.will_flag = true; connect_flags.will_retain = last_will.retain; @@ -214,7 +205,7 @@ impl WireLength for Connect { /// ║ ║ User Name ║ Password ║ Will Retain ║ Will QoS ║ Will Flag ║ Clean Start ║ Reserved ║ /// ╚═════â•İ═══════════â•İ══════════â•İ═════════════â•İ══════════â•İ═══════════â•İ═════════════â•İ══════════╝ #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct ConnectFlags{ +pub struct ConnectFlags { pub clean_start: bool, pub will_flag: bool, pub will_qos: QoS, @@ -223,27 +214,25 @@ pub struct ConnectFlags{ pub username: bool, } -impl ConnectFlags{ - pub fn from_u8(value: u8) -> Result{ - Ok( - Self{ - clean_start: ((value & 0b00000010) >> 1) != 0, - will_flag: ((value & 0b00000100) >> 2) != 0, - will_qos: QoS::from_u8((value & 0b00011000) >> 3)?, - will_retain: ((value & 0b00100000) >> 5) != 0, - password: ((value & 0b01000000) >> 6) != 0, - username: ((value & 0b10000000) >> 7) != 0, - } - ) +impl ConnectFlags { + pub fn from_u8(value: u8) -> Result { + Ok(Self { + clean_start: ((value & 0b00000010) >> 1) != 0, + will_flag: ((value & 0b00000100) >> 2) != 0, + will_qos: QoS::from_u8((value & 0b00011000) >> 3)?, + will_retain: ((value & 0b00100000) >> 5) != 0, + password: ((value & 0b01000000) >> 6) != 0, + username: ((value & 0b10000000) >> 7) != 0, + }) } - pub fn into_u8(&self) -> Result{ + pub fn into_u8(&self) -> Result { let byte = ((self.clean_start as u8) << 1) - | ((self.will_flag as u8) << 2) - | (self.will_qos.into_u8() << 3) - | ((self.will_retain as u8) << 5) - | ((self.password as u8) << 6) - | ((self.username as u8) << 7); + | ((self.will_flag as u8) << 2) + | (self.will_qos.into_u8() << 3) + | ((self.will_retain as u8) << 5) + | ((self.password as u8) << 6) + | ((self.username as u8) << 7); Ok(byte) } } @@ -264,11 +253,7 @@ impl Default for ConnectFlags { impl MqttRead for ConnectFlags { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "ConnectFlags".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("ConnectFlags".to_string(), 0, 1)); } let byte = buf.get_u8(); @@ -388,11 +373,7 @@ impl MqttRead for ConnectProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "ConnectProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("ConnectProperties".to_string(), buf.len(), len)); } let mut property_data = buf.split_to(len); @@ -401,69 +382,50 @@ impl MqttRead for ConnectProperties { match PropertyType::read(&mut property_data)? { PropertyType::SessionExpiryInterval => { if properties.session_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::SessionExpiryInterval, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); } properties.session_expiry_interval = Some(property_data.get_u32()); } PropertyType::ReceiveMaximum => { if properties.receive_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ReceiveMaximum, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); } properties.receive_maximum = Some(property_data.get_u16()); } PropertyType::MaximumPacketSize => { if properties.maximum_packet_size.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::MaximumPacketSize, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); } properties.maximum_packet_size = Some(property_data.get_u32()); } PropertyType::TopicAliasMaximum => { if properties.topic_alias_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::TopicAliasMaximum, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); } properties.topic_alias_maximum = Some(property_data.get_u16()); } PropertyType::RequestResponseInformation => { if properties.request_response_information.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::RequestResponseInformation, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::RequestResponseInformation)); } properties.request_response_information = Some(property_data.get_u8()); } PropertyType::RequestProblemInformation => { if properties.request_problem_information.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::RequestProblemInformation, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::RequestProblemInformation)); } properties.request_problem_information = Some(property_data.get_u8()); } - PropertyType::UserProperty => properties.user_properties.push(( - String::read(&mut property_data)?, - String::read(&mut property_data)?, - )), + PropertyType::UserProperty => properties.user_properties.push((String::read(&mut property_data)?, String::read(&mut property_data)?)), PropertyType::AuthenticationMethod => { if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::AuthenticationMethod, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); } properties.authentication_method = Some(String::read(&mut property_data)?); } PropertyType::AuthenticationData => { if properties.authentication_data.is_empty() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::AuthenticationData, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); } properties.authentication_data = Bytes::read(&mut property_data)?; } @@ -475,11 +437,8 @@ impl MqttRead for ConnectProperties { } } - if !properties.authentication_data.is_empty() && properties.authentication_method.is_none() - { - return Err(DeserializeError::MalformedPacketWithInfo( - "Authentication data is not empty while authentication method is".to_string(), - )); + if !properties.authentication_data.is_empty() && properties.authentication_method.is_none() { + return Err(DeserializeError::MalformedPacketWithInfo("Authentication data is not empty while authentication method is".to_string())); } Ok(properties) @@ -540,12 +499,7 @@ pub struct LastWill { } impl LastWill { - pub fn new, P: Into>>( - qos: QoS, - retain: bool, - topic: T, - payload: P, - ) -> LastWill { + pub fn new, P: Into>>(qos: QoS, retain: bool, topic: T, payload: P) -> LastWill { Self { qos, retain, @@ -582,10 +536,7 @@ impl WireLength for LastWill { fn wire_len(&self) -> usize { let property_len = self.last_will_properties.wire_len(); - self.topic.wire_len() - + self.payload.wire_len() - + variable_integer_len(property_len) - + property_len + self.topic.wire_len() + self.payload.wire_len() + variable_integer_len(property_len) + property_len } } @@ -615,11 +566,7 @@ impl MqttRead for LastWillProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "LastWillProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("LastWillProperties".to_string(), buf.len(), len)); } let mut property_data = buf.split_to(len); @@ -628,56 +575,41 @@ impl MqttRead for LastWillProperties { match PropertyType::read(&mut property_data)? { PropertyType::WillDelayInterval => { if properties.delay_interval.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::WillDelayInterval, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::WillDelayInterval)); } properties.delay_interval = Some(u32::read(&mut property_data)?); } PropertyType::PayloadFormatIndicator => { if properties.payload_format_indicator.is_none() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::PayloadFormatIndicator, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); } properties.payload_format_indicator = Some(u8::read(&mut property_data)?); } PropertyType::MessageExpiryInterval => { if properties.message_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::MessageExpiryInterval, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); } properties.message_expiry_interval = Some(u32::read(&mut property_data)?); } PropertyType::ContentType => { if properties.content_type.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ContentType, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); } properties.content_type = Some(String::read(&mut property_data)?); } PropertyType::ResponseTopic => { if properties.response_topic.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ResponseTopic, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); } properties.response_topic = Some(String::read(&mut property_data)?); } PropertyType::CorrelationData => { if properties.correlation_data.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::CorrelationData, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); } properties.correlation_data = Some(Bytes::read(&mut property_data)?); } - PropertyType::UserProperty => properties.user_properties.push(( - String::read(&mut property_data)?, - String::read(&mut property_data)?, - )), + PropertyType::UserProperty => properties.user_properties.push((String::read(&mut property_data)?, String::read(&mut property_data)?)), e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)), } @@ -743,18 +675,9 @@ impl WireLength for LastWillProperties { len += 5; } // +1 for the property type - len += self - .content_type - .as_ref() - .map_or_else(|| 0, |s| s.wire_len() + 1); - len += self - .response_topic - .as_ref() - .map_or_else(|| 0, |s| s.wire_len() + 1); - len += self - .correlation_data - .as_ref() - .map_or_else(|| 0, |b| b.wire_len() + 1); + len += self.content_type.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); + len += self.response_topic.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); + len += self.correlation_data.as_ref().map_or_else(|| 0, |b| b.wire_len() + 1); for (key, value) in &self.user_properties { len += key.wire_len() + value.wire_len() + 1; } @@ -770,7 +693,7 @@ mod tests { QoS, }; - use super::{Connect, LastWill, ConnectFlags}; + use super::{Connect, ConnectFlags, LastWill}; #[test] fn read_connect() { @@ -920,15 +843,11 @@ mod tests { #[test] fn read_and_write_connect2() { let _packet = [ - 0x10, 0x1d, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, - 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, - 0x65, 0x73, 0x74, + 0x10, 0x1d, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, ]; let data = [ - 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, - 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, - 0x74, + 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, ]; let mut buf = bytes::BytesMut::new(); @@ -965,7 +884,7 @@ mod tests { } #[test] - fn connect_flag(){ + fn connect_flag() { let byte = 0b1100_1110; let flags = ConnectFlags::from_u8(byte).unwrap(); assert_eq!(byte, flags.into_u8().unwrap()); diff --git a/src/packets/disconnect.rs b/src/packets/disconnect.rs index 1f08009..4019ba3 100644 --- a/src/packets/disconnect.rs +++ b/src/packets/disconnect.rs @@ -24,11 +24,7 @@ impl Default for Disconnect { } impl VariableHeaderRead for Disconnect { - fn read( - _: u8, - remaining_length: usize, - mut buf: bytes::Bytes, - ) -> Result { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { let reason_code; let properties; if remaining_length == 0 { @@ -39,17 +35,12 @@ impl VariableHeaderRead for Disconnect { properties = DisconnectProperties::read(&mut buf)?; } - Ok(Self { - reason_code, - properties, - }) + Ok(Self { reason_code, properties }) } } impl VariableHeaderWrite for Disconnect { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - if self.reason_code != DisconnectReasonCode::NormalDisconnection - || self.properties.wire_len() != 0 - { + if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { self.reason_code.write(buf)?; self.properties.write(buf)?; } @@ -58,9 +49,7 @@ impl VariableHeaderWrite for Disconnect { } impl WireLength for Disconnect { fn wire_len(&self) -> usize { - if self.reason_code != DisconnectReasonCode::NormalDisconnection - || self.properties.wire_len() != 0 - { + if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { let property_len = self.properties.wire_len(); // reasoncode, length of property length, property length 1 + variable_integer_len(property_len) + property_len @@ -86,11 +75,7 @@ impl MqttRead for DisconnectProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "DisconnectProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("DisconnectProperties".to_string(), buf.len(), len)); } let mut property_data = buf.split_to(len); @@ -99,38 +84,24 @@ impl MqttRead for DisconnectProperties { match PropertyType::from_u8(u8::read(&mut property_data)?)? { PropertyType::SessionExpiryInterval => { if properties.session_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::SessionExpiryInterval, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); } properties.session_expiry_interval = Some(u32::read(&mut property_data)?); } PropertyType::ReasonString => { if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ReasonString, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); } properties.reason_string = Some(String::read(&mut property_data)?); } PropertyType::ServerReference => { if properties.server_reference.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ServerReference, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); } properties.server_reference = Some(String::read(&mut property_data)?); } - PropertyType::UserProperty => properties.user_properties.push(( - String::read(&mut property_data)?, - String::read(&mut property_data)?, - )), - e => { - return Err(DeserializeError::UnexpectedProperty( - e, - PacketType::Disconnect, - )) - } + PropertyType::UserProperty => properties.user_properties.push((String::read(&mut property_data)?, String::read(&mut property_data)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Disconnect)), } if property_data.is_empty() { diff --git a/src/packets/mod.rs b/src/packets/mod.rs index faf5f1a..f50d5cb 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -163,11 +163,7 @@ impl MqttRead for Bytes { let len = buf.get_u16() as usize; if len > buf.len() { - return Err(DeserializeError::InsufficientData( - "Bytes".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("Bytes".to_string(), buf.len(), len)); } Ok(buf.split_to(len)) @@ -228,11 +224,7 @@ impl MqttRead for u8 { impl MqttRead for u16 { fn read(buf: &mut Bytes) -> Result { if buf.len() < 2 { - return Err(DeserializeError::InsufficientData( - "u16".to_string(), - buf.len(), - 2, - )); + return Err(DeserializeError::InsufficientData("u16".to_string(), buf.len(), 2)); } Ok(buf.get_u16()) } @@ -241,19 +233,13 @@ impl MqttRead for u16 { impl MqttRead for u32 { fn read(buf: &mut Bytes) -> Result { if buf.len() < 4 { - return Err(DeserializeError::InsufficientData( - "u32".to_string(), - buf.len(), - 4, - )); + return Err(DeserializeError::InsufficientData("u32".to_string(), buf.len(), 4)); } Ok(buf.get_u32()) } } -pub fn read_fixed_header_rem_len( - mut buf: Iter, -) -> Result<(usize, usize), ReadBytes> { +pub fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usize), ReadBytes> { let mut integer = 0; let mut length = 0; @@ -359,11 +345,7 @@ pub enum PropertyType { impl MqttRead for PropertyType { fn read(buf: &mut Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "PropertyType".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("PropertyType".to_string(), 0, 1)); } match buf.get_u8() { @@ -631,51 +613,21 @@ impl Packet { pub fn read(header: FixedHeader, buf: Bytes) -> Result { let packet = match header.packet_type { - PacketType::Connect => { - Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?) - } - PacketType::ConnAck => { - Packet::ConnAck(ConnAck::read(header.flags, header.remaining_length, buf)?) - } - PacketType::Publish => { - Packet::Publish(Publish::read(header.flags, header.remaining_length, buf)?) - } - PacketType::PubAck => { - Packet::PubAck(PubAck::read(header.flags, header.remaining_length, buf)?) - } - PacketType::PubRec => { - Packet::PubRec(PubRec::read(header.flags, header.remaining_length, buf)?) - } - PacketType::PubRel => { - Packet::PubRel(PubRel::read(header.flags, header.remaining_length, buf)?) - } - PacketType::PubComp => { - Packet::PubComp(PubComp::read(header.flags, header.remaining_length, buf)?) - } - PacketType::Subscribe => { - Packet::Subscribe(Subscribe::read(header.flags, header.remaining_length, buf)?) - } - PacketType::SubAck => { - Packet::SubAck(SubAck::read(header.flags, header.remaining_length, buf)?) - } - PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read( - header.flags, - header.remaining_length, - buf, - )?), - PacketType::UnsubAck => { - Packet::UnsubAck(UnsubAck::read(header.flags, header.remaining_length, buf)?) - } + PacketType::Connect => Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?), + PacketType::ConnAck => Packet::ConnAck(ConnAck::read(header.flags, header.remaining_length, buf)?), + PacketType::Publish => Packet::Publish(Publish::read(header.flags, header.remaining_length, buf)?), + PacketType::PubAck => Packet::PubAck(PubAck::read(header.flags, header.remaining_length, buf)?), + PacketType::PubRec => Packet::PubRec(PubRec::read(header.flags, header.remaining_length, buf)?), + PacketType::PubRel => Packet::PubRel(PubRel::read(header.flags, header.remaining_length, buf)?), + PacketType::PubComp => Packet::PubComp(PubComp::read(header.flags, header.remaining_length, buf)?), + PacketType::Subscribe => Packet::Subscribe(Subscribe::read(header.flags, header.remaining_length, buf)?), + PacketType::SubAck => Packet::SubAck(SubAck::read(header.flags, header.remaining_length, buf)?), + PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(header.flags, header.remaining_length, buf)?), + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(header.flags, header.remaining_length, buf)?), PacketType::PingReq => Packet::PingReq, PacketType::PingResp => Packet::PingResp, - PacketType::Disconnect => Packet::Disconnect(Disconnect::read( - header.flags, - header.remaining_length, - buf, - )?), - PacketType::Auth => { - Packet::Auth(Auth::read(header.flags, header.remaining_length, buf)?) - } + PacketType::Disconnect => Packet::Disconnect(Disconnect::read(header.flags, header.remaining_length, buf)?), + PacketType::Auth => Packet::Auth(Auth::read(header.flags, header.remaining_length, buf)?), }; Ok(packet) } @@ -683,9 +635,7 @@ impl Packet { pub fn read_from_buffer(buffer: &mut BytesMut) -> Result> { let (header, header_length) = FixedHeader::read_fixed_header(buffer.iter())?; if header.remaining_length + header_length > buffer.len() { - return Err(ReadBytes::InsufficientBytes( - header.remaining_length + header_length - buffer.len(), - )); + return Err(ReadBytes::InsufficientBytes(header.remaining_length + header_length - buffer.len())); } buffer.advance(header_length); @@ -697,10 +647,18 @@ impl Packet { impl Display for Packet { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self{ - Packet::Connect(c) => write!(f, "Connect(version: {:?}, clean: {}, username: {:?}, password: {:?}, keep_alive: {}, client_id: {})", c.protocol_version, c.clean_start, c.username, c.password, c.keep_alive, c.client_id), + match self { + Packet::Connect(c) => write!( + f, + "Connect(version: {:?}, clean: {}, username: {:?}, password: {:?}, keep_alive: {}, client_id: {})", + c.protocol_version, c.clean_start, c.username, c.password, c.keep_alive, c.client_id + ), Packet::ConnAck(c) => write!(f, "ConnAck(session:{:?}, reason code{:?})", c.connack_flags, c.reason_code), - Packet::Publish(p) => write!(f, "Publish(topic: {}, qos: {:?}, dup: {:?}, retain: {:?}, packet id: {:?})", &p.topic, p.qos, p.dup, p.retain, p.packet_identifier), + Packet::Publish(p) => write!( + f, + "Publish(topic: {}, qos: {:?}, dup: {:?}, retain: {:?}, packet id: {:?})", + &p.topic, p.qos, p.dup, p.retain, p.packet_identifier + ), Packet::PubAck(p) => write!(f, "PubAck(id:{:?}, reason code: {:?})", p.packet_identifier, p.reason_code), Packet::PubRec(p) => write!(f, "PubRec(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code), Packet::PubRel(p) => write!(f, "PubRel(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code), @@ -736,9 +694,7 @@ pub struct FixedHeader { } impl FixedHeader { - pub fn read_fixed_header( - mut header: Iter, - ) -> Result<(Self, usize), ReadBytes> { + pub fn read_fixed_header(mut header: Iter) -> Result<(Self, usize), ReadBytes> { if header.len() < 2 { return Err(ReadBytes::InsufficientBytes(2 - header.len())); } @@ -746,20 +702,12 @@ impl FixedHeader { let mut header_length = 1; let first_byte = header.next().unwrap(); - let (packet_type, flags) = - PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?; + let (packet_type, flags) = PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?; let (remaining_length, length) = read_fixed_header_rem_len(header)?; header_length += length; - Ok(( - Self { - packet_type, - flags, - remaining_length, - }, - header_length, - )) + Ok((Self { packet_type, flags, remaining_length }, header_length)) } } @@ -821,8 +769,7 @@ mod tests { #[test] fn test_connack_read() { let connack = [ - 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01, - 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01, + 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01, 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01, ]; let mut buf = BytesMut::new(); buf.extend(connack); @@ -832,7 +779,7 @@ mod tests { let res = res.unwrap(); let expected = ConnAck { - connack_flags: ConnAckFlags{ session_present: true }, + connack_flags: ConnAckFlags { session_present: true }, reason_code: ConnAckReasonCode::Success, connack_properties: ConnAckProperties { session_expiry_interval: None, @@ -918,9 +865,8 @@ mod tests { #[test] fn test_publish_read() { let packet = [ - 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74, - 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01, - 0x01, 0x09, 0x00, 0x04, 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01, + 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01, 0x01, 0x09, 0x00, + 0x04, 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01, ]; let mut buf = BytesMut::new(); diff --git a/src/packets/puback.rs b/src/packets/puback.rs index 83ebe81..380fc2a 100644 --- a/src/packets/puback.rs +++ b/src/packets/puback.rs @@ -16,11 +16,7 @@ pub struct PubAck { } impl VariableHeaderRead for PubAck { - fn read( - _: u8, - remaining_length: usize, - mut buf: bytes::Bytes, - ) -> Result { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { return Ok(Self { @@ -31,11 +27,7 @@ impl VariableHeaderRead for PubAck { } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData( - "PubAck".to_string(), - buf.len(), - 4, - )); + return Err(DeserializeError::InsufficientData("PubAck".to_string(), buf.len(), 4)); } let packet_identifier = u16::read(&mut buf)?; @@ -54,14 +46,9 @@ impl VariableHeaderWrite for PubAck { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); - if self.reason_code == PubAckReasonCode::Success - && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + if self.reason_code == PubAckReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { // nothing here - } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { self.reason_code.write(buf)?; } else { self.reason_code.write(buf)?; @@ -73,15 +60,10 @@ impl VariableHeaderWrite for PubAck { impl WireLength for PubAck { fn wire_len(&self) -> usize { - if self.reason_code == PubAckReasonCode::Success - && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + if self.reason_code == PubAckReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { // Only pkid 2 - } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { // pkid and reason code 3 } else { @@ -112,11 +94,7 @@ impl MqttRead for PubAckProperties { return Ok(Self::default()); } if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "PubAckProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("PubAckProperties".to_string(), buf.len(), len)); } let mut properties = PubAckProperties::default(); @@ -125,15 +103,11 @@ impl MqttRead for PubAckProperties { match PropertyType::from_u8(u8::read(buf)?)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ReasonString, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); } properties.reason_string = Some(String::read(buf)?); } - PropertyType::UserProperty => properties - .user_properties - .push((String::read(buf)?, String::read(buf)?)), + PropertyType::UserProperty => properties.user_properties.push((String::read(buf)?, String::read(buf)?)), e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubAck)), } if buf.is_empty() { @@ -264,20 +238,12 @@ mod tests { #[test] fn test_properties() { let mut properties_data = BytesMut::new(); - PropertyType::ReasonString - .write(&mut properties_data) - .unwrap(); - "reason string, test 1-2-3." - .write(&mut properties_data) - .unwrap(); - PropertyType::UserProperty - .write(&mut properties_data) - .unwrap(); + PropertyType::ReasonString.write(&mut properties_data).unwrap(); + "reason string, test 1-2-3.".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); "This is the key".write(&mut properties_data).unwrap(); "This is the value".write(&mut properties_data).unwrap(); - PropertyType::UserProperty - .write(&mut properties_data) - .unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); "Another thingy".write(&mut properties_data).unwrap(); "The thingy".write(&mut properties_data).unwrap(); diff --git a/src/packets/pubcomp.rs b/src/packets/pubcomp.rs index 8281df2..1d1a63e 100644 --- a/src/packets/pubcomp.rs +++ b/src/packets/pubcomp.rs @@ -26,11 +26,7 @@ impl PubComp { } impl VariableHeaderRead for PubComp { - fn read( - _: u8, - remaining_length: usize, - mut buf: bytes::Bytes, - ) -> Result { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { return Ok(Self { @@ -41,11 +37,7 @@ impl VariableHeaderRead for PubComp { } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData( - "PubComp".to_string(), - buf.len(), - 4, - )); + return Err(DeserializeError::InsufficientData("PubComp".to_string(), buf.len(), 4)); } let packet_identifier = u16::read(&mut buf)?; @@ -64,14 +56,9 @@ impl VariableHeaderWrite for PubComp { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); - if self.reason_code == PubCompReasonCode::Success - && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { // nothing here - } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { self.reason_code.write(buf)?; } else { self.reason_code.write(buf)?; @@ -83,14 +70,9 @@ impl VariableHeaderWrite for PubComp { impl WireLength for PubComp { fn wire_len(&self) -> usize { - if self.reason_code == PubCompReasonCode::Success - && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 2 - } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 3 } else { 2 + 1 + self.properties.wire_len() @@ -118,11 +100,7 @@ impl MqttRead for PubCompProperties { return Ok(Self::default()); } if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "PubCompProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("PubCompProperties".to_string(), buf.len(), len)); } let mut properties = PubCompProperties::default(); @@ -131,15 +109,11 @@ impl MqttRead for PubCompProperties { match PropertyType::from_u8(u8::read(buf)?)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ReasonString, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); } properties.reason_string = Some(String::read(buf)?); } - PropertyType::UserProperty => properties - .user_properties - .push((String::read(buf)?, String::read(buf)?)), + PropertyType::UserProperty => properties.user_properties.push((String::read(buf)?, String::read(buf)?)), e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubComp)), } if buf.is_empty() { @@ -268,20 +242,12 @@ mod tests { #[test] fn test_properties() { let mut properties_data = BytesMut::new(); - PropertyType::ReasonString - .write(&mut properties_data) - .unwrap(); - "reason string, test 1-2-3." - .write(&mut properties_data) - .unwrap(); - PropertyType::UserProperty - .write(&mut properties_data) - .unwrap(); + PropertyType::ReasonString.write(&mut properties_data).unwrap(); + "reason string, test 1-2-3.".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); "This is the key".write(&mut properties_data).unwrap(); "This is the value".write(&mut properties_data).unwrap(); - PropertyType::UserProperty - .write(&mut properties_data) - .unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); "Another thingy".write(&mut properties_data).unwrap(); "The thingy".write(&mut properties_data).unwrap(); diff --git a/src/packets/publish.rs b/src/packets/publish.rs index e463822..691b270 100644 --- a/src/packets/publish.rs +++ b/src/packets/publish.rs @@ -1,12 +1,9 @@ use bytes::{BufMut, Bytes}; -use super::mqtt_traits::{ - MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength, -}; +use super::mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}; use super::{ error::{DeserializeError, SerializeError}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, - QoS, + read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -34,14 +31,7 @@ pub struct Publish { } impl Publish { - pub fn new( - qos: QoS, - retain: bool, - topic: String, - packet_identifier: Option, - publish_properties: PublishProperties, - payload: Bytes, - ) -> Self { + pub fn new(qos: QoS, retain: bool, topic: String, packet_identifier: Option, publish_properties: PublishProperties, payload: Bytes) -> Self { Self { dup: false, qos, @@ -154,11 +144,7 @@ impl MqttRead for PublishProperties { if len == 0 { return Ok(Self::default()); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "PublishProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("PublishProperties".to_string(), buf.len(), len)); } let mut property_data = buf.split_to(len); @@ -169,58 +155,41 @@ impl MqttRead for PublishProperties { match PropertyType::from_u8(u8::read(&mut property_data)?)? { PropertyType::PayloadFormatIndicator => { if properties.payload_format_indicator.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::PayloadFormatIndicator, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); } properties.payload_format_indicator = Some(u8::read(&mut property_data)?); } PropertyType::MessageExpiryInterval => { if properties.message_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::MessageExpiryInterval, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); } properties.message_expiry_interval = Some(u32::read(&mut property_data)?); } PropertyType::TopicAlias => { if properties.topic_alias.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::TopicAlias, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAlias)); } properties.topic_alias = Some(u16::read(&mut property_data)?); } PropertyType::ResponseTopic => { if properties.response_topic.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ResponseTopic, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); } properties.response_topic = Some(String::read(&mut property_data)?); } PropertyType::CorrelationData => { if properties.correlation_data.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::CorrelationData, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); } properties.correlation_data = Some(Bytes::read(&mut property_data)?); } PropertyType::SubscriptionIdentifier => { - properties - .subscription_identifier - .push(read_variable_integer(&mut property_data)?.0); + properties.subscription_identifier.push(read_variable_integer(&mut property_data)?.0); } - PropertyType::UserProperty => properties.user_properties.push(( - String::read(&mut property_data)?, - String::read(&mut property_data)?, - )), + PropertyType::UserProperty => properties.user_properties.push((String::read(&mut property_data)?, String::read(&mut property_data)?)), PropertyType::ContentType => { if properties.content_type.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ContentType, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); } properties.content_type = Some(String::read(&mut property_data)?); } diff --git a/src/packets/pubrec.rs b/src/packets/pubrec.rs index c5e52b4..947d8a0 100644 --- a/src/packets/pubrec.rs +++ b/src/packets/pubrec.rs @@ -25,11 +25,7 @@ impl PubRec { } impl VariableHeaderRead for PubRec { - fn read( - _: u8, - remaining_length: usize, - mut buf: bytes::Bytes, - ) -> Result { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { return Ok(Self { @@ -40,11 +36,7 @@ impl VariableHeaderRead for PubRec { } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData( - "PubRec".to_string(), - buf.len(), - 4, - )); + return Err(DeserializeError::InsufficientData("PubRec".to_string(), buf.len(), 4)); } let packet_identifier = u16::read(&mut buf)?; @@ -63,14 +55,9 @@ impl VariableHeaderWrite for PubRec { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); - if self.reason_code == PubRecReasonCode::Success - && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + if self.reason_code == PubRecReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { // nothing here - } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { self.reason_code.write(buf)?; } else { self.reason_code.write(buf)?; @@ -82,14 +69,9 @@ impl VariableHeaderWrite for PubRec { impl WireLength for PubRec { fn wire_len(&self) -> usize { - if self.reason_code == PubRecReasonCode::Success - && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + if self.reason_code == PubRecReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 2 - } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 3 } else { 2 + 1 + self.properties.wire_len() @@ -117,11 +99,7 @@ impl MqttRead for PubRecProperties { return Ok(Self::default()); } if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "PubRecProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("PubRecProperties".to_string(), buf.len(), len)); } let mut properties = PubRecProperties::default(); @@ -130,15 +108,11 @@ impl MqttRead for PubRecProperties { match PropertyType::from_u8(u8::read(buf)?)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ReasonString, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); } properties.reason_string = Some(String::read(buf)?); } - PropertyType::UserProperty => properties - .user_properties - .push((String::read(buf)?, String::read(buf)?)), + PropertyType::UserProperty => properties.user_properties.push((String::read(buf)?, String::read(buf)?)), e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRec)), } if buf.is_empty() { @@ -269,20 +243,12 @@ mod tests { #[test] fn test_properties() { let mut properties_data = BytesMut::new(); - PropertyType::ReasonString - .write(&mut properties_data) - .unwrap(); - "reason string, test 1-2-3." - .write(&mut properties_data) - .unwrap(); - PropertyType::UserProperty - .write(&mut properties_data) - .unwrap(); + PropertyType::ReasonString.write(&mut properties_data).unwrap(); + "reason string, test 1-2-3.".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); "This is the key".write(&mut properties_data).unwrap(); "This is the value".write(&mut properties_data).unwrap(); - PropertyType::UserProperty - .write(&mut properties_data) - .unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); "Another thingy".write(&mut properties_data).unwrap(); "The thingy".write(&mut properties_data).unwrap(); diff --git a/src/packets/pubrel.rs b/src/packets/pubrel.rs index f4953d0..50007a8 100644 --- a/src/packets/pubrel.rs +++ b/src/packets/pubrel.rs @@ -26,11 +26,7 @@ impl PubRel { } impl VariableHeaderRead for PubRel { - fn read( - _: u8, - remaining_length: usize, - mut buf: bytes::Bytes, - ) -> Result { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { Ok(Self { @@ -58,14 +54,9 @@ impl VariableHeaderWrite for PubRel { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); - if self.reason_code == PubRelReasonCode::Success - && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { // Nothing here - } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { self.reason_code.write(buf)?; } else { self.reason_code.write(buf)?; @@ -77,14 +68,9 @@ impl VariableHeaderWrite for PubRel { impl WireLength for PubRel { fn wire_len(&self) -> usize { - if self.reason_code == PubRelReasonCode::Success - && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 2 - } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() - { + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 3 } else { 2 + 1 + self.properties.wire_len() @@ -112,11 +98,7 @@ impl MqttRead for PubRelProperties { return Ok(Self::default()); } if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "PubRelProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("PubRelProperties".to_string(), buf.len(), len)); } let mut properties = PubRelProperties::default(); @@ -125,15 +107,11 @@ impl MqttRead for PubRelProperties { match PropertyType::from_u8(u8::read(buf)?)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty( - PropertyType::ReasonString, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); } properties.reason_string = Some(String::read(buf)?); } - PropertyType::UserProperty => properties - .user_properties - .push((String::read(buf)?, String::read(buf)?)), + PropertyType::UserProperty => properties.user_properties.push((String::read(buf)?, String::read(buf)?)), e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRel)), } if buf.is_empty() { @@ -292,20 +270,12 @@ mod tests { #[test] fn test_properties() { let mut properties_data = BytesMut::new(); - PropertyType::ReasonString - .write(&mut properties_data) - .unwrap(); - "reason string, test 1-2-3." - .write(&mut properties_data) - .unwrap(); - PropertyType::UserProperty - .write(&mut properties_data) - .unwrap(); + PropertyType::ReasonString.write(&mut properties_data).unwrap(); + "reason string, test 1-2-3.".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); "This is the key".write(&mut properties_data).unwrap(); "This is the value".write(&mut properties_data).unwrap(); - PropertyType::UserProperty - .write(&mut properties_data) - .unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); "Another thingy".write(&mut properties_data).unwrap(); "The thingy".write(&mut properties_data).unwrap(); diff --git a/src/packets/reason_codes.rs b/src/packets/reason_codes.rs index 73395fd..e5a0c7c 100644 --- a/src/packets/reason_codes.rs +++ b/src/packets/reason_codes.rs @@ -3,9 +3,11 @@ use bytes::{Buf, BufMut}; use super::error::DeserializeError; use super::mqtt_traits::{MqttRead, MqttWrite}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ConnAckReasonCode { + #[default] Success, + UnspecifiedError, MalformedPacket, ProtocolError, @@ -32,11 +34,7 @@ pub enum ConnAckReasonCode { impl MqttRead for ConnAckReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "ConAckReasonCode".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("ConAckReasonCode".to_string(), 0, 1)); } match buf.get_u8() { @@ -110,11 +108,7 @@ pub enum AuthReasonCode { impl MqttRead for AuthReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "AuthReasonCode".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("AuthReasonCode".to_string(), 0, 1)); } match buf.get_u8() { @@ -176,11 +170,7 @@ pub enum DisconnectReasonCode { impl MqttRead for DisconnectReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "DisconnectReasonCode".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("DisconnectReasonCode".to_string(), 0, 1)); } match buf.get_u8() { @@ -274,11 +264,7 @@ pub enum PubAckReasonCode { impl MqttRead for PubAckReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "PubAckReasonCode".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("PubAckReasonCode".to_string(), 0, 1)); } match buf.get_u8() { @@ -325,11 +311,7 @@ pub enum PubCompReasonCode { impl MqttRead for PubCompReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "PubCompReasonCode".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("PubCompReasonCode".to_string(), 0, 1)); } match buf.get_u8() { @@ -368,11 +350,7 @@ pub enum PubRecReasonCode { impl MqttRead for PubRecReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "PubRecReasonCode".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("PubRecReasonCode".to_string(), 0, 1)); } match buf.get_u8() { @@ -418,11 +396,7 @@ pub enum PubRelReasonCode { impl MqttRead for PubRelReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "PubRelReasonCode".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("PubRelReasonCode".to_string(), 0, 1)); } match buf.get_u8() { @@ -464,11 +438,7 @@ pub enum SubAckReasonCode { impl MqttRead for SubAckReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "SubAckReasonCode".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("SubAckReasonCode".to_string(), 0, 1)); } match buf.get_u8() { @@ -525,11 +495,7 @@ pub enum UnsubAckReasonCode { impl MqttRead for UnsubAckReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "UnsubAckReasonCode".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("UnsubAckReasonCode".to_string(), 0, 1)); } match buf.get_u8() { diff --git a/src/packets/suback.rs b/src/packets/suback.rs index f86fa40..f182e5e 100644 --- a/src/packets/suback.rs +++ b/src/packets/suback.rs @@ -19,11 +19,7 @@ pub struct SubAck { } impl VariableHeaderRead for SubAck { - fn read( - _: u8, - _: usize, - mut buf: bytes::Bytes, - ) -> Result { + fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = SubAckProperties::read(&mut buf)?; let mut reason_codes = vec![]; @@ -78,11 +74,7 @@ impl MqttRead for SubAckProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "SubAckProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("SubAckProperties".to_string(), buf.len(), len)); } let mut properties_data = buf.split_to(len); @@ -95,16 +87,11 @@ impl MqttRead for SubAckProperties { properties.subscription_id = Some(subscription_id); } else { - return Err(DeserializeError::DuplicateProperty( - PropertyType::SubscriptionIdentifier, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); } } PropertyType::UserProperty => { - properties.user_properties.push(( - String::read(&mut properties_data)?, - String::read(&mut properties_data)?, - )); + properties.user_properties.push((String::read(&mut properties_data)?, String::read(&mut properties_data)?)); } e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::SubAck)), } diff --git a/src/packets/subscribe.rs b/src/packets/subscribe.rs index 1e48d4b..826a8e8 100644 --- a/src/packets/subscribe.rs +++ b/src/packets/subscribe.rs @@ -1,8 +1,7 @@ use super::{ error::DeserializeError, mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, - QoS, + read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, }; use bytes::{Buf, BufMut}; @@ -24,11 +23,7 @@ impl Subscribe { } impl VariableHeaderRead for Subscribe { - fn read( - _: u8, - _: usize, - mut buf: bytes::Bytes, - ) -> Result { + fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = SubscribeProperties::read(&mut buf)?; let mut topics = vec![]; @@ -97,11 +92,7 @@ impl MqttRead for SubscribeProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "SubscribeProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("SubscribeProperties".to_string(), buf.len(), len)); } let mut properties_data = buf.split_to(len); @@ -114,23 +105,13 @@ impl MqttRead for SubscribeProperties { properties.subscription_id = Some(subscription_id); } else { - return Err(DeserializeError::DuplicateProperty( - PropertyType::SubscriptionIdentifier, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); } } PropertyType::UserProperty => { - properties.user_properties.push(( - String::read(&mut properties_data)?, - String::read(&mut properties_data)?, - )); - } - e => { - return Err(DeserializeError::UnexpectedProperty( - e, - PacketType::Subscribe, - )) + properties.user_properties.push((String::read(&mut properties_data)?, String::read(&mut properties_data)?)); } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Subscribe)), } if properties_data.is_empty() { @@ -192,11 +173,7 @@ impl Default for SubscriptionOptions { impl MqttRead for SubscriptionOptions { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData( - "SubscriptionOptions".to_string(), - 0, - 1, - )); + return Err(DeserializeError::InsufficientData("SubscriptionOptions".to_string(), 0, 1)); } let byte = buf.get_u8(); @@ -219,10 +196,7 @@ impl MqttRead for SubscriptionOptions { impl MqttWrite for SubscriptionOptions { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let byte = (self.retain_handling.into_u8() << 4) - | ((self.retain_as_publish as u8) << 3) - | ((self.no_local as u8) << 2) - | self.qos.into_u8(); + let byte = (self.retain_handling.into_u8() << 4) | ((self.retain_as_publish as u8) << 3) | ((self.no_local as u8) << 2) | self.qos.into_u8(); buf.put_u8(byte); Ok(()) @@ -282,12 +256,7 @@ impl From<&str> for Subscription { impl From<&[&str]> for Subscription { fn from(value: &[&str]) -> Self { - Self( - value - .iter() - .map(|topic| (topic.to_string(), SubscriptionOptions::default())) - .collect(), - ) + Self(value.iter().map(|topic| (topic.to_string(), SubscriptionOptions::default())).collect()) } } @@ -307,15 +276,12 @@ mod tests { #[test] fn test_read_write_subscribe() { let _entire_sub_packet = [ - 0x82, 0x1e, 0x35, 0xd6, 0x02, 0x0b, 0x01, 0x00, 0x16, 0x73, 0x75, 0x62, 0x73, 0x63, - 0x72, 0x69, 0x62, 0x65, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x74, - 0x65, 0x73, 0x74, 0x15, + 0x82, 0x1e, 0x35, 0xd6, 0x02, 0x0b, 0x01, 0x00, 0x16, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x74, 0x65, 0x73, 0x74, + 0x15, ]; let sub_data = [ - 0x35, 0xd6, 0x02, 0x0b, 0x01, 0x00, 0x16, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, - 0x62, 0x65, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x74, 0x65, 0x73, - 0x74, 0x15, + 0x35, 0xd6, 0x02, 0x0b, 0x01, 0x00, 0x16, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x15, ]; // let sub_data = &[0x35, 0xd6, 0x02, 0x0b, 0x01, 0x00, 0x16, 0x73, @@ -345,10 +311,7 @@ mod tests { #[test] fn test_write() { - let expected_bytes = [ - 0x82, 0x0e, 0x00, 0x01, 0x00, 0x00, 0x08, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, - 0x33, 0x00, - ]; + let expected_bytes = [0x82, 0x0e, 0x00, 0x01, 0x00, 0x00, 0x08, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x00]; let sub = Subscribe { packet_identifier: 1, diff --git a/src/packets/unsuback.rs b/src/packets/unsuback.rs index da1aa1b..ff49e50 100644 --- a/src/packets/unsuback.rs +++ b/src/packets/unsuback.rs @@ -1,13 +1,8 @@ use bytes::BufMut; use super::error::{DeserializeError, SerializeError}; -use super::mqtt_traits::{ - MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength, -}; -use super::{ - read_variable_integer, reason_codes::UnsubAckReasonCode, write_variable_integer, PacketType, - PropertyType, -}; +use super::mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}; +use super::{read_variable_integer, reason_codes::UnsubAckReasonCode, write_variable_integer, PacketType, PropertyType}; #[derive(Debug, Default, PartialEq, Eq, Clone)] pub struct UnsubAck { @@ -17,11 +12,7 @@ pub struct UnsubAck { } impl VariableHeaderRead for UnsubAck { - fn read( - _: u8, - _: usize, - mut buf: bytes::Bytes, - ) -> Result { + fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = UnsubAckProperties::read(&mut buf)?; let mut reason_codes = vec![]; @@ -72,11 +63,7 @@ impl MqttRead for UnsubAckProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "UnsubAckProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("UnsubAckProperties".to_string(), buf.len(), len)); } let mut properties_data = buf.split_to(len); @@ -87,23 +74,13 @@ impl MqttRead for UnsubAckProperties { if properties.reason_string.is_none() { properties.reason_string = Some(String::read(&mut properties_data)?); } else { - return Err(DeserializeError::DuplicateProperty( - PropertyType::SubscriptionIdentifier, - )); + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); } } PropertyType::UserProperty => { - properties.user_properties.push(( - String::read(&mut properties_data)?, - String::read(&mut properties_data)?, - )); - } - e => { - return Err(DeserializeError::UnexpectedProperty( - e, - PacketType::UnsubAck, - )) + properties.user_properties.push((String::read(&mut properties_data)?, String::read(&mut properties_data)?)); } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::UnsubAck)), } if buf.is_empty() { diff --git a/src/packets/unsubscribe.rs b/src/packets/unsubscribe.rs index 662ad4c..b963fd4 100644 --- a/src/packets/unsubscribe.rs +++ b/src/packets/unsubscribe.rs @@ -12,12 +12,18 @@ pub struct Unsubscribe { pub topics: Vec, } +impl Unsubscribe { + pub fn new(packet_identifier: u16, topics: Vec) -> Self { + Self { + packet_identifier, + properties: UnsubscribeProperties::default(), + topics, + } + } +} + impl VariableHeaderRead for Unsubscribe { - fn read( - _: u8, - _: usize, - mut buf: bytes::Bytes, - ) -> Result { + fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = UnsubscribeProperties::read(&mut buf)?; let mut topics = vec![]; @@ -54,8 +60,7 @@ impl VariableHeaderWrite for Unsubscribe { impl WireLength for Unsubscribe { fn wire_len(&self) -> usize { - let mut len = - 2 + variable_integer_len(self.properties.wire_len()) + self.properties.wire_len(); + let mut len = 2 + variable_integer_len(self.properties.wire_len()) + self.properties.wire_len(); for topic in &self.topics { len += topic.wire_len(); } @@ -77,11 +82,7 @@ impl MqttRead for UnsubscribeProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData( - "UnsubscribeProperties".to_string(), - buf.len(), - len, - )); + return Err(DeserializeError::InsufficientData("UnsubscribeProperties".to_string(), buf.len(), len)); } let mut properties_data = buf.split_to(len); @@ -89,17 +90,9 @@ impl MqttRead for UnsubscribeProperties { loop { match PropertyType::read(&mut properties_data)? { PropertyType::UserProperty => { - properties.user_properties.push(( - String::read(&mut properties_data)?, - String::read(&mut properties_data)?, - )); - } - e => { - return Err(DeserializeError::UnexpectedProperty( - e, - PacketType::Unsubscribe, - )) + properties.user_properties.push((String::read(&mut properties_data)?, String::read(&mut properties_data)?)); } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Unsubscribe)), } if properties_data.is_empty() { @@ -190,8 +183,7 @@ mod tests { #[test] fn read_write_unsubscribe() { let unsubscribe_packet = &[ - 0x35, 0xd7, 0x00, 0x00, 0x16, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, - 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x74, 0x65, 0x73, 0x74, + 0x35, 0xd7, 0x00, 0x00, 0x16, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x74, 0x65, 0x73, 0x74, ]; let mut bufmut = BytesMut::new(); diff --git a/src/smol_network.rs b/src/smol_network.rs deleted file mode 100644 index ea5f6e1..0000000 --- a/src/smol_network.rs +++ /dev/null @@ -1,162 +0,0 @@ -use async_channel::{Receiver, Sender}; - -use futures::{select, FutureExt}; -use smol::io::{AsyncReadExt, AsyncWriteExt}; - -use std::time::{Duration, Instant}; - -use crate::connect_options::ConnectOptions; -use crate::connections::smol_stream::SmolStream; -use crate::error::ConnectionError; -use crate::packets::error::ReadBytes; -use crate::packets::reason_codes::DisconnectReasonCode; -use crate::packets::{Disconnect, Packet, PacketType}; -use crate::NetworkStatus; - -/// [`SmolNetwork`] reads and writes to the network based on futures [`AsyncReadExt`] [`AsyncWriteExt`]. -/// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. -/// The most import thing to remember is that you have to provide a new stream after the previous failed. -/// (i.e. you need to reconnect after any expected or unexpected disconnect). -pub struct SmolNetwork { - network: Option>, - - /// Options of the current mqtt connection - options: ConnectOptions, - - last_network_action: Instant, - await_pingresp: Option, - perform_keep_alive: bool, - - network_to_handler_s: Sender, - - to_network_r: Receiver, -} - -impl SmolNetwork -where - S: AsyncReadExt + AsyncWriteExt + Sized + Unpin, -{ - pub fn new( - options: ConnectOptions, - network_to_handler_s: Sender, - to_network_r: Receiver, - ) -> Self { - Self { - network: None, - - options, - - last_network_action: Instant::now(), - await_pingresp: None, - perform_keep_alive: true, - - network_to_handler_s, - to_network_r, - } - } - - pub fn reset(&mut self) { - self.network = None; - } - - pub async fn connect(&mut self, stream: S) -> Result<(), ConnectionError> { - let (network, connack) = SmolStream::connect(&self.options, stream).await?; - - self.network = Some(network); - self.network_to_handler_s.send(connack).await?; - self.last_network_action = Instant::now(); - if self.options.keep_alive_interval_s == 0 { - self.perform_keep_alive = false; - } - Ok(()) - } - - pub async fn run(&mut self) -> Result { - if self.network.is_none() { - return Err(ConnectionError::NoNetwork); - } - - match self.select().await { - Ok(NetworkStatus::Active) => Ok(NetworkStatus::Active), - otherwise => { - self.reset(); - otherwise - } - } - } - - async fn select(&mut self) -> Result { - if self.network.is_none() { - return Err(ConnectionError::NoNetwork); - } - - let SmolNetwork { - network, - options: _, - last_network_action, - await_pingresp, - perform_keep_alive, - network_to_handler_s, - to_network_r, - } = self; - - let sleep; - if !(*perform_keep_alive) { - sleep = Duration::new(3600, 0); - } else if let Some(instant) = await_pingresp { - sleep = - *instant + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now(); - } else { - sleep = *last_network_action + Duration::from_secs(self.options.keep_alive_interval_s) - - Instant::now(); - } - - if let Some(stream) = network { - loop { - select! { - _ = stream.read_bytes().fuse() => { - match stream.parse_messages(network_to_handler_s).await { - Err(ReadBytes::Err(err)) => return Err(err), - Err(ReadBytes::InsufficientBytes(_)) => continue, - Ok(Some(PacketType::PingResp)) => { - *await_pingresp = None; - return Ok(NetworkStatus::Active) - }, - Ok(Some(PacketType::Disconnect)) => { - return Ok(NetworkStatus::IncomingDisconnect) - }, - Ok(_) => { - return Ok(NetworkStatus::Active) - } - }; - }, - outgoing = to_network_r.recv().fuse() => { - let packet = outgoing?; - stream.write(&packet).await?; - *last_network_action = Instant::now(); - if packet.packet_type() == PacketType::Disconnect{ - return Ok(NetworkStatus::OutgoingDisconnect); - } - return Ok(NetworkStatus::Active); - }, - _ = smol::Timer::after(sleep).fuse() => { - if await_pingresp.is_none() && *perform_keep_alive{ - let packet = Packet::PingReq; - stream.write(&packet).await?; - *last_network_action = Instant::now(); - *await_pingresp = Some(Instant::now()); - return Ok(NetworkStatus::Active); - } - else{ - let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; - stream.write(&Packet::Disconnect(disconnect)).await?; - return Ok(NetworkStatus::NoPingResp); - } - }, - } - } - } else { - Err(ConnectionError::NoNetwork) - } - } -} diff --git a/src/state.rs b/src/state.rs index 890dc05..419c1f9 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,29 +1,32 @@ -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeSet, VecDeque}; use async_channel::Receiver; use crate::{ available_packet_ids::AvailablePacketIds, - packets::{Publish, Subscribe, Unsubscribe}, + error::HandlerError, + packets::{Packet, PubRel, Publish}, }; #[derive(Debug)] /// [`State`] keeps track of the outgoing and incoming messages on which actions needs to be taken. /// In the future this will be adjusted to rebroadcast packets that have not been acked and thus need to be rebroadcast. pub struct State { - pub(crate) apkid: AvailablePacketIds, + apkid: AvailablePacketIds, /// Outgoing Subcribe requests which aren't acked yet - pub(crate) outgoing_sub: BTreeMap, + outgoing_sub: BTreeSet, /// Outgoing Unsubcribe requests which aren't acked yet - pub(crate) outgoing_unsub: BTreeMap, + outgoing_unsub: BTreeSet, + /// Outgoing QoS 1, 2 publishes which aren't acked yet - pub(crate) outgoing_pub: BTreeMap, + outgoing_pub: Vec>, + outgoing_pub_order: VecDeque, /// Packet ids of released QoS 2 publishes - pub(crate) outgoing_rel: BTreeSet, + outgoing_rel: BTreeSet, - /// Packets on incoming QoS 2 publishes - pub(crate) incoming_pub: BTreeSet, + /// Packet IDs of packets that arrive with QoS 2 + incoming_pub: BTreeSet, } impl State { @@ -32,13 +35,148 @@ impl State { let state = Self { apkid, - outgoing_sub: BTreeMap::new(), - outgoing_unsub: BTreeMap::new(), - outgoing_pub: BTreeMap::new(), + + // make everything an option. We do not want to use vec::remove because it will shift everything right of the element to the left. + // Which because we ussually remove the oldest (most left) items first there will be a lot of shifting! + // If we just swap in place with None than we should be good. + outgoing_sub: BTreeSet::new(), + outgoing_unsub: BTreeSet::new(), + outgoing_pub: vec![None; receive_maximum as usize], + outgoing_pub_order: VecDeque::new(), outgoing_rel: BTreeSet::new(), incoming_pub: BTreeSet::new(), }; (state, r) } + + pub fn make_pkid_available(&mut self, pkid: u16) -> Result<(), HandlerError> { + self.apkid.mark_available(pkid) + } + + pub fn add_incoming_pub(&mut self, pkid: u16) -> bool { + self.incoming_pub.insert(pkid) + } + + /// Returns whether the packett id was present. + pub fn remove_incoming_pub(&mut self, pkid: u16) -> bool { + self.incoming_pub.remove(&pkid) + } + + pub fn add_outgoing_pub(&mut self, pkid: u16, publish: Publish) -> Result<(), HandlerError> { + let current_pub = self.outgoing_pub[(pkid - 1) as usize].take(); + self.outgoing_pub[(pkid - 1) as usize] = Some(publish); + + if current_pub.is_some() { + Err(HandlerError::PacketIdCollision(pkid)) + } else { + self.outgoing_pub_order.push_back(pkid); + Ok(()) + } + } + + pub fn remove_outgoing_pub(&mut self, pkid: u16) -> Option { + for (index, id) in self.outgoing_pub_order.iter().enumerate() { + if pkid == *id { + self.outgoing_pub_order.remove(index); + break; + } + } + + self.outgoing_pub[pkid as usize - 1].take() + } + + pub fn add_outgoing_rel(&mut self, pkid: u16) -> bool { + self.outgoing_rel.insert(pkid) + } + + /// Returns whether the packett id was present. + pub fn remove_outgoing_rel(&mut self, pkid: &u16) -> bool { + self.outgoing_rel.remove(pkid) + } + + pub fn add_outgoing_sub(&mut self, pkid: u16) -> bool { + self.outgoing_sub.insert(pkid) + } + + /// Returns whether the packett id was present. + pub fn remove_outgoing_sub(&mut self, pkid: u16) -> bool { + self.outgoing_sub.remove(&pkid) + } + + pub fn add_outgoing_unsub(&mut self, pkid: u16) -> bool { + self.outgoing_unsub.insert(pkid) + } + + /// Returns whether the packett id was present. + pub fn remove_outgoing_unsub(&mut self, pkid: u16) -> bool { + self.outgoing_unsub.remove(&pkid) + } + + /// Returns the identifiers that are in use but can be freed + pub fn reset(&mut self, retransmission: bool) -> (Vec, Vec) { + let State { + apkid: _, + outgoing_sub, + outgoing_unsub, + outgoing_pub, + outgoing_pub_order, + outgoing_rel, + incoming_pub, + } = self; + + let mut freeable_ids = Vec::::with_capacity(outgoing_sub.len() + outgoing_unsub.len()); + // let mut freeable_ids = outgoing_sub.iter().chain(outgoing_unsub.iter()).collect::>(); + let mut retransmit = Vec::with_capacity(outgoing_pub_order.len()); + + freeable_ids.extend(outgoing_sub.iter()); + freeable_ids.extend(outgoing_unsub.iter()); + + if retransmission { + for i in outgoing_pub_order { + let mut packet = outgoing_pub[(*i - 1) as usize].clone().unwrap(); + packet.dup = true; + retransmit.push(Packet::Publish(packet)); + } + + for &rel in outgoing_rel.iter() { + retransmit.push(Packet::PubRel(PubRel::new(rel))); + } + } else { + freeable_ids.extend(outgoing_pub_order.iter()); + + *outgoing_pub = vec![None; outgoing_pub.len()]; + outgoing_pub_order.clear(); + outgoing_rel.clear(); + } + + outgoing_sub.clear(); + outgoing_unsub.clear(); + + incoming_pub.clear(); + + (freeable_ids, retransmit) + } +} + +#[cfg(test)] +impl State { + pub fn outgoing_sub(&mut self) -> &mut BTreeSet { + &mut self.outgoing_sub + } + pub fn outgoing_unsub(&mut self) -> &mut BTreeSet { + &mut self.outgoing_unsub + } + pub fn outgoing_pub(&mut self) -> &mut Vec> { + &mut self.outgoing_pub + } + pub fn outgoing_pub_order(&mut self) -> &mut VecDeque { + &mut self.outgoing_pub_order + } + pub fn outgoing_rel(&mut self) -> &mut BTreeSet { + &mut self.outgoing_rel + } + pub fn incoming_pub(&mut self) -> &mut BTreeSet { + &mut self.incoming_pub + } } diff --git a/src/tests/connection_tests.rs b/src/tests/connection_tests.rs deleted file mode 100644 index 7be6728..0000000 --- a/src/tests/connection_tests.rs +++ /dev/null @@ -1,99 +0,0 @@ -// #[cfg(feature = "tokio")] -// #[cfg(test)] -// mod tokio_e2e { - -// use futures_concurrency::future::Join; - -// use tracing::Level; -// use tracing_subscriber::FmtSubscriber; - -// use crate::tests::stages::Nop; -// use crate::{ -// connect_options::ConnectOptions, create_tokio_tcp, error::ClientError, -// event_handler::EventHandler, packets::QoS, -// }; - -// use crate::tests::stages::qos_2::TestPubQoS2; - -// #[tokio::test(flavor = "multi_thread", worker_threads = 4)] -// async fn test_pub_qos_2() { -// let filter = tracing_subscriber::filter::EnvFilter::new("none,mqrstt=trace"); - -// let subscriber = FmtSubscriber::builder() -// .with_env_filter(filter) -// .with_max_level(Level::TRACE) -// .with_line_number(true) -// .finish(); - -// tracing::subscriber::set_global_default(subscriber) -// .expect("setting default subscriber failed"); - -// let opt = ConnectOptions::new_with_tls_config( -// "broker.emqx.io".to_string(), -// 1883, -// "test123123".to_string(), -// None, -// ); -// // let opt = ConnectOptions::new("127.0.0.1".to_string(), 1883, "test123123".to_string(), None); - -// let (mut mqtt_network, handler, client) = create_tokio_tcp(opt); - -// let network = tokio::task::spawn(async move { dbg!(mqtt_network.run().await) }); - -// let event_handler = tokio::task::spawn(async move { -// let mut custom_handler = Nop{}; -// dbg!(handler.handle(&mut custom_handler).await) -// }); - -// let sender = tokio::task::spawn(async move { -// client.subscribe("mqrstt").await.unwrap(); -// client -// .publish(QoS::ExactlyOnce, false, "test".to_string(), "123456789") -// .await?; - -// let lol = smol::future::pending::>(); -// lol.await -// }); - -// dbg!((network, event_handler, sender).join().await); -// } -// } - -// #[cfg(all(feature = "smol", feature = "smol-rustls"))] -// #[cfg(test)] -// mod smol_rustls_e2e { - -// use tracing::Level; -// use tracing_subscriber::FmtSubscriber; - -// use crate::{ -// connect_options::ConnectOptions, connections::transport::RustlsConfig, create_smol_rustls, -// tests::resources::EMQX_CERT, -// }; - -// #[test] -// fn test_pub_tcp_qos_2() { -// let filter = tracing_subscriber::filter::EnvFilter::new("none,mqrstt=trace"); - -// let subscriber = FmtSubscriber::builder() -// .with_env_filter(filter) -// .with_max_level(Level::TRACE) -// .with_line_number(true) -// .finish(); - -// tracing::subscriber::set_global_default(subscriber) -// .expect("setting default subscriber failed"); - -// let config = RustlsConfig::Simple { -// ca: EMQX_CERT.to_vec(), -// alpn: None, -// client_auth: None, -// }; - -// let opt = ConnectOptions::new("broker.emqx.io".to_string(), 8883, "test123123".to_string()); - -// let (mut mqtt_network, _handler, _client) = create_smol_rustls(opt, config); - -// smol::block_on(mqtt_network.run()).unwrap() -// } -// } diff --git a/src/tests/mod.rs b/src/tests/mod.rs index eb56cef..eaa5dc5 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,4 +1,3 @@ -mod connection_tests; pub mod resources; mod test_bytes; pub mod test_packets; diff --git a/src/tests/test_bytes.rs b/src/tests/test_bytes.rs index a65e1ac..1af5a7b 100644 --- a/src/tests/test_bytes.rs +++ b/src/tests/test_bytes.rs @@ -6,12 +6,9 @@ use crate::packets::{mqtt_traits::WireLength, Packet}; fn publish_packet() -> Vec { const PUBLISH_BYTES: [u8; 79] = [ - 0x35, 0x4d, 0x00, 0x1a, 0x63, 0x75, 0x2f, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x2f, 0x72, 0x65, - 0x70, 0x6c, 0x79, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x74, 0x65, 0x73, 0x74, - 0x06, 0x47, 0x29, 0x23, 0x00, 0x01, 0x09, 0x00, 0x0b, 0x43, 0x6f, 0x72, 0x72, 0x65, 0x6c, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x26, 0x00, 0x01, 0x41, 0x00, 0x01, 0x42, 0x26, 0x00, 0x01, - 0x43, 0x00, 0x01, 0x44, 0x03, 0x00, 0x07, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x68, - 0x65, 0x6c, 0x6c, 0x6f, + 0x35, 0x4d, 0x00, 0x1a, 0x63, 0x75, 0x2f, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x2f, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x74, 0x65, 0x73, 0x74, 0x06, 0x47, + 0x29, 0x23, 0x00, 0x01, 0x09, 0x00, 0x0b, 0x43, 0x6f, 0x72, 0x72, 0x65, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x26, 0x00, 0x01, 0x41, 0x00, 0x01, 0x42, 0x26, 0x00, 0x01, 0x43, 0x00, 0x01, 0x44, + 0x03, 0x00, 0x07, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x68, 0x65, 0x6c, 0x6c, 0x6f, ]; PUBLISH_BYTES.to_vec() @@ -19,12 +16,9 @@ fn publish_packet() -> Vec { #[fixture] fn publish_packet_2() -> Vec { const PUBLISH_BYTES: [u8; 76] = [ - 0x34, 0x4a, 0x00, 0x1a, 0x63, 0x75, 0x2f, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x2f, 0x72, 0x65, - 0x70, 0x6c, 0x79, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x74, 0x65, 0x73, 0x74, - 0x00, 0x02, 0x26, 0x03, 0x00, 0x07, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x09, 0x00, - 0x0b, 0x43, 0x6f, 0x72, 0x72, 0x65, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x26, 0x00, 0x01, - 0x41, 0x00, 0x01, 0x42, 0x26, 0x00, 0x01, 0x43, 0x00, 0x01, 0x44, 0x68, 0x65, 0x6c, 0x6c, - 0x6f, + 0x34, 0x4a, 0x00, 0x1a, 0x63, 0x75, 0x2f, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x2f, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x74, 0x65, 0x73, 0x74, 0x00, 0x02, + 0x26, 0x03, 0x00, 0x07, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x09, 0x00, 0x0b, 0x43, 0x6f, 0x72, 0x72, 0x65, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x26, 0x00, 0x01, 0x41, 0x00, 0x01, 0x42, + 0x26, 0x00, 0x01, 0x43, 0x00, 0x01, 0x44, 0x68, 0x65, 0x6c, 0x6c, 0x6f, ]; PUBLISH_BYTES.to_vec() @@ -32,36 +26,24 @@ fn publish_packet_2() -> Vec { fn connect_packet() -> Vec { vec![ - 0x10, 0xbe, 0x02, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0xc6, 0x00, 0x3c, 0x2a, 0x11, - 0x00, 0x00, 0x00, 0x00, 0x21, 0x00, 0x7b, 0x27, 0x00, 0x00, 0x1f, 0x40, 0x22, 0x09, 0xc4, - 0x19, 0x01, 0x17, 0x01, 0x26, 0x00, 0x01, 0x41, 0x00, 0x01, 0x42, 0x26, 0x00, 0x02, 0x43, - 0x44, 0x00, 0x08, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x00, 0x0e, 0x6d, 0x71, - 0x74, 0x74, 0x78, 0x5f, 0x36, 0x37, 0x62, 0x65, 0x32, 0x33, 0x38, 0x34, 0x06, 0x03, 0x00, - 0x00, 0x08, 0x00, 0x00, 0x00, 0x12, 0x61, 0x6e, 0x6f, 0x74, 0x68, 0x65, 0x72, 0x2f, 0x74, - 0x65, 0x73, 0x74, 0x2f, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x00, 0xc5, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, - 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x00, 0x08, - 0x54, 0x65, 0x73, 0x74, 0x54, 0x65, 0x73, 0x74, 0x00, 0x0b, 0x50, 0x61, 0x73, 0x73, 0x77, - 0x6f, 0x72, 0x64, 0x31, 0x32, 33, + 0x10, 0xbe, 0x02, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0xc6, 0x00, 0x3c, 0x2a, 0x11, 0x00, 0x00, 0x00, 0x00, 0x21, 0x00, 0x7b, 0x27, 0x00, 0x00, 0x1f, 0x40, 0x22, 0x09, 0xc4, 0x19, 0x01, + 0x17, 0x01, 0x26, 0x00, 0x01, 0x41, 0x00, 0x01, 0x42, 0x26, 0x00, 0x02, 0x43, 0x44, 0x00, 0x08, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x00, 0x0e, 0x6d, 0x71, 0x74, 0x74, 0x78, 0x5f, + 0x36, 0x37, 0x62, 0x65, 0x32, 0x33, 0x38, 0x34, 0x06, 0x03, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x12, 0x61, 0x6e, 0x6f, 0x74, 0x68, 0x65, 0x72, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x74, 0x6f, + 0x70, 0x69, 0x63, 0x00, 0xc5, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x00, 0x08, 0x54, 0x65, 0x73, 0x74, 0x54, 0x65, 0x73, 0x74, 0x00, 0x0b, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x31, 0x32, + 33, ] } pub fn subscribe_packet() -> Vec { vec![ - 0x82, 0x22, 0x82, 0x02, 0x02, 0x0b, 0x7b, 0x00, 0x1a, 0x63, 0x75, 0x2f, 0x39, 0x2e, 0x30, - 0x2e, 0x31, 0x2f, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, - 0x74, 0x74, 0x65, 0x73, 0x74, 0x2e, + 0x82, 0x22, 0x82, 0x02, 0x02, 0x0b, 0x7b, 0x00, 0x1a, 0x63, 0x75, 0x2f, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x2f, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x72, 0x74, 0x74, + 0x65, 0x73, 0x74, 0x2e, ] } @@ -91,14 +73,7 @@ fn test_connect() { if let Packet::Connect(p) = &packet { assert_eq!(42, p.connect_properties.wire_len()); - assert_eq!( - 6, - p.last_will - .as_ref() - .unwrap() - .last_will_properties - .wire_len() - ); + assert_eq!(6, p.last_will.as_ref().unwrap().last_will_properties.wire_len()); } assert_eq!(bytes.len(), write_buffer.len()); diff --git a/src/tests/test_packets.rs b/src/tests/test_packets.rs index 7aeba26..2a3ac96 100644 --- a/src/tests/test_packets.rs +++ b/src/tests/test_packets.rs @@ -4,8 +4,7 @@ use rstest::*; use crate::packets::{ reason_codes::{DisconnectReasonCode, PubAckReasonCode}, - Disconnect, DisconnectProperties, Packet, PubAck, PubAckProperties, Publish, PublishProperties, - QoS, Subscribe, Subscription, + ConnAck, Disconnect, DisconnectProperties, Packet, PubAck, PubAckProperties, Publish, PublishProperties, QoS, Subscribe, Subscription, Unsubscribe, }; fn publish_packet_1() -> Packet { @@ -96,12 +95,12 @@ pub fn create_subscribe_packet(packet_identifier: u16) -> Packet { Packet::Subscribe(sub) } -pub fn create_publish_packet( - qos: QoS, - dup: bool, - retain: bool, - packet_identifier: Option, -) -> Packet { +pub fn create_unsubscribe_packet(packet_identifier: u16) -> Packet { + let sub = Unsubscribe::new(packet_identifier, vec!["test/topic".to_string()]); + Packet::Unsubscribe(sub) +} + +pub fn create_publish_packet(qos: QoS, dup: bool, retain: bool, packet_identifier: Option) -> Packet { Packet::Publish(Publish { dup, qos, @@ -130,6 +129,13 @@ pub fn create_puback_packet(packet_identifier: u16) -> Packet { }) } +pub fn create_connack_packet(session_present: bool) -> Packet { + let mut connack = ConnAck::default(); + connack.connack_flags.session_present = session_present; + + Packet::ConnAck(connack) +} + pub fn create_disconnect_packet() -> Packet { Packet::Disconnect(Disconnect { reason_code: DisconnectReasonCode::NormalDisconnection, diff --git a/src/tests/tls.rs b/src/tests/tls.rs index 855af2d..1bc3153 100644 --- a/src/tests/tls.rs +++ b/src/tests/tls.rs @@ -13,22 +13,14 @@ pub mod tests { ECC(Vec), } - pub fn simple_rust_tls( - ca: Vec, - alpn: Option>>, - client_auth: Option<(Vec, PrivateKey)>, - ) -> Result, Error> { + pub fn simple_rust_tls(ca: Vec, alpn: Option>>, client_auth: Option<(Vec, PrivateKey)>) -> Result, Error> { let mut root_cert_store = RootCertStore::empty(); let ca_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca))).unwrap(); let trust_anchors = ca_certs.iter().map_while(|cert| { if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { - Some(OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - )) + Some(OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)) } else { None } @@ -37,27 +29,19 @@ pub mod tests { assert!(!root_cert_store.is_empty()); - let config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store); + let config = ClientConfig::builder().with_safe_defaults().with_root_certificates(root_cert_store); let mut config = match client_auth { Some((client_cert_info, client_private_info)) => { let read_private_keys = match client_private_info { - PrivateKey::RSA(rsa) => { - rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa))) - } - PrivateKey::ECC(ecc) => { - rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc))) - } + PrivateKey::RSA(rsa) => rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa))), + PrivateKey::ECC(ecc) => rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc))), } .unwrap(); let key = read_private_keys.into_iter().next().unwrap(); - let client_certs = - rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info))) - .unwrap(); + let client_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info))).unwrap(); let client_cert_chain = client_certs.into_iter().map(Certificate).collect(); config.with_single_cert(client_cert_chain, rustls::PrivateKey(key))? diff --git a/src/tokio_network.rs b/src/tokio_network.rs deleted file mode 100644 index c630b91..0000000 --- a/src/tokio_network.rs +++ /dev/null @@ -1,153 +0,0 @@ -use async_channel::{Receiver, Sender}; - -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -use std::time::{Duration, Instant}; - -use crate::connect_options::ConnectOptions; -use crate::connections::tokio_stream::TokioStream; -use crate::error::ConnectionError; -use crate::packets::error::ReadBytes; -use crate::packets::reason_codes::DisconnectReasonCode; -use crate::packets::{Disconnect, Packet, PacketType}; -use crate::NetworkStatus; - -/// [`TokioNetwork`] reads and writes to the network based on tokios [`AsyncReadExt`] [`AsyncWriteExt`]. -/// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. -/// The most import thing to remember is that you have to provide a new stream after the previous has failed. -/// (i.e. you need to reconnect after any expected or unexpected disconnect). -pub struct TokioNetwork { - network: Option>, - - /// Options of the current mqtt connection - options: ConnectOptions, - - last_network_action: Instant, - await_pingresp: Option, - perform_keep_alive: bool, - - network_to_handler_s: Sender, - - to_network_r: Receiver, -} - -impl TokioNetwork -where - S: AsyncReadExt + AsyncWriteExt + Sized + Unpin, -{ - pub fn new( - options: ConnectOptions, - network_to_handler_s: Sender, - to_network_r: Receiver, - ) -> Self { - Self { - network: None, - - options, - - last_network_action: Instant::now(), - await_pingresp: None, - perform_keep_alive: true, - - network_to_handler_s, - to_network_r, - } - } - - pub fn reset(&mut self) { - self.network = None; - } - - pub async fn connect(&mut self, stream: S) -> Result<(), ConnectionError> { - let (network, connack) = TokioStream::connect(&self.options, stream).await?; - - self.network = Some(network); - self.network_to_handler_s.send(connack).await?; - self.last_network_action = Instant::now(); - if self.options.keep_alive_interval_s == 0 { - self.perform_keep_alive = false; - } - Ok(()) - } - - pub async fn run(&mut self) -> Result { - if self.network.is_none() { - return Err(ConnectionError::NoNetwork); - } - - match self.select().await { - Ok(NetworkStatus::Active) => Ok(NetworkStatus::Active), - otherwise => { - self.reset(); - otherwise - } - } - } - - async fn select(&mut self) -> Result { - let TokioNetwork { - network, - options: _, - last_network_action, - await_pingresp, - perform_keep_alive, - network_to_handler_s, - to_network_r, - } = self; - - let sleep; - if let Some(instant) = await_pingresp { - sleep = - *instant + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now(); - } else { - sleep = *last_network_action + Duration::from_secs(self.options.keep_alive_interval_s) - - Instant::now(); - } - - if let Some(stream) = network { - loop { - tokio::select! { - _ = stream.read_bytes() => { - match stream.parse_messages(network_to_handler_s).await { - Err(ReadBytes::Err(err)) => return Err(err), - Err(ReadBytes::InsufficientBytes(_)) => continue, - Ok(Some(PacketType::PingResp)) => { - *await_pingresp = None; - return Ok(NetworkStatus::Active) - }, - Ok(Some(PacketType::Disconnect)) => { - return Ok(NetworkStatus::IncomingDisconnect) - }, - Ok(_) => { - return Ok(NetworkStatus::Active) - } - }; - }, - outgoing = to_network_r.recv() => { - let packet = outgoing?; - stream.write(&packet).await?; - *last_network_action = Instant::now(); - if packet.packet_type() == PacketType::Disconnect{ - return Ok(NetworkStatus::OutgoingDisconnect); - } - return Ok(NetworkStatus::Active); - }, - _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { - let packet = Packet::PingReq; - stream.write(&packet).await?; - *last_network_action = Instant::now(); - *await_pingresp = Some(Instant::now()); - return Ok(NetworkStatus::Active); - }, - _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { - let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; - stream.write(&Packet::Disconnect(disconnect)).await?; - return Ok(NetworkStatus::NoPingResp); - } - } - } - } else { - Err(ConnectionError::NoNetwork) - } - } -} diff --git a/src/util/mod.rs b/src/util/mod.rs index 987b09f..0b6d2c4 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1 +1 @@ -pub mod constants; \ No newline at end of file +pub mod constants;